diff --git a/stanza/stanza/models/classifiers/base_classifier.py b/stanza/stanza/models/classifiers/base_classifier.py
new file mode 100644
index 0000000000000000000000000000000000000000..3679e79091d93941d70c720831e26c82b182caf4
--- /dev/null
+++ b/stanza/stanza/models/classifiers/base_classifier.py
@@ -0,0 +1,65 @@
+from abc import ABC, abstractmethod
+
+import logging
+
+import torch
+import torch.nn as nn
+
+from stanza.models.common.utils import split_into_batches, sort_with_indices, unsort
+
+"""
+A base classifier type
+
+Currently, has the ability to process text or other inputs in a manner
+suitable for the particular model type.
+In other words, the CNNClassifier processes lists of words,
+and the ConstituencyClassifier processes trees
+"""
+
+logger = logging.getLogger('stanza')
+
+class BaseClassifier(ABC, nn.Module):
+ @abstractmethod
+ def extract_sentences(self, doc):
+ """
+ Extract the sentences or the relevant information in the sentences from a document
+ """
+
+ def preprocess_sentences(self, sentences):
+ """
+ By default, don't do anything
+ """
+ return sentences
+
+ def label_sentences(self, sentences, batch_size=None):
+ """
+ Given a list of sentences, return the model's results on that text.
+ """
+ self.eval()
+
+ sentences = self.preprocess_sentences(sentences)
+
+ if batch_size is None:
+ intervals = [(0, len(sentences))]
+ orig_idx = None
+ else:
+ sentences, orig_idx = sort_with_indices(sentences, key=len, reverse=True)
+ intervals = split_into_batches(sentences, batch_size)
+ labels = []
+ for interval in intervals:
+ if interval[1] - interval[0] == 0:
+ # this can happen for empty text
+ continue
+ output = self(sentences[interval[0]:interval[1]])
+ predicted = torch.argmax(output, dim=1)
+ labels.extend(predicted.tolist())
+
+ if orig_idx:
+ sentences = unsort(sentences, orig_idx)
+ labels = unsort(labels, orig_idx)
+
+ logger.debug("Found labels")
+ for (label, sentence) in zip(labels, sentences):
+ logger.debug((label, sentence))
+
+ return labels
diff --git a/stanza/stanza/models/classifiers/cnn_classifier.py b/stanza/stanza/models/classifiers/cnn_classifier.py
new file mode 100644
index 0000000000000000000000000000000000000000..e48e1f5955e2e994c259fe1de23dd8f18a727a59
--- /dev/null
+++ b/stanza/stanza/models/classifiers/cnn_classifier.py
@@ -0,0 +1,547 @@
+import dataclasses
+import logging
+import math
+import os
+import random
+import re
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+import stanza.models.classifiers.data as data
+from stanza.models.classifiers.base_classifier import BaseClassifier
+from stanza.models.classifiers.config import CNNConfig
+from stanza.models.classifiers.data import SentimentDatum
+from stanza.models.classifiers.utils import ExtraVectors, ModelType, build_output_layers
+from stanza.models.common.bert_embedding import extract_bert_embeddings
+from stanza.models.common.data import get_long_tensor, sort_all
+from stanza.models.common.utils import attach_bert_model
+from stanza.models.common.vocab import PAD_ID, UNK_ID
+
+"""
+The CNN classifier is based on Yoon Kim's work:
+
+https://arxiv.org/abs/1408.5882
+
+Also included are maxpool 2d, conv 2d, and a bilstm, as in
+
+Text Classification Improved by Integrating Bidirectional LSTM
+with Two-dimensional Max Pooling
+https://aclanthology.org/C16-1329.pdf
+
+The architecture is simple:
+
+- Embedding at the bottom layer
+ - separate learnable entry for UNK, since many of the embeddings we have use 0 for UNK
+- maybe a bilstm layer, as per a command line flag
+- Some number of conv2d layers over the embedding
+- Maxpool layers over small windows, window size being a parameter
+- FC layer to the classification layer
+
+One experiment which was run and found to be a bit of a negative was
+putting a layer on top of the pretrain. You would think that might
+help, but dev performance went down for each variation of
+ - trans(emb)
+ - relu(trans(emb))
+ - dropout(trans(emb))
+ - dropout(relu(trans(emb)))
+"""
+
+logger = logging.getLogger('stanza')
+tlogger = logging.getLogger('stanza.classifiers.trainer')
+
+class CNNClassifier(BaseClassifier):
+ def __init__(self, pretrain, extra_vocab, labels,
+ charmodel_forward, charmodel_backward, elmo_model, bert_model, bert_tokenizer, force_bert_saved, peft_name,
+ args):
+ """
+ pretrain is a pretrained word embedding. should have .emb and .vocab
+
+ extra_vocab is a collection of words in the training data to
+ be used for the delta word embedding, if used. can be set to
+ None if delta word embedding is not used.
+
+ labels is the list of labels we expect in the training data.
+ Used to derive the number of classes. Saving it in the model
+ will let us check that test data has the same labels
+
+ args is either the complete arguments when training, or the
+ subset of arguments stored in the model save file
+ """
+ super(CNNClassifier, self).__init__()
+ self.labels = labels
+ bert_finetune = args.bert_finetune
+ use_peft = args.use_peft
+ force_bert_saved = force_bert_saved or bert_finetune
+ logger.debug("bert_finetune %s / force_bert_saved %s", bert_finetune, force_bert_saved)
+
+ # this may change when loaded in a new Pipeline, so it's not part of the config
+ self.peft_name = peft_name
+
+ # we build a separate config out of the args so that we can easily save it in torch
+ self.config = CNNConfig(filter_channels = args.filter_channels,
+ filter_sizes = args.filter_sizes,
+ fc_shapes = args.fc_shapes,
+ dropout = args.dropout,
+ num_classes = len(labels),
+ wordvec_type = args.wordvec_type,
+ extra_wordvec_method = args.extra_wordvec_method,
+ extra_wordvec_dim = args.extra_wordvec_dim,
+ extra_wordvec_max_norm = args.extra_wordvec_max_norm,
+ char_lowercase = args.char_lowercase,
+ charlm_projection = args.charlm_projection,
+ has_charlm_forward = charmodel_forward is not None,
+ has_charlm_backward = charmodel_backward is not None,
+ use_elmo = args.use_elmo,
+ elmo_projection = args.elmo_projection,
+ bert_model = args.bert_model,
+ bert_finetune = bert_finetune,
+ bert_hidden_layers = args.bert_hidden_layers,
+ force_bert_saved = force_bert_saved,
+
+ use_peft = use_peft,
+ lora_rank = args.lora_rank,
+ lora_alpha = args.lora_alpha,
+ lora_dropout = args.lora_dropout,
+ lora_modules_to_save = args.lora_modules_to_save,
+ lora_target_modules = args.lora_target_modules,
+
+ bilstm = args.bilstm,
+ bilstm_hidden_dim = args.bilstm_hidden_dim,
+ maxpool_width = args.maxpool_width,
+ model_type = ModelType.CNN)
+
+ self.char_lowercase = args.char_lowercase
+
+ self.unsaved_modules = []
+
+ emb_matrix = pretrain.emb
+ self.add_unsaved_module('embedding', nn.Embedding.from_pretrained(emb_matrix, freeze=True))
+ self.add_unsaved_module('elmo_model', elmo_model)
+ self.vocab_size = emb_matrix.shape[0]
+ self.embedding_dim = emb_matrix.shape[1]
+
+ self.add_unsaved_module('forward_charlm', charmodel_forward)
+ if charmodel_forward is not None:
+ tlogger.debug("Got forward char model of dimension {}".format(charmodel_forward.hidden_dim()))
+ if not charmodel_forward.is_forward_lm:
+ raise ValueError("Got a backward charlm as a forward charlm!")
+ self.add_unsaved_module('backward_charlm', charmodel_backward)
+ if charmodel_backward is not None:
+ tlogger.debug("Got backward char model of dimension {}".format(charmodel_backward.hidden_dim()))
+ if charmodel_backward.is_forward_lm:
+ raise ValueError("Got a forward charlm as a backward charlm!")
+
+ attach_bert_model(self, bert_model, bert_tokenizer, self.config.use_peft, force_bert_saved)
+
+ # The Pretrain has PAD and UNK already (indices 0 and 1), but we
+ # possibly want to train UNK while freezing the rest of the embedding
+ # note that the /10.0 operation has to be inside nn.Parameter unless
+ # you want to spend a long time debugging this
+ self.unk = nn.Parameter(torch.randn(self.embedding_dim) / np.sqrt(self.embedding_dim) / 10.0)
+
+ # replacing NBSP picks up a whole bunch of words for VI
+ self.vocab_map = { word.replace('\xa0', ' '): i for i, word in enumerate(pretrain.vocab) }
+
+ if self.config.extra_wordvec_method is not ExtraVectors.NONE:
+ if not extra_vocab:
+ raise ValueError("Should have had extra_vocab set for extra_wordvec_method {}".format(self.config.extra_wordvec_method))
+ if not args.extra_wordvec_dim:
+ self.config.extra_wordvec_dim = self.embedding_dim
+ if self.config.extra_wordvec_method is ExtraVectors.SUM:
+ if self.config.extra_wordvec_dim != self.embedding_dim:
+ raise ValueError("extra_wordvec_dim must equal embedding_dim for {}".format(self.config.extra_wordvec_method))
+
+ self.extra_vocab = list(extra_vocab)
+ self.extra_vocab_map = { word: i for i, word in enumerate(self.extra_vocab) }
+ # TODO: possibly add regularization specifically on the extra embedding?
+ # note: it looks like a bug that this doesn't add UNK or PAD, but actually
+ # those are expected to already be the first two entries
+ self.extra_embedding = nn.Embedding(num_embeddings = len(extra_vocab),
+ embedding_dim = self.config.extra_wordvec_dim,
+ max_norm = self.config.extra_wordvec_max_norm,
+ padding_idx = 0)
+ tlogger.debug("Extra embedding size: {}".format(self.extra_embedding.weight.shape))
+ else:
+ self.extra_vocab = None
+ self.extra_vocab_map = None
+ self.config.extra_wordvec_dim = 0
+ self.extra_embedding = None
+
+ # Pytorch is "aware" of the existence of the nn.Modules inside
+ # an nn.ModuleList in terms of parameters() etc
+ if self.config.extra_wordvec_method is ExtraVectors.NONE:
+ total_embedding_dim = self.embedding_dim
+ elif self.config.extra_wordvec_method is ExtraVectors.SUM:
+ total_embedding_dim = self.embedding_dim
+ elif self.config.extra_wordvec_method is ExtraVectors.CONCAT:
+ total_embedding_dim = self.embedding_dim + self.config.extra_wordvec_dim
+ else:
+ raise ValueError("unable to handle {}".format(self.config.extra_wordvec_method))
+
+ if charmodel_forward is not None:
+ if args.charlm_projection:
+ self.charmodel_forward_projection = nn.Linear(charmodel_forward.hidden_dim(), args.charlm_projection)
+ total_embedding_dim += args.charlm_projection
+ else:
+ self.charmodel_forward_projection = None
+ total_embedding_dim += charmodel_forward.hidden_dim()
+
+ if charmodel_backward is not None:
+ if args.charlm_projection:
+ self.charmodel_backward_projection = nn.Linear(charmodel_backward.hidden_dim(), args.charlm_projection)
+ total_embedding_dim += args.charlm_projection
+ else:
+ self.charmodel_backward_projection = None
+ total_embedding_dim += charmodel_backward.hidden_dim()
+
+ if self.config.use_elmo:
+ if elmo_model is None:
+ raise ValueError("Model requires elmo, but elmo_model not passed in")
+ elmo_dim = elmo_model.sents2elmo([["Test"]])[0].shape[1]
+
+ # this mapping will combine 3 layers of elmo to 1 layer of features
+ self.elmo_combine_layers = nn.Linear(in_features=3, out_features=1, bias=False)
+ if self.config.elmo_projection:
+ self.elmo_projection = nn.Linear(in_features=elmo_dim, out_features=self.config.elmo_projection)
+ total_embedding_dim = total_embedding_dim + self.config.elmo_projection
+ else:
+ total_embedding_dim = total_embedding_dim + elmo_dim
+
+ if bert_model is not None:
+ if self.config.bert_hidden_layers:
+ # The average will be offset by 1/N so that the default zeros
+ # repressents an average of the N layers
+ if self.config.bert_hidden_layers > bert_model.config.num_hidden_layers:
+ # limit ourselves to the number of layers actually available
+ # note that we can +1 because of the initial embedding layer
+ self.config.bert_hidden_layers = bert_model.config.num_hidden_layers + 1
+ self.bert_layer_mix = nn.Linear(self.config.bert_hidden_layers, 1, bias=False)
+ nn.init.zeros_(self.bert_layer_mix.weight)
+ else:
+ # an average of layers 2, 3, 4 will be used
+ # (for historic reasons)
+ self.bert_layer_mix = None
+
+ if bert_tokenizer is None:
+ raise ValueError("Cannot have a bert model without a tokenizer")
+ self.bert_dim = self.bert_model.config.hidden_size
+ total_embedding_dim += self.bert_dim
+
+ if self.config.bilstm:
+ conv_input_dim = self.config.bilstm_hidden_dim * 2
+ self.bilstm = nn.LSTM(batch_first=True,
+ input_size=total_embedding_dim,
+ hidden_size=self.config.bilstm_hidden_dim,
+ num_layers=2,
+ bidirectional=True,
+ dropout=0.2)
+ else:
+ conv_input_dim = total_embedding_dim
+ self.bilstm = None
+
+ self.fc_input_size = 0
+ self.conv_layers = nn.ModuleList()
+ self.max_window = 0
+ for filter_idx, filter_size in enumerate(self.config.filter_sizes):
+ if isinstance(filter_size, int):
+ self.max_window = max(self.max_window, filter_size)
+ if isinstance(self.config.filter_channels, int):
+ filter_channels = self.config.filter_channels
+ else:
+ filter_channels = self.config.filter_channels[filter_idx]
+ fc_delta = filter_channels // self.config.maxpool_width
+ tlogger.debug("Adding full width filter %d. Output channels: %d -> %d", filter_size, filter_channels, fc_delta)
+ self.fc_input_size += fc_delta
+ self.conv_layers.append(nn.Conv2d(in_channels=1,
+ out_channels=filter_channels,
+ kernel_size=(filter_size, conv_input_dim)))
+ elif isinstance(filter_size, tuple) and len(filter_size) == 2:
+ filter_height, filter_width = filter_size
+ self.max_window = max(self.max_window, filter_width)
+ if isinstance(self.config.filter_channels, int):
+ filter_channels = max(1, self.config.filter_channels // (conv_input_dim // filter_width))
+ else:
+ filter_channels = self.config.filter_channels[filter_idx]
+ fc_delta = filter_channels * (conv_input_dim // filter_width) // self.config.maxpool_width
+ tlogger.debug("Adding filter %s. Output channels: %d -> %d", filter_size, filter_channels, fc_delta)
+ self.fc_input_size += fc_delta
+ self.conv_layers.append(nn.Conv2d(in_channels=1,
+ out_channels=filter_channels,
+ stride=(1, filter_width),
+ kernel_size=(filter_height, filter_width)))
+ else:
+ raise ValueError("Expected int or 2d tuple for conv size")
+
+ tlogger.debug("Input dim to FC layers: %d", self.fc_input_size)
+ self.fc_layers = build_output_layers(self.fc_input_size, self.config.fc_shapes, self.config.num_classes)
+
+ self.dropout = nn.Dropout(self.config.dropout)
+
+ def add_unsaved_module(self, name, module):
+ self.unsaved_modules += [name]
+ setattr(self, name, module)
+
+ if module is not None and (name in ('forward_charlm', 'backward_charlm') or
+ (name == 'bert_model' and not self.config.use_peft)):
+ # if we are using peft, we should not save the transformer directly
+ # instead, the peft parameters only will be saved later
+ for _, parameter in module.named_parameters():
+ parameter.requires_grad = False
+
+ def is_unsaved_module(self, name):
+ return name.split('.')[0] in self.unsaved_modules
+
+ def log_configuration(self):
+ """
+ Log some essential information about the model configuration to the training logger
+ """
+ tlogger.info("Filter sizes: %s" % str(self.config.filter_sizes))
+ tlogger.info("Filter channels: %s" % str(self.config.filter_channels))
+ tlogger.info("Intermediate layers: %s" % str(self.config.fc_shapes))
+
+ def log_norms(self):
+ lines = ["NORMS FOR MODEL PARAMTERS"]
+ for name, param in self.named_parameters():
+ if param.requires_grad and name.split(".")[0] not in ('forward_charlm', 'backward_charlm'):
+ lines.append("%s %.6g" % (name, torch.norm(param).item()))
+ logger.info("\n".join(lines))
+
+ def build_char_reps(self, inputs, max_phrase_len, charlm, projection, begin_paddings, device):
+ char_reps = charlm.build_char_representation(inputs)
+ if projection is not None:
+ char_reps = [projection(x) for x in char_reps]
+ char_inputs = torch.zeros((len(inputs), max_phrase_len, char_reps[0].shape[-1]), device=device)
+ for idx, rep in enumerate(char_reps):
+ start = begin_paddings[idx]
+ end = start + rep.shape[0]
+ char_inputs[idx, start:end, :] = rep
+ return char_inputs
+
+ def extract_bert_embeddings(self, inputs, max_phrase_len, begin_paddings, device):
+ bert_embeddings = extract_bert_embeddings(self.config.bert_model, self.bert_tokenizer, self.bert_model, inputs, device,
+ keep_endpoints=False,
+ num_layers=self.bert_layer_mix.in_features if self.bert_layer_mix is not None else None,
+ detach=not self.config.bert_finetune,
+ peft_name=self.peft_name)
+ if self.bert_layer_mix is not None:
+ # add the average so that the default behavior is to
+ # take an average of the N layers, and anything else
+ # other than that needs to be learned
+ bert_embeddings = [self.bert_layer_mix(feature).squeeze(2) + feature.sum(axis=2) / self.bert_layer_mix.in_features for feature in bert_embeddings]
+ bert_inputs = torch.zeros((len(inputs), max_phrase_len, bert_embeddings[0].shape[-1]), device=device)
+ for idx, rep in enumerate(bert_embeddings):
+ start = begin_paddings[idx]
+ end = start + rep.shape[0]
+ bert_inputs[idx, start:end, :] = rep
+ return bert_inputs
+
+ def forward(self, inputs):
+ # assume all pieces are on the same device
+ device = next(self.parameters()).device
+
+ vocab_map = self.vocab_map
+ def map_word(word):
+ idx = vocab_map.get(word, None)
+ if idx is not None:
+ return idx
+ if word[-1] == "'":
+ idx = vocab_map.get(word[:-1], None)
+ if idx is not None:
+ return idx
+ return vocab_map.get(word.lower(), UNK_ID)
+
+ inputs = [x.text if isinstance(x, SentimentDatum) else x for x in inputs]
+ # we will pad each phrase so either it matches the longest
+ # conv or the longest phrase in the input, whichever is longer
+ max_phrase_len = max(len(x) for x in inputs)
+ if self.max_window > max_phrase_len:
+ max_phrase_len = self.max_window
+
+ batch_indices = []
+ batch_unknowns = []
+ extra_batch_indices = []
+ begin_paddings = []
+ end_paddings = []
+
+ elmo_batch_words = []
+
+ for phrase in inputs:
+ # we use random at training time to try to learn different
+ # positions of padding. at test time, though, we want to
+ # have consistent results, so we set that to 0 begin_pad
+ if self.training:
+ begin_pad_width = random.randint(0, max_phrase_len - len(phrase))
+ else:
+ begin_pad_width = 0
+ end_pad_width = max_phrase_len - begin_pad_width - len(phrase)
+
+ begin_paddings.append(begin_pad_width)
+ end_paddings.append(end_pad_width)
+
+ # the initial lists are the length of the begin padding
+ sentence_indices = [PAD_ID] * begin_pad_width
+ sentence_indices.extend([map_word(x) for x in phrase])
+ sentence_indices.extend([PAD_ID] * end_pad_width)
+
+ # the "unknowns" will be the locations of the unknown words.
+ # these locations will get the specially trained unknown vector
+ # TODO: split UNK based on part of speech? might be an interesting experiment
+ sentence_unknowns = [idx for idx, word in enumerate(sentence_indices) if word == UNK_ID]
+
+ batch_indices.append(sentence_indices)
+ batch_unknowns.append(sentence_unknowns)
+
+ if self.extra_vocab:
+ extra_sentence_indices = [PAD_ID] * begin_pad_width
+ for word in phrase:
+ if word in self.extra_vocab_map:
+ # the extra vocab is initialized from the
+ # words in the training set, which means there
+ # would be no unknown words. to occasionally
+ # train the extra vocab's unknown words, we
+ # replace 1% of the words with UNK
+ # we don't do that for the original embedding
+ # on the assumption that there may be some
+ # unknown words in the training set anyway
+ # TODO: maybe train unk for the original embedding?
+ if self.training and random.random() < 0.01:
+ extra_sentence_indices.append(UNK_ID)
+ else:
+ extra_sentence_indices.append(self.extra_vocab_map[word])
+ else:
+ extra_sentence_indices.append(UNK_ID)
+ extra_sentence_indices.extend([PAD_ID] * end_pad_width)
+ extra_batch_indices.append(extra_sentence_indices)
+
+ if self.config.use_elmo:
+ elmo_phrase_words = [""] * begin_pad_width
+ for word in phrase:
+ elmo_phrase_words.append(word)
+ elmo_phrase_words.extend([""] * end_pad_width)
+ elmo_batch_words.append(elmo_phrase_words)
+
+ # creating a single large list with all the indices lets us
+ # create a single tensor, which is much faster than creating
+ # many tiny tensors
+ # we can convert this to the input to the CNN
+ # it is padded at one or both ends so that it is now num_phrases x max_len x emb_size
+ # there are two ways in which this padding is suboptimal
+ # the first is that for short sentences, smaller windows will
+ # be padded to the point that some windows are entirely pad
+ # the second is that a sentence S will have more or less padding
+ # depending on what other sentences are in its batch
+ # we assume these effects are pretty minimal
+ batch_indices = torch.tensor(batch_indices, requires_grad=False, device=device)
+ input_vectors = self.embedding(batch_indices)
+ # we use the random unk so that we are not necessarily
+ # learning to match 0s for unk
+ for phrase_num, sentence_unknowns in enumerate(batch_unknowns):
+ input_vectors[phrase_num][sentence_unknowns] = self.unk
+
+ if self.extra_vocab:
+ extra_batch_indices = torch.tensor(extra_batch_indices, requires_grad=False, device=device)
+ extra_input_vectors = self.extra_embedding(extra_batch_indices)
+ if self.config.extra_wordvec_method is ExtraVectors.CONCAT:
+ all_inputs = [input_vectors, extra_input_vectors]
+ elif self.config.extra_wordvec_method is ExtraVectors.SUM:
+ all_inputs = [input_vectors + extra_input_vectors]
+ else:
+ raise ValueError("unable to handle {}".format(self.config.extra_wordvec_method))
+ else:
+ all_inputs = [input_vectors]
+
+ if self.forward_charlm is not None:
+ char_reps_forward = self.build_char_reps(inputs, max_phrase_len, self.forward_charlm, self.charmodel_forward_projection, begin_paddings, device)
+ all_inputs.append(char_reps_forward)
+
+ if self.backward_charlm is not None:
+ char_reps_backward = self.build_char_reps(inputs, max_phrase_len, self.backward_charlm, self.charmodel_backward_projection, begin_paddings, device)
+ all_inputs.append(char_reps_backward)
+
+ if self.config.use_elmo:
+ # this will be N arrays of 3xMx1024 where M is the number of words
+ # and N is the number of sentences (and 1024 is actually the number of weights)
+ elmo_arrays = self.elmo_model.sents2elmo(elmo_batch_words, output_layer=-2)
+ elmo_tensors = [torch.tensor(x).to(device=device) for x in elmo_arrays]
+ # elmo_tensor will now be Nx3xMx1024
+ elmo_tensor = torch.stack(elmo_tensors)
+ # Nx1024xMx3
+ elmo_tensor = torch.transpose(elmo_tensor, 1, 3)
+ # NxMx1024x3
+ elmo_tensor = torch.transpose(elmo_tensor, 1, 2)
+ # NxMx1024x1
+ elmo_tensor = self.elmo_combine_layers(elmo_tensor)
+ # NxMx1024
+ elmo_tensor = elmo_tensor.squeeze(3)
+ if self.config.elmo_projection:
+ elmo_tensor = self.elmo_projection(elmo_tensor)
+ all_inputs.append(elmo_tensor)
+
+ if self.bert_model is not None:
+ bert_embeddings = self.extract_bert_embeddings(inputs, max_phrase_len, begin_paddings, device)
+ all_inputs.append(bert_embeddings)
+
+ # still works even if there's just one item
+ input_vectors = torch.cat(all_inputs, dim=2)
+
+ if self.config.bilstm:
+ input_vectors, _ = self.bilstm(self.dropout(input_vectors))
+
+ # reshape to fit the input tensors
+ x = input_vectors.unsqueeze(1)
+
+ conv_outs = []
+ for conv, filter_size in zip(self.conv_layers, self.config.filter_sizes):
+ if isinstance(filter_size, int):
+ conv_out = self.dropout(F.relu(conv(x).squeeze(3)))
+ conv_outs.append(conv_out)
+ else:
+ conv_out = conv(x).transpose(2, 3).flatten(1, 2)
+ conv_out = self.dropout(F.relu(conv_out))
+ conv_outs.append(conv_out)
+ pool_outs = [F.max_pool2d(out, (self.config.maxpool_width, out.shape[2])).squeeze(2) for out in conv_outs]
+ pooled = torch.cat(pool_outs, dim=1)
+
+ previous_layer = pooled
+ for fc in self.fc_layers[:-1]:
+ previous_layer = self.dropout(F.relu(fc(previous_layer)))
+ out = self.fc_layers[-1](previous_layer)
+ # note that we return the raw logits rather than use a softmax
+ # https://discuss.pytorch.org/t/multi-class-cross-entropy-loss-and-softmax-in-pytorch/24920/4
+ return out
+
+ def get_params(self, skip_modules=True):
+ model_state = self.state_dict()
+ # skip saving modules like pretrained embeddings, because they are large and will be saved in a separate file
+ if skip_modules:
+ skipped = [k for k in model_state.keys() if self.is_unsaved_module(k)]
+ for k in skipped:
+ del model_state[k]
+
+ config = dataclasses.asdict(self.config)
+ config['wordvec_type'] = config['wordvec_type'].name
+ config['extra_wordvec_method'] = config['extra_wordvec_method'].name
+ config['model_type'] = config['model_type'].name
+
+ params = {
+ 'model': model_state,
+ 'config': config,
+ 'labels': self.labels,
+ 'extra_vocab': self.extra_vocab,
+ }
+ if self.config.use_peft:
+ # Hide import so that peft dependency is optional
+ from peft import get_peft_model_state_dict
+ params["bert_lora"] = get_peft_model_state_dict(self.bert_model, adapter_name=self.peft_name)
+ return params
+
+ def preprocess_data(self, sentences):
+ sentences = [data.update_text(s, self.config.wordvec_type) for s in sentences]
+ return sentences
+
+ def extract_sentences(self, doc):
+ # TODO: tokens or words better here?
+ return [[token.text for token in sentence.tokens] for sentence in doc.sentences]
diff --git a/stanza/stanza/models/classifiers/iterate_test.py b/stanza/stanza/models/classifiers/iterate_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..01ee75c012232b08a1dc5258baf2f578eacdafb7
--- /dev/null
+++ b/stanza/stanza/models/classifiers/iterate_test.py
@@ -0,0 +1,64 @@
+"""Iterate test."""
+import argparse
+import glob
+import logging
+
+import stanza.models.classifier as classifier
+import stanza.models.classifiers.cnn_classifier as cnn_classifier
+from stanza.models.common import utils
+
+from stanza.utils.confusion import format_confusion, confusion_to_accuracy
+
+"""
+A script for running the same test file on several different classifiers.
+
+For each one, it will output the accuracy and, if possible, the confusion matrix.
+
+Includes the arguments for pretrain, which allows for passing in a
+different directory for the pretrain file.
+
+Example command line:
+ python3 -m stanza.models.classifiers.iterate_test --test_file extern_data/sentiment/sst-processed/threeclass/test-threeclass-roots.txt --glob "saved_models/classifier/FC41_3class_en_ewt_FS*ACC66*"
+"""
+
+logger = logging.getLogger('stanza')
+
+
+def parse_args():
+ """Add and parse arguments."""
+ parser = classifier.build_parser()
+
+ parser.add_argument('--glob', type=str, default='saved_models/classifier/*classifier*pt', help='Model file(s) to test.')
+
+ args = parser.parse_args()
+ return args
+
+args = parse_args()
+seed = utils.set_random_seed(args.seed)
+
+model_files = []
+for glob_piece in args.glob.split():
+ model_files.extend(glob.glob(glob_piece))
+model_files = sorted(set(model_files))
+
+test_set = data.read_dataset(args.test_file, args.wordvec_type, min_len=None)
+logger.info("Using test set: %s" % args.test_file)
+
+device = None
+for load_name in model_files:
+ args.load_name = load_name
+ model = classifier.load_model(args)
+
+ logger.info("Testing %s" % load_name)
+ model = cnn_classifier.load(load_name, pretrain)
+ if device is None:
+ device = next(model.parameters()).device
+ logger.info("Current device: %s" % device)
+
+ labels = model.labels
+ classifier.check_labels(labels, test_set)
+
+ confusion = classifier.confusion_dataset(model, test_set, device=device)
+ correct, total = confusion_to_accuracy(confusion)
+ logger.info(" Results: %d correct of %d examples. Accuracy: %f" % (correct, total, correct / total))
+ logger.info("Confusion matrix:\n{}".format(format_confusion(confusion, model.labels)))
diff --git a/stanza/stanza/models/classifiers/trainer.py b/stanza/stanza/models/classifiers/trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..296e47bb65f7af7fcbb70ea14ba2155a2af15277
--- /dev/null
+++ b/stanza/stanza/models/classifiers/trainer.py
@@ -0,0 +1,304 @@
+"""
+Organizes the model itself and its optimizer in one place
+
+Saving the optimizer allows for easy restarting of training
+"""
+
+import logging
+import os
+import torch
+import torch.optim as optim
+from types import SimpleNamespace
+
+import stanza.models.classifiers.data as data
+import stanza.models.classifiers.cnn_classifier as cnn_classifier
+import stanza.models.classifiers.constituency_classifier as constituency_classifier
+from stanza.models.classifiers.config import CNNConfig, ConstituencyConfig
+from stanza.models.classifiers.utils import ModelType, WVType, ExtraVectors
+from stanza.models.common.foundation_cache import load_bert, load_bert_with_peft, load_charlm, load_pretrain
+from stanza.models.common.peft_config import build_peft_wrapper, load_peft_wrapper
+from stanza.models.common.pretrain import Pretrain
+from stanza.models.common.utils import get_split_optimizer
+from stanza.models.constituency.tree_embedding import TreeEmbedding
+
+from pickle import UnpicklingError
+import warnings
+
+logger = logging.getLogger('stanza')
+
+class Trainer:
+ """
+ Stores a constituency model and its optimizer
+ """
+
+ def __init__(self, model, optimizer=None, epochs_trained=0, global_step=0, best_score=None):
+ self.model = model
+ self.optimizer = optimizer
+ # we keep track of position in the learning so that we can
+ # checkpoint & restart if needed without restarting the epoch count
+ self.epochs_trained = epochs_trained
+ self.global_step = global_step
+ # save the best dev score so that when reloading a checkpoint
+ # of a model, we know how far we got
+ self.best_score = best_score
+
+ def save(self, filename, epochs_trained=None, skip_modules=True, save_optimizer=True):
+ """
+ save the current model, optimizer, and other state to filename
+
+ epochs_trained can be passed as a parameter to handle saving at the end of an epoch
+ """
+ if epochs_trained is None:
+ epochs_trained = self.epochs_trained
+ save_dir = os.path.split(filename)[0]
+ os.makedirs(save_dir, exist_ok=True)
+ model_params = self.model.get_params(skip_modules)
+ params = {
+ 'params': model_params,
+ 'epochs_trained': epochs_trained,
+ 'global_step': self.global_step,
+ 'best_score': self.best_score,
+ }
+ if save_optimizer and self.optimizer is not None:
+ params['optimizer_state_dict'] = {opt_name: opt.state_dict() for opt_name, opt in self.optimizer.items()}
+ torch.save(params, filename, _use_new_zipfile_serialization=False)
+ logger.info("Model saved to {}".format(filename))
+
+ @staticmethod
+ def load(filename, args, foundation_cache=None, load_optimizer=False):
+ if not os.path.exists(filename):
+ if args.save_dir is None:
+ raise FileNotFoundError("Cannot find model in {} and args.save_dir is None".format(filename))
+ elif os.path.exists(os.path.join(args.save_dir, filename)):
+ filename = os.path.join(args.save_dir, filename)
+ else:
+ raise FileNotFoundError("Cannot find model in {} or in {}".format(filename, os.path.join(args.save_dir, filename)))
+ try:
+ # TODO: can remove the try/except once the new version is out
+ #checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True)
+ try:
+ checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True)
+ except UnpicklingError as e:
+ checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=False)
+ warnings.warn("The saved classifier has an old format using SimpleNamespace and/or Enum instead of a dict to store config. This version of Stanza can support reading both the new and the old formats. Future versions will only allow loading with weights_only=True. Please resave the pretrained classifier using this version ASAP.")
+ except BaseException:
+ logger.exception("Cannot load model from {}".format(filename))
+ raise
+ logger.debug("Loaded model {}".format(filename))
+
+ epochs_trained = checkpoint.get('epochs_trained', 0)
+ global_step = checkpoint.get('global_step', 0)
+ best_score = checkpoint.get('best_score', None)
+
+ # TODO: can remove this block once all models are retrained
+ if 'params' not in checkpoint:
+ model_params = {
+ 'model': checkpoint['model'],
+ 'config': checkpoint['config'],
+ 'labels': checkpoint['labels'],
+ 'extra_vocab': checkpoint['extra_vocab'],
+ }
+ else:
+ model_params = checkpoint['params']
+ # TODO: this can be removed once v1.10.0 is out
+ if isinstance(model_params['config'], SimpleNamespace):
+ model_params['config'] = vars(model_params['config'])
+ # TODO: these isinstance can go away after 1.10.0
+ model_type = model_params['config']['model_type']
+ if isinstance(model_type, str):
+ model_type = ModelType[model_type]
+ model_params['config']['model_type'] = model_type
+
+ if model_type == ModelType.CNN:
+ # TODO: these updates are only necessary during the
+ # transition to the @dataclass version of the config
+ # Once those are all saved, it is no longer necessary
+ # to patch existing models (since they will all be patched)
+ if 'has_charlm_forward' not in model_params['config']:
+ model_params['config']['has_charlm_forward'] = args.charlm_forward_file is not None
+ if 'has_charlm_backward' not in model_params['config']:
+ model_params['config']['has_charlm_backward'] = args.charlm_backward_file is not None
+ for argname in ['bert_hidden_layers', 'bert_finetune', 'force_bert_saved', 'use_peft',
+ 'lora_rank', 'lora_alpha', 'lora_dropout', 'lora_modules_to_save', 'lora_target_modules']:
+ model_params['config'][argname] = model_params['config'].get(argname, None)
+ # TODO: these isinstance can go away after 1.10.0
+ if isinstance(model_params['config']['wordvec_type'], str):
+ model_params['config']['wordvec_type'] = WVType[model_params['config']['wordvec_type']]
+ if isinstance(model_params['config']['extra_wordvec_method'], str):
+ model_params['config']['extra_wordvec_method'] = ExtraVectors[model_params['config']['extra_wordvec_method']]
+ model_params['config'] = CNNConfig(**model_params['config'])
+
+ pretrain = Trainer.load_pretrain(args, foundation_cache)
+ elmo_model = utils.load_elmo(args.elmo_model) if args.use_elmo else None
+
+ if model_params['config'].has_charlm_forward:
+ charmodel_forward = load_charlm(args.charlm_forward_file, foundation_cache)
+ else:
+ charmodel_forward = None
+ if model_params['config'].has_charlm_backward:
+ charmodel_backward = load_charlm(args.charlm_backward_file, foundation_cache)
+ else:
+ charmodel_backward = None
+
+ bert_model = model_params['config'].bert_model
+ # TODO: can get rid of the getattr after rebuilding all models
+ use_peft = getattr(model_params['config'], 'use_peft', False)
+ force_bert_saved = getattr(model_params['config'], 'force_bert_saved', False)
+ peft_name = None
+ if use_peft:
+ # if loading a peft model, we first load the base transformer
+ # the CNNClassifier code wraps the transformer in peft
+ # after creating the CNNClassifier with the peft wrapper,
+ # we *then* load the weights
+ bert_model, bert_tokenizer, peft_name = load_bert_with_peft(bert_model, "classifier", foundation_cache)
+ bert_model = load_peft_wrapper(bert_model, model_params['bert_lora'], vars(model_params['config']), logger, peft_name)
+ elif force_bert_saved:
+ bert_model, bert_tokenizer = load_bert(bert_model)
+ else:
+ bert_model, bert_tokenizer = load_bert(bert_model, foundation_cache)
+ model = cnn_classifier.CNNClassifier(pretrain=pretrain,
+ extra_vocab=model_params['extra_vocab'],
+ labels=model_params['labels'],
+ charmodel_forward=charmodel_forward,
+ charmodel_backward=charmodel_backward,
+ elmo_model=elmo_model,
+ bert_model=bert_model,
+ bert_tokenizer=bert_tokenizer,
+ force_bert_saved=force_bert_saved,
+ peft_name=peft_name,
+ args=model_params['config'])
+ elif model_type == ModelType.CONSTITUENCY:
+ # the constituency version doesn't have a peft feature yet
+ use_peft = False
+ pretrain_args = {
+ 'wordvec_pretrain_file': args.wordvec_pretrain_file,
+ 'charlm_forward_file': args.charlm_forward_file,
+ 'charlm_backward_file': args.charlm_backward_file,
+ }
+ # TODO: integrate with peft for the constituency version
+ tree_embedding = TreeEmbedding.model_from_params(model_params['tree_embedding'], pretrain_args, foundation_cache)
+ model_params['config'] = ConstituencyConfig(**model_params['config'])
+ model = constituency_classifier.ConstituencyClassifier(tree_embedding=tree_embedding,
+ labels=model_params['labels'],
+ args=model_params['config'])
+ else:
+ raise ValueError("Unknown model type {}".format(model_type))
+ model.load_state_dict(model_params['model'], strict=False)
+ model = model.to(args.device)
+
+ logger.debug("-- MODEL CONFIG --")
+ for k in model.config.__dict__:
+ logger.debug(" --{}: {}".format(k, model.config.__dict__[k]))
+
+ logger.debug("-- MODEL LABELS --")
+ logger.debug(" {}".format(" ".join(model.labels)))
+
+ optimizer = None
+ if load_optimizer:
+ optimizer = Trainer.build_optimizer(model, args)
+ if checkpoint.get('optimizer_state_dict', None) is not None:
+ for opt_name, opt_state_dict in checkpoint['optimizer_state_dict'].items():
+ optimizer[opt_name].load_state_dict(opt_state_dict)
+ else:
+ logger.info("Attempted to load optimizer to resume training, but optimizer not saved. Creating new optimizer")
+
+ trainer = Trainer(model, optimizer, epochs_trained, global_step, best_score)
+
+ return trainer
+
+
+ def load_pretrain(args, foundation_cache):
+ if args.wordvec_pretrain_file:
+ pretrain_file = args.wordvec_pretrain_file
+ elif args.wordvec_type:
+ pretrain_file = '{}/{}.{}.pretrain.pt'.format(args.save_dir, args.shorthand, args.wordvec_type.name.lower())
+ else:
+ raise RuntimeError("TODO: need to get the wv type back from get_wordvec_file")
+
+ logger.debug("Looking for pretrained vectors in {}".format(pretrain_file))
+ if os.path.exists(pretrain_file):
+ return load_pretrain(pretrain_file, foundation_cache)
+ elif args.wordvec_raw_file:
+ vec_file = args.wordvec_raw_file
+ logger.debug("Pretrain not found. Looking in {}".format(vec_file))
+ else:
+ vec_file = utils.get_wordvec_file(args.wordvec_dir, args.shorthand, args.wordvec_type.name.lower())
+ logger.debug("Pretrain not found. Looking in {}".format(vec_file))
+ pretrain = Pretrain(pretrain_file, vec_file, args.pretrain_max_vocab)
+ logger.debug("Embedding shape: %s" % str(pretrain.emb.shape))
+ return pretrain
+
+
+ @staticmethod
+ def build_new_model(args, train_set):
+ """
+ Load pretrained pieces and then build a new model
+ """
+ if train_set is None:
+ raise ValueError("Must have a train set to build a new model - needed for labels and delta word vectors")
+
+ labels = data.dataset_labels(train_set)
+
+ if args.model_type == ModelType.CNN:
+ pretrain = Trainer.load_pretrain(args, foundation_cache=None)
+ elmo_model = utils.load_elmo(args.elmo_model) if args.use_elmo else None
+ charmodel_forward = load_charlm(args.charlm_forward_file)
+ charmodel_backward = load_charlm(args.charlm_backward_file)
+ peft_name = None
+ bert_model, bert_tokenizer = load_bert(args.bert_model)
+
+ use_peft = getattr(args, "use_peft", False)
+ if use_peft:
+ peft_name = "sentiment"
+ bert_model = build_peft_wrapper(bert_model, vars(args), logger, adapter_name=peft_name)
+
+ extra_vocab = data.dataset_vocab(train_set)
+ force_bert_saved = args.bert_finetune
+ model = cnn_classifier.CNNClassifier(pretrain=pretrain,
+ extra_vocab=extra_vocab,
+ labels=labels,
+ charmodel_forward=charmodel_forward,
+ charmodel_backward=charmodel_backward,
+ elmo_model=elmo_model,
+ bert_model=bert_model,
+ bert_tokenizer=bert_tokenizer,
+ force_bert_saved=force_bert_saved,
+ peft_name=peft_name,
+ args=args)
+ model = model.to(args.device)
+ elif args.model_type == ModelType.CONSTITUENCY:
+ # this passes flags such as "constituency_backprop" from
+ # the classifier to the TreeEmbedding as the "backprop" flag
+ parser_args = { x[len("constituency_"):]: y for x, y in vars(args).items() if x.startswith("constituency_") }
+ parser_args.update({
+ "wordvec_pretrain_file": args.wordvec_pretrain_file,
+ "charlm_forward_file": args.charlm_forward_file,
+ "charlm_backward_file": args.charlm_backward_file,
+ "bert_model": args.bert_model,
+ # we found that finetuning from the classifier output
+ # all the way to the bert layers caused the bert model
+ # to go astray
+ # could make this an option... but it is much less accurate
+ # with the Bert finetuning
+ # noting that the constituency parser itself works better
+ # after finetuning, of course
+ "bert_finetune": False,
+ "stage1_bert_finetune": False,
+ })
+ logger.info("Building constituency classifier using %s as the base model" % args.constituency_model)
+ tree_embedding = TreeEmbedding.from_parser_file(parser_args)
+ model = constituency_classifier.ConstituencyClassifier(tree_embedding=tree_embedding,
+ labels=labels,
+ args=args)
+ model = model.to(args.device)
+ else:
+ raise ValueError("Unhandled model type {}".format(args.model_type))
+
+ optimizer = Trainer.build_optimizer(model, args)
+
+ return Trainer(model, optimizer)
+
+
+ @staticmethod
+ def build_optimizer(model, args):
+ return get_split_optimizer(args.optim.lower(), model, args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay, bert_learning_rate=args.bert_learning_rate, bert_weight_decay=args.weight_decay * args.bert_weight_decay, is_peft=args.use_peft)
diff --git a/stanza/stanza/models/constituency/__init__.py b/stanza/stanza/models/constituency/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/stanza/stanza/models/constituency/evaluate_treebanks.py b/stanza/stanza/models/constituency/evaluate_treebanks.py
new file mode 100644
index 0000000000000000000000000000000000000000..11f3084b3413a8f82eef0949f0a8023a1ec187dd
--- /dev/null
+++ b/stanza/stanza/models/constituency/evaluate_treebanks.py
@@ -0,0 +1,36 @@
+"""
+Read multiple treebanks, score the results.
+
+Reports the k-best score if multiple predicted treebanks are given.
+"""
+
+import argparse
+
+from stanza.models.constituency import tree_reader
+from stanza.server.parser_eval import EvaluateParser, ParseResult
+
+
+def main():
+ parser = argparse.ArgumentParser(description='Get scores for one or more treebanks against the gold')
+ parser.add_argument('gold', type=str, help='Which file to load as the gold trees')
+ parser.add_argument('pred', type=str, nargs='+', help='Which file(s) are the predictions. If more than one is given, the evaluation will be "k-best" with the first prediction treated as the canonical')
+ args = parser.parse_args()
+
+ print("Loading gold treebank: " + args.gold)
+ gold = tree_reader.read_treebank(args.gold)
+ print("Loading predicted treebanks: " + args.pred)
+ pred = [tree_reader.read_treebank(x) for x in args.pred]
+
+ full_results = [ParseResult(parses[0], [*parses[1:]])
+ for parses in zip(gold, *pred)]
+
+ if len(pred) <= 1:
+ kbest = None
+ else:
+ kbest = len(pred)
+
+ with EvaluateParser(kbest=kbest) as evaluator:
+ response = evaluator.process(full_results)
+
+if __name__ == '__main__':
+ main()
diff --git a/stanza/stanza/models/constituency/label_attention.py b/stanza/stanza/models/constituency/label_attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..1fee6585548da307597f5dd08974a341bb8187df
--- /dev/null
+++ b/stanza/stanza/models/constituency/label_attention.py
@@ -0,0 +1,726 @@
+import numpy as np
+import functools
+import sys
+import torch
+from torch.autograd import Variable
+import torch.nn as nn
+import torch.nn.init as init
+
+# publicly available versions alternate between torch.uint8 and torch.bool,
+# but that is for older versions of torch anyway
+DTYPE = torch.bool
+
+class BatchIndices:
+ """
+ Batch indices container class (used to implement packed batches)
+ """
+ def __init__(self, batch_idxs_np, device):
+ self.batch_idxs_np = batch_idxs_np
+ self.batch_idxs_torch = torch.as_tensor(batch_idxs_np, dtype=torch.long, device=device)
+
+ self.batch_size = int(1 + np.max(batch_idxs_np))
+
+ batch_idxs_np_extra = np.concatenate([[-1], batch_idxs_np, [-1]])
+ self.boundaries_np = np.nonzero(batch_idxs_np_extra[1:] != batch_idxs_np_extra[:-1])[0]
+
+ #print(f"boundaries_np: {self.boundaries_np}")
+ #print(f"boundaries_np[1:]: {self.boundaries_np[1:]}")
+ #print(f"boundaries_np[:-1]: {self.boundaries_np[:-1]}")
+ self.seq_lens_np = self.boundaries_np[1:] - self.boundaries_np[:-1]
+ #print(f"seq_lens_np: {self.seq_lens_np}")
+ #print(f"batch_size: {self.batch_size}")
+ assert len(self.seq_lens_np) == self.batch_size
+ self.max_len = int(np.max(self.boundaries_np[1:] - self.boundaries_np[:-1]))
+
+
+class FeatureDropoutFunction(torch.autograd.function.InplaceFunction):
+ @classmethod
+ def forward(cls, ctx, input, batch_idxs, p=0.5, train=False, inplace=False):
+ if p < 0 or p > 1:
+ raise ValueError("dropout probability has to be between 0 and 1, "
+ "but got {}".format(p))
+
+ ctx.p = p
+ ctx.train = train
+ ctx.inplace = inplace
+
+ if ctx.inplace:
+ ctx.mark_dirty(input)
+ output = input
+ else:
+ output = input.clone()
+
+ if ctx.p > 0 and ctx.train:
+ ctx.noise = input.new().resize_(batch_idxs.batch_size, input.size(1))
+ if ctx.p == 1:
+ ctx.noise.fill_(0)
+ else:
+ ctx.noise.bernoulli_(1 - ctx.p).div_(1 - ctx.p)
+ ctx.noise = ctx.noise[batch_idxs.batch_idxs_torch, :]
+ output.mul_(ctx.noise)
+
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ if ctx.p > 0 and ctx.train:
+ return grad_output.mul(ctx.noise), None, None, None, None
+ else:
+ return grad_output, None, None, None, None
+
+#
+class FeatureDropout(nn.Module):
+ """
+ Feature-level dropout: takes an input of size len x num_features and drops
+ each feature with probabibility p. A feature is dropped across the full
+ portion of the input that corresponds to a single batch element.
+ """
+ def __init__(self, p=0.5, inplace=False):
+ super().__init__()
+ if p < 0 or p > 1:
+ raise ValueError("dropout probability has to be between 0 and 1, "
+ "but got {}".format(p))
+ self.p = p
+ self.inplace = inplace
+
+ def forward(self, input, batch_idxs):
+ return FeatureDropoutFunction.apply(input, batch_idxs, self.p, self.training, self.inplace)
+
+
+
+class LayerNormalization(nn.Module):
+ def __init__(self, d_hid, eps=1e-3, affine=True):
+ super(LayerNormalization, self).__init__()
+
+ self.eps = eps
+ self.affine = affine
+ if self.affine:
+ self.a_2 = nn.Parameter(torch.ones(d_hid), requires_grad=True)
+ self.b_2 = nn.Parameter(torch.zeros(d_hid), requires_grad=True)
+
+ def forward(self, z):
+ if z.size(-1) == 1:
+ return z
+
+ mu = torch.mean(z, keepdim=True, dim=-1)
+ sigma = torch.std(z, keepdim=True, dim=-1)
+ ln_out = (z - mu.expand_as(z)) / (sigma.expand_as(z) + self.eps)
+ if self.affine:
+ ln_out = ln_out * self.a_2.expand_as(ln_out) + self.b_2.expand_as(ln_out)
+
+ return ln_out
+
+
+
+class ScaledDotProductAttention(nn.Module):
+ def __init__(self, d_model, attention_dropout=0.1):
+ super(ScaledDotProductAttention, self).__init__()
+ self.temper = d_model ** 0.5
+ self.dropout = nn.Dropout(attention_dropout)
+ self.softmax = nn.Softmax(dim=-1)
+
+ def forward(self, q, k, v, attn_mask=None):
+ # q: [batch, slot, feat] or (batch * d_l) x max_len x d_k
+ # k: [batch, slot, feat] or (batch * d_l) x max_len x d_k
+ # v: [batch, slot, feat] or (batch * d_l) x max_len x d_v
+ # q in LAL is (batch * d_l) x 1 x d_k
+
+ attn = torch.bmm(q, k.transpose(1, 2)) / self.temper # (batch * d_l) x max_len x max_len
+ # in LAL, gives: (batch * d_l) x 1 x max_len
+ # attention weights from each word to each word, for each label
+ # in best model (repeated q): attention weights from label (as vector weights) to each word
+
+ if attn_mask is not None:
+ assert attn_mask.size() == attn.size(), \
+ 'Attention mask shape {} mismatch ' \
+ 'with Attention logit tensor shape ' \
+ '{}.'.format(attn_mask.size(), attn.size())
+
+ attn.data.masked_fill_(attn_mask, -float('inf'))
+
+ attn = self.softmax(attn)
+ # Note that this makes the distribution not sum to 1. At some point it
+ # may be worth researching whether this is the right way to apply
+ # dropout to the attention.
+ # Note that the t2t code also applies dropout in this manner
+ attn = self.dropout(attn)
+ output = torch.bmm(attn, v) # (batch * d_l) x max_len x d_v
+ # in LAL, gives: (batch * d_l) x 1 x d_v
+
+ return output, attn
+
+
+class MultiHeadAttention(nn.Module):
+ """
+ Multi-head attention module
+ """
+
+ def __init__(self, n_head, d_model, d_k, d_v, residual_dropout=0.1, attention_dropout=0.1, d_positional=None):
+ super(MultiHeadAttention, self).__init__()
+
+ self.n_head = n_head
+ self.d_k = d_k
+ self.d_v = d_v
+
+ if not d_positional:
+ self.partitioned = False
+ else:
+ self.partitioned = True
+
+ if self.partitioned:
+ self.d_content = d_model - d_positional
+ self.d_positional = d_positional
+
+ self.w_qs1 = nn.Parameter(torch.FloatTensor(n_head, self.d_content, d_k // 2))
+ self.w_ks1 = nn.Parameter(torch.FloatTensor(n_head, self.d_content, d_k // 2))
+ self.w_vs1 = nn.Parameter(torch.FloatTensor(n_head, self.d_content, d_v // 2))
+
+ self.w_qs2 = nn.Parameter(torch.FloatTensor(n_head, self.d_positional, d_k // 2))
+ self.w_ks2 = nn.Parameter(torch.FloatTensor(n_head, self.d_positional, d_k // 2))
+ self.w_vs2 = nn.Parameter(torch.FloatTensor(n_head, self.d_positional, d_v // 2))
+
+ init.xavier_normal_(self.w_qs1)
+ init.xavier_normal_(self.w_ks1)
+ init.xavier_normal_(self.w_vs1)
+
+ init.xavier_normal_(self.w_qs2)
+ init.xavier_normal_(self.w_ks2)
+ init.xavier_normal_(self.w_vs2)
+ else:
+ self.w_qs = nn.Parameter(torch.FloatTensor(n_head, d_model, d_k))
+ self.w_ks = nn.Parameter(torch.FloatTensor(n_head, d_model, d_k))
+ self.w_vs = nn.Parameter(torch.FloatTensor(n_head, d_model, d_v))
+
+ init.xavier_normal_(self.w_qs)
+ init.xavier_normal_(self.w_ks)
+ init.xavier_normal_(self.w_vs)
+
+ self.attention = ScaledDotProductAttention(d_model, attention_dropout=attention_dropout)
+ self.layer_norm = LayerNormalization(d_model)
+
+ if not self.partitioned:
+ # The lack of a bias term here is consistent with the t2t code, though
+ # in my experiments I have never observed this making a difference.
+ self.proj = nn.Linear(n_head*d_v, d_model, bias=False)
+ else:
+ self.proj1 = nn.Linear(n_head*(d_v//2), self.d_content, bias=False)
+ self.proj2 = nn.Linear(n_head*(d_v//2), self.d_positional, bias=False)
+
+ self.residual_dropout = FeatureDropout(residual_dropout)
+
+ def split_qkv_packed(self, inp, qk_inp=None):
+ v_inp_repeated = inp.repeat(self.n_head, 1).view(self.n_head, -1, inp.size(-1)) # n_head x len_inp x d_model
+ if qk_inp is None:
+ qk_inp_repeated = v_inp_repeated
+ else:
+ qk_inp_repeated = qk_inp.repeat(self.n_head, 1).view(self.n_head, -1, qk_inp.size(-1))
+
+ if not self.partitioned:
+ q_s = torch.bmm(qk_inp_repeated, self.w_qs) # n_head x len_inp x d_k
+ k_s = torch.bmm(qk_inp_repeated, self.w_ks) # n_head x len_inp x d_k
+ v_s = torch.bmm(v_inp_repeated, self.w_vs) # n_head x len_inp x d_v
+ else:
+ q_s = torch.cat([
+ torch.bmm(qk_inp_repeated[:,:,:self.d_content], self.w_qs1),
+ torch.bmm(qk_inp_repeated[:,:,self.d_content:], self.w_qs2),
+ ], -1)
+ k_s = torch.cat([
+ torch.bmm(qk_inp_repeated[:,:,:self.d_content], self.w_ks1),
+ torch.bmm(qk_inp_repeated[:,:,self.d_content:], self.w_ks2),
+ ], -1)
+ v_s = torch.cat([
+ torch.bmm(v_inp_repeated[:,:,:self.d_content], self.w_vs1),
+ torch.bmm(v_inp_repeated[:,:,self.d_content:], self.w_vs2),
+ ], -1)
+ return q_s, k_s, v_s
+
+ def pad_and_rearrange(self, q_s, k_s, v_s, batch_idxs):
+ # Input is padded representation: n_head x len_inp x d
+ # Output is packed representation: (n_head * mb_size) x len_padded x d
+ # (along with masks for the attention and output)
+ n_head = self.n_head
+ d_k, d_v = self.d_k, self.d_v
+
+ len_padded = batch_idxs.max_len
+ mb_size = batch_idxs.batch_size
+ q_padded = q_s.new_zeros((n_head, mb_size, len_padded, d_k))
+ k_padded = k_s.new_zeros((n_head, mb_size, len_padded, d_k))
+ v_padded = v_s.new_zeros((n_head, mb_size, len_padded, d_v))
+ invalid_mask = q_s.new_ones((mb_size, len_padded), dtype=DTYPE)
+
+ for i, (start, end) in enumerate(zip(batch_idxs.boundaries_np[:-1], batch_idxs.boundaries_np[1:])):
+ q_padded[:,i,:end-start,:] = q_s[:,start:end,:]
+ k_padded[:,i,:end-start,:] = k_s[:,start:end,:]
+ v_padded[:,i,:end-start,:] = v_s[:,start:end,:]
+ invalid_mask[i, :end-start].fill_(False)
+
+ return(
+ q_padded.view(-1, len_padded, d_k),
+ k_padded.view(-1, len_padded, d_k),
+ v_padded.view(-1, len_padded, d_v),
+ invalid_mask.unsqueeze(1).expand(mb_size, len_padded, len_padded).repeat(n_head, 1, 1),
+ (~invalid_mask).repeat(n_head, 1),
+ )
+
+ def combine_v(self, outputs):
+ # Combine attention information from the different heads
+ n_head = self.n_head
+ outputs = outputs.view(n_head, -1, self.d_v) # n_head x len_inp x d_kv
+
+ if not self.partitioned:
+ # Switch from n_head x len_inp x d_v to len_inp x (n_head * d_v)
+ outputs = torch.transpose(outputs, 0, 1).contiguous().view(-1, n_head * self.d_v)
+
+ # Project back to residual size
+ outputs = self.proj(outputs)
+ else:
+ d_v1 = self.d_v // 2
+ outputs1 = outputs[:,:,:d_v1]
+ outputs2 = outputs[:,:,d_v1:]
+ outputs1 = torch.transpose(outputs1, 0, 1).contiguous().view(-1, n_head * d_v1)
+ outputs2 = torch.transpose(outputs2, 0, 1).contiguous().view(-1, n_head * d_v1)
+ outputs = torch.cat([
+ self.proj1(outputs1),
+ self.proj2(outputs2),
+ ], -1)
+
+ return outputs
+
+ def forward(self, inp, batch_idxs, qk_inp=None):
+ residual = inp
+
+ # While still using a packed representation, project to obtain the
+ # query/key/value for each head
+ q_s, k_s, v_s = self.split_qkv_packed(inp, qk_inp=qk_inp)
+ # n_head x len_inp x d_kv
+
+ # Switch to padded representation, perform attention, then switch back
+ q_padded, k_padded, v_padded, attn_mask, output_mask = self.pad_and_rearrange(q_s, k_s, v_s, batch_idxs)
+ # (n_head * batch) x len_padded x d_kv
+
+ outputs_padded, attns_padded = self.attention(
+ q_padded, k_padded, v_padded,
+ attn_mask=attn_mask,
+ )
+ outputs = outputs_padded[output_mask]
+ # (n_head * len_inp) x d_kv
+ outputs = self.combine_v(outputs)
+ # len_inp x d_model
+
+ outputs = self.residual_dropout(outputs, batch_idxs)
+
+ return self.layer_norm(outputs + residual), attns_padded
+
+#
+class PositionwiseFeedForward(nn.Module):
+ """
+ A position-wise feed forward module.
+
+ Projects to a higher-dimensional space before applying ReLU, then projects
+ back.
+ """
+
+ def __init__(self, d_hid, d_ff, relu_dropout=0.1, residual_dropout=0.1):
+ super(PositionwiseFeedForward, self).__init__()
+ self.w_1 = nn.Linear(d_hid, d_ff)
+ self.w_2 = nn.Linear(d_ff, d_hid)
+
+ self.layer_norm = LayerNormalization(d_hid)
+ self.relu_dropout = FeatureDropout(relu_dropout)
+ self.residual_dropout = FeatureDropout(residual_dropout)
+ self.relu = nn.ReLU()
+
+
+ def forward(self, x, batch_idxs):
+ residual = x
+
+ output = self.w_1(x)
+ output = self.relu_dropout(self.relu(output), batch_idxs)
+ output = self.w_2(output)
+
+ output = self.residual_dropout(output, batch_idxs)
+ return self.layer_norm(output + residual)
+
+#
+class PartitionedPositionwiseFeedForward(nn.Module):
+ def __init__(self, d_hid, d_ff, d_positional, relu_dropout=0.1, residual_dropout=0.1):
+ super().__init__()
+ self.d_content = d_hid - d_positional
+ self.w_1c = nn.Linear(self.d_content, d_ff//2)
+ self.w_1p = nn.Linear(d_positional, d_ff//2)
+ self.w_2c = nn.Linear(d_ff//2, self.d_content)
+ self.w_2p = nn.Linear(d_ff//2, d_positional)
+ self.layer_norm = LayerNormalization(d_hid)
+ self.relu_dropout = FeatureDropout(relu_dropout)
+ self.residual_dropout = FeatureDropout(residual_dropout)
+ self.relu = nn.ReLU()
+
+ def forward(self, x, batch_idxs):
+ residual = x
+ xc = x[:, :self.d_content]
+ xp = x[:, self.d_content:]
+
+ outputc = self.w_1c(xc)
+ outputc = self.relu_dropout(self.relu(outputc), batch_idxs)
+ outputc = self.w_2c(outputc)
+
+ outputp = self.w_1p(xp)
+ outputp = self.relu_dropout(self.relu(outputp), batch_idxs)
+ outputp = self.w_2p(outputp)
+
+ output = torch.cat([outputc, outputp], -1)
+
+ output = self.residual_dropout(output, batch_idxs)
+ return self.layer_norm(output + residual)
+
+class LabelAttention(nn.Module):
+ """
+ Single-head Attention layer for label-specific representations
+ """
+
+ def __init__(self, d_model, d_k, d_v, d_l, d_proj, combine_as_self, use_resdrop=True, q_as_matrix=False, residual_dropout=0.1, attention_dropout=0.1, d_positional=None):
+ super(LabelAttention, self).__init__()
+ self.d_k = d_k
+ self.d_v = d_v
+ self.d_l = d_l # Number of Labels
+ self.d_model = d_model # Model Dimensionality
+ self.d_proj = d_proj # Projection dimension of each label output
+ self.use_resdrop = use_resdrop # Using Residual Dropout?
+ self.q_as_matrix = q_as_matrix # Using a Matrix of Q to be multiplied with input instead of learned q vectors
+ self.combine_as_self = combine_as_self # Using the Combination Method of Self-Attention
+
+ if not d_positional:
+ self.partitioned = False
+ else:
+ self.partitioned = True
+
+ if self.partitioned:
+ if d_model <= d_positional:
+ raise ValueError("Unable to build LabelAttention. d_model %d <= d_positional %d" % (d_model, d_positional))
+ self.d_content = d_model - d_positional
+ self.d_positional = d_positional
+
+ if self.q_as_matrix:
+ self.w_qs1 = nn.Parameter(torch.FloatTensor(self.d_l, self.d_content, d_k // 2), requires_grad=True)
+ else:
+ self.w_qs1 = nn.Parameter(torch.FloatTensor(self.d_l, d_k // 2), requires_grad=True)
+ self.w_ks1 = nn.Parameter(torch.FloatTensor(self.d_l, self.d_content, d_k // 2), requires_grad=True)
+ self.w_vs1 = nn.Parameter(torch.FloatTensor(self.d_l, self.d_content, d_v // 2), requires_grad=True)
+
+ if self.q_as_matrix:
+ self.w_qs2 = nn.Parameter(torch.FloatTensor(self.d_l, self.d_positional, d_k // 2), requires_grad=True)
+ else:
+ self.w_qs2 = nn.Parameter(torch.FloatTensor(self.d_l, d_k // 2), requires_grad=True)
+ self.w_ks2 = nn.Parameter(torch.FloatTensor(self.d_l, self.d_positional, d_k // 2), requires_grad=True)
+ self.w_vs2 = nn.Parameter(torch.FloatTensor(self.d_l, self.d_positional, d_v // 2), requires_grad=True)
+
+ init.xavier_normal_(self.w_qs1)
+ init.xavier_normal_(self.w_ks1)
+ init.xavier_normal_(self.w_vs1)
+
+ init.xavier_normal_(self.w_qs2)
+ init.xavier_normal_(self.w_ks2)
+ init.xavier_normal_(self.w_vs2)
+ else:
+ if self.q_as_matrix:
+ self.w_qs = nn.Parameter(torch.FloatTensor(self.d_l, d_model, d_k), requires_grad=True)
+ else:
+ self.w_qs = nn.Parameter(torch.FloatTensor(self.d_l, d_k), requires_grad=True)
+ self.w_ks = nn.Parameter(torch.FloatTensor(self.d_l, d_model, d_k), requires_grad=True)
+ self.w_vs = nn.Parameter(torch.FloatTensor(self.d_l, d_model, d_v), requires_grad=True)
+
+ init.xavier_normal_(self.w_qs)
+ init.xavier_normal_(self.w_ks)
+ init.xavier_normal_(self.w_vs)
+
+ self.attention = ScaledDotProductAttention(d_model, attention_dropout=attention_dropout)
+ if self.combine_as_self:
+ self.layer_norm = LayerNormalization(d_model)
+ else:
+ self.layer_norm = LayerNormalization(self.d_proj)
+
+ if not self.partitioned:
+ # The lack of a bias term here is consistent with the t2t code, though
+ # in my experiments I have never observed this making a difference.
+ if self.combine_as_self:
+ self.proj = nn.Linear(self.d_l * d_v, d_model, bias=False)
+ else:
+ self.proj = nn.Linear(d_v, d_model, bias=False) # input dimension does not match, should be d_l * d_v
+ else:
+ if self.combine_as_self:
+ self.proj1 = nn.Linear(self.d_l*(d_v//2), self.d_content, bias=False)
+ self.proj2 = nn.Linear(self.d_l*(d_v//2), self.d_positional, bias=False)
+ else:
+ self.proj1 = nn.Linear(d_v//2, self.d_content, bias=False)
+ self.proj2 = nn.Linear(d_v//2, self.d_positional, bias=False)
+ if not self.combine_as_self:
+ self.reduce_proj = nn.Linear(d_model, self.d_proj, bias=False)
+
+ self.residual_dropout = FeatureDropout(residual_dropout)
+
+ def split_qkv_packed(self, inp, k_inp=None):
+ len_inp = inp.size(0)
+ v_inp_repeated = inp.repeat(self.d_l, 1).view(self.d_l, -1, inp.size(-1)) # d_l x len_inp x d_model
+ if k_inp is None:
+ k_inp_repeated = v_inp_repeated
+ else:
+ k_inp_repeated = k_inp.repeat(self.d_l, 1).view(self.d_l, -1, k_inp.size(-1)) # d_l x len_inp x d_model
+
+ if not self.partitioned:
+ if self.q_as_matrix:
+ q_s = torch.bmm(k_inp_repeated, self.w_qs) # d_l x len_inp x d_k
+ else:
+ q_s = self.w_qs.unsqueeze(1) # d_l x 1 x d_k
+ k_s = torch.bmm(k_inp_repeated, self.w_ks) # d_l x len_inp x d_k
+ v_s = torch.bmm(v_inp_repeated, self.w_vs) # d_l x len_inp x d_v
+ else:
+ if self.q_as_matrix:
+ q_s = torch.cat([
+ torch.bmm(k_inp_repeated[:,:,:self.d_content], self.w_qs1),
+ torch.bmm(k_inp_repeated[:,:,self.d_content:], self.w_qs2),
+ ], -1)
+ else:
+ q_s = torch.cat([
+ self.w_qs1.unsqueeze(1),
+ self.w_qs2.unsqueeze(1),
+ ], -1)
+ k_s = torch.cat([
+ torch.bmm(k_inp_repeated[:,:,:self.d_content], self.w_ks1),
+ torch.bmm(k_inp_repeated[:,:,self.d_content:], self.w_ks2),
+ ], -1)
+ v_s = torch.cat([
+ torch.bmm(v_inp_repeated[:,:,:self.d_content], self.w_vs1),
+ torch.bmm(v_inp_repeated[:,:,self.d_content:], self.w_vs2),
+ ], -1)
+ return q_s, k_s, v_s
+
+ def pad_and_rearrange(self, q_s, k_s, v_s, batch_idxs):
+ # Input is padded representation: n_head x len_inp x d
+ # Output is packed representation: (n_head * mb_size) x len_padded x d
+ # (along with masks for the attention and output)
+ n_head = self.d_l
+ d_k, d_v = self.d_k, self.d_v
+
+ len_padded = batch_idxs.max_len
+ mb_size = batch_idxs.batch_size
+ if self.q_as_matrix:
+ q_padded = q_s.new_zeros((n_head, mb_size, len_padded, d_k))
+ else:
+ q_padded = q_s.repeat(mb_size, 1, 1) # (d_l * mb_size) x 1 x d_k
+ k_padded = k_s.new_zeros((n_head, mb_size, len_padded, d_k))
+ v_padded = v_s.new_zeros((n_head, mb_size, len_padded, d_v))
+ invalid_mask = q_s.new_ones((mb_size, len_padded), dtype=DTYPE)
+
+ for i, (start, end) in enumerate(zip(batch_idxs.boundaries_np[:-1], batch_idxs.boundaries_np[1:])):
+ if self.q_as_matrix:
+ q_padded[:,i,:end-start,:] = q_s[:,start:end,:]
+ k_padded[:,i,:end-start,:] = k_s[:,start:end,:]
+ v_padded[:,i,:end-start,:] = v_s[:,start:end,:]
+ invalid_mask[i, :end-start].fill_(False)
+
+ if self.q_as_matrix:
+ q_padded = q_padded.view(-1, len_padded, d_k)
+ attn_mask = invalid_mask.unsqueeze(1).expand(mb_size, len_padded, len_padded).repeat(n_head, 1, 1)
+ else:
+ attn_mask = invalid_mask.unsqueeze(1).repeat(n_head, 1, 1)
+
+ output_mask = (~invalid_mask).repeat(n_head, 1)
+
+ return(
+ q_padded,
+ k_padded.view(-1, len_padded, d_k),
+ v_padded.view(-1, len_padded, d_v),
+ attn_mask,
+ output_mask,
+ )
+
+ def combine_v(self, outputs):
+ # Combine attention information from the different labels
+ d_l = self.d_l
+ outputs = outputs.view(d_l, -1, self.d_v) # d_l x len_inp x d_v
+
+ if not self.partitioned:
+ # Switch from d_l x len_inp x d_v to len_inp x d_l x d_v
+ if self.combine_as_self:
+ outputs = torch.transpose(outputs, 0, 1).contiguous().view(-1, d_l * self.d_v)
+ else:
+ outputs = torch.transpose(outputs, 0, 1)#.contiguous() #.view(-1, d_l * self.d_v)
+ # Project back to residual size
+ outputs = self.proj(outputs) # Becomes len_inp x d_l x d_model
+ else:
+ d_v1 = self.d_v // 2
+ outputs1 = outputs[:,:,:d_v1]
+ outputs2 = outputs[:,:,d_v1:]
+ if self.combine_as_self:
+ outputs1 = torch.transpose(outputs1, 0, 1).contiguous().view(-1, d_l * d_v1)
+ outputs2 = torch.transpose(outputs2, 0, 1).contiguous().view(-1, d_l * d_v1)
+ else:
+ outputs1 = torch.transpose(outputs1, 0, 1)#.contiguous() #.view(-1, d_l * d_v1)
+ outputs2 = torch.transpose(outputs2, 0, 1)#.contiguous() #.view(-1, d_l * d_v1)
+ outputs = torch.cat([
+ self.proj1(outputs1),
+ self.proj2(outputs2),
+ ], -1)#.contiguous()
+
+ return outputs
+
+ def forward(self, inp, batch_idxs, k_inp=None):
+ residual = inp # len_inp x d_model
+ #print()
+ #print(f"inp.shape: {inp.shape}")
+ len_inp = inp.size(0)
+ #print(f"len_inp: {len_inp}")
+
+ # While still using a packed representation, project to obtain the
+ # query/key/value for each head
+ q_s, k_s, v_s = self.split_qkv_packed(inp, k_inp=k_inp)
+ # d_l x len_inp x d_k
+ # q_s is d_l x 1 x d_k
+
+ # Switch to padded representation, perform attention, then switch back
+ q_padded, k_padded, v_padded, attn_mask, output_mask = self.pad_and_rearrange(q_s, k_s, v_s, batch_idxs)
+ # q_padded, k_padded, v_padded: (d_l * batch_size) x max_len x d_kv
+ # q_s is (d_l * batch_size) x 1 x d_kv
+
+ outputs_padded, attns_padded = self.attention(
+ q_padded, k_padded, v_padded,
+ attn_mask=attn_mask,
+ )
+ # outputs_padded: (d_l * batch_size) x max_len x d_kv
+ # in LAL: (d_l * batch_size) x 1 x d_kv
+ # on the best model, this is one value vector per label that is repeated max_len times
+ if not self.q_as_matrix:
+ outputs_padded = outputs_padded.repeat(1,output_mask.size(-1),1)
+ outputs = outputs_padded[output_mask]
+ # outputs: (d_l * len_inp) x d_kv or LAL: (d_l * len_inp) x d_kv
+ # output_mask: (d_l * batch_size) x max_len
+ outputs = self.combine_v(outputs)
+ #print(f"outputs shape: {outputs.shape}")
+ # outputs: len_inp x d_l x d_model, whereas a normal self-attention layer gets len_inp x d_model
+ if self.use_resdrop:
+ if self.combine_as_self:
+ outputs = self.residual_dropout(outputs, batch_idxs)
+ else:
+ outputs = torch.cat([self.residual_dropout(outputs[:,i,:], batch_idxs).unsqueeze(1) for i in range(self.d_l)], 1)
+ if self.combine_as_self:
+ outputs = self.layer_norm(outputs + inp)
+ else:
+ for l in range(self.d_l):
+ outputs[:, l, :] = outputs[:, l, :] + inp
+
+ outputs = self.reduce_proj(outputs) # len_inp x d_l x d_proj
+ outputs = self.layer_norm(outputs) # len_inp x d_l x d_proj
+ outputs = outputs.view(len_inp, -1).contiguous() # len_inp x (d_l * d_proj)
+
+ return outputs, attns_padded
+
+
+#
+class LabelAttentionModule(nn.Module):
+ """
+ Label Attention Module for label-specific representations
+ The module can be used right after the Partitioned Attention, or it can be experimented with for the transition stack
+ """
+ #
+ def __init__(self,
+ d_model,
+ d_input_proj,
+ d_k,
+ d_v,
+ d_l,
+ d_proj,
+ combine_as_self,
+ use_resdrop=True,
+ q_as_matrix=False,
+ residual_dropout=0.1,
+ attention_dropout=0.1,
+ d_positional=None,
+ d_ff=2048,
+ relu_dropout=0.2,
+ lattn_partitioned=True):
+ super().__init__()
+ self.ff_dim = d_proj * d_l
+
+ if not lattn_partitioned:
+ self.d_positional = 0
+ else:
+ self.d_positional = d_positional if d_positional else 0
+
+ if d_input_proj:
+ if d_input_proj <= self.d_positional:
+ raise ValueError("Illegal argument for d_input_proj: d_input_proj %d is smaller than d_positional %d" % (d_input_proj, self.d_positional))
+ self.input_projection = nn.Linear(d_model - self.d_positional, d_input_proj - self.d_positional, bias=False)
+ d_input = d_input_proj
+ else:
+ self.input_projection = None
+ d_input = d_model
+
+ self.label_attention = LabelAttention(d_input,
+ d_k,
+ d_v,
+ d_l,
+ d_proj,
+ combine_as_self,
+ use_resdrop,
+ q_as_matrix,
+ residual_dropout,
+ attention_dropout,
+ self.d_positional)
+
+ if not lattn_partitioned:
+ self.lal_ff = PositionwiseFeedForward(self.ff_dim,
+ d_ff,
+ relu_dropout,
+ residual_dropout)
+ else:
+ self.lal_ff = PartitionedPositionwiseFeedForward(self.ff_dim,
+ d_ff,
+ self.d_positional,
+ relu_dropout,
+ residual_dropout)
+
+ def forward(self, word_embeddings, tagged_word_lists):
+ if self.input_projection:
+ if self.d_positional > 0:
+ word_embeddings = [torch.cat((self.input_projection(sentence[:, :-self.d_positional]),
+ sentence[:, -self.d_positional:]), dim=1)
+ for sentence in word_embeddings]
+ else:
+ word_embeddings = [self.input_projection(sentence) for sentence in word_embeddings]
+ # Extract Labeled Representation
+ packed_len = sum(sentence.shape[0] for sentence in word_embeddings)
+ batch_idxs = np.zeros(packed_len, dtype=int)
+
+ batch_size = len(word_embeddings)
+ i = 0
+
+ sentence_lengths = [0] * batch_size
+ for sentence_idx, sentence in enumerate(word_embeddings):
+ sentence_lengths[sentence_idx] = len(sentence)
+ for word in sentence:
+ batch_idxs[i] = sentence_idx
+ i += 1
+
+ batch_indices = batch_idxs
+ batch_idxs = BatchIndices(batch_idxs, word_embeddings[0].device)
+
+ new_embeds = []
+ for sentence_idx, batch in enumerate(word_embeddings):
+ for word_idx, embed in enumerate(batch):
+ if word_idx < sentence_lengths[sentence_idx]:
+ new_embeds.append(embed)
+
+ new_word_embeddings = torch.stack(new_embeds)
+
+ labeled_representations, _ = self.label_attention(new_word_embeddings, batch_idxs)
+ labeled_representations = self.lal_ff(labeled_representations, batch_idxs)
+ final_labeled_representations = [[] for i in range(batch_size)]
+
+ for idx, embed in enumerate(labeled_representations):
+ final_labeled_representations[batch_indices[idx]].append(embed)
+
+ for idx, representation in enumerate(final_labeled_representations):
+ final_labeled_representations[idx] = torch.stack(representation)
+
+ return final_labeled_representations
+
diff --git a/stanza/stanza/models/constituency/lstm_tree_stack.py b/stanza/stanza/models/constituency/lstm_tree_stack.py
new file mode 100644
index 0000000000000000000000000000000000000000..0846304c388e724d36894c827f40f9289f5f8a48
--- /dev/null
+++ b/stanza/stanza/models/constituency/lstm_tree_stack.py
@@ -0,0 +1,91 @@
+"""
+Keeps an LSTM in TreeStack form.
+
+The TreeStack nodes keep the hx and cx for the LSTM, along with a
+"value" which represents whatever the user needs to store.
+
+The TreeStacks can be ppped to get back to the previous LSTM state.
+
+The module itself implements three methods: initial_state, push_states, output
+"""
+
+from collections import namedtuple
+
+import torch
+import torch.nn as nn
+
+from stanza.models.constituency.tree_stack import TreeStack
+
+Node = namedtuple("Node", ['value', 'lstm_hx', 'lstm_cx'])
+
+class LSTMTreeStack(nn.Module):
+ def __init__(self, input_size, hidden_size, num_lstm_layers, dropout, uses_boundary_vector, input_dropout):
+ """
+ Prepare LSTM and parameters
+
+ input_size: dimension of the inputs to the LSTM
+ hidden_size: LSTM internal & output dimension
+ num_lstm_layers: how many layers of LSTM to use
+ dropout: value of the LSTM dropout
+ uses_boundary_vector: if set, learn a start_embedding parameter. otherwise, use zeros
+ input_dropout: an nn.Module to dropout inputs. TODO: allow a float parameter as well
+ """
+ super().__init__()
+
+ self.uses_boundary_vector = uses_boundary_vector
+
+ # The start embedding needs to be input_size as we put it through the LSTM
+ if uses_boundary_vector:
+ self.register_parameter('start_embedding', torch.nn.Parameter(0.2 * torch.randn(input_size, requires_grad=True)))
+ else:
+ self.register_buffer('input_zeros', torch.zeros(num_lstm_layers, 1, input_size))
+ self.register_buffer('hidden_zeros', torch.zeros(num_lstm_layers, 1, hidden_size))
+
+ self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_lstm_layers, dropout=dropout)
+ self.input_dropout = input_dropout
+
+
+ def initial_state(self, initial_value=None):
+ """
+ Return an initial state, either based on zeros or based on the initial embedding and LSTM
+
+ Note that LSTM start operation is already batched, in a sense
+ The subsequent batch built this way will be used for batch_size trees
+
+ Returns a stack with None value, hx & cx either based on the
+ start_embedding or zeros, and no parent.
+ """
+ if self.uses_boundary_vector:
+ start = self.start_embedding.unsqueeze(0).unsqueeze(0)
+ output, (hx, cx) = self.lstm(start)
+ start = output[0, 0, :]
+ else:
+ start = self.input_zeros
+ hx = self.hidden_zeros
+ cx = self.hidden_zeros
+ return TreeStack(value=Node(initial_value, hx, cx), parent=None, length=1)
+
+ def push_states(self, stacks, values, inputs):
+ """
+ Starting from a list of current stacks, put the inputs through the LSTM and build new stack nodes.
+
+ B = stacks.len() = values.len()
+
+ inputs must be of shape 1 x B x input_size
+ """
+ inputs = self.input_dropout(inputs)
+
+ hx = torch.cat([t.value.lstm_hx for t in stacks], axis=1)
+ cx = torch.cat([t.value.lstm_cx for t in stacks], axis=1)
+ output, (hx, cx) = self.lstm(inputs, (hx, cx))
+ new_stacks = [stack.push(Node(transition, hx[:, i:i+1, :], cx[:, i:i+1, :]))
+ for i, (stack, transition) in enumerate(zip(stacks, values))]
+ return new_stacks
+
+ def output(self, stack):
+ """
+ Return the last layer of the lstm_hx as the output from a stack
+
+ Refactored so that alternate structures have an easy way of getting the output
+ """
+ return stack.value.lstm_hx[-1, 0, :]
diff --git a/stanza/stanza/models/constituency/score_converted_dependencies.py b/stanza/stanza/models/constituency/score_converted_dependencies.py
new file mode 100644
index 0000000000000000000000000000000000000000..b75fa76519c27dfea60b52847f77bf2591f62903
--- /dev/null
+++ b/stanza/stanza/models/constituency/score_converted_dependencies.py
@@ -0,0 +1,65 @@
+"""
+Script which processes a dependency file by using the constituency parser, then converting with the CoreNLP converter
+
+Currently this does not have the constituency parser as an option,
+although that is easy to add.
+
+Only English is supported, as only English is available in the CoreNLP converter
+"""
+
+import argparse
+import os
+import tempfile
+
+import stanza
+from stanza.models.constituency import retagging
+from stanza.models.depparse import scorer
+from stanza.utils.conll import CoNLL
+
+def score_converted_dependencies(args):
+ if args['lang'] != 'en':
+ raise ValueError("Converting and scoring dependencies is currently only supported for English")
+
+ constituency_package = args['constituency_package']
+ pipeline_args = {'lang': args['lang'],
+ 'tokenize_pretokenized': True,
+ 'package': {'pos': args['retag_package'], 'depparse': 'converter', 'constituency': constituency_package},
+ 'processors': 'tokenize, pos, constituency, depparse'}
+ pipeline = stanza.Pipeline(**pipeline_args)
+
+ input_doc = CoNLL.conll2doc(args['eval_file'])
+ output_doc = pipeline(input_doc)
+ print("Processed %d sentences" % len(output_doc.sentences))
+ # reload - the pipeline clobbered the gold values
+ input_doc = CoNLL.conll2doc(args['eval_file'])
+
+ scorer.score_named_dependencies(output_doc, input_doc)
+ with tempfile.TemporaryDirectory() as tempdir:
+ output_path = os.path.join(tempdir, "converted.conll")
+
+ CoNLL.write_doc2conll(output_doc, output_path)
+
+ _, _, score = scorer.score(output_path, args['eval_file'])
+
+ print("Parser score:")
+ print("{} {:.2f}".format(constituency_package, score*100))
+
+
+def main():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument('--lang', default='en', type=str, help='Language')
+ parser.add_argument('--eval_file', default="extern_data/ud2/ud-treebanks-v2.13/UD_English-EWT/en_ewt-ud-test.conllu", help='Input file for data loader.')
+ parser.add_argument('--constituency_package', default="ptb3-revised_electra-large", help='Which constituency parser to use for converting')
+
+ retagging.add_retag_args(parser)
+ args = parser.parse_args()
+
+ args = vars(args)
+ retagging.postprocess_args(args)
+
+ score_converted_dependencies(args)
+
+if __name__ == '__main__':
+ main()
+
diff --git a/stanza/stanza/models/constituency/text_processing.py b/stanza/stanza/models/constituency/text_processing.py
new file mode 100644
index 0000000000000000000000000000000000000000..90d66d6a9d29b364fe436ca0b6b2460033b03fe2
--- /dev/null
+++ b/stanza/stanza/models/constituency/text_processing.py
@@ -0,0 +1,166 @@
+import os
+
+import logging
+
+from stanza.models.common import utils
+from stanza.models.constituency.utils import retag_tags
+from stanza.models.constituency.trainer import Trainer
+from stanza.models.constituency.tree_reader import read_trees
+from stanza.utils.get_tqdm import get_tqdm
+
+logger = logging.getLogger('stanza')
+tqdm = get_tqdm()
+
+def read_tokenized_file(tokenized_file):
+ """
+ Read sentences from a tokenized file, potentially replacing _ with space for languages such as VI
+ """
+ with open(tokenized_file, encoding='utf-8') as fin:
+ lines = fin.readlines()
+ lines = [x.strip() for x in lines]
+ lines = [x for x in lines if x]
+ docs = [[word if all(x == '_' for x in word) else word.replace("_", " ") for word in sentence.split()] for sentence in lines]
+ ids = [None] * len(docs)
+ return docs, ids
+
+def read_xml_tree_file(tree_file):
+ """
+ Read sentences from a file of the format unique to VLSP test sets
+
+ in particular, it should be multiple blocks of
+
+
+ (tree ...)
+
+ """
+ with open(tree_file, encoding='utf-8') as fin:
+ lines = fin.readlines()
+ lines = [x.strip() for x in lines]
+ lines = [x for x in lines if x]
+ docs = []
+ ids = []
+ tree_id = None
+ tree_text = []
+ for line in lines:
+ if line.startswith(" 1:
+ tree_id = tree_id[1]
+ if tree_id.endswith(">"):
+ tree_id = tree_id[:-1]
+ tree_id = int(tree_id)
+ else:
+ tree_id = None
+ elif line.startswith(" 1000 and use_tqdm:
+ self.line_iterator = iter(tqdm(self.lines))
+ else:
+ self.line_iterator = iter(self.lines)
+
+
+class FileTokenIterator(TokenIterator):
+ def __init__(self, filename):
+ super().__init__()
+ self.filename = filename
+
+ def __enter__(self):
+ # TODO: use the file_size instead of counting the lines
+ # file_size = Path(self.filename).stat().st_size
+ with open(self.filename) as fin:
+ num_lines = sum(1 for _ in fin)
+
+ self.file_obj = open(self.filename)
+ if num_lines > 1000:
+ self.line_iterator = iter(tqdm(self.file_obj, total=num_lines))
+ else:
+ self.line_iterator = iter(self.file_obj)
+ return self
+
+ def __exit__(self, exc_type, exc_value, exc_tb):
+ if self.file_obj:
+ self.file_obj.close()
+
+def read_token_iterator(token_iterator, broken_ok, tree_callback):
+ trees = []
+ token = next(token_iterator, None)
+ while token:
+ if token == OPEN_PAREN:
+ next_tree = read_single_tree(token_iterator, broken_ok=broken_ok)
+ if next_tree is None:
+ raise ValueError("Tree reader somehow created a None tree! Line number %d" % token_iterator.line_num)
+ if tree_callback is not None:
+ transformed = tree_callback(next_tree)
+ if transformed is not None:
+ trees.append(transformed)
+ else:
+ trees.append(next_tree)
+ token = next(token_iterator, None)
+ elif token == CLOSE_PAREN:
+ raise ExtraCloseTreeError(token_iterator.line_num)
+ else:
+ raise ValueError("Tree document had text between trees! Line number %d" % token_iterator.line_num)
+
+ return trees
+
+
+def read_trees(text, broken_ok=False, tree_callback=None, use_tqdm=True):
+ """
+ Reads multiple trees from the text
+
+ TODO: some of the error cases we hit can be recovered from
+ """
+ token_iterator = TextTokenIterator(text, use_tqdm)
+ return read_token_iterator(token_iterator, broken_ok=broken_ok, tree_callback=tree_callback)
+
+def read_tree_file(filename, broken_ok=False, tree_callback=None):
+ """
+ Read all of the trees in the given file
+ """
+ with FileTokenIterator(filename) as token_iterator:
+ trees = read_token_iterator(token_iterator, broken_ok=broken_ok, tree_callback=tree_callback)
+ return trees
+
+def read_directory(dirname, broken_ok=False, tree_callback=None):
+ """
+ Read all of the trees in all of the files in a directory
+ """
+ trees = []
+ for filename in sorted(os.listdir(dirname)):
+ full_name = os.path.join(dirname, filename)
+ trees.extend(read_tree_file(full_name, broken_ok, tree_callback))
+ return trees
+
+def read_treebank(filename, tree_callback=None):
+ """
+ Read a treebank and alter the trees to be a simpler format for learning to parse
+ """
+ logger.info("Reading trees from %s", filename)
+ trees = read_tree_file(filename, tree_callback=tree_callback)
+ trees = [t.prune_none().simplify_labels() for t in trees]
+
+ illegal_trees = [t for t in trees if len(t.children) > 1]
+ if len(illegal_trees) > 0:
+ raise ValueError("Found {} tree(s) which had non-unary transitions at the ROOT. First illegal tree: {:P}".format(len(illegal_trees), illegal_trees[0]))
+
+ return trees
+
+def main():
+ """
+ Reads a sample tree
+ """
+ text="( (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))"
+ trees = read_trees(text)
+ print(trees)
+
+if __name__ == '__main__':
+ main()
diff --git a/stanza/stanza/models/constituency/tree_stack.py b/stanza/stanza/models/constituency/tree_stack.py
new file mode 100644
index 0000000000000000000000000000000000000000..b44108b018299b4cc75963718a49714e65f94e9a
--- /dev/null
+++ b/stanza/stanza/models/constituency/tree_stack.py
@@ -0,0 +1,57 @@
+"""
+A utilitiy class for keeping track of intermediate parse states
+"""
+
+from collections import namedtuple
+
+class TreeStack(namedtuple('TreeStack', ['value', 'parent', 'length'])):
+ """
+ A stack which can branch in several directions, as long as you
+ keep track of the branching heads
+
+ An example usage is when K constituents are removed at once
+ to create a new constituent, and then the LSTM which tracks the
+ values of the constituents is updated starting from the Kth
+ output of the LSTM with the new value.
+
+ We don't simply keep track of a single stack object using a deque
+ because versions of the parser which use a beam will want to be
+ able to branch in different directions from the same base stack
+
+ Another possible usage is if an oracle is used for training
+ in a manner where some fraction of steps are non-gold steps,
+ but we also want to take a gold step from the same state.
+ Eg, parser gets to state X, wants to make incorrect transition T
+ instead of gold transition G, and so we continue training both
+ X+G and X+T. If we only represent the state X with standard
+ python stacks, it would not be possible to track both of these
+ states at the same time without copying the entire thing.
+
+ Value can be as transition, a word, or a partially built constituent
+
+ Implemented as a namedtuple to make it a bit more efficient
+ """
+ def pop(self):
+ return self.parent
+
+ def push(self, value):
+ # returns a new stack node which points to this
+ return TreeStack(value, self, self.length+1)
+
+ def __iter__(self):
+ stack = self
+ while stack.parent is not None:
+ yield stack.value
+ stack = stack.parent
+ yield stack.value
+
+ def __reversed__(self):
+ items = list(iter(self))
+ for item in reversed(items):
+ yield item
+
+ def __str__(self):
+ return "TreeStack(%s)" % ", ".join([str(x) for x in self])
+
+ def __len__(self):
+ return self.length
diff --git a/stanza/stanza/models/constituency/utils.py b/stanza/stanza/models/constituency/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..87f411b888c67bc57ede0ef491c7a83979bad916
--- /dev/null
+++ b/stanza/stanza/models/constituency/utils.py
@@ -0,0 +1,375 @@
+"""
+Collects a few of the conparser utility methods which don't belong elsewhere
+"""
+
+from collections import Counter
+import logging
+import warnings
+
+import torch.nn as nn
+from torch import optim
+
+from stanza.models.common.doc import TEXT, Document
+from stanza.models.common.utils import get_optimizer
+from stanza.models.constituency.base_model import SimpleModel
+from stanza.models.constituency.parse_transitions import TransitionScheme
+from stanza.models.constituency.parse_tree import Tree
+from stanza.utils.get_tqdm import get_tqdm
+
+tqdm = get_tqdm()
+
+DEFAULT_LEARNING_RATES = { "adamw": 0.0002, "adadelta": 1.0, "sgd": 0.001, "adabelief": 0.00005, "madgrad": 0.0000007 , "mirror_madgrad": 0.00005 }
+DEFAULT_LEARNING_EPS = { "adabelief": 1e-12, "adadelta": 1e-6, "adamw": 1e-8 }
+DEFAULT_LEARNING_RHO = 0.9
+DEFAULT_MOMENTUM = { "madgrad": 0.9, "mirror_madgrad": 0.9, "sgd": 0.9 }
+
+tlogger = logging.getLogger('stanza.constituency.trainer')
+
+# madgrad experiment for weight decay
+# with learning_rate set to 0.0000007 and momentum 0.9
+# on en_wsj, with a baseline model trained on adadela for 200,
+# then madgrad used to further improve that model
+# 0.00000002.out: 0.9590347746438835
+# 0.00000005.out: 0.9591378819960182
+# 0.0000001.out: 0.9595450596319405
+# 0.0000002.out: 0.9594603134479271
+# 0.0000005.out: 0.9591317672706594
+# 0.000001.out: 0.9592548741021389
+# 0.000002.out: 0.9598395477013945
+# 0.000003.out: 0.9594974271553495
+# 0.000004.out: 0.9596665982603754
+# 0.000005.out: 0.9591620720706487
+DEFAULT_WEIGHT_DECAY = { "adamw": 0.05, "adadelta": 0.02, "sgd": 0.01, "adabelief": 1.2e-6, "madgrad": 2e-6, "mirror_madgrad": 2e-6 }
+
+def retag_tags(doc, pipelines, xpos):
+ """
+ Returns a list of list of tags for the items in doc
+
+ doc can be anything which feeds into the pipeline(s)
+ pipelines are a list of 1 or more retag pipelines
+ if multiple pipelines are given, majority vote wins
+ """
+ tag_lists = []
+ for pipeline in pipelines:
+ doc = pipeline(doc)
+ tag_lists.append([[x.xpos if xpos else x.upos for x in sentence.words] for sentence in doc.sentences])
+ # tag_lists: for N pipeline, S sentences
+ # we now have N lists of S sentences each
+ # for sentence in zip(*tag_lists): N lists of |s| tags for this given sentence s
+ # for tag in zip(*sentence): N predicted tags.
+ # most common one in the Counter will be chosen
+ tag_lists = [[Counter(tag).most_common(1)[0][0] for tag in zip(*sentence)]
+ for sentence in zip(*tag_lists)]
+ return tag_lists
+
+def retag_trees(trees, pipelines, xpos=True):
+ """
+ Retag all of the trees using the given processor
+
+ Returns a list of new trees
+ """
+ if len(trees) == 0:
+ return trees
+
+ new_trees = []
+ chunk_size = 1000
+ with tqdm(total=len(trees)) as pbar:
+ for chunk_start in range(0, len(trees), chunk_size):
+ chunk_end = min(chunk_start + chunk_size, len(trees))
+ chunk = trees[chunk_start:chunk_end]
+ sentences = []
+ try:
+ for idx, tree in enumerate(chunk):
+ tokens = [{TEXT: pt.children[0].label} for pt in tree.yield_preterminals()]
+ sentences.append(tokens)
+ except ValueError as e:
+ raise ValueError("Unable to process tree %d" % (idx + chunk_start)) from e
+
+ doc = Document(sentences)
+ tag_lists = retag_tags(doc, pipelines, xpos)
+
+ for tree_idx, (tree, tags) in enumerate(zip(chunk, tag_lists)):
+ try:
+ if any(tag is None for tag in tags):
+ raise RuntimeError("Tagged tree #{} with a None tag!\n{}\n{}".format(tree_idx, tree, tags))
+ new_tree = tree.replace_tags(tags)
+ new_trees.append(new_tree)
+ pbar.update(1)
+ except ValueError as e:
+ raise ValueError("Failed to properly retag tree #{}: {}".format(tree_idx, tree)) from e
+ if len(new_trees) != len(trees):
+ raise AssertionError("Retagged tree counts did not match: {} vs {}".format(len(new_trees), len(trees)))
+ return new_trees
+
+
+# experimental results on nonlinearities
+# this is on a VI dataset, VLSP_22, using 1/10th of the data as a dev set
+# (no released test set at the time of the experiment)
+# original non-Bert tagger, with 1 iteration each instead of averaged over 5
+# considering the number of experiments and the length of time they would take
+#
+# Gelu had the highest score, which tracks with other experiments run.
+# Note that publicly released models have typically used Relu
+# on account of the runtime speed improvement
+#
+# Anyway, a larger experiment of 5x models on gelu or relu, using the
+# Roberta POS tagger and a corpus of silver trees, resulted in 0.8270
+# for relu and 0.8248 for gelu. So it is not even clear that
+# switching to gelu would be an accuracy improvement.
+#
+# Gelu: 82.32
+# Relu: 82.14
+# Mish: 81.95
+# Relu6: 81.91
+# Silu: 81.90
+# ELU: 81.73
+# Hardswish: 81.67
+# Softsign: 81.63
+# Hardtanh: 81.44
+# Celu: 81.43
+# Selu: 81.17
+# TODO: need to redo the prelu experiment with
+# possibly different numbers of parameters
+# and proper weight decay
+# Prelu: 80.95 (terminated early)
+# Softplus: 80.94
+# Logsigmoid: 80.91
+# Hardsigmoid: 79.03
+# RReLU: 77.00
+# Hardshrink: failed
+# Softshrink: failed
+NONLINEARITY = {
+ 'celu': nn.CELU,
+ 'elu': nn.ELU,
+ 'gelu': nn.GELU,
+ 'hardshrink': nn.Hardshrink,
+ 'hardtanh': nn.Hardtanh,
+ 'leaky_relu': nn.LeakyReLU,
+ 'logsigmoid': nn.LogSigmoid,
+ 'prelu': nn.PReLU,
+ 'relu': nn.ReLU,
+ 'relu6': nn.ReLU6,
+ 'rrelu': nn.RReLU,
+ 'selu': nn.SELU,
+ 'softplus': nn.Softplus,
+ 'softshrink': nn.Softshrink,
+ 'softsign': nn.Softsign,
+ 'tanhshrink': nn.Tanhshrink,
+ 'tanh': nn.Tanh,
+}
+
+# separating these out allows for backwards compatibility with earlier versions of pytorch
+# NOTE torch compatibility: if we ever *release* models with these
+# activation functions, we will need to break that compatibility
+
+nonlinearity_list = [
+ 'GLU',
+ 'Hardsigmoid',
+ 'Hardswish',
+ 'Mish',
+ 'SiLU',
+]
+
+for nonlinearity in nonlinearity_list:
+ if hasattr(nn, nonlinearity):
+ NONLINEARITY[nonlinearity.lower()] = getattr(nn, nonlinearity)
+
+def build_nonlinearity(nonlinearity):
+ """
+ Look up "nonlinearity" in a map from function name to function, build the appropriate layer.
+ """
+ if nonlinearity in NONLINEARITY:
+ return NONLINEARITY[nonlinearity]()
+ raise ValueError('Chosen value of nonlinearity, "%s", not handled' % nonlinearity)
+
+def build_optimizer(args, model, build_simple_adadelta=False):
+ """
+ Build an optimizer based on the arguments given
+
+ If we are "multistage" training and epochs_trained < epochs // 2,
+ we build an AdaDelta optimizer instead of whatever was requested
+ The build_simple_adadelta parameter controls this
+ """
+ bert_learning_rate = 0.0
+ bert_weight_decay = args['bert_weight_decay']
+ if build_simple_adadelta:
+ optim_type = 'adadelta'
+ bert_finetune = args.get('stage1_bert_finetune', False)
+ if bert_finetune:
+ bert_learning_rate = args['stage1_bert_learning_rate']
+ learning_beta2 = 0.999 # doesn't matter for AdaDelta
+ learning_eps = DEFAULT_LEARNING_EPS['adadelta']
+ learning_rate = args['stage1_learning_rate']
+ learning_rho = DEFAULT_LEARNING_RHO
+ momentum = None # also doesn't matter for AdaDelta
+ weight_decay = DEFAULT_WEIGHT_DECAY['adadelta']
+ else:
+ optim_type = args['optim'].lower()
+ bert_finetune = args.get('bert_finetune', False)
+ if bert_finetune:
+ bert_learning_rate = args['bert_learning_rate']
+ learning_beta2 = args['learning_beta2']
+ learning_eps = args['learning_eps']
+ learning_rate = args['learning_rate']
+ learning_rho = args['learning_rho']
+ momentum = args['learning_momentum']
+ weight_decay = args['learning_weight_decay']
+
+ # TODO: allow rho as an arg for AdaDelta
+ return get_optimizer(name=optim_type,
+ model=model,
+ lr=learning_rate,
+ betas=(0.9, learning_beta2),
+ eps=learning_eps,
+ momentum=momentum,
+ weight_decay=weight_decay,
+ bert_learning_rate=bert_learning_rate,
+ bert_weight_decay=weight_decay*bert_weight_decay,
+ is_peft=args.get('use_peft', False),
+ bert_finetune_layers=args['bert_finetune_layers'],
+ opt_logger=tlogger)
+
+def build_scheduler(args, optimizer, first_optimizer=False):
+ """
+ Build the scheduler for the conparser based on its args
+
+ Used to use a warmup for learning rate, but that wasn't working very well
+ Now, we just use a ReduceLROnPlateau, which does quite well
+ """
+ #if args.get('learning_rate_warmup', 0) <= 0:
+ # # TODO: is there an easier way to make an empty scheduler?
+ # lr_lambda = lambda x: 1.0
+ #else:
+ # warmup_end = args['learning_rate_warmup']
+ # def lr_lambda(x):
+ # if x >= warmup_end:
+ # return 1.0
+ # return x / warmup_end
+
+ #scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
+
+ if first_optimizer:
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=args['learning_rate_factor'], patience=args['learning_rate_patience'], cooldown=args['learning_rate_cooldown'], min_lr=args['stage1_learning_rate_min_lr'])
+ else:
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=args['learning_rate_factor'], patience=args['learning_rate_patience'], cooldown=args['learning_rate_cooldown'], min_lr=args['learning_rate_min_lr'])
+ return scheduler
+
+def initialize_linear(linear, nonlinearity, bias):
+ """
+ Initializes the bias to a positive value, hopefully preventing dead neurons
+ """
+ if nonlinearity in ('relu', 'leaky_relu'):
+ nn.init.kaiming_normal_(linear.weight, nonlinearity=nonlinearity)
+ nn.init.uniform_(linear.bias, 0, 1 / (bias * 2) ** 0.5)
+
+def add_predict_output_args(parser):
+ """
+ Args specifically for the output location of data
+ """
+ parser.add_argument('--predict_dir', type=str, default=".", help='Where to write the predictions during --mode predict. Pred and orig files will be written - the orig file will be retagged if that is requested. Writing the orig file is useful for removing None and retagging')
+ parser.add_argument('--predict_file', type=str, default=None, help='Base name for writing predictions')
+ parser.add_argument('--predict_format', type=str, default="{:_O}", help='Format to use when writing predictions')
+
+ parser.add_argument('--predict_output_gold_tags', default=False, action='store_true', help='Output gold tags as part of the evaluation - useful for putting the trees through EvalB')
+
+def postprocess_predict_output_args(args):
+ if len(args['predict_format']) <= 2 or (len(args['predict_format']) <= 3 and args['predict_format'].endswith("Vi")):
+ args['predict_format'] = "{:" + args['predict_format'] + "}"
+
+
+def get_open_nodes(trees, transition_scheme):
+ """
+ Return a list of all open nodes in the given dataset.
+ Depending on the parameters, may be single or compound open transitions.
+ """
+ if transition_scheme is TransitionScheme.TOP_DOWN_COMPOUND:
+ return Tree.get_compound_constituents(trees)
+ elif transition_scheme is TransitionScheme.IN_ORDER_COMPOUND:
+ return Tree.get_compound_constituents(trees, separate_root=True)
+ else:
+ return [(x,) for x in Tree.get_unique_constituent_labels(trees)]
+
+
+def verify_transitions(trees, sequences, transition_scheme, unary_limit, reverse, name, root_labels):
+ """
+ Given a list of trees and their transition sequences, verify that the sequences rebuild the trees
+ """
+ model = SimpleModel(transition_scheme, unary_limit, reverse, root_labels)
+ tlogger.info("Verifying the transition sequences for %d trees", len(trees))
+
+ data = zip(trees, sequences)
+ if tlogger.getEffectiveLevel() <= logging.INFO:
+ data = tqdm(zip(trees, sequences), total=len(trees))
+
+ for tree_idx, (tree, sequence) in enumerate(data):
+ # TODO: make the SimpleModel have a parse operation?
+ state = model.initial_state_from_gold_trees([tree])[0]
+ for idx, trans in enumerate(sequence):
+ if not trans.is_legal(state, model):
+ raise RuntimeError("Tree {} of {} failed: transition {}:{} was not legal in a transition sequence:\nOriginal tree: {}\nTransitions: {}".format(tree_idx, name, idx, trans, tree, sequence))
+ state = trans.apply(state, model)
+ result = model.get_top_constituent(state.constituents)
+ if reverse:
+ result = result.reverse()
+ if tree != result:
+ raise RuntimeError("Tree {} of {} failed: transition sequence did not match for a tree!\nOriginal tree:{}\nTransitions: {}\nResult tree:{}".format(tree_idx, name, tree, sequence, result))
+
+def check_constituents(train_constituents, trees, treebank_name, fail=True):
+ """
+ Check that all the constituents in the other dataset are known in the train set
+ """
+ constituents = Tree.get_unique_constituent_labels(trees)
+ for con in constituents:
+ if con not in train_constituents:
+ first_error = None
+ num_errors = 0
+ for tree_idx, tree in enumerate(trees):
+ constituents = Tree.get_unique_constituent_labels(tree)
+ if con in constituents:
+ num_errors += 1
+ if first_error is None:
+ first_error = tree_idx
+ error = "Found constituent label {} in the {} set which don't exist in the train set. This constituent label occured in {} trees, with the first tree index at {} counting from 1\nThe error tree (which may have POS tags changed from the retagger and may be missing functional tags or empty nodes) is:\n{:P}".format(con, treebank_name, num_errors, (first_error+1), trees[first_error])
+ if fail:
+ raise RuntimeError(error)
+ else:
+ warnings.warn(error)
+
+def check_root_labels(root_labels, other_trees, treebank_name):
+ """
+ Check that all the root states in the other dataset are known in the train set
+ """
+ for root_state in Tree.get_root_labels(other_trees):
+ if root_state not in root_labels:
+ raise RuntimeError("Found root state {} in the {} set which is not a ROOT state in the train set".format(root_state, treebank_name))
+
+def remove_duplicate_trees(trees, treebank_name):
+ """
+ Filter duplicates from the given dataset
+ """
+ new_trees = []
+ known_trees = set()
+ for tree in trees:
+ tree_str = "{}".format(tree)
+ if tree_str in known_trees:
+ continue
+ known_trees.add(tree_str)
+ new_trees.append(tree)
+ if len(new_trees) < len(trees):
+ tlogger.info("Filtered %d duplicates from %s dataset", (len(trees) - len(new_trees)), treebank_name)
+ return new_trees
+
+def remove_singleton_trees(trees):
+ """
+ remove trees which are just a root and a single word
+
+ TODO: remove these trees in the conversion instead of here
+ """
+ new_trees = [x for x in trees if
+ len(x.children) > 1 or
+ (len(x.children) == 1 and len(x.children[0].children) > 1) or
+ (len(x.children) == 1 and len(x.children[0].children) == 1 and len(x.children[0].children[0].children) >= 1)]
+ if len(trees) - len(new_trees) > 0:
+ tlogger.info("Eliminated %d trees with missing structure", (len(trees) - len(new_trees)))
+ return new_trees
+
diff --git a/stanza/stanza/models/coref/predict.py b/stanza/stanza/models/coref/predict.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a035e66a2bb7beaaa5ac8f94b6dc981d5b53459
--- /dev/null
+++ b/stanza/stanza/models/coref/predict.py
@@ -0,0 +1,55 @@
+import argparse
+
+import json
+import torch
+from tqdm import tqdm
+
+from stanza.models.coref.model import CorefModel
+
+
+if __name__ == "__main__":
+ argparser = argparse.ArgumentParser()
+ argparser.add_argument("experiment")
+ argparser.add_argument("input_file")
+ argparser.add_argument("output_file")
+ argparser.add_argument("--config-file", default="config.toml")
+ argparser.add_argument("--batch-size", type=int,
+ help="Adjust to override the config value if you're"
+ " experiencing out-of-memory issues")
+ argparser.add_argument("--weights",
+ help="Path to file with weights to load."
+ " If not supplied, in the latest"
+ " weights of the experiment will be loaded;"
+ " if there aren't any, an error is raised.")
+ args = argparser.parse_args()
+
+ model = CorefModel.load_model(path=args.weights,
+ map_location="cpu",
+ ignore={"bert_optimizer", "general_optimizer",
+ "bert_scheduler", "general_scheduler"})
+ if args.batch_size:
+ model.config.a_scoring_batch_size = args.batch_size
+ model.training = False
+
+ try:
+ with open(args.input_file, encoding="utf-8") as fin:
+ input_data = json.load(fin)
+ except json.decoder.JSONDecodeError:
+ # read the old jsonlines format if necessary
+ with open(args.input_file, encoding="utf-8") as fin:
+ text = "[" + ",\n".join(fin) + "]"
+ input_data = json.loads(text)
+ docs = [model.build_doc(doc) for doc in input_data]
+
+ with torch.no_grad():
+ for doc in tqdm(docs, unit="docs"):
+ result = model.run(doc)
+ doc["span_clusters"] = result.span_clusters
+ doc["word_clusters"] = result.word_clusters
+
+ for key in ("word2subword", "subwords", "word_id", "head2span"):
+ del doc[key]
+
+ with open(args.output_file, mode="w") as fout:
+ for doc in docs:
+ json.dump(doc, fout)
diff --git a/stanza/stanza/models/coref/span_predictor.py b/stanza/stanza/models/coref/span_predictor.py
new file mode 100644
index 0000000000000000000000000000000000000000..44c5719df54e8217b34cd1cc481ac15ead7cc3c2
--- /dev/null
+++ b/stanza/stanza/models/coref/span_predictor.py
@@ -0,0 +1,146 @@
+""" Describes SpanPredictor which aims to predict spans by taking as input
+head word and context embeddings.
+"""
+
+from typing import List, Optional, Tuple
+
+from stanza.models.coref.const import Doc, Span
+import torch
+
+
+class SpanPredictor(torch.nn.Module):
+ def __init__(self, input_size: int, distance_emb_size: int):
+ super().__init__()
+ self.ffnn = torch.nn.Sequential(
+ torch.nn.Linear(input_size * 2 + 64, input_size),
+ torch.nn.ReLU(),
+ torch.nn.Dropout(0.3),
+ torch.nn.Linear(input_size, 256),
+ torch.nn.ReLU(),
+ torch.nn.Dropout(0.3),
+ torch.nn.Linear(256, 64),
+ )
+ self.conv = torch.nn.Sequential(
+ torch.nn.Conv1d(64, 4, 3, 1, 1),
+ torch.nn.Conv1d(4, 2, 3, 1, 1)
+ )
+ self.emb = torch.nn.Embedding(128, distance_emb_size) # [-63, 63] + too_far
+
+ @property
+ def device(self) -> torch.device:
+ """ A workaround to get current device (which is assumed to be the
+ device of the first parameter of one of the submodules) """
+ return next(self.ffnn.parameters()).device
+
+ def forward(self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch
+ doc: Doc,
+ words: torch.Tensor,
+ heads_ids: torch.Tensor) -> torch.Tensor:
+ """
+ Calculates span start/end scores of words for each span head in
+ heads_ids
+
+ Args:
+ doc (Doc): the document data
+ words (torch.Tensor): contextual embeddings for each word in the
+ document, [n_words, emb_size]
+ heads_ids (torch.Tensor): word indices of span heads
+
+ Returns:
+ torch.Tensor: span start/end scores, [n_heads, n_words, 2]
+ """
+ # Obtain distance embedding indices, [n_heads, n_words]
+ relative_positions = (heads_ids.unsqueeze(1) - torch.arange(words.shape[0], device=words.device).unsqueeze(0))
+ emb_ids = relative_positions + 63 # make all valid distances positive
+ emb_ids[(emb_ids < 0) + (emb_ids > 126)] = 127 # "too_far"
+
+ # Obtain "same sentence" boolean mask, [n_heads, n_words]
+ sent_id = torch.tensor(doc["sent_id"], device=words.device)
+ same_sent = (sent_id[heads_ids].unsqueeze(1) == sent_id.unsqueeze(0))
+
+ # To save memory, only pass candidates from one sentence for each head
+ # pair_matrix contains concatenated span_head_emb + candidate_emb + distance_emb
+ # for each candidate among the words in the same sentence as span_head
+ # [n_heads, input_size * 2 + distance_emb_size]
+ rows, cols = same_sent.nonzero(as_tuple=True)
+ pair_matrix = torch.cat((
+ words[heads_ids[rows]],
+ words[cols],
+ self.emb(emb_ids[rows, cols]),
+ ), dim=1)
+
+ lengths = same_sent.sum(dim=1)
+ padding_mask = torch.arange(0, lengths.max(), device=words.device).unsqueeze(0)
+ padding_mask = (padding_mask < lengths.unsqueeze(1)) # [n_heads, max_sent_len]
+
+ # [n_heads, max_sent_len, input_size * 2 + distance_emb_size]
+ # This is necessary to allow the convolution layer to look at several
+ # word scores
+ padded_pairs = torch.zeros(*padding_mask.shape, pair_matrix.shape[-1], device=words.device)
+ padded_pairs[padding_mask] = pair_matrix
+
+ res = self.ffnn(padded_pairs) # [n_heads, n_candidates, last_layer_output]
+ res = self.conv(res.permute(0, 2, 1)).permute(0, 2, 1) # [n_heads, n_candidates, 2]
+
+ scores = torch.full((heads_ids.shape[0], words.shape[0], 2), float('-inf'), device=words.device)
+ scores[rows, cols] = res[padding_mask]
+
+ # Make sure that start <= head <= end during inference
+ if not self.training:
+ valid_starts = torch.log((relative_positions >= 0).to(torch.float))
+ valid_ends = torch.log((relative_positions <= 0).to(torch.float))
+ valid_positions = torch.stack((valid_starts, valid_ends), dim=2)
+ return scores + valid_positions
+ return scores
+
+ def get_training_data(self,
+ doc: Doc,
+ words: torch.Tensor
+ ) -> Tuple[Optional[torch.Tensor],
+ Optional[Tuple[torch.Tensor, torch.Tensor]]]:
+ """ Returns span starts/ends for gold mentions in the document. """
+ head2span = sorted(doc["head2span"])
+ if not head2span:
+ return None, None
+ heads, starts, ends = zip(*head2span)
+ heads = torch.tensor(heads, device=self.device)
+ starts = torch.tensor(starts, device=self.device)
+ ends = torch.tensor(ends, device=self.device) - 1
+ return self(doc, words, heads), (starts, ends)
+
+ def predict(self,
+ doc: Doc,
+ words: torch.Tensor,
+ clusters: List[List[int]]) -> List[List[Span]]:
+ """
+ Predicts span clusters based on the word clusters.
+
+ Args:
+ doc (Doc): the document data
+ words (torch.Tensor): [n_words, emb_size] matrix containing
+ embeddings for each of the words in the text
+ clusters (List[List[int]]): a list of clusters where each cluster
+ is a list of word indices
+
+ Returns:
+ List[List[Span]]: span clusters
+ """
+ if not clusters:
+ return []
+
+ heads_ids = torch.tensor(
+ sorted(i for cluster in clusters for i in cluster),
+ device=self.device
+ )
+
+ scores = self(doc, words, heads_ids)
+ starts = scores[:, :, 0].argmax(dim=1).tolist()
+ ends = (scores[:, :, 1].argmax(dim=1) + 1).tolist()
+
+ head2span = {
+ head: (start, end)
+ for head, start, end in zip(heads_ids.tolist(), starts, ends)
+ }
+
+ return [[head2span[head] for head in cluster]
+ for cluster in clusters]
diff --git a/stanza/stanza/models/coref/tokenizer_customization.py b/stanza/stanza/models/coref/tokenizer_customization.py
new file mode 100644
index 0000000000000000000000000000000000000000..b2a8a1b8357825db18911161b6393b71cfbb89a1
--- /dev/null
+++ b/stanza/stanza/models/coref/tokenizer_customization.py
@@ -0,0 +1,18 @@
+""" This file defines functions used to modify the default behaviour
+of transformers.AutoTokenizer. These changes are necessary, because some
+tokenizers are meant to be used with raw text, while the OntoNotes documents
+have already been split into words.
+All the functions are used in coref_model.CorefModel._get_docs. """
+
+
+# Filters out unwanted tokens produced by the tokenizer
+TOKENIZER_FILTERS = {
+ "albert-xxlarge-v2": (lambda token: token != "▁"), # U+2581, not just "_"
+ "albert-large-v2": (lambda token: token != "▁"),
+}
+
+# Maps some words to tokens directly, without a tokenizer
+TOKENIZER_MAPS = {
+ "roberta-large": {".": ["."], ",": [","], "!": ["!"], "?": ["?"],
+ ":":[":"], ";":[";"], "'s": ["'s"]}
+}
diff --git a/stanza/stanza/models/coref/word_encoder.py b/stanza/stanza/models/coref/word_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d20abc3458957f02b942b81eeb103080bf977f4
--- /dev/null
+++ b/stanza/stanza/models/coref/word_encoder.py
@@ -0,0 +1,108 @@
+""" Describes WordEncoder. Extracts mention vectors from bert-encoded text.
+"""
+
+from typing import Tuple
+
+import torch
+
+from stanza.models.coref.config import Config
+from stanza.models.coref.const import Doc
+
+
+class WordEncoder(torch.nn.Module): # pylint: disable=too-many-instance-attributes
+ """ Receives bert contextual embeddings of a text, extracts all the
+ possible mentions in that text. """
+
+ def __init__(self, features: int, config: Config):
+ """
+ Args:
+ features (int): the number of featues in the input embeddings
+ config (Config): the configuration of the current session
+ """
+ super().__init__()
+ self.attn = torch.nn.Linear(in_features=features, out_features=1)
+ self.dropout = torch.nn.Dropout(config.dropout_rate)
+
+ @property
+ def device(self) -> torch.device:
+ """ A workaround to get current device (which is assumed to be the
+ device of the first parameter of one of the submodules) """
+ return next(self.attn.parameters()).device
+
+ def forward(self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch
+ doc: Doc,
+ x: torch.Tensor,
+ ) -> Tuple[torch.Tensor, ...]:
+ """
+ Extracts word representations from text.
+
+ Args:
+ doc: the document data
+ x: a tensor containing bert output, shape (n_subtokens, bert_dim)
+
+ Returns:
+ words: a Tensor of shape [n_words, mention_emb];
+ mention representations
+ cluster_ids: tensor of shape [n_words], containing cluster indices
+ for each word. Non-coreferent words have cluster id of zero.
+ """
+ word_boundaries = torch.tensor(doc["word2subword"], device=self.device)
+ starts = word_boundaries[:, 0]
+ ends = word_boundaries[:, 1]
+
+ # [n_mentions, features]
+ words = self._attn_scores(x, starts, ends).mm(x)
+
+ words = self.dropout(words)
+
+ return (words, self._cluster_ids(doc))
+
+ def _attn_scores(self,
+ bert_out: torch.Tensor,
+ word_starts: torch.Tensor,
+ word_ends: torch.Tensor) -> torch.Tensor:
+ """ Calculates attention scores for each of the mentions.
+
+ Args:
+ bert_out (torch.Tensor): [n_subwords, bert_emb], bert embeddings
+ for each of the subwords in the document
+ word_starts (torch.Tensor): [n_words], start indices of words
+ word_ends (torch.Tensor): [n_words], end indices of words
+
+ Returns:
+ torch.Tensor: [description]
+ """
+ n_subtokens = len(bert_out)
+ n_words = len(word_starts)
+
+ # [n_mentions, n_subtokens]
+ # with 0 at positions belonging to the words and -inf elsewhere
+ attn_mask = torch.arange(0, n_subtokens, device=self.device).expand((n_words, n_subtokens))
+ attn_mask = ((attn_mask >= word_starts.unsqueeze(1))
+ * (attn_mask < word_ends.unsqueeze(1)))
+ attn_mask = torch.log(attn_mask.to(torch.float))
+
+ attn_scores = self.attn(bert_out).T # [1, n_subtokens]
+ attn_scores = attn_scores.expand((n_words, n_subtokens))
+ attn_scores = attn_mask + attn_scores
+ del attn_mask
+ return torch.softmax(attn_scores, dim=1) # [n_words, n_subtokens]
+
+ def _cluster_ids(self, doc: Doc) -> torch.Tensor:
+ """
+ Args:
+ doc: document information
+
+ Returns:
+ torch.Tensor of shape [n_word], containing cluster indices for
+ each word. Non-coreferent words have cluster id of zero.
+ """
+ word2cluster = {word_i: i
+ for i, cluster in enumerate(doc["word_clusters"], start=1)
+ for word_i in cluster}
+
+ return torch.tensor(
+ [word2cluster.get(word_i, 0)
+ for word_i in range(len(doc["cased_words"]))],
+ device=self.device
+ )
diff --git a/stanza/stanza/models/depparse/data.py b/stanza/stanza/models/depparse/data.py
new file mode 100644
index 0000000000000000000000000000000000000000..70949ba6942cf8ab179da223951ddc7e1af9922c
--- /dev/null
+++ b/stanza/stanza/models/depparse/data.py
@@ -0,0 +1,233 @@
+import random
+import logging
+import torch
+
+from stanza.models.common.bert_embedding import filter_data, needs_length_filter
+from stanza.models.common.data import map_to_ids, get_long_tensor, get_float_tensor, sort_all
+from stanza.models.common.vocab import PAD_ID, VOCAB_PREFIX, ROOT_ID, CompositeVocab, CharVocab
+from stanza.models.pos.vocab import WordVocab, XPOSVocab, FeatureVocab, MultiVocab
+from stanza.models.pos.xpos_vocab_factory import xpos_vocab_factory
+from stanza.models.common.doc import *
+
+logger = logging.getLogger('stanza')
+
+def data_to_batches(data, batch_size, eval_mode, sort_during_eval, min_length_to_batch_separately):
+ """
+ Given a list of lists, where the first element of each sublist
+ represents the sentence, group the sentences into batches.
+
+ During training mode (not eval_mode) the sentences are sorted by
+ length with a bit of random shuffling. During eval mode, the
+ sentences are sorted by length if sort_during_eval is true.
+
+ Refactored from the data structure in case other models could use
+ it and for ease of testing.
+
+ Returns (batches, original_order), where original_order is None
+ when in train mode or when unsorted and represents the original
+ location of each sentence in the sort
+ """
+ res = []
+
+ if not eval_mode:
+ # sort sentences (roughly) by length for better memory utilization
+ data = sorted(data, key = lambda x: len(x[0]), reverse=random.random() > .5)
+ data_orig_idx = None
+ elif sort_during_eval:
+ (data, ), data_orig_idx = sort_all([data], [len(x[0]) for x in data])
+ else:
+ data_orig_idx = None
+
+ current = []
+ currentlen = 0
+ for x in data:
+ if min_length_to_batch_separately is not None and len(x[0]) > min_length_to_batch_separately:
+ if currentlen > 0:
+ res.append(current)
+ current = []
+ currentlen = 0
+ res.append([x])
+ else:
+ if len(x[0]) + currentlen > batch_size and currentlen > 0:
+ res.append(current)
+ current = []
+ currentlen = 0
+ current.append(x)
+ currentlen += len(x[0])
+
+ if currentlen > 0:
+ res.append(current)
+
+ return res, data_orig_idx
+
+
+class DataLoader:
+
+ def __init__(self, doc, batch_size, args, pretrain, vocab=None, evaluation=False, sort_during_eval=False, min_length_to_batch_separately=None, bert_tokenizer=None):
+ self.batch_size = batch_size
+ self.min_length_to_batch_separately=min_length_to_batch_separately
+ self.args = args
+ self.eval = evaluation
+ self.shuffled = not self.eval
+ self.sort_during_eval = sort_during_eval
+ self.doc = doc
+ data = self.load_doc(doc)
+
+ # handle vocab
+ if vocab is None:
+ self.vocab = self.init_vocab(data)
+ else:
+ self.vocab = vocab
+
+ # filter out the long sentences if bert is used
+ if self.args.get('bert_model', None) and needs_length_filter(self.args['bert_model']):
+ data = filter_data(self.args['bert_model'], data, bert_tokenizer)
+
+ # handle pretrain; pretrain vocab is used when args['pretrain'] == True and pretrain is not None
+ self.pretrain_vocab = None
+ if pretrain is not None and args['pretrain']:
+ self.pretrain_vocab = pretrain.vocab
+
+ # filter and sample data
+ if args.get('sample_train', 1.0) < 1.0 and not self.eval:
+ keep = int(args['sample_train'] * len(data))
+ data = random.sample(data, keep)
+ logger.debug("Subsample training set with rate {:g}".format(args['sample_train']))
+
+ data = self.preprocess(data, self.vocab, self.pretrain_vocab, args)
+ # shuffle for training
+ if self.shuffled:
+ random.shuffle(data)
+ self.num_examples = len(data)
+
+ # chunk into batches
+ self.data = self.chunk_batches(data)
+ logger.debug("{} batches created.".format(len(self.data)))
+
+ def init_vocab(self, data):
+ assert self.eval == False # for eval vocab must exist
+ charvocab = CharVocab(data, self.args['shorthand'])
+ wordvocab = WordVocab(data, self.args['shorthand'], cutoff=7, lower=True)
+ uposvocab = WordVocab(data, self.args['shorthand'], idx=1)
+ xposvocab = xpos_vocab_factory(data, self.args['shorthand'])
+ featsvocab = FeatureVocab(data, self.args['shorthand'], idx=3)
+ lemmavocab = WordVocab(data, self.args['shorthand'], cutoff=7, idx=4, lower=True)
+ deprelvocab = WordVocab(data, self.args['shorthand'], idx=6)
+ vocab = MultiVocab({'char': charvocab,
+ 'word': wordvocab,
+ 'upos': uposvocab,
+ 'xpos': xposvocab,
+ 'feats': featsvocab,
+ 'lemma': lemmavocab,
+ 'deprel': deprelvocab})
+ return vocab
+
+ def preprocess(self, data, vocab, pretrain_vocab, args):
+ processed = []
+ xpos_replacement = [[ROOT_ID] * len(vocab['xpos'])] if isinstance(vocab['xpos'], CompositeVocab) else [ROOT_ID]
+ feats_replacement = [[ROOT_ID] * len(vocab['feats'])]
+ for sent in data:
+ processed_sent = [[ROOT_ID] + vocab['word'].map([w[0] for w in sent])]
+ processed_sent += [[[ROOT_ID]] + [vocab['char'].map([x for x in w[0]]) for w in sent]]
+ processed_sent += [[ROOT_ID] + vocab['upos'].map([w[1] for w in sent])]
+ processed_sent += [xpos_replacement + vocab['xpos'].map([w[2] for w in sent])]
+ processed_sent += [feats_replacement + vocab['feats'].map([w[3] for w in sent])]
+ if pretrain_vocab is not None:
+ # always use lowercase lookup in pretrained vocab
+ processed_sent += [[ROOT_ID] + pretrain_vocab.map([w[0].lower() for w in sent])]
+ else:
+ processed_sent += [[ROOT_ID] + [PAD_ID] * len(sent)]
+ processed_sent += [[ROOT_ID] + vocab['lemma'].map([w[4] for w in sent])]
+ processed_sent += [[to_int(w[5], ignore_error=self.eval) for w in sent]]
+ processed_sent += [vocab['deprel'].map([w[6] for w in sent])]
+ processed_sent.append([w[0] for w in sent])
+ processed.append(processed_sent)
+ return processed
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(self, key):
+ """ Get a batch with index. """
+ if not isinstance(key, int):
+ raise TypeError
+ if key < 0 or key >= len(self.data):
+ raise IndexError
+ batch = self.data[key]
+ batch_size = len(batch)
+ batch = list(zip(*batch))
+ assert len(batch) == 10
+
+ # sort sentences by lens for easy RNN operations
+ lens = [len(x) for x in batch[0]]
+ batch, orig_idx = sort_all(batch, lens)
+
+ # sort words by lens for easy char-RNN operations
+ batch_words = [w for sent in batch[1] for w in sent]
+ word_lens = [len(x) for x in batch_words]
+ batch_words, word_orig_idx = sort_all([batch_words], word_lens)
+ batch_words = batch_words[0]
+ word_lens = [len(x) for x in batch_words]
+
+ # convert to tensors
+ words = batch[0]
+ words = get_long_tensor(words, batch_size)
+ words_mask = torch.eq(words, PAD_ID)
+ wordchars = get_long_tensor(batch_words, len(word_lens))
+ wordchars_mask = torch.eq(wordchars, PAD_ID)
+
+ upos = get_long_tensor(batch[2], batch_size)
+ xpos = get_long_tensor(batch[3], batch_size)
+ ufeats = get_long_tensor(batch[4], batch_size)
+ pretrained = get_long_tensor(batch[5], batch_size)
+ sentlens = [len(x) for x in batch[0]]
+ lemma = get_long_tensor(batch[6], batch_size)
+ head = get_long_tensor(batch[7], batch_size)
+ deprel = get_long_tensor(batch[8], batch_size)
+ text = batch[9]
+ return words, words_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained, lemma, head, deprel, orig_idx, word_orig_idx, sentlens, word_lens, text
+
+ def load_doc(self, doc):
+ data = doc.get([TEXT, UPOS, XPOS, FEATS, LEMMA, HEAD, DEPREL], as_sentences=True)
+ data = self.resolve_none(data)
+ return data
+
+ def resolve_none(self, data):
+ # replace None to '_'
+ for sent_idx in range(len(data)):
+ for tok_idx in range(len(data[sent_idx])):
+ for feat_idx in range(len(data[sent_idx][tok_idx])):
+ if data[sent_idx][tok_idx][feat_idx] is None:
+ data[sent_idx][tok_idx][feat_idx] = '_'
+ return data
+
+ def __iter__(self):
+ for i in range(self.__len__()):
+ yield self.__getitem__(i)
+
+ def set_batch_size(self, batch_size):
+ self.batch_size = batch_size
+
+ def reshuffle(self):
+ data = [y for x in self.data for y in x]
+ self.data = self.chunk_batches(data)
+ random.shuffle(self.data)
+
+ def chunk_batches(self, data):
+ batches, data_orig_idx = data_to_batches(data=data, batch_size=self.batch_size,
+ eval_mode=self.eval, sort_during_eval=self.sort_during_eval,
+ min_length_to_batch_separately=self.min_length_to_batch_separately)
+ # data_orig_idx might be None at train time, since we don't anticipate unsorting
+ self.data_orig_idx = data_orig_idx
+ return batches
+
+def to_int(string, ignore_error=False):
+ try:
+ res = int(string)
+ except ValueError as err:
+ if ignore_error:
+ return 0
+ else:
+ raise err
+ return res
+
diff --git a/stanza/stanza/models/lemma/attach_lemma_classifier.py b/stanza/stanza/models/lemma/attach_lemma_classifier.py
new file mode 100644
index 0000000000000000000000000000000000000000..4f59782c891aa4806c3a40f6bcbda31363bd3b5f
--- /dev/null
+++ b/stanza/stanza/models/lemma/attach_lemma_classifier.py
@@ -0,0 +1,25 @@
+import argparse
+
+from stanza.models.lemma.trainer import Trainer
+from stanza.models.lemma_classifier.base_model import LemmaClassifier
+
+def attach_classifier(input_filename, output_filename, classifiers):
+ trainer = Trainer(model_file=input_filename)
+
+ for classifier in classifiers:
+ classifier = LemmaClassifier.load(classifier)
+ trainer.contextual_lemmatizers.append(classifier)
+
+ trainer.save(output_filename)
+
+def main(args=None):
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--input', type=str, required=True, help='Which lemmatizer to start from')
+ parser.add_argument('--output', type=str, required=True, help='Where to save the lemmatizer')
+ parser.add_argument('--classifier', type=str, required=True, nargs='+', help='Lemma classifier to attach')
+ args = parser.parse_args(args)
+
+ attach_classifier(args.input, args.output, args.classifier)
+
+if __name__ == '__main__':
+ main()
diff --git a/stanza/stanza/models/lemma/scorer.py b/stanza/stanza/models/lemma/scorer.py
new file mode 100644
index 0000000000000000000000000000000000000000..f34f088dab1dd9a32e422f21397a81b89d281f53
--- /dev/null
+++ b/stanza/stanza/models/lemma/scorer.py
@@ -0,0 +1,13 @@
+"""
+Utils and wrappers for scoring lemmatizers.
+"""
+
+from stanza.models.common.utils import ud_scores
+
+def score(system_conllu_file, gold_conllu_file):
+ """ Wrapper for lemma scorer. """
+ evaluation = ud_scores(gold_conllu_file, system_conllu_file)
+ el = evaluation["Lemmas"]
+ p, r, f = el.precision, el.recall, el.f1
+ return p, r, f
+
diff --git a/stanza/stanza/models/lemma/vocab.py b/stanza/stanza/models/lemma/vocab.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a2cca0bc954326c61fb1b4fe22975be9b776435
--- /dev/null
+++ b/stanza/stanza/models/lemma/vocab.py
@@ -0,0 +1,18 @@
+from collections import Counter
+
+from stanza.models.common.vocab import BaseVocab, BaseMultiVocab
+from stanza.models.common.seq2seq_constant import VOCAB_PREFIX
+
+class Vocab(BaseVocab):
+ def build_vocab(self):
+ counter = Counter(self.data)
+ self._id2unit = VOCAB_PREFIX + list(sorted(list(counter.keys()), key=lambda k: counter[k], reverse=True))
+ self._unit2id = {w:i for i, w in enumerate(self._id2unit)}
+
+class MultiVocab(BaseMultiVocab):
+ @classmethod
+ def load_state_dict(cls, state_dict):
+ new = cls()
+ for k,v in state_dict.items():
+ new[k] = Vocab.load_state_dict(v)
+ return new
diff --git a/stanza/stanza/models/lemma_classifier/base_trainer.py b/stanza/stanza/models/lemma_classifier/base_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..160301072fcd3e2dee3205e2bc89cc947c9417a1
--- /dev/null
+++ b/stanza/stanza/models/lemma_classifier/base_trainer.py
@@ -0,0 +1,114 @@
+
+from abc import ABC, abstractmethod
+import logging
+import os
+from typing import List, Tuple, Any, Mapping
+
+import torch
+import torch.nn as nn
+import torch.optim as optim
+
+from stanza.models.common.utils import default_device
+from stanza.models.lemma_classifier import utils
+from stanza.models.lemma_classifier.constants import DEFAULT_BATCH_SIZE
+from stanza.models.lemma_classifier.evaluate_models import evaluate_model
+from stanza.utils.get_tqdm import get_tqdm
+
+tqdm = get_tqdm()
+logger = logging.getLogger('stanza.lemmaclassifier')
+
+class BaseLemmaClassifierTrainer(ABC):
+ def configure_weighted_loss(self, label_decoder: Mapping, counts: Mapping):
+ """
+ If applicable, this function will update the loss function of the LemmaClassifierLSTM model to become BCEWithLogitsLoss.
+ The weights are determined by the counts of the classes in the dataset. The weights are inversely proportional to the
+ frequency of the class in the set. E.g. classes with lower frequency will have higher weight.
+ """
+ weights = [0 for _ in label_decoder.keys()] # each key in the label decoder is one class, we have one weight per class
+ total_samples = sum(counts.values())
+ for class_idx in counts:
+ weights[class_idx] = total_samples / (counts[class_idx] * len(counts)) # weight_i = total / (# examples in class i * num classes)
+ weights = torch.tensor(weights)
+ logger.info(f"Using weights {weights} for weighted loss.")
+ self.criterion = nn.BCEWithLogitsLoss(weight=weights)
+
+ @abstractmethod
+ def build_model(self, label_decoder, upos_to_id, known_words, target_words, target_upos):
+ """
+ Build a model using pieces of the dataset to determine some of the model shape
+ """
+
+ def train(self, num_epochs: int, save_name: str, args: Mapping, eval_file: str, train_file: str) -> None:
+ """
+ Trains a model on batches of texts, position indices of the target token, and labels (lemma annotation) for the target token.
+
+ Args:
+ num_epochs (int): Number of training epochs
+ save_name (str): Path to file where trained model should be saved.
+ eval_file (str): Path to the dev set file for evaluating model checkpoints each epoch.
+ train_file (str): Path to data file, containing tokenized text sentences, token index and true label for token lemma on each line.
+ """
+ # Put model on GPU (if possible)
+ device = default_device()
+
+ if not train_file:
+ raise ValueError("Cannot train model - no train_file supplied!")
+
+ dataset = utils.Dataset(train_file, get_counts=self.weighted_loss, batch_size=args.get("batch_size", DEFAULT_BATCH_SIZE))
+ label_decoder = dataset.label_decoder
+ upos_to_id = dataset.upos_to_id
+ self.output_dim = len(label_decoder)
+ logger.info(f"Loaded dataset successfully from {train_file}")
+ logger.info(f"Using label decoder: {label_decoder} Output dimension: {self.output_dim}")
+ logger.info(f"Target words: {dataset.target_words}")
+
+ self.model = self.build_model(label_decoder, upos_to_id, dataset.known_words, dataset.target_words, set(dataset.target_upos))
+ self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
+
+ self.model.to(device)
+ logger.info(f"Training model on device: {device}. {next(self.model.parameters()).device}")
+
+ if os.path.exists(save_name) and not args.get('force', False):
+ raise FileExistsError(f"Save name {save_name} already exists; training would overwrite previous file contents. Aborting...")
+
+ if self.weighted_loss:
+ self.configure_weighted_loss(label_decoder, dataset.counts)
+
+ # Put the criterion on GPU too
+ logger.debug(f"Criterion on {next(self.model.parameters()).device}")
+ self.criterion = self.criterion.to(next(self.model.parameters()).device)
+
+ best_model, best_f1 = None, float("-inf") # Used for saving checkpoints of the model
+ for epoch in range(num_epochs):
+ # go over entire dataset with each epoch
+ for sentences, positions, upos_tags, labels in tqdm(dataset):
+ assert len(sentences) == len(positions) == len(labels), f"Input sentences, positions, and labels are of unequal length ({len(sentences), len(positions), len(labels)})"
+
+ self.optimizer.zero_grad()
+ outputs = self.model(positions, sentences, upos_tags)
+
+ # Compute loss, which is different if using CE or BCEWithLogitsLoss
+ if self.weighted_loss: # BCEWithLogitsLoss requires a vector for target where probability is 1 on the true label class, and 0 on others.
+ # TODO: three classes?
+ targets = torch.stack([torch.tensor([1, 0]) if label == 0 else torch.tensor([0, 1]) for label in labels]).to(dtype=torch.float32).to(device)
+ # should be shape size (batch_size, 2)
+ else: # CELoss accepts target as just raw label
+ targets = labels.to(device)
+
+ loss = self.criterion(outputs, targets)
+
+ loss.backward()
+ self.optimizer.step()
+
+ logger.info(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item()}")
+ if eval_file:
+ # Evaluate model on dev set to see if it should be saved.
+ _, _, _, f1 = evaluate_model(self.model, eval_file, is_training=True)
+ logger.info(f"Weighted f1 for model: {f1}")
+ if f1 > best_f1:
+ best_f1 = f1
+ self.model.save(save_name)
+ logger.info(f"New best model: weighted f1 score of {f1}.")
+ else:
+ self.model.save(save_name)
+
diff --git a/stanza/stanza/models/lemma_classifier/constants.py b/stanza/stanza/models/lemma_classifier/constants.py
new file mode 100644
index 0000000000000000000000000000000000000000..09fa9044cdaf27e61f9bdf419c115286a65909a1
--- /dev/null
+++ b/stanza/stanza/models/lemma_classifier/constants.py
@@ -0,0 +1,14 @@
+from enum import Enum
+
+UNKNOWN_TOKEN = "unk" # token name for unknown tokens
+UNKNOWN_TOKEN_IDX = -1 # custom index we apply to unknown tokens
+
+# TODO: ModelType could just be LSTM and TRANSFORMER
+# and then the transformer baseline would have the transformer as another argument
+class ModelType(Enum):
+ LSTM = 1
+ TRANSFORMER = 2
+ BERT = 3
+ ROBERTA = 4
+
+DEFAULT_BATCH_SIZE = 16
\ No newline at end of file
diff --git a/stanza/stanza/models/lemma_classifier/evaluate_many.py b/stanza/stanza/models/lemma_classifier/evaluate_many.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0ab2c662c016e711cd42a745ec85d95a1bf7583
--- /dev/null
+++ b/stanza/stanza/models/lemma_classifier/evaluate_many.py
@@ -0,0 +1,68 @@
+"""
+Utils to evaluate many models of the same type at once
+"""
+import argparse
+import os
+import logging
+
+from stanza.models.lemma_classifier.evaluate_models import main as evaluate_main
+
+
+logger = logging.getLogger('stanza.lemmaclassifier')
+
+def evaluate_n_models(path_to_models_dir, args):
+
+ total_results = {
+ "be": 0.0,
+ "have": 0.0,
+ "accuracy": 0.0,
+ "weighted_f1": 0.0
+ }
+ paths = os.listdir(path_to_models_dir)
+ num_models = len(paths)
+ for model_path in paths:
+ full_path = os.path.join(path_to_models_dir, model_path)
+ args.save_name = full_path
+ mcc_results, confusion, acc, weighted_f1 = evaluate_main(predefined_args=args)
+
+ for lemma in mcc_results:
+
+ lemma_f1 = mcc_results.get(lemma, None).get("f1") * 100
+ total_results[lemma] += lemma_f1
+
+ total_results["accuracy"] += acc
+ total_results["weighted_f1"] += weighted_f1
+
+ total_results["be"] /= num_models
+ total_results["have"] /= num_models
+ total_results["accuracy"] /= num_models
+ total_results["weighted_f1"] /= num_models
+
+ logger.info(f"Models in {path_to_models_dir} had average weighted f1 of {100 * total_results['weighted_f1']}.\nLemma 'be' had f1: {total_results['be']}\nLemma 'have' had f1: {total_results['have']}.\nAccuracy: {100 * total_results['accuracy']}.\n ({num_models} models evaluated).")
+ return total_results
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--vocab_size", type=int, default=10000, help="Number of tokens in vocab")
+ parser.add_argument("--embedding_dim", type=int, default=100, help="Number of dimensions in word embeddings (currently using GloVe)")
+ parser.add_argument("--hidden_dim", type=int, default=256, help="Size of hidden layer")
+ parser.add_argument('--wordvec_pretrain_file', type=str, default=None, help='Exact name of the pretrain file to read')
+ parser.add_argument("--charlm", action='store_true', default=False, help="Whether not to use the charlm embeddings")
+ parser.add_argument('--charlm_shorthand', type=str, default=None, help="Shorthand for character-level language model training corpus.")
+ parser.add_argument("--charlm_forward_file", type=str, default=os.path.join(os.path.dirname(__file__), "charlm_files", "1billion_forward.pt"), help="Path to forward charlm file")
+ parser.add_argument("--charlm_backward_file", type=str, default=os.path.join(os.path.dirname(__file__), "charlm_files", "1billion_backwards.pt"), help="Path to backward charlm file")
+ parser.add_argument("--save_name", type=str, default=os.path.join(os.path.dirname(__file__), "saved_models", "lemma_classifier_model.pt"), help="Path to model save file")
+ parser.add_argument("--model_type", type=str, default="roberta", help="Which transformer to use ('bert' or 'roberta' or 'lstm')")
+ parser.add_argument("--bert_model", type=str, default=None, help="Use a specific transformer instead of the default bert/roberta")
+ parser.add_argument("--eval_file", type=str, help="path to evaluation file")
+
+ # Args specific to several model eval
+ parser.add_argument("--base_path", type=str, default=None, help="path to dir for eval")
+
+ args = parser.parse_args()
+ evaluate_n_models(args.base_path, args)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/stanza/stanza/models/lemma_classifier/evaluate_models.py b/stanza/stanza/models/lemma_classifier/evaluate_models.py
new file mode 100644
index 0000000000000000000000000000000000000000..9deb98fdf430a9f9d56e7af25dd677f5f4162043
--- /dev/null
+++ b/stanza/stanza/models/lemma_classifier/evaluate_models.py
@@ -0,0 +1,228 @@
+import os
+import sys
+
+parentdir = os.path.dirname(__file__)
+parentdir = os.path.dirname(parentdir)
+parentdir = os.path.dirname(parentdir)
+sys.path.append(parentdir)
+
+import logging
+import argparse
+import os
+
+from typing import Any, List, Tuple, Mapping
+from collections import defaultdict
+from numpy import random
+
+import torch
+import torch.nn as nn
+
+import stanza
+
+from stanza.models.common.utils import default_device
+from stanza.models.lemma_classifier import utils
+from stanza.models.lemma_classifier.base_model import LemmaClassifier
+from stanza.models.lemma_classifier.lstm_model import LemmaClassifierLSTM
+from stanza.models.lemma_classifier.transformer_model import LemmaClassifierWithTransformer
+from stanza.utils.confusion import format_confusion
+from stanza.utils.get_tqdm import get_tqdm
+
+tqdm = get_tqdm()
+
+logger = logging.getLogger('stanza.lemmaclassifier')
+
+
+def get_weighted_f1(mcc_results: Mapping[int, Mapping[str, float]], confusion: Mapping[int, Mapping[int, int]]) -> float:
+ """
+ Computes the weighted F1 score across an evaluation set.
+
+ The weight of a class's F1 score is equal to the number of examples in evaluation. This makes classes that have more
+ examples in the evaluation more impactful to the weighted f1.
+ """
+ num_total_examples = 0
+ weighted_f1 = 0
+
+ for class_id in mcc_results:
+ class_f1 = mcc_results.get(class_id).get("f1")
+ num_class_examples = sum(confusion.get(class_id).values())
+ weighted_f1 += class_f1 * num_class_examples
+ num_total_examples += num_class_examples
+
+ return weighted_f1 / num_total_examples
+
+
+def evaluate_sequences(gold_tag_sequences: List[Any], pred_tag_sequences: List[Any], label_decoder: Mapping, verbose=True):
+ """
+ Evaluates a model's predicted tags against a set of gold tags. Computes precision, recall, and f1 for all classes.
+
+ Precision = true positives / true positives + false positives
+ Recall = true positives / true positives + false negatives
+ F1 = 2 * (Precision * Recall) / (Precision + Recall)
+
+ Returns:
+ 1. Multi class result dictionary, where each class is a key and maps to another map of its F1, precision, and recall scores.
+ e.g. multiclass_results[0]["precision"] would give class 0's precision.
+ 2. Confusion matrix, where each key is a gold tag and its value is another map with a key of the predicted tag with value of that (gold, pred) count.
+ e.g. confusion[0][1] = 6 would mean that for gold tag 0, the model predicted tag 1 a total of 6 times.
+ """
+ assert len(gold_tag_sequences) == len(pred_tag_sequences), \
+ f"Length of gold tag sequences is {len(gold_tag_sequences)}, while length of predicted tag sequence is {len(pred_tag_sequences)}"
+
+ confusion = defaultdict(lambda: defaultdict(int))
+
+ reverse_label_decoder = {y: x for x, y in label_decoder.items()}
+ for gold, pred in zip(gold_tag_sequences, pred_tag_sequences):
+ confusion[reverse_label_decoder[gold]][reverse_label_decoder[pred]] += 1
+
+ multi_class_result = defaultdict(lambda: defaultdict(float))
+ # compute precision, recall and f1 for each class and store inside of `multi_class_result`
+ for gold_tag in confusion.keys():
+
+ try:
+ prec = confusion.get(gold_tag, {}).get(gold_tag, 0) / sum([confusion.get(k, {}).get(gold_tag, 0) for k in confusion.keys()])
+ except ZeroDivisionError:
+ prec = 0.0
+
+ try:
+ recall = confusion.get(gold_tag, {}).get(gold_tag, 0) / sum(confusion.get(gold_tag, {}).values())
+ except ZeroDivisionError:
+ recall = 0.0
+
+ try:
+ f1 = 2 * (prec * recall) / (prec + recall)
+ except ZeroDivisionError:
+ f1 = 0.0
+
+ multi_class_result[gold_tag] = {
+ "precision": prec,
+ "recall": recall,
+ "f1": f1
+ }
+
+ if verbose:
+ for lemma in multi_class_result:
+ logger.info(f"Lemma '{lemma}' had precision {100 * multi_class_result[lemma]['precision']}, recall {100 * multi_class_result[lemma]['recall']} and F1 score of {100 * multi_class_result[lemma]['f1']}")
+
+ weighted_f1 = get_weighted_f1(multi_class_result, confusion)
+
+ return multi_class_result, confusion, weighted_f1
+
+
+def model_predict(model: nn.Module, position_indices: torch.Tensor, sentences: List[List[str]], upos_tags: List[List[int]]=[]) -> torch.Tensor:
+ """
+ A LemmaClassifierLSTM or LemmaClassifierWithTransformer is used to predict on a single text example, given the position index of the target token.
+
+ Args:
+ model (LemmaClassifier): A trained LemmaClassifier that is able to predict on a target token.
+ position_indices (Tensor[int]): A tensor of the (zero-indexed) position of the target token in `text` for each example in the batch.
+ sentences (List[List[str]]): A list of lists of the tokenized strings of the input sentences.
+
+ Returns:
+ (int): The index of the predicted class in `model`'s output.
+ """
+ with torch.no_grad():
+ logits = model(position_indices, sentences, upos_tags) # should be size (batch_size, output_size)
+ predicted_class = torch.argmax(logits, dim=1) # should be size (batch_size, 1)
+
+ return predicted_class
+
+
+def evaluate_model(model: nn.Module, eval_path: str, verbose: bool = True, is_training: bool = False) -> Tuple[Mapping, Mapping, float, float]:
+ """
+ Helper function for model evaluation
+
+ Args:
+ model (LemmaClassifierLSTM or LemmaClassifierWithTransformer): An instance of the LemmaClassifier class that has architecture initialized which matches the model saved in `model_path`.
+ model_path (str): Path to the saved model weights that will be loaded into `model`.
+ eval_path (str): Path to the saved evaluation dataset.
+ verbose (bool, optional): True if `evaluate_sequences()` should print the F1, Precision, and Recall for each class. Defaults to True.
+ is_training (bool, optional): Whether the model is in training mode. If the model is training, we do not change it to eval mode.
+
+ Returns:
+ 1. Multi-class results (Mapping[int, Mapping[str, float]]): first map has keys as the classes (lemma indices) and value is
+ another map with key of "f1", "precision", or "recall" with corresponding values.
+ 2. Confusion Matrix (Mapping[int, Mapping[int, int]]): A confusion matrix with keys equal to the index of the gold tag, and a value of the
+ map with the key as the predicted tag and corresponding count of that (gold, pred) pair.
+ 3. Accuracy (float): the total accuracy (num correct / total examples) across the evaluation set.
+ """
+ # load model
+ device = default_device()
+ model.to(device)
+
+ if not is_training:
+ model.eval() # set to eval mode
+
+ # load in eval data
+ dataset = utils.Dataset(eval_path, label_decoder=model.label_decoder, shuffle=False)
+
+ logger.info(f"Evaluating on evaluation file {eval_path}")
+
+ correct, total = 0, 0
+ gold_tags, pred_tags = dataset.labels, []
+
+ # run eval on each example from dataset
+ for sentences, pos_indices, upos_tags, labels in tqdm(dataset, "Evaluating examples from data file"):
+ pred = model_predict(model, pos_indices, sentences, upos_tags) # Pred should be size (batch_size, )
+ correct_preds = pred == labels.to(device)
+ correct += torch.sum(correct_preds)
+ total += len(correct_preds)
+ pred_tags += pred.tolist()
+
+ logger.info("Finished evaluating on dataset. Computing scores...")
+ accuracy = correct / total
+
+ mc_results, confusion, weighted_f1 = evaluate_sequences(gold_tags, pred_tags, dataset.label_decoder, verbose=verbose)
+ # add brackets around batches of gold and pred tags because each batch is an element within the sequences in this helper
+ if verbose:
+ logger.info(f"Accuracy: {accuracy} ({correct}/{total})")
+ logger.info(f"Label decoder: {dataset.label_decoder}")
+
+ return mc_results, confusion, accuracy, weighted_f1
+
+
+def main(args=None, predefined_args=None):
+
+ # TODO: can unify this script with train_lstm_model.py?
+ # TODO: can save the model type in the model .pt, then
+ # automatically figure out what type of model we are using by
+ # looking in the file
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--vocab_size", type=int, default=10000, help="Number of tokens in vocab")
+ parser.add_argument("--embedding_dim", type=int, default=100, help="Number of dimensions in word embeddings (currently using GloVe)")
+ parser.add_argument("--hidden_dim", type=int, default=256, help="Size of hidden layer")
+ parser.add_argument('--wordvec_pretrain_file', type=str, default=None, help='Exact name of the pretrain file to read')
+ parser.add_argument("--charlm", action='store_true', default=False, help="Whether not to use the charlm embeddings")
+ parser.add_argument('--charlm_shorthand', type=str, default=None, help="Shorthand for character-level language model training corpus.")
+ parser.add_argument("--charlm_forward_file", type=str, default=os.path.join(os.path.dirname(__file__), "charlm_files", "1billion_forward.pt"), help="Path to forward charlm file")
+ parser.add_argument("--charlm_backward_file", type=str, default=os.path.join(os.path.dirname(__file__), "charlm_files", "1billion_backwards.pt"), help="Path to backward charlm file")
+ parser.add_argument("--save_name", type=str, default=os.path.join(os.path.dirname(__file__), "saved_models", "lemma_classifier_model.pt"), help="Path to model save file")
+ parser.add_argument("--model_type", type=str, default="roberta", help="Which transformer to use ('bert' or 'roberta' or 'lstm')")
+ parser.add_argument("--bert_model", type=str, default=None, help="Use a specific transformer instead of the default bert/roberta")
+ parser.add_argument("--eval_file", type=str, help="path to evaluation file")
+
+ args = parser.parse_args(args) if not predefined_args else predefined_args
+
+ logger.info("Running training script with the following args:")
+ args = vars(args)
+ for arg in args:
+ logger.info(f"{arg}: {args[arg]}")
+ logger.info("------------------------------------------------------------")
+
+ logger.info(f"Attempting evaluation of model from {args['save_name']} on file {args['eval_file']}")
+ model = LemmaClassifier.load(args['save_name'], args)
+
+ mcc_results, confusion, acc, weighted_f1 = evaluate_model(model, args['eval_file'])
+
+ logger.info(f"MCC Results: {dict(mcc_results)}")
+ logger.info("______________________________________________")
+ logger.info(f"Confusion:\n%s", format_confusion(confusion))
+ logger.info("______________________________________________")
+ logger.info(f"Accuracy: {acc}")
+ logger.info("______________________________________________")
+ logger.info(f"Weighted f1: {weighted_f1}")
+
+ return mcc_results, confusion, acc, weighted_f1
+
+
+if __name__ == "__main__":
+ main()
diff --git a/stanza/stanza/models/lemma_classifier/prepare_dataset.py b/stanza/stanza/models/lemma_classifier/prepare_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..a3a49c0c13185d8abbcadd3a2a0cd0eb3d7163b5
--- /dev/null
+++ b/stanza/stanza/models/lemma_classifier/prepare_dataset.py
@@ -0,0 +1,125 @@
+import argparse
+import json
+import os
+import re
+
+import stanza
+from stanza.models.lemma_classifier import utils
+
+from typing import List, Tuple, Any
+
+"""
+The code in this file processes a CoNLL dataset by taking its sentences and filtering out all sentences that do not contain the target token.
+Furthermore, it will store tuples of the Stanza document object, the position index of the target token, and its lemma.
+"""
+
+
+def load_doc_from_conll_file(path: str):
+ """"
+ loads in a Stanza document object from a path to a CoNLL file containing annotated sentences.
+ """
+ return stanza.utils.conll.CoNLL.conll2doc(path)
+
+
+class DataProcessor():
+
+ def __init__(self, target_word: str, target_upos: List[str], allowed_lemmas: str):
+ self.target_word = target_word
+ self.target_word_regex = re.compile(target_word)
+ self.target_upos = target_upos
+ self.allowed_lemmas = re.compile(allowed_lemmas)
+
+ def keep_sentence(self, sentence):
+ for word in sentence.words:
+ if self.target_word_regex.fullmatch(word.text) and word.upos in self.target_upos:
+ return True
+ return False
+
+ def find_all_occurrences(self, sentence) -> List[int]:
+ """
+ Finds all occurrences of self.target_word in tokens and returns the index(es) of such occurrences.
+ """
+ occurrences = []
+ for idx, token in enumerate(sentence.words):
+ if self.target_word_regex.fullmatch(token.text) and token.upos in self.target_upos:
+ occurrences.append(idx)
+ return occurrences
+
+ @staticmethod
+ def write_output_file(save_name, target_upos, sentences):
+ with open(save_name, "w+", encoding="utf-8") as output_f:
+ output_f.write("{\n")
+ output_f.write(' "upos": %s,\n' % json.dumps(target_upos))
+ output_f.write(' "sentences": [')
+ wrote_sentence = False
+ for sentence in sentences:
+ if not wrote_sentence:
+ output_f.write("\n ")
+ wrote_sentence = True
+ else:
+ output_f.write(",\n ")
+ output_f.write(json.dumps(sentence))
+ output_f.write("\n ]\n}\n")
+
+ def process_document(self, doc, save_name: str) -> None:
+ """
+ Takes any sentence from `doc` that meets the condition of `keep_sentence` and writes its tokens, index of target word, and lemma to `save_name`
+
+ Sentences that meet `keep_sentence` and contain `self.target_word` multiple times have each instance in a different example in the output file.
+
+ Args:
+ doc (Stanza.doc): Document object that represents the file to be analyzed
+ save_name (str): Path to the file for storing output
+ """
+ sentences = []
+ for sentence in doc.sentences:
+ # for each sentence, we need to determine if it should be added to the output file.
+ # if the sentence fulfills keep_sentence, then we will save it along with the target word's index and its corresponding lemma
+ if self.keep_sentence(sentence):
+ tokens = [token.text for token in sentence.words]
+ indexes = self.find_all_occurrences(sentence)
+ for idx in indexes:
+ if self.allowed_lemmas.fullmatch(sentence.words[idx].lemma):
+ # for each example found, we write the tokens,
+ # their respective upos tags, the target token index,
+ # and the target lemma
+ upos_tags = [sentence.words[i].upos for i in range(len(sentence.words))]
+ num_tokens = len(upos_tags)
+ sentences.append({
+ "words": tokens,
+ "upos_tags": upos_tags,
+ "index": idx,
+ "lemma": sentence.words[idx].lemma
+ })
+
+ if save_name:
+ self.write_output_file(save_name, self.target_upos, sentences)
+ return sentences
+
+def main(args=None):
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument("--conll_path", type=str, default=os.path.join(os.path.dirname(__file__), "en_gum-ud-train.conllu"), help="path to the conll file to translate")
+ parser.add_argument("--target_word", type=str, default="'s", help="Token to classify on, e.g. 's.")
+ parser.add_argument("--target_upos", type=str, default="AUX", help="upos on target token")
+ parser.add_argument("--output_path", type=str, default="test_output.txt", help="Path for output file")
+ parser.add_argument("--allowed_lemmas", type=str, default=".*", help="A regex for allowed lemmas. If not set, all lemmas are allowed")
+
+ args = parser.parse_args(args)
+
+ conll_path = args.conll_path
+ target_upos = args.target_upos
+ output_path = args.output_path
+ allowed_lemmas = args.allowed_lemmas
+
+ args = vars(args)
+ for arg in args:
+ print(f"{arg}: {args[arg]}")
+
+ doc = load_doc_from_conll_file(conll_path)
+ processor = DataProcessor(target_word=args['target_word'], target_upos=[target_upos], allowed_lemmas=allowed_lemmas)
+
+ return processor.process_document(doc, output_path)
+
+if __name__ == "__main__":
+ main()
diff --git a/stanza/stanza/models/lemma_classifier/train_lstm_model.py b/stanza/stanza/models/lemma_classifier/train_lstm_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..1432a7dba28c917596f03957183ce48dbe13dc1e
--- /dev/null
+++ b/stanza/stanza/models/lemma_classifier/train_lstm_model.py
@@ -0,0 +1,147 @@
+"""
+The code in this file works to train a lemma classifier for 's
+"""
+
+import argparse
+import logging
+import os
+
+import torch
+import torch.nn as nn
+
+from stanza.models.common.foundation_cache import load_pretrain
+from stanza.models.lemma_classifier.base_trainer import BaseLemmaClassifierTrainer
+from stanza.models.lemma_classifier.constants import DEFAULT_BATCH_SIZE
+from stanza.models.lemma_classifier.lstm_model import LemmaClassifierLSTM
+
+logger = logging.getLogger('stanza.lemmaclassifier')
+
+class LemmaClassifierTrainer(BaseLemmaClassifierTrainer):
+ """
+ Class to assist with training a LemmaClassifierLSTM
+ """
+
+ def __init__(self, model_args: dict, embedding_file: str, use_charlm: bool = False, charlm_forward_file: str = None, charlm_backward_file: str = None, lr: float = 0.001, loss_func: str = None):
+ """
+ Initializes the LemmaClassifierTrainer class.
+
+ Args:
+ model_args (dict): Various model shape parameters
+ embedding_file (str): What word embeddings file to use. Use a Stanza pretrain .pt
+ use_charlm (bool, optional): Whether to use charlm embeddings as well. Defaults to False.
+ charlm_forward_file (str): Path to the forward pass embeddings for the charlm
+ charlm_backward_file (str): Path to the backward pass embeddings for the charlm
+ upos_emb_dim (int): The dimension size of UPOS tag embeddings
+ num_heads (int): The number of attention heads to use.
+ lr (float): Learning rate, defaults to 0.001.
+ loss_func (str): Which loss function to use (either 'ce' or 'weighted_bce')
+
+ Raises:
+ FileNotFoundError: If the forward charlm file is not present
+ FileNotFoundError: If the backward charlm file is not present
+ """
+ super().__init__()
+
+ self.model_args = model_args
+
+ # Load word embeddings
+ pt = load_pretrain(embedding_file)
+ self.pt_embedding = pt
+
+ # Load CharLM embeddings
+ if use_charlm and charlm_forward_file is not None and not os.path.exists(charlm_forward_file):
+ raise FileNotFoundError(f"Could not find forward charlm file: {charlm_forward_file}")
+ if use_charlm and charlm_backward_file is not None and not os.path.exists(charlm_backward_file):
+ raise FileNotFoundError(f"Could not find backward charlm file: {charlm_backward_file}")
+
+ # TODO: just pass around the args instead
+ self.use_charlm = use_charlm
+ self.charlm_forward_file = charlm_forward_file
+ self.charlm_backward_file = charlm_backward_file
+ self.lr = lr
+
+ # Find loss function
+ if loss_func == "ce":
+ self.criterion = nn.CrossEntropyLoss()
+ self.weighted_loss = False
+ logger.debug("Using CE loss")
+ elif loss_func == "weighted_bce":
+ self.criterion = nn.BCEWithLogitsLoss()
+ self.weighted_loss = True # used to add weights during train time.
+ logger.debug("Using Weighted BCE loss")
+ else:
+ raise ValueError("Must enter a valid loss function (e.g. 'ce' or 'weighted_bce')")
+
+ def build_model(self, label_decoder, upos_to_id, known_words, target_words, target_upos):
+ return LemmaClassifierLSTM(self.model_args, self.output_dim, self.pt_embedding, label_decoder, upos_to_id, known_words, target_words, target_upos,
+ use_charlm=self.use_charlm, charlm_forward_file=self.charlm_forward_file, charlm_backward_file=self.charlm_backward_file)
+
+def build_argparse():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--hidden_dim", type=int, default=256, help="Size of hidden layer")
+ parser.add_argument('--wordvec_pretrain_file', type=str, default=os.path.join(os.path.dirname(__file__), "pretrain", "glove.pt"), help='Exact name of the pretrain file to read')
+ parser.add_argument("--charlm", action='store_true', dest='use_charlm', default=False, help="Whether not to use the charlm embeddings")
+ parser.add_argument('--charlm_shorthand', type=str, default=None, help="Shorthand for character-level language model training corpus.")
+ parser.add_argument("--charlm_forward_file", type=str, default=os.path.join(os.path.dirname(__file__), "charlm_files", "1billion_forward.pt"), help="Path to forward charlm file")
+ parser.add_argument("--charlm_backward_file", type=str, default=os.path.join(os.path.dirname(__file__), "charlm_files", "1billion_backwards.pt"), help="Path to backward charlm file")
+ parser.add_argument("--upos_emb_dim", type=int, default=20, help="Dimension size for UPOS tag embeddings.")
+ parser.add_argument("--use_attn", action='store_true', dest='attn', default=False, help='Whether to use multihead attention instead of LSTM.')
+ parser.add_argument("--num_heads", type=int, default=0, help="Number of heads to use for multihead attention.")
+ parser.add_argument("--save_name", type=str, default=os.path.join(os.path.dirname(__file__), "saved_models", "lemma_classifier_model_weighted_loss_charlm_new.pt"), help="Path to model save file")
+ parser.add_argument("--lr", type=float, default=0.001, help="learning rate")
+ parser.add_argument("--num_epochs", type=float, default=10, help="Number of training epochs")
+ parser.add_argument("--batch_size", type=int, default=DEFAULT_BATCH_SIZE, help="Number of examples to include in each batch")
+ parser.add_argument("--train_file", type=str, default=os.path.join(os.path.dirname(__file__), "data", "processed_ud_en", "combined_train.txt"), help="Full path to training file")
+ parser.add_argument("--weighted_loss", action='store_true', dest='weighted_loss', default=False, help="Whether to use weighted loss during training.")
+ parser.add_argument("--eval_file", type=str, default=os.path.join(os.path.dirname(__file__), "data", "processed_ud_en", "combined_dev.txt"), help="Path to dev file used to evaluate model for saves")
+ parser.add_argument("--force", action='store_true', default=False, help='Whether or not to clobber an existing save file')
+ return parser
+
+def main(args=None, predefined_args=None):
+ parser = build_argparse()
+ args = parser.parse_args(args) if predefined_args is None else predefined_args
+
+ wordvec_pretrain_file = args.wordvec_pretrain_file
+ use_charlm = args.use_charlm
+ charlm_forward_file = args.charlm_forward_file
+ charlm_backward_file = args.charlm_backward_file
+ upos_emb_dim = args.upos_emb_dim
+ use_attention = args.attn
+ num_heads = args.num_heads
+ save_name = args.save_name
+ lr = args.lr
+ num_epochs = args.num_epochs
+ train_file = args.train_file
+ weighted_loss = args.weighted_loss
+ eval_file = args.eval_file
+
+ args = vars(args)
+
+ if os.path.exists(save_name) and not args.get('force', False):
+ raise FileExistsError(f"Save name {save_name} already exists. Training would override existing data. Aborting...")
+ if not os.path.exists(train_file):
+ raise FileNotFoundError(f"Training file {train_file} not found. Try again with a valid path.")
+
+ logger.info("Running training script with the following args:")
+ for arg in args:
+ logger.info(f"{arg}: {args[arg]}")
+ logger.info("------------------------------------------------------------")
+
+ trainer = LemmaClassifierTrainer(model_args=args,
+ embedding_file=wordvec_pretrain_file,
+ use_charlm=use_charlm,
+ charlm_forward_file=charlm_forward_file,
+ charlm_backward_file=charlm_backward_file,
+ lr=lr,
+ loss_func="weighted_bce" if weighted_loss else "ce",
+ )
+
+ trainer.train(
+ num_epochs=num_epochs, save_name=save_name, args=args, eval_file=eval_file, train_file=train_file
+ )
+
+ return trainer
+
+if __name__ == "__main__":
+ main()
+
diff --git a/stanza/stanza/models/lemma_classifier/train_many.py b/stanza/stanza/models/lemma_classifier/train_many.py
new file mode 100644
index 0000000000000000000000000000000000000000..cefe7b93f6c18a531c2154daa375a7d3155d3da3
--- /dev/null
+++ b/stanza/stanza/models/lemma_classifier/train_many.py
@@ -0,0 +1,155 @@
+"""
+Utils for training and evaluating multiple models simultaneously
+"""
+
+import argparse
+import os
+
+from stanza.models.lemma_classifier.train_lstm_model import main as train_lstm_main
+from stanza.models.lemma_classifier.train_transformer_model import main as train_tfmr_main
+from stanza.models.lemma_classifier.constants import DEFAULT_BATCH_SIZE
+
+
+change_params_map = {
+ "lstm_layer": [16, 32, 64, 128, 256, 512],
+ "upos_emb_dim": [5, 10, 20, 30],
+ "training_size": [150, 300, 450, 600, 'full'],
+} # TODO: Add attention
+
+def train_n_models(num_models: int, base_path: str, args):
+
+ if args.change_param == "lstm_layer":
+ for num_layers in change_params_map.get("lstm_layer", None):
+ for i in range(num_models):
+ new_save_name = os.path.join(base_path, f"{num_layers}_{i}.pt")
+ args.save_name = new_save_name
+ args.hidden_dim = num_layers
+ train_lstm_main(predefined_args=args)
+
+ if args.change_param == "upos_emb_dim":
+ for upos_dim in change_params_map("upos_emb_dim", None):
+ for i in range(num_models):
+ new_save_name = os.path.join(base_path, f"dim_{upos_dim}_{i}.pt")
+ args.save_name = new_save_name
+ args.upos_emb_dim = upos_dim
+ train_lstm_main(predefined_args=args)
+
+ if args.change_param == "training_size":
+ for size in change_params_map.get("training_size", None):
+ for i in range(num_models):
+ new_save_name = os.path.join(base_path, f"{size}_examples_{i}.pt")
+ new_train_file = os.path.join(os.path.dirname(__file__), "data", "processed_ud_en", "combined_train.txt")
+ args.save_name = new_save_name
+ args.train_file = new_train_file
+ train_lstm_main(predefined_args=args)
+
+ if args.change_param == "base":
+ for i in range(num_models):
+ new_save_name = os.path.join(base_path, f"lstm_model_{i}.pt")
+ args.save_name = new_save_name
+ args.weighted_loss = False
+ train_lstm_main(predefined_args=args)
+
+ if not args.weighted_loss:
+ args.weighted_loss = True
+ new_save_name = os.path.join(base_path, f"lstm_model_wloss_{i}.pt")
+ args.save_name = new_save_name
+ train_lstm_main(predefined_args=args)
+
+ if args.change_param == "base_charlm":
+ for i in range(num_models):
+ new_save_name = os.path.join(base_path, f"lstm_charlm_{i}.pt")
+ args.save_name = new_save_name
+ train_lstm_main(predefined_args=args)
+
+ if args.change_param == "base_charlm_upos":
+ for i in range(num_models):
+ new_save_name = os.path.join(base_path, f"lstm_charlm_upos_{i}.pt")
+ args.save_name = new_save_name
+ train_lstm_main(predefined_args=args)
+
+ if args.change_param == "base_upos":
+ for i in range(num_models):
+ new_save_name = os.path.join(base_path, f"lstm_upos_{i}.pt")
+ args.save_name = new_save_name
+ train_lstm_main(predefined_args=args)
+
+ if args.change_param == "attn_model":
+ for i in range(num_models):
+ new_save_name = os.path.join(base_path, f"attn_model_{args.num_heads}_heads_{i}.pt")
+ args.save_name = new_save_name
+ train_lstm_main(predefined_args=args)
+
+def train_n_tfmrs(num_models: int, base_path: str, args):
+
+ if args.multi_train_type == "tfmr":
+
+ for i in range(num_models):
+
+ if args.change_param == "bert":
+ new_save_name = os.path.join(base_path, f"bert_{i}.pt")
+ args.save_name = new_save_name
+ args.loss_fn = "ce"
+ train_tfmr_main(predefined_args=args)
+
+ new_save_name = os.path.join(base_path, f"bert_wloss_{i}.pt")
+ args.save_name = new_save_name
+ args.loss_fn = "weighted_bce"
+ train_tfmr_main(predefined_args=args)
+
+ elif args.change_param == "roberta":
+ new_save_name = os.path.join(base_path, f"roberta_{i}.pt")
+ args.save_name = new_save_name
+ args.loss_fn = "ce"
+ train_tfmr_main(predefined_args=args)
+
+ new_save_name = os.path.join(base_path, f"roberta_wloss_{i}.pt")
+ args.save_name = new_save_name
+ args.loss_fn = "weighted_bce"
+ train_tfmr_main(predefined_args=args)
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--hidden_dim", type=int, default=256, help="Size of hidden layer")
+ parser.add_argument('--wordvec_pretrain_file', type=str, default=os.path.join(os.path.dirname(__file__), "pretrain", "glove.pt"), help='Exact name of the pretrain file to read')
+ parser.add_argument("--charlm", action='store_true', dest='use_charlm', default=False, help="Whether not to use the charlm embeddings")
+ parser.add_argument('--charlm_shorthand', type=str, default=None, help="Shorthand for character-level language model training corpus.")
+ parser.add_argument("--charlm_forward_file", type=str, default=os.path.join(os.path.dirname(__file__), "charlm_files", "1billion_forward.pt"), help="Path to forward charlm file")
+ parser.add_argument("--charlm_backward_file", type=str, default=os.path.join(os.path.dirname(__file__), "charlm_files", "1billion_backwards.pt"), help="Path to backward charlm file")
+ parser.add_argument("--upos_emb_dim", type=int, default=20, help="Dimension size for UPOS tag embeddings.")
+ parser.add_argument("--use_attn", action='store_true', dest='attn', default=False, help='Whether to use multihead attention instead of LSTM.')
+ parser.add_argument("--num_heads", type=int, default=0, help="Number of heads to use for multihead attention.")
+ parser.add_argument("--save_name", type=str, default=os.path.join(os.path.dirname(__file__), "saved_models", "lemma_classifier_model_weighted_loss_charlm_new.pt"), help="Path to model save file")
+ parser.add_argument("--lr", type=float, default=0.001, help="learning rate")
+ parser.add_argument("--num_epochs", type=float, default=10, help="Number of training epochs")
+ parser.add_argument("--batch_size", type=int, default=DEFAULT_BATCH_SIZE, help="Number of examples to include in each batch")
+ parser.add_argument("--train_file", type=str, default=os.path.join(os.path.dirname(__file__), "data", "processed_ud_en", "combined_train.txt"), help="Full path to training file")
+ parser.add_argument("--weighted_loss", action='store_true', dest='weighted_loss', default=False, help="Whether to use weighted loss during training.")
+ parser.add_argument("--eval_file", type=str, default=os.path.join(os.path.dirname(__file__), "data", "processed_ud_en", "combined_dev.txt"), help="Path to dev file used to evaluate model for saves")
+ # Tfmr-specific args
+ parser.add_argument("--model_type", type=str, default="roberta", help="Which transformer to use ('bert' or 'roberta')")
+ parser.add_argument("--bert_model", type=str, default=None, help="Use a specific transformer instead of the default bert/roberta")
+ parser.add_argument("--loss_fn", type=str, default="weighted_bce", help="Which loss function to train with (e.g. 'ce' or 'weighted_bce')")
+ # Multi-model train args
+ parser.add_argument("--multi_train_type", type=str, default="lstm", help="Whether you are attempting to multi-train an LSTM or transformer")
+ parser.add_argument("--multi_train_count", type=int, default=5, help="Number of each model to build")
+ parser.add_argument("--base_path", type=str, default=None, help="Path to start generating model type for.")
+ parser.add_argument("--change_param", type=str, default=None, help="Which hyperparameter to change when training")
+
+
+ args = parser.parse_args()
+
+ if args.multi_train_type == "lstm":
+ train_n_models(num_models=args.multi_train_count,
+ base_path=args.base_path,
+ args=args)
+ elif args.multi_train_type == "tfmr":
+ train_n_tfmrs(num_models=args.multi_train_count,
+ base_path=args.base_path,
+ args=args)
+ else:
+ raise ValueError(f"Improper input {args.multi_train_type}")
+
+if __name__ == "__main__":
+ main()
diff --git a/stanza/stanza/models/lemma_classifier/train_transformer_model.py b/stanza/stanza/models/lemma_classifier/train_transformer_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f6e2adb8803acb273aa9f760ea17d8c756fff1d
--- /dev/null
+++ b/stanza/stanza/models/lemma_classifier/train_transformer_model.py
@@ -0,0 +1,130 @@
+"""
+This file contains code used to train a baseline transformer model to classify on a lemma of a particular token.
+"""
+
+import argparse
+import os
+import sys
+import logging
+
+import torch
+import torch.nn as nn
+import torch.optim as optim
+
+from stanza.models.lemma_classifier.base_trainer import BaseLemmaClassifierTrainer
+from stanza.models.lemma_classifier.constants import DEFAULT_BATCH_SIZE
+from stanza.models.lemma_classifier.transformer_model import LemmaClassifierWithTransformer
+from stanza.models.common.utils import default_device
+
+logger = logging.getLogger('stanza.lemmaclassifier')
+
+class TransformerBaselineTrainer(BaseLemmaClassifierTrainer):
+ """
+ Class to assist with training a baseline transformer model to classify on token lemmas.
+ To find the model spec, refer to `model.py` in this directory.
+ """
+
+ def __init__(self, model_args: dict, transformer_name: str = "roberta", loss_func: str = "ce", lr: int = 0.001):
+ """
+ Creates the Trainer object
+
+ Args:
+ transformer_name (str, optional): What kind of transformer to use for embeddings. Defaults to "roberta".
+ loss_func (str, optional): Which loss function to use (either 'ce' or 'weighted_bce'). Defaults to "ce".
+ lr (int, optional): learning rate for the optimizer. Defaults to 0.001.
+ """
+ super().__init__()
+
+ self.model_args = model_args
+
+ # Find loss function
+ if loss_func == "ce":
+ self.criterion = nn.CrossEntropyLoss()
+ self.weighted_loss = False
+ elif loss_func == "weighted_bce":
+ self.criterion = nn.BCEWithLogitsLoss()
+ self.weighted_loss = True # used to add weights during train time.
+ else:
+ raise ValueError("Must enter a valid loss function (e.g. 'ce' or 'weighted_bce')")
+
+ self.transformer_name = transformer_name
+ self.lr = lr
+
+ def set_layer_learning_rates(self, transformer_lr: float, mlp_lr: float) -> torch.optim:
+ """
+ Sets learning rates for each layer of the model.
+ Currently, the model has the transformer layer and the MLP layer, so these are tweakable.
+
+ Returns (torch.optim): An Adam optimizer with the learning rates adjusted per layer.
+
+ Currently unused - could be refactored into the parent class's train method,
+ or the parent class could call a build_optimizer and this subclass would use the optimizer
+ """
+ transformer_params, mlp_params = [], []
+ for name, param in self.model.named_parameters():
+ if 'transformer' in name:
+ transformer_params.append(param)
+ elif 'mlp' in name:
+ mlp_params.append(param)
+ optimizer = optim.Adam([
+ {"params": transformer_params, "lr": transformer_lr},
+ {"params": mlp_params, "lr": mlp_lr}
+ ])
+ return optimizer
+
+ def build_model(self, label_decoder, upos_to_id, known_words, target_words, target_upos):
+ return LemmaClassifierWithTransformer(model_args=self.model_args, output_dim=self.output_dim, transformer_name=self.transformer_name, label_decoder=label_decoder, target_words=target_words, target_upos=target_upos)
+
+
+def main(args=None, predefined_args=None):
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument("--save_name", type=str, default=os.path.join(os.path.dirname(os.path.dirname(__file__)), "saved_models", "big_model_roberta_weighted_loss.pt"), help="Path to model save file")
+ parser.add_argument("--num_epochs", type=float, default=10, help="Number of training epochs")
+ parser.add_argument("--train_file", type=str, default=os.path.join(os.path.dirname(os.path.dirname(__file__)), "test_sets", "combined_train.txt"), help="Full path to training file")
+ parser.add_argument("--model_type", type=str, default="roberta", help="Which transformer to use ('bert' or 'roberta')")
+ parser.add_argument("--bert_model", type=str, default=None, help="Use a specific transformer instead of the default bert/roberta")
+ parser.add_argument("--loss_fn", type=str, default="weighted_bce", help="Which loss function to train with (e.g. 'ce' or 'weighted_bce')")
+ parser.add_argument("--batch_size", type=int, default=DEFAULT_BATCH_SIZE, help="Number of examples to include in each batch")
+ parser.add_argument("--eval_file", type=str, default=os.path.join(os.path.dirname(os.path.dirname(__file__)), "test_sets", "combined_dev.txt"), help="Path to dev file used to evaluate model for saves")
+ parser.add_argument("--lr", type=float, default=0.001, help="Learning rate for the optimizer.")
+ parser.add_argument("--force", action='store_true', default=False, help='Whether or not to clobber an existing save file')
+
+ args = parser.parse_args(args) if predefined_args is None else predefined_args
+
+ save_name = args.save_name
+ num_epochs = args.num_epochs
+ train_file = args.train_file
+ loss_fn = args.loss_fn
+ eval_file = args.eval_file
+ lr = args.lr
+
+ args = vars(args)
+
+ if args['model_type'] == 'bert':
+ args['bert_model'] = 'bert-base-uncased'
+ elif args['model_type'] == 'roberta':
+ args['bert_model'] = 'roberta-base'
+ elif args['model_type'] == 'transformer':
+ if args['bert_model'] is None:
+ raise ValueError("Need to specify a bert_model for model_type transformer!")
+ else:
+ raise ValueError("Unknown model type " + args['model_type'])
+
+ if os.path.exists(save_name) and not args.get('force', False):
+ raise FileExistsError(f"Save name {save_name} already exists. Training would override existing data. Aborting...")
+ if not os.path.exists(train_file):
+ raise FileNotFoundError(f"Training file {train_file} not found. Try again with a valid path.")
+
+ logger.info("Running training script with the following args:")
+ for arg in args:
+ logger.info(f"{arg}: {args[arg]}")
+ logger.info("------------------------------------------------------------")
+
+ trainer = TransformerBaselineTrainer(model_args=args, transformer_name=args['bert_model'], loss_func=loss_fn, lr=lr)
+
+ trainer.train(num_epochs=num_epochs, save_name=save_name, train_file=train_file, args=args, eval_file=eval_file)
+ return trainer
+
+if __name__ == "__main__":
+ main()
diff --git a/stanza/stanza/models/lemma_classifier/transformer_model.py b/stanza/stanza/models/lemma_classifier/transformer_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e6f09ec91e004477a849cadf2d5c0536308c0b7
--- /dev/null
+++ b/stanza/stanza/models/lemma_classifier/transformer_model.py
@@ -0,0 +1,89 @@
+import torch
+import torch.nn as nn
+import os
+import sys
+import logging
+
+from transformers import AutoTokenizer, AutoModel
+from typing import Mapping, List, Tuple, Any
+from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, pad_sequence
+from stanza.models.common.bert_embedding import extract_bert_embeddings
+from stanza.models.lemma_classifier.base_model import LemmaClassifier
+from stanza.models.lemma_classifier.constants import ModelType
+
+logger = logging.getLogger('stanza.lemmaclassifier')
+
+class LemmaClassifierWithTransformer(LemmaClassifier):
+ def __init__(self, model_args: dict, output_dim: int, transformer_name: str, label_decoder: Mapping, target_words: set, target_upos: set):
+ """
+ Model architecture:
+
+ Use a transformer (BERT or RoBERTa) to extract contextual embedding over a sentence.
+ Get the embedding for the word that is to be classified on, and feed the embedding
+ as input to an MLP classifier that has 2 linear layers, and a prediction head.
+
+ Args:
+ model_args (dict): args for the model
+ output_dim (int): Dimension of the output from the MLP
+ transformer_name (str): name of the HF transformer to use
+ label_decoder (dict): a map of the labels available to the model
+ target_words (set(str)): a set of the words which might need lemmatization
+ """
+ super(LemmaClassifierWithTransformer, self).__init__(label_decoder, target_words, target_upos)
+ self.model_args = model_args
+
+ # Choose transformer
+ self.transformer_name = transformer_name
+ self.tokenizer = AutoTokenizer.from_pretrained(transformer_name, use_fast=True, add_prefix_space=True)
+ self.add_unsaved_module("transformer", AutoModel.from_pretrained(transformer_name))
+ config = self.transformer.config
+
+ embedding_size = config.hidden_size
+
+ # define an MLP layer
+ self.mlp = nn.Sequential(
+ nn.Linear(embedding_size, 64),
+ nn.ReLU(),
+ nn.Linear(64, output_dim)
+ )
+
+ def get_save_dict(self):
+ save_dict = {
+ "params": self.state_dict(),
+ "label_decoder": self.label_decoder,
+ "target_words": list(self.target_words),
+ "target_upos": list(self.target_upos),
+ "model_type": self.model_type().name,
+ "args": self.model_args,
+ }
+ skipped = [k for k in save_dict["params"].keys() if self.is_unsaved_module(k)]
+ for k in skipped:
+ del save_dict["params"][k]
+ return save_dict
+
+ def convert_tags(self, upos_tags: List[List[str]]):
+ return None
+
+ def forward(self, idx_positions: List[int], sentences: List[List[str]], upos_tags: List[List[int]]):
+ """
+ Computes the forward pass of the transformer baselines
+
+ Args:
+ idx_positions (List[int]): A list of the position index of the target token for lemmatization classification in each sentence.
+ sentences (List[List[str]]): A list of the token-split sentences of the input data.
+ upos_tags (List[List[int]]): A list of the upos tags for each token in every sentence - not used in this model, here for compatibility
+
+ Returns:
+ torch.tensor: Output logits of the neural network, where the shape is (n, output_size) where n is the number of sentences.
+ """
+ device = next(self.transformer.parameters()).device
+ bert_embeddings = extract_bert_embeddings(self.transformer_name, self.tokenizer, self.transformer, sentences, device,
+ keep_endpoints=False, num_layers=1, detach=True)
+ embeddings = [emb[idx] for idx, emb in zip(idx_positions, bert_embeddings)]
+ embeddings = torch.stack(embeddings, dim=0)[:, :, 0]
+ # pass to the MLP
+ output = self.mlp(embeddings)
+ return output
+
+ def model_type(self):
+ return ModelType.TRANSFORMER
diff --git a/stanza/stanza/models/lemma_classifier/utils.py b/stanza/stanza/models/lemma_classifier/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb95bde7c2fc41e24a684340e269d32c6c4a00cc
--- /dev/null
+++ b/stanza/stanza/models/lemma_classifier/utils.py
@@ -0,0 +1,173 @@
+from collections import Counter
+import json
+import logging
+import os
+import random
+from typing import List, Tuple, Any, Mapping
+
+import stanza
+import torch
+
+from stanza.models.lemma_classifier.constants import DEFAULT_BATCH_SIZE
+
+logger = logging.getLogger('stanza.lemmaclassifier')
+
+class Dataset:
+ def __init__(self, data_path: str, batch_size: int =DEFAULT_BATCH_SIZE, get_counts: bool = False, label_decoder: dict = None, shuffle: bool = True):
+ """
+ Loads a data file into data batches for tokenized text sentences, token indices, and true labels for each sentence.
+
+ Args:
+ data_path (str): Path to data file, containing tokenized text sentences, token index and true label for token lemma on each line.
+ batch_size (int): Size of each batch of examples
+ get_counts (optional, bool): Whether there should be a map of the label index to counts
+
+ Returns:
+ 1. List[List[List[str]]]: Batches of sentences, where each token is a separate entry in each sentence
+ 2. List[torch.tensor[int]]: A batch of indexes for the target token corresponding to its sentence
+ 3. List[torch.tensor[int]]: A batch of labels for the target token's lemma
+ 4. List[List[int]]: A batch of UPOS IDs for the target token (this is a List of Lists, not a tensor. It should be padded later.)
+ 5 (Optional): A mapping of label ID to counts in the dataset.
+ 6. Mapping[str, int]: A map between the labels and their indexes
+ 7. Mapping[str, int]: A map between the UPOS tags and their corresponding IDs found in the UPOS batches
+ """
+
+ if data_path is None or not os.path.exists(data_path):
+ raise FileNotFoundError(f"Data file {data_path} could not be found.")
+
+ if label_decoder is None:
+ label_decoder = {}
+ else:
+ # if labels in the test set aren't in the original model,
+ # the model will never predict those labels,
+ # but we can still use those labels in a confusion matrix
+ label_decoder = dict(label_decoder)
+
+ logger.debug("Final label decoder: %s Should be strings to ints", label_decoder)
+
+ # words which we are analyzing
+ target_words = set()
+
+ # all known words in the dataset, not just target words
+ known_words = set()
+
+ with open(data_path, "r+", encoding="utf-8") as fin:
+ sentences, indices, labels, upos_ids, counts, upos_to_id = [], [], [], [], Counter(), {}
+
+ input_json = json.load(fin)
+ sentences_data = input_json['sentences']
+ self.target_upos = input_json['upos']
+
+ for idx, sentence in enumerate(sentences_data):
+ # TODO Could replace this with sentence.values(), but need to know if Stanza requires Python 3.7 or later for backward compatability reasons
+ words, target_idx, upos_tags, label = sentence.get("words"), sentence.get("index"), sentence.get("upos_tags"), sentence.get("lemma")
+ if None in [words, target_idx, upos_tags, label]:
+ raise ValueError(f"Expected data to be complete but found a null value in sentence {idx}: {sentence}")
+
+ label_id = label_decoder.get(label, None)
+ if label_id is None:
+ label_decoder[label] = len(label_decoder) # create a new ID for the unknown label
+
+ converted_upos_tags = [] # convert upos tags to upos IDs
+ for upos_tag in upos_tags:
+ if upos_tag not in upos_to_id:
+ upos_to_id[upos_tag] = len(upos_to_id) # create a new ID for the unknown UPOS tag
+ converted_upos_tags.append(upos_to_id[upos_tag])
+
+ sentences.append(words)
+ indices.append(target_idx)
+ upos_ids.append(converted_upos_tags)
+ labels.append(label_decoder[label])
+
+ if get_counts:
+ counts[label_decoder[label]] += 1
+
+ target_words.add(words[target_idx])
+ known_words.update(words)
+
+ self.sentences = sentences
+ self.indices = indices
+ self.upos_ids = upos_ids
+ self.labels = labels
+
+ self.counts = counts
+ self.label_decoder = label_decoder
+ self.upos_to_id = upos_to_id
+
+ self.batch_size = batch_size
+ self.shuffle = shuffle
+
+ self.known_words = [x.lower() for x in sorted(known_words)]
+ self.target_words = set(x.lower() for x in target_words)
+
+ def __len__(self):
+ """
+ Number of batches, rounded up to nearest batch
+ """
+ return len(self.sentences) // self.batch_size + (len(self.sentences) % self.batch_size > 0)
+
+ def __iter__(self):
+ num_sentences = len(self.sentences)
+ indices = list(range(num_sentences))
+ if self.shuffle:
+ random.shuffle(indices)
+ for i in range(self.__len__()):
+ batch_start = self.batch_size * i
+ batch_end = min(batch_start + self.batch_size, num_sentences)
+
+ batch_sentences = [self.sentences[x] for x in indices[batch_start:batch_end]]
+ batch_indices = torch.tensor([self.indices[x] for x in indices[batch_start:batch_end]])
+ batch_upos_ids = [self.upos_ids[x] for x in indices[batch_start:batch_end]]
+ batch_labels = torch.tensor([self.labels[x] for x in indices[batch_start:batch_end]])
+ yield batch_sentences, batch_indices, batch_upos_ids, batch_labels
+
+def extract_unknown_token_indices(tokenized_indices: torch.tensor, unknown_token_idx: int) -> List[int]:
+ """
+ Extracts the indices within `tokenized_indices` which match `unknown_token_idx`
+
+ Args:
+ tokenized_indices (torch.tensor): A tensor filled with tokenized indices of words that have been mapped to vector indices.
+ unknown_token_idx (int): The special index for which unknown tokens are marked in the word vectors.
+
+ Returns:
+ List[int]: A list of indices in `tokenized_indices` which match `unknown_token_index`
+ """
+ return [idx for idx, token_index in enumerate(tokenized_indices) if token_index == unknown_token_idx]
+
+
+def get_device():
+ """
+ Get the device to run computations on
+ """
+ if torch.cuda.is_available:
+ device = torch.device("cuda")
+ if torch.backends.mps.is_available():
+ device = torch.device("mps")
+ else:
+ device = torch.device("cpu")
+
+ return device
+
+
+def round_up_to_multiple(number, multiple):
+ if multiple == 0:
+ return "Error: The second number (multiple) cannot be zero."
+
+ # Calculate the remainder when dividing the number by the multiple
+ remainder = number % multiple
+
+ # If remainder is non-zero, round up to the next multiple
+ if remainder != 0:
+ rounded_number = number + (multiple - remainder)
+ else:
+ rounded_number = number # No rounding needed
+
+ return rounded_number
+
+
+def main():
+ default_test_path = os.path.join(os.path.dirname(__file__), "test_sets", "processed_ud_en", "combined_dev.txt") # get the GUM stuff
+ sentence_batches, indices_batches, upos_batches, _, counts, _, upos_to_id = load_dataset(default_test_path, get_counts=True)
+
+if __name__ == "__main__":
+ main()
diff --git a/stanza/stanza/models/mwt/character_classifier.py b/stanza/stanza/models/mwt/character_classifier.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ae3f1023a11a228ffb2b26a93e1d6a6496b9f1d
--- /dev/null
+++ b/stanza/stanza/models/mwt/character_classifier.py
@@ -0,0 +1,65 @@
+"""
+Classify characters based on an LSTM with learned character representations
+"""
+
+import logging
+
+import torch
+from torch import nn
+
+import stanza.models.common.seq2seq_constant as constant
+
+logger = logging.getLogger('stanza')
+
+class CharacterClassifier(nn.Module):
+ def __init__(self, args):
+ super().__init__()
+
+ self.vocab_size = args['vocab_size']
+ self.emb_dim = args['emb_dim']
+ self.hidden_dim = args['hidden_dim']
+ self.nlayers = args['num_layers'] # lstm encoder layers
+ self.pad_token = constant.PAD_ID
+ self.enc_hidden_dim = self.hidden_dim // 2 # since it is bidirectional
+
+ self.num_outputs = 2
+
+ self.args = args
+
+ self.emb_dropout = args.get('emb_dropout', 0.0)
+ self.emb_drop = nn.Dropout(self.emb_dropout)
+ self.dropout = args['dropout']
+
+ self.embedding = nn.Embedding(self.vocab_size, self.emb_dim, self.pad_token)
+ self.input_dim = self.emb_dim
+ self.encoder = nn.LSTM(self.input_dim, self.enc_hidden_dim, self.nlayers, \
+ bidirectional=True, batch_first=True, dropout=self.dropout if self.nlayers > 1 else 0)
+
+ self.output_layer = nn.Sequential(
+ nn.Linear(self.hidden_dim, self.hidden_dim),
+ nn.ReLU(),
+ nn.Linear(self.hidden_dim, self.num_outputs))
+
+ def encode(self, enc_inputs, lens):
+ """ Encode source sequence. """
+ packed_inputs = nn.utils.rnn.pack_padded_sequence(enc_inputs, lens, batch_first=True)
+ packed_h_in, (hn, cn) = self.encoder(packed_inputs)
+ return packed_h_in
+
+ def embed(self, src, src_mask):
+ # the input data could have characters outside the known range
+ # of characters in cases where the vocabulary was temporarily
+ # expanded (note that this model does nothing with those chars)
+ embed_src = src.clone()
+ embed_src[embed_src >= self.vocab_size] = constant.UNK_ID
+ enc_inputs = self.emb_drop(self.embedding(embed_src))
+ batch_size = enc_inputs.size(0)
+ src_lens = list(src_mask.data.eq(self.pad_token).long().sum(1))
+ return enc_inputs, batch_size, src_lens, src_mask
+
+ def forward(self, src, src_mask):
+ enc_inputs, batch_size, src_lens, src_mask = self.embed(src, src_mask)
+ encoded = self.encode(enc_inputs, src_lens)
+ encoded, _ = nn.utils.rnn.pad_packed_sequence(encoded, batch_first=True)
+ logits = self.output_layer(encoded)
+ return logits
diff --git a/stanza/stanza/models/mwt/trainer.py b/stanza/stanza/models/mwt/trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..090df806da631af092c6d3ec6fddf99b9fb0a96a
--- /dev/null
+++ b/stanza/stanza/models/mwt/trainer.py
@@ -0,0 +1,218 @@
+"""
+A trainer class to handle training and testing of models.
+"""
+
+import sys
+import numpy as np
+from collections import Counter
+import logging
+import torch
+from torch import nn
+import torch.nn.init as init
+
+import stanza.models.common.seq2seq_constant as constant
+from stanza.models.common.trainer import Trainer as BaseTrainer
+from stanza.models.common.seq2seq_model import Seq2SeqModel
+from stanza.models.common import utils, loss
+from stanza.models.mwt.character_classifier import CharacterClassifier
+from stanza.models.mwt.vocab import Vocab
+
+logger = logging.getLogger('stanza')
+
+def unpack_batch(batch, device):
+ """ Unpack a batch from the data loader. """
+ inputs = [b.to(device) if b is not None else None for b in batch[:4]]
+ orig_text = batch[4]
+ orig_idx = batch[5]
+ return inputs, orig_text, orig_idx
+
+class Trainer(BaseTrainer):
+ """ A trainer for training models. """
+ def __init__(self, args=None, vocab=None, emb_matrix=None, model_file=None, device=None):
+ if model_file is not None:
+ # load from file
+ self.load(model_file)
+ else:
+ self.args = args
+ if args['dict_only']:
+ self.model = None
+ elif args.get('force_exact_pieces', False):
+ self.model = CharacterClassifier(args)
+ else:
+ self.model = Seq2SeqModel(args, emb_matrix=emb_matrix)
+ self.vocab = vocab
+ self.expansion_dict = dict()
+ if not self.args['dict_only']:
+ self.model = self.model.to(device)
+ if self.args.get('force_exact_pieces', False):
+ self.crit = nn.CrossEntropyLoss()
+ else:
+ self.crit = loss.SequenceLoss(self.vocab.size).to(device)
+ self.optimizer = utils.get_optimizer(self.args['optim'], self.model, self.args['lr'])
+
+ def update(self, batch, eval=False):
+ device = next(self.model.parameters()).device
+ # ignore the original text when training
+ # can try to learn the correct values, even if we eventually
+ # copy directly from the original text
+ inputs, _, orig_idx = unpack_batch(batch, device)
+ src, src_mask, tgt_in, tgt_out = inputs
+
+ if eval:
+ self.model.eval()
+ else:
+ self.model.train()
+ self.optimizer.zero_grad()
+ if self.args.get('force_exact_pieces', False):
+ log_probs = self.model(src, src_mask)
+ src_lens = list(src_mask.data.eq(constant.PAD_ID).long().sum(1))
+ packed_output = nn.utils.rnn.pack_padded_sequence(log_probs, src_lens, batch_first=True)
+ packed_tgt = nn.utils.rnn.pack_padded_sequence(tgt_in, src_lens, batch_first=True)
+ loss = self.crit(packed_output.data, packed_tgt.data)
+ else:
+ log_probs, _ = self.model(src, src_mask, tgt_in)
+ loss = self.crit(log_probs.view(-1, self.vocab.size), tgt_out.view(-1))
+ loss_val = loss.data.item()
+ if eval:
+ return loss_val
+
+ loss.backward()
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm'])
+ self.optimizer.step()
+ return loss_val
+
+ def predict(self, batch, unsort=True, never_decode_unk=False, vocab=None):
+ if vocab is None:
+ vocab = self.vocab
+
+ device = next(self.model.parameters()).device
+ inputs, orig_text, orig_idx = unpack_batch(batch, device)
+ src, src_mask, tgt, tgt_mask = inputs
+
+ self.model.eval()
+ batch_size = src.size(0)
+ if self.args.get('force_exact_pieces', False):
+ log_probs = self.model(src, src_mask)
+ cuts = log_probs[:, :, 1] > log_probs[:, :, 0]
+ src_lens = list(src_mask.data.eq(constant.PAD_ID).long().sum(1))
+ pred_tokens = []
+ for src_ids, cut, src_len in zip(src, cuts, src_lens):
+ src_chars = vocab.unmap(src_ids)
+ pred_seq = []
+ for char_idx in range(1, src_len-1):
+ if cut[char_idx]:
+ pred_seq.append(' ')
+ pred_seq.append(src_chars[char_idx])
+ pred_seq = "".join(pred_seq).strip()
+ pred_tokens.append(pred_seq)
+ else:
+ preds, _ = self.model.predict(src, src_mask, self.args['beam_size'], never_decode_unk=never_decode_unk)
+ pred_seqs = [vocab.unmap(ids) for ids in preds] # unmap to tokens
+ pred_seqs = utils.prune_decoded_seqs(pred_seqs)
+
+ pred_tokens = ["".join(seq) for seq in pred_seqs] # join chars to be tokens
+ # if any tokens are predicted to expand to blank,
+ # that is likely an error. use the original text
+ # this originally came up with the Spanish model turning 's' into a blank
+ # furthermore, if there are no spaces predicted by the seq2seq,
+ # might as well use the original in case the seq2seq went crazy
+ # this particular error came up training a Hebrew MWT
+ pred_tokens = [x if x and ' ' in x else y for x, y in zip(pred_tokens, orig_text)]
+ if unsort:
+ pred_tokens = utils.unsort(pred_tokens, orig_idx)
+ return pred_tokens
+
+ def train_dict(self, pairs):
+ """ Train a MWT expander given training word-expansion pairs. """
+ # accumulate counter
+ ctr = Counter()
+ ctr.update([(p[0], p[1]) for p in pairs])
+ seen = set()
+ # find the most frequent mappings
+ for p, _ in ctr.most_common():
+ w, l = p
+ if w not in seen and w != l:
+ self.expansion_dict[w] = l
+ seen.add(w)
+ return
+
+ def dict_expansion(self, word):
+ """
+ Check the expansion dictionary for the word along with a couple common lowercasings of the word
+
+ (Leadingcase and UPPERCASE)
+ """
+ expansion = self.expansion_dict.get(word)
+ if expansion is not None:
+ return expansion
+
+ if word.isupper():
+ expansion = self.expansion_dict.get(word.lower())
+ if expansion is not None:
+ return expansion.upper()
+
+ if word[0].isupper() and word[1:].islower():
+ expansion = self.expansion_dict.get(word.lower())
+ if expansion is not None:
+ return expansion[0].upper() + expansion[1:]
+
+ # could build a truecasing model of some kind to handle cRaZyCaSe...
+ # but that's probably too much effort
+ return None
+
+ def predict_dict(self, words):
+ """ Predict a list of expansions given words. """
+ expansions = []
+ for w in words:
+ expansion = self.dict_expansion(w)
+ if expansion is not None:
+ expansions.append(expansion)
+ else:
+ expansions.append(w)
+ return expansions
+
+ def ensemble(self, cands, other_preds):
+ """ Ensemble the dict with statistical model predictions. """
+ expansions = []
+ assert len(cands) == len(other_preds)
+ for c, pred in zip(cands, other_preds):
+ expansion = self.dict_expansion(c)
+ if expansion is not None:
+ expansions.append(expansion)
+ else:
+ expansions.append(pred)
+ return expansions
+
+ def save(self, filename):
+ params = {
+ 'model': self.model.state_dict() if self.model is not None else None,
+ 'dict': self.expansion_dict,
+ 'vocab': self.vocab.state_dict(),
+ 'config': self.args
+ }
+ try:
+ torch.save(params, filename, _use_new_zipfile_serialization=False)
+ logger.info("Model saved to {}".format(filename))
+ except BaseException:
+ logger.warning("Saving failed... continuing anyway.")
+
+ def load(self, filename):
+ try:
+ checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True)
+ except BaseException:
+ logger.error("Cannot load model from {}".format(filename))
+ raise
+ self.args = checkpoint['config']
+ self.expansion_dict = checkpoint['dict']
+ if not self.args['dict_only']:
+ if self.args.get('force_exact_pieces', False):
+ self.model = CharacterClassifier(self.args)
+ else:
+ self.model = Seq2SeqModel(self.args)
+ # could remove strict=False after rebuilding all models,
+ # or could switch to 1.6.0 torch with the buffer in seq2seq persistent=False
+ self.model.load_state_dict(checkpoint['model'], strict=False)
+ else:
+ self.model = None
+ self.vocab = Vocab.load_state_dict(checkpoint['vocab'])
+
diff --git a/stanza/stanza/models/mwt/vocab.py b/stanza/stanza/models/mwt/vocab.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c861e7a49a0aba7fd7b9ef8717edb39ca19a2a9
--- /dev/null
+++ b/stanza/stanza/models/mwt/vocab.py
@@ -0,0 +1,19 @@
+from collections import Counter
+
+from stanza.models.common.vocab import BaseVocab
+import stanza.models.common.seq2seq_constant as constant
+
+class Vocab(BaseVocab):
+ def build_vocab(self):
+ pairs = self.data
+ allchars = "".join([src + tgt for src, tgt in pairs])
+ counter = Counter(allchars)
+
+ self._id2unit = constant.VOCAB_PREFIX + list(sorted(list(counter.keys()), key=lambda k: counter[k], reverse=True))
+ self._unit2id = {w:i for i, w in enumerate(self._id2unit)}
+
+ def add_unit(self, unit):
+ if unit in self._unit2id:
+ return
+ self._unit2id[unit] = len(self._id2unit)
+ self._id2unit.append(unit)
diff --git a/stanza/stanza/models/ner/vocab.py b/stanza/stanza/models/ner/vocab.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d7ad11c585999a0dadd10eca630b10ef12313db
--- /dev/null
+++ b/stanza/stanza/models/ner/vocab.py
@@ -0,0 +1,56 @@
+from collections import Counter, OrderedDict
+
+from stanza.models.common.vocab import BaseVocab, BaseMultiVocab, CharVocab, CompositeVocab
+from stanza.models.common.vocab import VOCAB_PREFIX
+from stanza.models.common.pretrain import PretrainedWordVocab
+from stanza.models.pos.vocab import WordVocab
+
+class TagVocab(BaseVocab):
+ """ A vocab for the output tag sequence. """
+ def build_vocab(self):
+ counter = Counter([w[self.idx] for sent in self.data for w in sent])
+
+ self._id2unit = VOCAB_PREFIX + list(sorted(list(counter.keys()), key=lambda k: counter[k], reverse=True))
+ self._unit2id = {w:i for i, w in enumerate(self._id2unit)}
+
+def convert_tag_vocab(state_dict):
+ if state_dict['lower']:
+ raise AssertionError("Did not expect an NER vocab with 'lower' set to True")
+ items = state_dict['_id2unit'][len(VOCAB_PREFIX):]
+ # this looks silly, but the vocab builder treats this as words with multiple fields
+ # (we set it to look for field 0 with idx=0)
+ # and then the label field is expected to be a list or tuple of items
+ items = [[[[x]]] for x in items]
+ vocab = CompositeVocab(data=items, lang=state_dict['lang'], idx=0, sep=None)
+ if len(vocab._id2unit[0]) != len(state_dict['_id2unit']):
+ raise AssertionError("Failed to construct a new vocab of the same length as the original")
+ if vocab._id2unit[0] != state_dict['_id2unit']:
+ raise AssertionError("Failed to construct a new vocab in the same order as the original")
+ return vocab
+
+class MultiVocab(BaseMultiVocab):
+ def state_dict(self):
+ """ Also save a vocab name to class name mapping in state dict. """
+ state = OrderedDict()
+ key2class = OrderedDict()
+ for k, v in self._vocabs.items():
+ state[k] = v.state_dict()
+ key2class[k] = type(v).__name__
+ state['_key2class'] = key2class
+ return state
+
+ @classmethod
+ def load_state_dict(cls, state_dict):
+ class_dict = {'CharVocab': CharVocab.load_state_dict,
+ 'PretrainedWordVocab': PretrainedWordVocab.load_state_dict,
+ 'TagVocab': convert_tag_vocab,
+ 'CompositeVocab': CompositeVocab.load_state_dict,
+ 'WordVocab': WordVocab.load_state_dict}
+ new = cls()
+ assert '_key2class' in state_dict, "Cannot find class name mapping in state dict!"
+ key2class = state_dict.pop('_key2class')
+ for k,v in state_dict.items():
+ classname = key2class[k]
+ new[k] = class_dict[classname](v)
+ return new
+
diff --git a/stanza/stanza/models/pos/__init__.py b/stanza/stanza/models/pos/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/stanza/stanza/models/pos/build_xpos_vocab_factory.py b/stanza/stanza/models/pos/build_xpos_vocab_factory.py
new file mode 100644
index 0000000000000000000000000000000000000000..3c09bce4ac34ca49f4ccd2e82656dbed8ff20a1e
--- /dev/null
+++ b/stanza/stanza/models/pos/build_xpos_vocab_factory.py
@@ -0,0 +1,144 @@
+import argparse
+from collections import defaultdict
+import logging
+import os
+import re
+import sys
+from zipfile import ZipFile
+
+from stanza.models.common.constant import treebank_to_short_name
+from stanza.models.pos.xpos_vocab_utils import DEFAULT_KEY, choose_simplest_factory, XPOSType
+from stanza.models.common.doc import *
+from stanza.utils.conll import CoNLL
+from stanza.utils import default_paths
+
+SHORTNAME_RE = re.compile("[a-z-]+_[a-z0-9]+")
+DATA_DIR = default_paths.get_default_paths()['POS_DATA_DIR']
+
+logger = logging.getLogger('stanza')
+
+def get_xpos_factory(shorthand, fn):
+ logger.info('Resolving vocab option for {}...'.format(shorthand))
+ doc = None
+ train_file = os.path.join(DATA_DIR, '{}.train.in.conllu'.format(shorthand))
+ if os.path.exists(train_file):
+ doc = CoNLL.conll2doc(input_file=train_file)
+ else:
+ zip_file = os.path.join(DATA_DIR, '{}.train.in.zip'.format(shorthand))
+ if os.path.exists(zip_file):
+ with ZipFile(zip_file) as zin:
+ for train_file in zin.namelist():
+ doc = CoNLL.conll2doc(input_file=train_file, zip_file=zip_file)
+ if any(word.xpos for sentence in doc.sentences for word in sentence.words):
+ break
+ else:
+ raise ValueError('Found training data in {}, but none of the files contained had xpos'.format(zip_file))
+
+ if doc is None:
+ raise FileNotFoundError('Training data for {} not found. To generate the XPOS vocabulary '
+ 'for this treebank properly, please run the following command first:\n'
+ ' python3 stanza/utils/datasets/prepare_pos_treebank.py {}'.format(fn, fn))
+ # without the training file, there's not much we can do
+ key = DEFAULT_KEY
+ return key
+
+ data = doc.get([TEXT, UPOS, XPOS, FEATS], as_sentences=True)
+ return choose_simplest_factory(data, shorthand)
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--treebanks', type=str, default=DATA_DIR, help="Treebanks to process - directory with processed datasets or a file with a list")
+ parser.add_argument('--output_file', type=str, default="stanza/models/pos/xpos_vocab_factory.py", help="Where to write the results")
+ args = parser.parse_args()
+
+ output_file = args.output_file
+ if os.path.isdir(args.treebanks):
+ # if the path is a directory of datasets (which is the default if --treebanks is not set)
+ # we use those datasets to prepare the xpos factories
+ treebanks = os.listdir(args.treebanks)
+ treebanks = [x.split(".", maxsplit=1)[0] for x in treebanks]
+ treebanks = sorted(set(treebanks))
+ elif os.path.exists(args.treebanks):
+ # maybe it's a file with a list of names
+ with open(args.treebanks) as fin:
+ treebanks = sorted(set([x.strip() for x in fin.readlines() if x.strip()]))
+ else:
+ raise ValueError("Cannot figure out which treebanks to use. Please set the --treebanks parameter")
+
+ logger.info("Processing the following treebanks: %s" % " ".join(treebanks))
+
+ shorthands = []
+ fullnames = []
+ for treebank in treebanks:
+ fullnames.append(treebank)
+ if SHORTNAME_RE.match(treebank):
+ shorthands.append(treebank)
+ else:
+ shorthands.append(treebank_to_short_name(treebank))
+
+ # For each treebank, we would like to find the XPOS Vocab configuration that minimizes
+ # the number of total classes needed to predict by all tagger classifiers. This is
+ # achieved by enumerating different options of separators that different treebanks might
+ # use, and comparing that to treating the XPOS tags as separate categories (using a
+ # WordVocab).
+ mapping = defaultdict(list)
+ for sh, fn in zip(shorthands, fullnames):
+ factory = get_xpos_factory(sh, fn)
+ mapping[factory].append(sh)
+ if sh == 'zh-hans_gsdsimp':
+ mapping[factory].append('zh_gsdsimp')
+ elif sh == 'no_bokmaal':
+ mapping[factory].append('nb_bokmaal')
+
+ mapping[DEFAULT_KEY].append('en_test')
+
+ # Generate code. This takes the XPOS vocabulary classes selected above, and generates the
+ # actual factory class as seen in models.pos.xpos_vocab_factory.
+ first = True
+ with open(output_file, 'w') as f:
+ max_len = max(max(len(x) for x in mapping[key]) for key in mapping)
+ print('''# This is the XPOS factory method generated automatically from stanza.models.pos.build_xpos_vocab_factory.
+# Please don't edit it!
+
+import logging
+
+from stanza.models.pos.vocab import WordVocab, XPOSVocab
+from stanza.models.pos.xpos_vocab_utils import XPOSDescription, XPOSType, build_xpos_vocab, choose_simplest_factory
+
+# using a sublogger makes it easier to test in the unittests
+logger = logging.getLogger('stanza.models.pos.xpos_vocab_factory')
+
+XPOS_DESCRIPTIONS = {''', file=f)
+
+ for key_idx, key in enumerate(mapping):
+ if key_idx > 0:
+ print(file=f)
+ for shorthand in sorted(mapping[key]):
+ # +2 to max_len for the ''
+ # this format string is left justified (either would be okay, probably)
+ if key.sep is None:
+ sep = 'None'
+ else:
+ sep = "'%s'" % key.sep
+ print((" {:%ds}: XPOSDescription({}, {})," % (max_len+2)).format("'%s'" % shorthand, key.xpos_type, sep), file=f)
+
+ print('''}
+
+def xpos_vocab_factory(data, shorthand):
+ if shorthand not in XPOS_DESCRIPTIONS:
+ logger.warning("%s is not a known dataset. Examining the data to choose which xpos vocab to use", shorthand)
+ desc = choose_simplest_factory(data, shorthand)
+ if shorthand in XPOS_DESCRIPTIONS:
+ if XPOS_DESCRIPTIONS[shorthand] != desc:
+ # log instead of throw
+ # otherwise, updating datasets would be unpleasant
+ logger.error("XPOS tagset in %s has apparently changed! Was %s, is now %s", shorthand, XPOS_DESCRIPTIONS[shorthand], desc)
+ else:
+ logger.warning("Chose %s for the xpos factory for %s", desc, shorthand)
+ return build_xpos_vocab(desc, data, shorthand)
+''', file=f)
+
+ logger.info('Done!')
+
+if __name__ == "__main__":
+ main()
diff --git a/stanza/stanza/models/pos/data.py b/stanza/stanza/models/pos/data.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae12a1bfe6a29f1c6085c880e7c861d858073d42
--- /dev/null
+++ b/stanza/stanza/models/pos/data.py
@@ -0,0 +1,387 @@
+import random
+import logging
+import copy
+import torch
+from collections import namedtuple
+
+from torch.utils.data import DataLoader as DL
+from torch.utils.data.sampler import Sampler
+from torch.nn.utils.rnn import pad_sequence
+
+from stanza.models.common.bert_embedding import filter_data, needs_length_filter
+from stanza.models.common.data import map_to_ids, get_long_tensor, get_float_tensor, sort_all
+from stanza.models.common.vocab import PAD_ID, VOCAB_PREFIX, CharVocab
+from stanza.models.pos.vocab import WordVocab, XPOSVocab, FeatureVocab, MultiVocab
+from stanza.models.pos.xpos_vocab_factory import xpos_vocab_factory
+from stanza.models.common.doc import *
+
+logger = logging.getLogger('stanza')
+
+DataSample = namedtuple("DataSample", "word char upos xpos feats pretrain text")
+DataBatch = namedtuple("DataBatch", "words words_mask wordchars wordchars_mask upos xpos ufeats pretrained orig_idx word_orig_idx lens word_lens text idx")
+
+class Dataset:
+ def __init__(self, doc, args, pretrain, vocab=None, evaluation=False, sort_during_eval=False, bert_tokenizer=None, **kwargs):
+ self.args = args
+ self.eval = evaluation
+ self.shuffled = not self.eval
+ self.sort_during_eval = sort_during_eval
+ self.doc = doc
+
+ if vocab is None:
+ self.vocab = Dataset.init_vocab([doc], args)
+ else:
+ self.vocab = vocab
+
+ self.has_upos = not all(x is None or x == '_' for x in doc.get(UPOS, as_sentences=False))
+ self.has_xpos = not all(x is None or x == '_' for x in doc.get(XPOS, as_sentences=False))
+ self.has_feats = not all(x is None or x == '_' for x in doc.get(FEATS, as_sentences=False))
+
+ data = self.load_doc(self.doc)
+ # filter out the long sentences if bert is used
+ if self.args.get('bert_model', None) and needs_length_filter(self.args['bert_model']):
+ data = filter_data(self.args['bert_model'], data, bert_tokenizer)
+
+ # handle pretrain; pretrain vocab is used when args['pretrain'] == True and pretrain is not None
+ self.pretrain_vocab = None
+ if pretrain is not None and args['pretrain']:
+ self.pretrain_vocab = pretrain.vocab
+
+ # filter and sample data
+ if args.get('sample_train', 1.0) < 1.0 and not self.eval:
+ keep = int(args['sample_train'] * len(data))
+ data = random.sample(data, keep)
+ logger.debug("Subsample training set with rate {:g}".format(args['sample_train']))
+
+ data = self.preprocess(data, self.vocab, self.pretrain_vocab, args)
+
+ self.data = data
+
+ self.num_examples = len(data)
+ self.__punct_tags = self.vocab["upos"].map(["PUNCT"])
+ self.augment_nopunct = self.args.get("augment_nopunct", 0.0)
+
+ @staticmethod
+ def init_vocab(docs, args):
+ data = [x for doc in docs for x in Dataset.load_doc(doc)]
+ charvocab = CharVocab(data, args['shorthand'])
+ wordvocab = WordVocab(data, args['shorthand'], cutoff=args['word_cutoff'], lower=True)
+ uposvocab = WordVocab(data, args['shorthand'], idx=1)
+ xposvocab = xpos_vocab_factory(data, args['shorthand'])
+ try:
+ featsvocab = FeatureVocab(data, args['shorthand'], idx=3)
+ except ValueError as e:
+ raise ValueError("Unable to build features vocab. Please check the Features column of your data for an error which may match the following description.") from e
+ vocab = MultiVocab({'char': charvocab,
+ 'word': wordvocab,
+ 'upos': uposvocab,
+ 'xpos': xposvocab,
+ 'feats': featsvocab})
+ return vocab
+
+ def preprocess(self, data, vocab, pretrain_vocab, args):
+ processed = []
+ for sent in data:
+ processed_sent = DataSample(
+ word = [vocab['word'].map([w[0] for w in sent])],
+ char = [[vocab['char'].map([x for x in w[0]]) for w in sent]],
+ upos = [vocab['upos'].map([w[1] for w in sent])],
+ xpos = [vocab['xpos'].map([w[2] for w in sent])],
+ feats = [vocab['feats'].map([w[3] for w in sent])],
+ pretrain = ([pretrain_vocab.map([w[0].lower() for w in sent])]
+ if pretrain_vocab is not None
+ else [[PAD_ID] * len(sent)]),
+ text = [w[0] for w in sent]
+ )
+ processed.append(processed_sent)
+
+ return processed
+
+ def __len__(self):
+ return len(self.data)
+
+ def __mask(self, upos):
+ """Returns a torch boolean about which elements should be masked out"""
+
+ # creates all false mask
+ mask = torch.zeros_like(upos, dtype=torch.bool)
+
+ ### augmentation 1: punctuation augmentation ###
+ # tags that needs to be checked, currently only PUNCT
+ if random.uniform(0,1) < self.augment_nopunct:
+ for i in self.__punct_tags:
+ # generate a mask for the last element
+ last_element = torch.zeros_like(upos, dtype=torch.bool)
+ last_element[..., -1] = True
+ # we or the bitmask against the existing mask
+ # if it satisfies, we remove the word by masking it
+ # to true
+ #
+ # if your input is just a lone punctuation, we perform
+ # no masking
+ if not torch.all(upos.eq(torch.tensor([[i]]))):
+ mask |= ((upos == i) & (last_element))
+
+ return mask
+
+ def __getitem__(self, key):
+ """Retrieves a sample from the dataset.
+
+ Retrieves a sample from the dataset. This function, for the
+ most part, is spent performing ad-hoc data augmentation and
+ restoration. It recieves a DataSample object from the storage,
+ and returns an almost-identical DataSample object that may
+ have been augmented with /possibly/ (depending on augment_punct
+ settings) PUNCT chopped.
+
+ **Important Note**
+ ------------------
+ If you would like to load the data into a model, please convert
+ this Dataset object into a DataLoader via self.to_loader(). Then,
+ you can use the resulting object like any other PyTorch data
+ loader. As masks are calculated ad-hoc given the batch, the samples
+ returned from this object doesn't have the appropriate masking.
+
+ Motivation
+ ----------
+ Why is this here? Every time you call next(iter(dataloader)), it calls
+ this function. Therefore, if we augmented each sample on each iteration,
+ the model will see dynamically generated augmentation.
+ Furthermore, PyTorch dataloader handles shuffling natively.
+
+ Parameters
+ ----------
+ key : int
+ the integer ID to from which to retrieve the key.
+
+ Returns
+ -------
+ DataSample
+ The sample of data you requested, with augmentation.
+ """
+ # get a sample of the input data
+ sample = self.data[key]
+
+ # some data augmentation requires constructing a mask based on upos.
+ # For instance, sometimes we'd like to mask out ending sentence punctuation.
+ # We copy the other items here so that any edits made because
+ # of the mask don't clobber the version owned by the Dataset
+ # convert to tensors
+ # TODO: only store single lists per data entry?
+ words = torch.tensor(sample.word[0])
+ # convert the rest to tensors
+ upos = torch.tensor(sample.upos[0]) if self.has_upos else None
+ xpos = torch.tensor(sample.xpos[0]) if self.has_xpos else None
+ ufeats = torch.tensor(sample.feats[0]) if self.has_feats else None
+ pretrained = torch.tensor(sample.pretrain[0])
+
+ # and deal with char & raw_text
+ char = sample.char[0]
+ raw_text = sample.text
+
+ # some data augmentation requires constructing a mask based on
+ # which upos. For instance, sometimes we'd like to mask out ending
+ # sentence punctuation. The mask is True if we want to remove the element
+ if self.has_upos and upos is not None and not self.eval:
+ # perform actual masking
+ mask = self.__mask(upos)
+ else:
+ # dummy mask that's all false
+ mask = None
+ if mask is not None:
+ mask_index = mask.nonzero()
+
+ # mask out the elements that we need to mask out
+ for mask in mask_index:
+ mask = mask.item()
+ words[mask] = PAD_ID
+ if upos is not None:
+ upos[mask] = PAD_ID
+ if xpos is not None:
+ # TODO: test the multi-dimension xpos
+ xpos[mask, ...] = PAD_ID
+ if ufeats is not None:
+ ufeats[mask, ...] = PAD_ID
+ pretrained[mask] = PAD_ID
+ char = char[:mask] + char[mask+1:]
+ raw_text = raw_text[:mask] + raw_text[mask+1:]
+
+ # get each character from the input sentnece
+ # chars = [w for sent in char for w in sent]
+
+ return DataSample(words, char, upos, xpos, ufeats, pretrained, raw_text), key
+
+ def __iter__(self):
+ for i in range(self.__len__()):
+ yield self.__getitem__(i)
+
+ def to_loader(self, **kwargs):
+ """Converts self to a DataLoader """
+
+ return DL(self,
+ collate_fn=Dataset.__collate_fn,
+ **kwargs)
+
+ def to_length_limited_loader(self, batch_size, maximum_tokens):
+ sampler = LengthLimitedBatchSampler(self, batch_size, maximum_tokens)
+ return DL(self,
+ collate_fn=Dataset.__collate_fn,
+ batch_sampler = sampler)
+
+ @staticmethod
+ def __collate_fn(data):
+ """Function used by DataLoader to pack data"""
+ (data, idx) = zip(*data)
+ (words, wordchars, upos, xpos, ufeats, pretrained, text) = zip(*data)
+
+ # collate_fn is given a list of length batch size
+ batch_size = len(data)
+
+ # sort sentences by lens for easy RNN operations
+ lens = [torch.sum(x != PAD_ID) for x in words]
+ (words, wordchars, upos, xpos,
+ ufeats, pretrained, text), orig_idx = sort_all((words, wordchars, upos, xpos,
+ ufeats, pretrained, text), lens)
+ lens = [torch.sum(x != PAD_ID) for x in words] # we need to reinterpret lengths for the RNN
+
+ # combine all words into one large list, and sort for easy charRNN ops
+ wordchars = [w for sent in wordchars for w in sent]
+ word_lens = [len(x) for x in wordchars]
+ (wordchars,), word_orig_idx = sort_all([wordchars], word_lens)
+ word_lens = [len(x) for x in wordchars] # we need to reinterpret lengths for the RNN
+
+ # We now pad everything
+ words = pad_sequence(words, True, PAD_ID)
+ if None not in upos:
+ upos = pad_sequence(upos, True, PAD_ID)
+ else:
+ upos = None
+ if None not in xpos:
+ xpos = pad_sequence(xpos, True, PAD_ID)
+ else:
+ xpos = None
+ if None not in ufeats:
+ ufeats = pad_sequence(ufeats, True, PAD_ID)
+ else:
+ ufeats = None
+ pretrained = pad_sequence(pretrained, True, PAD_ID)
+ wordchars = get_long_tensor(wordchars, len(word_lens))
+
+ # and finally create masks for the padding indices
+ words_mask = torch.eq(words, PAD_ID)
+ wordchars_mask = torch.eq(wordchars, PAD_ID)
+
+ return DataBatch(words, words_mask, wordchars, wordchars_mask, upos, xpos, ufeats,
+ pretrained, orig_idx, word_orig_idx, lens, word_lens, text, idx)
+
+ @staticmethod
+ def load_doc(doc):
+ data = doc.get([TEXT, UPOS, XPOS, FEATS], as_sentences=True)
+ data = Dataset.resolve_none(data)
+ return data
+
+ @staticmethod
+ def resolve_none(data):
+ # replace None to '_'
+ for sent_idx in range(len(data)):
+ for tok_idx in range(len(data[sent_idx])):
+ for feat_idx in range(len(data[sent_idx][tok_idx])):
+ if data[sent_idx][tok_idx][feat_idx] is None:
+ data[sent_idx][tok_idx][feat_idx] = '_'
+ return data
+
+class LengthLimitedBatchSampler(Sampler):
+ """
+ Batches up the text in batches of batch_size, but cuts off each time a batch reaches maximum_tokens
+
+ Intent is to avoid GPU OOM in situations where one sentence is significantly longer than expected,
+ leaving a batch too large to fit in the GPU
+
+ Sentences which are longer than maximum_tokens by themselves are put in their own batches
+ """
+ def __init__(self, data, batch_size, maximum_tokens):
+ """
+ Precalculate the batches, making it so len and iter just read off the precalculated batches
+ """
+ self.data = data
+ self.batch_size = batch_size
+ self.maximum_tokens = maximum_tokens
+
+ self.batches = []
+ current_batch = []
+ current_length = 0
+
+ for item, item_idx in data:
+ item_len = len(item.word)
+ if maximum_tokens and item_len > maximum_tokens:
+ if len(current_batch) > 0:
+ self.batches.append(current_batch)
+ current_batch = []
+ current_length = 0
+ self.batches.append([item_idx])
+ continue
+ if len(current_batch) + 1 > batch_size or (maximum_tokens and item_len + current_length > maximum_tokens):
+ self.batches.append(current_batch)
+ current_batch = []
+ current_length = 0
+ current_batch.append(item_idx)
+ current_length += item_len
+
+ if len(current_batch) > 0:
+ self.batches.append(current_batch)
+
+ def __len__(self):
+ return len(self.batches)
+
+ def __iter__(self):
+ for batch in self.batches:
+ current_batch = []
+ for idx in batch:
+ current_batch.append(idx)
+ yield current_batch
+
+
+class ShuffledDataset:
+ """A wrapper around one or more datasets which shuffles the data in batch_size chunks
+
+ This means that if multiple datasets are passed in, the batches
+ from each dataset are shuffled together, with one batch being
+ entirely members of the same dataset.
+
+ The main use case of this is that in the tagger, there are cases
+ where batches from different datasets will have different
+ properties, such as having or not having UPOS tags. We found that
+ it is actually somewhat tricky to make the model's loss function
+ (in model.py) properly represent batches with mixed w/ and w/o
+ property, whereas keeping one entire batch together makes it a lot
+ easier to process.
+
+ The mechanism for the shuffling is that the iterator first makes a
+ list long enough to represent each batch from each dataset,
+ tracking the index of the dataset it is coming from, then shuffles
+ that list. Another alternative would be to use a weighted
+ randomization approach, but this is very simple and the memory
+ requirements are not too onerous.
+
+ Note that the batch indices are wasteful in the case of only one
+ underlying dataset, which is actually the most common use case,
+ but the overhead is small enough that it probably isn't worth
+ special casing the one dataset version.
+ """
+ def __init__(self, datasets, batch_size):
+ self.batch_size = batch_size
+ self.datasets = datasets
+ self.loaders = [x.to_loader(batch_size=self.batch_size, shuffle=True) for x in self.datasets]
+
+ def __iter__(self):
+ iterators = [iter(x) for x in self.loaders]
+ lengths = [len(x) for x in self.loaders]
+ indices = [[x] * y for x, y in enumerate(lengths)]
+ indices = [idx for inner in indices for idx in inner]
+ random.shuffle(indices)
+
+ for idx in indices:
+ yield(next(iterators[idx]))
+
+ def __len__(self):
+ return sum(len(x) for x in self.datasets)
diff --git a/stanza/stanza/models/pos/model.py b/stanza/stanza/models/pos/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..f7d620698d75564a177a3e4a07dae30638431267
--- /dev/null
+++ b/stanza/stanza/models/pos/model.py
@@ -0,0 +1,256 @@
+import logging
+import os
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, pack_sequence, pad_sequence, PackedSequence
+
+from stanza.models.common.bert_embedding import extract_bert_embeddings
+from stanza.models.common.biaffine import BiaffineScorer
+from stanza.models.common.foundation_cache import load_bert, load_charlm
+from stanza.models.common.hlstm import HighwayLSTM
+from stanza.models.common.dropout import WordDropout
+from stanza.models.common.utils import attach_bert_model
+from stanza.models.common.vocab import CompositeVocab
+from stanza.models.common.char_model import CharacterModel
+from stanza.models.common import utils
+
+logger = logging.getLogger('stanza')
+
+class Tagger(nn.Module):
+ def __init__(self, args, vocab, emb_matrix=None, share_hid=False, foundation_cache=None, bert_model=None, bert_tokenizer=None, force_bert_saved=False, peft_name=None):
+ super().__init__()
+
+ self.vocab = vocab
+ self.args = args
+ self.share_hid = share_hid
+ self.unsaved_modules = []
+
+ # input layers
+ input_size = 0
+ if self.args['word_emb_dim'] > 0:
+ # frequent word embeddings
+ self.word_emb = nn.Embedding(len(vocab['word']), self.args['word_emb_dim'], padding_idx=0)
+ input_size += self.args['word_emb_dim']
+
+ if not share_hid:
+ # upos embeddings
+ self.upos_emb = nn.Embedding(len(vocab['upos']), self.args['tag_emb_dim'], padding_idx=0)
+
+ if self.args['char'] and self.args['char_emb_dim'] > 0:
+ if self.args.get('charlm', None):
+ if args['charlm_forward_file'] is None or not os.path.exists(args['charlm_forward_file']):
+ raise FileNotFoundError('Could not find forward character model: {} Please specify with --charlm_forward_file'.format(args['charlm_forward_file']))
+ if args['charlm_backward_file'] is None or not os.path.exists(args['charlm_backward_file']):
+ raise FileNotFoundError('Could not find backward character model: {} Please specify with --charlm_backward_file'.format(args['charlm_backward_file']))
+ logger.debug("POS model loading charmodels: %s and %s", args['charlm_forward_file'], args['charlm_backward_file'])
+ self.add_unsaved_module('charmodel_forward', load_charlm(args['charlm_forward_file'], foundation_cache=foundation_cache))
+ self.add_unsaved_module('charmodel_backward', load_charlm(args['charlm_backward_file'], foundation_cache=foundation_cache))
+ # optionally add a input transformation layer
+ if self.args.get('charlm_transform_dim', 0):
+ self.charmodel_forward_transform = nn.Linear(self.charmodel_forward.hidden_dim(), self.args['charlm_transform_dim'], bias=False)
+ self.charmodel_backward_transform = nn.Linear(self.charmodel_backward.hidden_dim(), self.args['charlm_transform_dim'], bias=False)
+ input_size += self.args['charlm_transform_dim'] * 2
+ else:
+ self.charmodel_forward_transform = None
+ self.charmodel_backward_transform = None
+ input_size += self.charmodel_forward.hidden_dim() + self.charmodel_backward.hidden_dim()
+ else:
+ bidirectional = args.get('char_bidirectional', False)
+ self.charmodel = CharacterModel(args, vocab, bidirectional=bidirectional)
+ if bidirectional:
+ self.trans_char = nn.Linear(self.args['char_hidden_dim'] * 2, self.args['transformed_dim'], bias=False)
+ else:
+ self.trans_char = nn.Linear(self.args['char_hidden_dim'], self.args['transformed_dim'], bias=False)
+ input_size += self.args['transformed_dim']
+
+ self.peft_name = peft_name
+ attach_bert_model(self, bert_model, bert_tokenizer, self.args.get('use_peft', False), force_bert_saved)
+ if self.args.get('bert_model', None):
+ # TODO: refactor bert_hidden_layers between the different models
+ if args.get('bert_hidden_layers', False):
+ # The average will be offset by 1/N so that the default zeros
+ # represents an average of the N layers
+ self.bert_layer_mix = nn.Linear(args['bert_hidden_layers'], 1, bias=False)
+ nn.init.zeros_(self.bert_layer_mix.weight)
+ else:
+ # an average of layers 2, 3, 4 will be used
+ # (for historic reasons)
+ self.bert_layer_mix = None
+ input_size += self.bert_model.config.hidden_size
+
+ if self.args['pretrain']:
+ # pretrained embeddings, by default this won't be saved into model file
+ self.add_unsaved_module('pretrained_emb', nn.Embedding.from_pretrained(emb_matrix, freeze=True))
+ self.trans_pretrained = nn.Linear(emb_matrix.shape[1], self.args['transformed_dim'], bias=False)
+ input_size += self.args['transformed_dim']
+
+ # recurrent layers
+ self.taggerlstm = HighwayLSTM(input_size, self.args['hidden_dim'], self.args['num_layers'], batch_first=True, bidirectional=True, dropout=self.args['dropout'], rec_dropout=self.args['rec_dropout'], highway_func=torch.tanh)
+ self.drop_replacement = nn.Parameter(torch.randn(input_size) / np.sqrt(input_size))
+ self.taggerlstm_h_init = nn.Parameter(torch.zeros(2 * self.args['num_layers'], 1, self.args['hidden_dim']))
+ self.taggerlstm_c_init = nn.Parameter(torch.zeros(2 * self.args['num_layers'], 1, self.args['hidden_dim']))
+
+ # classifiers
+ self.upos_hid = nn.Linear(self.args['hidden_dim'] * 2, self.args['deep_biaff_hidden_dim'])
+ self.upos_clf = nn.Linear(self.args['deep_biaff_hidden_dim'], len(vocab['upos']))
+ self.upos_clf.weight.data.zero_()
+ self.upos_clf.bias.data.zero_()
+
+ if share_hid:
+ clf_constructor = lambda insize, outsize: nn.Linear(insize, outsize)
+ else:
+ self.xpos_hid = nn.Linear(self.args['hidden_dim'] * 2, self.args['deep_biaff_hidden_dim'] if not isinstance(vocab['xpos'], CompositeVocab) else self.args['composite_deep_biaff_hidden_dim'])
+ self.ufeats_hid = nn.Linear(self.args['hidden_dim'] * 2, self.args['composite_deep_biaff_hidden_dim'])
+ clf_constructor = lambda insize, outsize: BiaffineScorer(insize, self.args['tag_emb_dim'], outsize)
+
+ if isinstance(vocab['xpos'], CompositeVocab):
+ self.xpos_clf = nn.ModuleList()
+ for l in vocab['xpos'].lens():
+ self.xpos_clf.append(clf_constructor(self.args['composite_deep_biaff_hidden_dim'], l))
+ else:
+ self.xpos_clf = clf_constructor(self.args['deep_biaff_hidden_dim'], len(vocab['xpos']))
+ if share_hid:
+ self.xpos_clf.weight.data.zero_()
+ self.xpos_clf.bias.data.zero_()
+
+ self.ufeats_clf = nn.ModuleList()
+ for l in vocab['feats'].lens():
+ if share_hid:
+ self.ufeats_clf.append(clf_constructor(self.args['deep_biaff_hidden_dim'], l))
+ self.ufeats_clf[-1].weight.data.zero_()
+ self.ufeats_clf[-1].bias.data.zero_()
+ else:
+ self.ufeats_clf.append(clf_constructor(self.args['composite_deep_biaff_hidden_dim'], l))
+
+ # criterion
+ self.crit = nn.CrossEntropyLoss(ignore_index=0) # ignore padding
+
+ self.drop = nn.Dropout(args['dropout'])
+ self.worddrop = WordDropout(args['word_dropout'])
+
+ def add_unsaved_module(self, name, module):
+ self.unsaved_modules += [name]
+ setattr(self, name, module)
+
+ def log_norms(self):
+ utils.log_norms(self)
+
+ def forward(self, word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained, word_orig_idx, sentlens, wordlens, text):
+
+ def pack(x):
+ return pack_padded_sequence(x, sentlens, batch_first=True)
+
+ inputs = []
+ if self.args['word_emb_dim'] > 0:
+ word_emb = self.word_emb(word)
+ word_emb = pack(word_emb)
+ inputs += [word_emb]
+
+ if self.args['pretrain']:
+ pretrained_emb = self.pretrained_emb(pretrained)
+ pretrained_emb = self.trans_pretrained(pretrained_emb)
+ pretrained_emb = pack(pretrained_emb)
+ inputs += [pretrained_emb]
+
+ def pad(x):
+ return pad_packed_sequence(PackedSequence(x, inputs[0].batch_sizes), batch_first=True)[0]
+
+ if self.args['char'] and self.args['char_emb_dim'] > 0:
+ if self.args.get('charlm', None):
+ all_forward_chars = self.charmodel_forward.build_char_representation(text)
+ assert isinstance(all_forward_chars, list)
+ if self.charmodel_forward_transform is not None:
+ all_forward_chars = [self.charmodel_forward_transform(x) for x in all_forward_chars]
+ all_forward_chars = pack(pad_sequence(all_forward_chars, batch_first=True))
+
+ all_backward_chars = self.charmodel_backward.build_char_representation(text)
+ if self.charmodel_backward_transform is not None:
+ all_backward_chars = [self.charmodel_backward_transform(x) for x in all_backward_chars]
+ all_backward_chars = pack(pad_sequence(all_backward_chars, batch_first=True))
+
+ inputs += [all_forward_chars, all_backward_chars]
+ else:
+ char_reps = self.charmodel(wordchars, wordchars_mask, word_orig_idx, sentlens, wordlens)
+ char_reps = PackedSequence(self.trans_char(self.drop(char_reps.data)), char_reps.batch_sizes)
+ inputs += [char_reps]
+
+ if self.bert_model is not None:
+ device = next(self.parameters()).device
+ processed_bert = extract_bert_embeddings(self.args['bert_model'], self.bert_tokenizer, self.bert_model, text, device, keep_endpoints=False,
+ num_layers=self.bert_layer_mix.in_features if self.bert_layer_mix is not None else None,
+ detach=not self.args.get('bert_finetune', False) or not self.training,
+ peft_name=self.peft_name)
+
+ if self.bert_layer_mix is not None:
+ # add the average so that the default behavior is to
+ # take an average of the N layers, and anything else
+ # other than that needs to be learned
+ # TODO: refactor this
+ processed_bert = [self.bert_layer_mix(feature).squeeze(2) + feature.sum(axis=2) / self.bert_layer_mix.in_features for feature in processed_bert]
+
+ processed_bert = pad_sequence(processed_bert, batch_first=True)
+ inputs += [pack(processed_bert)]
+
+ lstm_inputs = torch.cat([x.data for x in inputs], 1)
+ lstm_inputs = self.worddrop(lstm_inputs, self.drop_replacement)
+ lstm_inputs = self.drop(lstm_inputs)
+ lstm_inputs = PackedSequence(lstm_inputs, inputs[0].batch_sizes)
+
+ lstm_outputs, _ = self.taggerlstm(lstm_inputs, sentlens, hx=(self.taggerlstm_h_init.expand(2 * self.args['num_layers'], word.size(0), self.args['hidden_dim']).contiguous(), self.taggerlstm_c_init.expand(2 * self.args['num_layers'], word.size(0), self.args['hidden_dim']).contiguous()))
+ lstm_outputs = lstm_outputs.data
+
+ upos_hid = F.relu(self.upos_hid(self.drop(lstm_outputs)))
+ upos_pred = self.upos_clf(self.drop(upos_hid))
+
+ preds = [pad(upos_pred).max(2)[1]]
+
+ if upos is not None:
+ upos = pack(upos).data
+ loss = self.crit(upos_pred.view(-1, upos_pred.size(-1)), upos.view(-1))
+ else:
+ loss = 0.0
+
+ if self.share_hid:
+ xpos_hid = upos_hid
+ ufeats_hid = upos_hid
+
+ clffunc = lambda clf, hid: clf(self.drop(hid))
+ else:
+ xpos_hid = F.relu(self.xpos_hid(self.drop(lstm_outputs)))
+ ufeats_hid = F.relu(self.ufeats_hid(self.drop(lstm_outputs)))
+
+ if self.training and upos is not None:
+ upos_emb = self.upos_emb(upos)
+ else:
+ upos_emb = self.upos_emb(upos_pred.max(1)[1])
+
+ clffunc = lambda clf, hid: clf(self.drop(hid), self.drop(upos_emb))
+
+ if xpos is not None: xpos = pack(xpos).data
+ if isinstance(self.vocab['xpos'], CompositeVocab):
+ xpos_preds = []
+ for i in range(len(self.vocab['xpos'])):
+ xpos_pred = clffunc(self.xpos_clf[i], xpos_hid)
+ if xpos is not None:
+ loss += self.crit(xpos_pred.view(-1, xpos_pred.size(-1)), xpos[:, i].view(-1))
+ xpos_preds.append(pad(xpos_pred).max(2, keepdim=True)[1])
+ preds.append(torch.cat(xpos_preds, 2))
+ else:
+ xpos_pred = clffunc(self.xpos_clf, xpos_hid)
+ if xpos is not None:
+ loss += self.crit(xpos_pred.view(-1, xpos_pred.size(-1)), xpos.view(-1))
+ preds.append(pad(xpos_pred).max(2)[1])
+
+ ufeats_preds = []
+ if ufeats is not None: ufeats = pack(ufeats).data
+ for i in range(len(self.vocab['feats'])):
+ ufeats_pred = clffunc(self.ufeats_clf[i], ufeats_hid)
+ if ufeats is not None:
+ loss += self.crit(ufeats_pred.view(-1, ufeats_pred.size(-1)), ufeats[:, i].view(-1))
+ ufeats_preds.append(pad(ufeats_pred).max(2, keepdim=True)[1])
+ preds.append(torch.cat(ufeats_preds, 2))
+
+ return loss, preds
diff --git a/stanza/stanza/models/pos/trainer.py b/stanza/stanza/models/pos/trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ec8fbf161152a73fb1bcdb2542bf8d7255685d5
--- /dev/null
+++ b/stanza/stanza/models/pos/trainer.py
@@ -0,0 +1,179 @@
+"""
+A trainer class to handle training and testing of models.
+"""
+
+import sys
+import logging
+import torch
+from torch import nn
+
+from stanza.models.common.trainer import Trainer as BaseTrainer
+from stanza.models.common import utils, loss
+from stanza.models.common.foundation_cache import load_bert, load_bert_with_peft, NoTransformerFoundationCache
+from stanza.models.common.peft_config import build_peft_wrapper, load_peft_wrapper
+from stanza.models.pos.model import Tagger
+from stanza.models.pos.vocab import MultiVocab
+
+logger = logging.getLogger('stanza')
+
+def unpack_batch(batch, device):
+ """ Unpack a batch from the data loader. """
+ inputs = [b.to(device) if b is not None else None for b in batch[:8]]
+ orig_idx = batch[8]
+ word_orig_idx = batch[9]
+ sentlens = batch[10]
+ wordlens = batch[11]
+ text = batch[12]
+ return inputs, orig_idx, word_orig_idx, sentlens, wordlens, text
+
+class Trainer(BaseTrainer):
+ """ A trainer for training models. """
+ def __init__(self, args=None, vocab=None, pretrain=None, model_file=None, device=None, foundation_cache=None):
+ if model_file is not None:
+ # load everything from file
+ self.load(model_file, pretrain, args=args, foundation_cache=foundation_cache)
+ else:
+ # build model from scratch
+ self.args = args
+ self.vocab = vocab
+
+ bert_model, bert_tokenizer = load_bert(self.args['bert_model'])
+ peft_name = None
+ if self.args['use_peft']:
+ # fine tune the bert if we're using peft
+ self.args['bert_finetune'] = True
+ peft_name = "pos"
+ bert_model = build_peft_wrapper(bert_model, self.args, logger, adapter_name=peft_name)
+
+ self.model = Tagger(args, vocab, emb_matrix=pretrain.emb if pretrain is not None else None, share_hid=args['share_hid'], foundation_cache=foundation_cache, bert_model=bert_model, bert_tokenizer=bert_tokenizer, force_bert_saved=self.args['bert_finetune'], peft_name=peft_name)
+
+ self.model = self.model.to(device)
+ self.optimizers = utils.get_split_optimizer(self.args['optim'], self.model, self.args['lr'], betas=(0.9, self.args['beta2']), eps=1e-6, weight_decay=self.args.get('initial_weight_decay', None), bert_learning_rate=self.args.get('bert_learning_rate', 0.0), is_peft=self.args.get("peft", False))
+
+ self.schedulers = {}
+
+ if self.args.get('bert_finetune', None):
+ import transformers
+ warmup_scheduler = transformers.get_linear_schedule_with_warmup(
+ self.optimizers["bert_optimizer"],
+ # todo late starting?
+ 0, self.args["max_steps"])
+ self.schedulers["bert_scheduler"] = warmup_scheduler
+
+ def update(self, batch, eval=False):
+ device = next(self.model.parameters()).device
+ inputs, orig_idx, word_orig_idx, sentlens, wordlens, text = unpack_batch(batch, device)
+ word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained = inputs
+
+ if eval:
+ self.model.eval()
+ else:
+ self.model.train()
+ for optimizer in self.optimizers.values():
+ optimizer.zero_grad()
+ loss, _ = self.model(word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained, word_orig_idx, sentlens, wordlens, text)
+ if loss == 0.0:
+ return loss
+
+ loss_val = loss.data.item()
+ if eval:
+ return loss_val
+
+ loss.backward()
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm'])
+
+ for optimizer in self.optimizers.values():
+ optimizer.step()
+ for scheduler in self.schedulers.values():
+ scheduler.step()
+ return loss_val
+
+ def predict(self, batch, unsort=True):
+ device = next(self.model.parameters()).device
+ inputs, orig_idx, word_orig_idx, sentlens, wordlens, text = unpack_batch(batch, device)
+ word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained = inputs
+
+ self.model.eval()
+ batch_size = word.size(0)
+ _, preds = self.model(word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained, word_orig_idx, sentlens, wordlens, text)
+ upos_seqs = [self.vocab['upos'].unmap(sent) for sent in preds[0].tolist()]
+ xpos_seqs = [self.vocab['xpos'].unmap(sent) for sent in preds[1].tolist()]
+ feats_seqs = [self.vocab['feats'].unmap(sent) for sent in preds[2].tolist()]
+
+ pred_tokens = [[[upos_seqs[i][j], xpos_seqs[i][j], feats_seqs[i][j]] for j in range(sentlens[i])] for i in range(batch_size)]
+ if unsort:
+ pred_tokens = utils.unsort(pred_tokens, orig_idx)
+ return pred_tokens
+
+ def save(self, filename, skip_modules=True):
+ model_state = self.model.state_dict()
+ # skip saving modules like pretrained embeddings, because they are large and will be saved in a separate file
+ if skip_modules:
+ skipped = [k for k in model_state.keys() if k.split('.')[0] in self.model.unsaved_modules]
+ for k in skipped:
+ del model_state[k]
+ params = {
+ 'model': model_state,
+ 'vocab': self.vocab.state_dict(),
+ 'config': self.args
+ }
+ if self.args.get('use_peft', False):
+ # Hide import so that peft dependency is optional
+ from peft import get_peft_model_state_dict
+ params["bert_lora"] = get_peft_model_state_dict(self.model.bert_model, adapter_name=self.model.peft_name)
+
+ try:
+ torch.save(params, filename, _use_new_zipfile_serialization=False)
+ logger.info("Model saved to {}".format(filename))
+ except (KeyboardInterrupt, SystemExit):
+ raise
+ except Exception as e:
+ logger.warning(f"Saving failed... {e} continuing anyway.")
+
+ def load(self, filename, pretrain, args=None, foundation_cache=None):
+ """
+ Load a model from file, with preloaded pretrain embeddings. Here we allow the pretrain to be None or a dummy input,
+ and the actual use of pretrain embeddings will depend on the boolean config "pretrain" in the loaded args.
+ """
+ try:
+ checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True)
+ except BaseException:
+ logger.error("Cannot load model from {}".format(filename))
+ raise
+ self.args = checkpoint['config']
+ if args is not None: self.args.update(args)
+
+ # preserve old models which were created before transformers were added
+ if 'bert_model' not in self.args:
+ self.args['bert_model'] = None
+
+ lora_weights = checkpoint.get('bert_lora')
+ if lora_weights:
+ logger.debug("Found peft weights for POS; loading a peft adapter")
+ self.args["use_peft"] = True
+
+ # TODO: refactor this common block of code with NER
+ force_bert_saved = False
+ peft_name = None
+ if self.args.get('use_peft', False):
+ force_bert_saved = True
+ bert_model, bert_tokenizer, peft_name = load_bert_with_peft(self.args['bert_model'], "pos", foundation_cache)
+ bert_model = load_peft_wrapper(bert_model, lora_weights, self.args, logger, peft_name)
+ logger.debug("Loaded peft with name %s", peft_name)
+ else:
+ if any(x.startswith("bert_model.") for x in checkpoint['model'].keys()):
+ logger.debug("Model %s has a finetuned transformer. Not using transformer cache to make sure the finetuned version of the transformer isn't accidentally used elsewhere", filename)
+ foundation_cache = NoTransformerFoundationCache(foundation_cache)
+ force_bert_saved = True
+ bert_model, bert_tokenizer = load_bert(self.args.get('bert_model'), foundation_cache)
+
+ self.vocab = MultiVocab.load_state_dict(checkpoint['vocab'])
+ # load model
+ emb_matrix = None
+ if self.args['pretrain'] and pretrain is not None: # we use pretrain only if args['pretrain'] == True and pretrain is not None
+ emb_matrix = pretrain.emb
+ if any(x.startswith("bert_model.") for x in checkpoint['model'].keys()):
+ logger.debug("Model %s has a finetuned transformer. Not using transformer cache to make sure the finetuned version of the transformer isn't accidentally used elsewhere", filename)
+ foundation_cache = NoTransformerFoundationCache(foundation_cache)
+ self.model = Tagger(self.args, self.vocab, emb_matrix=emb_matrix, share_hid=self.args['share_hid'], foundation_cache=foundation_cache, bert_model=bert_model, bert_tokenizer=bert_tokenizer, force_bert_saved=force_bert_saved, peft_name=peft_name)
+ self.model.load_state_dict(checkpoint['model'], strict=False)
diff --git a/stanza/stanza/models/pos/xpos_vocab_factory.py b/stanza/stanza/models/pos/xpos_vocab_factory.py
new file mode 100644
index 0000000000000000000000000000000000000000..864e5abda5ebaaf81817ad183487f34920033660
--- /dev/null
+++ b/stanza/stanza/models/pos/xpos_vocab_factory.py
@@ -0,0 +1,200 @@
+# This is the XPOS factory method generated automatically from stanza.models.pos.build_xpos_vocab_factory.
+# Please don't edit it!
+
+import logging
+
+from stanza.models.pos.vocab import WordVocab, XPOSVocab
+from stanza.models.pos.xpos_vocab_utils import XPOSDescription, XPOSType, build_xpos_vocab, choose_simplest_factory
+
+# using a sublogger makes it easier to test in the unittests
+logger = logging.getLogger('stanza.models.pos.xpos_vocab_factory')
+
+XPOS_DESCRIPTIONS = {
+ 'af_afribooms' : XPOSDescription(XPOSType.XPOS, ''),
+ 'ar_padt' : XPOSDescription(XPOSType.XPOS, ''),
+ 'bg_btb' : XPOSDescription(XPOSType.XPOS, ''),
+ 'ca_ancora' : XPOSDescription(XPOSType.XPOS, ''),
+ 'cs_cac' : XPOSDescription(XPOSType.XPOS, ''),
+ 'cs_cltt' : XPOSDescription(XPOSType.XPOS, ''),
+ 'cs_fictree' : XPOSDescription(XPOSType.XPOS, ''),
+ 'cs_pdt' : XPOSDescription(XPOSType.XPOS, ''),
+ 'en_partut' : XPOSDescription(XPOSType.XPOS, ''),
+ 'es_ancora' : XPOSDescription(XPOSType.XPOS, ''),
+ 'es_combined' : XPOSDescription(XPOSType.XPOS, ''),
+ 'fr_partut' : XPOSDescription(XPOSType.XPOS, ''),
+ 'gd_arcosg' : XPOSDescription(XPOSType.XPOS, ''),
+ 'gl_ctg' : XPOSDescription(XPOSType.XPOS, ''),
+ 'gl_treegal' : XPOSDescription(XPOSType.XPOS, ''),
+ 'grc_perseus' : XPOSDescription(XPOSType.XPOS, ''),
+ 'hr_set' : XPOSDescription(XPOSType.XPOS, ''),
+ 'is_gc' : XPOSDescription(XPOSType.XPOS, ''),
+ 'is_icepahc' : XPOSDescription(XPOSType.XPOS, ''),
+ 'is_modern' : XPOSDescription(XPOSType.XPOS, ''),
+ 'it_combined' : XPOSDescription(XPOSType.XPOS, ''),
+ 'it_isdt' : XPOSDescription(XPOSType.XPOS, ''),
+ 'it_markit' : XPOSDescription(XPOSType.XPOS, ''),
+ 'it_parlamint' : XPOSDescription(XPOSType.XPOS, ''),
+ 'it_partut' : XPOSDescription(XPOSType.XPOS, ''),
+ 'it_postwita' : XPOSDescription(XPOSType.XPOS, ''),
+ 'it_twittiro' : XPOSDescription(XPOSType.XPOS, ''),
+ 'it_vit' : XPOSDescription(XPOSType.XPOS, ''),
+ 'la_perseus' : XPOSDescription(XPOSType.XPOS, ''),
+ 'la_udante' : XPOSDescription(XPOSType.XPOS, ''),
+ 'lt_alksnis' : XPOSDescription(XPOSType.XPOS, ''),
+ 'lv_lvtb' : XPOSDescription(XPOSType.XPOS, ''),
+ 'ro_nonstandard' : XPOSDescription(XPOSType.XPOS, ''),
+ 'ro_rrt' : XPOSDescription(XPOSType.XPOS, ''),
+ 'ro_simonero' : XPOSDescription(XPOSType.XPOS, ''),
+ 'sk_snk' : XPOSDescription(XPOSType.XPOS, ''),
+ 'sl_ssj' : XPOSDescription(XPOSType.XPOS, ''),
+ 'sl_sst' : XPOSDescription(XPOSType.XPOS, ''),
+ 'sr_set' : XPOSDescription(XPOSType.XPOS, ''),
+ 'ta_ttb' : XPOSDescription(XPOSType.XPOS, ''),
+ 'uk_iu' : XPOSDescription(XPOSType.XPOS, ''),
+
+ 'be_hse' : XPOSDescription(XPOSType.WORD, None),
+ 'bxr_bdt' : XPOSDescription(XPOSType.WORD, None),
+ 'cop_scriptorium': XPOSDescription(XPOSType.WORD, None),
+ 'cu_proiel' : XPOSDescription(XPOSType.WORD, None),
+ 'cy_ccg' : XPOSDescription(XPOSType.WORD, None),
+ 'da_ddt' : XPOSDescription(XPOSType.WORD, None),
+ 'de_gsd' : XPOSDescription(XPOSType.WORD, None),
+ 'de_hdt' : XPOSDescription(XPOSType.WORD, None),
+ 'el_gdt' : XPOSDescription(XPOSType.WORD, None),
+ 'el_gud' : XPOSDescription(XPOSType.WORD, None),
+ 'en_atis' : XPOSDescription(XPOSType.WORD, None),
+ 'en_combined' : XPOSDescription(XPOSType.WORD, None),
+ 'en_craft' : XPOSDescription(XPOSType.WORD, None),
+ 'en_eslspok' : XPOSDescription(XPOSType.WORD, None),
+ 'en_ewt' : XPOSDescription(XPOSType.WORD, None),
+ 'en_genia' : XPOSDescription(XPOSType.WORD, None),
+ 'en_gum' : XPOSDescription(XPOSType.WORD, None),
+ 'en_gumreddit' : XPOSDescription(XPOSType.WORD, None),
+ 'en_mimic' : XPOSDescription(XPOSType.WORD, None),
+ 'en_test' : XPOSDescription(XPOSType.WORD, None),
+ 'es_gsd' : XPOSDescription(XPOSType.WORD, None),
+ 'et_edt' : XPOSDescription(XPOSType.WORD, None),
+ 'et_ewt' : XPOSDescription(XPOSType.WORD, None),
+ 'eu_bdt' : XPOSDescription(XPOSType.WORD, None),
+ 'fa_perdt' : XPOSDescription(XPOSType.WORD, None),
+ 'fa_seraji' : XPOSDescription(XPOSType.WORD, None),
+ 'fi_tdt' : XPOSDescription(XPOSType.WORD, None),
+ 'fr_combined' : XPOSDescription(XPOSType.WORD, None),
+ 'fr_gsd' : XPOSDescription(XPOSType.WORD, None),
+ 'fr_parisstories': XPOSDescription(XPOSType.WORD, None),
+ 'fr_rhapsodie' : XPOSDescription(XPOSType.WORD, None),
+ 'fr_sequoia' : XPOSDescription(XPOSType.WORD, None),
+ 'fro_profiterole': XPOSDescription(XPOSType.WORD, None),
+ 'ga_idt' : XPOSDescription(XPOSType.WORD, None),
+ 'ga_twittirish' : XPOSDescription(XPOSType.WORD, None),
+ 'got_proiel' : XPOSDescription(XPOSType.WORD, None),
+ 'grc_proiel' : XPOSDescription(XPOSType.WORD, None),
+ 'grc_ptnk' : XPOSDescription(XPOSType.WORD, None),
+ 'gv_cadhan' : XPOSDescription(XPOSType.WORD, None),
+ 'hbo_ptnk' : XPOSDescription(XPOSType.WORD, None),
+ 'he_combined' : XPOSDescription(XPOSType.WORD, None),
+ 'he_htb' : XPOSDescription(XPOSType.WORD, None),
+ 'he_iahltknesset': XPOSDescription(XPOSType.WORD, None),
+ 'he_iahltwiki' : XPOSDescription(XPOSType.WORD, None),
+ 'hi_hdtb' : XPOSDescription(XPOSType.WORD, None),
+ 'hsb_ufal' : XPOSDescription(XPOSType.WORD, None),
+ 'hu_szeged' : XPOSDescription(XPOSType.WORD, None),
+ 'hy_armtdp' : XPOSDescription(XPOSType.WORD, None),
+ 'hy_bsut' : XPOSDescription(XPOSType.WORD, None),
+ 'hyw_armtdp' : XPOSDescription(XPOSType.WORD, None),
+ 'id_csui' : XPOSDescription(XPOSType.WORD, None),
+ 'it_old' : XPOSDescription(XPOSType.WORD, None),
+ 'ka_glc' : XPOSDescription(XPOSType.WORD, None),
+ 'kk_ktb' : XPOSDescription(XPOSType.WORD, None),
+ 'kmr_mg' : XPOSDescription(XPOSType.WORD, None),
+ 'kpv_lattice' : XPOSDescription(XPOSType.WORD, None),
+ 'ky_ktmu' : XPOSDescription(XPOSType.WORD, None),
+ 'la_proiel' : XPOSDescription(XPOSType.WORD, None),
+ 'lij_glt' : XPOSDescription(XPOSType.WORD, None),
+ 'lt_hse' : XPOSDescription(XPOSType.WORD, None),
+ 'lzh_kyoto' : XPOSDescription(XPOSType.WORD, None),
+ 'mr_ufal' : XPOSDescription(XPOSType.WORD, None),
+ 'mt_mudt' : XPOSDescription(XPOSType.WORD, None),
+ 'myv_jr' : XPOSDescription(XPOSType.WORD, None),
+ 'nb_bokmaal' : XPOSDescription(XPOSType.WORD, None),
+ 'nds_lsdc' : XPOSDescription(XPOSType.WORD, None),
+ 'nn_nynorsk' : XPOSDescription(XPOSType.WORD, None),
+ 'nn_nynorsklia' : XPOSDescription(XPOSType.WORD, None),
+ 'no_bokmaal' : XPOSDescription(XPOSType.WORD, None),
+ 'orv_birchbark' : XPOSDescription(XPOSType.WORD, None),
+ 'orv_rnc' : XPOSDescription(XPOSType.WORD, None),
+ 'orv_torot' : XPOSDescription(XPOSType.WORD, None),
+ 'ota_boun' : XPOSDescription(XPOSType.WORD, None),
+ 'pcm_nsc' : XPOSDescription(XPOSType.WORD, None),
+ 'pt_bosque' : XPOSDescription(XPOSType.WORD, None),
+ 'pt_cintil' : XPOSDescription(XPOSType.WORD, None),
+ 'pt_dantestocks' : XPOSDescription(XPOSType.WORD, None),
+ 'pt_gsd' : XPOSDescription(XPOSType.WORD, None),
+ 'pt_petrogold' : XPOSDescription(XPOSType.WORD, None),
+ 'pt_porttinari' : XPOSDescription(XPOSType.WORD, None),
+ 'qpm_philotis' : XPOSDescription(XPOSType.WORD, None),
+ 'qtd_sagt' : XPOSDescription(XPOSType.WORD, None),
+ 'ru_gsd' : XPOSDescription(XPOSType.WORD, None),
+ 'ru_poetry' : XPOSDescription(XPOSType.WORD, None),
+ 'ru_syntagrus' : XPOSDescription(XPOSType.WORD, None),
+ 'ru_taiga' : XPOSDescription(XPOSType.WORD, None),
+ 'sa_vedic' : XPOSDescription(XPOSType.WORD, None),
+ 'sme_giella' : XPOSDescription(XPOSType.WORD, None),
+ 'swl_sslc' : XPOSDescription(XPOSType.WORD, None),
+ 'sq_staf' : XPOSDescription(XPOSType.WORD, None),
+ 'te_mtg' : XPOSDescription(XPOSType.WORD, None),
+ 'tr_atis' : XPOSDescription(XPOSType.WORD, None),
+ 'tr_boun' : XPOSDescription(XPOSType.WORD, None),
+ 'tr_framenet' : XPOSDescription(XPOSType.WORD, None),
+ 'tr_imst' : XPOSDescription(XPOSType.WORD, None),
+ 'tr_kenet' : XPOSDescription(XPOSType.WORD, None),
+ 'tr_penn' : XPOSDescription(XPOSType.WORD, None),
+ 'tr_tourism' : XPOSDescription(XPOSType.WORD, None),
+ 'ug_udt' : XPOSDescription(XPOSType.WORD, None),
+ 'uk_parlamint' : XPOSDescription(XPOSType.WORD, None),
+ 'vi_vtb' : XPOSDescription(XPOSType.WORD, None),
+ 'wo_wtb' : XPOSDescription(XPOSType.WORD, None),
+ 'xcl_caval' : XPOSDescription(XPOSType.WORD, None),
+ 'zh-hans_gsdsimp': XPOSDescription(XPOSType.WORD, None),
+ 'zh-hant_gsd' : XPOSDescription(XPOSType.WORD, None),
+ 'zh_gsdsimp' : XPOSDescription(XPOSType.WORD, None),
+
+ 'en_lines' : XPOSDescription(XPOSType.XPOS, '-'),
+ 'fo_farpahc' : XPOSDescription(XPOSType.XPOS, '-'),
+ 'ja_gsd' : XPOSDescription(XPOSType.XPOS, '-'),
+ 'ja_gsdluw' : XPOSDescription(XPOSType.XPOS, '-'),
+ 'sv_lines' : XPOSDescription(XPOSType.XPOS, '-'),
+ 'ur_udtb' : XPOSDescription(XPOSType.XPOS, '-'),
+
+ 'fi_ftb' : XPOSDescription(XPOSType.XPOS, ','),
+ 'orv_ruthenian' : XPOSDescription(XPOSType.XPOS, ','),
+
+ 'id_gsd' : XPOSDescription(XPOSType.XPOS, '+'),
+ 'ko_gsd' : XPOSDescription(XPOSType.XPOS, '+'),
+ 'ko_kaist' : XPOSDescription(XPOSType.XPOS, '+'),
+ 'ko_ksl' : XPOSDescription(XPOSType.XPOS, '+'),
+ 'qaf_arabizi' : XPOSDescription(XPOSType.XPOS, '+'),
+
+ 'la_ittb' : XPOSDescription(XPOSType.XPOS, '|'),
+ 'la_llct' : XPOSDescription(XPOSType.XPOS, '|'),
+ 'nl_alpino' : XPOSDescription(XPOSType.XPOS, '|'),
+ 'nl_lassysmall' : XPOSDescription(XPOSType.XPOS, '|'),
+ 'sv_talbanken' : XPOSDescription(XPOSType.XPOS, '|'),
+
+ 'pl_lfg' : XPOSDescription(XPOSType.XPOS, ':'),
+ 'pl_pdb' : XPOSDescription(XPOSType.XPOS, ':'),
+}
+
+def xpos_vocab_factory(data, shorthand):
+ if shorthand not in XPOS_DESCRIPTIONS:
+ logger.warning("%s is not a known dataset. Examining the data to choose which xpos vocab to use", shorthand)
+ desc = choose_simplest_factory(data, shorthand)
+ if shorthand in XPOS_DESCRIPTIONS:
+ if XPOS_DESCRIPTIONS[shorthand] != desc:
+ # log instead of throw
+ # otherwise, updating datasets would be unpleasant
+ logger.error("XPOS tagset in %s has apparently changed! Was %s, is now %s", shorthand, XPOS_DESCRIPTIONS[shorthand], desc)
+ else:
+ logger.warning("Chose %s for the xpos factory for %s", desc, shorthand)
+ return build_xpos_vocab(desc, data, shorthand)
+
diff --git a/stanza/stanza/models/pos/xpos_vocab_utils.py b/stanza/stanza/models/pos/xpos_vocab_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..a8cd79501bb365b04c322d8fd9e631d7a69c3297
--- /dev/null
+++ b/stanza/stanza/models/pos/xpos_vocab_utils.py
@@ -0,0 +1,48 @@
+from collections import namedtuple
+from enum import Enum
+import logging
+import os
+
+from stanza.models.common.vocab import VOCAB_PREFIX
+from stanza.models.pos.vocab import XPOSVocab, WordVocab
+
+class XPOSType(Enum):
+ XPOS = 1
+ WORD = 2
+
+XPOSDescription = namedtuple('XPOSDescription', ['xpos_type', 'sep'])
+DEFAULT_KEY = XPOSDescription(XPOSType.WORD, None)
+
+logger = logging.getLogger('stanza')
+
+def filter_data(data, idx):
+ data_filtered = []
+ for sentence in data:
+ flag = True
+ for token in sentence:
+ if token[idx] is None:
+ flag = False
+ if flag: data_filtered.append(sentence)
+ return data_filtered
+
+def choose_simplest_factory(data, shorthand):
+ logger.info(f'Original length = {len(data)}')
+ data = filter_data(data, idx=2)
+ logger.info(f'Filtered length = {len(data)}')
+ vocab = WordVocab(data, shorthand, idx=2, ignore=["_"])
+ key = DEFAULT_KEY
+ best_size = len(vocab) - len(VOCAB_PREFIX)
+ if best_size > 20:
+ for sep in ['', '-', '+', '|', ',', ':']: # separators
+ vocab = XPOSVocab(data, shorthand, idx=2, sep=sep)
+ length = sum(len(x) - len(VOCAB_PREFIX) for x in vocab._id2unit.values())
+ if length < best_size:
+ key = XPOSDescription(XPOSType.XPOS, sep)
+ best_size = length
+ return key
+
+def build_xpos_vocab(description, data, shorthand):
+ if description.xpos_type is XPOSType.WORD:
+ return WordVocab(data, shorthand, idx=2, ignore=["_"])
+
+ return XPOSVocab(data, shorthand, idx=2, sep=description.sep)
diff --git a/stanza/stanza/models/tokenization/__init__.py b/stanza/stanza/models/tokenization/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/stanza/stanza/models/tokenization/data.py b/stanza/stanza/models/tokenization/data.py
new file mode 100644
index 0000000000000000000000000000000000000000..3ff919b0ba96cef77d85f145de2263f4e078a67e
--- /dev/null
+++ b/stanza/stanza/models/tokenization/data.py
@@ -0,0 +1,432 @@
+from bisect import bisect_right
+from copy import copy
+import numpy as np
+import random
+import logging
+import re
+import torch
+from torch.utils.data import Dataset
+from .vocab import Vocab
+
+from stanza.models.common.utils import sort_with_indices, unsort
+
+logger = logging.getLogger('stanza')
+
+def filter_consecutive_whitespaces(para):
+ filtered = []
+ for i, (char, label) in enumerate(para):
+ if i > 0:
+ if char == ' ' and para[i-1][0] == ' ':
+ continue
+
+ filtered.append((char, label))
+
+ return filtered
+
+NEWLINE_WHITESPACE_RE = re.compile(r'\n\s*\n')
+# this was (r'^([\d]+[,\.]*)+$')
+# but the runtime on that can explode exponentially
+# for example, on 111111111111111111111111a
+NUMERIC_RE = re.compile(r'^[\d]+([,\.]+[\d]+)*[,\.]*$')
+WHITESPACE_RE = re.compile(r'\s')
+
+class TokenizationDataset:
+ def __init__(self, tokenizer_args, input_files={'txt': None, 'label': None}, input_text=None, vocab=None, evaluation=False, dictionary=None, *args, **kwargs):
+ super().__init__(*args, **kwargs) # forwards all unused arguments
+ self.args = tokenizer_args
+ self.eval = evaluation
+ self.dictionary = dictionary
+ self.vocab = vocab
+
+ # get input files
+ txt_file = input_files['txt']
+ label_file = input_files['label']
+
+ # Load data and process it
+ # set up text from file or input string
+ assert txt_file is not None or input_text is not None
+ if input_text is None:
+ with open(txt_file) as f:
+ text = ''.join(f.readlines()).rstrip()
+ else:
+ text = input_text
+
+ text_chunks = NEWLINE_WHITESPACE_RE.split(text)
+ text_chunks = [pt.rstrip() for pt in text_chunks]
+ text_chunks = [pt for pt in text_chunks if pt]
+ if label_file is not None:
+ with open(label_file) as f:
+ labels = ''.join(f.readlines()).rstrip()
+ labels = NEWLINE_WHITESPACE_RE.split(labels)
+ labels = [pt.rstrip() for pt in labels]
+ labels = [map(int, pt) for pt in labels if pt]
+ else:
+ labels = [[0 for _ in pt] for pt in text_chunks]
+
+ skip_newline = self.args.get('skip_newline', False)
+ self.data = [[(WHITESPACE_RE.sub(' ', char), label) # substitute special whitespaces
+ for char, label in zip(pt, pc) if not (skip_newline and char == '\n')] # check if newline needs to be eaten
+ for pt, pc in zip(text_chunks, labels)]
+
+ # remove consecutive whitespaces
+ self.data = [filter_consecutive_whitespaces(x) for x in self.data]
+
+ def labels(self):
+ """
+ Returns a list of the labels for all of the sentences in this DataLoader
+
+ Used at eval time to compare to the results, for example
+ """
+ return [np.array(list(x[1] for x in sent)) for sent in self.data]
+
+ def extract_dict_feat(self, para, idx):
+ """
+ This function is to extract dictionary features for each character
+ """
+ length = len(para)
+
+ dict_forward_feats = [0 for i in range(self.args['num_dict_feat'])]
+ dict_backward_feats = [0 for i in range(self.args['num_dict_feat'])]
+ forward_word = para[idx][0]
+ backward_word = para[idx][0]
+ prefix = True
+ suffix = True
+ for window in range(1,self.args['num_dict_feat']+1):
+ # concatenate each character and check if words found in dict not, stop if prefix not found
+ #check if idx+t is out of bound and if the prefix is already not found
+ if (idx + window) <= length-1 and prefix:
+ forward_word += para[idx+window][0].lower()
+ #check in json file if the word is present as prefix or word or None.
+ feat = 1 if forward_word in self.dictionary["words"] else 0
+ #if the return value is not 2 or 3 then the checking word is not a valid word in dict.
+ dict_forward_feats[window-1] = feat
+ #if the dict return 0 means no prefixes found, thus, stop looking for forward.
+ if forward_word not in self.dictionary["prefixes"]:
+ prefix = False
+ #backward check: similar to forward
+ if (idx - window) >= 0 and suffix:
+ backward_word = para[idx-window][0].lower() + backward_word
+ feat = 1 if backward_word in self.dictionary["words"] else 0
+ dict_backward_feats[window-1] = feat
+ if backward_word not in self.dictionary["suffixes"]:
+ suffix = False
+ #if cannot find both prefix and suffix, then exit the loop
+ if not prefix and not suffix:
+ break
+
+ return dict_forward_feats + dict_backward_feats
+
+ def para_to_sentences(self, para):
+ """ Convert a paragraph to a list of processed sentences. """
+ res = []
+ funcs = []
+ for feat_func in self.args['feat_funcs']:
+ if feat_func == 'end_of_para' or feat_func == 'start_of_para':
+ # skip for position-dependent features
+ continue
+ if feat_func == 'space_before':
+ func = lambda x: 1 if x.startswith(' ') else 0
+ elif feat_func == 'capitalized':
+ func = lambda x: 1 if x[0].isupper() else 0
+ elif feat_func == 'numeric':
+ func = lambda x: 1 if (NUMERIC_RE.match(x) is not None) else 0
+ else:
+ raise ValueError('Feature function "{}" is undefined.'.format(feat_func))
+
+ funcs.append(func)
+
+ # stacking all featurize functions
+ composite_func = lambda x: [f(x) for f in funcs]
+
+ def process_sentence(sent_units, sent_labels, sent_feats):
+ return (np.array([self.vocab.unit2id(y) for y in sent_units]),
+ np.array(sent_labels),
+ np.array(sent_feats),
+ list(sent_units))
+
+ use_end_of_para = 'end_of_para' in self.args['feat_funcs']
+ use_start_of_para = 'start_of_para' in self.args['feat_funcs']
+ use_dictionary = self.args['use_dictionary']
+ current_units = []
+ current_labels = []
+ current_feats = []
+ for i, (unit, label) in enumerate(para):
+ feats = composite_func(unit)
+ # position-dependent features
+ if use_end_of_para:
+ f = 1 if i == len(para)-1 else 0
+ feats.append(f)
+ if use_start_of_para:
+ f = 1 if i == 0 else 0
+ feats.append(f)
+
+ #if dictionary feature is selected
+ if use_dictionary:
+ dict_feats = self.extract_dict_feat(para, i)
+ feats = feats + dict_feats
+
+ current_units.append(unit)
+ current_labels.append(label)
+ current_feats.append(feats)
+ if not self.eval and (label == 2 or label == 4): # end of sentence
+ if len(current_units) <= self.args['max_seqlen']:
+ # get rid of sentences that are too long during training of the tokenizer
+ res.append(process_sentence(current_units, current_labels, current_feats))
+ current_units.clear()
+ current_labels.clear()
+ current_feats.clear()
+
+ if len(current_units) > 0:
+ if self.eval or len(current_units) <= self.args['max_seqlen']:
+ res.append(process_sentence(current_units, current_labels, current_feats))
+
+ return res
+
+ def advance_old_batch(self, eval_offsets, old_batch):
+ """
+ Advance to a new position in a batch where we have partially processed the batch
+
+ If we have previously built a batch of data and made predictions on them, then when we are trying to make
+ prediction on later characters in those paragraphs, we can avoid rebuilding the converted data from scratch
+ and just (essentially) advance the indices/offsets from where we read converted data in this old batch.
+ In this case, eval_offsets index within the old_batch to advance the strings to process.
+ """
+ unkid = self.vocab.unit2id('')
+ padid = self.vocab.unit2id('')
+
+ ounits, olabels, ofeatures, oraw = old_batch
+ feat_size = ofeatures.shape[-1]
+ lens = (ounits != padid).sum(1).tolist()
+ pad_len = max(l-i for i, l in zip(eval_offsets, lens))
+
+ units = torch.full((len(ounits), pad_len), padid, dtype=torch.int64)
+ labels = torch.full((len(ounits), pad_len), -1, dtype=torch.int32)
+ features = torch.zeros((len(ounits), pad_len, feat_size), dtype=torch.float32)
+ raw_units = []
+
+ for i in range(len(ounits)):
+ eval_offsets[i] = min(eval_offsets[i], lens[i])
+ units[i, :(lens[i] - eval_offsets[i])] = ounits[i, eval_offsets[i]:lens[i]]
+ labels[i, :(lens[i] - eval_offsets[i])] = olabels[i, eval_offsets[i]:lens[i]]
+ features[i, :(lens[i] - eval_offsets[i])] = ofeatures[i, eval_offsets[i]:lens[i]]
+ raw_units.append(oraw[i][eval_offsets[i]:lens[i]] + [''] * (pad_len - lens[i] + eval_offsets[i]))
+
+ return units, labels, features, raw_units
+
+class DataLoader(TokenizationDataset):
+ """
+ This is the training version of the dataset.
+ """
+ def __init__(self, args, input_files={'txt': None, 'label': None}, input_text=None, vocab=None, evaluation=False, dictionary=None):
+ super().__init__(args, input_files, input_text, vocab, evaluation, dictionary)
+
+ self.vocab = vocab if vocab is not None else self.init_vocab()
+
+ # data comes in a list of paragraphs, where each paragraph is a list of units with unit-level labels.
+ # At evaluation time, each paragraph is treated as single "sentence" as we don't know a priori where
+ # sentence breaks occur. We make prediction from left to right for each paragraph and move forward to
+ # the last predicted sentence break to start afresh.
+ self.sentences = [self.para_to_sentences(para) for para in self.data]
+
+ self.init_sent_ids()
+ logger.debug(f"{len(self.sentence_ids)} sentences loaded.")
+
+ def __len__(self):
+ return len(self.sentence_ids)
+
+ def init_vocab(self):
+ vocab = Vocab(self.data, self.args['lang'])
+ return vocab
+
+ def init_sent_ids(self):
+ self.sentence_ids = []
+ self.cumlen = [0]
+ for i, para in enumerate(self.sentences):
+ for j in range(len(para)):
+ self.sentence_ids += [(i, j)]
+ self.cumlen += [self.cumlen[-1] + len(self.sentences[i][j][0])]
+
+ def has_mwt(self):
+ # presumably this only needs to be called either 0 or 1 times,
+ # 1 when training and 0 any other time, so no effort is put
+ # into caching the result
+ for sentence in self.data:
+ for word in sentence:
+ if word[1] > 2:
+ return True
+ return False
+
+ def shuffle(self):
+ for para in self.sentences:
+ random.shuffle(para)
+ self.init_sent_ids()
+
+ def next(self, eval_offsets=None, unit_dropout=0.0, feat_unit_dropout=0.0):
+ ''' Get a batch of converted and padded PyTorch data from preprocessed raw text for training/prediction. '''
+ feat_size = len(self.sentences[0][0][2][0])
+ unkid = self.vocab.unit2id('')
+ padid = self.vocab.unit2id('')
+
+ def strings_starting(id_pair, offset=0, pad_len=self.args['max_seqlen']):
+ # At eval time, this combines sentences in paragraph (indexed by id_pair[0]) starting sentence (indexed
+ # by id_pair[1]) into a long string for evaluation. At training time, we just select random sentences
+ # from the entire dataset until we reach max_seqlen.
+ pid, sid = id_pair if self.eval else random.choice(self.sentence_ids)
+ sentences = [copy([x[offset:] for x in self.sentences[pid][sid]])]
+
+ drop_sents = False if self.eval or (self.args.get('sent_drop_prob', 0) == 0) else (random.random() < self.args.get('sent_drop_prob', 0))
+ drop_last_char = False if self.eval or (self.args.get('last_char_drop_prob', 0) == 0) else (random.random() < self.args.get('last_char_drop_prob', 0))
+ total_len = len(sentences[0][0])
+
+ assert self.eval or total_len <= self.args['max_seqlen'], 'The maximum sequence length {} is less than that of the longest sentence length ({}) in the data, consider increasing it! {}'.format(self.args['max_seqlen'], total_len, ' '.join(["{}/{}".format(*x) for x in zip(self.sentences[pid][sid])]))
+ if self.eval:
+ for sid1 in range(sid+1, len(self.sentences[pid])):
+ total_len += len(self.sentences[pid][sid1][0])
+ sentences.append(self.sentences[pid][sid1])
+
+ if total_len >= self.args['max_seqlen']:
+ break
+ else:
+ while True:
+ pid1, sid1 = random.choice(self.sentence_ids)
+ total_len += len(self.sentences[pid1][sid1][0])
+ sentences.append(self.sentences[pid1][sid1])
+
+ if total_len >= self.args['max_seqlen']:
+ break
+
+ if drop_sents and len(sentences) > 1:
+ if total_len > self.args['max_seqlen']:
+ sentences = sentences[:-1]
+ if len(sentences) > 1:
+ p = [.5 ** i for i in range(1, len(sentences) + 1)] # drop a large number of sentences with smaller probability
+ cutoff = random.choices(list(range(len(sentences))), weights=list(reversed(p)))[0]
+ sentences = sentences[:cutoff+1]
+
+ units = np.concatenate([s[0] for s in sentences])
+ labels = np.concatenate([s[1] for s in sentences])
+ feats = np.concatenate([s[2] for s in sentences])
+ raw_units = [x for s in sentences for x in s[3]]
+
+ if not self.eval:
+ cutoff = self.args['max_seqlen']
+ units, labels, feats, raw_units = units[:cutoff], labels[:cutoff], feats[:cutoff], raw_units[:cutoff]
+
+ if drop_last_char: # can only happen in non-eval mode
+ if len(labels) > 1 and labels[-1] == 2 and labels[-2] in (1, 3):
+ # training text ended with a sentence end position
+ # and that word was a single character
+ # and the previous character ended the word
+ units, labels, feats, raw_units = units[:-1], labels[:-1], feats[:-1], raw_units[:-1]
+ # word end -> sentence end, mwt end -> sentence mwt end
+ labels[-1] = labels[-1] + 1
+
+ return units, labels, feats, raw_units
+
+ if eval_offsets is not None:
+ # find max padding length
+ pad_len = 0
+ for eval_offset in eval_offsets:
+ if eval_offset < self.cumlen[-1]:
+ pair_id = bisect_right(self.cumlen, eval_offset) - 1
+ pair = self.sentence_ids[pair_id]
+ pad_len = max(pad_len, len(strings_starting(pair, offset=eval_offset-self.cumlen[pair_id])[0]))
+
+ pad_len += 1
+ id_pairs = [bisect_right(self.cumlen, eval_offset) - 1 for eval_offset in eval_offsets]
+ pairs = [self.sentence_ids[pair_id] for pair_id in id_pairs]
+ offsets = [eval_offset - self.cumlen[pair_id] for eval_offset, pair_id in zip(eval_offsets, id_pairs)]
+
+ offsets_pairs = list(zip(offsets, pairs))
+ else:
+ id_pairs = random.sample(self.sentence_ids, min(len(self.sentence_ids), self.args['batch_size']))
+ offsets_pairs = [(0, x) for x in id_pairs]
+ pad_len = self.args['max_seqlen']
+
+ # put everything into padded and nicely shaped NumPy arrays and eventually convert to PyTorch tensors
+ units = np.full((len(id_pairs), pad_len), padid, dtype=np.int64)
+ labels = np.full((len(id_pairs), pad_len), -1, dtype=np.int64)
+ features = np.zeros((len(id_pairs), pad_len, feat_size), dtype=np.float32)
+ raw_units = []
+ for i, (offset, pair) in enumerate(offsets_pairs):
+ u_, l_, f_, r_ = strings_starting(pair, offset=offset, pad_len=pad_len)
+ units[i, :len(u_)] = u_
+ labels[i, :len(l_)] = l_
+ features[i, :len(f_), :] = f_
+ raw_units.append(r_ + [''] * (pad_len - len(r_)))
+
+ if unit_dropout > 0 and not self.eval:
+ # dropout characters/units at training time and replace them with UNKs
+ mask = np.random.random_sample(units.shape) < unit_dropout
+ mask[units == padid] = 0
+ units[mask] = unkid
+ for i in range(len(raw_units)):
+ for j in range(len(raw_units[i])):
+ if mask[i, j]:
+ raw_units[i][j] = ''
+
+ # dropout unit feature vector in addition to only torch.dropout in the model.
+ # experiments showed that only torch.dropout hurts the model
+ # we believe it is because the dict feature vector is mostly scarse so it makes
+ # more sense to drop out the whole vector instead of only single element.
+ if self.args['use_dictionary'] and feat_unit_dropout > 0 and not self.eval:
+ mask_feat = np.random.random_sample(units.shape) < feat_unit_dropout
+ mask_feat[units == padid] = 0
+ for i in range(len(raw_units)):
+ for j in range(len(raw_units[i])):
+ if mask_feat[i,j]:
+ features[i,j,:] = 0
+
+ units = torch.from_numpy(units)
+ labels = torch.from_numpy(labels)
+ features = torch.from_numpy(features)
+
+ return units, labels, features, raw_units
+
+class SortedDataset(Dataset):
+ """
+ Holds a TokenizationDataset for use in a torch DataLoader
+
+ The torch DataLoader is different from the DataLoader defined here
+ and allows for cpu & gpu parallelism. Updating output_predictions
+ to use this class as a wrapper to a TokenizationDataset means the
+ calculation of features can happen in parallel, saving quite a
+ bit of time.
+ """
+ def __init__(self, dataset):
+ super().__init__()
+
+ self.dataset = dataset
+ self.data, self.indices = sort_with_indices(self.dataset.data, key=len)
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(self, index):
+ return self.dataset.para_to_sentences(self.data[index])
+
+ def unsort(self, arr):
+ return unsort(arr, self.indices)
+
+ def collate(self, samples):
+ if any(len(x) > 1 for x in samples):
+ raise ValueError("Expected all paragraphs to have no preset sentence splits!")
+ feat_size = samples[0][0][2].shape[-1]
+ padid = self.dataset.vocab.unit2id('')
+
+ # +1 so that all samples end with at least one pad
+ pad_len = max(len(x[0][3]) for x in samples) + 1
+
+ units = torch.full((len(samples), pad_len), padid, dtype=torch.int64)
+ labels = torch.full((len(samples), pad_len), -1, dtype=torch.int32)
+ features = torch.zeros((len(samples), pad_len, feat_size), dtype=torch.float32)
+ raw_units = []
+ for i, sample in enumerate(samples):
+ u_, l_, f_, r_ = sample[0]
+ units[i, :len(u_)] = torch.from_numpy(u_)
+ labels[i, :len(l_)] = torch.from_numpy(l_)
+ features[i, :len(f_), :] = torch.from_numpy(f_)
+ raw_units.append(r_ + [''] * (pad_len - len(r_)))
+
+ return units, labels, features, raw_units
+
diff --git a/stanza/stanza/models/tokenization/model.py b/stanza/stanza/models/tokenization/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f6098712633ab52462c47b55ffb0da44b7370e1
--- /dev/null
+++ b/stanza/stanza/models/tokenization/model.py
@@ -0,0 +1,101 @@
+import torch
+import torch.nn.functional as F
+import torch.nn as nn
+
+class Tokenizer(nn.Module):
+ def __init__(self, args, nchars, emb_dim, hidden_dim, dropout, feat_dropout):
+ super().__init__()
+
+ self.args = args
+ feat_dim = args['feat_dim']
+
+ self.embeddings = nn.Embedding(nchars, emb_dim, padding_idx=0)
+
+ self.rnn = nn.LSTM(emb_dim + feat_dim, hidden_dim, num_layers=self.args['rnn_layers'], bidirectional=True, batch_first=True, dropout=dropout if self.args['rnn_layers'] > 1 else 0)
+
+ if self.args['conv_res'] is not None:
+ self.conv_res = nn.ModuleList()
+ self.conv_sizes = [int(x) for x in self.args['conv_res'].split(',')]
+
+ for si, size in enumerate(self.conv_sizes):
+ l = nn.Conv1d(emb_dim + feat_dim, hidden_dim * 2, size, padding=size//2, bias=self.args.get('hier_conv_res', False) or (si == 0))
+ self.conv_res.append(l)
+
+ if self.args.get('hier_conv_res', False):
+ self.conv_res2 = nn.Conv1d(hidden_dim * 2 * len(self.conv_sizes), hidden_dim * 2, 1)
+ self.tok_clf = nn.Linear(hidden_dim * 2, 1)
+ self.sent_clf = nn.Linear(hidden_dim * 2, 1)
+ if self.args['use_mwt']:
+ self.mwt_clf = nn.Linear(hidden_dim * 2, 1)
+
+ if args['hierarchical']:
+ in_dim = hidden_dim * 2
+ self.rnn2 = nn.LSTM(in_dim, hidden_dim, num_layers=1, bidirectional=True, batch_first=True)
+ self.tok_clf2 = nn.Linear(hidden_dim * 2, 1, bias=False)
+ self.sent_clf2 = nn.Linear(hidden_dim * 2, 1, bias=False)
+ if self.args['use_mwt']:
+ self.mwt_clf2 = nn.Linear(hidden_dim * 2, 1, bias=False)
+
+ self.dropout = nn.Dropout(dropout)
+ self.dropout_feat = nn.Dropout(feat_dropout)
+
+ self.toknoise = nn.Dropout(self.args['tok_noise'])
+
+ def forward(self, x, feats):
+ emb = self.embeddings(x)
+ emb = self.dropout(emb)
+ feats = self.dropout_feat(feats)
+
+
+ emb = torch.cat([emb, feats], 2)
+
+ inp, _ = self.rnn(emb)
+
+ if self.args['conv_res'] is not None:
+ conv_input = emb.transpose(1, 2).contiguous()
+ if not self.args.get('hier_conv_res', False):
+ for l in self.conv_res:
+ inp = inp + l(conv_input).transpose(1, 2).contiguous()
+ else:
+ hid = []
+ for l in self.conv_res:
+ hid += [l(conv_input)]
+ hid = torch.cat(hid, 1)
+ hid = F.relu(hid)
+ hid = self.dropout(hid)
+ inp = inp + self.conv_res2(hid).transpose(1, 2).contiguous()
+
+ inp = self.dropout(inp)
+
+ tok0 = self.tok_clf(inp)
+ sent0 = self.sent_clf(inp)
+ if self.args['use_mwt']:
+ mwt0 = self.mwt_clf(inp)
+
+ if self.args['hierarchical']:
+ if self.args['hier_invtemp'] > 0:
+ inp2, _ = self.rnn2(inp * (1 - self.toknoise(torch.sigmoid(-tok0 * self.args['hier_invtemp']))))
+ else:
+ inp2, _ = self.rnn2(inp)
+
+ inp2 = self.dropout(inp2)
+
+ tok0 = tok0 + self.tok_clf2(inp2)
+ sent0 = sent0 + self.sent_clf2(inp2)
+ if self.args['use_mwt']:
+ mwt0 = mwt0 + self.mwt_clf2(inp2)
+
+ nontok = F.logsigmoid(-tok0)
+ tok = F.logsigmoid(tok0)
+ nonsent = F.logsigmoid(-sent0)
+ sent = F.logsigmoid(sent0)
+ if self.args['use_mwt']:
+ nonmwt = F.logsigmoid(-mwt0)
+ mwt = F.logsigmoid(mwt0)
+
+ if self.args['use_mwt']:
+ pred = torch.cat([nontok, tok+nonsent+nonmwt, tok+sent+nonmwt, tok+nonsent+mwt, tok+sent+mwt], 2)
+ else:
+ pred = torch.cat([nontok, tok+nonsent, tok+sent], 2)
+
+ return pred
diff --git a/stanza/stanza/models/tokenization/tokenize_files.py b/stanza/stanza/models/tokenization/tokenize_files.py
new file mode 100644
index 0000000000000000000000000000000000000000..257d82218b06f9eea9beb0e74b049abfc6dca9d0
--- /dev/null
+++ b/stanza/stanza/models/tokenization/tokenize_files.py
@@ -0,0 +1,83 @@
+"""Use a Stanza tokenizer to turn a text file into one tokenized paragraph per line
+
+For example, the output of this script is suitable for Glove
+
+Currently this *only* supports tokenization, no MWT splitting.
+It also would be beneficial to have an option to convert spaces into
+NBSP, underscore, or some other marker to make it easier to process
+languages such as VI which have spaces in them
+"""
+
+
+import argparse
+import io
+import os
+import time
+import re
+import zipfile
+
+import torch
+
+import stanza
+from stanza.models.common.utils import open_read_text, default_device
+from stanza.models.tokenization.data import TokenizationDataset
+from stanza.models.tokenization.utils import output_predictions
+from stanza.pipeline.tokenize_processor import TokenizeProcessor
+from stanza.utils.get_tqdm import get_tqdm
+
+tqdm = get_tqdm()
+
+NEWLINE_SPLIT_RE = re.compile(r"\n\s*\n")
+
+def tokenize_to_file(tokenizer, fin, fout, chunk_size=500):
+ raw_text = fin.read()
+ documents = NEWLINE_SPLIT_RE.split(raw_text)
+ for chunk_start in tqdm(range(0, len(documents), chunk_size), leave=False):
+ chunk_end = min(chunk_start + chunk_size, len(documents))
+ chunk = documents[chunk_start:chunk_end]
+ in_docs = [stanza.Document([], text=d) for d in chunk]
+ out_docs = tokenizer.bulk_process(in_docs)
+ for document in out_docs:
+ for sent_idx, sentence in enumerate(document.sentences):
+ if sent_idx > 0:
+ fout.write(" ")
+ fout.write(" ".join(x.text for x in sentence.tokens))
+ fout.write("\n")
+
+def main(args=None):
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--lang", type=str, default="sd", help="Which language to use for tokenization")
+ parser.add_argument("--tokenize_model_path", type=str, default=None, help="Specific tokenizer model to use")
+ parser.add_argument("input_files", type=str, nargs="+", help="Which input files to tokenize")
+ parser.add_argument("--output_file", type=str, default="glove.txt", help="Where to write the tokenized output")
+ parser.add_argument("--model_dir", type=str, default=None, help="Where to get models for a Pipeline (None => default models dir)")
+ parser.add_argument("--chunk_size", type=int, default=500, help="How many 'documents' to use in a chunk when tokenizing. This is separate from the tokenizer batching - this limits how much memory gets used at once, since we don't need to store an entire file in memory at once")
+ args = parser.parse_args(args=args)
+
+ if os.path.exists(args.output_file):
+ print("Cowardly refusing to overwrite existing output file %s" % args.output_file)
+ return
+
+ if args.tokenize_model_path:
+ config = { "model_path": args.tokenize_model_path,
+ "check_requirements": False }
+ tokenizer = TokenizeProcessor(config, pipeline=None, device=default_device())
+ else:
+ pipe = stanza.Pipeline(lang=args.lang, processors="tokenize", model_dir=args.model_dir)
+ tokenizer = pipe.processors["tokenize"]
+
+ with open(args.output_file, "w", encoding="utf-8") as fout:
+ for filename in tqdm(args.input_files):
+ if filename.endswith(".zip"):
+ with zipfile.ZipFile(filename) as zin:
+ input_names = zin.namelist()
+ for input_name in tqdm(input_names, leave=False):
+ with zin.open(input_names[0]) as fin:
+ fin = io.TextIOWrapper(fin, encoding='utf-8')
+ tokenize_to_file(tokenizer, fin, fout)
+ else:
+ with open_read_text(filename, encoding="utf-8") as fin:
+ tokenize_to_file(tokenizer, fin, fout)
+
+if __name__ == '__main__':
+ main()
diff --git a/stanza/stanza/models/tokenization/trainer.py b/stanza/stanza/models/tokenization/trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc5aa99c06ca5009d062db6f827b6241953d2f02
--- /dev/null
+++ b/stanza/stanza/models/tokenization/trainer.py
@@ -0,0 +1,102 @@
+import sys
+import logging
+import torch
+import torch.nn as nn
+import torch.optim as optim
+
+from stanza.models.common import utils
+from stanza.models.common.trainer import Trainer as BaseTrainer
+from stanza.models.tokenization.utils import create_dictionary
+
+from .model import Tokenizer
+from .vocab import Vocab
+
+logger = logging.getLogger('stanza')
+
+class Trainer(BaseTrainer):
+ def __init__(self, args=None, vocab=None, lexicon=None, dictionary=None, model_file=None, device=None):
+ if model_file is not None:
+ # load everything from file
+ self.load(model_file)
+ else:
+ # build model from scratch
+ self.args = args
+ self.vocab = vocab
+ self.lexicon = list(lexicon) if lexicon is not None else None
+ self.dictionary = dictionary
+ self.model = Tokenizer(self.args, self.args['vocab_size'], self.args['emb_dim'], self.args['hidden_dim'], dropout=self.args['dropout'], feat_dropout=self.args['feat_dropout'])
+ self.model = self.model.to(device)
+ self.criterion = nn.CrossEntropyLoss(ignore_index=-1).to(device)
+ self.optimizer = utils.get_optimizer("adam", self.model, lr=self.args['lr0'], betas=(.9, .9), weight_decay=self.args['weight_decay'])
+ self.feat_funcs = self.args.get('feat_funcs', None)
+ self.lang = self.args['lang'] # language determines how token normalization is done
+
+ def update(self, inputs):
+ self.model.train()
+ units, labels, features, _ = inputs
+
+ device = next(self.model.parameters()).device
+ units = units.to(device)
+ labels = labels.to(device)
+ features = features.to(device)
+
+ pred = self.model(units, features)
+
+ self.optimizer.zero_grad()
+ classes = pred.size(2)
+ loss = self.criterion(pred.view(-1, classes), labels.view(-1))
+
+ loss.backward()
+ nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm'])
+ self.optimizer.step()
+
+ return loss.item()
+
+ def predict(self, inputs):
+ self.model.eval()
+ units, _, features, _ = inputs
+
+ device = next(self.model.parameters()).device
+ units = units.to(device)
+ features = features.to(device)
+
+ pred = self.model(units, features)
+
+ return pred.data.cpu().numpy()
+
+ def save(self, filename):
+ params = {
+ 'model': self.model.state_dict() if self.model is not None else None,
+ 'vocab': self.vocab.state_dict(),
+ # save and load lexicon as list instead of set so
+ # we can use weights_only=True
+ 'lexicon': list(self.lexicon) if self.lexicon is not None else None,
+ 'config': self.args
+ }
+ try:
+ torch.save(params, filename, _use_new_zipfile_serialization=False)
+ logger.info("Model saved to {}".format(filename))
+ except BaseException:
+ logger.warning("Saving failed... continuing anyway.")
+
+ def load(self, filename):
+ try:
+ checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True)
+ except BaseException:
+ logger.error("Cannot load model from {}".format(filename))
+ raise
+ self.args = checkpoint['config']
+ if self.args.get('use_mwt', None) is None:
+ # Default to True as many currently saved models
+ # were built with mwt layers
+ self.args['use_mwt'] = True
+ self.model = Tokenizer(self.args, self.args['vocab_size'], self.args['emb_dim'], self.args['hidden_dim'], dropout=self.args['dropout'], feat_dropout=self.args['feat_dropout'])
+ self.model.load_state_dict(checkpoint['model'])
+ self.vocab = Vocab.load_state_dict(checkpoint['vocab'])
+ self.lexicon = checkpoint['lexicon']
+
+ if self.lexicon is not None:
+ self.lexicon = set(self.lexicon)
+ self.dictionary = create_dictionary(self.lexicon)
+ else:
+ self.dictionary = None
diff --git a/stanza/stanza/utils/datasets/constituency/convert_ctb.py b/stanza/stanza/utils/datasets/constituency/convert_ctb.py
new file mode 100644
index 0000000000000000000000000000000000000000..03857f86dba813bfb00ac1fb675f55f4feff9948
--- /dev/null
+++ b/stanza/stanza/utils/datasets/constituency/convert_ctb.py
@@ -0,0 +1,224 @@
+from enum import Enum
+import glob
+import os
+import re
+
+import xml.etree.ElementTree as ET
+
+from stanza.models.constituency import tree_reader
+from stanza.utils.datasets.constituency.utils import write_dataset
+from stanza.utils.get_tqdm import get_tqdm
+
+tqdm = get_tqdm()
+
+class Version(Enum):
+ V51 = 1
+ V51b = 2
+ V90 = 3
+
+def filenum_to_shard_51(filenum):
+ if filenum >= 1 and filenum <= 815:
+ return 0
+ if filenum >= 1001 and filenum <= 1136:
+ return 0
+
+ if filenum >= 886 and filenum <= 931:
+ return 1
+ if filenum >= 1148 and filenum <= 1151:
+ return 1
+
+ if filenum >= 816 and filenum <= 885:
+ return 2
+ if filenum >= 1137 and filenum <= 1147:
+ return 2
+
+ raise ValueError("Unhandled filenum %d" % filenum)
+
+def filenum_to_shard_51_basic(filenum):
+ if filenum >= 1 and filenum <= 270:
+ return 0
+ if filenum >= 440 and filenum <= 1151:
+ return 0
+
+ if filenum >= 301 and filenum <= 325:
+ return 1
+
+ if filenum >= 271 and filenum <= 300:
+ return 2
+
+ if filenum >= 400 and filenum <= 439:
+ return None
+
+ raise ValueError("Unhandled filenum %d" % filenum)
+
+def filenum_to_shard_90(filenum):
+ if filenum >= 1 and filenum <= 40:
+ return 2
+ if filenum >= 900 and filenum <= 931:
+ return 2
+ if filenum in (1018, 1020, 1036, 1044, 1060, 1061, 1072, 1118, 1119, 1132, 1141, 1142, 1148):
+ return 2
+ if filenum >= 2165 and filenum <= 2180:
+ return 2
+ if filenum >= 2295 and filenum <= 2310:
+ return 2
+ if filenum >= 2570 and filenum <= 2602:
+ return 2
+ if filenum >= 2800 and filenum <= 2819:
+ return 2
+ if filenum >= 3110 and filenum <= 3145:
+ return 2
+
+
+ if filenum >= 41 and filenum <= 80:
+ return 1
+ if filenum >= 1120 and filenum <= 1129:
+ return 1
+ if filenum >= 2140 and filenum <= 2159:
+ return 1
+ if filenum >= 2280 and filenum <= 2294:
+ return 1
+ if filenum >= 2550 and filenum <= 2569:
+ return 1
+ if filenum >= 2775 and filenum <= 2799:
+ return 1
+ if filenum >= 3080 and filenum <= 3109:
+ return 1
+
+ if filenum >= 81 and filenum <= 900:
+ return 0
+ if filenum >= 1001 and filenum <= 1017:
+ return 0
+ if filenum in (1019, 1130, 1131):
+ return 0
+ if filenum >= 1021 and filenum <= 1035:
+ return 0
+ if filenum >= 1037 and filenum <= 1043:
+ return 0
+ if filenum >= 1045 and filenum <= 1059:
+ return 0
+ if filenum >= 1062 and filenum <= 1071:
+ return 0
+ if filenum >= 1073 and filenum <= 1117:
+ return 0
+ if filenum >= 1133 and filenum <= 1140:
+ return 0
+ if filenum >= 1143 and filenum <= 1147:
+ return 0
+ if filenum >= 1149 and filenum <= 2139:
+ return 0
+ if filenum >= 2160 and filenum <= 2164:
+ return 0
+ if filenum >= 2181 and filenum <= 2279:
+ return 0
+ if filenum >= 2311 and filenum <= 2549:
+ return 0
+ if filenum >= 2603 and filenum <= 2774:
+ return 0
+ if filenum >= 2820 and filenum <= 3079:
+ return 0
+ if filenum >= 4000 and filenum <= 7017:
+ return 0
+
+
+def collect_trees_s(root):
+ if root.tag == 'S':
+ yield root.text, root.attrib['ID']
+
+ for child in root:
+ for tree in collect_trees_s(child):
+ yield tree
+
+def collect_trees_text(root):
+ if root.tag == 'TEXT' and len(root.text.strip()) > 0:
+ yield root.text, None
+
+ if root.tag == 'TURN' and len(root.text.strip()) > 0:
+ yield root.text, None
+
+ for child in root:
+ for tree in collect_trees_text(child):
+ yield tree
+
+
+id_re = re.compile("")
+su_re = re.compile("<(su|msg) id=([0-9a-zA-Z_=]+)>")
+
+def convert_ctb(input_dir, output_dir, dataset_name, version):
+ input_files = glob.glob(os.path.join(input_dir, "*"))
+
+ # train, dev, test
+ datasets = [[], [], []]
+
+ sorted_filenames = []
+ for input_filename in input_files:
+ base_filename = os.path.split(input_filename)[1]
+ filenum = int(os.path.splitext(base_filename)[0].split("_")[1])
+ sorted_filenames.append((filenum, input_filename))
+ sorted_filenames.sort()
+
+ for filenum, filename in tqdm(sorted_filenames):
+ if version in (Version.V51, Version.V51b):
+ with open(filename, errors='ignore', encoding="gb2312") as fin:
+ text = fin.read()
+ elif version is Version.V90:
+ with open(filename, encoding="utf-8") as fin:
+ text = fin.read()
+ if text.find("") >= 0 and text.find("") < 0:
+ text = text.replace("", "")
+ if filenum in (4205, 4208, 4289):
+ text = text.replace("<)", "<)").replace(">)", ">)")
+ if filenum >= 4000 and filenum <= 4411:
+ if text.find("= 0:
+ text = text.replace("", "")
+ elif text.find("", "")
+ text = "\n%s\n" % text
+ if filenum >= 5000 and filenum <= 5558 or filenum >= 6000 and filenum <= 6700 or filenum >= 7000 and filenum <= 7017:
+ text = su_re.sub("", text)
+ if filenum in (6066, 6453):
+ text = text.replace("<", "<").replace(">", ">")
+ text = "\n%s\n" % text
+ else:
+ raise ValueError("Unknown CTB version %s" % version)
+ text = id_re.sub(r'', text)
+ text = text.replace("&", "&")
+
+ try:
+ xml_root = ET.fromstring(text)
+ except Exception as e:
+ print(text[:1000])
+ raise RuntimeError("Cannot xml process %s" % filename) from e
+ trees = [x for x in collect_trees_s(xml_root)]
+ if version is Version.V90 and len(trees) == 0:
+ trees = [x for x in collect_trees_text(xml_root)]
+
+ if version in (Version.V51, Version.V51b):
+ trees = [x[0] for x in trees if filenum != 414 or x[1] != "4366"]
+ else:
+ trees = [x[0] for x in trees]
+
+ trees = "\n".join(trees)
+ try:
+ trees = tree_reader.read_trees(trees, use_tqdm=False)
+ except ValueError as e:
+ print(text[:300])
+ raise RuntimeError("Could not process the tree text in %s" % filename)
+ trees = [t.prune_none().simplify_labels() for t in trees]
+
+ assert len(trees) > 0, "No trees in %s" % filename
+
+ if version is Version.V51:
+ shard = filenum_to_shard_51(filenum)
+ elif version is Version.V51b:
+ shard = filenum_to_shard_51_basic(filenum)
+ else:
+ shard = filenum_to_shard_90(filenum)
+ if shard is None:
+ continue
+ datasets[shard].extend(trees)
+
+
+ write_dataset(datasets, output_dir, dataset_name)
diff --git a/stanza/stanza/utils/datasets/constituency/extract_silver_dataset.py b/stanza/stanza/utils/datasets/constituency/extract_silver_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a89fb96d6d01a71ec2e539cd713b49d89134433
--- /dev/null
+++ b/stanza/stanza/utils/datasets/constituency/extract_silver_dataset.py
@@ -0,0 +1,47 @@
+"""
+After running build_silver_dataset.py, this extracts the trees of a certain match level
+
+For example
+
+python3 stanza/utils/datasets/constituency/extract_silver_dataset.py --parsed_trees /u/nlp/data/constituency-parser/italian/2024_it_vit_electra/a*.trees --keep_score 0 --output_file /u/nlp/data/constituency-parser/italian/2024_it_vit_electra/it_silver_0.mrg
+
+for i in `echo 0 1 2 3 4 5 6 7 8 9 10`; do python3 stanza/utils/datasets/constituency/extract_silver_dataset.py --parsed_trees /u/nlp/data/constituency-parser/italian/2024_it_vit_electra/a*.trees --keep_score $i --output_file /u/nlp/data/constituency-parser/italian/2024_it_vit_electra/it_silver_$i.mrg; done
+"""
+
+import argparse
+import json
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="After finding common trees using build_silver_dataset, this extracts them all or just the ones from a particular level of accuracy")
+ parser.add_argument('--parsed_trees', type=str, nargs='+', help='Input file(s) of trees parsed into the build_silver_dataset json format.')
+ parser.add_argument('--keep_score', type=int, default=None, help='Which agreement level to keep. None keeps all')
+ parser.add_argument('--output_file', type=str, default=None, help='Where to put the output file')
+ args = parser.parse_args()
+
+ return args
+
+
+def main():
+ args = parse_args()
+
+ trees = []
+ for filename in args.parsed_trees:
+ with open(filename, encoding='utf-8') as fin:
+ for line in fin.readlines():
+ tree = json.loads(line)
+ if args.keep_score is None or tree['count'] == args.keep_score:
+ tree = tree['tree']
+ trees.append(tree)
+
+ if args.output_file is None:
+ for tree in trees:
+ print(tree)
+ else:
+ with open(args.output_file, 'w', encoding='utf-8') as fout:
+ for tree in trees:
+ fout.write(tree)
+ fout.write('\n')
+
+if __name__ == '__main__':
+ main()
+
diff --git a/stanza/stanza/utils/datasets/coref/balance_languages.py b/stanza/stanza/utils/datasets/coref/balance_languages.py
new file mode 100644
index 0000000000000000000000000000000000000000..3aede0eb5e75bcd888356dc7e66917bc90c40933
--- /dev/null
+++ b/stanza/stanza/utils/datasets/coref/balance_languages.py
@@ -0,0 +1,60 @@
+"""
+balance_concat.py
+create a test set from a dev set which is language balanced
+"""
+
+import json
+from collections import defaultdict
+
+from random import Random
+
+# fix random seed for reproducability
+R = Random(42)
+
+with open("./corefud_concat_v1_0_langid.train.json", 'r') as df:
+ raw = json.load(df)
+
+# calculate type of each class; then, we will select the one
+# which has the LOWEST counts as the sample rate
+lang_counts = defaultdict(int)
+for i in raw:
+ lang_counts[i["lang"]] += 1
+
+min_lang_count = min(lang_counts.values())
+
+# sample 20% of the smallest amount for test set
+# this will look like an absurdly small number, but
+# remember this is DOCUMENTS not TOKENS or UTTERANCES
+# so its actually decent
+# also its per language
+test_set_size = int(0.1*min_lang_count)
+
+# sampling input by language
+raw_by_language = defaultdict(list)
+for i in raw:
+ raw_by_language[i["lang"]].append(i)
+languages = list(set(raw_by_language.keys()))
+
+train_set = []
+test_set = []
+for i in languages:
+ length = list(range(len(raw_by_language[i])))
+ choices = R.sample(length, test_set_size)
+
+ for indx,i in enumerate(raw_by_language[i]):
+ if indx in choices:
+ test_set.append(i)
+ else:
+ train_set.append(i)
+
+with open("./corefud_concat_v1_0_langid-bal.train.json", 'w') as df:
+ json.dump(train_set, df, indent=2)
+
+with open("./corefud_concat_v1_0_langid-bal.test.json", 'w') as df:
+ json.dump(test_set, df, indent=2)
+
+
+
+# raw_by_language["en"]
+
+
diff --git a/stanza/stanza/utils/datasets/ner/__init__.py b/stanza/stanza/utils/datasets/ner/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/stanza/stanza/utils/datasets/ner/build_en_combined.py b/stanza/stanza/utils/datasets/ner/build_en_combined.py
new file mode 100644
index 0000000000000000000000000000000000000000..07811c4729b50415b64c77f91b94b4894b3232e0
--- /dev/null
+++ b/stanza/stanza/utils/datasets/ner/build_en_combined.py
@@ -0,0 +1,105 @@
+"""
+Builds a combined model out of OntoNotes, WW, and CoNLL.
+
+This is done with three layers in the multi_ner column:
+
+First layer is OntoNotes only. Other datasets have that left as blank.
+
+Second layer is the 9 class WW dataset. OntoNotes is reduced to 9 classes for this column.
+
+Third column is the CoNLL dataset. OntoNotes and WW are both projected to this.
+"""
+
+import json
+import os
+import shutil
+
+from stanza.utils import default_paths
+from stanza.utils.datasets.ner.simplify_en_worldwide import process_label
+from stanza.utils.datasets.ner.simplify_ontonotes_to_worldwide import simplify_ontonotes_to_worldwide
+from stanza.utils.datasets.ner.utils import combine_files
+
+def convert_ontonotes_file(filename, short_name):
+ assert "en_ontonotes." in filename
+ if not os.path.exists(filename):
+ raise FileNotFoundError("Cannot convert missing file %s" % filename)
+ new_filename = filename.replace("en_ontonotes.", short_name + ".ontonotes.")
+
+ with open(filename) as fin:
+ doc = json.load(fin)
+
+ for sentence in doc:
+ is_start = False
+ for word in sentence:
+ text = word['text']
+ ner = word['ner']
+ s9 = simplify_ontonotes_to_worldwide(ner)
+ _, s4, is_start = process_label((text, s9), is_start)
+ word['multi_ner'] = (ner, s9, s4)
+
+ with open(new_filename, "w") as fout:
+ json.dump(doc, fout, indent=2)
+
+def convert_worldwide_file(filename, short_name):
+ assert "en_worldwide-9class." in filename
+ if not os.path.exists(filename):
+ raise FileNotFoundError("Cannot convert missing file %s" % filename)
+ new_filename = filename.replace("en_worldwide-9class.", short_name + ".worldwide-9class.")
+
+ with open(filename) as fin:
+ doc = json.load(fin)
+
+ for sentence in doc:
+ is_start = False
+ for word in sentence:
+ text = word['text']
+ ner = word['ner']
+ _, s4, is_start = process_label((text, ner), is_start)
+ word['multi_ner'] = ("-", ner, s4)
+
+ with open(new_filename, "w") as fout:
+ json.dump(doc, fout, indent=2)
+
+def convert_conll03_file(filename, short_name):
+ assert "en_conll03." in filename
+ if not os.path.exists(filename):
+ raise FileNotFoundError("Cannot convert missing file %s" % filename)
+ new_filename = filename.replace("en_conll03.", short_name + ".conll03.")
+
+ with open(filename) as fin:
+ doc = json.load(fin)
+
+ for sentence in doc:
+ for word in sentence:
+ ner = word['ner']
+ word['multi_ner'] = ("-", "-", ner)
+
+ with open(new_filename, "w") as fout:
+ json.dump(doc, fout, indent=2)
+
+def build_combined_dataset(base_output_path, short_name):
+ convert_ontonotes_file(os.path.join(base_output_path, "en_ontonotes.train.json"), short_name)
+ convert_ontonotes_file(os.path.join(base_output_path, "en_ontonotes.dev.json"), short_name)
+ convert_ontonotes_file(os.path.join(base_output_path, "en_ontonotes.test.json"), short_name)
+
+ convert_worldwide_file(os.path.join(base_output_path, "en_worldwide-9class.train.json"), short_name)
+ convert_conll03_file(os.path.join(base_output_path, "en_conll03.train.json"), short_name)
+
+ combine_files(os.path.join(base_output_path, "%s.train.json" % short_name),
+ os.path.join(base_output_path, "en_combined.ontonotes.train.json"),
+ os.path.join(base_output_path, "en_combined.worldwide-9class.train.json"),
+ os.path.join(base_output_path, "en_combined.conll03.train.json"))
+ shutil.copyfile(os.path.join(base_output_path, "en_combined.ontonotes.dev.json"),
+ os.path.join(base_output_path, "%s.dev.json" % short_name))
+ shutil.copyfile(os.path.join(base_output_path, "en_combined.ontonotes.test.json"),
+ os.path.join(base_output_path, "%s.test.json" % short_name))
+
+
+def main():
+ paths = default_paths.get_default_paths()
+ base_output_path = paths["NER_DATA_DIR"]
+
+ build_combined_dataset(base_output_path, "en_combined")
+
+if __name__ == '__main__':
+ main()
diff --git a/stanza/stanza/utils/datasets/ner/check_for_duplicates.py b/stanza/stanza/utils/datasets/ner/check_for_duplicates.py
new file mode 100644
index 0000000000000000000000000000000000000000..f91c057945f4b0b112d1f28accc6342916836a6d
--- /dev/null
+++ b/stanza/stanza/utils/datasets/ner/check_for_duplicates.py
@@ -0,0 +1,49 @@
+"""
+A simple tool to check if there are duplicates in a set of NER files
+
+It's surprising how many datasets have a bunch of duplicates...
+"""
+
+def read_sentences(filename):
+ """
+ Read the sentences (without tags) from a BIO file
+ """
+ sentences = []
+ with open(filename) as fin:
+ lines = fin.readlines()
+ current_sentence = []
+ for line in lines:
+ line = line.strip()
+ if not line:
+ if current_sentence:
+ sentences.append(tuple(current_sentence))
+ current_sentence = []
+ continue
+ word = line.split("\t")[0]
+ current_sentence.append(word)
+ if len(current_sentence) > 0:
+ sentences.append(tuple(current_sentence))
+ return sentences
+
+def check_for_duplicates(output_filenames, fail=False, check_self=False, print_all=False):
+ """
+ Checks for exact duplicates in a list of NER files
+ """
+ sentence_map = {}
+ for output_filename in output_filenames:
+ duplicates = 0
+ sentences = read_sentences(output_filename)
+ for sentence in sentences:
+ other_file = sentence_map.get(sentence, None)
+ if other_file is not None and (check_self or other_file != output_filename):
+ if fail:
+ raise ValueError("Duplicate sentence '{}', first in {}, also in {}".format("".join(sentence), sentence_map[sentence], output_filename))
+ else:
+ if duplicates == 0 and not print_all:
+ print("First duplicate:")
+ if duplicates == 0 or print_all:
+ print("{}\nFound in {} and {}".format(sentence, other_file, output_filename))
+ duplicates = duplicates + 1
+ sentence_map[sentence] = output_filename
+ if duplicates > 0:
+ print("%d duplicates found in %s" % (duplicates, output_filename))
diff --git a/stanza/stanza/utils/datasets/ner/convert_ar_aqmar.py b/stanza/stanza/utils/datasets/ner/convert_ar_aqmar.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b6d1a780cba65f248a1412584df9f0e0588386d
--- /dev/null
+++ b/stanza/stanza/utils/datasets/ner/convert_ar_aqmar.py
@@ -0,0 +1,126 @@
+"""
+A script to randomly shuffle the input files in the AQMAR dataset and produce train/dev/test for stanza
+
+The sentences themselves are shuffled, not the data files
+
+This script reads the input files directly from the .zip
+"""
+
+
+from collections import Counter
+import random
+import zipfile
+
+from stanza.utils.datasets.ner.utils import write_dataset
+
+def read_sentences(infile):
+ """
+ Read sentences from an open file
+ """
+ sents = []
+ cache = []
+ for line in infile:
+ if isinstance(line, bytes):
+ line = line.decode()
+ line = line.rstrip()
+ if len(line) == 0:
+ if len(cache) > 0:
+ sents.append(cache)
+ cache = []
+ continue
+ array = line.split()
+ assert len(array) == 2
+ w, t = array
+ cache.append([w, t])
+ if len(cache) > 0:
+ sents.append(cache)
+ cache = []
+ return sents
+
+
+def normalize_tags(sents):
+ new_sents = []
+ # normalize tags
+ for sent in sents:
+ new_sentence = []
+ for i, pair in enumerate(sent):
+ w, t = pair
+ if t.startswith('O'):
+ new_t = 'O'
+ elif t.startswith('I-'):
+ type = t[2:]
+ if type.startswith('MIS'):
+ new_t = 'I-MISC'
+ elif type.startswith('-'): # handle I--ORG
+ new_t = 'I-' + type[1:]
+ else:
+ new_t = t
+ elif t.startswith('B-'):
+ type = t[2:]
+ if type.startswith('MIS'):
+ new_t = 'B-MISC'
+ elif type.startswith('ENGLISH') or type.startswith('SPANISH'):
+ new_t = 'O'
+ else:
+ new_t = t
+ else:
+ new_t = 'O'
+ # modify original tag
+ new_sentence.append((sent[i][0], new_t))
+ new_sents.append(new_sentence)
+ return new_sents
+
+
+def convert_shuffle(base_input_path, base_output_path, short_name):
+ """
+ Convert AQMAR to a randomly shuffled dataset
+
+ base_input_path is the zip file. base_output_path is the output directory
+ """
+ if not zipfile.is_zipfile(base_input_path):
+ raise FileNotFoundError("Expected %s to be the zipfile with AQMAR in it" % base_input_path)
+
+ with zipfile.ZipFile(base_input_path) as zin:
+ namelist = zin.namelist()
+ annotation_files = [x for x in namelist if x.endswith(".txt") and not "/" in x]
+ annotation_files = sorted(annotation_files)
+
+ # although not necessary for good results, this does put
+ # things in the same order the shell was alphabetizing files
+ # when the original models were created for Stanza
+ assert annotation_files[2] == 'Computer.txt'
+ assert annotation_files[3] == 'Computer_Software.txt'
+ annotation_files[2], annotation_files[3] = annotation_files[3], annotation_files[2]
+
+ if len(annotation_files) != 28:
+ raise RuntimeError("Expected exactly 28 labeled .txt files in %s but got %d" % (base_input_path, len(annotation_files)))
+
+ sentences = []
+ for in_filename in annotation_files:
+ with zin.open(in_filename) as infile:
+ new_sentences = read_sentences(infile)
+ print(f"{len(new_sentences)} sentences read from {in_filename}")
+
+ new_sentences = normalize_tags(new_sentences)
+ sentences.extend(new_sentences)
+
+ all_tags = Counter([p[1] for sent in sentences for p in sent])
+ print("All tags after normalization:")
+ print(list(all_tags.keys()))
+
+ num = len(sentences)
+ train_num = int(num*0.7)
+ dev_num = int(num*0.15)
+
+ random.seed(1234)
+
+ random.shuffle(sentences)
+
+ train_sents = sentences[:train_num]
+ dev_sents = sentences[train_num:train_num+dev_num]
+ test_sents = sentences[train_num+dev_num:]
+
+ shuffled_dataset = [train_sents, dev_sents, test_sents]
+
+ write_dataset(shuffled_dataset, base_output_path, short_name)
+
diff --git a/stanza/stanza/utils/datasets/ner/convert_bsf_to_beios.py b/stanza/stanza/utils/datasets/ner/convert_bsf_to_beios.py
new file mode 100644
index 0000000000000000000000000000000000000000..60ddcf8c31686866c8ddb1d17820607e370d3cb8
--- /dev/null
+++ b/stanza/stanza/utils/datasets/ner/convert_bsf_to_beios.py
@@ -0,0 +1,227 @@
+import argparse
+import logging
+import os
+import glob
+from collections import namedtuple
+import re
+from typing import Tuple
+from tqdm import tqdm
+from random import choices, shuffle
+
+BsfInfo = namedtuple('BsfInfo', 'id, tag, start_idx, end_idx, token')
+
+log = logging.getLogger(__name__)
+log.setLevel(logging.INFO)
+
+
+def format_token_as_beios(token: str, tag: str) -> list:
+ t_words = token.split()
+ res = []
+ if len(t_words) == 1:
+ res.append(token + ' S-' + tag)
+ else:
+ res.append(t_words[0] + ' B-' + tag)
+ for t_word in t_words[1: -1]:
+ res.append(t_word + ' I-' + tag)
+ res.append(t_words[-1] + ' E-' + tag)
+ return res
+
+
+def format_token_as_iob(token: str, tag: str) -> list:
+ t_words = token.split()
+ res = []
+ if len(t_words) == 1:
+ res.append(token + ' B-' + tag)
+ else:
+ res.append(t_words[0] + ' B-' + tag)
+ for t_word in t_words[1:]:
+ res.append(t_word + ' I-' + tag)
+ return res
+
+
+def convert_bsf(data: str, bsf_markup: str, converter: str = 'beios') -> str:
+ """
+ Convert data file with NER markup in Brat Standoff Format to BEIOS or IOB format.
+
+ :param converter: iob or beios converter to use for document
+ :param data: tokenized data to be converted. Each token separated with a space
+ :param bsf_markup: Brat Standoff Format markup
+ :return: data in BEIOS or IOB format https://en.wikipedia.org/wiki/Inside–outside–beginning_(tagging)
+ """
+
+ def join_simple_chunk(chunk: str) -> list:
+ if len(chunk.strip()) == 0:
+ return []
+ # keep the newlines, but discard the non-newline whitespace
+ tokens = re.split(r'(\n)|\s', chunk.strip())
+ # the re will return None for splits which were not caught in a group
+ tokens = [x for x in tokens if x is not None]
+ return [token + ' O' if len(token.strip()) > 0 else token for token in tokens]
+
+ converters = {'beios': format_token_as_beios, 'iob': format_token_as_iob}
+ res = []
+ markup = parse_bsf(bsf_markup)
+
+ prev_idx = 0
+ m_ln: BsfInfo
+ for m_ln in markup:
+ res += join_simple_chunk(data[prev_idx:m_ln.start_idx])
+
+ convert_f = converters[converter]
+ res.extend(convert_f(m_ln.token, m_ln.tag))
+ prev_idx = m_ln.end_idx
+
+ if prev_idx < len(data) - 1:
+ res += join_simple_chunk(data[prev_idx:])
+
+ return '\n'.join(res)
+
+
+def parse_bsf(bsf_data: str) -> list:
+ """
+ Convert textual bsf representation to a list of named entities.
+
+ :param bsf_data: data in the format 'T9 PERS 778 783 токен'
+ :return: list of named tuples for each line of the data representing a single named entity token
+ """
+ if len(bsf_data.strip()) == 0:
+ return []
+
+ ln_ptrn = re.compile(r'(T\d+)\s(\w+)\s(\d+)\s(\d+)\s(.+?)(?=T\d+\s\w+\s\d+\s\d+|$)', flags=re.DOTALL)
+ result = []
+ for m in ln_ptrn.finditer(bsf_data.strip()):
+ bsf = BsfInfo(m.group(1), m.group(2), int(m.group(3)), int(m.group(4)), m.group(5).strip())
+ result.append(bsf)
+ return result
+
+
+CORPUS_NAME = 'Ukrainian-languk'
+
+
+def convert_bsf_in_folder(src_dir_path: str, dst_dir_path: str, converter: str = 'beios',
+ doc_delim: str = '\n', train_test_split_file: str = None) -> None:
+ """
+
+ :param doc_delim: delimiter to be used between documents
+ :param src_dir_path: path to directory with BSF marked files
+ :param dst_dir_path: where to save output data
+ :param converter: `beios` or `iob` output formats
+ :param train_test_split_file: path to file containing train/test lists of file names
+ :return:
+ """
+ ann_path = os.path.join(src_dir_path, '*.tok.ann')
+ ann_files = glob.glob(ann_path)
+ ann_files.sort()
+
+ tok_path = os.path.join(src_dir_path, '*.tok.txt')
+ tok_files = glob.glob(tok_path)
+ tok_files.sort()
+
+ corpus_folder = os.path.join(dst_dir_path, CORPUS_NAME)
+ if not os.path.exists(corpus_folder):
+ os.makedirs(corpus_folder)
+
+ if len(ann_files) == 0 or len(tok_files) == 0:
+ raise FileNotFoundError(f'Token and annotation files are not found at specified path {ann_path}')
+ if len(ann_files) != len(tok_files):
+ raise RuntimeError(f'Mismatch between Annotation and Token files. Ann files: {len(ann_files)}, token files: {len(tok_files)}')
+
+ train_set = []
+ dev_set = []
+ test_set = []
+
+ data_sets = [train_set, dev_set, test_set]
+ split_weights = (8, 1, 1)
+
+ if train_test_split_file is not None:
+ train_names, dev_names, test_names = read_languk_train_test_split(train_test_split_file)
+
+ log.info(f'Found {len(tok_files)} files in data folder "{src_dir_path}"')
+ for (tok_fname, ann_fname) in tqdm(zip(tok_files, ann_files), total=len(tok_files), unit='file'):
+ if tok_fname[:-3] != ann_fname[:-3]:
+ tqdm.write(f'Token and Annotation file names do not match ann={ann_fname}, tok={tok_fname}')
+ continue
+
+ with open(tok_fname) as tok_file, open(ann_fname) as ann_file:
+ token_data = tok_file.read()
+ ann_data = ann_file.read()
+ out_data = convert_bsf(token_data, ann_data, converter)
+
+ if train_test_split_file is None:
+ target_dataset = choices(data_sets, split_weights)[0]
+ else:
+ target_dataset = train_set
+ fkey = os.path.basename(tok_fname)[:-4]
+ if fkey in dev_names:
+ target_dataset = dev_set
+ elif fkey in test_names:
+ target_dataset = test_set
+
+ target_dataset.append(out_data)
+ log.info(f'Data is split as following: train={len(train_set)}, dev={len(dev_set)}, test={len(test_set)}')
+
+ # writing data to {train/dev/test}.bio files
+ names = ['train', 'dev', 'test']
+ if doc_delim != '\n':
+ doc_delim = '\n' + doc_delim + '\n'
+ for idx, name in enumerate(names):
+ fname = os.path.join(corpus_folder, name + '.bio')
+ with open(fname, 'w') as f:
+ f.write(doc_delim.join(data_sets[idx]))
+ log.info('Writing to ' + fname)
+
+ log.info('All done')
+
+
+def read_languk_train_test_split(file_path: str, dev_split: float = 0.1) -> Tuple:
+ """
+ Read predefined split of train and test files in data set.
+ Originally located under doc/dev-test-split.txt
+ :param file_path: path to dev-test-split.txt file (should include file name with extension)
+ :param dev_split: 0 to 1 float value defining how much to allocate to dev split
+ :return: tuple of (train, dev, test) each containing list of files to be used for respective data sets
+ """
+ log.info(f'Trying to read train/dev/test split from file "{file_path}". Dev allocation = {dev_split}')
+ train_files, test_files, dev_files = [], [], []
+ container = test_files
+ with open(file_path, 'r') as f:
+ for ln in f:
+ ln = ln.strip()
+ if ln == 'DEV':
+ container = train_files
+ elif ln == 'TEST':
+ container = test_files
+ elif ln == '':
+ pass
+ else:
+ container.append(ln)
+
+ # split in file only contains train and test split.
+ # For Stanza training we need train, dev, test
+ # We will take part of train as dev set
+ # This way anyone using test set outside of this code base can be sure that there was no data set polution
+ shuffle(train_files)
+ dev_files = train_files[: int(len(train_files) * dev_split)]
+ train_files = train_files[int(len(train_files) * dev_split):]
+
+ assert len(set(train_files).intersection(set(dev_files))) == 0
+
+ log.info(f'Files in each set: train={len(train_files)}, dev={len(dev_files)}, test={len(test_files)}')
+ return train_files, dev_files, test_files
+
+
+if __name__ == '__main__':
+ logging.basicConfig()
+
+ parser = argparse.ArgumentParser(description='Convert lang-uk NER data set from BSF format to BEIOS format compatible with Stanza NER model training requirements.\n'
+ 'Original data set should be downloaded from https://github.com/lang-uk/ner-uk\n'
+ 'For example, create a directory extern_data/lang_uk, then run "git clone git@github.com:lang-uk/ner-uk.git')
+ parser.add_argument('--src_dataset', type=str, default='extern_data/ner/lang-uk/ner-uk/data', help='Dir with lang-uk dataset "data" folder (https://github.com/lang-uk/ner-uk)')
+ parser.add_argument('--dst', type=str, default='data/ner', help='Where to store the converted dataset')
+ parser.add_argument('-c', type=str, default='beios', help='`beios` or `iob` formats to be used for output')
+ parser.add_argument('--doc_delim', type=str, default='\n', help='Delimiter to be used to separate documents in the output data')
+ parser.add_argument('--split_file', type=str, help='Name of a file containing Train/Test split (files in train and test set)')
+ parser.print_help()
+ args = parser.parse_args()
+
+ convert_bsf_in_folder(args.src_dataset, args.dst, args.c, args.doc_delim, train_test_split_file=args.split_file)
diff --git a/stanza/stanza/utils/datasets/ner/convert_ijc.py b/stanza/stanza/utils/datasets/ner/convert_ijc.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc6caa8b6a4debd1880876a4b8e59ba5dcc06fcd
--- /dev/null
+++ b/stanza/stanza/utils/datasets/ner/convert_ijc.py
@@ -0,0 +1,146 @@
+import argparse
+import random
+import sys
+
+"""
+Converts IJC data to a TSV format.
+
+So far, tested on Hindi. Not checked on any of the other languages.
+"""
+
+def convert_tag(tag):
+ """
+ Project the classes IJC used to 4 classes with more human-readable names
+
+ The trained result is a pile, as I inadvertently taught my
+ daughter to call horrible things, but leaving them with the
+ original classes is also a pile
+ """
+ if not tag:
+ return "O"
+ if tag == "NEP":
+ return "PER"
+ if tag == "NEO":
+ return "ORG"
+ if tag == "NEL":
+ return "LOC"
+ return "MISC"
+
+def read_single_file(input_file, bio_format=True):
+ """
+ Reads an IJC NER file and returns a list of list of lines
+ """
+ sentences = []
+ lineno = 0
+ with open(input_file) as fin:
+ current_sentence = []
+ in_ner = False
+ in_sentence = False
+ printed_first = False
+ nesting = 0
+ for line in fin:
+ lineno = lineno + 1
+ line = line.strip()
+ if not line:
+ continue
+ if line.startswith(""):
+ assert not current_sentence, "File %s had an unexpected tag" % input_file
+ continue
+
+ if line.startswith(""):
+ # Would like to assert that empty sentences don't exist, but alas, they do
+ # assert current_sentence, "File %s has an empty sentence at %d" % (input_file, lineno)
+ # AssertionError: File .../hi_ijc/training-hindi/193.naval.utf8 has an empty sentence at 74
+ if current_sentence:
+ sentences.append(current_sentence)
+ current_sentence = []
+ continue
+
+ if line == "))":
+ assert in_sentence, "File %s closed a sentence when there was no open sentence at %d" % (input_file, lineno)
+ nesting = nesting - 1
+ if nesting < 0:
+ in_sentence = False
+ nesting = 0
+ elif nesting == 0:
+ in_ner = False
+ continue
+
+ pieces = line.split("\t")
+ if pieces[0] == '0':
+ assert pieces[1] == '((', "File %s has an unexpected first line at %d" % (input_file, lineno)
+ in_sentence = True
+ continue
+
+ if pieces[1] == '((':
+ nesting = nesting + 1
+ if nesting == 1:
+ if len(pieces) < 4:
+ tag = None
+ else:
+ assert pieces[3][0] == '<' and pieces[3][-1] == '>', "File %s has an unexpected tag format at %d: %s" % (input_file, lineno, pieces[3])
+ ne, tag = pieces[3][1:-1].split('=', 1)
+ assert pieces[3] == "<%s=%s>" % (ne, tag), "File %s has an unexpected tag format at %d: %s" % (input_file, lineno, pieces[3])
+ in_ner = True
+ printed_first = False
+ tag = convert_tag(tag)
+ elif in_ner and tag:
+ if bio_format:
+ if printed_first:
+ current_sentence.append((pieces[1], "I-" + tag))
+ else:
+ current_sentence.append((pieces[1], "B-" + tag))
+ printed_first = True
+ else:
+ current_sentence.append((pieces[1], tag))
+ else:
+ current_sentence.append((pieces[1], "O"))
+ assert not current_sentence, "File %s is unclosed!" % input_file
+ return sentences
+
+def read_ijc_files(input_files, bio_format=True):
+ sentences = []
+ for input_file in input_files:
+ sentences.extend(read_single_file(input_file, bio_format))
+ return sentences
+
+def convert_ijc(input_files, csv_file, bio_format=True):
+ sentences = read_ijc_files(input_files, bio_format)
+ with open(csv_file, "w") as fout:
+ for sentence in sentences:
+ for word in sentence:
+ fout.write("%s\t%s\n" % word)
+ fout.write("\n")
+
+def convert_split_ijc(input_files, train_csv, dev_csv):
+ """
+ Randomly splits the given list of input files into a train/dev with 85/15 split
+
+ The original datasets only have train & test
+ """
+ random.seed(1234)
+ train_files = []
+ dev_files = []
+ for filename in input_files:
+ if random.random() < 0.85:
+ train_files.append(filename)
+ else:
+ dev_files.append(filename)
+
+ if len(train_files) == 0 or len(dev_files) == 0:
+ raise RuntimeError("Not enough files to split into train & dev")
+
+ convert_ijc(train_files, train_csv)
+ convert_ijc(dev_files, dev_csv)
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--output_path', type=str, default="/home/john/stanza/data/ner/hi_ijc.test.csv", help="Where to output the results")
+ parser.add_argument('input_files', metavar='N', nargs='+', help='input files to process')
+ args = parser.parse_args()
+
+ convert_ijc(args.input_files, args.output_path, False)
diff --git a/stanza/stanza/utils/datasets/ner/convert_nkjp.py b/stanza/stanza/utils/datasets/ner/convert_nkjp.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0de125a5767bfc79dc3d703f7b014fa2bf6d55d
--- /dev/null
+++ b/stanza/stanza/utils/datasets/ner/convert_nkjp.py
@@ -0,0 +1,266 @@
+import argparse
+import json
+import os
+import random
+import tarfile
+import tempfile
+from tqdm import tqdm
+# could import lxml here, but that would involve adding lxml as a
+# dependency to the stanza package
+# another alternative would be to try & catch ImportError
+try:
+ from lxml import etree
+except ImportError:
+ import xml.etree.ElementTree as etree
+
+
+NAMESPACE = "http://www.tei-c.org/ns/1.0"
+MORPH_FILE = "ann_morphosyntax.xml"
+NER_FILE = "ann_named.xml"
+SEGMENTATION_FILE = "ann_segmentation.xml"
+
+def parse_xml(path):
+ if not os.path.exists(path):
+ return None
+ et = etree.parse(path)
+ rt = et.getroot()
+ return rt
+
+
+def get_node_id(node):
+ # get the id from the xml node
+ return node.get('{http://www.w3.org/XML/1998/namespace}id')
+
+
+def extract_entities_from_subfolder(subfolder, nkjp_dir):
+ # read the ner annotation from a subfolder, assign it to paragraphs
+ subfolder_entities = extract_unassigned_subfolder_entities(subfolder, nkjp_dir)
+ par_id_to_segs = assign_entities(subfolder, subfolder_entities, nkjp_dir)
+ return par_id_to_segs
+
+
+def extract_unassigned_subfolder_entities(subfolder, nkjp_dir):
+ """
+ Build and return a map from par_id to extracted entities
+ """
+ ner_path = os.path.join(nkjp_dir, subfolder, NER_FILE)
+ rt = parse_xml(ner_path)
+ if rt is None:
+ return None
+ subfolder_entities = {}
+ ner_pars = rt.findall("{%s}TEI/{%s}text/{%s}body/{%s}p" % (NAMESPACE, NAMESPACE, NAMESPACE, NAMESPACE))
+ for par in ner_pars:
+ par_entities = {}
+ _, par_id = get_node_id(par).split("_")
+ ner_sents = par.findall("{%s}s" % NAMESPACE)
+ for ner_sent in ner_sents:
+ corresp = ner_sent.get("corresp")
+ _, ner_sent_id = corresp.split("#morph_")
+ par_entities[ner_sent_id] = extract_entities_from_sentence(ner_sent)
+ subfolder_entities[par_id] = par_entities
+ return subfolder_entities
+
+def extract_entities_from_sentence(ner_sent):
+ # extracts all the entity dicts from the sentence
+ # we assume that an entity cannot span across sentences
+ segs = ner_sent.findall("./{%s}seg" % NAMESPACE)
+ sent_entities = {}
+ for i, seg in enumerate(segs):
+ ent_id = get_node_id(seg)
+ targets = [ptr.get("target") for ptr in seg.findall("./{%s}ptr" % NAMESPACE)]
+ orth = seg.findall("./{%s}fs/{%s}f[@name='orth']/{%s}string" % (NAMESPACE, NAMESPACE, NAMESPACE))[0].text
+ ner_type = seg.findall("./{%s}fs/{%s}f[@name='type']/{%s}symbol" % (NAMESPACE, NAMESPACE, NAMESPACE))[0].get("value")
+ ner_subtype_node = seg.findall("./{%s}fs/{%s}f[@name='subtype']/{%s}symbol" % (NAMESPACE, NAMESPACE, NAMESPACE))
+ if ner_subtype_node:
+ ner_subtype = ner_subtype_node[0].get("value")
+ else:
+ ner_subtype = None
+ entity = {"ent_id": ent_id,
+ "index": i,
+ "orth": orth,
+ "ner_type": ner_type,
+ "ner_subtype": ner_subtype,
+ "targets": targets}
+ sent_entities[ent_id] = entity
+ cleared_entities = clear_entities(sent_entities)
+ return cleared_entities
+
+
+def clear_entities(entities):
+ # eliminates entities which extend beyond our scope
+ resolve_entities(entities)
+ entities_list = sorted(list(entities.values()), key=lambda ent: ent["index"])
+ entities = eliminate_overlapping_entities(entities_list)
+ for entity in entities:
+ targets = entity["targets"]
+ entity["targets"] = [t.split("morph_")[1] for t in targets]
+ return entities
+
+
+def resolve_entities(entities):
+ # assign morphological level targets to entities
+ resolved_targets = {entity_id: resolve_entity(entity, entities) for entity_id, entity in entities.items()}
+ for entity_id in entities:
+ entities[entity_id]["targets"] = resolved_targets[entity_id]
+
+
+def resolve_entity(entity, entities):
+ # translate targets defined in terms of entities, into morphological units
+ # works recurrently
+ targets = entity["targets"]
+ resolved = []
+ for target in targets:
+ if target.startswith("named_"):
+ target_entity = entities[target]
+ resolved.extend(resolve_entity(target_entity, entities))
+ else:
+ resolved.append(target)
+ return resolved
+
+
+def eliminate_overlapping_entities(entities_list):
+ # we eliminate entities which are at least partially contained in one ocurring prior to them
+ # this amounts to removing overlap
+ subsumed = set([])
+ for sub_i, sub in enumerate(entities_list):
+ for over in entities_list[:sub_i]:
+ if any([target in over["targets"] for target in sub["targets"]]):
+ subsumed.add(sub["ent_id"])
+ return [entity for entity in entities_list if entity["ent_id"] not in subsumed]
+
+
+def assign_entities(subfolder, subfolder_entities, nkjp_dir):
+ # recovers all the segments from a subfolder, and annotates it with NER
+ morph_path = os.path.join(nkjp_dir, subfolder, MORPH_FILE)
+ rt = parse_xml(morph_path)
+ morph_pars = rt.findall("{%s}TEI/{%s}text/{%s}body/{%s}p" % (NAMESPACE, NAMESPACE, NAMESPACE, NAMESPACE))
+ par_id_to_segs = {}
+ for par in morph_pars:
+ _, par_id = get_node_id(par).split("_")
+ morph_sents = par.findall("{%s}s" % NAMESPACE)
+ sent_id_to_segs = {}
+ for morph_sent in morph_sents:
+ _, sent_id = get_node_id(morph_sent).split("_")
+ segs = morph_sent.findall("{%s}seg" % NAMESPACE)
+ sent_segs = {}
+ for i, seg in enumerate(segs):
+ _, seg_id = get_node_id(seg).split("morph_")
+ orth = seg.findall("{%s}fs/{%s}f[@name='orth']/{%s}string" % (NAMESPACE, NAMESPACE, NAMESPACE))[0].text
+ token = {"seg_id": seg_id,
+ "i": i,
+ "orth": orth,
+ "text": orth,
+ "tag": "_",
+ "ner": "O", # This will be overwritten
+ "ner_subtype": None,
+ }
+ sent_segs[seg_id] = token
+ sent_id_to_segs[sent_id] = sent_segs
+ par_id_to_segs[par_id] = sent_id_to_segs
+
+ if subfolder_entities is None:
+ return None
+
+ for par_key in subfolder_entities:
+ par_ents = subfolder_entities[par_key]
+ for sent_key in par_ents:
+ sent_entities = par_ents[sent_key]
+ for entity in sent_entities:
+ targets = entity["targets"]
+ iob = "B"
+ ner_label = entity["ner_type"]
+ matching_tokens = sorted([par_id_to_segs[par_key][sent_key][target] for target in targets], key=lambda x:x["i"])
+ for token in matching_tokens:
+ full_label = f"{iob}-{ner_label}"
+ token["ner"] = full_label
+ token["ner_subtype"] = entity["ner_subtype"]
+ iob = "I"
+ return par_id_to_segs
+
+
+def load_xml_nkjp(nkjp_dir):
+ subfolder_to_annotations = {}
+ subfolders = sorted(os.listdir(nkjp_dir))
+ for subfolder in tqdm([name for name in subfolders if os.path.isdir(os.path.join(nkjp_dir, name))]):
+ out = extract_entities_from_subfolder(subfolder, nkjp_dir)
+ if out:
+ subfolder_to_annotations[subfolder] = out
+ else:
+ print(subfolder, "has no ann_named.xml file")
+
+ return subfolder_to_annotations
+
+
+def split_dataset(dataset, shuffle=True, train_fraction=0.9, dev_fraction=0.05, test_section=True):
+ random.seed(987654321)
+ if shuffle:
+ random.shuffle(dataset)
+
+ if not test_section:
+ dev_fraction = 1 - train_fraction
+
+ train_size = int(train_fraction * len(dataset))
+ dev_size = int(dev_fraction * len(dataset))
+ train = dataset[:train_size]
+ dev = dataset[train_size: train_size + dev_size]
+ test = dataset[train_size + dev_size:]
+
+ return {
+ 'train': train,
+ 'dev': dev,
+ 'test': test
+ }
+
+
+def convert_nkjp(nkjp_path, output_dir):
+ """Converts NKJP NER data into IOB json format.
+
+ nkjp_dir is the path to directory where NKJP files are located.
+ """
+ # Load XML NKJP
+ print("Reading data from %s" % nkjp_path)
+ if os.path.isfile(nkjp_path) and (nkjp_path.endswith(".tar.gz") or nkjp_path.endswith(".tgz")):
+ with tempfile.TemporaryDirectory() as nkjp_dir:
+ print("Temporarily extracting %s to %s" % (nkjp_path, nkjp_dir))
+ with tarfile.open(nkjp_path, "r:gz") as tar_in:
+ tar_in.extractall(nkjp_dir)
+
+ subfolder_to_entities = load_xml_nkjp(nkjp_dir)
+ elif os.path.isdir(nkjp_path):
+ subfolder_to_entities = load_xml_nkjp(nkjp_path)
+ else:
+ raise FileNotFoundError("Cannot find either unpacked dataset or gzipped file")
+ converted = []
+ for subfolder_name, pars in subfolder_to_entities.items():
+ for par_id, par in pars.items():
+ paragraph_identifier = f"{subfolder_name}|{par_id}"
+ par_tokens = []
+ for _, sent in par.items():
+ tokens = sent.values()
+ srt = sorted(tokens, key=lambda tok:tok["i"])
+ for token in srt:
+ _ = token.pop("i")
+ _ = token.pop("seg_id")
+ par_tokens.append(token)
+ par_tokens[0]["paragraph_id"] = paragraph_identifier
+ converted.append(par_tokens)
+
+ split = split_dataset(converted)
+
+ for split_name, split in split.items():
+ if split:
+ with open(os.path.join(output_dir, f"pl_nkjp.{split_name}.json"), "w", encoding="utf-8") as f:
+ json.dump(split, f, ensure_ascii=False, indent=2)
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--input_path', type=str, default="/u/nlp/data/ner/stanza/polish/NKJP-PodkorpusMilionowy-1.2.tar.gz", help="Where to find the files")
+ parser.add_argument('--output_path', type=str, default="data/ner", help="Where to output the results")
+ args = parser.parse_args()
+
+ convert_nkjp(args.input_path, args.output_path)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/stanza/stanza/utils/datasets/ner/convert_nytk.py b/stanza/stanza/utils/datasets/ner/convert_nytk.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ae5f9d228d4dd51be0914bf43fbf575dcef1955
--- /dev/null
+++ b/stanza/stanza/utils/datasets/ner/convert_nytk.py
@@ -0,0 +1,32 @@
+
+import glob
+import os
+
+def convert_nytk(base_input_path, base_output_path, short_name):
+ for shard in ('train', 'dev', 'test'):
+ if shard == 'dev':
+ base_input_subdir = os.path.join(base_input_path, "data/train-devel-test/devel")
+ else:
+ base_input_subdir = os.path.join(base_input_path, "data/train-devel-test", shard)
+
+ shard_lines = []
+ base_input_glob = base_input_subdir + "/*/no-morph/*"
+ subpaths = glob.glob(base_input_glob)
+ print("Reading %d input files from %s" % (len(subpaths), base_input_glob))
+ for input_filename in subpaths:
+ if len(shard_lines) > 0:
+ shard_lines.append("")
+ with open(input_filename) as fin:
+ lines = fin.readlines()
+ if lines[0].strip() != '# global.columns = FORM LEMMA UPOS XPOS FEATS CONLL:NER':
+ raise ValueError("Unexpected format in %s" % input_filename)
+ lines = [x.strip().split("\t") for x in lines[1:]]
+ lines = ["%s\t%s" % (x[0], x[5]) if len(x) > 1 else "" for x in lines]
+ shard_lines.extend(lines)
+
+ bio_filename = os.path.join(base_output_path, '%s.%s.bio' % (short_name, shard))
+ with open(bio_filename, "w") as fout:
+ print("Writing %d lines to %s" % (len(shard_lines), bio_filename))
+ for line in shard_lines:
+ fout.write(line)
+ fout.write("\n")
diff --git a/stanza/stanza/utils/datasets/ner/convert_rgai.py b/stanza/stanza/utils/datasets/ner/convert_rgai.py
new file mode 100644
index 0000000000000000000000000000000000000000..86f65fec1d26136d301940d2f90ad687ac79bce9
--- /dev/null
+++ b/stanza/stanza/utils/datasets/ner/convert_rgai.py
@@ -0,0 +1,62 @@
+"""
+This script converts the Hungarian files available at u-szeged
+ https://rgai.inf.u-szeged.hu/node/130
+"""
+
+import os
+import tempfile
+
+# we reuse this to split the data randomly
+from stanza.utils.datasets.ner.split_wikiner import split_wikiner
+
+def read_rgai_file(filename, separator):
+ with open(filename, encoding="latin-1") as fin:
+ lines = fin.readlines()
+ lines = [x.strip() for x in lines]
+
+ for idx, line in enumerate(lines):
+ if not line:
+ continue
+ pieces = lines[idx].split(separator)
+ if len(pieces) != 2:
+ raise ValueError("Line %d is in an unexpected format! Expected exactly two pieces when split on %s" % (idx, separator))
+ # some of the data has '0' (the digit) instead of 'O' (the letter)
+ if pieces[-1] == '0':
+ pieces[-1] = "O"
+ lines[idx] = "\t".join(pieces)
+ print("Read %d lines from %s" % (len(lines), filename))
+ return lines
+
+def get_rgai_data(base_input_path, use_business, use_criminal):
+ assert use_business or use_criminal, "Must specify one or more sections of the dataset to use"
+
+ dataset_lines = []
+ if use_business:
+ business_file = os.path.join(base_input_path, "hun_ner_corpus.txt")
+
+ lines = read_rgai_file(business_file, "\t")
+ dataset_lines.extend(lines)
+
+ if use_criminal:
+ # There are two different annotation schemes, Context and
+ # NoContext. NoContext seems to fit better with the
+ # business_file's annotation scheme, since the scores are much
+ # higher when NoContext and hun_ner are combined
+ criminal_file = os.path.join(base_input_path, "HVGJavNENoContext")
+
+ lines = read_rgai_file(criminal_file, " ")
+ dataset_lines.extend(lines)
+
+ return dataset_lines
+
+def convert_rgai(base_input_path, base_output_path, short_name, use_business, use_criminal):
+ all_data_file = tempfile.NamedTemporaryFile(delete=False)
+ try:
+ raw_data = get_rgai_data(base_input_path, use_business, use_criminal)
+ for line in raw_data:
+ all_data_file.write(line.encode())
+ all_data_file.write("\n".encode())
+ all_data_file.close()
+ split_wikiner(base_output_path, all_data_file.name, prefix=short_name)
+ finally:
+ os.unlink(all_data_file.name)
diff --git a/stanza/stanza/utils/datasets/ner/count_entities.py b/stanza/stanza/utils/datasets/ner/count_entities.py
new file mode 100644
index 0000000000000000000000000000000000000000..c75cea2cacc058625df1610031164cf9061dabcb
--- /dev/null
+++ b/stanza/stanza/utils/datasets/ner/count_entities.py
@@ -0,0 +1,39 @@
+
+import argparse
+from collections import defaultdict
+import json
+
+from stanza.models.common.doc import Document
+from stanza.utils.datasets.ner.utils import list_doc_entities
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Report the coverage of one NER file on another.")
+ parser.add_argument('filename', type=str, nargs='+', help='File(s) to count')
+ args = parser.parse_args()
+ return args
+
+
+def count_entities(*filenames):
+ entity_collection = defaultdict(list)
+
+ for filename in filenames:
+ with open(filename) as fin:
+ doc = Document(json.load(fin))
+ num_tokens = sum(1 for sentence in doc.sentences for token in sentence.tokens)
+ print("Number of tokens in %s: %d" % (filename, num_tokens))
+ entities = list_doc_entities(doc)
+
+ for ent in entities:
+ entity_collection[ent[1]].append(ent[0])
+
+ keys = sorted(entity_collection.keys())
+ for k in keys:
+ print(k, len(entity_collection[k]))
+
+def main():
+ args = parse_args()
+
+ count_entities(*args.filename)
+
+if __name__ == '__main__':
+ main()
diff --git a/stanza/stanza/utils/datasets/ner/prepare_ner_dataset.py b/stanza/stanza/utils/datasets/ner/prepare_ner_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..5af5d7fce1462b10e41932c7beaf95eb9f15e8f8
--- /dev/null
+++ b/stanza/stanza/utils/datasets/ner/prepare_ner_dataset.py
@@ -0,0 +1,1449 @@
+"""Converts raw data files into json files usable by the training script.
+
+Currently it supports converting WikiNER datasets, available here:
+ https://figshare.com/articles/dataset/Learning_multilingual_named_entity_recognition_from_Wikipedia/5462500
+ - download the language of interest to {Language}-WikiNER
+ - then run
+ prepare_ner_dataset.py French-WikiNER
+
+A gold re-edit of WikiNER for French is here:
+ - https://huggingface.co/datasets/danrun/WikiNER-fr-gold/tree/main
+ - https://arxiv.org/abs/2411.00030
+ Danrun Cao, Nicolas Béchet, Pierre-François Marteau
+ - download to $NERBASE/wikiner-fr-gold/wikiner-fr-gold.conll
+ prepare_ner_dataset.py fr_wikinergold
+
+French WikiNER and its gold re-edit can be mixed together with
+ prepare_ner_dataset.py fr_wikinermixed
+ - the data for both WikiNER and WikiNER-fr-gold needs to be in the right place first
+
+Also, Finnish Turku dataset, available here:
+ - https://turkunlp.org/fin-ner.html
+ - https://github.com/TurkuNLP/turku-ner-corpus
+ git clone the repo into $NERBASE/finnish
+ you will now have a directory
+ $NERBASE/finnish/turku-ner-corpus
+ - prepare_ner_dataset.py fi_turku
+
+FBK in Italy produced an Italian dataset.
+ - KIND: an Italian Multi-Domain Dataset for Named Entity Recognition
+ Paccosi T. and Palmero Aprosio A.
+ LREC 2022
+ - https://arxiv.org/abs/2112.15099
+ The processing here is for a combined .tsv file they sent us.
+ - prepare_ner_dataset.py it_fbk
+ There is a newer version of the data available here:
+ https://github.com/dhfbk/KIND
+ TODO: update to the newer version of the data
+
+IJCNLP 2008 produced a few Indian language NER datasets.
+ description:
+ http://ltrc.iiit.ac.in/ner-ssea-08/index.cgi?topic=3
+ download:
+ http://ltrc.iiit.ac.in/ner-ssea-08/index.cgi?topic=5
+ The models produced from these datasets have extremely low recall, unfortunately.
+ - prepare_ner_dataset.py hi_ijc
+
+FIRE 2013 also produced NER datasets for Indian languages.
+ http://au-kbc.org/nlp/NER-FIRE2013/index.html
+ The datasets are password locked.
+ For Stanford users, contact Chris Manning for license details.
+ For external users, please contact the organizers for more information.
+ - prepare_ner_dataset.py hi-fire2013
+
+HiNER is another Hindi dataset option
+ https://github.com/cfiltnlp/HiNER
+ - HiNER: A Large Hindi Named Entity Recognition Dataset
+ Murthy, Rudra and Bhattacharjee, Pallab and Sharnagat, Rahul and
+ Khatri, Jyotsana and Kanojia, Diptesh and Bhattacharyya, Pushpak
+ There are two versions:
+ hi_hinercollapsed and hi_hiner
+ The collapsed version has just PER, LOC, ORG
+ - convert data as follows:
+ cd $NERBASE
+ mkdir hindi
+ cd hindi
+ git clone git@github.com:cfiltnlp/HiNER.git
+ python3 -m stanza.utils.datasets.ner.prepare_ner_dataset hi_hiner
+ python3 -m stanza.utils.datasets.ner.prepare_ner_dataset hi_hinercollapsed
+
+Ukranian NER is provided by lang-uk, available here:
+ https://github.com/lang-uk/ner-uk
+ git clone the repo to $NERBASE/lang-uk
+ There should be a subdirectory $NERBASE/lang-uk/ner-uk/data at that point
+ Conversion script graciously provided by Andrii Garkavyi @gawy
+ - prepare_ner_dataset.py uk_languk
+
+There are two Hungarian datasets are available here:
+ https://rgai.inf.u-szeged.hu/node/130
+ http://www.lrec-conf.org/proceedings/lrec2006/pdf/365_pdf.pdf
+ We combined them and give them the label hu_rgai
+ You can also build individual pieces with hu_rgai_business or hu_rgai_criminal
+ Create a subdirectory of $NERBASE, $NERBASE/hu_rgai, and download both of
+ the pieces and unzip them in that directory.
+ - prepare_ner_dataset.py hu_rgai
+
+Another Hungarian dataset is here:
+ - https://github.com/nytud/NYTK-NerKor
+ - git clone the entire thing in your $NERBASE directory to operate on it
+ - prepare_ner_dataset.py hu_nytk
+
+The two Hungarian datasets can be combined with hu_combined
+ TODO: verify that there is no overlap in text
+ - prepare_ner_dataset.py hu_combined
+
+BSNLP publishes NER datasets for Eastern European languages.
+ - In 2019 they published BG, CS, PL, RU.
+ - http://bsnlp.cs.helsinki.fi/bsnlp-2019/shared_task.html
+ - In 2021 they added some more data, but the test sets
+ were not publicly available as of April 2021.
+ Therefore, currently the model is made from 2019.
+ In 2021, the link to the 2021 task is here:
+ http://bsnlp.cs.helsinki.fi/shared-task.html
+ - The below method processes the 2019 version of the corpus.
+ It has specific adjustments for the BG section, which has
+ quite a few typos or mis-annotations in it. Other languages
+ probably need similar work in order to function optimally.
+ - make a directory $NERBASE/bsnlp2019
+ - download the "training data are available HERE" and
+ "test data are available HERE" to this subdirectory
+ - unzip those files in that directory
+ - we use the code name "bg_bsnlp19". Other languages from
+ bsnlp 2019 can be supported by adding the appropriate
+ functionality in convert_bsnlp.py.
+ - prepare_ner_dataset.py bg_bsnlp19
+
+NCHLT produced NER datasets for many African languages.
+ Unfortunately, it is difficult to make use of many of these,
+ as there is no corresponding UD data from which to build a
+ tokenizer or other tools.
+ - Afrikaans: https://repo.sadilar.org/handle/20.500.12185/299
+ - isiNdebele: https://repo.sadilar.org/handle/20.500.12185/306
+ - isiXhosa: https://repo.sadilar.org/handle/20.500.12185/312
+ - isiZulu: https://repo.sadilar.org/handle/20.500.12185/319
+ - Sepedi: https://repo.sadilar.org/handle/20.500.12185/328
+ - Sesotho: https://repo.sadilar.org/handle/20.500.12185/334
+ - Setswana: https://repo.sadilar.org/handle/20.500.12185/341
+ - Siswati: https://repo.sadilar.org/handle/20.500.12185/346
+ - Tsivenda: https://repo.sadilar.org/handle/20.500.12185/355
+ - Xitsonga: https://repo.sadilar.org/handle/20.500.12185/362
+ Agree to the license, download the zip, and unzip it in
+ $NERBASE/NCHLT
+
+UCSY built a Myanmar dataset. They have not made it publicly
+ available, but they did make it available to Stanford for research
+ purposes. Contact Chris Manning or John Bauer for the data files if
+ you are Stanford affiliated.
+ - https://arxiv.org/abs/1903.04739
+ - Syllable-based Neural Named Entity Recognition for Myanmar Language
+ by Hsu Myat Mo and Khin Mar Soe
+
+Hanieh Poostchi et al produced a Persian NER dataset:
+ - git@github.com:HaniehP/PersianNER.git
+ - https://github.com/HaniehP/PersianNER
+ - Hanieh Poostchi, Ehsan Zare Borzeshi, Mohammad Abdous, and Massimo Piccardi,
+ "PersoNER: Persian Named-Entity Recognition"
+ - Hanieh Poostchi, Ehsan Zare Borzeshi, and Massimo Piccardi,
+ "BiLSTM-CRF for Persian Named-Entity Recognition; ArmanPersoNERCorpus: the First Entity-Annotated Persian Dataset"
+ - Conveniently, this dataset is already in BIO format. It does not have a dev split, though.
+ git clone the above repo, unzip ArmanPersoNERCorpus.zip, and this script will split the
+ first train fold into a dev section.
+
+SUC3 is a Swedish NER dataset provided by Språkbanken
+ - https://spraakbanken.gu.se/en/resources/suc3
+ - The splitting tool is generously provided by
+ Emil Stenstrom
+ https://github.com/EmilStenstrom/suc_to_iob
+ - Download the .bz2 file at this URL and put it in $NERBASE/sv_suc3shuffle
+ It is not necessary to unzip it.
+ - Gustafson-Capková, Sophia and Britt Hartmann, 2006,
+ Manual of the Stockholm Umeå Corpus version 2.0.
+ Stockholm University.
+ - Östling, Robert, 2013, Stagger
+ an Open-Source Part of Speech Tagger for Swedish
+ Northern European Journal of Language Technology 3: 1–18
+ DOI 10.3384/nejlt.2000-1533.1331
+ - The shuffled dataset can be converted with dataset code
+ prepare_ner_dataset.py sv_suc3shuffle
+ - If you fill out the license form and get the official data,
+ you can get the official splits by putting the provided zip file
+ in $NERBASE/sv_suc3licensed. Again, not necessary to unzip it
+ python3 -m stanza.utils.datasets.ner.prepare_ner_dataset sv_suc3licensed
+
+DDT is a reformulation of the Danish Dependency Treebank as an NER dataset
+ - https://danlp-alexandra.readthedocs.io/en/latest/docs/datasets.html#dane
+ - direct download link as of late 2021: https://danlp.alexandra.dk/304bd159d5de/datasets/ddt.zip
+ - https://aclanthology.org/2020.lrec-1.565.pdf
+ DaNE: A Named Entity Resource for Danish
+ Rasmus Hvingelby, Amalie Brogaard Pauli, Maria Barrett,
+ Christina Rosted, Lasse Malm Lidegaard, Anders Søgaard
+ - place ddt.zip in $NERBASE/da_ddt/ddt.zip
+ python3 -m stanza.utils.datasets.ner.prepare_ner_dataset da_ddt
+
+NorNE is the Norwegian Dependency Treebank with NER labels
+ - LREC 2020
+ NorNE: Annotating Named Entities for Norwegian
+ Fredrik Jørgensen, Tobias Aasmoe, Anne-Stine Ruud Husevåg,
+ Lilja Øvrelid, and Erik Velldal
+ - both Bokmål and Nynorsk
+ - This dataset is in a git repo:
+ https://github.com/ltgoslo/norne
+ Clone it into $NERBASE
+ git clone git@github.com:ltgoslo/norne.git
+ python3 -m stanza.utils.datasets.ner.prepare_ner_dataset nb_norne
+ python3 -m stanza.utils.datasets.ner.prepare_ner_dataset nn_norne
+
+tr_starlang is a set of constituency trees for Turkish
+ The words in this dataset (usually) have NER labels as well
+
+ A dataset in three parts from the Starlang group in Turkey:
+ Neslihan Kara, Büşra Marşan, et al
+ Creating A Syntactically Felicitous Constituency Treebank For Turkish
+ https://ieeexplore.ieee.org/document/9259873
+ git clone the following three repos
+ https://github.com/olcaytaner/TurkishAnnotatedTreeBank-15
+ https://github.com/olcaytaner/TurkishAnnotatedTreeBank2-15
+ https://github.com/olcaytaner/TurkishAnnotatedTreeBank2-20
+ Put them in
+ $CONSTITUENCY_HOME/turkish (yes, the constituency home)
+ python3 -m stanza.utils.datasets.ner.prepare_ner_dataset tr_starlang
+
+GermEval2014 is a German NER dataset
+ https://sites.google.com/site/germeval2014ner/data
+ https://drive.google.com/drive/folders/1kC0I2UGl2ltrluI9NqDjaQJGw5iliw_J
+ Download the files in that directory
+ NER-de-train.tsv NER-de-dev.tsv NER-de-test.tsv
+ put them in
+ $NERBASE/germeval2014
+ then run
+ python3 -m stanza.utils.datasets.ner.prepare_ner_dataset de_germeval2014
+
+The UD Japanese GSD dataset has a conversion by Megagon Labs
+ https://github.com/megagonlabs/UD_Japanese-GSD
+ https://github.com/megagonlabs/UD_Japanese-GSD/tags
+ - r2.9-NE has the NE tagged files inside a "spacy"
+ folder in the download
+ - expected directory for this data:
+ unzip the .zip of the release into
+ $NERBASE/ja_gsd
+ so it should wind up in
+ $NERBASE/ja_gsd/UD_Japanese-GSD-r2.9-NE
+ python3 -m stanza.utils.datasets.ner.prepare_ner_dataset ja_gsd
+
+L3Cube is a Marathi dataset
+ - https://arxiv.org/abs/2204.06029
+ https://arxiv.org/pdf/2204.06029.pdf
+ https://github.com/l3cube-pune/MarathiNLP
+ - L3Cube-MahaNER: A Marathi Named Entity Recognition Dataset and BERT models
+ Parth Patil, Aparna Ranade, Maithili Sabane, Onkar Litake, Raviraj Joshi
+
+ Clone the repo into $NERBASE/marathi
+ git clone git@github.com:l3cube-pune/MarathiNLP.git
+ Then run
+ python3 -m stanza.utils.datasets.ner.prepare_ner_dataset mr_l3cube
+
+Daffodil University produced a Bangla NER dataset
+ - https://github.com/Rifat1493/Bengali-NER
+ - https://ieeexplore.ieee.org/document/8944804
+ - Bengali Named Entity Recognition:
+ A survey with deep learning benchmark
+ Md Jamiur Rahman Rifat, Sheikh Abujar, Sheak Rashed Haider Noori,
+ Syed Akhter Hossain
+
+ Clone the repo into a "bangla" subdirectory of $NERBASE
+ cd $NERBASE/bangla
+ git clone git@github.com:Rifat1493/Bengali-NER.git
+ Then run
+ python3 -m stanza.utils.datasets.ner.prepare_ner_dataset bn_daffodil
+
+LST20 is a Thai NER dataset from 2020
+ - https://arxiv.org/abs/2008.05055
+ The Annotation Guideline of LST20 Corpus
+ Prachya Boonkwan, Vorapon Luantangsrisuk, Sitthaa Phaholphinyo,
+ Kanyanat Kriengket, Dhanon Leenoi, Charun Phrombut,
+ Monthika Boriboon, Krit Kosawat, Thepchai Supnithi
+ - This script processes a version which can be downloaded here after registration:
+ https://aiforthai.in.th/index.php
+ - There is another version downloadable from HuggingFace
+ The script will likely need some modification to be compatible
+ with the HuggingFace version
+ - Download the data in $NERBASE/thai/LST20_Corpus
+ There should be "train", "eval", "test" directories after downloading
+ - Then run
+ pytohn3 -m stanza.utils.datasets.ner.prepare_ner_dataset th_lst20
+
+Thai-NNER is another Thai NER dataset, from 2022
+ - https://github.com/vistec-AI/Thai-NNER
+ - https://aclanthology.org/2022.findings-acl.116/
+ Thai Nested Named Entity Recognition Corpus
+ Weerayut Buaphet, Can Udomcharoenchaikit, Peerat Limkonchotiwat,
+ Attapol Rutherford, and Sarana Nutanong
+ - git clone the data to $NERBASE/thai
+ - On the git repo, there should be a link to a more complete version
+ of the dataset. For example, in Sep. 2023 it is here:
+ https://github.com/vistec-AI/Thai-NNER#dataset
+ The Google drive it goes to has "postproc".
+ Put the train.json, dev.json, and test.json in
+ $NERBASE/thai/Thai-NNER/data/scb-nner-th-2022/postproc/
+ - Then run
+ pytohn3 -m stanza.utils.datasets.ner.prepare_ner_dataset th_nner22
+
+
+NKJP is a Polish NER dataset
+ - http://nkjp.pl/index.php?page=0&lang=1
+ About the Project
+ - http://zil.ipipan.waw.pl/DistrNKJP
+ Wikipedia subcorpus used to train charlm model
+ - http://clip.ipipan.waw.pl/NationalCorpusOfPolish?action=AttachFile&do=view&target=NKJP-PodkorpusMilionowy-1.2.tar.gz
+ Annotated subcorpus to train NER model.
+ Download and extract to $NERBASE/Polish-NKJP or leave the gzip in $NERBASE/polish/...
+
+kk_kazNERD is a Kazakh dataset published in 2021
+ - https://github.com/IS2AI/KazNERD
+ - https://arxiv.org/abs/2111.13419
+ KazNERD: Kazakh Named Entity Recognition Dataset
+ Rustem Yeshpanov, Yerbolat Khassanov, Huseyin Atakan Varol
+ - in $NERBASE, make a "kazakh" directory, then git clone the repo there
+ mkdir -p $NERBASE/kazakh
+ cd $NERBASE/kazakh
+ git clone git@github.com:IS2AI/KazNERD.git
+ - Then run
+ pytohn3 -m stanza.utils.datasets.ner.prepare_ner_dataset kk_kazNERD
+
+Masakhane NER is a set of NER datasets for African languages
+ - MasakhaNER: Named Entity Recognition for African Languages
+ Adelani, David Ifeoluwa; Abbott, Jade; Neubig, Graham;
+ D’souza, Daniel; Kreutzer, Julia; Lignos, Constantine;
+ Palen-Michel, Chester; Buzaaba, Happy; Rijhwani, Shruti;
+ Ruder, Sebastian; Mayhew, Stephen; Azime, Israel Abebe;
+ Muhammad, Shamsuddeen H.; Emezue, Chris Chinenye;
+ Nakatumba-Nabende, Joyce; Ogayo, Perez; Anuoluwapo, Aremu;
+ Gitau, Catherine; Mbaye, Derguene; Alabi, Jesujoba;
+ Yimam, Seid Muhie; Gwadabe, Tajuddeen Rabiu; Ezeani, Ignatius;
+ Niyongabo, Rubungo Andre; Mukiibi, Jonathan; Otiende, Verrah;
+ Orife, Iroro; David, Davis; Ngom, Samba; Adewumi, Tosin;
+ Rayson, Paul; Adeyemi, Mofetoluwa; Muriuki, Gerald;
+ Anebi, Emmanuel; Chukwuneke, Chiamaka; Odu, Nkiruka;
+ Wairagala, Eric Peter; Oyerinde, Samuel; Siro, Clemencia;
+ Bateesa, Tobius Saul; Oloyede, Temilola; Wambui, Yvonne;
+ Akinode, Victor; Nabagereka, Deborah; Katusiime, Maurice;
+ Awokoya, Ayodele; MBOUP, Mouhamadane; Gebreyohannes, Dibora;
+ Tilaye, Henok; Nwaike, Kelechi; Wolde, Degaga; Faye, Abdoulaye;
+ Sibanda, Blessing; Ahia, Orevaoghene; Dossou, Bonaventure F. P.;
+ Ogueji, Kelechi; DIOP, Thierno Ibrahima; Diallo, Abdoulaye;
+ Akinfaderin, Adewale; Marengereke, Tendai; Osei, Salomey
+ - https://github.com/masakhane-io/masakhane-ner
+ - git clone the repo to $NERBASE
+ - Then run
+ python3 -m stanza.utils.datasets.ner.prepare_ner_dataset lcode_masakhane
+ - You can use the full language name, the 3 letter language code,
+ or in the case of languages with a 2 letter language code,
+ the 2 letter code for lcode. The tool will throw an error
+ if the language is not supported in Masakhane.
+
+SiNER is a Sindhi NER dataset
+ - https://aclanthology.org/2020.lrec-1.361/
+ SiNER: A Large Dataset for Sindhi Named Entity Recognition
+ Wazir Ali, Junyu Lu, Zenglin Xu
+ - It is available via git repository
+ https://github.com/AliWazir/SiNER-dataset
+ As of Nov. 2022, there were a few changes to the dataset
+ to update a couple instances of broken tags & tokenization
+ - Clone the repo to $NERBASE/sindhi
+ mkdir $NERBASE/sindhi
+ cd $NERBASE/sindhi
+ git clone git@github.com:AliWazir/SiNER-dataset.git
+ - Then, prepare the dataset with this script:
+ python3 -m stanza.utils.datasets.ner.prepare_ner_dataset sd_siner
+
+en_sample is the toy dataset included with stanza-train
+ https://github.com/stanfordnlp/stanza-train
+ this is not meant for any kind of actual NER use
+
+ArmTDP-NER is an Armenian NER dataset
+ - https://github.com/myavrum/ArmTDP-NER.git
+ ArmTDP-NER: The corpus was developed by the ArmTDP team led by Marat M. Yavrumyan
+ at the Yerevan State University by the collaboration of "Armenia National SDG Innovation Lab"
+ and "UC Berkley's Armenian Linguists' network".
+ - in $NERBASE, make a "armenian" directory, then git clone the repo there
+ mkdir -p $NERBASE/armenian
+ cd $NERBASE/armenian
+ git clone https://github.com/myavrum/ArmTDP-NER.git
+ - Then run
+ python3 -m stanza.utils.datasets.ner.prepare_ner_dataset hy_armtdp
+
+en_conll03 is the classic 2003 4 class CoNLL dataset
+ - The version we use is posted on HuggingFace
+ - https://huggingface.co/datasets/conll2003
+ - The prepare script will download from HF
+ using the datasets package, then convert to json
+ - Introduction to the CoNLL-2003 Shared Task:
+ Language-Independent Named Entity Recognition
+ Tjong Kim Sang, Erik F. and De Meulder, Fien
+ - python3 stanza/utils/datasets/ner/prepare_ner_dataset.py en_conll03
+
+en_conll03ww is CoNLL 03 with Worldwide added to the training data.
+ - python3 stanza/utils/datasets/ner/prepare_ner_dataset.py en_conll03ww
+
+en_conllpp is a test set from 2020 newswire
+ - https://arxiv.org/abs/2212.09747
+ - https://github.com/ShuhengL/acl2023_conllpp
+ - Do CoNLL-2003 Named Entity Taggers Still Work Well in 2023?
+ Shuheng Liu, Alan Ritter
+ - git clone the repo in $NERBASE
+ - then run
+ python3 stanza/utils/datasets/ner/prepare_ner_dataset.py en_conllpp
+
+en_ontonotes is the OntoNotes 5 on HuggingFace
+ - https://huggingface.co/datasets/conll2012_ontonotesv5
+ - python3 stanza/utils/datasets/ner/prepare_ner_dataset.py en_ontonotes
+ - this downloads the "v12" version of the data
+
+en_worldwide-4class is an English non-US newswire dataset
+ - annotated by MLTwist and Aya Data, with help from Datasaur,
+ collected at Stanford
+ - work to be published at EMNLP Findings
+ - the 4 class version is converted to the 4 classes in conll,
+ then split into train/dev/test
+ - clone https://github.com/stanfordnlp/en-worldwide-newswire
+ into $NERBASE/en_worldwide
+
+en_worldwide-9class is an English non-US newswire dataset
+ - annotated by MLTwist and Aya Data, with help from Datasaur,
+ collected at Stanford
+ - work to be published at EMNLP Findings
+ - the 9 class version is not edited
+ - clone https://github.com/stanfordnlp/en-worldwide-newswire
+ into $NERBASE/en_worldwide
+
+zh-hans_ontonotes is the ZH split of the OntoNotes dataset
+ - https://catalog.ldc.upenn.edu/LDC2013T19
+ - https://huggingface.co/datasets/conll2012_ontonotesv5
+ - python3 stanza/utils/datasets/ner/prepare_ner_dataset.py zh-hans_ontonotes
+ - this downloads the "v4" version of the data
+
+
+AQMAR is a small dataset of Arabic Wikipedia articles
+ - http://www.cs.cmu.edu/~ark/ArabicNER/
+ - Recall-Oriented Learning of Named Entities in Arabic Wikipedia
+ Behrang Mohit, Nathan Schneider, Rishav Bhowmick, Kemal Oflazer, and Noah A. Smith.
+ In Proceedings of the 13th Conference of the European Chapter of
+ the Association for Computational Linguistics, Avignon, France,
+ April 2012.
+ - download the .zip file there and put it in
+ $NERBASE/arabic/AQMAR
+ - there is a challenge for it here:
+ https://www.topcoder.com/challenges/f3cf483e-a95c-4a7e-83e8-6bdd83174d38
+ - alternatively, we just randomly split it ourselves
+ - currently, running the following reproduces the random split:
+ python3 stanza/utils/datasets/ner/prepare_ner_dataset.py ar_aqmar
+
+IAHLT contains NER for Hebrew in the knesset treebank
+ - as of UD 2.14, it is only in the git repo
+ - download that git repo to $UDBASE_GIT:
+ https://github.com/UniversalDependencies/UD_Hebrew-IAHLTknesset
+ - change to the dev branch in that repo
+ python3 stanza/utils/datasets/ner/prepare_ner_dataset.py he_iahlt
+"""
+
+import glob
+import os
+import json
+import random
+import re
+import shutil
+import sys
+import tempfile
+
+from stanza.models.common.constant import treebank_to_short_name, lcode2lang, lang_to_langcode, two_to_three_letters
+from stanza.models.ner.utils import to_bio2, bio2_to_bioes
+import stanza.utils.default_paths as default_paths
+
+from stanza.utils.datasets.common import UnknownDatasetError
+from stanza.utils.datasets.ner.preprocess_wikiner import preprocess_wikiner
+from stanza.utils.datasets.ner.split_wikiner import split_wikiner
+import stanza.utils.datasets.ner.build_en_combined as build_en_combined
+import stanza.utils.datasets.ner.conll_to_iob as conll_to_iob
+import stanza.utils.datasets.ner.convert_ar_aqmar as convert_ar_aqmar
+import stanza.utils.datasets.ner.convert_bn_daffodil as convert_bn_daffodil
+import stanza.utils.datasets.ner.convert_bsf_to_beios as convert_bsf_to_beios
+import stanza.utils.datasets.ner.convert_bsnlp as convert_bsnlp
+import stanza.utils.datasets.ner.convert_en_conll03 as convert_en_conll03
+import stanza.utils.datasets.ner.convert_fire_2013 as convert_fire_2013
+import stanza.utils.datasets.ner.convert_he_iahlt as convert_he_iahlt
+import stanza.utils.datasets.ner.convert_ijc as convert_ijc
+import stanza.utils.datasets.ner.convert_kk_kazNERD as convert_kk_kazNERD
+import stanza.utils.datasets.ner.convert_lst20 as convert_lst20
+import stanza.utils.datasets.ner.convert_nner22 as convert_nner22
+import stanza.utils.datasets.ner.convert_mr_l3cube as convert_mr_l3cube
+import stanza.utils.datasets.ner.convert_my_ucsy as convert_my_ucsy
+import stanza.utils.datasets.ner.convert_ontonotes as convert_ontonotes
+import stanza.utils.datasets.ner.convert_rgai as convert_rgai
+import stanza.utils.datasets.ner.convert_nytk as convert_nytk
+import stanza.utils.datasets.ner.convert_starlang_ner as convert_starlang_ner
+import stanza.utils.datasets.ner.convert_nkjp as convert_nkjp
+import stanza.utils.datasets.ner.prepare_ner_file as prepare_ner_file
+import stanza.utils.datasets.ner.convert_sindhi_siner as convert_sindhi_siner
+import stanza.utils.datasets.ner.ontonotes_multitag as ontonotes_multitag
+import stanza.utils.datasets.ner.simplify_en_worldwide as simplify_en_worldwide
+import stanza.utils.datasets.ner.suc_to_iob as suc_to_iob
+import stanza.utils.datasets.ner.suc_conll_to_iob as suc_conll_to_iob
+import stanza.utils.datasets.ner.convert_hy_armtdp as convert_hy_armtdp
+from stanza.utils.datasets.ner.utils import convert_bioes_to_bio, convert_bio_to_json, get_tags, read_tsv, write_sentences, write_dataset, random_shuffle_by_prefixes, read_prefix_file, combine_files
+
+SHARDS = ('train', 'dev', 'test')
+
+def process_turku(paths, short_name):
+ assert short_name == 'fi_turku'
+ base_input_path = os.path.join(paths["NERBASE"], "finnish", "turku-ner-corpus", "data", "conll")
+ base_output_path = paths["NER_DATA_DIR"]
+ for shard in SHARDS:
+ input_filename = os.path.join(base_input_path, '%s.tsv' % shard)
+ if not os.path.exists(input_filename):
+ raise FileNotFoundError('Cannot find %s component of %s in %s' % (shard, short_name, input_filename))
+ output_filename = os.path.join(base_output_path, '%s.%s.json' % (short_name, shard))
+ prepare_ner_file.process_dataset(input_filename, output_filename)
+
+def process_it_fbk(paths, short_name):
+ assert short_name == "it_fbk"
+ base_input_path = os.path.join(paths["NERBASE"], short_name)
+ csv_file = os.path.join(base_input_path, "all-wiki-split.tsv")
+ if not os.path.exists(csv_file):
+ raise FileNotFoundError("Cannot find the FBK dataset in its expected location: {}".format(csv_file))
+ base_output_path = paths["NER_DATA_DIR"]
+ split_wikiner(base_output_path, csv_file, prefix=short_name, suffix="io", shuffle=False, train_fraction=0.8, dev_fraction=0.1)
+ convert_bio_to_json(base_output_path, base_output_path, short_name, suffix="io")
+
+
+def process_languk(paths, short_name):
+ assert short_name == 'uk_languk'
+ base_input_path = os.path.join(paths["NERBASE"], 'lang-uk', 'ner-uk', 'data')
+ base_output_path = paths["NER_DATA_DIR"]
+ train_test_split_fname = os.path.join(paths["NERBASE"], 'lang-uk', 'ner-uk', 'doc', 'dev-test-split.txt')
+ convert_bsf_to_beios.convert_bsf_in_folder(base_input_path, base_output_path, train_test_split_file=train_test_split_fname)
+ for shard in SHARDS:
+ input_filename = os.path.join(base_output_path, convert_bsf_to_beios.CORPUS_NAME, "%s.bio" % shard)
+ if not os.path.exists(input_filename):
+ raise FileNotFoundError('Cannot find %s component of %s in %s' % (shard, short_name, input_filename))
+ output_filename = os.path.join(base_output_path, '%s.%s.json' % (short_name, shard))
+ prepare_ner_file.process_dataset(input_filename, output_filename)
+
+
+def process_ijc(paths, short_name):
+ """
+ Splits the ijc Hindi dataset in train, dev, test
+
+ The original data had train & test splits, so we randomly divide
+ the files in train to make a dev set.
+
+ The expected location of the IJC data is hi_ijc. This method
+ should be possible to use for other languages, but we have very
+ little support for the other languages of IJC at the moment.
+ """
+ base_input_path = os.path.join(paths["NERBASE"], short_name)
+ base_output_path = paths["NER_DATA_DIR"]
+
+ test_files = [os.path.join(base_input_path, "test-data-hindi.txt")]
+ test_csv_file = os.path.join(base_output_path, short_name + ".test.csv")
+ print("Converting test input %s to space separated file in %s" % (test_files[0], test_csv_file))
+ convert_ijc.convert_ijc(test_files, test_csv_file)
+
+ train_input_path = os.path.join(base_input_path, "training-hindi", "*utf8")
+ train_files = glob.glob(train_input_path)
+ train_csv_file = os.path.join(base_output_path, short_name + ".train.csv")
+ dev_csv_file = os.path.join(base_output_path, short_name + ".dev.csv")
+ print("Converting training input from %s to space separated files in %s and %s" % (train_input_path, train_csv_file, dev_csv_file))
+ convert_ijc.convert_split_ijc(train_files, train_csv_file, dev_csv_file)
+
+ for csv_file, shard in zip((train_csv_file, dev_csv_file, test_csv_file), SHARDS):
+ output_filename = os.path.join(base_output_path, '%s.%s.json' % (short_name, shard))
+ prepare_ner_file.process_dataset(csv_file, output_filename)
+
+
+def process_fire_2013(paths, dataset):
+ """
+ Splits the FIRE 2013 dataset into train, dev, test
+
+ The provided datasets are all mixed together at this point, so it
+ is not possible to recreate the original test conditions used in
+ the bakeoff
+ """
+ short_name = treebank_to_short_name(dataset)
+ langcode, _ = short_name.split("_")
+ short_name = "%s_fire2013" % langcode
+ if not langcode in ("hi", "en", "ta", "bn", "mal"):
+ raise UnkonwnDatasetError(dataset, "Language %s not one of the FIRE 2013 languages" % langcode)
+ language = lcode2lang[langcode].lower()
+
+ # for example, FIRE2013/hindi_train
+ base_input_path = os.path.join(paths["NERBASE"], "FIRE2013", "%s_train" % language)
+ base_output_path = paths["NER_DATA_DIR"]
+
+ train_csv_file = os.path.join(base_output_path, "%s.train.csv" % short_name)
+ dev_csv_file = os.path.join(base_output_path, "%s.dev.csv" % short_name)
+ test_csv_file = os.path.join(base_output_path, "%s.test.csv" % short_name)
+
+ convert_fire_2013.convert_fire_2013(base_input_path, train_csv_file, dev_csv_file, test_csv_file)
+
+ for csv_file, shard in zip((train_csv_file, dev_csv_file, test_csv_file), SHARDS):
+ output_filename = os.path.join(base_output_path, '%s.%s.json' % (short_name, shard))
+ prepare_ner_file.process_dataset(csv_file, output_filename)
+
+def process_wikiner(paths, dataset):
+ short_name = treebank_to_short_name(dataset)
+
+ base_input_path = os.path.join(paths["NERBASE"], dataset)
+ base_output_path = paths["NER_DATA_DIR"]
+
+ expected_filename = "aij*wikiner*"
+ input_files = [x for x in glob.glob(os.path.join(base_input_path, expected_filename)) if not x.endswith("bz2")]
+ if len(input_files) == 0:
+ raw_input_path = os.path.join(base_input_path, "raw")
+ input_files = [x for x in glob.glob(os.path.join(raw_input_path, expected_filename)) if not x.endswith("bz2")]
+ if len(input_files) > 1:
+ raise FileNotFoundError("Found too many raw wikiner files in %s: %s" % (raw_input_path, ", ".join(input_files)))
+ elif len(input_files) > 1:
+ raise FileNotFoundError("Found too many raw wikiner files in %s: %s" % (base_input_path, ", ".join(input_files)))
+
+ if len(input_files) == 0:
+ raise FileNotFoundError("Could not find any raw wikiner files in %s or %s" % (base_input_path, raw_input_path))
+
+ csv_file = os.path.join(base_output_path, short_name + "_csv")
+ print("Converting raw input %s to space separated file in %s" % (input_files[0], csv_file))
+ try:
+ preprocess_wikiner(input_files[0], csv_file)
+ except UnicodeDecodeError:
+ preprocess_wikiner(input_files[0], csv_file, encoding="iso8859-1")
+
+ # this should create train.bio, dev.bio, and test.bio
+ print("Splitting %s to %s" % (csv_file, base_output_path))
+ split_wikiner(base_output_path, csv_file, prefix=short_name)
+ convert_bio_to_json(base_output_path, base_output_path, short_name)
+
+def process_french_wikiner_gold(paths, dataset):
+ short_name = treebank_to_short_name(dataset)
+
+ base_input_path = os.path.join(paths["NERBASE"], "wikiner-fr-gold")
+ base_output_path = paths["NER_DATA_DIR"]
+
+ input_filename = os.path.join(base_input_path, "wikiner-fr-gold.conll")
+ if not os.path.exists(input_filename):
+ raise FileNotFoundError("Could not find the expected input file %s for dataset %s" % (input_filename, base_input_path))
+
+ print("Reading %s" % input_filename)
+ sentences = read_tsv(input_filename, text_column=0, annotation_column=2, separator=" ")
+ print("Read %d sentences" % len(sentences))
+
+ tags = [y for sentence in sentences for x, y in sentence]
+ tags = sorted(set(tags))
+ print("Found the following tags:\n%s" % tags)
+ expected_tags = ['B-LOC', 'B-MISC', 'B-ORG', 'B-PER',
+ 'E-LOC', 'E-MISC', 'E-ORG', 'E-PER',
+ 'I-LOC', 'I-MISC', 'I-ORG', 'I-PER',
+ 'O',
+ 'S-LOC', 'S-MISC', 'S-ORG', 'S-PER']
+ assert tags == expected_tags
+
+ output_filename = os.path.join(base_output_path, "%s.full.bioes" % short_name)
+ print("Writing BIOES to %s" % output_filename)
+ write_sentences(output_filename, sentences)
+
+ print("Splitting %s to %s" % (output_filename, base_output_path))
+ split_wikiner(base_output_path, output_filename, prefix=short_name, suffix="bioes")
+ convert_bioes_to_bio(base_output_path, base_output_path, short_name)
+ convert_bio_to_json(base_output_path, base_output_path, short_name, suffix="bioes")
+
+def process_french_wikiner_mixed(paths, dataset):
+ """
+ Build both the original and gold edited versions of WikiNER, then mix them
+
+ First we eliminate any duplicates (with one exception), then we combine the data
+
+ There are two main ways we could have done this:
+ - mix it together without any restrictions
+ - use the multi_ner mechanism to build a dataset which represents two prediction heads
+
+ The second method seems to give slightly better results than the first method,
+ but neither beat just using a transformer on the gold set alone
+
+ On the randomly selected test set, using WV and charlm but not a transformer
+ (this was on a previously published version of the dataset):
+
+ one prediction head:
+ INFO: Score by entity:
+ Prec. Rec. F1
+ 89.32 89.26 89.29
+ INFO: Score by token:
+ Prec. Rec. F1
+ 89.43 86.88 88.14
+ INFO: Weighted f1 for non-O tokens: 0.878855
+
+ two prediction heads:
+ INFO: Score by entity:
+ Prec. Rec. F1
+ 89.83 89.76 89.79
+ INFO: Score by token:
+ Prec. Rec. F1
+ 89.17 88.15 88.66
+ INFO: Weighted f1 for non-O tokens: 0.885675
+
+ On a randomly selected dev set, using transformer:
+
+ gold:
+ INFO: Score by entity:
+ Prec. Rec. F1
+ 93.63 93.98 93.81
+ INFO: Score by token:
+ Prec. Rec. F1
+ 92.80 92.79 92.80
+ INFO: Weighted f1 for non-O tokens: 0.927548
+
+ mixed:
+ INFO: Score by entity:
+ Prec. Rec. F1
+ 93.54 93.82 93.68
+ INFO: Score by token:
+ Prec. Rec. F1
+ 92.99 92.51 92.75
+ INFO: Weighted f1 for non-O tokens: 0.926964
+ """
+ short_name = treebank_to_short_name(dataset)
+
+ process_french_wikiner_gold(paths, "fr_wikinergold")
+ process_wikiner(paths, "French-WikiNER")
+ base_output_path = paths["NER_DATA_DIR"]
+
+ with open(os.path.join(base_output_path, "fr_wikinergold.train.json")) as fin:
+ gold_train = json.load(fin)
+ with open(os.path.join(base_output_path, "fr_wikinergold.dev.json")) as fin:
+ gold_dev = json.load(fin)
+ with open(os.path.join(base_output_path, "fr_wikinergold.test.json")) as fin:
+ gold_test = json.load(fin)
+
+ gold = gold_train + gold_dev + gold_test
+ print("%d total sentences in the gold relabeled dataset (randomly split)" % len(gold))
+ gold = {tuple([x["text"] for x in sentence]): sentence for sentence in gold}
+ print(" (%d after dedup)" % len(gold))
+
+ original = (read_tsv(os.path.join(base_output_path, "fr_wikiner.train.bio"), text_column=0, annotation_column=1) +
+ read_tsv(os.path.join(base_output_path, "fr_wikiner.dev.bio"), text_column=0, annotation_column=1) +
+ read_tsv(os.path.join(base_output_path, "fr_wikiner.test.bio"), text_column=0, annotation_column=1))
+ print("%d total sentences in the original wiki" % len(original))
+ original_words = {tuple([x[0] for x in sentence]) for sentence in original}
+ print(" (%d after dedup)" % len(original_words))
+
+ missing = [sentence for sentence in gold if sentence not in original_words]
+ for sentence in missing:
+ # the capitalization of WisiGoths and OstroGoths is different
+ # between the original and the new in some cases
+ goths = tuple([x.replace("Goth", "goth") for x in sentence])
+ if goths != sentence and goths in original_words:
+ original_words.add(sentence)
+ missing = [sentence for sentence in gold if sentence not in original_words]
+ # currently this dataset doesn't find two sentences
+ # one was dropped by the filter for incompletely tagged lines
+ # the other is probably not a huge deal to have one duplicate
+ print("Missing %d sentences" % len(missing))
+ assert len(missing) <= 2
+ for sent in missing:
+ print(sent)
+
+ skipped = 0
+ silver = []
+ silver_used = set()
+ for sentence in original:
+ words = tuple([x[0] for x in sentence])
+ tags = tuple([x[1] for x in sentence])
+ if words in gold or words in silver_used:
+ skipped += 1
+ continue
+ tags = to_bio2(tags)
+ tags = bio2_to_bioes(tags)
+ sentence = [{"text": x, "ner": y, "multi_ner": ["-", y]} for x, y in zip(words, tags)]
+ silver.append(sentence)
+ silver_used.add(words)
+ print("Using %d sentences from the original wikiner alongside the gold annotated train set" % len(silver))
+ print("Skipped %d sentences" % skipped)
+
+ gold_train = [[{"text": x["text"], "ner": x["ner"], "multi_ner": [x["ner"], "-"]} for x in sentence]
+ for sentence in gold_train]
+ gold_dev = [[{"text": x["text"], "ner": x["ner"], "multi_ner": [x["ner"], "-"]} for x in sentence]
+ for sentence in gold_dev]
+ gold_test = [[{"text": x["text"], "ner": x["ner"], "multi_ner": [x["ner"], "-"]} for x in sentence]
+ for sentence in gold_test]
+
+ mixed_train = gold_train + silver
+ print("Total sentences in the mixed training set: %d" % len(mixed_train))
+ output_filename = os.path.join(base_output_path, "%s.train.json" % short_name)
+ with open(output_filename, 'w', encoding='utf-8') as fout:
+ json.dump(mixed_train, fout, indent=1)
+
+ output_filename = os.path.join(base_output_path, "%s.dev.json" % short_name)
+ with open(output_filename, 'w', encoding='utf-8') as fout:
+ json.dump(gold_dev, fout, indent=1)
+ output_filename = os.path.join(base_output_path, "%s.test.json" % short_name)
+ with open(output_filename, 'w', encoding='utf-8') as fout:
+ json.dump(gold_test, fout, indent=1)
+
+
+def get_rgai_input_path(paths):
+ return os.path.join(paths["NERBASE"], "hu_rgai")
+
+def process_rgai(paths, short_name):
+ base_output_path = paths["NER_DATA_DIR"]
+ base_input_path = get_rgai_input_path(paths)
+
+ if short_name == 'hu_rgai':
+ use_business = True
+ use_criminal = True
+ elif short_name == 'hu_rgai_business':
+ use_business = True
+ use_criminal = False
+ elif short_name == 'hu_rgai_criminal':
+ use_business = False
+ use_criminal = True
+ else:
+ raise UnknownDatasetError(short_name, "Unknown subset of hu_rgai data: %s" % short_name)
+
+ convert_rgai.convert_rgai(base_input_path, base_output_path, short_name, use_business, use_criminal)
+ convert_bio_to_json(base_output_path, base_output_path, short_name)
+
+def get_nytk_input_path(paths):
+ return os.path.join(paths["NERBASE"], "NYTK-NerKor")
+
+def process_nytk(paths, short_name):
+ """
+ Process the NYTK dataset
+ """
+ assert short_name == "hu_nytk"
+ base_output_path = paths["NER_DATA_DIR"]
+ base_input_path = get_nytk_input_path(paths)
+
+ convert_nytk.convert_nytk(base_input_path, base_output_path, short_name)
+ convert_bio_to_json(base_output_path, base_output_path, short_name)
+
+def concat_files(output_file, *input_files):
+ input_lines = []
+ for input_file in input_files:
+ with open(input_file) as fin:
+ lines = fin.readlines()
+ if not len(lines):
+ raise ValueError("Empty input file: %s" % input_file)
+ if not lines[-1]:
+ lines[-1] = "\n"
+ elif lines[-1].strip():
+ lines.append("\n")
+ input_lines.append(lines)
+ with open(output_file, "w") as fout:
+ for lines in input_lines:
+ for line in lines:
+ fout.write(line)
+
+
+def process_hu_combined(paths, short_name):
+ assert short_name == "hu_combined"
+
+ base_output_path = paths["NER_DATA_DIR"]
+ rgai_input_path = get_rgai_input_path(paths)
+ nytk_input_path = get_nytk_input_path(paths)
+
+ with tempfile.TemporaryDirectory() as tmp_output_path:
+ convert_rgai.convert_rgai(rgai_input_path, tmp_output_path, "hu_rgai", True, True)
+ convert_nytk.convert_nytk(nytk_input_path, tmp_output_path, "hu_nytk")
+
+ for shard in SHARDS:
+ rgai_input = os.path.join(tmp_output_path, "hu_rgai.%s.bio" % shard)
+ nytk_input = os.path.join(tmp_output_path, "hu_nytk.%s.bio" % shard)
+ output_file = os.path.join(base_output_path, "hu_combined.%s.bio" % shard)
+ concat_files(output_file, rgai_input, nytk_input)
+
+ convert_bio_to_json(base_output_path, base_output_path, short_name)
+
+def process_bsnlp(paths, short_name):
+ """
+ Process files downloaded from http://bsnlp.cs.helsinki.fi/bsnlp-2019/shared_task.html
+
+ If you download the training and test data zip files and unzip
+ them without rearranging in any way, the layout is somewhat weird.
+ Training data goes into a specific subdirectory, but the test data
+ goes into the top level directory.
+ """
+ base_input_path = os.path.join(paths["NERBASE"], "bsnlp2019")
+ base_train_path = os.path.join(base_input_path, "training_pl_cs_ru_bg_rc1")
+ base_test_path = base_input_path
+
+ base_output_path = paths["NER_DATA_DIR"]
+
+ output_train_filename = os.path.join(base_output_path, "%s.train.csv" % short_name)
+ output_dev_filename = os.path.join(base_output_path, "%s.dev.csv" % short_name)
+ output_test_filename = os.path.join(base_output_path, "%s.test.csv" % short_name)
+
+ language = short_name.split("_")[0]
+
+ convert_bsnlp.convert_bsnlp(language, base_test_path, output_test_filename)
+ convert_bsnlp.convert_bsnlp(language, base_train_path, output_train_filename, output_dev_filename)
+
+ for shard, csv_file in zip(SHARDS, (output_train_filename, output_dev_filename, output_test_filename)):
+ output_filename = os.path.join(base_output_path, '%s.%s.json' % (short_name, shard))
+ prepare_ner_file.process_dataset(csv_file, output_filename)
+
+NCHLT_LANGUAGE_MAP = {
+ "af": "NCHLT Afrikaans Named Entity Annotated Corpus",
+ # none of the following have UD datasets as of 2.8. Until they
+ # exist, we assume the language codes NCHTL are sufficient
+ "nr": "NCHLT isiNdebele Named Entity Annotated Corpus",
+ "nso": "NCHLT Sepedi Named Entity Annotated Corpus",
+ "ss": "NCHLT Siswati Named Entity Annotated Corpus",
+ "st": "NCHLT Sesotho Named Entity Annotated Corpus",
+ "tn": "NCHLT Setswana Named Entity Annotated Corpus",
+ "ts": "NCHLT Xitsonga Named Entity Annotated Corpus",
+ "ve": "NCHLT Tshivenda Named Entity Annotated Corpus",
+ "xh": "NCHLT isiXhosa Named Entity Annotated Corpus",
+ "zu": "NCHLT isiZulu Named Entity Annotated Corpus",
+}
+
+def process_nchlt(paths, short_name):
+ language = short_name.split("_")[0]
+ if not language in NCHLT_LANGUAGE_MAP:
+ raise UnknownDatasetError(short_name, "Language %s not part of NCHLT" % language)
+ short_name = "%s_nchlt" % language
+
+ base_input_path = os.path.join(paths["NERBASE"], "NCHLT", NCHLT_LANGUAGE_MAP[language], "*Full.txt")
+ input_files = glob.glob(base_input_path)
+ if len(input_files) == 0:
+ raise FileNotFoundError("Cannot find NCHLT dataset in '%s' Did you remember to download the file?" % base_input_path)
+
+ if len(input_files) > 1:
+ raise ValueError("Unexpected number of files matched '%s' There should only be one" % base_input_path)
+
+ base_output_path = paths["NER_DATA_DIR"]
+ split_wikiner(base_output_path, input_files[0], prefix=short_name, remap={"OUT": "O"})
+ convert_bio_to_json(base_output_path, base_output_path, short_name)
+
+def process_my_ucsy(paths, short_name):
+ assert short_name == "my_ucsy"
+ language = "my"
+
+ base_input_path = os.path.join(paths["NERBASE"], short_name)
+ base_output_path = paths["NER_DATA_DIR"]
+ convert_my_ucsy.convert_my_ucsy(base_input_path, base_output_path)
+ convert_bio_to_json(base_output_path, base_output_path, short_name)
+
+def process_fa_arman(paths, short_name):
+ """
+ Converts fa_arman dataset
+
+ The conversion is quite simple, actually.
+ Just need to split the train file and then convert bio -> json
+ """
+ assert short_name == "fa_arman"
+ language = "fa"
+ base_input_path = os.path.join(paths["NERBASE"], "PersianNER")
+ train_input_file = os.path.join(base_input_path, "train_fold1.txt")
+ test_input_file = os.path.join(base_input_path, "test_fold1.txt")
+ if not os.path.exists(train_input_file) or not os.path.exists(test_input_file):
+ full_corpus_file = os.path.join(base_input_path, "ArmanPersoNERCorpus.zip")
+ if os.path.exists(full_corpus_file):
+ raise FileNotFoundError("Please unzip the file {}".format(full_corpus_file))
+ raise FileNotFoundError("Cannot find the arman corpus in the expected directory: {}".format(base_input_path))
+
+ base_output_path = paths["NER_DATA_DIR"]
+ test_output_file = os.path.join(base_output_path, "%s.test.bio" % short_name)
+
+ split_wikiner(base_output_path, train_input_file, prefix=short_name, train_fraction=0.8, test_section=False)
+ shutil.copy2(test_input_file, test_output_file)
+ convert_bio_to_json(base_output_path, base_output_path, short_name)
+
+def process_sv_suc3licensed(paths, short_name):
+ """
+ The .zip provided for SUC3 includes train/dev/test splits already
+
+ This extracts those splits without needing to unzip the original file
+ """
+ assert short_name == "sv_suc3licensed"
+ language = "sv"
+ train_input_file = os.path.join(paths["NERBASE"], short_name, "SUC3.0.zip")
+ if not os.path.exists(train_input_file):
+ raise FileNotFoundError("Cannot find the officially licensed SUC3 dataset in %s" % train_input_file)
+
+ base_output_path = paths["NER_DATA_DIR"]
+ suc_conll_to_iob.process_suc3(train_input_file, short_name, base_output_path)
+ convert_bio_to_json(base_output_path, base_output_path, short_name)
+
+def process_sv_suc3shuffle(paths, short_name):
+ """
+ Uses an externally provided script to read the SUC3 XML file, then splits it
+ """
+ assert short_name == "sv_suc3shuffle"
+ language = "sv"
+ train_input_file = os.path.join(paths["NERBASE"], short_name, "suc3.xml.bz2")
+ if not os.path.exists(train_input_file):
+ train_input_file = train_input_file[:-4]
+ if not os.path.exists(train_input_file):
+ raise FileNotFoundError("Unable to find the SUC3 dataset in {}.bz2".format(train_input_file))
+
+ base_output_path = paths["NER_DATA_DIR"]
+ train_output_file = os.path.join(base_output_path, "sv_suc3shuffle.bio")
+ suc_to_iob.main([train_input_file, train_output_file])
+ split_wikiner(base_output_path, train_output_file, prefix=short_name)
+ convert_bio_to_json(base_output_path, base_output_path, short_name)
+
+def process_da_ddt(paths, short_name):
+ """
+ Processes Danish DDT dataset
+
+ This dataset is in a conll file with the "name" attribute in the
+ misc column for the NER tag. This function uses a script to
+ convert such CoNLL files to .bio
+ """
+ assert short_name == "da_ddt"
+ language = "da"
+ IN_FILES = ("ddt.train.conllu", "ddt.dev.conllu", "ddt.test.conllu")
+
+ base_output_path = paths["NER_DATA_DIR"]
+ OUT_FILES = [os.path.join(base_output_path, "%s.%s.bio" % (short_name, shard)) for shard in SHARDS]
+
+ zip_file = os.path.join(paths["NERBASE"], "da_ddt", "ddt.zip")
+ if os.path.exists(zip_file):
+ for in_filename, out_filename, shard in zip(IN_FILES, OUT_FILES, SHARDS):
+ conll_to_iob.process_conll(in_filename, out_filename, zip_file)
+ else:
+ for in_filename, out_filename, shard in zip(IN_FILES, OUT_FILES, SHARDS):
+ in_filename = os.path.join(paths["NERBASE"], "da_ddt", in_filename)
+ if not os.path.exists(in_filename):
+ raise FileNotFoundError("Could not find zip in expected location %s and could not file %s file in %s" % (zip_file, shard, in_filename))
+
+ conll_to_iob.process_conll(in_filename, out_filename)
+ convert_bio_to_json(base_output_path, base_output_path, short_name)
+
+
+def process_norne(paths, short_name):
+ """
+ Processes Norwegian NorNE
+
+ Can handle either Bokmål or Nynorsk
+
+ Converts GPE_LOC and GPE_ORG to GPE
+ """
+ language, name = short_name.split("_", 1)
+ assert language in ('nb', 'nn')
+ assert name == 'norne'
+
+ if language == 'nb':
+ IN_FILES = ("nob/no_bokmaal-ud-train.conllu", "nob/no_bokmaal-ud-dev.conllu", "nob/no_bokmaal-ud-test.conllu")
+ else:
+ IN_FILES = ("nno/no_nynorsk-ud-train.conllu", "nno/no_nynorsk-ud-dev.conllu", "nno/no_nynorsk-ud-test.conllu")
+
+ base_output_path = paths["NER_DATA_DIR"]
+ OUT_FILES = [os.path.join(base_output_path, "%s.%s.bio" % (short_name, shard)) for shard in SHARDS]
+
+ CONVERSION = { "GPE_LOC": "GPE", "GPE_ORG": "GPE" }
+
+ for in_filename, out_filename, shard in zip(IN_FILES, OUT_FILES, SHARDS):
+ in_filename = os.path.join(paths["NERBASE"], "norne", "ud", in_filename)
+ if not os.path.exists(in_filename):
+ raise FileNotFoundError("Could not find %s file in %s" % (shard, in_filename))
+
+ conll_to_iob.process_conll(in_filename, out_filename, conversion=CONVERSION)
+
+ convert_bio_to_json(base_output_path, base_output_path, short_name)
+
+def process_ja_gsd(paths, short_name):
+ """
+ Convert ja_gsd from MegagonLabs
+
+ for example, can download from https://github.com/megagonlabs/UD_Japanese-GSD/releases/tag/r2.9-NE
+ """
+ language, name = short_name.split("_", 1)
+ assert language == 'ja'
+ assert name == 'gsd'
+
+ base_output_path = paths["NER_DATA_DIR"]
+ output_files = [os.path.join(base_output_path, "%s.%s.bio" % (short_name, shard)) for shard in SHARDS]
+
+ search_path = os.path.join(paths["NERBASE"], "ja_gsd", "UD_Japanese-GSD-r2.*-NE")
+ versions = glob.glob(search_path)
+ max_version = None
+ base_input_path = None
+ version_re = re.compile("GSD-r2.([0-9]+)-NE$")
+
+ for ver in versions:
+ match = version_re.search(ver)
+ if not match:
+ continue
+ ver_num = int(match.groups(1)[0])
+ if max_version is None or ver_num > max_version:
+ max_version = ver_num
+ base_input_path = ver
+
+ if base_input_path is None:
+ raise FileNotFoundError("Could not find any copies of the NE conversion of ja_gsd here: {}".format(search_path))
+ print("Most recent version found: {}".format(base_input_path))
+
+ input_files = ["ja_gsd-ud-train.ne.conllu", "ja_gsd-ud-dev.ne.conllu", "ja_gsd-ud-test.ne.conllu"]
+
+ def conversion(x):
+ if x[0] == 'L':
+ return 'E' + x[1:]
+ if x[0] == 'U':
+ return 'S' + x[1:]
+ # B, I unchanged
+ return x
+
+ for in_filename, out_filename, shard in zip(input_files, output_files, SHARDS):
+ in_path = os.path.join(base_input_path, in_filename)
+ if not os.path.exists(in_path):
+ in_spacy = os.path.join(base_input_path, "spacy", in_filename)
+ if not os.path.exists(in_spacy):
+ raise FileNotFoundError("Could not find %s file in %s or %s" % (shard, in_path, in_spacy))
+ in_path = in_spacy
+
+ conll_to_iob.process_conll(in_path, out_filename, conversion=conversion, allow_empty=True, attr_prefix="NE")
+
+ convert_bio_to_json(base_output_path, base_output_path, short_name)
+
+def process_starlang(paths, short_name):
+ """
+ Process a Turkish dataset from Starlang
+ """
+ assert short_name == 'tr_starlang'
+
+ PIECES = ["TurkishAnnotatedTreeBank-15",
+ "TurkishAnnotatedTreeBank2-15",
+ "TurkishAnnotatedTreeBank2-20"]
+
+ chunk_paths = [os.path.join(paths["CONSTITUENCY_BASE"], "turkish", piece) for piece in PIECES]
+ datasets = convert_starlang_ner.read_starlang(chunk_paths)
+
+ write_dataset(datasets, paths["NER_DATA_DIR"], short_name)
+
+def remap_germeval_tag(tag):
+ """
+ Simplify tags for GermEval2014 using a simple rubric
+
+ all tags become their parent tag
+ OTH becomes MISC
+ """
+ if tag == "O":
+ return tag
+ if tag[1:5] == "-LOC":
+ return tag[:5]
+ if tag[1:5] == "-PER":
+ return tag[:5]
+ if tag[1:5] == "-ORG":
+ return tag[:5]
+ if tag[1:5] == "-OTH":
+ return tag[0] + "-MISC"
+ raise ValueError("Unexpected tag: %s" % tag)
+
+def process_de_germeval2014(paths, short_name):
+ """
+ Process the TSV of the GermEval2014 dataset
+ """
+ in_directory = os.path.join(paths["NERBASE"], "germeval2014")
+ base_output_path = paths["NER_DATA_DIR"]
+ datasets = []
+ for shard in SHARDS:
+ in_file = os.path.join(in_directory, "NER-de-%s.tsv" % shard)
+ sentences = read_tsv(in_file, 1, 2, remap_fn=remap_germeval_tag)
+ datasets.append(sentences)
+ tags = get_tags(datasets)
+ print("Found the following tags: {}".format(sorted(tags)))
+ write_dataset(datasets, base_output_path, short_name)
+
+def process_hiner(paths, short_name):
+ in_directory = os.path.join(paths["NERBASE"], "hindi", "HiNER", "data", "original")
+ convert_bio_to_json(in_directory, paths["NER_DATA_DIR"], short_name, suffix="conll", shard_names=("train", "validation", "test"))
+
+def process_hinercollapsed(paths, short_name):
+ in_directory = os.path.join(paths["NERBASE"], "hindi", "HiNER", "data", "collapsed")
+ convert_bio_to_json(in_directory, paths["NER_DATA_DIR"], short_name, suffix="conll", shard_names=("train", "validation", "test"))
+
+def process_lst20(paths, short_name, include_space_char=True):
+ convert_lst20.convert_lst20(paths, short_name, include_space_char)
+
+def process_nner22(paths, short_name, include_space_char=True):
+ convert_nner22.convert_nner22(paths, short_name, include_space_char)
+
+def process_mr_l3cube(paths, short_name):
+ base_output_path = paths["NER_DATA_DIR"]
+ in_directory = os.path.join(paths["NERBASE"], "marathi", "MarathiNLP", "L3Cube-MahaNER", "IOB")
+ input_files = ["train_iob.txt", "valid_iob.txt", "test_iob.txt"]
+ input_files = [os.path.join(in_directory, x) for x in input_files]
+ for input_file in input_files:
+ if not os.path.exists(input_file):
+ raise FileNotFoundError("Could not find the expected piece of the l3cube dataset %s" % input_file)
+
+ datasets = [convert_mr_l3cube.convert(input_file) for input_file in input_files]
+ write_dataset(datasets, base_output_path, short_name)
+
+def process_bn_daffodil(paths, short_name):
+ in_directory = os.path.join(paths["NERBASE"], "bangla", "Bengali-NER")
+ out_directory = paths["NER_DATA_DIR"]
+ convert_bn_daffodil.convert_dataset(in_directory, out_directory)
+
+def process_pl_nkjp(paths, short_name):
+ out_directory = paths["NER_DATA_DIR"]
+ candidates = [os.path.join(paths["NERBASE"], "Polish-NKJP"),
+ os.path.join(paths["NERBASE"], "polish", "Polish-NKJP"),
+ os.path.join(paths["NERBASE"], "polish", "NKJP-PodkorpusMilionowy-1.2.tar.gz"),]
+ for in_path in candidates:
+ if os.path.exists(in_path):
+ break
+ else:
+ raise FileNotFoundError("Could not find %s Looked in %s" % (short_name, " ".join(candidates)))
+ convert_nkjp.convert_nkjp(in_path, out_directory)
+
+def process_kk_kazNERD(paths, short_name):
+ in_directory = os.path.join(paths["NERBASE"], "kazakh", "KazNERD", "KazNERD")
+ out_directory = paths["NER_DATA_DIR"]
+ convert_kk_kazNERD.convert_dataset(in_directory, out_directory, short_name)
+
+def process_masakhane(paths, dataset_name):
+ """
+ Converts Masakhane NER datasets to Stanza's .json format
+
+ If we let N be the length of the first sentence, the NER files
+ (in version 2, at least) are all of the form
+
+ word tag
+ ...
+ word tag
+ (blank line for sentence break)
+ word tag
+ ...
+
+ Once the dataset is git cloned in $NERBASE, the directory structure is
+
+ $NERBASE/masakhane-ner/MasakhaNER2.0/data/$lcode/{train,dev,test}.txt
+
+ The only tricky thing here is that for some languages, we treat
+ the 2 letter lcode as canonical thanks to UD, but Masakhane NER
+ uses 3 letter lcodes for all languages.
+ """
+ language, dataset = dataset_name.split("_")
+ lcode = lang_to_langcode(language)
+ if lcode in two_to_three_letters:
+ masakhane_lcode = two_to_three_letters[lcode]
+ else:
+ masakhane_lcode = lcode
+
+ mn_directory = os.path.join(paths["NERBASE"], "masakhane-ner")
+ if not os.path.exists(mn_directory):
+ raise FileNotFoundError("Cannot find Masakhane NER repo. Please check the setting of NERBASE or clone the repo to %s" % mn_directory)
+ data_directory = os.path.join(mn_directory, "MasakhaNER2.0", "data")
+ if not os.path.exists(data_directory):
+ raise FileNotFoundError("Apparently found the repo at %s but the expected directory structure is not there - was looking for %s" % (mn_directory, data_directory))
+
+ in_directory = os.path.join(data_directory, masakhane_lcode)
+ if not os.path.exists(in_directory):
+ raise UnknownDatasetError(dataset_name, "Found the Masakhane repo, but there was no %s in the repo at path %s" % (dataset_name, in_directory))
+ convert_bio_to_json(in_directory, paths["NER_DATA_DIR"], "%s_masakhane" % lcode, "txt")
+
+def process_sd_siner(paths, short_name):
+ in_directory = os.path.join(paths["NERBASE"], "sindhi", "SiNER-dataset")
+ if not os.path.exists(in_directory):
+ raise FileNotFoundError("Cannot find SiNER checkout in $NERBASE/sindhi Please git clone to repo in that directory")
+ in_filename = os.path.join(in_directory, "SiNER-dataset.txt")
+ if not os.path.exists(in_filename):
+ in_filename = os.path.join(in_directory, "SiNER dataset.txt")
+ if not os.path.exists(in_filename):
+ raise FileNotFoundError("Found an SiNER directory at %s but the directory did not contain the dataset" % in_directory)
+ convert_sindhi_siner.convert_sindhi_siner(in_filename, paths["NER_DATA_DIR"], short_name)
+
+def process_en_worldwide_4class(paths, short_name):
+ simplify_en_worldwide.main(args=['--simplify'])
+
+ in_directory = os.path.join(paths["NERBASE"], "en_worldwide", "4class")
+ out_directory = paths["NER_DATA_DIR"]
+
+ destination_file = os.path.join(paths["NERBASE"], "en_worldwide", "en-worldwide-newswire", "regions.txt")
+ prefix_map = read_prefix_file(destination_file)
+
+ random_shuffle_by_prefixes(in_directory, out_directory, short_name, prefix_map)
+
+def process_en_worldwide_9class(paths, short_name):
+ simplify_en_worldwide.main(args=['--no_simplify'])
+
+ in_directory = os.path.join(paths["NERBASE"], "en_worldwide", "9class")
+ out_directory = paths["NER_DATA_DIR"]
+
+ destination_file = os.path.join(paths["NERBASE"], "en_worldwide", "en-worldwide-newswire", "regions.txt")
+ prefix_map = read_prefix_file(destination_file)
+
+ random_shuffle_by_prefixes(in_directory, out_directory, short_name, prefix_map)
+
+def process_en_ontonotes(paths, short_name):
+ ner_input_path = paths['NERBASE']
+ ontonotes_path = os.path.join(ner_input_path, "english", "en_ontonotes")
+ ner_output_path = paths['NER_DATA_DIR']
+ convert_ontonotes.process_dataset("en_ontonotes", ontonotes_path, ner_output_path)
+
+def process_zh_ontonotes(paths, short_name):
+ ner_input_path = paths['NERBASE']
+ ontonotes_path = os.path.join(ner_input_path, "chinese", "zh_ontonotes")
+ ner_output_path = paths['NER_DATA_DIR']
+ convert_ontonotes.process_dataset(short_name, ontonotes_path, ner_output_path)
+
+def process_en_conll03(paths, short_name):
+ ner_input_path = paths['NERBASE']
+ conll_path = os.path.join(ner_input_path, "english", "en_conll03")
+ ner_output_path = paths['NER_DATA_DIR']
+ convert_en_conll03.process_dataset("en_conll03", conll_path, ner_output_path)
+
+def process_en_conll03_worldwide(paths, short_name):
+ """
+ Adds the training data for conll03 and worldwide together
+ """
+ print("============== Preparing CoNLL 2003 ===================")
+ process_en_conll03(paths, "en_conll03")
+ print("========== Preparing 4 Class Worldwide ================")
+ process_en_worldwide_4class(paths, "en_worldwide-4class")
+ print("============== Combined Train Data ====================")
+ input_files = [os.path.join(paths['NER_DATA_DIR'], "en_conll03.train.json"),
+ os.path.join(paths['NER_DATA_DIR'], "en_worldwide-4class.train.json")]
+ output_file = os.path.join(paths['NER_DATA_DIR'], "%s.train.json" % short_name)
+ combine_files(output_file, *input_files)
+ shutil.copyfile(os.path.join(paths['NER_DATA_DIR'], "en_conll03.dev.json"),
+ os.path.join(paths['NER_DATA_DIR'], "%s.dev.json" % short_name))
+ shutil.copyfile(os.path.join(paths['NER_DATA_DIR'], "en_conll03.test.json"),
+ os.path.join(paths['NER_DATA_DIR'], "%s.test.json" % short_name))
+
+def process_en_ontonotes_ww_multi(paths, short_name):
+ """
+ Combine the worldwide data with the OntoNotes data in a multi channel format
+ """
+ print("=============== Preparing OntoNotes ===============")
+ process_en_ontonotes(paths, "en_ontonotes")
+ print("========== Preparing 9 Class Worldwide ================")
+ process_en_worldwide_9class(paths, "en_worldwide-9class")
+ # TODO: pass in options?
+ ontonotes_multitag.build_multitag_dataset(paths['NER_DATA_DIR'], short_name, True, True)
+
+def process_en_combined(paths, short_name):
+ """
+ Combine WW, OntoNotes, and CoNLL into a 3 channel dataset
+ """
+ print("================= Preparing OntoNotes =================")
+ process_en_ontonotes(paths, "en_ontonotes")
+ print("========== Preparing 9 Class Worldwide ================")
+ process_en_worldwide_9class(paths, "en_worldwide-9class")
+ print("=============== Preparing CoNLL 03 ====================")
+ process_en_conll03(paths, "en_conll03")
+ build_en_combined.build_combined_dataset(paths['NER_DATA_DIR'], short_name)
+
+
+def process_en_conllpp(paths, short_name):
+ """
+ This is ONLY a test set
+
+ the test set has entities start with I- instead of B- unless they
+ are in the middle of a sentence, but that should be find, as
+ process_tags in the NER model converts those to B- in a BIOES
+ conversion
+ """
+ base_input_path = os.path.join(paths["NERBASE"], "acl2023_conllpp", "dataset", "conllpp.txt")
+ base_output_path = paths["NER_DATA_DIR"]
+ sentences = read_tsv(base_input_path, 0, 3, separator=None)
+ sentences = [sent for sent in sentences if len(sent) > 1 or sent[0][0] != '-DOCSTART-']
+ write_dataset([sentences], base_output_path, short_name, shard_names=["test"], shards=["test"])
+
+def process_armtdp(paths, short_name):
+ assert short_name == 'hy_armtdp'
+ base_input_path = os.path.join(paths["NERBASE"], "armenian", "ArmTDP-NER")
+ base_output_path = paths["NER_DATA_DIR"]
+ convert_hy_armtdp.convert_dataset(base_input_path, base_output_path, short_name)
+ for shard in SHARDS:
+ input_filename = os.path.join(base_output_path, f'{short_name}.{shard}.tsv')
+ if not os.path.exists(input_filename):
+ raise FileNotFoundError('Cannot find %s component of %s in %s' % (shard, short_name, input_filename))
+ output_filename = os.path.join(base_output_path, '%s.%s.json' % (short_name, shard))
+ prepare_ner_file.process_dataset(input_filename, output_filename)
+
+def process_toy_dataset(paths, short_name):
+ convert_bio_to_json(os.path.join(paths["NERBASE"], "English-SAMPLE"), paths["NER_DATA_DIR"], short_name)
+
+def process_ar_aqmar(paths, short_name):
+ base_input_path = os.path.join(paths["NERBASE"], "arabic", "AQMAR", "AQMAR_Arabic_NER_corpus-1.0.zip")
+ base_output_path = paths["NER_DATA_DIR"]
+ convert_ar_aqmar.convert_shuffle(base_input_path, base_output_path, short_name)
+
+def process_he_iahlt(paths, short_name):
+ assert short_name == 'he_iahlt'
+ # for now, need to use UDBASE_GIT until IAHLTknesset is added to UD
+ udbase = paths["UDBASE_GIT"]
+ base_output_path = paths["NER_DATA_DIR"]
+ convert_he_iahlt.convert_iahlt(udbase, base_output_path, "he_iahlt")
+
+
+DATASET_MAPPING = {
+ "ar_aqmar": process_ar_aqmar,
+ "bn_daffodil": process_bn_daffodil,
+ "da_ddt": process_da_ddt,
+ "de_germeval2014": process_de_germeval2014,
+ "en_conll03": process_en_conll03,
+ "en_conll03ww": process_en_conll03_worldwide,
+ "en_conllpp": process_en_conllpp,
+ "en_ontonotes": process_en_ontonotes,
+ "en_ontonotes-ww-multi": process_en_ontonotes_ww_multi,
+ "en_combined": process_en_combined,
+ "en_worldwide-4class": process_en_worldwide_4class,
+ "en_worldwide-9class": process_en_worldwide_9class,
+ "fa_arman": process_fa_arman,
+ "fi_turku": process_turku,
+ "fr_wikinergold": process_french_wikiner_gold,
+ "fr_wikinermixed": process_french_wikiner_mixed,
+ "hi_hiner": process_hiner,
+ "hi_hinercollapsed": process_hinercollapsed,
+ "hi_ijc": process_ijc,
+ "he_iahlt": process_he_iahlt,
+ "hu_nytk": process_nytk,
+ "hu_combined": process_hu_combined,
+ "hy_armtdp": process_armtdp,
+ "it_fbk": process_it_fbk,
+ "ja_gsd": process_ja_gsd,
+ "kk_kazNERD": process_kk_kazNERD,
+ "mr_l3cube": process_mr_l3cube,
+ "my_ucsy": process_my_ucsy,
+ "pl_nkjp": process_pl_nkjp,
+ "sd_siner": process_sd_siner,
+ "sv_suc3licensed": process_sv_suc3licensed,
+ "sv_suc3shuffle": process_sv_suc3shuffle,
+ "tr_starlang": process_starlang,
+ "th_lst20": process_lst20,
+ "th_nner22": process_nner22,
+ "zh-hans_ontonotes": process_zh_ontonotes,
+}
+
+def main(dataset_name):
+ paths = default_paths.get_default_paths()
+ print("Processing %s" % dataset_name)
+
+ random.seed(1234)
+
+ if dataset_name in DATASET_MAPPING:
+ DATASET_MAPPING[dataset_name](paths, dataset_name)
+ elif dataset_name in ('uk_languk', 'Ukranian_languk', 'Ukranian-languk'):
+ process_languk(paths, dataset_name)
+ elif dataset_name.endswith("FIRE2013") or dataset_name.endswith("fire2013"):
+ process_fire_2013(paths, dataset_name)
+ elif dataset_name.endswith('WikiNER'):
+ process_wikiner(paths, dataset_name)
+ elif dataset_name.startswith('hu_rgai'):
+ process_rgai(paths, dataset_name)
+ elif dataset_name.endswith("_bsnlp19"):
+ process_bsnlp(paths, dataset_name)
+ elif dataset_name.endswith("_nchlt"):
+ process_nchlt(paths, dataset_name)
+ elif dataset_name in ("nb_norne", "nn_norne"):
+ process_norne(paths, dataset_name)
+ elif dataset_name == 'en_sample':
+ process_toy_dataset(paths, dataset_name)
+ elif dataset_name.lower().endswith("_masakhane"):
+ process_masakhane(paths, dataset_name)
+ else:
+ raise UnknownDatasetError(dataset_name, f"dataset {dataset_name} currently not handled by prepare_ner_dataset")
+ print("Done processing %s" % dataset_name)
+
+if __name__ == '__main__':
+ main(sys.argv[1])
diff --git a/stanza/stanza/utils/datasets/ner/suc_to_iob.py b/stanza/stanza/utils/datasets/ner/suc_to_iob.py
new file mode 100644
index 0000000000000000000000000000000000000000..bec66c6b89de2babdee20eb43ee283102973a360
--- /dev/null
+++ b/stanza/stanza/utils/datasets/ner/suc_to_iob.py
@@ -0,0 +1,181 @@
+"""
+Conversion tool to transform SUC3's xml format to IOB
+
+Copyright 2017-2022, Emil Stenström
+
+Permission is hereby granted, free of charge, to any person obtaining
+a copy of this software and associated documentation files (the
+"Software"), to deal in the Software without restriction, including
+without limitation the rights to use, copy, modify, merge, publish,
+distribute, sublicense, and/or sell copies of the Software, and to
+permit persons to whom the Software is furnished to do so, subject to
+the following conditions:
+
+The above copyright notice and this permission notice shall be
+included in all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
+LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
+OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
+WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+"""
+
+from bz2 import BZ2File
+from xml.etree.ElementTree import iterparse
+import argparse
+from collections import Counter
+import sys
+
+def parse(fp, skiptypes=[]):
+ root = None
+ ne_prefix = ""
+ ne_type = "O"
+ name_prefix = ""
+ name_type = "O"
+
+ for event, elem in iterparse(fp, events=("start", "end")):
+ if root is None:
+ root = elem
+
+ if event == "start":
+ if elem.tag == "name":
+ _type = name_type_to_label(elem.attrib["type"])
+ if (
+ _type not in skiptypes and
+ not (_type == "ORG" and ne_type == "LOC")
+ ):
+ name_type = _type
+ name_prefix = "B-"
+
+ elif elem.tag == "ne":
+ _type = ne_type_to_label(elem.attrib["type"])
+ if "/" in _type:
+ _type = ne_type_to_label(_type[_type.index("/") + 1:])
+
+ if _type not in skiptypes:
+ ne_type = _type
+ ne_prefix = "B-"
+
+ elif elem.tag == "w":
+ if name_type == "PER" and elem.attrib["pos"] == "NN":
+ name_type = "O"
+ name_prefix = ""
+
+ elif event == "end":
+ if elem.tag == "sentence":
+ yield
+
+ elif elem.tag == "name":
+ name_type = "O"
+ name_prefix = ""
+
+ elif elem.tag == "ne":
+ ne_type = "O"
+ ne_prefix = ""
+
+ elif elem.tag == "w":
+ if name_type != "O" and name_type != "OTH":
+ yield elem.text, name_prefix, name_type
+ elif ne_type != "O":
+ yield elem.text, ne_prefix, ne_type
+ else:
+ yield elem.text, "", "O"
+
+ if ne_type != "O":
+ ne_prefix = "I-"
+
+ if name_type != "O":
+ name_prefix = "I-"
+
+ root.clear()
+
+def ne_type_to_label(ne_type):
+ mapping = {
+ "PRS": "PER",
+ }
+ return mapping.get(ne_type, ne_type)
+
+def name_type_to_label(name_type):
+ mapping = {
+ "inst": "ORG",
+ "product": "OBJ",
+ "other": "OTH",
+ "place": "LOC",
+ "myth": "PER",
+ "person": "PER",
+ "event": "EVN",
+ "work": "WRK",
+ "animal": "PER",
+ }
+ return mapping.get(name_type)
+
+def main(args=None):
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "infile",
+ help="""
+ Input for that contains the full SUC 3.0 XML.
+ Can be the bz2-zipped version or the xml version.
+ """
+ )
+ parser.add_argument(
+ "outfile",
+ nargs="?",
+ help="""
+ Output file for IOB format.
+ Optional - will print to stdout otherwise
+ """
+ )
+ parser.add_argument(
+ "--skiptypes",
+ help="Entity types that should be skipped in output.",
+ nargs="+",
+ default=[]
+ )
+ parser.add_argument(
+ "--stats_only",
+ help="Show statistics of found labels at the end of output.",
+ action='store_true',
+ default=False
+ )
+ args = parser.parse_args(args)
+
+ MAGIC_BZ2_FILE_START = b"\x42\x5a\x68"
+ fp = open(args.infile, "rb")
+ is_bz2 = (fp.read(len(MAGIC_BZ2_FILE_START)) == MAGIC_BZ2_FILE_START)
+
+ if is_bz2:
+ fp = BZ2File(args.infile, "rb")
+ else:
+ fp = open(args.infile, "rb")
+
+ if args.outfile is not None:
+ fout = open(args.outfile, "w", encoding="utf-8")
+ else:
+ fout = sys.stdout
+
+ type_stats = Counter()
+ for token in parse(fp, skiptypes=args.skiptypes):
+ if not token:
+ if not args.stats_only:
+ fout.write("\n")
+ else:
+ word, prefix, label = token
+ if args.stats_only:
+ type_stats[label] += 1
+ else:
+ fout.write("%s\t%s%s\n" % (word, prefix, label))
+
+ if args.stats_only:
+ fout.write(str(type_stats) + "\n")
+
+ fp.close()
+ if args.outfile is not None:
+ fout.close()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/stanza/stanza/utils/datasets/sentiment/prepare_sentiment_dataset.py b/stanza/stanza/utils/datasets/sentiment/prepare_sentiment_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..95e50a65b827a05d2c9a331556c1e61b34aef62f
--- /dev/null
+++ b/stanza/stanza/utils/datasets/sentiment/prepare_sentiment_dataset.py
@@ -0,0 +1,441 @@
+"""Prepare a single dataset or a combination dataset for the sentiment project
+
+Manipulates various downloads from their original form to a form
+usable by the classifier model
+
+Explanations for the existing datasets are below.
+After processing the dataset, you can train with
+the run_sentiment script
+
+python3 -m stanza.utils.datasets.sentiment.prepare_sentiment_dataset
+python3 -m stanza.utils.training.run_sentiment
+
+English
+-------
+
+SST (Stanford Sentiment Treebank)
+ https://nlp.stanford.edu/sentiment/
+ https://github.com/stanfordnlp/sentiment-treebank
+ The git repo includes fixed tokenization and sentence splits, along
+ with a partial conversion to updated PTB tokenization standards.
+
+ The first step is to git clone the SST to here:
+ $SENTIMENT_BASE/sentiment-treebank
+ eg:
+ cd $SENTIMENT_BASE
+ git clone git@github.com:stanfordnlp/sentiment-treebank.git
+
+ There are a few different usages of SST.
+
+ The scores most commonly reported are for SST-2,
+ positive and negative only.
+ To get a version of this:
+
+ python3 -m stanza.utils.datasets.sentiment.prepare_sentiment_dataset en_sst2
+ python3 -m stanza.utils.datasets.sentiment.prepare_sentiment_dataset en_sst2roots
+
+ The model we distribute is a three class model (+, 0, -)
+ with some smaller datasets added for better coverage.
+ See "sstplus" below.
+
+MELD
+ https://github.com/SenticNet/MELD/tree/master/data/MELD
+ https://github.com/SenticNet/MELD
+ https://arxiv.org/pdf/1810.02508.pdf
+
+ MELD: A Multimodal Multi-Party Dataset for Emotion Recognition in Conversation. ACL 2019.
+ S. Poria, D. Hazarika, N. Majumder, G. Naik, E. Cambria, R. Mihalcea.
+
+ An Emotion Corpus of Multi-Party Conversations.
+ Chen, S.Y., Hsu, C.C., Kuo, C.C. and Ku, L.W.
+
+ Copy the three files in the repo into
+ $SENTIMENT_BASE/MELD
+ TODO: make it so you git clone the repo instead
+
+ There are train/dev/test splits, so you can build a model
+ out of just this corpus. The first step is to convert
+ to the classifier data format:
+
+ python3 -m stanza.utils.datasets.sentiment.prepare_sentiment_dataset en_meld
+
+ However, in general we simply include this in the sstplus model
+ rather than releasing a separate model.
+
+Arguana
+ http://argumentation.bplaced.net/arguana/data
+ http://argumentation.bplaced.net/arguana-data/arguana-tripadvisor-annotated-v2.zip
+
+ http://argumentation.bplaced.net/arguana-publications/papers/wachsmuth14a-cicling.pdf
+ A Review Corpus for Argumentation Analysis. CICLing 2014
+ Henning Wachsmuth, Martin Trenkmann, Benno Stein, Gregor Engels, Tsvetomira Palarkarska
+
+ Download the zip file and unzip it in
+ $SENTIMENT_BASE/arguana
+
+ This is included in the sstplus model.
+
+airline
+ A Kaggle corpus for sentiment detection on airline tweets.
+ We include this in sstplus as well.
+
+ https://www.kaggle.com/datasets/crowdflower/twitter-airline-sentiment
+
+ Download Tweets.csv and put it in
+ $SENTIMENT_BASE/airline
+
+SLSD
+ https://archive.ics.uci.edu/ml/datasets/Sentiment+Labelled+Sentences
+
+ From Group to Individual Labels using Deep Features. KDD 2015
+ Kotzias et. al
+
+ Put the contents of the zip file in
+ $SENTIMENT_BASE/slsd
+
+ The sstplus model includes this as training data
+
+en_sstplus
+ This is a three class model built from SST, along with the additional
+ English data sources above for coverage of additional domains.
+
+ python3 -m stanza.utils.datasets.sentiment.prepare_sentiment_dataset en_sstplus
+
+en_corona
+ A kaggle covid-19 text classification dataset
+ https://www.kaggle.com/datasets/datatattle/covid-19-nlp-text-classification
+ python3 -m stanza.utils.datasets.sentiment.prepare_sentiment_dataset en_corona
+
+German
+------
+
+de_sb10k
+ This used to be here:
+ https://www.spinningbytes.com/resources/germansentiment/
+ Now it appears to have moved here?
+ https://github.com/oliverguhr/german-sentiment
+
+ https://dl.acm.org/doi/pdf/10.1145/3038912.3052611
+ Leveraging Large Amounts of Weakly Supervised Data for Multi-Language Sentiment Classification
+ WWW '17: Proceedings of the 26th International Conference on World Wide Web
+ Jan Deriu, Aurelien Lucchi, Valeria De Luca, Aliaksei Severyn,
+ Simon Müller, Mark Cieliebak, Thomas Hofmann, Martin Jaggi
+
+ The current prep script works on the old version of the data.
+ TODO: update to work on the git repo
+
+ python3 -m stanza.utils.datasets.sentiment.prepare_sentiment_dataset de_sb10k
+
+de_scare
+ http://romanklinger.de/scare/
+
+ The Sentiment Corpus of App Reviews with Fine-grained Annotations in German
+ LREC 2016
+ Mario Sänger, Ulf Leser, Steffen Kemmerer, Peter Adolphs, and Roman Klinger
+
+ Download the data and put it in
+ $SENTIMENT_BASE/german/scare
+ There should be two subdirectories once you are done:
+ scare_v1.0.0
+ scare_v1.0.0_text
+
+ We wound up not including this in the default German model.
+ It might be worth revisiting in the future.
+
+de_usage
+ https://www.romanklinger.de/usagecorpus/
+
+ http://www.lrec-conf.org/proceedings/lrec2014/summaries/85.html
+ The USAGE Review Corpus for Fine Grained Multi Lingual Opinion Analysis
+ Roman Klinger and Philipp Cimiano
+
+ Again, not included in the default German model
+
+Chinese
+-------
+
+zh-hans_ren
+ This used to be here:
+ http://a1-www.is.tokushima-u.ac.jp/member/ren/Ren-CECps1.0/Ren-CECps1.0.html
+
+ That page doesn't seem to respond as of 2022, and I can't find it elsewhere.
+
+The following will be available starting in 1.4.1:
+
+Spanish
+-------
+
+tass2020
+ - http://tass.sepln.org/2020/?page_id=74
+ - Download the following 5 files:
+ task1.2-test-gold.tsv
+ Task1-train-dev.zip
+ tass2020-test-gold.zip
+ Test1.1.zip
+ test1.2.zip
+ Put them in a directory
+ $SENTIMENT_BASE/spanish/tass2020
+
+ python3 -m stanza.utils.datasets.sentiment.prepare_sentiment_dataset es_tass2020
+
+
+Vietnamese
+----------
+
+vi_vsfc
+ I found a corpus labeled VSFC here:
+ https://drive.google.com/drive/folders/1xclbjHHK58zk2X6iqbvMPS2rcy9y9E0X
+ It doesn't seem to have a license or paper associated with it,
+ but happy to put those details here if relevant.
+
+ Download the files to
+ $SENTIMENT_BASE/vietnamese/_UIT-VSFC
+
+ python3 -m stanza.utils.datasets.sentiment.prepare_sentiment_dataset vi_vsfc
+
+Marathi
+-------
+
+mr_l3cube
+ https://github.com/l3cube-pune/MarathiNLP
+
+ https://arxiv.org/abs/2103.11408
+ L3CubeMahaSent: A Marathi Tweet-based Sentiment Analysis Dataset
+ Atharva Kulkarni, Meet Mandhane, Manali Likhitkar, Gayatri Kshirsagar, Raviraj Joshi
+
+ git clone the repo in
+ $SENTIMENT_BASE
+
+ cd $SENTIMENT_BASE
+ git clone git@github.com:l3cube-pune/MarathiNLP.git
+
+ python3 -m stanza.utils.datasets.sentiment.prepare_sentiment_dataset mr_l3cube
+
+
+Italian
+-------
+
+it_sentipolc16
+ from here:
+ http://www.di.unito.it/~tutreeb/sentipolc-evalita16/data.html
+ paper describing the evaluation and the results:
+ http://ceur-ws.org/Vol-1749/paper_026.pdf
+
+ download the training and test zip files to $SENTIMENT_BASE/italian/sentipolc16
+ unzip them there
+
+ so you should have
+ $SENTIMENT_BASE/italian/sentipolc16/test_set_sentipolc16_gold2000.csv
+ $SENTIMENT_BASE/italian/sentipolc16/training_set_sentipolc16.csv
+
+ python3 -m stanza.utils.datasets.sentiment.prepare_sentiment_dataset it_sentipolc16
+
+ this script splits the training data into dev & train, keeps the test the same
+
+ The conversion allows for 4 ways of handling the "mixed" class:
+ treat it as the same as neutral, treat it as a separate class,
+ only distinguish positive or not positive,
+ only distinguish negative or not negative
+ for more details:
+ python3 -m stanza.utils.datasets.sentiment.prepare_sentiment_dataset it_sentipolc16 --help
+
+another option not implemented yet: absita18
+ http://sag.art.uniroma2.it/absita/data/
+"""
+
+import os
+import random
+import sys
+
+import stanza.utils.default_paths as default_paths
+
+from stanza.utils.datasets.sentiment import process_airline
+from stanza.utils.datasets.sentiment import process_arguana_xml
+from stanza.utils.datasets.sentiment import process_corona
+from stanza.utils.datasets.sentiment import process_es_tass2020
+from stanza.utils.datasets.sentiment import process_it_sentipolc16
+from stanza.utils.datasets.sentiment import process_MELD
+from stanza.utils.datasets.sentiment import process_ren_chinese
+from stanza.utils.datasets.sentiment import process_sb10k
+from stanza.utils.datasets.sentiment import process_scare
+from stanza.utils.datasets.sentiment import process_slsd
+from stanza.utils.datasets.sentiment import process_sst
+from stanza.utils.datasets.sentiment import process_usage_german
+from stanza.utils.datasets.sentiment import process_vsfc_vietnamese
+
+from stanza.utils.datasets.sentiment import process_utils
+
+def convert_sst_general(paths, dataset_name, version):
+ in_directory = paths['SENTIMENT_BASE']
+ sst_dir = os.path.join(in_directory, "sentiment-treebank")
+ train_phrases = process_sst.get_phrases(version, "train.txt", sst_dir)
+ dev_phrases = process_sst.get_phrases(version, "dev.txt", sst_dir)
+ test_phrases = process_sst.get_phrases(version, "test.txt", sst_dir)
+
+ out_directory = paths['SENTIMENT_DATA_DIR']
+ dataset = [train_phrases, dev_phrases, test_phrases]
+ process_utils.write_dataset(dataset, out_directory, dataset_name)
+
+def convert_sst2(paths, dataset_name, *args):
+ """
+ Create a 2 class SST dataset (neutral items are dropped)
+ """
+ convert_sst_general(paths, dataset_name, "binary")
+
+def convert_sst2roots(paths, dataset_name, *args):
+ """
+ Create a 2 class SST dataset using only the roots
+ """
+ convert_sst_general(paths, dataset_name, "binaryroot")
+
+def convert_sst3(paths, dataset_name, *args):
+ """
+ Create a 3 class SST dataset using only the roots
+ """
+ convert_sst_general(paths, dataset_name, "threeclass")
+
+def convert_sst3roots(paths, dataset_name, *args):
+ """
+ Create a 3 class SST dataset using only the roots
+ """
+ convert_sst_general(paths, dataset_name, "threeclassroot")
+
+def convert_sstplus(paths, dataset_name, *args):
+ """
+ Create a 3 class SST dataset with a few other small datasets added
+ """
+ train_phrases = []
+ in_directory = paths['SENTIMENT_BASE']
+ train_phrases.extend(process_arguana_xml.get_tokenized_phrases(os.path.join(in_directory, "arguana")))
+ train_phrases.extend(process_MELD.get_tokenized_phrases("train", os.path.join(in_directory, "MELD")))
+ train_phrases.extend(process_slsd.get_tokenized_phrases(os.path.join(in_directory, "slsd")))
+ train_phrases.extend(process_airline.get_tokenized_phrases(os.path.join(in_directory, "airline")))
+
+ sst_dir = os.path.join(in_directory, "sentiment-treebank")
+ train_phrases.extend(process_sst.get_phrases("threeclass", "train.txt", sst_dir))
+ train_phrases.extend(process_sst.get_phrases("threeclass", "extra-train.txt", sst_dir))
+ train_phrases.extend(process_sst.get_phrases("threeclass", "checked-extra-train.txt", sst_dir))
+
+ dev_phrases = process_sst.get_phrases("threeclass", "dev.txt", sst_dir)
+ test_phrases = process_sst.get_phrases("threeclass", "test.txt", sst_dir)
+
+ out_directory = paths['SENTIMENT_DATA_DIR']
+ dataset = [train_phrases, dev_phrases, test_phrases]
+ process_utils.write_dataset(dataset, out_directory, dataset_name)
+
+def convert_meld(paths, dataset_name, *args):
+ """
+ Convert the MELD dataset to train/dev/test files
+ """
+ in_directory = os.path.join(paths['SENTIMENT_BASE'], "MELD")
+ out_directory = paths['SENTIMENT_DATA_DIR']
+ process_MELD.main(in_directory, out_directory, dataset_name)
+
+def convert_corona(paths, dataset_name, *args):
+ """
+ Convert the kaggle covid dataset to train/dev/test files
+ """
+ process_corona.main(*args)
+
+def convert_scare(paths, dataset_name, *args):
+ in_directory = os.path.join(paths['SENTIMENT_BASE'], "german", "scare")
+ out_directory = paths['SENTIMENT_DATA_DIR']
+ process_scare.main(in_directory, out_directory, dataset_name)
+
+
+def convert_de_usage(paths, dataset_name, *args):
+ in_directory = os.path.join(paths['SENTIMENT_BASE'], "USAGE")
+ out_directory = paths['SENTIMENT_DATA_DIR']
+ process_usage_german.main(in_directory, out_directory, dataset_name)
+
+def convert_sb10k(paths, dataset_name, *args):
+ """
+ Essentially runs the sb10k script twice with different arguments to produce the de_sb10k dataset
+
+ stanza.utils.datasets.sentiment.process_sb10k --csv_filename extern_data/sentiment/german/sb-10k/de_full/de_test.tsv --out_dir $SENTIMENT_DATA_DIR --short_name de_sb10k --split test --sentiment_column 2 --text_column 3
+ stanza.utils.datasets.sentiment.process_sb10k --csv_filename extern_data/sentiment/german/sb-10k/de_full/de_train.tsv --out_dir $SENTIMENT_DATA_DIR --short_name de_sb10k --split train_dev --sentiment_column 2 --text_column 3
+ """
+ column_args = ["--sentiment_column", "2", "--text_column", "3"]
+
+ process_sb10k.main(["--csv_filename", os.path.join(paths['SENTIMENT_BASE'], "german", "sb-10k", "de_full", "de_test.tsv"),
+ "--out_dir", paths['SENTIMENT_DATA_DIR'],
+ "--short_name", dataset_name,
+ "--split", "test",
+ *column_args])
+ process_sb10k.main(["--csv_filename", os.path.join(paths['SENTIMENT_BASE'], "german", "sb-10k", "de_full", "de_train.tsv"),
+ "--out_dir", paths['SENTIMENT_DATA_DIR'],
+ "--short_name", dataset_name,
+ "--split", "train_dev",
+ *column_args])
+
+def convert_vi_vsfc(paths, dataset_name, *args):
+ in_directory = os.path.join(paths['SENTIMENT_BASE'], "vietnamese", "_UIT-VSFC")
+ out_directory = paths['SENTIMENT_DATA_DIR']
+ process_vsfc_vietnamese.main(in_directory, out_directory, dataset_name)
+
+def convert_mr_l3cube(paths, dataset_name, *args):
+ # csv_filename = 'extern_data/sentiment/MarathiNLP/L3CubeMahaSent Dataset/tweets-train.csv'
+ MAPPING = {"-1": "0", "0": "1", "1": "2"}
+
+ out_directory = paths['SENTIMENT_DATA_DIR']
+ os.makedirs(out_directory, exist_ok=True)
+
+ in_directory = os.path.join(paths['SENTIMENT_BASE'], "MarathiNLP", "L3CubeMahaSent Dataset")
+ input_files = ['tweets-train.csv', 'tweets-valid.csv', 'tweets-test.csv']
+ input_files = [os.path.join(in_directory, x) for x in input_files]
+ datasets = [process_utils.read_snippets(csv_filename, sentiment_column=1, text_column=0, tokenizer_language="mr", mapping=MAPPING, delimiter=',', quotechar='"', skip_first_line=True)
+ for csv_filename in input_files]
+
+ process_utils.write_dataset(datasets, out_directory, dataset_name)
+
+def convert_es_tass2020(paths, dataset_name, *args):
+ process_es_tass2020.convert_tass2020(paths['SENTIMENT_BASE'], paths['SENTIMENT_DATA_DIR'], dataset_name)
+
+def convert_it_sentipolc16(paths, dataset_name, *args):
+ in_directory = os.path.join(paths['SENTIMENT_BASE'], "italian", "sentipolc16")
+ out_directory = paths['SENTIMENT_DATA_DIR']
+ process_it_sentipolc16.main(in_directory, out_directory, dataset_name, *args)
+
+
+def convert_ren(paths, dataset_name, *args):
+ in_directory = os.path.join(paths['SENTIMENT_BASE'], "chinese", "RenCECps")
+ out_directory = paths['SENTIMENT_DATA_DIR']
+ process_ren_chinese.main(in_directory, out_directory, dataset_name)
+
+DATASET_MAPPING = {
+ "de_sb10k": convert_sb10k,
+ "de_scare": convert_scare,
+ "de_usage": convert_de_usage,
+
+ "en_corona": convert_corona,
+ "en_sst2": convert_sst2,
+ "en_sst2roots": convert_sst2roots,
+ "en_sst3": convert_sst3,
+ "en_sst3roots": convert_sst3roots,
+ "en_sstplus": convert_sstplus,
+ "en_meld": convert_meld,
+
+ "es_tass2020": convert_es_tass2020,
+
+ "it_sentipolc16": convert_it_sentipolc16,
+
+ "mr_l3cube": convert_mr_l3cube,
+
+ "vi_vsfc": convert_vi_vsfc,
+
+ "zh-hans_ren": convert_ren,
+}
+
+def main(dataset_name, *args):
+ paths = default_paths.get_default_paths()
+
+ random.seed(1234)
+
+ if dataset_name in DATASET_MAPPING:
+ DATASET_MAPPING[dataset_name](paths, dataset_name, *args)
+ else:
+ raise ValueError(f"dataset {dataset_name} currently not handled")
+
+if __name__ == '__main__':
+ main(sys.argv[1], sys.argv[2:])
+
diff --git a/stanza/stanza/utils/datasets/sentiment/process_corona.py b/stanza/stanza/utils/datasets/sentiment/process_corona.py
new file mode 100644
index 0000000000000000000000000000000000000000..35cb6c35614d92986c84b20804b60132a600bb4a
--- /dev/null
+++ b/stanza/stanza/utils/datasets/sentiment/process_corona.py
@@ -0,0 +1,69 @@
+"""
+Processes a kaggle covid-19 text classification dataset
+
+The original description of the dataset is here:
+
+https://www.kaggle.com/datasets/datatattle/covid-19-nlp-text-classification
+
+There are two files in the archive, Corona_NLP_train.csv and Corona_NLP_test.csv
+Unzip the files in archive.zip to $SENTIMENT_BASE/english/corona/Corona_NLP_train.csv
+
+There is no dedicated dev set, so we randomly split train/dev
+(using a specific seed, so that the split always comes out the same)
+"""
+
+import argparse
+import os
+import random
+
+import stanza
+
+import stanza.utils.datasets.sentiment.process_utils as process_utils
+from stanza.utils.default_paths import get_default_paths
+
+# TODO: could give an option to keep the 'extremely'
+MAPPING = {'extremely positive': "2",
+ 'positive': "2",
+ 'neutral': "1",
+ 'negative': "0",
+ 'extremely negative': "0"}
+
+def main(args=None):
+ default_paths = get_default_paths()
+ sentiment_base_dir = default_paths["SENTIMENT_BASE"]
+ default_in_dir = os.path.join(sentiment_base_dir, "english", "corona")
+ default_out_dir = default_paths["SENTIMENT_DATA_DIR"]
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--in_dir', type=str, default=default_in_dir, help='Where to get the input files')
+ parser.add_argument('--out_dir', type=str, default=default_out_dir, help='Where to write the output files')
+ parser.add_argument('--short_name', type=str, default="en_corona", help='short name to use when writing files')
+ args = parser.parse_args(args=args)
+
+ TEXT_COLUMN = 4
+ SENTIMENT_COLUMN = 5
+
+ train_csv = os.path.join(args.in_dir, "Corona_NLP_train.csv")
+ test_csv = os.path.join(args.in_dir, "Corona_NLP_test.csv")
+
+ nlp = stanza.Pipeline("en", processors='tokenize')
+
+ train_snippets = process_utils.read_snippets(train_csv, SENTIMENT_COLUMN, TEXT_COLUMN, 'en', MAPPING, delimiter=",", quotechar='"', skip_first_line=True, nlp=nlp, encoding="latin1")
+ test_snippets = process_utils.read_snippets(test_csv, SENTIMENT_COLUMN, TEXT_COLUMN, 'en', MAPPING, delimiter=",", quotechar='"', skip_first_line=True, nlp=nlp, encoding="latin1")
+
+ print("Read %d train snippets" % len(train_snippets))
+ print("Read %d test snippets" % len(test_snippets))
+
+ random.seed(1234)
+ random.shuffle(train_snippets)
+
+ os.makedirs(args.out_dir, exist_ok=True)
+ process_utils.write_splits(args.out_dir,
+ train_snippets,
+ (process_utils.Split("%s.train.json" % args.short_name, 0.9),
+ process_utils.Split("%s.dev.json" % args.short_name, 0.1)))
+ process_utils.write_list(os.path.join(args.out_dir, "%s.test.json" % args.short_name), test_snippets)
+
+if __name__ == '__main__':
+ main()
+
diff --git a/stanza/stanza/utils/datasets/sentiment/process_it_sentipolc16.py b/stanza/stanza/utils/datasets/sentiment/process_it_sentipolc16.py
new file mode 100644
index 0000000000000000000000000000000000000000..44ba7c6169d3ed3c902817627ad31e5033311132
--- /dev/null
+++ b/stanza/stanza/utils/datasets/sentiment/process_it_sentipolc16.py
@@ -0,0 +1,92 @@
+"""
+Process the SentiPolc dataset from Evalita
+
+Can be run as a standalone script or as a module from
+prepare_sentiment_dataset
+
+An option controls how to split up the positive/negative/neutral/mixed classes
+"""
+
+import argparse
+from enum import Enum
+import os
+import random
+import sys
+
+import stanza
+from stanza.utils.datasets.sentiment import process_utils
+import stanza.utils.default_paths as default_paths
+
+class Mode(Enum):
+ COMBINED = 1
+ SEPARATE = 2
+ POSITIVE = 3
+ NEGATIVE = 4
+
+def main(in_dir, out_dir, short_name, *args):
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--mode', default=Mode.COMBINED, type=lambda x: Mode[x.upper()],
+ help='How to handle mixed vs neutral. {}'.format(", ".join(x.name for x in Mode)))
+ parser.add_argument('--name', default=None, type=str,
+ help='Use a different name to save the dataset. Useful for keeping POSITIVE & NEGATIVE separate')
+ args = parser.parse_args(args=list(*args))
+
+ if args.name is not None:
+ short_name = args.name
+
+ nlp = stanza.Pipeline("it", processors='tokenize')
+
+ if args.mode == Mode.COMBINED:
+ mapping = {
+ ('0', '0'): "1", # neither negative nor positive: neutral
+ ('1', '0'): "2", # positive, not negative: positive
+ ('0', '1'): "0", # negative, not positive: negative
+ ('1', '1'): "1", # mixed combined with neutral
+ }
+ elif args.mode == Mode.SEPARATE:
+ mapping = {
+ ('0', '0'): "1", # neither negative nor positive: neutral
+ ('1', '0'): "2", # positive, not negative: positive
+ ('0', '1'): "0", # negative, not positive: negative
+ ('1', '1'): "3", # mixed as a different class
+ }
+ elif args.mode == Mode.POSITIVE:
+ mapping = {
+ ('0', '0'): "0", # neutral -> not positive
+ ('1', '0'): "1", # positive -> positive
+ ('0', '1'): "0", # negative -> not positive
+ ('1', '1'): "1", # mixed -> positive
+ }
+ elif args.mode == Mode.NEGATIVE:
+ mapping = {
+ ('0', '0'): "0", # neutral -> not negative
+ ('1', '0'): "0", # positive -> not negative
+ ('0', '1'): "1", # negative -> negative
+ ('1', '1'): "1", # mixed -> negative
+ }
+
+ print("Using {} scheme to handle the 4 values. Mapping: {}".format(args.mode, mapping))
+ print("Saving to {} using the short name {}".format(out_dir, short_name))
+
+ test_filename = os.path.join(in_dir, "test_set_sentipolc16_gold2000.csv")
+ test_snippets = process_utils.read_snippets(test_filename, (2,3), 8, "it", mapping, delimiter=",", skip_first_line=False, quotechar='"', nlp=nlp)
+
+ train_filename = os.path.join(in_dir, "training_set_sentipolc16.csv")
+ train_snippets = process_utils.read_snippets(train_filename, (2,3), 8, "it", mapping, delimiter=",", skip_first_line=True, quotechar='"', nlp=nlp)
+
+ random.shuffle(train_snippets)
+ dev_len = len(train_snippets) // 10
+ dev_snippets = train_snippets[:dev_len]
+ train_snippets = train_snippets[dev_len:]
+
+ dataset = (train_snippets, dev_snippets, test_snippets)
+
+ process_utils.write_dataset(dataset, out_dir, short_name)
+
+if __name__ == '__main__':
+ paths = default_paths.get_default_paths()
+ random.seed(1234)
+
+ in_directory = os.path.join(paths['SENTIMENT_BASE'], "italian", "sentipolc16")
+ out_directory = paths['SENTIMENT_DATA_DIR']
+ main(in_directory, out_directory, "it_sentipolc16", sys.argv[1:])
diff --git a/stanza/stanza/utils/datasets/sentiment/process_sb10k.py b/stanza/stanza/utils/datasets/sentiment/process_sb10k.py
new file mode 100644
index 0000000000000000000000000000000000000000..2cdabac56fe90dc33683bb76b39a5e38b68eba9f
--- /dev/null
+++ b/stanza/stanza/utils/datasets/sentiment/process_sb10k.py
@@ -0,0 +1,76 @@
+"""
+Processes the SB10k dataset
+
+The original description of the dataset and corpus_v1.0.tsv is here:
+
+https://www.spinningbytes.com/resources/germansentiment/
+
+Download script is here:
+
+https://github.com/aritter/twitter_download
+
+The problem with this file is that many of the tweets with labels no
+longer exist. Roughly 1/3 as of June 2020.
+
+You can contact the authors for the complete dataset.
+
+There is a paper describing some experiments run on the dataset here:
+https://dl.acm.org/doi/pdf/10.1145/3038912.3052611
+"""
+
+import argparse
+import os
+import random
+
+from enum import Enum
+
+import stanza.utils.datasets.sentiment.process_utils as process_utils
+
+class Split(Enum):
+ TRAIN_DEV_TEST = 1
+ TRAIN_DEV = 2
+ TEST = 3
+
+MAPPING = {'positive': "2",
+ 'neutral': "1",
+ 'negative': "0"}
+
+def main(args=None):
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--csv_filename', type=str, default=None, help='CSV file to read in')
+ parser.add_argument('--out_dir', type=str, default=None, help='Where to write the output files')
+ parser.add_argument('--sentiment_column', type=int, default=2, help='Column with the sentiment')
+ parser.add_argument('--text_column', type=int, default=3, help='Column with the text')
+ parser.add_argument('--short_name', type=str, default="sb10k", help='short name to use when writing files')
+
+ parser.add_argument('--split', type=lambda x: Split[x.upper()], default=Split.TRAIN_DEV_TEST,
+ help="How to split the resulting data")
+
+ args = parser.parse_args(args=args)
+
+ snippets = process_utils.read_snippets(args.csv_filename, args.sentiment_column, args.text_column, 'de', MAPPING)
+
+ print(len(snippets))
+ random.shuffle(snippets)
+
+ os.makedirs(args.out_dir, exist_ok=True)
+ if args.split is Split.TRAIN_DEV_TEST:
+ process_utils.write_splits(args.out_dir,
+ snippets,
+ (process_utils.Split("%s.train.json" % args.short_name, 0.8),
+ process_utils.Split("%s.dev.json" % args.short_name, 0.1),
+ process_utils.Split("%s.test.json" % args.short_name, 0.1)))
+ elif args.split is Split.TRAIN_DEV:
+ process_utils.write_splits(args.out_dir,
+ snippets,
+ (process_utils.Split("%s.train.json" % args.short_name, 0.9),
+ process_utils.Split("%s.dev.json" % args.short_name, 0.1)))
+ elif args.split is Split.TEST:
+ process_utils.write_list(os.path.join(args.out_dir, "%s.test.json" % args.short_name), snippets)
+ else:
+ raise ValueError("Unknown split method {}".format(args.split))
+
+if __name__ == '__main__':
+ random.seed(1234)
+ main()
+
diff --git a/stanza/stanza/utils/datasets/tokenization/convert_ml_cochin.py b/stanza/stanza/utils/datasets/tokenization/convert_ml_cochin.py
new file mode 100644
index 0000000000000000000000000000000000000000..f7523e68c867093d9abb6b86e01f702a8b606f22
--- /dev/null
+++ b/stanza/stanza/utils/datasets/tokenization/convert_ml_cochin.py
@@ -0,0 +1,229 @@
+"""
+Convert a Malayalam NER dataset to a tokenization dataset using
+the additional labeling provided by TTec's Indian partners
+
+This is still WIP - ongoing discussion with TTec and the team at UFAL
+doing the UD Malayalam dataset - but if someone wants the data to
+recreate it, feel free to contact Prof. Manning or John Bauer
+
+Data was annotated through Datasaur by TTec - possibly another team
+involved, will double check with the annotators.
+
+#1 current issue with the data is a difference in annotation style
+observed by the UFAL group. I believe TTec is working on reannotating
+this.
+
+Discussing the first sentence in the first split file:
+
+> I am not sure about the guidelines that the annotators followed, but
+> I would not have split നാമജപത്തോടുകൂടി as നാമ --- ജപത്തോടുകൂടി. Because
+> they are not multiple syntactic words. I would have done it like
+> നാമജപത്തോടു --- കൂടി as കൂടി ('with') can be tagged as ADP. I agree with
+> the second MWT വ്യത്യസ്തം --- കൂടാതെ.
+>
+> In Malayalam, we do have many words which potentially can be treated
+> as compounds and split but sometimes it becomes difficult to make
+> that decision as the etymology or the word formation process is
+> unclear. So for the Malayalam UD annotations I stayed away from
+> doing it because I didn't find it necessary and moreover the
+> guidelines say that the words should be split into syntactic words
+> and not into morphemes.
+
+As for using this script, create a directory extern_data/malayalam/cochin_ner/
+The original NER dataset from Cochin University going there:
+extern_data/malayalam/cochin_ner/final_ner.txt
+The relabeled data from TTEC goes in
+extern_data/malayalam/cochin_ner/relabeled_tsv/malayalam_File_1.txt.tsv etc etc
+
+This can be invoked from the command line, or it can be used as part of
+stanza/utils/datasets/prepare_tokenizer_treebank.py ml_cochin
+in which case the conll splits will be turned into tokenizer labels as well
+"""
+
+from difflib import SequenceMatcher
+import os
+import random
+import sys
+
+import stanza.utils.default_paths as default_paths
+
+def read_words(filename):
+ with open(filename, encoding="utf-8") as fin:
+ text = fin.readlines()
+ text = [x.strip().split()[0] if x.strip() else "" for x in text]
+ return text
+
+def read_original_text(input_dir):
+ original_file = os.path.join(input_dir, "final_ner.txt")
+ return read_words(original_file)
+
+def list_relabeled_files(relabeled_dir):
+ tsv_files = os.listdir(relabeled_dir)
+ assert all(x.startswith("malayalam_File_") and x.endswith(".txt.tsv") for x in tsv_files)
+ tsv_files = sorted(tsv_files, key = lambda filename: int(filename.split(".")[0].split("_")[2]))
+ return tsv_files
+
+def find_word(original_text, target, start_index, end_index):
+ for word in original_text[start_index:end_index]:
+ if word == target:
+ return True
+ return False
+
+def scan_file(original_text, current_index, tsv_file):
+ relabeled_text = read_words(tsv_file)
+ # for now, at least, we ignore these markers
+ relabeled_indices = [idx for idx, x in enumerate(relabeled_text) if x != '$' and x != '^']
+ relabeled_text = [x for x in relabeled_text if x != '$' and x != '^']
+ diffs = SequenceMatcher(None, original_text, relabeled_text, False)
+
+ blocks = diffs.get_matching_blocks()
+ assert blocks[-1].size == 0
+ if len(blocks) == 1:
+ raise ValueError("Could not find a match between %s and the original text" % tsv_file)
+
+ sentences = []
+ current_sentence = []
+
+ in_mwt = False
+ bad_sentence = False
+ current_mwt = []
+ block_index = 0
+ current_block = blocks[0]
+ for tsv_index, next_word in enumerate(relabeled_text):
+ if not next_word:
+ if in_mwt:
+ current_mwt = []
+ in_mwt = False
+ bad_sentence = True
+ print("Unclosed MWT found at %s line %d" % (tsv_file, tsv_index))
+ if current_sentence:
+ if not bad_sentence:
+ sentences.append(current_sentence)
+ bad_sentence = False
+ current_sentence = []
+ continue
+
+ # tsv_index will now be inside the current block or before the current block
+ while tsv_index >= blocks[block_index].b + current_block.size:
+ block_index += 1
+ current_block = blocks[block_index]
+ #print(tsv_index, current_block.b, current_block.size)
+
+ if next_word == ',' or next_word == '.':
+ # many of these punctuations were added by the relabelers
+ current_sentence.append(next_word)
+ continue
+ if tsv_index >= current_block.b and tsv_index < current_block.b + current_block.size:
+ # ideal case: in a matching block
+ current_sentence.append(next_word)
+ continue
+
+ # in between blocks... need to handle re-spelled words and MWTs
+ if not in_mwt and next_word == '@':
+ in_mwt = True
+ continue
+ if not in_mwt:
+ current_sentence.append(next_word)
+ continue
+ if in_mwt and next_word == '@' and (tsv_index + 1 < len(relabeled_text) and relabeled_text[tsv_index+1] == '@'):
+ # we'll stop the MWT next time around
+ continue
+ if in_mwt and next_word == '@':
+ if block_index > 0 and (len(current_mwt) == 2 or len(current_mwt) == 3):
+ mwt = "".join(current_mwt)
+ start_original = blocks[block_index-1].a + blocks[block_index-1].size
+ end_original = current_block.a
+ if find_word(original_text, mwt, start_original, end_original):
+ current_sentence.append((mwt, current_mwt))
+ else:
+ print("%d word MWT %s at %s %d. Should be somewhere in %d %d" % (len(current_mwt), mwt, tsv_file, relabeled_indices[tsv_index], start_original, end_original))
+ bad_sentence = True
+ elif len(current_mwt) > 6:
+ raise ValueError("Unreasonably long MWT span in %s at line %d" % (tsv_file, relabeled_indices[tsv_index]))
+ elif len(current_mwt) > 3:
+ print("%d word sequence, stop being lazy - %s %d" % (len(current_mwt), tsv_file, relabeled_indices[tsv_index]))
+ bad_sentence = True
+ else:
+ # short MWT, but it was at the start of a file, and we don't want to search the whole file for the item
+ # TODO, could maybe search the 10 words or so before the start of the block?
+ bad_sentence = True
+ current_mwt = []
+ in_mwt = False
+ continue
+ # now we know we are in an MWT... TODO
+ current_mwt.append(next_word)
+
+ if len(current_sentence) > 0 and not bad_sentence:
+ sentences.append(current_sentence)
+
+ return current_index, sentences
+
+def split_sentences(sentences):
+ train = []
+ dev = []
+ test = []
+
+ for sentence in sentences:
+ rand = random.random()
+ if rand < 0.8:
+ train.append(sentence)
+ elif rand < 0.9:
+ dev.append(sentence)
+ else:
+ test.append(sentence)
+
+ return train, dev, test
+
+def main(input_dir, tokenizer_dir, relabeled_dir="relabeled_tsv", split_data=True):
+ random.seed(1006)
+
+ input_dir = os.path.join(input_dir, "malayalam", "cochin_ner")
+ relabeled_dir = os.path.join(input_dir, relabeled_dir)
+ tsv_files = list_relabeled_files(relabeled_dir)
+
+ original_text = read_original_text(input_dir)
+ print("Original text len: %d" %len(original_text))
+ current_index = 0
+ sentences = []
+ for tsv_file in tsv_files:
+ print(tsv_file)
+ current_index, new_sentences = scan_file(original_text, current_index, os.path.join(relabeled_dir, tsv_file))
+ sentences.extend(new_sentences)
+
+ print("Found %d sentences" % len(sentences))
+
+ if split_data:
+ splits = split_sentences(sentences)
+ SHARDS = ("train", "dev", "test")
+ else:
+ splits = [sentences]
+ SHARDS = ["train"]
+
+ for split, shard in zip(splits, SHARDS):
+ output_filename = os.path.join(tokenizer_dir, "ml_cochin.%s.gold.conllu" % shard)
+ print("Writing %d sentences to %s" % (len(split), output_filename))
+ with open(output_filename, "w", encoding="utf-8") as fout:
+ for sentence in split:
+ word_idx = 1
+ for token in sentence:
+ if isinstance(token, str):
+ fake_dep = "\t0\troot" if word_idx == 1 else "\t1\tdep"
+ fout.write("%d\t%s" % (word_idx, token) + "\t_" * 4 + fake_dep + "\t_\t_\n")
+ word_idx += 1
+ else:
+ text = token[0]
+ mwt = token[1]
+ fout.write("%d-%d\t%s" % (word_idx, word_idx + len(mwt) - 1, text) + "\t_" * 8 + "\n")
+ for piece in mwt:
+ fake_dep = "\t0\troot" if word_idx == 1 else "\t1\tdep"
+ fout.write("%d\t%s" % (word_idx, piece) + "\t_" * 4 + fake_dep + "\t_\t_\n")
+ word_idx += 1
+ fout.write("\n")
+
+if __name__ == '__main__':
+ sys.stdout.reconfigure(encoding='utf-8')
+ paths = default_paths.get_default_paths()
+ tokenizer_dir = paths["TOKENIZE_DATA_DIR"]
+ input_dir = paths["STANZA_EXTERN_DIR"]
+ main(input_dir, tokenizer_dir, "relabeled_tsv_v2", False)
+
diff --git a/stanza/stanza/utils/datasets/tokenization/convert_th_orchid.py b/stanza/stanza/utils/datasets/tokenization/convert_th_orchid.py
new file mode 100644
index 0000000000000000000000000000000000000000..871e87d1dfc657cc77f87d778a0258aa11a09349
--- /dev/null
+++ b/stanza/stanza/utils/datasets/tokenization/convert_th_orchid.py
@@ -0,0 +1,163 @@
+"""Parses the xml conversion of orchid
+
+https://github.com/korakot/thainlp/blob/master/xmlchid.xml
+
+For example, if you put the data file in the above link in
+extern_data/thai/orchid/xmlchid.xml
+you would then run
+python3 -m stanza.utils.datasets.tokenization.convert_th_orchid extern_data/thai/orchid/xmlchid.xml data/tokenize
+
+Because there is no definitive train/dev/test split that we have found
+so far, we randomly shuffle the data on a paragraph level and split it
+80/10/10. A random seed is chosen so that the splits are reproducible.
+
+The datasets produced have a similar format to the UD datasets, so we
+give it a fake UD name to make life easier for the downstream tools.
+
+Training on this dataset seems to work best with low dropout numbers.
+For example:
+
+python3 -m stanza.utils.training.run_tokenizer th_orchid --dropout 0.05 --unit_dropout 0.05
+
+This results in a model with dev set scores:
+ th_orchid 87.98 70.94
+test set scores:
+ 91.60 72.43
+
+Apparently the random split produced a test set easier than the dev set.
+"""
+
+import os
+import random
+import sys
+import xml.etree.ElementTree as ET
+
+from stanza.utils.datasets.tokenization.process_thai_tokenization import write_dataset
+
+# line "122819" has some error in the tokenization of the musical notation
+# line "209380" is also messed up
+# others have @ followed by a part of speech, which is clearly wrong
+
+skipped_lines = {
+ "122819",
+ "209380",
+ "227769",
+ "245992",
+ "347163",
+ "409708",
+ "431227",
+}
+
+escape_sequences = {
+ '': '(',
+ '': ')',
+ '': '^',
+ '': '.',
+ '': '-',
+ '': '*',
+ '': '"',
+ '': '/',
+ '': ':',
+ '': '=',
+ '': ',',
+ '': ';',
+ '': '<',
+ '': '>',
+ '': '&',
+ '': '{',
+ '': '}',
+ '': "'",
+ '': '+',
+ '': '#',
+ '': '$',
+ '': '@',
+ '': '?',
+ '': '!',
+ 'appances': 'appliances',
+ 'intelgence': 'intelligence',
+ "'": "/'",
+ '<100>': '100',
+}
+
+allowed_sequences = {
+ '',
+ '',
+ '',
+ '',
+ '',
+ '',
+ '<---vp',
+ '<---',
+ '<----',
+}
+
+def read_data(input_filename):
+ print("Reading {}".format(input_filename))
+ tree = ET.parse(input_filename)
+ documents = parse_xml(tree)
+ print("Number of documents: {}".format(len(documents)))
+ print("Number of paragraphs: {}".format(sum(len(document) for document in documents)))
+ return documents
+
+def parse_xml(tree):
+ # we will put each paragraph in a separate block in the output file
+ # we won't pay any attention to the document boundaries unless we
+ # later find out it was necessary
+ # a paragraph will be a list of sentences
+ # a sentence is a list of words, where each word is a string
+ documents = []
+
+ root = tree.getroot()
+ for document in root:
+ # these should all be documents
+ if document.tag != 'document':
+ raise ValueError("Unexpected orchid xml layout: {}".format(document.tag))
+ paragraphs = []
+ for paragraph in document:
+ if paragraph.tag != 'paragraph':
+ raise ValueError("Unexpected orchid xml layout: {} under {}".format(paragraph.tag, document.tag))
+ sentences = []
+ for sentence in paragraph:
+ if sentence.tag != 'sentence':
+ raise ValueError("Unexpected orchid xml layout: {} under {}".format(sentence.tag, document.tag))
+ if sentence.attrib['line_num'] in skipped_lines:
+ continue
+ words = []
+ for word_idx, word in enumerate(sentence):
+ if word.tag != 'word':
+ raise ValueError("Unexpected orchid xml layout: {} under {}".format(word.tag, sentence.tag))
+ word = word.attrib['surface']
+ word = escape_sequences.get(word, word)
+ if word == '':
+ if word_idx == 0:
+ raise ValueError("Space character was the first token in a sentence: {}".format(sentence.attrib['line_num']))
+ else:
+ words[-1] = (words[-1][0], True)
+ continue
+ if len(word) > 1 and word[0] == '<' and word not in allowed_sequences:
+ raise ValueError("Unknown escape sequence {}".format(word))
+ words.append((word, False))
+ if len(words) == 0:
+ continue
+ words[-1] = (words[-1][0], True)
+ sentences.append(words)
+ paragraphs.append(sentences)
+ documents.append(paragraphs)
+
+ return documents
+
+
+def main(*args):
+ random.seed(1007)
+ if not args:
+ args = sys.argv[1:]
+ input_filename = args[0]
+ if os.path.isdir(input_filename):
+ input_filename = os.path.join(input_filename, "thai", "orchid", "xmlchid.xml")
+ output_dir = args[1]
+ documents = read_data(input_filename)
+ write_dataset(documents, output_dir, "orchid")
+
+
+if __name__ == '__main__':
+ main()
diff --git a/stanza/stanza/utils/datasets/tokenization/process_thai_tokenization.py b/stanza/stanza/utils/datasets/tokenization/process_thai_tokenization.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ef0e3d5b273694cb1d99eecc9c3ac3896b9f907
--- /dev/null
+++ b/stanza/stanza/utils/datasets/tokenization/process_thai_tokenization.py
@@ -0,0 +1,187 @@
+import os
+import random
+
+try:
+ from pythainlp import sent_tokenize
+except ImportError:
+ pass
+
+def write_section(output_dir, dataset_name, section, documents):
+ """
+ Writes a list of documents for tokenization, including a file in conll format
+
+ The Thai datasets generally have no MWT (apparently not relevant for Thai)
+
+ output_dir: the destination directory for the output files
+ dataset_name: orchid, BEST, lst20, etc
+ section: train/dev/test
+ documents: a nested list of documents, paragraphs, sentences, words
+ words is a list of (word, space_follows)
+ """
+ with open(os.path.join(output_dir, 'th_%s-ud-%s-mwt.json' % (dataset_name, section)), 'w') as fout:
+ fout.write("[]\n")
+
+ text_out = open(os.path.join(output_dir, 'th_%s.%s.txt' % (dataset_name, section)), 'w')
+ label_out = open(os.path.join(output_dir, 'th_%s-ud-%s.toklabels' % (dataset_name, section)), 'w')
+ for document in documents:
+ for paragraph in document:
+ for sentence_idx, sentence in enumerate(paragraph):
+ for word_idx, word in enumerate(sentence):
+ # TODO: split with newlines to make it more readable?
+ text_out.write(word[0])
+ for i in range(len(word[0]) - 1):
+ label_out.write("0")
+ if word_idx == len(sentence) - 1:
+ label_out.write("2")
+ else:
+ label_out.write("1")
+ if word[1] and (sentence_idx != len(paragraph) - 1 or word_idx != len(sentence) - 1):
+ text_out.write(' ')
+ label_out.write('0')
+
+ text_out.write("\n\n")
+ label_out.write("\n\n")
+
+ text_out.close()
+ label_out.close()
+
+ with open(os.path.join(output_dir, 'th_%s.%s.gold.conllu' % (dataset_name, section)), 'w') as fout:
+ for document in documents:
+ for paragraph in document:
+ new_par = True
+ for sentence in paragraph:
+ for word_idx, word in enumerate(sentence):
+ # SpaceAfter is left blank if there is space after the word
+ if word[1] and new_par:
+ space = 'NewPar=Yes'
+ elif word[1]:
+ space = '_'
+ elif new_par:
+ space = 'SpaceAfter=No|NewPar=Yes'
+ else:
+ space = 'SpaceAfter=No'
+ new_par = False
+
+ # Note the faked dependency structure: the conll reading code
+ # needs it even if it isn't being used in any way
+ fake_dep = 'root' if word_idx == 0 else 'dep'
+ fout.write('{}\t{}\t_\t_\t_\t_\t{}\t{}\t{}:{}\t{}\n'.format(word_idx+1, word[0], word_idx, fake_dep, word_idx, fake_dep, space))
+ fout.write('\n')
+
+def write_dataset(documents, output_dir, dataset_name):
+ """
+ Shuffle a list of documents, write three sections
+ """
+ random.shuffle(documents)
+ num_train = int(len(documents) * 0.8)
+ num_dev = int(len(documents) * 0.1)
+ os.makedirs(output_dir, exist_ok=True)
+ write_section(output_dir, dataset_name, 'train', documents[:num_train])
+ write_section(output_dir, dataset_name, 'dev', documents[num_train:num_train+num_dev])
+ write_section(output_dir, dataset_name, 'test', documents[num_train+num_dev:])
+
+def write_dataset_best(documents, test_documents, output_dir, dataset_name):
+ """
+ Shuffle a list of documents, write three sections
+ """
+ random.shuffle(documents)
+ num_train = int(len(documents) * 0.85)
+ num_dev = int(len(documents) * 0.15)
+ os.makedirs(output_dir, exist_ok=True)
+ write_section(output_dir, dataset_name, 'train', documents[:num_train])
+ write_section(output_dir, dataset_name, 'dev', documents[num_train:num_train+num_dev])
+ write_section(output_dir, dataset_name, 'test', test_documents)
+
+
+def reprocess_lines(processed_lines):
+ """
+ Reprocesses lines using pythainlp to cut up sentences into shorter sentences.
+
+ Many of the lines in BEST seem to be multiple Thai sentences concatenated, according to native Thai speakers.
+
+ Input: a list of lines, where each line is a list of words. Space characters can be included as words
+ Output: a new list of lines, resplit using pythainlp
+ """
+ reprocessed_lines = []
+ for line in processed_lines:
+ text = "".join(line)
+ try:
+ chunks = sent_tokenize(text)
+ except NameError as e:
+ raise NameError("Sentences cannot be reprocessed without first installing pythainlp") from e
+ # Check that the total text back is the same as the text in
+ if sum(len(x) for x in chunks) != len(text):
+ raise ValueError("Got unexpected text length: \n{}\nvs\n{}".format(text, chunks))
+
+ chunk_lengths = [len(x) for x in chunks]
+
+ current_length = 0
+ new_line = []
+ for word in line:
+ if len(word) + current_length < chunk_lengths[0]:
+ new_line.append(word)
+ current_length = current_length + len(word)
+ elif len(word) + current_length == chunk_lengths[0]:
+ new_line.append(word)
+ reprocessed_lines.append(new_line)
+ new_line = []
+ chunk_lengths = chunk_lengths[1:]
+ current_length = 0
+ else:
+ remaining_len = chunk_lengths[0] - current_length
+ new_line.append(word[:remaining_len])
+ reprocessed_lines.append(new_line)
+ word = word[remaining_len:]
+ chunk_lengths = chunk_lengths[1:]
+ while len(word) > chunk_lengths[0]:
+ new_line = [word[:chunk_lengths[0]]]
+ reprocessed_lines.append(new_line)
+ word = word[chunk_lengths[0]:]
+ chunk_lengths = chunk_lengths[1:]
+ new_line = [word]
+ current_length = len(word)
+ reprocessed_lines.append(new_line)
+ return reprocessed_lines
+
+def convert_processed_lines(processed_lines):
+ """
+ Convert a list of sentences into documents suitable for the output methods in this module.
+
+ Input: a list of lines, including space words
+ Output: a list of documents, each document containing a list of sentences
+ Each sentence is a list of words: (text, space_follows)
+ Space words will be eliminated.
+ """
+ paragraphs = []
+ sentences = []
+ for words in processed_lines:
+ # turn the words into a sentence
+ if len(words) > 1 and " " == words[0]:
+ words = words[1:]
+ elif len(words) == 1 and " " == words[0]:
+ words = []
+
+ sentence = []
+ for word in words:
+ word = word.strip()
+ if not word:
+ if len(sentence) == 0:
+ print(word)
+ raise ValueError("Unexpected space at start of sentence in document {}".format(filename))
+ sentence[-1] = (sentence[-1][0], True)
+ else:
+ sentence.append((word, False))
+ # blank lines are very rare in best, but why not treat them as a paragraph break
+ if len(sentence) == 0:
+ paragraphs.append([sentences])
+ sentences = []
+ continue
+ sentence[-1] = (sentence[-1][0], True)
+ sentences.append(sentence)
+ paragraphs.append([sentences])
+ return paragraphs
+
+
+
+
+
diff --git a/stanza/stanza/utils/datasets/vietnamese/renormalize.py b/stanza/stanza/utils/datasets/vietnamese/renormalize.py
new file mode 100644
index 0000000000000000000000000000000000000000..08fcfab2c31f1755b40815ba1649c00044c89318
--- /dev/null
+++ b/stanza/stanza/utils/datasets/vietnamese/renormalize.py
@@ -0,0 +1,141 @@
+"""
+Script to renormalize diacritics for Vietnamese text
+
+from BARTpho
+https://github.com/VinAIResearch/BARTpho/blob/main/VietnameseToneNormalization.md
+https://github.com/VinAIResearch/BARTpho/blob/main/LICENSE
+
+MIT License
+
+Copyright (c) 2021 VinAI Research
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
+"""
+
+import argparse
+import os
+
+DICT_MAP = {
+ "òa": "oà",
+ "Òa": "Oà",
+ "ÒA": "OÀ",
+ "óa": "oá",
+ "Óa": "Oá",
+ "ÓA": "OÁ",
+ "ỏa": "oả",
+ "Ỏa": "Oả",
+ "ỎA": "OẢ",
+ "õa": "oã",
+ "Õa": "Oã",
+ "ÕA": "OÃ",
+ "ọa": "oạ",
+ "Ọa": "Oạ",
+ "ỌA": "OẠ",
+ "òe": "oè",
+ "Òe": "Oè",
+ "ÒE": "OÈ",
+ "óe": "oé",
+ "Óe": "Oé",
+ "ÓE": "OÉ",
+ "ỏe": "oẻ",
+ "Ỏe": "Oẻ",
+ "ỎE": "OẺ",
+ "õe": "oẽ",
+ "Õe": "Oẽ",
+ "ÕE": "OẼ",
+ "ọe": "oẹ",
+ "Ọe": "Oẹ",
+ "ỌE": "OẸ",
+ "ùy": "uỳ",
+ "Ùy": "Uỳ",
+ "ÙY": "UỲ",
+ "úy": "uý",
+ "Úy": "Uý",
+ "ÚY": "UÝ",
+ "ủy": "uỷ",
+ "Ủy": "Uỷ",
+ "ỦY": "UỶ",
+ "ũy": "uỹ",
+ "Ũy": "Uỹ",
+ "ŨY": "UỸ",
+ "ụy": "uỵ",
+ "Ụy": "Uỵ",
+ "ỤY": "UỴ",
+}
+
+
+def replace_all(text):
+ for i, j in DICT_MAP.items():
+ text = text.replace(i, j)
+ return text
+
+def convert_file(org_file, new_file):
+ with open(org_file, 'r', encoding='utf-8') as reader, open(new_file, 'w', encoding='utf-8') as writer:
+ content = reader.readlines()
+ for line in content:
+ new_line = replace_all(line)
+ writer.write(new_line)
+
+def convert_files(file_list, new_dir):
+ for file_name in file_list:
+ base_name = os.path.split(file_name)[-1]
+ new_file_path = os.path.join(new_dir, base_name)
+
+ convert_file(file_name, new_file_path)
+
+
+def convert_dir(org_dir, new_dir, suffix):
+ os.makedirs(new_dir, exist_ok=True)
+ file_list = os.listdir(org_dir)
+ file_list = [os.path.join(org_dir, f) for f in file_list if os.path.splitext(f)[1] == suffix]
+ convert_files(file_list, new_dir)
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description='Script that renormalizes diacritics'
+ )
+
+ parser.add_argument(
+ 'orig',
+ help='Location of the original directory'
+ )
+
+ parser.add_argument(
+ 'converted',
+ help='The location of new directory'
+ )
+
+ parser.add_argument(
+ '--suffix',
+ type=str,
+ default='.txt',
+ help='Which suffix to look for when renormalizing a directory'
+ )
+
+ args = parser.parse_args()
+
+ if os.path.isfile(args.orig):
+ convert_file(args.orig, args.converted)
+ else:
+ convert_dir(args.orig, args.converted, args.suffix)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/stanza/stanza/utils/training/run_ete.py b/stanza/stanza/utils/training/run_ete.py
new file mode 100644
index 0000000000000000000000000000000000000000..c5d0e26b91ed7beb32906f8230752bddc9d896eb
--- /dev/null
+++ b/stanza/stanza/utils/training/run_ete.py
@@ -0,0 +1,194 @@
+"""
+Runs a pipeline end-to-end, reports conll scores.
+
+For example, you can do
+ python3 stanza/utils/training/run_ete.py it_isdt --score_test
+You can run on all models at once:
+ python3 stanza/utils/training/run_ete.py ud_all --score_test
+
+You can also run one model on a different model's data:
+ python3 stanza/utils/training/run_ete.py it_isdt --score_dev --test_data it_vit
+ python3 stanza/utils/training/run_ete.py it_isdt --score_test --test_data it_vit
+
+Running multiple models with a --test_data flag will run them all on the same data:
+ python3 stanza/utils/training/run_ete.py it_combined it_isdt it_vit --score_test --test_data it_vit
+
+If run with no dataset arguments, then the dataset used is the train
+data, which may or may not be useful.
+"""
+
+import logging
+import os
+import tempfile
+
+from stanza.models import identity_lemmatizer
+from stanza.models import lemmatizer
+from stanza.models import mwt_expander
+from stanza.models import parser
+from stanza.models import tagger
+from stanza.models import tokenizer
+
+from stanza.models.common.constant import treebank_to_short_name
+
+from stanza.utils.training import common
+from stanza.utils.training.common import Mode, build_pos_charlm_args, build_lemma_charlm_args, build_depparse_charlm_args
+from stanza.utils.training.run_lemma import check_lemmas
+from stanza.utils.training.run_mwt import check_mwt
+from stanza.utils.training.run_pos import wordvec_args
+
+logger = logging.getLogger('stanza')
+
+# a constant so that the script which looks for these results knows what to look for
+RESULTS_STRING = "End to end results for"
+
+def add_args(parser):
+ parser.add_argument('--test_data', default=None, type=str, help='Which data to test on, if not using the default data for this model')
+ common.add_charlm_args(parser)
+
+def run_ete(paths, dataset, short_name, command_args, extra_args):
+ short_language, package = short_name.split("_", 1)
+
+ tokenize_dir = paths["TOKENIZE_DATA_DIR"]
+ mwt_dir = paths["MWT_DATA_DIR"]
+ lemma_dir = paths["LEMMA_DATA_DIR"]
+ ete_dir = paths["ETE_DATA_DIR"]
+ wordvec_dir = paths["WORDVEC_DIR"]
+
+ # run models in the following order:
+ # tokenize
+ # mwt, if exists
+ # pos
+ # lemma, if exists
+ # depparse
+ # the output of each step is either kept or discarded based on the
+ # value of command_args.save_output
+
+ if command_args and command_args.test_data:
+ test_short_name = treebank_to_short_name(command_args.test_data)
+ else:
+ test_short_name = short_name
+
+ # TOKENIZE step
+ # the raw data to process starts in tokenize_dir
+ # retokenize it using the saved model
+ tokenizer_type = "--txt_file"
+ tokenizer_file = f"{tokenize_dir}/{test_short_name}.{dataset}.txt"
+
+ tokenizer_output = f"{ete_dir}/{short_name}.{dataset}.tokenizer.conllu"
+
+ tokenizer_args = ["--mode", "predict", tokenizer_type, tokenizer_file, "--lang", short_language,
+ "--conll_file", tokenizer_output, "--shorthand", short_name]
+ tokenizer_args = tokenizer_args + extra_args
+ logger.info("----- TOKENIZER ----------")
+ logger.info("Running tokenizer step with args: {}".format(tokenizer_args))
+ tokenizer.main(tokenizer_args)
+
+ # If the data has any MWT in it, there should be an MWT model
+ # trained, so run that. Otherwise, we skip MWT
+ mwt_train_file = f"{mwt_dir}/{short_name}.train.in.conllu"
+ logger.info("----- MWT ----------")
+ if check_mwt(mwt_train_file):
+ mwt_output = f"{ete_dir}/{short_name}.{dataset}.mwt.conllu"
+ mwt_args = ['--eval_file', tokenizer_output,
+ '--output_file', mwt_output,
+ '--lang', short_language,
+ '--shorthand', short_name,
+ '--mode', 'predict']
+ mwt_args = mwt_args + extra_args
+ logger.info("Running mwt step with args: {}".format(mwt_args))
+ mwt_expander.main(mwt_args)
+ else:
+ logger.info("No MWT in training data. Skipping")
+ mwt_output = tokenizer_output
+
+ # Run the POS step
+ # TODO: add batch args
+ # TODO: add transformer args
+ logger.info("----- POS ----------")
+ pos_output = f"{ete_dir}/{short_name}.{dataset}.pos.conllu"
+ pos_args = ['--wordvec_dir', wordvec_dir,
+ '--eval_file', mwt_output,
+ '--output_file', pos_output,
+ '--lang', short_language,
+ '--shorthand', short_name,
+ '--mode', 'predict',
+ # the MWT is not preserving the tags,
+ # so we don't ask the tagger to report a score
+ # the ETE will score the whole thing at the end
+ '--no_gold_labels']
+
+ pos_charlm_args = build_pos_charlm_args(short_language, package, command_args.charlm)
+
+ pos_args = pos_args + wordvec_args(short_language, package, extra_args) + pos_charlm_args + extra_args
+ logger.info("Running pos step with args: {}".format(pos_args))
+ tagger.main(pos_args)
+
+ # Run the LEMMA step. If there are no lemmas in the training
+ # data, use the identity lemmatizer.
+ logger.info("----- LEMMA ----------")
+ lemma_train_file = f"{lemma_dir}/{short_name}.train.in.conllu"
+ lemma_output = f"{ete_dir}/{short_name}.{dataset}.lemma.conllu"
+ lemma_args = ['--eval_file', pos_output,
+ '--output_file', lemma_output,
+ '--shorthand', short_name,
+ '--mode', 'predict']
+ if check_lemmas(lemma_train_file):
+ lemma_charlm_args = build_lemma_charlm_args(short_language, package, command_args.charlm)
+ lemma_args = lemma_args + lemma_charlm_args + extra_args
+ logger.info("Running lemmatizer step with args: {}".format(lemma_args))
+ lemmatizer.main(lemma_args)
+ else:
+ lemma_args = lemma_args + extra_args
+ logger.info("No lemmas in training data")
+ logger.info("Running identity lemmatizer step with args: {}".format(lemma_args))
+ identity_lemmatizer.main(lemma_args)
+
+ # Run the DEPPARSE step. This is the last step
+ # Note that we do NOT use the depparse directory's data. That is
+ # because it has either gold tags, or predicted tags based on
+ # retagging using gold tokenization, and we aren't sure which at
+ # this point in the process.
+ # TODO: add batch args
+ logger.info("----- DEPPARSE ----------")
+ depparse_output = f"{ete_dir}/{short_name}.{dataset}.depparse.conllu"
+ depparse_args = ['--wordvec_dir', wordvec_dir,
+ '--eval_file', lemma_output,
+ '--output_file', depparse_output,
+ '--lang', short_name,
+ '--shorthand', short_name,
+ '--mode', 'predict']
+ depparse_charlm_args = build_depparse_charlm_args(short_language, package, command_args.charlm)
+ depparse_args = depparse_args + wordvec_args(short_language, package, extra_args) + depparse_charlm_args + extra_args
+ logger.info("Running depparse step with args: {}".format(depparse_args))
+ parser.main(depparse_args)
+
+ logger.info("----- EVALUATION ----------")
+ gold_file = f"{tokenize_dir}/{test_short_name}.{dataset}.gold.conllu"
+ ete_file = depparse_output
+ results = common.run_eval_script(gold_file, ete_file)
+ logger.info("{} {} models on {} {} data:\n{}".format(RESULTS_STRING, short_name, test_short_name, dataset, results))
+
+def run_treebank(mode, paths, treebank, short_name,
+ temp_output_file, command_args, extra_args):
+ if mode == Mode.TRAIN:
+ dataset = 'train'
+ elif mode == Mode.SCORE_DEV:
+ dataset = 'dev'
+ elif mode == Mode.SCORE_TEST:
+ dataset = 'test'
+
+ if command_args.temp_output:
+ with tempfile.TemporaryDirectory() as ete_dir:
+ paths = dict(paths)
+ paths["ETE_DATA_DIR"] = ete_dir
+ run_ete(paths, dataset, short_name, command_args, extra_args)
+ else:
+ os.makedirs(paths["ETE_DATA_DIR"], exist_ok=True)
+ run_ete(paths, dataset, short_name, command_args, extra_args)
+
+def main():
+ common.main(run_treebank, "ete", "ete", add_args)
+
+if __name__ == "__main__":
+ main()
+
diff --git a/stanza/stanza/utils/training/run_pos.py b/stanza/stanza/utils/training/run_pos.py
new file mode 100644
index 0000000000000000000000000000000000000000..54dc2eee0cf86ea2aed35ccc0166d47b9f519ca8
--- /dev/null
+++ b/stanza/stanza/utils/training/run_pos.py
@@ -0,0 +1,147 @@
+
+
+import logging
+import os
+
+from stanza.models import tagger
+
+from stanza.resources.default_packages import no_pretrain_languages, pos_pretrains, default_pretrains
+from stanza.utils.training import common
+from stanza.utils.training.common import Mode, add_charlm_args, build_pos_charlm_args, choose_pos_charlm, find_wordvec_pretrain
+
+logger = logging.getLogger('stanza')
+
+def add_pos_args(parser):
+ add_charlm_args(parser)
+
+ parser.add_argument('--use_bert', default=False, action="store_true", help='Use the default transformer for this language')
+
+# TODO: move this somewhere common
+def wordvec_args(short_language, dataset, extra_args):
+ if '--wordvec_pretrain_file' in extra_args or '--no_pretrain' in extra_args:
+ return []
+
+ if short_language in no_pretrain_languages:
+ # we couldn't find word vectors for a few languages...:
+ # coptic, naija, old russian, turkish german, swedish sign language
+ logger.warning("No known word vectors for language {} If those vectors can be found, please update the training scripts.".format(short_language))
+ return ["--no_pretrain"]
+ else:
+ if short_language in pos_pretrains and dataset in pos_pretrains[short_language]:
+ dataset_pretrains = pos_pretrains
+ else:
+ dataset_pretrains = {}
+ wordvec_pretrain = find_wordvec_pretrain(short_language, default_pretrains, dataset_pretrains, dataset)
+ return ["--wordvec_pretrain_file", wordvec_pretrain]
+
+def build_model_filename(paths, short_name, command_args, extra_args):
+ short_language, dataset = short_name.split("_", 1)
+
+ # TODO: can avoid downloading the charlm at this point, since we
+ # might not even be training
+ charlm_args = build_pos_charlm_args(short_language, dataset, command_args.charlm)
+ bert_args = common.choose_transformer(short_language, command_args, extra_args, warn=False)
+
+ train_args = ["--shorthand", short_name,
+ "--mode", "train"]
+ # TODO: also, this downloads the wordvec, which we might not want to do yet
+ train_args = train_args + wordvec_args(short_language, dataset, extra_args) + charlm_args + bert_args + extra_args
+ if command_args.save_name is not None:
+ train_args.extend(["--save_name", command_args.save_name])
+ if command_args.save_dir is not None:
+ train_args.extend(["--save_dir", command_args.save_dir])
+ args = tagger.parse_args(train_args)
+ save_name = tagger.model_file_name(args)
+ return save_name
+
+
+
+def run_treebank(mode, paths, treebank, short_name,
+ temp_output_file, command_args, extra_args):
+ short_language, dataset = short_name.split("_", 1)
+
+ pos_dir = paths["POS_DATA_DIR"]
+ train_file = f"{pos_dir}/{short_name}.train.in.conllu"
+ if short_name == 'vi_vlsp22':
+ train_file += f";{pos_dir}/vi_vtb.train.in.conllu"
+ dev_in_file = f"{pos_dir}/{short_name}.dev.in.conllu"
+ dev_pred_file = temp_output_file if temp_output_file else f"{pos_dir}/{short_name}.dev.pred.conllu"
+ test_in_file = f"{pos_dir}/{short_name}.test.in.conllu"
+ test_pred_file = temp_output_file if temp_output_file else f"{pos_dir}/{short_name}.test.pred.conllu"
+
+ charlm_args = build_pos_charlm_args(short_language, dataset, command_args.charlm)
+ bert_args = common.choose_transformer(short_language, command_args, extra_args)
+
+ eval_file = None
+ if '--eval_file' in extra_args:
+ eval_file = extra_args[extra_args.index('--eval_file') + 1]
+
+ if mode == Mode.TRAIN:
+ train_pieces = []
+ for train_piece in train_file.split(";"):
+ zip_piece = os.path.splitext(train_piece)[0] + ".zip"
+ if os.path.exists(train_piece) and os.path.exists(zip_piece):
+ logger.error("POS TRAIN FILE %s and %s both exist... this is very confusing, skipping %s" % (train_piece, zip_piece, short_name))
+ return
+ if os.path.exists(train_piece):
+ train_pieces.append(train_piece)
+ else: # not os.path.exists(train_piece):
+ if os.path.exists(zip_piece):
+ train_pieces.append(zip_piece)
+ continue
+ logger.error("TRAIN FILE NOT FOUND: %s ... skipping" % train_piece)
+ return
+ train_file = ";".join(train_pieces)
+
+ train_args = ["--wordvec_dir", paths["WORDVEC_DIR"],
+ "--train_file", train_file,
+ "--output_file", dev_pred_file,
+ "--lang", short_language,
+ "--shorthand", short_name,
+ "--mode", "train"]
+ if eval_file is None:
+ train_args += ['--eval_file', dev_in_file]
+ train_args = train_args + wordvec_args(short_language, dataset, extra_args) + charlm_args + bert_args
+ train_args = train_args + extra_args
+ logger.info("Running train POS for {} with args {}".format(treebank, train_args))
+ tagger.main(train_args)
+
+ if mode == Mode.SCORE_DEV or mode == Mode.TRAIN:
+ dev_args = ["--wordvec_dir", paths["WORDVEC_DIR"],
+ "--output_file", dev_pred_file,
+ "--lang", short_language,
+ "--shorthand", short_name,
+ "--mode", "predict"]
+ if eval_file is None:
+ dev_args += ['--eval_file', dev_in_file]
+ dev_args = dev_args + wordvec_args(short_language, dataset, extra_args) + charlm_args + bert_args
+ dev_args = dev_args + extra_args
+ logger.info("Running dev POS for {} with args {}".format(treebank, dev_args))
+ tagger.main(dev_args)
+
+ results = common.run_eval_script_pos(eval_file if eval_file else dev_in_file, dev_pred_file)
+ logger.info("Finished running dev set on\n{}\n{}".format(treebank, results))
+
+ if mode == Mode.SCORE_TEST:
+ test_args = ["--wordvec_dir", paths["WORDVEC_DIR"],
+ "--output_file", test_pred_file,
+ "--lang", short_language,
+ "--shorthand", short_name,
+ "--mode", "predict"]
+ if eval_file is None:
+ test_args += ['--eval_file', test_in_file]
+ test_args = test_args + wordvec_args(short_language, dataset, extra_args) + charlm_args + bert_args
+ test_args = test_args + extra_args
+ logger.info("Running test POS for {} with args {}".format(treebank, test_args))
+ tagger.main(test_args)
+
+ results = common.run_eval_script_pos(eval_file if eval_file else test_in_file, test_pred_file)
+ logger.info("Finished running test set on\n{}\n{}".format(treebank, results))
+
+
+def main():
+ common.main(run_treebank, "pos", "tagger", add_pos_args, tagger.build_argparse(), build_model_filename=build_model_filename, choose_charlm_method=choose_pos_charlm)
+
+if __name__ == "__main__":
+ main()
+