Spaces:
Running
Running
| import torch as th | |
| import numpy as np | |
| def compute_logp(args, model, x, input_ids): | |
| word_emb = model.weight | |
| sigma = 0.1 | |
| if args.model_arch == '1d-unet': | |
| x = x.permute(0, 2, 1) | |
| bsz, seqlen, dim = x.shape | |
| x_flat = x.reshape(-1, x.size(-1)).unsqueeze(0) # 1, bsz*sample*seqlen, dim | |
| word_emb_flat = word_emb.unsqueeze(1) # vocab, 1, dim | |
| diff = (x_flat - word_emb_flat) ** 2 # vocab, seqlen, dim | |
| logp_expanded = -diff.sum(dim=-1) / (2 * sigma ** 2) # vocab, seqlen | |
| logp_expanded = logp_expanded.permute((1, 0)) | |
| # print(th.topk(logp_expanded.view(bsz, seqlen, -1), k=5, dim=-1)[0]) | |
| # print(input_ids[0]) | |
| ce = th.nn.CrossEntropyLoss(reduction='none') | |
| loss = ce(logp_expanded, input_ids.view(-1)).view(bsz, seqlen) | |
| # print(loss[0]) | |
| # print(loss.shape) | |
| return loss | |
| def get_weights(model, args): | |
| if hasattr(model, 'transformer'): | |
| input_embs = model.transformer.wte # input_embs | |
| down_proj = model.down_proj | |
| down_proj_emb = down_proj(input_embs.weight) | |
| print(down_proj_emb.shape) | |
| # model = th.nn.Embedding(down_proj_emb.shape[1], down_proj_emb.shape[0]) | |
| model = th.nn.Embedding(down_proj_emb.size(0), down_proj_emb.size(1)) | |
| print(args.emb_scale_factor) | |
| model.weight.data = down_proj_emb * args.emb_scale_factor | |
| elif hasattr(model, 'weight'): | |
| pass | |
| else: | |
| assert NotImplementedError | |
| model.weight.requires_grad = False | |
| return model | |
| def denoised_fn_round(args, model, text_emb, t): | |
| # return text_emb | |
| thresh_t = 350 | |
| # print(thresh_t) | |
| # print(t) | |
| if thresh_t is not None and t[0] > thresh_t: | |
| return text_emb | |
| # return text_emb | |
| # print(t.float().mean(), t[0]) | |
| # assert t.float().mean() == t[0].float() | |
| # print(text_emb.shape) # bsz, seqlen, dim | |
| # down_proj_emb = model.weight # input_embs | |
| down_proj_emb = model | |
| # print(t) | |
| old_shape = text_emb.shape | |
| old_device = text_emb.device | |
| def get_efficient_knn(down_proj_emb, text_emb, dist='l2'): | |
| if dist == 'l2': | |
| emb_norm = (down_proj_emb**2).sum(-1).view(-1, 1) #vocab | |
| text_emb_t = th.transpose(text_emb.view(-1, text_emb.size(-1)), 0, 1) #d, bsz*seqlen | |
| arr_norm = (text_emb ** 2).sum(-1).view(-1, 1) #bsz*seqlen, 1 | |
| # print(emb_norm.shape, arr_norm.shape) | |
| dist = emb_norm + arr_norm.transpose(0, 1) - 2.0 * th.mm(down_proj_emb, text_emb_t) #(vocab, d) x (d, bsz*seqlen) | |
| dist = th.clamp(dist, 0.0, np.inf) | |
| # print(dist.shape) | |
| topk_out = th.topk(-dist, k=1, dim=0) | |
| # adjacency = down_proj_emb.unsqueeze(1).expand(-1, text_emb.size(0), -1) - text_emb.unsqueeze(0).expand( | |
| # down_proj_emb.size(0), -1, -1) | |
| # adjacency = -th.norm(adjacency, dim=-1) | |
| # topk_out = th.topk(adjacency, k=1, dim=0) | |
| # print(topk_out1.indices == topk_out.indices) | |
| # assert th.all(topk_out1.indices == topk_out.indices) | |
| return topk_out.values, topk_out.indices | |
| # def get_knn(down_proj_emb, text_emb, dist='l2'): | |
| # if dist == 'l2': | |
| # adjacency = down_proj_emb.unsqueeze(1).expand(-1, text_emb.size(0), -1) - text_emb.unsqueeze(0).expand( | |
| # down_proj_emb.size(0), -1, -1) | |
| # adjacency = -th.norm(adjacency, dim=-1) | |
| # topk_out = th.topk(adjacency, k=1, dim=0) | |
| # return topk_out.values, topk_out.indices | |
| dist = 'l2' | |
| if len(text_emb.shape) > 2: | |
| text_emb = text_emb.reshape(-1, text_emb.size(-1)) | |
| else: | |
| text_emb = text_emb | |
| # val, indices = get_knn(down_proj_emb, | |
| # text_emb.to(down_proj_emb.device), dist=dist) | |
| val, indices = get_efficient_knn(down_proj_emb, | |
| text_emb.to(down_proj_emb.device), dist=dist) | |
| rounded_tokens = indices[0] | |
| # print(rounded_tokens.shape) | |
| new_embeds = model[rounded_tokens].view(old_shape).to(old_device) | |
| return new_embeds | |
| def load_results(json_path, load_dict): | |
| import json | |
| with open(json_path, 'w') as f: | |
| json.dump(load_dict, f, indent=2) | |