Upload 7 files
Browse files- gector/bert_token_embedder.py +269 -0
- gector/datareader.py +151 -0
- gector/gec_model.py +298 -0
- gector/seq2labels_model.py +194 -0
- gector/tokenization.py +181 -0
- gector/tokenizer_indexer.py +161 -0
- gector/trainer.py +845 -0
gector/bert_token_embedder.py
ADDED
|
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tweaked version of corresponding AllenNLP file"""
|
| 2 |
+
import logging
|
| 3 |
+
from copy import deepcopy
|
| 4 |
+
from typing import Dict
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from allennlp.modules.token_embedders.token_embedder import TokenEmbedder
|
| 9 |
+
from allennlp.nn import util
|
| 10 |
+
from transformers import AutoModel, PreTrainedModel
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class PretrainedBertModel:
|
| 16 |
+
"""
|
| 17 |
+
In some instances you may want to load the same BERT model twice
|
| 18 |
+
(e.g. to use as a token embedder and also as a pooling layer).
|
| 19 |
+
This factory provides a cache so that you don't actually have to load the model twice.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
_cache: Dict[str, PreTrainedModel] = {}
|
| 23 |
+
|
| 24 |
+
@classmethod
|
| 25 |
+
def load(cls, model_name: str, cache_model: bool = True) -> PreTrainedModel:
|
| 26 |
+
if model_name in cls._cache:
|
| 27 |
+
return PretrainedBertModel._cache[model_name]
|
| 28 |
+
|
| 29 |
+
model = AutoModel.from_pretrained(model_name)
|
| 30 |
+
if cache_model:
|
| 31 |
+
cls._cache[model_name] = model
|
| 32 |
+
|
| 33 |
+
return model
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class BertEmbedder(TokenEmbedder):
|
| 37 |
+
"""
|
| 38 |
+
A ``TokenEmbedder`` that produces BERT embeddings for your tokens.
|
| 39 |
+
Should be paired with a ``BertIndexer``, which produces wordpiece ids.
|
| 40 |
+
Most likely you probably want to use ``PretrainedBertEmbedder``
|
| 41 |
+
for one of the named pretrained models, not this base class.
|
| 42 |
+
Parameters
|
| 43 |
+
----------
|
| 44 |
+
bert_model: ``BertModel``
|
| 45 |
+
The BERT model being wrapped.
|
| 46 |
+
top_layer_only: ``bool``, optional (default = ``False``)
|
| 47 |
+
If ``True``, then only return the top layer instead of apply the scalar mix.
|
| 48 |
+
max_pieces : int, optional (default: 512)
|
| 49 |
+
The BERT embedder uses positional embeddings and so has a corresponding
|
| 50 |
+
maximum length for its input ids. Assuming the inputs are windowed
|
| 51 |
+
and padded appropriately by this length, the embedder will split them into a
|
| 52 |
+
large batch, feed them into BERT, and recombine the output as if it was a
|
| 53 |
+
longer sequence.
|
| 54 |
+
num_start_tokens : int, optional (default: 1)
|
| 55 |
+
The number of starting special tokens input to BERT (usually 1, i.e., [CLS])
|
| 56 |
+
num_end_tokens : int, optional (default: 1)
|
| 57 |
+
The number of ending tokens input to BERT (usually 1, i.e., [SEP])
|
| 58 |
+
scalar_mix_parameters: ``List[float]``, optional, (default = None)
|
| 59 |
+
If not ``None``, use these scalar mix parameters to weight the representations
|
| 60 |
+
produced by different layers. These mixing weights are not updated during
|
| 61 |
+
training.
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
def __init__(
|
| 65 |
+
self,
|
| 66 |
+
bert_model: PreTrainedModel,
|
| 67 |
+
top_layer_only: bool = False,
|
| 68 |
+
max_pieces: int = 512,
|
| 69 |
+
num_start_tokens: int = 1,
|
| 70 |
+
num_end_tokens: int = 1
|
| 71 |
+
) -> None:
|
| 72 |
+
super().__init__()
|
| 73 |
+
self.bert_model = deepcopy(bert_model)
|
| 74 |
+
self.output_dim = bert_model.config.hidden_size
|
| 75 |
+
self.max_pieces = max_pieces
|
| 76 |
+
self.num_start_tokens = num_start_tokens
|
| 77 |
+
self.num_end_tokens = num_end_tokens
|
| 78 |
+
self._scalar_mix = None
|
| 79 |
+
|
| 80 |
+
def set_weights(self, freeze):
|
| 81 |
+
for param in self.bert_model.parameters():
|
| 82 |
+
param.requires_grad = not freeze
|
| 83 |
+
return
|
| 84 |
+
|
| 85 |
+
def get_output_dim(self) -> int:
|
| 86 |
+
return self.output_dim
|
| 87 |
+
|
| 88 |
+
def forward(
|
| 89 |
+
self,
|
| 90 |
+
input_ids: torch.LongTensor,
|
| 91 |
+
offsets: torch.LongTensor = None
|
| 92 |
+
) -> torch.Tensor:
|
| 93 |
+
"""
|
| 94 |
+
Parameters
|
| 95 |
+
----------
|
| 96 |
+
input_ids : ``torch.LongTensor``
|
| 97 |
+
The (batch_size, ..., max_sequence_length) tensor of wordpiece ids.
|
| 98 |
+
offsets : ``torch.LongTensor``, optional
|
| 99 |
+
The BERT embeddings are one per wordpiece. However it's possible/likely
|
| 100 |
+
you might want one per original token. In that case, ``offsets``
|
| 101 |
+
represents the indices of the desired wordpiece for each original token.
|
| 102 |
+
Depending on how your token indexer is configured, this could be the
|
| 103 |
+
position of the last wordpiece for each token, or it could be the position
|
| 104 |
+
of the first wordpiece for each token.
|
| 105 |
+
For example, if you had the sentence "Definitely not", and if the corresponding
|
| 106 |
+
wordpieces were ["Def", "##in", "##ite", "##ly", "not"], then the input_ids
|
| 107 |
+
would be 5 wordpiece ids, and the "last wordpiece" offsets would be [3, 4].
|
| 108 |
+
If offsets are provided, the returned tensor will contain only the wordpiece
|
| 109 |
+
embeddings at those positions, and (in particular) will contain one embedding
|
| 110 |
+
per token. If offsets are not provided, the entire tensor of wordpiece embeddings
|
| 111 |
+
will be returned.
|
| 112 |
+
"""
|
| 113 |
+
|
| 114 |
+
batch_size, full_seq_len = input_ids.size(0), input_ids.size(-1)
|
| 115 |
+
initial_dims = list(input_ids.shape[:-1])
|
| 116 |
+
|
| 117 |
+
# The embedder may receive an input tensor that has a sequence length longer than can
|
| 118 |
+
# be fit. In that case, we should expect the wordpiece indexer to create padded windows
|
| 119 |
+
# of length `self.max_pieces` for us, and have them concatenated into one long sequence.
|
| 120 |
+
# E.g., "[CLS] I went to the [SEP] [CLS] to the store to [SEP] ..."
|
| 121 |
+
# We can then split the sequence into sub-sequences of that length, and concatenate them
|
| 122 |
+
# along the batch dimension so we effectively have one huge batch of partial sentences.
|
| 123 |
+
# This can then be fed into BERT without any sentence length issues. Keep in mind
|
| 124 |
+
# that the memory consumption can dramatically increase for large batches with extremely
|
| 125 |
+
# long sentences.
|
| 126 |
+
needs_split = full_seq_len > self.max_pieces
|
| 127 |
+
last_window_size = 0
|
| 128 |
+
if needs_split:
|
| 129 |
+
# Split the flattened list by the window size, `max_pieces`
|
| 130 |
+
split_input_ids = list(input_ids.split(self.max_pieces, dim=-1))
|
| 131 |
+
|
| 132 |
+
# We want all sequences to be the same length, so pad the last sequence
|
| 133 |
+
last_window_size = split_input_ids[-1].size(-1)
|
| 134 |
+
padding_amount = self.max_pieces - last_window_size
|
| 135 |
+
split_input_ids[-1] = F.pad(split_input_ids[-1], pad=[0, padding_amount], value=0)
|
| 136 |
+
|
| 137 |
+
# Now combine the sequences along the batch dimension
|
| 138 |
+
input_ids = torch.cat(split_input_ids, dim=0)
|
| 139 |
+
|
| 140 |
+
input_mask = (input_ids != 0).long()
|
| 141 |
+
# input_ids may have extra dimensions, so we reshape down to 2-d
|
| 142 |
+
# before calling the BERT model and then reshape back at the end.
|
| 143 |
+
all_encoder_layers = self.bert_model(
|
| 144 |
+
input_ids=util.combine_initial_dims(input_ids),
|
| 145 |
+
attention_mask=util.combine_initial_dims(input_mask),
|
| 146 |
+
)[0]
|
| 147 |
+
if len(all_encoder_layers[0].shape) == 3:
|
| 148 |
+
all_encoder_layers = torch.stack(all_encoder_layers)
|
| 149 |
+
elif len(all_encoder_layers[0].shape) == 2:
|
| 150 |
+
all_encoder_layers = torch.unsqueeze(all_encoder_layers, dim=0)
|
| 151 |
+
|
| 152 |
+
if needs_split:
|
| 153 |
+
# First, unpack the output embeddings into one long sequence again
|
| 154 |
+
unpacked_embeddings = torch.split(all_encoder_layers, batch_size, dim=1)
|
| 155 |
+
unpacked_embeddings = torch.cat(unpacked_embeddings, dim=2)
|
| 156 |
+
|
| 157 |
+
# Next, select indices of the sequence such that it will result in embeddings representing the original
|
| 158 |
+
# sentence. To capture maximal context, the indices will be the middle part of each embedded window
|
| 159 |
+
# sub-sequence (plus any leftover start and final edge windows), e.g.,
|
| 160 |
+
# 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
|
| 161 |
+
# "[CLS] I went to the very fine [SEP] [CLS] the very fine store to eat [SEP]"
|
| 162 |
+
# with max_pieces = 8 should produce max context indices [2, 3, 4, 10, 11, 12] with additional start
|
| 163 |
+
# and final windows with indices [0, 1] and [14, 15] respectively.
|
| 164 |
+
|
| 165 |
+
# Find the stride as half the max pieces, ignoring the special start and end tokens
|
| 166 |
+
# Calculate an offset to extract the centermost embeddings of each window
|
| 167 |
+
stride = (self.max_pieces - self.num_start_tokens - self.num_end_tokens) // 2
|
| 168 |
+
stride_offset = stride // 2 + self.num_start_tokens
|
| 169 |
+
|
| 170 |
+
first_window = list(range(stride_offset))
|
| 171 |
+
|
| 172 |
+
max_context_windows = [
|
| 173 |
+
i
|
| 174 |
+
for i in range(full_seq_len)
|
| 175 |
+
if stride_offset - 1 < i % self.max_pieces < stride_offset + stride
|
| 176 |
+
]
|
| 177 |
+
|
| 178 |
+
# Lookback what's left, unless it's the whole self.max_pieces window
|
| 179 |
+
if full_seq_len % self.max_pieces == 0:
|
| 180 |
+
lookback = self.max_pieces
|
| 181 |
+
else:
|
| 182 |
+
lookback = full_seq_len % self.max_pieces
|
| 183 |
+
|
| 184 |
+
final_window_start = full_seq_len - lookback + stride_offset + stride
|
| 185 |
+
final_window = list(range(final_window_start, full_seq_len))
|
| 186 |
+
|
| 187 |
+
select_indices = first_window + max_context_windows + final_window
|
| 188 |
+
|
| 189 |
+
initial_dims.append(len(select_indices))
|
| 190 |
+
|
| 191 |
+
recombined_embeddings = unpacked_embeddings[:, :, select_indices]
|
| 192 |
+
else:
|
| 193 |
+
recombined_embeddings = all_encoder_layers
|
| 194 |
+
|
| 195 |
+
# Recombine the outputs of all layers
|
| 196 |
+
# (layers, batch_size * d1 * ... * dn, sequence_length, embedding_dim)
|
| 197 |
+
# recombined = torch.cat(combined, dim=2)
|
| 198 |
+
input_mask = (recombined_embeddings != 0).long()
|
| 199 |
+
|
| 200 |
+
if self._scalar_mix is not None:
|
| 201 |
+
mix = self._scalar_mix(recombined_embeddings, input_mask)
|
| 202 |
+
else:
|
| 203 |
+
mix = recombined_embeddings[-1]
|
| 204 |
+
|
| 205 |
+
# At this point, mix is (batch_size * d1 * ... * dn, sequence_length, embedding_dim)
|
| 206 |
+
|
| 207 |
+
if offsets is None:
|
| 208 |
+
# Resize to (batch_size, d1, ..., dn, sequence_length, embedding_dim)
|
| 209 |
+
dims = initial_dims if needs_split else input_ids.size()
|
| 210 |
+
return util.uncombine_initial_dims(mix, dims)
|
| 211 |
+
else:
|
| 212 |
+
# offsets is (batch_size, d1, ..., dn, orig_sequence_length)
|
| 213 |
+
offsets2d = util.combine_initial_dims(offsets)
|
| 214 |
+
# now offsets is (batch_size * d1 * ... * dn, orig_sequence_length)
|
| 215 |
+
range_vector = util.get_range_vector(
|
| 216 |
+
offsets2d.size(0), device=util.get_device_of(mix)
|
| 217 |
+
).unsqueeze(1)
|
| 218 |
+
# selected embeddings is also (batch_size * d1 * ... * dn, orig_sequence_length)
|
| 219 |
+
selected_embeddings = mix[range_vector, offsets2d]
|
| 220 |
+
|
| 221 |
+
return util.uncombine_initial_dims(selected_embeddings, offsets.size())
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
# @TokenEmbedder.register("bert-pretrained")
|
| 225 |
+
class PretrainedBertEmbedder(BertEmbedder):
|
| 226 |
+
|
| 227 |
+
"""
|
| 228 |
+
Parameters
|
| 229 |
+
----------
|
| 230 |
+
pretrained_model: ``str``
|
| 231 |
+
Either the name of the pretrained model to use (e.g. 'bert-base-uncased'),
|
| 232 |
+
or the path to the .tar.gz file with the model weights.
|
| 233 |
+
If the name is a key in the list of pretrained models at
|
| 234 |
+
https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/pytorch_pretrained_bert/modeling.py#L41
|
| 235 |
+
the corresponding path will be used; otherwise it will be interpreted as a path or URL.
|
| 236 |
+
requires_grad : ``bool``, optional (default = False)
|
| 237 |
+
If True, compute gradient of BERT parameters for fine tuning.
|
| 238 |
+
top_layer_only: ``bool``, optional (default = ``False``)
|
| 239 |
+
If ``True``, then only return the top layer instead of apply the scalar mix.
|
| 240 |
+
scalar_mix_parameters: ``List[float]``, optional, (default = None)
|
| 241 |
+
If not ``None``, use these scalar mix parameters to weight the representations
|
| 242 |
+
produced by different layers. These mixing weights are not updated during
|
| 243 |
+
training.
|
| 244 |
+
"""
|
| 245 |
+
|
| 246 |
+
def __init__(
|
| 247 |
+
self,
|
| 248 |
+
pretrained_model: str,
|
| 249 |
+
requires_grad: bool = False,
|
| 250 |
+
top_layer_only: bool = False,
|
| 251 |
+
special_tokens_fix: int = 0,
|
| 252 |
+
) -> None:
|
| 253 |
+
model = PretrainedBertModel.load(pretrained_model)
|
| 254 |
+
|
| 255 |
+
for param in model.parameters():
|
| 256 |
+
param.requires_grad = requires_grad
|
| 257 |
+
|
| 258 |
+
super().__init__(
|
| 259 |
+
bert_model=model,
|
| 260 |
+
top_layer_only=top_layer_only
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
if special_tokens_fix:
|
| 264 |
+
try:
|
| 265 |
+
vocab_size = self.bert_model.embeddings.word_embeddings.num_embeddings
|
| 266 |
+
except AttributeError:
|
| 267 |
+
# reserve more space
|
| 268 |
+
vocab_size = self.bert_model.word_embedding.num_embeddings + 5
|
| 269 |
+
self.bert_model.resize_token_embeddings(vocab_size + 1)
|
gector/datareader.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tweaked AllenNLP dataset reader."""
|
| 2 |
+
import logging
|
| 3 |
+
import re
|
| 4 |
+
from random import random
|
| 5 |
+
from typing import Dict, List
|
| 6 |
+
|
| 7 |
+
from allennlp.common.file_utils import cached_path
|
| 8 |
+
from allennlp.data.dataset_readers.dataset_reader import DatasetReader
|
| 9 |
+
from allennlp.data.fields import TextField, SequenceLabelField, MetadataField, Field
|
| 10 |
+
from allennlp.data.instance import Instance
|
| 11 |
+
from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer
|
| 12 |
+
from allennlp.data.tokenizers import Token
|
| 13 |
+
from overrides import overrides
|
| 14 |
+
|
| 15 |
+
from utils.helpers import SEQ_DELIMETERS, START_TOKEN
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@DatasetReader.register("seq2labels_datareader")
|
| 21 |
+
class Seq2LabelsDatasetReader(DatasetReader):
|
| 22 |
+
"""
|
| 23 |
+
Reads instances from a pretokenised file where each line is in the following format:
|
| 24 |
+
|
| 25 |
+
WORD###TAG [TAB] WORD###TAG [TAB] ..... \n
|
| 26 |
+
|
| 27 |
+
and converts it into a ``Dataset`` suitable for sequence tagging. You can also specify
|
| 28 |
+
alternative delimiters in the constructor.
|
| 29 |
+
|
| 30 |
+
Parameters
|
| 31 |
+
----------
|
| 32 |
+
delimiters: ``dict``
|
| 33 |
+
The dcitionary with all delimeters.
|
| 34 |
+
token_indexers : ``Dict[str, TokenIndexer]``, optional (default=``{"tokens": SingleIdTokenIndexer()}``)
|
| 35 |
+
We use this to define the input representation for the text. See :class:`TokenIndexer`.
|
| 36 |
+
Note that the `output` tags will always correspond to single token IDs based on how they
|
| 37 |
+
are pre-tokenised in the data file.
|
| 38 |
+
max_len: if set than will truncate long sentences
|
| 39 |
+
"""
|
| 40 |
+
# fix broken sentences mostly in Lang8
|
| 41 |
+
BROKEN_SENTENCES_REGEXP = re.compile(r'\.[a-zA-RT-Z]')
|
| 42 |
+
|
| 43 |
+
def __init__(self,
|
| 44 |
+
token_indexers: Dict[str, TokenIndexer] = None,
|
| 45 |
+
delimeters: dict = SEQ_DELIMETERS,
|
| 46 |
+
skip_correct: bool = False,
|
| 47 |
+
skip_complex: int = 0,
|
| 48 |
+
lazy: bool = False,
|
| 49 |
+
max_len: int = None,
|
| 50 |
+
test_mode: bool = False,
|
| 51 |
+
tag_strategy: str = "keep_one",
|
| 52 |
+
tn_prob: float = 0,
|
| 53 |
+
tp_prob: float = 0,
|
| 54 |
+
broken_dot_strategy: str = "keep") -> None:
|
| 55 |
+
super().__init__(lazy)
|
| 56 |
+
self._token_indexers = token_indexers or {'tokens': SingleIdTokenIndexer()}
|
| 57 |
+
self._delimeters = delimeters
|
| 58 |
+
self._max_len = max_len
|
| 59 |
+
self._skip_correct = skip_correct
|
| 60 |
+
self._skip_complex = skip_complex
|
| 61 |
+
self._tag_strategy = tag_strategy
|
| 62 |
+
self._broken_dot_strategy = broken_dot_strategy
|
| 63 |
+
self._test_mode = test_mode
|
| 64 |
+
self._tn_prob = tn_prob
|
| 65 |
+
self._tp_prob = tp_prob
|
| 66 |
+
|
| 67 |
+
@overrides
|
| 68 |
+
def _read(self, file_path):
|
| 69 |
+
# if `file_path` is a URL, redirect to the cache
|
| 70 |
+
file_path = cached_path(file_path)
|
| 71 |
+
with open(file_path, "r") as data_file:
|
| 72 |
+
logger.info("Reading instances from lines in file at: %s", file_path)
|
| 73 |
+
for line in data_file:
|
| 74 |
+
line = line.strip("\n")
|
| 75 |
+
# skip blank and broken lines
|
| 76 |
+
if not line or (not self._test_mode and self._broken_dot_strategy == 'skip'
|
| 77 |
+
and self.BROKEN_SENTENCES_REGEXP.search(line) is not None):
|
| 78 |
+
continue
|
| 79 |
+
|
| 80 |
+
tokens_and_tags = [pair.rsplit(self._delimeters['labels'], 1)
|
| 81 |
+
for pair in line.split(self._delimeters['tokens'])]
|
| 82 |
+
try:
|
| 83 |
+
tokens = [Token(token) for token, tag in tokens_and_tags]
|
| 84 |
+
tags = [tag for token, tag in tokens_and_tags]
|
| 85 |
+
except ValueError:
|
| 86 |
+
tokens = [Token(token[0]) for token in tokens_and_tags]
|
| 87 |
+
tags = None
|
| 88 |
+
|
| 89 |
+
if tokens and tokens[0] != Token(START_TOKEN):
|
| 90 |
+
tokens = [Token(START_TOKEN)] + tokens
|
| 91 |
+
|
| 92 |
+
words = [x.text for x in tokens]
|
| 93 |
+
if self._max_len is not None:
|
| 94 |
+
tokens = tokens[:self._max_len]
|
| 95 |
+
tags = None if tags is None else tags[:self._max_len]
|
| 96 |
+
instance = self.text_to_instance(tokens, tags, words)
|
| 97 |
+
if instance:
|
| 98 |
+
yield instance
|
| 99 |
+
|
| 100 |
+
def extract_tags(self, tags: List[str]):
|
| 101 |
+
op_del = self._delimeters['operations']
|
| 102 |
+
|
| 103 |
+
labels = [x.split(op_del) for x in tags]
|
| 104 |
+
|
| 105 |
+
comlex_flag_dict = {}
|
| 106 |
+
# get flags
|
| 107 |
+
for i in range(5):
|
| 108 |
+
idx = i + 1
|
| 109 |
+
comlex_flag_dict[idx] = sum([len(x) > idx for x in labels])
|
| 110 |
+
|
| 111 |
+
if self._tag_strategy == "keep_one":
|
| 112 |
+
# get only first candidates for r_tags in right and the last for left
|
| 113 |
+
labels = [x[0] for x in labels]
|
| 114 |
+
elif self._tag_strategy == "merge_all":
|
| 115 |
+
# consider phrases as a words
|
| 116 |
+
pass
|
| 117 |
+
else:
|
| 118 |
+
raise Exception("Incorrect tag strategy")
|
| 119 |
+
|
| 120 |
+
detect_tags = ["CORRECT" if label == "$KEEP" else "INCORRECT" for label in labels]
|
| 121 |
+
return labels, detect_tags, comlex_flag_dict
|
| 122 |
+
|
| 123 |
+
def text_to_instance(self, tokens: List[Token], tags: List[str] = None,
|
| 124 |
+
words: List[str] = None) -> Instance: # type: ignore
|
| 125 |
+
"""
|
| 126 |
+
We take `pre-tokenized` input here, because we don't have a tokenizer in this class.
|
| 127 |
+
"""
|
| 128 |
+
# pylint: disable=arguments-differ
|
| 129 |
+
fields: Dict[str, Field] = {}
|
| 130 |
+
sequence = TextField(tokens, self._token_indexers)
|
| 131 |
+
fields["tokens"] = sequence
|
| 132 |
+
fields["metadata"] = MetadataField({"words": words})
|
| 133 |
+
if tags is not None:
|
| 134 |
+
labels, detect_tags, complex_flag_dict = self.extract_tags(tags)
|
| 135 |
+
if self._skip_complex and complex_flag_dict[self._skip_complex] > 0:
|
| 136 |
+
return None
|
| 137 |
+
rnd = random()
|
| 138 |
+
# skip TN
|
| 139 |
+
if self._skip_correct and all(x == "CORRECT" for x in detect_tags):
|
| 140 |
+
if rnd > self._tn_prob:
|
| 141 |
+
return None
|
| 142 |
+
# skip TP
|
| 143 |
+
else:
|
| 144 |
+
if rnd > self._tp_prob:
|
| 145 |
+
return None
|
| 146 |
+
|
| 147 |
+
fields["labels"] = SequenceLabelField(labels, sequence,
|
| 148 |
+
label_namespace="labels")
|
| 149 |
+
fields["d_tags"] = SequenceLabelField(detect_tags, sequence,
|
| 150 |
+
label_namespace="d_tags")
|
| 151 |
+
return Instance(fields)
|
gector/gec_model.py
ADDED
|
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Wrapper of AllenNLP model. Fixes errors based on model predictions"""
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
from time import time
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from allennlp.data.dataset import Batch
|
| 9 |
+
from allennlp.data.fields import TextField
|
| 10 |
+
from allennlp.data.instance import Instance
|
| 11 |
+
from allennlp.data.tokenizers import Token
|
| 12 |
+
from allennlp.data.vocabulary import Vocabulary
|
| 13 |
+
from allennlp.modules.text_field_embedders import BasicTextFieldEmbedder
|
| 14 |
+
from allennlp.nn import util
|
| 15 |
+
|
| 16 |
+
from gector.bert_token_embedder import PretrainedBertEmbedder
|
| 17 |
+
from gector.seq2labels_model import Seq2Labels
|
| 18 |
+
from gector.tokenizer_indexer import PretrainedBertIndexer
|
| 19 |
+
from utils.helpers import PAD, UNK, get_target_sent_by_edits, START_TOKEN
|
| 20 |
+
from utils.helpers import get_weights_name
|
| 21 |
+
|
| 22 |
+
logging.getLogger("werkzeug").setLevel(logging.ERROR)
|
| 23 |
+
logger = logging.getLogger(__file__)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class GecBERTModel(object):
|
| 27 |
+
def __init__(self, vocab_path=None, model_paths=None,
|
| 28 |
+
weigths=None,
|
| 29 |
+
max_len=50,
|
| 30 |
+
min_len=3,
|
| 31 |
+
lowercase_tokens=False,
|
| 32 |
+
log=False,
|
| 33 |
+
iterations=3,
|
| 34 |
+
model_name='roberta',
|
| 35 |
+
special_tokens_fix=1,
|
| 36 |
+
is_ensemble=True,
|
| 37 |
+
min_error_probability=0.0,
|
| 38 |
+
confidence=0,
|
| 39 |
+
del_confidence=0,
|
| 40 |
+
resolve_cycles=False,
|
| 41 |
+
):
|
| 42 |
+
self.model_weights = list(map(float, weigths)) if weigths else [1] * len(model_paths)
|
| 43 |
+
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 44 |
+
self.max_len = max_len
|
| 45 |
+
self.min_len = min_len
|
| 46 |
+
self.lowercase_tokens = lowercase_tokens
|
| 47 |
+
self.min_error_probability = min_error_probability
|
| 48 |
+
self.vocab = Vocabulary.from_files(vocab_path)
|
| 49 |
+
self.log = log
|
| 50 |
+
self.iterations = iterations
|
| 51 |
+
self.confidence = confidence
|
| 52 |
+
self.del_conf = del_confidence
|
| 53 |
+
self.resolve_cycles = resolve_cycles
|
| 54 |
+
# set training parameters and operations
|
| 55 |
+
|
| 56 |
+
self.indexers = []
|
| 57 |
+
self.models = []
|
| 58 |
+
for model_path in model_paths:
|
| 59 |
+
if is_ensemble:
|
| 60 |
+
model_name, special_tokens_fix = self._get_model_data(model_path)
|
| 61 |
+
weights_name = get_weights_name(model_name, lowercase_tokens)
|
| 62 |
+
self.indexers.append(self._get_indexer(weights_name, special_tokens_fix))
|
| 63 |
+
model = Seq2Labels(vocab=self.vocab,
|
| 64 |
+
text_field_embedder=self._get_embbeder(weights_name, special_tokens_fix),
|
| 65 |
+
confidence=self.confidence,
|
| 66 |
+
del_confidence=self.del_conf,
|
| 67 |
+
).to(self.device)
|
| 68 |
+
if torch.cuda.is_available():
|
| 69 |
+
model.load_state_dict(torch.load(model_path), strict=False)
|
| 70 |
+
else:
|
| 71 |
+
model.load_state_dict(torch.load(model_path,
|
| 72 |
+
map_location=torch.device('cpu')),
|
| 73 |
+
strict=False)
|
| 74 |
+
model.eval()
|
| 75 |
+
self.models.append(model)
|
| 76 |
+
|
| 77 |
+
@staticmethod
|
| 78 |
+
def _get_model_data(model_path):
|
| 79 |
+
model_name = model_path.split('/')[-1]
|
| 80 |
+
tr_model, stf = model_name.split('_')[:2]
|
| 81 |
+
return tr_model, int(stf)
|
| 82 |
+
|
| 83 |
+
def _restore_model(self, input_path):
|
| 84 |
+
if os.path.isdir(input_path):
|
| 85 |
+
print("Model could not be restored from directory", file=sys.stderr)
|
| 86 |
+
filenames = []
|
| 87 |
+
else:
|
| 88 |
+
filenames = [input_path]
|
| 89 |
+
for model_path in filenames:
|
| 90 |
+
try:
|
| 91 |
+
if torch.cuda.is_available():
|
| 92 |
+
loaded_model = torch.load(model_path)
|
| 93 |
+
else:
|
| 94 |
+
loaded_model = torch.load(model_path,
|
| 95 |
+
map_location=lambda storage,
|
| 96 |
+
loc: storage)
|
| 97 |
+
except:
|
| 98 |
+
print(f"{model_path} is not valid model", file=sys.stderr)
|
| 99 |
+
own_state = self.model.state_dict()
|
| 100 |
+
for name, weights in loaded_model.items():
|
| 101 |
+
if name not in own_state:
|
| 102 |
+
continue
|
| 103 |
+
try:
|
| 104 |
+
if len(filenames) == 1:
|
| 105 |
+
own_state[name].copy_(weights)
|
| 106 |
+
else:
|
| 107 |
+
own_state[name] += weights
|
| 108 |
+
except RuntimeError:
|
| 109 |
+
continue
|
| 110 |
+
print("Model is restored", file=sys.stderr)
|
| 111 |
+
|
| 112 |
+
def predict(self, batches):
|
| 113 |
+
t11 = time()
|
| 114 |
+
predictions = []
|
| 115 |
+
for batch, model in zip(batches, self.models):
|
| 116 |
+
batch = util.move_to_device(batch.as_tensor_dict(), 0 if torch.cuda.is_available() else -1)
|
| 117 |
+
with torch.no_grad():
|
| 118 |
+
prediction = model.forward(**batch)
|
| 119 |
+
predictions.append(prediction)
|
| 120 |
+
|
| 121 |
+
preds, idx, error_probs = self._convert(predictions)
|
| 122 |
+
t55 = time()
|
| 123 |
+
if self.log:
|
| 124 |
+
print(f"Inference time {t55 - t11}")
|
| 125 |
+
return preds, idx, error_probs
|
| 126 |
+
|
| 127 |
+
def get_token_action(self, token, index, prob, sugg_token):
|
| 128 |
+
"""Get lost of suggested actions for token."""
|
| 129 |
+
# cases when we don't need to do anything
|
| 130 |
+
if prob < self.min_error_probability or sugg_token in [UNK, PAD, '$KEEP']:
|
| 131 |
+
return None
|
| 132 |
+
|
| 133 |
+
if sugg_token.startswith('$REPLACE_') or sugg_token.startswith('$TRANSFORM_') or sugg_token == '$DELETE':
|
| 134 |
+
start_pos = index
|
| 135 |
+
end_pos = index + 1
|
| 136 |
+
elif sugg_token.startswith("$APPEND_") or sugg_token.startswith("$MERGE_"):
|
| 137 |
+
start_pos = index + 1
|
| 138 |
+
end_pos = index + 1
|
| 139 |
+
|
| 140 |
+
if sugg_token == "$DELETE":
|
| 141 |
+
sugg_token_clear = ""
|
| 142 |
+
elif sugg_token.startswith('$TRANSFORM_') or sugg_token.startswith("$MERGE_"):
|
| 143 |
+
sugg_token_clear = sugg_token[:]
|
| 144 |
+
else:
|
| 145 |
+
sugg_token_clear = sugg_token[sugg_token.index('_') + 1:]
|
| 146 |
+
|
| 147 |
+
return start_pos - 1, end_pos - 1, sugg_token_clear, prob
|
| 148 |
+
|
| 149 |
+
def _get_embbeder(self, weigths_name, special_tokens_fix):
|
| 150 |
+
embedders = {'bert': PretrainedBertEmbedder(
|
| 151 |
+
pretrained_model=weigths_name,
|
| 152 |
+
requires_grad=False,
|
| 153 |
+
top_layer_only=True,
|
| 154 |
+
special_tokens_fix=special_tokens_fix)
|
| 155 |
+
}
|
| 156 |
+
text_field_embedder = BasicTextFieldEmbedder(
|
| 157 |
+
token_embedders=embedders,
|
| 158 |
+
embedder_to_indexer_map={"bert": ["bert", "bert-offsets"]},
|
| 159 |
+
allow_unmatched_keys=True)
|
| 160 |
+
return text_field_embedder
|
| 161 |
+
|
| 162 |
+
def _get_indexer(self, weights_name, special_tokens_fix):
|
| 163 |
+
bert_token_indexer = PretrainedBertIndexer(
|
| 164 |
+
pretrained_model=weights_name,
|
| 165 |
+
do_lowercase=self.lowercase_tokens,
|
| 166 |
+
max_pieces_per_token=5,
|
| 167 |
+
special_tokens_fix=special_tokens_fix
|
| 168 |
+
)
|
| 169 |
+
return {'bert': bert_token_indexer}
|
| 170 |
+
|
| 171 |
+
def preprocess(self, token_batch):
|
| 172 |
+
seq_lens = [len(sequence) for sequence in token_batch if sequence]
|
| 173 |
+
if not seq_lens:
|
| 174 |
+
return []
|
| 175 |
+
max_len = min(max(seq_lens), self.max_len)
|
| 176 |
+
batches = []
|
| 177 |
+
for indexer in self.indexers:
|
| 178 |
+
batch = []
|
| 179 |
+
for sequence in token_batch:
|
| 180 |
+
tokens = sequence[:max_len]
|
| 181 |
+
tokens = [Token(token) for token in ['$START'] + tokens]
|
| 182 |
+
batch.append(Instance({'tokens': TextField(tokens, indexer)}))
|
| 183 |
+
batch = Batch(batch)
|
| 184 |
+
batch.index_instances(self.vocab)
|
| 185 |
+
batches.append(batch)
|
| 186 |
+
|
| 187 |
+
return batches
|
| 188 |
+
|
| 189 |
+
def _convert(self, data):
|
| 190 |
+
all_class_probs = torch.zeros_like(data[0]['class_probabilities_labels'])
|
| 191 |
+
error_probs = torch.zeros_like(data[0]['max_error_probability'])
|
| 192 |
+
for output, weight in zip(data, self.model_weights):
|
| 193 |
+
all_class_probs += weight * output['class_probabilities_labels'] / sum(self.model_weights)
|
| 194 |
+
error_probs += weight * output['max_error_probability'] / sum(self.model_weights)
|
| 195 |
+
|
| 196 |
+
max_vals = torch.max(all_class_probs, dim=-1)
|
| 197 |
+
probs = max_vals[0].tolist()
|
| 198 |
+
idx = max_vals[1].tolist()
|
| 199 |
+
return probs, idx, error_probs.tolist()
|
| 200 |
+
|
| 201 |
+
def update_final_batch(self, final_batch, pred_ids, pred_batch,
|
| 202 |
+
prev_preds_dict):
|
| 203 |
+
new_pred_ids = []
|
| 204 |
+
total_updated = 0
|
| 205 |
+
for i, orig_id in enumerate(pred_ids):
|
| 206 |
+
orig = final_batch[orig_id]
|
| 207 |
+
pred = pred_batch[i]
|
| 208 |
+
prev_preds = prev_preds_dict[orig_id]
|
| 209 |
+
if orig != pred and pred not in prev_preds:
|
| 210 |
+
final_batch[orig_id] = pred
|
| 211 |
+
new_pred_ids.append(orig_id)
|
| 212 |
+
prev_preds_dict[orig_id].append(pred)
|
| 213 |
+
total_updated += 1
|
| 214 |
+
elif orig != pred and pred in prev_preds:
|
| 215 |
+
# update final batch, but stop iterations
|
| 216 |
+
final_batch[orig_id] = pred
|
| 217 |
+
total_updated += 1
|
| 218 |
+
else:
|
| 219 |
+
continue
|
| 220 |
+
return final_batch, new_pred_ids, total_updated
|
| 221 |
+
|
| 222 |
+
def postprocess_batch(self, batch, all_probabilities, all_idxs,
|
| 223 |
+
error_probs):
|
| 224 |
+
all_results = []
|
| 225 |
+
noop_index = self.vocab.get_token_index("$KEEP", "labels")
|
| 226 |
+
for tokens, probabilities, idxs, error_prob in zip(batch,
|
| 227 |
+
all_probabilities,
|
| 228 |
+
all_idxs,
|
| 229 |
+
error_probs):
|
| 230 |
+
length = min(len(tokens), self.max_len)
|
| 231 |
+
edits = []
|
| 232 |
+
|
| 233 |
+
# skip whole sentences if there no errors
|
| 234 |
+
if max(idxs) == 0:
|
| 235 |
+
all_results.append(tokens)
|
| 236 |
+
continue
|
| 237 |
+
|
| 238 |
+
# skip whole sentence if probability of correctness is not high
|
| 239 |
+
if error_prob < self.min_error_probability:
|
| 240 |
+
all_results.append(tokens)
|
| 241 |
+
continue
|
| 242 |
+
|
| 243 |
+
for i in range(length + 1):
|
| 244 |
+
# because of START token
|
| 245 |
+
if i == 0:
|
| 246 |
+
token = START_TOKEN
|
| 247 |
+
else:
|
| 248 |
+
token = tokens[i - 1]
|
| 249 |
+
# skip if there is no error
|
| 250 |
+
if idxs[i] == noop_index:
|
| 251 |
+
continue
|
| 252 |
+
|
| 253 |
+
sugg_token = self.vocab.get_token_from_index(idxs[i],
|
| 254 |
+
namespace='labels')
|
| 255 |
+
action = self.get_token_action(token, i, probabilities[i],
|
| 256 |
+
sugg_token)
|
| 257 |
+
if not action:
|
| 258 |
+
continue
|
| 259 |
+
|
| 260 |
+
edits.append(action)
|
| 261 |
+
all_results.append(get_target_sent_by_edits(tokens, edits))
|
| 262 |
+
return all_results
|
| 263 |
+
|
| 264 |
+
def handle_batch(self, full_batch):
|
| 265 |
+
"""
|
| 266 |
+
Handle batch of requests.
|
| 267 |
+
"""
|
| 268 |
+
final_batch = full_batch[:]
|
| 269 |
+
batch_size = len(full_batch)
|
| 270 |
+
prev_preds_dict = {i: [final_batch[i]] for i in range(len(final_batch))}
|
| 271 |
+
short_ids = [i for i in range(len(full_batch))
|
| 272 |
+
if len(full_batch[i]) < self.min_len]
|
| 273 |
+
pred_ids = [i for i in range(len(full_batch)) if i not in short_ids]
|
| 274 |
+
total_updates = 0
|
| 275 |
+
|
| 276 |
+
for n_iter in range(self.iterations):
|
| 277 |
+
orig_batch = [final_batch[i] for i in pred_ids]
|
| 278 |
+
|
| 279 |
+
sequences = self.preprocess(orig_batch)
|
| 280 |
+
|
| 281 |
+
if not sequences:
|
| 282 |
+
break
|
| 283 |
+
probabilities, idxs, error_probs = self.predict(sequences)
|
| 284 |
+
|
| 285 |
+
pred_batch = self.postprocess_batch(orig_batch, probabilities,
|
| 286 |
+
idxs, error_probs)
|
| 287 |
+
if self.log:
|
| 288 |
+
print(f"Iteration {n_iter + 1}. Predicted {round(100*len(pred_ids)/batch_size, 1)}% of sentences.")
|
| 289 |
+
|
| 290 |
+
final_batch, pred_ids, cnt = \
|
| 291 |
+
self.update_final_batch(final_batch, pred_ids, pred_batch,
|
| 292 |
+
prev_preds_dict)
|
| 293 |
+
total_updates += cnt
|
| 294 |
+
|
| 295 |
+
if not pred_ids:
|
| 296 |
+
break
|
| 297 |
+
|
| 298 |
+
return final_batch, total_updates
|
gector/seq2labels_model.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Basic model. Predicts tags for every token"""
|
| 2 |
+
from typing import Dict, Optional, List, Any
|
| 3 |
+
|
| 4 |
+
import numpy
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from allennlp.data import Vocabulary
|
| 8 |
+
from allennlp.models.model import Model
|
| 9 |
+
from allennlp.modules import TimeDistributed, TextFieldEmbedder
|
| 10 |
+
from allennlp.nn import InitializerApplicator, RegularizerApplicator
|
| 11 |
+
from allennlp.nn.util import get_text_field_mask, sequence_cross_entropy_with_logits
|
| 12 |
+
from allennlp.training.metrics import CategoricalAccuracy
|
| 13 |
+
from overrides import overrides
|
| 14 |
+
from torch.nn.modules.linear import Linear
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@Model.register("seq2labels")
|
| 18 |
+
class Seq2Labels(Model):
|
| 19 |
+
"""
|
| 20 |
+
This ``Seq2Labels`` simply encodes a sequence of text with a stacked ``Seq2SeqEncoder``, then
|
| 21 |
+
predicts a tag (or couple tags) for each token in the sequence.
|
| 22 |
+
|
| 23 |
+
Parameters
|
| 24 |
+
----------
|
| 25 |
+
vocab : ``Vocabulary``, required
|
| 26 |
+
A Vocabulary, required in order to compute sizes for input/output projections.
|
| 27 |
+
text_field_embedder : ``TextFieldEmbedder``, required
|
| 28 |
+
Used to embed the ``tokens`` ``TextField`` we get as input to the model.
|
| 29 |
+
encoder : ``Seq2SeqEncoder``
|
| 30 |
+
The encoder (with its own internal stacking) that we will use in between embedding tokens
|
| 31 |
+
and predicting output tags.
|
| 32 |
+
calculate_span_f1 : ``bool``, optional (default=``None``)
|
| 33 |
+
Calculate span-level F1 metrics during training. If this is ``True``, then
|
| 34 |
+
``label_encoding`` is required. If ``None`` and
|
| 35 |
+
label_encoding is specified, this is set to ``True``.
|
| 36 |
+
If ``None`` and label_encoding is not specified, it defaults
|
| 37 |
+
to ``False``.
|
| 38 |
+
label_encoding : ``str``, optional (default=``None``)
|
| 39 |
+
Label encoding to use when calculating span f1.
|
| 40 |
+
Valid options are "BIO", "BIOUL", "IOB1", "BMES".
|
| 41 |
+
Required if ``calculate_span_f1`` is true.
|
| 42 |
+
labels_namespace : ``str``, optional (default=``labels``)
|
| 43 |
+
This is needed to compute the SpanBasedF1Measure metric, if desired.
|
| 44 |
+
Unless you did something unusual, the default value should be what you want.
|
| 45 |
+
verbose_metrics : ``bool``, optional (default = False)
|
| 46 |
+
If true, metrics will be returned per label class in addition
|
| 47 |
+
to the overall statistics.
|
| 48 |
+
initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``)
|
| 49 |
+
Used to initialize the model parameters.
|
| 50 |
+
regularizer : ``RegularizerApplicator``, optional (default=``None``)
|
| 51 |
+
If provided, will be used to calculate the regularization penalty during training.
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
def __init__(self, vocab: Vocabulary,
|
| 55 |
+
text_field_embedder: TextFieldEmbedder,
|
| 56 |
+
predictor_dropout=0.0,
|
| 57 |
+
labels_namespace: str = "labels",
|
| 58 |
+
detect_namespace: str = "d_tags",
|
| 59 |
+
verbose_metrics: bool = False,
|
| 60 |
+
label_smoothing: float = 0.0,
|
| 61 |
+
confidence: float = 0.0,
|
| 62 |
+
del_confidence: float = 0.0,
|
| 63 |
+
initializer: InitializerApplicator = InitializerApplicator(),
|
| 64 |
+
regularizer: Optional[RegularizerApplicator] = None) -> None:
|
| 65 |
+
super(Seq2Labels, self).__init__(vocab, regularizer)
|
| 66 |
+
|
| 67 |
+
self.label_namespaces = [labels_namespace,
|
| 68 |
+
detect_namespace]
|
| 69 |
+
self.text_field_embedder = text_field_embedder
|
| 70 |
+
self.num_labels_classes = self.vocab.get_vocab_size(labels_namespace)
|
| 71 |
+
self.num_detect_classes = self.vocab.get_vocab_size(detect_namespace)
|
| 72 |
+
self.label_smoothing = label_smoothing
|
| 73 |
+
self.confidence = confidence
|
| 74 |
+
self.del_conf = del_confidence
|
| 75 |
+
self.incorr_index = self.vocab.get_token_index("INCORRECT",
|
| 76 |
+
namespace=detect_namespace)
|
| 77 |
+
|
| 78 |
+
self._verbose_metrics = verbose_metrics
|
| 79 |
+
self.predictor_dropout = TimeDistributed(torch.nn.Dropout(predictor_dropout))
|
| 80 |
+
|
| 81 |
+
self.tag_labels_projection_layer = TimeDistributed(
|
| 82 |
+
Linear(text_field_embedder._token_embedders['bert'].get_output_dim(), self.num_labels_classes))
|
| 83 |
+
|
| 84 |
+
self.tag_detect_projection_layer = TimeDistributed(
|
| 85 |
+
Linear(text_field_embedder._token_embedders['bert'].get_output_dim(), self.num_detect_classes))
|
| 86 |
+
|
| 87 |
+
self.metrics = {"accuracy": CategoricalAccuracy()}
|
| 88 |
+
|
| 89 |
+
initializer(self)
|
| 90 |
+
|
| 91 |
+
@overrides
|
| 92 |
+
def forward(self, # type: ignore
|
| 93 |
+
tokens: Dict[str, torch.LongTensor],
|
| 94 |
+
labels: torch.LongTensor = None,
|
| 95 |
+
d_tags: torch.LongTensor = None,
|
| 96 |
+
metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
|
| 97 |
+
# pylint: disable=arguments-differ
|
| 98 |
+
"""
|
| 99 |
+
Parameters
|
| 100 |
+
----------
|
| 101 |
+
tokens : Dict[str, torch.LongTensor], required
|
| 102 |
+
The output of ``TextField.as_array()``, which should typically be passed directly to a
|
| 103 |
+
``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer``
|
| 104 |
+
tensors. At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens":
|
| 105 |
+
Tensor(batch_size, num_tokens)}``. This dictionary will have the same keys as were used
|
| 106 |
+
for the ``TokenIndexers`` when you created the ``TextField`` representing your
|
| 107 |
+
sequence. The dictionary is designed to be passed directly to a ``TextFieldEmbedder``,
|
| 108 |
+
which knows how to combine different word representations into a single vector per
|
| 109 |
+
token in your input.
|
| 110 |
+
labels : torch.LongTensor, optional (default = None)
|
| 111 |
+
A torch tensor representing the sequence of integer gold class labels of shape
|
| 112 |
+
``(batch_size, num_tokens)``.
|
| 113 |
+
d_tags : torch.LongTensor, optional (default = None)
|
| 114 |
+
A torch tensor representing the sequence of integer gold class labels of shape
|
| 115 |
+
``(batch_size, num_tokens)``.
|
| 116 |
+
metadata : ``List[Dict[str, Any]]``, optional, (default = None)
|
| 117 |
+
metadata containing the original words in the sentence to be tagged under a 'words' key.
|
| 118 |
+
|
| 119 |
+
Returns
|
| 120 |
+
-------
|
| 121 |
+
An output dictionary consisting of:
|
| 122 |
+
logits : torch.FloatTensor
|
| 123 |
+
A tensor of shape ``(batch_size, num_tokens, tag_vocab_size)`` representing
|
| 124 |
+
unnormalised log probabilities of the tag classes.
|
| 125 |
+
class_probabilities : torch.FloatTensor
|
| 126 |
+
A tensor of shape ``(batch_size, num_tokens, tag_vocab_size)`` representing
|
| 127 |
+
a distribution of the tag classes per word.
|
| 128 |
+
loss : torch.FloatTensor, optional
|
| 129 |
+
A scalar loss to be optimised.
|
| 130 |
+
|
| 131 |
+
"""
|
| 132 |
+
encoded_text = self.text_field_embedder(tokens)
|
| 133 |
+
batch_size, sequence_length, _ = encoded_text.size()
|
| 134 |
+
mask = get_text_field_mask(tokens)
|
| 135 |
+
logits_labels = self.tag_labels_projection_layer(self.predictor_dropout(encoded_text))
|
| 136 |
+
logits_d = self.tag_detect_projection_layer(encoded_text)
|
| 137 |
+
|
| 138 |
+
class_probabilities_labels = F.softmax(logits_labels, dim=-1).view(
|
| 139 |
+
[batch_size, sequence_length, self.num_labels_classes])
|
| 140 |
+
class_probabilities_d = F.softmax(logits_d, dim=-1).view(
|
| 141 |
+
[batch_size, sequence_length, self.num_detect_classes])
|
| 142 |
+
error_probs = class_probabilities_d[:, :, self.incorr_index] * mask
|
| 143 |
+
incorr_prob = torch.max(error_probs, dim=-1)[0]
|
| 144 |
+
|
| 145 |
+
probability_change = [self.confidence, self.del_conf] + [0] * (self.num_labels_classes - 2)
|
| 146 |
+
class_probabilities_labels += torch.FloatTensor(probability_change).repeat(
|
| 147 |
+
(batch_size, sequence_length, 1)).to(class_probabilities_labels.device)
|
| 148 |
+
|
| 149 |
+
output_dict = {"logits_labels": logits_labels,
|
| 150 |
+
"logits_d_tags": logits_d,
|
| 151 |
+
"class_probabilities_labels": class_probabilities_labels,
|
| 152 |
+
"class_probabilities_d_tags": class_probabilities_d,
|
| 153 |
+
"max_error_probability": incorr_prob}
|
| 154 |
+
if labels is not None and d_tags is not None:
|
| 155 |
+
loss_labels = sequence_cross_entropy_with_logits(logits_labels, labels, mask,
|
| 156 |
+
label_smoothing=self.label_smoothing)
|
| 157 |
+
loss_d = sequence_cross_entropy_with_logits(logits_d, d_tags, mask)
|
| 158 |
+
for metric in self.metrics.values():
|
| 159 |
+
metric(logits_labels, labels, mask.float())
|
| 160 |
+
metric(logits_d, d_tags, mask.float())
|
| 161 |
+
output_dict["loss"] = loss_labels + loss_d
|
| 162 |
+
|
| 163 |
+
if metadata is not None:
|
| 164 |
+
output_dict["words"] = [x["words"] for x in metadata]
|
| 165 |
+
return output_dict
|
| 166 |
+
|
| 167 |
+
@overrides
|
| 168 |
+
def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
| 169 |
+
"""
|
| 170 |
+
Does a simple position-wise argmax over each token, converts indices to string labels, and
|
| 171 |
+
adds a ``"tags"`` key to the dictionary with the result.
|
| 172 |
+
"""
|
| 173 |
+
for label_namespace in self.label_namespaces:
|
| 174 |
+
all_predictions = output_dict[f'class_probabilities_{label_namespace}']
|
| 175 |
+
all_predictions = all_predictions.cpu().data.numpy()
|
| 176 |
+
if all_predictions.ndim == 3:
|
| 177 |
+
predictions_list = [all_predictions[i] for i in range(all_predictions.shape[0])]
|
| 178 |
+
else:
|
| 179 |
+
predictions_list = [all_predictions]
|
| 180 |
+
all_tags = []
|
| 181 |
+
|
| 182 |
+
for predictions in predictions_list:
|
| 183 |
+
argmax_indices = numpy.argmax(predictions, axis=-1)
|
| 184 |
+
tags = [self.vocab.get_token_from_index(x, namespace=label_namespace)
|
| 185 |
+
for x in argmax_indices]
|
| 186 |
+
all_tags.append(tags)
|
| 187 |
+
output_dict[f'{label_namespace}'] = all_tags
|
| 188 |
+
return output_dict
|
| 189 |
+
|
| 190 |
+
@overrides
|
| 191 |
+
def get_metrics(self, reset: bool = False) -> Dict[str, float]:
|
| 192 |
+
metrics_to_return = {metric_name: metric.get_metric(reset) for
|
| 193 |
+
metric_name, metric in self.metrics.items()}
|
| 194 |
+
return metrics_to_return
|
gector/tokenization.py
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from time import time
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def get_bpe_groups(token_offsets, bpe_offsets, input_ids, max_bpe_pieces=5):
|
| 9 |
+
bpe_groups = []
|
| 10 |
+
last_used_bpe = 0
|
| 11 |
+
# find the size of offsets
|
| 12 |
+
if (0, 0) in bpe_offsets:
|
| 13 |
+
bpe_size = bpe_offsets.index((0, 0))
|
| 14 |
+
else:
|
| 15 |
+
bpe_size = len(bpe_offsets)
|
| 16 |
+
|
| 17 |
+
saved_ids = [i for i in range(len(input_ids))]
|
| 18 |
+
redundant_ids = []
|
| 19 |
+
for token_offset in token_offsets:
|
| 20 |
+
start_token, end_token = token_offset
|
| 21 |
+
bpe_group = []
|
| 22 |
+
mapping_is_found = False
|
| 23 |
+
for i in range(last_used_bpe, bpe_size):
|
| 24 |
+
start_bpe, end_bpe = bpe_offsets[i]
|
| 25 |
+
if start_bpe >= start_token and end_bpe <= end_token:
|
| 26 |
+
# check if bpe_group is satisfy max_bpe_pieces constraint
|
| 27 |
+
if len(bpe_group) < max_bpe_pieces:
|
| 28 |
+
bpe_group.append(i)
|
| 29 |
+
else:
|
| 30 |
+
redundant_ids.append(i)
|
| 31 |
+
last_used_bpe = i + 1
|
| 32 |
+
mapping_is_found = True
|
| 33 |
+
elif mapping_is_found:
|
| 34 |
+
# stop doing useless iterations
|
| 35 |
+
break
|
| 36 |
+
else:
|
| 37 |
+
continue
|
| 38 |
+
bpe_groups.append(bpe_group)
|
| 39 |
+
saved_ids = [i for i in saved_ids if i not in redundant_ids]
|
| 40 |
+
return bpe_groups, saved_ids
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def reduce_input_ids(input_ids, bpe_groups, saved_ids,
|
| 44 |
+
max_bpe_length=80, max_bpe_pieces=5):
|
| 45 |
+
# check if sequence is satisfy max_bpe_length constraint
|
| 46 |
+
while len(saved_ids) > max_bpe_length:
|
| 47 |
+
max_bpe_pieces -= 1
|
| 48 |
+
for token_id in range(len(bpe_groups)):
|
| 49 |
+
if len(bpe_groups[token_id]) > max_bpe_pieces:
|
| 50 |
+
redundant_ids = bpe_groups[token_id][max_bpe_pieces:]
|
| 51 |
+
bpe_groups[token_id] = bpe_groups[token_id][:max_bpe_pieces]
|
| 52 |
+
saved_ids = [i for i in saved_ids if i not in redundant_ids]
|
| 53 |
+
|
| 54 |
+
# get offsets
|
| 55 |
+
reduced_ids = [input_ids[i] for i in saved_ids]
|
| 56 |
+
correct_offsets = []
|
| 57 |
+
idx = 0
|
| 58 |
+
for i, bpe_group in enumerate(bpe_groups):
|
| 59 |
+
norm_idx = min(idx, len(reduced_ids) - 1)
|
| 60 |
+
correct_offsets.append(norm_idx)
|
| 61 |
+
idx += len(bpe_group)
|
| 62 |
+
|
| 63 |
+
return reduced_ids, correct_offsets
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def get_offsets_and_reduce_input_ids(tokenizer_output, token_offset_list,
|
| 67 |
+
index_name="bert", max_bpe_length=80,
|
| 68 |
+
max_bpe_pieces=5):
|
| 69 |
+
timings = {"bpe": 0, "reduce": 0, "mask": 0}
|
| 70 |
+
output_ids, output_offsets, output_masks = [], [], []
|
| 71 |
+
for i, token_offsets in enumerate(token_offset_list):
|
| 72 |
+
input_ids = tokenizer_output['input_ids'][i]
|
| 73 |
+
|
| 74 |
+
t0 = time()
|
| 75 |
+
# get bpe level offsets
|
| 76 |
+
bpe_offsets = tokenizer_output['offset_mapping'][i]
|
| 77 |
+
bpe_groups, saved_ids = get_bpe_groups(token_offsets, bpe_offsets,
|
| 78 |
+
input_ids,
|
| 79 |
+
max_bpe_pieces=max_bpe_pieces)
|
| 80 |
+
t1 = time()
|
| 81 |
+
timings["bpe"] += t1 - t0
|
| 82 |
+
|
| 83 |
+
# reduce sequence length
|
| 84 |
+
reduced_ids, correct_offsets = reduce_input_ids(input_ids, bpe_groups,
|
| 85 |
+
saved_ids,
|
| 86 |
+
max_bpe_length=max_bpe_length,
|
| 87 |
+
max_bpe_pieces=max_bpe_pieces)
|
| 88 |
+
|
| 89 |
+
t2 = time()
|
| 90 |
+
timings["reduce"] += t2 - t1
|
| 91 |
+
|
| 92 |
+
# get mask
|
| 93 |
+
bpe_mask = [1 for _ in correct_offsets]
|
| 94 |
+
output_ids.append(reduced_ids)
|
| 95 |
+
output_offsets.append(correct_offsets)
|
| 96 |
+
output_masks.append(bpe_mask)
|
| 97 |
+
|
| 98 |
+
t3 = time()
|
| 99 |
+
timings["mask"] += t3 - t2
|
| 100 |
+
|
| 101 |
+
# tt = sum(timings.values())
|
| 102 |
+
# timings = {k: f"{round(v * 100 / tt, 2)}%" for k, v in timings.items()}
|
| 103 |
+
# print(timings)
|
| 104 |
+
|
| 105 |
+
output = {index_name: output_ids,
|
| 106 |
+
f"{index_name}-offsets": output_offsets,
|
| 107 |
+
"mask": output_masks}
|
| 108 |
+
return output
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def get_offset_for_tokens(tokens):
|
| 112 |
+
sentence = " ".join(tokens)
|
| 113 |
+
token_offsets = []
|
| 114 |
+
end_idx = 0
|
| 115 |
+
for token in tokens:
|
| 116 |
+
idx = sentence[end_idx:].index(token) + end_idx
|
| 117 |
+
end_idx = idx + len(token)
|
| 118 |
+
offset = (idx, end_idx)
|
| 119 |
+
token_offsets.append(offset)
|
| 120 |
+
return token_offsets
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def get_token_offsets(batch):
|
| 124 |
+
token_offset_list = []
|
| 125 |
+
for tokens in batch:
|
| 126 |
+
token_offsets = get_offset_for_tokens(tokens)
|
| 127 |
+
token_offset_list.append(token_offsets)
|
| 128 |
+
return token_offset_list
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def pad_output(output, pad_idx=0):
|
| 132 |
+
padded_output = {}
|
| 133 |
+
for input_key in output.keys():
|
| 134 |
+
indexes = output[input_key]
|
| 135 |
+
max_len = max([len(x) for x in indexes])
|
| 136 |
+
padded_indexes = []
|
| 137 |
+
for index_list in indexes:
|
| 138 |
+
cur_len = len(index_list)
|
| 139 |
+
pad_len = max_len - cur_len
|
| 140 |
+
padded_indexes.append(index_list + [pad_idx] * pad_len)
|
| 141 |
+
padded_output[input_key] = padded_indexes
|
| 142 |
+
return padded_output
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def tokenize_batch(tokenizer, batch_tokens, index_name="bert",
|
| 146 |
+
max_bpe_length=80, max_bpe_pieces=5):
|
| 147 |
+
timings = {}
|
| 148 |
+
t0 = time()
|
| 149 |
+
# get batch with sentences
|
| 150 |
+
batch_sentences = [" ".join(x) for x in batch_tokens]
|
| 151 |
+
# get token level offsets
|
| 152 |
+
token_offset_list = get_token_offsets(batch_tokens)
|
| 153 |
+
# token_offset_list = get_token_offsets_multi(batch_tokens)
|
| 154 |
+
t1 = time()
|
| 155 |
+
timings["offset_time"] = t1 - t0
|
| 156 |
+
# tokenize batch
|
| 157 |
+
tokenizer_output = tokenizer.batch_encode_plus(batch_sentences,
|
| 158 |
+
pad_to_max_length=False,
|
| 159 |
+
return_offsets_mapping=True,
|
| 160 |
+
add_special_tokens=False)
|
| 161 |
+
|
| 162 |
+
t2 = time()
|
| 163 |
+
timings["tokenize_time"] = t2 - t1
|
| 164 |
+
# postprocess batch
|
| 165 |
+
output = get_offsets_and_reduce_input_ids(tokenizer_output,
|
| 166 |
+
token_offset_list,
|
| 167 |
+
index_name=index_name,
|
| 168 |
+
max_bpe_length=max_bpe_length,
|
| 169 |
+
max_bpe_pieces=max_bpe_pieces)
|
| 170 |
+
|
| 171 |
+
t3 = time()
|
| 172 |
+
timings["reduce_time"] = t3 - t2
|
| 173 |
+
# pad output
|
| 174 |
+
output = pad_output(output)
|
| 175 |
+
t4 = time()
|
| 176 |
+
timings["pading_time"] = t4 - t3
|
| 177 |
+
# tt = sum(timings.values())
|
| 178 |
+
# timings = {k:f"{round(v*100/tt, 2)}%" for k,v in timings.items()}
|
| 179 |
+
# print(timings)
|
| 180 |
+
|
| 181 |
+
return output
|
gector/tokenizer_indexer.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tweaked version of corresponding AllenNLP file"""
|
| 2 |
+
import logging
|
| 3 |
+
from collections import defaultdict
|
| 4 |
+
from typing import Dict, List, Callable
|
| 5 |
+
|
| 6 |
+
from allennlp.common.util import pad_sequence_to_length
|
| 7 |
+
from allennlp.data.token_indexers.token_indexer import TokenIndexer
|
| 8 |
+
from allennlp.data.tokenizers.token import Token
|
| 9 |
+
from allennlp.data.vocabulary import Vocabulary
|
| 10 |
+
from overrides import overrides
|
| 11 |
+
from transformers import AutoTokenizer
|
| 12 |
+
|
| 13 |
+
from utils.helpers import START_TOKEN
|
| 14 |
+
|
| 15 |
+
from gector.tokenization import tokenize_batch
|
| 16 |
+
import copy
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# TODO(joelgrus): Figure out how to generate token_type_ids out of this token indexer.
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class TokenizerIndexer(TokenIndexer[int]):
|
| 25 |
+
"""
|
| 26 |
+
A token indexer that does the wordpiece-tokenization (e.g. for BERT embeddings).
|
| 27 |
+
If you are using one of the pretrained BERT models, you'll want to use the ``PretrainedBertIndexer``
|
| 28 |
+
subclass rather than this base class.
|
| 29 |
+
|
| 30 |
+
Parameters
|
| 31 |
+
----------
|
| 32 |
+
tokenizer : ``Callable[[str], List[str]]``
|
| 33 |
+
A function that does the actual tokenization.
|
| 34 |
+
max_pieces : int, optional (default: 512)
|
| 35 |
+
The BERT embedder uses positional embeddings and so has a corresponding
|
| 36 |
+
maximum length for its input ids. Any inputs longer than this will
|
| 37 |
+
either be truncated (default), or be split apart and batched using a
|
| 38 |
+
sliding window.
|
| 39 |
+
token_min_padding_length : ``int``, optional (default=``0``)
|
| 40 |
+
See :class:`TokenIndexer`.
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
def __init__(self,
|
| 44 |
+
tokenizer: Callable[[str], List[str]],
|
| 45 |
+
max_pieces: int = 512,
|
| 46 |
+
max_pieces_per_token: int = 3,
|
| 47 |
+
token_min_padding_length: int = 0) -> None:
|
| 48 |
+
super().__init__(token_min_padding_length)
|
| 49 |
+
|
| 50 |
+
# The BERT code itself does a two-step tokenization:
|
| 51 |
+
# sentence -> [words], and then word -> [wordpieces]
|
| 52 |
+
# In AllenNLP, the first step is implemented as the ``BertBasicWordSplitter``,
|
| 53 |
+
# and this token indexer handles the second.
|
| 54 |
+
|
| 55 |
+
self.tokenizer = tokenizer
|
| 56 |
+
self.max_pieces_per_token = max_pieces_per_token
|
| 57 |
+
self.max_pieces = max_pieces
|
| 58 |
+
self.max_pieces_per_sentence = 80
|
| 59 |
+
|
| 60 |
+
@overrides
|
| 61 |
+
def tokens_to_indices(self, tokens: List[Token],
|
| 62 |
+
vocabulary: Vocabulary,
|
| 63 |
+
index_name: str) -> Dict[str, List[int]]:
|
| 64 |
+
text = [token.text for token in tokens]
|
| 65 |
+
batch_tokens = [text]
|
| 66 |
+
|
| 67 |
+
output_fast = tokenize_batch(self.tokenizer,
|
| 68 |
+
batch_tokens,
|
| 69 |
+
max_bpe_length=self.max_pieces,
|
| 70 |
+
max_bpe_pieces=self.max_pieces_per_token)
|
| 71 |
+
output_fast = {k: v[0] for k, v in output_fast.items()}
|
| 72 |
+
return output_fast
|
| 73 |
+
|
| 74 |
+
@overrides
|
| 75 |
+
def count_vocab_items(self, token: Token, counter: Dict[str, Dict[str, int]]):
|
| 76 |
+
# If we only use pretrained models, we don't need to do anything here.
|
| 77 |
+
pass
|
| 78 |
+
|
| 79 |
+
@overrides
|
| 80 |
+
def get_padding_token(self) -> int:
|
| 81 |
+
return 0
|
| 82 |
+
|
| 83 |
+
@overrides
|
| 84 |
+
def get_padding_lengths(self, token: int) -> Dict[str, int]: # pylint: disable=unused-argument
|
| 85 |
+
return {}
|
| 86 |
+
|
| 87 |
+
@overrides
|
| 88 |
+
def pad_token_sequence(self,
|
| 89 |
+
tokens: Dict[str, List[int]],
|
| 90 |
+
desired_num_tokens: Dict[str, int],
|
| 91 |
+
padding_lengths: Dict[str, int]) -> Dict[str, List[int]]: # pylint: disable=unused-argument
|
| 92 |
+
return {key: pad_sequence_to_length(val, desired_num_tokens[key])
|
| 93 |
+
for key, val in tokens.items()}
|
| 94 |
+
|
| 95 |
+
@overrides
|
| 96 |
+
def get_keys(self, index_name: str) -> List[str]:
|
| 97 |
+
"""
|
| 98 |
+
We need to override this because the indexer generates multiple keys.
|
| 99 |
+
"""
|
| 100 |
+
# pylint: disable=no-self-use
|
| 101 |
+
return [index_name, f"{index_name}-offsets", f"{index_name}-type-ids", "mask"]
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class PretrainedBertIndexer(TokenizerIndexer):
|
| 105 |
+
# pylint: disable=line-too-long
|
| 106 |
+
"""
|
| 107 |
+
A ``TokenIndexer`` corresponding to a pretrained BERT model.
|
| 108 |
+
|
| 109 |
+
Parameters
|
| 110 |
+
----------
|
| 111 |
+
pretrained_model: ``str``
|
| 112 |
+
Either the name of the pretrained model to use (e.g. 'bert-base-uncased'),
|
| 113 |
+
or the path to the .txt file with its vocabulary.
|
| 114 |
+
If the name is a key in the list of pretrained models at
|
| 115 |
+
https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/pytorch_pretrained_bert/tokenization.py#L33
|
| 116 |
+
the corresponding path will be used; otherwise it will be interpreted as a path or URL.
|
| 117 |
+
do_lowercase: ``bool``, optional (default = True)
|
| 118 |
+
Whether to lowercase the tokens before converting to wordpiece ids.
|
| 119 |
+
max_pieces: int, optional (default: 512)
|
| 120 |
+
The BERT embedder uses positional embeddings and so has a corresponding
|
| 121 |
+
maximum length for its input ids. Any inputs longer than this will
|
| 122 |
+
either be truncated (default), or be split apart and batched using a
|
| 123 |
+
sliding window.
|
| 124 |
+
"""
|
| 125 |
+
|
| 126 |
+
def __init__(self,
|
| 127 |
+
pretrained_model: str,
|
| 128 |
+
do_lowercase: bool = True,
|
| 129 |
+
max_pieces: int = 512,
|
| 130 |
+
max_pieces_per_token: int = 5,
|
| 131 |
+
special_tokens_fix: int = 0) -> None:
|
| 132 |
+
|
| 133 |
+
if pretrained_model.endswith("-cased") and do_lowercase:
|
| 134 |
+
logger.warning("Your BERT model appears to be cased, "
|
| 135 |
+
"but your indexer is lowercasing tokens.")
|
| 136 |
+
elif pretrained_model.endswith("-uncased") and not do_lowercase:
|
| 137 |
+
logger.warning("Your BERT model appears to be uncased, "
|
| 138 |
+
"but your indexer is not lowercasing tokens.")
|
| 139 |
+
|
| 140 |
+
model_name = copy.deepcopy(pretrained_model)
|
| 141 |
+
|
| 142 |
+
model_tokenizer = AutoTokenizer.from_pretrained(
|
| 143 |
+
model_name, do_lower_case=do_lowercase, do_basic_tokenize=False, use_fast=True)
|
| 144 |
+
|
| 145 |
+
# to adjust all tokenizers
|
| 146 |
+
if hasattr(model_tokenizer, 'encoder'):
|
| 147 |
+
model_tokenizer.vocab = model_tokenizer.encoder
|
| 148 |
+
if hasattr(model_tokenizer, 'sp_model'):
|
| 149 |
+
model_tokenizer.vocab = defaultdict(lambda: 1)
|
| 150 |
+
for i in range(model_tokenizer.sp_model.get_piece_size()):
|
| 151 |
+
model_tokenizer.vocab[model_tokenizer.sp_model.id_to_piece(i)] = i
|
| 152 |
+
|
| 153 |
+
if special_tokens_fix:
|
| 154 |
+
model_tokenizer.add_tokens([START_TOKEN])
|
| 155 |
+
model_tokenizer.vocab[START_TOKEN] = len(model_tokenizer) - 1
|
| 156 |
+
|
| 157 |
+
super().__init__(tokenizer=model_tokenizer,
|
| 158 |
+
max_pieces=max_pieces,
|
| 159 |
+
max_pieces_per_token=max_pieces_per_token
|
| 160 |
+
)
|
| 161 |
+
|
gector/trainer.py
ADDED
|
@@ -0,0 +1,845 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tweaked version of corresponding AllenNLP file"""
|
| 2 |
+
import datetime
|
| 3 |
+
import logging
|
| 4 |
+
import math
|
| 5 |
+
import os
|
| 6 |
+
import time
|
| 7 |
+
import traceback
|
| 8 |
+
from typing import Dict, Optional, List, Tuple, Union, Iterable, Any
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.optim.lr_scheduler
|
| 12 |
+
from allennlp.common import Params
|
| 13 |
+
from allennlp.common.checks import ConfigurationError, parse_cuda_device
|
| 14 |
+
from allennlp.common.tqdm import Tqdm
|
| 15 |
+
from allennlp.common.util import dump_metrics, gpu_memory_mb, peak_memory_mb, lazy_groups_of
|
| 16 |
+
from allennlp.data.instance import Instance
|
| 17 |
+
from allennlp.data.iterators.data_iterator import DataIterator, TensorDict
|
| 18 |
+
from allennlp.models.model import Model
|
| 19 |
+
from allennlp.nn import util as nn_util
|
| 20 |
+
from allennlp.training import util as training_util
|
| 21 |
+
from allennlp.training.checkpointer import Checkpointer
|
| 22 |
+
from allennlp.training.learning_rate_schedulers import LearningRateScheduler
|
| 23 |
+
from allennlp.training.metric_tracker import MetricTracker
|
| 24 |
+
from allennlp.training.momentum_schedulers import MomentumScheduler
|
| 25 |
+
from allennlp.training.moving_average import MovingAverage
|
| 26 |
+
from allennlp.training.optimizers import Optimizer
|
| 27 |
+
from allennlp.training.tensorboard_writer import TensorboardWriter
|
| 28 |
+
from allennlp.training.trainer_base import TrainerBase
|
| 29 |
+
|
| 30 |
+
logger = logging.getLogger(__name__)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class Trainer(TrainerBase):
|
| 34 |
+
def __init__(
|
| 35 |
+
self,
|
| 36 |
+
model: Model,
|
| 37 |
+
optimizer: torch.optim.Optimizer,
|
| 38 |
+
scheduler: torch.optim.lr_scheduler,
|
| 39 |
+
iterator: DataIterator,
|
| 40 |
+
train_dataset: Iterable[Instance],
|
| 41 |
+
validation_dataset: Optional[Iterable[Instance]] = None,
|
| 42 |
+
patience: Optional[int] = None,
|
| 43 |
+
validation_metric: str = "-loss",
|
| 44 |
+
validation_iterator: DataIterator = None,
|
| 45 |
+
shuffle: bool = True,
|
| 46 |
+
num_epochs: int = 20,
|
| 47 |
+
accumulated_batch_count: int = 1,
|
| 48 |
+
serialization_dir: Optional[str] = None,
|
| 49 |
+
num_serialized_models_to_keep: int = 20,
|
| 50 |
+
keep_serialized_model_every_num_seconds: int = None,
|
| 51 |
+
checkpointer: Checkpointer = None,
|
| 52 |
+
model_save_interval: float = None,
|
| 53 |
+
cuda_device: Union[int, List] = -1,
|
| 54 |
+
grad_norm: Optional[float] = None,
|
| 55 |
+
grad_clipping: Optional[float] = None,
|
| 56 |
+
learning_rate_scheduler: Optional[LearningRateScheduler] = None,
|
| 57 |
+
momentum_scheduler: Optional[MomentumScheduler] = None,
|
| 58 |
+
summary_interval: int = 100,
|
| 59 |
+
histogram_interval: int = None,
|
| 60 |
+
should_log_parameter_statistics: bool = True,
|
| 61 |
+
should_log_learning_rate: bool = False,
|
| 62 |
+
log_batch_size_period: Optional[int] = None,
|
| 63 |
+
moving_average: Optional[MovingAverage] = None,
|
| 64 |
+
cold_step_count: int = 0,
|
| 65 |
+
cold_lr: float = 1e-3,
|
| 66 |
+
cuda_verbose_step=None,
|
| 67 |
+
) -> None:
|
| 68 |
+
"""
|
| 69 |
+
A trainer for doing supervised learning. It just takes a labeled dataset
|
| 70 |
+
and a ``DataIterator``, and uses the supplied ``Optimizer`` to learn the weights
|
| 71 |
+
for your model over some fixed number of epochs. You can also pass in a validation
|
| 72 |
+
dataset and enable early stopping. There are many other bells and whistles as well.
|
| 73 |
+
|
| 74 |
+
Parameters
|
| 75 |
+
----------
|
| 76 |
+
model : ``Model``, required.
|
| 77 |
+
An AllenNLP model to be optimized. Pytorch Modules can also be optimized if
|
| 78 |
+
their ``forward`` method returns a dictionary with a "loss" key, containing a
|
| 79 |
+
scalar tensor representing the loss function to be optimized.
|
| 80 |
+
|
| 81 |
+
If you are training your model using GPUs, your model should already be
|
| 82 |
+
on the correct device. (If you use `Trainer.from_params` this will be
|
| 83 |
+
handled for you.)
|
| 84 |
+
optimizer : ``torch.nn.Optimizer``, required.
|
| 85 |
+
An instance of a Pytorch Optimizer, instantiated with the parameters of the
|
| 86 |
+
model to be optimized.
|
| 87 |
+
iterator : ``DataIterator``, required.
|
| 88 |
+
A method for iterating over a ``Dataset``, yielding padded indexed batches.
|
| 89 |
+
train_dataset : ``Dataset``, required.
|
| 90 |
+
A ``Dataset`` to train on. The dataset should have already been indexed.
|
| 91 |
+
validation_dataset : ``Dataset``, optional, (default = None).
|
| 92 |
+
A ``Dataset`` to evaluate on. The dataset should have already been indexed.
|
| 93 |
+
patience : Optional[int] > 0, optional (default=None)
|
| 94 |
+
Number of epochs to be patient before early stopping: the training is stopped
|
| 95 |
+
after ``patience`` epochs with no improvement. If given, it must be ``> 0``.
|
| 96 |
+
If None, early stopping is disabled.
|
| 97 |
+
validation_metric : str, optional (default="loss")
|
| 98 |
+
Validation metric to measure for whether to stop training using patience
|
| 99 |
+
and whether to serialize an ``is_best`` model each epoch. The metric name
|
| 100 |
+
must be prepended with either "+" or "-", which specifies whether the metric
|
| 101 |
+
is an increasing or decreasing function.
|
| 102 |
+
validation_iterator : ``DataIterator``, optional (default=None)
|
| 103 |
+
An iterator to use for the validation set. If ``None``, then
|
| 104 |
+
use the training `iterator`.
|
| 105 |
+
shuffle: ``bool``, optional (default=True)
|
| 106 |
+
Whether to shuffle the instances in the iterator or not.
|
| 107 |
+
num_epochs : int, optional (default = 20)
|
| 108 |
+
Number of training epochs.
|
| 109 |
+
serialization_dir : str, optional (default=None)
|
| 110 |
+
Path to directory for saving and loading model files. Models will not be saved if
|
| 111 |
+
this parameter is not passed.
|
| 112 |
+
num_serialized_models_to_keep : ``int``, optional (default=20)
|
| 113 |
+
Number of previous model checkpoints to retain. Default is to keep 20 checkpoints.
|
| 114 |
+
A value of None or -1 means all checkpoints will be kept.
|
| 115 |
+
keep_serialized_model_every_num_seconds : ``int``, optional (default=None)
|
| 116 |
+
If num_serialized_models_to_keep is not None, then occasionally it's useful to
|
| 117 |
+
save models at a given interval in addition to the last num_serialized_models_to_keep.
|
| 118 |
+
To do so, specify keep_serialized_model_every_num_seconds as the number of seconds
|
| 119 |
+
between permanently saved checkpoints. Note that this option is only used if
|
| 120 |
+
num_serialized_models_to_keep is not None, otherwise all checkpoints are kept.
|
| 121 |
+
checkpointer : ``Checkpointer``, optional (default=None)
|
| 122 |
+
An instance of class Checkpointer to use instead of the default. If a checkpointer is specified,
|
| 123 |
+
the arguments num_serialized_models_to_keep and keep_serialized_model_every_num_seconds should
|
| 124 |
+
not be specified. The caller is responsible for initializing the checkpointer so that it is
|
| 125 |
+
consistent with serialization_dir.
|
| 126 |
+
model_save_interval : ``float``, optional (default=None)
|
| 127 |
+
If provided, then serialize models every ``model_save_interval``
|
| 128 |
+
seconds within single epochs. In all cases, models are also saved
|
| 129 |
+
at the end of every epoch if ``serialization_dir`` is provided.
|
| 130 |
+
cuda_device : ``Union[int, List[int]]``, optional (default = -1)
|
| 131 |
+
An integer or list of integers specifying the CUDA device(s) to use. If -1, the CPU is used.
|
| 132 |
+
grad_norm : ``float``, optional, (default = None).
|
| 133 |
+
If provided, gradient norms will be rescaled to have a maximum of this value.
|
| 134 |
+
grad_clipping : ``float``, optional (default = ``None``).
|
| 135 |
+
If provided, gradients will be clipped `during the backward pass` to have an (absolute)
|
| 136 |
+
maximum of this value. If you are getting ``NaNs`` in your gradients during training
|
| 137 |
+
that are not solved by using ``grad_norm``, you may need this.
|
| 138 |
+
learning_rate_scheduler : ``LearningRateScheduler``, optional (default = None)
|
| 139 |
+
If specified, the learning rate will be decayed with respect to
|
| 140 |
+
this schedule at the end of each epoch (or batch, if the scheduler implements
|
| 141 |
+
the ``step_batch`` method). If you use :class:`torch.optim.lr_scheduler.ReduceLROnPlateau`,
|
| 142 |
+
this will use the ``validation_metric`` provided to determine if learning has plateaued.
|
| 143 |
+
To support updating the learning rate on every batch, this can optionally implement
|
| 144 |
+
``step_batch(batch_num_total)`` which updates the learning rate given the batch number.
|
| 145 |
+
momentum_scheduler : ``MomentumScheduler``, optional (default = None)
|
| 146 |
+
If specified, the momentum will be updated at the end of each batch or epoch
|
| 147 |
+
according to the schedule.
|
| 148 |
+
summary_interval: ``int``, optional, (default = 100)
|
| 149 |
+
Number of batches between logging scalars to tensorboard
|
| 150 |
+
histogram_interval : ``int``, optional, (default = ``None``)
|
| 151 |
+
If not None, then log histograms to tensorboard every ``histogram_interval`` batches.
|
| 152 |
+
When this parameter is specified, the following additional logging is enabled:
|
| 153 |
+
* Histograms of model parameters
|
| 154 |
+
* The ratio of parameter update norm to parameter norm
|
| 155 |
+
* Histogram of layer activations
|
| 156 |
+
We log histograms of the parameters returned by
|
| 157 |
+
``model.get_parameters_for_histogram_tensorboard_logging``.
|
| 158 |
+
The layer activations are logged for any modules in the ``Model`` that have
|
| 159 |
+
the attribute ``should_log_activations`` set to ``True``. Logging
|
| 160 |
+
histograms requires a number of GPU-CPU copies during training and is typically
|
| 161 |
+
slow, so we recommend logging histograms relatively infrequently.
|
| 162 |
+
Note: only Modules that return tensors, tuples of tensors or dicts
|
| 163 |
+
with tensors as values currently support activation logging.
|
| 164 |
+
should_log_parameter_statistics : ``bool``, optional, (default = True)
|
| 165 |
+
Whether to send parameter statistics (mean and standard deviation
|
| 166 |
+
of parameters and gradients) to tensorboard.
|
| 167 |
+
should_log_learning_rate : ``bool``, optional, (default = False)
|
| 168 |
+
Whether to send parameter specific learning rate to tensorboard.
|
| 169 |
+
log_batch_size_period : ``int``, optional, (default = ``None``)
|
| 170 |
+
If defined, how often to log the average batch size.
|
| 171 |
+
moving_average: ``MovingAverage``, optional, (default = None)
|
| 172 |
+
If provided, we will maintain moving averages for all parameters. During training, we
|
| 173 |
+
employ a shadow variable for each parameter, which maintains the moving average. During
|
| 174 |
+
evaluation, we backup the original parameters and assign the moving averages to corresponding
|
| 175 |
+
parameters. Be careful that when saving the checkpoint, we will save the moving averages of
|
| 176 |
+
parameters. This is necessary because we want the saved model to perform as well as the validated
|
| 177 |
+
model if we load it later. But this may cause problems if you restart the training from checkpoint.
|
| 178 |
+
"""
|
| 179 |
+
super().__init__(serialization_dir, cuda_device)
|
| 180 |
+
|
| 181 |
+
# I am not calling move_to_gpu here, because if the model is
|
| 182 |
+
# not already on the GPU then the optimizer is going to be wrong.
|
| 183 |
+
self.model = model
|
| 184 |
+
|
| 185 |
+
self.iterator = iterator
|
| 186 |
+
self._validation_iterator = validation_iterator
|
| 187 |
+
self.shuffle = shuffle
|
| 188 |
+
self.optimizer = optimizer
|
| 189 |
+
self.scheduler = scheduler
|
| 190 |
+
self.train_data = train_dataset
|
| 191 |
+
self._validation_data = validation_dataset
|
| 192 |
+
self.accumulated_batch_count = accumulated_batch_count
|
| 193 |
+
self.cold_step_count = cold_step_count
|
| 194 |
+
self.cold_lr = cold_lr
|
| 195 |
+
self.cuda_verbose_step = cuda_verbose_step
|
| 196 |
+
|
| 197 |
+
if patience is None: # no early stopping
|
| 198 |
+
if validation_dataset:
|
| 199 |
+
logger.warning(
|
| 200 |
+
"You provided a validation dataset but patience was set to None, "
|
| 201 |
+
"meaning that early stopping is disabled"
|
| 202 |
+
)
|
| 203 |
+
elif (not isinstance(patience, int)) or patience <= 0:
|
| 204 |
+
raise ConfigurationError(
|
| 205 |
+
'{} is an invalid value for "patience": it must be a positive integer '
|
| 206 |
+
"or None (if you want to disable early stopping)".format(patience)
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
# For tracking is_best_so_far and should_stop_early
|
| 210 |
+
self._metric_tracker = MetricTracker(patience, validation_metric)
|
| 211 |
+
# Get rid of + or -
|
| 212 |
+
self._validation_metric = validation_metric[1:]
|
| 213 |
+
|
| 214 |
+
self._num_epochs = num_epochs
|
| 215 |
+
|
| 216 |
+
if checkpointer is not None:
|
| 217 |
+
# We can't easily check if these parameters were passed in, so check against their default values.
|
| 218 |
+
# We don't check against serialization_dir since it is also used by the parent class.
|
| 219 |
+
if num_serialized_models_to_keep != 20 \
|
| 220 |
+
or keep_serialized_model_every_num_seconds is not None:
|
| 221 |
+
raise ConfigurationError(
|
| 222 |
+
"When passing a custom Checkpointer, you may not also pass in separate checkpointer "
|
| 223 |
+
"args 'num_serialized_models_to_keep' or 'keep_serialized_model_every_num_seconds'."
|
| 224 |
+
)
|
| 225 |
+
self._checkpointer = checkpointer
|
| 226 |
+
else:
|
| 227 |
+
self._checkpointer = Checkpointer(
|
| 228 |
+
serialization_dir,
|
| 229 |
+
keep_serialized_model_every_num_seconds,
|
| 230 |
+
num_serialized_models_to_keep,
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
self._model_save_interval = model_save_interval
|
| 234 |
+
|
| 235 |
+
self._grad_norm = grad_norm
|
| 236 |
+
self._grad_clipping = grad_clipping
|
| 237 |
+
|
| 238 |
+
self._learning_rate_scheduler = learning_rate_scheduler
|
| 239 |
+
self._momentum_scheduler = momentum_scheduler
|
| 240 |
+
self._moving_average = moving_average
|
| 241 |
+
|
| 242 |
+
# We keep the total batch number as an instance variable because it
|
| 243 |
+
# is used inside a closure for the hook which logs activations in
|
| 244 |
+
# ``_enable_activation_logging``.
|
| 245 |
+
self._batch_num_total = 0
|
| 246 |
+
|
| 247 |
+
self._tensorboard = TensorboardWriter(
|
| 248 |
+
get_batch_num_total=lambda: self._batch_num_total,
|
| 249 |
+
serialization_dir=serialization_dir,
|
| 250 |
+
summary_interval=summary_interval,
|
| 251 |
+
histogram_interval=histogram_interval,
|
| 252 |
+
should_log_parameter_statistics=should_log_parameter_statistics,
|
| 253 |
+
should_log_learning_rate=should_log_learning_rate,
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
self._log_batch_size_period = log_batch_size_period
|
| 257 |
+
|
| 258 |
+
self._last_log = 0.0 # time of last logging
|
| 259 |
+
|
| 260 |
+
# Enable activation logging.
|
| 261 |
+
if histogram_interval is not None:
|
| 262 |
+
self._tensorboard.enable_activation_logging(self.model)
|
| 263 |
+
|
| 264 |
+
def rescale_gradients(self) -> Optional[float]:
|
| 265 |
+
return training_util.rescale_gradients(self.model, self._grad_norm)
|
| 266 |
+
|
| 267 |
+
def batch_loss(self, batch_group: List[TensorDict], for_training: bool) -> torch.Tensor:
|
| 268 |
+
"""
|
| 269 |
+
Does a forward pass on the given batches and returns the ``loss`` value in the result.
|
| 270 |
+
If ``for_training`` is `True` also applies regularization penalty.
|
| 271 |
+
"""
|
| 272 |
+
if self._multiple_gpu:
|
| 273 |
+
output_dict = training_util.data_parallel(batch_group, self.model, self._cuda_devices)
|
| 274 |
+
else:
|
| 275 |
+
assert len(batch_group) == 1
|
| 276 |
+
batch = batch_group[0]
|
| 277 |
+
batch = nn_util.move_to_device(batch, self._cuda_devices[0])
|
| 278 |
+
output_dict = self.model(**batch)
|
| 279 |
+
|
| 280 |
+
try:
|
| 281 |
+
loss = output_dict["loss"]
|
| 282 |
+
if for_training:
|
| 283 |
+
loss += self.model.get_regularization_penalty()
|
| 284 |
+
except KeyError:
|
| 285 |
+
if for_training:
|
| 286 |
+
raise RuntimeError(
|
| 287 |
+
"The model you are trying to optimize does not contain a"
|
| 288 |
+
" 'loss' key in the output of model.forward(inputs)."
|
| 289 |
+
)
|
| 290 |
+
loss = None
|
| 291 |
+
|
| 292 |
+
return loss
|
| 293 |
+
|
| 294 |
+
def _train_epoch(self, epoch: int) -> Dict[str, float]:
|
| 295 |
+
"""
|
| 296 |
+
Trains one epoch and returns metrics.
|
| 297 |
+
"""
|
| 298 |
+
logger.info("Epoch %d/%d", epoch, self._num_epochs - 1)
|
| 299 |
+
peak_cpu_usage = peak_memory_mb()
|
| 300 |
+
logger.info(f"Peak CPU memory usage MB: {peak_cpu_usage}")
|
| 301 |
+
gpu_usage = []
|
| 302 |
+
for gpu, memory in gpu_memory_mb().items():
|
| 303 |
+
gpu_usage.append((gpu, memory))
|
| 304 |
+
logger.info(f"GPU {gpu} memory usage MB: {memory}")
|
| 305 |
+
|
| 306 |
+
train_loss = 0.0
|
| 307 |
+
# Set the model to "train" mode.
|
| 308 |
+
self.model.train()
|
| 309 |
+
|
| 310 |
+
num_gpus = len(self._cuda_devices)
|
| 311 |
+
|
| 312 |
+
# Get tqdm for the training batches
|
| 313 |
+
raw_train_generator = self.iterator(self.train_data, num_epochs=1, shuffle=self.shuffle)
|
| 314 |
+
train_generator = lazy_groups_of(raw_train_generator, num_gpus)
|
| 315 |
+
num_training_batches = math.ceil(self.iterator.get_num_batches(self.train_data) / num_gpus)
|
| 316 |
+
residue = num_training_batches % self.accumulated_batch_count
|
| 317 |
+
self._last_log = time.time()
|
| 318 |
+
last_save_time = time.time()
|
| 319 |
+
|
| 320 |
+
batches_this_epoch = 0
|
| 321 |
+
if self._batch_num_total is None:
|
| 322 |
+
self._batch_num_total = 0
|
| 323 |
+
|
| 324 |
+
histogram_parameters = set(self.model.get_parameters_for_histogram_tensorboard_logging())
|
| 325 |
+
|
| 326 |
+
logger.info("Training")
|
| 327 |
+
train_generator_tqdm = Tqdm.tqdm(train_generator, total=num_training_batches)
|
| 328 |
+
cumulative_batch_size = 0
|
| 329 |
+
self.optimizer.zero_grad()
|
| 330 |
+
for batch_group in train_generator_tqdm:
|
| 331 |
+
batches_this_epoch += 1
|
| 332 |
+
self._batch_num_total += 1
|
| 333 |
+
batch_num_total = self._batch_num_total
|
| 334 |
+
|
| 335 |
+
iter_len = self.accumulated_batch_count \
|
| 336 |
+
if batches_this_epoch <= (num_training_batches - residue) else residue
|
| 337 |
+
|
| 338 |
+
if self.cuda_verbose_step is not None and batch_num_total % self.cuda_verbose_step == 0:
|
| 339 |
+
print(f'Before forward pass - Cuda memory allocated: {torch.cuda.memory_allocated() / 1e9}')
|
| 340 |
+
print(f'Before forward pass - Cuda memory cached: {torch.cuda.memory_cached() / 1e9}')
|
| 341 |
+
try:
|
| 342 |
+
loss = self.batch_loss(batch_group, for_training=True) / iter_len
|
| 343 |
+
except RuntimeError as e:
|
| 344 |
+
print(e)
|
| 345 |
+
for x in batch_group:
|
| 346 |
+
all_words = [len(y['words']) for y in x['metadata']]
|
| 347 |
+
print(f"Total sents: {len(all_words)}. "
|
| 348 |
+
f"Min {min(all_words)}. Max {max(all_words)}")
|
| 349 |
+
for elem in ['labels', 'd_tags']:
|
| 350 |
+
tt = x[elem]
|
| 351 |
+
print(
|
| 352 |
+
f"{elem} shape {list(tt.shape)} and min {tt.min().item()} and {tt.max().item()}")
|
| 353 |
+
for elem in ["bert", "mask", "bert-offsets"]:
|
| 354 |
+
tt = x['tokens'][elem]
|
| 355 |
+
print(
|
| 356 |
+
f"{elem} shape {list(tt.shape)} and min {tt.min().item()} and {tt.max().item()}")
|
| 357 |
+
raise e
|
| 358 |
+
|
| 359 |
+
if self.cuda_verbose_step is not None and batch_num_total % self.cuda_verbose_step == 0:
|
| 360 |
+
print(f'After forward pass - Cuda memory allocated: {torch.cuda.memory_allocated() / 1e9}')
|
| 361 |
+
print(f'After forward pass - Cuda memory cached: {torch.cuda.memory_cached() / 1e9}')
|
| 362 |
+
|
| 363 |
+
if torch.isnan(loss):
|
| 364 |
+
raise ValueError("nan loss encountered")
|
| 365 |
+
|
| 366 |
+
loss.backward()
|
| 367 |
+
|
| 368 |
+
if self.cuda_verbose_step is not None and batch_num_total % self.cuda_verbose_step == 0:
|
| 369 |
+
print(f'After backprop - Cuda memory allocated: {torch.cuda.memory_allocated() / 1e9}')
|
| 370 |
+
print(f'After backprop - Cuda memory cached: {torch.cuda.memory_cached() / 1e9}')
|
| 371 |
+
|
| 372 |
+
train_loss += loss.item() * iter_len
|
| 373 |
+
|
| 374 |
+
del batch_group, loss
|
| 375 |
+
torch.cuda.empty_cache()
|
| 376 |
+
|
| 377 |
+
if self.cuda_verbose_step is not None and batch_num_total % self.cuda_verbose_step == 0:
|
| 378 |
+
print(f'After collecting garbage - Cuda memory allocated: {torch.cuda.memory_allocated() / 1e9}')
|
| 379 |
+
print(f'After collecting garbage - Cuda memory cached: {torch.cuda.memory_cached() / 1e9}')
|
| 380 |
+
|
| 381 |
+
batch_grad_norm = self.rescale_gradients()
|
| 382 |
+
|
| 383 |
+
# This does nothing if batch_num_total is None or you are using a
|
| 384 |
+
# scheduler which doesn't update per batch.
|
| 385 |
+
if self._learning_rate_scheduler:
|
| 386 |
+
self._learning_rate_scheduler.step_batch(batch_num_total)
|
| 387 |
+
if self._momentum_scheduler:
|
| 388 |
+
self._momentum_scheduler.step_batch(batch_num_total)
|
| 389 |
+
|
| 390 |
+
if self._tensorboard.should_log_histograms_this_batch():
|
| 391 |
+
# get the magnitude of parameter updates for logging
|
| 392 |
+
# We need a copy of current parameters to compute magnitude of updates,
|
| 393 |
+
# and copy them to CPU so large models won't go OOM on the GPU.
|
| 394 |
+
param_updates = {
|
| 395 |
+
name: param.detach().cpu().clone()
|
| 396 |
+
for name, param in self.model.named_parameters()
|
| 397 |
+
}
|
| 398 |
+
if batches_this_epoch % self.accumulated_batch_count == 0 or \
|
| 399 |
+
batches_this_epoch == num_training_batches:
|
| 400 |
+
self.optimizer.step()
|
| 401 |
+
self.optimizer.zero_grad()
|
| 402 |
+
for name, param in self.model.named_parameters():
|
| 403 |
+
param_updates[name].sub_(param.detach().cpu())
|
| 404 |
+
update_norm = torch.norm(param_updates[name].view(-1))
|
| 405 |
+
param_norm = torch.norm(param.view(-1)).cpu()
|
| 406 |
+
self._tensorboard.add_train_scalar(
|
| 407 |
+
"gradient_update/" + name, update_norm / (param_norm + 1e-7)
|
| 408 |
+
)
|
| 409 |
+
else:
|
| 410 |
+
if batches_this_epoch % self.accumulated_batch_count == 0 or \
|
| 411 |
+
batches_this_epoch == num_training_batches:
|
| 412 |
+
self.optimizer.step()
|
| 413 |
+
self.optimizer.zero_grad()
|
| 414 |
+
|
| 415 |
+
# Update moving averages
|
| 416 |
+
if self._moving_average is not None:
|
| 417 |
+
self._moving_average.apply(batch_num_total)
|
| 418 |
+
|
| 419 |
+
# Update the description with the latest metrics
|
| 420 |
+
metrics = training_util.get_metrics(self.model, train_loss, batches_this_epoch)
|
| 421 |
+
description = training_util.description_from_metrics(metrics)
|
| 422 |
+
|
| 423 |
+
train_generator_tqdm.set_description(description, refresh=False)
|
| 424 |
+
|
| 425 |
+
# Log parameter values to Tensorboard
|
| 426 |
+
if self._tensorboard.should_log_this_batch():
|
| 427 |
+
self._tensorboard.log_parameter_and_gradient_statistics(self.model, batch_grad_norm)
|
| 428 |
+
self._tensorboard.log_learning_rates(self.model, self.optimizer)
|
| 429 |
+
|
| 430 |
+
self._tensorboard.add_train_scalar("loss/loss_train", metrics["loss"])
|
| 431 |
+
self._tensorboard.log_metrics({"epoch_metrics/" + k: v for k, v in metrics.items()})
|
| 432 |
+
|
| 433 |
+
if self._tensorboard.should_log_histograms_this_batch():
|
| 434 |
+
self._tensorboard.log_histograms(self.model, histogram_parameters)
|
| 435 |
+
|
| 436 |
+
if self._log_batch_size_period:
|
| 437 |
+
cur_batch = sum([training_util.get_batch_size(batch) for batch in batch_group])
|
| 438 |
+
cumulative_batch_size += cur_batch
|
| 439 |
+
if (batches_this_epoch - 1) % self._log_batch_size_period == 0:
|
| 440 |
+
average = cumulative_batch_size / batches_this_epoch
|
| 441 |
+
logger.info(f"current batch size: {cur_batch} mean batch size: {average}")
|
| 442 |
+
self._tensorboard.add_train_scalar("current_batch_size", cur_batch)
|
| 443 |
+
self._tensorboard.add_train_scalar("mean_batch_size", average)
|
| 444 |
+
|
| 445 |
+
# Save model if needed.
|
| 446 |
+
if self._model_save_interval is not None and (
|
| 447 |
+
time.time() - last_save_time > self._model_save_interval
|
| 448 |
+
):
|
| 449 |
+
last_save_time = time.time()
|
| 450 |
+
self._save_checkpoint(
|
| 451 |
+
"{0}.{1}".format(epoch, training_util.time_to_str(int(last_save_time)))
|
| 452 |
+
)
|
| 453 |
+
|
| 454 |
+
metrics = training_util.get_metrics(self.model, train_loss, batches_this_epoch, reset=True)
|
| 455 |
+
metrics["cpu_memory_MB"] = peak_cpu_usage
|
| 456 |
+
for (gpu_num, memory) in gpu_usage:
|
| 457 |
+
metrics["gpu_" + str(gpu_num) + "_memory_MB"] = memory
|
| 458 |
+
return metrics
|
| 459 |
+
|
| 460 |
+
def _validation_loss(self) -> Tuple[float, int]:
|
| 461 |
+
"""
|
| 462 |
+
Computes the validation loss. Returns it and the number of batches.
|
| 463 |
+
"""
|
| 464 |
+
logger.info("Validating")
|
| 465 |
+
|
| 466 |
+
self.model.eval()
|
| 467 |
+
|
| 468 |
+
# Replace parameter values with the shadow values from the moving averages.
|
| 469 |
+
if self._moving_average is not None:
|
| 470 |
+
self._moving_average.assign_average_value()
|
| 471 |
+
|
| 472 |
+
if self._validation_iterator is not None:
|
| 473 |
+
val_iterator = self._validation_iterator
|
| 474 |
+
else:
|
| 475 |
+
val_iterator = self.iterator
|
| 476 |
+
|
| 477 |
+
num_gpus = len(self._cuda_devices)
|
| 478 |
+
|
| 479 |
+
raw_val_generator = val_iterator(self._validation_data, num_epochs=1, shuffle=False)
|
| 480 |
+
val_generator = lazy_groups_of(raw_val_generator, num_gpus)
|
| 481 |
+
num_validation_batches = math.ceil(
|
| 482 |
+
val_iterator.get_num_batches(self._validation_data) / num_gpus
|
| 483 |
+
)
|
| 484 |
+
val_generator_tqdm = Tqdm.tqdm(val_generator, total=num_validation_batches)
|
| 485 |
+
batches_this_epoch = 0
|
| 486 |
+
val_loss = 0
|
| 487 |
+
for batch_group in val_generator_tqdm:
|
| 488 |
+
|
| 489 |
+
loss = self.batch_loss(batch_group, for_training=False)
|
| 490 |
+
if loss is not None:
|
| 491 |
+
# You shouldn't necessarily have to compute a loss for validation, so we allow for
|
| 492 |
+
# `loss` to be None. We need to be careful, though - `batches_this_epoch` is
|
| 493 |
+
# currently only used as the divisor for the loss function, so we can safely only
|
| 494 |
+
# count those batches for which we actually have a loss. If this variable ever
|
| 495 |
+
# gets used for something else, we might need to change things around a bit.
|
| 496 |
+
batches_this_epoch += 1
|
| 497 |
+
val_loss += loss.detach().cpu().numpy()
|
| 498 |
+
|
| 499 |
+
# Update the description with the latest metrics
|
| 500 |
+
val_metrics = training_util.get_metrics(self.model, val_loss, batches_this_epoch)
|
| 501 |
+
description = training_util.description_from_metrics(val_metrics)
|
| 502 |
+
val_generator_tqdm.set_description(description, refresh=False)
|
| 503 |
+
|
| 504 |
+
# Now restore the original parameter values.
|
| 505 |
+
if self._moving_average is not None:
|
| 506 |
+
self._moving_average.restore()
|
| 507 |
+
|
| 508 |
+
return val_loss, batches_this_epoch
|
| 509 |
+
|
| 510 |
+
def train(self) -> Dict[str, Any]:
|
| 511 |
+
"""
|
| 512 |
+
Trains the supplied model with the supplied parameters.
|
| 513 |
+
"""
|
| 514 |
+
try:
|
| 515 |
+
epoch_counter = self._restore_checkpoint()
|
| 516 |
+
except RuntimeError:
|
| 517 |
+
traceback.print_exc()
|
| 518 |
+
raise ConfigurationError(
|
| 519 |
+
"Could not recover training from the checkpoint. Did you mean to output to "
|
| 520 |
+
"a different serialization directory or delete the existing serialization "
|
| 521 |
+
"directory?"
|
| 522 |
+
)
|
| 523 |
+
|
| 524 |
+
training_util.enable_gradient_clipping(self.model, self._grad_clipping)
|
| 525 |
+
|
| 526 |
+
logger.info("Beginning training.")
|
| 527 |
+
|
| 528 |
+
train_metrics: Dict[str, float] = {}
|
| 529 |
+
val_metrics: Dict[str, float] = {}
|
| 530 |
+
this_epoch_val_metric: float = None
|
| 531 |
+
metrics: Dict[str, Any] = {}
|
| 532 |
+
epochs_trained = 0
|
| 533 |
+
training_start_time = time.time()
|
| 534 |
+
|
| 535 |
+
if self.cold_step_count > 0:
|
| 536 |
+
base_lr = self.optimizer.param_groups[0]['lr']
|
| 537 |
+
for param_group in self.optimizer.param_groups:
|
| 538 |
+
param_group['lr'] = self.cold_lr
|
| 539 |
+
self.model.text_field_embedder._token_embedders['bert'].set_weights(freeze=True)
|
| 540 |
+
|
| 541 |
+
metrics["best_epoch"] = self._metric_tracker.best_epoch
|
| 542 |
+
for key, value in self._metric_tracker.best_epoch_metrics.items():
|
| 543 |
+
metrics["best_validation_" + key] = value
|
| 544 |
+
|
| 545 |
+
for epoch in range(epoch_counter, self._num_epochs):
|
| 546 |
+
if epoch == self.cold_step_count and epoch != 0:
|
| 547 |
+
for param_group in self.optimizer.param_groups:
|
| 548 |
+
param_group['lr'] = base_lr
|
| 549 |
+
self.model.text_field_embedder._token_embedders['bert'].set_weights(freeze=False)
|
| 550 |
+
|
| 551 |
+
epoch_start_time = time.time()
|
| 552 |
+
train_metrics = self._train_epoch(epoch)
|
| 553 |
+
|
| 554 |
+
# get peak of memory usage
|
| 555 |
+
if "cpu_memory_MB" in train_metrics:
|
| 556 |
+
metrics["peak_cpu_memory_MB"] = max(
|
| 557 |
+
metrics.get("peak_cpu_memory_MB", 0), train_metrics["cpu_memory_MB"]
|
| 558 |
+
)
|
| 559 |
+
for key, value in train_metrics.items():
|
| 560 |
+
if key.startswith("gpu_"):
|
| 561 |
+
metrics["peak_" + key] = max(metrics.get("peak_" + key, 0), value)
|
| 562 |
+
|
| 563 |
+
# clear cache before validation
|
| 564 |
+
torch.cuda.empty_cache()
|
| 565 |
+
if self._validation_data is not None:
|
| 566 |
+
with torch.no_grad():
|
| 567 |
+
# We have a validation set, so compute all the metrics on it.
|
| 568 |
+
val_loss, num_batches = self._validation_loss()
|
| 569 |
+
val_metrics = training_util.get_metrics(
|
| 570 |
+
self.model, val_loss, num_batches, reset=True
|
| 571 |
+
)
|
| 572 |
+
|
| 573 |
+
# Check validation metric for early stopping
|
| 574 |
+
this_epoch_val_metric = val_metrics[self._validation_metric]
|
| 575 |
+
self._metric_tracker.add_metric(this_epoch_val_metric)
|
| 576 |
+
|
| 577 |
+
if self._metric_tracker.should_stop_early():
|
| 578 |
+
logger.info("Ran out of patience. Stopping training.")
|
| 579 |
+
break
|
| 580 |
+
|
| 581 |
+
self._tensorboard.log_metrics(
|
| 582 |
+
train_metrics, val_metrics=val_metrics, log_to_console=True, epoch=epoch + 1
|
| 583 |
+
) # +1 because tensorboard doesn't like 0
|
| 584 |
+
|
| 585 |
+
# Create overall metrics dict
|
| 586 |
+
training_elapsed_time = time.time() - training_start_time
|
| 587 |
+
metrics["training_duration"] = str(datetime.timedelta(seconds=training_elapsed_time))
|
| 588 |
+
metrics["training_start_epoch"] = epoch_counter
|
| 589 |
+
metrics["training_epochs"] = epochs_trained
|
| 590 |
+
metrics["epoch"] = epoch
|
| 591 |
+
|
| 592 |
+
for key, value in train_metrics.items():
|
| 593 |
+
metrics["training_" + key] = value
|
| 594 |
+
for key, value in val_metrics.items():
|
| 595 |
+
metrics["validation_" + key] = value
|
| 596 |
+
|
| 597 |
+
# if self.cold_step_count <= epoch:
|
| 598 |
+
self.scheduler.step(metrics['validation_loss'])
|
| 599 |
+
|
| 600 |
+
if self._metric_tracker.is_best_so_far():
|
| 601 |
+
# Update all the best_ metrics.
|
| 602 |
+
# (Otherwise they just stay the same as they were.)
|
| 603 |
+
metrics["best_epoch"] = epoch
|
| 604 |
+
for key, value in val_metrics.items():
|
| 605 |
+
metrics["best_validation_" + key] = value
|
| 606 |
+
|
| 607 |
+
self._metric_tracker.best_epoch_metrics = val_metrics
|
| 608 |
+
|
| 609 |
+
if self._serialization_dir:
|
| 610 |
+
dump_metrics(
|
| 611 |
+
os.path.join(self._serialization_dir, f"metrics_epoch_{epoch}.json"), metrics
|
| 612 |
+
)
|
| 613 |
+
|
| 614 |
+
# The Scheduler API is agnostic to whether your schedule requires a validation metric -
|
| 615 |
+
# if it doesn't, the validation metric passed here is ignored.
|
| 616 |
+
if self._learning_rate_scheduler:
|
| 617 |
+
self._learning_rate_scheduler.step(this_epoch_val_metric, epoch)
|
| 618 |
+
if self._momentum_scheduler:
|
| 619 |
+
self._momentum_scheduler.step(this_epoch_val_metric, epoch)
|
| 620 |
+
|
| 621 |
+
self._save_checkpoint(epoch)
|
| 622 |
+
|
| 623 |
+
epoch_elapsed_time = time.time() - epoch_start_time
|
| 624 |
+
logger.info("Epoch duration: %s", datetime.timedelta(seconds=epoch_elapsed_time))
|
| 625 |
+
|
| 626 |
+
if epoch < self._num_epochs - 1:
|
| 627 |
+
training_elapsed_time = time.time() - training_start_time
|
| 628 |
+
estimated_time_remaining = training_elapsed_time * (
|
| 629 |
+
(self._num_epochs - epoch_counter) / float(epoch - epoch_counter + 1) - 1
|
| 630 |
+
)
|
| 631 |
+
formatted_time = str(datetime.timedelta(seconds=int(estimated_time_remaining)))
|
| 632 |
+
logger.info("Estimated training time remaining: %s", formatted_time)
|
| 633 |
+
|
| 634 |
+
epochs_trained += 1
|
| 635 |
+
|
| 636 |
+
# make sure pending events are flushed to disk and files are closed properly
|
| 637 |
+
# self._tensorboard.close()
|
| 638 |
+
|
| 639 |
+
# Load the best model state before returning
|
| 640 |
+
best_model_state = self._checkpointer.best_model_state()
|
| 641 |
+
if best_model_state:
|
| 642 |
+
self.model.load_state_dict(best_model_state)
|
| 643 |
+
|
| 644 |
+
return metrics
|
| 645 |
+
|
| 646 |
+
def _save_checkpoint(self, epoch: Union[int, str]) -> None:
|
| 647 |
+
"""
|
| 648 |
+
Saves a checkpoint of the model to self._serialization_dir.
|
| 649 |
+
Is a no-op if self._serialization_dir is None.
|
| 650 |
+
|
| 651 |
+
Parameters
|
| 652 |
+
----------
|
| 653 |
+
epoch : Union[int, str], required.
|
| 654 |
+
The epoch of training. If the checkpoint is saved in the middle
|
| 655 |
+
of an epoch, the parameter is a string with the epoch and timestamp.
|
| 656 |
+
"""
|
| 657 |
+
# If moving averages are used for parameters, we save
|
| 658 |
+
# the moving average values into checkpoint, instead of the current values.
|
| 659 |
+
if self._moving_average is not None:
|
| 660 |
+
self._moving_average.assign_average_value()
|
| 661 |
+
|
| 662 |
+
# These are the training states we need to persist.
|
| 663 |
+
training_states = {
|
| 664 |
+
"metric_tracker": self._metric_tracker.state_dict(),
|
| 665 |
+
"optimizer": self.optimizer.state_dict(),
|
| 666 |
+
"batch_num_total": self._batch_num_total,
|
| 667 |
+
}
|
| 668 |
+
|
| 669 |
+
# If we have a learning rate or momentum scheduler, we should persist them too.
|
| 670 |
+
if self._learning_rate_scheduler is not None:
|
| 671 |
+
training_states["learning_rate_scheduler"] = self._learning_rate_scheduler.state_dict()
|
| 672 |
+
if self._momentum_scheduler is not None:
|
| 673 |
+
training_states["momentum_scheduler"] = self._momentum_scheduler.state_dict()
|
| 674 |
+
|
| 675 |
+
self._checkpointer.save_checkpoint(
|
| 676 |
+
model_state=self.model.state_dict(),
|
| 677 |
+
epoch=epoch,
|
| 678 |
+
training_states=training_states,
|
| 679 |
+
is_best_so_far=self._metric_tracker.is_best_so_far(),
|
| 680 |
+
)
|
| 681 |
+
|
| 682 |
+
# Restore the original values for parameters so that training will not be affected.
|
| 683 |
+
if self._moving_average is not None:
|
| 684 |
+
self._moving_average.restore()
|
| 685 |
+
|
| 686 |
+
def _restore_checkpoint(self) -> int:
|
| 687 |
+
"""
|
| 688 |
+
Restores the model and training state from the last saved checkpoint.
|
| 689 |
+
This includes an epoch count and optimizer state, which is serialized separately
|
| 690 |
+
from model parameters. This function should only be used to continue training -
|
| 691 |
+
if you wish to load a model for inference/load parts of a model into a new
|
| 692 |
+
computation graph, you should use the native Pytorch functions:
|
| 693 |
+
`` model.load_state_dict(torch.load("/path/to/model/weights.th"))``
|
| 694 |
+
|
| 695 |
+
If ``self._serialization_dir`` does not exist or does not contain any checkpointed weights,
|
| 696 |
+
this function will do nothing and return 0.
|
| 697 |
+
|
| 698 |
+
Returns
|
| 699 |
+
-------
|
| 700 |
+
epoch: int
|
| 701 |
+
The epoch at which to resume training, which should be one after the epoch
|
| 702 |
+
in the saved training state.
|
| 703 |
+
"""
|
| 704 |
+
model_state, training_state = self._checkpointer.restore_checkpoint()
|
| 705 |
+
|
| 706 |
+
if not training_state:
|
| 707 |
+
# No checkpoint to restore, start at 0
|
| 708 |
+
return 0
|
| 709 |
+
|
| 710 |
+
self.model.load_state_dict(model_state)
|
| 711 |
+
self.optimizer.load_state_dict(training_state["optimizer"])
|
| 712 |
+
if self._learning_rate_scheduler is not None \
|
| 713 |
+
and "learning_rate_scheduler" in training_state:
|
| 714 |
+
self._learning_rate_scheduler.load_state_dict(training_state["learning_rate_scheduler"])
|
| 715 |
+
if self._momentum_scheduler is not None and "momentum_scheduler" in training_state:
|
| 716 |
+
self._momentum_scheduler.load_state_dict(training_state["momentum_scheduler"])
|
| 717 |
+
training_util.move_optimizer_to_cuda(self.optimizer)
|
| 718 |
+
|
| 719 |
+
# Currently the ``training_state`` contains a serialized ``MetricTracker``.
|
| 720 |
+
if "metric_tracker" in training_state:
|
| 721 |
+
self._metric_tracker.load_state_dict(training_state["metric_tracker"])
|
| 722 |
+
# It used to be the case that we tracked ``val_metric_per_epoch``.
|
| 723 |
+
elif "val_metric_per_epoch" in training_state:
|
| 724 |
+
self._metric_tracker.clear()
|
| 725 |
+
self._metric_tracker.add_metrics(training_state["val_metric_per_epoch"])
|
| 726 |
+
# And before that we didn't track anything.
|
| 727 |
+
else:
|
| 728 |
+
self._metric_tracker.clear()
|
| 729 |
+
|
| 730 |
+
if isinstance(training_state["epoch"], int):
|
| 731 |
+
epoch_to_return = training_state["epoch"] + 1
|
| 732 |
+
else:
|
| 733 |
+
epoch_to_return = int(training_state["epoch"].split(".")[0]) + 1
|
| 734 |
+
|
| 735 |
+
# For older checkpoints with batch_num_total missing, default to old behavior where
|
| 736 |
+
# it is unchanged.
|
| 737 |
+
batch_num_total = training_state.get("batch_num_total")
|
| 738 |
+
if batch_num_total is not None:
|
| 739 |
+
self._batch_num_total = batch_num_total
|
| 740 |
+
|
| 741 |
+
return epoch_to_return
|
| 742 |
+
|
| 743 |
+
# Requires custom from_params.
|
| 744 |
+
@classmethod
|
| 745 |
+
def from_params( # type: ignore
|
| 746 |
+
cls,
|
| 747 |
+
model: Model,
|
| 748 |
+
serialization_dir: str,
|
| 749 |
+
iterator: DataIterator,
|
| 750 |
+
train_data: Iterable[Instance],
|
| 751 |
+
validation_data: Optional[Iterable[Instance]],
|
| 752 |
+
params: Params,
|
| 753 |
+
validation_iterator: DataIterator = None,
|
| 754 |
+
) -> "Trainer":
|
| 755 |
+
|
| 756 |
+
patience = params.pop_int("patience", None)
|
| 757 |
+
validation_metric = params.pop("validation_metric", "-loss")
|
| 758 |
+
shuffle = params.pop_bool("shuffle", True)
|
| 759 |
+
num_epochs = params.pop_int("num_epochs", 20)
|
| 760 |
+
cuda_device = parse_cuda_device(params.pop("cuda_device", -1))
|
| 761 |
+
grad_norm = params.pop_float("grad_norm", None)
|
| 762 |
+
grad_clipping = params.pop_float("grad_clipping", None)
|
| 763 |
+
lr_scheduler_params = params.pop("learning_rate_scheduler", None)
|
| 764 |
+
momentum_scheduler_params = params.pop("momentum_scheduler", None)
|
| 765 |
+
|
| 766 |
+
if isinstance(cuda_device, list):
|
| 767 |
+
model_device = cuda_device[0]
|
| 768 |
+
else:
|
| 769 |
+
model_device = cuda_device
|
| 770 |
+
if model_device >= 0:
|
| 771 |
+
# Moving model to GPU here so that the optimizer state gets constructed on
|
| 772 |
+
# the right device.
|
| 773 |
+
model = model.cuda(model_device)
|
| 774 |
+
|
| 775 |
+
parameters = [[n, p] for n, p in model.named_parameters() if p.requires_grad]
|
| 776 |
+
optimizer = Optimizer.from_params(parameters, params.pop("optimizer"))
|
| 777 |
+
if "moving_average" in params:
|
| 778 |
+
moving_average = MovingAverage.from_params(
|
| 779 |
+
params.pop("moving_average"), parameters=parameters
|
| 780 |
+
)
|
| 781 |
+
else:
|
| 782 |
+
moving_average = None
|
| 783 |
+
|
| 784 |
+
if lr_scheduler_params:
|
| 785 |
+
lr_scheduler = LearningRateScheduler.from_params(optimizer, lr_scheduler_params)
|
| 786 |
+
else:
|
| 787 |
+
lr_scheduler = None
|
| 788 |
+
if momentum_scheduler_params:
|
| 789 |
+
momentum_scheduler = MomentumScheduler.from_params(optimizer, momentum_scheduler_params)
|
| 790 |
+
else:
|
| 791 |
+
momentum_scheduler = None
|
| 792 |
+
|
| 793 |
+
if "checkpointer" in params:
|
| 794 |
+
if "keep_serialized_model_every_num_seconds" in params \
|
| 795 |
+
or "num_serialized_models_to_keep" in params:
|
| 796 |
+
raise ConfigurationError(
|
| 797 |
+
"Checkpointer may be initialized either from the 'checkpointer' key or from the "
|
| 798 |
+
"keys 'num_serialized_models_to_keep' and 'keep_serialized_model_every_num_seconds'"
|
| 799 |
+
" but the passed config uses both methods."
|
| 800 |
+
)
|
| 801 |
+
checkpointer = Checkpointer.from_params(params.pop("checkpointer"))
|
| 802 |
+
else:
|
| 803 |
+
num_serialized_models_to_keep = params.pop_int("num_serialized_models_to_keep", 20)
|
| 804 |
+
keep_serialized_model_every_num_seconds = params.pop_int(
|
| 805 |
+
"keep_serialized_model_every_num_seconds", None
|
| 806 |
+
)
|
| 807 |
+
checkpointer = Checkpointer(
|
| 808 |
+
serialization_dir=serialization_dir,
|
| 809 |
+
num_serialized_models_to_keep=num_serialized_models_to_keep,
|
| 810 |
+
keep_serialized_model_every_num_seconds=keep_serialized_model_every_num_seconds,
|
| 811 |
+
)
|
| 812 |
+
model_save_interval = params.pop_float("model_save_interval", None)
|
| 813 |
+
summary_interval = params.pop_int("summary_interval", 100)
|
| 814 |
+
histogram_interval = params.pop_int("histogram_interval", None)
|
| 815 |
+
should_log_parameter_statistics = params.pop_bool("should_log_parameter_statistics", True)
|
| 816 |
+
should_log_learning_rate = params.pop_bool("should_log_learning_rate", False)
|
| 817 |
+
log_batch_size_period = params.pop_int("log_batch_size_period", None)
|
| 818 |
+
|
| 819 |
+
params.assert_empty(cls.__name__)
|
| 820 |
+
return cls(
|
| 821 |
+
model,
|
| 822 |
+
optimizer,
|
| 823 |
+
iterator,
|
| 824 |
+
train_data,
|
| 825 |
+
validation_data,
|
| 826 |
+
patience=patience,
|
| 827 |
+
validation_metric=validation_metric,
|
| 828 |
+
validation_iterator=validation_iterator,
|
| 829 |
+
shuffle=shuffle,
|
| 830 |
+
num_epochs=num_epochs,
|
| 831 |
+
serialization_dir=serialization_dir,
|
| 832 |
+
cuda_device=cuda_device,
|
| 833 |
+
grad_norm=grad_norm,
|
| 834 |
+
grad_clipping=grad_clipping,
|
| 835 |
+
learning_rate_scheduler=lr_scheduler,
|
| 836 |
+
momentum_scheduler=momentum_scheduler,
|
| 837 |
+
checkpointer=checkpointer,
|
| 838 |
+
model_save_interval=model_save_interval,
|
| 839 |
+
summary_interval=summary_interval,
|
| 840 |
+
histogram_interval=histogram_interval,
|
| 841 |
+
should_log_parameter_statistics=should_log_parameter_statistics,
|
| 842 |
+
should_log_learning_rate=should_log_learning_rate,
|
| 843 |
+
log_batch_size_period=log_batch_size_period,
|
| 844 |
+
moving_average=moving_average,
|
| 845 |
+
)
|