Spaces:
Running
Running
| import torch | |
| # bert results | |
| from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer, default_data_collator | |
| import sys, yaml, os | |
| # print( os.path.join(sys.path[0], '../../transformers/examples/pytorch/language-modeling')) | |
| # sys.path.insert(0, 'diffusion_lm/transformers/examples/pytorch/language-modeling') | |
| # sys.path.insert(0, os.path.join(sys.path[0], '../../transformers/examples/pytorch/language-modeling')) | |
| # from custom_trainer import GPT2LMHeadModelCompress, BERTModelCompress, AutoEncoderWithNoise | |
| def load_models(modality, mode, model_name_or_path, emb_dim, file, extra_args=None): | |
| if mode in ['random', 'random1', 'random_up_proj', 'glove']: | |
| if modality == 'synth': | |
| pass# print(file, 'deciding what to load::: ') | |
| # if 'synth128' in file: | |
| # config = 'diffusion_lm/synthetic_data/configs/emnlp2020/experiments/difflm_seed0_m3_k128_trainc20000.yaml' | |
| # else: | |
| # config = 'diffusion_lm/synthetic_data/configs/emnlp2020/experiments/difflm_seed0_m3_k32_trainc20000.yaml' | |
| # import sys, os | |
| # sys.path.insert(0, 'diffusion_lm/synthetic_data/rnns-stacks') | |
| # from dataset import Dataset as SynthDataset | |
| # args_synth = yaml.load(open(config)) | |
| # dataset = SynthDataset(args_synth) | |
| # model = torch.nn.Embedding(len(dataset.vocab), emb_dim) | |
| # print('initializing the random embeddings', model) | |
| # # print(os.path.split(file.split('.')[0])[-1]) | |
| # # path_save = '{}/random_emb.torch'.format(file) | |
| # path_save = '{}/random_emb.torch'.format(file) | |
| # model.load_state_dict(torch.load(path_save)) | |
| # print(dataset.vocab) | |
| # tokenizer = {v: k for k, v in dataset.vocab.items()} | |
| else: | |
| import json | |
| if modality == 'book' or (extra_args is not None and extra_args.use_bert_tokenizer == 'yes'): | |
| pass# tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') | |
| # if 'e2e' in file and modality == 'book': | |
| # emb_dim = 1 | |
| else: | |
| path_save_tokenizer = '{}/vocab.json'.format(file) | |
| path_save_tokenizer = '/data0/gonghaisong/Diffusion-LM/improved-diffusion/diffusion_models/diff_e2e-tgt_block_rand16_transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd102_xstart_e2e/vocab.json' | |
| print(f'loading from {path_save_tokenizer}') | |
| with open(path_save_tokenizer, 'r') as f: | |
| vocab = json.load(f) | |
| print(len(vocab)) | |
| tokenizer = {v: k for k, v in vocab.items()} | |
| model = torch.nn.Embedding(len(tokenizer), emb_dim) | |
| path_save = '{}/random_emb.torch'.format(file) | |
| path_save = '/data0/gonghaisong/Diffusion-LM/improved-diffusion/diffusion_models/diff_e2e-tgt_block_rand16_transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd102_xstart_e2e/random_emb.torch' | |
| model.load_state_dict(torch.load(path_save)) | |
| return model, tokenizer | |
| def load_tokenizer(modality, mode, model_name_or_path): | |
| if mode in ['random', 'random_up_proj', 'glove']: | |
| if modality == 'synth': | |
| print(model_name_or_path, 'deciding what to load::: ') | |
| if 'synth128' in model_name_or_path: | |
| config = 'diffusion_lm/synthetic_data/configs/emnlp2020/experiments/difflm_seed0_m3_k128_trainc20000.yaml' | |
| else: | |
| config = 'diffusion_lm/synthetic_data/configs/emnlp2020/experiments/difflm_seed0_m3_k32_trainc20000.yaml' | |
| import sys, os | |
| sys.path.insert(0, 'diffusion_lm/synthetic_data/rnns-stacks') | |
| from dataset import Dataset as SynthDataset | |
| args_synth = yaml.load(open(config)) | |
| dataset = SynthDataset(args_synth) | |
| tokenizer = {v: k for k, v in dataset.vocab.items()} | |
| elif modality =='book': | |
| tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') | |
| else: | |
| import json | |
| path_save_tokenizer = '{}/vocab.json'.format(model_name_or_path) | |
| with open(path_save_tokenizer, 'r') as f: | |
| vocab = json.load(f) | |
| tokenizer = {v: k for k, v in vocab.items()} | |
| return tokenizer | |
| def rounding_func(mode, text_emb_lst, model, tokenizer, emb_scale_factor=1.0): | |
| decoded_out_lst = [] | |
| if mode in ['random', 'random_up_proj', 'glove']: | |
| down_proj_emb = model.weight # input_embs | |
| down_proj_emb2 = None | |
| def get_knn(down_proj_emb, text_emb, dist='cos'): | |
| if dist == 'cos': | |
| adjacency = down_proj_emb @ text_emb.transpose(1, 0).to(down_proj_emb.device) | |
| elif dist == 'l2': | |
| adjacency = down_proj_emb.unsqueeze(1).expand(-1, text_emb.size(0), -1) - text_emb.unsqueeze(0).expand( | |
| down_proj_emb.size(0), -1, -1) | |
| adjacency = -torch.norm(adjacency, dim=-1) | |
| topk_out = torch.topk(adjacency, k=6, dim=0) | |
| return topk_out.values, topk_out.indices | |
| dist = 'l2' | |
| # print(npzfile['arr_0'].shape) | |
| for text_emb in text_emb_lst: | |
| import torch | |
| text_emb = torch.tensor(text_emb) | |
| # print(text_emb.shape) | |
| if len(text_emb.shape) > 2: | |
| text_emb = text_emb.view(-1, text_emb.size(-1)) | |
| else: | |
| text_emb = text_emb | |
| val, indices = get_knn((down_proj_emb2 if dist == 'cos' else down_proj_emb), | |
| text_emb.to(down_proj_emb.device), dist=dist) | |
| # generated_lst.append(tuple(indices[0].tolist())) | |
| # print(indices[0].tolist()) | |
| # for i in range(64): | |
| # print([tokenizer[x.item()] for x in indices[:,i]]) | |
| decoded_out = " ".join([tokenizer[i] for i in indices[0].tolist()]) | |
| decoded_out_lst.append(decoded_out) | |
| return decoded_out_lst | |