Add files using upload-large-folder tool
Browse files- stanza/stanza/models/classifiers/constituency_classifier.py +96 -0
- stanza/stanza/models/classifiers/data.py +169 -0
- stanza/stanza/models/coref/bert.py +69 -0
- stanza/stanza/models/langid/__init__.py +0 -0
- stanza/stanza/models/langid/data.py +134 -0
- stanza/stanza/models/langid/model.py +126 -0
- stanza/stanza/models/lemma_classifier/__init__.py +0 -0
- stanza/stanza/models/ner/model.py +278 -0
- stanza/stanza/models/pos/scorer.py +22 -0
- stanza/stanza/models/pos/vocab.py +71 -0
- stanza/stanza/pipeline/demo/stanza-brat.js +1316 -0
- stanza/stanza/pipeline/external/corenlp_converter_depparse.py +29 -0
- stanza/stanza/pipeline/external/jieba.py +71 -0
- stanza/stanza/pipeline/external/sudachipy.py +84 -0
- stanza/stanza/utils/charlm/oscar_to_text.py +78 -0
- stanza/stanza/utils/constituency/__init__.py +0 -0
- stanza/stanza/utils/constituency/grep_test_logs.py +24 -0
- stanza/stanza/utils/datasets/constituency/build_silver_dataset.py +117 -0
- stanza/stanza/utils/datasets/constituency/convert_cintil.py +80 -0
- stanza/stanza/utils/datasets/constituency/count_common_words.py +12 -0
- stanza/stanza/utils/datasets/constituency/prepare_con_dataset.py +594 -0
- stanza/stanza/utils/datasets/constituency/silver_variance.py +108 -0
- stanza/stanza/utils/datasets/coref/convert_hindi.py +170 -0
- stanza/stanza/utils/datasets/ner/compare_entities.py +38 -0
- stanza/stanza/utils/datasets/ner/conll_to_iob.py +59 -0
- stanza/stanza/utils/datasets/ner/convert_bn_daffodil.py +123 -0
- stanza/stanza/utils/datasets/ner/convert_en_conll03.py +42 -0
- stanza/stanza/utils/datasets/ner/convert_he_iahlt.py +108 -0
- stanza/stanza/utils/datasets/ner/convert_lst20.py +74 -0
- stanza/stanza/utils/datasets/ner/convert_mr_l3cube.py +54 -0
- stanza/stanza/utils/datasets/ner/convert_nner22.py +70 -0
- stanza/stanza/utils/datasets/ner/convert_ontonotes.py +58 -0
- stanza/stanza/utils/datasets/ner/json_to_bio.py +43 -0
- stanza/stanza/utils/datasets/ner/misc_to_date.py +77 -0
- stanza/stanza/utils/datasets/ner/preprocess_wikiner.py +37 -0
- stanza/stanza/utils/datasets/ner/simplify_en_worldwide.py +152 -0
- stanza/stanza/utils/datasets/ner/simplify_ontonotes_to_worldwide.py +118 -0
- stanza/stanza/utils/datasets/ner/split_wikiner.py +104 -0
- stanza/stanza/utils/datasets/ner/suc_conll_to_iob.py +72 -0
- stanza/stanza/utils/datasets/pos/__init__.py +0 -0
- stanza/stanza/utils/datasets/pos/convert_trees_to_pos.py +94 -0
- stanza/stanza/utils/datasets/prepare_tokenizer_data.py +151 -0
- stanza/stanza/utils/datasets/prepare_tokenizer_treebank.py +1396 -0
- stanza/stanza/utils/datasets/pretrain/__init__.py +0 -0
- stanza/stanza/utils/datasets/tokenization/__init__.py +0 -0
- stanza/stanza/utils/datasets/tokenization/convert_vi_vlsp.py +155 -0
- stanza/stanza/utils/ner/spacy_ner_tag_dataset.py +138 -0
- stanza/stanza/utils/training/__init__.py +0 -0
- stanza/stanza/utils/training/remove_constituency_optimizer.py +77 -0
- stanza/stanza/utils/visualization/dependency_visualization.py +108 -0
stanza/stanza/models/classifiers/constituency_classifier.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
A classifier that uses a constituency parser for the base embeddings
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import dataclasses
|
| 6 |
+
import logging
|
| 7 |
+
from types import SimpleNamespace
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
|
| 13 |
+
from stanza.models.classifiers.base_classifier import BaseClassifier
|
| 14 |
+
from stanza.models.classifiers.config import ConstituencyConfig
|
| 15 |
+
from stanza.models.classifiers.data import SentimentDatum
|
| 16 |
+
from stanza.models.classifiers.utils import ModelType, build_output_layers
|
| 17 |
+
|
| 18 |
+
from stanza.models.common.utils import split_into_batches, sort_with_indices, unsort
|
| 19 |
+
|
| 20 |
+
logger = logging.getLogger('stanza')
|
| 21 |
+
tlogger = logging.getLogger('stanza.classifiers.trainer')
|
| 22 |
+
|
| 23 |
+
class ConstituencyClassifier(BaseClassifier):
|
| 24 |
+
def __init__(self, tree_embedding, labels, args):
|
| 25 |
+
super(ConstituencyClassifier, self).__init__()
|
| 26 |
+
self.labels = labels
|
| 27 |
+
# we build a separate config out of the args so that we can easily save it in torch
|
| 28 |
+
self.config = ConstituencyConfig(fc_shapes = args.fc_shapes,
|
| 29 |
+
dropout = args.dropout,
|
| 30 |
+
num_classes = len(labels),
|
| 31 |
+
constituency_backprop = args.constituency_backprop,
|
| 32 |
+
constituency_batch_norm = args.constituency_batch_norm,
|
| 33 |
+
constituency_node_attn = args.constituency_node_attn,
|
| 34 |
+
constituency_top_layer = args.constituency_top_layer,
|
| 35 |
+
constituency_all_words = args.constituency_all_words,
|
| 36 |
+
model_type = ModelType.CONSTITUENCY)
|
| 37 |
+
|
| 38 |
+
self.tree_embedding = tree_embedding
|
| 39 |
+
|
| 40 |
+
self.fc_layers = build_output_layers(self.tree_embedding.output_size, self.config.fc_shapes, self.config.num_classes)
|
| 41 |
+
self.dropout = nn.Dropout(self.config.dropout)
|
| 42 |
+
|
| 43 |
+
def is_unsaved_module(self, name):
|
| 44 |
+
return False
|
| 45 |
+
|
| 46 |
+
def log_configuration(self):
|
| 47 |
+
tlogger.info("Backprop into parser: %s", self.config.constituency_backprop)
|
| 48 |
+
tlogger.info("Batch norm: %s", self.config.constituency_batch_norm)
|
| 49 |
+
tlogger.info("Word positions used: %s", "all words" if self.config.constituency_all_words else "start and end words")
|
| 50 |
+
tlogger.info("Attention over nodes: %s", self.config.constituency_node_attn)
|
| 51 |
+
tlogger.info("Intermediate layers: %s", self.config.fc_shapes)
|
| 52 |
+
|
| 53 |
+
def log_norms(self):
|
| 54 |
+
lines = ["NORMS FOR MODEL PARAMTERS"]
|
| 55 |
+
lines.extend(["tree_embedding." + x for x in self.tree_embedding.get_norms()])
|
| 56 |
+
for name, param in self.named_parameters():
|
| 57 |
+
if param.requires_grad and not name.startswith('tree_embedding.'):
|
| 58 |
+
lines.append("%s %.6g" % (name, torch.norm(param).item()))
|
| 59 |
+
logger.info("\n".join(lines))
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def forward(self, inputs):
|
| 63 |
+
inputs = [x.constituency if isinstance(x, SentimentDatum) else x for x in inputs]
|
| 64 |
+
|
| 65 |
+
embedding = self.tree_embedding.embed_trees(inputs)
|
| 66 |
+
previous_layer = torch.stack([torch.max(x, dim=0)[0] for x in embedding], dim=0)
|
| 67 |
+
previous_layer = self.dropout(previous_layer)
|
| 68 |
+
for fc in self.fc_layers[:-1]:
|
| 69 |
+
# relu cause many neuron die
|
| 70 |
+
previous_layer = self.dropout(F.gelu(fc(previous_layer)))
|
| 71 |
+
out = self.fc_layers[-1](previous_layer)
|
| 72 |
+
return out
|
| 73 |
+
|
| 74 |
+
def get_params(self, skip_modules=True):
|
| 75 |
+
model_state = self.state_dict()
|
| 76 |
+
# skip all of the constituency parameters here -
|
| 77 |
+
# we will add them by calling the model's get_params()
|
| 78 |
+
skipped = [k for k in model_state.keys() if k.startswith("tree_embedding.")]
|
| 79 |
+
for k in skipped:
|
| 80 |
+
del model_state[k]
|
| 81 |
+
|
| 82 |
+
tree_embedding = self.tree_embedding.get_params(skip_modules)
|
| 83 |
+
|
| 84 |
+
config = dataclasses.asdict(self.config)
|
| 85 |
+
config['model_type'] = config['model_type'].name
|
| 86 |
+
|
| 87 |
+
params = {
|
| 88 |
+
'model': model_state,
|
| 89 |
+
'tree_embedding': tree_embedding,
|
| 90 |
+
'config': config,
|
| 91 |
+
'labels': self.labels,
|
| 92 |
+
}
|
| 93 |
+
return params
|
| 94 |
+
|
| 95 |
+
def extract_sentences(self, doc):
|
| 96 |
+
return [sentence.constituency for sentence in doc.sentences]
|
stanza/stanza/models/classifiers/data.py
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Stanza models classifier data functions."""
|
| 2 |
+
|
| 3 |
+
import collections
|
| 4 |
+
from collections import namedtuple
|
| 5 |
+
import logging
|
| 6 |
+
import json
|
| 7 |
+
import random
|
| 8 |
+
import re
|
| 9 |
+
from typing import List
|
| 10 |
+
|
| 11 |
+
from stanza.models.classifiers.utils import WVType
|
| 12 |
+
from stanza.models.common.vocab import PAD, PAD_ID, UNK, UNK_ID
|
| 13 |
+
import stanza.models.constituency.tree_reader as tree_reader
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger('stanza')
|
| 16 |
+
|
| 17 |
+
class SentimentDatum:
|
| 18 |
+
def __init__(self, sentiment, text, constituency=None):
|
| 19 |
+
self.sentiment = sentiment
|
| 20 |
+
self.text = text
|
| 21 |
+
self.constituency = constituency
|
| 22 |
+
|
| 23 |
+
def __eq__(self, other):
|
| 24 |
+
if self is other:
|
| 25 |
+
return True
|
| 26 |
+
if not isinstance(other, SentimentDatum):
|
| 27 |
+
return False
|
| 28 |
+
return self.sentiment == other.sentiment and self.text == other.text and self.constituency == other.constituency
|
| 29 |
+
|
| 30 |
+
def __str__(self):
|
| 31 |
+
return str(self._asdict())
|
| 32 |
+
|
| 33 |
+
def _asdict(self):
|
| 34 |
+
if self.constituency is None:
|
| 35 |
+
return {'sentiment': self.sentiment, 'text': self.text}
|
| 36 |
+
else:
|
| 37 |
+
return {'sentiment': self.sentiment, 'text': self.text, 'constituency': str(self.constituency)}
|
| 38 |
+
|
| 39 |
+
def update_text(sentence: List[str], wordvec_type: WVType) -> List[str]:
|
| 40 |
+
"""
|
| 41 |
+
Process a line of text (with tokenization provided as whitespace)
|
| 42 |
+
into a list of strings.
|
| 43 |
+
"""
|
| 44 |
+
# stanford sentiment dataset has a lot of random - and /
|
| 45 |
+
# remove those characters and flatten the newly created sublists into one list each time
|
| 46 |
+
sentence = [y for x in sentence for y in x.split("-") if y]
|
| 47 |
+
sentence = [y for x in sentence for y in x.split("/") if y]
|
| 48 |
+
sentence = [x.strip() for x in sentence]
|
| 49 |
+
sentence = [x for x in sentence if x]
|
| 50 |
+
if sentence == []:
|
| 51 |
+
# removed too much
|
| 52 |
+
sentence = ["-"]
|
| 53 |
+
# our current word vectors are all entirely lowercased
|
| 54 |
+
sentence = [word.lower() for word in sentence]
|
| 55 |
+
if wordvec_type == WVType.WORD2VEC:
|
| 56 |
+
return sentence
|
| 57 |
+
elif wordvec_type == WVType.GOOGLE:
|
| 58 |
+
new_sentence = []
|
| 59 |
+
for word in sentence:
|
| 60 |
+
if word != '0' and word != '1':
|
| 61 |
+
word = re.sub('[0-9]', '#', word)
|
| 62 |
+
new_sentence.append(word)
|
| 63 |
+
return new_sentence
|
| 64 |
+
elif wordvec_type == WVType.FASTTEXT:
|
| 65 |
+
return sentence
|
| 66 |
+
elif wordvec_type == WVType.OTHER:
|
| 67 |
+
return sentence
|
| 68 |
+
else:
|
| 69 |
+
raise ValueError("Unknown wordvec_type {}".format(wordvec_type))
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def read_dataset(dataset, wordvec_type: WVType, min_len: int) -> List[SentimentDatum]:
|
| 73 |
+
"""
|
| 74 |
+
returns a list where the values of the list are
|
| 75 |
+
label, [token...]
|
| 76 |
+
"""
|
| 77 |
+
lines = []
|
| 78 |
+
for filename in str(dataset).split(","):
|
| 79 |
+
with open(filename, encoding="utf-8") as fin:
|
| 80 |
+
new_lines = json.load(fin)
|
| 81 |
+
new_lines = [(str(x['sentiment']), x['text'], x.get('constituency', None)) for x in new_lines]
|
| 82 |
+
lines.extend(new_lines)
|
| 83 |
+
# TODO: maybe do this processing later, once the model is built.
|
| 84 |
+
# then move the processing into the model so we can use
|
| 85 |
+
# overloading to potentially make future model types
|
| 86 |
+
lines = [SentimentDatum(x[0], update_text(x[1], wordvec_type), tree_reader.read_trees(x[2])[0] if x[2] else None) for x in lines]
|
| 87 |
+
if min_len:
|
| 88 |
+
lines = [x for x in lines if len(x.text) >= min_len]
|
| 89 |
+
return lines
|
| 90 |
+
|
| 91 |
+
def dataset_labels(dataset):
|
| 92 |
+
"""
|
| 93 |
+
Returns a sorted list of label name
|
| 94 |
+
"""
|
| 95 |
+
labels = set([x.sentiment for x in dataset])
|
| 96 |
+
if all(re.match("^[0-9]+$", label) for label in labels):
|
| 97 |
+
# if all of the labels are integers, sort numerically
|
| 98 |
+
# maybe not super important, but it would be nicer than having
|
| 99 |
+
# 10 before 2
|
| 100 |
+
labels = [str(x) for x in sorted(map(int, list(labels)))]
|
| 101 |
+
else:
|
| 102 |
+
labels = sorted(list(labels))
|
| 103 |
+
return labels
|
| 104 |
+
|
| 105 |
+
def dataset_vocab(dataset):
|
| 106 |
+
vocab = set()
|
| 107 |
+
for line in dataset:
|
| 108 |
+
for word in line.text:
|
| 109 |
+
vocab.add(word)
|
| 110 |
+
vocab = [PAD, UNK] + list(vocab)
|
| 111 |
+
if vocab[PAD_ID] != PAD or vocab[UNK_ID] != UNK:
|
| 112 |
+
raise ValueError("Unexpected values for PAD and UNK!")
|
| 113 |
+
return vocab
|
| 114 |
+
|
| 115 |
+
def sort_dataset_by_len(dataset, keep_index=False):
|
| 116 |
+
"""
|
| 117 |
+
returns a dict mapping length -> list of items of that length
|
| 118 |
+
|
| 119 |
+
an OrderedDict is used so that the mapping is sorted from smallest to largest
|
| 120 |
+
"""
|
| 121 |
+
sorted_dataset = collections.OrderedDict()
|
| 122 |
+
lengths = sorted(list(set(len(x.text) for x in dataset)))
|
| 123 |
+
for l in lengths:
|
| 124 |
+
sorted_dataset[l] = []
|
| 125 |
+
for item_idx, item in enumerate(dataset):
|
| 126 |
+
if keep_index:
|
| 127 |
+
sorted_dataset[len(item.text)].append((item, item_idx))
|
| 128 |
+
else:
|
| 129 |
+
sorted_dataset[len(item.text)].append(item)
|
| 130 |
+
return sorted_dataset
|
| 131 |
+
|
| 132 |
+
def shuffle_dataset(sorted_dataset, batch_size, batch_single_item):
|
| 133 |
+
"""
|
| 134 |
+
Given a dataset sorted by len, sorts within each length to make
|
| 135 |
+
chunks of roughly the same size. Returns all items as a single list.
|
| 136 |
+
"""
|
| 137 |
+
dataset = []
|
| 138 |
+
for l in sorted_dataset.keys():
|
| 139 |
+
items = list(sorted_dataset[l])
|
| 140 |
+
random.shuffle(items)
|
| 141 |
+
dataset.extend(items)
|
| 142 |
+
batches = []
|
| 143 |
+
next_batch = []
|
| 144 |
+
for item in dataset:
|
| 145 |
+
if batch_single_item > 0 and len(item.text) >= batch_single_item:
|
| 146 |
+
batches.append([item])
|
| 147 |
+
else:
|
| 148 |
+
next_batch.append(item)
|
| 149 |
+
if len(next_batch) >= batch_size:
|
| 150 |
+
batches.append(next_batch)
|
| 151 |
+
next_batch = []
|
| 152 |
+
if len(next_batch) > 0:
|
| 153 |
+
batches.append(next_batch)
|
| 154 |
+
random.shuffle(batches)
|
| 155 |
+
return batches
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def check_labels(labels, dataset):
|
| 159 |
+
"""
|
| 160 |
+
Check that all of the labels in the dataset are in the known labels.
|
| 161 |
+
|
| 162 |
+
Actually, unknown labels could be acceptable if we just treat the model as always wrong.
|
| 163 |
+
However, this is a good sanity check to make sure the datasets match
|
| 164 |
+
"""
|
| 165 |
+
new_labels = dataset_labels(dataset)
|
| 166 |
+
not_found = [i for i in new_labels if i not in labels]
|
| 167 |
+
if not_found:
|
| 168 |
+
raise RuntimeError('Dataset contains labels which the model does not know about:' + str(not_found))
|
| 169 |
+
|
stanza/stanza/models/coref/bert.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Functions related to BERT or similar models"""
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
from typing import List, Tuple
|
| 5 |
+
|
| 6 |
+
import numpy as np # type: ignore
|
| 7 |
+
from transformers import AutoModel, AutoTokenizer # type: ignore
|
| 8 |
+
|
| 9 |
+
from stanza.models.coref.config import Config
|
| 10 |
+
from stanza.models.coref.const import Doc
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger('stanza')
|
| 14 |
+
|
| 15 |
+
def get_subwords_batches(doc: Doc,
|
| 16 |
+
config: Config,
|
| 17 |
+
tok: AutoTokenizer
|
| 18 |
+
) -> np.ndarray:
|
| 19 |
+
"""
|
| 20 |
+
Turns a list of subwords to a list of lists of subword indices
|
| 21 |
+
of max length == batch_size (or shorter, as batch boundaries
|
| 22 |
+
should match sentence boundaries). Each batch is enclosed in cls and sep
|
| 23 |
+
special tokens.
|
| 24 |
+
|
| 25 |
+
Returns:
|
| 26 |
+
batches of bert tokens [n_batches, batch_size]
|
| 27 |
+
"""
|
| 28 |
+
batch_size = config.bert_window_size - 2 # to save space for CLS and SEP
|
| 29 |
+
|
| 30 |
+
subwords: List[str] = doc["subwords"]
|
| 31 |
+
subwords_batches = []
|
| 32 |
+
start, end = 0, 0
|
| 33 |
+
|
| 34 |
+
while end < len(subwords):
|
| 35 |
+
# to prevent the case where a batch_size step forward
|
| 36 |
+
# doesn't capture more than 1 sentence, we will just cut
|
| 37 |
+
# that sequence
|
| 38 |
+
prev_end = end
|
| 39 |
+
end = min(end + batch_size, len(subwords))
|
| 40 |
+
|
| 41 |
+
# Move back till we hit a sentence end
|
| 42 |
+
if end < len(subwords):
|
| 43 |
+
sent_id = doc["sent_id"][doc["word_id"][end]]
|
| 44 |
+
while end and doc["sent_id"][doc["word_id"][end - 1]] == sent_id:
|
| 45 |
+
end -= 1
|
| 46 |
+
|
| 47 |
+
# this occurs IFF there was no sentence end found throughout
|
| 48 |
+
# the forward scan; this means that our sentence was waay too
|
| 49 |
+
# long (i.e. longer than the max length of the transformer.
|
| 50 |
+
#
|
| 51 |
+
# if so, we give up and just chop the sentence off at the max length
|
| 52 |
+
# that was given
|
| 53 |
+
if end == prev_end:
|
| 54 |
+
end = min(end + batch_size, len(subwords))
|
| 55 |
+
|
| 56 |
+
length = end - start
|
| 57 |
+
if tok.cls_token == None or tok.sep_token == None:
|
| 58 |
+
batch = [tok.eos_token] + subwords[start:end] + [tok.eos_token]
|
| 59 |
+
else:
|
| 60 |
+
batch = [tok.cls_token] + subwords[start:end] + [tok.sep_token]
|
| 61 |
+
|
| 62 |
+
# Padding to desired length
|
| 63 |
+
batch += [tok.pad_token] * (batch_size - length)
|
| 64 |
+
|
| 65 |
+
subwords_batches.append([tok.convert_tokens_to_ids(token)
|
| 66 |
+
for token in batch])
|
| 67 |
+
start += length
|
| 68 |
+
|
| 69 |
+
return np.array(subwords_batches)
|
stanza/stanza/models/langid/__init__.py
ADDED
|
File without changes
|
stanza/stanza/models/langid/data.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import random
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class DataLoader:
|
| 7 |
+
"""
|
| 8 |
+
Class for loading language id data and providing batches
|
| 9 |
+
|
| 10 |
+
Attempt to recreate data pre-processing from: https://github.com/AU-DIS/LSTM_langid
|
| 11 |
+
|
| 12 |
+
Uses methods from: https://github.com/AU-DIS/LSTM_langid/blob/main/src/language_datasets.py
|
| 13 |
+
|
| 14 |
+
Data format is same as LSTM_langid
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def __init__(self, device=None):
|
| 18 |
+
self.batches = None
|
| 19 |
+
self.batches_iter = None
|
| 20 |
+
self.tag_to_idx = None
|
| 21 |
+
self.idx_to_tag = None
|
| 22 |
+
self.lang_weights = None
|
| 23 |
+
self.device = device
|
| 24 |
+
|
| 25 |
+
def load_data(self, batch_size, data_files, char_index, tag_index, randomize=False, randomize_range=(5,20),
|
| 26 |
+
max_length=None):
|
| 27 |
+
"""
|
| 28 |
+
Load sequence data and labels, calculate weights for weighted cross entropy loss.
|
| 29 |
+
Data is stored in a file, 1 example per line
|
| 30 |
+
Example: {"text": "Hello world.", "label": "en"}
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
# set up examples from data files
|
| 34 |
+
examples = []
|
| 35 |
+
for data_file in data_files:
|
| 36 |
+
examples += [x for x in open(data_file).read().split("\n") if x.strip()]
|
| 37 |
+
random.shuffle(examples)
|
| 38 |
+
examples = [json.loads(x) for x in examples]
|
| 39 |
+
|
| 40 |
+
# add additional labels in this data set to tag index
|
| 41 |
+
tag_index = dict(tag_index)
|
| 42 |
+
new_labels = set([x["label"] for x in examples]) - set(tag_index.keys())
|
| 43 |
+
for new_label in new_labels:
|
| 44 |
+
tag_index[new_label] = len(tag_index)
|
| 45 |
+
self.tag_to_idx = tag_index
|
| 46 |
+
self.idx_to_tag = [i[1] for i in sorted([(v,k) for k,v in self.tag_to_idx.items()])]
|
| 47 |
+
|
| 48 |
+
# set up lang counts used for weights for cross entropy loss
|
| 49 |
+
lang_counts = [0 for _ in tag_index]
|
| 50 |
+
|
| 51 |
+
# optionally limit text to max length
|
| 52 |
+
if max_length is not None:
|
| 53 |
+
examples = [{"text": x["text"][:max_length], "label": x["label"]} for x in examples]
|
| 54 |
+
|
| 55 |
+
# randomize data
|
| 56 |
+
if randomize:
|
| 57 |
+
split_examples = []
|
| 58 |
+
for example in examples:
|
| 59 |
+
sequence = example["text"]
|
| 60 |
+
label = example["label"]
|
| 61 |
+
sequences = DataLoader.randomize_data([sequence], upper_lim=randomize_range[1],
|
| 62 |
+
lower_lim=randomize_range[0])
|
| 63 |
+
split_examples += [{"text": seq, "label": label} for seq in sequences]
|
| 64 |
+
examples = split_examples
|
| 65 |
+
random.shuffle(examples)
|
| 66 |
+
|
| 67 |
+
# break into equal length batches
|
| 68 |
+
batch_lengths = {}
|
| 69 |
+
for example in examples:
|
| 70 |
+
sequence = example["text"]
|
| 71 |
+
label = example["label"]
|
| 72 |
+
if len(sequence) not in batch_lengths:
|
| 73 |
+
batch_lengths[len(sequence)] = []
|
| 74 |
+
sequence_as_list = [char_index.get(c, char_index["UNK"]) for c in list(sequence)]
|
| 75 |
+
batch_lengths[len(sequence)].append((sequence_as_list, tag_index[label]))
|
| 76 |
+
lang_counts[tag_index[label]] += 1
|
| 77 |
+
for length in batch_lengths:
|
| 78 |
+
random.shuffle(batch_lengths[length])
|
| 79 |
+
|
| 80 |
+
# create final set of batches
|
| 81 |
+
batches = []
|
| 82 |
+
for length in batch_lengths:
|
| 83 |
+
for sublist in [batch_lengths[length][i:i + batch_size] for i in
|
| 84 |
+
range(0, len(batch_lengths[length]), batch_size)]:
|
| 85 |
+
batches.append(sublist)
|
| 86 |
+
|
| 87 |
+
self.batches = [self.build_batch_tensors(batch) for batch in batches]
|
| 88 |
+
|
| 89 |
+
# set up lang weights
|
| 90 |
+
most_frequent = max(lang_counts)
|
| 91 |
+
# set to 0.0 if lang_count is 0 or most_frequent/lang_count otherwise
|
| 92 |
+
lang_counts = [(most_frequent * x)/(max(1, x) ** 2) for x in lang_counts]
|
| 93 |
+
self.lang_weights = torch.tensor(lang_counts, device=self.device, dtype=torch.float)
|
| 94 |
+
|
| 95 |
+
# shuffle batches to mix up lengths
|
| 96 |
+
random.shuffle(self.batches)
|
| 97 |
+
self.batches_iter = iter(self.batches)
|
| 98 |
+
|
| 99 |
+
@staticmethod
|
| 100 |
+
def randomize_data(sentences, upper_lim=20, lower_lim=5):
|
| 101 |
+
"""
|
| 102 |
+
Takes the original data and creates random length examples with length between upper limit and lower limit
|
| 103 |
+
From LSTM_langid: https://github.com/AU-DIS/LSTM_langid/blob/main/src/language_datasets.py
|
| 104 |
+
"""
|
| 105 |
+
|
| 106 |
+
new_data = []
|
| 107 |
+
for sentence in sentences:
|
| 108 |
+
remaining = sentence
|
| 109 |
+
while lower_lim < len(remaining):
|
| 110 |
+
lim = random.randint(lower_lim, upper_lim)
|
| 111 |
+
m = min(len(remaining), lim)
|
| 112 |
+
new_sentence = remaining[:m]
|
| 113 |
+
new_data.append(new_sentence)
|
| 114 |
+
split = remaining[m:].split(" ", 1)
|
| 115 |
+
if len(split) <= 1:
|
| 116 |
+
break
|
| 117 |
+
remaining = split[1]
|
| 118 |
+
random.shuffle(new_data)
|
| 119 |
+
return new_data
|
| 120 |
+
|
| 121 |
+
def build_batch_tensors(self, batch):
|
| 122 |
+
"""
|
| 123 |
+
Helper to turn batches into tensors
|
| 124 |
+
"""
|
| 125 |
+
|
| 126 |
+
batch_tensors = dict()
|
| 127 |
+
batch_tensors["sentences"] = torch.tensor([s[0] for s in batch], device=self.device, dtype=torch.long)
|
| 128 |
+
batch_tensors["targets"] = torch.tensor([s[1] for s in batch], device=self.device, dtype=torch.long)
|
| 129 |
+
|
| 130 |
+
return batch_tensors
|
| 131 |
+
|
| 132 |
+
def next(self):
|
| 133 |
+
return next(self.batches_iter)
|
| 134 |
+
|
stanza/stanza/models/langid/model.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class LangIDBiLSTM(nn.Module):
|
| 8 |
+
"""
|
| 9 |
+
Multi-layer BiLSTM model for language detecting. A recreation of "A reproduction of Apple's bi-directional LSTM models
|
| 10 |
+
for language identification in short strings." (Toftrup et al 2021)
|
| 11 |
+
|
| 12 |
+
Arxiv: https://arxiv.org/abs/2102.06282
|
| 13 |
+
GitHub: https://github.com/AU-DIS/LSTM_langid
|
| 14 |
+
|
| 15 |
+
This class is similar to https://github.com/AU-DIS/LSTM_langid/blob/main/src/LSTMLID.py
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(self, char_to_idx, tag_to_idx, num_layers, embedding_dim, hidden_dim, batch_size=64, weights=None,
|
| 19 |
+
dropout=0.0, lang_subset=None):
|
| 20 |
+
super(LangIDBiLSTM, self).__init__()
|
| 21 |
+
self.num_layers = num_layers
|
| 22 |
+
self.embedding_dim = embedding_dim
|
| 23 |
+
self.hidden_dim = hidden_dim
|
| 24 |
+
self.char_to_idx = char_to_idx
|
| 25 |
+
self.vocab_size = len(char_to_idx)
|
| 26 |
+
self.tag_to_idx = tag_to_idx
|
| 27 |
+
self.idx_to_tag = [i[1] for i in sorted([(v,k) for k,v in self.tag_to_idx.items()])]
|
| 28 |
+
self.lang_subset = lang_subset
|
| 29 |
+
self.padding_idx = char_to_idx["<PAD>"]
|
| 30 |
+
self.tagset_size = len(tag_to_idx)
|
| 31 |
+
self.batch_size = batch_size
|
| 32 |
+
self.loss_train = nn.CrossEntropyLoss(weight=weights)
|
| 33 |
+
self.dropout_prob = dropout
|
| 34 |
+
|
| 35 |
+
# embeddings for chars
|
| 36 |
+
self.char_embeds = nn.Embedding(
|
| 37 |
+
num_embeddings=self.vocab_size,
|
| 38 |
+
embedding_dim=self.embedding_dim,
|
| 39 |
+
padding_idx=self.padding_idx
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
# the bidirectional LSTM
|
| 43 |
+
self.lstm = nn.LSTM(
|
| 44 |
+
self.embedding_dim,
|
| 45 |
+
self.hidden_dim,
|
| 46 |
+
num_layers=self.num_layers,
|
| 47 |
+
bidirectional=True,
|
| 48 |
+
batch_first=True
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
# convert output to tag space
|
| 52 |
+
self.hidden_to_tag = nn.Linear(
|
| 53 |
+
self.hidden_dim * 2,
|
| 54 |
+
self.tagset_size
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
# dropout layer
|
| 58 |
+
self.dropout = nn.Dropout(p=self.dropout_prob)
|
| 59 |
+
|
| 60 |
+
def build_lang_mask(self, device):
|
| 61 |
+
"""
|
| 62 |
+
Build language mask if a lang subset is specified (e.g. ["en", "fr"])
|
| 63 |
+
|
| 64 |
+
The mask will be added to the results to set the prediction scores of illegal languages to -inf
|
| 65 |
+
"""
|
| 66 |
+
if self.lang_subset:
|
| 67 |
+
lang_mask_list = [0.0 if lang in self.lang_subset else -float('inf') for lang in self.idx_to_tag]
|
| 68 |
+
self.lang_mask = torch.tensor(lang_mask_list, device=device, dtype=torch.float)
|
| 69 |
+
else:
|
| 70 |
+
self.lang_mask = torch.zeros(len(self.idx_to_tag), device=device, dtype=torch.float)
|
| 71 |
+
|
| 72 |
+
def loss(self, Y_hat, Y):
|
| 73 |
+
return self.loss_train(Y_hat, Y)
|
| 74 |
+
|
| 75 |
+
def forward(self, x):
|
| 76 |
+
# embed input
|
| 77 |
+
x = self.char_embeds(x)
|
| 78 |
+
|
| 79 |
+
# run through LSTM
|
| 80 |
+
x, _ = self.lstm(x)
|
| 81 |
+
|
| 82 |
+
# run through linear layer
|
| 83 |
+
x = self.hidden_to_tag(x)
|
| 84 |
+
|
| 85 |
+
# sum character outputs for each sequence
|
| 86 |
+
x = torch.sum(x, dim=1)
|
| 87 |
+
|
| 88 |
+
return x
|
| 89 |
+
|
| 90 |
+
def prediction_scores(self, x):
|
| 91 |
+
prediction_probs = self(x)
|
| 92 |
+
if self.lang_subset:
|
| 93 |
+
prediction_batch_size = prediction_probs.size()[0]
|
| 94 |
+
batch_mask = torch.stack([self.lang_mask for _ in range(prediction_batch_size)])
|
| 95 |
+
prediction_probs = prediction_probs + batch_mask
|
| 96 |
+
return torch.argmax(prediction_probs, dim=1)
|
| 97 |
+
|
| 98 |
+
def save(self, path):
|
| 99 |
+
""" Save a model at path """
|
| 100 |
+
checkpoint = {
|
| 101 |
+
"char_to_idx": self.char_to_idx,
|
| 102 |
+
"tag_to_idx": self.tag_to_idx,
|
| 103 |
+
"num_layers": self.num_layers,
|
| 104 |
+
"embedding_dim": self.embedding_dim,
|
| 105 |
+
"hidden_dim": self.hidden_dim,
|
| 106 |
+
"model_state_dict": self.state_dict()
|
| 107 |
+
}
|
| 108 |
+
torch.save(checkpoint, path)
|
| 109 |
+
|
| 110 |
+
@classmethod
|
| 111 |
+
def load(cls, path, device=None, batch_size=64, lang_subset=None):
|
| 112 |
+
""" Load a serialized model located at path """
|
| 113 |
+
if path is None:
|
| 114 |
+
raise FileNotFoundError("Trying to load langid model, but path not specified! Try --load_name")
|
| 115 |
+
if not os.path.exists(path):
|
| 116 |
+
raise FileNotFoundError("Trying to load langid model from path which does not exist: %s" % path)
|
| 117 |
+
checkpoint = torch.load(path, map_location=torch.device("cpu"), weights_only=True)
|
| 118 |
+
weights = checkpoint["model_state_dict"]["loss_train.weight"]
|
| 119 |
+
model = cls(checkpoint["char_to_idx"], checkpoint["tag_to_idx"], checkpoint["num_layers"],
|
| 120 |
+
checkpoint["embedding_dim"], checkpoint["hidden_dim"], batch_size=batch_size, weights=weights,
|
| 121 |
+
lang_subset=lang_subset)
|
| 122 |
+
model.load_state_dict(checkpoint["model_state_dict"])
|
| 123 |
+
model = model.to(device)
|
| 124 |
+
model.build_lang_mask(device)
|
| 125 |
+
return model
|
| 126 |
+
|
stanza/stanza/models/lemma_classifier/__init__.py
ADDED
|
File without changes
|
stanza/stanza/models/ner/model.py
ADDED
|
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import logging
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, pack_sequence, pad_sequence, PackedSequence
|
| 9 |
+
|
| 10 |
+
from stanza.models.common.data import map_to_ids, get_long_tensor
|
| 11 |
+
from stanza.models.common.exceptions import ForwardCharlmNotFoundError, BackwardCharlmNotFoundError
|
| 12 |
+
from stanza.models.common.packed_lstm import PackedLSTM
|
| 13 |
+
from stanza.models.common.dropout import WordDropout, LockedDropout
|
| 14 |
+
from stanza.models.common.char_model import CharacterModel, CharacterLanguageModel
|
| 15 |
+
from stanza.models.common.crf import CRFLoss
|
| 16 |
+
from stanza.models.common.foundation_cache import load_bert
|
| 17 |
+
from stanza.models.common.utils import attach_bert_model
|
| 18 |
+
from stanza.models.common.vocab import PAD_ID, UNK_ID, EMPTY_ID
|
| 19 |
+
from stanza.models.common.bert_embedding import extract_bert_embeddings
|
| 20 |
+
|
| 21 |
+
logger = logging.getLogger('stanza')
|
| 22 |
+
|
| 23 |
+
# this gets created in two places in trainer
|
| 24 |
+
# in both places, pass in the bert model & tokenizer
|
| 25 |
+
class NERTagger(nn.Module):
|
| 26 |
+
def __init__(self, args, vocab, emb_matrix=None, foundation_cache=None, bert_model=None, bert_tokenizer=None, force_bert_saved=False, peft_name=None):
|
| 27 |
+
super().__init__()
|
| 28 |
+
|
| 29 |
+
self.vocab = vocab
|
| 30 |
+
self.args = args
|
| 31 |
+
self.unsaved_modules = []
|
| 32 |
+
|
| 33 |
+
# input layers
|
| 34 |
+
input_size = 0
|
| 35 |
+
if self.args['word_emb_dim'] > 0:
|
| 36 |
+
emb_finetune = self.args.get('emb_finetune', True)
|
| 37 |
+
|
| 38 |
+
# load pretrained embeddings if specified
|
| 39 |
+
word_emb = nn.Embedding(len(self.vocab['word']), self.args['word_emb_dim'], PAD_ID)
|
| 40 |
+
# if a model trained with no 'delta' vocab is loaded, and
|
| 41 |
+
# emb_finetune is off, any resaving of the model will need
|
| 42 |
+
# the updated vectors. this is accounted for in load()
|
| 43 |
+
if not emb_finetune or 'delta' in self.vocab:
|
| 44 |
+
# if emb_finetune is off
|
| 45 |
+
# or if the delta embedding is present
|
| 46 |
+
# then we won't fine tune the original embedding
|
| 47 |
+
self.add_unsaved_module('word_emb', word_emb)
|
| 48 |
+
self.word_emb.weight.detach_()
|
| 49 |
+
else:
|
| 50 |
+
self.word_emb = word_emb
|
| 51 |
+
if emb_matrix is not None:
|
| 52 |
+
self.init_emb(emb_matrix)
|
| 53 |
+
|
| 54 |
+
# TODO: allow for expansion of delta embedding if new
|
| 55 |
+
# training data has new words in it?
|
| 56 |
+
self.delta_emb = None
|
| 57 |
+
if 'delta' in self.vocab:
|
| 58 |
+
# zero inits seems to work better
|
| 59 |
+
# note that the gradient will flow to the bottom and then adjust the 0 weights
|
| 60 |
+
# as opposed to a 0 matrix cutting off the gradient if higher up in the model
|
| 61 |
+
self.delta_emb = nn.Embedding(len(self.vocab['delta']), self.args['word_emb_dim'], PAD_ID)
|
| 62 |
+
nn.init.zeros_(self.delta_emb.weight)
|
| 63 |
+
# if the model was trained with a delta embedding, but emb_finetune is off now,
|
| 64 |
+
# then we will detach the delta embedding
|
| 65 |
+
if not emb_finetune:
|
| 66 |
+
self.delta_emb.weight.detach_()
|
| 67 |
+
|
| 68 |
+
input_size += self.args['word_emb_dim']
|
| 69 |
+
|
| 70 |
+
self.peft_name = peft_name
|
| 71 |
+
attach_bert_model(self, bert_model, bert_tokenizer, self.args.get('use_peft', False), force_bert_saved)
|
| 72 |
+
if self.args.get('bert_model', None):
|
| 73 |
+
# TODO: refactor bert_hidden_layers between the different models
|
| 74 |
+
if args.get('bert_hidden_layers', False):
|
| 75 |
+
# The average will be offset by 1/N so that the default zeros
|
| 76 |
+
# represents an average of the N layers
|
| 77 |
+
self.bert_layer_mix = nn.Linear(args['bert_hidden_layers'], 1, bias=False)
|
| 78 |
+
nn.init.zeros_(self.bert_layer_mix.weight)
|
| 79 |
+
else:
|
| 80 |
+
# an average of layers 2, 3, 4 will be used
|
| 81 |
+
# (for historic reasons)
|
| 82 |
+
self.bert_layer_mix = None
|
| 83 |
+
input_size += self.bert_model.config.hidden_size
|
| 84 |
+
|
| 85 |
+
if self.args['char'] and self.args['char_emb_dim'] > 0:
|
| 86 |
+
if self.args['charlm']:
|
| 87 |
+
if args['charlm_forward_file'] is None or not os.path.exists(args['charlm_forward_file']):
|
| 88 |
+
raise ForwardCharlmNotFoundError('Could not find forward character model: {} Please specify with --charlm_forward_file'.format(args['charlm_forward_file']), args['charlm_forward_file'])
|
| 89 |
+
if args['charlm_backward_file'] is None or not os.path.exists(args['charlm_backward_file']):
|
| 90 |
+
raise BackwardCharlmNotFoundError('Could not find backward character model: {} Please specify with --charlm_backward_file'.format(args['charlm_backward_file']), args['charlm_backward_file'])
|
| 91 |
+
self.add_unsaved_module('charmodel_forward', CharacterLanguageModel.load(args['charlm_forward_file'], finetune=False))
|
| 92 |
+
self.add_unsaved_module('charmodel_backward', CharacterLanguageModel.load(args['charlm_backward_file'], finetune=False))
|
| 93 |
+
input_size += self.charmodel_forward.hidden_dim() + self.charmodel_backward.hidden_dim()
|
| 94 |
+
else:
|
| 95 |
+
self.charmodel = CharacterModel(args, vocab, bidirectional=True, attention=False)
|
| 96 |
+
input_size += self.args['char_hidden_dim'] * 2
|
| 97 |
+
|
| 98 |
+
# optionally add a input transformation layer
|
| 99 |
+
if self.args.get('input_transform', False):
|
| 100 |
+
self.input_transform = nn.Linear(input_size, input_size)
|
| 101 |
+
else:
|
| 102 |
+
self.input_transform = None
|
| 103 |
+
|
| 104 |
+
# recurrent layers
|
| 105 |
+
self.taggerlstm = PackedLSTM(input_size, self.args['hidden_dim'], self.args['num_layers'], batch_first=True, \
|
| 106 |
+
bidirectional=True, dropout=0 if self.args['num_layers'] == 1 else self.args['dropout'])
|
| 107 |
+
# self.drop_replacement = nn.Parameter(torch.randn(input_size) / np.sqrt(input_size))
|
| 108 |
+
self.drop_replacement = None
|
| 109 |
+
self.taggerlstm_h_init = nn.Parameter(torch.zeros(2 * self.args['num_layers'], 1, self.args['hidden_dim']), requires_grad=False)
|
| 110 |
+
self.taggerlstm_c_init = nn.Parameter(torch.zeros(2 * self.args['num_layers'], 1, self.args['hidden_dim']), requires_grad=False)
|
| 111 |
+
|
| 112 |
+
# tag classifier
|
| 113 |
+
tag_lengths = self.vocab['tag'].lens()
|
| 114 |
+
self.num_output_layers = len(tag_lengths)
|
| 115 |
+
if self.args.get('connect_output_layers'):
|
| 116 |
+
tag_clfs = [nn.Linear(self.args['hidden_dim']*2, tag_lengths[0])]
|
| 117 |
+
for prev_length, next_length in zip(tag_lengths[:-1], tag_lengths[1:]):
|
| 118 |
+
tag_clfs.append(nn.Linear(self.args['hidden_dim']*2 + prev_length, next_length))
|
| 119 |
+
self.tag_clfs = nn.ModuleList(tag_clfs)
|
| 120 |
+
else:
|
| 121 |
+
self.tag_clfs = nn.ModuleList([nn.Linear(self.args['hidden_dim']*2, num_tag) for num_tag in tag_lengths])
|
| 122 |
+
for tag_clf in self.tag_clfs:
|
| 123 |
+
tag_clf.bias.data.zero_()
|
| 124 |
+
self.crits = nn.ModuleList([CRFLoss(num_tag) for num_tag in tag_lengths])
|
| 125 |
+
|
| 126 |
+
self.drop = nn.Dropout(args['dropout'])
|
| 127 |
+
self.worddrop = WordDropout(args['word_dropout'])
|
| 128 |
+
self.lockeddrop = LockedDropout(args['locked_dropout'])
|
| 129 |
+
|
| 130 |
+
def init_emb(self, emb_matrix):
|
| 131 |
+
if isinstance(emb_matrix, np.ndarray):
|
| 132 |
+
emb_matrix = torch.from_numpy(emb_matrix)
|
| 133 |
+
vocab_size = len(self.vocab['word'])
|
| 134 |
+
dim = self.args['word_emb_dim']
|
| 135 |
+
assert emb_matrix.size() == (vocab_size, dim), \
|
| 136 |
+
"Input embedding matrix must match size: {} x {}, found {}".format(vocab_size, dim, emb_matrix.size())
|
| 137 |
+
self.word_emb.weight.data.copy_(emb_matrix)
|
| 138 |
+
|
| 139 |
+
def add_unsaved_module(self, name, module):
|
| 140 |
+
self.unsaved_modules += [name]
|
| 141 |
+
setattr(self, name, module)
|
| 142 |
+
|
| 143 |
+
def log_norms(self):
|
| 144 |
+
lines = ["NORMS FOR MODEL PARAMTERS"]
|
| 145 |
+
for name, param in self.named_parameters():
|
| 146 |
+
if param.requires_grad and name.split(".")[0] not in ('charmodel_forward', 'charmodel_backward'):
|
| 147 |
+
lines.append(" %s %.6g" % (name, torch.norm(param).item()))
|
| 148 |
+
logger.info("\n".join(lines))
|
| 149 |
+
|
| 150 |
+
def forward(self, sentences, wordchars, wordchars_mask, tags, word_orig_idx, sentlens, wordlens, chars, charoffsets, charlens, char_orig_idx):
|
| 151 |
+
device = next(self.parameters()).device
|
| 152 |
+
|
| 153 |
+
def pack(x):
|
| 154 |
+
return pack_padded_sequence(x, sentlens, batch_first=True)
|
| 155 |
+
|
| 156 |
+
inputs = []
|
| 157 |
+
batch_size = len(sentences)
|
| 158 |
+
|
| 159 |
+
if self.args['word_emb_dim'] > 0:
|
| 160 |
+
#extract static embeddings
|
| 161 |
+
static_words, word_mask = self.extract_static_embeddings(self.args, sentences, self.vocab['word'])
|
| 162 |
+
|
| 163 |
+
word_mask = word_mask.to(device)
|
| 164 |
+
static_words = static_words.to(device)
|
| 165 |
+
|
| 166 |
+
word_static_emb = self.word_emb(static_words)
|
| 167 |
+
|
| 168 |
+
if 'delta' in self.vocab and self.delta_emb is not None:
|
| 169 |
+
# masks should be the same
|
| 170 |
+
delta_words, _ = self.extract_static_embeddings(self.args, sentences, self.vocab['delta'])
|
| 171 |
+
delta_words = delta_words.to(device)
|
| 172 |
+
# unclear whether to treat words in the main embedding
|
| 173 |
+
# but not in delta as unknown
|
| 174 |
+
# simple heuristic though - treating them as not
|
| 175 |
+
# unknown keeps existing models the same when
|
| 176 |
+
# separating models into the base WV and delta WV
|
| 177 |
+
# also, note that at training time, words like this
|
| 178 |
+
# did not show up in the training data, but are
|
| 179 |
+
# not exactly UNK, so it makes sense
|
| 180 |
+
delta_unk_mask = torch.eq(delta_words, UNK_ID)
|
| 181 |
+
static_unk_mask = torch.not_equal(static_words, UNK_ID)
|
| 182 |
+
unk_mask = delta_unk_mask * static_unk_mask
|
| 183 |
+
delta_words[unk_mask] = PAD_ID
|
| 184 |
+
|
| 185 |
+
delta_emb = self.delta_emb(delta_words)
|
| 186 |
+
word_static_emb = word_static_emb + delta_emb
|
| 187 |
+
|
| 188 |
+
word_emb = pack(word_static_emb)
|
| 189 |
+
inputs += [word_emb]
|
| 190 |
+
|
| 191 |
+
if self.bert_model is not None:
|
| 192 |
+
device = next(self.parameters()).device
|
| 193 |
+
processed_bert = extract_bert_embeddings(self.args['bert_model'], self.bert_tokenizer, self.bert_model, sentences, device, keep_endpoints=False,
|
| 194 |
+
num_layers=self.bert_layer_mix.in_features if self.bert_layer_mix is not None else None,
|
| 195 |
+
detach=not self.args.get('bert_finetune', False),
|
| 196 |
+
peft_name=self.peft_name)
|
| 197 |
+
if self.bert_layer_mix is not None:
|
| 198 |
+
# use a linear layer to weighted average the embedding dynamically
|
| 199 |
+
processed_bert = [self.bert_layer_mix(feature).squeeze(2) + feature.sum(axis=2) / self.bert_layer_mix.in_features for feature in processed_bert]
|
| 200 |
+
|
| 201 |
+
processed_bert = pad_sequence(processed_bert, batch_first=True)
|
| 202 |
+
inputs += [pack(processed_bert)]
|
| 203 |
+
|
| 204 |
+
def pad(x):
|
| 205 |
+
return pad_packed_sequence(PackedSequence(x, word_emb.batch_sizes), batch_first=True)[0]
|
| 206 |
+
|
| 207 |
+
if self.args['char'] and self.args['char_emb_dim'] > 0:
|
| 208 |
+
if self.args.get('charlm', None):
|
| 209 |
+
char_reps_forward = self.charmodel_forward.get_representation(chars[0], charoffsets[0], charlens, char_orig_idx)
|
| 210 |
+
char_reps_forward = PackedSequence(char_reps_forward.data, char_reps_forward.batch_sizes)
|
| 211 |
+
char_reps_backward = self.charmodel_backward.get_representation(chars[1], charoffsets[1], charlens, char_orig_idx)
|
| 212 |
+
char_reps_backward = PackedSequence(char_reps_backward.data, char_reps_backward.batch_sizes)
|
| 213 |
+
inputs += [char_reps_forward, char_reps_backward]
|
| 214 |
+
else:
|
| 215 |
+
char_reps = self.charmodel(wordchars, wordchars_mask, word_orig_idx, sentlens, wordlens)
|
| 216 |
+
char_reps = PackedSequence(char_reps.data, char_reps.batch_sizes)
|
| 217 |
+
inputs += [char_reps]
|
| 218 |
+
|
| 219 |
+
lstm_inputs = torch.cat([x.data for x in inputs], 1)
|
| 220 |
+
if self.args['word_dropout'] > 0:
|
| 221 |
+
lstm_inputs = self.worddrop(lstm_inputs, self.drop_replacement)
|
| 222 |
+
lstm_inputs = self.drop(lstm_inputs)
|
| 223 |
+
lstm_inputs = pad(lstm_inputs)
|
| 224 |
+
lstm_inputs = self.lockeddrop(lstm_inputs)
|
| 225 |
+
lstm_inputs = pack(lstm_inputs).data
|
| 226 |
+
|
| 227 |
+
if self.input_transform:
|
| 228 |
+
lstm_inputs = self.input_transform(lstm_inputs)
|
| 229 |
+
|
| 230 |
+
lstm_inputs = PackedSequence(lstm_inputs, inputs[0].batch_sizes)
|
| 231 |
+
lstm_outputs, _ = self.taggerlstm(lstm_inputs, sentlens, hx=(\
|
| 232 |
+
self.taggerlstm_h_init.expand(2 * self.args['num_layers'], batch_size, self.args['hidden_dim']).contiguous(), \
|
| 233 |
+
self.taggerlstm_c_init.expand(2 * self.args['num_layers'], batch_size, self.args['hidden_dim']).contiguous()))
|
| 234 |
+
lstm_outputs = lstm_outputs.data
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
# prediction layer
|
| 238 |
+
lstm_outputs = self.drop(lstm_outputs)
|
| 239 |
+
lstm_outputs = pad(lstm_outputs)
|
| 240 |
+
lstm_outputs = self.lockeddrop(lstm_outputs)
|
| 241 |
+
lstm_outputs = pack(lstm_outputs).data
|
| 242 |
+
|
| 243 |
+
loss = 0
|
| 244 |
+
logits = []
|
| 245 |
+
trans = []
|
| 246 |
+
for idx, (tag_clf, crit) in enumerate(zip(self.tag_clfs, self.crits)):
|
| 247 |
+
if not self.args.get('connect_output_layers') or idx == 0:
|
| 248 |
+
next_logits = pad(tag_clf(lstm_outputs)).contiguous()
|
| 249 |
+
else:
|
| 250 |
+
# here we pack the output of the previous round, then append it
|
| 251 |
+
packed_logits = pack(next_logits).data
|
| 252 |
+
input_logits = torch.cat([lstm_outputs, packed_logits], axis=1)
|
| 253 |
+
next_logits = pad(tag_clf(input_logits)).contiguous()
|
| 254 |
+
# the tag_mask lets us avoid backprop on a blank tag
|
| 255 |
+
tag_mask = torch.eq(tags[:, :, idx], EMPTY_ID)
|
| 256 |
+
next_loss, next_trans = crit(next_logits, torch.bitwise_or(tag_mask, word_mask), tags[:, :, idx])
|
| 257 |
+
loss = loss + next_loss
|
| 258 |
+
logits.append(next_logits)
|
| 259 |
+
trans.append(next_trans)
|
| 260 |
+
|
| 261 |
+
return loss, logits, trans
|
| 262 |
+
|
| 263 |
+
@staticmethod
|
| 264 |
+
def extract_static_embeddings(args, sents, vocab):
|
| 265 |
+
processed = []
|
| 266 |
+
if args.get('lowercase', True): # handle word case
|
| 267 |
+
case = lambda x: x.lower()
|
| 268 |
+
else:
|
| 269 |
+
case = lambda x: x
|
| 270 |
+
for idx, sent in enumerate(sents):
|
| 271 |
+
processed_sent = [vocab.map([case(w) for w in sent])]
|
| 272 |
+
processed.append(processed_sent[0])
|
| 273 |
+
|
| 274 |
+
words = get_long_tensor(processed, len(sents))
|
| 275 |
+
words_mask = torch.eq(words, PAD_ID)
|
| 276 |
+
|
| 277 |
+
return words, words_mask
|
| 278 |
+
|
stanza/stanza/models/pos/scorer.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utils and wrappers for scoring taggers.
|
| 3 |
+
"""
|
| 4 |
+
import logging
|
| 5 |
+
|
| 6 |
+
from stanza.models.common.utils import ud_scores
|
| 7 |
+
|
| 8 |
+
logger = logging.getLogger('stanza')
|
| 9 |
+
|
| 10 |
+
def score(system_conllu_file, gold_conllu_file, verbose=True, eval_type='AllTags'):
|
| 11 |
+
""" Wrapper for tagger scorer. """
|
| 12 |
+
evaluation = ud_scores(gold_conllu_file, system_conllu_file)
|
| 13 |
+
el = evaluation[eval_type]
|
| 14 |
+
p = el.precision
|
| 15 |
+
r = el.recall
|
| 16 |
+
f = el.f1
|
| 17 |
+
if verbose:
|
| 18 |
+
scores = [evaluation[k].f1 * 100 for k in ['UPOS', 'XPOS', 'UFeats', 'AllTags']]
|
| 19 |
+
logger.info("UPOS\tXPOS\tUFeats\tAllTags")
|
| 20 |
+
logger.info("{:.2f}\t{:.2f}\t{:.2f}\t{:.2f}".format(*scores))
|
| 21 |
+
return p, r, f
|
| 22 |
+
|
stanza/stanza/models/pos/vocab.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import Counter, OrderedDict
|
| 2 |
+
|
| 3 |
+
from stanza.models.common.vocab import BaseVocab, BaseMultiVocab, CharVocab
|
| 4 |
+
from stanza.models.common.vocab import CompositeVocab, VOCAB_PREFIX, EMPTY, EMPTY_ID
|
| 5 |
+
|
| 6 |
+
class WordVocab(BaseVocab):
|
| 7 |
+
def __init__(self, data=None, lang="", idx=0, cutoff=0, lower=False, ignore=None):
|
| 8 |
+
self.ignore = ignore if ignore is not None else []
|
| 9 |
+
super().__init__(data, lang=lang, idx=idx, cutoff=cutoff, lower=lower)
|
| 10 |
+
self.state_attrs += ['ignore']
|
| 11 |
+
|
| 12 |
+
def id2unit(self, id):
|
| 13 |
+
if len(self.ignore) > 0 and id == EMPTY_ID:
|
| 14 |
+
return '_'
|
| 15 |
+
else:
|
| 16 |
+
return super().id2unit(id)
|
| 17 |
+
|
| 18 |
+
def unit2id(self, unit):
|
| 19 |
+
if len(self.ignore) > 0 and unit in self.ignore:
|
| 20 |
+
return self._unit2id[EMPTY]
|
| 21 |
+
else:
|
| 22 |
+
return super().unit2id(unit)
|
| 23 |
+
|
| 24 |
+
def build_vocab(self):
|
| 25 |
+
if self.lower:
|
| 26 |
+
counter = Counter([w[self.idx].lower() for sent in self.data for w in sent])
|
| 27 |
+
else:
|
| 28 |
+
counter = Counter([w[self.idx] for sent in self.data for w in sent])
|
| 29 |
+
for k in list(counter.keys()):
|
| 30 |
+
if counter[k] < self.cutoff or k in self.ignore:
|
| 31 |
+
del counter[k]
|
| 32 |
+
|
| 33 |
+
self._id2unit = VOCAB_PREFIX + list(sorted(list(counter.keys()), key=lambda k: counter[k], reverse=True))
|
| 34 |
+
self._unit2id = {w:i for i, w in enumerate(self._id2unit)}
|
| 35 |
+
|
| 36 |
+
def __str__(self):
|
| 37 |
+
return "<{}: {}>".format(type(self), ",".join("|%s|" % x for x in self._id2unit))
|
| 38 |
+
|
| 39 |
+
class XPOSVocab(CompositeVocab):
|
| 40 |
+
def __init__(self, data=None, lang="", idx=0, sep="", keyed=False):
|
| 41 |
+
super().__init__(data, lang, idx=idx, sep=sep, keyed=keyed)
|
| 42 |
+
|
| 43 |
+
class FeatureVocab(CompositeVocab):
|
| 44 |
+
def __init__(self, data=None, lang="", idx=0, sep="|", keyed=True):
|
| 45 |
+
super().__init__(data, lang, idx=idx, sep=sep, keyed=keyed)
|
| 46 |
+
|
| 47 |
+
class MultiVocab(BaseMultiVocab):
|
| 48 |
+
def state_dict(self):
|
| 49 |
+
""" Also save a vocab name to class name mapping in state dict. """
|
| 50 |
+
state = OrderedDict()
|
| 51 |
+
key2class = OrderedDict()
|
| 52 |
+
for k, v in self._vocabs.items():
|
| 53 |
+
state[k] = v.state_dict()
|
| 54 |
+
key2class[k] = type(v).__name__
|
| 55 |
+
state['_key2class'] = key2class
|
| 56 |
+
return state
|
| 57 |
+
|
| 58 |
+
@classmethod
|
| 59 |
+
def load_state_dict(cls, state_dict):
|
| 60 |
+
class_dict = {'CharVocab': CharVocab,
|
| 61 |
+
'WordVocab': WordVocab,
|
| 62 |
+
'XPOSVocab': XPOSVocab,
|
| 63 |
+
'FeatureVocab': FeatureVocab}
|
| 64 |
+
new = cls()
|
| 65 |
+
assert '_key2class' in state_dict, "Cannot find class name mapping in state dict!"
|
| 66 |
+
key2class = state_dict.pop('_key2class')
|
| 67 |
+
for k,v in state_dict.items():
|
| 68 |
+
classname = key2class[k]
|
| 69 |
+
new[k] = class_dict[classname].load_state_dict(v)
|
| 70 |
+
return new
|
| 71 |
+
|
stanza/stanza/pipeline/demo/stanza-brat.js
ADDED
|
@@ -0,0 +1,1316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Takes Stanford CoreNLP JSON output (var data = ... in data.js)
|
| 2 |
+
// and uses brat to render everything.
|
| 3 |
+
|
| 4 |
+
//var serverAddress = 'http://localhost:5000';
|
| 5 |
+
|
| 6 |
+
// Load Brat libraries
|
| 7 |
+
var bratLocation = 'https://nlp.stanford.edu/js/brat/';
|
| 8 |
+
head.js(
|
| 9 |
+
// External libraries
|
| 10 |
+
bratLocation + '/client/lib/jquery.svg.min.js',
|
| 11 |
+
bratLocation + '/client/lib/jquery.svgdom.min.js',
|
| 12 |
+
|
| 13 |
+
// brat helper modules
|
| 14 |
+
bratLocation + '/client/src/configuration.js',
|
| 15 |
+
bratLocation + '/client/src/util.js',
|
| 16 |
+
bratLocation + '/client/src/annotation_log.js',
|
| 17 |
+
bratLocation + '/client/lib/webfont.js',
|
| 18 |
+
|
| 19 |
+
// brat modules
|
| 20 |
+
bratLocation + '/client/src/dispatcher.js',
|
| 21 |
+
bratLocation + '/client/src/url_monitor.js',
|
| 22 |
+
bratLocation + '/client/src/visualizer.js',
|
| 23 |
+
|
| 24 |
+
// parse viewer
|
| 25 |
+
'./stanza-parseviewer.js'
|
| 26 |
+
);
|
| 27 |
+
|
| 28 |
+
// Uses Dagre (https://github.com/cpettitt/dagre) for constinuency parse
|
| 29 |
+
// visualization. It works better than the brat visualization.
|
| 30 |
+
var useDagre = true;
|
| 31 |
+
var currentQuery = 'The quick brown fox jumped over the lazy dog.';
|
| 32 |
+
var currentSentences = '';
|
| 33 |
+
var currentText = '';
|
| 34 |
+
|
| 35 |
+
// ----------------------------------------------------------------------------
|
| 36 |
+
// HELPERS
|
| 37 |
+
// ----------------------------------------------------------------------------
|
| 38 |
+
|
| 39 |
+
/**
|
| 40 |
+
* Add the startsWith function to the String class
|
| 41 |
+
*/
|
| 42 |
+
if (typeof String.prototype.startsWith !== 'function') {
|
| 43 |
+
// see below for better implementation!
|
| 44 |
+
String.prototype.startsWith = function (str){
|
| 45 |
+
return this.indexOf(str) === 0;
|
| 46 |
+
};
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
function isInt(value) {
|
| 50 |
+
return !isNaN(value) && (function(x) { return (x | 0) === x; })(parseFloat(value))
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
/**
|
| 54 |
+
* A reverse map of PTB tokens to their original gloss
|
| 55 |
+
*/
|
| 56 |
+
var tokensMap = {
|
| 57 |
+
'-LRB-': '(',
|
| 58 |
+
'-RRB-': ')',
|
| 59 |
+
'-LSB-': '[',
|
| 60 |
+
'-RSB-': ']',
|
| 61 |
+
'-LCB-': '{',
|
| 62 |
+
'-RCB-': '}',
|
| 63 |
+
'``': '"',
|
| 64 |
+
'\'\'': '"',
|
| 65 |
+
};
|
| 66 |
+
|
| 67 |
+
/**
|
| 68 |
+
* A mapping from part of speech tag to the associated
|
| 69 |
+
* visualization color
|
| 70 |
+
*/
|
| 71 |
+
function posColor(posTag) {
|
| 72 |
+
if (posTag === null) {
|
| 73 |
+
return '#E3E3E3';
|
| 74 |
+
} else if (posTag.startsWith('N')) {
|
| 75 |
+
return '#A4BCED';
|
| 76 |
+
} else if (posTag.startsWith('V') || posTag.startsWith('M')) {
|
| 77 |
+
return '#ADF6A2';
|
| 78 |
+
} else if (posTag.startsWith('P')) {
|
| 79 |
+
return '#CCDAF6';
|
| 80 |
+
} else if (posTag.startsWith('I')) {
|
| 81 |
+
return '#FFE8BE';
|
| 82 |
+
} else if (posTag.startsWith('R') || posTag.startsWith('W')) {
|
| 83 |
+
return '#FFFDA8';
|
| 84 |
+
} else if (posTag.startsWith('D') || posTag === 'CD') {
|
| 85 |
+
return '#CCADF6';
|
| 86 |
+
} else if (posTag.startsWith('J')) {
|
| 87 |
+
return '#FFFDA8';
|
| 88 |
+
} else if (posTag.startsWith('T')) {
|
| 89 |
+
return '#FFE8BE';
|
| 90 |
+
} else if (posTag.startsWith('E') || posTag.startsWith('S')) {
|
| 91 |
+
return '#E4CBF6';
|
| 92 |
+
} else if (posTag.startsWith('CC')) {
|
| 93 |
+
return '#FFFFFF';
|
| 94 |
+
} else if (posTag === 'LS' || posTag === 'FW') {
|
| 95 |
+
return '#FFFFFF';
|
| 96 |
+
} else {
|
| 97 |
+
return '#E3E3E3';
|
| 98 |
+
}
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
/**
|
| 102 |
+
* A mapping from part of speech tag to the associated
|
| 103 |
+
* visualization color
|
| 104 |
+
*/
|
| 105 |
+
function uposColor(posTag) {
|
| 106 |
+
if (posTag === null) {
|
| 107 |
+
return '#E3E3E3';
|
| 108 |
+
} else if (posTag === 'NOUN' || posTag === 'PROPN') {
|
| 109 |
+
return '#A4BCED';
|
| 110 |
+
} else if (posTag.startsWith('V') || posTag === 'AUX') {
|
| 111 |
+
return '#ADF6A2';
|
| 112 |
+
} else if (posTag === 'PART') {
|
| 113 |
+
return '#CCDAF6';
|
| 114 |
+
} else if (posTag === 'ADP') {
|
| 115 |
+
return '#FFE8BE';
|
| 116 |
+
} else if (posTag === 'ADV' || posTag.startsWith('PRON')) {
|
| 117 |
+
return '#FFFDA8';
|
| 118 |
+
} else if (posTag === 'NUM' || posTag === 'DET') {
|
| 119 |
+
return '#CCADF6';
|
| 120 |
+
} else if (posTag === 'ADJ') {
|
| 121 |
+
return '#FFFDA8';
|
| 122 |
+
} else if (posTag.startsWith('E') || posTag.startsWith('S')) {
|
| 123 |
+
return '#E4CBF6';
|
| 124 |
+
} else if (posTag.startsWith('CC')) {
|
| 125 |
+
return '#FFFFFF';
|
| 126 |
+
} else if (posTag === 'X' || posTag === 'FW') {
|
| 127 |
+
return '#FFFFFF';
|
| 128 |
+
} else {
|
| 129 |
+
return '#E3E3E3';
|
| 130 |
+
}
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
/**
|
| 134 |
+
* A mapping from named entity tag to the associated
|
| 135 |
+
* visualization color
|
| 136 |
+
*/
|
| 137 |
+
function nerColor(nerTag) {
|
| 138 |
+
if (nerTag === null) {
|
| 139 |
+
return '#E3E3E3';
|
| 140 |
+
} else if (nerTag === 'PERSON' || nerTag === 'PER') {
|
| 141 |
+
return '#FFCCAA';
|
| 142 |
+
} else if (nerTag === 'ORGANIZATION' || nerTag === 'ORG') {
|
| 143 |
+
return '#8FB2FF';
|
| 144 |
+
} else if (nerTag === 'MISC') {
|
| 145 |
+
return '#F1F447';
|
| 146 |
+
} else if (nerTag === 'LOCATION' || nerTag == 'LOC') {
|
| 147 |
+
return '#95DFFF';
|
| 148 |
+
} else if (nerTag === 'DATE' || nerTag === 'TIME' || nerTag === 'SET') {
|
| 149 |
+
return '#9AFFE6';
|
| 150 |
+
} else if (nerTag === 'MONEY') {
|
| 151 |
+
return '#FFFFFF';
|
| 152 |
+
} else if (nerTag === 'PERCENT') {
|
| 153 |
+
return '#FFA22B';
|
| 154 |
+
} else {
|
| 155 |
+
return '#E3E3E3';
|
| 156 |
+
}
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
/**
|
| 161 |
+
* A mapping from sentiment value to the associated
|
| 162 |
+
* visualization color
|
| 163 |
+
*/
|
| 164 |
+
function sentimentColor(sentiment) {
|
| 165 |
+
if (sentiment === "VERY POSITIVE") {
|
| 166 |
+
return '#00FF00';
|
| 167 |
+
} else if (sentiment === "POSITIVE") {
|
| 168 |
+
return '#7FFF00';
|
| 169 |
+
} else if (sentiment === "NEUTRAL") {
|
| 170 |
+
return '#FFFF00';
|
| 171 |
+
} else if (sentiment === "NEGATIVE") {
|
| 172 |
+
return '#FF7F00';
|
| 173 |
+
} else if (sentiment === "VERY NEGATIVE") {
|
| 174 |
+
return '#FF0000';
|
| 175 |
+
} else {
|
| 176 |
+
return '#E3E3E3';
|
| 177 |
+
}
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
/**
|
| 182 |
+
* Get a list of annotators, from the annotator option input.
|
| 183 |
+
*/
|
| 184 |
+
function annotators() {
|
| 185 |
+
var annotators = "tokenize,ssplit";
|
| 186 |
+
$('#annotators').find('option:selected').each(function () {
|
| 187 |
+
annotators += "," + $(this).val();
|
| 188 |
+
});
|
| 189 |
+
return annotators;
|
| 190 |
+
}
|
| 191 |
+
|
| 192 |
+
/**
|
| 193 |
+
* Get the input date
|
| 194 |
+
*/
|
| 195 |
+
function date() {
|
| 196 |
+
function f(n) {
|
| 197 |
+
return n < 10 ? '0' + n : n;
|
| 198 |
+
}
|
| 199 |
+
var date = new Date();
|
| 200 |
+
var M = date.getMonth() + 1;
|
| 201 |
+
var D = date.getDate();
|
| 202 |
+
var Y = date.getFullYear();
|
| 203 |
+
var h = date.getHours();
|
| 204 |
+
var m = date.getMinutes();
|
| 205 |
+
var s = date.getSeconds();
|
| 206 |
+
return "" + Y + "-" + f(M) + "-" + f(D) + "T" + f(h) + ':' + f(m) + ':' + f(s);
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
//-----------------------------------------------------------------------------
|
| 211 |
+
// Constituency parser
|
| 212 |
+
//-----------------------------------------------------------------------------
|
| 213 |
+
function ConstituencyParseProcessor() {
|
| 214 |
+
var parenthesize = function (input, list) {
|
| 215 |
+
if (list === undefined) {
|
| 216 |
+
return parenthesize(input, []);
|
| 217 |
+
} else {
|
| 218 |
+
var token = input.shift();
|
| 219 |
+
if (token === undefined) {
|
| 220 |
+
return list.pop();
|
| 221 |
+
} else if (token === "(") {
|
| 222 |
+
list.push(parenthesize(input, []));
|
| 223 |
+
return parenthesize(input, list);
|
| 224 |
+
} else if (token === ")") {
|
| 225 |
+
return list;
|
| 226 |
+
} else {
|
| 227 |
+
return parenthesize(input, list.concat(token));
|
| 228 |
+
}
|
| 229 |
+
}
|
| 230 |
+
};
|
| 231 |
+
|
| 232 |
+
var toTree = function (list) {
|
| 233 |
+
if (list.length === 2 && typeof list[1] === 'string') {
|
| 234 |
+
return {label: list[0], text: list[1], isTerminal: true};
|
| 235 |
+
} else if (list.length >= 2) {
|
| 236 |
+
var label = list.shift();
|
| 237 |
+
var node = {label: label};
|
| 238 |
+
var rest = list.map(function (x) {
|
| 239 |
+
var t = toTree(x);
|
| 240 |
+
if (typeof t === 'object') {
|
| 241 |
+
t.parent = node;
|
| 242 |
+
}
|
| 243 |
+
return t;
|
| 244 |
+
});
|
| 245 |
+
node.children = rest;
|
| 246 |
+
return node;
|
| 247 |
+
} else {
|
| 248 |
+
return list;
|
| 249 |
+
}
|
| 250 |
+
};
|
| 251 |
+
|
| 252 |
+
var indexTree = function (tree, tokens, index) {
|
| 253 |
+
index = index || 0;
|
| 254 |
+
if (tree.isTerminal) {
|
| 255 |
+
tree.token = tokens[index];
|
| 256 |
+
tree.tokenIndex = index;
|
| 257 |
+
tree.tokenStart = index;
|
| 258 |
+
tree.tokenEnd = index + 1;
|
| 259 |
+
return index + 1;
|
| 260 |
+
} else if (tree.children) {
|
| 261 |
+
tree.tokenStart = index;
|
| 262 |
+
for (var i = 0; i < tree.children.length; i++) {
|
| 263 |
+
var child = tree.children[i];
|
| 264 |
+
index = indexTree(child, tokens, index);
|
| 265 |
+
}
|
| 266 |
+
tree.tokenEnd = index;
|
| 267 |
+
}
|
| 268 |
+
return index;
|
| 269 |
+
};
|
| 270 |
+
|
| 271 |
+
var tokenize = function (input) {
|
| 272 |
+
return input.split('"')
|
| 273 |
+
.map(function (x, i) {
|
| 274 |
+
if (i % 2 === 0) { // not in string
|
| 275 |
+
return x.replace(/\(/g, ' ( ')
|
| 276 |
+
.replace(/\)/g, ' ) ');
|
| 277 |
+
} else { // in string
|
| 278 |
+
return x.replace(/ /g, "!whitespace!");
|
| 279 |
+
}
|
| 280 |
+
})
|
| 281 |
+
.join('"')
|
| 282 |
+
.trim()
|
| 283 |
+
.split(/\s+/)
|
| 284 |
+
.map(function (x) {
|
| 285 |
+
return x.replace(/!whitespace!/g, " ");
|
| 286 |
+
});
|
| 287 |
+
};
|
| 288 |
+
|
| 289 |
+
var convertParseStringToTree = function (input, tokens) {
|
| 290 |
+
var p = parenthesize(tokenize(input));
|
| 291 |
+
if (Array.isArray(p)) {
|
| 292 |
+
var tree = toTree(p);
|
| 293 |
+
// Correlate tree with tokens
|
| 294 |
+
indexTree(tree, tokens);
|
| 295 |
+
return tree;
|
| 296 |
+
}
|
| 297 |
+
};
|
| 298 |
+
|
| 299 |
+
this.process = function(annotation) {
|
| 300 |
+
for (var i = 0; i < annotation.sentences.length; i++) {
|
| 301 |
+
var s = annotation.sentences[i];
|
| 302 |
+
if (s.parse) {
|
| 303 |
+
s.parseTree = convertParseStringToTree(s.parse, s.tokens);
|
| 304 |
+
}
|
| 305 |
+
}
|
| 306 |
+
}
|
| 307 |
+
}
|
| 308 |
+
|
| 309 |
+
// ----------------------------------------------------------------------------
|
| 310 |
+
// RENDER
|
| 311 |
+
// ----------------------------------------------------------------------------
|
| 312 |
+
|
| 313 |
+
/**
|
| 314 |
+
* Render a given JSON data structure
|
| 315 |
+
*/
|
| 316 |
+
function render(data, reverse) {
|
| 317 |
+
// Tweak arguments
|
| 318 |
+
if (typeof reverse !== 'boolean') {
|
| 319 |
+
reverse = false;
|
| 320 |
+
}
|
| 321 |
+
|
| 322 |
+
// Error checks
|
| 323 |
+
if (typeof data.sentences === 'undefined') { return; }
|
| 324 |
+
|
| 325 |
+
/**
|
| 326 |
+
* Register an entity type (a tag) for Brat
|
| 327 |
+
*/
|
| 328 |
+
var entityTypesSet = {};
|
| 329 |
+
var entityTypes = [];
|
| 330 |
+
function addEntityType(name, type, coarseType) {
|
| 331 |
+
if (typeof coarseType === "undefined") {
|
| 332 |
+
coarseType = type;
|
| 333 |
+
}
|
| 334 |
+
// Don't add duplicates
|
| 335 |
+
if (entityTypesSet[type]) return;
|
| 336 |
+
entityTypesSet[type] = true;
|
| 337 |
+
// Get the color of the entity type
|
| 338 |
+
color = '#ffccaa';
|
| 339 |
+
if (name === 'POS') {
|
| 340 |
+
color = posColor(type);
|
| 341 |
+
} else if (name === 'UPOS') {
|
| 342 |
+
color = uposColor(type);
|
| 343 |
+
} else if (name === 'NER') {
|
| 344 |
+
color = nerColor(coarseType);
|
| 345 |
+
} else if (name === 'NNER') {
|
| 346 |
+
color = nerColor(coarseType);
|
| 347 |
+
} else if (name === 'COREF') {
|
| 348 |
+
color = '#FFE000';
|
| 349 |
+
} else if (name === 'ENTITY') {
|
| 350 |
+
color = posColor('NN');
|
| 351 |
+
} else if (name === 'RELATION') {
|
| 352 |
+
color = posColor('VB');
|
| 353 |
+
} else if (name === 'LEMMA') {
|
| 354 |
+
color = '#FFFFFF';
|
| 355 |
+
} else if (name === 'SENTIMENT') {
|
| 356 |
+
color = sentimentColor(type);
|
| 357 |
+
} else if (name === 'LINK') {
|
| 358 |
+
color = '#FFFFFF';
|
| 359 |
+
} else if (name === 'KBP_ENTITY') {
|
| 360 |
+
color = '#FFFFFF';
|
| 361 |
+
}
|
| 362 |
+
// Register the type
|
| 363 |
+
entityTypes.push({
|
| 364 |
+
type: type,
|
| 365 |
+
labels : [type],
|
| 366 |
+
bgColor: color,
|
| 367 |
+
borderColor: 'darken'
|
| 368 |
+
});
|
| 369 |
+
}
|
| 370 |
+
|
| 371 |
+
/**
|
| 372 |
+
* Register a relation type (an arc) for Brat
|
| 373 |
+
*/
|
| 374 |
+
var relationTypesSet = {};
|
| 375 |
+
var relationTypes = [];
|
| 376 |
+
function addRelationType(type, symmetricEdge) {
|
| 377 |
+
// Prevent adding duplicates
|
| 378 |
+
if (relationTypesSet[type]) return;
|
| 379 |
+
relationTypesSet[type] = true;
|
| 380 |
+
// Default arguments
|
| 381 |
+
if (typeof symmetricEdge === 'undefined') { symmetricEdge = false; }
|
| 382 |
+
// Add the type
|
| 383 |
+
relationTypes.push({
|
| 384 |
+
type: type,
|
| 385 |
+
labels: [type],
|
| 386 |
+
dashArray: (symmetricEdge ? '3,3' : undefined),
|
| 387 |
+
arrowHead: (symmetricEdge ? 'none' : undefined),
|
| 388 |
+
});
|
| 389 |
+
}
|
| 390 |
+
|
| 391 |
+
//
|
| 392 |
+
// Construct text of annotation
|
| 393 |
+
//
|
| 394 |
+
currentText = []; // GLOBAL
|
| 395 |
+
currentSentences = data.sentences; // GLOBAL
|
| 396 |
+
data.sentences.forEach(function(sentence) {
|
| 397 |
+
for (var i = 0; i < sentence.tokens.length; ++i) {
|
| 398 |
+
var token = sentence.tokens[i];
|
| 399 |
+
var word = token.word;
|
| 400 |
+
if (!(typeof tokensMap[word] === "undefined")) {
|
| 401 |
+
word = tokensMap[word];
|
| 402 |
+
}
|
| 403 |
+
if (i > 0) { currentText.push(' '); }
|
| 404 |
+
token.characterOffsetBegin = currentText.length;
|
| 405 |
+
for (var j = 0; j < word.length; ++j) {
|
| 406 |
+
currentText.push(word[j]);
|
| 407 |
+
}
|
| 408 |
+
token.characterOffsetEnd = currentText.length;
|
| 409 |
+
}
|
| 410 |
+
currentText.push('\n');
|
| 411 |
+
});
|
| 412 |
+
currentText = currentText.join('');
|
| 413 |
+
|
| 414 |
+
//
|
| 415 |
+
// Shared variables
|
| 416 |
+
// These are what we'll render in BRAT
|
| 417 |
+
//
|
| 418 |
+
// (pos)
|
| 419 |
+
var posEntities = [];
|
| 420 |
+
// (upos)
|
| 421 |
+
var uposEntities = [];
|
| 422 |
+
// (lemma)
|
| 423 |
+
var lemmaEntities = [];
|
| 424 |
+
// (ner)
|
| 425 |
+
var nerEntities = [];
|
| 426 |
+
var nerEntitiesNormalized = [];
|
| 427 |
+
// (sentiment)
|
| 428 |
+
var sentimentEntities = [];
|
| 429 |
+
// (entitylinking)
|
| 430 |
+
var linkEntities = [];
|
| 431 |
+
// (dependencies)
|
| 432 |
+
var depsRelations = [];
|
| 433 |
+
var deps2Relations = [];
|
| 434 |
+
// (openie)
|
| 435 |
+
var openieEntities = [];
|
| 436 |
+
var openieEntitiesSet = {};
|
| 437 |
+
var openieRelations = [];
|
| 438 |
+
var openieRelationsSet = {};
|
| 439 |
+
// (kbp)
|
| 440 |
+
var kbpEntities = [];
|
| 441 |
+
var kbpEntitiesSet = [];
|
| 442 |
+
var kbpRelations = [];
|
| 443 |
+
var kbpRelationsSet = [];
|
| 444 |
+
|
| 445 |
+
var cparseEntities = [];
|
| 446 |
+
var cparseRelations = [];
|
| 447 |
+
|
| 448 |
+
//
|
| 449 |
+
// Loop over sentences.
|
| 450 |
+
// This fills in the variables above.
|
| 451 |
+
//
|
| 452 |
+
for (var sentI = 0; sentI < data.sentences.length; ++sentI) {
|
| 453 |
+
var sentence = data.sentences[sentI];
|
| 454 |
+
var index = sentence.index;
|
| 455 |
+
var tokens = sentence.tokens;
|
| 456 |
+
var deps = sentence['basicDependencies'];
|
| 457 |
+
var deps2 = sentence['enhancedPlusPlusDependencies'];
|
| 458 |
+
var parseTree = sentence['parseTree'];
|
| 459 |
+
|
| 460 |
+
// POS tags
|
| 461 |
+
/**
|
| 462 |
+
* Generate a POS tagged token id
|
| 463 |
+
*/
|
| 464 |
+
function posID(i) {
|
| 465 |
+
return 'POS_' + sentI + '_' + i;
|
| 466 |
+
}
|
| 467 |
+
var noXPOS = true;
|
| 468 |
+
if (tokens.length > 0 && typeof tokens[0].pos !== 'undefined' && tokens[0].pos !== null) {
|
| 469 |
+
noXPOS = false;
|
| 470 |
+
for (var i = 0; i < tokens.length; i++) {
|
| 471 |
+
var token = tokens[i];
|
| 472 |
+
var pos = token.pos;
|
| 473 |
+
var begin = parseInt(token.characterOffsetBegin);
|
| 474 |
+
var end = parseInt(token.characterOffsetEnd);
|
| 475 |
+
addEntityType('POS', pos);
|
| 476 |
+
posEntities.push([posID(i), pos, [[begin, end]]]);
|
| 477 |
+
}
|
| 478 |
+
}
|
| 479 |
+
|
| 480 |
+
// Universal POS tags
|
| 481 |
+
/**
|
| 482 |
+
* Generate a POS tagged token id
|
| 483 |
+
*/
|
| 484 |
+
function uposID(i) {
|
| 485 |
+
return 'UPOS_' + sentI + '_' + i;
|
| 486 |
+
}
|
| 487 |
+
if (tokens.length > 0 && typeof tokens[0].upos !== 'undefined') {
|
| 488 |
+
for (var i = 0; i < tokens.length; i++) {
|
| 489 |
+
var token = tokens[i];
|
| 490 |
+
var upos = token.upos;
|
| 491 |
+
var begin = parseInt(token.characterOffsetBegin);
|
| 492 |
+
var end = parseInt(token.characterOffsetEnd);
|
| 493 |
+
addEntityType('UPOS', upos);
|
| 494 |
+
uposEntities.push([uposID(i), upos, [[begin, end]]]);
|
| 495 |
+
}
|
| 496 |
+
}
|
| 497 |
+
|
| 498 |
+
// Constituency parse
|
| 499 |
+
// Carries the same assumption as NER
|
| 500 |
+
if (parseTree && !useDagre) {
|
| 501 |
+
var parseEntities = [];
|
| 502 |
+
var parseRels = [];
|
| 503 |
+
function processParseTree(tree, index) {
|
| 504 |
+
tree.visitIndex = index;
|
| 505 |
+
index++;
|
| 506 |
+
if (tree.isTerminal) {
|
| 507 |
+
parseEntities[tree.visitIndex] = uposEntities[tree.tokenIndex];
|
| 508 |
+
return index;
|
| 509 |
+
} else if (tree.children) {
|
| 510 |
+
addEntityType('PARSENODE', tree.label);
|
| 511 |
+
parseEntities[tree.visitIndex] =
|
| 512 |
+
['PARSENODE_' + sentI + '_' + tree.visitIndex, tree.label,
|
| 513 |
+
[[tokens[tree.tokenStart].characterOffsetBegin, tokens[tree.tokenEnd-1].characterOffsetEnd]]];
|
| 514 |
+
var parentEnt = parseEntities[tree.visitIndex];
|
| 515 |
+
for (var i = 0; i < tree.children.length; i++) {
|
| 516 |
+
var child = tree.children[i];
|
| 517 |
+
index = processParseTree(child, index);
|
| 518 |
+
var childEnt = parseEntities[child.visitIndex];
|
| 519 |
+
addRelationType('pc');
|
| 520 |
+
parseRels.push(['PARSEEDGE_' + sentI + '_' + parseRels.length, 'pc', [['parent', parentEnt[0]], ['child', childEnt[0]]]]);
|
| 521 |
+
}
|
| 522 |
+
}
|
| 523 |
+
return index;
|
| 524 |
+
}
|
| 525 |
+
processParseTree(parseTree, 0);
|
| 526 |
+
cparseEntities = cparseEntities.concat(cparseEntities, parseEntities);
|
| 527 |
+
cparseRelations = cparseRelations.concat(parseRels);
|
| 528 |
+
}
|
| 529 |
+
|
| 530 |
+
// Dependency parsing
|
| 531 |
+
/**
|
| 532 |
+
* Process a dependency tree from JSON to Brat relations
|
| 533 |
+
*/
|
| 534 |
+
function processDeps(name, deps) {
|
| 535 |
+
var relations = [];
|
| 536 |
+
// Format: [${ID}, ${TYPE}, [[${ARGNAME}, ${TARGET}], [${ARGNAME}, ${TARGET}]]]
|
| 537 |
+
for (var i = 0; i < deps.length; i++) {
|
| 538 |
+
var dep = deps[i];
|
| 539 |
+
var governor = dep.governor - 1;
|
| 540 |
+
var dependent = dep.dependent - 1;
|
| 541 |
+
if (governor == -1) continue;
|
| 542 |
+
addRelationType(dep.dep);
|
| 543 |
+
relations.push([name + '_' + sentI + '_' + i, dep.dep, [['governor', uposID(governor)], ['dependent', uposID(dependent)]]]);
|
| 544 |
+
}
|
| 545 |
+
return relations;
|
| 546 |
+
}
|
| 547 |
+
// Actually add the dependencies
|
| 548 |
+
if (typeof deps !== 'undefined') {
|
| 549 |
+
depsRelations = depsRelations.concat(processDeps('dep', deps));
|
| 550 |
+
}
|
| 551 |
+
if (typeof deps2 !== 'undefined') {
|
| 552 |
+
deps2Relations = deps2Relations.concat(processDeps('dep2', deps2));
|
| 553 |
+
}
|
| 554 |
+
|
| 555 |
+
// Lemmas
|
| 556 |
+
if (tokens.length > 0 && typeof tokens[0].lemma !== 'undefined') {
|
| 557 |
+
for (var i = 0; i < tokens.length; i++) {
|
| 558 |
+
var token = tokens[i];
|
| 559 |
+
var lemma = token.lemma;
|
| 560 |
+
var begin = parseInt(token.characterOffsetBegin);
|
| 561 |
+
var end = parseInt(token.characterOffsetEnd);
|
| 562 |
+
addEntityType('LEMMA', lemma);
|
| 563 |
+
lemmaEntities.push(['LEMMA_' + sentI + '_' + i, lemma, [[begin, end]]]);
|
| 564 |
+
}
|
| 565 |
+
}
|
| 566 |
+
|
| 567 |
+
// NER tags
|
| 568 |
+
// Assumption: contiguous occurrence of one non-O is a single entity
|
| 569 |
+
var noNER = true;
|
| 570 |
+
if (tokens.some(function(token) { return token.ner; })) {
|
| 571 |
+
noNER = false;
|
| 572 |
+
for (var i = 0; i < tokens.length; i++) {
|
| 573 |
+
var ner = tokens[i].ner || 'O';
|
| 574 |
+
var normalizedNER = tokens[i].normalizedNER;
|
| 575 |
+
if (typeof normalizedNER === "undefined") {
|
| 576 |
+
normalizedNER = ner;
|
| 577 |
+
}
|
| 578 |
+
if (ner == 'O') continue;
|
| 579 |
+
var j = i;
|
| 580 |
+
while (j < tokens.length - 1 && tokens[j+1].ner == ner) j++;
|
| 581 |
+
addEntityType('NER', ner, ner);
|
| 582 |
+
nerEntities.push(['NER_' + sentI + '_' + i, ner, [[tokens[i].characterOffsetBegin, tokens[j].characterOffsetEnd]]]);
|
| 583 |
+
if (ner != normalizedNER) {
|
| 584 |
+
addEntityType('NNER', normalizedNER, ner);
|
| 585 |
+
nerEntities.push(['NNER_' + sentI + '_' + i, normalizedNER, [[tokens[i].characterOffsetBegin, tokens[j].characterOffsetEnd]]]);
|
| 586 |
+
|
| 587 |
+
}
|
| 588 |
+
i = j;
|
| 589 |
+
}
|
| 590 |
+
}
|
| 591 |
+
|
| 592 |
+
// Sentiment
|
| 593 |
+
if (typeof sentence.sentiment !== "undefined") {
|
| 594 |
+
var sentiment = sentence.sentiment.toUpperCase().replace("VERY", "VERY ");
|
| 595 |
+
addEntityType('SENTIMENT', sentiment);
|
| 596 |
+
sentimentEntities.push(['SENTIMENT_' + sentI, sentiment,
|
| 597 |
+
[[tokens[0].characterOffsetBegin, tokens[tokens.length - 1].characterOffsetEnd]]]);
|
| 598 |
+
}
|
| 599 |
+
|
| 600 |
+
// Entity Links
|
| 601 |
+
// Carries the same assumption as NER
|
| 602 |
+
if (tokens.length > 0) {
|
| 603 |
+
for (var i = 0; i < tokens.length; i++) {
|
| 604 |
+
var link = tokens[i].entitylink;
|
| 605 |
+
if (link == 'O' || typeof link === 'undefined') continue;
|
| 606 |
+
var j = i;
|
| 607 |
+
while (j < tokens.length - 1 && tokens[j+1].entitylink == link) j++;
|
| 608 |
+
addEntityType('LINK', link);
|
| 609 |
+
linkEntities.push(['LINK_' + sentI + '_' + i, link, [[tokens[i].characterOffsetBegin, tokens[j].characterOffsetEnd]]]);
|
| 610 |
+
i = j;
|
| 611 |
+
}
|
| 612 |
+
}
|
| 613 |
+
|
| 614 |
+
// Open IE
|
| 615 |
+
// Helper Functions
|
| 616 |
+
function openieID(span) {
|
| 617 |
+
return 'OPENIEENTITY' + '_' + sentI + '_' + span[0] + '_' + span[1];
|
| 618 |
+
}
|
| 619 |
+
function addEntity(span, role) {
|
| 620 |
+
// Don't add duplicate entities
|
| 621 |
+
if (openieEntitiesSet[[sentI, span, role]]) return;
|
| 622 |
+
openieEntitiesSet[[sentI, span, role]] = true;
|
| 623 |
+
// Add the entity
|
| 624 |
+
openieEntities.push([openieID(span), role,
|
| 625 |
+
[[tokens[span[0]].characterOffsetBegin,
|
| 626 |
+
tokens[span[1] - 1].characterOffsetEnd ]] ]);
|
| 627 |
+
}
|
| 628 |
+
function addRelation(gov, dep, role) {
|
| 629 |
+
// Don't add duplicate relations
|
| 630 |
+
if (openieRelationsSet[[sentI, gov, dep, role]]) return;
|
| 631 |
+
openieRelationsSet[[sentI, gov, dep, role]] = true;
|
| 632 |
+
// Add the relation
|
| 633 |
+
openieRelations.push(['OPENIESUBJREL_' + sentI + '_' + gov[0] + '_' + gov[1] + '_' + dep[0] + '_' + dep[1],
|
| 634 |
+
role,
|
| 635 |
+
[['governor', openieID(gov)],
|
| 636 |
+
['dependent', openieID(dep)] ] ]);
|
| 637 |
+
}
|
| 638 |
+
// Render OpenIE
|
| 639 |
+
if (typeof sentence.openie !== 'undefined') {
|
| 640 |
+
// Register the entities + relations we'll need
|
| 641 |
+
addEntityType('ENTITY', 'Entity');
|
| 642 |
+
addEntityType('RELATION', 'Relation');
|
| 643 |
+
addRelationType('subject');
|
| 644 |
+
addRelationType('object');
|
| 645 |
+
// Loop over triples
|
| 646 |
+
for (var i = 0; i < sentence.openie.length; ++i) {
|
| 647 |
+
var subjectSpan = sentence.openie[i].subjectSpan;
|
| 648 |
+
var relationSpan = sentence.openie[i].relationSpan;
|
| 649 |
+
var objectSpan = sentence.openie[i].objectSpan;
|
| 650 |
+
if (parseInt(relationSpan[0]) < 0 || parseInt(relationSpan[1]) < 0) {
|
| 651 |
+
continue; // This is a phantom relation
|
| 652 |
+
}
|
| 653 |
+
var begin = parseInt(token.characterOffsetBegin);
|
| 654 |
+
// Add the entities
|
| 655 |
+
addEntity(subjectSpan, 'Entity');
|
| 656 |
+
addEntity(relationSpan, 'Relation');
|
| 657 |
+
addEntity(objectSpan, 'Entity');
|
| 658 |
+
// Add the relations
|
| 659 |
+
addRelation(relationSpan, subjectSpan, 'subject');
|
| 660 |
+
addRelation(relationSpan, objectSpan, 'object');
|
| 661 |
+
}
|
| 662 |
+
} // End OpenIE block
|
| 663 |
+
|
| 664 |
+
|
| 665 |
+
//
|
| 666 |
+
// KBP
|
| 667 |
+
//
|
| 668 |
+
// Helper Functions
|
| 669 |
+
function kbpEntity(span) {
|
| 670 |
+
return 'KBPENTITY' + '_' + sentI + '_' + span[0] + '_' + span[1];
|
| 671 |
+
}
|
| 672 |
+
function addKBPEntity(span, role) {
|
| 673 |
+
// Don't add duplicate entities
|
| 674 |
+
if (kbpEntitiesSet[[sentI, span, role]]) return;
|
| 675 |
+
kbpEntitiesSet[[sentI, span, role]] = true;
|
| 676 |
+
// Add the entity
|
| 677 |
+
kbpEntities.push([kbpEntity(span), role,
|
| 678 |
+
[[tokens[span[0]].characterOffsetBegin,
|
| 679 |
+
tokens[span[1] - 1].characterOffsetEnd ]] ]);
|
| 680 |
+
}
|
| 681 |
+
function addKBPRelation(gov, dep, role) {
|
| 682 |
+
// Don't add duplicate relations
|
| 683 |
+
if (kbpRelationsSet[[sentI, gov, dep, role]]) return;
|
| 684 |
+
kbpRelationsSet[[sentI, gov, dep, role]] = true;
|
| 685 |
+
// Add the relation
|
| 686 |
+
kbpRelations.push(['KBPRELATION_' + sentI + '_' + gov[0] + '_' + gov[1] + '_' + dep[0] + '_' + dep[1],
|
| 687 |
+
role,
|
| 688 |
+
[['governor', kbpEntity(gov)],
|
| 689 |
+
['dependent', kbpEntity(dep)] ] ]);
|
| 690 |
+
}
|
| 691 |
+
if (typeof sentence.kbp !== 'undefined') {
|
| 692 |
+
// Register the entities + relations we'll need
|
| 693 |
+
addRelationType('subject');
|
| 694 |
+
addRelationType('object');
|
| 695 |
+
// Loop over triples
|
| 696 |
+
for (var i = 0; i < sentence.kbp.length; ++i) {
|
| 697 |
+
var subjectSpan = sentence.kbp[i].subjectSpan;
|
| 698 |
+
var subjectLink = 'Entity';
|
| 699 |
+
for (var k = subjectSpan[0]; k < subjectSpan[1]; ++k) {
|
| 700 |
+
if (subjectLink == 'Entity' &&
|
| 701 |
+
typeof tokens[k] !== 'undefined' &&
|
| 702 |
+
tokens[k].entitylink != 'O' &&
|
| 703 |
+
typeof tokens[k].entitylink !== 'undefined') {
|
| 704 |
+
subjectLink = tokens[k].entitylink
|
| 705 |
+
}
|
| 706 |
+
}
|
| 707 |
+
addEntityType('KBP_ENTITY', subjectLink);
|
| 708 |
+
var objectSpan = sentence.kbp[i].objectSpan;
|
| 709 |
+
var objectLink = 'Entity';
|
| 710 |
+
for (var k = objectSpan[0]; k < objectSpan[1]; ++k) {
|
| 711 |
+
if (objectLink == 'Entity' &&
|
| 712 |
+
typeof tokens[k] !== 'undefined' &&
|
| 713 |
+
tokens[k].entitylink != 'O' &&
|
| 714 |
+
typeof tokens[k].entitylink !== 'undefined') {
|
| 715 |
+
objectLink = tokens[k].entitylink
|
| 716 |
+
}
|
| 717 |
+
}
|
| 718 |
+
addEntityType('KBP_ENTITY', objectLink);
|
| 719 |
+
var relation = sentence.kbp[i].relation;
|
| 720 |
+
var begin = parseInt(token.characterOffsetBegin);
|
| 721 |
+
// Add the entities
|
| 722 |
+
addKBPEntity(subjectSpan, subjectLink);
|
| 723 |
+
addKBPEntity(objectSpan, objectLink);
|
| 724 |
+
// Add the relations
|
| 725 |
+
addKBPRelation(subjectSpan, objectSpan, relation);
|
| 726 |
+
}
|
| 727 |
+
} // End KBP block
|
| 728 |
+
|
| 729 |
+
} // End sentence loop
|
| 730 |
+
|
| 731 |
+
//
|
| 732 |
+
// Coreference
|
| 733 |
+
//
|
| 734 |
+
var corefEntities = [];
|
| 735 |
+
var corefRelations = [];
|
| 736 |
+
if (typeof data.corefs !== 'undefined') {
|
| 737 |
+
addRelationType('coref', true);
|
| 738 |
+
addEntityType('COREF', 'Mention');
|
| 739 |
+
var clusters = Object.keys(data.corefs);
|
| 740 |
+
clusters.forEach( function (clusterId) {
|
| 741 |
+
var chain = data.corefs[clusterId];
|
| 742 |
+
if (chain.length > 1) {
|
| 743 |
+
for (var i = 0; i < chain.length; ++i) {
|
| 744 |
+
var mention = chain[i];
|
| 745 |
+
var id = 'COREF' + mention.id;
|
| 746 |
+
var tokens = data.sentences[mention.sentNum - 1].tokens;
|
| 747 |
+
corefEntities.push([id, 'Mention',
|
| 748 |
+
[[tokens[mention.startIndex - 1].characterOffsetBegin,
|
| 749 |
+
tokens[mention.endIndex - 2].characterOffsetEnd ]] ]);
|
| 750 |
+
if (i > 0) {
|
| 751 |
+
var lastId = 'COREF' + chain[i - 1].id;
|
| 752 |
+
corefRelations.push(['COREF' + chain[i-1].id + '_' + chain[i].id,
|
| 753 |
+
'coref',
|
| 754 |
+
[['governor', lastId],
|
| 755 |
+
['dependent', id] ] ]);
|
| 756 |
+
}
|
| 757 |
+
}
|
| 758 |
+
}
|
| 759 |
+
});
|
| 760 |
+
} // End coreference block
|
| 761 |
+
|
| 762 |
+
//
|
| 763 |
+
// Actually render the elements
|
| 764 |
+
//
|
| 765 |
+
|
| 766 |
+
/**
|
| 767 |
+
* Helper function to render a given set of entities / relations
|
| 768 |
+
* to a Div, if it exists.
|
| 769 |
+
*/
|
| 770 |
+
function embed(container, entities, relations, reverse) {
|
| 771 |
+
var text = currentText;
|
| 772 |
+
if (reverse) {
|
| 773 |
+
var length = currentText.length;
|
| 774 |
+
for (var i = 0; i < entities.length; ++i) {
|
| 775 |
+
var offsets = entities[i][2][0];
|
| 776 |
+
var tmp = length - offsets[0];
|
| 777 |
+
offsets[0] = length - offsets[1];
|
| 778 |
+
offsets[1] = tmp;
|
| 779 |
+
}
|
| 780 |
+
text = text.split("").reverse().join("");
|
| 781 |
+
}
|
| 782 |
+
if ($('#' + container).length > 0) {
|
| 783 |
+
Util.embed(container,
|
| 784 |
+
{entity_types: entityTypes, relation_types: relationTypes},
|
| 785 |
+
{text: text, entities: entities, relations: relations}
|
| 786 |
+
);
|
| 787 |
+
}
|
| 788 |
+
}
|
| 789 |
+
|
| 790 |
+
function reportna(container, text) {
|
| 791 |
+
$('#' + container).text(text);
|
| 792 |
+
}
|
| 793 |
+
|
| 794 |
+
// Render each annotation
|
| 795 |
+
head.ready(function() {
|
| 796 |
+
if (!noXPOS) {
|
| 797 |
+
embed('pos', posEntities);
|
| 798 |
+
} else {
|
| 799 |
+
reportna('pos', 'XPOS is not available for this language at this time.')
|
| 800 |
+
}
|
| 801 |
+
embed('upos', uposEntities);
|
| 802 |
+
embed('lemma', lemmaEntities);
|
| 803 |
+
if (!noNER) {
|
| 804 |
+
embed('ner', nerEntities);
|
| 805 |
+
} else {
|
| 806 |
+
reportna('ner', 'NER is not available for this language at this time.')
|
| 807 |
+
}
|
| 808 |
+
embed('entities', linkEntities);
|
| 809 |
+
if (!useDagre) {
|
| 810 |
+
embed('parse', cparseEntities, cparseRelations);
|
| 811 |
+
}
|
| 812 |
+
embed('deps', uposEntities, depsRelations);
|
| 813 |
+
embed('deps2', posEntities, deps2Relations);
|
| 814 |
+
embed('coref', corefEntities, corefRelations);
|
| 815 |
+
embed('openie', openieEntities, openieRelations);
|
| 816 |
+
embed('kbp', kbpEntities, kbpRelations);
|
| 817 |
+
embed('sentiment', sentimentEntities);
|
| 818 |
+
|
| 819 |
+
// Constituency parse
|
| 820 |
+
// Uses d3 and dagre-d3 (not brat)
|
| 821 |
+
if ($('#parse').length > 0 && useDagre) {
|
| 822 |
+
var parseViewer = new ParseViewer({ selector: '#parse' });
|
| 823 |
+
parseViewer.showAnnotation(data);
|
| 824 |
+
$('#parse').addClass('svg').css('display', 'block');
|
| 825 |
+
}
|
| 826 |
+
});
|
| 827 |
+
|
| 828 |
+
} // End render function
|
| 829 |
+
|
| 830 |
+
|
| 831 |
+
/**
|
| 832 |
+
* Render a TokensRegex response
|
| 833 |
+
*/
|
| 834 |
+
function renderTokensregex(data) {
|
| 835 |
+
/**
|
| 836 |
+
* Register an entity type (a tag) for Brat
|
| 837 |
+
*/
|
| 838 |
+
var entityTypesSet = {};
|
| 839 |
+
var entityTypes = [];
|
| 840 |
+
function addEntityType(type, color) {
|
| 841 |
+
// Don't add duplicates
|
| 842 |
+
if (entityTypesSet[type]) return;
|
| 843 |
+
entityTypesSet[type] = true;
|
| 844 |
+
// Set the color
|
| 845 |
+
if (typeof color === 'undefined') {
|
| 846 |
+
color = '#ADF6A2';
|
| 847 |
+
}
|
| 848 |
+
// Register the type
|
| 849 |
+
entityTypes.push({
|
| 850 |
+
type: type,
|
| 851 |
+
labels : [type],
|
| 852 |
+
bgColor: color,
|
| 853 |
+
borderColor: 'darken'
|
| 854 |
+
});
|
| 855 |
+
}
|
| 856 |
+
|
| 857 |
+
var entities = [];
|
| 858 |
+
for (var sentI = 0; sentI < data.sentences.length; ++sentI) {
|
| 859 |
+
var tokens = currentSentences[sentI].tokens;
|
| 860 |
+
for (var matchI = 0; matchI < data.sentences[sentI].length; ++matchI) {
|
| 861 |
+
var match = data.sentences[sentI][matchI];
|
| 862 |
+
// Add groups
|
| 863 |
+
for (groupName in match) {
|
| 864 |
+
if (groupName.startsWith("$") || isInt(groupName)) {
|
| 865 |
+
addEntityType(groupName, '#FFFDA8');
|
| 866 |
+
var begin = parseInt(tokens[match[groupName].begin].characterOffsetBegin);
|
| 867 |
+
var end = parseInt(tokens[match[groupName].end - 1].characterOffsetEnd);
|
| 868 |
+
entities.push(['TOK_' + sentI + '_' + matchI + '_' + groupName,
|
| 869 |
+
groupName,
|
| 870 |
+
[[begin, end]]]);
|
| 871 |
+
}
|
| 872 |
+
}
|
| 873 |
+
// Add match
|
| 874 |
+
addEntityType('match', '#ADF6A2');
|
| 875 |
+
var begin = parseInt(tokens[match.begin].characterOffsetBegin);
|
| 876 |
+
var end = parseInt(tokens[match.end - 1].characterOffsetEnd);
|
| 877 |
+
entities.push(['TOK_' + sentI + '_' + matchI + '_match',
|
| 878 |
+
'match',
|
| 879 |
+
[[begin, end]]]);
|
| 880 |
+
}
|
| 881 |
+
}
|
| 882 |
+
|
| 883 |
+
Util.embed('tokensregex',
|
| 884 |
+
{entity_types: entityTypes, relation_types: []},
|
| 885 |
+
{text: currentText, entities: entities, relations: []}
|
| 886 |
+
);
|
| 887 |
+
} // END renderTokensregex()
|
| 888 |
+
|
| 889 |
+
|
| 890 |
+
/**
|
| 891 |
+
* Render a Semgrex response
|
| 892 |
+
*/
|
| 893 |
+
function renderSemgrex(data) {
|
| 894 |
+
/**
|
| 895 |
+
* Register an entity type (a tag) for Brat
|
| 896 |
+
*/
|
| 897 |
+
var entityTypesSet = {};
|
| 898 |
+
var entityTypes = [];
|
| 899 |
+
function addEntityType(type, color) {
|
| 900 |
+
// Don't add duplicates
|
| 901 |
+
if (entityTypesSet[type]) return;
|
| 902 |
+
entityTypesSet[type] = true;
|
| 903 |
+
// Set the color
|
| 904 |
+
if (typeof color === 'undefined') {
|
| 905 |
+
color = '#ADF6A2';
|
| 906 |
+
}
|
| 907 |
+
// Register the type
|
| 908 |
+
entityTypes.push({
|
| 909 |
+
type: type,
|
| 910 |
+
labels : [type],
|
| 911 |
+
bgColor: color,
|
| 912 |
+
borderColor: 'darken'
|
| 913 |
+
});
|
| 914 |
+
}
|
| 915 |
+
|
| 916 |
+
|
| 917 |
+
relationTypes = [{
|
| 918 |
+
type: 'semgrex',
|
| 919 |
+
labels: ['-'],
|
| 920 |
+
dashArray: '3,3',
|
| 921 |
+
arrowHead: 'none',
|
| 922 |
+
}];
|
| 923 |
+
|
| 924 |
+
var entities = [];
|
| 925 |
+
var relations = [];
|
| 926 |
+
|
| 927 |
+
for (var sentI = 0; sentI < data.sentences.length; ++sentI) {
|
| 928 |
+
var tokens = currentSentences[sentI].tokens;
|
| 929 |
+
for (var matchI = 0; matchI < data.sentences[sentI].length; ++matchI) {
|
| 930 |
+
var match = data.sentences[sentI][matchI];
|
| 931 |
+
// Add match
|
| 932 |
+
addEntityType('match', '#ADF6A2');
|
| 933 |
+
var begin = parseInt(tokens[match.begin].characterOffsetBegin);
|
| 934 |
+
var end = parseInt(tokens[match.end - 1].characterOffsetEnd);
|
| 935 |
+
entities.push(['SEM_' + sentI + '_' + matchI + '_match',
|
| 936 |
+
'match',
|
| 937 |
+
[[begin, end]]]);
|
| 938 |
+
|
| 939 |
+
// Add groups
|
| 940 |
+
for (groupName in match) {
|
| 941 |
+
if (groupName.startsWith("$") || isInt(groupName)) {
|
| 942 |
+
// (add node)
|
| 943 |
+
group = match[groupName];
|
| 944 |
+
groupName = groupName.substring(1);
|
| 945 |
+
addEntityType(groupName, '#FFFDA8');
|
| 946 |
+
var begin = parseInt(tokens[group.begin].characterOffsetBegin);
|
| 947 |
+
var end = parseInt(tokens[group.end - 1].characterOffsetEnd);
|
| 948 |
+
entities.push(['SEM_' + sentI + '_' + matchI + '_' + groupName,
|
| 949 |
+
groupName,
|
| 950 |
+
[[begin, end]]]);
|
| 951 |
+
|
| 952 |
+
// (add relation)
|
| 953 |
+
relations.push(['SEMGREX_' + sentI + '_' + matchI + '_' + groupName,
|
| 954 |
+
'semgrex',
|
| 955 |
+
[['governor', 'SEM_' + sentI + '_' + matchI + '_match'],
|
| 956 |
+
['dependent', 'SEM_' + sentI + '_' + matchI + '_' + groupName] ] ]);
|
| 957 |
+
}
|
| 958 |
+
}
|
| 959 |
+
}
|
| 960 |
+
}
|
| 961 |
+
|
| 962 |
+
Util.embed('semgrex',
|
| 963 |
+
{entity_types: entityTypes, relation_types: relationTypes},
|
| 964 |
+
{text: currentText, entities: entities, relations: relations}
|
| 965 |
+
);
|
| 966 |
+
} // END renderSemgrex
|
| 967 |
+
|
| 968 |
+
/**
|
| 969 |
+
* Render a Tregex response
|
| 970 |
+
*/
|
| 971 |
+
function renderTregex(data) {
|
| 972 |
+
$('#tregex').empty();
|
| 973 |
+
$('#tregex').append('<pre>' + JSON.stringify(data, null, 4) + '</pre>');
|
| 974 |
+
} // END renderTregex
|
| 975 |
+
|
| 976 |
+
// ----------------------------------------------------------------------------
|
| 977 |
+
// MAIN
|
| 978 |
+
// ----------------------------------------------------------------------------
|
| 979 |
+
|
| 980 |
+
/**
|
| 981 |
+
* MAIN()
|
| 982 |
+
*
|
| 983 |
+
* The entry point of the page
|
| 984 |
+
*/
|
| 985 |
+
$(document).ready(function() {
|
| 986 |
+
// Some initial styling
|
| 987 |
+
$('.chosen-select').chosen();
|
| 988 |
+
$('.chosen-container').css('width', '100%');
|
| 989 |
+
|
| 990 |
+
|
| 991 |
+
// Language-specific changes
|
| 992 |
+
$('#language').on('change', function() {
|
| 993 |
+
$('#text').attr('dir', '');
|
| 994 |
+
if ($('#language').val() === 'ar' ||
|
| 995 |
+
$('#language').val() === 'fa' ||
|
| 996 |
+
$('#language').val() === 'he' ||
|
| 997 |
+
$('#language').val() === 'ur') {
|
| 998 |
+
$('#text').attr('dir', 'rtl');
|
| 999 |
+
}
|
| 1000 |
+
if ($('#language').val() === 'ar') {
|
| 1001 |
+
$('#text').attr('placeholder', 'على سبيل المثال، قفز الثعلب البني السريع فوق الكلب الكسول.');
|
| 1002 |
+
} else if ($('#language').val() === 'en') {
|
| 1003 |
+
$('#text').attr('placeholder', 'e.g., The quick brown fox jumped over the lazy dog.');
|
| 1004 |
+
} else if ($('#language').val() === 'zh') {
|
| 1005 |
+
$('#text').attr('placeholder', '例如,快速的棕色狐狸跳过了懒惰的狗。');
|
| 1006 |
+
} else if ($('#language').val() === 'zh-Hant') {
|
| 1007 |
+
$('#text').attr('placeholder', '例如,快速的棕色狐狸跳過了懶惰的狗。');
|
| 1008 |
+
} else if ($('#language').val() === 'fr') {
|
| 1009 |
+
$('#text').attr('placeholder', 'Par exemple, le renard brun rapide a sauté sur le chien paresseux.');
|
| 1010 |
+
} else if ($('#language').val() === 'de') {
|
| 1011 |
+
$('#text').attr('placeholder', 'Z. B. sprang der schnelle braune Fuchs über den faulen Hund.');
|
| 1012 |
+
} else if ($('#language').val() === 'es') {
|
| 1013 |
+
$('#text').attr('placeholder', 'Por ejemplo, el rápido zorro marrón saltó sobre el perro perezoso.');
|
| 1014 |
+
} else if ($('#language').val() === 'ur') {
|
| 1015 |
+
$('#text').attr('placeholder', 'میرا نام علی ہے');
|
| 1016 |
+
} else {
|
| 1017 |
+
$('#text').attr('placeholder', 'Unknown language for placeholder query: ' + $('#language').val());
|
| 1018 |
+
}
|
| 1019 |
+
});
|
| 1020 |
+
|
| 1021 |
+
// Submit on shift-enter
|
| 1022 |
+
$('#text').keydown(function (event) {
|
| 1023 |
+
if (event.keyCode == 13) {
|
| 1024 |
+
if(event.shiftKey){
|
| 1025 |
+
event.preventDefault(); // don't register the enter key when pressed
|
| 1026 |
+
return false;
|
| 1027 |
+
}
|
| 1028 |
+
}
|
| 1029 |
+
});
|
| 1030 |
+
$('#text').keyup(function (event) {
|
| 1031 |
+
if (event.keyCode == 13) {
|
| 1032 |
+
if(event.shiftKey){
|
| 1033 |
+
$('#submit').click(); // submit the form when the enter key is released
|
| 1034 |
+
event.stopPropagation();
|
| 1035 |
+
return false;
|
| 1036 |
+
}
|
| 1037 |
+
}
|
| 1038 |
+
});
|
| 1039 |
+
|
| 1040 |
+
// Submit on clicking the 'submit' button
|
| 1041 |
+
$('#submit').click(function() {
|
| 1042 |
+
// Get the text to annotate
|
| 1043 |
+
currentQuery = $('#text').val();
|
| 1044 |
+
if (currentQuery.trim() == '') {
|
| 1045 |
+
if ($('#language').val() === 'ar') {
|
| 1046 |
+
currentQuery = 'قفز الثعلب البني السريع فوق الكلب الكسول.';
|
| 1047 |
+
} else if ($('#language').val() === 'en') {
|
| 1048 |
+
currentQuery = 'The quick brown fox jumped over the lazy dog.';
|
| 1049 |
+
} else if ($('#language').val() === 'zh') {
|
| 1050 |
+
currentQuery = '快速的棕色狐狸跳过了懒惰的狗。';
|
| 1051 |
+
} else if ($('#language').val() === 'zh-Hant') {
|
| 1052 |
+
currentQuery = '快速的棕色狐狸跳過了懶惰的狗。';
|
| 1053 |
+
} else if ($('#language').val() === 'fr') {
|
| 1054 |
+
currentQuery = 'Le renard brun rapide a sauté sur le chien paresseux.';
|
| 1055 |
+
} else if ($('#language').val() === 'de') {
|
| 1056 |
+
currentQuery = 'Sprang der schnelle braune Fuchs über den faulen Hund.';
|
| 1057 |
+
} else if ($('#language').val() === 'es') {
|
| 1058 |
+
currentQuery = 'El rápido zorro marrón saltó sobre el perro perezoso.';
|
| 1059 |
+
} else if ($('#language').val() === 'ur') {
|
| 1060 |
+
currentQuery = 'میرا نام علی ہے';
|
| 1061 |
+
} else {
|
| 1062 |
+
currentQuery = 'Unknown language for default query: ' + $('#language').val();
|
| 1063 |
+
}
|
| 1064 |
+
$('#text').val(currentQuery);
|
| 1065 |
+
}
|
| 1066 |
+
// Update the UI
|
| 1067 |
+
$('#submit').prop('disabled', true);
|
| 1068 |
+
$('#annotations').hide();
|
| 1069 |
+
$('#patterns_row').hide();
|
| 1070 |
+
$('#loading').show();
|
| 1071 |
+
|
| 1072 |
+
// Run query
|
| 1073 |
+
$.ajax({
|
| 1074 |
+
type: 'POST',
|
| 1075 |
+
url: serverAddress + '?properties=' + encodeURIComponent(
|
| 1076 |
+
'{"annotators": "' + annotators() + '", "date": "' + date() + '"}') +
|
| 1077 |
+
'&pipelineLanguage=' + encodeURIComponent($('#language').val()),
|
| 1078 |
+
data: encodeURIComponent(currentQuery), //jQuery doesn't automatically URI encode strings
|
| 1079 |
+
dataType: 'json',
|
| 1080 |
+
contentType: "application/x-www-form-urlencoded;charset=UTF-8",
|
| 1081 |
+
responseType: "application/json",
|
| 1082 |
+
success: function(data) {
|
| 1083 |
+
$('#submit').prop('disabled', false);
|
| 1084 |
+
if (typeof data === 'undefined' || data.sentences == undefined) {
|
| 1085 |
+
alert("Failed to reach server!");
|
| 1086 |
+
} else {
|
| 1087 |
+
// Process constituency parse
|
| 1088 |
+
var constituencyParseProcessor = new ConstituencyParseProcessor();
|
| 1089 |
+
constituencyParseProcessor.process(data);
|
| 1090 |
+
// Empty divs
|
| 1091 |
+
$('#annotations').empty();
|
| 1092 |
+
// Re-render divs
|
| 1093 |
+
function createAnnotationDiv(id, annotator, selector, label) {
|
| 1094 |
+
// (make sure we requested that element)
|
| 1095 |
+
if (annotators().split(",").indexOf(annotator) < 0) {
|
| 1096 |
+
return;
|
| 1097 |
+
}
|
| 1098 |
+
// (make sure the data contains that element)
|
| 1099 |
+
ok = false;
|
| 1100 |
+
if (typeof data[selector] !== 'undefined') {
|
| 1101 |
+
ok = true;
|
| 1102 |
+
} else if (typeof data.sentences !== 'undefined' && data.sentences.length > 0) {
|
| 1103 |
+
if (typeof data.sentences[0][selector] !== 'undefined') {
|
| 1104 |
+
ok = true;
|
| 1105 |
+
} else if (typeof data.sentences[0].tokens != 'undefined' && data.sentences[0].tokens.length > 0) {
|
| 1106 |
+
// (make sure the annotator select is in at least one of the tokens of any sentence)
|
| 1107 |
+
ok = data.sentences.some(function(sentence) {
|
| 1108 |
+
return sentence.tokens.some(function(token) {
|
| 1109 |
+
return typeof token[selector] !== 'undefined';
|
| 1110 |
+
});
|
| 1111 |
+
});
|
| 1112 |
+
}
|
| 1113 |
+
}
|
| 1114 |
+
// (render the element)
|
| 1115 |
+
if (ok) {
|
| 1116 |
+
$('#annotations').append('<h4 class="red">' + label + ':</h4> <div id="' + id + '"></div>');
|
| 1117 |
+
}
|
| 1118 |
+
}
|
| 1119 |
+
// (create the divs)
|
| 1120 |
+
// div id annotator field_in_data label
|
| 1121 |
+
createAnnotationDiv('pos', 'pos', 'pos', 'Part-of-Speech (XPOS)' );
|
| 1122 |
+
createAnnotationDiv('upos', 'upos', 'upos', 'Universal Part-of-Speech');
|
| 1123 |
+
createAnnotationDiv('lemma', 'lemma', 'lemma', 'Lemmas' );
|
| 1124 |
+
createAnnotationDiv('ner', 'ner', 'ner', 'Named Entity Recognition');
|
| 1125 |
+
createAnnotationDiv('deps', 'depparse', 'basicDependencies', 'Universal Dependencies' );
|
| 1126 |
+
createAnnotationDiv('parse', 'parse', 'parseTree', 'Constituency Parse' );
|
| 1127 |
+
//createAnnotationDiv('deps2', 'depparse', 'enhancedPlusPlusDependencies', 'Enhanced++ Dependencies' );
|
| 1128 |
+
//createAnnotationDiv('openie', 'openie', 'openie', 'Open IE' );
|
| 1129 |
+
//createAnnotationDiv('coref', 'coref', 'corefs', 'Coreference' );
|
| 1130 |
+
//createAnnotationDiv('entities', 'entitylink', 'entitylink', 'Wikidict Entities' );
|
| 1131 |
+
//createAnnotationDiv('kbp', 'kbp', 'kbp', 'KBP Relations' );
|
| 1132 |
+
//createAnnotationDiv('sentiment','sentiment', 'sentiment', 'Sentiment' );
|
| 1133 |
+
// Update UI
|
| 1134 |
+
$('#loading').hide();
|
| 1135 |
+
$('.corenlp_error').remove(); // Clear error messages
|
| 1136 |
+
$('#annotations').show();
|
| 1137 |
+
// Render
|
| 1138 |
+
var reverse = ($('#language').val() === 'ar' || $('#language').val() === 'fa' || $('#language').val() === 'he' || $('#language').val() === 'ur');
|
| 1139 |
+
render(data, reverse);
|
| 1140 |
+
// Render patterns
|
| 1141 |
+
//$('#annotations').append('<h4 class="red" style="margin-top: 4ex;">CoreNLP Tools:</h4>'); // TODO(gabor) a strange place to add this header to
|
| 1142 |
+
//$('#patterns_row').show();
|
| 1143 |
+
}
|
| 1144 |
+
},
|
| 1145 |
+
error: function(data) {
|
| 1146 |
+
DATA = data;
|
| 1147 |
+
var alertDiv = $('<div/>').addClass('alert').addClass('alert-danger').addClass('alert-dismissible').addClass('corenlp_error').attr('role', 'alert')
|
| 1148 |
+
var button = $('<button type="button" class="close" data-dismiss="alert" aria-label="Close"><span aria-hidden="true">×</span></button>');
|
| 1149 |
+
var message = $('<span/>').text(data.responseText);
|
| 1150 |
+
button.appendTo(alertDiv);
|
| 1151 |
+
message.appendTo(alertDiv);
|
| 1152 |
+
$('#loading').hide();
|
| 1153 |
+
alertDiv.appendTo($('#errors'));
|
| 1154 |
+
$('#submit').prop('disabled', false);
|
| 1155 |
+
}
|
| 1156 |
+
});
|
| 1157 |
+
event.preventDefault();
|
| 1158 |
+
event.stopPropagation();
|
| 1159 |
+
return false;
|
| 1160 |
+
});
|
| 1161 |
+
|
| 1162 |
+
|
| 1163 |
+
// Support passing parameters on page launch, via window.location.hash parameters.
|
| 1164 |
+
// Example: http://localhost:9000/#text=foo%20bar&annotators=pos,lemma,ner
|
| 1165 |
+
(function() {
|
| 1166 |
+
var rawParams = window.location.hash.slice(1).split("&");
|
| 1167 |
+
var params = {};
|
| 1168 |
+
rawParams.forEach(function(paramKV) {
|
| 1169 |
+
paramKV = paramKV.split("=");
|
| 1170 |
+
if (paramKV.length === 2) {
|
| 1171 |
+
var key = paramKV[0];
|
| 1172 |
+
var value = paramKV[1];
|
| 1173 |
+
params[key] = value;
|
| 1174 |
+
}
|
| 1175 |
+
});
|
| 1176 |
+
if (params.text) {
|
| 1177 |
+
var text = decodeURIComponent(params.text);
|
| 1178 |
+
$('#text').val(text);
|
| 1179 |
+
}
|
| 1180 |
+
if (params.annotators) {
|
| 1181 |
+
var annotators = params.annotators.split(",");
|
| 1182 |
+
// De-select everything
|
| 1183 |
+
$('#annotators').find('option').each(function() {
|
| 1184 |
+
$(this).prop('selected', false);
|
| 1185 |
+
});
|
| 1186 |
+
// Select the specified ones.
|
| 1187 |
+
annotators.forEach(function(a) {
|
| 1188 |
+
$('#annotators').find('option[value="'+a+'"]').prop('selected', true);
|
| 1189 |
+
});
|
| 1190 |
+
// Refresh Chosen
|
| 1191 |
+
$('#annotators').trigger('chosen:updated');
|
| 1192 |
+
}
|
| 1193 |
+
if (params.text || params.annotators) {
|
| 1194 |
+
// Finally, let's auto-submit.
|
| 1195 |
+
$('#submit').click();
|
| 1196 |
+
}
|
| 1197 |
+
})();
|
| 1198 |
+
|
| 1199 |
+
|
| 1200 |
+
$('#form_tokensregex').submit( function (e) {
|
| 1201 |
+
// Don't actually submit the form
|
| 1202 |
+
e.preventDefault();
|
| 1203 |
+
// Get text
|
| 1204 |
+
if ($('#tokensregex_search').val().trim() == '') {
|
| 1205 |
+
$('#tokensregex_search').val('(?$foxtype [{pos:JJ}]+ ) fox');
|
| 1206 |
+
}
|
| 1207 |
+
var pattern = $('#tokensregex_search').val();
|
| 1208 |
+
// Remove existing annotation
|
| 1209 |
+
$('#tokensregex').remove();
|
| 1210 |
+
// Make ajax call
|
| 1211 |
+
$.ajax({
|
| 1212 |
+
type: 'POST',
|
| 1213 |
+
url: serverAddress + '/tokensregex?pattern=' + encodeURIComponent(
|
| 1214 |
+
pattern.replace("&", "\\&").replace('+', '\\+')) +
|
| 1215 |
+
'&properties=' + encodeURIComponent(
|
| 1216 |
+
'{"annotators": "' + annotators() + '", "date": "' + date() + '"}') +
|
| 1217 |
+
'&pipelineLanguage=' + encodeURIComponent($('#language').val()),
|
| 1218 |
+
data: encodeURIComponent(currentQuery),
|
| 1219 |
+
success: function(data) {
|
| 1220 |
+
$('.tokensregex_error').remove(); // Clear error messages
|
| 1221 |
+
$('<div id="tokensregex" class="pattern_brat"/>').appendTo($('#div_tokensregex'));
|
| 1222 |
+
renderTokensregex(data);
|
| 1223 |
+
},
|
| 1224 |
+
error: function(data) {
|
| 1225 |
+
var alertDiv = $('<div/>').addClass('alert').addClass('alert-danger').addClass('alert-dismissible').addClass('tokensregex_error').attr('role', 'alert')
|
| 1226 |
+
var button = $('<button type="button" class="close" data-dismiss="alert" aria-label="Close"><span aria-hidden="true">×</span></button>');
|
| 1227 |
+
var message = $('<span/>').text(data.responseText);
|
| 1228 |
+
button.appendTo(alertDiv);
|
| 1229 |
+
message.appendTo(alertDiv);
|
| 1230 |
+
alertDiv.appendTo($('#div_tokensregex'));
|
| 1231 |
+
}
|
| 1232 |
+
});
|
| 1233 |
+
});
|
| 1234 |
+
|
| 1235 |
+
|
| 1236 |
+
$('#form_semgrex').submit( function (e) {
|
| 1237 |
+
// Don't actually submit the form
|
| 1238 |
+
e.preventDefault();
|
| 1239 |
+
// Get text
|
| 1240 |
+
if ($('#semgrex_search').val().trim() == '') {
|
| 1241 |
+
$('#semgrex_search').val('{pos:/VB.*/} >nsubj {}=subject >/nmod:.*/ {}=prep_phrase');
|
| 1242 |
+
}
|
| 1243 |
+
var pattern = $('#semgrex_search').val();
|
| 1244 |
+
// Remove existing annotation
|
| 1245 |
+
$('#semgrex').remove();
|
| 1246 |
+
// Add missing required annotators
|
| 1247 |
+
var requiredAnnotators = annotators().split(',');
|
| 1248 |
+
if (requiredAnnotators.indexOf('depparse') < 0) {
|
| 1249 |
+
requiredAnnotators.push('depparse');
|
| 1250 |
+
}
|
| 1251 |
+
// Make ajax call
|
| 1252 |
+
$.ajax({
|
| 1253 |
+
type: 'POST',
|
| 1254 |
+
url: serverAddress + '/semgrex?pattern=' + encodeURIComponent(
|
| 1255 |
+
pattern.replace("&", "\\&").replace('+', '\\+')) +
|
| 1256 |
+
'&properties=' + encodeURIComponent(
|
| 1257 |
+
'{"annotators": "' + requiredAnnotators.join(',') + '", "date": "' + date() + '"}') +
|
| 1258 |
+
'&pipelineLanguage=' + encodeURIComponent($('#language').val()),
|
| 1259 |
+
data: encodeURIComponent(currentQuery),
|
| 1260 |
+
success: function(data) {
|
| 1261 |
+
$('.semgrex_error').remove(); // Clear error messages
|
| 1262 |
+
$('<div id="semgrex" class="pattern_brat"/>').appendTo($('#div_semgrex'));
|
| 1263 |
+
renderSemgrex(data);
|
| 1264 |
+
},
|
| 1265 |
+
error: function(data) {
|
| 1266 |
+
var alertDiv = $('<div/>').addClass('alert').addClass('alert-danger').addClass('alert-dismissible').addClass('semgrex_error').attr('role', 'alert')
|
| 1267 |
+
var button = $('<button type="button" class="close" data-dismiss="alert" aria-label="Close"><span aria-hidden="true">×</span></button>');
|
| 1268 |
+
var message = $('<span/>').text(data.responseText);
|
| 1269 |
+
button.appendTo(alertDiv);
|
| 1270 |
+
message.appendTo(alertDiv);
|
| 1271 |
+
alertDiv.appendTo($('#div_semgrex'));
|
| 1272 |
+
}
|
| 1273 |
+
});
|
| 1274 |
+
});
|
| 1275 |
+
|
| 1276 |
+
$('#form_tregex').submit( function (e) {
|
| 1277 |
+
// Don't actually submit the form
|
| 1278 |
+
e.preventDefault();
|
| 1279 |
+
// Get text
|
| 1280 |
+
if ($('#tregex_search').val().trim() == '') {
|
| 1281 |
+
$('#tregex_search').val('NP < NN=animal');
|
| 1282 |
+
}
|
| 1283 |
+
var pattern = $('#tregex_search').val();
|
| 1284 |
+
// Remove existing annotation
|
| 1285 |
+
$('#tregex').remove();
|
| 1286 |
+
// Add missing required annotators
|
| 1287 |
+
var requiredAnnotators = annotators().split(',');
|
| 1288 |
+
if (requiredAnnotators.indexOf('parse') < 0) {
|
| 1289 |
+
requiredAnnotators.push('parse');
|
| 1290 |
+
}
|
| 1291 |
+
// Make ajax call
|
| 1292 |
+
$.ajax({
|
| 1293 |
+
type: 'POST',
|
| 1294 |
+
url: serverAddress + '/tregex?pattern=' + encodeURIComponent(
|
| 1295 |
+
pattern.replace("&", "\\&").replace('+', '\\+')) +
|
| 1296 |
+
'&properties=' + encodeURIComponent(
|
| 1297 |
+
'{"annotators": "' + requiredAnnotators.join(',') + '", "date": "' + date() + '"}') +
|
| 1298 |
+
'&pipelineLanguage=' + encodeURIComponent($('#language').val()),
|
| 1299 |
+
data: encodeURIComponent(currentQuery),
|
| 1300 |
+
success: function(data) {
|
| 1301 |
+
$('.tregex_error').remove(); // Clear error messages
|
| 1302 |
+
$('<div id="tregex" class="pattern_brat"/>').appendTo($('#div_tregex'));
|
| 1303 |
+
renderTregex(data);
|
| 1304 |
+
},
|
| 1305 |
+
error: function(data) {
|
| 1306 |
+
var alertDiv = $('<div/>').addClass('alert').addClass('alert-danger').addClass('alert-dismissible').addClass('tregex_error').attr('role', 'alert')
|
| 1307 |
+
var button = $('<button type="button" class="close" data-dismiss="alert" aria-label="Close"><span aria-hidden="true">×</span></button>');
|
| 1308 |
+
var message = $('<span/>').text(data.responseText);
|
| 1309 |
+
button.appendTo(alertDiv);
|
| 1310 |
+
message.appendTo(alertDiv);
|
| 1311 |
+
alertDiv.appendTo($('#div_tregex'));
|
| 1312 |
+
}
|
| 1313 |
+
});
|
| 1314 |
+
});
|
| 1315 |
+
|
| 1316 |
+
});
|
stanza/stanza/pipeline/external/corenlp_converter_depparse.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
A depparse processor which converts constituency trees using CoreNLP
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from stanza.pipeline._constants import TOKENIZE, CONSTITUENCY, DEPPARSE
|
| 6 |
+
from stanza.pipeline.processor import ProcessorVariant, register_processor_variant
|
| 7 |
+
from stanza.server.dependency_converter import DependencyConverter
|
| 8 |
+
|
| 9 |
+
@register_processor_variant(DEPPARSE, 'converter')
|
| 10 |
+
class ConverterDepparse(ProcessorVariant):
|
| 11 |
+
# set of processor requirements for this processor
|
| 12 |
+
REQUIRES_DEFAULT = set([TOKENIZE, CONSTITUENCY])
|
| 13 |
+
|
| 14 |
+
def __init__(self, config):
|
| 15 |
+
if config['lang'] != 'en':
|
| 16 |
+
raise ValueError("Constituency to dependency converter only works for English")
|
| 17 |
+
|
| 18 |
+
# TODO: get classpath from config
|
| 19 |
+
# TODO: close this when finished?
|
| 20 |
+
# a more involved approach would be to turn the Pipeline into
|
| 21 |
+
# a context with __enter__ and __exit__
|
| 22 |
+
# __exit__ would try to free all resources, although some
|
| 23 |
+
# might linger such as GPU allocations
|
| 24 |
+
# maybe it isn't worth even trying to clean things up on account of that
|
| 25 |
+
self.converter = DependencyConverter(classpath="$CLASSPATH")
|
| 26 |
+
self.converter.open_pipe()
|
| 27 |
+
|
| 28 |
+
def process(self, document):
|
| 29 |
+
return self.converter.process(document)
|
stanza/stanza/pipeline/external/jieba.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Processors related to Jieba in the pipeline.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import re
|
| 6 |
+
|
| 7 |
+
from stanza.models.common import doc
|
| 8 |
+
from stanza.pipeline._constants import TOKENIZE
|
| 9 |
+
from stanza.pipeline.processor import ProcessorVariant, register_processor_variant
|
| 10 |
+
|
| 11 |
+
def check_jieba():
|
| 12 |
+
"""
|
| 13 |
+
Import necessary components from Jieba to perform tokenization.
|
| 14 |
+
"""
|
| 15 |
+
try:
|
| 16 |
+
import jieba
|
| 17 |
+
except ImportError:
|
| 18 |
+
raise ImportError(
|
| 19 |
+
"Jieba is used but not installed on your machine. Go to https://pypi.org/project/jieba/ for installation instructions."
|
| 20 |
+
)
|
| 21 |
+
return True
|
| 22 |
+
|
| 23 |
+
@register_processor_variant(TOKENIZE, 'jieba')
|
| 24 |
+
class JiebaTokenizer(ProcessorVariant):
|
| 25 |
+
def __init__(self, config):
|
| 26 |
+
""" Construct a Jieba-based tokenizer by loading the Jieba pipeline.
|
| 27 |
+
|
| 28 |
+
Note that this tokenizer uses regex for sentence segmentation.
|
| 29 |
+
"""
|
| 30 |
+
if config['lang'] not in ['zh', 'zh-hans', 'zh-hant']:
|
| 31 |
+
raise Exception("Jieba tokenizer is currently only allowed in Chinese (simplified or traditional) pipelines.")
|
| 32 |
+
|
| 33 |
+
check_jieba()
|
| 34 |
+
import jieba
|
| 35 |
+
self.nlp = jieba
|
| 36 |
+
self.no_ssplit = config.get('no_ssplit', False)
|
| 37 |
+
|
| 38 |
+
def process(self, document):
|
| 39 |
+
""" Tokenize a document with the Jieba tokenizer and wrap the results into a Doc object.
|
| 40 |
+
"""
|
| 41 |
+
if isinstance(document, doc.Document):
|
| 42 |
+
text = document.text
|
| 43 |
+
else:
|
| 44 |
+
text = document
|
| 45 |
+
if not isinstance(text, str):
|
| 46 |
+
raise Exception("Must supply a string or Stanza Document object to the Jieba tokenizer.")
|
| 47 |
+
tokens = self.nlp.cut(text, cut_all=False)
|
| 48 |
+
|
| 49 |
+
sentences = []
|
| 50 |
+
current_sentence = []
|
| 51 |
+
offset = 0
|
| 52 |
+
for token in tokens:
|
| 53 |
+
if re.match(r'\s+', token):
|
| 54 |
+
offset += len(token)
|
| 55 |
+
continue
|
| 56 |
+
|
| 57 |
+
token_entry = {
|
| 58 |
+
doc.TEXT: token,
|
| 59 |
+
doc.MISC: f"{doc.START_CHAR}={offset}|{doc.END_CHAR}={offset+len(token)}"
|
| 60 |
+
}
|
| 61 |
+
current_sentence.append(token_entry)
|
| 62 |
+
offset += len(token)
|
| 63 |
+
|
| 64 |
+
if not self.no_ssplit and token in ['。', '!', '?', '!', '?']:
|
| 65 |
+
sentences.append(current_sentence)
|
| 66 |
+
current_sentence = []
|
| 67 |
+
|
| 68 |
+
if len(current_sentence) > 0:
|
| 69 |
+
sentences.append(current_sentence)
|
| 70 |
+
|
| 71 |
+
return doc.Document(sentences, text)
|
stanza/stanza/pipeline/external/sudachipy.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Processors related to SudachiPy in the pipeline.
|
| 3 |
+
|
| 4 |
+
GitHub Home: https://github.com/WorksApplications/SudachiPy
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import re
|
| 8 |
+
|
| 9 |
+
from stanza.models.common import doc
|
| 10 |
+
from stanza.pipeline._constants import TOKENIZE
|
| 11 |
+
from stanza.pipeline.processor import ProcessorVariant, register_processor_variant
|
| 12 |
+
|
| 13 |
+
def check_sudachipy():
|
| 14 |
+
"""
|
| 15 |
+
Import necessary components from SudachiPy to perform tokenization.
|
| 16 |
+
"""
|
| 17 |
+
try:
|
| 18 |
+
import sudachipy
|
| 19 |
+
import sudachidict_core
|
| 20 |
+
except ImportError:
|
| 21 |
+
raise ImportError(
|
| 22 |
+
"Both sudachipy and sudachidict_core libraries are required. "
|
| 23 |
+
"Try install them with `pip install sudachipy sudachidict_core`. "
|
| 24 |
+
"Go to https://github.com/WorksApplications/SudachiPy for more information."
|
| 25 |
+
)
|
| 26 |
+
return True
|
| 27 |
+
|
| 28 |
+
@register_processor_variant(TOKENIZE, 'sudachipy')
|
| 29 |
+
class SudachiPyTokenizer(ProcessorVariant):
|
| 30 |
+
def __init__(self, config):
|
| 31 |
+
""" Construct a SudachiPy-based tokenizer.
|
| 32 |
+
|
| 33 |
+
Note that this tokenizer uses regex for sentence segmentation.
|
| 34 |
+
"""
|
| 35 |
+
if config['lang'] != 'ja':
|
| 36 |
+
raise Exception("SudachiPy tokenizer is only allowed in Japanese pipelines.")
|
| 37 |
+
|
| 38 |
+
check_sudachipy()
|
| 39 |
+
from sudachipy import tokenizer
|
| 40 |
+
from sudachipy import dictionary
|
| 41 |
+
|
| 42 |
+
self.tokenizer = dictionary.Dictionary().create()
|
| 43 |
+
self.no_ssplit = config.get('no_ssplit', False)
|
| 44 |
+
|
| 45 |
+
def process(self, document):
|
| 46 |
+
""" Tokenize a document with the SudachiPy tokenizer and wrap the results into a Doc object.
|
| 47 |
+
"""
|
| 48 |
+
if isinstance(document, doc.Document):
|
| 49 |
+
text = document.text
|
| 50 |
+
else:
|
| 51 |
+
text = document
|
| 52 |
+
if not isinstance(text, str):
|
| 53 |
+
raise Exception("Must supply a string or Stanza Document object to the SudachiPy tokenizer.")
|
| 54 |
+
|
| 55 |
+
# we use the default sudachipy tokenization mode (i.e., mode C)
|
| 56 |
+
# more config needs to be added to support other modes
|
| 57 |
+
|
| 58 |
+
tokens = self.tokenizer.tokenize(text)
|
| 59 |
+
|
| 60 |
+
sentences = []
|
| 61 |
+
current_sentence = []
|
| 62 |
+
for token in tokens:
|
| 63 |
+
token_text = token.surface()
|
| 64 |
+
# by default sudachipy will output whitespace as a token
|
| 65 |
+
# we need to skip these tokens to be consistent with other tokenizers
|
| 66 |
+
if token_text.isspace():
|
| 67 |
+
continue
|
| 68 |
+
start = token.begin()
|
| 69 |
+
end = token.end()
|
| 70 |
+
|
| 71 |
+
token_entry = {
|
| 72 |
+
doc.TEXT: token_text,
|
| 73 |
+
doc.MISC: f"{doc.START_CHAR}={start}|{doc.END_CHAR}={end}"
|
| 74 |
+
}
|
| 75 |
+
current_sentence.append(token_entry)
|
| 76 |
+
|
| 77 |
+
if not self.no_ssplit and token_text in ['。', '!', '?', '!', '?']:
|
| 78 |
+
sentences.append(current_sentence)
|
| 79 |
+
current_sentence = []
|
| 80 |
+
|
| 81 |
+
if len(current_sentence) > 0:
|
| 82 |
+
sentences.append(current_sentence)
|
| 83 |
+
|
| 84 |
+
return doc.Document(sentences, text)
|
stanza/stanza/utils/charlm/oscar_to_text.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Turns an Oscar 2022 jsonl file to text
|
| 3 |
+
|
| 4 |
+
YOU DO NOT NEED THIS if you use the oscar extractor which reads from
|
| 5 |
+
HuggingFace, dump_oscar.py
|
| 6 |
+
|
| 7 |
+
to run:
|
| 8 |
+
python3 -m stanza.utils.charlm.oscar_to_text <path> ...
|
| 9 |
+
|
| 10 |
+
each path can be a file or a directory with multiple .jsonl files in it
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import argparse
|
| 14 |
+
import glob
|
| 15 |
+
import json
|
| 16 |
+
import lzma
|
| 17 |
+
import os
|
| 18 |
+
import sys
|
| 19 |
+
from stanza.models.common.utils import open_read_text
|
| 20 |
+
|
| 21 |
+
def extract_file(output_directory, input_filename, use_xz):
|
| 22 |
+
print("Extracting %s" % input_filename)
|
| 23 |
+
if output_directory is None:
|
| 24 |
+
output_directory, output_filename = os.path.split(input_filename)
|
| 25 |
+
else:
|
| 26 |
+
_, output_filename = os.path.split(input_filename)
|
| 27 |
+
|
| 28 |
+
json_idx = output_filename.rfind(".jsonl")
|
| 29 |
+
if json_idx < 0:
|
| 30 |
+
output_filename = output_filename + ".txt"
|
| 31 |
+
else:
|
| 32 |
+
output_filename = output_filename[:json_idx] + ".txt"
|
| 33 |
+
if use_xz:
|
| 34 |
+
output_filename += ".xz"
|
| 35 |
+
open_file = lambda x: lzma.open(x, "wt", encoding="utf-8")
|
| 36 |
+
else:
|
| 37 |
+
open_file = lambda x: open(x, "w", encoding="utf-8")
|
| 38 |
+
|
| 39 |
+
output_filename = os.path.join(output_directory, output_filename)
|
| 40 |
+
print("Writing content to %s" % output_filename)
|
| 41 |
+
with open_read_text(input_filename) as fin:
|
| 42 |
+
with open_file(output_filename) as fout:
|
| 43 |
+
for line in fin:
|
| 44 |
+
content = json.loads(line)
|
| 45 |
+
content = content['content']
|
| 46 |
+
|
| 47 |
+
fout.write(content)
|
| 48 |
+
fout.write("\n\n")
|
| 49 |
+
|
| 50 |
+
def parse_args():
|
| 51 |
+
parser = argparse.ArgumentParser()
|
| 52 |
+
parser.add_argument("--output", default=None, help="Output directory for saving files. If None, will write to the original directory")
|
| 53 |
+
parser.add_argument("--no_xz", default=True, dest="xz", action="store_false", help="Don't use xz to compress the output files")
|
| 54 |
+
parser.add_argument("filenames", nargs="+", help="Filenames or directories to process")
|
| 55 |
+
args = parser.parse_args()
|
| 56 |
+
return args
|
| 57 |
+
|
| 58 |
+
def main():
|
| 59 |
+
"""
|
| 60 |
+
Go through each of the given filenames or directories, convert json to .txt.xz
|
| 61 |
+
"""
|
| 62 |
+
args = parse_args()
|
| 63 |
+
if args.output is not None:
|
| 64 |
+
os.makedirs(args.output, exist_ok=True)
|
| 65 |
+
for filename in args.filenames:
|
| 66 |
+
if os.path.isfile(filename):
|
| 67 |
+
extract_file(args.output, filename, args.xz)
|
| 68 |
+
elif os.path.isdir(filename):
|
| 69 |
+
files = glob.glob(os.path.join(filename, "*jsonl*"))
|
| 70 |
+
files = sorted([x for x in files if os.path.isfile(x)])
|
| 71 |
+
print("Found %d files:" % len(files))
|
| 72 |
+
if len(files) > 0:
|
| 73 |
+
print(" %s" % "\n ".join(files))
|
| 74 |
+
for json_filename in files:
|
| 75 |
+
extract_file(args.output, json_filename, args.xz)
|
| 76 |
+
|
| 77 |
+
if __name__ == "__main__":
|
| 78 |
+
main()
|
stanza/stanza/utils/constituency/__init__.py
ADDED
|
File without changes
|
stanza/stanza/utils/constituency/grep_test_logs.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import subprocess
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
filenames = sys.argv[1:]
|
| 5 |
+
|
| 6 |
+
total_score = 0.0
|
| 7 |
+
num_scores = 0
|
| 8 |
+
|
| 9 |
+
for filename in filenames:
|
| 10 |
+
grep_cmd = ["grep", "F1 score.*test.*", filename]
|
| 11 |
+
grep_result = subprocess.run(grep_cmd, stdout=subprocess.PIPE, encoding="utf-8")
|
| 12 |
+
grep_result = grep_result.stdout.strip()
|
| 13 |
+
if not grep_result:
|
| 14 |
+
print("{}: no result".format(filename))
|
| 15 |
+
continue
|
| 16 |
+
|
| 17 |
+
score = float(grep_result.split()[-1])
|
| 18 |
+
print("{}: {}".format(filename, score))
|
| 19 |
+
total_score += score
|
| 20 |
+
num_scores += 1
|
| 21 |
+
|
| 22 |
+
if num_scores > 0:
|
| 23 |
+
avg = total_score / num_scores
|
| 24 |
+
print("Avg: {}".format(avg))
|
stanza/stanza/utils/datasets/constituency/build_silver_dataset.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Given two ensembles and a tokenized file, output the trees for which those ensembles agree and report how many of the sub-models agree on those trees.
|
| 3 |
+
|
| 4 |
+
For example:
|
| 5 |
+
|
| 6 |
+
python3 -m stanza.utils.datasets.constituency.build_silver_dataset --tokenized_file /u/nlp/data/constituency-parser/italian/2024_wiki_tokenization/it_wiki_tokenized_AA.txt --lang it --output_file asdf.out --e1 saved_models/constituency/it_vit_electra_100?_top_constituency.pt --e2 saved_models/constituency/it_vit_electra_100?_constituency.pt
|
| 7 |
+
|
| 8 |
+
for i in `echo f g h i j k l m n o p q r s t`; do nlprun -d a6000 "python3 -m stanza.utils.datasets.constituency.build_silver_dataset --tokenized_file /u/nlp/data/constituency-parser/italian/2024_wiki_tokenization/it_wiki_tok_6M_a$i.txt --lang it --output_file /u/nlp/data/constituency-parser/italian/2024_it_vit_electra/a$i.trees --e1 saved_models/constituency/it_vit_electra_100?_top_constituency.pt --e2 saved_models/constituency/it_vit_electra_100?_constituency.pt" -o /u/nlp/data/constituency-parser/italian/2024_it_vit_electra/a$i.out; done
|
| 9 |
+
|
| 10 |
+
for i in `echo a b c d`; do nlprun -d a6000 "python3 -m stanza.utils.datasets.constituency.build_silver_dataset --tokenized_file /u/nlp/data/constituency-parser/english/en_wiki_2023/shuf_1M.a$i --lang en --output_file /u/nlp/data/constituency-parser/english/2024_en_ptb3_electra/forward_a$i.trees --e1 saved_models/constituency/en_ptb3_electra-large_100?_in_constituency.pt --e2 saved_models/constituency/en_ptb3_electra-large_100?_top_constituency.pt" -o /u/nlp/data/constituency-parser/english/2024_en_ptb3_electra/forward_a$i.out; done
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import argparse
|
| 14 |
+
import json
|
| 15 |
+
|
| 16 |
+
import logging
|
| 17 |
+
|
| 18 |
+
from stanza.models.common import utils
|
| 19 |
+
from stanza.models.common.foundation_cache import FoundationCache
|
| 20 |
+
from stanza.models.constituency import retagging
|
| 21 |
+
from stanza.models.constituency import text_processing
|
| 22 |
+
from stanza.models.constituency import tree_reader
|
| 23 |
+
from stanza.models.constituency.ensemble import Ensemble
|
| 24 |
+
from stanza.utils.get_tqdm import get_tqdm
|
| 25 |
+
|
| 26 |
+
tqdm = get_tqdm()
|
| 27 |
+
|
| 28 |
+
logger = logging.getLogger('stanza.constituency.trainer')
|
| 29 |
+
|
| 30 |
+
def parse_args(args=None):
|
| 31 |
+
parser = argparse.ArgumentParser(description="Script that uses multiple ensembles to find trees where both ensembles agree")
|
| 32 |
+
|
| 33 |
+
input_group = parser.add_mutually_exclusive_group(required=True)
|
| 34 |
+
input_group.add_argument('--tokenized_file', type=str, default=None, help='Input file of tokenized text for parsing with parse_text.')
|
| 35 |
+
input_group.add_argument('--tree_file', type=str, default=None, help='Input file of already parsed text for reparsing with parse_text.')
|
| 36 |
+
parser.add_argument('--output_file', type=str, default=None, help='Where to put the output file')
|
| 37 |
+
|
| 38 |
+
parser.add_argument('--charlm_forward_file', type=str, default=None, help="Exact path to use for forward charlm")
|
| 39 |
+
parser.add_argument('--charlm_backward_file', type=str, default=None, help="Exact path to use for backward charlm")
|
| 40 |
+
parser.add_argument('--wordvec_pretrain_file', type=str, default=None, help='Exact name of the pretrain file to read')
|
| 41 |
+
|
| 42 |
+
utils.add_device_args(parser)
|
| 43 |
+
|
| 44 |
+
parser.add_argument('--lang', default='en', help='Language to use')
|
| 45 |
+
|
| 46 |
+
parser.add_argument('--eval_batch_size', type=int, default=50, help='How many trees to batch when running eval')
|
| 47 |
+
parser.add_argument('--e1', type=str, nargs='+', default=None, help="Which model(s) to load in the first ensemble")
|
| 48 |
+
parser.add_argument('--e2', type=str, nargs='+', default=None, help="Which model(s) to load in the second ensemble")
|
| 49 |
+
|
| 50 |
+
parser.add_argument('--mode', default='predict', choices=['parse_text', 'predict'])
|
| 51 |
+
|
| 52 |
+
# another option would be to include the tree idx in each entry in an existing saved file
|
| 53 |
+
# the processing could then pick up at exactly the last known idx
|
| 54 |
+
parser.add_argument('--start_tree', type=int, default=0, help='Where to start... most useful if the previous incarnation crashed')
|
| 55 |
+
parser.add_argument('--end_tree', type=int, default=None, help='Where to end. If unset, will process to the end of the file')
|
| 56 |
+
|
| 57 |
+
retagging.add_retag_args(parser)
|
| 58 |
+
|
| 59 |
+
args = vars(parser.parse_args())
|
| 60 |
+
|
| 61 |
+
retagging.postprocess_args(args)
|
| 62 |
+
args['num_generate'] = 0
|
| 63 |
+
|
| 64 |
+
return args
|
| 65 |
+
|
| 66 |
+
def main():
|
| 67 |
+
args = parse_args()
|
| 68 |
+
utils.log_training_args(args, logger, name="ensemble")
|
| 69 |
+
|
| 70 |
+
retag_pipeline = retagging.build_retag_pipeline(args)
|
| 71 |
+
foundation_cache = retag_pipeline[0].foundation_cache if retag_pipeline else FoundationCache()
|
| 72 |
+
|
| 73 |
+
logger.info("Building ensemble #1 out of %s", args['e1'])
|
| 74 |
+
e1 = Ensemble(args, filenames=args['e1'], foundation_cache=foundation_cache)
|
| 75 |
+
e1.to(args.get('device', None))
|
| 76 |
+
logger.info("Building ensemble #2 out of %s", args['e2'])
|
| 77 |
+
e2 = Ensemble(args, filenames=args['e2'], foundation_cache=foundation_cache)
|
| 78 |
+
e2.to(args.get('device', None))
|
| 79 |
+
|
| 80 |
+
if args['tokenized_file']:
|
| 81 |
+
tokenized_sentences = text_processing.read_tokenized_file(args['tokenized_file'])
|
| 82 |
+
elif args['tree_file']:
|
| 83 |
+
treebank = tree_reader.read_treebank(args['tree_file'])
|
| 84 |
+
tokenized_sentences = [x.leaf_labels() for x in treebank]
|
| 85 |
+
if args['lang'] == 'vi':
|
| 86 |
+
tokenized_sentences = [[x.replace("_", " ") for x in sentence] for sentence in tokenized_sentences]
|
| 87 |
+
logger.info("Read %d tokenized sentences", len(tokenized_sentences))
|
| 88 |
+
|
| 89 |
+
all_models = e1.models + e2.models
|
| 90 |
+
|
| 91 |
+
chunk_size = 1000
|
| 92 |
+
with open(args['output_file'], 'w', encoding='utf-8') as fout:
|
| 93 |
+
end_tree = len(tokenized_sentences) if args['end_tree'] is None else args['end_tree']
|
| 94 |
+
for chunk_start in tqdm(range(args['start_tree'], end_tree, chunk_size)):
|
| 95 |
+
chunk = tokenized_sentences[chunk_start:chunk_start+chunk_size]
|
| 96 |
+
logger.info("Processing trees %d to %d", chunk_start, chunk_start+len(chunk))
|
| 97 |
+
parsed1 = text_processing.parse_tokenized_sentences(args, e1, retag_pipeline, chunk)
|
| 98 |
+
parsed1 = [x.predictions[0].tree for x in parsed1]
|
| 99 |
+
parsed2 = text_processing.parse_tokenized_sentences(args, e2, retag_pipeline, chunk)
|
| 100 |
+
parsed2 = [x.predictions[0].tree for x in parsed2]
|
| 101 |
+
matching = [t for t, t2 in zip(parsed1, parsed2) if t == t2]
|
| 102 |
+
logger.info("%d trees matched", len(matching))
|
| 103 |
+
model_counts = [0] * len(matching)
|
| 104 |
+
for model in all_models:
|
| 105 |
+
model_chunk = model.parse_sentences_no_grad(iter(matching), model.build_batch_from_trees, args['eval_batch_size'], model.predict)
|
| 106 |
+
model_chunk = [x.predictions[0].tree for x in model_chunk]
|
| 107 |
+
for idx, (t1, t2) in enumerate(zip(matching, model_chunk)):
|
| 108 |
+
if t1 == t2:
|
| 109 |
+
model_counts[idx] += 1
|
| 110 |
+
for count, tree in zip(model_counts, matching):
|
| 111 |
+
line = {"tree": "%s" % tree, "count": count}
|
| 112 |
+
fout.write(json.dumps(line))
|
| 113 |
+
fout.write("\n")
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
if __name__ == '__main__':
|
| 117 |
+
main()
|
stanza/stanza/utils/datasets/constituency/convert_cintil.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import xml.etree.ElementTree as ET
|
| 2 |
+
|
| 3 |
+
from stanza.models.constituency import tree_reader
|
| 4 |
+
from stanza.utils.datasets.constituency import utils
|
| 5 |
+
|
| 6 |
+
def read_xml_file(input_filename):
|
| 7 |
+
"""
|
| 8 |
+
Convert the CINTIL xml file to id & test
|
| 9 |
+
|
| 10 |
+
Returns a list of tuples: (id, text)
|
| 11 |
+
"""
|
| 12 |
+
with open(input_filename, encoding="utf-8") as fin:
|
| 13 |
+
dataset = ET.parse(fin)
|
| 14 |
+
dataset = dataset.getroot()
|
| 15 |
+
corpus = dataset.find("{http://nlx.di.fc.ul.pt}corpus")
|
| 16 |
+
if not corpus:
|
| 17 |
+
raise ValueError("Unexpected dataset structure : no 'corpus'")
|
| 18 |
+
trees = []
|
| 19 |
+
for sentence in corpus:
|
| 20 |
+
if sentence.tag != "{http://nlx.di.fc.ul.pt}sentence":
|
| 21 |
+
raise ValueError("Unexpected sentence tag: {}".format(sentence.tag))
|
| 22 |
+
id_node = None
|
| 23 |
+
raw_node = None
|
| 24 |
+
tree_nodde = None
|
| 25 |
+
for node in sentence:
|
| 26 |
+
if node.tag == '{http://nlx.di.fc.ul.pt}id':
|
| 27 |
+
id_node = node
|
| 28 |
+
elif node.tag == '{http://nlx.di.fc.ul.pt}raw':
|
| 29 |
+
raw_node = node
|
| 30 |
+
elif node.tag == '{http://nlx.di.fc.ul.pt}tree':
|
| 31 |
+
tree_node = node
|
| 32 |
+
else:
|
| 33 |
+
raise ValueError("Unexpected tag in sentence {}: {}".format(sentence, node.tag))
|
| 34 |
+
if id_node is None or raw_node is None or tree_node is None:
|
| 35 |
+
raise ValueError("Missing node in sentence {}".format(sentence))
|
| 36 |
+
tree_id = "".join(id_node.itertext())
|
| 37 |
+
tree_text = "".join(tree_node.itertext())
|
| 38 |
+
trees.append((tree_id, tree_text))
|
| 39 |
+
return trees
|
| 40 |
+
|
| 41 |
+
def convert_cintil_treebank(input_filename, train_size=0.8, dev_size=0.1):
|
| 42 |
+
"""
|
| 43 |
+
dev_size is the size for splitting train & dev
|
| 44 |
+
"""
|
| 45 |
+
trees = read_xml_file(input_filename)
|
| 46 |
+
|
| 47 |
+
synthetic_trees = []
|
| 48 |
+
natural_trees = []
|
| 49 |
+
for tree_id, tree_text in trees:
|
| 50 |
+
if tree_text.find(" _") >= 0:
|
| 51 |
+
raise ValueError("Unexpected underscore")
|
| 52 |
+
tree_text = tree_text.replace("_)", ")")
|
| 53 |
+
tree_text = tree_text.replace("(A (", "(A' (")
|
| 54 |
+
# trees don't have ROOT, but we typically use a ROOT label at the top
|
| 55 |
+
tree_text = "(ROOT %s)" % tree_text
|
| 56 |
+
trees = tree_reader.read_trees(tree_text)
|
| 57 |
+
if len(trees) != 1:
|
| 58 |
+
raise ValueError("Unexpectedly found %d trees in %s" % (len(trees), tree_id))
|
| 59 |
+
tree = trees[0]
|
| 60 |
+
if tree_id.startswith("aTSTS"):
|
| 61 |
+
synthetic_trees.append(tree)
|
| 62 |
+
elif tree_id.find("TSTS") >= 0:
|
| 63 |
+
raise ValueError("Unexpected TSTS")
|
| 64 |
+
else:
|
| 65 |
+
natural_trees.append(tree)
|
| 66 |
+
|
| 67 |
+
print("Read %d synthetic trees" % len(synthetic_trees))
|
| 68 |
+
print("Read %d natural trees" % len(natural_trees))
|
| 69 |
+
train_trees, dev_trees, test_trees = utils.split_treebank(natural_trees, train_size, dev_size)
|
| 70 |
+
print("Split %d trees into %d train %d dev %d test" % (len(natural_trees), len(train_trees), len(dev_trees), len(test_trees)))
|
| 71 |
+
train_trees = synthetic_trees + train_trees
|
| 72 |
+
print("Total lengths %d train %d dev %d test" % (len(train_trees), len(dev_trees), len(test_trees)))
|
| 73 |
+
return train_trees, dev_trees, test_trees
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def main():
|
| 77 |
+
treebank = convert_cintil_treebank("extern_data/constituency/portuguese/CINTIL/CINTIL-Treebank.xml")
|
| 78 |
+
|
| 79 |
+
if __name__ == '__main__':
|
| 80 |
+
main()
|
stanza/stanza/utils/datasets/constituency/count_common_words.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
|
| 3 |
+
from collections import Counter
|
| 4 |
+
|
| 5 |
+
from stanza.models.constituency import parse_tree
|
| 6 |
+
from stanza.models.constituency import tree_reader
|
| 7 |
+
|
| 8 |
+
word_counter = Counter()
|
| 9 |
+
count_words = lambda x: word_counter.update(x.leaf_labels())
|
| 10 |
+
|
| 11 |
+
tree_reader.read_tree_file(sys.argv[1], tree_callback=count_words)
|
| 12 |
+
print(word_counter.most_common()[:100])
|
stanza/stanza/utils/datasets/constituency/prepare_con_dataset.py
ADDED
|
@@ -0,0 +1,594 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Converts raw data files from their original format (dataset dependent) into PTB trees.
|
| 2 |
+
|
| 3 |
+
The operation of this script depends heavily on the dataset in question.
|
| 4 |
+
The common result is that the data files go to data/constituency and are in PTB format.
|
| 5 |
+
|
| 6 |
+
da_arboretum
|
| 7 |
+
Ekhard Bick
|
| 8 |
+
Arboretum, a Hybrid Treebank for Danish
|
| 9 |
+
https://www.researchgate.net/publication/251202293_Arboretum_a_Hybrid_Treebank_for_Danish
|
| 10 |
+
Available here for a license fee:
|
| 11 |
+
http://catalog.elra.info/en-us/repository/browse/ELRA-W0084/
|
| 12 |
+
Internal to Stanford, please contact Chris Manning and/or John Bauer
|
| 13 |
+
The file processed is the tiger xml, although there are some edits
|
| 14 |
+
needed in order to make it functional for our parser
|
| 15 |
+
The treebank comes as a tar.gz file, W0084.tar.gz
|
| 16 |
+
untar this file in $CONSTITUENCY_BASE/danish
|
| 17 |
+
then move the extracted folder to "arboretum"
|
| 18 |
+
$CONSTITUENCY_BASE/danish/W0084/... becomes
|
| 19 |
+
$CONSTITUENCY_BASE/danish/arboretum/...
|
| 20 |
+
|
| 21 |
+
en_ptb3-revised is an updated version of PTB with NML and stuff
|
| 22 |
+
put LDC2015T13 in $CONSTITUENCY_BASE/english
|
| 23 |
+
the directory name may look like LDC2015T13_eng_news_txt_tbnk-ptb_revised
|
| 24 |
+
python3 -m stanza.utils.datasets.constituency.prepare_con_dataset en_ptb3-revised
|
| 25 |
+
|
| 26 |
+
All this needs to do is concatenate the various pieces
|
| 27 |
+
|
| 28 |
+
@article{ptb_revised,
|
| 29 |
+
title= {Penn Treebank Revised: English News Text Treebank LDC2015T13},
|
| 30 |
+
journal= {},
|
| 31 |
+
author= {Ann Bies and Justin Mott and Colin Warner},
|
| 32 |
+
year= {2015},
|
| 33 |
+
url= {https://doi.org/10.35111/xpjy-at91},
|
| 34 |
+
doi= {10.35111/xpjy-at91},
|
| 35 |
+
isbn= {1-58563-724-6},
|
| 36 |
+
dcmi= {text},
|
| 37 |
+
languages= {english},
|
| 38 |
+
language= {english},
|
| 39 |
+
ldc= {LDC2015T13},
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
id_icon
|
| 43 |
+
ICON: Building a Large-Scale Benchmark Constituency Treebank
|
| 44 |
+
for the Indonesian Language
|
| 45 |
+
Ee Suan Lim, Wei Qi Leong, Ngan Thanh Nguyen, Dea Adhista,
|
| 46 |
+
Wei Ming Kng, William Chandra Tjhi, Ayu Purwarianti
|
| 47 |
+
https://aclanthology.org/2023.tlt-1.5.pdf
|
| 48 |
+
Available at https://github.com/aisingapore/seacorenlp-data
|
| 49 |
+
git clone the repo in $CONSTITUENCY_BASE/seacorenlp
|
| 50 |
+
so there is now a directory
|
| 51 |
+
$CONSTITUENCY_BASE/seacorenlp/seacorenlp-data
|
| 52 |
+
python3 -m stanza.utils.datasets.constituency.prepare_con_dataset id_icon
|
| 53 |
+
|
| 54 |
+
it_turin
|
| 55 |
+
A combination of Evalita competition from 2011 and the ParTUT trees
|
| 56 |
+
More information is available in convert_it_turin
|
| 57 |
+
|
| 58 |
+
it_vit
|
| 59 |
+
The original for the VIT UD Dataset
|
| 60 |
+
The UD version has a lot of corrections, so we try to apply those as much as possible
|
| 61 |
+
In fact, we applied some corrections of our own back to UD based on this treebank.
|
| 62 |
+
The first version which had those corrections is UD 2.10
|
| 63 |
+
Versions of UD before that won't work
|
| 64 |
+
Hopefully versions after that work
|
| 65 |
+
Set UDBASE to a path such that $UDBASE/UD_Italian-VIT is the UD version
|
| 66 |
+
The constituency labels are generally not very understandable, unfortunately
|
| 67 |
+
Some documentation is available here:
|
| 68 |
+
https://core.ac.uk/download/pdf/223148096.pdf
|
| 69 |
+
https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.423.5538&rep=rep1&type=pdf
|
| 70 |
+
Available from ELRA:
|
| 71 |
+
http://catalog.elra.info/en-us/repository/browse/ELRA-W0040/
|
| 72 |
+
|
| 73 |
+
ja_alt
|
| 74 |
+
Asian Language Treebank produced a treebank for Japanese:
|
| 75 |
+
Ye Kyaw Thu, Win Pa Pa, Masao Utiyama, Andrew Finch, Eiichiro Sumita
|
| 76 |
+
Introducing the Asian Language Treebank
|
| 77 |
+
http://www.lrec-conf.org/proceedings/lrec2016/pdf/435_Paper.pdf
|
| 78 |
+
Download
|
| 79 |
+
https://www2.nict.go.jp/astrec-att/member/mutiyama/ALT/Japanese-ALT-20210218.zip
|
| 80 |
+
unzip this in $CONSTITUENCY_BASE/japanese
|
| 81 |
+
this should create a directory $CONSTITUENCY_BASE/japanese/Japanese-ALT-20210218
|
| 82 |
+
In this directory, also download the following:
|
| 83 |
+
https://www2.nict.go.jp/astrec-att/member/mutiyama/ALT/URL-train.txt
|
| 84 |
+
https://www2.nict.go.jp/astrec-att/member/mutiyama/ALT/URL-dev.txt
|
| 85 |
+
https://www2.nict.go.jp/astrec-att/member/mutiyama/ALT/URL-test.txt
|
| 86 |
+
In particular, there are two files with a bunch of bracketed parses,
|
| 87 |
+
Japanese-ALT-Draft.txt and Japanese-ALT-Reviewed.txt
|
| 88 |
+
The first word of each of these lines is SNT.80188.1 or something like that
|
| 89 |
+
This correlates with the three URL-... files, telling us whether the
|
| 90 |
+
sentence belongs in train/dev/test
|
| 91 |
+
python3 -m stanza.utils.datasets.constituency.prepare_con_dataset ja_alt
|
| 92 |
+
|
| 93 |
+
pt_cintil
|
| 94 |
+
CINTIL treebank for Portuguese, available at ELRA:
|
| 95 |
+
https://catalogue.elra.info/en-us/repository/browse/ELRA-W0055/
|
| 96 |
+
It can also be obtained from here:
|
| 97 |
+
https://hdl.handle.net/21.11129/0000-000B-D2FE-A
|
| 98 |
+
Produced at U Lisbon
|
| 99 |
+
António Branco; João Silva; Francisco Costa; Sérgio Castro
|
| 100 |
+
CINTIL TreeBank Handbook: Design options for the representation of syntactic constituency
|
| 101 |
+
Silva, João; António Branco; Sérgio Castro; Ruben Reis
|
| 102 |
+
Out-of-the-Box Robust Parsing of Portuguese
|
| 103 |
+
https://portulanclarin.net/repository/extradocs/CINTIL-Treebank.pdf
|
| 104 |
+
http://www.di.fc.ul.pt/~ahb/pubs/2011bBrancoSilvaCostaEtAl.pdf
|
| 105 |
+
If at Stanford, ask John Bauer or Chris Manning for the data
|
| 106 |
+
Otherwise, purchase it from ELRA or find it elsewhere if possible
|
| 107 |
+
Either way, unzip it in
|
| 108 |
+
$CONSTITUENCY_BASE/portuguese to the CINTIL directory
|
| 109 |
+
so for example, the final result might be
|
| 110 |
+
extern_data/constituency/portuguese/CINTIL/CINTIL-Treebank.xml
|
| 111 |
+
python3 -m stanza.utils.datasets.constituency.prepare_con_dataset pt_cintil
|
| 112 |
+
|
| 113 |
+
tr_starlang
|
| 114 |
+
A dataset in three parts from the Starlang group in Turkey:
|
| 115 |
+
Neslihan Kara, Büşra Marşan, et al
|
| 116 |
+
Creating A Syntactically Felicitous Constituency Treebank For Turkish
|
| 117 |
+
https://ieeexplore.ieee.org/document/9259873
|
| 118 |
+
git clone the following three repos
|
| 119 |
+
https://github.com/olcaytaner/TurkishAnnotatedTreeBank-15
|
| 120 |
+
https://github.com/olcaytaner/TurkishAnnotatedTreeBank2-15
|
| 121 |
+
https://github.com/olcaytaner/TurkishAnnotatedTreeBank2-20
|
| 122 |
+
Put them in
|
| 123 |
+
$CONSTITUENCY_BASE/turkish
|
| 124 |
+
python3 -m stanza.utils.datasets.constituency.prepare_con_dataset tr_starlang
|
| 125 |
+
|
| 126 |
+
vlsp09 is the 2009 constituency treebank:
|
| 127 |
+
Nguyen Phuong Thai, Vu Xuan Luong, Nguyen Thi Minh Huyen, Nguyen Van Hiep, Le Hong Phuong
|
| 128 |
+
Building a Large Syntactically-Annotated Corpus of Vietnamese
|
| 129 |
+
Proceedings of The Third Linguistic Annotation Workshop
|
| 130 |
+
In conjunction with ACL-IJCNLP 2009, Suntec City, Singapore, 2009
|
| 131 |
+
This can be obtained by contacting vlsp.resources@gmail.com
|
| 132 |
+
|
| 133 |
+
vlsp22 is the 2022 constituency treebank from the VLSP bakeoff
|
| 134 |
+
there is an official test set as well
|
| 135 |
+
you may be able to obtain both of these by contacting vlsp.resources@gmail.com
|
| 136 |
+
NGUYEN Thi Minh Huyen, HA My Linh, VU Xuan Luong, PHAN Thi Hue,
|
| 137 |
+
LE Van Cuong, NGUYEN Thi Luong, NGO The Quyen
|
| 138 |
+
VLSP 2022 Challenge: Vietnamese Constituency Parsing
|
| 139 |
+
to appear in Journal of Computer Science and Cybernetics.
|
| 140 |
+
|
| 141 |
+
vlsp23 is the 2023 update to the constituency treebank from the VLSP bakeoff
|
| 142 |
+
the vlsp22 code also works for the new dataset,
|
| 143 |
+
although some effort may be needed to update the tags
|
| 144 |
+
As of late 2024, the test set is available on request at vlsp.resources@gmail.com
|
| 145 |
+
Organize the directory
|
| 146 |
+
$CONSTITUENCY_BASE/vietnamese/VLSP_2023
|
| 147 |
+
$CONSTITUENCY_BASE/vietnamese/VLSP_2023/Trainingset
|
| 148 |
+
$CONSTITUENCY_BASE/vietnamese/VLSP_2023/test
|
| 149 |
+
|
| 150 |
+
zh_ctb-51 is the 5.1 version of CTB
|
| 151 |
+
put LDC2005T01U01_ChineseTreebank5.1 in $CONSTITUENCY_BASE/chinese
|
| 152 |
+
python3 -m stanza.utils.datasets.constituency.prepare_con_dataset zh_ctb-51
|
| 153 |
+
|
| 154 |
+
@article{xue_xia_chiou_palmer_2005,
|
| 155 |
+
title={The Penn Chinese TreeBank: Phrase structure annotation of a large corpus},
|
| 156 |
+
volume={11},
|
| 157 |
+
DOI={10.1017/S135132490400364X},
|
| 158 |
+
number={2},
|
| 159 |
+
journal={Natural Language Engineering},
|
| 160 |
+
publisher={Cambridge University Press},
|
| 161 |
+
author={XUE, NAIWEN and XIA, FEI and CHIOU, FU-DONG and PALMER, MARTA},
|
| 162 |
+
year={2005},
|
| 163 |
+
pages={207–238}}
|
| 164 |
+
|
| 165 |
+
zh_ctb-51b is the same dataset, but using a smaller dev/test set
|
| 166 |
+
in our experiments, this is substantially easier
|
| 167 |
+
|
| 168 |
+
zh_ctb-90 is the 9.0 version of CTB
|
| 169 |
+
put LDC2016T13 in $CONSTITUENCY_BASE/chinese
|
| 170 |
+
python3 -m stanza.utils.datasets.constituency.prepare_con_dataset zh_ctb-90
|
| 171 |
+
|
| 172 |
+
the splits used are the ones from the file docs/ctb9.0-file-list.txt
|
| 173 |
+
included in the CTB 9.0 release
|
| 174 |
+
|
| 175 |
+
SPMRL adds several treebanks
|
| 176 |
+
https://www.spmrl.org/
|
| 177 |
+
https://www.spmrl.org/sancl-posters2014.html
|
| 178 |
+
Currently only German is converted, the German version being a
|
| 179 |
+
version of the Tiger Treebank
|
| 180 |
+
python3 -m stanza.utils.datasets.constituency.prepare_con_dataset de_spmrl
|
| 181 |
+
|
| 182 |
+
en_mctb is a multidomain test set covering five domains other than newswire
|
| 183 |
+
https://github.com/RingoS/multi-domain-parsing-analysis
|
| 184 |
+
Challenges to Open-Domain Constituency Parsing
|
| 185 |
+
|
| 186 |
+
@inproceedings{yang-etal-2022-challenges,
|
| 187 |
+
title = "Challenges to Open-Domain Constituency Parsing",
|
| 188 |
+
author = "Yang, Sen and
|
| 189 |
+
Cui, Leyang and
|
| 190 |
+
Ning, Ruoxi and
|
| 191 |
+
Wu, Di and
|
| 192 |
+
Zhang, Yue",
|
| 193 |
+
booktitle = "Findings of the Association for Computational Linguistics: ACL 2022",
|
| 194 |
+
month = may,
|
| 195 |
+
year = "2022",
|
| 196 |
+
address = "Dublin, Ireland",
|
| 197 |
+
publisher = "Association for Computational Linguistics",
|
| 198 |
+
url = "https://aclanthology.org/2022.findings-acl.11",
|
| 199 |
+
doi = "10.18653/v1/2022.findings-acl.11",
|
| 200 |
+
pages = "112--127",
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
This conversion replaces the top bracket from top -> ROOT and puts an extra S
|
| 204 |
+
bracket on any roots with more than one node.
|
| 205 |
+
"""
|
| 206 |
+
|
| 207 |
+
import argparse
|
| 208 |
+
import os
|
| 209 |
+
import random
|
| 210 |
+
import sys
|
| 211 |
+
import tempfile
|
| 212 |
+
|
| 213 |
+
from tqdm import tqdm
|
| 214 |
+
|
| 215 |
+
from stanza.models.constituency import parse_tree
|
| 216 |
+
import stanza.utils.default_paths as default_paths
|
| 217 |
+
from stanza.models.constituency import tree_reader
|
| 218 |
+
from stanza.models.constituency.parse_tree import Tree
|
| 219 |
+
from stanza.server import tsurgeon
|
| 220 |
+
from stanza.utils.datasets.common import UnknownDatasetError
|
| 221 |
+
from stanza.utils.datasets.constituency import utils
|
| 222 |
+
from stanza.utils.datasets.constituency.convert_alt import convert_alt
|
| 223 |
+
from stanza.utils.datasets.constituency.convert_arboretum import convert_tiger_treebank
|
| 224 |
+
from stanza.utils.datasets.constituency.convert_cintil import convert_cintil_treebank
|
| 225 |
+
import stanza.utils.datasets.constituency.convert_ctb as convert_ctb
|
| 226 |
+
from stanza.utils.datasets.constituency.convert_it_turin import convert_it_turin
|
| 227 |
+
from stanza.utils.datasets.constituency.convert_it_vit import convert_it_vit
|
| 228 |
+
from stanza.utils.datasets.constituency.convert_spmrl import convert_spmrl
|
| 229 |
+
from stanza.utils.datasets.constituency.convert_starlang import read_starlang
|
| 230 |
+
from stanza.utils.datasets.constituency.utils import SHARDS, write_dataset
|
| 231 |
+
import stanza.utils.datasets.constituency.vtb_convert as vtb_convert
|
| 232 |
+
import stanza.utils.datasets.constituency.vtb_split as vtb_split
|
| 233 |
+
|
| 234 |
+
def process_it_turin(paths, dataset_name, *args):
|
| 235 |
+
"""
|
| 236 |
+
Convert the it_turin dataset
|
| 237 |
+
"""
|
| 238 |
+
assert dataset_name == 'it_turin'
|
| 239 |
+
input_dir = os.path.join(paths["CONSTITUENCY_BASE"], "italian")
|
| 240 |
+
output_dir = paths["CONSTITUENCY_DATA_DIR"]
|
| 241 |
+
convert_it_turin(input_dir, output_dir)
|
| 242 |
+
|
| 243 |
+
def process_it_vit(paths, dataset_name, *args):
|
| 244 |
+
# needs at least UD 2.11 or this will not work
|
| 245 |
+
# in the meantime, the git version of VIT will suffice
|
| 246 |
+
assert dataset_name == 'it_vit'
|
| 247 |
+
convert_it_vit(paths, dataset_name)
|
| 248 |
+
|
| 249 |
+
def process_vlsp09(paths, dataset_name, *args):
|
| 250 |
+
"""
|
| 251 |
+
Processes the VLSP 2009 dataset, discarding or fixing trees when needed
|
| 252 |
+
"""
|
| 253 |
+
assert dataset_name == 'vi_vlsp09'
|
| 254 |
+
vlsp_path = os.path.join(paths["CONSTITUENCY_BASE"], "vietnamese", "VietTreebank_VLSP_SP73", "Kho ngu lieu 10000 cay cu phap")
|
| 255 |
+
with tempfile.TemporaryDirectory() as tmp_output_path:
|
| 256 |
+
vtb_convert.convert_dir(vlsp_path, tmp_output_path)
|
| 257 |
+
vtb_split.split_files(tmp_output_path, paths["CONSTITUENCY_DATA_DIR"], dataset_name)
|
| 258 |
+
|
| 259 |
+
def process_vlsp21(paths, dataset_name, *args):
|
| 260 |
+
"""
|
| 261 |
+
Processes the VLSP 2021 dataset, which is just a single file
|
| 262 |
+
"""
|
| 263 |
+
assert dataset_name == 'vi_vlsp21'
|
| 264 |
+
vlsp_file = os.path.join(paths["CONSTITUENCY_BASE"], "vietnamese", "VLSP_2021", "VTB_VLSP21_tree.txt")
|
| 265 |
+
if not os.path.exists(vlsp_file):
|
| 266 |
+
raise FileNotFoundError("Could not find the 2021 dataset in the expected location of {} - CONSTITUENCY_BASE == {}".format(vlsp_file, paths["CONSTITUENCY_BASE"]))
|
| 267 |
+
with tempfile.TemporaryDirectory() as tmp_output_path:
|
| 268 |
+
vtb_convert.convert_files([vlsp_file], tmp_output_path)
|
| 269 |
+
# This produces a 0 length test set, just as a placeholder until the actual test set is released
|
| 270 |
+
vtb_split.split_files(tmp_output_path, paths["CONSTITUENCY_DATA_DIR"], dataset_name, train_size=0.9, dev_size=0.1)
|
| 271 |
+
_, _, test_file = vtb_split.create_paths(paths["CONSTITUENCY_DATA_DIR"], dataset_name)
|
| 272 |
+
with open(test_file, "w"):
|
| 273 |
+
# create an empty test file - currently we don't have actual test data for VLSP 21
|
| 274 |
+
pass
|
| 275 |
+
|
| 276 |
+
def process_vlsp22(paths, dataset_name, *args):
|
| 277 |
+
"""
|
| 278 |
+
Processes the VLSP 2022 dataset, which is four separate files for some reason
|
| 279 |
+
"""
|
| 280 |
+
assert dataset_name == 'vi_vlsp22' or dataset_name == 'vi_vlsp23'
|
| 281 |
+
|
| 282 |
+
if dataset_name == 'vi_vlsp22':
|
| 283 |
+
default_subdir = 'VLSP_2022'
|
| 284 |
+
default_make_test_split = False
|
| 285 |
+
updated_tagset = False
|
| 286 |
+
elif dataset_name == 'vi_vlsp23':
|
| 287 |
+
default_subdir = os.path.join('VLSP_2023', 'Trainingdataset')
|
| 288 |
+
default_make_test_split = False
|
| 289 |
+
updated_tagset = True
|
| 290 |
+
|
| 291 |
+
parser = argparse.ArgumentParser()
|
| 292 |
+
parser.add_argument('--subdir', default=default_subdir, type=str, help='Where to find the data - allows for using previous versions, if needed')
|
| 293 |
+
parser.add_argument('--no_convert_brackets', default=True, action='store_false', dest='convert_brackets', help="Don't convert the VLSP parens RKBT & LKBT to PTB parens")
|
| 294 |
+
parser.add_argument('--n_splits', default=None, type=int, help='Split the data into this many pieces. Relevant as there is no set training/dev split, so this allows for N models on N different dev sets')
|
| 295 |
+
parser.add_argument('--test_split', default=default_make_test_split, action='store_true', help='Split 1/10th of the data as a test split as well. Useful for experimental results. Less relevant since there is now an official test set')
|
| 296 |
+
parser.add_argument('--no_test_split', dest='test_split', action='store_false', help='Split 1/10th of the data as a test split as well. Useful for experimental results. Less relevant since there is now an official test set')
|
| 297 |
+
parser.add_argument('--seed', default=1234, type=int, help='Random seed to use when splitting')
|
| 298 |
+
args = parser.parse_args(args=list(*args))
|
| 299 |
+
|
| 300 |
+
if os.path.exists(args.subdir):
|
| 301 |
+
vlsp_dir = args.subdir
|
| 302 |
+
else:
|
| 303 |
+
vlsp_dir = os.path.join(paths["CONSTITUENCY_BASE"], "vietnamese", args.subdir)
|
| 304 |
+
if not os.path.exists(vlsp_dir):
|
| 305 |
+
raise FileNotFoundError("Could not find the {} dataset in the expected location of {} - CONSTITUENCY_BASE == {}".format(dataset_name, vlsp_dir, paths["CONSTITUENCY_BASE"]))
|
| 306 |
+
vlsp_files = os.listdir(vlsp_dir)
|
| 307 |
+
vlsp_train_files = [os.path.join(vlsp_dir, x) for x in vlsp_files if x.startswith("file") and not x.endswith(".zip")]
|
| 308 |
+
vlsp_train_files.sort()
|
| 309 |
+
|
| 310 |
+
if dataset_name == 'vi_vlsp22':
|
| 311 |
+
vlsp_test_files = [os.path.join(vlsp_dir, x) for x in vlsp_files if x.startswith("private") and not x.endswith(".zip")]
|
| 312 |
+
elif dataset_name == 'vi_vlsp23':
|
| 313 |
+
vlsp_test_dir = os.path.abspath(os.path.join(vlsp_dir, os.pardir, "test"))
|
| 314 |
+
vlsp_test_files = os.listdir(vlsp_test_dir)
|
| 315 |
+
vlsp_test_files = [os.path.join(vlsp_test_dir, x) for x in vlsp_test_files if x.endswith(".csv")]
|
| 316 |
+
|
| 317 |
+
if len(vlsp_train_files) == 0:
|
| 318 |
+
raise FileNotFoundError("No train files (files starting with 'file') found in {}".format(vlsp_dir))
|
| 319 |
+
if not args.test_split and len(vlsp_test_files) == 0:
|
| 320 |
+
raise FileNotFoundError("No test files found in {}".format(vlsp_dir))
|
| 321 |
+
print("Loading training files from {}".format(vlsp_dir))
|
| 322 |
+
print("Procesing training files:\n {}".format("\n ".join(vlsp_train_files)))
|
| 323 |
+
with tempfile.TemporaryDirectory() as train_output_path:
|
| 324 |
+
vtb_convert.convert_files(vlsp_train_files, train_output_path, verbose=True, fix_errors=True, convert_brackets=args.convert_brackets, updated_tagset=updated_tagset)
|
| 325 |
+
# This produces a 0 length test set, just as a placeholder until the actual test set is released
|
| 326 |
+
if args.n_splits:
|
| 327 |
+
test_size = 0.1 if args.test_split else 0.0
|
| 328 |
+
dev_size = (1.0 - test_size) / args.n_splits
|
| 329 |
+
train_size = 1.0 - test_size - dev_size
|
| 330 |
+
for rotation in range(args.n_splits):
|
| 331 |
+
# there is a shuffle inside the split routine,
|
| 332 |
+
# so we need to reset the random seed each time
|
| 333 |
+
random.seed(args.seed)
|
| 334 |
+
rotation_name = "%s-%d-%d" % (dataset_name, rotation, args.n_splits)
|
| 335 |
+
if args.test_split:
|
| 336 |
+
rotation_name = rotation_name + "t"
|
| 337 |
+
vtb_split.split_files(train_output_path, paths["CONSTITUENCY_DATA_DIR"], rotation_name, train_size=train_size, dev_size=dev_size, rotation=(rotation, args.n_splits))
|
| 338 |
+
else:
|
| 339 |
+
test_size = 0.1 if args.test_split else 0.0
|
| 340 |
+
dev_size = 0.1
|
| 341 |
+
train_size = 1.0 - test_size - dev_size
|
| 342 |
+
if args.test_split:
|
| 343 |
+
dataset_name = dataset_name + "t"
|
| 344 |
+
vtb_split.split_files(train_output_path, paths["CONSTITUENCY_DATA_DIR"], dataset_name, train_size=train_size, dev_size=dev_size)
|
| 345 |
+
|
| 346 |
+
if not args.test_split:
|
| 347 |
+
print("Procesing test files:\n {}".format("\n ".join(vlsp_test_files)))
|
| 348 |
+
with tempfile.TemporaryDirectory() as test_output_path:
|
| 349 |
+
vtb_convert.convert_files(vlsp_test_files, test_output_path, verbose=True, fix_errors=True, convert_brackets=args.convert_brackets, updated_tagset=updated_tagset)
|
| 350 |
+
if args.n_splits:
|
| 351 |
+
for rotation in range(args.n_splits):
|
| 352 |
+
rotation_name = "%s-%d-%d" % (dataset_name, rotation, args.n_splits)
|
| 353 |
+
vtb_split.split_files(test_output_path, paths["CONSTITUENCY_DATA_DIR"], rotation_name, train_size=0, dev_size=0)
|
| 354 |
+
else:
|
| 355 |
+
vtb_split.split_files(test_output_path, paths["CONSTITUENCY_DATA_DIR"], dataset_name, train_size=0, dev_size=0)
|
| 356 |
+
if not args.test_split and not args.n_splits and dataset_name == 'vi_vlsp23':
|
| 357 |
+
print("Procesing test files and keeping ids:\n {}".format("\n ".join(vlsp_test_files)))
|
| 358 |
+
with tempfile.TemporaryDirectory() as test_output_path:
|
| 359 |
+
vtb_convert.convert_files(vlsp_test_files, test_output_path, verbose=True, fix_errors=True, convert_brackets=args.convert_brackets, updated_tagset=updated_tagset, write_ids=True)
|
| 360 |
+
vtb_split.split_files(test_output_path, paths["CONSTITUENCY_DATA_DIR"], dataset_name + "-ids", train_size=0, dev_size=0)
|
| 361 |
+
|
| 362 |
+
def process_arboretum(paths, dataset_name, *args):
|
| 363 |
+
"""
|
| 364 |
+
Processes the Danish dataset, Arboretum
|
| 365 |
+
"""
|
| 366 |
+
assert dataset_name == 'da_arboretum'
|
| 367 |
+
|
| 368 |
+
arboretum_file = os.path.join(paths["CONSTITUENCY_BASE"], "danish", "arboretum", "arboretum.tiger", "arboretum.tiger")
|
| 369 |
+
if not os.path.exists(arboretum_file):
|
| 370 |
+
raise FileNotFoundError("Unable to find input file for Arboretum. Expected in {}".format(arboretum_file))
|
| 371 |
+
|
| 372 |
+
treebank = convert_tiger_treebank(arboretum_file)
|
| 373 |
+
datasets = utils.split_treebank(treebank, 0.8, 0.1)
|
| 374 |
+
output_dir = paths["CONSTITUENCY_DATA_DIR"]
|
| 375 |
+
|
| 376 |
+
output_filename = os.path.join(output_dir, "%s.mrg" % dataset_name)
|
| 377 |
+
print("Writing {} trees to {}".format(len(treebank), output_filename))
|
| 378 |
+
parse_tree.Tree.write_treebank(treebank, output_filename)
|
| 379 |
+
|
| 380 |
+
write_dataset(datasets, output_dir, dataset_name)
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
def process_starlang(paths, dataset_name, *args):
|
| 384 |
+
"""
|
| 385 |
+
Convert the Turkish Starlang dataset to brackets
|
| 386 |
+
"""
|
| 387 |
+
assert dataset_name == 'tr_starlang'
|
| 388 |
+
|
| 389 |
+
PIECES = ["TurkishAnnotatedTreeBank-15",
|
| 390 |
+
"TurkishAnnotatedTreeBank2-15",
|
| 391 |
+
"TurkishAnnotatedTreeBank2-20"]
|
| 392 |
+
|
| 393 |
+
output_dir = paths["CONSTITUENCY_DATA_DIR"]
|
| 394 |
+
chunk_paths = [os.path.join(paths["CONSTITUENCY_BASE"], "turkish", piece) for piece in PIECES]
|
| 395 |
+
datasets = read_starlang(chunk_paths)
|
| 396 |
+
|
| 397 |
+
write_dataset(datasets, output_dir, dataset_name)
|
| 398 |
+
|
| 399 |
+
def process_ja_alt(paths, dataset_name, *args):
|
| 400 |
+
"""
|
| 401 |
+
Convert and split the ALT dataset
|
| 402 |
+
|
| 403 |
+
TODO: could theoretically extend this to MY or any other similar dataset from ALT
|
| 404 |
+
"""
|
| 405 |
+
lang, source = dataset_name.split("_", 1)
|
| 406 |
+
assert lang == 'ja'
|
| 407 |
+
assert source == 'alt'
|
| 408 |
+
|
| 409 |
+
PIECES = ["Japanese-ALT-Draft.txt", "Japanese-ALT-Reviewed.txt"]
|
| 410 |
+
input_dir = os.path.join(paths["CONSTITUENCY_BASE"], "japanese", "Japanese-ALT-20210218")
|
| 411 |
+
input_files = [os.path.join(input_dir, input_file) for input_file in PIECES]
|
| 412 |
+
split_files = [os.path.join(input_dir, "URL-%s.txt" % shard) for shard in SHARDS]
|
| 413 |
+
output_dir = paths["CONSTITUENCY_DATA_DIR"]
|
| 414 |
+
output_files = [os.path.join(output_dir, "%s_%s.mrg" % (dataset_name, shard)) for shard in SHARDS]
|
| 415 |
+
convert_alt(input_files, split_files, output_files)
|
| 416 |
+
|
| 417 |
+
def process_pt_cintil(paths, dataset_name, *args):
|
| 418 |
+
"""
|
| 419 |
+
Convert and split the PT Cintil dataset
|
| 420 |
+
"""
|
| 421 |
+
lang, source = dataset_name.split("_", 1)
|
| 422 |
+
assert lang == 'pt'
|
| 423 |
+
assert source == 'cintil'
|
| 424 |
+
|
| 425 |
+
input_file = os.path.join(paths["CONSTITUENCY_BASE"], "portuguese", "CINTIL", "CINTIL-Treebank.xml")
|
| 426 |
+
output_dir = paths["CONSTITUENCY_DATA_DIR"]
|
| 427 |
+
datasets = convert_cintil_treebank(input_file)
|
| 428 |
+
|
| 429 |
+
write_dataset(datasets, output_dir, dataset_name)
|
| 430 |
+
|
| 431 |
+
def process_id_icon(paths, dataset_name, *args):
|
| 432 |
+
lang, source = dataset_name.split("_", 1)
|
| 433 |
+
assert lang == 'id'
|
| 434 |
+
assert source == 'icon'
|
| 435 |
+
|
| 436 |
+
input_dir = os.path.join(paths["CONSTITUENCY_BASE"], "seacorenlp", "seacorenlp-data", "id", "constituency")
|
| 437 |
+
input_files = [os.path.join(input_dir, x) for x in ("train.txt", "dev.txt", "test.txt")]
|
| 438 |
+
datasets = []
|
| 439 |
+
for input_file in input_files:
|
| 440 |
+
trees = tree_reader.read_tree_file(input_file)
|
| 441 |
+
trees = [Tree("ROOT", tree) for tree in trees]
|
| 442 |
+
datasets.append(trees)
|
| 443 |
+
|
| 444 |
+
output_dir = paths["CONSTITUENCY_DATA_DIR"]
|
| 445 |
+
write_dataset(datasets, output_dir, dataset_name)
|
| 446 |
+
|
| 447 |
+
def process_ctb_51(paths, dataset_name, *args):
|
| 448 |
+
lang, source = dataset_name.split("_", 1)
|
| 449 |
+
assert lang == 'zh-hans'
|
| 450 |
+
assert source == 'ctb-51'
|
| 451 |
+
|
| 452 |
+
input_dir = os.path.join(paths["CONSTITUENCY_BASE"], "chinese", "LDC2005T01U01_ChineseTreebank5.1", "bracketed")
|
| 453 |
+
output_dir = paths["CONSTITUENCY_DATA_DIR"]
|
| 454 |
+
convert_ctb.convert_ctb(input_dir, output_dir, dataset_name, convert_ctb.Version.V51)
|
| 455 |
+
|
| 456 |
+
def process_ctb_51b(paths, dataset_name, *args):
|
| 457 |
+
lang, source = dataset_name.split("_", 1)
|
| 458 |
+
assert lang == 'zh-hans'
|
| 459 |
+
assert source == 'ctb-51b'
|
| 460 |
+
|
| 461 |
+
input_dir = os.path.join(paths["CONSTITUENCY_BASE"], "chinese", "LDC2005T01U01_ChineseTreebank5.1", "bracketed")
|
| 462 |
+
output_dir = paths["CONSTITUENCY_DATA_DIR"]
|
| 463 |
+
if not os.path.exists(input_dir):
|
| 464 |
+
raise FileNotFoundError("CTB 5.1 location not found: %s" % input_dir)
|
| 465 |
+
print("Loading trees from %s" % input_dir)
|
| 466 |
+
convert_ctb.convert_ctb(input_dir, output_dir, dataset_name, convert_ctb.Version.V51b)
|
| 467 |
+
|
| 468 |
+
def process_ctb_90(paths, dataset_name, *args):
|
| 469 |
+
lang, source = dataset_name.split("_", 1)
|
| 470 |
+
assert lang == 'zh-hans'
|
| 471 |
+
assert source == 'ctb-90'
|
| 472 |
+
|
| 473 |
+
input_dir = os.path.join(paths["CONSTITUENCY_BASE"], "chinese", "LDC2016T13", "ctb9.0", "data", "bracketed")
|
| 474 |
+
output_dir = paths["CONSTITUENCY_DATA_DIR"]
|
| 475 |
+
convert_ctb.convert_ctb(input_dir, output_dir, dataset_name, convert_ctb.Version.V90)
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
def process_ptb3_revised(paths, dataset_name, *args):
|
| 479 |
+
input_dir = os.path.join(paths["CONSTITUENCY_BASE"], "english", "LDC2015T13_eng_news_txt_tbnk-ptb_revised")
|
| 480 |
+
if not os.path.exists(input_dir):
|
| 481 |
+
backup_input_dir = os.path.join(paths["CONSTITUENCY_BASE"], "english", "LDC2015T13")
|
| 482 |
+
if not os.path.exists(backup_input_dir):
|
| 483 |
+
raise FileNotFoundError("Could not find ptb3-revised in either %s or %s" % (input_dir, backup_input_dir))
|
| 484 |
+
input_dir = backup_input_dir
|
| 485 |
+
|
| 486 |
+
bracket_dir = os.path.join(input_dir, "data", "penntree")
|
| 487 |
+
output_dir = paths["CONSTITUENCY_DATA_DIR"]
|
| 488 |
+
|
| 489 |
+
# compensate for a weird mislabeling in the original dataset
|
| 490 |
+
label_map = {"ADJ-PRD": "ADJP-PRD"}
|
| 491 |
+
|
| 492 |
+
train_trees = []
|
| 493 |
+
for i in tqdm(range(2, 22)):
|
| 494 |
+
new_trees = tree_reader.read_directory(os.path.join(bracket_dir, "%02d" % i))
|
| 495 |
+
new_trees = [t.remap_constituent_labels(label_map) for t in new_trees]
|
| 496 |
+
train_trees.extend(new_trees)
|
| 497 |
+
|
| 498 |
+
move_tregex = "_ROOT_ <1 __=home <2 /^[.]$/=move"
|
| 499 |
+
move_tsurgeon = "move move >-1 home"
|
| 500 |
+
|
| 501 |
+
print("Moving sentence final punctuation if necessary")
|
| 502 |
+
with tsurgeon.Tsurgeon() as tsurgeon_processor:
|
| 503 |
+
train_trees = [tsurgeon_processor.process(tree, move_tregex, move_tsurgeon)[0] for tree in tqdm(train_trees)]
|
| 504 |
+
|
| 505 |
+
dev_trees = tree_reader.read_directory(os.path.join(bracket_dir, "22"))
|
| 506 |
+
dev_trees = [t.remap_constituent_labels(label_map) for t in dev_trees]
|
| 507 |
+
|
| 508 |
+
test_trees = tree_reader.read_directory(os.path.join(bracket_dir, "23"))
|
| 509 |
+
test_trees = [t.remap_constituent_labels(label_map) for t in test_trees]
|
| 510 |
+
print("Read %d train trees, %d dev trees, and %d test trees" % (len(train_trees), len(dev_trees), len(test_trees)))
|
| 511 |
+
datasets = [train_trees, dev_trees, test_trees]
|
| 512 |
+
write_dataset(datasets, output_dir, dataset_name)
|
| 513 |
+
|
| 514 |
+
def process_en_mctb(paths, dataset_name, *args):
|
| 515 |
+
"""
|
| 516 |
+
Converts the following blocks:
|
| 517 |
+
|
| 518 |
+
dialogue.cleaned.txt forum.cleaned.txt law.cleaned.txt literature.cleaned.txt review.cleaned.txt
|
| 519 |
+
"""
|
| 520 |
+
base_path = os.path.join(paths["CONSTITUENCY_BASE"], "english", "multi-domain-parsing-analysis", "data", "MCTB_en")
|
| 521 |
+
if not os.path.exists(base_path):
|
| 522 |
+
raise FileNotFoundError("Please download multi-domain-parsing-analysis to %s" % base_path)
|
| 523 |
+
def tree_callback(tree):
|
| 524 |
+
if len(tree.children) > 1:
|
| 525 |
+
tree = parse_tree.Tree("S", tree.children)
|
| 526 |
+
return parse_tree.Tree("ROOT", [tree])
|
| 527 |
+
return parse_tree.Tree("ROOT", tree.children)
|
| 528 |
+
|
| 529 |
+
filenames = ["dialogue.cleaned.txt", "forum.cleaned.txt", "law.cleaned.txt", "literature.cleaned.txt", "review.cleaned.txt"]
|
| 530 |
+
for filename in filenames:
|
| 531 |
+
trees = tree_reader.read_tree_file(os.path.join(base_path, filename), tree_callback=tree_callback)
|
| 532 |
+
print("%d trees in %s" % (len(trees), filename))
|
| 533 |
+
output_filename = "%s-%s_test.mrg" % (dataset_name, filename.split(".")[0])
|
| 534 |
+
output_filename = os.path.join(paths["CONSTITUENCY_DATA_DIR"], output_filename)
|
| 535 |
+
print("Writing trees to %s" % output_filename)
|
| 536 |
+
parse_tree.Tree.write_treebank(trees, output_filename)
|
| 537 |
+
|
| 538 |
+
def process_spmrl(paths, dataset_name, *args):
|
| 539 |
+
if dataset_name != 'de_spmrl':
|
| 540 |
+
raise ValueError("SPMRL dataset %s currently not supported" % dataset_name)
|
| 541 |
+
|
| 542 |
+
output_directory = paths["CONSTITUENCY_DATA_DIR"]
|
| 543 |
+
input_directory = os.path.join(paths["CONSTITUENCY_BASE"], "spmrl", "SPMRL_SHARED_2014", "GERMAN_SPMRL", "gold", "ptb")
|
| 544 |
+
|
| 545 |
+
convert_spmrl(input_directory, output_directory, dataset_name)
|
| 546 |
+
|
| 547 |
+
DATASET_MAPPING = {
|
| 548 |
+
'da_arboretum': process_arboretum,
|
| 549 |
+
|
| 550 |
+
'de_spmrl': process_spmrl,
|
| 551 |
+
|
| 552 |
+
'en_ptb3-revised': process_ptb3_revised,
|
| 553 |
+
'en_mctb': process_en_mctb,
|
| 554 |
+
|
| 555 |
+
'id_icon': process_id_icon,
|
| 556 |
+
|
| 557 |
+
'it_turin': process_it_turin,
|
| 558 |
+
'it_vit': process_it_vit,
|
| 559 |
+
|
| 560 |
+
'ja_alt': process_ja_alt,
|
| 561 |
+
|
| 562 |
+
'pt_cintil': process_pt_cintil,
|
| 563 |
+
|
| 564 |
+
'tr_starlang': process_starlang,
|
| 565 |
+
|
| 566 |
+
'vi_vlsp09': process_vlsp09,
|
| 567 |
+
'vi_vlsp21': process_vlsp21,
|
| 568 |
+
'vi_vlsp22': process_vlsp22,
|
| 569 |
+
'vi_vlsp23': process_vlsp22, # options allow for this
|
| 570 |
+
|
| 571 |
+
'zh-hans_ctb-51': process_ctb_51,
|
| 572 |
+
'zh-hans_ctb-51b': process_ctb_51b,
|
| 573 |
+
'zh-hans_ctb-90': process_ctb_90,
|
| 574 |
+
}
|
| 575 |
+
|
| 576 |
+
def main(dataset_name, *args):
|
| 577 |
+
paths = default_paths.get_default_paths()
|
| 578 |
+
|
| 579 |
+
random.seed(1234)
|
| 580 |
+
|
| 581 |
+
if dataset_name in DATASET_MAPPING:
|
| 582 |
+
DATASET_MAPPING[dataset_name](paths, dataset_name, *args)
|
| 583 |
+
else:
|
| 584 |
+
raise UnknownDatasetError(dataset_name, f"dataset {dataset_name} currently not handled by prepare_con_dataset")
|
| 585 |
+
|
| 586 |
+
if __name__ == '__main__':
|
| 587 |
+
if len(sys.argv) == 1:
|
| 588 |
+
print("Known datasets:")
|
| 589 |
+
for key in DATASET_MAPPING:
|
| 590 |
+
print(" %s" % key)
|
| 591 |
+
else:
|
| 592 |
+
main(sys.argv[1], sys.argv[2:])
|
| 593 |
+
|
| 594 |
+
|
stanza/stanza/utils/datasets/constituency/silver_variance.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Use the concepts in "Dataset Cartography" and "Mind Your Outliers" to find trees with the least variance over a training run
|
| 3 |
+
|
| 4 |
+
https://arxiv.org/pdf/2009.10795.pdf
|
| 5 |
+
https://arxiv.org/abs/2107.02331
|
| 6 |
+
|
| 7 |
+
The idea here is that high variance trees are more likely to be wrong in the first place. Using this will filter a silver dataset to have better trees.
|
| 8 |
+
|
| 9 |
+
for example:
|
| 10 |
+
|
| 11 |
+
nlprun -d a6000 -p high "export CLASSPATH=/sailhome/horatio/CoreNLP/classes:/sailhome/horatio/CoreNLP/lib/*:$CLASSPATH; python3 stanza/utils/datasets/constituency/silver_variance.py --eval_file /u/nlp/data/constituency-parser/italian/2024_it_vit_electra/it_silver_0.mrg saved_models/constituency/it_vit.top.each.silver0.constituency_0*0.pt --output_file filtered_silver0.mrg" -o filter.out
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import argparse
|
| 15 |
+
|
| 16 |
+
import logging
|
| 17 |
+
|
| 18 |
+
import numpy
|
| 19 |
+
|
| 20 |
+
from stanza.models.common import utils
|
| 21 |
+
from stanza.models.common.foundation_cache import FoundationCache
|
| 22 |
+
from stanza.models.constituency import retagging
|
| 23 |
+
from stanza.models.constituency import tree_reader
|
| 24 |
+
from stanza.models.constituency.parser_training import run_dev_set
|
| 25 |
+
from stanza.models.constituency.trainer import Trainer
|
| 26 |
+
from stanza.models.constituency.utils import retag_trees
|
| 27 |
+
from stanza.server.parser_eval import EvaluateParser
|
| 28 |
+
from stanza.utils.get_tqdm import get_tqdm
|
| 29 |
+
|
| 30 |
+
tqdm = get_tqdm()
|
| 31 |
+
|
| 32 |
+
logger = logging.getLogger('stanza.constituency.trainer')
|
| 33 |
+
|
| 34 |
+
def parse_args(args=None):
|
| 35 |
+
parser = argparse.ArgumentParser(description="Script to filter trees by how much variance they show over multiple checkpoints of a parser training run.")
|
| 36 |
+
|
| 37 |
+
parser.add_argument('--eval_file', type=str, default=None, help='Input file for data loader.')
|
| 38 |
+
parser.add_argument('--output_file', type=str, default=None, help='Output file after sorting by variance.')
|
| 39 |
+
|
| 40 |
+
parser.add_argument('--charlm_forward_file', type=str, default=None, help="Exact path to use for forward charlm")
|
| 41 |
+
parser.add_argument('--charlm_backward_file', type=str, default=None, help="Exact path to use for backward charlm")
|
| 42 |
+
parser.add_argument('--wordvec_pretrain_file', type=str, default=None, help='Exact name of the pretrain file to read')
|
| 43 |
+
|
| 44 |
+
utils.add_device_args(parser)
|
| 45 |
+
|
| 46 |
+
# TODO: use the training scripts to pick the charlm & pretrain if needed
|
| 47 |
+
parser.add_argument('--lang', default='it', help='Language to use')
|
| 48 |
+
|
| 49 |
+
parser.add_argument('--eval_batch_size', type=int, default=50, help='How many trees to batch when running eval')
|
| 50 |
+
parser.add_argument('models', type=str, nargs='+', default=None, help="Which model(s) to load")
|
| 51 |
+
|
| 52 |
+
parser.add_argument('--keep', type=float, default=0.5, help="How many trees to keep after sorting by variance")
|
| 53 |
+
parser.add_argument('--reverse', default=False, action='store_true', help='Actually, keep the high variance trees')
|
| 54 |
+
|
| 55 |
+
retagging.add_retag_args(parser)
|
| 56 |
+
|
| 57 |
+
args = vars(parser.parse_args())
|
| 58 |
+
|
| 59 |
+
retagging.postprocess_args(args)
|
| 60 |
+
|
| 61 |
+
return args
|
| 62 |
+
|
| 63 |
+
def main():
|
| 64 |
+
args = parse_args()
|
| 65 |
+
retag_pipeline = retagging.build_retag_pipeline(args)
|
| 66 |
+
foundation_cache = retag_pipeline[0].foundation_cache if retag_pipeline else FoundationCache()
|
| 67 |
+
|
| 68 |
+
print("Analyzing with the following models:\n " + "\n ".join(args['models']))
|
| 69 |
+
|
| 70 |
+
treebank = tree_reader.read_treebank(args['eval_file'])
|
| 71 |
+
logger.info("Read %d trees for analysis", len(treebank))
|
| 72 |
+
|
| 73 |
+
f1_history = []
|
| 74 |
+
retagged_treebank = None
|
| 75 |
+
|
| 76 |
+
chunk_size = 5000
|
| 77 |
+
with EvaluateParser() as evaluator:
|
| 78 |
+
for model_filename in args['models']:
|
| 79 |
+
print("Starting processing with %s" % model_filename)
|
| 80 |
+
trainer = Trainer.load(model_filename, args=args, foundation_cache=foundation_cache)
|
| 81 |
+
if retag_pipeline is not None and retagged_treebank is None:
|
| 82 |
+
retag_method = trainer.model.args['retag_method']
|
| 83 |
+
retag_xpos = trainer.model.args['retag_xpos']
|
| 84 |
+
logger.info("Retagging trees using the %s tags from the %s package...", retag_method, args['retag_package'])
|
| 85 |
+
retagged_treebank = retag_trees(treebank, retag_pipeline, retag_xpos)
|
| 86 |
+
logger.info("Retagging finished")
|
| 87 |
+
|
| 88 |
+
current_history = []
|
| 89 |
+
for chunk_start in range(0, len(treebank), chunk_size):
|
| 90 |
+
chunk = treebank[chunk_start:chunk_start+chunk_size]
|
| 91 |
+
retagged_chunk = retagged_treebank[chunk_start:chunk_start+chunk_size] if retagged_treebank else None
|
| 92 |
+
f1, kbestF1, treeF1 = run_dev_set(trainer.model, retagged_chunk, chunk, args, evaluator)
|
| 93 |
+
current_history.extend(treeF1)
|
| 94 |
+
|
| 95 |
+
f1_history.append(current_history)
|
| 96 |
+
|
| 97 |
+
f1_history = numpy.array(f1_history)
|
| 98 |
+
f1_variance = numpy.var(f1_history, axis=0)
|
| 99 |
+
f1_sorted = sorted([(x, idx) for idx, x in enumerate(f1_variance)], reverse=args['reverse'])
|
| 100 |
+
|
| 101 |
+
num_keep = int(len(f1_sorted) * args['keep'])
|
| 102 |
+
with open(args['output_file'], "w", encoding="utf-8") as fout:
|
| 103 |
+
for _, idx in f1_sorted[:num_keep]:
|
| 104 |
+
fout.write(str(treebank[idx]))
|
| 105 |
+
fout.write("\n")
|
| 106 |
+
|
| 107 |
+
if __name__ == "__main__":
|
| 108 |
+
main()
|
stanza/stanza/utils/datasets/coref/convert_hindi.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
from operator import itemgetter
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
import stanza
|
| 7 |
+
|
| 8 |
+
from stanza.utils.default_paths import get_default_paths
|
| 9 |
+
from stanza.utils.get_tqdm import get_tqdm
|
| 10 |
+
from stanza.utils.datasets.coref.utils import process_document
|
| 11 |
+
|
| 12 |
+
tqdm = get_tqdm()
|
| 13 |
+
|
| 14 |
+
def flatten_spans(coref_spans):
|
| 15 |
+
"""
|
| 16 |
+
Put span IDs on each span, then flatten them into a single list sorted by first word
|
| 17 |
+
"""
|
| 18 |
+
# put span indices on the spans
|
| 19 |
+
# [[[38, 39], [42, 43], [41, 41], [180, 180], [300, 300]], [[60, 68],
|
| 20 |
+
# -->
|
| 21 |
+
# [[[0, 38, 39], [0, 42, 43], [0, 41, 41], [0, 180, 180], [0, 300, 300]], [[1, 60, 68], ...
|
| 22 |
+
coref_spans = [[[span_idx, x, y] for x, y in span] for span_idx, span in enumerate(coref_spans)]
|
| 23 |
+
# flatten list
|
| 24 |
+
# -->
|
| 25 |
+
# [[0, 38, 39], [0, 42, 43], [0, 41, 41], [0, 180, 180], [0, 300, 300], [1, 60, 68], ...
|
| 26 |
+
coref_spans = [y for x in coref_spans for y in x]
|
| 27 |
+
# sort by the first word index
|
| 28 |
+
# -->
|
| 29 |
+
# [[0, 38, 39], [0, 42, 43], [0, 41, 41], [1, 60, 68], [0, 180, 180], [0, 300, 300], ...
|
| 30 |
+
coref_spans = sorted(coref_spans, key=itemgetter(1))
|
| 31 |
+
return coref_spans
|
| 32 |
+
|
| 33 |
+
def remove_nulls(coref_spans, sentences):
|
| 34 |
+
"""
|
| 35 |
+
Removes the "" and "NULL" words from the sentences
|
| 36 |
+
|
| 37 |
+
Also, reindex the spans by the number of words removed.
|
| 38 |
+
So, we might get something like
|
| 39 |
+
[[0, 2], [31, 33], [134, 136], [161, 162]]
|
| 40 |
+
->
|
| 41 |
+
[[0, 2], [30, 32], [129, 131], [155, 156]]
|
| 42 |
+
"""
|
| 43 |
+
word_map = []
|
| 44 |
+
word_idx = 0
|
| 45 |
+
map_idx = 0
|
| 46 |
+
new_sentences = []
|
| 47 |
+
for sentence in sentences:
|
| 48 |
+
new_sentence = []
|
| 49 |
+
for word in sentence:
|
| 50 |
+
word_map.append(map_idx)
|
| 51 |
+
word_idx += 1
|
| 52 |
+
if word != '' and word != 'NULL':
|
| 53 |
+
new_sentence.append(word)
|
| 54 |
+
map_idx += 1
|
| 55 |
+
new_sentences.append(new_sentence)
|
| 56 |
+
|
| 57 |
+
new_spans = []
|
| 58 |
+
for mention in coref_spans:
|
| 59 |
+
new_mention = []
|
| 60 |
+
for span in mention:
|
| 61 |
+
span = [word_map[x] for x in span]
|
| 62 |
+
new_mention.append(span)
|
| 63 |
+
new_spans.append(new_mention)
|
| 64 |
+
return new_spans, new_sentences
|
| 65 |
+
|
| 66 |
+
def arrange_spans_by_sentence(coref_spans, sentences):
|
| 67 |
+
sentence_spans = []
|
| 68 |
+
|
| 69 |
+
current_index = 0
|
| 70 |
+
span_idx = 0
|
| 71 |
+
for sentence in sentences:
|
| 72 |
+
current_sentence_spans = []
|
| 73 |
+
end_index = current_index + len(sentence)
|
| 74 |
+
while span_idx < len(coref_spans) and coref_spans[span_idx][1] < end_index:
|
| 75 |
+
new_span = [coref_spans[span_idx][0], coref_spans[span_idx][1] - current_index, coref_spans[span_idx][2] - current_index]
|
| 76 |
+
current_sentence_spans.append(new_span)
|
| 77 |
+
span_idx += 1
|
| 78 |
+
sentence_spans.append(current_sentence_spans)
|
| 79 |
+
current_index = end_index
|
| 80 |
+
return sentence_spans
|
| 81 |
+
|
| 82 |
+
def convert_dataset_section(pipe, section, use_cconj_heads):
|
| 83 |
+
"""
|
| 84 |
+
Reprocess the original data into a format compatible with previous conversion utilities
|
| 85 |
+
|
| 86 |
+
- remove blank and NULL words
|
| 87 |
+
- rearrange the spans into spans per sentence instead of a list of indices for each span
|
| 88 |
+
- process the document using a Hindi pipeline
|
| 89 |
+
"""
|
| 90 |
+
processed_section = []
|
| 91 |
+
|
| 92 |
+
for idx, doc in enumerate(tqdm(section)):
|
| 93 |
+
doc_id = doc['doc_key']
|
| 94 |
+
part_id = ""
|
| 95 |
+
sentences = doc['sentences']
|
| 96 |
+
sentence_speakers = doc['speakers']
|
| 97 |
+
|
| 98 |
+
coref_spans = doc['clusters']
|
| 99 |
+
coref_spans, sentences = remove_nulls(coref_spans, sentences)
|
| 100 |
+
coref_spans = flatten_spans(coref_spans)
|
| 101 |
+
coref_spans = arrange_spans_by_sentence(coref_spans, sentences)
|
| 102 |
+
|
| 103 |
+
processed = process_document(pipe, doc_id, part_id, sentences, coref_spans, sentence_speakers, use_cconj_heads=use_cconj_heads)
|
| 104 |
+
processed_section.append(processed)
|
| 105 |
+
return processed_section
|
| 106 |
+
|
| 107 |
+
def remove_nulls_dataset_section(section):
|
| 108 |
+
processed_section = []
|
| 109 |
+
for doc in section:
|
| 110 |
+
sentences = doc['sentences']
|
| 111 |
+
coref_spans = doc['clusters']
|
| 112 |
+
coref_spans, sentences = remove_nulls(coref_spans, sentences)
|
| 113 |
+
doc['sentences'] = sentences
|
| 114 |
+
doc['clusters'] = coref_spans
|
| 115 |
+
processed_section.append(doc)
|
| 116 |
+
return processed_section
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def read_json_file(filename):
|
| 120 |
+
with open(filename, encoding="utf-8") as fin:
|
| 121 |
+
dataset = []
|
| 122 |
+
for line in fin:
|
| 123 |
+
line = line.strip()
|
| 124 |
+
if not line:
|
| 125 |
+
continue
|
| 126 |
+
dataset.append(json.loads(line))
|
| 127 |
+
return dataset
|
| 128 |
+
|
| 129 |
+
def write_json_file(output_filename, converted_section):
|
| 130 |
+
with open(output_filename, "w", encoding="utf-8") as fout:
|
| 131 |
+
json.dump(converted_section, fout, indent=2)
|
| 132 |
+
|
| 133 |
+
def main():
|
| 134 |
+
parser = argparse.ArgumentParser(
|
| 135 |
+
prog='Convert Hindi Coref Data',
|
| 136 |
+
)
|
| 137 |
+
parser.add_argument('--no_use_cconj_heads', dest='use_cconj_heads', action='store_false', help="Don't use the conjunction-aware transformation")
|
| 138 |
+
parser.add_argument('--remove_nulls', action='store_true', help="The only action is to remove the NULLs and blank tokens")
|
| 139 |
+
args = parser.parse_args()
|
| 140 |
+
|
| 141 |
+
paths = get_default_paths()
|
| 142 |
+
coref_input_path = paths["COREF_BASE"]
|
| 143 |
+
hindi_base_path = os.path.join(coref_input_path, "hindi", "dataset")
|
| 144 |
+
|
| 145 |
+
sections = ("train", "dev", "test")
|
| 146 |
+
if args.remove_nulls:
|
| 147 |
+
for section in sections:
|
| 148 |
+
input_filename = os.path.join(hindi_base_path, "%s.hindi.jsonlines" % section)
|
| 149 |
+
dataset = read_json_file(input_filename)
|
| 150 |
+
dataset = remove_nulls_dataset_section(dataset)
|
| 151 |
+
output_filename = os.path.join(hindi_base_path, "hi_deeph.%s.nonulls.json" % section)
|
| 152 |
+
with open(output_filename, "w", encoding="utf-8") as fout:
|
| 153 |
+
for doc in dataset:
|
| 154 |
+
json.dump(doc, fout, ensure_ascii=False)
|
| 155 |
+
fout.write("\n")
|
| 156 |
+
else:
|
| 157 |
+
pipe = stanza.Pipeline("hi", processors="tokenize,pos,lemma,depparse", package="default_accurate", tokenize_pretokenized=True)
|
| 158 |
+
|
| 159 |
+
os.makedirs(paths["COREF_DATA_DIR"], exist_ok=True)
|
| 160 |
+
|
| 161 |
+
for section in sections:
|
| 162 |
+
input_filename = os.path.join(hindi_base_path, "%s.hindi.jsonlines" % section)
|
| 163 |
+
dataset = read_json_file(input_filename)
|
| 164 |
+
|
| 165 |
+
output_filename = os.path.join(paths["COREF_DATA_DIR"], "hi_deeph.%s.json" % section)
|
| 166 |
+
converted_section = convert_dataset_section(pipe, dataset, use_cconj_heads=args.use_cconj_heads)
|
| 167 |
+
write_json_file(output_filename, converted_section)
|
| 168 |
+
|
| 169 |
+
if __name__ == '__main__':
|
| 170 |
+
main()
|
stanza/stanza/utils/datasets/ner/compare_entities.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Report the fraction of NER entities in one file which are present in another.
|
| 3 |
+
|
| 4 |
+
Purpose: show the coverage of one file on another, such as reporting
|
| 5 |
+
the number of entities in one dataset on another
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
import argparse
|
| 10 |
+
|
| 11 |
+
from stanza.utils.datasets.ner.utils import read_json_entities
|
| 12 |
+
|
| 13 |
+
def parse_args():
|
| 14 |
+
parser = argparse.ArgumentParser(description="Report the coverage of one NER file on another.")
|
| 15 |
+
parser.add_argument('--train', type=str, nargs="+", required=True, help='File to use to collect the known entities (not necessarily train).')
|
| 16 |
+
parser.add_argument('--test', type=str, nargs="+", required=True, help='File for which we want to know the ratio of known entities')
|
| 17 |
+
args = parser.parse_args()
|
| 18 |
+
return args
|
| 19 |
+
|
| 20 |
+
def report_known_entities(train_file, test_file):
|
| 21 |
+
train_entities = read_json_entities(train_file)
|
| 22 |
+
test_entities = read_json_entities(test_file)
|
| 23 |
+
|
| 24 |
+
train_entities = set(x[0] for x in train_entities)
|
| 25 |
+
total_score = sum(1 for x in test_entities if x[0] in train_entities)
|
| 26 |
+
print(train_file, test_file, total_score / len(test_entities))
|
| 27 |
+
|
| 28 |
+
def main():
|
| 29 |
+
args = parse_args()
|
| 30 |
+
|
| 31 |
+
for train_idx, train_file in enumerate(args.train):
|
| 32 |
+
if train_idx > 0:
|
| 33 |
+
print()
|
| 34 |
+
for test_file in args.test:
|
| 35 |
+
report_known_entities(train_file, test_file)
|
| 36 |
+
|
| 37 |
+
if __name__ == '__main__':
|
| 38 |
+
main()
|
stanza/stanza/utils/datasets/ner/conll_to_iob.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Process a conll file into BIO
|
| 3 |
+
|
| 4 |
+
Includes the ability to process a file from a text file
|
| 5 |
+
or a text file within a zip
|
| 6 |
+
|
| 7 |
+
Main program extracts a piece of the zip file from the Danish DDT dataset
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import io
|
| 11 |
+
import zipfile
|
| 12 |
+
from zipfile import ZipFile
|
| 13 |
+
from stanza.utils.conll import CoNLL
|
| 14 |
+
|
| 15 |
+
def process_conll(input_file, output_file, zip_file=None, conversion=None, attr_prefix="name", allow_empty=False):
|
| 16 |
+
"""
|
| 17 |
+
Process a single file from DDT
|
| 18 |
+
|
| 19 |
+
zip_filename: path to ddt.zip
|
| 20 |
+
in_filename: which piece to read
|
| 21 |
+
out_filename: where to write the result
|
| 22 |
+
|
| 23 |
+
label: which attribute to get from the misc field
|
| 24 |
+
"""
|
| 25 |
+
if not attr_prefix.endswith("="):
|
| 26 |
+
attr_prefix = attr_prefix + "="
|
| 27 |
+
|
| 28 |
+
doc = CoNLL.conll2doc(input_file=input_file, zip_file=zip_file)
|
| 29 |
+
|
| 30 |
+
with open(output_file, "w", encoding="utf-8") as fout:
|
| 31 |
+
for sentence_idx, sentence in enumerate(doc.sentences):
|
| 32 |
+
for token_idx, token in enumerate(sentence.tokens):
|
| 33 |
+
misc = token.misc.split("|")
|
| 34 |
+
for attr in misc:
|
| 35 |
+
if attr.startswith(attr_prefix):
|
| 36 |
+
ner = attr.split("=", 1)[1]
|
| 37 |
+
break
|
| 38 |
+
else: # name= not found
|
| 39 |
+
if allow_empty:
|
| 40 |
+
ner = "O"
|
| 41 |
+
else:
|
| 42 |
+
raise ValueError("Could not find ner tag in document {}, sentence {}, token {}".format(input_file, sentence_idx, token_idx))
|
| 43 |
+
|
| 44 |
+
if ner != "O" and conversion is not None:
|
| 45 |
+
if isinstance(conversion, dict):
|
| 46 |
+
bio, label = ner.split("-", 1)
|
| 47 |
+
if label in conversion:
|
| 48 |
+
label = conversion[label]
|
| 49 |
+
ner = "%s-%s" % (bio, label)
|
| 50 |
+
else:
|
| 51 |
+
ner = conversion(ner)
|
| 52 |
+
fout.write("%s\t%s\n" % (token.text, ner))
|
| 53 |
+
fout.write("\n")
|
| 54 |
+
|
| 55 |
+
def main():
|
| 56 |
+
process_conll(zip_file="extern_data/ner/da_ddt/ddt.zip", input_file="ddt.train.conllu", output_file="data/ner/da_ddt.train.bio")
|
| 57 |
+
|
| 58 |
+
if __name__ == '__main__':
|
| 59 |
+
main()
|
stanza/stanza/utils/datasets/ner/convert_bn_daffodil.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Convert a Bengali NER dataset to our internal .json format
|
| 3 |
+
|
| 4 |
+
The dataset is here:
|
| 5 |
+
|
| 6 |
+
https://github.com/Rifat1493/Bengali-NER/tree/master/Input
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import argparse
|
| 10 |
+
import os
|
| 11 |
+
import random
|
| 12 |
+
import tempfile
|
| 13 |
+
|
| 14 |
+
from stanza.utils.datasets.ner.utils import read_tsv, write_dataset
|
| 15 |
+
|
| 16 |
+
def redo_time_tags(sentences):
|
| 17 |
+
"""
|
| 18 |
+
Replace all TIM, TIM with B-TIM, I-TIM
|
| 19 |
+
|
| 20 |
+
A brief use of Google Translate suggests the time phrases are
|
| 21 |
+
generally one phrase, so we don't want to turn this into B-TIM, B-TIM
|
| 22 |
+
"""
|
| 23 |
+
new_sentences = []
|
| 24 |
+
|
| 25 |
+
for sentence in sentences:
|
| 26 |
+
new_sentence = []
|
| 27 |
+
prev_time = False
|
| 28 |
+
for word, tag in sentence:
|
| 29 |
+
if tag == 'TIM':
|
| 30 |
+
if prev_time:
|
| 31 |
+
new_sentence.append((word, "I-TIM"))
|
| 32 |
+
else:
|
| 33 |
+
prev_time = True
|
| 34 |
+
new_sentence.append((word, "B-TIM"))
|
| 35 |
+
else:
|
| 36 |
+
prev_time = False
|
| 37 |
+
new_sentence.append((word, tag))
|
| 38 |
+
new_sentences.append(new_sentence)
|
| 39 |
+
|
| 40 |
+
return new_sentences
|
| 41 |
+
|
| 42 |
+
def strip_words(dataset):
|
| 43 |
+
return [[(x[0].strip().replace('\ufeff', ''), x[1]) for x in sentence] for sentence in dataset]
|
| 44 |
+
|
| 45 |
+
def filter_blank_words(train_file, train_filtered_file):
|
| 46 |
+
"""
|
| 47 |
+
As of July 2022, this dataset has blank words with O labels, which is not ideal
|
| 48 |
+
|
| 49 |
+
This method removes those lines
|
| 50 |
+
"""
|
| 51 |
+
with open(train_file, encoding="utf-8") as fin:
|
| 52 |
+
with open(train_filtered_file, "w", encoding="utf-8") as fout:
|
| 53 |
+
for line in fin:
|
| 54 |
+
if line.strip() == 'O':
|
| 55 |
+
continue
|
| 56 |
+
fout.write(line)
|
| 57 |
+
|
| 58 |
+
def filter_broken_tags(train_sentences):
|
| 59 |
+
"""
|
| 60 |
+
Eliminate any sentences where any of the tags were empty
|
| 61 |
+
"""
|
| 62 |
+
return [x for x in train_sentences if not any(y[1] is None for y in x)]
|
| 63 |
+
|
| 64 |
+
def filter_bad_words(train_sentences):
|
| 65 |
+
"""
|
| 66 |
+
Not bad words like poop, but characters that don't exist
|
| 67 |
+
|
| 68 |
+
These characters look like n and l in emacs, but they are really
|
| 69 |
+
0xF06C and 0xF06E
|
| 70 |
+
"""
|
| 71 |
+
return [[x for x in sentence if not x[0] in ("", "")] for sentence in train_sentences]
|
| 72 |
+
|
| 73 |
+
def read_datasets(in_directory):
|
| 74 |
+
"""
|
| 75 |
+
Reads & splits the train data, reads the test data
|
| 76 |
+
|
| 77 |
+
There is no validation data, so we split the training data into
|
| 78 |
+
two pieces and use the smaller piece as the dev set
|
| 79 |
+
|
| 80 |
+
Also performeed is a conversion of TIM -> B-TIM, I-TIM
|
| 81 |
+
"""
|
| 82 |
+
# make sure we always get the same shuffle & split
|
| 83 |
+
random.seed(1234)
|
| 84 |
+
|
| 85 |
+
train_file = os.path.join(in_directory, "Input", "train_data.txt")
|
| 86 |
+
with tempfile.TemporaryDirectory() as tempdir:
|
| 87 |
+
train_filtered_file = os.path.join(tempdir, "train.txt")
|
| 88 |
+
filter_blank_words(train_file, train_filtered_file)
|
| 89 |
+
train_sentences = read_tsv(train_filtered_file, text_column=0, annotation_column=1, keep_broken_tags=True)
|
| 90 |
+
train_sentences = filter_broken_tags(train_sentences)
|
| 91 |
+
train_sentences = filter_bad_words(train_sentences)
|
| 92 |
+
train_sentences = redo_time_tags(train_sentences)
|
| 93 |
+
train_sentences = strip_words(train_sentences)
|
| 94 |
+
|
| 95 |
+
test_file = os.path.join(in_directory, "Input", "test_data.txt")
|
| 96 |
+
test_sentences = read_tsv(test_file, text_column=0, annotation_column=1, keep_broken_tags=True)
|
| 97 |
+
test_sentences = filter_broken_tags(test_sentences)
|
| 98 |
+
test_sentences = filter_bad_words(test_sentences)
|
| 99 |
+
test_sentences = redo_time_tags(test_sentences)
|
| 100 |
+
test_sentences = strip_words(test_sentences)
|
| 101 |
+
|
| 102 |
+
random.shuffle(train_sentences)
|
| 103 |
+
split_len = len(train_sentences) * 9 // 10
|
| 104 |
+
dev_sentences = train_sentences[split_len:]
|
| 105 |
+
train_sentences = train_sentences[:split_len]
|
| 106 |
+
|
| 107 |
+
datasets = (train_sentences, dev_sentences, test_sentences)
|
| 108 |
+
return datasets
|
| 109 |
+
|
| 110 |
+
def convert_dataset(in_directory, out_directory):
|
| 111 |
+
"""
|
| 112 |
+
Reads the datasets using read_datasets, then write them back out
|
| 113 |
+
"""
|
| 114 |
+
datasets = read_datasets(in_directory)
|
| 115 |
+
write_dataset(datasets, out_directory, "bn_daffodil")
|
| 116 |
+
|
| 117 |
+
if __name__ == '__main__':
|
| 118 |
+
parser = argparse.ArgumentParser()
|
| 119 |
+
parser.add_argument('--input_path', type=str, default="/home/john/extern_data/ner/bangla/Bengali-NER", help="Where to find the files")
|
| 120 |
+
parser.add_argument('--output_path', type=str, default="/home/john/stanza/data/ner", help="Where to output the results")
|
| 121 |
+
args = parser.parse_args()
|
| 122 |
+
|
| 123 |
+
convert_dataset(args.input_path, args.output_path)
|
stanza/stanza/utils/datasets/ner/convert_en_conll03.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Downloads (if necessary) conll03 from Huggingface, then converts it to Stanza .json
|
| 3 |
+
|
| 4 |
+
Some online sources for CoNLL 2003 require multiple pieces, but it is currently hosted on HF:
|
| 5 |
+
https://huggingface.co/datasets/conll2003
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
from stanza.utils.default_paths import get_default_paths
|
| 11 |
+
from stanza.utils.datasets.ner.utils import write_dataset
|
| 12 |
+
|
| 13 |
+
TAG_TO_ID = {'O': 0, 'B-PER': 1, 'I-PER': 2, 'B-ORG': 3, 'I-ORG': 4, 'B-LOC': 5, 'I-LOC': 6, 'B-MISC': 7, 'I-MISC': 8}
|
| 14 |
+
ID_TO_TAG = {y: x for x, y in TAG_TO_ID.items()}
|
| 15 |
+
|
| 16 |
+
def convert_dataset_section(section):
|
| 17 |
+
sentences = []
|
| 18 |
+
for item in section:
|
| 19 |
+
words = item['tokens']
|
| 20 |
+
tags = [ID_TO_TAG[x] for x in item['ner_tags']]
|
| 21 |
+
sentences.append(list(zip(words, tags)))
|
| 22 |
+
return sentences
|
| 23 |
+
|
| 24 |
+
def process_dataset(short_name, conll_path, ner_output_path):
|
| 25 |
+
try:
|
| 26 |
+
from datasets import load_dataset
|
| 27 |
+
except ImportError as e:
|
| 28 |
+
raise ImportError("Please install the datasets package to process CoNLL03 with Stanza")
|
| 29 |
+
|
| 30 |
+
dataset = load_dataset('conll2003', cache_dir=conll_path)
|
| 31 |
+
datasets = [convert_dataset_section(x) for x in [dataset['train'], dataset['validation'], dataset['test']]]
|
| 32 |
+
write_dataset(datasets, ner_output_path, short_name)
|
| 33 |
+
|
| 34 |
+
def main():
|
| 35 |
+
paths = get_default_paths()
|
| 36 |
+
ner_input_path = paths['NERBASE']
|
| 37 |
+
conll_path = os.path.join(ner_input_path, "english", "en_conll03")
|
| 38 |
+
ner_output_path = paths['NER_DATA_DIR']
|
| 39 |
+
process_dataset("en_conll03", conll_path, ner_output_path)
|
| 40 |
+
|
| 41 |
+
if __name__ == '__main__':
|
| 42 |
+
main()
|
stanza/stanza/utils/datasets/ner/convert_he_iahlt.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import defaultdict
|
| 2 |
+
import os
|
| 3 |
+
import re
|
| 4 |
+
|
| 5 |
+
from stanza.utils.conll import CoNLL
|
| 6 |
+
import stanza.utils.default_paths as default_paths
|
| 7 |
+
from stanza.utils.datasets.ner.utils import write_dataset
|
| 8 |
+
|
| 9 |
+
def output_entities(sentence):
|
| 10 |
+
for word in sentence.words:
|
| 11 |
+
misc = word.misc
|
| 12 |
+
if misc is None:
|
| 13 |
+
continue
|
| 14 |
+
|
| 15 |
+
pieces = misc.split("|")
|
| 16 |
+
for piece in pieces:
|
| 17 |
+
if piece.startswith("Entity="):
|
| 18 |
+
entity = piece.split("=", maxsplit=1)[1]
|
| 19 |
+
print(" " + entity)
|
| 20 |
+
break
|
| 21 |
+
|
| 22 |
+
def extract_single_sentence(sentence):
|
| 23 |
+
current_entity = []
|
| 24 |
+
words = []
|
| 25 |
+
for word in sentence.words:
|
| 26 |
+
text = word.text
|
| 27 |
+
misc = word.misc
|
| 28 |
+
if misc is None:
|
| 29 |
+
pieces = []
|
| 30 |
+
else:
|
| 31 |
+
pieces = misc.split("|")
|
| 32 |
+
|
| 33 |
+
closes = []
|
| 34 |
+
first_entity = False
|
| 35 |
+
for piece in pieces:
|
| 36 |
+
if piece.startswith("Entity="):
|
| 37 |
+
entity = piece.split("=", maxsplit=1)[1]
|
| 38 |
+
entity_pieces = re.split(r"([()])", entity)
|
| 39 |
+
entity_pieces = [x for x in entity_pieces if x] # remove blanks from re.split
|
| 40 |
+
entity_idx = 0
|
| 41 |
+
while entity_idx < len(entity_pieces):
|
| 42 |
+
if entity_pieces[entity_idx] == '(':
|
| 43 |
+
assert len(entity_pieces) > entity_idx + 1, "Opening an unspecified entity"
|
| 44 |
+
if len(current_entity) == 0:
|
| 45 |
+
first_entity = True
|
| 46 |
+
current_entity.append(entity_pieces[entity_idx + 1])
|
| 47 |
+
entity_idx += 2
|
| 48 |
+
elif entity_pieces[entity_idx] == ')':
|
| 49 |
+
assert entity_idx != 0, "Closing an unspecified entity"
|
| 50 |
+
closes.append(entity_pieces[entity_idx-1])
|
| 51 |
+
entity_idx += 1
|
| 52 |
+
else:
|
| 53 |
+
# the entities themselves get added or removed via the ()
|
| 54 |
+
entity_idx += 1
|
| 55 |
+
|
| 56 |
+
if len(current_entity) == 0:
|
| 57 |
+
entity = 'O'
|
| 58 |
+
else:
|
| 59 |
+
entity = current_entity[0]
|
| 60 |
+
entity = "B-" + entity if first_entity else "I-" + entity
|
| 61 |
+
words.append((text, entity))
|
| 62 |
+
|
| 63 |
+
assert len(current_entity) >= len(closes), "Too many closes for the current open entities"
|
| 64 |
+
for close_entity in closes:
|
| 65 |
+
# TODO: check the close is closing the right thing
|
| 66 |
+
assert close_entity == current_entity[-1], "Closed the wrong entity: %s vs %s" % (close_entity, current_entity[-1])
|
| 67 |
+
current_entity = current_entity[:-1]
|
| 68 |
+
return words
|
| 69 |
+
|
| 70 |
+
def extract_sentences(doc):
|
| 71 |
+
sentences = []
|
| 72 |
+
for sentence in doc.sentences:
|
| 73 |
+
try:
|
| 74 |
+
words = extract_single_sentence(sentence)
|
| 75 |
+
sentences.append(words)
|
| 76 |
+
except AssertionError as e:
|
| 77 |
+
print("Skipping sentence %s ... %s" % (sentence.sent_id, str(e)))
|
| 78 |
+
output_entities(sentence)
|
| 79 |
+
|
| 80 |
+
return sentences
|
| 81 |
+
|
| 82 |
+
def convert_iahlt(udbase, output_dir, short_name):
|
| 83 |
+
shards = ("train", "dev", "test")
|
| 84 |
+
ud_datasets = ["UD_Hebrew-IAHLTwiki", "UD_Hebrew-IAHLTknesset"]
|
| 85 |
+
base_filenames = ["he_iahltwiki-ud-%s.conllu", "he_iahltknesset-ud-%s.conllu"]
|
| 86 |
+
datasets = defaultdict(list)
|
| 87 |
+
|
| 88 |
+
for ud_dataset, base_filename in zip(ud_datasets, base_filenames):
|
| 89 |
+
ud_dataset_path = os.path.join(udbase, ud_dataset)
|
| 90 |
+
for shard in shards:
|
| 91 |
+
filename = os.path.join(ud_dataset_path, base_filename % shard)
|
| 92 |
+
doc = CoNLL.conll2doc(filename)
|
| 93 |
+
sentences = extract_sentences(doc)
|
| 94 |
+
print("Read %d sentences from %s" % (len(sentences), filename))
|
| 95 |
+
datasets[shard].extend(sentences)
|
| 96 |
+
|
| 97 |
+
datasets = [datasets[x] for x in shards]
|
| 98 |
+
write_dataset(datasets, output_dir, short_name)
|
| 99 |
+
|
| 100 |
+
def main():
|
| 101 |
+
paths = default_paths.get_default_paths()
|
| 102 |
+
|
| 103 |
+
udbase = paths["UDBASE_GIT"]
|
| 104 |
+
output_directory = paths["NER_DATA_DIR"]
|
| 105 |
+
convert_iahlt(udbase, output_directory, "he_iahlt")
|
| 106 |
+
|
| 107 |
+
if __name__ == '__main__':
|
| 108 |
+
main()
|
stanza/stanza/utils/datasets/ner/convert_lst20.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Converts the Thai LST20 dataset to a format usable by Stanza's NER model
|
| 3 |
+
|
| 4 |
+
The dataset in the original format has a few tag errors which we
|
| 5 |
+
automatically fix (or at worst cover up)
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
from stanza.utils.datasets.ner.utils import convert_bio_to_json
|
| 11 |
+
|
| 12 |
+
def convert_lst20(paths, short_name, include_space_char=True):
|
| 13 |
+
assert short_name == "th_lst20"
|
| 14 |
+
SHARDS = ("train", "eval", "test")
|
| 15 |
+
BASE_OUTPUT_PATH = paths["NER_DATA_DIR"]
|
| 16 |
+
|
| 17 |
+
input_split = [(os.path.join(paths["NERBASE"], "thai", "LST20_Corpus", x), x) for x in SHARDS]
|
| 18 |
+
|
| 19 |
+
if not include_space_char:
|
| 20 |
+
short_name = short_name + "_no_ws"
|
| 21 |
+
|
| 22 |
+
for input_folder, split_type in input_split:
|
| 23 |
+
text_list = [text for text in os.listdir(input_folder) if text[0] == 'T']
|
| 24 |
+
|
| 25 |
+
if split_type == "eval":
|
| 26 |
+
split_type = "dev"
|
| 27 |
+
|
| 28 |
+
output_path = os.path.join(BASE_OUTPUT_PATH, "%s.%s.bio" % (short_name, split_type))
|
| 29 |
+
print(output_path)
|
| 30 |
+
|
| 31 |
+
with open(output_path, 'w', encoding='utf-8') as fout:
|
| 32 |
+
for text in text_list:
|
| 33 |
+
lst = []
|
| 34 |
+
with open(os.path.join(input_folder, text), 'r', encoding='utf-8') as fin:
|
| 35 |
+
lines = fin.readlines()
|
| 36 |
+
|
| 37 |
+
for line_idx, line in enumerate(lines):
|
| 38 |
+
x = line.strip().split('\t')
|
| 39 |
+
if len(x) > 1:
|
| 40 |
+
if x[0] == '_' and not include_space_char:
|
| 41 |
+
continue
|
| 42 |
+
else:
|
| 43 |
+
word, tag = x[0], x[2]
|
| 44 |
+
|
| 45 |
+
if tag == "MEA_BI":
|
| 46 |
+
tag = "B_MEA"
|
| 47 |
+
if tag == "OBRN_B":
|
| 48 |
+
tag = "B_BRN"
|
| 49 |
+
if tag == "ORG_I":
|
| 50 |
+
tag = "I_ORG"
|
| 51 |
+
if tag == "PER_I":
|
| 52 |
+
tag = "I_PER"
|
| 53 |
+
if tag == "LOC_I":
|
| 54 |
+
tag = "I_LOC"
|
| 55 |
+
if tag == "B" and line_idx + 1 < len(lines):
|
| 56 |
+
x_next = lines[line_idx+1].strip().split('\t')
|
| 57 |
+
if len(x_next) > 1:
|
| 58 |
+
tag_next = x_next[2]
|
| 59 |
+
if "I_" in tag_next or "E_" in tag_next:
|
| 60 |
+
tag = tag + tag_next[1:]
|
| 61 |
+
else:
|
| 62 |
+
tag = "O"
|
| 63 |
+
else:
|
| 64 |
+
tag = "O"
|
| 65 |
+
if "_" in tag:
|
| 66 |
+
tag = tag.replace("_", "-")
|
| 67 |
+
if "ABB" in tag or tag == "DDEM" or tag == "I" or tag == "__":
|
| 68 |
+
tag = "O"
|
| 69 |
+
|
| 70 |
+
fout.write('{}\t{}'.format(word, tag))
|
| 71 |
+
fout.write('\n')
|
| 72 |
+
else:
|
| 73 |
+
fout.write('\n')
|
| 74 |
+
convert_bio_to_json(BASE_OUTPUT_PATH, BASE_OUTPUT_PATH, short_name)
|
stanza/stanza/utils/datasets/ner/convert_mr_l3cube.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Reads one piece of the MR L3Cube dataset
|
| 3 |
+
|
| 4 |
+
The dataset is structured as a long list of words already in IOB format
|
| 5 |
+
The sentences have an ID which changes when a new sentence starts
|
| 6 |
+
The tags are labeled BNEM instead of B-NEM, so we update that.
|
| 7 |
+
(Could theoretically remap the tags to names more typical of other datasets as well)
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
def convert(input_file):
|
| 11 |
+
"""
|
| 12 |
+
Converts one file of the dataset
|
| 13 |
+
|
| 14 |
+
Return: a list of list of pairs, (text, tag)
|
| 15 |
+
"""
|
| 16 |
+
with open(input_file, encoding="utf-8") as fin:
|
| 17 |
+
lines = fin.readlines()
|
| 18 |
+
|
| 19 |
+
sentences = []
|
| 20 |
+
current_sentence = []
|
| 21 |
+
prev_sent_id = None
|
| 22 |
+
for idx, line in enumerate(lines):
|
| 23 |
+
# first line of each of the segments is the header
|
| 24 |
+
if idx == 0:
|
| 25 |
+
continue
|
| 26 |
+
|
| 27 |
+
line = line.strip()
|
| 28 |
+
if not line:
|
| 29 |
+
continue
|
| 30 |
+
pieces = line.split("\t")
|
| 31 |
+
if len(pieces) != 3:
|
| 32 |
+
raise ValueError("Unexpected number of pieces at line %d of %s" % (idx, input_file))
|
| 33 |
+
|
| 34 |
+
text, ner, sent_id = pieces
|
| 35 |
+
if ner != 'O':
|
| 36 |
+
# ner symbols are written as BNEM, BNED, etc in this dataset
|
| 37 |
+
ner = ner[0] + "-" + ner[1:]
|
| 38 |
+
|
| 39 |
+
if not prev_sent_id:
|
| 40 |
+
prev_sent_id = sent_id
|
| 41 |
+
if sent_id != prev_sent_id:
|
| 42 |
+
prev_sent_id = sent_id
|
| 43 |
+
if len(current_sentence) == 0:
|
| 44 |
+
raise ValueError("This should not happen!")
|
| 45 |
+
sentences.append(current_sentence)
|
| 46 |
+
current_sentence = []
|
| 47 |
+
|
| 48 |
+
current_sentence.append((text, ner))
|
| 49 |
+
|
| 50 |
+
if current_sentence:
|
| 51 |
+
sentences.append(current_sentence)
|
| 52 |
+
|
| 53 |
+
print("Read %d sentences in %d lines from %s" % (len(sentences), len(lines), input_file))
|
| 54 |
+
return sentences
|
stanza/stanza/utils/datasets/ner/convert_nner22.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Converts the Thai NNER22 dataset to a format usable by Stanza's NER model
|
| 3 |
+
|
| 4 |
+
The dataset is already written in json format, so we will convert into a compatible json format.
|
| 5 |
+
|
| 6 |
+
The dataset in the original format has nested NER format which we will only extract the first layer
|
| 7 |
+
of NER tag and write it in the format accepted by current Stanza model
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
import logging
|
| 12 |
+
import json
|
| 13 |
+
|
| 14 |
+
def convert_nner22(paths, short_name, include_space_char=True):
|
| 15 |
+
assert short_name == "th_nner22"
|
| 16 |
+
SHARDS = ("train", "dev", "test")
|
| 17 |
+
BASE_INPUT_PATH = os.path.join(paths["NERBASE"], "thai", "Thai-NNER", "data", "scb-nner-th-2022", "postproc")
|
| 18 |
+
|
| 19 |
+
if not include_space_char:
|
| 20 |
+
short_name = short_name + "_no_ws"
|
| 21 |
+
|
| 22 |
+
for shard in SHARDS:
|
| 23 |
+
input_path = os.path.join(BASE_INPUT_PATH, "%s.json" % (shard))
|
| 24 |
+
output_path = os.path.join(paths["NER_DATA_DIR"], "%s.%s.json" % (short_name, shard))
|
| 25 |
+
|
| 26 |
+
logging.info("Output path for %s split at %s" % (shard, output_path))
|
| 27 |
+
|
| 28 |
+
data = json.load(open(input_path))
|
| 29 |
+
|
| 30 |
+
documents = []
|
| 31 |
+
|
| 32 |
+
for i in range(len(data)):
|
| 33 |
+
token, entities = data[i]["tokens"], data[i]["entities"]
|
| 34 |
+
|
| 35 |
+
token_length, sofar = len(token), 0
|
| 36 |
+
document, ner_dict = [], {}
|
| 37 |
+
|
| 38 |
+
for entity in entities:
|
| 39 |
+
start, stop = entity["span"]
|
| 40 |
+
|
| 41 |
+
if stop > sofar:
|
| 42 |
+
ner = entity["entity_type"].upper()
|
| 43 |
+
sofar = stop
|
| 44 |
+
|
| 45 |
+
for j in range(start, stop):
|
| 46 |
+
if j == start:
|
| 47 |
+
ner_tag = "B-" + ner
|
| 48 |
+
elif j == stop - 1:
|
| 49 |
+
ner_tag = "E-" + ner
|
| 50 |
+
else:
|
| 51 |
+
ner_tag = "I-" + ner
|
| 52 |
+
|
| 53 |
+
ner_dict[j] = (ner_tag, token[j])
|
| 54 |
+
|
| 55 |
+
for k in range(token_length):
|
| 56 |
+
dict_add = {}
|
| 57 |
+
|
| 58 |
+
if k not in ner_dict:
|
| 59 |
+
dict_add["ner"], dict_add["text"] = "O", token[k]
|
| 60 |
+
else:
|
| 61 |
+
dict_add["ner"], dict_add["text"] = ner_dict[k]
|
| 62 |
+
|
| 63 |
+
document.append(dict_add)
|
| 64 |
+
|
| 65 |
+
documents.append(document)
|
| 66 |
+
|
| 67 |
+
with open(output_path, "w") as outfile:
|
| 68 |
+
json.dump(documents, outfile, indent=1)
|
| 69 |
+
|
| 70 |
+
logging.info("%s.%s.json file successfully created" % (short_name, shard))
|
stanza/stanza/utils/datasets/ner/convert_ontonotes.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Downloads (if necessary) conll03 from Huggingface, then converts it to Stanza .json
|
| 3 |
+
|
| 4 |
+
Some online sources for CoNLL 2003 require multiple pieces, but it is currently hosted on HF:
|
| 5 |
+
https://huggingface.co/datasets/conll2003
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
from stanza.utils.default_paths import get_default_paths
|
| 11 |
+
from stanza.utils.datasets.ner.utils import write_dataset
|
| 12 |
+
|
| 13 |
+
ID_TO_TAG = ["O", "B-PERSON", "I-PERSON", "B-NORP", "I-NORP", "B-FAC", "I-FAC", "B-ORG", "I-ORG", "B-GPE", "I-GPE", "B-LOC", "I-LOC", "B-PRODUCT", "I-PRODUCT", "B-DATE", "I-DATE", "B-TIME", "I-TIME", "B-PERCENT", "I-PERCENT", "B-MONEY", "I-MONEY", "B-QUANTITY", "I-QUANTITY", "B-ORDINAL", "I-ORDINAL", "B-CARDINAL", "I-CARDINAL", "B-EVENT", "I-EVENT", "B-WORK_OF_ART", "I-WORK_OF_ART", "B-LAW", "I-LAW", "B-LANGUAGE", "I-LANGUAGE",]
|
| 14 |
+
|
| 15 |
+
def convert_dataset_section(config_name, section):
|
| 16 |
+
sentences = []
|
| 17 |
+
for doc in section:
|
| 18 |
+
# the nt_ sentences (New Testament) in the HF version of OntoNotes
|
| 19 |
+
# have blank named_entities, even though there was no original .name file
|
| 20 |
+
# that corresponded with these annotations
|
| 21 |
+
if config_name.startswith("english") and doc['document_id'].startswith("pt/nt"):
|
| 22 |
+
continue
|
| 23 |
+
for sentence in doc['sentences']:
|
| 24 |
+
words = sentence['words']
|
| 25 |
+
tags = [ID_TO_TAG[x] for x in sentence['named_entities']]
|
| 26 |
+
sentences.append(list(zip(words, tags)))
|
| 27 |
+
return sentences
|
| 28 |
+
|
| 29 |
+
def process_dataset(short_name, conll_path, ner_output_path):
|
| 30 |
+
try:
|
| 31 |
+
from datasets import load_dataset
|
| 32 |
+
except ImportError as e:
|
| 33 |
+
raise ImportError("Please install the datasets package to process CoNLL03 with Stanza")
|
| 34 |
+
|
| 35 |
+
if short_name == 'en_ontonotes':
|
| 36 |
+
# there is an english_v12, but it is filled with junk annotations
|
| 37 |
+
# for example, near the end:
|
| 38 |
+
# And John_O, I realize
|
| 39 |
+
config_name = 'english_v4'
|
| 40 |
+
elif short_name in ('zh_ontonotes', 'zh-hans_ontonotes'):
|
| 41 |
+
config_name = 'chinese_v4'
|
| 42 |
+
elif short_name == 'ar_ontonotes':
|
| 43 |
+
config_name = 'arabic_v4'
|
| 44 |
+
else:
|
| 45 |
+
raise ValueError("Unknown short name for downloading ontonotes: %s" % short_name)
|
| 46 |
+
dataset = load_dataset("conll2012_ontonotesv5", config_name, cache_dir=conll_path)
|
| 47 |
+
datasets = [convert_dataset_section(config_name, x) for x in [dataset['train'], dataset['validation'], dataset['test']]]
|
| 48 |
+
write_dataset(datasets, ner_output_path, short_name)
|
| 49 |
+
|
| 50 |
+
def main():
|
| 51 |
+
paths = get_default_paths()
|
| 52 |
+
ner_input_path = paths['NERBASE']
|
| 53 |
+
conll_path = os.path.join(ner_input_path, "english", "en_ontonotes")
|
| 54 |
+
ner_output_path = paths['NER_DATA_DIR']
|
| 55 |
+
process_dataset("en_ontonotes", conll_path, ner_output_path)
|
| 56 |
+
|
| 57 |
+
if __name__ == '__main__':
|
| 58 |
+
main()
|
stanza/stanza/utils/datasets/ner/json_to_bio.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
If you want to convert .json back to .bio for some reason, this will do it for you
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
import json
|
| 7 |
+
import os
|
| 8 |
+
from stanza.models.common.doc import Document
|
| 9 |
+
from stanza.models.ner.utils import process_tags
|
| 10 |
+
from stanza.utils.default_paths import get_default_paths
|
| 11 |
+
|
| 12 |
+
def convert_json_to_bio(input_filename, output_filename):
|
| 13 |
+
with open(input_filename, encoding="utf-8") as fin:
|
| 14 |
+
doc = Document(json.load(fin))
|
| 15 |
+
sentences = [[(word.text, word.ner) for word in sentence.tokens] for sentence in doc.sentences]
|
| 16 |
+
sentences = process_tags(sentences, "bioes")
|
| 17 |
+
with open(output_filename, "w", encoding="utf-8") as fout:
|
| 18 |
+
for sentence in sentences:
|
| 19 |
+
for word in sentence:
|
| 20 |
+
fout.write("%s\t%s\n" % word)
|
| 21 |
+
fout.write("\n")
|
| 22 |
+
|
| 23 |
+
def main(args=None):
|
| 24 |
+
ner_data_dir = get_default_paths()['NER_DATA_DIR']
|
| 25 |
+
parser = argparse.ArgumentParser()
|
| 26 |
+
parser.add_argument('--input_filename', type=str, default="data/ner/en_foreign-4class.test.json", help='Convert an individual file')
|
| 27 |
+
parser.add_argument('--input_dir', type=str, default=ner_data_dir, help='Which directory to find the dataset, if using --input_dataset')
|
| 28 |
+
parser.add_argument('--input_dataset', type=str, help='Convert an entire dataset')
|
| 29 |
+
parser.add_argument('--output_suffix', type=str, default='bioes', help='suffix for output filenames')
|
| 30 |
+
args = parser.parse_args(args)
|
| 31 |
+
|
| 32 |
+
if args.input_dataset:
|
| 33 |
+
input_filenames = [os.path.join(args.input_dir, "%s.%s.json" % (args.input_dataset, shard))
|
| 34 |
+
for shard in ("train", "dev", "test")]
|
| 35 |
+
else:
|
| 36 |
+
input_filenames = [args.input_filename]
|
| 37 |
+
for input_filename in input_filenames:
|
| 38 |
+
output_filename = os.path.splitext(input_filename)[0] + "." + args.output_suffix
|
| 39 |
+
print("%s -> %s" % (input_filename, output_filename))
|
| 40 |
+
convert_json_to_bio(input_filename, output_filename)
|
| 41 |
+
|
| 42 |
+
if __name__ == '__main__':
|
| 43 |
+
main()
|
stanza/stanza/utils/datasets/ner/misc_to_date.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# for the Worldwide dataset, automatically switch the Misc tags to Date when Stanza Ontonotes thinks it's a Date
|
| 2 |
+
# this keeps our annotation scheme for dates (eg, not "3 months ago") while hopefully switching them all to Date
|
| 3 |
+
#
|
| 4 |
+
# maybe some got missed
|
| 5 |
+
# also, there are a few with some nested entities. printed out warnings and edited those by hand
|
| 6 |
+
#
|
| 7 |
+
# just need to run this with the Worldwide dataset in the ner path
|
| 8 |
+
# it will automatically convert as many as it can
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
|
| 14 |
+
import stanza
|
| 15 |
+
from stanza.utils.datasets.ner.utils import read_tsv
|
| 16 |
+
from stanza.utils.default_paths import get_default_paths
|
| 17 |
+
|
| 18 |
+
paths = get_default_paths()
|
| 19 |
+
BASE_PATH = os.path.join(paths["NERBASE"], "en_foreign")
|
| 20 |
+
input_dir = os.path.join(BASE_PATH, "en-foreign-newswire")
|
| 21 |
+
|
| 22 |
+
pipe = stanza.Pipeline("en", processors="tokenize,ner", tokenize_pretokenized=True, package={"ner": "ontonotes_bert"})
|
| 23 |
+
|
| 24 |
+
filenames = []
|
| 25 |
+
|
| 26 |
+
def ner_tags(pipe, sentence):
|
| 27 |
+
doc = pipe([sentence])
|
| 28 |
+
tags = [token.ner for sentence in doc.sentences for token in sentence.tokens]
|
| 29 |
+
return tags
|
| 30 |
+
|
| 31 |
+
for root, dirs, files in os.walk(input_dir):
|
| 32 |
+
if root[-6:] == "REVIEW":
|
| 33 |
+
batch_files = os.listdir(root)
|
| 34 |
+
for filename in batch_files:
|
| 35 |
+
file_path = os.path.join(root, filename)
|
| 36 |
+
filenames.append(file_path)
|
| 37 |
+
|
| 38 |
+
for filename in tqdm(filenames):
|
| 39 |
+
try:
|
| 40 |
+
data = read_tsv(filename, text_column=0, annotation_column=1, skip_comments=False, keep_all_columns=True)
|
| 41 |
+
|
| 42 |
+
with open(filename, 'w', encoding='utf-8') as fout:
|
| 43 |
+
warned_file = False
|
| 44 |
+
for sentence in data: # segments delimited by spaces, effectively sentences
|
| 45 |
+
tokens = [x[0] for x in sentence]
|
| 46 |
+
labels = [x[1] for x in sentence]
|
| 47 |
+
|
| 48 |
+
if any(x.endswith("Misc") for x in labels):
|
| 49 |
+
stanza_tags = ner_tags(pipe, tokens)
|
| 50 |
+
in_date = False
|
| 51 |
+
for i, stanza_tag in enumerate(stanza_tags):
|
| 52 |
+
if stanza_tag[2:] == "DATE" and labels[i] != "O":
|
| 53 |
+
if len(sentence[i]) > 2:
|
| 54 |
+
if not warned_file:
|
| 55 |
+
print("Warning: file %s has nested tags being altered" % filename)
|
| 56 |
+
warned_file = True
|
| 57 |
+
# put DATE tags where Stanza thinks there are DATEs
|
| 58 |
+
# as long as we already had a MISC (or something else, I suppose)
|
| 59 |
+
if in_date and not stanza_tag[0].startswith("B") and not stanza_tag[0].startswith("S"):
|
| 60 |
+
sentence[i][1] = "I-Date"
|
| 61 |
+
else:
|
| 62 |
+
sentence[i][1] = "B-Date"
|
| 63 |
+
in_date = True
|
| 64 |
+
elif in_date:
|
| 65 |
+
# make sure new tags start with B- instead of I-
|
| 66 |
+
# honestly it's not clear if, in these cases,
|
| 67 |
+
# we should be switching the following tags to
|
| 68 |
+
# DATE as well. will have to experiment some
|
| 69 |
+
in_date = False
|
| 70 |
+
if labels[i].startswith("I-"):
|
| 71 |
+
sentence[i][1] = "B-" + labels[i][2:]
|
| 72 |
+
for word in sentence:
|
| 73 |
+
fout.write("\t".join(word))
|
| 74 |
+
fout.write("\n")
|
| 75 |
+
fout.write("\n")
|
| 76 |
+
except AssertionError:
|
| 77 |
+
print("Could not process %s" % filename)
|
stanza/stanza/utils/datasets/ner/preprocess_wikiner.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Converts the WikiNER data format to a format usable by our processing tools
|
| 3 |
+
|
| 4 |
+
python preprocess_wikiner input output
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import sys
|
| 8 |
+
|
| 9 |
+
def preprocess_wikiner(input_file, output_file, encoding="utf-8"):
|
| 10 |
+
with open(input_file, encoding=encoding) as fin:
|
| 11 |
+
with open(output_file, "w", encoding="utf-8") as fout:
|
| 12 |
+
for line in fin:
|
| 13 |
+
line = line.strip()
|
| 14 |
+
if not line:
|
| 15 |
+
fout.write("-DOCSTART- O\n")
|
| 16 |
+
fout.write("\n")
|
| 17 |
+
continue
|
| 18 |
+
|
| 19 |
+
words = line.split()
|
| 20 |
+
for word in words:
|
| 21 |
+
pieces = word.split("|")
|
| 22 |
+
text = pieces[0]
|
| 23 |
+
tag = pieces[-1]
|
| 24 |
+
# some words look like Daniel_Bernoulli|I-PER
|
| 25 |
+
# but the original .pl conversion script didn't take that into account
|
| 26 |
+
subtext = text.split("_")
|
| 27 |
+
if tag.startswith("B-") and len(subtext) > 1:
|
| 28 |
+
fout.write("{} {}\n".format(subtext[0], tag))
|
| 29 |
+
for chunk in subtext[1:]:
|
| 30 |
+
fout.write("{} I-{}\n".format(chunk, tag[2:]))
|
| 31 |
+
else:
|
| 32 |
+
for chunk in subtext:
|
| 33 |
+
fout.write("{} {}\n".format(chunk, tag))
|
| 34 |
+
fout.write("\n")
|
| 35 |
+
|
| 36 |
+
if __name__ == '__main__':
|
| 37 |
+
preprocess_wikiner(sys.argv[1], sys.argv[2])
|
stanza/stanza/utils/datasets/ner/simplify_en_worldwide.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import tempfile
|
| 4 |
+
|
| 5 |
+
import stanza
|
| 6 |
+
from stanza.utils.default_paths import get_default_paths
|
| 7 |
+
from stanza.utils.datasets.ner.utils import read_tsv
|
| 8 |
+
from stanza.utils.get_tqdm import get_tqdm
|
| 9 |
+
|
| 10 |
+
tqdm = get_tqdm()
|
| 11 |
+
|
| 12 |
+
PUNCTUATION = """!"#%&'()*+, -./:;<=>?@[\\]^_`{|}~"""
|
| 13 |
+
MONEY_WORDS = {"million", "billion", "trillion", "millions", "billions", "trillions", "hundred", "hundreds",
|
| 14 |
+
"lakh", "crore", # south asian english
|
| 15 |
+
"tens", "of", "ten", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine", "couple"}
|
| 16 |
+
|
| 17 |
+
# Doesn't include Money but this case is handled explicitly for processing
|
| 18 |
+
LABEL_TRANSLATION = {
|
| 19 |
+
"Date": None,
|
| 20 |
+
"Misc": "MISC",
|
| 21 |
+
"Product": "MISC",
|
| 22 |
+
"NORP": "MISC",
|
| 23 |
+
"Facility": "LOC",
|
| 24 |
+
"Location": "LOC",
|
| 25 |
+
"Person": "PER",
|
| 26 |
+
"Organization": "ORG",
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
def isfloat(num):
|
| 30 |
+
try:
|
| 31 |
+
float(num)
|
| 32 |
+
return True
|
| 33 |
+
except ValueError:
|
| 34 |
+
return False
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def process_label(line, is_start=False):
|
| 38 |
+
"""
|
| 39 |
+
Converts our stuff to conll labels
|
| 40 |
+
|
| 41 |
+
event, product, work of art, norp -> MISC
|
| 42 |
+
take out dates - can use Stanza to identify them as dates and eliminate them
|
| 43 |
+
money requires some special care
|
| 44 |
+
facility -> location (there are examples of Bridge and Hospital in the data)
|
| 45 |
+
the version of conll we used to train CoreNLP NER is here:
|
| 46 |
+
|
| 47 |
+
Overall plan:
|
| 48 |
+
Collapse Product, NORP, Money (extract only the symbols), into misc.
|
| 49 |
+
Collapse Facilities into LOC
|
| 50 |
+
Deletes Dates
|
| 51 |
+
|
| 52 |
+
Rule for currency is that we take out labels for the numbers that return True for isfloat()
|
| 53 |
+
Take out words that categorize money (Million, Billion, Trillion, Thousand, Hundred, Ten, Nine, Eight, Seven, Six, Five,
|
| 54 |
+
Four, Three, Two, One)
|
| 55 |
+
Take out punctuation characters
|
| 56 |
+
|
| 57 |
+
If we remove the 'B' tag, then move it to the first remaining tag.
|
| 58 |
+
|
| 59 |
+
Replace tags with 'O'
|
| 60 |
+
is_start parameter signals whether or not this current line is the new start of a tag. Needed for when
|
| 61 |
+
the previous line analyzed is the start of a MONEY tag but is removed because it is a non symbol- need to
|
| 62 |
+
set the starting token that is a symbol to the B-MONEY tag when it might have previously been I-MONEY
|
| 63 |
+
"""
|
| 64 |
+
if not line:
|
| 65 |
+
return []
|
| 66 |
+
token = line[0]
|
| 67 |
+
biggest_label = line[1]
|
| 68 |
+
position, label_name = biggest_label[:2], biggest_label[2:]
|
| 69 |
+
|
| 70 |
+
if label_name == "Money":
|
| 71 |
+
if token.lower() in MONEY_WORDS or token in PUNCTUATION or isfloat(token): # remove this tag
|
| 72 |
+
label_name = "O"
|
| 73 |
+
is_start = True
|
| 74 |
+
position = ""
|
| 75 |
+
else: # keep money tag
|
| 76 |
+
label_name = "MISC"
|
| 77 |
+
if is_start:
|
| 78 |
+
position = "B-"
|
| 79 |
+
is_start = False
|
| 80 |
+
|
| 81 |
+
elif not label_name or label_name == "O":
|
| 82 |
+
pass
|
| 83 |
+
elif label_name in LABEL_TRANSLATION:
|
| 84 |
+
label_name = LABEL_TRANSLATION[label_name]
|
| 85 |
+
if label_name is None:
|
| 86 |
+
position = ""
|
| 87 |
+
label_name = "O"
|
| 88 |
+
is_start = False
|
| 89 |
+
else:
|
| 90 |
+
raise ValueError("Oops, missed a label: %s" % label_name)
|
| 91 |
+
return [token, position + label_name, is_start]
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def write_new_file(save_dir, input_path, old_file, simplify):
|
| 95 |
+
starts_b = False
|
| 96 |
+
with open(input_path, "r+", encoding="utf-8") as iob:
|
| 97 |
+
new_filename = (os.path.splitext(old_file)[0] + ".4class.tsv") if simplify else old_file
|
| 98 |
+
with open(os.path.join(save_dir, new_filename), 'w', encoding='utf-8') as fout:
|
| 99 |
+
for i, line in enumerate(iob):
|
| 100 |
+
if i == 0 or i == 1: # skip over the URL and subsequent space line.
|
| 101 |
+
continue
|
| 102 |
+
line = line.strip()
|
| 103 |
+
if not line:
|
| 104 |
+
fout.write("\n")
|
| 105 |
+
continue
|
| 106 |
+
label = line.split("\t")
|
| 107 |
+
if simplify:
|
| 108 |
+
try:
|
| 109 |
+
edited = process_label(label, is_start=starts_b) # processed label line labels
|
| 110 |
+
except ValueError as e:
|
| 111 |
+
raise ValueError("Error in %s at line %d" % (input_path, i)) from e
|
| 112 |
+
assert edited
|
| 113 |
+
starts_b = edited[-1]
|
| 114 |
+
fout.write("\t".join(edited[:-1]))
|
| 115 |
+
fout.write("\n")
|
| 116 |
+
else:
|
| 117 |
+
fout.write("%s\t%s\n" % (label[0], label[1]))
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def copy_and_simplify(base_path, simplify):
|
| 121 |
+
with tempfile.TemporaryDirectory(dir=base_path) as tempdir:
|
| 122 |
+
# Condense Labels
|
| 123 |
+
input_dir = os.path.join(base_path, "en-worldwide-newswire")
|
| 124 |
+
final_dir = os.path.join(base_path, "4class" if simplify else "9class")
|
| 125 |
+
os.makedirs(tempdir, exist_ok=True)
|
| 126 |
+
os.makedirs(final_dir, exist_ok=True)
|
| 127 |
+
for root, dirs, files in os.walk(input_dir):
|
| 128 |
+
if root[-6:] == "REVIEW":
|
| 129 |
+
batch_files = os.listdir(root)
|
| 130 |
+
for filename in batch_files:
|
| 131 |
+
file_path = os.path.join(root, filename)
|
| 132 |
+
write_new_file(final_dir, file_path, filename, simplify)
|
| 133 |
+
|
| 134 |
+
def main(args=None):
|
| 135 |
+
BASE_PATH = "C:\\Users\\SystemAdmin\\PycharmProjects\\General Code\\stanza source code"
|
| 136 |
+
if not os.path.exists(BASE_PATH):
|
| 137 |
+
paths = get_default_paths()
|
| 138 |
+
BASE_PATH = os.path.join(paths["NERBASE"], "en_worldwide")
|
| 139 |
+
|
| 140 |
+
parser = argparse.ArgumentParser()
|
| 141 |
+
parser.add_argument('--base_path', type=str, default=BASE_PATH, help="Where to find the raw data")
|
| 142 |
+
parser.add_argument('--simplify', default=False, action='store_true', help='Simplify to 4 classes... otherwise, keep all classes')
|
| 143 |
+
parser.add_argument('--no_simplify', dest='simplify', action='store_false', help="Don't simplify to 4 classes")
|
| 144 |
+
args = parser.parse_args(args=args)
|
| 145 |
+
|
| 146 |
+
copy_and_simplify(args.base_path, args.simplify)
|
| 147 |
+
|
| 148 |
+
if __name__ == '__main__':
|
| 149 |
+
main()
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
|
stanza/stanza/utils/datasets/ner/simplify_ontonotes_to_worldwide.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Simplify an existing ner json with the OntoNotes 18 class scheme to the Worldwide scheme
|
| 3 |
+
|
| 4 |
+
Simplified classes used in the Worldwide dataset are:
|
| 5 |
+
|
| 6 |
+
Date
|
| 7 |
+
Facility
|
| 8 |
+
Location
|
| 9 |
+
Misc
|
| 10 |
+
Money
|
| 11 |
+
NORP
|
| 12 |
+
Organization
|
| 13 |
+
Person
|
| 14 |
+
Product
|
| 15 |
+
|
| 16 |
+
vs OntoNotes classes:
|
| 17 |
+
|
| 18 |
+
CARDINAL
|
| 19 |
+
DATE
|
| 20 |
+
EVENT
|
| 21 |
+
FAC
|
| 22 |
+
GPE
|
| 23 |
+
LANGUAGE
|
| 24 |
+
LAW
|
| 25 |
+
LOC
|
| 26 |
+
MONEY
|
| 27 |
+
NORP
|
| 28 |
+
ORDINAL
|
| 29 |
+
ORG
|
| 30 |
+
PERCENT
|
| 31 |
+
PERSON
|
| 32 |
+
PRODUCT
|
| 33 |
+
QUANTITY
|
| 34 |
+
TIME
|
| 35 |
+
WORK_OF_ART
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
import argparse
|
| 39 |
+
import glob
|
| 40 |
+
import json
|
| 41 |
+
import os
|
| 42 |
+
|
| 43 |
+
from stanza.utils.default_paths import get_default_paths
|
| 44 |
+
|
| 45 |
+
WORLDWIDE_ENTITY_MAPPING = {
|
| 46 |
+
"CARDINAL": None,
|
| 47 |
+
"ORDINAL": None,
|
| 48 |
+
"PERCENT": None,
|
| 49 |
+
"QUANTITY": None,
|
| 50 |
+
"TIME": None,
|
| 51 |
+
|
| 52 |
+
"DATE": "Date",
|
| 53 |
+
"EVENT": "Misc",
|
| 54 |
+
"FAC": "Facility",
|
| 55 |
+
"GPE": "Location",
|
| 56 |
+
"LANGUAGE": "NORP",
|
| 57 |
+
"LAW": "Misc",
|
| 58 |
+
"LOC": "Location",
|
| 59 |
+
"MONEY": "Money",
|
| 60 |
+
"NORP": "NORP",
|
| 61 |
+
"ORG": "Organization",
|
| 62 |
+
"PERSON": "Person",
|
| 63 |
+
"PRODUCT": "Product",
|
| 64 |
+
"WORK_OF_ART": "Misc",
|
| 65 |
+
|
| 66 |
+
# identity map in case this is called on the Worldwide half of the tags
|
| 67 |
+
"Date": "Date",
|
| 68 |
+
"Facility": "Facility",
|
| 69 |
+
"Location": "Location",
|
| 70 |
+
"Misc": "Misc",
|
| 71 |
+
"Money": "Money",
|
| 72 |
+
"Organization":"Organization",
|
| 73 |
+
"Person": "Person",
|
| 74 |
+
"Product": "Product",
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
def simplify_ontonotes_to_worldwide(entity):
|
| 78 |
+
if not entity or entity == "O":
|
| 79 |
+
return "O"
|
| 80 |
+
|
| 81 |
+
ent_iob, ent_type = entity.split("-", maxsplit=1)
|
| 82 |
+
|
| 83 |
+
if ent_type in WORLDWIDE_ENTITY_MAPPING:
|
| 84 |
+
if not WORLDWIDE_ENTITY_MAPPING[ent_type]:
|
| 85 |
+
return "O"
|
| 86 |
+
return ent_iob + "-" + WORLDWIDE_ENTITY_MAPPING[ent_type]
|
| 87 |
+
raise ValueError("Unhandled entity: %s" % ent_type)
|
| 88 |
+
|
| 89 |
+
def convert_file(in_file, out_file):
|
| 90 |
+
with open(in_file) as fin:
|
| 91 |
+
gold_doc = json.load(fin)
|
| 92 |
+
|
| 93 |
+
for sentence in gold_doc:
|
| 94 |
+
for word in sentence:
|
| 95 |
+
if 'ner' not in word:
|
| 96 |
+
continue
|
| 97 |
+
word['ner'] = simplify_ontonotes_to_worldwide(word['ner'])
|
| 98 |
+
|
| 99 |
+
with open(out_file, "w", encoding="utf-8") as fout:
|
| 100 |
+
json.dump(gold_doc, fout, indent=2)
|
| 101 |
+
|
| 102 |
+
def main():
|
| 103 |
+
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
| 104 |
+
parser.add_argument('--input_dataset', type=str, default='en_ontonotes', help='which files to convert')
|
| 105 |
+
parser.add_argument('--output_dataset', type=str, default='en_ontonotes-8class', help='which files to write out')
|
| 106 |
+
parser.add_argument('--ner_data_dir', type=str, default=get_default_paths()["NER_DATA_DIR"], help='which directory has the data')
|
| 107 |
+
args = parser.parse_args()
|
| 108 |
+
|
| 109 |
+
input_files = glob.glob(os.path.join(args.ner_data_dir, args.input_dataset + ".*"))
|
| 110 |
+
for input_file in input_files:
|
| 111 |
+
output_file = os.path.split(input_file)[1][len(args.input_dataset):]
|
| 112 |
+
output_file = os.path.join(args.ner_data_dir, args.output_dataset + output_file)
|
| 113 |
+
print("Converting %s to %s" % (input_file, output_file))
|
| 114 |
+
convert_file(input_file, output_file)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
if __name__ == '__main__':
|
| 118 |
+
main()
|
stanza/stanza/utils/datasets/ner/split_wikiner.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Preprocess the WikiNER dataset, by
|
| 3 |
+
1) normalizing tags;
|
| 4 |
+
2) split into train (70%), dev (15%), test (15%) datasets.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import random
|
| 9 |
+
import warnings
|
| 10 |
+
from collections import Counter
|
| 11 |
+
|
| 12 |
+
def read_sentences(filename, encoding):
|
| 13 |
+
sents = []
|
| 14 |
+
cache = []
|
| 15 |
+
skipped = 0
|
| 16 |
+
skip = False
|
| 17 |
+
with open(filename, encoding=encoding) as infile:
|
| 18 |
+
for i, line in enumerate(infile):
|
| 19 |
+
line = line.rstrip()
|
| 20 |
+
if len(line) == 0:
|
| 21 |
+
if len(cache) > 0:
|
| 22 |
+
if not skip:
|
| 23 |
+
sents.append(cache)
|
| 24 |
+
else:
|
| 25 |
+
skipped += 1
|
| 26 |
+
skip = False
|
| 27 |
+
cache = []
|
| 28 |
+
continue
|
| 29 |
+
array = line.split()
|
| 30 |
+
if len(array) != 2:
|
| 31 |
+
skip = True
|
| 32 |
+
warnings.warn("Format error at line {}: {}".format(i+1, line))
|
| 33 |
+
continue
|
| 34 |
+
w, t = array
|
| 35 |
+
cache.append([w, t])
|
| 36 |
+
if len(cache) > 0:
|
| 37 |
+
if not skip:
|
| 38 |
+
sents.append(cache)
|
| 39 |
+
else:
|
| 40 |
+
skipped += 1
|
| 41 |
+
cache = []
|
| 42 |
+
print("Skipped {} examples due to formatting issues.".format(skipped))
|
| 43 |
+
return sents
|
| 44 |
+
|
| 45 |
+
def write_sentences_to_file(sents, filename):
|
| 46 |
+
print(f"Writing {len(sents)} sentences to {filename}")
|
| 47 |
+
with open(filename, 'w') as outfile:
|
| 48 |
+
for sent in sents:
|
| 49 |
+
for pair in sent:
|
| 50 |
+
print(f"{pair[0]}\t{pair[1]}", file=outfile)
|
| 51 |
+
print("", file=outfile)
|
| 52 |
+
|
| 53 |
+
def remap_labels(sents, remap):
|
| 54 |
+
new_sentences = []
|
| 55 |
+
for sentence in sents:
|
| 56 |
+
new_sent = []
|
| 57 |
+
for word in sentence:
|
| 58 |
+
new_sent.append([word[0], remap.get(word[1], word[1])])
|
| 59 |
+
new_sentences.append(new_sent)
|
| 60 |
+
return new_sentences
|
| 61 |
+
|
| 62 |
+
def split_wikiner(directory, *in_filenames, encoding="utf-8", prefix="", suffix="bio", remap=None, shuffle=True, train_fraction=0.7, dev_fraction=0.15, test_section=True):
|
| 63 |
+
random.seed(1234)
|
| 64 |
+
|
| 65 |
+
sents = []
|
| 66 |
+
for filename in in_filenames:
|
| 67 |
+
new_sents = read_sentences(filename, encoding)
|
| 68 |
+
print(f"{len(new_sents)} sentences read from {filename}.")
|
| 69 |
+
sents.extend(new_sents)
|
| 70 |
+
|
| 71 |
+
if remap:
|
| 72 |
+
sents = remap_labels(sents, remap)
|
| 73 |
+
|
| 74 |
+
# split
|
| 75 |
+
num = len(sents)
|
| 76 |
+
train_num = int(num*train_fraction)
|
| 77 |
+
if test_section:
|
| 78 |
+
dev_num = int(num*dev_fraction)
|
| 79 |
+
if train_fraction + dev_fraction > 1.0:
|
| 80 |
+
raise ValueError("Train and dev fractions added up to more than 1: {} {} {}".format(train_fraction, dev_fraction))
|
| 81 |
+
else:
|
| 82 |
+
dev_num = num - train_num
|
| 83 |
+
|
| 84 |
+
if shuffle:
|
| 85 |
+
random.shuffle(sents)
|
| 86 |
+
train_sents = sents[:train_num]
|
| 87 |
+
dev_sents = sents[train_num:train_num+dev_num]
|
| 88 |
+
if test_section:
|
| 89 |
+
test_sents = sents[train_num+dev_num:]
|
| 90 |
+
batches = [train_sents, dev_sents, test_sents]
|
| 91 |
+
filenames = [f'train.{suffix}', f'dev.{suffix}', f'test.{suffix}']
|
| 92 |
+
else:
|
| 93 |
+
batches = [train_sents, dev_sents]
|
| 94 |
+
filenames = [f'train.{suffix}', f'dev.{suffix}']
|
| 95 |
+
|
| 96 |
+
if prefix:
|
| 97 |
+
filenames = ['%s.%s' % (prefix, f) for f in filenames]
|
| 98 |
+
for batch, filename in zip(batches, filenames):
|
| 99 |
+
write_sentences_to_file(batch, os.path.join(directory, filename))
|
| 100 |
+
|
| 101 |
+
if __name__ == "__main__":
|
| 102 |
+
in_filename = 'raw/wp2.txt'
|
| 103 |
+
directory = "."
|
| 104 |
+
split_wikiner(directory, in_filename)
|
stanza/stanza/utils/datasets/ner/suc_conll_to_iob.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Process the licensed version of SUC3 to BIO
|
| 3 |
+
|
| 4 |
+
The main program processes the expected location, or you can pass in a
|
| 5 |
+
specific zip or filename to read
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from io import TextIOWrapper
|
| 9 |
+
from zipfile import ZipFile
|
| 10 |
+
|
| 11 |
+
def extract(infile, outfile):
|
| 12 |
+
"""
|
| 13 |
+
Convert the infile to an outfile
|
| 14 |
+
|
| 15 |
+
Assumes the files are already open (this allows you to pass in a zipfile reader, for example)
|
| 16 |
+
|
| 17 |
+
The SUC3 format is like conll, but with the tags in tabs 10 and 11
|
| 18 |
+
"""
|
| 19 |
+
lines = infile.readlines()
|
| 20 |
+
sentences = []
|
| 21 |
+
cur_sentence = []
|
| 22 |
+
for idx, line in enumerate(lines):
|
| 23 |
+
line = line.strip()
|
| 24 |
+
if not line:
|
| 25 |
+
# if we're currently reading a sentence, append it to the list
|
| 26 |
+
if cur_sentence:
|
| 27 |
+
sentences.append(cur_sentence)
|
| 28 |
+
cur_sentence = []
|
| 29 |
+
continue
|
| 30 |
+
|
| 31 |
+
pieces = line.split("\t")
|
| 32 |
+
if len(pieces) < 12:
|
| 33 |
+
raise ValueError("Unexpected line length in the SUC3 dataset at %d" % idx)
|
| 34 |
+
if pieces[10] == 'O':
|
| 35 |
+
cur_sentence.append((pieces[1], "O"))
|
| 36 |
+
else:
|
| 37 |
+
cur_sentence.append((pieces[1], "%s-%s" % (pieces[10], pieces[11])))
|
| 38 |
+
if cur_sentence:
|
| 39 |
+
sentences.append(cur_sentence)
|
| 40 |
+
|
| 41 |
+
for sentence in sentences:
|
| 42 |
+
for word in sentence:
|
| 43 |
+
outfile.write("%s\t%s\n" % word)
|
| 44 |
+
outfile.write("\n")
|
| 45 |
+
|
| 46 |
+
return len(sentences)
|
| 47 |
+
|
| 48 |
+
def extract_from_zip(zip_filename, in_filename, out_filename):
|
| 49 |
+
"""
|
| 50 |
+
Process a single file from SUC3
|
| 51 |
+
|
| 52 |
+
zip_filename: path to SUC3.0.zip
|
| 53 |
+
in_filename: which piece to read
|
| 54 |
+
out_filename: where to write the result
|
| 55 |
+
"""
|
| 56 |
+
with ZipFile(zip_filename) as zin:
|
| 57 |
+
with zin.open(in_filename) as fin:
|
| 58 |
+
with open(out_filename, "w") as fout:
|
| 59 |
+
num = extract(TextIOWrapper(fin, encoding="utf-8"), fout)
|
| 60 |
+
print("Processed %d sentences from %s:%s to %s" % (num, zip_filename, in_filename, out_filename))
|
| 61 |
+
return num
|
| 62 |
+
|
| 63 |
+
def process_suc3(zip_filename, short_name, out_dir):
|
| 64 |
+
extract_from_zip(zip_filename, "SUC3.0/corpus/conll/suc-train.conll", "%s/%s.train.bio" % (out_dir, short_name))
|
| 65 |
+
extract_from_zip(zip_filename, "SUC3.0/corpus/conll/suc-dev.conll", "%s/%s.dev.bio" % (out_dir, short_name))
|
| 66 |
+
extract_from_zip(zip_filename, "SUC3.0/corpus/conll/suc-test.conll", "%s/%s.test.bio" % (out_dir, short_name))
|
| 67 |
+
|
| 68 |
+
def main():
|
| 69 |
+
process_suc3("extern_data/ner/sv_suc3/SUC3.0.zip", "data/ner")
|
| 70 |
+
|
| 71 |
+
if __name__ == '__main__':
|
| 72 |
+
main()
|
stanza/stanza/utils/datasets/pos/__init__.py
ADDED
|
File without changes
|
stanza/stanza/utils/datasets/pos/convert_trees_to_pos.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Turns a constituency treebank into a POS dataset with the tags as the upos column
|
| 3 |
+
|
| 4 |
+
The constituency treebank first has to be converted from the original
|
| 5 |
+
data to PTB style trees. This script converts trees from the
|
| 6 |
+
CONSTITUENCY_DATA_DIR folder to a conllu dataset in the POS_DATA_DIR folder.
|
| 7 |
+
|
| 8 |
+
Note that this doesn't pay any attention to whether or not the tags actually are upos.
|
| 9 |
+
Also not possible: using this for tokenization.
|
| 10 |
+
|
| 11 |
+
TODO: upgrade the POS model to handle xpos datasets with no upos, then make upos/xpos an option here
|
| 12 |
+
|
| 13 |
+
To run this:
|
| 14 |
+
python3 stanza/utils/training/run_pos.py vi_vlsp22
|
| 15 |
+
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import argparse
|
| 19 |
+
import os
|
| 20 |
+
import shutil
|
| 21 |
+
import sys
|
| 22 |
+
|
| 23 |
+
from stanza.models.constituency import tree_reader
|
| 24 |
+
import stanza.utils.default_paths as default_paths
|
| 25 |
+
from stanza.utils.get_tqdm import get_tqdm
|
| 26 |
+
|
| 27 |
+
tqdm = get_tqdm()
|
| 28 |
+
|
| 29 |
+
SHARDS = ("train", "dev", "test")
|
| 30 |
+
|
| 31 |
+
def convert_file(in_file, out_file, upos):
|
| 32 |
+
print("Reading %s" % in_file)
|
| 33 |
+
trees = tree_reader.read_tree_file(in_file)
|
| 34 |
+
print("Writing %s" % out_file)
|
| 35 |
+
with open(out_file, "w") as fout:
|
| 36 |
+
for tree in tqdm(trees):
|
| 37 |
+
tree = tree.simplify_labels()
|
| 38 |
+
text = " ".join(tree.leaf_labels())
|
| 39 |
+
fout.write("# text = %s\n" % text)
|
| 40 |
+
|
| 41 |
+
for pt_idx, pt in enumerate(tree.yield_preterminals()):
|
| 42 |
+
# word index
|
| 43 |
+
fout.write("%d\t" % (pt_idx+1))
|
| 44 |
+
# word
|
| 45 |
+
fout.write("%s\t" % pt.children[0].label)
|
| 46 |
+
# don't know the lemma
|
| 47 |
+
fout.write("_\t")
|
| 48 |
+
# always put the tag, whatever it is, in the upos (for now)
|
| 49 |
+
if upos:
|
| 50 |
+
fout.write("%s\t_\t" % pt.label)
|
| 51 |
+
else:
|
| 52 |
+
fout.write("_\t%s\t" % pt.label)
|
| 53 |
+
# don't have any features
|
| 54 |
+
fout.write("_\t")
|
| 55 |
+
# so word 0 fake dep on root, everyone else fake dep on previous word
|
| 56 |
+
fout.write("%d\t" % pt_idx)
|
| 57 |
+
if pt_idx == 0:
|
| 58 |
+
fout.write("root")
|
| 59 |
+
else:
|
| 60 |
+
fout.write("dep")
|
| 61 |
+
fout.write("\t_\t_\n")
|
| 62 |
+
fout.write("\n")
|
| 63 |
+
|
| 64 |
+
def convert_treebank(short_name, upos, output_name, paths):
|
| 65 |
+
in_dir = paths["CONSTITUENCY_DATA_DIR"]
|
| 66 |
+
in_files = [os.path.join(in_dir, "%s_%s.mrg" % (short_name, shard)) for shard in SHARDS]
|
| 67 |
+
for in_file in in_files:
|
| 68 |
+
if not os.path.exists(in_file):
|
| 69 |
+
raise FileNotFoundError("Cannot find expected datafile %s" % in_file)
|
| 70 |
+
|
| 71 |
+
out_dir = paths["POS_DATA_DIR"]
|
| 72 |
+
if not os.path.exists(out_dir):
|
| 73 |
+
os.makedirs(out_dir)
|
| 74 |
+
if output_name is None:
|
| 75 |
+
output_name = short_name
|
| 76 |
+
out_files = [os.path.join(out_dir, "%s.%s.in.conllu" % (output_name, shard)) for shard in SHARDS]
|
| 77 |
+
gold_files = [os.path.join(out_dir, "%s.%s.gold.conllu" % (output_name, shard)) for shard in SHARDS]
|
| 78 |
+
|
| 79 |
+
for in_file, out_file in zip(in_files, out_files):
|
| 80 |
+
convert_file(in_file, out_file, upos)
|
| 81 |
+
for out_file, gold_file in zip(out_files, gold_files):
|
| 82 |
+
shutil.copy2(out_file, gold_file)
|
| 83 |
+
|
| 84 |
+
if __name__ == '__main__':
|
| 85 |
+
parser = argparse.ArgumentParser()
|
| 86 |
+
parser.add_argument("dataset", help="Which dataset to process from trees to POS")
|
| 87 |
+
parser.add_argument("--upos", action="store_true", default=False, help="Store tags on the UPOS")
|
| 88 |
+
parser.add_argument("--xpos", dest="upos", action="store_false", help="Store tags on the XPOS")
|
| 89 |
+
parser.add_argument("--output_name", default=None, help="What name to give the output dataset. If blank, will use the dataset arg")
|
| 90 |
+
args = parser.parse_args()
|
| 91 |
+
|
| 92 |
+
paths = default_paths.get_default_paths()
|
| 93 |
+
|
| 94 |
+
convert_treebank(args.dataset, args.upos, args.output_name, paths)
|
stanza/stanza/utils/datasets/prepare_tokenizer_data.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
import re
|
| 5 |
+
import sys
|
| 6 |
+
|
| 7 |
+
from collections import Counter
|
| 8 |
+
|
| 9 |
+
"""
|
| 10 |
+
Data is output in 4 files:
|
| 11 |
+
|
| 12 |
+
a file containing the mwt information
|
| 13 |
+
a file containing the words and sentences in conllu format
|
| 14 |
+
a file containing the raw text of each paragraph
|
| 15 |
+
a file of 0,1,2 indicating word break or sentence break on a character level for the raw text
|
| 16 |
+
1: end of word
|
| 17 |
+
2: end of sentence
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
PARAGRAPH_BREAK = re.compile(r'\n\s*\n')
|
| 21 |
+
|
| 22 |
+
def is_para_break(index, text):
|
| 23 |
+
""" Detect if a paragraph break can be found, and return the length of the paragraph break sequence. """
|
| 24 |
+
if text[index] == '\n':
|
| 25 |
+
para_break = PARAGRAPH_BREAK.match(text, index)
|
| 26 |
+
if para_break:
|
| 27 |
+
break_len = len(para_break.group(0))
|
| 28 |
+
return True, break_len
|
| 29 |
+
return False, 0
|
| 30 |
+
|
| 31 |
+
def find_next_word(index, text, word, output):
|
| 32 |
+
"""
|
| 33 |
+
Locate the next word in the text. In case a paragraph break is found, also write paragraph break to labels.
|
| 34 |
+
"""
|
| 35 |
+
idx = 0
|
| 36 |
+
word_sofar = ''
|
| 37 |
+
while index < len(text) and idx < len(word):
|
| 38 |
+
para_break, break_len = is_para_break(index, text)
|
| 39 |
+
if para_break:
|
| 40 |
+
# multiple newlines found, paragraph break
|
| 41 |
+
if len(word_sofar) > 0:
|
| 42 |
+
assert re.match(r'^\s+$', word_sofar), 'Found non-empty string at the end of a paragraph that doesn\'t match any token: |{}|'.format(word_sofar)
|
| 43 |
+
word_sofar = ''
|
| 44 |
+
|
| 45 |
+
output.write('\n\n')
|
| 46 |
+
index += break_len - 1
|
| 47 |
+
elif re.match(r'^\s$', text[index]) and not re.match(r'^\s$', word[idx]):
|
| 48 |
+
# whitespace found, and whitespace is not part of a word
|
| 49 |
+
word_sofar += text[index]
|
| 50 |
+
else:
|
| 51 |
+
# non-whitespace char, or a whitespace char that's part of a word
|
| 52 |
+
word_sofar += text[index]
|
| 53 |
+
assert text[index].replace('\n', ' ') == word[idx], "Character mismatch: raw text contains |%s| but the next word is |%s|." % (word_sofar, word)
|
| 54 |
+
idx += 1
|
| 55 |
+
index += 1
|
| 56 |
+
return index, word_sofar
|
| 57 |
+
|
| 58 |
+
def main(args):
|
| 59 |
+
parser = argparse.ArgumentParser()
|
| 60 |
+
|
| 61 |
+
parser.add_argument('plaintext_file', type=str, help="Plaintext file containing the raw input")
|
| 62 |
+
parser.add_argument('conllu_file', type=str, help="CoNLL-U file containing tokens and sentence breaks")
|
| 63 |
+
parser.add_argument('-o', '--output', default=None, type=str, help="Output file name; output to the console if not specified (the default)")
|
| 64 |
+
parser.add_argument('-m', '--mwt_output', default=None, type=str, help="Output file name for MWT expansions; output to the console if not specified (the default)")
|
| 65 |
+
|
| 66 |
+
args = parser.parse_args(args=args)
|
| 67 |
+
|
| 68 |
+
with open(args.plaintext_file, 'r', encoding='utf-8') as f:
|
| 69 |
+
text = ''.join(f.readlines())
|
| 70 |
+
textlen = len(text)
|
| 71 |
+
|
| 72 |
+
if args.output is None:
|
| 73 |
+
output = sys.stdout
|
| 74 |
+
else:
|
| 75 |
+
outdir = os.path.split(args.output)[0]
|
| 76 |
+
os.makedirs(outdir, exist_ok=True)
|
| 77 |
+
output = open(args.output, 'w')
|
| 78 |
+
|
| 79 |
+
index = 0 # character offset in rawtext
|
| 80 |
+
|
| 81 |
+
mwt_expansions = []
|
| 82 |
+
with open(args.conllu_file, 'r', encoding='utf-8') as f:
|
| 83 |
+
buf = ''
|
| 84 |
+
mwtbegin = 0
|
| 85 |
+
mwtend = -1
|
| 86 |
+
expanded = []
|
| 87 |
+
last_comments = ""
|
| 88 |
+
for line in f:
|
| 89 |
+
line = line.strip()
|
| 90 |
+
if len(line):
|
| 91 |
+
if line[0] == "#":
|
| 92 |
+
# comment, don't do anything
|
| 93 |
+
if len(last_comments) == 0:
|
| 94 |
+
last_comments = line
|
| 95 |
+
continue
|
| 96 |
+
|
| 97 |
+
line = line.split('\t')
|
| 98 |
+
if '.' in line[0]:
|
| 99 |
+
# the tokenizer doesn't deal with ellipsis
|
| 100 |
+
continue
|
| 101 |
+
|
| 102 |
+
word = line[1]
|
| 103 |
+
if '-' in line[0]:
|
| 104 |
+
# multiword token
|
| 105 |
+
mwtbegin, mwtend = [int(x) for x in line[0].split('-')]
|
| 106 |
+
lastmwt = word
|
| 107 |
+
expanded = []
|
| 108 |
+
elif mwtbegin <= int(line[0]) < mwtend:
|
| 109 |
+
expanded += [word]
|
| 110 |
+
continue
|
| 111 |
+
elif int(line[0]) == mwtend:
|
| 112 |
+
expanded += [word]
|
| 113 |
+
expanded = [x.lower() for x in expanded] # evaluation doesn't care about case
|
| 114 |
+
mwt_expansions += [(lastmwt, tuple(expanded))]
|
| 115 |
+
if lastmwt[0].islower() and not expanded[0][0].islower():
|
| 116 |
+
print('Sentence ID with potential wrong MWT expansion: ', last_comments, file=sys.stderr)
|
| 117 |
+
mwtbegin = 0
|
| 118 |
+
mwtend = -1
|
| 119 |
+
lastmwt = None
|
| 120 |
+
continue
|
| 121 |
+
|
| 122 |
+
if len(buf):
|
| 123 |
+
output.write(buf)
|
| 124 |
+
index, word_found = find_next_word(index, text, word, output)
|
| 125 |
+
buf = '0' * (len(word_found)-1) + ('1' if '-' not in line[0] else '3')
|
| 126 |
+
else:
|
| 127 |
+
# sentence break found
|
| 128 |
+
if len(buf):
|
| 129 |
+
assert int(buf[-1]) >= 1
|
| 130 |
+
output.write(buf[:-1] + '{}'.format(int(buf[-1]) + 1))
|
| 131 |
+
buf = ''
|
| 132 |
+
|
| 133 |
+
last_comments = ''
|
| 134 |
+
|
| 135 |
+
status_line = ""
|
| 136 |
+
if args.output:
|
| 137 |
+
output.close()
|
| 138 |
+
status_line = 'Tokenizer labels written to %s\n ' % args.output
|
| 139 |
+
|
| 140 |
+
mwts = Counter(mwt_expansions)
|
| 141 |
+
if args.mwt_output is None:
|
| 142 |
+
print('MWTs:', mwts)
|
| 143 |
+
else:
|
| 144 |
+
with open(args.mwt_output, 'w') as f:
|
| 145 |
+
json.dump(list(mwts.items()), f, indent=2)
|
| 146 |
+
|
| 147 |
+
status_line = status_line + '{} unique MWTs found in data. MWTs written to {}'.format(len(mwts), args.mwt_output)
|
| 148 |
+
print(status_line)
|
| 149 |
+
|
| 150 |
+
if __name__ == '__main__':
|
| 151 |
+
main(sys.argv[1:])
|
stanza/stanza/utils/datasets/prepare_tokenizer_treebank.py
ADDED
|
@@ -0,0 +1,1396 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Prepares train, dev, test for a treebank
|
| 3 |
+
|
| 4 |
+
For example, do
|
| 5 |
+
python -m stanza.utils.datasets.prepare_tokenizer_treebank TREEBANK
|
| 6 |
+
such as
|
| 7 |
+
python -m stanza.utils.datasets.prepare_tokenizer_treebank UD_English-EWT
|
| 8 |
+
|
| 9 |
+
and it will prepare each of train, dev, test
|
| 10 |
+
|
| 11 |
+
There are macros for preparing all of the UD treebanks at once:
|
| 12 |
+
python -m stanza.utils.datasets.prepare_tokenizer_treebank ud_all
|
| 13 |
+
python -m stanza.utils.datasets.prepare_tokenizer_treebank all_ud
|
| 14 |
+
Both are present because I kept forgetting which was the correct one
|
| 15 |
+
|
| 16 |
+
There are a few special case handlings of treebanks in this file:
|
| 17 |
+
- all Vietnamese treebanks have special post-processing to handle
|
| 18 |
+
some of the difficult spacing issues in Vietnamese text
|
| 19 |
+
- treebanks with train and test but no dev split have the
|
| 20 |
+
train data randomly split into two pieces
|
| 21 |
+
- however, instead of splitting very tiny treebanks, we skip those
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
import argparse
|
| 25 |
+
import glob
|
| 26 |
+
import io
|
| 27 |
+
import os
|
| 28 |
+
import random
|
| 29 |
+
import re
|
| 30 |
+
import tempfile
|
| 31 |
+
import zipfile
|
| 32 |
+
|
| 33 |
+
from collections import Counter
|
| 34 |
+
|
| 35 |
+
from stanza.models.common.constant import treebank_to_short_name
|
| 36 |
+
import stanza.utils.datasets.common as common
|
| 37 |
+
from stanza.utils.datasets.common import read_sentences_from_conllu, write_sentences_to_conllu, write_sentences_to_file, INT_RE, MWT_RE, MWT_OR_COPY_RE
|
| 38 |
+
import stanza.utils.datasets.tokenization.convert_ml_cochin as convert_ml_cochin
|
| 39 |
+
import stanza.utils.datasets.tokenization.convert_my_alt as convert_my_alt
|
| 40 |
+
import stanza.utils.datasets.tokenization.convert_vi_vlsp as convert_vi_vlsp
|
| 41 |
+
import stanza.utils.datasets.tokenization.convert_th_best as convert_th_best
|
| 42 |
+
import stanza.utils.datasets.tokenization.convert_th_lst20 as convert_th_lst20
|
| 43 |
+
import stanza.utils.datasets.tokenization.convert_th_orchid as convert_th_orchid
|
| 44 |
+
|
| 45 |
+
def copy_conllu_file(tokenizer_dir, tokenizer_file, dest_dir, dest_file, short_name):
|
| 46 |
+
original = f"{tokenizer_dir}/{short_name}.{tokenizer_file}.conllu"
|
| 47 |
+
copied = f"{dest_dir}/{short_name}.{dest_file}.conllu"
|
| 48 |
+
|
| 49 |
+
print("Copying from %s to %s" % (original, copied))
|
| 50 |
+
# do this instead of shutil.copyfile in case there are manipulations needed
|
| 51 |
+
# for example, we might need to add fake dependencies (TODO: still needed?)
|
| 52 |
+
sents = read_sentences_from_conllu(original)
|
| 53 |
+
write_sentences_to_conllu(copied, sents)
|
| 54 |
+
|
| 55 |
+
def copy_conllu_treebank(treebank, model_type, paths, dest_dir, postprocess=None, augment=True):
|
| 56 |
+
"""
|
| 57 |
+
This utility method copies only the conllu files to the given destination directory.
|
| 58 |
+
|
| 59 |
+
Both POS, lemma, and depparse annotators need this.
|
| 60 |
+
"""
|
| 61 |
+
os.makedirs(dest_dir, exist_ok=True)
|
| 62 |
+
|
| 63 |
+
short_name = treebank_to_short_name(treebank)
|
| 64 |
+
short_language = short_name.split("_")[0]
|
| 65 |
+
|
| 66 |
+
with tempfile.TemporaryDirectory() as tokenizer_dir:
|
| 67 |
+
paths = dict(paths)
|
| 68 |
+
paths["TOKENIZE_DATA_DIR"] = tokenizer_dir
|
| 69 |
+
|
| 70 |
+
# first we process the tokenization data
|
| 71 |
+
args = argparse.Namespace()
|
| 72 |
+
args.augment = augment
|
| 73 |
+
args.prepare_labels = False
|
| 74 |
+
process_treebank(treebank, model_type, paths, args)
|
| 75 |
+
|
| 76 |
+
os.makedirs(dest_dir, exist_ok=True)
|
| 77 |
+
|
| 78 |
+
if postprocess is None:
|
| 79 |
+
postprocess = copy_conllu_file
|
| 80 |
+
|
| 81 |
+
# now we copy the processed conllu data files
|
| 82 |
+
postprocess(tokenizer_dir, "train.gold", dest_dir, "train.in", short_name)
|
| 83 |
+
postprocess(tokenizer_dir, "dev.gold", dest_dir, "dev.in", short_name)
|
| 84 |
+
postprocess(tokenizer_dir, "test.gold", dest_dir, "test.in", short_name)
|
| 85 |
+
if model_type is not common.ModelType.POS and model_type is not common.ModelType.DEPPARSE:
|
| 86 |
+
copy_conllu_file(dest_dir, "dev.in", dest_dir, "dev.gold", short_name)
|
| 87 |
+
copy_conllu_file(dest_dir, "test.in", dest_dir, "test.gold", short_name)
|
| 88 |
+
|
| 89 |
+
def split_train_file(treebank, train_input_conllu, train_output_conllu, dev_output_conllu):
|
| 90 |
+
# set the seed for each data file so that the results are the same
|
| 91 |
+
# regardless of how many treebanks are processed at once
|
| 92 |
+
random.seed(1234)
|
| 93 |
+
|
| 94 |
+
# read and shuffle conllu data
|
| 95 |
+
sents = read_sentences_from_conllu(train_input_conllu)
|
| 96 |
+
random.shuffle(sents)
|
| 97 |
+
n_dev = int(len(sents) * XV_RATIO)
|
| 98 |
+
assert n_dev >= 1, "Dev sentence number less than one."
|
| 99 |
+
n_train = len(sents) - n_dev
|
| 100 |
+
|
| 101 |
+
# split conllu data
|
| 102 |
+
dev_sents = sents[:n_dev]
|
| 103 |
+
train_sents = sents[n_dev:]
|
| 104 |
+
print("Train/dev split not present. Randomly splitting train file from %s to %s and %s" % (train_input_conllu, train_output_conllu, dev_output_conllu))
|
| 105 |
+
print(f"{len(sents)} total sentences found: {n_train} in train, {n_dev} in dev")
|
| 106 |
+
|
| 107 |
+
# write conllu
|
| 108 |
+
write_sentences_to_conllu(train_output_conllu, train_sents)
|
| 109 |
+
write_sentences_to_conllu(dev_output_conllu, dev_sents)
|
| 110 |
+
|
| 111 |
+
return True
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def has_space_after_no(piece):
|
| 115 |
+
if not piece or piece == "_":
|
| 116 |
+
return False
|
| 117 |
+
if piece == "SpaceAfter=No":
|
| 118 |
+
return True
|
| 119 |
+
tags = piece.split("|")
|
| 120 |
+
return any(t == "SpaceAfter=No" for t in tags)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def remove_space_after_no(piece, fail_if_missing=True):
|
| 124 |
+
"""
|
| 125 |
+
Removes a SpaceAfter=No annotation from a single piece of a single word.
|
| 126 |
+
In other words, given a list of conll lines, first call split("\t"), then call this on the -1 column
|
| 127 |
+
"""
|
| 128 |
+
# |SpaceAfter is in UD_Romanian-Nonstandard... seems fitting
|
| 129 |
+
if piece == "SpaceAfter=No" or piece == "|SpaceAfter=No":
|
| 130 |
+
piece = "_"
|
| 131 |
+
elif piece.startswith("SpaceAfter=No|"):
|
| 132 |
+
piece = piece.replace("SpaceAfter=No|", "")
|
| 133 |
+
elif piece.find("|SpaceAfter=No") > 0:
|
| 134 |
+
piece = piece.replace("|SpaceAfter=No", "")
|
| 135 |
+
elif fail_if_missing:
|
| 136 |
+
raise ValueError("Could not find SpaceAfter=No in the given notes field")
|
| 137 |
+
return piece
|
| 138 |
+
|
| 139 |
+
def add_space_after_no(piece, fail_if_found=True):
|
| 140 |
+
if piece == '_':
|
| 141 |
+
return "SpaceAfter=No"
|
| 142 |
+
else:
|
| 143 |
+
if fail_if_found:
|
| 144 |
+
if has_space_after_no(piece):
|
| 145 |
+
raise ValueError("Given notes field already contained SpaceAfter=No")
|
| 146 |
+
return piece + "|SpaceAfter=No"
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def augment_arabic_padt(sents, ratio=0.05):
|
| 150 |
+
"""
|
| 151 |
+
Basic Arabic tokenizer gets the trailing punctuation wrong if there is a blank space.
|
| 152 |
+
|
| 153 |
+
Reason seems to be that there are almost no examples of "text ." in the dataset.
|
| 154 |
+
This function augments the Arabic-PADT dataset with a few such examples.
|
| 155 |
+
TODO: it may very well be that a lot of tokeners have this problem.
|
| 156 |
+
|
| 157 |
+
Also, there are a few examples in UD2.7 which are apparently
|
| 158 |
+
headlines where there is a ' . ' in the middle of the text.
|
| 159 |
+
According to an Arabic speaking labmate, the sentences are
|
| 160 |
+
headlines which could be reasonably split into two items. Having
|
| 161 |
+
them as one item is quite confusing and possibly incorrect, but
|
| 162 |
+
such is life.
|
| 163 |
+
"""
|
| 164 |
+
new_sents = []
|
| 165 |
+
for sentence in sents:
|
| 166 |
+
if len(sentence) < 4:
|
| 167 |
+
raise ValueError("Read a surprisingly short sentence")
|
| 168 |
+
text_line = None
|
| 169 |
+
if sentence[0].startswith("# newdoc") and sentence[3].startswith("# text"):
|
| 170 |
+
text_line = 3
|
| 171 |
+
elif sentence[0].startswith("# newpar") and sentence[2].startswith("# text"):
|
| 172 |
+
text_line = 2
|
| 173 |
+
elif sentence[0].startswith("# sent_id") and sentence[1].startswith("# text"):
|
| 174 |
+
text_line = 1
|
| 175 |
+
else:
|
| 176 |
+
raise ValueError("Could not find text line in %s" % sentence[0].split()[-1])
|
| 177 |
+
|
| 178 |
+
# for some reason performance starts dropping quickly at higher numbers
|
| 179 |
+
if random.random() > ratio:
|
| 180 |
+
continue
|
| 181 |
+
|
| 182 |
+
if (sentence[text_line][-1] in ('.', '؟', '?', '!') and
|
| 183 |
+
sentence[text_line][-2] not in ('.', '؟', '?', '!', ' ') and
|
| 184 |
+
has_space_after_no(sentence[-2].split()[-1]) and
|
| 185 |
+
len(sentence[-1].split()[1]) == 1):
|
| 186 |
+
new_sent = list(sentence)
|
| 187 |
+
new_sent[text_line] = new_sent[text_line][:-1] + ' ' + new_sent[text_line][-1]
|
| 188 |
+
pieces = sentence[-2].split("\t")
|
| 189 |
+
pieces[-1] = remove_space_after_no(pieces[-1])
|
| 190 |
+
new_sent[-2] = "\t".join(pieces)
|
| 191 |
+
assert new_sent != sentence
|
| 192 |
+
new_sents.append(new_sent)
|
| 193 |
+
return sents + new_sents
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def augment_telugu(sents):
|
| 197 |
+
"""
|
| 198 |
+
Add a few sentences with modified punctuation to Telugu_MTG
|
| 199 |
+
|
| 200 |
+
The Telugu-MTG dataset has punctuation separated from the text in
|
| 201 |
+
almost all cases, which makes the tokenizer not learn how to
|
| 202 |
+
process that correctly.
|
| 203 |
+
|
| 204 |
+
All of the Telugu sentences end with their sentence final
|
| 205 |
+
punctuation being separated. Furthermore, all commas are
|
| 206 |
+
separated. We change that on some subset of the sentences to
|
| 207 |
+
make the tools more generalizable on wild text.
|
| 208 |
+
"""
|
| 209 |
+
new_sents = []
|
| 210 |
+
for sentence in sents:
|
| 211 |
+
if not sentence[1].startswith("# text"):
|
| 212 |
+
raise ValueError("Expected the second line of %s to start with # text" % sentence[0])
|
| 213 |
+
if not sentence[2].startswith("# translit"):
|
| 214 |
+
raise ValueError("Expected the second line of %s to start with # translit" % sentence[0])
|
| 215 |
+
if sentence[1].endswith(". . .") or sentence[1][-1] not in ('.', '?', '!'):
|
| 216 |
+
continue
|
| 217 |
+
if sentence[1][-1] in ('.', '?', '!') and sentence[1][-2] != ' ' and sentence[1][-3:] != ' ..' and sentence[1][-4:] != ' ...':
|
| 218 |
+
raise ValueError("Sentence %s does not end with space-punctuation, which is against our assumptions for the te_mtg treebank. Please check the augment method to see if it is still needed" % sentence[0])
|
| 219 |
+
if random.random() < 0.1:
|
| 220 |
+
new_sentence = list(sentence)
|
| 221 |
+
new_sentence[1] = new_sentence[1][:-2] + new_sentence[1][-1]
|
| 222 |
+
new_sentence[2] = new_sentence[2][:-2] + new_sentence[2][-1]
|
| 223 |
+
new_sentence[-2] = new_sentence[-2] + "|SpaceAfter=No"
|
| 224 |
+
new_sents.append(new_sentence)
|
| 225 |
+
if sentence[1].find(",") > 1 and random.random() < 0.1:
|
| 226 |
+
new_sentence = list(sentence)
|
| 227 |
+
index = sentence[1].find(",")
|
| 228 |
+
new_sentence[1] = sentence[1][:index-1] + sentence[1][index:]
|
| 229 |
+
index = sentence[1].find(",")
|
| 230 |
+
new_sentence[2] = sentence[2][:index-1] + sentence[2][index:]
|
| 231 |
+
for idx, word in enumerate(new_sentence):
|
| 232 |
+
if idx < 4:
|
| 233 |
+
# skip sent_id, text, transliteration, and the first word
|
| 234 |
+
continue
|
| 235 |
+
if word.split("\t")[1] == ',':
|
| 236 |
+
new_sentence[idx-1] = new_sentence[idx-1] + "|SpaceAfter=No"
|
| 237 |
+
break
|
| 238 |
+
new_sents.append(new_sentence)
|
| 239 |
+
return sents + new_sents
|
| 240 |
+
|
| 241 |
+
COMMA_SEPARATED_RE = re.compile(" ([a-zA-Z]+)[,] ([a-zA-Z]+) ")
|
| 242 |
+
def augment_comma_separations(sents, ratio=0.03):
|
| 243 |
+
"""Find some fraction of the sentences which match "asdf, zzzz" and squish them to "asdf,zzzz"
|
| 244 |
+
|
| 245 |
+
This leaves the tokens and all of the other data the same. The
|
| 246 |
+
only change made is to change SpaceAfter=No for the "," token and
|
| 247 |
+
adjust the #text line, with the assumption that the conllu->txt
|
| 248 |
+
conversion will correctly handle this change.
|
| 249 |
+
|
| 250 |
+
This was particularly an issue for Spanish-AnCora, but it's
|
| 251 |
+
reasonable to think it could happen to any dataset. Currently
|
| 252 |
+
this just operates on commas and ascii letters to avoid
|
| 253 |
+
accidentally squishing anything that shouldn't be squished.
|
| 254 |
+
|
| 255 |
+
UD_Spanish-AnCora 2.7 had a problem is with this sentence:
|
| 256 |
+
# orig_file_sentence 143#5
|
| 257 |
+
In this sentence, there was a comma smashed next to a token.
|
| 258 |
+
|
| 259 |
+
Fixing just this one sentence is not sufficient to tokenize
|
| 260 |
+
"asdf,zzzz" as desired, so we also augment by some fraction where
|
| 261 |
+
we have squished "asdf, zzzz" into "asdf,zzzz".
|
| 262 |
+
|
| 263 |
+
This exact example was later fixed in UD 2.8, but it should still
|
| 264 |
+
potentially be useful for compensating for typos.
|
| 265 |
+
"""
|
| 266 |
+
new_sents = []
|
| 267 |
+
for sentence in sents:
|
| 268 |
+
for text_idx, text_line in enumerate(sentence):
|
| 269 |
+
# look for the line that starts with "# text".
|
| 270 |
+
# keep going until we find it, or silently ignore it
|
| 271 |
+
# if the dataset isn't in that format
|
| 272 |
+
if text_line.startswith("# text"):
|
| 273 |
+
break
|
| 274 |
+
else:
|
| 275 |
+
continue
|
| 276 |
+
|
| 277 |
+
match = COMMA_SEPARATED_RE.search(sentence[text_idx])
|
| 278 |
+
if match and random.random() < ratio:
|
| 279 |
+
for idx, word in enumerate(sentence):
|
| 280 |
+
if word.startswith("#"):
|
| 281 |
+
continue
|
| 282 |
+
# find() doesn't work because we wind up finding substrings
|
| 283 |
+
if word.split("\t")[1] != match.group(1):
|
| 284 |
+
continue
|
| 285 |
+
if sentence[idx+1].split("\t")[1] != ',':
|
| 286 |
+
continue
|
| 287 |
+
if sentence[idx+2].split("\t")[1] != match.group(2):
|
| 288 |
+
continue
|
| 289 |
+
break
|
| 290 |
+
if idx == len(sentence) - 1:
|
| 291 |
+
# this can happen with MWTs. we may actually just
|
| 292 |
+
# want to skip MWTs anyway, so no big deal
|
| 293 |
+
continue
|
| 294 |
+
# now idx+1 should be the line with the comma in it
|
| 295 |
+
comma = sentence[idx+1]
|
| 296 |
+
pieces = comma.split("\t")
|
| 297 |
+
assert pieces[1] == ','
|
| 298 |
+
pieces[-1] = add_space_after_no(pieces[-1])
|
| 299 |
+
comma = "\t".join(pieces)
|
| 300 |
+
new_sent = sentence[:idx+1] + [comma] + sentence[idx+2:]
|
| 301 |
+
|
| 302 |
+
text_offset = sentence[text_idx].find(match.group(1) + ", " + match.group(2))
|
| 303 |
+
text_len = len(match.group(1) + ", " + match.group(2))
|
| 304 |
+
new_text = sentence[text_idx][:text_offset] + match.group(1) + "," + match.group(2) + sentence[text_idx][text_offset+text_len:]
|
| 305 |
+
new_sent[text_idx] = new_text
|
| 306 |
+
|
| 307 |
+
new_sents.append(new_sent)
|
| 308 |
+
|
| 309 |
+
print("Added %d new sentences with asdf, zzzz -> asdf,zzzz" % len(new_sents))
|
| 310 |
+
|
| 311 |
+
return sents + new_sents
|
| 312 |
+
|
| 313 |
+
def augment_move_comma(sents, ratio=0.02):
|
| 314 |
+
"""
|
| 315 |
+
Move the comma from after a word to before the next word some fraction of the time
|
| 316 |
+
|
| 317 |
+
We looks for this exact pattern:
|
| 318 |
+
w1, w2
|
| 319 |
+
and replace it with
|
| 320 |
+
w1 ,w2
|
| 321 |
+
|
| 322 |
+
The idea is that this is a relatively common typo, but the tool
|
| 323 |
+
won't learn how to tokenize it without some help.
|
| 324 |
+
|
| 325 |
+
Note that this modification replaces the original text.
|
| 326 |
+
"""
|
| 327 |
+
new_sents = []
|
| 328 |
+
num_operations = 0
|
| 329 |
+
for sentence in sents:
|
| 330 |
+
if random.random() > ratio:
|
| 331 |
+
new_sents.append(sentence)
|
| 332 |
+
continue
|
| 333 |
+
|
| 334 |
+
found = False
|
| 335 |
+
for word_idx, word in enumerate(sentence):
|
| 336 |
+
if word.startswith("#"):
|
| 337 |
+
continue
|
| 338 |
+
if word_idx == 0 or word_idx >= len(sentence) - 2:
|
| 339 |
+
continue
|
| 340 |
+
pieces = word.split("\t")
|
| 341 |
+
if pieces[1] == ',' and not has_space_after_no(pieces[-1]):
|
| 342 |
+
# found a comma with a space after it
|
| 343 |
+
prev_word = sentence[word_idx-1]
|
| 344 |
+
if not has_space_after_no(prev_word.split("\t")[-1]):
|
| 345 |
+
# unfortunately, the previous word also had a
|
| 346 |
+
# space after it. does not fit what we are
|
| 347 |
+
# looking for
|
| 348 |
+
continue
|
| 349 |
+
# also, want to skip instances near MWT or copy nodes,
|
| 350 |
+
# since those are harder to rearrange
|
| 351 |
+
next_word = sentence[word_idx+1]
|
| 352 |
+
if MWT_OR_COPY_RE.match(next_word.split("\t")[0]):
|
| 353 |
+
continue
|
| 354 |
+
if MWT_OR_COPY_RE.match(prev_word.split("\t")[0]):
|
| 355 |
+
continue
|
| 356 |
+
# at this point, the previous word has no space and the comma does
|
| 357 |
+
found = True
|
| 358 |
+
break
|
| 359 |
+
|
| 360 |
+
if not found:
|
| 361 |
+
new_sents.append(sentence)
|
| 362 |
+
continue
|
| 363 |
+
|
| 364 |
+
new_sentence = list(sentence)
|
| 365 |
+
|
| 366 |
+
pieces = new_sentence[word_idx].split("\t")
|
| 367 |
+
pieces[-1] = add_space_after_no(pieces[-1])
|
| 368 |
+
new_sentence[word_idx] = "\t".join(pieces)
|
| 369 |
+
|
| 370 |
+
pieces = new_sentence[word_idx-1].split("\t")
|
| 371 |
+
prev_word = pieces[1]
|
| 372 |
+
pieces[-1] = remove_space_after_no(pieces[-1])
|
| 373 |
+
new_sentence[word_idx-1] = "\t".join(pieces)
|
| 374 |
+
|
| 375 |
+
next_word = new_sentence[word_idx+1].split("\t")[1]
|
| 376 |
+
|
| 377 |
+
for text_idx, text_line in enumerate(sentence):
|
| 378 |
+
# look for the line that starts with "# text".
|
| 379 |
+
# keep going until we find it, or silently ignore it
|
| 380 |
+
# if the dataset isn't in that format
|
| 381 |
+
if text_line.startswith("# text"):
|
| 382 |
+
old_chunk = prev_word + ", " + next_word
|
| 383 |
+
new_chunk = prev_word + " ," + next_word
|
| 384 |
+
word_idx = text_line.find(old_chunk)
|
| 385 |
+
if word_idx < 0:
|
| 386 |
+
raise RuntimeError("Unexpected #text line which did not contain the original text to be modified. Looking for\n" + old_chunk + "\n" + text_line)
|
| 387 |
+
new_text_line = text_line[:word_idx] + new_chunk + text_line[word_idx+len(old_chunk):]
|
| 388 |
+
new_sentence[text_idx] = new_text_line
|
| 389 |
+
break
|
| 390 |
+
|
| 391 |
+
new_sents.append(new_sentence)
|
| 392 |
+
num_operations = num_operations + 1
|
| 393 |
+
|
| 394 |
+
print("Swapped 'w1, w2' for 'w1 ,w2' %d times" % num_operations)
|
| 395 |
+
return new_sents
|
| 396 |
+
|
| 397 |
+
def augment_apos(sents):
|
| 398 |
+
|
| 399 |
+
"""
|
| 400 |
+
If there are no instances of ’ in the dataset, but there are instances of ',
|
| 401 |
+
we replace some fraction of ' with ’ so that the tokenizer will recognize it.
|
| 402 |
+
|
| 403 |
+
# TODO: we could do it the other way around as well
|
| 404 |
+
"""
|
| 405 |
+
has_unicode_apos = False
|
| 406 |
+
has_ascii_apos = False
|
| 407 |
+
for sent_idx, sent in enumerate(sents):
|
| 408 |
+
if len(sent) == 0:
|
| 409 |
+
raise AssertionError("Got a blank sentence in position %d!" % sent_idx)
|
| 410 |
+
for line in sent:
|
| 411 |
+
if line.startswith("# text"):
|
| 412 |
+
if line.find("'") >= 0:
|
| 413 |
+
has_ascii_apos = True
|
| 414 |
+
if line.find("’") >= 0:
|
| 415 |
+
has_unicode_apos = True
|
| 416 |
+
break
|
| 417 |
+
else:
|
| 418 |
+
raise ValueError("Cannot find '# text' in sentences %d. First line: %s" % (sent_idx, sent[0]))
|
| 419 |
+
|
| 420 |
+
if has_unicode_apos or not has_ascii_apos:
|
| 421 |
+
return sents
|
| 422 |
+
|
| 423 |
+
new_sents = []
|
| 424 |
+
for sent in sents:
|
| 425 |
+
if random.random() > 0.05:
|
| 426 |
+
new_sents.append(sent)
|
| 427 |
+
continue
|
| 428 |
+
new_sent = []
|
| 429 |
+
for line in sent:
|
| 430 |
+
if line.startswith("# text"):
|
| 431 |
+
new_sent.append(line.replace("'", "’"))
|
| 432 |
+
elif line.startswith("#"):
|
| 433 |
+
new_sent.append(line)
|
| 434 |
+
else:
|
| 435 |
+
pieces = line.split("\t")
|
| 436 |
+
pieces[1] = pieces[1].replace("'", "’")
|
| 437 |
+
new_sent.append("\t".join(pieces))
|
| 438 |
+
new_sents.append(new_sent)
|
| 439 |
+
|
| 440 |
+
return new_sents
|
| 441 |
+
|
| 442 |
+
def augment_ellipses(sents):
|
| 443 |
+
"""
|
| 444 |
+
Replaces a fraction of '...' with '…'
|
| 445 |
+
"""
|
| 446 |
+
has_ellipses = False
|
| 447 |
+
has_unicode_ellipses = False
|
| 448 |
+
for sent in sents:
|
| 449 |
+
for line in sent:
|
| 450 |
+
if line.startswith("#"):
|
| 451 |
+
continue
|
| 452 |
+
pieces = line.split("\t")
|
| 453 |
+
if pieces[1] == '...':
|
| 454 |
+
has_ellipses = True
|
| 455 |
+
elif pieces[1] == '…':
|
| 456 |
+
has_unicode_ellipses = True
|
| 457 |
+
|
| 458 |
+
if has_unicode_ellipses or not has_ellipses:
|
| 459 |
+
return sents
|
| 460 |
+
|
| 461 |
+
new_sents = []
|
| 462 |
+
|
| 463 |
+
num_updated = 0
|
| 464 |
+
for sent in sents:
|
| 465 |
+
if random.random() > 0.1:
|
| 466 |
+
new_sents.append(sent)
|
| 467 |
+
continue
|
| 468 |
+
found = False
|
| 469 |
+
new_sent = []
|
| 470 |
+
for line in sent:
|
| 471 |
+
if line.startswith("#"):
|
| 472 |
+
new_sent.append(line)
|
| 473 |
+
else:
|
| 474 |
+
pieces = line.split("\t")
|
| 475 |
+
if pieces[1] == '...':
|
| 476 |
+
pieces[1] = '…'
|
| 477 |
+
found = True
|
| 478 |
+
new_sent.append("\t".join(pieces))
|
| 479 |
+
new_sents.append(new_sent)
|
| 480 |
+
if found:
|
| 481 |
+
num_updated = num_updated + 1
|
| 482 |
+
|
| 483 |
+
print("Changed %d sentences to use fancy unicode ellipses" % num_updated)
|
| 484 |
+
return new_sents
|
| 485 |
+
|
| 486 |
+
# https://en.wikipedia.org/wiki/Quotation_mark
|
| 487 |
+
QUOTES = ['"', '“', '”', '«', '»', '「', '」', '《', '》', '„', '″']
|
| 488 |
+
QUOTES_RE = re.compile("(.?)[" + "".join(QUOTES) + "](.+)[" + "".join(QUOTES) + "](.?)")
|
| 489 |
+
# Danish does '«' the other way around from most European languages
|
| 490 |
+
START_QUOTES = ['"', '“', '”', '«', '»', '「', '《', '„', '„', '″']
|
| 491 |
+
END_QUOTES = ['"', '“', '”', '»', '«', '」', '》', '”', '“', '″']
|
| 492 |
+
|
| 493 |
+
def augment_quotes(sents, ratio=0.15):
|
| 494 |
+
"""
|
| 495 |
+
Go through the sentences and replace a fraction of sentences with alternate quotes
|
| 496 |
+
|
| 497 |
+
TODO: for certain languages we may want to make some language-specific changes
|
| 498 |
+
eg Danish, don't add «...»
|
| 499 |
+
"""
|
| 500 |
+
assert len(START_QUOTES) == len(END_QUOTES)
|
| 501 |
+
|
| 502 |
+
counts = Counter()
|
| 503 |
+
new_sents = []
|
| 504 |
+
for sent in sents:
|
| 505 |
+
if random.random() > ratio:
|
| 506 |
+
new_sents.append(sent)
|
| 507 |
+
continue
|
| 508 |
+
|
| 509 |
+
# count if there are exactly 2 quotes in this sentence
|
| 510 |
+
# this is for convenience - otherwise we need to figure out which pairs go together
|
| 511 |
+
count_quotes = sum(1 for x in sent
|
| 512 |
+
if (not x.startswith("#") and
|
| 513 |
+
x.split("\t")[1] in QUOTES))
|
| 514 |
+
if count_quotes != 2:
|
| 515 |
+
new_sents.append(sent)
|
| 516 |
+
continue
|
| 517 |
+
|
| 518 |
+
# choose a pair of quotes from the candidates
|
| 519 |
+
quote_idx = random.choice(range(len(START_QUOTES)))
|
| 520 |
+
start_quote = START_QUOTES[quote_idx]
|
| 521 |
+
end_quote = END_QUOTES[quote_idx]
|
| 522 |
+
counts[start_quote + end_quote] = counts[start_quote + end_quote] + 1
|
| 523 |
+
|
| 524 |
+
new_sent = []
|
| 525 |
+
saw_start = False
|
| 526 |
+
for line in sent:
|
| 527 |
+
if line.startswith("#"):
|
| 528 |
+
new_sent.append(line)
|
| 529 |
+
continue
|
| 530 |
+
pieces = line.split("\t")
|
| 531 |
+
if pieces[1] in QUOTES:
|
| 532 |
+
if saw_start:
|
| 533 |
+
# Note that we don't change the lemma. Presumably it's
|
| 534 |
+
# set to the correct lemma for a quote for this treebank
|
| 535 |
+
pieces[1] = end_quote
|
| 536 |
+
else:
|
| 537 |
+
pieces[1] = start_quote
|
| 538 |
+
saw_start = True
|
| 539 |
+
new_sent.append("\t".join(pieces))
|
| 540 |
+
else:
|
| 541 |
+
new_sent.append(line)
|
| 542 |
+
|
| 543 |
+
for text_idx, text_line in enumerate(new_sent):
|
| 544 |
+
# look for the line that starts with "# text".
|
| 545 |
+
# keep going until we find it, or silently ignore it
|
| 546 |
+
# if the dataset isn't in that format
|
| 547 |
+
if text_line.startswith("# text"):
|
| 548 |
+
replacement = "\\1%s\\2%s\\3" % (start_quote, end_quote)
|
| 549 |
+
new_text_line = QUOTES_RE.sub(replacement, text_line)
|
| 550 |
+
new_sent[text_idx] = new_text_line
|
| 551 |
+
|
| 552 |
+
new_sents.append(new_sent)
|
| 553 |
+
|
| 554 |
+
print("Augmented {} quotes: {}".format(sum(counts.values()), counts))
|
| 555 |
+
return new_sents
|
| 556 |
+
|
| 557 |
+
def find_text_idx(sentence):
|
| 558 |
+
"""
|
| 559 |
+
Return the index of the # text line or -1
|
| 560 |
+
"""
|
| 561 |
+
for idx, line in enumerate(sentence):
|
| 562 |
+
if line.startswith("# text"):
|
| 563 |
+
return idx
|
| 564 |
+
return -1
|
| 565 |
+
|
| 566 |
+
DIGIT_RE = re.compile("[0-9]")
|
| 567 |
+
|
| 568 |
+
def change_indices(line, delta):
|
| 569 |
+
"""
|
| 570 |
+
Adjust all indices in the given sentence by delta. Useful when removing a word, for example
|
| 571 |
+
"""
|
| 572 |
+
if line.startswith("#"):
|
| 573 |
+
return line
|
| 574 |
+
|
| 575 |
+
pieces = line.split("\t")
|
| 576 |
+
if MWT_RE.match(pieces[0]):
|
| 577 |
+
indices = pieces[0].split("-")
|
| 578 |
+
pieces[0] = "%d-%d" % (int(indices[0]) + delta, int(indices[1]) + delta)
|
| 579 |
+
line = "\t".join(pieces)
|
| 580 |
+
return line
|
| 581 |
+
|
| 582 |
+
if MWT_OR_COPY_RE.match(pieces[0]):
|
| 583 |
+
index_pieces = pieces[0].split(".", maxsplit=1)
|
| 584 |
+
pieces[0] = "%d.%s" % (int(index_pieces[0]) + delta, index_pieces[1])
|
| 585 |
+
elif not INT_RE.match(pieces[0]):
|
| 586 |
+
raise NotImplementedError("Unknown index type: %s" % pieces[0])
|
| 587 |
+
else:
|
| 588 |
+
pieces[0] = str(int(pieces[0]) + delta)
|
| 589 |
+
if pieces[6] != '_':
|
| 590 |
+
# copy nodes don't have basic dependencies in the es_ancora treebank
|
| 591 |
+
dep = int(pieces[6])
|
| 592 |
+
if dep != 0:
|
| 593 |
+
pieces[6] = str(int(dep) + delta)
|
| 594 |
+
if pieces[8] != '_':
|
| 595 |
+
dep_pieces = pieces[8].split(":", maxsplit=1)
|
| 596 |
+
if DIGIT_RE.search(dep_pieces[1]):
|
| 597 |
+
raise NotImplementedError("Need to handle multiple additional deps:\n%s" % line)
|
| 598 |
+
if int(dep_pieces[0]) != 0:
|
| 599 |
+
pieces[8] = str(int(dep_pieces[0]) + delta) + ":" + dep_pieces[1]
|
| 600 |
+
line = "\t".join(pieces)
|
| 601 |
+
return line
|
| 602 |
+
|
| 603 |
+
def augment_initial_punct(sents, ratio=0.20):
|
| 604 |
+
"""
|
| 605 |
+
If a sentence starts with certain punct marks, occasionally use the same sentence without the initial punct.
|
| 606 |
+
|
| 607 |
+
Currently this just handles ¿
|
| 608 |
+
This helps languages such as CA and ES where the models go awry when the initial ¿ is missing.
|
| 609 |
+
"""
|
| 610 |
+
new_sents = []
|
| 611 |
+
for sent in sents:
|
| 612 |
+
if random.random() > ratio:
|
| 613 |
+
continue
|
| 614 |
+
|
| 615 |
+
text_idx = find_text_idx(sent)
|
| 616 |
+
text_line = sent[text_idx]
|
| 617 |
+
if text_line.count("¿") != 1:
|
| 618 |
+
# only handle sentences with exactly one ¿
|
| 619 |
+
continue
|
| 620 |
+
|
| 621 |
+
# find the first line with actual text
|
| 622 |
+
for idx, line in enumerate(sent):
|
| 623 |
+
if line.startswith("#"):
|
| 624 |
+
continue
|
| 625 |
+
break
|
| 626 |
+
if idx >= len(sent) - 1:
|
| 627 |
+
raise ValueError("Unexpectedly an entire sentence is comments")
|
| 628 |
+
pieces = line.split("\t")
|
| 629 |
+
if pieces[1] != '¿':
|
| 630 |
+
continue
|
| 631 |
+
if has_space_after_no(pieces[-1]):
|
| 632 |
+
replace_text = "¿"
|
| 633 |
+
else:
|
| 634 |
+
replace_text = "¿ "
|
| 635 |
+
|
| 636 |
+
new_sent = sent[:idx] + sent[idx+1:]
|
| 637 |
+
new_sent[text_idx] = text_line.replace(replace_text, "")
|
| 638 |
+
|
| 639 |
+
# now need to update all indices
|
| 640 |
+
new_sent = [change_indices(x, -1) for x in new_sent]
|
| 641 |
+
new_sents.append(new_sent)
|
| 642 |
+
|
| 643 |
+
if len(new_sents) > 0:
|
| 644 |
+
print("Added %d sentences with the leading ¿ removed" % len(new_sents))
|
| 645 |
+
|
| 646 |
+
return sents + new_sents
|
| 647 |
+
|
| 648 |
+
|
| 649 |
+
def augment_brackets(sents, ratio=0.1):
|
| 650 |
+
"""
|
| 651 |
+
If there are no sentences with [], transform some () into []
|
| 652 |
+
"""
|
| 653 |
+
new_sents = []
|
| 654 |
+
for sent in sents:
|
| 655 |
+
text_idx = find_text_idx(sent)
|
| 656 |
+
text_line = sent[text_idx]
|
| 657 |
+
if text_line.count("[") > 0 or text_line.count("]") > 0:
|
| 658 |
+
# found a square bracket, so, never mind
|
| 659 |
+
return sents
|
| 660 |
+
|
| 661 |
+
for sent in sents:
|
| 662 |
+
if random.random() > ratio:
|
| 663 |
+
continue
|
| 664 |
+
|
| 665 |
+
text_idx = find_text_idx(sent)
|
| 666 |
+
text_line = sent[text_idx]
|
| 667 |
+
if text_line.count("(") == 0 and text_line.count(")") == 0:
|
| 668 |
+
continue
|
| 669 |
+
|
| 670 |
+
text_line = text_line.replace("(", "[").replace(")", "]")
|
| 671 |
+
new_sent = list(sent)
|
| 672 |
+
new_sent[text_idx] = text_line
|
| 673 |
+
for idx, line in enumerate(new_sent):
|
| 674 |
+
if line.startswith("#"):
|
| 675 |
+
continue
|
| 676 |
+
pieces = line.split("\t")
|
| 677 |
+
if pieces[1] == '(':
|
| 678 |
+
pieces[1] = '['
|
| 679 |
+
elif pieces[1] == ')':
|
| 680 |
+
pieces[1] = ']'
|
| 681 |
+
new_sent[idx] = "\t".join(pieces)
|
| 682 |
+
new_sents.append(new_sent)
|
| 683 |
+
|
| 684 |
+
if len(new_sents) > 0:
|
| 685 |
+
print("Added %d sentences with parens replaced with square brackets" % len(new_sents))
|
| 686 |
+
|
| 687 |
+
return sents + new_sents
|
| 688 |
+
|
| 689 |
+
|
| 690 |
+
def augment_punct(sents):
|
| 691 |
+
"""
|
| 692 |
+
If there are no instances of ’ in the dataset, but there are instances of ',
|
| 693 |
+
we replace some fraction of ' with ’ so that the tokenizer will recognize it.
|
| 694 |
+
|
| 695 |
+
Also augments with ... / …
|
| 696 |
+
"""
|
| 697 |
+
new_sents = augment_apos(sents)
|
| 698 |
+
new_sents = augment_quotes(new_sents)
|
| 699 |
+
new_sents = augment_move_comma(new_sents)
|
| 700 |
+
new_sents = augment_comma_separations(new_sents)
|
| 701 |
+
new_sents = augment_initial_punct(new_sents)
|
| 702 |
+
new_sents = augment_ellipses(new_sents)
|
| 703 |
+
new_sents = augment_brackets(new_sents)
|
| 704 |
+
|
| 705 |
+
return new_sents
|
| 706 |
+
|
| 707 |
+
|
| 708 |
+
|
| 709 |
+
def write_augmented_dataset(input_conllu, output_conllu, augment_function):
|
| 710 |
+
# set the seed for each data file so that the results are the same
|
| 711 |
+
# regardless of how many treebanks are processed at once
|
| 712 |
+
random.seed(1234)
|
| 713 |
+
|
| 714 |
+
# read and shuffle conllu data
|
| 715 |
+
sents = read_sentences_from_conllu(input_conllu)
|
| 716 |
+
|
| 717 |
+
# the actual meat of the function - produce new sentences
|
| 718 |
+
new_sents = augment_function(sents)
|
| 719 |
+
|
| 720 |
+
write_sentences_to_conllu(output_conllu, new_sents)
|
| 721 |
+
|
| 722 |
+
def remove_spaces_from_sentences(sents):
|
| 723 |
+
"""
|
| 724 |
+
Makes sure every word in the list of sentences has SpaceAfter=No.
|
| 725 |
+
|
| 726 |
+
Returns a new list of sentences
|
| 727 |
+
"""
|
| 728 |
+
new_sents = []
|
| 729 |
+
for sentence in sents:
|
| 730 |
+
new_sentence = []
|
| 731 |
+
for word in sentence:
|
| 732 |
+
if word.startswith("#"):
|
| 733 |
+
new_sentence.append(word)
|
| 734 |
+
continue
|
| 735 |
+
pieces = word.split("\t")
|
| 736 |
+
if pieces[-1] == "_":
|
| 737 |
+
pieces[-1] = "SpaceAfter=No"
|
| 738 |
+
elif pieces[-1].find("SpaceAfter=No") >= 0:
|
| 739 |
+
pass
|
| 740 |
+
else:
|
| 741 |
+
raise ValueError("oops")
|
| 742 |
+
word = "\t".join(pieces)
|
| 743 |
+
new_sentence.append(word)
|
| 744 |
+
new_sents.append(new_sentence)
|
| 745 |
+
return new_sents
|
| 746 |
+
|
| 747 |
+
def remove_spaces(input_conllu, output_conllu):
|
| 748 |
+
"""
|
| 749 |
+
Turns a dataset into something appropriate for building a segmenter.
|
| 750 |
+
|
| 751 |
+
For example, this works well on the Korean datasets.
|
| 752 |
+
"""
|
| 753 |
+
sents = read_sentences_from_conllu(input_conllu)
|
| 754 |
+
|
| 755 |
+
new_sents = remove_spaces_from_sentences(sents)
|
| 756 |
+
|
| 757 |
+
write_sentences_to_conllu(output_conllu, new_sents)
|
| 758 |
+
|
| 759 |
+
|
| 760 |
+
def build_combined_korean_dataset(udbase_dir, tokenizer_dir, short_name, dataset, output_conllu):
|
| 761 |
+
"""
|
| 762 |
+
Builds a combined dataset out of multiple Korean datasets.
|
| 763 |
+
|
| 764 |
+
Currently this uses GSD and Kaist. If a segmenter-appropriate
|
| 765 |
+
dataset was requested, spaces are removed.
|
| 766 |
+
|
| 767 |
+
TODO: we need to handle the difference in xpos tags somehow.
|
| 768 |
+
"""
|
| 769 |
+
gsd_conllu = common.find_treebank_dataset_file("UD_Korean-GSD", udbase_dir, dataset, "conllu")
|
| 770 |
+
kaist_conllu = common.find_treebank_dataset_file("UD_Korean-Kaist", udbase_dir, dataset, "conllu")
|
| 771 |
+
sents = read_sentences_from_conllu(gsd_conllu) + read_sentences_from_conllu(kaist_conllu)
|
| 772 |
+
|
| 773 |
+
segmenter = short_name.endswith("_seg")
|
| 774 |
+
if segmenter:
|
| 775 |
+
sents = remove_spaces_from_sentences(sents)
|
| 776 |
+
|
| 777 |
+
write_sentences_to_conllu(output_conllu, sents)
|
| 778 |
+
|
| 779 |
+
def build_combined_korean(udbase_dir, tokenizer_dir, short_name):
|
| 780 |
+
for dataset in ("train", "dev", "test"):
|
| 781 |
+
output_conllu = common.tokenizer_conllu_name(tokenizer_dir, short_name, dataset)
|
| 782 |
+
build_combined_korean_dataset(udbase_dir, tokenizer_dir, short_name, dataset, output_conllu)
|
| 783 |
+
|
| 784 |
+
def build_combined_italian_dataset(paths, model_type, dataset):
|
| 785 |
+
udbase_dir = paths["UDBASE"]
|
| 786 |
+
if dataset == 'train':
|
| 787 |
+
# could maybe add ParTUT, but that dataset has a slightly different xpos set
|
| 788 |
+
# (no DE or I)
|
| 789 |
+
# and I didn't feel like sorting through the differences
|
| 790 |
+
# Note: currently these each have small changes compared with
|
| 791 |
+
# the UD2.11 release. See the issues (possibly closed by now)
|
| 792 |
+
# filed by AngledLuffa on each of the treebanks for more info.
|
| 793 |
+
treebanks = [
|
| 794 |
+
"UD_Italian-ISDT",
|
| 795 |
+
"UD_Italian-VIT",
|
| 796 |
+
]
|
| 797 |
+
if model_type is not common.ModelType.TOKENIZER:
|
| 798 |
+
treebanks.extend([
|
| 799 |
+
"UD_Italian-TWITTIRO",
|
| 800 |
+
"UD_Italian-PoSTWITA"
|
| 801 |
+
])
|
| 802 |
+
print("Building {} dataset out of {}".format(model_type, " ".join(treebanks)))
|
| 803 |
+
sents = []
|
| 804 |
+
for treebank in treebanks:
|
| 805 |
+
conllu_file = common.find_treebank_dataset_file(treebank, udbase_dir, dataset, "conllu", fail=True)
|
| 806 |
+
sents.extend(read_sentences_from_conllu(conllu_file))
|
| 807 |
+
else:
|
| 808 |
+
istd_conllu = common.find_treebank_dataset_file("UD_Italian-ISDT", udbase_dir, dataset, "conllu")
|
| 809 |
+
sents = read_sentences_from_conllu(istd_conllu)
|
| 810 |
+
|
| 811 |
+
return sents
|
| 812 |
+
|
| 813 |
+
def check_gum_ready(udbase_dir):
|
| 814 |
+
gum_conllu = common.find_treebank_dataset_file("UD_English-GUMReddit", udbase_dir, "train", "conllu")
|
| 815 |
+
if common.mostly_underscores(gum_conllu):
|
| 816 |
+
raise ValueError("Cannot process UD_English-GUMReddit in its current form. There should be a download script available in the directory which will help integrate the missing proprietary values. Please run that script to update the data, then try again.")
|
| 817 |
+
|
| 818 |
+
def build_combined_english_dataset(paths, model_type, dataset):
|
| 819 |
+
"""
|
| 820 |
+
en_combined is currently EWT, GUM, PUD, Pronouns, and handparsed
|
| 821 |
+
"""
|
| 822 |
+
udbase_dir = paths["UDBASE"]
|
| 823 |
+
check_gum_ready(udbase_dir)
|
| 824 |
+
|
| 825 |
+
if dataset == 'train':
|
| 826 |
+
# TODO: include more UD treebanks, possibly with xpos removed
|
| 827 |
+
# UD_English-ParTUT - xpos are different
|
| 828 |
+
# also include "external" treebanks such as PTB
|
| 829 |
+
# NOTE: in order to get the best results, make sure each of these treebanks have the latest edits applied
|
| 830 |
+
train_treebanks = ["UD_English-EWT", "UD_English-GUM", "UD_English-GUMReddit"]
|
| 831 |
+
test_treebanks = ["UD_English-PUD", "UD_English-Pronouns"]
|
| 832 |
+
sents = []
|
| 833 |
+
for treebank in train_treebanks:
|
| 834 |
+
conllu_file = common.find_treebank_dataset_file(treebank, udbase_dir, "train", "conllu", fail=True)
|
| 835 |
+
new_sents = read_sentences_from_conllu(conllu_file)
|
| 836 |
+
print("Read %d sentences from %s" % (len(new_sents), conllu_file))
|
| 837 |
+
sents.extend(new_sents)
|
| 838 |
+
for treebank in test_treebanks:
|
| 839 |
+
conllu_file = common.find_treebank_dataset_file(treebank, udbase_dir, "test", "conllu", fail=True)
|
| 840 |
+
new_sents = read_sentences_from_conllu(conllu_file)
|
| 841 |
+
print("Read %d sentences from %s" % (len(new_sents), conllu_file))
|
| 842 |
+
sents.extend(new_sents)
|
| 843 |
+
else:
|
| 844 |
+
ewt_conllu = common.find_treebank_dataset_file("UD_English-EWT", udbase_dir, dataset, "conllu")
|
| 845 |
+
sents = read_sentences_from_conllu(ewt_conllu)
|
| 846 |
+
|
| 847 |
+
return sents
|
| 848 |
+
|
| 849 |
+
def add_english_sentence_final_punctuation(handparsed_sentences):
|
| 850 |
+
"""
|
| 851 |
+
Add a period to the end of a sentence with no punct at the end.
|
| 852 |
+
|
| 853 |
+
The next-to-last word has SpaceAfter=No added as well.
|
| 854 |
+
|
| 855 |
+
Possibly English-specific because of the xpos. Could be upgraded
|
| 856 |
+
to handle multiple languages by passing in the xpos as an argument
|
| 857 |
+
"""
|
| 858 |
+
new_sents = []
|
| 859 |
+
for sent in handparsed_sentences:
|
| 860 |
+
root_id = None
|
| 861 |
+
max_id = None
|
| 862 |
+
last_punct = False
|
| 863 |
+
for line in sent:
|
| 864 |
+
if line.startswith("#"):
|
| 865 |
+
continue
|
| 866 |
+
pieces = line.split("\t")
|
| 867 |
+
if MWT_OR_COPY_RE.match(pieces[0]):
|
| 868 |
+
continue
|
| 869 |
+
if pieces[6] == '0':
|
| 870 |
+
root_id = pieces[0]
|
| 871 |
+
max_id = int(pieces[0])
|
| 872 |
+
last_punct = pieces[3] == 'PUNCT'
|
| 873 |
+
if not last_punct:
|
| 874 |
+
new_sent = list(sent)
|
| 875 |
+
pieces = new_sent[-1].split("\t")
|
| 876 |
+
pieces[-1] = add_space_after_no(pieces[-1])
|
| 877 |
+
new_sent[-1] = "\t".join(pieces)
|
| 878 |
+
new_sent.append("%d\t.\t.\tPUNCT\t.\t_\t%s\tpunct\t%s:punct\t_" % (max_id+1, root_id, root_id))
|
| 879 |
+
new_sents.append(new_sent)
|
| 880 |
+
else:
|
| 881 |
+
new_sents.append(sent)
|
| 882 |
+
return new_sents
|
| 883 |
+
|
| 884 |
+
def build_extra_combined_french_dataset(paths, model_type, dataset):
|
| 885 |
+
"""
|
| 886 |
+
Extra sentences we don't want augmented for French - currently, handparsed lemmas
|
| 887 |
+
"""
|
| 888 |
+
handparsed_dir = paths["HANDPARSED_DIR"]
|
| 889 |
+
sents = []
|
| 890 |
+
if dataset == 'train':
|
| 891 |
+
if model_type is common.ModelType.LEMMA:
|
| 892 |
+
handparsed_path = os.path.join(handparsed_dir, "french-lemmas", "fr_lemmas.conllu")
|
| 893 |
+
handparsed_sentences = read_sentences_from_conllu(handparsed_path)
|
| 894 |
+
print("Loaded %d sentences from %s" % (len(handparsed_sentences), handparsed_path))
|
| 895 |
+
sents.extend(handparsed_sentences)
|
| 896 |
+
return sents
|
| 897 |
+
|
| 898 |
+
|
| 899 |
+
def build_extra_combined_english_dataset(paths, model_type, dataset):
|
| 900 |
+
"""
|
| 901 |
+
Extra sentences we don't want augmented
|
| 902 |
+
"""
|
| 903 |
+
handparsed_dir = paths["HANDPARSED_DIR"]
|
| 904 |
+
sents = []
|
| 905 |
+
if dataset == 'train':
|
| 906 |
+
handparsed_path = os.path.join(handparsed_dir, "english-handparsed", "english.conll")
|
| 907 |
+
handparsed_sentences = read_sentences_from_conllu(handparsed_path)
|
| 908 |
+
handparsed_sentences = add_english_sentence_final_punctuation(handparsed_sentences)
|
| 909 |
+
sents.extend(handparsed_sentences)
|
| 910 |
+
print("Loaded %d sentences from %s" % (len(sents), handparsed_path))
|
| 911 |
+
|
| 912 |
+
if model_type is common.ModelType.LEMMA:
|
| 913 |
+
handparsed_path = os.path.join(handparsed_dir, "english-lemmas", "en_lemmas.conllu")
|
| 914 |
+
handparsed_sentences = read_sentences_from_conllu(handparsed_path)
|
| 915 |
+
print("Loaded %d sentences from %s" % (len(handparsed_sentences), handparsed_path))
|
| 916 |
+
sents.extend(handparsed_sentences)
|
| 917 |
+
return sents
|
| 918 |
+
|
| 919 |
+
def build_extra_combined_italian_dataset(paths, model_type, dataset):
|
| 920 |
+
"""
|
| 921 |
+
Extra data - the MWT data for Italian
|
| 922 |
+
"""
|
| 923 |
+
handparsed_dir = paths["HANDPARSED_DIR"]
|
| 924 |
+
if dataset != 'train':
|
| 925 |
+
return []
|
| 926 |
+
|
| 927 |
+
extra_italian = os.path.join(handparsed_dir, "italian-mwt", "italian.mwt")
|
| 928 |
+
if not os.path.exists(extra_italian):
|
| 929 |
+
raise FileNotFoundError("Cannot find the extra dataset 'italian.mwt' which includes various multi-words retokenized, expected {}".format(extra_italian))
|
| 930 |
+
|
| 931 |
+
extra_sents = read_sentences_from_conllu(extra_italian)
|
| 932 |
+
for sentence in extra_sents:
|
| 933 |
+
if not sentence[2].endswith("_") or not MWT_RE.match(sentence[2]):
|
| 934 |
+
raise AssertionError("Unexpected format of the italian.mwt file. Has it already be modified to have SpaceAfter=No everywhere?")
|
| 935 |
+
sentence[2] = sentence[2][:-1] + "SpaceAfter=No"
|
| 936 |
+
print("Loaded %d sentences from %s" % (len(extra_sents), extra_italian))
|
| 937 |
+
return extra_sents
|
| 938 |
+
|
| 939 |
+
def replace_semicolons(sentences):
|
| 940 |
+
"""
|
| 941 |
+
Spanish GSD and AnCora have different standards for semicolons.
|
| 942 |
+
|
| 943 |
+
GSD has semicolons at the end of sentences, AnCora has them in the middle as clause separators.
|
| 944 |
+
Consecutive sentences in GSD do not seem to be related, so there is no combining that can be done.
|
| 945 |
+
The easiest solution is to replace sentence final semicolons with "." in GSD
|
| 946 |
+
"""
|
| 947 |
+
new_sents = []
|
| 948 |
+
count = 0
|
| 949 |
+
for sentence in sentences:
|
| 950 |
+
for text_idx, text_line in enumerate(sentence):
|
| 951 |
+
if text_line.startswith("# text"):
|
| 952 |
+
break
|
| 953 |
+
else:
|
| 954 |
+
raise ValueError("Expected every sentence in GSD to have a # text field")
|
| 955 |
+
if not text_line.endswith(";"):
|
| 956 |
+
new_sents.append(sentence)
|
| 957 |
+
continue
|
| 958 |
+
count = count + 1
|
| 959 |
+
new_sent = list(sentence)
|
| 960 |
+
new_sent[text_idx] = text_line[:-1] + "."
|
| 961 |
+
new_sent[-1] = new_sent[-1].replace(";", ".")
|
| 962 |
+
count = count + 1
|
| 963 |
+
new_sents.append(new_sent)
|
| 964 |
+
print("Updated %d sentences to replace sentence-final ; with ." % count)
|
| 965 |
+
return new_sents
|
| 966 |
+
|
| 967 |
+
def strip_column(sents, column):
|
| 968 |
+
"""
|
| 969 |
+
Removes a specified column from the given dataset
|
| 970 |
+
|
| 971 |
+
Particularly useful when mixing two different POS formalisms in the same tagger
|
| 972 |
+
"""
|
| 973 |
+
new_sents = []
|
| 974 |
+
for sentence in sents:
|
| 975 |
+
new_sent = []
|
| 976 |
+
for word in sentence:
|
| 977 |
+
if word.startswith("#"):
|
| 978 |
+
new_sent.append(word)
|
| 979 |
+
continue
|
| 980 |
+
pieces = word.split("\t")
|
| 981 |
+
pieces[column] = "_"
|
| 982 |
+
new_sent.append("\t".join(pieces))
|
| 983 |
+
new_sents.append(new_sent)
|
| 984 |
+
return new_sents
|
| 985 |
+
|
| 986 |
+
def strip_xpos(sents):
|
| 987 |
+
"""
|
| 988 |
+
Removes all xpos from the given dataset
|
| 989 |
+
|
| 990 |
+
Particularly useful when mixing two different POS formalisms in the same tagger
|
| 991 |
+
"""
|
| 992 |
+
return strip_column(sents, 4)
|
| 993 |
+
|
| 994 |
+
def strip_feats(sents):
|
| 995 |
+
"""
|
| 996 |
+
Removes all features from the given dataset
|
| 997 |
+
|
| 998 |
+
Particularly useful when mixing two different POS formalisms in the same tagger
|
| 999 |
+
"""
|
| 1000 |
+
return strip_column(sents, 5)
|
| 1001 |
+
|
| 1002 |
+
def build_combined_albanian_dataset(paths, model_type, dataset):
|
| 1003 |
+
"""
|
| 1004 |
+
sq_combined is STAF as the base, with TSA added for some things
|
| 1005 |
+
"""
|
| 1006 |
+
udbase_dir = paths["UDBASE"]
|
| 1007 |
+
udbase_git_dir = paths["UDBASE_GIT"]
|
| 1008 |
+
handparsed_dir = paths["HANDPARSED_DIR"]
|
| 1009 |
+
|
| 1010 |
+
treebanks = ["UD_Albanian-STAF", "UD_Albanian-TSA"]
|
| 1011 |
+
|
| 1012 |
+
if dataset == 'train' and model_type == common.ModelType.POS:
|
| 1013 |
+
documents = {}
|
| 1014 |
+
|
| 1015 |
+
conllu_file = common.find_treebank_dataset_file(treebanks[0], udbase_dir, "train", "conllu", fail=True)
|
| 1016 |
+
new_sents = read_sentences_from_conllu(conllu_file)
|
| 1017 |
+
documents[treebanks[0]] = new_sents
|
| 1018 |
+
|
| 1019 |
+
# we use udbase_git_dir for TSA because of an updated MWT scheme
|
| 1020 |
+
conllu_file = common.find_treebank_dataset_file(treebanks[1], udbase_git_dir, "test", "conllu", fail=True)
|
| 1021 |
+
new_sents = read_sentences_from_conllu(conllu_file)
|
| 1022 |
+
new_sents = strip_xpos(new_sents)
|
| 1023 |
+
new_sents = strip_feats(new_sents)
|
| 1024 |
+
documents[treebanks[1]] = new_sents
|
| 1025 |
+
|
| 1026 |
+
return documents
|
| 1027 |
+
|
| 1028 |
+
if dataset == 'train' and model_type is not common.ModelType.DEPPARSE:
|
| 1029 |
+
sents = []
|
| 1030 |
+
|
| 1031 |
+
conllu_file = common.find_treebank_dataset_file(treebanks[0], udbase_dir, "train", "conllu", fail=True)
|
| 1032 |
+
new_sents = read_sentences_from_conllu(conllu_file)
|
| 1033 |
+
print("Read %d sentences from %s" % (len(new_sents), conllu_file))
|
| 1034 |
+
sents.extend(new_sents)
|
| 1035 |
+
|
| 1036 |
+
conllu_file = common.find_treebank_dataset_file(treebanks[1], udbase_git_dir, "test", "conllu", fail=True)
|
| 1037 |
+
new_sents = read_sentences_from_conllu(conllu_file)
|
| 1038 |
+
print("Read %d sentences from %s" % (len(new_sents), conllu_file))
|
| 1039 |
+
sents.extend(new_sents)
|
| 1040 |
+
|
| 1041 |
+
return sents
|
| 1042 |
+
|
| 1043 |
+
conllu_file = common.find_treebank_dataset_file(treebanks[0], udbase_dir, dataset, "conllu", fail=True)
|
| 1044 |
+
sents = read_sentences_from_conllu(conllu_file)
|
| 1045 |
+
return sents
|
| 1046 |
+
|
| 1047 |
+
def build_combined_spanish_dataset(paths, model_type, dataset):
|
| 1048 |
+
"""
|
| 1049 |
+
es_combined is AnCora and GSD put together
|
| 1050 |
+
|
| 1051 |
+
For POS training, we put the different datasets into a zip file so
|
| 1052 |
+
that we can keep the conllu files separate and remove the xpos
|
| 1053 |
+
from the non-AnCora training files. It is necessary to remove the
|
| 1054 |
+
xpos because GSD and PUD both use different xpos schemes from
|
| 1055 |
+
AnCora, and the tagger can use additional data files as training
|
| 1056 |
+
data without a specific column if that column is entirely blank
|
| 1057 |
+
|
| 1058 |
+
TODO: consider mixing in PUD?
|
| 1059 |
+
"""
|
| 1060 |
+
udbase_dir = paths["UDBASE"]
|
| 1061 |
+
handparsed_dir = paths["HANDPARSED_DIR"]
|
| 1062 |
+
|
| 1063 |
+
treebanks = ["UD_Spanish-AnCora", "UD_Spanish-GSD"]
|
| 1064 |
+
|
| 1065 |
+
if dataset == 'train' and model_type == common.ModelType.POS:
|
| 1066 |
+
documents = {}
|
| 1067 |
+
for treebank in treebanks:
|
| 1068 |
+
conllu_file = common.find_treebank_dataset_file(treebank, udbase_dir, dataset, "conllu", fail=True)
|
| 1069 |
+
new_sents = read_sentences_from_conllu(conllu_file)
|
| 1070 |
+
if not treebank.endswith("AnCora"):
|
| 1071 |
+
new_sents = strip_xpos(new_sents)
|
| 1072 |
+
documents[treebank] = new_sents
|
| 1073 |
+
|
| 1074 |
+
return documents
|
| 1075 |
+
|
| 1076 |
+
if dataset == 'train':
|
| 1077 |
+
sents = []
|
| 1078 |
+
for treebank in treebanks:
|
| 1079 |
+
conllu_file = common.find_treebank_dataset_file(treebank, udbase_dir, dataset, "conllu", fail=True)
|
| 1080 |
+
new_sents = read_sentences_from_conllu(conllu_file)
|
| 1081 |
+
print("Read %d sentences from %s" % (len(new_sents), conllu_file))
|
| 1082 |
+
if treebank.endswith("GSD"):
|
| 1083 |
+
new_sents = replace_semicolons(new_sents)
|
| 1084 |
+
sents.extend(new_sents)
|
| 1085 |
+
|
| 1086 |
+
if model_type in (common.ModelType.TOKENIZER, common.ModelType.MWT, common.ModelType.LEMMA):
|
| 1087 |
+
extra_spanish = os.path.join(handparsed_dir, "spanish-mwt", "adjectives.conllu")
|
| 1088 |
+
if not os.path.exists(extra_spanish):
|
| 1089 |
+
raise FileNotFoundError("Cannot find the extra dataset 'handpicked.mwt' which includes various multi-words retokenized, expected {}".format(extra_italian))
|
| 1090 |
+
extra_sents = read_sentences_from_conllu(extra_spanish)
|
| 1091 |
+
print("Read %d sentences from %s" % (len(extra_sents), extra_spanish))
|
| 1092 |
+
sents.extend(extra_sents)
|
| 1093 |
+
else:
|
| 1094 |
+
conllu_file = common.find_treebank_dataset_file("UD_Spanish-AnCora", udbase_dir, dataset, "conllu", fail=True)
|
| 1095 |
+
sents = read_sentences_from_conllu(conllu_file)
|
| 1096 |
+
|
| 1097 |
+
return sents
|
| 1098 |
+
|
| 1099 |
+
def build_combined_french_dataset(paths, model_type, dataset):
|
| 1100 |
+
udbase_dir = paths["UDBASE"]
|
| 1101 |
+
handparsed_dir = paths["HANDPARSED_DIR"]
|
| 1102 |
+
if dataset == 'train':
|
| 1103 |
+
train_treebanks = ["UD_French-GSD", "UD_French-ParisStories", "UD_French-Rhapsodie", "UD_French-Sequoia"]
|
| 1104 |
+
sents = []
|
| 1105 |
+
for treebank in train_treebanks:
|
| 1106 |
+
conllu_file = common.find_treebank_dataset_file(treebank, udbase_dir, "train", "conllu", fail=True)
|
| 1107 |
+
new_sents = read_sentences_from_conllu(conllu_file)
|
| 1108 |
+
print("Read %d sentences from %s" % (len(new_sents), conllu_file))
|
| 1109 |
+
sents.extend(new_sents)
|
| 1110 |
+
|
| 1111 |
+
extra_french = os.path.join(handparsed_dir, "french-handparsed", "handparsed_deps.conllu")
|
| 1112 |
+
if not os.path.exists(extra_french):
|
| 1113 |
+
raise FileNotFoundError("Cannot find the extra dataset 'handparsed_deps.conllu' which includes various dependency fixes, expected {}".format(extra_italian))
|
| 1114 |
+
extra_sents = read_sentences_from_conllu(extra_french)
|
| 1115 |
+
print("Read %d sentences from %s" % (len(extra_sents), extra_french))
|
| 1116 |
+
sents.extend(extra_sents)
|
| 1117 |
+
else:
|
| 1118 |
+
gsd_conllu = common.find_treebank_dataset_file("UD_French-GSD", udbase_dir, dataset, "conllu")
|
| 1119 |
+
sents = read_sentences_from_conllu(gsd_conllu)
|
| 1120 |
+
|
| 1121 |
+
return sents
|
| 1122 |
+
|
| 1123 |
+
def build_combined_hebrew_dataset(paths, model_type, dataset):
|
| 1124 |
+
"""
|
| 1125 |
+
Combines the IAHLT treebank with an updated form of HTB where the annotation style more closes matches IAHLT
|
| 1126 |
+
|
| 1127 |
+
Currently the updated HTB is not in UD, so you will need to clone
|
| 1128 |
+
git@github.com:IAHLT/UD_Hebrew.git to $UDBASE_GIT
|
| 1129 |
+
|
| 1130 |
+
dev and test sets will be those from IAHLT
|
| 1131 |
+
"""
|
| 1132 |
+
udbase_dir = paths["UDBASE"]
|
| 1133 |
+
udbase_git_dir = paths["UDBASE_GIT"]
|
| 1134 |
+
|
| 1135 |
+
treebanks = ["UD_Hebrew-IAHLTwiki", "UD_Hebrew-IAHLTknesset"]
|
| 1136 |
+
if dataset == 'train':
|
| 1137 |
+
sents = []
|
| 1138 |
+
for treebank in treebanks:
|
| 1139 |
+
conllu_file = common.find_treebank_dataset_file(treebank, udbase_dir, dataset, "conllu", fail=True)
|
| 1140 |
+
new_sents = read_sentences_from_conllu(conllu_file)
|
| 1141 |
+
print("Read %d sentences from %s" % (len(new_sents), conllu_file))
|
| 1142 |
+
sents.extend(new_sents)
|
| 1143 |
+
|
| 1144 |
+
# if/when this gets ported back to UD, switch to getting both datasets from UD
|
| 1145 |
+
hebrew_git_dir = os.path.join(udbase_git_dir, "UD_Hebrew")
|
| 1146 |
+
if not os.path.exists(hebrew_git_dir):
|
| 1147 |
+
raise FileNotFoundError("Please download git@github.com:IAHLT/UD_Hebrew.git to %s (based on $UDBASE_GIT)" % hebrew_git_dir)
|
| 1148 |
+
conllu_file = os.path.join(hebrew_git_dir, "he_htb-ud-train.conllu")
|
| 1149 |
+
if not os.path.exists(conllu_file):
|
| 1150 |
+
raise FileNotFoundError("Found %s but inexplicably there was no %s" % (hebrew_git_dir, conllu_file))
|
| 1151 |
+
new_sents = read_sentences_from_conllu(conllu_file)
|
| 1152 |
+
print("Read %d sentences from %s" % (len(new_sents), conllu_file))
|
| 1153 |
+
sents.extend(new_sents)
|
| 1154 |
+
else:
|
| 1155 |
+
conllu_file = common.find_treebank_dataset_file(treebanks[0], udbase_dir, dataset, "conllu", fail=True)
|
| 1156 |
+
sents = read_sentences_from_conllu(conllu_file)
|
| 1157 |
+
|
| 1158 |
+
return sents
|
| 1159 |
+
|
| 1160 |
+
COMBINED_FNS = {
|
| 1161 |
+
"en_combined": build_combined_english_dataset,
|
| 1162 |
+
"es_combined": build_combined_spanish_dataset,
|
| 1163 |
+
"fr_combined": build_combined_french_dataset,
|
| 1164 |
+
"he_combined": build_combined_hebrew_dataset,
|
| 1165 |
+
"it_combined": build_combined_italian_dataset,
|
| 1166 |
+
"sq_combined": build_combined_albanian_dataset,
|
| 1167 |
+
}
|
| 1168 |
+
|
| 1169 |
+
# some extra data for the combined models without augmenting
|
| 1170 |
+
COMBINED_EXTRA_FNS = {
|
| 1171 |
+
"en_combined": build_extra_combined_english_dataset,
|
| 1172 |
+
"fr_combined": build_extra_combined_french_dataset,
|
| 1173 |
+
"it_combined": build_extra_combined_italian_dataset,
|
| 1174 |
+
}
|
| 1175 |
+
|
| 1176 |
+
def build_combined_dataset(paths, short_name, model_type, augment):
|
| 1177 |
+
random.seed(1234)
|
| 1178 |
+
tokenizer_dir = paths["TOKENIZE_DATA_DIR"]
|
| 1179 |
+
build_fn = COMBINED_FNS[short_name]
|
| 1180 |
+
extra_fn = COMBINED_EXTRA_FNS.get(short_name, None)
|
| 1181 |
+
for dataset in ("train", "dev", "test"):
|
| 1182 |
+
output_conllu = common.tokenizer_conllu_name(tokenizer_dir, short_name, dataset)
|
| 1183 |
+
sents = build_fn(paths, model_type, dataset)
|
| 1184 |
+
if isinstance(sents, dict):
|
| 1185 |
+
if dataset == 'train' and augment:
|
| 1186 |
+
for filename in list(sents.keys()):
|
| 1187 |
+
sents[filename] = augment_punct(sents[filename])
|
| 1188 |
+
output_zip = os.path.splitext(output_conllu)[0] + ".zip"
|
| 1189 |
+
with zipfile.ZipFile(output_zip, "w") as zout:
|
| 1190 |
+
for filename in list(sents.keys()):
|
| 1191 |
+
with zout.open(filename + ".conllu", "w") as zfout:
|
| 1192 |
+
with io.TextIOWrapper(zfout, encoding='utf-8', newline='') as fout:
|
| 1193 |
+
write_sentences_to_file(fout, sents[filename])
|
| 1194 |
+
else:
|
| 1195 |
+
if dataset == 'train' and augment:
|
| 1196 |
+
sents = augment_punct(sents)
|
| 1197 |
+
if extra_fn is not None:
|
| 1198 |
+
sents.extend(extra_fn(paths, model_type, dataset))
|
| 1199 |
+
write_sentences_to_conllu(output_conllu, sents)
|
| 1200 |
+
|
| 1201 |
+
BIO_DATASETS = ("en_craft", "en_genia", "en_mimic")
|
| 1202 |
+
|
| 1203 |
+
def build_bio_dataset(paths, udbase_dir, tokenizer_dir, handparsed_dir, short_name, model_type, augment):
|
| 1204 |
+
"""
|
| 1205 |
+
Process the en bio datasets
|
| 1206 |
+
|
| 1207 |
+
Creates a dataset by combining the en_combined data with one of the bio sets
|
| 1208 |
+
"""
|
| 1209 |
+
random.seed(1234)
|
| 1210 |
+
name, bio_dataset = short_name.split("_")
|
| 1211 |
+
assert name == 'en'
|
| 1212 |
+
for dataset in ("train", "dev", "test"):
|
| 1213 |
+
output_conllu = common.tokenizer_conllu_name(tokenizer_dir, short_name, dataset)
|
| 1214 |
+
if dataset == 'train':
|
| 1215 |
+
sents = build_combined_english_dataset(paths, model_type, dataset)
|
| 1216 |
+
if dataset == 'train' and augment:
|
| 1217 |
+
sents = augment_punct(sents)
|
| 1218 |
+
else:
|
| 1219 |
+
sents = []
|
| 1220 |
+
bio_file = os.path.join(paths["BIO_UD_DIR"], "UD_English-%s" % bio_dataset.upper(), "en_%s-ud-%s.conllu" % (bio_dataset.lower(), dataset))
|
| 1221 |
+
sents.extend(read_sentences_from_conllu(bio_file))
|
| 1222 |
+
write_sentences_to_conllu(output_conllu, sents)
|
| 1223 |
+
|
| 1224 |
+
def build_combined_english_gum_dataset(udbase_dir, tokenizer_dir, short_name, dataset, augment):
|
| 1225 |
+
"""
|
| 1226 |
+
Build the GUM dataset by combining GUMReddit
|
| 1227 |
+
|
| 1228 |
+
It checks to make sure GUMReddit is filled out using the included script
|
| 1229 |
+
"""
|
| 1230 |
+
check_gum_ready(udbase_dir)
|
| 1231 |
+
random.seed(1234)
|
| 1232 |
+
|
| 1233 |
+
output_conllu = common.tokenizer_conllu_name(tokenizer_dir, short_name, dataset)
|
| 1234 |
+
|
| 1235 |
+
treebanks = ["UD_English-GUM", "UD_English-GUMReddit"]
|
| 1236 |
+
sents = []
|
| 1237 |
+
for treebank in treebanks:
|
| 1238 |
+
conllu_file = common.find_treebank_dataset_file(treebank, udbase_dir, dataset, "conllu", fail=True)
|
| 1239 |
+
sents.extend(read_sentences_from_conllu(conllu_file))
|
| 1240 |
+
|
| 1241 |
+
if dataset == 'train' and augment:
|
| 1242 |
+
sents = augment_punct(sents)
|
| 1243 |
+
|
| 1244 |
+
write_sentences_to_conllu(output_conllu, sents)
|
| 1245 |
+
|
| 1246 |
+
def build_combined_english_gum(udbase_dir, tokenizer_dir, short_name, augment):
|
| 1247 |
+
for dataset in ("train", "dev", "test"):
|
| 1248 |
+
build_combined_english_gum_dataset(udbase_dir, tokenizer_dir, short_name, dataset, augment)
|
| 1249 |
+
|
| 1250 |
+
def prepare_ud_dataset(treebank, udbase_dir, tokenizer_dir, short_name, short_language, dataset, augment=True, input_conllu=None, output_conllu=None):
|
| 1251 |
+
if input_conllu is None:
|
| 1252 |
+
input_conllu = common.find_treebank_dataset_file(treebank, udbase_dir, dataset, "conllu", fail=True)
|
| 1253 |
+
if output_conllu is None:
|
| 1254 |
+
output_conllu = common.tokenizer_conllu_name(tokenizer_dir, short_name, dataset)
|
| 1255 |
+
print("Reading from %s and writing to %s" % (input_conllu, output_conllu))
|
| 1256 |
+
|
| 1257 |
+
if short_name == "te_mtg" and dataset == 'train' and augment:
|
| 1258 |
+
write_augmented_dataset(input_conllu, output_conllu, augment_telugu)
|
| 1259 |
+
elif short_name == "ar_padt" and dataset == 'train' and augment:
|
| 1260 |
+
write_augmented_dataset(input_conllu, output_conllu, augment_arabic_padt)
|
| 1261 |
+
elif short_name.startswith("ko_") and short_name.endswith("_seg"):
|
| 1262 |
+
remove_spaces(input_conllu, output_conllu)
|
| 1263 |
+
elif dataset == 'train' and augment:
|
| 1264 |
+
write_augmented_dataset(input_conllu, output_conllu, augment_punct)
|
| 1265 |
+
else:
|
| 1266 |
+
sents = read_sentences_from_conllu(input_conllu)
|
| 1267 |
+
write_sentences_to_conllu(output_conllu, sents)
|
| 1268 |
+
|
| 1269 |
+
def process_ud_treebank(treebank, udbase_dir, tokenizer_dir, short_name, short_language, augment=True):
|
| 1270 |
+
"""
|
| 1271 |
+
Process a normal UD treebank with train/dev/test splits
|
| 1272 |
+
|
| 1273 |
+
SL-SSJ and other datasets with inline modifications all use this code path as well.
|
| 1274 |
+
"""
|
| 1275 |
+
prepare_ud_dataset(treebank, udbase_dir, tokenizer_dir, short_name, short_language, "train", augment)
|
| 1276 |
+
prepare_ud_dataset(treebank, udbase_dir, tokenizer_dir, short_name, short_language, "dev", augment)
|
| 1277 |
+
prepare_ud_dataset(treebank, udbase_dir, tokenizer_dir, short_name, short_language, "test", augment)
|
| 1278 |
+
|
| 1279 |
+
|
| 1280 |
+
XV_RATIO = 0.2
|
| 1281 |
+
|
| 1282 |
+
def process_partial_ud_treebank(treebank, udbase_dir, tokenizer_dir, short_name, short_language):
|
| 1283 |
+
"""
|
| 1284 |
+
Process a UD treebank with only train/test splits
|
| 1285 |
+
|
| 1286 |
+
For example, in UD 2.7:
|
| 1287 |
+
UD_Buryat-BDT
|
| 1288 |
+
UD_Galician-TreeGal
|
| 1289 |
+
UD_Indonesian-CSUI
|
| 1290 |
+
UD_Kazakh-KTB
|
| 1291 |
+
UD_Kurmanji-MG
|
| 1292 |
+
UD_Latin-Perseus
|
| 1293 |
+
UD_Livvi-KKPP
|
| 1294 |
+
UD_North_Sami-Giella
|
| 1295 |
+
UD_Old_Russian-RNC
|
| 1296 |
+
UD_Sanskrit-Vedic
|
| 1297 |
+
UD_Slovenian-SST
|
| 1298 |
+
UD_Upper_Sorbian-UFAL
|
| 1299 |
+
UD_Welsh-CCG
|
| 1300 |
+
"""
|
| 1301 |
+
train_input_conllu = common.find_treebank_dataset_file(treebank, udbase_dir, "train", "conllu")
|
| 1302 |
+
test_input_conllu = common.find_treebank_dataset_file(treebank, udbase_dir, "test", "conllu")
|
| 1303 |
+
|
| 1304 |
+
train_output_conllu = common.tokenizer_conllu_name(tokenizer_dir, short_name, "train")
|
| 1305 |
+
dev_output_conllu = common.tokenizer_conllu_name(tokenizer_dir, short_name, "dev")
|
| 1306 |
+
test_output_conllu = common.tokenizer_conllu_name(tokenizer_dir, short_name, "test")
|
| 1307 |
+
|
| 1308 |
+
if (common.num_words_in_file(train_input_conllu) <= 1000 and
|
| 1309 |
+
common.num_words_in_file(test_input_conllu) > 5000):
|
| 1310 |
+
train_input_conllu, test_input_conllu = test_input_conllu, train_input_conllu
|
| 1311 |
+
|
| 1312 |
+
if not split_train_file(treebank=treebank,
|
| 1313 |
+
train_input_conllu=train_input_conllu,
|
| 1314 |
+
train_output_conllu=train_output_conllu,
|
| 1315 |
+
dev_output_conllu=dev_output_conllu):
|
| 1316 |
+
return
|
| 1317 |
+
|
| 1318 |
+
# the test set is already fine
|
| 1319 |
+
# currently we do not do any augmentation of these partial treebanks
|
| 1320 |
+
prepare_ud_dataset(treebank, udbase_dir, tokenizer_dir, short_name, short_language, "test", augment=False, input_conllu=test_input_conllu, output_conllu=test_output_conllu)
|
| 1321 |
+
|
| 1322 |
+
def add_specific_args(parser):
|
| 1323 |
+
parser.add_argument('--no_augment', action='store_false', dest='augment', default=True,
|
| 1324 |
+
help='Augment the dataset in various ways')
|
| 1325 |
+
parser.add_argument('--no_prepare_labels', action='store_false', dest='prepare_labels', default=True,
|
| 1326 |
+
help='Prepare tokenizer and MWT labels. Expensive, but obviously necessary for training those models.')
|
| 1327 |
+
convert_th_lst20.add_lst20_args(parser)
|
| 1328 |
+
|
| 1329 |
+
convert_vi_vlsp.add_vlsp_args(parser)
|
| 1330 |
+
|
| 1331 |
+
def process_treebank(treebank, model_type, paths, args):
|
| 1332 |
+
"""
|
| 1333 |
+
Processes a single treebank into train, dev, test parts
|
| 1334 |
+
|
| 1335 |
+
Includes processing for a few external tokenization datasets:
|
| 1336 |
+
vi_vlsp, th_orchid, th_best
|
| 1337 |
+
|
| 1338 |
+
Also, there is no specific mechanism for UD_Arabic-NYUAD or
|
| 1339 |
+
similar treebanks, which need integration with LDC datsets
|
| 1340 |
+
"""
|
| 1341 |
+
udbase_dir = paths["UDBASE"]
|
| 1342 |
+
tokenizer_dir = paths["TOKENIZE_DATA_DIR"]
|
| 1343 |
+
handparsed_dir = paths["HANDPARSED_DIR"]
|
| 1344 |
+
|
| 1345 |
+
short_name = treebank_to_short_name(treebank)
|
| 1346 |
+
short_language = short_name.split("_")[0]
|
| 1347 |
+
|
| 1348 |
+
os.makedirs(tokenizer_dir, exist_ok=True)
|
| 1349 |
+
|
| 1350 |
+
if short_name == "my_alt":
|
| 1351 |
+
convert_my_alt.convert_my_alt(paths["CONSTITUENCY_BASE"], tokenizer_dir)
|
| 1352 |
+
elif short_name == "vi_vlsp":
|
| 1353 |
+
convert_vi_vlsp.convert_vi_vlsp(paths["STANZA_EXTERN_DIR"], tokenizer_dir, args)
|
| 1354 |
+
elif short_name == "th_orchid":
|
| 1355 |
+
convert_th_orchid.main(paths["STANZA_EXTERN_DIR"], tokenizer_dir)
|
| 1356 |
+
elif short_name == "th_lst20":
|
| 1357 |
+
convert_th_lst20.convert(paths["STANZA_EXTERN_DIR"], tokenizer_dir, args)
|
| 1358 |
+
elif short_name == "th_best":
|
| 1359 |
+
convert_th_best.main(paths["STANZA_EXTERN_DIR"], tokenizer_dir)
|
| 1360 |
+
elif short_name == "ml_cochin":
|
| 1361 |
+
convert_ml_cochin.main(paths["STANZA_EXTERN_DIR"], tokenizer_dir)
|
| 1362 |
+
elif short_name.startswith("ko_combined"):
|
| 1363 |
+
build_combined_korean(udbase_dir, tokenizer_dir, short_name)
|
| 1364 |
+
elif short_name in COMBINED_FNS: # eg "it_combined", "en_combined", etc
|
| 1365 |
+
build_combined_dataset(paths, short_name, model_type, args.augment)
|
| 1366 |
+
elif short_name in BIO_DATASETS:
|
| 1367 |
+
build_bio_dataset(paths, udbase_dir, tokenizer_dir, handparsed_dir, short_name, model_type, args.augment)
|
| 1368 |
+
elif short_name.startswith("en_gum"):
|
| 1369 |
+
# we special case GUM because it should include a filled-out GUMReddit
|
| 1370 |
+
print("Preparing data for %s: %s, %s" % (treebank, short_name, short_language))
|
| 1371 |
+
build_combined_english_gum(udbase_dir, tokenizer_dir, short_name, args.augment)
|
| 1372 |
+
else:
|
| 1373 |
+
# check that we can find the train file where we expect it
|
| 1374 |
+
train_conllu_file = common.find_treebank_dataset_file(treebank, udbase_dir, "train", "conllu", fail=True)
|
| 1375 |
+
|
| 1376 |
+
print("Preparing data for %s: %s, %s" % (treebank, short_name, short_language))
|
| 1377 |
+
|
| 1378 |
+
if not common.find_treebank_dataset_file(treebank, udbase_dir, "dev", "conllu", fail=False):
|
| 1379 |
+
process_partial_ud_treebank(treebank, udbase_dir, tokenizer_dir, short_name, short_language)
|
| 1380 |
+
else:
|
| 1381 |
+
process_ud_treebank(treebank, udbase_dir, tokenizer_dir, short_name, short_language, args.augment)
|
| 1382 |
+
|
| 1383 |
+
if model_type is common.ModelType.TOKENIZER or model_type is common.ModelType.MWT:
|
| 1384 |
+
if not short_name in ('th_orchid', 'th_lst20'):
|
| 1385 |
+
common.convert_conllu_to_txt(tokenizer_dir, short_name)
|
| 1386 |
+
|
| 1387 |
+
if args.prepare_labels:
|
| 1388 |
+
common.prepare_tokenizer_treebank_labels(tokenizer_dir, short_name)
|
| 1389 |
+
|
| 1390 |
+
|
| 1391 |
+
def main():
|
| 1392 |
+
common.main(process_treebank, common.ModelType.TOKENIZER, add_specific_args)
|
| 1393 |
+
|
| 1394 |
+
if __name__ == '__main__':
|
| 1395 |
+
main()
|
| 1396 |
+
|
stanza/stanza/utils/datasets/pretrain/__init__.py
ADDED
|
File without changes
|
stanza/stanza/utils/datasets/tokenization/__init__.py
ADDED
|
File without changes
|
stanza/stanza/utils/datasets/tokenization/convert_vi_vlsp.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
punctuation_set = (',', '.', '!', '?', ')', ':', ';', '”', '…', '...')
|
| 5 |
+
|
| 6 |
+
def find_spaces(sentence):
|
| 7 |
+
# TODO: there are some sentences where there is only one quote,
|
| 8 |
+
# and some of them should be attached to the previous word instead
|
| 9 |
+
# of the next word. Training should work this way, though
|
| 10 |
+
odd_quotes = False
|
| 11 |
+
|
| 12 |
+
spaces = []
|
| 13 |
+
for word_idx, word in enumerate(sentence):
|
| 14 |
+
space = True
|
| 15 |
+
# Quote period at the end of a sentence needs to be attached
|
| 16 |
+
# to the rest of the text. Some sentences have `"... text`
|
| 17 |
+
# in the middle, though, so look for that
|
| 18 |
+
if word_idx < len(sentence) - 2 and sentence[word_idx+1] == '"':
|
| 19 |
+
if sentence[word_idx+2] == '.':
|
| 20 |
+
space = False
|
| 21 |
+
elif word_idx == len(sentence) - 3 and sentence[word_idx+2] == '...':
|
| 22 |
+
space = False
|
| 23 |
+
if word_idx < len(sentence) - 1:
|
| 24 |
+
if sentence[word_idx+1] in (',', '.', '!', '?', ')', ':', ';', '”', '…', '...','/', '%'):
|
| 25 |
+
space = False
|
| 26 |
+
if word in ('(', '“', '/'):
|
| 27 |
+
space = False
|
| 28 |
+
if word == '"':
|
| 29 |
+
if odd_quotes:
|
| 30 |
+
# already saw one quote. put this one at the end of the PREVIOUS word
|
| 31 |
+
# note that we know there must be at least one word already
|
| 32 |
+
odd_quotes = False
|
| 33 |
+
spaces[word_idx-1] = False
|
| 34 |
+
else:
|
| 35 |
+
odd_quotes = True
|
| 36 |
+
space = False
|
| 37 |
+
spaces.append(space)
|
| 38 |
+
return spaces
|
| 39 |
+
|
| 40 |
+
def add_vlsp_args(parser):
|
| 41 |
+
parser.add_argument('--include_pos_data', action='store_true', default=False, help='To include or not POS training dataset for tokenization training. The path to POS dataset is expected to be in the same dir with WS path. For example, extern_dir/vietnamese/VLSP2013-POS-data')
|
| 42 |
+
parser.add_argument('--vlsp_include_spaces', action='store_true', default=False, help='When processing vi_vlsp tokenization, include all of the spaces. Otherwise, we try to turn the text back into standard text')
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def write_file(vlsp_include_spaces, output_filename, sentences, shard):
|
| 46 |
+
with open(output_filename, "w") as fout:
|
| 47 |
+
check_headlines = False
|
| 48 |
+
for sent_idx, sentence in enumerate(sentences):
|
| 49 |
+
fout.write("# sent_id = %s.%d\n" % (shard, sent_idx))
|
| 50 |
+
orig_text = " ".join(sentence)
|
| 51 |
+
#check if the previous line is a headline (no ending mark at the end) then make this sentence a new par
|
| 52 |
+
if check_headlines:
|
| 53 |
+
fout.write("# newpar id =%s.%d.1\n" % (shard, sent_idx))
|
| 54 |
+
check_headlines = False
|
| 55 |
+
if sentence[len(sentence) - 1] not in punctuation_set:
|
| 56 |
+
check_headlines = True
|
| 57 |
+
|
| 58 |
+
if vlsp_include_spaces:
|
| 59 |
+
fout.write("# text = %s\n" % orig_text)
|
| 60 |
+
else:
|
| 61 |
+
spaces = find_spaces(sentence)
|
| 62 |
+
full_text = ""
|
| 63 |
+
for word, space in zip(sentence, spaces):
|
| 64 |
+
# could be made more efficient, but shouldn't matter
|
| 65 |
+
full_text = full_text + word
|
| 66 |
+
if space:
|
| 67 |
+
full_text = full_text + " "
|
| 68 |
+
fout.write("# text = %s\n" % full_text)
|
| 69 |
+
fout.write("# orig_text = %s\n" % orig_text)
|
| 70 |
+
for word_idx, word in enumerate(sentence):
|
| 71 |
+
fake_dep = "root" if word_idx == 0 else "dep"
|
| 72 |
+
fout.write("%d\t%s\t%s" % ((word_idx+1), word, word))
|
| 73 |
+
fout.write("\t_\t_\t_")
|
| 74 |
+
fout.write("\t%d\t%s" % (word_idx, fake_dep))
|
| 75 |
+
fout.write("\t_\t")
|
| 76 |
+
if vlsp_include_spaces or spaces[word_idx]:
|
| 77 |
+
fout.write("_")
|
| 78 |
+
else:
|
| 79 |
+
fout.write("SpaceAfter=No")
|
| 80 |
+
fout.write("\n")
|
| 81 |
+
fout.write("\n")
|
| 82 |
+
|
| 83 |
+
def convert_pos_dataset(file_path):
|
| 84 |
+
"""
|
| 85 |
+
This function is to process the pos dataset
|
| 86 |
+
"""
|
| 87 |
+
|
| 88 |
+
file = open(file_path, "r")
|
| 89 |
+
document = file.readlines()
|
| 90 |
+
sentences = []
|
| 91 |
+
sent = []
|
| 92 |
+
for line in document:
|
| 93 |
+
if line == "\n" and len(sent)>1:
|
| 94 |
+
if sent not in sentences:
|
| 95 |
+
sentences.append(sent)
|
| 96 |
+
sent = []
|
| 97 |
+
elif line != "\n":
|
| 98 |
+
sent.append(line.split("\t")[0].replace("_"," ").strip())
|
| 99 |
+
return sentences
|
| 100 |
+
|
| 101 |
+
def convert_file(vlsp_include_spaces, input_filename, output_filename, shard, split_filename=None, split_shard=None, pos_data = None):
|
| 102 |
+
with open(input_filename) as fin:
|
| 103 |
+
lines = fin.readlines()
|
| 104 |
+
|
| 105 |
+
sentences = []
|
| 106 |
+
set_sentences = set()
|
| 107 |
+
for line in lines:
|
| 108 |
+
if len(line.replace("_", " ").split())>1:
|
| 109 |
+
words = line.split()
|
| 110 |
+
#one syllable lines are eliminated
|
| 111 |
+
if len(words) == 1 and len(words[0].split("_")) == 1:
|
| 112 |
+
continue
|
| 113 |
+
else:
|
| 114 |
+
words = [w.replace("_", " ") for w in words]
|
| 115 |
+
#only add sentences that hasn't been added before
|
| 116 |
+
if words not in sentences:
|
| 117 |
+
sentences.append(words)
|
| 118 |
+
set_sentences.add(' '.join(words))
|
| 119 |
+
|
| 120 |
+
if split_filename is not None:
|
| 121 |
+
# even this is a larger dev set than the train set
|
| 122 |
+
split_point = int(len(sentences) * 0.95)
|
| 123 |
+
#check pos_data that aren't overlapping with current VLSP WS dataset
|
| 124 |
+
sentences_pos = [] if pos_data is None else [sent for sent in pos_data if ' '.join(sent) not in set_sentences]
|
| 125 |
+
print("Added ", len(sentences_pos), " sentences from POS dataset.")
|
| 126 |
+
write_file(vlsp_include_spaces, output_filename, sentences[:split_point]+sentences_pos, shard)
|
| 127 |
+
write_file(vlsp_include_spaces, split_filename, sentences[split_point:], split_shard)
|
| 128 |
+
else:
|
| 129 |
+
write_file(vlsp_include_spaces, output_filename, sentences, shard)
|
| 130 |
+
|
| 131 |
+
def convert_vi_vlsp(extern_dir, tokenizer_dir, args):
|
| 132 |
+
input_path = os.path.join(extern_dir, "vietnamese", "VLSP2013-WS-data")
|
| 133 |
+
input_pos_path = os.path.join(extern_dir, "vietnamese", "VLSP2013-POS-data")
|
| 134 |
+
input_train_filename = os.path.join(input_path, "VLSP2013_WS_train_gold.txt")
|
| 135 |
+
input_test_filename = os.path.join(input_path, "VLSP2013_WS_test_gold.txt")
|
| 136 |
+
|
| 137 |
+
input_pos_filename = os.path.join(input_pos_path, "VLSP2013_POS_train_BI_POS_Column.txt.goldSeg")
|
| 138 |
+
if not os.path.exists(input_train_filename):
|
| 139 |
+
raise FileNotFoundError("Cannot find train set for VLSP at %s" % input_train_filename)
|
| 140 |
+
if not os.path.exists(input_test_filename):
|
| 141 |
+
raise FileNotFoundError("Cannot find test set for VLSP at %s" % input_test_filename)
|
| 142 |
+
pos_data = None
|
| 143 |
+
if args.include_pos_data:
|
| 144 |
+
if not os.path.exists(input_pos_filename):
|
| 145 |
+
raise FileNotFoundError("Cannot find pos dataset for VLSP at %" % input_pos_filename)
|
| 146 |
+
else:
|
| 147 |
+
pos_data = convert_pos_dataset(input_pos_filename)
|
| 148 |
+
|
| 149 |
+
output_train_filename = os.path.join(tokenizer_dir, "vi_vlsp.train.gold.conllu")
|
| 150 |
+
output_dev_filename = os.path.join(tokenizer_dir, "vi_vlsp.dev.gold.conllu")
|
| 151 |
+
output_test_filename = os.path.join(tokenizer_dir, "vi_vlsp.test.gold.conllu")
|
| 152 |
+
|
| 153 |
+
convert_file(args.vlsp_include_spaces, input_train_filename, output_train_filename, "train", output_dev_filename, "dev", pos_data)
|
| 154 |
+
convert_file(args.vlsp_include_spaces, input_test_filename, output_test_filename, "test")
|
| 155 |
+
|
stanza/stanza/utils/ner/spacy_ner_tag_dataset.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Test a spacy model on a 4 class dataset
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
import json
|
| 7 |
+
|
| 8 |
+
import spacy
|
| 9 |
+
from spacy.tokens import Doc
|
| 10 |
+
|
| 11 |
+
from stanza.models.ner.utils import process_tags
|
| 12 |
+
from stanza.models.ner.scorer import score_by_entity, score_by_token
|
| 13 |
+
|
| 14 |
+
from stanza.utils.confusion import format_confusion
|
| 15 |
+
from stanza.utils.datasets.ner.simplify_ontonotes_to_worldwide import simplify_ontonotes_to_worldwide
|
| 16 |
+
|
| 17 |
+
from stanza.utils.get_tqdm import get_tqdm
|
| 18 |
+
tqdm = get_tqdm()
|
| 19 |
+
|
| 20 |
+
"""
|
| 21 |
+
Simplified classes used in the Worldwide dataset are:
|
| 22 |
+
|
| 23 |
+
Date
|
| 24 |
+
Facility
|
| 25 |
+
Location
|
| 26 |
+
Misc
|
| 27 |
+
Money
|
| 28 |
+
NORP
|
| 29 |
+
Organization
|
| 30 |
+
Person
|
| 31 |
+
Product
|
| 32 |
+
|
| 33 |
+
vs OntoNotes classes:
|
| 34 |
+
|
| 35 |
+
CARDINAL
|
| 36 |
+
DATE
|
| 37 |
+
EVENT
|
| 38 |
+
FAC
|
| 39 |
+
GPE
|
| 40 |
+
LANGUAGE
|
| 41 |
+
LAW
|
| 42 |
+
LOC
|
| 43 |
+
MONEY
|
| 44 |
+
NORP
|
| 45 |
+
ORDINAL
|
| 46 |
+
ORG
|
| 47 |
+
PERCENT
|
| 48 |
+
PERSON
|
| 49 |
+
PRODUCT
|
| 50 |
+
QUANTITY
|
| 51 |
+
TIME
|
| 52 |
+
WORK_OF_ART
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
def test_file(eval_file, tagger, simplify):
|
| 56 |
+
with open(eval_file) as fin:
|
| 57 |
+
gold_doc = json.load(fin)
|
| 58 |
+
gold_doc = [[(x['text'], x['ner']) for x in sentence] for sentence in gold_doc]
|
| 59 |
+
gold_doc = process_tags(gold_doc, 'bioes')
|
| 60 |
+
|
| 61 |
+
if simplify:
|
| 62 |
+
for doc in gold_doc:
|
| 63 |
+
for idx, word in enumerate(doc):
|
| 64 |
+
if word[1] != "O":
|
| 65 |
+
word = [word[0], simplify_ontonotes_to_worldwide(word[1])]
|
| 66 |
+
doc[idx] = word
|
| 67 |
+
|
| 68 |
+
ignore_tags = "Date,DATE" if simplify else None
|
| 69 |
+
|
| 70 |
+
original_text = [[x[0] for x in gold_sentence] for gold_sentence in gold_doc]
|
| 71 |
+
pred_doc = []
|
| 72 |
+
for sentence in tqdm(original_text):
|
| 73 |
+
spacy_sentence = Doc(tagger.vocab, sentence)
|
| 74 |
+
spacy_sentence = tagger(spacy_sentence)
|
| 75 |
+
entities = ["O" if not token.ent_type_ else "%s-%s" % (token.ent_iob_, token.ent_type_) for token in spacy_sentence]
|
| 76 |
+
if simplify:
|
| 77 |
+
entities = [simplify_ontonotes_to_worldwide(x) for x in entities]
|
| 78 |
+
pred_sentence = [[token.text, entity] for token, entity in zip(spacy_sentence, entities)]
|
| 79 |
+
pred_doc.append(pred_sentence)
|
| 80 |
+
|
| 81 |
+
pred_doc = process_tags(pred_doc, 'bioes')
|
| 82 |
+
pred_tags = [[x[1] for x in sentence] for sentence in pred_doc]
|
| 83 |
+
gold_tags = [[x[1] for x in sentence] for sentence in gold_doc]
|
| 84 |
+
print("RESULTS ON: %s" % eval_file)
|
| 85 |
+
_, _, f_micro, _ = score_by_entity(pred_tags, gold_tags, ignore_tags=ignore_tags)
|
| 86 |
+
_, _, _, confusion = score_by_token(pred_tags, gold_tags, ignore_tags=ignore_tags)
|
| 87 |
+
print("NER token confusion matrix:\n{}".format(format_confusion(confusion, hide_blank=True, transpose=True)))
|
| 88 |
+
return f_micro
|
| 89 |
+
|
| 90 |
+
def main():
|
| 91 |
+
parser = argparse.ArgumentParser()
|
| 92 |
+
parser.add_argument('--ner_model', type=str, default=None, help='Which spacy model to test')
|
| 93 |
+
parser.add_argument('filename', type=str, nargs='*', help='which files to test')
|
| 94 |
+
parser.add_argument('--simplify', default=False, action='store_true', help='Simplify classes to the 8 class Worldwide model')
|
| 95 |
+
args = parser.parse_args()
|
| 96 |
+
|
| 97 |
+
if args.ner_model is None:
|
| 98 |
+
ner_models = ['en_core_web_sm', 'en_core_web_trf']
|
| 99 |
+
else:
|
| 100 |
+
ner_models = [args.ner_model]
|
| 101 |
+
|
| 102 |
+
if not args.filename:
|
| 103 |
+
args.filename = ["data/ner/en_ontonotes-8class.test.json",
|
| 104 |
+
"data/ner/en_worldwide-8class.test.json",
|
| 105 |
+
"data/ner/en_worldwide-8class-africa.test.json",
|
| 106 |
+
"data/ner/en_worldwide-8class-asia.test.json",
|
| 107 |
+
"data/ner/en_worldwide-8class-indigenous.test.json",
|
| 108 |
+
"data/ner/en_worldwide-8class-latam.test.json",
|
| 109 |
+
"data/ner/en_worldwide-8class-middle_east.test.json"]
|
| 110 |
+
|
| 111 |
+
print("Processing the files: %s" % ",".join(args.filename))
|
| 112 |
+
|
| 113 |
+
results = []
|
| 114 |
+
model_results = {}
|
| 115 |
+
|
| 116 |
+
for ner_model in ner_models:
|
| 117 |
+
model_results[ner_model] = []
|
| 118 |
+
# load tagger
|
| 119 |
+
print("-----------------------------")
|
| 120 |
+
print("Running %s" % ner_model)
|
| 121 |
+
print("-----------------------------")
|
| 122 |
+
tagger = spacy.load(ner_model, disable=["tagger", "parser", "attribute_ruler", "lemmatizer"])
|
| 123 |
+
|
| 124 |
+
for filename in args.filename:
|
| 125 |
+
f_micro = test_file(filename, tagger, args.simplify)
|
| 126 |
+
f_micro = "%.2f" % (f_micro * 100)
|
| 127 |
+
results.append((ner_model, filename, f_micro))
|
| 128 |
+
model_results[ner_model].append(f_micro)
|
| 129 |
+
|
| 130 |
+
for result in results:
|
| 131 |
+
print(result)
|
| 132 |
+
|
| 133 |
+
for model in model_results.keys():
|
| 134 |
+
result = [model] + model_results[model]
|
| 135 |
+
print(" & ".join(result))
|
| 136 |
+
|
| 137 |
+
if __name__ == '__main__':
|
| 138 |
+
main()
|
stanza/stanza/utils/training/__init__.py
ADDED
|
File without changes
|
stanza/stanza/utils/training/remove_constituency_optimizer.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Saved a huge, bloated model with an optimizer? Use this to remove it, greatly shrinking the model size
|
| 2 |
+
|
| 3 |
+
This tries to find reasonable defaults for word vectors and charlm
|
| 4 |
+
(which need to be loaded so that the model knows the matrix sizes)
|
| 5 |
+
|
| 6 |
+
so ideally all that needs to be run is
|
| 7 |
+
|
| 8 |
+
python3 stanza/utils/training/remove_constituency_optimizer.py <treebanks>
|
| 9 |
+
python3 stanza/utils/training/remove_constituency_optimizer.py da_arboretum ...
|
| 10 |
+
|
| 11 |
+
This can also be used to load and save models as part of an update
|
| 12 |
+
to the serialized format
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import argparse
|
| 16 |
+
import logging
|
| 17 |
+
import os
|
| 18 |
+
|
| 19 |
+
from stanza.models import constituency_parser
|
| 20 |
+
from stanza.models.common.constant import treebank_to_short_name
|
| 21 |
+
from stanza.resources.default_packages import default_charlms, default_pretrains
|
| 22 |
+
from stanza.utils.training import common
|
| 23 |
+
|
| 24 |
+
logger = logging.getLogger('stanza')
|
| 25 |
+
|
| 26 |
+
def parse_args():
|
| 27 |
+
parser = argparse.ArgumentParser()
|
| 28 |
+
parser.add_argument('--wordvec_pretrain_file', type=str, default=None, help='Exact name of the pretrain file to read')
|
| 29 |
+
parser.add_argument('--charlm', default="default", type=str, help='Which charlm to run on. Will use the default charlm for this language/model if not set. Set to None to turn off charlm for languages with a default charlm')
|
| 30 |
+
parser.add_argument('--no_charlm', dest='charlm', action="store_const", const=None, help="Don't use a charlm, even if one is used by default for this package")
|
| 31 |
+
|
| 32 |
+
parser.add_argument('--load_dir', type=str, default="saved_models/constituency", help="Root dir for getting the models to resave.")
|
| 33 |
+
parser.add_argument('--save_dir', type=str, default="resaved_models/constituency", help="Root dir for resaving the models.")
|
| 34 |
+
|
| 35 |
+
parser.add_argument('treebanks', type=str, nargs='+', help='Which treebanks to run on. Use all_ud or ud_all for all UD treebanks')
|
| 36 |
+
|
| 37 |
+
args = parser.parse_args()
|
| 38 |
+
return args
|
| 39 |
+
|
| 40 |
+
def main():
|
| 41 |
+
"""
|
| 42 |
+
For each of the models specified, load and resave the model
|
| 43 |
+
|
| 44 |
+
The resaved model will have the optimizer removed
|
| 45 |
+
"""
|
| 46 |
+
args = parse_args()
|
| 47 |
+
os.makedirs(args.save_dir, exist_ok=True)
|
| 48 |
+
|
| 49 |
+
for treebank in args.treebanks:
|
| 50 |
+
logger.info("PROCESSING %s", treebank)
|
| 51 |
+
short_name = treebank_to_short_name(treebank)
|
| 52 |
+
language, dataset = short_name.split("_", maxsplit=1)
|
| 53 |
+
logger.info("%s: %s %s", short_name, language, dataset)
|
| 54 |
+
|
| 55 |
+
if not args.wordvec_pretrain_file:
|
| 56 |
+
# will throw an error if the pretrain can't be found
|
| 57 |
+
wordvec_pretrain = common.find_wordvec_pretrain(language, default_pretrains)
|
| 58 |
+
wordvec_args = ['--wordvec_pretrain_file', wordvec_pretrain]
|
| 59 |
+
else:
|
| 60 |
+
wordvec_args = []
|
| 61 |
+
|
| 62 |
+
charlm = common.choose_charlm(language, dataset, args.charlm, default_charlms, {})
|
| 63 |
+
charlm_args = common.build_charlm_args(language, charlm, base_args=False)
|
| 64 |
+
|
| 65 |
+
base_name = '{}_constituency.pt'.format(short_name)
|
| 66 |
+
load_name = os.path.join(args.load_dir, base_name)
|
| 67 |
+
save_name = os.path.join(args.save_dir, base_name)
|
| 68 |
+
resave_args = ['--mode', 'remove_optimizer',
|
| 69 |
+
'--load_name', load_name,
|
| 70 |
+
'--save_name', save_name,
|
| 71 |
+
'--save_dir', ".",
|
| 72 |
+
'--shorthand', short_name]
|
| 73 |
+
resave_args = resave_args + wordvec_args + charlm_args
|
| 74 |
+
constituency_parser.main(resave_args)
|
| 75 |
+
|
| 76 |
+
if __name__ == '__main__':
|
| 77 |
+
main()
|
stanza/stanza/utils/visualization/dependency_visualization.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Functions to visualize dependency relations in texts and Stanza documents
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from stanza.models.common.constant import is_right_to_left
|
| 6 |
+
import stanza
|
| 7 |
+
import spacy
|
| 8 |
+
from spacy import displacy
|
| 9 |
+
from spacy.tokens import Doc
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def visualize_doc(doc, language):
|
| 13 |
+
"""
|
| 14 |
+
Takes in a Document and visualizes it using displacy.
|
| 15 |
+
|
| 16 |
+
The document to visualize must be from the stanza pipeline.
|
| 17 |
+
|
| 18 |
+
right-to-left languages such as Arabic are displayed right-to-left based on the language code
|
| 19 |
+
"""
|
| 20 |
+
visualization_options = {"compact": True, "bg": "#09a3d5", "color": "white", "distance": 90,
|
| 21 |
+
"font": "Source Sans Pro", "arrow_spacing": 25}
|
| 22 |
+
# blank model - we don't use any of the model features, just the viz
|
| 23 |
+
nlp = spacy.blank("en")
|
| 24 |
+
sentences_to_visualize = []
|
| 25 |
+
for sentence in doc.sentences:
|
| 26 |
+
words, lemmas, heads, deps, tags = [], [], [], [], []
|
| 27 |
+
if is_right_to_left(language): # order of words displayed is reversed, dependency arcs remain intact
|
| 28 |
+
sent_len = len(sentence.words)
|
| 29 |
+
for word in reversed(sentence.words):
|
| 30 |
+
words.append(word.text)
|
| 31 |
+
lemmas.append(word.lemma)
|
| 32 |
+
deps.append(word.deprel)
|
| 33 |
+
tags.append(word.upos)
|
| 34 |
+
if word.head == 0: # spaCy head indexes are formatted differently than that of Stanza
|
| 35 |
+
heads.append(sent_len - word.id)
|
| 36 |
+
else:
|
| 37 |
+
heads.append(sent_len - word.head)
|
| 38 |
+
else: # left to right rendering
|
| 39 |
+
for word in sentence.words:
|
| 40 |
+
words.append(word.text)
|
| 41 |
+
lemmas.append(word.lemma)
|
| 42 |
+
deps.append(word.deprel)
|
| 43 |
+
tags.append(word.upos)
|
| 44 |
+
if word.head == 0:
|
| 45 |
+
heads.append(word.id - 1)
|
| 46 |
+
else:
|
| 47 |
+
heads.append(word.head - 1)
|
| 48 |
+
document_result = Doc(nlp.vocab, words=words, lemmas=lemmas, heads=heads, deps=deps, pos=tags)
|
| 49 |
+
sentences_to_visualize.append(document_result)
|
| 50 |
+
|
| 51 |
+
for line in sentences_to_visualize: # render all sentences through displaCy
|
| 52 |
+
# If this program is NOT being run in a Jupyter notebook, replace displacy.render with displacy.serve
|
| 53 |
+
# and the visualization will be hosted locally, link being provided in the program output.
|
| 54 |
+
displacy.render(line, style="dep", options=visualization_options)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def visualize_str(text, pipeline_code, pipe):
|
| 58 |
+
"""
|
| 59 |
+
Takes a string and visualizes it using displacy.
|
| 60 |
+
|
| 61 |
+
The string is processed using the stanza pipeline and its
|
| 62 |
+
dependencies are formatted into a spaCy doc object for easy
|
| 63 |
+
visualization. Accepts valid stanza (UD) pipelines as the pipeline
|
| 64 |
+
argument. Must supply the stanza pipeline code (the two-letter
|
| 65 |
+
abbreviation of the language, such as 'en' for English. Must also
|
| 66 |
+
supply the stanza pipeline object as the third argument.
|
| 67 |
+
"""
|
| 68 |
+
doc = pipe(text)
|
| 69 |
+
visualize_doc(doc, pipeline_code)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def visualize_docs(docs, lang_code):
|
| 73 |
+
"""
|
| 74 |
+
Takes in a list of Stanza document objects and a language code (ex: 'en' for English) and visualizes the
|
| 75 |
+
dependency relationships within each document.
|
| 76 |
+
|
| 77 |
+
This function uses spaCy visualizations. See the visualize_doc function for more details.
|
| 78 |
+
"""
|
| 79 |
+
for doc in docs:
|
| 80 |
+
visualize_doc(doc, lang_code)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def visualize_strings(texts, lang_code):
|
| 84 |
+
"""
|
| 85 |
+
Takes a language code (ex: 'en' for English) and a list of strings to process and visualizes the
|
| 86 |
+
dependency relationships in each text.
|
| 87 |
+
|
| 88 |
+
This function loads the Stanza pipeline for the given language and uses it to visualize all of the strings provided.
|
| 89 |
+
"""
|
| 90 |
+
pipe = stanza.Pipeline(lang_code, processors="tokenize,pos,lemma,depparse")
|
| 91 |
+
for text in texts:
|
| 92 |
+
visualize_str(text, lang_code, pipe)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def main():
|
| 96 |
+
ar_strings = ['برلين ترفض حصول شركة اميركية على رخصة تصنيع دبابة "ليوبارد" الالمانية', "هل بإمكاني مساعدتك؟",
|
| 97 |
+
"أراك في مابعد", "لحظة من فضلك"]
|
| 98 |
+
en_strings = ["This is a sentence.",
|
| 99 |
+
"Barack Obama was born in Hawaii. He was elected President of the United States in 2008."]
|
| 100 |
+
zh_strings = ["中国是一个很有意思的国家。"]
|
| 101 |
+
# Testing with right to left language
|
| 102 |
+
visualize_strings(ar_strings, "ar")
|
| 103 |
+
# Testing with left to right languages
|
| 104 |
+
visualize_strings(en_strings, "en")
|
| 105 |
+
visualize_strings(zh_strings, "zh")
|
| 106 |
+
|
| 107 |
+
if __name__ == '__main__':
|
| 108 |
+
main()
|