Spaces:
Running
Running
| # from PIL import Image | |
| # import blobfile as bf | |
| # from mpi4py import MPI | |
| import numpy as np | |
| from torch.utils.data import DataLoader, Dataset | |
| from transformers import ( | |
| AutoModelForCausalLM, | |
| AutoConfig, | |
| AutoTokenizer, | |
| default_data_collator, | |
| PreTrainedTokenizerFast, | |
| PreTrainedTokenizer, | |
| ) | |
| # from datasets import load_dataset | |
| import sys, os | |
| import torch | |
| # sys.path.insert(0, os.path.join(sys.path[0], '../../transformers/examples/pytorch/language-modeling')) | |
| # from custom_trainer import GPT2LMHeadModelCompress, BERTModelCompress, AutoEncoderWithNoise | |
| from collections import Counter, defaultdict | |
| from functools import partial | |
| from itertools import chain | |
| def load_data_text( | |
| *, | |
| data_dir, | |
| batch_size, | |
| image_size, | |
| class_cond=False, | |
| deterministic=False, | |
| data_args=None, | |
| task_mode="roc", | |
| model=None, | |
| padding_mode="block", | |
| split="train", | |
| load_vocab=None, | |
| ): | |
| """ | |
| For a dataset, create a generator over (images, kwargs) pairs. | |
| Each images is an NCHW float tensor, and the kwargs dict contains zero or | |
| more keys, each of which map to a batched Tensor of their own. | |
| The kwargs dict can be used for class labels, in which case the key is "y" | |
| and the values are integer tensors of class labels. | |
| :param data_dir: a dataset directory. | |
| :param batch_size: the batch size of each returned pair. | |
| :param image_size: the size to which images are resized. | |
| :param class_cond: if True, include a "y" key in returned dicts for class | |
| label. If classes are not available and this is true, an | |
| exception will be raised. | |
| :param deterministic: if True, yield results in a deterministic order. | |
| """ | |
| print("hello loading text data. ") | |
| if data_args.experiment.startswith("random") and model is None: | |
| model = None | |
| # elif data_args.experiment.startswith('random') and model is not None: | |
| # print('loading initialized random embeddings. ') | |
| if task_mode == "roc" or task_mode == "roc-aug": | |
| pass | |
| # training_data, model = get_corpus_rocstory(data_args, model, image_size, | |
| # padding_mode=padding_mode, split=split, | |
| # load_vocab=load_vocab) | |
| elif task_mode == "simple-wiki": | |
| pass | |
| # training_data, model = get_corpus_rocstory(data_args, model, image_size, | |
| # padding_mode=padding_mode, split=split, | |
| # load_vocab=load_vocab) | |
| elif task_mode == "e2e-tgt": | |
| print("hello loading e2e-tgt. ") | |
| training_data, model = get_corpus_rocstory( | |
| data_args, | |
| model, | |
| image_size, | |
| padding_mode=padding_mode, | |
| split=split, | |
| load_vocab=load_vocab, | |
| ) | |
| # elif task_mode == 'yelp': | |
| # print('hello loading yelp ') | |
| # training_data, model = get_corpus_rocstory(data_args, model, image_size, | |
| # padding_mode=padding_mode, split=split, | |
| # load_vocab=load_vocab) | |
| # elif task_mode == 'commonGen' or task_mode == 'commonGen-aug': | |
| # print('hello loading common-gen ') | |
| # training_data, model = get_corpus_rocstory(data_args, model, image_size, | |
| # padding_mode=padding_mode, split=split, | |
| # load_vocab=load_vocab) | |
| # elif task_mode == 'e2e': | |
| # training_data, model = get_corpus_rocstory(data_args, model, image_size, | |
| # padding_mode=padding_mode, split=split, | |
| # load_vocab=load_vocab) | |
| # elif task_mode == 'book': | |
| # tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') | |
| # training_data, model = get_corpus_book(data_args, tokenizer, model, image_size, | |
| # padding_mode=padding_mode, split=split,) | |
| if ( | |
| data_args.modality | |
| in ["roc-aug", "roc", "book", "yelp", "commonGen", "commonGen-aug"] | |
| and data_args.cache_mode == "no" | |
| ): | |
| pass # dataset = TextDataset_NoCache( | |
| # training_data, | |
| # image_size, | |
| # data_args, | |
| # model_arch=data_args.model_arch, | |
| # model_emb=model | |
| # ) | |
| else: | |
| dataset = TextDataset( | |
| training_data, | |
| image_size, | |
| data_args, | |
| model_arch=data_args.model_arch, | |
| ) | |
| if deterministic: | |
| pass # data_loader = DataLoader( | |
| # dataset, | |
| # batch_size=batch_size, # 20, | |
| # drop_last=True, | |
| # shuffle=False, | |
| # num_workers=1, | |
| # ) | |
| else: | |
| data_loader = DataLoader( | |
| dataset, | |
| batch_size=batch_size, # 20, | |
| drop_last=True, | |
| shuffle=True, | |
| num_workers=1, | |
| ) | |
| while True: | |
| yield from data_loader | |
| def helper_tokenize_encode_cond(sentence_lst, vocab_dict, model, seqlen, data_args): | |
| result_train_lst = [] | |
| group_lst = defaultdict(list) | |
| with torch.no_grad(): | |
| for src_ids, input_ids in sentence_lst: | |
| tokenized_ = [vocab_dict.get(x, vocab_dict["UNK"]) for x in input_ids] | |
| tokenized_src = [vocab_dict.get(x, vocab_dict["UNK"]) for x in src_ids] | |
| input_ids = [0] + tokenized_ + [1] | |
| group_lst["word_ids"].append(input_ids) | |
| group_lst["src_ids"].append(tokenized_src) | |
| print(group_lst["word_ids"][:2]) | |
| print("padding mode is pad") | |
| max_length = seqlen | |
| group_lst["word_ids"] = _collate_batch_helper( | |
| group_lst["word_ids"], vocab_dict["PAD"], max_length | |
| ) | |
| max_src_length = max([len(xx) for xx in group_lst["src_ids"]]) | |
| print(max_src_length, seqlen) | |
| max_src_length = min(seqlen, max_src_length) | |
| group_lst["src_ids"], group_lst["src_mask"] = _collate_batch_helper( | |
| group_lst["src_ids"], vocab_dict["PAD"], max_src_length, return_mask=True | |
| ) | |
| for input_ids, src_ids, src_mask in zip( | |
| group_lst["word_ids"], group_lst["src_ids"], group_lst["src_mask"] | |
| ): | |
| if data_args.experiment.startswith("random"): | |
| hidden_state = model(torch.tensor(input_ids)) | |
| elif data_args.experiment == "gpt2_pre_compress": | |
| input_ids2 = torch.tensor(input_ids).to(model.device) | |
| input_embs = model.transformer.wte(input_ids2) # input_embs | |
| hidden_state = model.down_proj(input_embs) | |
| hidden_state = hidden_state * data_args.emb_scale_factor | |
| result_train_lst.append( | |
| { | |
| "input_ids": input_ids, | |
| "hidden_states": hidden_state.cpu().tolist(), | |
| "src_ids": src_ids, | |
| "src_mask": src_mask, | |
| } | |
| ) | |
| return result_train_lst | |
| def helper_tokenize_stream( | |
| sentence_lst, | |
| vocab_dict, | |
| model, | |
| seqlen, | |
| data_args, | |
| padding_mode, | |
| ): | |
| import psutil | |
| # Process.memory_info is expressed in bytes, so convert to megabytes | |
| print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB") | |
| from datasets import Dataset as Dataset2 | |
| raw_datasets = Dataset2.from_dict({"text": sentence_lst}) | |
| print(raw_datasets) | |
| print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB") | |
| def tokenize_function(examples): | |
| if isinstance(vocab_dict, dict): | |
| input_ids = [ | |
| [0] + [vocab_dict.get(x, vocab_dict["UNK"]) for x in seq] + [1] | |
| for seq in examples["text"] | |
| ] | |
| elif isinstance(vocab_dict, PreTrainedTokenizerFast): | |
| examples["text"] = [" ".join(seq) for seq in examples["text"]] | |
| input_ids = vocab_dict(examples["text"], add_special_tokens=True)[ | |
| "input_ids" | |
| ] | |
| result_dict = {"input_ids": input_ids} | |
| # clm input could be much much longer than block_size | |
| return result_dict | |
| tokenized_datasets = raw_datasets.map( | |
| tokenize_function, | |
| batched=True, | |
| num_proc=4, | |
| remove_columns=["text"], | |
| load_from_cache_file=True, | |
| desc="Running tokenizer on dataset", | |
| ) | |
| print(tokenized_datasets) | |
| print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB") | |
| if padding_mode == "block": | |
| block_size = seqlen | |
| def group_texts(examples): | |
| concatenated_examples = { | |
| k: list(chain(*examples[k])) for k in examples.keys() | |
| } | |
| total_length = len(concatenated_examples[list(examples.keys())[0]]) | |
| if total_length >= block_size: | |
| total_length = (total_length // block_size) * block_size | |
| result = { | |
| k: [t[i : i + block_size] for i in range(0, total_length, block_size)] | |
| for k, t in concatenated_examples.items() | |
| } | |
| result["labels"] = result["input_ids"].copy() | |
| return result | |
| lm_datasets = tokenized_datasets.map( | |
| group_texts, | |
| batched=True, | |
| num_proc=data_args.preprocessing_num_workers, | |
| load_from_cache_file=not data_args.overwrite_cache, | |
| desc=f"Grouping texts in chunks of {block_size}", | |
| ) | |
| else: | |
| def pad_function(group_lst): | |
| max_length = seqlen | |
| if isinstance(vocab_dict, dict): | |
| group_lst["input_ids"] = _collate_batch_helper( | |
| group_lst["input_ids"], vocab_dict["PAD"], max_length | |
| ) | |
| else: | |
| group_lst["input_ids"] = _collate_batch_helper( | |
| group_lst["input_ids"], vocab_dict.pad_token_id, max_length | |
| ) | |
| return group_lst | |
| # Process.memory_info is expressed in bytes, so convert to megabytes | |
| print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB") | |
| lm_datasets = tokenized_datasets.map( | |
| pad_function, | |
| batched=True, | |
| num_proc=1, | |
| desc=f"padding", | |
| ) | |
| print(lm_datasets, "padded dataset") | |
| print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB") | |
| import datasets | |
| raw_datasets = datasets.DatasetDict() | |
| raw_datasets["train"] = lm_datasets | |
| print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB") | |
| return raw_datasets | |
| def helper_tokenize_encode( | |
| sentence_lst, | |
| vocab_dict, | |
| model, | |
| seqlen, | |
| data_args, | |
| padding_mode, | |
| ): | |
| result_train_lst = [] | |
| group_lst = defaultdict(list) | |
| with torch.no_grad(): | |
| for input_ids in sentence_lst: | |
| tokenized_ = [vocab_dict.get(x, vocab_dict["UNK"]) for x in input_ids] | |
| input_ids = [0] + tokenized_ + [1] | |
| group_lst["word_ids"].append(input_ids) | |
| print(group_lst["word_ids"][:2]) | |
| if padding_mode == "block": | |
| print("padding mode is block") | |
| concatenated_examples = {k: sum(group_lst[k], []) for k in group_lst.keys()} | |
| total_length = len(concatenated_examples[list(group_lst.keys())[0]]) | |
| block_size = seqlen | |
| total_length = (total_length // block_size) * block_size | |
| # Split by chunks of max_len. | |
| group_lst = { | |
| k: [t[i : i + block_size] for i in range(0, total_length, block_size)] | |
| for k, t in concatenated_examples.items() | |
| } | |
| elif padding_mode == "pad": | |
| print("padding mode is pad") | |
| max_length = seqlen | |
| group_lst["word_ids"] = _collate_batch_helper( | |
| group_lst["word_ids"], vocab_dict["PAD"], max_length | |
| ) | |
| for input_ids in group_lst["word_ids"]: | |
| if data_args.experiment.startswith("random"): | |
| hidden_state = model(torch.tensor(input_ids)) | |
| elif data_args.experiment == "gpt2_pre_compress": | |
| input_ids2 = torch.tensor(input_ids).to(model.device) | |
| input_embs = model.transformer.wte(input_ids2) # input_embs | |
| hidden_state = model.down_proj(input_embs) | |
| hidden_state = hidden_state * data_args.emb_scale_factor | |
| elif data_args.experiment == "glove": | |
| hidden_state = model(torch.tensor(input_ids)) | |
| result_train_lst.append( | |
| {"input_ids": input_ids, "hidden_states": hidden_state.cpu().tolist()} | |
| ) | |
| return result_train_lst | |
| def load_glove_model(File): | |
| print("Loading Glove Model") | |
| glove_model = {} | |
| with open(File, "r") as f: | |
| for line in f: | |
| split_line = line.split() | |
| word = split_line[0] | |
| embedding = torch.tensor(np.array(split_line[1:], dtype=np.float64)) | |
| # embedding = np.array(split_line[1:], dtype=np.float64) | |
| glove_model[word] = embedding | |
| print(f"{len(glove_model)} words loaded!") | |
| return glove_model | |
| def load_glove(vocab): | |
| model = torch.nn.Embedding(len(vocab), 50) | |
| glove_model = load_glove_model("predictability/glove/glove.6B.50d.txt") | |
| array_lst = [] | |
| count_ = 0 | |
| for word, idx in vocab.items(): | |
| if word in glove_model: | |
| array_lst.append(glove_model[word]) | |
| else: | |
| count_ += 1 | |
| array_lst.append(torch.randn(50)) | |
| print(f"{count_} out of {len(vocab)} is initialized. ") | |
| array_lst = torch.stack(array_lst) | |
| print(torch.norm(array_lst, dim=-1).mean()) | |
| model.weight.data = array_lst | |
| return model | |
| def get_corpus_rocstory( | |
| data_args, model, image_size, padding_mode="block", split="train", load_vocab=None | |
| ): | |
| import csv, torch, json | |
| from spacy.lang.en import English | |
| if data_args.experiment_mode == "lm": | |
| if data_args.modality == "roc": | |
| pass | |
| # print('loading dataset from ROCStory') | |
| # nlp = English() | |
| # tokenizer = nlp.tokenizer | |
| # sentence_lst = [] | |
| # print(f'loading from {data_args.roc_train}') | |
| # if split == 'train': | |
| # print('loading form the TRAIN set') | |
| # path = f'{data_args.roc_train}/roc_train.json' | |
| # elif split == 'valid': | |
| # print('loading form the VALID set') | |
| # path = f'{data_args.roc_train}/roc_valid.json' | |
| # else: | |
| # assert False, "invalid split for ROC dataset" | |
| # with open(path, 'r') as roc_reader: | |
| # for row in roc_reader: | |
| # sentences = json.loads(row)[0].strip() | |
| # word_lst = [x.text for x in tokenizer(sentences)] | |
| # sentence_lst.append(word_lst) | |
| # # with open(data_args.roc_train, 'r') as csvfile: | |
| # # roc_reader = csv.reader(csvfile) #delimiter=' ', quotechar='|') | |
| # # for row in roc_reader: | |
| # # # tokenize. | |
| # # sentences = " ".join(row[2:]) | |
| # # word_lst = [x.text for x in tokenizer(sentences)] | |
| # # sentence_lst.append(word_lst) | |
| # # sentence_lst = sentence_lst[1:] | |
| # print(sentence_lst[:2]) | |
| if data_args.modality == "roc-aug": | |
| pass | |
| # print('loading dataset from ROCStory') | |
| # nlp = English() | |
| # tokenizer = nlp.tokenizer | |
| # sentence_lst = [] | |
| # if split == 'train': | |
| # print('loading form the TRAIN set') | |
| # path_lst = [f'{data_args.roc_train}/roc_train.json'] | |
| # path_lst.append('diffusion_lm/improved-diffusion/diff_models/rocstories_gptj.txt') | |
| # # path_lst.append('diffusion_lm/improved-diffusion/cache/ar_model_augment_roc.json') | |
| # # path_lst.append('diffusion_lm/improved-diffusion/cache/ar_model_augment_roc2.json') | |
| # elif split == 'valid': | |
| # print('loading form the VALID set') | |
| # path_lst = [f'{data_args.roc_train}/roc_valid.json'] | |
| # else: | |
| # assert False, "invalid split for ROC dataset" | |
| # print(path_lst) | |
| # for path in path_lst: | |
| # if path.endswith('txt'): | |
| # with open(path, 'r') as roc_reader: | |
| # for row in roc_reader: | |
| # sentences = row.strip() | |
| # word_lst = [x.text for x in tokenizer(sentences)] | |
| # sentence_lst.append(word_lst) | |
| # else: | |
| # with open(path, 'r') as roc_reader: | |
| # for row in roc_reader: | |
| # sentences = json.loads(row)[0].strip() | |
| # word_lst = [x.text for x in tokenizer(sentences)] | |
| # sentence_lst.append(word_lst) | |
| # print(sentence_lst[:2],sentence_lst[-2:], 'dataset size=',len(sentence_lst)) | |
| elif data_args.modality == "simple-wiki": | |
| pass | |
| # print('loading dataset from simple wikipedia') | |
| # sentence_lst = [] | |
| # with open(data_args.wiki_train, 'r') as ff: | |
| # for row in ff: | |
| # word_lst = row.lower().split() | |
| # sentence_lst.append(word_lst) | |
| # print(sentence_lst[:2]) | |
| elif data_args.modality == "e2e-tgt": | |
| print("loading dataset from simple e2e dataset") | |
| sentence_lst = [] | |
| nlp = English() | |
| tokenizer = nlp.tokenizer | |
| if split == "train": | |
| print("loading form the TRAIN set") | |
| path = ( | |
| "/data0/gonghaisong/Diffusion-LM/datasets/e2e_data/src1_train.txt" | |
| ) | |
| # path = f'../{data_args.e2e_train}/src1_train.txt' | |
| elif split == "valid": | |
| print("loading form the VALID set") | |
| path = f"../{data_args.e2e_train}/src1_valid.txt" | |
| path = ( | |
| "/data0/gonghaisong/Diffusion-LM/datasets/e2e_data/src1_valid.txt" | |
| ) | |
| elif split == "test": | |
| print("loading form the TEST set") | |
| path = f"../{data_args.e2e_train}/src1_test.txt" | |
| path = "/data0/gonghaisong/Diffusion-LM/datasets/e2e_data/src1_test.txt" | |
| elif split == "debug": | |
| print("loading form the DEBUG set") | |
| path = data_args.debug_path | |
| import json | |
| with open(path, "r") as ff: | |
| for line in ff: | |
| sentence_lst.append(json.loads(line)[0].split(" ")) | |
| sentence_lst = sentence_lst + sentence_lst | |
| if split in ["train", "valid", "test"]: | |
| with open(path, "r") as ff: | |
| for row in ff: | |
| word_lst = row.split("||")[1] | |
| word_lst = [x.text for x in tokenizer(word_lst)] | |
| sentence_lst.append(word_lst) | |
| print(sentence_lst[:2]) | |
| elif data_args.modality == "yelp": | |
| print("loading dataset from simple YelpNLG dataset") | |
| sentence_lst = [] | |
| nlp = English() | |
| tokenizer = nlp.tokenizer | |
| if split == "train": | |
| print("loading form the TRAIN set") | |
| path = f"{data_args.yelp_train}/yelpnlg-train.csv" | |
| elif split == "valid": | |
| print("loading form the VALID set") | |
| path = f"{data_args.yelp_train}/yelpnlg-dev.csv" | |
| elif split == "test": | |
| print("loading form the TEST set") | |
| path = f"{data_args.yelp_train}/yelpnlg-test.csv" | |
| if split in ["train", "valid", "test"]: | |
| with open(path, "r") as csvfile: | |
| yelp_reader = csv.reader(csvfile) # delimiter=' ', quotechar='|') | |
| for row in yelp_reader: | |
| sentences = row[1] | |
| word_lst = [x.text for x in tokenizer(sentences)] | |
| sentence_lst.append(word_lst) | |
| sentence_lst = sentence_lst[1:] | |
| print(sentence_lst[:2]) | |
| elif data_args.modality == "commonGen": | |
| print("loading dataset from simple YelpNLG dataset") | |
| sentence_lst = [] | |
| nlp = English() | |
| tokenizer = nlp.tokenizer | |
| if split == "train": | |
| print("loading form the TRAIN set") | |
| path = f"{data_args.commonGen_train}/commongen.train.jsonl" | |
| elif split == "valid": | |
| print("loading form the VALID set") | |
| path = f"{data_args.commonGen_train}/commongen.dev.jsonl" | |
| elif split == "test": | |
| print("loading form the TEST set") | |
| path = f"{data_args.commonGen_train}/commongen.test.jsonl" | |
| if split in ["train", "valid", "test"]: | |
| with open(path, "r") as ff: | |
| for line in ff: | |
| line = json.loads(line) | |
| for sentences in line["scene"]: | |
| word_lst = [x.text for x in tokenizer(sentences)] | |
| sentence_lst.append(word_lst) | |
| print(sentence_lst[:2]) | |
| elif data_args.modality == "commonGen-aug": | |
| print("loading dataset from simple YelpNLG dataset") | |
| sentence_lst = [] | |
| nlp = English() | |
| tokenizer = nlp.tokenizer | |
| if split == "train": | |
| print("loading form the TRAIN set") | |
| path = f"{data_args.commonGen_train}/commongen.train.jsonl" | |
| path_lst = [f"{data_args.roc_train}/roc_train.json"] | |
| path_lst.append( | |
| "diffusion_lm/improved-diffusion/diff_models/rocstories_gptj.txt" | |
| ) | |
| elif split == "valid": | |
| print("loading form the VALID set") | |
| path = f"{data_args.commonGen_train}/commongen.dev.jsonl" | |
| path_lst = [] | |
| elif split == "test": | |
| print("loading form the TEST set") | |
| path = f"{data_args.commonGen_train}/commongen.test.jsonl" | |
| path_lst = [] | |
| if split in ["train", "valid", "test"]: | |
| with open(path, "r") as ff: | |
| for line in ff: | |
| line = json.loads(line) | |
| for sentences in line["scene"]: | |
| word_lst = [x.text for x in tokenizer(sentences)] | |
| sentence_lst.append(word_lst) | |
| print(sentence_lst[:2]) | |
| import itertools | |
| for path in path_lst: | |
| if path.endswith("txt"): | |
| with open(path, "r") as roc_reader: | |
| for row in roc_reader: | |
| sentences = row.strip() | |
| word_lst = [x.text for x in tokenizer(sentences)] | |
| spl = [[]] | |
| for x, y in itertools.groupby(word_lst, lambda z: z == "."): | |
| spl[-1].extend(y) | |
| if x: | |
| spl.append([]) | |
| sentence_lst.extend(spl[:-1]) | |
| else: | |
| with open(path, "r") as roc_reader: | |
| for row in roc_reader: | |
| sentences = json.loads(row)[0].strip() | |
| word_lst = [x.text for x in tokenizer(sentences)] | |
| spl = [[]] | |
| for x, y in itertools.groupby(word_lst, lambda z: z == "."): | |
| spl[-1].extend(y) | |
| if x: | |
| spl.append([]) | |
| sentence_lst.extend(spl[:-1]) | |
| print(sentence_lst[-2:]) | |
| # get tokenizer. | |
| if load_vocab is None: | |
| counter = Counter() | |
| for input_ids in sentence_lst: | |
| counter.update(input_ids) | |
| if data_args.experiment_mode == "conditional_gen": | |
| if data_args.modality == "e2e": | |
| print("loading dataset from simple e2e dataset") | |
| sentence_lst = [] | |
| nlp = English() | |
| tokenizer = nlp.tokenizer | |
| if split == "train": | |
| path = f"{data_args.e2e_train}/src1_train.txt" | |
| with open(path, "r") as ff: | |
| for row in ff: | |
| src_lst, word_lst = row.split("||") | |
| word_lst = [x.text for x in tokenizer(word_lst)] | |
| src_lst = [x.text for x in tokenizer(src_lst)] | |
| sentence_lst.append((src_lst, word_lst)) | |
| elif split == "valid": | |
| path = f"{data_args.e2e_train}/src1_valid.txt" | |
| sentence_lst = read_e2e_files(path, data_args, tokenizer) | |
| print(sentence_lst[:2]) | |
| # get tokenizer. | |
| if load_vocab is None: | |
| counter = Counter() | |
| for src_ids, input_ids in sentence_lst: | |
| counter.update(input_ids) | |
| counter.update(src_ids) | |
| if load_vocab is None: | |
| vocab_dict = {"START": 0, "END": 1, "UNK": 2, "PAD": 3} | |
| for k, v in counter.items(): | |
| if v > 10: | |
| vocab_dict[k] = len(vocab_dict) | |
| print(len(counter), len(vocab_dict)) | |
| path_save_vocab = "/data0/gonghaisong/Diffusion-LM/improved-diffusion/diffusion_models/diff_e2e-tgt_block_rand16_transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd102_xstart_e2e/vocab.json" | |
| print(f"save the vocab to {path_save_vocab}") | |
| with open(path_save_vocab, "w") as f: | |
| json.dump(vocab_dict, f) | |
| else: | |
| vocab_dict = load_vocab | |
| path_save_vocab = "/data0/gonghaisong/Diffusion-LM/improved-diffusion/diffusion_models/diff_e2e-tgt_block_rand16_transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd102_xstart_e2e/vocab.json" | |
| if not os.path.exists(path_save_vocab): | |
| print(f"save the vocab to {path_save_vocab}") | |
| if isinstance(vocab_dict, dict): | |
| with open(path_save_vocab, "w") as f: | |
| json.dump(vocab_dict, f) | |
| assert vocab_dict["START"] == 0 | |
| elif isinstance(vocab_dict, PreTrainedTokenizerFast): | |
| vocab_dict.save_pretrained(data_args.checkpoint_path) | |
| else: | |
| assert False, "invalid type of vocab_dict" | |
| if model is None and data_args.experiment == "random": | |
| model = torch.nn.Embedding(len(vocab_dict), data_args.in_channel) | |
| print("initializing the random embeddings", model) | |
| torch.nn.init.normal_(model.weight) | |
| path_save = "/data0/gonghaisong/Diffusion-LM/improved-diffusion/diffusion_models/diff_e2e-tgt_block_rand16_transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd102_xstart_e2e/random_emb.torch" | |
| print( | |
| f"save the random encoder to {data_args.checkpoint_path}/random_emb.torch" | |
| ) | |
| torch.save(model.state_dict(), path_save) | |
| # path_save = f'{data_args.checkpoint_path}/random_emb.torch' | |
| # if not os.path.exists(path_save) and data_args.experiment == 'random': | |
| # torch.save(model.state_dict(), path_save) | |
| if ( | |
| data_args.experiment_mode == "lm" | |
| and data_args.modality | |
| in ["roc-aug", "roc", "yelp", "commonGen", "commonGen-aug"] | |
| and data_args.cache_mode == "no" | |
| ): | |
| train_dataset = helper_tokenize_stream( | |
| sentence_lst, vocab_dict, model, image_size**2, data_args, padding_mode | |
| ) | |
| return train_dataset, model | |
| elif data_args.experiment_mode == "lm": | |
| result_train_lst = helper_tokenize_encode( | |
| sentence_lst, vocab_dict, model, image_size**2, data_args, padding_mode | |
| ) | |
| elif data_args.experiment_mode == "conditional_gen": | |
| result_train_lst = helper_tokenize_encode_cond( | |
| sentence_lst, vocab_dict, model, image_size**2, data_args | |
| ) | |
| return {"train": result_train_lst}, model | |
| def write_e2e_corr(prompt_lst, file_dict, corr_path): | |
| print(len(prompt_lst)) | |
| with open(corr_path, "w") as f: | |
| for x in prompt_lst: | |
| for line in file_dict[x]: | |
| print(" ".join(line), file=f) | |
| print("", file=f) | |
| def write_e2e_src(prompt_lst, corr_path): | |
| with open(corr_path, "w") as f: | |
| for x in prompt_lst: | |
| print(" ".join(x), file=f) | |
| return | |
| def read_e2e_files(path, args, tokenizer): | |
| file_dict = {} | |
| with open(path, "r") as f: | |
| for line in f: | |
| src_lst, word_lst = line.strip().split("||") | |
| tgt = tuple([x.text for x in tokenizer(word_lst)]) | |
| src = tuple([x.text for x in tokenizer(src_lst)]) | |
| if src not in file_dict: | |
| file_dict[src] = [] | |
| file_dict[src].append(tgt) | |
| temp = "1" | |
| prompt_text_dict = file_dict | |
| prompt_text_lst = list(prompt_text_dict.keys()) | |
| gold_dir = os.path.join(args.out_dir, "{}_{}_{}".format(temp, args.split, "gold")) | |
| print("gold dir", gold_dir) | |
| write_e2e_corr(prompt_text_lst, prompt_text_dict, gold_dir) | |
| src_dir = os.path.join(args.out_dir, "{}_{}_{}".format(temp, args.split, "src")) | |
| write_e2e_src(prompt_text_lst, src_dir) | |
| final_lst = [(xx, prompt_text_dict[xx][0]) for xx in prompt_text_lst] | |
| return final_lst | |
| def get_corpus_book( | |
| data_args, | |
| tokenizer, | |
| model, | |
| image_size, | |
| padding_mode="block", | |
| split="train", | |
| ): | |
| max_length = image_size**2 | |
| import os | |
| assert padding_mode == "block" | |
| raw_datasets = load_dataset("bookcorpus") | |
| if "validation" not in raw_datasets.keys(): | |
| raw_datasets["validation"] = load_dataset( | |
| "bookcorpus", | |
| split=f"train[:1%]", | |
| ) | |
| raw_datasets["train"] = load_dataset( | |
| "bookcorpus", | |
| split=f"train[1%:]", | |
| ) | |
| print(raw_datasets) | |
| column_names = raw_datasets["train"].column_names | |
| def tokenize_function(examples): | |
| output = tokenizer(examples["text"], add_special_tokens=False) | |
| return output | |
| tokenized_datasets = raw_datasets.map( | |
| tokenize_function, | |
| batched=True, | |
| num_proc=data_args.preprocessing_num_workers, | |
| remove_columns=column_names, | |
| load_from_cache_file=True, | |
| ) | |
| print(tokenized_datasets) | |
| block_size = max_length | |
| # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size. | |
| def group_texts(examples): | |
| # Concatenate all texts. | |
| concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} | |
| total_length = len(concatenated_examples[list(examples.keys())[0]]) | |
| if total_length >= block_size: | |
| total_length = (total_length // block_size) * block_size | |
| result = { | |
| k: [t[i : i + block_size] for i in range(0, total_length, block_size)] | |
| for k, t in concatenated_examples.items() | |
| } | |
| return result | |
| lm_datasets = tokenized_datasets.map( | |
| group_texts, | |
| batched=True, | |
| num_proc=4, | |
| load_from_cache_file=True, | |
| desc=f"Grouping texts in chunks of {block_size}", | |
| ) | |
| print(lm_datasets) | |
| if model is None: | |
| if data_args.training_mode.startswith("e2e"): | |
| print("since its e2e, initialize a dummy embedding") | |
| model = torch.nn.Embedding(len(tokenizer), 1) | |
| else: | |
| model = torch.nn.Embedding(len(tokenizer), data_args.in_channel) | |
| print("initializing the random embeddings", model) | |
| torch.nn.init.normal_(model.weight) | |
| path_save = f"{data_args.checkpoint_path}/random_emb.torch" | |
| print( | |
| f"save the random encoder to {data_args.checkpoint_path}/random_emb.torch" | |
| ) | |
| torch.save(model.state_dict(), path_save) | |
| if split == "train": | |
| return lm_datasets, model | |
| else: | |
| lm_datasets["train"] = lm_datasets["validation"] | |
| return lm_datasets, model | |
| class TextDataset(Dataset): | |
| def __init__( | |
| self, | |
| text_datasets, | |
| resolution, | |
| data_args, | |
| model_arch="conv-unet", | |
| classes=None, | |
| shard=0, | |
| num_shards=1, | |
| eigen_transform=None, | |
| mapping_func=None, | |
| model_emb=None, | |
| ): | |
| super().__init__() | |
| self.resolution = resolution | |
| self.text_datasets = text_datasets | |
| self.length = len(self.text_datasets["train"]) | |
| self.model_arch = model_arch | |
| self.data_args = data_args | |
| print(self.resolution) | |
| self.eigen_transform = eigen_transform | |
| self.mapping_func = mapping_func | |
| self.model_emb = model_emb | |
| # self.local_images = image_paths[shard:][::num_shards] | |
| # self.local_classes = None if classes is None else classes[shard:][::num_shards] | |
| def __len__(self): | |
| return self.length | |
| def __getitem__(self, idx): | |
| # We are not on a new enough PIL to support the `reducing_gap` | |
| # argument, which uses BOX downsampling at powers of two first. | |
| # Thus, we do it by hand to improve downsample quality. | |
| if self.model_arch == "conv-unet": | |
| pass # arr = np.array(self.text_datasets['train'][idx]['hidden_states'], | |
| # dtype=np.float32).reshape(self.resolution, self.resolution, -1) | |
| # # print(self.eigen_transform.shape) | |
| # if self.eigen_transform is not None: | |
| # old_shape = arr.shape | |
| # arr = arr.reshape(1, -1) - self.eigen_transform['mean'] | |
| # arr = arr @ self.eigen_transform['map'] | |
| # arr = arr.reshape(old_shape) | |
| # if hasattr(self.data_args, 'noise_level') and self.data_args.noise_level > 0: | |
| # arr = arr + self.data_args.noise_level * np.random.randn(*arr.shape).astype(arr.dtype) | |
| # out_dict = {} | |
| # out_dict['input_ids'] = np.array(self.text_datasets['train'][idx]['input_ids']) | |
| # # if self.local_classes is not None: | |
| # # out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64) | |
| # # print(out_dict.keys()) | |
| # return np.transpose(arr, [2, 0, 1]), out_dict | |
| elif self.model_arch == "1d-unet": | |
| pass # arr = np.array(self.text_datasets['train'][idx]['hidden_states'], | |
| # dtype=np.float32) # seqlen, dim | |
| # if self.eigen_transform is not None: | |
| # old_shape = arr.shape | |
| # arr = arr.reshape(1, -1) - self.eigen_transform['mean'] | |
| # arr = arr @ self.eigen_transform['map'] | |
| # arr = arr.reshape(old_shape) | |
| # if hasattr(self.data_args, 'noise_level') and self.data_args.noise_level > 0: | |
| # arr = arr + self.data_args.noise_level * np.random.randn(*arr.shape).astype(arr.dtype) | |
| # arr = np.transpose(arr, [1, 0]) | |
| # out_dict = {} | |
| # out_dict['input_ids'] = np.array(self.text_datasets['train'][idx]['input_ids']) | |
| # # out_dict['mapping_func'] = self.mapping_func | |
| # # if self.local_classes is not None: | |
| # # out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64) | |
| # # print(arr.shape) | |
| # return arr, out_dict | |
| else: | |
| arr = np.array( | |
| self.text_datasets["train"][idx]["hidden_states"], dtype=np.float32 | |
| ) | |
| if self.eigen_transform is not None: | |
| old_shape = arr.shape | |
| # arr = arr.reshape(1, -1) @ self.eigen_transform | |
| arr = arr.reshape(1, -1) - self.eigen_transform["mean"] | |
| arr = arr @ self.eigen_transform["map"] | |
| arr = arr.reshape(old_shape) | |
| if ( | |
| hasattr(self.data_args, "noise_level") | |
| and self.data_args.noise_level > 0 | |
| ): | |
| # print(arr.dtype) | |
| # print(self.data_args.noise_level, 'using the noise level.') | |
| arr = arr + self.data_args.noise_level * np.random.randn( | |
| *arr.shape | |
| ).astype(arr.dtype) | |
| # print(arr.dtype) | |
| out_dict = {} | |
| out_dict["input_ids"] = np.array( | |
| self.text_datasets["train"][idx]["input_ids"] | |
| ) | |
| # out_dict['mapping_func'] = self.mapping_func | |
| if self.data_args.experiment_mode == "conditional_gen": | |
| out_dict["src_ids"] = np.array( | |
| self.text_datasets["train"][idx]["src_ids"] | |
| ) | |
| out_dict["src_mask"] = np.array( | |
| self.text_datasets["train"][idx]["src_mask"] | |
| ) | |
| # if self.local_classes is not None: | |
| # out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64) | |
| return arr, out_dict | |
| # print(arr.dtype) | |
| # arr = arr.float() | |
| # print(arr.shape) | |
| class TextDataset_NoCache(Dataset): | |
| def __init__( | |
| self, | |
| text_datasets, | |
| resolution, | |
| data_args, | |
| model_arch="conv-unet", | |
| classes=None, | |
| shard=0, | |
| num_shards=1, | |
| eigen_transform=None, | |
| mapping_func=None, | |
| model_emb=None, | |
| ): | |
| super().__init__() | |
| self.resolution = resolution | |
| self.text_datasets = text_datasets | |
| self.length = len(self.text_datasets["train"]) | |
| self.model_arch = model_arch | |
| self.data_args = data_args | |
| print(self.resolution) | |
| self.eigen_transform = eigen_transform | |
| self.mapping_func = mapping_func | |
| self.model_emb = model_emb | |
| # self.local_images = image_paths[shard:][::num_shards] | |
| # self.local_classes = None if classes is None else classes[shard:][::num_shards] | |
| def __len__(self): | |
| return self.length | |
| def __getitem__(self, idx): | |
| # We are not on a new enough PIL to support the `reducing_gap` | |
| # argument, which uses BOX downsampling at powers of two first. | |
| # Thus, we do it by hand to improve downsample quality. | |
| with torch.no_grad(): | |
| input_ids = self.text_datasets["train"][idx]["input_ids"] | |
| model = self.model_emb | |
| if self.data_args.experiment.startswith("random"): | |
| hidden_state = model(torch.tensor(input_ids)) | |
| elif self.data_args.experiment == "gpt2_pre_compress": | |
| input_ids2 = torch.tensor(input_ids).to(model.device) | |
| input_embs = model.transformer.wte(input_ids2) # input_embs | |
| hidden_state = model.down_proj(input_embs) | |
| hidden_state = hidden_state * data_args.emb_scale_factor | |
| if self.model_arch == "conv-unet": | |
| arr = np.array(hidden_state, dtype=np.float32).reshape( | |
| self.resolution, self.resolution, -1 | |
| ) | |
| # print(self.eigen_transform.shape) | |
| if self.eigen_transform is not None: | |
| old_shape = arr.shape | |
| arr = arr.reshape(1, -1) - self.eigen_transform["mean"] | |
| arr = arr @ self.eigen_transform["map"] | |
| arr = arr.reshape(old_shape) | |
| if ( | |
| hasattr(self.data_args, "noise_level") | |
| and self.data_args.noise_level > 0 | |
| ): | |
| arr = arr + self.data_args.noise_level * np.random.randn( | |
| *arr.shape | |
| ).astype(arr.dtype) | |
| out_dict = {} | |
| out_dict["input_ids"] = np.array( | |
| self.text_datasets["train"][idx]["input_ids"] | |
| ) | |
| # if self.local_classes is not None: | |
| # out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64) | |
| # print(out_dict.keys()) | |
| return np.transpose(arr, [2, 0, 1]), out_dict | |
| elif self.model_arch == "1d-unet": | |
| arr = np.array(hidden_state, dtype=np.float32) # seqlen, dim | |
| if self.eigen_transform is not None: | |
| old_shape = arr.shape | |
| arr = arr.reshape(1, -1) - self.eigen_transform["mean"] | |
| arr = arr @ self.eigen_transform["map"] | |
| arr = arr.reshape(old_shape) | |
| if ( | |
| hasattr(self.data_args, "noise_level") | |
| and self.data_args.noise_level > 0 | |
| ): | |
| arr = arr + self.data_args.noise_level * np.random.randn( | |
| *arr.shape | |
| ).astype(arr.dtype) | |
| arr = np.transpose(arr, [1, 0]) | |
| out_dict = {} | |
| out_dict["input_ids"] = np.array( | |
| self.text_datasets["train"][idx]["input_ids"] | |
| ) | |
| # out_dict['mapping_func'] = self.mapping_func | |
| # if self.local_classes is not None: | |
| # out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64) | |
| # print(arr.shape) | |
| return arr, out_dict | |
| else: | |
| arr = np.array(hidden_state, dtype=np.float32) | |
| if self.eigen_transform is not None: | |
| old_shape = arr.shape | |
| # arr = arr.reshape(1, -1) @ self.eigen_transform | |
| arr = arr.reshape(1, -1) - self.eigen_transform["mean"] | |
| arr = arr @ self.eigen_transform["map"] | |
| arr = arr.reshape(old_shape) | |
| if ( | |
| hasattr(self.data_args, "noise_level") | |
| and self.data_args.noise_level > 0 | |
| ): | |
| # print(arr.dtype) | |
| # print(self.data_args.noise_level, 'using the noise level.') | |
| arr = arr + self.data_args.noise_level * np.random.randn( | |
| *arr.shape | |
| ).astype(arr.dtype) | |
| # print(arr.dtype) | |
| out_dict = {} | |
| out_dict["input_ids"] = np.array( | |
| self.text_datasets["train"][idx]["input_ids"] | |
| ) | |
| # out_dict['mapping_func'] = self.mapping_func | |
| if self.data_args.experiment_mode == "conditional_gen": | |
| out_dict["src_ids"] = np.array( | |
| self.text_datasets["train"][idx]["src_ids"] | |
| ) | |
| out_dict["src_mask"] = np.array( | |
| self.text_datasets["train"][idx]["src_mask"] | |
| ) | |
| # if self.local_classes is not None: | |
| # out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64) | |
| return arr, out_dict | |
| def _collate_batch_helper(examples, pad_token_id, max_length, return_mask=False): | |
| result = torch.full( | |
| [len(examples), max_length], pad_token_id, dtype=torch.int64 | |
| ).tolist() | |
| mask_ = torch.full( | |
| [len(examples), max_length], pad_token_id, dtype=torch.int64 | |
| ).tolist() | |
| for i, example in enumerate(examples): | |
| curr_len = min(len(example), max_length) | |
| result[i][:curr_len] = example[:curr_len] | |
| mask_[i][:curr_len] = [1] * curr_len | |
| if return_mask: | |
| return result, mask_ | |
| return result | |
| def _torch_collate_batch(examples, pad_token_id, max_length): | |
| """Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary.""" | |
| import numpy as np | |
| import torch | |
| # Tensorize if necessary. | |
| if isinstance(examples[0], (list, tuple, np.ndarray)): | |
| examples = [torch.tensor(e, dtype=torch.long) for e in examples] | |
| # length_of_first = examples[0].size(0) | |
| # Check if padding is necessary. | |
| # are_tensors_same_length = all(x.size(0) == length_of_first for x in examples) | |
| # if are_tensors_same_length and (pad_to_multiple_of is None or length_of_first % pad_to_multiple_of == 0): | |
| # return torch.stack(examples, dim=0) | |
| # Creating the full tensor and filling it with our data. | |
| # max_length = max(x.size(0) for x in examples) | |
| # if pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): | |
| # max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of | |
| result = examples[0].new_full([len(examples), max_length], pad_token_id) | |
| for i, example in enumerate(examples): | |
| if True: | |
| result[i, : example.shape[0]] = example | |
| else: | |
| result[i, -example.shape[0] :] = example | |
| return result | |