File size: 3,334 Bytes
19b8775 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
"""
Utils for seq2seq models.
"""
from collections import Counter
import random
import json
import torch
import stanza.models.common.seq2seq_constant as constant
# torch utils
def get_optimizer(name, parameters, lr):
if name == 'sgd':
return torch.optim.SGD(parameters, lr=lr)
elif name == 'adagrad':
return torch.optim.Adagrad(parameters, lr=lr)
elif name == 'adam':
return torch.optim.Adam(parameters) # use default lr
elif name == 'adamax':
return torch.optim.Adamax(parameters) # use default lr
else:
raise Exception("Unsupported optimizer: {}".format(name))
def change_lr(optimizer, new_lr):
for param_group in optimizer.param_groups:
param_group['lr'] = new_lr
def flatten_indices(seq_lens, width):
flat = []
for i, l in enumerate(seq_lens):
for j in range(l):
flat.append(i * width + j)
return flat
def keep_partial_grad(grad, topk):
"""
Keep only the topk rows of grads.
"""
assert topk < grad.size(0)
grad.data[topk:].zero_()
return grad
# other utils
def save_config(config, path, verbose=True):
with open(path, 'w') as outfile:
json.dump(config, outfile, indent=2)
if verbose:
print("Config saved to file {}".format(path))
return config
def load_config(path, verbose=True):
with open(path) as f:
config = json.load(f)
if verbose:
print("Config loaded from file {}".format(path))
return config
def unmap_with_copy(indices, src_tokens, vocab):
"""
Unmap a list of list of indices, by optionally copying from src_tokens.
"""
result = []
for ind, tokens in zip(indices, src_tokens):
words = []
for idx in ind:
if idx >= 0:
words.append(vocab.id2word[idx])
else:
idx = -idx - 1 # flip and minus 1
words.append(tokens[idx])
result += [words]
return result
def prune_decoded_seqs(seqs):
"""
Prune decoded sequences after EOS token.
"""
out = []
for s in seqs:
if constant.EOS in s:
idx = s.index(constant.EOS_TOKEN)
out += [s[:idx]]
else:
out += [s]
return out
def prune_hyp(hyp):
"""
Prune a decoded hypothesis
"""
if constant.EOS_ID in hyp:
idx = hyp.index(constant.EOS_ID)
return hyp[:idx]
else:
return hyp
def prune(data_list, lens):
assert len(data_list) == len(lens)
nl = []
for d, l in zip(data_list, lens):
nl.append(d[:l])
return nl
def sort(packed, ref, reverse=True):
"""
Sort a series of packed list, according to a ref list.
Also return the original index before the sort.
"""
assert (isinstance(packed, tuple) or isinstance(packed, list)) and isinstance(ref, list)
packed = [ref] + [range(len(ref))] + list(packed)
sorted_packed = [list(t) for t in zip(*sorted(zip(*packed), reverse=reverse))]
return tuple(sorted_packed[1:])
def unsort(sorted_list, oidx):
"""
Unsort a sorted list, based on the original idx.
"""
assert len(sorted_list) == len(oidx), "Number of list elements must match with original indices."
_, unsorted = [list(t) for t in zip(*sorted(zip(oidx, sorted_list)))]
return unsorted
|