bowphs commited on
Commit
3f3b639
·
verified ·
1 Parent(s): fff2f9b

Add files using upload-large-folder tool

Browse files
stanza/stanza/models/common/beam.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import division
2
+ import torch
3
+
4
+ import stanza.models.common.seq2seq_constant as constant
5
+
6
+ r"""
7
+ Adapted and modified from the OpenNMT project.
8
+
9
+ Class for managing the internals of the beam search process.
10
+
11
+
12
+ hyp1-hyp1---hyp1 -hyp1
13
+ \ /
14
+ hyp2 \-hyp2 /-hyp2hyp2
15
+ / \
16
+ hyp3-hyp3---hyp3 -hyp3
17
+ ========================
18
+
19
+ Takes care of beams, back pointers, and scores.
20
+ """
21
+
22
+
23
+ # TORCH COMPATIBILITY
24
+ #
25
+ # Here we special case trunc division
26
+ # torch < 1.8.0 has no rounding_model='trunc' argument for torch.div
27
+ # however, there were several versions in a row where // would loudly
28
+ # proclaim it was buggy, and users complained about that
29
+ # this hopefully maintains compatibility for torch
30
+ try:
31
+ a = torch.tensor([1.])
32
+ b = torch.tensor([2.])
33
+ c = torch.div(a, b, rounding_mode='trunc')
34
+ def trunc_division(a, b):
35
+ return torch.div(a, b, rounding_mode='trunc')
36
+ except TypeError:
37
+ def trunc_division(a, b):
38
+ return a // b
39
+
40
+ class Beam(object):
41
+ def __init__(self, size, device=None):
42
+ self.size = size
43
+ self.done = False
44
+
45
+ # The score for each translation on the beam.
46
+ self.scores = torch.zeros(size, dtype=torch.float32, device=device)
47
+ self.allScores = []
48
+
49
+ # The backpointers at each time-step.
50
+ self.prevKs = []
51
+
52
+ # The outputs at each time-step.
53
+ self.nextYs = [torch.zeros(size, dtype=torch.int64, device=device).fill_(constant.PAD_ID)]
54
+ self.nextYs[0][0] = constant.SOS_ID
55
+
56
+ # The copy indices for each time
57
+ self.copy = []
58
+
59
+ def get_current_state(self):
60
+ "Get the outputs for the current timestep."
61
+ return self.nextYs[-1]
62
+
63
+ def get_current_origin(self):
64
+ "Get the backpointers for the current timestep."
65
+ return self.prevKs[-1]
66
+
67
+ def advance(self, wordLk, copy_indices=None):
68
+ """
69
+ Given prob over words for every last beam `wordLk` and attention
70
+ `attnOut`: Compute and update the beam search.
71
+
72
+ Parameters:
73
+
74
+ * `wordLk`- probs of advancing from the last step (K x words)
75
+ * `copy_indices` - copy indices (K x ctx_len)
76
+
77
+ Returns: True if beam search is complete.
78
+ """
79
+ if self.done:
80
+ return True
81
+ numWords = wordLk.size(1)
82
+
83
+ # Sum the previous scores.
84
+ if len(self.prevKs) > 0:
85
+ beamLk = wordLk + self.scores.unsqueeze(1).expand_as(wordLk)
86
+ else:
87
+ # first step, expand from the first position
88
+ beamLk = wordLk[0]
89
+
90
+ flatBeamLk = beamLk.view(-1)
91
+
92
+ bestScores, bestScoresId = flatBeamLk.topk(self.size, 0, True, True)
93
+ self.allScores.append(self.scores)
94
+ self.scores = bestScores
95
+
96
+ # bestScoresId is flattened beam x word array, so calculate which
97
+ # word and beam each score came from
98
+ # bestScoreId is the integer ids, and numWords is the integer length.
99
+ # Need to do integer division
100
+ prevK = trunc_division(bestScoresId, numWords)
101
+ self.prevKs.append(prevK)
102
+ self.nextYs.append(bestScoresId - prevK * numWords)
103
+ if copy_indices is not None:
104
+ self.copy.append(copy_indices.index_select(0, prevK))
105
+
106
+ # End condition is when top-of-beam is EOS.
107
+ if self.nextYs[-1][0] == constant.EOS_ID:
108
+ self.done = True
109
+ self.allScores.append(self.scores)
110
+
111
+ return self.done
112
+
113
+ def sort_best(self):
114
+ return torch.sort(self.scores, 0, True)
115
+
116
+ def get_best(self):
117
+ "Get the score of the best in the beam."
118
+ scores, ids = self.sortBest()
119
+ return scores[1], ids[1]
120
+
121
+ def get_hyp(self, k):
122
+ """
123
+ Walk back to construct the full hypothesis.
124
+
125
+ Parameters:
126
+
127
+ * `k` - the position in the beam to construct.
128
+
129
+ Returns: The hypothesis
130
+ """
131
+ hyp = []
132
+ cpy = []
133
+ for j in range(len(self.prevKs) - 1, -1, -1):
134
+ hyp.append(self.nextYs[j+1][k])
135
+ if len(self.copy) > 0:
136
+ cpy.append(self.copy[j][k])
137
+ k = self.prevKs[j][k]
138
+
139
+ hyp = hyp[::-1]
140
+ cpy = cpy[::-1]
141
+ # postprocess: if cpy index is not -1, use cpy index instead of hyp word
142
+ for i,cidx in enumerate(cpy):
143
+ if cidx >= 0:
144
+ hyp[i] = -(cidx+1) # make index 1-based and flip it for token generation
145
+
146
+ return hyp
stanza/stanza/models/common/convert_pretrain.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A utility script to load a word embedding file from a text file and save it as a .pt
3
+
4
+ Run it as follows:
5
+ python stanza/models/common/convert_pretrain.py <.pt file> <text file> <# vectors>
6
+
7
+ Note that -1 for # of vectors will keep all the vectors.
8
+ You probably want to keep fewer than that for most publicly released
9
+ embeddings, though, as they can get quite large.
10
+
11
+ As a concrete example, you can convert a newly downloaded Faroese WV file as follows:
12
+ python3 stanza/models/common/convert_pretrain.py ~/stanza/saved_models/pos/fo_farpahc.pretrain.pt ~/extern_data/wordvec/fasttext/faroese.txt -1
13
+ or save part of an Icelandic WV file:
14
+ python3 stanza/models/common/convert_pretrain.py ~/stanza/saved_models/pos/is_icepahc.pretrain.pt ~/extern_data/wordvec/fasttext/icelandic.cc.is.300.vec 150000
15
+ Note that if the pretrain already exists, nothing will be changed. It will not overwrite an existing .pt file.
16
+
17
+ """
18
+
19
+ import argparse
20
+ import os
21
+ import sys
22
+
23
+ from stanza.models.common import pretrain
24
+
25
+ def main():
26
+ parser = argparse.ArgumentParser()
27
+ parser.add_argument("output_pt", default=None, help="Where to write the converted PT file")
28
+ parser.add_argument("input_vec", default=None, help="Unconverted vectors file")
29
+ parser.add_argument("max_vocab", type=int, default=-1, nargs="?", help="How many vectors to convert. -1 means convert them all")
30
+ args = parser.parse_args()
31
+
32
+ if os.path.exists(args.output_pt):
33
+ print("Not overwriting existing pretrain file in %s" % args.output_pt)
34
+
35
+ if args.input_vec.endswith(".csv"):
36
+ pt = pretrain.Pretrain(args.output_pt, max_vocab=args.max_vocab, csv_filename=args.input_vec)
37
+ else:
38
+ pt = pretrain.Pretrain(args.output_pt, args.input_vec, max_vocab=args.max_vocab)
39
+ print("Pretrain is of size {}".format(len(pt.vocab)))
40
+
41
+ if __name__ == '__main__':
42
+ main()
stanza/stanza/models/common/pretrain.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Supports for pretrained data.
3
+ """
4
+ import csv
5
+ import os
6
+ import re
7
+
8
+ import lzma
9
+ import logging
10
+ import numpy as np
11
+ import torch
12
+
13
+ from .vocab import BaseVocab, VOCAB_PREFIX, UNK_ID
14
+
15
+ from stanza.models.common.utils import open_read_binary, open_read_text
16
+ from stanza.resources.common import DEFAULT_MODEL_DIR
17
+
18
+ from pickle import UnpicklingError
19
+ import warnings
20
+
21
+ logger = logging.getLogger('stanza')
22
+
23
+ class PretrainedWordVocab(BaseVocab):
24
+ def build_vocab(self):
25
+ self._id2unit = VOCAB_PREFIX + self.data
26
+ self._unit2id = {w:i for i, w in enumerate(self._id2unit)}
27
+
28
+ def normalize_unit(self, unit):
29
+ unit = super().normalize_unit(unit)
30
+ if unit:
31
+ unit = unit.replace(" ","\xa0")
32
+ return unit
33
+
34
+ class Pretrain:
35
+ """ A loader and saver for pretrained embeddings. """
36
+
37
+ def __init__(self, filename=None, vec_filename=None, max_vocab=-1, save_to_file=True, csv_filename=None):
38
+ self.filename = filename
39
+ self._vec_filename = vec_filename
40
+ self._csv_filename = csv_filename
41
+ self._max_vocab = max_vocab
42
+ self._save_to_file = save_to_file
43
+
44
+ @property
45
+ def vocab(self):
46
+ if not hasattr(self, '_vocab'):
47
+ self.load()
48
+ return self._vocab
49
+
50
+ @property
51
+ def emb(self):
52
+ if not hasattr(self, '_emb'):
53
+ self.load()
54
+ return self._emb
55
+
56
+ def load(self):
57
+ if self.filename is not None and os.path.exists(self.filename):
58
+ try:
59
+ # TODO: after making the next release, remove the weights_only=False version
60
+ try:
61
+ data = torch.load(self.filename, lambda storage, loc: storage, weights_only=True)
62
+ except UnpicklingError:
63
+ data = torch.load(self.filename, lambda storage, loc: storage, weights_only=False)
64
+ warnings.warn("The saved pretrain has an old format using numpy.ndarray instead of torch to store weights. This version of Stanza can support reading both the new and the old formats. Future versions will only allow loading with weights_only=True. Please resave the pretrained embedding using this version ASAP.")
65
+ logger.debug("Loaded pretrain from {}".format(self.filename))
66
+ if not isinstance(data, dict):
67
+ raise RuntimeError("File {} exists but is not a stanza pretrain file. It is not a dict, whereas a Stanza pretrain should have a dict with 'emb' and 'vocab'".format(self.filename))
68
+ if 'emb' not in data or 'vocab' not in data:
69
+ raise RuntimeError("File {} exists but is not a stanza pretrain file. A Stanza pretrain file should have 'emb' and 'vocab' fields in its state dict".format(self.filename))
70
+ self._vocab = PretrainedWordVocab.load_state_dict(data['vocab'])
71
+ self._emb = data['emb']
72
+ if isinstance(self._emb, np.ndarray):
73
+ self._emb = torch.from_numpy(self._emb)
74
+ return
75
+ except (KeyboardInterrupt, SystemExit):
76
+ raise
77
+ except BaseException as e:
78
+ if not self._vec_filename and not self._csv_filename:
79
+ raise
80
+ logger.warning("Pretrained file exists but cannot be loaded from {}, due to the following exception:\n\t{}".format(self.filename, e))
81
+ vocab, emb = self.read_pretrain()
82
+ else:
83
+ if not self._vec_filename and not self._csv_filename:
84
+ raise FileNotFoundError("Pretrained file {} does not exist, and no text/xz file was provided".format(self.filename))
85
+ if self.filename is not None:
86
+ logger.info("Pretrained filename %s specified, but file does not exist. Attempting to load from text file" % self.filename)
87
+ vocab, emb = self.read_pretrain()
88
+
89
+ self._vocab = vocab
90
+ self._emb = emb
91
+
92
+ if self._save_to_file:
93
+ # save to file
94
+ assert self.filename is not None, "Filename must be provided to save pretrained vector to file."
95
+ self.save(self.filename)
96
+
97
+ def save(self, filename):
98
+ directory, _ = os.path.split(filename)
99
+ if directory:
100
+ os.makedirs(directory, exist_ok=True)
101
+ # should not infinite loop since the load function sets _vocab and _emb before trying to save
102
+ data = {'vocab': self.vocab.state_dict(), 'emb': self.emb}
103
+ try:
104
+ torch.save(data, filename, _use_new_zipfile_serialization=False)
105
+ logger.info("Saved pretrained vocab and vectors to {}".format(filename))
106
+ except (KeyboardInterrupt, SystemExit):
107
+ raise
108
+ except BaseException as e:
109
+ logger.warning("Saving pretrained data failed due to the following exception... continuing anyway.\n\t{}".format(e))
110
+
111
+
112
+ def write_text(self, filename, header=False):
113
+ """
114
+ Write the vocab & values to a text file
115
+ """
116
+ with open(filename, "w") as fout:
117
+ if header:
118
+ word_dim = self.emb[0].shape[0]
119
+ fout.write("%d %d\n" % (len(self.vocab), word_dim))
120
+ for word_idx, word in enumerate(self.vocab):
121
+ row = self.emb[word_idx].to("cpu")
122
+ fout.write(word)
123
+ fout.write(" ")
124
+ fout.write(" ".join(["%.6f" % x.item() for x in row]))
125
+ fout.write("\n")
126
+
127
+
128
+ def read_pretrain(self):
129
+ # load from pretrained filename
130
+ if self._vec_filename is not None:
131
+ words, emb, failed = self.read_from_file(self._vec_filename, self._max_vocab)
132
+ elif self._csv_filename is not None:
133
+ words, emb = self.read_from_csv(self._csv_filename)
134
+ else:
135
+ raise RuntimeError("Vector file is not provided.")
136
+
137
+ if len(emb) - len(VOCAB_PREFIX) != len(words):
138
+ raise RuntimeError("Loaded number of vectors does not match number of words.")
139
+
140
+ # Use a fixed vocab size
141
+ if self._max_vocab > len(VOCAB_PREFIX) and self._max_vocab < len(words) + len(VOCAB_PREFIX):
142
+ words = words[:self._max_vocab - len(VOCAB_PREFIX)]
143
+ emb = emb[:self._max_vocab]
144
+
145
+ vocab = PretrainedWordVocab(words)
146
+
147
+ return vocab, emb
148
+
149
+ @staticmethod
150
+ def read_from_csv(filename):
151
+ """
152
+ Read vectors from CSV
153
+
154
+ Skips the first row
155
+ """
156
+ logger.info("Reading pretrained vectors from csv file %s ...", filename)
157
+ with open_read_text(filename) as fin:
158
+ csv_reader = csv.reader(fin)
159
+ # the header of the thai csv vector file we have is just the number of columns
160
+ # so we read past the first line
161
+ for line in csv_reader:
162
+ break
163
+ lines = [line for line in csv_reader]
164
+
165
+ rows = len(lines)
166
+ cols = len(lines[0]) - 1
167
+
168
+ emb = torch.zeros((rows + len(VOCAB_PREFIX), cols), dtype=torch.float32)
169
+ for i, line in enumerate(lines):
170
+ emb[i+len(VOCAB_PREFIX)] = torch.tensor([float(x) for x in line[-cols:]], dtype=torch.float32)
171
+ words = [line[0].replace(' ', '\xa0') for line in lines]
172
+ return words, emb
173
+
174
+ @staticmethod
175
+ def read_from_file(filename, max_vocab=None):
176
+ """
177
+ Open a vector file using the provided function and read from it.
178
+ """
179
+ logger.info("Reading pretrained vectors from %s ...", filename)
180
+
181
+ # some vector files, such as Google News, use tabs
182
+ tab_space_pattern = re.compile(r"[ \t]+")
183
+ first = True
184
+ cols = None
185
+ lines = []
186
+ failed = 0
187
+ unk_line = None
188
+ with open_read_binary(filename) as f:
189
+ for i, line in enumerate(f):
190
+ try:
191
+ line = line.decode()
192
+ except UnicodeDecodeError:
193
+ failed += 1
194
+ continue
195
+ line = line.rstrip()
196
+ if not line:
197
+ continue
198
+ pieces = tab_space_pattern.split(line)
199
+ if first:
200
+ # the first line contains the number of word vectors and the dimensionality
201
+ # note that a 1d embedding with a number as the first entry
202
+ # will fail to read properly. we ignore that case
203
+ first = False
204
+ if len(pieces) == 2:
205
+ cols = int(pieces[1])
206
+ continue
207
+
208
+ if pieces[0] == '<unk>':
209
+ if unk_line is not None:
210
+ logger.error("More than one <unk> line in the pretrain! Keeping the most recent one")
211
+ else:
212
+ logger.debug("Found an unk line while reading the pretrain")
213
+ unk_line = pieces
214
+ else:
215
+ if not max_vocab or max_vocab < 0 or len(lines) < max_vocab:
216
+ lines.append(pieces)
217
+
218
+ if cols is None:
219
+ # another failure case: all words have spaces in them
220
+ cols = min(len(x) for x in lines) - 1
221
+ rows = len(lines)
222
+ emb = torch.zeros((rows + len(VOCAB_PREFIX), cols), dtype=torch.float32)
223
+ if unk_line is not None:
224
+ emb[UNK_ID] = torch.tensor([float(x) for x in unk_line[-cols:]], dtype=torch.float32)
225
+ for i, line in enumerate(lines):
226
+ emb[i+len(VOCAB_PREFIX)] = torch.tensor([float(x) for x in line[-cols:]], dtype=torch.float32)
227
+
228
+ # if there were word pieces separated with spaces, rejoin them with nbsp instead
229
+ # this way, the normalize_unit method in vocab.py can find the word at test time
230
+ words = ['\xa0'.join(line[:-cols]) for line in lines]
231
+ if failed > 0:
232
+ logger.info("Failed to read %d lines from embedding", failed)
233
+ return words, emb, failed
234
+
235
+
236
+ def find_pretrain_file(wordvec_pretrain_file, save_dir, shorthand, lang):
237
+ """
238
+ When training a model, look in a few different places for a .pt file
239
+
240
+ If a specific argument was passsed in, prefer that location
241
+ Otherwise, check in a few places:
242
+ saved_models/{model}/{shorthand}.pretrain.pt
243
+ saved_models/{model}/{shorthand}_pretrain.pt
244
+ ~/stanza_resources/{language}/pretrain/{shorthand}_pretrain.pt
245
+ """
246
+ if wordvec_pretrain_file:
247
+ return wordvec_pretrain_file
248
+
249
+ default_pretrain_file = os.path.join(save_dir, '{}.pretrain.pt'.format(shorthand))
250
+ if os.path.exists(default_pretrain_file):
251
+ logger.debug("Found existing .pt file in %s" % default_pretrain_file)
252
+ return default_pretrain_file
253
+ else:
254
+ logger.debug("Cannot find pretrained vectors in %s" % default_pretrain_file)
255
+
256
+ pretrain_file = os.path.join(save_dir, '{}_pretrain.pt'.format(shorthand))
257
+ if os.path.exists(pretrain_file):
258
+ logger.debug("Found existing .pt file in %s" % pretrain_file)
259
+ return pretrain_file
260
+ else:
261
+ logger.debug("Cannot find pretrained vectors in %s" % pretrain_file)
262
+
263
+ if shorthand.find("_") >= 0:
264
+ # try to assemble /home/user/stanza_resources/vi/pretrain/vtb.pt for example
265
+ pretrain_file = os.path.join(DEFAULT_MODEL_DIR, lang, 'pretrain', '{}.pt'.format(shorthand.split('_', 1)[1]))
266
+ if os.path.exists(pretrain_file):
267
+ logger.debug("Found existing .pt file in %s" % pretrain_file)
268
+ return pretrain_file
269
+ else:
270
+ logger.debug("Cannot find pretrained vectors in %s" % pretrain_file)
271
+
272
+ # if we can't find it anywhere, just return the first location searched...
273
+ # maybe we'll get lucky and the original .txt file can be found
274
+ return default_pretrain_file
275
+
276
+
277
+ if __name__ == '__main__':
278
+ with open('test.txt', 'w') as fout:
279
+ fout.write('3 2\na 1 1\nb -1 -1\nc 0 0\n')
280
+ # 1st load: save to pt file
281
+ pretrain = Pretrain('test.pt', 'test.txt')
282
+ print(pretrain.emb)
283
+ # verify pt file
284
+ x = torch.load('test.pt', weights_only=True)
285
+ print(x)
286
+ # 2nd load: load saved pt file
287
+ pretrain = Pretrain('test.pt', 'test.txt')
288
+ print(pretrain.emb)
289
+
stanza/stanza/models/common/seq2seq_modules.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Pytorch implementation of basic sequence to Sequence modules.
3
+ """
4
+
5
+ import logging
6
+ import torch
7
+ import torch.nn as nn
8
+ import math
9
+ import numpy as np
10
+
11
+ import stanza.models.common.seq2seq_constant as constant
12
+
13
+ logger = logging.getLogger('stanza')
14
+
15
+ class BasicAttention(nn.Module):
16
+ """
17
+ A basic MLP attention layer.
18
+ """
19
+ def __init__(self, dim):
20
+ super(BasicAttention, self).__init__()
21
+ self.linear_in = nn.Linear(dim, dim, bias=False)
22
+ self.linear_c = nn.Linear(dim, dim)
23
+ self.linear_v = nn.Linear(dim, 1, bias=False)
24
+ self.linear_out = nn.Linear(dim * 2, dim, bias=False)
25
+ self.tanh = nn.Tanh()
26
+ self.sm = nn.Softmax(dim=1)
27
+
28
+ def forward(self, input, context, mask=None, attn_only=False):
29
+ """
30
+ input: batch x dim
31
+ context: batch x sourceL x dim
32
+ """
33
+ batch_size = context.size(0)
34
+ source_len = context.size(1)
35
+ dim = context.size(2)
36
+ target = self.linear_in(input) # batch x dim
37
+ source = self.linear_c(context.contiguous().view(-1, dim)).view(batch_size, source_len, dim)
38
+ attn = target.unsqueeze(1).expand_as(context) + source
39
+ attn = self.tanh(attn) # batch x sourceL x dim
40
+ attn = self.linear_v(attn.view(-1, dim)).view(batch_size, source_len)
41
+
42
+ if mask is not None:
43
+ attn.masked_fill_(mask, -constant.INFINITY_NUMBER)
44
+
45
+ attn = self.sm(attn)
46
+ if attn_only:
47
+ return attn
48
+
49
+ weighted_context = torch.bmm(attn.unsqueeze(1), context).squeeze(1)
50
+ h_tilde = torch.cat((weighted_context, input), 1)
51
+ h_tilde = self.tanh(self.linear_out(h_tilde))
52
+
53
+ return h_tilde, attn
54
+
55
+ class SoftDotAttention(nn.Module):
56
+ """Soft Dot Attention.
57
+
58
+ Ref: http://www.aclweb.org/anthology/D15-1166
59
+ Adapted from PyTorch OPEN NMT.
60
+ """
61
+
62
+ def __init__(self, dim):
63
+ """Initialize layer."""
64
+ super(SoftDotAttention, self).__init__()
65
+ self.linear_in = nn.Linear(dim, dim, bias=False)
66
+ self.sm = nn.Softmax(dim=1)
67
+ self.linear_out = nn.Linear(dim * 2, dim, bias=False)
68
+ self.tanh = nn.Tanh()
69
+ self.mask = None
70
+
71
+ def forward(self, input, context, mask=None, attn_only=False, return_logattn=False):
72
+ """Propagate input through the network.
73
+
74
+ input: batch x dim
75
+ context: batch x sourceL x dim
76
+ """
77
+ target = self.linear_in(input).unsqueeze(2) # batch x dim x 1
78
+
79
+ # Get attention
80
+ attn = torch.bmm(context, target).squeeze(2) # batch x sourceL
81
+
82
+ if mask is not None:
83
+ # sett the padding attention logits to -inf
84
+ assert mask.size() == attn.size(), "Mask size must match the attention size!"
85
+ attn.masked_fill_(mask, -constant.INFINITY_NUMBER)
86
+
87
+ if return_logattn:
88
+ attn = torch.log_softmax(attn, 1)
89
+ attn_w = torch.exp(attn)
90
+ else:
91
+ attn = self.sm(attn)
92
+ attn_w = attn
93
+ if attn_only:
94
+ return attn
95
+
96
+ attn3 = attn_w.view(attn_w.size(0), 1, attn_w.size(1)) # batch x 1 x sourceL
97
+
98
+ weighted_context = torch.bmm(attn3, context).squeeze(1) # batch x dim
99
+ h_tilde = torch.cat((weighted_context, input), 1)
100
+
101
+ h_tilde = self.tanh(self.linear_out(h_tilde))
102
+
103
+ return h_tilde, attn
104
+
105
+
106
+ class LinearAttention(nn.Module):
107
+ """ A linear attention form, inspired by BiDAF:
108
+ a = W (u; v; u o v)
109
+ """
110
+
111
+ def __init__(self, dim):
112
+ super(LinearAttention, self).__init__()
113
+ self.linear = nn.Linear(dim*3, 1, bias=False)
114
+ self.linear_out = nn.Linear(dim * 2, dim, bias=False)
115
+ self.sm = nn.Softmax(dim=1)
116
+ self.tanh = nn.Tanh()
117
+ self.mask = None
118
+
119
+ def forward(self, input, context, mask=None, attn_only=False):
120
+ """
121
+ input: batch x dim
122
+ context: batch x sourceL x dim
123
+ """
124
+ batch_size = context.size(0)
125
+ source_len = context.size(1)
126
+ dim = context.size(2)
127
+ u = input.unsqueeze(1).expand_as(context).contiguous().view(-1, dim) # batch*sourceL x dim
128
+ v = context.contiguous().view(-1, dim)
129
+ attn_in = torch.cat((u, v, u.mul(v)), 1)
130
+ attn = self.linear(attn_in).view(batch_size, source_len)
131
+
132
+ if mask is not None:
133
+ # sett the padding attention logits to -inf
134
+ assert mask.size() == attn.size(), "Mask size must match the attention size!"
135
+ attn.masked_fill_(mask, -constant.INFINITY_NUMBER)
136
+
137
+ attn = self.sm(attn)
138
+ if attn_only:
139
+ return attn
140
+
141
+ attn3 = attn.view(batch_size, 1, source_len) # batch x 1 x sourceL
142
+
143
+ weighted_context = torch.bmm(attn3, context).squeeze(1) # batch x dim
144
+ h_tilde = torch.cat((weighted_context, input), 1)
145
+ h_tilde = self.tanh(self.linear_out(h_tilde))
146
+ return h_tilde, attn
147
+
148
+ class DeepAttention(nn.Module):
149
+ """ A deep attention form, invented by Robert:
150
+ u = ReLU(Wx)
151
+ v = ReLU(Wy)
152
+ a = V.(u o v)
153
+ """
154
+
155
+ def __init__(self, dim):
156
+ super(DeepAttention, self).__init__()
157
+ self.linear_in = nn.Linear(dim, dim, bias=False)
158
+ self.linear_v = nn.Linear(dim, 1, bias=False)
159
+ self.linear_out = nn.Linear(dim * 2, dim, bias=False)
160
+ self.relu = nn.ReLU()
161
+ self.sm = nn.Softmax(dim=1)
162
+ self.tanh = nn.Tanh()
163
+ self.mask = None
164
+
165
+ def forward(self, input, context, mask=None, attn_only=False):
166
+ """
167
+ input: batch x dim
168
+ context: batch x sourceL x dim
169
+ """
170
+ batch_size = context.size(0)
171
+ source_len = context.size(1)
172
+ dim = context.size(2)
173
+ u = input.unsqueeze(1).expand_as(context).contiguous().view(-1, dim) # batch*sourceL x dim
174
+ u = self.relu(self.linear_in(u))
175
+ v = self.relu(self.linear_in(context.contiguous().view(-1, dim)))
176
+ attn = self.linear_v(u.mul(v)).view(batch_size, source_len)
177
+
178
+ if mask is not None:
179
+ # sett the padding attention logits to -inf
180
+ assert mask.size() == attn.size(), "Mask size must match the attention size!"
181
+ attn.masked_fill_(mask, -constant.INFINITY_NUMBER)
182
+
183
+ attn = self.sm(attn)
184
+ if attn_only:
185
+ return attn
186
+
187
+ attn3 = attn.view(batch_size, 1, source_len) # batch x 1 x sourceL
188
+
189
+ weighted_context = torch.bmm(attn3, context).squeeze(1) # batch x dim
190
+ h_tilde = torch.cat((weighted_context, input), 1)
191
+ h_tilde = self.tanh(self.linear_out(h_tilde))
192
+ return h_tilde, attn
193
+
194
+ class LSTMAttention(nn.Module):
195
+ r"""A long short-term memory (LSTM) cell with attention."""
196
+
197
+ def __init__(self, input_size, hidden_size, batch_first=True, attn_type='soft'):
198
+ """Initialize params."""
199
+ super(LSTMAttention, self).__init__()
200
+ self.input_size = input_size
201
+ self.hidden_size = hidden_size
202
+ self.batch_first = batch_first
203
+
204
+ self.lstm_cell = nn.LSTMCell(input_size, hidden_size)
205
+
206
+ if attn_type == 'soft':
207
+ self.attention_layer = SoftDotAttention(hidden_size)
208
+ elif attn_type == 'mlp':
209
+ self.attention_layer = BasicAttention(hidden_size)
210
+ elif attn_type == 'linear':
211
+ self.attention_layer = LinearAttention(hidden_size)
212
+ elif attn_type == 'deep':
213
+ self.attention_layer = DeepAttention(hidden_size)
214
+ else:
215
+ raise Exception("Unsupported LSTM attention type: {}".format(attn_type))
216
+ logger.debug("Using {} attention for LSTM.".format(attn_type))
217
+
218
+ def forward(self, input, hidden, ctx, ctx_mask=None, return_logattn=False):
219
+ """Propagate input through the network."""
220
+ if self.batch_first:
221
+ input = input.transpose(0,1)
222
+
223
+ output = []
224
+ attn = []
225
+ steps = range(input.size(0))
226
+ for i in steps:
227
+ hidden = self.lstm_cell(input[i], hidden)
228
+ hy, cy = hidden
229
+ h_tilde, alpha = self.attention_layer(hy, ctx, mask=ctx_mask, return_logattn=return_logattn)
230
+ output.append(h_tilde)
231
+ attn.append(alpha)
232
+ output = torch.cat(output, 0).view(input.size(0), *output[0].size())
233
+
234
+ if self.batch_first:
235
+ output = output.transpose(0,1)
236
+
237
+ if return_logattn:
238
+ attn = torch.stack(attn, 0)
239
+ if self.batch_first:
240
+ attn = attn.transpose(0, 1)
241
+ return output, hidden, attn
242
+
243
+ return output, hidden
244
+
stanza/stanza/models/common/stanza_object.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def _readonly_setter(self, name):
2
+ full_classname = self.__class__.__module__
3
+ if full_classname is None:
4
+ full_classname = self.__class__.__qualname__
5
+ else:
6
+ full_classname += '.' + self.__class__.__qualname__
7
+ raise ValueError(f'Property "{name}" of "{full_classname}" is read-only.')
8
+
9
+ class StanzaObject(object):
10
+ """
11
+ Base class for all Stanza data objects that allows for some flexibility handling annotations
12
+ """
13
+
14
+ @classmethod
15
+ def add_property(cls, name, default=None, getter=None, setter=None):
16
+ """
17
+ Add a property accessible through self.{name} with underlying variable self._{name}.
18
+ Optionally setup a setter as well.
19
+ """
20
+
21
+ if hasattr(cls, name):
22
+ raise ValueError(f'Property by the name of {name} already exists in {cls}. Maybe you want to find another name?')
23
+
24
+ setattr(cls, f'_{name}', default)
25
+ if getter is None:
26
+ getter = lambda self: getattr(self, f'_{name}')
27
+ if setter is None:
28
+ setter = lambda self, value: _readonly_setter(self, name)
29
+
30
+ setattr(cls, name, property(getter, setter))
31
+