bowphs commited on
Commit
986094c
·
verified ·
1 Parent(s): 495f002

Add files using upload-large-folder tool

Browse files
Files changed (50) hide show
  1. stanza/stanza/models/classifiers/constituency_classifier.py +96 -0
  2. stanza/stanza/models/classifiers/data.py +169 -0
  3. stanza/stanza/models/coref/bert.py +69 -0
  4. stanza/stanza/models/langid/__init__.py +0 -0
  5. stanza/stanza/models/langid/data.py +134 -0
  6. stanza/stanza/models/langid/model.py +126 -0
  7. stanza/stanza/models/lemma_classifier/__init__.py +0 -0
  8. stanza/stanza/models/ner/model.py +278 -0
  9. stanza/stanza/models/pos/scorer.py +22 -0
  10. stanza/stanza/models/pos/vocab.py +71 -0
  11. stanza/stanza/pipeline/demo/stanza-brat.js +1316 -0
  12. stanza/stanza/pipeline/external/corenlp_converter_depparse.py +29 -0
  13. stanza/stanza/pipeline/external/jieba.py +71 -0
  14. stanza/stanza/pipeline/external/sudachipy.py +84 -0
  15. stanza/stanza/utils/charlm/oscar_to_text.py +78 -0
  16. stanza/stanza/utils/constituency/__init__.py +0 -0
  17. stanza/stanza/utils/constituency/grep_test_logs.py +24 -0
  18. stanza/stanza/utils/datasets/constituency/build_silver_dataset.py +117 -0
  19. stanza/stanza/utils/datasets/constituency/convert_cintil.py +80 -0
  20. stanza/stanza/utils/datasets/constituency/count_common_words.py +12 -0
  21. stanza/stanza/utils/datasets/constituency/prepare_con_dataset.py +594 -0
  22. stanza/stanza/utils/datasets/constituency/silver_variance.py +108 -0
  23. stanza/stanza/utils/datasets/coref/convert_hindi.py +170 -0
  24. stanza/stanza/utils/datasets/ner/compare_entities.py +38 -0
  25. stanza/stanza/utils/datasets/ner/conll_to_iob.py +59 -0
  26. stanza/stanza/utils/datasets/ner/convert_bn_daffodil.py +123 -0
  27. stanza/stanza/utils/datasets/ner/convert_en_conll03.py +42 -0
  28. stanza/stanza/utils/datasets/ner/convert_he_iahlt.py +108 -0
  29. stanza/stanza/utils/datasets/ner/convert_lst20.py +74 -0
  30. stanza/stanza/utils/datasets/ner/convert_mr_l3cube.py +54 -0
  31. stanza/stanza/utils/datasets/ner/convert_nner22.py +70 -0
  32. stanza/stanza/utils/datasets/ner/convert_ontonotes.py +58 -0
  33. stanza/stanza/utils/datasets/ner/json_to_bio.py +43 -0
  34. stanza/stanza/utils/datasets/ner/misc_to_date.py +77 -0
  35. stanza/stanza/utils/datasets/ner/preprocess_wikiner.py +37 -0
  36. stanza/stanza/utils/datasets/ner/simplify_en_worldwide.py +152 -0
  37. stanza/stanza/utils/datasets/ner/simplify_ontonotes_to_worldwide.py +118 -0
  38. stanza/stanza/utils/datasets/ner/split_wikiner.py +104 -0
  39. stanza/stanza/utils/datasets/ner/suc_conll_to_iob.py +72 -0
  40. stanza/stanza/utils/datasets/pos/__init__.py +0 -0
  41. stanza/stanza/utils/datasets/pos/convert_trees_to_pos.py +94 -0
  42. stanza/stanza/utils/datasets/prepare_tokenizer_data.py +151 -0
  43. stanza/stanza/utils/datasets/prepare_tokenizer_treebank.py +1396 -0
  44. stanza/stanza/utils/datasets/pretrain/__init__.py +0 -0
  45. stanza/stanza/utils/datasets/tokenization/__init__.py +0 -0
  46. stanza/stanza/utils/datasets/tokenization/convert_vi_vlsp.py +155 -0
  47. stanza/stanza/utils/ner/spacy_ner_tag_dataset.py +138 -0
  48. stanza/stanza/utils/training/__init__.py +0 -0
  49. stanza/stanza/utils/training/remove_constituency_optimizer.py +77 -0
  50. stanza/stanza/utils/visualization/dependency_visualization.py +108 -0
stanza/stanza/models/classifiers/constituency_classifier.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A classifier that uses a constituency parser for the base embeddings
3
+ """
4
+
5
+ import dataclasses
6
+ import logging
7
+ from types import SimpleNamespace
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+ from stanza.models.classifiers.base_classifier import BaseClassifier
14
+ from stanza.models.classifiers.config import ConstituencyConfig
15
+ from stanza.models.classifiers.data import SentimentDatum
16
+ from stanza.models.classifiers.utils import ModelType, build_output_layers
17
+
18
+ from stanza.models.common.utils import split_into_batches, sort_with_indices, unsort
19
+
20
+ logger = logging.getLogger('stanza')
21
+ tlogger = logging.getLogger('stanza.classifiers.trainer')
22
+
23
+ class ConstituencyClassifier(BaseClassifier):
24
+ def __init__(self, tree_embedding, labels, args):
25
+ super(ConstituencyClassifier, self).__init__()
26
+ self.labels = labels
27
+ # we build a separate config out of the args so that we can easily save it in torch
28
+ self.config = ConstituencyConfig(fc_shapes = args.fc_shapes,
29
+ dropout = args.dropout,
30
+ num_classes = len(labels),
31
+ constituency_backprop = args.constituency_backprop,
32
+ constituency_batch_norm = args.constituency_batch_norm,
33
+ constituency_node_attn = args.constituency_node_attn,
34
+ constituency_top_layer = args.constituency_top_layer,
35
+ constituency_all_words = args.constituency_all_words,
36
+ model_type = ModelType.CONSTITUENCY)
37
+
38
+ self.tree_embedding = tree_embedding
39
+
40
+ self.fc_layers = build_output_layers(self.tree_embedding.output_size, self.config.fc_shapes, self.config.num_classes)
41
+ self.dropout = nn.Dropout(self.config.dropout)
42
+
43
+ def is_unsaved_module(self, name):
44
+ return False
45
+
46
+ def log_configuration(self):
47
+ tlogger.info("Backprop into parser: %s", self.config.constituency_backprop)
48
+ tlogger.info("Batch norm: %s", self.config.constituency_batch_norm)
49
+ tlogger.info("Word positions used: %s", "all words" if self.config.constituency_all_words else "start and end words")
50
+ tlogger.info("Attention over nodes: %s", self.config.constituency_node_attn)
51
+ tlogger.info("Intermediate layers: %s", self.config.fc_shapes)
52
+
53
+ def log_norms(self):
54
+ lines = ["NORMS FOR MODEL PARAMTERS"]
55
+ lines.extend(["tree_embedding." + x for x in self.tree_embedding.get_norms()])
56
+ for name, param in self.named_parameters():
57
+ if param.requires_grad and not name.startswith('tree_embedding.'):
58
+ lines.append("%s %.6g" % (name, torch.norm(param).item()))
59
+ logger.info("\n".join(lines))
60
+
61
+
62
+ def forward(self, inputs):
63
+ inputs = [x.constituency if isinstance(x, SentimentDatum) else x for x in inputs]
64
+
65
+ embedding = self.tree_embedding.embed_trees(inputs)
66
+ previous_layer = torch.stack([torch.max(x, dim=0)[0] for x in embedding], dim=0)
67
+ previous_layer = self.dropout(previous_layer)
68
+ for fc in self.fc_layers[:-1]:
69
+ # relu cause many neuron die
70
+ previous_layer = self.dropout(F.gelu(fc(previous_layer)))
71
+ out = self.fc_layers[-1](previous_layer)
72
+ return out
73
+
74
+ def get_params(self, skip_modules=True):
75
+ model_state = self.state_dict()
76
+ # skip all of the constituency parameters here -
77
+ # we will add them by calling the model's get_params()
78
+ skipped = [k for k in model_state.keys() if k.startswith("tree_embedding.")]
79
+ for k in skipped:
80
+ del model_state[k]
81
+
82
+ tree_embedding = self.tree_embedding.get_params(skip_modules)
83
+
84
+ config = dataclasses.asdict(self.config)
85
+ config['model_type'] = config['model_type'].name
86
+
87
+ params = {
88
+ 'model': model_state,
89
+ 'tree_embedding': tree_embedding,
90
+ 'config': config,
91
+ 'labels': self.labels,
92
+ }
93
+ return params
94
+
95
+ def extract_sentences(self, doc):
96
+ return [sentence.constituency for sentence in doc.sentences]
stanza/stanza/models/classifiers/data.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Stanza models classifier data functions."""
2
+
3
+ import collections
4
+ from collections import namedtuple
5
+ import logging
6
+ import json
7
+ import random
8
+ import re
9
+ from typing import List
10
+
11
+ from stanza.models.classifiers.utils import WVType
12
+ from stanza.models.common.vocab import PAD, PAD_ID, UNK, UNK_ID
13
+ import stanza.models.constituency.tree_reader as tree_reader
14
+
15
+ logger = logging.getLogger('stanza')
16
+
17
+ class SentimentDatum:
18
+ def __init__(self, sentiment, text, constituency=None):
19
+ self.sentiment = sentiment
20
+ self.text = text
21
+ self.constituency = constituency
22
+
23
+ def __eq__(self, other):
24
+ if self is other:
25
+ return True
26
+ if not isinstance(other, SentimentDatum):
27
+ return False
28
+ return self.sentiment == other.sentiment and self.text == other.text and self.constituency == other.constituency
29
+
30
+ def __str__(self):
31
+ return str(self._asdict())
32
+
33
+ def _asdict(self):
34
+ if self.constituency is None:
35
+ return {'sentiment': self.sentiment, 'text': self.text}
36
+ else:
37
+ return {'sentiment': self.sentiment, 'text': self.text, 'constituency': str(self.constituency)}
38
+
39
+ def update_text(sentence: List[str], wordvec_type: WVType) -> List[str]:
40
+ """
41
+ Process a line of text (with tokenization provided as whitespace)
42
+ into a list of strings.
43
+ """
44
+ # stanford sentiment dataset has a lot of random - and /
45
+ # remove those characters and flatten the newly created sublists into one list each time
46
+ sentence = [y for x in sentence for y in x.split("-") if y]
47
+ sentence = [y for x in sentence for y in x.split("/") if y]
48
+ sentence = [x.strip() for x in sentence]
49
+ sentence = [x for x in sentence if x]
50
+ if sentence == []:
51
+ # removed too much
52
+ sentence = ["-"]
53
+ # our current word vectors are all entirely lowercased
54
+ sentence = [word.lower() for word in sentence]
55
+ if wordvec_type == WVType.WORD2VEC:
56
+ return sentence
57
+ elif wordvec_type == WVType.GOOGLE:
58
+ new_sentence = []
59
+ for word in sentence:
60
+ if word != '0' and word != '1':
61
+ word = re.sub('[0-9]', '#', word)
62
+ new_sentence.append(word)
63
+ return new_sentence
64
+ elif wordvec_type == WVType.FASTTEXT:
65
+ return sentence
66
+ elif wordvec_type == WVType.OTHER:
67
+ return sentence
68
+ else:
69
+ raise ValueError("Unknown wordvec_type {}".format(wordvec_type))
70
+
71
+
72
+ def read_dataset(dataset, wordvec_type: WVType, min_len: int) -> List[SentimentDatum]:
73
+ """
74
+ returns a list where the values of the list are
75
+ label, [token...]
76
+ """
77
+ lines = []
78
+ for filename in str(dataset).split(","):
79
+ with open(filename, encoding="utf-8") as fin:
80
+ new_lines = json.load(fin)
81
+ new_lines = [(str(x['sentiment']), x['text'], x.get('constituency', None)) for x in new_lines]
82
+ lines.extend(new_lines)
83
+ # TODO: maybe do this processing later, once the model is built.
84
+ # then move the processing into the model so we can use
85
+ # overloading to potentially make future model types
86
+ lines = [SentimentDatum(x[0], update_text(x[1], wordvec_type), tree_reader.read_trees(x[2])[0] if x[2] else None) for x in lines]
87
+ if min_len:
88
+ lines = [x for x in lines if len(x.text) >= min_len]
89
+ return lines
90
+
91
+ def dataset_labels(dataset):
92
+ """
93
+ Returns a sorted list of label name
94
+ """
95
+ labels = set([x.sentiment for x in dataset])
96
+ if all(re.match("^[0-9]+$", label) for label in labels):
97
+ # if all of the labels are integers, sort numerically
98
+ # maybe not super important, but it would be nicer than having
99
+ # 10 before 2
100
+ labels = [str(x) for x in sorted(map(int, list(labels)))]
101
+ else:
102
+ labels = sorted(list(labels))
103
+ return labels
104
+
105
+ def dataset_vocab(dataset):
106
+ vocab = set()
107
+ for line in dataset:
108
+ for word in line.text:
109
+ vocab.add(word)
110
+ vocab = [PAD, UNK] + list(vocab)
111
+ if vocab[PAD_ID] != PAD or vocab[UNK_ID] != UNK:
112
+ raise ValueError("Unexpected values for PAD and UNK!")
113
+ return vocab
114
+
115
+ def sort_dataset_by_len(dataset, keep_index=False):
116
+ """
117
+ returns a dict mapping length -> list of items of that length
118
+
119
+ an OrderedDict is used so that the mapping is sorted from smallest to largest
120
+ """
121
+ sorted_dataset = collections.OrderedDict()
122
+ lengths = sorted(list(set(len(x.text) for x in dataset)))
123
+ for l in lengths:
124
+ sorted_dataset[l] = []
125
+ for item_idx, item in enumerate(dataset):
126
+ if keep_index:
127
+ sorted_dataset[len(item.text)].append((item, item_idx))
128
+ else:
129
+ sorted_dataset[len(item.text)].append(item)
130
+ return sorted_dataset
131
+
132
+ def shuffle_dataset(sorted_dataset, batch_size, batch_single_item):
133
+ """
134
+ Given a dataset sorted by len, sorts within each length to make
135
+ chunks of roughly the same size. Returns all items as a single list.
136
+ """
137
+ dataset = []
138
+ for l in sorted_dataset.keys():
139
+ items = list(sorted_dataset[l])
140
+ random.shuffle(items)
141
+ dataset.extend(items)
142
+ batches = []
143
+ next_batch = []
144
+ for item in dataset:
145
+ if batch_single_item > 0 and len(item.text) >= batch_single_item:
146
+ batches.append([item])
147
+ else:
148
+ next_batch.append(item)
149
+ if len(next_batch) >= batch_size:
150
+ batches.append(next_batch)
151
+ next_batch = []
152
+ if len(next_batch) > 0:
153
+ batches.append(next_batch)
154
+ random.shuffle(batches)
155
+ return batches
156
+
157
+
158
+ def check_labels(labels, dataset):
159
+ """
160
+ Check that all of the labels in the dataset are in the known labels.
161
+
162
+ Actually, unknown labels could be acceptable if we just treat the model as always wrong.
163
+ However, this is a good sanity check to make sure the datasets match
164
+ """
165
+ new_labels = dataset_labels(dataset)
166
+ not_found = [i for i in new_labels if i not in labels]
167
+ if not_found:
168
+ raise RuntimeError('Dataset contains labels which the model does not know about:' + str(not_found))
169
+
stanza/stanza/models/coref/bert.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Functions related to BERT or similar models"""
2
+
3
+ import logging
4
+ from typing import List, Tuple
5
+
6
+ import numpy as np # type: ignore
7
+ from transformers import AutoModel, AutoTokenizer # type: ignore
8
+
9
+ from stanza.models.coref.config import Config
10
+ from stanza.models.coref.const import Doc
11
+
12
+
13
+ logger = logging.getLogger('stanza')
14
+
15
+ def get_subwords_batches(doc: Doc,
16
+ config: Config,
17
+ tok: AutoTokenizer
18
+ ) -> np.ndarray:
19
+ """
20
+ Turns a list of subwords to a list of lists of subword indices
21
+ of max length == batch_size (or shorter, as batch boundaries
22
+ should match sentence boundaries). Each batch is enclosed in cls and sep
23
+ special tokens.
24
+
25
+ Returns:
26
+ batches of bert tokens [n_batches, batch_size]
27
+ """
28
+ batch_size = config.bert_window_size - 2 # to save space for CLS and SEP
29
+
30
+ subwords: List[str] = doc["subwords"]
31
+ subwords_batches = []
32
+ start, end = 0, 0
33
+
34
+ while end < len(subwords):
35
+ # to prevent the case where a batch_size step forward
36
+ # doesn't capture more than 1 sentence, we will just cut
37
+ # that sequence
38
+ prev_end = end
39
+ end = min(end + batch_size, len(subwords))
40
+
41
+ # Move back till we hit a sentence end
42
+ if end < len(subwords):
43
+ sent_id = doc["sent_id"][doc["word_id"][end]]
44
+ while end and doc["sent_id"][doc["word_id"][end - 1]] == sent_id:
45
+ end -= 1
46
+
47
+ # this occurs IFF there was no sentence end found throughout
48
+ # the forward scan; this means that our sentence was waay too
49
+ # long (i.e. longer than the max length of the transformer.
50
+ #
51
+ # if so, we give up and just chop the sentence off at the max length
52
+ # that was given
53
+ if end == prev_end:
54
+ end = min(end + batch_size, len(subwords))
55
+
56
+ length = end - start
57
+ if tok.cls_token == None or tok.sep_token == None:
58
+ batch = [tok.eos_token] + subwords[start:end] + [tok.eos_token]
59
+ else:
60
+ batch = [tok.cls_token] + subwords[start:end] + [tok.sep_token]
61
+
62
+ # Padding to desired length
63
+ batch += [tok.pad_token] * (batch_size - length)
64
+
65
+ subwords_batches.append([tok.convert_tokens_to_ids(token)
66
+ for token in batch])
67
+ start += length
68
+
69
+ return np.array(subwords_batches)
stanza/stanza/models/langid/__init__.py ADDED
File without changes
stanza/stanza/models/langid/data.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import random
3
+ import torch
4
+
5
+
6
+ class DataLoader:
7
+ """
8
+ Class for loading language id data and providing batches
9
+
10
+ Attempt to recreate data pre-processing from: https://github.com/AU-DIS/LSTM_langid
11
+
12
+ Uses methods from: https://github.com/AU-DIS/LSTM_langid/blob/main/src/language_datasets.py
13
+
14
+ Data format is same as LSTM_langid
15
+ """
16
+
17
+ def __init__(self, device=None):
18
+ self.batches = None
19
+ self.batches_iter = None
20
+ self.tag_to_idx = None
21
+ self.idx_to_tag = None
22
+ self.lang_weights = None
23
+ self.device = device
24
+
25
+ def load_data(self, batch_size, data_files, char_index, tag_index, randomize=False, randomize_range=(5,20),
26
+ max_length=None):
27
+ """
28
+ Load sequence data and labels, calculate weights for weighted cross entropy loss.
29
+ Data is stored in a file, 1 example per line
30
+ Example: {"text": "Hello world.", "label": "en"}
31
+ """
32
+
33
+ # set up examples from data files
34
+ examples = []
35
+ for data_file in data_files:
36
+ examples += [x for x in open(data_file).read().split("\n") if x.strip()]
37
+ random.shuffle(examples)
38
+ examples = [json.loads(x) for x in examples]
39
+
40
+ # add additional labels in this data set to tag index
41
+ tag_index = dict(tag_index)
42
+ new_labels = set([x["label"] for x in examples]) - set(tag_index.keys())
43
+ for new_label in new_labels:
44
+ tag_index[new_label] = len(tag_index)
45
+ self.tag_to_idx = tag_index
46
+ self.idx_to_tag = [i[1] for i in sorted([(v,k) for k,v in self.tag_to_idx.items()])]
47
+
48
+ # set up lang counts used for weights for cross entropy loss
49
+ lang_counts = [0 for _ in tag_index]
50
+
51
+ # optionally limit text to max length
52
+ if max_length is not None:
53
+ examples = [{"text": x["text"][:max_length], "label": x["label"]} for x in examples]
54
+
55
+ # randomize data
56
+ if randomize:
57
+ split_examples = []
58
+ for example in examples:
59
+ sequence = example["text"]
60
+ label = example["label"]
61
+ sequences = DataLoader.randomize_data([sequence], upper_lim=randomize_range[1],
62
+ lower_lim=randomize_range[0])
63
+ split_examples += [{"text": seq, "label": label} for seq in sequences]
64
+ examples = split_examples
65
+ random.shuffle(examples)
66
+
67
+ # break into equal length batches
68
+ batch_lengths = {}
69
+ for example in examples:
70
+ sequence = example["text"]
71
+ label = example["label"]
72
+ if len(sequence) not in batch_lengths:
73
+ batch_lengths[len(sequence)] = []
74
+ sequence_as_list = [char_index.get(c, char_index["UNK"]) for c in list(sequence)]
75
+ batch_lengths[len(sequence)].append((sequence_as_list, tag_index[label]))
76
+ lang_counts[tag_index[label]] += 1
77
+ for length in batch_lengths:
78
+ random.shuffle(batch_lengths[length])
79
+
80
+ # create final set of batches
81
+ batches = []
82
+ for length in batch_lengths:
83
+ for sublist in [batch_lengths[length][i:i + batch_size] for i in
84
+ range(0, len(batch_lengths[length]), batch_size)]:
85
+ batches.append(sublist)
86
+
87
+ self.batches = [self.build_batch_tensors(batch) for batch in batches]
88
+
89
+ # set up lang weights
90
+ most_frequent = max(lang_counts)
91
+ # set to 0.0 if lang_count is 0 or most_frequent/lang_count otherwise
92
+ lang_counts = [(most_frequent * x)/(max(1, x) ** 2) for x in lang_counts]
93
+ self.lang_weights = torch.tensor(lang_counts, device=self.device, dtype=torch.float)
94
+
95
+ # shuffle batches to mix up lengths
96
+ random.shuffle(self.batches)
97
+ self.batches_iter = iter(self.batches)
98
+
99
+ @staticmethod
100
+ def randomize_data(sentences, upper_lim=20, lower_lim=5):
101
+ """
102
+ Takes the original data and creates random length examples with length between upper limit and lower limit
103
+ From LSTM_langid: https://github.com/AU-DIS/LSTM_langid/blob/main/src/language_datasets.py
104
+ """
105
+
106
+ new_data = []
107
+ for sentence in sentences:
108
+ remaining = sentence
109
+ while lower_lim < len(remaining):
110
+ lim = random.randint(lower_lim, upper_lim)
111
+ m = min(len(remaining), lim)
112
+ new_sentence = remaining[:m]
113
+ new_data.append(new_sentence)
114
+ split = remaining[m:].split(" ", 1)
115
+ if len(split) <= 1:
116
+ break
117
+ remaining = split[1]
118
+ random.shuffle(new_data)
119
+ return new_data
120
+
121
+ def build_batch_tensors(self, batch):
122
+ """
123
+ Helper to turn batches into tensors
124
+ """
125
+
126
+ batch_tensors = dict()
127
+ batch_tensors["sentences"] = torch.tensor([s[0] for s in batch], device=self.device, dtype=torch.long)
128
+ batch_tensors["targets"] = torch.tensor([s[1] for s in batch], device=self.device, dtype=torch.long)
129
+
130
+ return batch_tensors
131
+
132
+ def next(self):
133
+ return next(self.batches_iter)
134
+
stanza/stanza/models/langid/model.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ class LangIDBiLSTM(nn.Module):
8
+ """
9
+ Multi-layer BiLSTM model for language detecting. A recreation of "A reproduction of Apple's bi-directional LSTM models
10
+ for language identification in short strings." (Toftrup et al 2021)
11
+
12
+ Arxiv: https://arxiv.org/abs/2102.06282
13
+ GitHub: https://github.com/AU-DIS/LSTM_langid
14
+
15
+ This class is similar to https://github.com/AU-DIS/LSTM_langid/blob/main/src/LSTMLID.py
16
+ """
17
+
18
+ def __init__(self, char_to_idx, tag_to_idx, num_layers, embedding_dim, hidden_dim, batch_size=64, weights=None,
19
+ dropout=0.0, lang_subset=None):
20
+ super(LangIDBiLSTM, self).__init__()
21
+ self.num_layers = num_layers
22
+ self.embedding_dim = embedding_dim
23
+ self.hidden_dim = hidden_dim
24
+ self.char_to_idx = char_to_idx
25
+ self.vocab_size = len(char_to_idx)
26
+ self.tag_to_idx = tag_to_idx
27
+ self.idx_to_tag = [i[1] for i in sorted([(v,k) for k,v in self.tag_to_idx.items()])]
28
+ self.lang_subset = lang_subset
29
+ self.padding_idx = char_to_idx["<PAD>"]
30
+ self.tagset_size = len(tag_to_idx)
31
+ self.batch_size = batch_size
32
+ self.loss_train = nn.CrossEntropyLoss(weight=weights)
33
+ self.dropout_prob = dropout
34
+
35
+ # embeddings for chars
36
+ self.char_embeds = nn.Embedding(
37
+ num_embeddings=self.vocab_size,
38
+ embedding_dim=self.embedding_dim,
39
+ padding_idx=self.padding_idx
40
+ )
41
+
42
+ # the bidirectional LSTM
43
+ self.lstm = nn.LSTM(
44
+ self.embedding_dim,
45
+ self.hidden_dim,
46
+ num_layers=self.num_layers,
47
+ bidirectional=True,
48
+ batch_first=True
49
+ )
50
+
51
+ # convert output to tag space
52
+ self.hidden_to_tag = nn.Linear(
53
+ self.hidden_dim * 2,
54
+ self.tagset_size
55
+ )
56
+
57
+ # dropout layer
58
+ self.dropout = nn.Dropout(p=self.dropout_prob)
59
+
60
+ def build_lang_mask(self, device):
61
+ """
62
+ Build language mask if a lang subset is specified (e.g. ["en", "fr"])
63
+
64
+ The mask will be added to the results to set the prediction scores of illegal languages to -inf
65
+ """
66
+ if self.lang_subset:
67
+ lang_mask_list = [0.0 if lang in self.lang_subset else -float('inf') for lang in self.idx_to_tag]
68
+ self.lang_mask = torch.tensor(lang_mask_list, device=device, dtype=torch.float)
69
+ else:
70
+ self.lang_mask = torch.zeros(len(self.idx_to_tag), device=device, dtype=torch.float)
71
+
72
+ def loss(self, Y_hat, Y):
73
+ return self.loss_train(Y_hat, Y)
74
+
75
+ def forward(self, x):
76
+ # embed input
77
+ x = self.char_embeds(x)
78
+
79
+ # run through LSTM
80
+ x, _ = self.lstm(x)
81
+
82
+ # run through linear layer
83
+ x = self.hidden_to_tag(x)
84
+
85
+ # sum character outputs for each sequence
86
+ x = torch.sum(x, dim=1)
87
+
88
+ return x
89
+
90
+ def prediction_scores(self, x):
91
+ prediction_probs = self(x)
92
+ if self.lang_subset:
93
+ prediction_batch_size = prediction_probs.size()[0]
94
+ batch_mask = torch.stack([self.lang_mask for _ in range(prediction_batch_size)])
95
+ prediction_probs = prediction_probs + batch_mask
96
+ return torch.argmax(prediction_probs, dim=1)
97
+
98
+ def save(self, path):
99
+ """ Save a model at path """
100
+ checkpoint = {
101
+ "char_to_idx": self.char_to_idx,
102
+ "tag_to_idx": self.tag_to_idx,
103
+ "num_layers": self.num_layers,
104
+ "embedding_dim": self.embedding_dim,
105
+ "hidden_dim": self.hidden_dim,
106
+ "model_state_dict": self.state_dict()
107
+ }
108
+ torch.save(checkpoint, path)
109
+
110
+ @classmethod
111
+ def load(cls, path, device=None, batch_size=64, lang_subset=None):
112
+ """ Load a serialized model located at path """
113
+ if path is None:
114
+ raise FileNotFoundError("Trying to load langid model, but path not specified! Try --load_name")
115
+ if not os.path.exists(path):
116
+ raise FileNotFoundError("Trying to load langid model from path which does not exist: %s" % path)
117
+ checkpoint = torch.load(path, map_location=torch.device("cpu"), weights_only=True)
118
+ weights = checkpoint["model_state_dict"]["loss_train.weight"]
119
+ model = cls(checkpoint["char_to_idx"], checkpoint["tag_to_idx"], checkpoint["num_layers"],
120
+ checkpoint["embedding_dim"], checkpoint["hidden_dim"], batch_size=batch_size, weights=weights,
121
+ lang_subset=lang_subset)
122
+ model.load_state_dict(checkpoint["model_state_dict"])
123
+ model = model.to(device)
124
+ model.build_lang_mask(device)
125
+ return model
126
+
stanza/stanza/models/lemma_classifier/__init__.py ADDED
File without changes
stanza/stanza/models/ner/model.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, pack_sequence, pad_sequence, PackedSequence
9
+
10
+ from stanza.models.common.data import map_to_ids, get_long_tensor
11
+ from stanza.models.common.exceptions import ForwardCharlmNotFoundError, BackwardCharlmNotFoundError
12
+ from stanza.models.common.packed_lstm import PackedLSTM
13
+ from stanza.models.common.dropout import WordDropout, LockedDropout
14
+ from stanza.models.common.char_model import CharacterModel, CharacterLanguageModel
15
+ from stanza.models.common.crf import CRFLoss
16
+ from stanza.models.common.foundation_cache import load_bert
17
+ from stanza.models.common.utils import attach_bert_model
18
+ from stanza.models.common.vocab import PAD_ID, UNK_ID, EMPTY_ID
19
+ from stanza.models.common.bert_embedding import extract_bert_embeddings
20
+
21
+ logger = logging.getLogger('stanza')
22
+
23
+ # this gets created in two places in trainer
24
+ # in both places, pass in the bert model & tokenizer
25
+ class NERTagger(nn.Module):
26
+ def __init__(self, args, vocab, emb_matrix=None, foundation_cache=None, bert_model=None, bert_tokenizer=None, force_bert_saved=False, peft_name=None):
27
+ super().__init__()
28
+
29
+ self.vocab = vocab
30
+ self.args = args
31
+ self.unsaved_modules = []
32
+
33
+ # input layers
34
+ input_size = 0
35
+ if self.args['word_emb_dim'] > 0:
36
+ emb_finetune = self.args.get('emb_finetune', True)
37
+
38
+ # load pretrained embeddings if specified
39
+ word_emb = nn.Embedding(len(self.vocab['word']), self.args['word_emb_dim'], PAD_ID)
40
+ # if a model trained with no 'delta' vocab is loaded, and
41
+ # emb_finetune is off, any resaving of the model will need
42
+ # the updated vectors. this is accounted for in load()
43
+ if not emb_finetune or 'delta' in self.vocab:
44
+ # if emb_finetune is off
45
+ # or if the delta embedding is present
46
+ # then we won't fine tune the original embedding
47
+ self.add_unsaved_module('word_emb', word_emb)
48
+ self.word_emb.weight.detach_()
49
+ else:
50
+ self.word_emb = word_emb
51
+ if emb_matrix is not None:
52
+ self.init_emb(emb_matrix)
53
+
54
+ # TODO: allow for expansion of delta embedding if new
55
+ # training data has new words in it?
56
+ self.delta_emb = None
57
+ if 'delta' in self.vocab:
58
+ # zero inits seems to work better
59
+ # note that the gradient will flow to the bottom and then adjust the 0 weights
60
+ # as opposed to a 0 matrix cutting off the gradient if higher up in the model
61
+ self.delta_emb = nn.Embedding(len(self.vocab['delta']), self.args['word_emb_dim'], PAD_ID)
62
+ nn.init.zeros_(self.delta_emb.weight)
63
+ # if the model was trained with a delta embedding, but emb_finetune is off now,
64
+ # then we will detach the delta embedding
65
+ if not emb_finetune:
66
+ self.delta_emb.weight.detach_()
67
+
68
+ input_size += self.args['word_emb_dim']
69
+
70
+ self.peft_name = peft_name
71
+ attach_bert_model(self, bert_model, bert_tokenizer, self.args.get('use_peft', False), force_bert_saved)
72
+ if self.args.get('bert_model', None):
73
+ # TODO: refactor bert_hidden_layers between the different models
74
+ if args.get('bert_hidden_layers', False):
75
+ # The average will be offset by 1/N so that the default zeros
76
+ # represents an average of the N layers
77
+ self.bert_layer_mix = nn.Linear(args['bert_hidden_layers'], 1, bias=False)
78
+ nn.init.zeros_(self.bert_layer_mix.weight)
79
+ else:
80
+ # an average of layers 2, 3, 4 will be used
81
+ # (for historic reasons)
82
+ self.bert_layer_mix = None
83
+ input_size += self.bert_model.config.hidden_size
84
+
85
+ if self.args['char'] and self.args['char_emb_dim'] > 0:
86
+ if self.args['charlm']:
87
+ if args['charlm_forward_file'] is None or not os.path.exists(args['charlm_forward_file']):
88
+ raise ForwardCharlmNotFoundError('Could not find forward character model: {} Please specify with --charlm_forward_file'.format(args['charlm_forward_file']), args['charlm_forward_file'])
89
+ if args['charlm_backward_file'] is None or not os.path.exists(args['charlm_backward_file']):
90
+ raise BackwardCharlmNotFoundError('Could not find backward character model: {} Please specify with --charlm_backward_file'.format(args['charlm_backward_file']), args['charlm_backward_file'])
91
+ self.add_unsaved_module('charmodel_forward', CharacterLanguageModel.load(args['charlm_forward_file'], finetune=False))
92
+ self.add_unsaved_module('charmodel_backward', CharacterLanguageModel.load(args['charlm_backward_file'], finetune=False))
93
+ input_size += self.charmodel_forward.hidden_dim() + self.charmodel_backward.hidden_dim()
94
+ else:
95
+ self.charmodel = CharacterModel(args, vocab, bidirectional=True, attention=False)
96
+ input_size += self.args['char_hidden_dim'] * 2
97
+
98
+ # optionally add a input transformation layer
99
+ if self.args.get('input_transform', False):
100
+ self.input_transform = nn.Linear(input_size, input_size)
101
+ else:
102
+ self.input_transform = None
103
+
104
+ # recurrent layers
105
+ self.taggerlstm = PackedLSTM(input_size, self.args['hidden_dim'], self.args['num_layers'], batch_first=True, \
106
+ bidirectional=True, dropout=0 if self.args['num_layers'] == 1 else self.args['dropout'])
107
+ # self.drop_replacement = nn.Parameter(torch.randn(input_size) / np.sqrt(input_size))
108
+ self.drop_replacement = None
109
+ self.taggerlstm_h_init = nn.Parameter(torch.zeros(2 * self.args['num_layers'], 1, self.args['hidden_dim']), requires_grad=False)
110
+ self.taggerlstm_c_init = nn.Parameter(torch.zeros(2 * self.args['num_layers'], 1, self.args['hidden_dim']), requires_grad=False)
111
+
112
+ # tag classifier
113
+ tag_lengths = self.vocab['tag'].lens()
114
+ self.num_output_layers = len(tag_lengths)
115
+ if self.args.get('connect_output_layers'):
116
+ tag_clfs = [nn.Linear(self.args['hidden_dim']*2, tag_lengths[0])]
117
+ for prev_length, next_length in zip(tag_lengths[:-1], tag_lengths[1:]):
118
+ tag_clfs.append(nn.Linear(self.args['hidden_dim']*2 + prev_length, next_length))
119
+ self.tag_clfs = nn.ModuleList(tag_clfs)
120
+ else:
121
+ self.tag_clfs = nn.ModuleList([nn.Linear(self.args['hidden_dim']*2, num_tag) for num_tag in tag_lengths])
122
+ for tag_clf in self.tag_clfs:
123
+ tag_clf.bias.data.zero_()
124
+ self.crits = nn.ModuleList([CRFLoss(num_tag) for num_tag in tag_lengths])
125
+
126
+ self.drop = nn.Dropout(args['dropout'])
127
+ self.worddrop = WordDropout(args['word_dropout'])
128
+ self.lockeddrop = LockedDropout(args['locked_dropout'])
129
+
130
+ def init_emb(self, emb_matrix):
131
+ if isinstance(emb_matrix, np.ndarray):
132
+ emb_matrix = torch.from_numpy(emb_matrix)
133
+ vocab_size = len(self.vocab['word'])
134
+ dim = self.args['word_emb_dim']
135
+ assert emb_matrix.size() == (vocab_size, dim), \
136
+ "Input embedding matrix must match size: {} x {}, found {}".format(vocab_size, dim, emb_matrix.size())
137
+ self.word_emb.weight.data.copy_(emb_matrix)
138
+
139
+ def add_unsaved_module(self, name, module):
140
+ self.unsaved_modules += [name]
141
+ setattr(self, name, module)
142
+
143
+ def log_norms(self):
144
+ lines = ["NORMS FOR MODEL PARAMTERS"]
145
+ for name, param in self.named_parameters():
146
+ if param.requires_grad and name.split(".")[0] not in ('charmodel_forward', 'charmodel_backward'):
147
+ lines.append(" %s %.6g" % (name, torch.norm(param).item()))
148
+ logger.info("\n".join(lines))
149
+
150
+ def forward(self, sentences, wordchars, wordchars_mask, tags, word_orig_idx, sentlens, wordlens, chars, charoffsets, charlens, char_orig_idx):
151
+ device = next(self.parameters()).device
152
+
153
+ def pack(x):
154
+ return pack_padded_sequence(x, sentlens, batch_first=True)
155
+
156
+ inputs = []
157
+ batch_size = len(sentences)
158
+
159
+ if self.args['word_emb_dim'] > 0:
160
+ #extract static embeddings
161
+ static_words, word_mask = self.extract_static_embeddings(self.args, sentences, self.vocab['word'])
162
+
163
+ word_mask = word_mask.to(device)
164
+ static_words = static_words.to(device)
165
+
166
+ word_static_emb = self.word_emb(static_words)
167
+
168
+ if 'delta' in self.vocab and self.delta_emb is not None:
169
+ # masks should be the same
170
+ delta_words, _ = self.extract_static_embeddings(self.args, sentences, self.vocab['delta'])
171
+ delta_words = delta_words.to(device)
172
+ # unclear whether to treat words in the main embedding
173
+ # but not in delta as unknown
174
+ # simple heuristic though - treating them as not
175
+ # unknown keeps existing models the same when
176
+ # separating models into the base WV and delta WV
177
+ # also, note that at training time, words like this
178
+ # did not show up in the training data, but are
179
+ # not exactly UNK, so it makes sense
180
+ delta_unk_mask = torch.eq(delta_words, UNK_ID)
181
+ static_unk_mask = torch.not_equal(static_words, UNK_ID)
182
+ unk_mask = delta_unk_mask * static_unk_mask
183
+ delta_words[unk_mask] = PAD_ID
184
+
185
+ delta_emb = self.delta_emb(delta_words)
186
+ word_static_emb = word_static_emb + delta_emb
187
+
188
+ word_emb = pack(word_static_emb)
189
+ inputs += [word_emb]
190
+
191
+ if self.bert_model is not None:
192
+ device = next(self.parameters()).device
193
+ processed_bert = extract_bert_embeddings(self.args['bert_model'], self.bert_tokenizer, self.bert_model, sentences, device, keep_endpoints=False,
194
+ num_layers=self.bert_layer_mix.in_features if self.bert_layer_mix is not None else None,
195
+ detach=not self.args.get('bert_finetune', False),
196
+ peft_name=self.peft_name)
197
+ if self.bert_layer_mix is not None:
198
+ # use a linear layer to weighted average the embedding dynamically
199
+ processed_bert = [self.bert_layer_mix(feature).squeeze(2) + feature.sum(axis=2) / self.bert_layer_mix.in_features for feature in processed_bert]
200
+
201
+ processed_bert = pad_sequence(processed_bert, batch_first=True)
202
+ inputs += [pack(processed_bert)]
203
+
204
+ def pad(x):
205
+ return pad_packed_sequence(PackedSequence(x, word_emb.batch_sizes), batch_first=True)[0]
206
+
207
+ if self.args['char'] and self.args['char_emb_dim'] > 0:
208
+ if self.args.get('charlm', None):
209
+ char_reps_forward = self.charmodel_forward.get_representation(chars[0], charoffsets[0], charlens, char_orig_idx)
210
+ char_reps_forward = PackedSequence(char_reps_forward.data, char_reps_forward.batch_sizes)
211
+ char_reps_backward = self.charmodel_backward.get_representation(chars[1], charoffsets[1], charlens, char_orig_idx)
212
+ char_reps_backward = PackedSequence(char_reps_backward.data, char_reps_backward.batch_sizes)
213
+ inputs += [char_reps_forward, char_reps_backward]
214
+ else:
215
+ char_reps = self.charmodel(wordchars, wordchars_mask, word_orig_idx, sentlens, wordlens)
216
+ char_reps = PackedSequence(char_reps.data, char_reps.batch_sizes)
217
+ inputs += [char_reps]
218
+
219
+ lstm_inputs = torch.cat([x.data for x in inputs], 1)
220
+ if self.args['word_dropout'] > 0:
221
+ lstm_inputs = self.worddrop(lstm_inputs, self.drop_replacement)
222
+ lstm_inputs = self.drop(lstm_inputs)
223
+ lstm_inputs = pad(lstm_inputs)
224
+ lstm_inputs = self.lockeddrop(lstm_inputs)
225
+ lstm_inputs = pack(lstm_inputs).data
226
+
227
+ if self.input_transform:
228
+ lstm_inputs = self.input_transform(lstm_inputs)
229
+
230
+ lstm_inputs = PackedSequence(lstm_inputs, inputs[0].batch_sizes)
231
+ lstm_outputs, _ = self.taggerlstm(lstm_inputs, sentlens, hx=(\
232
+ self.taggerlstm_h_init.expand(2 * self.args['num_layers'], batch_size, self.args['hidden_dim']).contiguous(), \
233
+ self.taggerlstm_c_init.expand(2 * self.args['num_layers'], batch_size, self.args['hidden_dim']).contiguous()))
234
+ lstm_outputs = lstm_outputs.data
235
+
236
+
237
+ # prediction layer
238
+ lstm_outputs = self.drop(lstm_outputs)
239
+ lstm_outputs = pad(lstm_outputs)
240
+ lstm_outputs = self.lockeddrop(lstm_outputs)
241
+ lstm_outputs = pack(lstm_outputs).data
242
+
243
+ loss = 0
244
+ logits = []
245
+ trans = []
246
+ for idx, (tag_clf, crit) in enumerate(zip(self.tag_clfs, self.crits)):
247
+ if not self.args.get('connect_output_layers') or idx == 0:
248
+ next_logits = pad(tag_clf(lstm_outputs)).contiguous()
249
+ else:
250
+ # here we pack the output of the previous round, then append it
251
+ packed_logits = pack(next_logits).data
252
+ input_logits = torch.cat([lstm_outputs, packed_logits], axis=1)
253
+ next_logits = pad(tag_clf(input_logits)).contiguous()
254
+ # the tag_mask lets us avoid backprop on a blank tag
255
+ tag_mask = torch.eq(tags[:, :, idx], EMPTY_ID)
256
+ next_loss, next_trans = crit(next_logits, torch.bitwise_or(tag_mask, word_mask), tags[:, :, idx])
257
+ loss = loss + next_loss
258
+ logits.append(next_logits)
259
+ trans.append(next_trans)
260
+
261
+ return loss, logits, trans
262
+
263
+ @staticmethod
264
+ def extract_static_embeddings(args, sents, vocab):
265
+ processed = []
266
+ if args.get('lowercase', True): # handle word case
267
+ case = lambda x: x.lower()
268
+ else:
269
+ case = lambda x: x
270
+ for idx, sent in enumerate(sents):
271
+ processed_sent = [vocab.map([case(w) for w in sent])]
272
+ processed.append(processed_sent[0])
273
+
274
+ words = get_long_tensor(processed, len(sents))
275
+ words_mask = torch.eq(words, PAD_ID)
276
+
277
+ return words, words_mask
278
+
stanza/stanza/models/pos/scorer.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utils and wrappers for scoring taggers.
3
+ """
4
+ import logging
5
+
6
+ from stanza.models.common.utils import ud_scores
7
+
8
+ logger = logging.getLogger('stanza')
9
+
10
+ def score(system_conllu_file, gold_conllu_file, verbose=True, eval_type='AllTags'):
11
+ """ Wrapper for tagger scorer. """
12
+ evaluation = ud_scores(gold_conllu_file, system_conllu_file)
13
+ el = evaluation[eval_type]
14
+ p = el.precision
15
+ r = el.recall
16
+ f = el.f1
17
+ if verbose:
18
+ scores = [evaluation[k].f1 * 100 for k in ['UPOS', 'XPOS', 'UFeats', 'AllTags']]
19
+ logger.info("UPOS\tXPOS\tUFeats\tAllTags")
20
+ logger.info("{:.2f}\t{:.2f}\t{:.2f}\t{:.2f}".format(*scores))
21
+ return p, r, f
22
+
stanza/stanza/models/pos/vocab.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import Counter, OrderedDict
2
+
3
+ from stanza.models.common.vocab import BaseVocab, BaseMultiVocab, CharVocab
4
+ from stanza.models.common.vocab import CompositeVocab, VOCAB_PREFIX, EMPTY, EMPTY_ID
5
+
6
+ class WordVocab(BaseVocab):
7
+ def __init__(self, data=None, lang="", idx=0, cutoff=0, lower=False, ignore=None):
8
+ self.ignore = ignore if ignore is not None else []
9
+ super().__init__(data, lang=lang, idx=idx, cutoff=cutoff, lower=lower)
10
+ self.state_attrs += ['ignore']
11
+
12
+ def id2unit(self, id):
13
+ if len(self.ignore) > 0 and id == EMPTY_ID:
14
+ return '_'
15
+ else:
16
+ return super().id2unit(id)
17
+
18
+ def unit2id(self, unit):
19
+ if len(self.ignore) > 0 and unit in self.ignore:
20
+ return self._unit2id[EMPTY]
21
+ else:
22
+ return super().unit2id(unit)
23
+
24
+ def build_vocab(self):
25
+ if self.lower:
26
+ counter = Counter([w[self.idx].lower() for sent in self.data for w in sent])
27
+ else:
28
+ counter = Counter([w[self.idx] for sent in self.data for w in sent])
29
+ for k in list(counter.keys()):
30
+ if counter[k] < self.cutoff or k in self.ignore:
31
+ del counter[k]
32
+
33
+ self._id2unit = VOCAB_PREFIX + list(sorted(list(counter.keys()), key=lambda k: counter[k], reverse=True))
34
+ self._unit2id = {w:i for i, w in enumerate(self._id2unit)}
35
+
36
+ def __str__(self):
37
+ return "<{}: {}>".format(type(self), ",".join("|%s|" % x for x in self._id2unit))
38
+
39
+ class XPOSVocab(CompositeVocab):
40
+ def __init__(self, data=None, lang="", idx=0, sep="", keyed=False):
41
+ super().__init__(data, lang, idx=idx, sep=sep, keyed=keyed)
42
+
43
+ class FeatureVocab(CompositeVocab):
44
+ def __init__(self, data=None, lang="", idx=0, sep="|", keyed=True):
45
+ super().__init__(data, lang, idx=idx, sep=sep, keyed=keyed)
46
+
47
+ class MultiVocab(BaseMultiVocab):
48
+ def state_dict(self):
49
+ """ Also save a vocab name to class name mapping in state dict. """
50
+ state = OrderedDict()
51
+ key2class = OrderedDict()
52
+ for k, v in self._vocabs.items():
53
+ state[k] = v.state_dict()
54
+ key2class[k] = type(v).__name__
55
+ state['_key2class'] = key2class
56
+ return state
57
+
58
+ @classmethod
59
+ def load_state_dict(cls, state_dict):
60
+ class_dict = {'CharVocab': CharVocab,
61
+ 'WordVocab': WordVocab,
62
+ 'XPOSVocab': XPOSVocab,
63
+ 'FeatureVocab': FeatureVocab}
64
+ new = cls()
65
+ assert '_key2class' in state_dict, "Cannot find class name mapping in state dict!"
66
+ key2class = state_dict.pop('_key2class')
67
+ for k,v in state_dict.items():
68
+ classname = key2class[k]
69
+ new[k] = class_dict[classname].load_state_dict(v)
70
+ return new
71
+
stanza/stanza/pipeline/demo/stanza-brat.js ADDED
@@ -0,0 +1,1316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Takes Stanford CoreNLP JSON output (var data = ... in data.js)
2
+ // and uses brat to render everything.
3
+
4
+ //var serverAddress = 'http://localhost:5000';
5
+
6
+ // Load Brat libraries
7
+ var bratLocation = 'https://nlp.stanford.edu/js/brat/';
8
+ head.js(
9
+ // External libraries
10
+ bratLocation + '/client/lib/jquery.svg.min.js',
11
+ bratLocation + '/client/lib/jquery.svgdom.min.js',
12
+
13
+ // brat helper modules
14
+ bratLocation + '/client/src/configuration.js',
15
+ bratLocation + '/client/src/util.js',
16
+ bratLocation + '/client/src/annotation_log.js',
17
+ bratLocation + '/client/lib/webfont.js',
18
+
19
+ // brat modules
20
+ bratLocation + '/client/src/dispatcher.js',
21
+ bratLocation + '/client/src/url_monitor.js',
22
+ bratLocation + '/client/src/visualizer.js',
23
+
24
+ // parse viewer
25
+ './stanza-parseviewer.js'
26
+ );
27
+
28
+ // Uses Dagre (https://github.com/cpettitt/dagre) for constinuency parse
29
+ // visualization. It works better than the brat visualization.
30
+ var useDagre = true;
31
+ var currentQuery = 'The quick brown fox jumped over the lazy dog.';
32
+ var currentSentences = '';
33
+ var currentText = '';
34
+
35
+ // ----------------------------------------------------------------------------
36
+ // HELPERS
37
+ // ----------------------------------------------------------------------------
38
+
39
+ /**
40
+ * Add the startsWith function to the String class
41
+ */
42
+ if (typeof String.prototype.startsWith !== 'function') {
43
+ // see below for better implementation!
44
+ String.prototype.startsWith = function (str){
45
+ return this.indexOf(str) === 0;
46
+ };
47
+ }
48
+
49
+ function isInt(value) {
50
+ return !isNaN(value) && (function(x) { return (x | 0) === x; })(parseFloat(value))
51
+ }
52
+
53
+ /**
54
+ * A reverse map of PTB tokens to their original gloss
55
+ */
56
+ var tokensMap = {
57
+ '-LRB-': '(',
58
+ '-RRB-': ')',
59
+ '-LSB-': '[',
60
+ '-RSB-': ']',
61
+ '-LCB-': '{',
62
+ '-RCB-': '}',
63
+ '``': '"',
64
+ '\'\'': '"',
65
+ };
66
+
67
+ /**
68
+ * A mapping from part of speech tag to the associated
69
+ * visualization color
70
+ */
71
+ function posColor(posTag) {
72
+ if (posTag === null) {
73
+ return '#E3E3E3';
74
+ } else if (posTag.startsWith('N')) {
75
+ return '#A4BCED';
76
+ } else if (posTag.startsWith('V') || posTag.startsWith('M')) {
77
+ return '#ADF6A2';
78
+ } else if (posTag.startsWith('P')) {
79
+ return '#CCDAF6';
80
+ } else if (posTag.startsWith('I')) {
81
+ return '#FFE8BE';
82
+ } else if (posTag.startsWith('R') || posTag.startsWith('W')) {
83
+ return '#FFFDA8';
84
+ } else if (posTag.startsWith('D') || posTag === 'CD') {
85
+ return '#CCADF6';
86
+ } else if (posTag.startsWith('J')) {
87
+ return '#FFFDA8';
88
+ } else if (posTag.startsWith('T')) {
89
+ return '#FFE8BE';
90
+ } else if (posTag.startsWith('E') || posTag.startsWith('S')) {
91
+ return '#E4CBF6';
92
+ } else if (posTag.startsWith('CC')) {
93
+ return '#FFFFFF';
94
+ } else if (posTag === 'LS' || posTag === 'FW') {
95
+ return '#FFFFFF';
96
+ } else {
97
+ return '#E3E3E3';
98
+ }
99
+ }
100
+
101
+ /**
102
+ * A mapping from part of speech tag to the associated
103
+ * visualization color
104
+ */
105
+ function uposColor(posTag) {
106
+ if (posTag === null) {
107
+ return '#E3E3E3';
108
+ } else if (posTag === 'NOUN' || posTag === 'PROPN') {
109
+ return '#A4BCED';
110
+ } else if (posTag.startsWith('V') || posTag === 'AUX') {
111
+ return '#ADF6A2';
112
+ } else if (posTag === 'PART') {
113
+ return '#CCDAF6';
114
+ } else if (posTag === 'ADP') {
115
+ return '#FFE8BE';
116
+ } else if (posTag === 'ADV' || posTag.startsWith('PRON')) {
117
+ return '#FFFDA8';
118
+ } else if (posTag === 'NUM' || posTag === 'DET') {
119
+ return '#CCADF6';
120
+ } else if (posTag === 'ADJ') {
121
+ return '#FFFDA8';
122
+ } else if (posTag.startsWith('E') || posTag.startsWith('S')) {
123
+ return '#E4CBF6';
124
+ } else if (posTag.startsWith('CC')) {
125
+ return '#FFFFFF';
126
+ } else if (posTag === 'X' || posTag === 'FW') {
127
+ return '#FFFFFF';
128
+ } else {
129
+ return '#E3E3E3';
130
+ }
131
+ }
132
+
133
+ /**
134
+ * A mapping from named entity tag to the associated
135
+ * visualization color
136
+ */
137
+ function nerColor(nerTag) {
138
+ if (nerTag === null) {
139
+ return '#E3E3E3';
140
+ } else if (nerTag === 'PERSON' || nerTag === 'PER') {
141
+ return '#FFCCAA';
142
+ } else if (nerTag === 'ORGANIZATION' || nerTag === 'ORG') {
143
+ return '#8FB2FF';
144
+ } else if (nerTag === 'MISC') {
145
+ return '#F1F447';
146
+ } else if (nerTag === 'LOCATION' || nerTag == 'LOC') {
147
+ return '#95DFFF';
148
+ } else if (nerTag === 'DATE' || nerTag === 'TIME' || nerTag === 'SET') {
149
+ return '#9AFFE6';
150
+ } else if (nerTag === 'MONEY') {
151
+ return '#FFFFFF';
152
+ } else if (nerTag === 'PERCENT') {
153
+ return '#FFA22B';
154
+ } else {
155
+ return '#E3E3E3';
156
+ }
157
+ }
158
+
159
+
160
+ /**
161
+ * A mapping from sentiment value to the associated
162
+ * visualization color
163
+ */
164
+ function sentimentColor(sentiment) {
165
+ if (sentiment === "VERY POSITIVE") {
166
+ return '#00FF00';
167
+ } else if (sentiment === "POSITIVE") {
168
+ return '#7FFF00';
169
+ } else if (sentiment === "NEUTRAL") {
170
+ return '#FFFF00';
171
+ } else if (sentiment === "NEGATIVE") {
172
+ return '#FF7F00';
173
+ } else if (sentiment === "VERY NEGATIVE") {
174
+ return '#FF0000';
175
+ } else {
176
+ return '#E3E3E3';
177
+ }
178
+ }
179
+
180
+
181
+ /**
182
+ * Get a list of annotators, from the annotator option input.
183
+ */
184
+ function annotators() {
185
+ var annotators = "tokenize,ssplit";
186
+ $('#annotators').find('option:selected').each(function () {
187
+ annotators += "," + $(this).val();
188
+ });
189
+ return annotators;
190
+ }
191
+
192
+ /**
193
+ * Get the input date
194
+ */
195
+ function date() {
196
+ function f(n) {
197
+ return n < 10 ? '0' + n : n;
198
+ }
199
+ var date = new Date();
200
+ var M = date.getMonth() + 1;
201
+ var D = date.getDate();
202
+ var Y = date.getFullYear();
203
+ var h = date.getHours();
204
+ var m = date.getMinutes();
205
+ var s = date.getSeconds();
206
+ return "" + Y + "-" + f(M) + "-" + f(D) + "T" + f(h) + ':' + f(m) + ':' + f(s);
207
+ }
208
+
209
+
210
+ //-----------------------------------------------------------------------------
211
+ // Constituency parser
212
+ //-----------------------------------------------------------------------------
213
+ function ConstituencyParseProcessor() {
214
+ var parenthesize = function (input, list) {
215
+ if (list === undefined) {
216
+ return parenthesize(input, []);
217
+ } else {
218
+ var token = input.shift();
219
+ if (token === undefined) {
220
+ return list.pop();
221
+ } else if (token === "(") {
222
+ list.push(parenthesize(input, []));
223
+ return parenthesize(input, list);
224
+ } else if (token === ")") {
225
+ return list;
226
+ } else {
227
+ return parenthesize(input, list.concat(token));
228
+ }
229
+ }
230
+ };
231
+
232
+ var toTree = function (list) {
233
+ if (list.length === 2 && typeof list[1] === 'string') {
234
+ return {label: list[0], text: list[1], isTerminal: true};
235
+ } else if (list.length >= 2) {
236
+ var label = list.shift();
237
+ var node = {label: label};
238
+ var rest = list.map(function (x) {
239
+ var t = toTree(x);
240
+ if (typeof t === 'object') {
241
+ t.parent = node;
242
+ }
243
+ return t;
244
+ });
245
+ node.children = rest;
246
+ return node;
247
+ } else {
248
+ return list;
249
+ }
250
+ };
251
+
252
+ var indexTree = function (tree, tokens, index) {
253
+ index = index || 0;
254
+ if (tree.isTerminal) {
255
+ tree.token = tokens[index];
256
+ tree.tokenIndex = index;
257
+ tree.tokenStart = index;
258
+ tree.tokenEnd = index + 1;
259
+ return index + 1;
260
+ } else if (tree.children) {
261
+ tree.tokenStart = index;
262
+ for (var i = 0; i < tree.children.length; i++) {
263
+ var child = tree.children[i];
264
+ index = indexTree(child, tokens, index);
265
+ }
266
+ tree.tokenEnd = index;
267
+ }
268
+ return index;
269
+ };
270
+
271
+ var tokenize = function (input) {
272
+ return input.split('"')
273
+ .map(function (x, i) {
274
+ if (i % 2 === 0) { // not in string
275
+ return x.replace(/\(/g, ' ( ')
276
+ .replace(/\)/g, ' ) ');
277
+ } else { // in string
278
+ return x.replace(/ /g, "!whitespace!");
279
+ }
280
+ })
281
+ .join('"')
282
+ .trim()
283
+ .split(/\s+/)
284
+ .map(function (x) {
285
+ return x.replace(/!whitespace!/g, " ");
286
+ });
287
+ };
288
+
289
+ var convertParseStringToTree = function (input, tokens) {
290
+ var p = parenthesize(tokenize(input));
291
+ if (Array.isArray(p)) {
292
+ var tree = toTree(p);
293
+ // Correlate tree with tokens
294
+ indexTree(tree, tokens);
295
+ return tree;
296
+ }
297
+ };
298
+
299
+ this.process = function(annotation) {
300
+ for (var i = 0; i < annotation.sentences.length; i++) {
301
+ var s = annotation.sentences[i];
302
+ if (s.parse) {
303
+ s.parseTree = convertParseStringToTree(s.parse, s.tokens);
304
+ }
305
+ }
306
+ }
307
+ }
308
+
309
+ // ----------------------------------------------------------------------------
310
+ // RENDER
311
+ // ----------------------------------------------------------------------------
312
+
313
+ /**
314
+ * Render a given JSON data structure
315
+ */
316
+ function render(data, reverse) {
317
+ // Tweak arguments
318
+ if (typeof reverse !== 'boolean') {
319
+ reverse = false;
320
+ }
321
+
322
+ // Error checks
323
+ if (typeof data.sentences === 'undefined') { return; }
324
+
325
+ /**
326
+ * Register an entity type (a tag) for Brat
327
+ */
328
+ var entityTypesSet = {};
329
+ var entityTypes = [];
330
+ function addEntityType(name, type, coarseType) {
331
+ if (typeof coarseType === "undefined") {
332
+ coarseType = type;
333
+ }
334
+ // Don't add duplicates
335
+ if (entityTypesSet[type]) return;
336
+ entityTypesSet[type] = true;
337
+ // Get the color of the entity type
338
+ color = '#ffccaa';
339
+ if (name === 'POS') {
340
+ color = posColor(type);
341
+ } else if (name === 'UPOS') {
342
+ color = uposColor(type);
343
+ } else if (name === 'NER') {
344
+ color = nerColor(coarseType);
345
+ } else if (name === 'NNER') {
346
+ color = nerColor(coarseType);
347
+ } else if (name === 'COREF') {
348
+ color = '#FFE000';
349
+ } else if (name === 'ENTITY') {
350
+ color = posColor('NN');
351
+ } else if (name === 'RELATION') {
352
+ color = posColor('VB');
353
+ } else if (name === 'LEMMA') {
354
+ color = '#FFFFFF';
355
+ } else if (name === 'SENTIMENT') {
356
+ color = sentimentColor(type);
357
+ } else if (name === 'LINK') {
358
+ color = '#FFFFFF';
359
+ } else if (name === 'KBP_ENTITY') {
360
+ color = '#FFFFFF';
361
+ }
362
+ // Register the type
363
+ entityTypes.push({
364
+ type: type,
365
+ labels : [type],
366
+ bgColor: color,
367
+ borderColor: 'darken'
368
+ });
369
+ }
370
+
371
+ /**
372
+ * Register a relation type (an arc) for Brat
373
+ */
374
+ var relationTypesSet = {};
375
+ var relationTypes = [];
376
+ function addRelationType(type, symmetricEdge) {
377
+ // Prevent adding duplicates
378
+ if (relationTypesSet[type]) return;
379
+ relationTypesSet[type] = true;
380
+ // Default arguments
381
+ if (typeof symmetricEdge === 'undefined') { symmetricEdge = false; }
382
+ // Add the type
383
+ relationTypes.push({
384
+ type: type,
385
+ labels: [type],
386
+ dashArray: (symmetricEdge ? '3,3' : undefined),
387
+ arrowHead: (symmetricEdge ? 'none' : undefined),
388
+ });
389
+ }
390
+
391
+ //
392
+ // Construct text of annotation
393
+ //
394
+ currentText = []; // GLOBAL
395
+ currentSentences = data.sentences; // GLOBAL
396
+ data.sentences.forEach(function(sentence) {
397
+ for (var i = 0; i < sentence.tokens.length; ++i) {
398
+ var token = sentence.tokens[i];
399
+ var word = token.word;
400
+ if (!(typeof tokensMap[word] === "undefined")) {
401
+ word = tokensMap[word];
402
+ }
403
+ if (i > 0) { currentText.push(' '); }
404
+ token.characterOffsetBegin = currentText.length;
405
+ for (var j = 0; j < word.length; ++j) {
406
+ currentText.push(word[j]);
407
+ }
408
+ token.characterOffsetEnd = currentText.length;
409
+ }
410
+ currentText.push('\n');
411
+ });
412
+ currentText = currentText.join('');
413
+
414
+ //
415
+ // Shared variables
416
+ // These are what we'll render in BRAT
417
+ //
418
+ // (pos)
419
+ var posEntities = [];
420
+ // (upos)
421
+ var uposEntities = [];
422
+ // (lemma)
423
+ var lemmaEntities = [];
424
+ // (ner)
425
+ var nerEntities = [];
426
+ var nerEntitiesNormalized = [];
427
+ // (sentiment)
428
+ var sentimentEntities = [];
429
+ // (entitylinking)
430
+ var linkEntities = [];
431
+ // (dependencies)
432
+ var depsRelations = [];
433
+ var deps2Relations = [];
434
+ // (openie)
435
+ var openieEntities = [];
436
+ var openieEntitiesSet = {};
437
+ var openieRelations = [];
438
+ var openieRelationsSet = {};
439
+ // (kbp)
440
+ var kbpEntities = [];
441
+ var kbpEntitiesSet = [];
442
+ var kbpRelations = [];
443
+ var kbpRelationsSet = [];
444
+
445
+ var cparseEntities = [];
446
+ var cparseRelations = [];
447
+
448
+ //
449
+ // Loop over sentences.
450
+ // This fills in the variables above.
451
+ //
452
+ for (var sentI = 0; sentI < data.sentences.length; ++sentI) {
453
+ var sentence = data.sentences[sentI];
454
+ var index = sentence.index;
455
+ var tokens = sentence.tokens;
456
+ var deps = sentence['basicDependencies'];
457
+ var deps2 = sentence['enhancedPlusPlusDependencies'];
458
+ var parseTree = sentence['parseTree'];
459
+
460
+ // POS tags
461
+ /**
462
+ * Generate a POS tagged token id
463
+ */
464
+ function posID(i) {
465
+ return 'POS_' + sentI + '_' + i;
466
+ }
467
+ var noXPOS = true;
468
+ if (tokens.length > 0 && typeof tokens[0].pos !== 'undefined' && tokens[0].pos !== null) {
469
+ noXPOS = false;
470
+ for (var i = 0; i < tokens.length; i++) {
471
+ var token = tokens[i];
472
+ var pos = token.pos;
473
+ var begin = parseInt(token.characterOffsetBegin);
474
+ var end = parseInt(token.characterOffsetEnd);
475
+ addEntityType('POS', pos);
476
+ posEntities.push([posID(i), pos, [[begin, end]]]);
477
+ }
478
+ }
479
+
480
+ // Universal POS tags
481
+ /**
482
+ * Generate a POS tagged token id
483
+ */
484
+ function uposID(i) {
485
+ return 'UPOS_' + sentI + '_' + i;
486
+ }
487
+ if (tokens.length > 0 && typeof tokens[0].upos !== 'undefined') {
488
+ for (var i = 0; i < tokens.length; i++) {
489
+ var token = tokens[i];
490
+ var upos = token.upos;
491
+ var begin = parseInt(token.characterOffsetBegin);
492
+ var end = parseInt(token.characterOffsetEnd);
493
+ addEntityType('UPOS', upos);
494
+ uposEntities.push([uposID(i), upos, [[begin, end]]]);
495
+ }
496
+ }
497
+
498
+ // Constituency parse
499
+ // Carries the same assumption as NER
500
+ if (parseTree && !useDagre) {
501
+ var parseEntities = [];
502
+ var parseRels = [];
503
+ function processParseTree(tree, index) {
504
+ tree.visitIndex = index;
505
+ index++;
506
+ if (tree.isTerminal) {
507
+ parseEntities[tree.visitIndex] = uposEntities[tree.tokenIndex];
508
+ return index;
509
+ } else if (tree.children) {
510
+ addEntityType('PARSENODE', tree.label);
511
+ parseEntities[tree.visitIndex] =
512
+ ['PARSENODE_' + sentI + '_' + tree.visitIndex, tree.label,
513
+ [[tokens[tree.tokenStart].characterOffsetBegin, tokens[tree.tokenEnd-1].characterOffsetEnd]]];
514
+ var parentEnt = parseEntities[tree.visitIndex];
515
+ for (var i = 0; i < tree.children.length; i++) {
516
+ var child = tree.children[i];
517
+ index = processParseTree(child, index);
518
+ var childEnt = parseEntities[child.visitIndex];
519
+ addRelationType('pc');
520
+ parseRels.push(['PARSEEDGE_' + sentI + '_' + parseRels.length, 'pc', [['parent', parentEnt[0]], ['child', childEnt[0]]]]);
521
+ }
522
+ }
523
+ return index;
524
+ }
525
+ processParseTree(parseTree, 0);
526
+ cparseEntities = cparseEntities.concat(cparseEntities, parseEntities);
527
+ cparseRelations = cparseRelations.concat(parseRels);
528
+ }
529
+
530
+ // Dependency parsing
531
+ /**
532
+ * Process a dependency tree from JSON to Brat relations
533
+ */
534
+ function processDeps(name, deps) {
535
+ var relations = [];
536
+ // Format: [${ID}, ${TYPE}, [[${ARGNAME}, ${TARGET}], [${ARGNAME}, ${TARGET}]]]
537
+ for (var i = 0; i < deps.length; i++) {
538
+ var dep = deps[i];
539
+ var governor = dep.governor - 1;
540
+ var dependent = dep.dependent - 1;
541
+ if (governor == -1) continue;
542
+ addRelationType(dep.dep);
543
+ relations.push([name + '_' + sentI + '_' + i, dep.dep, [['governor', uposID(governor)], ['dependent', uposID(dependent)]]]);
544
+ }
545
+ return relations;
546
+ }
547
+ // Actually add the dependencies
548
+ if (typeof deps !== 'undefined') {
549
+ depsRelations = depsRelations.concat(processDeps('dep', deps));
550
+ }
551
+ if (typeof deps2 !== 'undefined') {
552
+ deps2Relations = deps2Relations.concat(processDeps('dep2', deps2));
553
+ }
554
+
555
+ // Lemmas
556
+ if (tokens.length > 0 && typeof tokens[0].lemma !== 'undefined') {
557
+ for (var i = 0; i < tokens.length; i++) {
558
+ var token = tokens[i];
559
+ var lemma = token.lemma;
560
+ var begin = parseInt(token.characterOffsetBegin);
561
+ var end = parseInt(token.characterOffsetEnd);
562
+ addEntityType('LEMMA', lemma);
563
+ lemmaEntities.push(['LEMMA_' + sentI + '_' + i, lemma, [[begin, end]]]);
564
+ }
565
+ }
566
+
567
+ // NER tags
568
+ // Assumption: contiguous occurrence of one non-O is a single entity
569
+ var noNER = true;
570
+ if (tokens.some(function(token) { return token.ner; })) {
571
+ noNER = false;
572
+ for (var i = 0; i < tokens.length; i++) {
573
+ var ner = tokens[i].ner || 'O';
574
+ var normalizedNER = tokens[i].normalizedNER;
575
+ if (typeof normalizedNER === "undefined") {
576
+ normalizedNER = ner;
577
+ }
578
+ if (ner == 'O') continue;
579
+ var j = i;
580
+ while (j < tokens.length - 1 && tokens[j+1].ner == ner) j++;
581
+ addEntityType('NER', ner, ner);
582
+ nerEntities.push(['NER_' + sentI + '_' + i, ner, [[tokens[i].characterOffsetBegin, tokens[j].characterOffsetEnd]]]);
583
+ if (ner != normalizedNER) {
584
+ addEntityType('NNER', normalizedNER, ner);
585
+ nerEntities.push(['NNER_' + sentI + '_' + i, normalizedNER, [[tokens[i].characterOffsetBegin, tokens[j].characterOffsetEnd]]]);
586
+
587
+ }
588
+ i = j;
589
+ }
590
+ }
591
+
592
+ // Sentiment
593
+ if (typeof sentence.sentiment !== "undefined") {
594
+ var sentiment = sentence.sentiment.toUpperCase().replace("VERY", "VERY ");
595
+ addEntityType('SENTIMENT', sentiment);
596
+ sentimentEntities.push(['SENTIMENT_' + sentI, sentiment,
597
+ [[tokens[0].characterOffsetBegin, tokens[tokens.length - 1].characterOffsetEnd]]]);
598
+ }
599
+
600
+ // Entity Links
601
+ // Carries the same assumption as NER
602
+ if (tokens.length > 0) {
603
+ for (var i = 0; i < tokens.length; i++) {
604
+ var link = tokens[i].entitylink;
605
+ if (link == 'O' || typeof link === 'undefined') continue;
606
+ var j = i;
607
+ while (j < tokens.length - 1 && tokens[j+1].entitylink == link) j++;
608
+ addEntityType('LINK', link);
609
+ linkEntities.push(['LINK_' + sentI + '_' + i, link, [[tokens[i].characterOffsetBegin, tokens[j].characterOffsetEnd]]]);
610
+ i = j;
611
+ }
612
+ }
613
+
614
+ // Open IE
615
+ // Helper Functions
616
+ function openieID(span) {
617
+ return 'OPENIEENTITY' + '_' + sentI + '_' + span[0] + '_' + span[1];
618
+ }
619
+ function addEntity(span, role) {
620
+ // Don't add duplicate entities
621
+ if (openieEntitiesSet[[sentI, span, role]]) return;
622
+ openieEntitiesSet[[sentI, span, role]] = true;
623
+ // Add the entity
624
+ openieEntities.push([openieID(span), role,
625
+ [[tokens[span[0]].characterOffsetBegin,
626
+ tokens[span[1] - 1].characterOffsetEnd ]] ]);
627
+ }
628
+ function addRelation(gov, dep, role) {
629
+ // Don't add duplicate relations
630
+ if (openieRelationsSet[[sentI, gov, dep, role]]) return;
631
+ openieRelationsSet[[sentI, gov, dep, role]] = true;
632
+ // Add the relation
633
+ openieRelations.push(['OPENIESUBJREL_' + sentI + '_' + gov[0] + '_' + gov[1] + '_' + dep[0] + '_' + dep[1],
634
+ role,
635
+ [['governor', openieID(gov)],
636
+ ['dependent', openieID(dep)] ] ]);
637
+ }
638
+ // Render OpenIE
639
+ if (typeof sentence.openie !== 'undefined') {
640
+ // Register the entities + relations we'll need
641
+ addEntityType('ENTITY', 'Entity');
642
+ addEntityType('RELATION', 'Relation');
643
+ addRelationType('subject');
644
+ addRelationType('object');
645
+ // Loop over triples
646
+ for (var i = 0; i < sentence.openie.length; ++i) {
647
+ var subjectSpan = sentence.openie[i].subjectSpan;
648
+ var relationSpan = sentence.openie[i].relationSpan;
649
+ var objectSpan = sentence.openie[i].objectSpan;
650
+ if (parseInt(relationSpan[0]) < 0 || parseInt(relationSpan[1]) < 0) {
651
+ continue; // This is a phantom relation
652
+ }
653
+ var begin = parseInt(token.characterOffsetBegin);
654
+ // Add the entities
655
+ addEntity(subjectSpan, 'Entity');
656
+ addEntity(relationSpan, 'Relation');
657
+ addEntity(objectSpan, 'Entity');
658
+ // Add the relations
659
+ addRelation(relationSpan, subjectSpan, 'subject');
660
+ addRelation(relationSpan, objectSpan, 'object');
661
+ }
662
+ } // End OpenIE block
663
+
664
+
665
+ //
666
+ // KBP
667
+ //
668
+ // Helper Functions
669
+ function kbpEntity(span) {
670
+ return 'KBPENTITY' + '_' + sentI + '_' + span[0] + '_' + span[1];
671
+ }
672
+ function addKBPEntity(span, role) {
673
+ // Don't add duplicate entities
674
+ if (kbpEntitiesSet[[sentI, span, role]]) return;
675
+ kbpEntitiesSet[[sentI, span, role]] = true;
676
+ // Add the entity
677
+ kbpEntities.push([kbpEntity(span), role,
678
+ [[tokens[span[0]].characterOffsetBegin,
679
+ tokens[span[1] - 1].characterOffsetEnd ]] ]);
680
+ }
681
+ function addKBPRelation(gov, dep, role) {
682
+ // Don't add duplicate relations
683
+ if (kbpRelationsSet[[sentI, gov, dep, role]]) return;
684
+ kbpRelationsSet[[sentI, gov, dep, role]] = true;
685
+ // Add the relation
686
+ kbpRelations.push(['KBPRELATION_' + sentI + '_' + gov[0] + '_' + gov[1] + '_' + dep[0] + '_' + dep[1],
687
+ role,
688
+ [['governor', kbpEntity(gov)],
689
+ ['dependent', kbpEntity(dep)] ] ]);
690
+ }
691
+ if (typeof sentence.kbp !== 'undefined') {
692
+ // Register the entities + relations we'll need
693
+ addRelationType('subject');
694
+ addRelationType('object');
695
+ // Loop over triples
696
+ for (var i = 0; i < sentence.kbp.length; ++i) {
697
+ var subjectSpan = sentence.kbp[i].subjectSpan;
698
+ var subjectLink = 'Entity';
699
+ for (var k = subjectSpan[0]; k < subjectSpan[1]; ++k) {
700
+ if (subjectLink == 'Entity' &&
701
+ typeof tokens[k] !== 'undefined' &&
702
+ tokens[k].entitylink != 'O' &&
703
+ typeof tokens[k].entitylink !== 'undefined') {
704
+ subjectLink = tokens[k].entitylink
705
+ }
706
+ }
707
+ addEntityType('KBP_ENTITY', subjectLink);
708
+ var objectSpan = sentence.kbp[i].objectSpan;
709
+ var objectLink = 'Entity';
710
+ for (var k = objectSpan[0]; k < objectSpan[1]; ++k) {
711
+ if (objectLink == 'Entity' &&
712
+ typeof tokens[k] !== 'undefined' &&
713
+ tokens[k].entitylink != 'O' &&
714
+ typeof tokens[k].entitylink !== 'undefined') {
715
+ objectLink = tokens[k].entitylink
716
+ }
717
+ }
718
+ addEntityType('KBP_ENTITY', objectLink);
719
+ var relation = sentence.kbp[i].relation;
720
+ var begin = parseInt(token.characterOffsetBegin);
721
+ // Add the entities
722
+ addKBPEntity(subjectSpan, subjectLink);
723
+ addKBPEntity(objectSpan, objectLink);
724
+ // Add the relations
725
+ addKBPRelation(subjectSpan, objectSpan, relation);
726
+ }
727
+ } // End KBP block
728
+
729
+ } // End sentence loop
730
+
731
+ //
732
+ // Coreference
733
+ //
734
+ var corefEntities = [];
735
+ var corefRelations = [];
736
+ if (typeof data.corefs !== 'undefined') {
737
+ addRelationType('coref', true);
738
+ addEntityType('COREF', 'Mention');
739
+ var clusters = Object.keys(data.corefs);
740
+ clusters.forEach( function (clusterId) {
741
+ var chain = data.corefs[clusterId];
742
+ if (chain.length > 1) {
743
+ for (var i = 0; i < chain.length; ++i) {
744
+ var mention = chain[i];
745
+ var id = 'COREF' + mention.id;
746
+ var tokens = data.sentences[mention.sentNum - 1].tokens;
747
+ corefEntities.push([id, 'Mention',
748
+ [[tokens[mention.startIndex - 1].characterOffsetBegin,
749
+ tokens[mention.endIndex - 2].characterOffsetEnd ]] ]);
750
+ if (i > 0) {
751
+ var lastId = 'COREF' + chain[i - 1].id;
752
+ corefRelations.push(['COREF' + chain[i-1].id + '_' + chain[i].id,
753
+ 'coref',
754
+ [['governor', lastId],
755
+ ['dependent', id] ] ]);
756
+ }
757
+ }
758
+ }
759
+ });
760
+ } // End coreference block
761
+
762
+ //
763
+ // Actually render the elements
764
+ //
765
+
766
+ /**
767
+ * Helper function to render a given set of entities / relations
768
+ * to a Div, if it exists.
769
+ */
770
+ function embed(container, entities, relations, reverse) {
771
+ var text = currentText;
772
+ if (reverse) {
773
+ var length = currentText.length;
774
+ for (var i = 0; i < entities.length; ++i) {
775
+ var offsets = entities[i][2][0];
776
+ var tmp = length - offsets[0];
777
+ offsets[0] = length - offsets[1];
778
+ offsets[1] = tmp;
779
+ }
780
+ text = text.split("").reverse().join("");
781
+ }
782
+ if ($('#' + container).length > 0) {
783
+ Util.embed(container,
784
+ {entity_types: entityTypes, relation_types: relationTypes},
785
+ {text: text, entities: entities, relations: relations}
786
+ );
787
+ }
788
+ }
789
+
790
+ function reportna(container, text) {
791
+ $('#' + container).text(text);
792
+ }
793
+
794
+ // Render each annotation
795
+ head.ready(function() {
796
+ if (!noXPOS) {
797
+ embed('pos', posEntities);
798
+ } else {
799
+ reportna('pos', 'XPOS is not available for this language at this time.')
800
+ }
801
+ embed('upos', uposEntities);
802
+ embed('lemma', lemmaEntities);
803
+ if (!noNER) {
804
+ embed('ner', nerEntities);
805
+ } else {
806
+ reportna('ner', 'NER is not available for this language at this time.')
807
+ }
808
+ embed('entities', linkEntities);
809
+ if (!useDagre) {
810
+ embed('parse', cparseEntities, cparseRelations);
811
+ }
812
+ embed('deps', uposEntities, depsRelations);
813
+ embed('deps2', posEntities, deps2Relations);
814
+ embed('coref', corefEntities, corefRelations);
815
+ embed('openie', openieEntities, openieRelations);
816
+ embed('kbp', kbpEntities, kbpRelations);
817
+ embed('sentiment', sentimentEntities);
818
+
819
+ // Constituency parse
820
+ // Uses d3 and dagre-d3 (not brat)
821
+ if ($('#parse').length > 0 && useDagre) {
822
+ var parseViewer = new ParseViewer({ selector: '#parse' });
823
+ parseViewer.showAnnotation(data);
824
+ $('#parse').addClass('svg').css('display', 'block');
825
+ }
826
+ });
827
+
828
+ } // End render function
829
+
830
+
831
+ /**
832
+ * Render a TokensRegex response
833
+ */
834
+ function renderTokensregex(data) {
835
+ /**
836
+ * Register an entity type (a tag) for Brat
837
+ */
838
+ var entityTypesSet = {};
839
+ var entityTypes = [];
840
+ function addEntityType(type, color) {
841
+ // Don't add duplicates
842
+ if (entityTypesSet[type]) return;
843
+ entityTypesSet[type] = true;
844
+ // Set the color
845
+ if (typeof color === 'undefined') {
846
+ color = '#ADF6A2';
847
+ }
848
+ // Register the type
849
+ entityTypes.push({
850
+ type: type,
851
+ labels : [type],
852
+ bgColor: color,
853
+ borderColor: 'darken'
854
+ });
855
+ }
856
+
857
+ var entities = [];
858
+ for (var sentI = 0; sentI < data.sentences.length; ++sentI) {
859
+ var tokens = currentSentences[sentI].tokens;
860
+ for (var matchI = 0; matchI < data.sentences[sentI].length; ++matchI) {
861
+ var match = data.sentences[sentI][matchI];
862
+ // Add groups
863
+ for (groupName in match) {
864
+ if (groupName.startsWith("$") || isInt(groupName)) {
865
+ addEntityType(groupName, '#FFFDA8');
866
+ var begin = parseInt(tokens[match[groupName].begin].characterOffsetBegin);
867
+ var end = parseInt(tokens[match[groupName].end - 1].characterOffsetEnd);
868
+ entities.push(['TOK_' + sentI + '_' + matchI + '_' + groupName,
869
+ groupName,
870
+ [[begin, end]]]);
871
+ }
872
+ }
873
+ // Add match
874
+ addEntityType('match', '#ADF6A2');
875
+ var begin = parseInt(tokens[match.begin].characterOffsetBegin);
876
+ var end = parseInt(tokens[match.end - 1].characterOffsetEnd);
877
+ entities.push(['TOK_' + sentI + '_' + matchI + '_match',
878
+ 'match',
879
+ [[begin, end]]]);
880
+ }
881
+ }
882
+
883
+ Util.embed('tokensregex',
884
+ {entity_types: entityTypes, relation_types: []},
885
+ {text: currentText, entities: entities, relations: []}
886
+ );
887
+ } // END renderTokensregex()
888
+
889
+
890
+ /**
891
+ * Render a Semgrex response
892
+ */
893
+ function renderSemgrex(data) {
894
+ /**
895
+ * Register an entity type (a tag) for Brat
896
+ */
897
+ var entityTypesSet = {};
898
+ var entityTypes = [];
899
+ function addEntityType(type, color) {
900
+ // Don't add duplicates
901
+ if (entityTypesSet[type]) return;
902
+ entityTypesSet[type] = true;
903
+ // Set the color
904
+ if (typeof color === 'undefined') {
905
+ color = '#ADF6A2';
906
+ }
907
+ // Register the type
908
+ entityTypes.push({
909
+ type: type,
910
+ labels : [type],
911
+ bgColor: color,
912
+ borderColor: 'darken'
913
+ });
914
+ }
915
+
916
+
917
+ relationTypes = [{
918
+ type: 'semgrex',
919
+ labels: ['-'],
920
+ dashArray: '3,3',
921
+ arrowHead: 'none',
922
+ }];
923
+
924
+ var entities = [];
925
+ var relations = [];
926
+
927
+ for (var sentI = 0; sentI < data.sentences.length; ++sentI) {
928
+ var tokens = currentSentences[sentI].tokens;
929
+ for (var matchI = 0; matchI < data.sentences[sentI].length; ++matchI) {
930
+ var match = data.sentences[sentI][matchI];
931
+ // Add match
932
+ addEntityType('match', '#ADF6A2');
933
+ var begin = parseInt(tokens[match.begin].characterOffsetBegin);
934
+ var end = parseInt(tokens[match.end - 1].characterOffsetEnd);
935
+ entities.push(['SEM_' + sentI + '_' + matchI + '_match',
936
+ 'match',
937
+ [[begin, end]]]);
938
+
939
+ // Add groups
940
+ for (groupName in match) {
941
+ if (groupName.startsWith("$") || isInt(groupName)) {
942
+ // (add node)
943
+ group = match[groupName];
944
+ groupName = groupName.substring(1);
945
+ addEntityType(groupName, '#FFFDA8');
946
+ var begin = parseInt(tokens[group.begin].characterOffsetBegin);
947
+ var end = parseInt(tokens[group.end - 1].characterOffsetEnd);
948
+ entities.push(['SEM_' + sentI + '_' + matchI + '_' + groupName,
949
+ groupName,
950
+ [[begin, end]]]);
951
+
952
+ // (add relation)
953
+ relations.push(['SEMGREX_' + sentI + '_' + matchI + '_' + groupName,
954
+ 'semgrex',
955
+ [['governor', 'SEM_' + sentI + '_' + matchI + '_match'],
956
+ ['dependent', 'SEM_' + sentI + '_' + matchI + '_' + groupName] ] ]);
957
+ }
958
+ }
959
+ }
960
+ }
961
+
962
+ Util.embed('semgrex',
963
+ {entity_types: entityTypes, relation_types: relationTypes},
964
+ {text: currentText, entities: entities, relations: relations}
965
+ );
966
+ } // END renderSemgrex
967
+
968
+ /**
969
+ * Render a Tregex response
970
+ */
971
+ function renderTregex(data) {
972
+ $('#tregex').empty();
973
+ $('#tregex').append('<pre>' + JSON.stringify(data, null, 4) + '</pre>');
974
+ } // END renderTregex
975
+
976
+ // ----------------------------------------------------------------------------
977
+ // MAIN
978
+ // ----------------------------------------------------------------------------
979
+
980
+ /**
981
+ * MAIN()
982
+ *
983
+ * The entry point of the page
984
+ */
985
+ $(document).ready(function() {
986
+ // Some initial styling
987
+ $('.chosen-select').chosen();
988
+ $('.chosen-container').css('width', '100%');
989
+
990
+
991
+ // Language-specific changes
992
+ $('#language').on('change', function() {
993
+ $('#text').attr('dir', '');
994
+ if ($('#language').val() === 'ar' ||
995
+ $('#language').val() === 'fa' ||
996
+ $('#language').val() === 'he' ||
997
+ $('#language').val() === 'ur') {
998
+ $('#text').attr('dir', 'rtl');
999
+ }
1000
+ if ($('#language').val() === 'ar') {
1001
+ $('#text').attr('placeholder', 'على سبيل المثال، قفز الثعلب البني السريع فوق الكلب الكسول.');
1002
+ } else if ($('#language').val() === 'en') {
1003
+ $('#text').attr('placeholder', 'e.g., The quick brown fox jumped over the lazy dog.');
1004
+ } else if ($('#language').val() === 'zh') {
1005
+ $('#text').attr('placeholder', '例如,快速的棕色狐狸跳过了懒惰的狗。');
1006
+ } else if ($('#language').val() === 'zh-Hant') {
1007
+ $('#text').attr('placeholder', '例如,快速的棕色狐狸跳過了懶惰的狗。');
1008
+ } else if ($('#language').val() === 'fr') {
1009
+ $('#text').attr('placeholder', 'Par exemple, le renard brun rapide a sauté sur le chien paresseux.');
1010
+ } else if ($('#language').val() === 'de') {
1011
+ $('#text').attr('placeholder', 'Z. B. sprang der schnelle braune Fuchs über den faulen Hund.');
1012
+ } else if ($('#language').val() === 'es') {
1013
+ $('#text').attr('placeholder', 'Por ejemplo, el rápido zorro marrón saltó sobre el perro perezoso.');
1014
+ } else if ($('#language').val() === 'ur') {
1015
+ $('#text').attr('placeholder', 'میرا نام علی ہے');
1016
+ } else {
1017
+ $('#text').attr('placeholder', 'Unknown language for placeholder query: ' + $('#language').val());
1018
+ }
1019
+ });
1020
+
1021
+ // Submit on shift-enter
1022
+ $('#text').keydown(function (event) {
1023
+ if (event.keyCode == 13) {
1024
+ if(event.shiftKey){
1025
+ event.preventDefault(); // don't register the enter key when pressed
1026
+ return false;
1027
+ }
1028
+ }
1029
+ });
1030
+ $('#text').keyup(function (event) {
1031
+ if (event.keyCode == 13) {
1032
+ if(event.shiftKey){
1033
+ $('#submit').click(); // submit the form when the enter key is released
1034
+ event.stopPropagation();
1035
+ return false;
1036
+ }
1037
+ }
1038
+ });
1039
+
1040
+ // Submit on clicking the 'submit' button
1041
+ $('#submit').click(function() {
1042
+ // Get the text to annotate
1043
+ currentQuery = $('#text').val();
1044
+ if (currentQuery.trim() == '') {
1045
+ if ($('#language').val() === 'ar') {
1046
+ currentQuery = 'قفز الثعلب البني السريع فوق الكلب الكسول.';
1047
+ } else if ($('#language').val() === 'en') {
1048
+ currentQuery = 'The quick brown fox jumped over the lazy dog.';
1049
+ } else if ($('#language').val() === 'zh') {
1050
+ currentQuery = '快速的棕色狐狸跳过了懒惰的狗。';
1051
+ } else if ($('#language').val() === 'zh-Hant') {
1052
+ currentQuery = '快速的棕色狐狸跳過了懶惰的狗。';
1053
+ } else if ($('#language').val() === 'fr') {
1054
+ currentQuery = 'Le renard brun rapide a sauté sur le chien paresseux.';
1055
+ } else if ($('#language').val() === 'de') {
1056
+ currentQuery = 'Sprang der schnelle braune Fuchs über den faulen Hund.';
1057
+ } else if ($('#language').val() === 'es') {
1058
+ currentQuery = 'El rápido zorro marrón saltó sobre el perro perezoso.';
1059
+ } else if ($('#language').val() === 'ur') {
1060
+ currentQuery = 'میرا نام علی ہے';
1061
+ } else {
1062
+ currentQuery = 'Unknown language for default query: ' + $('#language').val();
1063
+ }
1064
+ $('#text').val(currentQuery);
1065
+ }
1066
+ // Update the UI
1067
+ $('#submit').prop('disabled', true);
1068
+ $('#annotations').hide();
1069
+ $('#patterns_row').hide();
1070
+ $('#loading').show();
1071
+
1072
+ // Run query
1073
+ $.ajax({
1074
+ type: 'POST',
1075
+ url: serverAddress + '?properties=' + encodeURIComponent(
1076
+ '{"annotators": "' + annotators() + '", "date": "' + date() + '"}') +
1077
+ '&pipelineLanguage=' + encodeURIComponent($('#language').val()),
1078
+ data: encodeURIComponent(currentQuery), //jQuery doesn't automatically URI encode strings
1079
+ dataType: 'json',
1080
+ contentType: "application/x-www-form-urlencoded;charset=UTF-8",
1081
+ responseType: "application/json",
1082
+ success: function(data) {
1083
+ $('#submit').prop('disabled', false);
1084
+ if (typeof data === 'undefined' || data.sentences == undefined) {
1085
+ alert("Failed to reach server!");
1086
+ } else {
1087
+ // Process constituency parse
1088
+ var constituencyParseProcessor = new ConstituencyParseProcessor();
1089
+ constituencyParseProcessor.process(data);
1090
+ // Empty divs
1091
+ $('#annotations').empty();
1092
+ // Re-render divs
1093
+ function createAnnotationDiv(id, annotator, selector, label) {
1094
+ // (make sure we requested that element)
1095
+ if (annotators().split(",").indexOf(annotator) < 0) {
1096
+ return;
1097
+ }
1098
+ // (make sure the data contains that element)
1099
+ ok = false;
1100
+ if (typeof data[selector] !== 'undefined') {
1101
+ ok = true;
1102
+ } else if (typeof data.sentences !== 'undefined' && data.sentences.length > 0) {
1103
+ if (typeof data.sentences[0][selector] !== 'undefined') {
1104
+ ok = true;
1105
+ } else if (typeof data.sentences[0].tokens != 'undefined' && data.sentences[0].tokens.length > 0) {
1106
+ // (make sure the annotator select is in at least one of the tokens of any sentence)
1107
+ ok = data.sentences.some(function(sentence) {
1108
+ return sentence.tokens.some(function(token) {
1109
+ return typeof token[selector] !== 'undefined';
1110
+ });
1111
+ });
1112
+ }
1113
+ }
1114
+ // (render the element)
1115
+ if (ok) {
1116
+ $('#annotations').append('<h4 class="red">' + label + ':</h4> <div id="' + id + '"></div>');
1117
+ }
1118
+ }
1119
+ // (create the divs)
1120
+ // div id annotator field_in_data label
1121
+ createAnnotationDiv('pos', 'pos', 'pos', 'Part-of-Speech (XPOS)' );
1122
+ createAnnotationDiv('upos', 'upos', 'upos', 'Universal Part-of-Speech');
1123
+ createAnnotationDiv('lemma', 'lemma', 'lemma', 'Lemmas' );
1124
+ createAnnotationDiv('ner', 'ner', 'ner', 'Named Entity Recognition');
1125
+ createAnnotationDiv('deps', 'depparse', 'basicDependencies', 'Universal Dependencies' );
1126
+ createAnnotationDiv('parse', 'parse', 'parseTree', 'Constituency Parse' );
1127
+ //createAnnotationDiv('deps2', 'depparse', 'enhancedPlusPlusDependencies', 'Enhanced++ Dependencies' );
1128
+ //createAnnotationDiv('openie', 'openie', 'openie', 'Open IE' );
1129
+ //createAnnotationDiv('coref', 'coref', 'corefs', 'Coreference' );
1130
+ //createAnnotationDiv('entities', 'entitylink', 'entitylink', 'Wikidict Entities' );
1131
+ //createAnnotationDiv('kbp', 'kbp', 'kbp', 'KBP Relations' );
1132
+ //createAnnotationDiv('sentiment','sentiment', 'sentiment', 'Sentiment' );
1133
+ // Update UI
1134
+ $('#loading').hide();
1135
+ $('.corenlp_error').remove(); // Clear error messages
1136
+ $('#annotations').show();
1137
+ // Render
1138
+ var reverse = ($('#language').val() === 'ar' || $('#language').val() === 'fa' || $('#language').val() === 'he' || $('#language').val() === 'ur');
1139
+ render(data, reverse);
1140
+ // Render patterns
1141
+ //$('#annotations').append('<h4 class="red" style="margin-top: 4ex;">CoreNLP Tools:</h4>'); // TODO(gabor) a strange place to add this header to
1142
+ //$('#patterns_row').show();
1143
+ }
1144
+ },
1145
+ error: function(data) {
1146
+ DATA = data;
1147
+ var alertDiv = $('<div/>').addClass('alert').addClass('alert-danger').addClass('alert-dismissible').addClass('corenlp_error').attr('role', 'alert')
1148
+ var button = $('<button type="button" class="close" data-dismiss="alert" aria-label="Close"><span aria-hidden="true">&times;</span></button>');
1149
+ var message = $('<span/>').text(data.responseText);
1150
+ button.appendTo(alertDiv);
1151
+ message.appendTo(alertDiv);
1152
+ $('#loading').hide();
1153
+ alertDiv.appendTo($('#errors'));
1154
+ $('#submit').prop('disabled', false);
1155
+ }
1156
+ });
1157
+ event.preventDefault();
1158
+ event.stopPropagation();
1159
+ return false;
1160
+ });
1161
+
1162
+
1163
+ // Support passing parameters on page launch, via window.location.hash parameters.
1164
+ // Example: http://localhost:9000/#text=foo%20bar&annotators=pos,lemma,ner
1165
+ (function() {
1166
+ var rawParams = window.location.hash.slice(1).split("&");
1167
+ var params = {};
1168
+ rawParams.forEach(function(paramKV) {
1169
+ paramKV = paramKV.split("=");
1170
+ if (paramKV.length === 2) {
1171
+ var key = paramKV[0];
1172
+ var value = paramKV[1];
1173
+ params[key] = value;
1174
+ }
1175
+ });
1176
+ if (params.text) {
1177
+ var text = decodeURIComponent(params.text);
1178
+ $('#text').val(text);
1179
+ }
1180
+ if (params.annotators) {
1181
+ var annotators = params.annotators.split(",");
1182
+ // De-select everything
1183
+ $('#annotators').find('option').each(function() {
1184
+ $(this).prop('selected', false);
1185
+ });
1186
+ // Select the specified ones.
1187
+ annotators.forEach(function(a) {
1188
+ $('#annotators').find('option[value="'+a+'"]').prop('selected', true);
1189
+ });
1190
+ // Refresh Chosen
1191
+ $('#annotators').trigger('chosen:updated');
1192
+ }
1193
+ if (params.text || params.annotators) {
1194
+ // Finally, let's auto-submit.
1195
+ $('#submit').click();
1196
+ }
1197
+ })();
1198
+
1199
+
1200
+ $('#form_tokensregex').submit( function (e) {
1201
+ // Don't actually submit the form
1202
+ e.preventDefault();
1203
+ // Get text
1204
+ if ($('#tokensregex_search').val().trim() == '') {
1205
+ $('#tokensregex_search').val('(?$foxtype [{pos:JJ}]+ ) fox');
1206
+ }
1207
+ var pattern = $('#tokensregex_search').val();
1208
+ // Remove existing annotation
1209
+ $('#tokensregex').remove();
1210
+ // Make ajax call
1211
+ $.ajax({
1212
+ type: 'POST',
1213
+ url: serverAddress + '/tokensregex?pattern=' + encodeURIComponent(
1214
+ pattern.replace("&", "\\&").replace('+', '\\+')) +
1215
+ '&properties=' + encodeURIComponent(
1216
+ '{"annotators": "' + annotators() + '", "date": "' + date() + '"}') +
1217
+ '&pipelineLanguage=' + encodeURIComponent($('#language').val()),
1218
+ data: encodeURIComponent(currentQuery),
1219
+ success: function(data) {
1220
+ $('.tokensregex_error').remove(); // Clear error messages
1221
+ $('<div id="tokensregex" class="pattern_brat"/>').appendTo($('#div_tokensregex'));
1222
+ renderTokensregex(data);
1223
+ },
1224
+ error: function(data) {
1225
+ var alertDiv = $('<div/>').addClass('alert').addClass('alert-danger').addClass('alert-dismissible').addClass('tokensregex_error').attr('role', 'alert')
1226
+ var button = $('<button type="button" class="close" data-dismiss="alert" aria-label="Close"><span aria-hidden="true">&times;</span></button>');
1227
+ var message = $('<span/>').text(data.responseText);
1228
+ button.appendTo(alertDiv);
1229
+ message.appendTo(alertDiv);
1230
+ alertDiv.appendTo($('#div_tokensregex'));
1231
+ }
1232
+ });
1233
+ });
1234
+
1235
+
1236
+ $('#form_semgrex').submit( function (e) {
1237
+ // Don't actually submit the form
1238
+ e.preventDefault();
1239
+ // Get text
1240
+ if ($('#semgrex_search').val().trim() == '') {
1241
+ $('#semgrex_search').val('{pos:/VB.*/} >nsubj {}=subject >/nmod:.*/ {}=prep_phrase');
1242
+ }
1243
+ var pattern = $('#semgrex_search').val();
1244
+ // Remove existing annotation
1245
+ $('#semgrex').remove();
1246
+ // Add missing required annotators
1247
+ var requiredAnnotators = annotators().split(',');
1248
+ if (requiredAnnotators.indexOf('depparse') < 0) {
1249
+ requiredAnnotators.push('depparse');
1250
+ }
1251
+ // Make ajax call
1252
+ $.ajax({
1253
+ type: 'POST',
1254
+ url: serverAddress + '/semgrex?pattern=' + encodeURIComponent(
1255
+ pattern.replace("&", "\\&").replace('+', '\\+')) +
1256
+ '&properties=' + encodeURIComponent(
1257
+ '{"annotators": "' + requiredAnnotators.join(',') + '", "date": "' + date() + '"}') +
1258
+ '&pipelineLanguage=' + encodeURIComponent($('#language').val()),
1259
+ data: encodeURIComponent(currentQuery),
1260
+ success: function(data) {
1261
+ $('.semgrex_error').remove(); // Clear error messages
1262
+ $('<div id="semgrex" class="pattern_brat"/>').appendTo($('#div_semgrex'));
1263
+ renderSemgrex(data);
1264
+ },
1265
+ error: function(data) {
1266
+ var alertDiv = $('<div/>').addClass('alert').addClass('alert-danger').addClass('alert-dismissible').addClass('semgrex_error').attr('role', 'alert')
1267
+ var button = $('<button type="button" class="close" data-dismiss="alert" aria-label="Close"><span aria-hidden="true">&times;</span></button>');
1268
+ var message = $('<span/>').text(data.responseText);
1269
+ button.appendTo(alertDiv);
1270
+ message.appendTo(alertDiv);
1271
+ alertDiv.appendTo($('#div_semgrex'));
1272
+ }
1273
+ });
1274
+ });
1275
+
1276
+ $('#form_tregex').submit( function (e) {
1277
+ // Don't actually submit the form
1278
+ e.preventDefault();
1279
+ // Get text
1280
+ if ($('#tregex_search').val().trim() == '') {
1281
+ $('#tregex_search').val('NP < NN=animal');
1282
+ }
1283
+ var pattern = $('#tregex_search').val();
1284
+ // Remove existing annotation
1285
+ $('#tregex').remove();
1286
+ // Add missing required annotators
1287
+ var requiredAnnotators = annotators().split(',');
1288
+ if (requiredAnnotators.indexOf('parse') < 0) {
1289
+ requiredAnnotators.push('parse');
1290
+ }
1291
+ // Make ajax call
1292
+ $.ajax({
1293
+ type: 'POST',
1294
+ url: serverAddress + '/tregex?pattern=' + encodeURIComponent(
1295
+ pattern.replace("&", "\\&").replace('+', '\\+')) +
1296
+ '&properties=' + encodeURIComponent(
1297
+ '{"annotators": "' + requiredAnnotators.join(',') + '", "date": "' + date() + '"}') +
1298
+ '&pipelineLanguage=' + encodeURIComponent($('#language').val()),
1299
+ data: encodeURIComponent(currentQuery),
1300
+ success: function(data) {
1301
+ $('.tregex_error').remove(); // Clear error messages
1302
+ $('<div id="tregex" class="pattern_brat"/>').appendTo($('#div_tregex'));
1303
+ renderTregex(data);
1304
+ },
1305
+ error: function(data) {
1306
+ var alertDiv = $('<div/>').addClass('alert').addClass('alert-danger').addClass('alert-dismissible').addClass('tregex_error').attr('role', 'alert')
1307
+ var button = $('<button type="button" class="close" data-dismiss="alert" aria-label="Close"><span aria-hidden="true">&times;</span></button>');
1308
+ var message = $('<span/>').text(data.responseText);
1309
+ button.appendTo(alertDiv);
1310
+ message.appendTo(alertDiv);
1311
+ alertDiv.appendTo($('#div_tregex'));
1312
+ }
1313
+ });
1314
+ });
1315
+
1316
+ });
stanza/stanza/pipeline/external/corenlp_converter_depparse.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A depparse processor which converts constituency trees using CoreNLP
3
+ """
4
+
5
+ from stanza.pipeline._constants import TOKENIZE, CONSTITUENCY, DEPPARSE
6
+ from stanza.pipeline.processor import ProcessorVariant, register_processor_variant
7
+ from stanza.server.dependency_converter import DependencyConverter
8
+
9
+ @register_processor_variant(DEPPARSE, 'converter')
10
+ class ConverterDepparse(ProcessorVariant):
11
+ # set of processor requirements for this processor
12
+ REQUIRES_DEFAULT = set([TOKENIZE, CONSTITUENCY])
13
+
14
+ def __init__(self, config):
15
+ if config['lang'] != 'en':
16
+ raise ValueError("Constituency to dependency converter only works for English")
17
+
18
+ # TODO: get classpath from config
19
+ # TODO: close this when finished?
20
+ # a more involved approach would be to turn the Pipeline into
21
+ # a context with __enter__ and __exit__
22
+ # __exit__ would try to free all resources, although some
23
+ # might linger such as GPU allocations
24
+ # maybe it isn't worth even trying to clean things up on account of that
25
+ self.converter = DependencyConverter(classpath="$CLASSPATH")
26
+ self.converter.open_pipe()
27
+
28
+ def process(self, document):
29
+ return self.converter.process(document)
stanza/stanza/pipeline/external/jieba.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Processors related to Jieba in the pipeline.
3
+ """
4
+
5
+ import re
6
+
7
+ from stanza.models.common import doc
8
+ from stanza.pipeline._constants import TOKENIZE
9
+ from stanza.pipeline.processor import ProcessorVariant, register_processor_variant
10
+
11
+ def check_jieba():
12
+ """
13
+ Import necessary components from Jieba to perform tokenization.
14
+ """
15
+ try:
16
+ import jieba
17
+ except ImportError:
18
+ raise ImportError(
19
+ "Jieba is used but not installed on your machine. Go to https://pypi.org/project/jieba/ for installation instructions."
20
+ )
21
+ return True
22
+
23
+ @register_processor_variant(TOKENIZE, 'jieba')
24
+ class JiebaTokenizer(ProcessorVariant):
25
+ def __init__(self, config):
26
+ """ Construct a Jieba-based tokenizer by loading the Jieba pipeline.
27
+
28
+ Note that this tokenizer uses regex for sentence segmentation.
29
+ """
30
+ if config['lang'] not in ['zh', 'zh-hans', 'zh-hant']:
31
+ raise Exception("Jieba tokenizer is currently only allowed in Chinese (simplified or traditional) pipelines.")
32
+
33
+ check_jieba()
34
+ import jieba
35
+ self.nlp = jieba
36
+ self.no_ssplit = config.get('no_ssplit', False)
37
+
38
+ def process(self, document):
39
+ """ Tokenize a document with the Jieba tokenizer and wrap the results into a Doc object.
40
+ """
41
+ if isinstance(document, doc.Document):
42
+ text = document.text
43
+ else:
44
+ text = document
45
+ if not isinstance(text, str):
46
+ raise Exception("Must supply a string or Stanza Document object to the Jieba tokenizer.")
47
+ tokens = self.nlp.cut(text, cut_all=False)
48
+
49
+ sentences = []
50
+ current_sentence = []
51
+ offset = 0
52
+ for token in tokens:
53
+ if re.match(r'\s+', token):
54
+ offset += len(token)
55
+ continue
56
+
57
+ token_entry = {
58
+ doc.TEXT: token,
59
+ doc.MISC: f"{doc.START_CHAR}={offset}|{doc.END_CHAR}={offset+len(token)}"
60
+ }
61
+ current_sentence.append(token_entry)
62
+ offset += len(token)
63
+
64
+ if not self.no_ssplit and token in ['。', '!', '?', '!', '?']:
65
+ sentences.append(current_sentence)
66
+ current_sentence = []
67
+
68
+ if len(current_sentence) > 0:
69
+ sentences.append(current_sentence)
70
+
71
+ return doc.Document(sentences, text)
stanza/stanza/pipeline/external/sudachipy.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Processors related to SudachiPy in the pipeline.
3
+
4
+ GitHub Home: https://github.com/WorksApplications/SudachiPy
5
+ """
6
+
7
+ import re
8
+
9
+ from stanza.models.common import doc
10
+ from stanza.pipeline._constants import TOKENIZE
11
+ from stanza.pipeline.processor import ProcessorVariant, register_processor_variant
12
+
13
+ def check_sudachipy():
14
+ """
15
+ Import necessary components from SudachiPy to perform tokenization.
16
+ """
17
+ try:
18
+ import sudachipy
19
+ import sudachidict_core
20
+ except ImportError:
21
+ raise ImportError(
22
+ "Both sudachipy and sudachidict_core libraries are required. "
23
+ "Try install them with `pip install sudachipy sudachidict_core`. "
24
+ "Go to https://github.com/WorksApplications/SudachiPy for more information."
25
+ )
26
+ return True
27
+
28
+ @register_processor_variant(TOKENIZE, 'sudachipy')
29
+ class SudachiPyTokenizer(ProcessorVariant):
30
+ def __init__(self, config):
31
+ """ Construct a SudachiPy-based tokenizer.
32
+
33
+ Note that this tokenizer uses regex for sentence segmentation.
34
+ """
35
+ if config['lang'] != 'ja':
36
+ raise Exception("SudachiPy tokenizer is only allowed in Japanese pipelines.")
37
+
38
+ check_sudachipy()
39
+ from sudachipy import tokenizer
40
+ from sudachipy import dictionary
41
+
42
+ self.tokenizer = dictionary.Dictionary().create()
43
+ self.no_ssplit = config.get('no_ssplit', False)
44
+
45
+ def process(self, document):
46
+ """ Tokenize a document with the SudachiPy tokenizer and wrap the results into a Doc object.
47
+ """
48
+ if isinstance(document, doc.Document):
49
+ text = document.text
50
+ else:
51
+ text = document
52
+ if not isinstance(text, str):
53
+ raise Exception("Must supply a string or Stanza Document object to the SudachiPy tokenizer.")
54
+
55
+ # we use the default sudachipy tokenization mode (i.e., mode C)
56
+ # more config needs to be added to support other modes
57
+
58
+ tokens = self.tokenizer.tokenize(text)
59
+
60
+ sentences = []
61
+ current_sentence = []
62
+ for token in tokens:
63
+ token_text = token.surface()
64
+ # by default sudachipy will output whitespace as a token
65
+ # we need to skip these tokens to be consistent with other tokenizers
66
+ if token_text.isspace():
67
+ continue
68
+ start = token.begin()
69
+ end = token.end()
70
+
71
+ token_entry = {
72
+ doc.TEXT: token_text,
73
+ doc.MISC: f"{doc.START_CHAR}={start}|{doc.END_CHAR}={end}"
74
+ }
75
+ current_sentence.append(token_entry)
76
+
77
+ if not self.no_ssplit and token_text in ['。', '!', '?', '!', '?']:
78
+ sentences.append(current_sentence)
79
+ current_sentence = []
80
+
81
+ if len(current_sentence) > 0:
82
+ sentences.append(current_sentence)
83
+
84
+ return doc.Document(sentences, text)
stanza/stanza/utils/charlm/oscar_to_text.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Turns an Oscar 2022 jsonl file to text
3
+
4
+ YOU DO NOT NEED THIS if you use the oscar extractor which reads from
5
+ HuggingFace, dump_oscar.py
6
+
7
+ to run:
8
+ python3 -m stanza.utils.charlm.oscar_to_text <path> ...
9
+
10
+ each path can be a file or a directory with multiple .jsonl files in it
11
+ """
12
+
13
+ import argparse
14
+ import glob
15
+ import json
16
+ import lzma
17
+ import os
18
+ import sys
19
+ from stanza.models.common.utils import open_read_text
20
+
21
+ def extract_file(output_directory, input_filename, use_xz):
22
+ print("Extracting %s" % input_filename)
23
+ if output_directory is None:
24
+ output_directory, output_filename = os.path.split(input_filename)
25
+ else:
26
+ _, output_filename = os.path.split(input_filename)
27
+
28
+ json_idx = output_filename.rfind(".jsonl")
29
+ if json_idx < 0:
30
+ output_filename = output_filename + ".txt"
31
+ else:
32
+ output_filename = output_filename[:json_idx] + ".txt"
33
+ if use_xz:
34
+ output_filename += ".xz"
35
+ open_file = lambda x: lzma.open(x, "wt", encoding="utf-8")
36
+ else:
37
+ open_file = lambda x: open(x, "w", encoding="utf-8")
38
+
39
+ output_filename = os.path.join(output_directory, output_filename)
40
+ print("Writing content to %s" % output_filename)
41
+ with open_read_text(input_filename) as fin:
42
+ with open_file(output_filename) as fout:
43
+ for line in fin:
44
+ content = json.loads(line)
45
+ content = content['content']
46
+
47
+ fout.write(content)
48
+ fout.write("\n\n")
49
+
50
+ def parse_args():
51
+ parser = argparse.ArgumentParser()
52
+ parser.add_argument("--output", default=None, help="Output directory for saving files. If None, will write to the original directory")
53
+ parser.add_argument("--no_xz", default=True, dest="xz", action="store_false", help="Don't use xz to compress the output files")
54
+ parser.add_argument("filenames", nargs="+", help="Filenames or directories to process")
55
+ args = parser.parse_args()
56
+ return args
57
+
58
+ def main():
59
+ """
60
+ Go through each of the given filenames or directories, convert json to .txt.xz
61
+ """
62
+ args = parse_args()
63
+ if args.output is not None:
64
+ os.makedirs(args.output, exist_ok=True)
65
+ for filename in args.filenames:
66
+ if os.path.isfile(filename):
67
+ extract_file(args.output, filename, args.xz)
68
+ elif os.path.isdir(filename):
69
+ files = glob.glob(os.path.join(filename, "*jsonl*"))
70
+ files = sorted([x for x in files if os.path.isfile(x)])
71
+ print("Found %d files:" % len(files))
72
+ if len(files) > 0:
73
+ print(" %s" % "\n ".join(files))
74
+ for json_filename in files:
75
+ extract_file(args.output, json_filename, args.xz)
76
+
77
+ if __name__ == "__main__":
78
+ main()
stanza/stanza/utils/constituency/__init__.py ADDED
File without changes
stanza/stanza/utils/constituency/grep_test_logs.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ import sys
3
+
4
+ filenames = sys.argv[1:]
5
+
6
+ total_score = 0.0
7
+ num_scores = 0
8
+
9
+ for filename in filenames:
10
+ grep_cmd = ["grep", "F1 score.*test.*", filename]
11
+ grep_result = subprocess.run(grep_cmd, stdout=subprocess.PIPE, encoding="utf-8")
12
+ grep_result = grep_result.stdout.strip()
13
+ if not grep_result:
14
+ print("{}: no result".format(filename))
15
+ continue
16
+
17
+ score = float(grep_result.split()[-1])
18
+ print("{}: {}".format(filename, score))
19
+ total_score += score
20
+ num_scores += 1
21
+
22
+ if num_scores > 0:
23
+ avg = total_score / num_scores
24
+ print("Avg: {}".format(avg))
stanza/stanza/utils/datasets/constituency/build_silver_dataset.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Given two ensembles and a tokenized file, output the trees for which those ensembles agree and report how many of the sub-models agree on those trees.
3
+
4
+ For example:
5
+
6
+ python3 -m stanza.utils.datasets.constituency.build_silver_dataset --tokenized_file /u/nlp/data/constituency-parser/italian/2024_wiki_tokenization/it_wiki_tokenized_AA.txt --lang it --output_file asdf.out --e1 saved_models/constituency/it_vit_electra_100?_top_constituency.pt --e2 saved_models/constituency/it_vit_electra_100?_constituency.pt
7
+
8
+ for i in `echo f g h i j k l m n o p q r s t`; do nlprun -d a6000 "python3 -m stanza.utils.datasets.constituency.build_silver_dataset --tokenized_file /u/nlp/data/constituency-parser/italian/2024_wiki_tokenization/it_wiki_tok_6M_a$i.txt --lang it --output_file /u/nlp/data/constituency-parser/italian/2024_it_vit_electra/a$i.trees --e1 saved_models/constituency/it_vit_electra_100?_top_constituency.pt --e2 saved_models/constituency/it_vit_electra_100?_constituency.pt" -o /u/nlp/data/constituency-parser/italian/2024_it_vit_electra/a$i.out; done
9
+
10
+ for i in `echo a b c d`; do nlprun -d a6000 "python3 -m stanza.utils.datasets.constituency.build_silver_dataset --tokenized_file /u/nlp/data/constituency-parser/english/en_wiki_2023/shuf_1M.a$i --lang en --output_file /u/nlp/data/constituency-parser/english/2024_en_ptb3_electra/forward_a$i.trees --e1 saved_models/constituency/en_ptb3_electra-large_100?_in_constituency.pt --e2 saved_models/constituency/en_ptb3_electra-large_100?_top_constituency.pt" -o /u/nlp/data/constituency-parser/english/2024_en_ptb3_electra/forward_a$i.out; done
11
+ """
12
+
13
+ import argparse
14
+ import json
15
+
16
+ import logging
17
+
18
+ from stanza.models.common import utils
19
+ from stanza.models.common.foundation_cache import FoundationCache
20
+ from stanza.models.constituency import retagging
21
+ from stanza.models.constituency import text_processing
22
+ from stanza.models.constituency import tree_reader
23
+ from stanza.models.constituency.ensemble import Ensemble
24
+ from stanza.utils.get_tqdm import get_tqdm
25
+
26
+ tqdm = get_tqdm()
27
+
28
+ logger = logging.getLogger('stanza.constituency.trainer')
29
+
30
+ def parse_args(args=None):
31
+ parser = argparse.ArgumentParser(description="Script that uses multiple ensembles to find trees where both ensembles agree")
32
+
33
+ input_group = parser.add_mutually_exclusive_group(required=True)
34
+ input_group.add_argument('--tokenized_file', type=str, default=None, help='Input file of tokenized text for parsing with parse_text.')
35
+ input_group.add_argument('--tree_file', type=str, default=None, help='Input file of already parsed text for reparsing with parse_text.')
36
+ parser.add_argument('--output_file', type=str, default=None, help='Where to put the output file')
37
+
38
+ parser.add_argument('--charlm_forward_file', type=str, default=None, help="Exact path to use for forward charlm")
39
+ parser.add_argument('--charlm_backward_file', type=str, default=None, help="Exact path to use for backward charlm")
40
+ parser.add_argument('--wordvec_pretrain_file', type=str, default=None, help='Exact name of the pretrain file to read')
41
+
42
+ utils.add_device_args(parser)
43
+
44
+ parser.add_argument('--lang', default='en', help='Language to use')
45
+
46
+ parser.add_argument('--eval_batch_size', type=int, default=50, help='How many trees to batch when running eval')
47
+ parser.add_argument('--e1', type=str, nargs='+', default=None, help="Which model(s) to load in the first ensemble")
48
+ parser.add_argument('--e2', type=str, nargs='+', default=None, help="Which model(s) to load in the second ensemble")
49
+
50
+ parser.add_argument('--mode', default='predict', choices=['parse_text', 'predict'])
51
+
52
+ # another option would be to include the tree idx in each entry in an existing saved file
53
+ # the processing could then pick up at exactly the last known idx
54
+ parser.add_argument('--start_tree', type=int, default=0, help='Where to start... most useful if the previous incarnation crashed')
55
+ parser.add_argument('--end_tree', type=int, default=None, help='Where to end. If unset, will process to the end of the file')
56
+
57
+ retagging.add_retag_args(parser)
58
+
59
+ args = vars(parser.parse_args())
60
+
61
+ retagging.postprocess_args(args)
62
+ args['num_generate'] = 0
63
+
64
+ return args
65
+
66
+ def main():
67
+ args = parse_args()
68
+ utils.log_training_args(args, logger, name="ensemble")
69
+
70
+ retag_pipeline = retagging.build_retag_pipeline(args)
71
+ foundation_cache = retag_pipeline[0].foundation_cache if retag_pipeline else FoundationCache()
72
+
73
+ logger.info("Building ensemble #1 out of %s", args['e1'])
74
+ e1 = Ensemble(args, filenames=args['e1'], foundation_cache=foundation_cache)
75
+ e1.to(args.get('device', None))
76
+ logger.info("Building ensemble #2 out of %s", args['e2'])
77
+ e2 = Ensemble(args, filenames=args['e2'], foundation_cache=foundation_cache)
78
+ e2.to(args.get('device', None))
79
+
80
+ if args['tokenized_file']:
81
+ tokenized_sentences = text_processing.read_tokenized_file(args['tokenized_file'])
82
+ elif args['tree_file']:
83
+ treebank = tree_reader.read_treebank(args['tree_file'])
84
+ tokenized_sentences = [x.leaf_labels() for x in treebank]
85
+ if args['lang'] == 'vi':
86
+ tokenized_sentences = [[x.replace("_", " ") for x in sentence] for sentence in tokenized_sentences]
87
+ logger.info("Read %d tokenized sentences", len(tokenized_sentences))
88
+
89
+ all_models = e1.models + e2.models
90
+
91
+ chunk_size = 1000
92
+ with open(args['output_file'], 'w', encoding='utf-8') as fout:
93
+ end_tree = len(tokenized_sentences) if args['end_tree'] is None else args['end_tree']
94
+ for chunk_start in tqdm(range(args['start_tree'], end_tree, chunk_size)):
95
+ chunk = tokenized_sentences[chunk_start:chunk_start+chunk_size]
96
+ logger.info("Processing trees %d to %d", chunk_start, chunk_start+len(chunk))
97
+ parsed1 = text_processing.parse_tokenized_sentences(args, e1, retag_pipeline, chunk)
98
+ parsed1 = [x.predictions[0].tree for x in parsed1]
99
+ parsed2 = text_processing.parse_tokenized_sentences(args, e2, retag_pipeline, chunk)
100
+ parsed2 = [x.predictions[0].tree for x in parsed2]
101
+ matching = [t for t, t2 in zip(parsed1, parsed2) if t == t2]
102
+ logger.info("%d trees matched", len(matching))
103
+ model_counts = [0] * len(matching)
104
+ for model in all_models:
105
+ model_chunk = model.parse_sentences_no_grad(iter(matching), model.build_batch_from_trees, args['eval_batch_size'], model.predict)
106
+ model_chunk = [x.predictions[0].tree for x in model_chunk]
107
+ for idx, (t1, t2) in enumerate(zip(matching, model_chunk)):
108
+ if t1 == t2:
109
+ model_counts[idx] += 1
110
+ for count, tree in zip(model_counts, matching):
111
+ line = {"tree": "%s" % tree, "count": count}
112
+ fout.write(json.dumps(line))
113
+ fout.write("\n")
114
+
115
+
116
+ if __name__ == '__main__':
117
+ main()
stanza/stanza/utils/datasets/constituency/convert_cintil.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import xml.etree.ElementTree as ET
2
+
3
+ from stanza.models.constituency import tree_reader
4
+ from stanza.utils.datasets.constituency import utils
5
+
6
+ def read_xml_file(input_filename):
7
+ """
8
+ Convert the CINTIL xml file to id & test
9
+
10
+ Returns a list of tuples: (id, text)
11
+ """
12
+ with open(input_filename, encoding="utf-8") as fin:
13
+ dataset = ET.parse(fin)
14
+ dataset = dataset.getroot()
15
+ corpus = dataset.find("{http://nlx.di.fc.ul.pt}corpus")
16
+ if not corpus:
17
+ raise ValueError("Unexpected dataset structure : no 'corpus'")
18
+ trees = []
19
+ for sentence in corpus:
20
+ if sentence.tag != "{http://nlx.di.fc.ul.pt}sentence":
21
+ raise ValueError("Unexpected sentence tag: {}".format(sentence.tag))
22
+ id_node = None
23
+ raw_node = None
24
+ tree_nodde = None
25
+ for node in sentence:
26
+ if node.tag == '{http://nlx.di.fc.ul.pt}id':
27
+ id_node = node
28
+ elif node.tag == '{http://nlx.di.fc.ul.pt}raw':
29
+ raw_node = node
30
+ elif node.tag == '{http://nlx.di.fc.ul.pt}tree':
31
+ tree_node = node
32
+ else:
33
+ raise ValueError("Unexpected tag in sentence {}: {}".format(sentence, node.tag))
34
+ if id_node is None or raw_node is None or tree_node is None:
35
+ raise ValueError("Missing node in sentence {}".format(sentence))
36
+ tree_id = "".join(id_node.itertext())
37
+ tree_text = "".join(tree_node.itertext())
38
+ trees.append((tree_id, tree_text))
39
+ return trees
40
+
41
+ def convert_cintil_treebank(input_filename, train_size=0.8, dev_size=0.1):
42
+ """
43
+ dev_size is the size for splitting train & dev
44
+ """
45
+ trees = read_xml_file(input_filename)
46
+
47
+ synthetic_trees = []
48
+ natural_trees = []
49
+ for tree_id, tree_text in trees:
50
+ if tree_text.find(" _") >= 0:
51
+ raise ValueError("Unexpected underscore")
52
+ tree_text = tree_text.replace("_)", ")")
53
+ tree_text = tree_text.replace("(A (", "(A' (")
54
+ # trees don't have ROOT, but we typically use a ROOT label at the top
55
+ tree_text = "(ROOT %s)" % tree_text
56
+ trees = tree_reader.read_trees(tree_text)
57
+ if len(trees) != 1:
58
+ raise ValueError("Unexpectedly found %d trees in %s" % (len(trees), tree_id))
59
+ tree = trees[0]
60
+ if tree_id.startswith("aTSTS"):
61
+ synthetic_trees.append(tree)
62
+ elif tree_id.find("TSTS") >= 0:
63
+ raise ValueError("Unexpected TSTS")
64
+ else:
65
+ natural_trees.append(tree)
66
+
67
+ print("Read %d synthetic trees" % len(synthetic_trees))
68
+ print("Read %d natural trees" % len(natural_trees))
69
+ train_trees, dev_trees, test_trees = utils.split_treebank(natural_trees, train_size, dev_size)
70
+ print("Split %d trees into %d train %d dev %d test" % (len(natural_trees), len(train_trees), len(dev_trees), len(test_trees)))
71
+ train_trees = synthetic_trees + train_trees
72
+ print("Total lengths %d train %d dev %d test" % (len(train_trees), len(dev_trees), len(test_trees)))
73
+ return train_trees, dev_trees, test_trees
74
+
75
+
76
+ def main():
77
+ treebank = convert_cintil_treebank("extern_data/constituency/portuguese/CINTIL/CINTIL-Treebank.xml")
78
+
79
+ if __name__ == '__main__':
80
+ main()
stanza/stanza/utils/datasets/constituency/count_common_words.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ from collections import Counter
4
+
5
+ from stanza.models.constituency import parse_tree
6
+ from stanza.models.constituency import tree_reader
7
+
8
+ word_counter = Counter()
9
+ count_words = lambda x: word_counter.update(x.leaf_labels())
10
+
11
+ tree_reader.read_tree_file(sys.argv[1], tree_callback=count_words)
12
+ print(word_counter.most_common()[:100])
stanza/stanza/utils/datasets/constituency/prepare_con_dataset.py ADDED
@@ -0,0 +1,594 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Converts raw data files from their original format (dataset dependent) into PTB trees.
2
+
3
+ The operation of this script depends heavily on the dataset in question.
4
+ The common result is that the data files go to data/constituency and are in PTB format.
5
+
6
+ da_arboretum
7
+ Ekhard Bick
8
+ Arboretum, a Hybrid Treebank for Danish
9
+ https://www.researchgate.net/publication/251202293_Arboretum_a_Hybrid_Treebank_for_Danish
10
+ Available here for a license fee:
11
+ http://catalog.elra.info/en-us/repository/browse/ELRA-W0084/
12
+ Internal to Stanford, please contact Chris Manning and/or John Bauer
13
+ The file processed is the tiger xml, although there are some edits
14
+ needed in order to make it functional for our parser
15
+ The treebank comes as a tar.gz file, W0084.tar.gz
16
+ untar this file in $CONSTITUENCY_BASE/danish
17
+ then move the extracted folder to "arboretum"
18
+ $CONSTITUENCY_BASE/danish/W0084/... becomes
19
+ $CONSTITUENCY_BASE/danish/arboretum/...
20
+
21
+ en_ptb3-revised is an updated version of PTB with NML and stuff
22
+ put LDC2015T13 in $CONSTITUENCY_BASE/english
23
+ the directory name may look like LDC2015T13_eng_news_txt_tbnk-ptb_revised
24
+ python3 -m stanza.utils.datasets.constituency.prepare_con_dataset en_ptb3-revised
25
+
26
+ All this needs to do is concatenate the various pieces
27
+
28
+ @article{ptb_revised,
29
+ title= {Penn Treebank Revised: English News Text Treebank LDC2015T13},
30
+ journal= {},
31
+ author= {Ann Bies and Justin Mott and Colin Warner},
32
+ year= {2015},
33
+ url= {https://doi.org/10.35111/xpjy-at91},
34
+ doi= {10.35111/xpjy-at91},
35
+ isbn= {1-58563-724-6},
36
+ dcmi= {text},
37
+ languages= {english},
38
+ language= {english},
39
+ ldc= {LDC2015T13},
40
+ }
41
+
42
+ id_icon
43
+ ICON: Building a Large-Scale Benchmark Constituency Treebank
44
+ for the Indonesian Language
45
+ Ee Suan Lim, Wei Qi Leong, Ngan Thanh Nguyen, Dea Adhista,
46
+ Wei Ming Kng, William Chandra Tjhi, Ayu Purwarianti
47
+ https://aclanthology.org/2023.tlt-1.5.pdf
48
+ Available at https://github.com/aisingapore/seacorenlp-data
49
+ git clone the repo in $CONSTITUENCY_BASE/seacorenlp
50
+ so there is now a directory
51
+ $CONSTITUENCY_BASE/seacorenlp/seacorenlp-data
52
+ python3 -m stanza.utils.datasets.constituency.prepare_con_dataset id_icon
53
+
54
+ it_turin
55
+ A combination of Evalita competition from 2011 and the ParTUT trees
56
+ More information is available in convert_it_turin
57
+
58
+ it_vit
59
+ The original for the VIT UD Dataset
60
+ The UD version has a lot of corrections, so we try to apply those as much as possible
61
+ In fact, we applied some corrections of our own back to UD based on this treebank.
62
+ The first version which had those corrections is UD 2.10
63
+ Versions of UD before that won't work
64
+ Hopefully versions after that work
65
+ Set UDBASE to a path such that $UDBASE/UD_Italian-VIT is the UD version
66
+ The constituency labels are generally not very understandable, unfortunately
67
+ Some documentation is available here:
68
+ https://core.ac.uk/download/pdf/223148096.pdf
69
+ https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.423.5538&rep=rep1&type=pdf
70
+ Available from ELRA:
71
+ http://catalog.elra.info/en-us/repository/browse/ELRA-W0040/
72
+
73
+ ja_alt
74
+ Asian Language Treebank produced a treebank for Japanese:
75
+ Ye Kyaw Thu, Win Pa Pa, Masao Utiyama, Andrew Finch, Eiichiro Sumita
76
+ Introducing the Asian Language Treebank
77
+ http://www.lrec-conf.org/proceedings/lrec2016/pdf/435_Paper.pdf
78
+ Download
79
+ https://www2.nict.go.jp/astrec-att/member/mutiyama/ALT/Japanese-ALT-20210218.zip
80
+ unzip this in $CONSTITUENCY_BASE/japanese
81
+ this should create a directory $CONSTITUENCY_BASE/japanese/Japanese-ALT-20210218
82
+ In this directory, also download the following:
83
+ https://www2.nict.go.jp/astrec-att/member/mutiyama/ALT/URL-train.txt
84
+ https://www2.nict.go.jp/astrec-att/member/mutiyama/ALT/URL-dev.txt
85
+ https://www2.nict.go.jp/astrec-att/member/mutiyama/ALT/URL-test.txt
86
+ In particular, there are two files with a bunch of bracketed parses,
87
+ Japanese-ALT-Draft.txt and Japanese-ALT-Reviewed.txt
88
+ The first word of each of these lines is SNT.80188.1 or something like that
89
+ This correlates with the three URL-... files, telling us whether the
90
+ sentence belongs in train/dev/test
91
+ python3 -m stanza.utils.datasets.constituency.prepare_con_dataset ja_alt
92
+
93
+ pt_cintil
94
+ CINTIL treebank for Portuguese, available at ELRA:
95
+ https://catalogue.elra.info/en-us/repository/browse/ELRA-W0055/
96
+ It can also be obtained from here:
97
+ https://hdl.handle.net/21.11129/0000-000B-D2FE-A
98
+ Produced at U Lisbon
99
+ António Branco; João Silva; Francisco Costa; Sérgio Castro
100
+ CINTIL TreeBank Handbook: Design options for the representation of syntactic constituency
101
+ Silva, João; António Branco; Sérgio Castro; Ruben Reis
102
+ Out-of-the-Box Robust Parsing of Portuguese
103
+ https://portulanclarin.net/repository/extradocs/CINTIL-Treebank.pdf
104
+ http://www.di.fc.ul.pt/~ahb/pubs/2011bBrancoSilvaCostaEtAl.pdf
105
+ If at Stanford, ask John Bauer or Chris Manning for the data
106
+ Otherwise, purchase it from ELRA or find it elsewhere if possible
107
+ Either way, unzip it in
108
+ $CONSTITUENCY_BASE/portuguese to the CINTIL directory
109
+ so for example, the final result might be
110
+ extern_data/constituency/portuguese/CINTIL/CINTIL-Treebank.xml
111
+ python3 -m stanza.utils.datasets.constituency.prepare_con_dataset pt_cintil
112
+
113
+ tr_starlang
114
+ A dataset in three parts from the Starlang group in Turkey:
115
+ Neslihan Kara, Büşra Marşan, et al
116
+ Creating A Syntactically Felicitous Constituency Treebank For Turkish
117
+ https://ieeexplore.ieee.org/document/9259873
118
+ git clone the following three repos
119
+ https://github.com/olcaytaner/TurkishAnnotatedTreeBank-15
120
+ https://github.com/olcaytaner/TurkishAnnotatedTreeBank2-15
121
+ https://github.com/olcaytaner/TurkishAnnotatedTreeBank2-20
122
+ Put them in
123
+ $CONSTITUENCY_BASE/turkish
124
+ python3 -m stanza.utils.datasets.constituency.prepare_con_dataset tr_starlang
125
+
126
+ vlsp09 is the 2009 constituency treebank:
127
+ Nguyen Phuong Thai, Vu Xuan Luong, Nguyen Thi Minh Huyen, Nguyen Van Hiep, Le Hong Phuong
128
+ Building a Large Syntactically-Annotated Corpus of Vietnamese
129
+ Proceedings of The Third Linguistic Annotation Workshop
130
+ In conjunction with ACL-IJCNLP 2009, Suntec City, Singapore, 2009
131
+ This can be obtained by contacting vlsp.resources@gmail.com
132
+
133
+ vlsp22 is the 2022 constituency treebank from the VLSP bakeoff
134
+ there is an official test set as well
135
+ you may be able to obtain both of these by contacting vlsp.resources@gmail.com
136
+ NGUYEN Thi Minh Huyen, HA My Linh, VU Xuan Luong, PHAN Thi Hue,
137
+ LE Van Cuong, NGUYEN Thi Luong, NGO The Quyen
138
+ VLSP 2022 Challenge: Vietnamese Constituency Parsing
139
+ to appear in Journal of Computer Science and Cybernetics.
140
+
141
+ vlsp23 is the 2023 update to the constituency treebank from the VLSP bakeoff
142
+ the vlsp22 code also works for the new dataset,
143
+ although some effort may be needed to update the tags
144
+ As of late 2024, the test set is available on request at vlsp.resources@gmail.com
145
+ Organize the directory
146
+ $CONSTITUENCY_BASE/vietnamese/VLSP_2023
147
+ $CONSTITUENCY_BASE/vietnamese/VLSP_2023/Trainingset
148
+ $CONSTITUENCY_BASE/vietnamese/VLSP_2023/test
149
+
150
+ zh_ctb-51 is the 5.1 version of CTB
151
+ put LDC2005T01U01_ChineseTreebank5.1 in $CONSTITUENCY_BASE/chinese
152
+ python3 -m stanza.utils.datasets.constituency.prepare_con_dataset zh_ctb-51
153
+
154
+ @article{xue_xia_chiou_palmer_2005,
155
+ title={The Penn Chinese TreeBank: Phrase structure annotation of a large corpus},
156
+ volume={11},
157
+ DOI={10.1017/S135132490400364X},
158
+ number={2},
159
+ journal={Natural Language Engineering},
160
+ publisher={Cambridge University Press},
161
+ author={XUE, NAIWEN and XIA, FEI and CHIOU, FU-DONG and PALMER, MARTA},
162
+ year={2005},
163
+ pages={207–238}}
164
+
165
+ zh_ctb-51b is the same dataset, but using a smaller dev/test set
166
+ in our experiments, this is substantially easier
167
+
168
+ zh_ctb-90 is the 9.0 version of CTB
169
+ put LDC2016T13 in $CONSTITUENCY_BASE/chinese
170
+ python3 -m stanza.utils.datasets.constituency.prepare_con_dataset zh_ctb-90
171
+
172
+ the splits used are the ones from the file docs/ctb9.0-file-list.txt
173
+ included in the CTB 9.0 release
174
+
175
+ SPMRL adds several treebanks
176
+ https://www.spmrl.org/
177
+ https://www.spmrl.org/sancl-posters2014.html
178
+ Currently only German is converted, the German version being a
179
+ version of the Tiger Treebank
180
+ python3 -m stanza.utils.datasets.constituency.prepare_con_dataset de_spmrl
181
+
182
+ en_mctb is a multidomain test set covering five domains other than newswire
183
+ https://github.com/RingoS/multi-domain-parsing-analysis
184
+ Challenges to Open-Domain Constituency Parsing
185
+
186
+ @inproceedings{yang-etal-2022-challenges,
187
+ title = "Challenges to Open-Domain Constituency Parsing",
188
+ author = "Yang, Sen and
189
+ Cui, Leyang and
190
+ Ning, Ruoxi and
191
+ Wu, Di and
192
+ Zhang, Yue",
193
+ booktitle = "Findings of the Association for Computational Linguistics: ACL 2022",
194
+ month = may,
195
+ year = "2022",
196
+ address = "Dublin, Ireland",
197
+ publisher = "Association for Computational Linguistics",
198
+ url = "https://aclanthology.org/2022.findings-acl.11",
199
+ doi = "10.18653/v1/2022.findings-acl.11",
200
+ pages = "112--127",
201
+ }
202
+
203
+ This conversion replaces the top bracket from top -> ROOT and puts an extra S
204
+ bracket on any roots with more than one node.
205
+ """
206
+
207
+ import argparse
208
+ import os
209
+ import random
210
+ import sys
211
+ import tempfile
212
+
213
+ from tqdm import tqdm
214
+
215
+ from stanza.models.constituency import parse_tree
216
+ import stanza.utils.default_paths as default_paths
217
+ from stanza.models.constituency import tree_reader
218
+ from stanza.models.constituency.parse_tree import Tree
219
+ from stanza.server import tsurgeon
220
+ from stanza.utils.datasets.common import UnknownDatasetError
221
+ from stanza.utils.datasets.constituency import utils
222
+ from stanza.utils.datasets.constituency.convert_alt import convert_alt
223
+ from stanza.utils.datasets.constituency.convert_arboretum import convert_tiger_treebank
224
+ from stanza.utils.datasets.constituency.convert_cintil import convert_cintil_treebank
225
+ import stanza.utils.datasets.constituency.convert_ctb as convert_ctb
226
+ from stanza.utils.datasets.constituency.convert_it_turin import convert_it_turin
227
+ from stanza.utils.datasets.constituency.convert_it_vit import convert_it_vit
228
+ from stanza.utils.datasets.constituency.convert_spmrl import convert_spmrl
229
+ from stanza.utils.datasets.constituency.convert_starlang import read_starlang
230
+ from stanza.utils.datasets.constituency.utils import SHARDS, write_dataset
231
+ import stanza.utils.datasets.constituency.vtb_convert as vtb_convert
232
+ import stanza.utils.datasets.constituency.vtb_split as vtb_split
233
+
234
+ def process_it_turin(paths, dataset_name, *args):
235
+ """
236
+ Convert the it_turin dataset
237
+ """
238
+ assert dataset_name == 'it_turin'
239
+ input_dir = os.path.join(paths["CONSTITUENCY_BASE"], "italian")
240
+ output_dir = paths["CONSTITUENCY_DATA_DIR"]
241
+ convert_it_turin(input_dir, output_dir)
242
+
243
+ def process_it_vit(paths, dataset_name, *args):
244
+ # needs at least UD 2.11 or this will not work
245
+ # in the meantime, the git version of VIT will suffice
246
+ assert dataset_name == 'it_vit'
247
+ convert_it_vit(paths, dataset_name)
248
+
249
+ def process_vlsp09(paths, dataset_name, *args):
250
+ """
251
+ Processes the VLSP 2009 dataset, discarding or fixing trees when needed
252
+ """
253
+ assert dataset_name == 'vi_vlsp09'
254
+ vlsp_path = os.path.join(paths["CONSTITUENCY_BASE"], "vietnamese", "VietTreebank_VLSP_SP73", "Kho ngu lieu 10000 cay cu phap")
255
+ with tempfile.TemporaryDirectory() as tmp_output_path:
256
+ vtb_convert.convert_dir(vlsp_path, tmp_output_path)
257
+ vtb_split.split_files(tmp_output_path, paths["CONSTITUENCY_DATA_DIR"], dataset_name)
258
+
259
+ def process_vlsp21(paths, dataset_name, *args):
260
+ """
261
+ Processes the VLSP 2021 dataset, which is just a single file
262
+ """
263
+ assert dataset_name == 'vi_vlsp21'
264
+ vlsp_file = os.path.join(paths["CONSTITUENCY_BASE"], "vietnamese", "VLSP_2021", "VTB_VLSP21_tree.txt")
265
+ if not os.path.exists(vlsp_file):
266
+ raise FileNotFoundError("Could not find the 2021 dataset in the expected location of {} - CONSTITUENCY_BASE == {}".format(vlsp_file, paths["CONSTITUENCY_BASE"]))
267
+ with tempfile.TemporaryDirectory() as tmp_output_path:
268
+ vtb_convert.convert_files([vlsp_file], tmp_output_path)
269
+ # This produces a 0 length test set, just as a placeholder until the actual test set is released
270
+ vtb_split.split_files(tmp_output_path, paths["CONSTITUENCY_DATA_DIR"], dataset_name, train_size=0.9, dev_size=0.1)
271
+ _, _, test_file = vtb_split.create_paths(paths["CONSTITUENCY_DATA_DIR"], dataset_name)
272
+ with open(test_file, "w"):
273
+ # create an empty test file - currently we don't have actual test data for VLSP 21
274
+ pass
275
+
276
+ def process_vlsp22(paths, dataset_name, *args):
277
+ """
278
+ Processes the VLSP 2022 dataset, which is four separate files for some reason
279
+ """
280
+ assert dataset_name == 'vi_vlsp22' or dataset_name == 'vi_vlsp23'
281
+
282
+ if dataset_name == 'vi_vlsp22':
283
+ default_subdir = 'VLSP_2022'
284
+ default_make_test_split = False
285
+ updated_tagset = False
286
+ elif dataset_name == 'vi_vlsp23':
287
+ default_subdir = os.path.join('VLSP_2023', 'Trainingdataset')
288
+ default_make_test_split = False
289
+ updated_tagset = True
290
+
291
+ parser = argparse.ArgumentParser()
292
+ parser.add_argument('--subdir', default=default_subdir, type=str, help='Where to find the data - allows for using previous versions, if needed')
293
+ parser.add_argument('--no_convert_brackets', default=True, action='store_false', dest='convert_brackets', help="Don't convert the VLSP parens RKBT & LKBT to PTB parens")
294
+ parser.add_argument('--n_splits', default=None, type=int, help='Split the data into this many pieces. Relevant as there is no set training/dev split, so this allows for N models on N different dev sets')
295
+ parser.add_argument('--test_split', default=default_make_test_split, action='store_true', help='Split 1/10th of the data as a test split as well. Useful for experimental results. Less relevant since there is now an official test set')
296
+ parser.add_argument('--no_test_split', dest='test_split', action='store_false', help='Split 1/10th of the data as a test split as well. Useful for experimental results. Less relevant since there is now an official test set')
297
+ parser.add_argument('--seed', default=1234, type=int, help='Random seed to use when splitting')
298
+ args = parser.parse_args(args=list(*args))
299
+
300
+ if os.path.exists(args.subdir):
301
+ vlsp_dir = args.subdir
302
+ else:
303
+ vlsp_dir = os.path.join(paths["CONSTITUENCY_BASE"], "vietnamese", args.subdir)
304
+ if not os.path.exists(vlsp_dir):
305
+ raise FileNotFoundError("Could not find the {} dataset in the expected location of {} - CONSTITUENCY_BASE == {}".format(dataset_name, vlsp_dir, paths["CONSTITUENCY_BASE"]))
306
+ vlsp_files = os.listdir(vlsp_dir)
307
+ vlsp_train_files = [os.path.join(vlsp_dir, x) for x in vlsp_files if x.startswith("file") and not x.endswith(".zip")]
308
+ vlsp_train_files.sort()
309
+
310
+ if dataset_name == 'vi_vlsp22':
311
+ vlsp_test_files = [os.path.join(vlsp_dir, x) for x in vlsp_files if x.startswith("private") and not x.endswith(".zip")]
312
+ elif dataset_name == 'vi_vlsp23':
313
+ vlsp_test_dir = os.path.abspath(os.path.join(vlsp_dir, os.pardir, "test"))
314
+ vlsp_test_files = os.listdir(vlsp_test_dir)
315
+ vlsp_test_files = [os.path.join(vlsp_test_dir, x) for x in vlsp_test_files if x.endswith(".csv")]
316
+
317
+ if len(vlsp_train_files) == 0:
318
+ raise FileNotFoundError("No train files (files starting with 'file') found in {}".format(vlsp_dir))
319
+ if not args.test_split and len(vlsp_test_files) == 0:
320
+ raise FileNotFoundError("No test files found in {}".format(vlsp_dir))
321
+ print("Loading training files from {}".format(vlsp_dir))
322
+ print("Procesing training files:\n {}".format("\n ".join(vlsp_train_files)))
323
+ with tempfile.TemporaryDirectory() as train_output_path:
324
+ vtb_convert.convert_files(vlsp_train_files, train_output_path, verbose=True, fix_errors=True, convert_brackets=args.convert_brackets, updated_tagset=updated_tagset)
325
+ # This produces a 0 length test set, just as a placeholder until the actual test set is released
326
+ if args.n_splits:
327
+ test_size = 0.1 if args.test_split else 0.0
328
+ dev_size = (1.0 - test_size) / args.n_splits
329
+ train_size = 1.0 - test_size - dev_size
330
+ for rotation in range(args.n_splits):
331
+ # there is a shuffle inside the split routine,
332
+ # so we need to reset the random seed each time
333
+ random.seed(args.seed)
334
+ rotation_name = "%s-%d-%d" % (dataset_name, rotation, args.n_splits)
335
+ if args.test_split:
336
+ rotation_name = rotation_name + "t"
337
+ vtb_split.split_files(train_output_path, paths["CONSTITUENCY_DATA_DIR"], rotation_name, train_size=train_size, dev_size=dev_size, rotation=(rotation, args.n_splits))
338
+ else:
339
+ test_size = 0.1 if args.test_split else 0.0
340
+ dev_size = 0.1
341
+ train_size = 1.0 - test_size - dev_size
342
+ if args.test_split:
343
+ dataset_name = dataset_name + "t"
344
+ vtb_split.split_files(train_output_path, paths["CONSTITUENCY_DATA_DIR"], dataset_name, train_size=train_size, dev_size=dev_size)
345
+
346
+ if not args.test_split:
347
+ print("Procesing test files:\n {}".format("\n ".join(vlsp_test_files)))
348
+ with tempfile.TemporaryDirectory() as test_output_path:
349
+ vtb_convert.convert_files(vlsp_test_files, test_output_path, verbose=True, fix_errors=True, convert_brackets=args.convert_brackets, updated_tagset=updated_tagset)
350
+ if args.n_splits:
351
+ for rotation in range(args.n_splits):
352
+ rotation_name = "%s-%d-%d" % (dataset_name, rotation, args.n_splits)
353
+ vtb_split.split_files(test_output_path, paths["CONSTITUENCY_DATA_DIR"], rotation_name, train_size=0, dev_size=0)
354
+ else:
355
+ vtb_split.split_files(test_output_path, paths["CONSTITUENCY_DATA_DIR"], dataset_name, train_size=0, dev_size=0)
356
+ if not args.test_split and not args.n_splits and dataset_name == 'vi_vlsp23':
357
+ print("Procesing test files and keeping ids:\n {}".format("\n ".join(vlsp_test_files)))
358
+ with tempfile.TemporaryDirectory() as test_output_path:
359
+ vtb_convert.convert_files(vlsp_test_files, test_output_path, verbose=True, fix_errors=True, convert_brackets=args.convert_brackets, updated_tagset=updated_tagset, write_ids=True)
360
+ vtb_split.split_files(test_output_path, paths["CONSTITUENCY_DATA_DIR"], dataset_name + "-ids", train_size=0, dev_size=0)
361
+
362
+ def process_arboretum(paths, dataset_name, *args):
363
+ """
364
+ Processes the Danish dataset, Arboretum
365
+ """
366
+ assert dataset_name == 'da_arboretum'
367
+
368
+ arboretum_file = os.path.join(paths["CONSTITUENCY_BASE"], "danish", "arboretum", "arboretum.tiger", "arboretum.tiger")
369
+ if not os.path.exists(arboretum_file):
370
+ raise FileNotFoundError("Unable to find input file for Arboretum. Expected in {}".format(arboretum_file))
371
+
372
+ treebank = convert_tiger_treebank(arboretum_file)
373
+ datasets = utils.split_treebank(treebank, 0.8, 0.1)
374
+ output_dir = paths["CONSTITUENCY_DATA_DIR"]
375
+
376
+ output_filename = os.path.join(output_dir, "%s.mrg" % dataset_name)
377
+ print("Writing {} trees to {}".format(len(treebank), output_filename))
378
+ parse_tree.Tree.write_treebank(treebank, output_filename)
379
+
380
+ write_dataset(datasets, output_dir, dataset_name)
381
+
382
+
383
+ def process_starlang(paths, dataset_name, *args):
384
+ """
385
+ Convert the Turkish Starlang dataset to brackets
386
+ """
387
+ assert dataset_name == 'tr_starlang'
388
+
389
+ PIECES = ["TurkishAnnotatedTreeBank-15",
390
+ "TurkishAnnotatedTreeBank2-15",
391
+ "TurkishAnnotatedTreeBank2-20"]
392
+
393
+ output_dir = paths["CONSTITUENCY_DATA_DIR"]
394
+ chunk_paths = [os.path.join(paths["CONSTITUENCY_BASE"], "turkish", piece) for piece in PIECES]
395
+ datasets = read_starlang(chunk_paths)
396
+
397
+ write_dataset(datasets, output_dir, dataset_name)
398
+
399
+ def process_ja_alt(paths, dataset_name, *args):
400
+ """
401
+ Convert and split the ALT dataset
402
+
403
+ TODO: could theoretically extend this to MY or any other similar dataset from ALT
404
+ """
405
+ lang, source = dataset_name.split("_", 1)
406
+ assert lang == 'ja'
407
+ assert source == 'alt'
408
+
409
+ PIECES = ["Japanese-ALT-Draft.txt", "Japanese-ALT-Reviewed.txt"]
410
+ input_dir = os.path.join(paths["CONSTITUENCY_BASE"], "japanese", "Japanese-ALT-20210218")
411
+ input_files = [os.path.join(input_dir, input_file) for input_file in PIECES]
412
+ split_files = [os.path.join(input_dir, "URL-%s.txt" % shard) for shard in SHARDS]
413
+ output_dir = paths["CONSTITUENCY_DATA_DIR"]
414
+ output_files = [os.path.join(output_dir, "%s_%s.mrg" % (dataset_name, shard)) for shard in SHARDS]
415
+ convert_alt(input_files, split_files, output_files)
416
+
417
+ def process_pt_cintil(paths, dataset_name, *args):
418
+ """
419
+ Convert and split the PT Cintil dataset
420
+ """
421
+ lang, source = dataset_name.split("_", 1)
422
+ assert lang == 'pt'
423
+ assert source == 'cintil'
424
+
425
+ input_file = os.path.join(paths["CONSTITUENCY_BASE"], "portuguese", "CINTIL", "CINTIL-Treebank.xml")
426
+ output_dir = paths["CONSTITUENCY_DATA_DIR"]
427
+ datasets = convert_cintil_treebank(input_file)
428
+
429
+ write_dataset(datasets, output_dir, dataset_name)
430
+
431
+ def process_id_icon(paths, dataset_name, *args):
432
+ lang, source = dataset_name.split("_", 1)
433
+ assert lang == 'id'
434
+ assert source == 'icon'
435
+
436
+ input_dir = os.path.join(paths["CONSTITUENCY_BASE"], "seacorenlp", "seacorenlp-data", "id", "constituency")
437
+ input_files = [os.path.join(input_dir, x) for x in ("train.txt", "dev.txt", "test.txt")]
438
+ datasets = []
439
+ for input_file in input_files:
440
+ trees = tree_reader.read_tree_file(input_file)
441
+ trees = [Tree("ROOT", tree) for tree in trees]
442
+ datasets.append(trees)
443
+
444
+ output_dir = paths["CONSTITUENCY_DATA_DIR"]
445
+ write_dataset(datasets, output_dir, dataset_name)
446
+
447
+ def process_ctb_51(paths, dataset_name, *args):
448
+ lang, source = dataset_name.split("_", 1)
449
+ assert lang == 'zh-hans'
450
+ assert source == 'ctb-51'
451
+
452
+ input_dir = os.path.join(paths["CONSTITUENCY_BASE"], "chinese", "LDC2005T01U01_ChineseTreebank5.1", "bracketed")
453
+ output_dir = paths["CONSTITUENCY_DATA_DIR"]
454
+ convert_ctb.convert_ctb(input_dir, output_dir, dataset_name, convert_ctb.Version.V51)
455
+
456
+ def process_ctb_51b(paths, dataset_name, *args):
457
+ lang, source = dataset_name.split("_", 1)
458
+ assert lang == 'zh-hans'
459
+ assert source == 'ctb-51b'
460
+
461
+ input_dir = os.path.join(paths["CONSTITUENCY_BASE"], "chinese", "LDC2005T01U01_ChineseTreebank5.1", "bracketed")
462
+ output_dir = paths["CONSTITUENCY_DATA_DIR"]
463
+ if not os.path.exists(input_dir):
464
+ raise FileNotFoundError("CTB 5.1 location not found: %s" % input_dir)
465
+ print("Loading trees from %s" % input_dir)
466
+ convert_ctb.convert_ctb(input_dir, output_dir, dataset_name, convert_ctb.Version.V51b)
467
+
468
+ def process_ctb_90(paths, dataset_name, *args):
469
+ lang, source = dataset_name.split("_", 1)
470
+ assert lang == 'zh-hans'
471
+ assert source == 'ctb-90'
472
+
473
+ input_dir = os.path.join(paths["CONSTITUENCY_BASE"], "chinese", "LDC2016T13", "ctb9.0", "data", "bracketed")
474
+ output_dir = paths["CONSTITUENCY_DATA_DIR"]
475
+ convert_ctb.convert_ctb(input_dir, output_dir, dataset_name, convert_ctb.Version.V90)
476
+
477
+
478
+ def process_ptb3_revised(paths, dataset_name, *args):
479
+ input_dir = os.path.join(paths["CONSTITUENCY_BASE"], "english", "LDC2015T13_eng_news_txt_tbnk-ptb_revised")
480
+ if not os.path.exists(input_dir):
481
+ backup_input_dir = os.path.join(paths["CONSTITUENCY_BASE"], "english", "LDC2015T13")
482
+ if not os.path.exists(backup_input_dir):
483
+ raise FileNotFoundError("Could not find ptb3-revised in either %s or %s" % (input_dir, backup_input_dir))
484
+ input_dir = backup_input_dir
485
+
486
+ bracket_dir = os.path.join(input_dir, "data", "penntree")
487
+ output_dir = paths["CONSTITUENCY_DATA_DIR"]
488
+
489
+ # compensate for a weird mislabeling in the original dataset
490
+ label_map = {"ADJ-PRD": "ADJP-PRD"}
491
+
492
+ train_trees = []
493
+ for i in tqdm(range(2, 22)):
494
+ new_trees = tree_reader.read_directory(os.path.join(bracket_dir, "%02d" % i))
495
+ new_trees = [t.remap_constituent_labels(label_map) for t in new_trees]
496
+ train_trees.extend(new_trees)
497
+
498
+ move_tregex = "_ROOT_ <1 __=home <2 /^[.]$/=move"
499
+ move_tsurgeon = "move move >-1 home"
500
+
501
+ print("Moving sentence final punctuation if necessary")
502
+ with tsurgeon.Tsurgeon() as tsurgeon_processor:
503
+ train_trees = [tsurgeon_processor.process(tree, move_tregex, move_tsurgeon)[0] for tree in tqdm(train_trees)]
504
+
505
+ dev_trees = tree_reader.read_directory(os.path.join(bracket_dir, "22"))
506
+ dev_trees = [t.remap_constituent_labels(label_map) for t in dev_trees]
507
+
508
+ test_trees = tree_reader.read_directory(os.path.join(bracket_dir, "23"))
509
+ test_trees = [t.remap_constituent_labels(label_map) for t in test_trees]
510
+ print("Read %d train trees, %d dev trees, and %d test trees" % (len(train_trees), len(dev_trees), len(test_trees)))
511
+ datasets = [train_trees, dev_trees, test_trees]
512
+ write_dataset(datasets, output_dir, dataset_name)
513
+
514
+ def process_en_mctb(paths, dataset_name, *args):
515
+ """
516
+ Converts the following blocks:
517
+
518
+ dialogue.cleaned.txt forum.cleaned.txt law.cleaned.txt literature.cleaned.txt review.cleaned.txt
519
+ """
520
+ base_path = os.path.join(paths["CONSTITUENCY_BASE"], "english", "multi-domain-parsing-analysis", "data", "MCTB_en")
521
+ if not os.path.exists(base_path):
522
+ raise FileNotFoundError("Please download multi-domain-parsing-analysis to %s" % base_path)
523
+ def tree_callback(tree):
524
+ if len(tree.children) > 1:
525
+ tree = parse_tree.Tree("S", tree.children)
526
+ return parse_tree.Tree("ROOT", [tree])
527
+ return parse_tree.Tree("ROOT", tree.children)
528
+
529
+ filenames = ["dialogue.cleaned.txt", "forum.cleaned.txt", "law.cleaned.txt", "literature.cleaned.txt", "review.cleaned.txt"]
530
+ for filename in filenames:
531
+ trees = tree_reader.read_tree_file(os.path.join(base_path, filename), tree_callback=tree_callback)
532
+ print("%d trees in %s" % (len(trees), filename))
533
+ output_filename = "%s-%s_test.mrg" % (dataset_name, filename.split(".")[0])
534
+ output_filename = os.path.join(paths["CONSTITUENCY_DATA_DIR"], output_filename)
535
+ print("Writing trees to %s" % output_filename)
536
+ parse_tree.Tree.write_treebank(trees, output_filename)
537
+
538
+ def process_spmrl(paths, dataset_name, *args):
539
+ if dataset_name != 'de_spmrl':
540
+ raise ValueError("SPMRL dataset %s currently not supported" % dataset_name)
541
+
542
+ output_directory = paths["CONSTITUENCY_DATA_DIR"]
543
+ input_directory = os.path.join(paths["CONSTITUENCY_BASE"], "spmrl", "SPMRL_SHARED_2014", "GERMAN_SPMRL", "gold", "ptb")
544
+
545
+ convert_spmrl(input_directory, output_directory, dataset_name)
546
+
547
+ DATASET_MAPPING = {
548
+ 'da_arboretum': process_arboretum,
549
+
550
+ 'de_spmrl': process_spmrl,
551
+
552
+ 'en_ptb3-revised': process_ptb3_revised,
553
+ 'en_mctb': process_en_mctb,
554
+
555
+ 'id_icon': process_id_icon,
556
+
557
+ 'it_turin': process_it_turin,
558
+ 'it_vit': process_it_vit,
559
+
560
+ 'ja_alt': process_ja_alt,
561
+
562
+ 'pt_cintil': process_pt_cintil,
563
+
564
+ 'tr_starlang': process_starlang,
565
+
566
+ 'vi_vlsp09': process_vlsp09,
567
+ 'vi_vlsp21': process_vlsp21,
568
+ 'vi_vlsp22': process_vlsp22,
569
+ 'vi_vlsp23': process_vlsp22, # options allow for this
570
+
571
+ 'zh-hans_ctb-51': process_ctb_51,
572
+ 'zh-hans_ctb-51b': process_ctb_51b,
573
+ 'zh-hans_ctb-90': process_ctb_90,
574
+ }
575
+
576
+ def main(dataset_name, *args):
577
+ paths = default_paths.get_default_paths()
578
+
579
+ random.seed(1234)
580
+
581
+ if dataset_name in DATASET_MAPPING:
582
+ DATASET_MAPPING[dataset_name](paths, dataset_name, *args)
583
+ else:
584
+ raise UnknownDatasetError(dataset_name, f"dataset {dataset_name} currently not handled by prepare_con_dataset")
585
+
586
+ if __name__ == '__main__':
587
+ if len(sys.argv) == 1:
588
+ print("Known datasets:")
589
+ for key in DATASET_MAPPING:
590
+ print(" %s" % key)
591
+ else:
592
+ main(sys.argv[1], sys.argv[2:])
593
+
594
+
stanza/stanza/utils/datasets/constituency/silver_variance.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Use the concepts in "Dataset Cartography" and "Mind Your Outliers" to find trees with the least variance over a training run
3
+
4
+ https://arxiv.org/pdf/2009.10795.pdf
5
+ https://arxiv.org/abs/2107.02331
6
+
7
+ The idea here is that high variance trees are more likely to be wrong in the first place. Using this will filter a silver dataset to have better trees.
8
+
9
+ for example:
10
+
11
+ nlprun -d a6000 -p high "export CLASSPATH=/sailhome/horatio/CoreNLP/classes:/sailhome/horatio/CoreNLP/lib/*:$CLASSPATH; python3 stanza/utils/datasets/constituency/silver_variance.py --eval_file /u/nlp/data/constituency-parser/italian/2024_it_vit_electra/it_silver_0.mrg saved_models/constituency/it_vit.top.each.silver0.constituency_0*0.pt --output_file filtered_silver0.mrg" -o filter.out
12
+ """
13
+
14
+ import argparse
15
+
16
+ import logging
17
+
18
+ import numpy
19
+
20
+ from stanza.models.common import utils
21
+ from stanza.models.common.foundation_cache import FoundationCache
22
+ from stanza.models.constituency import retagging
23
+ from stanza.models.constituency import tree_reader
24
+ from stanza.models.constituency.parser_training import run_dev_set
25
+ from stanza.models.constituency.trainer import Trainer
26
+ from stanza.models.constituency.utils import retag_trees
27
+ from stanza.server.parser_eval import EvaluateParser
28
+ from stanza.utils.get_tqdm import get_tqdm
29
+
30
+ tqdm = get_tqdm()
31
+
32
+ logger = logging.getLogger('stanza.constituency.trainer')
33
+
34
+ def parse_args(args=None):
35
+ parser = argparse.ArgumentParser(description="Script to filter trees by how much variance they show over multiple checkpoints of a parser training run.")
36
+
37
+ parser.add_argument('--eval_file', type=str, default=None, help='Input file for data loader.')
38
+ parser.add_argument('--output_file', type=str, default=None, help='Output file after sorting by variance.')
39
+
40
+ parser.add_argument('--charlm_forward_file', type=str, default=None, help="Exact path to use for forward charlm")
41
+ parser.add_argument('--charlm_backward_file', type=str, default=None, help="Exact path to use for backward charlm")
42
+ parser.add_argument('--wordvec_pretrain_file', type=str, default=None, help='Exact name of the pretrain file to read')
43
+
44
+ utils.add_device_args(parser)
45
+
46
+ # TODO: use the training scripts to pick the charlm & pretrain if needed
47
+ parser.add_argument('--lang', default='it', help='Language to use')
48
+
49
+ parser.add_argument('--eval_batch_size', type=int, default=50, help='How many trees to batch when running eval')
50
+ parser.add_argument('models', type=str, nargs='+', default=None, help="Which model(s) to load")
51
+
52
+ parser.add_argument('--keep', type=float, default=0.5, help="How many trees to keep after sorting by variance")
53
+ parser.add_argument('--reverse', default=False, action='store_true', help='Actually, keep the high variance trees')
54
+
55
+ retagging.add_retag_args(parser)
56
+
57
+ args = vars(parser.parse_args())
58
+
59
+ retagging.postprocess_args(args)
60
+
61
+ return args
62
+
63
+ def main():
64
+ args = parse_args()
65
+ retag_pipeline = retagging.build_retag_pipeline(args)
66
+ foundation_cache = retag_pipeline[0].foundation_cache if retag_pipeline else FoundationCache()
67
+
68
+ print("Analyzing with the following models:\n " + "\n ".join(args['models']))
69
+
70
+ treebank = tree_reader.read_treebank(args['eval_file'])
71
+ logger.info("Read %d trees for analysis", len(treebank))
72
+
73
+ f1_history = []
74
+ retagged_treebank = None
75
+
76
+ chunk_size = 5000
77
+ with EvaluateParser() as evaluator:
78
+ for model_filename in args['models']:
79
+ print("Starting processing with %s" % model_filename)
80
+ trainer = Trainer.load(model_filename, args=args, foundation_cache=foundation_cache)
81
+ if retag_pipeline is not None and retagged_treebank is None:
82
+ retag_method = trainer.model.args['retag_method']
83
+ retag_xpos = trainer.model.args['retag_xpos']
84
+ logger.info("Retagging trees using the %s tags from the %s package...", retag_method, args['retag_package'])
85
+ retagged_treebank = retag_trees(treebank, retag_pipeline, retag_xpos)
86
+ logger.info("Retagging finished")
87
+
88
+ current_history = []
89
+ for chunk_start in range(0, len(treebank), chunk_size):
90
+ chunk = treebank[chunk_start:chunk_start+chunk_size]
91
+ retagged_chunk = retagged_treebank[chunk_start:chunk_start+chunk_size] if retagged_treebank else None
92
+ f1, kbestF1, treeF1 = run_dev_set(trainer.model, retagged_chunk, chunk, args, evaluator)
93
+ current_history.extend(treeF1)
94
+
95
+ f1_history.append(current_history)
96
+
97
+ f1_history = numpy.array(f1_history)
98
+ f1_variance = numpy.var(f1_history, axis=0)
99
+ f1_sorted = sorted([(x, idx) for idx, x in enumerate(f1_variance)], reverse=args['reverse'])
100
+
101
+ num_keep = int(len(f1_sorted) * args['keep'])
102
+ with open(args['output_file'], "w", encoding="utf-8") as fout:
103
+ for _, idx in f1_sorted[:num_keep]:
104
+ fout.write(str(treebank[idx]))
105
+ fout.write("\n")
106
+
107
+ if __name__ == "__main__":
108
+ main()
stanza/stanza/utils/datasets/coref/convert_hindi.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ from operator import itemgetter
4
+ import os
5
+
6
+ import stanza
7
+
8
+ from stanza.utils.default_paths import get_default_paths
9
+ from stanza.utils.get_tqdm import get_tqdm
10
+ from stanza.utils.datasets.coref.utils import process_document
11
+
12
+ tqdm = get_tqdm()
13
+
14
+ def flatten_spans(coref_spans):
15
+ """
16
+ Put span IDs on each span, then flatten them into a single list sorted by first word
17
+ """
18
+ # put span indices on the spans
19
+ # [[[38, 39], [42, 43], [41, 41], [180, 180], [300, 300]], [[60, 68],
20
+ # -->
21
+ # [[[0, 38, 39], [0, 42, 43], [0, 41, 41], [0, 180, 180], [0, 300, 300]], [[1, 60, 68], ...
22
+ coref_spans = [[[span_idx, x, y] for x, y in span] for span_idx, span in enumerate(coref_spans)]
23
+ # flatten list
24
+ # -->
25
+ # [[0, 38, 39], [0, 42, 43], [0, 41, 41], [0, 180, 180], [0, 300, 300], [1, 60, 68], ...
26
+ coref_spans = [y for x in coref_spans for y in x]
27
+ # sort by the first word index
28
+ # -->
29
+ # [[0, 38, 39], [0, 42, 43], [0, 41, 41], [1, 60, 68], [0, 180, 180], [0, 300, 300], ...
30
+ coref_spans = sorted(coref_spans, key=itemgetter(1))
31
+ return coref_spans
32
+
33
+ def remove_nulls(coref_spans, sentences):
34
+ """
35
+ Removes the "" and "NULL" words from the sentences
36
+
37
+ Also, reindex the spans by the number of words removed.
38
+ So, we might get something like
39
+ [[0, 2], [31, 33], [134, 136], [161, 162]]
40
+ ->
41
+ [[0, 2], [30, 32], [129, 131], [155, 156]]
42
+ """
43
+ word_map = []
44
+ word_idx = 0
45
+ map_idx = 0
46
+ new_sentences = []
47
+ for sentence in sentences:
48
+ new_sentence = []
49
+ for word in sentence:
50
+ word_map.append(map_idx)
51
+ word_idx += 1
52
+ if word != '' and word != 'NULL':
53
+ new_sentence.append(word)
54
+ map_idx += 1
55
+ new_sentences.append(new_sentence)
56
+
57
+ new_spans = []
58
+ for mention in coref_spans:
59
+ new_mention = []
60
+ for span in mention:
61
+ span = [word_map[x] for x in span]
62
+ new_mention.append(span)
63
+ new_spans.append(new_mention)
64
+ return new_spans, new_sentences
65
+
66
+ def arrange_spans_by_sentence(coref_spans, sentences):
67
+ sentence_spans = []
68
+
69
+ current_index = 0
70
+ span_idx = 0
71
+ for sentence in sentences:
72
+ current_sentence_spans = []
73
+ end_index = current_index + len(sentence)
74
+ while span_idx < len(coref_spans) and coref_spans[span_idx][1] < end_index:
75
+ new_span = [coref_spans[span_idx][0], coref_spans[span_idx][1] - current_index, coref_spans[span_idx][2] - current_index]
76
+ current_sentence_spans.append(new_span)
77
+ span_idx += 1
78
+ sentence_spans.append(current_sentence_spans)
79
+ current_index = end_index
80
+ return sentence_spans
81
+
82
+ def convert_dataset_section(pipe, section, use_cconj_heads):
83
+ """
84
+ Reprocess the original data into a format compatible with previous conversion utilities
85
+
86
+ - remove blank and NULL words
87
+ - rearrange the spans into spans per sentence instead of a list of indices for each span
88
+ - process the document using a Hindi pipeline
89
+ """
90
+ processed_section = []
91
+
92
+ for idx, doc in enumerate(tqdm(section)):
93
+ doc_id = doc['doc_key']
94
+ part_id = ""
95
+ sentences = doc['sentences']
96
+ sentence_speakers = doc['speakers']
97
+
98
+ coref_spans = doc['clusters']
99
+ coref_spans, sentences = remove_nulls(coref_spans, sentences)
100
+ coref_spans = flatten_spans(coref_spans)
101
+ coref_spans = arrange_spans_by_sentence(coref_spans, sentences)
102
+
103
+ processed = process_document(pipe, doc_id, part_id, sentences, coref_spans, sentence_speakers, use_cconj_heads=use_cconj_heads)
104
+ processed_section.append(processed)
105
+ return processed_section
106
+
107
+ def remove_nulls_dataset_section(section):
108
+ processed_section = []
109
+ for doc in section:
110
+ sentences = doc['sentences']
111
+ coref_spans = doc['clusters']
112
+ coref_spans, sentences = remove_nulls(coref_spans, sentences)
113
+ doc['sentences'] = sentences
114
+ doc['clusters'] = coref_spans
115
+ processed_section.append(doc)
116
+ return processed_section
117
+
118
+
119
+ def read_json_file(filename):
120
+ with open(filename, encoding="utf-8") as fin:
121
+ dataset = []
122
+ for line in fin:
123
+ line = line.strip()
124
+ if not line:
125
+ continue
126
+ dataset.append(json.loads(line))
127
+ return dataset
128
+
129
+ def write_json_file(output_filename, converted_section):
130
+ with open(output_filename, "w", encoding="utf-8") as fout:
131
+ json.dump(converted_section, fout, indent=2)
132
+
133
+ def main():
134
+ parser = argparse.ArgumentParser(
135
+ prog='Convert Hindi Coref Data',
136
+ )
137
+ parser.add_argument('--no_use_cconj_heads', dest='use_cconj_heads', action='store_false', help="Don't use the conjunction-aware transformation")
138
+ parser.add_argument('--remove_nulls', action='store_true', help="The only action is to remove the NULLs and blank tokens")
139
+ args = parser.parse_args()
140
+
141
+ paths = get_default_paths()
142
+ coref_input_path = paths["COREF_BASE"]
143
+ hindi_base_path = os.path.join(coref_input_path, "hindi", "dataset")
144
+
145
+ sections = ("train", "dev", "test")
146
+ if args.remove_nulls:
147
+ for section in sections:
148
+ input_filename = os.path.join(hindi_base_path, "%s.hindi.jsonlines" % section)
149
+ dataset = read_json_file(input_filename)
150
+ dataset = remove_nulls_dataset_section(dataset)
151
+ output_filename = os.path.join(hindi_base_path, "hi_deeph.%s.nonulls.json" % section)
152
+ with open(output_filename, "w", encoding="utf-8") as fout:
153
+ for doc in dataset:
154
+ json.dump(doc, fout, ensure_ascii=False)
155
+ fout.write("\n")
156
+ else:
157
+ pipe = stanza.Pipeline("hi", processors="tokenize,pos,lemma,depparse", package="default_accurate", tokenize_pretokenized=True)
158
+
159
+ os.makedirs(paths["COREF_DATA_DIR"], exist_ok=True)
160
+
161
+ for section in sections:
162
+ input_filename = os.path.join(hindi_base_path, "%s.hindi.jsonlines" % section)
163
+ dataset = read_json_file(input_filename)
164
+
165
+ output_filename = os.path.join(paths["COREF_DATA_DIR"], "hi_deeph.%s.json" % section)
166
+ converted_section = convert_dataset_section(pipe, dataset, use_cconj_heads=args.use_cconj_heads)
167
+ write_json_file(output_filename, converted_section)
168
+
169
+ if __name__ == '__main__':
170
+ main()
stanza/stanza/utils/datasets/ner/compare_entities.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Report the fraction of NER entities in one file which are present in another.
3
+
4
+ Purpose: show the coverage of one file on another, such as reporting
5
+ the number of entities in one dataset on another
6
+ """
7
+
8
+
9
+ import argparse
10
+
11
+ from stanza.utils.datasets.ner.utils import read_json_entities
12
+
13
+ def parse_args():
14
+ parser = argparse.ArgumentParser(description="Report the coverage of one NER file on another.")
15
+ parser.add_argument('--train', type=str, nargs="+", required=True, help='File to use to collect the known entities (not necessarily train).')
16
+ parser.add_argument('--test', type=str, nargs="+", required=True, help='File for which we want to know the ratio of known entities')
17
+ args = parser.parse_args()
18
+ return args
19
+
20
+ def report_known_entities(train_file, test_file):
21
+ train_entities = read_json_entities(train_file)
22
+ test_entities = read_json_entities(test_file)
23
+
24
+ train_entities = set(x[0] for x in train_entities)
25
+ total_score = sum(1 for x in test_entities if x[0] in train_entities)
26
+ print(train_file, test_file, total_score / len(test_entities))
27
+
28
+ def main():
29
+ args = parse_args()
30
+
31
+ for train_idx, train_file in enumerate(args.train):
32
+ if train_idx > 0:
33
+ print()
34
+ for test_file in args.test:
35
+ report_known_entities(train_file, test_file)
36
+
37
+ if __name__ == '__main__':
38
+ main()
stanza/stanza/utils/datasets/ner/conll_to_iob.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Process a conll file into BIO
3
+
4
+ Includes the ability to process a file from a text file
5
+ or a text file within a zip
6
+
7
+ Main program extracts a piece of the zip file from the Danish DDT dataset
8
+ """
9
+
10
+ import io
11
+ import zipfile
12
+ from zipfile import ZipFile
13
+ from stanza.utils.conll import CoNLL
14
+
15
+ def process_conll(input_file, output_file, zip_file=None, conversion=None, attr_prefix="name", allow_empty=False):
16
+ """
17
+ Process a single file from DDT
18
+
19
+ zip_filename: path to ddt.zip
20
+ in_filename: which piece to read
21
+ out_filename: where to write the result
22
+
23
+ label: which attribute to get from the misc field
24
+ """
25
+ if not attr_prefix.endswith("="):
26
+ attr_prefix = attr_prefix + "="
27
+
28
+ doc = CoNLL.conll2doc(input_file=input_file, zip_file=zip_file)
29
+
30
+ with open(output_file, "w", encoding="utf-8") as fout:
31
+ for sentence_idx, sentence in enumerate(doc.sentences):
32
+ for token_idx, token in enumerate(sentence.tokens):
33
+ misc = token.misc.split("|")
34
+ for attr in misc:
35
+ if attr.startswith(attr_prefix):
36
+ ner = attr.split("=", 1)[1]
37
+ break
38
+ else: # name= not found
39
+ if allow_empty:
40
+ ner = "O"
41
+ else:
42
+ raise ValueError("Could not find ner tag in document {}, sentence {}, token {}".format(input_file, sentence_idx, token_idx))
43
+
44
+ if ner != "O" and conversion is not None:
45
+ if isinstance(conversion, dict):
46
+ bio, label = ner.split("-", 1)
47
+ if label in conversion:
48
+ label = conversion[label]
49
+ ner = "%s-%s" % (bio, label)
50
+ else:
51
+ ner = conversion(ner)
52
+ fout.write("%s\t%s\n" % (token.text, ner))
53
+ fout.write("\n")
54
+
55
+ def main():
56
+ process_conll(zip_file="extern_data/ner/da_ddt/ddt.zip", input_file="ddt.train.conllu", output_file="data/ner/da_ddt.train.bio")
57
+
58
+ if __name__ == '__main__':
59
+ main()
stanza/stanza/utils/datasets/ner/convert_bn_daffodil.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Convert a Bengali NER dataset to our internal .json format
3
+
4
+ The dataset is here:
5
+
6
+ https://github.com/Rifat1493/Bengali-NER/tree/master/Input
7
+ """
8
+
9
+ import argparse
10
+ import os
11
+ import random
12
+ import tempfile
13
+
14
+ from stanza.utils.datasets.ner.utils import read_tsv, write_dataset
15
+
16
+ def redo_time_tags(sentences):
17
+ """
18
+ Replace all TIM, TIM with B-TIM, I-TIM
19
+
20
+ A brief use of Google Translate suggests the time phrases are
21
+ generally one phrase, so we don't want to turn this into B-TIM, B-TIM
22
+ """
23
+ new_sentences = []
24
+
25
+ for sentence in sentences:
26
+ new_sentence = []
27
+ prev_time = False
28
+ for word, tag in sentence:
29
+ if tag == 'TIM':
30
+ if prev_time:
31
+ new_sentence.append((word, "I-TIM"))
32
+ else:
33
+ prev_time = True
34
+ new_sentence.append((word, "B-TIM"))
35
+ else:
36
+ prev_time = False
37
+ new_sentence.append((word, tag))
38
+ new_sentences.append(new_sentence)
39
+
40
+ return new_sentences
41
+
42
+ def strip_words(dataset):
43
+ return [[(x[0].strip().replace('\ufeff', ''), x[1]) for x in sentence] for sentence in dataset]
44
+
45
+ def filter_blank_words(train_file, train_filtered_file):
46
+ """
47
+ As of July 2022, this dataset has blank words with O labels, which is not ideal
48
+
49
+ This method removes those lines
50
+ """
51
+ with open(train_file, encoding="utf-8") as fin:
52
+ with open(train_filtered_file, "w", encoding="utf-8") as fout:
53
+ for line in fin:
54
+ if line.strip() == 'O':
55
+ continue
56
+ fout.write(line)
57
+
58
+ def filter_broken_tags(train_sentences):
59
+ """
60
+ Eliminate any sentences where any of the tags were empty
61
+ """
62
+ return [x for x in train_sentences if not any(y[1] is None for y in x)]
63
+
64
+ def filter_bad_words(train_sentences):
65
+ """
66
+ Not bad words like poop, but characters that don't exist
67
+
68
+ These characters look like n and l in emacs, but they are really
69
+ 0xF06C and 0xF06E
70
+ """
71
+ return [[x for x in sentence if not x[0] in ("", "")] for sentence in train_sentences]
72
+
73
+ def read_datasets(in_directory):
74
+ """
75
+ Reads & splits the train data, reads the test data
76
+
77
+ There is no validation data, so we split the training data into
78
+ two pieces and use the smaller piece as the dev set
79
+
80
+ Also performeed is a conversion of TIM -> B-TIM, I-TIM
81
+ """
82
+ # make sure we always get the same shuffle & split
83
+ random.seed(1234)
84
+
85
+ train_file = os.path.join(in_directory, "Input", "train_data.txt")
86
+ with tempfile.TemporaryDirectory() as tempdir:
87
+ train_filtered_file = os.path.join(tempdir, "train.txt")
88
+ filter_blank_words(train_file, train_filtered_file)
89
+ train_sentences = read_tsv(train_filtered_file, text_column=0, annotation_column=1, keep_broken_tags=True)
90
+ train_sentences = filter_broken_tags(train_sentences)
91
+ train_sentences = filter_bad_words(train_sentences)
92
+ train_sentences = redo_time_tags(train_sentences)
93
+ train_sentences = strip_words(train_sentences)
94
+
95
+ test_file = os.path.join(in_directory, "Input", "test_data.txt")
96
+ test_sentences = read_tsv(test_file, text_column=0, annotation_column=1, keep_broken_tags=True)
97
+ test_sentences = filter_broken_tags(test_sentences)
98
+ test_sentences = filter_bad_words(test_sentences)
99
+ test_sentences = redo_time_tags(test_sentences)
100
+ test_sentences = strip_words(test_sentences)
101
+
102
+ random.shuffle(train_sentences)
103
+ split_len = len(train_sentences) * 9 // 10
104
+ dev_sentences = train_sentences[split_len:]
105
+ train_sentences = train_sentences[:split_len]
106
+
107
+ datasets = (train_sentences, dev_sentences, test_sentences)
108
+ return datasets
109
+
110
+ def convert_dataset(in_directory, out_directory):
111
+ """
112
+ Reads the datasets using read_datasets, then write them back out
113
+ """
114
+ datasets = read_datasets(in_directory)
115
+ write_dataset(datasets, out_directory, "bn_daffodil")
116
+
117
+ if __name__ == '__main__':
118
+ parser = argparse.ArgumentParser()
119
+ parser.add_argument('--input_path', type=str, default="/home/john/extern_data/ner/bangla/Bengali-NER", help="Where to find the files")
120
+ parser.add_argument('--output_path', type=str, default="/home/john/stanza/data/ner", help="Where to output the results")
121
+ args = parser.parse_args()
122
+
123
+ convert_dataset(args.input_path, args.output_path)
stanza/stanza/utils/datasets/ner/convert_en_conll03.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Downloads (if necessary) conll03 from Huggingface, then converts it to Stanza .json
3
+
4
+ Some online sources for CoNLL 2003 require multiple pieces, but it is currently hosted on HF:
5
+ https://huggingface.co/datasets/conll2003
6
+ """
7
+
8
+ import os
9
+
10
+ from stanza.utils.default_paths import get_default_paths
11
+ from stanza.utils.datasets.ner.utils import write_dataset
12
+
13
+ TAG_TO_ID = {'O': 0, 'B-PER': 1, 'I-PER': 2, 'B-ORG': 3, 'I-ORG': 4, 'B-LOC': 5, 'I-LOC': 6, 'B-MISC': 7, 'I-MISC': 8}
14
+ ID_TO_TAG = {y: x for x, y in TAG_TO_ID.items()}
15
+
16
+ def convert_dataset_section(section):
17
+ sentences = []
18
+ for item in section:
19
+ words = item['tokens']
20
+ tags = [ID_TO_TAG[x] for x in item['ner_tags']]
21
+ sentences.append(list(zip(words, tags)))
22
+ return sentences
23
+
24
+ def process_dataset(short_name, conll_path, ner_output_path):
25
+ try:
26
+ from datasets import load_dataset
27
+ except ImportError as e:
28
+ raise ImportError("Please install the datasets package to process CoNLL03 with Stanza")
29
+
30
+ dataset = load_dataset('conll2003', cache_dir=conll_path)
31
+ datasets = [convert_dataset_section(x) for x in [dataset['train'], dataset['validation'], dataset['test']]]
32
+ write_dataset(datasets, ner_output_path, short_name)
33
+
34
+ def main():
35
+ paths = get_default_paths()
36
+ ner_input_path = paths['NERBASE']
37
+ conll_path = os.path.join(ner_input_path, "english", "en_conll03")
38
+ ner_output_path = paths['NER_DATA_DIR']
39
+ process_dataset("en_conll03", conll_path, ner_output_path)
40
+
41
+ if __name__ == '__main__':
42
+ main()
stanza/stanza/utils/datasets/ner/convert_he_iahlt.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ import os
3
+ import re
4
+
5
+ from stanza.utils.conll import CoNLL
6
+ import stanza.utils.default_paths as default_paths
7
+ from stanza.utils.datasets.ner.utils import write_dataset
8
+
9
+ def output_entities(sentence):
10
+ for word in sentence.words:
11
+ misc = word.misc
12
+ if misc is None:
13
+ continue
14
+
15
+ pieces = misc.split("|")
16
+ for piece in pieces:
17
+ if piece.startswith("Entity="):
18
+ entity = piece.split("=", maxsplit=1)[1]
19
+ print(" " + entity)
20
+ break
21
+
22
+ def extract_single_sentence(sentence):
23
+ current_entity = []
24
+ words = []
25
+ for word in sentence.words:
26
+ text = word.text
27
+ misc = word.misc
28
+ if misc is None:
29
+ pieces = []
30
+ else:
31
+ pieces = misc.split("|")
32
+
33
+ closes = []
34
+ first_entity = False
35
+ for piece in pieces:
36
+ if piece.startswith("Entity="):
37
+ entity = piece.split("=", maxsplit=1)[1]
38
+ entity_pieces = re.split(r"([()])", entity)
39
+ entity_pieces = [x for x in entity_pieces if x] # remove blanks from re.split
40
+ entity_idx = 0
41
+ while entity_idx < len(entity_pieces):
42
+ if entity_pieces[entity_idx] == '(':
43
+ assert len(entity_pieces) > entity_idx + 1, "Opening an unspecified entity"
44
+ if len(current_entity) == 0:
45
+ first_entity = True
46
+ current_entity.append(entity_pieces[entity_idx + 1])
47
+ entity_idx += 2
48
+ elif entity_pieces[entity_idx] == ')':
49
+ assert entity_idx != 0, "Closing an unspecified entity"
50
+ closes.append(entity_pieces[entity_idx-1])
51
+ entity_idx += 1
52
+ else:
53
+ # the entities themselves get added or removed via the ()
54
+ entity_idx += 1
55
+
56
+ if len(current_entity) == 0:
57
+ entity = 'O'
58
+ else:
59
+ entity = current_entity[0]
60
+ entity = "B-" + entity if first_entity else "I-" + entity
61
+ words.append((text, entity))
62
+
63
+ assert len(current_entity) >= len(closes), "Too many closes for the current open entities"
64
+ for close_entity in closes:
65
+ # TODO: check the close is closing the right thing
66
+ assert close_entity == current_entity[-1], "Closed the wrong entity: %s vs %s" % (close_entity, current_entity[-1])
67
+ current_entity = current_entity[:-1]
68
+ return words
69
+
70
+ def extract_sentences(doc):
71
+ sentences = []
72
+ for sentence in doc.sentences:
73
+ try:
74
+ words = extract_single_sentence(sentence)
75
+ sentences.append(words)
76
+ except AssertionError as e:
77
+ print("Skipping sentence %s ... %s" % (sentence.sent_id, str(e)))
78
+ output_entities(sentence)
79
+
80
+ return sentences
81
+
82
+ def convert_iahlt(udbase, output_dir, short_name):
83
+ shards = ("train", "dev", "test")
84
+ ud_datasets = ["UD_Hebrew-IAHLTwiki", "UD_Hebrew-IAHLTknesset"]
85
+ base_filenames = ["he_iahltwiki-ud-%s.conllu", "he_iahltknesset-ud-%s.conllu"]
86
+ datasets = defaultdict(list)
87
+
88
+ for ud_dataset, base_filename in zip(ud_datasets, base_filenames):
89
+ ud_dataset_path = os.path.join(udbase, ud_dataset)
90
+ for shard in shards:
91
+ filename = os.path.join(ud_dataset_path, base_filename % shard)
92
+ doc = CoNLL.conll2doc(filename)
93
+ sentences = extract_sentences(doc)
94
+ print("Read %d sentences from %s" % (len(sentences), filename))
95
+ datasets[shard].extend(sentences)
96
+
97
+ datasets = [datasets[x] for x in shards]
98
+ write_dataset(datasets, output_dir, short_name)
99
+
100
+ def main():
101
+ paths = default_paths.get_default_paths()
102
+
103
+ udbase = paths["UDBASE_GIT"]
104
+ output_directory = paths["NER_DATA_DIR"]
105
+ convert_iahlt(udbase, output_directory, "he_iahlt")
106
+
107
+ if __name__ == '__main__':
108
+ main()
stanza/stanza/utils/datasets/ner/convert_lst20.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Converts the Thai LST20 dataset to a format usable by Stanza's NER model
3
+
4
+ The dataset in the original format has a few tag errors which we
5
+ automatically fix (or at worst cover up)
6
+ """
7
+
8
+ import os
9
+
10
+ from stanza.utils.datasets.ner.utils import convert_bio_to_json
11
+
12
+ def convert_lst20(paths, short_name, include_space_char=True):
13
+ assert short_name == "th_lst20"
14
+ SHARDS = ("train", "eval", "test")
15
+ BASE_OUTPUT_PATH = paths["NER_DATA_DIR"]
16
+
17
+ input_split = [(os.path.join(paths["NERBASE"], "thai", "LST20_Corpus", x), x) for x in SHARDS]
18
+
19
+ if not include_space_char:
20
+ short_name = short_name + "_no_ws"
21
+
22
+ for input_folder, split_type in input_split:
23
+ text_list = [text for text in os.listdir(input_folder) if text[0] == 'T']
24
+
25
+ if split_type == "eval":
26
+ split_type = "dev"
27
+
28
+ output_path = os.path.join(BASE_OUTPUT_PATH, "%s.%s.bio" % (short_name, split_type))
29
+ print(output_path)
30
+
31
+ with open(output_path, 'w', encoding='utf-8') as fout:
32
+ for text in text_list:
33
+ lst = []
34
+ with open(os.path.join(input_folder, text), 'r', encoding='utf-8') as fin:
35
+ lines = fin.readlines()
36
+
37
+ for line_idx, line in enumerate(lines):
38
+ x = line.strip().split('\t')
39
+ if len(x) > 1:
40
+ if x[0] == '_' and not include_space_char:
41
+ continue
42
+ else:
43
+ word, tag = x[0], x[2]
44
+
45
+ if tag == "MEA_BI":
46
+ tag = "B_MEA"
47
+ if tag == "OBRN_B":
48
+ tag = "B_BRN"
49
+ if tag == "ORG_I":
50
+ tag = "I_ORG"
51
+ if tag == "PER_I":
52
+ tag = "I_PER"
53
+ if tag == "LOC_I":
54
+ tag = "I_LOC"
55
+ if tag == "B" and line_idx + 1 < len(lines):
56
+ x_next = lines[line_idx+1].strip().split('\t')
57
+ if len(x_next) > 1:
58
+ tag_next = x_next[2]
59
+ if "I_" in tag_next or "E_" in tag_next:
60
+ tag = tag + tag_next[1:]
61
+ else:
62
+ tag = "O"
63
+ else:
64
+ tag = "O"
65
+ if "_" in tag:
66
+ tag = tag.replace("_", "-")
67
+ if "ABB" in tag or tag == "DDEM" or tag == "I" or tag == "__":
68
+ tag = "O"
69
+
70
+ fout.write('{}\t{}'.format(word, tag))
71
+ fout.write('\n')
72
+ else:
73
+ fout.write('\n')
74
+ convert_bio_to_json(BASE_OUTPUT_PATH, BASE_OUTPUT_PATH, short_name)
stanza/stanza/utils/datasets/ner/convert_mr_l3cube.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Reads one piece of the MR L3Cube dataset
3
+
4
+ The dataset is structured as a long list of words already in IOB format
5
+ The sentences have an ID which changes when a new sentence starts
6
+ The tags are labeled BNEM instead of B-NEM, so we update that.
7
+ (Could theoretically remap the tags to names more typical of other datasets as well)
8
+ """
9
+
10
+ def convert(input_file):
11
+ """
12
+ Converts one file of the dataset
13
+
14
+ Return: a list of list of pairs, (text, tag)
15
+ """
16
+ with open(input_file, encoding="utf-8") as fin:
17
+ lines = fin.readlines()
18
+
19
+ sentences = []
20
+ current_sentence = []
21
+ prev_sent_id = None
22
+ for idx, line in enumerate(lines):
23
+ # first line of each of the segments is the header
24
+ if idx == 0:
25
+ continue
26
+
27
+ line = line.strip()
28
+ if not line:
29
+ continue
30
+ pieces = line.split("\t")
31
+ if len(pieces) != 3:
32
+ raise ValueError("Unexpected number of pieces at line %d of %s" % (idx, input_file))
33
+
34
+ text, ner, sent_id = pieces
35
+ if ner != 'O':
36
+ # ner symbols are written as BNEM, BNED, etc in this dataset
37
+ ner = ner[0] + "-" + ner[1:]
38
+
39
+ if not prev_sent_id:
40
+ prev_sent_id = sent_id
41
+ if sent_id != prev_sent_id:
42
+ prev_sent_id = sent_id
43
+ if len(current_sentence) == 0:
44
+ raise ValueError("This should not happen!")
45
+ sentences.append(current_sentence)
46
+ current_sentence = []
47
+
48
+ current_sentence.append((text, ner))
49
+
50
+ if current_sentence:
51
+ sentences.append(current_sentence)
52
+
53
+ print("Read %d sentences in %d lines from %s" % (len(sentences), len(lines), input_file))
54
+ return sentences
stanza/stanza/utils/datasets/ner/convert_nner22.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Converts the Thai NNER22 dataset to a format usable by Stanza's NER model
3
+
4
+ The dataset is already written in json format, so we will convert into a compatible json format.
5
+
6
+ The dataset in the original format has nested NER format which we will only extract the first layer
7
+ of NER tag and write it in the format accepted by current Stanza model
8
+ """
9
+
10
+ import os
11
+ import logging
12
+ import json
13
+
14
+ def convert_nner22(paths, short_name, include_space_char=True):
15
+ assert short_name == "th_nner22"
16
+ SHARDS = ("train", "dev", "test")
17
+ BASE_INPUT_PATH = os.path.join(paths["NERBASE"], "thai", "Thai-NNER", "data", "scb-nner-th-2022", "postproc")
18
+
19
+ if not include_space_char:
20
+ short_name = short_name + "_no_ws"
21
+
22
+ for shard in SHARDS:
23
+ input_path = os.path.join(BASE_INPUT_PATH, "%s.json" % (shard))
24
+ output_path = os.path.join(paths["NER_DATA_DIR"], "%s.%s.json" % (short_name, shard))
25
+
26
+ logging.info("Output path for %s split at %s" % (shard, output_path))
27
+
28
+ data = json.load(open(input_path))
29
+
30
+ documents = []
31
+
32
+ for i in range(len(data)):
33
+ token, entities = data[i]["tokens"], data[i]["entities"]
34
+
35
+ token_length, sofar = len(token), 0
36
+ document, ner_dict = [], {}
37
+
38
+ for entity in entities:
39
+ start, stop = entity["span"]
40
+
41
+ if stop > sofar:
42
+ ner = entity["entity_type"].upper()
43
+ sofar = stop
44
+
45
+ for j in range(start, stop):
46
+ if j == start:
47
+ ner_tag = "B-" + ner
48
+ elif j == stop - 1:
49
+ ner_tag = "E-" + ner
50
+ else:
51
+ ner_tag = "I-" + ner
52
+
53
+ ner_dict[j] = (ner_tag, token[j])
54
+
55
+ for k in range(token_length):
56
+ dict_add = {}
57
+
58
+ if k not in ner_dict:
59
+ dict_add["ner"], dict_add["text"] = "O", token[k]
60
+ else:
61
+ dict_add["ner"], dict_add["text"] = ner_dict[k]
62
+
63
+ document.append(dict_add)
64
+
65
+ documents.append(document)
66
+
67
+ with open(output_path, "w") as outfile:
68
+ json.dump(documents, outfile, indent=1)
69
+
70
+ logging.info("%s.%s.json file successfully created" % (short_name, shard))
stanza/stanza/utils/datasets/ner/convert_ontonotes.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Downloads (if necessary) conll03 from Huggingface, then converts it to Stanza .json
3
+
4
+ Some online sources for CoNLL 2003 require multiple pieces, but it is currently hosted on HF:
5
+ https://huggingface.co/datasets/conll2003
6
+ """
7
+
8
+ import os
9
+
10
+ from stanza.utils.default_paths import get_default_paths
11
+ from stanza.utils.datasets.ner.utils import write_dataset
12
+
13
+ ID_TO_TAG = ["O", "B-PERSON", "I-PERSON", "B-NORP", "I-NORP", "B-FAC", "I-FAC", "B-ORG", "I-ORG", "B-GPE", "I-GPE", "B-LOC", "I-LOC", "B-PRODUCT", "I-PRODUCT", "B-DATE", "I-DATE", "B-TIME", "I-TIME", "B-PERCENT", "I-PERCENT", "B-MONEY", "I-MONEY", "B-QUANTITY", "I-QUANTITY", "B-ORDINAL", "I-ORDINAL", "B-CARDINAL", "I-CARDINAL", "B-EVENT", "I-EVENT", "B-WORK_OF_ART", "I-WORK_OF_ART", "B-LAW", "I-LAW", "B-LANGUAGE", "I-LANGUAGE",]
14
+
15
+ def convert_dataset_section(config_name, section):
16
+ sentences = []
17
+ for doc in section:
18
+ # the nt_ sentences (New Testament) in the HF version of OntoNotes
19
+ # have blank named_entities, even though there was no original .name file
20
+ # that corresponded with these annotations
21
+ if config_name.startswith("english") and doc['document_id'].startswith("pt/nt"):
22
+ continue
23
+ for sentence in doc['sentences']:
24
+ words = sentence['words']
25
+ tags = [ID_TO_TAG[x] for x in sentence['named_entities']]
26
+ sentences.append(list(zip(words, tags)))
27
+ return sentences
28
+
29
+ def process_dataset(short_name, conll_path, ner_output_path):
30
+ try:
31
+ from datasets import load_dataset
32
+ except ImportError as e:
33
+ raise ImportError("Please install the datasets package to process CoNLL03 with Stanza")
34
+
35
+ if short_name == 'en_ontonotes':
36
+ # there is an english_v12, but it is filled with junk annotations
37
+ # for example, near the end:
38
+ # And John_O, I realize
39
+ config_name = 'english_v4'
40
+ elif short_name in ('zh_ontonotes', 'zh-hans_ontonotes'):
41
+ config_name = 'chinese_v4'
42
+ elif short_name == 'ar_ontonotes':
43
+ config_name = 'arabic_v4'
44
+ else:
45
+ raise ValueError("Unknown short name for downloading ontonotes: %s" % short_name)
46
+ dataset = load_dataset("conll2012_ontonotesv5", config_name, cache_dir=conll_path)
47
+ datasets = [convert_dataset_section(config_name, x) for x in [dataset['train'], dataset['validation'], dataset['test']]]
48
+ write_dataset(datasets, ner_output_path, short_name)
49
+
50
+ def main():
51
+ paths = get_default_paths()
52
+ ner_input_path = paths['NERBASE']
53
+ conll_path = os.path.join(ner_input_path, "english", "en_ontonotes")
54
+ ner_output_path = paths['NER_DATA_DIR']
55
+ process_dataset("en_ontonotes", conll_path, ner_output_path)
56
+
57
+ if __name__ == '__main__':
58
+ main()
stanza/stanza/utils/datasets/ner/json_to_bio.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ If you want to convert .json back to .bio for some reason, this will do it for you
3
+ """
4
+
5
+ import argparse
6
+ import json
7
+ import os
8
+ from stanza.models.common.doc import Document
9
+ from stanza.models.ner.utils import process_tags
10
+ from stanza.utils.default_paths import get_default_paths
11
+
12
+ def convert_json_to_bio(input_filename, output_filename):
13
+ with open(input_filename, encoding="utf-8") as fin:
14
+ doc = Document(json.load(fin))
15
+ sentences = [[(word.text, word.ner) for word in sentence.tokens] for sentence in doc.sentences]
16
+ sentences = process_tags(sentences, "bioes")
17
+ with open(output_filename, "w", encoding="utf-8") as fout:
18
+ for sentence in sentences:
19
+ for word in sentence:
20
+ fout.write("%s\t%s\n" % word)
21
+ fout.write("\n")
22
+
23
+ def main(args=None):
24
+ ner_data_dir = get_default_paths()['NER_DATA_DIR']
25
+ parser = argparse.ArgumentParser()
26
+ parser.add_argument('--input_filename', type=str, default="data/ner/en_foreign-4class.test.json", help='Convert an individual file')
27
+ parser.add_argument('--input_dir', type=str, default=ner_data_dir, help='Which directory to find the dataset, if using --input_dataset')
28
+ parser.add_argument('--input_dataset', type=str, help='Convert an entire dataset')
29
+ parser.add_argument('--output_suffix', type=str, default='bioes', help='suffix for output filenames')
30
+ args = parser.parse_args(args)
31
+
32
+ if args.input_dataset:
33
+ input_filenames = [os.path.join(args.input_dir, "%s.%s.json" % (args.input_dataset, shard))
34
+ for shard in ("train", "dev", "test")]
35
+ else:
36
+ input_filenames = [args.input_filename]
37
+ for input_filename in input_filenames:
38
+ output_filename = os.path.splitext(input_filename)[0] + "." + args.output_suffix
39
+ print("%s -> %s" % (input_filename, output_filename))
40
+ convert_json_to_bio(input_filename, output_filename)
41
+
42
+ if __name__ == '__main__':
43
+ main()
stanza/stanza/utils/datasets/ner/misc_to_date.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # for the Worldwide dataset, automatically switch the Misc tags to Date when Stanza Ontonotes thinks it's a Date
2
+ # this keeps our annotation scheme for dates (eg, not "3 months ago") while hopefully switching them all to Date
3
+ #
4
+ # maybe some got missed
5
+ # also, there are a few with some nested entities. printed out warnings and edited those by hand
6
+ #
7
+ # just need to run this with the Worldwide dataset in the ner path
8
+ # it will automatically convert as many as it can
9
+
10
+ import os
11
+
12
+ from tqdm import tqdm
13
+
14
+ import stanza
15
+ from stanza.utils.datasets.ner.utils import read_tsv
16
+ from stanza.utils.default_paths import get_default_paths
17
+
18
+ paths = get_default_paths()
19
+ BASE_PATH = os.path.join(paths["NERBASE"], "en_foreign")
20
+ input_dir = os.path.join(BASE_PATH, "en-foreign-newswire")
21
+
22
+ pipe = stanza.Pipeline("en", processors="tokenize,ner", tokenize_pretokenized=True, package={"ner": "ontonotes_bert"})
23
+
24
+ filenames = []
25
+
26
+ def ner_tags(pipe, sentence):
27
+ doc = pipe([sentence])
28
+ tags = [token.ner for sentence in doc.sentences for token in sentence.tokens]
29
+ return tags
30
+
31
+ for root, dirs, files in os.walk(input_dir):
32
+ if root[-6:] == "REVIEW":
33
+ batch_files = os.listdir(root)
34
+ for filename in batch_files:
35
+ file_path = os.path.join(root, filename)
36
+ filenames.append(file_path)
37
+
38
+ for filename in tqdm(filenames):
39
+ try:
40
+ data = read_tsv(filename, text_column=0, annotation_column=1, skip_comments=False, keep_all_columns=True)
41
+
42
+ with open(filename, 'w', encoding='utf-8') as fout:
43
+ warned_file = False
44
+ for sentence in data: # segments delimited by spaces, effectively sentences
45
+ tokens = [x[0] for x in sentence]
46
+ labels = [x[1] for x in sentence]
47
+
48
+ if any(x.endswith("Misc") for x in labels):
49
+ stanza_tags = ner_tags(pipe, tokens)
50
+ in_date = False
51
+ for i, stanza_tag in enumerate(stanza_tags):
52
+ if stanza_tag[2:] == "DATE" and labels[i] != "O":
53
+ if len(sentence[i]) > 2:
54
+ if not warned_file:
55
+ print("Warning: file %s has nested tags being altered" % filename)
56
+ warned_file = True
57
+ # put DATE tags where Stanza thinks there are DATEs
58
+ # as long as we already had a MISC (or something else, I suppose)
59
+ if in_date and not stanza_tag[0].startswith("B") and not stanza_tag[0].startswith("S"):
60
+ sentence[i][1] = "I-Date"
61
+ else:
62
+ sentence[i][1] = "B-Date"
63
+ in_date = True
64
+ elif in_date:
65
+ # make sure new tags start with B- instead of I-
66
+ # honestly it's not clear if, in these cases,
67
+ # we should be switching the following tags to
68
+ # DATE as well. will have to experiment some
69
+ in_date = False
70
+ if labels[i].startswith("I-"):
71
+ sentence[i][1] = "B-" + labels[i][2:]
72
+ for word in sentence:
73
+ fout.write("\t".join(word))
74
+ fout.write("\n")
75
+ fout.write("\n")
76
+ except AssertionError:
77
+ print("Could not process %s" % filename)
stanza/stanza/utils/datasets/ner/preprocess_wikiner.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Converts the WikiNER data format to a format usable by our processing tools
3
+
4
+ python preprocess_wikiner input output
5
+ """
6
+
7
+ import sys
8
+
9
+ def preprocess_wikiner(input_file, output_file, encoding="utf-8"):
10
+ with open(input_file, encoding=encoding) as fin:
11
+ with open(output_file, "w", encoding="utf-8") as fout:
12
+ for line in fin:
13
+ line = line.strip()
14
+ if not line:
15
+ fout.write("-DOCSTART- O\n")
16
+ fout.write("\n")
17
+ continue
18
+
19
+ words = line.split()
20
+ for word in words:
21
+ pieces = word.split("|")
22
+ text = pieces[0]
23
+ tag = pieces[-1]
24
+ # some words look like Daniel_Bernoulli|I-PER
25
+ # but the original .pl conversion script didn't take that into account
26
+ subtext = text.split("_")
27
+ if tag.startswith("B-") and len(subtext) > 1:
28
+ fout.write("{} {}\n".format(subtext[0], tag))
29
+ for chunk in subtext[1:]:
30
+ fout.write("{} I-{}\n".format(chunk, tag[2:]))
31
+ else:
32
+ for chunk in subtext:
33
+ fout.write("{} {}\n".format(chunk, tag))
34
+ fout.write("\n")
35
+
36
+ if __name__ == '__main__':
37
+ preprocess_wikiner(sys.argv[1], sys.argv[2])
stanza/stanza/utils/datasets/ner/simplify_en_worldwide.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import tempfile
4
+
5
+ import stanza
6
+ from stanza.utils.default_paths import get_default_paths
7
+ from stanza.utils.datasets.ner.utils import read_tsv
8
+ from stanza.utils.get_tqdm import get_tqdm
9
+
10
+ tqdm = get_tqdm()
11
+
12
+ PUNCTUATION = """!"#%&'()*+, -./:;<=>?@[\\]^_`{|}~"""
13
+ MONEY_WORDS = {"million", "billion", "trillion", "millions", "billions", "trillions", "hundred", "hundreds",
14
+ "lakh", "crore", # south asian english
15
+ "tens", "of", "ten", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine", "couple"}
16
+
17
+ # Doesn't include Money but this case is handled explicitly for processing
18
+ LABEL_TRANSLATION = {
19
+ "Date": None,
20
+ "Misc": "MISC",
21
+ "Product": "MISC",
22
+ "NORP": "MISC",
23
+ "Facility": "LOC",
24
+ "Location": "LOC",
25
+ "Person": "PER",
26
+ "Organization": "ORG",
27
+ }
28
+
29
+ def isfloat(num):
30
+ try:
31
+ float(num)
32
+ return True
33
+ except ValueError:
34
+ return False
35
+
36
+
37
+ def process_label(line, is_start=False):
38
+ """
39
+ Converts our stuff to conll labels
40
+
41
+ event, product, work of art, norp -> MISC
42
+ take out dates - can use Stanza to identify them as dates and eliminate them
43
+ money requires some special care
44
+ facility -> location (there are examples of Bridge and Hospital in the data)
45
+ the version of conll we used to train CoreNLP NER is here:
46
+
47
+ Overall plan:
48
+ Collapse Product, NORP, Money (extract only the symbols), into misc.
49
+ Collapse Facilities into LOC
50
+ Deletes Dates
51
+
52
+ Rule for currency is that we take out labels for the numbers that return True for isfloat()
53
+ Take out words that categorize money (Million, Billion, Trillion, Thousand, Hundred, Ten, Nine, Eight, Seven, Six, Five,
54
+ Four, Three, Two, One)
55
+ Take out punctuation characters
56
+
57
+ If we remove the 'B' tag, then move it to the first remaining tag.
58
+
59
+ Replace tags with 'O'
60
+ is_start parameter signals whether or not this current line is the new start of a tag. Needed for when
61
+ the previous line analyzed is the start of a MONEY tag but is removed because it is a non symbol- need to
62
+ set the starting token that is a symbol to the B-MONEY tag when it might have previously been I-MONEY
63
+ """
64
+ if not line:
65
+ return []
66
+ token = line[0]
67
+ biggest_label = line[1]
68
+ position, label_name = biggest_label[:2], biggest_label[2:]
69
+
70
+ if label_name == "Money":
71
+ if token.lower() in MONEY_WORDS or token in PUNCTUATION or isfloat(token): # remove this tag
72
+ label_name = "O"
73
+ is_start = True
74
+ position = ""
75
+ else: # keep money tag
76
+ label_name = "MISC"
77
+ if is_start:
78
+ position = "B-"
79
+ is_start = False
80
+
81
+ elif not label_name or label_name == "O":
82
+ pass
83
+ elif label_name in LABEL_TRANSLATION:
84
+ label_name = LABEL_TRANSLATION[label_name]
85
+ if label_name is None:
86
+ position = ""
87
+ label_name = "O"
88
+ is_start = False
89
+ else:
90
+ raise ValueError("Oops, missed a label: %s" % label_name)
91
+ return [token, position + label_name, is_start]
92
+
93
+
94
+ def write_new_file(save_dir, input_path, old_file, simplify):
95
+ starts_b = False
96
+ with open(input_path, "r+", encoding="utf-8") as iob:
97
+ new_filename = (os.path.splitext(old_file)[0] + ".4class.tsv") if simplify else old_file
98
+ with open(os.path.join(save_dir, new_filename), 'w', encoding='utf-8') as fout:
99
+ for i, line in enumerate(iob):
100
+ if i == 0 or i == 1: # skip over the URL and subsequent space line.
101
+ continue
102
+ line = line.strip()
103
+ if not line:
104
+ fout.write("\n")
105
+ continue
106
+ label = line.split("\t")
107
+ if simplify:
108
+ try:
109
+ edited = process_label(label, is_start=starts_b) # processed label line labels
110
+ except ValueError as e:
111
+ raise ValueError("Error in %s at line %d" % (input_path, i)) from e
112
+ assert edited
113
+ starts_b = edited[-1]
114
+ fout.write("\t".join(edited[:-1]))
115
+ fout.write("\n")
116
+ else:
117
+ fout.write("%s\t%s\n" % (label[0], label[1]))
118
+
119
+
120
+ def copy_and_simplify(base_path, simplify):
121
+ with tempfile.TemporaryDirectory(dir=base_path) as tempdir:
122
+ # Condense Labels
123
+ input_dir = os.path.join(base_path, "en-worldwide-newswire")
124
+ final_dir = os.path.join(base_path, "4class" if simplify else "9class")
125
+ os.makedirs(tempdir, exist_ok=True)
126
+ os.makedirs(final_dir, exist_ok=True)
127
+ for root, dirs, files in os.walk(input_dir):
128
+ if root[-6:] == "REVIEW":
129
+ batch_files = os.listdir(root)
130
+ for filename in batch_files:
131
+ file_path = os.path.join(root, filename)
132
+ write_new_file(final_dir, file_path, filename, simplify)
133
+
134
+ def main(args=None):
135
+ BASE_PATH = "C:\\Users\\SystemAdmin\\PycharmProjects\\General Code\\stanza source code"
136
+ if not os.path.exists(BASE_PATH):
137
+ paths = get_default_paths()
138
+ BASE_PATH = os.path.join(paths["NERBASE"], "en_worldwide")
139
+
140
+ parser = argparse.ArgumentParser()
141
+ parser.add_argument('--base_path', type=str, default=BASE_PATH, help="Where to find the raw data")
142
+ parser.add_argument('--simplify', default=False, action='store_true', help='Simplify to 4 classes... otherwise, keep all classes')
143
+ parser.add_argument('--no_simplify', dest='simplify', action='store_false', help="Don't simplify to 4 classes")
144
+ args = parser.parse_args(args=args)
145
+
146
+ copy_and_simplify(args.base_path, args.simplify)
147
+
148
+ if __name__ == '__main__':
149
+ main()
150
+
151
+
152
+
stanza/stanza/utils/datasets/ner/simplify_ontonotes_to_worldwide.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Simplify an existing ner json with the OntoNotes 18 class scheme to the Worldwide scheme
3
+
4
+ Simplified classes used in the Worldwide dataset are:
5
+
6
+ Date
7
+ Facility
8
+ Location
9
+ Misc
10
+ Money
11
+ NORP
12
+ Organization
13
+ Person
14
+ Product
15
+
16
+ vs OntoNotes classes:
17
+
18
+ CARDINAL
19
+ DATE
20
+ EVENT
21
+ FAC
22
+ GPE
23
+ LANGUAGE
24
+ LAW
25
+ LOC
26
+ MONEY
27
+ NORP
28
+ ORDINAL
29
+ ORG
30
+ PERCENT
31
+ PERSON
32
+ PRODUCT
33
+ QUANTITY
34
+ TIME
35
+ WORK_OF_ART
36
+ """
37
+
38
+ import argparse
39
+ import glob
40
+ import json
41
+ import os
42
+
43
+ from stanza.utils.default_paths import get_default_paths
44
+
45
+ WORLDWIDE_ENTITY_MAPPING = {
46
+ "CARDINAL": None,
47
+ "ORDINAL": None,
48
+ "PERCENT": None,
49
+ "QUANTITY": None,
50
+ "TIME": None,
51
+
52
+ "DATE": "Date",
53
+ "EVENT": "Misc",
54
+ "FAC": "Facility",
55
+ "GPE": "Location",
56
+ "LANGUAGE": "NORP",
57
+ "LAW": "Misc",
58
+ "LOC": "Location",
59
+ "MONEY": "Money",
60
+ "NORP": "NORP",
61
+ "ORG": "Organization",
62
+ "PERSON": "Person",
63
+ "PRODUCT": "Product",
64
+ "WORK_OF_ART": "Misc",
65
+
66
+ # identity map in case this is called on the Worldwide half of the tags
67
+ "Date": "Date",
68
+ "Facility": "Facility",
69
+ "Location": "Location",
70
+ "Misc": "Misc",
71
+ "Money": "Money",
72
+ "Organization":"Organization",
73
+ "Person": "Person",
74
+ "Product": "Product",
75
+ }
76
+
77
+ def simplify_ontonotes_to_worldwide(entity):
78
+ if not entity or entity == "O":
79
+ return "O"
80
+
81
+ ent_iob, ent_type = entity.split("-", maxsplit=1)
82
+
83
+ if ent_type in WORLDWIDE_ENTITY_MAPPING:
84
+ if not WORLDWIDE_ENTITY_MAPPING[ent_type]:
85
+ return "O"
86
+ return ent_iob + "-" + WORLDWIDE_ENTITY_MAPPING[ent_type]
87
+ raise ValueError("Unhandled entity: %s" % ent_type)
88
+
89
+ def convert_file(in_file, out_file):
90
+ with open(in_file) as fin:
91
+ gold_doc = json.load(fin)
92
+
93
+ for sentence in gold_doc:
94
+ for word in sentence:
95
+ if 'ner' not in word:
96
+ continue
97
+ word['ner'] = simplify_ontonotes_to_worldwide(word['ner'])
98
+
99
+ with open(out_file, "w", encoding="utf-8") as fout:
100
+ json.dump(gold_doc, fout, indent=2)
101
+
102
+ def main():
103
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
104
+ parser.add_argument('--input_dataset', type=str, default='en_ontonotes', help='which files to convert')
105
+ parser.add_argument('--output_dataset', type=str, default='en_ontonotes-8class', help='which files to write out')
106
+ parser.add_argument('--ner_data_dir', type=str, default=get_default_paths()["NER_DATA_DIR"], help='which directory has the data')
107
+ args = parser.parse_args()
108
+
109
+ input_files = glob.glob(os.path.join(args.ner_data_dir, args.input_dataset + ".*"))
110
+ for input_file in input_files:
111
+ output_file = os.path.split(input_file)[1][len(args.input_dataset):]
112
+ output_file = os.path.join(args.ner_data_dir, args.output_dataset + output_file)
113
+ print("Converting %s to %s" % (input_file, output_file))
114
+ convert_file(input_file, output_file)
115
+
116
+
117
+ if __name__ == '__main__':
118
+ main()
stanza/stanza/utils/datasets/ner/split_wikiner.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Preprocess the WikiNER dataset, by
3
+ 1) normalizing tags;
4
+ 2) split into train (70%), dev (15%), test (15%) datasets.
5
+ """
6
+
7
+ import os
8
+ import random
9
+ import warnings
10
+ from collections import Counter
11
+
12
+ def read_sentences(filename, encoding):
13
+ sents = []
14
+ cache = []
15
+ skipped = 0
16
+ skip = False
17
+ with open(filename, encoding=encoding) as infile:
18
+ for i, line in enumerate(infile):
19
+ line = line.rstrip()
20
+ if len(line) == 0:
21
+ if len(cache) > 0:
22
+ if not skip:
23
+ sents.append(cache)
24
+ else:
25
+ skipped += 1
26
+ skip = False
27
+ cache = []
28
+ continue
29
+ array = line.split()
30
+ if len(array) != 2:
31
+ skip = True
32
+ warnings.warn("Format error at line {}: {}".format(i+1, line))
33
+ continue
34
+ w, t = array
35
+ cache.append([w, t])
36
+ if len(cache) > 0:
37
+ if not skip:
38
+ sents.append(cache)
39
+ else:
40
+ skipped += 1
41
+ cache = []
42
+ print("Skipped {} examples due to formatting issues.".format(skipped))
43
+ return sents
44
+
45
+ def write_sentences_to_file(sents, filename):
46
+ print(f"Writing {len(sents)} sentences to {filename}")
47
+ with open(filename, 'w') as outfile:
48
+ for sent in sents:
49
+ for pair in sent:
50
+ print(f"{pair[0]}\t{pair[1]}", file=outfile)
51
+ print("", file=outfile)
52
+
53
+ def remap_labels(sents, remap):
54
+ new_sentences = []
55
+ for sentence in sents:
56
+ new_sent = []
57
+ for word in sentence:
58
+ new_sent.append([word[0], remap.get(word[1], word[1])])
59
+ new_sentences.append(new_sent)
60
+ return new_sentences
61
+
62
+ def split_wikiner(directory, *in_filenames, encoding="utf-8", prefix="", suffix="bio", remap=None, shuffle=True, train_fraction=0.7, dev_fraction=0.15, test_section=True):
63
+ random.seed(1234)
64
+
65
+ sents = []
66
+ for filename in in_filenames:
67
+ new_sents = read_sentences(filename, encoding)
68
+ print(f"{len(new_sents)} sentences read from {filename}.")
69
+ sents.extend(new_sents)
70
+
71
+ if remap:
72
+ sents = remap_labels(sents, remap)
73
+
74
+ # split
75
+ num = len(sents)
76
+ train_num = int(num*train_fraction)
77
+ if test_section:
78
+ dev_num = int(num*dev_fraction)
79
+ if train_fraction + dev_fraction > 1.0:
80
+ raise ValueError("Train and dev fractions added up to more than 1: {} {} {}".format(train_fraction, dev_fraction))
81
+ else:
82
+ dev_num = num - train_num
83
+
84
+ if shuffle:
85
+ random.shuffle(sents)
86
+ train_sents = sents[:train_num]
87
+ dev_sents = sents[train_num:train_num+dev_num]
88
+ if test_section:
89
+ test_sents = sents[train_num+dev_num:]
90
+ batches = [train_sents, dev_sents, test_sents]
91
+ filenames = [f'train.{suffix}', f'dev.{suffix}', f'test.{suffix}']
92
+ else:
93
+ batches = [train_sents, dev_sents]
94
+ filenames = [f'train.{suffix}', f'dev.{suffix}']
95
+
96
+ if prefix:
97
+ filenames = ['%s.%s' % (prefix, f) for f in filenames]
98
+ for batch, filename in zip(batches, filenames):
99
+ write_sentences_to_file(batch, os.path.join(directory, filename))
100
+
101
+ if __name__ == "__main__":
102
+ in_filename = 'raw/wp2.txt'
103
+ directory = "."
104
+ split_wikiner(directory, in_filename)
stanza/stanza/utils/datasets/ner/suc_conll_to_iob.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Process the licensed version of SUC3 to BIO
3
+
4
+ The main program processes the expected location, or you can pass in a
5
+ specific zip or filename to read
6
+ """
7
+
8
+ from io import TextIOWrapper
9
+ from zipfile import ZipFile
10
+
11
+ def extract(infile, outfile):
12
+ """
13
+ Convert the infile to an outfile
14
+
15
+ Assumes the files are already open (this allows you to pass in a zipfile reader, for example)
16
+
17
+ The SUC3 format is like conll, but with the tags in tabs 10 and 11
18
+ """
19
+ lines = infile.readlines()
20
+ sentences = []
21
+ cur_sentence = []
22
+ for idx, line in enumerate(lines):
23
+ line = line.strip()
24
+ if not line:
25
+ # if we're currently reading a sentence, append it to the list
26
+ if cur_sentence:
27
+ sentences.append(cur_sentence)
28
+ cur_sentence = []
29
+ continue
30
+
31
+ pieces = line.split("\t")
32
+ if len(pieces) < 12:
33
+ raise ValueError("Unexpected line length in the SUC3 dataset at %d" % idx)
34
+ if pieces[10] == 'O':
35
+ cur_sentence.append((pieces[1], "O"))
36
+ else:
37
+ cur_sentence.append((pieces[1], "%s-%s" % (pieces[10], pieces[11])))
38
+ if cur_sentence:
39
+ sentences.append(cur_sentence)
40
+
41
+ for sentence in sentences:
42
+ for word in sentence:
43
+ outfile.write("%s\t%s\n" % word)
44
+ outfile.write("\n")
45
+
46
+ return len(sentences)
47
+
48
+ def extract_from_zip(zip_filename, in_filename, out_filename):
49
+ """
50
+ Process a single file from SUC3
51
+
52
+ zip_filename: path to SUC3.0.zip
53
+ in_filename: which piece to read
54
+ out_filename: where to write the result
55
+ """
56
+ with ZipFile(zip_filename) as zin:
57
+ with zin.open(in_filename) as fin:
58
+ with open(out_filename, "w") as fout:
59
+ num = extract(TextIOWrapper(fin, encoding="utf-8"), fout)
60
+ print("Processed %d sentences from %s:%s to %s" % (num, zip_filename, in_filename, out_filename))
61
+ return num
62
+
63
+ def process_suc3(zip_filename, short_name, out_dir):
64
+ extract_from_zip(zip_filename, "SUC3.0/corpus/conll/suc-train.conll", "%s/%s.train.bio" % (out_dir, short_name))
65
+ extract_from_zip(zip_filename, "SUC3.0/corpus/conll/suc-dev.conll", "%s/%s.dev.bio" % (out_dir, short_name))
66
+ extract_from_zip(zip_filename, "SUC3.0/corpus/conll/suc-test.conll", "%s/%s.test.bio" % (out_dir, short_name))
67
+
68
+ def main():
69
+ process_suc3("extern_data/ner/sv_suc3/SUC3.0.zip", "data/ner")
70
+
71
+ if __name__ == '__main__':
72
+ main()
stanza/stanza/utils/datasets/pos/__init__.py ADDED
File without changes
stanza/stanza/utils/datasets/pos/convert_trees_to_pos.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Turns a constituency treebank into a POS dataset with the tags as the upos column
3
+
4
+ The constituency treebank first has to be converted from the original
5
+ data to PTB style trees. This script converts trees from the
6
+ CONSTITUENCY_DATA_DIR folder to a conllu dataset in the POS_DATA_DIR folder.
7
+
8
+ Note that this doesn't pay any attention to whether or not the tags actually are upos.
9
+ Also not possible: using this for tokenization.
10
+
11
+ TODO: upgrade the POS model to handle xpos datasets with no upos, then make upos/xpos an option here
12
+
13
+ To run this:
14
+ python3 stanza/utils/training/run_pos.py vi_vlsp22
15
+
16
+ """
17
+
18
+ import argparse
19
+ import os
20
+ import shutil
21
+ import sys
22
+
23
+ from stanza.models.constituency import tree_reader
24
+ import stanza.utils.default_paths as default_paths
25
+ from stanza.utils.get_tqdm import get_tqdm
26
+
27
+ tqdm = get_tqdm()
28
+
29
+ SHARDS = ("train", "dev", "test")
30
+
31
+ def convert_file(in_file, out_file, upos):
32
+ print("Reading %s" % in_file)
33
+ trees = tree_reader.read_tree_file(in_file)
34
+ print("Writing %s" % out_file)
35
+ with open(out_file, "w") as fout:
36
+ for tree in tqdm(trees):
37
+ tree = tree.simplify_labels()
38
+ text = " ".join(tree.leaf_labels())
39
+ fout.write("# text = %s\n" % text)
40
+
41
+ for pt_idx, pt in enumerate(tree.yield_preterminals()):
42
+ # word index
43
+ fout.write("%d\t" % (pt_idx+1))
44
+ # word
45
+ fout.write("%s\t" % pt.children[0].label)
46
+ # don't know the lemma
47
+ fout.write("_\t")
48
+ # always put the tag, whatever it is, in the upos (for now)
49
+ if upos:
50
+ fout.write("%s\t_\t" % pt.label)
51
+ else:
52
+ fout.write("_\t%s\t" % pt.label)
53
+ # don't have any features
54
+ fout.write("_\t")
55
+ # so word 0 fake dep on root, everyone else fake dep on previous word
56
+ fout.write("%d\t" % pt_idx)
57
+ if pt_idx == 0:
58
+ fout.write("root")
59
+ else:
60
+ fout.write("dep")
61
+ fout.write("\t_\t_\n")
62
+ fout.write("\n")
63
+
64
+ def convert_treebank(short_name, upos, output_name, paths):
65
+ in_dir = paths["CONSTITUENCY_DATA_DIR"]
66
+ in_files = [os.path.join(in_dir, "%s_%s.mrg" % (short_name, shard)) for shard in SHARDS]
67
+ for in_file in in_files:
68
+ if not os.path.exists(in_file):
69
+ raise FileNotFoundError("Cannot find expected datafile %s" % in_file)
70
+
71
+ out_dir = paths["POS_DATA_DIR"]
72
+ if not os.path.exists(out_dir):
73
+ os.makedirs(out_dir)
74
+ if output_name is None:
75
+ output_name = short_name
76
+ out_files = [os.path.join(out_dir, "%s.%s.in.conllu" % (output_name, shard)) for shard in SHARDS]
77
+ gold_files = [os.path.join(out_dir, "%s.%s.gold.conllu" % (output_name, shard)) for shard in SHARDS]
78
+
79
+ for in_file, out_file in zip(in_files, out_files):
80
+ convert_file(in_file, out_file, upos)
81
+ for out_file, gold_file in zip(out_files, gold_files):
82
+ shutil.copy2(out_file, gold_file)
83
+
84
+ if __name__ == '__main__':
85
+ parser = argparse.ArgumentParser()
86
+ parser.add_argument("dataset", help="Which dataset to process from trees to POS")
87
+ parser.add_argument("--upos", action="store_true", default=False, help="Store tags on the UPOS")
88
+ parser.add_argument("--xpos", dest="upos", action="store_false", help="Store tags on the XPOS")
89
+ parser.add_argument("--output_name", default=None, help="What name to give the output dataset. If blank, will use the dataset arg")
90
+ args = parser.parse_args()
91
+
92
+ paths = default_paths.get_default_paths()
93
+
94
+ convert_treebank(args.dataset, args.upos, args.output_name, paths)
stanza/stanza/utils/datasets/prepare_tokenizer_data.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ import re
5
+ import sys
6
+
7
+ from collections import Counter
8
+
9
+ """
10
+ Data is output in 4 files:
11
+
12
+ a file containing the mwt information
13
+ a file containing the words and sentences in conllu format
14
+ a file containing the raw text of each paragraph
15
+ a file of 0,1,2 indicating word break or sentence break on a character level for the raw text
16
+ 1: end of word
17
+ 2: end of sentence
18
+ """
19
+
20
+ PARAGRAPH_BREAK = re.compile(r'\n\s*\n')
21
+
22
+ def is_para_break(index, text):
23
+ """ Detect if a paragraph break can be found, and return the length of the paragraph break sequence. """
24
+ if text[index] == '\n':
25
+ para_break = PARAGRAPH_BREAK.match(text, index)
26
+ if para_break:
27
+ break_len = len(para_break.group(0))
28
+ return True, break_len
29
+ return False, 0
30
+
31
+ def find_next_word(index, text, word, output):
32
+ """
33
+ Locate the next word in the text. In case a paragraph break is found, also write paragraph break to labels.
34
+ """
35
+ idx = 0
36
+ word_sofar = ''
37
+ while index < len(text) and idx < len(word):
38
+ para_break, break_len = is_para_break(index, text)
39
+ if para_break:
40
+ # multiple newlines found, paragraph break
41
+ if len(word_sofar) > 0:
42
+ assert re.match(r'^\s+$', word_sofar), 'Found non-empty string at the end of a paragraph that doesn\'t match any token: |{}|'.format(word_sofar)
43
+ word_sofar = ''
44
+
45
+ output.write('\n\n')
46
+ index += break_len - 1
47
+ elif re.match(r'^\s$', text[index]) and not re.match(r'^\s$', word[idx]):
48
+ # whitespace found, and whitespace is not part of a word
49
+ word_sofar += text[index]
50
+ else:
51
+ # non-whitespace char, or a whitespace char that's part of a word
52
+ word_sofar += text[index]
53
+ assert text[index].replace('\n', ' ') == word[idx], "Character mismatch: raw text contains |%s| but the next word is |%s|." % (word_sofar, word)
54
+ idx += 1
55
+ index += 1
56
+ return index, word_sofar
57
+
58
+ def main(args):
59
+ parser = argparse.ArgumentParser()
60
+
61
+ parser.add_argument('plaintext_file', type=str, help="Plaintext file containing the raw input")
62
+ parser.add_argument('conllu_file', type=str, help="CoNLL-U file containing tokens and sentence breaks")
63
+ parser.add_argument('-o', '--output', default=None, type=str, help="Output file name; output to the console if not specified (the default)")
64
+ parser.add_argument('-m', '--mwt_output', default=None, type=str, help="Output file name for MWT expansions; output to the console if not specified (the default)")
65
+
66
+ args = parser.parse_args(args=args)
67
+
68
+ with open(args.plaintext_file, 'r', encoding='utf-8') as f:
69
+ text = ''.join(f.readlines())
70
+ textlen = len(text)
71
+
72
+ if args.output is None:
73
+ output = sys.stdout
74
+ else:
75
+ outdir = os.path.split(args.output)[0]
76
+ os.makedirs(outdir, exist_ok=True)
77
+ output = open(args.output, 'w')
78
+
79
+ index = 0 # character offset in rawtext
80
+
81
+ mwt_expansions = []
82
+ with open(args.conllu_file, 'r', encoding='utf-8') as f:
83
+ buf = ''
84
+ mwtbegin = 0
85
+ mwtend = -1
86
+ expanded = []
87
+ last_comments = ""
88
+ for line in f:
89
+ line = line.strip()
90
+ if len(line):
91
+ if line[0] == "#":
92
+ # comment, don't do anything
93
+ if len(last_comments) == 0:
94
+ last_comments = line
95
+ continue
96
+
97
+ line = line.split('\t')
98
+ if '.' in line[0]:
99
+ # the tokenizer doesn't deal with ellipsis
100
+ continue
101
+
102
+ word = line[1]
103
+ if '-' in line[0]:
104
+ # multiword token
105
+ mwtbegin, mwtend = [int(x) for x in line[0].split('-')]
106
+ lastmwt = word
107
+ expanded = []
108
+ elif mwtbegin <= int(line[0]) < mwtend:
109
+ expanded += [word]
110
+ continue
111
+ elif int(line[0]) == mwtend:
112
+ expanded += [word]
113
+ expanded = [x.lower() for x in expanded] # evaluation doesn't care about case
114
+ mwt_expansions += [(lastmwt, tuple(expanded))]
115
+ if lastmwt[0].islower() and not expanded[0][0].islower():
116
+ print('Sentence ID with potential wrong MWT expansion: ', last_comments, file=sys.stderr)
117
+ mwtbegin = 0
118
+ mwtend = -1
119
+ lastmwt = None
120
+ continue
121
+
122
+ if len(buf):
123
+ output.write(buf)
124
+ index, word_found = find_next_word(index, text, word, output)
125
+ buf = '0' * (len(word_found)-1) + ('1' if '-' not in line[0] else '3')
126
+ else:
127
+ # sentence break found
128
+ if len(buf):
129
+ assert int(buf[-1]) >= 1
130
+ output.write(buf[:-1] + '{}'.format(int(buf[-1]) + 1))
131
+ buf = ''
132
+
133
+ last_comments = ''
134
+
135
+ status_line = ""
136
+ if args.output:
137
+ output.close()
138
+ status_line = 'Tokenizer labels written to %s\n ' % args.output
139
+
140
+ mwts = Counter(mwt_expansions)
141
+ if args.mwt_output is None:
142
+ print('MWTs:', mwts)
143
+ else:
144
+ with open(args.mwt_output, 'w') as f:
145
+ json.dump(list(mwts.items()), f, indent=2)
146
+
147
+ status_line = status_line + '{} unique MWTs found in data. MWTs written to {}'.format(len(mwts), args.mwt_output)
148
+ print(status_line)
149
+
150
+ if __name__ == '__main__':
151
+ main(sys.argv[1:])
stanza/stanza/utils/datasets/prepare_tokenizer_treebank.py ADDED
@@ -0,0 +1,1396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Prepares train, dev, test for a treebank
3
+
4
+ For example, do
5
+ python -m stanza.utils.datasets.prepare_tokenizer_treebank TREEBANK
6
+ such as
7
+ python -m stanza.utils.datasets.prepare_tokenizer_treebank UD_English-EWT
8
+
9
+ and it will prepare each of train, dev, test
10
+
11
+ There are macros for preparing all of the UD treebanks at once:
12
+ python -m stanza.utils.datasets.prepare_tokenizer_treebank ud_all
13
+ python -m stanza.utils.datasets.prepare_tokenizer_treebank all_ud
14
+ Both are present because I kept forgetting which was the correct one
15
+
16
+ There are a few special case handlings of treebanks in this file:
17
+ - all Vietnamese treebanks have special post-processing to handle
18
+ some of the difficult spacing issues in Vietnamese text
19
+ - treebanks with train and test but no dev split have the
20
+ train data randomly split into two pieces
21
+ - however, instead of splitting very tiny treebanks, we skip those
22
+ """
23
+
24
+ import argparse
25
+ import glob
26
+ import io
27
+ import os
28
+ import random
29
+ import re
30
+ import tempfile
31
+ import zipfile
32
+
33
+ from collections import Counter
34
+
35
+ from stanza.models.common.constant import treebank_to_short_name
36
+ import stanza.utils.datasets.common as common
37
+ from stanza.utils.datasets.common import read_sentences_from_conllu, write_sentences_to_conllu, write_sentences_to_file, INT_RE, MWT_RE, MWT_OR_COPY_RE
38
+ import stanza.utils.datasets.tokenization.convert_ml_cochin as convert_ml_cochin
39
+ import stanza.utils.datasets.tokenization.convert_my_alt as convert_my_alt
40
+ import stanza.utils.datasets.tokenization.convert_vi_vlsp as convert_vi_vlsp
41
+ import stanza.utils.datasets.tokenization.convert_th_best as convert_th_best
42
+ import stanza.utils.datasets.tokenization.convert_th_lst20 as convert_th_lst20
43
+ import stanza.utils.datasets.tokenization.convert_th_orchid as convert_th_orchid
44
+
45
+ def copy_conllu_file(tokenizer_dir, tokenizer_file, dest_dir, dest_file, short_name):
46
+ original = f"{tokenizer_dir}/{short_name}.{tokenizer_file}.conllu"
47
+ copied = f"{dest_dir}/{short_name}.{dest_file}.conllu"
48
+
49
+ print("Copying from %s to %s" % (original, copied))
50
+ # do this instead of shutil.copyfile in case there are manipulations needed
51
+ # for example, we might need to add fake dependencies (TODO: still needed?)
52
+ sents = read_sentences_from_conllu(original)
53
+ write_sentences_to_conllu(copied, sents)
54
+
55
+ def copy_conllu_treebank(treebank, model_type, paths, dest_dir, postprocess=None, augment=True):
56
+ """
57
+ This utility method copies only the conllu files to the given destination directory.
58
+
59
+ Both POS, lemma, and depparse annotators need this.
60
+ """
61
+ os.makedirs(dest_dir, exist_ok=True)
62
+
63
+ short_name = treebank_to_short_name(treebank)
64
+ short_language = short_name.split("_")[0]
65
+
66
+ with tempfile.TemporaryDirectory() as tokenizer_dir:
67
+ paths = dict(paths)
68
+ paths["TOKENIZE_DATA_DIR"] = tokenizer_dir
69
+
70
+ # first we process the tokenization data
71
+ args = argparse.Namespace()
72
+ args.augment = augment
73
+ args.prepare_labels = False
74
+ process_treebank(treebank, model_type, paths, args)
75
+
76
+ os.makedirs(dest_dir, exist_ok=True)
77
+
78
+ if postprocess is None:
79
+ postprocess = copy_conllu_file
80
+
81
+ # now we copy the processed conllu data files
82
+ postprocess(tokenizer_dir, "train.gold", dest_dir, "train.in", short_name)
83
+ postprocess(tokenizer_dir, "dev.gold", dest_dir, "dev.in", short_name)
84
+ postprocess(tokenizer_dir, "test.gold", dest_dir, "test.in", short_name)
85
+ if model_type is not common.ModelType.POS and model_type is not common.ModelType.DEPPARSE:
86
+ copy_conllu_file(dest_dir, "dev.in", dest_dir, "dev.gold", short_name)
87
+ copy_conllu_file(dest_dir, "test.in", dest_dir, "test.gold", short_name)
88
+
89
+ def split_train_file(treebank, train_input_conllu, train_output_conllu, dev_output_conllu):
90
+ # set the seed for each data file so that the results are the same
91
+ # regardless of how many treebanks are processed at once
92
+ random.seed(1234)
93
+
94
+ # read and shuffle conllu data
95
+ sents = read_sentences_from_conllu(train_input_conllu)
96
+ random.shuffle(sents)
97
+ n_dev = int(len(sents) * XV_RATIO)
98
+ assert n_dev >= 1, "Dev sentence number less than one."
99
+ n_train = len(sents) - n_dev
100
+
101
+ # split conllu data
102
+ dev_sents = sents[:n_dev]
103
+ train_sents = sents[n_dev:]
104
+ print("Train/dev split not present. Randomly splitting train file from %s to %s and %s" % (train_input_conllu, train_output_conllu, dev_output_conllu))
105
+ print(f"{len(sents)} total sentences found: {n_train} in train, {n_dev} in dev")
106
+
107
+ # write conllu
108
+ write_sentences_to_conllu(train_output_conllu, train_sents)
109
+ write_sentences_to_conllu(dev_output_conllu, dev_sents)
110
+
111
+ return True
112
+
113
+
114
+ def has_space_after_no(piece):
115
+ if not piece or piece == "_":
116
+ return False
117
+ if piece == "SpaceAfter=No":
118
+ return True
119
+ tags = piece.split("|")
120
+ return any(t == "SpaceAfter=No" for t in tags)
121
+
122
+
123
+ def remove_space_after_no(piece, fail_if_missing=True):
124
+ """
125
+ Removes a SpaceAfter=No annotation from a single piece of a single word.
126
+ In other words, given a list of conll lines, first call split("\t"), then call this on the -1 column
127
+ """
128
+ # |SpaceAfter is in UD_Romanian-Nonstandard... seems fitting
129
+ if piece == "SpaceAfter=No" or piece == "|SpaceAfter=No":
130
+ piece = "_"
131
+ elif piece.startswith("SpaceAfter=No|"):
132
+ piece = piece.replace("SpaceAfter=No|", "")
133
+ elif piece.find("|SpaceAfter=No") > 0:
134
+ piece = piece.replace("|SpaceAfter=No", "")
135
+ elif fail_if_missing:
136
+ raise ValueError("Could not find SpaceAfter=No in the given notes field")
137
+ return piece
138
+
139
+ def add_space_after_no(piece, fail_if_found=True):
140
+ if piece == '_':
141
+ return "SpaceAfter=No"
142
+ else:
143
+ if fail_if_found:
144
+ if has_space_after_no(piece):
145
+ raise ValueError("Given notes field already contained SpaceAfter=No")
146
+ return piece + "|SpaceAfter=No"
147
+
148
+
149
+ def augment_arabic_padt(sents, ratio=0.05):
150
+ """
151
+ Basic Arabic tokenizer gets the trailing punctuation wrong if there is a blank space.
152
+
153
+ Reason seems to be that there are almost no examples of "text ." in the dataset.
154
+ This function augments the Arabic-PADT dataset with a few such examples.
155
+ TODO: it may very well be that a lot of tokeners have this problem.
156
+
157
+ Also, there are a few examples in UD2.7 which are apparently
158
+ headlines where there is a ' . ' in the middle of the text.
159
+ According to an Arabic speaking labmate, the sentences are
160
+ headlines which could be reasonably split into two items. Having
161
+ them as one item is quite confusing and possibly incorrect, but
162
+ such is life.
163
+ """
164
+ new_sents = []
165
+ for sentence in sents:
166
+ if len(sentence) < 4:
167
+ raise ValueError("Read a surprisingly short sentence")
168
+ text_line = None
169
+ if sentence[0].startswith("# newdoc") and sentence[3].startswith("# text"):
170
+ text_line = 3
171
+ elif sentence[0].startswith("# newpar") and sentence[2].startswith("# text"):
172
+ text_line = 2
173
+ elif sentence[0].startswith("# sent_id") and sentence[1].startswith("# text"):
174
+ text_line = 1
175
+ else:
176
+ raise ValueError("Could not find text line in %s" % sentence[0].split()[-1])
177
+
178
+ # for some reason performance starts dropping quickly at higher numbers
179
+ if random.random() > ratio:
180
+ continue
181
+
182
+ if (sentence[text_line][-1] in ('.', '؟', '?', '!') and
183
+ sentence[text_line][-2] not in ('.', '؟', '?', '!', ' ') and
184
+ has_space_after_no(sentence[-2].split()[-1]) and
185
+ len(sentence[-1].split()[1]) == 1):
186
+ new_sent = list(sentence)
187
+ new_sent[text_line] = new_sent[text_line][:-1] + ' ' + new_sent[text_line][-1]
188
+ pieces = sentence[-2].split("\t")
189
+ pieces[-1] = remove_space_after_no(pieces[-1])
190
+ new_sent[-2] = "\t".join(pieces)
191
+ assert new_sent != sentence
192
+ new_sents.append(new_sent)
193
+ return sents + new_sents
194
+
195
+
196
+ def augment_telugu(sents):
197
+ """
198
+ Add a few sentences with modified punctuation to Telugu_MTG
199
+
200
+ The Telugu-MTG dataset has punctuation separated from the text in
201
+ almost all cases, which makes the tokenizer not learn how to
202
+ process that correctly.
203
+
204
+ All of the Telugu sentences end with their sentence final
205
+ punctuation being separated. Furthermore, all commas are
206
+ separated. We change that on some subset of the sentences to
207
+ make the tools more generalizable on wild text.
208
+ """
209
+ new_sents = []
210
+ for sentence in sents:
211
+ if not sentence[1].startswith("# text"):
212
+ raise ValueError("Expected the second line of %s to start with # text" % sentence[0])
213
+ if not sentence[2].startswith("# translit"):
214
+ raise ValueError("Expected the second line of %s to start with # translit" % sentence[0])
215
+ if sentence[1].endswith(". . .") or sentence[1][-1] not in ('.', '?', '!'):
216
+ continue
217
+ if sentence[1][-1] in ('.', '?', '!') and sentence[1][-2] != ' ' and sentence[1][-3:] != ' ..' and sentence[1][-4:] != ' ...':
218
+ raise ValueError("Sentence %s does not end with space-punctuation, which is against our assumptions for the te_mtg treebank. Please check the augment method to see if it is still needed" % sentence[0])
219
+ if random.random() < 0.1:
220
+ new_sentence = list(sentence)
221
+ new_sentence[1] = new_sentence[1][:-2] + new_sentence[1][-1]
222
+ new_sentence[2] = new_sentence[2][:-2] + new_sentence[2][-1]
223
+ new_sentence[-2] = new_sentence[-2] + "|SpaceAfter=No"
224
+ new_sents.append(new_sentence)
225
+ if sentence[1].find(",") > 1 and random.random() < 0.1:
226
+ new_sentence = list(sentence)
227
+ index = sentence[1].find(",")
228
+ new_sentence[1] = sentence[1][:index-1] + sentence[1][index:]
229
+ index = sentence[1].find(",")
230
+ new_sentence[2] = sentence[2][:index-1] + sentence[2][index:]
231
+ for idx, word in enumerate(new_sentence):
232
+ if idx < 4:
233
+ # skip sent_id, text, transliteration, and the first word
234
+ continue
235
+ if word.split("\t")[1] == ',':
236
+ new_sentence[idx-1] = new_sentence[idx-1] + "|SpaceAfter=No"
237
+ break
238
+ new_sents.append(new_sentence)
239
+ return sents + new_sents
240
+
241
+ COMMA_SEPARATED_RE = re.compile(" ([a-zA-Z]+)[,] ([a-zA-Z]+) ")
242
+ def augment_comma_separations(sents, ratio=0.03):
243
+ """Find some fraction of the sentences which match "asdf, zzzz" and squish them to "asdf,zzzz"
244
+
245
+ This leaves the tokens and all of the other data the same. The
246
+ only change made is to change SpaceAfter=No for the "," token and
247
+ adjust the #text line, with the assumption that the conllu->txt
248
+ conversion will correctly handle this change.
249
+
250
+ This was particularly an issue for Spanish-AnCora, but it's
251
+ reasonable to think it could happen to any dataset. Currently
252
+ this just operates on commas and ascii letters to avoid
253
+ accidentally squishing anything that shouldn't be squished.
254
+
255
+ UD_Spanish-AnCora 2.7 had a problem is with this sentence:
256
+ # orig_file_sentence 143#5
257
+ In this sentence, there was a comma smashed next to a token.
258
+
259
+ Fixing just this one sentence is not sufficient to tokenize
260
+ "asdf,zzzz" as desired, so we also augment by some fraction where
261
+ we have squished "asdf, zzzz" into "asdf,zzzz".
262
+
263
+ This exact example was later fixed in UD 2.8, but it should still
264
+ potentially be useful for compensating for typos.
265
+ """
266
+ new_sents = []
267
+ for sentence in sents:
268
+ for text_idx, text_line in enumerate(sentence):
269
+ # look for the line that starts with "# text".
270
+ # keep going until we find it, or silently ignore it
271
+ # if the dataset isn't in that format
272
+ if text_line.startswith("# text"):
273
+ break
274
+ else:
275
+ continue
276
+
277
+ match = COMMA_SEPARATED_RE.search(sentence[text_idx])
278
+ if match and random.random() < ratio:
279
+ for idx, word in enumerate(sentence):
280
+ if word.startswith("#"):
281
+ continue
282
+ # find() doesn't work because we wind up finding substrings
283
+ if word.split("\t")[1] != match.group(1):
284
+ continue
285
+ if sentence[idx+1].split("\t")[1] != ',':
286
+ continue
287
+ if sentence[idx+2].split("\t")[1] != match.group(2):
288
+ continue
289
+ break
290
+ if idx == len(sentence) - 1:
291
+ # this can happen with MWTs. we may actually just
292
+ # want to skip MWTs anyway, so no big deal
293
+ continue
294
+ # now idx+1 should be the line with the comma in it
295
+ comma = sentence[idx+1]
296
+ pieces = comma.split("\t")
297
+ assert pieces[1] == ','
298
+ pieces[-1] = add_space_after_no(pieces[-1])
299
+ comma = "\t".join(pieces)
300
+ new_sent = sentence[:idx+1] + [comma] + sentence[idx+2:]
301
+
302
+ text_offset = sentence[text_idx].find(match.group(1) + ", " + match.group(2))
303
+ text_len = len(match.group(1) + ", " + match.group(2))
304
+ new_text = sentence[text_idx][:text_offset] + match.group(1) + "," + match.group(2) + sentence[text_idx][text_offset+text_len:]
305
+ new_sent[text_idx] = new_text
306
+
307
+ new_sents.append(new_sent)
308
+
309
+ print("Added %d new sentences with asdf, zzzz -> asdf,zzzz" % len(new_sents))
310
+
311
+ return sents + new_sents
312
+
313
+ def augment_move_comma(sents, ratio=0.02):
314
+ """
315
+ Move the comma from after a word to before the next word some fraction of the time
316
+
317
+ We looks for this exact pattern:
318
+ w1, w2
319
+ and replace it with
320
+ w1 ,w2
321
+
322
+ The idea is that this is a relatively common typo, but the tool
323
+ won't learn how to tokenize it without some help.
324
+
325
+ Note that this modification replaces the original text.
326
+ """
327
+ new_sents = []
328
+ num_operations = 0
329
+ for sentence in sents:
330
+ if random.random() > ratio:
331
+ new_sents.append(sentence)
332
+ continue
333
+
334
+ found = False
335
+ for word_idx, word in enumerate(sentence):
336
+ if word.startswith("#"):
337
+ continue
338
+ if word_idx == 0 or word_idx >= len(sentence) - 2:
339
+ continue
340
+ pieces = word.split("\t")
341
+ if pieces[1] == ',' and not has_space_after_no(pieces[-1]):
342
+ # found a comma with a space after it
343
+ prev_word = sentence[word_idx-1]
344
+ if not has_space_after_no(prev_word.split("\t")[-1]):
345
+ # unfortunately, the previous word also had a
346
+ # space after it. does not fit what we are
347
+ # looking for
348
+ continue
349
+ # also, want to skip instances near MWT or copy nodes,
350
+ # since those are harder to rearrange
351
+ next_word = sentence[word_idx+1]
352
+ if MWT_OR_COPY_RE.match(next_word.split("\t")[0]):
353
+ continue
354
+ if MWT_OR_COPY_RE.match(prev_word.split("\t")[0]):
355
+ continue
356
+ # at this point, the previous word has no space and the comma does
357
+ found = True
358
+ break
359
+
360
+ if not found:
361
+ new_sents.append(sentence)
362
+ continue
363
+
364
+ new_sentence = list(sentence)
365
+
366
+ pieces = new_sentence[word_idx].split("\t")
367
+ pieces[-1] = add_space_after_no(pieces[-1])
368
+ new_sentence[word_idx] = "\t".join(pieces)
369
+
370
+ pieces = new_sentence[word_idx-1].split("\t")
371
+ prev_word = pieces[1]
372
+ pieces[-1] = remove_space_after_no(pieces[-1])
373
+ new_sentence[word_idx-1] = "\t".join(pieces)
374
+
375
+ next_word = new_sentence[word_idx+1].split("\t")[1]
376
+
377
+ for text_idx, text_line in enumerate(sentence):
378
+ # look for the line that starts with "# text".
379
+ # keep going until we find it, or silently ignore it
380
+ # if the dataset isn't in that format
381
+ if text_line.startswith("# text"):
382
+ old_chunk = prev_word + ", " + next_word
383
+ new_chunk = prev_word + " ," + next_word
384
+ word_idx = text_line.find(old_chunk)
385
+ if word_idx < 0:
386
+ raise RuntimeError("Unexpected #text line which did not contain the original text to be modified. Looking for\n" + old_chunk + "\n" + text_line)
387
+ new_text_line = text_line[:word_idx] + new_chunk + text_line[word_idx+len(old_chunk):]
388
+ new_sentence[text_idx] = new_text_line
389
+ break
390
+
391
+ new_sents.append(new_sentence)
392
+ num_operations = num_operations + 1
393
+
394
+ print("Swapped 'w1, w2' for 'w1 ,w2' %d times" % num_operations)
395
+ return new_sents
396
+
397
+ def augment_apos(sents):
398
+
399
+ """
400
+ If there are no instances of ’ in the dataset, but there are instances of ',
401
+ we replace some fraction of ' with ’ so that the tokenizer will recognize it.
402
+
403
+ # TODO: we could do it the other way around as well
404
+ """
405
+ has_unicode_apos = False
406
+ has_ascii_apos = False
407
+ for sent_idx, sent in enumerate(sents):
408
+ if len(sent) == 0:
409
+ raise AssertionError("Got a blank sentence in position %d!" % sent_idx)
410
+ for line in sent:
411
+ if line.startswith("# text"):
412
+ if line.find("'") >= 0:
413
+ has_ascii_apos = True
414
+ if line.find("’") >= 0:
415
+ has_unicode_apos = True
416
+ break
417
+ else:
418
+ raise ValueError("Cannot find '# text' in sentences %d. First line: %s" % (sent_idx, sent[0]))
419
+
420
+ if has_unicode_apos or not has_ascii_apos:
421
+ return sents
422
+
423
+ new_sents = []
424
+ for sent in sents:
425
+ if random.random() > 0.05:
426
+ new_sents.append(sent)
427
+ continue
428
+ new_sent = []
429
+ for line in sent:
430
+ if line.startswith("# text"):
431
+ new_sent.append(line.replace("'", "’"))
432
+ elif line.startswith("#"):
433
+ new_sent.append(line)
434
+ else:
435
+ pieces = line.split("\t")
436
+ pieces[1] = pieces[1].replace("'", "’")
437
+ new_sent.append("\t".join(pieces))
438
+ new_sents.append(new_sent)
439
+
440
+ return new_sents
441
+
442
+ def augment_ellipses(sents):
443
+ """
444
+ Replaces a fraction of '...' with '…'
445
+ """
446
+ has_ellipses = False
447
+ has_unicode_ellipses = False
448
+ for sent in sents:
449
+ for line in sent:
450
+ if line.startswith("#"):
451
+ continue
452
+ pieces = line.split("\t")
453
+ if pieces[1] == '...':
454
+ has_ellipses = True
455
+ elif pieces[1] == '…':
456
+ has_unicode_ellipses = True
457
+
458
+ if has_unicode_ellipses or not has_ellipses:
459
+ return sents
460
+
461
+ new_sents = []
462
+
463
+ num_updated = 0
464
+ for sent in sents:
465
+ if random.random() > 0.1:
466
+ new_sents.append(sent)
467
+ continue
468
+ found = False
469
+ new_sent = []
470
+ for line in sent:
471
+ if line.startswith("#"):
472
+ new_sent.append(line)
473
+ else:
474
+ pieces = line.split("\t")
475
+ if pieces[1] == '...':
476
+ pieces[1] = '…'
477
+ found = True
478
+ new_sent.append("\t".join(pieces))
479
+ new_sents.append(new_sent)
480
+ if found:
481
+ num_updated = num_updated + 1
482
+
483
+ print("Changed %d sentences to use fancy unicode ellipses" % num_updated)
484
+ return new_sents
485
+
486
+ # https://en.wikipedia.org/wiki/Quotation_mark
487
+ QUOTES = ['"', '“', '”', '«', '»', '「', '」', '《', '》', '„', '″']
488
+ QUOTES_RE = re.compile("(.?)[" + "".join(QUOTES) + "](.+)[" + "".join(QUOTES) + "](.?)")
489
+ # Danish does '«' the other way around from most European languages
490
+ START_QUOTES = ['"', '“', '”', '«', '»', '「', '《', '„', '„', '″']
491
+ END_QUOTES = ['"', '“', '”', '»', '«', '」', '》', '”', '“', '″']
492
+
493
+ def augment_quotes(sents, ratio=0.15):
494
+ """
495
+ Go through the sentences and replace a fraction of sentences with alternate quotes
496
+
497
+ TODO: for certain languages we may want to make some language-specific changes
498
+ eg Danish, don't add «...»
499
+ """
500
+ assert len(START_QUOTES) == len(END_QUOTES)
501
+
502
+ counts = Counter()
503
+ new_sents = []
504
+ for sent in sents:
505
+ if random.random() > ratio:
506
+ new_sents.append(sent)
507
+ continue
508
+
509
+ # count if there are exactly 2 quotes in this sentence
510
+ # this is for convenience - otherwise we need to figure out which pairs go together
511
+ count_quotes = sum(1 for x in sent
512
+ if (not x.startswith("#") and
513
+ x.split("\t")[1] in QUOTES))
514
+ if count_quotes != 2:
515
+ new_sents.append(sent)
516
+ continue
517
+
518
+ # choose a pair of quotes from the candidates
519
+ quote_idx = random.choice(range(len(START_QUOTES)))
520
+ start_quote = START_QUOTES[quote_idx]
521
+ end_quote = END_QUOTES[quote_idx]
522
+ counts[start_quote + end_quote] = counts[start_quote + end_quote] + 1
523
+
524
+ new_sent = []
525
+ saw_start = False
526
+ for line in sent:
527
+ if line.startswith("#"):
528
+ new_sent.append(line)
529
+ continue
530
+ pieces = line.split("\t")
531
+ if pieces[1] in QUOTES:
532
+ if saw_start:
533
+ # Note that we don't change the lemma. Presumably it's
534
+ # set to the correct lemma for a quote for this treebank
535
+ pieces[1] = end_quote
536
+ else:
537
+ pieces[1] = start_quote
538
+ saw_start = True
539
+ new_sent.append("\t".join(pieces))
540
+ else:
541
+ new_sent.append(line)
542
+
543
+ for text_idx, text_line in enumerate(new_sent):
544
+ # look for the line that starts with "# text".
545
+ # keep going until we find it, or silently ignore it
546
+ # if the dataset isn't in that format
547
+ if text_line.startswith("# text"):
548
+ replacement = "\\1%s\\2%s\\3" % (start_quote, end_quote)
549
+ new_text_line = QUOTES_RE.sub(replacement, text_line)
550
+ new_sent[text_idx] = new_text_line
551
+
552
+ new_sents.append(new_sent)
553
+
554
+ print("Augmented {} quotes: {}".format(sum(counts.values()), counts))
555
+ return new_sents
556
+
557
+ def find_text_idx(sentence):
558
+ """
559
+ Return the index of the # text line or -1
560
+ """
561
+ for idx, line in enumerate(sentence):
562
+ if line.startswith("# text"):
563
+ return idx
564
+ return -1
565
+
566
+ DIGIT_RE = re.compile("[0-9]")
567
+
568
+ def change_indices(line, delta):
569
+ """
570
+ Adjust all indices in the given sentence by delta. Useful when removing a word, for example
571
+ """
572
+ if line.startswith("#"):
573
+ return line
574
+
575
+ pieces = line.split("\t")
576
+ if MWT_RE.match(pieces[0]):
577
+ indices = pieces[0].split("-")
578
+ pieces[0] = "%d-%d" % (int(indices[0]) + delta, int(indices[1]) + delta)
579
+ line = "\t".join(pieces)
580
+ return line
581
+
582
+ if MWT_OR_COPY_RE.match(pieces[0]):
583
+ index_pieces = pieces[0].split(".", maxsplit=1)
584
+ pieces[0] = "%d.%s" % (int(index_pieces[0]) + delta, index_pieces[1])
585
+ elif not INT_RE.match(pieces[0]):
586
+ raise NotImplementedError("Unknown index type: %s" % pieces[0])
587
+ else:
588
+ pieces[0] = str(int(pieces[0]) + delta)
589
+ if pieces[6] != '_':
590
+ # copy nodes don't have basic dependencies in the es_ancora treebank
591
+ dep = int(pieces[6])
592
+ if dep != 0:
593
+ pieces[6] = str(int(dep) + delta)
594
+ if pieces[8] != '_':
595
+ dep_pieces = pieces[8].split(":", maxsplit=1)
596
+ if DIGIT_RE.search(dep_pieces[1]):
597
+ raise NotImplementedError("Need to handle multiple additional deps:\n%s" % line)
598
+ if int(dep_pieces[0]) != 0:
599
+ pieces[8] = str(int(dep_pieces[0]) + delta) + ":" + dep_pieces[1]
600
+ line = "\t".join(pieces)
601
+ return line
602
+
603
+ def augment_initial_punct(sents, ratio=0.20):
604
+ """
605
+ If a sentence starts with certain punct marks, occasionally use the same sentence without the initial punct.
606
+
607
+ Currently this just handles ¿
608
+ This helps languages such as CA and ES where the models go awry when the initial ¿ is missing.
609
+ """
610
+ new_sents = []
611
+ for sent in sents:
612
+ if random.random() > ratio:
613
+ continue
614
+
615
+ text_idx = find_text_idx(sent)
616
+ text_line = sent[text_idx]
617
+ if text_line.count("¿") != 1:
618
+ # only handle sentences with exactly one ¿
619
+ continue
620
+
621
+ # find the first line with actual text
622
+ for idx, line in enumerate(sent):
623
+ if line.startswith("#"):
624
+ continue
625
+ break
626
+ if idx >= len(sent) - 1:
627
+ raise ValueError("Unexpectedly an entire sentence is comments")
628
+ pieces = line.split("\t")
629
+ if pieces[1] != '¿':
630
+ continue
631
+ if has_space_after_no(pieces[-1]):
632
+ replace_text = "¿"
633
+ else:
634
+ replace_text = "¿ "
635
+
636
+ new_sent = sent[:idx] + sent[idx+1:]
637
+ new_sent[text_idx] = text_line.replace(replace_text, "")
638
+
639
+ # now need to update all indices
640
+ new_sent = [change_indices(x, -1) for x in new_sent]
641
+ new_sents.append(new_sent)
642
+
643
+ if len(new_sents) > 0:
644
+ print("Added %d sentences with the leading ¿ removed" % len(new_sents))
645
+
646
+ return sents + new_sents
647
+
648
+
649
+ def augment_brackets(sents, ratio=0.1):
650
+ """
651
+ If there are no sentences with [], transform some () into []
652
+ """
653
+ new_sents = []
654
+ for sent in sents:
655
+ text_idx = find_text_idx(sent)
656
+ text_line = sent[text_idx]
657
+ if text_line.count("[") > 0 or text_line.count("]") > 0:
658
+ # found a square bracket, so, never mind
659
+ return sents
660
+
661
+ for sent in sents:
662
+ if random.random() > ratio:
663
+ continue
664
+
665
+ text_idx = find_text_idx(sent)
666
+ text_line = sent[text_idx]
667
+ if text_line.count("(") == 0 and text_line.count(")") == 0:
668
+ continue
669
+
670
+ text_line = text_line.replace("(", "[").replace(")", "]")
671
+ new_sent = list(sent)
672
+ new_sent[text_idx] = text_line
673
+ for idx, line in enumerate(new_sent):
674
+ if line.startswith("#"):
675
+ continue
676
+ pieces = line.split("\t")
677
+ if pieces[1] == '(':
678
+ pieces[1] = '['
679
+ elif pieces[1] == ')':
680
+ pieces[1] = ']'
681
+ new_sent[idx] = "\t".join(pieces)
682
+ new_sents.append(new_sent)
683
+
684
+ if len(new_sents) > 0:
685
+ print("Added %d sentences with parens replaced with square brackets" % len(new_sents))
686
+
687
+ return sents + new_sents
688
+
689
+
690
+ def augment_punct(sents):
691
+ """
692
+ If there are no instances of ’ in the dataset, but there are instances of ',
693
+ we replace some fraction of ' with ’ so that the tokenizer will recognize it.
694
+
695
+ Also augments with ... / …
696
+ """
697
+ new_sents = augment_apos(sents)
698
+ new_sents = augment_quotes(new_sents)
699
+ new_sents = augment_move_comma(new_sents)
700
+ new_sents = augment_comma_separations(new_sents)
701
+ new_sents = augment_initial_punct(new_sents)
702
+ new_sents = augment_ellipses(new_sents)
703
+ new_sents = augment_brackets(new_sents)
704
+
705
+ return new_sents
706
+
707
+
708
+
709
+ def write_augmented_dataset(input_conllu, output_conllu, augment_function):
710
+ # set the seed for each data file so that the results are the same
711
+ # regardless of how many treebanks are processed at once
712
+ random.seed(1234)
713
+
714
+ # read and shuffle conllu data
715
+ sents = read_sentences_from_conllu(input_conllu)
716
+
717
+ # the actual meat of the function - produce new sentences
718
+ new_sents = augment_function(sents)
719
+
720
+ write_sentences_to_conllu(output_conllu, new_sents)
721
+
722
+ def remove_spaces_from_sentences(sents):
723
+ """
724
+ Makes sure every word in the list of sentences has SpaceAfter=No.
725
+
726
+ Returns a new list of sentences
727
+ """
728
+ new_sents = []
729
+ for sentence in sents:
730
+ new_sentence = []
731
+ for word in sentence:
732
+ if word.startswith("#"):
733
+ new_sentence.append(word)
734
+ continue
735
+ pieces = word.split("\t")
736
+ if pieces[-1] == "_":
737
+ pieces[-1] = "SpaceAfter=No"
738
+ elif pieces[-1].find("SpaceAfter=No") >= 0:
739
+ pass
740
+ else:
741
+ raise ValueError("oops")
742
+ word = "\t".join(pieces)
743
+ new_sentence.append(word)
744
+ new_sents.append(new_sentence)
745
+ return new_sents
746
+
747
+ def remove_spaces(input_conllu, output_conllu):
748
+ """
749
+ Turns a dataset into something appropriate for building a segmenter.
750
+
751
+ For example, this works well on the Korean datasets.
752
+ """
753
+ sents = read_sentences_from_conllu(input_conllu)
754
+
755
+ new_sents = remove_spaces_from_sentences(sents)
756
+
757
+ write_sentences_to_conllu(output_conllu, new_sents)
758
+
759
+
760
+ def build_combined_korean_dataset(udbase_dir, tokenizer_dir, short_name, dataset, output_conllu):
761
+ """
762
+ Builds a combined dataset out of multiple Korean datasets.
763
+
764
+ Currently this uses GSD and Kaist. If a segmenter-appropriate
765
+ dataset was requested, spaces are removed.
766
+
767
+ TODO: we need to handle the difference in xpos tags somehow.
768
+ """
769
+ gsd_conllu = common.find_treebank_dataset_file("UD_Korean-GSD", udbase_dir, dataset, "conllu")
770
+ kaist_conllu = common.find_treebank_dataset_file("UD_Korean-Kaist", udbase_dir, dataset, "conllu")
771
+ sents = read_sentences_from_conllu(gsd_conllu) + read_sentences_from_conllu(kaist_conllu)
772
+
773
+ segmenter = short_name.endswith("_seg")
774
+ if segmenter:
775
+ sents = remove_spaces_from_sentences(sents)
776
+
777
+ write_sentences_to_conllu(output_conllu, sents)
778
+
779
+ def build_combined_korean(udbase_dir, tokenizer_dir, short_name):
780
+ for dataset in ("train", "dev", "test"):
781
+ output_conllu = common.tokenizer_conllu_name(tokenizer_dir, short_name, dataset)
782
+ build_combined_korean_dataset(udbase_dir, tokenizer_dir, short_name, dataset, output_conllu)
783
+
784
+ def build_combined_italian_dataset(paths, model_type, dataset):
785
+ udbase_dir = paths["UDBASE"]
786
+ if dataset == 'train':
787
+ # could maybe add ParTUT, but that dataset has a slightly different xpos set
788
+ # (no DE or I)
789
+ # and I didn't feel like sorting through the differences
790
+ # Note: currently these each have small changes compared with
791
+ # the UD2.11 release. See the issues (possibly closed by now)
792
+ # filed by AngledLuffa on each of the treebanks for more info.
793
+ treebanks = [
794
+ "UD_Italian-ISDT",
795
+ "UD_Italian-VIT",
796
+ ]
797
+ if model_type is not common.ModelType.TOKENIZER:
798
+ treebanks.extend([
799
+ "UD_Italian-TWITTIRO",
800
+ "UD_Italian-PoSTWITA"
801
+ ])
802
+ print("Building {} dataset out of {}".format(model_type, " ".join(treebanks)))
803
+ sents = []
804
+ for treebank in treebanks:
805
+ conllu_file = common.find_treebank_dataset_file(treebank, udbase_dir, dataset, "conllu", fail=True)
806
+ sents.extend(read_sentences_from_conllu(conllu_file))
807
+ else:
808
+ istd_conllu = common.find_treebank_dataset_file("UD_Italian-ISDT", udbase_dir, dataset, "conllu")
809
+ sents = read_sentences_from_conllu(istd_conllu)
810
+
811
+ return sents
812
+
813
+ def check_gum_ready(udbase_dir):
814
+ gum_conllu = common.find_treebank_dataset_file("UD_English-GUMReddit", udbase_dir, "train", "conllu")
815
+ if common.mostly_underscores(gum_conllu):
816
+ raise ValueError("Cannot process UD_English-GUMReddit in its current form. There should be a download script available in the directory which will help integrate the missing proprietary values. Please run that script to update the data, then try again.")
817
+
818
+ def build_combined_english_dataset(paths, model_type, dataset):
819
+ """
820
+ en_combined is currently EWT, GUM, PUD, Pronouns, and handparsed
821
+ """
822
+ udbase_dir = paths["UDBASE"]
823
+ check_gum_ready(udbase_dir)
824
+
825
+ if dataset == 'train':
826
+ # TODO: include more UD treebanks, possibly with xpos removed
827
+ # UD_English-ParTUT - xpos are different
828
+ # also include "external" treebanks such as PTB
829
+ # NOTE: in order to get the best results, make sure each of these treebanks have the latest edits applied
830
+ train_treebanks = ["UD_English-EWT", "UD_English-GUM", "UD_English-GUMReddit"]
831
+ test_treebanks = ["UD_English-PUD", "UD_English-Pronouns"]
832
+ sents = []
833
+ for treebank in train_treebanks:
834
+ conllu_file = common.find_treebank_dataset_file(treebank, udbase_dir, "train", "conllu", fail=True)
835
+ new_sents = read_sentences_from_conllu(conllu_file)
836
+ print("Read %d sentences from %s" % (len(new_sents), conllu_file))
837
+ sents.extend(new_sents)
838
+ for treebank in test_treebanks:
839
+ conllu_file = common.find_treebank_dataset_file(treebank, udbase_dir, "test", "conllu", fail=True)
840
+ new_sents = read_sentences_from_conllu(conllu_file)
841
+ print("Read %d sentences from %s" % (len(new_sents), conllu_file))
842
+ sents.extend(new_sents)
843
+ else:
844
+ ewt_conllu = common.find_treebank_dataset_file("UD_English-EWT", udbase_dir, dataset, "conllu")
845
+ sents = read_sentences_from_conllu(ewt_conllu)
846
+
847
+ return sents
848
+
849
+ def add_english_sentence_final_punctuation(handparsed_sentences):
850
+ """
851
+ Add a period to the end of a sentence with no punct at the end.
852
+
853
+ The next-to-last word has SpaceAfter=No added as well.
854
+
855
+ Possibly English-specific because of the xpos. Could be upgraded
856
+ to handle multiple languages by passing in the xpos as an argument
857
+ """
858
+ new_sents = []
859
+ for sent in handparsed_sentences:
860
+ root_id = None
861
+ max_id = None
862
+ last_punct = False
863
+ for line in sent:
864
+ if line.startswith("#"):
865
+ continue
866
+ pieces = line.split("\t")
867
+ if MWT_OR_COPY_RE.match(pieces[0]):
868
+ continue
869
+ if pieces[6] == '0':
870
+ root_id = pieces[0]
871
+ max_id = int(pieces[0])
872
+ last_punct = pieces[3] == 'PUNCT'
873
+ if not last_punct:
874
+ new_sent = list(sent)
875
+ pieces = new_sent[-1].split("\t")
876
+ pieces[-1] = add_space_after_no(pieces[-1])
877
+ new_sent[-1] = "\t".join(pieces)
878
+ new_sent.append("%d\t.\t.\tPUNCT\t.\t_\t%s\tpunct\t%s:punct\t_" % (max_id+1, root_id, root_id))
879
+ new_sents.append(new_sent)
880
+ else:
881
+ new_sents.append(sent)
882
+ return new_sents
883
+
884
+ def build_extra_combined_french_dataset(paths, model_type, dataset):
885
+ """
886
+ Extra sentences we don't want augmented for French - currently, handparsed lemmas
887
+ """
888
+ handparsed_dir = paths["HANDPARSED_DIR"]
889
+ sents = []
890
+ if dataset == 'train':
891
+ if model_type is common.ModelType.LEMMA:
892
+ handparsed_path = os.path.join(handparsed_dir, "french-lemmas", "fr_lemmas.conllu")
893
+ handparsed_sentences = read_sentences_from_conllu(handparsed_path)
894
+ print("Loaded %d sentences from %s" % (len(handparsed_sentences), handparsed_path))
895
+ sents.extend(handparsed_sentences)
896
+ return sents
897
+
898
+
899
+ def build_extra_combined_english_dataset(paths, model_type, dataset):
900
+ """
901
+ Extra sentences we don't want augmented
902
+ """
903
+ handparsed_dir = paths["HANDPARSED_DIR"]
904
+ sents = []
905
+ if dataset == 'train':
906
+ handparsed_path = os.path.join(handparsed_dir, "english-handparsed", "english.conll")
907
+ handparsed_sentences = read_sentences_from_conllu(handparsed_path)
908
+ handparsed_sentences = add_english_sentence_final_punctuation(handparsed_sentences)
909
+ sents.extend(handparsed_sentences)
910
+ print("Loaded %d sentences from %s" % (len(sents), handparsed_path))
911
+
912
+ if model_type is common.ModelType.LEMMA:
913
+ handparsed_path = os.path.join(handparsed_dir, "english-lemmas", "en_lemmas.conllu")
914
+ handparsed_sentences = read_sentences_from_conllu(handparsed_path)
915
+ print("Loaded %d sentences from %s" % (len(handparsed_sentences), handparsed_path))
916
+ sents.extend(handparsed_sentences)
917
+ return sents
918
+
919
+ def build_extra_combined_italian_dataset(paths, model_type, dataset):
920
+ """
921
+ Extra data - the MWT data for Italian
922
+ """
923
+ handparsed_dir = paths["HANDPARSED_DIR"]
924
+ if dataset != 'train':
925
+ return []
926
+
927
+ extra_italian = os.path.join(handparsed_dir, "italian-mwt", "italian.mwt")
928
+ if not os.path.exists(extra_italian):
929
+ raise FileNotFoundError("Cannot find the extra dataset 'italian.mwt' which includes various multi-words retokenized, expected {}".format(extra_italian))
930
+
931
+ extra_sents = read_sentences_from_conllu(extra_italian)
932
+ for sentence in extra_sents:
933
+ if not sentence[2].endswith("_") or not MWT_RE.match(sentence[2]):
934
+ raise AssertionError("Unexpected format of the italian.mwt file. Has it already be modified to have SpaceAfter=No everywhere?")
935
+ sentence[2] = sentence[2][:-1] + "SpaceAfter=No"
936
+ print("Loaded %d sentences from %s" % (len(extra_sents), extra_italian))
937
+ return extra_sents
938
+
939
+ def replace_semicolons(sentences):
940
+ """
941
+ Spanish GSD and AnCora have different standards for semicolons.
942
+
943
+ GSD has semicolons at the end of sentences, AnCora has them in the middle as clause separators.
944
+ Consecutive sentences in GSD do not seem to be related, so there is no combining that can be done.
945
+ The easiest solution is to replace sentence final semicolons with "." in GSD
946
+ """
947
+ new_sents = []
948
+ count = 0
949
+ for sentence in sentences:
950
+ for text_idx, text_line in enumerate(sentence):
951
+ if text_line.startswith("# text"):
952
+ break
953
+ else:
954
+ raise ValueError("Expected every sentence in GSD to have a # text field")
955
+ if not text_line.endswith(";"):
956
+ new_sents.append(sentence)
957
+ continue
958
+ count = count + 1
959
+ new_sent = list(sentence)
960
+ new_sent[text_idx] = text_line[:-1] + "."
961
+ new_sent[-1] = new_sent[-1].replace(";", ".")
962
+ count = count + 1
963
+ new_sents.append(new_sent)
964
+ print("Updated %d sentences to replace sentence-final ; with ." % count)
965
+ return new_sents
966
+
967
+ def strip_column(sents, column):
968
+ """
969
+ Removes a specified column from the given dataset
970
+
971
+ Particularly useful when mixing two different POS formalisms in the same tagger
972
+ """
973
+ new_sents = []
974
+ for sentence in sents:
975
+ new_sent = []
976
+ for word in sentence:
977
+ if word.startswith("#"):
978
+ new_sent.append(word)
979
+ continue
980
+ pieces = word.split("\t")
981
+ pieces[column] = "_"
982
+ new_sent.append("\t".join(pieces))
983
+ new_sents.append(new_sent)
984
+ return new_sents
985
+
986
+ def strip_xpos(sents):
987
+ """
988
+ Removes all xpos from the given dataset
989
+
990
+ Particularly useful when mixing two different POS formalisms in the same tagger
991
+ """
992
+ return strip_column(sents, 4)
993
+
994
+ def strip_feats(sents):
995
+ """
996
+ Removes all features from the given dataset
997
+
998
+ Particularly useful when mixing two different POS formalisms in the same tagger
999
+ """
1000
+ return strip_column(sents, 5)
1001
+
1002
+ def build_combined_albanian_dataset(paths, model_type, dataset):
1003
+ """
1004
+ sq_combined is STAF as the base, with TSA added for some things
1005
+ """
1006
+ udbase_dir = paths["UDBASE"]
1007
+ udbase_git_dir = paths["UDBASE_GIT"]
1008
+ handparsed_dir = paths["HANDPARSED_DIR"]
1009
+
1010
+ treebanks = ["UD_Albanian-STAF", "UD_Albanian-TSA"]
1011
+
1012
+ if dataset == 'train' and model_type == common.ModelType.POS:
1013
+ documents = {}
1014
+
1015
+ conllu_file = common.find_treebank_dataset_file(treebanks[0], udbase_dir, "train", "conllu", fail=True)
1016
+ new_sents = read_sentences_from_conllu(conllu_file)
1017
+ documents[treebanks[0]] = new_sents
1018
+
1019
+ # we use udbase_git_dir for TSA because of an updated MWT scheme
1020
+ conllu_file = common.find_treebank_dataset_file(treebanks[1], udbase_git_dir, "test", "conllu", fail=True)
1021
+ new_sents = read_sentences_from_conllu(conllu_file)
1022
+ new_sents = strip_xpos(new_sents)
1023
+ new_sents = strip_feats(new_sents)
1024
+ documents[treebanks[1]] = new_sents
1025
+
1026
+ return documents
1027
+
1028
+ if dataset == 'train' and model_type is not common.ModelType.DEPPARSE:
1029
+ sents = []
1030
+
1031
+ conllu_file = common.find_treebank_dataset_file(treebanks[0], udbase_dir, "train", "conllu", fail=True)
1032
+ new_sents = read_sentences_from_conllu(conllu_file)
1033
+ print("Read %d sentences from %s" % (len(new_sents), conllu_file))
1034
+ sents.extend(new_sents)
1035
+
1036
+ conllu_file = common.find_treebank_dataset_file(treebanks[1], udbase_git_dir, "test", "conllu", fail=True)
1037
+ new_sents = read_sentences_from_conllu(conllu_file)
1038
+ print("Read %d sentences from %s" % (len(new_sents), conllu_file))
1039
+ sents.extend(new_sents)
1040
+
1041
+ return sents
1042
+
1043
+ conllu_file = common.find_treebank_dataset_file(treebanks[0], udbase_dir, dataset, "conllu", fail=True)
1044
+ sents = read_sentences_from_conllu(conllu_file)
1045
+ return sents
1046
+
1047
+ def build_combined_spanish_dataset(paths, model_type, dataset):
1048
+ """
1049
+ es_combined is AnCora and GSD put together
1050
+
1051
+ For POS training, we put the different datasets into a zip file so
1052
+ that we can keep the conllu files separate and remove the xpos
1053
+ from the non-AnCora training files. It is necessary to remove the
1054
+ xpos because GSD and PUD both use different xpos schemes from
1055
+ AnCora, and the tagger can use additional data files as training
1056
+ data without a specific column if that column is entirely blank
1057
+
1058
+ TODO: consider mixing in PUD?
1059
+ """
1060
+ udbase_dir = paths["UDBASE"]
1061
+ handparsed_dir = paths["HANDPARSED_DIR"]
1062
+
1063
+ treebanks = ["UD_Spanish-AnCora", "UD_Spanish-GSD"]
1064
+
1065
+ if dataset == 'train' and model_type == common.ModelType.POS:
1066
+ documents = {}
1067
+ for treebank in treebanks:
1068
+ conllu_file = common.find_treebank_dataset_file(treebank, udbase_dir, dataset, "conllu", fail=True)
1069
+ new_sents = read_sentences_from_conllu(conllu_file)
1070
+ if not treebank.endswith("AnCora"):
1071
+ new_sents = strip_xpos(new_sents)
1072
+ documents[treebank] = new_sents
1073
+
1074
+ return documents
1075
+
1076
+ if dataset == 'train':
1077
+ sents = []
1078
+ for treebank in treebanks:
1079
+ conllu_file = common.find_treebank_dataset_file(treebank, udbase_dir, dataset, "conllu", fail=True)
1080
+ new_sents = read_sentences_from_conllu(conllu_file)
1081
+ print("Read %d sentences from %s" % (len(new_sents), conllu_file))
1082
+ if treebank.endswith("GSD"):
1083
+ new_sents = replace_semicolons(new_sents)
1084
+ sents.extend(new_sents)
1085
+
1086
+ if model_type in (common.ModelType.TOKENIZER, common.ModelType.MWT, common.ModelType.LEMMA):
1087
+ extra_spanish = os.path.join(handparsed_dir, "spanish-mwt", "adjectives.conllu")
1088
+ if not os.path.exists(extra_spanish):
1089
+ raise FileNotFoundError("Cannot find the extra dataset 'handpicked.mwt' which includes various multi-words retokenized, expected {}".format(extra_italian))
1090
+ extra_sents = read_sentences_from_conllu(extra_spanish)
1091
+ print("Read %d sentences from %s" % (len(extra_sents), extra_spanish))
1092
+ sents.extend(extra_sents)
1093
+ else:
1094
+ conllu_file = common.find_treebank_dataset_file("UD_Spanish-AnCora", udbase_dir, dataset, "conllu", fail=True)
1095
+ sents = read_sentences_from_conllu(conllu_file)
1096
+
1097
+ return sents
1098
+
1099
+ def build_combined_french_dataset(paths, model_type, dataset):
1100
+ udbase_dir = paths["UDBASE"]
1101
+ handparsed_dir = paths["HANDPARSED_DIR"]
1102
+ if dataset == 'train':
1103
+ train_treebanks = ["UD_French-GSD", "UD_French-ParisStories", "UD_French-Rhapsodie", "UD_French-Sequoia"]
1104
+ sents = []
1105
+ for treebank in train_treebanks:
1106
+ conllu_file = common.find_treebank_dataset_file(treebank, udbase_dir, "train", "conllu", fail=True)
1107
+ new_sents = read_sentences_from_conllu(conllu_file)
1108
+ print("Read %d sentences from %s" % (len(new_sents), conllu_file))
1109
+ sents.extend(new_sents)
1110
+
1111
+ extra_french = os.path.join(handparsed_dir, "french-handparsed", "handparsed_deps.conllu")
1112
+ if not os.path.exists(extra_french):
1113
+ raise FileNotFoundError("Cannot find the extra dataset 'handparsed_deps.conllu' which includes various dependency fixes, expected {}".format(extra_italian))
1114
+ extra_sents = read_sentences_from_conllu(extra_french)
1115
+ print("Read %d sentences from %s" % (len(extra_sents), extra_french))
1116
+ sents.extend(extra_sents)
1117
+ else:
1118
+ gsd_conllu = common.find_treebank_dataset_file("UD_French-GSD", udbase_dir, dataset, "conllu")
1119
+ sents = read_sentences_from_conllu(gsd_conllu)
1120
+
1121
+ return sents
1122
+
1123
+ def build_combined_hebrew_dataset(paths, model_type, dataset):
1124
+ """
1125
+ Combines the IAHLT treebank with an updated form of HTB where the annotation style more closes matches IAHLT
1126
+
1127
+ Currently the updated HTB is not in UD, so you will need to clone
1128
+ git@github.com:IAHLT/UD_Hebrew.git to $UDBASE_GIT
1129
+
1130
+ dev and test sets will be those from IAHLT
1131
+ """
1132
+ udbase_dir = paths["UDBASE"]
1133
+ udbase_git_dir = paths["UDBASE_GIT"]
1134
+
1135
+ treebanks = ["UD_Hebrew-IAHLTwiki", "UD_Hebrew-IAHLTknesset"]
1136
+ if dataset == 'train':
1137
+ sents = []
1138
+ for treebank in treebanks:
1139
+ conllu_file = common.find_treebank_dataset_file(treebank, udbase_dir, dataset, "conllu", fail=True)
1140
+ new_sents = read_sentences_from_conllu(conllu_file)
1141
+ print("Read %d sentences from %s" % (len(new_sents), conllu_file))
1142
+ sents.extend(new_sents)
1143
+
1144
+ # if/when this gets ported back to UD, switch to getting both datasets from UD
1145
+ hebrew_git_dir = os.path.join(udbase_git_dir, "UD_Hebrew")
1146
+ if not os.path.exists(hebrew_git_dir):
1147
+ raise FileNotFoundError("Please download git@github.com:IAHLT/UD_Hebrew.git to %s (based on $UDBASE_GIT)" % hebrew_git_dir)
1148
+ conllu_file = os.path.join(hebrew_git_dir, "he_htb-ud-train.conllu")
1149
+ if not os.path.exists(conllu_file):
1150
+ raise FileNotFoundError("Found %s but inexplicably there was no %s" % (hebrew_git_dir, conllu_file))
1151
+ new_sents = read_sentences_from_conllu(conllu_file)
1152
+ print("Read %d sentences from %s" % (len(new_sents), conllu_file))
1153
+ sents.extend(new_sents)
1154
+ else:
1155
+ conllu_file = common.find_treebank_dataset_file(treebanks[0], udbase_dir, dataset, "conllu", fail=True)
1156
+ sents = read_sentences_from_conllu(conllu_file)
1157
+
1158
+ return sents
1159
+
1160
+ COMBINED_FNS = {
1161
+ "en_combined": build_combined_english_dataset,
1162
+ "es_combined": build_combined_spanish_dataset,
1163
+ "fr_combined": build_combined_french_dataset,
1164
+ "he_combined": build_combined_hebrew_dataset,
1165
+ "it_combined": build_combined_italian_dataset,
1166
+ "sq_combined": build_combined_albanian_dataset,
1167
+ }
1168
+
1169
+ # some extra data for the combined models without augmenting
1170
+ COMBINED_EXTRA_FNS = {
1171
+ "en_combined": build_extra_combined_english_dataset,
1172
+ "fr_combined": build_extra_combined_french_dataset,
1173
+ "it_combined": build_extra_combined_italian_dataset,
1174
+ }
1175
+
1176
+ def build_combined_dataset(paths, short_name, model_type, augment):
1177
+ random.seed(1234)
1178
+ tokenizer_dir = paths["TOKENIZE_DATA_DIR"]
1179
+ build_fn = COMBINED_FNS[short_name]
1180
+ extra_fn = COMBINED_EXTRA_FNS.get(short_name, None)
1181
+ for dataset in ("train", "dev", "test"):
1182
+ output_conllu = common.tokenizer_conllu_name(tokenizer_dir, short_name, dataset)
1183
+ sents = build_fn(paths, model_type, dataset)
1184
+ if isinstance(sents, dict):
1185
+ if dataset == 'train' and augment:
1186
+ for filename in list(sents.keys()):
1187
+ sents[filename] = augment_punct(sents[filename])
1188
+ output_zip = os.path.splitext(output_conllu)[0] + ".zip"
1189
+ with zipfile.ZipFile(output_zip, "w") as zout:
1190
+ for filename in list(sents.keys()):
1191
+ with zout.open(filename + ".conllu", "w") as zfout:
1192
+ with io.TextIOWrapper(zfout, encoding='utf-8', newline='') as fout:
1193
+ write_sentences_to_file(fout, sents[filename])
1194
+ else:
1195
+ if dataset == 'train' and augment:
1196
+ sents = augment_punct(sents)
1197
+ if extra_fn is not None:
1198
+ sents.extend(extra_fn(paths, model_type, dataset))
1199
+ write_sentences_to_conllu(output_conllu, sents)
1200
+
1201
+ BIO_DATASETS = ("en_craft", "en_genia", "en_mimic")
1202
+
1203
+ def build_bio_dataset(paths, udbase_dir, tokenizer_dir, handparsed_dir, short_name, model_type, augment):
1204
+ """
1205
+ Process the en bio datasets
1206
+
1207
+ Creates a dataset by combining the en_combined data with one of the bio sets
1208
+ """
1209
+ random.seed(1234)
1210
+ name, bio_dataset = short_name.split("_")
1211
+ assert name == 'en'
1212
+ for dataset in ("train", "dev", "test"):
1213
+ output_conllu = common.tokenizer_conllu_name(tokenizer_dir, short_name, dataset)
1214
+ if dataset == 'train':
1215
+ sents = build_combined_english_dataset(paths, model_type, dataset)
1216
+ if dataset == 'train' and augment:
1217
+ sents = augment_punct(sents)
1218
+ else:
1219
+ sents = []
1220
+ bio_file = os.path.join(paths["BIO_UD_DIR"], "UD_English-%s" % bio_dataset.upper(), "en_%s-ud-%s.conllu" % (bio_dataset.lower(), dataset))
1221
+ sents.extend(read_sentences_from_conllu(bio_file))
1222
+ write_sentences_to_conllu(output_conllu, sents)
1223
+
1224
+ def build_combined_english_gum_dataset(udbase_dir, tokenizer_dir, short_name, dataset, augment):
1225
+ """
1226
+ Build the GUM dataset by combining GUMReddit
1227
+
1228
+ It checks to make sure GUMReddit is filled out using the included script
1229
+ """
1230
+ check_gum_ready(udbase_dir)
1231
+ random.seed(1234)
1232
+
1233
+ output_conllu = common.tokenizer_conllu_name(tokenizer_dir, short_name, dataset)
1234
+
1235
+ treebanks = ["UD_English-GUM", "UD_English-GUMReddit"]
1236
+ sents = []
1237
+ for treebank in treebanks:
1238
+ conllu_file = common.find_treebank_dataset_file(treebank, udbase_dir, dataset, "conllu", fail=True)
1239
+ sents.extend(read_sentences_from_conllu(conllu_file))
1240
+
1241
+ if dataset == 'train' and augment:
1242
+ sents = augment_punct(sents)
1243
+
1244
+ write_sentences_to_conllu(output_conllu, sents)
1245
+
1246
+ def build_combined_english_gum(udbase_dir, tokenizer_dir, short_name, augment):
1247
+ for dataset in ("train", "dev", "test"):
1248
+ build_combined_english_gum_dataset(udbase_dir, tokenizer_dir, short_name, dataset, augment)
1249
+
1250
+ def prepare_ud_dataset(treebank, udbase_dir, tokenizer_dir, short_name, short_language, dataset, augment=True, input_conllu=None, output_conllu=None):
1251
+ if input_conllu is None:
1252
+ input_conllu = common.find_treebank_dataset_file(treebank, udbase_dir, dataset, "conllu", fail=True)
1253
+ if output_conllu is None:
1254
+ output_conllu = common.tokenizer_conllu_name(tokenizer_dir, short_name, dataset)
1255
+ print("Reading from %s and writing to %s" % (input_conllu, output_conllu))
1256
+
1257
+ if short_name == "te_mtg" and dataset == 'train' and augment:
1258
+ write_augmented_dataset(input_conllu, output_conllu, augment_telugu)
1259
+ elif short_name == "ar_padt" and dataset == 'train' and augment:
1260
+ write_augmented_dataset(input_conllu, output_conllu, augment_arabic_padt)
1261
+ elif short_name.startswith("ko_") and short_name.endswith("_seg"):
1262
+ remove_spaces(input_conllu, output_conllu)
1263
+ elif dataset == 'train' and augment:
1264
+ write_augmented_dataset(input_conllu, output_conllu, augment_punct)
1265
+ else:
1266
+ sents = read_sentences_from_conllu(input_conllu)
1267
+ write_sentences_to_conllu(output_conllu, sents)
1268
+
1269
+ def process_ud_treebank(treebank, udbase_dir, tokenizer_dir, short_name, short_language, augment=True):
1270
+ """
1271
+ Process a normal UD treebank with train/dev/test splits
1272
+
1273
+ SL-SSJ and other datasets with inline modifications all use this code path as well.
1274
+ """
1275
+ prepare_ud_dataset(treebank, udbase_dir, tokenizer_dir, short_name, short_language, "train", augment)
1276
+ prepare_ud_dataset(treebank, udbase_dir, tokenizer_dir, short_name, short_language, "dev", augment)
1277
+ prepare_ud_dataset(treebank, udbase_dir, tokenizer_dir, short_name, short_language, "test", augment)
1278
+
1279
+
1280
+ XV_RATIO = 0.2
1281
+
1282
+ def process_partial_ud_treebank(treebank, udbase_dir, tokenizer_dir, short_name, short_language):
1283
+ """
1284
+ Process a UD treebank with only train/test splits
1285
+
1286
+ For example, in UD 2.7:
1287
+ UD_Buryat-BDT
1288
+ UD_Galician-TreeGal
1289
+ UD_Indonesian-CSUI
1290
+ UD_Kazakh-KTB
1291
+ UD_Kurmanji-MG
1292
+ UD_Latin-Perseus
1293
+ UD_Livvi-KKPP
1294
+ UD_North_Sami-Giella
1295
+ UD_Old_Russian-RNC
1296
+ UD_Sanskrit-Vedic
1297
+ UD_Slovenian-SST
1298
+ UD_Upper_Sorbian-UFAL
1299
+ UD_Welsh-CCG
1300
+ """
1301
+ train_input_conllu = common.find_treebank_dataset_file(treebank, udbase_dir, "train", "conllu")
1302
+ test_input_conllu = common.find_treebank_dataset_file(treebank, udbase_dir, "test", "conllu")
1303
+
1304
+ train_output_conllu = common.tokenizer_conllu_name(tokenizer_dir, short_name, "train")
1305
+ dev_output_conllu = common.tokenizer_conllu_name(tokenizer_dir, short_name, "dev")
1306
+ test_output_conllu = common.tokenizer_conllu_name(tokenizer_dir, short_name, "test")
1307
+
1308
+ if (common.num_words_in_file(train_input_conllu) <= 1000 and
1309
+ common.num_words_in_file(test_input_conllu) > 5000):
1310
+ train_input_conllu, test_input_conllu = test_input_conllu, train_input_conllu
1311
+
1312
+ if not split_train_file(treebank=treebank,
1313
+ train_input_conllu=train_input_conllu,
1314
+ train_output_conllu=train_output_conllu,
1315
+ dev_output_conllu=dev_output_conllu):
1316
+ return
1317
+
1318
+ # the test set is already fine
1319
+ # currently we do not do any augmentation of these partial treebanks
1320
+ prepare_ud_dataset(treebank, udbase_dir, tokenizer_dir, short_name, short_language, "test", augment=False, input_conllu=test_input_conllu, output_conllu=test_output_conllu)
1321
+
1322
+ def add_specific_args(parser):
1323
+ parser.add_argument('--no_augment', action='store_false', dest='augment', default=True,
1324
+ help='Augment the dataset in various ways')
1325
+ parser.add_argument('--no_prepare_labels', action='store_false', dest='prepare_labels', default=True,
1326
+ help='Prepare tokenizer and MWT labels. Expensive, but obviously necessary for training those models.')
1327
+ convert_th_lst20.add_lst20_args(parser)
1328
+
1329
+ convert_vi_vlsp.add_vlsp_args(parser)
1330
+
1331
+ def process_treebank(treebank, model_type, paths, args):
1332
+ """
1333
+ Processes a single treebank into train, dev, test parts
1334
+
1335
+ Includes processing for a few external tokenization datasets:
1336
+ vi_vlsp, th_orchid, th_best
1337
+
1338
+ Also, there is no specific mechanism for UD_Arabic-NYUAD or
1339
+ similar treebanks, which need integration with LDC datsets
1340
+ """
1341
+ udbase_dir = paths["UDBASE"]
1342
+ tokenizer_dir = paths["TOKENIZE_DATA_DIR"]
1343
+ handparsed_dir = paths["HANDPARSED_DIR"]
1344
+
1345
+ short_name = treebank_to_short_name(treebank)
1346
+ short_language = short_name.split("_")[0]
1347
+
1348
+ os.makedirs(tokenizer_dir, exist_ok=True)
1349
+
1350
+ if short_name == "my_alt":
1351
+ convert_my_alt.convert_my_alt(paths["CONSTITUENCY_BASE"], tokenizer_dir)
1352
+ elif short_name == "vi_vlsp":
1353
+ convert_vi_vlsp.convert_vi_vlsp(paths["STANZA_EXTERN_DIR"], tokenizer_dir, args)
1354
+ elif short_name == "th_orchid":
1355
+ convert_th_orchid.main(paths["STANZA_EXTERN_DIR"], tokenizer_dir)
1356
+ elif short_name == "th_lst20":
1357
+ convert_th_lst20.convert(paths["STANZA_EXTERN_DIR"], tokenizer_dir, args)
1358
+ elif short_name == "th_best":
1359
+ convert_th_best.main(paths["STANZA_EXTERN_DIR"], tokenizer_dir)
1360
+ elif short_name == "ml_cochin":
1361
+ convert_ml_cochin.main(paths["STANZA_EXTERN_DIR"], tokenizer_dir)
1362
+ elif short_name.startswith("ko_combined"):
1363
+ build_combined_korean(udbase_dir, tokenizer_dir, short_name)
1364
+ elif short_name in COMBINED_FNS: # eg "it_combined", "en_combined", etc
1365
+ build_combined_dataset(paths, short_name, model_type, args.augment)
1366
+ elif short_name in BIO_DATASETS:
1367
+ build_bio_dataset(paths, udbase_dir, tokenizer_dir, handparsed_dir, short_name, model_type, args.augment)
1368
+ elif short_name.startswith("en_gum"):
1369
+ # we special case GUM because it should include a filled-out GUMReddit
1370
+ print("Preparing data for %s: %s, %s" % (treebank, short_name, short_language))
1371
+ build_combined_english_gum(udbase_dir, tokenizer_dir, short_name, args.augment)
1372
+ else:
1373
+ # check that we can find the train file where we expect it
1374
+ train_conllu_file = common.find_treebank_dataset_file(treebank, udbase_dir, "train", "conllu", fail=True)
1375
+
1376
+ print("Preparing data for %s: %s, %s" % (treebank, short_name, short_language))
1377
+
1378
+ if not common.find_treebank_dataset_file(treebank, udbase_dir, "dev", "conllu", fail=False):
1379
+ process_partial_ud_treebank(treebank, udbase_dir, tokenizer_dir, short_name, short_language)
1380
+ else:
1381
+ process_ud_treebank(treebank, udbase_dir, tokenizer_dir, short_name, short_language, args.augment)
1382
+
1383
+ if model_type is common.ModelType.TOKENIZER or model_type is common.ModelType.MWT:
1384
+ if not short_name in ('th_orchid', 'th_lst20'):
1385
+ common.convert_conllu_to_txt(tokenizer_dir, short_name)
1386
+
1387
+ if args.prepare_labels:
1388
+ common.prepare_tokenizer_treebank_labels(tokenizer_dir, short_name)
1389
+
1390
+
1391
+ def main():
1392
+ common.main(process_treebank, common.ModelType.TOKENIZER, add_specific_args)
1393
+
1394
+ if __name__ == '__main__':
1395
+ main()
1396
+
stanza/stanza/utils/datasets/pretrain/__init__.py ADDED
File without changes
stanza/stanza/utils/datasets/tokenization/__init__.py ADDED
File without changes
stanza/stanza/utils/datasets/tokenization/convert_vi_vlsp.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+
4
+ punctuation_set = (',', '.', '!', '?', ')', ':', ';', '”', '…', '...')
5
+
6
+ def find_spaces(sentence):
7
+ # TODO: there are some sentences where there is only one quote,
8
+ # and some of them should be attached to the previous word instead
9
+ # of the next word. Training should work this way, though
10
+ odd_quotes = False
11
+
12
+ spaces = []
13
+ for word_idx, word in enumerate(sentence):
14
+ space = True
15
+ # Quote period at the end of a sentence needs to be attached
16
+ # to the rest of the text. Some sentences have `"... text`
17
+ # in the middle, though, so look for that
18
+ if word_idx < len(sentence) - 2 and sentence[word_idx+1] == '"':
19
+ if sentence[word_idx+2] == '.':
20
+ space = False
21
+ elif word_idx == len(sentence) - 3 and sentence[word_idx+2] == '...':
22
+ space = False
23
+ if word_idx < len(sentence) - 1:
24
+ if sentence[word_idx+1] in (',', '.', '!', '?', ')', ':', ';', '”', '…', '...','/', '%'):
25
+ space = False
26
+ if word in ('(', '“', '/'):
27
+ space = False
28
+ if word == '"':
29
+ if odd_quotes:
30
+ # already saw one quote. put this one at the end of the PREVIOUS word
31
+ # note that we know there must be at least one word already
32
+ odd_quotes = False
33
+ spaces[word_idx-1] = False
34
+ else:
35
+ odd_quotes = True
36
+ space = False
37
+ spaces.append(space)
38
+ return spaces
39
+
40
+ def add_vlsp_args(parser):
41
+ parser.add_argument('--include_pos_data', action='store_true', default=False, help='To include or not POS training dataset for tokenization training. The path to POS dataset is expected to be in the same dir with WS path. For example, extern_dir/vietnamese/VLSP2013-POS-data')
42
+ parser.add_argument('--vlsp_include_spaces', action='store_true', default=False, help='When processing vi_vlsp tokenization, include all of the spaces. Otherwise, we try to turn the text back into standard text')
43
+
44
+
45
+ def write_file(vlsp_include_spaces, output_filename, sentences, shard):
46
+ with open(output_filename, "w") as fout:
47
+ check_headlines = False
48
+ for sent_idx, sentence in enumerate(sentences):
49
+ fout.write("# sent_id = %s.%d\n" % (shard, sent_idx))
50
+ orig_text = " ".join(sentence)
51
+ #check if the previous line is a headline (no ending mark at the end) then make this sentence a new par
52
+ if check_headlines:
53
+ fout.write("# newpar id =%s.%d.1\n" % (shard, sent_idx))
54
+ check_headlines = False
55
+ if sentence[len(sentence) - 1] not in punctuation_set:
56
+ check_headlines = True
57
+
58
+ if vlsp_include_spaces:
59
+ fout.write("# text = %s\n" % orig_text)
60
+ else:
61
+ spaces = find_spaces(sentence)
62
+ full_text = ""
63
+ for word, space in zip(sentence, spaces):
64
+ # could be made more efficient, but shouldn't matter
65
+ full_text = full_text + word
66
+ if space:
67
+ full_text = full_text + " "
68
+ fout.write("# text = %s\n" % full_text)
69
+ fout.write("# orig_text = %s\n" % orig_text)
70
+ for word_idx, word in enumerate(sentence):
71
+ fake_dep = "root" if word_idx == 0 else "dep"
72
+ fout.write("%d\t%s\t%s" % ((word_idx+1), word, word))
73
+ fout.write("\t_\t_\t_")
74
+ fout.write("\t%d\t%s" % (word_idx, fake_dep))
75
+ fout.write("\t_\t")
76
+ if vlsp_include_spaces or spaces[word_idx]:
77
+ fout.write("_")
78
+ else:
79
+ fout.write("SpaceAfter=No")
80
+ fout.write("\n")
81
+ fout.write("\n")
82
+
83
+ def convert_pos_dataset(file_path):
84
+ """
85
+ This function is to process the pos dataset
86
+ """
87
+
88
+ file = open(file_path, "r")
89
+ document = file.readlines()
90
+ sentences = []
91
+ sent = []
92
+ for line in document:
93
+ if line == "\n" and len(sent)>1:
94
+ if sent not in sentences:
95
+ sentences.append(sent)
96
+ sent = []
97
+ elif line != "\n":
98
+ sent.append(line.split("\t")[0].replace("_"," ").strip())
99
+ return sentences
100
+
101
+ def convert_file(vlsp_include_spaces, input_filename, output_filename, shard, split_filename=None, split_shard=None, pos_data = None):
102
+ with open(input_filename) as fin:
103
+ lines = fin.readlines()
104
+
105
+ sentences = []
106
+ set_sentences = set()
107
+ for line in lines:
108
+ if len(line.replace("_", " ").split())>1:
109
+ words = line.split()
110
+ #one syllable lines are eliminated
111
+ if len(words) == 1 and len(words[0].split("_")) == 1:
112
+ continue
113
+ else:
114
+ words = [w.replace("_", " ") for w in words]
115
+ #only add sentences that hasn't been added before
116
+ if words not in sentences:
117
+ sentences.append(words)
118
+ set_sentences.add(' '.join(words))
119
+
120
+ if split_filename is not None:
121
+ # even this is a larger dev set than the train set
122
+ split_point = int(len(sentences) * 0.95)
123
+ #check pos_data that aren't overlapping with current VLSP WS dataset
124
+ sentences_pos = [] if pos_data is None else [sent for sent in pos_data if ' '.join(sent) not in set_sentences]
125
+ print("Added ", len(sentences_pos), " sentences from POS dataset.")
126
+ write_file(vlsp_include_spaces, output_filename, sentences[:split_point]+sentences_pos, shard)
127
+ write_file(vlsp_include_spaces, split_filename, sentences[split_point:], split_shard)
128
+ else:
129
+ write_file(vlsp_include_spaces, output_filename, sentences, shard)
130
+
131
+ def convert_vi_vlsp(extern_dir, tokenizer_dir, args):
132
+ input_path = os.path.join(extern_dir, "vietnamese", "VLSP2013-WS-data")
133
+ input_pos_path = os.path.join(extern_dir, "vietnamese", "VLSP2013-POS-data")
134
+ input_train_filename = os.path.join(input_path, "VLSP2013_WS_train_gold.txt")
135
+ input_test_filename = os.path.join(input_path, "VLSP2013_WS_test_gold.txt")
136
+
137
+ input_pos_filename = os.path.join(input_pos_path, "VLSP2013_POS_train_BI_POS_Column.txt.goldSeg")
138
+ if not os.path.exists(input_train_filename):
139
+ raise FileNotFoundError("Cannot find train set for VLSP at %s" % input_train_filename)
140
+ if not os.path.exists(input_test_filename):
141
+ raise FileNotFoundError("Cannot find test set for VLSP at %s" % input_test_filename)
142
+ pos_data = None
143
+ if args.include_pos_data:
144
+ if not os.path.exists(input_pos_filename):
145
+ raise FileNotFoundError("Cannot find pos dataset for VLSP at %" % input_pos_filename)
146
+ else:
147
+ pos_data = convert_pos_dataset(input_pos_filename)
148
+
149
+ output_train_filename = os.path.join(tokenizer_dir, "vi_vlsp.train.gold.conllu")
150
+ output_dev_filename = os.path.join(tokenizer_dir, "vi_vlsp.dev.gold.conllu")
151
+ output_test_filename = os.path.join(tokenizer_dir, "vi_vlsp.test.gold.conllu")
152
+
153
+ convert_file(args.vlsp_include_spaces, input_train_filename, output_train_filename, "train", output_dev_filename, "dev", pos_data)
154
+ convert_file(args.vlsp_include_spaces, input_test_filename, output_test_filename, "test")
155
+
stanza/stanza/utils/ner/spacy_ner_tag_dataset.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Test a spacy model on a 4 class dataset
3
+ """
4
+
5
+ import argparse
6
+ import json
7
+
8
+ import spacy
9
+ from spacy.tokens import Doc
10
+
11
+ from stanza.models.ner.utils import process_tags
12
+ from stanza.models.ner.scorer import score_by_entity, score_by_token
13
+
14
+ from stanza.utils.confusion import format_confusion
15
+ from stanza.utils.datasets.ner.simplify_ontonotes_to_worldwide import simplify_ontonotes_to_worldwide
16
+
17
+ from stanza.utils.get_tqdm import get_tqdm
18
+ tqdm = get_tqdm()
19
+
20
+ """
21
+ Simplified classes used in the Worldwide dataset are:
22
+
23
+ Date
24
+ Facility
25
+ Location
26
+ Misc
27
+ Money
28
+ NORP
29
+ Organization
30
+ Person
31
+ Product
32
+
33
+ vs OntoNotes classes:
34
+
35
+ CARDINAL
36
+ DATE
37
+ EVENT
38
+ FAC
39
+ GPE
40
+ LANGUAGE
41
+ LAW
42
+ LOC
43
+ MONEY
44
+ NORP
45
+ ORDINAL
46
+ ORG
47
+ PERCENT
48
+ PERSON
49
+ PRODUCT
50
+ QUANTITY
51
+ TIME
52
+ WORK_OF_ART
53
+ """
54
+
55
+ def test_file(eval_file, tagger, simplify):
56
+ with open(eval_file) as fin:
57
+ gold_doc = json.load(fin)
58
+ gold_doc = [[(x['text'], x['ner']) for x in sentence] for sentence in gold_doc]
59
+ gold_doc = process_tags(gold_doc, 'bioes')
60
+
61
+ if simplify:
62
+ for doc in gold_doc:
63
+ for idx, word in enumerate(doc):
64
+ if word[1] != "O":
65
+ word = [word[0], simplify_ontonotes_to_worldwide(word[1])]
66
+ doc[idx] = word
67
+
68
+ ignore_tags = "Date,DATE" if simplify else None
69
+
70
+ original_text = [[x[0] for x in gold_sentence] for gold_sentence in gold_doc]
71
+ pred_doc = []
72
+ for sentence in tqdm(original_text):
73
+ spacy_sentence = Doc(tagger.vocab, sentence)
74
+ spacy_sentence = tagger(spacy_sentence)
75
+ entities = ["O" if not token.ent_type_ else "%s-%s" % (token.ent_iob_, token.ent_type_) for token in spacy_sentence]
76
+ if simplify:
77
+ entities = [simplify_ontonotes_to_worldwide(x) for x in entities]
78
+ pred_sentence = [[token.text, entity] for token, entity in zip(spacy_sentence, entities)]
79
+ pred_doc.append(pred_sentence)
80
+
81
+ pred_doc = process_tags(pred_doc, 'bioes')
82
+ pred_tags = [[x[1] for x in sentence] for sentence in pred_doc]
83
+ gold_tags = [[x[1] for x in sentence] for sentence in gold_doc]
84
+ print("RESULTS ON: %s" % eval_file)
85
+ _, _, f_micro, _ = score_by_entity(pred_tags, gold_tags, ignore_tags=ignore_tags)
86
+ _, _, _, confusion = score_by_token(pred_tags, gold_tags, ignore_tags=ignore_tags)
87
+ print("NER token confusion matrix:\n{}".format(format_confusion(confusion, hide_blank=True, transpose=True)))
88
+ return f_micro
89
+
90
+ def main():
91
+ parser = argparse.ArgumentParser()
92
+ parser.add_argument('--ner_model', type=str, default=None, help='Which spacy model to test')
93
+ parser.add_argument('filename', type=str, nargs='*', help='which files to test')
94
+ parser.add_argument('--simplify', default=False, action='store_true', help='Simplify classes to the 8 class Worldwide model')
95
+ args = parser.parse_args()
96
+
97
+ if args.ner_model is None:
98
+ ner_models = ['en_core_web_sm', 'en_core_web_trf']
99
+ else:
100
+ ner_models = [args.ner_model]
101
+
102
+ if not args.filename:
103
+ args.filename = ["data/ner/en_ontonotes-8class.test.json",
104
+ "data/ner/en_worldwide-8class.test.json",
105
+ "data/ner/en_worldwide-8class-africa.test.json",
106
+ "data/ner/en_worldwide-8class-asia.test.json",
107
+ "data/ner/en_worldwide-8class-indigenous.test.json",
108
+ "data/ner/en_worldwide-8class-latam.test.json",
109
+ "data/ner/en_worldwide-8class-middle_east.test.json"]
110
+
111
+ print("Processing the files: %s" % ",".join(args.filename))
112
+
113
+ results = []
114
+ model_results = {}
115
+
116
+ for ner_model in ner_models:
117
+ model_results[ner_model] = []
118
+ # load tagger
119
+ print("-----------------------------")
120
+ print("Running %s" % ner_model)
121
+ print("-----------------------------")
122
+ tagger = spacy.load(ner_model, disable=["tagger", "parser", "attribute_ruler", "lemmatizer"])
123
+
124
+ for filename in args.filename:
125
+ f_micro = test_file(filename, tagger, args.simplify)
126
+ f_micro = "%.2f" % (f_micro * 100)
127
+ results.append((ner_model, filename, f_micro))
128
+ model_results[ner_model].append(f_micro)
129
+
130
+ for result in results:
131
+ print(result)
132
+
133
+ for model in model_results.keys():
134
+ result = [model] + model_results[model]
135
+ print(" & ".join(result))
136
+
137
+ if __name__ == '__main__':
138
+ main()
stanza/stanza/utils/training/__init__.py ADDED
File without changes
stanza/stanza/utils/training/remove_constituency_optimizer.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Saved a huge, bloated model with an optimizer? Use this to remove it, greatly shrinking the model size
2
+
3
+ This tries to find reasonable defaults for word vectors and charlm
4
+ (which need to be loaded so that the model knows the matrix sizes)
5
+
6
+ so ideally all that needs to be run is
7
+
8
+ python3 stanza/utils/training/remove_constituency_optimizer.py <treebanks>
9
+ python3 stanza/utils/training/remove_constituency_optimizer.py da_arboretum ...
10
+
11
+ This can also be used to load and save models as part of an update
12
+ to the serialized format
13
+ """
14
+
15
+ import argparse
16
+ import logging
17
+ import os
18
+
19
+ from stanza.models import constituency_parser
20
+ from stanza.models.common.constant import treebank_to_short_name
21
+ from stanza.resources.default_packages import default_charlms, default_pretrains
22
+ from stanza.utils.training import common
23
+
24
+ logger = logging.getLogger('stanza')
25
+
26
+ def parse_args():
27
+ parser = argparse.ArgumentParser()
28
+ parser.add_argument('--wordvec_pretrain_file', type=str, default=None, help='Exact name of the pretrain file to read')
29
+ parser.add_argument('--charlm', default="default", type=str, help='Which charlm to run on. Will use the default charlm for this language/model if not set. Set to None to turn off charlm for languages with a default charlm')
30
+ parser.add_argument('--no_charlm', dest='charlm', action="store_const", const=None, help="Don't use a charlm, even if one is used by default for this package")
31
+
32
+ parser.add_argument('--load_dir', type=str, default="saved_models/constituency", help="Root dir for getting the models to resave.")
33
+ parser.add_argument('--save_dir', type=str, default="resaved_models/constituency", help="Root dir for resaving the models.")
34
+
35
+ parser.add_argument('treebanks', type=str, nargs='+', help='Which treebanks to run on. Use all_ud or ud_all for all UD treebanks')
36
+
37
+ args = parser.parse_args()
38
+ return args
39
+
40
+ def main():
41
+ """
42
+ For each of the models specified, load and resave the model
43
+
44
+ The resaved model will have the optimizer removed
45
+ """
46
+ args = parse_args()
47
+ os.makedirs(args.save_dir, exist_ok=True)
48
+
49
+ for treebank in args.treebanks:
50
+ logger.info("PROCESSING %s", treebank)
51
+ short_name = treebank_to_short_name(treebank)
52
+ language, dataset = short_name.split("_", maxsplit=1)
53
+ logger.info("%s: %s %s", short_name, language, dataset)
54
+
55
+ if not args.wordvec_pretrain_file:
56
+ # will throw an error if the pretrain can't be found
57
+ wordvec_pretrain = common.find_wordvec_pretrain(language, default_pretrains)
58
+ wordvec_args = ['--wordvec_pretrain_file', wordvec_pretrain]
59
+ else:
60
+ wordvec_args = []
61
+
62
+ charlm = common.choose_charlm(language, dataset, args.charlm, default_charlms, {})
63
+ charlm_args = common.build_charlm_args(language, charlm, base_args=False)
64
+
65
+ base_name = '{}_constituency.pt'.format(short_name)
66
+ load_name = os.path.join(args.load_dir, base_name)
67
+ save_name = os.path.join(args.save_dir, base_name)
68
+ resave_args = ['--mode', 'remove_optimizer',
69
+ '--load_name', load_name,
70
+ '--save_name', save_name,
71
+ '--save_dir', ".",
72
+ '--shorthand', short_name]
73
+ resave_args = resave_args + wordvec_args + charlm_args
74
+ constituency_parser.main(resave_args)
75
+
76
+ if __name__ == '__main__':
77
+ main()
stanza/stanza/utils/visualization/dependency_visualization.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Functions to visualize dependency relations in texts and Stanza documents
3
+ """
4
+
5
+ from stanza.models.common.constant import is_right_to_left
6
+ import stanza
7
+ import spacy
8
+ from spacy import displacy
9
+ from spacy.tokens import Doc
10
+
11
+
12
+ def visualize_doc(doc, language):
13
+ """
14
+ Takes in a Document and visualizes it using displacy.
15
+
16
+ The document to visualize must be from the stanza pipeline.
17
+
18
+ right-to-left languages such as Arabic are displayed right-to-left based on the language code
19
+ """
20
+ visualization_options = {"compact": True, "bg": "#09a3d5", "color": "white", "distance": 90,
21
+ "font": "Source Sans Pro", "arrow_spacing": 25}
22
+ # blank model - we don't use any of the model features, just the viz
23
+ nlp = spacy.blank("en")
24
+ sentences_to_visualize = []
25
+ for sentence in doc.sentences:
26
+ words, lemmas, heads, deps, tags = [], [], [], [], []
27
+ if is_right_to_left(language): # order of words displayed is reversed, dependency arcs remain intact
28
+ sent_len = len(sentence.words)
29
+ for word in reversed(sentence.words):
30
+ words.append(word.text)
31
+ lemmas.append(word.lemma)
32
+ deps.append(word.deprel)
33
+ tags.append(word.upos)
34
+ if word.head == 0: # spaCy head indexes are formatted differently than that of Stanza
35
+ heads.append(sent_len - word.id)
36
+ else:
37
+ heads.append(sent_len - word.head)
38
+ else: # left to right rendering
39
+ for word in sentence.words:
40
+ words.append(word.text)
41
+ lemmas.append(word.lemma)
42
+ deps.append(word.deprel)
43
+ tags.append(word.upos)
44
+ if word.head == 0:
45
+ heads.append(word.id - 1)
46
+ else:
47
+ heads.append(word.head - 1)
48
+ document_result = Doc(nlp.vocab, words=words, lemmas=lemmas, heads=heads, deps=deps, pos=tags)
49
+ sentences_to_visualize.append(document_result)
50
+
51
+ for line in sentences_to_visualize: # render all sentences through displaCy
52
+ # If this program is NOT being run in a Jupyter notebook, replace displacy.render with displacy.serve
53
+ # and the visualization will be hosted locally, link being provided in the program output.
54
+ displacy.render(line, style="dep", options=visualization_options)
55
+
56
+
57
+ def visualize_str(text, pipeline_code, pipe):
58
+ """
59
+ Takes a string and visualizes it using displacy.
60
+
61
+ The string is processed using the stanza pipeline and its
62
+ dependencies are formatted into a spaCy doc object for easy
63
+ visualization. Accepts valid stanza (UD) pipelines as the pipeline
64
+ argument. Must supply the stanza pipeline code (the two-letter
65
+ abbreviation of the language, such as 'en' for English. Must also
66
+ supply the stanza pipeline object as the third argument.
67
+ """
68
+ doc = pipe(text)
69
+ visualize_doc(doc, pipeline_code)
70
+
71
+
72
+ def visualize_docs(docs, lang_code):
73
+ """
74
+ Takes in a list of Stanza document objects and a language code (ex: 'en' for English) and visualizes the
75
+ dependency relationships within each document.
76
+
77
+ This function uses spaCy visualizations. See the visualize_doc function for more details.
78
+ """
79
+ for doc in docs:
80
+ visualize_doc(doc, lang_code)
81
+
82
+
83
+ def visualize_strings(texts, lang_code):
84
+ """
85
+ Takes a language code (ex: 'en' for English) and a list of strings to process and visualizes the
86
+ dependency relationships in each text.
87
+
88
+ This function loads the Stanza pipeline for the given language and uses it to visualize all of the strings provided.
89
+ """
90
+ pipe = stanza.Pipeline(lang_code, processors="tokenize,pos,lemma,depparse")
91
+ for text in texts:
92
+ visualize_str(text, lang_code, pipe)
93
+
94
+
95
+ def main():
96
+ ar_strings = ['برلين ترفض حصول شركة اميركية على رخصة تصنيع دبابة "ليوبارد" الالمانية', "هل بإمكاني مساعدتك؟",
97
+ "أراك في مابعد", "لحظة من فضلك"]
98
+ en_strings = ["This is a sentence.",
99
+ "Barack Obama was born in Hawaii. He was elected President of the United States in 2008."]
100
+ zh_strings = ["中国是一个很有意思的国家。"]
101
+ # Testing with right to left language
102
+ visualize_strings(ar_strings, "ar")
103
+ # Testing with left to right languages
104
+ visualize_strings(en_strings, "en")
105
+ visualize_strings(zh_strings, "zh")
106
+
107
+ if __name__ == '__main__':
108
+ main()