event-argument-api / Nested /trainers /BertNestedTrainer.py
aaljabari's picture
Upload 38 files
4014eb8 verified
Raw
History Blame Contribute Delete
8.42 kB
import os
import logging
import torch
import numpy as np
from Nested.trainers import BaseTrainer
from Nested.utils.metrics import compute_nested_metrics
logger = logging.getLogger(__name__)
class BertNestedTrainer(BaseTrainer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def train(self):
best_val_loss, test_loss = np.inf, np.inf
num_train_batch = len(self.train_dataloader)
num_labels = [len(v) for v in self.train_dataloader.dataset.vocab.tags[1:]]
patience = self.patience
for epoch_index in range(self.max_epochs):
self.current_epoch = epoch_index
train_loss = 0
for batch_index, (subwords, gold_tags, tokens, valid_len, logits) in enumerate(self.tag(
self.train_dataloader, is_train=True
), 1):
self.current_timestep += 1
# Compute loses for each output
# logits = B x T x L x C
losses = [self.loss(logits[:, :, i, 0:l].view(-1, logits[:, :, i, 0:l].shape[-1]),
torch.reshape(gold_tags[:, i, :], (-1,)).long())
for i, l in enumerate(num_labels)]
torch.autograd.backward(losses)
# Avoid exploding gradient by doing gradient clipping
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip)
self.optimizer.step()
self.scheduler.step()
batch_loss = sum(l.item() for l in losses)
train_loss += batch_loss
if self.current_timestep % self.log_interval == 0:
logger.info(
"Epoch %d | Batch %d/%d | Timestep %d | LR %.10f | Loss %f",
epoch_index,
batch_index,
num_train_batch,
self.current_timestep,
self.optimizer.param_groups[0]['lr'],
batch_loss
)
train_loss /= num_train_batch
logger.info("** Evaluating on validation dataset **")
val_preds, segments, valid_len, val_loss = self.eval(self.val_dataloader)
val_metrics = compute_nested_metrics(segments, self.val_dataloader.dataset.transform.vocab.tags[1:])
epoch_summary_loss = {
"train_loss": train_loss,
"val_loss": val_loss
}
epoch_summary_metrics = {
"val_micro_f1": val_metrics.micro_f1,
"val_precision": val_metrics.precision,
"val_recall": val_metrics.recall
}
logger.info(
"Epoch %d | Timestep %d | Train Loss %f | Val Loss %f | F1 %f",
epoch_index,
self.current_timestep,
train_loss,
val_loss,
val_metrics.micro_f1
)
if val_loss < best_val_loss:
patience = self.patience
best_val_loss = val_loss
logger.info("** Validation improved, evaluating test data **")
test_preds, segments, valid_len, test_loss = self.eval(self.test_dataloader)
self.segments_to_file(segments, os.path.join(self.output_path, "predictions.txt"))
test_metrics = compute_nested_metrics(segments, self.test_dataloader.dataset.transform.vocab.tags[1:])
epoch_summary_loss["test_loss"] = test_loss
epoch_summary_metrics["test_micro_f1"] = test_metrics.micro_f1
epoch_summary_metrics["test_precision"] = test_metrics.precision
epoch_summary_metrics["test_recall"] = test_metrics.recall
logger.info(
f"Epoch %d | Timestep %d | Test Loss %f | F1 %f",
epoch_index,
self.current_timestep,
test_loss,
test_metrics.micro_f1
)
self.save()
else:
patience -= 1
# No improvements, terminating early
if patience == 0:
logger.info("Early termination triggered")
break
self.summary_writer.add_scalars("Loss", epoch_summary_loss, global_step=self.current_timestep)
self.summary_writer.add_scalars("Metrics", epoch_summary_metrics, global_step=self.current_timestep)
def tag(self, dataloader, is_train=True):
"""
Given a dataloader containing segments, predict the tags
:param dataloader: torch.utils.data.DataLoader
:param is_train: boolean - True for training model, False for evaluation
:return: Iterator
subwords (B x T x NUM_LABELS)- torch.Tensor - BERT subword ID
gold_tags (B x T x NUM_LABELS) - torch.Tensor - ground truth tags IDs
tokens - List[Nested.data.dataset.Token] - list of tokens
valid_len (B x 1) - int - valiud length of each sequence
logits (B x T x NUM_LABELS) - logits for each token and each tag
"""
for subwords, gold_tags, tokens, mask, valid_len in dataloader:
self.model.train(is_train)
if torch.cuda.is_available():
subwords = subwords.cuda()
gold_tags = gold_tags.cuda()
if is_train:
self.optimizer.zero_grad()
logits = self.model(subwords)
else:
with torch.no_grad():
logits = self.model(subwords)
yield subwords, gold_tags, tokens, valid_len, logits
def eval(self, dataloader):
golds, preds, segments, valid_lens = list(), list(), list(), list()
num_labels = [len(v) for v in dataloader.dataset.vocab.tags[1:]]
loss = 0
for _, gold_tags, tokens, valid_len, logits in self.tag(
dataloader, is_train=False
):
losses = [self.loss(logits[:, :, i, 0:l].view(-1, logits[:, :, i, 0:l].shape[-1]),
torch.reshape(gold_tags[:, i, :], (-1,)).long())
for i, l in enumerate(num_labels)]
loss += sum(losses)
preds += torch.argmax(logits, dim=3)
segments += tokens
valid_lens += list(valid_len)
loss /= len(dataloader)
# Update segments, attach predicted tags to each token
segments = self.to_segments(segments, preds, valid_lens, dataloader.dataset.vocab)
return preds, segments, valid_lens, loss
def infer(self, dataloader):
golds, preds, segments, valid_lens = list(), list(), list(), list()
for _, gold_tags, tokens, valid_len, logits in self.tag(
dataloader, is_train=False
):
preds += torch.argmax(logits, dim=3)
segments += tokens
valid_lens += list(valid_len)
segments = self.to_segments(segments, preds, valid_lens, dataloader.dataset.vocab)
return segments
def to_segments(self, segments, preds, valid_lens, vocab):
if vocab is None:
vocab = self.vocab
tagged_segments = list()
tokens_stoi = vocab.tokens.get_stoi()
unk_id = tokens_stoi["UNK"]
for segment, pred, valid_len in zip(segments, preds, valid_lens):
# First, the token at 0th index [CLS] and token at nth index [SEP]
# Combine the tokens with their corresponding predictions
segment_pred = zip(segment[1:valid_len-1], pred[1:valid_len-1])
# Ignore the sub-tokens/subwords, which are identified with text being UNK
segment_pred = list(filter(lambda t: tokens_stoi[t[0].text] != unk_id, segment_pred))
# Attach the predicted tags to each token
list(map(lambda t: setattr(t[0], 'pred_tag', [{"tag": vocab.get_itos()[tag_id]}
for tag_id, vocab in zip(t[1].int().tolist(), vocab.tags[1:])]), segment_pred))
# We are only interested in the tagged tokens, we do no longer need raw model predictions
tagged_segment = [t for t, _ in segment_pred]
tagged_segments.append(tagged_segment)
return tagged_segments