tyfsadik commited on
Commit
e38d58c
·
verified ·
1 Parent(s): dbe470d

Upload 7 files

Browse files
gector/bert_token_embedder.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tweaked version of corresponding AllenNLP file"""
2
+ import logging
3
+ from copy import deepcopy
4
+ from typing import Dict
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from allennlp.modules.token_embedders.token_embedder import TokenEmbedder
9
+ from allennlp.nn import util
10
+ from transformers import AutoModel, PreTrainedModel
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class PretrainedBertModel:
16
+ """
17
+ In some instances you may want to load the same BERT model twice
18
+ (e.g. to use as a token embedder and also as a pooling layer).
19
+ This factory provides a cache so that you don't actually have to load the model twice.
20
+ """
21
+
22
+ _cache: Dict[str, PreTrainedModel] = {}
23
+
24
+ @classmethod
25
+ def load(cls, model_name: str, cache_model: bool = True) -> PreTrainedModel:
26
+ if model_name in cls._cache:
27
+ return PretrainedBertModel._cache[model_name]
28
+
29
+ model = AutoModel.from_pretrained(model_name)
30
+ if cache_model:
31
+ cls._cache[model_name] = model
32
+
33
+ return model
34
+
35
+
36
+ class BertEmbedder(TokenEmbedder):
37
+ """
38
+ A ``TokenEmbedder`` that produces BERT embeddings for your tokens.
39
+ Should be paired with a ``BertIndexer``, which produces wordpiece ids.
40
+ Most likely you probably want to use ``PretrainedBertEmbedder``
41
+ for one of the named pretrained models, not this base class.
42
+ Parameters
43
+ ----------
44
+ bert_model: ``BertModel``
45
+ The BERT model being wrapped.
46
+ top_layer_only: ``bool``, optional (default = ``False``)
47
+ If ``True``, then only return the top layer instead of apply the scalar mix.
48
+ max_pieces : int, optional (default: 512)
49
+ The BERT embedder uses positional embeddings and so has a corresponding
50
+ maximum length for its input ids. Assuming the inputs are windowed
51
+ and padded appropriately by this length, the embedder will split them into a
52
+ large batch, feed them into BERT, and recombine the output as if it was a
53
+ longer sequence.
54
+ num_start_tokens : int, optional (default: 1)
55
+ The number of starting special tokens input to BERT (usually 1, i.e., [CLS])
56
+ num_end_tokens : int, optional (default: 1)
57
+ The number of ending tokens input to BERT (usually 1, i.e., [SEP])
58
+ scalar_mix_parameters: ``List[float]``, optional, (default = None)
59
+ If not ``None``, use these scalar mix parameters to weight the representations
60
+ produced by different layers. These mixing weights are not updated during
61
+ training.
62
+ """
63
+
64
+ def __init__(
65
+ self,
66
+ bert_model: PreTrainedModel,
67
+ top_layer_only: bool = False,
68
+ max_pieces: int = 512,
69
+ num_start_tokens: int = 1,
70
+ num_end_tokens: int = 1
71
+ ) -> None:
72
+ super().__init__()
73
+ self.bert_model = deepcopy(bert_model)
74
+ self.output_dim = bert_model.config.hidden_size
75
+ self.max_pieces = max_pieces
76
+ self.num_start_tokens = num_start_tokens
77
+ self.num_end_tokens = num_end_tokens
78
+ self._scalar_mix = None
79
+
80
+ def set_weights(self, freeze):
81
+ for param in self.bert_model.parameters():
82
+ param.requires_grad = not freeze
83
+ return
84
+
85
+ def get_output_dim(self) -> int:
86
+ return self.output_dim
87
+
88
+ def forward(
89
+ self,
90
+ input_ids: torch.LongTensor,
91
+ offsets: torch.LongTensor = None
92
+ ) -> torch.Tensor:
93
+ """
94
+ Parameters
95
+ ----------
96
+ input_ids : ``torch.LongTensor``
97
+ The (batch_size, ..., max_sequence_length) tensor of wordpiece ids.
98
+ offsets : ``torch.LongTensor``, optional
99
+ The BERT embeddings are one per wordpiece. However it's possible/likely
100
+ you might want one per original token. In that case, ``offsets``
101
+ represents the indices of the desired wordpiece for each original token.
102
+ Depending on how your token indexer is configured, this could be the
103
+ position of the last wordpiece for each token, or it could be the position
104
+ of the first wordpiece for each token.
105
+ For example, if you had the sentence "Definitely not", and if the corresponding
106
+ wordpieces were ["Def", "##in", "##ite", "##ly", "not"], then the input_ids
107
+ would be 5 wordpiece ids, and the "last wordpiece" offsets would be [3, 4].
108
+ If offsets are provided, the returned tensor will contain only the wordpiece
109
+ embeddings at those positions, and (in particular) will contain one embedding
110
+ per token. If offsets are not provided, the entire tensor of wordpiece embeddings
111
+ will be returned.
112
+ """
113
+
114
+ batch_size, full_seq_len = input_ids.size(0), input_ids.size(-1)
115
+ initial_dims = list(input_ids.shape[:-1])
116
+
117
+ # The embedder may receive an input tensor that has a sequence length longer than can
118
+ # be fit. In that case, we should expect the wordpiece indexer to create padded windows
119
+ # of length `self.max_pieces` for us, and have them concatenated into one long sequence.
120
+ # E.g., "[CLS] I went to the [SEP] [CLS] to the store to [SEP] ..."
121
+ # We can then split the sequence into sub-sequences of that length, and concatenate them
122
+ # along the batch dimension so we effectively have one huge batch of partial sentences.
123
+ # This can then be fed into BERT without any sentence length issues. Keep in mind
124
+ # that the memory consumption can dramatically increase for large batches with extremely
125
+ # long sentences.
126
+ needs_split = full_seq_len > self.max_pieces
127
+ last_window_size = 0
128
+ if needs_split:
129
+ # Split the flattened list by the window size, `max_pieces`
130
+ split_input_ids = list(input_ids.split(self.max_pieces, dim=-1))
131
+
132
+ # We want all sequences to be the same length, so pad the last sequence
133
+ last_window_size = split_input_ids[-1].size(-1)
134
+ padding_amount = self.max_pieces - last_window_size
135
+ split_input_ids[-1] = F.pad(split_input_ids[-1], pad=[0, padding_amount], value=0)
136
+
137
+ # Now combine the sequences along the batch dimension
138
+ input_ids = torch.cat(split_input_ids, dim=0)
139
+
140
+ input_mask = (input_ids != 0).long()
141
+ # input_ids may have extra dimensions, so we reshape down to 2-d
142
+ # before calling the BERT model and then reshape back at the end.
143
+ all_encoder_layers = self.bert_model(
144
+ input_ids=util.combine_initial_dims(input_ids),
145
+ attention_mask=util.combine_initial_dims(input_mask),
146
+ )[0]
147
+ if len(all_encoder_layers[0].shape) == 3:
148
+ all_encoder_layers = torch.stack(all_encoder_layers)
149
+ elif len(all_encoder_layers[0].shape) == 2:
150
+ all_encoder_layers = torch.unsqueeze(all_encoder_layers, dim=0)
151
+
152
+ if needs_split:
153
+ # First, unpack the output embeddings into one long sequence again
154
+ unpacked_embeddings = torch.split(all_encoder_layers, batch_size, dim=1)
155
+ unpacked_embeddings = torch.cat(unpacked_embeddings, dim=2)
156
+
157
+ # Next, select indices of the sequence such that it will result in embeddings representing the original
158
+ # sentence. To capture maximal context, the indices will be the middle part of each embedded window
159
+ # sub-sequence (plus any leftover start and final edge windows), e.g.,
160
+ # 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
161
+ # "[CLS] I went to the very fine [SEP] [CLS] the very fine store to eat [SEP]"
162
+ # with max_pieces = 8 should produce max context indices [2, 3, 4, 10, 11, 12] with additional start
163
+ # and final windows with indices [0, 1] and [14, 15] respectively.
164
+
165
+ # Find the stride as half the max pieces, ignoring the special start and end tokens
166
+ # Calculate an offset to extract the centermost embeddings of each window
167
+ stride = (self.max_pieces - self.num_start_tokens - self.num_end_tokens) // 2
168
+ stride_offset = stride // 2 + self.num_start_tokens
169
+
170
+ first_window = list(range(stride_offset))
171
+
172
+ max_context_windows = [
173
+ i
174
+ for i in range(full_seq_len)
175
+ if stride_offset - 1 < i % self.max_pieces < stride_offset + stride
176
+ ]
177
+
178
+ # Lookback what's left, unless it's the whole self.max_pieces window
179
+ if full_seq_len % self.max_pieces == 0:
180
+ lookback = self.max_pieces
181
+ else:
182
+ lookback = full_seq_len % self.max_pieces
183
+
184
+ final_window_start = full_seq_len - lookback + stride_offset + stride
185
+ final_window = list(range(final_window_start, full_seq_len))
186
+
187
+ select_indices = first_window + max_context_windows + final_window
188
+
189
+ initial_dims.append(len(select_indices))
190
+
191
+ recombined_embeddings = unpacked_embeddings[:, :, select_indices]
192
+ else:
193
+ recombined_embeddings = all_encoder_layers
194
+
195
+ # Recombine the outputs of all layers
196
+ # (layers, batch_size * d1 * ... * dn, sequence_length, embedding_dim)
197
+ # recombined = torch.cat(combined, dim=2)
198
+ input_mask = (recombined_embeddings != 0).long()
199
+
200
+ if self._scalar_mix is not None:
201
+ mix = self._scalar_mix(recombined_embeddings, input_mask)
202
+ else:
203
+ mix = recombined_embeddings[-1]
204
+
205
+ # At this point, mix is (batch_size * d1 * ... * dn, sequence_length, embedding_dim)
206
+
207
+ if offsets is None:
208
+ # Resize to (batch_size, d1, ..., dn, sequence_length, embedding_dim)
209
+ dims = initial_dims if needs_split else input_ids.size()
210
+ return util.uncombine_initial_dims(mix, dims)
211
+ else:
212
+ # offsets is (batch_size, d1, ..., dn, orig_sequence_length)
213
+ offsets2d = util.combine_initial_dims(offsets)
214
+ # now offsets is (batch_size * d1 * ... * dn, orig_sequence_length)
215
+ range_vector = util.get_range_vector(
216
+ offsets2d.size(0), device=util.get_device_of(mix)
217
+ ).unsqueeze(1)
218
+ # selected embeddings is also (batch_size * d1 * ... * dn, orig_sequence_length)
219
+ selected_embeddings = mix[range_vector, offsets2d]
220
+
221
+ return util.uncombine_initial_dims(selected_embeddings, offsets.size())
222
+
223
+
224
+ # @TokenEmbedder.register("bert-pretrained")
225
+ class PretrainedBertEmbedder(BertEmbedder):
226
+
227
+ """
228
+ Parameters
229
+ ----------
230
+ pretrained_model: ``str``
231
+ Either the name of the pretrained model to use (e.g. 'bert-base-uncased'),
232
+ or the path to the .tar.gz file with the model weights.
233
+ If the name is a key in the list of pretrained models at
234
+ https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/pytorch_pretrained_bert/modeling.py#L41
235
+ the corresponding path will be used; otherwise it will be interpreted as a path or URL.
236
+ requires_grad : ``bool``, optional (default = False)
237
+ If True, compute gradient of BERT parameters for fine tuning.
238
+ top_layer_only: ``bool``, optional (default = ``False``)
239
+ If ``True``, then only return the top layer instead of apply the scalar mix.
240
+ scalar_mix_parameters: ``List[float]``, optional, (default = None)
241
+ If not ``None``, use these scalar mix parameters to weight the representations
242
+ produced by different layers. These mixing weights are not updated during
243
+ training.
244
+ """
245
+
246
+ def __init__(
247
+ self,
248
+ pretrained_model: str,
249
+ requires_grad: bool = False,
250
+ top_layer_only: bool = False,
251
+ special_tokens_fix: int = 0,
252
+ ) -> None:
253
+ model = PretrainedBertModel.load(pretrained_model)
254
+
255
+ for param in model.parameters():
256
+ param.requires_grad = requires_grad
257
+
258
+ super().__init__(
259
+ bert_model=model,
260
+ top_layer_only=top_layer_only
261
+ )
262
+
263
+ if special_tokens_fix:
264
+ try:
265
+ vocab_size = self.bert_model.embeddings.word_embeddings.num_embeddings
266
+ except AttributeError:
267
+ # reserve more space
268
+ vocab_size = self.bert_model.word_embedding.num_embeddings + 5
269
+ self.bert_model.resize_token_embeddings(vocab_size + 1)
gector/datareader.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tweaked AllenNLP dataset reader."""
2
+ import logging
3
+ import re
4
+ from random import random
5
+ from typing import Dict, List
6
+
7
+ from allennlp.common.file_utils import cached_path
8
+ from allennlp.data.dataset_readers.dataset_reader import DatasetReader
9
+ from allennlp.data.fields import TextField, SequenceLabelField, MetadataField, Field
10
+ from allennlp.data.instance import Instance
11
+ from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer
12
+ from allennlp.data.tokenizers import Token
13
+ from overrides import overrides
14
+
15
+ from utils.helpers import SEQ_DELIMETERS, START_TOKEN
16
+
17
+ logger = logging.getLogger(__name__) # pylint: disable=invalid-name
18
+
19
+
20
+ @DatasetReader.register("seq2labels_datareader")
21
+ class Seq2LabelsDatasetReader(DatasetReader):
22
+ """
23
+ Reads instances from a pretokenised file where each line is in the following format:
24
+
25
+ WORD###TAG [TAB] WORD###TAG [TAB] ..... \n
26
+
27
+ and converts it into a ``Dataset`` suitable for sequence tagging. You can also specify
28
+ alternative delimiters in the constructor.
29
+
30
+ Parameters
31
+ ----------
32
+ delimiters: ``dict``
33
+ The dcitionary with all delimeters.
34
+ token_indexers : ``Dict[str, TokenIndexer]``, optional (default=``{"tokens": SingleIdTokenIndexer()}``)
35
+ We use this to define the input representation for the text. See :class:`TokenIndexer`.
36
+ Note that the `output` tags will always correspond to single token IDs based on how they
37
+ are pre-tokenised in the data file.
38
+ max_len: if set than will truncate long sentences
39
+ """
40
+ # fix broken sentences mostly in Lang8
41
+ BROKEN_SENTENCES_REGEXP = re.compile(r'\.[a-zA-RT-Z]')
42
+
43
+ def __init__(self,
44
+ token_indexers: Dict[str, TokenIndexer] = None,
45
+ delimeters: dict = SEQ_DELIMETERS,
46
+ skip_correct: bool = False,
47
+ skip_complex: int = 0,
48
+ lazy: bool = False,
49
+ max_len: int = None,
50
+ test_mode: bool = False,
51
+ tag_strategy: str = "keep_one",
52
+ tn_prob: float = 0,
53
+ tp_prob: float = 0,
54
+ broken_dot_strategy: str = "keep") -> None:
55
+ super().__init__(lazy)
56
+ self._token_indexers = token_indexers or {'tokens': SingleIdTokenIndexer()}
57
+ self._delimeters = delimeters
58
+ self._max_len = max_len
59
+ self._skip_correct = skip_correct
60
+ self._skip_complex = skip_complex
61
+ self._tag_strategy = tag_strategy
62
+ self._broken_dot_strategy = broken_dot_strategy
63
+ self._test_mode = test_mode
64
+ self._tn_prob = tn_prob
65
+ self._tp_prob = tp_prob
66
+
67
+ @overrides
68
+ def _read(self, file_path):
69
+ # if `file_path` is a URL, redirect to the cache
70
+ file_path = cached_path(file_path)
71
+ with open(file_path, "r") as data_file:
72
+ logger.info("Reading instances from lines in file at: %s", file_path)
73
+ for line in data_file:
74
+ line = line.strip("\n")
75
+ # skip blank and broken lines
76
+ if not line or (not self._test_mode and self._broken_dot_strategy == 'skip'
77
+ and self.BROKEN_SENTENCES_REGEXP.search(line) is not None):
78
+ continue
79
+
80
+ tokens_and_tags = [pair.rsplit(self._delimeters['labels'], 1)
81
+ for pair in line.split(self._delimeters['tokens'])]
82
+ try:
83
+ tokens = [Token(token) for token, tag in tokens_and_tags]
84
+ tags = [tag for token, tag in tokens_and_tags]
85
+ except ValueError:
86
+ tokens = [Token(token[0]) for token in tokens_and_tags]
87
+ tags = None
88
+
89
+ if tokens and tokens[0] != Token(START_TOKEN):
90
+ tokens = [Token(START_TOKEN)] + tokens
91
+
92
+ words = [x.text for x in tokens]
93
+ if self._max_len is not None:
94
+ tokens = tokens[:self._max_len]
95
+ tags = None if tags is None else tags[:self._max_len]
96
+ instance = self.text_to_instance(tokens, tags, words)
97
+ if instance:
98
+ yield instance
99
+
100
+ def extract_tags(self, tags: List[str]):
101
+ op_del = self._delimeters['operations']
102
+
103
+ labels = [x.split(op_del) for x in tags]
104
+
105
+ comlex_flag_dict = {}
106
+ # get flags
107
+ for i in range(5):
108
+ idx = i + 1
109
+ comlex_flag_dict[idx] = sum([len(x) > idx for x in labels])
110
+
111
+ if self._tag_strategy == "keep_one":
112
+ # get only first candidates for r_tags in right and the last for left
113
+ labels = [x[0] for x in labels]
114
+ elif self._tag_strategy == "merge_all":
115
+ # consider phrases as a words
116
+ pass
117
+ else:
118
+ raise Exception("Incorrect tag strategy")
119
+
120
+ detect_tags = ["CORRECT" if label == "$KEEP" else "INCORRECT" for label in labels]
121
+ return labels, detect_tags, comlex_flag_dict
122
+
123
+ def text_to_instance(self, tokens: List[Token], tags: List[str] = None,
124
+ words: List[str] = None) -> Instance: # type: ignore
125
+ """
126
+ We take `pre-tokenized` input here, because we don't have a tokenizer in this class.
127
+ """
128
+ # pylint: disable=arguments-differ
129
+ fields: Dict[str, Field] = {}
130
+ sequence = TextField(tokens, self._token_indexers)
131
+ fields["tokens"] = sequence
132
+ fields["metadata"] = MetadataField({"words": words})
133
+ if tags is not None:
134
+ labels, detect_tags, complex_flag_dict = self.extract_tags(tags)
135
+ if self._skip_complex and complex_flag_dict[self._skip_complex] > 0:
136
+ return None
137
+ rnd = random()
138
+ # skip TN
139
+ if self._skip_correct and all(x == "CORRECT" for x in detect_tags):
140
+ if rnd > self._tn_prob:
141
+ return None
142
+ # skip TP
143
+ else:
144
+ if rnd > self._tp_prob:
145
+ return None
146
+
147
+ fields["labels"] = SequenceLabelField(labels, sequence,
148
+ label_namespace="labels")
149
+ fields["d_tags"] = SequenceLabelField(detect_tags, sequence,
150
+ label_namespace="d_tags")
151
+ return Instance(fields)
gector/gec_model.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Wrapper of AllenNLP model. Fixes errors based on model predictions"""
2
+ import logging
3
+ import os
4
+ import sys
5
+ from time import time
6
+
7
+ import torch
8
+ from allennlp.data.dataset import Batch
9
+ from allennlp.data.fields import TextField
10
+ from allennlp.data.instance import Instance
11
+ from allennlp.data.tokenizers import Token
12
+ from allennlp.data.vocabulary import Vocabulary
13
+ from allennlp.modules.text_field_embedders import BasicTextFieldEmbedder
14
+ from allennlp.nn import util
15
+
16
+ from gector.bert_token_embedder import PretrainedBertEmbedder
17
+ from gector.seq2labels_model import Seq2Labels
18
+ from gector.tokenizer_indexer import PretrainedBertIndexer
19
+ from utils.helpers import PAD, UNK, get_target_sent_by_edits, START_TOKEN
20
+ from utils.helpers import get_weights_name
21
+
22
+ logging.getLogger("werkzeug").setLevel(logging.ERROR)
23
+ logger = logging.getLogger(__file__)
24
+
25
+
26
+ class GecBERTModel(object):
27
+ def __init__(self, vocab_path=None, model_paths=None,
28
+ weigths=None,
29
+ max_len=50,
30
+ min_len=3,
31
+ lowercase_tokens=False,
32
+ log=False,
33
+ iterations=3,
34
+ model_name='roberta',
35
+ special_tokens_fix=1,
36
+ is_ensemble=True,
37
+ min_error_probability=0.0,
38
+ confidence=0,
39
+ del_confidence=0,
40
+ resolve_cycles=False,
41
+ ):
42
+ self.model_weights = list(map(float, weigths)) if weigths else [1] * len(model_paths)
43
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
44
+ self.max_len = max_len
45
+ self.min_len = min_len
46
+ self.lowercase_tokens = lowercase_tokens
47
+ self.min_error_probability = min_error_probability
48
+ self.vocab = Vocabulary.from_files(vocab_path)
49
+ self.log = log
50
+ self.iterations = iterations
51
+ self.confidence = confidence
52
+ self.del_conf = del_confidence
53
+ self.resolve_cycles = resolve_cycles
54
+ # set training parameters and operations
55
+
56
+ self.indexers = []
57
+ self.models = []
58
+ for model_path in model_paths:
59
+ if is_ensemble:
60
+ model_name, special_tokens_fix = self._get_model_data(model_path)
61
+ weights_name = get_weights_name(model_name, lowercase_tokens)
62
+ self.indexers.append(self._get_indexer(weights_name, special_tokens_fix))
63
+ model = Seq2Labels(vocab=self.vocab,
64
+ text_field_embedder=self._get_embbeder(weights_name, special_tokens_fix),
65
+ confidence=self.confidence,
66
+ del_confidence=self.del_conf,
67
+ ).to(self.device)
68
+ if torch.cuda.is_available():
69
+ model.load_state_dict(torch.load(model_path), strict=False)
70
+ else:
71
+ model.load_state_dict(torch.load(model_path,
72
+ map_location=torch.device('cpu')),
73
+ strict=False)
74
+ model.eval()
75
+ self.models.append(model)
76
+
77
+ @staticmethod
78
+ def _get_model_data(model_path):
79
+ model_name = model_path.split('/')[-1]
80
+ tr_model, stf = model_name.split('_')[:2]
81
+ return tr_model, int(stf)
82
+
83
+ def _restore_model(self, input_path):
84
+ if os.path.isdir(input_path):
85
+ print("Model could not be restored from directory", file=sys.stderr)
86
+ filenames = []
87
+ else:
88
+ filenames = [input_path]
89
+ for model_path in filenames:
90
+ try:
91
+ if torch.cuda.is_available():
92
+ loaded_model = torch.load(model_path)
93
+ else:
94
+ loaded_model = torch.load(model_path,
95
+ map_location=lambda storage,
96
+ loc: storage)
97
+ except:
98
+ print(f"{model_path} is not valid model", file=sys.stderr)
99
+ own_state = self.model.state_dict()
100
+ for name, weights in loaded_model.items():
101
+ if name not in own_state:
102
+ continue
103
+ try:
104
+ if len(filenames) == 1:
105
+ own_state[name].copy_(weights)
106
+ else:
107
+ own_state[name] += weights
108
+ except RuntimeError:
109
+ continue
110
+ print("Model is restored", file=sys.stderr)
111
+
112
+ def predict(self, batches):
113
+ t11 = time()
114
+ predictions = []
115
+ for batch, model in zip(batches, self.models):
116
+ batch = util.move_to_device(batch.as_tensor_dict(), 0 if torch.cuda.is_available() else -1)
117
+ with torch.no_grad():
118
+ prediction = model.forward(**batch)
119
+ predictions.append(prediction)
120
+
121
+ preds, idx, error_probs = self._convert(predictions)
122
+ t55 = time()
123
+ if self.log:
124
+ print(f"Inference time {t55 - t11}")
125
+ return preds, idx, error_probs
126
+
127
+ def get_token_action(self, token, index, prob, sugg_token):
128
+ """Get lost of suggested actions for token."""
129
+ # cases when we don't need to do anything
130
+ if prob < self.min_error_probability or sugg_token in [UNK, PAD, '$KEEP']:
131
+ return None
132
+
133
+ if sugg_token.startswith('$REPLACE_') or sugg_token.startswith('$TRANSFORM_') or sugg_token == '$DELETE':
134
+ start_pos = index
135
+ end_pos = index + 1
136
+ elif sugg_token.startswith("$APPEND_") or sugg_token.startswith("$MERGE_"):
137
+ start_pos = index + 1
138
+ end_pos = index + 1
139
+
140
+ if sugg_token == "$DELETE":
141
+ sugg_token_clear = ""
142
+ elif sugg_token.startswith('$TRANSFORM_') or sugg_token.startswith("$MERGE_"):
143
+ sugg_token_clear = sugg_token[:]
144
+ else:
145
+ sugg_token_clear = sugg_token[sugg_token.index('_') + 1:]
146
+
147
+ return start_pos - 1, end_pos - 1, sugg_token_clear, prob
148
+
149
+ def _get_embbeder(self, weigths_name, special_tokens_fix):
150
+ embedders = {'bert': PretrainedBertEmbedder(
151
+ pretrained_model=weigths_name,
152
+ requires_grad=False,
153
+ top_layer_only=True,
154
+ special_tokens_fix=special_tokens_fix)
155
+ }
156
+ text_field_embedder = BasicTextFieldEmbedder(
157
+ token_embedders=embedders,
158
+ embedder_to_indexer_map={"bert": ["bert", "bert-offsets"]},
159
+ allow_unmatched_keys=True)
160
+ return text_field_embedder
161
+
162
+ def _get_indexer(self, weights_name, special_tokens_fix):
163
+ bert_token_indexer = PretrainedBertIndexer(
164
+ pretrained_model=weights_name,
165
+ do_lowercase=self.lowercase_tokens,
166
+ max_pieces_per_token=5,
167
+ special_tokens_fix=special_tokens_fix
168
+ )
169
+ return {'bert': bert_token_indexer}
170
+
171
+ def preprocess(self, token_batch):
172
+ seq_lens = [len(sequence) for sequence in token_batch if sequence]
173
+ if not seq_lens:
174
+ return []
175
+ max_len = min(max(seq_lens), self.max_len)
176
+ batches = []
177
+ for indexer in self.indexers:
178
+ batch = []
179
+ for sequence in token_batch:
180
+ tokens = sequence[:max_len]
181
+ tokens = [Token(token) for token in ['$START'] + tokens]
182
+ batch.append(Instance({'tokens': TextField(tokens, indexer)}))
183
+ batch = Batch(batch)
184
+ batch.index_instances(self.vocab)
185
+ batches.append(batch)
186
+
187
+ return batches
188
+
189
+ def _convert(self, data):
190
+ all_class_probs = torch.zeros_like(data[0]['class_probabilities_labels'])
191
+ error_probs = torch.zeros_like(data[0]['max_error_probability'])
192
+ for output, weight in zip(data, self.model_weights):
193
+ all_class_probs += weight * output['class_probabilities_labels'] / sum(self.model_weights)
194
+ error_probs += weight * output['max_error_probability'] / sum(self.model_weights)
195
+
196
+ max_vals = torch.max(all_class_probs, dim=-1)
197
+ probs = max_vals[0].tolist()
198
+ idx = max_vals[1].tolist()
199
+ return probs, idx, error_probs.tolist()
200
+
201
+ def update_final_batch(self, final_batch, pred_ids, pred_batch,
202
+ prev_preds_dict):
203
+ new_pred_ids = []
204
+ total_updated = 0
205
+ for i, orig_id in enumerate(pred_ids):
206
+ orig = final_batch[orig_id]
207
+ pred = pred_batch[i]
208
+ prev_preds = prev_preds_dict[orig_id]
209
+ if orig != pred and pred not in prev_preds:
210
+ final_batch[orig_id] = pred
211
+ new_pred_ids.append(orig_id)
212
+ prev_preds_dict[orig_id].append(pred)
213
+ total_updated += 1
214
+ elif orig != pred and pred in prev_preds:
215
+ # update final batch, but stop iterations
216
+ final_batch[orig_id] = pred
217
+ total_updated += 1
218
+ else:
219
+ continue
220
+ return final_batch, new_pred_ids, total_updated
221
+
222
+ def postprocess_batch(self, batch, all_probabilities, all_idxs,
223
+ error_probs):
224
+ all_results = []
225
+ noop_index = self.vocab.get_token_index("$KEEP", "labels")
226
+ for tokens, probabilities, idxs, error_prob in zip(batch,
227
+ all_probabilities,
228
+ all_idxs,
229
+ error_probs):
230
+ length = min(len(tokens), self.max_len)
231
+ edits = []
232
+
233
+ # skip whole sentences if there no errors
234
+ if max(idxs) == 0:
235
+ all_results.append(tokens)
236
+ continue
237
+
238
+ # skip whole sentence if probability of correctness is not high
239
+ if error_prob < self.min_error_probability:
240
+ all_results.append(tokens)
241
+ continue
242
+
243
+ for i in range(length + 1):
244
+ # because of START token
245
+ if i == 0:
246
+ token = START_TOKEN
247
+ else:
248
+ token = tokens[i - 1]
249
+ # skip if there is no error
250
+ if idxs[i] == noop_index:
251
+ continue
252
+
253
+ sugg_token = self.vocab.get_token_from_index(idxs[i],
254
+ namespace='labels')
255
+ action = self.get_token_action(token, i, probabilities[i],
256
+ sugg_token)
257
+ if not action:
258
+ continue
259
+
260
+ edits.append(action)
261
+ all_results.append(get_target_sent_by_edits(tokens, edits))
262
+ return all_results
263
+
264
+ def handle_batch(self, full_batch):
265
+ """
266
+ Handle batch of requests.
267
+ """
268
+ final_batch = full_batch[:]
269
+ batch_size = len(full_batch)
270
+ prev_preds_dict = {i: [final_batch[i]] for i in range(len(final_batch))}
271
+ short_ids = [i for i in range(len(full_batch))
272
+ if len(full_batch[i]) < self.min_len]
273
+ pred_ids = [i for i in range(len(full_batch)) if i not in short_ids]
274
+ total_updates = 0
275
+
276
+ for n_iter in range(self.iterations):
277
+ orig_batch = [final_batch[i] for i in pred_ids]
278
+
279
+ sequences = self.preprocess(orig_batch)
280
+
281
+ if not sequences:
282
+ break
283
+ probabilities, idxs, error_probs = self.predict(sequences)
284
+
285
+ pred_batch = self.postprocess_batch(orig_batch, probabilities,
286
+ idxs, error_probs)
287
+ if self.log:
288
+ print(f"Iteration {n_iter + 1}. Predicted {round(100*len(pred_ids)/batch_size, 1)}% of sentences.")
289
+
290
+ final_batch, pred_ids, cnt = \
291
+ self.update_final_batch(final_batch, pred_ids, pred_batch,
292
+ prev_preds_dict)
293
+ total_updates += cnt
294
+
295
+ if not pred_ids:
296
+ break
297
+
298
+ return final_batch, total_updates
gector/seq2labels_model.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Basic model. Predicts tags for every token"""
2
+ from typing import Dict, Optional, List, Any
3
+
4
+ import numpy
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from allennlp.data import Vocabulary
8
+ from allennlp.models.model import Model
9
+ from allennlp.modules import TimeDistributed, TextFieldEmbedder
10
+ from allennlp.nn import InitializerApplicator, RegularizerApplicator
11
+ from allennlp.nn.util import get_text_field_mask, sequence_cross_entropy_with_logits
12
+ from allennlp.training.metrics import CategoricalAccuracy
13
+ from overrides import overrides
14
+ from torch.nn.modules.linear import Linear
15
+
16
+
17
+ @Model.register("seq2labels")
18
+ class Seq2Labels(Model):
19
+ """
20
+ This ``Seq2Labels`` simply encodes a sequence of text with a stacked ``Seq2SeqEncoder``, then
21
+ predicts a tag (or couple tags) for each token in the sequence.
22
+
23
+ Parameters
24
+ ----------
25
+ vocab : ``Vocabulary``, required
26
+ A Vocabulary, required in order to compute sizes for input/output projections.
27
+ text_field_embedder : ``TextFieldEmbedder``, required
28
+ Used to embed the ``tokens`` ``TextField`` we get as input to the model.
29
+ encoder : ``Seq2SeqEncoder``
30
+ The encoder (with its own internal stacking) that we will use in between embedding tokens
31
+ and predicting output tags.
32
+ calculate_span_f1 : ``bool``, optional (default=``None``)
33
+ Calculate span-level F1 metrics during training. If this is ``True``, then
34
+ ``label_encoding`` is required. If ``None`` and
35
+ label_encoding is specified, this is set to ``True``.
36
+ If ``None`` and label_encoding is not specified, it defaults
37
+ to ``False``.
38
+ label_encoding : ``str``, optional (default=``None``)
39
+ Label encoding to use when calculating span f1.
40
+ Valid options are "BIO", "BIOUL", "IOB1", "BMES".
41
+ Required if ``calculate_span_f1`` is true.
42
+ labels_namespace : ``str``, optional (default=``labels``)
43
+ This is needed to compute the SpanBasedF1Measure metric, if desired.
44
+ Unless you did something unusual, the default value should be what you want.
45
+ verbose_metrics : ``bool``, optional (default = False)
46
+ If true, metrics will be returned per label class in addition
47
+ to the overall statistics.
48
+ initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``)
49
+ Used to initialize the model parameters.
50
+ regularizer : ``RegularizerApplicator``, optional (default=``None``)
51
+ If provided, will be used to calculate the regularization penalty during training.
52
+ """
53
+
54
+ def __init__(self, vocab: Vocabulary,
55
+ text_field_embedder: TextFieldEmbedder,
56
+ predictor_dropout=0.0,
57
+ labels_namespace: str = "labels",
58
+ detect_namespace: str = "d_tags",
59
+ verbose_metrics: bool = False,
60
+ label_smoothing: float = 0.0,
61
+ confidence: float = 0.0,
62
+ del_confidence: float = 0.0,
63
+ initializer: InitializerApplicator = InitializerApplicator(),
64
+ regularizer: Optional[RegularizerApplicator] = None) -> None:
65
+ super(Seq2Labels, self).__init__(vocab, regularizer)
66
+
67
+ self.label_namespaces = [labels_namespace,
68
+ detect_namespace]
69
+ self.text_field_embedder = text_field_embedder
70
+ self.num_labels_classes = self.vocab.get_vocab_size(labels_namespace)
71
+ self.num_detect_classes = self.vocab.get_vocab_size(detect_namespace)
72
+ self.label_smoothing = label_smoothing
73
+ self.confidence = confidence
74
+ self.del_conf = del_confidence
75
+ self.incorr_index = self.vocab.get_token_index("INCORRECT",
76
+ namespace=detect_namespace)
77
+
78
+ self._verbose_metrics = verbose_metrics
79
+ self.predictor_dropout = TimeDistributed(torch.nn.Dropout(predictor_dropout))
80
+
81
+ self.tag_labels_projection_layer = TimeDistributed(
82
+ Linear(text_field_embedder._token_embedders['bert'].get_output_dim(), self.num_labels_classes))
83
+
84
+ self.tag_detect_projection_layer = TimeDistributed(
85
+ Linear(text_field_embedder._token_embedders['bert'].get_output_dim(), self.num_detect_classes))
86
+
87
+ self.metrics = {"accuracy": CategoricalAccuracy()}
88
+
89
+ initializer(self)
90
+
91
+ @overrides
92
+ def forward(self, # type: ignore
93
+ tokens: Dict[str, torch.LongTensor],
94
+ labels: torch.LongTensor = None,
95
+ d_tags: torch.LongTensor = None,
96
+ metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
97
+ # pylint: disable=arguments-differ
98
+ """
99
+ Parameters
100
+ ----------
101
+ tokens : Dict[str, torch.LongTensor], required
102
+ The output of ``TextField.as_array()``, which should typically be passed directly to a
103
+ ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer``
104
+ tensors. At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens":
105
+ Tensor(batch_size, num_tokens)}``. This dictionary will have the same keys as were used
106
+ for the ``TokenIndexers`` when you created the ``TextField`` representing your
107
+ sequence. The dictionary is designed to be passed directly to a ``TextFieldEmbedder``,
108
+ which knows how to combine different word representations into a single vector per
109
+ token in your input.
110
+ labels : torch.LongTensor, optional (default = None)
111
+ A torch tensor representing the sequence of integer gold class labels of shape
112
+ ``(batch_size, num_tokens)``.
113
+ d_tags : torch.LongTensor, optional (default = None)
114
+ A torch tensor representing the sequence of integer gold class labels of shape
115
+ ``(batch_size, num_tokens)``.
116
+ metadata : ``List[Dict[str, Any]]``, optional, (default = None)
117
+ metadata containing the original words in the sentence to be tagged under a 'words' key.
118
+
119
+ Returns
120
+ -------
121
+ An output dictionary consisting of:
122
+ logits : torch.FloatTensor
123
+ A tensor of shape ``(batch_size, num_tokens, tag_vocab_size)`` representing
124
+ unnormalised log probabilities of the tag classes.
125
+ class_probabilities : torch.FloatTensor
126
+ A tensor of shape ``(batch_size, num_tokens, tag_vocab_size)`` representing
127
+ a distribution of the tag classes per word.
128
+ loss : torch.FloatTensor, optional
129
+ A scalar loss to be optimised.
130
+
131
+ """
132
+ encoded_text = self.text_field_embedder(tokens)
133
+ batch_size, sequence_length, _ = encoded_text.size()
134
+ mask = get_text_field_mask(tokens)
135
+ logits_labels = self.tag_labels_projection_layer(self.predictor_dropout(encoded_text))
136
+ logits_d = self.tag_detect_projection_layer(encoded_text)
137
+
138
+ class_probabilities_labels = F.softmax(logits_labels, dim=-1).view(
139
+ [batch_size, sequence_length, self.num_labels_classes])
140
+ class_probabilities_d = F.softmax(logits_d, dim=-1).view(
141
+ [batch_size, sequence_length, self.num_detect_classes])
142
+ error_probs = class_probabilities_d[:, :, self.incorr_index] * mask
143
+ incorr_prob = torch.max(error_probs, dim=-1)[0]
144
+
145
+ probability_change = [self.confidence, self.del_conf] + [0] * (self.num_labels_classes - 2)
146
+ class_probabilities_labels += torch.FloatTensor(probability_change).repeat(
147
+ (batch_size, sequence_length, 1)).to(class_probabilities_labels.device)
148
+
149
+ output_dict = {"logits_labels": logits_labels,
150
+ "logits_d_tags": logits_d,
151
+ "class_probabilities_labels": class_probabilities_labels,
152
+ "class_probabilities_d_tags": class_probabilities_d,
153
+ "max_error_probability": incorr_prob}
154
+ if labels is not None and d_tags is not None:
155
+ loss_labels = sequence_cross_entropy_with_logits(logits_labels, labels, mask,
156
+ label_smoothing=self.label_smoothing)
157
+ loss_d = sequence_cross_entropy_with_logits(logits_d, d_tags, mask)
158
+ for metric in self.metrics.values():
159
+ metric(logits_labels, labels, mask.float())
160
+ metric(logits_d, d_tags, mask.float())
161
+ output_dict["loss"] = loss_labels + loss_d
162
+
163
+ if metadata is not None:
164
+ output_dict["words"] = [x["words"] for x in metadata]
165
+ return output_dict
166
+
167
+ @overrides
168
+ def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
169
+ """
170
+ Does a simple position-wise argmax over each token, converts indices to string labels, and
171
+ adds a ``"tags"`` key to the dictionary with the result.
172
+ """
173
+ for label_namespace in self.label_namespaces:
174
+ all_predictions = output_dict[f'class_probabilities_{label_namespace}']
175
+ all_predictions = all_predictions.cpu().data.numpy()
176
+ if all_predictions.ndim == 3:
177
+ predictions_list = [all_predictions[i] for i in range(all_predictions.shape[0])]
178
+ else:
179
+ predictions_list = [all_predictions]
180
+ all_tags = []
181
+
182
+ for predictions in predictions_list:
183
+ argmax_indices = numpy.argmax(predictions, axis=-1)
184
+ tags = [self.vocab.get_token_from_index(x, namespace=label_namespace)
185
+ for x in argmax_indices]
186
+ all_tags.append(tags)
187
+ output_dict[f'{label_namespace}'] = all_tags
188
+ return output_dict
189
+
190
+ @overrides
191
+ def get_metrics(self, reset: bool = False) -> Dict[str, float]:
192
+ metrics_to_return = {metric_name: metric.get_metric(reset) for
193
+ metric_name, metric in self.metrics.items()}
194
+ return metrics_to_return
gector/tokenization.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from time import time
3
+
4
+
5
+ os.environ['TOKENIZERS_PARALLELISM'] = 'false'
6
+
7
+
8
+ def get_bpe_groups(token_offsets, bpe_offsets, input_ids, max_bpe_pieces=5):
9
+ bpe_groups = []
10
+ last_used_bpe = 0
11
+ # find the size of offsets
12
+ if (0, 0) in bpe_offsets:
13
+ bpe_size = bpe_offsets.index((0, 0))
14
+ else:
15
+ bpe_size = len(bpe_offsets)
16
+
17
+ saved_ids = [i for i in range(len(input_ids))]
18
+ redundant_ids = []
19
+ for token_offset in token_offsets:
20
+ start_token, end_token = token_offset
21
+ bpe_group = []
22
+ mapping_is_found = False
23
+ for i in range(last_used_bpe, bpe_size):
24
+ start_bpe, end_bpe = bpe_offsets[i]
25
+ if start_bpe >= start_token and end_bpe <= end_token:
26
+ # check if bpe_group is satisfy max_bpe_pieces constraint
27
+ if len(bpe_group) < max_bpe_pieces:
28
+ bpe_group.append(i)
29
+ else:
30
+ redundant_ids.append(i)
31
+ last_used_bpe = i + 1
32
+ mapping_is_found = True
33
+ elif mapping_is_found:
34
+ # stop doing useless iterations
35
+ break
36
+ else:
37
+ continue
38
+ bpe_groups.append(bpe_group)
39
+ saved_ids = [i for i in saved_ids if i not in redundant_ids]
40
+ return bpe_groups, saved_ids
41
+
42
+
43
+ def reduce_input_ids(input_ids, bpe_groups, saved_ids,
44
+ max_bpe_length=80, max_bpe_pieces=5):
45
+ # check if sequence is satisfy max_bpe_length constraint
46
+ while len(saved_ids) > max_bpe_length:
47
+ max_bpe_pieces -= 1
48
+ for token_id in range(len(bpe_groups)):
49
+ if len(bpe_groups[token_id]) > max_bpe_pieces:
50
+ redundant_ids = bpe_groups[token_id][max_bpe_pieces:]
51
+ bpe_groups[token_id] = bpe_groups[token_id][:max_bpe_pieces]
52
+ saved_ids = [i for i in saved_ids if i not in redundant_ids]
53
+
54
+ # get offsets
55
+ reduced_ids = [input_ids[i] for i in saved_ids]
56
+ correct_offsets = []
57
+ idx = 0
58
+ for i, bpe_group in enumerate(bpe_groups):
59
+ norm_idx = min(idx, len(reduced_ids) - 1)
60
+ correct_offsets.append(norm_idx)
61
+ idx += len(bpe_group)
62
+
63
+ return reduced_ids, correct_offsets
64
+
65
+
66
+ def get_offsets_and_reduce_input_ids(tokenizer_output, token_offset_list,
67
+ index_name="bert", max_bpe_length=80,
68
+ max_bpe_pieces=5):
69
+ timings = {"bpe": 0, "reduce": 0, "mask": 0}
70
+ output_ids, output_offsets, output_masks = [], [], []
71
+ for i, token_offsets in enumerate(token_offset_list):
72
+ input_ids = tokenizer_output['input_ids'][i]
73
+
74
+ t0 = time()
75
+ # get bpe level offsets
76
+ bpe_offsets = tokenizer_output['offset_mapping'][i]
77
+ bpe_groups, saved_ids = get_bpe_groups(token_offsets, bpe_offsets,
78
+ input_ids,
79
+ max_bpe_pieces=max_bpe_pieces)
80
+ t1 = time()
81
+ timings["bpe"] += t1 - t0
82
+
83
+ # reduce sequence length
84
+ reduced_ids, correct_offsets = reduce_input_ids(input_ids, bpe_groups,
85
+ saved_ids,
86
+ max_bpe_length=max_bpe_length,
87
+ max_bpe_pieces=max_bpe_pieces)
88
+
89
+ t2 = time()
90
+ timings["reduce"] += t2 - t1
91
+
92
+ # get mask
93
+ bpe_mask = [1 for _ in correct_offsets]
94
+ output_ids.append(reduced_ids)
95
+ output_offsets.append(correct_offsets)
96
+ output_masks.append(bpe_mask)
97
+
98
+ t3 = time()
99
+ timings["mask"] += t3 - t2
100
+
101
+ # tt = sum(timings.values())
102
+ # timings = {k: f"{round(v * 100 / tt, 2)}%" for k, v in timings.items()}
103
+ # print(timings)
104
+
105
+ output = {index_name: output_ids,
106
+ f"{index_name}-offsets": output_offsets,
107
+ "mask": output_masks}
108
+ return output
109
+
110
+
111
+ def get_offset_for_tokens(tokens):
112
+ sentence = " ".join(tokens)
113
+ token_offsets = []
114
+ end_idx = 0
115
+ for token in tokens:
116
+ idx = sentence[end_idx:].index(token) + end_idx
117
+ end_idx = idx + len(token)
118
+ offset = (idx, end_idx)
119
+ token_offsets.append(offset)
120
+ return token_offsets
121
+
122
+
123
+ def get_token_offsets(batch):
124
+ token_offset_list = []
125
+ for tokens in batch:
126
+ token_offsets = get_offset_for_tokens(tokens)
127
+ token_offset_list.append(token_offsets)
128
+ return token_offset_list
129
+
130
+
131
+ def pad_output(output, pad_idx=0):
132
+ padded_output = {}
133
+ for input_key in output.keys():
134
+ indexes = output[input_key]
135
+ max_len = max([len(x) for x in indexes])
136
+ padded_indexes = []
137
+ for index_list in indexes:
138
+ cur_len = len(index_list)
139
+ pad_len = max_len - cur_len
140
+ padded_indexes.append(index_list + [pad_idx] * pad_len)
141
+ padded_output[input_key] = padded_indexes
142
+ return padded_output
143
+
144
+
145
+ def tokenize_batch(tokenizer, batch_tokens, index_name="bert",
146
+ max_bpe_length=80, max_bpe_pieces=5):
147
+ timings = {}
148
+ t0 = time()
149
+ # get batch with sentences
150
+ batch_sentences = [" ".join(x) for x in batch_tokens]
151
+ # get token level offsets
152
+ token_offset_list = get_token_offsets(batch_tokens)
153
+ # token_offset_list = get_token_offsets_multi(batch_tokens)
154
+ t1 = time()
155
+ timings["offset_time"] = t1 - t0
156
+ # tokenize batch
157
+ tokenizer_output = tokenizer.batch_encode_plus(batch_sentences,
158
+ pad_to_max_length=False,
159
+ return_offsets_mapping=True,
160
+ add_special_tokens=False)
161
+
162
+ t2 = time()
163
+ timings["tokenize_time"] = t2 - t1
164
+ # postprocess batch
165
+ output = get_offsets_and_reduce_input_ids(tokenizer_output,
166
+ token_offset_list,
167
+ index_name=index_name,
168
+ max_bpe_length=max_bpe_length,
169
+ max_bpe_pieces=max_bpe_pieces)
170
+
171
+ t3 = time()
172
+ timings["reduce_time"] = t3 - t2
173
+ # pad output
174
+ output = pad_output(output)
175
+ t4 = time()
176
+ timings["pading_time"] = t4 - t3
177
+ # tt = sum(timings.values())
178
+ # timings = {k:f"{round(v*100/tt, 2)}%" for k,v in timings.items()}
179
+ # print(timings)
180
+
181
+ return output
gector/tokenizer_indexer.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tweaked version of corresponding AllenNLP file"""
2
+ import logging
3
+ from collections import defaultdict
4
+ from typing import Dict, List, Callable
5
+
6
+ from allennlp.common.util import pad_sequence_to_length
7
+ from allennlp.data.token_indexers.token_indexer import TokenIndexer
8
+ from allennlp.data.tokenizers.token import Token
9
+ from allennlp.data.vocabulary import Vocabulary
10
+ from overrides import overrides
11
+ from transformers import AutoTokenizer
12
+
13
+ from utils.helpers import START_TOKEN
14
+
15
+ from gector.tokenization import tokenize_batch
16
+ import copy
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ # TODO(joelgrus): Figure out how to generate token_type_ids out of this token indexer.
22
+
23
+
24
+ class TokenizerIndexer(TokenIndexer[int]):
25
+ """
26
+ A token indexer that does the wordpiece-tokenization (e.g. for BERT embeddings).
27
+ If you are using one of the pretrained BERT models, you'll want to use the ``PretrainedBertIndexer``
28
+ subclass rather than this base class.
29
+
30
+ Parameters
31
+ ----------
32
+ tokenizer : ``Callable[[str], List[str]]``
33
+ A function that does the actual tokenization.
34
+ max_pieces : int, optional (default: 512)
35
+ The BERT embedder uses positional embeddings and so has a corresponding
36
+ maximum length for its input ids. Any inputs longer than this will
37
+ either be truncated (default), or be split apart and batched using a
38
+ sliding window.
39
+ token_min_padding_length : ``int``, optional (default=``0``)
40
+ See :class:`TokenIndexer`.
41
+ """
42
+
43
+ def __init__(self,
44
+ tokenizer: Callable[[str], List[str]],
45
+ max_pieces: int = 512,
46
+ max_pieces_per_token: int = 3,
47
+ token_min_padding_length: int = 0) -> None:
48
+ super().__init__(token_min_padding_length)
49
+
50
+ # The BERT code itself does a two-step tokenization:
51
+ # sentence -> [words], and then word -> [wordpieces]
52
+ # In AllenNLP, the first step is implemented as the ``BertBasicWordSplitter``,
53
+ # and this token indexer handles the second.
54
+
55
+ self.tokenizer = tokenizer
56
+ self.max_pieces_per_token = max_pieces_per_token
57
+ self.max_pieces = max_pieces
58
+ self.max_pieces_per_sentence = 80
59
+
60
+ @overrides
61
+ def tokens_to_indices(self, tokens: List[Token],
62
+ vocabulary: Vocabulary,
63
+ index_name: str) -> Dict[str, List[int]]:
64
+ text = [token.text for token in tokens]
65
+ batch_tokens = [text]
66
+
67
+ output_fast = tokenize_batch(self.tokenizer,
68
+ batch_tokens,
69
+ max_bpe_length=self.max_pieces,
70
+ max_bpe_pieces=self.max_pieces_per_token)
71
+ output_fast = {k: v[0] for k, v in output_fast.items()}
72
+ return output_fast
73
+
74
+ @overrides
75
+ def count_vocab_items(self, token: Token, counter: Dict[str, Dict[str, int]]):
76
+ # If we only use pretrained models, we don't need to do anything here.
77
+ pass
78
+
79
+ @overrides
80
+ def get_padding_token(self) -> int:
81
+ return 0
82
+
83
+ @overrides
84
+ def get_padding_lengths(self, token: int) -> Dict[str, int]: # pylint: disable=unused-argument
85
+ return {}
86
+
87
+ @overrides
88
+ def pad_token_sequence(self,
89
+ tokens: Dict[str, List[int]],
90
+ desired_num_tokens: Dict[str, int],
91
+ padding_lengths: Dict[str, int]) -> Dict[str, List[int]]: # pylint: disable=unused-argument
92
+ return {key: pad_sequence_to_length(val, desired_num_tokens[key])
93
+ for key, val in tokens.items()}
94
+
95
+ @overrides
96
+ def get_keys(self, index_name: str) -> List[str]:
97
+ """
98
+ We need to override this because the indexer generates multiple keys.
99
+ """
100
+ # pylint: disable=no-self-use
101
+ return [index_name, f"{index_name}-offsets", f"{index_name}-type-ids", "mask"]
102
+
103
+
104
+ class PretrainedBertIndexer(TokenizerIndexer):
105
+ # pylint: disable=line-too-long
106
+ """
107
+ A ``TokenIndexer`` corresponding to a pretrained BERT model.
108
+
109
+ Parameters
110
+ ----------
111
+ pretrained_model: ``str``
112
+ Either the name of the pretrained model to use (e.g. 'bert-base-uncased'),
113
+ or the path to the .txt file with its vocabulary.
114
+ If the name is a key in the list of pretrained models at
115
+ https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/pytorch_pretrained_bert/tokenization.py#L33
116
+ the corresponding path will be used; otherwise it will be interpreted as a path or URL.
117
+ do_lowercase: ``bool``, optional (default = True)
118
+ Whether to lowercase the tokens before converting to wordpiece ids.
119
+ max_pieces: int, optional (default: 512)
120
+ The BERT embedder uses positional embeddings and so has a corresponding
121
+ maximum length for its input ids. Any inputs longer than this will
122
+ either be truncated (default), or be split apart and batched using a
123
+ sliding window.
124
+ """
125
+
126
+ def __init__(self,
127
+ pretrained_model: str,
128
+ do_lowercase: bool = True,
129
+ max_pieces: int = 512,
130
+ max_pieces_per_token: int = 5,
131
+ special_tokens_fix: int = 0) -> None:
132
+
133
+ if pretrained_model.endswith("-cased") and do_lowercase:
134
+ logger.warning("Your BERT model appears to be cased, "
135
+ "but your indexer is lowercasing tokens.")
136
+ elif pretrained_model.endswith("-uncased") and not do_lowercase:
137
+ logger.warning("Your BERT model appears to be uncased, "
138
+ "but your indexer is not lowercasing tokens.")
139
+
140
+ model_name = copy.deepcopy(pretrained_model)
141
+
142
+ model_tokenizer = AutoTokenizer.from_pretrained(
143
+ model_name, do_lower_case=do_lowercase, do_basic_tokenize=False, use_fast=True)
144
+
145
+ # to adjust all tokenizers
146
+ if hasattr(model_tokenizer, 'encoder'):
147
+ model_tokenizer.vocab = model_tokenizer.encoder
148
+ if hasattr(model_tokenizer, 'sp_model'):
149
+ model_tokenizer.vocab = defaultdict(lambda: 1)
150
+ for i in range(model_tokenizer.sp_model.get_piece_size()):
151
+ model_tokenizer.vocab[model_tokenizer.sp_model.id_to_piece(i)] = i
152
+
153
+ if special_tokens_fix:
154
+ model_tokenizer.add_tokens([START_TOKEN])
155
+ model_tokenizer.vocab[START_TOKEN] = len(model_tokenizer) - 1
156
+
157
+ super().__init__(tokenizer=model_tokenizer,
158
+ max_pieces=max_pieces,
159
+ max_pieces_per_token=max_pieces_per_token
160
+ )
161
+
gector/trainer.py ADDED
@@ -0,0 +1,845 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tweaked version of corresponding AllenNLP file"""
2
+ import datetime
3
+ import logging
4
+ import math
5
+ import os
6
+ import time
7
+ import traceback
8
+ from typing import Dict, Optional, List, Tuple, Union, Iterable, Any
9
+
10
+ import torch
11
+ import torch.optim.lr_scheduler
12
+ from allennlp.common import Params
13
+ from allennlp.common.checks import ConfigurationError, parse_cuda_device
14
+ from allennlp.common.tqdm import Tqdm
15
+ from allennlp.common.util import dump_metrics, gpu_memory_mb, peak_memory_mb, lazy_groups_of
16
+ from allennlp.data.instance import Instance
17
+ from allennlp.data.iterators.data_iterator import DataIterator, TensorDict
18
+ from allennlp.models.model import Model
19
+ from allennlp.nn import util as nn_util
20
+ from allennlp.training import util as training_util
21
+ from allennlp.training.checkpointer import Checkpointer
22
+ from allennlp.training.learning_rate_schedulers import LearningRateScheduler
23
+ from allennlp.training.metric_tracker import MetricTracker
24
+ from allennlp.training.momentum_schedulers import MomentumScheduler
25
+ from allennlp.training.moving_average import MovingAverage
26
+ from allennlp.training.optimizers import Optimizer
27
+ from allennlp.training.tensorboard_writer import TensorboardWriter
28
+ from allennlp.training.trainer_base import TrainerBase
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+
33
+ class Trainer(TrainerBase):
34
+ def __init__(
35
+ self,
36
+ model: Model,
37
+ optimizer: torch.optim.Optimizer,
38
+ scheduler: torch.optim.lr_scheduler,
39
+ iterator: DataIterator,
40
+ train_dataset: Iterable[Instance],
41
+ validation_dataset: Optional[Iterable[Instance]] = None,
42
+ patience: Optional[int] = None,
43
+ validation_metric: str = "-loss",
44
+ validation_iterator: DataIterator = None,
45
+ shuffle: bool = True,
46
+ num_epochs: int = 20,
47
+ accumulated_batch_count: int = 1,
48
+ serialization_dir: Optional[str] = None,
49
+ num_serialized_models_to_keep: int = 20,
50
+ keep_serialized_model_every_num_seconds: int = None,
51
+ checkpointer: Checkpointer = None,
52
+ model_save_interval: float = None,
53
+ cuda_device: Union[int, List] = -1,
54
+ grad_norm: Optional[float] = None,
55
+ grad_clipping: Optional[float] = None,
56
+ learning_rate_scheduler: Optional[LearningRateScheduler] = None,
57
+ momentum_scheduler: Optional[MomentumScheduler] = None,
58
+ summary_interval: int = 100,
59
+ histogram_interval: int = None,
60
+ should_log_parameter_statistics: bool = True,
61
+ should_log_learning_rate: bool = False,
62
+ log_batch_size_period: Optional[int] = None,
63
+ moving_average: Optional[MovingAverage] = None,
64
+ cold_step_count: int = 0,
65
+ cold_lr: float = 1e-3,
66
+ cuda_verbose_step=None,
67
+ ) -> None:
68
+ """
69
+ A trainer for doing supervised learning. It just takes a labeled dataset
70
+ and a ``DataIterator``, and uses the supplied ``Optimizer`` to learn the weights
71
+ for your model over some fixed number of epochs. You can also pass in a validation
72
+ dataset and enable early stopping. There are many other bells and whistles as well.
73
+
74
+ Parameters
75
+ ----------
76
+ model : ``Model``, required.
77
+ An AllenNLP model to be optimized. Pytorch Modules can also be optimized if
78
+ their ``forward`` method returns a dictionary with a "loss" key, containing a
79
+ scalar tensor representing the loss function to be optimized.
80
+
81
+ If you are training your model using GPUs, your model should already be
82
+ on the correct device. (If you use `Trainer.from_params` this will be
83
+ handled for you.)
84
+ optimizer : ``torch.nn.Optimizer``, required.
85
+ An instance of a Pytorch Optimizer, instantiated with the parameters of the
86
+ model to be optimized.
87
+ iterator : ``DataIterator``, required.
88
+ A method for iterating over a ``Dataset``, yielding padded indexed batches.
89
+ train_dataset : ``Dataset``, required.
90
+ A ``Dataset`` to train on. The dataset should have already been indexed.
91
+ validation_dataset : ``Dataset``, optional, (default = None).
92
+ A ``Dataset`` to evaluate on. The dataset should have already been indexed.
93
+ patience : Optional[int] > 0, optional (default=None)
94
+ Number of epochs to be patient before early stopping: the training is stopped
95
+ after ``patience`` epochs with no improvement. If given, it must be ``> 0``.
96
+ If None, early stopping is disabled.
97
+ validation_metric : str, optional (default="loss")
98
+ Validation metric to measure for whether to stop training using patience
99
+ and whether to serialize an ``is_best`` model each epoch. The metric name
100
+ must be prepended with either "+" or "-", which specifies whether the metric
101
+ is an increasing or decreasing function.
102
+ validation_iterator : ``DataIterator``, optional (default=None)
103
+ An iterator to use for the validation set. If ``None``, then
104
+ use the training `iterator`.
105
+ shuffle: ``bool``, optional (default=True)
106
+ Whether to shuffle the instances in the iterator or not.
107
+ num_epochs : int, optional (default = 20)
108
+ Number of training epochs.
109
+ serialization_dir : str, optional (default=None)
110
+ Path to directory for saving and loading model files. Models will not be saved if
111
+ this parameter is not passed.
112
+ num_serialized_models_to_keep : ``int``, optional (default=20)
113
+ Number of previous model checkpoints to retain. Default is to keep 20 checkpoints.
114
+ A value of None or -1 means all checkpoints will be kept.
115
+ keep_serialized_model_every_num_seconds : ``int``, optional (default=None)
116
+ If num_serialized_models_to_keep is not None, then occasionally it's useful to
117
+ save models at a given interval in addition to the last num_serialized_models_to_keep.
118
+ To do so, specify keep_serialized_model_every_num_seconds as the number of seconds
119
+ between permanently saved checkpoints. Note that this option is only used if
120
+ num_serialized_models_to_keep is not None, otherwise all checkpoints are kept.
121
+ checkpointer : ``Checkpointer``, optional (default=None)
122
+ An instance of class Checkpointer to use instead of the default. If a checkpointer is specified,
123
+ the arguments num_serialized_models_to_keep and keep_serialized_model_every_num_seconds should
124
+ not be specified. The caller is responsible for initializing the checkpointer so that it is
125
+ consistent with serialization_dir.
126
+ model_save_interval : ``float``, optional (default=None)
127
+ If provided, then serialize models every ``model_save_interval``
128
+ seconds within single epochs. In all cases, models are also saved
129
+ at the end of every epoch if ``serialization_dir`` is provided.
130
+ cuda_device : ``Union[int, List[int]]``, optional (default = -1)
131
+ An integer or list of integers specifying the CUDA device(s) to use. If -1, the CPU is used.
132
+ grad_norm : ``float``, optional, (default = None).
133
+ If provided, gradient norms will be rescaled to have a maximum of this value.
134
+ grad_clipping : ``float``, optional (default = ``None``).
135
+ If provided, gradients will be clipped `during the backward pass` to have an (absolute)
136
+ maximum of this value. If you are getting ``NaNs`` in your gradients during training
137
+ that are not solved by using ``grad_norm``, you may need this.
138
+ learning_rate_scheduler : ``LearningRateScheduler``, optional (default = None)
139
+ If specified, the learning rate will be decayed with respect to
140
+ this schedule at the end of each epoch (or batch, if the scheduler implements
141
+ the ``step_batch`` method). If you use :class:`torch.optim.lr_scheduler.ReduceLROnPlateau`,
142
+ this will use the ``validation_metric`` provided to determine if learning has plateaued.
143
+ To support updating the learning rate on every batch, this can optionally implement
144
+ ``step_batch(batch_num_total)`` which updates the learning rate given the batch number.
145
+ momentum_scheduler : ``MomentumScheduler``, optional (default = None)
146
+ If specified, the momentum will be updated at the end of each batch or epoch
147
+ according to the schedule.
148
+ summary_interval: ``int``, optional, (default = 100)
149
+ Number of batches between logging scalars to tensorboard
150
+ histogram_interval : ``int``, optional, (default = ``None``)
151
+ If not None, then log histograms to tensorboard every ``histogram_interval`` batches.
152
+ When this parameter is specified, the following additional logging is enabled:
153
+ * Histograms of model parameters
154
+ * The ratio of parameter update norm to parameter norm
155
+ * Histogram of layer activations
156
+ We log histograms of the parameters returned by
157
+ ``model.get_parameters_for_histogram_tensorboard_logging``.
158
+ The layer activations are logged for any modules in the ``Model`` that have
159
+ the attribute ``should_log_activations`` set to ``True``. Logging
160
+ histograms requires a number of GPU-CPU copies during training and is typically
161
+ slow, so we recommend logging histograms relatively infrequently.
162
+ Note: only Modules that return tensors, tuples of tensors or dicts
163
+ with tensors as values currently support activation logging.
164
+ should_log_parameter_statistics : ``bool``, optional, (default = True)
165
+ Whether to send parameter statistics (mean and standard deviation
166
+ of parameters and gradients) to tensorboard.
167
+ should_log_learning_rate : ``bool``, optional, (default = False)
168
+ Whether to send parameter specific learning rate to tensorboard.
169
+ log_batch_size_period : ``int``, optional, (default = ``None``)
170
+ If defined, how often to log the average batch size.
171
+ moving_average: ``MovingAverage``, optional, (default = None)
172
+ If provided, we will maintain moving averages for all parameters. During training, we
173
+ employ a shadow variable for each parameter, which maintains the moving average. During
174
+ evaluation, we backup the original parameters and assign the moving averages to corresponding
175
+ parameters. Be careful that when saving the checkpoint, we will save the moving averages of
176
+ parameters. This is necessary because we want the saved model to perform as well as the validated
177
+ model if we load it later. But this may cause problems if you restart the training from checkpoint.
178
+ """
179
+ super().__init__(serialization_dir, cuda_device)
180
+
181
+ # I am not calling move_to_gpu here, because if the model is
182
+ # not already on the GPU then the optimizer is going to be wrong.
183
+ self.model = model
184
+
185
+ self.iterator = iterator
186
+ self._validation_iterator = validation_iterator
187
+ self.shuffle = shuffle
188
+ self.optimizer = optimizer
189
+ self.scheduler = scheduler
190
+ self.train_data = train_dataset
191
+ self._validation_data = validation_dataset
192
+ self.accumulated_batch_count = accumulated_batch_count
193
+ self.cold_step_count = cold_step_count
194
+ self.cold_lr = cold_lr
195
+ self.cuda_verbose_step = cuda_verbose_step
196
+
197
+ if patience is None: # no early stopping
198
+ if validation_dataset:
199
+ logger.warning(
200
+ "You provided a validation dataset but patience was set to None, "
201
+ "meaning that early stopping is disabled"
202
+ )
203
+ elif (not isinstance(patience, int)) or patience <= 0:
204
+ raise ConfigurationError(
205
+ '{} is an invalid value for "patience": it must be a positive integer '
206
+ "or None (if you want to disable early stopping)".format(patience)
207
+ )
208
+
209
+ # For tracking is_best_so_far and should_stop_early
210
+ self._metric_tracker = MetricTracker(patience, validation_metric)
211
+ # Get rid of + or -
212
+ self._validation_metric = validation_metric[1:]
213
+
214
+ self._num_epochs = num_epochs
215
+
216
+ if checkpointer is not None:
217
+ # We can't easily check if these parameters were passed in, so check against their default values.
218
+ # We don't check against serialization_dir since it is also used by the parent class.
219
+ if num_serialized_models_to_keep != 20 \
220
+ or keep_serialized_model_every_num_seconds is not None:
221
+ raise ConfigurationError(
222
+ "When passing a custom Checkpointer, you may not also pass in separate checkpointer "
223
+ "args 'num_serialized_models_to_keep' or 'keep_serialized_model_every_num_seconds'."
224
+ )
225
+ self._checkpointer = checkpointer
226
+ else:
227
+ self._checkpointer = Checkpointer(
228
+ serialization_dir,
229
+ keep_serialized_model_every_num_seconds,
230
+ num_serialized_models_to_keep,
231
+ )
232
+
233
+ self._model_save_interval = model_save_interval
234
+
235
+ self._grad_norm = grad_norm
236
+ self._grad_clipping = grad_clipping
237
+
238
+ self._learning_rate_scheduler = learning_rate_scheduler
239
+ self._momentum_scheduler = momentum_scheduler
240
+ self._moving_average = moving_average
241
+
242
+ # We keep the total batch number as an instance variable because it
243
+ # is used inside a closure for the hook which logs activations in
244
+ # ``_enable_activation_logging``.
245
+ self._batch_num_total = 0
246
+
247
+ self._tensorboard = TensorboardWriter(
248
+ get_batch_num_total=lambda: self._batch_num_total,
249
+ serialization_dir=serialization_dir,
250
+ summary_interval=summary_interval,
251
+ histogram_interval=histogram_interval,
252
+ should_log_parameter_statistics=should_log_parameter_statistics,
253
+ should_log_learning_rate=should_log_learning_rate,
254
+ )
255
+
256
+ self._log_batch_size_period = log_batch_size_period
257
+
258
+ self._last_log = 0.0 # time of last logging
259
+
260
+ # Enable activation logging.
261
+ if histogram_interval is not None:
262
+ self._tensorboard.enable_activation_logging(self.model)
263
+
264
+ def rescale_gradients(self) -> Optional[float]:
265
+ return training_util.rescale_gradients(self.model, self._grad_norm)
266
+
267
+ def batch_loss(self, batch_group: List[TensorDict], for_training: bool) -> torch.Tensor:
268
+ """
269
+ Does a forward pass on the given batches and returns the ``loss`` value in the result.
270
+ If ``for_training`` is `True` also applies regularization penalty.
271
+ """
272
+ if self._multiple_gpu:
273
+ output_dict = training_util.data_parallel(batch_group, self.model, self._cuda_devices)
274
+ else:
275
+ assert len(batch_group) == 1
276
+ batch = batch_group[0]
277
+ batch = nn_util.move_to_device(batch, self._cuda_devices[0])
278
+ output_dict = self.model(**batch)
279
+
280
+ try:
281
+ loss = output_dict["loss"]
282
+ if for_training:
283
+ loss += self.model.get_regularization_penalty()
284
+ except KeyError:
285
+ if for_training:
286
+ raise RuntimeError(
287
+ "The model you are trying to optimize does not contain a"
288
+ " 'loss' key in the output of model.forward(inputs)."
289
+ )
290
+ loss = None
291
+
292
+ return loss
293
+
294
+ def _train_epoch(self, epoch: int) -> Dict[str, float]:
295
+ """
296
+ Trains one epoch and returns metrics.
297
+ """
298
+ logger.info("Epoch %d/%d", epoch, self._num_epochs - 1)
299
+ peak_cpu_usage = peak_memory_mb()
300
+ logger.info(f"Peak CPU memory usage MB: {peak_cpu_usage}")
301
+ gpu_usage = []
302
+ for gpu, memory in gpu_memory_mb().items():
303
+ gpu_usage.append((gpu, memory))
304
+ logger.info(f"GPU {gpu} memory usage MB: {memory}")
305
+
306
+ train_loss = 0.0
307
+ # Set the model to "train" mode.
308
+ self.model.train()
309
+
310
+ num_gpus = len(self._cuda_devices)
311
+
312
+ # Get tqdm for the training batches
313
+ raw_train_generator = self.iterator(self.train_data, num_epochs=1, shuffle=self.shuffle)
314
+ train_generator = lazy_groups_of(raw_train_generator, num_gpus)
315
+ num_training_batches = math.ceil(self.iterator.get_num_batches(self.train_data) / num_gpus)
316
+ residue = num_training_batches % self.accumulated_batch_count
317
+ self._last_log = time.time()
318
+ last_save_time = time.time()
319
+
320
+ batches_this_epoch = 0
321
+ if self._batch_num_total is None:
322
+ self._batch_num_total = 0
323
+
324
+ histogram_parameters = set(self.model.get_parameters_for_histogram_tensorboard_logging())
325
+
326
+ logger.info("Training")
327
+ train_generator_tqdm = Tqdm.tqdm(train_generator, total=num_training_batches)
328
+ cumulative_batch_size = 0
329
+ self.optimizer.zero_grad()
330
+ for batch_group in train_generator_tqdm:
331
+ batches_this_epoch += 1
332
+ self._batch_num_total += 1
333
+ batch_num_total = self._batch_num_total
334
+
335
+ iter_len = self.accumulated_batch_count \
336
+ if batches_this_epoch <= (num_training_batches - residue) else residue
337
+
338
+ if self.cuda_verbose_step is not None and batch_num_total % self.cuda_verbose_step == 0:
339
+ print(f'Before forward pass - Cuda memory allocated: {torch.cuda.memory_allocated() / 1e9}')
340
+ print(f'Before forward pass - Cuda memory cached: {torch.cuda.memory_cached() / 1e9}')
341
+ try:
342
+ loss = self.batch_loss(batch_group, for_training=True) / iter_len
343
+ except RuntimeError as e:
344
+ print(e)
345
+ for x in batch_group:
346
+ all_words = [len(y['words']) for y in x['metadata']]
347
+ print(f"Total sents: {len(all_words)}. "
348
+ f"Min {min(all_words)}. Max {max(all_words)}")
349
+ for elem in ['labels', 'd_tags']:
350
+ tt = x[elem]
351
+ print(
352
+ f"{elem} shape {list(tt.shape)} and min {tt.min().item()} and {tt.max().item()}")
353
+ for elem in ["bert", "mask", "bert-offsets"]:
354
+ tt = x['tokens'][elem]
355
+ print(
356
+ f"{elem} shape {list(tt.shape)} and min {tt.min().item()} and {tt.max().item()}")
357
+ raise e
358
+
359
+ if self.cuda_verbose_step is not None and batch_num_total % self.cuda_verbose_step == 0:
360
+ print(f'After forward pass - Cuda memory allocated: {torch.cuda.memory_allocated() / 1e9}')
361
+ print(f'After forward pass - Cuda memory cached: {torch.cuda.memory_cached() / 1e9}')
362
+
363
+ if torch.isnan(loss):
364
+ raise ValueError("nan loss encountered")
365
+
366
+ loss.backward()
367
+
368
+ if self.cuda_verbose_step is not None and batch_num_total % self.cuda_verbose_step == 0:
369
+ print(f'After backprop - Cuda memory allocated: {torch.cuda.memory_allocated() / 1e9}')
370
+ print(f'After backprop - Cuda memory cached: {torch.cuda.memory_cached() / 1e9}')
371
+
372
+ train_loss += loss.item() * iter_len
373
+
374
+ del batch_group, loss
375
+ torch.cuda.empty_cache()
376
+
377
+ if self.cuda_verbose_step is not None and batch_num_total % self.cuda_verbose_step == 0:
378
+ print(f'After collecting garbage - Cuda memory allocated: {torch.cuda.memory_allocated() / 1e9}')
379
+ print(f'After collecting garbage - Cuda memory cached: {torch.cuda.memory_cached() / 1e9}')
380
+
381
+ batch_grad_norm = self.rescale_gradients()
382
+
383
+ # This does nothing if batch_num_total is None or you are using a
384
+ # scheduler which doesn't update per batch.
385
+ if self._learning_rate_scheduler:
386
+ self._learning_rate_scheduler.step_batch(batch_num_total)
387
+ if self._momentum_scheduler:
388
+ self._momentum_scheduler.step_batch(batch_num_total)
389
+
390
+ if self._tensorboard.should_log_histograms_this_batch():
391
+ # get the magnitude of parameter updates for logging
392
+ # We need a copy of current parameters to compute magnitude of updates,
393
+ # and copy them to CPU so large models won't go OOM on the GPU.
394
+ param_updates = {
395
+ name: param.detach().cpu().clone()
396
+ for name, param in self.model.named_parameters()
397
+ }
398
+ if batches_this_epoch % self.accumulated_batch_count == 0 or \
399
+ batches_this_epoch == num_training_batches:
400
+ self.optimizer.step()
401
+ self.optimizer.zero_grad()
402
+ for name, param in self.model.named_parameters():
403
+ param_updates[name].sub_(param.detach().cpu())
404
+ update_norm = torch.norm(param_updates[name].view(-1))
405
+ param_norm = torch.norm(param.view(-1)).cpu()
406
+ self._tensorboard.add_train_scalar(
407
+ "gradient_update/" + name, update_norm / (param_norm + 1e-7)
408
+ )
409
+ else:
410
+ if batches_this_epoch % self.accumulated_batch_count == 0 or \
411
+ batches_this_epoch == num_training_batches:
412
+ self.optimizer.step()
413
+ self.optimizer.zero_grad()
414
+
415
+ # Update moving averages
416
+ if self._moving_average is not None:
417
+ self._moving_average.apply(batch_num_total)
418
+
419
+ # Update the description with the latest metrics
420
+ metrics = training_util.get_metrics(self.model, train_loss, batches_this_epoch)
421
+ description = training_util.description_from_metrics(metrics)
422
+
423
+ train_generator_tqdm.set_description(description, refresh=False)
424
+
425
+ # Log parameter values to Tensorboard
426
+ if self._tensorboard.should_log_this_batch():
427
+ self._tensorboard.log_parameter_and_gradient_statistics(self.model, batch_grad_norm)
428
+ self._tensorboard.log_learning_rates(self.model, self.optimizer)
429
+
430
+ self._tensorboard.add_train_scalar("loss/loss_train", metrics["loss"])
431
+ self._tensorboard.log_metrics({"epoch_metrics/" + k: v for k, v in metrics.items()})
432
+
433
+ if self._tensorboard.should_log_histograms_this_batch():
434
+ self._tensorboard.log_histograms(self.model, histogram_parameters)
435
+
436
+ if self._log_batch_size_period:
437
+ cur_batch = sum([training_util.get_batch_size(batch) for batch in batch_group])
438
+ cumulative_batch_size += cur_batch
439
+ if (batches_this_epoch - 1) % self._log_batch_size_period == 0:
440
+ average = cumulative_batch_size / batches_this_epoch
441
+ logger.info(f"current batch size: {cur_batch} mean batch size: {average}")
442
+ self._tensorboard.add_train_scalar("current_batch_size", cur_batch)
443
+ self._tensorboard.add_train_scalar("mean_batch_size", average)
444
+
445
+ # Save model if needed.
446
+ if self._model_save_interval is not None and (
447
+ time.time() - last_save_time > self._model_save_interval
448
+ ):
449
+ last_save_time = time.time()
450
+ self._save_checkpoint(
451
+ "{0}.{1}".format(epoch, training_util.time_to_str(int(last_save_time)))
452
+ )
453
+
454
+ metrics = training_util.get_metrics(self.model, train_loss, batches_this_epoch, reset=True)
455
+ metrics["cpu_memory_MB"] = peak_cpu_usage
456
+ for (gpu_num, memory) in gpu_usage:
457
+ metrics["gpu_" + str(gpu_num) + "_memory_MB"] = memory
458
+ return metrics
459
+
460
+ def _validation_loss(self) -> Tuple[float, int]:
461
+ """
462
+ Computes the validation loss. Returns it and the number of batches.
463
+ """
464
+ logger.info("Validating")
465
+
466
+ self.model.eval()
467
+
468
+ # Replace parameter values with the shadow values from the moving averages.
469
+ if self._moving_average is not None:
470
+ self._moving_average.assign_average_value()
471
+
472
+ if self._validation_iterator is not None:
473
+ val_iterator = self._validation_iterator
474
+ else:
475
+ val_iterator = self.iterator
476
+
477
+ num_gpus = len(self._cuda_devices)
478
+
479
+ raw_val_generator = val_iterator(self._validation_data, num_epochs=1, shuffle=False)
480
+ val_generator = lazy_groups_of(raw_val_generator, num_gpus)
481
+ num_validation_batches = math.ceil(
482
+ val_iterator.get_num_batches(self._validation_data) / num_gpus
483
+ )
484
+ val_generator_tqdm = Tqdm.tqdm(val_generator, total=num_validation_batches)
485
+ batches_this_epoch = 0
486
+ val_loss = 0
487
+ for batch_group in val_generator_tqdm:
488
+
489
+ loss = self.batch_loss(batch_group, for_training=False)
490
+ if loss is not None:
491
+ # You shouldn't necessarily have to compute a loss for validation, so we allow for
492
+ # `loss` to be None. We need to be careful, though - `batches_this_epoch` is
493
+ # currently only used as the divisor for the loss function, so we can safely only
494
+ # count those batches for which we actually have a loss. If this variable ever
495
+ # gets used for something else, we might need to change things around a bit.
496
+ batches_this_epoch += 1
497
+ val_loss += loss.detach().cpu().numpy()
498
+
499
+ # Update the description with the latest metrics
500
+ val_metrics = training_util.get_metrics(self.model, val_loss, batches_this_epoch)
501
+ description = training_util.description_from_metrics(val_metrics)
502
+ val_generator_tqdm.set_description(description, refresh=False)
503
+
504
+ # Now restore the original parameter values.
505
+ if self._moving_average is not None:
506
+ self._moving_average.restore()
507
+
508
+ return val_loss, batches_this_epoch
509
+
510
+ def train(self) -> Dict[str, Any]:
511
+ """
512
+ Trains the supplied model with the supplied parameters.
513
+ """
514
+ try:
515
+ epoch_counter = self._restore_checkpoint()
516
+ except RuntimeError:
517
+ traceback.print_exc()
518
+ raise ConfigurationError(
519
+ "Could not recover training from the checkpoint. Did you mean to output to "
520
+ "a different serialization directory or delete the existing serialization "
521
+ "directory?"
522
+ )
523
+
524
+ training_util.enable_gradient_clipping(self.model, self._grad_clipping)
525
+
526
+ logger.info("Beginning training.")
527
+
528
+ train_metrics: Dict[str, float] = {}
529
+ val_metrics: Dict[str, float] = {}
530
+ this_epoch_val_metric: float = None
531
+ metrics: Dict[str, Any] = {}
532
+ epochs_trained = 0
533
+ training_start_time = time.time()
534
+
535
+ if self.cold_step_count > 0:
536
+ base_lr = self.optimizer.param_groups[0]['lr']
537
+ for param_group in self.optimizer.param_groups:
538
+ param_group['lr'] = self.cold_lr
539
+ self.model.text_field_embedder._token_embedders['bert'].set_weights(freeze=True)
540
+
541
+ metrics["best_epoch"] = self._metric_tracker.best_epoch
542
+ for key, value in self._metric_tracker.best_epoch_metrics.items():
543
+ metrics["best_validation_" + key] = value
544
+
545
+ for epoch in range(epoch_counter, self._num_epochs):
546
+ if epoch == self.cold_step_count and epoch != 0:
547
+ for param_group in self.optimizer.param_groups:
548
+ param_group['lr'] = base_lr
549
+ self.model.text_field_embedder._token_embedders['bert'].set_weights(freeze=False)
550
+
551
+ epoch_start_time = time.time()
552
+ train_metrics = self._train_epoch(epoch)
553
+
554
+ # get peak of memory usage
555
+ if "cpu_memory_MB" in train_metrics:
556
+ metrics["peak_cpu_memory_MB"] = max(
557
+ metrics.get("peak_cpu_memory_MB", 0), train_metrics["cpu_memory_MB"]
558
+ )
559
+ for key, value in train_metrics.items():
560
+ if key.startswith("gpu_"):
561
+ metrics["peak_" + key] = max(metrics.get("peak_" + key, 0), value)
562
+
563
+ # clear cache before validation
564
+ torch.cuda.empty_cache()
565
+ if self._validation_data is not None:
566
+ with torch.no_grad():
567
+ # We have a validation set, so compute all the metrics on it.
568
+ val_loss, num_batches = self._validation_loss()
569
+ val_metrics = training_util.get_metrics(
570
+ self.model, val_loss, num_batches, reset=True
571
+ )
572
+
573
+ # Check validation metric for early stopping
574
+ this_epoch_val_metric = val_metrics[self._validation_metric]
575
+ self._metric_tracker.add_metric(this_epoch_val_metric)
576
+
577
+ if self._metric_tracker.should_stop_early():
578
+ logger.info("Ran out of patience. Stopping training.")
579
+ break
580
+
581
+ self._tensorboard.log_metrics(
582
+ train_metrics, val_metrics=val_metrics, log_to_console=True, epoch=epoch + 1
583
+ ) # +1 because tensorboard doesn't like 0
584
+
585
+ # Create overall metrics dict
586
+ training_elapsed_time = time.time() - training_start_time
587
+ metrics["training_duration"] = str(datetime.timedelta(seconds=training_elapsed_time))
588
+ metrics["training_start_epoch"] = epoch_counter
589
+ metrics["training_epochs"] = epochs_trained
590
+ metrics["epoch"] = epoch
591
+
592
+ for key, value in train_metrics.items():
593
+ metrics["training_" + key] = value
594
+ for key, value in val_metrics.items():
595
+ metrics["validation_" + key] = value
596
+
597
+ # if self.cold_step_count <= epoch:
598
+ self.scheduler.step(metrics['validation_loss'])
599
+
600
+ if self._metric_tracker.is_best_so_far():
601
+ # Update all the best_ metrics.
602
+ # (Otherwise they just stay the same as they were.)
603
+ metrics["best_epoch"] = epoch
604
+ for key, value in val_metrics.items():
605
+ metrics["best_validation_" + key] = value
606
+
607
+ self._metric_tracker.best_epoch_metrics = val_metrics
608
+
609
+ if self._serialization_dir:
610
+ dump_metrics(
611
+ os.path.join(self._serialization_dir, f"metrics_epoch_{epoch}.json"), metrics
612
+ )
613
+
614
+ # The Scheduler API is agnostic to whether your schedule requires a validation metric -
615
+ # if it doesn't, the validation metric passed here is ignored.
616
+ if self._learning_rate_scheduler:
617
+ self._learning_rate_scheduler.step(this_epoch_val_metric, epoch)
618
+ if self._momentum_scheduler:
619
+ self._momentum_scheduler.step(this_epoch_val_metric, epoch)
620
+
621
+ self._save_checkpoint(epoch)
622
+
623
+ epoch_elapsed_time = time.time() - epoch_start_time
624
+ logger.info("Epoch duration: %s", datetime.timedelta(seconds=epoch_elapsed_time))
625
+
626
+ if epoch < self._num_epochs - 1:
627
+ training_elapsed_time = time.time() - training_start_time
628
+ estimated_time_remaining = training_elapsed_time * (
629
+ (self._num_epochs - epoch_counter) / float(epoch - epoch_counter + 1) - 1
630
+ )
631
+ formatted_time = str(datetime.timedelta(seconds=int(estimated_time_remaining)))
632
+ logger.info("Estimated training time remaining: %s", formatted_time)
633
+
634
+ epochs_trained += 1
635
+
636
+ # make sure pending events are flushed to disk and files are closed properly
637
+ # self._tensorboard.close()
638
+
639
+ # Load the best model state before returning
640
+ best_model_state = self._checkpointer.best_model_state()
641
+ if best_model_state:
642
+ self.model.load_state_dict(best_model_state)
643
+
644
+ return metrics
645
+
646
+ def _save_checkpoint(self, epoch: Union[int, str]) -> None:
647
+ """
648
+ Saves a checkpoint of the model to self._serialization_dir.
649
+ Is a no-op if self._serialization_dir is None.
650
+
651
+ Parameters
652
+ ----------
653
+ epoch : Union[int, str], required.
654
+ The epoch of training. If the checkpoint is saved in the middle
655
+ of an epoch, the parameter is a string with the epoch and timestamp.
656
+ """
657
+ # If moving averages are used for parameters, we save
658
+ # the moving average values into checkpoint, instead of the current values.
659
+ if self._moving_average is not None:
660
+ self._moving_average.assign_average_value()
661
+
662
+ # These are the training states we need to persist.
663
+ training_states = {
664
+ "metric_tracker": self._metric_tracker.state_dict(),
665
+ "optimizer": self.optimizer.state_dict(),
666
+ "batch_num_total": self._batch_num_total,
667
+ }
668
+
669
+ # If we have a learning rate or momentum scheduler, we should persist them too.
670
+ if self._learning_rate_scheduler is not None:
671
+ training_states["learning_rate_scheduler"] = self._learning_rate_scheduler.state_dict()
672
+ if self._momentum_scheduler is not None:
673
+ training_states["momentum_scheduler"] = self._momentum_scheduler.state_dict()
674
+
675
+ self._checkpointer.save_checkpoint(
676
+ model_state=self.model.state_dict(),
677
+ epoch=epoch,
678
+ training_states=training_states,
679
+ is_best_so_far=self._metric_tracker.is_best_so_far(),
680
+ )
681
+
682
+ # Restore the original values for parameters so that training will not be affected.
683
+ if self._moving_average is not None:
684
+ self._moving_average.restore()
685
+
686
+ def _restore_checkpoint(self) -> int:
687
+ """
688
+ Restores the model and training state from the last saved checkpoint.
689
+ This includes an epoch count and optimizer state, which is serialized separately
690
+ from model parameters. This function should only be used to continue training -
691
+ if you wish to load a model for inference/load parts of a model into a new
692
+ computation graph, you should use the native Pytorch functions:
693
+ `` model.load_state_dict(torch.load("/path/to/model/weights.th"))``
694
+
695
+ If ``self._serialization_dir`` does not exist or does not contain any checkpointed weights,
696
+ this function will do nothing and return 0.
697
+
698
+ Returns
699
+ -------
700
+ epoch: int
701
+ The epoch at which to resume training, which should be one after the epoch
702
+ in the saved training state.
703
+ """
704
+ model_state, training_state = self._checkpointer.restore_checkpoint()
705
+
706
+ if not training_state:
707
+ # No checkpoint to restore, start at 0
708
+ return 0
709
+
710
+ self.model.load_state_dict(model_state)
711
+ self.optimizer.load_state_dict(training_state["optimizer"])
712
+ if self._learning_rate_scheduler is not None \
713
+ and "learning_rate_scheduler" in training_state:
714
+ self._learning_rate_scheduler.load_state_dict(training_state["learning_rate_scheduler"])
715
+ if self._momentum_scheduler is not None and "momentum_scheduler" in training_state:
716
+ self._momentum_scheduler.load_state_dict(training_state["momentum_scheduler"])
717
+ training_util.move_optimizer_to_cuda(self.optimizer)
718
+
719
+ # Currently the ``training_state`` contains a serialized ``MetricTracker``.
720
+ if "metric_tracker" in training_state:
721
+ self._metric_tracker.load_state_dict(training_state["metric_tracker"])
722
+ # It used to be the case that we tracked ``val_metric_per_epoch``.
723
+ elif "val_metric_per_epoch" in training_state:
724
+ self._metric_tracker.clear()
725
+ self._metric_tracker.add_metrics(training_state["val_metric_per_epoch"])
726
+ # And before that we didn't track anything.
727
+ else:
728
+ self._metric_tracker.clear()
729
+
730
+ if isinstance(training_state["epoch"], int):
731
+ epoch_to_return = training_state["epoch"] + 1
732
+ else:
733
+ epoch_to_return = int(training_state["epoch"].split(".")[0]) + 1
734
+
735
+ # For older checkpoints with batch_num_total missing, default to old behavior where
736
+ # it is unchanged.
737
+ batch_num_total = training_state.get("batch_num_total")
738
+ if batch_num_total is not None:
739
+ self._batch_num_total = batch_num_total
740
+
741
+ return epoch_to_return
742
+
743
+ # Requires custom from_params.
744
+ @classmethod
745
+ def from_params( # type: ignore
746
+ cls,
747
+ model: Model,
748
+ serialization_dir: str,
749
+ iterator: DataIterator,
750
+ train_data: Iterable[Instance],
751
+ validation_data: Optional[Iterable[Instance]],
752
+ params: Params,
753
+ validation_iterator: DataIterator = None,
754
+ ) -> "Trainer":
755
+
756
+ patience = params.pop_int("patience", None)
757
+ validation_metric = params.pop("validation_metric", "-loss")
758
+ shuffle = params.pop_bool("shuffle", True)
759
+ num_epochs = params.pop_int("num_epochs", 20)
760
+ cuda_device = parse_cuda_device(params.pop("cuda_device", -1))
761
+ grad_norm = params.pop_float("grad_norm", None)
762
+ grad_clipping = params.pop_float("grad_clipping", None)
763
+ lr_scheduler_params = params.pop("learning_rate_scheduler", None)
764
+ momentum_scheduler_params = params.pop("momentum_scheduler", None)
765
+
766
+ if isinstance(cuda_device, list):
767
+ model_device = cuda_device[0]
768
+ else:
769
+ model_device = cuda_device
770
+ if model_device >= 0:
771
+ # Moving model to GPU here so that the optimizer state gets constructed on
772
+ # the right device.
773
+ model = model.cuda(model_device)
774
+
775
+ parameters = [[n, p] for n, p in model.named_parameters() if p.requires_grad]
776
+ optimizer = Optimizer.from_params(parameters, params.pop("optimizer"))
777
+ if "moving_average" in params:
778
+ moving_average = MovingAverage.from_params(
779
+ params.pop("moving_average"), parameters=parameters
780
+ )
781
+ else:
782
+ moving_average = None
783
+
784
+ if lr_scheduler_params:
785
+ lr_scheduler = LearningRateScheduler.from_params(optimizer, lr_scheduler_params)
786
+ else:
787
+ lr_scheduler = None
788
+ if momentum_scheduler_params:
789
+ momentum_scheduler = MomentumScheduler.from_params(optimizer, momentum_scheduler_params)
790
+ else:
791
+ momentum_scheduler = None
792
+
793
+ if "checkpointer" in params:
794
+ if "keep_serialized_model_every_num_seconds" in params \
795
+ or "num_serialized_models_to_keep" in params:
796
+ raise ConfigurationError(
797
+ "Checkpointer may be initialized either from the 'checkpointer' key or from the "
798
+ "keys 'num_serialized_models_to_keep' and 'keep_serialized_model_every_num_seconds'"
799
+ " but the passed config uses both methods."
800
+ )
801
+ checkpointer = Checkpointer.from_params(params.pop("checkpointer"))
802
+ else:
803
+ num_serialized_models_to_keep = params.pop_int("num_serialized_models_to_keep", 20)
804
+ keep_serialized_model_every_num_seconds = params.pop_int(
805
+ "keep_serialized_model_every_num_seconds", None
806
+ )
807
+ checkpointer = Checkpointer(
808
+ serialization_dir=serialization_dir,
809
+ num_serialized_models_to_keep=num_serialized_models_to_keep,
810
+ keep_serialized_model_every_num_seconds=keep_serialized_model_every_num_seconds,
811
+ )
812
+ model_save_interval = params.pop_float("model_save_interval", None)
813
+ summary_interval = params.pop_int("summary_interval", 100)
814
+ histogram_interval = params.pop_int("histogram_interval", None)
815
+ should_log_parameter_statistics = params.pop_bool("should_log_parameter_statistics", True)
816
+ should_log_learning_rate = params.pop_bool("should_log_learning_rate", False)
817
+ log_batch_size_period = params.pop_int("log_batch_size_period", None)
818
+
819
+ params.assert_empty(cls.__name__)
820
+ return cls(
821
+ model,
822
+ optimizer,
823
+ iterator,
824
+ train_data,
825
+ validation_data,
826
+ patience=patience,
827
+ validation_metric=validation_metric,
828
+ validation_iterator=validation_iterator,
829
+ shuffle=shuffle,
830
+ num_epochs=num_epochs,
831
+ serialization_dir=serialization_dir,
832
+ cuda_device=cuda_device,
833
+ grad_norm=grad_norm,
834
+ grad_clipping=grad_clipping,
835
+ learning_rate_scheduler=lr_scheduler,
836
+ momentum_scheduler=momentum_scheduler,
837
+ checkpointer=checkpointer,
838
+ model_save_interval=model_save_interval,
839
+ summary_interval=summary_interval,
840
+ histogram_interval=histogram_interval,
841
+ should_log_parameter_statistics=should_log_parameter_statistics,
842
+ should_log_learning_rate=should_log_learning_rate,
843
+ log_batch_size_period=log_batch_size_period,
844
+ moving_average=moving_average,
845
+ )