Add files using upload-large-folder tool
Browse files
stanza/stanza/models/common/beam.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import division
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
import stanza.models.common.seq2seq_constant as constant
|
| 5 |
+
|
| 6 |
+
r"""
|
| 7 |
+
Adapted and modified from the OpenNMT project.
|
| 8 |
+
|
| 9 |
+
Class for managing the internals of the beam search process.
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
hyp1-hyp1---hyp1 -hyp1
|
| 13 |
+
\ /
|
| 14 |
+
hyp2 \-hyp2 /-hyp2hyp2
|
| 15 |
+
/ \
|
| 16 |
+
hyp3-hyp3---hyp3 -hyp3
|
| 17 |
+
========================
|
| 18 |
+
|
| 19 |
+
Takes care of beams, back pointers, and scores.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# TORCH COMPATIBILITY
|
| 24 |
+
#
|
| 25 |
+
# Here we special case trunc division
|
| 26 |
+
# torch < 1.8.0 has no rounding_model='trunc' argument for torch.div
|
| 27 |
+
# however, there were several versions in a row where // would loudly
|
| 28 |
+
# proclaim it was buggy, and users complained about that
|
| 29 |
+
# this hopefully maintains compatibility for torch
|
| 30 |
+
try:
|
| 31 |
+
a = torch.tensor([1.])
|
| 32 |
+
b = torch.tensor([2.])
|
| 33 |
+
c = torch.div(a, b, rounding_mode='trunc')
|
| 34 |
+
def trunc_division(a, b):
|
| 35 |
+
return torch.div(a, b, rounding_mode='trunc')
|
| 36 |
+
except TypeError:
|
| 37 |
+
def trunc_division(a, b):
|
| 38 |
+
return a // b
|
| 39 |
+
|
| 40 |
+
class Beam(object):
|
| 41 |
+
def __init__(self, size, device=None):
|
| 42 |
+
self.size = size
|
| 43 |
+
self.done = False
|
| 44 |
+
|
| 45 |
+
# The score for each translation on the beam.
|
| 46 |
+
self.scores = torch.zeros(size, dtype=torch.float32, device=device)
|
| 47 |
+
self.allScores = []
|
| 48 |
+
|
| 49 |
+
# The backpointers at each time-step.
|
| 50 |
+
self.prevKs = []
|
| 51 |
+
|
| 52 |
+
# The outputs at each time-step.
|
| 53 |
+
self.nextYs = [torch.zeros(size, dtype=torch.int64, device=device).fill_(constant.PAD_ID)]
|
| 54 |
+
self.nextYs[0][0] = constant.SOS_ID
|
| 55 |
+
|
| 56 |
+
# The copy indices for each time
|
| 57 |
+
self.copy = []
|
| 58 |
+
|
| 59 |
+
def get_current_state(self):
|
| 60 |
+
"Get the outputs for the current timestep."
|
| 61 |
+
return self.nextYs[-1]
|
| 62 |
+
|
| 63 |
+
def get_current_origin(self):
|
| 64 |
+
"Get the backpointers for the current timestep."
|
| 65 |
+
return self.prevKs[-1]
|
| 66 |
+
|
| 67 |
+
def advance(self, wordLk, copy_indices=None):
|
| 68 |
+
"""
|
| 69 |
+
Given prob over words for every last beam `wordLk` and attention
|
| 70 |
+
`attnOut`: Compute and update the beam search.
|
| 71 |
+
|
| 72 |
+
Parameters:
|
| 73 |
+
|
| 74 |
+
* `wordLk`- probs of advancing from the last step (K x words)
|
| 75 |
+
* `copy_indices` - copy indices (K x ctx_len)
|
| 76 |
+
|
| 77 |
+
Returns: True if beam search is complete.
|
| 78 |
+
"""
|
| 79 |
+
if self.done:
|
| 80 |
+
return True
|
| 81 |
+
numWords = wordLk.size(1)
|
| 82 |
+
|
| 83 |
+
# Sum the previous scores.
|
| 84 |
+
if len(self.prevKs) > 0:
|
| 85 |
+
beamLk = wordLk + self.scores.unsqueeze(1).expand_as(wordLk)
|
| 86 |
+
else:
|
| 87 |
+
# first step, expand from the first position
|
| 88 |
+
beamLk = wordLk[0]
|
| 89 |
+
|
| 90 |
+
flatBeamLk = beamLk.view(-1)
|
| 91 |
+
|
| 92 |
+
bestScores, bestScoresId = flatBeamLk.topk(self.size, 0, True, True)
|
| 93 |
+
self.allScores.append(self.scores)
|
| 94 |
+
self.scores = bestScores
|
| 95 |
+
|
| 96 |
+
# bestScoresId is flattened beam x word array, so calculate which
|
| 97 |
+
# word and beam each score came from
|
| 98 |
+
# bestScoreId is the integer ids, and numWords is the integer length.
|
| 99 |
+
# Need to do integer division
|
| 100 |
+
prevK = trunc_division(bestScoresId, numWords)
|
| 101 |
+
self.prevKs.append(prevK)
|
| 102 |
+
self.nextYs.append(bestScoresId - prevK * numWords)
|
| 103 |
+
if copy_indices is not None:
|
| 104 |
+
self.copy.append(copy_indices.index_select(0, prevK))
|
| 105 |
+
|
| 106 |
+
# End condition is when top-of-beam is EOS.
|
| 107 |
+
if self.nextYs[-1][0] == constant.EOS_ID:
|
| 108 |
+
self.done = True
|
| 109 |
+
self.allScores.append(self.scores)
|
| 110 |
+
|
| 111 |
+
return self.done
|
| 112 |
+
|
| 113 |
+
def sort_best(self):
|
| 114 |
+
return torch.sort(self.scores, 0, True)
|
| 115 |
+
|
| 116 |
+
def get_best(self):
|
| 117 |
+
"Get the score of the best in the beam."
|
| 118 |
+
scores, ids = self.sortBest()
|
| 119 |
+
return scores[1], ids[1]
|
| 120 |
+
|
| 121 |
+
def get_hyp(self, k):
|
| 122 |
+
"""
|
| 123 |
+
Walk back to construct the full hypothesis.
|
| 124 |
+
|
| 125 |
+
Parameters:
|
| 126 |
+
|
| 127 |
+
* `k` - the position in the beam to construct.
|
| 128 |
+
|
| 129 |
+
Returns: The hypothesis
|
| 130 |
+
"""
|
| 131 |
+
hyp = []
|
| 132 |
+
cpy = []
|
| 133 |
+
for j in range(len(self.prevKs) - 1, -1, -1):
|
| 134 |
+
hyp.append(self.nextYs[j+1][k])
|
| 135 |
+
if len(self.copy) > 0:
|
| 136 |
+
cpy.append(self.copy[j][k])
|
| 137 |
+
k = self.prevKs[j][k]
|
| 138 |
+
|
| 139 |
+
hyp = hyp[::-1]
|
| 140 |
+
cpy = cpy[::-1]
|
| 141 |
+
# postprocess: if cpy index is not -1, use cpy index instead of hyp word
|
| 142 |
+
for i,cidx in enumerate(cpy):
|
| 143 |
+
if cidx >= 0:
|
| 144 |
+
hyp[i] = -(cidx+1) # make index 1-based and flip it for token generation
|
| 145 |
+
|
| 146 |
+
return hyp
|
stanza/stanza/models/common/convert_pretrain.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
A utility script to load a word embedding file from a text file and save it as a .pt
|
| 3 |
+
|
| 4 |
+
Run it as follows:
|
| 5 |
+
python stanza/models/common/convert_pretrain.py <.pt file> <text file> <# vectors>
|
| 6 |
+
|
| 7 |
+
Note that -1 for # of vectors will keep all the vectors.
|
| 8 |
+
You probably want to keep fewer than that for most publicly released
|
| 9 |
+
embeddings, though, as they can get quite large.
|
| 10 |
+
|
| 11 |
+
As a concrete example, you can convert a newly downloaded Faroese WV file as follows:
|
| 12 |
+
python3 stanza/models/common/convert_pretrain.py ~/stanza/saved_models/pos/fo_farpahc.pretrain.pt ~/extern_data/wordvec/fasttext/faroese.txt -1
|
| 13 |
+
or save part of an Icelandic WV file:
|
| 14 |
+
python3 stanza/models/common/convert_pretrain.py ~/stanza/saved_models/pos/is_icepahc.pretrain.pt ~/extern_data/wordvec/fasttext/icelandic.cc.is.300.vec 150000
|
| 15 |
+
Note that if the pretrain already exists, nothing will be changed. It will not overwrite an existing .pt file.
|
| 16 |
+
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import argparse
|
| 20 |
+
import os
|
| 21 |
+
import sys
|
| 22 |
+
|
| 23 |
+
from stanza.models.common import pretrain
|
| 24 |
+
|
| 25 |
+
def main():
|
| 26 |
+
parser = argparse.ArgumentParser()
|
| 27 |
+
parser.add_argument("output_pt", default=None, help="Where to write the converted PT file")
|
| 28 |
+
parser.add_argument("input_vec", default=None, help="Unconverted vectors file")
|
| 29 |
+
parser.add_argument("max_vocab", type=int, default=-1, nargs="?", help="How many vectors to convert. -1 means convert them all")
|
| 30 |
+
args = parser.parse_args()
|
| 31 |
+
|
| 32 |
+
if os.path.exists(args.output_pt):
|
| 33 |
+
print("Not overwriting existing pretrain file in %s" % args.output_pt)
|
| 34 |
+
|
| 35 |
+
if args.input_vec.endswith(".csv"):
|
| 36 |
+
pt = pretrain.Pretrain(args.output_pt, max_vocab=args.max_vocab, csv_filename=args.input_vec)
|
| 37 |
+
else:
|
| 38 |
+
pt = pretrain.Pretrain(args.output_pt, args.input_vec, max_vocab=args.max_vocab)
|
| 39 |
+
print("Pretrain is of size {}".format(len(pt.vocab)))
|
| 40 |
+
|
| 41 |
+
if __name__ == '__main__':
|
| 42 |
+
main()
|
stanza/stanza/models/common/pretrain.py
ADDED
|
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Supports for pretrained data.
|
| 3 |
+
"""
|
| 4 |
+
import csv
|
| 5 |
+
import os
|
| 6 |
+
import re
|
| 7 |
+
|
| 8 |
+
import lzma
|
| 9 |
+
import logging
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
|
| 13 |
+
from .vocab import BaseVocab, VOCAB_PREFIX, UNK_ID
|
| 14 |
+
|
| 15 |
+
from stanza.models.common.utils import open_read_binary, open_read_text
|
| 16 |
+
from stanza.resources.common import DEFAULT_MODEL_DIR
|
| 17 |
+
|
| 18 |
+
from pickle import UnpicklingError
|
| 19 |
+
import warnings
|
| 20 |
+
|
| 21 |
+
logger = logging.getLogger('stanza')
|
| 22 |
+
|
| 23 |
+
class PretrainedWordVocab(BaseVocab):
|
| 24 |
+
def build_vocab(self):
|
| 25 |
+
self._id2unit = VOCAB_PREFIX + self.data
|
| 26 |
+
self._unit2id = {w:i for i, w in enumerate(self._id2unit)}
|
| 27 |
+
|
| 28 |
+
def normalize_unit(self, unit):
|
| 29 |
+
unit = super().normalize_unit(unit)
|
| 30 |
+
if unit:
|
| 31 |
+
unit = unit.replace(" ","\xa0")
|
| 32 |
+
return unit
|
| 33 |
+
|
| 34 |
+
class Pretrain:
|
| 35 |
+
""" A loader and saver for pretrained embeddings. """
|
| 36 |
+
|
| 37 |
+
def __init__(self, filename=None, vec_filename=None, max_vocab=-1, save_to_file=True, csv_filename=None):
|
| 38 |
+
self.filename = filename
|
| 39 |
+
self._vec_filename = vec_filename
|
| 40 |
+
self._csv_filename = csv_filename
|
| 41 |
+
self._max_vocab = max_vocab
|
| 42 |
+
self._save_to_file = save_to_file
|
| 43 |
+
|
| 44 |
+
@property
|
| 45 |
+
def vocab(self):
|
| 46 |
+
if not hasattr(self, '_vocab'):
|
| 47 |
+
self.load()
|
| 48 |
+
return self._vocab
|
| 49 |
+
|
| 50 |
+
@property
|
| 51 |
+
def emb(self):
|
| 52 |
+
if not hasattr(self, '_emb'):
|
| 53 |
+
self.load()
|
| 54 |
+
return self._emb
|
| 55 |
+
|
| 56 |
+
def load(self):
|
| 57 |
+
if self.filename is not None and os.path.exists(self.filename):
|
| 58 |
+
try:
|
| 59 |
+
# TODO: after making the next release, remove the weights_only=False version
|
| 60 |
+
try:
|
| 61 |
+
data = torch.load(self.filename, lambda storage, loc: storage, weights_only=True)
|
| 62 |
+
except UnpicklingError:
|
| 63 |
+
data = torch.load(self.filename, lambda storage, loc: storage, weights_only=False)
|
| 64 |
+
warnings.warn("The saved pretrain has an old format using numpy.ndarray instead of torch to store weights. This version of Stanza can support reading both the new and the old formats. Future versions will only allow loading with weights_only=True. Please resave the pretrained embedding using this version ASAP.")
|
| 65 |
+
logger.debug("Loaded pretrain from {}".format(self.filename))
|
| 66 |
+
if not isinstance(data, dict):
|
| 67 |
+
raise RuntimeError("File {} exists but is not a stanza pretrain file. It is not a dict, whereas a Stanza pretrain should have a dict with 'emb' and 'vocab'".format(self.filename))
|
| 68 |
+
if 'emb' not in data or 'vocab' not in data:
|
| 69 |
+
raise RuntimeError("File {} exists but is not a stanza pretrain file. A Stanza pretrain file should have 'emb' and 'vocab' fields in its state dict".format(self.filename))
|
| 70 |
+
self._vocab = PretrainedWordVocab.load_state_dict(data['vocab'])
|
| 71 |
+
self._emb = data['emb']
|
| 72 |
+
if isinstance(self._emb, np.ndarray):
|
| 73 |
+
self._emb = torch.from_numpy(self._emb)
|
| 74 |
+
return
|
| 75 |
+
except (KeyboardInterrupt, SystemExit):
|
| 76 |
+
raise
|
| 77 |
+
except BaseException as e:
|
| 78 |
+
if not self._vec_filename and not self._csv_filename:
|
| 79 |
+
raise
|
| 80 |
+
logger.warning("Pretrained file exists but cannot be loaded from {}, due to the following exception:\n\t{}".format(self.filename, e))
|
| 81 |
+
vocab, emb = self.read_pretrain()
|
| 82 |
+
else:
|
| 83 |
+
if not self._vec_filename and not self._csv_filename:
|
| 84 |
+
raise FileNotFoundError("Pretrained file {} does not exist, and no text/xz file was provided".format(self.filename))
|
| 85 |
+
if self.filename is not None:
|
| 86 |
+
logger.info("Pretrained filename %s specified, but file does not exist. Attempting to load from text file" % self.filename)
|
| 87 |
+
vocab, emb = self.read_pretrain()
|
| 88 |
+
|
| 89 |
+
self._vocab = vocab
|
| 90 |
+
self._emb = emb
|
| 91 |
+
|
| 92 |
+
if self._save_to_file:
|
| 93 |
+
# save to file
|
| 94 |
+
assert self.filename is not None, "Filename must be provided to save pretrained vector to file."
|
| 95 |
+
self.save(self.filename)
|
| 96 |
+
|
| 97 |
+
def save(self, filename):
|
| 98 |
+
directory, _ = os.path.split(filename)
|
| 99 |
+
if directory:
|
| 100 |
+
os.makedirs(directory, exist_ok=True)
|
| 101 |
+
# should not infinite loop since the load function sets _vocab and _emb before trying to save
|
| 102 |
+
data = {'vocab': self.vocab.state_dict(), 'emb': self.emb}
|
| 103 |
+
try:
|
| 104 |
+
torch.save(data, filename, _use_new_zipfile_serialization=False)
|
| 105 |
+
logger.info("Saved pretrained vocab and vectors to {}".format(filename))
|
| 106 |
+
except (KeyboardInterrupt, SystemExit):
|
| 107 |
+
raise
|
| 108 |
+
except BaseException as e:
|
| 109 |
+
logger.warning("Saving pretrained data failed due to the following exception... continuing anyway.\n\t{}".format(e))
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def write_text(self, filename, header=False):
|
| 113 |
+
"""
|
| 114 |
+
Write the vocab & values to a text file
|
| 115 |
+
"""
|
| 116 |
+
with open(filename, "w") as fout:
|
| 117 |
+
if header:
|
| 118 |
+
word_dim = self.emb[0].shape[0]
|
| 119 |
+
fout.write("%d %d\n" % (len(self.vocab), word_dim))
|
| 120 |
+
for word_idx, word in enumerate(self.vocab):
|
| 121 |
+
row = self.emb[word_idx].to("cpu")
|
| 122 |
+
fout.write(word)
|
| 123 |
+
fout.write(" ")
|
| 124 |
+
fout.write(" ".join(["%.6f" % x.item() for x in row]))
|
| 125 |
+
fout.write("\n")
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def read_pretrain(self):
|
| 129 |
+
# load from pretrained filename
|
| 130 |
+
if self._vec_filename is not None:
|
| 131 |
+
words, emb, failed = self.read_from_file(self._vec_filename, self._max_vocab)
|
| 132 |
+
elif self._csv_filename is not None:
|
| 133 |
+
words, emb = self.read_from_csv(self._csv_filename)
|
| 134 |
+
else:
|
| 135 |
+
raise RuntimeError("Vector file is not provided.")
|
| 136 |
+
|
| 137 |
+
if len(emb) - len(VOCAB_PREFIX) != len(words):
|
| 138 |
+
raise RuntimeError("Loaded number of vectors does not match number of words.")
|
| 139 |
+
|
| 140 |
+
# Use a fixed vocab size
|
| 141 |
+
if self._max_vocab > len(VOCAB_PREFIX) and self._max_vocab < len(words) + len(VOCAB_PREFIX):
|
| 142 |
+
words = words[:self._max_vocab - len(VOCAB_PREFIX)]
|
| 143 |
+
emb = emb[:self._max_vocab]
|
| 144 |
+
|
| 145 |
+
vocab = PretrainedWordVocab(words)
|
| 146 |
+
|
| 147 |
+
return vocab, emb
|
| 148 |
+
|
| 149 |
+
@staticmethod
|
| 150 |
+
def read_from_csv(filename):
|
| 151 |
+
"""
|
| 152 |
+
Read vectors from CSV
|
| 153 |
+
|
| 154 |
+
Skips the first row
|
| 155 |
+
"""
|
| 156 |
+
logger.info("Reading pretrained vectors from csv file %s ...", filename)
|
| 157 |
+
with open_read_text(filename) as fin:
|
| 158 |
+
csv_reader = csv.reader(fin)
|
| 159 |
+
# the header of the thai csv vector file we have is just the number of columns
|
| 160 |
+
# so we read past the first line
|
| 161 |
+
for line in csv_reader:
|
| 162 |
+
break
|
| 163 |
+
lines = [line for line in csv_reader]
|
| 164 |
+
|
| 165 |
+
rows = len(lines)
|
| 166 |
+
cols = len(lines[0]) - 1
|
| 167 |
+
|
| 168 |
+
emb = torch.zeros((rows + len(VOCAB_PREFIX), cols), dtype=torch.float32)
|
| 169 |
+
for i, line in enumerate(lines):
|
| 170 |
+
emb[i+len(VOCAB_PREFIX)] = torch.tensor([float(x) for x in line[-cols:]], dtype=torch.float32)
|
| 171 |
+
words = [line[0].replace(' ', '\xa0') for line in lines]
|
| 172 |
+
return words, emb
|
| 173 |
+
|
| 174 |
+
@staticmethod
|
| 175 |
+
def read_from_file(filename, max_vocab=None):
|
| 176 |
+
"""
|
| 177 |
+
Open a vector file using the provided function and read from it.
|
| 178 |
+
"""
|
| 179 |
+
logger.info("Reading pretrained vectors from %s ...", filename)
|
| 180 |
+
|
| 181 |
+
# some vector files, such as Google News, use tabs
|
| 182 |
+
tab_space_pattern = re.compile(r"[ \t]+")
|
| 183 |
+
first = True
|
| 184 |
+
cols = None
|
| 185 |
+
lines = []
|
| 186 |
+
failed = 0
|
| 187 |
+
unk_line = None
|
| 188 |
+
with open_read_binary(filename) as f:
|
| 189 |
+
for i, line in enumerate(f):
|
| 190 |
+
try:
|
| 191 |
+
line = line.decode()
|
| 192 |
+
except UnicodeDecodeError:
|
| 193 |
+
failed += 1
|
| 194 |
+
continue
|
| 195 |
+
line = line.rstrip()
|
| 196 |
+
if not line:
|
| 197 |
+
continue
|
| 198 |
+
pieces = tab_space_pattern.split(line)
|
| 199 |
+
if first:
|
| 200 |
+
# the first line contains the number of word vectors and the dimensionality
|
| 201 |
+
# note that a 1d embedding with a number as the first entry
|
| 202 |
+
# will fail to read properly. we ignore that case
|
| 203 |
+
first = False
|
| 204 |
+
if len(pieces) == 2:
|
| 205 |
+
cols = int(pieces[1])
|
| 206 |
+
continue
|
| 207 |
+
|
| 208 |
+
if pieces[0] == '<unk>':
|
| 209 |
+
if unk_line is not None:
|
| 210 |
+
logger.error("More than one <unk> line in the pretrain! Keeping the most recent one")
|
| 211 |
+
else:
|
| 212 |
+
logger.debug("Found an unk line while reading the pretrain")
|
| 213 |
+
unk_line = pieces
|
| 214 |
+
else:
|
| 215 |
+
if not max_vocab or max_vocab < 0 or len(lines) < max_vocab:
|
| 216 |
+
lines.append(pieces)
|
| 217 |
+
|
| 218 |
+
if cols is None:
|
| 219 |
+
# another failure case: all words have spaces in them
|
| 220 |
+
cols = min(len(x) for x in lines) - 1
|
| 221 |
+
rows = len(lines)
|
| 222 |
+
emb = torch.zeros((rows + len(VOCAB_PREFIX), cols), dtype=torch.float32)
|
| 223 |
+
if unk_line is not None:
|
| 224 |
+
emb[UNK_ID] = torch.tensor([float(x) for x in unk_line[-cols:]], dtype=torch.float32)
|
| 225 |
+
for i, line in enumerate(lines):
|
| 226 |
+
emb[i+len(VOCAB_PREFIX)] = torch.tensor([float(x) for x in line[-cols:]], dtype=torch.float32)
|
| 227 |
+
|
| 228 |
+
# if there were word pieces separated with spaces, rejoin them with nbsp instead
|
| 229 |
+
# this way, the normalize_unit method in vocab.py can find the word at test time
|
| 230 |
+
words = ['\xa0'.join(line[:-cols]) for line in lines]
|
| 231 |
+
if failed > 0:
|
| 232 |
+
logger.info("Failed to read %d lines from embedding", failed)
|
| 233 |
+
return words, emb, failed
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def find_pretrain_file(wordvec_pretrain_file, save_dir, shorthand, lang):
|
| 237 |
+
"""
|
| 238 |
+
When training a model, look in a few different places for a .pt file
|
| 239 |
+
|
| 240 |
+
If a specific argument was passsed in, prefer that location
|
| 241 |
+
Otherwise, check in a few places:
|
| 242 |
+
saved_models/{model}/{shorthand}.pretrain.pt
|
| 243 |
+
saved_models/{model}/{shorthand}_pretrain.pt
|
| 244 |
+
~/stanza_resources/{language}/pretrain/{shorthand}_pretrain.pt
|
| 245 |
+
"""
|
| 246 |
+
if wordvec_pretrain_file:
|
| 247 |
+
return wordvec_pretrain_file
|
| 248 |
+
|
| 249 |
+
default_pretrain_file = os.path.join(save_dir, '{}.pretrain.pt'.format(shorthand))
|
| 250 |
+
if os.path.exists(default_pretrain_file):
|
| 251 |
+
logger.debug("Found existing .pt file in %s" % default_pretrain_file)
|
| 252 |
+
return default_pretrain_file
|
| 253 |
+
else:
|
| 254 |
+
logger.debug("Cannot find pretrained vectors in %s" % default_pretrain_file)
|
| 255 |
+
|
| 256 |
+
pretrain_file = os.path.join(save_dir, '{}_pretrain.pt'.format(shorthand))
|
| 257 |
+
if os.path.exists(pretrain_file):
|
| 258 |
+
logger.debug("Found existing .pt file in %s" % pretrain_file)
|
| 259 |
+
return pretrain_file
|
| 260 |
+
else:
|
| 261 |
+
logger.debug("Cannot find pretrained vectors in %s" % pretrain_file)
|
| 262 |
+
|
| 263 |
+
if shorthand.find("_") >= 0:
|
| 264 |
+
# try to assemble /home/user/stanza_resources/vi/pretrain/vtb.pt for example
|
| 265 |
+
pretrain_file = os.path.join(DEFAULT_MODEL_DIR, lang, 'pretrain', '{}.pt'.format(shorthand.split('_', 1)[1]))
|
| 266 |
+
if os.path.exists(pretrain_file):
|
| 267 |
+
logger.debug("Found existing .pt file in %s" % pretrain_file)
|
| 268 |
+
return pretrain_file
|
| 269 |
+
else:
|
| 270 |
+
logger.debug("Cannot find pretrained vectors in %s" % pretrain_file)
|
| 271 |
+
|
| 272 |
+
# if we can't find it anywhere, just return the first location searched...
|
| 273 |
+
# maybe we'll get lucky and the original .txt file can be found
|
| 274 |
+
return default_pretrain_file
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
if __name__ == '__main__':
|
| 278 |
+
with open('test.txt', 'w') as fout:
|
| 279 |
+
fout.write('3 2\na 1 1\nb -1 -1\nc 0 0\n')
|
| 280 |
+
# 1st load: save to pt file
|
| 281 |
+
pretrain = Pretrain('test.pt', 'test.txt')
|
| 282 |
+
print(pretrain.emb)
|
| 283 |
+
# verify pt file
|
| 284 |
+
x = torch.load('test.pt', weights_only=True)
|
| 285 |
+
print(x)
|
| 286 |
+
# 2nd load: load saved pt file
|
| 287 |
+
pretrain = Pretrain('test.pt', 'test.txt')
|
| 288 |
+
print(pretrain.emb)
|
| 289 |
+
|
stanza/stanza/models/common/seq2seq_modules.py
ADDED
|
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Pytorch implementation of basic sequence to Sequence modules.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import math
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
import stanza.models.common.seq2seq_constant as constant
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger('stanza')
|
| 14 |
+
|
| 15 |
+
class BasicAttention(nn.Module):
|
| 16 |
+
"""
|
| 17 |
+
A basic MLP attention layer.
|
| 18 |
+
"""
|
| 19 |
+
def __init__(self, dim):
|
| 20 |
+
super(BasicAttention, self).__init__()
|
| 21 |
+
self.linear_in = nn.Linear(dim, dim, bias=False)
|
| 22 |
+
self.linear_c = nn.Linear(dim, dim)
|
| 23 |
+
self.linear_v = nn.Linear(dim, 1, bias=False)
|
| 24 |
+
self.linear_out = nn.Linear(dim * 2, dim, bias=False)
|
| 25 |
+
self.tanh = nn.Tanh()
|
| 26 |
+
self.sm = nn.Softmax(dim=1)
|
| 27 |
+
|
| 28 |
+
def forward(self, input, context, mask=None, attn_only=False):
|
| 29 |
+
"""
|
| 30 |
+
input: batch x dim
|
| 31 |
+
context: batch x sourceL x dim
|
| 32 |
+
"""
|
| 33 |
+
batch_size = context.size(0)
|
| 34 |
+
source_len = context.size(1)
|
| 35 |
+
dim = context.size(2)
|
| 36 |
+
target = self.linear_in(input) # batch x dim
|
| 37 |
+
source = self.linear_c(context.contiguous().view(-1, dim)).view(batch_size, source_len, dim)
|
| 38 |
+
attn = target.unsqueeze(1).expand_as(context) + source
|
| 39 |
+
attn = self.tanh(attn) # batch x sourceL x dim
|
| 40 |
+
attn = self.linear_v(attn.view(-1, dim)).view(batch_size, source_len)
|
| 41 |
+
|
| 42 |
+
if mask is not None:
|
| 43 |
+
attn.masked_fill_(mask, -constant.INFINITY_NUMBER)
|
| 44 |
+
|
| 45 |
+
attn = self.sm(attn)
|
| 46 |
+
if attn_only:
|
| 47 |
+
return attn
|
| 48 |
+
|
| 49 |
+
weighted_context = torch.bmm(attn.unsqueeze(1), context).squeeze(1)
|
| 50 |
+
h_tilde = torch.cat((weighted_context, input), 1)
|
| 51 |
+
h_tilde = self.tanh(self.linear_out(h_tilde))
|
| 52 |
+
|
| 53 |
+
return h_tilde, attn
|
| 54 |
+
|
| 55 |
+
class SoftDotAttention(nn.Module):
|
| 56 |
+
"""Soft Dot Attention.
|
| 57 |
+
|
| 58 |
+
Ref: http://www.aclweb.org/anthology/D15-1166
|
| 59 |
+
Adapted from PyTorch OPEN NMT.
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
def __init__(self, dim):
|
| 63 |
+
"""Initialize layer."""
|
| 64 |
+
super(SoftDotAttention, self).__init__()
|
| 65 |
+
self.linear_in = nn.Linear(dim, dim, bias=False)
|
| 66 |
+
self.sm = nn.Softmax(dim=1)
|
| 67 |
+
self.linear_out = nn.Linear(dim * 2, dim, bias=False)
|
| 68 |
+
self.tanh = nn.Tanh()
|
| 69 |
+
self.mask = None
|
| 70 |
+
|
| 71 |
+
def forward(self, input, context, mask=None, attn_only=False, return_logattn=False):
|
| 72 |
+
"""Propagate input through the network.
|
| 73 |
+
|
| 74 |
+
input: batch x dim
|
| 75 |
+
context: batch x sourceL x dim
|
| 76 |
+
"""
|
| 77 |
+
target = self.linear_in(input).unsqueeze(2) # batch x dim x 1
|
| 78 |
+
|
| 79 |
+
# Get attention
|
| 80 |
+
attn = torch.bmm(context, target).squeeze(2) # batch x sourceL
|
| 81 |
+
|
| 82 |
+
if mask is not None:
|
| 83 |
+
# sett the padding attention logits to -inf
|
| 84 |
+
assert mask.size() == attn.size(), "Mask size must match the attention size!"
|
| 85 |
+
attn.masked_fill_(mask, -constant.INFINITY_NUMBER)
|
| 86 |
+
|
| 87 |
+
if return_logattn:
|
| 88 |
+
attn = torch.log_softmax(attn, 1)
|
| 89 |
+
attn_w = torch.exp(attn)
|
| 90 |
+
else:
|
| 91 |
+
attn = self.sm(attn)
|
| 92 |
+
attn_w = attn
|
| 93 |
+
if attn_only:
|
| 94 |
+
return attn
|
| 95 |
+
|
| 96 |
+
attn3 = attn_w.view(attn_w.size(0), 1, attn_w.size(1)) # batch x 1 x sourceL
|
| 97 |
+
|
| 98 |
+
weighted_context = torch.bmm(attn3, context).squeeze(1) # batch x dim
|
| 99 |
+
h_tilde = torch.cat((weighted_context, input), 1)
|
| 100 |
+
|
| 101 |
+
h_tilde = self.tanh(self.linear_out(h_tilde))
|
| 102 |
+
|
| 103 |
+
return h_tilde, attn
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class LinearAttention(nn.Module):
|
| 107 |
+
""" A linear attention form, inspired by BiDAF:
|
| 108 |
+
a = W (u; v; u o v)
|
| 109 |
+
"""
|
| 110 |
+
|
| 111 |
+
def __init__(self, dim):
|
| 112 |
+
super(LinearAttention, self).__init__()
|
| 113 |
+
self.linear = nn.Linear(dim*3, 1, bias=False)
|
| 114 |
+
self.linear_out = nn.Linear(dim * 2, dim, bias=False)
|
| 115 |
+
self.sm = nn.Softmax(dim=1)
|
| 116 |
+
self.tanh = nn.Tanh()
|
| 117 |
+
self.mask = None
|
| 118 |
+
|
| 119 |
+
def forward(self, input, context, mask=None, attn_only=False):
|
| 120 |
+
"""
|
| 121 |
+
input: batch x dim
|
| 122 |
+
context: batch x sourceL x dim
|
| 123 |
+
"""
|
| 124 |
+
batch_size = context.size(0)
|
| 125 |
+
source_len = context.size(1)
|
| 126 |
+
dim = context.size(2)
|
| 127 |
+
u = input.unsqueeze(1).expand_as(context).contiguous().view(-1, dim) # batch*sourceL x dim
|
| 128 |
+
v = context.contiguous().view(-1, dim)
|
| 129 |
+
attn_in = torch.cat((u, v, u.mul(v)), 1)
|
| 130 |
+
attn = self.linear(attn_in).view(batch_size, source_len)
|
| 131 |
+
|
| 132 |
+
if mask is not None:
|
| 133 |
+
# sett the padding attention logits to -inf
|
| 134 |
+
assert mask.size() == attn.size(), "Mask size must match the attention size!"
|
| 135 |
+
attn.masked_fill_(mask, -constant.INFINITY_NUMBER)
|
| 136 |
+
|
| 137 |
+
attn = self.sm(attn)
|
| 138 |
+
if attn_only:
|
| 139 |
+
return attn
|
| 140 |
+
|
| 141 |
+
attn3 = attn.view(batch_size, 1, source_len) # batch x 1 x sourceL
|
| 142 |
+
|
| 143 |
+
weighted_context = torch.bmm(attn3, context).squeeze(1) # batch x dim
|
| 144 |
+
h_tilde = torch.cat((weighted_context, input), 1)
|
| 145 |
+
h_tilde = self.tanh(self.linear_out(h_tilde))
|
| 146 |
+
return h_tilde, attn
|
| 147 |
+
|
| 148 |
+
class DeepAttention(nn.Module):
|
| 149 |
+
""" A deep attention form, invented by Robert:
|
| 150 |
+
u = ReLU(Wx)
|
| 151 |
+
v = ReLU(Wy)
|
| 152 |
+
a = V.(u o v)
|
| 153 |
+
"""
|
| 154 |
+
|
| 155 |
+
def __init__(self, dim):
|
| 156 |
+
super(DeepAttention, self).__init__()
|
| 157 |
+
self.linear_in = nn.Linear(dim, dim, bias=False)
|
| 158 |
+
self.linear_v = nn.Linear(dim, 1, bias=False)
|
| 159 |
+
self.linear_out = nn.Linear(dim * 2, dim, bias=False)
|
| 160 |
+
self.relu = nn.ReLU()
|
| 161 |
+
self.sm = nn.Softmax(dim=1)
|
| 162 |
+
self.tanh = nn.Tanh()
|
| 163 |
+
self.mask = None
|
| 164 |
+
|
| 165 |
+
def forward(self, input, context, mask=None, attn_only=False):
|
| 166 |
+
"""
|
| 167 |
+
input: batch x dim
|
| 168 |
+
context: batch x sourceL x dim
|
| 169 |
+
"""
|
| 170 |
+
batch_size = context.size(0)
|
| 171 |
+
source_len = context.size(1)
|
| 172 |
+
dim = context.size(2)
|
| 173 |
+
u = input.unsqueeze(1).expand_as(context).contiguous().view(-1, dim) # batch*sourceL x dim
|
| 174 |
+
u = self.relu(self.linear_in(u))
|
| 175 |
+
v = self.relu(self.linear_in(context.contiguous().view(-1, dim)))
|
| 176 |
+
attn = self.linear_v(u.mul(v)).view(batch_size, source_len)
|
| 177 |
+
|
| 178 |
+
if mask is not None:
|
| 179 |
+
# sett the padding attention logits to -inf
|
| 180 |
+
assert mask.size() == attn.size(), "Mask size must match the attention size!"
|
| 181 |
+
attn.masked_fill_(mask, -constant.INFINITY_NUMBER)
|
| 182 |
+
|
| 183 |
+
attn = self.sm(attn)
|
| 184 |
+
if attn_only:
|
| 185 |
+
return attn
|
| 186 |
+
|
| 187 |
+
attn3 = attn.view(batch_size, 1, source_len) # batch x 1 x sourceL
|
| 188 |
+
|
| 189 |
+
weighted_context = torch.bmm(attn3, context).squeeze(1) # batch x dim
|
| 190 |
+
h_tilde = torch.cat((weighted_context, input), 1)
|
| 191 |
+
h_tilde = self.tanh(self.linear_out(h_tilde))
|
| 192 |
+
return h_tilde, attn
|
| 193 |
+
|
| 194 |
+
class LSTMAttention(nn.Module):
|
| 195 |
+
r"""A long short-term memory (LSTM) cell with attention."""
|
| 196 |
+
|
| 197 |
+
def __init__(self, input_size, hidden_size, batch_first=True, attn_type='soft'):
|
| 198 |
+
"""Initialize params."""
|
| 199 |
+
super(LSTMAttention, self).__init__()
|
| 200 |
+
self.input_size = input_size
|
| 201 |
+
self.hidden_size = hidden_size
|
| 202 |
+
self.batch_first = batch_first
|
| 203 |
+
|
| 204 |
+
self.lstm_cell = nn.LSTMCell(input_size, hidden_size)
|
| 205 |
+
|
| 206 |
+
if attn_type == 'soft':
|
| 207 |
+
self.attention_layer = SoftDotAttention(hidden_size)
|
| 208 |
+
elif attn_type == 'mlp':
|
| 209 |
+
self.attention_layer = BasicAttention(hidden_size)
|
| 210 |
+
elif attn_type == 'linear':
|
| 211 |
+
self.attention_layer = LinearAttention(hidden_size)
|
| 212 |
+
elif attn_type == 'deep':
|
| 213 |
+
self.attention_layer = DeepAttention(hidden_size)
|
| 214 |
+
else:
|
| 215 |
+
raise Exception("Unsupported LSTM attention type: {}".format(attn_type))
|
| 216 |
+
logger.debug("Using {} attention for LSTM.".format(attn_type))
|
| 217 |
+
|
| 218 |
+
def forward(self, input, hidden, ctx, ctx_mask=None, return_logattn=False):
|
| 219 |
+
"""Propagate input through the network."""
|
| 220 |
+
if self.batch_first:
|
| 221 |
+
input = input.transpose(0,1)
|
| 222 |
+
|
| 223 |
+
output = []
|
| 224 |
+
attn = []
|
| 225 |
+
steps = range(input.size(0))
|
| 226 |
+
for i in steps:
|
| 227 |
+
hidden = self.lstm_cell(input[i], hidden)
|
| 228 |
+
hy, cy = hidden
|
| 229 |
+
h_tilde, alpha = self.attention_layer(hy, ctx, mask=ctx_mask, return_logattn=return_logattn)
|
| 230 |
+
output.append(h_tilde)
|
| 231 |
+
attn.append(alpha)
|
| 232 |
+
output = torch.cat(output, 0).view(input.size(0), *output[0].size())
|
| 233 |
+
|
| 234 |
+
if self.batch_first:
|
| 235 |
+
output = output.transpose(0,1)
|
| 236 |
+
|
| 237 |
+
if return_logattn:
|
| 238 |
+
attn = torch.stack(attn, 0)
|
| 239 |
+
if self.batch_first:
|
| 240 |
+
attn = attn.transpose(0, 1)
|
| 241 |
+
return output, hidden, attn
|
| 242 |
+
|
| 243 |
+
return output, hidden
|
| 244 |
+
|
stanza/stanza/models/common/stanza_object.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def _readonly_setter(self, name):
|
| 2 |
+
full_classname = self.__class__.__module__
|
| 3 |
+
if full_classname is None:
|
| 4 |
+
full_classname = self.__class__.__qualname__
|
| 5 |
+
else:
|
| 6 |
+
full_classname += '.' + self.__class__.__qualname__
|
| 7 |
+
raise ValueError(f'Property "{name}" of "{full_classname}" is read-only.')
|
| 8 |
+
|
| 9 |
+
class StanzaObject(object):
|
| 10 |
+
"""
|
| 11 |
+
Base class for all Stanza data objects that allows for some flexibility handling annotations
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
@classmethod
|
| 15 |
+
def add_property(cls, name, default=None, getter=None, setter=None):
|
| 16 |
+
"""
|
| 17 |
+
Add a property accessible through self.{name} with underlying variable self._{name}.
|
| 18 |
+
Optionally setup a setter as well.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
if hasattr(cls, name):
|
| 22 |
+
raise ValueError(f'Property by the name of {name} already exists in {cls}. Maybe you want to find another name?')
|
| 23 |
+
|
| 24 |
+
setattr(cls, f'_{name}', default)
|
| 25 |
+
if getter is None:
|
| 26 |
+
getter = lambda self: getattr(self, f'_{name}')
|
| 27 |
+
if setter is None:
|
| 28 |
+
setter = lambda self, value: _readonly_setter(self, name)
|
| 29 |
+
|
| 30 |
+
setattr(cls, name, property(getter, setter))
|
| 31 |
+
|