Albin Thörn Cleland
Clean initial commit with LFS
19b8775
import sys
import logging
import torch
import torch.nn as nn
import torch.optim as optim
from stanza.models.common import utils
from stanza.models.common.trainer import Trainer as BaseTrainer
from stanza.models.tokenization.utils import create_dictionary
from .model import Tokenizer
from .vocab import Vocab
logger = logging.getLogger('stanza')
class Trainer(BaseTrainer):
def __init__(self, args=None, vocab=None, lexicon=None, dictionary=None, model_file=None, device=None, foundation_cache=None):
# TODO: make a test of the training w/ and w/o charlm
if model_file is not None:
# load everything from file
self.load(model_file, args, foundation_cache)
else:
# build model from scratch
self.args = args
self.vocab = vocab
self.lexicon = list(lexicon) if lexicon is not None else None
self.dictionary = dictionary
self.model = Tokenizer(self.args, self.args['vocab_size'], self.args['emb_dim'], self.args['hidden_dim'], dropout=self.args['dropout'], feat_dropout=self.args['feat_dropout'])
self.model = self.model.to(device)
self.criterion = nn.CrossEntropyLoss(ignore_index=-1).to(device)
self.optimizer = utils.get_optimizer("adam", self.model, lr=self.args['lr0'], betas=(.9, .9), weight_decay=self.args['weight_decay'])
self.feat_funcs = self.args.get('feat_funcs', None)
self.lang = self.args['lang'] # language determines how token normalization is done
def update(self, inputs):
self.model.train()
units, labels, features, text = inputs
lengths = [len(x) for x in text]
device = next(self.model.parameters()).device
units = units.to(device)
labels = labels.to(device)
features = features.to(device)
pred = self.model(units, features, lengths, text)
self.optimizer.zero_grad()
classes = pred.size(2)
loss = self.criterion(pred.view(-1, classes), labels.view(-1))
loss.backward()
nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm'])
self.optimizer.step()
return loss.item()
def predict(self, inputs):
self.model.eval()
units, _, features, text = inputs
lengths = [len(x) for x in text]
device = next(self.model.parameters()).device
units = units.to(device)
features = features.to(device)
pred = self.model(units, features, lengths, text)
return pred.data.cpu().numpy()
def save(self, filename, skip_modules=True):
model_state = None
if self.model is not None:
model_state = self.model.state_dict()
# skip saving modules like the pretrained charlm
if skip_modules:
skipped = [k for k in model_state.keys() if k.split('.')[0] in self.model.unsaved_modules]
for k in skipped:
del model_state[k]
params = {
'model': model_state,
'vocab': self.vocab.state_dict(),
# save and load lexicon as list instead of set so
# we can use weights_only=True
'lexicon': list(self.lexicon) if self.lexicon is not None else None,
'config': self.args
}
try:
torch.save(params, filename, _use_new_zipfile_serialization=False)
logger.info("Model saved to {}".format(filename))
except BaseException:
logger.warning("Saving failed... continuing anyway.")
def load(self, filename, args, foundation_cache):
try:
checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True)
except BaseException:
logger.error("Cannot load model from {}".format(filename))
raise
self.args = checkpoint['config']
if args is not None and args.get('charlm_forward_file', None) is not None:
if checkpoint['config'].get('charlm_forward_file') is None:
# if the saved model didn't use a charlm, we skip the charlm here
# otherwise the loaded model weights won't fit in the newly created model
self.args['charlm_forward_file'] = None
else:
self.args['charlm_forward_file'] = args['charlm_forward_file']
if self.args.get('use_mwt', None) is None:
# Default to True as many currently saved models
# were built with mwt layers
self.args['use_mwt'] = True
self.model = Tokenizer(self.args, self.args['vocab_size'], self.args['emb_dim'], self.args['hidden_dim'], dropout=self.args['dropout'], feat_dropout=self.args['feat_dropout'])
self.model.load_state_dict(checkpoint['model'], strict=False)
self.vocab = Vocab.load_state_dict(checkpoint['vocab'])
self.lexicon = checkpoint['lexicon']
if self.lexicon is not None:
self.lexicon = set(self.lexicon)
self.dictionary = create_dictionary(self.lexicon)
else:
self.dictionary = None