Spaces:
Build error
Build error
| import comet.src.data.utils as data_utils | |
| import comet.src.data.atomic as adata | |
| import comet.src.data.config as cfg | |
| import torch | |
| import random | |
| from tqdm import tqdm | |
| def map_name(name, opt): | |
| if name == "train": | |
| return "train{}k.txt".format(opt.trainsize) | |
| elif name == "test": | |
| return "test.txt" | |
| else: | |
| return "dev{}.txt".format(opt.devversion) | |
| conceptnet_relations = [ | |
| 'AtLocation', 'CapableOf', 'Causes', 'CausesDesire', | |
| 'CreatedBy', 'DefinedAs', 'DesireOf', 'Desires', 'HasA', | |
| 'HasFirstSubevent', 'HasLastSubevent', 'HasPainCharacter', | |
| 'HasPainIntensity', 'HasPrerequisite', 'HasProperty', | |
| 'HasSubevent', 'InheritsFrom', 'InstanceOf', 'IsA', | |
| 'LocatedNear', 'LocationOfAction', 'MadeOf', 'MotivatedByGoal', | |
| 'NotCapableOf', 'NotDesires', 'NotHasA', 'NotHasProperty', | |
| 'NotIsA', 'NotMadeOf', 'PartOf', 'ReceivesAction', 'RelatedTo', | |
| 'SymbolOf', 'UsedFor' | |
| ] | |
| split_into_words = { | |
| 'AtLocation': "at location", | |
| 'CapableOf': "capable of", | |
| 'Causes': "causes", | |
| 'CausesDesire': "causes desire", | |
| 'CreatedBy': "created by", | |
| 'DefinedAs': "defined as", | |
| 'DesireOf': "desire of", | |
| 'Desires': "desires", | |
| 'HasA': "has a", | |
| 'HasFirstSubevent': "has first subevent", | |
| 'HasLastSubevent': "has last subevent", | |
| 'HasPainCharacter': "has pain character", | |
| 'HasPainIntensity': "has pain intensity", | |
| 'HasPrerequisite': "has prequisite", | |
| 'HasProperty': "has property", | |
| 'HasSubevent': "has subevent", | |
| 'InheritsFrom': "inherits from", | |
| 'InstanceOf': 'instance of', | |
| 'IsA': "is a", | |
| 'LocatedNear': "located near", | |
| 'LocationOfAction': "location of action", | |
| 'MadeOf': "made of", | |
| 'MotivatedByGoal': "motivated by goal", | |
| 'NotCapableOf': "not capable of", | |
| 'NotDesires': "not desires", | |
| 'NotHasA': "not has a", | |
| 'NotHasProperty': "not has property", | |
| 'NotIsA': "not is a", | |
| 'NotMadeOf': "not made of", | |
| 'PartOf': "part of", | |
| 'ReceivesAction': "receives action", | |
| 'RelatedTo': "related to", | |
| 'SymbolOf': "symbol of", | |
| 'UsedFor': "used for" | |
| } | |
| class GenerationDataLoader(adata.DataLoader): | |
| def __init__(self, opt, categories=None): | |
| super(GenerationDataLoader, self).__init__(opt) | |
| self.opt = opt | |
| for split in self.data: | |
| self.data[split] = {"total": []} | |
| self.offsets[split] = {"total": 0} | |
| self.vocab_encoder = None | |
| self.vocab_decoder = None | |
| self.special_chars = None | |
| self.max_e1 = None | |
| self.max_e2 = None | |
| self.max_r = None | |
| def offset_summary(self, split): | |
| return sum(self.offsets[split].values()) | |
| def load_data(self, path): | |
| if ".pickle" in path: | |
| print("Loading data from: {}".format(path)) | |
| data_utils.load_existing_data_loader(self, path) | |
| return True | |
| for split in self.data: | |
| file_name = map_name(split, self.opt.data) | |
| if split != "dev" or self.opt.data.devversion != "12": | |
| string_tuples = open("{}/{}".format( | |
| path, file_name), "r").read().split("\n") | |
| tuples = [x.split("\t") for x in string_tuples if x] | |
| else: | |
| string_tuples = open("{}/{}".format( | |
| path, "dev1.txt"), "r").read().split("\n") | |
| tuples = [x.split("\t") for x in string_tuples if x] | |
| string_tuples = open("{}/{}".format( | |
| path, "dev2.txt"), "r").read().split("\n") | |
| tuples += [x.split("\t") for x in string_tuples if x] | |
| if split in ["dev", "test"]: | |
| if self.opt.data.rel == "language": | |
| self.data[split]["total"] = \ | |
| [(i[1].lower().strip(), split_into_words[i[0]], | |
| i[2].lower().strip(), int(i[3])) for i in tuples] | |
| self.data[split]["positive"] = \ | |
| [(i[1].lower().strip(), split_into_words[i[0]], | |
| i[2].lower().strip(), int(i[3])) for i in tuples if int(i[3])] | |
| self.data[split]["negative"] = \ | |
| [(i[1].lower().strip(), split_into_words[i[0]], | |
| i[2].lower().strip(), int(i[3])) for i in tuples if not int(i[3])] | |
| elif self.opt.data.rel == "relation": | |
| self.data[split]["total"] = \ | |
| [(i[1].lower().strip(), "<{}>".format(i[0]), | |
| i[2].lower().strip(), int(i[3])) for i in tuples] | |
| self.data[split]["positive"] = \ | |
| [(i[1].lower().strip(), "<{}>".format(i[0]), | |
| i[2].lower().strip(), int(i[3])) for i in tuples if int(i[3])] | |
| self.data[split]["negative"] = \ | |
| [(i[1].lower().strip(), "<{}>".format(i[0]), | |
| i[2].lower().strip(), int(i[3])) for i in tuples if not int(i[3])] | |
| else: | |
| if self.opt.data.rel == "language": | |
| self.data[split]["total"] = \ | |
| [(i[1].lower().strip(), split_into_words[i[0]], | |
| i[2].lower().strip(), i[3]) for i in tuples] | |
| elif self.opt.data.rel == "relation": | |
| self.data[split]["total"] = \ | |
| [(i[1].lower().strip(), "<{}>".format(i[0]), | |
| i[2].lower().strip(), i[3]) for i in tuples] | |
| return False | |
| def make_tensors(self, text_encoder, special, | |
| splits=["train", "dev", "test"], test=False): | |
| self.vocab_encoder = text_encoder.encoder | |
| self.vocab_decoder = text_encoder.decoder | |
| self.special_chars = special | |
| sequences = {} | |
| for split in splits: | |
| sequences[split], discarded = get_generation_sequences( | |
| self.data, split, text_encoder, test, self.opt.data.maxe1, | |
| self.opt.data.maxe2) | |
| if split == "train": | |
| self.data[split]["total"] = [j for i, j in enumerate( | |
| self.data[split]["total"]) if i not in set(discarded)] | |
| self.masks[split]["total"] = [(len(i[0]), len(i[1]), len(i[2])) for | |
| i in sequences[split]] | |
| self.max_e1 = max([max([l[0] for l in self.masks[split]["total"]]) | |
| for split in self.masks]) | |
| self.max_r = max([max([l[1] for l in self.masks[split]["total"]]) | |
| for split in self.masks]) | |
| self.max_e2 = max([max([l[2] for l in self.masks[split]["total"]]) | |
| for split in self.masks]) | |
| print(self.max_e1) | |
| print(self.max_r) | |
| print(self.max_e2) | |
| for split in splits: | |
| num_elements = len(sequences[split]) | |
| self.sequences[split]["total"] = torch.LongTensor( | |
| num_elements, self.max_e1 + self.max_e2 + self.max_r).fill_(0) | |
| for i, seq in enumerate(sequences[split]): | |
| # print(self.sequences[split]["total"][i, :len(seq[0])].size()) | |
| # print(torch.FloatTensor(seq[0]).size()) | |
| self.sequences[split]["total"][i, :len(seq[0])] = \ | |
| torch.LongTensor(seq[0]) | |
| start_r = self.max_e1 | |
| end_r = self.max_e1 + len(seq[1]) | |
| self.sequences[split]["total"][i, start_r:end_r] = \ | |
| torch.LongTensor(seq[1]) | |
| start_e2 = self.max_e1 + self.max_r | |
| end_e2 = self.max_e1 + self.max_r + len(seq[2]) | |
| self.sequences[split]["total"][i, start_e2:end_e2] = \ | |
| torch.LongTensor(seq[2]) | |
| if split in ["test", "dev"]: | |
| print(split) | |
| self.sequences[split]["negative"] = \ | |
| self.sequences[split]["total"].index_select( | |
| 0, torch.LongTensor([i for i, j in enumerate( | |
| self.data[split]['total']) if not j[3]])) | |
| # self.data[split]['total'][:self.sequences[split]["total"].size(0)]) if not j[3]])) | |
| self.sequences[split]["positive"] = \ | |
| self.sequences[split]["total"].index_select( | |
| 0, torch.LongTensor([i for i, j in enumerate( | |
| self.data[split]['total']) if j[3]])) | |
| # self.data[split]['total'][:self.sequences[split]["total"].size(0)]) if j[3]])) | |
| def sample_batch(self, split, bs, cat="total", idxs=None): | |
| offset = self.offsets[split][cat] | |
| batch = {} | |
| # Decided not to reduce computation on here because it's all parallel | |
| # anyway and we don't want to run out of memory in cases where we | |
| # don't see the longest version quickly enough | |
| if idxs: | |
| seqs = self.sequences[split][cat].index_select( | |
| 0, torch.LongTensor(idxs).to( | |
| self.sequences[split][cat].device)) | |
| else: | |
| seqs = self.sequences[split][cat][offset:offset + bs] | |
| batch["sequences"] = seqs.to(cfg.device) | |
| batch["attention_mask"] = make_attention_mask(seqs) | |
| batch["loss_mask"] = make_loss_mask(seqs, self.max_e1 + self.max_r) | |
| batch["key"] = (cat, offset, offset + bs) | |
| offset += seqs.size(0) | |
| self.offsets[split][cat] = offset | |
| if split == "train" and offset + bs > len(self.sequences[split][cat]): | |
| return batch, True | |
| elif offset >= len(self.sequences[split][cat]): | |
| return batch, True | |
| else: | |
| return batch, False | |
| def reset_offsets(self, splits=["train", "test", "dev"], | |
| shuffle=True, keys=None): | |
| if isinstance(splits, str): | |
| splits = [splits] | |
| for split in splits: | |
| if keys is None: | |
| keys = ["total", "positive", "negative"] | |
| for key in keys: | |
| self.offsets[split][key] = 0 | |
| if shuffle: | |
| self.shuffle_sequences(split, keys) | |
| def shuffle_sequences(self, split="train", keys=None): | |
| if keys is None: | |
| # print(type(self.data)) | |
| # print(type(self.data.keys())) | |
| keys = self.data[split].keys() | |
| for key in keys: | |
| if key in ["positive", "negative"]: | |
| continue | |
| idxs = list(range(len(self.data[split][key]))) | |
| random.shuffle(idxs) | |
| self.sequences[split][key] = \ | |
| self.sequences[split][key].index_select( | |
| 0, torch.LongTensor(idxs)) | |
| temp = [self.data[split][key][i] for i in idxs] | |
| self.data[split][key] = temp | |
| temp = [self.masks[split][key][i] for i in idxs] | |
| self.masks[split][key] = temp | |
| def make_attention_mask(sequences): | |
| return (sequences != 0).float().to(cfg.device) | |
| def make_loss_mask(sequences, max_event): | |
| # print(sequences.size()) | |
| mask = (sequences != 0).float() | |
| mask[:, :max_event] = 0 | |
| return mask[:, 1:].to(cfg.device) | |
| def get_generation_sequences(data, split, text_encoder, test, | |
| max_e1=10, max_e2=15): | |
| sequences = [] | |
| count = 0 | |
| final_event1 = None | |
| final_event2 = None | |
| final_relation = None | |
| discarded = [] | |
| for event1, relation, event2, _ in tqdm(data[split]["total"]): | |
| e1, r, e2 = do_example(text_encoder, event1, relation, event2) | |
| if (split == "train" and len(e1) > max_e1 or | |
| len(e2) > max_e2): | |
| discarded.append(count) | |
| count += 1 | |
| continue | |
| final = compile_final_sequence( | |
| e1, e2, r, text_encoder) | |
| sequences.append(final) | |
| count += 1 | |
| if count > 10 and test: | |
| break | |
| return sequences, discarded | |
| def do_example(text_encoder, event1, relation, event2): | |
| final_event1 = text_encoder.encode([event1], verbose=False)[0] | |
| if relation.lower() != relation: | |
| final_relation = [text_encoder.encoder[relation]] | |
| else: | |
| final_relation = text_encoder.encode( | |
| [relation], verbose=False)[0] | |
| if event2 is not None: | |
| final_event2 = text_encoder.encode([event2], verbose=False)[0] | |
| else: | |
| final_event2 = None | |
| return final_event1, final_relation, final_event2 | |
| def compile_final_sequence(final_event1, final_event2, final_relation, text_encoder): | |
| final = [] | |
| final.append(final_event1) | |
| final.append(final_relation) | |
| final.append(final_event2) | |
| final[-1].append(text_encoder.encoder["<END>"]) | |
| return final | |