""" Utils for seq2seq models. """ from collections import Counter import random import json import torch import stanza.models.common.seq2seq_constant as constant # torch utils def get_optimizer(name, parameters, lr): if name == 'sgd': return torch.optim.SGD(parameters, lr=lr) elif name == 'adagrad': return torch.optim.Adagrad(parameters, lr=lr) elif name == 'adam': return torch.optim.Adam(parameters) # use default lr elif name == 'adamax': return torch.optim.Adamax(parameters) # use default lr else: raise Exception("Unsupported optimizer: {}".format(name)) def change_lr(optimizer, new_lr): for param_group in optimizer.param_groups: param_group['lr'] = new_lr def flatten_indices(seq_lens, width): flat = [] for i, l in enumerate(seq_lens): for j in range(l): flat.append(i * width + j) return flat def keep_partial_grad(grad, topk): """ Keep only the topk rows of grads. """ assert topk < grad.size(0) grad.data[topk:].zero_() return grad # other utils def save_config(config, path, verbose=True): with open(path, 'w') as outfile: json.dump(config, outfile, indent=2) if verbose: print("Config saved to file {}".format(path)) return config def load_config(path, verbose=True): with open(path) as f: config = json.load(f) if verbose: print("Config loaded from file {}".format(path)) return config def unmap_with_copy(indices, src_tokens, vocab): """ Unmap a list of list of indices, by optionally copying from src_tokens. """ result = [] for ind, tokens in zip(indices, src_tokens): words = [] for idx in ind: if idx >= 0: words.append(vocab.id2word[idx]) else: idx = -idx - 1 # flip and minus 1 words.append(tokens[idx]) result += [words] return result def prune_decoded_seqs(seqs): """ Prune decoded sequences after EOS token. """ out = [] for s in seqs: if constant.EOS in s: idx = s.index(constant.EOS_TOKEN) out += [s[:idx]] else: out += [s] return out def prune_hyp(hyp): """ Prune a decoded hypothesis """ if constant.EOS_ID in hyp: idx = hyp.index(constant.EOS_ID) return hyp[:idx] else: return hyp def prune(data_list, lens): assert len(data_list) == len(lens) nl = [] for d, l in zip(data_list, lens): nl.append(d[:l]) return nl def sort(packed, ref, reverse=True): """ Sort a series of packed list, according to a ref list. Also return the original index before the sort. """ assert (isinstance(packed, tuple) or isinstance(packed, list)) and isinstance(ref, list) packed = [ref] + [range(len(ref))] + list(packed) sorted_packed = [list(t) for t in zip(*sorted(zip(*packed), reverse=reverse))] return tuple(sorted_packed[1:]) def unsort(sorted_list, oidx): """ Unsort a sorted list, based on the original idx. """ assert len(sorted_list) == len(oidx), "Number of list elements must match with original indices." _, unsorted = [list(t) for t in zip(*sorted(zip(oidx, sorted_list)))] return unsorted