Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- stanza/stanza/models/classifiers/base_classifier.py +65 -0
- stanza/stanza/models/classifiers/cnn_classifier.py +547 -0
- stanza/stanza/models/classifiers/iterate_test.py +64 -0
- stanza/stanza/models/classifiers/trainer.py +304 -0
- stanza/stanza/models/constituency/__init__.py +0 -0
- stanza/stanza/models/constituency/evaluate_treebanks.py +36 -0
- stanza/stanza/models/constituency/label_attention.py +726 -0
- stanza/stanza/models/constituency/lstm_tree_stack.py +91 -0
- stanza/stanza/models/constituency/score_converted_dependencies.py +65 -0
- stanza/stanza/models/constituency/text_processing.py +166 -0
- stanza/stanza/models/constituency/tree_reader.py +274 -0
- stanza/stanza/models/constituency/tree_stack.py +57 -0
- stanza/stanza/models/constituency/utils.py +375 -0
- stanza/stanza/models/coref/predict.py +55 -0
- stanza/stanza/models/coref/span_predictor.py +146 -0
- stanza/stanza/models/coref/tokenizer_customization.py +18 -0
- stanza/stanza/models/coref/word_encoder.py +108 -0
- stanza/stanza/models/depparse/data.py +233 -0
- stanza/stanza/models/lemma/attach_lemma_classifier.py +25 -0
- stanza/stanza/models/lemma/scorer.py +13 -0
- stanza/stanza/models/lemma/vocab.py +18 -0
- stanza/stanza/models/lemma_classifier/base_trainer.py +114 -0
- stanza/stanza/models/lemma_classifier/constants.py +14 -0
- stanza/stanza/models/lemma_classifier/evaluate_many.py +68 -0
- stanza/stanza/models/lemma_classifier/evaluate_models.py +228 -0
- stanza/stanza/models/lemma_classifier/prepare_dataset.py +125 -0
- stanza/stanza/models/lemma_classifier/train_lstm_model.py +147 -0
- stanza/stanza/models/lemma_classifier/train_many.py +155 -0
- stanza/stanza/models/lemma_classifier/train_transformer_model.py +130 -0
- stanza/stanza/models/lemma_classifier/transformer_model.py +89 -0
- stanza/stanza/models/lemma_classifier/utils.py +173 -0
- stanza/stanza/models/mwt/character_classifier.py +65 -0
- stanza/stanza/models/mwt/trainer.py +218 -0
- stanza/stanza/models/mwt/vocab.py +19 -0
- stanza/stanza/models/ner/vocab.py +56 -0
- stanza/stanza/models/pos/__init__.py +0 -0
- stanza/stanza/models/pos/build_xpos_vocab_factory.py +144 -0
- stanza/stanza/models/pos/data.py +387 -0
- stanza/stanza/models/pos/model.py +256 -0
- stanza/stanza/models/pos/trainer.py +179 -0
- stanza/stanza/models/pos/xpos_vocab_factory.py +200 -0
- stanza/stanza/models/pos/xpos_vocab_utils.py +48 -0
- stanza/stanza/models/tokenization/__init__.py +0 -0
- stanza/stanza/models/tokenization/data.py +432 -0
- stanza/stanza/models/tokenization/model.py +101 -0
- stanza/stanza/models/tokenization/tokenize_files.py +83 -0
- stanza/stanza/models/tokenization/trainer.py +102 -0
- stanza/stanza/utils/datasets/constituency/convert_ctb.py +224 -0
- stanza/stanza/utils/datasets/constituency/extract_silver_dataset.py +47 -0
- stanza/stanza/utils/datasets/coref/balance_languages.py +60 -0
stanza/stanza/models/classifiers/base_classifier.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
|
| 8 |
+
from stanza.models.common.utils import split_into_batches, sort_with_indices, unsort
|
| 9 |
+
|
| 10 |
+
"""
|
| 11 |
+
A base classifier type
|
| 12 |
+
|
| 13 |
+
Currently, has the ability to process text or other inputs in a manner
|
| 14 |
+
suitable for the particular model type.
|
| 15 |
+
In other words, the CNNClassifier processes lists of words,
|
| 16 |
+
and the ConstituencyClassifier processes trees
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger('stanza')
|
| 20 |
+
|
| 21 |
+
class BaseClassifier(ABC, nn.Module):
|
| 22 |
+
@abstractmethod
|
| 23 |
+
def extract_sentences(self, doc):
|
| 24 |
+
"""
|
| 25 |
+
Extract the sentences or the relevant information in the sentences from a document
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
def preprocess_sentences(self, sentences):
|
| 29 |
+
"""
|
| 30 |
+
By default, don't do anything
|
| 31 |
+
"""
|
| 32 |
+
return sentences
|
| 33 |
+
|
| 34 |
+
def label_sentences(self, sentences, batch_size=None):
|
| 35 |
+
"""
|
| 36 |
+
Given a list of sentences, return the model's results on that text.
|
| 37 |
+
"""
|
| 38 |
+
self.eval()
|
| 39 |
+
|
| 40 |
+
sentences = self.preprocess_sentences(sentences)
|
| 41 |
+
|
| 42 |
+
if batch_size is None:
|
| 43 |
+
intervals = [(0, len(sentences))]
|
| 44 |
+
orig_idx = None
|
| 45 |
+
else:
|
| 46 |
+
sentences, orig_idx = sort_with_indices(sentences, key=len, reverse=True)
|
| 47 |
+
intervals = split_into_batches(sentences, batch_size)
|
| 48 |
+
labels = []
|
| 49 |
+
for interval in intervals:
|
| 50 |
+
if interval[1] - interval[0] == 0:
|
| 51 |
+
# this can happen for empty text
|
| 52 |
+
continue
|
| 53 |
+
output = self(sentences[interval[0]:interval[1]])
|
| 54 |
+
predicted = torch.argmax(output, dim=1)
|
| 55 |
+
labels.extend(predicted.tolist())
|
| 56 |
+
|
| 57 |
+
if orig_idx:
|
| 58 |
+
sentences = unsort(sentences, orig_idx)
|
| 59 |
+
labels = unsort(labels, orig_idx)
|
| 60 |
+
|
| 61 |
+
logger.debug("Found labels")
|
| 62 |
+
for (label, sentence) in zip(labels, sentences):
|
| 63 |
+
logger.debug((label, sentence))
|
| 64 |
+
|
| 65 |
+
return labels
|
stanza/stanza/models/classifiers/cnn_classifier.py
ADDED
|
@@ -0,0 +1,547 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import dataclasses
|
| 2 |
+
import logging
|
| 3 |
+
import math
|
| 4 |
+
import os
|
| 5 |
+
import random
|
| 6 |
+
import re
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
|
| 13 |
+
import stanza.models.classifiers.data as data
|
| 14 |
+
from stanza.models.classifiers.base_classifier import BaseClassifier
|
| 15 |
+
from stanza.models.classifiers.config import CNNConfig
|
| 16 |
+
from stanza.models.classifiers.data import SentimentDatum
|
| 17 |
+
from stanza.models.classifiers.utils import ExtraVectors, ModelType, build_output_layers
|
| 18 |
+
from stanza.models.common.bert_embedding import extract_bert_embeddings
|
| 19 |
+
from stanza.models.common.data import get_long_tensor, sort_all
|
| 20 |
+
from stanza.models.common.utils import attach_bert_model
|
| 21 |
+
from stanza.models.common.vocab import PAD_ID, UNK_ID
|
| 22 |
+
|
| 23 |
+
"""
|
| 24 |
+
The CNN classifier is based on Yoon Kim's work:
|
| 25 |
+
|
| 26 |
+
https://arxiv.org/abs/1408.5882
|
| 27 |
+
|
| 28 |
+
Also included are maxpool 2d, conv 2d, and a bilstm, as in
|
| 29 |
+
|
| 30 |
+
Text Classification Improved by Integrating Bidirectional LSTM
|
| 31 |
+
with Two-dimensional Max Pooling
|
| 32 |
+
https://aclanthology.org/C16-1329.pdf
|
| 33 |
+
|
| 34 |
+
The architecture is simple:
|
| 35 |
+
|
| 36 |
+
- Embedding at the bottom layer
|
| 37 |
+
- separate learnable entry for UNK, since many of the embeddings we have use 0 for UNK
|
| 38 |
+
- maybe a bilstm layer, as per a command line flag
|
| 39 |
+
- Some number of conv2d layers over the embedding
|
| 40 |
+
- Maxpool layers over small windows, window size being a parameter
|
| 41 |
+
- FC layer to the classification layer
|
| 42 |
+
|
| 43 |
+
One experiment which was run and found to be a bit of a negative was
|
| 44 |
+
putting a layer on top of the pretrain. You would think that might
|
| 45 |
+
help, but dev performance went down for each variation of
|
| 46 |
+
- trans(emb)
|
| 47 |
+
- relu(trans(emb))
|
| 48 |
+
- dropout(trans(emb))
|
| 49 |
+
- dropout(relu(trans(emb)))
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
logger = logging.getLogger('stanza')
|
| 53 |
+
tlogger = logging.getLogger('stanza.classifiers.trainer')
|
| 54 |
+
|
| 55 |
+
class CNNClassifier(BaseClassifier):
|
| 56 |
+
def __init__(self, pretrain, extra_vocab, labels,
|
| 57 |
+
charmodel_forward, charmodel_backward, elmo_model, bert_model, bert_tokenizer, force_bert_saved, peft_name,
|
| 58 |
+
args):
|
| 59 |
+
"""
|
| 60 |
+
pretrain is a pretrained word embedding. should have .emb and .vocab
|
| 61 |
+
|
| 62 |
+
extra_vocab is a collection of words in the training data to
|
| 63 |
+
be used for the delta word embedding, if used. can be set to
|
| 64 |
+
None if delta word embedding is not used.
|
| 65 |
+
|
| 66 |
+
labels is the list of labels we expect in the training data.
|
| 67 |
+
Used to derive the number of classes. Saving it in the model
|
| 68 |
+
will let us check that test data has the same labels
|
| 69 |
+
|
| 70 |
+
args is either the complete arguments when training, or the
|
| 71 |
+
subset of arguments stored in the model save file
|
| 72 |
+
"""
|
| 73 |
+
super(CNNClassifier, self).__init__()
|
| 74 |
+
self.labels = labels
|
| 75 |
+
bert_finetune = args.bert_finetune
|
| 76 |
+
use_peft = args.use_peft
|
| 77 |
+
force_bert_saved = force_bert_saved or bert_finetune
|
| 78 |
+
logger.debug("bert_finetune %s / force_bert_saved %s", bert_finetune, force_bert_saved)
|
| 79 |
+
|
| 80 |
+
# this may change when loaded in a new Pipeline, so it's not part of the config
|
| 81 |
+
self.peft_name = peft_name
|
| 82 |
+
|
| 83 |
+
# we build a separate config out of the args so that we can easily save it in torch
|
| 84 |
+
self.config = CNNConfig(filter_channels = args.filter_channels,
|
| 85 |
+
filter_sizes = args.filter_sizes,
|
| 86 |
+
fc_shapes = args.fc_shapes,
|
| 87 |
+
dropout = args.dropout,
|
| 88 |
+
num_classes = len(labels),
|
| 89 |
+
wordvec_type = args.wordvec_type,
|
| 90 |
+
extra_wordvec_method = args.extra_wordvec_method,
|
| 91 |
+
extra_wordvec_dim = args.extra_wordvec_dim,
|
| 92 |
+
extra_wordvec_max_norm = args.extra_wordvec_max_norm,
|
| 93 |
+
char_lowercase = args.char_lowercase,
|
| 94 |
+
charlm_projection = args.charlm_projection,
|
| 95 |
+
has_charlm_forward = charmodel_forward is not None,
|
| 96 |
+
has_charlm_backward = charmodel_backward is not None,
|
| 97 |
+
use_elmo = args.use_elmo,
|
| 98 |
+
elmo_projection = args.elmo_projection,
|
| 99 |
+
bert_model = args.bert_model,
|
| 100 |
+
bert_finetune = bert_finetune,
|
| 101 |
+
bert_hidden_layers = args.bert_hidden_layers,
|
| 102 |
+
force_bert_saved = force_bert_saved,
|
| 103 |
+
|
| 104 |
+
use_peft = use_peft,
|
| 105 |
+
lora_rank = args.lora_rank,
|
| 106 |
+
lora_alpha = args.lora_alpha,
|
| 107 |
+
lora_dropout = args.lora_dropout,
|
| 108 |
+
lora_modules_to_save = args.lora_modules_to_save,
|
| 109 |
+
lora_target_modules = args.lora_target_modules,
|
| 110 |
+
|
| 111 |
+
bilstm = args.bilstm,
|
| 112 |
+
bilstm_hidden_dim = args.bilstm_hidden_dim,
|
| 113 |
+
maxpool_width = args.maxpool_width,
|
| 114 |
+
model_type = ModelType.CNN)
|
| 115 |
+
|
| 116 |
+
self.char_lowercase = args.char_lowercase
|
| 117 |
+
|
| 118 |
+
self.unsaved_modules = []
|
| 119 |
+
|
| 120 |
+
emb_matrix = pretrain.emb
|
| 121 |
+
self.add_unsaved_module('embedding', nn.Embedding.from_pretrained(emb_matrix, freeze=True))
|
| 122 |
+
self.add_unsaved_module('elmo_model', elmo_model)
|
| 123 |
+
self.vocab_size = emb_matrix.shape[0]
|
| 124 |
+
self.embedding_dim = emb_matrix.shape[1]
|
| 125 |
+
|
| 126 |
+
self.add_unsaved_module('forward_charlm', charmodel_forward)
|
| 127 |
+
if charmodel_forward is not None:
|
| 128 |
+
tlogger.debug("Got forward char model of dimension {}".format(charmodel_forward.hidden_dim()))
|
| 129 |
+
if not charmodel_forward.is_forward_lm:
|
| 130 |
+
raise ValueError("Got a backward charlm as a forward charlm!")
|
| 131 |
+
self.add_unsaved_module('backward_charlm', charmodel_backward)
|
| 132 |
+
if charmodel_backward is not None:
|
| 133 |
+
tlogger.debug("Got backward char model of dimension {}".format(charmodel_backward.hidden_dim()))
|
| 134 |
+
if charmodel_backward.is_forward_lm:
|
| 135 |
+
raise ValueError("Got a forward charlm as a backward charlm!")
|
| 136 |
+
|
| 137 |
+
attach_bert_model(self, bert_model, bert_tokenizer, self.config.use_peft, force_bert_saved)
|
| 138 |
+
|
| 139 |
+
# The Pretrain has PAD and UNK already (indices 0 and 1), but we
|
| 140 |
+
# possibly want to train UNK while freezing the rest of the embedding
|
| 141 |
+
# note that the /10.0 operation has to be inside nn.Parameter unless
|
| 142 |
+
# you want to spend a long time debugging this
|
| 143 |
+
self.unk = nn.Parameter(torch.randn(self.embedding_dim) / np.sqrt(self.embedding_dim) / 10.0)
|
| 144 |
+
|
| 145 |
+
# replacing NBSP picks up a whole bunch of words for VI
|
| 146 |
+
self.vocab_map = { word.replace('\xa0', ' '): i for i, word in enumerate(pretrain.vocab) }
|
| 147 |
+
|
| 148 |
+
if self.config.extra_wordvec_method is not ExtraVectors.NONE:
|
| 149 |
+
if not extra_vocab:
|
| 150 |
+
raise ValueError("Should have had extra_vocab set for extra_wordvec_method {}".format(self.config.extra_wordvec_method))
|
| 151 |
+
if not args.extra_wordvec_dim:
|
| 152 |
+
self.config.extra_wordvec_dim = self.embedding_dim
|
| 153 |
+
if self.config.extra_wordvec_method is ExtraVectors.SUM:
|
| 154 |
+
if self.config.extra_wordvec_dim != self.embedding_dim:
|
| 155 |
+
raise ValueError("extra_wordvec_dim must equal embedding_dim for {}".format(self.config.extra_wordvec_method))
|
| 156 |
+
|
| 157 |
+
self.extra_vocab = list(extra_vocab)
|
| 158 |
+
self.extra_vocab_map = { word: i for i, word in enumerate(self.extra_vocab) }
|
| 159 |
+
# TODO: possibly add regularization specifically on the extra embedding?
|
| 160 |
+
# note: it looks like a bug that this doesn't add UNK or PAD, but actually
|
| 161 |
+
# those are expected to already be the first two entries
|
| 162 |
+
self.extra_embedding = nn.Embedding(num_embeddings = len(extra_vocab),
|
| 163 |
+
embedding_dim = self.config.extra_wordvec_dim,
|
| 164 |
+
max_norm = self.config.extra_wordvec_max_norm,
|
| 165 |
+
padding_idx = 0)
|
| 166 |
+
tlogger.debug("Extra embedding size: {}".format(self.extra_embedding.weight.shape))
|
| 167 |
+
else:
|
| 168 |
+
self.extra_vocab = None
|
| 169 |
+
self.extra_vocab_map = None
|
| 170 |
+
self.config.extra_wordvec_dim = 0
|
| 171 |
+
self.extra_embedding = None
|
| 172 |
+
|
| 173 |
+
# Pytorch is "aware" of the existence of the nn.Modules inside
|
| 174 |
+
# an nn.ModuleList in terms of parameters() etc
|
| 175 |
+
if self.config.extra_wordvec_method is ExtraVectors.NONE:
|
| 176 |
+
total_embedding_dim = self.embedding_dim
|
| 177 |
+
elif self.config.extra_wordvec_method is ExtraVectors.SUM:
|
| 178 |
+
total_embedding_dim = self.embedding_dim
|
| 179 |
+
elif self.config.extra_wordvec_method is ExtraVectors.CONCAT:
|
| 180 |
+
total_embedding_dim = self.embedding_dim + self.config.extra_wordvec_dim
|
| 181 |
+
else:
|
| 182 |
+
raise ValueError("unable to handle {}".format(self.config.extra_wordvec_method))
|
| 183 |
+
|
| 184 |
+
if charmodel_forward is not None:
|
| 185 |
+
if args.charlm_projection:
|
| 186 |
+
self.charmodel_forward_projection = nn.Linear(charmodel_forward.hidden_dim(), args.charlm_projection)
|
| 187 |
+
total_embedding_dim += args.charlm_projection
|
| 188 |
+
else:
|
| 189 |
+
self.charmodel_forward_projection = None
|
| 190 |
+
total_embedding_dim += charmodel_forward.hidden_dim()
|
| 191 |
+
|
| 192 |
+
if charmodel_backward is not None:
|
| 193 |
+
if args.charlm_projection:
|
| 194 |
+
self.charmodel_backward_projection = nn.Linear(charmodel_backward.hidden_dim(), args.charlm_projection)
|
| 195 |
+
total_embedding_dim += args.charlm_projection
|
| 196 |
+
else:
|
| 197 |
+
self.charmodel_backward_projection = None
|
| 198 |
+
total_embedding_dim += charmodel_backward.hidden_dim()
|
| 199 |
+
|
| 200 |
+
if self.config.use_elmo:
|
| 201 |
+
if elmo_model is None:
|
| 202 |
+
raise ValueError("Model requires elmo, but elmo_model not passed in")
|
| 203 |
+
elmo_dim = elmo_model.sents2elmo([["Test"]])[0].shape[1]
|
| 204 |
+
|
| 205 |
+
# this mapping will combine 3 layers of elmo to 1 layer of features
|
| 206 |
+
self.elmo_combine_layers = nn.Linear(in_features=3, out_features=1, bias=False)
|
| 207 |
+
if self.config.elmo_projection:
|
| 208 |
+
self.elmo_projection = nn.Linear(in_features=elmo_dim, out_features=self.config.elmo_projection)
|
| 209 |
+
total_embedding_dim = total_embedding_dim + self.config.elmo_projection
|
| 210 |
+
else:
|
| 211 |
+
total_embedding_dim = total_embedding_dim + elmo_dim
|
| 212 |
+
|
| 213 |
+
if bert_model is not None:
|
| 214 |
+
if self.config.bert_hidden_layers:
|
| 215 |
+
# The average will be offset by 1/N so that the default zeros
|
| 216 |
+
# repressents an average of the N layers
|
| 217 |
+
if self.config.bert_hidden_layers > bert_model.config.num_hidden_layers:
|
| 218 |
+
# limit ourselves to the number of layers actually available
|
| 219 |
+
# note that we can +1 because of the initial embedding layer
|
| 220 |
+
self.config.bert_hidden_layers = bert_model.config.num_hidden_layers + 1
|
| 221 |
+
self.bert_layer_mix = nn.Linear(self.config.bert_hidden_layers, 1, bias=False)
|
| 222 |
+
nn.init.zeros_(self.bert_layer_mix.weight)
|
| 223 |
+
else:
|
| 224 |
+
# an average of layers 2, 3, 4 will be used
|
| 225 |
+
# (for historic reasons)
|
| 226 |
+
self.bert_layer_mix = None
|
| 227 |
+
|
| 228 |
+
if bert_tokenizer is None:
|
| 229 |
+
raise ValueError("Cannot have a bert model without a tokenizer")
|
| 230 |
+
self.bert_dim = self.bert_model.config.hidden_size
|
| 231 |
+
total_embedding_dim += self.bert_dim
|
| 232 |
+
|
| 233 |
+
if self.config.bilstm:
|
| 234 |
+
conv_input_dim = self.config.bilstm_hidden_dim * 2
|
| 235 |
+
self.bilstm = nn.LSTM(batch_first=True,
|
| 236 |
+
input_size=total_embedding_dim,
|
| 237 |
+
hidden_size=self.config.bilstm_hidden_dim,
|
| 238 |
+
num_layers=2,
|
| 239 |
+
bidirectional=True,
|
| 240 |
+
dropout=0.2)
|
| 241 |
+
else:
|
| 242 |
+
conv_input_dim = total_embedding_dim
|
| 243 |
+
self.bilstm = None
|
| 244 |
+
|
| 245 |
+
self.fc_input_size = 0
|
| 246 |
+
self.conv_layers = nn.ModuleList()
|
| 247 |
+
self.max_window = 0
|
| 248 |
+
for filter_idx, filter_size in enumerate(self.config.filter_sizes):
|
| 249 |
+
if isinstance(filter_size, int):
|
| 250 |
+
self.max_window = max(self.max_window, filter_size)
|
| 251 |
+
if isinstance(self.config.filter_channels, int):
|
| 252 |
+
filter_channels = self.config.filter_channels
|
| 253 |
+
else:
|
| 254 |
+
filter_channels = self.config.filter_channels[filter_idx]
|
| 255 |
+
fc_delta = filter_channels // self.config.maxpool_width
|
| 256 |
+
tlogger.debug("Adding full width filter %d. Output channels: %d -> %d", filter_size, filter_channels, fc_delta)
|
| 257 |
+
self.fc_input_size += fc_delta
|
| 258 |
+
self.conv_layers.append(nn.Conv2d(in_channels=1,
|
| 259 |
+
out_channels=filter_channels,
|
| 260 |
+
kernel_size=(filter_size, conv_input_dim)))
|
| 261 |
+
elif isinstance(filter_size, tuple) and len(filter_size) == 2:
|
| 262 |
+
filter_height, filter_width = filter_size
|
| 263 |
+
self.max_window = max(self.max_window, filter_width)
|
| 264 |
+
if isinstance(self.config.filter_channels, int):
|
| 265 |
+
filter_channels = max(1, self.config.filter_channels // (conv_input_dim // filter_width))
|
| 266 |
+
else:
|
| 267 |
+
filter_channels = self.config.filter_channels[filter_idx]
|
| 268 |
+
fc_delta = filter_channels * (conv_input_dim // filter_width) // self.config.maxpool_width
|
| 269 |
+
tlogger.debug("Adding filter %s. Output channels: %d -> %d", filter_size, filter_channels, fc_delta)
|
| 270 |
+
self.fc_input_size += fc_delta
|
| 271 |
+
self.conv_layers.append(nn.Conv2d(in_channels=1,
|
| 272 |
+
out_channels=filter_channels,
|
| 273 |
+
stride=(1, filter_width),
|
| 274 |
+
kernel_size=(filter_height, filter_width)))
|
| 275 |
+
else:
|
| 276 |
+
raise ValueError("Expected int or 2d tuple for conv size")
|
| 277 |
+
|
| 278 |
+
tlogger.debug("Input dim to FC layers: %d", self.fc_input_size)
|
| 279 |
+
self.fc_layers = build_output_layers(self.fc_input_size, self.config.fc_shapes, self.config.num_classes)
|
| 280 |
+
|
| 281 |
+
self.dropout = nn.Dropout(self.config.dropout)
|
| 282 |
+
|
| 283 |
+
def add_unsaved_module(self, name, module):
|
| 284 |
+
self.unsaved_modules += [name]
|
| 285 |
+
setattr(self, name, module)
|
| 286 |
+
|
| 287 |
+
if module is not None and (name in ('forward_charlm', 'backward_charlm') or
|
| 288 |
+
(name == 'bert_model' and not self.config.use_peft)):
|
| 289 |
+
# if we are using peft, we should not save the transformer directly
|
| 290 |
+
# instead, the peft parameters only will be saved later
|
| 291 |
+
for _, parameter in module.named_parameters():
|
| 292 |
+
parameter.requires_grad = False
|
| 293 |
+
|
| 294 |
+
def is_unsaved_module(self, name):
|
| 295 |
+
return name.split('.')[0] in self.unsaved_modules
|
| 296 |
+
|
| 297 |
+
def log_configuration(self):
|
| 298 |
+
"""
|
| 299 |
+
Log some essential information about the model configuration to the training logger
|
| 300 |
+
"""
|
| 301 |
+
tlogger.info("Filter sizes: %s" % str(self.config.filter_sizes))
|
| 302 |
+
tlogger.info("Filter channels: %s" % str(self.config.filter_channels))
|
| 303 |
+
tlogger.info("Intermediate layers: %s" % str(self.config.fc_shapes))
|
| 304 |
+
|
| 305 |
+
def log_norms(self):
|
| 306 |
+
lines = ["NORMS FOR MODEL PARAMTERS"]
|
| 307 |
+
for name, param in self.named_parameters():
|
| 308 |
+
if param.requires_grad and name.split(".")[0] not in ('forward_charlm', 'backward_charlm'):
|
| 309 |
+
lines.append("%s %.6g" % (name, torch.norm(param).item()))
|
| 310 |
+
logger.info("\n".join(lines))
|
| 311 |
+
|
| 312 |
+
def build_char_reps(self, inputs, max_phrase_len, charlm, projection, begin_paddings, device):
|
| 313 |
+
char_reps = charlm.build_char_representation(inputs)
|
| 314 |
+
if projection is not None:
|
| 315 |
+
char_reps = [projection(x) for x in char_reps]
|
| 316 |
+
char_inputs = torch.zeros((len(inputs), max_phrase_len, char_reps[0].shape[-1]), device=device)
|
| 317 |
+
for idx, rep in enumerate(char_reps):
|
| 318 |
+
start = begin_paddings[idx]
|
| 319 |
+
end = start + rep.shape[0]
|
| 320 |
+
char_inputs[idx, start:end, :] = rep
|
| 321 |
+
return char_inputs
|
| 322 |
+
|
| 323 |
+
def extract_bert_embeddings(self, inputs, max_phrase_len, begin_paddings, device):
|
| 324 |
+
bert_embeddings = extract_bert_embeddings(self.config.bert_model, self.bert_tokenizer, self.bert_model, inputs, device,
|
| 325 |
+
keep_endpoints=False,
|
| 326 |
+
num_layers=self.bert_layer_mix.in_features if self.bert_layer_mix is not None else None,
|
| 327 |
+
detach=not self.config.bert_finetune,
|
| 328 |
+
peft_name=self.peft_name)
|
| 329 |
+
if self.bert_layer_mix is not None:
|
| 330 |
+
# add the average so that the default behavior is to
|
| 331 |
+
# take an average of the N layers, and anything else
|
| 332 |
+
# other than that needs to be learned
|
| 333 |
+
bert_embeddings = [self.bert_layer_mix(feature).squeeze(2) + feature.sum(axis=2) / self.bert_layer_mix.in_features for feature in bert_embeddings]
|
| 334 |
+
bert_inputs = torch.zeros((len(inputs), max_phrase_len, bert_embeddings[0].shape[-1]), device=device)
|
| 335 |
+
for idx, rep in enumerate(bert_embeddings):
|
| 336 |
+
start = begin_paddings[idx]
|
| 337 |
+
end = start + rep.shape[0]
|
| 338 |
+
bert_inputs[idx, start:end, :] = rep
|
| 339 |
+
return bert_inputs
|
| 340 |
+
|
| 341 |
+
def forward(self, inputs):
|
| 342 |
+
# assume all pieces are on the same device
|
| 343 |
+
device = next(self.parameters()).device
|
| 344 |
+
|
| 345 |
+
vocab_map = self.vocab_map
|
| 346 |
+
def map_word(word):
|
| 347 |
+
idx = vocab_map.get(word, None)
|
| 348 |
+
if idx is not None:
|
| 349 |
+
return idx
|
| 350 |
+
if word[-1] == "'":
|
| 351 |
+
idx = vocab_map.get(word[:-1], None)
|
| 352 |
+
if idx is not None:
|
| 353 |
+
return idx
|
| 354 |
+
return vocab_map.get(word.lower(), UNK_ID)
|
| 355 |
+
|
| 356 |
+
inputs = [x.text if isinstance(x, SentimentDatum) else x for x in inputs]
|
| 357 |
+
# we will pad each phrase so either it matches the longest
|
| 358 |
+
# conv or the longest phrase in the input, whichever is longer
|
| 359 |
+
max_phrase_len = max(len(x) for x in inputs)
|
| 360 |
+
if self.max_window > max_phrase_len:
|
| 361 |
+
max_phrase_len = self.max_window
|
| 362 |
+
|
| 363 |
+
batch_indices = []
|
| 364 |
+
batch_unknowns = []
|
| 365 |
+
extra_batch_indices = []
|
| 366 |
+
begin_paddings = []
|
| 367 |
+
end_paddings = []
|
| 368 |
+
|
| 369 |
+
elmo_batch_words = []
|
| 370 |
+
|
| 371 |
+
for phrase in inputs:
|
| 372 |
+
# we use random at training time to try to learn different
|
| 373 |
+
# positions of padding. at test time, though, we want to
|
| 374 |
+
# have consistent results, so we set that to 0 begin_pad
|
| 375 |
+
if self.training:
|
| 376 |
+
begin_pad_width = random.randint(0, max_phrase_len - len(phrase))
|
| 377 |
+
else:
|
| 378 |
+
begin_pad_width = 0
|
| 379 |
+
end_pad_width = max_phrase_len - begin_pad_width - len(phrase)
|
| 380 |
+
|
| 381 |
+
begin_paddings.append(begin_pad_width)
|
| 382 |
+
end_paddings.append(end_pad_width)
|
| 383 |
+
|
| 384 |
+
# the initial lists are the length of the begin padding
|
| 385 |
+
sentence_indices = [PAD_ID] * begin_pad_width
|
| 386 |
+
sentence_indices.extend([map_word(x) for x in phrase])
|
| 387 |
+
sentence_indices.extend([PAD_ID] * end_pad_width)
|
| 388 |
+
|
| 389 |
+
# the "unknowns" will be the locations of the unknown words.
|
| 390 |
+
# these locations will get the specially trained unknown vector
|
| 391 |
+
# TODO: split UNK based on part of speech? might be an interesting experiment
|
| 392 |
+
sentence_unknowns = [idx for idx, word in enumerate(sentence_indices) if word == UNK_ID]
|
| 393 |
+
|
| 394 |
+
batch_indices.append(sentence_indices)
|
| 395 |
+
batch_unknowns.append(sentence_unknowns)
|
| 396 |
+
|
| 397 |
+
if self.extra_vocab:
|
| 398 |
+
extra_sentence_indices = [PAD_ID] * begin_pad_width
|
| 399 |
+
for word in phrase:
|
| 400 |
+
if word in self.extra_vocab_map:
|
| 401 |
+
# the extra vocab is initialized from the
|
| 402 |
+
# words in the training set, which means there
|
| 403 |
+
# would be no unknown words. to occasionally
|
| 404 |
+
# train the extra vocab's unknown words, we
|
| 405 |
+
# replace 1% of the words with UNK
|
| 406 |
+
# we don't do that for the original embedding
|
| 407 |
+
# on the assumption that there may be some
|
| 408 |
+
# unknown words in the training set anyway
|
| 409 |
+
# TODO: maybe train unk for the original embedding?
|
| 410 |
+
if self.training and random.random() < 0.01:
|
| 411 |
+
extra_sentence_indices.append(UNK_ID)
|
| 412 |
+
else:
|
| 413 |
+
extra_sentence_indices.append(self.extra_vocab_map[word])
|
| 414 |
+
else:
|
| 415 |
+
extra_sentence_indices.append(UNK_ID)
|
| 416 |
+
extra_sentence_indices.extend([PAD_ID] * end_pad_width)
|
| 417 |
+
extra_batch_indices.append(extra_sentence_indices)
|
| 418 |
+
|
| 419 |
+
if self.config.use_elmo:
|
| 420 |
+
elmo_phrase_words = [""] * begin_pad_width
|
| 421 |
+
for word in phrase:
|
| 422 |
+
elmo_phrase_words.append(word)
|
| 423 |
+
elmo_phrase_words.extend([""] * end_pad_width)
|
| 424 |
+
elmo_batch_words.append(elmo_phrase_words)
|
| 425 |
+
|
| 426 |
+
# creating a single large list with all the indices lets us
|
| 427 |
+
# create a single tensor, which is much faster than creating
|
| 428 |
+
# many tiny tensors
|
| 429 |
+
# we can convert this to the input to the CNN
|
| 430 |
+
# it is padded at one or both ends so that it is now num_phrases x max_len x emb_size
|
| 431 |
+
# there are two ways in which this padding is suboptimal
|
| 432 |
+
# the first is that for short sentences, smaller windows will
|
| 433 |
+
# be padded to the point that some windows are entirely pad
|
| 434 |
+
# the second is that a sentence S will have more or less padding
|
| 435 |
+
# depending on what other sentences are in its batch
|
| 436 |
+
# we assume these effects are pretty minimal
|
| 437 |
+
batch_indices = torch.tensor(batch_indices, requires_grad=False, device=device)
|
| 438 |
+
input_vectors = self.embedding(batch_indices)
|
| 439 |
+
# we use the random unk so that we are not necessarily
|
| 440 |
+
# learning to match 0s for unk
|
| 441 |
+
for phrase_num, sentence_unknowns in enumerate(batch_unknowns):
|
| 442 |
+
input_vectors[phrase_num][sentence_unknowns] = self.unk
|
| 443 |
+
|
| 444 |
+
if self.extra_vocab:
|
| 445 |
+
extra_batch_indices = torch.tensor(extra_batch_indices, requires_grad=False, device=device)
|
| 446 |
+
extra_input_vectors = self.extra_embedding(extra_batch_indices)
|
| 447 |
+
if self.config.extra_wordvec_method is ExtraVectors.CONCAT:
|
| 448 |
+
all_inputs = [input_vectors, extra_input_vectors]
|
| 449 |
+
elif self.config.extra_wordvec_method is ExtraVectors.SUM:
|
| 450 |
+
all_inputs = [input_vectors + extra_input_vectors]
|
| 451 |
+
else:
|
| 452 |
+
raise ValueError("unable to handle {}".format(self.config.extra_wordvec_method))
|
| 453 |
+
else:
|
| 454 |
+
all_inputs = [input_vectors]
|
| 455 |
+
|
| 456 |
+
if self.forward_charlm is not None:
|
| 457 |
+
char_reps_forward = self.build_char_reps(inputs, max_phrase_len, self.forward_charlm, self.charmodel_forward_projection, begin_paddings, device)
|
| 458 |
+
all_inputs.append(char_reps_forward)
|
| 459 |
+
|
| 460 |
+
if self.backward_charlm is not None:
|
| 461 |
+
char_reps_backward = self.build_char_reps(inputs, max_phrase_len, self.backward_charlm, self.charmodel_backward_projection, begin_paddings, device)
|
| 462 |
+
all_inputs.append(char_reps_backward)
|
| 463 |
+
|
| 464 |
+
if self.config.use_elmo:
|
| 465 |
+
# this will be N arrays of 3xMx1024 where M is the number of words
|
| 466 |
+
# and N is the number of sentences (and 1024 is actually the number of weights)
|
| 467 |
+
elmo_arrays = self.elmo_model.sents2elmo(elmo_batch_words, output_layer=-2)
|
| 468 |
+
elmo_tensors = [torch.tensor(x).to(device=device) for x in elmo_arrays]
|
| 469 |
+
# elmo_tensor will now be Nx3xMx1024
|
| 470 |
+
elmo_tensor = torch.stack(elmo_tensors)
|
| 471 |
+
# Nx1024xMx3
|
| 472 |
+
elmo_tensor = torch.transpose(elmo_tensor, 1, 3)
|
| 473 |
+
# NxMx1024x3
|
| 474 |
+
elmo_tensor = torch.transpose(elmo_tensor, 1, 2)
|
| 475 |
+
# NxMx1024x1
|
| 476 |
+
elmo_tensor = self.elmo_combine_layers(elmo_tensor)
|
| 477 |
+
# NxMx1024
|
| 478 |
+
elmo_tensor = elmo_tensor.squeeze(3)
|
| 479 |
+
if self.config.elmo_projection:
|
| 480 |
+
elmo_tensor = self.elmo_projection(elmo_tensor)
|
| 481 |
+
all_inputs.append(elmo_tensor)
|
| 482 |
+
|
| 483 |
+
if self.bert_model is not None:
|
| 484 |
+
bert_embeddings = self.extract_bert_embeddings(inputs, max_phrase_len, begin_paddings, device)
|
| 485 |
+
all_inputs.append(bert_embeddings)
|
| 486 |
+
|
| 487 |
+
# still works even if there's just one item
|
| 488 |
+
input_vectors = torch.cat(all_inputs, dim=2)
|
| 489 |
+
|
| 490 |
+
if self.config.bilstm:
|
| 491 |
+
input_vectors, _ = self.bilstm(self.dropout(input_vectors))
|
| 492 |
+
|
| 493 |
+
# reshape to fit the input tensors
|
| 494 |
+
x = input_vectors.unsqueeze(1)
|
| 495 |
+
|
| 496 |
+
conv_outs = []
|
| 497 |
+
for conv, filter_size in zip(self.conv_layers, self.config.filter_sizes):
|
| 498 |
+
if isinstance(filter_size, int):
|
| 499 |
+
conv_out = self.dropout(F.relu(conv(x).squeeze(3)))
|
| 500 |
+
conv_outs.append(conv_out)
|
| 501 |
+
else:
|
| 502 |
+
conv_out = conv(x).transpose(2, 3).flatten(1, 2)
|
| 503 |
+
conv_out = self.dropout(F.relu(conv_out))
|
| 504 |
+
conv_outs.append(conv_out)
|
| 505 |
+
pool_outs = [F.max_pool2d(out, (self.config.maxpool_width, out.shape[2])).squeeze(2) for out in conv_outs]
|
| 506 |
+
pooled = torch.cat(pool_outs, dim=1)
|
| 507 |
+
|
| 508 |
+
previous_layer = pooled
|
| 509 |
+
for fc in self.fc_layers[:-1]:
|
| 510 |
+
previous_layer = self.dropout(F.relu(fc(previous_layer)))
|
| 511 |
+
out = self.fc_layers[-1](previous_layer)
|
| 512 |
+
# note that we return the raw logits rather than use a softmax
|
| 513 |
+
# https://discuss.pytorch.org/t/multi-class-cross-entropy-loss-and-softmax-in-pytorch/24920/4
|
| 514 |
+
return out
|
| 515 |
+
|
| 516 |
+
def get_params(self, skip_modules=True):
|
| 517 |
+
model_state = self.state_dict()
|
| 518 |
+
# skip saving modules like pretrained embeddings, because they are large and will be saved in a separate file
|
| 519 |
+
if skip_modules:
|
| 520 |
+
skipped = [k for k in model_state.keys() if self.is_unsaved_module(k)]
|
| 521 |
+
for k in skipped:
|
| 522 |
+
del model_state[k]
|
| 523 |
+
|
| 524 |
+
config = dataclasses.asdict(self.config)
|
| 525 |
+
config['wordvec_type'] = config['wordvec_type'].name
|
| 526 |
+
config['extra_wordvec_method'] = config['extra_wordvec_method'].name
|
| 527 |
+
config['model_type'] = config['model_type'].name
|
| 528 |
+
|
| 529 |
+
params = {
|
| 530 |
+
'model': model_state,
|
| 531 |
+
'config': config,
|
| 532 |
+
'labels': self.labels,
|
| 533 |
+
'extra_vocab': self.extra_vocab,
|
| 534 |
+
}
|
| 535 |
+
if self.config.use_peft:
|
| 536 |
+
# Hide import so that peft dependency is optional
|
| 537 |
+
from peft import get_peft_model_state_dict
|
| 538 |
+
params["bert_lora"] = get_peft_model_state_dict(self.bert_model, adapter_name=self.peft_name)
|
| 539 |
+
return params
|
| 540 |
+
|
| 541 |
+
def preprocess_data(self, sentences):
|
| 542 |
+
sentences = [data.update_text(s, self.config.wordvec_type) for s in sentences]
|
| 543 |
+
return sentences
|
| 544 |
+
|
| 545 |
+
def extract_sentences(self, doc):
|
| 546 |
+
# TODO: tokens or words better here?
|
| 547 |
+
return [[token.text for token in sentence.tokens] for sentence in doc.sentences]
|
stanza/stanza/models/classifiers/iterate_test.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Iterate test."""
|
| 2 |
+
import argparse
|
| 3 |
+
import glob
|
| 4 |
+
import logging
|
| 5 |
+
|
| 6 |
+
import stanza.models.classifier as classifier
|
| 7 |
+
import stanza.models.classifiers.cnn_classifier as cnn_classifier
|
| 8 |
+
from stanza.models.common import utils
|
| 9 |
+
|
| 10 |
+
from stanza.utils.confusion import format_confusion, confusion_to_accuracy
|
| 11 |
+
|
| 12 |
+
"""
|
| 13 |
+
A script for running the same test file on several different classifiers.
|
| 14 |
+
|
| 15 |
+
For each one, it will output the accuracy and, if possible, the confusion matrix.
|
| 16 |
+
|
| 17 |
+
Includes the arguments for pretrain, which allows for passing in a
|
| 18 |
+
different directory for the pretrain file.
|
| 19 |
+
|
| 20 |
+
Example command line:
|
| 21 |
+
python3 -m stanza.models.classifiers.iterate_test --test_file extern_data/sentiment/sst-processed/threeclass/test-threeclass-roots.txt --glob "saved_models/classifier/FC41_3class_en_ewt_FS*ACC66*"
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
logger = logging.getLogger('stanza')
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def parse_args():
|
| 28 |
+
"""Add and parse arguments."""
|
| 29 |
+
parser = classifier.build_parser()
|
| 30 |
+
|
| 31 |
+
parser.add_argument('--glob', type=str, default='saved_models/classifier/*classifier*pt', help='Model file(s) to test.')
|
| 32 |
+
|
| 33 |
+
args = parser.parse_args()
|
| 34 |
+
return args
|
| 35 |
+
|
| 36 |
+
args = parse_args()
|
| 37 |
+
seed = utils.set_random_seed(args.seed)
|
| 38 |
+
|
| 39 |
+
model_files = []
|
| 40 |
+
for glob_piece in args.glob.split():
|
| 41 |
+
model_files.extend(glob.glob(glob_piece))
|
| 42 |
+
model_files = sorted(set(model_files))
|
| 43 |
+
|
| 44 |
+
test_set = data.read_dataset(args.test_file, args.wordvec_type, min_len=None)
|
| 45 |
+
logger.info("Using test set: %s" % args.test_file)
|
| 46 |
+
|
| 47 |
+
device = None
|
| 48 |
+
for load_name in model_files:
|
| 49 |
+
args.load_name = load_name
|
| 50 |
+
model = classifier.load_model(args)
|
| 51 |
+
|
| 52 |
+
logger.info("Testing %s" % load_name)
|
| 53 |
+
model = cnn_classifier.load(load_name, pretrain)
|
| 54 |
+
if device is None:
|
| 55 |
+
device = next(model.parameters()).device
|
| 56 |
+
logger.info("Current device: %s" % device)
|
| 57 |
+
|
| 58 |
+
labels = model.labels
|
| 59 |
+
classifier.check_labels(labels, test_set)
|
| 60 |
+
|
| 61 |
+
confusion = classifier.confusion_dataset(model, test_set, device=device)
|
| 62 |
+
correct, total = confusion_to_accuracy(confusion)
|
| 63 |
+
logger.info(" Results: %d correct of %d examples. Accuracy: %f" % (correct, total, correct / total))
|
| 64 |
+
logger.info("Confusion matrix:\n{}".format(format_confusion(confusion, model.labels)))
|
stanza/stanza/models/classifiers/trainer.py
ADDED
|
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Organizes the model itself and its optimizer in one place
|
| 3 |
+
|
| 4 |
+
Saving the optimizer allows for easy restarting of training
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
import os
|
| 9 |
+
import torch
|
| 10 |
+
import torch.optim as optim
|
| 11 |
+
from types import SimpleNamespace
|
| 12 |
+
|
| 13 |
+
import stanza.models.classifiers.data as data
|
| 14 |
+
import stanza.models.classifiers.cnn_classifier as cnn_classifier
|
| 15 |
+
import stanza.models.classifiers.constituency_classifier as constituency_classifier
|
| 16 |
+
from stanza.models.classifiers.config import CNNConfig, ConstituencyConfig
|
| 17 |
+
from stanza.models.classifiers.utils import ModelType, WVType, ExtraVectors
|
| 18 |
+
from stanza.models.common.foundation_cache import load_bert, load_bert_with_peft, load_charlm, load_pretrain
|
| 19 |
+
from stanza.models.common.peft_config import build_peft_wrapper, load_peft_wrapper
|
| 20 |
+
from stanza.models.common.pretrain import Pretrain
|
| 21 |
+
from stanza.models.common.utils import get_split_optimizer
|
| 22 |
+
from stanza.models.constituency.tree_embedding import TreeEmbedding
|
| 23 |
+
|
| 24 |
+
from pickle import UnpicklingError
|
| 25 |
+
import warnings
|
| 26 |
+
|
| 27 |
+
logger = logging.getLogger('stanza')
|
| 28 |
+
|
| 29 |
+
class Trainer:
|
| 30 |
+
"""
|
| 31 |
+
Stores a constituency model and its optimizer
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
def __init__(self, model, optimizer=None, epochs_trained=0, global_step=0, best_score=None):
|
| 35 |
+
self.model = model
|
| 36 |
+
self.optimizer = optimizer
|
| 37 |
+
# we keep track of position in the learning so that we can
|
| 38 |
+
# checkpoint & restart if needed without restarting the epoch count
|
| 39 |
+
self.epochs_trained = epochs_trained
|
| 40 |
+
self.global_step = global_step
|
| 41 |
+
# save the best dev score so that when reloading a checkpoint
|
| 42 |
+
# of a model, we know how far we got
|
| 43 |
+
self.best_score = best_score
|
| 44 |
+
|
| 45 |
+
def save(self, filename, epochs_trained=None, skip_modules=True, save_optimizer=True):
|
| 46 |
+
"""
|
| 47 |
+
save the current model, optimizer, and other state to filename
|
| 48 |
+
|
| 49 |
+
epochs_trained can be passed as a parameter to handle saving at the end of an epoch
|
| 50 |
+
"""
|
| 51 |
+
if epochs_trained is None:
|
| 52 |
+
epochs_trained = self.epochs_trained
|
| 53 |
+
save_dir = os.path.split(filename)[0]
|
| 54 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 55 |
+
model_params = self.model.get_params(skip_modules)
|
| 56 |
+
params = {
|
| 57 |
+
'params': model_params,
|
| 58 |
+
'epochs_trained': epochs_trained,
|
| 59 |
+
'global_step': self.global_step,
|
| 60 |
+
'best_score': self.best_score,
|
| 61 |
+
}
|
| 62 |
+
if save_optimizer and self.optimizer is not None:
|
| 63 |
+
params['optimizer_state_dict'] = {opt_name: opt.state_dict() for opt_name, opt in self.optimizer.items()}
|
| 64 |
+
torch.save(params, filename, _use_new_zipfile_serialization=False)
|
| 65 |
+
logger.info("Model saved to {}".format(filename))
|
| 66 |
+
|
| 67 |
+
@staticmethod
|
| 68 |
+
def load(filename, args, foundation_cache=None, load_optimizer=False):
|
| 69 |
+
if not os.path.exists(filename):
|
| 70 |
+
if args.save_dir is None:
|
| 71 |
+
raise FileNotFoundError("Cannot find model in {} and args.save_dir is None".format(filename))
|
| 72 |
+
elif os.path.exists(os.path.join(args.save_dir, filename)):
|
| 73 |
+
filename = os.path.join(args.save_dir, filename)
|
| 74 |
+
else:
|
| 75 |
+
raise FileNotFoundError("Cannot find model in {} or in {}".format(filename, os.path.join(args.save_dir, filename)))
|
| 76 |
+
try:
|
| 77 |
+
# TODO: can remove the try/except once the new version is out
|
| 78 |
+
#checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True)
|
| 79 |
+
try:
|
| 80 |
+
checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True)
|
| 81 |
+
except UnpicklingError as e:
|
| 82 |
+
checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=False)
|
| 83 |
+
warnings.warn("The saved classifier has an old format using SimpleNamespace and/or Enum instead of a dict to store config. This version of Stanza can support reading both the new and the old formats. Future versions will only allow loading with weights_only=True. Please resave the pretrained classifier using this version ASAP.")
|
| 84 |
+
except BaseException:
|
| 85 |
+
logger.exception("Cannot load model from {}".format(filename))
|
| 86 |
+
raise
|
| 87 |
+
logger.debug("Loaded model {}".format(filename))
|
| 88 |
+
|
| 89 |
+
epochs_trained = checkpoint.get('epochs_trained', 0)
|
| 90 |
+
global_step = checkpoint.get('global_step', 0)
|
| 91 |
+
best_score = checkpoint.get('best_score', None)
|
| 92 |
+
|
| 93 |
+
# TODO: can remove this block once all models are retrained
|
| 94 |
+
if 'params' not in checkpoint:
|
| 95 |
+
model_params = {
|
| 96 |
+
'model': checkpoint['model'],
|
| 97 |
+
'config': checkpoint['config'],
|
| 98 |
+
'labels': checkpoint['labels'],
|
| 99 |
+
'extra_vocab': checkpoint['extra_vocab'],
|
| 100 |
+
}
|
| 101 |
+
else:
|
| 102 |
+
model_params = checkpoint['params']
|
| 103 |
+
# TODO: this can be removed once v1.10.0 is out
|
| 104 |
+
if isinstance(model_params['config'], SimpleNamespace):
|
| 105 |
+
model_params['config'] = vars(model_params['config'])
|
| 106 |
+
# TODO: these isinstance can go away after 1.10.0
|
| 107 |
+
model_type = model_params['config']['model_type']
|
| 108 |
+
if isinstance(model_type, str):
|
| 109 |
+
model_type = ModelType[model_type]
|
| 110 |
+
model_params['config']['model_type'] = model_type
|
| 111 |
+
|
| 112 |
+
if model_type == ModelType.CNN:
|
| 113 |
+
# TODO: these updates are only necessary during the
|
| 114 |
+
# transition to the @dataclass version of the config
|
| 115 |
+
# Once those are all saved, it is no longer necessary
|
| 116 |
+
# to patch existing models (since they will all be patched)
|
| 117 |
+
if 'has_charlm_forward' not in model_params['config']:
|
| 118 |
+
model_params['config']['has_charlm_forward'] = args.charlm_forward_file is not None
|
| 119 |
+
if 'has_charlm_backward' not in model_params['config']:
|
| 120 |
+
model_params['config']['has_charlm_backward'] = args.charlm_backward_file is not None
|
| 121 |
+
for argname in ['bert_hidden_layers', 'bert_finetune', 'force_bert_saved', 'use_peft',
|
| 122 |
+
'lora_rank', 'lora_alpha', 'lora_dropout', 'lora_modules_to_save', 'lora_target_modules']:
|
| 123 |
+
model_params['config'][argname] = model_params['config'].get(argname, None)
|
| 124 |
+
# TODO: these isinstance can go away after 1.10.0
|
| 125 |
+
if isinstance(model_params['config']['wordvec_type'], str):
|
| 126 |
+
model_params['config']['wordvec_type'] = WVType[model_params['config']['wordvec_type']]
|
| 127 |
+
if isinstance(model_params['config']['extra_wordvec_method'], str):
|
| 128 |
+
model_params['config']['extra_wordvec_method'] = ExtraVectors[model_params['config']['extra_wordvec_method']]
|
| 129 |
+
model_params['config'] = CNNConfig(**model_params['config'])
|
| 130 |
+
|
| 131 |
+
pretrain = Trainer.load_pretrain(args, foundation_cache)
|
| 132 |
+
elmo_model = utils.load_elmo(args.elmo_model) if args.use_elmo else None
|
| 133 |
+
|
| 134 |
+
if model_params['config'].has_charlm_forward:
|
| 135 |
+
charmodel_forward = load_charlm(args.charlm_forward_file, foundation_cache)
|
| 136 |
+
else:
|
| 137 |
+
charmodel_forward = None
|
| 138 |
+
if model_params['config'].has_charlm_backward:
|
| 139 |
+
charmodel_backward = load_charlm(args.charlm_backward_file, foundation_cache)
|
| 140 |
+
else:
|
| 141 |
+
charmodel_backward = None
|
| 142 |
+
|
| 143 |
+
bert_model = model_params['config'].bert_model
|
| 144 |
+
# TODO: can get rid of the getattr after rebuilding all models
|
| 145 |
+
use_peft = getattr(model_params['config'], 'use_peft', False)
|
| 146 |
+
force_bert_saved = getattr(model_params['config'], 'force_bert_saved', False)
|
| 147 |
+
peft_name = None
|
| 148 |
+
if use_peft:
|
| 149 |
+
# if loading a peft model, we first load the base transformer
|
| 150 |
+
# the CNNClassifier code wraps the transformer in peft
|
| 151 |
+
# after creating the CNNClassifier with the peft wrapper,
|
| 152 |
+
# we *then* load the weights
|
| 153 |
+
bert_model, bert_tokenizer, peft_name = load_bert_with_peft(bert_model, "classifier", foundation_cache)
|
| 154 |
+
bert_model = load_peft_wrapper(bert_model, model_params['bert_lora'], vars(model_params['config']), logger, peft_name)
|
| 155 |
+
elif force_bert_saved:
|
| 156 |
+
bert_model, bert_tokenizer = load_bert(bert_model)
|
| 157 |
+
else:
|
| 158 |
+
bert_model, bert_tokenizer = load_bert(bert_model, foundation_cache)
|
| 159 |
+
model = cnn_classifier.CNNClassifier(pretrain=pretrain,
|
| 160 |
+
extra_vocab=model_params['extra_vocab'],
|
| 161 |
+
labels=model_params['labels'],
|
| 162 |
+
charmodel_forward=charmodel_forward,
|
| 163 |
+
charmodel_backward=charmodel_backward,
|
| 164 |
+
elmo_model=elmo_model,
|
| 165 |
+
bert_model=bert_model,
|
| 166 |
+
bert_tokenizer=bert_tokenizer,
|
| 167 |
+
force_bert_saved=force_bert_saved,
|
| 168 |
+
peft_name=peft_name,
|
| 169 |
+
args=model_params['config'])
|
| 170 |
+
elif model_type == ModelType.CONSTITUENCY:
|
| 171 |
+
# the constituency version doesn't have a peft feature yet
|
| 172 |
+
use_peft = False
|
| 173 |
+
pretrain_args = {
|
| 174 |
+
'wordvec_pretrain_file': args.wordvec_pretrain_file,
|
| 175 |
+
'charlm_forward_file': args.charlm_forward_file,
|
| 176 |
+
'charlm_backward_file': args.charlm_backward_file,
|
| 177 |
+
}
|
| 178 |
+
# TODO: integrate with peft for the constituency version
|
| 179 |
+
tree_embedding = TreeEmbedding.model_from_params(model_params['tree_embedding'], pretrain_args, foundation_cache)
|
| 180 |
+
model_params['config'] = ConstituencyConfig(**model_params['config'])
|
| 181 |
+
model = constituency_classifier.ConstituencyClassifier(tree_embedding=tree_embedding,
|
| 182 |
+
labels=model_params['labels'],
|
| 183 |
+
args=model_params['config'])
|
| 184 |
+
else:
|
| 185 |
+
raise ValueError("Unknown model type {}".format(model_type))
|
| 186 |
+
model.load_state_dict(model_params['model'], strict=False)
|
| 187 |
+
model = model.to(args.device)
|
| 188 |
+
|
| 189 |
+
logger.debug("-- MODEL CONFIG --")
|
| 190 |
+
for k in model.config.__dict__:
|
| 191 |
+
logger.debug(" --{}: {}".format(k, model.config.__dict__[k]))
|
| 192 |
+
|
| 193 |
+
logger.debug("-- MODEL LABELS --")
|
| 194 |
+
logger.debug(" {}".format(" ".join(model.labels)))
|
| 195 |
+
|
| 196 |
+
optimizer = None
|
| 197 |
+
if load_optimizer:
|
| 198 |
+
optimizer = Trainer.build_optimizer(model, args)
|
| 199 |
+
if checkpoint.get('optimizer_state_dict', None) is not None:
|
| 200 |
+
for opt_name, opt_state_dict in checkpoint['optimizer_state_dict'].items():
|
| 201 |
+
optimizer[opt_name].load_state_dict(opt_state_dict)
|
| 202 |
+
else:
|
| 203 |
+
logger.info("Attempted to load optimizer to resume training, but optimizer not saved. Creating new optimizer")
|
| 204 |
+
|
| 205 |
+
trainer = Trainer(model, optimizer, epochs_trained, global_step, best_score)
|
| 206 |
+
|
| 207 |
+
return trainer
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def load_pretrain(args, foundation_cache):
|
| 211 |
+
if args.wordvec_pretrain_file:
|
| 212 |
+
pretrain_file = args.wordvec_pretrain_file
|
| 213 |
+
elif args.wordvec_type:
|
| 214 |
+
pretrain_file = '{}/{}.{}.pretrain.pt'.format(args.save_dir, args.shorthand, args.wordvec_type.name.lower())
|
| 215 |
+
else:
|
| 216 |
+
raise RuntimeError("TODO: need to get the wv type back from get_wordvec_file")
|
| 217 |
+
|
| 218 |
+
logger.debug("Looking for pretrained vectors in {}".format(pretrain_file))
|
| 219 |
+
if os.path.exists(pretrain_file):
|
| 220 |
+
return load_pretrain(pretrain_file, foundation_cache)
|
| 221 |
+
elif args.wordvec_raw_file:
|
| 222 |
+
vec_file = args.wordvec_raw_file
|
| 223 |
+
logger.debug("Pretrain not found. Looking in {}".format(vec_file))
|
| 224 |
+
else:
|
| 225 |
+
vec_file = utils.get_wordvec_file(args.wordvec_dir, args.shorthand, args.wordvec_type.name.lower())
|
| 226 |
+
logger.debug("Pretrain not found. Looking in {}".format(vec_file))
|
| 227 |
+
pretrain = Pretrain(pretrain_file, vec_file, args.pretrain_max_vocab)
|
| 228 |
+
logger.debug("Embedding shape: %s" % str(pretrain.emb.shape))
|
| 229 |
+
return pretrain
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
@staticmethod
|
| 233 |
+
def build_new_model(args, train_set):
|
| 234 |
+
"""
|
| 235 |
+
Load pretrained pieces and then build a new model
|
| 236 |
+
"""
|
| 237 |
+
if train_set is None:
|
| 238 |
+
raise ValueError("Must have a train set to build a new model - needed for labels and delta word vectors")
|
| 239 |
+
|
| 240 |
+
labels = data.dataset_labels(train_set)
|
| 241 |
+
|
| 242 |
+
if args.model_type == ModelType.CNN:
|
| 243 |
+
pretrain = Trainer.load_pretrain(args, foundation_cache=None)
|
| 244 |
+
elmo_model = utils.load_elmo(args.elmo_model) if args.use_elmo else None
|
| 245 |
+
charmodel_forward = load_charlm(args.charlm_forward_file)
|
| 246 |
+
charmodel_backward = load_charlm(args.charlm_backward_file)
|
| 247 |
+
peft_name = None
|
| 248 |
+
bert_model, bert_tokenizer = load_bert(args.bert_model)
|
| 249 |
+
|
| 250 |
+
use_peft = getattr(args, "use_peft", False)
|
| 251 |
+
if use_peft:
|
| 252 |
+
peft_name = "sentiment"
|
| 253 |
+
bert_model = build_peft_wrapper(bert_model, vars(args), logger, adapter_name=peft_name)
|
| 254 |
+
|
| 255 |
+
extra_vocab = data.dataset_vocab(train_set)
|
| 256 |
+
force_bert_saved = args.bert_finetune
|
| 257 |
+
model = cnn_classifier.CNNClassifier(pretrain=pretrain,
|
| 258 |
+
extra_vocab=extra_vocab,
|
| 259 |
+
labels=labels,
|
| 260 |
+
charmodel_forward=charmodel_forward,
|
| 261 |
+
charmodel_backward=charmodel_backward,
|
| 262 |
+
elmo_model=elmo_model,
|
| 263 |
+
bert_model=bert_model,
|
| 264 |
+
bert_tokenizer=bert_tokenizer,
|
| 265 |
+
force_bert_saved=force_bert_saved,
|
| 266 |
+
peft_name=peft_name,
|
| 267 |
+
args=args)
|
| 268 |
+
model = model.to(args.device)
|
| 269 |
+
elif args.model_type == ModelType.CONSTITUENCY:
|
| 270 |
+
# this passes flags such as "constituency_backprop" from
|
| 271 |
+
# the classifier to the TreeEmbedding as the "backprop" flag
|
| 272 |
+
parser_args = { x[len("constituency_"):]: y for x, y in vars(args).items() if x.startswith("constituency_") }
|
| 273 |
+
parser_args.update({
|
| 274 |
+
"wordvec_pretrain_file": args.wordvec_pretrain_file,
|
| 275 |
+
"charlm_forward_file": args.charlm_forward_file,
|
| 276 |
+
"charlm_backward_file": args.charlm_backward_file,
|
| 277 |
+
"bert_model": args.bert_model,
|
| 278 |
+
# we found that finetuning from the classifier output
|
| 279 |
+
# all the way to the bert layers caused the bert model
|
| 280 |
+
# to go astray
|
| 281 |
+
# could make this an option... but it is much less accurate
|
| 282 |
+
# with the Bert finetuning
|
| 283 |
+
# noting that the constituency parser itself works better
|
| 284 |
+
# after finetuning, of course
|
| 285 |
+
"bert_finetune": False,
|
| 286 |
+
"stage1_bert_finetune": False,
|
| 287 |
+
})
|
| 288 |
+
logger.info("Building constituency classifier using %s as the base model" % args.constituency_model)
|
| 289 |
+
tree_embedding = TreeEmbedding.from_parser_file(parser_args)
|
| 290 |
+
model = constituency_classifier.ConstituencyClassifier(tree_embedding=tree_embedding,
|
| 291 |
+
labels=labels,
|
| 292 |
+
args=args)
|
| 293 |
+
model = model.to(args.device)
|
| 294 |
+
else:
|
| 295 |
+
raise ValueError("Unhandled model type {}".format(args.model_type))
|
| 296 |
+
|
| 297 |
+
optimizer = Trainer.build_optimizer(model, args)
|
| 298 |
+
|
| 299 |
+
return Trainer(model, optimizer)
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
@staticmethod
|
| 303 |
+
def build_optimizer(model, args):
|
| 304 |
+
return get_split_optimizer(args.optim.lower(), model, args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay, bert_learning_rate=args.bert_learning_rate, bert_weight_decay=args.weight_decay * args.bert_weight_decay, is_peft=args.use_peft)
|
stanza/stanza/models/constituency/__init__.py
ADDED
|
File without changes
|
stanza/stanza/models/constituency/evaluate_treebanks.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Read multiple treebanks, score the results.
|
| 3 |
+
|
| 4 |
+
Reports the k-best score if multiple predicted treebanks are given.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import argparse
|
| 8 |
+
|
| 9 |
+
from stanza.models.constituency import tree_reader
|
| 10 |
+
from stanza.server.parser_eval import EvaluateParser, ParseResult
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def main():
|
| 14 |
+
parser = argparse.ArgumentParser(description='Get scores for one or more treebanks against the gold')
|
| 15 |
+
parser.add_argument('gold', type=str, help='Which file to load as the gold trees')
|
| 16 |
+
parser.add_argument('pred', type=str, nargs='+', help='Which file(s) are the predictions. If more than one is given, the evaluation will be "k-best" with the first prediction treated as the canonical')
|
| 17 |
+
args = parser.parse_args()
|
| 18 |
+
|
| 19 |
+
print("Loading gold treebank: " + args.gold)
|
| 20 |
+
gold = tree_reader.read_treebank(args.gold)
|
| 21 |
+
print("Loading predicted treebanks: " + args.pred)
|
| 22 |
+
pred = [tree_reader.read_treebank(x) for x in args.pred]
|
| 23 |
+
|
| 24 |
+
full_results = [ParseResult(parses[0], [*parses[1:]])
|
| 25 |
+
for parses in zip(gold, *pred)]
|
| 26 |
+
|
| 27 |
+
if len(pred) <= 1:
|
| 28 |
+
kbest = None
|
| 29 |
+
else:
|
| 30 |
+
kbest = len(pred)
|
| 31 |
+
|
| 32 |
+
with EvaluateParser(kbest=kbest) as evaluator:
|
| 33 |
+
response = evaluator.process(full_results)
|
| 34 |
+
|
| 35 |
+
if __name__ == '__main__':
|
| 36 |
+
main()
|
stanza/stanza/models/constituency/label_attention.py
ADDED
|
@@ -0,0 +1,726 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import functools
|
| 3 |
+
import sys
|
| 4 |
+
import torch
|
| 5 |
+
from torch.autograd import Variable
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.init as init
|
| 8 |
+
|
| 9 |
+
# publicly available versions alternate between torch.uint8 and torch.bool,
|
| 10 |
+
# but that is for older versions of torch anyway
|
| 11 |
+
DTYPE = torch.bool
|
| 12 |
+
|
| 13 |
+
class BatchIndices:
|
| 14 |
+
"""
|
| 15 |
+
Batch indices container class (used to implement packed batches)
|
| 16 |
+
"""
|
| 17 |
+
def __init__(self, batch_idxs_np, device):
|
| 18 |
+
self.batch_idxs_np = batch_idxs_np
|
| 19 |
+
self.batch_idxs_torch = torch.as_tensor(batch_idxs_np, dtype=torch.long, device=device)
|
| 20 |
+
|
| 21 |
+
self.batch_size = int(1 + np.max(batch_idxs_np))
|
| 22 |
+
|
| 23 |
+
batch_idxs_np_extra = np.concatenate([[-1], batch_idxs_np, [-1]])
|
| 24 |
+
self.boundaries_np = np.nonzero(batch_idxs_np_extra[1:] != batch_idxs_np_extra[:-1])[0]
|
| 25 |
+
|
| 26 |
+
#print(f"boundaries_np: {self.boundaries_np}")
|
| 27 |
+
#print(f"boundaries_np[1:]: {self.boundaries_np[1:]}")
|
| 28 |
+
#print(f"boundaries_np[:-1]: {self.boundaries_np[:-1]}")
|
| 29 |
+
self.seq_lens_np = self.boundaries_np[1:] - self.boundaries_np[:-1]
|
| 30 |
+
#print(f"seq_lens_np: {self.seq_lens_np}")
|
| 31 |
+
#print(f"batch_size: {self.batch_size}")
|
| 32 |
+
assert len(self.seq_lens_np) == self.batch_size
|
| 33 |
+
self.max_len = int(np.max(self.boundaries_np[1:] - self.boundaries_np[:-1]))
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class FeatureDropoutFunction(torch.autograd.function.InplaceFunction):
|
| 37 |
+
@classmethod
|
| 38 |
+
def forward(cls, ctx, input, batch_idxs, p=0.5, train=False, inplace=False):
|
| 39 |
+
if p < 0 or p > 1:
|
| 40 |
+
raise ValueError("dropout probability has to be between 0 and 1, "
|
| 41 |
+
"but got {}".format(p))
|
| 42 |
+
|
| 43 |
+
ctx.p = p
|
| 44 |
+
ctx.train = train
|
| 45 |
+
ctx.inplace = inplace
|
| 46 |
+
|
| 47 |
+
if ctx.inplace:
|
| 48 |
+
ctx.mark_dirty(input)
|
| 49 |
+
output = input
|
| 50 |
+
else:
|
| 51 |
+
output = input.clone()
|
| 52 |
+
|
| 53 |
+
if ctx.p > 0 and ctx.train:
|
| 54 |
+
ctx.noise = input.new().resize_(batch_idxs.batch_size, input.size(1))
|
| 55 |
+
if ctx.p == 1:
|
| 56 |
+
ctx.noise.fill_(0)
|
| 57 |
+
else:
|
| 58 |
+
ctx.noise.bernoulli_(1 - ctx.p).div_(1 - ctx.p)
|
| 59 |
+
ctx.noise = ctx.noise[batch_idxs.batch_idxs_torch, :]
|
| 60 |
+
output.mul_(ctx.noise)
|
| 61 |
+
|
| 62 |
+
return output
|
| 63 |
+
|
| 64 |
+
@staticmethod
|
| 65 |
+
def backward(ctx, grad_output):
|
| 66 |
+
if ctx.p > 0 and ctx.train:
|
| 67 |
+
return grad_output.mul(ctx.noise), None, None, None, None
|
| 68 |
+
else:
|
| 69 |
+
return grad_output, None, None, None, None
|
| 70 |
+
|
| 71 |
+
#
|
| 72 |
+
class FeatureDropout(nn.Module):
|
| 73 |
+
"""
|
| 74 |
+
Feature-level dropout: takes an input of size len x num_features and drops
|
| 75 |
+
each feature with probabibility p. A feature is dropped across the full
|
| 76 |
+
portion of the input that corresponds to a single batch element.
|
| 77 |
+
"""
|
| 78 |
+
def __init__(self, p=0.5, inplace=False):
|
| 79 |
+
super().__init__()
|
| 80 |
+
if p < 0 or p > 1:
|
| 81 |
+
raise ValueError("dropout probability has to be between 0 and 1, "
|
| 82 |
+
"but got {}".format(p))
|
| 83 |
+
self.p = p
|
| 84 |
+
self.inplace = inplace
|
| 85 |
+
|
| 86 |
+
def forward(self, input, batch_idxs):
|
| 87 |
+
return FeatureDropoutFunction.apply(input, batch_idxs, self.p, self.training, self.inplace)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class LayerNormalization(nn.Module):
|
| 92 |
+
def __init__(self, d_hid, eps=1e-3, affine=True):
|
| 93 |
+
super(LayerNormalization, self).__init__()
|
| 94 |
+
|
| 95 |
+
self.eps = eps
|
| 96 |
+
self.affine = affine
|
| 97 |
+
if self.affine:
|
| 98 |
+
self.a_2 = nn.Parameter(torch.ones(d_hid), requires_grad=True)
|
| 99 |
+
self.b_2 = nn.Parameter(torch.zeros(d_hid), requires_grad=True)
|
| 100 |
+
|
| 101 |
+
def forward(self, z):
|
| 102 |
+
if z.size(-1) == 1:
|
| 103 |
+
return z
|
| 104 |
+
|
| 105 |
+
mu = torch.mean(z, keepdim=True, dim=-1)
|
| 106 |
+
sigma = torch.std(z, keepdim=True, dim=-1)
|
| 107 |
+
ln_out = (z - mu.expand_as(z)) / (sigma.expand_as(z) + self.eps)
|
| 108 |
+
if self.affine:
|
| 109 |
+
ln_out = ln_out * self.a_2.expand_as(ln_out) + self.b_2.expand_as(ln_out)
|
| 110 |
+
|
| 111 |
+
return ln_out
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class ScaledDotProductAttention(nn.Module):
|
| 116 |
+
def __init__(self, d_model, attention_dropout=0.1):
|
| 117 |
+
super(ScaledDotProductAttention, self).__init__()
|
| 118 |
+
self.temper = d_model ** 0.5
|
| 119 |
+
self.dropout = nn.Dropout(attention_dropout)
|
| 120 |
+
self.softmax = nn.Softmax(dim=-1)
|
| 121 |
+
|
| 122 |
+
def forward(self, q, k, v, attn_mask=None):
|
| 123 |
+
# q: [batch, slot, feat] or (batch * d_l) x max_len x d_k
|
| 124 |
+
# k: [batch, slot, feat] or (batch * d_l) x max_len x d_k
|
| 125 |
+
# v: [batch, slot, feat] or (batch * d_l) x max_len x d_v
|
| 126 |
+
# q in LAL is (batch * d_l) x 1 x d_k
|
| 127 |
+
|
| 128 |
+
attn = torch.bmm(q, k.transpose(1, 2)) / self.temper # (batch * d_l) x max_len x max_len
|
| 129 |
+
# in LAL, gives: (batch * d_l) x 1 x max_len
|
| 130 |
+
# attention weights from each word to each word, for each label
|
| 131 |
+
# in best model (repeated q): attention weights from label (as vector weights) to each word
|
| 132 |
+
|
| 133 |
+
if attn_mask is not None:
|
| 134 |
+
assert attn_mask.size() == attn.size(), \
|
| 135 |
+
'Attention mask shape {} mismatch ' \
|
| 136 |
+
'with Attention logit tensor shape ' \
|
| 137 |
+
'{}.'.format(attn_mask.size(), attn.size())
|
| 138 |
+
|
| 139 |
+
attn.data.masked_fill_(attn_mask, -float('inf'))
|
| 140 |
+
|
| 141 |
+
attn = self.softmax(attn)
|
| 142 |
+
# Note that this makes the distribution not sum to 1. At some point it
|
| 143 |
+
# may be worth researching whether this is the right way to apply
|
| 144 |
+
# dropout to the attention.
|
| 145 |
+
# Note that the t2t code also applies dropout in this manner
|
| 146 |
+
attn = self.dropout(attn)
|
| 147 |
+
output = torch.bmm(attn, v) # (batch * d_l) x max_len x d_v
|
| 148 |
+
# in LAL, gives: (batch * d_l) x 1 x d_v
|
| 149 |
+
|
| 150 |
+
return output, attn
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
class MultiHeadAttention(nn.Module):
|
| 154 |
+
"""
|
| 155 |
+
Multi-head attention module
|
| 156 |
+
"""
|
| 157 |
+
|
| 158 |
+
def __init__(self, n_head, d_model, d_k, d_v, residual_dropout=0.1, attention_dropout=0.1, d_positional=None):
|
| 159 |
+
super(MultiHeadAttention, self).__init__()
|
| 160 |
+
|
| 161 |
+
self.n_head = n_head
|
| 162 |
+
self.d_k = d_k
|
| 163 |
+
self.d_v = d_v
|
| 164 |
+
|
| 165 |
+
if not d_positional:
|
| 166 |
+
self.partitioned = False
|
| 167 |
+
else:
|
| 168 |
+
self.partitioned = True
|
| 169 |
+
|
| 170 |
+
if self.partitioned:
|
| 171 |
+
self.d_content = d_model - d_positional
|
| 172 |
+
self.d_positional = d_positional
|
| 173 |
+
|
| 174 |
+
self.w_qs1 = nn.Parameter(torch.FloatTensor(n_head, self.d_content, d_k // 2))
|
| 175 |
+
self.w_ks1 = nn.Parameter(torch.FloatTensor(n_head, self.d_content, d_k // 2))
|
| 176 |
+
self.w_vs1 = nn.Parameter(torch.FloatTensor(n_head, self.d_content, d_v // 2))
|
| 177 |
+
|
| 178 |
+
self.w_qs2 = nn.Parameter(torch.FloatTensor(n_head, self.d_positional, d_k // 2))
|
| 179 |
+
self.w_ks2 = nn.Parameter(torch.FloatTensor(n_head, self.d_positional, d_k // 2))
|
| 180 |
+
self.w_vs2 = nn.Parameter(torch.FloatTensor(n_head, self.d_positional, d_v // 2))
|
| 181 |
+
|
| 182 |
+
init.xavier_normal_(self.w_qs1)
|
| 183 |
+
init.xavier_normal_(self.w_ks1)
|
| 184 |
+
init.xavier_normal_(self.w_vs1)
|
| 185 |
+
|
| 186 |
+
init.xavier_normal_(self.w_qs2)
|
| 187 |
+
init.xavier_normal_(self.w_ks2)
|
| 188 |
+
init.xavier_normal_(self.w_vs2)
|
| 189 |
+
else:
|
| 190 |
+
self.w_qs = nn.Parameter(torch.FloatTensor(n_head, d_model, d_k))
|
| 191 |
+
self.w_ks = nn.Parameter(torch.FloatTensor(n_head, d_model, d_k))
|
| 192 |
+
self.w_vs = nn.Parameter(torch.FloatTensor(n_head, d_model, d_v))
|
| 193 |
+
|
| 194 |
+
init.xavier_normal_(self.w_qs)
|
| 195 |
+
init.xavier_normal_(self.w_ks)
|
| 196 |
+
init.xavier_normal_(self.w_vs)
|
| 197 |
+
|
| 198 |
+
self.attention = ScaledDotProductAttention(d_model, attention_dropout=attention_dropout)
|
| 199 |
+
self.layer_norm = LayerNormalization(d_model)
|
| 200 |
+
|
| 201 |
+
if not self.partitioned:
|
| 202 |
+
# The lack of a bias term here is consistent with the t2t code, though
|
| 203 |
+
# in my experiments I have never observed this making a difference.
|
| 204 |
+
self.proj = nn.Linear(n_head*d_v, d_model, bias=False)
|
| 205 |
+
else:
|
| 206 |
+
self.proj1 = nn.Linear(n_head*(d_v//2), self.d_content, bias=False)
|
| 207 |
+
self.proj2 = nn.Linear(n_head*(d_v//2), self.d_positional, bias=False)
|
| 208 |
+
|
| 209 |
+
self.residual_dropout = FeatureDropout(residual_dropout)
|
| 210 |
+
|
| 211 |
+
def split_qkv_packed(self, inp, qk_inp=None):
|
| 212 |
+
v_inp_repeated = inp.repeat(self.n_head, 1).view(self.n_head, -1, inp.size(-1)) # n_head x len_inp x d_model
|
| 213 |
+
if qk_inp is None:
|
| 214 |
+
qk_inp_repeated = v_inp_repeated
|
| 215 |
+
else:
|
| 216 |
+
qk_inp_repeated = qk_inp.repeat(self.n_head, 1).view(self.n_head, -1, qk_inp.size(-1))
|
| 217 |
+
|
| 218 |
+
if not self.partitioned:
|
| 219 |
+
q_s = torch.bmm(qk_inp_repeated, self.w_qs) # n_head x len_inp x d_k
|
| 220 |
+
k_s = torch.bmm(qk_inp_repeated, self.w_ks) # n_head x len_inp x d_k
|
| 221 |
+
v_s = torch.bmm(v_inp_repeated, self.w_vs) # n_head x len_inp x d_v
|
| 222 |
+
else:
|
| 223 |
+
q_s = torch.cat([
|
| 224 |
+
torch.bmm(qk_inp_repeated[:,:,:self.d_content], self.w_qs1),
|
| 225 |
+
torch.bmm(qk_inp_repeated[:,:,self.d_content:], self.w_qs2),
|
| 226 |
+
], -1)
|
| 227 |
+
k_s = torch.cat([
|
| 228 |
+
torch.bmm(qk_inp_repeated[:,:,:self.d_content], self.w_ks1),
|
| 229 |
+
torch.bmm(qk_inp_repeated[:,:,self.d_content:], self.w_ks2),
|
| 230 |
+
], -1)
|
| 231 |
+
v_s = torch.cat([
|
| 232 |
+
torch.bmm(v_inp_repeated[:,:,:self.d_content], self.w_vs1),
|
| 233 |
+
torch.bmm(v_inp_repeated[:,:,self.d_content:], self.w_vs2),
|
| 234 |
+
], -1)
|
| 235 |
+
return q_s, k_s, v_s
|
| 236 |
+
|
| 237 |
+
def pad_and_rearrange(self, q_s, k_s, v_s, batch_idxs):
|
| 238 |
+
# Input is padded representation: n_head x len_inp x d
|
| 239 |
+
# Output is packed representation: (n_head * mb_size) x len_padded x d
|
| 240 |
+
# (along with masks for the attention and output)
|
| 241 |
+
n_head = self.n_head
|
| 242 |
+
d_k, d_v = self.d_k, self.d_v
|
| 243 |
+
|
| 244 |
+
len_padded = batch_idxs.max_len
|
| 245 |
+
mb_size = batch_idxs.batch_size
|
| 246 |
+
q_padded = q_s.new_zeros((n_head, mb_size, len_padded, d_k))
|
| 247 |
+
k_padded = k_s.new_zeros((n_head, mb_size, len_padded, d_k))
|
| 248 |
+
v_padded = v_s.new_zeros((n_head, mb_size, len_padded, d_v))
|
| 249 |
+
invalid_mask = q_s.new_ones((mb_size, len_padded), dtype=DTYPE)
|
| 250 |
+
|
| 251 |
+
for i, (start, end) in enumerate(zip(batch_idxs.boundaries_np[:-1], batch_idxs.boundaries_np[1:])):
|
| 252 |
+
q_padded[:,i,:end-start,:] = q_s[:,start:end,:]
|
| 253 |
+
k_padded[:,i,:end-start,:] = k_s[:,start:end,:]
|
| 254 |
+
v_padded[:,i,:end-start,:] = v_s[:,start:end,:]
|
| 255 |
+
invalid_mask[i, :end-start].fill_(False)
|
| 256 |
+
|
| 257 |
+
return(
|
| 258 |
+
q_padded.view(-1, len_padded, d_k),
|
| 259 |
+
k_padded.view(-1, len_padded, d_k),
|
| 260 |
+
v_padded.view(-1, len_padded, d_v),
|
| 261 |
+
invalid_mask.unsqueeze(1).expand(mb_size, len_padded, len_padded).repeat(n_head, 1, 1),
|
| 262 |
+
(~invalid_mask).repeat(n_head, 1),
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
def combine_v(self, outputs):
|
| 266 |
+
# Combine attention information from the different heads
|
| 267 |
+
n_head = self.n_head
|
| 268 |
+
outputs = outputs.view(n_head, -1, self.d_v) # n_head x len_inp x d_kv
|
| 269 |
+
|
| 270 |
+
if not self.partitioned:
|
| 271 |
+
# Switch from n_head x len_inp x d_v to len_inp x (n_head * d_v)
|
| 272 |
+
outputs = torch.transpose(outputs, 0, 1).contiguous().view(-1, n_head * self.d_v)
|
| 273 |
+
|
| 274 |
+
# Project back to residual size
|
| 275 |
+
outputs = self.proj(outputs)
|
| 276 |
+
else:
|
| 277 |
+
d_v1 = self.d_v // 2
|
| 278 |
+
outputs1 = outputs[:,:,:d_v1]
|
| 279 |
+
outputs2 = outputs[:,:,d_v1:]
|
| 280 |
+
outputs1 = torch.transpose(outputs1, 0, 1).contiguous().view(-1, n_head * d_v1)
|
| 281 |
+
outputs2 = torch.transpose(outputs2, 0, 1).contiguous().view(-1, n_head * d_v1)
|
| 282 |
+
outputs = torch.cat([
|
| 283 |
+
self.proj1(outputs1),
|
| 284 |
+
self.proj2(outputs2),
|
| 285 |
+
], -1)
|
| 286 |
+
|
| 287 |
+
return outputs
|
| 288 |
+
|
| 289 |
+
def forward(self, inp, batch_idxs, qk_inp=None):
|
| 290 |
+
residual = inp
|
| 291 |
+
|
| 292 |
+
# While still using a packed representation, project to obtain the
|
| 293 |
+
# query/key/value for each head
|
| 294 |
+
q_s, k_s, v_s = self.split_qkv_packed(inp, qk_inp=qk_inp)
|
| 295 |
+
# n_head x len_inp x d_kv
|
| 296 |
+
|
| 297 |
+
# Switch to padded representation, perform attention, then switch back
|
| 298 |
+
q_padded, k_padded, v_padded, attn_mask, output_mask = self.pad_and_rearrange(q_s, k_s, v_s, batch_idxs)
|
| 299 |
+
# (n_head * batch) x len_padded x d_kv
|
| 300 |
+
|
| 301 |
+
outputs_padded, attns_padded = self.attention(
|
| 302 |
+
q_padded, k_padded, v_padded,
|
| 303 |
+
attn_mask=attn_mask,
|
| 304 |
+
)
|
| 305 |
+
outputs = outputs_padded[output_mask]
|
| 306 |
+
# (n_head * len_inp) x d_kv
|
| 307 |
+
outputs = self.combine_v(outputs)
|
| 308 |
+
# len_inp x d_model
|
| 309 |
+
|
| 310 |
+
outputs = self.residual_dropout(outputs, batch_idxs)
|
| 311 |
+
|
| 312 |
+
return self.layer_norm(outputs + residual), attns_padded
|
| 313 |
+
|
| 314 |
+
#
|
| 315 |
+
class PositionwiseFeedForward(nn.Module):
|
| 316 |
+
"""
|
| 317 |
+
A position-wise feed forward module.
|
| 318 |
+
|
| 319 |
+
Projects to a higher-dimensional space before applying ReLU, then projects
|
| 320 |
+
back.
|
| 321 |
+
"""
|
| 322 |
+
|
| 323 |
+
def __init__(self, d_hid, d_ff, relu_dropout=0.1, residual_dropout=0.1):
|
| 324 |
+
super(PositionwiseFeedForward, self).__init__()
|
| 325 |
+
self.w_1 = nn.Linear(d_hid, d_ff)
|
| 326 |
+
self.w_2 = nn.Linear(d_ff, d_hid)
|
| 327 |
+
|
| 328 |
+
self.layer_norm = LayerNormalization(d_hid)
|
| 329 |
+
self.relu_dropout = FeatureDropout(relu_dropout)
|
| 330 |
+
self.residual_dropout = FeatureDropout(residual_dropout)
|
| 331 |
+
self.relu = nn.ReLU()
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
def forward(self, x, batch_idxs):
|
| 335 |
+
residual = x
|
| 336 |
+
|
| 337 |
+
output = self.w_1(x)
|
| 338 |
+
output = self.relu_dropout(self.relu(output), batch_idxs)
|
| 339 |
+
output = self.w_2(output)
|
| 340 |
+
|
| 341 |
+
output = self.residual_dropout(output, batch_idxs)
|
| 342 |
+
return self.layer_norm(output + residual)
|
| 343 |
+
|
| 344 |
+
#
|
| 345 |
+
class PartitionedPositionwiseFeedForward(nn.Module):
|
| 346 |
+
def __init__(self, d_hid, d_ff, d_positional, relu_dropout=0.1, residual_dropout=0.1):
|
| 347 |
+
super().__init__()
|
| 348 |
+
self.d_content = d_hid - d_positional
|
| 349 |
+
self.w_1c = nn.Linear(self.d_content, d_ff//2)
|
| 350 |
+
self.w_1p = nn.Linear(d_positional, d_ff//2)
|
| 351 |
+
self.w_2c = nn.Linear(d_ff//2, self.d_content)
|
| 352 |
+
self.w_2p = nn.Linear(d_ff//2, d_positional)
|
| 353 |
+
self.layer_norm = LayerNormalization(d_hid)
|
| 354 |
+
self.relu_dropout = FeatureDropout(relu_dropout)
|
| 355 |
+
self.residual_dropout = FeatureDropout(residual_dropout)
|
| 356 |
+
self.relu = nn.ReLU()
|
| 357 |
+
|
| 358 |
+
def forward(self, x, batch_idxs):
|
| 359 |
+
residual = x
|
| 360 |
+
xc = x[:, :self.d_content]
|
| 361 |
+
xp = x[:, self.d_content:]
|
| 362 |
+
|
| 363 |
+
outputc = self.w_1c(xc)
|
| 364 |
+
outputc = self.relu_dropout(self.relu(outputc), batch_idxs)
|
| 365 |
+
outputc = self.w_2c(outputc)
|
| 366 |
+
|
| 367 |
+
outputp = self.w_1p(xp)
|
| 368 |
+
outputp = self.relu_dropout(self.relu(outputp), batch_idxs)
|
| 369 |
+
outputp = self.w_2p(outputp)
|
| 370 |
+
|
| 371 |
+
output = torch.cat([outputc, outputp], -1)
|
| 372 |
+
|
| 373 |
+
output = self.residual_dropout(output, batch_idxs)
|
| 374 |
+
return self.layer_norm(output + residual)
|
| 375 |
+
|
| 376 |
+
class LabelAttention(nn.Module):
|
| 377 |
+
"""
|
| 378 |
+
Single-head Attention layer for label-specific representations
|
| 379 |
+
"""
|
| 380 |
+
|
| 381 |
+
def __init__(self, d_model, d_k, d_v, d_l, d_proj, combine_as_self, use_resdrop=True, q_as_matrix=False, residual_dropout=0.1, attention_dropout=0.1, d_positional=None):
|
| 382 |
+
super(LabelAttention, self).__init__()
|
| 383 |
+
self.d_k = d_k
|
| 384 |
+
self.d_v = d_v
|
| 385 |
+
self.d_l = d_l # Number of Labels
|
| 386 |
+
self.d_model = d_model # Model Dimensionality
|
| 387 |
+
self.d_proj = d_proj # Projection dimension of each label output
|
| 388 |
+
self.use_resdrop = use_resdrop # Using Residual Dropout?
|
| 389 |
+
self.q_as_matrix = q_as_matrix # Using a Matrix of Q to be multiplied with input instead of learned q vectors
|
| 390 |
+
self.combine_as_self = combine_as_self # Using the Combination Method of Self-Attention
|
| 391 |
+
|
| 392 |
+
if not d_positional:
|
| 393 |
+
self.partitioned = False
|
| 394 |
+
else:
|
| 395 |
+
self.partitioned = True
|
| 396 |
+
|
| 397 |
+
if self.partitioned:
|
| 398 |
+
if d_model <= d_positional:
|
| 399 |
+
raise ValueError("Unable to build LabelAttention. d_model %d <= d_positional %d" % (d_model, d_positional))
|
| 400 |
+
self.d_content = d_model - d_positional
|
| 401 |
+
self.d_positional = d_positional
|
| 402 |
+
|
| 403 |
+
if self.q_as_matrix:
|
| 404 |
+
self.w_qs1 = nn.Parameter(torch.FloatTensor(self.d_l, self.d_content, d_k // 2), requires_grad=True)
|
| 405 |
+
else:
|
| 406 |
+
self.w_qs1 = nn.Parameter(torch.FloatTensor(self.d_l, d_k // 2), requires_grad=True)
|
| 407 |
+
self.w_ks1 = nn.Parameter(torch.FloatTensor(self.d_l, self.d_content, d_k // 2), requires_grad=True)
|
| 408 |
+
self.w_vs1 = nn.Parameter(torch.FloatTensor(self.d_l, self.d_content, d_v // 2), requires_grad=True)
|
| 409 |
+
|
| 410 |
+
if self.q_as_matrix:
|
| 411 |
+
self.w_qs2 = nn.Parameter(torch.FloatTensor(self.d_l, self.d_positional, d_k // 2), requires_grad=True)
|
| 412 |
+
else:
|
| 413 |
+
self.w_qs2 = nn.Parameter(torch.FloatTensor(self.d_l, d_k // 2), requires_grad=True)
|
| 414 |
+
self.w_ks2 = nn.Parameter(torch.FloatTensor(self.d_l, self.d_positional, d_k // 2), requires_grad=True)
|
| 415 |
+
self.w_vs2 = nn.Parameter(torch.FloatTensor(self.d_l, self.d_positional, d_v // 2), requires_grad=True)
|
| 416 |
+
|
| 417 |
+
init.xavier_normal_(self.w_qs1)
|
| 418 |
+
init.xavier_normal_(self.w_ks1)
|
| 419 |
+
init.xavier_normal_(self.w_vs1)
|
| 420 |
+
|
| 421 |
+
init.xavier_normal_(self.w_qs2)
|
| 422 |
+
init.xavier_normal_(self.w_ks2)
|
| 423 |
+
init.xavier_normal_(self.w_vs2)
|
| 424 |
+
else:
|
| 425 |
+
if self.q_as_matrix:
|
| 426 |
+
self.w_qs = nn.Parameter(torch.FloatTensor(self.d_l, d_model, d_k), requires_grad=True)
|
| 427 |
+
else:
|
| 428 |
+
self.w_qs = nn.Parameter(torch.FloatTensor(self.d_l, d_k), requires_grad=True)
|
| 429 |
+
self.w_ks = nn.Parameter(torch.FloatTensor(self.d_l, d_model, d_k), requires_grad=True)
|
| 430 |
+
self.w_vs = nn.Parameter(torch.FloatTensor(self.d_l, d_model, d_v), requires_grad=True)
|
| 431 |
+
|
| 432 |
+
init.xavier_normal_(self.w_qs)
|
| 433 |
+
init.xavier_normal_(self.w_ks)
|
| 434 |
+
init.xavier_normal_(self.w_vs)
|
| 435 |
+
|
| 436 |
+
self.attention = ScaledDotProductAttention(d_model, attention_dropout=attention_dropout)
|
| 437 |
+
if self.combine_as_self:
|
| 438 |
+
self.layer_norm = LayerNormalization(d_model)
|
| 439 |
+
else:
|
| 440 |
+
self.layer_norm = LayerNormalization(self.d_proj)
|
| 441 |
+
|
| 442 |
+
if not self.partitioned:
|
| 443 |
+
# The lack of a bias term here is consistent with the t2t code, though
|
| 444 |
+
# in my experiments I have never observed this making a difference.
|
| 445 |
+
if self.combine_as_self:
|
| 446 |
+
self.proj = nn.Linear(self.d_l * d_v, d_model, bias=False)
|
| 447 |
+
else:
|
| 448 |
+
self.proj = nn.Linear(d_v, d_model, bias=False) # input dimension does not match, should be d_l * d_v
|
| 449 |
+
else:
|
| 450 |
+
if self.combine_as_self:
|
| 451 |
+
self.proj1 = nn.Linear(self.d_l*(d_v//2), self.d_content, bias=False)
|
| 452 |
+
self.proj2 = nn.Linear(self.d_l*(d_v//2), self.d_positional, bias=False)
|
| 453 |
+
else:
|
| 454 |
+
self.proj1 = nn.Linear(d_v//2, self.d_content, bias=False)
|
| 455 |
+
self.proj2 = nn.Linear(d_v//2, self.d_positional, bias=False)
|
| 456 |
+
if not self.combine_as_self:
|
| 457 |
+
self.reduce_proj = nn.Linear(d_model, self.d_proj, bias=False)
|
| 458 |
+
|
| 459 |
+
self.residual_dropout = FeatureDropout(residual_dropout)
|
| 460 |
+
|
| 461 |
+
def split_qkv_packed(self, inp, k_inp=None):
|
| 462 |
+
len_inp = inp.size(0)
|
| 463 |
+
v_inp_repeated = inp.repeat(self.d_l, 1).view(self.d_l, -1, inp.size(-1)) # d_l x len_inp x d_model
|
| 464 |
+
if k_inp is None:
|
| 465 |
+
k_inp_repeated = v_inp_repeated
|
| 466 |
+
else:
|
| 467 |
+
k_inp_repeated = k_inp.repeat(self.d_l, 1).view(self.d_l, -1, k_inp.size(-1)) # d_l x len_inp x d_model
|
| 468 |
+
|
| 469 |
+
if not self.partitioned:
|
| 470 |
+
if self.q_as_matrix:
|
| 471 |
+
q_s = torch.bmm(k_inp_repeated, self.w_qs) # d_l x len_inp x d_k
|
| 472 |
+
else:
|
| 473 |
+
q_s = self.w_qs.unsqueeze(1) # d_l x 1 x d_k
|
| 474 |
+
k_s = torch.bmm(k_inp_repeated, self.w_ks) # d_l x len_inp x d_k
|
| 475 |
+
v_s = torch.bmm(v_inp_repeated, self.w_vs) # d_l x len_inp x d_v
|
| 476 |
+
else:
|
| 477 |
+
if self.q_as_matrix:
|
| 478 |
+
q_s = torch.cat([
|
| 479 |
+
torch.bmm(k_inp_repeated[:,:,:self.d_content], self.w_qs1),
|
| 480 |
+
torch.bmm(k_inp_repeated[:,:,self.d_content:], self.w_qs2),
|
| 481 |
+
], -1)
|
| 482 |
+
else:
|
| 483 |
+
q_s = torch.cat([
|
| 484 |
+
self.w_qs1.unsqueeze(1),
|
| 485 |
+
self.w_qs2.unsqueeze(1),
|
| 486 |
+
], -1)
|
| 487 |
+
k_s = torch.cat([
|
| 488 |
+
torch.bmm(k_inp_repeated[:,:,:self.d_content], self.w_ks1),
|
| 489 |
+
torch.bmm(k_inp_repeated[:,:,self.d_content:], self.w_ks2),
|
| 490 |
+
], -1)
|
| 491 |
+
v_s = torch.cat([
|
| 492 |
+
torch.bmm(v_inp_repeated[:,:,:self.d_content], self.w_vs1),
|
| 493 |
+
torch.bmm(v_inp_repeated[:,:,self.d_content:], self.w_vs2),
|
| 494 |
+
], -1)
|
| 495 |
+
return q_s, k_s, v_s
|
| 496 |
+
|
| 497 |
+
def pad_and_rearrange(self, q_s, k_s, v_s, batch_idxs):
|
| 498 |
+
# Input is padded representation: n_head x len_inp x d
|
| 499 |
+
# Output is packed representation: (n_head * mb_size) x len_padded x d
|
| 500 |
+
# (along with masks for the attention and output)
|
| 501 |
+
n_head = self.d_l
|
| 502 |
+
d_k, d_v = self.d_k, self.d_v
|
| 503 |
+
|
| 504 |
+
len_padded = batch_idxs.max_len
|
| 505 |
+
mb_size = batch_idxs.batch_size
|
| 506 |
+
if self.q_as_matrix:
|
| 507 |
+
q_padded = q_s.new_zeros((n_head, mb_size, len_padded, d_k))
|
| 508 |
+
else:
|
| 509 |
+
q_padded = q_s.repeat(mb_size, 1, 1) # (d_l * mb_size) x 1 x d_k
|
| 510 |
+
k_padded = k_s.new_zeros((n_head, mb_size, len_padded, d_k))
|
| 511 |
+
v_padded = v_s.new_zeros((n_head, mb_size, len_padded, d_v))
|
| 512 |
+
invalid_mask = q_s.new_ones((mb_size, len_padded), dtype=DTYPE)
|
| 513 |
+
|
| 514 |
+
for i, (start, end) in enumerate(zip(batch_idxs.boundaries_np[:-1], batch_idxs.boundaries_np[1:])):
|
| 515 |
+
if self.q_as_matrix:
|
| 516 |
+
q_padded[:,i,:end-start,:] = q_s[:,start:end,:]
|
| 517 |
+
k_padded[:,i,:end-start,:] = k_s[:,start:end,:]
|
| 518 |
+
v_padded[:,i,:end-start,:] = v_s[:,start:end,:]
|
| 519 |
+
invalid_mask[i, :end-start].fill_(False)
|
| 520 |
+
|
| 521 |
+
if self.q_as_matrix:
|
| 522 |
+
q_padded = q_padded.view(-1, len_padded, d_k)
|
| 523 |
+
attn_mask = invalid_mask.unsqueeze(1).expand(mb_size, len_padded, len_padded).repeat(n_head, 1, 1)
|
| 524 |
+
else:
|
| 525 |
+
attn_mask = invalid_mask.unsqueeze(1).repeat(n_head, 1, 1)
|
| 526 |
+
|
| 527 |
+
output_mask = (~invalid_mask).repeat(n_head, 1)
|
| 528 |
+
|
| 529 |
+
return(
|
| 530 |
+
q_padded,
|
| 531 |
+
k_padded.view(-1, len_padded, d_k),
|
| 532 |
+
v_padded.view(-1, len_padded, d_v),
|
| 533 |
+
attn_mask,
|
| 534 |
+
output_mask,
|
| 535 |
+
)
|
| 536 |
+
|
| 537 |
+
def combine_v(self, outputs):
|
| 538 |
+
# Combine attention information from the different labels
|
| 539 |
+
d_l = self.d_l
|
| 540 |
+
outputs = outputs.view(d_l, -1, self.d_v) # d_l x len_inp x d_v
|
| 541 |
+
|
| 542 |
+
if not self.partitioned:
|
| 543 |
+
# Switch from d_l x len_inp x d_v to len_inp x d_l x d_v
|
| 544 |
+
if self.combine_as_self:
|
| 545 |
+
outputs = torch.transpose(outputs, 0, 1).contiguous().view(-1, d_l * self.d_v)
|
| 546 |
+
else:
|
| 547 |
+
outputs = torch.transpose(outputs, 0, 1)#.contiguous() #.view(-1, d_l * self.d_v)
|
| 548 |
+
# Project back to residual size
|
| 549 |
+
outputs = self.proj(outputs) # Becomes len_inp x d_l x d_model
|
| 550 |
+
else:
|
| 551 |
+
d_v1 = self.d_v // 2
|
| 552 |
+
outputs1 = outputs[:,:,:d_v1]
|
| 553 |
+
outputs2 = outputs[:,:,d_v1:]
|
| 554 |
+
if self.combine_as_self:
|
| 555 |
+
outputs1 = torch.transpose(outputs1, 0, 1).contiguous().view(-1, d_l * d_v1)
|
| 556 |
+
outputs2 = torch.transpose(outputs2, 0, 1).contiguous().view(-1, d_l * d_v1)
|
| 557 |
+
else:
|
| 558 |
+
outputs1 = torch.transpose(outputs1, 0, 1)#.contiguous() #.view(-1, d_l * d_v1)
|
| 559 |
+
outputs2 = torch.transpose(outputs2, 0, 1)#.contiguous() #.view(-1, d_l * d_v1)
|
| 560 |
+
outputs = torch.cat([
|
| 561 |
+
self.proj1(outputs1),
|
| 562 |
+
self.proj2(outputs2),
|
| 563 |
+
], -1)#.contiguous()
|
| 564 |
+
|
| 565 |
+
return outputs
|
| 566 |
+
|
| 567 |
+
def forward(self, inp, batch_idxs, k_inp=None):
|
| 568 |
+
residual = inp # len_inp x d_model
|
| 569 |
+
#print()
|
| 570 |
+
#print(f"inp.shape: {inp.shape}")
|
| 571 |
+
len_inp = inp.size(0)
|
| 572 |
+
#print(f"len_inp: {len_inp}")
|
| 573 |
+
|
| 574 |
+
# While still using a packed representation, project to obtain the
|
| 575 |
+
# query/key/value for each head
|
| 576 |
+
q_s, k_s, v_s = self.split_qkv_packed(inp, k_inp=k_inp)
|
| 577 |
+
# d_l x len_inp x d_k
|
| 578 |
+
# q_s is d_l x 1 x d_k
|
| 579 |
+
|
| 580 |
+
# Switch to padded representation, perform attention, then switch back
|
| 581 |
+
q_padded, k_padded, v_padded, attn_mask, output_mask = self.pad_and_rearrange(q_s, k_s, v_s, batch_idxs)
|
| 582 |
+
# q_padded, k_padded, v_padded: (d_l * batch_size) x max_len x d_kv
|
| 583 |
+
# q_s is (d_l * batch_size) x 1 x d_kv
|
| 584 |
+
|
| 585 |
+
outputs_padded, attns_padded = self.attention(
|
| 586 |
+
q_padded, k_padded, v_padded,
|
| 587 |
+
attn_mask=attn_mask,
|
| 588 |
+
)
|
| 589 |
+
# outputs_padded: (d_l * batch_size) x max_len x d_kv
|
| 590 |
+
# in LAL: (d_l * batch_size) x 1 x d_kv
|
| 591 |
+
# on the best model, this is one value vector per label that is repeated max_len times
|
| 592 |
+
if not self.q_as_matrix:
|
| 593 |
+
outputs_padded = outputs_padded.repeat(1,output_mask.size(-1),1)
|
| 594 |
+
outputs = outputs_padded[output_mask]
|
| 595 |
+
# outputs: (d_l * len_inp) x d_kv or LAL: (d_l * len_inp) x d_kv
|
| 596 |
+
# output_mask: (d_l * batch_size) x max_len
|
| 597 |
+
outputs = self.combine_v(outputs)
|
| 598 |
+
#print(f"outputs shape: {outputs.shape}")
|
| 599 |
+
# outputs: len_inp x d_l x d_model, whereas a normal self-attention layer gets len_inp x d_model
|
| 600 |
+
if self.use_resdrop:
|
| 601 |
+
if self.combine_as_self:
|
| 602 |
+
outputs = self.residual_dropout(outputs, batch_idxs)
|
| 603 |
+
else:
|
| 604 |
+
outputs = torch.cat([self.residual_dropout(outputs[:,i,:], batch_idxs).unsqueeze(1) for i in range(self.d_l)], 1)
|
| 605 |
+
if self.combine_as_self:
|
| 606 |
+
outputs = self.layer_norm(outputs + inp)
|
| 607 |
+
else:
|
| 608 |
+
for l in range(self.d_l):
|
| 609 |
+
outputs[:, l, :] = outputs[:, l, :] + inp
|
| 610 |
+
|
| 611 |
+
outputs = self.reduce_proj(outputs) # len_inp x d_l x d_proj
|
| 612 |
+
outputs = self.layer_norm(outputs) # len_inp x d_l x d_proj
|
| 613 |
+
outputs = outputs.view(len_inp, -1).contiguous() # len_inp x (d_l * d_proj)
|
| 614 |
+
|
| 615 |
+
return outputs, attns_padded
|
| 616 |
+
|
| 617 |
+
|
| 618 |
+
#
|
| 619 |
+
class LabelAttentionModule(nn.Module):
|
| 620 |
+
"""
|
| 621 |
+
Label Attention Module for label-specific representations
|
| 622 |
+
The module can be used right after the Partitioned Attention, or it can be experimented with for the transition stack
|
| 623 |
+
"""
|
| 624 |
+
#
|
| 625 |
+
def __init__(self,
|
| 626 |
+
d_model,
|
| 627 |
+
d_input_proj,
|
| 628 |
+
d_k,
|
| 629 |
+
d_v,
|
| 630 |
+
d_l,
|
| 631 |
+
d_proj,
|
| 632 |
+
combine_as_self,
|
| 633 |
+
use_resdrop=True,
|
| 634 |
+
q_as_matrix=False,
|
| 635 |
+
residual_dropout=0.1,
|
| 636 |
+
attention_dropout=0.1,
|
| 637 |
+
d_positional=None,
|
| 638 |
+
d_ff=2048,
|
| 639 |
+
relu_dropout=0.2,
|
| 640 |
+
lattn_partitioned=True):
|
| 641 |
+
super().__init__()
|
| 642 |
+
self.ff_dim = d_proj * d_l
|
| 643 |
+
|
| 644 |
+
if not lattn_partitioned:
|
| 645 |
+
self.d_positional = 0
|
| 646 |
+
else:
|
| 647 |
+
self.d_positional = d_positional if d_positional else 0
|
| 648 |
+
|
| 649 |
+
if d_input_proj:
|
| 650 |
+
if d_input_proj <= self.d_positional:
|
| 651 |
+
raise ValueError("Illegal argument for d_input_proj: d_input_proj %d is smaller than d_positional %d" % (d_input_proj, self.d_positional))
|
| 652 |
+
self.input_projection = nn.Linear(d_model - self.d_positional, d_input_proj - self.d_positional, bias=False)
|
| 653 |
+
d_input = d_input_proj
|
| 654 |
+
else:
|
| 655 |
+
self.input_projection = None
|
| 656 |
+
d_input = d_model
|
| 657 |
+
|
| 658 |
+
self.label_attention = LabelAttention(d_input,
|
| 659 |
+
d_k,
|
| 660 |
+
d_v,
|
| 661 |
+
d_l,
|
| 662 |
+
d_proj,
|
| 663 |
+
combine_as_self,
|
| 664 |
+
use_resdrop,
|
| 665 |
+
q_as_matrix,
|
| 666 |
+
residual_dropout,
|
| 667 |
+
attention_dropout,
|
| 668 |
+
self.d_positional)
|
| 669 |
+
|
| 670 |
+
if not lattn_partitioned:
|
| 671 |
+
self.lal_ff = PositionwiseFeedForward(self.ff_dim,
|
| 672 |
+
d_ff,
|
| 673 |
+
relu_dropout,
|
| 674 |
+
residual_dropout)
|
| 675 |
+
else:
|
| 676 |
+
self.lal_ff = PartitionedPositionwiseFeedForward(self.ff_dim,
|
| 677 |
+
d_ff,
|
| 678 |
+
self.d_positional,
|
| 679 |
+
relu_dropout,
|
| 680 |
+
residual_dropout)
|
| 681 |
+
|
| 682 |
+
def forward(self, word_embeddings, tagged_word_lists):
|
| 683 |
+
if self.input_projection:
|
| 684 |
+
if self.d_positional > 0:
|
| 685 |
+
word_embeddings = [torch.cat((self.input_projection(sentence[:, :-self.d_positional]),
|
| 686 |
+
sentence[:, -self.d_positional:]), dim=1)
|
| 687 |
+
for sentence in word_embeddings]
|
| 688 |
+
else:
|
| 689 |
+
word_embeddings = [self.input_projection(sentence) for sentence in word_embeddings]
|
| 690 |
+
# Extract Labeled Representation
|
| 691 |
+
packed_len = sum(sentence.shape[0] for sentence in word_embeddings)
|
| 692 |
+
batch_idxs = np.zeros(packed_len, dtype=int)
|
| 693 |
+
|
| 694 |
+
batch_size = len(word_embeddings)
|
| 695 |
+
i = 0
|
| 696 |
+
|
| 697 |
+
sentence_lengths = [0] * batch_size
|
| 698 |
+
for sentence_idx, sentence in enumerate(word_embeddings):
|
| 699 |
+
sentence_lengths[sentence_idx] = len(sentence)
|
| 700 |
+
for word in sentence:
|
| 701 |
+
batch_idxs[i] = sentence_idx
|
| 702 |
+
i += 1
|
| 703 |
+
|
| 704 |
+
batch_indices = batch_idxs
|
| 705 |
+
batch_idxs = BatchIndices(batch_idxs, word_embeddings[0].device)
|
| 706 |
+
|
| 707 |
+
new_embeds = []
|
| 708 |
+
for sentence_idx, batch in enumerate(word_embeddings):
|
| 709 |
+
for word_idx, embed in enumerate(batch):
|
| 710 |
+
if word_idx < sentence_lengths[sentence_idx]:
|
| 711 |
+
new_embeds.append(embed)
|
| 712 |
+
|
| 713 |
+
new_word_embeddings = torch.stack(new_embeds)
|
| 714 |
+
|
| 715 |
+
labeled_representations, _ = self.label_attention(new_word_embeddings, batch_idxs)
|
| 716 |
+
labeled_representations = self.lal_ff(labeled_representations, batch_idxs)
|
| 717 |
+
final_labeled_representations = [[] for i in range(batch_size)]
|
| 718 |
+
|
| 719 |
+
for idx, embed in enumerate(labeled_representations):
|
| 720 |
+
final_labeled_representations[batch_indices[idx]].append(embed)
|
| 721 |
+
|
| 722 |
+
for idx, representation in enumerate(final_labeled_representations):
|
| 723 |
+
final_labeled_representations[idx] = torch.stack(representation)
|
| 724 |
+
|
| 725 |
+
return final_labeled_representations
|
| 726 |
+
|
stanza/stanza/models/constituency/lstm_tree_stack.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Keeps an LSTM in TreeStack form.
|
| 3 |
+
|
| 4 |
+
The TreeStack nodes keep the hx and cx for the LSTM, along with a
|
| 5 |
+
"value" which represents whatever the user needs to store.
|
| 6 |
+
|
| 7 |
+
The TreeStacks can be ppped to get back to the previous LSTM state.
|
| 8 |
+
|
| 9 |
+
The module itself implements three methods: initial_state, push_states, output
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from collections import namedtuple
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
|
| 17 |
+
from stanza.models.constituency.tree_stack import TreeStack
|
| 18 |
+
|
| 19 |
+
Node = namedtuple("Node", ['value', 'lstm_hx', 'lstm_cx'])
|
| 20 |
+
|
| 21 |
+
class LSTMTreeStack(nn.Module):
|
| 22 |
+
def __init__(self, input_size, hidden_size, num_lstm_layers, dropout, uses_boundary_vector, input_dropout):
|
| 23 |
+
"""
|
| 24 |
+
Prepare LSTM and parameters
|
| 25 |
+
|
| 26 |
+
input_size: dimension of the inputs to the LSTM
|
| 27 |
+
hidden_size: LSTM internal & output dimension
|
| 28 |
+
num_lstm_layers: how many layers of LSTM to use
|
| 29 |
+
dropout: value of the LSTM dropout
|
| 30 |
+
uses_boundary_vector: if set, learn a start_embedding parameter. otherwise, use zeros
|
| 31 |
+
input_dropout: an nn.Module to dropout inputs. TODO: allow a float parameter as well
|
| 32 |
+
"""
|
| 33 |
+
super().__init__()
|
| 34 |
+
|
| 35 |
+
self.uses_boundary_vector = uses_boundary_vector
|
| 36 |
+
|
| 37 |
+
# The start embedding needs to be input_size as we put it through the LSTM
|
| 38 |
+
if uses_boundary_vector:
|
| 39 |
+
self.register_parameter('start_embedding', torch.nn.Parameter(0.2 * torch.randn(input_size, requires_grad=True)))
|
| 40 |
+
else:
|
| 41 |
+
self.register_buffer('input_zeros', torch.zeros(num_lstm_layers, 1, input_size))
|
| 42 |
+
self.register_buffer('hidden_zeros', torch.zeros(num_lstm_layers, 1, hidden_size))
|
| 43 |
+
|
| 44 |
+
self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_lstm_layers, dropout=dropout)
|
| 45 |
+
self.input_dropout = input_dropout
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def initial_state(self, initial_value=None):
|
| 49 |
+
"""
|
| 50 |
+
Return an initial state, either based on zeros or based on the initial embedding and LSTM
|
| 51 |
+
|
| 52 |
+
Note that LSTM start operation is already batched, in a sense
|
| 53 |
+
The subsequent batch built this way will be used for batch_size trees
|
| 54 |
+
|
| 55 |
+
Returns a stack with None value, hx & cx either based on the
|
| 56 |
+
start_embedding or zeros, and no parent.
|
| 57 |
+
"""
|
| 58 |
+
if self.uses_boundary_vector:
|
| 59 |
+
start = self.start_embedding.unsqueeze(0).unsqueeze(0)
|
| 60 |
+
output, (hx, cx) = self.lstm(start)
|
| 61 |
+
start = output[0, 0, :]
|
| 62 |
+
else:
|
| 63 |
+
start = self.input_zeros
|
| 64 |
+
hx = self.hidden_zeros
|
| 65 |
+
cx = self.hidden_zeros
|
| 66 |
+
return TreeStack(value=Node(initial_value, hx, cx), parent=None, length=1)
|
| 67 |
+
|
| 68 |
+
def push_states(self, stacks, values, inputs):
|
| 69 |
+
"""
|
| 70 |
+
Starting from a list of current stacks, put the inputs through the LSTM and build new stack nodes.
|
| 71 |
+
|
| 72 |
+
B = stacks.len() = values.len()
|
| 73 |
+
|
| 74 |
+
inputs must be of shape 1 x B x input_size
|
| 75 |
+
"""
|
| 76 |
+
inputs = self.input_dropout(inputs)
|
| 77 |
+
|
| 78 |
+
hx = torch.cat([t.value.lstm_hx for t in stacks], axis=1)
|
| 79 |
+
cx = torch.cat([t.value.lstm_cx for t in stacks], axis=1)
|
| 80 |
+
output, (hx, cx) = self.lstm(inputs, (hx, cx))
|
| 81 |
+
new_stacks = [stack.push(Node(transition, hx[:, i:i+1, :], cx[:, i:i+1, :]))
|
| 82 |
+
for i, (stack, transition) in enumerate(zip(stacks, values))]
|
| 83 |
+
return new_stacks
|
| 84 |
+
|
| 85 |
+
def output(self, stack):
|
| 86 |
+
"""
|
| 87 |
+
Return the last layer of the lstm_hx as the output from a stack
|
| 88 |
+
|
| 89 |
+
Refactored so that alternate structures have an easy way of getting the output
|
| 90 |
+
"""
|
| 91 |
+
return stack.value.lstm_hx[-1, 0, :]
|
stanza/stanza/models/constituency/score_converted_dependencies.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Script which processes a dependency file by using the constituency parser, then converting with the CoreNLP converter
|
| 3 |
+
|
| 4 |
+
Currently this does not have the constituency parser as an option,
|
| 5 |
+
although that is easy to add.
|
| 6 |
+
|
| 7 |
+
Only English is supported, as only English is available in the CoreNLP converter
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import argparse
|
| 11 |
+
import os
|
| 12 |
+
import tempfile
|
| 13 |
+
|
| 14 |
+
import stanza
|
| 15 |
+
from stanza.models.constituency import retagging
|
| 16 |
+
from stanza.models.depparse import scorer
|
| 17 |
+
from stanza.utils.conll import CoNLL
|
| 18 |
+
|
| 19 |
+
def score_converted_dependencies(args):
|
| 20 |
+
if args['lang'] != 'en':
|
| 21 |
+
raise ValueError("Converting and scoring dependencies is currently only supported for English")
|
| 22 |
+
|
| 23 |
+
constituency_package = args['constituency_package']
|
| 24 |
+
pipeline_args = {'lang': args['lang'],
|
| 25 |
+
'tokenize_pretokenized': True,
|
| 26 |
+
'package': {'pos': args['retag_package'], 'depparse': 'converter', 'constituency': constituency_package},
|
| 27 |
+
'processors': 'tokenize, pos, constituency, depparse'}
|
| 28 |
+
pipeline = stanza.Pipeline(**pipeline_args)
|
| 29 |
+
|
| 30 |
+
input_doc = CoNLL.conll2doc(args['eval_file'])
|
| 31 |
+
output_doc = pipeline(input_doc)
|
| 32 |
+
print("Processed %d sentences" % len(output_doc.sentences))
|
| 33 |
+
# reload - the pipeline clobbered the gold values
|
| 34 |
+
input_doc = CoNLL.conll2doc(args['eval_file'])
|
| 35 |
+
|
| 36 |
+
scorer.score_named_dependencies(output_doc, input_doc)
|
| 37 |
+
with tempfile.TemporaryDirectory() as tempdir:
|
| 38 |
+
output_path = os.path.join(tempdir, "converted.conll")
|
| 39 |
+
|
| 40 |
+
CoNLL.write_doc2conll(output_doc, output_path)
|
| 41 |
+
|
| 42 |
+
_, _, score = scorer.score(output_path, args['eval_file'])
|
| 43 |
+
|
| 44 |
+
print("Parser score:")
|
| 45 |
+
print("{} {:.2f}".format(constituency_package, score*100))
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def main():
|
| 49 |
+
parser = argparse.ArgumentParser()
|
| 50 |
+
|
| 51 |
+
parser.add_argument('--lang', default='en', type=str, help='Language')
|
| 52 |
+
parser.add_argument('--eval_file', default="extern_data/ud2/ud-treebanks-v2.13/UD_English-EWT/en_ewt-ud-test.conllu", help='Input file for data loader.')
|
| 53 |
+
parser.add_argument('--constituency_package', default="ptb3-revised_electra-large", help='Which constituency parser to use for converting')
|
| 54 |
+
|
| 55 |
+
retagging.add_retag_args(parser)
|
| 56 |
+
args = parser.parse_args()
|
| 57 |
+
|
| 58 |
+
args = vars(args)
|
| 59 |
+
retagging.postprocess_args(args)
|
| 60 |
+
|
| 61 |
+
score_converted_dependencies(args)
|
| 62 |
+
|
| 63 |
+
if __name__ == '__main__':
|
| 64 |
+
main()
|
| 65 |
+
|
stanza/stanza/models/constituency/text_processing.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
|
| 5 |
+
from stanza.models.common import utils
|
| 6 |
+
from stanza.models.constituency.utils import retag_tags
|
| 7 |
+
from stanza.models.constituency.trainer import Trainer
|
| 8 |
+
from stanza.models.constituency.tree_reader import read_trees
|
| 9 |
+
from stanza.utils.get_tqdm import get_tqdm
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger('stanza')
|
| 12 |
+
tqdm = get_tqdm()
|
| 13 |
+
|
| 14 |
+
def read_tokenized_file(tokenized_file):
|
| 15 |
+
"""
|
| 16 |
+
Read sentences from a tokenized file, potentially replacing _ with space for languages such as VI
|
| 17 |
+
"""
|
| 18 |
+
with open(tokenized_file, encoding='utf-8') as fin:
|
| 19 |
+
lines = fin.readlines()
|
| 20 |
+
lines = [x.strip() for x in lines]
|
| 21 |
+
lines = [x for x in lines if x]
|
| 22 |
+
docs = [[word if all(x == '_' for x in word) else word.replace("_", " ") for word in sentence.split()] for sentence in lines]
|
| 23 |
+
ids = [None] * len(docs)
|
| 24 |
+
return docs, ids
|
| 25 |
+
|
| 26 |
+
def read_xml_tree_file(tree_file):
|
| 27 |
+
"""
|
| 28 |
+
Read sentences from a file of the format unique to VLSP test sets
|
| 29 |
+
|
| 30 |
+
in particular, it should be multiple blocks of
|
| 31 |
+
|
| 32 |
+
<s id=1>
|
| 33 |
+
(tree ...)
|
| 34 |
+
</s>
|
| 35 |
+
"""
|
| 36 |
+
with open(tree_file, encoding='utf-8') as fin:
|
| 37 |
+
lines = fin.readlines()
|
| 38 |
+
lines = [x.strip() for x in lines]
|
| 39 |
+
lines = [x for x in lines if x]
|
| 40 |
+
docs = []
|
| 41 |
+
ids = []
|
| 42 |
+
tree_id = None
|
| 43 |
+
tree_text = []
|
| 44 |
+
for line in lines:
|
| 45 |
+
if line.startswith("<s"):
|
| 46 |
+
tree_id = line.split("=")
|
| 47 |
+
if len(tree_id) > 1:
|
| 48 |
+
tree_id = tree_id[1]
|
| 49 |
+
if tree_id.endswith(">"):
|
| 50 |
+
tree_id = tree_id[:-1]
|
| 51 |
+
tree_id = int(tree_id)
|
| 52 |
+
else:
|
| 53 |
+
tree_id = None
|
| 54 |
+
elif line.startswith("</s"):
|
| 55 |
+
if len(tree_text) == 0:
|
| 56 |
+
raise ValueError("Found a blank tree in %s" % tree_file)
|
| 57 |
+
ids.append(tree_id)
|
| 58 |
+
tree_text = "\n".join(tree_text)
|
| 59 |
+
trees = read_trees(tree_text)
|
| 60 |
+
# TODO: perhaps the processing can be put into read_trees instead
|
| 61 |
+
trees = [t.prune_none().simplify_labels() for t in trees]
|
| 62 |
+
if len(trees) != 1:
|
| 63 |
+
raise ValueError("Found a tree with %d trees in %s" % (len(trees), tree_file))
|
| 64 |
+
tree = trees[0]
|
| 65 |
+
text = tree.leaf_labels()
|
| 66 |
+
text = [word if all(x == '_' for x in word) else word.replace("_", " ") for word in text]
|
| 67 |
+
docs.append(text)
|
| 68 |
+
tree_text = []
|
| 69 |
+
tree_id = None
|
| 70 |
+
else:
|
| 71 |
+
tree_text.append(line)
|
| 72 |
+
|
| 73 |
+
return docs, ids
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def parse_tokenized_sentences(args, model, retag_pipeline, sentences):
|
| 77 |
+
"""
|
| 78 |
+
Parse the given sentences, return a list of ParseResult objects
|
| 79 |
+
"""
|
| 80 |
+
tags = retag_tags(sentences, retag_pipeline, model.uses_xpos())
|
| 81 |
+
words = [[(word, tag) for word, tag in zip(s_words, s_tags)] for s_words, s_tags in zip(sentences, tags)]
|
| 82 |
+
logger.info("Retagging finished. Parsing tagged text")
|
| 83 |
+
|
| 84 |
+
assert len(words) == len(sentences)
|
| 85 |
+
treebank = model.parse_sentences_no_grad(iter(tqdm(words)), model.build_batch_from_tagged_words, args['eval_batch_size'], model.predict, keep_scores=False)
|
| 86 |
+
return treebank
|
| 87 |
+
|
| 88 |
+
def parse_text(args, model, retag_pipeline, tokenized_file=None, predict_file=None):
|
| 89 |
+
"""
|
| 90 |
+
Use the given model to parse text and write it
|
| 91 |
+
|
| 92 |
+
refactored so it can be used elsewhere, such as Ensemble
|
| 93 |
+
"""
|
| 94 |
+
model.eval()
|
| 95 |
+
|
| 96 |
+
if predict_file is None:
|
| 97 |
+
if args['predict_file']:
|
| 98 |
+
predict_file = args['predict_file']
|
| 99 |
+
if args['predict_dir']:
|
| 100 |
+
predict_file = os.path.join(args['predict_dir'], predict_file)
|
| 101 |
+
|
| 102 |
+
if tokenized_file is None:
|
| 103 |
+
tokenized_file = args['tokenized_file']
|
| 104 |
+
|
| 105 |
+
docs, ids = None, None
|
| 106 |
+
if tokenized_file is not None:
|
| 107 |
+
docs, ids = read_tokenized_file(tokenized_file)
|
| 108 |
+
elif args['xml_tree_file']:
|
| 109 |
+
logger.info("Reading trees from %s" % args['xml_tree_file'])
|
| 110 |
+
docs, ids = read_xml_tree_file(args['xml_tree_file'])
|
| 111 |
+
|
| 112 |
+
if not docs:
|
| 113 |
+
logger.error("No sentences to process!")
|
| 114 |
+
return
|
| 115 |
+
|
| 116 |
+
logger.info("Processing %d sentences", len(docs))
|
| 117 |
+
|
| 118 |
+
with utils.output_stream(predict_file) as fout:
|
| 119 |
+
chunk_size = 10000
|
| 120 |
+
for chunk_start in range(0, len(docs), chunk_size):
|
| 121 |
+
chunk = docs[chunk_start:chunk_start+chunk_size]
|
| 122 |
+
ids_chunk = ids[chunk_start:chunk_start+chunk_size]
|
| 123 |
+
logger.info("Processing trees %d to %d", chunk_start, chunk_start+len(chunk))
|
| 124 |
+
treebank = parse_tokenized_sentences(args, model, retag_pipeline, chunk)
|
| 125 |
+
|
| 126 |
+
for result, tree_id in zip(treebank, ids_chunk):
|
| 127 |
+
tree = result.predictions[0].tree
|
| 128 |
+
if tree_id is not None:
|
| 129 |
+
tree.tree_id = tree_id
|
| 130 |
+
fout.write(args['predict_format'].format(tree))
|
| 131 |
+
fout.write("\n")
|
| 132 |
+
|
| 133 |
+
def parse_dir(args, model, retag_pipeline, tokenized_dir, predict_dir):
|
| 134 |
+
os.makedirs(predict_dir, exist_ok=True)
|
| 135 |
+
for filename in os.listdir(tokenized_dir):
|
| 136 |
+
input_path = os.path.join(tokenized_dir, filename)
|
| 137 |
+
output_path = os.path.join(predict_dir, os.path.splitext(filename)[0] + ".mrg")
|
| 138 |
+
logger.info("Processing %s to %s", input_path, output_path)
|
| 139 |
+
parse_text(args, model, retag_pipeline, tokenized_file=input_path, predict_file=output_path)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def load_model_parse_text(args, model_file, retag_pipeline):
|
| 143 |
+
"""
|
| 144 |
+
Load a model, then parse text and write it to stdout or args['predict_file']
|
| 145 |
+
|
| 146 |
+
retag_pipeline: a list of Pipeline meant to use for retagging
|
| 147 |
+
"""
|
| 148 |
+
foundation_cache = retag_pipeline[0].foundation_cache if retag_pipeline else FoundationCache()
|
| 149 |
+
load_args = {
|
| 150 |
+
'wordvec_pretrain_file': args['wordvec_pretrain_file'],
|
| 151 |
+
'charlm_forward_file': args['charlm_forward_file'],
|
| 152 |
+
'charlm_backward_file': args['charlm_backward_file'],
|
| 153 |
+
'device': args['device'],
|
| 154 |
+
}
|
| 155 |
+
trainer = Trainer.load(model_file, args=load_args, foundation_cache=foundation_cache)
|
| 156 |
+
model = trainer.model
|
| 157 |
+
model.eval()
|
| 158 |
+
logger.info("Loaded model from %s", model_file)
|
| 159 |
+
|
| 160 |
+
if args['tokenized_dir']:
|
| 161 |
+
if not args['predict_dir']:
|
| 162 |
+
raise ValueError("Must specific --predict_dir to go with --tokenized_dir")
|
| 163 |
+
parse_dir(args, model, retag_pipeline, args['tokenized_dir'], args['predict_dir'])
|
| 164 |
+
else:
|
| 165 |
+
parse_text(args, model, retag_pipeline)
|
| 166 |
+
|
stanza/stanza/models/constituency/tree_reader.py
ADDED
|
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Reads ParseTree objects from a file, string, or similar input
|
| 3 |
+
|
| 4 |
+
Works by first splitting the input into (, ), and all other tokens,
|
| 5 |
+
then recursively processing those tokens into trees.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from collections import deque
|
| 9 |
+
import logging
|
| 10 |
+
import os
|
| 11 |
+
import re
|
| 12 |
+
|
| 13 |
+
from stanza.models.constituency.parse_tree import Tree
|
| 14 |
+
from stanza.utils.get_tqdm import get_tqdm
|
| 15 |
+
|
| 16 |
+
tqdm = get_tqdm()
|
| 17 |
+
|
| 18 |
+
OPEN_PAREN = "("
|
| 19 |
+
CLOSE_PAREN = ")"
|
| 20 |
+
|
| 21 |
+
logger = logging.getLogger('stanza.constituency')
|
| 22 |
+
|
| 23 |
+
# A few specific exception types to clarify parsing errors
|
| 24 |
+
# They store the line number where the error occurred
|
| 25 |
+
|
| 26 |
+
class UnclosedTreeError(ValueError):
|
| 27 |
+
"""
|
| 28 |
+
A tree looked like (Foo
|
| 29 |
+
"""
|
| 30 |
+
def __init__(self, line_num):
|
| 31 |
+
super().__init__("Found an unfinished tree (missing close brackets). Tree started on line %d" % line_num)
|
| 32 |
+
self.line_num = line_num
|
| 33 |
+
|
| 34 |
+
class ExtraCloseTreeError(ValueError):
|
| 35 |
+
"""
|
| 36 |
+
A tree looked like (Foo))
|
| 37 |
+
"""
|
| 38 |
+
def __init__(self, line_num):
|
| 39 |
+
super().__init__("Found a broken tree (extra close brackets). Tree started on line %d" % line_num)
|
| 40 |
+
self.line_num = line_num
|
| 41 |
+
|
| 42 |
+
class UnlabeledTreeError(ValueError):
|
| 43 |
+
"""
|
| 44 |
+
A tree had no label, such as ((Foo) (Bar))
|
| 45 |
+
|
| 46 |
+
This does not actually happen at the root, btw, as ROOT is silently added
|
| 47 |
+
"""
|
| 48 |
+
def __init__(self, line_num):
|
| 49 |
+
super().__init__("Found a tree with no label on a node! Line number %d" % line_num)
|
| 50 |
+
self.line_num = line_num
|
| 51 |
+
|
| 52 |
+
class MixedTreeError(ValueError):
|
| 53 |
+
"""
|
| 54 |
+
Leaf and constituent children are mixed in the same node
|
| 55 |
+
"""
|
| 56 |
+
def __init__(self, line_num, child_label, children):
|
| 57 |
+
super().__init__("Found a tree with both text children and bracketed children! Line number {} Child label {} Children {}".format(line_num, child_label, children))
|
| 58 |
+
self.line_num = line_num
|
| 59 |
+
self.child_label = child_label
|
| 60 |
+
self.children = children
|
| 61 |
+
|
| 62 |
+
def normalize(text):
|
| 63 |
+
return text.replace("-LRB-", "(").replace("-RRB-", ")")
|
| 64 |
+
|
| 65 |
+
def read_single_tree(token_iterator, broken_ok):
|
| 66 |
+
"""
|
| 67 |
+
Build a tree from the tokens in the token_iterator
|
| 68 |
+
"""
|
| 69 |
+
# we were called here at a open paren, so start the stack of
|
| 70 |
+
# children with one empty list already on it
|
| 71 |
+
children_stack = deque()
|
| 72 |
+
children_stack.append([])
|
| 73 |
+
text_stack = deque()
|
| 74 |
+
text_stack.append([])
|
| 75 |
+
|
| 76 |
+
token = next(token_iterator, None)
|
| 77 |
+
token_iterator.set_mark()
|
| 78 |
+
while token is not None:
|
| 79 |
+
if token == OPEN_PAREN:
|
| 80 |
+
children_stack.append([])
|
| 81 |
+
text_stack.append([])
|
| 82 |
+
elif token == CLOSE_PAREN:
|
| 83 |
+
text = text_stack.pop()
|
| 84 |
+
children = children_stack.pop()
|
| 85 |
+
if text:
|
| 86 |
+
pieces = " ".join(text).split()
|
| 87 |
+
if len(pieces) == 1:
|
| 88 |
+
child = Tree(pieces[0], children)
|
| 89 |
+
else:
|
| 90 |
+
# the assumption here is that a language such as VI may
|
| 91 |
+
# have spaces in the words, but it still represents
|
| 92 |
+
# just one child
|
| 93 |
+
label = pieces[0]
|
| 94 |
+
child_label = " ".join(pieces[1:])
|
| 95 |
+
if children:
|
| 96 |
+
if broken_ok:
|
| 97 |
+
child = Tree(label, children + [Tree(normalize(child_label))])
|
| 98 |
+
else:
|
| 99 |
+
raise MixedTreeError(token_iterator.line_num, child_label, children)
|
| 100 |
+
else:
|
| 101 |
+
child = Tree(label, Tree(normalize(child_label)))
|
| 102 |
+
if not children_stack:
|
| 103 |
+
return child
|
| 104 |
+
else:
|
| 105 |
+
if not children_stack:
|
| 106 |
+
return Tree("ROOT", children)
|
| 107 |
+
elif broken_ok:
|
| 108 |
+
child = Tree(None, children)
|
| 109 |
+
else:
|
| 110 |
+
raise UnlabeledTreeError(token_iterator.line_num)
|
| 111 |
+
children_stack[-1].append(child)
|
| 112 |
+
else:
|
| 113 |
+
text_stack[-1].append(token)
|
| 114 |
+
token = next(token_iterator, None)
|
| 115 |
+
raise UnclosedTreeError(token_iterator.get_mark())
|
| 116 |
+
|
| 117 |
+
LINE_SPLIT_RE = re.compile(r"([()])")
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
class TokenIterator:
|
| 121 |
+
"""
|
| 122 |
+
A specific iterator for reading trees from a tree file
|
| 123 |
+
|
| 124 |
+
The idea is that this will keep track of which line
|
| 125 |
+
we are processing, so that an error can be logged
|
| 126 |
+
from the correct line
|
| 127 |
+
"""
|
| 128 |
+
def __init__(self):
|
| 129 |
+
self.token_iterator = iter([])
|
| 130 |
+
self.line_num = -1
|
| 131 |
+
self.mark = None
|
| 132 |
+
|
| 133 |
+
def set_mark(self):
|
| 134 |
+
"""
|
| 135 |
+
The mark is used for determining where the start of a tree occurs for an error
|
| 136 |
+
"""
|
| 137 |
+
self.mark = self.line_num
|
| 138 |
+
|
| 139 |
+
def get_mark(self):
|
| 140 |
+
if self.mark is None:
|
| 141 |
+
raise ValueError("No mark set!")
|
| 142 |
+
return self.mark
|
| 143 |
+
|
| 144 |
+
def __iter__(self):
|
| 145 |
+
return self
|
| 146 |
+
|
| 147 |
+
def __next__(self):
|
| 148 |
+
n = next(self.token_iterator, None)
|
| 149 |
+
while n is None:
|
| 150 |
+
self.line_num = self.line_num + 1
|
| 151 |
+
line = next(self.line_iterator)
|
| 152 |
+
if line is None:
|
| 153 |
+
raise StopIteration
|
| 154 |
+
line = line.strip()
|
| 155 |
+
if not line:
|
| 156 |
+
continue
|
| 157 |
+
|
| 158 |
+
pieces = LINE_SPLIT_RE.split(line)
|
| 159 |
+
pieces = [x.strip() for x in pieces]
|
| 160 |
+
pieces = [x for x in pieces if x]
|
| 161 |
+
self.token_iterator = iter(pieces)
|
| 162 |
+
n = next(self.token_iterator, None)
|
| 163 |
+
|
| 164 |
+
return n
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
class TextTokenIterator(TokenIterator):
|
| 168 |
+
def __init__(self, text, use_tqdm=True):
|
| 169 |
+
super().__init__()
|
| 170 |
+
|
| 171 |
+
self.lines = text.split("\n")
|
| 172 |
+
self.num_lines = len(self.lines)
|
| 173 |
+
if self.num_lines > 1000 and use_tqdm:
|
| 174 |
+
self.line_iterator = iter(tqdm(self.lines))
|
| 175 |
+
else:
|
| 176 |
+
self.line_iterator = iter(self.lines)
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
class FileTokenIterator(TokenIterator):
|
| 180 |
+
def __init__(self, filename):
|
| 181 |
+
super().__init__()
|
| 182 |
+
self.filename = filename
|
| 183 |
+
|
| 184 |
+
def __enter__(self):
|
| 185 |
+
# TODO: use the file_size instead of counting the lines
|
| 186 |
+
# file_size = Path(self.filename).stat().st_size
|
| 187 |
+
with open(self.filename) as fin:
|
| 188 |
+
num_lines = sum(1 for _ in fin)
|
| 189 |
+
|
| 190 |
+
self.file_obj = open(self.filename)
|
| 191 |
+
if num_lines > 1000:
|
| 192 |
+
self.line_iterator = iter(tqdm(self.file_obj, total=num_lines))
|
| 193 |
+
else:
|
| 194 |
+
self.line_iterator = iter(self.file_obj)
|
| 195 |
+
return self
|
| 196 |
+
|
| 197 |
+
def __exit__(self, exc_type, exc_value, exc_tb):
|
| 198 |
+
if self.file_obj:
|
| 199 |
+
self.file_obj.close()
|
| 200 |
+
|
| 201 |
+
def read_token_iterator(token_iterator, broken_ok, tree_callback):
|
| 202 |
+
trees = []
|
| 203 |
+
token = next(token_iterator, None)
|
| 204 |
+
while token:
|
| 205 |
+
if token == OPEN_PAREN:
|
| 206 |
+
next_tree = read_single_tree(token_iterator, broken_ok=broken_ok)
|
| 207 |
+
if next_tree is None:
|
| 208 |
+
raise ValueError("Tree reader somehow created a None tree! Line number %d" % token_iterator.line_num)
|
| 209 |
+
if tree_callback is not None:
|
| 210 |
+
transformed = tree_callback(next_tree)
|
| 211 |
+
if transformed is not None:
|
| 212 |
+
trees.append(transformed)
|
| 213 |
+
else:
|
| 214 |
+
trees.append(next_tree)
|
| 215 |
+
token = next(token_iterator, None)
|
| 216 |
+
elif token == CLOSE_PAREN:
|
| 217 |
+
raise ExtraCloseTreeError(token_iterator.line_num)
|
| 218 |
+
else:
|
| 219 |
+
raise ValueError("Tree document had text between trees! Line number %d" % token_iterator.line_num)
|
| 220 |
+
|
| 221 |
+
return trees
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def read_trees(text, broken_ok=False, tree_callback=None, use_tqdm=True):
|
| 225 |
+
"""
|
| 226 |
+
Reads multiple trees from the text
|
| 227 |
+
|
| 228 |
+
TODO: some of the error cases we hit can be recovered from
|
| 229 |
+
"""
|
| 230 |
+
token_iterator = TextTokenIterator(text, use_tqdm)
|
| 231 |
+
return read_token_iterator(token_iterator, broken_ok=broken_ok, tree_callback=tree_callback)
|
| 232 |
+
|
| 233 |
+
def read_tree_file(filename, broken_ok=False, tree_callback=None):
|
| 234 |
+
"""
|
| 235 |
+
Read all of the trees in the given file
|
| 236 |
+
"""
|
| 237 |
+
with FileTokenIterator(filename) as token_iterator:
|
| 238 |
+
trees = read_token_iterator(token_iterator, broken_ok=broken_ok, tree_callback=tree_callback)
|
| 239 |
+
return trees
|
| 240 |
+
|
| 241 |
+
def read_directory(dirname, broken_ok=False, tree_callback=None):
|
| 242 |
+
"""
|
| 243 |
+
Read all of the trees in all of the files in a directory
|
| 244 |
+
"""
|
| 245 |
+
trees = []
|
| 246 |
+
for filename in sorted(os.listdir(dirname)):
|
| 247 |
+
full_name = os.path.join(dirname, filename)
|
| 248 |
+
trees.extend(read_tree_file(full_name, broken_ok, tree_callback))
|
| 249 |
+
return trees
|
| 250 |
+
|
| 251 |
+
def read_treebank(filename, tree_callback=None):
|
| 252 |
+
"""
|
| 253 |
+
Read a treebank and alter the trees to be a simpler format for learning to parse
|
| 254 |
+
"""
|
| 255 |
+
logger.info("Reading trees from %s", filename)
|
| 256 |
+
trees = read_tree_file(filename, tree_callback=tree_callback)
|
| 257 |
+
trees = [t.prune_none().simplify_labels() for t in trees]
|
| 258 |
+
|
| 259 |
+
illegal_trees = [t for t in trees if len(t.children) > 1]
|
| 260 |
+
if len(illegal_trees) > 0:
|
| 261 |
+
raise ValueError("Found {} tree(s) which had non-unary transitions at the ROOT. First illegal tree: {:P}".format(len(illegal_trees), illegal_trees[0]))
|
| 262 |
+
|
| 263 |
+
return trees
|
| 264 |
+
|
| 265 |
+
def main():
|
| 266 |
+
"""
|
| 267 |
+
Reads a sample tree
|
| 268 |
+
"""
|
| 269 |
+
text="( (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))"
|
| 270 |
+
trees = read_trees(text)
|
| 271 |
+
print(trees)
|
| 272 |
+
|
| 273 |
+
if __name__ == '__main__':
|
| 274 |
+
main()
|
stanza/stanza/models/constituency/tree_stack.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
A utilitiy class for keeping track of intermediate parse states
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from collections import namedtuple
|
| 6 |
+
|
| 7 |
+
class TreeStack(namedtuple('TreeStack', ['value', 'parent', 'length'])):
|
| 8 |
+
"""
|
| 9 |
+
A stack which can branch in several directions, as long as you
|
| 10 |
+
keep track of the branching heads
|
| 11 |
+
|
| 12 |
+
An example usage is when K constituents are removed at once
|
| 13 |
+
to create a new constituent, and then the LSTM which tracks the
|
| 14 |
+
values of the constituents is updated starting from the Kth
|
| 15 |
+
output of the LSTM with the new value.
|
| 16 |
+
|
| 17 |
+
We don't simply keep track of a single stack object using a deque
|
| 18 |
+
because versions of the parser which use a beam will want to be
|
| 19 |
+
able to branch in different directions from the same base stack
|
| 20 |
+
|
| 21 |
+
Another possible usage is if an oracle is used for training
|
| 22 |
+
in a manner where some fraction of steps are non-gold steps,
|
| 23 |
+
but we also want to take a gold step from the same state.
|
| 24 |
+
Eg, parser gets to state X, wants to make incorrect transition T
|
| 25 |
+
instead of gold transition G, and so we continue training both
|
| 26 |
+
X+G and X+T. If we only represent the state X with standard
|
| 27 |
+
python stacks, it would not be possible to track both of these
|
| 28 |
+
states at the same time without copying the entire thing.
|
| 29 |
+
|
| 30 |
+
Value can be as transition, a word, or a partially built constituent
|
| 31 |
+
|
| 32 |
+
Implemented as a namedtuple to make it a bit more efficient
|
| 33 |
+
"""
|
| 34 |
+
def pop(self):
|
| 35 |
+
return self.parent
|
| 36 |
+
|
| 37 |
+
def push(self, value):
|
| 38 |
+
# returns a new stack node which points to this
|
| 39 |
+
return TreeStack(value, self, self.length+1)
|
| 40 |
+
|
| 41 |
+
def __iter__(self):
|
| 42 |
+
stack = self
|
| 43 |
+
while stack.parent is not None:
|
| 44 |
+
yield stack.value
|
| 45 |
+
stack = stack.parent
|
| 46 |
+
yield stack.value
|
| 47 |
+
|
| 48 |
+
def __reversed__(self):
|
| 49 |
+
items = list(iter(self))
|
| 50 |
+
for item in reversed(items):
|
| 51 |
+
yield item
|
| 52 |
+
|
| 53 |
+
def __str__(self):
|
| 54 |
+
return "TreeStack(%s)" % ", ".join([str(x) for x in self])
|
| 55 |
+
|
| 56 |
+
def __len__(self):
|
| 57 |
+
return self.length
|
stanza/stanza/models/constituency/utils.py
ADDED
|
@@ -0,0 +1,375 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Collects a few of the conparser utility methods which don't belong elsewhere
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from collections import Counter
|
| 6 |
+
import logging
|
| 7 |
+
import warnings
|
| 8 |
+
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
from torch import optim
|
| 11 |
+
|
| 12 |
+
from stanza.models.common.doc import TEXT, Document
|
| 13 |
+
from stanza.models.common.utils import get_optimizer
|
| 14 |
+
from stanza.models.constituency.base_model import SimpleModel
|
| 15 |
+
from stanza.models.constituency.parse_transitions import TransitionScheme
|
| 16 |
+
from stanza.models.constituency.parse_tree import Tree
|
| 17 |
+
from stanza.utils.get_tqdm import get_tqdm
|
| 18 |
+
|
| 19 |
+
tqdm = get_tqdm()
|
| 20 |
+
|
| 21 |
+
DEFAULT_LEARNING_RATES = { "adamw": 0.0002, "adadelta": 1.0, "sgd": 0.001, "adabelief": 0.00005, "madgrad": 0.0000007 , "mirror_madgrad": 0.00005 }
|
| 22 |
+
DEFAULT_LEARNING_EPS = { "adabelief": 1e-12, "adadelta": 1e-6, "adamw": 1e-8 }
|
| 23 |
+
DEFAULT_LEARNING_RHO = 0.9
|
| 24 |
+
DEFAULT_MOMENTUM = { "madgrad": 0.9, "mirror_madgrad": 0.9, "sgd": 0.9 }
|
| 25 |
+
|
| 26 |
+
tlogger = logging.getLogger('stanza.constituency.trainer')
|
| 27 |
+
|
| 28 |
+
# madgrad experiment for weight decay
|
| 29 |
+
# with learning_rate set to 0.0000007 and momentum 0.9
|
| 30 |
+
# on en_wsj, with a baseline model trained on adadela for 200,
|
| 31 |
+
# then madgrad used to further improve that model
|
| 32 |
+
# 0.00000002.out: 0.9590347746438835
|
| 33 |
+
# 0.00000005.out: 0.9591378819960182
|
| 34 |
+
# 0.0000001.out: 0.9595450596319405
|
| 35 |
+
# 0.0000002.out: 0.9594603134479271
|
| 36 |
+
# 0.0000005.out: 0.9591317672706594
|
| 37 |
+
# 0.000001.out: 0.9592548741021389
|
| 38 |
+
# 0.000002.out: 0.9598395477013945
|
| 39 |
+
# 0.000003.out: 0.9594974271553495
|
| 40 |
+
# 0.000004.out: 0.9596665982603754
|
| 41 |
+
# 0.000005.out: 0.9591620720706487
|
| 42 |
+
DEFAULT_WEIGHT_DECAY = { "adamw": 0.05, "adadelta": 0.02, "sgd": 0.01, "adabelief": 1.2e-6, "madgrad": 2e-6, "mirror_madgrad": 2e-6 }
|
| 43 |
+
|
| 44 |
+
def retag_tags(doc, pipelines, xpos):
|
| 45 |
+
"""
|
| 46 |
+
Returns a list of list of tags for the items in doc
|
| 47 |
+
|
| 48 |
+
doc can be anything which feeds into the pipeline(s)
|
| 49 |
+
pipelines are a list of 1 or more retag pipelines
|
| 50 |
+
if multiple pipelines are given, majority vote wins
|
| 51 |
+
"""
|
| 52 |
+
tag_lists = []
|
| 53 |
+
for pipeline in pipelines:
|
| 54 |
+
doc = pipeline(doc)
|
| 55 |
+
tag_lists.append([[x.xpos if xpos else x.upos for x in sentence.words] for sentence in doc.sentences])
|
| 56 |
+
# tag_lists: for N pipeline, S sentences
|
| 57 |
+
# we now have N lists of S sentences each
|
| 58 |
+
# for sentence in zip(*tag_lists): N lists of |s| tags for this given sentence s
|
| 59 |
+
# for tag in zip(*sentence): N predicted tags.
|
| 60 |
+
# most common one in the Counter will be chosen
|
| 61 |
+
tag_lists = [[Counter(tag).most_common(1)[0][0] for tag in zip(*sentence)]
|
| 62 |
+
for sentence in zip(*tag_lists)]
|
| 63 |
+
return tag_lists
|
| 64 |
+
|
| 65 |
+
def retag_trees(trees, pipelines, xpos=True):
|
| 66 |
+
"""
|
| 67 |
+
Retag all of the trees using the given processor
|
| 68 |
+
|
| 69 |
+
Returns a list of new trees
|
| 70 |
+
"""
|
| 71 |
+
if len(trees) == 0:
|
| 72 |
+
return trees
|
| 73 |
+
|
| 74 |
+
new_trees = []
|
| 75 |
+
chunk_size = 1000
|
| 76 |
+
with tqdm(total=len(trees)) as pbar:
|
| 77 |
+
for chunk_start in range(0, len(trees), chunk_size):
|
| 78 |
+
chunk_end = min(chunk_start + chunk_size, len(trees))
|
| 79 |
+
chunk = trees[chunk_start:chunk_end]
|
| 80 |
+
sentences = []
|
| 81 |
+
try:
|
| 82 |
+
for idx, tree in enumerate(chunk):
|
| 83 |
+
tokens = [{TEXT: pt.children[0].label} for pt in tree.yield_preterminals()]
|
| 84 |
+
sentences.append(tokens)
|
| 85 |
+
except ValueError as e:
|
| 86 |
+
raise ValueError("Unable to process tree %d" % (idx + chunk_start)) from e
|
| 87 |
+
|
| 88 |
+
doc = Document(sentences)
|
| 89 |
+
tag_lists = retag_tags(doc, pipelines, xpos)
|
| 90 |
+
|
| 91 |
+
for tree_idx, (tree, tags) in enumerate(zip(chunk, tag_lists)):
|
| 92 |
+
try:
|
| 93 |
+
if any(tag is None for tag in tags):
|
| 94 |
+
raise RuntimeError("Tagged tree #{} with a None tag!\n{}\n{}".format(tree_idx, tree, tags))
|
| 95 |
+
new_tree = tree.replace_tags(tags)
|
| 96 |
+
new_trees.append(new_tree)
|
| 97 |
+
pbar.update(1)
|
| 98 |
+
except ValueError as e:
|
| 99 |
+
raise ValueError("Failed to properly retag tree #{}: {}".format(tree_idx, tree)) from e
|
| 100 |
+
if len(new_trees) != len(trees):
|
| 101 |
+
raise AssertionError("Retagged tree counts did not match: {} vs {}".format(len(new_trees), len(trees)))
|
| 102 |
+
return new_trees
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
# experimental results on nonlinearities
|
| 106 |
+
# this is on a VI dataset, VLSP_22, using 1/10th of the data as a dev set
|
| 107 |
+
# (no released test set at the time of the experiment)
|
| 108 |
+
# original non-Bert tagger, with 1 iteration each instead of averaged over 5
|
| 109 |
+
# considering the number of experiments and the length of time they would take
|
| 110 |
+
#
|
| 111 |
+
# Gelu had the highest score, which tracks with other experiments run.
|
| 112 |
+
# Note that publicly released models have typically used Relu
|
| 113 |
+
# on account of the runtime speed improvement
|
| 114 |
+
#
|
| 115 |
+
# Anyway, a larger experiment of 5x models on gelu or relu, using the
|
| 116 |
+
# Roberta POS tagger and a corpus of silver trees, resulted in 0.8270
|
| 117 |
+
# for relu and 0.8248 for gelu. So it is not even clear that
|
| 118 |
+
# switching to gelu would be an accuracy improvement.
|
| 119 |
+
#
|
| 120 |
+
# Gelu: 82.32
|
| 121 |
+
# Relu: 82.14
|
| 122 |
+
# Mish: 81.95
|
| 123 |
+
# Relu6: 81.91
|
| 124 |
+
# Silu: 81.90
|
| 125 |
+
# ELU: 81.73
|
| 126 |
+
# Hardswish: 81.67
|
| 127 |
+
# Softsign: 81.63
|
| 128 |
+
# Hardtanh: 81.44
|
| 129 |
+
# Celu: 81.43
|
| 130 |
+
# Selu: 81.17
|
| 131 |
+
# TODO: need to redo the prelu experiment with
|
| 132 |
+
# possibly different numbers of parameters
|
| 133 |
+
# and proper weight decay
|
| 134 |
+
# Prelu: 80.95 (terminated early)
|
| 135 |
+
# Softplus: 80.94
|
| 136 |
+
# Logsigmoid: 80.91
|
| 137 |
+
# Hardsigmoid: 79.03
|
| 138 |
+
# RReLU: 77.00
|
| 139 |
+
# Hardshrink: failed
|
| 140 |
+
# Softshrink: failed
|
| 141 |
+
NONLINEARITY = {
|
| 142 |
+
'celu': nn.CELU,
|
| 143 |
+
'elu': nn.ELU,
|
| 144 |
+
'gelu': nn.GELU,
|
| 145 |
+
'hardshrink': nn.Hardshrink,
|
| 146 |
+
'hardtanh': nn.Hardtanh,
|
| 147 |
+
'leaky_relu': nn.LeakyReLU,
|
| 148 |
+
'logsigmoid': nn.LogSigmoid,
|
| 149 |
+
'prelu': nn.PReLU,
|
| 150 |
+
'relu': nn.ReLU,
|
| 151 |
+
'relu6': nn.ReLU6,
|
| 152 |
+
'rrelu': nn.RReLU,
|
| 153 |
+
'selu': nn.SELU,
|
| 154 |
+
'softplus': nn.Softplus,
|
| 155 |
+
'softshrink': nn.Softshrink,
|
| 156 |
+
'softsign': nn.Softsign,
|
| 157 |
+
'tanhshrink': nn.Tanhshrink,
|
| 158 |
+
'tanh': nn.Tanh,
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
# separating these out allows for backwards compatibility with earlier versions of pytorch
|
| 162 |
+
# NOTE torch compatibility: if we ever *release* models with these
|
| 163 |
+
# activation functions, we will need to break that compatibility
|
| 164 |
+
|
| 165 |
+
nonlinearity_list = [
|
| 166 |
+
'GLU',
|
| 167 |
+
'Hardsigmoid',
|
| 168 |
+
'Hardswish',
|
| 169 |
+
'Mish',
|
| 170 |
+
'SiLU',
|
| 171 |
+
]
|
| 172 |
+
|
| 173 |
+
for nonlinearity in nonlinearity_list:
|
| 174 |
+
if hasattr(nn, nonlinearity):
|
| 175 |
+
NONLINEARITY[nonlinearity.lower()] = getattr(nn, nonlinearity)
|
| 176 |
+
|
| 177 |
+
def build_nonlinearity(nonlinearity):
|
| 178 |
+
"""
|
| 179 |
+
Look up "nonlinearity" in a map from function name to function, build the appropriate layer.
|
| 180 |
+
"""
|
| 181 |
+
if nonlinearity in NONLINEARITY:
|
| 182 |
+
return NONLINEARITY[nonlinearity]()
|
| 183 |
+
raise ValueError('Chosen value of nonlinearity, "%s", not handled' % nonlinearity)
|
| 184 |
+
|
| 185 |
+
def build_optimizer(args, model, build_simple_adadelta=False):
|
| 186 |
+
"""
|
| 187 |
+
Build an optimizer based on the arguments given
|
| 188 |
+
|
| 189 |
+
If we are "multistage" training and epochs_trained < epochs // 2,
|
| 190 |
+
we build an AdaDelta optimizer instead of whatever was requested
|
| 191 |
+
The build_simple_adadelta parameter controls this
|
| 192 |
+
"""
|
| 193 |
+
bert_learning_rate = 0.0
|
| 194 |
+
bert_weight_decay = args['bert_weight_decay']
|
| 195 |
+
if build_simple_adadelta:
|
| 196 |
+
optim_type = 'adadelta'
|
| 197 |
+
bert_finetune = args.get('stage1_bert_finetune', False)
|
| 198 |
+
if bert_finetune:
|
| 199 |
+
bert_learning_rate = args['stage1_bert_learning_rate']
|
| 200 |
+
learning_beta2 = 0.999 # doesn't matter for AdaDelta
|
| 201 |
+
learning_eps = DEFAULT_LEARNING_EPS['adadelta']
|
| 202 |
+
learning_rate = args['stage1_learning_rate']
|
| 203 |
+
learning_rho = DEFAULT_LEARNING_RHO
|
| 204 |
+
momentum = None # also doesn't matter for AdaDelta
|
| 205 |
+
weight_decay = DEFAULT_WEIGHT_DECAY['adadelta']
|
| 206 |
+
else:
|
| 207 |
+
optim_type = args['optim'].lower()
|
| 208 |
+
bert_finetune = args.get('bert_finetune', False)
|
| 209 |
+
if bert_finetune:
|
| 210 |
+
bert_learning_rate = args['bert_learning_rate']
|
| 211 |
+
learning_beta2 = args['learning_beta2']
|
| 212 |
+
learning_eps = args['learning_eps']
|
| 213 |
+
learning_rate = args['learning_rate']
|
| 214 |
+
learning_rho = args['learning_rho']
|
| 215 |
+
momentum = args['learning_momentum']
|
| 216 |
+
weight_decay = args['learning_weight_decay']
|
| 217 |
+
|
| 218 |
+
# TODO: allow rho as an arg for AdaDelta
|
| 219 |
+
return get_optimizer(name=optim_type,
|
| 220 |
+
model=model,
|
| 221 |
+
lr=learning_rate,
|
| 222 |
+
betas=(0.9, learning_beta2),
|
| 223 |
+
eps=learning_eps,
|
| 224 |
+
momentum=momentum,
|
| 225 |
+
weight_decay=weight_decay,
|
| 226 |
+
bert_learning_rate=bert_learning_rate,
|
| 227 |
+
bert_weight_decay=weight_decay*bert_weight_decay,
|
| 228 |
+
is_peft=args.get('use_peft', False),
|
| 229 |
+
bert_finetune_layers=args['bert_finetune_layers'],
|
| 230 |
+
opt_logger=tlogger)
|
| 231 |
+
|
| 232 |
+
def build_scheduler(args, optimizer, first_optimizer=False):
|
| 233 |
+
"""
|
| 234 |
+
Build the scheduler for the conparser based on its args
|
| 235 |
+
|
| 236 |
+
Used to use a warmup for learning rate, but that wasn't working very well
|
| 237 |
+
Now, we just use a ReduceLROnPlateau, which does quite well
|
| 238 |
+
"""
|
| 239 |
+
#if args.get('learning_rate_warmup', 0) <= 0:
|
| 240 |
+
# # TODO: is there an easier way to make an empty scheduler?
|
| 241 |
+
# lr_lambda = lambda x: 1.0
|
| 242 |
+
#else:
|
| 243 |
+
# warmup_end = args['learning_rate_warmup']
|
| 244 |
+
# def lr_lambda(x):
|
| 245 |
+
# if x >= warmup_end:
|
| 246 |
+
# return 1.0
|
| 247 |
+
# return x / warmup_end
|
| 248 |
+
|
| 249 |
+
#scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
|
| 250 |
+
|
| 251 |
+
if first_optimizer:
|
| 252 |
+
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=args['learning_rate_factor'], patience=args['learning_rate_patience'], cooldown=args['learning_rate_cooldown'], min_lr=args['stage1_learning_rate_min_lr'])
|
| 253 |
+
else:
|
| 254 |
+
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=args['learning_rate_factor'], patience=args['learning_rate_patience'], cooldown=args['learning_rate_cooldown'], min_lr=args['learning_rate_min_lr'])
|
| 255 |
+
return scheduler
|
| 256 |
+
|
| 257 |
+
def initialize_linear(linear, nonlinearity, bias):
|
| 258 |
+
"""
|
| 259 |
+
Initializes the bias to a positive value, hopefully preventing dead neurons
|
| 260 |
+
"""
|
| 261 |
+
if nonlinearity in ('relu', 'leaky_relu'):
|
| 262 |
+
nn.init.kaiming_normal_(linear.weight, nonlinearity=nonlinearity)
|
| 263 |
+
nn.init.uniform_(linear.bias, 0, 1 / (bias * 2) ** 0.5)
|
| 264 |
+
|
| 265 |
+
def add_predict_output_args(parser):
|
| 266 |
+
"""
|
| 267 |
+
Args specifically for the output location of data
|
| 268 |
+
"""
|
| 269 |
+
parser.add_argument('--predict_dir', type=str, default=".", help='Where to write the predictions during --mode predict. Pred and orig files will be written - the orig file will be retagged if that is requested. Writing the orig file is useful for removing None and retagging')
|
| 270 |
+
parser.add_argument('--predict_file', type=str, default=None, help='Base name for writing predictions')
|
| 271 |
+
parser.add_argument('--predict_format', type=str, default="{:_O}", help='Format to use when writing predictions')
|
| 272 |
+
|
| 273 |
+
parser.add_argument('--predict_output_gold_tags', default=False, action='store_true', help='Output gold tags as part of the evaluation - useful for putting the trees through EvalB')
|
| 274 |
+
|
| 275 |
+
def postprocess_predict_output_args(args):
|
| 276 |
+
if len(args['predict_format']) <= 2 or (len(args['predict_format']) <= 3 and args['predict_format'].endswith("Vi")):
|
| 277 |
+
args['predict_format'] = "{:" + args['predict_format'] + "}"
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def get_open_nodes(trees, transition_scheme):
|
| 281 |
+
"""
|
| 282 |
+
Return a list of all open nodes in the given dataset.
|
| 283 |
+
Depending on the parameters, may be single or compound open transitions.
|
| 284 |
+
"""
|
| 285 |
+
if transition_scheme is TransitionScheme.TOP_DOWN_COMPOUND:
|
| 286 |
+
return Tree.get_compound_constituents(trees)
|
| 287 |
+
elif transition_scheme is TransitionScheme.IN_ORDER_COMPOUND:
|
| 288 |
+
return Tree.get_compound_constituents(trees, separate_root=True)
|
| 289 |
+
else:
|
| 290 |
+
return [(x,) for x in Tree.get_unique_constituent_labels(trees)]
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
def verify_transitions(trees, sequences, transition_scheme, unary_limit, reverse, name, root_labels):
|
| 294 |
+
"""
|
| 295 |
+
Given a list of trees and their transition sequences, verify that the sequences rebuild the trees
|
| 296 |
+
"""
|
| 297 |
+
model = SimpleModel(transition_scheme, unary_limit, reverse, root_labels)
|
| 298 |
+
tlogger.info("Verifying the transition sequences for %d trees", len(trees))
|
| 299 |
+
|
| 300 |
+
data = zip(trees, sequences)
|
| 301 |
+
if tlogger.getEffectiveLevel() <= logging.INFO:
|
| 302 |
+
data = tqdm(zip(trees, sequences), total=len(trees))
|
| 303 |
+
|
| 304 |
+
for tree_idx, (tree, sequence) in enumerate(data):
|
| 305 |
+
# TODO: make the SimpleModel have a parse operation?
|
| 306 |
+
state = model.initial_state_from_gold_trees([tree])[0]
|
| 307 |
+
for idx, trans in enumerate(sequence):
|
| 308 |
+
if not trans.is_legal(state, model):
|
| 309 |
+
raise RuntimeError("Tree {} of {} failed: transition {}:{} was not legal in a transition sequence:\nOriginal tree: {}\nTransitions: {}".format(tree_idx, name, idx, trans, tree, sequence))
|
| 310 |
+
state = trans.apply(state, model)
|
| 311 |
+
result = model.get_top_constituent(state.constituents)
|
| 312 |
+
if reverse:
|
| 313 |
+
result = result.reverse()
|
| 314 |
+
if tree != result:
|
| 315 |
+
raise RuntimeError("Tree {} of {} failed: transition sequence did not match for a tree!\nOriginal tree:{}\nTransitions: {}\nResult tree:{}".format(tree_idx, name, tree, sequence, result))
|
| 316 |
+
|
| 317 |
+
def check_constituents(train_constituents, trees, treebank_name, fail=True):
|
| 318 |
+
"""
|
| 319 |
+
Check that all the constituents in the other dataset are known in the train set
|
| 320 |
+
"""
|
| 321 |
+
constituents = Tree.get_unique_constituent_labels(trees)
|
| 322 |
+
for con in constituents:
|
| 323 |
+
if con not in train_constituents:
|
| 324 |
+
first_error = None
|
| 325 |
+
num_errors = 0
|
| 326 |
+
for tree_idx, tree in enumerate(trees):
|
| 327 |
+
constituents = Tree.get_unique_constituent_labels(tree)
|
| 328 |
+
if con in constituents:
|
| 329 |
+
num_errors += 1
|
| 330 |
+
if first_error is None:
|
| 331 |
+
first_error = tree_idx
|
| 332 |
+
error = "Found constituent label {} in the {} set which don't exist in the train set. This constituent label occured in {} trees, with the first tree index at {} counting from 1\nThe error tree (which may have POS tags changed from the retagger and may be missing functional tags or empty nodes) is:\n{:P}".format(con, treebank_name, num_errors, (first_error+1), trees[first_error])
|
| 333 |
+
if fail:
|
| 334 |
+
raise RuntimeError(error)
|
| 335 |
+
else:
|
| 336 |
+
warnings.warn(error)
|
| 337 |
+
|
| 338 |
+
def check_root_labels(root_labels, other_trees, treebank_name):
|
| 339 |
+
"""
|
| 340 |
+
Check that all the root states in the other dataset are known in the train set
|
| 341 |
+
"""
|
| 342 |
+
for root_state in Tree.get_root_labels(other_trees):
|
| 343 |
+
if root_state not in root_labels:
|
| 344 |
+
raise RuntimeError("Found root state {} in the {} set which is not a ROOT state in the train set".format(root_state, treebank_name))
|
| 345 |
+
|
| 346 |
+
def remove_duplicate_trees(trees, treebank_name):
|
| 347 |
+
"""
|
| 348 |
+
Filter duplicates from the given dataset
|
| 349 |
+
"""
|
| 350 |
+
new_trees = []
|
| 351 |
+
known_trees = set()
|
| 352 |
+
for tree in trees:
|
| 353 |
+
tree_str = "{}".format(tree)
|
| 354 |
+
if tree_str in known_trees:
|
| 355 |
+
continue
|
| 356 |
+
known_trees.add(tree_str)
|
| 357 |
+
new_trees.append(tree)
|
| 358 |
+
if len(new_trees) < len(trees):
|
| 359 |
+
tlogger.info("Filtered %d duplicates from %s dataset", (len(trees) - len(new_trees)), treebank_name)
|
| 360 |
+
return new_trees
|
| 361 |
+
|
| 362 |
+
def remove_singleton_trees(trees):
|
| 363 |
+
"""
|
| 364 |
+
remove trees which are just a root and a single word
|
| 365 |
+
|
| 366 |
+
TODO: remove these trees in the conversion instead of here
|
| 367 |
+
"""
|
| 368 |
+
new_trees = [x for x in trees if
|
| 369 |
+
len(x.children) > 1 or
|
| 370 |
+
(len(x.children) == 1 and len(x.children[0].children) > 1) or
|
| 371 |
+
(len(x.children) == 1 and len(x.children[0].children) == 1 and len(x.children[0].children[0].children) >= 1)]
|
| 372 |
+
if len(trees) - len(new_trees) > 0:
|
| 373 |
+
tlogger.info("Eliminated %d trees with missing structure", (len(trees) - len(new_trees)))
|
| 374 |
+
return new_trees
|
| 375 |
+
|
stanza/stanza/models/coref/predict.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import torch
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
|
| 7 |
+
from stanza.models.coref.model import CorefModel
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
if __name__ == "__main__":
|
| 11 |
+
argparser = argparse.ArgumentParser()
|
| 12 |
+
argparser.add_argument("experiment")
|
| 13 |
+
argparser.add_argument("input_file")
|
| 14 |
+
argparser.add_argument("output_file")
|
| 15 |
+
argparser.add_argument("--config-file", default="config.toml")
|
| 16 |
+
argparser.add_argument("--batch-size", type=int,
|
| 17 |
+
help="Adjust to override the config value if you're"
|
| 18 |
+
" experiencing out-of-memory issues")
|
| 19 |
+
argparser.add_argument("--weights",
|
| 20 |
+
help="Path to file with weights to load."
|
| 21 |
+
" If not supplied, in the latest"
|
| 22 |
+
" weights of the experiment will be loaded;"
|
| 23 |
+
" if there aren't any, an error is raised.")
|
| 24 |
+
args = argparser.parse_args()
|
| 25 |
+
|
| 26 |
+
model = CorefModel.load_model(path=args.weights,
|
| 27 |
+
map_location="cpu",
|
| 28 |
+
ignore={"bert_optimizer", "general_optimizer",
|
| 29 |
+
"bert_scheduler", "general_scheduler"})
|
| 30 |
+
if args.batch_size:
|
| 31 |
+
model.config.a_scoring_batch_size = args.batch_size
|
| 32 |
+
model.training = False
|
| 33 |
+
|
| 34 |
+
try:
|
| 35 |
+
with open(args.input_file, encoding="utf-8") as fin:
|
| 36 |
+
input_data = json.load(fin)
|
| 37 |
+
except json.decoder.JSONDecodeError:
|
| 38 |
+
# read the old jsonlines format if necessary
|
| 39 |
+
with open(args.input_file, encoding="utf-8") as fin:
|
| 40 |
+
text = "[" + ",\n".join(fin) + "]"
|
| 41 |
+
input_data = json.loads(text)
|
| 42 |
+
docs = [model.build_doc(doc) for doc in input_data]
|
| 43 |
+
|
| 44 |
+
with torch.no_grad():
|
| 45 |
+
for doc in tqdm(docs, unit="docs"):
|
| 46 |
+
result = model.run(doc)
|
| 47 |
+
doc["span_clusters"] = result.span_clusters
|
| 48 |
+
doc["word_clusters"] = result.word_clusters
|
| 49 |
+
|
| 50 |
+
for key in ("word2subword", "subwords", "word_id", "head2span"):
|
| 51 |
+
del doc[key]
|
| 52 |
+
|
| 53 |
+
with open(args.output_file, mode="w") as fout:
|
| 54 |
+
for doc in docs:
|
| 55 |
+
json.dump(doc, fout)
|
stanza/stanza/models/coref/span_predictor.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" Describes SpanPredictor which aims to predict spans by taking as input
|
| 2 |
+
head word and context embeddings.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from typing import List, Optional, Tuple
|
| 6 |
+
|
| 7 |
+
from stanza.models.coref.const import Doc, Span
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class SpanPredictor(torch.nn.Module):
|
| 12 |
+
def __init__(self, input_size: int, distance_emb_size: int):
|
| 13 |
+
super().__init__()
|
| 14 |
+
self.ffnn = torch.nn.Sequential(
|
| 15 |
+
torch.nn.Linear(input_size * 2 + 64, input_size),
|
| 16 |
+
torch.nn.ReLU(),
|
| 17 |
+
torch.nn.Dropout(0.3),
|
| 18 |
+
torch.nn.Linear(input_size, 256),
|
| 19 |
+
torch.nn.ReLU(),
|
| 20 |
+
torch.nn.Dropout(0.3),
|
| 21 |
+
torch.nn.Linear(256, 64),
|
| 22 |
+
)
|
| 23 |
+
self.conv = torch.nn.Sequential(
|
| 24 |
+
torch.nn.Conv1d(64, 4, 3, 1, 1),
|
| 25 |
+
torch.nn.Conv1d(4, 2, 3, 1, 1)
|
| 26 |
+
)
|
| 27 |
+
self.emb = torch.nn.Embedding(128, distance_emb_size) # [-63, 63] + too_far
|
| 28 |
+
|
| 29 |
+
@property
|
| 30 |
+
def device(self) -> torch.device:
|
| 31 |
+
""" A workaround to get current device (which is assumed to be the
|
| 32 |
+
device of the first parameter of one of the submodules) """
|
| 33 |
+
return next(self.ffnn.parameters()).device
|
| 34 |
+
|
| 35 |
+
def forward(self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch
|
| 36 |
+
doc: Doc,
|
| 37 |
+
words: torch.Tensor,
|
| 38 |
+
heads_ids: torch.Tensor) -> torch.Tensor:
|
| 39 |
+
"""
|
| 40 |
+
Calculates span start/end scores of words for each span head in
|
| 41 |
+
heads_ids
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
doc (Doc): the document data
|
| 45 |
+
words (torch.Tensor): contextual embeddings for each word in the
|
| 46 |
+
document, [n_words, emb_size]
|
| 47 |
+
heads_ids (torch.Tensor): word indices of span heads
|
| 48 |
+
|
| 49 |
+
Returns:
|
| 50 |
+
torch.Tensor: span start/end scores, [n_heads, n_words, 2]
|
| 51 |
+
"""
|
| 52 |
+
# Obtain distance embedding indices, [n_heads, n_words]
|
| 53 |
+
relative_positions = (heads_ids.unsqueeze(1) - torch.arange(words.shape[0], device=words.device).unsqueeze(0))
|
| 54 |
+
emb_ids = relative_positions + 63 # make all valid distances positive
|
| 55 |
+
emb_ids[(emb_ids < 0) + (emb_ids > 126)] = 127 # "too_far"
|
| 56 |
+
|
| 57 |
+
# Obtain "same sentence" boolean mask, [n_heads, n_words]
|
| 58 |
+
sent_id = torch.tensor(doc["sent_id"], device=words.device)
|
| 59 |
+
same_sent = (sent_id[heads_ids].unsqueeze(1) == sent_id.unsqueeze(0))
|
| 60 |
+
|
| 61 |
+
# To save memory, only pass candidates from one sentence for each head
|
| 62 |
+
# pair_matrix contains concatenated span_head_emb + candidate_emb + distance_emb
|
| 63 |
+
# for each candidate among the words in the same sentence as span_head
|
| 64 |
+
# [n_heads, input_size * 2 + distance_emb_size]
|
| 65 |
+
rows, cols = same_sent.nonzero(as_tuple=True)
|
| 66 |
+
pair_matrix = torch.cat((
|
| 67 |
+
words[heads_ids[rows]],
|
| 68 |
+
words[cols],
|
| 69 |
+
self.emb(emb_ids[rows, cols]),
|
| 70 |
+
), dim=1)
|
| 71 |
+
|
| 72 |
+
lengths = same_sent.sum(dim=1)
|
| 73 |
+
padding_mask = torch.arange(0, lengths.max(), device=words.device).unsqueeze(0)
|
| 74 |
+
padding_mask = (padding_mask < lengths.unsqueeze(1)) # [n_heads, max_sent_len]
|
| 75 |
+
|
| 76 |
+
# [n_heads, max_sent_len, input_size * 2 + distance_emb_size]
|
| 77 |
+
# This is necessary to allow the convolution layer to look at several
|
| 78 |
+
# word scores
|
| 79 |
+
padded_pairs = torch.zeros(*padding_mask.shape, pair_matrix.shape[-1], device=words.device)
|
| 80 |
+
padded_pairs[padding_mask] = pair_matrix
|
| 81 |
+
|
| 82 |
+
res = self.ffnn(padded_pairs) # [n_heads, n_candidates, last_layer_output]
|
| 83 |
+
res = self.conv(res.permute(0, 2, 1)).permute(0, 2, 1) # [n_heads, n_candidates, 2]
|
| 84 |
+
|
| 85 |
+
scores = torch.full((heads_ids.shape[0], words.shape[0], 2), float('-inf'), device=words.device)
|
| 86 |
+
scores[rows, cols] = res[padding_mask]
|
| 87 |
+
|
| 88 |
+
# Make sure that start <= head <= end during inference
|
| 89 |
+
if not self.training:
|
| 90 |
+
valid_starts = torch.log((relative_positions >= 0).to(torch.float))
|
| 91 |
+
valid_ends = torch.log((relative_positions <= 0).to(torch.float))
|
| 92 |
+
valid_positions = torch.stack((valid_starts, valid_ends), dim=2)
|
| 93 |
+
return scores + valid_positions
|
| 94 |
+
return scores
|
| 95 |
+
|
| 96 |
+
def get_training_data(self,
|
| 97 |
+
doc: Doc,
|
| 98 |
+
words: torch.Tensor
|
| 99 |
+
) -> Tuple[Optional[torch.Tensor],
|
| 100 |
+
Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
| 101 |
+
""" Returns span starts/ends for gold mentions in the document. """
|
| 102 |
+
head2span = sorted(doc["head2span"])
|
| 103 |
+
if not head2span:
|
| 104 |
+
return None, None
|
| 105 |
+
heads, starts, ends = zip(*head2span)
|
| 106 |
+
heads = torch.tensor(heads, device=self.device)
|
| 107 |
+
starts = torch.tensor(starts, device=self.device)
|
| 108 |
+
ends = torch.tensor(ends, device=self.device) - 1
|
| 109 |
+
return self(doc, words, heads), (starts, ends)
|
| 110 |
+
|
| 111 |
+
def predict(self,
|
| 112 |
+
doc: Doc,
|
| 113 |
+
words: torch.Tensor,
|
| 114 |
+
clusters: List[List[int]]) -> List[List[Span]]:
|
| 115 |
+
"""
|
| 116 |
+
Predicts span clusters based on the word clusters.
|
| 117 |
+
|
| 118 |
+
Args:
|
| 119 |
+
doc (Doc): the document data
|
| 120 |
+
words (torch.Tensor): [n_words, emb_size] matrix containing
|
| 121 |
+
embeddings for each of the words in the text
|
| 122 |
+
clusters (List[List[int]]): a list of clusters where each cluster
|
| 123 |
+
is a list of word indices
|
| 124 |
+
|
| 125 |
+
Returns:
|
| 126 |
+
List[List[Span]]: span clusters
|
| 127 |
+
"""
|
| 128 |
+
if not clusters:
|
| 129 |
+
return []
|
| 130 |
+
|
| 131 |
+
heads_ids = torch.tensor(
|
| 132 |
+
sorted(i for cluster in clusters for i in cluster),
|
| 133 |
+
device=self.device
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
scores = self(doc, words, heads_ids)
|
| 137 |
+
starts = scores[:, :, 0].argmax(dim=1).tolist()
|
| 138 |
+
ends = (scores[:, :, 1].argmax(dim=1) + 1).tolist()
|
| 139 |
+
|
| 140 |
+
head2span = {
|
| 141 |
+
head: (start, end)
|
| 142 |
+
for head, start, end in zip(heads_ids.tolist(), starts, ends)
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
return [[head2span[head] for head in cluster]
|
| 146 |
+
for cluster in clusters]
|
stanza/stanza/models/coref/tokenizer_customization.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" This file defines functions used to modify the default behaviour
|
| 2 |
+
of transformers.AutoTokenizer. These changes are necessary, because some
|
| 3 |
+
tokenizers are meant to be used with raw text, while the OntoNotes documents
|
| 4 |
+
have already been split into words.
|
| 5 |
+
All the functions are used in coref_model.CorefModel._get_docs. """
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
# Filters out unwanted tokens produced by the tokenizer
|
| 9 |
+
TOKENIZER_FILTERS = {
|
| 10 |
+
"albert-xxlarge-v2": (lambda token: token != "▁"), # U+2581, not just "_"
|
| 11 |
+
"albert-large-v2": (lambda token: token != "▁"),
|
| 12 |
+
}
|
| 13 |
+
|
| 14 |
+
# Maps some words to tokens directly, without a tokenizer
|
| 15 |
+
TOKENIZER_MAPS = {
|
| 16 |
+
"roberta-large": {".": ["."], ",": [","], "!": ["!"], "?": ["?"],
|
| 17 |
+
":":[":"], ";":[";"], "'s": ["'s"]}
|
| 18 |
+
}
|
stanza/stanza/models/coref/word_encoder.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" Describes WordEncoder. Extracts mention vectors from bert-encoded text.
|
| 2 |
+
"""
|
| 3 |
+
|
| 4 |
+
from typing import Tuple
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from stanza.models.coref.config import Config
|
| 9 |
+
from stanza.models.coref.const import Doc
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class WordEncoder(torch.nn.Module): # pylint: disable=too-many-instance-attributes
|
| 13 |
+
""" Receives bert contextual embeddings of a text, extracts all the
|
| 14 |
+
possible mentions in that text. """
|
| 15 |
+
|
| 16 |
+
def __init__(self, features: int, config: Config):
|
| 17 |
+
"""
|
| 18 |
+
Args:
|
| 19 |
+
features (int): the number of featues in the input embeddings
|
| 20 |
+
config (Config): the configuration of the current session
|
| 21 |
+
"""
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.attn = torch.nn.Linear(in_features=features, out_features=1)
|
| 24 |
+
self.dropout = torch.nn.Dropout(config.dropout_rate)
|
| 25 |
+
|
| 26 |
+
@property
|
| 27 |
+
def device(self) -> torch.device:
|
| 28 |
+
""" A workaround to get current device (which is assumed to be the
|
| 29 |
+
device of the first parameter of one of the submodules) """
|
| 30 |
+
return next(self.attn.parameters()).device
|
| 31 |
+
|
| 32 |
+
def forward(self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch
|
| 33 |
+
doc: Doc,
|
| 34 |
+
x: torch.Tensor,
|
| 35 |
+
) -> Tuple[torch.Tensor, ...]:
|
| 36 |
+
"""
|
| 37 |
+
Extracts word representations from text.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
doc: the document data
|
| 41 |
+
x: a tensor containing bert output, shape (n_subtokens, bert_dim)
|
| 42 |
+
|
| 43 |
+
Returns:
|
| 44 |
+
words: a Tensor of shape [n_words, mention_emb];
|
| 45 |
+
mention representations
|
| 46 |
+
cluster_ids: tensor of shape [n_words], containing cluster indices
|
| 47 |
+
for each word. Non-coreferent words have cluster id of zero.
|
| 48 |
+
"""
|
| 49 |
+
word_boundaries = torch.tensor(doc["word2subword"], device=self.device)
|
| 50 |
+
starts = word_boundaries[:, 0]
|
| 51 |
+
ends = word_boundaries[:, 1]
|
| 52 |
+
|
| 53 |
+
# [n_mentions, features]
|
| 54 |
+
words = self._attn_scores(x, starts, ends).mm(x)
|
| 55 |
+
|
| 56 |
+
words = self.dropout(words)
|
| 57 |
+
|
| 58 |
+
return (words, self._cluster_ids(doc))
|
| 59 |
+
|
| 60 |
+
def _attn_scores(self,
|
| 61 |
+
bert_out: torch.Tensor,
|
| 62 |
+
word_starts: torch.Tensor,
|
| 63 |
+
word_ends: torch.Tensor) -> torch.Tensor:
|
| 64 |
+
""" Calculates attention scores for each of the mentions.
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
bert_out (torch.Tensor): [n_subwords, bert_emb], bert embeddings
|
| 68 |
+
for each of the subwords in the document
|
| 69 |
+
word_starts (torch.Tensor): [n_words], start indices of words
|
| 70 |
+
word_ends (torch.Tensor): [n_words], end indices of words
|
| 71 |
+
|
| 72 |
+
Returns:
|
| 73 |
+
torch.Tensor: [description]
|
| 74 |
+
"""
|
| 75 |
+
n_subtokens = len(bert_out)
|
| 76 |
+
n_words = len(word_starts)
|
| 77 |
+
|
| 78 |
+
# [n_mentions, n_subtokens]
|
| 79 |
+
# with 0 at positions belonging to the words and -inf elsewhere
|
| 80 |
+
attn_mask = torch.arange(0, n_subtokens, device=self.device).expand((n_words, n_subtokens))
|
| 81 |
+
attn_mask = ((attn_mask >= word_starts.unsqueeze(1))
|
| 82 |
+
* (attn_mask < word_ends.unsqueeze(1)))
|
| 83 |
+
attn_mask = torch.log(attn_mask.to(torch.float))
|
| 84 |
+
|
| 85 |
+
attn_scores = self.attn(bert_out).T # [1, n_subtokens]
|
| 86 |
+
attn_scores = attn_scores.expand((n_words, n_subtokens))
|
| 87 |
+
attn_scores = attn_mask + attn_scores
|
| 88 |
+
del attn_mask
|
| 89 |
+
return torch.softmax(attn_scores, dim=1) # [n_words, n_subtokens]
|
| 90 |
+
|
| 91 |
+
def _cluster_ids(self, doc: Doc) -> torch.Tensor:
|
| 92 |
+
"""
|
| 93 |
+
Args:
|
| 94 |
+
doc: document information
|
| 95 |
+
|
| 96 |
+
Returns:
|
| 97 |
+
torch.Tensor of shape [n_word], containing cluster indices for
|
| 98 |
+
each word. Non-coreferent words have cluster id of zero.
|
| 99 |
+
"""
|
| 100 |
+
word2cluster = {word_i: i
|
| 101 |
+
for i, cluster in enumerate(doc["word_clusters"], start=1)
|
| 102 |
+
for word_i in cluster}
|
| 103 |
+
|
| 104 |
+
return torch.tensor(
|
| 105 |
+
[word2cluster.get(word_i, 0)
|
| 106 |
+
for word_i in range(len(doc["cased_words"]))],
|
| 107 |
+
device=self.device
|
| 108 |
+
)
|
stanza/stanza/models/depparse/data.py
ADDED
|
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
import logging
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from stanza.models.common.bert_embedding import filter_data, needs_length_filter
|
| 6 |
+
from stanza.models.common.data import map_to_ids, get_long_tensor, get_float_tensor, sort_all
|
| 7 |
+
from stanza.models.common.vocab import PAD_ID, VOCAB_PREFIX, ROOT_ID, CompositeVocab, CharVocab
|
| 8 |
+
from stanza.models.pos.vocab import WordVocab, XPOSVocab, FeatureVocab, MultiVocab
|
| 9 |
+
from stanza.models.pos.xpos_vocab_factory import xpos_vocab_factory
|
| 10 |
+
from stanza.models.common.doc import *
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger('stanza')
|
| 13 |
+
|
| 14 |
+
def data_to_batches(data, batch_size, eval_mode, sort_during_eval, min_length_to_batch_separately):
|
| 15 |
+
"""
|
| 16 |
+
Given a list of lists, where the first element of each sublist
|
| 17 |
+
represents the sentence, group the sentences into batches.
|
| 18 |
+
|
| 19 |
+
During training mode (not eval_mode) the sentences are sorted by
|
| 20 |
+
length with a bit of random shuffling. During eval mode, the
|
| 21 |
+
sentences are sorted by length if sort_during_eval is true.
|
| 22 |
+
|
| 23 |
+
Refactored from the data structure in case other models could use
|
| 24 |
+
it and for ease of testing.
|
| 25 |
+
|
| 26 |
+
Returns (batches, original_order), where original_order is None
|
| 27 |
+
when in train mode or when unsorted and represents the original
|
| 28 |
+
location of each sentence in the sort
|
| 29 |
+
"""
|
| 30 |
+
res = []
|
| 31 |
+
|
| 32 |
+
if not eval_mode:
|
| 33 |
+
# sort sentences (roughly) by length for better memory utilization
|
| 34 |
+
data = sorted(data, key = lambda x: len(x[0]), reverse=random.random() > .5)
|
| 35 |
+
data_orig_idx = None
|
| 36 |
+
elif sort_during_eval:
|
| 37 |
+
(data, ), data_orig_idx = sort_all([data], [len(x[0]) for x in data])
|
| 38 |
+
else:
|
| 39 |
+
data_orig_idx = None
|
| 40 |
+
|
| 41 |
+
current = []
|
| 42 |
+
currentlen = 0
|
| 43 |
+
for x in data:
|
| 44 |
+
if min_length_to_batch_separately is not None and len(x[0]) > min_length_to_batch_separately:
|
| 45 |
+
if currentlen > 0:
|
| 46 |
+
res.append(current)
|
| 47 |
+
current = []
|
| 48 |
+
currentlen = 0
|
| 49 |
+
res.append([x])
|
| 50 |
+
else:
|
| 51 |
+
if len(x[0]) + currentlen > batch_size and currentlen > 0:
|
| 52 |
+
res.append(current)
|
| 53 |
+
current = []
|
| 54 |
+
currentlen = 0
|
| 55 |
+
current.append(x)
|
| 56 |
+
currentlen += len(x[0])
|
| 57 |
+
|
| 58 |
+
if currentlen > 0:
|
| 59 |
+
res.append(current)
|
| 60 |
+
|
| 61 |
+
return res, data_orig_idx
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class DataLoader:
|
| 65 |
+
|
| 66 |
+
def __init__(self, doc, batch_size, args, pretrain, vocab=None, evaluation=False, sort_during_eval=False, min_length_to_batch_separately=None, bert_tokenizer=None):
|
| 67 |
+
self.batch_size = batch_size
|
| 68 |
+
self.min_length_to_batch_separately=min_length_to_batch_separately
|
| 69 |
+
self.args = args
|
| 70 |
+
self.eval = evaluation
|
| 71 |
+
self.shuffled = not self.eval
|
| 72 |
+
self.sort_during_eval = sort_during_eval
|
| 73 |
+
self.doc = doc
|
| 74 |
+
data = self.load_doc(doc)
|
| 75 |
+
|
| 76 |
+
# handle vocab
|
| 77 |
+
if vocab is None:
|
| 78 |
+
self.vocab = self.init_vocab(data)
|
| 79 |
+
else:
|
| 80 |
+
self.vocab = vocab
|
| 81 |
+
|
| 82 |
+
# filter out the long sentences if bert is used
|
| 83 |
+
if self.args.get('bert_model', None) and needs_length_filter(self.args['bert_model']):
|
| 84 |
+
data = filter_data(self.args['bert_model'], data, bert_tokenizer)
|
| 85 |
+
|
| 86 |
+
# handle pretrain; pretrain vocab is used when args['pretrain'] == True and pretrain is not None
|
| 87 |
+
self.pretrain_vocab = None
|
| 88 |
+
if pretrain is not None and args['pretrain']:
|
| 89 |
+
self.pretrain_vocab = pretrain.vocab
|
| 90 |
+
|
| 91 |
+
# filter and sample data
|
| 92 |
+
if args.get('sample_train', 1.0) < 1.0 and not self.eval:
|
| 93 |
+
keep = int(args['sample_train'] * len(data))
|
| 94 |
+
data = random.sample(data, keep)
|
| 95 |
+
logger.debug("Subsample training set with rate {:g}".format(args['sample_train']))
|
| 96 |
+
|
| 97 |
+
data = self.preprocess(data, self.vocab, self.pretrain_vocab, args)
|
| 98 |
+
# shuffle for training
|
| 99 |
+
if self.shuffled:
|
| 100 |
+
random.shuffle(data)
|
| 101 |
+
self.num_examples = len(data)
|
| 102 |
+
|
| 103 |
+
# chunk into batches
|
| 104 |
+
self.data = self.chunk_batches(data)
|
| 105 |
+
logger.debug("{} batches created.".format(len(self.data)))
|
| 106 |
+
|
| 107 |
+
def init_vocab(self, data):
|
| 108 |
+
assert self.eval == False # for eval vocab must exist
|
| 109 |
+
charvocab = CharVocab(data, self.args['shorthand'])
|
| 110 |
+
wordvocab = WordVocab(data, self.args['shorthand'], cutoff=7, lower=True)
|
| 111 |
+
uposvocab = WordVocab(data, self.args['shorthand'], idx=1)
|
| 112 |
+
xposvocab = xpos_vocab_factory(data, self.args['shorthand'])
|
| 113 |
+
featsvocab = FeatureVocab(data, self.args['shorthand'], idx=3)
|
| 114 |
+
lemmavocab = WordVocab(data, self.args['shorthand'], cutoff=7, idx=4, lower=True)
|
| 115 |
+
deprelvocab = WordVocab(data, self.args['shorthand'], idx=6)
|
| 116 |
+
vocab = MultiVocab({'char': charvocab,
|
| 117 |
+
'word': wordvocab,
|
| 118 |
+
'upos': uposvocab,
|
| 119 |
+
'xpos': xposvocab,
|
| 120 |
+
'feats': featsvocab,
|
| 121 |
+
'lemma': lemmavocab,
|
| 122 |
+
'deprel': deprelvocab})
|
| 123 |
+
return vocab
|
| 124 |
+
|
| 125 |
+
def preprocess(self, data, vocab, pretrain_vocab, args):
|
| 126 |
+
processed = []
|
| 127 |
+
xpos_replacement = [[ROOT_ID] * len(vocab['xpos'])] if isinstance(vocab['xpos'], CompositeVocab) else [ROOT_ID]
|
| 128 |
+
feats_replacement = [[ROOT_ID] * len(vocab['feats'])]
|
| 129 |
+
for sent in data:
|
| 130 |
+
processed_sent = [[ROOT_ID] + vocab['word'].map([w[0] for w in sent])]
|
| 131 |
+
processed_sent += [[[ROOT_ID]] + [vocab['char'].map([x for x in w[0]]) for w in sent]]
|
| 132 |
+
processed_sent += [[ROOT_ID] + vocab['upos'].map([w[1] for w in sent])]
|
| 133 |
+
processed_sent += [xpos_replacement + vocab['xpos'].map([w[2] for w in sent])]
|
| 134 |
+
processed_sent += [feats_replacement + vocab['feats'].map([w[3] for w in sent])]
|
| 135 |
+
if pretrain_vocab is not None:
|
| 136 |
+
# always use lowercase lookup in pretrained vocab
|
| 137 |
+
processed_sent += [[ROOT_ID] + pretrain_vocab.map([w[0].lower() for w in sent])]
|
| 138 |
+
else:
|
| 139 |
+
processed_sent += [[ROOT_ID] + [PAD_ID] * len(sent)]
|
| 140 |
+
processed_sent += [[ROOT_ID] + vocab['lemma'].map([w[4] for w in sent])]
|
| 141 |
+
processed_sent += [[to_int(w[5], ignore_error=self.eval) for w in sent]]
|
| 142 |
+
processed_sent += [vocab['deprel'].map([w[6] for w in sent])]
|
| 143 |
+
processed_sent.append([w[0] for w in sent])
|
| 144 |
+
processed.append(processed_sent)
|
| 145 |
+
return processed
|
| 146 |
+
|
| 147 |
+
def __len__(self):
|
| 148 |
+
return len(self.data)
|
| 149 |
+
|
| 150 |
+
def __getitem__(self, key):
|
| 151 |
+
""" Get a batch with index. """
|
| 152 |
+
if not isinstance(key, int):
|
| 153 |
+
raise TypeError
|
| 154 |
+
if key < 0 or key >= len(self.data):
|
| 155 |
+
raise IndexError
|
| 156 |
+
batch = self.data[key]
|
| 157 |
+
batch_size = len(batch)
|
| 158 |
+
batch = list(zip(*batch))
|
| 159 |
+
assert len(batch) == 10
|
| 160 |
+
|
| 161 |
+
# sort sentences by lens for easy RNN operations
|
| 162 |
+
lens = [len(x) for x in batch[0]]
|
| 163 |
+
batch, orig_idx = sort_all(batch, lens)
|
| 164 |
+
|
| 165 |
+
# sort words by lens for easy char-RNN operations
|
| 166 |
+
batch_words = [w for sent in batch[1] for w in sent]
|
| 167 |
+
word_lens = [len(x) for x in batch_words]
|
| 168 |
+
batch_words, word_orig_idx = sort_all([batch_words], word_lens)
|
| 169 |
+
batch_words = batch_words[0]
|
| 170 |
+
word_lens = [len(x) for x in batch_words]
|
| 171 |
+
|
| 172 |
+
# convert to tensors
|
| 173 |
+
words = batch[0]
|
| 174 |
+
words = get_long_tensor(words, batch_size)
|
| 175 |
+
words_mask = torch.eq(words, PAD_ID)
|
| 176 |
+
wordchars = get_long_tensor(batch_words, len(word_lens))
|
| 177 |
+
wordchars_mask = torch.eq(wordchars, PAD_ID)
|
| 178 |
+
|
| 179 |
+
upos = get_long_tensor(batch[2], batch_size)
|
| 180 |
+
xpos = get_long_tensor(batch[3], batch_size)
|
| 181 |
+
ufeats = get_long_tensor(batch[4], batch_size)
|
| 182 |
+
pretrained = get_long_tensor(batch[5], batch_size)
|
| 183 |
+
sentlens = [len(x) for x in batch[0]]
|
| 184 |
+
lemma = get_long_tensor(batch[6], batch_size)
|
| 185 |
+
head = get_long_tensor(batch[7], batch_size)
|
| 186 |
+
deprel = get_long_tensor(batch[8], batch_size)
|
| 187 |
+
text = batch[9]
|
| 188 |
+
return words, words_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained, lemma, head, deprel, orig_idx, word_orig_idx, sentlens, word_lens, text
|
| 189 |
+
|
| 190 |
+
def load_doc(self, doc):
|
| 191 |
+
data = doc.get([TEXT, UPOS, XPOS, FEATS, LEMMA, HEAD, DEPREL], as_sentences=True)
|
| 192 |
+
data = self.resolve_none(data)
|
| 193 |
+
return data
|
| 194 |
+
|
| 195 |
+
def resolve_none(self, data):
|
| 196 |
+
# replace None to '_'
|
| 197 |
+
for sent_idx in range(len(data)):
|
| 198 |
+
for tok_idx in range(len(data[sent_idx])):
|
| 199 |
+
for feat_idx in range(len(data[sent_idx][tok_idx])):
|
| 200 |
+
if data[sent_idx][tok_idx][feat_idx] is None:
|
| 201 |
+
data[sent_idx][tok_idx][feat_idx] = '_'
|
| 202 |
+
return data
|
| 203 |
+
|
| 204 |
+
def __iter__(self):
|
| 205 |
+
for i in range(self.__len__()):
|
| 206 |
+
yield self.__getitem__(i)
|
| 207 |
+
|
| 208 |
+
def set_batch_size(self, batch_size):
|
| 209 |
+
self.batch_size = batch_size
|
| 210 |
+
|
| 211 |
+
def reshuffle(self):
|
| 212 |
+
data = [y for x in self.data for y in x]
|
| 213 |
+
self.data = self.chunk_batches(data)
|
| 214 |
+
random.shuffle(self.data)
|
| 215 |
+
|
| 216 |
+
def chunk_batches(self, data):
|
| 217 |
+
batches, data_orig_idx = data_to_batches(data=data, batch_size=self.batch_size,
|
| 218 |
+
eval_mode=self.eval, sort_during_eval=self.sort_during_eval,
|
| 219 |
+
min_length_to_batch_separately=self.min_length_to_batch_separately)
|
| 220 |
+
# data_orig_idx might be None at train time, since we don't anticipate unsorting
|
| 221 |
+
self.data_orig_idx = data_orig_idx
|
| 222 |
+
return batches
|
| 223 |
+
|
| 224 |
+
def to_int(string, ignore_error=False):
|
| 225 |
+
try:
|
| 226 |
+
res = int(string)
|
| 227 |
+
except ValueError as err:
|
| 228 |
+
if ignore_error:
|
| 229 |
+
return 0
|
| 230 |
+
else:
|
| 231 |
+
raise err
|
| 232 |
+
return res
|
| 233 |
+
|
stanza/stanza/models/lemma/attach_lemma_classifier.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
|
| 3 |
+
from stanza.models.lemma.trainer import Trainer
|
| 4 |
+
from stanza.models.lemma_classifier.base_model import LemmaClassifier
|
| 5 |
+
|
| 6 |
+
def attach_classifier(input_filename, output_filename, classifiers):
|
| 7 |
+
trainer = Trainer(model_file=input_filename)
|
| 8 |
+
|
| 9 |
+
for classifier in classifiers:
|
| 10 |
+
classifier = LemmaClassifier.load(classifier)
|
| 11 |
+
trainer.contextual_lemmatizers.append(classifier)
|
| 12 |
+
|
| 13 |
+
trainer.save(output_filename)
|
| 14 |
+
|
| 15 |
+
def main(args=None):
|
| 16 |
+
parser = argparse.ArgumentParser()
|
| 17 |
+
parser.add_argument('--input', type=str, required=True, help='Which lemmatizer to start from')
|
| 18 |
+
parser.add_argument('--output', type=str, required=True, help='Where to save the lemmatizer')
|
| 19 |
+
parser.add_argument('--classifier', type=str, required=True, nargs='+', help='Lemma classifier to attach')
|
| 20 |
+
args = parser.parse_args(args)
|
| 21 |
+
|
| 22 |
+
attach_classifier(args.input, args.output, args.classifier)
|
| 23 |
+
|
| 24 |
+
if __name__ == '__main__':
|
| 25 |
+
main()
|
stanza/stanza/models/lemma/scorer.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utils and wrappers for scoring lemmatizers.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from stanza.models.common.utils import ud_scores
|
| 6 |
+
|
| 7 |
+
def score(system_conllu_file, gold_conllu_file):
|
| 8 |
+
""" Wrapper for lemma scorer. """
|
| 9 |
+
evaluation = ud_scores(gold_conllu_file, system_conllu_file)
|
| 10 |
+
el = evaluation["Lemmas"]
|
| 11 |
+
p, r, f = el.precision, el.recall, el.f1
|
| 12 |
+
return p, r, f
|
| 13 |
+
|
stanza/stanza/models/lemma/vocab.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import Counter
|
| 2 |
+
|
| 3 |
+
from stanza.models.common.vocab import BaseVocab, BaseMultiVocab
|
| 4 |
+
from stanza.models.common.seq2seq_constant import VOCAB_PREFIX
|
| 5 |
+
|
| 6 |
+
class Vocab(BaseVocab):
|
| 7 |
+
def build_vocab(self):
|
| 8 |
+
counter = Counter(self.data)
|
| 9 |
+
self._id2unit = VOCAB_PREFIX + list(sorted(list(counter.keys()), key=lambda k: counter[k], reverse=True))
|
| 10 |
+
self._unit2id = {w:i for i, w in enumerate(self._id2unit)}
|
| 11 |
+
|
| 12 |
+
class MultiVocab(BaseMultiVocab):
|
| 13 |
+
@classmethod
|
| 14 |
+
def load_state_dict(cls, state_dict):
|
| 15 |
+
new = cls()
|
| 16 |
+
for k,v in state_dict.items():
|
| 17 |
+
new[k] = Vocab.load_state_dict(v)
|
| 18 |
+
return new
|
stanza/stanza/models/lemma_classifier/base_trainer.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from abc import ABC, abstractmethod
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
from typing import List, Tuple, Any, Mapping
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.optim as optim
|
| 10 |
+
|
| 11 |
+
from stanza.models.common.utils import default_device
|
| 12 |
+
from stanza.models.lemma_classifier import utils
|
| 13 |
+
from stanza.models.lemma_classifier.constants import DEFAULT_BATCH_SIZE
|
| 14 |
+
from stanza.models.lemma_classifier.evaluate_models import evaluate_model
|
| 15 |
+
from stanza.utils.get_tqdm import get_tqdm
|
| 16 |
+
|
| 17 |
+
tqdm = get_tqdm()
|
| 18 |
+
logger = logging.getLogger('stanza.lemmaclassifier')
|
| 19 |
+
|
| 20 |
+
class BaseLemmaClassifierTrainer(ABC):
|
| 21 |
+
def configure_weighted_loss(self, label_decoder: Mapping, counts: Mapping):
|
| 22 |
+
"""
|
| 23 |
+
If applicable, this function will update the loss function of the LemmaClassifierLSTM model to become BCEWithLogitsLoss.
|
| 24 |
+
The weights are determined by the counts of the classes in the dataset. The weights are inversely proportional to the
|
| 25 |
+
frequency of the class in the set. E.g. classes with lower frequency will have higher weight.
|
| 26 |
+
"""
|
| 27 |
+
weights = [0 for _ in label_decoder.keys()] # each key in the label decoder is one class, we have one weight per class
|
| 28 |
+
total_samples = sum(counts.values())
|
| 29 |
+
for class_idx in counts:
|
| 30 |
+
weights[class_idx] = total_samples / (counts[class_idx] * len(counts)) # weight_i = total / (# examples in class i * num classes)
|
| 31 |
+
weights = torch.tensor(weights)
|
| 32 |
+
logger.info(f"Using weights {weights} for weighted loss.")
|
| 33 |
+
self.criterion = nn.BCEWithLogitsLoss(weight=weights)
|
| 34 |
+
|
| 35 |
+
@abstractmethod
|
| 36 |
+
def build_model(self, label_decoder, upos_to_id, known_words, target_words, target_upos):
|
| 37 |
+
"""
|
| 38 |
+
Build a model using pieces of the dataset to determine some of the model shape
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
def train(self, num_epochs: int, save_name: str, args: Mapping, eval_file: str, train_file: str) -> None:
|
| 42 |
+
"""
|
| 43 |
+
Trains a model on batches of texts, position indices of the target token, and labels (lemma annotation) for the target token.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
num_epochs (int): Number of training epochs
|
| 47 |
+
save_name (str): Path to file where trained model should be saved.
|
| 48 |
+
eval_file (str): Path to the dev set file for evaluating model checkpoints each epoch.
|
| 49 |
+
train_file (str): Path to data file, containing tokenized text sentences, token index and true label for token lemma on each line.
|
| 50 |
+
"""
|
| 51 |
+
# Put model on GPU (if possible)
|
| 52 |
+
device = default_device()
|
| 53 |
+
|
| 54 |
+
if not train_file:
|
| 55 |
+
raise ValueError("Cannot train model - no train_file supplied!")
|
| 56 |
+
|
| 57 |
+
dataset = utils.Dataset(train_file, get_counts=self.weighted_loss, batch_size=args.get("batch_size", DEFAULT_BATCH_SIZE))
|
| 58 |
+
label_decoder = dataset.label_decoder
|
| 59 |
+
upos_to_id = dataset.upos_to_id
|
| 60 |
+
self.output_dim = len(label_decoder)
|
| 61 |
+
logger.info(f"Loaded dataset successfully from {train_file}")
|
| 62 |
+
logger.info(f"Using label decoder: {label_decoder} Output dimension: {self.output_dim}")
|
| 63 |
+
logger.info(f"Target words: {dataset.target_words}")
|
| 64 |
+
|
| 65 |
+
self.model = self.build_model(label_decoder, upos_to_id, dataset.known_words, dataset.target_words, set(dataset.target_upos))
|
| 66 |
+
self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
|
| 67 |
+
|
| 68 |
+
self.model.to(device)
|
| 69 |
+
logger.info(f"Training model on device: {device}. {next(self.model.parameters()).device}")
|
| 70 |
+
|
| 71 |
+
if os.path.exists(save_name) and not args.get('force', False):
|
| 72 |
+
raise FileExistsError(f"Save name {save_name} already exists; training would overwrite previous file contents. Aborting...")
|
| 73 |
+
|
| 74 |
+
if self.weighted_loss:
|
| 75 |
+
self.configure_weighted_loss(label_decoder, dataset.counts)
|
| 76 |
+
|
| 77 |
+
# Put the criterion on GPU too
|
| 78 |
+
logger.debug(f"Criterion on {next(self.model.parameters()).device}")
|
| 79 |
+
self.criterion = self.criterion.to(next(self.model.parameters()).device)
|
| 80 |
+
|
| 81 |
+
best_model, best_f1 = None, float("-inf") # Used for saving checkpoints of the model
|
| 82 |
+
for epoch in range(num_epochs):
|
| 83 |
+
# go over entire dataset with each epoch
|
| 84 |
+
for sentences, positions, upos_tags, labels in tqdm(dataset):
|
| 85 |
+
assert len(sentences) == len(positions) == len(labels), f"Input sentences, positions, and labels are of unequal length ({len(sentences), len(positions), len(labels)})"
|
| 86 |
+
|
| 87 |
+
self.optimizer.zero_grad()
|
| 88 |
+
outputs = self.model(positions, sentences, upos_tags)
|
| 89 |
+
|
| 90 |
+
# Compute loss, which is different if using CE or BCEWithLogitsLoss
|
| 91 |
+
if self.weighted_loss: # BCEWithLogitsLoss requires a vector for target where probability is 1 on the true label class, and 0 on others.
|
| 92 |
+
# TODO: three classes?
|
| 93 |
+
targets = torch.stack([torch.tensor([1, 0]) if label == 0 else torch.tensor([0, 1]) for label in labels]).to(dtype=torch.float32).to(device)
|
| 94 |
+
# should be shape size (batch_size, 2)
|
| 95 |
+
else: # CELoss accepts target as just raw label
|
| 96 |
+
targets = labels.to(device)
|
| 97 |
+
|
| 98 |
+
loss = self.criterion(outputs, targets)
|
| 99 |
+
|
| 100 |
+
loss.backward()
|
| 101 |
+
self.optimizer.step()
|
| 102 |
+
|
| 103 |
+
logger.info(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item()}")
|
| 104 |
+
if eval_file:
|
| 105 |
+
# Evaluate model on dev set to see if it should be saved.
|
| 106 |
+
_, _, _, f1 = evaluate_model(self.model, eval_file, is_training=True)
|
| 107 |
+
logger.info(f"Weighted f1 for model: {f1}")
|
| 108 |
+
if f1 > best_f1:
|
| 109 |
+
best_f1 = f1
|
| 110 |
+
self.model.save(save_name)
|
| 111 |
+
logger.info(f"New best model: weighted f1 score of {f1}.")
|
| 112 |
+
else:
|
| 113 |
+
self.model.save(save_name)
|
| 114 |
+
|
stanza/stanza/models/lemma_classifier/constants.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from enum import Enum
|
| 2 |
+
|
| 3 |
+
UNKNOWN_TOKEN = "unk" # token name for unknown tokens
|
| 4 |
+
UNKNOWN_TOKEN_IDX = -1 # custom index we apply to unknown tokens
|
| 5 |
+
|
| 6 |
+
# TODO: ModelType could just be LSTM and TRANSFORMER
|
| 7 |
+
# and then the transformer baseline would have the transformer as another argument
|
| 8 |
+
class ModelType(Enum):
|
| 9 |
+
LSTM = 1
|
| 10 |
+
TRANSFORMER = 2
|
| 11 |
+
BERT = 3
|
| 12 |
+
ROBERTA = 4
|
| 13 |
+
|
| 14 |
+
DEFAULT_BATCH_SIZE = 16
|
stanza/stanza/models/lemma_classifier/evaluate_many.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utils to evaluate many models of the same type at once
|
| 3 |
+
"""
|
| 4 |
+
import argparse
|
| 5 |
+
import os
|
| 6 |
+
import logging
|
| 7 |
+
|
| 8 |
+
from stanza.models.lemma_classifier.evaluate_models import main as evaluate_main
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger('stanza.lemmaclassifier')
|
| 12 |
+
|
| 13 |
+
def evaluate_n_models(path_to_models_dir, args):
|
| 14 |
+
|
| 15 |
+
total_results = {
|
| 16 |
+
"be": 0.0,
|
| 17 |
+
"have": 0.0,
|
| 18 |
+
"accuracy": 0.0,
|
| 19 |
+
"weighted_f1": 0.0
|
| 20 |
+
}
|
| 21 |
+
paths = os.listdir(path_to_models_dir)
|
| 22 |
+
num_models = len(paths)
|
| 23 |
+
for model_path in paths:
|
| 24 |
+
full_path = os.path.join(path_to_models_dir, model_path)
|
| 25 |
+
args.save_name = full_path
|
| 26 |
+
mcc_results, confusion, acc, weighted_f1 = evaluate_main(predefined_args=args)
|
| 27 |
+
|
| 28 |
+
for lemma in mcc_results:
|
| 29 |
+
|
| 30 |
+
lemma_f1 = mcc_results.get(lemma, None).get("f1") * 100
|
| 31 |
+
total_results[lemma] += lemma_f1
|
| 32 |
+
|
| 33 |
+
total_results["accuracy"] += acc
|
| 34 |
+
total_results["weighted_f1"] += weighted_f1
|
| 35 |
+
|
| 36 |
+
total_results["be"] /= num_models
|
| 37 |
+
total_results["have"] /= num_models
|
| 38 |
+
total_results["accuracy"] /= num_models
|
| 39 |
+
total_results["weighted_f1"] /= num_models
|
| 40 |
+
|
| 41 |
+
logger.info(f"Models in {path_to_models_dir} had average weighted f1 of {100 * total_results['weighted_f1']}.\nLemma 'be' had f1: {total_results['be']}\nLemma 'have' had f1: {total_results['have']}.\nAccuracy: {100 * total_results['accuracy']}.\n ({num_models} models evaluated).")
|
| 42 |
+
return total_results
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def main():
|
| 46 |
+
parser = argparse.ArgumentParser()
|
| 47 |
+
parser.add_argument("--vocab_size", type=int, default=10000, help="Number of tokens in vocab")
|
| 48 |
+
parser.add_argument("--embedding_dim", type=int, default=100, help="Number of dimensions in word embeddings (currently using GloVe)")
|
| 49 |
+
parser.add_argument("--hidden_dim", type=int, default=256, help="Size of hidden layer")
|
| 50 |
+
parser.add_argument('--wordvec_pretrain_file', type=str, default=None, help='Exact name of the pretrain file to read')
|
| 51 |
+
parser.add_argument("--charlm", action='store_true', default=False, help="Whether not to use the charlm embeddings")
|
| 52 |
+
parser.add_argument('--charlm_shorthand', type=str, default=None, help="Shorthand for character-level language model training corpus.")
|
| 53 |
+
parser.add_argument("--charlm_forward_file", type=str, default=os.path.join(os.path.dirname(__file__), "charlm_files", "1billion_forward.pt"), help="Path to forward charlm file")
|
| 54 |
+
parser.add_argument("--charlm_backward_file", type=str, default=os.path.join(os.path.dirname(__file__), "charlm_files", "1billion_backwards.pt"), help="Path to backward charlm file")
|
| 55 |
+
parser.add_argument("--save_name", type=str, default=os.path.join(os.path.dirname(__file__), "saved_models", "lemma_classifier_model.pt"), help="Path to model save file")
|
| 56 |
+
parser.add_argument("--model_type", type=str, default="roberta", help="Which transformer to use ('bert' or 'roberta' or 'lstm')")
|
| 57 |
+
parser.add_argument("--bert_model", type=str, default=None, help="Use a specific transformer instead of the default bert/roberta")
|
| 58 |
+
parser.add_argument("--eval_file", type=str, help="path to evaluation file")
|
| 59 |
+
|
| 60 |
+
# Args specific to several model eval
|
| 61 |
+
parser.add_argument("--base_path", type=str, default=None, help="path to dir for eval")
|
| 62 |
+
|
| 63 |
+
args = parser.parse_args()
|
| 64 |
+
evaluate_n_models(args.base_path, args)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
if __name__ == "__main__":
|
| 68 |
+
main()
|
stanza/stanza/models/lemma_classifier/evaluate_models.py
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
parentdir = os.path.dirname(__file__)
|
| 5 |
+
parentdir = os.path.dirname(parentdir)
|
| 6 |
+
parentdir = os.path.dirname(parentdir)
|
| 7 |
+
sys.path.append(parentdir)
|
| 8 |
+
|
| 9 |
+
import logging
|
| 10 |
+
import argparse
|
| 11 |
+
import os
|
| 12 |
+
|
| 13 |
+
from typing import Any, List, Tuple, Mapping
|
| 14 |
+
from collections import defaultdict
|
| 15 |
+
from numpy import random
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
|
| 20 |
+
import stanza
|
| 21 |
+
|
| 22 |
+
from stanza.models.common.utils import default_device
|
| 23 |
+
from stanza.models.lemma_classifier import utils
|
| 24 |
+
from stanza.models.lemma_classifier.base_model import LemmaClassifier
|
| 25 |
+
from stanza.models.lemma_classifier.lstm_model import LemmaClassifierLSTM
|
| 26 |
+
from stanza.models.lemma_classifier.transformer_model import LemmaClassifierWithTransformer
|
| 27 |
+
from stanza.utils.confusion import format_confusion
|
| 28 |
+
from stanza.utils.get_tqdm import get_tqdm
|
| 29 |
+
|
| 30 |
+
tqdm = get_tqdm()
|
| 31 |
+
|
| 32 |
+
logger = logging.getLogger('stanza.lemmaclassifier')
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def get_weighted_f1(mcc_results: Mapping[int, Mapping[str, float]], confusion: Mapping[int, Mapping[int, int]]) -> float:
|
| 36 |
+
"""
|
| 37 |
+
Computes the weighted F1 score across an evaluation set.
|
| 38 |
+
|
| 39 |
+
The weight of a class's F1 score is equal to the number of examples in evaluation. This makes classes that have more
|
| 40 |
+
examples in the evaluation more impactful to the weighted f1.
|
| 41 |
+
"""
|
| 42 |
+
num_total_examples = 0
|
| 43 |
+
weighted_f1 = 0
|
| 44 |
+
|
| 45 |
+
for class_id in mcc_results:
|
| 46 |
+
class_f1 = mcc_results.get(class_id).get("f1")
|
| 47 |
+
num_class_examples = sum(confusion.get(class_id).values())
|
| 48 |
+
weighted_f1 += class_f1 * num_class_examples
|
| 49 |
+
num_total_examples += num_class_examples
|
| 50 |
+
|
| 51 |
+
return weighted_f1 / num_total_examples
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def evaluate_sequences(gold_tag_sequences: List[Any], pred_tag_sequences: List[Any], label_decoder: Mapping, verbose=True):
|
| 55 |
+
"""
|
| 56 |
+
Evaluates a model's predicted tags against a set of gold tags. Computes precision, recall, and f1 for all classes.
|
| 57 |
+
|
| 58 |
+
Precision = true positives / true positives + false positives
|
| 59 |
+
Recall = true positives / true positives + false negatives
|
| 60 |
+
F1 = 2 * (Precision * Recall) / (Precision + Recall)
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
1. Multi class result dictionary, where each class is a key and maps to another map of its F1, precision, and recall scores.
|
| 64 |
+
e.g. multiclass_results[0]["precision"] would give class 0's precision.
|
| 65 |
+
2. Confusion matrix, where each key is a gold tag and its value is another map with a key of the predicted tag with value of that (gold, pred) count.
|
| 66 |
+
e.g. confusion[0][1] = 6 would mean that for gold tag 0, the model predicted tag 1 a total of 6 times.
|
| 67 |
+
"""
|
| 68 |
+
assert len(gold_tag_sequences) == len(pred_tag_sequences), \
|
| 69 |
+
f"Length of gold tag sequences is {len(gold_tag_sequences)}, while length of predicted tag sequence is {len(pred_tag_sequences)}"
|
| 70 |
+
|
| 71 |
+
confusion = defaultdict(lambda: defaultdict(int))
|
| 72 |
+
|
| 73 |
+
reverse_label_decoder = {y: x for x, y in label_decoder.items()}
|
| 74 |
+
for gold, pred in zip(gold_tag_sequences, pred_tag_sequences):
|
| 75 |
+
confusion[reverse_label_decoder[gold]][reverse_label_decoder[pred]] += 1
|
| 76 |
+
|
| 77 |
+
multi_class_result = defaultdict(lambda: defaultdict(float))
|
| 78 |
+
# compute precision, recall and f1 for each class and store inside of `multi_class_result`
|
| 79 |
+
for gold_tag in confusion.keys():
|
| 80 |
+
|
| 81 |
+
try:
|
| 82 |
+
prec = confusion.get(gold_tag, {}).get(gold_tag, 0) / sum([confusion.get(k, {}).get(gold_tag, 0) for k in confusion.keys()])
|
| 83 |
+
except ZeroDivisionError:
|
| 84 |
+
prec = 0.0
|
| 85 |
+
|
| 86 |
+
try:
|
| 87 |
+
recall = confusion.get(gold_tag, {}).get(gold_tag, 0) / sum(confusion.get(gold_tag, {}).values())
|
| 88 |
+
except ZeroDivisionError:
|
| 89 |
+
recall = 0.0
|
| 90 |
+
|
| 91 |
+
try:
|
| 92 |
+
f1 = 2 * (prec * recall) / (prec + recall)
|
| 93 |
+
except ZeroDivisionError:
|
| 94 |
+
f1 = 0.0
|
| 95 |
+
|
| 96 |
+
multi_class_result[gold_tag] = {
|
| 97 |
+
"precision": prec,
|
| 98 |
+
"recall": recall,
|
| 99 |
+
"f1": f1
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
if verbose:
|
| 103 |
+
for lemma in multi_class_result:
|
| 104 |
+
logger.info(f"Lemma '{lemma}' had precision {100 * multi_class_result[lemma]['precision']}, recall {100 * multi_class_result[lemma]['recall']} and F1 score of {100 * multi_class_result[lemma]['f1']}")
|
| 105 |
+
|
| 106 |
+
weighted_f1 = get_weighted_f1(multi_class_result, confusion)
|
| 107 |
+
|
| 108 |
+
return multi_class_result, confusion, weighted_f1
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def model_predict(model: nn.Module, position_indices: torch.Tensor, sentences: List[List[str]], upos_tags: List[List[int]]=[]) -> torch.Tensor:
|
| 112 |
+
"""
|
| 113 |
+
A LemmaClassifierLSTM or LemmaClassifierWithTransformer is used to predict on a single text example, given the position index of the target token.
|
| 114 |
+
|
| 115 |
+
Args:
|
| 116 |
+
model (LemmaClassifier): A trained LemmaClassifier that is able to predict on a target token.
|
| 117 |
+
position_indices (Tensor[int]): A tensor of the (zero-indexed) position of the target token in `text` for each example in the batch.
|
| 118 |
+
sentences (List[List[str]]): A list of lists of the tokenized strings of the input sentences.
|
| 119 |
+
|
| 120 |
+
Returns:
|
| 121 |
+
(int): The index of the predicted class in `model`'s output.
|
| 122 |
+
"""
|
| 123 |
+
with torch.no_grad():
|
| 124 |
+
logits = model(position_indices, sentences, upos_tags) # should be size (batch_size, output_size)
|
| 125 |
+
predicted_class = torch.argmax(logits, dim=1) # should be size (batch_size, 1)
|
| 126 |
+
|
| 127 |
+
return predicted_class
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def evaluate_model(model: nn.Module, eval_path: str, verbose: bool = True, is_training: bool = False) -> Tuple[Mapping, Mapping, float, float]:
|
| 131 |
+
"""
|
| 132 |
+
Helper function for model evaluation
|
| 133 |
+
|
| 134 |
+
Args:
|
| 135 |
+
model (LemmaClassifierLSTM or LemmaClassifierWithTransformer): An instance of the LemmaClassifier class that has architecture initialized which matches the model saved in `model_path`.
|
| 136 |
+
model_path (str): Path to the saved model weights that will be loaded into `model`.
|
| 137 |
+
eval_path (str): Path to the saved evaluation dataset.
|
| 138 |
+
verbose (bool, optional): True if `evaluate_sequences()` should print the F1, Precision, and Recall for each class. Defaults to True.
|
| 139 |
+
is_training (bool, optional): Whether the model is in training mode. If the model is training, we do not change it to eval mode.
|
| 140 |
+
|
| 141 |
+
Returns:
|
| 142 |
+
1. Multi-class results (Mapping[int, Mapping[str, float]]): first map has keys as the classes (lemma indices) and value is
|
| 143 |
+
another map with key of "f1", "precision", or "recall" with corresponding values.
|
| 144 |
+
2. Confusion Matrix (Mapping[int, Mapping[int, int]]): A confusion matrix with keys equal to the index of the gold tag, and a value of the
|
| 145 |
+
map with the key as the predicted tag and corresponding count of that (gold, pred) pair.
|
| 146 |
+
3. Accuracy (float): the total accuracy (num correct / total examples) across the evaluation set.
|
| 147 |
+
"""
|
| 148 |
+
# load model
|
| 149 |
+
device = default_device()
|
| 150 |
+
model.to(device)
|
| 151 |
+
|
| 152 |
+
if not is_training:
|
| 153 |
+
model.eval() # set to eval mode
|
| 154 |
+
|
| 155 |
+
# load in eval data
|
| 156 |
+
dataset = utils.Dataset(eval_path, label_decoder=model.label_decoder, shuffle=False)
|
| 157 |
+
|
| 158 |
+
logger.info(f"Evaluating on evaluation file {eval_path}")
|
| 159 |
+
|
| 160 |
+
correct, total = 0, 0
|
| 161 |
+
gold_tags, pred_tags = dataset.labels, []
|
| 162 |
+
|
| 163 |
+
# run eval on each example from dataset
|
| 164 |
+
for sentences, pos_indices, upos_tags, labels in tqdm(dataset, "Evaluating examples from data file"):
|
| 165 |
+
pred = model_predict(model, pos_indices, sentences, upos_tags) # Pred should be size (batch_size, )
|
| 166 |
+
correct_preds = pred == labels.to(device)
|
| 167 |
+
correct += torch.sum(correct_preds)
|
| 168 |
+
total += len(correct_preds)
|
| 169 |
+
pred_tags += pred.tolist()
|
| 170 |
+
|
| 171 |
+
logger.info("Finished evaluating on dataset. Computing scores...")
|
| 172 |
+
accuracy = correct / total
|
| 173 |
+
|
| 174 |
+
mc_results, confusion, weighted_f1 = evaluate_sequences(gold_tags, pred_tags, dataset.label_decoder, verbose=verbose)
|
| 175 |
+
# add brackets around batches of gold and pred tags because each batch is an element within the sequences in this helper
|
| 176 |
+
if verbose:
|
| 177 |
+
logger.info(f"Accuracy: {accuracy} ({correct}/{total})")
|
| 178 |
+
logger.info(f"Label decoder: {dataset.label_decoder}")
|
| 179 |
+
|
| 180 |
+
return mc_results, confusion, accuracy, weighted_f1
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def main(args=None, predefined_args=None):
|
| 184 |
+
|
| 185 |
+
# TODO: can unify this script with train_lstm_model.py?
|
| 186 |
+
# TODO: can save the model type in the model .pt, then
|
| 187 |
+
# automatically figure out what type of model we are using by
|
| 188 |
+
# looking in the file
|
| 189 |
+
parser = argparse.ArgumentParser()
|
| 190 |
+
parser.add_argument("--vocab_size", type=int, default=10000, help="Number of tokens in vocab")
|
| 191 |
+
parser.add_argument("--embedding_dim", type=int, default=100, help="Number of dimensions in word embeddings (currently using GloVe)")
|
| 192 |
+
parser.add_argument("--hidden_dim", type=int, default=256, help="Size of hidden layer")
|
| 193 |
+
parser.add_argument('--wordvec_pretrain_file', type=str, default=None, help='Exact name of the pretrain file to read')
|
| 194 |
+
parser.add_argument("--charlm", action='store_true', default=False, help="Whether not to use the charlm embeddings")
|
| 195 |
+
parser.add_argument('--charlm_shorthand', type=str, default=None, help="Shorthand for character-level language model training corpus.")
|
| 196 |
+
parser.add_argument("--charlm_forward_file", type=str, default=os.path.join(os.path.dirname(__file__), "charlm_files", "1billion_forward.pt"), help="Path to forward charlm file")
|
| 197 |
+
parser.add_argument("--charlm_backward_file", type=str, default=os.path.join(os.path.dirname(__file__), "charlm_files", "1billion_backwards.pt"), help="Path to backward charlm file")
|
| 198 |
+
parser.add_argument("--save_name", type=str, default=os.path.join(os.path.dirname(__file__), "saved_models", "lemma_classifier_model.pt"), help="Path to model save file")
|
| 199 |
+
parser.add_argument("--model_type", type=str, default="roberta", help="Which transformer to use ('bert' or 'roberta' or 'lstm')")
|
| 200 |
+
parser.add_argument("--bert_model", type=str, default=None, help="Use a specific transformer instead of the default bert/roberta")
|
| 201 |
+
parser.add_argument("--eval_file", type=str, help="path to evaluation file")
|
| 202 |
+
|
| 203 |
+
args = parser.parse_args(args) if not predefined_args else predefined_args
|
| 204 |
+
|
| 205 |
+
logger.info("Running training script with the following args:")
|
| 206 |
+
args = vars(args)
|
| 207 |
+
for arg in args:
|
| 208 |
+
logger.info(f"{arg}: {args[arg]}")
|
| 209 |
+
logger.info("------------------------------------------------------------")
|
| 210 |
+
|
| 211 |
+
logger.info(f"Attempting evaluation of model from {args['save_name']} on file {args['eval_file']}")
|
| 212 |
+
model = LemmaClassifier.load(args['save_name'], args)
|
| 213 |
+
|
| 214 |
+
mcc_results, confusion, acc, weighted_f1 = evaluate_model(model, args['eval_file'])
|
| 215 |
+
|
| 216 |
+
logger.info(f"MCC Results: {dict(mcc_results)}")
|
| 217 |
+
logger.info("______________________________________________")
|
| 218 |
+
logger.info(f"Confusion:\n%s", format_confusion(confusion))
|
| 219 |
+
logger.info("______________________________________________")
|
| 220 |
+
logger.info(f"Accuracy: {acc}")
|
| 221 |
+
logger.info("______________________________________________")
|
| 222 |
+
logger.info(f"Weighted f1: {weighted_f1}")
|
| 223 |
+
|
| 224 |
+
return mcc_results, confusion, acc, weighted_f1
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
if __name__ == "__main__":
|
| 228 |
+
main()
|
stanza/stanza/models/lemma_classifier/prepare_dataset.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
import re
|
| 5 |
+
|
| 6 |
+
import stanza
|
| 7 |
+
from stanza.models.lemma_classifier import utils
|
| 8 |
+
|
| 9 |
+
from typing import List, Tuple, Any
|
| 10 |
+
|
| 11 |
+
"""
|
| 12 |
+
The code in this file processes a CoNLL dataset by taking its sentences and filtering out all sentences that do not contain the target token.
|
| 13 |
+
Furthermore, it will store tuples of the Stanza document object, the position index of the target token, and its lemma.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def load_doc_from_conll_file(path: str):
|
| 18 |
+
""""
|
| 19 |
+
loads in a Stanza document object from a path to a CoNLL file containing annotated sentences.
|
| 20 |
+
"""
|
| 21 |
+
return stanza.utils.conll.CoNLL.conll2doc(path)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class DataProcessor():
|
| 25 |
+
|
| 26 |
+
def __init__(self, target_word: str, target_upos: List[str], allowed_lemmas: str):
|
| 27 |
+
self.target_word = target_word
|
| 28 |
+
self.target_word_regex = re.compile(target_word)
|
| 29 |
+
self.target_upos = target_upos
|
| 30 |
+
self.allowed_lemmas = re.compile(allowed_lemmas)
|
| 31 |
+
|
| 32 |
+
def keep_sentence(self, sentence):
|
| 33 |
+
for word in sentence.words:
|
| 34 |
+
if self.target_word_regex.fullmatch(word.text) and word.upos in self.target_upos:
|
| 35 |
+
return True
|
| 36 |
+
return False
|
| 37 |
+
|
| 38 |
+
def find_all_occurrences(self, sentence) -> List[int]:
|
| 39 |
+
"""
|
| 40 |
+
Finds all occurrences of self.target_word in tokens and returns the index(es) of such occurrences.
|
| 41 |
+
"""
|
| 42 |
+
occurrences = []
|
| 43 |
+
for idx, token in enumerate(sentence.words):
|
| 44 |
+
if self.target_word_regex.fullmatch(token.text) and token.upos in self.target_upos:
|
| 45 |
+
occurrences.append(idx)
|
| 46 |
+
return occurrences
|
| 47 |
+
|
| 48 |
+
@staticmethod
|
| 49 |
+
def write_output_file(save_name, target_upos, sentences):
|
| 50 |
+
with open(save_name, "w+", encoding="utf-8") as output_f:
|
| 51 |
+
output_f.write("{\n")
|
| 52 |
+
output_f.write(' "upos": %s,\n' % json.dumps(target_upos))
|
| 53 |
+
output_f.write(' "sentences": [')
|
| 54 |
+
wrote_sentence = False
|
| 55 |
+
for sentence in sentences:
|
| 56 |
+
if not wrote_sentence:
|
| 57 |
+
output_f.write("\n ")
|
| 58 |
+
wrote_sentence = True
|
| 59 |
+
else:
|
| 60 |
+
output_f.write(",\n ")
|
| 61 |
+
output_f.write(json.dumps(sentence))
|
| 62 |
+
output_f.write("\n ]\n}\n")
|
| 63 |
+
|
| 64 |
+
def process_document(self, doc, save_name: str) -> None:
|
| 65 |
+
"""
|
| 66 |
+
Takes any sentence from `doc` that meets the condition of `keep_sentence` and writes its tokens, index of target word, and lemma to `save_name`
|
| 67 |
+
|
| 68 |
+
Sentences that meet `keep_sentence` and contain `self.target_word` multiple times have each instance in a different example in the output file.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
doc (Stanza.doc): Document object that represents the file to be analyzed
|
| 72 |
+
save_name (str): Path to the file for storing output
|
| 73 |
+
"""
|
| 74 |
+
sentences = []
|
| 75 |
+
for sentence in doc.sentences:
|
| 76 |
+
# for each sentence, we need to determine if it should be added to the output file.
|
| 77 |
+
# if the sentence fulfills keep_sentence, then we will save it along with the target word's index and its corresponding lemma
|
| 78 |
+
if self.keep_sentence(sentence):
|
| 79 |
+
tokens = [token.text for token in sentence.words]
|
| 80 |
+
indexes = self.find_all_occurrences(sentence)
|
| 81 |
+
for idx in indexes:
|
| 82 |
+
if self.allowed_lemmas.fullmatch(sentence.words[idx].lemma):
|
| 83 |
+
# for each example found, we write the tokens,
|
| 84 |
+
# their respective upos tags, the target token index,
|
| 85 |
+
# and the target lemma
|
| 86 |
+
upos_tags = [sentence.words[i].upos for i in range(len(sentence.words))]
|
| 87 |
+
num_tokens = len(upos_tags)
|
| 88 |
+
sentences.append({
|
| 89 |
+
"words": tokens,
|
| 90 |
+
"upos_tags": upos_tags,
|
| 91 |
+
"index": idx,
|
| 92 |
+
"lemma": sentence.words[idx].lemma
|
| 93 |
+
})
|
| 94 |
+
|
| 95 |
+
if save_name:
|
| 96 |
+
self.write_output_file(save_name, self.target_upos, sentences)
|
| 97 |
+
return sentences
|
| 98 |
+
|
| 99 |
+
def main(args=None):
|
| 100 |
+
parser = argparse.ArgumentParser()
|
| 101 |
+
|
| 102 |
+
parser.add_argument("--conll_path", type=str, default=os.path.join(os.path.dirname(__file__), "en_gum-ud-train.conllu"), help="path to the conll file to translate")
|
| 103 |
+
parser.add_argument("--target_word", type=str, default="'s", help="Token to classify on, e.g. 's.")
|
| 104 |
+
parser.add_argument("--target_upos", type=str, default="AUX", help="upos on target token")
|
| 105 |
+
parser.add_argument("--output_path", type=str, default="test_output.txt", help="Path for output file")
|
| 106 |
+
parser.add_argument("--allowed_lemmas", type=str, default=".*", help="A regex for allowed lemmas. If not set, all lemmas are allowed")
|
| 107 |
+
|
| 108 |
+
args = parser.parse_args(args)
|
| 109 |
+
|
| 110 |
+
conll_path = args.conll_path
|
| 111 |
+
target_upos = args.target_upos
|
| 112 |
+
output_path = args.output_path
|
| 113 |
+
allowed_lemmas = args.allowed_lemmas
|
| 114 |
+
|
| 115 |
+
args = vars(args)
|
| 116 |
+
for arg in args:
|
| 117 |
+
print(f"{arg}: {args[arg]}")
|
| 118 |
+
|
| 119 |
+
doc = load_doc_from_conll_file(conll_path)
|
| 120 |
+
processor = DataProcessor(target_word=args['target_word'], target_upos=[target_upos], allowed_lemmas=allowed_lemmas)
|
| 121 |
+
|
| 122 |
+
return processor.process_document(doc, output_path)
|
| 123 |
+
|
| 124 |
+
if __name__ == "__main__":
|
| 125 |
+
main()
|
stanza/stanza/models/lemma_classifier/train_lstm_model.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
The code in this file works to train a lemma classifier for 's
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
import logging
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
|
| 12 |
+
from stanza.models.common.foundation_cache import load_pretrain
|
| 13 |
+
from stanza.models.lemma_classifier.base_trainer import BaseLemmaClassifierTrainer
|
| 14 |
+
from stanza.models.lemma_classifier.constants import DEFAULT_BATCH_SIZE
|
| 15 |
+
from stanza.models.lemma_classifier.lstm_model import LemmaClassifierLSTM
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger('stanza.lemmaclassifier')
|
| 18 |
+
|
| 19 |
+
class LemmaClassifierTrainer(BaseLemmaClassifierTrainer):
|
| 20 |
+
"""
|
| 21 |
+
Class to assist with training a LemmaClassifierLSTM
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self, model_args: dict, embedding_file: str, use_charlm: bool = False, charlm_forward_file: str = None, charlm_backward_file: str = None, lr: float = 0.001, loss_func: str = None):
|
| 25 |
+
"""
|
| 26 |
+
Initializes the LemmaClassifierTrainer class.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
model_args (dict): Various model shape parameters
|
| 30 |
+
embedding_file (str): What word embeddings file to use. Use a Stanza pretrain .pt
|
| 31 |
+
use_charlm (bool, optional): Whether to use charlm embeddings as well. Defaults to False.
|
| 32 |
+
charlm_forward_file (str): Path to the forward pass embeddings for the charlm
|
| 33 |
+
charlm_backward_file (str): Path to the backward pass embeddings for the charlm
|
| 34 |
+
upos_emb_dim (int): The dimension size of UPOS tag embeddings
|
| 35 |
+
num_heads (int): The number of attention heads to use.
|
| 36 |
+
lr (float): Learning rate, defaults to 0.001.
|
| 37 |
+
loss_func (str): Which loss function to use (either 'ce' or 'weighted_bce')
|
| 38 |
+
|
| 39 |
+
Raises:
|
| 40 |
+
FileNotFoundError: If the forward charlm file is not present
|
| 41 |
+
FileNotFoundError: If the backward charlm file is not present
|
| 42 |
+
"""
|
| 43 |
+
super().__init__()
|
| 44 |
+
|
| 45 |
+
self.model_args = model_args
|
| 46 |
+
|
| 47 |
+
# Load word embeddings
|
| 48 |
+
pt = load_pretrain(embedding_file)
|
| 49 |
+
self.pt_embedding = pt
|
| 50 |
+
|
| 51 |
+
# Load CharLM embeddings
|
| 52 |
+
if use_charlm and charlm_forward_file is not None and not os.path.exists(charlm_forward_file):
|
| 53 |
+
raise FileNotFoundError(f"Could not find forward charlm file: {charlm_forward_file}")
|
| 54 |
+
if use_charlm and charlm_backward_file is not None and not os.path.exists(charlm_backward_file):
|
| 55 |
+
raise FileNotFoundError(f"Could not find backward charlm file: {charlm_backward_file}")
|
| 56 |
+
|
| 57 |
+
# TODO: just pass around the args instead
|
| 58 |
+
self.use_charlm = use_charlm
|
| 59 |
+
self.charlm_forward_file = charlm_forward_file
|
| 60 |
+
self.charlm_backward_file = charlm_backward_file
|
| 61 |
+
self.lr = lr
|
| 62 |
+
|
| 63 |
+
# Find loss function
|
| 64 |
+
if loss_func == "ce":
|
| 65 |
+
self.criterion = nn.CrossEntropyLoss()
|
| 66 |
+
self.weighted_loss = False
|
| 67 |
+
logger.debug("Using CE loss")
|
| 68 |
+
elif loss_func == "weighted_bce":
|
| 69 |
+
self.criterion = nn.BCEWithLogitsLoss()
|
| 70 |
+
self.weighted_loss = True # used to add weights during train time.
|
| 71 |
+
logger.debug("Using Weighted BCE loss")
|
| 72 |
+
else:
|
| 73 |
+
raise ValueError("Must enter a valid loss function (e.g. 'ce' or 'weighted_bce')")
|
| 74 |
+
|
| 75 |
+
def build_model(self, label_decoder, upos_to_id, known_words, target_words, target_upos):
|
| 76 |
+
return LemmaClassifierLSTM(self.model_args, self.output_dim, self.pt_embedding, label_decoder, upos_to_id, known_words, target_words, target_upos,
|
| 77 |
+
use_charlm=self.use_charlm, charlm_forward_file=self.charlm_forward_file, charlm_backward_file=self.charlm_backward_file)
|
| 78 |
+
|
| 79 |
+
def build_argparse():
|
| 80 |
+
parser = argparse.ArgumentParser()
|
| 81 |
+
parser.add_argument("--hidden_dim", type=int, default=256, help="Size of hidden layer")
|
| 82 |
+
parser.add_argument('--wordvec_pretrain_file', type=str, default=os.path.join(os.path.dirname(__file__), "pretrain", "glove.pt"), help='Exact name of the pretrain file to read')
|
| 83 |
+
parser.add_argument("--charlm", action='store_true', dest='use_charlm', default=False, help="Whether not to use the charlm embeddings")
|
| 84 |
+
parser.add_argument('--charlm_shorthand', type=str, default=None, help="Shorthand for character-level language model training corpus.")
|
| 85 |
+
parser.add_argument("--charlm_forward_file", type=str, default=os.path.join(os.path.dirname(__file__), "charlm_files", "1billion_forward.pt"), help="Path to forward charlm file")
|
| 86 |
+
parser.add_argument("--charlm_backward_file", type=str, default=os.path.join(os.path.dirname(__file__), "charlm_files", "1billion_backwards.pt"), help="Path to backward charlm file")
|
| 87 |
+
parser.add_argument("--upos_emb_dim", type=int, default=20, help="Dimension size for UPOS tag embeddings.")
|
| 88 |
+
parser.add_argument("--use_attn", action='store_true', dest='attn', default=False, help='Whether to use multihead attention instead of LSTM.')
|
| 89 |
+
parser.add_argument("--num_heads", type=int, default=0, help="Number of heads to use for multihead attention.")
|
| 90 |
+
parser.add_argument("--save_name", type=str, default=os.path.join(os.path.dirname(__file__), "saved_models", "lemma_classifier_model_weighted_loss_charlm_new.pt"), help="Path to model save file")
|
| 91 |
+
parser.add_argument("--lr", type=float, default=0.001, help="learning rate")
|
| 92 |
+
parser.add_argument("--num_epochs", type=float, default=10, help="Number of training epochs")
|
| 93 |
+
parser.add_argument("--batch_size", type=int, default=DEFAULT_BATCH_SIZE, help="Number of examples to include in each batch")
|
| 94 |
+
parser.add_argument("--train_file", type=str, default=os.path.join(os.path.dirname(__file__), "data", "processed_ud_en", "combined_train.txt"), help="Full path to training file")
|
| 95 |
+
parser.add_argument("--weighted_loss", action='store_true', dest='weighted_loss', default=False, help="Whether to use weighted loss during training.")
|
| 96 |
+
parser.add_argument("--eval_file", type=str, default=os.path.join(os.path.dirname(__file__), "data", "processed_ud_en", "combined_dev.txt"), help="Path to dev file used to evaluate model for saves")
|
| 97 |
+
parser.add_argument("--force", action='store_true', default=False, help='Whether or not to clobber an existing save file')
|
| 98 |
+
return parser
|
| 99 |
+
|
| 100 |
+
def main(args=None, predefined_args=None):
|
| 101 |
+
parser = build_argparse()
|
| 102 |
+
args = parser.parse_args(args) if predefined_args is None else predefined_args
|
| 103 |
+
|
| 104 |
+
wordvec_pretrain_file = args.wordvec_pretrain_file
|
| 105 |
+
use_charlm = args.use_charlm
|
| 106 |
+
charlm_forward_file = args.charlm_forward_file
|
| 107 |
+
charlm_backward_file = args.charlm_backward_file
|
| 108 |
+
upos_emb_dim = args.upos_emb_dim
|
| 109 |
+
use_attention = args.attn
|
| 110 |
+
num_heads = args.num_heads
|
| 111 |
+
save_name = args.save_name
|
| 112 |
+
lr = args.lr
|
| 113 |
+
num_epochs = args.num_epochs
|
| 114 |
+
train_file = args.train_file
|
| 115 |
+
weighted_loss = args.weighted_loss
|
| 116 |
+
eval_file = args.eval_file
|
| 117 |
+
|
| 118 |
+
args = vars(args)
|
| 119 |
+
|
| 120 |
+
if os.path.exists(save_name) and not args.get('force', False):
|
| 121 |
+
raise FileExistsError(f"Save name {save_name} already exists. Training would override existing data. Aborting...")
|
| 122 |
+
if not os.path.exists(train_file):
|
| 123 |
+
raise FileNotFoundError(f"Training file {train_file} not found. Try again with a valid path.")
|
| 124 |
+
|
| 125 |
+
logger.info("Running training script with the following args:")
|
| 126 |
+
for arg in args:
|
| 127 |
+
logger.info(f"{arg}: {args[arg]}")
|
| 128 |
+
logger.info("------------------------------------------------------------")
|
| 129 |
+
|
| 130 |
+
trainer = LemmaClassifierTrainer(model_args=args,
|
| 131 |
+
embedding_file=wordvec_pretrain_file,
|
| 132 |
+
use_charlm=use_charlm,
|
| 133 |
+
charlm_forward_file=charlm_forward_file,
|
| 134 |
+
charlm_backward_file=charlm_backward_file,
|
| 135 |
+
lr=lr,
|
| 136 |
+
loss_func="weighted_bce" if weighted_loss else "ce",
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
trainer.train(
|
| 140 |
+
num_epochs=num_epochs, save_name=save_name, args=args, eval_file=eval_file, train_file=train_file
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
return trainer
|
| 144 |
+
|
| 145 |
+
if __name__ == "__main__":
|
| 146 |
+
main()
|
| 147 |
+
|
stanza/stanza/models/lemma_classifier/train_many.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utils for training and evaluating multiple models simultaneously
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
from stanza.models.lemma_classifier.train_lstm_model import main as train_lstm_main
|
| 9 |
+
from stanza.models.lemma_classifier.train_transformer_model import main as train_tfmr_main
|
| 10 |
+
from stanza.models.lemma_classifier.constants import DEFAULT_BATCH_SIZE
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
change_params_map = {
|
| 14 |
+
"lstm_layer": [16, 32, 64, 128, 256, 512],
|
| 15 |
+
"upos_emb_dim": [5, 10, 20, 30],
|
| 16 |
+
"training_size": [150, 300, 450, 600, 'full'],
|
| 17 |
+
} # TODO: Add attention
|
| 18 |
+
|
| 19 |
+
def train_n_models(num_models: int, base_path: str, args):
|
| 20 |
+
|
| 21 |
+
if args.change_param == "lstm_layer":
|
| 22 |
+
for num_layers in change_params_map.get("lstm_layer", None):
|
| 23 |
+
for i in range(num_models):
|
| 24 |
+
new_save_name = os.path.join(base_path, f"{num_layers}_{i}.pt")
|
| 25 |
+
args.save_name = new_save_name
|
| 26 |
+
args.hidden_dim = num_layers
|
| 27 |
+
train_lstm_main(predefined_args=args)
|
| 28 |
+
|
| 29 |
+
if args.change_param == "upos_emb_dim":
|
| 30 |
+
for upos_dim in change_params_map("upos_emb_dim", None):
|
| 31 |
+
for i in range(num_models):
|
| 32 |
+
new_save_name = os.path.join(base_path, f"dim_{upos_dim}_{i}.pt")
|
| 33 |
+
args.save_name = new_save_name
|
| 34 |
+
args.upos_emb_dim = upos_dim
|
| 35 |
+
train_lstm_main(predefined_args=args)
|
| 36 |
+
|
| 37 |
+
if args.change_param == "training_size":
|
| 38 |
+
for size in change_params_map.get("training_size", None):
|
| 39 |
+
for i in range(num_models):
|
| 40 |
+
new_save_name = os.path.join(base_path, f"{size}_examples_{i}.pt")
|
| 41 |
+
new_train_file = os.path.join(os.path.dirname(__file__), "data", "processed_ud_en", "combined_train.txt")
|
| 42 |
+
args.save_name = new_save_name
|
| 43 |
+
args.train_file = new_train_file
|
| 44 |
+
train_lstm_main(predefined_args=args)
|
| 45 |
+
|
| 46 |
+
if args.change_param == "base":
|
| 47 |
+
for i in range(num_models):
|
| 48 |
+
new_save_name = os.path.join(base_path, f"lstm_model_{i}.pt")
|
| 49 |
+
args.save_name = new_save_name
|
| 50 |
+
args.weighted_loss = False
|
| 51 |
+
train_lstm_main(predefined_args=args)
|
| 52 |
+
|
| 53 |
+
if not args.weighted_loss:
|
| 54 |
+
args.weighted_loss = True
|
| 55 |
+
new_save_name = os.path.join(base_path, f"lstm_model_wloss_{i}.pt")
|
| 56 |
+
args.save_name = new_save_name
|
| 57 |
+
train_lstm_main(predefined_args=args)
|
| 58 |
+
|
| 59 |
+
if args.change_param == "base_charlm":
|
| 60 |
+
for i in range(num_models):
|
| 61 |
+
new_save_name = os.path.join(base_path, f"lstm_charlm_{i}.pt")
|
| 62 |
+
args.save_name = new_save_name
|
| 63 |
+
train_lstm_main(predefined_args=args)
|
| 64 |
+
|
| 65 |
+
if args.change_param == "base_charlm_upos":
|
| 66 |
+
for i in range(num_models):
|
| 67 |
+
new_save_name = os.path.join(base_path, f"lstm_charlm_upos_{i}.pt")
|
| 68 |
+
args.save_name = new_save_name
|
| 69 |
+
train_lstm_main(predefined_args=args)
|
| 70 |
+
|
| 71 |
+
if args.change_param == "base_upos":
|
| 72 |
+
for i in range(num_models):
|
| 73 |
+
new_save_name = os.path.join(base_path, f"lstm_upos_{i}.pt")
|
| 74 |
+
args.save_name = new_save_name
|
| 75 |
+
train_lstm_main(predefined_args=args)
|
| 76 |
+
|
| 77 |
+
if args.change_param == "attn_model":
|
| 78 |
+
for i in range(num_models):
|
| 79 |
+
new_save_name = os.path.join(base_path, f"attn_model_{args.num_heads}_heads_{i}.pt")
|
| 80 |
+
args.save_name = new_save_name
|
| 81 |
+
train_lstm_main(predefined_args=args)
|
| 82 |
+
|
| 83 |
+
def train_n_tfmrs(num_models: int, base_path: str, args):
|
| 84 |
+
|
| 85 |
+
if args.multi_train_type == "tfmr":
|
| 86 |
+
|
| 87 |
+
for i in range(num_models):
|
| 88 |
+
|
| 89 |
+
if args.change_param == "bert":
|
| 90 |
+
new_save_name = os.path.join(base_path, f"bert_{i}.pt")
|
| 91 |
+
args.save_name = new_save_name
|
| 92 |
+
args.loss_fn = "ce"
|
| 93 |
+
train_tfmr_main(predefined_args=args)
|
| 94 |
+
|
| 95 |
+
new_save_name = os.path.join(base_path, f"bert_wloss_{i}.pt")
|
| 96 |
+
args.save_name = new_save_name
|
| 97 |
+
args.loss_fn = "weighted_bce"
|
| 98 |
+
train_tfmr_main(predefined_args=args)
|
| 99 |
+
|
| 100 |
+
elif args.change_param == "roberta":
|
| 101 |
+
new_save_name = os.path.join(base_path, f"roberta_{i}.pt")
|
| 102 |
+
args.save_name = new_save_name
|
| 103 |
+
args.loss_fn = "ce"
|
| 104 |
+
train_tfmr_main(predefined_args=args)
|
| 105 |
+
|
| 106 |
+
new_save_name = os.path.join(base_path, f"roberta_wloss_{i}.pt")
|
| 107 |
+
args.save_name = new_save_name
|
| 108 |
+
args.loss_fn = "weighted_bce"
|
| 109 |
+
train_tfmr_main(predefined_args=args)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def main():
|
| 113 |
+
parser = argparse.ArgumentParser()
|
| 114 |
+
parser.add_argument("--hidden_dim", type=int, default=256, help="Size of hidden layer")
|
| 115 |
+
parser.add_argument('--wordvec_pretrain_file', type=str, default=os.path.join(os.path.dirname(__file__), "pretrain", "glove.pt"), help='Exact name of the pretrain file to read')
|
| 116 |
+
parser.add_argument("--charlm", action='store_true', dest='use_charlm', default=False, help="Whether not to use the charlm embeddings")
|
| 117 |
+
parser.add_argument('--charlm_shorthand', type=str, default=None, help="Shorthand for character-level language model training corpus.")
|
| 118 |
+
parser.add_argument("--charlm_forward_file", type=str, default=os.path.join(os.path.dirname(__file__), "charlm_files", "1billion_forward.pt"), help="Path to forward charlm file")
|
| 119 |
+
parser.add_argument("--charlm_backward_file", type=str, default=os.path.join(os.path.dirname(__file__), "charlm_files", "1billion_backwards.pt"), help="Path to backward charlm file")
|
| 120 |
+
parser.add_argument("--upos_emb_dim", type=int, default=20, help="Dimension size for UPOS tag embeddings.")
|
| 121 |
+
parser.add_argument("--use_attn", action='store_true', dest='attn', default=False, help='Whether to use multihead attention instead of LSTM.')
|
| 122 |
+
parser.add_argument("--num_heads", type=int, default=0, help="Number of heads to use for multihead attention.")
|
| 123 |
+
parser.add_argument("--save_name", type=str, default=os.path.join(os.path.dirname(__file__), "saved_models", "lemma_classifier_model_weighted_loss_charlm_new.pt"), help="Path to model save file")
|
| 124 |
+
parser.add_argument("--lr", type=float, default=0.001, help="learning rate")
|
| 125 |
+
parser.add_argument("--num_epochs", type=float, default=10, help="Number of training epochs")
|
| 126 |
+
parser.add_argument("--batch_size", type=int, default=DEFAULT_BATCH_SIZE, help="Number of examples to include in each batch")
|
| 127 |
+
parser.add_argument("--train_file", type=str, default=os.path.join(os.path.dirname(__file__), "data", "processed_ud_en", "combined_train.txt"), help="Full path to training file")
|
| 128 |
+
parser.add_argument("--weighted_loss", action='store_true', dest='weighted_loss', default=False, help="Whether to use weighted loss during training.")
|
| 129 |
+
parser.add_argument("--eval_file", type=str, default=os.path.join(os.path.dirname(__file__), "data", "processed_ud_en", "combined_dev.txt"), help="Path to dev file used to evaluate model for saves")
|
| 130 |
+
# Tfmr-specific args
|
| 131 |
+
parser.add_argument("--model_type", type=str, default="roberta", help="Which transformer to use ('bert' or 'roberta')")
|
| 132 |
+
parser.add_argument("--bert_model", type=str, default=None, help="Use a specific transformer instead of the default bert/roberta")
|
| 133 |
+
parser.add_argument("--loss_fn", type=str, default="weighted_bce", help="Which loss function to train with (e.g. 'ce' or 'weighted_bce')")
|
| 134 |
+
# Multi-model train args
|
| 135 |
+
parser.add_argument("--multi_train_type", type=str, default="lstm", help="Whether you are attempting to multi-train an LSTM or transformer")
|
| 136 |
+
parser.add_argument("--multi_train_count", type=int, default=5, help="Number of each model to build")
|
| 137 |
+
parser.add_argument("--base_path", type=str, default=None, help="Path to start generating model type for.")
|
| 138 |
+
parser.add_argument("--change_param", type=str, default=None, help="Which hyperparameter to change when training")
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
args = parser.parse_args()
|
| 142 |
+
|
| 143 |
+
if args.multi_train_type == "lstm":
|
| 144 |
+
train_n_models(num_models=args.multi_train_count,
|
| 145 |
+
base_path=args.base_path,
|
| 146 |
+
args=args)
|
| 147 |
+
elif args.multi_train_type == "tfmr":
|
| 148 |
+
train_n_tfmrs(num_models=args.multi_train_count,
|
| 149 |
+
base_path=args.base_path,
|
| 150 |
+
args=args)
|
| 151 |
+
else:
|
| 152 |
+
raise ValueError(f"Improper input {args.multi_train_type}")
|
| 153 |
+
|
| 154 |
+
if __name__ == "__main__":
|
| 155 |
+
main()
|
stanza/stanza/models/lemma_classifier/train_transformer_model.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This file contains code used to train a baseline transformer model to classify on a lemma of a particular token.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
import os
|
| 7 |
+
import sys
|
| 8 |
+
import logging
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.optim as optim
|
| 13 |
+
|
| 14 |
+
from stanza.models.lemma_classifier.base_trainer import BaseLemmaClassifierTrainer
|
| 15 |
+
from stanza.models.lemma_classifier.constants import DEFAULT_BATCH_SIZE
|
| 16 |
+
from stanza.models.lemma_classifier.transformer_model import LemmaClassifierWithTransformer
|
| 17 |
+
from stanza.models.common.utils import default_device
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger('stanza.lemmaclassifier')
|
| 20 |
+
|
| 21 |
+
class TransformerBaselineTrainer(BaseLemmaClassifierTrainer):
|
| 22 |
+
"""
|
| 23 |
+
Class to assist with training a baseline transformer model to classify on token lemmas.
|
| 24 |
+
To find the model spec, refer to `model.py` in this directory.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(self, model_args: dict, transformer_name: str = "roberta", loss_func: str = "ce", lr: int = 0.001):
|
| 28 |
+
"""
|
| 29 |
+
Creates the Trainer object
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
transformer_name (str, optional): What kind of transformer to use for embeddings. Defaults to "roberta".
|
| 33 |
+
loss_func (str, optional): Which loss function to use (either 'ce' or 'weighted_bce'). Defaults to "ce".
|
| 34 |
+
lr (int, optional): learning rate for the optimizer. Defaults to 0.001.
|
| 35 |
+
"""
|
| 36 |
+
super().__init__()
|
| 37 |
+
|
| 38 |
+
self.model_args = model_args
|
| 39 |
+
|
| 40 |
+
# Find loss function
|
| 41 |
+
if loss_func == "ce":
|
| 42 |
+
self.criterion = nn.CrossEntropyLoss()
|
| 43 |
+
self.weighted_loss = False
|
| 44 |
+
elif loss_func == "weighted_bce":
|
| 45 |
+
self.criterion = nn.BCEWithLogitsLoss()
|
| 46 |
+
self.weighted_loss = True # used to add weights during train time.
|
| 47 |
+
else:
|
| 48 |
+
raise ValueError("Must enter a valid loss function (e.g. 'ce' or 'weighted_bce')")
|
| 49 |
+
|
| 50 |
+
self.transformer_name = transformer_name
|
| 51 |
+
self.lr = lr
|
| 52 |
+
|
| 53 |
+
def set_layer_learning_rates(self, transformer_lr: float, mlp_lr: float) -> torch.optim:
|
| 54 |
+
"""
|
| 55 |
+
Sets learning rates for each layer of the model.
|
| 56 |
+
Currently, the model has the transformer layer and the MLP layer, so these are tweakable.
|
| 57 |
+
|
| 58 |
+
Returns (torch.optim): An Adam optimizer with the learning rates adjusted per layer.
|
| 59 |
+
|
| 60 |
+
Currently unused - could be refactored into the parent class's train method,
|
| 61 |
+
or the parent class could call a build_optimizer and this subclass would use the optimizer
|
| 62 |
+
"""
|
| 63 |
+
transformer_params, mlp_params = [], []
|
| 64 |
+
for name, param in self.model.named_parameters():
|
| 65 |
+
if 'transformer' in name:
|
| 66 |
+
transformer_params.append(param)
|
| 67 |
+
elif 'mlp' in name:
|
| 68 |
+
mlp_params.append(param)
|
| 69 |
+
optimizer = optim.Adam([
|
| 70 |
+
{"params": transformer_params, "lr": transformer_lr},
|
| 71 |
+
{"params": mlp_params, "lr": mlp_lr}
|
| 72 |
+
])
|
| 73 |
+
return optimizer
|
| 74 |
+
|
| 75 |
+
def build_model(self, label_decoder, upos_to_id, known_words, target_words, target_upos):
|
| 76 |
+
return LemmaClassifierWithTransformer(model_args=self.model_args, output_dim=self.output_dim, transformer_name=self.transformer_name, label_decoder=label_decoder, target_words=target_words, target_upos=target_upos)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def main(args=None, predefined_args=None):
|
| 80 |
+
parser = argparse.ArgumentParser()
|
| 81 |
+
|
| 82 |
+
parser.add_argument("--save_name", type=str, default=os.path.join(os.path.dirname(os.path.dirname(__file__)), "saved_models", "big_model_roberta_weighted_loss.pt"), help="Path to model save file")
|
| 83 |
+
parser.add_argument("--num_epochs", type=float, default=10, help="Number of training epochs")
|
| 84 |
+
parser.add_argument("--train_file", type=str, default=os.path.join(os.path.dirname(os.path.dirname(__file__)), "test_sets", "combined_train.txt"), help="Full path to training file")
|
| 85 |
+
parser.add_argument("--model_type", type=str, default="roberta", help="Which transformer to use ('bert' or 'roberta')")
|
| 86 |
+
parser.add_argument("--bert_model", type=str, default=None, help="Use a specific transformer instead of the default bert/roberta")
|
| 87 |
+
parser.add_argument("--loss_fn", type=str, default="weighted_bce", help="Which loss function to train with (e.g. 'ce' or 'weighted_bce')")
|
| 88 |
+
parser.add_argument("--batch_size", type=int, default=DEFAULT_BATCH_SIZE, help="Number of examples to include in each batch")
|
| 89 |
+
parser.add_argument("--eval_file", type=str, default=os.path.join(os.path.dirname(os.path.dirname(__file__)), "test_sets", "combined_dev.txt"), help="Path to dev file used to evaluate model for saves")
|
| 90 |
+
parser.add_argument("--lr", type=float, default=0.001, help="Learning rate for the optimizer.")
|
| 91 |
+
parser.add_argument("--force", action='store_true', default=False, help='Whether or not to clobber an existing save file')
|
| 92 |
+
|
| 93 |
+
args = parser.parse_args(args) if predefined_args is None else predefined_args
|
| 94 |
+
|
| 95 |
+
save_name = args.save_name
|
| 96 |
+
num_epochs = args.num_epochs
|
| 97 |
+
train_file = args.train_file
|
| 98 |
+
loss_fn = args.loss_fn
|
| 99 |
+
eval_file = args.eval_file
|
| 100 |
+
lr = args.lr
|
| 101 |
+
|
| 102 |
+
args = vars(args)
|
| 103 |
+
|
| 104 |
+
if args['model_type'] == 'bert':
|
| 105 |
+
args['bert_model'] = 'bert-base-uncased'
|
| 106 |
+
elif args['model_type'] == 'roberta':
|
| 107 |
+
args['bert_model'] = 'roberta-base'
|
| 108 |
+
elif args['model_type'] == 'transformer':
|
| 109 |
+
if args['bert_model'] is None:
|
| 110 |
+
raise ValueError("Need to specify a bert_model for model_type transformer!")
|
| 111 |
+
else:
|
| 112 |
+
raise ValueError("Unknown model type " + args['model_type'])
|
| 113 |
+
|
| 114 |
+
if os.path.exists(save_name) and not args.get('force', False):
|
| 115 |
+
raise FileExistsError(f"Save name {save_name} already exists. Training would override existing data. Aborting...")
|
| 116 |
+
if not os.path.exists(train_file):
|
| 117 |
+
raise FileNotFoundError(f"Training file {train_file} not found. Try again with a valid path.")
|
| 118 |
+
|
| 119 |
+
logger.info("Running training script with the following args:")
|
| 120 |
+
for arg in args:
|
| 121 |
+
logger.info(f"{arg}: {args[arg]}")
|
| 122 |
+
logger.info("------------------------------------------------------------")
|
| 123 |
+
|
| 124 |
+
trainer = TransformerBaselineTrainer(model_args=args, transformer_name=args['bert_model'], loss_func=loss_fn, lr=lr)
|
| 125 |
+
|
| 126 |
+
trainer.train(num_epochs=num_epochs, save_name=save_name, train_file=train_file, args=args, eval_file=eval_file)
|
| 127 |
+
return trainer
|
| 128 |
+
|
| 129 |
+
if __name__ == "__main__":
|
| 130 |
+
main()
|
stanza/stanza/models/lemma_classifier/transformer_model.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
import logging
|
| 6 |
+
|
| 7 |
+
from transformers import AutoTokenizer, AutoModel
|
| 8 |
+
from typing import Mapping, List, Tuple, Any
|
| 9 |
+
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, pad_sequence
|
| 10 |
+
from stanza.models.common.bert_embedding import extract_bert_embeddings
|
| 11 |
+
from stanza.models.lemma_classifier.base_model import LemmaClassifier
|
| 12 |
+
from stanza.models.lemma_classifier.constants import ModelType
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger('stanza.lemmaclassifier')
|
| 15 |
+
|
| 16 |
+
class LemmaClassifierWithTransformer(LemmaClassifier):
|
| 17 |
+
def __init__(self, model_args: dict, output_dim: int, transformer_name: str, label_decoder: Mapping, target_words: set, target_upos: set):
|
| 18 |
+
"""
|
| 19 |
+
Model architecture:
|
| 20 |
+
|
| 21 |
+
Use a transformer (BERT or RoBERTa) to extract contextual embedding over a sentence.
|
| 22 |
+
Get the embedding for the word that is to be classified on, and feed the embedding
|
| 23 |
+
as input to an MLP classifier that has 2 linear layers, and a prediction head.
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
model_args (dict): args for the model
|
| 27 |
+
output_dim (int): Dimension of the output from the MLP
|
| 28 |
+
transformer_name (str): name of the HF transformer to use
|
| 29 |
+
label_decoder (dict): a map of the labels available to the model
|
| 30 |
+
target_words (set(str)): a set of the words which might need lemmatization
|
| 31 |
+
"""
|
| 32 |
+
super(LemmaClassifierWithTransformer, self).__init__(label_decoder, target_words, target_upos)
|
| 33 |
+
self.model_args = model_args
|
| 34 |
+
|
| 35 |
+
# Choose transformer
|
| 36 |
+
self.transformer_name = transformer_name
|
| 37 |
+
self.tokenizer = AutoTokenizer.from_pretrained(transformer_name, use_fast=True, add_prefix_space=True)
|
| 38 |
+
self.add_unsaved_module("transformer", AutoModel.from_pretrained(transformer_name))
|
| 39 |
+
config = self.transformer.config
|
| 40 |
+
|
| 41 |
+
embedding_size = config.hidden_size
|
| 42 |
+
|
| 43 |
+
# define an MLP layer
|
| 44 |
+
self.mlp = nn.Sequential(
|
| 45 |
+
nn.Linear(embedding_size, 64),
|
| 46 |
+
nn.ReLU(),
|
| 47 |
+
nn.Linear(64, output_dim)
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
def get_save_dict(self):
|
| 51 |
+
save_dict = {
|
| 52 |
+
"params": self.state_dict(),
|
| 53 |
+
"label_decoder": self.label_decoder,
|
| 54 |
+
"target_words": list(self.target_words),
|
| 55 |
+
"target_upos": list(self.target_upos),
|
| 56 |
+
"model_type": self.model_type().name,
|
| 57 |
+
"args": self.model_args,
|
| 58 |
+
}
|
| 59 |
+
skipped = [k for k in save_dict["params"].keys() if self.is_unsaved_module(k)]
|
| 60 |
+
for k in skipped:
|
| 61 |
+
del save_dict["params"][k]
|
| 62 |
+
return save_dict
|
| 63 |
+
|
| 64 |
+
def convert_tags(self, upos_tags: List[List[str]]):
|
| 65 |
+
return None
|
| 66 |
+
|
| 67 |
+
def forward(self, idx_positions: List[int], sentences: List[List[str]], upos_tags: List[List[int]]):
|
| 68 |
+
"""
|
| 69 |
+
Computes the forward pass of the transformer baselines
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
idx_positions (List[int]): A list of the position index of the target token for lemmatization classification in each sentence.
|
| 73 |
+
sentences (List[List[str]]): A list of the token-split sentences of the input data.
|
| 74 |
+
upos_tags (List[List[int]]): A list of the upos tags for each token in every sentence - not used in this model, here for compatibility
|
| 75 |
+
|
| 76 |
+
Returns:
|
| 77 |
+
torch.tensor: Output logits of the neural network, where the shape is (n, output_size) where n is the number of sentences.
|
| 78 |
+
"""
|
| 79 |
+
device = next(self.transformer.parameters()).device
|
| 80 |
+
bert_embeddings = extract_bert_embeddings(self.transformer_name, self.tokenizer, self.transformer, sentences, device,
|
| 81 |
+
keep_endpoints=False, num_layers=1, detach=True)
|
| 82 |
+
embeddings = [emb[idx] for idx, emb in zip(idx_positions, bert_embeddings)]
|
| 83 |
+
embeddings = torch.stack(embeddings, dim=0)[:, :, 0]
|
| 84 |
+
# pass to the MLP
|
| 85 |
+
output = self.mlp(embeddings)
|
| 86 |
+
return output
|
| 87 |
+
|
| 88 |
+
def model_type(self):
|
| 89 |
+
return ModelType.TRANSFORMER
|
stanza/stanza/models/lemma_classifier/utils.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import Counter
|
| 2 |
+
import json
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
import random
|
| 6 |
+
from typing import List, Tuple, Any, Mapping
|
| 7 |
+
|
| 8 |
+
import stanza
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from stanza.models.lemma_classifier.constants import DEFAULT_BATCH_SIZE
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger('stanza.lemmaclassifier')
|
| 14 |
+
|
| 15 |
+
class Dataset:
|
| 16 |
+
def __init__(self, data_path: str, batch_size: int =DEFAULT_BATCH_SIZE, get_counts: bool = False, label_decoder: dict = None, shuffle: bool = True):
|
| 17 |
+
"""
|
| 18 |
+
Loads a data file into data batches for tokenized text sentences, token indices, and true labels for each sentence.
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
data_path (str): Path to data file, containing tokenized text sentences, token index and true label for token lemma on each line.
|
| 22 |
+
batch_size (int): Size of each batch of examples
|
| 23 |
+
get_counts (optional, bool): Whether there should be a map of the label index to counts
|
| 24 |
+
|
| 25 |
+
Returns:
|
| 26 |
+
1. List[List[List[str]]]: Batches of sentences, where each token is a separate entry in each sentence
|
| 27 |
+
2. List[torch.tensor[int]]: A batch of indexes for the target token corresponding to its sentence
|
| 28 |
+
3. List[torch.tensor[int]]: A batch of labels for the target token's lemma
|
| 29 |
+
4. List[List[int]]: A batch of UPOS IDs for the target token (this is a List of Lists, not a tensor. It should be padded later.)
|
| 30 |
+
5 (Optional): A mapping of label ID to counts in the dataset.
|
| 31 |
+
6. Mapping[str, int]: A map between the labels and their indexes
|
| 32 |
+
7. Mapping[str, int]: A map between the UPOS tags and their corresponding IDs found in the UPOS batches
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
if data_path is None or not os.path.exists(data_path):
|
| 36 |
+
raise FileNotFoundError(f"Data file {data_path} could not be found.")
|
| 37 |
+
|
| 38 |
+
if label_decoder is None:
|
| 39 |
+
label_decoder = {}
|
| 40 |
+
else:
|
| 41 |
+
# if labels in the test set aren't in the original model,
|
| 42 |
+
# the model will never predict those labels,
|
| 43 |
+
# but we can still use those labels in a confusion matrix
|
| 44 |
+
label_decoder = dict(label_decoder)
|
| 45 |
+
|
| 46 |
+
logger.debug("Final label decoder: %s Should be strings to ints", label_decoder)
|
| 47 |
+
|
| 48 |
+
# words which we are analyzing
|
| 49 |
+
target_words = set()
|
| 50 |
+
|
| 51 |
+
# all known words in the dataset, not just target words
|
| 52 |
+
known_words = set()
|
| 53 |
+
|
| 54 |
+
with open(data_path, "r+", encoding="utf-8") as fin:
|
| 55 |
+
sentences, indices, labels, upos_ids, counts, upos_to_id = [], [], [], [], Counter(), {}
|
| 56 |
+
|
| 57 |
+
input_json = json.load(fin)
|
| 58 |
+
sentences_data = input_json['sentences']
|
| 59 |
+
self.target_upos = input_json['upos']
|
| 60 |
+
|
| 61 |
+
for idx, sentence in enumerate(sentences_data):
|
| 62 |
+
# TODO Could replace this with sentence.values(), but need to know if Stanza requires Python 3.7 or later for backward compatability reasons
|
| 63 |
+
words, target_idx, upos_tags, label = sentence.get("words"), sentence.get("index"), sentence.get("upos_tags"), sentence.get("lemma")
|
| 64 |
+
if None in [words, target_idx, upos_tags, label]:
|
| 65 |
+
raise ValueError(f"Expected data to be complete but found a null value in sentence {idx}: {sentence}")
|
| 66 |
+
|
| 67 |
+
label_id = label_decoder.get(label, None)
|
| 68 |
+
if label_id is None:
|
| 69 |
+
label_decoder[label] = len(label_decoder) # create a new ID for the unknown label
|
| 70 |
+
|
| 71 |
+
converted_upos_tags = [] # convert upos tags to upos IDs
|
| 72 |
+
for upos_tag in upos_tags:
|
| 73 |
+
if upos_tag not in upos_to_id:
|
| 74 |
+
upos_to_id[upos_tag] = len(upos_to_id) # create a new ID for the unknown UPOS tag
|
| 75 |
+
converted_upos_tags.append(upos_to_id[upos_tag])
|
| 76 |
+
|
| 77 |
+
sentences.append(words)
|
| 78 |
+
indices.append(target_idx)
|
| 79 |
+
upos_ids.append(converted_upos_tags)
|
| 80 |
+
labels.append(label_decoder[label])
|
| 81 |
+
|
| 82 |
+
if get_counts:
|
| 83 |
+
counts[label_decoder[label]] += 1
|
| 84 |
+
|
| 85 |
+
target_words.add(words[target_idx])
|
| 86 |
+
known_words.update(words)
|
| 87 |
+
|
| 88 |
+
self.sentences = sentences
|
| 89 |
+
self.indices = indices
|
| 90 |
+
self.upos_ids = upos_ids
|
| 91 |
+
self.labels = labels
|
| 92 |
+
|
| 93 |
+
self.counts = counts
|
| 94 |
+
self.label_decoder = label_decoder
|
| 95 |
+
self.upos_to_id = upos_to_id
|
| 96 |
+
|
| 97 |
+
self.batch_size = batch_size
|
| 98 |
+
self.shuffle = shuffle
|
| 99 |
+
|
| 100 |
+
self.known_words = [x.lower() for x in sorted(known_words)]
|
| 101 |
+
self.target_words = set(x.lower() for x in target_words)
|
| 102 |
+
|
| 103 |
+
def __len__(self):
|
| 104 |
+
"""
|
| 105 |
+
Number of batches, rounded up to nearest batch
|
| 106 |
+
"""
|
| 107 |
+
return len(self.sentences) // self.batch_size + (len(self.sentences) % self.batch_size > 0)
|
| 108 |
+
|
| 109 |
+
def __iter__(self):
|
| 110 |
+
num_sentences = len(self.sentences)
|
| 111 |
+
indices = list(range(num_sentences))
|
| 112 |
+
if self.shuffle:
|
| 113 |
+
random.shuffle(indices)
|
| 114 |
+
for i in range(self.__len__()):
|
| 115 |
+
batch_start = self.batch_size * i
|
| 116 |
+
batch_end = min(batch_start + self.batch_size, num_sentences)
|
| 117 |
+
|
| 118 |
+
batch_sentences = [self.sentences[x] for x in indices[batch_start:batch_end]]
|
| 119 |
+
batch_indices = torch.tensor([self.indices[x] for x in indices[batch_start:batch_end]])
|
| 120 |
+
batch_upos_ids = [self.upos_ids[x] for x in indices[batch_start:batch_end]]
|
| 121 |
+
batch_labels = torch.tensor([self.labels[x] for x in indices[batch_start:batch_end]])
|
| 122 |
+
yield batch_sentences, batch_indices, batch_upos_ids, batch_labels
|
| 123 |
+
|
| 124 |
+
def extract_unknown_token_indices(tokenized_indices: torch.tensor, unknown_token_idx: int) -> List[int]:
|
| 125 |
+
"""
|
| 126 |
+
Extracts the indices within `tokenized_indices` which match `unknown_token_idx`
|
| 127 |
+
|
| 128 |
+
Args:
|
| 129 |
+
tokenized_indices (torch.tensor): A tensor filled with tokenized indices of words that have been mapped to vector indices.
|
| 130 |
+
unknown_token_idx (int): The special index for which unknown tokens are marked in the word vectors.
|
| 131 |
+
|
| 132 |
+
Returns:
|
| 133 |
+
List[int]: A list of indices in `tokenized_indices` which match `unknown_token_index`
|
| 134 |
+
"""
|
| 135 |
+
return [idx for idx, token_index in enumerate(tokenized_indices) if token_index == unknown_token_idx]
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def get_device():
|
| 139 |
+
"""
|
| 140 |
+
Get the device to run computations on
|
| 141 |
+
"""
|
| 142 |
+
if torch.cuda.is_available:
|
| 143 |
+
device = torch.device("cuda")
|
| 144 |
+
if torch.backends.mps.is_available():
|
| 145 |
+
device = torch.device("mps")
|
| 146 |
+
else:
|
| 147 |
+
device = torch.device("cpu")
|
| 148 |
+
|
| 149 |
+
return device
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def round_up_to_multiple(number, multiple):
|
| 153 |
+
if multiple == 0:
|
| 154 |
+
return "Error: The second number (multiple) cannot be zero."
|
| 155 |
+
|
| 156 |
+
# Calculate the remainder when dividing the number by the multiple
|
| 157 |
+
remainder = number % multiple
|
| 158 |
+
|
| 159 |
+
# If remainder is non-zero, round up to the next multiple
|
| 160 |
+
if remainder != 0:
|
| 161 |
+
rounded_number = number + (multiple - remainder)
|
| 162 |
+
else:
|
| 163 |
+
rounded_number = number # No rounding needed
|
| 164 |
+
|
| 165 |
+
return rounded_number
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def main():
|
| 169 |
+
default_test_path = os.path.join(os.path.dirname(__file__), "test_sets", "processed_ud_en", "combined_dev.txt") # get the GUM stuff
|
| 170 |
+
sentence_batches, indices_batches, upos_batches, _, counts, _, upos_to_id = load_dataset(default_test_path, get_counts=True)
|
| 171 |
+
|
| 172 |
+
if __name__ == "__main__":
|
| 173 |
+
main()
|
stanza/stanza/models/mwt/character_classifier.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Classify characters based on an LSTM with learned character representations
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch import nn
|
| 9 |
+
|
| 10 |
+
import stanza.models.common.seq2seq_constant as constant
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger('stanza')
|
| 13 |
+
|
| 14 |
+
class CharacterClassifier(nn.Module):
|
| 15 |
+
def __init__(self, args):
|
| 16 |
+
super().__init__()
|
| 17 |
+
|
| 18 |
+
self.vocab_size = args['vocab_size']
|
| 19 |
+
self.emb_dim = args['emb_dim']
|
| 20 |
+
self.hidden_dim = args['hidden_dim']
|
| 21 |
+
self.nlayers = args['num_layers'] # lstm encoder layers
|
| 22 |
+
self.pad_token = constant.PAD_ID
|
| 23 |
+
self.enc_hidden_dim = self.hidden_dim // 2 # since it is bidirectional
|
| 24 |
+
|
| 25 |
+
self.num_outputs = 2
|
| 26 |
+
|
| 27 |
+
self.args = args
|
| 28 |
+
|
| 29 |
+
self.emb_dropout = args.get('emb_dropout', 0.0)
|
| 30 |
+
self.emb_drop = nn.Dropout(self.emb_dropout)
|
| 31 |
+
self.dropout = args['dropout']
|
| 32 |
+
|
| 33 |
+
self.embedding = nn.Embedding(self.vocab_size, self.emb_dim, self.pad_token)
|
| 34 |
+
self.input_dim = self.emb_dim
|
| 35 |
+
self.encoder = nn.LSTM(self.input_dim, self.enc_hidden_dim, self.nlayers, \
|
| 36 |
+
bidirectional=True, batch_first=True, dropout=self.dropout if self.nlayers > 1 else 0)
|
| 37 |
+
|
| 38 |
+
self.output_layer = nn.Sequential(
|
| 39 |
+
nn.Linear(self.hidden_dim, self.hidden_dim),
|
| 40 |
+
nn.ReLU(),
|
| 41 |
+
nn.Linear(self.hidden_dim, self.num_outputs))
|
| 42 |
+
|
| 43 |
+
def encode(self, enc_inputs, lens):
|
| 44 |
+
""" Encode source sequence. """
|
| 45 |
+
packed_inputs = nn.utils.rnn.pack_padded_sequence(enc_inputs, lens, batch_first=True)
|
| 46 |
+
packed_h_in, (hn, cn) = self.encoder(packed_inputs)
|
| 47 |
+
return packed_h_in
|
| 48 |
+
|
| 49 |
+
def embed(self, src, src_mask):
|
| 50 |
+
# the input data could have characters outside the known range
|
| 51 |
+
# of characters in cases where the vocabulary was temporarily
|
| 52 |
+
# expanded (note that this model does nothing with those chars)
|
| 53 |
+
embed_src = src.clone()
|
| 54 |
+
embed_src[embed_src >= self.vocab_size] = constant.UNK_ID
|
| 55 |
+
enc_inputs = self.emb_drop(self.embedding(embed_src))
|
| 56 |
+
batch_size = enc_inputs.size(0)
|
| 57 |
+
src_lens = list(src_mask.data.eq(self.pad_token).long().sum(1))
|
| 58 |
+
return enc_inputs, batch_size, src_lens, src_mask
|
| 59 |
+
|
| 60 |
+
def forward(self, src, src_mask):
|
| 61 |
+
enc_inputs, batch_size, src_lens, src_mask = self.embed(src, src_mask)
|
| 62 |
+
encoded = self.encode(enc_inputs, src_lens)
|
| 63 |
+
encoded, _ = nn.utils.rnn.pad_packed_sequence(encoded, batch_first=True)
|
| 64 |
+
logits = self.output_layer(encoded)
|
| 65 |
+
return logits
|
stanza/stanza/models/mwt/trainer.py
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
A trainer class to handle training and testing of models.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import sys
|
| 6 |
+
import numpy as np
|
| 7 |
+
from collections import Counter
|
| 8 |
+
import logging
|
| 9 |
+
import torch
|
| 10 |
+
from torch import nn
|
| 11 |
+
import torch.nn.init as init
|
| 12 |
+
|
| 13 |
+
import stanza.models.common.seq2seq_constant as constant
|
| 14 |
+
from stanza.models.common.trainer import Trainer as BaseTrainer
|
| 15 |
+
from stanza.models.common.seq2seq_model import Seq2SeqModel
|
| 16 |
+
from stanza.models.common import utils, loss
|
| 17 |
+
from stanza.models.mwt.character_classifier import CharacterClassifier
|
| 18 |
+
from stanza.models.mwt.vocab import Vocab
|
| 19 |
+
|
| 20 |
+
logger = logging.getLogger('stanza')
|
| 21 |
+
|
| 22 |
+
def unpack_batch(batch, device):
|
| 23 |
+
""" Unpack a batch from the data loader. """
|
| 24 |
+
inputs = [b.to(device) if b is not None else None for b in batch[:4]]
|
| 25 |
+
orig_text = batch[4]
|
| 26 |
+
orig_idx = batch[5]
|
| 27 |
+
return inputs, orig_text, orig_idx
|
| 28 |
+
|
| 29 |
+
class Trainer(BaseTrainer):
|
| 30 |
+
""" A trainer for training models. """
|
| 31 |
+
def __init__(self, args=None, vocab=None, emb_matrix=None, model_file=None, device=None):
|
| 32 |
+
if model_file is not None:
|
| 33 |
+
# load from file
|
| 34 |
+
self.load(model_file)
|
| 35 |
+
else:
|
| 36 |
+
self.args = args
|
| 37 |
+
if args['dict_only']:
|
| 38 |
+
self.model = None
|
| 39 |
+
elif args.get('force_exact_pieces', False):
|
| 40 |
+
self.model = CharacterClassifier(args)
|
| 41 |
+
else:
|
| 42 |
+
self.model = Seq2SeqModel(args, emb_matrix=emb_matrix)
|
| 43 |
+
self.vocab = vocab
|
| 44 |
+
self.expansion_dict = dict()
|
| 45 |
+
if not self.args['dict_only']:
|
| 46 |
+
self.model = self.model.to(device)
|
| 47 |
+
if self.args.get('force_exact_pieces', False):
|
| 48 |
+
self.crit = nn.CrossEntropyLoss()
|
| 49 |
+
else:
|
| 50 |
+
self.crit = loss.SequenceLoss(self.vocab.size).to(device)
|
| 51 |
+
self.optimizer = utils.get_optimizer(self.args['optim'], self.model, self.args['lr'])
|
| 52 |
+
|
| 53 |
+
def update(self, batch, eval=False):
|
| 54 |
+
device = next(self.model.parameters()).device
|
| 55 |
+
# ignore the original text when training
|
| 56 |
+
# can try to learn the correct values, even if we eventually
|
| 57 |
+
# copy directly from the original text
|
| 58 |
+
inputs, _, orig_idx = unpack_batch(batch, device)
|
| 59 |
+
src, src_mask, tgt_in, tgt_out = inputs
|
| 60 |
+
|
| 61 |
+
if eval:
|
| 62 |
+
self.model.eval()
|
| 63 |
+
else:
|
| 64 |
+
self.model.train()
|
| 65 |
+
self.optimizer.zero_grad()
|
| 66 |
+
if self.args.get('force_exact_pieces', False):
|
| 67 |
+
log_probs = self.model(src, src_mask)
|
| 68 |
+
src_lens = list(src_mask.data.eq(constant.PAD_ID).long().sum(1))
|
| 69 |
+
packed_output = nn.utils.rnn.pack_padded_sequence(log_probs, src_lens, batch_first=True)
|
| 70 |
+
packed_tgt = nn.utils.rnn.pack_padded_sequence(tgt_in, src_lens, batch_first=True)
|
| 71 |
+
loss = self.crit(packed_output.data, packed_tgt.data)
|
| 72 |
+
else:
|
| 73 |
+
log_probs, _ = self.model(src, src_mask, tgt_in)
|
| 74 |
+
loss = self.crit(log_probs.view(-1, self.vocab.size), tgt_out.view(-1))
|
| 75 |
+
loss_val = loss.data.item()
|
| 76 |
+
if eval:
|
| 77 |
+
return loss_val
|
| 78 |
+
|
| 79 |
+
loss.backward()
|
| 80 |
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm'])
|
| 81 |
+
self.optimizer.step()
|
| 82 |
+
return loss_val
|
| 83 |
+
|
| 84 |
+
def predict(self, batch, unsort=True, never_decode_unk=False, vocab=None):
|
| 85 |
+
if vocab is None:
|
| 86 |
+
vocab = self.vocab
|
| 87 |
+
|
| 88 |
+
device = next(self.model.parameters()).device
|
| 89 |
+
inputs, orig_text, orig_idx = unpack_batch(batch, device)
|
| 90 |
+
src, src_mask, tgt, tgt_mask = inputs
|
| 91 |
+
|
| 92 |
+
self.model.eval()
|
| 93 |
+
batch_size = src.size(0)
|
| 94 |
+
if self.args.get('force_exact_pieces', False):
|
| 95 |
+
log_probs = self.model(src, src_mask)
|
| 96 |
+
cuts = log_probs[:, :, 1] > log_probs[:, :, 0]
|
| 97 |
+
src_lens = list(src_mask.data.eq(constant.PAD_ID).long().sum(1))
|
| 98 |
+
pred_tokens = []
|
| 99 |
+
for src_ids, cut, src_len in zip(src, cuts, src_lens):
|
| 100 |
+
src_chars = vocab.unmap(src_ids)
|
| 101 |
+
pred_seq = []
|
| 102 |
+
for char_idx in range(1, src_len-1):
|
| 103 |
+
if cut[char_idx]:
|
| 104 |
+
pred_seq.append(' ')
|
| 105 |
+
pred_seq.append(src_chars[char_idx])
|
| 106 |
+
pred_seq = "".join(pred_seq).strip()
|
| 107 |
+
pred_tokens.append(pred_seq)
|
| 108 |
+
else:
|
| 109 |
+
preds, _ = self.model.predict(src, src_mask, self.args['beam_size'], never_decode_unk=never_decode_unk)
|
| 110 |
+
pred_seqs = [vocab.unmap(ids) for ids in preds] # unmap to tokens
|
| 111 |
+
pred_seqs = utils.prune_decoded_seqs(pred_seqs)
|
| 112 |
+
|
| 113 |
+
pred_tokens = ["".join(seq) for seq in pred_seqs] # join chars to be tokens
|
| 114 |
+
# if any tokens are predicted to expand to blank,
|
| 115 |
+
# that is likely an error. use the original text
|
| 116 |
+
# this originally came up with the Spanish model turning 's' into a blank
|
| 117 |
+
# furthermore, if there are no spaces predicted by the seq2seq,
|
| 118 |
+
# might as well use the original in case the seq2seq went crazy
|
| 119 |
+
# this particular error came up training a Hebrew MWT
|
| 120 |
+
pred_tokens = [x if x and ' ' in x else y for x, y in zip(pred_tokens, orig_text)]
|
| 121 |
+
if unsort:
|
| 122 |
+
pred_tokens = utils.unsort(pred_tokens, orig_idx)
|
| 123 |
+
return pred_tokens
|
| 124 |
+
|
| 125 |
+
def train_dict(self, pairs):
|
| 126 |
+
""" Train a MWT expander given training word-expansion pairs. """
|
| 127 |
+
# accumulate counter
|
| 128 |
+
ctr = Counter()
|
| 129 |
+
ctr.update([(p[0], p[1]) for p in pairs])
|
| 130 |
+
seen = set()
|
| 131 |
+
# find the most frequent mappings
|
| 132 |
+
for p, _ in ctr.most_common():
|
| 133 |
+
w, l = p
|
| 134 |
+
if w not in seen and w != l:
|
| 135 |
+
self.expansion_dict[w] = l
|
| 136 |
+
seen.add(w)
|
| 137 |
+
return
|
| 138 |
+
|
| 139 |
+
def dict_expansion(self, word):
|
| 140 |
+
"""
|
| 141 |
+
Check the expansion dictionary for the word along with a couple common lowercasings of the word
|
| 142 |
+
|
| 143 |
+
(Leadingcase and UPPERCASE)
|
| 144 |
+
"""
|
| 145 |
+
expansion = self.expansion_dict.get(word)
|
| 146 |
+
if expansion is not None:
|
| 147 |
+
return expansion
|
| 148 |
+
|
| 149 |
+
if word.isupper():
|
| 150 |
+
expansion = self.expansion_dict.get(word.lower())
|
| 151 |
+
if expansion is not None:
|
| 152 |
+
return expansion.upper()
|
| 153 |
+
|
| 154 |
+
if word[0].isupper() and word[1:].islower():
|
| 155 |
+
expansion = self.expansion_dict.get(word.lower())
|
| 156 |
+
if expansion is not None:
|
| 157 |
+
return expansion[0].upper() + expansion[1:]
|
| 158 |
+
|
| 159 |
+
# could build a truecasing model of some kind to handle cRaZyCaSe...
|
| 160 |
+
# but that's probably too much effort
|
| 161 |
+
return None
|
| 162 |
+
|
| 163 |
+
def predict_dict(self, words):
|
| 164 |
+
""" Predict a list of expansions given words. """
|
| 165 |
+
expansions = []
|
| 166 |
+
for w in words:
|
| 167 |
+
expansion = self.dict_expansion(w)
|
| 168 |
+
if expansion is not None:
|
| 169 |
+
expansions.append(expansion)
|
| 170 |
+
else:
|
| 171 |
+
expansions.append(w)
|
| 172 |
+
return expansions
|
| 173 |
+
|
| 174 |
+
def ensemble(self, cands, other_preds):
|
| 175 |
+
""" Ensemble the dict with statistical model predictions. """
|
| 176 |
+
expansions = []
|
| 177 |
+
assert len(cands) == len(other_preds)
|
| 178 |
+
for c, pred in zip(cands, other_preds):
|
| 179 |
+
expansion = self.dict_expansion(c)
|
| 180 |
+
if expansion is not None:
|
| 181 |
+
expansions.append(expansion)
|
| 182 |
+
else:
|
| 183 |
+
expansions.append(pred)
|
| 184 |
+
return expansions
|
| 185 |
+
|
| 186 |
+
def save(self, filename):
|
| 187 |
+
params = {
|
| 188 |
+
'model': self.model.state_dict() if self.model is not None else None,
|
| 189 |
+
'dict': self.expansion_dict,
|
| 190 |
+
'vocab': self.vocab.state_dict(),
|
| 191 |
+
'config': self.args
|
| 192 |
+
}
|
| 193 |
+
try:
|
| 194 |
+
torch.save(params, filename, _use_new_zipfile_serialization=False)
|
| 195 |
+
logger.info("Model saved to {}".format(filename))
|
| 196 |
+
except BaseException:
|
| 197 |
+
logger.warning("Saving failed... continuing anyway.")
|
| 198 |
+
|
| 199 |
+
def load(self, filename):
|
| 200 |
+
try:
|
| 201 |
+
checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True)
|
| 202 |
+
except BaseException:
|
| 203 |
+
logger.error("Cannot load model from {}".format(filename))
|
| 204 |
+
raise
|
| 205 |
+
self.args = checkpoint['config']
|
| 206 |
+
self.expansion_dict = checkpoint['dict']
|
| 207 |
+
if not self.args['dict_only']:
|
| 208 |
+
if self.args.get('force_exact_pieces', False):
|
| 209 |
+
self.model = CharacterClassifier(self.args)
|
| 210 |
+
else:
|
| 211 |
+
self.model = Seq2SeqModel(self.args)
|
| 212 |
+
# could remove strict=False after rebuilding all models,
|
| 213 |
+
# or could switch to 1.6.0 torch with the buffer in seq2seq persistent=False
|
| 214 |
+
self.model.load_state_dict(checkpoint['model'], strict=False)
|
| 215 |
+
else:
|
| 216 |
+
self.model = None
|
| 217 |
+
self.vocab = Vocab.load_state_dict(checkpoint['vocab'])
|
| 218 |
+
|
stanza/stanza/models/mwt/vocab.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import Counter
|
| 2 |
+
|
| 3 |
+
from stanza.models.common.vocab import BaseVocab
|
| 4 |
+
import stanza.models.common.seq2seq_constant as constant
|
| 5 |
+
|
| 6 |
+
class Vocab(BaseVocab):
|
| 7 |
+
def build_vocab(self):
|
| 8 |
+
pairs = self.data
|
| 9 |
+
allchars = "".join([src + tgt for src, tgt in pairs])
|
| 10 |
+
counter = Counter(allchars)
|
| 11 |
+
|
| 12 |
+
self._id2unit = constant.VOCAB_PREFIX + list(sorted(list(counter.keys()), key=lambda k: counter[k], reverse=True))
|
| 13 |
+
self._unit2id = {w:i for i, w in enumerate(self._id2unit)}
|
| 14 |
+
|
| 15 |
+
def add_unit(self, unit):
|
| 16 |
+
if unit in self._unit2id:
|
| 17 |
+
return
|
| 18 |
+
self._unit2id[unit] = len(self._id2unit)
|
| 19 |
+
self._id2unit.append(unit)
|
stanza/stanza/models/ner/vocab.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import Counter, OrderedDict
|
| 2 |
+
|
| 3 |
+
from stanza.models.common.vocab import BaseVocab, BaseMultiVocab, CharVocab, CompositeVocab
|
| 4 |
+
from stanza.models.common.vocab import VOCAB_PREFIX
|
| 5 |
+
from stanza.models.common.pretrain import PretrainedWordVocab
|
| 6 |
+
from stanza.models.pos.vocab import WordVocab
|
| 7 |
+
|
| 8 |
+
class TagVocab(BaseVocab):
|
| 9 |
+
""" A vocab for the output tag sequence. """
|
| 10 |
+
def build_vocab(self):
|
| 11 |
+
counter = Counter([w[self.idx] for sent in self.data for w in sent])
|
| 12 |
+
|
| 13 |
+
self._id2unit = VOCAB_PREFIX + list(sorted(list(counter.keys()), key=lambda k: counter[k], reverse=True))
|
| 14 |
+
self._unit2id = {w:i for i, w in enumerate(self._id2unit)}
|
| 15 |
+
|
| 16 |
+
def convert_tag_vocab(state_dict):
|
| 17 |
+
if state_dict['lower']:
|
| 18 |
+
raise AssertionError("Did not expect an NER vocab with 'lower' set to True")
|
| 19 |
+
items = state_dict['_id2unit'][len(VOCAB_PREFIX):]
|
| 20 |
+
# this looks silly, but the vocab builder treats this as words with multiple fields
|
| 21 |
+
# (we set it to look for field 0 with idx=0)
|
| 22 |
+
# and then the label field is expected to be a list or tuple of items
|
| 23 |
+
items = [[[[x]]] for x in items]
|
| 24 |
+
vocab = CompositeVocab(data=items, lang=state_dict['lang'], idx=0, sep=None)
|
| 25 |
+
if len(vocab._id2unit[0]) != len(state_dict['_id2unit']):
|
| 26 |
+
raise AssertionError("Failed to construct a new vocab of the same length as the original")
|
| 27 |
+
if vocab._id2unit[0] != state_dict['_id2unit']:
|
| 28 |
+
raise AssertionError("Failed to construct a new vocab in the same order as the original")
|
| 29 |
+
return vocab
|
| 30 |
+
|
| 31 |
+
class MultiVocab(BaseMultiVocab):
|
| 32 |
+
def state_dict(self):
|
| 33 |
+
""" Also save a vocab name to class name mapping in state dict. """
|
| 34 |
+
state = OrderedDict()
|
| 35 |
+
key2class = OrderedDict()
|
| 36 |
+
for k, v in self._vocabs.items():
|
| 37 |
+
state[k] = v.state_dict()
|
| 38 |
+
key2class[k] = type(v).__name__
|
| 39 |
+
state['_key2class'] = key2class
|
| 40 |
+
return state
|
| 41 |
+
|
| 42 |
+
@classmethod
|
| 43 |
+
def load_state_dict(cls, state_dict):
|
| 44 |
+
class_dict = {'CharVocab': CharVocab.load_state_dict,
|
| 45 |
+
'PretrainedWordVocab': PretrainedWordVocab.load_state_dict,
|
| 46 |
+
'TagVocab': convert_tag_vocab,
|
| 47 |
+
'CompositeVocab': CompositeVocab.load_state_dict,
|
| 48 |
+
'WordVocab': WordVocab.load_state_dict}
|
| 49 |
+
new = cls()
|
| 50 |
+
assert '_key2class' in state_dict, "Cannot find class name mapping in state dict!"
|
| 51 |
+
key2class = state_dict.pop('_key2class')
|
| 52 |
+
for k,v in state_dict.items():
|
| 53 |
+
classname = key2class[k]
|
| 54 |
+
new[k] = class_dict[classname](v)
|
| 55 |
+
return new
|
| 56 |
+
|
stanza/stanza/models/pos/__init__.py
ADDED
|
File without changes
|
stanza/stanza/models/pos/build_xpos_vocab_factory.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
from collections import defaultdict
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
import re
|
| 6 |
+
import sys
|
| 7 |
+
from zipfile import ZipFile
|
| 8 |
+
|
| 9 |
+
from stanza.models.common.constant import treebank_to_short_name
|
| 10 |
+
from stanza.models.pos.xpos_vocab_utils import DEFAULT_KEY, choose_simplest_factory, XPOSType
|
| 11 |
+
from stanza.models.common.doc import *
|
| 12 |
+
from stanza.utils.conll import CoNLL
|
| 13 |
+
from stanza.utils import default_paths
|
| 14 |
+
|
| 15 |
+
SHORTNAME_RE = re.compile("[a-z-]+_[a-z0-9]+")
|
| 16 |
+
DATA_DIR = default_paths.get_default_paths()['POS_DATA_DIR']
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger('stanza')
|
| 19 |
+
|
| 20 |
+
def get_xpos_factory(shorthand, fn):
|
| 21 |
+
logger.info('Resolving vocab option for {}...'.format(shorthand))
|
| 22 |
+
doc = None
|
| 23 |
+
train_file = os.path.join(DATA_DIR, '{}.train.in.conllu'.format(shorthand))
|
| 24 |
+
if os.path.exists(train_file):
|
| 25 |
+
doc = CoNLL.conll2doc(input_file=train_file)
|
| 26 |
+
else:
|
| 27 |
+
zip_file = os.path.join(DATA_DIR, '{}.train.in.zip'.format(shorthand))
|
| 28 |
+
if os.path.exists(zip_file):
|
| 29 |
+
with ZipFile(zip_file) as zin:
|
| 30 |
+
for train_file in zin.namelist():
|
| 31 |
+
doc = CoNLL.conll2doc(input_file=train_file, zip_file=zip_file)
|
| 32 |
+
if any(word.xpos for sentence in doc.sentences for word in sentence.words):
|
| 33 |
+
break
|
| 34 |
+
else:
|
| 35 |
+
raise ValueError('Found training data in {}, but none of the files contained had xpos'.format(zip_file))
|
| 36 |
+
|
| 37 |
+
if doc is None:
|
| 38 |
+
raise FileNotFoundError('Training data for {} not found. To generate the XPOS vocabulary '
|
| 39 |
+
'for this treebank properly, please run the following command first:\n'
|
| 40 |
+
' python3 stanza/utils/datasets/prepare_pos_treebank.py {}'.format(fn, fn))
|
| 41 |
+
# without the training file, there's not much we can do
|
| 42 |
+
key = DEFAULT_KEY
|
| 43 |
+
return key
|
| 44 |
+
|
| 45 |
+
data = doc.get([TEXT, UPOS, XPOS, FEATS], as_sentences=True)
|
| 46 |
+
return choose_simplest_factory(data, shorthand)
|
| 47 |
+
|
| 48 |
+
def main():
|
| 49 |
+
parser = argparse.ArgumentParser()
|
| 50 |
+
parser.add_argument('--treebanks', type=str, default=DATA_DIR, help="Treebanks to process - directory with processed datasets or a file with a list")
|
| 51 |
+
parser.add_argument('--output_file', type=str, default="stanza/models/pos/xpos_vocab_factory.py", help="Where to write the results")
|
| 52 |
+
args = parser.parse_args()
|
| 53 |
+
|
| 54 |
+
output_file = args.output_file
|
| 55 |
+
if os.path.isdir(args.treebanks):
|
| 56 |
+
# if the path is a directory of datasets (which is the default if --treebanks is not set)
|
| 57 |
+
# we use those datasets to prepare the xpos factories
|
| 58 |
+
treebanks = os.listdir(args.treebanks)
|
| 59 |
+
treebanks = [x.split(".", maxsplit=1)[0] for x in treebanks]
|
| 60 |
+
treebanks = sorted(set(treebanks))
|
| 61 |
+
elif os.path.exists(args.treebanks):
|
| 62 |
+
# maybe it's a file with a list of names
|
| 63 |
+
with open(args.treebanks) as fin:
|
| 64 |
+
treebanks = sorted(set([x.strip() for x in fin.readlines() if x.strip()]))
|
| 65 |
+
else:
|
| 66 |
+
raise ValueError("Cannot figure out which treebanks to use. Please set the --treebanks parameter")
|
| 67 |
+
|
| 68 |
+
logger.info("Processing the following treebanks: %s" % " ".join(treebanks))
|
| 69 |
+
|
| 70 |
+
shorthands = []
|
| 71 |
+
fullnames = []
|
| 72 |
+
for treebank in treebanks:
|
| 73 |
+
fullnames.append(treebank)
|
| 74 |
+
if SHORTNAME_RE.match(treebank):
|
| 75 |
+
shorthands.append(treebank)
|
| 76 |
+
else:
|
| 77 |
+
shorthands.append(treebank_to_short_name(treebank))
|
| 78 |
+
|
| 79 |
+
# For each treebank, we would like to find the XPOS Vocab configuration that minimizes
|
| 80 |
+
# the number of total classes needed to predict by all tagger classifiers. This is
|
| 81 |
+
# achieved by enumerating different options of separators that different treebanks might
|
| 82 |
+
# use, and comparing that to treating the XPOS tags as separate categories (using a
|
| 83 |
+
# WordVocab).
|
| 84 |
+
mapping = defaultdict(list)
|
| 85 |
+
for sh, fn in zip(shorthands, fullnames):
|
| 86 |
+
factory = get_xpos_factory(sh, fn)
|
| 87 |
+
mapping[factory].append(sh)
|
| 88 |
+
if sh == 'zh-hans_gsdsimp':
|
| 89 |
+
mapping[factory].append('zh_gsdsimp')
|
| 90 |
+
elif sh == 'no_bokmaal':
|
| 91 |
+
mapping[factory].append('nb_bokmaal')
|
| 92 |
+
|
| 93 |
+
mapping[DEFAULT_KEY].append('en_test')
|
| 94 |
+
|
| 95 |
+
# Generate code. This takes the XPOS vocabulary classes selected above, and generates the
|
| 96 |
+
# actual factory class as seen in models.pos.xpos_vocab_factory.
|
| 97 |
+
first = True
|
| 98 |
+
with open(output_file, 'w') as f:
|
| 99 |
+
max_len = max(max(len(x) for x in mapping[key]) for key in mapping)
|
| 100 |
+
print('''# This is the XPOS factory method generated automatically from stanza.models.pos.build_xpos_vocab_factory.
|
| 101 |
+
# Please don't edit it!
|
| 102 |
+
|
| 103 |
+
import logging
|
| 104 |
+
|
| 105 |
+
from stanza.models.pos.vocab import WordVocab, XPOSVocab
|
| 106 |
+
from stanza.models.pos.xpos_vocab_utils import XPOSDescription, XPOSType, build_xpos_vocab, choose_simplest_factory
|
| 107 |
+
|
| 108 |
+
# using a sublogger makes it easier to test in the unittests
|
| 109 |
+
logger = logging.getLogger('stanza.models.pos.xpos_vocab_factory')
|
| 110 |
+
|
| 111 |
+
XPOS_DESCRIPTIONS = {''', file=f)
|
| 112 |
+
|
| 113 |
+
for key_idx, key in enumerate(mapping):
|
| 114 |
+
if key_idx > 0:
|
| 115 |
+
print(file=f)
|
| 116 |
+
for shorthand in sorted(mapping[key]):
|
| 117 |
+
# +2 to max_len for the ''
|
| 118 |
+
# this format string is left justified (either would be okay, probably)
|
| 119 |
+
if key.sep is None:
|
| 120 |
+
sep = 'None'
|
| 121 |
+
else:
|
| 122 |
+
sep = "'%s'" % key.sep
|
| 123 |
+
print((" {:%ds}: XPOSDescription({}, {})," % (max_len+2)).format("'%s'" % shorthand, key.xpos_type, sep), file=f)
|
| 124 |
+
|
| 125 |
+
print('''}
|
| 126 |
+
|
| 127 |
+
def xpos_vocab_factory(data, shorthand):
|
| 128 |
+
if shorthand not in XPOS_DESCRIPTIONS:
|
| 129 |
+
logger.warning("%s is not a known dataset. Examining the data to choose which xpos vocab to use", shorthand)
|
| 130 |
+
desc = choose_simplest_factory(data, shorthand)
|
| 131 |
+
if shorthand in XPOS_DESCRIPTIONS:
|
| 132 |
+
if XPOS_DESCRIPTIONS[shorthand] != desc:
|
| 133 |
+
# log instead of throw
|
| 134 |
+
# otherwise, updating datasets would be unpleasant
|
| 135 |
+
logger.error("XPOS tagset in %s has apparently changed! Was %s, is now %s", shorthand, XPOS_DESCRIPTIONS[shorthand], desc)
|
| 136 |
+
else:
|
| 137 |
+
logger.warning("Chose %s for the xpos factory for %s", desc, shorthand)
|
| 138 |
+
return build_xpos_vocab(desc, data, shorthand)
|
| 139 |
+
''', file=f)
|
| 140 |
+
|
| 141 |
+
logger.info('Done!')
|
| 142 |
+
|
| 143 |
+
if __name__ == "__main__":
|
| 144 |
+
main()
|
stanza/stanza/models/pos/data.py
ADDED
|
@@ -0,0 +1,387 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
import logging
|
| 3 |
+
import copy
|
| 4 |
+
import torch
|
| 5 |
+
from collections import namedtuple
|
| 6 |
+
|
| 7 |
+
from torch.utils.data import DataLoader as DL
|
| 8 |
+
from torch.utils.data.sampler import Sampler
|
| 9 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 10 |
+
|
| 11 |
+
from stanza.models.common.bert_embedding import filter_data, needs_length_filter
|
| 12 |
+
from stanza.models.common.data import map_to_ids, get_long_tensor, get_float_tensor, sort_all
|
| 13 |
+
from stanza.models.common.vocab import PAD_ID, VOCAB_PREFIX, CharVocab
|
| 14 |
+
from stanza.models.pos.vocab import WordVocab, XPOSVocab, FeatureVocab, MultiVocab
|
| 15 |
+
from stanza.models.pos.xpos_vocab_factory import xpos_vocab_factory
|
| 16 |
+
from stanza.models.common.doc import *
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger('stanza')
|
| 19 |
+
|
| 20 |
+
DataSample = namedtuple("DataSample", "word char upos xpos feats pretrain text")
|
| 21 |
+
DataBatch = namedtuple("DataBatch", "words words_mask wordchars wordchars_mask upos xpos ufeats pretrained orig_idx word_orig_idx lens word_lens text idx")
|
| 22 |
+
|
| 23 |
+
class Dataset:
|
| 24 |
+
def __init__(self, doc, args, pretrain, vocab=None, evaluation=False, sort_during_eval=False, bert_tokenizer=None, **kwargs):
|
| 25 |
+
self.args = args
|
| 26 |
+
self.eval = evaluation
|
| 27 |
+
self.shuffled = not self.eval
|
| 28 |
+
self.sort_during_eval = sort_during_eval
|
| 29 |
+
self.doc = doc
|
| 30 |
+
|
| 31 |
+
if vocab is None:
|
| 32 |
+
self.vocab = Dataset.init_vocab([doc], args)
|
| 33 |
+
else:
|
| 34 |
+
self.vocab = vocab
|
| 35 |
+
|
| 36 |
+
self.has_upos = not all(x is None or x == '_' for x in doc.get(UPOS, as_sentences=False))
|
| 37 |
+
self.has_xpos = not all(x is None or x == '_' for x in doc.get(XPOS, as_sentences=False))
|
| 38 |
+
self.has_feats = not all(x is None or x == '_' for x in doc.get(FEATS, as_sentences=False))
|
| 39 |
+
|
| 40 |
+
data = self.load_doc(self.doc)
|
| 41 |
+
# filter out the long sentences if bert is used
|
| 42 |
+
if self.args.get('bert_model', None) and needs_length_filter(self.args['bert_model']):
|
| 43 |
+
data = filter_data(self.args['bert_model'], data, bert_tokenizer)
|
| 44 |
+
|
| 45 |
+
# handle pretrain; pretrain vocab is used when args['pretrain'] == True and pretrain is not None
|
| 46 |
+
self.pretrain_vocab = None
|
| 47 |
+
if pretrain is not None and args['pretrain']:
|
| 48 |
+
self.pretrain_vocab = pretrain.vocab
|
| 49 |
+
|
| 50 |
+
# filter and sample data
|
| 51 |
+
if args.get('sample_train', 1.0) < 1.0 and not self.eval:
|
| 52 |
+
keep = int(args['sample_train'] * len(data))
|
| 53 |
+
data = random.sample(data, keep)
|
| 54 |
+
logger.debug("Subsample training set with rate {:g}".format(args['sample_train']))
|
| 55 |
+
|
| 56 |
+
data = self.preprocess(data, self.vocab, self.pretrain_vocab, args)
|
| 57 |
+
|
| 58 |
+
self.data = data
|
| 59 |
+
|
| 60 |
+
self.num_examples = len(data)
|
| 61 |
+
self.__punct_tags = self.vocab["upos"].map(["PUNCT"])
|
| 62 |
+
self.augment_nopunct = self.args.get("augment_nopunct", 0.0)
|
| 63 |
+
|
| 64 |
+
@staticmethod
|
| 65 |
+
def init_vocab(docs, args):
|
| 66 |
+
data = [x for doc in docs for x in Dataset.load_doc(doc)]
|
| 67 |
+
charvocab = CharVocab(data, args['shorthand'])
|
| 68 |
+
wordvocab = WordVocab(data, args['shorthand'], cutoff=args['word_cutoff'], lower=True)
|
| 69 |
+
uposvocab = WordVocab(data, args['shorthand'], idx=1)
|
| 70 |
+
xposvocab = xpos_vocab_factory(data, args['shorthand'])
|
| 71 |
+
try:
|
| 72 |
+
featsvocab = FeatureVocab(data, args['shorthand'], idx=3)
|
| 73 |
+
except ValueError as e:
|
| 74 |
+
raise ValueError("Unable to build features vocab. Please check the Features column of your data for an error which may match the following description.") from e
|
| 75 |
+
vocab = MultiVocab({'char': charvocab,
|
| 76 |
+
'word': wordvocab,
|
| 77 |
+
'upos': uposvocab,
|
| 78 |
+
'xpos': xposvocab,
|
| 79 |
+
'feats': featsvocab})
|
| 80 |
+
return vocab
|
| 81 |
+
|
| 82 |
+
def preprocess(self, data, vocab, pretrain_vocab, args):
|
| 83 |
+
processed = []
|
| 84 |
+
for sent in data:
|
| 85 |
+
processed_sent = DataSample(
|
| 86 |
+
word = [vocab['word'].map([w[0] for w in sent])],
|
| 87 |
+
char = [[vocab['char'].map([x for x in w[0]]) for w in sent]],
|
| 88 |
+
upos = [vocab['upos'].map([w[1] for w in sent])],
|
| 89 |
+
xpos = [vocab['xpos'].map([w[2] for w in sent])],
|
| 90 |
+
feats = [vocab['feats'].map([w[3] for w in sent])],
|
| 91 |
+
pretrain = ([pretrain_vocab.map([w[0].lower() for w in sent])]
|
| 92 |
+
if pretrain_vocab is not None
|
| 93 |
+
else [[PAD_ID] * len(sent)]),
|
| 94 |
+
text = [w[0] for w in sent]
|
| 95 |
+
)
|
| 96 |
+
processed.append(processed_sent)
|
| 97 |
+
|
| 98 |
+
return processed
|
| 99 |
+
|
| 100 |
+
def __len__(self):
|
| 101 |
+
return len(self.data)
|
| 102 |
+
|
| 103 |
+
def __mask(self, upos):
|
| 104 |
+
"""Returns a torch boolean about which elements should be masked out"""
|
| 105 |
+
|
| 106 |
+
# creates all false mask
|
| 107 |
+
mask = torch.zeros_like(upos, dtype=torch.bool)
|
| 108 |
+
|
| 109 |
+
### augmentation 1: punctuation augmentation ###
|
| 110 |
+
# tags that needs to be checked, currently only PUNCT
|
| 111 |
+
if random.uniform(0,1) < self.augment_nopunct:
|
| 112 |
+
for i in self.__punct_tags:
|
| 113 |
+
# generate a mask for the last element
|
| 114 |
+
last_element = torch.zeros_like(upos, dtype=torch.bool)
|
| 115 |
+
last_element[..., -1] = True
|
| 116 |
+
# we or the bitmask against the existing mask
|
| 117 |
+
# if it satisfies, we remove the word by masking it
|
| 118 |
+
# to true
|
| 119 |
+
#
|
| 120 |
+
# if your input is just a lone punctuation, we perform
|
| 121 |
+
# no masking
|
| 122 |
+
if not torch.all(upos.eq(torch.tensor([[i]]))):
|
| 123 |
+
mask |= ((upos == i) & (last_element))
|
| 124 |
+
|
| 125 |
+
return mask
|
| 126 |
+
|
| 127 |
+
def __getitem__(self, key):
|
| 128 |
+
"""Retrieves a sample from the dataset.
|
| 129 |
+
|
| 130 |
+
Retrieves a sample from the dataset. This function, for the
|
| 131 |
+
most part, is spent performing ad-hoc data augmentation and
|
| 132 |
+
restoration. It recieves a DataSample object from the storage,
|
| 133 |
+
and returns an almost-identical DataSample object that may
|
| 134 |
+
have been augmented with /possibly/ (depending on augment_punct
|
| 135 |
+
settings) PUNCT chopped.
|
| 136 |
+
|
| 137 |
+
**Important Note**
|
| 138 |
+
------------------
|
| 139 |
+
If you would like to load the data into a model, please convert
|
| 140 |
+
this Dataset object into a DataLoader via self.to_loader(). Then,
|
| 141 |
+
you can use the resulting object like any other PyTorch data
|
| 142 |
+
loader. As masks are calculated ad-hoc given the batch, the samples
|
| 143 |
+
returned from this object doesn't have the appropriate masking.
|
| 144 |
+
|
| 145 |
+
Motivation
|
| 146 |
+
----------
|
| 147 |
+
Why is this here? Every time you call next(iter(dataloader)), it calls
|
| 148 |
+
this function. Therefore, if we augmented each sample on each iteration,
|
| 149 |
+
the model will see dynamically generated augmentation.
|
| 150 |
+
Furthermore, PyTorch dataloader handles shuffling natively.
|
| 151 |
+
|
| 152 |
+
Parameters
|
| 153 |
+
----------
|
| 154 |
+
key : int
|
| 155 |
+
the integer ID to from which to retrieve the key.
|
| 156 |
+
|
| 157 |
+
Returns
|
| 158 |
+
-------
|
| 159 |
+
DataSample
|
| 160 |
+
The sample of data you requested, with augmentation.
|
| 161 |
+
"""
|
| 162 |
+
# get a sample of the input data
|
| 163 |
+
sample = self.data[key]
|
| 164 |
+
|
| 165 |
+
# some data augmentation requires constructing a mask based on upos.
|
| 166 |
+
# For instance, sometimes we'd like to mask out ending sentence punctuation.
|
| 167 |
+
# We copy the other items here so that any edits made because
|
| 168 |
+
# of the mask don't clobber the version owned by the Dataset
|
| 169 |
+
# convert to tensors
|
| 170 |
+
# TODO: only store single lists per data entry?
|
| 171 |
+
words = torch.tensor(sample.word[0])
|
| 172 |
+
# convert the rest to tensors
|
| 173 |
+
upos = torch.tensor(sample.upos[0]) if self.has_upos else None
|
| 174 |
+
xpos = torch.tensor(sample.xpos[0]) if self.has_xpos else None
|
| 175 |
+
ufeats = torch.tensor(sample.feats[0]) if self.has_feats else None
|
| 176 |
+
pretrained = torch.tensor(sample.pretrain[0])
|
| 177 |
+
|
| 178 |
+
# and deal with char & raw_text
|
| 179 |
+
char = sample.char[0]
|
| 180 |
+
raw_text = sample.text
|
| 181 |
+
|
| 182 |
+
# some data augmentation requires constructing a mask based on
|
| 183 |
+
# which upos. For instance, sometimes we'd like to mask out ending
|
| 184 |
+
# sentence punctuation. The mask is True if we want to remove the element
|
| 185 |
+
if self.has_upos and upos is not None and not self.eval:
|
| 186 |
+
# perform actual masking
|
| 187 |
+
mask = self.__mask(upos)
|
| 188 |
+
else:
|
| 189 |
+
# dummy mask that's all false
|
| 190 |
+
mask = None
|
| 191 |
+
if mask is not None:
|
| 192 |
+
mask_index = mask.nonzero()
|
| 193 |
+
|
| 194 |
+
# mask out the elements that we need to mask out
|
| 195 |
+
for mask in mask_index:
|
| 196 |
+
mask = mask.item()
|
| 197 |
+
words[mask] = PAD_ID
|
| 198 |
+
if upos is not None:
|
| 199 |
+
upos[mask] = PAD_ID
|
| 200 |
+
if xpos is not None:
|
| 201 |
+
# TODO: test the multi-dimension xpos
|
| 202 |
+
xpos[mask, ...] = PAD_ID
|
| 203 |
+
if ufeats is not None:
|
| 204 |
+
ufeats[mask, ...] = PAD_ID
|
| 205 |
+
pretrained[mask] = PAD_ID
|
| 206 |
+
char = char[:mask] + char[mask+1:]
|
| 207 |
+
raw_text = raw_text[:mask] + raw_text[mask+1:]
|
| 208 |
+
|
| 209 |
+
# get each character from the input sentnece
|
| 210 |
+
# chars = [w for sent in char for w in sent]
|
| 211 |
+
|
| 212 |
+
return DataSample(words, char, upos, xpos, ufeats, pretrained, raw_text), key
|
| 213 |
+
|
| 214 |
+
def __iter__(self):
|
| 215 |
+
for i in range(self.__len__()):
|
| 216 |
+
yield self.__getitem__(i)
|
| 217 |
+
|
| 218 |
+
def to_loader(self, **kwargs):
|
| 219 |
+
"""Converts self to a DataLoader """
|
| 220 |
+
|
| 221 |
+
return DL(self,
|
| 222 |
+
collate_fn=Dataset.__collate_fn,
|
| 223 |
+
**kwargs)
|
| 224 |
+
|
| 225 |
+
def to_length_limited_loader(self, batch_size, maximum_tokens):
|
| 226 |
+
sampler = LengthLimitedBatchSampler(self, batch_size, maximum_tokens)
|
| 227 |
+
return DL(self,
|
| 228 |
+
collate_fn=Dataset.__collate_fn,
|
| 229 |
+
batch_sampler = sampler)
|
| 230 |
+
|
| 231 |
+
@staticmethod
|
| 232 |
+
def __collate_fn(data):
|
| 233 |
+
"""Function used by DataLoader to pack data"""
|
| 234 |
+
(data, idx) = zip(*data)
|
| 235 |
+
(words, wordchars, upos, xpos, ufeats, pretrained, text) = zip(*data)
|
| 236 |
+
|
| 237 |
+
# collate_fn is given a list of length batch size
|
| 238 |
+
batch_size = len(data)
|
| 239 |
+
|
| 240 |
+
# sort sentences by lens for easy RNN operations
|
| 241 |
+
lens = [torch.sum(x != PAD_ID) for x in words]
|
| 242 |
+
(words, wordchars, upos, xpos,
|
| 243 |
+
ufeats, pretrained, text), orig_idx = sort_all((words, wordchars, upos, xpos,
|
| 244 |
+
ufeats, pretrained, text), lens)
|
| 245 |
+
lens = [torch.sum(x != PAD_ID) for x in words] # we need to reinterpret lengths for the RNN
|
| 246 |
+
|
| 247 |
+
# combine all words into one large list, and sort for easy charRNN ops
|
| 248 |
+
wordchars = [w for sent in wordchars for w in sent]
|
| 249 |
+
word_lens = [len(x) for x in wordchars]
|
| 250 |
+
(wordchars,), word_orig_idx = sort_all([wordchars], word_lens)
|
| 251 |
+
word_lens = [len(x) for x in wordchars] # we need to reinterpret lengths for the RNN
|
| 252 |
+
|
| 253 |
+
# We now pad everything
|
| 254 |
+
words = pad_sequence(words, True, PAD_ID)
|
| 255 |
+
if None not in upos:
|
| 256 |
+
upos = pad_sequence(upos, True, PAD_ID)
|
| 257 |
+
else:
|
| 258 |
+
upos = None
|
| 259 |
+
if None not in xpos:
|
| 260 |
+
xpos = pad_sequence(xpos, True, PAD_ID)
|
| 261 |
+
else:
|
| 262 |
+
xpos = None
|
| 263 |
+
if None not in ufeats:
|
| 264 |
+
ufeats = pad_sequence(ufeats, True, PAD_ID)
|
| 265 |
+
else:
|
| 266 |
+
ufeats = None
|
| 267 |
+
pretrained = pad_sequence(pretrained, True, PAD_ID)
|
| 268 |
+
wordchars = get_long_tensor(wordchars, len(word_lens))
|
| 269 |
+
|
| 270 |
+
# and finally create masks for the padding indices
|
| 271 |
+
words_mask = torch.eq(words, PAD_ID)
|
| 272 |
+
wordchars_mask = torch.eq(wordchars, PAD_ID)
|
| 273 |
+
|
| 274 |
+
return DataBatch(words, words_mask, wordchars, wordchars_mask, upos, xpos, ufeats,
|
| 275 |
+
pretrained, orig_idx, word_orig_idx, lens, word_lens, text, idx)
|
| 276 |
+
|
| 277 |
+
@staticmethod
|
| 278 |
+
def load_doc(doc):
|
| 279 |
+
data = doc.get([TEXT, UPOS, XPOS, FEATS], as_sentences=True)
|
| 280 |
+
data = Dataset.resolve_none(data)
|
| 281 |
+
return data
|
| 282 |
+
|
| 283 |
+
@staticmethod
|
| 284 |
+
def resolve_none(data):
|
| 285 |
+
# replace None to '_'
|
| 286 |
+
for sent_idx in range(len(data)):
|
| 287 |
+
for tok_idx in range(len(data[sent_idx])):
|
| 288 |
+
for feat_idx in range(len(data[sent_idx][tok_idx])):
|
| 289 |
+
if data[sent_idx][tok_idx][feat_idx] is None:
|
| 290 |
+
data[sent_idx][tok_idx][feat_idx] = '_'
|
| 291 |
+
return data
|
| 292 |
+
|
| 293 |
+
class LengthLimitedBatchSampler(Sampler):
|
| 294 |
+
"""
|
| 295 |
+
Batches up the text in batches of batch_size, but cuts off each time a batch reaches maximum_tokens
|
| 296 |
+
|
| 297 |
+
Intent is to avoid GPU OOM in situations where one sentence is significantly longer than expected,
|
| 298 |
+
leaving a batch too large to fit in the GPU
|
| 299 |
+
|
| 300 |
+
Sentences which are longer than maximum_tokens by themselves are put in their own batches
|
| 301 |
+
"""
|
| 302 |
+
def __init__(self, data, batch_size, maximum_tokens):
|
| 303 |
+
"""
|
| 304 |
+
Precalculate the batches, making it so len and iter just read off the precalculated batches
|
| 305 |
+
"""
|
| 306 |
+
self.data = data
|
| 307 |
+
self.batch_size = batch_size
|
| 308 |
+
self.maximum_tokens = maximum_tokens
|
| 309 |
+
|
| 310 |
+
self.batches = []
|
| 311 |
+
current_batch = []
|
| 312 |
+
current_length = 0
|
| 313 |
+
|
| 314 |
+
for item, item_idx in data:
|
| 315 |
+
item_len = len(item.word)
|
| 316 |
+
if maximum_tokens and item_len > maximum_tokens:
|
| 317 |
+
if len(current_batch) > 0:
|
| 318 |
+
self.batches.append(current_batch)
|
| 319 |
+
current_batch = []
|
| 320 |
+
current_length = 0
|
| 321 |
+
self.batches.append([item_idx])
|
| 322 |
+
continue
|
| 323 |
+
if len(current_batch) + 1 > batch_size or (maximum_tokens and item_len + current_length > maximum_tokens):
|
| 324 |
+
self.batches.append(current_batch)
|
| 325 |
+
current_batch = []
|
| 326 |
+
current_length = 0
|
| 327 |
+
current_batch.append(item_idx)
|
| 328 |
+
current_length += item_len
|
| 329 |
+
|
| 330 |
+
if len(current_batch) > 0:
|
| 331 |
+
self.batches.append(current_batch)
|
| 332 |
+
|
| 333 |
+
def __len__(self):
|
| 334 |
+
return len(self.batches)
|
| 335 |
+
|
| 336 |
+
def __iter__(self):
|
| 337 |
+
for batch in self.batches:
|
| 338 |
+
current_batch = []
|
| 339 |
+
for idx in batch:
|
| 340 |
+
current_batch.append(idx)
|
| 341 |
+
yield current_batch
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
class ShuffledDataset:
|
| 345 |
+
"""A wrapper around one or more datasets which shuffles the data in batch_size chunks
|
| 346 |
+
|
| 347 |
+
This means that if multiple datasets are passed in, the batches
|
| 348 |
+
from each dataset are shuffled together, with one batch being
|
| 349 |
+
entirely members of the same dataset.
|
| 350 |
+
|
| 351 |
+
The main use case of this is that in the tagger, there are cases
|
| 352 |
+
where batches from different datasets will have different
|
| 353 |
+
properties, such as having or not having UPOS tags. We found that
|
| 354 |
+
it is actually somewhat tricky to make the model's loss function
|
| 355 |
+
(in model.py) properly represent batches with mixed w/ and w/o
|
| 356 |
+
property, whereas keeping one entire batch together makes it a lot
|
| 357 |
+
easier to process.
|
| 358 |
+
|
| 359 |
+
The mechanism for the shuffling is that the iterator first makes a
|
| 360 |
+
list long enough to represent each batch from each dataset,
|
| 361 |
+
tracking the index of the dataset it is coming from, then shuffles
|
| 362 |
+
that list. Another alternative would be to use a weighted
|
| 363 |
+
randomization approach, but this is very simple and the memory
|
| 364 |
+
requirements are not too onerous.
|
| 365 |
+
|
| 366 |
+
Note that the batch indices are wasteful in the case of only one
|
| 367 |
+
underlying dataset, which is actually the most common use case,
|
| 368 |
+
but the overhead is small enough that it probably isn't worth
|
| 369 |
+
special casing the one dataset version.
|
| 370 |
+
"""
|
| 371 |
+
def __init__(self, datasets, batch_size):
|
| 372 |
+
self.batch_size = batch_size
|
| 373 |
+
self.datasets = datasets
|
| 374 |
+
self.loaders = [x.to_loader(batch_size=self.batch_size, shuffle=True) for x in self.datasets]
|
| 375 |
+
|
| 376 |
+
def __iter__(self):
|
| 377 |
+
iterators = [iter(x) for x in self.loaders]
|
| 378 |
+
lengths = [len(x) for x in self.loaders]
|
| 379 |
+
indices = [[x] * y for x, y in enumerate(lengths)]
|
| 380 |
+
indices = [idx for inner in indices for idx in inner]
|
| 381 |
+
random.shuffle(indices)
|
| 382 |
+
|
| 383 |
+
for idx in indices:
|
| 384 |
+
yield(next(iterators[idx]))
|
| 385 |
+
|
| 386 |
+
def __len__(self):
|
| 387 |
+
return sum(len(x) for x in self.datasets)
|
stanza/stanza/models/pos/model.py
ADDED
|
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, pack_sequence, pad_sequence, PackedSequence
|
| 9 |
+
|
| 10 |
+
from stanza.models.common.bert_embedding import extract_bert_embeddings
|
| 11 |
+
from stanza.models.common.biaffine import BiaffineScorer
|
| 12 |
+
from stanza.models.common.foundation_cache import load_bert, load_charlm
|
| 13 |
+
from stanza.models.common.hlstm import HighwayLSTM
|
| 14 |
+
from stanza.models.common.dropout import WordDropout
|
| 15 |
+
from stanza.models.common.utils import attach_bert_model
|
| 16 |
+
from stanza.models.common.vocab import CompositeVocab
|
| 17 |
+
from stanza.models.common.char_model import CharacterModel
|
| 18 |
+
from stanza.models.common import utils
|
| 19 |
+
|
| 20 |
+
logger = logging.getLogger('stanza')
|
| 21 |
+
|
| 22 |
+
class Tagger(nn.Module):
|
| 23 |
+
def __init__(self, args, vocab, emb_matrix=None, share_hid=False, foundation_cache=None, bert_model=None, bert_tokenizer=None, force_bert_saved=False, peft_name=None):
|
| 24 |
+
super().__init__()
|
| 25 |
+
|
| 26 |
+
self.vocab = vocab
|
| 27 |
+
self.args = args
|
| 28 |
+
self.share_hid = share_hid
|
| 29 |
+
self.unsaved_modules = []
|
| 30 |
+
|
| 31 |
+
# input layers
|
| 32 |
+
input_size = 0
|
| 33 |
+
if self.args['word_emb_dim'] > 0:
|
| 34 |
+
# frequent word embeddings
|
| 35 |
+
self.word_emb = nn.Embedding(len(vocab['word']), self.args['word_emb_dim'], padding_idx=0)
|
| 36 |
+
input_size += self.args['word_emb_dim']
|
| 37 |
+
|
| 38 |
+
if not share_hid:
|
| 39 |
+
# upos embeddings
|
| 40 |
+
self.upos_emb = nn.Embedding(len(vocab['upos']), self.args['tag_emb_dim'], padding_idx=0)
|
| 41 |
+
|
| 42 |
+
if self.args['char'] and self.args['char_emb_dim'] > 0:
|
| 43 |
+
if self.args.get('charlm', None):
|
| 44 |
+
if args['charlm_forward_file'] is None or not os.path.exists(args['charlm_forward_file']):
|
| 45 |
+
raise FileNotFoundError('Could not find forward character model: {} Please specify with --charlm_forward_file'.format(args['charlm_forward_file']))
|
| 46 |
+
if args['charlm_backward_file'] is None or not os.path.exists(args['charlm_backward_file']):
|
| 47 |
+
raise FileNotFoundError('Could not find backward character model: {} Please specify with --charlm_backward_file'.format(args['charlm_backward_file']))
|
| 48 |
+
logger.debug("POS model loading charmodels: %s and %s", args['charlm_forward_file'], args['charlm_backward_file'])
|
| 49 |
+
self.add_unsaved_module('charmodel_forward', load_charlm(args['charlm_forward_file'], foundation_cache=foundation_cache))
|
| 50 |
+
self.add_unsaved_module('charmodel_backward', load_charlm(args['charlm_backward_file'], foundation_cache=foundation_cache))
|
| 51 |
+
# optionally add a input transformation layer
|
| 52 |
+
if self.args.get('charlm_transform_dim', 0):
|
| 53 |
+
self.charmodel_forward_transform = nn.Linear(self.charmodel_forward.hidden_dim(), self.args['charlm_transform_dim'], bias=False)
|
| 54 |
+
self.charmodel_backward_transform = nn.Linear(self.charmodel_backward.hidden_dim(), self.args['charlm_transform_dim'], bias=False)
|
| 55 |
+
input_size += self.args['charlm_transform_dim'] * 2
|
| 56 |
+
else:
|
| 57 |
+
self.charmodel_forward_transform = None
|
| 58 |
+
self.charmodel_backward_transform = None
|
| 59 |
+
input_size += self.charmodel_forward.hidden_dim() + self.charmodel_backward.hidden_dim()
|
| 60 |
+
else:
|
| 61 |
+
bidirectional = args.get('char_bidirectional', False)
|
| 62 |
+
self.charmodel = CharacterModel(args, vocab, bidirectional=bidirectional)
|
| 63 |
+
if bidirectional:
|
| 64 |
+
self.trans_char = nn.Linear(self.args['char_hidden_dim'] * 2, self.args['transformed_dim'], bias=False)
|
| 65 |
+
else:
|
| 66 |
+
self.trans_char = nn.Linear(self.args['char_hidden_dim'], self.args['transformed_dim'], bias=False)
|
| 67 |
+
input_size += self.args['transformed_dim']
|
| 68 |
+
|
| 69 |
+
self.peft_name = peft_name
|
| 70 |
+
attach_bert_model(self, bert_model, bert_tokenizer, self.args.get('use_peft', False), force_bert_saved)
|
| 71 |
+
if self.args.get('bert_model', None):
|
| 72 |
+
# TODO: refactor bert_hidden_layers between the different models
|
| 73 |
+
if args.get('bert_hidden_layers', False):
|
| 74 |
+
# The average will be offset by 1/N so that the default zeros
|
| 75 |
+
# represents an average of the N layers
|
| 76 |
+
self.bert_layer_mix = nn.Linear(args['bert_hidden_layers'], 1, bias=False)
|
| 77 |
+
nn.init.zeros_(self.bert_layer_mix.weight)
|
| 78 |
+
else:
|
| 79 |
+
# an average of layers 2, 3, 4 will be used
|
| 80 |
+
# (for historic reasons)
|
| 81 |
+
self.bert_layer_mix = None
|
| 82 |
+
input_size += self.bert_model.config.hidden_size
|
| 83 |
+
|
| 84 |
+
if self.args['pretrain']:
|
| 85 |
+
# pretrained embeddings, by default this won't be saved into model file
|
| 86 |
+
self.add_unsaved_module('pretrained_emb', nn.Embedding.from_pretrained(emb_matrix, freeze=True))
|
| 87 |
+
self.trans_pretrained = nn.Linear(emb_matrix.shape[1], self.args['transformed_dim'], bias=False)
|
| 88 |
+
input_size += self.args['transformed_dim']
|
| 89 |
+
|
| 90 |
+
# recurrent layers
|
| 91 |
+
self.taggerlstm = HighwayLSTM(input_size, self.args['hidden_dim'], self.args['num_layers'], batch_first=True, bidirectional=True, dropout=self.args['dropout'], rec_dropout=self.args['rec_dropout'], highway_func=torch.tanh)
|
| 92 |
+
self.drop_replacement = nn.Parameter(torch.randn(input_size) / np.sqrt(input_size))
|
| 93 |
+
self.taggerlstm_h_init = nn.Parameter(torch.zeros(2 * self.args['num_layers'], 1, self.args['hidden_dim']))
|
| 94 |
+
self.taggerlstm_c_init = nn.Parameter(torch.zeros(2 * self.args['num_layers'], 1, self.args['hidden_dim']))
|
| 95 |
+
|
| 96 |
+
# classifiers
|
| 97 |
+
self.upos_hid = nn.Linear(self.args['hidden_dim'] * 2, self.args['deep_biaff_hidden_dim'])
|
| 98 |
+
self.upos_clf = nn.Linear(self.args['deep_biaff_hidden_dim'], len(vocab['upos']))
|
| 99 |
+
self.upos_clf.weight.data.zero_()
|
| 100 |
+
self.upos_clf.bias.data.zero_()
|
| 101 |
+
|
| 102 |
+
if share_hid:
|
| 103 |
+
clf_constructor = lambda insize, outsize: nn.Linear(insize, outsize)
|
| 104 |
+
else:
|
| 105 |
+
self.xpos_hid = nn.Linear(self.args['hidden_dim'] * 2, self.args['deep_biaff_hidden_dim'] if not isinstance(vocab['xpos'], CompositeVocab) else self.args['composite_deep_biaff_hidden_dim'])
|
| 106 |
+
self.ufeats_hid = nn.Linear(self.args['hidden_dim'] * 2, self.args['composite_deep_biaff_hidden_dim'])
|
| 107 |
+
clf_constructor = lambda insize, outsize: BiaffineScorer(insize, self.args['tag_emb_dim'], outsize)
|
| 108 |
+
|
| 109 |
+
if isinstance(vocab['xpos'], CompositeVocab):
|
| 110 |
+
self.xpos_clf = nn.ModuleList()
|
| 111 |
+
for l in vocab['xpos'].lens():
|
| 112 |
+
self.xpos_clf.append(clf_constructor(self.args['composite_deep_biaff_hidden_dim'], l))
|
| 113 |
+
else:
|
| 114 |
+
self.xpos_clf = clf_constructor(self.args['deep_biaff_hidden_dim'], len(vocab['xpos']))
|
| 115 |
+
if share_hid:
|
| 116 |
+
self.xpos_clf.weight.data.zero_()
|
| 117 |
+
self.xpos_clf.bias.data.zero_()
|
| 118 |
+
|
| 119 |
+
self.ufeats_clf = nn.ModuleList()
|
| 120 |
+
for l in vocab['feats'].lens():
|
| 121 |
+
if share_hid:
|
| 122 |
+
self.ufeats_clf.append(clf_constructor(self.args['deep_biaff_hidden_dim'], l))
|
| 123 |
+
self.ufeats_clf[-1].weight.data.zero_()
|
| 124 |
+
self.ufeats_clf[-1].bias.data.zero_()
|
| 125 |
+
else:
|
| 126 |
+
self.ufeats_clf.append(clf_constructor(self.args['composite_deep_biaff_hidden_dim'], l))
|
| 127 |
+
|
| 128 |
+
# criterion
|
| 129 |
+
self.crit = nn.CrossEntropyLoss(ignore_index=0) # ignore padding
|
| 130 |
+
|
| 131 |
+
self.drop = nn.Dropout(args['dropout'])
|
| 132 |
+
self.worddrop = WordDropout(args['word_dropout'])
|
| 133 |
+
|
| 134 |
+
def add_unsaved_module(self, name, module):
|
| 135 |
+
self.unsaved_modules += [name]
|
| 136 |
+
setattr(self, name, module)
|
| 137 |
+
|
| 138 |
+
def log_norms(self):
|
| 139 |
+
utils.log_norms(self)
|
| 140 |
+
|
| 141 |
+
def forward(self, word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained, word_orig_idx, sentlens, wordlens, text):
|
| 142 |
+
|
| 143 |
+
def pack(x):
|
| 144 |
+
return pack_padded_sequence(x, sentlens, batch_first=True)
|
| 145 |
+
|
| 146 |
+
inputs = []
|
| 147 |
+
if self.args['word_emb_dim'] > 0:
|
| 148 |
+
word_emb = self.word_emb(word)
|
| 149 |
+
word_emb = pack(word_emb)
|
| 150 |
+
inputs += [word_emb]
|
| 151 |
+
|
| 152 |
+
if self.args['pretrain']:
|
| 153 |
+
pretrained_emb = self.pretrained_emb(pretrained)
|
| 154 |
+
pretrained_emb = self.trans_pretrained(pretrained_emb)
|
| 155 |
+
pretrained_emb = pack(pretrained_emb)
|
| 156 |
+
inputs += [pretrained_emb]
|
| 157 |
+
|
| 158 |
+
def pad(x):
|
| 159 |
+
return pad_packed_sequence(PackedSequence(x, inputs[0].batch_sizes), batch_first=True)[0]
|
| 160 |
+
|
| 161 |
+
if self.args['char'] and self.args['char_emb_dim'] > 0:
|
| 162 |
+
if self.args.get('charlm', None):
|
| 163 |
+
all_forward_chars = self.charmodel_forward.build_char_representation(text)
|
| 164 |
+
assert isinstance(all_forward_chars, list)
|
| 165 |
+
if self.charmodel_forward_transform is not None:
|
| 166 |
+
all_forward_chars = [self.charmodel_forward_transform(x) for x in all_forward_chars]
|
| 167 |
+
all_forward_chars = pack(pad_sequence(all_forward_chars, batch_first=True))
|
| 168 |
+
|
| 169 |
+
all_backward_chars = self.charmodel_backward.build_char_representation(text)
|
| 170 |
+
if self.charmodel_backward_transform is not None:
|
| 171 |
+
all_backward_chars = [self.charmodel_backward_transform(x) for x in all_backward_chars]
|
| 172 |
+
all_backward_chars = pack(pad_sequence(all_backward_chars, batch_first=True))
|
| 173 |
+
|
| 174 |
+
inputs += [all_forward_chars, all_backward_chars]
|
| 175 |
+
else:
|
| 176 |
+
char_reps = self.charmodel(wordchars, wordchars_mask, word_orig_idx, sentlens, wordlens)
|
| 177 |
+
char_reps = PackedSequence(self.trans_char(self.drop(char_reps.data)), char_reps.batch_sizes)
|
| 178 |
+
inputs += [char_reps]
|
| 179 |
+
|
| 180 |
+
if self.bert_model is not None:
|
| 181 |
+
device = next(self.parameters()).device
|
| 182 |
+
processed_bert = extract_bert_embeddings(self.args['bert_model'], self.bert_tokenizer, self.bert_model, text, device, keep_endpoints=False,
|
| 183 |
+
num_layers=self.bert_layer_mix.in_features if self.bert_layer_mix is not None else None,
|
| 184 |
+
detach=not self.args.get('bert_finetune', False) or not self.training,
|
| 185 |
+
peft_name=self.peft_name)
|
| 186 |
+
|
| 187 |
+
if self.bert_layer_mix is not None:
|
| 188 |
+
# add the average so that the default behavior is to
|
| 189 |
+
# take an average of the N layers, and anything else
|
| 190 |
+
# other than that needs to be learned
|
| 191 |
+
# TODO: refactor this
|
| 192 |
+
processed_bert = [self.bert_layer_mix(feature).squeeze(2) + feature.sum(axis=2) / self.bert_layer_mix.in_features for feature in processed_bert]
|
| 193 |
+
|
| 194 |
+
processed_bert = pad_sequence(processed_bert, batch_first=True)
|
| 195 |
+
inputs += [pack(processed_bert)]
|
| 196 |
+
|
| 197 |
+
lstm_inputs = torch.cat([x.data for x in inputs], 1)
|
| 198 |
+
lstm_inputs = self.worddrop(lstm_inputs, self.drop_replacement)
|
| 199 |
+
lstm_inputs = self.drop(lstm_inputs)
|
| 200 |
+
lstm_inputs = PackedSequence(lstm_inputs, inputs[0].batch_sizes)
|
| 201 |
+
|
| 202 |
+
lstm_outputs, _ = self.taggerlstm(lstm_inputs, sentlens, hx=(self.taggerlstm_h_init.expand(2 * self.args['num_layers'], word.size(0), self.args['hidden_dim']).contiguous(), self.taggerlstm_c_init.expand(2 * self.args['num_layers'], word.size(0), self.args['hidden_dim']).contiguous()))
|
| 203 |
+
lstm_outputs = lstm_outputs.data
|
| 204 |
+
|
| 205 |
+
upos_hid = F.relu(self.upos_hid(self.drop(lstm_outputs)))
|
| 206 |
+
upos_pred = self.upos_clf(self.drop(upos_hid))
|
| 207 |
+
|
| 208 |
+
preds = [pad(upos_pred).max(2)[1]]
|
| 209 |
+
|
| 210 |
+
if upos is not None:
|
| 211 |
+
upos = pack(upos).data
|
| 212 |
+
loss = self.crit(upos_pred.view(-1, upos_pred.size(-1)), upos.view(-1))
|
| 213 |
+
else:
|
| 214 |
+
loss = 0.0
|
| 215 |
+
|
| 216 |
+
if self.share_hid:
|
| 217 |
+
xpos_hid = upos_hid
|
| 218 |
+
ufeats_hid = upos_hid
|
| 219 |
+
|
| 220 |
+
clffunc = lambda clf, hid: clf(self.drop(hid))
|
| 221 |
+
else:
|
| 222 |
+
xpos_hid = F.relu(self.xpos_hid(self.drop(lstm_outputs)))
|
| 223 |
+
ufeats_hid = F.relu(self.ufeats_hid(self.drop(lstm_outputs)))
|
| 224 |
+
|
| 225 |
+
if self.training and upos is not None:
|
| 226 |
+
upos_emb = self.upos_emb(upos)
|
| 227 |
+
else:
|
| 228 |
+
upos_emb = self.upos_emb(upos_pred.max(1)[1])
|
| 229 |
+
|
| 230 |
+
clffunc = lambda clf, hid: clf(self.drop(hid), self.drop(upos_emb))
|
| 231 |
+
|
| 232 |
+
if xpos is not None: xpos = pack(xpos).data
|
| 233 |
+
if isinstance(self.vocab['xpos'], CompositeVocab):
|
| 234 |
+
xpos_preds = []
|
| 235 |
+
for i in range(len(self.vocab['xpos'])):
|
| 236 |
+
xpos_pred = clffunc(self.xpos_clf[i], xpos_hid)
|
| 237 |
+
if xpos is not None:
|
| 238 |
+
loss += self.crit(xpos_pred.view(-1, xpos_pred.size(-1)), xpos[:, i].view(-1))
|
| 239 |
+
xpos_preds.append(pad(xpos_pred).max(2, keepdim=True)[1])
|
| 240 |
+
preds.append(torch.cat(xpos_preds, 2))
|
| 241 |
+
else:
|
| 242 |
+
xpos_pred = clffunc(self.xpos_clf, xpos_hid)
|
| 243 |
+
if xpos is not None:
|
| 244 |
+
loss += self.crit(xpos_pred.view(-1, xpos_pred.size(-1)), xpos.view(-1))
|
| 245 |
+
preds.append(pad(xpos_pred).max(2)[1])
|
| 246 |
+
|
| 247 |
+
ufeats_preds = []
|
| 248 |
+
if ufeats is not None: ufeats = pack(ufeats).data
|
| 249 |
+
for i in range(len(self.vocab['feats'])):
|
| 250 |
+
ufeats_pred = clffunc(self.ufeats_clf[i], ufeats_hid)
|
| 251 |
+
if ufeats is not None:
|
| 252 |
+
loss += self.crit(ufeats_pred.view(-1, ufeats_pred.size(-1)), ufeats[:, i].view(-1))
|
| 253 |
+
ufeats_preds.append(pad(ufeats_pred).max(2, keepdim=True)[1])
|
| 254 |
+
preds.append(torch.cat(ufeats_preds, 2))
|
| 255 |
+
|
| 256 |
+
return loss, preds
|
stanza/stanza/models/pos/trainer.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
A trainer class to handle training and testing of models.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import sys
|
| 6 |
+
import logging
|
| 7 |
+
import torch
|
| 8 |
+
from torch import nn
|
| 9 |
+
|
| 10 |
+
from stanza.models.common.trainer import Trainer as BaseTrainer
|
| 11 |
+
from stanza.models.common import utils, loss
|
| 12 |
+
from stanza.models.common.foundation_cache import load_bert, load_bert_with_peft, NoTransformerFoundationCache
|
| 13 |
+
from stanza.models.common.peft_config import build_peft_wrapper, load_peft_wrapper
|
| 14 |
+
from stanza.models.pos.model import Tagger
|
| 15 |
+
from stanza.models.pos.vocab import MultiVocab
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger('stanza')
|
| 18 |
+
|
| 19 |
+
def unpack_batch(batch, device):
|
| 20 |
+
""" Unpack a batch from the data loader. """
|
| 21 |
+
inputs = [b.to(device) if b is not None else None for b in batch[:8]]
|
| 22 |
+
orig_idx = batch[8]
|
| 23 |
+
word_orig_idx = batch[9]
|
| 24 |
+
sentlens = batch[10]
|
| 25 |
+
wordlens = batch[11]
|
| 26 |
+
text = batch[12]
|
| 27 |
+
return inputs, orig_idx, word_orig_idx, sentlens, wordlens, text
|
| 28 |
+
|
| 29 |
+
class Trainer(BaseTrainer):
|
| 30 |
+
""" A trainer for training models. """
|
| 31 |
+
def __init__(self, args=None, vocab=None, pretrain=None, model_file=None, device=None, foundation_cache=None):
|
| 32 |
+
if model_file is not None:
|
| 33 |
+
# load everything from file
|
| 34 |
+
self.load(model_file, pretrain, args=args, foundation_cache=foundation_cache)
|
| 35 |
+
else:
|
| 36 |
+
# build model from scratch
|
| 37 |
+
self.args = args
|
| 38 |
+
self.vocab = vocab
|
| 39 |
+
|
| 40 |
+
bert_model, bert_tokenizer = load_bert(self.args['bert_model'])
|
| 41 |
+
peft_name = None
|
| 42 |
+
if self.args['use_peft']:
|
| 43 |
+
# fine tune the bert if we're using peft
|
| 44 |
+
self.args['bert_finetune'] = True
|
| 45 |
+
peft_name = "pos"
|
| 46 |
+
bert_model = build_peft_wrapper(bert_model, self.args, logger, adapter_name=peft_name)
|
| 47 |
+
|
| 48 |
+
self.model = Tagger(args, vocab, emb_matrix=pretrain.emb if pretrain is not None else None, share_hid=args['share_hid'], foundation_cache=foundation_cache, bert_model=bert_model, bert_tokenizer=bert_tokenizer, force_bert_saved=self.args['bert_finetune'], peft_name=peft_name)
|
| 49 |
+
|
| 50 |
+
self.model = self.model.to(device)
|
| 51 |
+
self.optimizers = utils.get_split_optimizer(self.args['optim'], self.model, self.args['lr'], betas=(0.9, self.args['beta2']), eps=1e-6, weight_decay=self.args.get('initial_weight_decay', None), bert_learning_rate=self.args.get('bert_learning_rate', 0.0), is_peft=self.args.get("peft", False))
|
| 52 |
+
|
| 53 |
+
self.schedulers = {}
|
| 54 |
+
|
| 55 |
+
if self.args.get('bert_finetune', None):
|
| 56 |
+
import transformers
|
| 57 |
+
warmup_scheduler = transformers.get_linear_schedule_with_warmup(
|
| 58 |
+
self.optimizers["bert_optimizer"],
|
| 59 |
+
# todo late starting?
|
| 60 |
+
0, self.args["max_steps"])
|
| 61 |
+
self.schedulers["bert_scheduler"] = warmup_scheduler
|
| 62 |
+
|
| 63 |
+
def update(self, batch, eval=False):
|
| 64 |
+
device = next(self.model.parameters()).device
|
| 65 |
+
inputs, orig_idx, word_orig_idx, sentlens, wordlens, text = unpack_batch(batch, device)
|
| 66 |
+
word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained = inputs
|
| 67 |
+
|
| 68 |
+
if eval:
|
| 69 |
+
self.model.eval()
|
| 70 |
+
else:
|
| 71 |
+
self.model.train()
|
| 72 |
+
for optimizer in self.optimizers.values():
|
| 73 |
+
optimizer.zero_grad()
|
| 74 |
+
loss, _ = self.model(word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained, word_orig_idx, sentlens, wordlens, text)
|
| 75 |
+
if loss == 0.0:
|
| 76 |
+
return loss
|
| 77 |
+
|
| 78 |
+
loss_val = loss.data.item()
|
| 79 |
+
if eval:
|
| 80 |
+
return loss_val
|
| 81 |
+
|
| 82 |
+
loss.backward()
|
| 83 |
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm'])
|
| 84 |
+
|
| 85 |
+
for optimizer in self.optimizers.values():
|
| 86 |
+
optimizer.step()
|
| 87 |
+
for scheduler in self.schedulers.values():
|
| 88 |
+
scheduler.step()
|
| 89 |
+
return loss_val
|
| 90 |
+
|
| 91 |
+
def predict(self, batch, unsort=True):
|
| 92 |
+
device = next(self.model.parameters()).device
|
| 93 |
+
inputs, orig_idx, word_orig_idx, sentlens, wordlens, text = unpack_batch(batch, device)
|
| 94 |
+
word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained = inputs
|
| 95 |
+
|
| 96 |
+
self.model.eval()
|
| 97 |
+
batch_size = word.size(0)
|
| 98 |
+
_, preds = self.model(word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained, word_orig_idx, sentlens, wordlens, text)
|
| 99 |
+
upos_seqs = [self.vocab['upos'].unmap(sent) for sent in preds[0].tolist()]
|
| 100 |
+
xpos_seqs = [self.vocab['xpos'].unmap(sent) for sent in preds[1].tolist()]
|
| 101 |
+
feats_seqs = [self.vocab['feats'].unmap(sent) for sent in preds[2].tolist()]
|
| 102 |
+
|
| 103 |
+
pred_tokens = [[[upos_seqs[i][j], xpos_seqs[i][j], feats_seqs[i][j]] for j in range(sentlens[i])] for i in range(batch_size)]
|
| 104 |
+
if unsort:
|
| 105 |
+
pred_tokens = utils.unsort(pred_tokens, orig_idx)
|
| 106 |
+
return pred_tokens
|
| 107 |
+
|
| 108 |
+
def save(self, filename, skip_modules=True):
|
| 109 |
+
model_state = self.model.state_dict()
|
| 110 |
+
# skip saving modules like pretrained embeddings, because they are large and will be saved in a separate file
|
| 111 |
+
if skip_modules:
|
| 112 |
+
skipped = [k for k in model_state.keys() if k.split('.')[0] in self.model.unsaved_modules]
|
| 113 |
+
for k in skipped:
|
| 114 |
+
del model_state[k]
|
| 115 |
+
params = {
|
| 116 |
+
'model': model_state,
|
| 117 |
+
'vocab': self.vocab.state_dict(),
|
| 118 |
+
'config': self.args
|
| 119 |
+
}
|
| 120 |
+
if self.args.get('use_peft', False):
|
| 121 |
+
# Hide import so that peft dependency is optional
|
| 122 |
+
from peft import get_peft_model_state_dict
|
| 123 |
+
params["bert_lora"] = get_peft_model_state_dict(self.model.bert_model, adapter_name=self.model.peft_name)
|
| 124 |
+
|
| 125 |
+
try:
|
| 126 |
+
torch.save(params, filename, _use_new_zipfile_serialization=False)
|
| 127 |
+
logger.info("Model saved to {}".format(filename))
|
| 128 |
+
except (KeyboardInterrupt, SystemExit):
|
| 129 |
+
raise
|
| 130 |
+
except Exception as e:
|
| 131 |
+
logger.warning(f"Saving failed... {e} continuing anyway.")
|
| 132 |
+
|
| 133 |
+
def load(self, filename, pretrain, args=None, foundation_cache=None):
|
| 134 |
+
"""
|
| 135 |
+
Load a model from file, with preloaded pretrain embeddings. Here we allow the pretrain to be None or a dummy input,
|
| 136 |
+
and the actual use of pretrain embeddings will depend on the boolean config "pretrain" in the loaded args.
|
| 137 |
+
"""
|
| 138 |
+
try:
|
| 139 |
+
checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True)
|
| 140 |
+
except BaseException:
|
| 141 |
+
logger.error("Cannot load model from {}".format(filename))
|
| 142 |
+
raise
|
| 143 |
+
self.args = checkpoint['config']
|
| 144 |
+
if args is not None: self.args.update(args)
|
| 145 |
+
|
| 146 |
+
# preserve old models which were created before transformers were added
|
| 147 |
+
if 'bert_model' not in self.args:
|
| 148 |
+
self.args['bert_model'] = None
|
| 149 |
+
|
| 150 |
+
lora_weights = checkpoint.get('bert_lora')
|
| 151 |
+
if lora_weights:
|
| 152 |
+
logger.debug("Found peft weights for POS; loading a peft adapter")
|
| 153 |
+
self.args["use_peft"] = True
|
| 154 |
+
|
| 155 |
+
# TODO: refactor this common block of code with NER
|
| 156 |
+
force_bert_saved = False
|
| 157 |
+
peft_name = None
|
| 158 |
+
if self.args.get('use_peft', False):
|
| 159 |
+
force_bert_saved = True
|
| 160 |
+
bert_model, bert_tokenizer, peft_name = load_bert_with_peft(self.args['bert_model'], "pos", foundation_cache)
|
| 161 |
+
bert_model = load_peft_wrapper(bert_model, lora_weights, self.args, logger, peft_name)
|
| 162 |
+
logger.debug("Loaded peft with name %s", peft_name)
|
| 163 |
+
else:
|
| 164 |
+
if any(x.startswith("bert_model.") for x in checkpoint['model'].keys()):
|
| 165 |
+
logger.debug("Model %s has a finetuned transformer. Not using transformer cache to make sure the finetuned version of the transformer isn't accidentally used elsewhere", filename)
|
| 166 |
+
foundation_cache = NoTransformerFoundationCache(foundation_cache)
|
| 167 |
+
force_bert_saved = True
|
| 168 |
+
bert_model, bert_tokenizer = load_bert(self.args.get('bert_model'), foundation_cache)
|
| 169 |
+
|
| 170 |
+
self.vocab = MultiVocab.load_state_dict(checkpoint['vocab'])
|
| 171 |
+
# load model
|
| 172 |
+
emb_matrix = None
|
| 173 |
+
if self.args['pretrain'] and pretrain is not None: # we use pretrain only if args['pretrain'] == True and pretrain is not None
|
| 174 |
+
emb_matrix = pretrain.emb
|
| 175 |
+
if any(x.startswith("bert_model.") for x in checkpoint['model'].keys()):
|
| 176 |
+
logger.debug("Model %s has a finetuned transformer. Not using transformer cache to make sure the finetuned version of the transformer isn't accidentally used elsewhere", filename)
|
| 177 |
+
foundation_cache = NoTransformerFoundationCache(foundation_cache)
|
| 178 |
+
self.model = Tagger(self.args, self.vocab, emb_matrix=emb_matrix, share_hid=self.args['share_hid'], foundation_cache=foundation_cache, bert_model=bert_model, bert_tokenizer=bert_tokenizer, force_bert_saved=force_bert_saved, peft_name=peft_name)
|
| 179 |
+
self.model.load_state_dict(checkpoint['model'], strict=False)
|
stanza/stanza/models/pos/xpos_vocab_factory.py
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This is the XPOS factory method generated automatically from stanza.models.pos.build_xpos_vocab_factory.
|
| 2 |
+
# Please don't edit it!
|
| 3 |
+
|
| 4 |
+
import logging
|
| 5 |
+
|
| 6 |
+
from stanza.models.pos.vocab import WordVocab, XPOSVocab
|
| 7 |
+
from stanza.models.pos.xpos_vocab_utils import XPOSDescription, XPOSType, build_xpos_vocab, choose_simplest_factory
|
| 8 |
+
|
| 9 |
+
# using a sublogger makes it easier to test in the unittests
|
| 10 |
+
logger = logging.getLogger('stanza.models.pos.xpos_vocab_factory')
|
| 11 |
+
|
| 12 |
+
XPOS_DESCRIPTIONS = {
|
| 13 |
+
'af_afribooms' : XPOSDescription(XPOSType.XPOS, ''),
|
| 14 |
+
'ar_padt' : XPOSDescription(XPOSType.XPOS, ''),
|
| 15 |
+
'bg_btb' : XPOSDescription(XPOSType.XPOS, ''),
|
| 16 |
+
'ca_ancora' : XPOSDescription(XPOSType.XPOS, ''),
|
| 17 |
+
'cs_cac' : XPOSDescription(XPOSType.XPOS, ''),
|
| 18 |
+
'cs_cltt' : XPOSDescription(XPOSType.XPOS, ''),
|
| 19 |
+
'cs_fictree' : XPOSDescription(XPOSType.XPOS, ''),
|
| 20 |
+
'cs_pdt' : XPOSDescription(XPOSType.XPOS, ''),
|
| 21 |
+
'en_partut' : XPOSDescription(XPOSType.XPOS, ''),
|
| 22 |
+
'es_ancora' : XPOSDescription(XPOSType.XPOS, ''),
|
| 23 |
+
'es_combined' : XPOSDescription(XPOSType.XPOS, ''),
|
| 24 |
+
'fr_partut' : XPOSDescription(XPOSType.XPOS, ''),
|
| 25 |
+
'gd_arcosg' : XPOSDescription(XPOSType.XPOS, ''),
|
| 26 |
+
'gl_ctg' : XPOSDescription(XPOSType.XPOS, ''),
|
| 27 |
+
'gl_treegal' : XPOSDescription(XPOSType.XPOS, ''),
|
| 28 |
+
'grc_perseus' : XPOSDescription(XPOSType.XPOS, ''),
|
| 29 |
+
'hr_set' : XPOSDescription(XPOSType.XPOS, ''),
|
| 30 |
+
'is_gc' : XPOSDescription(XPOSType.XPOS, ''),
|
| 31 |
+
'is_icepahc' : XPOSDescription(XPOSType.XPOS, ''),
|
| 32 |
+
'is_modern' : XPOSDescription(XPOSType.XPOS, ''),
|
| 33 |
+
'it_combined' : XPOSDescription(XPOSType.XPOS, ''),
|
| 34 |
+
'it_isdt' : XPOSDescription(XPOSType.XPOS, ''),
|
| 35 |
+
'it_markit' : XPOSDescription(XPOSType.XPOS, ''),
|
| 36 |
+
'it_parlamint' : XPOSDescription(XPOSType.XPOS, ''),
|
| 37 |
+
'it_partut' : XPOSDescription(XPOSType.XPOS, ''),
|
| 38 |
+
'it_postwita' : XPOSDescription(XPOSType.XPOS, ''),
|
| 39 |
+
'it_twittiro' : XPOSDescription(XPOSType.XPOS, ''),
|
| 40 |
+
'it_vit' : XPOSDescription(XPOSType.XPOS, ''),
|
| 41 |
+
'la_perseus' : XPOSDescription(XPOSType.XPOS, ''),
|
| 42 |
+
'la_udante' : XPOSDescription(XPOSType.XPOS, ''),
|
| 43 |
+
'lt_alksnis' : XPOSDescription(XPOSType.XPOS, ''),
|
| 44 |
+
'lv_lvtb' : XPOSDescription(XPOSType.XPOS, ''),
|
| 45 |
+
'ro_nonstandard' : XPOSDescription(XPOSType.XPOS, ''),
|
| 46 |
+
'ro_rrt' : XPOSDescription(XPOSType.XPOS, ''),
|
| 47 |
+
'ro_simonero' : XPOSDescription(XPOSType.XPOS, ''),
|
| 48 |
+
'sk_snk' : XPOSDescription(XPOSType.XPOS, ''),
|
| 49 |
+
'sl_ssj' : XPOSDescription(XPOSType.XPOS, ''),
|
| 50 |
+
'sl_sst' : XPOSDescription(XPOSType.XPOS, ''),
|
| 51 |
+
'sr_set' : XPOSDescription(XPOSType.XPOS, ''),
|
| 52 |
+
'ta_ttb' : XPOSDescription(XPOSType.XPOS, ''),
|
| 53 |
+
'uk_iu' : XPOSDescription(XPOSType.XPOS, ''),
|
| 54 |
+
|
| 55 |
+
'be_hse' : XPOSDescription(XPOSType.WORD, None),
|
| 56 |
+
'bxr_bdt' : XPOSDescription(XPOSType.WORD, None),
|
| 57 |
+
'cop_scriptorium': XPOSDescription(XPOSType.WORD, None),
|
| 58 |
+
'cu_proiel' : XPOSDescription(XPOSType.WORD, None),
|
| 59 |
+
'cy_ccg' : XPOSDescription(XPOSType.WORD, None),
|
| 60 |
+
'da_ddt' : XPOSDescription(XPOSType.WORD, None),
|
| 61 |
+
'de_gsd' : XPOSDescription(XPOSType.WORD, None),
|
| 62 |
+
'de_hdt' : XPOSDescription(XPOSType.WORD, None),
|
| 63 |
+
'el_gdt' : XPOSDescription(XPOSType.WORD, None),
|
| 64 |
+
'el_gud' : XPOSDescription(XPOSType.WORD, None),
|
| 65 |
+
'en_atis' : XPOSDescription(XPOSType.WORD, None),
|
| 66 |
+
'en_combined' : XPOSDescription(XPOSType.WORD, None),
|
| 67 |
+
'en_craft' : XPOSDescription(XPOSType.WORD, None),
|
| 68 |
+
'en_eslspok' : XPOSDescription(XPOSType.WORD, None),
|
| 69 |
+
'en_ewt' : XPOSDescription(XPOSType.WORD, None),
|
| 70 |
+
'en_genia' : XPOSDescription(XPOSType.WORD, None),
|
| 71 |
+
'en_gum' : XPOSDescription(XPOSType.WORD, None),
|
| 72 |
+
'en_gumreddit' : XPOSDescription(XPOSType.WORD, None),
|
| 73 |
+
'en_mimic' : XPOSDescription(XPOSType.WORD, None),
|
| 74 |
+
'en_test' : XPOSDescription(XPOSType.WORD, None),
|
| 75 |
+
'es_gsd' : XPOSDescription(XPOSType.WORD, None),
|
| 76 |
+
'et_edt' : XPOSDescription(XPOSType.WORD, None),
|
| 77 |
+
'et_ewt' : XPOSDescription(XPOSType.WORD, None),
|
| 78 |
+
'eu_bdt' : XPOSDescription(XPOSType.WORD, None),
|
| 79 |
+
'fa_perdt' : XPOSDescription(XPOSType.WORD, None),
|
| 80 |
+
'fa_seraji' : XPOSDescription(XPOSType.WORD, None),
|
| 81 |
+
'fi_tdt' : XPOSDescription(XPOSType.WORD, None),
|
| 82 |
+
'fr_combined' : XPOSDescription(XPOSType.WORD, None),
|
| 83 |
+
'fr_gsd' : XPOSDescription(XPOSType.WORD, None),
|
| 84 |
+
'fr_parisstories': XPOSDescription(XPOSType.WORD, None),
|
| 85 |
+
'fr_rhapsodie' : XPOSDescription(XPOSType.WORD, None),
|
| 86 |
+
'fr_sequoia' : XPOSDescription(XPOSType.WORD, None),
|
| 87 |
+
'fro_profiterole': XPOSDescription(XPOSType.WORD, None),
|
| 88 |
+
'ga_idt' : XPOSDescription(XPOSType.WORD, None),
|
| 89 |
+
'ga_twittirish' : XPOSDescription(XPOSType.WORD, None),
|
| 90 |
+
'got_proiel' : XPOSDescription(XPOSType.WORD, None),
|
| 91 |
+
'grc_proiel' : XPOSDescription(XPOSType.WORD, None),
|
| 92 |
+
'grc_ptnk' : XPOSDescription(XPOSType.WORD, None),
|
| 93 |
+
'gv_cadhan' : XPOSDescription(XPOSType.WORD, None),
|
| 94 |
+
'hbo_ptnk' : XPOSDescription(XPOSType.WORD, None),
|
| 95 |
+
'he_combined' : XPOSDescription(XPOSType.WORD, None),
|
| 96 |
+
'he_htb' : XPOSDescription(XPOSType.WORD, None),
|
| 97 |
+
'he_iahltknesset': XPOSDescription(XPOSType.WORD, None),
|
| 98 |
+
'he_iahltwiki' : XPOSDescription(XPOSType.WORD, None),
|
| 99 |
+
'hi_hdtb' : XPOSDescription(XPOSType.WORD, None),
|
| 100 |
+
'hsb_ufal' : XPOSDescription(XPOSType.WORD, None),
|
| 101 |
+
'hu_szeged' : XPOSDescription(XPOSType.WORD, None),
|
| 102 |
+
'hy_armtdp' : XPOSDescription(XPOSType.WORD, None),
|
| 103 |
+
'hy_bsut' : XPOSDescription(XPOSType.WORD, None),
|
| 104 |
+
'hyw_armtdp' : XPOSDescription(XPOSType.WORD, None),
|
| 105 |
+
'id_csui' : XPOSDescription(XPOSType.WORD, None),
|
| 106 |
+
'it_old' : XPOSDescription(XPOSType.WORD, None),
|
| 107 |
+
'ka_glc' : XPOSDescription(XPOSType.WORD, None),
|
| 108 |
+
'kk_ktb' : XPOSDescription(XPOSType.WORD, None),
|
| 109 |
+
'kmr_mg' : XPOSDescription(XPOSType.WORD, None),
|
| 110 |
+
'kpv_lattice' : XPOSDescription(XPOSType.WORD, None),
|
| 111 |
+
'ky_ktmu' : XPOSDescription(XPOSType.WORD, None),
|
| 112 |
+
'la_proiel' : XPOSDescription(XPOSType.WORD, None),
|
| 113 |
+
'lij_glt' : XPOSDescription(XPOSType.WORD, None),
|
| 114 |
+
'lt_hse' : XPOSDescription(XPOSType.WORD, None),
|
| 115 |
+
'lzh_kyoto' : XPOSDescription(XPOSType.WORD, None),
|
| 116 |
+
'mr_ufal' : XPOSDescription(XPOSType.WORD, None),
|
| 117 |
+
'mt_mudt' : XPOSDescription(XPOSType.WORD, None),
|
| 118 |
+
'myv_jr' : XPOSDescription(XPOSType.WORD, None),
|
| 119 |
+
'nb_bokmaal' : XPOSDescription(XPOSType.WORD, None),
|
| 120 |
+
'nds_lsdc' : XPOSDescription(XPOSType.WORD, None),
|
| 121 |
+
'nn_nynorsk' : XPOSDescription(XPOSType.WORD, None),
|
| 122 |
+
'nn_nynorsklia' : XPOSDescription(XPOSType.WORD, None),
|
| 123 |
+
'no_bokmaal' : XPOSDescription(XPOSType.WORD, None),
|
| 124 |
+
'orv_birchbark' : XPOSDescription(XPOSType.WORD, None),
|
| 125 |
+
'orv_rnc' : XPOSDescription(XPOSType.WORD, None),
|
| 126 |
+
'orv_torot' : XPOSDescription(XPOSType.WORD, None),
|
| 127 |
+
'ota_boun' : XPOSDescription(XPOSType.WORD, None),
|
| 128 |
+
'pcm_nsc' : XPOSDescription(XPOSType.WORD, None),
|
| 129 |
+
'pt_bosque' : XPOSDescription(XPOSType.WORD, None),
|
| 130 |
+
'pt_cintil' : XPOSDescription(XPOSType.WORD, None),
|
| 131 |
+
'pt_dantestocks' : XPOSDescription(XPOSType.WORD, None),
|
| 132 |
+
'pt_gsd' : XPOSDescription(XPOSType.WORD, None),
|
| 133 |
+
'pt_petrogold' : XPOSDescription(XPOSType.WORD, None),
|
| 134 |
+
'pt_porttinari' : XPOSDescription(XPOSType.WORD, None),
|
| 135 |
+
'qpm_philotis' : XPOSDescription(XPOSType.WORD, None),
|
| 136 |
+
'qtd_sagt' : XPOSDescription(XPOSType.WORD, None),
|
| 137 |
+
'ru_gsd' : XPOSDescription(XPOSType.WORD, None),
|
| 138 |
+
'ru_poetry' : XPOSDescription(XPOSType.WORD, None),
|
| 139 |
+
'ru_syntagrus' : XPOSDescription(XPOSType.WORD, None),
|
| 140 |
+
'ru_taiga' : XPOSDescription(XPOSType.WORD, None),
|
| 141 |
+
'sa_vedic' : XPOSDescription(XPOSType.WORD, None),
|
| 142 |
+
'sme_giella' : XPOSDescription(XPOSType.WORD, None),
|
| 143 |
+
'swl_sslc' : XPOSDescription(XPOSType.WORD, None),
|
| 144 |
+
'sq_staf' : XPOSDescription(XPOSType.WORD, None),
|
| 145 |
+
'te_mtg' : XPOSDescription(XPOSType.WORD, None),
|
| 146 |
+
'tr_atis' : XPOSDescription(XPOSType.WORD, None),
|
| 147 |
+
'tr_boun' : XPOSDescription(XPOSType.WORD, None),
|
| 148 |
+
'tr_framenet' : XPOSDescription(XPOSType.WORD, None),
|
| 149 |
+
'tr_imst' : XPOSDescription(XPOSType.WORD, None),
|
| 150 |
+
'tr_kenet' : XPOSDescription(XPOSType.WORD, None),
|
| 151 |
+
'tr_penn' : XPOSDescription(XPOSType.WORD, None),
|
| 152 |
+
'tr_tourism' : XPOSDescription(XPOSType.WORD, None),
|
| 153 |
+
'ug_udt' : XPOSDescription(XPOSType.WORD, None),
|
| 154 |
+
'uk_parlamint' : XPOSDescription(XPOSType.WORD, None),
|
| 155 |
+
'vi_vtb' : XPOSDescription(XPOSType.WORD, None),
|
| 156 |
+
'wo_wtb' : XPOSDescription(XPOSType.WORD, None),
|
| 157 |
+
'xcl_caval' : XPOSDescription(XPOSType.WORD, None),
|
| 158 |
+
'zh-hans_gsdsimp': XPOSDescription(XPOSType.WORD, None),
|
| 159 |
+
'zh-hant_gsd' : XPOSDescription(XPOSType.WORD, None),
|
| 160 |
+
'zh_gsdsimp' : XPOSDescription(XPOSType.WORD, None),
|
| 161 |
+
|
| 162 |
+
'en_lines' : XPOSDescription(XPOSType.XPOS, '-'),
|
| 163 |
+
'fo_farpahc' : XPOSDescription(XPOSType.XPOS, '-'),
|
| 164 |
+
'ja_gsd' : XPOSDescription(XPOSType.XPOS, '-'),
|
| 165 |
+
'ja_gsdluw' : XPOSDescription(XPOSType.XPOS, '-'),
|
| 166 |
+
'sv_lines' : XPOSDescription(XPOSType.XPOS, '-'),
|
| 167 |
+
'ur_udtb' : XPOSDescription(XPOSType.XPOS, '-'),
|
| 168 |
+
|
| 169 |
+
'fi_ftb' : XPOSDescription(XPOSType.XPOS, ','),
|
| 170 |
+
'orv_ruthenian' : XPOSDescription(XPOSType.XPOS, ','),
|
| 171 |
+
|
| 172 |
+
'id_gsd' : XPOSDescription(XPOSType.XPOS, '+'),
|
| 173 |
+
'ko_gsd' : XPOSDescription(XPOSType.XPOS, '+'),
|
| 174 |
+
'ko_kaist' : XPOSDescription(XPOSType.XPOS, '+'),
|
| 175 |
+
'ko_ksl' : XPOSDescription(XPOSType.XPOS, '+'),
|
| 176 |
+
'qaf_arabizi' : XPOSDescription(XPOSType.XPOS, '+'),
|
| 177 |
+
|
| 178 |
+
'la_ittb' : XPOSDescription(XPOSType.XPOS, '|'),
|
| 179 |
+
'la_llct' : XPOSDescription(XPOSType.XPOS, '|'),
|
| 180 |
+
'nl_alpino' : XPOSDescription(XPOSType.XPOS, '|'),
|
| 181 |
+
'nl_lassysmall' : XPOSDescription(XPOSType.XPOS, '|'),
|
| 182 |
+
'sv_talbanken' : XPOSDescription(XPOSType.XPOS, '|'),
|
| 183 |
+
|
| 184 |
+
'pl_lfg' : XPOSDescription(XPOSType.XPOS, ':'),
|
| 185 |
+
'pl_pdb' : XPOSDescription(XPOSType.XPOS, ':'),
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
def xpos_vocab_factory(data, shorthand):
|
| 189 |
+
if shorthand not in XPOS_DESCRIPTIONS:
|
| 190 |
+
logger.warning("%s is not a known dataset. Examining the data to choose which xpos vocab to use", shorthand)
|
| 191 |
+
desc = choose_simplest_factory(data, shorthand)
|
| 192 |
+
if shorthand in XPOS_DESCRIPTIONS:
|
| 193 |
+
if XPOS_DESCRIPTIONS[shorthand] != desc:
|
| 194 |
+
# log instead of throw
|
| 195 |
+
# otherwise, updating datasets would be unpleasant
|
| 196 |
+
logger.error("XPOS tagset in %s has apparently changed! Was %s, is now %s", shorthand, XPOS_DESCRIPTIONS[shorthand], desc)
|
| 197 |
+
else:
|
| 198 |
+
logger.warning("Chose %s for the xpos factory for %s", desc, shorthand)
|
| 199 |
+
return build_xpos_vocab(desc, data, shorthand)
|
| 200 |
+
|
stanza/stanza/models/pos/xpos_vocab_utils.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import namedtuple
|
| 2 |
+
from enum import Enum
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
from stanza.models.common.vocab import VOCAB_PREFIX
|
| 7 |
+
from stanza.models.pos.vocab import XPOSVocab, WordVocab
|
| 8 |
+
|
| 9 |
+
class XPOSType(Enum):
|
| 10 |
+
XPOS = 1
|
| 11 |
+
WORD = 2
|
| 12 |
+
|
| 13 |
+
XPOSDescription = namedtuple('XPOSDescription', ['xpos_type', 'sep'])
|
| 14 |
+
DEFAULT_KEY = XPOSDescription(XPOSType.WORD, None)
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger('stanza')
|
| 17 |
+
|
| 18 |
+
def filter_data(data, idx):
|
| 19 |
+
data_filtered = []
|
| 20 |
+
for sentence in data:
|
| 21 |
+
flag = True
|
| 22 |
+
for token in sentence:
|
| 23 |
+
if token[idx] is None:
|
| 24 |
+
flag = False
|
| 25 |
+
if flag: data_filtered.append(sentence)
|
| 26 |
+
return data_filtered
|
| 27 |
+
|
| 28 |
+
def choose_simplest_factory(data, shorthand):
|
| 29 |
+
logger.info(f'Original length = {len(data)}')
|
| 30 |
+
data = filter_data(data, idx=2)
|
| 31 |
+
logger.info(f'Filtered length = {len(data)}')
|
| 32 |
+
vocab = WordVocab(data, shorthand, idx=2, ignore=["_"])
|
| 33 |
+
key = DEFAULT_KEY
|
| 34 |
+
best_size = len(vocab) - len(VOCAB_PREFIX)
|
| 35 |
+
if best_size > 20:
|
| 36 |
+
for sep in ['', '-', '+', '|', ',', ':']: # separators
|
| 37 |
+
vocab = XPOSVocab(data, shorthand, idx=2, sep=sep)
|
| 38 |
+
length = sum(len(x) - len(VOCAB_PREFIX) for x in vocab._id2unit.values())
|
| 39 |
+
if length < best_size:
|
| 40 |
+
key = XPOSDescription(XPOSType.XPOS, sep)
|
| 41 |
+
best_size = length
|
| 42 |
+
return key
|
| 43 |
+
|
| 44 |
+
def build_xpos_vocab(description, data, shorthand):
|
| 45 |
+
if description.xpos_type is XPOSType.WORD:
|
| 46 |
+
return WordVocab(data, shorthand, idx=2, ignore=["_"])
|
| 47 |
+
|
| 48 |
+
return XPOSVocab(data, shorthand, idx=2, sep=description.sep)
|
stanza/stanza/models/tokenization/__init__.py
ADDED
|
File without changes
|
stanza/stanza/models/tokenization/data.py
ADDED
|
@@ -0,0 +1,432 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from bisect import bisect_right
|
| 2 |
+
from copy import copy
|
| 3 |
+
import numpy as np
|
| 4 |
+
import random
|
| 5 |
+
import logging
|
| 6 |
+
import re
|
| 7 |
+
import torch
|
| 8 |
+
from torch.utils.data import Dataset
|
| 9 |
+
from .vocab import Vocab
|
| 10 |
+
|
| 11 |
+
from stanza.models.common.utils import sort_with_indices, unsort
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger('stanza')
|
| 14 |
+
|
| 15 |
+
def filter_consecutive_whitespaces(para):
|
| 16 |
+
filtered = []
|
| 17 |
+
for i, (char, label) in enumerate(para):
|
| 18 |
+
if i > 0:
|
| 19 |
+
if char == ' ' and para[i-1][0] == ' ':
|
| 20 |
+
continue
|
| 21 |
+
|
| 22 |
+
filtered.append((char, label))
|
| 23 |
+
|
| 24 |
+
return filtered
|
| 25 |
+
|
| 26 |
+
NEWLINE_WHITESPACE_RE = re.compile(r'\n\s*\n')
|
| 27 |
+
# this was (r'^([\d]+[,\.]*)+$')
|
| 28 |
+
# but the runtime on that can explode exponentially
|
| 29 |
+
# for example, on 111111111111111111111111a
|
| 30 |
+
NUMERIC_RE = re.compile(r'^[\d]+([,\.]+[\d]+)*[,\.]*$')
|
| 31 |
+
WHITESPACE_RE = re.compile(r'\s')
|
| 32 |
+
|
| 33 |
+
class TokenizationDataset:
|
| 34 |
+
def __init__(self, tokenizer_args, input_files={'txt': None, 'label': None}, input_text=None, vocab=None, evaluation=False, dictionary=None, *args, **kwargs):
|
| 35 |
+
super().__init__(*args, **kwargs) # forwards all unused arguments
|
| 36 |
+
self.args = tokenizer_args
|
| 37 |
+
self.eval = evaluation
|
| 38 |
+
self.dictionary = dictionary
|
| 39 |
+
self.vocab = vocab
|
| 40 |
+
|
| 41 |
+
# get input files
|
| 42 |
+
txt_file = input_files['txt']
|
| 43 |
+
label_file = input_files['label']
|
| 44 |
+
|
| 45 |
+
# Load data and process it
|
| 46 |
+
# set up text from file or input string
|
| 47 |
+
assert txt_file is not None or input_text is not None
|
| 48 |
+
if input_text is None:
|
| 49 |
+
with open(txt_file) as f:
|
| 50 |
+
text = ''.join(f.readlines()).rstrip()
|
| 51 |
+
else:
|
| 52 |
+
text = input_text
|
| 53 |
+
|
| 54 |
+
text_chunks = NEWLINE_WHITESPACE_RE.split(text)
|
| 55 |
+
text_chunks = [pt.rstrip() for pt in text_chunks]
|
| 56 |
+
text_chunks = [pt for pt in text_chunks if pt]
|
| 57 |
+
if label_file is not None:
|
| 58 |
+
with open(label_file) as f:
|
| 59 |
+
labels = ''.join(f.readlines()).rstrip()
|
| 60 |
+
labels = NEWLINE_WHITESPACE_RE.split(labels)
|
| 61 |
+
labels = [pt.rstrip() for pt in labels]
|
| 62 |
+
labels = [map(int, pt) for pt in labels if pt]
|
| 63 |
+
else:
|
| 64 |
+
labels = [[0 for _ in pt] for pt in text_chunks]
|
| 65 |
+
|
| 66 |
+
skip_newline = self.args.get('skip_newline', False)
|
| 67 |
+
self.data = [[(WHITESPACE_RE.sub(' ', char), label) # substitute special whitespaces
|
| 68 |
+
for char, label in zip(pt, pc) if not (skip_newline and char == '\n')] # check if newline needs to be eaten
|
| 69 |
+
for pt, pc in zip(text_chunks, labels)]
|
| 70 |
+
|
| 71 |
+
# remove consecutive whitespaces
|
| 72 |
+
self.data = [filter_consecutive_whitespaces(x) for x in self.data]
|
| 73 |
+
|
| 74 |
+
def labels(self):
|
| 75 |
+
"""
|
| 76 |
+
Returns a list of the labels for all of the sentences in this DataLoader
|
| 77 |
+
|
| 78 |
+
Used at eval time to compare to the results, for example
|
| 79 |
+
"""
|
| 80 |
+
return [np.array(list(x[1] for x in sent)) for sent in self.data]
|
| 81 |
+
|
| 82 |
+
def extract_dict_feat(self, para, idx):
|
| 83 |
+
"""
|
| 84 |
+
This function is to extract dictionary features for each character
|
| 85 |
+
"""
|
| 86 |
+
length = len(para)
|
| 87 |
+
|
| 88 |
+
dict_forward_feats = [0 for i in range(self.args['num_dict_feat'])]
|
| 89 |
+
dict_backward_feats = [0 for i in range(self.args['num_dict_feat'])]
|
| 90 |
+
forward_word = para[idx][0]
|
| 91 |
+
backward_word = para[idx][0]
|
| 92 |
+
prefix = True
|
| 93 |
+
suffix = True
|
| 94 |
+
for window in range(1,self.args['num_dict_feat']+1):
|
| 95 |
+
# concatenate each character and check if words found in dict not, stop if prefix not found
|
| 96 |
+
#check if idx+t is out of bound and if the prefix is already not found
|
| 97 |
+
if (idx + window) <= length-1 and prefix:
|
| 98 |
+
forward_word += para[idx+window][0].lower()
|
| 99 |
+
#check in json file if the word is present as prefix or word or None.
|
| 100 |
+
feat = 1 if forward_word in self.dictionary["words"] else 0
|
| 101 |
+
#if the return value is not 2 or 3 then the checking word is not a valid word in dict.
|
| 102 |
+
dict_forward_feats[window-1] = feat
|
| 103 |
+
#if the dict return 0 means no prefixes found, thus, stop looking for forward.
|
| 104 |
+
if forward_word not in self.dictionary["prefixes"]:
|
| 105 |
+
prefix = False
|
| 106 |
+
#backward check: similar to forward
|
| 107 |
+
if (idx - window) >= 0 and suffix:
|
| 108 |
+
backward_word = para[idx-window][0].lower() + backward_word
|
| 109 |
+
feat = 1 if backward_word in self.dictionary["words"] else 0
|
| 110 |
+
dict_backward_feats[window-1] = feat
|
| 111 |
+
if backward_word not in self.dictionary["suffixes"]:
|
| 112 |
+
suffix = False
|
| 113 |
+
#if cannot find both prefix and suffix, then exit the loop
|
| 114 |
+
if not prefix and not suffix:
|
| 115 |
+
break
|
| 116 |
+
|
| 117 |
+
return dict_forward_feats + dict_backward_feats
|
| 118 |
+
|
| 119 |
+
def para_to_sentences(self, para):
|
| 120 |
+
""" Convert a paragraph to a list of processed sentences. """
|
| 121 |
+
res = []
|
| 122 |
+
funcs = []
|
| 123 |
+
for feat_func in self.args['feat_funcs']:
|
| 124 |
+
if feat_func == 'end_of_para' or feat_func == 'start_of_para':
|
| 125 |
+
# skip for position-dependent features
|
| 126 |
+
continue
|
| 127 |
+
if feat_func == 'space_before':
|
| 128 |
+
func = lambda x: 1 if x.startswith(' ') else 0
|
| 129 |
+
elif feat_func == 'capitalized':
|
| 130 |
+
func = lambda x: 1 if x[0].isupper() else 0
|
| 131 |
+
elif feat_func == 'numeric':
|
| 132 |
+
func = lambda x: 1 if (NUMERIC_RE.match(x) is not None) else 0
|
| 133 |
+
else:
|
| 134 |
+
raise ValueError('Feature function "{}" is undefined.'.format(feat_func))
|
| 135 |
+
|
| 136 |
+
funcs.append(func)
|
| 137 |
+
|
| 138 |
+
# stacking all featurize functions
|
| 139 |
+
composite_func = lambda x: [f(x) for f in funcs]
|
| 140 |
+
|
| 141 |
+
def process_sentence(sent_units, sent_labels, sent_feats):
|
| 142 |
+
return (np.array([self.vocab.unit2id(y) for y in sent_units]),
|
| 143 |
+
np.array(sent_labels),
|
| 144 |
+
np.array(sent_feats),
|
| 145 |
+
list(sent_units))
|
| 146 |
+
|
| 147 |
+
use_end_of_para = 'end_of_para' in self.args['feat_funcs']
|
| 148 |
+
use_start_of_para = 'start_of_para' in self.args['feat_funcs']
|
| 149 |
+
use_dictionary = self.args['use_dictionary']
|
| 150 |
+
current_units = []
|
| 151 |
+
current_labels = []
|
| 152 |
+
current_feats = []
|
| 153 |
+
for i, (unit, label) in enumerate(para):
|
| 154 |
+
feats = composite_func(unit)
|
| 155 |
+
# position-dependent features
|
| 156 |
+
if use_end_of_para:
|
| 157 |
+
f = 1 if i == len(para)-1 else 0
|
| 158 |
+
feats.append(f)
|
| 159 |
+
if use_start_of_para:
|
| 160 |
+
f = 1 if i == 0 else 0
|
| 161 |
+
feats.append(f)
|
| 162 |
+
|
| 163 |
+
#if dictionary feature is selected
|
| 164 |
+
if use_dictionary:
|
| 165 |
+
dict_feats = self.extract_dict_feat(para, i)
|
| 166 |
+
feats = feats + dict_feats
|
| 167 |
+
|
| 168 |
+
current_units.append(unit)
|
| 169 |
+
current_labels.append(label)
|
| 170 |
+
current_feats.append(feats)
|
| 171 |
+
if not self.eval and (label == 2 or label == 4): # end of sentence
|
| 172 |
+
if len(current_units) <= self.args['max_seqlen']:
|
| 173 |
+
# get rid of sentences that are too long during training of the tokenizer
|
| 174 |
+
res.append(process_sentence(current_units, current_labels, current_feats))
|
| 175 |
+
current_units.clear()
|
| 176 |
+
current_labels.clear()
|
| 177 |
+
current_feats.clear()
|
| 178 |
+
|
| 179 |
+
if len(current_units) > 0:
|
| 180 |
+
if self.eval or len(current_units) <= self.args['max_seqlen']:
|
| 181 |
+
res.append(process_sentence(current_units, current_labels, current_feats))
|
| 182 |
+
|
| 183 |
+
return res
|
| 184 |
+
|
| 185 |
+
def advance_old_batch(self, eval_offsets, old_batch):
|
| 186 |
+
"""
|
| 187 |
+
Advance to a new position in a batch where we have partially processed the batch
|
| 188 |
+
|
| 189 |
+
If we have previously built a batch of data and made predictions on them, then when we are trying to make
|
| 190 |
+
prediction on later characters in those paragraphs, we can avoid rebuilding the converted data from scratch
|
| 191 |
+
and just (essentially) advance the indices/offsets from where we read converted data in this old batch.
|
| 192 |
+
In this case, eval_offsets index within the old_batch to advance the strings to process.
|
| 193 |
+
"""
|
| 194 |
+
unkid = self.vocab.unit2id('<UNK>')
|
| 195 |
+
padid = self.vocab.unit2id('<PAD>')
|
| 196 |
+
|
| 197 |
+
ounits, olabels, ofeatures, oraw = old_batch
|
| 198 |
+
feat_size = ofeatures.shape[-1]
|
| 199 |
+
lens = (ounits != padid).sum(1).tolist()
|
| 200 |
+
pad_len = max(l-i for i, l in zip(eval_offsets, lens))
|
| 201 |
+
|
| 202 |
+
units = torch.full((len(ounits), pad_len), padid, dtype=torch.int64)
|
| 203 |
+
labels = torch.full((len(ounits), pad_len), -1, dtype=torch.int32)
|
| 204 |
+
features = torch.zeros((len(ounits), pad_len, feat_size), dtype=torch.float32)
|
| 205 |
+
raw_units = []
|
| 206 |
+
|
| 207 |
+
for i in range(len(ounits)):
|
| 208 |
+
eval_offsets[i] = min(eval_offsets[i], lens[i])
|
| 209 |
+
units[i, :(lens[i] - eval_offsets[i])] = ounits[i, eval_offsets[i]:lens[i]]
|
| 210 |
+
labels[i, :(lens[i] - eval_offsets[i])] = olabels[i, eval_offsets[i]:lens[i]]
|
| 211 |
+
features[i, :(lens[i] - eval_offsets[i])] = ofeatures[i, eval_offsets[i]:lens[i]]
|
| 212 |
+
raw_units.append(oraw[i][eval_offsets[i]:lens[i]] + ['<PAD>'] * (pad_len - lens[i] + eval_offsets[i]))
|
| 213 |
+
|
| 214 |
+
return units, labels, features, raw_units
|
| 215 |
+
|
| 216 |
+
class DataLoader(TokenizationDataset):
|
| 217 |
+
"""
|
| 218 |
+
This is the training version of the dataset.
|
| 219 |
+
"""
|
| 220 |
+
def __init__(self, args, input_files={'txt': None, 'label': None}, input_text=None, vocab=None, evaluation=False, dictionary=None):
|
| 221 |
+
super().__init__(args, input_files, input_text, vocab, evaluation, dictionary)
|
| 222 |
+
|
| 223 |
+
self.vocab = vocab if vocab is not None else self.init_vocab()
|
| 224 |
+
|
| 225 |
+
# data comes in a list of paragraphs, where each paragraph is a list of units with unit-level labels.
|
| 226 |
+
# At evaluation time, each paragraph is treated as single "sentence" as we don't know a priori where
|
| 227 |
+
# sentence breaks occur. We make prediction from left to right for each paragraph and move forward to
|
| 228 |
+
# the last predicted sentence break to start afresh.
|
| 229 |
+
self.sentences = [self.para_to_sentences(para) for para in self.data]
|
| 230 |
+
|
| 231 |
+
self.init_sent_ids()
|
| 232 |
+
logger.debug(f"{len(self.sentence_ids)} sentences loaded.")
|
| 233 |
+
|
| 234 |
+
def __len__(self):
|
| 235 |
+
return len(self.sentence_ids)
|
| 236 |
+
|
| 237 |
+
def init_vocab(self):
|
| 238 |
+
vocab = Vocab(self.data, self.args['lang'])
|
| 239 |
+
return vocab
|
| 240 |
+
|
| 241 |
+
def init_sent_ids(self):
|
| 242 |
+
self.sentence_ids = []
|
| 243 |
+
self.cumlen = [0]
|
| 244 |
+
for i, para in enumerate(self.sentences):
|
| 245 |
+
for j in range(len(para)):
|
| 246 |
+
self.sentence_ids += [(i, j)]
|
| 247 |
+
self.cumlen += [self.cumlen[-1] + len(self.sentences[i][j][0])]
|
| 248 |
+
|
| 249 |
+
def has_mwt(self):
|
| 250 |
+
# presumably this only needs to be called either 0 or 1 times,
|
| 251 |
+
# 1 when training and 0 any other time, so no effort is put
|
| 252 |
+
# into caching the result
|
| 253 |
+
for sentence in self.data:
|
| 254 |
+
for word in sentence:
|
| 255 |
+
if word[1] > 2:
|
| 256 |
+
return True
|
| 257 |
+
return False
|
| 258 |
+
|
| 259 |
+
def shuffle(self):
|
| 260 |
+
for para in self.sentences:
|
| 261 |
+
random.shuffle(para)
|
| 262 |
+
self.init_sent_ids()
|
| 263 |
+
|
| 264 |
+
def next(self, eval_offsets=None, unit_dropout=0.0, feat_unit_dropout=0.0):
|
| 265 |
+
''' Get a batch of converted and padded PyTorch data from preprocessed raw text for training/prediction. '''
|
| 266 |
+
feat_size = len(self.sentences[0][0][2][0])
|
| 267 |
+
unkid = self.vocab.unit2id('<UNK>')
|
| 268 |
+
padid = self.vocab.unit2id('<PAD>')
|
| 269 |
+
|
| 270 |
+
def strings_starting(id_pair, offset=0, pad_len=self.args['max_seqlen']):
|
| 271 |
+
# At eval time, this combines sentences in paragraph (indexed by id_pair[0]) starting sentence (indexed
|
| 272 |
+
# by id_pair[1]) into a long string for evaluation. At training time, we just select random sentences
|
| 273 |
+
# from the entire dataset until we reach max_seqlen.
|
| 274 |
+
pid, sid = id_pair if self.eval else random.choice(self.sentence_ids)
|
| 275 |
+
sentences = [copy([x[offset:] for x in self.sentences[pid][sid]])]
|
| 276 |
+
|
| 277 |
+
drop_sents = False if self.eval or (self.args.get('sent_drop_prob', 0) == 0) else (random.random() < self.args.get('sent_drop_prob', 0))
|
| 278 |
+
drop_last_char = False if self.eval or (self.args.get('last_char_drop_prob', 0) == 0) else (random.random() < self.args.get('last_char_drop_prob', 0))
|
| 279 |
+
total_len = len(sentences[0][0])
|
| 280 |
+
|
| 281 |
+
assert self.eval or total_len <= self.args['max_seqlen'], 'The maximum sequence length {} is less than that of the longest sentence length ({}) in the data, consider increasing it! {}'.format(self.args['max_seqlen'], total_len, ' '.join(["{}/{}".format(*x) for x in zip(self.sentences[pid][sid])]))
|
| 282 |
+
if self.eval:
|
| 283 |
+
for sid1 in range(sid+1, len(self.sentences[pid])):
|
| 284 |
+
total_len += len(self.sentences[pid][sid1][0])
|
| 285 |
+
sentences.append(self.sentences[pid][sid1])
|
| 286 |
+
|
| 287 |
+
if total_len >= self.args['max_seqlen']:
|
| 288 |
+
break
|
| 289 |
+
else:
|
| 290 |
+
while True:
|
| 291 |
+
pid1, sid1 = random.choice(self.sentence_ids)
|
| 292 |
+
total_len += len(self.sentences[pid1][sid1][0])
|
| 293 |
+
sentences.append(self.sentences[pid1][sid1])
|
| 294 |
+
|
| 295 |
+
if total_len >= self.args['max_seqlen']:
|
| 296 |
+
break
|
| 297 |
+
|
| 298 |
+
if drop_sents and len(sentences) > 1:
|
| 299 |
+
if total_len > self.args['max_seqlen']:
|
| 300 |
+
sentences = sentences[:-1]
|
| 301 |
+
if len(sentences) > 1:
|
| 302 |
+
p = [.5 ** i for i in range(1, len(sentences) + 1)] # drop a large number of sentences with smaller probability
|
| 303 |
+
cutoff = random.choices(list(range(len(sentences))), weights=list(reversed(p)))[0]
|
| 304 |
+
sentences = sentences[:cutoff+1]
|
| 305 |
+
|
| 306 |
+
units = np.concatenate([s[0] for s in sentences])
|
| 307 |
+
labels = np.concatenate([s[1] for s in sentences])
|
| 308 |
+
feats = np.concatenate([s[2] for s in sentences])
|
| 309 |
+
raw_units = [x for s in sentences for x in s[3]]
|
| 310 |
+
|
| 311 |
+
if not self.eval:
|
| 312 |
+
cutoff = self.args['max_seqlen']
|
| 313 |
+
units, labels, feats, raw_units = units[:cutoff], labels[:cutoff], feats[:cutoff], raw_units[:cutoff]
|
| 314 |
+
|
| 315 |
+
if drop_last_char: # can only happen in non-eval mode
|
| 316 |
+
if len(labels) > 1 and labels[-1] == 2 and labels[-2] in (1, 3):
|
| 317 |
+
# training text ended with a sentence end position
|
| 318 |
+
# and that word was a single character
|
| 319 |
+
# and the previous character ended the word
|
| 320 |
+
units, labels, feats, raw_units = units[:-1], labels[:-1], feats[:-1], raw_units[:-1]
|
| 321 |
+
# word end -> sentence end, mwt end -> sentence mwt end
|
| 322 |
+
labels[-1] = labels[-1] + 1
|
| 323 |
+
|
| 324 |
+
return units, labels, feats, raw_units
|
| 325 |
+
|
| 326 |
+
if eval_offsets is not None:
|
| 327 |
+
# find max padding length
|
| 328 |
+
pad_len = 0
|
| 329 |
+
for eval_offset in eval_offsets:
|
| 330 |
+
if eval_offset < self.cumlen[-1]:
|
| 331 |
+
pair_id = bisect_right(self.cumlen, eval_offset) - 1
|
| 332 |
+
pair = self.sentence_ids[pair_id]
|
| 333 |
+
pad_len = max(pad_len, len(strings_starting(pair, offset=eval_offset-self.cumlen[pair_id])[0]))
|
| 334 |
+
|
| 335 |
+
pad_len += 1
|
| 336 |
+
id_pairs = [bisect_right(self.cumlen, eval_offset) - 1 for eval_offset in eval_offsets]
|
| 337 |
+
pairs = [self.sentence_ids[pair_id] for pair_id in id_pairs]
|
| 338 |
+
offsets = [eval_offset - self.cumlen[pair_id] for eval_offset, pair_id in zip(eval_offsets, id_pairs)]
|
| 339 |
+
|
| 340 |
+
offsets_pairs = list(zip(offsets, pairs))
|
| 341 |
+
else:
|
| 342 |
+
id_pairs = random.sample(self.sentence_ids, min(len(self.sentence_ids), self.args['batch_size']))
|
| 343 |
+
offsets_pairs = [(0, x) for x in id_pairs]
|
| 344 |
+
pad_len = self.args['max_seqlen']
|
| 345 |
+
|
| 346 |
+
# put everything into padded and nicely shaped NumPy arrays and eventually convert to PyTorch tensors
|
| 347 |
+
units = np.full((len(id_pairs), pad_len), padid, dtype=np.int64)
|
| 348 |
+
labels = np.full((len(id_pairs), pad_len), -1, dtype=np.int64)
|
| 349 |
+
features = np.zeros((len(id_pairs), pad_len, feat_size), dtype=np.float32)
|
| 350 |
+
raw_units = []
|
| 351 |
+
for i, (offset, pair) in enumerate(offsets_pairs):
|
| 352 |
+
u_, l_, f_, r_ = strings_starting(pair, offset=offset, pad_len=pad_len)
|
| 353 |
+
units[i, :len(u_)] = u_
|
| 354 |
+
labels[i, :len(l_)] = l_
|
| 355 |
+
features[i, :len(f_), :] = f_
|
| 356 |
+
raw_units.append(r_ + ['<PAD>'] * (pad_len - len(r_)))
|
| 357 |
+
|
| 358 |
+
if unit_dropout > 0 and not self.eval:
|
| 359 |
+
# dropout characters/units at training time and replace them with UNKs
|
| 360 |
+
mask = np.random.random_sample(units.shape) < unit_dropout
|
| 361 |
+
mask[units == padid] = 0
|
| 362 |
+
units[mask] = unkid
|
| 363 |
+
for i in range(len(raw_units)):
|
| 364 |
+
for j in range(len(raw_units[i])):
|
| 365 |
+
if mask[i, j]:
|
| 366 |
+
raw_units[i][j] = '<UNK>'
|
| 367 |
+
|
| 368 |
+
# dropout unit feature vector in addition to only torch.dropout in the model.
|
| 369 |
+
# experiments showed that only torch.dropout hurts the model
|
| 370 |
+
# we believe it is because the dict feature vector is mostly scarse so it makes
|
| 371 |
+
# more sense to drop out the whole vector instead of only single element.
|
| 372 |
+
if self.args['use_dictionary'] and feat_unit_dropout > 0 and not self.eval:
|
| 373 |
+
mask_feat = np.random.random_sample(units.shape) < feat_unit_dropout
|
| 374 |
+
mask_feat[units == padid] = 0
|
| 375 |
+
for i in range(len(raw_units)):
|
| 376 |
+
for j in range(len(raw_units[i])):
|
| 377 |
+
if mask_feat[i,j]:
|
| 378 |
+
features[i,j,:] = 0
|
| 379 |
+
|
| 380 |
+
units = torch.from_numpy(units)
|
| 381 |
+
labels = torch.from_numpy(labels)
|
| 382 |
+
features = torch.from_numpy(features)
|
| 383 |
+
|
| 384 |
+
return units, labels, features, raw_units
|
| 385 |
+
|
| 386 |
+
class SortedDataset(Dataset):
|
| 387 |
+
"""
|
| 388 |
+
Holds a TokenizationDataset for use in a torch DataLoader
|
| 389 |
+
|
| 390 |
+
The torch DataLoader is different from the DataLoader defined here
|
| 391 |
+
and allows for cpu & gpu parallelism. Updating output_predictions
|
| 392 |
+
to use this class as a wrapper to a TokenizationDataset means the
|
| 393 |
+
calculation of features can happen in parallel, saving quite a
|
| 394 |
+
bit of time.
|
| 395 |
+
"""
|
| 396 |
+
def __init__(self, dataset):
|
| 397 |
+
super().__init__()
|
| 398 |
+
|
| 399 |
+
self.dataset = dataset
|
| 400 |
+
self.data, self.indices = sort_with_indices(self.dataset.data, key=len)
|
| 401 |
+
|
| 402 |
+
def __len__(self):
|
| 403 |
+
return len(self.data)
|
| 404 |
+
|
| 405 |
+
def __getitem__(self, index):
|
| 406 |
+
return self.dataset.para_to_sentences(self.data[index])
|
| 407 |
+
|
| 408 |
+
def unsort(self, arr):
|
| 409 |
+
return unsort(arr, self.indices)
|
| 410 |
+
|
| 411 |
+
def collate(self, samples):
|
| 412 |
+
if any(len(x) > 1 for x in samples):
|
| 413 |
+
raise ValueError("Expected all paragraphs to have no preset sentence splits!")
|
| 414 |
+
feat_size = samples[0][0][2].shape[-1]
|
| 415 |
+
padid = self.dataset.vocab.unit2id('<PAD>')
|
| 416 |
+
|
| 417 |
+
# +1 so that all samples end with at least one pad
|
| 418 |
+
pad_len = max(len(x[0][3]) for x in samples) + 1
|
| 419 |
+
|
| 420 |
+
units = torch.full((len(samples), pad_len), padid, dtype=torch.int64)
|
| 421 |
+
labels = torch.full((len(samples), pad_len), -1, dtype=torch.int32)
|
| 422 |
+
features = torch.zeros((len(samples), pad_len, feat_size), dtype=torch.float32)
|
| 423 |
+
raw_units = []
|
| 424 |
+
for i, sample in enumerate(samples):
|
| 425 |
+
u_, l_, f_, r_ = sample[0]
|
| 426 |
+
units[i, :len(u_)] = torch.from_numpy(u_)
|
| 427 |
+
labels[i, :len(l_)] = torch.from_numpy(l_)
|
| 428 |
+
features[i, :len(f_), :] = torch.from_numpy(f_)
|
| 429 |
+
raw_units.append(r_ + ['<PAD>'] * (pad_len - len(r_)))
|
| 430 |
+
|
| 431 |
+
return units, labels, features, raw_units
|
| 432 |
+
|
stanza/stanza/models/tokenization/model.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
|
| 5 |
+
class Tokenizer(nn.Module):
|
| 6 |
+
def __init__(self, args, nchars, emb_dim, hidden_dim, dropout, feat_dropout):
|
| 7 |
+
super().__init__()
|
| 8 |
+
|
| 9 |
+
self.args = args
|
| 10 |
+
feat_dim = args['feat_dim']
|
| 11 |
+
|
| 12 |
+
self.embeddings = nn.Embedding(nchars, emb_dim, padding_idx=0)
|
| 13 |
+
|
| 14 |
+
self.rnn = nn.LSTM(emb_dim + feat_dim, hidden_dim, num_layers=self.args['rnn_layers'], bidirectional=True, batch_first=True, dropout=dropout if self.args['rnn_layers'] > 1 else 0)
|
| 15 |
+
|
| 16 |
+
if self.args['conv_res'] is not None:
|
| 17 |
+
self.conv_res = nn.ModuleList()
|
| 18 |
+
self.conv_sizes = [int(x) for x in self.args['conv_res'].split(',')]
|
| 19 |
+
|
| 20 |
+
for si, size in enumerate(self.conv_sizes):
|
| 21 |
+
l = nn.Conv1d(emb_dim + feat_dim, hidden_dim * 2, size, padding=size//2, bias=self.args.get('hier_conv_res', False) or (si == 0))
|
| 22 |
+
self.conv_res.append(l)
|
| 23 |
+
|
| 24 |
+
if self.args.get('hier_conv_res', False):
|
| 25 |
+
self.conv_res2 = nn.Conv1d(hidden_dim * 2 * len(self.conv_sizes), hidden_dim * 2, 1)
|
| 26 |
+
self.tok_clf = nn.Linear(hidden_dim * 2, 1)
|
| 27 |
+
self.sent_clf = nn.Linear(hidden_dim * 2, 1)
|
| 28 |
+
if self.args['use_mwt']:
|
| 29 |
+
self.mwt_clf = nn.Linear(hidden_dim * 2, 1)
|
| 30 |
+
|
| 31 |
+
if args['hierarchical']:
|
| 32 |
+
in_dim = hidden_dim * 2
|
| 33 |
+
self.rnn2 = nn.LSTM(in_dim, hidden_dim, num_layers=1, bidirectional=True, batch_first=True)
|
| 34 |
+
self.tok_clf2 = nn.Linear(hidden_dim * 2, 1, bias=False)
|
| 35 |
+
self.sent_clf2 = nn.Linear(hidden_dim * 2, 1, bias=False)
|
| 36 |
+
if self.args['use_mwt']:
|
| 37 |
+
self.mwt_clf2 = nn.Linear(hidden_dim * 2, 1, bias=False)
|
| 38 |
+
|
| 39 |
+
self.dropout = nn.Dropout(dropout)
|
| 40 |
+
self.dropout_feat = nn.Dropout(feat_dropout)
|
| 41 |
+
|
| 42 |
+
self.toknoise = nn.Dropout(self.args['tok_noise'])
|
| 43 |
+
|
| 44 |
+
def forward(self, x, feats):
|
| 45 |
+
emb = self.embeddings(x)
|
| 46 |
+
emb = self.dropout(emb)
|
| 47 |
+
feats = self.dropout_feat(feats)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
emb = torch.cat([emb, feats], 2)
|
| 51 |
+
|
| 52 |
+
inp, _ = self.rnn(emb)
|
| 53 |
+
|
| 54 |
+
if self.args['conv_res'] is not None:
|
| 55 |
+
conv_input = emb.transpose(1, 2).contiguous()
|
| 56 |
+
if not self.args.get('hier_conv_res', False):
|
| 57 |
+
for l in self.conv_res:
|
| 58 |
+
inp = inp + l(conv_input).transpose(1, 2).contiguous()
|
| 59 |
+
else:
|
| 60 |
+
hid = []
|
| 61 |
+
for l in self.conv_res:
|
| 62 |
+
hid += [l(conv_input)]
|
| 63 |
+
hid = torch.cat(hid, 1)
|
| 64 |
+
hid = F.relu(hid)
|
| 65 |
+
hid = self.dropout(hid)
|
| 66 |
+
inp = inp + self.conv_res2(hid).transpose(1, 2).contiguous()
|
| 67 |
+
|
| 68 |
+
inp = self.dropout(inp)
|
| 69 |
+
|
| 70 |
+
tok0 = self.tok_clf(inp)
|
| 71 |
+
sent0 = self.sent_clf(inp)
|
| 72 |
+
if self.args['use_mwt']:
|
| 73 |
+
mwt0 = self.mwt_clf(inp)
|
| 74 |
+
|
| 75 |
+
if self.args['hierarchical']:
|
| 76 |
+
if self.args['hier_invtemp'] > 0:
|
| 77 |
+
inp2, _ = self.rnn2(inp * (1 - self.toknoise(torch.sigmoid(-tok0 * self.args['hier_invtemp']))))
|
| 78 |
+
else:
|
| 79 |
+
inp2, _ = self.rnn2(inp)
|
| 80 |
+
|
| 81 |
+
inp2 = self.dropout(inp2)
|
| 82 |
+
|
| 83 |
+
tok0 = tok0 + self.tok_clf2(inp2)
|
| 84 |
+
sent0 = sent0 + self.sent_clf2(inp2)
|
| 85 |
+
if self.args['use_mwt']:
|
| 86 |
+
mwt0 = mwt0 + self.mwt_clf2(inp2)
|
| 87 |
+
|
| 88 |
+
nontok = F.logsigmoid(-tok0)
|
| 89 |
+
tok = F.logsigmoid(tok0)
|
| 90 |
+
nonsent = F.logsigmoid(-sent0)
|
| 91 |
+
sent = F.logsigmoid(sent0)
|
| 92 |
+
if self.args['use_mwt']:
|
| 93 |
+
nonmwt = F.logsigmoid(-mwt0)
|
| 94 |
+
mwt = F.logsigmoid(mwt0)
|
| 95 |
+
|
| 96 |
+
if self.args['use_mwt']:
|
| 97 |
+
pred = torch.cat([nontok, tok+nonsent+nonmwt, tok+sent+nonmwt, tok+nonsent+mwt, tok+sent+mwt], 2)
|
| 98 |
+
else:
|
| 99 |
+
pred = torch.cat([nontok, tok+nonsent, tok+sent], 2)
|
| 100 |
+
|
| 101 |
+
return pred
|
stanza/stanza/models/tokenization/tokenize_files.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Use a Stanza tokenizer to turn a text file into one tokenized paragraph per line
|
| 2 |
+
|
| 3 |
+
For example, the output of this script is suitable for Glove
|
| 4 |
+
|
| 5 |
+
Currently this *only* supports tokenization, no MWT splitting.
|
| 6 |
+
It also would be beneficial to have an option to convert spaces into
|
| 7 |
+
NBSP, underscore, or some other marker to make it easier to process
|
| 8 |
+
languages such as VI which have spaces in them
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
import argparse
|
| 13 |
+
import io
|
| 14 |
+
import os
|
| 15 |
+
import time
|
| 16 |
+
import re
|
| 17 |
+
import zipfile
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
|
| 21 |
+
import stanza
|
| 22 |
+
from stanza.models.common.utils import open_read_text, default_device
|
| 23 |
+
from stanza.models.tokenization.data import TokenizationDataset
|
| 24 |
+
from stanza.models.tokenization.utils import output_predictions
|
| 25 |
+
from stanza.pipeline.tokenize_processor import TokenizeProcessor
|
| 26 |
+
from stanza.utils.get_tqdm import get_tqdm
|
| 27 |
+
|
| 28 |
+
tqdm = get_tqdm()
|
| 29 |
+
|
| 30 |
+
NEWLINE_SPLIT_RE = re.compile(r"\n\s*\n")
|
| 31 |
+
|
| 32 |
+
def tokenize_to_file(tokenizer, fin, fout, chunk_size=500):
|
| 33 |
+
raw_text = fin.read()
|
| 34 |
+
documents = NEWLINE_SPLIT_RE.split(raw_text)
|
| 35 |
+
for chunk_start in tqdm(range(0, len(documents), chunk_size), leave=False):
|
| 36 |
+
chunk_end = min(chunk_start + chunk_size, len(documents))
|
| 37 |
+
chunk = documents[chunk_start:chunk_end]
|
| 38 |
+
in_docs = [stanza.Document([], text=d) for d in chunk]
|
| 39 |
+
out_docs = tokenizer.bulk_process(in_docs)
|
| 40 |
+
for document in out_docs:
|
| 41 |
+
for sent_idx, sentence in enumerate(document.sentences):
|
| 42 |
+
if sent_idx > 0:
|
| 43 |
+
fout.write(" ")
|
| 44 |
+
fout.write(" ".join(x.text for x in sentence.tokens))
|
| 45 |
+
fout.write("\n")
|
| 46 |
+
|
| 47 |
+
def main(args=None):
|
| 48 |
+
parser = argparse.ArgumentParser()
|
| 49 |
+
parser.add_argument("--lang", type=str, default="sd", help="Which language to use for tokenization")
|
| 50 |
+
parser.add_argument("--tokenize_model_path", type=str, default=None, help="Specific tokenizer model to use")
|
| 51 |
+
parser.add_argument("input_files", type=str, nargs="+", help="Which input files to tokenize")
|
| 52 |
+
parser.add_argument("--output_file", type=str, default="glove.txt", help="Where to write the tokenized output")
|
| 53 |
+
parser.add_argument("--model_dir", type=str, default=None, help="Where to get models for a Pipeline (None => default models dir)")
|
| 54 |
+
parser.add_argument("--chunk_size", type=int, default=500, help="How many 'documents' to use in a chunk when tokenizing. This is separate from the tokenizer batching - this limits how much memory gets used at once, since we don't need to store an entire file in memory at once")
|
| 55 |
+
args = parser.parse_args(args=args)
|
| 56 |
+
|
| 57 |
+
if os.path.exists(args.output_file):
|
| 58 |
+
print("Cowardly refusing to overwrite existing output file %s" % args.output_file)
|
| 59 |
+
return
|
| 60 |
+
|
| 61 |
+
if args.tokenize_model_path:
|
| 62 |
+
config = { "model_path": args.tokenize_model_path,
|
| 63 |
+
"check_requirements": False }
|
| 64 |
+
tokenizer = TokenizeProcessor(config, pipeline=None, device=default_device())
|
| 65 |
+
else:
|
| 66 |
+
pipe = stanza.Pipeline(lang=args.lang, processors="tokenize", model_dir=args.model_dir)
|
| 67 |
+
tokenizer = pipe.processors["tokenize"]
|
| 68 |
+
|
| 69 |
+
with open(args.output_file, "w", encoding="utf-8") as fout:
|
| 70 |
+
for filename in tqdm(args.input_files):
|
| 71 |
+
if filename.endswith(".zip"):
|
| 72 |
+
with zipfile.ZipFile(filename) as zin:
|
| 73 |
+
input_names = zin.namelist()
|
| 74 |
+
for input_name in tqdm(input_names, leave=False):
|
| 75 |
+
with zin.open(input_names[0]) as fin:
|
| 76 |
+
fin = io.TextIOWrapper(fin, encoding='utf-8')
|
| 77 |
+
tokenize_to_file(tokenizer, fin, fout)
|
| 78 |
+
else:
|
| 79 |
+
with open_read_text(filename, encoding="utf-8") as fin:
|
| 80 |
+
tokenize_to_file(tokenizer, fin, fout)
|
| 81 |
+
|
| 82 |
+
if __name__ == '__main__':
|
| 83 |
+
main()
|
stanza/stanza/models/tokenization/trainer.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import logging
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.optim as optim
|
| 6 |
+
|
| 7 |
+
from stanza.models.common import utils
|
| 8 |
+
from stanza.models.common.trainer import Trainer as BaseTrainer
|
| 9 |
+
from stanza.models.tokenization.utils import create_dictionary
|
| 10 |
+
|
| 11 |
+
from .model import Tokenizer
|
| 12 |
+
from .vocab import Vocab
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger('stanza')
|
| 15 |
+
|
| 16 |
+
class Trainer(BaseTrainer):
|
| 17 |
+
def __init__(self, args=None, vocab=None, lexicon=None, dictionary=None, model_file=None, device=None):
|
| 18 |
+
if model_file is not None:
|
| 19 |
+
# load everything from file
|
| 20 |
+
self.load(model_file)
|
| 21 |
+
else:
|
| 22 |
+
# build model from scratch
|
| 23 |
+
self.args = args
|
| 24 |
+
self.vocab = vocab
|
| 25 |
+
self.lexicon = list(lexicon) if lexicon is not None else None
|
| 26 |
+
self.dictionary = dictionary
|
| 27 |
+
self.model = Tokenizer(self.args, self.args['vocab_size'], self.args['emb_dim'], self.args['hidden_dim'], dropout=self.args['dropout'], feat_dropout=self.args['feat_dropout'])
|
| 28 |
+
self.model = self.model.to(device)
|
| 29 |
+
self.criterion = nn.CrossEntropyLoss(ignore_index=-1).to(device)
|
| 30 |
+
self.optimizer = utils.get_optimizer("adam", self.model, lr=self.args['lr0'], betas=(.9, .9), weight_decay=self.args['weight_decay'])
|
| 31 |
+
self.feat_funcs = self.args.get('feat_funcs', None)
|
| 32 |
+
self.lang = self.args['lang'] # language determines how token normalization is done
|
| 33 |
+
|
| 34 |
+
def update(self, inputs):
|
| 35 |
+
self.model.train()
|
| 36 |
+
units, labels, features, _ = inputs
|
| 37 |
+
|
| 38 |
+
device = next(self.model.parameters()).device
|
| 39 |
+
units = units.to(device)
|
| 40 |
+
labels = labels.to(device)
|
| 41 |
+
features = features.to(device)
|
| 42 |
+
|
| 43 |
+
pred = self.model(units, features)
|
| 44 |
+
|
| 45 |
+
self.optimizer.zero_grad()
|
| 46 |
+
classes = pred.size(2)
|
| 47 |
+
loss = self.criterion(pred.view(-1, classes), labels.view(-1))
|
| 48 |
+
|
| 49 |
+
loss.backward()
|
| 50 |
+
nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm'])
|
| 51 |
+
self.optimizer.step()
|
| 52 |
+
|
| 53 |
+
return loss.item()
|
| 54 |
+
|
| 55 |
+
def predict(self, inputs):
|
| 56 |
+
self.model.eval()
|
| 57 |
+
units, _, features, _ = inputs
|
| 58 |
+
|
| 59 |
+
device = next(self.model.parameters()).device
|
| 60 |
+
units = units.to(device)
|
| 61 |
+
features = features.to(device)
|
| 62 |
+
|
| 63 |
+
pred = self.model(units, features)
|
| 64 |
+
|
| 65 |
+
return pred.data.cpu().numpy()
|
| 66 |
+
|
| 67 |
+
def save(self, filename):
|
| 68 |
+
params = {
|
| 69 |
+
'model': self.model.state_dict() if self.model is not None else None,
|
| 70 |
+
'vocab': self.vocab.state_dict(),
|
| 71 |
+
# save and load lexicon as list instead of set so
|
| 72 |
+
# we can use weights_only=True
|
| 73 |
+
'lexicon': list(self.lexicon) if self.lexicon is not None else None,
|
| 74 |
+
'config': self.args
|
| 75 |
+
}
|
| 76 |
+
try:
|
| 77 |
+
torch.save(params, filename, _use_new_zipfile_serialization=False)
|
| 78 |
+
logger.info("Model saved to {}".format(filename))
|
| 79 |
+
except BaseException:
|
| 80 |
+
logger.warning("Saving failed... continuing anyway.")
|
| 81 |
+
|
| 82 |
+
def load(self, filename):
|
| 83 |
+
try:
|
| 84 |
+
checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True)
|
| 85 |
+
except BaseException:
|
| 86 |
+
logger.error("Cannot load model from {}".format(filename))
|
| 87 |
+
raise
|
| 88 |
+
self.args = checkpoint['config']
|
| 89 |
+
if self.args.get('use_mwt', None) is None:
|
| 90 |
+
# Default to True as many currently saved models
|
| 91 |
+
# were built with mwt layers
|
| 92 |
+
self.args['use_mwt'] = True
|
| 93 |
+
self.model = Tokenizer(self.args, self.args['vocab_size'], self.args['emb_dim'], self.args['hidden_dim'], dropout=self.args['dropout'], feat_dropout=self.args['feat_dropout'])
|
| 94 |
+
self.model.load_state_dict(checkpoint['model'])
|
| 95 |
+
self.vocab = Vocab.load_state_dict(checkpoint['vocab'])
|
| 96 |
+
self.lexicon = checkpoint['lexicon']
|
| 97 |
+
|
| 98 |
+
if self.lexicon is not None:
|
| 99 |
+
self.lexicon = set(self.lexicon)
|
| 100 |
+
self.dictionary = create_dictionary(self.lexicon)
|
| 101 |
+
else:
|
| 102 |
+
self.dictionary = None
|
stanza/stanza/utils/datasets/constituency/convert_ctb.py
ADDED
|
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from enum import Enum
|
| 2 |
+
import glob
|
| 3 |
+
import os
|
| 4 |
+
import re
|
| 5 |
+
|
| 6 |
+
import xml.etree.ElementTree as ET
|
| 7 |
+
|
| 8 |
+
from stanza.models.constituency import tree_reader
|
| 9 |
+
from stanza.utils.datasets.constituency.utils import write_dataset
|
| 10 |
+
from stanza.utils.get_tqdm import get_tqdm
|
| 11 |
+
|
| 12 |
+
tqdm = get_tqdm()
|
| 13 |
+
|
| 14 |
+
class Version(Enum):
|
| 15 |
+
V51 = 1
|
| 16 |
+
V51b = 2
|
| 17 |
+
V90 = 3
|
| 18 |
+
|
| 19 |
+
def filenum_to_shard_51(filenum):
|
| 20 |
+
if filenum >= 1 and filenum <= 815:
|
| 21 |
+
return 0
|
| 22 |
+
if filenum >= 1001 and filenum <= 1136:
|
| 23 |
+
return 0
|
| 24 |
+
|
| 25 |
+
if filenum >= 886 and filenum <= 931:
|
| 26 |
+
return 1
|
| 27 |
+
if filenum >= 1148 and filenum <= 1151:
|
| 28 |
+
return 1
|
| 29 |
+
|
| 30 |
+
if filenum >= 816 and filenum <= 885:
|
| 31 |
+
return 2
|
| 32 |
+
if filenum >= 1137 and filenum <= 1147:
|
| 33 |
+
return 2
|
| 34 |
+
|
| 35 |
+
raise ValueError("Unhandled filenum %d" % filenum)
|
| 36 |
+
|
| 37 |
+
def filenum_to_shard_51_basic(filenum):
|
| 38 |
+
if filenum >= 1 and filenum <= 270:
|
| 39 |
+
return 0
|
| 40 |
+
if filenum >= 440 and filenum <= 1151:
|
| 41 |
+
return 0
|
| 42 |
+
|
| 43 |
+
if filenum >= 301 and filenum <= 325:
|
| 44 |
+
return 1
|
| 45 |
+
|
| 46 |
+
if filenum >= 271 and filenum <= 300:
|
| 47 |
+
return 2
|
| 48 |
+
|
| 49 |
+
if filenum >= 400 and filenum <= 439:
|
| 50 |
+
return None
|
| 51 |
+
|
| 52 |
+
raise ValueError("Unhandled filenum %d" % filenum)
|
| 53 |
+
|
| 54 |
+
def filenum_to_shard_90(filenum):
|
| 55 |
+
if filenum >= 1 and filenum <= 40:
|
| 56 |
+
return 2
|
| 57 |
+
if filenum >= 900 and filenum <= 931:
|
| 58 |
+
return 2
|
| 59 |
+
if filenum in (1018, 1020, 1036, 1044, 1060, 1061, 1072, 1118, 1119, 1132, 1141, 1142, 1148):
|
| 60 |
+
return 2
|
| 61 |
+
if filenum >= 2165 and filenum <= 2180:
|
| 62 |
+
return 2
|
| 63 |
+
if filenum >= 2295 and filenum <= 2310:
|
| 64 |
+
return 2
|
| 65 |
+
if filenum >= 2570 and filenum <= 2602:
|
| 66 |
+
return 2
|
| 67 |
+
if filenum >= 2800 and filenum <= 2819:
|
| 68 |
+
return 2
|
| 69 |
+
if filenum >= 3110 and filenum <= 3145:
|
| 70 |
+
return 2
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
if filenum >= 41 and filenum <= 80:
|
| 74 |
+
return 1
|
| 75 |
+
if filenum >= 1120 and filenum <= 1129:
|
| 76 |
+
return 1
|
| 77 |
+
if filenum >= 2140 and filenum <= 2159:
|
| 78 |
+
return 1
|
| 79 |
+
if filenum >= 2280 and filenum <= 2294:
|
| 80 |
+
return 1
|
| 81 |
+
if filenum >= 2550 and filenum <= 2569:
|
| 82 |
+
return 1
|
| 83 |
+
if filenum >= 2775 and filenum <= 2799:
|
| 84 |
+
return 1
|
| 85 |
+
if filenum >= 3080 and filenum <= 3109:
|
| 86 |
+
return 1
|
| 87 |
+
|
| 88 |
+
if filenum >= 81 and filenum <= 900:
|
| 89 |
+
return 0
|
| 90 |
+
if filenum >= 1001 and filenum <= 1017:
|
| 91 |
+
return 0
|
| 92 |
+
if filenum in (1019, 1130, 1131):
|
| 93 |
+
return 0
|
| 94 |
+
if filenum >= 1021 and filenum <= 1035:
|
| 95 |
+
return 0
|
| 96 |
+
if filenum >= 1037 and filenum <= 1043:
|
| 97 |
+
return 0
|
| 98 |
+
if filenum >= 1045 and filenum <= 1059:
|
| 99 |
+
return 0
|
| 100 |
+
if filenum >= 1062 and filenum <= 1071:
|
| 101 |
+
return 0
|
| 102 |
+
if filenum >= 1073 and filenum <= 1117:
|
| 103 |
+
return 0
|
| 104 |
+
if filenum >= 1133 and filenum <= 1140:
|
| 105 |
+
return 0
|
| 106 |
+
if filenum >= 1143 and filenum <= 1147:
|
| 107 |
+
return 0
|
| 108 |
+
if filenum >= 1149 and filenum <= 2139:
|
| 109 |
+
return 0
|
| 110 |
+
if filenum >= 2160 and filenum <= 2164:
|
| 111 |
+
return 0
|
| 112 |
+
if filenum >= 2181 and filenum <= 2279:
|
| 113 |
+
return 0
|
| 114 |
+
if filenum >= 2311 and filenum <= 2549:
|
| 115 |
+
return 0
|
| 116 |
+
if filenum >= 2603 and filenum <= 2774:
|
| 117 |
+
return 0
|
| 118 |
+
if filenum >= 2820 and filenum <= 3079:
|
| 119 |
+
return 0
|
| 120 |
+
if filenum >= 4000 and filenum <= 7017:
|
| 121 |
+
return 0
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def collect_trees_s(root):
|
| 125 |
+
if root.tag == 'S':
|
| 126 |
+
yield root.text, root.attrib['ID']
|
| 127 |
+
|
| 128 |
+
for child in root:
|
| 129 |
+
for tree in collect_trees_s(child):
|
| 130 |
+
yield tree
|
| 131 |
+
|
| 132 |
+
def collect_trees_text(root):
|
| 133 |
+
if root.tag == 'TEXT' and len(root.text.strip()) > 0:
|
| 134 |
+
yield root.text, None
|
| 135 |
+
|
| 136 |
+
if root.tag == 'TURN' and len(root.text.strip()) > 0:
|
| 137 |
+
yield root.text, None
|
| 138 |
+
|
| 139 |
+
for child in root:
|
| 140 |
+
for tree in collect_trees_text(child):
|
| 141 |
+
yield tree
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
id_re = re.compile("<S ID=([0-9a-z]+)>")
|
| 145 |
+
su_re = re.compile("<(su|msg) id=([0-9a-zA-Z_=]+)>")
|
| 146 |
+
|
| 147 |
+
def convert_ctb(input_dir, output_dir, dataset_name, version):
|
| 148 |
+
input_files = glob.glob(os.path.join(input_dir, "*"))
|
| 149 |
+
|
| 150 |
+
# train, dev, test
|
| 151 |
+
datasets = [[], [], []]
|
| 152 |
+
|
| 153 |
+
sorted_filenames = []
|
| 154 |
+
for input_filename in input_files:
|
| 155 |
+
base_filename = os.path.split(input_filename)[1]
|
| 156 |
+
filenum = int(os.path.splitext(base_filename)[0].split("_")[1])
|
| 157 |
+
sorted_filenames.append((filenum, input_filename))
|
| 158 |
+
sorted_filenames.sort()
|
| 159 |
+
|
| 160 |
+
for filenum, filename in tqdm(sorted_filenames):
|
| 161 |
+
if version in (Version.V51, Version.V51b):
|
| 162 |
+
with open(filename, errors='ignore', encoding="gb2312") as fin:
|
| 163 |
+
text = fin.read()
|
| 164 |
+
elif version is Version.V90:
|
| 165 |
+
with open(filename, encoding="utf-8") as fin:
|
| 166 |
+
text = fin.read()
|
| 167 |
+
if text.find("<TURN>") >= 0 and text.find("</TURN>") < 0:
|
| 168 |
+
text = text.replace("<TURN>", "")
|
| 169 |
+
if filenum in (4205, 4208, 4289):
|
| 170 |
+
text = text.replace("<)", "<)").replace(">)", ">)")
|
| 171 |
+
if filenum >= 4000 and filenum <= 4411:
|
| 172 |
+
if text.find("<segment") >= 0:
|
| 173 |
+
text = text.replace("<segment id=", "<S ID=").replace("</segment>", "</S>")
|
| 174 |
+
elif text.find("<seg") < 0:
|
| 175 |
+
text = "<TEXT>\n%s</TEXT>\n" % text
|
| 176 |
+
else:
|
| 177 |
+
text = text.replace("<seg id=", "<S ID=").replace("</seg>", "</S>")
|
| 178 |
+
text = "<foo>\n%s</foo>\n" % text
|
| 179 |
+
if filenum >= 5000 and filenum <= 5558 or filenum >= 6000 and filenum <= 6700 or filenum >= 7000 and filenum <= 7017:
|
| 180 |
+
text = su_re.sub("", text)
|
| 181 |
+
if filenum in (6066, 6453):
|
| 182 |
+
text = text.replace("<", "<").replace(">", ">")
|
| 183 |
+
text = "<foo><TEXT>\n%s</TEXT></foo>\n" % text
|
| 184 |
+
else:
|
| 185 |
+
raise ValueError("Unknown CTB version %s" % version)
|
| 186 |
+
text = id_re.sub(r'<S ID="\1">', text)
|
| 187 |
+
text = text.replace("&", "&")
|
| 188 |
+
|
| 189 |
+
try:
|
| 190 |
+
xml_root = ET.fromstring(text)
|
| 191 |
+
except Exception as e:
|
| 192 |
+
print(text[:1000])
|
| 193 |
+
raise RuntimeError("Cannot xml process %s" % filename) from e
|
| 194 |
+
trees = [x for x in collect_trees_s(xml_root)]
|
| 195 |
+
if version is Version.V90 and len(trees) == 0:
|
| 196 |
+
trees = [x for x in collect_trees_text(xml_root)]
|
| 197 |
+
|
| 198 |
+
if version in (Version.V51, Version.V51b):
|
| 199 |
+
trees = [x[0] for x in trees if filenum != 414 or x[1] != "4366"]
|
| 200 |
+
else:
|
| 201 |
+
trees = [x[0] for x in trees]
|
| 202 |
+
|
| 203 |
+
trees = "\n".join(trees)
|
| 204 |
+
try:
|
| 205 |
+
trees = tree_reader.read_trees(trees, use_tqdm=False)
|
| 206 |
+
except ValueError as e:
|
| 207 |
+
print(text[:300])
|
| 208 |
+
raise RuntimeError("Could not process the tree text in %s" % filename)
|
| 209 |
+
trees = [t.prune_none().simplify_labels() for t in trees]
|
| 210 |
+
|
| 211 |
+
assert len(trees) > 0, "No trees in %s" % filename
|
| 212 |
+
|
| 213 |
+
if version is Version.V51:
|
| 214 |
+
shard = filenum_to_shard_51(filenum)
|
| 215 |
+
elif version is Version.V51b:
|
| 216 |
+
shard = filenum_to_shard_51_basic(filenum)
|
| 217 |
+
else:
|
| 218 |
+
shard = filenum_to_shard_90(filenum)
|
| 219 |
+
if shard is None:
|
| 220 |
+
continue
|
| 221 |
+
datasets[shard].extend(trees)
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
write_dataset(datasets, output_dir, dataset_name)
|
stanza/stanza/utils/datasets/constituency/extract_silver_dataset.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
After running build_silver_dataset.py, this extracts the trees of a certain match level
|
| 3 |
+
|
| 4 |
+
For example
|
| 5 |
+
|
| 6 |
+
python3 stanza/utils/datasets/constituency/extract_silver_dataset.py --parsed_trees /u/nlp/data/constituency-parser/italian/2024_it_vit_electra/a*.trees --keep_score 0 --output_file /u/nlp/data/constituency-parser/italian/2024_it_vit_electra/it_silver_0.mrg
|
| 7 |
+
|
| 8 |
+
for i in `echo 0 1 2 3 4 5 6 7 8 9 10`; do python3 stanza/utils/datasets/constituency/extract_silver_dataset.py --parsed_trees /u/nlp/data/constituency-parser/italian/2024_it_vit_electra/a*.trees --keep_score $i --output_file /u/nlp/data/constituency-parser/italian/2024_it_vit_electra/it_silver_$i.mrg; done
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import argparse
|
| 12 |
+
import json
|
| 13 |
+
|
| 14 |
+
def parse_args():
|
| 15 |
+
parser = argparse.ArgumentParser(description="After finding common trees using build_silver_dataset, this extracts them all or just the ones from a particular level of accuracy")
|
| 16 |
+
parser.add_argument('--parsed_trees', type=str, nargs='+', help='Input file(s) of trees parsed into the build_silver_dataset json format.')
|
| 17 |
+
parser.add_argument('--keep_score', type=int, default=None, help='Which agreement level to keep. None keeps all')
|
| 18 |
+
parser.add_argument('--output_file', type=str, default=None, help='Where to put the output file')
|
| 19 |
+
args = parser.parse_args()
|
| 20 |
+
|
| 21 |
+
return args
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def main():
|
| 25 |
+
args = parse_args()
|
| 26 |
+
|
| 27 |
+
trees = []
|
| 28 |
+
for filename in args.parsed_trees:
|
| 29 |
+
with open(filename, encoding='utf-8') as fin:
|
| 30 |
+
for line in fin.readlines():
|
| 31 |
+
tree = json.loads(line)
|
| 32 |
+
if args.keep_score is None or tree['count'] == args.keep_score:
|
| 33 |
+
tree = tree['tree']
|
| 34 |
+
trees.append(tree)
|
| 35 |
+
|
| 36 |
+
if args.output_file is None:
|
| 37 |
+
for tree in trees:
|
| 38 |
+
print(tree)
|
| 39 |
+
else:
|
| 40 |
+
with open(args.output_file, 'w', encoding='utf-8') as fout:
|
| 41 |
+
for tree in trees:
|
| 42 |
+
fout.write(tree)
|
| 43 |
+
fout.write('\n')
|
| 44 |
+
|
| 45 |
+
if __name__ == '__main__':
|
| 46 |
+
main()
|
| 47 |
+
|
stanza/stanza/utils/datasets/coref/balance_languages.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
balance_concat.py
|
| 3 |
+
create a test set from a dev set which is language balanced
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import json
|
| 7 |
+
from collections import defaultdict
|
| 8 |
+
|
| 9 |
+
from random import Random
|
| 10 |
+
|
| 11 |
+
# fix random seed for reproducability
|
| 12 |
+
R = Random(42)
|
| 13 |
+
|
| 14 |
+
with open("./corefud_concat_v1_0_langid.train.json", 'r') as df:
|
| 15 |
+
raw = json.load(df)
|
| 16 |
+
|
| 17 |
+
# calculate type of each class; then, we will select the one
|
| 18 |
+
# which has the LOWEST counts as the sample rate
|
| 19 |
+
lang_counts = defaultdict(int)
|
| 20 |
+
for i in raw:
|
| 21 |
+
lang_counts[i["lang"]] += 1
|
| 22 |
+
|
| 23 |
+
min_lang_count = min(lang_counts.values())
|
| 24 |
+
|
| 25 |
+
# sample 20% of the smallest amount for test set
|
| 26 |
+
# this will look like an absurdly small number, but
|
| 27 |
+
# remember this is DOCUMENTS not TOKENS or UTTERANCES
|
| 28 |
+
# so its actually decent
|
| 29 |
+
# also its per language
|
| 30 |
+
test_set_size = int(0.1*min_lang_count)
|
| 31 |
+
|
| 32 |
+
# sampling input by language
|
| 33 |
+
raw_by_language = defaultdict(list)
|
| 34 |
+
for i in raw:
|
| 35 |
+
raw_by_language[i["lang"]].append(i)
|
| 36 |
+
languages = list(set(raw_by_language.keys()))
|
| 37 |
+
|
| 38 |
+
train_set = []
|
| 39 |
+
test_set = []
|
| 40 |
+
for i in languages:
|
| 41 |
+
length = list(range(len(raw_by_language[i])))
|
| 42 |
+
choices = R.sample(length, test_set_size)
|
| 43 |
+
|
| 44 |
+
for indx,i in enumerate(raw_by_language[i]):
|
| 45 |
+
if indx in choices:
|
| 46 |
+
test_set.append(i)
|
| 47 |
+
else:
|
| 48 |
+
train_set.append(i)
|
| 49 |
+
|
| 50 |
+
with open("./corefud_concat_v1_0_langid-bal.train.json", 'w') as df:
|
| 51 |
+
json.dump(train_set, df, indent=2)
|
| 52 |
+
|
| 53 |
+
with open("./corefud_concat_v1_0_langid-bal.test.json", 'w') as df:
|
| 54 |
+
json.dump(test_set, df, indent=2)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# raw_by_language["en"]
|
| 59 |
+
|
| 60 |
+
|