aaljabari commited on
Commit
4462f0f
·
verified ·
1 Parent(s): 9762f2c

Create BertNestedTrainer.py

Browse files
Files changed (1) hide show
  1. Nested/trainers/BertNestedTrainer.py +203 -0
Nested/trainers/BertNestedTrainer.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import torch
4
+ import numpy as np
5
+ from Nested.trainers import BaseTrainer
6
+ from Nested.utils.metrics import compute_nested_metrics
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ class BertNestedTrainer(BaseTrainer):
12
+ def __init__(self, **kwargs):
13
+ super().__init__(**kwargs)
14
+
15
+ def train(self):
16
+ best_val_loss, test_loss = np.inf, np.inf
17
+ num_train_batch = len(self.train_dataloader)
18
+ num_labels = [len(v) for v in self.train_dataloader.dataset.vocab.tags[1:]]
19
+ patience = self.patience
20
+
21
+ for epoch_index in range(self.max_epochs):
22
+ self.current_epoch = epoch_index
23
+ train_loss = 0
24
+
25
+ for batch_index, (subwords, gold_tags, tokens, valid_len, logits) in enumerate(self.tag(
26
+ self.train_dataloader, is_train=True
27
+ ), 1):
28
+ self.current_timestep += 1
29
+
30
+ # Compute loses for each output
31
+ # logits = B x T x L x C
32
+ losses = [self.loss(logits[:, :, i, 0:l].view(-1, logits[:, :, i, 0:l].shape[-1]),
33
+ torch.reshape(gold_tags[:, i, :], (-1,)).long())
34
+ for i, l in enumerate(num_labels)]
35
+
36
+ torch.autograd.backward(losses)
37
+
38
+ # Avoid exploding gradient by doing gradient clipping
39
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip)
40
+
41
+ self.optimizer.step()
42
+ self.scheduler.step()
43
+ batch_loss = sum(l.item() for l in losses)
44
+ train_loss += batch_loss
45
+
46
+ if self.current_timestep % self.log_interval == 0:
47
+ logger.info(
48
+ "Epoch %d | Batch %d/%d | Timestep %d | LR %.10f | Loss %f",
49
+ epoch_index,
50
+ batch_index,
51
+ num_train_batch,
52
+ self.current_timestep,
53
+ self.optimizer.param_groups[0]['lr'],
54
+ batch_loss
55
+ )
56
+
57
+ train_loss /= num_train_batch
58
+
59
+ logger.info("** Evaluating on validation dataset **")
60
+ val_preds, segments, valid_len, val_loss = self.eval(self.val_dataloader)
61
+ val_metrics = compute_nested_metrics(segments, self.val_dataloader.dataset.transform.vocab.tags[1:])
62
+
63
+ epoch_summary_loss = {
64
+ "train_loss": train_loss,
65
+ "val_loss": val_loss
66
+ }
67
+ epoch_summary_metrics = {
68
+ "val_micro_f1": val_metrics.micro_f1,
69
+ "val_precision": val_metrics.precision,
70
+ "val_recall": val_metrics.recall
71
+ }
72
+
73
+ logger.info(
74
+ "Epoch %d | Timestep %d | Train Loss %f | Val Loss %f | F1 %f",
75
+ epoch_index,
76
+ self.current_timestep,
77
+ train_loss,
78
+ val_loss,
79
+ val_metrics.micro_f1
80
+ )
81
+
82
+ if val_loss < best_val_loss:
83
+ patience = self.patience
84
+ best_val_loss = val_loss
85
+ logger.info("** Validation improved, evaluating test data **")
86
+ test_preds, segments, valid_len, test_loss = self.eval(self.test_dataloader)
87
+ self.segments_to_file(segments, os.path.join(self.output_path, "predictions.txt"))
88
+ test_metrics = compute_nested_metrics(segments, self.test_dataloader.dataset.transform.vocab.tags[1:])
89
+
90
+ epoch_summary_loss["test_loss"] = test_loss
91
+ epoch_summary_metrics["test_micro_f1"] = test_metrics.micro_f1
92
+ epoch_summary_metrics["test_precision"] = test_metrics.precision
93
+ epoch_summary_metrics["test_recall"] = test_metrics.recall
94
+
95
+ logger.info(
96
+ f"Epoch %d | Timestep %d | Test Loss %f | F1 %f",
97
+ epoch_index,
98
+ self.current_timestep,
99
+ test_loss,
100
+ test_metrics.micro_f1
101
+ )
102
+
103
+ self.save()
104
+ else:
105
+ patience -= 1
106
+
107
+ # No improvements, terminating early
108
+ if patience == 0:
109
+ logger.info("Early termination triggered")
110
+ break
111
+
112
+ self.summary_writer.add_scalars("Loss", epoch_summary_loss, global_step=self.current_timestep)
113
+ self.summary_writer.add_scalars("Metrics", epoch_summary_metrics, global_step=self.current_timestep)
114
+
115
+ def tag(self, dataloader, is_train=True):
116
+ """
117
+ Given a dataloader containing segments, predict the tags
118
+ :param dataloader: torch.utils.data.DataLoader
119
+ :param is_train: boolean - True for training model, False for evaluation
120
+ :return: Iterator
121
+ subwords (B x T x NUM_LABELS)- torch.Tensor - BERT subword ID
122
+ gold_tags (B x T x NUM_LABELS) - torch.Tensor - ground truth tags IDs
123
+ tokens - List[Nested.data.dataset.Token] - list of tokens
124
+ valid_len (B x 1) - int - valiud length of each sequence
125
+ logits (B x T x NUM_LABELS) - logits for each token and each tag
126
+ """
127
+ for subwords, gold_tags, tokens, mask, valid_len in dataloader:
128
+ self.model.train(is_train)
129
+
130
+ if torch.cuda.is_available():
131
+ subwords = subwords.cuda()
132
+ gold_tags = gold_tags.cuda()
133
+
134
+ if is_train:
135
+ self.optimizer.zero_grad()
136
+ logits = self.model(subwords)
137
+ else:
138
+ with torch.no_grad():
139
+ logits = self.model(subwords)
140
+
141
+ yield subwords, gold_tags, tokens, valid_len, logits
142
+
143
+ def eval(self, dataloader):
144
+ golds, preds, segments, valid_lens = list(), list(), list(), list()
145
+ num_labels = [len(v) for v in dataloader.dataset.vocab.tags[1:]]
146
+ loss = 0
147
+
148
+ for _, gold_tags, tokens, valid_len, logits in self.tag(
149
+ dataloader, is_train=False
150
+ ):
151
+ losses = [self.loss(logits[:, :, i, 0:l].view(-1, logits[:, :, i, 0:l].shape[-1]),
152
+ torch.reshape(gold_tags[:, i, :], (-1,)).long())
153
+ for i, l in enumerate(num_labels)]
154
+ loss += sum(losses)
155
+ preds += torch.argmax(logits, dim=3)
156
+ segments += tokens
157
+ valid_lens += list(valid_len)
158
+
159
+ loss /= len(dataloader)
160
+
161
+ # Update segments, attach predicted tags to each token
162
+ segments = self.to_segments(segments, preds, valid_lens, dataloader.dataset.vocab)
163
+
164
+ return preds, segments, valid_lens, loss
165
+
166
+ def infer(self, dataloader):
167
+ golds, preds, segments, valid_lens = list(), list(), list(), list()
168
+
169
+ for _, gold_tags, tokens, valid_len, logits in self.tag(
170
+ dataloader, is_train=False
171
+ ):
172
+ preds += torch.argmax(logits, dim=3)
173
+ segments += tokens
174
+ valid_lens += list(valid_len)
175
+
176
+ segments = self.to_segments(segments, preds, valid_lens, dataloader.dataset.vocab)
177
+ return segments
178
+
179
+ def to_segments(self, segments, preds, valid_lens, vocab):
180
+ if vocab is None:
181
+ vocab = self.vocab
182
+
183
+ tagged_segments = list()
184
+ tokens_stoi = vocab.tokens.get_stoi()
185
+ unk_id = tokens_stoi["UNK"]
186
+
187
+ for segment, pred, valid_len in zip(segments, preds, valid_lens):
188
+ # First, the token at 0th index [CLS] and token at nth index [SEP]
189
+ # Combine the tokens with their corresponding predictions
190
+ segment_pred = zip(segment[1:valid_len-1], pred[1:valid_len-1])
191
+
192
+ # Ignore the sub-tokens/subwords, which are identified with text being UNK
193
+ segment_pred = list(filter(lambda t: tokens_stoi[t[0].text] != unk_id, segment_pred))
194
+
195
+ # Attach the predicted tags to each token
196
+ list(map(lambda t: setattr(t[0], 'pred_tag', [{"tag": vocab.get_itos()[tag_id]}
197
+ for tag_id, vocab in zip(t[1].int().tolist(), vocab.tags[1:])]), segment_pred))
198
+
199
+ # We are only interested in the tagged tokens, we do no longer need raw model predictions
200
+ tagged_segment = [t for t, _ in segment_pred]
201
+ tagged_segments.append(tagged_segment)
202
+
203
+ return tagged_segments