|
|
import argparse |
|
|
import ast |
|
|
import logging |
|
|
import os |
|
|
import random |
|
|
import re |
|
|
from enum import Enum |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
from stanza.models.common import loss |
|
|
from stanza.models.common import utils |
|
|
from stanza.models.pos.vocab import CharVocab |
|
|
|
|
|
import stanza.models.classifiers.data as data |
|
|
from stanza.models.classifiers.trainer import Trainer |
|
|
from stanza.models.classifiers.utils import WVType, ExtraVectors, ModelType |
|
|
from stanza.models.common.peft_config import add_peft_args, resolve_peft_args |
|
|
|
|
|
from stanza.utils.confusion import format_confusion, confusion_to_accuracy, confusion_to_macro_f1 |
|
|
|
|
|
|
|
|
class Loss(Enum): |
|
|
CROSS = 1 |
|
|
WEIGHTED_CROSS = 2 |
|
|
LOG_CROSS = 3 |
|
|
FOCAL = 4 |
|
|
|
|
|
class DevScoring(Enum): |
|
|
ACCURACY = 'ACC' |
|
|
WEIGHTED_F1 = 'WF' |
|
|
|
|
|
logger = logging.getLogger('stanza') |
|
|
tlogger = logging.getLogger('stanza.classifiers.trainer') |
|
|
|
|
|
logging.getLogger('elmoformanylangs').setLevel(logging.WARNING) |
|
|
|
|
|
DEFAULT_TRAIN='data/sentiment/en_sstplus.train.txt' |
|
|
DEFAULT_DEV='data/sentiment/en_sst3roots.dev.txt' |
|
|
DEFAULT_TEST='data/sentiment/en_sst3roots.test.txt' |
|
|
|
|
|
"""A script for training and testing classifier models, especially on the SST. |
|
|
|
|
|
If you run the script with no arguments, it will start trying to train |
|
|
a sentiment model. |
|
|
|
|
|
python3 -m stanza.models.classifier |
|
|
|
|
|
This requires the sentiment dataset to be in an `extern_data` |
|
|
directory, such as by symlinking it from somewhere else. |
|
|
|
|
|
The default model is a CNN where the word vectors are first mapped to |
|
|
channels with filters of a few different widths, those channels are |
|
|
maxpooled over the entire sentence, and then the resulting pools have |
|
|
fully connected layers until they reach the number of classes in the |
|
|
training data. You can see the defaults in the options below. |
|
|
|
|
|
https://arxiv.org/abs/1408.5882 |
|
|
|
|
|
(Currently the CNN is the only sentence classifier implemented.) |
|
|
|
|
|
To train with a more complicated CNN arch: |
|
|
|
|
|
nohup python3 -u -m stanza.models.classifier --max_epochs 400 --filter_channels 1000 --fc_shapes 400,100 > FC41.out 2>&1 & |
|
|
|
|
|
You can train models with word vectors other than the default word2vec. For example: |
|
|
|
|
|
nohup python3 -u -m stanza.models.classifier --wordvec_type google --wordvec_dir extern_data/google --max_epochs 200 --filter_channels 1000 --fc_shapes 200,100 --base_name FC21_google > FC21_google.out 2>&1 & |
|
|
|
|
|
A model trained on the 5 class dataset can be tested on the 2 class dataset with a command line like this: |
|
|
|
|
|
python3 -u -m stanza.models.classifier --no_train --load_name saved_models/classifier/sst_en_ewt_FS_3_4_5_C_1000_FC_400_100_classifier.E0165-ACC41.87.pt --test_file data/sentiment/en_sst2roots.test.txt --test_remap_labels "{0:0, 1:0, 3:1, 4:1}" |
|
|
|
|
|
python3 -u -m stanza.models.classifier --wordvec_type google --wordvec_dir extern_data/google --no_train --load_name saved_models/classifier/FC21_google_en_ewt_FS_3_4_5_C_1000_FC_200_100_classifier.E0189-ACC45.87.pt --test_file data/sentiment/en_sst2roots.test.txt --test_remap_labels "{0:0, 1:0, 3:1, 4:1}" |
|
|
|
|
|
A model trained on the 3 class dataset can be tested on the 2 class dataset with a command line like this: |
|
|
|
|
|
python3 -u -m stanza.models.classifier --wordvec_type google --wordvec_dir extern_data/google --no_train --load_name saved_models/classifier/FC21_3C_google_en_ewt_FS_3_4_5_C_1000_FC_200_100_classifier.E0101-ACC68.94.pt --test_file data/sentiment/en_sst2roots.test.txt --test_remap_labels "{0:0, 2:1}" |
|
|
|
|
|
To train models on combined 3 class datasets: |
|
|
|
|
|
nohup python3 -u -m stanza.models.classifier --max_epochs 400 --filter_channels 1000 --fc_shapes 400,100 --base_name FC41_3class --extra_wordvec_method CONCAT --extra_wordvec_dim 200 --train_file data/sentiment/en_sstplus.train.txt --dev_file data/sentiment/en_sst3roots.dev.txt --test_file data/sentiment/en_sst3roots.test.txt > FC41_3class.out 2>&1 & |
|
|
|
|
|
This tests that model: |
|
|
|
|
|
python3 -u -m stanza.models.classifier --no_train --load_name en_sstplus.pt --test_file data/sentiment/en_sst3roots.test.txt |
|
|
|
|
|
Here is an example for training a model in a different language: |
|
|
|
|
|
nohup python3 -u -m stanza.models.classifier --max_epochs 400 --filter_channels 1000 --fc_shapes 400,100 --base_name FC41_german --train_file data/sentiment/de_sb10k.train.txt --dev_file data/sentiment/de_sb10k.dev.txt --test_file data/sentiment/de_sb10k.test.txt --shorthand de_sb10k --min_train_len 3 --extra_wordvec_method CONCAT --extra_wordvec_dim 100 > de_sb10k.out 2>&1 & |
|
|
|
|
|
This uses more data, although that wound up being worse for the German model: |
|
|
|
|
|
nohup python3 -u -m stanza.models.classifier --max_epochs 400 --filter_channels 1000 --fc_shapes 400,100 --base_name FC41_german --train_file data/sentiment/de_sb10k.train.txt,data/sentiment/de_scare.train.txt,data/sentiment/de_usage.train.txt --dev_file data/sentiment/de_sb10k.dev.txt --test_file data/sentiment/de_sb10k.test.txt --shorthand de_sb10k --min_train_len 3 --extra_wordvec_method CONCAT --extra_wordvec_dim 100 > de_sb10k.out 2>&1 & |
|
|
|
|
|
nohup python3 -u -m stanza.models.classifier --max_epochs 400 --filter_channels 1000 --fc_shapes 400,100 --base_name FC41_chinese --train_file data/sentiment/zh_ren.train.txt --dev_file data/sentiment/zh_ren.dev.txt --test_file data/sentiment/zh_ren.test.txt --shorthand zh_ren --wordvec_type fasttext --extra_wordvec_method SUM --wordvec_pretrain_file ../stanza_resources/zh-hans/pretrain/gsdsimp.pt > zh_ren.out 2>&1 & |
|
|
|
|
|
nohup python3 -u -m stanza.models.classifier --max_epochs 400 --filter_channels 1000 --fc_shapes 400,100 --save_name vi_vsfc.pt --train_file data/sentiment/vi_vsfc.train.json --dev_file data/sentiment/vi_vsfc.dev.json --test_file data/sentiment/vi_vsfc.test.json --shorthand vi_vsfc --wordvec_pretrain_file ../stanza_resources/vi/pretrain/vtb.pt --wordvec_type word2vec --extra_wordvec_method SUM --dev_eval_scoring WEIGHTED_F1 > vi_vsfc.out 2>&1 & |
|
|
|
|
|
python3 -u -m stanza.models.classifier --no_train --test_file extern_data/sentiment/vietnamese/_UIT-VSFC/test.txt --shorthand vi_vsfc --wordvec_pretrain_file ../stanza_resources/vi/pretrain/vtb.pt --wordvec_type word2vec --load_name vi_vsfc.pt |
|
|
""" |
|
|
|
|
|
def convert_fc_shapes(arg): |
|
|
""" |
|
|
Returns a tuple of sizes to use in FC layers. |
|
|
|
|
|
For examples, converts "100" -> (100,) |
|
|
"100,200" -> (100,200) |
|
|
""" |
|
|
arg = arg.strip() |
|
|
if not arg: |
|
|
return () |
|
|
arg = ast.literal_eval(arg) |
|
|
if isinstance(arg, int): |
|
|
return (arg,) |
|
|
if isinstance(arg, tuple): |
|
|
return arg |
|
|
return tuple(arg) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DEFAULT_LEARNING_RATES = { "adamw": 0.0002, "adadelta": 1.0, "sgd": 0.001, "adabelief": 0.00005, "madgrad": 0.0005, "sgd": 0.001 } |
|
|
DEFAULT_LEARNING_EPS = { "adabelief": 1e-12, "adadelta": 1e-6, "adamw": 1e-8 } |
|
|
DEFAULT_LEARNING_RHO = 0.9 |
|
|
DEFAULT_MOMENTUM = { "madgrad": 0.9, "sgd": 0.9 } |
|
|
DEFAULT_WEIGHT_DECAY = { "adamw": 0.05, "adadelta": 0.0001, "sgd": 0.01, "adabelief": 1.2e-6, "madgrad": 2e-6 } |
|
|
|
|
|
def build_argparse(): |
|
|
""" |
|
|
Build the argparse for the classifier. |
|
|
|
|
|
Refactored so that other utility scripts can use the same parser if needed. |
|
|
""" |
|
|
parser = argparse.ArgumentParser() |
|
|
|
|
|
parser.add_argument('--train', dest='train', default=True, action='store_true', help='Train the model (default)') |
|
|
parser.add_argument('--no_train', dest='train', action='store_false', help="Don't train the model") |
|
|
|
|
|
parser.add_argument('--shorthand', type=str, default='en_ewt', help="Treebank shorthand, eg 'en' for English") |
|
|
|
|
|
parser.add_argument('--load_name', type=str, default=None, help='Name for loading an existing model') |
|
|
parser.add_argument('--save_dir', type=str, default='saved_models/classifier', help='Root dir for saving models.') |
|
|
parser.add_argument('--save_name', type=str, default="{shorthand}_{embedding}_{bert_finetuning}_{classifier_type}_classifier.pt", help='Name for saving the model') |
|
|
|
|
|
parser.add_argument('--checkpoint_save_name', type=str, default=None, help="File name to save the most recent checkpoint") |
|
|
parser.add_argument('--no_checkpoint', dest='checkpoint', action='store_false', help="Don't save checkpoints") |
|
|
|
|
|
parser.add_argument('--save_intermediate_models', default=False, action='store_true', |
|
|
help='Save all intermediate models - this can be a lot!') |
|
|
|
|
|
parser.add_argument('--train_file', type=str, default=DEFAULT_TRAIN, help='Input file(s) to train a model from. Each line is an example. Should go <label> <tokenized sentence>. Comma separated list.') |
|
|
parser.add_argument('--dev_file', type=str, default=DEFAULT_DEV, help='Input file(s) to use as the dev set.') |
|
|
parser.add_argument('--test_file', type=str, default=DEFAULT_TEST, help='Input file(s) to use as the test set.') |
|
|
parser.add_argument('--output_predictions', default=False, action='store_true', help='Output predictions when running the test set') |
|
|
parser.add_argument('--max_epochs', type=int, default=100) |
|
|
parser.add_argument('--tick', type=int, default=50) |
|
|
|
|
|
parser.add_argument('--model_type', type=lambda x: ModelType[x.upper()], default=ModelType.CNN, |
|
|
help='Model type to use. Options: %s' % " ".join(x.name for x in ModelType)) |
|
|
|
|
|
parser.add_argument('--filter_sizes', default=(3,4,5), type=ast.literal_eval, help='Filter sizes for the layer after the word vectors') |
|
|
parser.add_argument('--filter_channels', default=1000, type=ast.literal_eval, help='Number of channels for layers after the word vectors. Int for same number of channels (scaled by width) for each filter, or tuple/list for exact lengths for each filter') |
|
|
parser.add_argument('--fc_shapes', default="400,100", type=convert_fc_shapes, help='Extra fully connected layers to put after the initial filters. If set to blank, will FC directly from the max pooling to the output layer.') |
|
|
parser.add_argument('--dropout', default=0.5, type=float, help='Dropout value to use') |
|
|
|
|
|
parser.add_argument('--batch_size', default=50, type=int, help='Batch size when training') |
|
|
parser.add_argument('--batch_single_item', default=200, type=int, help='Items of this size go in their own batch') |
|
|
parser.add_argument('--dev_eval_batches', default=2000, type=int, help='Run the dev set after this many train batches. Set to 0 to only do it once per epoch') |
|
|
parser.add_argument('--dev_eval_scoring', type=lambda x: DevScoring[x.upper()], default=DevScoring.WEIGHTED_F1, |
|
|
help=('Scoring method to use for choosing the best model. Options: %s' % |
|
|
" ".join(x.name for x in DevScoring))) |
|
|
|
|
|
parser.add_argument('--weight_decay', default=None, type=float, help='Weight decay (eg, l2 reg) to use in the optimizer') |
|
|
parser.add_argument('--learning_rate', default=None, type=float, help='Learning rate to use in the optimizer') |
|
|
parser.add_argument('--momentum', default=None, type=float, help='Momentum to use in the optimizer') |
|
|
|
|
|
parser.add_argument('--optim', default='adadelta', choices=['adadelta', 'madgrad', 'sgd'], help='Optimizer type: SGD, Adadelta, or madgrad. Highly recommend to install madgrad and use that') |
|
|
|
|
|
parser.add_argument('--test_remap_labels', default=None, type=ast.literal_eval, |
|
|
help='Map of which label each classifier label should map to. For example, "{0:0, 1:0, 3:1, 4:1}" to map a 5 class sentiment test to a 2 class. Any labels not mapped will be considered wrong') |
|
|
parser.add_argument('--forgive_unmapped_labels', dest='forgive_unmapped_labels', default=True, action='store_true', |
|
|
help='When remapping labels, such as from 5 class to 2 class, pick a different label if the first guess is not remapped.') |
|
|
parser.add_argument('--no_forgive_unmapped_labels', dest='forgive_unmapped_labels', action='store_false', |
|
|
help="When remapping labels, such as from 5 class to 2 class, DON'T pick a different label if the first guess is not remapped.") |
|
|
|
|
|
parser.add_argument('--loss', type=lambda x: Loss[x.upper()], default=Loss.CROSS, |
|
|
help="Whether to use regular cross entropy or scale it by 1/log(quantity)") |
|
|
parser.add_argument('--loss_focal_gamma', default=2, type=float, help='gamma value for a focal loss') |
|
|
parser.add_argument('--min_train_len', type=int, default=0, |
|
|
help="Filter sentences less than this length") |
|
|
|
|
|
parser.add_argument('--pretrain_max_vocab', type=int, default=-1) |
|
|
parser.add_argument('--wordvec_pretrain_file', type=str, default=None, help='Exact name of the pretrain file to read') |
|
|
parser.add_argument('--wordvec_raw_file', type=str, default=None, help='Exact name of the raw wordvec file to read') |
|
|
parser.add_argument('--wordvec_dir', type=str, default='extern_data/wordvec', help='Directory of word vectors') |
|
|
parser.add_argument('--wordvec_type', type=lambda x: WVType[x.upper()], default='word2vec', help='Different vector types have different options, such as google 300d replacing numbers with #') |
|
|
parser.add_argument('--extra_wordvec_dim', type=int, default=0, help="Extra dim of word vectors - will be trained") |
|
|
parser.add_argument('--extra_wordvec_method', type=lambda x: ExtraVectors[x.upper()], default='sum', help='How to train extra dimensions of word vectors, if at all') |
|
|
parser.add_argument('--extra_wordvec_max_norm', type=float, default=None, help="Max norm for initializing the extra vectors") |
|
|
|
|
|
parser.add_argument('--charlm_forward_file', type=str, default=None, help="Exact path to use for forward charlm") |
|
|
parser.add_argument('--charlm_backward_file', type=str, default=None, help="Exact path to use for backward charlm") |
|
|
parser.add_argument('--charlm_projection', type=int, default=None, help="Project the charlm values to this dimension") |
|
|
parser.add_argument('--char_lowercase', dest='char_lowercase', action='store_true', help="Use lowercased characters in character model.") |
|
|
|
|
|
parser.add_argument('--elmo_model', default='extern_data/manyelmo/english', help='Directory with elmo model') |
|
|
parser.add_argument('--use_elmo', dest='use_elmo', default=False, action='store_true', help='Use an elmo model as a source of parameters') |
|
|
parser.add_argument('--elmo_projection', type=int, default=None, help='Project elmo to this many dimensions') |
|
|
|
|
|
parser.add_argument('--bert_model', type=str, default=None, help="Use an external bert model (requires the transformers package)") |
|
|
parser.add_argument('--no_bert_model', dest='bert_model', action="store_const", const=None, help="Don't use bert") |
|
|
parser.add_argument('--bert_finetune', default=False, action='store_true', help="Finetune the Bert model") |
|
|
parser.add_argument('--bert_learning_rate', default=0.01, type=float, help='Scale the learning rate for transformer finetuning by this much') |
|
|
parser.add_argument('--bert_weight_decay', default=0.0001, type=float, help='Scale the weight decay for transformer finetuning by this much') |
|
|
parser.add_argument('--bert_hidden_layers', type=int, default=4, help="How many layers of hidden state to use from the transformer") |
|
|
parser.add_argument('--bert_hidden_layers_original', action='store_const', const=None, dest='bert_hidden_layers', help='Use layers 2,3,4 of the Bert embedding') |
|
|
|
|
|
parser.add_argument('--bilstm', dest='bilstm', action='store_true', default=True, help="Use a bilstm after the inputs, before the convs. Using bilstm is about as accurate and significantly faster (because of dim reduction) than going straight to the filters") |
|
|
parser.add_argument('--no_bilstm', dest='bilstm', action='store_false', help="Don't use a bilstm after the inputs, before the convs.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
parser.add_argument('--bilstm_hidden_dim', type=int, default=300, help="Dimension of the bilstm to use") |
|
|
|
|
|
parser.add_argument('--maxpool_width', type=int, default=1, help="Width of the maxpool kernel to use") |
|
|
|
|
|
parser.add_argument('--no_constituency_backprop', dest='constituency_backprop', default=True, action='store_false', help="When using a constituency parser, backprop into the parser's weights if True") |
|
|
parser.add_argument('--constituency_model', type=str, default="/home/john/stanza_resources/it/constituency/vit_bert.pt", help="Which constituency model to use. TODO: make this more user friendly") |
|
|
parser.add_argument('--constituency_batch_norm', default=False, action='store_true', help='Add a LayerNorm between the output of the parser and the classifier layers') |
|
|
parser.add_argument('--constituency_node_attn', default=False, action='store_true', help='True means to make an attn layer out of the tree, with the words as key and nodes as query') |
|
|
parser.add_argument('--no_constituency_node_attn', dest='constituency_node_attn', action='store_false', help='True means to make an attn layer out of the tree, with the words as key and nodes as query') |
|
|
parser.add_argument('--constituency_top_layer', dest='constituency_top_layer', default=False, action='store_true', help='True means use the top (ROOT) layer of the constituents. Otherwise, the next layer down (S, usually) will be used') |
|
|
parser.add_argument('--no_constituency_top_layer', dest='constituency_top_layer', action='store_false', help='True means use the top (ROOT) layer of the constituents. Otherwise, the next layer down (S, usually) will be used') |
|
|
parser.add_argument('--constituency_all_words', default=False, action='store_true', help='Use all word positions in the constituency classifier') |
|
|
parser.add_argument('--no_constituency_all_words', dest='constituency_all_words', default=False, action='store_false', help='Use the start and end word embeddings as inputs to the constituency classifier') |
|
|
|
|
|
parser.add_argument('--log_norms', default=False, action='store_true', help='Log the parameters norms while training. A very noisy option') |
|
|
|
|
|
parser.add_argument('--wandb', action='store_true', help='Start a wandb session and write the results of training. Only applies to training. Use --wandb_name instead to specify a name') |
|
|
parser.add_argument('--wandb_name', default=None, help='Name of a wandb session to start when training. Will default to the dataset short name') |
|
|
|
|
|
parser.add_argument('--seed', default=None, type=int, help='Random seed for model') |
|
|
|
|
|
add_peft_args(parser) |
|
|
utils.add_device_args(parser) |
|
|
|
|
|
return parser |
|
|
|
|
|
def build_model_filename(args): |
|
|
shape = "FS_%s" % "_".join([str(x) for x in args.filter_sizes]) |
|
|
shape = shape + "_C_%d_" % args.filter_channels |
|
|
if args.fc_shapes: |
|
|
shape = shape + "_FC_%s_" % "_".join([str(x) for x in args.fc_shapes]) |
|
|
|
|
|
model_save_file = utils.standard_model_file_name(vars(args), "classifier", shape=shape, classifier_type=args.model_type.name) |
|
|
logger.info("Expanded save_name: %s", model_save_file) |
|
|
return model_save_file |
|
|
|
|
|
def parse_args(args=None): |
|
|
""" |
|
|
Add arguments for building the classifier. |
|
|
Parses command line args and returns the result. |
|
|
""" |
|
|
parser = build_argparse() |
|
|
args = parser.parse_args(args) |
|
|
resolve_peft_args(args, tlogger) |
|
|
|
|
|
if args.wandb_name: |
|
|
args.wandb = True |
|
|
|
|
|
args.optim = args.optim.lower() |
|
|
if args.weight_decay is None: |
|
|
args.weight_decay = DEFAULT_WEIGHT_DECAY.get(args.optim, None) |
|
|
if args.momentum is None: |
|
|
args.momentum = DEFAULT_MOMENTUM.get(args.optim, None) |
|
|
if args.learning_rate is None: |
|
|
args.learning_rate = DEFAULT_LEARNING_RATES.get(args.optim, None) |
|
|
|
|
|
return args |
|
|
|
|
|
|
|
|
def dataset_predictions(model, dataset): |
|
|
model.eval() |
|
|
index_label_map = {x: y for (x, y) in enumerate(model.labels)} |
|
|
|
|
|
dataset_lengths = data.sort_dataset_by_len(dataset, keep_index=True) |
|
|
|
|
|
predictions = [] |
|
|
o_idx = [] |
|
|
for length in dataset_lengths.keys(): |
|
|
batch = dataset_lengths[length] |
|
|
output = model([x[0] for x in batch]) |
|
|
for i in range(len(batch)): |
|
|
predicted = torch.argmax(output[i]) |
|
|
predicted_label = index_label_map[predicted.item()] |
|
|
predictions.append(predicted_label) |
|
|
o_idx.append(batch[i][1]) |
|
|
|
|
|
predictions = utils.unsort(predictions, o_idx) |
|
|
return predictions |
|
|
|
|
|
def confusion_dataset(predictions, dataset, labels): |
|
|
""" |
|
|
Returns a confusion matrix |
|
|
|
|
|
First key: gold |
|
|
Second key: predicted |
|
|
so: confusion_matrix[gold][predicted] |
|
|
""" |
|
|
confusion_matrix = {} |
|
|
for label in labels: |
|
|
confusion_matrix[label] = {} |
|
|
|
|
|
for predicted_label, datum in zip(predictions, dataset): |
|
|
expected_label = datum.sentiment |
|
|
confusion_matrix[expected_label][predicted_label] = confusion_matrix[expected_label].get(predicted_label, 0) + 1 |
|
|
|
|
|
return confusion_matrix |
|
|
|
|
|
|
|
|
def score_dataset(model, dataset, label_map=None, |
|
|
remap_labels=None, forgive_unmapped_labels=False): |
|
|
""" |
|
|
remap_labels: a dict from old label to new label to use when |
|
|
testing a classifier on a dataset with a simpler label set. |
|
|
For example, a model trained on 5 class sentiment can be tested |
|
|
on a binary distribution with {"0": "0", "1": "0", "3": "1", "4": "1"} |
|
|
|
|
|
forgive_unmapped_labels says the following: in the case that the |
|
|
model predicts "2" in the above example for remap_labels, instead |
|
|
treat the model's prediction as whichever label it gave the |
|
|
highest score |
|
|
""" |
|
|
model.eval() |
|
|
if label_map is None: |
|
|
label_map = {x: y for (y, x) in enumerate(model.labels)} |
|
|
correct = 0 |
|
|
dataset_lengths = data.sort_dataset_by_len(dataset) |
|
|
|
|
|
for length in dataset_lengths.keys(): |
|
|
|
|
|
batch = dataset_lengths[length] |
|
|
expected_labels = [label_map[x.sentiment] for x in batch] |
|
|
|
|
|
output = model(batch) |
|
|
|
|
|
for i in range(len(expected_labels)): |
|
|
predicted = torch.argmax(output[i]) |
|
|
predicted_label = predicted.item() |
|
|
if remap_labels: |
|
|
if predicted_label in remap_labels: |
|
|
predicted_label = remap_labels[predicted_label] |
|
|
else: |
|
|
found = False |
|
|
if forgive_unmapped_labels: |
|
|
items = [] |
|
|
for j in range(len(output[i])): |
|
|
items.append((output[i][j].item(), j)) |
|
|
items.sort(key=lambda x: -x[0]) |
|
|
for _, item in items: |
|
|
if item in remap_labels: |
|
|
predicted_label = remap_labels[item] |
|
|
found = True |
|
|
break |
|
|
|
|
|
|
|
|
|
|
|
if not found: |
|
|
continue |
|
|
|
|
|
if predicted_label == expected_labels[i]: |
|
|
correct = correct + 1 |
|
|
return correct |
|
|
|
|
|
def score_dev_set(model, dev_set, dev_eval_scoring): |
|
|
predictions = dataset_predictions(model, dev_set) |
|
|
confusion_matrix = confusion_dataset(predictions, dev_set, model.labels) |
|
|
logger.info("Dev set confusion matrix:\n{}".format(format_confusion(confusion_matrix, model.labels))) |
|
|
correct, total = confusion_to_accuracy(confusion_matrix) |
|
|
macro_f1 = confusion_to_macro_f1(confusion_matrix) |
|
|
logger.info("Dev set: %d correct of %d examples. Accuracy: %f" % |
|
|
(correct, len(dev_set), correct / len(dev_set))) |
|
|
logger.info("Macro f1: {}".format(macro_f1)) |
|
|
|
|
|
accuracy = correct / total |
|
|
if dev_eval_scoring is DevScoring.ACCURACY: |
|
|
return accuracy, accuracy, macro_f1 |
|
|
elif dev_eval_scoring is DevScoring.WEIGHTED_F1: |
|
|
return macro_f1, accuracy, macro_f1 |
|
|
else: |
|
|
raise ValueError("Unknown scoring method {}".format(dev_eval_scoring)) |
|
|
|
|
|
def intermediate_name(filename, epoch, dev_scoring, score): |
|
|
""" |
|
|
Build an informative intermediate checkpoint name from a base name, epoch #, and accuracy |
|
|
""" |
|
|
root, ext = os.path.splitext(filename) |
|
|
return root + ".E{epoch:04d}-{score_type}{acc:05.2f}".format(**{"epoch": epoch, "score_type": dev_scoring.value, "acc": score * 100}) + ext |
|
|
|
|
|
def log_param_sizes(model): |
|
|
logger.debug("--- Model parameter sizes ---") |
|
|
total_size = 0 |
|
|
for name, param in model.named_parameters(): |
|
|
param_size = param.element_size() * param.nelement() |
|
|
total_size += param_size |
|
|
logger.debug(" %s %d %d %d", name, param.element_size(), param.nelement(), param_size) |
|
|
logger.debug(" Total size: %d", total_size) |
|
|
|
|
|
def train_model(trainer, model_file, checkpoint_file, args, train_set, dev_set, labels): |
|
|
tlogger.setLevel(logging.DEBUG) |
|
|
|
|
|
|
|
|
model = trainer.model |
|
|
optimizer = trainer.optimizer |
|
|
|
|
|
device = next(model.parameters()).device |
|
|
logger.info("Current device: %s" % device) |
|
|
|
|
|
label_map = {x: y for (y, x) in enumerate(labels)} |
|
|
label_tensors = {x: torch.tensor(y, requires_grad=False, device=device) |
|
|
for (y, x) in enumerate(labels)} |
|
|
|
|
|
process_outputs = lambda x: x |
|
|
if args.loss == Loss.CROSS: |
|
|
logger.info("Creating CrossEntropyLoss") |
|
|
loss_function = nn.CrossEntropyLoss() |
|
|
elif args.loss == Loss.WEIGHTED_CROSS: |
|
|
logger.info("Creating weighted cross entropy loss w/o log") |
|
|
loss_function = loss.weighted_cross_entropy_loss([label_map[x[0]] for x in train_set], log_dampened=False) |
|
|
elif args.loss == Loss.LOG_CROSS: |
|
|
logger.info("Creating weighted cross entropy loss w/ log") |
|
|
loss_function = loss.weighted_cross_entropy_loss([label_map[x[0]] for x in train_set], log_dampened=True) |
|
|
elif args.loss == Loss.FOCAL: |
|
|
try: |
|
|
from focal_loss.focal_loss import FocalLoss |
|
|
except ImportError: |
|
|
raise ImportError("focal_loss not installed. Must `pip install focal_loss_torch` to use the --loss=focal feature") |
|
|
logger.info("Creating FocalLoss with loss %f", args.loss_focal_gamma) |
|
|
process_outputs = lambda x: torch.softmax(x, dim=1) |
|
|
loss_function = FocalLoss(gamma=args.loss_focal_gamma) |
|
|
else: |
|
|
raise ValueError("Unknown loss function {}".format(args.loss)) |
|
|
loss_function.to(device) |
|
|
|
|
|
train_set_by_len = data.sort_dataset_by_len(train_set) |
|
|
|
|
|
if trainer.global_step > 0: |
|
|
|
|
|
_ = score_dev_set(model, dev_set, args.dev_eval_scoring) |
|
|
logger.info("Reloaded model for continued training.") |
|
|
if trainer.best_score is not None: |
|
|
logger.info("Previous best score: %.5f", trainer.best_score) |
|
|
|
|
|
log_param_sizes(model) |
|
|
|
|
|
|
|
|
if args.wandb: |
|
|
import wandb |
|
|
wandb_name = args.wandb_name if args.wandb_name else "%s_classifier" % args.shorthand |
|
|
wandb.init(name=wandb_name, config=args) |
|
|
wandb.run.define_metric('accuracy', summary='max') |
|
|
wandb.run.define_metric('macro_f1', summary='max') |
|
|
wandb.run.define_metric('epoch_loss', summary='min') |
|
|
|
|
|
for opt_name, opt in optimizer.items(): |
|
|
current_lr = opt.param_groups[0]['lr'] |
|
|
logger.info("optimizer %s learning rate: %s", opt_name, current_lr) |
|
|
|
|
|
|
|
|
if args.save_intermediate_models and trainer.epochs_trained == 0: |
|
|
intermediate_file = intermediate_name(model_file, trainer.epochs_trained, args.dev_eval_scoring, 0.0) |
|
|
trainer.save(intermediate_file, save_optimizer=False) |
|
|
for trainer.epochs_trained in range(trainer.epochs_trained, args.max_epochs): |
|
|
running_loss = 0.0 |
|
|
epoch_loss = 0.0 |
|
|
shuffled_batches = data.shuffle_dataset(train_set_by_len, args.batch_size, args.batch_single_item) |
|
|
|
|
|
model.train() |
|
|
logger.info("Starting epoch %d", trainer.epochs_trained) |
|
|
if args.log_norms: |
|
|
model.log_norms() |
|
|
|
|
|
for batch_num, batch in enumerate(shuffled_batches): |
|
|
|
|
|
trainer.global_step += 1 |
|
|
logger.debug("Starting batch: %d step %d", batch_num, trainer.global_step) |
|
|
|
|
|
batch_labels = torch.stack([label_tensors[x.sentiment] for x in batch]) |
|
|
|
|
|
|
|
|
for opt in optimizer.values(): |
|
|
opt.zero_grad() |
|
|
|
|
|
outputs = model(batch) |
|
|
outputs = process_outputs(outputs) |
|
|
batch_loss = loss_function(outputs, batch_labels) |
|
|
batch_loss.backward() |
|
|
for opt in optimizer.values(): |
|
|
opt.step() |
|
|
|
|
|
|
|
|
running_loss += batch_loss.item() |
|
|
if (batch_num + 1) % args.tick == 0: |
|
|
train_loss = running_loss / args.tick |
|
|
logger.info('[%d, %5d] Average loss: %.3f', trainer.epochs_trained + 1, batch_num + 1, train_loss) |
|
|
if args.wandb: |
|
|
wandb.log({'train_loss': train_loss}, step=trainer.global_step) |
|
|
if args.dev_eval_batches > 0 and (batch_num + 1) % args.dev_eval_batches == 0: |
|
|
logger.info('---- Interim analysis ----') |
|
|
dev_score, accuracy, macro_f1 = score_dev_set(model, dev_set, args.dev_eval_scoring) |
|
|
if args.wandb: |
|
|
wandb.log({'accuracy': accuracy, 'macro_f1': macro_f1}, step=trainer.global_step) |
|
|
if trainer.best_score is None or dev_score > trainer.best_score: |
|
|
trainer.best_score = dev_score |
|
|
trainer.save(model_file, save_optimizer=False) |
|
|
logger.info("Saved new best score model! Accuracy %.5f Macro F1 %.5f Epoch %5d Batch %d" % (accuracy, macro_f1, trainer.epochs_trained+1, batch_num+1)) |
|
|
model.train() |
|
|
if args.log_norms: |
|
|
trainer.model.log_norms() |
|
|
epoch_loss += running_loss |
|
|
running_loss = 0.0 |
|
|
|
|
|
epoch_loss += running_loss |
|
|
|
|
|
logger.info("Finished epoch %d Total loss %.3f" % (trainer.epochs_trained + 1, epoch_loss)) |
|
|
dev_score, accuracy, macro_f1 = score_dev_set(model, dev_set, args.dev_eval_scoring) |
|
|
if args.wandb: |
|
|
wandb.log({'accuracy': accuracy, 'macro_f1': macro_f1, 'epoch_loss': epoch_loss}, step=trainer.global_step) |
|
|
if checkpoint_file: |
|
|
trainer.save(checkpoint_file, epochs_trained = trainer.epochs_trained + 1) |
|
|
if args.save_intermediate_models: |
|
|
intermediate_file = intermediate_name(model_file, trainer.epochs_trained + 1, args.dev_eval_scoring, dev_score) |
|
|
trainer.save(intermediate_file, save_optimizer=False) |
|
|
if trainer.best_score is None or dev_score > trainer.best_score: |
|
|
trainer.best_score = dev_score |
|
|
trainer.save(model_file, save_optimizer=False) |
|
|
logger.info("Saved new best score model! Accuracy %.5f Macro F1 %.5f Epoch %5d" % (accuracy, macro_f1, trainer.epochs_trained+1)) |
|
|
|
|
|
if args.wandb: |
|
|
wandb.finish() |
|
|
|
|
|
def main(args=None): |
|
|
args = parse_args(args) |
|
|
seed = utils.set_random_seed(args.seed) |
|
|
logger.info("Using random seed: %d" % seed) |
|
|
|
|
|
utils.ensure_dir(args.save_dir) |
|
|
|
|
|
save_name = build_model_filename(args) |
|
|
|
|
|
|
|
|
|
|
|
checkpoint_file = None |
|
|
if args.train: |
|
|
train_set = data.read_dataset(args.train_file, args.wordvec_type, args.min_train_len) |
|
|
logger.info("Using training set: %s" % args.train_file) |
|
|
logger.info("Training set has %d labels" % len(data.dataset_labels(train_set))) |
|
|
tlogger.setLevel(logging.DEBUG) |
|
|
|
|
|
tlogger.info("Saving checkpoints: %s", args.checkpoint) |
|
|
if args.checkpoint: |
|
|
checkpoint_file = utils.checkpoint_name(args.save_dir, save_name, args.checkpoint_save_name) |
|
|
tlogger.info("Checkpoint filename: %s", checkpoint_file) |
|
|
elif not args.load_name: |
|
|
if save_name: |
|
|
args.load_name = save_name |
|
|
else: |
|
|
raise ValueError("No model provided and not asked to train a model. This makes no sense") |
|
|
else: |
|
|
train_set = None |
|
|
|
|
|
if args.train and checkpoint_file is not None and os.path.exists(checkpoint_file): |
|
|
trainer = Trainer.load(checkpoint_file, args, load_optimizer=args.train) |
|
|
elif args.load_name: |
|
|
trainer = Trainer.load(args.load_name, args, load_optimizer=args.train) |
|
|
else: |
|
|
trainer = Trainer.build_new_model(args, train_set) |
|
|
|
|
|
trainer.model.log_configuration() |
|
|
|
|
|
if args.train: |
|
|
utils.log_training_args(args, logger) |
|
|
|
|
|
dev_set = data.read_dataset(args.dev_file, args.wordvec_type, min_len=None) |
|
|
logger.info("Using dev set: %s", args.dev_file) |
|
|
logger.info("Training set has %d items", len(train_set)) |
|
|
logger.info("Dev set has %d items", len(dev_set)) |
|
|
data.check_labels(trainer.model.labels, dev_set) |
|
|
|
|
|
train_model(trainer, save_name, checkpoint_file, args, train_set, dev_set, trainer.model.labels) |
|
|
|
|
|
if args.log_norms: |
|
|
trainer.model.log_norms() |
|
|
test_set = data.read_dataset(args.test_file, args.wordvec_type, min_len=None) |
|
|
logger.info("Using test set: %s" % args.test_file) |
|
|
data.check_labels(trainer.model.labels, test_set) |
|
|
|
|
|
if args.test_remap_labels is None: |
|
|
predictions = dataset_predictions(trainer.model, test_set) |
|
|
confusion_matrix = confusion_dataset(predictions, test_set, trainer.model.labels) |
|
|
if args.output_predictions: |
|
|
logger.info("List of predictions: %s", predictions) |
|
|
logger.info("Confusion matrix:\n{}".format(format_confusion(confusion_matrix, trainer.model.labels))) |
|
|
correct, total = confusion_to_accuracy(confusion_matrix) |
|
|
logger.info("Macro f1: {}".format(confusion_to_macro_f1(confusion_matrix))) |
|
|
else: |
|
|
correct = score_dataset(trainer.model, test_set, |
|
|
remap_labels=args.test_remap_labels, |
|
|
forgive_unmapped_labels=args.forgive_unmapped_labels) |
|
|
total = len(test_set) |
|
|
logger.info("Test set: %d correct of %d examples. Accuracy: %f" % |
|
|
(correct, total, correct / total)) |
|
|
|
|
|
if __name__ == '__main__': |
|
|
main() |
|
|
|