Spaces:
Running
Running
| # from PIL import Image | |
| # import blobfile as bf | |
| from mpi4py import MPI | |
| import numpy as np | |
| from torch.utils.data import DataLoader, Dataset | |
| from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer, default_data_collator, PreTrainedTokenizerFast, \ | |
| PreTrainedTokenizer | |
| # from datasets import load_dataset | |
| import sys, os | |
| import torch | |
| # sys.path.insert(0, os.path.join(sys.path[0], '../../transformers/examples/pytorch/language-modeling')) | |
| # from custom_trainer import GPT2LMHeadModelCompress, BERTModelCompress, AutoEncoderWithNoise | |
| from collections import Counter, defaultdict | |
| from functools import partial | |
| from itertools import chain | |
| def load_data_text( | |
| *, data_dir, batch_size, image_size, class_cond=False, deterministic=False, data_args=None, | |
| task_mode='roc', model=None, padding_mode='block', split='train', load_vocab=None, | |
| ): | |
| """ | |
| For a dataset, create a generator over (images, kwargs) pairs. | |
| Each images is an NCHW float tensor, and the kwargs dict contains zero or | |
| more keys, each of which map to a batched Tensor of their own. | |
| The kwargs dict can be used for class labels, in which case the key is "y" | |
| and the values are integer tensors of class labels. | |
| :param data_dir: a dataset directory. | |
| :param batch_size: the batch size of each returned pair. | |
| :param image_size: the size to which images are resized. | |
| :param class_cond: if True, include a "y" key in returned dicts for class | |
| label. If classes are not available and this is true, an | |
| exception will be raised. | |
| :param deterministic: if True, yield results in a deterministic order. | |
| """ | |
| print('hello loading text data. ') | |
| if data_args.experiment.startswith('random') and model is None: | |
| model = None | |
| # elif data_args.experiment.startswith('random') and model is not None: | |
| # print('loading initialized random embeddings. ') | |
| if task_mode == 'roc' or task_mode == 'roc-aug' : | |
| pass | |
| # training_data, model = get_corpus_rocstory(data_args, model, image_size, | |
| # padding_mode=padding_mode, split=split, | |
| # load_vocab=load_vocab) | |
| elif task_mode == 'simple-wiki': | |
| pass | |
| # training_data, model = get_corpus_rocstory(data_args, model, image_size, | |
| # padding_mode=padding_mode, split=split, | |
| # load_vocab=load_vocab) | |
| elif task_mode == 'e2e-tgt': | |
| print('hello loading e2e-tgt. ') | |
| training_data, model = get_corpus_rocstory(data_args, model, image_size, | |
| padding_mode=padding_mode, split=split, | |
| load_vocab=load_vocab) | |
| # elif task_mode == 'yelp': | |
| # print('hello loading yelp ') | |
| # training_data, model = get_corpus_rocstory(data_args, model, image_size, | |
| # padding_mode=padding_mode, split=split, | |
| # load_vocab=load_vocab) | |
| # elif task_mode == 'commonGen' or task_mode == 'commonGen-aug': | |
| # print('hello loading common-gen ') | |
| # training_data, model = get_corpus_rocstory(data_args, model, image_size, | |
| # padding_mode=padding_mode, split=split, | |
| # load_vocab=load_vocab) | |
| # elif task_mode == 'e2e': | |
| # training_data, model = get_corpus_rocstory(data_args, model, image_size, | |
| # padding_mode=padding_mode, split=split, | |
| # load_vocab=load_vocab) | |
| # elif task_mode == 'book': | |
| # tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') | |
| # training_data, model = get_corpus_book(data_args, tokenizer, model, image_size, | |
| # padding_mode=padding_mode, split=split,) | |
| if data_args.modality in ['roc-aug', 'roc', 'book', 'yelp', 'commonGen', 'commonGen-aug'] and data_args.cache_mode=='no': | |
| pass# dataset = TextDataset_NoCache( | |
| # training_data, | |
| # image_size, | |
| # data_args, | |
| # model_arch=data_args.model_arch, | |
| # model_emb=model | |
| # ) | |
| else: | |
| dataset = TextDataset( | |
| training_data, | |
| image_size, | |
| data_args, | |
| model_arch=data_args.model_arch, | |
| ) | |
| if deterministic: | |
| pass# data_loader = DataLoader( | |
| # dataset, | |
| # batch_size=batch_size, # 20, | |
| # drop_last=True, | |
| # shuffle=False, | |
| # num_workers=1, | |
| # ) | |
| else: | |
| data_loader = DataLoader( | |
| dataset, | |
| batch_size=batch_size, # 20, | |
| drop_last=True, | |
| shuffle=True, | |
| num_workers=1, | |
| ) | |
| while True: | |
| yield from data_loader | |
| def helper_tokenize_encode_cond(sentence_lst, vocab_dict, model, seqlen, data_args): | |
| result_train_lst = [] | |
| group_lst = defaultdict(list) | |
| with torch.no_grad(): | |
| for (src_ids, input_ids) in sentence_lst: | |
| tokenized_ = [vocab_dict.get(x, vocab_dict['UNK']) for x in input_ids] | |
| tokenized_src = [vocab_dict.get(x, vocab_dict['UNK']) for x in src_ids] | |
| input_ids = [0] + tokenized_ + [1] | |
| group_lst['word_ids'].append(input_ids) | |
| group_lst['src_ids'].append(tokenized_src) | |
| print(group_lst['word_ids'][:2]) | |
| print('padding mode is pad') | |
| max_length = seqlen | |
| group_lst['word_ids'] = _collate_batch_helper(group_lst['word_ids'], vocab_dict['PAD'], max_length) | |
| max_src_length = max([len(xx) for xx in group_lst['src_ids']]) | |
| print(max_src_length, seqlen) | |
| max_src_length = min(seqlen, max_src_length) | |
| group_lst['src_ids'], group_lst['src_mask'] = _collate_batch_helper(group_lst['src_ids'], | |
| vocab_dict['PAD'], | |
| max_src_length, | |
| return_mask=True) | |
| for input_ids, src_ids, src_mask in zip(group_lst['word_ids'], group_lst['src_ids'], | |
| group_lst['src_mask']): | |
| if data_args.experiment.startswith('random'): | |
| hidden_state = model(torch.tensor(input_ids)) | |
| elif data_args.experiment == 'gpt2_pre_compress': | |
| input_ids2 = torch.tensor(input_ids).to(model.device) | |
| input_embs = model.transformer.wte(input_ids2) # input_embs | |
| hidden_state = model.down_proj(input_embs) | |
| hidden_state = hidden_state * data_args.emb_scale_factor | |
| result_train_lst.append({'input_ids': input_ids, | |
| 'hidden_states': hidden_state.cpu().tolist(), | |
| 'src_ids':src_ids, | |
| 'src_mask':src_mask | |
| }) | |
| return result_train_lst | |
| def helper_tokenize_stream(sentence_lst, vocab_dict, model, seqlen, data_args, padding_mode, ): | |
| import psutil | |
| # Process.memory_info is expressed in bytes, so convert to megabytes | |
| print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB") | |
| from datasets import Dataset as Dataset2 | |
| raw_datasets = Dataset2.from_dict({'text':sentence_lst}) | |
| print(raw_datasets) | |
| print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB") | |
| def tokenize_function(examples): | |
| if isinstance(vocab_dict, dict): | |
| input_ids = [[0] + [vocab_dict.get(x, vocab_dict['UNK']) for x in seq] + [1] for seq in examples['text']] | |
| elif isinstance(vocab_dict, PreTrainedTokenizerFast): | |
| examples['text'] = [" ".join(seq) for seq in examples['text']] | |
| input_ids = vocab_dict(examples['text'], add_special_tokens=True)['input_ids'] | |
| result_dict = {'input_ids': input_ids} | |
| # clm input could be much much longer than block_size | |
| return result_dict | |
| tokenized_datasets = raw_datasets.map( | |
| tokenize_function, | |
| batched=True, | |
| num_proc=4, | |
| remove_columns=['text'], | |
| load_from_cache_file=True, | |
| desc="Running tokenizer on dataset", | |
| ) | |
| print(tokenized_datasets) | |
| print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB") | |
| if padding_mode == 'block': | |
| block_size = seqlen | |
| def group_texts(examples): | |
| concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} | |
| total_length = len(concatenated_examples[list(examples.keys())[0]]) | |
| if total_length >= block_size: | |
| total_length = (total_length // block_size) * block_size | |
| result = { | |
| k: [t[i: i + block_size] for i in range(0, total_length, block_size)] | |
| for k, t in concatenated_examples.items() | |
| } | |
| result["labels"] = result["input_ids"].copy() | |
| return result | |
| lm_datasets = tokenized_datasets.map( | |
| group_texts, | |
| batched=True, | |
| num_proc=data_args.preprocessing_num_workers, | |
| load_from_cache_file=not data_args.overwrite_cache, | |
| desc=f"Grouping texts in chunks of {block_size}", | |
| ) | |
| else: | |
| def pad_function(group_lst): | |
| max_length = seqlen | |
| if isinstance(vocab_dict, dict): | |
| group_lst['input_ids'] = _collate_batch_helper(group_lst['input_ids'], vocab_dict['PAD'], max_length) | |
| else: | |
| group_lst['input_ids'] = _collate_batch_helper(group_lst['input_ids'], vocab_dict.pad_token_id, max_length) | |
| return group_lst | |
| # Process.memory_info is expressed in bytes, so convert to megabytes | |
| print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB") | |
| lm_datasets = tokenized_datasets.map( | |
| pad_function, | |
| batched=True, | |
| num_proc=1, | |
| desc=f"padding", | |
| ) | |
| print(lm_datasets, 'padded dataset') | |
| print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB") | |
| import datasets | |
| raw_datasets = datasets.DatasetDict() | |
| raw_datasets['train'] = lm_datasets | |
| print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB") | |
| return raw_datasets | |
| def helper_tokenize_encode(sentence_lst, vocab_dict, model, seqlen, data_args, padding_mode, ): | |
| result_train_lst = [] | |
| group_lst = defaultdict(list) | |
| with torch.no_grad(): | |
| for input_ids in sentence_lst: | |
| tokenized_ = [vocab_dict.get(x, vocab_dict['UNK']) for x in input_ids] | |
| input_ids = [0] + tokenized_ + [1] | |
| group_lst['word_ids'].append(input_ids) | |
| print(group_lst['word_ids'][:2]) | |
| if padding_mode == 'block': | |
| print('padding mode is block') | |
| concatenated_examples = {k: sum(group_lst[k], []) for k in group_lst.keys()} | |
| total_length = len(concatenated_examples[list(group_lst.keys())[0]]) | |
| block_size = seqlen | |
| total_length = (total_length // block_size) * block_size | |
| # Split by chunks of max_len. | |
| group_lst = { | |
| k: [t[i: i + block_size] for i in range(0, total_length, block_size)] | |
| for k, t in concatenated_examples.items() | |
| } | |
| elif padding_mode == 'pad': | |
| print('padding mode is pad') | |
| max_length = seqlen | |
| group_lst['word_ids'] = _collate_batch_helper(group_lst['word_ids'], vocab_dict['PAD'], max_length) | |
| for input_ids in group_lst['word_ids']: | |
| if data_args.experiment.startswith('random'): | |
| hidden_state = model(torch.tensor(input_ids)) | |
| elif data_args.experiment == 'gpt2_pre_compress': | |
| input_ids2 = torch.tensor(input_ids).to(model.device) | |
| input_embs = model.transformer.wte(input_ids2) # input_embs | |
| hidden_state = model.down_proj(input_embs) | |
| hidden_state = hidden_state * data_args.emb_scale_factor | |
| elif data_args.experiment == 'glove': | |
| hidden_state = model(torch.tensor(input_ids)) | |
| result_train_lst.append({'input_ids': input_ids, 'hidden_states': hidden_state.cpu().tolist()}) | |
| return result_train_lst | |
| def load_glove_model(File): | |
| print("Loading Glove Model") | |
| glove_model = {} | |
| with open(File,'r') as f: | |
| for line in f: | |
| split_line = line.split() | |
| word = split_line[0] | |
| embedding = torch.tensor(np.array(split_line[1:], dtype=np.float64)) | |
| # embedding = np.array(split_line[1:], dtype=np.float64) | |
| glove_model[word] = embedding | |
| print(f"{len(glove_model)} words loaded!") | |
| return glove_model | |
| def load_glove(vocab): | |
| model = torch.nn.Embedding(len(vocab), 50) | |
| glove_model = load_glove_model('predictability/glove/glove.6B.50d.txt') | |
| array_lst = [] | |
| count_ = 0 | |
| for word, idx in vocab.items(): | |
| if word in glove_model: | |
| array_lst.append(glove_model[word]) | |
| else: | |
| count_ += 1 | |
| array_lst.append(torch.randn(50)) | |
| print(f'{count_} out of {len(vocab)} is initialized. ') | |
| array_lst = torch.stack(array_lst) | |
| print(torch.norm(array_lst, dim=-1).mean()) | |
| model.weight.data = array_lst | |
| return model | |
| def get_corpus_rocstory(data_args, model, image_size, padding_mode='block', | |
| split='train', load_vocab=None): | |
| import csv, torch, json | |
| from spacy.lang.en import English | |
| if data_args.experiment_mode == 'lm': | |
| if data_args.modality == 'roc': | |
| pass | |
| # print('loading dataset from ROCStory') | |
| # nlp = English() | |
| # tokenizer = nlp.tokenizer | |
| # sentence_lst = [] | |
| # print(f'loading from {data_args.roc_train}') | |
| # if split == 'train': | |
| # print('loading form the TRAIN set') | |
| # path = f'{data_args.roc_train}/roc_train.json' | |
| # elif split == 'valid': | |
| # print('loading form the VALID set') | |
| # path = f'{data_args.roc_train}/roc_valid.json' | |
| # else: | |
| # assert False, "invalid split for ROC dataset" | |
| # with open(path, 'r') as roc_reader: | |
| # for row in roc_reader: | |
| # sentences = json.loads(row)[0].strip() | |
| # word_lst = [x.text for x in tokenizer(sentences)] | |
| # sentence_lst.append(word_lst) | |
| # # with open(data_args.roc_train, 'r') as csvfile: | |
| # # roc_reader = csv.reader(csvfile) #delimiter=' ', quotechar='|') | |
| # # for row in roc_reader: | |
| # # # tokenize. | |
| # # sentences = " ".join(row[2:]) | |
| # # word_lst = [x.text for x in tokenizer(sentences)] | |
| # # sentence_lst.append(word_lst) | |
| # # sentence_lst = sentence_lst[1:] | |
| # print(sentence_lst[:2]) | |
| if data_args.modality == 'roc-aug': | |
| pass | |
| # print('loading dataset from ROCStory') | |
| # nlp = English() | |
| # tokenizer = nlp.tokenizer | |
| # sentence_lst = [] | |
| # if split == 'train': | |
| # print('loading form the TRAIN set') | |
| # path_lst = [f'{data_args.roc_train}/roc_train.json'] | |
| # path_lst.append('diffusion_lm/improved-diffusion/diff_models/rocstories_gptj.txt') | |
| # # path_lst.append('diffusion_lm/improved-diffusion/cache/ar_model_augment_roc.json') | |
| # # path_lst.append('diffusion_lm/improved-diffusion/cache/ar_model_augment_roc2.json') | |
| # elif split == 'valid': | |
| # print('loading form the VALID set') | |
| # path_lst = [f'{data_args.roc_train}/roc_valid.json'] | |
| # else: | |
| # assert False, "invalid split for ROC dataset" | |
| # print(path_lst) | |
| # for path in path_lst: | |
| # if path.endswith('txt'): | |
| # with open(path, 'r') as roc_reader: | |
| # for row in roc_reader: | |
| # sentences = row.strip() | |
| # word_lst = [x.text for x in tokenizer(sentences)] | |
| # sentence_lst.append(word_lst) | |
| # else: | |
| # with open(path, 'r') as roc_reader: | |
| # for row in roc_reader: | |
| # sentences = json.loads(row)[0].strip() | |
| # word_lst = [x.text for x in tokenizer(sentences)] | |
| # sentence_lst.append(word_lst) | |
| # print(sentence_lst[:2],sentence_lst[-2:], 'dataset size=',len(sentence_lst)) | |
| elif data_args.modality == 'simple-wiki': | |
| pass | |
| # print('loading dataset from simple wikipedia') | |
| # sentence_lst = [] | |
| # with open(data_args.wiki_train, 'r') as ff: | |
| # for row in ff: | |
| # word_lst = row.lower().split() | |
| # sentence_lst.append(word_lst) | |
| # print(sentence_lst[:2]) | |
| elif data_args.modality == 'e2e-tgt': | |
| print('loading dataset from simple e2e dataset') | |
| sentence_lst = [] | |
| nlp = English() | |
| tokenizer = nlp.tokenizer | |
| if split == 'train': | |
| print('loading form the TRAIN set') | |
| path = '/data0/gonghaisong/Diffusion-LM/datasets/e2e_data/src1_train.txt' | |
| # path = f'../{data_args.e2e_train}/src1_train.txt' | |
| elif split == 'valid': | |
| print('loading form the VALID set') | |
| path = f'../{data_args.e2e_train}/src1_valid.txt' | |
| path = '/data0/gonghaisong/Diffusion-LM/datasets/e2e_data/src1_valid.txt' | |
| elif split == 'test': | |
| print('loading form the TEST set') | |
| path = f'../{data_args.e2e_train}/src1_test.txt' | |
| path = '/data0/gonghaisong/Diffusion-LM/datasets/e2e_data/src1_test.txt' | |
| elif split == 'debug': | |
| print('loading form the DEBUG set') | |
| path = data_args.debug_path | |
| import json | |
| with open(path, 'r') as ff: | |
| for line in ff: | |
| sentence_lst.append(json.loads(line)[0].split(' ')) | |
| sentence_lst = sentence_lst + sentence_lst | |
| if split in ['train', 'valid', 'test']: | |
| with open(path, 'r') as ff: | |
| for row in ff: | |
| word_lst = row.split('||')[1] | |
| word_lst = [x.text for x in tokenizer(word_lst)] | |
| sentence_lst.append(word_lst) | |
| print(sentence_lst[:2]) | |
| elif data_args.modality == 'yelp': | |
| print('loading dataset from simple YelpNLG dataset') | |
| sentence_lst = [] | |
| nlp = English() | |
| tokenizer = nlp.tokenizer | |
| if split == 'train': | |
| print('loading form the TRAIN set') | |
| path = f'{data_args.yelp_train}/yelpnlg-train.csv' | |
| elif split == 'valid': | |
| print('loading form the VALID set') | |
| path = f'{data_args.yelp_train}/yelpnlg-dev.csv' | |
| elif split == 'test': | |
| print('loading form the TEST set') | |
| path = f'{data_args.yelp_train}/yelpnlg-test.csv' | |
| if split in ['train', 'valid', 'test']: | |
| with open(path, 'r') as csvfile: | |
| yelp_reader = csv.reader(csvfile) #delimiter=' ', quotechar='|') | |
| for row in yelp_reader: | |
| sentences = row[1] | |
| word_lst = [x.text for x in tokenizer(sentences)] | |
| sentence_lst.append(word_lst) | |
| sentence_lst = sentence_lst[1:] | |
| print(sentence_lst[:2]) | |
| elif data_args.modality == 'commonGen': | |
| print('loading dataset from simple YelpNLG dataset') | |
| sentence_lst = [] | |
| nlp = English() | |
| tokenizer = nlp.tokenizer | |
| if split == 'train': | |
| print('loading form the TRAIN set') | |
| path = f'{data_args.commonGen_train}/commongen.train.jsonl' | |
| elif split == 'valid': | |
| print('loading form the VALID set') | |
| path = f'{data_args.commonGen_train}/commongen.dev.jsonl' | |
| elif split == 'test': | |
| print('loading form the TEST set') | |
| path = f'{data_args.commonGen_train}/commongen.test.jsonl' | |
| if split in ['train', 'valid', 'test']: | |
| with open(path, 'r') as ff: | |
| for line in ff: | |
| line = json.loads(line) | |
| for sentences in line['scene']: | |
| word_lst = [x.text for x in tokenizer(sentences)] | |
| sentence_lst.append(word_lst) | |
| print(sentence_lst[:2]) | |
| elif data_args.modality == 'commonGen-aug': | |
| print('loading dataset from simple YelpNLG dataset') | |
| sentence_lst = [] | |
| nlp = English() | |
| tokenizer = nlp.tokenizer | |
| if split == 'train': | |
| print('loading form the TRAIN set') | |
| path = f'{data_args.commonGen_train}/commongen.train.jsonl' | |
| path_lst = [f'{data_args.roc_train}/roc_train.json'] | |
| path_lst.append('diffusion_lm/improved-diffusion/diff_models/rocstories_gptj.txt') | |
| elif split == 'valid': | |
| print('loading form the VALID set') | |
| path = f'{data_args.commonGen_train}/commongen.dev.jsonl' | |
| path_lst = [] | |
| elif split == 'test': | |
| print('loading form the TEST set') | |
| path = f'{data_args.commonGen_train}/commongen.test.jsonl' | |
| path_lst = [] | |
| if split in ['train', 'valid', 'test']: | |
| with open(path, 'r') as ff: | |
| for line in ff: | |
| line = json.loads(line) | |
| for sentences in line['scene']: | |
| word_lst = [x.text for x in tokenizer(sentences)] | |
| sentence_lst.append(word_lst) | |
| print(sentence_lst[:2]) | |
| import itertools | |
| for path in path_lst: | |
| if path.endswith('txt'): | |
| with open(path, 'r') as roc_reader: | |
| for row in roc_reader: | |
| sentences = row.strip() | |
| word_lst = [x.text for x in tokenizer(sentences)] | |
| spl = [[]] | |
| for x, y in itertools.groupby(word_lst, lambda z: z == '.'): | |
| spl[-1].extend(y) | |
| if x: spl.append([]) | |
| sentence_lst.extend(spl[:-1]) | |
| else: | |
| with open(path, 'r') as roc_reader: | |
| for row in roc_reader: | |
| sentences = json.loads(row)[0].strip() | |
| word_lst = [x.text for x in tokenizer(sentences)] | |
| spl = [[]] | |
| for x, y in itertools.groupby(word_lst, lambda z: z == '.'): | |
| spl[-1].extend(y) | |
| if x: spl.append([]) | |
| sentence_lst.extend(spl[:-1]) | |
| print(sentence_lst[-2:]) | |
| # get tokenizer. | |
| if load_vocab is None: | |
| counter = Counter() | |
| for input_ids in sentence_lst: | |
| counter.update(input_ids) | |
| if data_args.experiment_mode == 'conditional_gen': | |
| if data_args.modality == 'e2e': | |
| print('loading dataset from simple e2e dataset') | |
| sentence_lst = [] | |
| nlp = English() | |
| tokenizer = nlp.tokenizer | |
| if split == 'train': | |
| path = f'{data_args.e2e_train}/src1_train.txt' | |
| with open(path, 'r') as ff: | |
| for row in ff: | |
| src_lst, word_lst = row.split('||') | |
| word_lst = [x.text for x in tokenizer(word_lst)] | |
| src_lst = [x.text for x in tokenizer(src_lst)] | |
| sentence_lst.append((src_lst, word_lst)) | |
| elif split == 'valid': | |
| path = f'{data_args.e2e_train}/src1_valid.txt' | |
| sentence_lst = read_e2e_files(path, data_args, tokenizer) | |
| print(sentence_lst[:2]) | |
| # get tokenizer. | |
| if load_vocab is None: | |
| counter = Counter() | |
| for (src_ids, input_ids) in sentence_lst: | |
| counter.update(input_ids) | |
| counter.update(src_ids) | |
| if load_vocab is None: | |
| vocab_dict = {'START': 0, 'END': 1, 'UNK':2, 'PAD':3} | |
| for k, v in counter.items(): | |
| if v > 10: | |
| vocab_dict[k] = len(vocab_dict) | |
| print(len(counter), len(vocab_dict)) | |
| path_save_vocab = '/data0/gonghaisong/Diffusion-LM/improved-diffusion/diffusion_models/diff_e2e-tgt_block_rand16_transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd102_xstart_e2e/vocab.json' | |
| print(f'save the vocab to {path_save_vocab}') | |
| with open(path_save_vocab, 'w') as f: | |
| json.dump(vocab_dict, f) | |
| else: | |
| vocab_dict = load_vocab | |
| path_save_vocab = '/data0/gonghaisong/Diffusion-LM/improved-diffusion/diffusion_models/diff_e2e-tgt_block_rand16_transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd102_xstart_e2e/vocab.json' | |
| if not os.path.exists(path_save_vocab): | |
| print(f'save the vocab to {path_save_vocab}') | |
| if isinstance(vocab_dict, dict): | |
| with open(path_save_vocab, 'w') as f: | |
| json.dump(vocab_dict, f) | |
| assert vocab_dict['START'] == 0 | |
| elif isinstance(vocab_dict, PreTrainedTokenizerFast): | |
| vocab_dict.save_pretrained(data_args.checkpoint_path) | |
| else: | |
| assert False, "invalid type of vocab_dict" | |
| if model is None and data_args.experiment == 'random': | |
| model = torch.nn.Embedding(len(vocab_dict), data_args.in_channel) | |
| print('initializing the random embeddings', model) | |
| torch.nn.init.normal_(model.weight) | |
| path_save = '/data0/gonghaisong/Diffusion-LM/improved-diffusion/diffusion_models/diff_e2e-tgt_block_rand16_transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd102_xstart_e2e/random_emb.torch' | |
| print(f'save the random encoder to {data_args.checkpoint_path}/random_emb.torch') | |
| torch.save(model.state_dict(), path_save) | |
| # path_save = f'{data_args.checkpoint_path}/random_emb.torch' | |
| # if not os.path.exists(path_save) and data_args.experiment == 'random': | |
| # torch.save(model.state_dict(), path_save) | |
| if data_args.experiment_mode == 'lm' and data_args.modality in ['roc-aug', 'roc', 'yelp', 'commonGen', 'commonGen-aug'] \ | |
| and data_args.cache_mode=='no': | |
| train_dataset = helper_tokenize_stream(sentence_lst, vocab_dict, model, image_size**2, data_args, padding_mode) | |
| return train_dataset, model | |
| elif data_args.experiment_mode == 'lm': | |
| result_train_lst = helper_tokenize_encode(sentence_lst, vocab_dict, model, image_size**2, data_args, padding_mode) | |
| elif data_args.experiment_mode == 'conditional_gen': | |
| result_train_lst = helper_tokenize_encode_cond(sentence_lst, vocab_dict, model, image_size ** 2, data_args) | |
| return {'train': result_train_lst}, model | |
| def write_e2e_corr(prompt_lst, file_dict, corr_path): | |
| print(len(prompt_lst)) | |
| with open(corr_path, 'w') as f: | |
| for x in prompt_lst: | |
| for line in file_dict[x]: | |
| print(" ".join(line), file=f) | |
| print('', file=f) | |
| def write_e2e_src(prompt_lst, corr_path): | |
| with open(corr_path, 'w') as f: | |
| for x in prompt_lst: | |
| print(" ".join(x), file=f) | |
| return | |
| def read_e2e_files(path, args, tokenizer): | |
| file_dict = {} | |
| with open(path, 'r') as f: | |
| for line in f: | |
| src_lst, word_lst = line.strip().split('||') | |
| tgt = tuple([x.text for x in tokenizer(word_lst)]) | |
| src = tuple([x.text for x in tokenizer(src_lst)]) | |
| if src not in file_dict: | |
| file_dict[src] = [] | |
| file_dict[src].append(tgt) | |
| temp = '1' | |
| prompt_text_dict = file_dict | |
| prompt_text_lst = list(prompt_text_dict.keys()) | |
| gold_dir = os.path.join(args.out_dir, '{}_{}_{}'.format(temp, args.split, 'gold')) | |
| print("gold dir", gold_dir) | |
| write_e2e_corr(prompt_text_lst, prompt_text_dict, gold_dir) | |
| src_dir = os.path.join(args.out_dir, '{}_{}_{}'.format(temp, args.split, 'src')) | |
| write_e2e_src(prompt_text_lst, src_dir) | |
| final_lst = [(xx, prompt_text_dict[xx][0]) for xx in prompt_text_lst] | |
| return final_lst | |
| def get_corpus_book(data_args, tokenizer, model, image_size, padding_mode='block', split='train',): | |
| max_length = image_size ** 2 | |
| import os | |
| assert padding_mode == 'block' | |
| raw_datasets = load_dataset('bookcorpus') | |
| if "validation" not in raw_datasets.keys(): | |
| raw_datasets["validation"] = load_dataset( | |
| 'bookcorpus', | |
| split=f"train[:1%]", | |
| ) | |
| raw_datasets["train"] = load_dataset( | |
| 'bookcorpus', | |
| split=f"train[1%:]", | |
| ) | |
| print(raw_datasets) | |
| column_names = raw_datasets["train"].column_names | |
| def tokenize_function(examples): | |
| output = tokenizer(examples['text'], add_special_tokens=False) | |
| return output | |
| tokenized_datasets = raw_datasets.map( | |
| tokenize_function, | |
| batched=True, | |
| num_proc=data_args.preprocessing_num_workers, | |
| remove_columns=column_names, | |
| load_from_cache_file=True, | |
| ) | |
| print(tokenized_datasets) | |
| block_size = max_length | |
| # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size. | |
| def group_texts(examples): | |
| # Concatenate all texts. | |
| concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} | |
| total_length = len(concatenated_examples[list(examples.keys())[0]]) | |
| if total_length >= block_size: | |
| total_length = (total_length // block_size) * block_size | |
| result = { | |
| k: [t[i: i + block_size] for i in range(0, total_length, block_size)] | |
| for k, t in concatenated_examples.items() | |
| } | |
| return result | |
| lm_datasets = tokenized_datasets.map( | |
| group_texts, | |
| batched=True, | |
| num_proc=4, | |
| load_from_cache_file=True, | |
| desc=f"Grouping texts in chunks of {block_size}", | |
| ) | |
| print(lm_datasets) | |
| if model is None: | |
| if data_args.training_mode.startswith('e2e'): | |
| print('since its e2e, initialize a dummy embedding' ) | |
| model = torch.nn.Embedding(len(tokenizer), 1) | |
| else: | |
| model = torch.nn.Embedding(len(tokenizer), data_args.in_channel) | |
| print('initializing the random embeddings', model) | |
| torch.nn.init.normal_(model.weight) | |
| path_save = f'{data_args.checkpoint_path}/random_emb.torch' | |
| print(f'save the random encoder to {data_args.checkpoint_path}/random_emb.torch') | |
| torch.save(model.state_dict(), path_save) | |
| if split == 'train': | |
| return lm_datasets, model | |
| else: | |
| lm_datasets['train'] = lm_datasets['validation'] | |
| return lm_datasets, model | |
| class TextDataset(Dataset): | |
| def __init__(self, text_datasets, resolution, data_args, model_arch='conv-unet', | |
| classes=None, shard=0, num_shards=1, eigen_transform=None, | |
| mapping_func=None, model_emb=None): | |
| super().__init__() | |
| self.resolution = resolution | |
| self.text_datasets = text_datasets | |
| self.length = len(self.text_datasets['train']) | |
| self.model_arch = model_arch | |
| self.data_args = data_args | |
| print(self.resolution) | |
| self.eigen_transform = eigen_transform | |
| self.mapping_func = mapping_func | |
| self.model_emb = model_emb | |
| # self.local_images = image_paths[shard:][::num_shards] | |
| # self.local_classes = None if classes is None else classes[shard:][::num_shards] | |
| def __len__(self): | |
| return self.length | |
| def __getitem__(self, idx): | |
| # We are not on a new enough PIL to support the `reducing_gap` | |
| # argument, which uses BOX downsampling at powers of two first. | |
| # Thus, we do it by hand to improve downsample quality. | |
| if self.model_arch == 'conv-unet': | |
| pass# arr = np.array(self.text_datasets['train'][idx]['hidden_states'], | |
| # dtype=np.float32).reshape(self.resolution, self.resolution, -1) | |
| # # print(self.eigen_transform.shape) | |
| # if self.eigen_transform is not None: | |
| # old_shape = arr.shape | |
| # arr = arr.reshape(1, -1) - self.eigen_transform['mean'] | |
| # arr = arr @ self.eigen_transform['map'] | |
| # arr = arr.reshape(old_shape) | |
| # if hasattr(self.data_args, 'noise_level') and self.data_args.noise_level > 0: | |
| # arr = arr + self.data_args.noise_level * np.random.randn(*arr.shape).astype(arr.dtype) | |
| # out_dict = {} | |
| # out_dict['input_ids'] = np.array(self.text_datasets['train'][idx]['input_ids']) | |
| # # if self.local_classes is not None: | |
| # # out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64) | |
| # # print(out_dict.keys()) | |
| # return np.transpose(arr, [2, 0, 1]), out_dict | |
| elif self.model_arch == '1d-unet': | |
| pass# arr = np.array(self.text_datasets['train'][idx]['hidden_states'], | |
| # dtype=np.float32) # seqlen, dim | |
| # if self.eigen_transform is not None: | |
| # old_shape = arr.shape | |
| # arr = arr.reshape(1, -1) - self.eigen_transform['mean'] | |
| # arr = arr @ self.eigen_transform['map'] | |
| # arr = arr.reshape(old_shape) | |
| # if hasattr(self.data_args, 'noise_level') and self.data_args.noise_level > 0: | |
| # arr = arr + self.data_args.noise_level * np.random.randn(*arr.shape).astype(arr.dtype) | |
| # arr = np.transpose(arr, [1, 0]) | |
| # out_dict = {} | |
| # out_dict['input_ids'] = np.array(self.text_datasets['train'][idx]['input_ids']) | |
| # # out_dict['mapping_func'] = self.mapping_func | |
| # # if self.local_classes is not None: | |
| # # out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64) | |
| # # print(arr.shape) | |
| # return arr, out_dict | |
| else: | |
| arr = np.array(self.text_datasets['train'][idx]['hidden_states'], | |
| dtype=np.float32) | |
| if self.eigen_transform is not None: | |
| old_shape = arr.shape | |
| # arr = arr.reshape(1, -1) @ self.eigen_transform | |
| arr = arr.reshape(1, -1) - self.eigen_transform['mean'] | |
| arr = arr @ self.eigen_transform['map'] | |
| arr = arr.reshape(old_shape) | |
| if hasattr(self.data_args, 'noise_level') and self.data_args.noise_level > 0: | |
| # print(arr.dtype) | |
| # print(self.data_args.noise_level, 'using the noise level.') | |
| arr = arr + self.data_args.noise_level * np.random.randn(*arr.shape).astype(arr.dtype) | |
| # print(arr.dtype) | |
| out_dict = {} | |
| out_dict['input_ids'] = np.array(self.text_datasets['train'][idx]['input_ids']) | |
| # out_dict['mapping_func'] = self.mapping_func | |
| if self.data_args.experiment_mode == 'conditional_gen': | |
| out_dict['src_ids'] = np.array(self.text_datasets['train'][idx]['src_ids']) | |
| out_dict['src_mask'] = np.array(self.text_datasets['train'][idx]['src_mask']) | |
| # if self.local_classes is not None: | |
| # out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64) | |
| return arr, out_dict | |
| # print(arr.dtype) | |
| # arr = arr.float() | |
| # print(arr.shape) | |
| class TextDataset_NoCache(Dataset): | |
| def __init__(self, text_datasets, resolution, data_args, model_arch='conv-unet', | |
| classes=None, shard=0, num_shards=1, eigen_transform=None, | |
| mapping_func=None, model_emb=None): | |
| super().__init__() | |
| self.resolution = resolution | |
| self.text_datasets = text_datasets | |
| self.length = len(self.text_datasets['train']) | |
| self.model_arch = model_arch | |
| self.data_args = data_args | |
| print(self.resolution) | |
| self.eigen_transform = eigen_transform | |
| self.mapping_func = mapping_func | |
| self.model_emb = model_emb | |
| # self.local_images = image_paths[shard:][::num_shards] | |
| # self.local_classes = None if classes is None else classes[shard:][::num_shards] | |
| def __len__(self): | |
| return self.length | |
| def __getitem__(self, idx): | |
| # We are not on a new enough PIL to support the `reducing_gap` | |
| # argument, which uses BOX downsampling at powers of two first. | |
| # Thus, we do it by hand to improve downsample quality. | |
| with torch.no_grad(): | |
| input_ids = self.text_datasets['train'][idx]['input_ids'] | |
| model = self.model_emb | |
| if self.data_args.experiment.startswith('random'): | |
| hidden_state = model(torch.tensor(input_ids)) | |
| elif self.data_args.experiment == 'gpt2_pre_compress': | |
| input_ids2 = torch.tensor(input_ids).to(model.device) | |
| input_embs = model.transformer.wte(input_ids2) # input_embs | |
| hidden_state = model.down_proj(input_embs) | |
| hidden_state = hidden_state * data_args.emb_scale_factor | |
| if self.model_arch == 'conv-unet': | |
| arr = np.array(hidden_state, | |
| dtype=np.float32).reshape(self.resolution, self.resolution, -1) | |
| # print(self.eigen_transform.shape) | |
| if self.eigen_transform is not None: | |
| old_shape = arr.shape | |
| arr = arr.reshape(1, -1) - self.eigen_transform['mean'] | |
| arr = arr @ self.eigen_transform['map'] | |
| arr = arr.reshape(old_shape) | |
| if hasattr(self.data_args, 'noise_level') and self.data_args.noise_level > 0: | |
| arr = arr + self.data_args.noise_level * np.random.randn(*arr.shape).astype(arr.dtype) | |
| out_dict = {} | |
| out_dict['input_ids'] = np.array(self.text_datasets['train'][idx]['input_ids']) | |
| # if self.local_classes is not None: | |
| # out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64) | |
| # print(out_dict.keys()) | |
| return np.transpose(arr, [2, 0, 1]), out_dict | |
| elif self.model_arch == '1d-unet': | |
| arr = np.array(hidden_state, | |
| dtype=np.float32) # seqlen, dim | |
| if self.eigen_transform is not None: | |
| old_shape = arr.shape | |
| arr = arr.reshape(1, -1) - self.eigen_transform['mean'] | |
| arr = arr @ self.eigen_transform['map'] | |
| arr = arr.reshape(old_shape) | |
| if hasattr(self.data_args, 'noise_level') and self.data_args.noise_level > 0: | |
| arr = arr + self.data_args.noise_level * np.random.randn(*arr.shape).astype(arr.dtype) | |
| arr = np.transpose(arr, [1, 0]) | |
| out_dict = {} | |
| out_dict['input_ids'] = np.array(self.text_datasets['train'][idx]['input_ids']) | |
| # out_dict['mapping_func'] = self.mapping_func | |
| # if self.local_classes is not None: | |
| # out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64) | |
| # print(arr.shape) | |
| return arr, out_dict | |
| else: | |
| arr = np.array(hidden_state, | |
| dtype=np.float32) | |
| if self.eigen_transform is not None: | |
| old_shape = arr.shape | |
| # arr = arr.reshape(1, -1) @ self.eigen_transform | |
| arr = arr.reshape(1, -1) - self.eigen_transform['mean'] | |
| arr = arr @ self.eigen_transform['map'] | |
| arr = arr.reshape(old_shape) | |
| if hasattr(self.data_args, 'noise_level') and self.data_args.noise_level > 0: | |
| # print(arr.dtype) | |
| # print(self.data_args.noise_level, 'using the noise level.') | |
| arr = arr + self.data_args.noise_level * np.random.randn(*arr.shape).astype(arr.dtype) | |
| # print(arr.dtype) | |
| out_dict = {} | |
| out_dict['input_ids'] = np.array(self.text_datasets['train'][idx]['input_ids']) | |
| # out_dict['mapping_func'] = self.mapping_func | |
| if self.data_args.experiment_mode == 'conditional_gen': | |
| out_dict['src_ids'] = np.array(self.text_datasets['train'][idx]['src_ids']) | |
| out_dict['src_mask'] = np.array(self.text_datasets['train'][idx]['src_mask']) | |
| # if self.local_classes is not None: | |
| # out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64) | |
| return arr, out_dict | |
| def _collate_batch_helper(examples, pad_token_id, max_length, return_mask=False): | |
| result = torch.full([len(examples), max_length], pad_token_id, dtype=torch.int64).tolist() | |
| mask_ = torch.full([len(examples), max_length], pad_token_id, dtype=torch.int64).tolist() | |
| for i, example in enumerate(examples): | |
| curr_len = min(len(example), max_length) | |
| result[i][:curr_len] = example[:curr_len] | |
| mask_[i][:curr_len] = [1] * curr_len | |
| if return_mask: | |
| return result, mask_ | |
| return result | |
| def _torch_collate_batch(examples, pad_token_id, max_length): | |
| """Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary.""" | |
| import numpy as np | |
| import torch | |
| # Tensorize if necessary. | |
| if isinstance(examples[0], (list, tuple, np.ndarray)): | |
| examples = [torch.tensor(e, dtype=torch.long) for e in examples] | |
| # length_of_first = examples[0].size(0) | |
| # Check if padding is necessary. | |
| # are_tensors_same_length = all(x.size(0) == length_of_first for x in examples) | |
| # if are_tensors_same_length and (pad_to_multiple_of is None or length_of_first % pad_to_multiple_of == 0): | |
| # return torch.stack(examples, dim=0) | |
| # Creating the full tensor and filling it with our data. | |
| # max_length = max(x.size(0) for x in examples) | |
| # if pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): | |
| # max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of | |
| result = examples[0].new_full([len(examples), max_length], pad_token_id) | |
| for i, example in enumerate(examples): | |
| if True: | |
| result[i, : example.shape[0]] = example | |
| else: | |
| result[i, -example.shape[0] :] = example | |
| return result |