Spaces:
Build error
Build error
| #!/usr/bin/env python3 | |
| # Copyright 2017-present, Facebook, Inc. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """Functions for putting examples into torch format.""" | |
| from collections import Counter | |
| import torch | |
| def vectorize(ex, model, single_answer=False): | |
| """Torchify a single example.""" | |
| args = model.args | |
| word_dict = model.word_dict | |
| feature_dict = model.feature_dict | |
| # Index words | |
| document = torch.LongTensor([word_dict[w] for w in ex['document']]) | |
| question = torch.LongTensor([word_dict[w] for w in ex['question']]) | |
| # Create extra features vector | |
| if len(feature_dict) > 0: | |
| features = torch.zeros(len(ex['document']), len(feature_dict)) | |
| else: | |
| features = None | |
| # f_{exact_match} | |
| if args.use_in_question: | |
| q_words_cased = {w for w in ex['question']} | |
| q_words_uncased = {w.lower() for w in ex['question']} | |
| q_lemma = {w for w in ex['qlemma']} if args.use_lemma else None | |
| for i in range(len(ex['document'])): | |
| if ex['document'][i] in q_words_cased: | |
| features[i][feature_dict['in_question']] = 1.0 | |
| if ex['document'][i].lower() in q_words_uncased: | |
| features[i][feature_dict['in_question_uncased']] = 1.0 | |
| if q_lemma and ex['lemma'][i] in q_lemma: | |
| features[i][feature_dict['in_question_lemma']] = 1.0 | |
| # f_{token} (POS) | |
| if args.use_pos: | |
| for i, w in enumerate(ex['pos']): | |
| f = 'pos=%s' % w | |
| if f in feature_dict: | |
| features[i][feature_dict[f]] = 1.0 | |
| # f_{token} (NER) | |
| if args.use_ner: | |
| for i, w in enumerate(ex['ner']): | |
| f = 'ner=%s' % w | |
| if f in feature_dict: | |
| features[i][feature_dict[f]] = 1.0 | |
| # f_{token} (TF) | |
| if args.use_tf: | |
| counter = Counter([w.lower() for w in ex['document']]) | |
| l = len(ex['document']) | |
| for i, w in enumerate(ex['document']): | |
| features[i][feature_dict['tf']] = counter[w.lower()] * 1.0 / l | |
| # Maybe return without target | |
| if 'answers' not in ex: | |
| return document, features, question, ex['id'] | |
| # ...or with target(s) (might still be empty if answers is empty) | |
| if single_answer: | |
| assert(len(ex['answers']) > 0) | |
| start = torch.LongTensor(1).fill_(ex['answers'][0][0]) | |
| end = torch.LongTensor(1).fill_(ex['answers'][0][1]) | |
| else: | |
| start = [a[0] for a in ex['answers']] | |
| end = [a[1] for a in ex['answers']] | |
| return document, features, question, start, end, ex['id'] | |
| def batchify(batch): | |
| """Gather a batch of individual examples into one batch.""" | |
| NUM_INPUTS = 3 | |
| NUM_TARGETS = 2 | |
| NUM_EXTRA = 1 | |
| ids = [ex[-1] for ex in batch] | |
| docs = [ex[0] for ex in batch] | |
| features = [ex[1] for ex in batch] | |
| questions = [ex[2] for ex in batch] | |
| # Batch documents and features | |
| max_length = max([d.size(0) for d in docs]) | |
| x1 = torch.LongTensor(len(docs), max_length).zero_() | |
| x1_mask = torch.ByteTensor(len(docs), max_length).fill_(1) | |
| if features[0] is None: | |
| x1_f = None | |
| else: | |
| x1_f = torch.zeros(len(docs), max_length, features[0].size(1)) | |
| for i, d in enumerate(docs): | |
| x1[i, :d.size(0)].copy_(d) | |
| x1_mask[i, :d.size(0)].fill_(0) | |
| if x1_f is not None: | |
| x1_f[i, :d.size(0)].copy_(features[i]) | |
| # Batch questions | |
| max_length = max([q.size(0) for q in questions]) | |
| x2 = torch.LongTensor(len(questions), max_length).zero_() | |
| x2_mask = torch.ByteTensor(len(questions), max_length).fill_(1) | |
| for i, q in enumerate(questions): | |
| x2[i, :q.size(0)].copy_(q) | |
| x2_mask[i, :q.size(0)].fill_(0) | |
| # Maybe return without targets | |
| if len(batch[0]) == NUM_INPUTS + NUM_EXTRA: | |
| return x1, x1_f, x1_mask, x2, x2_mask, ids | |
| elif len(batch[0]) == NUM_INPUTS + NUM_EXTRA + NUM_TARGETS: | |
| # ...Otherwise add targets | |
| if torch.is_tensor(batch[0][3]): | |
| y_s = torch.cat([ex[3] for ex in batch]) | |
| y_e = torch.cat([ex[4] for ex in batch]) | |
| else: | |
| y_s = [ex[3] for ex in batch] | |
| y_e = [ex[4] for ex in batch] | |
| else: | |
| raise RuntimeError('Incorrect number of inputs per example.') | |
| return x1, x1_f, x1_mask, x2, x2_mask, y_s, y_e, ids | |