Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| # coding=utf-8 | |
| import pickle | |
| import torch | |
| from data.parser.from_mrp.node_centric_parser import NodeCentricParser | |
| from data.parser.from_mrp.labeled_edge_parser import LabeledEdgeParser | |
| from data.parser.from_mrp.sequential_parser import SequentialParser | |
| from data.parser.from_mrp.evaluation_parser import EvaluationParser | |
| from data.parser.from_mrp.request_parser import RequestParser | |
| from data.field.edge_field import EdgeField | |
| from data.field.edge_label_field import EdgeLabelField | |
| from data.field.field import Field | |
| from data.field.mini_torchtext.field import Field as TorchTextField | |
| from data.field.label_field import LabelField | |
| from data.field.anchored_label_field import AnchoredLabelField | |
| from data.field.nested_field import NestedField | |
| from data.field.basic_field import BasicField | |
| from data.field.bert_field import BertField | |
| from data.field.anchor_field import AnchorField | |
| from data.batch import Batch | |
| def char_tokenize(word): | |
| return [c for i, c in enumerate(word)] # if i < 10 or len(word) - i <= 10] | |
| class Collate: | |
| def __call__(self, batch): | |
| batch.sort(key=lambda example: example["every_input"][0].size(0), reverse=True) | |
| return Batch.build(batch) | |
| class Dataset: | |
| def __init__(self, args, verbose=True): | |
| self.verbose = verbose | |
| self.sos, self.eos, self.pad, self.unk = "<sos>", "<eos>", "<pad>", "<unk>" | |
| self.bert_input_field = BertField() | |
| self.scatter_field = BasicField() | |
| self.every_word_input_field = Field(lower=True, init_token=self.sos, eos_token=self.eos, batch_first=True, include_lengths=True) | |
| char_form_nesting = TorchTextField(tokenize=char_tokenize, init_token=self.sos, eos_token=self.eos, batch_first=True) | |
| self.char_form_field = NestedField(char_form_nesting, include_lengths=True) | |
| self.label_field = LabelField(preprocessing=lambda nodes: [n["label"] for n in nodes]) | |
| self.anchored_label_field = AnchoredLabelField() | |
| self.id_field = Field(batch_first=True, tokenize=lambda x: [x]) | |
| self.edge_presence_field = EdgeField() | |
| self.edge_label_field = EdgeLabelField() | |
| self.anchor_field = AnchorField() | |
| self.source_anchor_field = AnchorField() | |
| self.target_anchor_field = AnchorField() | |
| self.token_interval_field = BasicField() | |
| self.load_dataset(args) | |
| def log(self, text): | |
| if not self.verbose: | |
| return | |
| print(text, flush=True) | |
| def load_state_dict(self, args, d): | |
| for key, value in d["vocabs"].items(): | |
| getattr(self, key).vocab = pickle.loads(value) | |
| def state_dict(self): | |
| return { | |
| "vocabs": {key: pickle.dumps(value.vocab) for key, value in self.__dict__.items() if hasattr(value, "vocab")} | |
| } | |
| def load_sentences(self, sentences, args): | |
| dataset = RequestParser( | |
| sentences, args, | |
| fields={ | |
| "input": [("every_input", self.every_word_input_field), ("char_form_input", self.char_form_field)], | |
| "bert input": ("input", self.bert_input_field), | |
| "to scatter": ("input_scatter", self.scatter_field), | |
| "token anchors": ("token_intervals", self.token_interval_field), | |
| "id": ("id", self.id_field), | |
| }, | |
| ) | |
| self.every_word_input_field.build_vocab(dataset, min_freq=1, specials=[self.pad, self.unk, self.sos, self.eos]) | |
| self.id_field.build_vocab(dataset, min_freq=1, specials=[]) | |
| return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=Collate()) | |
| def load_dataset(self, args): | |
| parser = { | |
| "sequential": SequentialParser, | |
| "node-centric": NodeCentricParser, | |
| "labeled-edge": LabeledEdgeParser | |
| }[args.graph_mode] | |
| train = parser( | |
| args, "training", | |
| fields={ | |
| "input": [("every_input", self.every_word_input_field), ("char_form_input", self.char_form_field)], | |
| "bert input": ("input", self.bert_input_field), | |
| "to scatter": ("input_scatter", self.scatter_field), | |
| "nodes": ("labels", self.label_field), | |
| "anchored labels": ("anchored_labels", self.anchored_label_field), | |
| "edge presence": ("edge_presence", self.edge_presence_field), | |
| "edge labels": ("edge_labels", self.edge_label_field), | |
| "anchor edges": ("anchor", self.anchor_field), | |
| "source anchor edges": ("source_anchor", self.source_anchor_field), | |
| "target anchor edges": ("target_anchor", self.target_anchor_field), | |
| "token anchors": ("token_intervals", self.token_interval_field), | |
| "id": ("id", self.id_field), | |
| }, | |
| filter_pred=lambda example: len(example.input) <= 256, | |
| ) | |
| val = parser( | |
| args, "validation", | |
| fields={ | |
| "input": [("every_input", self.every_word_input_field), ("char_form_input", self.char_form_field)], | |
| "bert input": ("input", self.bert_input_field), | |
| "to scatter": ("input_scatter", self.scatter_field), | |
| "nodes": ("labels", self.label_field), | |
| "anchored labels": ("anchored_labels", self.anchored_label_field), | |
| "edge presence": ("edge_presence", self.edge_presence_field), | |
| "edge labels": ("edge_labels", self.edge_label_field), | |
| "anchor edges": ("anchor", self.anchor_field), | |
| "source anchor edges": ("source_anchor", self.source_anchor_field), | |
| "target anchor edges": ("target_anchor", self.target_anchor_field), | |
| "token anchors": ("token_intervals", self.token_interval_field), | |
| "id": ("id", self.id_field), | |
| }, | |
| ) | |
| test = EvaluationParser( | |
| args, | |
| fields={ | |
| "input": [("every_input", self.every_word_input_field), ("char_form_input", self.char_form_field)], | |
| "bert input": ("input", self.bert_input_field), | |
| "to scatter": ("input_scatter", self.scatter_field), | |
| "token anchors": ("token_intervals", self.token_interval_field), | |
| "id": ("id", self.id_field), | |
| }, | |
| ) | |
| del train.data, val.data, test.data # TODO: why? | |
| for f in list(train.fields.values()) + list(val.fields.values()) + list(test.fields.values()): # TODO: why? | |
| if hasattr(f, "preprocessing"): | |
| del f.preprocessing | |
| self.train_size = len(train) | |
| self.val_size = len(val) | |
| self.test_size = len(test) | |
| self.log(f"\n{self.train_size} sentences in the train split") | |
| self.log(f"{self.val_size} sentences in the validation split") | |
| self.log(f"{self.test_size} sentences in the test split") | |
| self.node_count = train.node_counter | |
| self.token_count = train.input_count | |
| self.edge_count = train.edge_counter | |
| self.no_edge_count = train.no_edge_counter | |
| self.anchor_freq = train.anchor_freq | |
| self.source_anchor_freq = train.source_anchor_freq if hasattr(train, "source_anchor_freq") else 0.5 | |
| self.target_anchor_freq = train.target_anchor_freq if hasattr(train, "target_anchor_freq") else 0.5 | |
| self.log(f"{self.node_count} nodes in the train split") | |
| self.every_word_input_field.build_vocab(val, test, min_freq=1, specials=[self.pad, self.unk, self.sos, self.eos]) | |
| self.char_form_field.build_vocab(train, min_freq=1, specials=[self.pad, self.unk, self.sos, self.eos]) | |
| self.char_form_field.nesting_field.vocab = self.char_form_field.vocab | |
| self.id_field.build_vocab(train, val, test, min_freq=1, specials=[]) | |
| self.label_field.build_vocab(train) | |
| self.anchored_label_field.vocab = self.label_field.vocab | |
| self.edge_label_field.build_vocab(train) | |
| print(list(self.edge_label_field.vocab.freqs.keys()), flush=True) | |
| self.char_form_vocab_size = len(self.char_form_field.vocab) | |
| self.create_label_freqs(args) | |
| self.create_edge_freqs(args) | |
| self.log(f"Edge frequency: {self.edge_presence_freq*100:.2f} %") | |
| self.log(f"{len(self.label_field.vocab)} words in the label vocabulary") | |
| self.log(f"{len(self.anchored_label_field.vocab)} words in the anchored label vocabulary") | |
| self.log(f"{len(self.edge_label_field.vocab)} words in the edge label vocabulary") | |
| self.log(f"{len(self.char_form_field.vocab)} characters in the vocabulary") | |
| self.log(self.label_field.vocab.freqs) | |
| self.log(self.anchored_label_field.vocab.freqs) | |
| self.train = torch.utils.data.DataLoader( | |
| train, | |
| batch_size=args.batch_size, | |
| shuffle=True, | |
| num_workers=args.workers, | |
| collate_fn=Collate(), | |
| pin_memory=True, | |
| drop_last=True | |
| ) | |
| self.train_size = len(self.train.dataset) | |
| self.val = torch.utils.data.DataLoader( | |
| val, | |
| batch_size=args.batch_size, | |
| shuffle=False, | |
| num_workers=args.workers, | |
| collate_fn=Collate(), | |
| pin_memory=True, | |
| ) | |
| self.val_size = len(self.val.dataset) | |
| self.test = torch.utils.data.DataLoader( | |
| test, | |
| batch_size=args.batch_size, | |
| shuffle=False, | |
| num_workers=args.workers, | |
| collate_fn=Collate(), | |
| pin_memory=True, | |
| ) | |
| self.test_size = len(self.test.dataset) | |
| if self.verbose: | |
| batch = next(iter(self.train)) | |
| print(f"\nBatch content: {Batch.to_str(batch)}\n") | |
| print(flush=True) | |
| def create_label_freqs(self, args): | |
| n_rules = len(self.label_field.vocab) | |
| blank_count = (args.query_length * self.token_count - self.node_count) | |
| label_counts = [blank_count] + [ | |
| self.label_field.vocab.freqs[self.label_field.vocab.itos[i]] | |
| for i in range(n_rules) | |
| ] | |
| label_counts = torch.FloatTensor(label_counts) | |
| self.label_freqs = label_counts / (self.node_count + blank_count) | |
| self.log(f"Label frequency: {self.label_freqs}") | |
| def create_edge_freqs(self, args): | |
| edge_counter = [ | |
| self.edge_label_field.vocab.freqs[self.edge_label_field.vocab.itos[i]] for i in range(len(self.edge_label_field.vocab)) | |
| ] | |
| edge_counter = torch.FloatTensor(edge_counter) | |
| self.edge_label_freqs = edge_counter / self.edge_count | |
| self.edge_presence_freq = self.edge_count / (self.edge_count + self.no_edge_count) | |