Albin Thörn Cleland
Clean initial commit with LFS
19b8775
"""
Organizes the model itself and its optimizer in one place
Saving the optimizer allows for easy restarting of training
"""
import logging
import os
import torch
import torch.optim as optim
from types import SimpleNamespace
import stanza.models.classifiers.data as data
import stanza.models.classifiers.cnn_classifier as cnn_classifier
import stanza.models.classifiers.constituency_classifier as constituency_classifier
from stanza.models.classifiers.config import CNNConfig, ConstituencyConfig
from stanza.models.classifiers.utils import ModelType, WVType, ExtraVectors
from stanza.models.common.foundation_cache import load_bert, load_bert_with_peft, load_charlm, load_pretrain
from stanza.models.common.peft_config import build_peft_wrapper, load_peft_wrapper
from stanza.models.common.pretrain import Pretrain
from stanza.models.common.utils import get_split_optimizer
from stanza.models.constituency.tree_embedding import TreeEmbedding
from pickle import UnpicklingError
import warnings
logger = logging.getLogger('stanza')
class Trainer:
"""
Stores a constituency model and its optimizer
"""
def __init__(self, model, optimizer=None, epochs_trained=0, global_step=0, best_score=None):
self.model = model
self.optimizer = optimizer
# we keep track of position in the learning so that we can
# checkpoint & restart if needed without restarting the epoch count
self.epochs_trained = epochs_trained
self.global_step = global_step
# save the best dev score so that when reloading a checkpoint
# of a model, we know how far we got
self.best_score = best_score
def save(self, filename, epochs_trained=None, skip_modules=True, save_optimizer=True):
"""
save the current model, optimizer, and other state to filename
epochs_trained can be passed as a parameter to handle saving at the end of an epoch
"""
if epochs_trained is None:
epochs_trained = self.epochs_trained
save_dir = os.path.split(filename)[0]
os.makedirs(save_dir, exist_ok=True)
model_params = self.model.get_params(skip_modules)
params = {
'params': model_params,
'epochs_trained': epochs_trained,
'global_step': self.global_step,
'best_score': self.best_score,
}
if save_optimizer and self.optimizer is not None:
params['optimizer_state_dict'] = {opt_name: opt.state_dict() for opt_name, opt in self.optimizer.items()}
torch.save(params, filename, _use_new_zipfile_serialization=False)
logger.info("Model saved to {}".format(filename))
@staticmethod
def load(filename, args, foundation_cache=None, load_optimizer=False):
if not os.path.exists(filename):
if args.save_dir is None:
raise FileNotFoundError("Cannot find model in {} and args.save_dir is None".format(filename))
elif os.path.exists(os.path.join(args.save_dir, filename)):
filename = os.path.join(args.save_dir, filename)
else:
raise FileNotFoundError("Cannot find model in {} or in {}".format(filename, os.path.join(args.save_dir, filename)))
try:
# TODO: can remove the try/except once the new version is out
#checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True)
try:
checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True)
except UnpicklingError as e:
checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=False)
warnings.warn("The saved classifier has an old format using SimpleNamespace and/or Enum instead of a dict to store config. This version of Stanza can support reading both the new and the old formats. Future versions will only allow loading with weights_only=True. Please resave the pretrained classifier using this version ASAP.")
except BaseException:
logger.exception("Cannot load model from {}".format(filename))
raise
logger.debug("Loaded model {}".format(filename))
epochs_trained = checkpoint.get('epochs_trained', 0)
global_step = checkpoint.get('global_step', 0)
best_score = checkpoint.get('best_score', None)
# TODO: can remove this block once all models are retrained
if 'params' not in checkpoint:
model_params = {
'model': checkpoint['model'],
'config': checkpoint['config'],
'labels': checkpoint['labels'],
'extra_vocab': checkpoint['extra_vocab'],
}
else:
model_params = checkpoint['params']
# TODO: this can be removed once v1.10.0 is out
if isinstance(model_params['config'], SimpleNamespace):
model_params['config'] = vars(model_params['config'])
# TODO: these isinstance can go away after 1.10.0
model_type = model_params['config']['model_type']
if isinstance(model_type, str):
model_type = ModelType[model_type]
model_params['config']['model_type'] = model_type
if model_type == ModelType.CNN:
# TODO: these updates are only necessary during the
# transition to the @dataclass version of the config
# Once those are all saved, it is no longer necessary
# to patch existing models (since they will all be patched)
if 'has_charlm_forward' not in model_params['config']:
model_params['config']['has_charlm_forward'] = args.charlm_forward_file is not None
if 'has_charlm_backward' not in model_params['config']:
model_params['config']['has_charlm_backward'] = args.charlm_backward_file is not None
for argname in ['bert_hidden_layers', 'bert_finetune', 'force_bert_saved', 'use_peft',
'lora_rank', 'lora_alpha', 'lora_dropout', 'lora_modules_to_save', 'lora_target_modules']:
model_params['config'][argname] = model_params['config'].get(argname, None)
# TODO: these isinstance can go away after 1.10.0
if isinstance(model_params['config']['wordvec_type'], str):
model_params['config']['wordvec_type'] = WVType[model_params['config']['wordvec_type']]
if isinstance(model_params['config']['extra_wordvec_method'], str):
model_params['config']['extra_wordvec_method'] = ExtraVectors[model_params['config']['extra_wordvec_method']]
model_params['config'] = CNNConfig(**model_params['config'])
pretrain = Trainer.load_pretrain(args, foundation_cache)
elmo_model = utils.load_elmo(args.elmo_model) if args.use_elmo else None
if model_params['config'].has_charlm_forward:
charmodel_forward = load_charlm(args.charlm_forward_file, foundation_cache)
else:
charmodel_forward = None
if model_params['config'].has_charlm_backward:
charmodel_backward = load_charlm(args.charlm_backward_file, foundation_cache)
else:
charmodel_backward = None
bert_model = model_params['config'].bert_model
# TODO: can get rid of the getattr after rebuilding all models
use_peft = getattr(model_params['config'], 'use_peft', False)
force_bert_saved = getattr(model_params['config'], 'force_bert_saved', False)
peft_name = None
if use_peft:
# if loading a peft model, we first load the base transformer
# the CNNClassifier code wraps the transformer in peft
# after creating the CNNClassifier with the peft wrapper,
# we *then* load the weights
bert_model, bert_tokenizer, peft_name = load_bert_with_peft(bert_model, "classifier", foundation_cache)
bert_model = load_peft_wrapper(bert_model, model_params['bert_lora'], vars(model_params['config']), logger, peft_name)
elif force_bert_saved:
bert_model, bert_tokenizer = load_bert(bert_model)
else:
bert_model, bert_tokenizer = load_bert(bert_model, foundation_cache)
model = cnn_classifier.CNNClassifier(pretrain=pretrain,
extra_vocab=model_params['extra_vocab'],
labels=model_params['labels'],
charmodel_forward=charmodel_forward,
charmodel_backward=charmodel_backward,
elmo_model=elmo_model,
bert_model=bert_model,
bert_tokenizer=bert_tokenizer,
force_bert_saved=force_bert_saved,
peft_name=peft_name,
args=model_params['config'])
elif model_type == ModelType.CONSTITUENCY:
# the constituency version doesn't have a peft feature yet
use_peft = False
pretrain_args = {
'wordvec_pretrain_file': args.wordvec_pretrain_file,
'charlm_forward_file': args.charlm_forward_file,
'charlm_backward_file': args.charlm_backward_file,
}
# TODO: integrate with peft for the constituency version
tree_embedding = TreeEmbedding.model_from_params(model_params['tree_embedding'], pretrain_args, foundation_cache)
model_params['config'] = ConstituencyConfig(**model_params['config'])
model = constituency_classifier.ConstituencyClassifier(tree_embedding=tree_embedding,
labels=model_params['labels'],
args=model_params['config'])
else:
raise ValueError("Unknown model type {}".format(model_type))
model.load_state_dict(model_params['model'], strict=False)
model = model.to(args.device)
logger.debug("-- MODEL CONFIG --")
for k in model.config.__dict__:
logger.debug(" --{}: {}".format(k, model.config.__dict__[k]))
logger.debug("-- MODEL LABELS --")
logger.debug(" {}".format(" ".join(model.labels)))
optimizer = None
if load_optimizer:
optimizer = Trainer.build_optimizer(model, args)
if checkpoint.get('optimizer_state_dict', None) is not None:
for opt_name, opt_state_dict in checkpoint['optimizer_state_dict'].items():
optimizer[opt_name].load_state_dict(opt_state_dict)
else:
logger.info("Attempted to load optimizer to resume training, but optimizer not saved. Creating new optimizer")
trainer = Trainer(model, optimizer, epochs_trained, global_step, best_score)
return trainer
def load_pretrain(args, foundation_cache):
if args.wordvec_pretrain_file:
pretrain_file = args.wordvec_pretrain_file
elif args.wordvec_type:
pretrain_file = '{}/{}.{}.pretrain.pt'.format(args.save_dir, args.shorthand, args.wordvec_type.name.lower())
else:
raise RuntimeError("TODO: need to get the wv type back from get_wordvec_file")
logger.debug("Looking for pretrained vectors in {}".format(pretrain_file))
if os.path.exists(pretrain_file):
return load_pretrain(pretrain_file, foundation_cache)
elif args.wordvec_raw_file:
vec_file = args.wordvec_raw_file
logger.debug("Pretrain not found. Looking in {}".format(vec_file))
else:
vec_file = utils.get_wordvec_file(args.wordvec_dir, args.shorthand, args.wordvec_type.name.lower())
logger.debug("Pretrain not found. Looking in {}".format(vec_file))
pretrain = Pretrain(pretrain_file, vec_file, args.pretrain_max_vocab)
logger.debug("Embedding shape: %s" % str(pretrain.emb.shape))
return pretrain
@staticmethod
def build_new_model(args, train_set):
"""
Load pretrained pieces and then build a new model
"""
if train_set is None:
raise ValueError("Must have a train set to build a new model - needed for labels and delta word vectors")
labels = data.dataset_labels(train_set)
if args.model_type == ModelType.CNN:
pretrain = Trainer.load_pretrain(args, foundation_cache=None)
elmo_model = utils.load_elmo(args.elmo_model) if args.use_elmo else None
charmodel_forward = load_charlm(args.charlm_forward_file)
charmodel_backward = load_charlm(args.charlm_backward_file)
peft_name = None
bert_model, bert_tokenizer = load_bert(args.bert_model)
use_peft = getattr(args, "use_peft", False)
if use_peft:
peft_name = "sentiment"
bert_model = build_peft_wrapper(bert_model, vars(args), logger, adapter_name=peft_name)
extra_vocab = data.dataset_vocab(train_set)
force_bert_saved = args.bert_finetune
model = cnn_classifier.CNNClassifier(pretrain=pretrain,
extra_vocab=extra_vocab,
labels=labels,
charmodel_forward=charmodel_forward,
charmodel_backward=charmodel_backward,
elmo_model=elmo_model,
bert_model=bert_model,
bert_tokenizer=bert_tokenizer,
force_bert_saved=force_bert_saved,
peft_name=peft_name,
args=args)
model = model.to(args.device)
elif args.model_type == ModelType.CONSTITUENCY:
# this passes flags such as "constituency_backprop" from
# the classifier to the TreeEmbedding as the "backprop" flag
parser_args = { x[len("constituency_"):]: y for x, y in vars(args).items() if x.startswith("constituency_") }
parser_args.update({
"wordvec_pretrain_file": args.wordvec_pretrain_file,
"charlm_forward_file": args.charlm_forward_file,
"charlm_backward_file": args.charlm_backward_file,
"bert_model": args.bert_model,
# we found that finetuning from the classifier output
# all the way to the bert layers caused the bert model
# to go astray
# could make this an option... but it is much less accurate
# with the Bert finetuning
# noting that the constituency parser itself works better
# after finetuning, of course
"bert_finetune": False,
"stage1_bert_finetune": False,
})
logger.info("Building constituency classifier using %s as the base model" % args.constituency_model)
tree_embedding = TreeEmbedding.from_parser_file(parser_args)
model = constituency_classifier.ConstituencyClassifier(tree_embedding=tree_embedding,
labels=labels,
args=args)
model = model.to(args.device)
else:
raise ValueError("Unhandled model type {}".format(args.model_type))
optimizer = Trainer.build_optimizer(model, args)
return Trainer(model, optimizer)
@staticmethod
def build_optimizer(model, args):
return get_split_optimizer(args.optim.lower(), model, args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay, bert_learning_rate=args.bert_learning_rate, bert_weight_decay=args.weight_decay * args.bert_weight_decay, is_peft=args.use_peft)