bowphs commited on
Commit
495f002
·
verified ·
1 Parent(s): bf62052

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. stanza/stanza/tests/classifiers/test_data.py +130 -0
  2. stanza/stanza/tests/constituency/test_tree_stack.py +50 -0
  3. stanza/stanza/tests/data/external_server.properties +1 -0
  4. stanza/stanza/tests/lemma/test_lowercase.py +57 -0
  5. stanza/stanza/tests/ner/test_bsf_2_beios.py +349 -0
  6. stanza/stanza/tests/ner/test_ner_training.py +261 -0
  7. stanza/stanza/tests/pipeline/pipeline_device_tests.py +45 -0
  8. stanza/stanza/tests/pipeline/test_pipeline_sentiment_processor.py +50 -0
  9. stanza/stanza/tests/pipeline/test_requirements.py +72 -0
  10. stanza/stanza/tests/tokenization/__init__.py +0 -0
  11. stanza/stanza/tests/tokenization/test_tokenize_utils.py +220 -0
  12. stanza/stanza/utils/charlm/__init__.py +0 -0
  13. stanza/stanza/utils/charlm/conll17_to_text.py +93 -0
  14. stanza/stanza/utils/charlm/dump_oscar.py +120 -0
  15. stanza/stanza/utils/charlm/make_lm_data.py +162 -0
  16. stanza/stanza/utils/constituency/check_transitions.py +27 -0
  17. stanza/stanza/utils/constituency/list_tensors.py +16 -0
  18. stanza/stanza/utils/datasets/__init__.py +0 -0
  19. stanza/stanza/utils/datasets/contract_mwt.py +46 -0
  20. stanza/stanza/utils/datasets/coref/__init__.py +0 -0
  21. stanza/stanza/utils/datasets/coref/convert_ontonotes.py +80 -0
  22. stanza/stanza/utils/datasets/coref/convert_udcoref.py +276 -0
  23. stanza/stanza/utils/datasets/coref/utils.py +148 -0
  24. stanza/stanza/utils/datasets/corenlp_segmenter_dataset.py +78 -0
  25. stanza/stanza/utils/datasets/ner/convert_bsnlp.py +333 -0
  26. stanza/stanza/utils/datasets/ner/convert_fire_2013.py +118 -0
  27. stanza/stanza/utils/datasets/ner/convert_hy_armtdp.py +145 -0
  28. stanza/stanza/utils/datasets/ner/convert_kk_kazNERD.py +35 -0
  29. stanza/stanza/utils/datasets/ner/convert_my_ucsy.py +102 -0
  30. stanza/stanza/utils/datasets/ner/convert_sindhi_siner.py +69 -0
  31. stanza/stanza/utils/datasets/ner/convert_starlang_ner.py +55 -0
  32. stanza/stanza/utils/datasets/ner/ontonotes_multitag.py +97 -0
  33. stanza/stanza/utils/datasets/ner/prepare_ner_file.py +78 -0
  34. stanza/stanza/utils/datasets/ner/utils.py +417 -0
  35. stanza/stanza/utils/datasets/vietnamese/__init__.py +0 -0
  36. stanza/stanza/utils/pretrain/compare_pretrains.py +54 -0
  37. stanza/stanza/utils/training/common.py +397 -0
  38. stanza/stanza/utils/training/compose_ete_results.py +100 -0
  39. stanza/stanza/utils/training/run_charlm.py +86 -0
  40. stanza/stanza/utils/training/run_constituency.py +130 -0
  41. stanza/stanza/utils/training/run_depparse.py +133 -0
  42. stanza/stanza/utils/training/run_lemma.py +179 -0
  43. stanza/stanza/utils/training/run_lemma_classifier.py +87 -0
  44. stanza/stanza/utils/training/run_mwt.py +122 -0
  45. stanza/stanza/utils/training/run_ner.py +159 -0
  46. stanza/stanza/utils/training/run_sentiment.py +118 -0
  47. stanza/stanza/utils/training/run_tokenizer.py +124 -0
  48. stanza/stanza/utils/training/separate_ner_pretrain.py +215 -0
  49. stanza/stanza/utils/visualization/__init__.py +0 -0
  50. stanza/stanza/utils/visualization/conll_deprel_visualization.py +83 -0
stanza/stanza/tests/classifiers/test_data.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import pytest
3
+
4
+ import stanza.models.classifiers.data as data
5
+ from stanza.models.classifiers.utils import WVType
6
+ from stanza.models.common.vocab import PAD, UNK
7
+ from stanza.models.constituency.parse_tree import Tree
8
+
9
+ SENTENCES = [
10
+ ["I", "hate", "the", "Opal", "banning"],
11
+ ["Tell", "my", "wife", "hello"], # obviously this is the neutral result
12
+ ["I", "like", "Sh'reyan", "'s", "antennae"],
13
+ ]
14
+
15
+ DATASET = [
16
+ {"sentiment": "0", "text": SENTENCES[0]},
17
+ {"sentiment": "1", "text": SENTENCES[1]},
18
+ {"sentiment": "2", "text": SENTENCES[2]},
19
+ ]
20
+
21
+ TREES = [
22
+ "(ROOT (S (NP (PRP I)) (VP (VBP hate) (NP (DT the) (NN Opal) (NN banning)))))",
23
+ "(ROOT (S (VP (VB Tell) (NP (PRP$ my) (NN wife)) (NP (UH hello)))))",
24
+ "(ROOT (S (NP (PRP I)) (VP (VBP like) (NP (NP (NNP Sh'reyan) (POS 's)) (NNS antennae)))))",
25
+ ]
26
+
27
+ DATASET_WITH_TREES = [
28
+ {"sentiment": "0", "text": SENTENCES[0], "constituency": TREES[0]},
29
+ {"sentiment": "1", "text": SENTENCES[1], "constituency": TREES[1]},
30
+ {"sentiment": "2", "text": SENTENCES[2], "constituency": TREES[2]},
31
+ ]
32
+
33
+ @pytest.fixture(scope="module")
34
+ def train_file(tmp_path_factory):
35
+ train_set = DATASET * 20
36
+ train_filename = tmp_path_factory.mktemp("data") / "train.json"
37
+ with open(train_filename, "w", encoding="utf-8") as fout:
38
+ json.dump(train_set, fout, ensure_ascii=False)
39
+ return train_filename
40
+
41
+ @pytest.fixture(scope="module")
42
+ def dev_file(tmp_path_factory):
43
+ dev_set = DATASET * 2
44
+ dev_filename = tmp_path_factory.mktemp("data") / "dev.json"
45
+ with open(dev_filename, "w", encoding="utf-8") as fout:
46
+ json.dump(dev_set, fout, ensure_ascii=False)
47
+ return dev_filename
48
+
49
+ @pytest.fixture(scope="module")
50
+ def test_file(tmp_path_factory):
51
+ test_set = DATASET
52
+ test_filename = tmp_path_factory.mktemp("data") / "test.json"
53
+ with open(test_filename, "w", encoding="utf-8") as fout:
54
+ json.dump(test_set, fout, ensure_ascii=False)
55
+ return test_filename
56
+
57
+ @pytest.fixture(scope="module")
58
+ def train_file_with_trees(tmp_path_factory):
59
+ train_set = DATASET_WITH_TREES * 20
60
+ train_filename = tmp_path_factory.mktemp("data") / "train_trees.json"
61
+ with open(train_filename, "w", encoding="utf-8") as fout:
62
+ json.dump(train_set, fout, ensure_ascii=False)
63
+ return train_filename
64
+
65
+ @pytest.fixture(scope="module")
66
+ def dev_file_with_trees(tmp_path_factory):
67
+ dev_set = DATASET_WITH_TREES * 2
68
+ dev_filename = tmp_path_factory.mktemp("data") / "dev_trees.json"
69
+ with open(dev_filename, "w", encoding="utf-8") as fout:
70
+ json.dump(dev_set, fout, ensure_ascii=False)
71
+ return dev_filename
72
+
73
+ class TestClassifierData:
74
+ def test_read_data(self, train_file):
75
+ """
76
+ Test reading of the json format
77
+ """
78
+ train_set = data.read_dataset(str(train_file), WVType.OTHER, 1)
79
+ assert len(train_set) == 60
80
+
81
+ def test_read_data_with_trees(self, train_file, train_file_with_trees):
82
+ """
83
+ Test reading of the json format
84
+ """
85
+ train_trees_set = data.read_dataset(str(train_file_with_trees), WVType.OTHER, 1)
86
+ assert len(train_trees_set) == 60
87
+ for idx, x in enumerate(train_trees_set):
88
+ assert isinstance(x.constituency, Tree)
89
+ assert str(x.constituency) == TREES[idx % len(TREES)]
90
+
91
+ train_set = data.read_dataset(str(train_file), WVType.OTHER, 1)
92
+
93
+ def test_dataset_vocab(self, train_file):
94
+ """
95
+ Converting a dataset to vocab should have a specific set of words along with PAD and UNK
96
+ """
97
+ train_set = data.read_dataset(str(train_file), WVType.OTHER, 1)
98
+ vocab = data.dataset_vocab(train_set)
99
+ expected = set([PAD, UNK] + [x.lower() for y in SENTENCES for x in y])
100
+ assert set(vocab) == expected
101
+
102
+ def test_dataset_labels(self, train_file):
103
+ """
104
+ Test the extraction of labels from a dataset
105
+ """
106
+ train_set = data.read_dataset(str(train_file), WVType.OTHER, 1)
107
+ labels = data.dataset_labels(train_set)
108
+ assert labels == ["0", "1", "2"]
109
+
110
+ def test_sort_by_length(self, train_file):
111
+ """
112
+ There are two unique lengths in the toy dataset
113
+ """
114
+ train_set = data.read_dataset(str(train_file), WVType.OTHER, 1)
115
+ sorted_dataset = data.sort_dataset_by_len(train_set)
116
+ assert list(sorted_dataset.keys()) == [4, 5]
117
+ assert len(sorted_dataset[4]) == len(train_set) // 3
118
+ assert len(sorted_dataset[5]) == 2 * len(train_set) // 3
119
+
120
+ def test_check_labels(self, train_file):
121
+ """
122
+ Check that an exception is thrown for an unknown label
123
+ """
124
+ train_set = data.read_dataset(str(train_file), WVType.OTHER, 1)
125
+ labels = sorted(set([x["sentiment"] for x in DATASET]))
126
+ assert len(labels) > 1
127
+ data.check_labels(labels, train_set)
128
+ with pytest.raises(RuntimeError):
129
+ data.check_labels(labels[:1], train_set)
130
+
stanza/stanza/tests/constituency/test_tree_stack.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+
3
+ from stanza.models.constituency.tree_stack import TreeStack
4
+
5
+ from stanza.tests import *
6
+
7
+ pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
8
+
9
+ def test_simple():
10
+ stack = TreeStack(value=5, parent=None, length=1)
11
+ stack = stack.push(3)
12
+ stack = stack.push(1)
13
+
14
+ expected_values = [1, 3, 5]
15
+ for value in expected_values:
16
+ assert stack.value == value
17
+ stack = stack.pop()
18
+ assert stack is None
19
+
20
+ def test_iter():
21
+ stack = TreeStack(value=5, parent=None, length=1)
22
+ stack = stack.push(3)
23
+ stack = stack.push(1)
24
+
25
+ stack_list = list(stack)
26
+ assert list(stack) == [1, 3, 5]
27
+
28
+ def test_str():
29
+ stack = TreeStack(value=5, parent=None, length=1)
30
+ stack = stack.push(3)
31
+ stack = stack.push(1)
32
+
33
+ assert str(stack) == "TreeStack(1, 3, 5)"
34
+
35
+ def test_len():
36
+ stack = TreeStack(value=5, parent=None, length=1)
37
+ assert len(stack) == 1
38
+
39
+ stack = stack.push(3)
40
+ stack = stack.push(1)
41
+ assert len(stack) == 3
42
+
43
+ def test_long_len():
44
+ """
45
+ Original stack had a bug where this took exponential time...
46
+ """
47
+ stack = TreeStack(value=0, parent=None, length=1)
48
+ for i in range(1, 40):
49
+ stack = stack.push(i)
50
+ assert len(stack) == 40
stanza/stanza/tests/data/external_server.properties ADDED
@@ -0,0 +1 @@
 
 
1
+ annotators = tokenize,ssplit,pos
stanza/stanza/tests/lemma/test_lowercase.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+
3
+ from stanza.models.lemmatizer import all_lowercase
4
+ from stanza.utils.conll import CoNLL
5
+
6
+ LATIN_CONLLU = """
7
+ # sent_id = train-s1
8
+ # text = unde et philosophus dicit felicitatem esse operationem perfectam.
9
+ # reference = ittb-scg-s4203
10
+ 1 unde unde ADV O4 AdvType=Loc|PronType=Rel 4 advmod:lmod _ _
11
+ 2 et et CCONJ O4 _ 3 advmod:emph _ _
12
+ 3 philosophus philosophus NOUN B1|grn1|casA|gen1 Case=Nom|Gender=Masc|InflClass=IndEurO|Number=Sing 4 nsubj _ _
13
+ 4 dicit dico VERB N3|modA|tem1|gen6 Aspect=Imp|InflClass=LatX|Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin|Voice=Act 0 root _ TraditionalMood=Indicativus|TraditionalTense=Praesens
14
+ 5 felicitatem felicitas NOUN C1|grn1|casD|gen2 Case=Acc|Gender=Fem|InflClass=IndEurX|Number=Sing 7 nsubj _ _
15
+ 6 esse sum AUX N3|modH|tem1 Aspect=Imp|Tense=Pres|VerbForm=Inf 7 cop _ _
16
+ 7 operationem operatio NOUN C1|grn1|casD|gen2|vgr1 Case=Acc|Gender=Fem|InflClass=IndEurX|Number=Sing 4 ccomp _ _
17
+ 8 perfectam perfectus ADJ A1|grn1|casD|gen2 Case=Acc|Gender=Fem|InflClass=IndEurA|Number=Sing 7 amod _ SpaceAfter=No
18
+ 9 . . PUNCT Punc _ 4 punct _ _
19
+
20
+ # sent_id = train-s2
21
+ # text = perfectio autem operationis dependet ex quatuor.
22
+ # reference = ittb-scg-s4204
23
+ 1 perfectio perfectio NOUN C1|grn1|casA|gen2 Case=Nom|Gender=Fem|InflClass=IndEurX|Number=Sing 4 nsubj _ _
24
+ 2 autem autem PART O4 _ 4 discourse _ _
25
+ 3 operationis operatio NOUN C1|grn1|casB|gen2|vgr1 Case=Gen|Gender=Fem|InflClass=IndEurX|Number=Sing 1 nmod _ _
26
+ 4 dependet dependeo VERB K3|modA|tem1|gen6 Aspect=Imp|InflClass=LatE|Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin|Voice=Act 0 root _ TraditionalMood=Indicativus|TraditionalTense=Praesens
27
+ 5 ex ex ADP S4|vgr2 _ 6 case _ _
28
+ 6 quatuor quattuor NUM G1|gen3|vgr1 NumForm=Word|NumType=Card 4 obl:arg _ SpaceAfter=No
29
+ 7 . . PUNCT Punc _ 4 punct _ _
30
+ """.lstrip()
31
+
32
+ ENG_CONLLU = """
33
+ # sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0007
34
+ # text = You wonder if he was manipulating the market with his bombing targets.
35
+ 1 You you PRON PRP Case=Nom|Person=2|PronType=Prs 2 nsubj 2:nsubj _
36
+ 2 wonder wonder VERB VBP Mood=Ind|Number=Sing|Person=2|Tense=Pres|VerbForm=Fin 0 root 0:root _
37
+ 3 if if SCONJ IN _ 6 mark 6:mark _
38
+ 4 he he PRON PRP Case=Nom|Gender=Masc|Number=Sing|Person=3|PronType=Prs 6 nsubj 6:nsubj _
39
+ 5 was be AUX VBD Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin 6 aux 6:aux _
40
+ 6 manipulating manipulate VERB VBG Tense=Pres|VerbForm=Part 2 ccomp 2:ccomp _
41
+ 7 the the DET DT Definite=Def|PronType=Art 8 det 8:det _
42
+ 8 market market NOUN NN Number=Sing 6 obj 6:obj _
43
+ 9 with with ADP IN _ 12 case 12:case _
44
+ 10 his his PRON PRP$ Case=Gen|Gender=Masc|Number=Sing|Person=3|Poss=Yes|PronType=Prs 12 nmod:poss 12:nmod:poss _
45
+ 11 bombing bombing NOUN NN Number=Sing 12 compound 12:compound _
46
+ 12 targets target NOUN NNS Number=Plur 6 obl 6:obl:with SpaceAfter=No
47
+ 13 . . PUNCT . _ 2 punct 2:punct _
48
+ """.lstrip()
49
+
50
+
51
+ def test_all_lowercase():
52
+ doc = CoNLL.conll2doc(input_str=LATIN_CONLLU)
53
+ assert all_lowercase(doc)
54
+
55
+ def test_not_all_lowercase():
56
+ doc = CoNLL.conll2doc(input_str=ENG_CONLLU)
57
+ assert not all_lowercase(doc)
stanza/stanza/tests/ner/test_bsf_2_beios.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Tests the conversion code for the lang_uk NER dataset
3
+ """
4
+
5
+ import unittest
6
+ from stanza.utils.datasets.ner.convert_bsf_to_beios import convert_bsf, parse_bsf, BsfInfo
7
+
8
+ import pytest
9
+ pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
10
+
11
+ class TestBsf2Beios(unittest.TestCase):
12
+
13
+ def test_empty_markup(self):
14
+ res = convert_bsf('', '')
15
+ self.assertEqual('', res)
16
+
17
+ def test_1line_markup(self):
18
+ data = 'тележурналіст Василь'
19
+ bsf_markup = 'T1 PERS 14 20 Василь'
20
+ expected = '''тележурналіст O
21
+ Василь S-PERS'''
22
+ self.assertEqual(expected, convert_bsf(data, bsf_markup))
23
+
24
+ def test_1line_follow_markup(self):
25
+ data = 'тележурналіст Василь .'
26
+ bsf_markup = 'T1 PERS 14 20 Василь'
27
+ expected = '''тележурналіст O
28
+ Василь S-PERS
29
+ . O'''
30
+ self.assertEqual(expected, convert_bsf(data, bsf_markup))
31
+
32
+ def test_1line_2tok_markup(self):
33
+ data = 'тележурналіст Василь Нагірний .'
34
+ bsf_markup = 'T1 PERS 14 29 Василь Нагірний'
35
+ expected = '''тележурналіст O
36
+ Василь B-PERS
37
+ Нагірний E-PERS
38
+ . O'''
39
+ self.assertEqual(expected, convert_bsf(data, bsf_markup))
40
+
41
+ def test_1line_Long_tok_markup(self):
42
+ data = 'А в музеї Гуцульщини і Покуття можна '
43
+ bsf_markup = 'T12 ORG 4 30 музеї Гуцульщини і Покуття'
44
+ expected = '''А O
45
+ в O
46
+ музеї B-ORG
47
+ Гуцульщини I-ORG
48
+ і I-ORG
49
+ Покуття E-ORG
50
+ можна O'''
51
+ self.assertEqual(expected, convert_bsf(data, bsf_markup))
52
+
53
+ def test_2line_2tok_markup(self):
54
+ data = '''тележурналіст Василь Нагірний .
55
+ В івано-франківському видавництві «Лілея НВ» вийшла друком'''
56
+ bsf_markup = '''T1 PERS 14 29 Василь Нагірний
57
+ T2 ORG 67 75 Лілея НВ'''
58
+ expected = '''тележурналіст O
59
+ Василь B-PERS
60
+ Нагірний E-PERS
61
+ . O
62
+
63
+
64
+ В O
65
+ івано-франківському O
66
+ видавництві O
67
+ « O
68
+ Лілея B-ORG
69
+ НВ E-ORG
70
+ » O
71
+ вийшла O
72
+ друком O'''
73
+ self.assertEqual(expected, convert_bsf(data, bsf_markup))
74
+
75
+ def test_real_markup(self):
76
+ data = '''Через напіввоєнний стан в Україні та збільшення телефонних терористичних погроз українці купуватимуть sim-карти тільки за паспортами .
77
+ Про це повідомив начальник управління зв'язків зі ЗМІ адміністрації Держспецзв'язку Віталій Кукса .
78
+ Він зауважив , що днями відомство опублікує проект змін до правил надання телекомунікаційних послуг , де будуть прописані норми ідентифікації громадян .
79
+ Абонентів , які на сьогодні вже мають sim-карту , за словами Віталія Кукси , реєструватимуть , коли ті звертатимуться в службу підтримки свого оператора мобільного зв'язку .
80
+ Однак мобільні оператори побоюються , що таке нововведення помітно зменшить продаж стартових пакетів , адже спеціалізовані магазини є лише у містах .
81
+ Відтак купити сімку в невеликих населених пунктах буде неможливо .
82
+ Крім того , нова процедура ідентифікації абонентів вимагатиме від операторів мобільного зв'язку додаткових витрат .
83
+ - Близько 90 % українських абонентів - це абоненти передоплати .
84
+ Якщо мова буде йти навіть про поетапну їх ідентифікацію , зробити це буде складно , довго і дорого .
85
+ Мобільним операторам доведеться йти на чималі витрати , пов'язані з укладанням і зберіганням договорів , веденням баз даних , - розповіла « Економічній правді » начальник відділу зв'язків з громадськістю « МТС-Україна » Вікторія Рубан .
86
+ '''
87
+ bsf_markup = '''T1 LOC 26 33 Україні
88
+ T2 ORG 203 218 Держспецзв'язку
89
+ T3 PERS 219 232 Віталій Кукса
90
+ T4 PERS 449 462 Віталія Кукси
91
+ T5 ORG 1201 1219 Економічній правді
92
+ T6 ORG 1267 1278 МТС-Україна
93
+ T7 PERS 1281 1295 Вікторія Рубан
94
+ '''
95
+ expected = '''Через O
96
+ напіввоєнний O
97
+ стан O
98
+ в O
99
+ Україні S-LOC
100
+ та O
101
+ збільшення O
102
+ телефонних O
103
+ терористичних O
104
+ погроз O
105
+ українці O
106
+ купуватимуть O
107
+ sim-карти O
108
+ тільки O
109
+ за O
110
+ паспортами O
111
+ . O
112
+
113
+
114
+ Про O
115
+ це O
116
+ повідомив O
117
+ начальник O
118
+ управління O
119
+ зв'язків O
120
+ зі O
121
+ ЗМІ O
122
+ адміністрації O
123
+ Держспецзв'язку S-ORG
124
+ Віталій B-PERS
125
+ Кукса E-PERS
126
+ . O
127
+
128
+
129
+ Він O
130
+ зауважив O
131
+ , O
132
+ що O
133
+ днями O
134
+ відомство O
135
+ опублікує O
136
+ проект O
137
+ змін O
138
+ до O
139
+ правил O
140
+ надання O
141
+ телекомунікаційних O
142
+ послуг O
143
+ , O
144
+ де O
145
+ будуть O
146
+ прописані O
147
+ норми O
148
+ ідентифікації O
149
+ громадян O
150
+ . O
151
+
152
+
153
+ Абонентів O
154
+ , O
155
+ які O
156
+ на O
157
+ сьогодні O
158
+ вже O
159
+ мають O
160
+ sim-карту O
161
+ , O
162
+ за O
163
+ словами O
164
+ Віталія B-PERS
165
+ Кукси E-PERS
166
+ , O
167
+ реєструватимуть O
168
+ , O
169
+ коли O
170
+ ті O
171
+ звертатимуться O
172
+ в O
173
+ службу O
174
+ підтримки O
175
+ свого O
176
+ оператора O
177
+ мобільного O
178
+ зв'язку O
179
+ . O
180
+
181
+
182
+ Однак O
183
+ мобільні O
184
+ оператори O
185
+ побоюються O
186
+ , O
187
+ що O
188
+ таке O
189
+ нововведення O
190
+ помітно O
191
+ зменшить O
192
+ продаж O
193
+ стартових O
194
+ пакетів O
195
+ , O
196
+ адже O
197
+ спеціалізовані O
198
+ магазини O
199
+ є O
200
+ лише O
201
+ у O
202
+ містах O
203
+ . O
204
+
205
+
206
+ Відтак O
207
+ купити O
208
+ сімку O
209
+ в O
210
+ невеликих O
211
+ населених O
212
+ пунктах O
213
+ буде O
214
+ неможливо O
215
+ . O
216
+
217
+
218
+ Крім O
219
+ того O
220
+ , O
221
+ нова O
222
+ процедура O
223
+ ідентифікації O
224
+ абонентів O
225
+ вимагатиме O
226
+ від O
227
+ операторів O
228
+ мобільного O
229
+ зв'язку O
230
+ додаткових O
231
+ витрат O
232
+ . O
233
+
234
+
235
+ - O
236
+ Близько O
237
+ 90 O
238
+ % O
239
+ українських O
240
+ абонентів O
241
+ - O
242
+ це O
243
+ абоненти O
244
+ передоплати O
245
+ . O
246
+
247
+
248
+ Якщо O
249
+ мова O
250
+ буде O
251
+ йти O
252
+ навіть O
253
+ про O
254
+ поетапну O
255
+ їх O
256
+ ідентифікацію O
257
+ , O
258
+ зробити O
259
+ це O
260
+ буде O
261
+ складно O
262
+ , O
263
+ довго O
264
+ і O
265
+ дорого O
266
+ . O
267
+
268
+
269
+ Мобільним O
270
+ операторам O
271
+ доведеться O
272
+ йти O
273
+ на O
274
+ чималі O
275
+ витрати O
276
+ , O
277
+ пов'язані O
278
+ з O
279
+ укладанням O
280
+ і O
281
+ зберіганням O
282
+ договорів O
283
+ , O
284
+ веденням O
285
+ баз O
286
+ даних O
287
+ , O
288
+ - O
289
+ розповіла O
290
+ « O
291
+ Економічній B-ORG
292
+ правді E-ORG
293
+ » O
294
+ начальник O
295
+ відділу O
296
+ зв'язків O
297
+ з O
298
+ громадськістю O
299
+ « O
300
+ МТС-Україна S-ORG
301
+ » O
302
+ Вікторія B-PERS
303
+ Рубан E-PERS
304
+ . O'''
305
+ self.assertEqual(expected, convert_bsf(data, bsf_markup))
306
+
307
+
308
+ class TestBsf(unittest.TestCase):
309
+
310
+ def test_empty_bsf(self):
311
+ self.assertEqual(parse_bsf(''), [])
312
+
313
+ def test_empty2_bsf(self):
314
+ self.assertEqual(parse_bsf(' \n \n'), [])
315
+
316
+ def test_1line_bsf(self):
317
+ bsf = 'T1 PERS 103 118 Василь Нагірний'
318
+ res = parse_bsf(bsf)
319
+ expected = BsfInfo('T1', 'PERS', 103, 118, 'Василь Нагірний')
320
+ self.assertEqual(len(res), 1)
321
+ self.assertEqual(res, [expected])
322
+
323
+ def test_2line_bsf(self):
324
+ bsf = '''T9 PERS 778 783 Карла
325
+ T10 MISC 814 819 міста'''
326
+ res = parse_bsf(bsf)
327
+ expected = [BsfInfo('T9', 'PERS', 778, 783, 'Карла'),
328
+ BsfInfo('T10', 'MISC', 814, 819, 'міста')]
329
+ self.assertEqual(len(res), 2)
330
+ self.assertEqual(res, expected)
331
+
332
+ def test_multiline_bsf(self):
333
+ bsf = '''T3 PERS 220 235 Андрієм Кіщуком
334
+ T4 MISC 251 285 А .
335
+ Kubler .
336
+ Світло і тіні маестро
337
+ T5 PERS 363 369 Кіблер'''
338
+ res = parse_bsf(bsf)
339
+ expected = [BsfInfo('T3', 'PERS', 220, 235, 'Андрієм Кіщуком'),
340
+ BsfInfo('T4', 'MISC', 251, 285, '''А .
341
+ Kubler .
342
+ Світло і тіні маестро'''),
343
+ BsfInfo('T5', 'PERS', 363, 369, 'Кіблер')]
344
+ self.assertEqual(len(res), len(expected))
345
+ self.assertEqual(res, expected)
346
+
347
+
348
+ if __name__ == '__main__':
349
+ unittest.main()
stanza/stanza/tests/ner/test_ner_training.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ import warnings
5
+
6
+ import pytest
7
+ import torch
8
+
9
+ pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
10
+
11
+ from stanza.models import ner_tagger
12
+ from stanza.models.ner.trainer import Trainer
13
+ from stanza.tests import TEST_WORKING_DIR
14
+ from stanza.utils.datasets.ner.prepare_ner_file import process_dataset
15
+
16
+ logger = logging.getLogger('stanza')
17
+
18
+ EN_TRAIN_BIO = """
19
+ Chris B-PERSON
20
+ Manning E-PERSON
21
+ is O
22
+ a O
23
+ good O
24
+ man O
25
+ . O
26
+
27
+ He O
28
+ works O
29
+ in O
30
+ Stanford B-ORG
31
+ University E-ORG
32
+ . O
33
+ """.lstrip().replace(" ", "\t")
34
+
35
+ EN_DEV_BIO = """
36
+ Chris B-PERSON
37
+ Manning E-PERSON
38
+ is O
39
+ part O
40
+ of O
41
+ Computer B-ORG
42
+ Science E-ORG
43
+ """.lstrip().replace(" ", "\t")
44
+
45
+ EN_TRAIN_2TAG = """
46
+ Chris B-PERSON B-PER
47
+ Manning E-PERSON E-PER
48
+ is O O
49
+ a O O
50
+ good O O
51
+ man O O
52
+ . O O
53
+
54
+ He O O
55
+ works O O
56
+ in O O
57
+ Stanford B-ORG B-ORG
58
+ University E-ORG B-ORG
59
+ . O O
60
+ """.strip().replace(" ", "\t")
61
+
62
+ EN_TRAIN_2TAG_EMPTY2 = """
63
+ Chris B-PERSON -
64
+ Manning E-PERSON -
65
+ is O -
66
+ a O -
67
+ good O -
68
+ man O -
69
+ . O -
70
+
71
+ He O -
72
+ works O -
73
+ in O -
74
+ Stanford B-ORG -
75
+ University E-ORG -
76
+ . O -
77
+ """.strip().replace(" ", "\t")
78
+
79
+ EN_DEV_2TAG = """
80
+ Chris B-PERSON B-PER
81
+ Manning E-PERSON E-PER
82
+ is O O
83
+ part O O
84
+ of O O
85
+ Computer B-ORG B-ORG
86
+ Science E-ORG E-ORG
87
+ """.strip().replace(" ", "\t")
88
+
89
+ @pytest.fixture(scope="module")
90
+ def pretrain_file():
91
+ return f'{TEST_WORKING_DIR}/in/tiny_emb.pt'
92
+
93
+ def write_temp_file(filename, bio_data):
94
+ bio_filename = os.path.splitext(filename)[0] + ".bio"
95
+ with open(bio_filename, "w", encoding="utf-8") as fout:
96
+ fout.write(bio_data)
97
+ process_dataset(bio_filename, filename)
98
+
99
+ def write_temp_2tag(filename, bio_data):
100
+ doc = []
101
+ sentences = bio_data.split("\n\n")
102
+ for sentence in sentences:
103
+ doc.append([])
104
+ for word in sentence.split("\n"):
105
+ text, tags = word.split("\t", maxsplit=1)
106
+ doc[-1].append({
107
+ "text": text,
108
+ "multi_ner": tags.split()
109
+ })
110
+
111
+ with open(filename, "w", encoding="utf-8") as fout:
112
+ json.dump(doc, fout)
113
+
114
+ def get_args(tmp_path, pretrain_file, train_json, dev_json, *extra_args):
115
+ save_dir = tmp_path / "models"
116
+ args = ["--data_dir", str(tmp_path),
117
+ "--wordvec_pretrain_file", pretrain_file,
118
+ "--train_file", str(train_json),
119
+ "--eval_file", str(dev_json),
120
+ "--shorthand", "en_test",
121
+ "--max_steps", "100",
122
+ "--eval_interval", "40",
123
+ "--save_dir", str(save_dir)]
124
+
125
+ args = args + list(extra_args)
126
+ return args
127
+
128
+ def run_two_tag_training(pretrain_file, tmp_path, *extra_args, train_data=EN_TRAIN_2TAG):
129
+ train_json = tmp_path / "en_test.train.json"
130
+ write_temp_2tag(train_json, train_data)
131
+
132
+ dev_json = tmp_path / "en_test.dev.json"
133
+ write_temp_2tag(dev_json, EN_DEV_2TAG)
134
+
135
+ args = get_args(tmp_path, pretrain_file, train_json, dev_json, *extra_args)
136
+ return ner_tagger.main(args)
137
+
138
+ def test_basic_two_tag_training(pretrain_file, tmp_path):
139
+ trainer = run_two_tag_training(pretrain_file, tmp_path)
140
+ assert len(trainer.model.tag_clfs) == 2
141
+ assert len(trainer.model.crits) == 2
142
+ assert len(trainer.vocab['tag'].lens()) == 2
143
+
144
+ def test_two_tag_training_backprop(pretrain_file, tmp_path):
145
+ """
146
+ Test that the training is backproping both tags
147
+
148
+ We can do this by using the "finetune" mechanism and verifying
149
+ that the output tensors are different
150
+ """
151
+ trainer = run_two_tag_training(pretrain_file, tmp_path)
152
+
153
+ # first, need to save the final model before restarting
154
+ # (alternatively, could reload the final checkpoint)
155
+ trainer.save(os.path.join(trainer.args['save_dir'], trainer.args['save_name']))
156
+ new_trainer = run_two_tag_training(pretrain_file, tmp_path, "--finetune")
157
+
158
+ assert len(trainer.model.tag_clfs) == 2
159
+ assert len(new_trainer.model.tag_clfs) == 2
160
+ for old_clf, new_clf in zip(trainer.model.tag_clfs, new_trainer.model.tag_clfs):
161
+ assert not torch.allclose(old_clf.weight, new_clf.weight)
162
+
163
+ def test_two_tag_training_c2_backprop(pretrain_file, tmp_path):
164
+ """
165
+ Test that the training is backproping only one tag if one column is blank
166
+
167
+ We can do this by using the "finetune" mechanism and verifying
168
+ that the output tensors are different in just the first column
169
+ """
170
+ trainer = run_two_tag_training(pretrain_file, tmp_path)
171
+
172
+ # first, need to save the final model before restarting
173
+ # (alternatively, could reload the final checkpoint)
174
+ trainer.save(os.path.join(trainer.args['save_dir'], trainer.args['save_name']))
175
+ new_trainer = run_two_tag_training(pretrain_file, tmp_path, "--finetune", train_data=EN_TRAIN_2TAG_EMPTY2)
176
+
177
+ assert len(trainer.model.tag_clfs) == 2
178
+ assert len(new_trainer.model.tag_clfs) == 2
179
+ assert not torch.allclose(trainer.model.tag_clfs[0].weight, new_trainer.model.tag_clfs[0].weight)
180
+ assert torch.allclose(trainer.model.tag_clfs[1].weight, new_trainer.model.tag_clfs[1].weight)
181
+
182
+ def test_connected_two_tag_training(pretrain_file, tmp_path):
183
+ trainer = run_two_tag_training(pretrain_file, tmp_path, "--connect_output_layers")
184
+ assert len(trainer.model.tag_clfs) == 2
185
+ assert len(trainer.model.crits) == 2
186
+ assert len(trainer.vocab['tag'].lens()) == 2
187
+
188
+ # this checks that with the connected output layers,
189
+ # the second output layer has its size increased
190
+ # by the number of tags known to the first output layer
191
+ assert trainer.model.tag_clfs[1].weight.shape[1] == trainer.vocab['tag'].lens()[0] + trainer.model.tag_clfs[0].weight.shape[1]
192
+
193
+ def run_training(pretrain_file, tmp_path, *extra_args):
194
+ train_json = tmp_path / "en_test.train.json"
195
+ write_temp_file(train_json, EN_TRAIN_BIO)
196
+
197
+ dev_json = tmp_path / "en_test.dev.json"
198
+ write_temp_file(dev_json, EN_DEV_BIO)
199
+
200
+ args = get_args(tmp_path, pretrain_file, train_json, dev_json, *extra_args)
201
+ return ner_tagger.main(args)
202
+
203
+
204
+ def test_train_model_gpu(pretrain_file, tmp_path):
205
+ """
206
+ Briefly train an NER model (no expectation of correctness) and check that it is on the GPU
207
+ """
208
+ trainer = run_training(pretrain_file, tmp_path)
209
+ if not torch.cuda.is_available():
210
+ warnings.warn("Cannot check that the NER model is on the GPU, since GPU is not available")
211
+ return
212
+
213
+ model = trainer.model
214
+ device = next(model.parameters()).device
215
+ assert str(device).startswith("cuda")
216
+
217
+
218
+ def test_train_model_cpu(pretrain_file, tmp_path):
219
+ """
220
+ Briefly train an NER model (no expectation of correctness) and check that it is on the GPU
221
+ """
222
+ trainer = run_training(pretrain_file, tmp_path, "--cpu")
223
+
224
+ model = trainer.model
225
+ device = next(model.parameters()).device
226
+ assert str(device).startswith("cpu")
227
+
228
+ def model_file_has_bert(filename):
229
+ checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True)
230
+ return any(x.startswith("bert_model.") for x in checkpoint['model'].keys())
231
+
232
+ def test_with_bert(pretrain_file, tmp_path):
233
+ trainer = run_training(pretrain_file, tmp_path, '--bert_model', 'hf-internal-testing/tiny-bert')
234
+ model_file = os.path.join(trainer.args['save_dir'], trainer.args['save_name'])
235
+ assert not model_file_has_bert(model_file)
236
+
237
+ def test_with_bert_finetune(pretrain_file, tmp_path):
238
+ trainer = run_training(pretrain_file, tmp_path, '--bert_model', 'hf-internal-testing/tiny-bert', '--bert_finetune')
239
+ model_file = os.path.join(trainer.args['save_dir'], trainer.args['save_name'])
240
+ assert model_file_has_bert(model_file)
241
+
242
+ foo_save_filename = os.path.join(tmp_path, "foo_" + trainer.args['save_name'])
243
+ bar_save_filename = os.path.join(tmp_path, "bar_" + trainer.args['save_name'])
244
+ trainer.save(foo_save_filename)
245
+ assert model_file_has_bert(foo_save_filename)
246
+
247
+ # TODO: technically this should still work if we turn off bert finetuning when reloading
248
+ reloaded_trainer = Trainer(args=trainer.args, model_file=foo_save_filename)
249
+ reloaded_trainer.save(bar_save_filename)
250
+ assert model_file_has_bert(bar_save_filename)
251
+
252
+ def test_with_peft_finetune(pretrain_file, tmp_path):
253
+ # TODO: check that the peft tensors are moving when training?
254
+ trainer = run_training(pretrain_file, tmp_path, '--bert_model', 'hf-internal-testing/tiny-bert', '--use_peft')
255
+ model_file = os.path.join(trainer.args['save_dir'], trainer.args['save_name'])
256
+ checkpoint = torch.load(model_file, lambda storage, loc: storage, weights_only=True)
257
+ assert 'bert_lora' in checkpoint
258
+ assert not any(x.startswith("bert_model.") for x in checkpoint['model'].keys())
259
+
260
+ # test loading
261
+ reloaded_trainer = Trainer(args=trainer.args, model_file=model_file)
stanza/stanza/tests/pipeline/pipeline_device_tests.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utility methods to check that all processors are on the expected device
3
+
4
+ Refactored since it can be used for multiple pipelines
5
+ """
6
+
7
+ import warnings
8
+
9
+ import torch
10
+
11
+ def check_on_gpu(pipeline):
12
+ """
13
+ Check that the processors are all on the GPU and that basic execution works
14
+ """
15
+ if not torch.cuda.is_available():
16
+ warnings.warn("Unable to run the test that checks the pipeline is on the GPU, as there is no GPU available!")
17
+ return
18
+
19
+ for name, proc in pipeline.processors.items():
20
+ if proc.trainer is not None:
21
+ device = next(proc.trainer.model.parameters()).device
22
+ else:
23
+ device = next(proc._model.parameters()).device
24
+
25
+ assert str(device).startswith("cuda"), "Processor %s was not on the GPU" % name
26
+
27
+ # just check that there are no cpu/cuda tensor conflicts
28
+ # when running on the GPU
29
+ pipeline("This is a small test")
30
+
31
+ def check_on_cpu(pipeline):
32
+ """
33
+ Check that the processors are all on the CPU and that basic execution works
34
+ """
35
+ for name, proc in pipeline.processors.items():
36
+ if proc.trainer is not None:
37
+ device = next(proc.trainer.model.parameters()).device
38
+ else:
39
+ device = next(proc._model.parameters()).device
40
+
41
+ assert str(device).startswith("cpu"), "Processor %s was not on the CPU" % name
42
+
43
+ # just check that there are no cpu/cuda tensor conflicts
44
+ # when running on the CPU
45
+ pipeline("This is a small test")
stanza/stanza/tests/pipeline/test_pipeline_sentiment_processor.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+
3
+ import pytest
4
+ import stanza
5
+ from stanza.utils.conll import CoNLL
6
+ from stanza.models.common.doc import Document
7
+
8
+ from stanza.tests import *
9
+
10
+ pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
11
+
12
+ # data for testing
13
+ EN_DOCS = ["Ragavan is terrible and should go away.", "Today is okay.", "Urza's Saga is great."]
14
+
15
+ EN_DOC = " ".join(EN_DOCS)
16
+
17
+ EXPECTED = [0, 1, 2]
18
+
19
+ class TestSentimentPipeline:
20
+ @pytest.fixture(scope="class")
21
+ def pipeline(self):
22
+ """
23
+ A reusable pipeline with the NER module
24
+ """
25
+ gc.collect()
26
+ return stanza.Pipeline(dir=TEST_MODELS_DIR, processors="tokenize,sentiment")
27
+
28
+ def test_simple(self, pipeline):
29
+ results = []
30
+ for text in EN_DOCS:
31
+ doc = pipeline(text)
32
+ assert len(doc.sentences) == 1
33
+ results.append(doc.sentences[0].sentiment)
34
+ assert EXPECTED == results
35
+
36
+ def test_multiple_sentences(self, pipeline):
37
+ doc = pipeline(EN_DOC)
38
+ assert len(doc.sentences) == 3
39
+ results = [sentence.sentiment for sentence in doc.sentences]
40
+ assert EXPECTED == results
41
+
42
+ def test_empty_text(self, pipeline):
43
+ """
44
+ Test empty text and a text which might get reduced to empty text by removing dashes
45
+ """
46
+ doc = pipeline("")
47
+ assert len(doc.sentences) == 0
48
+
49
+ doc = pipeline("--")
50
+ assert len(doc.sentences) == 1
stanza/stanza/tests/pipeline/test_requirements.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Test the requirements functionality for processors
3
+ """
4
+
5
+ import pytest
6
+ import stanza
7
+
8
+ from stanza.pipeline.core import PipelineRequirementsException
9
+ from stanza.pipeline.processor import ProcessorRequirementsException
10
+ from stanza.tests import *
11
+
12
+ pytestmark = pytest.mark.pipeline
13
+
14
+ def check_exception_vals(req_exception, req_exception_vals):
15
+ """
16
+ Check the values of a ProcessorRequirementsException against a dict of expected values.
17
+ :param req_exception: the ProcessorRequirementsException to evaluate
18
+ :param req_exception_vals: expected values for the ProcessorRequirementsException
19
+ :return: None
20
+ """
21
+ assert isinstance(req_exception, ProcessorRequirementsException)
22
+ assert req_exception.processor_type == req_exception_vals['processor_type']
23
+ assert req_exception.processors_list == req_exception_vals['processors_list']
24
+ assert req_exception.err_processor.requires == req_exception_vals['requires']
25
+
26
+
27
+ def test_missing_requirements():
28
+ """
29
+ Try to build several pipelines with bad configs and check thrown exceptions against gold exceptions.
30
+ :return: None
31
+ """
32
+ # list of (bad configs, list of gold ProcessorRequirementsExceptions that should be thrown) pairs
33
+ bad_config_lists = [
34
+ # missing tokenize
35
+ (
36
+ # input config
37
+ {'processors': 'pos,depparse', 'dir': TEST_MODELS_DIR, 'lang': 'en'},
38
+ # 2 expected exceptions
39
+ [
40
+ {'processor_type': 'POSProcessor', 'processors_list': ['pos', 'depparse'], 'provided_reqs': set([]),
41
+ 'requires': set(['tokenize'])},
42
+ {'processor_type': 'DepparseProcessor', 'processors_list': ['pos', 'depparse'],
43
+ 'provided_reqs': set([]), 'requires': set(['tokenize','pos', 'lemma'])}
44
+ ]
45
+ ),
46
+ # no pos when lemma_pos set to True; for english mwt should not be included in the loaded processor list
47
+ (
48
+ # input config
49
+ {'processors': 'tokenize,mwt,lemma', 'dir': TEST_MODELS_DIR, 'lang': 'en', 'lemma_pos': True},
50
+ # 1 expected exception
51
+ [
52
+ {'processor_type': 'LemmaProcessor', 'processors_list': ['tokenize', 'mwt', 'lemma'],
53
+ 'provided_reqs': set(['tokenize', 'mwt']), 'requires': set(['tokenize', 'pos'])}
54
+ ]
55
+ )
56
+ ]
57
+ # try to build each bad config, catch exceptions, check against gold
58
+ pipeline_fails = 0
59
+ for bad_config, gold_exceptions in bad_config_lists:
60
+ try:
61
+ stanza.Pipeline(**bad_config)
62
+ except PipelineRequirementsException as e:
63
+ pipeline_fails += 1
64
+ assert isinstance(e, PipelineRequirementsException)
65
+ assert len(e.processor_req_fails) == len(gold_exceptions)
66
+ for processor_req_e, gold_exception in zip(e.processor_req_fails,gold_exceptions):
67
+ # compare the thrown ProcessorRequirementsExceptions against gold
68
+ check_exception_vals(processor_req_e, gold_exception)
69
+ # check pipeline building failed twice
70
+ assert pipeline_fails == 2
71
+
72
+
stanza/stanza/tests/tokenization/__init__.py ADDED
File without changes
stanza/stanza/tests/tokenization/test_tokenize_utils.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Very simple test of the sentence slicing by <PAD> tags
3
+
4
+ TODO: could add a bunch more simple tests for the tokenization utils
5
+ """
6
+
7
+ import pytest
8
+ import stanza
9
+
10
+ from stanza import Pipeline
11
+ from stanza.tests import *
12
+ from stanza.models.common import doc
13
+ from stanza.models.tokenization import data
14
+ from stanza.models.tokenization import utils
15
+
16
+ pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
17
+
18
+ def test_find_spans():
19
+ """
20
+ Test various raw -> span manipulations
21
+ """
22
+ raw = ['u', 'n', 'b', 'a', 'n', ' ', 'm', 'o', 'x', ' ', 'o', 'p', 'a', 'l']
23
+ assert utils.find_spans(raw) == [(0, 14)]
24
+
25
+ raw = ['u', 'n', 'b', 'a', 'n', ' ', 'm', 'o', 'x', ' ', 'o', 'p', 'a', 'l', '<PAD>']
26
+ assert utils.find_spans(raw) == [(0, 14)]
27
+
28
+ raw = ['<PAD>', 'u', 'n', 'b', 'a', 'n', ' ', 'm', 'o', 'x', ' ', 'o', 'p', 'a', 'l', '<PAD>']
29
+ assert utils.find_spans(raw) == [(1, 15)]
30
+
31
+ raw = ['<PAD>', 'u', 'n', 'b', 'a', 'n', ' ', 'm', 'o', 'x', ' ', 'o', 'p', 'a', 'l']
32
+ assert utils.find_spans(raw) == [(1, 15)]
33
+
34
+ raw = ['<PAD>', 'u', 'n', 'b', 'a', 'n', '<PAD>', 'm', 'o', 'x', ' ', 'o', 'p', 'a', 'l']
35
+ assert utils.find_spans(raw) == [(1, 6), (7, 15)]
36
+
37
+ def check_offsets(doc, expected_offsets):
38
+ """
39
+ Compare the start_char and end_char of the tokens in the doc with the given list of list of offsets
40
+ """
41
+ assert len(doc.sentences) == len(expected_offsets)
42
+ for sentence, offsets in zip(doc.sentences, expected_offsets):
43
+ assert len(sentence.tokens) == len(offsets)
44
+ for token, offset in zip(sentence.tokens, offsets):
45
+ assert token.start_char == offset[0]
46
+ assert token.end_char == offset[1]
47
+
48
+ def test_match_tokens_with_text():
49
+ """
50
+ Test the conversion of pretokenized text to Document
51
+ """
52
+ doc = utils.match_tokens_with_text([["This", "is", "a", "test"]], "Thisisatest")
53
+ expected_offsets = [[(0, 4), (4, 6), (6, 7), (7, 11)]]
54
+ check_offsets(doc, expected_offsets)
55
+
56
+ doc = utils.match_tokens_with_text([["This", "is", "a", "test"], ["unban", "mox", "opal", "!"]], "Thisisatest unban mox opal!")
57
+ expected_offsets = [[(0, 4), (4, 6), (6, 7), (7, 11)],
58
+ [(13, 18), (19, 22), (24, 28), (28, 29)]]
59
+ check_offsets(doc, expected_offsets)
60
+
61
+ with pytest.raises(ValueError):
62
+ doc = utils.match_tokens_with_text([["This", "is", "a", "test"]], "Thisisatestttt")
63
+
64
+ with pytest.raises(ValueError):
65
+ doc = utils.match_tokens_with_text([["This", "is", "a", "test"]], "Thisisates")
66
+
67
+ with pytest.raises(ValueError):
68
+ doc = utils.match_tokens_with_text([["This", "iz", "a", "test"]], "Thisisatest")
69
+
70
+ def test_long_paragraph():
71
+ """
72
+ Test the tokenizer's capacity to break text up into smaller chunks
73
+ """
74
+ pipeline = Pipeline("en", dir=TEST_MODELS_DIR, processors="tokenize")
75
+ tokenizer = pipeline.processors['tokenize']
76
+
77
+ raw_text = "TIL not to ask a date to dress up as Smurfette on a first date. " * 100
78
+
79
+ # run a test to make sure the chunk operation is called
80
+ # if not, the test isn't actually testing what we need to test
81
+ batches = data.DataLoader(tokenizer.config, input_text=raw_text, vocab=tokenizer.vocab, evaluation=True, dictionary=tokenizer.trainer.dictionary)
82
+ batches.advance_old_batch = None
83
+ with pytest.raises(TypeError):
84
+ _, _, _, document = utils.output_predictions(None, tokenizer.trainer, batches, tokenizer.vocab, None, 3000,
85
+ orig_text=raw_text,
86
+ no_ssplit=tokenizer.config.get('no_ssplit', False))
87
+
88
+ # a new DataLoader should not be crippled as the above one was
89
+ batches = data.DataLoader(tokenizer.config, input_text=raw_text, vocab=tokenizer.vocab, evaluation=True, dictionary=tokenizer.trainer.dictionary)
90
+ _, _, _, document = utils.output_predictions(None, tokenizer.trainer, batches, tokenizer.vocab, None, 3000,
91
+ orig_text=raw_text,
92
+ no_ssplit=tokenizer.config.get('no_ssplit', False))
93
+
94
+ document = doc.Document(document, raw_text)
95
+ assert len(document.sentences) == 100
96
+
97
+ def test_postprocessor_application():
98
+ """
99
+ Check that the postprocessor behaves correctly by applying the identity postprocessor and hoping that it does indeed return correctly.
100
+ """
101
+
102
+ good_tokenization = [['I', 'am', 'Joe.', '⭆⊱⇞', 'Hi', '.'], ["I'm", 'a', 'chicken', '.']]
103
+ text = "I am Joe. ⭆⊱⇞ Hi. I'm a chicken."
104
+
105
+ target_doc = [[{'id': 1, 'text': 'I', 'start_char': 0, 'end_char': 1}, {'id': 2, 'text': 'am', 'start_char': 2, 'end_char': 4}, {'id': 3, 'text': 'Joe.', 'start_char': 5, 'end_char': 9}, {'id': 4, 'text': '⭆⊱⇞', 'start_char': 10, 'end_char': 13}, {'id': 5, 'text': 'Hi', 'start_char': 14, 'end_char': 16, 'misc': 'SpaceAfter=No'}, {'id': 6, 'text': '.', 'start_char': 16, 'end_char': 17}], [{'id': 1, 'text': "I'm", 'start_char': 18, 'end_char': 21}, {'id': 2, 'text': 'a', 'start_char': 22, 'end_char': 23}, {'id': 3, 'text': 'chicken', 'start_char': 24, 'end_char': 31, 'misc': 'SpaceAfter=No'}, {'id': 4, 'text': '.', 'start_char': 31, 'end_char': 32, 'misc': 'SpaceAfter=No'}]]
106
+
107
+ def postprocesor(_):
108
+ return good_tokenization
109
+
110
+ res = utils.postprocess_doc(target_doc, postprocesor, text)
111
+
112
+ assert res == target_doc
113
+
114
+ def test_reassembly_indexing():
115
+ """
116
+ Check that the reassembly code counts the indicies correctly, and including OOV chars.
117
+ """
118
+
119
+ good_tokenization = [['I', 'am', 'Joe.', '⭆⊱⇞', 'Hi', '.'], ["I'm", 'a', 'chicken', '.']]
120
+ good_mwts = [[False for _ in range(len(i))] for i in good_tokenization]
121
+ good_expansions = [[None for _ in range(len(i))] for i in good_tokenization]
122
+
123
+ text = "I am Joe. ⭆⊱⇞ Hi. I'm a chicken."
124
+
125
+ target_doc = [[{'id': 1, 'text': 'I', 'start_char': 0, 'end_char': 1}, {'id': 2, 'text': 'am', 'start_char': 2, 'end_char': 4}, {'id': 3, 'text': 'Joe.', 'start_char': 5, 'end_char': 9}, {'id': 4, 'text': '⭆⊱⇞', 'start_char': 10, 'end_char': 13}, {'id': 5, 'text': 'Hi', 'start_char': 14, 'end_char': 16, 'misc': 'SpaceAfter=No'}, {'id': 6, 'text': '.', 'start_char': 16, 'end_char': 17}], [{'id': 1, 'text': "I'm", 'start_char': 18, 'end_char': 21}, {'id': 2, 'text': 'a', 'start_char': 22, 'end_char': 23}, {'id': 3, 'text': 'chicken', 'start_char': 24, 'end_char': 31, 'misc': 'SpaceAfter=No'}, {'id': 4, 'text': '.', 'start_char': 31, 'end_char': 32, 'misc': 'SpaceAfter=No'}]]
126
+
127
+ res = utils.reassemble_doc_from_tokens(good_tokenization, good_mwts, good_expansions, text)
128
+
129
+ assert res == target_doc
130
+
131
+ def test_reassembly_reference_failures():
132
+ """
133
+ Check that the reassembly code complains correctly when the user adds tokens that doesn't exist
134
+ """
135
+
136
+ bad_addition_tokenization = [['Joe', 'Smith', 'lives', 'in', 'Southern', 'California', '.']]
137
+ bad_addition_mwts = [[False for _ in range(len(bad_addition_tokenization[0]))]]
138
+ bad_addition_expansions = [[None for _ in range(len(bad_addition_tokenization[0]))]]
139
+
140
+ bad_inline_tokenization = [['Joe', 'Smith', 'lives', 'in', 'Californiaa', '.']]
141
+ bad_inline_mwts = [[False for _ in range(len(bad_inline_tokenization[0]))]]
142
+ bad_inline_expansions = [[None for _ in range(len(bad_inline_tokenization[0]))]]
143
+
144
+ good_tokenization = [['Joe', 'Smith', 'lives', 'in', 'California', '.']]
145
+ good_mwts = [[False for _ in range(len(good_tokenization[0]))]]
146
+ good_expansions = [[None for _ in range(len(good_tokenization[0]))]]
147
+
148
+ text = "Joe Smith lives in California."
149
+
150
+ with pytest.raises(ValueError):
151
+ utils.reassemble_doc_from_tokens(bad_addition_tokenization, bad_addition_mwts, bad_addition_expansions, text)
152
+
153
+ with pytest.raises(ValueError):
154
+ utils.reassemble_doc_from_tokens(bad_inline_tokenization, bad_inline_mwts, bad_inline_mwts, text)
155
+
156
+ utils.reassemble_doc_from_tokens(good_tokenization, good_mwts, good_expansions, text)
157
+
158
+
159
+
160
+ TRAIN_DATA = """
161
+ # sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0003
162
+ # text = DPA: Iraqi authorities announced that they'd busted up three terrorist cells operating in Baghdad.
163
+ 1 DPA DPA PROPN NNP Number=Sing 0 root 0:root SpaceAfter=No
164
+ 2 : : PUNCT : _ 1 punct 1:punct _
165
+ 3 Iraqi Iraqi ADJ JJ Degree=Pos 4 amod 4:amod _
166
+ 4 authorities authority NOUN NNS Number=Plur 5 nsubj 5:nsubj _
167
+ 5 announced announce VERB VBD Mood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin 1 parataxis 1:parataxis _
168
+ 6 that that SCONJ IN _ 9 mark 9:mark _
169
+ 7-8 they'd _ _ _ _ _ _ _ _
170
+ 7 they they PRON PRP Case=Nom|Number=Plur|Person=3|PronType=Prs 9 nsubj 9:nsubj _
171
+ 8 'd have AUX VBD Mood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin 9 aux 9:aux _
172
+ 9 busted bust VERB VBN Tense=Past|VerbForm=Part 5 ccomp 5:ccomp _
173
+ 10 up up ADP RP _ 9 compound:prt 9:compound:prt _
174
+ 11 three three NUM CD NumForm=Digit|NumType=Card 13 nummod 13:nummod _
175
+ 12 terrorist terrorist ADJ JJ Degree=Pos 13 amod 13:amod _
176
+ 13 cells cell NOUN NNS Number=Plur 9 obj 9:obj _
177
+ 14 operating operate VERB VBG VerbForm=Ger 13 acl 13:acl _
178
+ 15 in in ADP IN _ 16 case 16:case _
179
+ 16 Baghdad Baghdad PROPN NNP Number=Sing 14 obl 14:obl:in SpaceAfter=No
180
+ 17 . . PUNCT . _ 1 punct 1:punct _
181
+
182
+ # sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0004
183
+ # text = Two of them were being run by 2 officials of the Ministry of the Interior!
184
+ 1 Two two NUM CD NumForm=Word|NumType=Card 6 nsubj:pass 6:nsubj:pass _
185
+ 2 of of ADP IN _ 3 case 3:case _
186
+ 3 them they PRON PRP Case=Acc|Number=Plur|Person=3|PronType=Prs 1 nmod 1:nmod:of _
187
+ 4 were be AUX VBD Mood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin 6 aux 6:aux _
188
+ 5 being be AUX VBG VerbForm=Ger 6 aux:pass 6:aux:pass _
189
+ 6 run run VERB VBN Tense=Past|VerbForm=Part|Voice=Pass 0 root 0:root _
190
+ 7 by by ADP IN _ 9 case 9:case _
191
+ 8 2 2 NUM CD NumForm=Digit|NumType=Card 9 nummod 9:nummod _
192
+ 9 officials official NOUN NNS Number=Plur 6 obl 6:obl:by _
193
+ 10 of of ADP IN _ 12 case 12:case _
194
+ 11 the the DET DT Definite=Def|PronType=Art 12 det 12:det _
195
+ 12 Ministry Ministry PROPN NNP Number=Sing 9 nmod 9:nmod:of _
196
+ 13 of of ADP IN _ 15 case 15:case _
197
+ 14 the the DET DT Definite=Def|PronType=Art 15 det 15:det _
198
+ 15 Interior Interior PROPN NNP Number=Sing 12 nmod 12:nmod:of SpaceAfter=No
199
+ 16 ! ! PUNCT . _ 6 punct 6:punct _
200
+
201
+ """.lstrip()
202
+
203
+ def test_lexicon_from_training_data(tmp_path):
204
+ """
205
+ Test a couple aspects of building a lexicon from training data
206
+
207
+ expected number of words eliminated for being too long
208
+ duplicate words counted once
209
+ numbers eliminated
210
+ """
211
+ conllu_file = str(tmp_path / "train.conllu")
212
+ with open(conllu_file, "w", encoding="utf-8") as fout:
213
+ fout.write(TRAIN_DATA)
214
+
215
+ lexicon, num_dict_feat = utils.create_lexicon("en_test", conllu_file)
216
+ lexicon = sorted(lexicon)
217
+ expected_lexicon = ["'d", 'announced', 'baghdad', 'being', 'busted', 'by', 'cells', 'dpa', 'in', 'interior', 'iraqi', 'ministry', 'of', 'officials', 'operating', 'run', 'terrorist', 'that', 'the', 'them', 'they', "they'd", 'three', 'two', 'up', 'were']
218
+ assert lexicon == expected_lexicon
219
+ assert num_dict_feat == max(len(x) for x in lexicon)
220
+
stanza/stanza/utils/charlm/__init__.py ADDED
File without changes
stanza/stanza/utils/charlm/conll17_to_text.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Turns a directory of conllu files from the conll 2017 shared task to a text file
3
+
4
+ Part of the process for building a charlm dataset
5
+
6
+ python conll17_to_text.py <directory>
7
+
8
+ This is an extension of the original script:
9
+ https://github.com/stanfordnlp/stanza-scripts/blob/master/charlm/conll17/conll2txt.py
10
+
11
+ To build a new charlm for a new language from a conll17 dataset:
12
+ - look for conll17 shared task data, possibly here:
13
+ https://lindat.mff.cuni.cz/repository/xmlui/handle/11234/1-1989
14
+ - python3 stanza/utils/charlm/conll17_to_text.py ~/extern_data/conll17/Bulgarian --output_directory extern_data/charlm_raw/bg/conll17
15
+ - python3 stanza/utils/charlm/make_lm_data.py --langs bg extern_data/charlm_raw extern_data/charlm/
16
+ """
17
+
18
+ import argparse
19
+ import lzma
20
+ import sys
21
+ import os
22
+
23
+ def process_file(input_filename, output_directory, compress):
24
+ if not input_filename.endswith('.conllu') and not input_filename.endswith(".conllu.xz"):
25
+ print("Skipping {}".format(input_filename))
26
+ return
27
+
28
+ if input_filename.endswith(".xz"):
29
+ open_fn = lambda x: lzma.open(x, mode='rt')
30
+ output_filename = input_filename[:-3].replace(".conllu", ".txt")
31
+ else:
32
+ open_fn = lambda x: open(x)
33
+ output_filename = input_filename.replace('.conllu', '.txt')
34
+
35
+ if output_directory:
36
+ output_filename = os.path.join(output_directory, os.path.split(output_filename)[1])
37
+
38
+ if compress:
39
+ output_filename = output_filename + ".xz"
40
+ output_fn = lambda x: lzma.open(x, mode='wt')
41
+ else:
42
+ output_fn = lambda x: open(x, mode='w')
43
+
44
+ if os.path.exists(output_filename):
45
+ print("Cowardly refusing to overwrite %s" % output_filename)
46
+ return
47
+
48
+ print("Converting %s to %s" % (input_filename, output_filename))
49
+ with open_fn(input_filename) as fin:
50
+ sentences = []
51
+ sentence = []
52
+ for line in fin:
53
+ line = line.strip()
54
+ if len(line) == 0: # new sentence
55
+ sentences.append(sentence)
56
+ sentence = []
57
+ continue
58
+ if line[0] == '#': # comment
59
+ continue
60
+ splitline = line.split('\t')
61
+ assert(len(splitline) == 10) # correct conllu
62
+ id, word = splitline[0], splitline[1]
63
+ if '-' not in id: # not mwt token
64
+ sentence.append(word)
65
+
66
+ if sentence:
67
+ sentences.append(sentence)
68
+
69
+ print(" Read in {} sentences".format(len(sentences)))
70
+ with output_fn(output_filename) as fout:
71
+ fout.write('\n'.join([' '.join(sentence) for sentence in sentences]))
72
+
73
+ def parse_args():
74
+ parser = argparse.ArgumentParser()
75
+ parser.add_argument("input_directory", help="Root directory with conllu or conllu.xz files.")
76
+ parser.add_argument("--output_directory", default=None, help="Directory to output to. Will output to input_directory if None")
77
+ parser.add_argument("--no_xz_output", default=True, dest="xz_output", action="store_false", help="Output compressed xz files")
78
+ args = parser.parse_args()
79
+ return args
80
+
81
+
82
+ if __name__ == '__main__':
83
+ args = parse_args()
84
+ directory = args.input_directory
85
+ filenames = sorted(os.listdir(directory))
86
+ print("Files to process in {}: {}".format(directory, filenames))
87
+ print("Processing to .xz files: {}".format(args.xz_output))
88
+
89
+ if args.output_directory:
90
+ os.makedirs(args.output_directory, exist_ok=True)
91
+ for filename in filenames:
92
+ process_file(os.path.join(directory, filename), args.output_directory, args.xz_output)
93
+
stanza/stanza/utils/charlm/dump_oscar.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This script downloads and extracts the text from an Oscar crawl on HuggingFace
3
+
4
+ To use, just run
5
+
6
+ dump_oscar.py <lang>
7
+
8
+ It will download the dataset and output all of the text to the --output directory.
9
+ Files will be broken into pieces to avoid having one giant file.
10
+ By default, files will also be compressed with xz (although this can be turned off)
11
+ """
12
+
13
+ import argparse
14
+ import lzma
15
+ import math
16
+ import os
17
+
18
+ from tqdm import tqdm
19
+
20
+ from datasets import get_dataset_split_names
21
+ from datasets import load_dataset
22
+
23
+ from stanza.models.common.constant import lang_to_langcode
24
+
25
+ def parse_args():
26
+ """
27
+ A few specific arguments for the dump program
28
+
29
+ Uses lang_to_langcode to process args.language, hopefully converting
30
+ a variety of possible formats to the short code used by HuggingFace
31
+ """
32
+ parser = argparse.ArgumentParser()
33
+ parser.add_argument("language", help="Language to download")
34
+ parser.add_argument("--output", default="oscar_dump", help="Path for saving files")
35
+ parser.add_argument("--no_xz", dest="xz", default=True, action='store_false', help="Don't xz the files - default is to compress while writing")
36
+ parser.add_argument("--prefix", default="oscar_dump", help="Prefix to use for the pieces of the dataset")
37
+ parser.add_argument("--version", choices=["2019", "2023"], default="2023", help="Which version of the Oscar dataset to download")
38
+
39
+ args = parser.parse_args()
40
+ args.language = lang_to_langcode(args.language)
41
+ return args
42
+
43
+ def download_2023(args):
44
+ dataset = load_dataset('oscar-corpus/OSCAR-2301', 'sd')
45
+ split_names = list(dataset.keys())
46
+
47
+
48
+ def main():
49
+ args = parse_args()
50
+
51
+ # this is the 2019 version. for 2023, you can do
52
+ # dataset = load_dataset('oscar-corpus/OSCAR-2301', 'sd')
53
+ language = args.language
54
+ if args.version == "2019":
55
+ dataset_name = "unshuffled_deduplicated_%s" % language
56
+ try:
57
+ split_names = get_dataset_split_names("oscar", dataset_name)
58
+ except ValueError as e:
59
+ raise ValueError("Language %s not available in HuggingFace Oscar" % language) from e
60
+
61
+ if len(split_names) > 1:
62
+ raise ValueError("Unexpected split_names: {}".format(split_names))
63
+
64
+ dataset = load_dataset("oscar", dataset_name)
65
+ dataset = dataset[split_names[0]]
66
+ size_in_bytes = dataset.info.size_in_bytes
67
+ process_item = lambda x: x['text']
68
+ elif args.version == "2023":
69
+ dataset = load_dataset("oscar-corpus/OSCAR-2301", language)
70
+ split_names = list(dataset.keys())
71
+ if len(split_names) > 1:
72
+ raise ValueError("Unexpected split_names: {}".format(split_names))
73
+ # it's not clear if some languages don't support size_in_bytes,
74
+ # or if there was an update to datasets which now allows that
75
+ #
76
+ # previously we did:
77
+ # dataset = dataset[split_names[0]]['text']
78
+ # size_in_bytes = sum(len(x) for x in dataset)
79
+ # process_item = lambda x: x
80
+ dataset = dataset[split_names[0]]
81
+ size_in_bytes = dataset.info.size_in_bytes
82
+ process_item = lambda x: x['text']
83
+ else:
84
+ raise AssertionError("Unknown version: %s" % args.version)
85
+
86
+ chunks = max(1.0, size_in_bytes // 1e8) # an overestimate
87
+ id_len = max(3, math.floor(math.log10(chunks)) + 1)
88
+
89
+ if args.xz:
90
+ format_str = "%s_%%0%dd.txt.xz" % (args.prefix, id_len)
91
+ fopen = lambda file_idx: lzma.open(os.path.join(args.output, format_str % file_idx), "wt")
92
+ else:
93
+ format_str = "%s_%%0%dd.txt" % (args.prefix, id_len)
94
+ fopen = lambda file_idx: open(os.path.join(args.output, format_str % file_idx), "w")
95
+
96
+ print("Writing dataset to %s" % args.output)
97
+ print("Dataset length: {}".format(size_in_bytes))
98
+ os.makedirs(args.output, exist_ok=True)
99
+
100
+ file_idx = 0
101
+ file_len = 0
102
+ total_len = 0
103
+ fout = fopen(file_idx)
104
+
105
+ for item in tqdm(dataset):
106
+ text = process_item(item)
107
+ fout.write(text)
108
+ fout.write("\n")
109
+ file_len += len(text)
110
+ file_len += 1
111
+ if file_len > 1e8:
112
+ file_len = 0
113
+ fout.close()
114
+ file_idx = file_idx + 1
115
+ fout = fopen(file_idx)
116
+
117
+ fout.close()
118
+
119
+ if __name__ == '__main__':
120
+ main()
stanza/stanza/utils/charlm/make_lm_data.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Create Stanza character LM train/dev/test data, by reading from txt files in each source corpus directory,
3
+ shuffling, splitting and saving into multiple smaller files (50MB by default) in a target directory.
4
+
5
+ This script assumes the following source directory structures:
6
+ - {src_dir}/{language}/{corpus}/*.txt
7
+ It will read from all source .txt files and create the following target directory structures:
8
+ - {tgt_dir}/{language}/{corpus}
9
+ and within each target directory, it will create the following files:
10
+ - train/*.txt
11
+ - dev.txt
12
+ - test.txt
13
+ Args:
14
+ - src_root: root directory of the source.
15
+ - tgt_root: root directory of the target.
16
+ - langs: a list of language codes to process; if specified, languages not in this list will be ignored.
17
+ Note: edit the {EXCLUDED_FOLDERS} variable to exclude more folders in the source directory.
18
+ """
19
+
20
+ import argparse
21
+ import glob
22
+ import os
23
+ from pathlib import Path
24
+ import shutil
25
+ import subprocess
26
+ import tempfile
27
+
28
+ from tqdm import tqdm
29
+
30
+ EXCLUDED_FOLDERS = ['raw_corpus']
31
+
32
+ def main():
33
+ parser = argparse.ArgumentParser()
34
+ parser.add_argument("src_root", default="src", help="Root directory with all source files. Expected structure is root dir -> language dirs -> package dirs -> text files to process")
35
+ parser.add_argument("tgt_root", default="tgt", help="Root directory with all target files.")
36
+ parser.add_argument("--langs", default="", help="A list of language codes to process. If not set, all languages under src_root will be processed.")
37
+ parser.add_argument("--packages", default="", help="A list of packages to process. If not set, all packages under the languages found will be processed.")
38
+ parser.add_argument("--no_xz_output", default=True, dest="xz_output", action="store_false", help="Output compressed xz files")
39
+ parser.add_argument("--split_size", default=50, type=int, help="How large to make each split, in MB")
40
+ parser.add_argument("--no_make_test_file", default=True, dest="make_test_file", action="store_false", help="Don't save a test file. Honestly, we never even use it. Best for low resource languages where every bit helps")
41
+ args = parser.parse_args()
42
+
43
+ print("Processing files:")
44
+ print(f"source root: {args.src_root}")
45
+ print(f"target root: {args.tgt_root}")
46
+ print("")
47
+
48
+ langs = []
49
+ if len(args.langs) > 0:
50
+ langs = args.langs.split(',')
51
+ print("Only processing the following languages: " + str(langs))
52
+
53
+ packages = []
54
+ if len(args.packages) > 0:
55
+ packages = args.packages.split(',')
56
+ print("Only processing the following packages: " + str(packages))
57
+
58
+ src_root = Path(args.src_root)
59
+ tgt_root = Path(args.tgt_root)
60
+
61
+ lang_dirs = os.listdir(src_root)
62
+ lang_dirs = [l for l in lang_dirs if l not in EXCLUDED_FOLDERS] # skip excluded
63
+ lang_dirs = [l for l in lang_dirs if os.path.isdir(src_root / l)] # skip non-directory
64
+ if len(langs) > 0: # filter languages if specified
65
+ lang_dirs = [l for l in lang_dirs if l in langs]
66
+ print(f"{len(lang_dirs)} total languages found:")
67
+ print(lang_dirs)
68
+ print("")
69
+
70
+ split_size = int(args.split_size * 1024 * 1024)
71
+
72
+ for lang in lang_dirs:
73
+ lang_root = src_root / lang
74
+ data_dirs = os.listdir(lang_root)
75
+ if len(packages) > 0:
76
+ data_dirs = [d for d in data_dirs if d in packages]
77
+ data_dirs = [d for d in data_dirs if os.path.isdir(lang_root / d)]
78
+ print(f"{len(data_dirs)} total corpus found for language {lang}.")
79
+ print(data_dirs)
80
+ print("")
81
+
82
+ for dataset_name in data_dirs:
83
+ src_dir = lang_root / dataset_name
84
+ tgt_dir = tgt_root / lang / dataset_name
85
+
86
+ if not os.path.exists(tgt_dir):
87
+ os.makedirs(tgt_dir)
88
+ print(f"-> Processing {lang}-{dataset_name}")
89
+ prepare_lm_data(src_dir, tgt_dir, lang, dataset_name, args.xz_output, split_size, args.make_test_file)
90
+
91
+ print("")
92
+
93
+ def prepare_lm_data(src_dir, tgt_dir, lang, dataset_name, compress, split_size, make_test_file):
94
+ """
95
+ Combine, shuffle and split data into smaller files, following a naming convention.
96
+ """
97
+ assert isinstance(src_dir, Path)
98
+ assert isinstance(tgt_dir, Path)
99
+ with tempfile.TemporaryDirectory(dir=tgt_dir) as tempdir:
100
+ tgt_tmp = os.path.join(tempdir, f"{lang}-{dataset_name}.tmp")
101
+ print(f"--> Copying files into {tgt_tmp}...")
102
+ # TODO: we can do this without the shell commands
103
+ input_files = glob.glob(str(src_dir) + '/*.txt') + glob.glob(str(src_dir) + '/*.txt.xz') + glob.glob(str(src_dir) + '/*.txt.gz')
104
+ for src_fn in tqdm(input_files):
105
+ if src_fn.endswith(".txt"):
106
+ cmd = f"cat {src_fn} >> {tgt_tmp}"
107
+ subprocess.run(cmd, shell=True)
108
+ elif src_fn.endswith(".txt.xz"):
109
+ cmd = f"xzcat {src_fn} >> {tgt_tmp}"
110
+ subprocess.run(cmd, shell=True)
111
+ elif src_fn.endswith(".txt.gz"):
112
+ cmd = f"zcat {src_fn} >> {tgt_tmp}"
113
+ subprocess.run(cmd, shell=True)
114
+ else:
115
+ raise AssertionError("should not have found %s" % src_fn)
116
+ tgt_tmp_shuffled = os.path.join(tempdir, f"{lang}-{dataset_name}.tmp.shuffled")
117
+
118
+ print(f"--> Shuffling files into {tgt_tmp_shuffled}...")
119
+ cmd = f"cat {tgt_tmp} | shuf > {tgt_tmp_shuffled}"
120
+ result = subprocess.run(cmd, shell=True)
121
+ if result.returncode != 0:
122
+ raise RuntimeError("Failed to shuffle files!")
123
+ size = os.path.getsize(tgt_tmp_shuffled) / 1024 / 1024 / 1024
124
+ print(f"--> Shuffled file size: {size:.4f} GB")
125
+ if size < 0.1:
126
+ raise RuntimeError("Not enough data found to build a charlm. At least 100MB data expected")
127
+
128
+ print(f"--> Splitting into smaller files of size {split_size} ...")
129
+ train_dir = tgt_dir / 'train'
130
+ if not os.path.exists(train_dir): # make training dir
131
+ os.makedirs(train_dir)
132
+ cmd = f"split -C {split_size} -a 4 -d --additional-suffix .txt {tgt_tmp_shuffled} {train_dir}/{lang}-{dataset_name}-"
133
+ result = subprocess.run(cmd, shell=True)
134
+ if result.returncode != 0:
135
+ raise RuntimeError("Failed to split files!")
136
+ total = len(glob.glob(f'{train_dir}/*.txt'))
137
+ print(f"--> {total} total files generated.")
138
+ if total < 3:
139
+ raise RuntimeError("Something went wrong! %d file(s) produced by shuffle and split, expected at least 3" % total)
140
+
141
+ dev_file = f"{tgt_dir}/dev.txt"
142
+ test_file = f"{tgt_dir}/test.txt"
143
+ if make_test_file:
144
+ print("--> Creating dev and test files...")
145
+ shutil.move(f"{train_dir}/{lang}-{dataset_name}-0000.txt", dev_file)
146
+ shutil.move(f"{train_dir}/{lang}-{dataset_name}-0001.txt", test_file)
147
+ txt_files = [dev_file, test_file] + glob.glob(f'{train_dir}/*.txt')
148
+ else:
149
+ print("--> Creating dev file...")
150
+ shutil.move(f"{train_dir}/{lang}-{dataset_name}-0000.txt", dev_file)
151
+ txt_files = [dev_file] + glob.glob(f'{train_dir}/*.txt')
152
+
153
+ if compress:
154
+ print("--> Compressing files...")
155
+ for txt_file in tqdm(txt_files):
156
+ subprocess.run(['xz', txt_file])
157
+
158
+ print("--> Cleaning up...")
159
+ print(f"--> All done for {lang}-{dataset_name}.\n")
160
+
161
+ if __name__ == "__main__":
162
+ main()
stanza/stanza/utils/constituency/check_transitions.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ from stanza.models.constituency import transition_sequence
4
+ from stanza.models.constituency import tree_reader
5
+ from stanza.models.constituency.parse_transitions import TransitionScheme
6
+ from stanza.models.constituency.parse_tree import Tree
7
+ from stanza.models.constituency.utils import verify_transitions
8
+
9
+ def main():
10
+ parser = argparse.ArgumentParser()
11
+ parser.add_argument('--train_file', type=str, default="data/constituency/en_ptb3_train.mrg", help='Input file for data loader.')
12
+ parser.add_argument('--transition_scheme', default=TransitionScheme.IN_ORDER, type=lambda x: TransitionScheme[x.upper()],
13
+ help='Transition scheme to use. {}'.format(", ".join(x.name for x in TransitionScheme)))
14
+ parser.add_argument('--reversed', default=False, action='store_true', help='Do the transition sequence reversed')
15
+ parser.add_argument('--iterations', default=30, type=int, help='How many times to iterate, such as if doing a cProfile')
16
+ args = parser.parse_args()
17
+ args = vars(args)
18
+
19
+ train_trees = tree_reader.read_treebank(args['train_file'])
20
+ unary_limit = max(t.count_unary_depth() for t in train_trees) + 1
21
+ train_sequences, train_transitions = transition_sequence.convert_trees_to_sequences(train_trees, "training", args['transition_scheme'], args['reversed'])
22
+ root_labels = Tree.get_root_labels(train_trees)
23
+ for i in range(args['iterations']):
24
+ verify_transitions(train_trees, train_sequences, args['transition_scheme'], unary_limit, args['reversed'], "train", root_labels)
25
+
26
+ if __name__ == '__main__':
27
+ main()
stanza/stanza/utils/constituency/list_tensors.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Lists all the tensors in a constituency model.
3
+
4
+ Currently useful in combination with torchshow for displaying a series of tensors as they change.
5
+ """
6
+
7
+ import sys
8
+
9
+ from stanza.models.constituency.trainer import Trainer
10
+
11
+
12
+ trainer = Trainer.load(sys.argv[1])
13
+ model = trainer.model
14
+
15
+ for name, param in model.named_parameters():
16
+ print(name, param.requires_grad)
stanza/stanza/utils/datasets/__init__.py ADDED
File without changes
stanza/stanza/utils/datasets/contract_mwt.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ def contract_mwt(infile, outfile, ignore_gapping=True):
4
+ """
5
+ Simplify the gold tokenizer data for use as MWT processor test files
6
+
7
+ The simplifications are to remove the expanded MWTs, and in the
8
+ case of ignore_gapping=True, remove any copy words for the dependencies
9
+ """
10
+ with open(outfile, 'w') as fout:
11
+ with open(infile, 'r') as fin:
12
+ idx = 0
13
+ mwt_begin = 0
14
+ mwt_end = -1
15
+ for line in fin:
16
+ line = line.strip()
17
+
18
+ if line.startswith('#'):
19
+ print(line, file=fout)
20
+ continue
21
+ elif len(line) <= 0:
22
+ print(line, file=fout)
23
+ idx = 0
24
+ mwt_begin = 0
25
+ mwt_end = -1
26
+ continue
27
+
28
+ line = line.split('\t')
29
+
30
+ # ignore gapping word
31
+ if ignore_gapping and '.' in line[0]:
32
+ continue
33
+
34
+ idx += 1
35
+ if '-' in line[0]:
36
+ mwt_begin, mwt_end = [int(x) for x in line[0].split('-')]
37
+ print("{}\t{}\t{}".format(idx, "\t".join(line[1:-1]), "MWT=Yes" if line[-1] == '_' else line[-1] + "|MWT=Yes"), file=fout)
38
+ idx -= 1
39
+ elif mwt_begin <= idx <= mwt_end:
40
+ continue
41
+ else:
42
+ print("{}\t{}".format(idx, "\t".join(line[1:])), file=fout)
43
+
44
+ if __name__ == '__main__':
45
+ contract_mwt(sys.argv[1], sys.argv[2])
46
+
stanza/stanza/utils/datasets/coref/__init__.py ADDED
File without changes
stanza/stanza/utils/datasets/coref/convert_ontonotes.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+
4
+ import stanza
5
+
6
+ from stanza.models.constituency import tree_reader
7
+ from stanza.utils.default_paths import get_default_paths
8
+ from stanza.utils.get_tqdm import get_tqdm
9
+ from stanza.utils.datasets.coref.utils import process_document
10
+
11
+ tqdm = get_tqdm()
12
+
13
+ def read_paragraphs(section):
14
+ for doc in section:
15
+ part_id = None
16
+ paragraph = []
17
+ for sentence in doc['sentences']:
18
+ if part_id is None:
19
+ part_id = sentence['part_id']
20
+ elif part_id != sentence['part_id']:
21
+ yield doc['document_id'], part_id, paragraph
22
+ paragraph = []
23
+ part_id = sentence['part_id']
24
+ paragraph.append(sentence)
25
+ if paragraph != []:
26
+ yield doc['document_id'], part_id, paragraph
27
+
28
+ def convert_dataset_section(pipe, section):
29
+ processed_section = []
30
+ section = list(x for x in read_paragraphs(section))
31
+
32
+ for idx, (doc_id, part_id, paragraph) in enumerate(tqdm(section)):
33
+ sentences = [x['words'] for x in paragraph]
34
+ coref_spans = [x['coref_spans'] for x in paragraph]
35
+ sentence_speakers = [x['speaker'] for x in paragraph]
36
+
37
+ processed = process_document(pipe, doc_id, part_id, sentences, coref_spans, sentence_speakers)
38
+ processed_section.append(processed)
39
+ return processed_section
40
+
41
+ SECTION_NAMES = {"train": "train",
42
+ "dev": "validation",
43
+ "test": "test"}
44
+
45
+ def process_dataset(short_name, ontonotes_path, coref_output_path):
46
+ try:
47
+ from datasets import load_dataset
48
+ except ImportError as e:
49
+ raise ImportError("Please install the datasets package to process OntoNotes coref with Stanza")
50
+
51
+ if short_name == 'en_ontonotes':
52
+ config_name = 'english_v4'
53
+ elif short_name in ('zh_ontonotes', 'zh-hans_ontonotes'):
54
+ config_name = 'chinese_v4'
55
+ elif short_name == 'ar_ontonotes':
56
+ config_name = 'arabic_v4'
57
+ else:
58
+ raise ValueError("Unknown short name for downloading ontonotes: %s" % short_name)
59
+
60
+ pipe = stanza.Pipeline("en", processors="tokenize,pos,lemma,depparse", package="default_accurate", tokenize_pretokenized=True)
61
+ dataset = load_dataset("conll2012_ontonotesv5", config_name, cache_dir=ontonotes_path)
62
+ for section, hf_name in SECTION_NAMES.items():
63
+ #for section, hf_name in [("test", "test")]:
64
+ print("Processing %s" % section)
65
+ converted_section = convert_dataset_section(pipe, dataset[hf_name])
66
+ output_filename = os.path.join(coref_output_path, "%s.%s.json" % (short_name, section))
67
+ with open(output_filename, "w", encoding="utf-8") as fout:
68
+ json.dump(converted_section, fout, indent=2)
69
+
70
+
71
+ def main():
72
+ paths = get_default_paths()
73
+ coref_input_path = paths['COREF_BASE']
74
+ ontonotes_path = os.path.join(coref_input_path, "english", "en_ontonotes")
75
+ coref_output_path = paths['COREF_DATA_DIR']
76
+ process_dataset("en_ontonotes", ontonotes_path, coref_output_path)
77
+
78
+ if __name__ == '__main__':
79
+ main()
80
+
stanza/stanza/utils/datasets/coref/convert_udcoref.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ import json
3
+ import os
4
+ import re
5
+ import glob
6
+
7
+ from stanza.utils.default_paths import get_default_paths
8
+ from stanza.utils.get_tqdm import get_tqdm
9
+ from stanza.utils.datasets.coref.utils import find_cconj_head
10
+
11
+ from stanza.utils.conll import CoNLL
12
+
13
+ from random import Random
14
+
15
+ import argparse
16
+
17
+ augment_random = Random(7)
18
+ split_random = Random(8)
19
+
20
+ tqdm = get_tqdm()
21
+ IS_UDCOREF_FORMAT = True
22
+ UDCOREF_ADDN = 0 if not IS_UDCOREF_FORMAT else 1
23
+
24
+ def process_documents(docs, augment=False):
25
+ processed_section = []
26
+
27
+ for idx, (doc, doc_id, lang) in enumerate(tqdm(docs)):
28
+ # drop the last token 10% of the time
29
+ if augment:
30
+ for i in doc.sentences:
31
+ if len(i.words) > 1:
32
+ if augment_random.random() < 0.1:
33
+ i.tokens = i.tokens[:-1]
34
+ i.words = i.words[:-1]
35
+
36
+ # extract the entities
37
+ # get sentence words and lengths
38
+ sentences = [[j.text for j in i.words]
39
+ for i in doc.sentences]
40
+ sentence_lens = [len(x.words) for x in doc.sentences]
41
+
42
+ cased_words = []
43
+ for x in sentences:
44
+ if augment:
45
+ # modify case of the first word with 50% chance
46
+ if augment_random.random() < 0.5:
47
+ x[0] = x[0].lower()
48
+
49
+ for y in x:
50
+ cased_words.append(y)
51
+
52
+ sent_id = [y for idx, sent_len in enumerate(sentence_lens) for y in [idx] * sent_len]
53
+
54
+ word_total = 0
55
+ heads = []
56
+ # TODO: does SD vs UD matter?
57
+ deprel = []
58
+ for sentence in doc.sentences:
59
+ for word in sentence.words:
60
+ deprel.append(word.deprel)
61
+ if word.head == 0:
62
+ heads.append("null")
63
+ else:
64
+ heads.append(word.head - 1 + word_total)
65
+ word_total += len(sentence.words)
66
+
67
+ span_clusters = defaultdict(list)
68
+ word_clusters = defaultdict(list)
69
+ head2span = []
70
+ word_total = 0
71
+ SPANS = re.compile(r"(\(\w+|[%\w]+\))")
72
+ for parsed_sentence in doc.sentences:
73
+ # spans regex
74
+ # parse the misc column, leaving on "Entity" entries
75
+ misc = [[k.split("=")
76
+ for k in j
77
+ if k.split("=")[0] == "Entity"]
78
+ for i in parsed_sentence.words
79
+ for j in [i.misc.split("|") if i.misc else []]]
80
+ # and extract the Entity entry values
81
+ entities = [i[0][1] if len(i) > 0 else None for i in misc]
82
+ # extract reference information
83
+ refs = [SPANS.findall(i) if i else [] for i in entities]
84
+ # and calculate spans: the basic rule is (e... begins a reference
85
+ # and ) without e before ends the most recent reference
86
+ # every single time we get a closing element, we pop it off
87
+ # the refdict and insert the pair to final_refs
88
+ refdict = defaultdict(list)
89
+ final_refs = defaultdict(list)
90
+ last_ref = None
91
+ for indx, i in enumerate(refs):
92
+ for j in i:
93
+ # this is the beginning of a reference
94
+ if j[0] == "(":
95
+ refdict[j[1+UDCOREF_ADDN:]].append(indx)
96
+ last_ref = j[1+UDCOREF_ADDN:]
97
+ # at the end of a reference, if we got exxxxx, that ends
98
+ # a particular refereenc; otherwise, it ends the last reference
99
+ elif j[-1] == ")" and j[UDCOREF_ADDN:-1].isnumeric():
100
+ if (not UDCOREF_ADDN) or j[0] == "e":
101
+ try:
102
+ final_refs[j[UDCOREF_ADDN:-1]].append((refdict[j[UDCOREF_ADDN:-1]].pop(-1), indx))
103
+ except IndexError:
104
+ # this is probably zero anaphora
105
+ continue
106
+ elif j[-1] == ")":
107
+ final_refs[last_ref].append((refdict[last_ref].pop(-1), indx))
108
+ last_ref = None
109
+ final_refs = dict(final_refs)
110
+ # convert it to the right format (specifically, in (ref, start, end) tuples)
111
+ coref_spans = []
112
+ for k, v in final_refs.items():
113
+ for i in v:
114
+ coref_spans.append([int(k), i[0], i[1]])
115
+ sentence_upos = [x.upos for x in parsed_sentence.words]
116
+ sentence_heads = [x.head - 1 if x.head > 0 else None for x in parsed_sentence.words]
117
+ for span in coref_spans:
118
+ # input is expected to be start word, end word + 1
119
+ # counting from 0
120
+ # whereas the OntoNotes coref_span is [start_word, end_word] inclusive
121
+ span_start = span[1] + word_total
122
+ span_end = span[2] + word_total + 1
123
+ candidate_head = find_cconj_head(sentence_heads, sentence_upos, span[1], span[2]+1)
124
+ if candidate_head is None:
125
+ for candidate_head in range(span[1], span[2] + 1):
126
+ # stanza uses 0 to mark the head, whereas OntoNotes is counting
127
+ # words from 0, so we have to subtract 1 from the stanza heads
128
+ #print(span, candidate_head, parsed_sentence.words[candidate_head].head - 1)
129
+ # treat the head of the phrase as the first word that has a head outside the phrase
130
+ if (parsed_sentence.words[candidate_head].head - 1 < span[1] or
131
+ parsed_sentence.words[candidate_head].head - 1 > span[2]):
132
+ break
133
+ else:
134
+ # if none have a head outside the phrase (circular??)
135
+ # then just take the first word
136
+ candidate_head = span[1]
137
+ #print("----> %d" % candidate_head)
138
+ candidate_head += word_total
139
+ span_clusters[span[0]].append((span_start, span_end))
140
+ word_clusters[span[0]].append(candidate_head)
141
+ head2span.append((candidate_head, span_start, span_end))
142
+ word_total += len(parsed_sentence.words)
143
+ span_clusters = sorted([sorted(values) for _, values in span_clusters.items()])
144
+ word_clusters = sorted([sorted(values) for _, values in word_clusters.items()])
145
+ head2span = sorted(head2span)
146
+
147
+ processed = {
148
+ "document_id": doc_id,
149
+ "cased_words": cased_words,
150
+ "sent_id": sent_id,
151
+ "part_id": idx,
152
+ # "pos": pos,
153
+ "deprel": deprel,
154
+ "head": heads,
155
+ "span_clusters": span_clusters,
156
+ "word_clusters": word_clusters,
157
+ "head2span": head2span,
158
+ "lang": lang
159
+ }
160
+ processed_section.append(processed)
161
+ return processed_section
162
+
163
+ def process_dataset(short_name, coref_output_path, split_test, train_files, dev_files):
164
+ section_names = ('train', 'dev')
165
+ section_filenames = [train_files, dev_files]
166
+ sections = []
167
+
168
+ test_sections = []
169
+
170
+ for section, filenames in zip(section_names, section_filenames):
171
+ input_file = []
172
+ for load in filenames:
173
+ lang = load.split("/")[-1].split("_")[0]
174
+ print("Ingesting %s from %s of lang %s" % (section, load, lang))
175
+ docs = CoNLL.conll2multi_docs(load)
176
+ print(" Ingested %d documents" % len(docs))
177
+ if split_test and section == 'train':
178
+ test_section = []
179
+ train_section = []
180
+ for i in docs:
181
+ # reseed for each doc so that we can attempt to keep things stable in the event
182
+ # of different file orderings or some change to the number of documents
183
+ split_random = Random(i.sentences[0].doc_id + i.sentences[0].text)
184
+ if split_random.random() < split_test:
185
+ test_section.append((i, i.sentences[0].doc_id, lang))
186
+ else:
187
+ train_section.append((i, i.sentences[0].doc_id, lang))
188
+ if len(test_section) == 0 and len(train_section) >= 2:
189
+ idx = split_random.randint(0, len(train_section) - 1)
190
+ test_section = [train_section[idx]]
191
+ train_section = train_section[:idx] + train_section[idx+1:]
192
+ print(" Splitting %d documents from %s for test" % (len(test_section), load))
193
+ input_file.extend(train_section)
194
+ test_sections.append(test_section)
195
+ else:
196
+ for i in docs:
197
+ input_file.append((i, i.sentences[0].doc_id, lang))
198
+ print("Ingested %d total documents" % len(input_file))
199
+ sections.append(input_file)
200
+
201
+ if split_test:
202
+ section_names = ('train', 'dev', 'test')
203
+ full_test_section = []
204
+ for filename, test_section in zip(filenames, test_sections):
205
+ # TODO: could write dataset-specific test sections as well
206
+ full_test_section.extend(test_section)
207
+ sections.append(full_test_section)
208
+
209
+
210
+ for section_data, section_name in zip(sections, section_names):
211
+ converted_section = process_documents(section_data, augment=(section_name=="train"))
212
+
213
+ os.makedirs(coref_output_path, exist_ok=True)
214
+ output_filename = os.path.join(coref_output_path, "%s.%s.json" % (short_name, section_name))
215
+ with open(output_filename, "w", encoding="utf-8") as fout:
216
+ json.dump(converted_section, fout, indent=2)
217
+
218
+ def get_dataset_by_language(coref_input_path, langs):
219
+ conll_path = os.path.join(coref_input_path, "CorefUD-1.2-public", "data")
220
+ train_filenames = []
221
+ dev_filenames = []
222
+ for lang in langs:
223
+ train_filenames.extend(glob.glob(os.path.join(conll_path, "*%s*" % lang, "*train.conllu")))
224
+ dev_filenames.extend(glob.glob(os.path.join(conll_path, "*%s*" % lang, "*dev.conllu")))
225
+ train_filenames = sorted(train_filenames)
226
+ dev_filenames = sorted(dev_filenames)
227
+ return train_filenames, dev_filenames
228
+
229
+ def main():
230
+ paths = get_default_paths()
231
+ parser = argparse.ArgumentParser(
232
+ prog='Convert UDCoref Data',
233
+ )
234
+ parser.add_argument('--split_test', default=None, type=float, help='How much of the data to randomly split from train to make a test set')
235
+
236
+ group = parser.add_mutually_exclusive_group(required=True)
237
+ group.add_argument('--directory', type=str, help="the name of the subfolder for data conversion")
238
+ group.add_argument('--project', type=str, help="Look for and use a set of datasets for data conversion - Slavic or Hungarian")
239
+
240
+ args = parser.parse_args()
241
+ coref_input_path = paths['COREF_BASE']
242
+ coref_output_path = paths['COREF_DATA_DIR']
243
+
244
+ if args.project:
245
+ if args.project == 'slavic':
246
+ project = "slavic_udcoref"
247
+ langs = ('Polish', 'Russian', 'Czech')
248
+ train_filenames, dev_filenames = get_dataset_by_language(coref_input_path, langs)
249
+ elif args.project == 'hungarian':
250
+ project = "hu_udcoref"
251
+ langs = ('Hungarian',)
252
+ train_filenames, dev_filenames = get_dataset_by_language(coref_input_path, langs)
253
+ elif args.project == 'gerrom':
254
+ project = "gerrom_udcoref"
255
+ langs = ('Catalan', 'English', 'French', 'German', 'Norwegian', 'Spanish')
256
+ train_filenames, dev_filenames = get_dataset_by_language(coref_input_path, langs)
257
+ elif args.project == 'germanic':
258
+ project = "germanic_udcoref"
259
+ langs = ('English', 'German', 'Norwegian')
260
+ train_filenames, dev_filenames = get_dataset_by_language(coref_input_path, langs)
261
+ elif args.project == 'norwegian':
262
+ project = "norwegian_udcoref"
263
+ langs = ('Norwegian',)
264
+ train_filenames, dev_filenames = get_dataset_by_language(coref_input_path, langs)
265
+ else:
266
+ project = args.directory
267
+ conll_path = os.path.join(coref_input_path, project)
268
+ if not os.path.exists(conll_path) and os.path.exists(project):
269
+ conll_path = args.directory
270
+ train_filenames = sorted(glob.glob(os.path.join(conll_path, f"*train.conllu")))
271
+ dev_filenames = sorted(glob.glob(os.path.join(conll_path, f"*dev.conllu")))
272
+ process_dataset(project, coref_output_path, args.split_test, train_filenames, dev_filenames)
273
+
274
+ if __name__ == '__main__':
275
+ main()
276
+
stanza/stanza/utils/datasets/coref/utils.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ from functools import lru_cache
3
+
4
+ class DynamicDepth():
5
+ """
6
+ Implements a cache + dynamic programming to find the relative depth of every word in a subphrase given the head word for every word.
7
+ """
8
+ def get_parse_depths(self, heads, start, end):
9
+ """Return the relative depth for every word
10
+
11
+ Args:
12
+ heads (list): List where each entry is the index of that entry's head word in the dependency parse
13
+ start (int): starting index of the heads for the subphrase
14
+ end (int): ending index of the heads for the subphrase
15
+
16
+ Returns:
17
+ list: Relative depth in the dependency parse for every word
18
+ """
19
+ self.heads = heads[start:end]
20
+ self.relative_heads = [h - start if h else -100 for h in self.heads] # -100 to deal with 'none' headwords
21
+
22
+ depths = [self._get_depth_recursive(h) for h in range(len(self.relative_heads))]
23
+
24
+ return depths
25
+
26
+ @lru_cache(maxsize=None)
27
+ def _get_depth_recursive(self, index):
28
+ """Recursively get the depths of every index using a cache and recursion
29
+
30
+ Args:
31
+ index (int): Index of the word for which to calculate the relative depth
32
+
33
+ Returns:
34
+ int: Relative depth of the word at the index
35
+ """
36
+ # if the head for the current index is outside the scope, this index is a relative root
37
+ if self.relative_heads[index] >= len(self.relative_heads) or self.relative_heads[index] < 0:
38
+ return 0
39
+ return self._get_depth_recursive(self.relative_heads[index]) + 1
40
+
41
+ def find_cconj_head(heads, upos, start, end):
42
+ """
43
+ Finds how far each word is from the head of a span, then uses the closest CCONJ to the head as the new head
44
+
45
+ If no CCONJ is present, returns None
46
+ """
47
+ # use head information to extract parse depth
48
+ dynamicDepth = DynamicDepth()
49
+ depth = dynamicDepth.get_parse_depths(heads, start, end)
50
+ depth_limit = 2
51
+
52
+ # return first 'CCONJ' token above depth limit, if exists
53
+ # unlike the original paper, we expect the parses to use UPOS, hence CCONJ instead of CC
54
+ cc_indexes = [i for i in range(end - start) if upos[i+start] == 'CCONJ' and depth[i] < depth_limit]
55
+ if cc_indexes:
56
+ return cc_indexes[0] + start
57
+ return None
58
+
59
+ def process_document(pipe, doc_id, part_id, sentences, coref_spans, sentence_speakers, use_cconj_heads=True):
60
+ """
61
+ coref_spans: a list of lists
62
+ one list per sentence
63
+ each sentence has a list of spans, where each span is (span_index, span_start, span_end)
64
+ """
65
+ sentence_lens = [len(x) for x in sentences]
66
+ if all(isinstance(x, list) for x in sentence_speakers):
67
+ speaker = [y for x in sentence_speakers for y in x]
68
+ else:
69
+ speaker = [y for x, sent_len in zip(sentence_speakers, sentence_lens) for y in [x] * sent_len]
70
+
71
+ cased_words = [y for x in sentences for y in x]
72
+ sent_id = [y for idx, sent_len in enumerate(sentence_lens) for y in [idx] * sent_len]
73
+
74
+ # use the trees to get the xpos tags
75
+ # alternatively, could translate the pos_tags field,
76
+ # but those have numbers, which is annoying
77
+ #tree_text = "\n".join(x['parse_tree'] for x in paragraph)
78
+ #trees = tree_reader.read_trees(tree_text)
79
+ #pos = [x.label for tree in trees for x in tree.yield_preterminals()]
80
+ # actually, the downstream code doesn't use pos at all. maybe we can skip?
81
+
82
+ doc = pipe(sentences)
83
+ word_total = 0
84
+ heads = []
85
+ # TODO: does SD vs UD matter?
86
+ deprel = []
87
+ for sentence in doc.sentences:
88
+ for word in sentence.words:
89
+ deprel.append(word.deprel)
90
+ if word.head == 0:
91
+ heads.append("null")
92
+ else:
93
+ heads.append(word.head - 1 + word_total)
94
+ word_total += len(sentence.words)
95
+
96
+ span_clusters = defaultdict(list)
97
+ word_clusters = defaultdict(list)
98
+ head2span = []
99
+ word_total = 0
100
+ for parsed_sentence, ontonotes_coref, ontonotes_words in zip(doc.sentences, coref_spans, sentences):
101
+ sentence_upos = [x.upos for x in parsed_sentence.words]
102
+ sentence_heads = [x.head - 1 if x.head > 0 else None for x in parsed_sentence.words]
103
+ for span in ontonotes_coref:
104
+ # input is expected to be start word, end word + 1
105
+ # counting from 0
106
+ # whereas the OntoNotes coref_span is [start_word, end_word] inclusive
107
+ span_start = span[1] + word_total
108
+ span_end = span[2] + word_total + 1
109
+ candidate_head = find_cconj_head(sentence_heads, sentence_upos, span[1], span[2]+1) if use_cconj_heads else None
110
+ if candidate_head is None:
111
+ for candidate_head in range(span[1], span[2] + 1):
112
+ # stanza uses 0 to mark the head, whereas OntoNotes is counting
113
+ # words from 0, so we have to subtract 1 from the stanza heads
114
+ #print(span, candidate_head, parsed_sentence.words[candidate_head].head - 1)
115
+ # treat the head of the phrase as the first word that has a head outside the phrase
116
+ if (parsed_sentence.words[candidate_head].head - 1 < span[1] or
117
+ parsed_sentence.words[candidate_head].head - 1 > span[2]):
118
+ break
119
+ else:
120
+ # if none have a head outside the phrase (circular??)
121
+ # then just take the first word
122
+ candidate_head = span[1]
123
+ #print("----> %d" % candidate_head)
124
+ candidate_head += word_total
125
+ span_clusters[span[0]].append((span_start, span_end))
126
+ word_clusters[span[0]].append(candidate_head)
127
+ head2span.append((candidate_head, span_start, span_end))
128
+ word_total += len(ontonotes_words)
129
+ span_clusters = sorted([sorted(values) for _, values in span_clusters.items()])
130
+ word_clusters = sorted([sorted(values) for _, values in word_clusters.items()])
131
+ head2span = sorted(head2span)
132
+
133
+ processed = {
134
+ "document_id": doc_id,
135
+ "part_id": part_id,
136
+ "cased_words": cased_words,
137
+ "sent_id": sent_id,
138
+ "speaker": speaker,
139
+ #"pos": pos,
140
+ "deprel": deprel,
141
+ "head": heads,
142
+ "span_clusters": span_clusters,
143
+ "word_clusters": word_clusters,
144
+ "head2span": head2span,
145
+ }
146
+ if part_id is not None:
147
+ processed["part_id"] = part_id
148
+ return processed
stanza/stanza/utils/datasets/corenlp_segmenter_dataset.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Output a treebank's sentences in a form that can be processed by the CoreNLP CRF Segmenter
3
+
4
+ Run it as
5
+ python3 -m stanza.utils.datasets.corenlp_segmenter_dataset <treebank>
6
+ such as
7
+ python3 -m stanza.utils.datasets.corenlp_segmenter_dataset UD_Chinese-GSDSimp --output_dir $CHINESE_SEGMENTER_HOME
8
+ """
9
+
10
+ import argparse
11
+ import os
12
+ import sys
13
+ import tempfile
14
+
15
+ import stanza.utils.datasets.common as common
16
+ import stanza.utils.datasets.prepare_tokenizer_treebank as prepare_tokenizer_treebank
17
+ import stanza.utils.default_paths as default_paths
18
+
19
+ from stanza.models.common.constant import treebank_to_short_name
20
+
21
+ def build_argparse():
22
+ parser = argparse.ArgumentParser()
23
+ parser.add_argument('treebanks', type=str, nargs='*', default=["UD_Chinese-GSDSimp"], help='Which treebanks to run on')
24
+ parser.add_argument('--output_dir', type=str, default='.', help='Where to put the results')
25
+ return parser
26
+
27
+
28
+ def write_segmenter_file(output_filename, dataset):
29
+ with open(output_filename, "w") as fout:
30
+ for sentence in dataset:
31
+ sentence = [x for x in sentence if not x.startswith("#")]
32
+ sentence = [x for x in [y.strip() for y in sentence] if x]
33
+ # eliminate MWE, although Chinese currently doesn't have any
34
+ sentence = [x for x in sentence if x.split("\t")[0].find("-") < 0]
35
+
36
+ text = " ".join(x.split("\t")[1] for x in sentence)
37
+ fout.write(text)
38
+ fout.write("\n")
39
+
40
+ def process_treebank(treebank, model_type, paths, output_dir):
41
+ with tempfile.TemporaryDirectory() as tokenizer_dir:
42
+ paths = dict(paths)
43
+ paths["TOKENIZE_DATA_DIR"] = tokenizer_dir
44
+
45
+ short_name = treebank_to_short_name(treebank)
46
+
47
+ # first we process the tokenization data
48
+ args = argparse.Namespace()
49
+ args.augment = False
50
+ args.prepare_labels = False
51
+ prepare_tokenizer_treebank.process_treebank(treebank, model_type, paths, args)
52
+
53
+ # TODO: these names should be refactored
54
+ train_file = f"{tokenizer_dir}/{short_name}.train.gold.conllu"
55
+ dev_file = f"{tokenizer_dir}/{short_name}.dev.gold.conllu"
56
+ test_file = f"{tokenizer_dir}/{short_name}.test.gold.conllu"
57
+
58
+ train_set = common.read_sentences_from_conllu(train_file)
59
+ dev_set = common.read_sentences_from_conllu(dev_file)
60
+ test_set = common.read_sentences_from_conllu(test_file)
61
+
62
+ train_out = os.path.join(output_dir, f"{short_name}.train.seg.txt")
63
+ test_out = os.path.join(output_dir, f"{short_name}.test.seg.txt")
64
+
65
+ write_segmenter_file(train_out, train_set + dev_set)
66
+ write_segmenter_file(test_out, test_set)
67
+
68
+ def main():
69
+ parser = build_argparse()
70
+ args = parser.parse_args()
71
+
72
+ paths = default_paths.get_default_paths()
73
+ for treebank in args.treebanks:
74
+ process_treebank(treebank, common.ModelType.TOKENIZER, paths, args.output_dir)
75
+
76
+ if __name__ == '__main__':
77
+ main()
78
+
stanza/stanza/utils/datasets/ner/convert_bsnlp.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import glob
3
+ import os
4
+ import logging
5
+ import random
6
+ import re
7
+
8
+ import stanza
9
+
10
+ logger = logging.getLogger('stanza')
11
+
12
+ AVAILABLE_LANGUAGES = ("bg", "cs", "pl", "ru")
13
+
14
+ def normalize_bg_entity(text, entity, raw):
15
+ entity = entity.strip()
16
+ # sanity check that the token is in the original text
17
+ if text.find(entity) >= 0:
18
+ return entity
19
+
20
+ # some entities have quotes, but the quotes are different from those in the data file
21
+ # for example:
22
+ # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_458.txt
23
+ # 'Съвета "Общи въпроси"'
24
+ # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_1002.txt
25
+ # 'Съвет "Общи въпроси"'
26
+ if sum(1 for x in entity if x == '"') == 2:
27
+ quote_entity = entity.replace('"', '“')
28
+ if text.find(quote_entity) >= 0:
29
+ logger.info("searching for '%s' instead of '%s' in %s" % (quote_entity, entity, raw))
30
+ return quote_entity
31
+
32
+ quote_entity = entity.replace('"', '„', 1).replace('"', '“')
33
+ if text.find(quote_entity) >= 0:
34
+ logger.info("searching for '%s' instead of '%s' in %s" % (quote_entity, entity, raw))
35
+ return quote_entity
36
+
37
+ if sum(1 for x in entity if x == '"') == 1:
38
+ quote_entity = entity.replace('"', '„', 1)
39
+ if text.find(quote_entity) >= 0:
40
+ logger.info("searching for '%s' instead of '%s' in %s" % (quote_entity, entity, raw))
41
+ return quote_entity
42
+
43
+ if entity.find("'") >= 0:
44
+ quote_entity = entity.replace("'", "’")
45
+ if text.find(quote_entity) >= 0:
46
+ logger.info("searching for '%s' instead of '%s' in %s" % (quote_entity, entity, raw))
47
+ return quote_entity
48
+
49
+ lower_idx = text.lower().find(entity.lower())
50
+ if lower_idx >= 0:
51
+ fixed_entity = text[lower_idx:lower_idx+len(entity)]
52
+ logger.info("lowercase match found. Searching for '%s' instead of '%s' in %s" % (fixed_entity, entity, raw))
53
+ return fixed_entity
54
+
55
+ substitution_pairs = {
56
+ # this exact error happens in:
57
+ # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_67.txt
58
+ 'Съвет по общи въпроси': 'Съвета по общи въпроси',
59
+ # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_214.txt
60
+ 'Сумимото Мицуи файненшъл груп': 'Сумитомо Мицуи файненшъл груп',
61
+ # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_245.txt
62
+ 'С и Д': 'С&Д',
63
+ # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_348.txt
64
+ 'законопроекта за излизане на Великобритания за излизане от Европейския съюз': 'законопроекта за излизане на Великобритания от Европейския съюз',
65
+ # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_771.txt
66
+ 'Унивеситета в Есекс': 'Университета в Есекс',
67
+ # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_779.txt
68
+ 'Съвет за сигурност на ООН': 'Съвета за сигурност на ООН',
69
+ # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_787.txt
70
+ 'Федерика Могерини': 'Федереика Могерини',
71
+ # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_938.txt
72
+ 'Уайстейбъл': 'Уайтстейбъл',
73
+ 'Партията за независимост на Обединеното кралство': 'Партията на независимостта на Обединеното кралство',
74
+ # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_972.txt
75
+ 'Европейска банка за възстановяване и развитие': 'Европейската банка за възстановяване и развитие',
76
+ # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_1065.txt
77
+ 'Харолд Уилсон': 'Харолд Уилсън',
78
+ 'Манчестърски университет': 'Манчестърския университет',
79
+ # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_1096.txt
80
+ 'Обединеното кралство в променящата се Европа': 'Обединеното кралство в променяща се Европа',
81
+ # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_1175.txt
82
+ 'The Daily Express': 'Daily Express',
83
+ # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_1186.txt
84
+ 'демократичната юнионистка партия': 'демократична юнионистка партия',
85
+ # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_1192.txt
86
+ 'Европейската агенция за безопасността на полетите': 'Европейската агенция за сигурността на полетите',
87
+ # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_1219.txt
88
+ 'пресцентъра на Външно министертво': 'пресцентъра на Външно министерство',
89
+ # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_1281.txt
90
+ 'Европейска агенциа за безопасността на полетите': 'Европейската агенция за сигурността на полетите',
91
+ # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_1415.txt
92
+ 'Хонк Конг': 'Хонг Конг',
93
+ # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_1663.txt
94
+ 'Лейбъристка партия': 'Лейбъристката партия',
95
+ # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_1963.txt
96
+ 'Найджъл Фараж': 'Найджъл Фарадж',
97
+ 'Фараж': 'Фарадж',
98
+
99
+ # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_1773.txt has an entity which is mixed Cyrillic and Ascii
100
+ 'Tescо': 'Tesco',
101
+ }
102
+
103
+ if entity in substitution_pairs and text.find(substitution_pairs[entity]) >= 0:
104
+ fixed_entity = substitution_pairs[entity]
105
+ logger.info("searching for '%s' instead of '%s' in %s" % (fixed_entity, entity, raw))
106
+ return fixed_entity
107
+
108
+ # oops, can't find it anywhere
109
+ # want to raise ValueError but there are just too many in the train set for BG
110
+ logger.error("Could not find '%s' in %s" % (entity, raw))
111
+
112
+ def fix_bg_typos(text, raw_filename):
113
+ typo_pairs = {
114
+ # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_202.txt is not exactly a typo, but the word is mixed cyrillic and ascii characters
115
+ 'brexit_bg.txt_file_202.txt': ('Вlооmbеrg', 'Bloomberg'),
116
+ # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_261.txt has a typo: Telegaph instead of Telegraph
117
+ 'brexit_bg.txt_file_261.txt': ('Telegaph', 'Telegraph'),
118
+ # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_574.txt has a typo: politicalskrapbook instead of politicalscrapbook
119
+ 'brexit_bg.txt_file_574.txt': ('politicalskrapbook', 'politicalscrapbook'),
120
+ # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_861.txt has a mix of cyrillic and ascii
121
+ 'brexit_bg.txt_file_861.txt': ('Съвета „Общи въпроси“', 'Съветa "Общи въпроси"'),
122
+ # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_992.txt is not exactly a typo, but the word is mixed cyrillic and ascii characters
123
+ 'brexit_bg.txt_file_992.txt': ('The Guardiаn', 'The Guardian'),
124
+ # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_1856.txt has a typo: Southerb instead of Southern
125
+ 'brexit_bg.txt_file_1856.txt': ('Southerb', 'Southern'),
126
+ }
127
+
128
+ filename = os.path.split(raw_filename)[1]
129
+ if filename in typo_pairs:
130
+ replacement = typo_pairs.get(filename)
131
+ text = text.replace(replacement[0], replacement[1])
132
+
133
+ return text
134
+
135
+ def get_sentences(language, pipeline, annotated, raw):
136
+ if language == 'bg':
137
+ normalize_entity = normalize_bg_entity
138
+ fix_typos = fix_bg_typos
139
+ else:
140
+ raise AssertionError("Please build a normalize_%s_entity and fix_%s_typos first" % language)
141
+
142
+ annotated_sentences = []
143
+ with open(raw) as fin:
144
+ lines = fin.readlines()
145
+ if len(lines) < 5:
146
+ raise ValueError("Unexpected format in %s" % raw)
147
+ text = "\n".join(lines[4:])
148
+ text = fix_typos(text, raw)
149
+
150
+ entities = {}
151
+ with open(annotated) as fin:
152
+ # first line
153
+ header = fin.readline().strip()
154
+ if len(header.split("\t")) > 1:
155
+ raise ValueError("Unexpected missing header line in %s" % annotated)
156
+ for line in fin:
157
+ pieces = line.strip().split("\t")
158
+ if len(pieces) < 3 or len(pieces) > 4:
159
+ raise ValueError("Unexpected annotation format in %s" % annotated)
160
+
161
+ entity = normalize_entity(text, pieces[0], raw)
162
+ if not entity:
163
+ continue
164
+ if entity in entities:
165
+ if entities[entity] != pieces[2]:
166
+ # would like to make this an error, but it actually happens and it's not clear how to fix
167
+ # annotated/nord_stream/bg/nord_stream_bg.txt_file_119.out
168
+ logger.warn("found multiple definitions for %s in %s" % (pieces[0], annotated))
169
+ entities[entity] = pieces[2]
170
+ else:
171
+ entities[entity] = pieces[2]
172
+
173
+ tokenized = pipeline(text)
174
+ # The benefit of doing these one at a time, instead of all at once,
175
+ # is that nested entities won't clobber previously labeled entities.
176
+ # For example, the file
177
+ # training_pl_cs_ru_bg_rc1/annotated/bg/brexit_bg.txt_file_994.out
178
+ # has each of:
179
+ # Северна Ирландия
180
+ # Република Ирландия
181
+ # Ирландия
182
+ # By doing the larger ones first, we can detect and skip the ones
183
+ # we already labeled when we reach the shorter one
184
+ regexes = [re.compile(re.escape(x)) for x in sorted(entities.keys(), key=len, reverse=True)]
185
+
186
+ bad_sentences = set()
187
+
188
+ for regex in regexes:
189
+ for match in regex.finditer(text):
190
+ start_char, end_char = match.span()
191
+ # this is inefficient, but for something only run once, it shouldn't matter
192
+ start_token = None
193
+ start_sloppy = False
194
+ end_token = None
195
+ end_sloppy = False
196
+ for token in tokenized.iter_tokens():
197
+ if token.start_char <= start_char and token.end_char > start_char:
198
+ start_token = token
199
+ if token.start_char != start_char:
200
+ start_sloppy = True
201
+ if token.start_char <= end_char and token.end_char >= end_char:
202
+ end_token = token
203
+ if token.end_char != end_char:
204
+ end_sloppy = True
205
+ break
206
+ if start_token is None or end_token is None:
207
+ raise RuntimeError("Match %s did not align with any tokens in %s" % (match.group(0), raw))
208
+ if not start_token.sent is end_token.sent:
209
+ bad_sentences.add(start_token.sent.id)
210
+ bad_sentences.add(end_token.sent.id)
211
+ logger.warn("match %s spanned sentences %d and %d in document %s" % (match.group(0), start_token.sent.id, end_token.sent.id, raw))
212
+ continue
213
+
214
+ # ids start at 1, not 0, so we have to subtract 1
215
+ # then the end token is included, so we add back the 1
216
+ # TODO: verify that this is correct if the language has MWE - cs, pl, for example
217
+ tokens = start_token.sent.tokens[start_token.id[0]-1:end_token.id[0]]
218
+ if all(token.ner for token in tokens):
219
+ # skip matches which have already been made
220
+ # this has the nice side effect of not complaining if
221
+ # a smaller match is found after a larger match
222
+ # earlier set the NER on those tokens
223
+ continue
224
+
225
+ if start_sloppy and end_sloppy:
226
+ bad_sentences.add(start_token.sent.id)
227
+ logger.warn("match %s matched in the middle of a token in %s" % (match.group(0), raw))
228
+ continue
229
+ if start_sloppy:
230
+ bad_sentences.add(end_token.sent.id)
231
+ logger.warn("match %s started matching in the middle of a token in %s" % (match.group(0), raw))
232
+ #print(start_token)
233
+ #print(end_token)
234
+ #print(start_char, end_char)
235
+ continue
236
+ if end_sloppy:
237
+ bad_sentences.add(start_token.sent.id)
238
+ logger.warn("match %s ended matching in the middle of a token in %s" % (match.group(0), raw))
239
+ #print(start_token)
240
+ #print(end_token)
241
+ #print(start_char, end_char)
242
+ continue
243
+ match_text = match.group(0)
244
+ if match_text not in entities:
245
+ raise RuntimeError("Matched %s, which is not in the entities from %s" % (match_text, annotated))
246
+ ner_tag = entities[match_text]
247
+ tokens[0].ner = "B-" + ner_tag
248
+ for token in tokens[1:]:
249
+ token.ner = "I-" + ner_tag
250
+
251
+ for sentence in tokenized.sentences:
252
+ if not sentence.id in bad_sentences:
253
+ annotated_sentences.append(sentence)
254
+
255
+ return annotated_sentences
256
+
257
+ def write_sentences(output_filename, annotated_sentences):
258
+ logger.info("Writing %d sentences to %s" % (len(annotated_sentences), output_filename))
259
+ with open(output_filename, "w") as fout:
260
+ for sentence in annotated_sentences:
261
+ for token in sentence.tokens:
262
+ ner_tag = token.ner
263
+ if not ner_tag:
264
+ ner_tag = "O"
265
+ fout.write("%s\t%s\n" % (token.text, ner_tag))
266
+ fout.write("\n")
267
+
268
+
269
+ def convert_bsnlp(language, base_input_path, output_filename, split_filename=None):
270
+ """
271
+ Converts the BSNLP dataset for the given language.
272
+
273
+ If only one output_filename is provided, all of the output goes to that file.
274
+ If split_filename is provided as well, 15% of the output chosen randomly
275
+ goes there instead. The dataset has no dev set, so this helps
276
+ divide the data into train/dev/test.
277
+ Note that the custom error fixes are only done for BG currently.
278
+ Please manually correct the data as appropriate before using this
279
+ for another language.
280
+ """
281
+ if language not in AVAILABLE_LANGUAGES:
282
+ raise ValueError("The current BSNLP datasets only include the following languages: %s" % ",".join(AVAILABLE_LANGUAGES))
283
+ if language != "bg":
284
+ raise ValueError("There were quite a few data fixes needed to get the data correct for BG. Please work on similar fixes before using the model for %s" % language.upper())
285
+ pipeline = stanza.Pipeline(language, processors="tokenize")
286
+ random.seed(1234)
287
+
288
+ annotated_path = os.path.join(base_input_path, "annotated", "*", language, "*")
289
+ annotated_files = sorted(glob.glob(annotated_path))
290
+ raw_path = os.path.join(base_input_path, "raw", "*", language, "*")
291
+ raw_files = sorted(glob.glob(raw_path))
292
+
293
+ # if the instructions for downloading the data from the
294
+ # process_ner_dataset script are followed, there will be two test
295
+ # directories of data and a separate training directory of data.
296
+ if len(annotated_files) == 0 and len(raw_files) == 0:
297
+ logger.info("Could not find files in %s" % annotated_path)
298
+ annotated_path = os.path.join(base_input_path, "annotated", language, "*")
299
+ logger.info("Trying %s instead" % annotated_path)
300
+ annotated_files = sorted(glob.glob(annotated_path))
301
+ raw_path = os.path.join(base_input_path, "raw", language, "*")
302
+ raw_files = sorted(glob.glob(raw_path))
303
+
304
+ if len(annotated_files) != len(raw_files):
305
+ raise ValueError("Unexpected differences in the file lists between %s and %s" % (annotated_files, raw_files))
306
+
307
+ for i, j in zip(annotated_files, raw_files):
308
+ if os.path.split(i)[1][:-4] != os.path.split(j)[1][:-4]:
309
+ raise ValueError("Unexpected differences in the file lists: found %s instead of %s" % (i, j))
310
+
311
+ annotated_sentences = []
312
+ if split_filename:
313
+ split_sentences = []
314
+ for annotated, raw in zip(annotated_files, raw_files):
315
+ new_sentences = get_sentences(language, pipeline, annotated, raw)
316
+ if not split_filename or random.random() < 0.85:
317
+ annotated_sentences.extend(new_sentences)
318
+ else:
319
+ split_sentences.extend(new_sentences)
320
+
321
+ write_sentences(output_filename, annotated_sentences)
322
+ if split_filename:
323
+ write_sentences(split_filename, split_sentences)
324
+
325
+ if __name__ == '__main__':
326
+ parser = argparse.ArgumentParser()
327
+ parser.add_argument('--language', type=str, default="bg", help="Language to process")
328
+ parser.add_argument('--input_path', type=str, default="/home/john/extern_data/ner/bsnlp2019", help="Where to find the files")
329
+ parser.add_argument('--output_path', type=str, default="/home/john/stanza/data/ner/bg_bsnlp.test.csv", help="Where to output the results")
330
+ parser.add_argument('--dev_path', type=str, default=None, help="A secondary output path - 15% of the data will go here")
331
+ args = parser.parse_args()
332
+
333
+ convert_bsnlp(args.language, args.input_path, args.output_path, args.dev_path)
stanza/stanza/utils/datasets/ner/convert_fire_2013.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Converts the FIRE 2013 dataset to TSV
3
+
4
+ http://au-kbc.org/nlp/NER-FIRE2013/index.html
5
+
6
+ The dataset is in six tab separated columns. The columns are
7
+
8
+ word tag chunk ner1 ner2 ner3
9
+
10
+ This script keeps just the word and the ner1. It is quite possible that using the tag would help
11
+ """
12
+
13
+ import argparse
14
+ import glob
15
+ import os
16
+ import random
17
+
18
+ def normalize(e1, e2, e3):
19
+ if e1 == 'o':
20
+ return "O"
21
+
22
+ if e2 != 'o' and e1[:2] != e2[:2]:
23
+ raise ValueError("Found a token with conflicting position tags %s,%s" % (e1, e2))
24
+ if e3 != 'o' and e2 == 'o':
25
+ raise ValueError("Found a token with tertiary label but no secondary label %s,%s,%s" % (e1, e2, e3))
26
+ if e3 != 'o' and (e1[:2] != e2[:2] or e1[:2] != e3[:2]):
27
+ raise ValueError("Found a token with conflicting position tags %s,%s,%s" % (e1, e2, e3))
28
+
29
+ if e1[2:] in ('ORGANIZATION', 'FACILITIES'):
30
+ return e1
31
+ if e1[2:] == 'ENTERTAINMENT' and e2[2:] != 'SPORTS' and e2[2:] != 'CINEMA':
32
+ return e1
33
+ if e1[2:] == 'DISEASE' and e2 == 'o':
34
+ return e1
35
+ if e1[2:] == 'PLANTS' and e2[2:] != 'PARTS':
36
+ return e1
37
+ if e1[2:] == 'PERSON' and e2[2:] == 'INDIVIDUAL':
38
+ return e1
39
+ if e1[2:] == 'LOCATION' and e2[2:] == 'PLACE':
40
+ return e1
41
+ if e1[2:] in ('DATE', 'TIME', 'YEAR'):
42
+ string = e1[:2] + 'DATETIME'
43
+ return string
44
+
45
+ return "O"
46
+
47
+ def read_fileset(filenames):
48
+ # first, read the sentences from each data file
49
+ sentences = []
50
+ for filename in filenames:
51
+ with open(filename) as fin:
52
+ next_sentence = []
53
+ for line in fin:
54
+ line = line.strip()
55
+ if not line:
56
+ # lots of single line "sentences" in the dataset
57
+ if next_sentence:
58
+ if len(next_sentence) > 1:
59
+ sentences.append(next_sentence)
60
+ next_sentence = []
61
+ else:
62
+ next_sentence.append(line)
63
+ if next_sentence and len(next_sentence) > 1:
64
+ sentences.append(next_sentence)
65
+ return sentences
66
+
67
+ def write_fileset(output_csv_file, sentences):
68
+ with open(output_csv_file, "w") as fout:
69
+ for sentence in sentences:
70
+ for line in sentence:
71
+ pieces = line.split("\t")
72
+ if len(pieces) != 6:
73
+ raise ValueError("Found %d pieces instead of the expected 6" % len(pieces))
74
+ if pieces[3] == 'o' and (pieces[4] != 'o' or pieces[5] != 'o'):
75
+ raise ValueError("Inner NER labeled but the top layer was O")
76
+ fout.write("%s\t%s\n" % (pieces[0], normalize(pieces[3], pieces[4], pieces[5])))
77
+ fout.write("\n")
78
+
79
+ def convert_fire_2013(input_path, train_csv_file, dev_csv_file, test_csv_file):
80
+ random.seed(1234)
81
+
82
+ filenames = glob.glob(os.path.join(input_path, "*"))
83
+
84
+ # won't be numerically sorted... shouldn't matter
85
+ filenames = sorted(filenames)
86
+ random.shuffle(filenames)
87
+
88
+ sentences = read_fileset(filenames)
89
+ random.shuffle(sentences)
90
+
91
+ train_cutoff = int(0.8 * len(sentences))
92
+ dev_cutoff = int(0.9 * len(sentences))
93
+
94
+ train_sentences = sentences[:train_cutoff]
95
+ dev_sentences = sentences[train_cutoff:dev_cutoff]
96
+ test_sentences = sentences[dev_cutoff:]
97
+
98
+ random.shuffle(train_sentences)
99
+ random.shuffle(dev_sentences)
100
+ random.shuffle(test_sentences)
101
+
102
+ assert len(train_sentences) > 0
103
+ assert len(dev_sentences) > 0
104
+ assert len(test_sentences) > 0
105
+
106
+ write_fileset(train_csv_file, train_sentences)
107
+ write_fileset(dev_csv_file, dev_sentences)
108
+ write_fileset(test_csv_file, test_sentences)
109
+
110
+ if __name__ == '__main__':
111
+ parser = argparse.ArgumentParser()
112
+ parser.add_argument('--input_path', type=str, default="/home/john/extern_data/ner/FIRE2013/hindi_train", help="Directory with raw files to read")
113
+ parser.add_argument('--train_file', type=str, default="/home/john/stanza/data/ner/hi_fire2013.train.csv", help="Where to put the train file")
114
+ parser.add_argument('--dev_file', type=str, default="/home/john/stanza/data/ner/hi_fire2013.dev.csv", help="Where to put the dev file")
115
+ parser.add_argument('--test_file', type=str, default="/home/john/stanza/data/ner/hi_fire2013.test.csv", help="Where to put the test file")
116
+ args = parser.parse_args()
117
+
118
+ convert_fire_2013(args.input_path, args.train_file, args.dev_file, args.test_file)
stanza/stanza/utils/datasets/ner/convert_hy_armtdp.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Convert a ArmTDP-NER dataset to BIO format
3
+
4
+ The dataset is here:
5
+
6
+ https://github.com/myavrum/ArmTDP-NER.git
7
+ """
8
+
9
+ import argparse
10
+ import os
11
+ import json
12
+ import re
13
+ import stanza
14
+ import random
15
+ from tqdm import tqdm
16
+
17
+ from stanza import DownloadMethod, Pipeline
18
+ import stanza.utils.default_paths as default_paths
19
+
20
+ def read_data(path: str) -> list:
21
+ """
22
+ Reads the Armenian named entity recognition dataset
23
+
24
+ Returns a list of dictionaries.
25
+ Each dictionary contains information
26
+ about a paragraph (text, labels, etc.)
27
+ """
28
+ with open(path, 'r') as file:
29
+ paragraphs = [json.loads(line) for line in file]
30
+ return paragraphs
31
+
32
+
33
+ def filter_unicode_broken_characters(text: str) -> str:
34
+ """
35
+ Removes all unicode characters in text
36
+ """
37
+ return re.sub(r'\\u[A-Za-z0-9]{4}', '', text)
38
+
39
+
40
+ def get_label(tok_start_char: int, tok_end_char: int, labels: list) -> list:
41
+ """
42
+ Returns the label that corresponds to the given token
43
+ """
44
+ for label in labels:
45
+ if label[0] <= tok_start_char and label[1] >= tok_end_char:
46
+ return label
47
+ return []
48
+
49
+
50
+ def format_sentences(paragraphs: list, nlp_hy: Pipeline) -> list:
51
+ """
52
+ Takes a list of paragraphs and returns a list of sentences,
53
+ where each sentence is a list of tokens along with their respective entity tags.
54
+ """
55
+ sentences = []
56
+ for paragraph in tqdm(paragraphs):
57
+ doc = nlp_hy(filter_unicode_broken_characters(paragraph['text']))
58
+ for sentence in doc.sentences:
59
+ sentence_ents = []
60
+ entity = []
61
+ for token in sentence.tokens:
62
+ label = get_label(token.start_char, token.end_char, paragraph['labels'])
63
+ if label:
64
+ entity.append(token.text)
65
+ if token.end_char == label[1]:
66
+ sentence_ents.append({'tokens': entity,
67
+ 'tag': label[2]})
68
+ entity = []
69
+ else:
70
+ sentence_ents.append({'tokens': [token.text],
71
+ 'tag': 'O'})
72
+ sentences.append(sentence_ents)
73
+ return sentences
74
+
75
+
76
+ def convert_to_bioes(sentences: list) -> list:
77
+ """
78
+ Returns a list of strings where each string represents a sentence in BIOES format
79
+ """
80
+ beios_sents = []
81
+ for sentence in tqdm(sentences):
82
+ sentence_toc = ''
83
+ for ent in sentence:
84
+ if ent['tag'] == 'O':
85
+ sentence_toc += ent['tokens'][0] + '\tO' + '\n'
86
+ else:
87
+ if len(ent['tokens']) == 1:
88
+ sentence_toc += ent['tokens'][0] + '\tS-' + ent['tag'] + '\n'
89
+ else:
90
+ sentence_toc += ent['tokens'][0] + '\tB-' + ent['tag'] + '\n'
91
+ for token in ent['tokens'][1:-1]:
92
+ sentence_toc += token + '\tI-' + ent['tag'] + '\n'
93
+ sentence_toc += ent['tokens'][-1] + '\tE-' + ent['tag'] + '\n'
94
+ beios_sents.append(sentence_toc)
95
+ return beios_sents
96
+
97
+
98
+ def write_sentences_to_file(sents, filename):
99
+ print(f"Writing {len(sents)} sentences to {filename}")
100
+ with open(filename, 'w') as outfile:
101
+ for sent in sents:
102
+ outfile.write(sent + '\n\n')
103
+
104
+
105
+ def train_test_dev_split(sents, base_output_path, short_name, train_fraction=0.7, dev_fraction=0.15):
106
+ """
107
+ Splits a list of sentences into training, dev, and test sets,
108
+ and writes each set to a separate file with write_sentences_to_file
109
+ """
110
+ num = len(sents)
111
+ train_num = int(num * train_fraction)
112
+ dev_num = int(num * dev_fraction)
113
+ if train_fraction + dev_fraction > 1.0:
114
+ raise ValueError(
115
+ "Train and dev fractions added up to more than 1: {} {} {}".format(train_fraction, dev_fraction))
116
+
117
+ random.shuffle(sents)
118
+ train_sents = sents[:train_num]
119
+ dev_sents = sents[train_num:train_num + dev_num]
120
+ test_sents = sents[train_num + dev_num:]
121
+ batches = [train_sents, dev_sents, test_sents]
122
+ filenames = [f'{short_name}.train.tsv', f'{short_name}.dev.tsv', f'{short_name}.test.tsv']
123
+ for batch, filename in zip(batches, filenames):
124
+ write_sentences_to_file(batch, os.path.join(base_output_path, filename))
125
+
126
+
127
+ def convert_dataset(base_input_path, base_output_path, short_name, download_method=DownloadMethod.DOWNLOAD_RESOURCES):
128
+ nlp_hy = stanza.Pipeline(lang='hy', processors='tokenize', download_method=download_method)
129
+ paragraphs = read_data(os.path.join(base_input_path, 'ArmNER-HY.json1'))
130
+ tagged_sentences = format_sentences(paragraphs, nlp_hy)
131
+ beios_sentences = convert_to_bioes(tagged_sentences)
132
+ train_test_dev_split(beios_sentences, base_output_path, short_name)
133
+
134
+
135
+ if __name__ == '__main__':
136
+ paths = default_paths.get_default_paths()
137
+
138
+ parser = argparse.ArgumentParser()
139
+ parser.add_argument('--input_path', type=str, default=os.path.join(paths["NERBASE"], "armenian", "ArmTDP-NER"), help="Path to input file")
140
+ parser.add_argument('--output_path', type=str, default=paths["NER_DATA_DIR"], help="Path to the output directory")
141
+ parser.add_argument('--short_name', type=str, default="hy_armtdp", help="Name to identify the dataset and the model")
142
+ parser.add_argument('--download_method', type=str, default=DownloadMethod.DOWNLOAD_RESOURCES, help="Download method for initializing the Pipeline. Default downloads the Armenian pipeline, --download_method NONE does not. Options: %s" % DownloadMethod._member_names_)
143
+ args = parser.parse_args()
144
+
145
+ convert_dataset(args.input_path, args.output_path, args.short_name, args.download_method)
stanza/stanza/utils/datasets/ner/convert_kk_kazNERD.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Convert a Kazakh NER dataset to our internal .json format
3
+ The dataset is here:
4
+
5
+ https://github.com/IS2AI/KazNERD/tree/main/KazNERD
6
+ """
7
+
8
+ import argparse
9
+ import os
10
+ import shutil
11
+ # import random
12
+
13
+ from stanza.utils.datasets.ner.utils import convert_bio_to_json, SHARDS
14
+
15
+ def convert_dataset(in_directory, out_directory, short_name):
16
+ """
17
+ Reads in train, validation, and test data and converts them to .json file
18
+ """
19
+ filenames = ("IOB2_train.txt", "IOB2_valid.txt", "IOB2_test.txt")
20
+ for shard, filename in zip(SHARDS, filenames):
21
+ input_filename = os.path.join(in_directory, filename)
22
+ output_filename = os.path.join(out_directory, "%s.%s.bio" % (short_name, shard))
23
+ shutil.copy(input_filename, output_filename)
24
+ convert_bio_to_json(out_directory, out_directory, short_name, "bio")
25
+
26
+ if __name__ == '__main__':
27
+ parser = argparse.ArgumentParser()
28
+ parser.add_argument('--input_path', type=str, default="/nlp/scr/aaydin/kazNERD/NER", help="Where to find the files")
29
+ parser.add_argument('--output_path', type=str, default="/nlp/scr/aaydin/kazNERD/data/ner", help="Where to output the results")
30
+ args = parser.parse_args()
31
+ # in_path = '/nlp/scr/aaydin/kazNERD/NER'
32
+ # out_path = '/nlp/scr/aaydin/kazNERD/NER/output'
33
+ # convert_dataset(in_path, out_path)
34
+ convert_dataset(args.input_path, args.output_path, "kk_kazNERD")
35
+
stanza/stanza/utils/datasets/ner/convert_my_ucsy.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Processes the three pieces of the NER dataset we received from UCSY.
3
+
4
+ Requires the Myanmar tokenizer to exist, since the text is not already tokenized.
5
+
6
+ There are three files sent to us from UCSY, one each for train, dev, test
7
+ This script expects them to be in the ner directory with the names
8
+ $NERBASE/my_ucsy/Myanmar_NER_train.txt
9
+ $NERBASE/my_ucsy/Myanmar_NER_dev.txt
10
+ $NERBASE/my_ucsy/Myanmar_NER_test.txt
11
+
12
+ The files are in the following format:
13
+ unsegmentedtext@LABEL|unsegmentedtext@LABEL|...
14
+ with one sentence per line
15
+
16
+ Solution:
17
+ - break the text up into fragments by splitting on |
18
+ - extract the labels
19
+ - segment each block of text using the MY tokenizer
20
+
21
+ We could take two approaches to breaking up the blocks. One would be
22
+ to combine all chunks, then segment an entire sentence at once. This
23
+ would require some logic to re-chunk the resulting pieces. Instead,
24
+ we resegment each individual chunk by itself. This loses the
25
+ information from the neighboring chunks, but guarantees there are no
26
+ screwups where segmentation crosses segment boundaries and is simpler
27
+ to code.
28
+
29
+ Of course, experimenting with the alternate approach might be better.
30
+
31
+ There is one stray label of SB in the training data, so we throw out
32
+ that entire sentence.
33
+ """
34
+
35
+
36
+ import os
37
+
38
+ from tqdm import tqdm
39
+ import stanza
40
+ from stanza.utils.datasets.ner.check_for_duplicates import check_for_duplicates
41
+
42
+ SPLITS = ("train", "dev", "test")
43
+
44
+ def convert_file(input_filename, output_filename, pipe):
45
+ with open(input_filename) as fin:
46
+ lines = fin.readlines()
47
+
48
+ all_labels = set()
49
+
50
+ with open(output_filename, "w") as fout:
51
+ for line in tqdm(lines):
52
+ pieces = line.split("|")
53
+ texts = []
54
+ labels = []
55
+ skip_sentence = False
56
+ for piece in pieces:
57
+ piece = piece.strip()
58
+ if not piece:
59
+ continue
60
+ text, label = piece.rsplit("@", maxsplit=1)
61
+ text = text.strip()
62
+ if not text:
63
+ continue
64
+ if label == 'SB':
65
+ skip_sentence = True
66
+ break
67
+
68
+ texts.append(text)
69
+ labels.append(label)
70
+
71
+ if skip_sentence:
72
+ continue
73
+
74
+ text = "\n\n".join(texts)
75
+ doc = pipe(text)
76
+ assert len(doc.sentences) == len(texts)
77
+ for sentence, label in zip(doc.sentences, labels):
78
+ all_labels.add(label)
79
+ for word_idx, word in enumerate(sentence.words):
80
+ if label == "O":
81
+ output_label = "O"
82
+ elif word_idx == 0:
83
+ output_label = "B-" + label
84
+ else:
85
+ output_label = "I-" + label
86
+
87
+ fout.write("%s\t%s\n" % (word.text, output_label))
88
+ fout.write("\n\n")
89
+
90
+ print("Finished processing {} Labels found: {}".format(input_filename, sorted(all_labels)))
91
+
92
+ def convert_my_ucsy(base_input_path, base_output_path):
93
+ os.makedirs(base_output_path, exist_ok=True)
94
+ pipe = stanza.Pipeline("my", processors="tokenize", tokenize_no_ssplit=True)
95
+ output_filenames = [os.path.join(base_output_path, "my_ucsy.%s.bio" % split) for split in SPLITS]
96
+
97
+ for split, output_filename in zip(SPLITS, output_filenames):
98
+ input_filename = os.path.join(base_input_path, "Myanmar_NER_%s.txt" % split)
99
+ if not os.path.exists(input_filename):
100
+ raise FileNotFoundError("Necessary file for my_ucsy does not exist: %s" % input_filename)
101
+
102
+ convert_file(input_filename, output_filename, pipe)
stanza/stanza/utils/datasets/ner/convert_sindhi_siner.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Converts the raw data from SiNER to .json for the Stanza NER system
3
+
4
+ https://aclanthology.org/2020.lrec-1.361.pdf
5
+ """
6
+
7
+ from stanza.utils.datasets.ner.utils import write_dataset
8
+
9
+ def fix_sentence(sentence):
10
+ """
11
+ Fix some of the mistags in the dataset
12
+
13
+ This covers 11 sentences: 1 P-PERSON, 2 with line breaks in the middle of the tag, and 8 with no B- or I-
14
+ """
15
+ new_sentence = []
16
+ for word_idx, word in enumerate(sentence):
17
+ if word[1] == 'P-PERSON':
18
+ new_sentence.append((word[0], 'B-PERSON'))
19
+ elif word[1] == 'B-OT"':
20
+ new_sentence.append((word[0], 'B-OTHERS'))
21
+ elif word[1] == 'B-T"':
22
+ new_sentence.append((word[0], 'B-TITLE'))
23
+ elif word[1] in ('GPE', 'LOC', 'OTHERS'):
24
+ if len(new_sentence) > 0 and new_sentence[-1][1][:2] in ('B-', 'I-') and new_sentence[-1][1][2:] == word[1]:
25
+ # one example... no idea if it should be a break or
26
+ # not, but the last word translates to "Corporation",
27
+ # so probably not: ميٽرو پوليٽن ڪارپوريشن
28
+ new_sentence.append((word[0], 'I-' + word[1]))
29
+ else:
30
+ new_sentence.append((word[0], 'B-' + word[1]))
31
+ else:
32
+ new_sentence.append(word)
33
+ return new_sentence
34
+
35
+ def convert_sindhi_siner(in_filename, out_directory, short_name, train_frac=0.8, dev_frac=0.1):
36
+ """
37
+ Read lines from the dataset, crudely separate sentences based on . or !, and write the dataset
38
+ """
39
+ with open(in_filename, encoding="utf-8") as fin:
40
+ lines = fin.readlines()
41
+
42
+ lines = [x.strip().split("\t") for x in lines]
43
+ lines = [(x[0].strip(), x[1].strip()) for x in lines if len(x) == 2]
44
+ print("Read %d words from %s" % (len(lines), in_filename))
45
+ sentences = []
46
+ prev_idx = 0
47
+ for sent_idx, line in enumerate(lines):
48
+ # maybe also handle line[0] == '،', "Arabic comma"?
49
+ if line[0] in ('.', '!'):
50
+ sentences.append(lines[prev_idx:sent_idx+1])
51
+ prev_idx=sent_idx+1
52
+
53
+ # in case the file doesn't end with punctuation, grab the last few lines
54
+ if prev_idx < len(lines):
55
+ sentences.append(lines[prev_idx:])
56
+
57
+ print("Found %d sentences before splitting" % len(sentences))
58
+ sentences = [fix_sentence(x) for x in sentences]
59
+ assert not any('"' in x[1] or x[1].startswith("P-") or x[1] in ("GPE", "LOC", "OTHERS") for sentence in sentences for x in sentence)
60
+
61
+ train_len = int(len(sentences) * train_frac)
62
+ dev_len = int(len(sentences) * (train_frac+dev_frac))
63
+ train_sentences = sentences[:train_len]
64
+ dev_sentences = sentences[train_len:dev_len]
65
+ test_sentences = sentences[dev_len:]
66
+
67
+ datasets = (train_sentences, dev_sentences, test_sentences)
68
+ write_dataset(datasets, out_directory, short_name, suffix="bio")
69
+
stanza/stanza/utils/datasets/ner/convert_starlang_ner.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Convert the starlang trees to a NER dataset
3
+
4
+ Has to hide quite a few trees with missing NER labels
5
+ """
6
+
7
+ import re
8
+
9
+ from stanza.models.constituency import tree_reader
10
+ import stanza.utils.datasets.constituency.convert_starlang as convert_starlang
11
+
12
+ TURKISH_WORD_RE = re.compile(r"[{]turkish=([^}]+)[}]")
13
+ TURKISH_LABEL_RE = re.compile(r"[{]namedEntity=([^}]+)[}]")
14
+
15
+
16
+
17
+ def read_tree(text):
18
+ """
19
+ Reads in a tree, then extracts the word and the NER
20
+
21
+ One problem is that it is unknown if there are cases of two separate items occurring consecutively
22
+
23
+ Note that this is quite similar to the convert_starlang script for constituency.
24
+ """
25
+ trees = tree_reader.read_trees(text)
26
+ if len(trees) > 1:
27
+ raise ValueError("Tree file had two trees!")
28
+ tree = trees[0]
29
+ words = []
30
+ for label in tree.leaf_labels():
31
+ match = TURKISH_WORD_RE.search(label)
32
+ if match is None:
33
+ raise ValueError("Could not find word in |{}|".format(label))
34
+ word = match.group(1)
35
+ word = word.replace("-LCB-", "{").replace("-RCB-", "}")
36
+
37
+ match = TURKISH_LABEL_RE.search(label)
38
+ if match is None:
39
+ raise ValueError("Could not find ner in |{}|".format(label))
40
+ tag = match.group(1)
41
+ if tag == 'NONE' or tag == "null":
42
+ tag = 'O'
43
+ words.append((word, tag))
44
+
45
+ return words
46
+
47
+ def read_starlang(paths):
48
+ return convert_starlang.read_starlang(paths, conversion=read_tree, log=False)
49
+
50
+ def main():
51
+ train, dev, test = convert_starlang.main(conversion=read_tree, log=False)
52
+
53
+ if __name__ == '__main__':
54
+ main()
55
+
stanza/stanza/utils/datasets/ner/ontonotes_multitag.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Combines OntoNotes and WW into a single dataset with OntoNotes used for dev & test
3
+
4
+ The resulting dataset has two layers saved in the multi_ner column.
5
+
6
+ WW is kept as 9 classes, with the tag put in either the first or
7
+ second layer depending on the flags.
8
+
9
+ OntoNotes is converted to one column for 18 and one column for 9 classes.
10
+ """
11
+
12
+ import argparse
13
+ import json
14
+ import os
15
+ import shutil
16
+
17
+ from stanza.utils import default_paths
18
+ from stanza.utils.datasets.ner.utils import combine_files
19
+ from stanza.utils.datasets.ner.simplify_ontonotes_to_worldwide import simplify_ontonotes_to_worldwide
20
+
21
+ def convert_ontonotes_file(filename, simplify, bigger_first):
22
+ assert "en_ontonotes" in filename
23
+ if not os.path.exists(filename):
24
+ raise FileNotFoundError("Cannot convert missing file %s" % filename)
25
+ new_filename = filename.replace("en_ontonotes", "en_ontonotes-multi")
26
+
27
+ with open(filename) as fin:
28
+ doc = json.load(fin)
29
+
30
+ for sentence in doc:
31
+ for word in sentence:
32
+ ner = word['ner']
33
+ if simplify:
34
+ simplified = simplify_ontonotes_to_worldwide(ner)
35
+ else:
36
+ simplified = "-"
37
+ if bigger_first:
38
+ word['multi_ner'] = (ner, simplified)
39
+ else:
40
+ word['multi_ner'] = (simplified, ner)
41
+
42
+ with open(new_filename, "w") as fout:
43
+ json.dump(doc, fout, indent=2)
44
+
45
+ def convert_worldwide_file(filename, bigger_first):
46
+ assert "en_worldwide-9class" in filename
47
+ if not os.path.exists(filename):
48
+ raise FileNotFoundError("Cannot convert missing file %s" % filename)
49
+
50
+ new_filename = filename.replace("en_worldwide-9class", "en_worldwide-9class-multi")
51
+
52
+ with open(filename) as fin:
53
+ doc = json.load(fin)
54
+
55
+ for sentence in doc:
56
+ for word in sentence:
57
+ ner = word['ner']
58
+ if bigger_first:
59
+ word['multi_ner'] = ("-", ner)
60
+ else:
61
+ word['multi_ner'] = (ner, "-")
62
+
63
+ with open(new_filename, "w") as fout:
64
+ json.dump(doc, fout, indent=2)
65
+
66
+ def build_multitag_dataset(base_output_path, short_name, simplify, bigger_first):
67
+ convert_ontonotes_file(os.path.join(base_output_path, "en_ontonotes.train.json"), simplify, bigger_first)
68
+ convert_ontonotes_file(os.path.join(base_output_path, "en_ontonotes.dev.json"), simplify, bigger_first)
69
+ convert_ontonotes_file(os.path.join(base_output_path, "en_ontonotes.test.json"), simplify, bigger_first)
70
+
71
+ convert_worldwide_file(os.path.join(base_output_path, "en_worldwide-9class.train.json"), bigger_first)
72
+ convert_worldwide_file(os.path.join(base_output_path, "en_worldwide-9class.dev.json"), bigger_first)
73
+ convert_worldwide_file(os.path.join(base_output_path, "en_worldwide-9class.test.json"), bigger_first)
74
+
75
+ combine_files(os.path.join(base_output_path, "%s.train.json" % short_name),
76
+ os.path.join(base_output_path, "en_ontonotes-multi.train.json"),
77
+ os.path.join(base_output_path, "en_worldwide-9class-multi.train.json"))
78
+ shutil.copyfile(os.path.join(base_output_path, "en_ontonotes-multi.dev.json"),
79
+ os.path.join(base_output_path, "%s.dev.json" % short_name))
80
+ shutil.copyfile(os.path.join(base_output_path, "en_ontonotes-multi.test.json"),
81
+ os.path.join(base_output_path, "%s.test.json" % short_name))
82
+
83
+
84
+ def main():
85
+ parser = argparse.ArgumentParser()
86
+ parser.add_argument('--no_simplify', dest='simplify', action='store_false', help='By default, this script will simplify the OntoNotes 18 classes to the 8 WorldWide classes in a second column. Turning that off will leave that column blank. Initial experiments with that setting were very bad, though')
87
+ parser.add_argument('--no_bigger_first', dest='bigger_first', action='store_false', help='By default, this script will put the 18 class tags in the first column and the 8 in the second. This flips the order')
88
+ args = parser.parse_args()
89
+
90
+ paths = default_paths.get_default_paths()
91
+ base_output_path = paths["NER_DATA_DIR"]
92
+
93
+ build_multitag_dataset(base_output_path, "en_ontonotes-ww-multi", args.simplify, args.bigger_first)
94
+
95
+ if __name__ == '__main__':
96
+ main()
97
+
stanza/stanza/utils/datasets/ner/prepare_ner_file.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This script converts NER data from the CoNLL03 format to the latest CoNLL-U format. The script assumes that in the
3
+ input column format data, the token is always in the first column, while the NER tag is always in the last column.
4
+ """
5
+
6
+ import argparse
7
+ import json
8
+
9
+ MIN_NUM_FIELD = 2
10
+ MAX_NUM_FIELD = 5
11
+
12
+ DOC_START_TOKEN = '-DOCSTART-'
13
+
14
+ def parse_args():
15
+ parser = argparse.ArgumentParser(description="Convert the conll03 format data into conllu format.")
16
+ parser.add_argument('input', help='Input conll03 format data filename.')
17
+ parser.add_argument('output', help='Output json filename.')
18
+ args = parser.parse_args()
19
+ return args
20
+
21
+ def main():
22
+ args = parse_args()
23
+ process_dataset(args.input, args.output)
24
+
25
+ def process_dataset(input_filename, output_filename):
26
+ sentences = load_conll03(input_filename)
27
+ print("{} examples loaded from {}".format(len(sentences), input_filename))
28
+
29
+ document = []
30
+ for (words, tags) in sentences:
31
+ sent = []
32
+ for w, t in zip(words, tags):
33
+ sent += [{'text': w, 'ner': t}]
34
+ document += [sent]
35
+
36
+ with open(output_filename, 'w', encoding="utf-8") as outfile:
37
+ json.dump(document, outfile, indent=1)
38
+ print("Generated json file {}".format(output_filename))
39
+
40
+ # TODO: make skip_doc_start an argument
41
+ def load_conll03(filename, skip_doc_start=True):
42
+ cached_lines = []
43
+ examples = []
44
+ with open(filename, encoding="utf-8") as infile:
45
+ for line in infile:
46
+ line = line.strip()
47
+ if skip_doc_start and DOC_START_TOKEN in line:
48
+ continue
49
+ if len(line) > 0:
50
+ array = line.split("\t")
51
+ if len(array) < MIN_NUM_FIELD:
52
+ array = line.split()
53
+ if len(array) < MIN_NUM_FIELD:
54
+ continue
55
+ else:
56
+ cached_lines.append(line)
57
+ elif len(cached_lines) > 0:
58
+ example = process_cache(cached_lines)
59
+ examples.append(example)
60
+ cached_lines = []
61
+ if len(cached_lines) > 0:
62
+ examples.append(process_cache(cached_lines))
63
+ return examples
64
+
65
+ def process_cache(cached_lines):
66
+ tokens = []
67
+ ner_tags = []
68
+ for line in cached_lines:
69
+ array = line.split("\t")
70
+ if len(array) < MIN_NUM_FIELD:
71
+ array = line.split()
72
+ assert len(array) >= MIN_NUM_FIELD and len(array) <= MAX_NUM_FIELD, "Got unexpected line length: {}".format(array)
73
+ tokens.append(array[0])
74
+ ner_tags.append(array[-1])
75
+ return (tokens, ner_tags)
76
+
77
+ if __name__ == '__main__':
78
+ main()
stanza/stanza/utils/datasets/ner/utils.py ADDED
@@ -0,0 +1,417 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utils for the processing of NER datasets
3
+
4
+ These can be invoked from either the specific dataset scripts
5
+ or the entire prepare_ner_dataset.py script
6
+ """
7
+
8
+ from collections import defaultdict
9
+ import json
10
+ import os
11
+ import random
12
+
13
+ from stanza.models.common.doc import Document
14
+ import stanza.utils.datasets.ner.prepare_ner_file as prepare_ner_file
15
+
16
+ SHARDS = ('train', 'dev', 'test')
17
+
18
+ def bioes_to_bio(tags):
19
+ new_tags = []
20
+ in_entity = False
21
+ for tag in tags:
22
+ if tag == 'O':
23
+ new_tags.append(tag)
24
+ in_entity = False
25
+ elif in_entity and (tag.startswith("B-") or tag.startswith("S-")):
26
+ # TODO: does the tag have to match the previous tag?
27
+ # eg, does B-LOC B-PER in BIOES need a B-PER or is I-PER sufficient?
28
+ new_tags.append('B-' + tag[2:])
29
+ else:
30
+ new_tags.append('I-' + tag[2:])
31
+ in_entity = True
32
+ return new_tags
33
+
34
+ def convert_bioes_to_bio(base_input_path, base_output_path, short_name):
35
+ """
36
+ Convert BIOES files back to BIO (not BIO2)
37
+
38
+ Useful for preparing datasets for CoreNLP, which doesn't do great with the more highly split classes
39
+ """
40
+ for shard in SHARDS:
41
+ input_filename = os.path.join(base_input_path, '%s.%s.bioes' % (short_name, shard))
42
+ output_filename = os.path.join(base_output_path, '%s.%s.bio' % (short_name, shard))
43
+
44
+ input_sentences = read_tsv(input_filename, text_column=0, annotation_column=1)
45
+ new_sentences = []
46
+ for sentence in input_sentences:
47
+ tags = [x[1] for x in sentence]
48
+ tags = bioes_to_bio(tags)
49
+ sentence = [(x[0], y) for x, y in zip(sentence, tags)]
50
+ new_sentences.append(sentence)
51
+ write_sentences(output_filename, new_sentences)
52
+
53
+
54
+ def convert_bio_to_json(base_input_path, base_output_path, short_name, suffix="bio", shard_names=SHARDS, shards=SHARDS):
55
+ """
56
+ Convert BIO files to json
57
+
58
+ It can often be convenient to put the intermediate BIO files in
59
+ the same directory as the output files, in which case you can pass
60
+ in same path for both base_input_path and base_output_path.
61
+
62
+ This also will rewrite a BIOES as json
63
+ """
64
+ for input_shard, output_shard in zip(shard_names, shards):
65
+ input_filename = os.path.join(base_input_path, '%s.%s.%s' % (short_name, input_shard, suffix))
66
+ if not os.path.exists(input_filename):
67
+ alt_filename = os.path.join(base_input_path, '%s.%s' % (input_shard, suffix))
68
+ if os.path.exists(alt_filename):
69
+ input_filename = alt_filename
70
+ else:
71
+ raise FileNotFoundError('Cannot find %s component of %s in %s or %s' % (output_shard, short_name, input_filename, alt_filename))
72
+ output_filename = os.path.join(base_output_path, '%s.%s.json' % (short_name, output_shard))
73
+ print("Converting %s to %s" % (input_filename, output_filename))
74
+ prepare_ner_file.process_dataset(input_filename, output_filename)
75
+
76
+ def get_tags(datasets):
77
+ """
78
+ return the set of tags used in these datasets
79
+
80
+ datasets is expected to be train, dev, test but could be any list
81
+ """
82
+ tags = set()
83
+ for dataset in datasets:
84
+ for sentence in dataset:
85
+ for word, tag in sentence:
86
+ tags.add(tag)
87
+ return tags
88
+
89
+ def write_sentences(output_filename, dataset):
90
+ """
91
+ Write exactly one output file worth of dataset
92
+ """
93
+ os.makedirs(os.path.split(output_filename)[0], exist_ok=True)
94
+ with open(output_filename, "w", encoding="utf-8") as fout:
95
+ for sent_idx, sentence in enumerate(dataset):
96
+ for word_idx, word in enumerate(sentence):
97
+ if len(word) > 2:
98
+ word = word[:2]
99
+ try:
100
+ fout.write("%s\t%s\n" % word)
101
+ except TypeError:
102
+ raise TypeError("Unable to process sentence %d word %d of file %s" % (sent_idx, word_idx, output_filename))
103
+ fout.write("\n")
104
+
105
+ def write_dataset(datasets, output_dir, short_name, suffix="bio", shard_names=SHARDS, shards=SHARDS):
106
+ """
107
+ write all three pieces of a dataset to output_dir
108
+
109
+ datasets should be 3 lists: train, dev, test
110
+ each list should be a list of sentences
111
+ each sentence is a list of pairs: word, tag
112
+
113
+ after writing to .bio files, the files will be converted to .json
114
+ """
115
+ for shard, dataset in zip(shard_names, datasets):
116
+ output_filename = os.path.join(output_dir, "%s.%s.%s" % (short_name, shard, suffix))
117
+ write_sentences(output_filename, dataset)
118
+
119
+ convert_bio_to_json(output_dir, output_dir, short_name, suffix, shard_names=shard_names, shards=shards)
120
+
121
+
122
+ def write_multitag_json(output_filename, dataset):
123
+ json_dataset = []
124
+ for sentence in dataset:
125
+ json_sentence = []
126
+ for word in sentence:
127
+ word = {'text': word[0],
128
+ 'ner': word[1],
129
+ 'multi_ner': word[2]}
130
+ json_sentence.append(word)
131
+ json_dataset.append(json_sentence)
132
+ with open(output_filename, 'w', encoding='utf-8') as fout:
133
+ json.dump(json_dataset, fout, indent=2)
134
+
135
+ def write_multitag_dataset(datasets, output_dir, short_name, suffix="bio", shard_names=SHARDS, shards=SHARDS):
136
+ for shard, dataset in zip(shard_names, datasets):
137
+ output_filename = os.path.join(output_dir, "%s.%s.%s" % (short_name, shard, suffix))
138
+ write_sentences(output_filename, dataset)
139
+
140
+ for shard, dataset in zip(shard_names, datasets):
141
+ output_filename = os.path.join(output_dir, "%s.%s.json" % (short_name, shard))
142
+ write_multitag_json(output_filename, dataset)
143
+
144
+ def read_tsv(filename, text_column, annotation_column, remap_fn=None, skip_comments=True, keep_broken_tags=False, keep_all_columns=False, separator="\t"):
145
+ """
146
+ Read sentences from a TSV file
147
+
148
+ Returns a list of list of (word, tag)
149
+
150
+ If keep_broken_tags==True, then None is returned for a missing. Otherwise, an IndexError is thrown
151
+ """
152
+ with open(filename, encoding="utf-8") as fin:
153
+ lines = fin.readlines()
154
+
155
+ lines = [x.strip() for x in lines]
156
+
157
+ sentences = []
158
+ current_sentence = []
159
+ for line_idx, line in enumerate(lines):
160
+ if not line:
161
+ if current_sentence:
162
+ sentences.append(current_sentence)
163
+ current_sentence = []
164
+ continue
165
+ if skip_comments and line.startswith("#"):
166
+ continue
167
+
168
+ pieces = line.split(separator)
169
+ try:
170
+ word = pieces[text_column]
171
+ except IndexError as e:
172
+ raise IndexError("Could not find word index %d at line %d |%s|" % (text_column, line_idx, line)) from e
173
+ if word == '\x96':
174
+ # this happens in GermEval2014 for some reason
175
+ continue
176
+ try:
177
+ tag = pieces[annotation_column]
178
+ except IndexError as e:
179
+ if keep_broken_tags:
180
+ tag = None
181
+ else:
182
+ raise IndexError("Could not find tag index %d at line %d |%s|" % (annotation_column, line_idx, line)) from e
183
+ if remap_fn:
184
+ tag = remap_fn(tag)
185
+
186
+ if keep_all_columns:
187
+ pieces[annotation_column] = tag
188
+ current_sentence.append(pieces)
189
+ else:
190
+ current_sentence.append((word, tag))
191
+
192
+ if current_sentence:
193
+ sentences.append(current_sentence)
194
+
195
+ return sentences
196
+
197
+ def random_shuffle_directory(input_dir, output_dir, short_name):
198
+ input_files = os.listdir(input_dir)
199
+ input_files = sorted(input_files)
200
+ random_shuffle_files(input_dir, input_files, output_dir, short_name)
201
+
202
+ def random_shuffle_files(input_dir, input_files, output_dir, short_name):
203
+ """
204
+ Shuffle the files into different chunks based on their filename
205
+
206
+ The first piece of the filename, split by ".", is used as a random seed.
207
+
208
+ This will make it so that adding new files or using a different
209
+ annotation scheme (assuming that's encoding in pieces of the
210
+ filename) won't change the distibution of the files
211
+ """
212
+ input_keys = {}
213
+ for f in input_files:
214
+ seed = f.split(".")[0]
215
+ if seed in input_keys:
216
+ raise ValueError("Multiple files with the same prefix: %s and %s" % (input_keys[seed], f))
217
+ input_keys[seed] = f
218
+ assert len(input_keys) == len(input_files)
219
+
220
+ train_files = []
221
+ dev_files = []
222
+ test_files = []
223
+
224
+ for filename in input_files:
225
+ seed = filename.split(".")[0]
226
+ # "salt" the filenames when using as a seed
227
+ # definitely not because of a dumb bug in the original implementation
228
+ seed = seed + ".txt.4class.tsv"
229
+ random.seed(seed, 2)
230
+ location = random.random()
231
+ if location < 0.7:
232
+ train_files.append(filename)
233
+ elif location < 0.8:
234
+ dev_files.append(filename)
235
+ else:
236
+ test_files.append(filename)
237
+
238
+ print("Train files: %d Dev files: %d Test files: %d" % (len(train_files), len(dev_files), len(test_files)))
239
+ assert len(train_files) + len(dev_files) + len(test_files) == len(input_files)
240
+
241
+ file_lists = [train_files, dev_files, test_files]
242
+ datasets = []
243
+ for files in file_lists:
244
+ dataset = []
245
+ for filename in files:
246
+ dataset.extend(read_tsv(os.path.join(input_dir, filename), 0, 1))
247
+ datasets.append(dataset)
248
+
249
+ write_dataset(datasets, output_dir, short_name)
250
+ return len(train_files), len(dev_files), len(test_files)
251
+
252
+ def random_shuffle_by_prefixes(input_dir, output_dir, short_name, prefix_map):
253
+ input_files = os.listdir(input_dir)
254
+ input_files = sorted(input_files)
255
+
256
+ file_divisions = defaultdict(list)
257
+ for filename in input_files:
258
+ for division in prefix_map.keys():
259
+ for prefix in prefix_map[division]:
260
+ if filename.startswith(prefix):
261
+ break
262
+ else: # for/else is intentional
263
+ continue
264
+ break
265
+ else: # yes, stop asking
266
+ raise ValueError("Could not assign %s to any of the divisions in the prefix_map" % filename)
267
+ #print("Assigning %s to %s because of %s" % (filename, division, prefix))
268
+ file_divisions[division].append(filename)
269
+
270
+ num_train_files = 0
271
+ num_dev_files = 0
272
+ num_test_files = 0
273
+ for division in file_divisions.keys():
274
+ print()
275
+ print("Processing %d files from %s" % (len(file_divisions[division]), division))
276
+ d_train, d_dev, d_test = random_shuffle_files(input_dir, file_divisions[division], output_dir, "%s-%s" % (short_name, division))
277
+ num_train_files += d_train
278
+ num_dev_files += d_dev
279
+ num_test_files += d_test
280
+
281
+ print()
282
+ print("After shuffling: Train files: %d Dev files: %d Test files: %d" % (num_train_files, num_dev_files, num_test_files))
283
+ dataset_divisions = ["%s-%s" % (short_name, division) for division in file_divisions]
284
+ combine_dataset(output_dir, output_dir, dataset_divisions, short_name)
285
+
286
+ def combine_dataset(input_dir, output_dir, input_datasets, output_dataset):
287
+ datasets = []
288
+ for shard in SHARDS:
289
+ full_dataset = []
290
+ for input_dataset in input_datasets:
291
+ input_filename = "%s.%s.json" % (input_dataset, shard)
292
+ input_path = os.path.join(input_dir, input_filename)
293
+ with open(input_path, encoding="utf-8") as fin:
294
+ dataset = json.load(fin)
295
+ converted = [[(word['text'], word['ner']) for word in sentence] for sentence in dataset]
296
+ full_dataset.extend(converted)
297
+ datasets.append(full_dataset)
298
+ write_dataset(datasets, output_dir, output_dataset)
299
+
300
+ def read_prefix_file(destination_file):
301
+ """
302
+ Read a prefix file such as the one for the Worldwide dataset
303
+
304
+ the format should be
305
+
306
+ africa:
307
+ af_
308
+ ...
309
+
310
+ asia:
311
+ cn_
312
+ ...
313
+ """
314
+ destination = None
315
+ known_prefixes = set()
316
+ prefixes = []
317
+
318
+ prefix_map = {}
319
+ with open(destination_file, encoding="utf-8") as fin:
320
+ for line in fin:
321
+ line = line.strip()
322
+ if line.startswith("#"):
323
+ continue
324
+ if not line:
325
+ continue
326
+ if line.endswith(":"):
327
+ if destination is not None:
328
+ prefix_map[destination] = prefixes
329
+ prefixes = []
330
+ destination = line[:-1].strip().lower().replace(" ", "_")
331
+ else:
332
+ if not destination:
333
+ raise RuntimeError("Found a prefix before the first label was assigned when reading %s" % destination_file)
334
+ prefixes.append(line)
335
+ if line in known_prefixes:
336
+ raise RuntimeError("Found the same prefix twice! %s" % line)
337
+ known_prefixes.add(line)
338
+
339
+ if destination and prefixes:
340
+ prefix_map[destination] = prefixes
341
+
342
+ return prefix_map
343
+
344
+ def read_json_entities(filename):
345
+ """
346
+ Read entities from a file, return a list of (text, label)
347
+
348
+ Should work on both BIOES and BIO
349
+ """
350
+ with open(filename) as fin:
351
+ doc = Document(json.load(fin))
352
+
353
+ return list_doc_entities(doc)
354
+
355
+ def list_doc_entities(doc):
356
+ """
357
+ Return a list of (text, label)
358
+
359
+ Should work on both BIOES and BIO
360
+ """
361
+ entities = []
362
+ for sentence in doc.sentences:
363
+ current_entity = []
364
+ previous_label = None
365
+ for token in sentence.tokens:
366
+ if token.ner == 'O' or token.ner.startswith("E-"):
367
+ if token.ner.startswith("E-"):
368
+ current_entity.append(token.text)
369
+ if current_entity:
370
+ assert previous_label is not None
371
+ entities.append((current_entity, previous_label))
372
+ current_entity = []
373
+ previous_label = None
374
+ elif token.ner.startswith("I-"):
375
+ if previous_label is not None and previous_label != 'O' and previous_label != token.ner[2:]:
376
+ if current_entity:
377
+ assert previous_label is not None
378
+ entities.append((current_entity, previous_label))
379
+ current_entity = []
380
+ previous_label = token.ner[2:]
381
+ current_entity.append(token.text)
382
+ elif token.ner.startswith("B-") or token.ner.startswith("S-"):
383
+ if current_entity:
384
+ assert previous_label is not None
385
+ entities.append((current_entity, previous_label))
386
+ current_entity = []
387
+ previous_label = None
388
+ current_entity.append(token.text)
389
+ previous_label = token.ner[2:]
390
+ if token.ner.startswith("S-"):
391
+ assert previous_label is not None
392
+ entities.append(current_entity)
393
+ current_entity = []
394
+ previous_label = None
395
+ else:
396
+ raise RuntimeError("Expected BIO(ES) format in the json file!")
397
+ previous_label = token.ner[2:]
398
+ if current_entity:
399
+ assert previous_label is not None
400
+ entities.append((current_entity, previous_label))
401
+ entities = [(tuple(x[0]), x[1]) for x in entities]
402
+ return entities
403
+
404
+ def combine_files(output_filename, *input_filenames):
405
+ """
406
+ Combine multiple NER json files into one NER file
407
+ """
408
+ doc = []
409
+
410
+ for filename in input_filenames:
411
+ with open(filename) as fin:
412
+ new_doc = json.load(fin)
413
+ doc.extend(new_doc)
414
+
415
+ with open(output_filename, "w") as fout:
416
+ json.dump(doc, fout, indent=2)
417
+
stanza/stanza/utils/datasets/vietnamese/__init__.py ADDED
File without changes
stanza/stanza/utils/pretrain/compare_pretrains.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import numpy as np
3
+
4
+ from stanza.models.common.pretrain import Pretrain
5
+
6
+ pt1_filename = sys.argv[1]
7
+ pt2_filename = sys.argv[2]
8
+
9
+ pt1 = Pretrain(pt1_filename)
10
+ pt2 = Pretrain(pt2_filename)
11
+
12
+ vocab1 = pt1.vocab
13
+ vocab2 = pt2.vocab
14
+
15
+ common_words = [x for x in vocab1 if x in vocab2]
16
+ print("%d shared words, out of %d in %s and %d in %s" % (len(common_words), len(vocab1), pt1_filename, len(vocab2), pt2_filename))
17
+
18
+ eps = 0.0001
19
+ total_norm = 0.0
20
+ total_close = 0
21
+
22
+ words_different = []
23
+
24
+ for word, idx in vocab1._unit2id.items():
25
+ if word not in vocab2:
26
+ continue
27
+ v1 = pt1.emb[idx]
28
+ v2 = pt2.emb[pt2.vocab[word]]
29
+ norm = np.linalg.norm(v1 - v2)
30
+
31
+ if norm < eps:
32
+ total_close += 1
33
+ else:
34
+ total_norm += norm
35
+ if len(words_different) < 10:
36
+ words_different.append("|%s|" % word)
37
+ #print(word, idx, pt2.vocab[word])
38
+ #print(v1)
39
+ #print(v2)
40
+
41
+ if total_close < len(common_words):
42
+ avg_norm = total_norm / (len(common_words) - total_close)
43
+ print("%d vectors were close. Average difference of the others: %f" % (total_close, avg_norm))
44
+ print("The first few different words were:\n %s" % "\n ".join(words_different))
45
+ else:
46
+ print("All %d vectors were close!" % total_close)
47
+
48
+ for word, idx in vocab1._unit2id.items():
49
+ if word not in vocab2:
50
+ continue
51
+ if pt2.vocab[word] != idx:
52
+ break
53
+ else:
54
+ print("All indices are the same")
stanza/stanza/utils/training/common.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import glob
3
+ import logging
4
+ import os
5
+ import pathlib
6
+ import sys
7
+ import tempfile
8
+
9
+ from enum import Enum
10
+
11
+ from stanza.resources.default_packages import default_charlms, lemma_charlms, pos_charlms, depparse_charlms, TRANSFORMERS, TRANSFORMER_LAYERS
12
+ from stanza.models.common.constant import treebank_to_short_name
13
+ from stanza.models.common.utils import ud_scores
14
+ from stanza.resources.common import download, DEFAULT_MODEL_DIR, UnknownLanguageError
15
+ from stanza.utils.datasets import common
16
+ import stanza.utils.default_paths as default_paths
17
+ from stanza.utils import conll18_ud_eval as ud_eval
18
+
19
+ logger = logging.getLogger('stanza')
20
+
21
+ class Mode(Enum):
22
+ TRAIN = 1
23
+ SCORE_DEV = 2
24
+ SCORE_TEST = 3
25
+ SCORE_TRAIN = 4
26
+
27
+ class ArgumentParserWithExtraHelp(argparse.ArgumentParser):
28
+ def __init__(self, sub_argparse, *args, **kwargs):
29
+ super().__init__(*args, **kwargs) # forwards all unused arguments
30
+
31
+ self.sub_argparse = sub_argparse
32
+
33
+ def print_help(self, file=None):
34
+ super().print_help(file=file)
35
+
36
+ def format_help(self):
37
+ help_text = super().format_help()
38
+ if self.sub_argparse is not None:
39
+ sub_text = self.sub_argparse.format_help().split("\n")
40
+ first_line = -1
41
+ for line_idx, line in enumerate(sub_text):
42
+ if line.strip().startswith("usage:"):
43
+ first_line = line_idx
44
+ elif first_line >= 0 and not line.strip():
45
+ first_line = line_idx
46
+ break
47
+ help_text = help_text + "\n\nmodel arguments:" + "\n".join(sub_text[first_line:])
48
+ return help_text
49
+
50
+
51
+ def build_argparse(sub_argparse=None):
52
+ parser = ArgumentParserWithExtraHelp(sub_argparse=sub_argparse, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
53
+ parser.add_argument('--save_output', dest='temp_output', default=True, action='store_false', help="Save output - default is to use a temp directory.")
54
+
55
+ parser.add_argument('treebanks', type=str, nargs='+', help='Which treebanks to run on. Use all_ud or ud_all for all UD treebanks')
56
+
57
+ parser.add_argument('--train', dest='mode', default=Mode.TRAIN, action='store_const', const=Mode.TRAIN, help='Run in train mode')
58
+ parser.add_argument('--score_dev', dest='mode', action='store_const', const=Mode.SCORE_DEV, help='Score the dev set')
59
+ parser.add_argument('--score_test', dest='mode', action='store_const', const=Mode.SCORE_TEST, help='Score the test set')
60
+ parser.add_argument('--score_train', dest='mode', action='store_const', const=Mode.SCORE_TRAIN, help='Score the train set as a test set. Currently only implemented for some models')
61
+
62
+ # These arguments need to be here so we can identify if the model already exists in the user-specified home
63
+ # TODO: when all of the model scripts handle their own names, can eliminate this argument
64
+ parser.add_argument('--save_dir', type=str, default=None, help="Root dir for saving models. If set, will override the model's default.")
65
+ parser.add_argument('--save_name', type=str, default=None, help="Base name for saving models. If set, will override the model's default.")
66
+
67
+ parser.add_argument('--charlm_only', action='store_true', default=False, help='When asking for ud_all, filter the ones which have charlms')
68
+ parser.add_argument('--transformer_only', action='store_true', default=False, help='When asking for ud_all, filter the ones for languages where we have transformers')
69
+
70
+ parser.add_argument('--force', dest='force', action='store_true', default=False, help='Retrain existing models')
71
+ return parser
72
+
73
+ def add_charlm_args(parser):
74
+ parser.add_argument('--charlm', default="default", type=str, help='Which charlm to run on. Will use the default charlm for this language/model if not set. Set to None to turn off charlm for languages with a default charlm')
75
+ parser.add_argument('--no_charlm', dest='charlm', action="store_const", const=None, help="Don't use a charlm, even if one is used by default for this package")
76
+
77
+ def main(run_treebank, model_dir, model_name, add_specific_args=None, sub_argparse=None, build_model_filename=None, choose_charlm_method=None, args=None):
78
+ """
79
+ A main program for each of the run_xyz scripts
80
+
81
+ It collects the arguments and runs the main method for each dataset provided.
82
+ It also tries to look for an existing model and not overwrite it unless --force is provided
83
+
84
+ model_name can be a callable expecting the args
85
+ - the charlm, for example, needs this feature, since it makes
86
+ both forward and backward models
87
+ """
88
+ if args is None:
89
+ logger.info("Training program called with:\n" + " ".join(sys.argv))
90
+ args = sys.argv[1:]
91
+ else:
92
+ logger.info("Training program called with:\n" + " ".join(args))
93
+
94
+ paths = default_paths.get_default_paths()
95
+
96
+ parser = build_argparse(sub_argparse)
97
+ if add_specific_args is not None:
98
+ add_specific_args(parser)
99
+ if '--extra_args' in sys.argv:
100
+ idx = sys.argv.index('--extra_args')
101
+ extra_args = sys.argv[idx+1:]
102
+ command_args = parser.parse_args(sys.argv[:idx])
103
+ else:
104
+ command_args, extra_args = parser.parse_known_args(args=args)
105
+
106
+ # Pass this through to the underlying model as well as use it here
107
+ # we don't put --save_name here for the awkward situation of
108
+ # --save_name being specified for an invocation with multiple treebanks
109
+ if command_args.save_dir:
110
+ extra_args.extend(["--save_dir", command_args.save_dir])
111
+
112
+ if callable(model_name):
113
+ model_name = model_name(command_args)
114
+
115
+ mode = command_args.mode
116
+ treebanks = []
117
+
118
+ for treebank in command_args.treebanks:
119
+ # this is a really annoying typo to make if you copy/paste a
120
+ # UD directory name on the cluster and your job dies 30s after
121
+ # being queued for an hour
122
+ if treebank.endswith("/"):
123
+ treebank = treebank[:-1]
124
+ if treebank.lower() in ('ud_all', 'all_ud'):
125
+ ud_treebanks = common.get_ud_treebanks(paths["UDBASE"])
126
+ if choose_charlm_method is not None and command_args.charlm_only:
127
+ logger.info("Filtering ud_all treebanks to only those which can use charlm for this model")
128
+ ud_treebanks = [x for x in ud_treebanks
129
+ if choose_charlm_method(*treebank_to_short_name(x).split("_", 1), 'default') is not None]
130
+ if command_args.transformer_only:
131
+ logger.info("Filtering ud_all treebanks to only those which can use a transformer for this model")
132
+ ud_treebanks = [x for x in ud_treebanks if treebank_to_short_name(x).split("_")[0] in TRANSFORMERS]
133
+ logger.info("Expanding %s to %s", treebank, " ".join(ud_treebanks))
134
+ treebanks.extend(ud_treebanks)
135
+ else:
136
+ treebanks.append(treebank)
137
+
138
+ for treebank_idx, treebank in enumerate(treebanks):
139
+ if treebank_idx > 0:
140
+ logger.info("=========================================")
141
+
142
+ short_name = treebank_to_short_name(treebank)
143
+ logger.debug("%s: %s" % (treebank, short_name))
144
+
145
+ save_name_args = []
146
+ if model_name != 'ete':
147
+ # ete is several models at once, so we don't set --save_name
148
+ # theoretically we could handle a parametrized save_name
149
+ if command_args.save_name:
150
+ save_name = command_args.save_name
151
+ # if there's more than 1 treebank, we can't save them all to this save_name
152
+ # we have to override that value for each treebank
153
+ if len(treebanks) > 1:
154
+ save_name_dir, save_name_filename = os.path.split(save_name)
155
+ save_name_filename = "%s_%s" % (short_name, save_name_filename)
156
+ save_name = os.path.join(save_name_dir, save_name_filename)
157
+ logger.info("Save file for %s model for %s: %s", short_name, treebank, save_name)
158
+ save_name_args = ['--save_name', save_name]
159
+ # some run scripts can build the model filename
160
+ # in order to check for models that are already created
161
+ elif build_model_filename is None:
162
+ save_name = "%s_%s.pt" % (short_name, model_name)
163
+ logger.info("Save file for %s model: %s", short_name, save_name)
164
+ save_name_args = ['--save_name', save_name]
165
+ else:
166
+ save_name_args = []
167
+
168
+ if mode == Mode.TRAIN and not command_args.force:
169
+ if build_model_filename is not None:
170
+ model_path = build_model_filename(paths, short_name, command_args, extra_args)
171
+ elif command_args.save_dir:
172
+ model_path = os.path.join(command_args.save_dir, save_name)
173
+ else:
174
+ save_dir = os.path.join("saved_models", model_dir)
175
+ save_name_args.extend(["--save_dir", save_dir])
176
+ model_path = os.path.join(save_dir, save_name)
177
+
178
+ if model_path is None:
179
+ # this can happen with the identity lemmatizer, for example
180
+ pass
181
+ elif os.path.exists(model_path):
182
+ logger.info("%s: %s exists, skipping!" % (treebank, model_path))
183
+ continue
184
+ else:
185
+ logger.info("%s: %s does not exist, training new model" % (treebank, model_path))
186
+
187
+ if command_args.temp_output and model_name != 'ete':
188
+ with tempfile.NamedTemporaryFile() as temp_output_file:
189
+ run_treebank(mode, paths, treebank, short_name,
190
+ temp_output_file.name, command_args, extra_args + save_name_args)
191
+ else:
192
+ run_treebank(mode, paths, treebank, short_name,
193
+ None, command_args, extra_args + save_name_args)
194
+
195
+ def run_eval_script(gold_conllu_file, system_conllu_file, evals=None):
196
+ """ Wrapper for lemma scorer. """
197
+ evaluation = ud_scores(gold_conllu_file, system_conllu_file)
198
+
199
+ if evals is None:
200
+ return ud_eval.build_evaluation_table(evaluation, verbose=True, counts=False, enhanced=False)
201
+ else:
202
+ results = [evaluation[key].f1 for key in evals]
203
+ max_len = max(5, max(len(e) for e in evals))
204
+ evals_string = " ".join(("{:>%d}" % max_len).format(e) for e in evals)
205
+ results_string = " ".join(("{:%d.2f}" % max_len).format(100 * x) for x in results)
206
+ return evals_string + "\n" + results_string
207
+
208
+ def run_eval_script_tokens(eval_gold, eval_pred):
209
+ return run_eval_script(eval_gold, eval_pred, evals=["Tokens", "Sentences", "Words"])
210
+
211
+ def run_eval_script_mwt(eval_gold, eval_pred):
212
+ return run_eval_script(eval_gold, eval_pred, evals=["Words"])
213
+
214
+ def run_eval_script_pos(eval_gold, eval_pred):
215
+ return run_eval_script(eval_gold, eval_pred, evals=["UPOS", "XPOS", "UFeats", "AllTags"])
216
+
217
+ def run_eval_script_depparse(eval_gold, eval_pred):
218
+ return run_eval_script(eval_gold, eval_pred, evals=["UAS", "LAS", "CLAS", "MLAS", "BLEX"])
219
+
220
+
221
+ def find_wordvec_pretrain(language, default_pretrains, dataset_pretrains=None, dataset=None, model_dir=DEFAULT_MODEL_DIR):
222
+ # try to get the default pretrain for the language,
223
+ # but allow the package specific value to override it if that is set
224
+ default_pt = default_pretrains.get(language, None)
225
+ if dataset is not None and dataset_pretrains is not None:
226
+ default_pt = dataset_pretrains.get(language, {}).get(dataset, default_pt)
227
+
228
+ if default_pt is not None:
229
+ default_pt_path = '{}/{}/pretrain/{}.pt'.format(model_dir, language, default_pt)
230
+ if not os.path.exists(default_pt_path):
231
+ logger.info("Default pretrain should be {} Attempting to download".format(default_pt_path))
232
+ try:
233
+ download(lang=language, package=None, processors={"pretrain": default_pt}, model_dir=model_dir)
234
+ except UnknownLanguageError:
235
+ # if there's a pretrain in the directory, hiding this
236
+ # error will let us find that pretrain later
237
+ pass
238
+ if os.path.exists(default_pt_path):
239
+ if dataset is not None and dataset_pretrains is not None and language in dataset_pretrains and dataset in dataset_pretrains[language]:
240
+ logger.info(f"Using default pretrain for {language}:{dataset}, found in {default_pt_path} To use a different pretrain, specify --wordvec_pretrain_file")
241
+ else:
242
+ logger.info(f"Using default pretrain for language, found in {default_pt_path} To use a different pretrain, specify --wordvec_pretrain_file")
243
+ return default_pt_path
244
+
245
+ pretrain_path = '{}/{}/pretrain/*.pt'.format(model_dir, language)
246
+ pretrains = glob.glob(pretrain_path)
247
+ if len(pretrains) == 0:
248
+ # we already tried to download the default pretrain once
249
+ # and it didn't work. maybe the default language package
250
+ # will have something?
251
+ logger.warning(f"Cannot figure out which pretrain to use for '{language}'. Will download the default package and hope for the best")
252
+ try:
253
+ download(lang=language, model_dir=model_dir)
254
+ except UnknownLanguageError as e:
255
+ # this is a very unusual situation
256
+ # basically, there was a language which we started to add
257
+ # to the resources, but then didn't release the models
258
+ # as part of resources.json
259
+ raise FileNotFoundError(f"Cannot find any pretrains in {pretrain_path} No pretrains in the system for this language. Please prepare an embedding as a .pt and use --wordvec_pretrain_file to specify a .pt file to use") from e
260
+ pretrains = glob.glob(pretrain_path)
261
+ if len(pretrains) == 0:
262
+ raise FileNotFoundError(f"Cannot find any pretrains in {pretrain_path} Try 'stanza.download(\"{language}\")' to get a default pretrain or use --wordvec_pretrain_file to specify a .pt file to use")
263
+ if len(pretrains) > 1:
264
+ raise FileNotFoundError(f"Too many pretrains to choose from in {pretrain_path} Must specify an exact path to a --wordvec_pretrain_file")
265
+ pt = pretrains[0]
266
+ logger.info(f"Using pretrain found in {pt} To use a different pretrain, specify --wordvec_pretrain_file")
267
+ return pt
268
+
269
+ def find_charlm_file(direction, language, charlm, model_dir=DEFAULT_MODEL_DIR):
270
+ """
271
+ Return the path to the forward or backward charlm if it exists for the given package
272
+
273
+ If we can figure out the package, but can't find it anywhere, we try to download it
274
+ """
275
+ saved_path = 'saved_models/charlm/{}_{}_{}_charlm.pt'.format(language, charlm, direction)
276
+ if os.path.exists(saved_path):
277
+ logger.info(f'Using model {saved_path} for {direction} charlm')
278
+ return saved_path
279
+
280
+ resource_path = '{}/{}/{}_charlm/{}.pt'.format(model_dir, language, direction, charlm)
281
+ if os.path.exists(resource_path):
282
+ logger.info(f'Using model {resource_path} for {direction} charlm')
283
+ return resource_path
284
+
285
+ try:
286
+ download(lang=language, package=None, processors={f"{direction}_charlm": charlm}, model_dir=model_dir)
287
+ if os.path.exists(resource_path):
288
+ logger.info(f'Downloaded model, using model {resource_path} for {direction} charlm')
289
+ return resource_path
290
+ except ValueError as e:
291
+ raise FileNotFoundError(f"Cannot find {direction} charlm in either {saved_path} or {resource_path} Attempted downloading {charlm} but that did not work") from e
292
+
293
+ raise FileNotFoundError(f"Cannot find {direction} charlm in either {saved_path} or {resource_path} Attempted downloading {charlm} but that did not work")
294
+
295
+ def build_charlm_args(language, charlm, base_args=True, model_dir=DEFAULT_MODEL_DIR):
296
+ """
297
+ If specified, return forward and backward charlm args
298
+ """
299
+ if charlm:
300
+ try:
301
+ forward = find_charlm_file('forward', language, charlm, model_dir=model_dir)
302
+ backward = find_charlm_file('backward', language, charlm, model_dir=model_dir)
303
+ except FileNotFoundError as e:
304
+ # if we couldn't find sd_isra when training an SD model,
305
+ # for example, but isra exists, we try to download the
306
+ # shorter model name
307
+ if charlm.startswith(language + "_"):
308
+ short_charlm = charlm[len(language)+1:]
309
+ try:
310
+ forward = find_charlm_file('forward', language, short_charlm, model_dir=model_dir)
311
+ backward = find_charlm_file('backward', language, short_charlm, model_dir=model_dir)
312
+ except FileNotFoundError as e2:
313
+ raise FileNotFoundError("Tried to find charlm %s, which doesn't exist. Also tried %s, but didn't find that either" % (charlm, short_charlm)) from e
314
+ logger.warning("Was asked to find charlm %s, which does not exist. Did find %s though", charlm, short_charlm)
315
+ else:
316
+ raise
317
+
318
+ char_args = ['--charlm_forward_file', forward,
319
+ '--charlm_backward_file', backward]
320
+ if not base_args:
321
+ return char_args
322
+ return ['--charlm',
323
+ '--charlm_shorthand', f'{language}_{charlm}'] + char_args
324
+
325
+ return []
326
+
327
+ def choose_charlm(language, dataset, charlm, language_charlms, dataset_charlms):
328
+ """
329
+ charlm == "default" means the default charlm for this dataset or language
330
+ charlm == None is no charlm
331
+ """
332
+ default_charlm = language_charlms.get(language, None)
333
+ specific_charlm = dataset_charlms.get(language, {}).get(dataset, None)
334
+
335
+ if charlm is None:
336
+ return None
337
+ elif charlm != "default":
338
+ return charlm
339
+ elif dataset in dataset_charlms.get(language, {}):
340
+ # this way, a "" or None result gets honored
341
+ # thus treating "not in the map" as a way for dataset_charlms to signal to use the default
342
+ return specific_charlm
343
+ elif default_charlm:
344
+ return default_charlm
345
+ else:
346
+ return None
347
+
348
+ def choose_pos_charlm(short_language, dataset, charlm):
349
+ """
350
+ charlm == "default" means the default charlm for this dataset or language
351
+ charlm == None is no charlm
352
+ """
353
+ return choose_charlm(short_language, dataset, charlm, default_charlms, pos_charlms)
354
+
355
+ def choose_depparse_charlm(short_language, dataset, charlm):
356
+ """
357
+ charlm == "default" means the default charlm for this dataset or language
358
+ charlm == None is no charlm
359
+ """
360
+ return choose_charlm(short_language, dataset, charlm, default_charlms, depparse_charlms)
361
+
362
+ def choose_lemma_charlm(short_language, dataset, charlm):
363
+ """
364
+ charlm == "default" means the default charlm for this dataset or language
365
+ charlm == None is no charlm
366
+ """
367
+ return choose_charlm(short_language, dataset, charlm, default_charlms, lemma_charlms)
368
+
369
+ def choose_transformer(short_language, command_args, extra_args, warn=True, layers=False):
370
+ """
371
+ Choose a transformer using the default options for this language
372
+ """
373
+ bert_args = []
374
+ if command_args is not None and command_args.use_bert and '--bert_model' not in extra_args:
375
+ if short_language in TRANSFORMERS:
376
+ bert_args = ['--bert_model', TRANSFORMERS.get(short_language)]
377
+ if layers and short_language in TRANSFORMER_LAYERS and '--bert_hidden_layers' not in extra_args:
378
+ bert_args.extend(['--bert_hidden_layers', str(TRANSFORMER_LAYERS.get(short_language))])
379
+ elif warn:
380
+ logger.error("Transformer requested, but no default transformer for %s Specify one using --bert_model" % short_language)
381
+
382
+ return bert_args
383
+
384
+ def build_pos_charlm_args(short_language, dataset, charlm, base_args=True, model_dir=DEFAULT_MODEL_DIR):
385
+ charlm = choose_pos_charlm(short_language, dataset, charlm)
386
+ charlm_args = build_charlm_args(short_language, charlm, base_args, model_dir)
387
+ return charlm_args
388
+
389
+ def build_lemma_charlm_args(short_language, dataset, charlm, base_args=True, model_dir=DEFAULT_MODEL_DIR):
390
+ charlm = choose_lemma_charlm(short_language, dataset, charlm)
391
+ charlm_args = build_charlm_args(short_language, charlm, base_args, model_dir)
392
+ return charlm_args
393
+
394
+ def build_depparse_charlm_args(short_language, dataset, charlm, base_args=True, model_dir=DEFAULT_MODEL_DIR):
395
+ charlm = choose_depparse_charlm(short_language, dataset, charlm)
396
+ charlm_args = build_charlm_args(short_language, charlm, base_args, model_dir)
397
+ return charlm_args
stanza/stanza/utils/training/compose_ete_results.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Turn the ETE results into markdown
3
+
4
+ Parses blocks like this from the model eval script
5
+
6
+ 2022-01-14 01:23:34 INFO: End to end results for af_afribooms models on af_afribooms test data:
7
+ Metric | Precision | Recall | F1 Score | AligndAcc
8
+ -----------+-----------+-----------+-----------+-----------
9
+ Tokens | 99.93 | 99.92 | 99.93 |
10
+ Sentences | 100.00 | 100.00 | 100.00 |
11
+ Words | 99.93 | 99.92 | 99.93 |
12
+ UPOS | 97.97 | 97.96 | 97.97 | 98.04
13
+ XPOS | 93.98 | 93.97 | 93.97 | 94.04
14
+ UFeats | 97.23 | 97.22 | 97.22 | 97.29
15
+ AllTags | 93.89 | 93.88 | 93.88 | 93.95
16
+ Lemmas | 97.40 | 97.39 | 97.39 | 97.46
17
+ UAS | 87.39 | 87.38 | 87.38 | 87.45
18
+ LAS | 83.57 | 83.56 | 83.57 | 83.63
19
+ CLAS | 76.88 | 76.45 | 76.66 | 76.52
20
+ MLAS | 72.28 | 71.87 | 72.07 | 71.94
21
+ BLEX | 73.20 | 72.79 | 73.00 | 72.86
22
+
23
+
24
+ Turns them into a markdown table.
25
+
26
+ Included is an attempt to mark the default packages with a green check.
27
+ <i class="fas fa-check" style="color:#33a02c"></i>
28
+ """
29
+
30
+ import argparse
31
+
32
+ from stanza.models.common.constant import pretty_langcode_to_lang
33
+ from stanza.models.common.short_name_to_treebank import short_name_to_treebank
34
+ from stanza.utils.training.run_ete import RESULTS_STRING
35
+ from stanza.resources.default_packages import default_treebanks
36
+
37
+ EXPECTED_ORDER = ["Tokens", "Sentences", "Words", "UPOS", "XPOS", "UFeats", "AllTags", "Lemmas", "UAS", "LAS", "CLAS", "MLAS", "BLEX"]
38
+
39
+ parser = argparse.ArgumentParser()
40
+ parser.add_argument("filenames", type=str, nargs="+", help="Which file(s) to read")
41
+ args = parser.parse_args()
42
+
43
+ lines = []
44
+ for filename in args.filenames:
45
+ with open(filename) as fin:
46
+ lines.extend(fin.readlines())
47
+
48
+ blocks = []
49
+ index = 0
50
+ while index < len(lines):
51
+ line = lines[index]
52
+ if line.find(RESULTS_STRING) < 0:
53
+ index = index + 1
54
+ continue
55
+
56
+ line = line[line.find(RESULTS_STRING) + len(RESULTS_STRING):].strip()
57
+ short_name = line.split()[0]
58
+
59
+ # skip the header of the expected output
60
+ index = index + 1
61
+ line = lines[index]
62
+ pieces = line.split("|")
63
+ assert pieces[0].strip() == 'Metric', "output format changed?"
64
+ assert pieces[3].strip() == 'F1 Score', "output format changed?"
65
+
66
+ index = index + 1
67
+ line = lines[index]
68
+ assert line.startswith("-----"), "output format changed?"
69
+
70
+ index = index + 1
71
+
72
+ block = lines[index:index+13]
73
+ assert len(block) == 13
74
+ index = index + 13
75
+
76
+ block = [x.split("|") for x in block]
77
+ assert all(x[0].strip() == y for x, y in zip(block, EXPECTED_ORDER)), "output format changed?"
78
+ lcode, short_dataset = short_name.split("_", 1)
79
+ language = pretty_langcode_to_lang(lcode)
80
+ treebank = short_name_to_treebank(short_name)
81
+ long_dataset = treebank.split("-")[-1]
82
+
83
+ checkmark = ""
84
+ if default_treebanks[lcode] == short_dataset:
85
+ checkmark = '<i class="fas fa-check" style="color:#33a02c"></i>'
86
+
87
+ block = [language, "[%s](%s)" % (long_dataset, "https://github.com/UniversalDependencies/%s" % treebank), lcode, checkmark] + [x[3].strip() for x in block]
88
+ blocks.append(block)
89
+
90
+ PREFIX = ["&#8203;Macro Avg", "&#8203;", "&#8203;", ""]
91
+
92
+ avg = [sum(float(x[i]) for x in blocks) / len(blocks) for i in range(len(PREFIX), len(EXPECTED_ORDER) + len(PREFIX))]
93
+ avg = PREFIX + ["%.2f" % x for x in avg]
94
+ blocks = sorted(blocks)
95
+ blocks = [avg] + blocks
96
+
97
+ chart = ["|%s|" % " | ".join(x) for x in blocks]
98
+ for line in chart:
99
+ print(line)
100
+
stanza/stanza/utils/training/run_charlm.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Trains or scores a charlm model.
3
+ """
4
+
5
+ import logging
6
+ import os
7
+
8
+ from stanza.models import charlm
9
+ from stanza.utils.training import common
10
+ from stanza.utils.training.common import Mode
11
+
12
+ logger = logging.getLogger('stanza')
13
+
14
+
15
+ def add_charlm_args(parser):
16
+ """
17
+ Extra args for the charlm: forward/backward
18
+ """
19
+ parser.add_argument('--direction', default='forward', choices=['forward', 'backward'], help="Forward or backward language model")
20
+ parser.add_argument('--forward', action='store_const', dest='direction', const='forward', help="Train a forward language model")
21
+ parser.add_argument('--backward', action='store_const', dest='direction', const='backward', help="Train a backward language model")
22
+
23
+
24
+ def run_treebank(mode, paths, treebank, short_name,
25
+ temp_output_file, command_args, extra_args):
26
+ short_language, dataset_name = short_name.split("_", 1)
27
+
28
+ train_dir = os.path.join(paths["CHARLM_DATA_DIR"], short_language, dataset_name, "train")
29
+
30
+ dev_file = os.path.join(paths["CHARLM_DATA_DIR"], short_language, dataset_name, "dev.txt")
31
+ if not os.path.exists(dev_file) and os.path.exists(dev_file + ".xz"):
32
+ dev_file = dev_file + ".xz"
33
+
34
+ test_file = os.path.join(paths["CHARLM_DATA_DIR"], short_language, dataset_name, "test.txt")
35
+ if not os.path.exists(test_file) and os.path.exists(test_file + ".xz"):
36
+ test_file = test_file + ".xz"
37
+
38
+ # python -m stanza.models.charlm --train_dir $train_dir --eval_file $dev_file \
39
+ # --direction $direction --shorthand $short --mode train $args
40
+ # python -m stanza.models.charlm --eval_file $dev_file \
41
+ # --direction $direction --shorthand $short --mode predict $args
42
+ # python -m stanza.models.charlm --eval_file $test_file \
43
+ # --direction $direction --shorthand $short --mode predict $args
44
+
45
+ direction = command_args.direction
46
+ default_args = ['--%s' % direction,
47
+ '--shorthand', short_name]
48
+ if mode == Mode.TRAIN:
49
+ train_args = ['--mode', 'train']
50
+ if '--train_dir' not in extra_args:
51
+ train_args += ['--train_dir', train_dir]
52
+ if '--eval_file' not in extra_args:
53
+ train_args += ['--eval_file', dev_file]
54
+ train_args = train_args + default_args + extra_args
55
+ logger.info("Running train step with args: %s", train_args)
56
+ charlm.main(train_args)
57
+
58
+ if mode == Mode.SCORE_DEV:
59
+ dev_args = ['--mode', 'predict']
60
+ if '--eval_file' not in extra_args:
61
+ dev_args += ['--eval_file', dev_file]
62
+ dev_args = dev_args + default_args + extra_args
63
+ logger.info("Running dev step with args: %s", dev_args)
64
+ charlm.main(dev_args)
65
+
66
+ if mode == Mode.SCORE_TEST:
67
+ test_args = ['--mode', 'predict']
68
+ if '--eval_file' not in extra_args:
69
+ test_args += ['--eval_file', test_file]
70
+ test_args = test_args + default_args + extra_args
71
+ logger.info("Running test step with args: %s", test_args)
72
+ charlm.main(test_args)
73
+
74
+
75
+ def get_model_name(args):
76
+ """
77
+ The charlm saves forward and backward charlms to the same dir, but with different filenames
78
+ """
79
+ return "%s_charlm" % args.direction
80
+
81
+ def main():
82
+ common.main(run_treebank, "charlm", get_model_name, add_charlm_args, charlm.build_argparse())
83
+
84
+ if __name__ == "__main__":
85
+ main()
86
+
stanza/stanza/utils/training/run_constituency.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Trains or scores a constituency model.
3
+
4
+ Currently a suuuuper preliminary script.
5
+
6
+ Example of how to run on multiple parsers at the same time on the Stanford workqueue:
7
+
8
+ for i in `echo 1000 1001 1002 1003 1004`; do nlprun -d a6000 "python3 stanza/utils/training/run_constituency.py vi_vlsp23 --use_bert --stage1_bert_finetun --save_name vi_vlsp23_$i.pt --seed $i --epochs 200 --force" -o vi_vlsp23_$i.out; done
9
+
10
+ """
11
+
12
+ import logging
13
+ import os
14
+
15
+ from stanza.models import constituency_parser
16
+ from stanza.models.constituency.retagging import RETAG_METHOD
17
+ from stanza.utils.datasets.constituency import prepare_con_dataset
18
+ from stanza.utils.training import common
19
+ from stanza.utils.training.common import Mode, add_charlm_args, build_charlm_args, choose_charlm, find_wordvec_pretrain
20
+
21
+ from stanza.resources.default_packages import default_charlms, default_pretrains
22
+
23
+ logger = logging.getLogger('stanza')
24
+
25
+ def add_constituency_args(parser):
26
+ add_charlm_args(parser)
27
+
28
+ parser.add_argument('--use_bert', default=False, action="store_true", help='Use the default transformer for this language')
29
+
30
+ parser.add_argument('--parse_text', dest='mode', action='store_const', const="parse_text", help='Parse a text file')
31
+
32
+ def build_wordvec_args(short_language, dataset, extra_args):
33
+ if '--wordvec_pretrain_file' not in extra_args:
34
+ # will throw an error if the pretrain can't be found
35
+ wordvec_pretrain = find_wordvec_pretrain(short_language, default_pretrains)
36
+ wordvec_args = ['--wordvec_pretrain_file', wordvec_pretrain]
37
+ else:
38
+ wordvec_args = []
39
+
40
+ return wordvec_args
41
+
42
+ def build_default_args(paths, short_language, dataset, command_args, extra_args):
43
+ if short_language in RETAG_METHOD:
44
+ retag_args = ["--retag_method", RETAG_METHOD[short_language]]
45
+ else:
46
+ retag_args = []
47
+
48
+ wordvec_args = build_wordvec_args(short_language, dataset, extra_args)
49
+
50
+ charlm = choose_charlm(short_language, dataset, command_args.charlm, default_charlms, {})
51
+ charlm_args = build_charlm_args(short_language, charlm, base_args=False)
52
+
53
+ bert_args = common.choose_transformer(short_language, command_args, extra_args, warn=True, layers=True)
54
+ default_args = retag_args + wordvec_args + charlm_args + bert_args
55
+
56
+ return default_args
57
+
58
+ def build_model_filename(paths, short_name, command_args, extra_args):
59
+ short_language, dataset = short_name.split("_", 1)
60
+
61
+ default_args = build_default_args(paths, short_language, dataset, command_args, extra_args)
62
+
63
+ train_args = ["--shorthand", short_name,
64
+ "--mode", "train"]
65
+ train_args = train_args + default_args
66
+ if command_args.save_name is not None:
67
+ train_args.extend(["--save_name", command_args.save_name])
68
+ if command_args.save_dir is not None:
69
+ train_args.extend(["--save_dir", command_args.save_dir])
70
+ args = constituency_parser.parse_args(train_args)
71
+ save_name = constituency_parser.build_model_filename(args)
72
+ return save_name
73
+
74
+
75
+ def run_treebank(mode, paths, treebank, short_name, temp_output_file, command_args, extra_args):
76
+ constituency_dir = paths["CONSTITUENCY_DATA_DIR"]
77
+ short_language, dataset = short_name.split("_")
78
+
79
+ train_file = os.path.join(constituency_dir, f"{short_name}_train.mrg")
80
+ dev_file = os.path.join(constituency_dir, f"{short_name}_dev.mrg")
81
+ test_file = os.path.join(constituency_dir, f"{short_name}_test.mrg")
82
+
83
+ if not os.path.exists(train_file) or not os.path.exists(dev_file) or not os.path.exists(test_file):
84
+ logger.warning(f"The data for {short_name} is missing or incomplete. Attempting to rebuild...")
85
+ try:
86
+ prepare_con_dataset.main(short_name)
87
+ except:
88
+ logger.error(f"Unable to build the data. Please correctly build the files in {train_file}, {dev_file}, {test_file} and then try again.")
89
+ raise
90
+
91
+ default_args = build_default_args(paths, short_language, dataset, command_args, extra_args)
92
+
93
+ if mode == Mode.TRAIN:
94
+ train_args = ['--train_file', train_file,
95
+ '--eval_file', dev_file,
96
+ '--shorthand', short_name,
97
+ '--mode', 'train']
98
+ train_args = train_args + default_args + extra_args
99
+ logger.info("Running train step with args: {}".format(train_args))
100
+ constituency_parser.main(train_args)
101
+
102
+ if mode == Mode.SCORE_DEV or mode == Mode.TRAIN:
103
+ dev_args = ['--eval_file', dev_file,
104
+ '--shorthand', short_name,
105
+ '--mode', 'predict']
106
+ dev_args = dev_args + default_args + extra_args
107
+ logger.info("Running dev step with args: {}".format(dev_args))
108
+ constituency_parser.main(dev_args)
109
+
110
+ if mode == Mode.SCORE_TEST or mode == Mode.TRAIN:
111
+ test_args = ['--eval_file', test_file,
112
+ '--shorthand', short_name,
113
+ '--mode', 'predict']
114
+ test_args = test_args + default_args + extra_args
115
+ logger.info("Running test step with args: {}".format(test_args))
116
+ constituency_parser.main(test_args)
117
+
118
+ if mode == "parse_text":
119
+ text_args = ['--shorthand', short_name,
120
+ '--mode', 'parse_text']
121
+ text_args = text_args + default_args + extra_args
122
+ logger.info("Processing text with args: {}".format(text_args))
123
+ constituency_parser.main(text_args)
124
+
125
+ def main():
126
+ common.main(run_treebank, "constituency", "constituency", add_constituency_args, sub_argparse=constituency_parser.build_argparse(), build_model_filename=build_model_filename)
127
+
128
+ if __name__ == "__main__":
129
+ main()
130
+
stanza/stanza/utils/training/run_depparse.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+
4
+ from stanza.models import parser
5
+
6
+ from stanza.utils.training import common
7
+ from stanza.utils.training.common import Mode, add_charlm_args, build_depparse_charlm_args, choose_depparse_charlm, choose_transformer
8
+ from stanza.utils.training.run_pos import wordvec_args
9
+
10
+ from stanza.resources.default_packages import default_charlms, depparse_charlms
11
+
12
+ logger = logging.getLogger('stanza')
13
+
14
+ def add_depparse_args(parser):
15
+ add_charlm_args(parser)
16
+
17
+ parser.add_argument('--use_bert', default=False, action="store_true", help='Use the default transformer for this language')
18
+
19
+ # TODO: refactor with run_pos
20
+ def build_model_filename(paths, short_name, command_args, extra_args):
21
+ short_language, dataset = short_name.split("_", 1)
22
+
23
+ # TODO: can avoid downloading the charlm at this point, since we
24
+ # might not even be training
25
+ charlm_args = build_depparse_charlm_args(short_language, dataset, command_args.charlm)
26
+
27
+ bert_args = choose_transformer(short_language, command_args, extra_args, warn=False)
28
+
29
+ train_args = ["--shorthand", short_name,
30
+ "--mode", "train"]
31
+ # TODO: also, this downloads the wordvec, which we might not want to do yet
32
+ train_args = train_args + wordvec_args(short_language, dataset, extra_args) + charlm_args + bert_args + extra_args
33
+ if command_args.save_name is not None:
34
+ train_args.extend(["--save_name", command_args.save_name])
35
+ if command_args.save_dir is not None:
36
+ train_args.extend(["--save_dir", command_args.save_dir])
37
+ args = parser.parse_args(train_args)
38
+ save_name = parser.model_file_name(args)
39
+ return save_name
40
+
41
+
42
+ def run_treebank(mode, paths, treebank, short_name,
43
+ temp_output_file, command_args, extra_args):
44
+ short_language, dataset = short_name.split("_")
45
+
46
+ # TODO: refactor these blocks?
47
+ depparse_dir = paths["DEPPARSE_DATA_DIR"]
48
+ train_file = f"{depparse_dir}/{short_name}.train.in.conllu"
49
+ dev_in_file = f"{depparse_dir}/{short_name}.dev.in.conllu"
50
+ dev_pred_file = temp_output_file if temp_output_file else f"{depparse_dir}/{short_name}.dev.pred.conllu"
51
+ test_in_file = f"{depparse_dir}/{short_name}.test.in.conllu"
52
+ test_pred_file = temp_output_file if temp_output_file else f"{depparse_dir}/{short_name}.test.pred.conllu"
53
+
54
+ eval_file = None
55
+ if '--eval_file' in extra_args:
56
+ eval_file = extra_args[extra_args.index('--eval_file') + 1]
57
+
58
+ charlm_args = build_depparse_charlm_args(short_language, dataset, command_args.charlm)
59
+
60
+ bert_args = choose_transformer(short_language, command_args, extra_args)
61
+
62
+ if mode == Mode.TRAIN:
63
+ if not os.path.exists(train_file):
64
+ logger.error("TRAIN FILE NOT FOUND: %s ... skipping" % train_file)
65
+ return
66
+
67
+ # some languages need reduced batch size
68
+ if short_name == 'de_hdt':
69
+ # 'UD_German-HDT'
70
+ batch_size = "1300"
71
+ elif short_name in ('hr_set', 'fi_tdt', 'ru_taiga', 'cs_cltt', 'gl_treegal', 'lv_lvtb', 'ro_simonero'):
72
+ # 'UD_Croatian-SET', 'UD_Finnish-TDT', 'UD_Russian-Taiga',
73
+ # 'UD_Czech-CLTT', 'UD_Galician-TreeGal', 'UD_Latvian-LVTB' 'Romanian-SiMoNERo'
74
+ batch_size = "3000"
75
+ else:
76
+ batch_size = "5000"
77
+
78
+ train_args = ["--wordvec_dir", paths["WORDVEC_DIR"],
79
+ "--train_file", train_file,
80
+ "--eval_file", eval_file if eval_file else dev_in_file,
81
+ "--output_file", dev_pred_file,
82
+ "--batch_size", batch_size,
83
+ "--lang", short_language,
84
+ "--shorthand", short_name,
85
+ "--mode", "train"]
86
+ train_args = train_args + wordvec_args(short_language, dataset, extra_args) + charlm_args + bert_args
87
+ train_args = train_args + extra_args
88
+ logger.info("Running train depparse for {} with args {}".format(treebank, train_args))
89
+ parser.main(train_args)
90
+
91
+ if mode == Mode.SCORE_DEV or mode == Mode.TRAIN:
92
+ dev_args = ["--wordvec_dir", paths["WORDVEC_DIR"],
93
+ "--eval_file", eval_file if eval_file else dev_in_file,
94
+ "--output_file", dev_pred_file,
95
+ "--lang", short_language,
96
+ "--shorthand", short_name,
97
+ "--mode", "predict"]
98
+ dev_args = dev_args + wordvec_args(short_language, dataset, extra_args) + charlm_args + bert_args
99
+ dev_args = dev_args + extra_args
100
+ logger.info("Running dev depparse for {} with args {}".format(treebank, dev_args))
101
+ parser.main(dev_args)
102
+
103
+ if '--no_gold_labels' not in extra_args:
104
+ results = common.run_eval_script_depparse(eval_file if eval_file else dev_in_file, dev_pred_file)
105
+ logger.info("Finished running dev set on\n{}\n{}".format(treebank, results))
106
+ if not temp_output_file:
107
+ logger.info("Output saved to %s", dev_pred_file)
108
+
109
+ if mode == Mode.SCORE_TEST:
110
+ test_args = ["--wordvec_dir", paths["WORDVEC_DIR"],
111
+ "--eval_file", eval_file if eval_file else test_in_file,
112
+ "--output_file", test_pred_file,
113
+ "--lang", short_language,
114
+ "--shorthand", short_name,
115
+ "--mode", "predict"]
116
+ test_args = test_args + wordvec_args(short_language, dataset, extra_args) + charlm_args + bert_args
117
+ test_args = test_args + extra_args
118
+ logger.info("Running test depparse for {} with args {}".format(treebank, test_args))
119
+ parser.main(test_args)
120
+
121
+ if '--no_gold_labels' not in extra_args:
122
+ results = common.run_eval_script_depparse(eval_file if eval_file else test_in_file, test_pred_file)
123
+ logger.info("Finished running test set on\n{}\n{}".format(treebank, results))
124
+ if not temp_output_file:
125
+ logger.info("Output saved to %s", test_pred_file)
126
+
127
+
128
+ def main():
129
+ common.main(run_treebank, "depparse", "parser", add_depparse_args, sub_argparse=parser.build_argparse(), build_model_filename=build_model_filename, choose_charlm_method=choose_depparse_charlm)
130
+
131
+ if __name__ == "__main__":
132
+ main()
133
+
stanza/stanza/utils/training/run_lemma.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This script allows for training or testing on dev / test of the UD lemmatizer.
3
+
4
+ If run with a single treebank name, it will train or test that treebank.
5
+ If run with ud_all or all_ud, it will iterate over all UD treebanks it can find.
6
+
7
+ Mode can be set to train&dev with --train, to dev set only
8
+ with --score_dev, and to test set only with --score_test.
9
+
10
+ Treebanks are specified as a list. all_ud or ud_all means to look for
11
+ all UD treebanks.
12
+
13
+ Extra arguments are passed to the lemmatizer. In case the run script
14
+ itself is shadowing arguments, you can specify --extra_args as a
15
+ parameter to mark where the lemmatizer arguments start.
16
+ """
17
+
18
+ import logging
19
+ import os
20
+
21
+ from stanza.models import identity_lemmatizer
22
+ from stanza.models import lemmatizer
23
+ from stanza.models.lemma import attach_lemma_classifier
24
+
25
+ from stanza.utils.training import common
26
+ from stanza.utils.training.common import Mode, add_charlm_args, build_lemma_charlm_args, choose_lemma_charlm
27
+ from stanza.utils.training import run_lemma_classifier
28
+
29
+ from stanza.utils.datasets.prepare_lemma_treebank import check_lemmas
30
+ import stanza.utils.datasets.prepare_lemma_classifier as prepare_lemma_classifier
31
+
32
+ logger = logging.getLogger('stanza')
33
+
34
+ def add_lemma_args(parser):
35
+ add_charlm_args(parser)
36
+
37
+ parser.add_argument('--lemma_classifier', dest='lemma_classifier', action='store_true', default=None,
38
+ help="Don't use the lemma classifier datasets. Default is to build lemma classifier as part of the original lemmatizer if the charlm is used")
39
+ parser.add_argument('--no_lemma_classifier', dest='lemma_classifier', action='store_false',
40
+ help="Don't use the lemma classifier datasets. Default is to build lemma classifier as part of the original lemmatizer if the charlm is used")
41
+
42
+ def build_model_filename(paths, short_name, command_args, extra_args):
43
+ """
44
+ Figure out what the model savename will be, taking into account the model settings.
45
+
46
+ Useful for figuring out if the model already exists
47
+
48
+ None will represent that there is no expected save_name
49
+ """
50
+ short_language, dataset = short_name.split("_", 1)
51
+
52
+ lemma_dir = paths["LEMMA_DATA_DIR"]
53
+ train_file = f"{lemma_dir}/{short_name}.train.in.conllu"
54
+
55
+ if not os.path.exists(train_file):
56
+ logger.debug("Treebank %s is not prepared for training the lemmatizer. Could not find any training data at %s Cannot figure out the expected save_name without looking at the data, but a later step in the process will skip the training anyway" % (short_name, train_file))
57
+ return None
58
+
59
+ has_lemmas = check_lemmas(train_file)
60
+ if not has_lemmas:
61
+ return None
62
+
63
+ # TODO: can avoid downloading the charlm at this point, since we
64
+ # might not even be training
65
+ charlm_args = build_lemma_charlm_args(short_language, dataset, command_args.charlm)
66
+
67
+ train_args = ["--train_file", train_file,
68
+ "--shorthand", short_name,
69
+ "--mode", "train"]
70
+ train_args = train_args + charlm_args + extra_args
71
+ args = lemmatizer.parse_args(train_args)
72
+ save_name = lemmatizer.build_model_filename(args)
73
+ return save_name
74
+
75
+ def run_treebank(mode, paths, treebank, short_name,
76
+ temp_output_file, command_args, extra_args):
77
+ short_language, dataset = short_name.split("_", 1)
78
+
79
+ lemma_dir = paths["LEMMA_DATA_DIR"]
80
+ train_file = f"{lemma_dir}/{short_name}.train.in.conllu"
81
+ dev_in_file = f"{lemma_dir}/{short_name}.dev.in.conllu"
82
+ dev_gold_file = f"{lemma_dir}/{short_name}.dev.gold.conllu"
83
+ dev_pred_file = temp_output_file if temp_output_file else f"{lemma_dir}/{short_name}.dev.pred.conllu"
84
+ test_in_file = f"{lemma_dir}/{short_name}.test.in.conllu"
85
+ test_gold_file = f"{lemma_dir}/{short_name}.test.gold.conllu"
86
+ test_pred_file = temp_output_file if temp_output_file else f"{lemma_dir}/{short_name}.test.pred.conllu"
87
+
88
+ charlm_args = build_lemma_charlm_args(short_language, dataset, command_args.charlm)
89
+
90
+ if not os.path.exists(train_file):
91
+ logger.error("Treebank %s is not prepared for training the lemmatizer. Could not find any training data at %s Skipping..." % (treebank, train_file))
92
+ return
93
+
94
+ has_lemmas = check_lemmas(train_file)
95
+ if not has_lemmas:
96
+ logger.info("Treebank " + treebank + " (" + short_name +
97
+ ") has no lemmas. Using identity lemmatizer")
98
+ if mode == Mode.TRAIN or mode == Mode.SCORE_DEV:
99
+ train_args = ["--train_file", train_file,
100
+ "--eval_file", dev_in_file,
101
+ "--output_file", dev_pred_file,
102
+ "--gold_file", dev_gold_file,
103
+ "--shorthand", short_name]
104
+ logger.info("Running identity lemmatizer for {} with args {}".format(treebank, train_args))
105
+ identity_lemmatizer.main(train_args)
106
+ elif mode == Mode.SCORE_TEST:
107
+ train_args = ["--train_file", train_file,
108
+ "--eval_file", test_in_file,
109
+ "--output_file", test_pred_file,
110
+ "--gold_file", test_gold_file,
111
+ "--shorthand", short_name]
112
+ logger.info("Running identity lemmatizer for {} with args {}".format(treebank, train_args))
113
+ identity_lemmatizer.main(train_args)
114
+ else:
115
+ if mode == Mode.TRAIN:
116
+ # ('UD_Czech-PDT', 'UD_Russian-SynTagRus', 'UD_German-HDT')
117
+ if short_name in ('cs_pdt', 'ru_syntagrus', 'de_hdt'):
118
+ num_epochs = "30"
119
+ else:
120
+ num_epochs = "60"
121
+
122
+ train_args = ["--train_file", train_file,
123
+ "--eval_file", dev_in_file,
124
+ "--output_file", dev_pred_file,
125
+ "--gold_file", dev_gold_file,
126
+ "--shorthand", short_name,
127
+ "--num_epoch", num_epochs,
128
+ "--mode", "train"]
129
+ train_args = train_args + charlm_args + extra_args
130
+ logger.info("Running train lemmatizer for {} with args {}".format(treebank, train_args))
131
+ lemmatizer.main(train_args)
132
+
133
+ if mode == Mode.SCORE_DEV or mode == Mode.TRAIN:
134
+ dev_args = ["--eval_file", dev_in_file,
135
+ "--output_file", dev_pred_file,
136
+ "--gold_file", dev_gold_file,
137
+ "--shorthand", short_name,
138
+ "--mode", "predict"]
139
+ dev_args = dev_args + charlm_args + extra_args
140
+ logger.info("Running dev lemmatizer for {} with args {}".format(treebank, dev_args))
141
+ lemmatizer.main(dev_args)
142
+
143
+ if mode == Mode.SCORE_TEST:
144
+ test_args = ["--eval_file", test_in_file,
145
+ "--output_file", test_pred_file,
146
+ "--gold_file", test_gold_file,
147
+ "--shorthand", short_name,
148
+ "--mode", "predict"]
149
+ test_args = test_args + charlm_args + extra_args
150
+ logger.info("Running test lemmatizer for {} with args {}".format(treebank, test_args))
151
+ lemmatizer.main(test_args)
152
+
153
+ use_lemma_classifier = command_args.lemma_classifier
154
+ if use_lemma_classifier is None:
155
+ use_lemma_classifier = command_args.charlm is not None
156
+ use_lemma_classifier = use_lemma_classifier and short_name in prepare_lemma_classifier.DATASET_MAPPING
157
+ if use_lemma_classifier and mode == Mode.TRAIN:
158
+ lc_charlm_args = ['--no_charlm'] if command_args.charlm is None else ['--charlm', command_args.charlm]
159
+ lemma_classifier_args = [treebank] + lc_charlm_args
160
+ if command_args.force:
161
+ lemma_classifier_args.append('--force')
162
+ run_lemma_classifier.main(lemma_classifier_args)
163
+
164
+ save_name = build_model_filename(paths, short_name, command_args, extra_args)
165
+ # TODO: use a temp path for the lemma_classifier or keep it somewhere
166
+ attach_args = ['--input', save_name,
167
+ '--output', save_name,
168
+ '--classifier', 'saved_models/lemma_classifier/%s_lemma_classifier.pt' % short_name]
169
+ attach_lemma_classifier.main(attach_args)
170
+
171
+ # now we rerun the dev set - the HI in particular demonstrates some good improvement
172
+ lemmatizer.main(dev_args)
173
+
174
+ def main():
175
+ common.main(run_treebank, "lemma", "lemmatizer", add_lemma_args, sub_argparse=lemmatizer.build_argparse(), build_model_filename=build_model_filename, choose_charlm_method=choose_lemma_charlm)
176
+
177
+ if __name__ == "__main__":
178
+ main()
179
+
stanza/stanza/utils/training/run_lemma_classifier.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from stanza.models.lemma_classifier import evaluate_models
4
+ from stanza.models.lemma_classifier import train_lstm_model
5
+ from stanza.models.lemma_classifier import train_transformer_model
6
+ from stanza.models.lemma_classifier.constants import ModelType
7
+
8
+ from stanza.resources.default_packages import default_pretrains, TRANSFORMERS
9
+ from stanza.utils.training import common
10
+ from stanza.utils.training.common import Mode, add_charlm_args, build_lemma_charlm_args, choose_lemma_charlm, find_wordvec_pretrain
11
+
12
+ def add_lemma_args(parser):
13
+ add_charlm_args(parser)
14
+
15
+ parser.add_argument('--model_type', default=ModelType.LSTM, type=lambda x: ModelType[x.upper()],
16
+ help='Model type to use. {}'.format(", ".join(x.name for x in ModelType)))
17
+
18
+ def build_model_filename(paths, short_name, command_args, extra_args):
19
+ return os.path.join("saved_models", "lemma_classifier", short_name + "_lemma_classifier.pt")
20
+
21
+ def run_treebank(mode, paths, treebank, short_name,
22
+ temp_output_file, command_args, extra_args):
23
+ short_language, dataset = short_name.split("_", 1)
24
+
25
+ base_args = []
26
+ if '--save_name' not in extra_args:
27
+ base_args += ['--save_name', build_model_filename(paths, short_name, command_args, extra_args)]
28
+
29
+ embedding_args = build_lemma_charlm_args(short_language, dataset, command_args.charlm)
30
+ if '--wordvec_pretrain_file' not in extra_args:
31
+ wordvec_pretrain = find_wordvec_pretrain(short_language, default_pretrains, {}, dataset)
32
+ embedding_args += ["--wordvec_pretrain_file", wordvec_pretrain]
33
+
34
+ bert_args = []
35
+ if command_args.model_type is ModelType.TRANSFORMER:
36
+ if '--bert_model' not in extra_args:
37
+ if short_language in TRANSFORMERS:
38
+ bert_args = ['--bert_model', TRANSFORMERS.get(short_language)]
39
+ else:
40
+ raise ValueError("--bert_model not specified, so cannot figure out which transformer to use for language %s" % short_language)
41
+
42
+ extra_train_args = []
43
+ if command_args.force:
44
+ extra_train_args.append('--force')
45
+
46
+ if mode == Mode.TRAIN:
47
+ train_args = []
48
+ if "--train_file" not in extra_args:
49
+ train_file = os.path.join("data", "lemma_classifier", "%s.train.lemma" % short_name)
50
+ train_args += ['--train_file', train_file]
51
+ if "--eval_file" not in extra_args:
52
+ eval_file = os.path.join("data", "lemma_classifier", "%s.dev.lemma" % short_name)
53
+ train_args += ['--eval_file', eval_file]
54
+ train_args = base_args + train_args + extra_args + extra_train_args
55
+
56
+ if command_args.model_type == ModelType.LSTM:
57
+ train_args = embedding_args + train_args
58
+ train_lstm_model.main(train_args)
59
+ else:
60
+ model_type_args = ["--model_type", command_args.model_type.name.lower()]
61
+ train_args = bert_args + model_type_args + train_args
62
+ train_transformer_model.main(train_args)
63
+
64
+ if mode == Mode.SCORE_DEV or mode == Mode.TRAIN:
65
+ eval_args = []
66
+ if "--eval_file" not in extra_args:
67
+ eval_file = os.path.join("data", "lemma_classifier", "%s.dev.lemma" % short_name)
68
+ eval_args += ['--eval_file', eval_file]
69
+ model_type_args = ["--model_type", command_args.model_type.name.lower()]
70
+ eval_args = bert_args + model_type_args + base_args + eval_args + embedding_args + extra_args
71
+ evaluate_models.main(eval_args)
72
+
73
+ if mode == Mode.SCORE_TEST or mode == Mode.TRAIN:
74
+ eval_args = []
75
+ if "--eval_file" not in extra_args:
76
+ eval_file = os.path.join("data", "lemma_classifier", "%s.test.lemma" % short_name)
77
+ eval_args += ['--eval_file', eval_file]
78
+ model_type_args = ["--model_type", command_args.model_type.name.lower()]
79
+ eval_args = bert_args + model_type_args + base_args + eval_args + embedding_args + extra_args
80
+ evaluate_models.main(eval_args)
81
+
82
+ def main(args=None):
83
+ common.main(run_treebank, "lemma_classifier", "lemma_classifier", add_lemma_args, sub_argparse=train_lstm_model.build_argparse(), build_model_filename=build_model_filename, choose_charlm_method=choose_lemma_charlm, args=args)
84
+
85
+
86
+ if __name__ == '__main__':
87
+ main()
stanza/stanza/utils/training/run_mwt.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This script allows for training or testing on dev / test of the UD mwt tools.
3
+
4
+ If run with a single treebank name, it will train or test that treebank.
5
+ If run with ud_all or all_ud, it will iterate over all UD treebanks it can find.
6
+
7
+ Mode can be set to train&dev with --train, to dev set only
8
+ with --score_dev, and to test set only with --score_test.
9
+
10
+ Treebanks are specified as a list. all_ud or ud_all means to look for
11
+ all UD treebanks.
12
+
13
+ Extra arguments are passed to mwt. In case the run script
14
+ itself is shadowing arguments, you can specify --extra_args as a
15
+ parameter to mark where the mwt arguments start.
16
+ """
17
+
18
+
19
+ import logging
20
+ import math
21
+
22
+ from stanza.models import mwt_expander
23
+ from stanza.models.common.doc import Document
24
+ from stanza.utils.conll import CoNLL
25
+ from stanza.utils.training import common
26
+ from stanza.utils.training.common import Mode
27
+
28
+ from stanza.utils.max_mwt_length import max_mwt_length
29
+
30
+ logger = logging.getLogger('stanza')
31
+
32
+ def check_mwt(filename):
33
+ """
34
+ Checks whether or not there are MWTs in the given conll file
35
+ """
36
+ doc = CoNLL.conll2doc(filename)
37
+ data = doc.get_mwt_expansions(False)
38
+ return len(data) > 0
39
+
40
+ def run_treebank(mode, paths, treebank, short_name,
41
+ temp_output_file, command_args, extra_args):
42
+ short_language = short_name.split("_")[0]
43
+
44
+ mwt_dir = paths["MWT_DATA_DIR"]
45
+
46
+ train_file = f"{mwt_dir}/{short_name}.train.in.conllu"
47
+ dev_in_file = f"{mwt_dir}/{short_name}.dev.in.conllu"
48
+ dev_gold_file = f"{mwt_dir}/{short_name}.dev.gold.conllu"
49
+ dev_output_file = temp_output_file if temp_output_file else f"{mwt_dir}/{short_name}.dev.pred.conllu"
50
+ test_in_file = f"{mwt_dir}/{short_name}.test.in.conllu"
51
+ test_gold_file = f"{mwt_dir}/{short_name}.test.gold.conllu"
52
+ test_output_file = temp_output_file if temp_output_file else f"{mwt_dir}/{short_name}.test.pred.conllu"
53
+
54
+ train_json = f"{mwt_dir}/{short_name}-ud-train-mwt.json"
55
+ dev_json = f"{mwt_dir}/{short_name}-ud-dev-mwt.json"
56
+ test_json = f"{mwt_dir}/{short_name}-ud-test-mwt.json"
57
+
58
+ eval_file = None
59
+ if '--eval_file' in extra_args:
60
+ eval_file = extra_args[extra_args.index('--eval_file') + 1]
61
+
62
+ gold_file = None
63
+ if '--gold_file' in extra_args:
64
+ gold_file = extra_args[extra_args.index('--gold_file') + 1]
65
+
66
+ if not check_mwt(train_file):
67
+ logger.info("No training MWTS found for %s. Skipping" % treebank)
68
+ return
69
+
70
+ if not check_mwt(dev_in_file) and mode == Mode.TRAIN:
71
+ logger.info("No dev MWTS found for %s. Training only the deterministic MWT expander" % treebank)
72
+ extra_args.append('--dict_only')
73
+
74
+ if mode == Mode.TRAIN:
75
+ max_mwt_len = math.ceil(max_mwt_length([train_json, dev_json]) * 1.1 + 1)
76
+ logger.info("Max len: %f" % max_mwt_len)
77
+ train_args = ['--train_file', train_file,
78
+ '--eval_file', eval_file if eval_file else dev_in_file,
79
+ '--output_file', dev_output_file,
80
+ '--gold_file', gold_file if gold_file else dev_gold_file,
81
+ '--lang', short_language,
82
+ '--shorthand', short_name,
83
+ '--mode', 'train',
84
+ '--max_dec_len', str(max_mwt_len)]
85
+ train_args = train_args + extra_args
86
+ logger.info("Running train step with args: {}".format(train_args))
87
+ mwt_expander.main(train_args)
88
+
89
+ if mode == Mode.SCORE_DEV or mode == Mode.TRAIN:
90
+ dev_args = ['--eval_file', eval_file if eval_file else dev_in_file,
91
+ '--output_file', dev_output_file,
92
+ '--gold_file', gold_file if gold_file else dev_gold_file,
93
+ '--lang', short_language,
94
+ '--shorthand', short_name,
95
+ '--mode', 'predict']
96
+ dev_args = dev_args + extra_args
97
+ logger.info("Running dev step with args: {}".format(dev_args))
98
+ mwt_expander.main(dev_args)
99
+
100
+ results = common.run_eval_script_mwt(gold_file if gold_file else dev_gold_file, dev_output_file)
101
+ logger.info("Finished running dev set on\n{}\n{}".format(treebank, results))
102
+
103
+ if mode == Mode.SCORE_TEST:
104
+ test_args = ['--eval_file', eval_file if eval_file else test_in_file,
105
+ '--output_file', test_output_file,
106
+ '--gold_file', gold_file if gold_file else test_gold_file,
107
+ '--lang', short_language,
108
+ '--shorthand', short_name,
109
+ '--mode', 'predict']
110
+ test_args = test_args + extra_args
111
+ logger.info("Running test step with args: {}".format(test_args))
112
+ mwt_expander.main(test_args)
113
+
114
+ results = common.run_eval_script_mwt(gold_file if gold_file else test_gold_file, test_output_file)
115
+ logger.info("Finished running test set on\n{}\n{}".format(treebank, results))
116
+
117
+ def main():
118
+ common.main(run_treebank, "mwt", "mwt_expander", sub_argparse=mwt_expander.build_argparse())
119
+
120
+ if __name__ == "__main__":
121
+ main()
122
+
stanza/stanza/utils/training/run_ner.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Trains or scores an NER model.
3
+
4
+ Will attempt to guess the appropriate word vector file if none is
5
+ specified, and will use the charlms specified in the resources
6
+ for a given dataset or language if possible.
7
+
8
+ Example command line:
9
+ python3 -m stanza.utils.training.run_ner.py hu_combined
10
+
11
+ This script expects the prepared data to be in
12
+ data/ner/{lang}_{dataset}.train.json, {lang}_{dataset}.dev.json, {lang}_{dataset}.test.json
13
+
14
+ If those files don't exist, it will make an attempt to rebuild them
15
+ using the prepare_ner_dataset script. However, this will fail if the
16
+ data is not already downloaded. More information on where to find
17
+ most of the datasets online is in that script. Some of the datasets
18
+ have licenses which must be agreed to, so no attempt is made to
19
+ automatically download the data.
20
+ """
21
+
22
+ import logging
23
+ import os
24
+
25
+ from stanza.models import ner_tagger
26
+ from stanza.resources.common import DEFAULT_MODEL_DIR
27
+ from stanza.utils.datasets.ner import prepare_ner_dataset
28
+ from stanza.utils.training import common
29
+ from stanza.utils.training.common import Mode, add_charlm_args, build_charlm_args, choose_charlm, find_wordvec_pretrain
30
+
31
+ from stanza.resources.default_packages import default_charlms, default_pretrains, ner_charlms, ner_pretrains
32
+
33
+ # extra arguments specific to a particular dataset
34
+ DATASET_EXTRA_ARGS = {
35
+ "da_ddt": [ "--dropout", "0.6" ],
36
+ "fa_arman": [ "--dropout", "0.6" ],
37
+ "vi_vlsp": [ "--dropout", "0.6",
38
+ "--word_dropout", "0.1",
39
+ "--locked_dropout", "0.1",
40
+ "--char_dropout", "0.1" ],
41
+ }
42
+
43
+ logger = logging.getLogger('stanza')
44
+
45
+ def add_ner_args(parser):
46
+ add_charlm_args(parser)
47
+
48
+ parser.add_argument('--use_bert', default=False, action="store_true", help='Use the default transformer for this language')
49
+
50
+
51
+ def build_pretrain_args(language, dataset, charlm="default", command_args=None, extra_args=None, model_dir=DEFAULT_MODEL_DIR):
52
+ """
53
+ Returns one list with the args for this language & dataset's charlm and pretrained embedding
54
+ """
55
+ charlm = choose_charlm(language, dataset, charlm, default_charlms, ner_charlms)
56
+ charlm_args = build_charlm_args(language, charlm, model_dir=model_dir)
57
+
58
+ wordvec_args = []
59
+ if extra_args is None or '--wordvec_pretrain_file' not in extra_args:
60
+ # will throw an error if the pretrain can't be found
61
+ wordvec_pretrain = find_wordvec_pretrain(language, default_pretrains, ner_pretrains, dataset, model_dir=model_dir)
62
+ wordvec_args = ['--wordvec_pretrain_file', wordvec_pretrain]
63
+
64
+ bert_args = common.choose_transformer(language, command_args, extra_args, warn=False)
65
+
66
+ return charlm_args + wordvec_args + bert_args
67
+
68
+
69
+ # TODO: refactor? tagger and depparse should be pretty similar
70
+ def build_model_filename(paths, short_name, command_args, extra_args):
71
+ short_language, dataset = short_name.split("_", 1)
72
+
73
+ # TODO: can avoid downloading the charlm at this point, since we
74
+ # might not even be training
75
+ pretrain_args = build_pretrain_args(short_language, dataset, command_args.charlm, command_args, extra_args)
76
+
77
+ dataset_args = DATASET_EXTRA_ARGS.get(short_name, [])
78
+
79
+ train_args = ["--shorthand", short_name,
80
+ "--mode", "train"]
81
+ train_args = train_args + pretrain_args + dataset_args + extra_args
82
+ if command_args.save_name is not None:
83
+ train_args.extend(["--save_name", command_args.save_name])
84
+ if command_args.save_dir is not None:
85
+ train_args.extend(["--save_dir", command_args.save_dir])
86
+ args = ner_tagger.parse_args(train_args)
87
+ save_name = ner_tagger.model_file_name(args)
88
+ return save_name
89
+
90
+
91
+ # Technically NER datasets are not necessarily treebanks
92
+ # (usually not, in fact)
93
+ # However, to keep the naming consistent, we leave the
94
+ # method which does the training as run_treebank
95
+ # TODO: rename treebank -> dataset everywhere
96
+ def run_treebank(mode, paths, treebank, short_name,
97
+ temp_output_file, command_args, extra_args):
98
+ ner_dir = paths["NER_DATA_DIR"]
99
+ language, dataset = short_name.split("_")
100
+
101
+ train_file = os.path.join(ner_dir, f"{treebank}.train.json")
102
+ dev_file = os.path.join(ner_dir, f"{treebank}.dev.json")
103
+ test_file = os.path.join(ner_dir, f"{treebank}.test.json")
104
+
105
+ # if any files are missing, try to rebuild the dataset
106
+ # if that still doesn't work, we have to throw an error
107
+ missing_file = [x for x in (train_file, dev_file, test_file) if not os.path.exists(x)]
108
+ if len(missing_file) > 0:
109
+ logger.warning(f"The data for {treebank} is missing or incomplete. Cannot find {missing_file} Attempting to rebuild...")
110
+ try:
111
+ prepare_ner_dataset.main(treebank)
112
+ except Exception as e:
113
+ raise FileNotFoundError(f"An exception occurred while trying to build the data for {treebank} At least one portion of the data was missing: {missing_file} Please correctly build these files and then try again.") from e
114
+
115
+ pretrain_args = build_pretrain_args(language, dataset, command_args.charlm, command_args, extra_args)
116
+
117
+ if mode == Mode.TRAIN:
118
+ # VI example arguments:
119
+ # --wordvec_pretrain_file ~/stanza_resources/vi/pretrain/vtb.pt
120
+ # --train_file data/ner/vi_vlsp.train.json
121
+ # --eval_file data/ner/vi_vlsp.dev.json
122
+ # --lang vi
123
+ # --shorthand vi_vlsp
124
+ # --mode train
125
+ # --charlm --charlm_shorthand vi_conll17
126
+ # --dropout 0.6 --word_dropout 0.1 --locked_dropout 0.1 --char_dropout 0.1
127
+ dataset_args = DATASET_EXTRA_ARGS.get(short_name, [])
128
+
129
+ train_args = ['--train_file', train_file,
130
+ '--eval_file', dev_file,
131
+ '--shorthand', short_name,
132
+ '--mode', 'train']
133
+ train_args = train_args + pretrain_args + dataset_args + extra_args
134
+ logger.info("Running train step with args: {}".format(train_args))
135
+ ner_tagger.main(train_args)
136
+
137
+ if mode == Mode.SCORE_DEV or mode == Mode.TRAIN:
138
+ dev_args = ['--eval_file', dev_file,
139
+ '--shorthand', short_name,
140
+ '--mode', 'predict']
141
+ dev_args = dev_args + pretrain_args + extra_args
142
+ logger.info("Running dev step with args: {}".format(dev_args))
143
+ ner_tagger.main(dev_args)
144
+
145
+ if mode == Mode.SCORE_TEST or mode == Mode.TRAIN:
146
+ test_args = ['--eval_file', test_file,
147
+ '--shorthand', short_name,
148
+ '--mode', 'predict']
149
+ test_args = test_args + pretrain_args + extra_args
150
+ logger.info("Running test step with args: {}".format(test_args))
151
+ ner_tagger.main(test_args)
152
+
153
+
154
+ def main():
155
+ common.main(run_treebank, "ner", "nertagger", add_ner_args, ner_tagger.build_argparse(), build_model_filename=build_model_filename)
156
+
157
+ if __name__ == "__main__":
158
+ main()
159
+
stanza/stanza/utils/training/run_sentiment.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Trains or tests a sentiment model using the classifier package
3
+
4
+ The prep script has separate entries for the root-only version of SST,
5
+ which is what people typically use to test. When training a model for
6
+ SST which uses all the data, the root-only version is used for
7
+ dev and test
8
+ """
9
+
10
+ import logging
11
+ import os
12
+
13
+ from stanza.models import classifier
14
+ from stanza.utils.training import common
15
+ from stanza.utils.training.common import Mode, build_charlm_args, choose_charlm, find_wordvec_pretrain
16
+
17
+ from stanza.resources.default_packages import default_charlms, default_pretrains
18
+
19
+ logger = logging.getLogger('stanza')
20
+
21
+ # TODO: refactor with ner & conparse
22
+ def add_sentiment_args(parser):
23
+ parser.add_argument('--charlm', default="default", type=str, help='Which charlm to run on. Will use the default charlm for this language/model if not set. Set to None to turn off charlm for languages with a default charlm')
24
+ parser.add_argument('--no_charlm', dest='charlm', action="store_const", const=None, help="Don't use a charlm, even if one is used by default for this package")
25
+
26
+ parser.add_argument('--use_bert', default=False, action="store_true", help='Use the default transformer for this language')
27
+
28
+ ALTERNATE_DATASET = {
29
+ "en_sst2": "en_sst2roots",
30
+ "en_sstplus": "en_sst3roots",
31
+ }
32
+
33
+ def build_default_args(paths, short_language, dataset, command_args, extra_args):
34
+ if '--wordvec_pretrain_file' not in extra_args:
35
+ # will throw an error if the pretrain can't be found
36
+ wordvec_pretrain = find_wordvec_pretrain(short_language, default_pretrains)
37
+ wordvec_args = ['--wordvec_pretrain_file', wordvec_pretrain]
38
+ else:
39
+ wordvec_args = []
40
+
41
+ charlm = choose_charlm(short_language, dataset, command_args.charlm, default_charlms, {})
42
+ charlm_args = build_charlm_args(short_language, charlm, base_args=False)
43
+
44
+ bert_args = common.choose_transformer(short_language, command_args, extra_args)
45
+ default_args = wordvec_args + charlm_args + bert_args
46
+
47
+ return default_args
48
+
49
+ def build_model_filename(paths, short_name, command_args, extra_args):
50
+ short_language, dataset = short_name.split("_", 1)
51
+
52
+ default_args = build_default_args(paths, short_language, dataset, command_args, extra_args)
53
+
54
+ train_args = ["--shorthand", short_name]
55
+ train_args = train_args + default_args
56
+ if command_args.save_name is not None:
57
+ train_args.extend(["--save_name", command_args.save_name])
58
+ if command_args.save_dir is not None:
59
+ train_args.extend(["--save_dir", command_args.save_dir])
60
+ args = classifier.parse_args(train_args + extra_args)
61
+ save_name = classifier.build_model_filename(args)
62
+ return save_name
63
+
64
+
65
+ def run_dataset(mode, paths, treebank, short_name,
66
+ temp_output_file, command_args, extra_args):
67
+ sentiment_dir = paths["SENTIMENT_DATA_DIR"]
68
+ short_language, dataset = short_name.split("_", 1)
69
+
70
+ train_file = os.path.join(sentiment_dir, f"{short_name}.train.json")
71
+
72
+ other_name = ALTERNATE_DATASET.get(short_name, short_name)
73
+ dev_file = os.path.join(sentiment_dir, f"{other_name}.dev.json")
74
+ test_file = os.path.join(sentiment_dir, f"{other_name}.test.json")
75
+
76
+ for filename in (train_file, dev_file, test_file):
77
+ if not os.path.exists(filename):
78
+ raise FileNotFoundError("Cannot find %s" % filename)
79
+
80
+ default_args = build_default_args(paths, short_language, dataset, command_args, extra_args)
81
+
82
+ if mode == Mode.TRAIN:
83
+ train_args = ['--train_file', train_file,
84
+ '--dev_file', dev_file,
85
+ '--test_file', test_file,
86
+ '--shorthand', short_name,
87
+ '--wordvec_type', 'word2vec', # TODO: chinese is fasttext
88
+ '--extra_wordvec_method', 'SUM']
89
+ train_args = train_args + default_args + extra_args
90
+ logger.info("Running train step with args: {}".format(train_args))
91
+ classifier.main(train_args)
92
+
93
+ if mode == Mode.SCORE_DEV or mode == Mode.TRAIN:
94
+ dev_args = ['--no_train',
95
+ '--test_file', dev_file,
96
+ '--shorthand', short_name,
97
+ '--wordvec_type', 'word2vec'] # TODO: chinese is fasttext
98
+ dev_args = dev_args + default_args + extra_args
99
+ logger.info("Running dev step with args: {}".format(dev_args))
100
+ classifier.main(dev_args)
101
+
102
+ if mode == Mode.SCORE_TEST or mode == Mode.TRAIN:
103
+ test_args = ['--no_train',
104
+ '--test_file', test_file,
105
+ '--shorthand', short_name,
106
+ '--wordvec_type', 'word2vec'] # TODO: chinese is fasttext
107
+ test_args = test_args + default_args + extra_args
108
+ logger.info("Running test step with args: {}".format(test_args))
109
+ classifier.main(test_args)
110
+
111
+
112
+
113
+ def main():
114
+ common.main(run_dataset, "classifier", "classifier", add_sentiment_args, classifier.build_argparse(), build_model_filename=build_model_filename)
115
+
116
+ if __name__ == "__main__":
117
+ main()
118
+
stanza/stanza/utils/training/run_tokenizer.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This script allows for training or testing on dev / test of the UD tokenizer.
3
+
4
+ If run with a single treebank name, it will train or test that treebank.
5
+ If run with ud_all or all_ud, it will iterate over all UD treebanks it can find.
6
+
7
+ Mode can be set to train&dev with --train, to dev set only
8
+ with --score_dev, and to test set only with --score_test.
9
+
10
+ Treebanks are specified as a list. all_ud or ud_all means to look for
11
+ all UD treebanks.
12
+
13
+ Extra arguments are passed to tokenizer. In case the run script
14
+ itself is shadowing arguments, you can specify --extra_args as a
15
+ parameter to mark where the tokenizer arguments start.
16
+
17
+ Default behavior is to discard the output and just print the results.
18
+ To keep the results instead, use --save_output
19
+ """
20
+
21
+ import logging
22
+ import math
23
+ import os
24
+
25
+ from stanza.models import tokenizer
26
+ from stanza.utils.avg_sent_len import avg_sent_len
27
+ from stanza.utils.training import common
28
+ from stanza.utils.training.common import Mode
29
+
30
+ logger = logging.getLogger('stanza')
31
+
32
+ def uses_dictionary(short_language):
33
+ """
34
+ Some of the languages (as shown here) have external dictionaries
35
+
36
+ We found this helped the overall tokenizer performance
37
+ If these can't be found, they can be extracted from the previous iteration of models
38
+ """
39
+ if short_language in ('ja', 'th', 'zh', 'zh-hans', 'zh-hant'):
40
+ return True
41
+ return False
42
+
43
+ def run_treebank(mode, paths, treebank, short_name,
44
+ temp_output_file, command_args, extra_args):
45
+ tokenize_dir = paths["TOKENIZE_DATA_DIR"]
46
+
47
+ short_language = short_name.split("_")[0]
48
+ label_type = "--label_file"
49
+ label_file = f"{tokenize_dir}/{short_name}-ud-train.toklabels"
50
+ dev_type = "--txt_file"
51
+ dev_file = f"{tokenize_dir}/{short_name}.dev.txt"
52
+ test_type = "--txt_file"
53
+ test_file = f"{tokenize_dir}/{short_name}.test.txt"
54
+ train_type = "--txt_file"
55
+ train_file = f"{tokenize_dir}/{short_name}.train.txt"
56
+ train_dev_args = ["--dev_txt_file", dev_file, "--dev_label_file", f"{tokenize_dir}/{short_name}-ud-dev.toklabels"]
57
+
58
+ if short_language == "zh" or short_language.startswith("zh-"):
59
+ extra_args = ["--skip_newline"] + extra_args
60
+
61
+ train_gold = f"{tokenize_dir}/{short_name}.train.gold.conllu"
62
+ dev_gold = f"{tokenize_dir}/{short_name}.dev.gold.conllu"
63
+ test_gold = f"{tokenize_dir}/{short_name}.test.gold.conllu"
64
+
65
+ train_mwt = f"{tokenize_dir}/{short_name}-ud-train-mwt.json"
66
+ dev_mwt = f"{tokenize_dir}/{short_name}-ud-dev-mwt.json"
67
+ test_mwt = f"{tokenize_dir}/{short_name}-ud-test-mwt.json"
68
+
69
+ train_pred = temp_output_file if temp_output_file else f"{tokenize_dir}/{short_name}.train.pred.conllu"
70
+ dev_pred = temp_output_file if temp_output_file else f"{tokenize_dir}/{short_name}.dev.pred.conllu"
71
+ test_pred = temp_output_file if temp_output_file else f"{tokenize_dir}/{short_name}.test.pred.conllu"
72
+
73
+ if mode == Mode.TRAIN:
74
+ seqlen = str(math.ceil(avg_sent_len(label_file) * 3 / 100) * 100)
75
+ train_args = ([label_type, label_file, train_type, train_file, "--lang", short_language,
76
+ "--max_seqlen", seqlen, "--mwt_json_file", dev_mwt] +
77
+ train_dev_args +
78
+ ["--dev_conll_gold", dev_gold, "--conll_file", dev_pred, "--shorthand", short_name])
79
+ if uses_dictionary(short_language):
80
+ train_args = train_args + ["--use_dictionary"]
81
+ train_args = train_args + extra_args
82
+ logger.info("Running train step with args: {}".format(train_args))
83
+ tokenizer.main(train_args)
84
+
85
+ if mode == Mode.SCORE_DEV or mode == Mode.TRAIN:
86
+ dev_args = ["--mode", "predict", dev_type, dev_file, "--lang", short_language,
87
+ "--conll_file", dev_pred, "--shorthand", short_name, "--mwt_json_file", dev_mwt]
88
+ dev_args = dev_args + extra_args
89
+ logger.info("Running dev step with args: {}".format(dev_args))
90
+ tokenizer.main(dev_args)
91
+
92
+ # TODO: log these results? The original script logged them to
93
+ # echo $results $args >> ${TOKENIZE_DATA_DIR}/${short}.results
94
+
95
+ results = common.run_eval_script_tokens(dev_gold, dev_pred)
96
+ logger.info("Finished running dev set on\n{}\n{}".format(treebank, results))
97
+
98
+ if mode == Mode.SCORE_TEST or mode == Mode.TRAIN:
99
+ test_args = ["--mode", "predict", test_type, test_file, "--lang", short_language,
100
+ "--conll_file", test_pred, "--shorthand", short_name, "--mwt_json_file", test_mwt]
101
+ test_args = test_args + extra_args
102
+ logger.info("Running test step with args: {}".format(test_args))
103
+ tokenizer.main(test_args)
104
+
105
+ results = common.run_eval_script_tokens(test_gold, test_pred)
106
+ logger.info("Finished running test set on\n{}\n{}".format(treebank, results))
107
+
108
+ if mode == Mode.SCORE_TRAIN:
109
+ test_args = ["--mode", "predict", test_type, train_file, "--lang", short_language,
110
+ "--conll_file", train_pred, "--shorthand", short_name, "--mwt_json_file", train_mwt]
111
+ test_args = test_args + extra_args
112
+ logger.info("Running test step with args: {}".format(test_args))
113
+ tokenizer.main(test_args)
114
+
115
+ results = common.run_eval_script_tokens(train_gold, train_pred)
116
+ logger.info("Finished running train set as a test on\n{}\n{}".format(treebank, results))
117
+
118
+
119
+
120
+ def main():
121
+ common.main(run_treebank, "tokenize", "tokenizer", sub_argparse=tokenizer.build_argparse())
122
+
123
+ if __name__ == "__main__":
124
+ main()
stanza/stanza/utils/training/separate_ner_pretrain.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Loads NER models & separates out the word vectors to base & delta
3
+
4
+ The model will then be resaved without the base word vector,
5
+ greatly reducing the size of the model
6
+
7
+ This may be useful for any external users of stanza who have an NER
8
+ model they wish to reuse without retraining
9
+
10
+ If you know which pretrain was used to build an NER model, you can
11
+ provide that pretrain. Otherwise, you can give a directory of
12
+ pretrains and the script will test each one. In the latter case,
13
+ the name of the pretrain needs to look like lang_dataset_pretrain.pt
14
+ """
15
+
16
+ import argparse
17
+ from collections import defaultdict
18
+ import logging
19
+ import os
20
+
21
+ import numpy as np
22
+ import torch
23
+ import torch.nn as nn
24
+
25
+ from stanza import Pipeline
26
+ from stanza.models.common.constant import lang_to_langcode
27
+ from stanza.models.common.pretrain import Pretrain, PretrainedWordVocab
28
+ from stanza.models.common.vocab import PAD_ID, VOCAB_PREFIX
29
+ from stanza.models.ner.trainer import Trainer
30
+
31
+ logger = logging.getLogger('stanza')
32
+ logger.setLevel(logging.ERROR)
33
+
34
+ DEBUG = False
35
+ EPS = 0.0001
36
+
37
+ def main():
38
+ parser = argparse.ArgumentParser()
39
+ parser.add_argument('--input_path', type=str, default='saved_models/ner', help='Where to find NER models (dir or filename)')
40
+ parser.add_argument('--output_path', type=str, default='saved_models/shrunk', help='Where to write shrunk NER models (dir)')
41
+ parser.add_argument('--pretrain_path', type=str, default='saved_models/pretrain', help='Where to find pretrains (dir or filename)')
42
+ args = parser.parse_args()
43
+
44
+ # get list of NER models to shrink
45
+ if os.path.isdir(args.input_path):
46
+ ner_model_dir = args.input_path
47
+ ners = os.listdir(ner_model_dir)
48
+ if len(ners) == 0:
49
+ raise FileNotFoundError("No ner models found in {}".format(args.input_path))
50
+ else:
51
+ if not os.path.isfile(args.input_path):
52
+ raise FileNotFoundError("No ner model found at path {}".format(args.input_path))
53
+ ner_model_dir, ners = os.path.split(args.input_path)
54
+ ners = [ners]
55
+
56
+ # get map from language to candidate pretrains
57
+ if os.path.isdir(args.pretrain_path):
58
+ pt_model_dir = args.pretrain_path
59
+ pretrains = os.listdir(pt_model_dir)
60
+ lang_to_pretrain = defaultdict(list)
61
+ for pt in pretrains:
62
+ lang_to_pretrain[pt.split("_")[0]].append(pt)
63
+ else:
64
+ pt_model_dir, pretrains = os.path.split(pt_model_dir)
65
+ pretrains = [pretrains]
66
+ lang_to_pretrain = defaultdict(lambda: pretrains)
67
+
68
+ # shrunk models will all go in this directory
69
+ new_dir = args.output_path
70
+ os.makedirs(new_dir, exist_ok=True)
71
+
72
+ final_pretrains = []
73
+ missing_pretrains = []
74
+ no_finetune = []
75
+
76
+ # for each model, go through the various pretrains
77
+ # until we find one that works or none of them work
78
+ for ner_model in ners:
79
+ ner_path = os.path.join(ner_model_dir, ner_model)
80
+
81
+ expected_ending = "_nertagger.pt"
82
+ if not ner_model.endswith(expected_ending):
83
+ raise ValueError("Unexpected name: {}".format(ner_model))
84
+ short_name = ner_model[:-len(expected_ending)]
85
+ lang, package = short_name.split("_", maxsplit=1)
86
+ print("===============================================")
87
+ print("Processing lang %s package %s" % (lang, package))
88
+
89
+ # this may look funny - basically, the pipeline has machinery
90
+ # to make sure the model has everything it needs to load,
91
+ # including downloading other pieces if needed
92
+ pipe = Pipeline(lang, processors="tokenize,ner", tokenize_pretokenized=True, package={"ner": package}, ner_model_path=ner_path)
93
+ ner_processor = pipe.processors['ner']
94
+ print("Loaded NER processor: {}".format(ner_processor))
95
+ trainer = ner_processor.trainers[0]
96
+ vocab = trainer.model.vocab
97
+ word_vocab = vocab['word']
98
+ num_vectors = trainer.model.word_emb.weight.shape[0]
99
+
100
+ # sanity check, make sure the model loaded matches the
101
+ # language from the model's filename
102
+ lcode = lang_to_langcode(trainer.args['lang'])
103
+ if lang != lcode and not (lcode == 'zh' and lang == 'zh-hans'):
104
+ raise ValueError("lang not as expected: {} vs {} ({})".format(lang, trainer.args['lang'], lcode))
105
+
106
+ ner_pretrains = sorted(set(lang_to_pretrain[lang] + lang_to_pretrain[lcode]))
107
+ for pt_model in ner_pretrains:
108
+ pt_path = os.path.join(pt_model_dir, pt_model)
109
+ print("Attempting pretrain: {}".format(pt_path))
110
+ pt = Pretrain(filename=pt_path)
111
+ print(" pretrain shape: {}".format(pt.emb.shape))
112
+ print(" embedding in ner model shape: {}".format(trainer.model.word_emb.weight.shape))
113
+ if pt.emb.shape[1] != trainer.model.word_emb.weight.shape[1]:
114
+ print(" DIMENSION DOES NOT MATCH. SKIPPING")
115
+ continue
116
+ N = min(pt.emb.shape[0], trainer.model.word_emb.weight.shape[0])
117
+ if pt.emb.shape[0] != trainer.model.word_emb.weight.shape[0]:
118
+ # If the vocab was exactly the same, that's a good
119
+ # sign this pretrain was used, just with a different size
120
+ # In such a case, we can reuse the rest of the pretrain
121
+ # Minor issue: some vectors which were trained will be
122
+ # lost in the case of |pt| < |model.word_emb|
123
+ if all(word_vocab.id2unit(x) == word_vocab.id2unit(x) for x in range(N)):
124
+ print(" Attempting to use pt vectors to replace ner model's vectors")
125
+ else:
126
+ print(" NUM VECTORS DO NOT MATCH. WORDS DO NOT MATCH. SKIPPING")
127
+ continue
128
+ if pt.emb.shape[0] < trainer.model.word_emb.weight.shape[0]:
129
+ print(" WARNING: if any vectors beyond {} were fine tuned, that fine tuning will be lost".format(N))
130
+ device = next(trainer.model.parameters()).device
131
+ delta = trainer.model.word_emb.weight[:N, :] - pt.emb.to(device)[:N, :]
132
+ delta = delta.detach()
133
+ delta_norms = torch.linalg.norm(delta, dim=1).cpu().numpy()
134
+ if np.sum(delta_norms < 0) > 0:
135
+ raise ValueError("This should not be - a norm was less than 0!")
136
+ num_matching = np.sum(delta_norms < EPS)
137
+ if num_matching > N / 2:
138
+ print(" Accepted! %d of %d vectors match for %s" % (num_matching, N, pt_path))
139
+ if pt.emb.shape[0] != trainer.model.word_emb.weight.shape[0]:
140
+ print(" Setting model vocab to match the pretrain")
141
+ word_vocab = pt.vocab
142
+ vocab['word'] = word_vocab
143
+ trainer.args['word_emb_dim'] = pt.emb.shape[1]
144
+ break
145
+ else:
146
+ print(" %d of %d vectors matched for %s - SKIPPING" % (num_matching, N, pt_path))
147
+ vocab_same = sum(x in pt.vocab for x in word_vocab)
148
+ print(" %d words were in both vocabs" % vocab_same)
149
+ # this is expensive, and in practice doesn't happen,
150
+ # but theoretically we might have missed a mostly matching pt
151
+ # if the vocab had been scrambled
152
+ if DEBUG:
153
+ rearranged_count = 0
154
+ for x in word_vocab:
155
+ if x not in pt.vocab:
156
+ continue
157
+ x_id = word_vocab.unit2id(x)
158
+ x_vec = trainer.model.word_emb.weight[x_id, :]
159
+ pt_id = pt.vocab.unit2id(x)
160
+ pt_vec = pt.emb[pt_id, :]
161
+ if (x_vec.detach().cpu() - pt_vec).norm() < EPS:
162
+ rearranged_count += 1
163
+ print(" %d vectors were close when ignoring id ordering" % rearranged_count)
164
+ else:
165
+ print("COULD NOT FIND A MATCHING PT: {}".format(ner_processor))
166
+ missing_pretrains.append(ner_model)
167
+ continue
168
+
169
+ # build a delta vector & embedding
170
+ assert 'delta' not in vocab.keys()
171
+ delta_vectors = [delta[i].cpu() for i in range(4)]
172
+ delta_vocab = []
173
+ for i in range(4, len(delta_norms)):
174
+ if delta_norms[i] > 0.0:
175
+ delta_vocab.append(word_vocab.id2unit(i))
176
+ delta_vectors.append(delta[i].cpu())
177
+
178
+ trainer.model.unsaved_modules.append("word_emb")
179
+ if len(delta_vocab) == 0:
180
+ print("No vectors were changed! Perhaps this model was trained without finetune.")
181
+ no_finetune.append(ner_model)
182
+ else:
183
+ print("%d delta vocab" % len(delta_vocab))
184
+ print("%d vectors in the delta set" % len(delta_vectors))
185
+ delta_vectors = np.stack(delta_vectors)
186
+ delta_vectors = torch.from_numpy(delta_vectors)
187
+ assert delta_vectors.shape[0] == len(delta_vocab) + len(VOCAB_PREFIX)
188
+ print(delta_vectors.shape)
189
+
190
+ delta_vocab = PretrainedWordVocab(delta_vocab, lang=word_vocab.lang, lower=word_vocab.lower)
191
+ vocab['delta'] = delta_vocab
192
+ trainer.model.delta_emb = nn.Embedding(delta_vectors.shape[0], delta_vectors.shape[1], PAD_ID)
193
+ trainer.model.delta_emb.weight.data.copy_(delta_vectors)
194
+
195
+ new_path = os.path.join(new_dir, ner_model)
196
+ trainer.save(new_path)
197
+
198
+ final_pretrains.append((ner_model, pt_model))
199
+
200
+ print()
201
+ if len(final_pretrains) > 0:
202
+ print("Final pretrain mappings:")
203
+ for i in final_pretrains:
204
+ print(i)
205
+ if len(missing_pretrains) > 0:
206
+ print("MISSING EMBEDDINGS:")
207
+ for i in missing_pretrains:
208
+ print(i)
209
+ if len(no_finetune) > 0:
210
+ print("NOT FINE TUNED:")
211
+ for i in no_finetune:
212
+ print(i)
213
+
214
+ if __name__ == '__main__':
215
+ main()
stanza/stanza/utils/visualization/__init__.py ADDED
File without changes
stanza/stanza/utils/visualization/conll_deprel_visualization.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from stanza.models.common.constant import is_right_to_left
2
+ import spacy
3
+ import argparse
4
+ from spacy import displacy
5
+ from spacy.tokens import Doc
6
+ from stanza.utils import conll
7
+ from stanza.utils.visualization import dependency_visualization as viz
8
+
9
+
10
+ def conll_to_visual(conll_file, pipeline, sent_count=10, display_all=False):
11
+ """
12
+ Takes in a conll file and visualizes it by converting the conll file to a Stanza Document object
13
+ and visualizing it with the visualize_doc method.
14
+
15
+ Input should be a proper conll file.
16
+
17
+ The pipeline for the conll file to be processed in must be provided as well.
18
+
19
+ Optionally, the sent_count argument can be tweaked to display a different amount of sentences.
20
+
21
+ To display all of the sentences in a conll file, the display_all argument can optionally be set to True.
22
+ BEWARE: setting this argument for a large conll file may result in too many renderings, resulting in a crash.
23
+ """
24
+ # convert conll file to doc
25
+ doc = conll.CoNLL.conll2doc(conll_file)
26
+
27
+ if display_all:
28
+ viz.visualize_doc(conll.CoNLL.conll2doc(conll_file), pipeline)
29
+ else: # visualize a given number of sentences
30
+ visualization_options = {"compact": True, "bg": "#09a3d5", "color": "white", "distance": 100,
31
+ "font": "Source Sans Pro", "offset_x": 30,
32
+ "arrow_spacing": 20} # see spaCy visualization settings doc for more options
33
+ nlp = spacy.blank("en")
34
+ sentences_to_visualize, rtl, num_sentences = [], is_right_to_left(pipeline), len(doc.sentences)
35
+
36
+ for i in range(sent_count):
37
+ if i >= num_sentences: # case where there are less sentences than amount requested
38
+ break
39
+ sentence = doc.sentences[i]
40
+ words, lemmas, heads, deps, tags = [], [], [], [], []
41
+ sentence_words = sentence.words
42
+ if rtl: # rtl languages will be visually rendered from right to left as well
43
+ sentence_words = reversed(sentence.words)
44
+ sent_len = len(sentence.words)
45
+ for word in sentence_words:
46
+ words.append(word.text)
47
+ lemmas.append(word.lemma)
48
+ deps.append(word.deprel)
49
+ tags.append(word.upos)
50
+ if rtl and word.head == 0: # word heads are off-by-1 in spaCy doc inits compared to Stanza
51
+ heads.append(sent_len - word.id)
52
+ elif rtl and word.head != 0:
53
+ heads.append(sent_len - word.head)
54
+ elif not rtl and word.head == 0:
55
+ heads.append(word.id - 1)
56
+ elif not rtl and word.head != 0:
57
+ heads.append(word.head - 1)
58
+
59
+ document_result = Doc(nlp.vocab, words=words, lemmas=lemmas, heads=heads, deps=deps, pos=tags)
60
+ sentences_to_visualize.append(document_result)
61
+
62
+ print(sentences_to_visualize)
63
+ for line in sentences_to_visualize: # render all sentences through displaCy
64
+ displacy.render(line, style="dep", options=visualization_options)
65
+
66
+
67
+ def main():
68
+ parser = argparse.ArgumentParser()
69
+ parser.add_argument('--conll_file', type=str,
70
+ default="C:\\Users\\Alex\\stanza\\demo\\en_test.conllu.txt",
71
+ help="File path of the CoNLL file to visualize dependencies of")
72
+ parser.add_argument('--pipeline', type=str, default="en",
73
+ help="Language code of the language pipeline to use (ex: 'en' for English)")
74
+ parser.add_argument('--sent_count', type=int, default=10, help="Number of sentences to visualize from CoNLL file")
75
+ parser.add_argument('--display_all', type=bool, default=False,
76
+ help="Whether or not to visualize all of the sentences from the file. Overrides sent_count if set to True")
77
+ args = parser.parse_args()
78
+ conll_to_visual(args.conll_file, args.pipeline, args.sent_count, args.display_all)
79
+ return
80
+
81
+
82
+ if __name__ == "__main__":
83
+ main()