|
|
import os |
|
|
import sys |
|
|
|
|
|
parentdir = os.path.dirname(__file__) |
|
|
parentdir = os.path.dirname(parentdir) |
|
|
parentdir = os.path.dirname(parentdir) |
|
|
sys.path.append(parentdir) |
|
|
|
|
|
import logging |
|
|
import argparse |
|
|
import os |
|
|
|
|
|
from typing import Any, List, Tuple, Mapping |
|
|
from collections import defaultdict |
|
|
from numpy import random |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
import stanza |
|
|
|
|
|
from stanza.models.common.utils import default_device |
|
|
from stanza.models.lemma_classifier import utils |
|
|
from stanza.models.lemma_classifier.base_model import LemmaClassifier |
|
|
from stanza.models.lemma_classifier.lstm_model import LemmaClassifierLSTM |
|
|
from stanza.models.lemma_classifier.transformer_model import LemmaClassifierWithTransformer |
|
|
from stanza.utils.confusion import format_confusion |
|
|
from stanza.utils.get_tqdm import get_tqdm |
|
|
|
|
|
tqdm = get_tqdm() |
|
|
|
|
|
logger = logging.getLogger('stanza.lemmaclassifier') |
|
|
|
|
|
|
|
|
def get_weighted_f1(mcc_results: Mapping[int, Mapping[str, float]], confusion: Mapping[int, Mapping[int, int]]) -> float: |
|
|
""" |
|
|
Computes the weighted F1 score across an evaluation set. |
|
|
|
|
|
The weight of a class's F1 score is equal to the number of examples in evaluation. This makes classes that have more |
|
|
examples in the evaluation more impactful to the weighted f1. |
|
|
""" |
|
|
num_total_examples = 0 |
|
|
weighted_f1 = 0 |
|
|
|
|
|
for class_id in mcc_results: |
|
|
class_f1 = mcc_results.get(class_id).get("f1") |
|
|
num_class_examples = sum(confusion.get(class_id).values()) |
|
|
weighted_f1 += class_f1 * num_class_examples |
|
|
num_total_examples += num_class_examples |
|
|
|
|
|
return weighted_f1 / num_total_examples |
|
|
|
|
|
|
|
|
def evaluate_sequences(gold_tag_sequences: List[Any], pred_tag_sequences: List[Any], label_decoder: Mapping, verbose=True): |
|
|
""" |
|
|
Evaluates a model's predicted tags against a set of gold tags. Computes precision, recall, and f1 for all classes. |
|
|
|
|
|
Precision = true positives / true positives + false positives |
|
|
Recall = true positives / true positives + false negatives |
|
|
F1 = 2 * (Precision * Recall) / (Precision + Recall) |
|
|
|
|
|
Returns: |
|
|
1. Multi class result dictionary, where each class is a key and maps to another map of its F1, precision, and recall scores. |
|
|
e.g. multiclass_results[0]["precision"] would give class 0's precision. |
|
|
2. Confusion matrix, where each key is a gold tag and its value is another map with a key of the predicted tag with value of that (gold, pred) count. |
|
|
e.g. confusion[0][1] = 6 would mean that for gold tag 0, the model predicted tag 1 a total of 6 times. |
|
|
""" |
|
|
assert len(gold_tag_sequences) == len(pred_tag_sequences), \ |
|
|
f"Length of gold tag sequences is {len(gold_tag_sequences)}, while length of predicted tag sequence is {len(pred_tag_sequences)}" |
|
|
|
|
|
confusion = defaultdict(lambda: defaultdict(int)) |
|
|
|
|
|
reverse_label_decoder = {y: x for x, y in label_decoder.items()} |
|
|
for gold, pred in zip(gold_tag_sequences, pred_tag_sequences): |
|
|
confusion[reverse_label_decoder[gold]][reverse_label_decoder[pred]] += 1 |
|
|
|
|
|
multi_class_result = defaultdict(lambda: defaultdict(float)) |
|
|
|
|
|
for gold_tag in confusion.keys(): |
|
|
|
|
|
try: |
|
|
prec = confusion.get(gold_tag, {}).get(gold_tag, 0) / sum([confusion.get(k, {}).get(gold_tag, 0) for k in confusion.keys()]) |
|
|
except ZeroDivisionError: |
|
|
prec = 0.0 |
|
|
|
|
|
try: |
|
|
recall = confusion.get(gold_tag, {}).get(gold_tag, 0) / sum(confusion.get(gold_tag, {}).values()) |
|
|
except ZeroDivisionError: |
|
|
recall = 0.0 |
|
|
|
|
|
try: |
|
|
f1 = 2 * (prec * recall) / (prec + recall) |
|
|
except ZeroDivisionError: |
|
|
f1 = 0.0 |
|
|
|
|
|
multi_class_result[gold_tag] = { |
|
|
"precision": prec, |
|
|
"recall": recall, |
|
|
"f1": f1 |
|
|
} |
|
|
|
|
|
if verbose: |
|
|
for lemma in multi_class_result: |
|
|
logger.info(f"Lemma '{lemma}' had precision {100 * multi_class_result[lemma]['precision']}, recall {100 * multi_class_result[lemma]['recall']} and F1 score of {100 * multi_class_result[lemma]['f1']}") |
|
|
|
|
|
weighted_f1 = get_weighted_f1(multi_class_result, confusion) |
|
|
|
|
|
return multi_class_result, confusion, weighted_f1 |
|
|
|
|
|
|
|
|
def model_predict(model: nn.Module, position_indices: torch.Tensor, sentences: List[List[str]], upos_tags: List[List[int]]=[]) -> torch.Tensor: |
|
|
""" |
|
|
A LemmaClassifierLSTM or LemmaClassifierWithTransformer is used to predict on a single text example, given the position index of the target token. |
|
|
|
|
|
Args: |
|
|
model (LemmaClassifier): A trained LemmaClassifier that is able to predict on a target token. |
|
|
position_indices (Tensor[int]): A tensor of the (zero-indexed) position of the target token in `text` for each example in the batch. |
|
|
sentences (List[List[str]]): A list of lists of the tokenized strings of the input sentences. |
|
|
|
|
|
Returns: |
|
|
(int): The index of the predicted class in `model`'s output. |
|
|
""" |
|
|
with torch.no_grad(): |
|
|
logits = model(position_indices, sentences, upos_tags) |
|
|
predicted_class = torch.argmax(logits, dim=1) |
|
|
|
|
|
return predicted_class |
|
|
|
|
|
|
|
|
def evaluate_model(model: nn.Module, eval_path: str, verbose: bool = True, is_training: bool = False) -> Tuple[Mapping, Mapping, float, float]: |
|
|
""" |
|
|
Helper function for model evaluation |
|
|
|
|
|
Args: |
|
|
model (LemmaClassifierLSTM or LemmaClassifierWithTransformer): An instance of the LemmaClassifier class that has architecture initialized which matches the model saved in `model_path`. |
|
|
model_path (str): Path to the saved model weights that will be loaded into `model`. |
|
|
eval_path (str): Path to the saved evaluation dataset. |
|
|
verbose (bool, optional): True if `evaluate_sequences()` should print the F1, Precision, and Recall for each class. Defaults to True. |
|
|
is_training (bool, optional): Whether the model is in training mode. If the model is training, we do not change it to eval mode. |
|
|
|
|
|
Returns: |
|
|
1. Multi-class results (Mapping[int, Mapping[str, float]]): first map has keys as the classes (lemma indices) and value is |
|
|
another map with key of "f1", "precision", or "recall" with corresponding values. |
|
|
2. Confusion Matrix (Mapping[int, Mapping[int, int]]): A confusion matrix with keys equal to the index of the gold tag, and a value of the |
|
|
map with the key as the predicted tag and corresponding count of that (gold, pred) pair. |
|
|
3. Accuracy (float): the total accuracy (num correct / total examples) across the evaluation set. |
|
|
""" |
|
|
|
|
|
device = default_device() |
|
|
model.to(device) |
|
|
|
|
|
if not is_training: |
|
|
model.eval() |
|
|
|
|
|
|
|
|
dataset = utils.Dataset(eval_path, label_decoder=model.label_decoder, shuffle=False) |
|
|
|
|
|
logger.info(f"Evaluating on evaluation file {eval_path}") |
|
|
|
|
|
correct, total = 0, 0 |
|
|
gold_tags, pred_tags = dataset.labels, [] |
|
|
|
|
|
|
|
|
for sentences, pos_indices, upos_tags, labels in tqdm(dataset, "Evaluating examples from data file"): |
|
|
pred = model_predict(model, pos_indices, sentences, upos_tags) |
|
|
correct_preds = pred == labels.to(device) |
|
|
correct += torch.sum(correct_preds) |
|
|
total += len(correct_preds) |
|
|
pred_tags += pred.tolist() |
|
|
|
|
|
logger.info("Finished evaluating on dataset. Computing scores...") |
|
|
accuracy = correct / total |
|
|
|
|
|
mc_results, confusion, weighted_f1 = evaluate_sequences(gold_tags, pred_tags, dataset.label_decoder, verbose=verbose) |
|
|
|
|
|
if verbose: |
|
|
logger.info(f"Accuracy: {accuracy} ({correct}/{total})") |
|
|
logger.info(f"Label decoder: {dataset.label_decoder}") |
|
|
|
|
|
return mc_results, confusion, accuracy, weighted_f1 |
|
|
|
|
|
|
|
|
def main(args=None, predefined_args=None): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument("--vocab_size", type=int, default=10000, help="Number of tokens in vocab") |
|
|
parser.add_argument("--embedding_dim", type=int, default=100, help="Number of dimensions in word embeddings (currently using GloVe)") |
|
|
parser.add_argument("--hidden_dim", type=int, default=256, help="Size of hidden layer") |
|
|
parser.add_argument('--wordvec_pretrain_file', type=str, default=None, help='Exact name of the pretrain file to read') |
|
|
parser.add_argument("--charlm", action='store_true', default=False, help="Whether not to use the charlm embeddings") |
|
|
parser.add_argument('--charlm_shorthand', type=str, default=None, help="Shorthand for character-level language model training corpus.") |
|
|
parser.add_argument("--charlm_forward_file", type=str, default=os.path.join(os.path.dirname(__file__), "charlm_files", "1billion_forward.pt"), help="Path to forward charlm file") |
|
|
parser.add_argument("--charlm_backward_file", type=str, default=os.path.join(os.path.dirname(__file__), "charlm_files", "1billion_backwards.pt"), help="Path to backward charlm file") |
|
|
parser.add_argument("--save_name", type=str, default=os.path.join(os.path.dirname(__file__), "saved_models", "lemma_classifier_model.pt"), help="Path to model save file") |
|
|
parser.add_argument("--model_type", type=str, default="roberta", help="Which transformer to use ('bert' or 'roberta' or 'lstm')") |
|
|
parser.add_argument("--bert_model", type=str, default=None, help="Use a specific transformer instead of the default bert/roberta") |
|
|
parser.add_argument("--eval_file", type=str, help="path to evaluation file") |
|
|
|
|
|
args = parser.parse_args(args) if not predefined_args else predefined_args |
|
|
|
|
|
logger.info("Running training script with the following args:") |
|
|
args = vars(args) |
|
|
for arg in args: |
|
|
logger.info(f"{arg}: {args[arg]}") |
|
|
logger.info("------------------------------------------------------------") |
|
|
|
|
|
logger.info(f"Attempting evaluation of model from {args['save_name']} on file {args['eval_file']}") |
|
|
model = LemmaClassifier.load(args['save_name'], args) |
|
|
|
|
|
mcc_results, confusion, acc, weighted_f1 = evaluate_model(model, args['eval_file']) |
|
|
|
|
|
logger.info(f"MCC Results: {dict(mcc_results)}") |
|
|
logger.info("______________________________________________") |
|
|
logger.info(f"Confusion:\n%s", format_confusion(confusion)) |
|
|
logger.info("______________________________________________") |
|
|
logger.info(f"Accuracy: {acc}") |
|
|
logger.info("______________________________________________") |
|
|
logger.info(f"Weighted f1: {weighted_f1}") |
|
|
|
|
|
return mcc_results, confusion, acc, weighted_f1 |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|