|
|
""" |
|
|
Organizes the model itself and its optimizer in one place |
|
|
|
|
|
Saving the optimizer allows for easy restarting of training |
|
|
""" |
|
|
|
|
|
import logging |
|
|
import os |
|
|
import torch |
|
|
import torch.optim as optim |
|
|
from types import SimpleNamespace |
|
|
|
|
|
import stanza.models.classifiers.data as data |
|
|
import stanza.models.classifiers.cnn_classifier as cnn_classifier |
|
|
import stanza.models.classifiers.constituency_classifier as constituency_classifier |
|
|
from stanza.models.classifiers.config import CNNConfig, ConstituencyConfig |
|
|
from stanza.models.classifiers.utils import ModelType, WVType, ExtraVectors |
|
|
from stanza.models.common.foundation_cache import load_bert, load_bert_with_peft, load_charlm, load_pretrain |
|
|
from stanza.models.common.peft_config import build_peft_wrapper, load_peft_wrapper |
|
|
from stanza.models.common.pretrain import Pretrain |
|
|
from stanza.models.common.utils import get_split_optimizer |
|
|
from stanza.models.constituency.tree_embedding import TreeEmbedding |
|
|
|
|
|
from pickle import UnpicklingError |
|
|
import warnings |
|
|
|
|
|
logger = logging.getLogger('stanza') |
|
|
|
|
|
class Trainer: |
|
|
""" |
|
|
Stores a constituency model and its optimizer |
|
|
""" |
|
|
|
|
|
def __init__(self, model, optimizer=None, epochs_trained=0, global_step=0, best_score=None): |
|
|
self.model = model |
|
|
self.optimizer = optimizer |
|
|
|
|
|
|
|
|
self.epochs_trained = epochs_trained |
|
|
self.global_step = global_step |
|
|
|
|
|
|
|
|
self.best_score = best_score |
|
|
|
|
|
def save(self, filename, epochs_trained=None, skip_modules=True, save_optimizer=True): |
|
|
""" |
|
|
save the current model, optimizer, and other state to filename |
|
|
|
|
|
epochs_trained can be passed as a parameter to handle saving at the end of an epoch |
|
|
""" |
|
|
if epochs_trained is None: |
|
|
epochs_trained = self.epochs_trained |
|
|
save_dir = os.path.split(filename)[0] |
|
|
os.makedirs(save_dir, exist_ok=True) |
|
|
model_params = self.model.get_params(skip_modules) |
|
|
params = { |
|
|
'params': model_params, |
|
|
'epochs_trained': epochs_trained, |
|
|
'global_step': self.global_step, |
|
|
'best_score': self.best_score, |
|
|
} |
|
|
if save_optimizer and self.optimizer is not None: |
|
|
params['optimizer_state_dict'] = {opt_name: opt.state_dict() for opt_name, opt in self.optimizer.items()} |
|
|
torch.save(params, filename, _use_new_zipfile_serialization=False) |
|
|
logger.info("Model saved to {}".format(filename)) |
|
|
|
|
|
@staticmethod |
|
|
def load(filename, args, foundation_cache=None, load_optimizer=False): |
|
|
if not os.path.exists(filename): |
|
|
if args.save_dir is None: |
|
|
raise FileNotFoundError("Cannot find model in {} and args.save_dir is None".format(filename)) |
|
|
elif os.path.exists(os.path.join(args.save_dir, filename)): |
|
|
filename = os.path.join(args.save_dir, filename) |
|
|
else: |
|
|
raise FileNotFoundError("Cannot find model in {} or in {}".format(filename, os.path.join(args.save_dir, filename))) |
|
|
try: |
|
|
|
|
|
|
|
|
try: |
|
|
checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True) |
|
|
except UnpicklingError as e: |
|
|
checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=False) |
|
|
warnings.warn("The saved classifier has an old format using SimpleNamespace and/or Enum instead of a dict to store config. This version of Stanza can support reading both the new and the old formats. Future versions will only allow loading with weights_only=True. Please resave the pretrained classifier using this version ASAP.") |
|
|
except BaseException: |
|
|
logger.exception("Cannot load model from {}".format(filename)) |
|
|
raise |
|
|
logger.debug("Loaded model {}".format(filename)) |
|
|
|
|
|
epochs_trained = checkpoint.get('epochs_trained', 0) |
|
|
global_step = checkpoint.get('global_step', 0) |
|
|
best_score = checkpoint.get('best_score', None) |
|
|
|
|
|
|
|
|
if 'params' not in checkpoint: |
|
|
model_params = { |
|
|
'model': checkpoint['model'], |
|
|
'config': checkpoint['config'], |
|
|
'labels': checkpoint['labels'], |
|
|
'extra_vocab': checkpoint['extra_vocab'], |
|
|
} |
|
|
else: |
|
|
model_params = checkpoint['params'] |
|
|
|
|
|
if isinstance(model_params['config'], SimpleNamespace): |
|
|
model_params['config'] = vars(model_params['config']) |
|
|
|
|
|
model_type = model_params['config']['model_type'] |
|
|
if isinstance(model_type, str): |
|
|
model_type = ModelType[model_type] |
|
|
model_params['config']['model_type'] = model_type |
|
|
|
|
|
if model_type == ModelType.CNN: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if 'has_charlm_forward' not in model_params['config']: |
|
|
model_params['config']['has_charlm_forward'] = args.charlm_forward_file is not None |
|
|
if 'has_charlm_backward' not in model_params['config']: |
|
|
model_params['config']['has_charlm_backward'] = args.charlm_backward_file is not None |
|
|
for argname in ['bert_hidden_layers', 'bert_finetune', 'force_bert_saved', 'use_peft', |
|
|
'lora_rank', 'lora_alpha', 'lora_dropout', 'lora_modules_to_save', 'lora_target_modules']: |
|
|
model_params['config'][argname] = model_params['config'].get(argname, None) |
|
|
|
|
|
if isinstance(model_params['config']['wordvec_type'], str): |
|
|
model_params['config']['wordvec_type'] = WVType[model_params['config']['wordvec_type']] |
|
|
if isinstance(model_params['config']['extra_wordvec_method'], str): |
|
|
model_params['config']['extra_wordvec_method'] = ExtraVectors[model_params['config']['extra_wordvec_method']] |
|
|
model_params['config'] = CNNConfig(**model_params['config']) |
|
|
|
|
|
pretrain = Trainer.load_pretrain(args, foundation_cache) |
|
|
elmo_model = utils.load_elmo(args.elmo_model) if args.use_elmo else None |
|
|
|
|
|
if model_params['config'].has_charlm_forward: |
|
|
charmodel_forward = load_charlm(args.charlm_forward_file, foundation_cache) |
|
|
else: |
|
|
charmodel_forward = None |
|
|
if model_params['config'].has_charlm_backward: |
|
|
charmodel_backward = load_charlm(args.charlm_backward_file, foundation_cache) |
|
|
else: |
|
|
charmodel_backward = None |
|
|
|
|
|
bert_model = model_params['config'].bert_model |
|
|
|
|
|
use_peft = getattr(model_params['config'], 'use_peft', False) |
|
|
force_bert_saved = getattr(model_params['config'], 'force_bert_saved', False) |
|
|
peft_name = None |
|
|
if use_peft: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bert_model, bert_tokenizer, peft_name = load_bert_with_peft(bert_model, "classifier", foundation_cache) |
|
|
bert_model = load_peft_wrapper(bert_model, model_params['bert_lora'], vars(model_params['config']), logger, peft_name) |
|
|
elif force_bert_saved: |
|
|
bert_model, bert_tokenizer = load_bert(bert_model) |
|
|
else: |
|
|
bert_model, bert_tokenizer = load_bert(bert_model, foundation_cache) |
|
|
model = cnn_classifier.CNNClassifier(pretrain=pretrain, |
|
|
extra_vocab=model_params['extra_vocab'], |
|
|
labels=model_params['labels'], |
|
|
charmodel_forward=charmodel_forward, |
|
|
charmodel_backward=charmodel_backward, |
|
|
elmo_model=elmo_model, |
|
|
bert_model=bert_model, |
|
|
bert_tokenizer=bert_tokenizer, |
|
|
force_bert_saved=force_bert_saved, |
|
|
peft_name=peft_name, |
|
|
args=model_params['config']) |
|
|
elif model_type == ModelType.CONSTITUENCY: |
|
|
|
|
|
use_peft = False |
|
|
pretrain_args = { |
|
|
'wordvec_pretrain_file': args.wordvec_pretrain_file, |
|
|
'charlm_forward_file': args.charlm_forward_file, |
|
|
'charlm_backward_file': args.charlm_backward_file, |
|
|
} |
|
|
|
|
|
tree_embedding = TreeEmbedding.model_from_params(model_params['tree_embedding'], pretrain_args, foundation_cache) |
|
|
model_params['config'] = ConstituencyConfig(**model_params['config']) |
|
|
model = constituency_classifier.ConstituencyClassifier(tree_embedding=tree_embedding, |
|
|
labels=model_params['labels'], |
|
|
args=model_params['config']) |
|
|
else: |
|
|
raise ValueError("Unknown model type {}".format(model_type)) |
|
|
model.load_state_dict(model_params['model'], strict=False) |
|
|
model = model.to(args.device) |
|
|
|
|
|
logger.debug("-- MODEL CONFIG --") |
|
|
for k in model.config.__dict__: |
|
|
logger.debug(" --{}: {}".format(k, model.config.__dict__[k])) |
|
|
|
|
|
logger.debug("-- MODEL LABELS --") |
|
|
logger.debug(" {}".format(" ".join(model.labels))) |
|
|
|
|
|
optimizer = None |
|
|
if load_optimizer: |
|
|
optimizer = Trainer.build_optimizer(model, args) |
|
|
if checkpoint.get('optimizer_state_dict', None) is not None: |
|
|
for opt_name, opt_state_dict in checkpoint['optimizer_state_dict'].items(): |
|
|
optimizer[opt_name].load_state_dict(opt_state_dict) |
|
|
else: |
|
|
logger.info("Attempted to load optimizer to resume training, but optimizer not saved. Creating new optimizer") |
|
|
|
|
|
trainer = Trainer(model, optimizer, epochs_trained, global_step, best_score) |
|
|
|
|
|
return trainer |
|
|
|
|
|
|
|
|
def load_pretrain(args, foundation_cache): |
|
|
if args.wordvec_pretrain_file: |
|
|
pretrain_file = args.wordvec_pretrain_file |
|
|
elif args.wordvec_type: |
|
|
pretrain_file = '{}/{}.{}.pretrain.pt'.format(args.save_dir, args.shorthand, args.wordvec_type.name.lower()) |
|
|
else: |
|
|
raise RuntimeError("TODO: need to get the wv type back from get_wordvec_file") |
|
|
|
|
|
logger.debug("Looking for pretrained vectors in {}".format(pretrain_file)) |
|
|
if os.path.exists(pretrain_file): |
|
|
return load_pretrain(pretrain_file, foundation_cache) |
|
|
elif args.wordvec_raw_file: |
|
|
vec_file = args.wordvec_raw_file |
|
|
logger.debug("Pretrain not found. Looking in {}".format(vec_file)) |
|
|
else: |
|
|
vec_file = utils.get_wordvec_file(args.wordvec_dir, args.shorthand, args.wordvec_type.name.lower()) |
|
|
logger.debug("Pretrain not found. Looking in {}".format(vec_file)) |
|
|
pretrain = Pretrain(pretrain_file, vec_file, args.pretrain_max_vocab) |
|
|
logger.debug("Embedding shape: %s" % str(pretrain.emb.shape)) |
|
|
return pretrain |
|
|
|
|
|
|
|
|
@staticmethod |
|
|
def build_new_model(args, train_set): |
|
|
""" |
|
|
Load pretrained pieces and then build a new model |
|
|
""" |
|
|
if train_set is None: |
|
|
raise ValueError("Must have a train set to build a new model - needed for labels and delta word vectors") |
|
|
|
|
|
labels = data.dataset_labels(train_set) |
|
|
|
|
|
if args.model_type == ModelType.CNN: |
|
|
pretrain = Trainer.load_pretrain(args, foundation_cache=None) |
|
|
elmo_model = utils.load_elmo(args.elmo_model) if args.use_elmo else None |
|
|
charmodel_forward = load_charlm(args.charlm_forward_file) |
|
|
charmodel_backward = load_charlm(args.charlm_backward_file) |
|
|
peft_name = None |
|
|
bert_model, bert_tokenizer = load_bert(args.bert_model) |
|
|
|
|
|
use_peft = getattr(args, "use_peft", False) |
|
|
if use_peft: |
|
|
peft_name = "sentiment" |
|
|
bert_model = build_peft_wrapper(bert_model, vars(args), logger, adapter_name=peft_name) |
|
|
|
|
|
extra_vocab = data.dataset_vocab(train_set) |
|
|
force_bert_saved = args.bert_finetune |
|
|
model = cnn_classifier.CNNClassifier(pretrain=pretrain, |
|
|
extra_vocab=extra_vocab, |
|
|
labels=labels, |
|
|
charmodel_forward=charmodel_forward, |
|
|
charmodel_backward=charmodel_backward, |
|
|
elmo_model=elmo_model, |
|
|
bert_model=bert_model, |
|
|
bert_tokenizer=bert_tokenizer, |
|
|
force_bert_saved=force_bert_saved, |
|
|
peft_name=peft_name, |
|
|
args=args) |
|
|
model = model.to(args.device) |
|
|
elif args.model_type == ModelType.CONSTITUENCY: |
|
|
|
|
|
|
|
|
parser_args = { x[len("constituency_"):]: y for x, y in vars(args).items() if x.startswith("constituency_") } |
|
|
parser_args.update({ |
|
|
"wordvec_pretrain_file": args.wordvec_pretrain_file, |
|
|
"charlm_forward_file": args.charlm_forward_file, |
|
|
"charlm_backward_file": args.charlm_backward_file, |
|
|
"bert_model": args.bert_model, |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"bert_finetune": False, |
|
|
"stage1_bert_finetune": False, |
|
|
}) |
|
|
logger.info("Building constituency classifier using %s as the base model" % args.constituency_model) |
|
|
tree_embedding = TreeEmbedding.from_parser_file(parser_args) |
|
|
model = constituency_classifier.ConstituencyClassifier(tree_embedding=tree_embedding, |
|
|
labels=labels, |
|
|
args=args) |
|
|
model = model.to(args.device) |
|
|
else: |
|
|
raise ValueError("Unhandled model type {}".format(args.model_type)) |
|
|
|
|
|
optimizer = Trainer.build_optimizer(model, args) |
|
|
|
|
|
return Trainer(model, optimizer) |
|
|
|
|
|
|
|
|
@staticmethod |
|
|
def build_optimizer(model, args): |
|
|
return get_split_optimizer(args.optim.lower(), model, args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay, bert_learning_rate=args.bert_learning_rate, bert_weight_decay=args.weight_decay * args.bert_weight_decay, is_peft=args.use_peft) |
|
|
|