Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- stanza/stanza/tests/classifiers/test_data.py +130 -0
- stanza/stanza/tests/constituency/test_tree_stack.py +50 -0
- stanza/stanza/tests/data/external_server.properties +1 -0
- stanza/stanza/tests/lemma/test_lowercase.py +57 -0
- stanza/stanza/tests/ner/test_bsf_2_beios.py +349 -0
- stanza/stanza/tests/ner/test_ner_training.py +261 -0
- stanza/stanza/tests/pipeline/pipeline_device_tests.py +45 -0
- stanza/stanza/tests/pipeline/test_pipeline_sentiment_processor.py +50 -0
- stanza/stanza/tests/pipeline/test_requirements.py +72 -0
- stanza/stanza/tests/tokenization/__init__.py +0 -0
- stanza/stanza/tests/tokenization/test_tokenize_utils.py +220 -0
- stanza/stanza/utils/charlm/__init__.py +0 -0
- stanza/stanza/utils/charlm/conll17_to_text.py +93 -0
- stanza/stanza/utils/charlm/dump_oscar.py +120 -0
- stanza/stanza/utils/charlm/make_lm_data.py +162 -0
- stanza/stanza/utils/constituency/check_transitions.py +27 -0
- stanza/stanza/utils/constituency/list_tensors.py +16 -0
- stanza/stanza/utils/datasets/__init__.py +0 -0
- stanza/stanza/utils/datasets/contract_mwt.py +46 -0
- stanza/stanza/utils/datasets/coref/__init__.py +0 -0
- stanza/stanza/utils/datasets/coref/convert_ontonotes.py +80 -0
- stanza/stanza/utils/datasets/coref/convert_udcoref.py +276 -0
- stanza/stanza/utils/datasets/coref/utils.py +148 -0
- stanza/stanza/utils/datasets/corenlp_segmenter_dataset.py +78 -0
- stanza/stanza/utils/datasets/ner/convert_bsnlp.py +333 -0
- stanza/stanza/utils/datasets/ner/convert_fire_2013.py +118 -0
- stanza/stanza/utils/datasets/ner/convert_hy_armtdp.py +145 -0
- stanza/stanza/utils/datasets/ner/convert_kk_kazNERD.py +35 -0
- stanza/stanza/utils/datasets/ner/convert_my_ucsy.py +102 -0
- stanza/stanza/utils/datasets/ner/convert_sindhi_siner.py +69 -0
- stanza/stanza/utils/datasets/ner/convert_starlang_ner.py +55 -0
- stanza/stanza/utils/datasets/ner/ontonotes_multitag.py +97 -0
- stanza/stanza/utils/datasets/ner/prepare_ner_file.py +78 -0
- stanza/stanza/utils/datasets/ner/utils.py +417 -0
- stanza/stanza/utils/datasets/vietnamese/__init__.py +0 -0
- stanza/stanza/utils/pretrain/compare_pretrains.py +54 -0
- stanza/stanza/utils/training/common.py +397 -0
- stanza/stanza/utils/training/compose_ete_results.py +100 -0
- stanza/stanza/utils/training/run_charlm.py +86 -0
- stanza/stanza/utils/training/run_constituency.py +130 -0
- stanza/stanza/utils/training/run_depparse.py +133 -0
- stanza/stanza/utils/training/run_lemma.py +179 -0
- stanza/stanza/utils/training/run_lemma_classifier.py +87 -0
- stanza/stanza/utils/training/run_mwt.py +122 -0
- stanza/stanza/utils/training/run_ner.py +159 -0
- stanza/stanza/utils/training/run_sentiment.py +118 -0
- stanza/stanza/utils/training/run_tokenizer.py +124 -0
- stanza/stanza/utils/training/separate_ner_pretrain.py +215 -0
- stanza/stanza/utils/visualization/__init__.py +0 -0
- stanza/stanza/utils/visualization/conll_deprel_visualization.py +83 -0
stanza/stanza/tests/classifiers/test_data.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import pytest
|
| 3 |
+
|
| 4 |
+
import stanza.models.classifiers.data as data
|
| 5 |
+
from stanza.models.classifiers.utils import WVType
|
| 6 |
+
from stanza.models.common.vocab import PAD, UNK
|
| 7 |
+
from stanza.models.constituency.parse_tree import Tree
|
| 8 |
+
|
| 9 |
+
SENTENCES = [
|
| 10 |
+
["I", "hate", "the", "Opal", "banning"],
|
| 11 |
+
["Tell", "my", "wife", "hello"], # obviously this is the neutral result
|
| 12 |
+
["I", "like", "Sh'reyan", "'s", "antennae"],
|
| 13 |
+
]
|
| 14 |
+
|
| 15 |
+
DATASET = [
|
| 16 |
+
{"sentiment": "0", "text": SENTENCES[0]},
|
| 17 |
+
{"sentiment": "1", "text": SENTENCES[1]},
|
| 18 |
+
{"sentiment": "2", "text": SENTENCES[2]},
|
| 19 |
+
]
|
| 20 |
+
|
| 21 |
+
TREES = [
|
| 22 |
+
"(ROOT (S (NP (PRP I)) (VP (VBP hate) (NP (DT the) (NN Opal) (NN banning)))))",
|
| 23 |
+
"(ROOT (S (VP (VB Tell) (NP (PRP$ my) (NN wife)) (NP (UH hello)))))",
|
| 24 |
+
"(ROOT (S (NP (PRP I)) (VP (VBP like) (NP (NP (NNP Sh'reyan) (POS 's)) (NNS antennae)))))",
|
| 25 |
+
]
|
| 26 |
+
|
| 27 |
+
DATASET_WITH_TREES = [
|
| 28 |
+
{"sentiment": "0", "text": SENTENCES[0], "constituency": TREES[0]},
|
| 29 |
+
{"sentiment": "1", "text": SENTENCES[1], "constituency": TREES[1]},
|
| 30 |
+
{"sentiment": "2", "text": SENTENCES[2], "constituency": TREES[2]},
|
| 31 |
+
]
|
| 32 |
+
|
| 33 |
+
@pytest.fixture(scope="module")
|
| 34 |
+
def train_file(tmp_path_factory):
|
| 35 |
+
train_set = DATASET * 20
|
| 36 |
+
train_filename = tmp_path_factory.mktemp("data") / "train.json"
|
| 37 |
+
with open(train_filename, "w", encoding="utf-8") as fout:
|
| 38 |
+
json.dump(train_set, fout, ensure_ascii=False)
|
| 39 |
+
return train_filename
|
| 40 |
+
|
| 41 |
+
@pytest.fixture(scope="module")
|
| 42 |
+
def dev_file(tmp_path_factory):
|
| 43 |
+
dev_set = DATASET * 2
|
| 44 |
+
dev_filename = tmp_path_factory.mktemp("data") / "dev.json"
|
| 45 |
+
with open(dev_filename, "w", encoding="utf-8") as fout:
|
| 46 |
+
json.dump(dev_set, fout, ensure_ascii=False)
|
| 47 |
+
return dev_filename
|
| 48 |
+
|
| 49 |
+
@pytest.fixture(scope="module")
|
| 50 |
+
def test_file(tmp_path_factory):
|
| 51 |
+
test_set = DATASET
|
| 52 |
+
test_filename = tmp_path_factory.mktemp("data") / "test.json"
|
| 53 |
+
with open(test_filename, "w", encoding="utf-8") as fout:
|
| 54 |
+
json.dump(test_set, fout, ensure_ascii=False)
|
| 55 |
+
return test_filename
|
| 56 |
+
|
| 57 |
+
@pytest.fixture(scope="module")
|
| 58 |
+
def train_file_with_trees(tmp_path_factory):
|
| 59 |
+
train_set = DATASET_WITH_TREES * 20
|
| 60 |
+
train_filename = tmp_path_factory.mktemp("data") / "train_trees.json"
|
| 61 |
+
with open(train_filename, "w", encoding="utf-8") as fout:
|
| 62 |
+
json.dump(train_set, fout, ensure_ascii=False)
|
| 63 |
+
return train_filename
|
| 64 |
+
|
| 65 |
+
@pytest.fixture(scope="module")
|
| 66 |
+
def dev_file_with_trees(tmp_path_factory):
|
| 67 |
+
dev_set = DATASET_WITH_TREES * 2
|
| 68 |
+
dev_filename = tmp_path_factory.mktemp("data") / "dev_trees.json"
|
| 69 |
+
with open(dev_filename, "w", encoding="utf-8") as fout:
|
| 70 |
+
json.dump(dev_set, fout, ensure_ascii=False)
|
| 71 |
+
return dev_filename
|
| 72 |
+
|
| 73 |
+
class TestClassifierData:
|
| 74 |
+
def test_read_data(self, train_file):
|
| 75 |
+
"""
|
| 76 |
+
Test reading of the json format
|
| 77 |
+
"""
|
| 78 |
+
train_set = data.read_dataset(str(train_file), WVType.OTHER, 1)
|
| 79 |
+
assert len(train_set) == 60
|
| 80 |
+
|
| 81 |
+
def test_read_data_with_trees(self, train_file, train_file_with_trees):
|
| 82 |
+
"""
|
| 83 |
+
Test reading of the json format
|
| 84 |
+
"""
|
| 85 |
+
train_trees_set = data.read_dataset(str(train_file_with_trees), WVType.OTHER, 1)
|
| 86 |
+
assert len(train_trees_set) == 60
|
| 87 |
+
for idx, x in enumerate(train_trees_set):
|
| 88 |
+
assert isinstance(x.constituency, Tree)
|
| 89 |
+
assert str(x.constituency) == TREES[idx % len(TREES)]
|
| 90 |
+
|
| 91 |
+
train_set = data.read_dataset(str(train_file), WVType.OTHER, 1)
|
| 92 |
+
|
| 93 |
+
def test_dataset_vocab(self, train_file):
|
| 94 |
+
"""
|
| 95 |
+
Converting a dataset to vocab should have a specific set of words along with PAD and UNK
|
| 96 |
+
"""
|
| 97 |
+
train_set = data.read_dataset(str(train_file), WVType.OTHER, 1)
|
| 98 |
+
vocab = data.dataset_vocab(train_set)
|
| 99 |
+
expected = set([PAD, UNK] + [x.lower() for y in SENTENCES for x in y])
|
| 100 |
+
assert set(vocab) == expected
|
| 101 |
+
|
| 102 |
+
def test_dataset_labels(self, train_file):
|
| 103 |
+
"""
|
| 104 |
+
Test the extraction of labels from a dataset
|
| 105 |
+
"""
|
| 106 |
+
train_set = data.read_dataset(str(train_file), WVType.OTHER, 1)
|
| 107 |
+
labels = data.dataset_labels(train_set)
|
| 108 |
+
assert labels == ["0", "1", "2"]
|
| 109 |
+
|
| 110 |
+
def test_sort_by_length(self, train_file):
|
| 111 |
+
"""
|
| 112 |
+
There are two unique lengths in the toy dataset
|
| 113 |
+
"""
|
| 114 |
+
train_set = data.read_dataset(str(train_file), WVType.OTHER, 1)
|
| 115 |
+
sorted_dataset = data.sort_dataset_by_len(train_set)
|
| 116 |
+
assert list(sorted_dataset.keys()) == [4, 5]
|
| 117 |
+
assert len(sorted_dataset[4]) == len(train_set) // 3
|
| 118 |
+
assert len(sorted_dataset[5]) == 2 * len(train_set) // 3
|
| 119 |
+
|
| 120 |
+
def test_check_labels(self, train_file):
|
| 121 |
+
"""
|
| 122 |
+
Check that an exception is thrown for an unknown label
|
| 123 |
+
"""
|
| 124 |
+
train_set = data.read_dataset(str(train_file), WVType.OTHER, 1)
|
| 125 |
+
labels = sorted(set([x["sentiment"] for x in DATASET]))
|
| 126 |
+
assert len(labels) > 1
|
| 127 |
+
data.check_labels(labels, train_set)
|
| 128 |
+
with pytest.raises(RuntimeError):
|
| 129 |
+
data.check_labels(labels[:1], train_set)
|
| 130 |
+
|
stanza/stanza/tests/constituency/test_tree_stack.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
|
| 3 |
+
from stanza.models.constituency.tree_stack import TreeStack
|
| 4 |
+
|
| 5 |
+
from stanza.tests import *
|
| 6 |
+
|
| 7 |
+
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
|
| 8 |
+
|
| 9 |
+
def test_simple():
|
| 10 |
+
stack = TreeStack(value=5, parent=None, length=1)
|
| 11 |
+
stack = stack.push(3)
|
| 12 |
+
stack = stack.push(1)
|
| 13 |
+
|
| 14 |
+
expected_values = [1, 3, 5]
|
| 15 |
+
for value in expected_values:
|
| 16 |
+
assert stack.value == value
|
| 17 |
+
stack = stack.pop()
|
| 18 |
+
assert stack is None
|
| 19 |
+
|
| 20 |
+
def test_iter():
|
| 21 |
+
stack = TreeStack(value=5, parent=None, length=1)
|
| 22 |
+
stack = stack.push(3)
|
| 23 |
+
stack = stack.push(1)
|
| 24 |
+
|
| 25 |
+
stack_list = list(stack)
|
| 26 |
+
assert list(stack) == [1, 3, 5]
|
| 27 |
+
|
| 28 |
+
def test_str():
|
| 29 |
+
stack = TreeStack(value=5, parent=None, length=1)
|
| 30 |
+
stack = stack.push(3)
|
| 31 |
+
stack = stack.push(1)
|
| 32 |
+
|
| 33 |
+
assert str(stack) == "TreeStack(1, 3, 5)"
|
| 34 |
+
|
| 35 |
+
def test_len():
|
| 36 |
+
stack = TreeStack(value=5, parent=None, length=1)
|
| 37 |
+
assert len(stack) == 1
|
| 38 |
+
|
| 39 |
+
stack = stack.push(3)
|
| 40 |
+
stack = stack.push(1)
|
| 41 |
+
assert len(stack) == 3
|
| 42 |
+
|
| 43 |
+
def test_long_len():
|
| 44 |
+
"""
|
| 45 |
+
Original stack had a bug where this took exponential time...
|
| 46 |
+
"""
|
| 47 |
+
stack = TreeStack(value=0, parent=None, length=1)
|
| 48 |
+
for i in range(1, 40):
|
| 49 |
+
stack = stack.push(i)
|
| 50 |
+
assert len(stack) == 40
|
stanza/stanza/tests/data/external_server.properties
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
annotators = tokenize,ssplit,pos
|
stanza/stanza/tests/lemma/test_lowercase.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
|
| 3 |
+
from stanza.models.lemmatizer import all_lowercase
|
| 4 |
+
from stanza.utils.conll import CoNLL
|
| 5 |
+
|
| 6 |
+
LATIN_CONLLU = """
|
| 7 |
+
# sent_id = train-s1
|
| 8 |
+
# text = unde et philosophus dicit felicitatem esse operationem perfectam.
|
| 9 |
+
# reference = ittb-scg-s4203
|
| 10 |
+
1 unde unde ADV O4 AdvType=Loc|PronType=Rel 4 advmod:lmod _ _
|
| 11 |
+
2 et et CCONJ O4 _ 3 advmod:emph _ _
|
| 12 |
+
3 philosophus philosophus NOUN B1|grn1|casA|gen1 Case=Nom|Gender=Masc|InflClass=IndEurO|Number=Sing 4 nsubj _ _
|
| 13 |
+
4 dicit dico VERB N3|modA|tem1|gen6 Aspect=Imp|InflClass=LatX|Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin|Voice=Act 0 root _ TraditionalMood=Indicativus|TraditionalTense=Praesens
|
| 14 |
+
5 felicitatem felicitas NOUN C1|grn1|casD|gen2 Case=Acc|Gender=Fem|InflClass=IndEurX|Number=Sing 7 nsubj _ _
|
| 15 |
+
6 esse sum AUX N3|modH|tem1 Aspect=Imp|Tense=Pres|VerbForm=Inf 7 cop _ _
|
| 16 |
+
7 operationem operatio NOUN C1|grn1|casD|gen2|vgr1 Case=Acc|Gender=Fem|InflClass=IndEurX|Number=Sing 4 ccomp _ _
|
| 17 |
+
8 perfectam perfectus ADJ A1|grn1|casD|gen2 Case=Acc|Gender=Fem|InflClass=IndEurA|Number=Sing 7 amod _ SpaceAfter=No
|
| 18 |
+
9 . . PUNCT Punc _ 4 punct _ _
|
| 19 |
+
|
| 20 |
+
# sent_id = train-s2
|
| 21 |
+
# text = perfectio autem operationis dependet ex quatuor.
|
| 22 |
+
# reference = ittb-scg-s4204
|
| 23 |
+
1 perfectio perfectio NOUN C1|grn1|casA|gen2 Case=Nom|Gender=Fem|InflClass=IndEurX|Number=Sing 4 nsubj _ _
|
| 24 |
+
2 autem autem PART O4 _ 4 discourse _ _
|
| 25 |
+
3 operationis operatio NOUN C1|grn1|casB|gen2|vgr1 Case=Gen|Gender=Fem|InflClass=IndEurX|Number=Sing 1 nmod _ _
|
| 26 |
+
4 dependet dependeo VERB K3|modA|tem1|gen6 Aspect=Imp|InflClass=LatE|Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin|Voice=Act 0 root _ TraditionalMood=Indicativus|TraditionalTense=Praesens
|
| 27 |
+
5 ex ex ADP S4|vgr2 _ 6 case _ _
|
| 28 |
+
6 quatuor quattuor NUM G1|gen3|vgr1 NumForm=Word|NumType=Card 4 obl:arg _ SpaceAfter=No
|
| 29 |
+
7 . . PUNCT Punc _ 4 punct _ _
|
| 30 |
+
""".lstrip()
|
| 31 |
+
|
| 32 |
+
ENG_CONLLU = """
|
| 33 |
+
# sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0007
|
| 34 |
+
# text = You wonder if he was manipulating the market with his bombing targets.
|
| 35 |
+
1 You you PRON PRP Case=Nom|Person=2|PronType=Prs 2 nsubj 2:nsubj _
|
| 36 |
+
2 wonder wonder VERB VBP Mood=Ind|Number=Sing|Person=2|Tense=Pres|VerbForm=Fin 0 root 0:root _
|
| 37 |
+
3 if if SCONJ IN _ 6 mark 6:mark _
|
| 38 |
+
4 he he PRON PRP Case=Nom|Gender=Masc|Number=Sing|Person=3|PronType=Prs 6 nsubj 6:nsubj _
|
| 39 |
+
5 was be AUX VBD Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin 6 aux 6:aux _
|
| 40 |
+
6 manipulating manipulate VERB VBG Tense=Pres|VerbForm=Part 2 ccomp 2:ccomp _
|
| 41 |
+
7 the the DET DT Definite=Def|PronType=Art 8 det 8:det _
|
| 42 |
+
8 market market NOUN NN Number=Sing 6 obj 6:obj _
|
| 43 |
+
9 with with ADP IN _ 12 case 12:case _
|
| 44 |
+
10 his his PRON PRP$ Case=Gen|Gender=Masc|Number=Sing|Person=3|Poss=Yes|PronType=Prs 12 nmod:poss 12:nmod:poss _
|
| 45 |
+
11 bombing bombing NOUN NN Number=Sing 12 compound 12:compound _
|
| 46 |
+
12 targets target NOUN NNS Number=Plur 6 obl 6:obl:with SpaceAfter=No
|
| 47 |
+
13 . . PUNCT . _ 2 punct 2:punct _
|
| 48 |
+
""".lstrip()
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def test_all_lowercase():
|
| 52 |
+
doc = CoNLL.conll2doc(input_str=LATIN_CONLLU)
|
| 53 |
+
assert all_lowercase(doc)
|
| 54 |
+
|
| 55 |
+
def test_not_all_lowercase():
|
| 56 |
+
doc = CoNLL.conll2doc(input_str=ENG_CONLLU)
|
| 57 |
+
assert not all_lowercase(doc)
|
stanza/stanza/tests/ner/test_bsf_2_beios.py
ADDED
|
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Tests the conversion code for the lang_uk NER dataset
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import unittest
|
| 6 |
+
from stanza.utils.datasets.ner.convert_bsf_to_beios import convert_bsf, parse_bsf, BsfInfo
|
| 7 |
+
|
| 8 |
+
import pytest
|
| 9 |
+
pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
|
| 10 |
+
|
| 11 |
+
class TestBsf2Beios(unittest.TestCase):
|
| 12 |
+
|
| 13 |
+
def test_empty_markup(self):
|
| 14 |
+
res = convert_bsf('', '')
|
| 15 |
+
self.assertEqual('', res)
|
| 16 |
+
|
| 17 |
+
def test_1line_markup(self):
|
| 18 |
+
data = 'тележурналіст Василь'
|
| 19 |
+
bsf_markup = 'T1 PERS 14 20 Василь'
|
| 20 |
+
expected = '''тележурналіст O
|
| 21 |
+
Василь S-PERS'''
|
| 22 |
+
self.assertEqual(expected, convert_bsf(data, bsf_markup))
|
| 23 |
+
|
| 24 |
+
def test_1line_follow_markup(self):
|
| 25 |
+
data = 'тележурналіст Василь .'
|
| 26 |
+
bsf_markup = 'T1 PERS 14 20 Василь'
|
| 27 |
+
expected = '''тележурналіст O
|
| 28 |
+
Василь S-PERS
|
| 29 |
+
. O'''
|
| 30 |
+
self.assertEqual(expected, convert_bsf(data, bsf_markup))
|
| 31 |
+
|
| 32 |
+
def test_1line_2tok_markup(self):
|
| 33 |
+
data = 'тележурналіст Василь Нагірний .'
|
| 34 |
+
bsf_markup = 'T1 PERS 14 29 Василь Нагірний'
|
| 35 |
+
expected = '''тележурналіст O
|
| 36 |
+
Василь B-PERS
|
| 37 |
+
Нагірний E-PERS
|
| 38 |
+
. O'''
|
| 39 |
+
self.assertEqual(expected, convert_bsf(data, bsf_markup))
|
| 40 |
+
|
| 41 |
+
def test_1line_Long_tok_markup(self):
|
| 42 |
+
data = 'А в музеї Гуцульщини і Покуття можна '
|
| 43 |
+
bsf_markup = 'T12 ORG 4 30 музеї Гуцульщини і Покуття'
|
| 44 |
+
expected = '''А O
|
| 45 |
+
в O
|
| 46 |
+
музеї B-ORG
|
| 47 |
+
Гуцульщини I-ORG
|
| 48 |
+
і I-ORG
|
| 49 |
+
Покуття E-ORG
|
| 50 |
+
можна O'''
|
| 51 |
+
self.assertEqual(expected, convert_bsf(data, bsf_markup))
|
| 52 |
+
|
| 53 |
+
def test_2line_2tok_markup(self):
|
| 54 |
+
data = '''тележурналіст Василь Нагірний .
|
| 55 |
+
В івано-франківському видавництві «Лілея НВ» вийшла друком'''
|
| 56 |
+
bsf_markup = '''T1 PERS 14 29 Василь Нагірний
|
| 57 |
+
T2 ORG 67 75 Лілея НВ'''
|
| 58 |
+
expected = '''тележурналіст O
|
| 59 |
+
Василь B-PERS
|
| 60 |
+
Нагірний E-PERS
|
| 61 |
+
. O
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
В O
|
| 65 |
+
івано-франківському O
|
| 66 |
+
видавництві O
|
| 67 |
+
« O
|
| 68 |
+
Лілея B-ORG
|
| 69 |
+
НВ E-ORG
|
| 70 |
+
» O
|
| 71 |
+
вийшла O
|
| 72 |
+
друком O'''
|
| 73 |
+
self.assertEqual(expected, convert_bsf(data, bsf_markup))
|
| 74 |
+
|
| 75 |
+
def test_real_markup(self):
|
| 76 |
+
data = '''Через напіввоєнний стан в Україні та збільшення телефонних терористичних погроз українці купуватимуть sim-карти тільки за паспортами .
|
| 77 |
+
Про це повідомив начальник управління зв'язків зі ЗМІ адміністрації Держспецзв'язку Віталій Кукса .
|
| 78 |
+
Він зауважив , що днями відомство опублікує проект змін до правил надання телекомунікаційних послуг , де будуть прописані норми ідентифікації громадян .
|
| 79 |
+
Абонентів , які на сьогодні вже мають sim-карту , за словами Віталія Кукси , реєструватимуть , коли ті звертатимуться в службу підтримки свого оператора мобільного зв'язку .
|
| 80 |
+
Однак мобільні оператори побоюються , що таке нововведення помітно зменшить продаж стартових пакетів , адже спеціалізовані магазини є лише у містах .
|
| 81 |
+
Відтак купити сімку в невеликих населених пунктах буде неможливо .
|
| 82 |
+
Крім того , нова процедура ідентифікації абонентів вимагатиме від операторів мобільного зв'язку додаткових витрат .
|
| 83 |
+
- Близько 90 % українських абонентів - це абоненти передоплати .
|
| 84 |
+
Якщо мова буде йти навіть про поетапну їх ідентифікацію , зробити це буде складно , довго і дорого .
|
| 85 |
+
Мобільним операторам доведеться йти на чималі витрати , пов'язані з укладанням і зберіганням договорів , веденням баз даних , - розповіла « Економічній правді » начальник відділу зв'язків з громадськістю « МТС-Україна » Вікторія Рубан .
|
| 86 |
+
'''
|
| 87 |
+
bsf_markup = '''T1 LOC 26 33 Україні
|
| 88 |
+
T2 ORG 203 218 Держспецзв'язку
|
| 89 |
+
T3 PERS 219 232 Віталій Кукса
|
| 90 |
+
T4 PERS 449 462 Віталія Кукси
|
| 91 |
+
T5 ORG 1201 1219 Економічній правді
|
| 92 |
+
T6 ORG 1267 1278 МТС-Україна
|
| 93 |
+
T7 PERS 1281 1295 Вікторія Рубан
|
| 94 |
+
'''
|
| 95 |
+
expected = '''Через O
|
| 96 |
+
напіввоєнний O
|
| 97 |
+
стан O
|
| 98 |
+
в O
|
| 99 |
+
Україні S-LOC
|
| 100 |
+
та O
|
| 101 |
+
збільшення O
|
| 102 |
+
телефонних O
|
| 103 |
+
терористичних O
|
| 104 |
+
погроз O
|
| 105 |
+
українці O
|
| 106 |
+
купуватимуть O
|
| 107 |
+
sim-карти O
|
| 108 |
+
тільки O
|
| 109 |
+
за O
|
| 110 |
+
паспортами O
|
| 111 |
+
. O
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
Про O
|
| 115 |
+
це O
|
| 116 |
+
повідомив O
|
| 117 |
+
начальник O
|
| 118 |
+
управління O
|
| 119 |
+
зв'язків O
|
| 120 |
+
зі O
|
| 121 |
+
ЗМІ O
|
| 122 |
+
адміністрації O
|
| 123 |
+
Держспецзв'язку S-ORG
|
| 124 |
+
Віталій B-PERS
|
| 125 |
+
Кукса E-PERS
|
| 126 |
+
. O
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
Він O
|
| 130 |
+
зауважив O
|
| 131 |
+
, O
|
| 132 |
+
що O
|
| 133 |
+
днями O
|
| 134 |
+
відомство O
|
| 135 |
+
опублікує O
|
| 136 |
+
проект O
|
| 137 |
+
змін O
|
| 138 |
+
до O
|
| 139 |
+
правил O
|
| 140 |
+
надання O
|
| 141 |
+
телекомунікаційних O
|
| 142 |
+
послуг O
|
| 143 |
+
, O
|
| 144 |
+
де O
|
| 145 |
+
будуть O
|
| 146 |
+
прописані O
|
| 147 |
+
норми O
|
| 148 |
+
ідентифікації O
|
| 149 |
+
громадян O
|
| 150 |
+
. O
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
Абонентів O
|
| 154 |
+
, O
|
| 155 |
+
які O
|
| 156 |
+
на O
|
| 157 |
+
сьогодні O
|
| 158 |
+
вже O
|
| 159 |
+
мають O
|
| 160 |
+
sim-карту O
|
| 161 |
+
, O
|
| 162 |
+
за O
|
| 163 |
+
словами O
|
| 164 |
+
Віталія B-PERS
|
| 165 |
+
Кукси E-PERS
|
| 166 |
+
, O
|
| 167 |
+
реєструватимуть O
|
| 168 |
+
, O
|
| 169 |
+
коли O
|
| 170 |
+
ті O
|
| 171 |
+
звертатимуться O
|
| 172 |
+
в O
|
| 173 |
+
службу O
|
| 174 |
+
підтримки O
|
| 175 |
+
свого O
|
| 176 |
+
оператора O
|
| 177 |
+
мобільного O
|
| 178 |
+
зв'язку O
|
| 179 |
+
. O
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
Однак O
|
| 183 |
+
мобільні O
|
| 184 |
+
оператори O
|
| 185 |
+
побоюються O
|
| 186 |
+
, O
|
| 187 |
+
що O
|
| 188 |
+
таке O
|
| 189 |
+
нововведення O
|
| 190 |
+
помітно O
|
| 191 |
+
зменшить O
|
| 192 |
+
продаж O
|
| 193 |
+
стартових O
|
| 194 |
+
пакетів O
|
| 195 |
+
, O
|
| 196 |
+
адже O
|
| 197 |
+
спеціалізовані O
|
| 198 |
+
магазини O
|
| 199 |
+
є O
|
| 200 |
+
лише O
|
| 201 |
+
у O
|
| 202 |
+
містах O
|
| 203 |
+
. O
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
Відтак O
|
| 207 |
+
купити O
|
| 208 |
+
сімку O
|
| 209 |
+
в O
|
| 210 |
+
невеликих O
|
| 211 |
+
населених O
|
| 212 |
+
пунктах O
|
| 213 |
+
буде O
|
| 214 |
+
неможливо O
|
| 215 |
+
. O
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
Крім O
|
| 219 |
+
того O
|
| 220 |
+
, O
|
| 221 |
+
нова O
|
| 222 |
+
процедура O
|
| 223 |
+
ідентифікації O
|
| 224 |
+
абонентів O
|
| 225 |
+
вимагатиме O
|
| 226 |
+
від O
|
| 227 |
+
операторів O
|
| 228 |
+
мобільного O
|
| 229 |
+
зв'язку O
|
| 230 |
+
додаткових O
|
| 231 |
+
витрат O
|
| 232 |
+
. O
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
- O
|
| 236 |
+
Близько O
|
| 237 |
+
90 O
|
| 238 |
+
% O
|
| 239 |
+
українських O
|
| 240 |
+
абонентів O
|
| 241 |
+
- O
|
| 242 |
+
це O
|
| 243 |
+
абоненти O
|
| 244 |
+
передоплати O
|
| 245 |
+
. O
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
Якщо O
|
| 249 |
+
мова O
|
| 250 |
+
буде O
|
| 251 |
+
йти O
|
| 252 |
+
навіть O
|
| 253 |
+
про O
|
| 254 |
+
поетапну O
|
| 255 |
+
їх O
|
| 256 |
+
ідентифікацію O
|
| 257 |
+
, O
|
| 258 |
+
зробити O
|
| 259 |
+
це O
|
| 260 |
+
буде O
|
| 261 |
+
складно O
|
| 262 |
+
, O
|
| 263 |
+
довго O
|
| 264 |
+
і O
|
| 265 |
+
дорого O
|
| 266 |
+
. O
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
Мобільним O
|
| 270 |
+
операторам O
|
| 271 |
+
доведеться O
|
| 272 |
+
йти O
|
| 273 |
+
на O
|
| 274 |
+
чималі O
|
| 275 |
+
витрати O
|
| 276 |
+
, O
|
| 277 |
+
пов'язані O
|
| 278 |
+
з O
|
| 279 |
+
укладанням O
|
| 280 |
+
і O
|
| 281 |
+
зберіганням O
|
| 282 |
+
договорів O
|
| 283 |
+
, O
|
| 284 |
+
веденням O
|
| 285 |
+
баз O
|
| 286 |
+
даних O
|
| 287 |
+
, O
|
| 288 |
+
- O
|
| 289 |
+
розповіла O
|
| 290 |
+
« O
|
| 291 |
+
Економічній B-ORG
|
| 292 |
+
правді E-ORG
|
| 293 |
+
» O
|
| 294 |
+
начальник O
|
| 295 |
+
відділу O
|
| 296 |
+
зв'язків O
|
| 297 |
+
з O
|
| 298 |
+
громадськістю O
|
| 299 |
+
« O
|
| 300 |
+
МТС-Україна S-ORG
|
| 301 |
+
» O
|
| 302 |
+
Вікторія B-PERS
|
| 303 |
+
Рубан E-PERS
|
| 304 |
+
. O'''
|
| 305 |
+
self.assertEqual(expected, convert_bsf(data, bsf_markup))
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
class TestBsf(unittest.TestCase):
|
| 309 |
+
|
| 310 |
+
def test_empty_bsf(self):
|
| 311 |
+
self.assertEqual(parse_bsf(''), [])
|
| 312 |
+
|
| 313 |
+
def test_empty2_bsf(self):
|
| 314 |
+
self.assertEqual(parse_bsf(' \n \n'), [])
|
| 315 |
+
|
| 316 |
+
def test_1line_bsf(self):
|
| 317 |
+
bsf = 'T1 PERS 103 118 Василь Нагірний'
|
| 318 |
+
res = parse_bsf(bsf)
|
| 319 |
+
expected = BsfInfo('T1', 'PERS', 103, 118, 'Василь Нагірний')
|
| 320 |
+
self.assertEqual(len(res), 1)
|
| 321 |
+
self.assertEqual(res, [expected])
|
| 322 |
+
|
| 323 |
+
def test_2line_bsf(self):
|
| 324 |
+
bsf = '''T9 PERS 778 783 Карла
|
| 325 |
+
T10 MISC 814 819 міста'''
|
| 326 |
+
res = parse_bsf(bsf)
|
| 327 |
+
expected = [BsfInfo('T9', 'PERS', 778, 783, 'Карла'),
|
| 328 |
+
BsfInfo('T10', 'MISC', 814, 819, 'міста')]
|
| 329 |
+
self.assertEqual(len(res), 2)
|
| 330 |
+
self.assertEqual(res, expected)
|
| 331 |
+
|
| 332 |
+
def test_multiline_bsf(self):
|
| 333 |
+
bsf = '''T3 PERS 220 235 Андрієм Кіщуком
|
| 334 |
+
T4 MISC 251 285 А .
|
| 335 |
+
Kubler .
|
| 336 |
+
Світло і тіні маестро
|
| 337 |
+
T5 PERS 363 369 Кіблер'''
|
| 338 |
+
res = parse_bsf(bsf)
|
| 339 |
+
expected = [BsfInfo('T3', 'PERS', 220, 235, 'Андрієм Кіщуком'),
|
| 340 |
+
BsfInfo('T4', 'MISC', 251, 285, '''А .
|
| 341 |
+
Kubler .
|
| 342 |
+
Світло і тіні маестро'''),
|
| 343 |
+
BsfInfo('T5', 'PERS', 363, 369, 'Кіблер')]
|
| 344 |
+
self.assertEqual(len(res), len(expected))
|
| 345 |
+
self.assertEqual(res, expected)
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
if __name__ == '__main__':
|
| 349 |
+
unittest.main()
|
stanza/stanza/tests/ner/test_ner_training.py
ADDED
|
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
import warnings
|
| 5 |
+
|
| 6 |
+
import pytest
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
|
| 10 |
+
|
| 11 |
+
from stanza.models import ner_tagger
|
| 12 |
+
from stanza.models.ner.trainer import Trainer
|
| 13 |
+
from stanza.tests import TEST_WORKING_DIR
|
| 14 |
+
from stanza.utils.datasets.ner.prepare_ner_file import process_dataset
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger('stanza')
|
| 17 |
+
|
| 18 |
+
EN_TRAIN_BIO = """
|
| 19 |
+
Chris B-PERSON
|
| 20 |
+
Manning E-PERSON
|
| 21 |
+
is O
|
| 22 |
+
a O
|
| 23 |
+
good O
|
| 24 |
+
man O
|
| 25 |
+
. O
|
| 26 |
+
|
| 27 |
+
He O
|
| 28 |
+
works O
|
| 29 |
+
in O
|
| 30 |
+
Stanford B-ORG
|
| 31 |
+
University E-ORG
|
| 32 |
+
. O
|
| 33 |
+
""".lstrip().replace(" ", "\t")
|
| 34 |
+
|
| 35 |
+
EN_DEV_BIO = """
|
| 36 |
+
Chris B-PERSON
|
| 37 |
+
Manning E-PERSON
|
| 38 |
+
is O
|
| 39 |
+
part O
|
| 40 |
+
of O
|
| 41 |
+
Computer B-ORG
|
| 42 |
+
Science E-ORG
|
| 43 |
+
""".lstrip().replace(" ", "\t")
|
| 44 |
+
|
| 45 |
+
EN_TRAIN_2TAG = """
|
| 46 |
+
Chris B-PERSON B-PER
|
| 47 |
+
Manning E-PERSON E-PER
|
| 48 |
+
is O O
|
| 49 |
+
a O O
|
| 50 |
+
good O O
|
| 51 |
+
man O O
|
| 52 |
+
. O O
|
| 53 |
+
|
| 54 |
+
He O O
|
| 55 |
+
works O O
|
| 56 |
+
in O O
|
| 57 |
+
Stanford B-ORG B-ORG
|
| 58 |
+
University E-ORG B-ORG
|
| 59 |
+
. O O
|
| 60 |
+
""".strip().replace(" ", "\t")
|
| 61 |
+
|
| 62 |
+
EN_TRAIN_2TAG_EMPTY2 = """
|
| 63 |
+
Chris B-PERSON -
|
| 64 |
+
Manning E-PERSON -
|
| 65 |
+
is O -
|
| 66 |
+
a O -
|
| 67 |
+
good O -
|
| 68 |
+
man O -
|
| 69 |
+
. O -
|
| 70 |
+
|
| 71 |
+
He O -
|
| 72 |
+
works O -
|
| 73 |
+
in O -
|
| 74 |
+
Stanford B-ORG -
|
| 75 |
+
University E-ORG -
|
| 76 |
+
. O -
|
| 77 |
+
""".strip().replace(" ", "\t")
|
| 78 |
+
|
| 79 |
+
EN_DEV_2TAG = """
|
| 80 |
+
Chris B-PERSON B-PER
|
| 81 |
+
Manning E-PERSON E-PER
|
| 82 |
+
is O O
|
| 83 |
+
part O O
|
| 84 |
+
of O O
|
| 85 |
+
Computer B-ORG B-ORG
|
| 86 |
+
Science E-ORG E-ORG
|
| 87 |
+
""".strip().replace(" ", "\t")
|
| 88 |
+
|
| 89 |
+
@pytest.fixture(scope="module")
|
| 90 |
+
def pretrain_file():
|
| 91 |
+
return f'{TEST_WORKING_DIR}/in/tiny_emb.pt'
|
| 92 |
+
|
| 93 |
+
def write_temp_file(filename, bio_data):
|
| 94 |
+
bio_filename = os.path.splitext(filename)[0] + ".bio"
|
| 95 |
+
with open(bio_filename, "w", encoding="utf-8") as fout:
|
| 96 |
+
fout.write(bio_data)
|
| 97 |
+
process_dataset(bio_filename, filename)
|
| 98 |
+
|
| 99 |
+
def write_temp_2tag(filename, bio_data):
|
| 100 |
+
doc = []
|
| 101 |
+
sentences = bio_data.split("\n\n")
|
| 102 |
+
for sentence in sentences:
|
| 103 |
+
doc.append([])
|
| 104 |
+
for word in sentence.split("\n"):
|
| 105 |
+
text, tags = word.split("\t", maxsplit=1)
|
| 106 |
+
doc[-1].append({
|
| 107 |
+
"text": text,
|
| 108 |
+
"multi_ner": tags.split()
|
| 109 |
+
})
|
| 110 |
+
|
| 111 |
+
with open(filename, "w", encoding="utf-8") as fout:
|
| 112 |
+
json.dump(doc, fout)
|
| 113 |
+
|
| 114 |
+
def get_args(tmp_path, pretrain_file, train_json, dev_json, *extra_args):
|
| 115 |
+
save_dir = tmp_path / "models"
|
| 116 |
+
args = ["--data_dir", str(tmp_path),
|
| 117 |
+
"--wordvec_pretrain_file", pretrain_file,
|
| 118 |
+
"--train_file", str(train_json),
|
| 119 |
+
"--eval_file", str(dev_json),
|
| 120 |
+
"--shorthand", "en_test",
|
| 121 |
+
"--max_steps", "100",
|
| 122 |
+
"--eval_interval", "40",
|
| 123 |
+
"--save_dir", str(save_dir)]
|
| 124 |
+
|
| 125 |
+
args = args + list(extra_args)
|
| 126 |
+
return args
|
| 127 |
+
|
| 128 |
+
def run_two_tag_training(pretrain_file, tmp_path, *extra_args, train_data=EN_TRAIN_2TAG):
|
| 129 |
+
train_json = tmp_path / "en_test.train.json"
|
| 130 |
+
write_temp_2tag(train_json, train_data)
|
| 131 |
+
|
| 132 |
+
dev_json = tmp_path / "en_test.dev.json"
|
| 133 |
+
write_temp_2tag(dev_json, EN_DEV_2TAG)
|
| 134 |
+
|
| 135 |
+
args = get_args(tmp_path, pretrain_file, train_json, dev_json, *extra_args)
|
| 136 |
+
return ner_tagger.main(args)
|
| 137 |
+
|
| 138 |
+
def test_basic_two_tag_training(pretrain_file, tmp_path):
|
| 139 |
+
trainer = run_two_tag_training(pretrain_file, tmp_path)
|
| 140 |
+
assert len(trainer.model.tag_clfs) == 2
|
| 141 |
+
assert len(trainer.model.crits) == 2
|
| 142 |
+
assert len(trainer.vocab['tag'].lens()) == 2
|
| 143 |
+
|
| 144 |
+
def test_two_tag_training_backprop(pretrain_file, tmp_path):
|
| 145 |
+
"""
|
| 146 |
+
Test that the training is backproping both tags
|
| 147 |
+
|
| 148 |
+
We can do this by using the "finetune" mechanism and verifying
|
| 149 |
+
that the output tensors are different
|
| 150 |
+
"""
|
| 151 |
+
trainer = run_two_tag_training(pretrain_file, tmp_path)
|
| 152 |
+
|
| 153 |
+
# first, need to save the final model before restarting
|
| 154 |
+
# (alternatively, could reload the final checkpoint)
|
| 155 |
+
trainer.save(os.path.join(trainer.args['save_dir'], trainer.args['save_name']))
|
| 156 |
+
new_trainer = run_two_tag_training(pretrain_file, tmp_path, "--finetune")
|
| 157 |
+
|
| 158 |
+
assert len(trainer.model.tag_clfs) == 2
|
| 159 |
+
assert len(new_trainer.model.tag_clfs) == 2
|
| 160 |
+
for old_clf, new_clf in zip(trainer.model.tag_clfs, new_trainer.model.tag_clfs):
|
| 161 |
+
assert not torch.allclose(old_clf.weight, new_clf.weight)
|
| 162 |
+
|
| 163 |
+
def test_two_tag_training_c2_backprop(pretrain_file, tmp_path):
|
| 164 |
+
"""
|
| 165 |
+
Test that the training is backproping only one tag if one column is blank
|
| 166 |
+
|
| 167 |
+
We can do this by using the "finetune" mechanism and verifying
|
| 168 |
+
that the output tensors are different in just the first column
|
| 169 |
+
"""
|
| 170 |
+
trainer = run_two_tag_training(pretrain_file, tmp_path)
|
| 171 |
+
|
| 172 |
+
# first, need to save the final model before restarting
|
| 173 |
+
# (alternatively, could reload the final checkpoint)
|
| 174 |
+
trainer.save(os.path.join(trainer.args['save_dir'], trainer.args['save_name']))
|
| 175 |
+
new_trainer = run_two_tag_training(pretrain_file, tmp_path, "--finetune", train_data=EN_TRAIN_2TAG_EMPTY2)
|
| 176 |
+
|
| 177 |
+
assert len(trainer.model.tag_clfs) == 2
|
| 178 |
+
assert len(new_trainer.model.tag_clfs) == 2
|
| 179 |
+
assert not torch.allclose(trainer.model.tag_clfs[0].weight, new_trainer.model.tag_clfs[0].weight)
|
| 180 |
+
assert torch.allclose(trainer.model.tag_clfs[1].weight, new_trainer.model.tag_clfs[1].weight)
|
| 181 |
+
|
| 182 |
+
def test_connected_two_tag_training(pretrain_file, tmp_path):
|
| 183 |
+
trainer = run_two_tag_training(pretrain_file, tmp_path, "--connect_output_layers")
|
| 184 |
+
assert len(trainer.model.tag_clfs) == 2
|
| 185 |
+
assert len(trainer.model.crits) == 2
|
| 186 |
+
assert len(trainer.vocab['tag'].lens()) == 2
|
| 187 |
+
|
| 188 |
+
# this checks that with the connected output layers,
|
| 189 |
+
# the second output layer has its size increased
|
| 190 |
+
# by the number of tags known to the first output layer
|
| 191 |
+
assert trainer.model.tag_clfs[1].weight.shape[1] == trainer.vocab['tag'].lens()[0] + trainer.model.tag_clfs[0].weight.shape[1]
|
| 192 |
+
|
| 193 |
+
def run_training(pretrain_file, tmp_path, *extra_args):
|
| 194 |
+
train_json = tmp_path / "en_test.train.json"
|
| 195 |
+
write_temp_file(train_json, EN_TRAIN_BIO)
|
| 196 |
+
|
| 197 |
+
dev_json = tmp_path / "en_test.dev.json"
|
| 198 |
+
write_temp_file(dev_json, EN_DEV_BIO)
|
| 199 |
+
|
| 200 |
+
args = get_args(tmp_path, pretrain_file, train_json, dev_json, *extra_args)
|
| 201 |
+
return ner_tagger.main(args)
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def test_train_model_gpu(pretrain_file, tmp_path):
|
| 205 |
+
"""
|
| 206 |
+
Briefly train an NER model (no expectation of correctness) and check that it is on the GPU
|
| 207 |
+
"""
|
| 208 |
+
trainer = run_training(pretrain_file, tmp_path)
|
| 209 |
+
if not torch.cuda.is_available():
|
| 210 |
+
warnings.warn("Cannot check that the NER model is on the GPU, since GPU is not available")
|
| 211 |
+
return
|
| 212 |
+
|
| 213 |
+
model = trainer.model
|
| 214 |
+
device = next(model.parameters()).device
|
| 215 |
+
assert str(device).startswith("cuda")
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def test_train_model_cpu(pretrain_file, tmp_path):
|
| 219 |
+
"""
|
| 220 |
+
Briefly train an NER model (no expectation of correctness) and check that it is on the GPU
|
| 221 |
+
"""
|
| 222 |
+
trainer = run_training(pretrain_file, tmp_path, "--cpu")
|
| 223 |
+
|
| 224 |
+
model = trainer.model
|
| 225 |
+
device = next(model.parameters()).device
|
| 226 |
+
assert str(device).startswith("cpu")
|
| 227 |
+
|
| 228 |
+
def model_file_has_bert(filename):
|
| 229 |
+
checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True)
|
| 230 |
+
return any(x.startswith("bert_model.") for x in checkpoint['model'].keys())
|
| 231 |
+
|
| 232 |
+
def test_with_bert(pretrain_file, tmp_path):
|
| 233 |
+
trainer = run_training(pretrain_file, tmp_path, '--bert_model', 'hf-internal-testing/tiny-bert')
|
| 234 |
+
model_file = os.path.join(trainer.args['save_dir'], trainer.args['save_name'])
|
| 235 |
+
assert not model_file_has_bert(model_file)
|
| 236 |
+
|
| 237 |
+
def test_with_bert_finetune(pretrain_file, tmp_path):
|
| 238 |
+
trainer = run_training(pretrain_file, tmp_path, '--bert_model', 'hf-internal-testing/tiny-bert', '--bert_finetune')
|
| 239 |
+
model_file = os.path.join(trainer.args['save_dir'], trainer.args['save_name'])
|
| 240 |
+
assert model_file_has_bert(model_file)
|
| 241 |
+
|
| 242 |
+
foo_save_filename = os.path.join(tmp_path, "foo_" + trainer.args['save_name'])
|
| 243 |
+
bar_save_filename = os.path.join(tmp_path, "bar_" + trainer.args['save_name'])
|
| 244 |
+
trainer.save(foo_save_filename)
|
| 245 |
+
assert model_file_has_bert(foo_save_filename)
|
| 246 |
+
|
| 247 |
+
# TODO: technically this should still work if we turn off bert finetuning when reloading
|
| 248 |
+
reloaded_trainer = Trainer(args=trainer.args, model_file=foo_save_filename)
|
| 249 |
+
reloaded_trainer.save(bar_save_filename)
|
| 250 |
+
assert model_file_has_bert(bar_save_filename)
|
| 251 |
+
|
| 252 |
+
def test_with_peft_finetune(pretrain_file, tmp_path):
|
| 253 |
+
# TODO: check that the peft tensors are moving when training?
|
| 254 |
+
trainer = run_training(pretrain_file, tmp_path, '--bert_model', 'hf-internal-testing/tiny-bert', '--use_peft')
|
| 255 |
+
model_file = os.path.join(trainer.args['save_dir'], trainer.args['save_name'])
|
| 256 |
+
checkpoint = torch.load(model_file, lambda storage, loc: storage, weights_only=True)
|
| 257 |
+
assert 'bert_lora' in checkpoint
|
| 258 |
+
assert not any(x.startswith("bert_model.") for x in checkpoint['model'].keys())
|
| 259 |
+
|
| 260 |
+
# test loading
|
| 261 |
+
reloaded_trainer = Trainer(args=trainer.args, model_file=model_file)
|
stanza/stanza/tests/pipeline/pipeline_device_tests.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utility methods to check that all processors are on the expected device
|
| 3 |
+
|
| 4 |
+
Refactored since it can be used for multiple pipelines
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import warnings
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
def check_on_gpu(pipeline):
|
| 12 |
+
"""
|
| 13 |
+
Check that the processors are all on the GPU and that basic execution works
|
| 14 |
+
"""
|
| 15 |
+
if not torch.cuda.is_available():
|
| 16 |
+
warnings.warn("Unable to run the test that checks the pipeline is on the GPU, as there is no GPU available!")
|
| 17 |
+
return
|
| 18 |
+
|
| 19 |
+
for name, proc in pipeline.processors.items():
|
| 20 |
+
if proc.trainer is not None:
|
| 21 |
+
device = next(proc.trainer.model.parameters()).device
|
| 22 |
+
else:
|
| 23 |
+
device = next(proc._model.parameters()).device
|
| 24 |
+
|
| 25 |
+
assert str(device).startswith("cuda"), "Processor %s was not on the GPU" % name
|
| 26 |
+
|
| 27 |
+
# just check that there are no cpu/cuda tensor conflicts
|
| 28 |
+
# when running on the GPU
|
| 29 |
+
pipeline("This is a small test")
|
| 30 |
+
|
| 31 |
+
def check_on_cpu(pipeline):
|
| 32 |
+
"""
|
| 33 |
+
Check that the processors are all on the CPU and that basic execution works
|
| 34 |
+
"""
|
| 35 |
+
for name, proc in pipeline.processors.items():
|
| 36 |
+
if proc.trainer is not None:
|
| 37 |
+
device = next(proc.trainer.model.parameters()).device
|
| 38 |
+
else:
|
| 39 |
+
device = next(proc._model.parameters()).device
|
| 40 |
+
|
| 41 |
+
assert str(device).startswith("cpu"), "Processor %s was not on the CPU" % name
|
| 42 |
+
|
| 43 |
+
# just check that there are no cpu/cuda tensor conflicts
|
| 44 |
+
# when running on the CPU
|
| 45 |
+
pipeline("This is a small test")
|
stanza/stanza/tests/pipeline/test_pipeline_sentiment_processor.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gc
|
| 2 |
+
|
| 3 |
+
import pytest
|
| 4 |
+
import stanza
|
| 5 |
+
from stanza.utils.conll import CoNLL
|
| 6 |
+
from stanza.models.common.doc import Document
|
| 7 |
+
|
| 8 |
+
from stanza.tests import *
|
| 9 |
+
|
| 10 |
+
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
|
| 11 |
+
|
| 12 |
+
# data for testing
|
| 13 |
+
EN_DOCS = ["Ragavan is terrible and should go away.", "Today is okay.", "Urza's Saga is great."]
|
| 14 |
+
|
| 15 |
+
EN_DOC = " ".join(EN_DOCS)
|
| 16 |
+
|
| 17 |
+
EXPECTED = [0, 1, 2]
|
| 18 |
+
|
| 19 |
+
class TestSentimentPipeline:
|
| 20 |
+
@pytest.fixture(scope="class")
|
| 21 |
+
def pipeline(self):
|
| 22 |
+
"""
|
| 23 |
+
A reusable pipeline with the NER module
|
| 24 |
+
"""
|
| 25 |
+
gc.collect()
|
| 26 |
+
return stanza.Pipeline(dir=TEST_MODELS_DIR, processors="tokenize,sentiment")
|
| 27 |
+
|
| 28 |
+
def test_simple(self, pipeline):
|
| 29 |
+
results = []
|
| 30 |
+
for text in EN_DOCS:
|
| 31 |
+
doc = pipeline(text)
|
| 32 |
+
assert len(doc.sentences) == 1
|
| 33 |
+
results.append(doc.sentences[0].sentiment)
|
| 34 |
+
assert EXPECTED == results
|
| 35 |
+
|
| 36 |
+
def test_multiple_sentences(self, pipeline):
|
| 37 |
+
doc = pipeline(EN_DOC)
|
| 38 |
+
assert len(doc.sentences) == 3
|
| 39 |
+
results = [sentence.sentiment for sentence in doc.sentences]
|
| 40 |
+
assert EXPECTED == results
|
| 41 |
+
|
| 42 |
+
def test_empty_text(self, pipeline):
|
| 43 |
+
"""
|
| 44 |
+
Test empty text and a text which might get reduced to empty text by removing dashes
|
| 45 |
+
"""
|
| 46 |
+
doc = pipeline("")
|
| 47 |
+
assert len(doc.sentences) == 0
|
| 48 |
+
|
| 49 |
+
doc = pipeline("--")
|
| 50 |
+
assert len(doc.sentences) == 1
|
stanza/stanza/tests/pipeline/test_requirements.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Test the requirements functionality for processors
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import pytest
|
| 6 |
+
import stanza
|
| 7 |
+
|
| 8 |
+
from stanza.pipeline.core import PipelineRequirementsException
|
| 9 |
+
from stanza.pipeline.processor import ProcessorRequirementsException
|
| 10 |
+
from stanza.tests import *
|
| 11 |
+
|
| 12 |
+
pytestmark = pytest.mark.pipeline
|
| 13 |
+
|
| 14 |
+
def check_exception_vals(req_exception, req_exception_vals):
|
| 15 |
+
"""
|
| 16 |
+
Check the values of a ProcessorRequirementsException against a dict of expected values.
|
| 17 |
+
:param req_exception: the ProcessorRequirementsException to evaluate
|
| 18 |
+
:param req_exception_vals: expected values for the ProcessorRequirementsException
|
| 19 |
+
:return: None
|
| 20 |
+
"""
|
| 21 |
+
assert isinstance(req_exception, ProcessorRequirementsException)
|
| 22 |
+
assert req_exception.processor_type == req_exception_vals['processor_type']
|
| 23 |
+
assert req_exception.processors_list == req_exception_vals['processors_list']
|
| 24 |
+
assert req_exception.err_processor.requires == req_exception_vals['requires']
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def test_missing_requirements():
|
| 28 |
+
"""
|
| 29 |
+
Try to build several pipelines with bad configs and check thrown exceptions against gold exceptions.
|
| 30 |
+
:return: None
|
| 31 |
+
"""
|
| 32 |
+
# list of (bad configs, list of gold ProcessorRequirementsExceptions that should be thrown) pairs
|
| 33 |
+
bad_config_lists = [
|
| 34 |
+
# missing tokenize
|
| 35 |
+
(
|
| 36 |
+
# input config
|
| 37 |
+
{'processors': 'pos,depparse', 'dir': TEST_MODELS_DIR, 'lang': 'en'},
|
| 38 |
+
# 2 expected exceptions
|
| 39 |
+
[
|
| 40 |
+
{'processor_type': 'POSProcessor', 'processors_list': ['pos', 'depparse'], 'provided_reqs': set([]),
|
| 41 |
+
'requires': set(['tokenize'])},
|
| 42 |
+
{'processor_type': 'DepparseProcessor', 'processors_list': ['pos', 'depparse'],
|
| 43 |
+
'provided_reqs': set([]), 'requires': set(['tokenize','pos', 'lemma'])}
|
| 44 |
+
]
|
| 45 |
+
),
|
| 46 |
+
# no pos when lemma_pos set to True; for english mwt should not be included in the loaded processor list
|
| 47 |
+
(
|
| 48 |
+
# input config
|
| 49 |
+
{'processors': 'tokenize,mwt,lemma', 'dir': TEST_MODELS_DIR, 'lang': 'en', 'lemma_pos': True},
|
| 50 |
+
# 1 expected exception
|
| 51 |
+
[
|
| 52 |
+
{'processor_type': 'LemmaProcessor', 'processors_list': ['tokenize', 'mwt', 'lemma'],
|
| 53 |
+
'provided_reqs': set(['tokenize', 'mwt']), 'requires': set(['tokenize', 'pos'])}
|
| 54 |
+
]
|
| 55 |
+
)
|
| 56 |
+
]
|
| 57 |
+
# try to build each bad config, catch exceptions, check against gold
|
| 58 |
+
pipeline_fails = 0
|
| 59 |
+
for bad_config, gold_exceptions in bad_config_lists:
|
| 60 |
+
try:
|
| 61 |
+
stanza.Pipeline(**bad_config)
|
| 62 |
+
except PipelineRequirementsException as e:
|
| 63 |
+
pipeline_fails += 1
|
| 64 |
+
assert isinstance(e, PipelineRequirementsException)
|
| 65 |
+
assert len(e.processor_req_fails) == len(gold_exceptions)
|
| 66 |
+
for processor_req_e, gold_exception in zip(e.processor_req_fails,gold_exceptions):
|
| 67 |
+
# compare the thrown ProcessorRequirementsExceptions against gold
|
| 68 |
+
check_exception_vals(processor_req_e, gold_exception)
|
| 69 |
+
# check pipeline building failed twice
|
| 70 |
+
assert pipeline_fails == 2
|
| 71 |
+
|
| 72 |
+
|
stanza/stanza/tests/tokenization/__init__.py
ADDED
|
File without changes
|
stanza/stanza/tests/tokenization/test_tokenize_utils.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Very simple test of the sentence slicing by <PAD> tags
|
| 3 |
+
|
| 4 |
+
TODO: could add a bunch more simple tests for the tokenization utils
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import pytest
|
| 8 |
+
import stanza
|
| 9 |
+
|
| 10 |
+
from stanza import Pipeline
|
| 11 |
+
from stanza.tests import *
|
| 12 |
+
from stanza.models.common import doc
|
| 13 |
+
from stanza.models.tokenization import data
|
| 14 |
+
from stanza.models.tokenization import utils
|
| 15 |
+
|
| 16 |
+
pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
|
| 17 |
+
|
| 18 |
+
def test_find_spans():
|
| 19 |
+
"""
|
| 20 |
+
Test various raw -> span manipulations
|
| 21 |
+
"""
|
| 22 |
+
raw = ['u', 'n', 'b', 'a', 'n', ' ', 'm', 'o', 'x', ' ', 'o', 'p', 'a', 'l']
|
| 23 |
+
assert utils.find_spans(raw) == [(0, 14)]
|
| 24 |
+
|
| 25 |
+
raw = ['u', 'n', 'b', 'a', 'n', ' ', 'm', 'o', 'x', ' ', 'o', 'p', 'a', 'l', '<PAD>']
|
| 26 |
+
assert utils.find_spans(raw) == [(0, 14)]
|
| 27 |
+
|
| 28 |
+
raw = ['<PAD>', 'u', 'n', 'b', 'a', 'n', ' ', 'm', 'o', 'x', ' ', 'o', 'p', 'a', 'l', '<PAD>']
|
| 29 |
+
assert utils.find_spans(raw) == [(1, 15)]
|
| 30 |
+
|
| 31 |
+
raw = ['<PAD>', 'u', 'n', 'b', 'a', 'n', ' ', 'm', 'o', 'x', ' ', 'o', 'p', 'a', 'l']
|
| 32 |
+
assert utils.find_spans(raw) == [(1, 15)]
|
| 33 |
+
|
| 34 |
+
raw = ['<PAD>', 'u', 'n', 'b', 'a', 'n', '<PAD>', 'm', 'o', 'x', ' ', 'o', 'p', 'a', 'l']
|
| 35 |
+
assert utils.find_spans(raw) == [(1, 6), (7, 15)]
|
| 36 |
+
|
| 37 |
+
def check_offsets(doc, expected_offsets):
|
| 38 |
+
"""
|
| 39 |
+
Compare the start_char and end_char of the tokens in the doc with the given list of list of offsets
|
| 40 |
+
"""
|
| 41 |
+
assert len(doc.sentences) == len(expected_offsets)
|
| 42 |
+
for sentence, offsets in zip(doc.sentences, expected_offsets):
|
| 43 |
+
assert len(sentence.tokens) == len(offsets)
|
| 44 |
+
for token, offset in zip(sentence.tokens, offsets):
|
| 45 |
+
assert token.start_char == offset[0]
|
| 46 |
+
assert token.end_char == offset[1]
|
| 47 |
+
|
| 48 |
+
def test_match_tokens_with_text():
|
| 49 |
+
"""
|
| 50 |
+
Test the conversion of pretokenized text to Document
|
| 51 |
+
"""
|
| 52 |
+
doc = utils.match_tokens_with_text([["This", "is", "a", "test"]], "Thisisatest")
|
| 53 |
+
expected_offsets = [[(0, 4), (4, 6), (6, 7), (7, 11)]]
|
| 54 |
+
check_offsets(doc, expected_offsets)
|
| 55 |
+
|
| 56 |
+
doc = utils.match_tokens_with_text([["This", "is", "a", "test"], ["unban", "mox", "opal", "!"]], "Thisisatest unban mox opal!")
|
| 57 |
+
expected_offsets = [[(0, 4), (4, 6), (6, 7), (7, 11)],
|
| 58 |
+
[(13, 18), (19, 22), (24, 28), (28, 29)]]
|
| 59 |
+
check_offsets(doc, expected_offsets)
|
| 60 |
+
|
| 61 |
+
with pytest.raises(ValueError):
|
| 62 |
+
doc = utils.match_tokens_with_text([["This", "is", "a", "test"]], "Thisisatestttt")
|
| 63 |
+
|
| 64 |
+
with pytest.raises(ValueError):
|
| 65 |
+
doc = utils.match_tokens_with_text([["This", "is", "a", "test"]], "Thisisates")
|
| 66 |
+
|
| 67 |
+
with pytest.raises(ValueError):
|
| 68 |
+
doc = utils.match_tokens_with_text([["This", "iz", "a", "test"]], "Thisisatest")
|
| 69 |
+
|
| 70 |
+
def test_long_paragraph():
|
| 71 |
+
"""
|
| 72 |
+
Test the tokenizer's capacity to break text up into smaller chunks
|
| 73 |
+
"""
|
| 74 |
+
pipeline = Pipeline("en", dir=TEST_MODELS_DIR, processors="tokenize")
|
| 75 |
+
tokenizer = pipeline.processors['tokenize']
|
| 76 |
+
|
| 77 |
+
raw_text = "TIL not to ask a date to dress up as Smurfette on a first date. " * 100
|
| 78 |
+
|
| 79 |
+
# run a test to make sure the chunk operation is called
|
| 80 |
+
# if not, the test isn't actually testing what we need to test
|
| 81 |
+
batches = data.DataLoader(tokenizer.config, input_text=raw_text, vocab=tokenizer.vocab, evaluation=True, dictionary=tokenizer.trainer.dictionary)
|
| 82 |
+
batches.advance_old_batch = None
|
| 83 |
+
with pytest.raises(TypeError):
|
| 84 |
+
_, _, _, document = utils.output_predictions(None, tokenizer.trainer, batches, tokenizer.vocab, None, 3000,
|
| 85 |
+
orig_text=raw_text,
|
| 86 |
+
no_ssplit=tokenizer.config.get('no_ssplit', False))
|
| 87 |
+
|
| 88 |
+
# a new DataLoader should not be crippled as the above one was
|
| 89 |
+
batches = data.DataLoader(tokenizer.config, input_text=raw_text, vocab=tokenizer.vocab, evaluation=True, dictionary=tokenizer.trainer.dictionary)
|
| 90 |
+
_, _, _, document = utils.output_predictions(None, tokenizer.trainer, batches, tokenizer.vocab, None, 3000,
|
| 91 |
+
orig_text=raw_text,
|
| 92 |
+
no_ssplit=tokenizer.config.get('no_ssplit', False))
|
| 93 |
+
|
| 94 |
+
document = doc.Document(document, raw_text)
|
| 95 |
+
assert len(document.sentences) == 100
|
| 96 |
+
|
| 97 |
+
def test_postprocessor_application():
|
| 98 |
+
"""
|
| 99 |
+
Check that the postprocessor behaves correctly by applying the identity postprocessor and hoping that it does indeed return correctly.
|
| 100 |
+
"""
|
| 101 |
+
|
| 102 |
+
good_tokenization = [['I', 'am', 'Joe.', '⭆⊱⇞', 'Hi', '.'], ["I'm", 'a', 'chicken', '.']]
|
| 103 |
+
text = "I am Joe. ⭆⊱⇞ Hi. I'm a chicken."
|
| 104 |
+
|
| 105 |
+
target_doc = [[{'id': 1, 'text': 'I', 'start_char': 0, 'end_char': 1}, {'id': 2, 'text': 'am', 'start_char': 2, 'end_char': 4}, {'id': 3, 'text': 'Joe.', 'start_char': 5, 'end_char': 9}, {'id': 4, 'text': '⭆⊱⇞', 'start_char': 10, 'end_char': 13}, {'id': 5, 'text': 'Hi', 'start_char': 14, 'end_char': 16, 'misc': 'SpaceAfter=No'}, {'id': 6, 'text': '.', 'start_char': 16, 'end_char': 17}], [{'id': 1, 'text': "I'm", 'start_char': 18, 'end_char': 21}, {'id': 2, 'text': 'a', 'start_char': 22, 'end_char': 23}, {'id': 3, 'text': 'chicken', 'start_char': 24, 'end_char': 31, 'misc': 'SpaceAfter=No'}, {'id': 4, 'text': '.', 'start_char': 31, 'end_char': 32, 'misc': 'SpaceAfter=No'}]]
|
| 106 |
+
|
| 107 |
+
def postprocesor(_):
|
| 108 |
+
return good_tokenization
|
| 109 |
+
|
| 110 |
+
res = utils.postprocess_doc(target_doc, postprocesor, text)
|
| 111 |
+
|
| 112 |
+
assert res == target_doc
|
| 113 |
+
|
| 114 |
+
def test_reassembly_indexing():
|
| 115 |
+
"""
|
| 116 |
+
Check that the reassembly code counts the indicies correctly, and including OOV chars.
|
| 117 |
+
"""
|
| 118 |
+
|
| 119 |
+
good_tokenization = [['I', 'am', 'Joe.', '⭆⊱⇞', 'Hi', '.'], ["I'm", 'a', 'chicken', '.']]
|
| 120 |
+
good_mwts = [[False for _ in range(len(i))] for i in good_tokenization]
|
| 121 |
+
good_expansions = [[None for _ in range(len(i))] for i in good_tokenization]
|
| 122 |
+
|
| 123 |
+
text = "I am Joe. ⭆⊱⇞ Hi. I'm a chicken."
|
| 124 |
+
|
| 125 |
+
target_doc = [[{'id': 1, 'text': 'I', 'start_char': 0, 'end_char': 1}, {'id': 2, 'text': 'am', 'start_char': 2, 'end_char': 4}, {'id': 3, 'text': 'Joe.', 'start_char': 5, 'end_char': 9}, {'id': 4, 'text': '⭆⊱⇞', 'start_char': 10, 'end_char': 13}, {'id': 5, 'text': 'Hi', 'start_char': 14, 'end_char': 16, 'misc': 'SpaceAfter=No'}, {'id': 6, 'text': '.', 'start_char': 16, 'end_char': 17}], [{'id': 1, 'text': "I'm", 'start_char': 18, 'end_char': 21}, {'id': 2, 'text': 'a', 'start_char': 22, 'end_char': 23}, {'id': 3, 'text': 'chicken', 'start_char': 24, 'end_char': 31, 'misc': 'SpaceAfter=No'}, {'id': 4, 'text': '.', 'start_char': 31, 'end_char': 32, 'misc': 'SpaceAfter=No'}]]
|
| 126 |
+
|
| 127 |
+
res = utils.reassemble_doc_from_tokens(good_tokenization, good_mwts, good_expansions, text)
|
| 128 |
+
|
| 129 |
+
assert res == target_doc
|
| 130 |
+
|
| 131 |
+
def test_reassembly_reference_failures():
|
| 132 |
+
"""
|
| 133 |
+
Check that the reassembly code complains correctly when the user adds tokens that doesn't exist
|
| 134 |
+
"""
|
| 135 |
+
|
| 136 |
+
bad_addition_tokenization = [['Joe', 'Smith', 'lives', 'in', 'Southern', 'California', '.']]
|
| 137 |
+
bad_addition_mwts = [[False for _ in range(len(bad_addition_tokenization[0]))]]
|
| 138 |
+
bad_addition_expansions = [[None for _ in range(len(bad_addition_tokenization[0]))]]
|
| 139 |
+
|
| 140 |
+
bad_inline_tokenization = [['Joe', 'Smith', 'lives', 'in', 'Californiaa', '.']]
|
| 141 |
+
bad_inline_mwts = [[False for _ in range(len(bad_inline_tokenization[0]))]]
|
| 142 |
+
bad_inline_expansions = [[None for _ in range(len(bad_inline_tokenization[0]))]]
|
| 143 |
+
|
| 144 |
+
good_tokenization = [['Joe', 'Smith', 'lives', 'in', 'California', '.']]
|
| 145 |
+
good_mwts = [[False for _ in range(len(good_tokenization[0]))]]
|
| 146 |
+
good_expansions = [[None for _ in range(len(good_tokenization[0]))]]
|
| 147 |
+
|
| 148 |
+
text = "Joe Smith lives in California."
|
| 149 |
+
|
| 150 |
+
with pytest.raises(ValueError):
|
| 151 |
+
utils.reassemble_doc_from_tokens(bad_addition_tokenization, bad_addition_mwts, bad_addition_expansions, text)
|
| 152 |
+
|
| 153 |
+
with pytest.raises(ValueError):
|
| 154 |
+
utils.reassemble_doc_from_tokens(bad_inline_tokenization, bad_inline_mwts, bad_inline_mwts, text)
|
| 155 |
+
|
| 156 |
+
utils.reassemble_doc_from_tokens(good_tokenization, good_mwts, good_expansions, text)
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
TRAIN_DATA = """
|
| 161 |
+
# sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0003
|
| 162 |
+
# text = DPA: Iraqi authorities announced that they'd busted up three terrorist cells operating in Baghdad.
|
| 163 |
+
1 DPA DPA PROPN NNP Number=Sing 0 root 0:root SpaceAfter=No
|
| 164 |
+
2 : : PUNCT : _ 1 punct 1:punct _
|
| 165 |
+
3 Iraqi Iraqi ADJ JJ Degree=Pos 4 amod 4:amod _
|
| 166 |
+
4 authorities authority NOUN NNS Number=Plur 5 nsubj 5:nsubj _
|
| 167 |
+
5 announced announce VERB VBD Mood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin 1 parataxis 1:parataxis _
|
| 168 |
+
6 that that SCONJ IN _ 9 mark 9:mark _
|
| 169 |
+
7-8 they'd _ _ _ _ _ _ _ _
|
| 170 |
+
7 they they PRON PRP Case=Nom|Number=Plur|Person=3|PronType=Prs 9 nsubj 9:nsubj _
|
| 171 |
+
8 'd have AUX VBD Mood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin 9 aux 9:aux _
|
| 172 |
+
9 busted bust VERB VBN Tense=Past|VerbForm=Part 5 ccomp 5:ccomp _
|
| 173 |
+
10 up up ADP RP _ 9 compound:prt 9:compound:prt _
|
| 174 |
+
11 three three NUM CD NumForm=Digit|NumType=Card 13 nummod 13:nummod _
|
| 175 |
+
12 terrorist terrorist ADJ JJ Degree=Pos 13 amod 13:amod _
|
| 176 |
+
13 cells cell NOUN NNS Number=Plur 9 obj 9:obj _
|
| 177 |
+
14 operating operate VERB VBG VerbForm=Ger 13 acl 13:acl _
|
| 178 |
+
15 in in ADP IN _ 16 case 16:case _
|
| 179 |
+
16 Baghdad Baghdad PROPN NNP Number=Sing 14 obl 14:obl:in SpaceAfter=No
|
| 180 |
+
17 . . PUNCT . _ 1 punct 1:punct _
|
| 181 |
+
|
| 182 |
+
# sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0004
|
| 183 |
+
# text = Two of them were being run by 2 officials of the Ministry of the Interior!
|
| 184 |
+
1 Two two NUM CD NumForm=Word|NumType=Card 6 nsubj:pass 6:nsubj:pass _
|
| 185 |
+
2 of of ADP IN _ 3 case 3:case _
|
| 186 |
+
3 them they PRON PRP Case=Acc|Number=Plur|Person=3|PronType=Prs 1 nmod 1:nmod:of _
|
| 187 |
+
4 were be AUX VBD Mood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin 6 aux 6:aux _
|
| 188 |
+
5 being be AUX VBG VerbForm=Ger 6 aux:pass 6:aux:pass _
|
| 189 |
+
6 run run VERB VBN Tense=Past|VerbForm=Part|Voice=Pass 0 root 0:root _
|
| 190 |
+
7 by by ADP IN _ 9 case 9:case _
|
| 191 |
+
8 2 2 NUM CD NumForm=Digit|NumType=Card 9 nummod 9:nummod _
|
| 192 |
+
9 officials official NOUN NNS Number=Plur 6 obl 6:obl:by _
|
| 193 |
+
10 of of ADP IN _ 12 case 12:case _
|
| 194 |
+
11 the the DET DT Definite=Def|PronType=Art 12 det 12:det _
|
| 195 |
+
12 Ministry Ministry PROPN NNP Number=Sing 9 nmod 9:nmod:of _
|
| 196 |
+
13 of of ADP IN _ 15 case 15:case _
|
| 197 |
+
14 the the DET DT Definite=Def|PronType=Art 15 det 15:det _
|
| 198 |
+
15 Interior Interior PROPN NNP Number=Sing 12 nmod 12:nmod:of SpaceAfter=No
|
| 199 |
+
16 ! ! PUNCT . _ 6 punct 6:punct _
|
| 200 |
+
|
| 201 |
+
""".lstrip()
|
| 202 |
+
|
| 203 |
+
def test_lexicon_from_training_data(tmp_path):
|
| 204 |
+
"""
|
| 205 |
+
Test a couple aspects of building a lexicon from training data
|
| 206 |
+
|
| 207 |
+
expected number of words eliminated for being too long
|
| 208 |
+
duplicate words counted once
|
| 209 |
+
numbers eliminated
|
| 210 |
+
"""
|
| 211 |
+
conllu_file = str(tmp_path / "train.conllu")
|
| 212 |
+
with open(conllu_file, "w", encoding="utf-8") as fout:
|
| 213 |
+
fout.write(TRAIN_DATA)
|
| 214 |
+
|
| 215 |
+
lexicon, num_dict_feat = utils.create_lexicon("en_test", conllu_file)
|
| 216 |
+
lexicon = sorted(lexicon)
|
| 217 |
+
expected_lexicon = ["'d", 'announced', 'baghdad', 'being', 'busted', 'by', 'cells', 'dpa', 'in', 'interior', 'iraqi', 'ministry', 'of', 'officials', 'operating', 'run', 'terrorist', 'that', 'the', 'them', 'they', "they'd", 'three', 'two', 'up', 'were']
|
| 218 |
+
assert lexicon == expected_lexicon
|
| 219 |
+
assert num_dict_feat == max(len(x) for x in lexicon)
|
| 220 |
+
|
stanza/stanza/utils/charlm/__init__.py
ADDED
|
File without changes
|
stanza/stanza/utils/charlm/conll17_to_text.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Turns a directory of conllu files from the conll 2017 shared task to a text file
|
| 3 |
+
|
| 4 |
+
Part of the process for building a charlm dataset
|
| 5 |
+
|
| 6 |
+
python conll17_to_text.py <directory>
|
| 7 |
+
|
| 8 |
+
This is an extension of the original script:
|
| 9 |
+
https://github.com/stanfordnlp/stanza-scripts/blob/master/charlm/conll17/conll2txt.py
|
| 10 |
+
|
| 11 |
+
To build a new charlm for a new language from a conll17 dataset:
|
| 12 |
+
- look for conll17 shared task data, possibly here:
|
| 13 |
+
https://lindat.mff.cuni.cz/repository/xmlui/handle/11234/1-1989
|
| 14 |
+
- python3 stanza/utils/charlm/conll17_to_text.py ~/extern_data/conll17/Bulgarian --output_directory extern_data/charlm_raw/bg/conll17
|
| 15 |
+
- python3 stanza/utils/charlm/make_lm_data.py --langs bg extern_data/charlm_raw extern_data/charlm/
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import argparse
|
| 19 |
+
import lzma
|
| 20 |
+
import sys
|
| 21 |
+
import os
|
| 22 |
+
|
| 23 |
+
def process_file(input_filename, output_directory, compress):
|
| 24 |
+
if not input_filename.endswith('.conllu') and not input_filename.endswith(".conllu.xz"):
|
| 25 |
+
print("Skipping {}".format(input_filename))
|
| 26 |
+
return
|
| 27 |
+
|
| 28 |
+
if input_filename.endswith(".xz"):
|
| 29 |
+
open_fn = lambda x: lzma.open(x, mode='rt')
|
| 30 |
+
output_filename = input_filename[:-3].replace(".conllu", ".txt")
|
| 31 |
+
else:
|
| 32 |
+
open_fn = lambda x: open(x)
|
| 33 |
+
output_filename = input_filename.replace('.conllu', '.txt')
|
| 34 |
+
|
| 35 |
+
if output_directory:
|
| 36 |
+
output_filename = os.path.join(output_directory, os.path.split(output_filename)[1])
|
| 37 |
+
|
| 38 |
+
if compress:
|
| 39 |
+
output_filename = output_filename + ".xz"
|
| 40 |
+
output_fn = lambda x: lzma.open(x, mode='wt')
|
| 41 |
+
else:
|
| 42 |
+
output_fn = lambda x: open(x, mode='w')
|
| 43 |
+
|
| 44 |
+
if os.path.exists(output_filename):
|
| 45 |
+
print("Cowardly refusing to overwrite %s" % output_filename)
|
| 46 |
+
return
|
| 47 |
+
|
| 48 |
+
print("Converting %s to %s" % (input_filename, output_filename))
|
| 49 |
+
with open_fn(input_filename) as fin:
|
| 50 |
+
sentences = []
|
| 51 |
+
sentence = []
|
| 52 |
+
for line in fin:
|
| 53 |
+
line = line.strip()
|
| 54 |
+
if len(line) == 0: # new sentence
|
| 55 |
+
sentences.append(sentence)
|
| 56 |
+
sentence = []
|
| 57 |
+
continue
|
| 58 |
+
if line[0] == '#': # comment
|
| 59 |
+
continue
|
| 60 |
+
splitline = line.split('\t')
|
| 61 |
+
assert(len(splitline) == 10) # correct conllu
|
| 62 |
+
id, word = splitline[0], splitline[1]
|
| 63 |
+
if '-' not in id: # not mwt token
|
| 64 |
+
sentence.append(word)
|
| 65 |
+
|
| 66 |
+
if sentence:
|
| 67 |
+
sentences.append(sentence)
|
| 68 |
+
|
| 69 |
+
print(" Read in {} sentences".format(len(sentences)))
|
| 70 |
+
with output_fn(output_filename) as fout:
|
| 71 |
+
fout.write('\n'.join([' '.join(sentence) for sentence in sentences]))
|
| 72 |
+
|
| 73 |
+
def parse_args():
|
| 74 |
+
parser = argparse.ArgumentParser()
|
| 75 |
+
parser.add_argument("input_directory", help="Root directory with conllu or conllu.xz files.")
|
| 76 |
+
parser.add_argument("--output_directory", default=None, help="Directory to output to. Will output to input_directory if None")
|
| 77 |
+
parser.add_argument("--no_xz_output", default=True, dest="xz_output", action="store_false", help="Output compressed xz files")
|
| 78 |
+
args = parser.parse_args()
|
| 79 |
+
return args
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
if __name__ == '__main__':
|
| 83 |
+
args = parse_args()
|
| 84 |
+
directory = args.input_directory
|
| 85 |
+
filenames = sorted(os.listdir(directory))
|
| 86 |
+
print("Files to process in {}: {}".format(directory, filenames))
|
| 87 |
+
print("Processing to .xz files: {}".format(args.xz_output))
|
| 88 |
+
|
| 89 |
+
if args.output_directory:
|
| 90 |
+
os.makedirs(args.output_directory, exist_ok=True)
|
| 91 |
+
for filename in filenames:
|
| 92 |
+
process_file(os.path.join(directory, filename), args.output_directory, args.xz_output)
|
| 93 |
+
|
stanza/stanza/utils/charlm/dump_oscar.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This script downloads and extracts the text from an Oscar crawl on HuggingFace
|
| 3 |
+
|
| 4 |
+
To use, just run
|
| 5 |
+
|
| 6 |
+
dump_oscar.py <lang>
|
| 7 |
+
|
| 8 |
+
It will download the dataset and output all of the text to the --output directory.
|
| 9 |
+
Files will be broken into pieces to avoid having one giant file.
|
| 10 |
+
By default, files will also be compressed with xz (although this can be turned off)
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import argparse
|
| 14 |
+
import lzma
|
| 15 |
+
import math
|
| 16 |
+
import os
|
| 17 |
+
|
| 18 |
+
from tqdm import tqdm
|
| 19 |
+
|
| 20 |
+
from datasets import get_dataset_split_names
|
| 21 |
+
from datasets import load_dataset
|
| 22 |
+
|
| 23 |
+
from stanza.models.common.constant import lang_to_langcode
|
| 24 |
+
|
| 25 |
+
def parse_args():
|
| 26 |
+
"""
|
| 27 |
+
A few specific arguments for the dump program
|
| 28 |
+
|
| 29 |
+
Uses lang_to_langcode to process args.language, hopefully converting
|
| 30 |
+
a variety of possible formats to the short code used by HuggingFace
|
| 31 |
+
"""
|
| 32 |
+
parser = argparse.ArgumentParser()
|
| 33 |
+
parser.add_argument("language", help="Language to download")
|
| 34 |
+
parser.add_argument("--output", default="oscar_dump", help="Path for saving files")
|
| 35 |
+
parser.add_argument("--no_xz", dest="xz", default=True, action='store_false', help="Don't xz the files - default is to compress while writing")
|
| 36 |
+
parser.add_argument("--prefix", default="oscar_dump", help="Prefix to use for the pieces of the dataset")
|
| 37 |
+
parser.add_argument("--version", choices=["2019", "2023"], default="2023", help="Which version of the Oscar dataset to download")
|
| 38 |
+
|
| 39 |
+
args = parser.parse_args()
|
| 40 |
+
args.language = lang_to_langcode(args.language)
|
| 41 |
+
return args
|
| 42 |
+
|
| 43 |
+
def download_2023(args):
|
| 44 |
+
dataset = load_dataset('oscar-corpus/OSCAR-2301', 'sd')
|
| 45 |
+
split_names = list(dataset.keys())
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def main():
|
| 49 |
+
args = parse_args()
|
| 50 |
+
|
| 51 |
+
# this is the 2019 version. for 2023, you can do
|
| 52 |
+
# dataset = load_dataset('oscar-corpus/OSCAR-2301', 'sd')
|
| 53 |
+
language = args.language
|
| 54 |
+
if args.version == "2019":
|
| 55 |
+
dataset_name = "unshuffled_deduplicated_%s" % language
|
| 56 |
+
try:
|
| 57 |
+
split_names = get_dataset_split_names("oscar", dataset_name)
|
| 58 |
+
except ValueError as e:
|
| 59 |
+
raise ValueError("Language %s not available in HuggingFace Oscar" % language) from e
|
| 60 |
+
|
| 61 |
+
if len(split_names) > 1:
|
| 62 |
+
raise ValueError("Unexpected split_names: {}".format(split_names))
|
| 63 |
+
|
| 64 |
+
dataset = load_dataset("oscar", dataset_name)
|
| 65 |
+
dataset = dataset[split_names[0]]
|
| 66 |
+
size_in_bytes = dataset.info.size_in_bytes
|
| 67 |
+
process_item = lambda x: x['text']
|
| 68 |
+
elif args.version == "2023":
|
| 69 |
+
dataset = load_dataset("oscar-corpus/OSCAR-2301", language)
|
| 70 |
+
split_names = list(dataset.keys())
|
| 71 |
+
if len(split_names) > 1:
|
| 72 |
+
raise ValueError("Unexpected split_names: {}".format(split_names))
|
| 73 |
+
# it's not clear if some languages don't support size_in_bytes,
|
| 74 |
+
# or if there was an update to datasets which now allows that
|
| 75 |
+
#
|
| 76 |
+
# previously we did:
|
| 77 |
+
# dataset = dataset[split_names[0]]['text']
|
| 78 |
+
# size_in_bytes = sum(len(x) for x in dataset)
|
| 79 |
+
# process_item = lambda x: x
|
| 80 |
+
dataset = dataset[split_names[0]]
|
| 81 |
+
size_in_bytes = dataset.info.size_in_bytes
|
| 82 |
+
process_item = lambda x: x['text']
|
| 83 |
+
else:
|
| 84 |
+
raise AssertionError("Unknown version: %s" % args.version)
|
| 85 |
+
|
| 86 |
+
chunks = max(1.0, size_in_bytes // 1e8) # an overestimate
|
| 87 |
+
id_len = max(3, math.floor(math.log10(chunks)) + 1)
|
| 88 |
+
|
| 89 |
+
if args.xz:
|
| 90 |
+
format_str = "%s_%%0%dd.txt.xz" % (args.prefix, id_len)
|
| 91 |
+
fopen = lambda file_idx: lzma.open(os.path.join(args.output, format_str % file_idx), "wt")
|
| 92 |
+
else:
|
| 93 |
+
format_str = "%s_%%0%dd.txt" % (args.prefix, id_len)
|
| 94 |
+
fopen = lambda file_idx: open(os.path.join(args.output, format_str % file_idx), "w")
|
| 95 |
+
|
| 96 |
+
print("Writing dataset to %s" % args.output)
|
| 97 |
+
print("Dataset length: {}".format(size_in_bytes))
|
| 98 |
+
os.makedirs(args.output, exist_ok=True)
|
| 99 |
+
|
| 100 |
+
file_idx = 0
|
| 101 |
+
file_len = 0
|
| 102 |
+
total_len = 0
|
| 103 |
+
fout = fopen(file_idx)
|
| 104 |
+
|
| 105 |
+
for item in tqdm(dataset):
|
| 106 |
+
text = process_item(item)
|
| 107 |
+
fout.write(text)
|
| 108 |
+
fout.write("\n")
|
| 109 |
+
file_len += len(text)
|
| 110 |
+
file_len += 1
|
| 111 |
+
if file_len > 1e8:
|
| 112 |
+
file_len = 0
|
| 113 |
+
fout.close()
|
| 114 |
+
file_idx = file_idx + 1
|
| 115 |
+
fout = fopen(file_idx)
|
| 116 |
+
|
| 117 |
+
fout.close()
|
| 118 |
+
|
| 119 |
+
if __name__ == '__main__':
|
| 120 |
+
main()
|
stanza/stanza/utils/charlm/make_lm_data.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Create Stanza character LM train/dev/test data, by reading from txt files in each source corpus directory,
|
| 3 |
+
shuffling, splitting and saving into multiple smaller files (50MB by default) in a target directory.
|
| 4 |
+
|
| 5 |
+
This script assumes the following source directory structures:
|
| 6 |
+
- {src_dir}/{language}/{corpus}/*.txt
|
| 7 |
+
It will read from all source .txt files and create the following target directory structures:
|
| 8 |
+
- {tgt_dir}/{language}/{corpus}
|
| 9 |
+
and within each target directory, it will create the following files:
|
| 10 |
+
- train/*.txt
|
| 11 |
+
- dev.txt
|
| 12 |
+
- test.txt
|
| 13 |
+
Args:
|
| 14 |
+
- src_root: root directory of the source.
|
| 15 |
+
- tgt_root: root directory of the target.
|
| 16 |
+
- langs: a list of language codes to process; if specified, languages not in this list will be ignored.
|
| 17 |
+
Note: edit the {EXCLUDED_FOLDERS} variable to exclude more folders in the source directory.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
import argparse
|
| 21 |
+
import glob
|
| 22 |
+
import os
|
| 23 |
+
from pathlib import Path
|
| 24 |
+
import shutil
|
| 25 |
+
import subprocess
|
| 26 |
+
import tempfile
|
| 27 |
+
|
| 28 |
+
from tqdm import tqdm
|
| 29 |
+
|
| 30 |
+
EXCLUDED_FOLDERS = ['raw_corpus']
|
| 31 |
+
|
| 32 |
+
def main():
|
| 33 |
+
parser = argparse.ArgumentParser()
|
| 34 |
+
parser.add_argument("src_root", default="src", help="Root directory with all source files. Expected structure is root dir -> language dirs -> package dirs -> text files to process")
|
| 35 |
+
parser.add_argument("tgt_root", default="tgt", help="Root directory with all target files.")
|
| 36 |
+
parser.add_argument("--langs", default="", help="A list of language codes to process. If not set, all languages under src_root will be processed.")
|
| 37 |
+
parser.add_argument("--packages", default="", help="A list of packages to process. If not set, all packages under the languages found will be processed.")
|
| 38 |
+
parser.add_argument("--no_xz_output", default=True, dest="xz_output", action="store_false", help="Output compressed xz files")
|
| 39 |
+
parser.add_argument("--split_size", default=50, type=int, help="How large to make each split, in MB")
|
| 40 |
+
parser.add_argument("--no_make_test_file", default=True, dest="make_test_file", action="store_false", help="Don't save a test file. Honestly, we never even use it. Best for low resource languages where every bit helps")
|
| 41 |
+
args = parser.parse_args()
|
| 42 |
+
|
| 43 |
+
print("Processing files:")
|
| 44 |
+
print(f"source root: {args.src_root}")
|
| 45 |
+
print(f"target root: {args.tgt_root}")
|
| 46 |
+
print("")
|
| 47 |
+
|
| 48 |
+
langs = []
|
| 49 |
+
if len(args.langs) > 0:
|
| 50 |
+
langs = args.langs.split(',')
|
| 51 |
+
print("Only processing the following languages: " + str(langs))
|
| 52 |
+
|
| 53 |
+
packages = []
|
| 54 |
+
if len(args.packages) > 0:
|
| 55 |
+
packages = args.packages.split(',')
|
| 56 |
+
print("Only processing the following packages: " + str(packages))
|
| 57 |
+
|
| 58 |
+
src_root = Path(args.src_root)
|
| 59 |
+
tgt_root = Path(args.tgt_root)
|
| 60 |
+
|
| 61 |
+
lang_dirs = os.listdir(src_root)
|
| 62 |
+
lang_dirs = [l for l in lang_dirs if l not in EXCLUDED_FOLDERS] # skip excluded
|
| 63 |
+
lang_dirs = [l for l in lang_dirs if os.path.isdir(src_root / l)] # skip non-directory
|
| 64 |
+
if len(langs) > 0: # filter languages if specified
|
| 65 |
+
lang_dirs = [l for l in lang_dirs if l in langs]
|
| 66 |
+
print(f"{len(lang_dirs)} total languages found:")
|
| 67 |
+
print(lang_dirs)
|
| 68 |
+
print("")
|
| 69 |
+
|
| 70 |
+
split_size = int(args.split_size * 1024 * 1024)
|
| 71 |
+
|
| 72 |
+
for lang in lang_dirs:
|
| 73 |
+
lang_root = src_root / lang
|
| 74 |
+
data_dirs = os.listdir(lang_root)
|
| 75 |
+
if len(packages) > 0:
|
| 76 |
+
data_dirs = [d for d in data_dirs if d in packages]
|
| 77 |
+
data_dirs = [d for d in data_dirs if os.path.isdir(lang_root / d)]
|
| 78 |
+
print(f"{len(data_dirs)} total corpus found for language {lang}.")
|
| 79 |
+
print(data_dirs)
|
| 80 |
+
print("")
|
| 81 |
+
|
| 82 |
+
for dataset_name in data_dirs:
|
| 83 |
+
src_dir = lang_root / dataset_name
|
| 84 |
+
tgt_dir = tgt_root / lang / dataset_name
|
| 85 |
+
|
| 86 |
+
if not os.path.exists(tgt_dir):
|
| 87 |
+
os.makedirs(tgt_dir)
|
| 88 |
+
print(f"-> Processing {lang}-{dataset_name}")
|
| 89 |
+
prepare_lm_data(src_dir, tgt_dir, lang, dataset_name, args.xz_output, split_size, args.make_test_file)
|
| 90 |
+
|
| 91 |
+
print("")
|
| 92 |
+
|
| 93 |
+
def prepare_lm_data(src_dir, tgt_dir, lang, dataset_name, compress, split_size, make_test_file):
|
| 94 |
+
"""
|
| 95 |
+
Combine, shuffle and split data into smaller files, following a naming convention.
|
| 96 |
+
"""
|
| 97 |
+
assert isinstance(src_dir, Path)
|
| 98 |
+
assert isinstance(tgt_dir, Path)
|
| 99 |
+
with tempfile.TemporaryDirectory(dir=tgt_dir) as tempdir:
|
| 100 |
+
tgt_tmp = os.path.join(tempdir, f"{lang}-{dataset_name}.tmp")
|
| 101 |
+
print(f"--> Copying files into {tgt_tmp}...")
|
| 102 |
+
# TODO: we can do this without the shell commands
|
| 103 |
+
input_files = glob.glob(str(src_dir) + '/*.txt') + glob.glob(str(src_dir) + '/*.txt.xz') + glob.glob(str(src_dir) + '/*.txt.gz')
|
| 104 |
+
for src_fn in tqdm(input_files):
|
| 105 |
+
if src_fn.endswith(".txt"):
|
| 106 |
+
cmd = f"cat {src_fn} >> {tgt_tmp}"
|
| 107 |
+
subprocess.run(cmd, shell=True)
|
| 108 |
+
elif src_fn.endswith(".txt.xz"):
|
| 109 |
+
cmd = f"xzcat {src_fn} >> {tgt_tmp}"
|
| 110 |
+
subprocess.run(cmd, shell=True)
|
| 111 |
+
elif src_fn.endswith(".txt.gz"):
|
| 112 |
+
cmd = f"zcat {src_fn} >> {tgt_tmp}"
|
| 113 |
+
subprocess.run(cmd, shell=True)
|
| 114 |
+
else:
|
| 115 |
+
raise AssertionError("should not have found %s" % src_fn)
|
| 116 |
+
tgt_tmp_shuffled = os.path.join(tempdir, f"{lang}-{dataset_name}.tmp.shuffled")
|
| 117 |
+
|
| 118 |
+
print(f"--> Shuffling files into {tgt_tmp_shuffled}...")
|
| 119 |
+
cmd = f"cat {tgt_tmp} | shuf > {tgt_tmp_shuffled}"
|
| 120 |
+
result = subprocess.run(cmd, shell=True)
|
| 121 |
+
if result.returncode != 0:
|
| 122 |
+
raise RuntimeError("Failed to shuffle files!")
|
| 123 |
+
size = os.path.getsize(tgt_tmp_shuffled) / 1024 / 1024 / 1024
|
| 124 |
+
print(f"--> Shuffled file size: {size:.4f} GB")
|
| 125 |
+
if size < 0.1:
|
| 126 |
+
raise RuntimeError("Not enough data found to build a charlm. At least 100MB data expected")
|
| 127 |
+
|
| 128 |
+
print(f"--> Splitting into smaller files of size {split_size} ...")
|
| 129 |
+
train_dir = tgt_dir / 'train'
|
| 130 |
+
if not os.path.exists(train_dir): # make training dir
|
| 131 |
+
os.makedirs(train_dir)
|
| 132 |
+
cmd = f"split -C {split_size} -a 4 -d --additional-suffix .txt {tgt_tmp_shuffled} {train_dir}/{lang}-{dataset_name}-"
|
| 133 |
+
result = subprocess.run(cmd, shell=True)
|
| 134 |
+
if result.returncode != 0:
|
| 135 |
+
raise RuntimeError("Failed to split files!")
|
| 136 |
+
total = len(glob.glob(f'{train_dir}/*.txt'))
|
| 137 |
+
print(f"--> {total} total files generated.")
|
| 138 |
+
if total < 3:
|
| 139 |
+
raise RuntimeError("Something went wrong! %d file(s) produced by shuffle and split, expected at least 3" % total)
|
| 140 |
+
|
| 141 |
+
dev_file = f"{tgt_dir}/dev.txt"
|
| 142 |
+
test_file = f"{tgt_dir}/test.txt"
|
| 143 |
+
if make_test_file:
|
| 144 |
+
print("--> Creating dev and test files...")
|
| 145 |
+
shutil.move(f"{train_dir}/{lang}-{dataset_name}-0000.txt", dev_file)
|
| 146 |
+
shutil.move(f"{train_dir}/{lang}-{dataset_name}-0001.txt", test_file)
|
| 147 |
+
txt_files = [dev_file, test_file] + glob.glob(f'{train_dir}/*.txt')
|
| 148 |
+
else:
|
| 149 |
+
print("--> Creating dev file...")
|
| 150 |
+
shutil.move(f"{train_dir}/{lang}-{dataset_name}-0000.txt", dev_file)
|
| 151 |
+
txt_files = [dev_file] + glob.glob(f'{train_dir}/*.txt')
|
| 152 |
+
|
| 153 |
+
if compress:
|
| 154 |
+
print("--> Compressing files...")
|
| 155 |
+
for txt_file in tqdm(txt_files):
|
| 156 |
+
subprocess.run(['xz', txt_file])
|
| 157 |
+
|
| 158 |
+
print("--> Cleaning up...")
|
| 159 |
+
print(f"--> All done for {lang}-{dataset_name}.\n")
|
| 160 |
+
|
| 161 |
+
if __name__ == "__main__":
|
| 162 |
+
main()
|
stanza/stanza/utils/constituency/check_transitions.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
|
| 3 |
+
from stanza.models.constituency import transition_sequence
|
| 4 |
+
from stanza.models.constituency import tree_reader
|
| 5 |
+
from stanza.models.constituency.parse_transitions import TransitionScheme
|
| 6 |
+
from stanza.models.constituency.parse_tree import Tree
|
| 7 |
+
from stanza.models.constituency.utils import verify_transitions
|
| 8 |
+
|
| 9 |
+
def main():
|
| 10 |
+
parser = argparse.ArgumentParser()
|
| 11 |
+
parser.add_argument('--train_file', type=str, default="data/constituency/en_ptb3_train.mrg", help='Input file for data loader.')
|
| 12 |
+
parser.add_argument('--transition_scheme', default=TransitionScheme.IN_ORDER, type=lambda x: TransitionScheme[x.upper()],
|
| 13 |
+
help='Transition scheme to use. {}'.format(", ".join(x.name for x in TransitionScheme)))
|
| 14 |
+
parser.add_argument('--reversed', default=False, action='store_true', help='Do the transition sequence reversed')
|
| 15 |
+
parser.add_argument('--iterations', default=30, type=int, help='How many times to iterate, such as if doing a cProfile')
|
| 16 |
+
args = parser.parse_args()
|
| 17 |
+
args = vars(args)
|
| 18 |
+
|
| 19 |
+
train_trees = tree_reader.read_treebank(args['train_file'])
|
| 20 |
+
unary_limit = max(t.count_unary_depth() for t in train_trees) + 1
|
| 21 |
+
train_sequences, train_transitions = transition_sequence.convert_trees_to_sequences(train_trees, "training", args['transition_scheme'], args['reversed'])
|
| 22 |
+
root_labels = Tree.get_root_labels(train_trees)
|
| 23 |
+
for i in range(args['iterations']):
|
| 24 |
+
verify_transitions(train_trees, train_sequences, args['transition_scheme'], unary_limit, args['reversed'], "train", root_labels)
|
| 25 |
+
|
| 26 |
+
if __name__ == '__main__':
|
| 27 |
+
main()
|
stanza/stanza/utils/constituency/list_tensors.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Lists all the tensors in a constituency model.
|
| 3 |
+
|
| 4 |
+
Currently useful in combination with torchshow for displaying a series of tensors as they change.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import sys
|
| 8 |
+
|
| 9 |
+
from stanza.models.constituency.trainer import Trainer
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
trainer = Trainer.load(sys.argv[1])
|
| 13 |
+
model = trainer.model
|
| 14 |
+
|
| 15 |
+
for name, param in model.named_parameters():
|
| 16 |
+
print(name, param.requires_grad)
|
stanza/stanza/utils/datasets/__init__.py
ADDED
|
File without changes
|
stanza/stanza/utils/datasets/contract_mwt.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
|
| 3 |
+
def contract_mwt(infile, outfile, ignore_gapping=True):
|
| 4 |
+
"""
|
| 5 |
+
Simplify the gold tokenizer data for use as MWT processor test files
|
| 6 |
+
|
| 7 |
+
The simplifications are to remove the expanded MWTs, and in the
|
| 8 |
+
case of ignore_gapping=True, remove any copy words for the dependencies
|
| 9 |
+
"""
|
| 10 |
+
with open(outfile, 'w') as fout:
|
| 11 |
+
with open(infile, 'r') as fin:
|
| 12 |
+
idx = 0
|
| 13 |
+
mwt_begin = 0
|
| 14 |
+
mwt_end = -1
|
| 15 |
+
for line in fin:
|
| 16 |
+
line = line.strip()
|
| 17 |
+
|
| 18 |
+
if line.startswith('#'):
|
| 19 |
+
print(line, file=fout)
|
| 20 |
+
continue
|
| 21 |
+
elif len(line) <= 0:
|
| 22 |
+
print(line, file=fout)
|
| 23 |
+
idx = 0
|
| 24 |
+
mwt_begin = 0
|
| 25 |
+
mwt_end = -1
|
| 26 |
+
continue
|
| 27 |
+
|
| 28 |
+
line = line.split('\t')
|
| 29 |
+
|
| 30 |
+
# ignore gapping word
|
| 31 |
+
if ignore_gapping and '.' in line[0]:
|
| 32 |
+
continue
|
| 33 |
+
|
| 34 |
+
idx += 1
|
| 35 |
+
if '-' in line[0]:
|
| 36 |
+
mwt_begin, mwt_end = [int(x) for x in line[0].split('-')]
|
| 37 |
+
print("{}\t{}\t{}".format(idx, "\t".join(line[1:-1]), "MWT=Yes" if line[-1] == '_' else line[-1] + "|MWT=Yes"), file=fout)
|
| 38 |
+
idx -= 1
|
| 39 |
+
elif mwt_begin <= idx <= mwt_end:
|
| 40 |
+
continue
|
| 41 |
+
else:
|
| 42 |
+
print("{}\t{}".format(idx, "\t".join(line[1:])), file=fout)
|
| 43 |
+
|
| 44 |
+
if __name__ == '__main__':
|
| 45 |
+
contract_mwt(sys.argv[1], sys.argv[2])
|
| 46 |
+
|
stanza/stanza/utils/datasets/coref/__init__.py
ADDED
|
File without changes
|
stanza/stanza/utils/datasets/coref/convert_ontonotes.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
import stanza
|
| 5 |
+
|
| 6 |
+
from stanza.models.constituency import tree_reader
|
| 7 |
+
from stanza.utils.default_paths import get_default_paths
|
| 8 |
+
from stanza.utils.get_tqdm import get_tqdm
|
| 9 |
+
from stanza.utils.datasets.coref.utils import process_document
|
| 10 |
+
|
| 11 |
+
tqdm = get_tqdm()
|
| 12 |
+
|
| 13 |
+
def read_paragraphs(section):
|
| 14 |
+
for doc in section:
|
| 15 |
+
part_id = None
|
| 16 |
+
paragraph = []
|
| 17 |
+
for sentence in doc['sentences']:
|
| 18 |
+
if part_id is None:
|
| 19 |
+
part_id = sentence['part_id']
|
| 20 |
+
elif part_id != sentence['part_id']:
|
| 21 |
+
yield doc['document_id'], part_id, paragraph
|
| 22 |
+
paragraph = []
|
| 23 |
+
part_id = sentence['part_id']
|
| 24 |
+
paragraph.append(sentence)
|
| 25 |
+
if paragraph != []:
|
| 26 |
+
yield doc['document_id'], part_id, paragraph
|
| 27 |
+
|
| 28 |
+
def convert_dataset_section(pipe, section):
|
| 29 |
+
processed_section = []
|
| 30 |
+
section = list(x for x in read_paragraphs(section))
|
| 31 |
+
|
| 32 |
+
for idx, (doc_id, part_id, paragraph) in enumerate(tqdm(section)):
|
| 33 |
+
sentences = [x['words'] for x in paragraph]
|
| 34 |
+
coref_spans = [x['coref_spans'] for x in paragraph]
|
| 35 |
+
sentence_speakers = [x['speaker'] for x in paragraph]
|
| 36 |
+
|
| 37 |
+
processed = process_document(pipe, doc_id, part_id, sentences, coref_spans, sentence_speakers)
|
| 38 |
+
processed_section.append(processed)
|
| 39 |
+
return processed_section
|
| 40 |
+
|
| 41 |
+
SECTION_NAMES = {"train": "train",
|
| 42 |
+
"dev": "validation",
|
| 43 |
+
"test": "test"}
|
| 44 |
+
|
| 45 |
+
def process_dataset(short_name, ontonotes_path, coref_output_path):
|
| 46 |
+
try:
|
| 47 |
+
from datasets import load_dataset
|
| 48 |
+
except ImportError as e:
|
| 49 |
+
raise ImportError("Please install the datasets package to process OntoNotes coref with Stanza")
|
| 50 |
+
|
| 51 |
+
if short_name == 'en_ontonotes':
|
| 52 |
+
config_name = 'english_v4'
|
| 53 |
+
elif short_name in ('zh_ontonotes', 'zh-hans_ontonotes'):
|
| 54 |
+
config_name = 'chinese_v4'
|
| 55 |
+
elif short_name == 'ar_ontonotes':
|
| 56 |
+
config_name = 'arabic_v4'
|
| 57 |
+
else:
|
| 58 |
+
raise ValueError("Unknown short name for downloading ontonotes: %s" % short_name)
|
| 59 |
+
|
| 60 |
+
pipe = stanza.Pipeline("en", processors="tokenize,pos,lemma,depparse", package="default_accurate", tokenize_pretokenized=True)
|
| 61 |
+
dataset = load_dataset("conll2012_ontonotesv5", config_name, cache_dir=ontonotes_path)
|
| 62 |
+
for section, hf_name in SECTION_NAMES.items():
|
| 63 |
+
#for section, hf_name in [("test", "test")]:
|
| 64 |
+
print("Processing %s" % section)
|
| 65 |
+
converted_section = convert_dataset_section(pipe, dataset[hf_name])
|
| 66 |
+
output_filename = os.path.join(coref_output_path, "%s.%s.json" % (short_name, section))
|
| 67 |
+
with open(output_filename, "w", encoding="utf-8") as fout:
|
| 68 |
+
json.dump(converted_section, fout, indent=2)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def main():
|
| 72 |
+
paths = get_default_paths()
|
| 73 |
+
coref_input_path = paths['COREF_BASE']
|
| 74 |
+
ontonotes_path = os.path.join(coref_input_path, "english", "en_ontonotes")
|
| 75 |
+
coref_output_path = paths['COREF_DATA_DIR']
|
| 76 |
+
process_dataset("en_ontonotes", ontonotes_path, coref_output_path)
|
| 77 |
+
|
| 78 |
+
if __name__ == '__main__':
|
| 79 |
+
main()
|
| 80 |
+
|
stanza/stanza/utils/datasets/coref/convert_udcoref.py
ADDED
|
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import defaultdict
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
import re
|
| 5 |
+
import glob
|
| 6 |
+
|
| 7 |
+
from stanza.utils.default_paths import get_default_paths
|
| 8 |
+
from stanza.utils.get_tqdm import get_tqdm
|
| 9 |
+
from stanza.utils.datasets.coref.utils import find_cconj_head
|
| 10 |
+
|
| 11 |
+
from stanza.utils.conll import CoNLL
|
| 12 |
+
|
| 13 |
+
from random import Random
|
| 14 |
+
|
| 15 |
+
import argparse
|
| 16 |
+
|
| 17 |
+
augment_random = Random(7)
|
| 18 |
+
split_random = Random(8)
|
| 19 |
+
|
| 20 |
+
tqdm = get_tqdm()
|
| 21 |
+
IS_UDCOREF_FORMAT = True
|
| 22 |
+
UDCOREF_ADDN = 0 if not IS_UDCOREF_FORMAT else 1
|
| 23 |
+
|
| 24 |
+
def process_documents(docs, augment=False):
|
| 25 |
+
processed_section = []
|
| 26 |
+
|
| 27 |
+
for idx, (doc, doc_id, lang) in enumerate(tqdm(docs)):
|
| 28 |
+
# drop the last token 10% of the time
|
| 29 |
+
if augment:
|
| 30 |
+
for i in doc.sentences:
|
| 31 |
+
if len(i.words) > 1:
|
| 32 |
+
if augment_random.random() < 0.1:
|
| 33 |
+
i.tokens = i.tokens[:-1]
|
| 34 |
+
i.words = i.words[:-1]
|
| 35 |
+
|
| 36 |
+
# extract the entities
|
| 37 |
+
# get sentence words and lengths
|
| 38 |
+
sentences = [[j.text for j in i.words]
|
| 39 |
+
for i in doc.sentences]
|
| 40 |
+
sentence_lens = [len(x.words) for x in doc.sentences]
|
| 41 |
+
|
| 42 |
+
cased_words = []
|
| 43 |
+
for x in sentences:
|
| 44 |
+
if augment:
|
| 45 |
+
# modify case of the first word with 50% chance
|
| 46 |
+
if augment_random.random() < 0.5:
|
| 47 |
+
x[0] = x[0].lower()
|
| 48 |
+
|
| 49 |
+
for y in x:
|
| 50 |
+
cased_words.append(y)
|
| 51 |
+
|
| 52 |
+
sent_id = [y for idx, sent_len in enumerate(sentence_lens) for y in [idx] * sent_len]
|
| 53 |
+
|
| 54 |
+
word_total = 0
|
| 55 |
+
heads = []
|
| 56 |
+
# TODO: does SD vs UD matter?
|
| 57 |
+
deprel = []
|
| 58 |
+
for sentence in doc.sentences:
|
| 59 |
+
for word in sentence.words:
|
| 60 |
+
deprel.append(word.deprel)
|
| 61 |
+
if word.head == 0:
|
| 62 |
+
heads.append("null")
|
| 63 |
+
else:
|
| 64 |
+
heads.append(word.head - 1 + word_total)
|
| 65 |
+
word_total += len(sentence.words)
|
| 66 |
+
|
| 67 |
+
span_clusters = defaultdict(list)
|
| 68 |
+
word_clusters = defaultdict(list)
|
| 69 |
+
head2span = []
|
| 70 |
+
word_total = 0
|
| 71 |
+
SPANS = re.compile(r"(\(\w+|[%\w]+\))")
|
| 72 |
+
for parsed_sentence in doc.sentences:
|
| 73 |
+
# spans regex
|
| 74 |
+
# parse the misc column, leaving on "Entity" entries
|
| 75 |
+
misc = [[k.split("=")
|
| 76 |
+
for k in j
|
| 77 |
+
if k.split("=")[0] == "Entity"]
|
| 78 |
+
for i in parsed_sentence.words
|
| 79 |
+
for j in [i.misc.split("|") if i.misc else []]]
|
| 80 |
+
# and extract the Entity entry values
|
| 81 |
+
entities = [i[0][1] if len(i) > 0 else None for i in misc]
|
| 82 |
+
# extract reference information
|
| 83 |
+
refs = [SPANS.findall(i) if i else [] for i in entities]
|
| 84 |
+
# and calculate spans: the basic rule is (e... begins a reference
|
| 85 |
+
# and ) without e before ends the most recent reference
|
| 86 |
+
# every single time we get a closing element, we pop it off
|
| 87 |
+
# the refdict and insert the pair to final_refs
|
| 88 |
+
refdict = defaultdict(list)
|
| 89 |
+
final_refs = defaultdict(list)
|
| 90 |
+
last_ref = None
|
| 91 |
+
for indx, i in enumerate(refs):
|
| 92 |
+
for j in i:
|
| 93 |
+
# this is the beginning of a reference
|
| 94 |
+
if j[0] == "(":
|
| 95 |
+
refdict[j[1+UDCOREF_ADDN:]].append(indx)
|
| 96 |
+
last_ref = j[1+UDCOREF_ADDN:]
|
| 97 |
+
# at the end of a reference, if we got exxxxx, that ends
|
| 98 |
+
# a particular refereenc; otherwise, it ends the last reference
|
| 99 |
+
elif j[-1] == ")" and j[UDCOREF_ADDN:-1].isnumeric():
|
| 100 |
+
if (not UDCOREF_ADDN) or j[0] == "e":
|
| 101 |
+
try:
|
| 102 |
+
final_refs[j[UDCOREF_ADDN:-1]].append((refdict[j[UDCOREF_ADDN:-1]].pop(-1), indx))
|
| 103 |
+
except IndexError:
|
| 104 |
+
# this is probably zero anaphora
|
| 105 |
+
continue
|
| 106 |
+
elif j[-1] == ")":
|
| 107 |
+
final_refs[last_ref].append((refdict[last_ref].pop(-1), indx))
|
| 108 |
+
last_ref = None
|
| 109 |
+
final_refs = dict(final_refs)
|
| 110 |
+
# convert it to the right format (specifically, in (ref, start, end) tuples)
|
| 111 |
+
coref_spans = []
|
| 112 |
+
for k, v in final_refs.items():
|
| 113 |
+
for i in v:
|
| 114 |
+
coref_spans.append([int(k), i[0], i[1]])
|
| 115 |
+
sentence_upos = [x.upos for x in parsed_sentence.words]
|
| 116 |
+
sentence_heads = [x.head - 1 if x.head > 0 else None for x in parsed_sentence.words]
|
| 117 |
+
for span in coref_spans:
|
| 118 |
+
# input is expected to be start word, end word + 1
|
| 119 |
+
# counting from 0
|
| 120 |
+
# whereas the OntoNotes coref_span is [start_word, end_word] inclusive
|
| 121 |
+
span_start = span[1] + word_total
|
| 122 |
+
span_end = span[2] + word_total + 1
|
| 123 |
+
candidate_head = find_cconj_head(sentence_heads, sentence_upos, span[1], span[2]+1)
|
| 124 |
+
if candidate_head is None:
|
| 125 |
+
for candidate_head in range(span[1], span[2] + 1):
|
| 126 |
+
# stanza uses 0 to mark the head, whereas OntoNotes is counting
|
| 127 |
+
# words from 0, so we have to subtract 1 from the stanza heads
|
| 128 |
+
#print(span, candidate_head, parsed_sentence.words[candidate_head].head - 1)
|
| 129 |
+
# treat the head of the phrase as the first word that has a head outside the phrase
|
| 130 |
+
if (parsed_sentence.words[candidate_head].head - 1 < span[1] or
|
| 131 |
+
parsed_sentence.words[candidate_head].head - 1 > span[2]):
|
| 132 |
+
break
|
| 133 |
+
else:
|
| 134 |
+
# if none have a head outside the phrase (circular??)
|
| 135 |
+
# then just take the first word
|
| 136 |
+
candidate_head = span[1]
|
| 137 |
+
#print("----> %d" % candidate_head)
|
| 138 |
+
candidate_head += word_total
|
| 139 |
+
span_clusters[span[0]].append((span_start, span_end))
|
| 140 |
+
word_clusters[span[0]].append(candidate_head)
|
| 141 |
+
head2span.append((candidate_head, span_start, span_end))
|
| 142 |
+
word_total += len(parsed_sentence.words)
|
| 143 |
+
span_clusters = sorted([sorted(values) for _, values in span_clusters.items()])
|
| 144 |
+
word_clusters = sorted([sorted(values) for _, values in word_clusters.items()])
|
| 145 |
+
head2span = sorted(head2span)
|
| 146 |
+
|
| 147 |
+
processed = {
|
| 148 |
+
"document_id": doc_id,
|
| 149 |
+
"cased_words": cased_words,
|
| 150 |
+
"sent_id": sent_id,
|
| 151 |
+
"part_id": idx,
|
| 152 |
+
# "pos": pos,
|
| 153 |
+
"deprel": deprel,
|
| 154 |
+
"head": heads,
|
| 155 |
+
"span_clusters": span_clusters,
|
| 156 |
+
"word_clusters": word_clusters,
|
| 157 |
+
"head2span": head2span,
|
| 158 |
+
"lang": lang
|
| 159 |
+
}
|
| 160 |
+
processed_section.append(processed)
|
| 161 |
+
return processed_section
|
| 162 |
+
|
| 163 |
+
def process_dataset(short_name, coref_output_path, split_test, train_files, dev_files):
|
| 164 |
+
section_names = ('train', 'dev')
|
| 165 |
+
section_filenames = [train_files, dev_files]
|
| 166 |
+
sections = []
|
| 167 |
+
|
| 168 |
+
test_sections = []
|
| 169 |
+
|
| 170 |
+
for section, filenames in zip(section_names, section_filenames):
|
| 171 |
+
input_file = []
|
| 172 |
+
for load in filenames:
|
| 173 |
+
lang = load.split("/")[-1].split("_")[0]
|
| 174 |
+
print("Ingesting %s from %s of lang %s" % (section, load, lang))
|
| 175 |
+
docs = CoNLL.conll2multi_docs(load)
|
| 176 |
+
print(" Ingested %d documents" % len(docs))
|
| 177 |
+
if split_test and section == 'train':
|
| 178 |
+
test_section = []
|
| 179 |
+
train_section = []
|
| 180 |
+
for i in docs:
|
| 181 |
+
# reseed for each doc so that we can attempt to keep things stable in the event
|
| 182 |
+
# of different file orderings or some change to the number of documents
|
| 183 |
+
split_random = Random(i.sentences[0].doc_id + i.sentences[0].text)
|
| 184 |
+
if split_random.random() < split_test:
|
| 185 |
+
test_section.append((i, i.sentences[0].doc_id, lang))
|
| 186 |
+
else:
|
| 187 |
+
train_section.append((i, i.sentences[0].doc_id, lang))
|
| 188 |
+
if len(test_section) == 0 and len(train_section) >= 2:
|
| 189 |
+
idx = split_random.randint(0, len(train_section) - 1)
|
| 190 |
+
test_section = [train_section[idx]]
|
| 191 |
+
train_section = train_section[:idx] + train_section[idx+1:]
|
| 192 |
+
print(" Splitting %d documents from %s for test" % (len(test_section), load))
|
| 193 |
+
input_file.extend(train_section)
|
| 194 |
+
test_sections.append(test_section)
|
| 195 |
+
else:
|
| 196 |
+
for i in docs:
|
| 197 |
+
input_file.append((i, i.sentences[0].doc_id, lang))
|
| 198 |
+
print("Ingested %d total documents" % len(input_file))
|
| 199 |
+
sections.append(input_file)
|
| 200 |
+
|
| 201 |
+
if split_test:
|
| 202 |
+
section_names = ('train', 'dev', 'test')
|
| 203 |
+
full_test_section = []
|
| 204 |
+
for filename, test_section in zip(filenames, test_sections):
|
| 205 |
+
# TODO: could write dataset-specific test sections as well
|
| 206 |
+
full_test_section.extend(test_section)
|
| 207 |
+
sections.append(full_test_section)
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
for section_data, section_name in zip(sections, section_names):
|
| 211 |
+
converted_section = process_documents(section_data, augment=(section_name=="train"))
|
| 212 |
+
|
| 213 |
+
os.makedirs(coref_output_path, exist_ok=True)
|
| 214 |
+
output_filename = os.path.join(coref_output_path, "%s.%s.json" % (short_name, section_name))
|
| 215 |
+
with open(output_filename, "w", encoding="utf-8") as fout:
|
| 216 |
+
json.dump(converted_section, fout, indent=2)
|
| 217 |
+
|
| 218 |
+
def get_dataset_by_language(coref_input_path, langs):
|
| 219 |
+
conll_path = os.path.join(coref_input_path, "CorefUD-1.2-public", "data")
|
| 220 |
+
train_filenames = []
|
| 221 |
+
dev_filenames = []
|
| 222 |
+
for lang in langs:
|
| 223 |
+
train_filenames.extend(glob.glob(os.path.join(conll_path, "*%s*" % lang, "*train.conllu")))
|
| 224 |
+
dev_filenames.extend(glob.glob(os.path.join(conll_path, "*%s*" % lang, "*dev.conllu")))
|
| 225 |
+
train_filenames = sorted(train_filenames)
|
| 226 |
+
dev_filenames = sorted(dev_filenames)
|
| 227 |
+
return train_filenames, dev_filenames
|
| 228 |
+
|
| 229 |
+
def main():
|
| 230 |
+
paths = get_default_paths()
|
| 231 |
+
parser = argparse.ArgumentParser(
|
| 232 |
+
prog='Convert UDCoref Data',
|
| 233 |
+
)
|
| 234 |
+
parser.add_argument('--split_test', default=None, type=float, help='How much of the data to randomly split from train to make a test set')
|
| 235 |
+
|
| 236 |
+
group = parser.add_mutually_exclusive_group(required=True)
|
| 237 |
+
group.add_argument('--directory', type=str, help="the name of the subfolder for data conversion")
|
| 238 |
+
group.add_argument('--project', type=str, help="Look for and use a set of datasets for data conversion - Slavic or Hungarian")
|
| 239 |
+
|
| 240 |
+
args = parser.parse_args()
|
| 241 |
+
coref_input_path = paths['COREF_BASE']
|
| 242 |
+
coref_output_path = paths['COREF_DATA_DIR']
|
| 243 |
+
|
| 244 |
+
if args.project:
|
| 245 |
+
if args.project == 'slavic':
|
| 246 |
+
project = "slavic_udcoref"
|
| 247 |
+
langs = ('Polish', 'Russian', 'Czech')
|
| 248 |
+
train_filenames, dev_filenames = get_dataset_by_language(coref_input_path, langs)
|
| 249 |
+
elif args.project == 'hungarian':
|
| 250 |
+
project = "hu_udcoref"
|
| 251 |
+
langs = ('Hungarian',)
|
| 252 |
+
train_filenames, dev_filenames = get_dataset_by_language(coref_input_path, langs)
|
| 253 |
+
elif args.project == 'gerrom':
|
| 254 |
+
project = "gerrom_udcoref"
|
| 255 |
+
langs = ('Catalan', 'English', 'French', 'German', 'Norwegian', 'Spanish')
|
| 256 |
+
train_filenames, dev_filenames = get_dataset_by_language(coref_input_path, langs)
|
| 257 |
+
elif args.project == 'germanic':
|
| 258 |
+
project = "germanic_udcoref"
|
| 259 |
+
langs = ('English', 'German', 'Norwegian')
|
| 260 |
+
train_filenames, dev_filenames = get_dataset_by_language(coref_input_path, langs)
|
| 261 |
+
elif args.project == 'norwegian':
|
| 262 |
+
project = "norwegian_udcoref"
|
| 263 |
+
langs = ('Norwegian',)
|
| 264 |
+
train_filenames, dev_filenames = get_dataset_by_language(coref_input_path, langs)
|
| 265 |
+
else:
|
| 266 |
+
project = args.directory
|
| 267 |
+
conll_path = os.path.join(coref_input_path, project)
|
| 268 |
+
if not os.path.exists(conll_path) and os.path.exists(project):
|
| 269 |
+
conll_path = args.directory
|
| 270 |
+
train_filenames = sorted(glob.glob(os.path.join(conll_path, f"*train.conllu")))
|
| 271 |
+
dev_filenames = sorted(glob.glob(os.path.join(conll_path, f"*dev.conllu")))
|
| 272 |
+
process_dataset(project, coref_output_path, args.split_test, train_filenames, dev_filenames)
|
| 273 |
+
|
| 274 |
+
if __name__ == '__main__':
|
| 275 |
+
main()
|
| 276 |
+
|
stanza/stanza/utils/datasets/coref/utils.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import defaultdict
|
| 2 |
+
from functools import lru_cache
|
| 3 |
+
|
| 4 |
+
class DynamicDepth():
|
| 5 |
+
"""
|
| 6 |
+
Implements a cache + dynamic programming to find the relative depth of every word in a subphrase given the head word for every word.
|
| 7 |
+
"""
|
| 8 |
+
def get_parse_depths(self, heads, start, end):
|
| 9 |
+
"""Return the relative depth for every word
|
| 10 |
+
|
| 11 |
+
Args:
|
| 12 |
+
heads (list): List where each entry is the index of that entry's head word in the dependency parse
|
| 13 |
+
start (int): starting index of the heads for the subphrase
|
| 14 |
+
end (int): ending index of the heads for the subphrase
|
| 15 |
+
|
| 16 |
+
Returns:
|
| 17 |
+
list: Relative depth in the dependency parse for every word
|
| 18 |
+
"""
|
| 19 |
+
self.heads = heads[start:end]
|
| 20 |
+
self.relative_heads = [h - start if h else -100 for h in self.heads] # -100 to deal with 'none' headwords
|
| 21 |
+
|
| 22 |
+
depths = [self._get_depth_recursive(h) for h in range(len(self.relative_heads))]
|
| 23 |
+
|
| 24 |
+
return depths
|
| 25 |
+
|
| 26 |
+
@lru_cache(maxsize=None)
|
| 27 |
+
def _get_depth_recursive(self, index):
|
| 28 |
+
"""Recursively get the depths of every index using a cache and recursion
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
index (int): Index of the word for which to calculate the relative depth
|
| 32 |
+
|
| 33 |
+
Returns:
|
| 34 |
+
int: Relative depth of the word at the index
|
| 35 |
+
"""
|
| 36 |
+
# if the head for the current index is outside the scope, this index is a relative root
|
| 37 |
+
if self.relative_heads[index] >= len(self.relative_heads) or self.relative_heads[index] < 0:
|
| 38 |
+
return 0
|
| 39 |
+
return self._get_depth_recursive(self.relative_heads[index]) + 1
|
| 40 |
+
|
| 41 |
+
def find_cconj_head(heads, upos, start, end):
|
| 42 |
+
"""
|
| 43 |
+
Finds how far each word is from the head of a span, then uses the closest CCONJ to the head as the new head
|
| 44 |
+
|
| 45 |
+
If no CCONJ is present, returns None
|
| 46 |
+
"""
|
| 47 |
+
# use head information to extract parse depth
|
| 48 |
+
dynamicDepth = DynamicDepth()
|
| 49 |
+
depth = dynamicDepth.get_parse_depths(heads, start, end)
|
| 50 |
+
depth_limit = 2
|
| 51 |
+
|
| 52 |
+
# return first 'CCONJ' token above depth limit, if exists
|
| 53 |
+
# unlike the original paper, we expect the parses to use UPOS, hence CCONJ instead of CC
|
| 54 |
+
cc_indexes = [i for i in range(end - start) if upos[i+start] == 'CCONJ' and depth[i] < depth_limit]
|
| 55 |
+
if cc_indexes:
|
| 56 |
+
return cc_indexes[0] + start
|
| 57 |
+
return None
|
| 58 |
+
|
| 59 |
+
def process_document(pipe, doc_id, part_id, sentences, coref_spans, sentence_speakers, use_cconj_heads=True):
|
| 60 |
+
"""
|
| 61 |
+
coref_spans: a list of lists
|
| 62 |
+
one list per sentence
|
| 63 |
+
each sentence has a list of spans, where each span is (span_index, span_start, span_end)
|
| 64 |
+
"""
|
| 65 |
+
sentence_lens = [len(x) for x in sentences]
|
| 66 |
+
if all(isinstance(x, list) for x in sentence_speakers):
|
| 67 |
+
speaker = [y for x in sentence_speakers for y in x]
|
| 68 |
+
else:
|
| 69 |
+
speaker = [y for x, sent_len in zip(sentence_speakers, sentence_lens) for y in [x] * sent_len]
|
| 70 |
+
|
| 71 |
+
cased_words = [y for x in sentences for y in x]
|
| 72 |
+
sent_id = [y for idx, sent_len in enumerate(sentence_lens) for y in [idx] * sent_len]
|
| 73 |
+
|
| 74 |
+
# use the trees to get the xpos tags
|
| 75 |
+
# alternatively, could translate the pos_tags field,
|
| 76 |
+
# but those have numbers, which is annoying
|
| 77 |
+
#tree_text = "\n".join(x['parse_tree'] for x in paragraph)
|
| 78 |
+
#trees = tree_reader.read_trees(tree_text)
|
| 79 |
+
#pos = [x.label for tree in trees for x in tree.yield_preterminals()]
|
| 80 |
+
# actually, the downstream code doesn't use pos at all. maybe we can skip?
|
| 81 |
+
|
| 82 |
+
doc = pipe(sentences)
|
| 83 |
+
word_total = 0
|
| 84 |
+
heads = []
|
| 85 |
+
# TODO: does SD vs UD matter?
|
| 86 |
+
deprel = []
|
| 87 |
+
for sentence in doc.sentences:
|
| 88 |
+
for word in sentence.words:
|
| 89 |
+
deprel.append(word.deprel)
|
| 90 |
+
if word.head == 0:
|
| 91 |
+
heads.append("null")
|
| 92 |
+
else:
|
| 93 |
+
heads.append(word.head - 1 + word_total)
|
| 94 |
+
word_total += len(sentence.words)
|
| 95 |
+
|
| 96 |
+
span_clusters = defaultdict(list)
|
| 97 |
+
word_clusters = defaultdict(list)
|
| 98 |
+
head2span = []
|
| 99 |
+
word_total = 0
|
| 100 |
+
for parsed_sentence, ontonotes_coref, ontonotes_words in zip(doc.sentences, coref_spans, sentences):
|
| 101 |
+
sentence_upos = [x.upos for x in parsed_sentence.words]
|
| 102 |
+
sentence_heads = [x.head - 1 if x.head > 0 else None for x in parsed_sentence.words]
|
| 103 |
+
for span in ontonotes_coref:
|
| 104 |
+
# input is expected to be start word, end word + 1
|
| 105 |
+
# counting from 0
|
| 106 |
+
# whereas the OntoNotes coref_span is [start_word, end_word] inclusive
|
| 107 |
+
span_start = span[1] + word_total
|
| 108 |
+
span_end = span[2] + word_total + 1
|
| 109 |
+
candidate_head = find_cconj_head(sentence_heads, sentence_upos, span[1], span[2]+1) if use_cconj_heads else None
|
| 110 |
+
if candidate_head is None:
|
| 111 |
+
for candidate_head in range(span[1], span[2] + 1):
|
| 112 |
+
# stanza uses 0 to mark the head, whereas OntoNotes is counting
|
| 113 |
+
# words from 0, so we have to subtract 1 from the stanza heads
|
| 114 |
+
#print(span, candidate_head, parsed_sentence.words[candidate_head].head - 1)
|
| 115 |
+
# treat the head of the phrase as the first word that has a head outside the phrase
|
| 116 |
+
if (parsed_sentence.words[candidate_head].head - 1 < span[1] or
|
| 117 |
+
parsed_sentence.words[candidate_head].head - 1 > span[2]):
|
| 118 |
+
break
|
| 119 |
+
else:
|
| 120 |
+
# if none have a head outside the phrase (circular??)
|
| 121 |
+
# then just take the first word
|
| 122 |
+
candidate_head = span[1]
|
| 123 |
+
#print("----> %d" % candidate_head)
|
| 124 |
+
candidate_head += word_total
|
| 125 |
+
span_clusters[span[0]].append((span_start, span_end))
|
| 126 |
+
word_clusters[span[0]].append(candidate_head)
|
| 127 |
+
head2span.append((candidate_head, span_start, span_end))
|
| 128 |
+
word_total += len(ontonotes_words)
|
| 129 |
+
span_clusters = sorted([sorted(values) for _, values in span_clusters.items()])
|
| 130 |
+
word_clusters = sorted([sorted(values) for _, values in word_clusters.items()])
|
| 131 |
+
head2span = sorted(head2span)
|
| 132 |
+
|
| 133 |
+
processed = {
|
| 134 |
+
"document_id": doc_id,
|
| 135 |
+
"part_id": part_id,
|
| 136 |
+
"cased_words": cased_words,
|
| 137 |
+
"sent_id": sent_id,
|
| 138 |
+
"speaker": speaker,
|
| 139 |
+
#"pos": pos,
|
| 140 |
+
"deprel": deprel,
|
| 141 |
+
"head": heads,
|
| 142 |
+
"span_clusters": span_clusters,
|
| 143 |
+
"word_clusters": word_clusters,
|
| 144 |
+
"head2span": head2span,
|
| 145 |
+
}
|
| 146 |
+
if part_id is not None:
|
| 147 |
+
processed["part_id"] = part_id
|
| 148 |
+
return processed
|
stanza/stanza/utils/datasets/corenlp_segmenter_dataset.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Output a treebank's sentences in a form that can be processed by the CoreNLP CRF Segmenter
|
| 3 |
+
|
| 4 |
+
Run it as
|
| 5 |
+
python3 -m stanza.utils.datasets.corenlp_segmenter_dataset <treebank>
|
| 6 |
+
such as
|
| 7 |
+
python3 -m stanza.utils.datasets.corenlp_segmenter_dataset UD_Chinese-GSDSimp --output_dir $CHINESE_SEGMENTER_HOME
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import argparse
|
| 11 |
+
import os
|
| 12 |
+
import sys
|
| 13 |
+
import tempfile
|
| 14 |
+
|
| 15 |
+
import stanza.utils.datasets.common as common
|
| 16 |
+
import stanza.utils.datasets.prepare_tokenizer_treebank as prepare_tokenizer_treebank
|
| 17 |
+
import stanza.utils.default_paths as default_paths
|
| 18 |
+
|
| 19 |
+
from stanza.models.common.constant import treebank_to_short_name
|
| 20 |
+
|
| 21 |
+
def build_argparse():
|
| 22 |
+
parser = argparse.ArgumentParser()
|
| 23 |
+
parser.add_argument('treebanks', type=str, nargs='*', default=["UD_Chinese-GSDSimp"], help='Which treebanks to run on')
|
| 24 |
+
parser.add_argument('--output_dir', type=str, default='.', help='Where to put the results')
|
| 25 |
+
return parser
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def write_segmenter_file(output_filename, dataset):
|
| 29 |
+
with open(output_filename, "w") as fout:
|
| 30 |
+
for sentence in dataset:
|
| 31 |
+
sentence = [x for x in sentence if not x.startswith("#")]
|
| 32 |
+
sentence = [x for x in [y.strip() for y in sentence] if x]
|
| 33 |
+
# eliminate MWE, although Chinese currently doesn't have any
|
| 34 |
+
sentence = [x for x in sentence if x.split("\t")[0].find("-") < 0]
|
| 35 |
+
|
| 36 |
+
text = " ".join(x.split("\t")[1] for x in sentence)
|
| 37 |
+
fout.write(text)
|
| 38 |
+
fout.write("\n")
|
| 39 |
+
|
| 40 |
+
def process_treebank(treebank, model_type, paths, output_dir):
|
| 41 |
+
with tempfile.TemporaryDirectory() as tokenizer_dir:
|
| 42 |
+
paths = dict(paths)
|
| 43 |
+
paths["TOKENIZE_DATA_DIR"] = tokenizer_dir
|
| 44 |
+
|
| 45 |
+
short_name = treebank_to_short_name(treebank)
|
| 46 |
+
|
| 47 |
+
# first we process the tokenization data
|
| 48 |
+
args = argparse.Namespace()
|
| 49 |
+
args.augment = False
|
| 50 |
+
args.prepare_labels = False
|
| 51 |
+
prepare_tokenizer_treebank.process_treebank(treebank, model_type, paths, args)
|
| 52 |
+
|
| 53 |
+
# TODO: these names should be refactored
|
| 54 |
+
train_file = f"{tokenizer_dir}/{short_name}.train.gold.conllu"
|
| 55 |
+
dev_file = f"{tokenizer_dir}/{short_name}.dev.gold.conllu"
|
| 56 |
+
test_file = f"{tokenizer_dir}/{short_name}.test.gold.conllu"
|
| 57 |
+
|
| 58 |
+
train_set = common.read_sentences_from_conllu(train_file)
|
| 59 |
+
dev_set = common.read_sentences_from_conllu(dev_file)
|
| 60 |
+
test_set = common.read_sentences_from_conllu(test_file)
|
| 61 |
+
|
| 62 |
+
train_out = os.path.join(output_dir, f"{short_name}.train.seg.txt")
|
| 63 |
+
test_out = os.path.join(output_dir, f"{short_name}.test.seg.txt")
|
| 64 |
+
|
| 65 |
+
write_segmenter_file(train_out, train_set + dev_set)
|
| 66 |
+
write_segmenter_file(test_out, test_set)
|
| 67 |
+
|
| 68 |
+
def main():
|
| 69 |
+
parser = build_argparse()
|
| 70 |
+
args = parser.parse_args()
|
| 71 |
+
|
| 72 |
+
paths = default_paths.get_default_paths()
|
| 73 |
+
for treebank in args.treebanks:
|
| 74 |
+
process_treebank(treebank, common.ModelType.TOKENIZER, paths, args.output_dir)
|
| 75 |
+
|
| 76 |
+
if __name__ == '__main__':
|
| 77 |
+
main()
|
| 78 |
+
|
stanza/stanza/utils/datasets/ner/convert_bsnlp.py
ADDED
|
@@ -0,0 +1,333 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import glob
|
| 3 |
+
import os
|
| 4 |
+
import logging
|
| 5 |
+
import random
|
| 6 |
+
import re
|
| 7 |
+
|
| 8 |
+
import stanza
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger('stanza')
|
| 11 |
+
|
| 12 |
+
AVAILABLE_LANGUAGES = ("bg", "cs", "pl", "ru")
|
| 13 |
+
|
| 14 |
+
def normalize_bg_entity(text, entity, raw):
|
| 15 |
+
entity = entity.strip()
|
| 16 |
+
# sanity check that the token is in the original text
|
| 17 |
+
if text.find(entity) >= 0:
|
| 18 |
+
return entity
|
| 19 |
+
|
| 20 |
+
# some entities have quotes, but the quotes are different from those in the data file
|
| 21 |
+
# for example:
|
| 22 |
+
# training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_458.txt
|
| 23 |
+
# 'Съвета "Общи въпроси"'
|
| 24 |
+
# training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_1002.txt
|
| 25 |
+
# 'Съвет "Общи въпроси"'
|
| 26 |
+
if sum(1 for x in entity if x == '"') == 2:
|
| 27 |
+
quote_entity = entity.replace('"', '“')
|
| 28 |
+
if text.find(quote_entity) >= 0:
|
| 29 |
+
logger.info("searching for '%s' instead of '%s' in %s" % (quote_entity, entity, raw))
|
| 30 |
+
return quote_entity
|
| 31 |
+
|
| 32 |
+
quote_entity = entity.replace('"', '„', 1).replace('"', '“')
|
| 33 |
+
if text.find(quote_entity) >= 0:
|
| 34 |
+
logger.info("searching for '%s' instead of '%s' in %s" % (quote_entity, entity, raw))
|
| 35 |
+
return quote_entity
|
| 36 |
+
|
| 37 |
+
if sum(1 for x in entity if x == '"') == 1:
|
| 38 |
+
quote_entity = entity.replace('"', '„', 1)
|
| 39 |
+
if text.find(quote_entity) >= 0:
|
| 40 |
+
logger.info("searching for '%s' instead of '%s' in %s" % (quote_entity, entity, raw))
|
| 41 |
+
return quote_entity
|
| 42 |
+
|
| 43 |
+
if entity.find("'") >= 0:
|
| 44 |
+
quote_entity = entity.replace("'", "’")
|
| 45 |
+
if text.find(quote_entity) >= 0:
|
| 46 |
+
logger.info("searching for '%s' instead of '%s' in %s" % (quote_entity, entity, raw))
|
| 47 |
+
return quote_entity
|
| 48 |
+
|
| 49 |
+
lower_idx = text.lower().find(entity.lower())
|
| 50 |
+
if lower_idx >= 0:
|
| 51 |
+
fixed_entity = text[lower_idx:lower_idx+len(entity)]
|
| 52 |
+
logger.info("lowercase match found. Searching for '%s' instead of '%s' in %s" % (fixed_entity, entity, raw))
|
| 53 |
+
return fixed_entity
|
| 54 |
+
|
| 55 |
+
substitution_pairs = {
|
| 56 |
+
# this exact error happens in:
|
| 57 |
+
# training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_67.txt
|
| 58 |
+
'Съвет по общи въпроси': 'Съвета по общи въпроси',
|
| 59 |
+
# training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_214.txt
|
| 60 |
+
'Сумимото Мицуи файненшъл груп': 'Сумитомо Мицуи файненшъл груп',
|
| 61 |
+
# training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_245.txt
|
| 62 |
+
'С и Д': 'С&Д',
|
| 63 |
+
# training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_348.txt
|
| 64 |
+
'законопроекта за излизане на Великобритания за излизане от Европейския съюз': 'законопроекта за излизане на Великобритания от Европейския съюз',
|
| 65 |
+
# training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_771.txt
|
| 66 |
+
'Унивеситета в Есекс': 'Университета в Есекс',
|
| 67 |
+
# training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_779.txt
|
| 68 |
+
'Съвет за сигурност на ООН': 'Съвета за сигурност на ООН',
|
| 69 |
+
# training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_787.txt
|
| 70 |
+
'Федерика Могерини': 'Федереика Могерини',
|
| 71 |
+
# training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_938.txt
|
| 72 |
+
'Уайстейбъл': 'Уайтстейбъл',
|
| 73 |
+
'Партията за независимост на Обединеното кралство': 'Партията на независимостта на Обединеното кралство',
|
| 74 |
+
# training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_972.txt
|
| 75 |
+
'Европейска банка за възстановяване и развитие': 'Европейската банка за възстановяване и развитие',
|
| 76 |
+
# training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_1065.txt
|
| 77 |
+
'Харолд Уилсон': 'Харолд Уилсън',
|
| 78 |
+
'Манчестърски университет': 'Манчестърския университет',
|
| 79 |
+
# training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_1096.txt
|
| 80 |
+
'Обединеното кралство в променящата се Европа': 'Обединеното кралство в променяща се Европа',
|
| 81 |
+
# training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_1175.txt
|
| 82 |
+
'The Daily Express': 'Daily Express',
|
| 83 |
+
# training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_1186.txt
|
| 84 |
+
'демократичната юнионистка партия': 'демократична юнионистка партия',
|
| 85 |
+
# training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_1192.txt
|
| 86 |
+
'Европейската агенция за безопасността на полетите': 'Европейската агенция за сигурността на полетите',
|
| 87 |
+
# training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_1219.txt
|
| 88 |
+
'пресцентъра на Външно министертво': 'пресцентъра на Външно министерство',
|
| 89 |
+
# training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_1281.txt
|
| 90 |
+
'Европейска агенциа за безопасността на полетите': 'Европейската агенция за сигурността на полетите',
|
| 91 |
+
# training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_1415.txt
|
| 92 |
+
'Хонк Конг': 'Хонг Конг',
|
| 93 |
+
# training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_1663.txt
|
| 94 |
+
'Лейбъристка партия': 'Лейбъристката партия',
|
| 95 |
+
# training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_1963.txt
|
| 96 |
+
'Найджъл Фараж': 'Найджъл Фарадж',
|
| 97 |
+
'Фараж': 'Фарадж',
|
| 98 |
+
|
| 99 |
+
# training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_1773.txt has an entity which is mixed Cyrillic and Ascii
|
| 100 |
+
'Tescо': 'Tesco',
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
if entity in substitution_pairs and text.find(substitution_pairs[entity]) >= 0:
|
| 104 |
+
fixed_entity = substitution_pairs[entity]
|
| 105 |
+
logger.info("searching for '%s' instead of '%s' in %s" % (fixed_entity, entity, raw))
|
| 106 |
+
return fixed_entity
|
| 107 |
+
|
| 108 |
+
# oops, can't find it anywhere
|
| 109 |
+
# want to raise ValueError but there are just too many in the train set for BG
|
| 110 |
+
logger.error("Could not find '%s' in %s" % (entity, raw))
|
| 111 |
+
|
| 112 |
+
def fix_bg_typos(text, raw_filename):
|
| 113 |
+
typo_pairs = {
|
| 114 |
+
# training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_202.txt is not exactly a typo, but the word is mixed cyrillic and ascii characters
|
| 115 |
+
'brexit_bg.txt_file_202.txt': ('Вlооmbеrg', 'Bloomberg'),
|
| 116 |
+
# training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_261.txt has a typo: Telegaph instead of Telegraph
|
| 117 |
+
'brexit_bg.txt_file_261.txt': ('Telegaph', 'Telegraph'),
|
| 118 |
+
# training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_574.txt has a typo: politicalskrapbook instead of politicalscrapbook
|
| 119 |
+
'brexit_bg.txt_file_574.txt': ('politicalskrapbook', 'politicalscrapbook'),
|
| 120 |
+
# training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_861.txt has a mix of cyrillic and ascii
|
| 121 |
+
'brexit_bg.txt_file_861.txt': ('Съвета „Общи въпроси“', 'Съветa "Общи въпроси"'),
|
| 122 |
+
# training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_992.txt is not exactly a typo, but the word is mixed cyrillic and ascii characters
|
| 123 |
+
'brexit_bg.txt_file_992.txt': ('The Guardiаn', 'The Guardian'),
|
| 124 |
+
# training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_1856.txt has a typo: Southerb instead of Southern
|
| 125 |
+
'brexit_bg.txt_file_1856.txt': ('Southerb', 'Southern'),
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
filename = os.path.split(raw_filename)[1]
|
| 129 |
+
if filename in typo_pairs:
|
| 130 |
+
replacement = typo_pairs.get(filename)
|
| 131 |
+
text = text.replace(replacement[0], replacement[1])
|
| 132 |
+
|
| 133 |
+
return text
|
| 134 |
+
|
| 135 |
+
def get_sentences(language, pipeline, annotated, raw):
|
| 136 |
+
if language == 'bg':
|
| 137 |
+
normalize_entity = normalize_bg_entity
|
| 138 |
+
fix_typos = fix_bg_typos
|
| 139 |
+
else:
|
| 140 |
+
raise AssertionError("Please build a normalize_%s_entity and fix_%s_typos first" % language)
|
| 141 |
+
|
| 142 |
+
annotated_sentences = []
|
| 143 |
+
with open(raw) as fin:
|
| 144 |
+
lines = fin.readlines()
|
| 145 |
+
if len(lines) < 5:
|
| 146 |
+
raise ValueError("Unexpected format in %s" % raw)
|
| 147 |
+
text = "\n".join(lines[4:])
|
| 148 |
+
text = fix_typos(text, raw)
|
| 149 |
+
|
| 150 |
+
entities = {}
|
| 151 |
+
with open(annotated) as fin:
|
| 152 |
+
# first line
|
| 153 |
+
header = fin.readline().strip()
|
| 154 |
+
if len(header.split("\t")) > 1:
|
| 155 |
+
raise ValueError("Unexpected missing header line in %s" % annotated)
|
| 156 |
+
for line in fin:
|
| 157 |
+
pieces = line.strip().split("\t")
|
| 158 |
+
if len(pieces) < 3 or len(pieces) > 4:
|
| 159 |
+
raise ValueError("Unexpected annotation format in %s" % annotated)
|
| 160 |
+
|
| 161 |
+
entity = normalize_entity(text, pieces[0], raw)
|
| 162 |
+
if not entity:
|
| 163 |
+
continue
|
| 164 |
+
if entity in entities:
|
| 165 |
+
if entities[entity] != pieces[2]:
|
| 166 |
+
# would like to make this an error, but it actually happens and it's not clear how to fix
|
| 167 |
+
# annotated/nord_stream/bg/nord_stream_bg.txt_file_119.out
|
| 168 |
+
logger.warn("found multiple definitions for %s in %s" % (pieces[0], annotated))
|
| 169 |
+
entities[entity] = pieces[2]
|
| 170 |
+
else:
|
| 171 |
+
entities[entity] = pieces[2]
|
| 172 |
+
|
| 173 |
+
tokenized = pipeline(text)
|
| 174 |
+
# The benefit of doing these one at a time, instead of all at once,
|
| 175 |
+
# is that nested entities won't clobber previously labeled entities.
|
| 176 |
+
# For example, the file
|
| 177 |
+
# training_pl_cs_ru_bg_rc1/annotated/bg/brexit_bg.txt_file_994.out
|
| 178 |
+
# has each of:
|
| 179 |
+
# Северна Ирландия
|
| 180 |
+
# Република Ирландия
|
| 181 |
+
# Ирландия
|
| 182 |
+
# By doing the larger ones first, we can detect and skip the ones
|
| 183 |
+
# we already labeled when we reach the shorter one
|
| 184 |
+
regexes = [re.compile(re.escape(x)) for x in sorted(entities.keys(), key=len, reverse=True)]
|
| 185 |
+
|
| 186 |
+
bad_sentences = set()
|
| 187 |
+
|
| 188 |
+
for regex in regexes:
|
| 189 |
+
for match in regex.finditer(text):
|
| 190 |
+
start_char, end_char = match.span()
|
| 191 |
+
# this is inefficient, but for something only run once, it shouldn't matter
|
| 192 |
+
start_token = None
|
| 193 |
+
start_sloppy = False
|
| 194 |
+
end_token = None
|
| 195 |
+
end_sloppy = False
|
| 196 |
+
for token in tokenized.iter_tokens():
|
| 197 |
+
if token.start_char <= start_char and token.end_char > start_char:
|
| 198 |
+
start_token = token
|
| 199 |
+
if token.start_char != start_char:
|
| 200 |
+
start_sloppy = True
|
| 201 |
+
if token.start_char <= end_char and token.end_char >= end_char:
|
| 202 |
+
end_token = token
|
| 203 |
+
if token.end_char != end_char:
|
| 204 |
+
end_sloppy = True
|
| 205 |
+
break
|
| 206 |
+
if start_token is None or end_token is None:
|
| 207 |
+
raise RuntimeError("Match %s did not align with any tokens in %s" % (match.group(0), raw))
|
| 208 |
+
if not start_token.sent is end_token.sent:
|
| 209 |
+
bad_sentences.add(start_token.sent.id)
|
| 210 |
+
bad_sentences.add(end_token.sent.id)
|
| 211 |
+
logger.warn("match %s spanned sentences %d and %d in document %s" % (match.group(0), start_token.sent.id, end_token.sent.id, raw))
|
| 212 |
+
continue
|
| 213 |
+
|
| 214 |
+
# ids start at 1, not 0, so we have to subtract 1
|
| 215 |
+
# then the end token is included, so we add back the 1
|
| 216 |
+
# TODO: verify that this is correct if the language has MWE - cs, pl, for example
|
| 217 |
+
tokens = start_token.sent.tokens[start_token.id[0]-1:end_token.id[0]]
|
| 218 |
+
if all(token.ner for token in tokens):
|
| 219 |
+
# skip matches which have already been made
|
| 220 |
+
# this has the nice side effect of not complaining if
|
| 221 |
+
# a smaller match is found after a larger match
|
| 222 |
+
# earlier set the NER on those tokens
|
| 223 |
+
continue
|
| 224 |
+
|
| 225 |
+
if start_sloppy and end_sloppy:
|
| 226 |
+
bad_sentences.add(start_token.sent.id)
|
| 227 |
+
logger.warn("match %s matched in the middle of a token in %s" % (match.group(0), raw))
|
| 228 |
+
continue
|
| 229 |
+
if start_sloppy:
|
| 230 |
+
bad_sentences.add(end_token.sent.id)
|
| 231 |
+
logger.warn("match %s started matching in the middle of a token in %s" % (match.group(0), raw))
|
| 232 |
+
#print(start_token)
|
| 233 |
+
#print(end_token)
|
| 234 |
+
#print(start_char, end_char)
|
| 235 |
+
continue
|
| 236 |
+
if end_sloppy:
|
| 237 |
+
bad_sentences.add(start_token.sent.id)
|
| 238 |
+
logger.warn("match %s ended matching in the middle of a token in %s" % (match.group(0), raw))
|
| 239 |
+
#print(start_token)
|
| 240 |
+
#print(end_token)
|
| 241 |
+
#print(start_char, end_char)
|
| 242 |
+
continue
|
| 243 |
+
match_text = match.group(0)
|
| 244 |
+
if match_text not in entities:
|
| 245 |
+
raise RuntimeError("Matched %s, which is not in the entities from %s" % (match_text, annotated))
|
| 246 |
+
ner_tag = entities[match_text]
|
| 247 |
+
tokens[0].ner = "B-" + ner_tag
|
| 248 |
+
for token in tokens[1:]:
|
| 249 |
+
token.ner = "I-" + ner_tag
|
| 250 |
+
|
| 251 |
+
for sentence in tokenized.sentences:
|
| 252 |
+
if not sentence.id in bad_sentences:
|
| 253 |
+
annotated_sentences.append(sentence)
|
| 254 |
+
|
| 255 |
+
return annotated_sentences
|
| 256 |
+
|
| 257 |
+
def write_sentences(output_filename, annotated_sentences):
|
| 258 |
+
logger.info("Writing %d sentences to %s" % (len(annotated_sentences), output_filename))
|
| 259 |
+
with open(output_filename, "w") as fout:
|
| 260 |
+
for sentence in annotated_sentences:
|
| 261 |
+
for token in sentence.tokens:
|
| 262 |
+
ner_tag = token.ner
|
| 263 |
+
if not ner_tag:
|
| 264 |
+
ner_tag = "O"
|
| 265 |
+
fout.write("%s\t%s\n" % (token.text, ner_tag))
|
| 266 |
+
fout.write("\n")
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
def convert_bsnlp(language, base_input_path, output_filename, split_filename=None):
|
| 270 |
+
"""
|
| 271 |
+
Converts the BSNLP dataset for the given language.
|
| 272 |
+
|
| 273 |
+
If only one output_filename is provided, all of the output goes to that file.
|
| 274 |
+
If split_filename is provided as well, 15% of the output chosen randomly
|
| 275 |
+
goes there instead. The dataset has no dev set, so this helps
|
| 276 |
+
divide the data into train/dev/test.
|
| 277 |
+
Note that the custom error fixes are only done for BG currently.
|
| 278 |
+
Please manually correct the data as appropriate before using this
|
| 279 |
+
for another language.
|
| 280 |
+
"""
|
| 281 |
+
if language not in AVAILABLE_LANGUAGES:
|
| 282 |
+
raise ValueError("The current BSNLP datasets only include the following languages: %s" % ",".join(AVAILABLE_LANGUAGES))
|
| 283 |
+
if language != "bg":
|
| 284 |
+
raise ValueError("There were quite a few data fixes needed to get the data correct for BG. Please work on similar fixes before using the model for %s" % language.upper())
|
| 285 |
+
pipeline = stanza.Pipeline(language, processors="tokenize")
|
| 286 |
+
random.seed(1234)
|
| 287 |
+
|
| 288 |
+
annotated_path = os.path.join(base_input_path, "annotated", "*", language, "*")
|
| 289 |
+
annotated_files = sorted(glob.glob(annotated_path))
|
| 290 |
+
raw_path = os.path.join(base_input_path, "raw", "*", language, "*")
|
| 291 |
+
raw_files = sorted(glob.glob(raw_path))
|
| 292 |
+
|
| 293 |
+
# if the instructions for downloading the data from the
|
| 294 |
+
# process_ner_dataset script are followed, there will be two test
|
| 295 |
+
# directories of data and a separate training directory of data.
|
| 296 |
+
if len(annotated_files) == 0 and len(raw_files) == 0:
|
| 297 |
+
logger.info("Could not find files in %s" % annotated_path)
|
| 298 |
+
annotated_path = os.path.join(base_input_path, "annotated", language, "*")
|
| 299 |
+
logger.info("Trying %s instead" % annotated_path)
|
| 300 |
+
annotated_files = sorted(glob.glob(annotated_path))
|
| 301 |
+
raw_path = os.path.join(base_input_path, "raw", language, "*")
|
| 302 |
+
raw_files = sorted(glob.glob(raw_path))
|
| 303 |
+
|
| 304 |
+
if len(annotated_files) != len(raw_files):
|
| 305 |
+
raise ValueError("Unexpected differences in the file lists between %s and %s" % (annotated_files, raw_files))
|
| 306 |
+
|
| 307 |
+
for i, j in zip(annotated_files, raw_files):
|
| 308 |
+
if os.path.split(i)[1][:-4] != os.path.split(j)[1][:-4]:
|
| 309 |
+
raise ValueError("Unexpected differences in the file lists: found %s instead of %s" % (i, j))
|
| 310 |
+
|
| 311 |
+
annotated_sentences = []
|
| 312 |
+
if split_filename:
|
| 313 |
+
split_sentences = []
|
| 314 |
+
for annotated, raw in zip(annotated_files, raw_files):
|
| 315 |
+
new_sentences = get_sentences(language, pipeline, annotated, raw)
|
| 316 |
+
if not split_filename or random.random() < 0.85:
|
| 317 |
+
annotated_sentences.extend(new_sentences)
|
| 318 |
+
else:
|
| 319 |
+
split_sentences.extend(new_sentences)
|
| 320 |
+
|
| 321 |
+
write_sentences(output_filename, annotated_sentences)
|
| 322 |
+
if split_filename:
|
| 323 |
+
write_sentences(split_filename, split_sentences)
|
| 324 |
+
|
| 325 |
+
if __name__ == '__main__':
|
| 326 |
+
parser = argparse.ArgumentParser()
|
| 327 |
+
parser.add_argument('--language', type=str, default="bg", help="Language to process")
|
| 328 |
+
parser.add_argument('--input_path', type=str, default="/home/john/extern_data/ner/bsnlp2019", help="Where to find the files")
|
| 329 |
+
parser.add_argument('--output_path', type=str, default="/home/john/stanza/data/ner/bg_bsnlp.test.csv", help="Where to output the results")
|
| 330 |
+
parser.add_argument('--dev_path', type=str, default=None, help="A secondary output path - 15% of the data will go here")
|
| 331 |
+
args = parser.parse_args()
|
| 332 |
+
|
| 333 |
+
convert_bsnlp(args.language, args.input_path, args.output_path, args.dev_path)
|
stanza/stanza/utils/datasets/ner/convert_fire_2013.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Converts the FIRE 2013 dataset to TSV
|
| 3 |
+
|
| 4 |
+
http://au-kbc.org/nlp/NER-FIRE2013/index.html
|
| 5 |
+
|
| 6 |
+
The dataset is in six tab separated columns. The columns are
|
| 7 |
+
|
| 8 |
+
word tag chunk ner1 ner2 ner3
|
| 9 |
+
|
| 10 |
+
This script keeps just the word and the ner1. It is quite possible that using the tag would help
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import argparse
|
| 14 |
+
import glob
|
| 15 |
+
import os
|
| 16 |
+
import random
|
| 17 |
+
|
| 18 |
+
def normalize(e1, e2, e3):
|
| 19 |
+
if e1 == 'o':
|
| 20 |
+
return "O"
|
| 21 |
+
|
| 22 |
+
if e2 != 'o' and e1[:2] != e2[:2]:
|
| 23 |
+
raise ValueError("Found a token with conflicting position tags %s,%s" % (e1, e2))
|
| 24 |
+
if e3 != 'o' and e2 == 'o':
|
| 25 |
+
raise ValueError("Found a token with tertiary label but no secondary label %s,%s,%s" % (e1, e2, e3))
|
| 26 |
+
if e3 != 'o' and (e1[:2] != e2[:2] or e1[:2] != e3[:2]):
|
| 27 |
+
raise ValueError("Found a token with conflicting position tags %s,%s,%s" % (e1, e2, e3))
|
| 28 |
+
|
| 29 |
+
if e1[2:] in ('ORGANIZATION', 'FACILITIES'):
|
| 30 |
+
return e1
|
| 31 |
+
if e1[2:] == 'ENTERTAINMENT' and e2[2:] != 'SPORTS' and e2[2:] != 'CINEMA':
|
| 32 |
+
return e1
|
| 33 |
+
if e1[2:] == 'DISEASE' and e2 == 'o':
|
| 34 |
+
return e1
|
| 35 |
+
if e1[2:] == 'PLANTS' and e2[2:] != 'PARTS':
|
| 36 |
+
return e1
|
| 37 |
+
if e1[2:] == 'PERSON' and e2[2:] == 'INDIVIDUAL':
|
| 38 |
+
return e1
|
| 39 |
+
if e1[2:] == 'LOCATION' and e2[2:] == 'PLACE':
|
| 40 |
+
return e1
|
| 41 |
+
if e1[2:] in ('DATE', 'TIME', 'YEAR'):
|
| 42 |
+
string = e1[:2] + 'DATETIME'
|
| 43 |
+
return string
|
| 44 |
+
|
| 45 |
+
return "O"
|
| 46 |
+
|
| 47 |
+
def read_fileset(filenames):
|
| 48 |
+
# first, read the sentences from each data file
|
| 49 |
+
sentences = []
|
| 50 |
+
for filename in filenames:
|
| 51 |
+
with open(filename) as fin:
|
| 52 |
+
next_sentence = []
|
| 53 |
+
for line in fin:
|
| 54 |
+
line = line.strip()
|
| 55 |
+
if not line:
|
| 56 |
+
# lots of single line "sentences" in the dataset
|
| 57 |
+
if next_sentence:
|
| 58 |
+
if len(next_sentence) > 1:
|
| 59 |
+
sentences.append(next_sentence)
|
| 60 |
+
next_sentence = []
|
| 61 |
+
else:
|
| 62 |
+
next_sentence.append(line)
|
| 63 |
+
if next_sentence and len(next_sentence) > 1:
|
| 64 |
+
sentences.append(next_sentence)
|
| 65 |
+
return sentences
|
| 66 |
+
|
| 67 |
+
def write_fileset(output_csv_file, sentences):
|
| 68 |
+
with open(output_csv_file, "w") as fout:
|
| 69 |
+
for sentence in sentences:
|
| 70 |
+
for line in sentence:
|
| 71 |
+
pieces = line.split("\t")
|
| 72 |
+
if len(pieces) != 6:
|
| 73 |
+
raise ValueError("Found %d pieces instead of the expected 6" % len(pieces))
|
| 74 |
+
if pieces[3] == 'o' and (pieces[4] != 'o' or pieces[5] != 'o'):
|
| 75 |
+
raise ValueError("Inner NER labeled but the top layer was O")
|
| 76 |
+
fout.write("%s\t%s\n" % (pieces[0], normalize(pieces[3], pieces[4], pieces[5])))
|
| 77 |
+
fout.write("\n")
|
| 78 |
+
|
| 79 |
+
def convert_fire_2013(input_path, train_csv_file, dev_csv_file, test_csv_file):
|
| 80 |
+
random.seed(1234)
|
| 81 |
+
|
| 82 |
+
filenames = glob.glob(os.path.join(input_path, "*"))
|
| 83 |
+
|
| 84 |
+
# won't be numerically sorted... shouldn't matter
|
| 85 |
+
filenames = sorted(filenames)
|
| 86 |
+
random.shuffle(filenames)
|
| 87 |
+
|
| 88 |
+
sentences = read_fileset(filenames)
|
| 89 |
+
random.shuffle(sentences)
|
| 90 |
+
|
| 91 |
+
train_cutoff = int(0.8 * len(sentences))
|
| 92 |
+
dev_cutoff = int(0.9 * len(sentences))
|
| 93 |
+
|
| 94 |
+
train_sentences = sentences[:train_cutoff]
|
| 95 |
+
dev_sentences = sentences[train_cutoff:dev_cutoff]
|
| 96 |
+
test_sentences = sentences[dev_cutoff:]
|
| 97 |
+
|
| 98 |
+
random.shuffle(train_sentences)
|
| 99 |
+
random.shuffle(dev_sentences)
|
| 100 |
+
random.shuffle(test_sentences)
|
| 101 |
+
|
| 102 |
+
assert len(train_sentences) > 0
|
| 103 |
+
assert len(dev_sentences) > 0
|
| 104 |
+
assert len(test_sentences) > 0
|
| 105 |
+
|
| 106 |
+
write_fileset(train_csv_file, train_sentences)
|
| 107 |
+
write_fileset(dev_csv_file, dev_sentences)
|
| 108 |
+
write_fileset(test_csv_file, test_sentences)
|
| 109 |
+
|
| 110 |
+
if __name__ == '__main__':
|
| 111 |
+
parser = argparse.ArgumentParser()
|
| 112 |
+
parser.add_argument('--input_path', type=str, default="/home/john/extern_data/ner/FIRE2013/hindi_train", help="Directory with raw files to read")
|
| 113 |
+
parser.add_argument('--train_file', type=str, default="/home/john/stanza/data/ner/hi_fire2013.train.csv", help="Where to put the train file")
|
| 114 |
+
parser.add_argument('--dev_file', type=str, default="/home/john/stanza/data/ner/hi_fire2013.dev.csv", help="Where to put the dev file")
|
| 115 |
+
parser.add_argument('--test_file', type=str, default="/home/john/stanza/data/ner/hi_fire2013.test.csv", help="Where to put the test file")
|
| 116 |
+
args = parser.parse_args()
|
| 117 |
+
|
| 118 |
+
convert_fire_2013(args.input_path, args.train_file, args.dev_file, args.test_file)
|
stanza/stanza/utils/datasets/ner/convert_hy_armtdp.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Convert a ArmTDP-NER dataset to BIO format
|
| 3 |
+
|
| 4 |
+
The dataset is here:
|
| 5 |
+
|
| 6 |
+
https://github.com/myavrum/ArmTDP-NER.git
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import argparse
|
| 10 |
+
import os
|
| 11 |
+
import json
|
| 12 |
+
import re
|
| 13 |
+
import stanza
|
| 14 |
+
import random
|
| 15 |
+
from tqdm import tqdm
|
| 16 |
+
|
| 17 |
+
from stanza import DownloadMethod, Pipeline
|
| 18 |
+
import stanza.utils.default_paths as default_paths
|
| 19 |
+
|
| 20 |
+
def read_data(path: str) -> list:
|
| 21 |
+
"""
|
| 22 |
+
Reads the Armenian named entity recognition dataset
|
| 23 |
+
|
| 24 |
+
Returns a list of dictionaries.
|
| 25 |
+
Each dictionary contains information
|
| 26 |
+
about a paragraph (text, labels, etc.)
|
| 27 |
+
"""
|
| 28 |
+
with open(path, 'r') as file:
|
| 29 |
+
paragraphs = [json.loads(line) for line in file]
|
| 30 |
+
return paragraphs
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def filter_unicode_broken_characters(text: str) -> str:
|
| 34 |
+
"""
|
| 35 |
+
Removes all unicode characters in text
|
| 36 |
+
"""
|
| 37 |
+
return re.sub(r'\\u[A-Za-z0-9]{4}', '', text)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def get_label(tok_start_char: int, tok_end_char: int, labels: list) -> list:
|
| 41 |
+
"""
|
| 42 |
+
Returns the label that corresponds to the given token
|
| 43 |
+
"""
|
| 44 |
+
for label in labels:
|
| 45 |
+
if label[0] <= tok_start_char and label[1] >= tok_end_char:
|
| 46 |
+
return label
|
| 47 |
+
return []
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def format_sentences(paragraphs: list, nlp_hy: Pipeline) -> list:
|
| 51 |
+
"""
|
| 52 |
+
Takes a list of paragraphs and returns a list of sentences,
|
| 53 |
+
where each sentence is a list of tokens along with their respective entity tags.
|
| 54 |
+
"""
|
| 55 |
+
sentences = []
|
| 56 |
+
for paragraph in tqdm(paragraphs):
|
| 57 |
+
doc = nlp_hy(filter_unicode_broken_characters(paragraph['text']))
|
| 58 |
+
for sentence in doc.sentences:
|
| 59 |
+
sentence_ents = []
|
| 60 |
+
entity = []
|
| 61 |
+
for token in sentence.tokens:
|
| 62 |
+
label = get_label(token.start_char, token.end_char, paragraph['labels'])
|
| 63 |
+
if label:
|
| 64 |
+
entity.append(token.text)
|
| 65 |
+
if token.end_char == label[1]:
|
| 66 |
+
sentence_ents.append({'tokens': entity,
|
| 67 |
+
'tag': label[2]})
|
| 68 |
+
entity = []
|
| 69 |
+
else:
|
| 70 |
+
sentence_ents.append({'tokens': [token.text],
|
| 71 |
+
'tag': 'O'})
|
| 72 |
+
sentences.append(sentence_ents)
|
| 73 |
+
return sentences
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def convert_to_bioes(sentences: list) -> list:
|
| 77 |
+
"""
|
| 78 |
+
Returns a list of strings where each string represents a sentence in BIOES format
|
| 79 |
+
"""
|
| 80 |
+
beios_sents = []
|
| 81 |
+
for sentence in tqdm(sentences):
|
| 82 |
+
sentence_toc = ''
|
| 83 |
+
for ent in sentence:
|
| 84 |
+
if ent['tag'] == 'O':
|
| 85 |
+
sentence_toc += ent['tokens'][0] + '\tO' + '\n'
|
| 86 |
+
else:
|
| 87 |
+
if len(ent['tokens']) == 1:
|
| 88 |
+
sentence_toc += ent['tokens'][0] + '\tS-' + ent['tag'] + '\n'
|
| 89 |
+
else:
|
| 90 |
+
sentence_toc += ent['tokens'][0] + '\tB-' + ent['tag'] + '\n'
|
| 91 |
+
for token in ent['tokens'][1:-1]:
|
| 92 |
+
sentence_toc += token + '\tI-' + ent['tag'] + '\n'
|
| 93 |
+
sentence_toc += ent['tokens'][-1] + '\tE-' + ent['tag'] + '\n'
|
| 94 |
+
beios_sents.append(sentence_toc)
|
| 95 |
+
return beios_sents
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def write_sentences_to_file(sents, filename):
|
| 99 |
+
print(f"Writing {len(sents)} sentences to {filename}")
|
| 100 |
+
with open(filename, 'w') as outfile:
|
| 101 |
+
for sent in sents:
|
| 102 |
+
outfile.write(sent + '\n\n')
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def train_test_dev_split(sents, base_output_path, short_name, train_fraction=0.7, dev_fraction=0.15):
|
| 106 |
+
"""
|
| 107 |
+
Splits a list of sentences into training, dev, and test sets,
|
| 108 |
+
and writes each set to a separate file with write_sentences_to_file
|
| 109 |
+
"""
|
| 110 |
+
num = len(sents)
|
| 111 |
+
train_num = int(num * train_fraction)
|
| 112 |
+
dev_num = int(num * dev_fraction)
|
| 113 |
+
if train_fraction + dev_fraction > 1.0:
|
| 114 |
+
raise ValueError(
|
| 115 |
+
"Train and dev fractions added up to more than 1: {} {} {}".format(train_fraction, dev_fraction))
|
| 116 |
+
|
| 117 |
+
random.shuffle(sents)
|
| 118 |
+
train_sents = sents[:train_num]
|
| 119 |
+
dev_sents = sents[train_num:train_num + dev_num]
|
| 120 |
+
test_sents = sents[train_num + dev_num:]
|
| 121 |
+
batches = [train_sents, dev_sents, test_sents]
|
| 122 |
+
filenames = [f'{short_name}.train.tsv', f'{short_name}.dev.tsv', f'{short_name}.test.tsv']
|
| 123 |
+
for batch, filename in zip(batches, filenames):
|
| 124 |
+
write_sentences_to_file(batch, os.path.join(base_output_path, filename))
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def convert_dataset(base_input_path, base_output_path, short_name, download_method=DownloadMethod.DOWNLOAD_RESOURCES):
|
| 128 |
+
nlp_hy = stanza.Pipeline(lang='hy', processors='tokenize', download_method=download_method)
|
| 129 |
+
paragraphs = read_data(os.path.join(base_input_path, 'ArmNER-HY.json1'))
|
| 130 |
+
tagged_sentences = format_sentences(paragraphs, nlp_hy)
|
| 131 |
+
beios_sentences = convert_to_bioes(tagged_sentences)
|
| 132 |
+
train_test_dev_split(beios_sentences, base_output_path, short_name)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
if __name__ == '__main__':
|
| 136 |
+
paths = default_paths.get_default_paths()
|
| 137 |
+
|
| 138 |
+
parser = argparse.ArgumentParser()
|
| 139 |
+
parser.add_argument('--input_path', type=str, default=os.path.join(paths["NERBASE"], "armenian", "ArmTDP-NER"), help="Path to input file")
|
| 140 |
+
parser.add_argument('--output_path', type=str, default=paths["NER_DATA_DIR"], help="Path to the output directory")
|
| 141 |
+
parser.add_argument('--short_name', type=str, default="hy_armtdp", help="Name to identify the dataset and the model")
|
| 142 |
+
parser.add_argument('--download_method', type=str, default=DownloadMethod.DOWNLOAD_RESOURCES, help="Download method for initializing the Pipeline. Default downloads the Armenian pipeline, --download_method NONE does not. Options: %s" % DownloadMethod._member_names_)
|
| 143 |
+
args = parser.parse_args()
|
| 144 |
+
|
| 145 |
+
convert_dataset(args.input_path, args.output_path, args.short_name, args.download_method)
|
stanza/stanza/utils/datasets/ner/convert_kk_kazNERD.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Convert a Kazakh NER dataset to our internal .json format
|
| 3 |
+
The dataset is here:
|
| 4 |
+
|
| 5 |
+
https://github.com/IS2AI/KazNERD/tree/main/KazNERD
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import argparse
|
| 9 |
+
import os
|
| 10 |
+
import shutil
|
| 11 |
+
# import random
|
| 12 |
+
|
| 13 |
+
from stanza.utils.datasets.ner.utils import convert_bio_to_json, SHARDS
|
| 14 |
+
|
| 15 |
+
def convert_dataset(in_directory, out_directory, short_name):
|
| 16 |
+
"""
|
| 17 |
+
Reads in train, validation, and test data and converts them to .json file
|
| 18 |
+
"""
|
| 19 |
+
filenames = ("IOB2_train.txt", "IOB2_valid.txt", "IOB2_test.txt")
|
| 20 |
+
for shard, filename in zip(SHARDS, filenames):
|
| 21 |
+
input_filename = os.path.join(in_directory, filename)
|
| 22 |
+
output_filename = os.path.join(out_directory, "%s.%s.bio" % (short_name, shard))
|
| 23 |
+
shutil.copy(input_filename, output_filename)
|
| 24 |
+
convert_bio_to_json(out_directory, out_directory, short_name, "bio")
|
| 25 |
+
|
| 26 |
+
if __name__ == '__main__':
|
| 27 |
+
parser = argparse.ArgumentParser()
|
| 28 |
+
parser.add_argument('--input_path', type=str, default="/nlp/scr/aaydin/kazNERD/NER", help="Where to find the files")
|
| 29 |
+
parser.add_argument('--output_path', type=str, default="/nlp/scr/aaydin/kazNERD/data/ner", help="Where to output the results")
|
| 30 |
+
args = parser.parse_args()
|
| 31 |
+
# in_path = '/nlp/scr/aaydin/kazNERD/NER'
|
| 32 |
+
# out_path = '/nlp/scr/aaydin/kazNERD/NER/output'
|
| 33 |
+
# convert_dataset(in_path, out_path)
|
| 34 |
+
convert_dataset(args.input_path, args.output_path, "kk_kazNERD")
|
| 35 |
+
|
stanza/stanza/utils/datasets/ner/convert_my_ucsy.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Processes the three pieces of the NER dataset we received from UCSY.
|
| 3 |
+
|
| 4 |
+
Requires the Myanmar tokenizer to exist, since the text is not already tokenized.
|
| 5 |
+
|
| 6 |
+
There are three files sent to us from UCSY, one each for train, dev, test
|
| 7 |
+
This script expects them to be in the ner directory with the names
|
| 8 |
+
$NERBASE/my_ucsy/Myanmar_NER_train.txt
|
| 9 |
+
$NERBASE/my_ucsy/Myanmar_NER_dev.txt
|
| 10 |
+
$NERBASE/my_ucsy/Myanmar_NER_test.txt
|
| 11 |
+
|
| 12 |
+
The files are in the following format:
|
| 13 |
+
unsegmentedtext@LABEL|unsegmentedtext@LABEL|...
|
| 14 |
+
with one sentence per line
|
| 15 |
+
|
| 16 |
+
Solution:
|
| 17 |
+
- break the text up into fragments by splitting on |
|
| 18 |
+
- extract the labels
|
| 19 |
+
- segment each block of text using the MY tokenizer
|
| 20 |
+
|
| 21 |
+
We could take two approaches to breaking up the blocks. One would be
|
| 22 |
+
to combine all chunks, then segment an entire sentence at once. This
|
| 23 |
+
would require some logic to re-chunk the resulting pieces. Instead,
|
| 24 |
+
we resegment each individual chunk by itself. This loses the
|
| 25 |
+
information from the neighboring chunks, but guarantees there are no
|
| 26 |
+
screwups where segmentation crosses segment boundaries and is simpler
|
| 27 |
+
to code.
|
| 28 |
+
|
| 29 |
+
Of course, experimenting with the alternate approach might be better.
|
| 30 |
+
|
| 31 |
+
There is one stray label of SB in the training data, so we throw out
|
| 32 |
+
that entire sentence.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
import os
|
| 37 |
+
|
| 38 |
+
from tqdm import tqdm
|
| 39 |
+
import stanza
|
| 40 |
+
from stanza.utils.datasets.ner.check_for_duplicates import check_for_duplicates
|
| 41 |
+
|
| 42 |
+
SPLITS = ("train", "dev", "test")
|
| 43 |
+
|
| 44 |
+
def convert_file(input_filename, output_filename, pipe):
|
| 45 |
+
with open(input_filename) as fin:
|
| 46 |
+
lines = fin.readlines()
|
| 47 |
+
|
| 48 |
+
all_labels = set()
|
| 49 |
+
|
| 50 |
+
with open(output_filename, "w") as fout:
|
| 51 |
+
for line in tqdm(lines):
|
| 52 |
+
pieces = line.split("|")
|
| 53 |
+
texts = []
|
| 54 |
+
labels = []
|
| 55 |
+
skip_sentence = False
|
| 56 |
+
for piece in pieces:
|
| 57 |
+
piece = piece.strip()
|
| 58 |
+
if not piece:
|
| 59 |
+
continue
|
| 60 |
+
text, label = piece.rsplit("@", maxsplit=1)
|
| 61 |
+
text = text.strip()
|
| 62 |
+
if not text:
|
| 63 |
+
continue
|
| 64 |
+
if label == 'SB':
|
| 65 |
+
skip_sentence = True
|
| 66 |
+
break
|
| 67 |
+
|
| 68 |
+
texts.append(text)
|
| 69 |
+
labels.append(label)
|
| 70 |
+
|
| 71 |
+
if skip_sentence:
|
| 72 |
+
continue
|
| 73 |
+
|
| 74 |
+
text = "\n\n".join(texts)
|
| 75 |
+
doc = pipe(text)
|
| 76 |
+
assert len(doc.sentences) == len(texts)
|
| 77 |
+
for sentence, label in zip(doc.sentences, labels):
|
| 78 |
+
all_labels.add(label)
|
| 79 |
+
for word_idx, word in enumerate(sentence.words):
|
| 80 |
+
if label == "O":
|
| 81 |
+
output_label = "O"
|
| 82 |
+
elif word_idx == 0:
|
| 83 |
+
output_label = "B-" + label
|
| 84 |
+
else:
|
| 85 |
+
output_label = "I-" + label
|
| 86 |
+
|
| 87 |
+
fout.write("%s\t%s\n" % (word.text, output_label))
|
| 88 |
+
fout.write("\n\n")
|
| 89 |
+
|
| 90 |
+
print("Finished processing {} Labels found: {}".format(input_filename, sorted(all_labels)))
|
| 91 |
+
|
| 92 |
+
def convert_my_ucsy(base_input_path, base_output_path):
|
| 93 |
+
os.makedirs(base_output_path, exist_ok=True)
|
| 94 |
+
pipe = stanza.Pipeline("my", processors="tokenize", tokenize_no_ssplit=True)
|
| 95 |
+
output_filenames = [os.path.join(base_output_path, "my_ucsy.%s.bio" % split) for split in SPLITS]
|
| 96 |
+
|
| 97 |
+
for split, output_filename in zip(SPLITS, output_filenames):
|
| 98 |
+
input_filename = os.path.join(base_input_path, "Myanmar_NER_%s.txt" % split)
|
| 99 |
+
if not os.path.exists(input_filename):
|
| 100 |
+
raise FileNotFoundError("Necessary file for my_ucsy does not exist: %s" % input_filename)
|
| 101 |
+
|
| 102 |
+
convert_file(input_filename, output_filename, pipe)
|
stanza/stanza/utils/datasets/ner/convert_sindhi_siner.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Converts the raw data from SiNER to .json for the Stanza NER system
|
| 3 |
+
|
| 4 |
+
https://aclanthology.org/2020.lrec-1.361.pdf
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from stanza.utils.datasets.ner.utils import write_dataset
|
| 8 |
+
|
| 9 |
+
def fix_sentence(sentence):
|
| 10 |
+
"""
|
| 11 |
+
Fix some of the mistags in the dataset
|
| 12 |
+
|
| 13 |
+
This covers 11 sentences: 1 P-PERSON, 2 with line breaks in the middle of the tag, and 8 with no B- or I-
|
| 14 |
+
"""
|
| 15 |
+
new_sentence = []
|
| 16 |
+
for word_idx, word in enumerate(sentence):
|
| 17 |
+
if word[1] == 'P-PERSON':
|
| 18 |
+
new_sentence.append((word[0], 'B-PERSON'))
|
| 19 |
+
elif word[1] == 'B-OT"':
|
| 20 |
+
new_sentence.append((word[0], 'B-OTHERS'))
|
| 21 |
+
elif word[1] == 'B-T"':
|
| 22 |
+
new_sentence.append((word[0], 'B-TITLE'))
|
| 23 |
+
elif word[1] in ('GPE', 'LOC', 'OTHERS'):
|
| 24 |
+
if len(new_sentence) > 0 and new_sentence[-1][1][:2] in ('B-', 'I-') and new_sentence[-1][1][2:] == word[1]:
|
| 25 |
+
# one example... no idea if it should be a break or
|
| 26 |
+
# not, but the last word translates to "Corporation",
|
| 27 |
+
# so probably not: ميٽرو پوليٽن ڪارپوريشن
|
| 28 |
+
new_sentence.append((word[0], 'I-' + word[1]))
|
| 29 |
+
else:
|
| 30 |
+
new_sentence.append((word[0], 'B-' + word[1]))
|
| 31 |
+
else:
|
| 32 |
+
new_sentence.append(word)
|
| 33 |
+
return new_sentence
|
| 34 |
+
|
| 35 |
+
def convert_sindhi_siner(in_filename, out_directory, short_name, train_frac=0.8, dev_frac=0.1):
|
| 36 |
+
"""
|
| 37 |
+
Read lines from the dataset, crudely separate sentences based on . or !, and write the dataset
|
| 38 |
+
"""
|
| 39 |
+
with open(in_filename, encoding="utf-8") as fin:
|
| 40 |
+
lines = fin.readlines()
|
| 41 |
+
|
| 42 |
+
lines = [x.strip().split("\t") for x in lines]
|
| 43 |
+
lines = [(x[0].strip(), x[1].strip()) for x in lines if len(x) == 2]
|
| 44 |
+
print("Read %d words from %s" % (len(lines), in_filename))
|
| 45 |
+
sentences = []
|
| 46 |
+
prev_idx = 0
|
| 47 |
+
for sent_idx, line in enumerate(lines):
|
| 48 |
+
# maybe also handle line[0] == '،', "Arabic comma"?
|
| 49 |
+
if line[0] in ('.', '!'):
|
| 50 |
+
sentences.append(lines[prev_idx:sent_idx+1])
|
| 51 |
+
prev_idx=sent_idx+1
|
| 52 |
+
|
| 53 |
+
# in case the file doesn't end with punctuation, grab the last few lines
|
| 54 |
+
if prev_idx < len(lines):
|
| 55 |
+
sentences.append(lines[prev_idx:])
|
| 56 |
+
|
| 57 |
+
print("Found %d sentences before splitting" % len(sentences))
|
| 58 |
+
sentences = [fix_sentence(x) for x in sentences]
|
| 59 |
+
assert not any('"' in x[1] or x[1].startswith("P-") or x[1] in ("GPE", "LOC", "OTHERS") for sentence in sentences for x in sentence)
|
| 60 |
+
|
| 61 |
+
train_len = int(len(sentences) * train_frac)
|
| 62 |
+
dev_len = int(len(sentences) * (train_frac+dev_frac))
|
| 63 |
+
train_sentences = sentences[:train_len]
|
| 64 |
+
dev_sentences = sentences[train_len:dev_len]
|
| 65 |
+
test_sentences = sentences[dev_len:]
|
| 66 |
+
|
| 67 |
+
datasets = (train_sentences, dev_sentences, test_sentences)
|
| 68 |
+
write_dataset(datasets, out_directory, short_name, suffix="bio")
|
| 69 |
+
|
stanza/stanza/utils/datasets/ner/convert_starlang_ner.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Convert the starlang trees to a NER dataset
|
| 3 |
+
|
| 4 |
+
Has to hide quite a few trees with missing NER labels
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import re
|
| 8 |
+
|
| 9 |
+
from stanza.models.constituency import tree_reader
|
| 10 |
+
import stanza.utils.datasets.constituency.convert_starlang as convert_starlang
|
| 11 |
+
|
| 12 |
+
TURKISH_WORD_RE = re.compile(r"[{]turkish=([^}]+)[}]")
|
| 13 |
+
TURKISH_LABEL_RE = re.compile(r"[{]namedEntity=([^}]+)[}]")
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def read_tree(text):
|
| 18 |
+
"""
|
| 19 |
+
Reads in a tree, then extracts the word and the NER
|
| 20 |
+
|
| 21 |
+
One problem is that it is unknown if there are cases of two separate items occurring consecutively
|
| 22 |
+
|
| 23 |
+
Note that this is quite similar to the convert_starlang script for constituency.
|
| 24 |
+
"""
|
| 25 |
+
trees = tree_reader.read_trees(text)
|
| 26 |
+
if len(trees) > 1:
|
| 27 |
+
raise ValueError("Tree file had two trees!")
|
| 28 |
+
tree = trees[0]
|
| 29 |
+
words = []
|
| 30 |
+
for label in tree.leaf_labels():
|
| 31 |
+
match = TURKISH_WORD_RE.search(label)
|
| 32 |
+
if match is None:
|
| 33 |
+
raise ValueError("Could not find word in |{}|".format(label))
|
| 34 |
+
word = match.group(1)
|
| 35 |
+
word = word.replace("-LCB-", "{").replace("-RCB-", "}")
|
| 36 |
+
|
| 37 |
+
match = TURKISH_LABEL_RE.search(label)
|
| 38 |
+
if match is None:
|
| 39 |
+
raise ValueError("Could not find ner in |{}|".format(label))
|
| 40 |
+
tag = match.group(1)
|
| 41 |
+
if tag == 'NONE' or tag == "null":
|
| 42 |
+
tag = 'O'
|
| 43 |
+
words.append((word, tag))
|
| 44 |
+
|
| 45 |
+
return words
|
| 46 |
+
|
| 47 |
+
def read_starlang(paths):
|
| 48 |
+
return convert_starlang.read_starlang(paths, conversion=read_tree, log=False)
|
| 49 |
+
|
| 50 |
+
def main():
|
| 51 |
+
train, dev, test = convert_starlang.main(conversion=read_tree, log=False)
|
| 52 |
+
|
| 53 |
+
if __name__ == '__main__':
|
| 54 |
+
main()
|
| 55 |
+
|
stanza/stanza/utils/datasets/ner/ontonotes_multitag.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Combines OntoNotes and WW into a single dataset with OntoNotes used for dev & test
|
| 3 |
+
|
| 4 |
+
The resulting dataset has two layers saved in the multi_ner column.
|
| 5 |
+
|
| 6 |
+
WW is kept as 9 classes, with the tag put in either the first or
|
| 7 |
+
second layer depending on the flags.
|
| 8 |
+
|
| 9 |
+
OntoNotes is converted to one column for 18 and one column for 9 classes.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import argparse
|
| 13 |
+
import json
|
| 14 |
+
import os
|
| 15 |
+
import shutil
|
| 16 |
+
|
| 17 |
+
from stanza.utils import default_paths
|
| 18 |
+
from stanza.utils.datasets.ner.utils import combine_files
|
| 19 |
+
from stanza.utils.datasets.ner.simplify_ontonotes_to_worldwide import simplify_ontonotes_to_worldwide
|
| 20 |
+
|
| 21 |
+
def convert_ontonotes_file(filename, simplify, bigger_first):
|
| 22 |
+
assert "en_ontonotes" in filename
|
| 23 |
+
if not os.path.exists(filename):
|
| 24 |
+
raise FileNotFoundError("Cannot convert missing file %s" % filename)
|
| 25 |
+
new_filename = filename.replace("en_ontonotes", "en_ontonotes-multi")
|
| 26 |
+
|
| 27 |
+
with open(filename) as fin:
|
| 28 |
+
doc = json.load(fin)
|
| 29 |
+
|
| 30 |
+
for sentence in doc:
|
| 31 |
+
for word in sentence:
|
| 32 |
+
ner = word['ner']
|
| 33 |
+
if simplify:
|
| 34 |
+
simplified = simplify_ontonotes_to_worldwide(ner)
|
| 35 |
+
else:
|
| 36 |
+
simplified = "-"
|
| 37 |
+
if bigger_first:
|
| 38 |
+
word['multi_ner'] = (ner, simplified)
|
| 39 |
+
else:
|
| 40 |
+
word['multi_ner'] = (simplified, ner)
|
| 41 |
+
|
| 42 |
+
with open(new_filename, "w") as fout:
|
| 43 |
+
json.dump(doc, fout, indent=2)
|
| 44 |
+
|
| 45 |
+
def convert_worldwide_file(filename, bigger_first):
|
| 46 |
+
assert "en_worldwide-9class" in filename
|
| 47 |
+
if not os.path.exists(filename):
|
| 48 |
+
raise FileNotFoundError("Cannot convert missing file %s" % filename)
|
| 49 |
+
|
| 50 |
+
new_filename = filename.replace("en_worldwide-9class", "en_worldwide-9class-multi")
|
| 51 |
+
|
| 52 |
+
with open(filename) as fin:
|
| 53 |
+
doc = json.load(fin)
|
| 54 |
+
|
| 55 |
+
for sentence in doc:
|
| 56 |
+
for word in sentence:
|
| 57 |
+
ner = word['ner']
|
| 58 |
+
if bigger_first:
|
| 59 |
+
word['multi_ner'] = ("-", ner)
|
| 60 |
+
else:
|
| 61 |
+
word['multi_ner'] = (ner, "-")
|
| 62 |
+
|
| 63 |
+
with open(new_filename, "w") as fout:
|
| 64 |
+
json.dump(doc, fout, indent=2)
|
| 65 |
+
|
| 66 |
+
def build_multitag_dataset(base_output_path, short_name, simplify, bigger_first):
|
| 67 |
+
convert_ontonotes_file(os.path.join(base_output_path, "en_ontonotes.train.json"), simplify, bigger_first)
|
| 68 |
+
convert_ontonotes_file(os.path.join(base_output_path, "en_ontonotes.dev.json"), simplify, bigger_first)
|
| 69 |
+
convert_ontonotes_file(os.path.join(base_output_path, "en_ontonotes.test.json"), simplify, bigger_first)
|
| 70 |
+
|
| 71 |
+
convert_worldwide_file(os.path.join(base_output_path, "en_worldwide-9class.train.json"), bigger_first)
|
| 72 |
+
convert_worldwide_file(os.path.join(base_output_path, "en_worldwide-9class.dev.json"), bigger_first)
|
| 73 |
+
convert_worldwide_file(os.path.join(base_output_path, "en_worldwide-9class.test.json"), bigger_first)
|
| 74 |
+
|
| 75 |
+
combine_files(os.path.join(base_output_path, "%s.train.json" % short_name),
|
| 76 |
+
os.path.join(base_output_path, "en_ontonotes-multi.train.json"),
|
| 77 |
+
os.path.join(base_output_path, "en_worldwide-9class-multi.train.json"))
|
| 78 |
+
shutil.copyfile(os.path.join(base_output_path, "en_ontonotes-multi.dev.json"),
|
| 79 |
+
os.path.join(base_output_path, "%s.dev.json" % short_name))
|
| 80 |
+
shutil.copyfile(os.path.join(base_output_path, "en_ontonotes-multi.test.json"),
|
| 81 |
+
os.path.join(base_output_path, "%s.test.json" % short_name))
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def main():
|
| 85 |
+
parser = argparse.ArgumentParser()
|
| 86 |
+
parser.add_argument('--no_simplify', dest='simplify', action='store_false', help='By default, this script will simplify the OntoNotes 18 classes to the 8 WorldWide classes in a second column. Turning that off will leave that column blank. Initial experiments with that setting were very bad, though')
|
| 87 |
+
parser.add_argument('--no_bigger_first', dest='bigger_first', action='store_false', help='By default, this script will put the 18 class tags in the first column and the 8 in the second. This flips the order')
|
| 88 |
+
args = parser.parse_args()
|
| 89 |
+
|
| 90 |
+
paths = default_paths.get_default_paths()
|
| 91 |
+
base_output_path = paths["NER_DATA_DIR"]
|
| 92 |
+
|
| 93 |
+
build_multitag_dataset(base_output_path, "en_ontonotes-ww-multi", args.simplify, args.bigger_first)
|
| 94 |
+
|
| 95 |
+
if __name__ == '__main__':
|
| 96 |
+
main()
|
| 97 |
+
|
stanza/stanza/utils/datasets/ner/prepare_ner_file.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This script converts NER data from the CoNLL03 format to the latest CoNLL-U format. The script assumes that in the
|
| 3 |
+
input column format data, the token is always in the first column, while the NER tag is always in the last column.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import argparse
|
| 7 |
+
import json
|
| 8 |
+
|
| 9 |
+
MIN_NUM_FIELD = 2
|
| 10 |
+
MAX_NUM_FIELD = 5
|
| 11 |
+
|
| 12 |
+
DOC_START_TOKEN = '-DOCSTART-'
|
| 13 |
+
|
| 14 |
+
def parse_args():
|
| 15 |
+
parser = argparse.ArgumentParser(description="Convert the conll03 format data into conllu format.")
|
| 16 |
+
parser.add_argument('input', help='Input conll03 format data filename.')
|
| 17 |
+
parser.add_argument('output', help='Output json filename.')
|
| 18 |
+
args = parser.parse_args()
|
| 19 |
+
return args
|
| 20 |
+
|
| 21 |
+
def main():
|
| 22 |
+
args = parse_args()
|
| 23 |
+
process_dataset(args.input, args.output)
|
| 24 |
+
|
| 25 |
+
def process_dataset(input_filename, output_filename):
|
| 26 |
+
sentences = load_conll03(input_filename)
|
| 27 |
+
print("{} examples loaded from {}".format(len(sentences), input_filename))
|
| 28 |
+
|
| 29 |
+
document = []
|
| 30 |
+
for (words, tags) in sentences:
|
| 31 |
+
sent = []
|
| 32 |
+
for w, t in zip(words, tags):
|
| 33 |
+
sent += [{'text': w, 'ner': t}]
|
| 34 |
+
document += [sent]
|
| 35 |
+
|
| 36 |
+
with open(output_filename, 'w', encoding="utf-8") as outfile:
|
| 37 |
+
json.dump(document, outfile, indent=1)
|
| 38 |
+
print("Generated json file {}".format(output_filename))
|
| 39 |
+
|
| 40 |
+
# TODO: make skip_doc_start an argument
|
| 41 |
+
def load_conll03(filename, skip_doc_start=True):
|
| 42 |
+
cached_lines = []
|
| 43 |
+
examples = []
|
| 44 |
+
with open(filename, encoding="utf-8") as infile:
|
| 45 |
+
for line in infile:
|
| 46 |
+
line = line.strip()
|
| 47 |
+
if skip_doc_start and DOC_START_TOKEN in line:
|
| 48 |
+
continue
|
| 49 |
+
if len(line) > 0:
|
| 50 |
+
array = line.split("\t")
|
| 51 |
+
if len(array) < MIN_NUM_FIELD:
|
| 52 |
+
array = line.split()
|
| 53 |
+
if len(array) < MIN_NUM_FIELD:
|
| 54 |
+
continue
|
| 55 |
+
else:
|
| 56 |
+
cached_lines.append(line)
|
| 57 |
+
elif len(cached_lines) > 0:
|
| 58 |
+
example = process_cache(cached_lines)
|
| 59 |
+
examples.append(example)
|
| 60 |
+
cached_lines = []
|
| 61 |
+
if len(cached_lines) > 0:
|
| 62 |
+
examples.append(process_cache(cached_lines))
|
| 63 |
+
return examples
|
| 64 |
+
|
| 65 |
+
def process_cache(cached_lines):
|
| 66 |
+
tokens = []
|
| 67 |
+
ner_tags = []
|
| 68 |
+
for line in cached_lines:
|
| 69 |
+
array = line.split("\t")
|
| 70 |
+
if len(array) < MIN_NUM_FIELD:
|
| 71 |
+
array = line.split()
|
| 72 |
+
assert len(array) >= MIN_NUM_FIELD and len(array) <= MAX_NUM_FIELD, "Got unexpected line length: {}".format(array)
|
| 73 |
+
tokens.append(array[0])
|
| 74 |
+
ner_tags.append(array[-1])
|
| 75 |
+
return (tokens, ner_tags)
|
| 76 |
+
|
| 77 |
+
if __name__ == '__main__':
|
| 78 |
+
main()
|
stanza/stanza/utils/datasets/ner/utils.py
ADDED
|
@@ -0,0 +1,417 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utils for the processing of NER datasets
|
| 3 |
+
|
| 4 |
+
These can be invoked from either the specific dataset scripts
|
| 5 |
+
or the entire prepare_ner_dataset.py script
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from collections import defaultdict
|
| 9 |
+
import json
|
| 10 |
+
import os
|
| 11 |
+
import random
|
| 12 |
+
|
| 13 |
+
from stanza.models.common.doc import Document
|
| 14 |
+
import stanza.utils.datasets.ner.prepare_ner_file as prepare_ner_file
|
| 15 |
+
|
| 16 |
+
SHARDS = ('train', 'dev', 'test')
|
| 17 |
+
|
| 18 |
+
def bioes_to_bio(tags):
|
| 19 |
+
new_tags = []
|
| 20 |
+
in_entity = False
|
| 21 |
+
for tag in tags:
|
| 22 |
+
if tag == 'O':
|
| 23 |
+
new_tags.append(tag)
|
| 24 |
+
in_entity = False
|
| 25 |
+
elif in_entity and (tag.startswith("B-") or tag.startswith("S-")):
|
| 26 |
+
# TODO: does the tag have to match the previous tag?
|
| 27 |
+
# eg, does B-LOC B-PER in BIOES need a B-PER or is I-PER sufficient?
|
| 28 |
+
new_tags.append('B-' + tag[2:])
|
| 29 |
+
else:
|
| 30 |
+
new_tags.append('I-' + tag[2:])
|
| 31 |
+
in_entity = True
|
| 32 |
+
return new_tags
|
| 33 |
+
|
| 34 |
+
def convert_bioes_to_bio(base_input_path, base_output_path, short_name):
|
| 35 |
+
"""
|
| 36 |
+
Convert BIOES files back to BIO (not BIO2)
|
| 37 |
+
|
| 38 |
+
Useful for preparing datasets for CoreNLP, which doesn't do great with the more highly split classes
|
| 39 |
+
"""
|
| 40 |
+
for shard in SHARDS:
|
| 41 |
+
input_filename = os.path.join(base_input_path, '%s.%s.bioes' % (short_name, shard))
|
| 42 |
+
output_filename = os.path.join(base_output_path, '%s.%s.bio' % (short_name, shard))
|
| 43 |
+
|
| 44 |
+
input_sentences = read_tsv(input_filename, text_column=0, annotation_column=1)
|
| 45 |
+
new_sentences = []
|
| 46 |
+
for sentence in input_sentences:
|
| 47 |
+
tags = [x[1] for x in sentence]
|
| 48 |
+
tags = bioes_to_bio(tags)
|
| 49 |
+
sentence = [(x[0], y) for x, y in zip(sentence, tags)]
|
| 50 |
+
new_sentences.append(sentence)
|
| 51 |
+
write_sentences(output_filename, new_sentences)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def convert_bio_to_json(base_input_path, base_output_path, short_name, suffix="bio", shard_names=SHARDS, shards=SHARDS):
|
| 55 |
+
"""
|
| 56 |
+
Convert BIO files to json
|
| 57 |
+
|
| 58 |
+
It can often be convenient to put the intermediate BIO files in
|
| 59 |
+
the same directory as the output files, in which case you can pass
|
| 60 |
+
in same path for both base_input_path and base_output_path.
|
| 61 |
+
|
| 62 |
+
This also will rewrite a BIOES as json
|
| 63 |
+
"""
|
| 64 |
+
for input_shard, output_shard in zip(shard_names, shards):
|
| 65 |
+
input_filename = os.path.join(base_input_path, '%s.%s.%s' % (short_name, input_shard, suffix))
|
| 66 |
+
if not os.path.exists(input_filename):
|
| 67 |
+
alt_filename = os.path.join(base_input_path, '%s.%s' % (input_shard, suffix))
|
| 68 |
+
if os.path.exists(alt_filename):
|
| 69 |
+
input_filename = alt_filename
|
| 70 |
+
else:
|
| 71 |
+
raise FileNotFoundError('Cannot find %s component of %s in %s or %s' % (output_shard, short_name, input_filename, alt_filename))
|
| 72 |
+
output_filename = os.path.join(base_output_path, '%s.%s.json' % (short_name, output_shard))
|
| 73 |
+
print("Converting %s to %s" % (input_filename, output_filename))
|
| 74 |
+
prepare_ner_file.process_dataset(input_filename, output_filename)
|
| 75 |
+
|
| 76 |
+
def get_tags(datasets):
|
| 77 |
+
"""
|
| 78 |
+
return the set of tags used in these datasets
|
| 79 |
+
|
| 80 |
+
datasets is expected to be train, dev, test but could be any list
|
| 81 |
+
"""
|
| 82 |
+
tags = set()
|
| 83 |
+
for dataset in datasets:
|
| 84 |
+
for sentence in dataset:
|
| 85 |
+
for word, tag in sentence:
|
| 86 |
+
tags.add(tag)
|
| 87 |
+
return tags
|
| 88 |
+
|
| 89 |
+
def write_sentences(output_filename, dataset):
|
| 90 |
+
"""
|
| 91 |
+
Write exactly one output file worth of dataset
|
| 92 |
+
"""
|
| 93 |
+
os.makedirs(os.path.split(output_filename)[0], exist_ok=True)
|
| 94 |
+
with open(output_filename, "w", encoding="utf-8") as fout:
|
| 95 |
+
for sent_idx, sentence in enumerate(dataset):
|
| 96 |
+
for word_idx, word in enumerate(sentence):
|
| 97 |
+
if len(word) > 2:
|
| 98 |
+
word = word[:2]
|
| 99 |
+
try:
|
| 100 |
+
fout.write("%s\t%s\n" % word)
|
| 101 |
+
except TypeError:
|
| 102 |
+
raise TypeError("Unable to process sentence %d word %d of file %s" % (sent_idx, word_idx, output_filename))
|
| 103 |
+
fout.write("\n")
|
| 104 |
+
|
| 105 |
+
def write_dataset(datasets, output_dir, short_name, suffix="bio", shard_names=SHARDS, shards=SHARDS):
|
| 106 |
+
"""
|
| 107 |
+
write all three pieces of a dataset to output_dir
|
| 108 |
+
|
| 109 |
+
datasets should be 3 lists: train, dev, test
|
| 110 |
+
each list should be a list of sentences
|
| 111 |
+
each sentence is a list of pairs: word, tag
|
| 112 |
+
|
| 113 |
+
after writing to .bio files, the files will be converted to .json
|
| 114 |
+
"""
|
| 115 |
+
for shard, dataset in zip(shard_names, datasets):
|
| 116 |
+
output_filename = os.path.join(output_dir, "%s.%s.%s" % (short_name, shard, suffix))
|
| 117 |
+
write_sentences(output_filename, dataset)
|
| 118 |
+
|
| 119 |
+
convert_bio_to_json(output_dir, output_dir, short_name, suffix, shard_names=shard_names, shards=shards)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def write_multitag_json(output_filename, dataset):
|
| 123 |
+
json_dataset = []
|
| 124 |
+
for sentence in dataset:
|
| 125 |
+
json_sentence = []
|
| 126 |
+
for word in sentence:
|
| 127 |
+
word = {'text': word[0],
|
| 128 |
+
'ner': word[1],
|
| 129 |
+
'multi_ner': word[2]}
|
| 130 |
+
json_sentence.append(word)
|
| 131 |
+
json_dataset.append(json_sentence)
|
| 132 |
+
with open(output_filename, 'w', encoding='utf-8') as fout:
|
| 133 |
+
json.dump(json_dataset, fout, indent=2)
|
| 134 |
+
|
| 135 |
+
def write_multitag_dataset(datasets, output_dir, short_name, suffix="bio", shard_names=SHARDS, shards=SHARDS):
|
| 136 |
+
for shard, dataset in zip(shard_names, datasets):
|
| 137 |
+
output_filename = os.path.join(output_dir, "%s.%s.%s" % (short_name, shard, suffix))
|
| 138 |
+
write_sentences(output_filename, dataset)
|
| 139 |
+
|
| 140 |
+
for shard, dataset in zip(shard_names, datasets):
|
| 141 |
+
output_filename = os.path.join(output_dir, "%s.%s.json" % (short_name, shard))
|
| 142 |
+
write_multitag_json(output_filename, dataset)
|
| 143 |
+
|
| 144 |
+
def read_tsv(filename, text_column, annotation_column, remap_fn=None, skip_comments=True, keep_broken_tags=False, keep_all_columns=False, separator="\t"):
|
| 145 |
+
"""
|
| 146 |
+
Read sentences from a TSV file
|
| 147 |
+
|
| 148 |
+
Returns a list of list of (word, tag)
|
| 149 |
+
|
| 150 |
+
If keep_broken_tags==True, then None is returned for a missing. Otherwise, an IndexError is thrown
|
| 151 |
+
"""
|
| 152 |
+
with open(filename, encoding="utf-8") as fin:
|
| 153 |
+
lines = fin.readlines()
|
| 154 |
+
|
| 155 |
+
lines = [x.strip() for x in lines]
|
| 156 |
+
|
| 157 |
+
sentences = []
|
| 158 |
+
current_sentence = []
|
| 159 |
+
for line_idx, line in enumerate(lines):
|
| 160 |
+
if not line:
|
| 161 |
+
if current_sentence:
|
| 162 |
+
sentences.append(current_sentence)
|
| 163 |
+
current_sentence = []
|
| 164 |
+
continue
|
| 165 |
+
if skip_comments and line.startswith("#"):
|
| 166 |
+
continue
|
| 167 |
+
|
| 168 |
+
pieces = line.split(separator)
|
| 169 |
+
try:
|
| 170 |
+
word = pieces[text_column]
|
| 171 |
+
except IndexError as e:
|
| 172 |
+
raise IndexError("Could not find word index %d at line %d |%s|" % (text_column, line_idx, line)) from e
|
| 173 |
+
if word == '\x96':
|
| 174 |
+
# this happens in GermEval2014 for some reason
|
| 175 |
+
continue
|
| 176 |
+
try:
|
| 177 |
+
tag = pieces[annotation_column]
|
| 178 |
+
except IndexError as e:
|
| 179 |
+
if keep_broken_tags:
|
| 180 |
+
tag = None
|
| 181 |
+
else:
|
| 182 |
+
raise IndexError("Could not find tag index %d at line %d |%s|" % (annotation_column, line_idx, line)) from e
|
| 183 |
+
if remap_fn:
|
| 184 |
+
tag = remap_fn(tag)
|
| 185 |
+
|
| 186 |
+
if keep_all_columns:
|
| 187 |
+
pieces[annotation_column] = tag
|
| 188 |
+
current_sentence.append(pieces)
|
| 189 |
+
else:
|
| 190 |
+
current_sentence.append((word, tag))
|
| 191 |
+
|
| 192 |
+
if current_sentence:
|
| 193 |
+
sentences.append(current_sentence)
|
| 194 |
+
|
| 195 |
+
return sentences
|
| 196 |
+
|
| 197 |
+
def random_shuffle_directory(input_dir, output_dir, short_name):
|
| 198 |
+
input_files = os.listdir(input_dir)
|
| 199 |
+
input_files = sorted(input_files)
|
| 200 |
+
random_shuffle_files(input_dir, input_files, output_dir, short_name)
|
| 201 |
+
|
| 202 |
+
def random_shuffle_files(input_dir, input_files, output_dir, short_name):
|
| 203 |
+
"""
|
| 204 |
+
Shuffle the files into different chunks based on their filename
|
| 205 |
+
|
| 206 |
+
The first piece of the filename, split by ".", is used as a random seed.
|
| 207 |
+
|
| 208 |
+
This will make it so that adding new files or using a different
|
| 209 |
+
annotation scheme (assuming that's encoding in pieces of the
|
| 210 |
+
filename) won't change the distibution of the files
|
| 211 |
+
"""
|
| 212 |
+
input_keys = {}
|
| 213 |
+
for f in input_files:
|
| 214 |
+
seed = f.split(".")[0]
|
| 215 |
+
if seed in input_keys:
|
| 216 |
+
raise ValueError("Multiple files with the same prefix: %s and %s" % (input_keys[seed], f))
|
| 217 |
+
input_keys[seed] = f
|
| 218 |
+
assert len(input_keys) == len(input_files)
|
| 219 |
+
|
| 220 |
+
train_files = []
|
| 221 |
+
dev_files = []
|
| 222 |
+
test_files = []
|
| 223 |
+
|
| 224 |
+
for filename in input_files:
|
| 225 |
+
seed = filename.split(".")[0]
|
| 226 |
+
# "salt" the filenames when using as a seed
|
| 227 |
+
# definitely not because of a dumb bug in the original implementation
|
| 228 |
+
seed = seed + ".txt.4class.tsv"
|
| 229 |
+
random.seed(seed, 2)
|
| 230 |
+
location = random.random()
|
| 231 |
+
if location < 0.7:
|
| 232 |
+
train_files.append(filename)
|
| 233 |
+
elif location < 0.8:
|
| 234 |
+
dev_files.append(filename)
|
| 235 |
+
else:
|
| 236 |
+
test_files.append(filename)
|
| 237 |
+
|
| 238 |
+
print("Train files: %d Dev files: %d Test files: %d" % (len(train_files), len(dev_files), len(test_files)))
|
| 239 |
+
assert len(train_files) + len(dev_files) + len(test_files) == len(input_files)
|
| 240 |
+
|
| 241 |
+
file_lists = [train_files, dev_files, test_files]
|
| 242 |
+
datasets = []
|
| 243 |
+
for files in file_lists:
|
| 244 |
+
dataset = []
|
| 245 |
+
for filename in files:
|
| 246 |
+
dataset.extend(read_tsv(os.path.join(input_dir, filename), 0, 1))
|
| 247 |
+
datasets.append(dataset)
|
| 248 |
+
|
| 249 |
+
write_dataset(datasets, output_dir, short_name)
|
| 250 |
+
return len(train_files), len(dev_files), len(test_files)
|
| 251 |
+
|
| 252 |
+
def random_shuffle_by_prefixes(input_dir, output_dir, short_name, prefix_map):
|
| 253 |
+
input_files = os.listdir(input_dir)
|
| 254 |
+
input_files = sorted(input_files)
|
| 255 |
+
|
| 256 |
+
file_divisions = defaultdict(list)
|
| 257 |
+
for filename in input_files:
|
| 258 |
+
for division in prefix_map.keys():
|
| 259 |
+
for prefix in prefix_map[division]:
|
| 260 |
+
if filename.startswith(prefix):
|
| 261 |
+
break
|
| 262 |
+
else: # for/else is intentional
|
| 263 |
+
continue
|
| 264 |
+
break
|
| 265 |
+
else: # yes, stop asking
|
| 266 |
+
raise ValueError("Could not assign %s to any of the divisions in the prefix_map" % filename)
|
| 267 |
+
#print("Assigning %s to %s because of %s" % (filename, division, prefix))
|
| 268 |
+
file_divisions[division].append(filename)
|
| 269 |
+
|
| 270 |
+
num_train_files = 0
|
| 271 |
+
num_dev_files = 0
|
| 272 |
+
num_test_files = 0
|
| 273 |
+
for division in file_divisions.keys():
|
| 274 |
+
print()
|
| 275 |
+
print("Processing %d files from %s" % (len(file_divisions[division]), division))
|
| 276 |
+
d_train, d_dev, d_test = random_shuffle_files(input_dir, file_divisions[division], output_dir, "%s-%s" % (short_name, division))
|
| 277 |
+
num_train_files += d_train
|
| 278 |
+
num_dev_files += d_dev
|
| 279 |
+
num_test_files += d_test
|
| 280 |
+
|
| 281 |
+
print()
|
| 282 |
+
print("After shuffling: Train files: %d Dev files: %d Test files: %d" % (num_train_files, num_dev_files, num_test_files))
|
| 283 |
+
dataset_divisions = ["%s-%s" % (short_name, division) for division in file_divisions]
|
| 284 |
+
combine_dataset(output_dir, output_dir, dataset_divisions, short_name)
|
| 285 |
+
|
| 286 |
+
def combine_dataset(input_dir, output_dir, input_datasets, output_dataset):
|
| 287 |
+
datasets = []
|
| 288 |
+
for shard in SHARDS:
|
| 289 |
+
full_dataset = []
|
| 290 |
+
for input_dataset in input_datasets:
|
| 291 |
+
input_filename = "%s.%s.json" % (input_dataset, shard)
|
| 292 |
+
input_path = os.path.join(input_dir, input_filename)
|
| 293 |
+
with open(input_path, encoding="utf-8") as fin:
|
| 294 |
+
dataset = json.load(fin)
|
| 295 |
+
converted = [[(word['text'], word['ner']) for word in sentence] for sentence in dataset]
|
| 296 |
+
full_dataset.extend(converted)
|
| 297 |
+
datasets.append(full_dataset)
|
| 298 |
+
write_dataset(datasets, output_dir, output_dataset)
|
| 299 |
+
|
| 300 |
+
def read_prefix_file(destination_file):
|
| 301 |
+
"""
|
| 302 |
+
Read a prefix file such as the one for the Worldwide dataset
|
| 303 |
+
|
| 304 |
+
the format should be
|
| 305 |
+
|
| 306 |
+
africa:
|
| 307 |
+
af_
|
| 308 |
+
...
|
| 309 |
+
|
| 310 |
+
asia:
|
| 311 |
+
cn_
|
| 312 |
+
...
|
| 313 |
+
"""
|
| 314 |
+
destination = None
|
| 315 |
+
known_prefixes = set()
|
| 316 |
+
prefixes = []
|
| 317 |
+
|
| 318 |
+
prefix_map = {}
|
| 319 |
+
with open(destination_file, encoding="utf-8") as fin:
|
| 320 |
+
for line in fin:
|
| 321 |
+
line = line.strip()
|
| 322 |
+
if line.startswith("#"):
|
| 323 |
+
continue
|
| 324 |
+
if not line:
|
| 325 |
+
continue
|
| 326 |
+
if line.endswith(":"):
|
| 327 |
+
if destination is not None:
|
| 328 |
+
prefix_map[destination] = prefixes
|
| 329 |
+
prefixes = []
|
| 330 |
+
destination = line[:-1].strip().lower().replace(" ", "_")
|
| 331 |
+
else:
|
| 332 |
+
if not destination:
|
| 333 |
+
raise RuntimeError("Found a prefix before the first label was assigned when reading %s" % destination_file)
|
| 334 |
+
prefixes.append(line)
|
| 335 |
+
if line in known_prefixes:
|
| 336 |
+
raise RuntimeError("Found the same prefix twice! %s" % line)
|
| 337 |
+
known_prefixes.add(line)
|
| 338 |
+
|
| 339 |
+
if destination and prefixes:
|
| 340 |
+
prefix_map[destination] = prefixes
|
| 341 |
+
|
| 342 |
+
return prefix_map
|
| 343 |
+
|
| 344 |
+
def read_json_entities(filename):
|
| 345 |
+
"""
|
| 346 |
+
Read entities from a file, return a list of (text, label)
|
| 347 |
+
|
| 348 |
+
Should work on both BIOES and BIO
|
| 349 |
+
"""
|
| 350 |
+
with open(filename) as fin:
|
| 351 |
+
doc = Document(json.load(fin))
|
| 352 |
+
|
| 353 |
+
return list_doc_entities(doc)
|
| 354 |
+
|
| 355 |
+
def list_doc_entities(doc):
|
| 356 |
+
"""
|
| 357 |
+
Return a list of (text, label)
|
| 358 |
+
|
| 359 |
+
Should work on both BIOES and BIO
|
| 360 |
+
"""
|
| 361 |
+
entities = []
|
| 362 |
+
for sentence in doc.sentences:
|
| 363 |
+
current_entity = []
|
| 364 |
+
previous_label = None
|
| 365 |
+
for token in sentence.tokens:
|
| 366 |
+
if token.ner == 'O' or token.ner.startswith("E-"):
|
| 367 |
+
if token.ner.startswith("E-"):
|
| 368 |
+
current_entity.append(token.text)
|
| 369 |
+
if current_entity:
|
| 370 |
+
assert previous_label is not None
|
| 371 |
+
entities.append((current_entity, previous_label))
|
| 372 |
+
current_entity = []
|
| 373 |
+
previous_label = None
|
| 374 |
+
elif token.ner.startswith("I-"):
|
| 375 |
+
if previous_label is not None and previous_label != 'O' and previous_label != token.ner[2:]:
|
| 376 |
+
if current_entity:
|
| 377 |
+
assert previous_label is not None
|
| 378 |
+
entities.append((current_entity, previous_label))
|
| 379 |
+
current_entity = []
|
| 380 |
+
previous_label = token.ner[2:]
|
| 381 |
+
current_entity.append(token.text)
|
| 382 |
+
elif token.ner.startswith("B-") or token.ner.startswith("S-"):
|
| 383 |
+
if current_entity:
|
| 384 |
+
assert previous_label is not None
|
| 385 |
+
entities.append((current_entity, previous_label))
|
| 386 |
+
current_entity = []
|
| 387 |
+
previous_label = None
|
| 388 |
+
current_entity.append(token.text)
|
| 389 |
+
previous_label = token.ner[2:]
|
| 390 |
+
if token.ner.startswith("S-"):
|
| 391 |
+
assert previous_label is not None
|
| 392 |
+
entities.append(current_entity)
|
| 393 |
+
current_entity = []
|
| 394 |
+
previous_label = None
|
| 395 |
+
else:
|
| 396 |
+
raise RuntimeError("Expected BIO(ES) format in the json file!")
|
| 397 |
+
previous_label = token.ner[2:]
|
| 398 |
+
if current_entity:
|
| 399 |
+
assert previous_label is not None
|
| 400 |
+
entities.append((current_entity, previous_label))
|
| 401 |
+
entities = [(tuple(x[0]), x[1]) for x in entities]
|
| 402 |
+
return entities
|
| 403 |
+
|
| 404 |
+
def combine_files(output_filename, *input_filenames):
|
| 405 |
+
"""
|
| 406 |
+
Combine multiple NER json files into one NER file
|
| 407 |
+
"""
|
| 408 |
+
doc = []
|
| 409 |
+
|
| 410 |
+
for filename in input_filenames:
|
| 411 |
+
with open(filename) as fin:
|
| 412 |
+
new_doc = json.load(fin)
|
| 413 |
+
doc.extend(new_doc)
|
| 414 |
+
|
| 415 |
+
with open(output_filename, "w") as fout:
|
| 416 |
+
json.dump(doc, fout, indent=2)
|
| 417 |
+
|
stanza/stanza/utils/datasets/vietnamese/__init__.py
ADDED
|
File without changes
|
stanza/stanza/utils/pretrain/compare_pretrains.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
from stanza.models.common.pretrain import Pretrain
|
| 5 |
+
|
| 6 |
+
pt1_filename = sys.argv[1]
|
| 7 |
+
pt2_filename = sys.argv[2]
|
| 8 |
+
|
| 9 |
+
pt1 = Pretrain(pt1_filename)
|
| 10 |
+
pt2 = Pretrain(pt2_filename)
|
| 11 |
+
|
| 12 |
+
vocab1 = pt1.vocab
|
| 13 |
+
vocab2 = pt2.vocab
|
| 14 |
+
|
| 15 |
+
common_words = [x for x in vocab1 if x in vocab2]
|
| 16 |
+
print("%d shared words, out of %d in %s and %d in %s" % (len(common_words), len(vocab1), pt1_filename, len(vocab2), pt2_filename))
|
| 17 |
+
|
| 18 |
+
eps = 0.0001
|
| 19 |
+
total_norm = 0.0
|
| 20 |
+
total_close = 0
|
| 21 |
+
|
| 22 |
+
words_different = []
|
| 23 |
+
|
| 24 |
+
for word, idx in vocab1._unit2id.items():
|
| 25 |
+
if word not in vocab2:
|
| 26 |
+
continue
|
| 27 |
+
v1 = pt1.emb[idx]
|
| 28 |
+
v2 = pt2.emb[pt2.vocab[word]]
|
| 29 |
+
norm = np.linalg.norm(v1 - v2)
|
| 30 |
+
|
| 31 |
+
if norm < eps:
|
| 32 |
+
total_close += 1
|
| 33 |
+
else:
|
| 34 |
+
total_norm += norm
|
| 35 |
+
if len(words_different) < 10:
|
| 36 |
+
words_different.append("|%s|" % word)
|
| 37 |
+
#print(word, idx, pt2.vocab[word])
|
| 38 |
+
#print(v1)
|
| 39 |
+
#print(v2)
|
| 40 |
+
|
| 41 |
+
if total_close < len(common_words):
|
| 42 |
+
avg_norm = total_norm / (len(common_words) - total_close)
|
| 43 |
+
print("%d vectors were close. Average difference of the others: %f" % (total_close, avg_norm))
|
| 44 |
+
print("The first few different words were:\n %s" % "\n ".join(words_different))
|
| 45 |
+
else:
|
| 46 |
+
print("All %d vectors were close!" % total_close)
|
| 47 |
+
|
| 48 |
+
for word, idx in vocab1._unit2id.items():
|
| 49 |
+
if word not in vocab2:
|
| 50 |
+
continue
|
| 51 |
+
if pt2.vocab[word] != idx:
|
| 52 |
+
break
|
| 53 |
+
else:
|
| 54 |
+
print("All indices are the same")
|
stanza/stanza/utils/training/common.py
ADDED
|
@@ -0,0 +1,397 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import glob
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
import pathlib
|
| 6 |
+
import sys
|
| 7 |
+
import tempfile
|
| 8 |
+
|
| 9 |
+
from enum import Enum
|
| 10 |
+
|
| 11 |
+
from stanza.resources.default_packages import default_charlms, lemma_charlms, pos_charlms, depparse_charlms, TRANSFORMERS, TRANSFORMER_LAYERS
|
| 12 |
+
from stanza.models.common.constant import treebank_to_short_name
|
| 13 |
+
from stanza.models.common.utils import ud_scores
|
| 14 |
+
from stanza.resources.common import download, DEFAULT_MODEL_DIR, UnknownLanguageError
|
| 15 |
+
from stanza.utils.datasets import common
|
| 16 |
+
import stanza.utils.default_paths as default_paths
|
| 17 |
+
from stanza.utils import conll18_ud_eval as ud_eval
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger('stanza')
|
| 20 |
+
|
| 21 |
+
class Mode(Enum):
|
| 22 |
+
TRAIN = 1
|
| 23 |
+
SCORE_DEV = 2
|
| 24 |
+
SCORE_TEST = 3
|
| 25 |
+
SCORE_TRAIN = 4
|
| 26 |
+
|
| 27 |
+
class ArgumentParserWithExtraHelp(argparse.ArgumentParser):
|
| 28 |
+
def __init__(self, sub_argparse, *args, **kwargs):
|
| 29 |
+
super().__init__(*args, **kwargs) # forwards all unused arguments
|
| 30 |
+
|
| 31 |
+
self.sub_argparse = sub_argparse
|
| 32 |
+
|
| 33 |
+
def print_help(self, file=None):
|
| 34 |
+
super().print_help(file=file)
|
| 35 |
+
|
| 36 |
+
def format_help(self):
|
| 37 |
+
help_text = super().format_help()
|
| 38 |
+
if self.sub_argparse is not None:
|
| 39 |
+
sub_text = self.sub_argparse.format_help().split("\n")
|
| 40 |
+
first_line = -1
|
| 41 |
+
for line_idx, line in enumerate(sub_text):
|
| 42 |
+
if line.strip().startswith("usage:"):
|
| 43 |
+
first_line = line_idx
|
| 44 |
+
elif first_line >= 0 and not line.strip():
|
| 45 |
+
first_line = line_idx
|
| 46 |
+
break
|
| 47 |
+
help_text = help_text + "\n\nmodel arguments:" + "\n".join(sub_text[first_line:])
|
| 48 |
+
return help_text
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def build_argparse(sub_argparse=None):
|
| 52 |
+
parser = ArgumentParserWithExtraHelp(sub_argparse=sub_argparse, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
| 53 |
+
parser.add_argument('--save_output', dest='temp_output', default=True, action='store_false', help="Save output - default is to use a temp directory.")
|
| 54 |
+
|
| 55 |
+
parser.add_argument('treebanks', type=str, nargs='+', help='Which treebanks to run on. Use all_ud or ud_all for all UD treebanks')
|
| 56 |
+
|
| 57 |
+
parser.add_argument('--train', dest='mode', default=Mode.TRAIN, action='store_const', const=Mode.TRAIN, help='Run in train mode')
|
| 58 |
+
parser.add_argument('--score_dev', dest='mode', action='store_const', const=Mode.SCORE_DEV, help='Score the dev set')
|
| 59 |
+
parser.add_argument('--score_test', dest='mode', action='store_const', const=Mode.SCORE_TEST, help='Score the test set')
|
| 60 |
+
parser.add_argument('--score_train', dest='mode', action='store_const', const=Mode.SCORE_TRAIN, help='Score the train set as a test set. Currently only implemented for some models')
|
| 61 |
+
|
| 62 |
+
# These arguments need to be here so we can identify if the model already exists in the user-specified home
|
| 63 |
+
# TODO: when all of the model scripts handle their own names, can eliminate this argument
|
| 64 |
+
parser.add_argument('--save_dir', type=str, default=None, help="Root dir for saving models. If set, will override the model's default.")
|
| 65 |
+
parser.add_argument('--save_name', type=str, default=None, help="Base name for saving models. If set, will override the model's default.")
|
| 66 |
+
|
| 67 |
+
parser.add_argument('--charlm_only', action='store_true', default=False, help='When asking for ud_all, filter the ones which have charlms')
|
| 68 |
+
parser.add_argument('--transformer_only', action='store_true', default=False, help='When asking for ud_all, filter the ones for languages where we have transformers')
|
| 69 |
+
|
| 70 |
+
parser.add_argument('--force', dest='force', action='store_true', default=False, help='Retrain existing models')
|
| 71 |
+
return parser
|
| 72 |
+
|
| 73 |
+
def add_charlm_args(parser):
|
| 74 |
+
parser.add_argument('--charlm', default="default", type=str, help='Which charlm to run on. Will use the default charlm for this language/model if not set. Set to None to turn off charlm for languages with a default charlm')
|
| 75 |
+
parser.add_argument('--no_charlm', dest='charlm', action="store_const", const=None, help="Don't use a charlm, even if one is used by default for this package")
|
| 76 |
+
|
| 77 |
+
def main(run_treebank, model_dir, model_name, add_specific_args=None, sub_argparse=None, build_model_filename=None, choose_charlm_method=None, args=None):
|
| 78 |
+
"""
|
| 79 |
+
A main program for each of the run_xyz scripts
|
| 80 |
+
|
| 81 |
+
It collects the arguments and runs the main method for each dataset provided.
|
| 82 |
+
It also tries to look for an existing model and not overwrite it unless --force is provided
|
| 83 |
+
|
| 84 |
+
model_name can be a callable expecting the args
|
| 85 |
+
- the charlm, for example, needs this feature, since it makes
|
| 86 |
+
both forward and backward models
|
| 87 |
+
"""
|
| 88 |
+
if args is None:
|
| 89 |
+
logger.info("Training program called with:\n" + " ".join(sys.argv))
|
| 90 |
+
args = sys.argv[1:]
|
| 91 |
+
else:
|
| 92 |
+
logger.info("Training program called with:\n" + " ".join(args))
|
| 93 |
+
|
| 94 |
+
paths = default_paths.get_default_paths()
|
| 95 |
+
|
| 96 |
+
parser = build_argparse(sub_argparse)
|
| 97 |
+
if add_specific_args is not None:
|
| 98 |
+
add_specific_args(parser)
|
| 99 |
+
if '--extra_args' in sys.argv:
|
| 100 |
+
idx = sys.argv.index('--extra_args')
|
| 101 |
+
extra_args = sys.argv[idx+1:]
|
| 102 |
+
command_args = parser.parse_args(sys.argv[:idx])
|
| 103 |
+
else:
|
| 104 |
+
command_args, extra_args = parser.parse_known_args(args=args)
|
| 105 |
+
|
| 106 |
+
# Pass this through to the underlying model as well as use it here
|
| 107 |
+
# we don't put --save_name here for the awkward situation of
|
| 108 |
+
# --save_name being specified for an invocation with multiple treebanks
|
| 109 |
+
if command_args.save_dir:
|
| 110 |
+
extra_args.extend(["--save_dir", command_args.save_dir])
|
| 111 |
+
|
| 112 |
+
if callable(model_name):
|
| 113 |
+
model_name = model_name(command_args)
|
| 114 |
+
|
| 115 |
+
mode = command_args.mode
|
| 116 |
+
treebanks = []
|
| 117 |
+
|
| 118 |
+
for treebank in command_args.treebanks:
|
| 119 |
+
# this is a really annoying typo to make if you copy/paste a
|
| 120 |
+
# UD directory name on the cluster and your job dies 30s after
|
| 121 |
+
# being queued for an hour
|
| 122 |
+
if treebank.endswith("/"):
|
| 123 |
+
treebank = treebank[:-1]
|
| 124 |
+
if treebank.lower() in ('ud_all', 'all_ud'):
|
| 125 |
+
ud_treebanks = common.get_ud_treebanks(paths["UDBASE"])
|
| 126 |
+
if choose_charlm_method is not None and command_args.charlm_only:
|
| 127 |
+
logger.info("Filtering ud_all treebanks to only those which can use charlm for this model")
|
| 128 |
+
ud_treebanks = [x for x in ud_treebanks
|
| 129 |
+
if choose_charlm_method(*treebank_to_short_name(x).split("_", 1), 'default') is not None]
|
| 130 |
+
if command_args.transformer_only:
|
| 131 |
+
logger.info("Filtering ud_all treebanks to only those which can use a transformer for this model")
|
| 132 |
+
ud_treebanks = [x for x in ud_treebanks if treebank_to_short_name(x).split("_")[0] in TRANSFORMERS]
|
| 133 |
+
logger.info("Expanding %s to %s", treebank, " ".join(ud_treebanks))
|
| 134 |
+
treebanks.extend(ud_treebanks)
|
| 135 |
+
else:
|
| 136 |
+
treebanks.append(treebank)
|
| 137 |
+
|
| 138 |
+
for treebank_idx, treebank in enumerate(treebanks):
|
| 139 |
+
if treebank_idx > 0:
|
| 140 |
+
logger.info("=========================================")
|
| 141 |
+
|
| 142 |
+
short_name = treebank_to_short_name(treebank)
|
| 143 |
+
logger.debug("%s: %s" % (treebank, short_name))
|
| 144 |
+
|
| 145 |
+
save_name_args = []
|
| 146 |
+
if model_name != 'ete':
|
| 147 |
+
# ete is several models at once, so we don't set --save_name
|
| 148 |
+
# theoretically we could handle a parametrized save_name
|
| 149 |
+
if command_args.save_name:
|
| 150 |
+
save_name = command_args.save_name
|
| 151 |
+
# if there's more than 1 treebank, we can't save them all to this save_name
|
| 152 |
+
# we have to override that value for each treebank
|
| 153 |
+
if len(treebanks) > 1:
|
| 154 |
+
save_name_dir, save_name_filename = os.path.split(save_name)
|
| 155 |
+
save_name_filename = "%s_%s" % (short_name, save_name_filename)
|
| 156 |
+
save_name = os.path.join(save_name_dir, save_name_filename)
|
| 157 |
+
logger.info("Save file for %s model for %s: %s", short_name, treebank, save_name)
|
| 158 |
+
save_name_args = ['--save_name', save_name]
|
| 159 |
+
# some run scripts can build the model filename
|
| 160 |
+
# in order to check for models that are already created
|
| 161 |
+
elif build_model_filename is None:
|
| 162 |
+
save_name = "%s_%s.pt" % (short_name, model_name)
|
| 163 |
+
logger.info("Save file for %s model: %s", short_name, save_name)
|
| 164 |
+
save_name_args = ['--save_name', save_name]
|
| 165 |
+
else:
|
| 166 |
+
save_name_args = []
|
| 167 |
+
|
| 168 |
+
if mode == Mode.TRAIN and not command_args.force:
|
| 169 |
+
if build_model_filename is not None:
|
| 170 |
+
model_path = build_model_filename(paths, short_name, command_args, extra_args)
|
| 171 |
+
elif command_args.save_dir:
|
| 172 |
+
model_path = os.path.join(command_args.save_dir, save_name)
|
| 173 |
+
else:
|
| 174 |
+
save_dir = os.path.join("saved_models", model_dir)
|
| 175 |
+
save_name_args.extend(["--save_dir", save_dir])
|
| 176 |
+
model_path = os.path.join(save_dir, save_name)
|
| 177 |
+
|
| 178 |
+
if model_path is None:
|
| 179 |
+
# this can happen with the identity lemmatizer, for example
|
| 180 |
+
pass
|
| 181 |
+
elif os.path.exists(model_path):
|
| 182 |
+
logger.info("%s: %s exists, skipping!" % (treebank, model_path))
|
| 183 |
+
continue
|
| 184 |
+
else:
|
| 185 |
+
logger.info("%s: %s does not exist, training new model" % (treebank, model_path))
|
| 186 |
+
|
| 187 |
+
if command_args.temp_output and model_name != 'ete':
|
| 188 |
+
with tempfile.NamedTemporaryFile() as temp_output_file:
|
| 189 |
+
run_treebank(mode, paths, treebank, short_name,
|
| 190 |
+
temp_output_file.name, command_args, extra_args + save_name_args)
|
| 191 |
+
else:
|
| 192 |
+
run_treebank(mode, paths, treebank, short_name,
|
| 193 |
+
None, command_args, extra_args + save_name_args)
|
| 194 |
+
|
| 195 |
+
def run_eval_script(gold_conllu_file, system_conllu_file, evals=None):
|
| 196 |
+
""" Wrapper for lemma scorer. """
|
| 197 |
+
evaluation = ud_scores(gold_conllu_file, system_conllu_file)
|
| 198 |
+
|
| 199 |
+
if evals is None:
|
| 200 |
+
return ud_eval.build_evaluation_table(evaluation, verbose=True, counts=False, enhanced=False)
|
| 201 |
+
else:
|
| 202 |
+
results = [evaluation[key].f1 for key in evals]
|
| 203 |
+
max_len = max(5, max(len(e) for e in evals))
|
| 204 |
+
evals_string = " ".join(("{:>%d}" % max_len).format(e) for e in evals)
|
| 205 |
+
results_string = " ".join(("{:%d.2f}" % max_len).format(100 * x) for x in results)
|
| 206 |
+
return evals_string + "\n" + results_string
|
| 207 |
+
|
| 208 |
+
def run_eval_script_tokens(eval_gold, eval_pred):
|
| 209 |
+
return run_eval_script(eval_gold, eval_pred, evals=["Tokens", "Sentences", "Words"])
|
| 210 |
+
|
| 211 |
+
def run_eval_script_mwt(eval_gold, eval_pred):
|
| 212 |
+
return run_eval_script(eval_gold, eval_pred, evals=["Words"])
|
| 213 |
+
|
| 214 |
+
def run_eval_script_pos(eval_gold, eval_pred):
|
| 215 |
+
return run_eval_script(eval_gold, eval_pred, evals=["UPOS", "XPOS", "UFeats", "AllTags"])
|
| 216 |
+
|
| 217 |
+
def run_eval_script_depparse(eval_gold, eval_pred):
|
| 218 |
+
return run_eval_script(eval_gold, eval_pred, evals=["UAS", "LAS", "CLAS", "MLAS", "BLEX"])
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def find_wordvec_pretrain(language, default_pretrains, dataset_pretrains=None, dataset=None, model_dir=DEFAULT_MODEL_DIR):
|
| 222 |
+
# try to get the default pretrain for the language,
|
| 223 |
+
# but allow the package specific value to override it if that is set
|
| 224 |
+
default_pt = default_pretrains.get(language, None)
|
| 225 |
+
if dataset is not None and dataset_pretrains is not None:
|
| 226 |
+
default_pt = dataset_pretrains.get(language, {}).get(dataset, default_pt)
|
| 227 |
+
|
| 228 |
+
if default_pt is not None:
|
| 229 |
+
default_pt_path = '{}/{}/pretrain/{}.pt'.format(model_dir, language, default_pt)
|
| 230 |
+
if not os.path.exists(default_pt_path):
|
| 231 |
+
logger.info("Default pretrain should be {} Attempting to download".format(default_pt_path))
|
| 232 |
+
try:
|
| 233 |
+
download(lang=language, package=None, processors={"pretrain": default_pt}, model_dir=model_dir)
|
| 234 |
+
except UnknownLanguageError:
|
| 235 |
+
# if there's a pretrain in the directory, hiding this
|
| 236 |
+
# error will let us find that pretrain later
|
| 237 |
+
pass
|
| 238 |
+
if os.path.exists(default_pt_path):
|
| 239 |
+
if dataset is not None and dataset_pretrains is not None and language in dataset_pretrains and dataset in dataset_pretrains[language]:
|
| 240 |
+
logger.info(f"Using default pretrain for {language}:{dataset}, found in {default_pt_path} To use a different pretrain, specify --wordvec_pretrain_file")
|
| 241 |
+
else:
|
| 242 |
+
logger.info(f"Using default pretrain for language, found in {default_pt_path} To use a different pretrain, specify --wordvec_pretrain_file")
|
| 243 |
+
return default_pt_path
|
| 244 |
+
|
| 245 |
+
pretrain_path = '{}/{}/pretrain/*.pt'.format(model_dir, language)
|
| 246 |
+
pretrains = glob.glob(pretrain_path)
|
| 247 |
+
if len(pretrains) == 0:
|
| 248 |
+
# we already tried to download the default pretrain once
|
| 249 |
+
# and it didn't work. maybe the default language package
|
| 250 |
+
# will have something?
|
| 251 |
+
logger.warning(f"Cannot figure out which pretrain to use for '{language}'. Will download the default package and hope for the best")
|
| 252 |
+
try:
|
| 253 |
+
download(lang=language, model_dir=model_dir)
|
| 254 |
+
except UnknownLanguageError as e:
|
| 255 |
+
# this is a very unusual situation
|
| 256 |
+
# basically, there was a language which we started to add
|
| 257 |
+
# to the resources, but then didn't release the models
|
| 258 |
+
# as part of resources.json
|
| 259 |
+
raise FileNotFoundError(f"Cannot find any pretrains in {pretrain_path} No pretrains in the system for this language. Please prepare an embedding as a .pt and use --wordvec_pretrain_file to specify a .pt file to use") from e
|
| 260 |
+
pretrains = glob.glob(pretrain_path)
|
| 261 |
+
if len(pretrains) == 0:
|
| 262 |
+
raise FileNotFoundError(f"Cannot find any pretrains in {pretrain_path} Try 'stanza.download(\"{language}\")' to get a default pretrain or use --wordvec_pretrain_file to specify a .pt file to use")
|
| 263 |
+
if len(pretrains) > 1:
|
| 264 |
+
raise FileNotFoundError(f"Too many pretrains to choose from in {pretrain_path} Must specify an exact path to a --wordvec_pretrain_file")
|
| 265 |
+
pt = pretrains[0]
|
| 266 |
+
logger.info(f"Using pretrain found in {pt} To use a different pretrain, specify --wordvec_pretrain_file")
|
| 267 |
+
return pt
|
| 268 |
+
|
| 269 |
+
def find_charlm_file(direction, language, charlm, model_dir=DEFAULT_MODEL_DIR):
|
| 270 |
+
"""
|
| 271 |
+
Return the path to the forward or backward charlm if it exists for the given package
|
| 272 |
+
|
| 273 |
+
If we can figure out the package, but can't find it anywhere, we try to download it
|
| 274 |
+
"""
|
| 275 |
+
saved_path = 'saved_models/charlm/{}_{}_{}_charlm.pt'.format(language, charlm, direction)
|
| 276 |
+
if os.path.exists(saved_path):
|
| 277 |
+
logger.info(f'Using model {saved_path} for {direction} charlm')
|
| 278 |
+
return saved_path
|
| 279 |
+
|
| 280 |
+
resource_path = '{}/{}/{}_charlm/{}.pt'.format(model_dir, language, direction, charlm)
|
| 281 |
+
if os.path.exists(resource_path):
|
| 282 |
+
logger.info(f'Using model {resource_path} for {direction} charlm')
|
| 283 |
+
return resource_path
|
| 284 |
+
|
| 285 |
+
try:
|
| 286 |
+
download(lang=language, package=None, processors={f"{direction}_charlm": charlm}, model_dir=model_dir)
|
| 287 |
+
if os.path.exists(resource_path):
|
| 288 |
+
logger.info(f'Downloaded model, using model {resource_path} for {direction} charlm')
|
| 289 |
+
return resource_path
|
| 290 |
+
except ValueError as e:
|
| 291 |
+
raise FileNotFoundError(f"Cannot find {direction} charlm in either {saved_path} or {resource_path} Attempted downloading {charlm} but that did not work") from e
|
| 292 |
+
|
| 293 |
+
raise FileNotFoundError(f"Cannot find {direction} charlm in either {saved_path} or {resource_path} Attempted downloading {charlm} but that did not work")
|
| 294 |
+
|
| 295 |
+
def build_charlm_args(language, charlm, base_args=True, model_dir=DEFAULT_MODEL_DIR):
|
| 296 |
+
"""
|
| 297 |
+
If specified, return forward and backward charlm args
|
| 298 |
+
"""
|
| 299 |
+
if charlm:
|
| 300 |
+
try:
|
| 301 |
+
forward = find_charlm_file('forward', language, charlm, model_dir=model_dir)
|
| 302 |
+
backward = find_charlm_file('backward', language, charlm, model_dir=model_dir)
|
| 303 |
+
except FileNotFoundError as e:
|
| 304 |
+
# if we couldn't find sd_isra when training an SD model,
|
| 305 |
+
# for example, but isra exists, we try to download the
|
| 306 |
+
# shorter model name
|
| 307 |
+
if charlm.startswith(language + "_"):
|
| 308 |
+
short_charlm = charlm[len(language)+1:]
|
| 309 |
+
try:
|
| 310 |
+
forward = find_charlm_file('forward', language, short_charlm, model_dir=model_dir)
|
| 311 |
+
backward = find_charlm_file('backward', language, short_charlm, model_dir=model_dir)
|
| 312 |
+
except FileNotFoundError as e2:
|
| 313 |
+
raise FileNotFoundError("Tried to find charlm %s, which doesn't exist. Also tried %s, but didn't find that either" % (charlm, short_charlm)) from e
|
| 314 |
+
logger.warning("Was asked to find charlm %s, which does not exist. Did find %s though", charlm, short_charlm)
|
| 315 |
+
else:
|
| 316 |
+
raise
|
| 317 |
+
|
| 318 |
+
char_args = ['--charlm_forward_file', forward,
|
| 319 |
+
'--charlm_backward_file', backward]
|
| 320 |
+
if not base_args:
|
| 321 |
+
return char_args
|
| 322 |
+
return ['--charlm',
|
| 323 |
+
'--charlm_shorthand', f'{language}_{charlm}'] + char_args
|
| 324 |
+
|
| 325 |
+
return []
|
| 326 |
+
|
| 327 |
+
def choose_charlm(language, dataset, charlm, language_charlms, dataset_charlms):
|
| 328 |
+
"""
|
| 329 |
+
charlm == "default" means the default charlm for this dataset or language
|
| 330 |
+
charlm == None is no charlm
|
| 331 |
+
"""
|
| 332 |
+
default_charlm = language_charlms.get(language, None)
|
| 333 |
+
specific_charlm = dataset_charlms.get(language, {}).get(dataset, None)
|
| 334 |
+
|
| 335 |
+
if charlm is None:
|
| 336 |
+
return None
|
| 337 |
+
elif charlm != "default":
|
| 338 |
+
return charlm
|
| 339 |
+
elif dataset in dataset_charlms.get(language, {}):
|
| 340 |
+
# this way, a "" or None result gets honored
|
| 341 |
+
# thus treating "not in the map" as a way for dataset_charlms to signal to use the default
|
| 342 |
+
return specific_charlm
|
| 343 |
+
elif default_charlm:
|
| 344 |
+
return default_charlm
|
| 345 |
+
else:
|
| 346 |
+
return None
|
| 347 |
+
|
| 348 |
+
def choose_pos_charlm(short_language, dataset, charlm):
|
| 349 |
+
"""
|
| 350 |
+
charlm == "default" means the default charlm for this dataset or language
|
| 351 |
+
charlm == None is no charlm
|
| 352 |
+
"""
|
| 353 |
+
return choose_charlm(short_language, dataset, charlm, default_charlms, pos_charlms)
|
| 354 |
+
|
| 355 |
+
def choose_depparse_charlm(short_language, dataset, charlm):
|
| 356 |
+
"""
|
| 357 |
+
charlm == "default" means the default charlm for this dataset or language
|
| 358 |
+
charlm == None is no charlm
|
| 359 |
+
"""
|
| 360 |
+
return choose_charlm(short_language, dataset, charlm, default_charlms, depparse_charlms)
|
| 361 |
+
|
| 362 |
+
def choose_lemma_charlm(short_language, dataset, charlm):
|
| 363 |
+
"""
|
| 364 |
+
charlm == "default" means the default charlm for this dataset or language
|
| 365 |
+
charlm == None is no charlm
|
| 366 |
+
"""
|
| 367 |
+
return choose_charlm(short_language, dataset, charlm, default_charlms, lemma_charlms)
|
| 368 |
+
|
| 369 |
+
def choose_transformer(short_language, command_args, extra_args, warn=True, layers=False):
|
| 370 |
+
"""
|
| 371 |
+
Choose a transformer using the default options for this language
|
| 372 |
+
"""
|
| 373 |
+
bert_args = []
|
| 374 |
+
if command_args is not None and command_args.use_bert and '--bert_model' not in extra_args:
|
| 375 |
+
if short_language in TRANSFORMERS:
|
| 376 |
+
bert_args = ['--bert_model', TRANSFORMERS.get(short_language)]
|
| 377 |
+
if layers and short_language in TRANSFORMER_LAYERS and '--bert_hidden_layers' not in extra_args:
|
| 378 |
+
bert_args.extend(['--bert_hidden_layers', str(TRANSFORMER_LAYERS.get(short_language))])
|
| 379 |
+
elif warn:
|
| 380 |
+
logger.error("Transformer requested, but no default transformer for %s Specify one using --bert_model" % short_language)
|
| 381 |
+
|
| 382 |
+
return bert_args
|
| 383 |
+
|
| 384 |
+
def build_pos_charlm_args(short_language, dataset, charlm, base_args=True, model_dir=DEFAULT_MODEL_DIR):
|
| 385 |
+
charlm = choose_pos_charlm(short_language, dataset, charlm)
|
| 386 |
+
charlm_args = build_charlm_args(short_language, charlm, base_args, model_dir)
|
| 387 |
+
return charlm_args
|
| 388 |
+
|
| 389 |
+
def build_lemma_charlm_args(short_language, dataset, charlm, base_args=True, model_dir=DEFAULT_MODEL_DIR):
|
| 390 |
+
charlm = choose_lemma_charlm(short_language, dataset, charlm)
|
| 391 |
+
charlm_args = build_charlm_args(short_language, charlm, base_args, model_dir)
|
| 392 |
+
return charlm_args
|
| 393 |
+
|
| 394 |
+
def build_depparse_charlm_args(short_language, dataset, charlm, base_args=True, model_dir=DEFAULT_MODEL_DIR):
|
| 395 |
+
charlm = choose_depparse_charlm(short_language, dataset, charlm)
|
| 396 |
+
charlm_args = build_charlm_args(short_language, charlm, base_args, model_dir)
|
| 397 |
+
return charlm_args
|
stanza/stanza/utils/training/compose_ete_results.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Turn the ETE results into markdown
|
| 3 |
+
|
| 4 |
+
Parses blocks like this from the model eval script
|
| 5 |
+
|
| 6 |
+
2022-01-14 01:23:34 INFO: End to end results for af_afribooms models on af_afribooms test data:
|
| 7 |
+
Metric | Precision | Recall | F1 Score | AligndAcc
|
| 8 |
+
-----------+-----------+-----------+-----------+-----------
|
| 9 |
+
Tokens | 99.93 | 99.92 | 99.93 |
|
| 10 |
+
Sentences | 100.00 | 100.00 | 100.00 |
|
| 11 |
+
Words | 99.93 | 99.92 | 99.93 |
|
| 12 |
+
UPOS | 97.97 | 97.96 | 97.97 | 98.04
|
| 13 |
+
XPOS | 93.98 | 93.97 | 93.97 | 94.04
|
| 14 |
+
UFeats | 97.23 | 97.22 | 97.22 | 97.29
|
| 15 |
+
AllTags | 93.89 | 93.88 | 93.88 | 93.95
|
| 16 |
+
Lemmas | 97.40 | 97.39 | 97.39 | 97.46
|
| 17 |
+
UAS | 87.39 | 87.38 | 87.38 | 87.45
|
| 18 |
+
LAS | 83.57 | 83.56 | 83.57 | 83.63
|
| 19 |
+
CLAS | 76.88 | 76.45 | 76.66 | 76.52
|
| 20 |
+
MLAS | 72.28 | 71.87 | 72.07 | 71.94
|
| 21 |
+
BLEX | 73.20 | 72.79 | 73.00 | 72.86
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
Turns them into a markdown table.
|
| 25 |
+
|
| 26 |
+
Included is an attempt to mark the default packages with a green check.
|
| 27 |
+
<i class="fas fa-check" style="color:#33a02c"></i>
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
import argparse
|
| 31 |
+
|
| 32 |
+
from stanza.models.common.constant import pretty_langcode_to_lang
|
| 33 |
+
from stanza.models.common.short_name_to_treebank import short_name_to_treebank
|
| 34 |
+
from stanza.utils.training.run_ete import RESULTS_STRING
|
| 35 |
+
from stanza.resources.default_packages import default_treebanks
|
| 36 |
+
|
| 37 |
+
EXPECTED_ORDER = ["Tokens", "Sentences", "Words", "UPOS", "XPOS", "UFeats", "AllTags", "Lemmas", "UAS", "LAS", "CLAS", "MLAS", "BLEX"]
|
| 38 |
+
|
| 39 |
+
parser = argparse.ArgumentParser()
|
| 40 |
+
parser.add_argument("filenames", type=str, nargs="+", help="Which file(s) to read")
|
| 41 |
+
args = parser.parse_args()
|
| 42 |
+
|
| 43 |
+
lines = []
|
| 44 |
+
for filename in args.filenames:
|
| 45 |
+
with open(filename) as fin:
|
| 46 |
+
lines.extend(fin.readlines())
|
| 47 |
+
|
| 48 |
+
blocks = []
|
| 49 |
+
index = 0
|
| 50 |
+
while index < len(lines):
|
| 51 |
+
line = lines[index]
|
| 52 |
+
if line.find(RESULTS_STRING) < 0:
|
| 53 |
+
index = index + 1
|
| 54 |
+
continue
|
| 55 |
+
|
| 56 |
+
line = line[line.find(RESULTS_STRING) + len(RESULTS_STRING):].strip()
|
| 57 |
+
short_name = line.split()[0]
|
| 58 |
+
|
| 59 |
+
# skip the header of the expected output
|
| 60 |
+
index = index + 1
|
| 61 |
+
line = lines[index]
|
| 62 |
+
pieces = line.split("|")
|
| 63 |
+
assert pieces[0].strip() == 'Metric', "output format changed?"
|
| 64 |
+
assert pieces[3].strip() == 'F1 Score', "output format changed?"
|
| 65 |
+
|
| 66 |
+
index = index + 1
|
| 67 |
+
line = lines[index]
|
| 68 |
+
assert line.startswith("-----"), "output format changed?"
|
| 69 |
+
|
| 70 |
+
index = index + 1
|
| 71 |
+
|
| 72 |
+
block = lines[index:index+13]
|
| 73 |
+
assert len(block) == 13
|
| 74 |
+
index = index + 13
|
| 75 |
+
|
| 76 |
+
block = [x.split("|") for x in block]
|
| 77 |
+
assert all(x[0].strip() == y for x, y in zip(block, EXPECTED_ORDER)), "output format changed?"
|
| 78 |
+
lcode, short_dataset = short_name.split("_", 1)
|
| 79 |
+
language = pretty_langcode_to_lang(lcode)
|
| 80 |
+
treebank = short_name_to_treebank(short_name)
|
| 81 |
+
long_dataset = treebank.split("-")[-1]
|
| 82 |
+
|
| 83 |
+
checkmark = ""
|
| 84 |
+
if default_treebanks[lcode] == short_dataset:
|
| 85 |
+
checkmark = '<i class="fas fa-check" style="color:#33a02c"></i>'
|
| 86 |
+
|
| 87 |
+
block = [language, "[%s](%s)" % (long_dataset, "https://github.com/UniversalDependencies/%s" % treebank), lcode, checkmark] + [x[3].strip() for x in block]
|
| 88 |
+
blocks.append(block)
|
| 89 |
+
|
| 90 |
+
PREFIX = ["​Macro Avg", "​", "​", ""]
|
| 91 |
+
|
| 92 |
+
avg = [sum(float(x[i]) for x in blocks) / len(blocks) for i in range(len(PREFIX), len(EXPECTED_ORDER) + len(PREFIX))]
|
| 93 |
+
avg = PREFIX + ["%.2f" % x for x in avg]
|
| 94 |
+
blocks = sorted(blocks)
|
| 95 |
+
blocks = [avg] + blocks
|
| 96 |
+
|
| 97 |
+
chart = ["|%s|" % " | ".join(x) for x in blocks]
|
| 98 |
+
for line in chart:
|
| 99 |
+
print(line)
|
| 100 |
+
|
stanza/stanza/utils/training/run_charlm.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Trains or scores a charlm model.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
from stanza.models import charlm
|
| 9 |
+
from stanza.utils.training import common
|
| 10 |
+
from stanza.utils.training.common import Mode
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger('stanza')
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def add_charlm_args(parser):
|
| 16 |
+
"""
|
| 17 |
+
Extra args for the charlm: forward/backward
|
| 18 |
+
"""
|
| 19 |
+
parser.add_argument('--direction', default='forward', choices=['forward', 'backward'], help="Forward or backward language model")
|
| 20 |
+
parser.add_argument('--forward', action='store_const', dest='direction', const='forward', help="Train a forward language model")
|
| 21 |
+
parser.add_argument('--backward', action='store_const', dest='direction', const='backward', help="Train a backward language model")
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def run_treebank(mode, paths, treebank, short_name,
|
| 25 |
+
temp_output_file, command_args, extra_args):
|
| 26 |
+
short_language, dataset_name = short_name.split("_", 1)
|
| 27 |
+
|
| 28 |
+
train_dir = os.path.join(paths["CHARLM_DATA_DIR"], short_language, dataset_name, "train")
|
| 29 |
+
|
| 30 |
+
dev_file = os.path.join(paths["CHARLM_DATA_DIR"], short_language, dataset_name, "dev.txt")
|
| 31 |
+
if not os.path.exists(dev_file) and os.path.exists(dev_file + ".xz"):
|
| 32 |
+
dev_file = dev_file + ".xz"
|
| 33 |
+
|
| 34 |
+
test_file = os.path.join(paths["CHARLM_DATA_DIR"], short_language, dataset_name, "test.txt")
|
| 35 |
+
if not os.path.exists(test_file) and os.path.exists(test_file + ".xz"):
|
| 36 |
+
test_file = test_file + ".xz"
|
| 37 |
+
|
| 38 |
+
# python -m stanza.models.charlm --train_dir $train_dir --eval_file $dev_file \
|
| 39 |
+
# --direction $direction --shorthand $short --mode train $args
|
| 40 |
+
# python -m stanza.models.charlm --eval_file $dev_file \
|
| 41 |
+
# --direction $direction --shorthand $short --mode predict $args
|
| 42 |
+
# python -m stanza.models.charlm --eval_file $test_file \
|
| 43 |
+
# --direction $direction --shorthand $short --mode predict $args
|
| 44 |
+
|
| 45 |
+
direction = command_args.direction
|
| 46 |
+
default_args = ['--%s' % direction,
|
| 47 |
+
'--shorthand', short_name]
|
| 48 |
+
if mode == Mode.TRAIN:
|
| 49 |
+
train_args = ['--mode', 'train']
|
| 50 |
+
if '--train_dir' not in extra_args:
|
| 51 |
+
train_args += ['--train_dir', train_dir]
|
| 52 |
+
if '--eval_file' not in extra_args:
|
| 53 |
+
train_args += ['--eval_file', dev_file]
|
| 54 |
+
train_args = train_args + default_args + extra_args
|
| 55 |
+
logger.info("Running train step with args: %s", train_args)
|
| 56 |
+
charlm.main(train_args)
|
| 57 |
+
|
| 58 |
+
if mode == Mode.SCORE_DEV:
|
| 59 |
+
dev_args = ['--mode', 'predict']
|
| 60 |
+
if '--eval_file' not in extra_args:
|
| 61 |
+
dev_args += ['--eval_file', dev_file]
|
| 62 |
+
dev_args = dev_args + default_args + extra_args
|
| 63 |
+
logger.info("Running dev step with args: %s", dev_args)
|
| 64 |
+
charlm.main(dev_args)
|
| 65 |
+
|
| 66 |
+
if mode == Mode.SCORE_TEST:
|
| 67 |
+
test_args = ['--mode', 'predict']
|
| 68 |
+
if '--eval_file' not in extra_args:
|
| 69 |
+
test_args += ['--eval_file', test_file]
|
| 70 |
+
test_args = test_args + default_args + extra_args
|
| 71 |
+
logger.info("Running test step with args: %s", test_args)
|
| 72 |
+
charlm.main(test_args)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def get_model_name(args):
|
| 76 |
+
"""
|
| 77 |
+
The charlm saves forward and backward charlms to the same dir, but with different filenames
|
| 78 |
+
"""
|
| 79 |
+
return "%s_charlm" % args.direction
|
| 80 |
+
|
| 81 |
+
def main():
|
| 82 |
+
common.main(run_treebank, "charlm", get_model_name, add_charlm_args, charlm.build_argparse())
|
| 83 |
+
|
| 84 |
+
if __name__ == "__main__":
|
| 85 |
+
main()
|
| 86 |
+
|
stanza/stanza/utils/training/run_constituency.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Trains or scores a constituency model.
|
| 3 |
+
|
| 4 |
+
Currently a suuuuper preliminary script.
|
| 5 |
+
|
| 6 |
+
Example of how to run on multiple parsers at the same time on the Stanford workqueue:
|
| 7 |
+
|
| 8 |
+
for i in `echo 1000 1001 1002 1003 1004`; do nlprun -d a6000 "python3 stanza/utils/training/run_constituency.py vi_vlsp23 --use_bert --stage1_bert_finetun --save_name vi_vlsp23_$i.pt --seed $i --epochs 200 --force" -o vi_vlsp23_$i.out; done
|
| 9 |
+
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import logging
|
| 13 |
+
import os
|
| 14 |
+
|
| 15 |
+
from stanza.models import constituency_parser
|
| 16 |
+
from stanza.models.constituency.retagging import RETAG_METHOD
|
| 17 |
+
from stanza.utils.datasets.constituency import prepare_con_dataset
|
| 18 |
+
from stanza.utils.training import common
|
| 19 |
+
from stanza.utils.training.common import Mode, add_charlm_args, build_charlm_args, choose_charlm, find_wordvec_pretrain
|
| 20 |
+
|
| 21 |
+
from stanza.resources.default_packages import default_charlms, default_pretrains
|
| 22 |
+
|
| 23 |
+
logger = logging.getLogger('stanza')
|
| 24 |
+
|
| 25 |
+
def add_constituency_args(parser):
|
| 26 |
+
add_charlm_args(parser)
|
| 27 |
+
|
| 28 |
+
parser.add_argument('--use_bert', default=False, action="store_true", help='Use the default transformer for this language')
|
| 29 |
+
|
| 30 |
+
parser.add_argument('--parse_text', dest='mode', action='store_const', const="parse_text", help='Parse a text file')
|
| 31 |
+
|
| 32 |
+
def build_wordvec_args(short_language, dataset, extra_args):
|
| 33 |
+
if '--wordvec_pretrain_file' not in extra_args:
|
| 34 |
+
# will throw an error if the pretrain can't be found
|
| 35 |
+
wordvec_pretrain = find_wordvec_pretrain(short_language, default_pretrains)
|
| 36 |
+
wordvec_args = ['--wordvec_pretrain_file', wordvec_pretrain]
|
| 37 |
+
else:
|
| 38 |
+
wordvec_args = []
|
| 39 |
+
|
| 40 |
+
return wordvec_args
|
| 41 |
+
|
| 42 |
+
def build_default_args(paths, short_language, dataset, command_args, extra_args):
|
| 43 |
+
if short_language in RETAG_METHOD:
|
| 44 |
+
retag_args = ["--retag_method", RETAG_METHOD[short_language]]
|
| 45 |
+
else:
|
| 46 |
+
retag_args = []
|
| 47 |
+
|
| 48 |
+
wordvec_args = build_wordvec_args(short_language, dataset, extra_args)
|
| 49 |
+
|
| 50 |
+
charlm = choose_charlm(short_language, dataset, command_args.charlm, default_charlms, {})
|
| 51 |
+
charlm_args = build_charlm_args(short_language, charlm, base_args=False)
|
| 52 |
+
|
| 53 |
+
bert_args = common.choose_transformer(short_language, command_args, extra_args, warn=True, layers=True)
|
| 54 |
+
default_args = retag_args + wordvec_args + charlm_args + bert_args
|
| 55 |
+
|
| 56 |
+
return default_args
|
| 57 |
+
|
| 58 |
+
def build_model_filename(paths, short_name, command_args, extra_args):
|
| 59 |
+
short_language, dataset = short_name.split("_", 1)
|
| 60 |
+
|
| 61 |
+
default_args = build_default_args(paths, short_language, dataset, command_args, extra_args)
|
| 62 |
+
|
| 63 |
+
train_args = ["--shorthand", short_name,
|
| 64 |
+
"--mode", "train"]
|
| 65 |
+
train_args = train_args + default_args
|
| 66 |
+
if command_args.save_name is not None:
|
| 67 |
+
train_args.extend(["--save_name", command_args.save_name])
|
| 68 |
+
if command_args.save_dir is not None:
|
| 69 |
+
train_args.extend(["--save_dir", command_args.save_dir])
|
| 70 |
+
args = constituency_parser.parse_args(train_args)
|
| 71 |
+
save_name = constituency_parser.build_model_filename(args)
|
| 72 |
+
return save_name
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def run_treebank(mode, paths, treebank, short_name, temp_output_file, command_args, extra_args):
|
| 76 |
+
constituency_dir = paths["CONSTITUENCY_DATA_DIR"]
|
| 77 |
+
short_language, dataset = short_name.split("_")
|
| 78 |
+
|
| 79 |
+
train_file = os.path.join(constituency_dir, f"{short_name}_train.mrg")
|
| 80 |
+
dev_file = os.path.join(constituency_dir, f"{short_name}_dev.mrg")
|
| 81 |
+
test_file = os.path.join(constituency_dir, f"{short_name}_test.mrg")
|
| 82 |
+
|
| 83 |
+
if not os.path.exists(train_file) or not os.path.exists(dev_file) or not os.path.exists(test_file):
|
| 84 |
+
logger.warning(f"The data for {short_name} is missing or incomplete. Attempting to rebuild...")
|
| 85 |
+
try:
|
| 86 |
+
prepare_con_dataset.main(short_name)
|
| 87 |
+
except:
|
| 88 |
+
logger.error(f"Unable to build the data. Please correctly build the files in {train_file}, {dev_file}, {test_file} and then try again.")
|
| 89 |
+
raise
|
| 90 |
+
|
| 91 |
+
default_args = build_default_args(paths, short_language, dataset, command_args, extra_args)
|
| 92 |
+
|
| 93 |
+
if mode == Mode.TRAIN:
|
| 94 |
+
train_args = ['--train_file', train_file,
|
| 95 |
+
'--eval_file', dev_file,
|
| 96 |
+
'--shorthand', short_name,
|
| 97 |
+
'--mode', 'train']
|
| 98 |
+
train_args = train_args + default_args + extra_args
|
| 99 |
+
logger.info("Running train step with args: {}".format(train_args))
|
| 100 |
+
constituency_parser.main(train_args)
|
| 101 |
+
|
| 102 |
+
if mode == Mode.SCORE_DEV or mode == Mode.TRAIN:
|
| 103 |
+
dev_args = ['--eval_file', dev_file,
|
| 104 |
+
'--shorthand', short_name,
|
| 105 |
+
'--mode', 'predict']
|
| 106 |
+
dev_args = dev_args + default_args + extra_args
|
| 107 |
+
logger.info("Running dev step with args: {}".format(dev_args))
|
| 108 |
+
constituency_parser.main(dev_args)
|
| 109 |
+
|
| 110 |
+
if mode == Mode.SCORE_TEST or mode == Mode.TRAIN:
|
| 111 |
+
test_args = ['--eval_file', test_file,
|
| 112 |
+
'--shorthand', short_name,
|
| 113 |
+
'--mode', 'predict']
|
| 114 |
+
test_args = test_args + default_args + extra_args
|
| 115 |
+
logger.info("Running test step with args: {}".format(test_args))
|
| 116 |
+
constituency_parser.main(test_args)
|
| 117 |
+
|
| 118 |
+
if mode == "parse_text":
|
| 119 |
+
text_args = ['--shorthand', short_name,
|
| 120 |
+
'--mode', 'parse_text']
|
| 121 |
+
text_args = text_args + default_args + extra_args
|
| 122 |
+
logger.info("Processing text with args: {}".format(text_args))
|
| 123 |
+
constituency_parser.main(text_args)
|
| 124 |
+
|
| 125 |
+
def main():
|
| 126 |
+
common.main(run_treebank, "constituency", "constituency", add_constituency_args, sub_argparse=constituency_parser.build_argparse(), build_model_filename=build_model_filename)
|
| 127 |
+
|
| 128 |
+
if __name__ == "__main__":
|
| 129 |
+
main()
|
| 130 |
+
|
stanza/stanza/utils/training/run_depparse.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
from stanza.models import parser
|
| 5 |
+
|
| 6 |
+
from stanza.utils.training import common
|
| 7 |
+
from stanza.utils.training.common import Mode, add_charlm_args, build_depparse_charlm_args, choose_depparse_charlm, choose_transformer
|
| 8 |
+
from stanza.utils.training.run_pos import wordvec_args
|
| 9 |
+
|
| 10 |
+
from stanza.resources.default_packages import default_charlms, depparse_charlms
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger('stanza')
|
| 13 |
+
|
| 14 |
+
def add_depparse_args(parser):
|
| 15 |
+
add_charlm_args(parser)
|
| 16 |
+
|
| 17 |
+
parser.add_argument('--use_bert', default=False, action="store_true", help='Use the default transformer for this language')
|
| 18 |
+
|
| 19 |
+
# TODO: refactor with run_pos
|
| 20 |
+
def build_model_filename(paths, short_name, command_args, extra_args):
|
| 21 |
+
short_language, dataset = short_name.split("_", 1)
|
| 22 |
+
|
| 23 |
+
# TODO: can avoid downloading the charlm at this point, since we
|
| 24 |
+
# might not even be training
|
| 25 |
+
charlm_args = build_depparse_charlm_args(short_language, dataset, command_args.charlm)
|
| 26 |
+
|
| 27 |
+
bert_args = choose_transformer(short_language, command_args, extra_args, warn=False)
|
| 28 |
+
|
| 29 |
+
train_args = ["--shorthand", short_name,
|
| 30 |
+
"--mode", "train"]
|
| 31 |
+
# TODO: also, this downloads the wordvec, which we might not want to do yet
|
| 32 |
+
train_args = train_args + wordvec_args(short_language, dataset, extra_args) + charlm_args + bert_args + extra_args
|
| 33 |
+
if command_args.save_name is not None:
|
| 34 |
+
train_args.extend(["--save_name", command_args.save_name])
|
| 35 |
+
if command_args.save_dir is not None:
|
| 36 |
+
train_args.extend(["--save_dir", command_args.save_dir])
|
| 37 |
+
args = parser.parse_args(train_args)
|
| 38 |
+
save_name = parser.model_file_name(args)
|
| 39 |
+
return save_name
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def run_treebank(mode, paths, treebank, short_name,
|
| 43 |
+
temp_output_file, command_args, extra_args):
|
| 44 |
+
short_language, dataset = short_name.split("_")
|
| 45 |
+
|
| 46 |
+
# TODO: refactor these blocks?
|
| 47 |
+
depparse_dir = paths["DEPPARSE_DATA_DIR"]
|
| 48 |
+
train_file = f"{depparse_dir}/{short_name}.train.in.conllu"
|
| 49 |
+
dev_in_file = f"{depparse_dir}/{short_name}.dev.in.conllu"
|
| 50 |
+
dev_pred_file = temp_output_file if temp_output_file else f"{depparse_dir}/{short_name}.dev.pred.conllu"
|
| 51 |
+
test_in_file = f"{depparse_dir}/{short_name}.test.in.conllu"
|
| 52 |
+
test_pred_file = temp_output_file if temp_output_file else f"{depparse_dir}/{short_name}.test.pred.conllu"
|
| 53 |
+
|
| 54 |
+
eval_file = None
|
| 55 |
+
if '--eval_file' in extra_args:
|
| 56 |
+
eval_file = extra_args[extra_args.index('--eval_file') + 1]
|
| 57 |
+
|
| 58 |
+
charlm_args = build_depparse_charlm_args(short_language, dataset, command_args.charlm)
|
| 59 |
+
|
| 60 |
+
bert_args = choose_transformer(short_language, command_args, extra_args)
|
| 61 |
+
|
| 62 |
+
if mode == Mode.TRAIN:
|
| 63 |
+
if not os.path.exists(train_file):
|
| 64 |
+
logger.error("TRAIN FILE NOT FOUND: %s ... skipping" % train_file)
|
| 65 |
+
return
|
| 66 |
+
|
| 67 |
+
# some languages need reduced batch size
|
| 68 |
+
if short_name == 'de_hdt':
|
| 69 |
+
# 'UD_German-HDT'
|
| 70 |
+
batch_size = "1300"
|
| 71 |
+
elif short_name in ('hr_set', 'fi_tdt', 'ru_taiga', 'cs_cltt', 'gl_treegal', 'lv_lvtb', 'ro_simonero'):
|
| 72 |
+
# 'UD_Croatian-SET', 'UD_Finnish-TDT', 'UD_Russian-Taiga',
|
| 73 |
+
# 'UD_Czech-CLTT', 'UD_Galician-TreeGal', 'UD_Latvian-LVTB' 'Romanian-SiMoNERo'
|
| 74 |
+
batch_size = "3000"
|
| 75 |
+
else:
|
| 76 |
+
batch_size = "5000"
|
| 77 |
+
|
| 78 |
+
train_args = ["--wordvec_dir", paths["WORDVEC_DIR"],
|
| 79 |
+
"--train_file", train_file,
|
| 80 |
+
"--eval_file", eval_file if eval_file else dev_in_file,
|
| 81 |
+
"--output_file", dev_pred_file,
|
| 82 |
+
"--batch_size", batch_size,
|
| 83 |
+
"--lang", short_language,
|
| 84 |
+
"--shorthand", short_name,
|
| 85 |
+
"--mode", "train"]
|
| 86 |
+
train_args = train_args + wordvec_args(short_language, dataset, extra_args) + charlm_args + bert_args
|
| 87 |
+
train_args = train_args + extra_args
|
| 88 |
+
logger.info("Running train depparse for {} with args {}".format(treebank, train_args))
|
| 89 |
+
parser.main(train_args)
|
| 90 |
+
|
| 91 |
+
if mode == Mode.SCORE_DEV or mode == Mode.TRAIN:
|
| 92 |
+
dev_args = ["--wordvec_dir", paths["WORDVEC_DIR"],
|
| 93 |
+
"--eval_file", eval_file if eval_file else dev_in_file,
|
| 94 |
+
"--output_file", dev_pred_file,
|
| 95 |
+
"--lang", short_language,
|
| 96 |
+
"--shorthand", short_name,
|
| 97 |
+
"--mode", "predict"]
|
| 98 |
+
dev_args = dev_args + wordvec_args(short_language, dataset, extra_args) + charlm_args + bert_args
|
| 99 |
+
dev_args = dev_args + extra_args
|
| 100 |
+
logger.info("Running dev depparse for {} with args {}".format(treebank, dev_args))
|
| 101 |
+
parser.main(dev_args)
|
| 102 |
+
|
| 103 |
+
if '--no_gold_labels' not in extra_args:
|
| 104 |
+
results = common.run_eval_script_depparse(eval_file if eval_file else dev_in_file, dev_pred_file)
|
| 105 |
+
logger.info("Finished running dev set on\n{}\n{}".format(treebank, results))
|
| 106 |
+
if not temp_output_file:
|
| 107 |
+
logger.info("Output saved to %s", dev_pred_file)
|
| 108 |
+
|
| 109 |
+
if mode == Mode.SCORE_TEST:
|
| 110 |
+
test_args = ["--wordvec_dir", paths["WORDVEC_DIR"],
|
| 111 |
+
"--eval_file", eval_file if eval_file else test_in_file,
|
| 112 |
+
"--output_file", test_pred_file,
|
| 113 |
+
"--lang", short_language,
|
| 114 |
+
"--shorthand", short_name,
|
| 115 |
+
"--mode", "predict"]
|
| 116 |
+
test_args = test_args + wordvec_args(short_language, dataset, extra_args) + charlm_args + bert_args
|
| 117 |
+
test_args = test_args + extra_args
|
| 118 |
+
logger.info("Running test depparse for {} with args {}".format(treebank, test_args))
|
| 119 |
+
parser.main(test_args)
|
| 120 |
+
|
| 121 |
+
if '--no_gold_labels' not in extra_args:
|
| 122 |
+
results = common.run_eval_script_depparse(eval_file if eval_file else test_in_file, test_pred_file)
|
| 123 |
+
logger.info("Finished running test set on\n{}\n{}".format(treebank, results))
|
| 124 |
+
if not temp_output_file:
|
| 125 |
+
logger.info("Output saved to %s", test_pred_file)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def main():
|
| 129 |
+
common.main(run_treebank, "depparse", "parser", add_depparse_args, sub_argparse=parser.build_argparse(), build_model_filename=build_model_filename, choose_charlm_method=choose_depparse_charlm)
|
| 130 |
+
|
| 131 |
+
if __name__ == "__main__":
|
| 132 |
+
main()
|
| 133 |
+
|
stanza/stanza/utils/training/run_lemma.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This script allows for training or testing on dev / test of the UD lemmatizer.
|
| 3 |
+
|
| 4 |
+
If run with a single treebank name, it will train or test that treebank.
|
| 5 |
+
If run with ud_all or all_ud, it will iterate over all UD treebanks it can find.
|
| 6 |
+
|
| 7 |
+
Mode can be set to train&dev with --train, to dev set only
|
| 8 |
+
with --score_dev, and to test set only with --score_test.
|
| 9 |
+
|
| 10 |
+
Treebanks are specified as a list. all_ud or ud_all means to look for
|
| 11 |
+
all UD treebanks.
|
| 12 |
+
|
| 13 |
+
Extra arguments are passed to the lemmatizer. In case the run script
|
| 14 |
+
itself is shadowing arguments, you can specify --extra_args as a
|
| 15 |
+
parameter to mark where the lemmatizer arguments start.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import logging
|
| 19 |
+
import os
|
| 20 |
+
|
| 21 |
+
from stanza.models import identity_lemmatizer
|
| 22 |
+
from stanza.models import lemmatizer
|
| 23 |
+
from stanza.models.lemma import attach_lemma_classifier
|
| 24 |
+
|
| 25 |
+
from stanza.utils.training import common
|
| 26 |
+
from stanza.utils.training.common import Mode, add_charlm_args, build_lemma_charlm_args, choose_lemma_charlm
|
| 27 |
+
from stanza.utils.training import run_lemma_classifier
|
| 28 |
+
|
| 29 |
+
from stanza.utils.datasets.prepare_lemma_treebank import check_lemmas
|
| 30 |
+
import stanza.utils.datasets.prepare_lemma_classifier as prepare_lemma_classifier
|
| 31 |
+
|
| 32 |
+
logger = logging.getLogger('stanza')
|
| 33 |
+
|
| 34 |
+
def add_lemma_args(parser):
|
| 35 |
+
add_charlm_args(parser)
|
| 36 |
+
|
| 37 |
+
parser.add_argument('--lemma_classifier', dest='lemma_classifier', action='store_true', default=None,
|
| 38 |
+
help="Don't use the lemma classifier datasets. Default is to build lemma classifier as part of the original lemmatizer if the charlm is used")
|
| 39 |
+
parser.add_argument('--no_lemma_classifier', dest='lemma_classifier', action='store_false',
|
| 40 |
+
help="Don't use the lemma classifier datasets. Default is to build lemma classifier as part of the original lemmatizer if the charlm is used")
|
| 41 |
+
|
| 42 |
+
def build_model_filename(paths, short_name, command_args, extra_args):
|
| 43 |
+
"""
|
| 44 |
+
Figure out what the model savename will be, taking into account the model settings.
|
| 45 |
+
|
| 46 |
+
Useful for figuring out if the model already exists
|
| 47 |
+
|
| 48 |
+
None will represent that there is no expected save_name
|
| 49 |
+
"""
|
| 50 |
+
short_language, dataset = short_name.split("_", 1)
|
| 51 |
+
|
| 52 |
+
lemma_dir = paths["LEMMA_DATA_DIR"]
|
| 53 |
+
train_file = f"{lemma_dir}/{short_name}.train.in.conllu"
|
| 54 |
+
|
| 55 |
+
if not os.path.exists(train_file):
|
| 56 |
+
logger.debug("Treebank %s is not prepared for training the lemmatizer. Could not find any training data at %s Cannot figure out the expected save_name without looking at the data, but a later step in the process will skip the training anyway" % (short_name, train_file))
|
| 57 |
+
return None
|
| 58 |
+
|
| 59 |
+
has_lemmas = check_lemmas(train_file)
|
| 60 |
+
if not has_lemmas:
|
| 61 |
+
return None
|
| 62 |
+
|
| 63 |
+
# TODO: can avoid downloading the charlm at this point, since we
|
| 64 |
+
# might not even be training
|
| 65 |
+
charlm_args = build_lemma_charlm_args(short_language, dataset, command_args.charlm)
|
| 66 |
+
|
| 67 |
+
train_args = ["--train_file", train_file,
|
| 68 |
+
"--shorthand", short_name,
|
| 69 |
+
"--mode", "train"]
|
| 70 |
+
train_args = train_args + charlm_args + extra_args
|
| 71 |
+
args = lemmatizer.parse_args(train_args)
|
| 72 |
+
save_name = lemmatizer.build_model_filename(args)
|
| 73 |
+
return save_name
|
| 74 |
+
|
| 75 |
+
def run_treebank(mode, paths, treebank, short_name,
|
| 76 |
+
temp_output_file, command_args, extra_args):
|
| 77 |
+
short_language, dataset = short_name.split("_", 1)
|
| 78 |
+
|
| 79 |
+
lemma_dir = paths["LEMMA_DATA_DIR"]
|
| 80 |
+
train_file = f"{lemma_dir}/{short_name}.train.in.conllu"
|
| 81 |
+
dev_in_file = f"{lemma_dir}/{short_name}.dev.in.conllu"
|
| 82 |
+
dev_gold_file = f"{lemma_dir}/{short_name}.dev.gold.conllu"
|
| 83 |
+
dev_pred_file = temp_output_file if temp_output_file else f"{lemma_dir}/{short_name}.dev.pred.conllu"
|
| 84 |
+
test_in_file = f"{lemma_dir}/{short_name}.test.in.conllu"
|
| 85 |
+
test_gold_file = f"{lemma_dir}/{short_name}.test.gold.conllu"
|
| 86 |
+
test_pred_file = temp_output_file if temp_output_file else f"{lemma_dir}/{short_name}.test.pred.conllu"
|
| 87 |
+
|
| 88 |
+
charlm_args = build_lemma_charlm_args(short_language, dataset, command_args.charlm)
|
| 89 |
+
|
| 90 |
+
if not os.path.exists(train_file):
|
| 91 |
+
logger.error("Treebank %s is not prepared for training the lemmatizer. Could not find any training data at %s Skipping..." % (treebank, train_file))
|
| 92 |
+
return
|
| 93 |
+
|
| 94 |
+
has_lemmas = check_lemmas(train_file)
|
| 95 |
+
if not has_lemmas:
|
| 96 |
+
logger.info("Treebank " + treebank + " (" + short_name +
|
| 97 |
+
") has no lemmas. Using identity lemmatizer")
|
| 98 |
+
if mode == Mode.TRAIN or mode == Mode.SCORE_DEV:
|
| 99 |
+
train_args = ["--train_file", train_file,
|
| 100 |
+
"--eval_file", dev_in_file,
|
| 101 |
+
"--output_file", dev_pred_file,
|
| 102 |
+
"--gold_file", dev_gold_file,
|
| 103 |
+
"--shorthand", short_name]
|
| 104 |
+
logger.info("Running identity lemmatizer for {} with args {}".format(treebank, train_args))
|
| 105 |
+
identity_lemmatizer.main(train_args)
|
| 106 |
+
elif mode == Mode.SCORE_TEST:
|
| 107 |
+
train_args = ["--train_file", train_file,
|
| 108 |
+
"--eval_file", test_in_file,
|
| 109 |
+
"--output_file", test_pred_file,
|
| 110 |
+
"--gold_file", test_gold_file,
|
| 111 |
+
"--shorthand", short_name]
|
| 112 |
+
logger.info("Running identity lemmatizer for {} with args {}".format(treebank, train_args))
|
| 113 |
+
identity_lemmatizer.main(train_args)
|
| 114 |
+
else:
|
| 115 |
+
if mode == Mode.TRAIN:
|
| 116 |
+
# ('UD_Czech-PDT', 'UD_Russian-SynTagRus', 'UD_German-HDT')
|
| 117 |
+
if short_name in ('cs_pdt', 'ru_syntagrus', 'de_hdt'):
|
| 118 |
+
num_epochs = "30"
|
| 119 |
+
else:
|
| 120 |
+
num_epochs = "60"
|
| 121 |
+
|
| 122 |
+
train_args = ["--train_file", train_file,
|
| 123 |
+
"--eval_file", dev_in_file,
|
| 124 |
+
"--output_file", dev_pred_file,
|
| 125 |
+
"--gold_file", dev_gold_file,
|
| 126 |
+
"--shorthand", short_name,
|
| 127 |
+
"--num_epoch", num_epochs,
|
| 128 |
+
"--mode", "train"]
|
| 129 |
+
train_args = train_args + charlm_args + extra_args
|
| 130 |
+
logger.info("Running train lemmatizer for {} with args {}".format(treebank, train_args))
|
| 131 |
+
lemmatizer.main(train_args)
|
| 132 |
+
|
| 133 |
+
if mode == Mode.SCORE_DEV or mode == Mode.TRAIN:
|
| 134 |
+
dev_args = ["--eval_file", dev_in_file,
|
| 135 |
+
"--output_file", dev_pred_file,
|
| 136 |
+
"--gold_file", dev_gold_file,
|
| 137 |
+
"--shorthand", short_name,
|
| 138 |
+
"--mode", "predict"]
|
| 139 |
+
dev_args = dev_args + charlm_args + extra_args
|
| 140 |
+
logger.info("Running dev lemmatizer for {} with args {}".format(treebank, dev_args))
|
| 141 |
+
lemmatizer.main(dev_args)
|
| 142 |
+
|
| 143 |
+
if mode == Mode.SCORE_TEST:
|
| 144 |
+
test_args = ["--eval_file", test_in_file,
|
| 145 |
+
"--output_file", test_pred_file,
|
| 146 |
+
"--gold_file", test_gold_file,
|
| 147 |
+
"--shorthand", short_name,
|
| 148 |
+
"--mode", "predict"]
|
| 149 |
+
test_args = test_args + charlm_args + extra_args
|
| 150 |
+
logger.info("Running test lemmatizer for {} with args {}".format(treebank, test_args))
|
| 151 |
+
lemmatizer.main(test_args)
|
| 152 |
+
|
| 153 |
+
use_lemma_classifier = command_args.lemma_classifier
|
| 154 |
+
if use_lemma_classifier is None:
|
| 155 |
+
use_lemma_classifier = command_args.charlm is not None
|
| 156 |
+
use_lemma_classifier = use_lemma_classifier and short_name in prepare_lemma_classifier.DATASET_MAPPING
|
| 157 |
+
if use_lemma_classifier and mode == Mode.TRAIN:
|
| 158 |
+
lc_charlm_args = ['--no_charlm'] if command_args.charlm is None else ['--charlm', command_args.charlm]
|
| 159 |
+
lemma_classifier_args = [treebank] + lc_charlm_args
|
| 160 |
+
if command_args.force:
|
| 161 |
+
lemma_classifier_args.append('--force')
|
| 162 |
+
run_lemma_classifier.main(lemma_classifier_args)
|
| 163 |
+
|
| 164 |
+
save_name = build_model_filename(paths, short_name, command_args, extra_args)
|
| 165 |
+
# TODO: use a temp path for the lemma_classifier or keep it somewhere
|
| 166 |
+
attach_args = ['--input', save_name,
|
| 167 |
+
'--output', save_name,
|
| 168 |
+
'--classifier', 'saved_models/lemma_classifier/%s_lemma_classifier.pt' % short_name]
|
| 169 |
+
attach_lemma_classifier.main(attach_args)
|
| 170 |
+
|
| 171 |
+
# now we rerun the dev set - the HI in particular demonstrates some good improvement
|
| 172 |
+
lemmatizer.main(dev_args)
|
| 173 |
+
|
| 174 |
+
def main():
|
| 175 |
+
common.main(run_treebank, "lemma", "lemmatizer", add_lemma_args, sub_argparse=lemmatizer.build_argparse(), build_model_filename=build_model_filename, choose_charlm_method=choose_lemma_charlm)
|
| 176 |
+
|
| 177 |
+
if __name__ == "__main__":
|
| 178 |
+
main()
|
| 179 |
+
|
stanza/stanza/utils/training/run_lemma_classifier.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
from stanza.models.lemma_classifier import evaluate_models
|
| 4 |
+
from stanza.models.lemma_classifier import train_lstm_model
|
| 5 |
+
from stanza.models.lemma_classifier import train_transformer_model
|
| 6 |
+
from stanza.models.lemma_classifier.constants import ModelType
|
| 7 |
+
|
| 8 |
+
from stanza.resources.default_packages import default_pretrains, TRANSFORMERS
|
| 9 |
+
from stanza.utils.training import common
|
| 10 |
+
from stanza.utils.training.common import Mode, add_charlm_args, build_lemma_charlm_args, choose_lemma_charlm, find_wordvec_pretrain
|
| 11 |
+
|
| 12 |
+
def add_lemma_args(parser):
|
| 13 |
+
add_charlm_args(parser)
|
| 14 |
+
|
| 15 |
+
parser.add_argument('--model_type', default=ModelType.LSTM, type=lambda x: ModelType[x.upper()],
|
| 16 |
+
help='Model type to use. {}'.format(", ".join(x.name for x in ModelType)))
|
| 17 |
+
|
| 18 |
+
def build_model_filename(paths, short_name, command_args, extra_args):
|
| 19 |
+
return os.path.join("saved_models", "lemma_classifier", short_name + "_lemma_classifier.pt")
|
| 20 |
+
|
| 21 |
+
def run_treebank(mode, paths, treebank, short_name,
|
| 22 |
+
temp_output_file, command_args, extra_args):
|
| 23 |
+
short_language, dataset = short_name.split("_", 1)
|
| 24 |
+
|
| 25 |
+
base_args = []
|
| 26 |
+
if '--save_name' not in extra_args:
|
| 27 |
+
base_args += ['--save_name', build_model_filename(paths, short_name, command_args, extra_args)]
|
| 28 |
+
|
| 29 |
+
embedding_args = build_lemma_charlm_args(short_language, dataset, command_args.charlm)
|
| 30 |
+
if '--wordvec_pretrain_file' not in extra_args:
|
| 31 |
+
wordvec_pretrain = find_wordvec_pretrain(short_language, default_pretrains, {}, dataset)
|
| 32 |
+
embedding_args += ["--wordvec_pretrain_file", wordvec_pretrain]
|
| 33 |
+
|
| 34 |
+
bert_args = []
|
| 35 |
+
if command_args.model_type is ModelType.TRANSFORMER:
|
| 36 |
+
if '--bert_model' not in extra_args:
|
| 37 |
+
if short_language in TRANSFORMERS:
|
| 38 |
+
bert_args = ['--bert_model', TRANSFORMERS.get(short_language)]
|
| 39 |
+
else:
|
| 40 |
+
raise ValueError("--bert_model not specified, so cannot figure out which transformer to use for language %s" % short_language)
|
| 41 |
+
|
| 42 |
+
extra_train_args = []
|
| 43 |
+
if command_args.force:
|
| 44 |
+
extra_train_args.append('--force')
|
| 45 |
+
|
| 46 |
+
if mode == Mode.TRAIN:
|
| 47 |
+
train_args = []
|
| 48 |
+
if "--train_file" not in extra_args:
|
| 49 |
+
train_file = os.path.join("data", "lemma_classifier", "%s.train.lemma" % short_name)
|
| 50 |
+
train_args += ['--train_file', train_file]
|
| 51 |
+
if "--eval_file" not in extra_args:
|
| 52 |
+
eval_file = os.path.join("data", "lemma_classifier", "%s.dev.lemma" % short_name)
|
| 53 |
+
train_args += ['--eval_file', eval_file]
|
| 54 |
+
train_args = base_args + train_args + extra_args + extra_train_args
|
| 55 |
+
|
| 56 |
+
if command_args.model_type == ModelType.LSTM:
|
| 57 |
+
train_args = embedding_args + train_args
|
| 58 |
+
train_lstm_model.main(train_args)
|
| 59 |
+
else:
|
| 60 |
+
model_type_args = ["--model_type", command_args.model_type.name.lower()]
|
| 61 |
+
train_args = bert_args + model_type_args + train_args
|
| 62 |
+
train_transformer_model.main(train_args)
|
| 63 |
+
|
| 64 |
+
if mode == Mode.SCORE_DEV or mode == Mode.TRAIN:
|
| 65 |
+
eval_args = []
|
| 66 |
+
if "--eval_file" not in extra_args:
|
| 67 |
+
eval_file = os.path.join("data", "lemma_classifier", "%s.dev.lemma" % short_name)
|
| 68 |
+
eval_args += ['--eval_file', eval_file]
|
| 69 |
+
model_type_args = ["--model_type", command_args.model_type.name.lower()]
|
| 70 |
+
eval_args = bert_args + model_type_args + base_args + eval_args + embedding_args + extra_args
|
| 71 |
+
evaluate_models.main(eval_args)
|
| 72 |
+
|
| 73 |
+
if mode == Mode.SCORE_TEST or mode == Mode.TRAIN:
|
| 74 |
+
eval_args = []
|
| 75 |
+
if "--eval_file" not in extra_args:
|
| 76 |
+
eval_file = os.path.join("data", "lemma_classifier", "%s.test.lemma" % short_name)
|
| 77 |
+
eval_args += ['--eval_file', eval_file]
|
| 78 |
+
model_type_args = ["--model_type", command_args.model_type.name.lower()]
|
| 79 |
+
eval_args = bert_args + model_type_args + base_args + eval_args + embedding_args + extra_args
|
| 80 |
+
evaluate_models.main(eval_args)
|
| 81 |
+
|
| 82 |
+
def main(args=None):
|
| 83 |
+
common.main(run_treebank, "lemma_classifier", "lemma_classifier", add_lemma_args, sub_argparse=train_lstm_model.build_argparse(), build_model_filename=build_model_filename, choose_charlm_method=choose_lemma_charlm, args=args)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
if __name__ == '__main__':
|
| 87 |
+
main()
|
stanza/stanza/utils/training/run_mwt.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This script allows for training or testing on dev / test of the UD mwt tools.
|
| 3 |
+
|
| 4 |
+
If run with a single treebank name, it will train or test that treebank.
|
| 5 |
+
If run with ud_all or all_ud, it will iterate over all UD treebanks it can find.
|
| 6 |
+
|
| 7 |
+
Mode can be set to train&dev with --train, to dev set only
|
| 8 |
+
with --score_dev, and to test set only with --score_test.
|
| 9 |
+
|
| 10 |
+
Treebanks are specified as a list. all_ud or ud_all means to look for
|
| 11 |
+
all UD treebanks.
|
| 12 |
+
|
| 13 |
+
Extra arguments are passed to mwt. In case the run script
|
| 14 |
+
itself is shadowing arguments, you can specify --extra_args as a
|
| 15 |
+
parameter to mark where the mwt arguments start.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
import logging
|
| 20 |
+
import math
|
| 21 |
+
|
| 22 |
+
from stanza.models import mwt_expander
|
| 23 |
+
from stanza.models.common.doc import Document
|
| 24 |
+
from stanza.utils.conll import CoNLL
|
| 25 |
+
from stanza.utils.training import common
|
| 26 |
+
from stanza.utils.training.common import Mode
|
| 27 |
+
|
| 28 |
+
from stanza.utils.max_mwt_length import max_mwt_length
|
| 29 |
+
|
| 30 |
+
logger = logging.getLogger('stanza')
|
| 31 |
+
|
| 32 |
+
def check_mwt(filename):
|
| 33 |
+
"""
|
| 34 |
+
Checks whether or not there are MWTs in the given conll file
|
| 35 |
+
"""
|
| 36 |
+
doc = CoNLL.conll2doc(filename)
|
| 37 |
+
data = doc.get_mwt_expansions(False)
|
| 38 |
+
return len(data) > 0
|
| 39 |
+
|
| 40 |
+
def run_treebank(mode, paths, treebank, short_name,
|
| 41 |
+
temp_output_file, command_args, extra_args):
|
| 42 |
+
short_language = short_name.split("_")[0]
|
| 43 |
+
|
| 44 |
+
mwt_dir = paths["MWT_DATA_DIR"]
|
| 45 |
+
|
| 46 |
+
train_file = f"{mwt_dir}/{short_name}.train.in.conllu"
|
| 47 |
+
dev_in_file = f"{mwt_dir}/{short_name}.dev.in.conllu"
|
| 48 |
+
dev_gold_file = f"{mwt_dir}/{short_name}.dev.gold.conllu"
|
| 49 |
+
dev_output_file = temp_output_file if temp_output_file else f"{mwt_dir}/{short_name}.dev.pred.conllu"
|
| 50 |
+
test_in_file = f"{mwt_dir}/{short_name}.test.in.conllu"
|
| 51 |
+
test_gold_file = f"{mwt_dir}/{short_name}.test.gold.conllu"
|
| 52 |
+
test_output_file = temp_output_file if temp_output_file else f"{mwt_dir}/{short_name}.test.pred.conllu"
|
| 53 |
+
|
| 54 |
+
train_json = f"{mwt_dir}/{short_name}-ud-train-mwt.json"
|
| 55 |
+
dev_json = f"{mwt_dir}/{short_name}-ud-dev-mwt.json"
|
| 56 |
+
test_json = f"{mwt_dir}/{short_name}-ud-test-mwt.json"
|
| 57 |
+
|
| 58 |
+
eval_file = None
|
| 59 |
+
if '--eval_file' in extra_args:
|
| 60 |
+
eval_file = extra_args[extra_args.index('--eval_file') + 1]
|
| 61 |
+
|
| 62 |
+
gold_file = None
|
| 63 |
+
if '--gold_file' in extra_args:
|
| 64 |
+
gold_file = extra_args[extra_args.index('--gold_file') + 1]
|
| 65 |
+
|
| 66 |
+
if not check_mwt(train_file):
|
| 67 |
+
logger.info("No training MWTS found for %s. Skipping" % treebank)
|
| 68 |
+
return
|
| 69 |
+
|
| 70 |
+
if not check_mwt(dev_in_file) and mode == Mode.TRAIN:
|
| 71 |
+
logger.info("No dev MWTS found for %s. Training only the deterministic MWT expander" % treebank)
|
| 72 |
+
extra_args.append('--dict_only')
|
| 73 |
+
|
| 74 |
+
if mode == Mode.TRAIN:
|
| 75 |
+
max_mwt_len = math.ceil(max_mwt_length([train_json, dev_json]) * 1.1 + 1)
|
| 76 |
+
logger.info("Max len: %f" % max_mwt_len)
|
| 77 |
+
train_args = ['--train_file', train_file,
|
| 78 |
+
'--eval_file', eval_file if eval_file else dev_in_file,
|
| 79 |
+
'--output_file', dev_output_file,
|
| 80 |
+
'--gold_file', gold_file if gold_file else dev_gold_file,
|
| 81 |
+
'--lang', short_language,
|
| 82 |
+
'--shorthand', short_name,
|
| 83 |
+
'--mode', 'train',
|
| 84 |
+
'--max_dec_len', str(max_mwt_len)]
|
| 85 |
+
train_args = train_args + extra_args
|
| 86 |
+
logger.info("Running train step with args: {}".format(train_args))
|
| 87 |
+
mwt_expander.main(train_args)
|
| 88 |
+
|
| 89 |
+
if mode == Mode.SCORE_DEV or mode == Mode.TRAIN:
|
| 90 |
+
dev_args = ['--eval_file', eval_file if eval_file else dev_in_file,
|
| 91 |
+
'--output_file', dev_output_file,
|
| 92 |
+
'--gold_file', gold_file if gold_file else dev_gold_file,
|
| 93 |
+
'--lang', short_language,
|
| 94 |
+
'--shorthand', short_name,
|
| 95 |
+
'--mode', 'predict']
|
| 96 |
+
dev_args = dev_args + extra_args
|
| 97 |
+
logger.info("Running dev step with args: {}".format(dev_args))
|
| 98 |
+
mwt_expander.main(dev_args)
|
| 99 |
+
|
| 100 |
+
results = common.run_eval_script_mwt(gold_file if gold_file else dev_gold_file, dev_output_file)
|
| 101 |
+
logger.info("Finished running dev set on\n{}\n{}".format(treebank, results))
|
| 102 |
+
|
| 103 |
+
if mode == Mode.SCORE_TEST:
|
| 104 |
+
test_args = ['--eval_file', eval_file if eval_file else test_in_file,
|
| 105 |
+
'--output_file', test_output_file,
|
| 106 |
+
'--gold_file', gold_file if gold_file else test_gold_file,
|
| 107 |
+
'--lang', short_language,
|
| 108 |
+
'--shorthand', short_name,
|
| 109 |
+
'--mode', 'predict']
|
| 110 |
+
test_args = test_args + extra_args
|
| 111 |
+
logger.info("Running test step with args: {}".format(test_args))
|
| 112 |
+
mwt_expander.main(test_args)
|
| 113 |
+
|
| 114 |
+
results = common.run_eval_script_mwt(gold_file if gold_file else test_gold_file, test_output_file)
|
| 115 |
+
logger.info("Finished running test set on\n{}\n{}".format(treebank, results))
|
| 116 |
+
|
| 117 |
+
def main():
|
| 118 |
+
common.main(run_treebank, "mwt", "mwt_expander", sub_argparse=mwt_expander.build_argparse())
|
| 119 |
+
|
| 120 |
+
if __name__ == "__main__":
|
| 121 |
+
main()
|
| 122 |
+
|
stanza/stanza/utils/training/run_ner.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Trains or scores an NER model.
|
| 3 |
+
|
| 4 |
+
Will attempt to guess the appropriate word vector file if none is
|
| 5 |
+
specified, and will use the charlms specified in the resources
|
| 6 |
+
for a given dataset or language if possible.
|
| 7 |
+
|
| 8 |
+
Example command line:
|
| 9 |
+
python3 -m stanza.utils.training.run_ner.py hu_combined
|
| 10 |
+
|
| 11 |
+
This script expects the prepared data to be in
|
| 12 |
+
data/ner/{lang}_{dataset}.train.json, {lang}_{dataset}.dev.json, {lang}_{dataset}.test.json
|
| 13 |
+
|
| 14 |
+
If those files don't exist, it will make an attempt to rebuild them
|
| 15 |
+
using the prepare_ner_dataset script. However, this will fail if the
|
| 16 |
+
data is not already downloaded. More information on where to find
|
| 17 |
+
most of the datasets online is in that script. Some of the datasets
|
| 18 |
+
have licenses which must be agreed to, so no attempt is made to
|
| 19 |
+
automatically download the data.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
import logging
|
| 23 |
+
import os
|
| 24 |
+
|
| 25 |
+
from stanza.models import ner_tagger
|
| 26 |
+
from stanza.resources.common import DEFAULT_MODEL_DIR
|
| 27 |
+
from stanza.utils.datasets.ner import prepare_ner_dataset
|
| 28 |
+
from stanza.utils.training import common
|
| 29 |
+
from stanza.utils.training.common import Mode, add_charlm_args, build_charlm_args, choose_charlm, find_wordvec_pretrain
|
| 30 |
+
|
| 31 |
+
from stanza.resources.default_packages import default_charlms, default_pretrains, ner_charlms, ner_pretrains
|
| 32 |
+
|
| 33 |
+
# extra arguments specific to a particular dataset
|
| 34 |
+
DATASET_EXTRA_ARGS = {
|
| 35 |
+
"da_ddt": [ "--dropout", "0.6" ],
|
| 36 |
+
"fa_arman": [ "--dropout", "0.6" ],
|
| 37 |
+
"vi_vlsp": [ "--dropout", "0.6",
|
| 38 |
+
"--word_dropout", "0.1",
|
| 39 |
+
"--locked_dropout", "0.1",
|
| 40 |
+
"--char_dropout", "0.1" ],
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
logger = logging.getLogger('stanza')
|
| 44 |
+
|
| 45 |
+
def add_ner_args(parser):
|
| 46 |
+
add_charlm_args(parser)
|
| 47 |
+
|
| 48 |
+
parser.add_argument('--use_bert', default=False, action="store_true", help='Use the default transformer for this language')
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def build_pretrain_args(language, dataset, charlm="default", command_args=None, extra_args=None, model_dir=DEFAULT_MODEL_DIR):
|
| 52 |
+
"""
|
| 53 |
+
Returns one list with the args for this language & dataset's charlm and pretrained embedding
|
| 54 |
+
"""
|
| 55 |
+
charlm = choose_charlm(language, dataset, charlm, default_charlms, ner_charlms)
|
| 56 |
+
charlm_args = build_charlm_args(language, charlm, model_dir=model_dir)
|
| 57 |
+
|
| 58 |
+
wordvec_args = []
|
| 59 |
+
if extra_args is None or '--wordvec_pretrain_file' not in extra_args:
|
| 60 |
+
# will throw an error if the pretrain can't be found
|
| 61 |
+
wordvec_pretrain = find_wordvec_pretrain(language, default_pretrains, ner_pretrains, dataset, model_dir=model_dir)
|
| 62 |
+
wordvec_args = ['--wordvec_pretrain_file', wordvec_pretrain]
|
| 63 |
+
|
| 64 |
+
bert_args = common.choose_transformer(language, command_args, extra_args, warn=False)
|
| 65 |
+
|
| 66 |
+
return charlm_args + wordvec_args + bert_args
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
# TODO: refactor? tagger and depparse should be pretty similar
|
| 70 |
+
def build_model_filename(paths, short_name, command_args, extra_args):
|
| 71 |
+
short_language, dataset = short_name.split("_", 1)
|
| 72 |
+
|
| 73 |
+
# TODO: can avoid downloading the charlm at this point, since we
|
| 74 |
+
# might not even be training
|
| 75 |
+
pretrain_args = build_pretrain_args(short_language, dataset, command_args.charlm, command_args, extra_args)
|
| 76 |
+
|
| 77 |
+
dataset_args = DATASET_EXTRA_ARGS.get(short_name, [])
|
| 78 |
+
|
| 79 |
+
train_args = ["--shorthand", short_name,
|
| 80 |
+
"--mode", "train"]
|
| 81 |
+
train_args = train_args + pretrain_args + dataset_args + extra_args
|
| 82 |
+
if command_args.save_name is not None:
|
| 83 |
+
train_args.extend(["--save_name", command_args.save_name])
|
| 84 |
+
if command_args.save_dir is not None:
|
| 85 |
+
train_args.extend(["--save_dir", command_args.save_dir])
|
| 86 |
+
args = ner_tagger.parse_args(train_args)
|
| 87 |
+
save_name = ner_tagger.model_file_name(args)
|
| 88 |
+
return save_name
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
# Technically NER datasets are not necessarily treebanks
|
| 92 |
+
# (usually not, in fact)
|
| 93 |
+
# However, to keep the naming consistent, we leave the
|
| 94 |
+
# method which does the training as run_treebank
|
| 95 |
+
# TODO: rename treebank -> dataset everywhere
|
| 96 |
+
def run_treebank(mode, paths, treebank, short_name,
|
| 97 |
+
temp_output_file, command_args, extra_args):
|
| 98 |
+
ner_dir = paths["NER_DATA_DIR"]
|
| 99 |
+
language, dataset = short_name.split("_")
|
| 100 |
+
|
| 101 |
+
train_file = os.path.join(ner_dir, f"{treebank}.train.json")
|
| 102 |
+
dev_file = os.path.join(ner_dir, f"{treebank}.dev.json")
|
| 103 |
+
test_file = os.path.join(ner_dir, f"{treebank}.test.json")
|
| 104 |
+
|
| 105 |
+
# if any files are missing, try to rebuild the dataset
|
| 106 |
+
# if that still doesn't work, we have to throw an error
|
| 107 |
+
missing_file = [x for x in (train_file, dev_file, test_file) if not os.path.exists(x)]
|
| 108 |
+
if len(missing_file) > 0:
|
| 109 |
+
logger.warning(f"The data for {treebank} is missing or incomplete. Cannot find {missing_file} Attempting to rebuild...")
|
| 110 |
+
try:
|
| 111 |
+
prepare_ner_dataset.main(treebank)
|
| 112 |
+
except Exception as e:
|
| 113 |
+
raise FileNotFoundError(f"An exception occurred while trying to build the data for {treebank} At least one portion of the data was missing: {missing_file} Please correctly build these files and then try again.") from e
|
| 114 |
+
|
| 115 |
+
pretrain_args = build_pretrain_args(language, dataset, command_args.charlm, command_args, extra_args)
|
| 116 |
+
|
| 117 |
+
if mode == Mode.TRAIN:
|
| 118 |
+
# VI example arguments:
|
| 119 |
+
# --wordvec_pretrain_file ~/stanza_resources/vi/pretrain/vtb.pt
|
| 120 |
+
# --train_file data/ner/vi_vlsp.train.json
|
| 121 |
+
# --eval_file data/ner/vi_vlsp.dev.json
|
| 122 |
+
# --lang vi
|
| 123 |
+
# --shorthand vi_vlsp
|
| 124 |
+
# --mode train
|
| 125 |
+
# --charlm --charlm_shorthand vi_conll17
|
| 126 |
+
# --dropout 0.6 --word_dropout 0.1 --locked_dropout 0.1 --char_dropout 0.1
|
| 127 |
+
dataset_args = DATASET_EXTRA_ARGS.get(short_name, [])
|
| 128 |
+
|
| 129 |
+
train_args = ['--train_file', train_file,
|
| 130 |
+
'--eval_file', dev_file,
|
| 131 |
+
'--shorthand', short_name,
|
| 132 |
+
'--mode', 'train']
|
| 133 |
+
train_args = train_args + pretrain_args + dataset_args + extra_args
|
| 134 |
+
logger.info("Running train step with args: {}".format(train_args))
|
| 135 |
+
ner_tagger.main(train_args)
|
| 136 |
+
|
| 137 |
+
if mode == Mode.SCORE_DEV or mode == Mode.TRAIN:
|
| 138 |
+
dev_args = ['--eval_file', dev_file,
|
| 139 |
+
'--shorthand', short_name,
|
| 140 |
+
'--mode', 'predict']
|
| 141 |
+
dev_args = dev_args + pretrain_args + extra_args
|
| 142 |
+
logger.info("Running dev step with args: {}".format(dev_args))
|
| 143 |
+
ner_tagger.main(dev_args)
|
| 144 |
+
|
| 145 |
+
if mode == Mode.SCORE_TEST or mode == Mode.TRAIN:
|
| 146 |
+
test_args = ['--eval_file', test_file,
|
| 147 |
+
'--shorthand', short_name,
|
| 148 |
+
'--mode', 'predict']
|
| 149 |
+
test_args = test_args + pretrain_args + extra_args
|
| 150 |
+
logger.info("Running test step with args: {}".format(test_args))
|
| 151 |
+
ner_tagger.main(test_args)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def main():
|
| 155 |
+
common.main(run_treebank, "ner", "nertagger", add_ner_args, ner_tagger.build_argparse(), build_model_filename=build_model_filename)
|
| 156 |
+
|
| 157 |
+
if __name__ == "__main__":
|
| 158 |
+
main()
|
| 159 |
+
|
stanza/stanza/utils/training/run_sentiment.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Trains or tests a sentiment model using the classifier package
|
| 3 |
+
|
| 4 |
+
The prep script has separate entries for the root-only version of SST,
|
| 5 |
+
which is what people typically use to test. When training a model for
|
| 6 |
+
SST which uses all the data, the root-only version is used for
|
| 7 |
+
dev and test
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import logging
|
| 11 |
+
import os
|
| 12 |
+
|
| 13 |
+
from stanza.models import classifier
|
| 14 |
+
from stanza.utils.training import common
|
| 15 |
+
from stanza.utils.training.common import Mode, build_charlm_args, choose_charlm, find_wordvec_pretrain
|
| 16 |
+
|
| 17 |
+
from stanza.resources.default_packages import default_charlms, default_pretrains
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger('stanza')
|
| 20 |
+
|
| 21 |
+
# TODO: refactor with ner & conparse
|
| 22 |
+
def add_sentiment_args(parser):
|
| 23 |
+
parser.add_argument('--charlm', default="default", type=str, help='Which charlm to run on. Will use the default charlm for this language/model if not set. Set to None to turn off charlm for languages with a default charlm')
|
| 24 |
+
parser.add_argument('--no_charlm', dest='charlm', action="store_const", const=None, help="Don't use a charlm, even if one is used by default for this package")
|
| 25 |
+
|
| 26 |
+
parser.add_argument('--use_bert', default=False, action="store_true", help='Use the default transformer for this language')
|
| 27 |
+
|
| 28 |
+
ALTERNATE_DATASET = {
|
| 29 |
+
"en_sst2": "en_sst2roots",
|
| 30 |
+
"en_sstplus": "en_sst3roots",
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
def build_default_args(paths, short_language, dataset, command_args, extra_args):
|
| 34 |
+
if '--wordvec_pretrain_file' not in extra_args:
|
| 35 |
+
# will throw an error if the pretrain can't be found
|
| 36 |
+
wordvec_pretrain = find_wordvec_pretrain(short_language, default_pretrains)
|
| 37 |
+
wordvec_args = ['--wordvec_pretrain_file', wordvec_pretrain]
|
| 38 |
+
else:
|
| 39 |
+
wordvec_args = []
|
| 40 |
+
|
| 41 |
+
charlm = choose_charlm(short_language, dataset, command_args.charlm, default_charlms, {})
|
| 42 |
+
charlm_args = build_charlm_args(short_language, charlm, base_args=False)
|
| 43 |
+
|
| 44 |
+
bert_args = common.choose_transformer(short_language, command_args, extra_args)
|
| 45 |
+
default_args = wordvec_args + charlm_args + bert_args
|
| 46 |
+
|
| 47 |
+
return default_args
|
| 48 |
+
|
| 49 |
+
def build_model_filename(paths, short_name, command_args, extra_args):
|
| 50 |
+
short_language, dataset = short_name.split("_", 1)
|
| 51 |
+
|
| 52 |
+
default_args = build_default_args(paths, short_language, dataset, command_args, extra_args)
|
| 53 |
+
|
| 54 |
+
train_args = ["--shorthand", short_name]
|
| 55 |
+
train_args = train_args + default_args
|
| 56 |
+
if command_args.save_name is not None:
|
| 57 |
+
train_args.extend(["--save_name", command_args.save_name])
|
| 58 |
+
if command_args.save_dir is not None:
|
| 59 |
+
train_args.extend(["--save_dir", command_args.save_dir])
|
| 60 |
+
args = classifier.parse_args(train_args + extra_args)
|
| 61 |
+
save_name = classifier.build_model_filename(args)
|
| 62 |
+
return save_name
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def run_dataset(mode, paths, treebank, short_name,
|
| 66 |
+
temp_output_file, command_args, extra_args):
|
| 67 |
+
sentiment_dir = paths["SENTIMENT_DATA_DIR"]
|
| 68 |
+
short_language, dataset = short_name.split("_", 1)
|
| 69 |
+
|
| 70 |
+
train_file = os.path.join(sentiment_dir, f"{short_name}.train.json")
|
| 71 |
+
|
| 72 |
+
other_name = ALTERNATE_DATASET.get(short_name, short_name)
|
| 73 |
+
dev_file = os.path.join(sentiment_dir, f"{other_name}.dev.json")
|
| 74 |
+
test_file = os.path.join(sentiment_dir, f"{other_name}.test.json")
|
| 75 |
+
|
| 76 |
+
for filename in (train_file, dev_file, test_file):
|
| 77 |
+
if not os.path.exists(filename):
|
| 78 |
+
raise FileNotFoundError("Cannot find %s" % filename)
|
| 79 |
+
|
| 80 |
+
default_args = build_default_args(paths, short_language, dataset, command_args, extra_args)
|
| 81 |
+
|
| 82 |
+
if mode == Mode.TRAIN:
|
| 83 |
+
train_args = ['--train_file', train_file,
|
| 84 |
+
'--dev_file', dev_file,
|
| 85 |
+
'--test_file', test_file,
|
| 86 |
+
'--shorthand', short_name,
|
| 87 |
+
'--wordvec_type', 'word2vec', # TODO: chinese is fasttext
|
| 88 |
+
'--extra_wordvec_method', 'SUM']
|
| 89 |
+
train_args = train_args + default_args + extra_args
|
| 90 |
+
logger.info("Running train step with args: {}".format(train_args))
|
| 91 |
+
classifier.main(train_args)
|
| 92 |
+
|
| 93 |
+
if mode == Mode.SCORE_DEV or mode == Mode.TRAIN:
|
| 94 |
+
dev_args = ['--no_train',
|
| 95 |
+
'--test_file', dev_file,
|
| 96 |
+
'--shorthand', short_name,
|
| 97 |
+
'--wordvec_type', 'word2vec'] # TODO: chinese is fasttext
|
| 98 |
+
dev_args = dev_args + default_args + extra_args
|
| 99 |
+
logger.info("Running dev step with args: {}".format(dev_args))
|
| 100 |
+
classifier.main(dev_args)
|
| 101 |
+
|
| 102 |
+
if mode == Mode.SCORE_TEST or mode == Mode.TRAIN:
|
| 103 |
+
test_args = ['--no_train',
|
| 104 |
+
'--test_file', test_file,
|
| 105 |
+
'--shorthand', short_name,
|
| 106 |
+
'--wordvec_type', 'word2vec'] # TODO: chinese is fasttext
|
| 107 |
+
test_args = test_args + default_args + extra_args
|
| 108 |
+
logger.info("Running test step with args: {}".format(test_args))
|
| 109 |
+
classifier.main(test_args)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def main():
|
| 114 |
+
common.main(run_dataset, "classifier", "classifier", add_sentiment_args, classifier.build_argparse(), build_model_filename=build_model_filename)
|
| 115 |
+
|
| 116 |
+
if __name__ == "__main__":
|
| 117 |
+
main()
|
| 118 |
+
|
stanza/stanza/utils/training/run_tokenizer.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This script allows for training or testing on dev / test of the UD tokenizer.
|
| 3 |
+
|
| 4 |
+
If run with a single treebank name, it will train or test that treebank.
|
| 5 |
+
If run with ud_all or all_ud, it will iterate over all UD treebanks it can find.
|
| 6 |
+
|
| 7 |
+
Mode can be set to train&dev with --train, to dev set only
|
| 8 |
+
with --score_dev, and to test set only with --score_test.
|
| 9 |
+
|
| 10 |
+
Treebanks are specified as a list. all_ud or ud_all means to look for
|
| 11 |
+
all UD treebanks.
|
| 12 |
+
|
| 13 |
+
Extra arguments are passed to tokenizer. In case the run script
|
| 14 |
+
itself is shadowing arguments, you can specify --extra_args as a
|
| 15 |
+
parameter to mark where the tokenizer arguments start.
|
| 16 |
+
|
| 17 |
+
Default behavior is to discard the output and just print the results.
|
| 18 |
+
To keep the results instead, use --save_output
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
import logging
|
| 22 |
+
import math
|
| 23 |
+
import os
|
| 24 |
+
|
| 25 |
+
from stanza.models import tokenizer
|
| 26 |
+
from stanza.utils.avg_sent_len import avg_sent_len
|
| 27 |
+
from stanza.utils.training import common
|
| 28 |
+
from stanza.utils.training.common import Mode
|
| 29 |
+
|
| 30 |
+
logger = logging.getLogger('stanza')
|
| 31 |
+
|
| 32 |
+
def uses_dictionary(short_language):
|
| 33 |
+
"""
|
| 34 |
+
Some of the languages (as shown here) have external dictionaries
|
| 35 |
+
|
| 36 |
+
We found this helped the overall tokenizer performance
|
| 37 |
+
If these can't be found, they can be extracted from the previous iteration of models
|
| 38 |
+
"""
|
| 39 |
+
if short_language in ('ja', 'th', 'zh', 'zh-hans', 'zh-hant'):
|
| 40 |
+
return True
|
| 41 |
+
return False
|
| 42 |
+
|
| 43 |
+
def run_treebank(mode, paths, treebank, short_name,
|
| 44 |
+
temp_output_file, command_args, extra_args):
|
| 45 |
+
tokenize_dir = paths["TOKENIZE_DATA_DIR"]
|
| 46 |
+
|
| 47 |
+
short_language = short_name.split("_")[0]
|
| 48 |
+
label_type = "--label_file"
|
| 49 |
+
label_file = f"{tokenize_dir}/{short_name}-ud-train.toklabels"
|
| 50 |
+
dev_type = "--txt_file"
|
| 51 |
+
dev_file = f"{tokenize_dir}/{short_name}.dev.txt"
|
| 52 |
+
test_type = "--txt_file"
|
| 53 |
+
test_file = f"{tokenize_dir}/{short_name}.test.txt"
|
| 54 |
+
train_type = "--txt_file"
|
| 55 |
+
train_file = f"{tokenize_dir}/{short_name}.train.txt"
|
| 56 |
+
train_dev_args = ["--dev_txt_file", dev_file, "--dev_label_file", f"{tokenize_dir}/{short_name}-ud-dev.toklabels"]
|
| 57 |
+
|
| 58 |
+
if short_language == "zh" or short_language.startswith("zh-"):
|
| 59 |
+
extra_args = ["--skip_newline"] + extra_args
|
| 60 |
+
|
| 61 |
+
train_gold = f"{tokenize_dir}/{short_name}.train.gold.conllu"
|
| 62 |
+
dev_gold = f"{tokenize_dir}/{short_name}.dev.gold.conllu"
|
| 63 |
+
test_gold = f"{tokenize_dir}/{short_name}.test.gold.conllu"
|
| 64 |
+
|
| 65 |
+
train_mwt = f"{tokenize_dir}/{short_name}-ud-train-mwt.json"
|
| 66 |
+
dev_mwt = f"{tokenize_dir}/{short_name}-ud-dev-mwt.json"
|
| 67 |
+
test_mwt = f"{tokenize_dir}/{short_name}-ud-test-mwt.json"
|
| 68 |
+
|
| 69 |
+
train_pred = temp_output_file if temp_output_file else f"{tokenize_dir}/{short_name}.train.pred.conllu"
|
| 70 |
+
dev_pred = temp_output_file if temp_output_file else f"{tokenize_dir}/{short_name}.dev.pred.conllu"
|
| 71 |
+
test_pred = temp_output_file if temp_output_file else f"{tokenize_dir}/{short_name}.test.pred.conllu"
|
| 72 |
+
|
| 73 |
+
if mode == Mode.TRAIN:
|
| 74 |
+
seqlen = str(math.ceil(avg_sent_len(label_file) * 3 / 100) * 100)
|
| 75 |
+
train_args = ([label_type, label_file, train_type, train_file, "--lang", short_language,
|
| 76 |
+
"--max_seqlen", seqlen, "--mwt_json_file", dev_mwt] +
|
| 77 |
+
train_dev_args +
|
| 78 |
+
["--dev_conll_gold", dev_gold, "--conll_file", dev_pred, "--shorthand", short_name])
|
| 79 |
+
if uses_dictionary(short_language):
|
| 80 |
+
train_args = train_args + ["--use_dictionary"]
|
| 81 |
+
train_args = train_args + extra_args
|
| 82 |
+
logger.info("Running train step with args: {}".format(train_args))
|
| 83 |
+
tokenizer.main(train_args)
|
| 84 |
+
|
| 85 |
+
if mode == Mode.SCORE_DEV or mode == Mode.TRAIN:
|
| 86 |
+
dev_args = ["--mode", "predict", dev_type, dev_file, "--lang", short_language,
|
| 87 |
+
"--conll_file", dev_pred, "--shorthand", short_name, "--mwt_json_file", dev_mwt]
|
| 88 |
+
dev_args = dev_args + extra_args
|
| 89 |
+
logger.info("Running dev step with args: {}".format(dev_args))
|
| 90 |
+
tokenizer.main(dev_args)
|
| 91 |
+
|
| 92 |
+
# TODO: log these results? The original script logged them to
|
| 93 |
+
# echo $results $args >> ${TOKENIZE_DATA_DIR}/${short}.results
|
| 94 |
+
|
| 95 |
+
results = common.run_eval_script_tokens(dev_gold, dev_pred)
|
| 96 |
+
logger.info("Finished running dev set on\n{}\n{}".format(treebank, results))
|
| 97 |
+
|
| 98 |
+
if mode == Mode.SCORE_TEST or mode == Mode.TRAIN:
|
| 99 |
+
test_args = ["--mode", "predict", test_type, test_file, "--lang", short_language,
|
| 100 |
+
"--conll_file", test_pred, "--shorthand", short_name, "--mwt_json_file", test_mwt]
|
| 101 |
+
test_args = test_args + extra_args
|
| 102 |
+
logger.info("Running test step with args: {}".format(test_args))
|
| 103 |
+
tokenizer.main(test_args)
|
| 104 |
+
|
| 105 |
+
results = common.run_eval_script_tokens(test_gold, test_pred)
|
| 106 |
+
logger.info("Finished running test set on\n{}\n{}".format(treebank, results))
|
| 107 |
+
|
| 108 |
+
if mode == Mode.SCORE_TRAIN:
|
| 109 |
+
test_args = ["--mode", "predict", test_type, train_file, "--lang", short_language,
|
| 110 |
+
"--conll_file", train_pred, "--shorthand", short_name, "--mwt_json_file", train_mwt]
|
| 111 |
+
test_args = test_args + extra_args
|
| 112 |
+
logger.info("Running test step with args: {}".format(test_args))
|
| 113 |
+
tokenizer.main(test_args)
|
| 114 |
+
|
| 115 |
+
results = common.run_eval_script_tokens(train_gold, train_pred)
|
| 116 |
+
logger.info("Finished running train set as a test on\n{}\n{}".format(treebank, results))
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def main():
|
| 121 |
+
common.main(run_treebank, "tokenize", "tokenizer", sub_argparse=tokenizer.build_argparse())
|
| 122 |
+
|
| 123 |
+
if __name__ == "__main__":
|
| 124 |
+
main()
|
stanza/stanza/utils/training/separate_ner_pretrain.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Loads NER models & separates out the word vectors to base & delta
|
| 3 |
+
|
| 4 |
+
The model will then be resaved without the base word vector,
|
| 5 |
+
greatly reducing the size of the model
|
| 6 |
+
|
| 7 |
+
This may be useful for any external users of stanza who have an NER
|
| 8 |
+
model they wish to reuse without retraining
|
| 9 |
+
|
| 10 |
+
If you know which pretrain was used to build an NER model, you can
|
| 11 |
+
provide that pretrain. Otherwise, you can give a directory of
|
| 12 |
+
pretrains and the script will test each one. In the latter case,
|
| 13 |
+
the name of the pretrain needs to look like lang_dataset_pretrain.pt
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import argparse
|
| 17 |
+
from collections import defaultdict
|
| 18 |
+
import logging
|
| 19 |
+
import os
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
import torch
|
| 23 |
+
import torch.nn as nn
|
| 24 |
+
|
| 25 |
+
from stanza import Pipeline
|
| 26 |
+
from stanza.models.common.constant import lang_to_langcode
|
| 27 |
+
from stanza.models.common.pretrain import Pretrain, PretrainedWordVocab
|
| 28 |
+
from stanza.models.common.vocab import PAD_ID, VOCAB_PREFIX
|
| 29 |
+
from stanza.models.ner.trainer import Trainer
|
| 30 |
+
|
| 31 |
+
logger = logging.getLogger('stanza')
|
| 32 |
+
logger.setLevel(logging.ERROR)
|
| 33 |
+
|
| 34 |
+
DEBUG = False
|
| 35 |
+
EPS = 0.0001
|
| 36 |
+
|
| 37 |
+
def main():
|
| 38 |
+
parser = argparse.ArgumentParser()
|
| 39 |
+
parser.add_argument('--input_path', type=str, default='saved_models/ner', help='Where to find NER models (dir or filename)')
|
| 40 |
+
parser.add_argument('--output_path', type=str, default='saved_models/shrunk', help='Where to write shrunk NER models (dir)')
|
| 41 |
+
parser.add_argument('--pretrain_path', type=str, default='saved_models/pretrain', help='Where to find pretrains (dir or filename)')
|
| 42 |
+
args = parser.parse_args()
|
| 43 |
+
|
| 44 |
+
# get list of NER models to shrink
|
| 45 |
+
if os.path.isdir(args.input_path):
|
| 46 |
+
ner_model_dir = args.input_path
|
| 47 |
+
ners = os.listdir(ner_model_dir)
|
| 48 |
+
if len(ners) == 0:
|
| 49 |
+
raise FileNotFoundError("No ner models found in {}".format(args.input_path))
|
| 50 |
+
else:
|
| 51 |
+
if not os.path.isfile(args.input_path):
|
| 52 |
+
raise FileNotFoundError("No ner model found at path {}".format(args.input_path))
|
| 53 |
+
ner_model_dir, ners = os.path.split(args.input_path)
|
| 54 |
+
ners = [ners]
|
| 55 |
+
|
| 56 |
+
# get map from language to candidate pretrains
|
| 57 |
+
if os.path.isdir(args.pretrain_path):
|
| 58 |
+
pt_model_dir = args.pretrain_path
|
| 59 |
+
pretrains = os.listdir(pt_model_dir)
|
| 60 |
+
lang_to_pretrain = defaultdict(list)
|
| 61 |
+
for pt in pretrains:
|
| 62 |
+
lang_to_pretrain[pt.split("_")[0]].append(pt)
|
| 63 |
+
else:
|
| 64 |
+
pt_model_dir, pretrains = os.path.split(pt_model_dir)
|
| 65 |
+
pretrains = [pretrains]
|
| 66 |
+
lang_to_pretrain = defaultdict(lambda: pretrains)
|
| 67 |
+
|
| 68 |
+
# shrunk models will all go in this directory
|
| 69 |
+
new_dir = args.output_path
|
| 70 |
+
os.makedirs(new_dir, exist_ok=True)
|
| 71 |
+
|
| 72 |
+
final_pretrains = []
|
| 73 |
+
missing_pretrains = []
|
| 74 |
+
no_finetune = []
|
| 75 |
+
|
| 76 |
+
# for each model, go through the various pretrains
|
| 77 |
+
# until we find one that works or none of them work
|
| 78 |
+
for ner_model in ners:
|
| 79 |
+
ner_path = os.path.join(ner_model_dir, ner_model)
|
| 80 |
+
|
| 81 |
+
expected_ending = "_nertagger.pt"
|
| 82 |
+
if not ner_model.endswith(expected_ending):
|
| 83 |
+
raise ValueError("Unexpected name: {}".format(ner_model))
|
| 84 |
+
short_name = ner_model[:-len(expected_ending)]
|
| 85 |
+
lang, package = short_name.split("_", maxsplit=1)
|
| 86 |
+
print("===============================================")
|
| 87 |
+
print("Processing lang %s package %s" % (lang, package))
|
| 88 |
+
|
| 89 |
+
# this may look funny - basically, the pipeline has machinery
|
| 90 |
+
# to make sure the model has everything it needs to load,
|
| 91 |
+
# including downloading other pieces if needed
|
| 92 |
+
pipe = Pipeline(lang, processors="tokenize,ner", tokenize_pretokenized=True, package={"ner": package}, ner_model_path=ner_path)
|
| 93 |
+
ner_processor = pipe.processors['ner']
|
| 94 |
+
print("Loaded NER processor: {}".format(ner_processor))
|
| 95 |
+
trainer = ner_processor.trainers[0]
|
| 96 |
+
vocab = trainer.model.vocab
|
| 97 |
+
word_vocab = vocab['word']
|
| 98 |
+
num_vectors = trainer.model.word_emb.weight.shape[0]
|
| 99 |
+
|
| 100 |
+
# sanity check, make sure the model loaded matches the
|
| 101 |
+
# language from the model's filename
|
| 102 |
+
lcode = lang_to_langcode(trainer.args['lang'])
|
| 103 |
+
if lang != lcode and not (lcode == 'zh' and lang == 'zh-hans'):
|
| 104 |
+
raise ValueError("lang not as expected: {} vs {} ({})".format(lang, trainer.args['lang'], lcode))
|
| 105 |
+
|
| 106 |
+
ner_pretrains = sorted(set(lang_to_pretrain[lang] + lang_to_pretrain[lcode]))
|
| 107 |
+
for pt_model in ner_pretrains:
|
| 108 |
+
pt_path = os.path.join(pt_model_dir, pt_model)
|
| 109 |
+
print("Attempting pretrain: {}".format(pt_path))
|
| 110 |
+
pt = Pretrain(filename=pt_path)
|
| 111 |
+
print(" pretrain shape: {}".format(pt.emb.shape))
|
| 112 |
+
print(" embedding in ner model shape: {}".format(trainer.model.word_emb.weight.shape))
|
| 113 |
+
if pt.emb.shape[1] != trainer.model.word_emb.weight.shape[1]:
|
| 114 |
+
print(" DIMENSION DOES NOT MATCH. SKIPPING")
|
| 115 |
+
continue
|
| 116 |
+
N = min(pt.emb.shape[0], trainer.model.word_emb.weight.shape[0])
|
| 117 |
+
if pt.emb.shape[0] != trainer.model.word_emb.weight.shape[0]:
|
| 118 |
+
# If the vocab was exactly the same, that's a good
|
| 119 |
+
# sign this pretrain was used, just with a different size
|
| 120 |
+
# In such a case, we can reuse the rest of the pretrain
|
| 121 |
+
# Minor issue: some vectors which were trained will be
|
| 122 |
+
# lost in the case of |pt| < |model.word_emb|
|
| 123 |
+
if all(word_vocab.id2unit(x) == word_vocab.id2unit(x) for x in range(N)):
|
| 124 |
+
print(" Attempting to use pt vectors to replace ner model's vectors")
|
| 125 |
+
else:
|
| 126 |
+
print(" NUM VECTORS DO NOT MATCH. WORDS DO NOT MATCH. SKIPPING")
|
| 127 |
+
continue
|
| 128 |
+
if pt.emb.shape[0] < trainer.model.word_emb.weight.shape[0]:
|
| 129 |
+
print(" WARNING: if any vectors beyond {} were fine tuned, that fine tuning will be lost".format(N))
|
| 130 |
+
device = next(trainer.model.parameters()).device
|
| 131 |
+
delta = trainer.model.word_emb.weight[:N, :] - pt.emb.to(device)[:N, :]
|
| 132 |
+
delta = delta.detach()
|
| 133 |
+
delta_norms = torch.linalg.norm(delta, dim=1).cpu().numpy()
|
| 134 |
+
if np.sum(delta_norms < 0) > 0:
|
| 135 |
+
raise ValueError("This should not be - a norm was less than 0!")
|
| 136 |
+
num_matching = np.sum(delta_norms < EPS)
|
| 137 |
+
if num_matching > N / 2:
|
| 138 |
+
print(" Accepted! %d of %d vectors match for %s" % (num_matching, N, pt_path))
|
| 139 |
+
if pt.emb.shape[0] != trainer.model.word_emb.weight.shape[0]:
|
| 140 |
+
print(" Setting model vocab to match the pretrain")
|
| 141 |
+
word_vocab = pt.vocab
|
| 142 |
+
vocab['word'] = word_vocab
|
| 143 |
+
trainer.args['word_emb_dim'] = pt.emb.shape[1]
|
| 144 |
+
break
|
| 145 |
+
else:
|
| 146 |
+
print(" %d of %d vectors matched for %s - SKIPPING" % (num_matching, N, pt_path))
|
| 147 |
+
vocab_same = sum(x in pt.vocab for x in word_vocab)
|
| 148 |
+
print(" %d words were in both vocabs" % vocab_same)
|
| 149 |
+
# this is expensive, and in practice doesn't happen,
|
| 150 |
+
# but theoretically we might have missed a mostly matching pt
|
| 151 |
+
# if the vocab had been scrambled
|
| 152 |
+
if DEBUG:
|
| 153 |
+
rearranged_count = 0
|
| 154 |
+
for x in word_vocab:
|
| 155 |
+
if x not in pt.vocab:
|
| 156 |
+
continue
|
| 157 |
+
x_id = word_vocab.unit2id(x)
|
| 158 |
+
x_vec = trainer.model.word_emb.weight[x_id, :]
|
| 159 |
+
pt_id = pt.vocab.unit2id(x)
|
| 160 |
+
pt_vec = pt.emb[pt_id, :]
|
| 161 |
+
if (x_vec.detach().cpu() - pt_vec).norm() < EPS:
|
| 162 |
+
rearranged_count += 1
|
| 163 |
+
print(" %d vectors were close when ignoring id ordering" % rearranged_count)
|
| 164 |
+
else:
|
| 165 |
+
print("COULD NOT FIND A MATCHING PT: {}".format(ner_processor))
|
| 166 |
+
missing_pretrains.append(ner_model)
|
| 167 |
+
continue
|
| 168 |
+
|
| 169 |
+
# build a delta vector & embedding
|
| 170 |
+
assert 'delta' not in vocab.keys()
|
| 171 |
+
delta_vectors = [delta[i].cpu() for i in range(4)]
|
| 172 |
+
delta_vocab = []
|
| 173 |
+
for i in range(4, len(delta_norms)):
|
| 174 |
+
if delta_norms[i] > 0.0:
|
| 175 |
+
delta_vocab.append(word_vocab.id2unit(i))
|
| 176 |
+
delta_vectors.append(delta[i].cpu())
|
| 177 |
+
|
| 178 |
+
trainer.model.unsaved_modules.append("word_emb")
|
| 179 |
+
if len(delta_vocab) == 0:
|
| 180 |
+
print("No vectors were changed! Perhaps this model was trained without finetune.")
|
| 181 |
+
no_finetune.append(ner_model)
|
| 182 |
+
else:
|
| 183 |
+
print("%d delta vocab" % len(delta_vocab))
|
| 184 |
+
print("%d vectors in the delta set" % len(delta_vectors))
|
| 185 |
+
delta_vectors = np.stack(delta_vectors)
|
| 186 |
+
delta_vectors = torch.from_numpy(delta_vectors)
|
| 187 |
+
assert delta_vectors.shape[0] == len(delta_vocab) + len(VOCAB_PREFIX)
|
| 188 |
+
print(delta_vectors.shape)
|
| 189 |
+
|
| 190 |
+
delta_vocab = PretrainedWordVocab(delta_vocab, lang=word_vocab.lang, lower=word_vocab.lower)
|
| 191 |
+
vocab['delta'] = delta_vocab
|
| 192 |
+
trainer.model.delta_emb = nn.Embedding(delta_vectors.shape[0], delta_vectors.shape[1], PAD_ID)
|
| 193 |
+
trainer.model.delta_emb.weight.data.copy_(delta_vectors)
|
| 194 |
+
|
| 195 |
+
new_path = os.path.join(new_dir, ner_model)
|
| 196 |
+
trainer.save(new_path)
|
| 197 |
+
|
| 198 |
+
final_pretrains.append((ner_model, pt_model))
|
| 199 |
+
|
| 200 |
+
print()
|
| 201 |
+
if len(final_pretrains) > 0:
|
| 202 |
+
print("Final pretrain mappings:")
|
| 203 |
+
for i in final_pretrains:
|
| 204 |
+
print(i)
|
| 205 |
+
if len(missing_pretrains) > 0:
|
| 206 |
+
print("MISSING EMBEDDINGS:")
|
| 207 |
+
for i in missing_pretrains:
|
| 208 |
+
print(i)
|
| 209 |
+
if len(no_finetune) > 0:
|
| 210 |
+
print("NOT FINE TUNED:")
|
| 211 |
+
for i in no_finetune:
|
| 212 |
+
print(i)
|
| 213 |
+
|
| 214 |
+
if __name__ == '__main__':
|
| 215 |
+
main()
|
stanza/stanza/utils/visualization/__init__.py
ADDED
|
File without changes
|
stanza/stanza/utils/visualization/conll_deprel_visualization.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from stanza.models.common.constant import is_right_to_left
|
| 2 |
+
import spacy
|
| 3 |
+
import argparse
|
| 4 |
+
from spacy import displacy
|
| 5 |
+
from spacy.tokens import Doc
|
| 6 |
+
from stanza.utils import conll
|
| 7 |
+
from stanza.utils.visualization import dependency_visualization as viz
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def conll_to_visual(conll_file, pipeline, sent_count=10, display_all=False):
|
| 11 |
+
"""
|
| 12 |
+
Takes in a conll file and visualizes it by converting the conll file to a Stanza Document object
|
| 13 |
+
and visualizing it with the visualize_doc method.
|
| 14 |
+
|
| 15 |
+
Input should be a proper conll file.
|
| 16 |
+
|
| 17 |
+
The pipeline for the conll file to be processed in must be provided as well.
|
| 18 |
+
|
| 19 |
+
Optionally, the sent_count argument can be tweaked to display a different amount of sentences.
|
| 20 |
+
|
| 21 |
+
To display all of the sentences in a conll file, the display_all argument can optionally be set to True.
|
| 22 |
+
BEWARE: setting this argument for a large conll file may result in too many renderings, resulting in a crash.
|
| 23 |
+
"""
|
| 24 |
+
# convert conll file to doc
|
| 25 |
+
doc = conll.CoNLL.conll2doc(conll_file)
|
| 26 |
+
|
| 27 |
+
if display_all:
|
| 28 |
+
viz.visualize_doc(conll.CoNLL.conll2doc(conll_file), pipeline)
|
| 29 |
+
else: # visualize a given number of sentences
|
| 30 |
+
visualization_options = {"compact": True, "bg": "#09a3d5", "color": "white", "distance": 100,
|
| 31 |
+
"font": "Source Sans Pro", "offset_x": 30,
|
| 32 |
+
"arrow_spacing": 20} # see spaCy visualization settings doc for more options
|
| 33 |
+
nlp = spacy.blank("en")
|
| 34 |
+
sentences_to_visualize, rtl, num_sentences = [], is_right_to_left(pipeline), len(doc.sentences)
|
| 35 |
+
|
| 36 |
+
for i in range(sent_count):
|
| 37 |
+
if i >= num_sentences: # case where there are less sentences than amount requested
|
| 38 |
+
break
|
| 39 |
+
sentence = doc.sentences[i]
|
| 40 |
+
words, lemmas, heads, deps, tags = [], [], [], [], []
|
| 41 |
+
sentence_words = sentence.words
|
| 42 |
+
if rtl: # rtl languages will be visually rendered from right to left as well
|
| 43 |
+
sentence_words = reversed(sentence.words)
|
| 44 |
+
sent_len = len(sentence.words)
|
| 45 |
+
for word in sentence_words:
|
| 46 |
+
words.append(word.text)
|
| 47 |
+
lemmas.append(word.lemma)
|
| 48 |
+
deps.append(word.deprel)
|
| 49 |
+
tags.append(word.upos)
|
| 50 |
+
if rtl and word.head == 0: # word heads are off-by-1 in spaCy doc inits compared to Stanza
|
| 51 |
+
heads.append(sent_len - word.id)
|
| 52 |
+
elif rtl and word.head != 0:
|
| 53 |
+
heads.append(sent_len - word.head)
|
| 54 |
+
elif not rtl and word.head == 0:
|
| 55 |
+
heads.append(word.id - 1)
|
| 56 |
+
elif not rtl and word.head != 0:
|
| 57 |
+
heads.append(word.head - 1)
|
| 58 |
+
|
| 59 |
+
document_result = Doc(nlp.vocab, words=words, lemmas=lemmas, heads=heads, deps=deps, pos=tags)
|
| 60 |
+
sentences_to_visualize.append(document_result)
|
| 61 |
+
|
| 62 |
+
print(sentences_to_visualize)
|
| 63 |
+
for line in sentences_to_visualize: # render all sentences through displaCy
|
| 64 |
+
displacy.render(line, style="dep", options=visualization_options)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def main():
|
| 68 |
+
parser = argparse.ArgumentParser()
|
| 69 |
+
parser.add_argument('--conll_file', type=str,
|
| 70 |
+
default="C:\\Users\\Alex\\stanza\\demo\\en_test.conllu.txt",
|
| 71 |
+
help="File path of the CoNLL file to visualize dependencies of")
|
| 72 |
+
parser.add_argument('--pipeline', type=str, default="en",
|
| 73 |
+
help="Language code of the language pipeline to use (ex: 'en' for English)")
|
| 74 |
+
parser.add_argument('--sent_count', type=int, default=10, help="Number of sentences to visualize from CoNLL file")
|
| 75 |
+
parser.add_argument('--display_all', type=bool, default=False,
|
| 76 |
+
help="Whether or not to visualize all of the sentences from the file. Overrides sent_count if set to True")
|
| 77 |
+
args = parser.parse_args()
|
| 78 |
+
conll_to_visual(args.conll_file, args.pipeline, args.sent_count, args.display_all)
|
| 79 |
+
return
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
if __name__ == "__main__":
|
| 83 |
+
main()
|