| |
| |
| |
| |
|
|
| import json |
| import os |
|
|
| import numpy as np |
| import torch |
| from fairseq.data import ( |
| Dictionary, |
| IdDataset, |
| ListDataset, |
| NestedDictionaryDataset, |
| NumelDataset, |
| NumSamplesDataset, |
| RawLabelDataset, |
| RightPadDataset, |
| SortDataset, |
| data_utils, |
| encoders, |
| ) |
| from fairseq.tasks import LegacyFairseqTask, register_task |
|
|
|
|
| @register_task("commonsense_qa") |
| class CommonsenseQATask(LegacyFairseqTask): |
| """Task to finetune RoBERTa for Commonsense QA.""" |
|
|
| @staticmethod |
| def add_args(parser): |
| """Add task-specific arguments to the parser.""" |
| parser.add_argument( |
| "data", metavar="DIR", help="path to data directory; we load <split>.jsonl" |
| ) |
| parser.add_argument( |
| "--init-token", |
| type=int, |
| default=None, |
| help="add token at the beginning of each batch item", |
| ) |
| parser.add_argument("--num-classes", type=int, default=5) |
|
|
| def __init__(self, args, vocab): |
| super().__init__(args) |
| self.vocab = vocab |
| self.mask = vocab.add_symbol("<mask>") |
|
|
| self.bpe = encoders.build_bpe(args) |
|
|
| @classmethod |
| def load_dictionary(cls, filename): |
| """Load the dictionary from the filename |
| |
| Args: |
| filename (str): the filename |
| """ |
| dictionary = Dictionary.load(filename) |
| dictionary.add_symbol("<mask>") |
| return dictionary |
|
|
| @classmethod |
| def setup_task(cls, args, **kwargs): |
| assert ( |
| args.criterion == "sentence_ranking" |
| ), "Must set --criterion=sentence_ranking" |
|
|
| |
| vocab = cls.load_dictionary(os.path.join(args.data, "dict.txt")) |
| print("| dictionary: {} types".format(len(vocab))) |
|
|
| return cls(args, vocab) |
|
|
| def load_dataset( |
| self, split, epoch=1, combine=False, data_path=None, return_only=False, **kwargs |
| ): |
| """Load a given dataset split. |
| |
| Args: |
| split (str): name of the split (e.g., train, valid, test) |
| """ |
|
|
| def binarize(s, append_bos=False): |
| if self.bpe is not None: |
| s = self.bpe.encode(s) |
| tokens = self.vocab.encode_line( |
| s, |
| append_eos=True, |
| add_if_not_exist=False, |
| ).long() |
| if append_bos and self.args.init_token is not None: |
| tokens = torch.cat([tokens.new([self.args.init_token]), tokens]) |
| return tokens |
|
|
| if data_path is None: |
| data_path = os.path.join(self.args.data, split + ".jsonl") |
| if not os.path.exists(data_path): |
| raise FileNotFoundError("Cannot find data: {}".format(data_path)) |
|
|
| src_tokens = [[] for i in range(self.args.num_classes)] |
| src_lengths = [[] for i in range(self.args.num_classes)] |
| labels = [] |
|
|
| with open(data_path) as h: |
| for line in h: |
| example = json.loads(line.strip()) |
| if "answerKey" in example: |
| label = ord(example["answerKey"]) - ord("A") |
| labels.append(label) |
| question = example["question"]["stem"] |
| assert len(example["question"]["choices"]) == self.args.num_classes |
| |
| question = "Q: " + question |
| question_toks = binarize(question, append_bos=True) |
| for i, choice in enumerate(example["question"]["choices"]): |
| src = "A: " + choice["text"] |
| src_bin = torch.cat([question_toks, binarize(src)]) |
| src_tokens[i].append(src_bin) |
| src_lengths[i].append(len(src_bin)) |
| assert all( |
| len(src_tokens[0]) == len(src_tokens[i]) |
| for i in range(self.args.num_classes) |
| ) |
| assert len(src_tokens[0]) == len(src_lengths[0]) |
| assert len(labels) == 0 or len(labels) == len(src_tokens[0]) |
|
|
| for i in range(self.args.num_classes): |
| src_lengths[i] = np.array(src_lengths[i]) |
| src_tokens[i] = ListDataset(src_tokens[i], src_lengths[i]) |
| src_lengths[i] = ListDataset(src_lengths[i]) |
|
|
| dataset = { |
| "id": IdDataset(), |
| "nsentences": NumSamplesDataset(), |
| "ntokens": NumelDataset(src_tokens[0], reduce=True), |
| } |
|
|
| for i in range(self.args.num_classes): |
| dataset.update( |
| { |
| "net_input{}".format(i + 1): { |
| "src_tokens": RightPadDataset( |
| src_tokens[i], |
| pad_idx=self.source_dictionary.pad(), |
| ), |
| "src_lengths": src_lengths[i], |
| } |
| } |
| ) |
|
|
| if len(labels) > 0: |
| dataset.update({"target": RawLabelDataset(labels)}) |
|
|
| dataset = NestedDictionaryDataset( |
| dataset, |
| sizes=[np.maximum.reduce([src_token.sizes for src_token in src_tokens])], |
| ) |
|
|
| with data_utils.numpy_seed(self.args.seed): |
| dataset = SortDataset( |
| dataset, |
| |
| sort_order=[np.random.permutation(len(dataset))], |
| ) |
|
|
| print("| Loaded {} with {} samples".format(split, len(dataset))) |
|
|
| self.datasets[split] = dataset |
| return self.datasets[split] |
|
|
| def build_model(self, args): |
| from fairseq import models |
|
|
| model = models.build_model(args, self) |
|
|
| model.register_classification_head( |
| "sentence_classification_head", |
| num_classes=1, |
| ) |
|
|
| return model |
|
|
| @property |
| def source_dictionary(self): |
| return self.vocab |
|
|
| @property |
| def target_dictionary(self): |
| return self.vocab |
|
|