Spaces:
Build error
Build error
| #!/usr/bin/env python3 | |
| # Copyright 2017-present, Facebook, Inc. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """DrQA Document Reader model""" | |
| import torch | |
| import torch.optim as optim | |
| import torch.nn.functional as F | |
| import numpy as np | |
| import logging | |
| import copy | |
| from .config import override_model_args | |
| from .rnn_reader import RnnDocReader | |
| logger = logging.getLogger(__name__) | |
| class DocReader(object): | |
| """High level model that handles intializing the underlying network | |
| architecture, saving, updating examples, and predicting examples. | |
| """ | |
| # -------------------------------------------------------------------------- | |
| # Initialization | |
| # -------------------------------------------------------------------------- | |
| def __init__(self, args, word_dict, feature_dict, | |
| state_dict=None, normalize=True): | |
| # Book-keeping. | |
| self.args = args | |
| self.word_dict = word_dict | |
| self.args.vocab_size = len(word_dict) | |
| self.feature_dict = feature_dict | |
| self.args.num_features = len(feature_dict) | |
| self.updates = 0 | |
| self.use_cuda = False | |
| self.parallel = False | |
| # Building network. If normalize if false, scores are not normalized | |
| # 0-1 per paragraph (no softmax). | |
| if args.model_type == 'rnn': | |
| self.network = RnnDocReader(args, normalize) | |
| else: | |
| raise RuntimeError('Unsupported model: %s' % args.model_type) | |
| # Load saved state | |
| if state_dict: | |
| # Load buffer separately | |
| if 'fixed_embedding' in state_dict: | |
| fixed_embedding = state_dict.pop('fixed_embedding') | |
| self.network.load_state_dict(state_dict) | |
| self.network.register_buffer('fixed_embedding', fixed_embedding) | |
| else: | |
| self.network.load_state_dict(state_dict) | |
| def expand_dictionary(self, words): | |
| """Add words to the DocReader dictionary if they do not exist. The | |
| underlying embedding matrix is also expanded (with random embeddings). | |
| Args: | |
| words: iterable of tokens to add to the dictionary. | |
| Output: | |
| added: set of tokens that were added. | |
| """ | |
| to_add = {self.word_dict.normalize(w) for w in words | |
| if w not in self.word_dict} | |
| # Add words to dictionary and expand embedding layer | |
| if len(to_add) > 0: | |
| logger.info('Adding %d new words to dictionary...' % len(to_add)) | |
| for w in to_add: | |
| self.word_dict.add(w) | |
| self.args.vocab_size = len(self.word_dict) | |
| logger.info('New vocab size: %d' % len(self.word_dict)) | |
| old_embedding = self.network.embedding.weight.data | |
| self.network.embedding = torch.nn.Embedding(self.args.vocab_size, | |
| self.args.embedding_dim, | |
| padding_idx=0) | |
| new_embedding = self.network.embedding.weight.data | |
| new_embedding[:old_embedding.size(0)] = old_embedding | |
| # Return added words | |
| return to_add | |
| def load_embeddings(self, words, embedding_file): | |
| """Load pretrained embeddings for a given list of words, if they exist. | |
| Args: | |
| words: iterable of tokens. Only those that are indexed in the | |
| dictionary are kept. | |
| embedding_file: path to text file of embeddings, space separated. | |
| """ | |
| words = {w for w in words if w in self.word_dict} | |
| logger.info('Loading pre-trained embeddings for %d words from %s' % | |
| (len(words), embedding_file)) | |
| embedding = self.network.embedding.weight.data | |
| # When normalized, some words are duplicated. (Average the embeddings). | |
| vec_counts = {} | |
| with open(embedding_file) as f: | |
| # Skip first line if of form count/dim. | |
| line = f.readline().rstrip().split(' ') | |
| if len(line) != 2: | |
| f.seek(0) | |
| for line in f: | |
| parsed = line.rstrip().split(' ') | |
| assert(len(parsed) == embedding.size(1) + 1) | |
| w = self.word_dict.normalize(parsed[0]) | |
| if w in words: | |
| vec = torch.Tensor([float(i) for i in parsed[1:]]) | |
| if w not in vec_counts: | |
| vec_counts[w] = 1 | |
| embedding[self.word_dict[w]].copy_(vec) | |
| else: | |
| logging.warning( | |
| 'WARN: Duplicate embedding found for %s' % w | |
| ) | |
| vec_counts[w] = vec_counts[w] + 1 | |
| embedding[self.word_dict[w]].add_(vec) | |
| for w, c in vec_counts.items(): | |
| embedding[self.word_dict[w]].div_(c) | |
| logger.info('Loaded %d embeddings (%.2f%%)' % | |
| (len(vec_counts), 100 * len(vec_counts) / len(words))) | |
| def tune_embeddings(self, words): | |
| """Unfix the embeddings of a list of words. This is only relevant if | |
| only some of the embeddings are being tuned (tune_partial = N). | |
| Shuffles the N specified words to the front of the dictionary, and saves | |
| the original vectors of the other N + 1:vocab words in a fixed buffer. | |
| Args: | |
| words: iterable of tokens contained in dictionary. | |
| """ | |
| words = {w for w in words if w in self.word_dict} | |
| if len(words) == 0: | |
| logger.warning('Tried to tune embeddings, but no words given!') | |
| return | |
| if len(words) == len(self.word_dict): | |
| logger.warning('Tuning ALL embeddings in dictionary') | |
| return | |
| # Shuffle words and vectors | |
| embedding = self.network.embedding.weight.data | |
| for idx, swap_word in enumerate(words, self.word_dict.START): | |
| # Get current word + embedding for this index | |
| curr_word = self.word_dict[idx] | |
| curr_emb = embedding[idx].clone() | |
| old_idx = self.word_dict[swap_word] | |
| # Swap embeddings + dictionary indices | |
| embedding[idx].copy_(embedding[old_idx]) | |
| embedding[old_idx].copy_(curr_emb) | |
| self.word_dict[swap_word] = idx | |
| self.word_dict[idx] = swap_word | |
| self.word_dict[curr_word] = old_idx | |
| self.word_dict[old_idx] = curr_word | |
| # Save the original, fixed embeddings | |
| self.network.register_buffer( | |
| 'fixed_embedding', embedding[idx + 1:].clone() | |
| ) | |
| def init_optimizer(self, state_dict=None): | |
| """Initialize an optimizer for the free parameters of the network. | |
| Args: | |
| state_dict: network parameters | |
| """ | |
| if self.args.fix_embeddings: | |
| for p in self.network.embedding.parameters(): | |
| p.requires_grad = False | |
| parameters = [p for p in self.network.parameters() if p.requires_grad] | |
| if self.args.optimizer == 'sgd': | |
| self.optimizer = optim.SGD(parameters, self.args.learning_rate, | |
| momentum=self.args.momentum, | |
| weight_decay=self.args.weight_decay) | |
| elif self.args.optimizer == 'adamax': | |
| self.optimizer = optim.Adamax(parameters, | |
| weight_decay=self.args.weight_decay) | |
| else: | |
| raise RuntimeError('Unsupported optimizer: %s' % | |
| self.args.optimizer) | |
| # -------------------------------------------------------------------------- | |
| # Learning | |
| # -------------------------------------------------------------------------- | |
| def update(self, ex): | |
| """Forward a batch of examples; step the optimizer to update weights.""" | |
| if not self.optimizer: | |
| raise RuntimeError('No optimizer set.') | |
| # Train mode | |
| self.network.train() | |
| # Transfer to GPU | |
| if self.use_cuda: | |
| inputs = [e if e is None else e.cuda(non_blocking=True) | |
| for e in ex[:5]] | |
| target_s = ex[5].cuda(non_blocking=True) | |
| target_e = ex[6].cuda(non_blocking=True) | |
| else: | |
| inputs = [e if e is None else e for e in ex[:5]] | |
| target_s = ex[5] | |
| target_e = ex[6] | |
| # Run forward | |
| score_s, score_e = self.network(*inputs) | |
| # Compute loss and accuracies | |
| loss = F.nll_loss(score_s, target_s) + F.nll_loss(score_e, target_e) | |
| # Clear gradients and run backward | |
| self.optimizer.zero_grad() | |
| loss.backward() | |
| # Clip gradients | |
| torch.nn.utils.clip_grad_norm_(self.network.parameters(), | |
| self.args.grad_clipping) | |
| # Update parameters | |
| self.optimizer.step() | |
| self.updates += 1 | |
| # Reset any partially fixed parameters (e.g. rare words) | |
| self.reset_parameters() | |
| return loss.item(), ex[0].size(0) | |
| def reset_parameters(self): | |
| """Reset any partially fixed parameters to original states.""" | |
| # Reset fixed embeddings to original value | |
| if self.args.tune_partial > 0: | |
| if self.parallel: | |
| embedding = self.network.module.embedding.weight.data | |
| fixed_embedding = self.network.module.fixed_embedding | |
| else: | |
| embedding = self.network.embedding.weight.data | |
| fixed_embedding = self.network.fixed_embedding | |
| # Embeddings to fix are the last indices | |
| offset = embedding.size(0) - fixed_embedding.size(0) | |
| if offset >= 0: | |
| embedding[offset:] = fixed_embedding | |
| # -------------------------------------------------------------------------- | |
| # Prediction | |
| # -------------------------------------------------------------------------- | |
| def predict(self, ex, candidates=None, top_n=1, async_pool=None): | |
| """Forward a batch of examples only to get predictions. | |
| Args: | |
| ex: the batch | |
| candidates: batch * variable length list of string answer options. | |
| The model will only consider exact spans contained in this list. | |
| top_n: Number of predictions to return per batch element. | |
| async_pool: If provided, non-gpu post-processing will be offloaded | |
| to this CPU process pool. | |
| Output: | |
| pred_s: batch * top_n predicted start indices | |
| pred_e: batch * top_n predicted end indices | |
| pred_score: batch * top_n prediction scores | |
| If async_pool is given, these will be AsyncResult handles. | |
| """ | |
| # Eval mode | |
| self.network.eval() | |
| # Transfer to GPU | |
| if self.use_cuda: | |
| inputs = [e if e is None else e.cuda(non_blocking=True) | |
| for e in ex[:5]] | |
| else: | |
| inputs = [e for e in ex[:5]] | |
| # Run forward | |
| with torch.no_grad(): | |
| score_s, score_e = self.network(*inputs) | |
| # Decode predictions | |
| score_s = score_s.data.cpu() | |
| score_e = score_e.data.cpu() | |
| if candidates: | |
| args = (score_s, score_e, candidates, top_n, self.args.max_len) | |
| if async_pool: | |
| return async_pool.apply_async(self.decode_candidates, args) | |
| else: | |
| return self.decode_candidates(*args) | |
| else: | |
| args = (score_s, score_e, top_n, self.args.max_len) | |
| if async_pool: | |
| return async_pool.apply_async(self.decode, args) | |
| else: | |
| return self.decode(*args) | |
| def decode(score_s, score_e, top_n=1, max_len=None): | |
| """Take argmax of constrained score_s * score_e. | |
| Args: | |
| score_s: independent start predictions | |
| score_e: independent end predictions | |
| top_n: number of top scored pairs to take | |
| max_len: max span length to consider | |
| """ | |
| pred_s = [] | |
| pred_e = [] | |
| pred_score = [] | |
| max_len = max_len or score_s.size(1) | |
| for i in range(score_s.size(0)): | |
| # Outer product of scores to get full p_s * p_e matrix | |
| scores = torch.ger(score_s[i], score_e[i]) | |
| # Zero out negative length and over-length span scores | |
| scores.triu_().tril_(max_len - 1) | |
| # Take argmax or top n | |
| scores = scores.numpy() | |
| scores_flat = scores.flatten() | |
| if top_n == 1: | |
| idx_sort = [np.argmax(scores_flat)] | |
| elif len(scores_flat) < top_n: | |
| idx_sort = np.argsort(-scores_flat) | |
| else: | |
| idx = np.argpartition(-scores_flat, top_n)[0:top_n] | |
| idx_sort = idx[np.argsort(-scores_flat[idx])] | |
| s_idx, e_idx = np.unravel_index(idx_sort, scores.shape) | |
| pred_s.append(s_idx) | |
| pred_e.append(e_idx) | |
| pred_score.append(scores_flat[idx_sort]) | |
| return pred_s, pred_e, pred_score | |
| def decode_candidates(score_s, score_e, candidates, top_n=1, max_len=None): | |
| """Take argmax of constrained score_s * score_e. Except only consider | |
| spans that are in the candidates list. | |
| """ | |
| pred_s = [] | |
| pred_e = [] | |
| pred_score = [] | |
| for i in range(score_s.size(0)): | |
| # Extract original tokens stored with candidates | |
| tokens = candidates[i]['input'] | |
| cands = candidates[i]['cands'] | |
| if not cands: | |
| # try getting from globals? (multiprocessing in pipeline mode) | |
| from ..pipeline.drqa import PROCESS_CANDS | |
| cands = PROCESS_CANDS | |
| if not cands: | |
| raise RuntimeError('No candidates given.') | |
| # Score all valid candidates found in text. | |
| # Brute force get all ngrams and compare against the candidate list. | |
| max_len = max_len or len(tokens) | |
| scores, s_idx, e_idx = [], [], [] | |
| for s, e in tokens.ngrams(n=max_len, as_strings=False): | |
| span = tokens.slice(s, e).untokenize() | |
| if span in cands or span.lower() in cands: | |
| # Match! Record its score. | |
| scores.append(score_s[i][s] * score_e[i][e - 1]) | |
| s_idx.append(s) | |
| e_idx.append(e - 1) | |
| if len(scores) == 0: | |
| # No candidates present | |
| pred_s.append([]) | |
| pred_e.append([]) | |
| pred_score.append([]) | |
| else: | |
| # Rank found candidates | |
| scores = np.array(scores) | |
| s_idx = np.array(s_idx) | |
| e_idx = np.array(e_idx) | |
| idx_sort = np.argsort(-scores)[0:top_n] | |
| pred_s.append(s_idx[idx_sort]) | |
| pred_e.append(e_idx[idx_sort]) | |
| pred_score.append(scores[idx_sort]) | |
| return pred_s, pred_e, pred_score | |
| # -------------------------------------------------------------------------- | |
| # Saving and loading | |
| # -------------------------------------------------------------------------- | |
| def save(self, filename): | |
| if self.parallel: | |
| network = self.network.module | |
| else: | |
| network = self.network | |
| state_dict = copy.copy(network.state_dict()) | |
| if 'fixed_embedding' in state_dict: | |
| state_dict.pop('fixed_embedding') | |
| params = { | |
| 'state_dict': state_dict, | |
| 'word_dict': self.word_dict, | |
| 'feature_dict': self.feature_dict, | |
| 'args': self.args, | |
| } | |
| try: | |
| torch.save(params, filename) | |
| except BaseException: | |
| logger.warning('WARN: Saving failed... continuing anyway.') | |
| def checkpoint(self, filename, epoch): | |
| if self.parallel: | |
| network = self.network.module | |
| else: | |
| network = self.network | |
| params = { | |
| 'state_dict': network.state_dict(), | |
| 'word_dict': self.word_dict, | |
| 'feature_dict': self.feature_dict, | |
| 'args': self.args, | |
| 'epoch': epoch, | |
| 'optimizer': self.optimizer.state_dict(), | |
| } | |
| try: | |
| torch.save(params, filename) | |
| except BaseException: | |
| logger.warning('WARN: Saving failed... continuing anyway.') | |
| def load(filename, new_args=None, normalize=True): | |
| logger.info('Loading model %s' % filename) | |
| saved_params = torch.load( | |
| filename, map_location=lambda storage, loc: storage | |
| ) | |
| word_dict = saved_params['word_dict'] | |
| feature_dict = saved_params['feature_dict'] | |
| state_dict = saved_params['state_dict'] | |
| args = saved_params['args'] | |
| if new_args: | |
| args = override_model_args(args, new_args) | |
| return DocReader(args, word_dict, feature_dict, state_dict, normalize) | |
| def load_checkpoint(filename, normalize=True): | |
| logger.info('Loading model %s' % filename) | |
| saved_params = torch.load( | |
| filename, map_location=lambda storage, loc: storage | |
| ) | |
| word_dict = saved_params['word_dict'] | |
| feature_dict = saved_params['feature_dict'] | |
| state_dict = saved_params['state_dict'] | |
| epoch = saved_params['epoch'] | |
| optimizer = saved_params['optimizer'] | |
| args = saved_params['args'] | |
| model = DocReader(args, word_dict, feature_dict, state_dict, normalize) | |
| model.init_optimizer(optimizer) | |
| return model, epoch | |
| # -------------------------------------------------------------------------- | |
| # Runtime | |
| # -------------------------------------------------------------------------- | |
| def cuda(self): | |
| self.use_cuda = True | |
| self.network = self.network.cuda() | |
| def cpu(self): | |
| self.use_cuda = False | |
| self.network = self.network.cpu() | |
| def parallelize(self): | |
| """Use data parallel to copy the model across several gpus. | |
| This will take all gpus visible with CUDA_VISIBLE_DEVICES. | |
| """ | |
| self.parallel = True | |
| self.network = torch.nn.DataParallel(self.network) | |