diff --git a/stanza/stanza/tests/classifiers/test_data.py b/stanza/stanza/tests/classifiers/test_data.py new file mode 100644 index 0000000000000000000000000000000000000000..e74f2a8c3faa7ecb70fec846f52c482c288d3ce6 --- /dev/null +++ b/stanza/stanza/tests/classifiers/test_data.py @@ -0,0 +1,130 @@ +import json +import pytest + +import stanza.models.classifiers.data as data +from stanza.models.classifiers.utils import WVType +from stanza.models.common.vocab import PAD, UNK +from stanza.models.constituency.parse_tree import Tree + +SENTENCES = [ + ["I", "hate", "the", "Opal", "banning"], + ["Tell", "my", "wife", "hello"], # obviously this is the neutral result + ["I", "like", "Sh'reyan", "'s", "antennae"], +] + +DATASET = [ + {"sentiment": "0", "text": SENTENCES[0]}, + {"sentiment": "1", "text": SENTENCES[1]}, + {"sentiment": "2", "text": SENTENCES[2]}, +] + +TREES = [ + "(ROOT (S (NP (PRP I)) (VP (VBP hate) (NP (DT the) (NN Opal) (NN banning)))))", + "(ROOT (S (VP (VB Tell) (NP (PRP$ my) (NN wife)) (NP (UH hello)))))", + "(ROOT (S (NP (PRP I)) (VP (VBP like) (NP (NP (NNP Sh'reyan) (POS 's)) (NNS antennae)))))", +] + +DATASET_WITH_TREES = [ + {"sentiment": "0", "text": SENTENCES[0], "constituency": TREES[0]}, + {"sentiment": "1", "text": SENTENCES[1], "constituency": TREES[1]}, + {"sentiment": "2", "text": SENTENCES[2], "constituency": TREES[2]}, +] + +@pytest.fixture(scope="module") +def train_file(tmp_path_factory): + train_set = DATASET * 20 + train_filename = tmp_path_factory.mktemp("data") / "train.json" + with open(train_filename, "w", encoding="utf-8") as fout: + json.dump(train_set, fout, ensure_ascii=False) + return train_filename + +@pytest.fixture(scope="module") +def dev_file(tmp_path_factory): + dev_set = DATASET * 2 + dev_filename = tmp_path_factory.mktemp("data") / "dev.json" + with open(dev_filename, "w", encoding="utf-8") as fout: + json.dump(dev_set, fout, ensure_ascii=False) + return dev_filename + +@pytest.fixture(scope="module") +def test_file(tmp_path_factory): + test_set = DATASET + test_filename = tmp_path_factory.mktemp("data") / "test.json" + with open(test_filename, "w", encoding="utf-8") as fout: + json.dump(test_set, fout, ensure_ascii=False) + return test_filename + +@pytest.fixture(scope="module") +def train_file_with_trees(tmp_path_factory): + train_set = DATASET_WITH_TREES * 20 + train_filename = tmp_path_factory.mktemp("data") / "train_trees.json" + with open(train_filename, "w", encoding="utf-8") as fout: + json.dump(train_set, fout, ensure_ascii=False) + return train_filename + +@pytest.fixture(scope="module") +def dev_file_with_trees(tmp_path_factory): + dev_set = DATASET_WITH_TREES * 2 + dev_filename = tmp_path_factory.mktemp("data") / "dev_trees.json" + with open(dev_filename, "w", encoding="utf-8") as fout: + json.dump(dev_set, fout, ensure_ascii=False) + return dev_filename + +class TestClassifierData: + def test_read_data(self, train_file): + """ + Test reading of the json format + """ + train_set = data.read_dataset(str(train_file), WVType.OTHER, 1) + assert len(train_set) == 60 + + def test_read_data_with_trees(self, train_file, train_file_with_trees): + """ + Test reading of the json format + """ + train_trees_set = data.read_dataset(str(train_file_with_trees), WVType.OTHER, 1) + assert len(train_trees_set) == 60 + for idx, x in enumerate(train_trees_set): + assert isinstance(x.constituency, Tree) + assert str(x.constituency) == TREES[idx % len(TREES)] + + train_set = data.read_dataset(str(train_file), WVType.OTHER, 1) + + def test_dataset_vocab(self, train_file): + """ + Converting a dataset to vocab should have a specific set of words along with PAD and UNK + """ + train_set = data.read_dataset(str(train_file), WVType.OTHER, 1) + vocab = data.dataset_vocab(train_set) + expected = set([PAD, UNK] + [x.lower() for y in SENTENCES for x in y]) + assert set(vocab) == expected + + def test_dataset_labels(self, train_file): + """ + Test the extraction of labels from a dataset + """ + train_set = data.read_dataset(str(train_file), WVType.OTHER, 1) + labels = data.dataset_labels(train_set) + assert labels == ["0", "1", "2"] + + def test_sort_by_length(self, train_file): + """ + There are two unique lengths in the toy dataset + """ + train_set = data.read_dataset(str(train_file), WVType.OTHER, 1) + sorted_dataset = data.sort_dataset_by_len(train_set) + assert list(sorted_dataset.keys()) == [4, 5] + assert len(sorted_dataset[4]) == len(train_set) // 3 + assert len(sorted_dataset[5]) == 2 * len(train_set) // 3 + + def test_check_labels(self, train_file): + """ + Check that an exception is thrown for an unknown label + """ + train_set = data.read_dataset(str(train_file), WVType.OTHER, 1) + labels = sorted(set([x["sentiment"] for x in DATASET])) + assert len(labels) > 1 + data.check_labels(labels, train_set) + with pytest.raises(RuntimeError): + data.check_labels(labels[:1], train_set) + diff --git a/stanza/stanza/tests/constituency/test_tree_stack.py b/stanza/stanza/tests/constituency/test_tree_stack.py new file mode 100644 index 0000000000000000000000000000000000000000..e7859a3b4f887ee596ca9ef4138296fab358cda8 --- /dev/null +++ b/stanza/stanza/tests/constituency/test_tree_stack.py @@ -0,0 +1,50 @@ +import pytest + +from stanza.models.constituency.tree_stack import TreeStack + +from stanza.tests import * + +pytestmark = [pytest.mark.pipeline, pytest.mark.travis] + +def test_simple(): + stack = TreeStack(value=5, parent=None, length=1) + stack = stack.push(3) + stack = stack.push(1) + + expected_values = [1, 3, 5] + for value in expected_values: + assert stack.value == value + stack = stack.pop() + assert stack is None + +def test_iter(): + stack = TreeStack(value=5, parent=None, length=1) + stack = stack.push(3) + stack = stack.push(1) + + stack_list = list(stack) + assert list(stack) == [1, 3, 5] + +def test_str(): + stack = TreeStack(value=5, parent=None, length=1) + stack = stack.push(3) + stack = stack.push(1) + + assert str(stack) == "TreeStack(1, 3, 5)" + +def test_len(): + stack = TreeStack(value=5, parent=None, length=1) + assert len(stack) == 1 + + stack = stack.push(3) + stack = stack.push(1) + assert len(stack) == 3 + +def test_long_len(): + """ + Original stack had a bug where this took exponential time... + """ + stack = TreeStack(value=0, parent=None, length=1) + for i in range(1, 40): + stack = stack.push(i) + assert len(stack) == 40 diff --git a/stanza/stanza/tests/data/external_server.properties b/stanza/stanza/tests/data/external_server.properties new file mode 100644 index 0000000000000000000000000000000000000000..408853472c0370eb5849ef3249280547315db481 --- /dev/null +++ b/stanza/stanza/tests/data/external_server.properties @@ -0,0 +1 @@ +annotators = tokenize,ssplit,pos diff --git a/stanza/stanza/tests/lemma/test_lowercase.py b/stanza/stanza/tests/lemma/test_lowercase.py new file mode 100644 index 0000000000000000000000000000000000000000..6692dafa6c4888d2b303843519417a230dc6f588 --- /dev/null +++ b/stanza/stanza/tests/lemma/test_lowercase.py @@ -0,0 +1,57 @@ +import pytest + +from stanza.models.lemmatizer import all_lowercase +from stanza.utils.conll import CoNLL + +LATIN_CONLLU = """ +# sent_id = train-s1 +# text = unde et philosophus dicit felicitatem esse operationem perfectam. +# reference = ittb-scg-s4203 +1 unde unde ADV O4 AdvType=Loc|PronType=Rel 4 advmod:lmod _ _ +2 et et CCONJ O4 _ 3 advmod:emph _ _ +3 philosophus philosophus NOUN B1|grn1|casA|gen1 Case=Nom|Gender=Masc|InflClass=IndEurO|Number=Sing 4 nsubj _ _ +4 dicit dico VERB N3|modA|tem1|gen6 Aspect=Imp|InflClass=LatX|Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin|Voice=Act 0 root _ TraditionalMood=Indicativus|TraditionalTense=Praesens +5 felicitatem felicitas NOUN C1|grn1|casD|gen2 Case=Acc|Gender=Fem|InflClass=IndEurX|Number=Sing 7 nsubj _ _ +6 esse sum AUX N3|modH|tem1 Aspect=Imp|Tense=Pres|VerbForm=Inf 7 cop _ _ +7 operationem operatio NOUN C1|grn1|casD|gen2|vgr1 Case=Acc|Gender=Fem|InflClass=IndEurX|Number=Sing 4 ccomp _ _ +8 perfectam perfectus ADJ A1|grn1|casD|gen2 Case=Acc|Gender=Fem|InflClass=IndEurA|Number=Sing 7 amod _ SpaceAfter=No +9 . . PUNCT Punc _ 4 punct _ _ + +# sent_id = train-s2 +# text = perfectio autem operationis dependet ex quatuor. +# reference = ittb-scg-s4204 +1 perfectio perfectio NOUN C1|grn1|casA|gen2 Case=Nom|Gender=Fem|InflClass=IndEurX|Number=Sing 4 nsubj _ _ +2 autem autem PART O4 _ 4 discourse _ _ +3 operationis operatio NOUN C1|grn1|casB|gen2|vgr1 Case=Gen|Gender=Fem|InflClass=IndEurX|Number=Sing 1 nmod _ _ +4 dependet dependeo VERB K3|modA|tem1|gen6 Aspect=Imp|InflClass=LatE|Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin|Voice=Act 0 root _ TraditionalMood=Indicativus|TraditionalTense=Praesens +5 ex ex ADP S4|vgr2 _ 6 case _ _ +6 quatuor quattuor NUM G1|gen3|vgr1 NumForm=Word|NumType=Card 4 obl:arg _ SpaceAfter=No +7 . . PUNCT Punc _ 4 punct _ _ +""".lstrip() + +ENG_CONLLU = """ +# sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0007 +# text = You wonder if he was manipulating the market with his bombing targets. +1 You you PRON PRP Case=Nom|Person=2|PronType=Prs 2 nsubj 2:nsubj _ +2 wonder wonder VERB VBP Mood=Ind|Number=Sing|Person=2|Tense=Pres|VerbForm=Fin 0 root 0:root _ +3 if if SCONJ IN _ 6 mark 6:mark _ +4 he he PRON PRP Case=Nom|Gender=Masc|Number=Sing|Person=3|PronType=Prs 6 nsubj 6:nsubj _ +5 was be AUX VBD Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin 6 aux 6:aux _ +6 manipulating manipulate VERB VBG Tense=Pres|VerbForm=Part 2 ccomp 2:ccomp _ +7 the the DET DT Definite=Def|PronType=Art 8 det 8:det _ +8 market market NOUN NN Number=Sing 6 obj 6:obj _ +9 with with ADP IN _ 12 case 12:case _ +10 his his PRON PRP$ Case=Gen|Gender=Masc|Number=Sing|Person=3|Poss=Yes|PronType=Prs 12 nmod:poss 12:nmod:poss _ +11 bombing bombing NOUN NN Number=Sing 12 compound 12:compound _ +12 targets target NOUN NNS Number=Plur 6 obl 6:obl:with SpaceAfter=No +13 . . PUNCT . _ 2 punct 2:punct _ +""".lstrip() + + +def test_all_lowercase(): + doc = CoNLL.conll2doc(input_str=LATIN_CONLLU) + assert all_lowercase(doc) + +def test_not_all_lowercase(): + doc = CoNLL.conll2doc(input_str=ENG_CONLLU) + assert not all_lowercase(doc) diff --git a/stanza/stanza/tests/ner/test_bsf_2_beios.py b/stanza/stanza/tests/ner/test_bsf_2_beios.py new file mode 100644 index 0000000000000000000000000000000000000000..16162b59fff421bd114690cca615e3d177972965 --- /dev/null +++ b/stanza/stanza/tests/ner/test_bsf_2_beios.py @@ -0,0 +1,349 @@ +""" +Tests the conversion code for the lang_uk NER dataset +""" + +import unittest +from stanza.utils.datasets.ner.convert_bsf_to_beios import convert_bsf, parse_bsf, BsfInfo + +import pytest +pytestmark = [pytest.mark.travis, pytest.mark.pipeline] + +class TestBsf2Beios(unittest.TestCase): + + def test_empty_markup(self): + res = convert_bsf('', '') + self.assertEqual('', res) + + def test_1line_markup(self): + data = 'тележурналіст Василь' + bsf_markup = 'T1 PERS 14 20 Василь' + expected = '''тележурналіст O +Василь S-PERS''' + self.assertEqual(expected, convert_bsf(data, bsf_markup)) + + def test_1line_follow_markup(self): + data = 'тележурналіст Василь .' + bsf_markup = 'T1 PERS 14 20 Василь' + expected = '''тележурналіст O +Василь S-PERS +. O''' + self.assertEqual(expected, convert_bsf(data, bsf_markup)) + + def test_1line_2tok_markup(self): + data = 'тележурналіст Василь Нагірний .' + bsf_markup = 'T1 PERS 14 29 Василь Нагірний' + expected = '''тележурналіст O +Василь B-PERS +Нагірний E-PERS +. O''' + self.assertEqual(expected, convert_bsf(data, bsf_markup)) + + def test_1line_Long_tok_markup(self): + data = 'А в музеї Гуцульщини і Покуття можна ' + bsf_markup = 'T12 ORG 4 30 музеї Гуцульщини і Покуття' + expected = '''А O +в O +музеї B-ORG +Гуцульщини I-ORG +і I-ORG +Покуття E-ORG +можна O''' + self.assertEqual(expected, convert_bsf(data, bsf_markup)) + + def test_2line_2tok_markup(self): + data = '''тележурналіст Василь Нагірний . +В івано-франківському видавництві «Лілея НВ» вийшла друком''' + bsf_markup = '''T1 PERS 14 29 Василь Нагірний +T2 ORG 67 75 Лілея НВ''' + expected = '''тележурналіст O +Василь B-PERS +Нагірний E-PERS +. O + + +В O +івано-франківському O +видавництві O +« O +Лілея B-ORG +НВ E-ORG +» O +вийшла O +друком O''' + self.assertEqual(expected, convert_bsf(data, bsf_markup)) + + def test_real_markup(self): + data = '''Через напіввоєнний стан в Україні та збільшення телефонних терористичних погроз українці купуватимуть sim-карти тільки за паспортами . +Про це повідомив начальник управління зв'язків зі ЗМІ адміністрації Держспецзв'язку Віталій Кукса . +Він зауважив , що днями відомство опублікує проект змін до правил надання телекомунікаційних послуг , де будуть прописані норми ідентифікації громадян . +Абонентів , які на сьогодні вже мають sim-карту , за словами Віталія Кукси , реєструватимуть , коли ті звертатимуться в службу підтримки свого оператора мобільного зв'язку . +Однак мобільні оператори побоюються , що таке нововведення помітно зменшить продаж стартових пакетів , адже спеціалізовані магазини є лише у містах . +Відтак купити сімку в невеликих населених пунктах буде неможливо . +Крім того , нова процедура ідентифікації абонентів вимагатиме від операторів мобільного зв'язку додаткових витрат . +- Близько 90 % українських абонентів - це абоненти передоплати . +Якщо мова буде йти навіть про поетапну їх ідентифікацію , зробити це буде складно , довго і дорого . +Мобільним операторам доведеться йти на чималі витрати , пов'язані з укладанням і зберіганням договорів , веденням баз даних , - розповіла « Економічній правді » начальник відділу зв'язків з громадськістю « МТС-Україна » Вікторія Рубан . +''' + bsf_markup = '''T1 LOC 26 33 Україні +T2 ORG 203 218 Держспецзв'язку +T3 PERS 219 232 Віталій Кукса +T4 PERS 449 462 Віталія Кукси +T5 ORG 1201 1219 Економічній правді +T6 ORG 1267 1278 МТС-Україна +T7 PERS 1281 1295 Вікторія Рубан +''' + expected = '''Через O +напіввоєнний O +стан O +в O +Україні S-LOC +та O +збільшення O +телефонних O +терористичних O +погроз O +українці O +купуватимуть O +sim-карти O +тільки O +за O +паспортами O +. O + + +Про O +це O +повідомив O +начальник O +управління O +зв'язків O +зі O +ЗМІ O +адміністрації O +Держспецзв'язку S-ORG +Віталій B-PERS +Кукса E-PERS +. O + + +Він O +зауважив O +, O +що O +днями O +відомство O +опублікує O +проект O +змін O +до O +правил O +надання O +телекомунікаційних O +послуг O +, O +де O +будуть O +прописані O +норми O +ідентифікації O +громадян O +. O + + +Абонентів O +, O +які O +на O +сьогодні O +вже O +мають O +sim-карту O +, O +за O +словами O +Віталія B-PERS +Кукси E-PERS +, O +реєструватимуть O +, O +коли O +ті O +звертатимуться O +в O +службу O +підтримки O +свого O +оператора O +мобільного O +зв'язку O +. O + + +Однак O +мобільні O +оператори O +побоюються O +, O +що O +таке O +нововведення O +помітно O +зменшить O +продаж O +стартових O +пакетів O +, O +адже O +спеціалізовані O +магазини O +є O +лише O +у O +містах O +. O + + +Відтак O +купити O +сімку O +в O +невеликих O +населених O +пунктах O +буде O +неможливо O +. O + + +Крім O +того O +, O +нова O +процедура O +ідентифікації O +абонентів O +вимагатиме O +від O +операторів O +мобільного O +зв'язку O +додаткових O +витрат O +. O + + +- O +Близько O +90 O +% O +українських O +абонентів O +- O +це O +абоненти O +передоплати O +. O + + +Якщо O +мова O +буде O +йти O +навіть O +про O +поетапну O +їх O +ідентифікацію O +, O +зробити O +це O +буде O +складно O +, O +довго O +і O +дорого O +. O + + +Мобільним O +операторам O +доведеться O +йти O +на O +чималі O +витрати O +, O +пов'язані O +з O +укладанням O +і O +зберіганням O +договорів O +, O +веденням O +баз O +даних O +, O +- O +розповіла O +« O +Економічній B-ORG +правді E-ORG +» O +начальник O +відділу O +зв'язків O +з O +громадськістю O +« O +МТС-Україна S-ORG +» O +Вікторія B-PERS +Рубан E-PERS +. O''' + self.assertEqual(expected, convert_bsf(data, bsf_markup)) + + +class TestBsf(unittest.TestCase): + + def test_empty_bsf(self): + self.assertEqual(parse_bsf(''), []) + + def test_empty2_bsf(self): + self.assertEqual(parse_bsf(' \n \n'), []) + + def test_1line_bsf(self): + bsf = 'T1 PERS 103 118 Василь Нагірний' + res = parse_bsf(bsf) + expected = BsfInfo('T1', 'PERS', 103, 118, 'Василь Нагірний') + self.assertEqual(len(res), 1) + self.assertEqual(res, [expected]) + + def test_2line_bsf(self): + bsf = '''T9 PERS 778 783 Карла +T10 MISC 814 819 міста''' + res = parse_bsf(bsf) + expected = [BsfInfo('T9', 'PERS', 778, 783, 'Карла'), + BsfInfo('T10', 'MISC', 814, 819, 'міста')] + self.assertEqual(len(res), 2) + self.assertEqual(res, expected) + + def test_multiline_bsf(self): + bsf = '''T3 PERS 220 235 Андрієм Кіщуком +T4 MISC 251 285 А . +Kubler . +Світло і тіні маестро +T5 PERS 363 369 Кіблер''' + res = parse_bsf(bsf) + expected = [BsfInfo('T3', 'PERS', 220, 235, 'Андрієм Кіщуком'), + BsfInfo('T4', 'MISC', 251, 285, '''А . +Kubler . +Світло і тіні маестро'''), + BsfInfo('T5', 'PERS', 363, 369, 'Кіблер')] + self.assertEqual(len(res), len(expected)) + self.assertEqual(res, expected) + + +if __name__ == '__main__': + unittest.main() diff --git a/stanza/stanza/tests/ner/test_ner_training.py b/stanza/stanza/tests/ner/test_ner_training.py new file mode 100644 index 0000000000000000000000000000000000000000..e9b4d69ae313adbb4492f23fd8fa516e2ad3e6cd --- /dev/null +++ b/stanza/stanza/tests/ner/test_ner_training.py @@ -0,0 +1,261 @@ +import json +import logging +import os +import warnings + +import pytest +import torch + +pytestmark = [pytest.mark.travis, pytest.mark.pipeline] + +from stanza.models import ner_tagger +from stanza.models.ner.trainer import Trainer +from stanza.tests import TEST_WORKING_DIR +from stanza.utils.datasets.ner.prepare_ner_file import process_dataset + +logger = logging.getLogger('stanza') + +EN_TRAIN_BIO = """ +Chris B-PERSON +Manning E-PERSON +is O +a O +good O +man O +. O + +He O +works O +in O +Stanford B-ORG +University E-ORG +. O +""".lstrip().replace(" ", "\t") + +EN_DEV_BIO = """ +Chris B-PERSON +Manning E-PERSON +is O +part O +of O +Computer B-ORG +Science E-ORG +""".lstrip().replace(" ", "\t") + +EN_TRAIN_2TAG = """ +Chris B-PERSON B-PER +Manning E-PERSON E-PER +is O O +a O O +good O O +man O O +. O O + +He O O +works O O +in O O +Stanford B-ORG B-ORG +University E-ORG B-ORG +. O O +""".strip().replace(" ", "\t") + +EN_TRAIN_2TAG_EMPTY2 = """ +Chris B-PERSON - +Manning E-PERSON - +is O - +a O - +good O - +man O - +. O - + +He O - +works O - +in O - +Stanford B-ORG - +University E-ORG - +. O - +""".strip().replace(" ", "\t") + +EN_DEV_2TAG = """ +Chris B-PERSON B-PER +Manning E-PERSON E-PER +is O O +part O O +of O O +Computer B-ORG B-ORG +Science E-ORG E-ORG +""".strip().replace(" ", "\t") + +@pytest.fixture(scope="module") +def pretrain_file(): + return f'{TEST_WORKING_DIR}/in/tiny_emb.pt' + +def write_temp_file(filename, bio_data): + bio_filename = os.path.splitext(filename)[0] + ".bio" + with open(bio_filename, "w", encoding="utf-8") as fout: + fout.write(bio_data) + process_dataset(bio_filename, filename) + +def write_temp_2tag(filename, bio_data): + doc = [] + sentences = bio_data.split("\n\n") + for sentence in sentences: + doc.append([]) + for word in sentence.split("\n"): + text, tags = word.split("\t", maxsplit=1) + doc[-1].append({ + "text": text, + "multi_ner": tags.split() + }) + + with open(filename, "w", encoding="utf-8") as fout: + json.dump(doc, fout) + +def get_args(tmp_path, pretrain_file, train_json, dev_json, *extra_args): + save_dir = tmp_path / "models" + args = ["--data_dir", str(tmp_path), + "--wordvec_pretrain_file", pretrain_file, + "--train_file", str(train_json), + "--eval_file", str(dev_json), + "--shorthand", "en_test", + "--max_steps", "100", + "--eval_interval", "40", + "--save_dir", str(save_dir)] + + args = args + list(extra_args) + return args + +def run_two_tag_training(pretrain_file, tmp_path, *extra_args, train_data=EN_TRAIN_2TAG): + train_json = tmp_path / "en_test.train.json" + write_temp_2tag(train_json, train_data) + + dev_json = tmp_path / "en_test.dev.json" + write_temp_2tag(dev_json, EN_DEV_2TAG) + + args = get_args(tmp_path, pretrain_file, train_json, dev_json, *extra_args) + return ner_tagger.main(args) + +def test_basic_two_tag_training(pretrain_file, tmp_path): + trainer = run_two_tag_training(pretrain_file, tmp_path) + assert len(trainer.model.tag_clfs) == 2 + assert len(trainer.model.crits) == 2 + assert len(trainer.vocab['tag'].lens()) == 2 + +def test_two_tag_training_backprop(pretrain_file, tmp_path): + """ + Test that the training is backproping both tags + + We can do this by using the "finetune" mechanism and verifying + that the output tensors are different + """ + trainer = run_two_tag_training(pretrain_file, tmp_path) + + # first, need to save the final model before restarting + # (alternatively, could reload the final checkpoint) + trainer.save(os.path.join(trainer.args['save_dir'], trainer.args['save_name'])) + new_trainer = run_two_tag_training(pretrain_file, tmp_path, "--finetune") + + assert len(trainer.model.tag_clfs) == 2 + assert len(new_trainer.model.tag_clfs) == 2 + for old_clf, new_clf in zip(trainer.model.tag_clfs, new_trainer.model.tag_clfs): + assert not torch.allclose(old_clf.weight, new_clf.weight) + +def test_two_tag_training_c2_backprop(pretrain_file, tmp_path): + """ + Test that the training is backproping only one tag if one column is blank + + We can do this by using the "finetune" mechanism and verifying + that the output tensors are different in just the first column + """ + trainer = run_two_tag_training(pretrain_file, tmp_path) + + # first, need to save the final model before restarting + # (alternatively, could reload the final checkpoint) + trainer.save(os.path.join(trainer.args['save_dir'], trainer.args['save_name'])) + new_trainer = run_two_tag_training(pretrain_file, tmp_path, "--finetune", train_data=EN_TRAIN_2TAG_EMPTY2) + + assert len(trainer.model.tag_clfs) == 2 + assert len(new_trainer.model.tag_clfs) == 2 + assert not torch.allclose(trainer.model.tag_clfs[0].weight, new_trainer.model.tag_clfs[0].weight) + assert torch.allclose(trainer.model.tag_clfs[1].weight, new_trainer.model.tag_clfs[1].weight) + +def test_connected_two_tag_training(pretrain_file, tmp_path): + trainer = run_two_tag_training(pretrain_file, tmp_path, "--connect_output_layers") + assert len(trainer.model.tag_clfs) == 2 + assert len(trainer.model.crits) == 2 + assert len(trainer.vocab['tag'].lens()) == 2 + + # this checks that with the connected output layers, + # the second output layer has its size increased + # by the number of tags known to the first output layer + assert trainer.model.tag_clfs[1].weight.shape[1] == trainer.vocab['tag'].lens()[0] + trainer.model.tag_clfs[0].weight.shape[1] + +def run_training(pretrain_file, tmp_path, *extra_args): + train_json = tmp_path / "en_test.train.json" + write_temp_file(train_json, EN_TRAIN_BIO) + + dev_json = tmp_path / "en_test.dev.json" + write_temp_file(dev_json, EN_DEV_BIO) + + args = get_args(tmp_path, pretrain_file, train_json, dev_json, *extra_args) + return ner_tagger.main(args) + + +def test_train_model_gpu(pretrain_file, tmp_path): + """ + Briefly train an NER model (no expectation of correctness) and check that it is on the GPU + """ + trainer = run_training(pretrain_file, tmp_path) + if not torch.cuda.is_available(): + warnings.warn("Cannot check that the NER model is on the GPU, since GPU is not available") + return + + model = trainer.model + device = next(model.parameters()).device + assert str(device).startswith("cuda") + + +def test_train_model_cpu(pretrain_file, tmp_path): + """ + Briefly train an NER model (no expectation of correctness) and check that it is on the GPU + """ + trainer = run_training(pretrain_file, tmp_path, "--cpu") + + model = trainer.model + device = next(model.parameters()).device + assert str(device).startswith("cpu") + +def model_file_has_bert(filename): + checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True) + return any(x.startswith("bert_model.") for x in checkpoint['model'].keys()) + +def test_with_bert(pretrain_file, tmp_path): + trainer = run_training(pretrain_file, tmp_path, '--bert_model', 'hf-internal-testing/tiny-bert') + model_file = os.path.join(trainer.args['save_dir'], trainer.args['save_name']) + assert not model_file_has_bert(model_file) + +def test_with_bert_finetune(pretrain_file, tmp_path): + trainer = run_training(pretrain_file, tmp_path, '--bert_model', 'hf-internal-testing/tiny-bert', '--bert_finetune') + model_file = os.path.join(trainer.args['save_dir'], trainer.args['save_name']) + assert model_file_has_bert(model_file) + + foo_save_filename = os.path.join(tmp_path, "foo_" + trainer.args['save_name']) + bar_save_filename = os.path.join(tmp_path, "bar_" + trainer.args['save_name']) + trainer.save(foo_save_filename) + assert model_file_has_bert(foo_save_filename) + + # TODO: technically this should still work if we turn off bert finetuning when reloading + reloaded_trainer = Trainer(args=trainer.args, model_file=foo_save_filename) + reloaded_trainer.save(bar_save_filename) + assert model_file_has_bert(bar_save_filename) + +def test_with_peft_finetune(pretrain_file, tmp_path): + # TODO: check that the peft tensors are moving when training? + trainer = run_training(pretrain_file, tmp_path, '--bert_model', 'hf-internal-testing/tiny-bert', '--use_peft') + model_file = os.path.join(trainer.args['save_dir'], trainer.args['save_name']) + checkpoint = torch.load(model_file, lambda storage, loc: storage, weights_only=True) + assert 'bert_lora' in checkpoint + assert not any(x.startswith("bert_model.") for x in checkpoint['model'].keys()) + + # test loading + reloaded_trainer = Trainer(args=trainer.args, model_file=model_file) diff --git a/stanza/stanza/tests/pipeline/pipeline_device_tests.py b/stanza/stanza/tests/pipeline/pipeline_device_tests.py new file mode 100644 index 0000000000000000000000000000000000000000..0f0722b850885659f9b9f569ed3598f327297755 --- /dev/null +++ b/stanza/stanza/tests/pipeline/pipeline_device_tests.py @@ -0,0 +1,45 @@ +""" +Utility methods to check that all processors are on the expected device + +Refactored since it can be used for multiple pipelines +""" + +import warnings + +import torch + +def check_on_gpu(pipeline): + """ + Check that the processors are all on the GPU and that basic execution works + """ + if not torch.cuda.is_available(): + warnings.warn("Unable to run the test that checks the pipeline is on the GPU, as there is no GPU available!") + return + + for name, proc in pipeline.processors.items(): + if proc.trainer is not None: + device = next(proc.trainer.model.parameters()).device + else: + device = next(proc._model.parameters()).device + + assert str(device).startswith("cuda"), "Processor %s was not on the GPU" % name + + # just check that there are no cpu/cuda tensor conflicts + # when running on the GPU + pipeline("This is a small test") + +def check_on_cpu(pipeline): + """ + Check that the processors are all on the CPU and that basic execution works + """ + for name, proc in pipeline.processors.items(): + if proc.trainer is not None: + device = next(proc.trainer.model.parameters()).device + else: + device = next(proc._model.parameters()).device + + assert str(device).startswith("cpu"), "Processor %s was not on the CPU" % name + + # just check that there are no cpu/cuda tensor conflicts + # when running on the CPU + pipeline("This is a small test") diff --git a/stanza/stanza/tests/pipeline/test_pipeline_sentiment_processor.py b/stanza/stanza/tests/pipeline/test_pipeline_sentiment_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..e6f8420ac0b01c6a3002f624f7bece055b249e00 --- /dev/null +++ b/stanza/stanza/tests/pipeline/test_pipeline_sentiment_processor.py @@ -0,0 +1,50 @@ +import gc + +import pytest +import stanza +from stanza.utils.conll import CoNLL +from stanza.models.common.doc import Document + +from stanza.tests import * + +pytestmark = [pytest.mark.pipeline, pytest.mark.travis] + +# data for testing +EN_DOCS = ["Ragavan is terrible and should go away.", "Today is okay.", "Urza's Saga is great."] + +EN_DOC = " ".join(EN_DOCS) + +EXPECTED = [0, 1, 2] + +class TestSentimentPipeline: + @pytest.fixture(scope="class") + def pipeline(self): + """ + A reusable pipeline with the NER module + """ + gc.collect() + return stanza.Pipeline(dir=TEST_MODELS_DIR, processors="tokenize,sentiment") + + def test_simple(self, pipeline): + results = [] + for text in EN_DOCS: + doc = pipeline(text) + assert len(doc.sentences) == 1 + results.append(doc.sentences[0].sentiment) + assert EXPECTED == results + + def test_multiple_sentences(self, pipeline): + doc = pipeline(EN_DOC) + assert len(doc.sentences) == 3 + results = [sentence.sentiment for sentence in doc.sentences] + assert EXPECTED == results + + def test_empty_text(self, pipeline): + """ + Test empty text and a text which might get reduced to empty text by removing dashes + """ + doc = pipeline("") + assert len(doc.sentences) == 0 + + doc = pipeline("--") + assert len(doc.sentences) == 1 diff --git a/stanza/stanza/tests/pipeline/test_requirements.py b/stanza/stanza/tests/pipeline/test_requirements.py new file mode 100644 index 0000000000000000000000000000000000000000..92e04ec6b172f6bba6fb90533b70a16d648be76b --- /dev/null +++ b/stanza/stanza/tests/pipeline/test_requirements.py @@ -0,0 +1,72 @@ +""" +Test the requirements functionality for processors +""" + +import pytest +import stanza + +from stanza.pipeline.core import PipelineRequirementsException +from stanza.pipeline.processor import ProcessorRequirementsException +from stanza.tests import * + +pytestmark = pytest.mark.pipeline + +def check_exception_vals(req_exception, req_exception_vals): + """ + Check the values of a ProcessorRequirementsException against a dict of expected values. + :param req_exception: the ProcessorRequirementsException to evaluate + :param req_exception_vals: expected values for the ProcessorRequirementsException + :return: None + """ + assert isinstance(req_exception, ProcessorRequirementsException) + assert req_exception.processor_type == req_exception_vals['processor_type'] + assert req_exception.processors_list == req_exception_vals['processors_list'] + assert req_exception.err_processor.requires == req_exception_vals['requires'] + + +def test_missing_requirements(): + """ + Try to build several pipelines with bad configs and check thrown exceptions against gold exceptions. + :return: None + """ + # list of (bad configs, list of gold ProcessorRequirementsExceptions that should be thrown) pairs + bad_config_lists = [ + # missing tokenize + ( + # input config + {'processors': 'pos,depparse', 'dir': TEST_MODELS_DIR, 'lang': 'en'}, + # 2 expected exceptions + [ + {'processor_type': 'POSProcessor', 'processors_list': ['pos', 'depparse'], 'provided_reqs': set([]), + 'requires': set(['tokenize'])}, + {'processor_type': 'DepparseProcessor', 'processors_list': ['pos', 'depparse'], + 'provided_reqs': set([]), 'requires': set(['tokenize','pos', 'lemma'])} + ] + ), + # no pos when lemma_pos set to True; for english mwt should not be included in the loaded processor list + ( + # input config + {'processors': 'tokenize,mwt,lemma', 'dir': TEST_MODELS_DIR, 'lang': 'en', 'lemma_pos': True}, + # 1 expected exception + [ + {'processor_type': 'LemmaProcessor', 'processors_list': ['tokenize', 'mwt', 'lemma'], + 'provided_reqs': set(['tokenize', 'mwt']), 'requires': set(['tokenize', 'pos'])} + ] + ) + ] + # try to build each bad config, catch exceptions, check against gold + pipeline_fails = 0 + for bad_config, gold_exceptions in bad_config_lists: + try: + stanza.Pipeline(**bad_config) + except PipelineRequirementsException as e: + pipeline_fails += 1 + assert isinstance(e, PipelineRequirementsException) + assert len(e.processor_req_fails) == len(gold_exceptions) + for processor_req_e, gold_exception in zip(e.processor_req_fails,gold_exceptions): + # compare the thrown ProcessorRequirementsExceptions against gold + check_exception_vals(processor_req_e, gold_exception) + # check pipeline building failed twice + assert pipeline_fails == 2 + + diff --git a/stanza/stanza/tests/tokenization/__init__.py b/stanza/stanza/tests/tokenization/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/stanza/stanza/tests/tokenization/test_tokenize_utils.py b/stanza/stanza/tests/tokenization/test_tokenize_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3218771d20e7d5417636a234d6005edf49e085bc --- /dev/null +++ b/stanza/stanza/tests/tokenization/test_tokenize_utils.py @@ -0,0 +1,220 @@ +""" +Very simple test of the sentence slicing by tags + +TODO: could add a bunch more simple tests for the tokenization utils +""" + +import pytest +import stanza + +from stanza import Pipeline +from stanza.tests import * +from stanza.models.common import doc +from stanza.models.tokenization import data +from stanza.models.tokenization import utils + +pytestmark = [pytest.mark.travis, pytest.mark.pipeline] + +def test_find_spans(): + """ + Test various raw -> span manipulations + """ + raw = ['u', 'n', 'b', 'a', 'n', ' ', 'm', 'o', 'x', ' ', 'o', 'p', 'a', 'l'] + assert utils.find_spans(raw) == [(0, 14)] + + raw = ['u', 'n', 'b', 'a', 'n', ' ', 'm', 'o', 'x', ' ', 'o', 'p', 'a', 'l', ''] + assert utils.find_spans(raw) == [(0, 14)] + + raw = ['', 'u', 'n', 'b', 'a', 'n', ' ', 'm', 'o', 'x', ' ', 'o', 'p', 'a', 'l', ''] + assert utils.find_spans(raw) == [(1, 15)] + + raw = ['', 'u', 'n', 'b', 'a', 'n', ' ', 'm', 'o', 'x', ' ', 'o', 'p', 'a', 'l'] + assert utils.find_spans(raw) == [(1, 15)] + + raw = ['', 'u', 'n', 'b', 'a', 'n', '', 'm', 'o', 'x', ' ', 'o', 'p', 'a', 'l'] + assert utils.find_spans(raw) == [(1, 6), (7, 15)] + +def check_offsets(doc, expected_offsets): + """ + Compare the start_char and end_char of the tokens in the doc with the given list of list of offsets + """ + assert len(doc.sentences) == len(expected_offsets) + for sentence, offsets in zip(doc.sentences, expected_offsets): + assert len(sentence.tokens) == len(offsets) + for token, offset in zip(sentence.tokens, offsets): + assert token.start_char == offset[0] + assert token.end_char == offset[1] + +def test_match_tokens_with_text(): + """ + Test the conversion of pretokenized text to Document + """ + doc = utils.match_tokens_with_text([["This", "is", "a", "test"]], "Thisisatest") + expected_offsets = [[(0, 4), (4, 6), (6, 7), (7, 11)]] + check_offsets(doc, expected_offsets) + + doc = utils.match_tokens_with_text([["This", "is", "a", "test"], ["unban", "mox", "opal", "!"]], "Thisisatest unban mox opal!") + expected_offsets = [[(0, 4), (4, 6), (6, 7), (7, 11)], + [(13, 18), (19, 22), (24, 28), (28, 29)]] + check_offsets(doc, expected_offsets) + + with pytest.raises(ValueError): + doc = utils.match_tokens_with_text([["This", "is", "a", "test"]], "Thisisatestttt") + + with pytest.raises(ValueError): + doc = utils.match_tokens_with_text([["This", "is", "a", "test"]], "Thisisates") + + with pytest.raises(ValueError): + doc = utils.match_tokens_with_text([["This", "iz", "a", "test"]], "Thisisatest") + +def test_long_paragraph(): + """ + Test the tokenizer's capacity to break text up into smaller chunks + """ + pipeline = Pipeline("en", dir=TEST_MODELS_DIR, processors="tokenize") + tokenizer = pipeline.processors['tokenize'] + + raw_text = "TIL not to ask a date to dress up as Smurfette on a first date. " * 100 + + # run a test to make sure the chunk operation is called + # if not, the test isn't actually testing what we need to test + batches = data.DataLoader(tokenizer.config, input_text=raw_text, vocab=tokenizer.vocab, evaluation=True, dictionary=tokenizer.trainer.dictionary) + batches.advance_old_batch = None + with pytest.raises(TypeError): + _, _, _, document = utils.output_predictions(None, tokenizer.trainer, batches, tokenizer.vocab, None, 3000, + orig_text=raw_text, + no_ssplit=tokenizer.config.get('no_ssplit', False)) + + # a new DataLoader should not be crippled as the above one was + batches = data.DataLoader(tokenizer.config, input_text=raw_text, vocab=tokenizer.vocab, evaluation=True, dictionary=tokenizer.trainer.dictionary) + _, _, _, document = utils.output_predictions(None, tokenizer.trainer, batches, tokenizer.vocab, None, 3000, + orig_text=raw_text, + no_ssplit=tokenizer.config.get('no_ssplit', False)) + + document = doc.Document(document, raw_text) + assert len(document.sentences) == 100 + +def test_postprocessor_application(): + """ + Check that the postprocessor behaves correctly by applying the identity postprocessor and hoping that it does indeed return correctly. + """ + + good_tokenization = [['I', 'am', 'Joe.', '⭆⊱⇞', 'Hi', '.'], ["I'm", 'a', 'chicken', '.']] + text = "I am Joe. ⭆⊱⇞ Hi. I'm a chicken." + + target_doc = [[{'id': 1, 'text': 'I', 'start_char': 0, 'end_char': 1}, {'id': 2, 'text': 'am', 'start_char': 2, 'end_char': 4}, {'id': 3, 'text': 'Joe.', 'start_char': 5, 'end_char': 9}, {'id': 4, 'text': '⭆⊱⇞', 'start_char': 10, 'end_char': 13}, {'id': 5, 'text': 'Hi', 'start_char': 14, 'end_char': 16, 'misc': 'SpaceAfter=No'}, {'id': 6, 'text': '.', 'start_char': 16, 'end_char': 17}], [{'id': 1, 'text': "I'm", 'start_char': 18, 'end_char': 21}, {'id': 2, 'text': 'a', 'start_char': 22, 'end_char': 23}, {'id': 3, 'text': 'chicken', 'start_char': 24, 'end_char': 31, 'misc': 'SpaceAfter=No'}, {'id': 4, 'text': '.', 'start_char': 31, 'end_char': 32, 'misc': 'SpaceAfter=No'}]] + + def postprocesor(_): + return good_tokenization + + res = utils.postprocess_doc(target_doc, postprocesor, text) + + assert res == target_doc + +def test_reassembly_indexing(): + """ + Check that the reassembly code counts the indicies correctly, and including OOV chars. + """ + + good_tokenization = [['I', 'am', 'Joe.', '⭆⊱⇞', 'Hi', '.'], ["I'm", 'a', 'chicken', '.']] + good_mwts = [[False for _ in range(len(i))] for i in good_tokenization] + good_expansions = [[None for _ in range(len(i))] for i in good_tokenization] + + text = "I am Joe. ⭆⊱⇞ Hi. I'm a chicken." + + target_doc = [[{'id': 1, 'text': 'I', 'start_char': 0, 'end_char': 1}, {'id': 2, 'text': 'am', 'start_char': 2, 'end_char': 4}, {'id': 3, 'text': 'Joe.', 'start_char': 5, 'end_char': 9}, {'id': 4, 'text': '⭆⊱⇞', 'start_char': 10, 'end_char': 13}, {'id': 5, 'text': 'Hi', 'start_char': 14, 'end_char': 16, 'misc': 'SpaceAfter=No'}, {'id': 6, 'text': '.', 'start_char': 16, 'end_char': 17}], [{'id': 1, 'text': "I'm", 'start_char': 18, 'end_char': 21}, {'id': 2, 'text': 'a', 'start_char': 22, 'end_char': 23}, {'id': 3, 'text': 'chicken', 'start_char': 24, 'end_char': 31, 'misc': 'SpaceAfter=No'}, {'id': 4, 'text': '.', 'start_char': 31, 'end_char': 32, 'misc': 'SpaceAfter=No'}]] + + res = utils.reassemble_doc_from_tokens(good_tokenization, good_mwts, good_expansions, text) + + assert res == target_doc + +def test_reassembly_reference_failures(): + """ + Check that the reassembly code complains correctly when the user adds tokens that doesn't exist + """ + + bad_addition_tokenization = [['Joe', 'Smith', 'lives', 'in', 'Southern', 'California', '.']] + bad_addition_mwts = [[False for _ in range(len(bad_addition_tokenization[0]))]] + bad_addition_expansions = [[None for _ in range(len(bad_addition_tokenization[0]))]] + + bad_inline_tokenization = [['Joe', 'Smith', 'lives', 'in', 'Californiaa', '.']] + bad_inline_mwts = [[False for _ in range(len(bad_inline_tokenization[0]))]] + bad_inline_expansions = [[None for _ in range(len(bad_inline_tokenization[0]))]] + + good_tokenization = [['Joe', 'Smith', 'lives', 'in', 'California', '.']] + good_mwts = [[False for _ in range(len(good_tokenization[0]))]] + good_expansions = [[None for _ in range(len(good_tokenization[0]))]] + + text = "Joe Smith lives in California." + + with pytest.raises(ValueError): + utils.reassemble_doc_from_tokens(bad_addition_tokenization, bad_addition_mwts, bad_addition_expansions, text) + + with pytest.raises(ValueError): + utils.reassemble_doc_from_tokens(bad_inline_tokenization, bad_inline_mwts, bad_inline_mwts, text) + + utils.reassemble_doc_from_tokens(good_tokenization, good_mwts, good_expansions, text) + + + +TRAIN_DATA = """ +# sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0003 +# text = DPA: Iraqi authorities announced that they'd busted up three terrorist cells operating in Baghdad. +1 DPA DPA PROPN NNP Number=Sing 0 root 0:root SpaceAfter=No +2 : : PUNCT : _ 1 punct 1:punct _ +3 Iraqi Iraqi ADJ JJ Degree=Pos 4 amod 4:amod _ +4 authorities authority NOUN NNS Number=Plur 5 nsubj 5:nsubj _ +5 announced announce VERB VBD Mood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin 1 parataxis 1:parataxis _ +6 that that SCONJ IN _ 9 mark 9:mark _ +7-8 they'd _ _ _ _ _ _ _ _ +7 they they PRON PRP Case=Nom|Number=Plur|Person=3|PronType=Prs 9 nsubj 9:nsubj _ +8 'd have AUX VBD Mood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin 9 aux 9:aux _ +9 busted bust VERB VBN Tense=Past|VerbForm=Part 5 ccomp 5:ccomp _ +10 up up ADP RP _ 9 compound:prt 9:compound:prt _ +11 three three NUM CD NumForm=Digit|NumType=Card 13 nummod 13:nummod _ +12 terrorist terrorist ADJ JJ Degree=Pos 13 amod 13:amod _ +13 cells cell NOUN NNS Number=Plur 9 obj 9:obj _ +14 operating operate VERB VBG VerbForm=Ger 13 acl 13:acl _ +15 in in ADP IN _ 16 case 16:case _ +16 Baghdad Baghdad PROPN NNP Number=Sing 14 obl 14:obl:in SpaceAfter=No +17 . . PUNCT . _ 1 punct 1:punct _ + +# sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0004 +# text = Two of them were being run by 2 officials of the Ministry of the Interior! +1 Two two NUM CD NumForm=Word|NumType=Card 6 nsubj:pass 6:nsubj:pass _ +2 of of ADP IN _ 3 case 3:case _ +3 them they PRON PRP Case=Acc|Number=Plur|Person=3|PronType=Prs 1 nmod 1:nmod:of _ +4 were be AUX VBD Mood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin 6 aux 6:aux _ +5 being be AUX VBG VerbForm=Ger 6 aux:pass 6:aux:pass _ +6 run run VERB VBN Tense=Past|VerbForm=Part|Voice=Pass 0 root 0:root _ +7 by by ADP IN _ 9 case 9:case _ +8 2 2 NUM CD NumForm=Digit|NumType=Card 9 nummod 9:nummod _ +9 officials official NOUN NNS Number=Plur 6 obl 6:obl:by _ +10 of of ADP IN _ 12 case 12:case _ +11 the the DET DT Definite=Def|PronType=Art 12 det 12:det _ +12 Ministry Ministry PROPN NNP Number=Sing 9 nmod 9:nmod:of _ +13 of of ADP IN _ 15 case 15:case _ +14 the the DET DT Definite=Def|PronType=Art 15 det 15:det _ +15 Interior Interior PROPN NNP Number=Sing 12 nmod 12:nmod:of SpaceAfter=No +16 ! ! PUNCT . _ 6 punct 6:punct _ + +""".lstrip() + +def test_lexicon_from_training_data(tmp_path): + """ + Test a couple aspects of building a lexicon from training data + + expected number of words eliminated for being too long + duplicate words counted once + numbers eliminated + """ + conllu_file = str(tmp_path / "train.conllu") + with open(conllu_file, "w", encoding="utf-8") as fout: + fout.write(TRAIN_DATA) + + lexicon, num_dict_feat = utils.create_lexicon("en_test", conllu_file) + lexicon = sorted(lexicon) + expected_lexicon = ["'d", 'announced', 'baghdad', 'being', 'busted', 'by', 'cells', 'dpa', 'in', 'interior', 'iraqi', 'ministry', 'of', 'officials', 'operating', 'run', 'terrorist', 'that', 'the', 'them', 'they', "they'd", 'three', 'two', 'up', 'were'] + assert lexicon == expected_lexicon + assert num_dict_feat == max(len(x) for x in lexicon) + diff --git a/stanza/stanza/utils/charlm/__init__.py b/stanza/stanza/utils/charlm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/stanza/stanza/utils/charlm/conll17_to_text.py b/stanza/stanza/utils/charlm/conll17_to_text.py new file mode 100644 index 0000000000000000000000000000000000000000..0f42e3ff08490d9ad12644c4ae73e76052ad4657 --- /dev/null +++ b/stanza/stanza/utils/charlm/conll17_to_text.py @@ -0,0 +1,93 @@ +""" +Turns a directory of conllu files from the conll 2017 shared task to a text file + +Part of the process for building a charlm dataset + +python conll17_to_text.py + +This is an extension of the original script: + https://github.com/stanfordnlp/stanza-scripts/blob/master/charlm/conll17/conll2txt.py + +To build a new charlm for a new language from a conll17 dataset: +- look for conll17 shared task data, possibly here: + https://lindat.mff.cuni.cz/repository/xmlui/handle/11234/1-1989 +- python3 stanza/utils/charlm/conll17_to_text.py ~/extern_data/conll17/Bulgarian --output_directory extern_data/charlm_raw/bg/conll17 +- python3 stanza/utils/charlm/make_lm_data.py --langs bg extern_data/charlm_raw extern_data/charlm/ +""" + +import argparse +import lzma +import sys +import os + +def process_file(input_filename, output_directory, compress): + if not input_filename.endswith('.conllu') and not input_filename.endswith(".conllu.xz"): + print("Skipping {}".format(input_filename)) + return + + if input_filename.endswith(".xz"): + open_fn = lambda x: lzma.open(x, mode='rt') + output_filename = input_filename[:-3].replace(".conllu", ".txt") + else: + open_fn = lambda x: open(x) + output_filename = input_filename.replace('.conllu', '.txt') + + if output_directory: + output_filename = os.path.join(output_directory, os.path.split(output_filename)[1]) + + if compress: + output_filename = output_filename + ".xz" + output_fn = lambda x: lzma.open(x, mode='wt') + else: + output_fn = lambda x: open(x, mode='w') + + if os.path.exists(output_filename): + print("Cowardly refusing to overwrite %s" % output_filename) + return + + print("Converting %s to %s" % (input_filename, output_filename)) + with open_fn(input_filename) as fin: + sentences = [] + sentence = [] + for line in fin: + line = line.strip() + if len(line) == 0: # new sentence + sentences.append(sentence) + sentence = [] + continue + if line[0] == '#': # comment + continue + splitline = line.split('\t') + assert(len(splitline) == 10) # correct conllu + id, word = splitline[0], splitline[1] + if '-' not in id: # not mwt token + sentence.append(word) + + if sentence: + sentences.append(sentence) + + print(" Read in {} sentences".format(len(sentences))) + with output_fn(output_filename) as fout: + fout.write('\n'.join([' '.join(sentence) for sentence in sentences])) + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("input_directory", help="Root directory with conllu or conllu.xz files.") + parser.add_argument("--output_directory", default=None, help="Directory to output to. Will output to input_directory if None") + parser.add_argument("--no_xz_output", default=True, dest="xz_output", action="store_false", help="Output compressed xz files") + args = parser.parse_args() + return args + + +if __name__ == '__main__': + args = parse_args() + directory = args.input_directory + filenames = sorted(os.listdir(directory)) + print("Files to process in {}: {}".format(directory, filenames)) + print("Processing to .xz files: {}".format(args.xz_output)) + + if args.output_directory: + os.makedirs(args.output_directory, exist_ok=True) + for filename in filenames: + process_file(os.path.join(directory, filename), args.output_directory, args.xz_output) + diff --git a/stanza/stanza/utils/charlm/dump_oscar.py b/stanza/stanza/utils/charlm/dump_oscar.py new file mode 100644 index 0000000000000000000000000000000000000000..7a709c9dd9daff3618eda797d73d0070b926549a --- /dev/null +++ b/stanza/stanza/utils/charlm/dump_oscar.py @@ -0,0 +1,120 @@ +""" +This script downloads and extracts the text from an Oscar crawl on HuggingFace + +To use, just run + +dump_oscar.py + +It will download the dataset and output all of the text to the --output directory. +Files will be broken into pieces to avoid having one giant file. +By default, files will also be compressed with xz (although this can be turned off) +""" + +import argparse +import lzma +import math +import os + +from tqdm import tqdm + +from datasets import get_dataset_split_names +from datasets import load_dataset + +from stanza.models.common.constant import lang_to_langcode + +def parse_args(): + """ + A few specific arguments for the dump program + + Uses lang_to_langcode to process args.language, hopefully converting + a variety of possible formats to the short code used by HuggingFace + """ + parser = argparse.ArgumentParser() + parser.add_argument("language", help="Language to download") + parser.add_argument("--output", default="oscar_dump", help="Path for saving files") + parser.add_argument("--no_xz", dest="xz", default=True, action='store_false', help="Don't xz the files - default is to compress while writing") + parser.add_argument("--prefix", default="oscar_dump", help="Prefix to use for the pieces of the dataset") + parser.add_argument("--version", choices=["2019", "2023"], default="2023", help="Which version of the Oscar dataset to download") + + args = parser.parse_args() + args.language = lang_to_langcode(args.language) + return args + +def download_2023(args): + dataset = load_dataset('oscar-corpus/OSCAR-2301', 'sd') + split_names = list(dataset.keys()) + + +def main(): + args = parse_args() + + # this is the 2019 version. for 2023, you can do + # dataset = load_dataset('oscar-corpus/OSCAR-2301', 'sd') + language = args.language + if args.version == "2019": + dataset_name = "unshuffled_deduplicated_%s" % language + try: + split_names = get_dataset_split_names("oscar", dataset_name) + except ValueError as e: + raise ValueError("Language %s not available in HuggingFace Oscar" % language) from e + + if len(split_names) > 1: + raise ValueError("Unexpected split_names: {}".format(split_names)) + + dataset = load_dataset("oscar", dataset_name) + dataset = dataset[split_names[0]] + size_in_bytes = dataset.info.size_in_bytes + process_item = lambda x: x['text'] + elif args.version == "2023": + dataset = load_dataset("oscar-corpus/OSCAR-2301", language) + split_names = list(dataset.keys()) + if len(split_names) > 1: + raise ValueError("Unexpected split_names: {}".format(split_names)) + # it's not clear if some languages don't support size_in_bytes, + # or if there was an update to datasets which now allows that + # + # previously we did: + # dataset = dataset[split_names[0]]['text'] + # size_in_bytes = sum(len(x) for x in dataset) + # process_item = lambda x: x + dataset = dataset[split_names[0]] + size_in_bytes = dataset.info.size_in_bytes + process_item = lambda x: x['text'] + else: + raise AssertionError("Unknown version: %s" % args.version) + + chunks = max(1.0, size_in_bytes // 1e8) # an overestimate + id_len = max(3, math.floor(math.log10(chunks)) + 1) + + if args.xz: + format_str = "%s_%%0%dd.txt.xz" % (args.prefix, id_len) + fopen = lambda file_idx: lzma.open(os.path.join(args.output, format_str % file_idx), "wt") + else: + format_str = "%s_%%0%dd.txt" % (args.prefix, id_len) + fopen = lambda file_idx: open(os.path.join(args.output, format_str % file_idx), "w") + + print("Writing dataset to %s" % args.output) + print("Dataset length: {}".format(size_in_bytes)) + os.makedirs(args.output, exist_ok=True) + + file_idx = 0 + file_len = 0 + total_len = 0 + fout = fopen(file_idx) + + for item in tqdm(dataset): + text = process_item(item) + fout.write(text) + fout.write("\n") + file_len += len(text) + file_len += 1 + if file_len > 1e8: + file_len = 0 + fout.close() + file_idx = file_idx + 1 + fout = fopen(file_idx) + + fout.close() + +if __name__ == '__main__': + main() diff --git a/stanza/stanza/utils/charlm/make_lm_data.py b/stanza/stanza/utils/charlm/make_lm_data.py new file mode 100644 index 0000000000000000000000000000000000000000..4bd28e5ae35d4d30414f7e85cf354c36f3db43ec --- /dev/null +++ b/stanza/stanza/utils/charlm/make_lm_data.py @@ -0,0 +1,162 @@ +""" +Create Stanza character LM train/dev/test data, by reading from txt files in each source corpus directory, +shuffling, splitting and saving into multiple smaller files (50MB by default) in a target directory. + +This script assumes the following source directory structures: + - {src_dir}/{language}/{corpus}/*.txt +It will read from all source .txt files and create the following target directory structures: + - {tgt_dir}/{language}/{corpus} +and within each target directory, it will create the following files: + - train/*.txt + - dev.txt + - test.txt +Args: + - src_root: root directory of the source. + - tgt_root: root directory of the target. + - langs: a list of language codes to process; if specified, languages not in this list will be ignored. +Note: edit the {EXCLUDED_FOLDERS} variable to exclude more folders in the source directory. +""" + +import argparse +import glob +import os +from pathlib import Path +import shutil +import subprocess +import tempfile + +from tqdm import tqdm + +EXCLUDED_FOLDERS = ['raw_corpus'] + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("src_root", default="src", help="Root directory with all source files. Expected structure is root dir -> language dirs -> package dirs -> text files to process") + parser.add_argument("tgt_root", default="tgt", help="Root directory with all target files.") + parser.add_argument("--langs", default="", help="A list of language codes to process. If not set, all languages under src_root will be processed.") + parser.add_argument("--packages", default="", help="A list of packages to process. If not set, all packages under the languages found will be processed.") + parser.add_argument("--no_xz_output", default=True, dest="xz_output", action="store_false", help="Output compressed xz files") + parser.add_argument("--split_size", default=50, type=int, help="How large to make each split, in MB") + parser.add_argument("--no_make_test_file", default=True, dest="make_test_file", action="store_false", help="Don't save a test file. Honestly, we never even use it. Best for low resource languages where every bit helps") + args = parser.parse_args() + + print("Processing files:") + print(f"source root: {args.src_root}") + print(f"target root: {args.tgt_root}") + print("") + + langs = [] + if len(args.langs) > 0: + langs = args.langs.split(',') + print("Only processing the following languages: " + str(langs)) + + packages = [] + if len(args.packages) > 0: + packages = args.packages.split(',') + print("Only processing the following packages: " + str(packages)) + + src_root = Path(args.src_root) + tgt_root = Path(args.tgt_root) + + lang_dirs = os.listdir(src_root) + lang_dirs = [l for l in lang_dirs if l not in EXCLUDED_FOLDERS] # skip excluded + lang_dirs = [l for l in lang_dirs if os.path.isdir(src_root / l)] # skip non-directory + if len(langs) > 0: # filter languages if specified + lang_dirs = [l for l in lang_dirs if l in langs] + print(f"{len(lang_dirs)} total languages found:") + print(lang_dirs) + print("") + + split_size = int(args.split_size * 1024 * 1024) + + for lang in lang_dirs: + lang_root = src_root / lang + data_dirs = os.listdir(lang_root) + if len(packages) > 0: + data_dirs = [d for d in data_dirs if d in packages] + data_dirs = [d for d in data_dirs if os.path.isdir(lang_root / d)] + print(f"{len(data_dirs)} total corpus found for language {lang}.") + print(data_dirs) + print("") + + for dataset_name in data_dirs: + src_dir = lang_root / dataset_name + tgt_dir = tgt_root / lang / dataset_name + + if not os.path.exists(tgt_dir): + os.makedirs(tgt_dir) + print(f"-> Processing {lang}-{dataset_name}") + prepare_lm_data(src_dir, tgt_dir, lang, dataset_name, args.xz_output, split_size, args.make_test_file) + + print("") + +def prepare_lm_data(src_dir, tgt_dir, lang, dataset_name, compress, split_size, make_test_file): + """ + Combine, shuffle and split data into smaller files, following a naming convention. + """ + assert isinstance(src_dir, Path) + assert isinstance(tgt_dir, Path) + with tempfile.TemporaryDirectory(dir=tgt_dir) as tempdir: + tgt_tmp = os.path.join(tempdir, f"{lang}-{dataset_name}.tmp") + print(f"--> Copying files into {tgt_tmp}...") + # TODO: we can do this without the shell commands + input_files = glob.glob(str(src_dir) + '/*.txt') + glob.glob(str(src_dir) + '/*.txt.xz') + glob.glob(str(src_dir) + '/*.txt.gz') + for src_fn in tqdm(input_files): + if src_fn.endswith(".txt"): + cmd = f"cat {src_fn} >> {tgt_tmp}" + subprocess.run(cmd, shell=True) + elif src_fn.endswith(".txt.xz"): + cmd = f"xzcat {src_fn} >> {tgt_tmp}" + subprocess.run(cmd, shell=True) + elif src_fn.endswith(".txt.gz"): + cmd = f"zcat {src_fn} >> {tgt_tmp}" + subprocess.run(cmd, shell=True) + else: + raise AssertionError("should not have found %s" % src_fn) + tgt_tmp_shuffled = os.path.join(tempdir, f"{lang}-{dataset_name}.tmp.shuffled") + + print(f"--> Shuffling files into {tgt_tmp_shuffled}...") + cmd = f"cat {tgt_tmp} | shuf > {tgt_tmp_shuffled}" + result = subprocess.run(cmd, shell=True) + if result.returncode != 0: + raise RuntimeError("Failed to shuffle files!") + size = os.path.getsize(tgt_tmp_shuffled) / 1024 / 1024 / 1024 + print(f"--> Shuffled file size: {size:.4f} GB") + if size < 0.1: + raise RuntimeError("Not enough data found to build a charlm. At least 100MB data expected") + + print(f"--> Splitting into smaller files of size {split_size} ...") + train_dir = tgt_dir / 'train' + if not os.path.exists(train_dir): # make training dir + os.makedirs(train_dir) + cmd = f"split -C {split_size} -a 4 -d --additional-suffix .txt {tgt_tmp_shuffled} {train_dir}/{lang}-{dataset_name}-" + result = subprocess.run(cmd, shell=True) + if result.returncode != 0: + raise RuntimeError("Failed to split files!") + total = len(glob.glob(f'{train_dir}/*.txt')) + print(f"--> {total} total files generated.") + if total < 3: + raise RuntimeError("Something went wrong! %d file(s) produced by shuffle and split, expected at least 3" % total) + + dev_file = f"{tgt_dir}/dev.txt" + test_file = f"{tgt_dir}/test.txt" + if make_test_file: + print("--> Creating dev and test files...") + shutil.move(f"{train_dir}/{lang}-{dataset_name}-0000.txt", dev_file) + shutil.move(f"{train_dir}/{lang}-{dataset_name}-0001.txt", test_file) + txt_files = [dev_file, test_file] + glob.glob(f'{train_dir}/*.txt') + else: + print("--> Creating dev file...") + shutil.move(f"{train_dir}/{lang}-{dataset_name}-0000.txt", dev_file) + txt_files = [dev_file] + glob.glob(f'{train_dir}/*.txt') + + if compress: + print("--> Compressing files...") + for txt_file in tqdm(txt_files): + subprocess.run(['xz', txt_file]) + + print("--> Cleaning up...") + print(f"--> All done for {lang}-{dataset_name}.\n") + +if __name__ == "__main__": + main() diff --git a/stanza/stanza/utils/constituency/check_transitions.py b/stanza/stanza/utils/constituency/check_transitions.py new file mode 100644 index 0000000000000000000000000000000000000000..83047c45a065e5096b02669681e2a258e8ee258c --- /dev/null +++ b/stanza/stanza/utils/constituency/check_transitions.py @@ -0,0 +1,27 @@ +import argparse + +from stanza.models.constituency import transition_sequence +from stanza.models.constituency import tree_reader +from stanza.models.constituency.parse_transitions import TransitionScheme +from stanza.models.constituency.parse_tree import Tree +from stanza.models.constituency.utils import verify_transitions + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--train_file', type=str, default="data/constituency/en_ptb3_train.mrg", help='Input file for data loader.') + parser.add_argument('--transition_scheme', default=TransitionScheme.IN_ORDER, type=lambda x: TransitionScheme[x.upper()], + help='Transition scheme to use. {}'.format(", ".join(x.name for x in TransitionScheme))) + parser.add_argument('--reversed', default=False, action='store_true', help='Do the transition sequence reversed') + parser.add_argument('--iterations', default=30, type=int, help='How many times to iterate, such as if doing a cProfile') + args = parser.parse_args() + args = vars(args) + + train_trees = tree_reader.read_treebank(args['train_file']) + unary_limit = max(t.count_unary_depth() for t in train_trees) + 1 + train_sequences, train_transitions = transition_sequence.convert_trees_to_sequences(train_trees, "training", args['transition_scheme'], args['reversed']) + root_labels = Tree.get_root_labels(train_trees) + for i in range(args['iterations']): + verify_transitions(train_trees, train_sequences, args['transition_scheme'], unary_limit, args['reversed'], "train", root_labels) + +if __name__ == '__main__': + main() diff --git a/stanza/stanza/utils/constituency/list_tensors.py b/stanza/stanza/utils/constituency/list_tensors.py new file mode 100644 index 0000000000000000000000000000000000000000..0e852165a1d03a3d19e3a7ae30417fedf7e8c854 --- /dev/null +++ b/stanza/stanza/utils/constituency/list_tensors.py @@ -0,0 +1,16 @@ +""" +Lists all the tensors in a constituency model. + +Currently useful in combination with torchshow for displaying a series of tensors as they change. +""" + +import sys + +from stanza.models.constituency.trainer import Trainer + + +trainer = Trainer.load(sys.argv[1]) +model = trainer.model + +for name, param in model.named_parameters(): + print(name, param.requires_grad) diff --git a/stanza/stanza/utils/datasets/__init__.py b/stanza/stanza/utils/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/stanza/stanza/utils/datasets/contract_mwt.py b/stanza/stanza/utils/datasets/contract_mwt.py new file mode 100644 index 0000000000000000000000000000000000000000..2f14919d25c47f74bae88663fbaac2700d49a056 --- /dev/null +++ b/stanza/stanza/utils/datasets/contract_mwt.py @@ -0,0 +1,46 @@ +import sys + +def contract_mwt(infile, outfile, ignore_gapping=True): + """ + Simplify the gold tokenizer data for use as MWT processor test files + + The simplifications are to remove the expanded MWTs, and in the + case of ignore_gapping=True, remove any copy words for the dependencies + """ + with open(outfile, 'w') as fout: + with open(infile, 'r') as fin: + idx = 0 + mwt_begin = 0 + mwt_end = -1 + for line in fin: + line = line.strip() + + if line.startswith('#'): + print(line, file=fout) + continue + elif len(line) <= 0: + print(line, file=fout) + idx = 0 + mwt_begin = 0 + mwt_end = -1 + continue + + line = line.split('\t') + + # ignore gapping word + if ignore_gapping and '.' in line[0]: + continue + + idx += 1 + if '-' in line[0]: + mwt_begin, mwt_end = [int(x) for x in line[0].split('-')] + print("{}\t{}\t{}".format(idx, "\t".join(line[1:-1]), "MWT=Yes" if line[-1] == '_' else line[-1] + "|MWT=Yes"), file=fout) + idx -= 1 + elif mwt_begin <= idx <= mwt_end: + continue + else: + print("{}\t{}".format(idx, "\t".join(line[1:])), file=fout) + +if __name__ == '__main__': + contract_mwt(sys.argv[1], sys.argv[2]) + diff --git a/stanza/stanza/utils/datasets/coref/__init__.py b/stanza/stanza/utils/datasets/coref/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/stanza/stanza/utils/datasets/coref/convert_ontonotes.py b/stanza/stanza/utils/datasets/coref/convert_ontonotes.py new file mode 100644 index 0000000000000000000000000000000000000000..d2f8e5eae3ca6c05c09589dad7e2f0741e7f0392 --- /dev/null +++ b/stanza/stanza/utils/datasets/coref/convert_ontonotes.py @@ -0,0 +1,80 @@ +import json +import os + +import stanza + +from stanza.models.constituency import tree_reader +from stanza.utils.default_paths import get_default_paths +from stanza.utils.get_tqdm import get_tqdm +from stanza.utils.datasets.coref.utils import process_document + +tqdm = get_tqdm() + +def read_paragraphs(section): + for doc in section: + part_id = None + paragraph = [] + for sentence in doc['sentences']: + if part_id is None: + part_id = sentence['part_id'] + elif part_id != sentence['part_id']: + yield doc['document_id'], part_id, paragraph + paragraph = [] + part_id = sentence['part_id'] + paragraph.append(sentence) + if paragraph != []: + yield doc['document_id'], part_id, paragraph + +def convert_dataset_section(pipe, section): + processed_section = [] + section = list(x for x in read_paragraphs(section)) + + for idx, (doc_id, part_id, paragraph) in enumerate(tqdm(section)): + sentences = [x['words'] for x in paragraph] + coref_spans = [x['coref_spans'] for x in paragraph] + sentence_speakers = [x['speaker'] for x in paragraph] + + processed = process_document(pipe, doc_id, part_id, sentences, coref_spans, sentence_speakers) + processed_section.append(processed) + return processed_section + +SECTION_NAMES = {"train": "train", + "dev": "validation", + "test": "test"} + +def process_dataset(short_name, ontonotes_path, coref_output_path): + try: + from datasets import load_dataset + except ImportError as e: + raise ImportError("Please install the datasets package to process OntoNotes coref with Stanza") + + if short_name == 'en_ontonotes': + config_name = 'english_v4' + elif short_name in ('zh_ontonotes', 'zh-hans_ontonotes'): + config_name = 'chinese_v4' + elif short_name == 'ar_ontonotes': + config_name = 'arabic_v4' + else: + raise ValueError("Unknown short name for downloading ontonotes: %s" % short_name) + + pipe = stanza.Pipeline("en", processors="tokenize,pos,lemma,depparse", package="default_accurate", tokenize_pretokenized=True) + dataset = load_dataset("conll2012_ontonotesv5", config_name, cache_dir=ontonotes_path) + for section, hf_name in SECTION_NAMES.items(): + #for section, hf_name in [("test", "test")]: + print("Processing %s" % section) + converted_section = convert_dataset_section(pipe, dataset[hf_name]) + output_filename = os.path.join(coref_output_path, "%s.%s.json" % (short_name, section)) + with open(output_filename, "w", encoding="utf-8") as fout: + json.dump(converted_section, fout, indent=2) + + +def main(): + paths = get_default_paths() + coref_input_path = paths['COREF_BASE'] + ontonotes_path = os.path.join(coref_input_path, "english", "en_ontonotes") + coref_output_path = paths['COREF_DATA_DIR'] + process_dataset("en_ontonotes", ontonotes_path, coref_output_path) + +if __name__ == '__main__': + main() + diff --git a/stanza/stanza/utils/datasets/coref/convert_udcoref.py b/stanza/stanza/utils/datasets/coref/convert_udcoref.py new file mode 100644 index 0000000000000000000000000000000000000000..72b0e8c18a071582198b79614a10822d155c33f3 --- /dev/null +++ b/stanza/stanza/utils/datasets/coref/convert_udcoref.py @@ -0,0 +1,276 @@ +from collections import defaultdict +import json +import os +import re +import glob + +from stanza.utils.default_paths import get_default_paths +from stanza.utils.get_tqdm import get_tqdm +from stanza.utils.datasets.coref.utils import find_cconj_head + +from stanza.utils.conll import CoNLL + +from random import Random + +import argparse + +augment_random = Random(7) +split_random = Random(8) + +tqdm = get_tqdm() +IS_UDCOREF_FORMAT = True +UDCOREF_ADDN = 0 if not IS_UDCOREF_FORMAT else 1 + +def process_documents(docs, augment=False): + processed_section = [] + + for idx, (doc, doc_id, lang) in enumerate(tqdm(docs)): + # drop the last token 10% of the time + if augment: + for i in doc.sentences: + if len(i.words) > 1: + if augment_random.random() < 0.1: + i.tokens = i.tokens[:-1] + i.words = i.words[:-1] + + # extract the entities + # get sentence words and lengths + sentences = [[j.text for j in i.words] + for i in doc.sentences] + sentence_lens = [len(x.words) for x in doc.sentences] + + cased_words = [] + for x in sentences: + if augment: + # modify case of the first word with 50% chance + if augment_random.random() < 0.5: + x[0] = x[0].lower() + + for y in x: + cased_words.append(y) + + sent_id = [y for idx, sent_len in enumerate(sentence_lens) for y in [idx] * sent_len] + + word_total = 0 + heads = [] + # TODO: does SD vs UD matter? + deprel = [] + for sentence in doc.sentences: + for word in sentence.words: + deprel.append(word.deprel) + if word.head == 0: + heads.append("null") + else: + heads.append(word.head - 1 + word_total) + word_total += len(sentence.words) + + span_clusters = defaultdict(list) + word_clusters = defaultdict(list) + head2span = [] + word_total = 0 + SPANS = re.compile(r"(\(\w+|[%\w]+\))") + for parsed_sentence in doc.sentences: + # spans regex + # parse the misc column, leaving on "Entity" entries + misc = [[k.split("=") + for k in j + if k.split("=")[0] == "Entity"] + for i in parsed_sentence.words + for j in [i.misc.split("|") if i.misc else []]] + # and extract the Entity entry values + entities = [i[0][1] if len(i) > 0 else None for i in misc] + # extract reference information + refs = [SPANS.findall(i) if i else [] for i in entities] + # and calculate spans: the basic rule is (e... begins a reference + # and ) without e before ends the most recent reference + # every single time we get a closing element, we pop it off + # the refdict and insert the pair to final_refs + refdict = defaultdict(list) + final_refs = defaultdict(list) + last_ref = None + for indx, i in enumerate(refs): + for j in i: + # this is the beginning of a reference + if j[0] == "(": + refdict[j[1+UDCOREF_ADDN:]].append(indx) + last_ref = j[1+UDCOREF_ADDN:] + # at the end of a reference, if we got exxxxx, that ends + # a particular refereenc; otherwise, it ends the last reference + elif j[-1] == ")" and j[UDCOREF_ADDN:-1].isnumeric(): + if (not UDCOREF_ADDN) or j[0] == "e": + try: + final_refs[j[UDCOREF_ADDN:-1]].append((refdict[j[UDCOREF_ADDN:-1]].pop(-1), indx)) + except IndexError: + # this is probably zero anaphora + continue + elif j[-1] == ")": + final_refs[last_ref].append((refdict[last_ref].pop(-1), indx)) + last_ref = None + final_refs = dict(final_refs) + # convert it to the right format (specifically, in (ref, start, end) tuples) + coref_spans = [] + for k, v in final_refs.items(): + for i in v: + coref_spans.append([int(k), i[0], i[1]]) + sentence_upos = [x.upos for x in parsed_sentence.words] + sentence_heads = [x.head - 1 if x.head > 0 else None for x in parsed_sentence.words] + for span in coref_spans: + # input is expected to be start word, end word + 1 + # counting from 0 + # whereas the OntoNotes coref_span is [start_word, end_word] inclusive + span_start = span[1] + word_total + span_end = span[2] + word_total + 1 + candidate_head = find_cconj_head(sentence_heads, sentence_upos, span[1], span[2]+1) + if candidate_head is None: + for candidate_head in range(span[1], span[2] + 1): + # stanza uses 0 to mark the head, whereas OntoNotes is counting + # words from 0, so we have to subtract 1 from the stanza heads + #print(span, candidate_head, parsed_sentence.words[candidate_head].head - 1) + # treat the head of the phrase as the first word that has a head outside the phrase + if (parsed_sentence.words[candidate_head].head - 1 < span[1] or + parsed_sentence.words[candidate_head].head - 1 > span[2]): + break + else: + # if none have a head outside the phrase (circular??) + # then just take the first word + candidate_head = span[1] + #print("----> %d" % candidate_head) + candidate_head += word_total + span_clusters[span[0]].append((span_start, span_end)) + word_clusters[span[0]].append(candidate_head) + head2span.append((candidate_head, span_start, span_end)) + word_total += len(parsed_sentence.words) + span_clusters = sorted([sorted(values) for _, values in span_clusters.items()]) + word_clusters = sorted([sorted(values) for _, values in word_clusters.items()]) + head2span = sorted(head2span) + + processed = { + "document_id": doc_id, + "cased_words": cased_words, + "sent_id": sent_id, + "part_id": idx, + # "pos": pos, + "deprel": deprel, + "head": heads, + "span_clusters": span_clusters, + "word_clusters": word_clusters, + "head2span": head2span, + "lang": lang + } + processed_section.append(processed) + return processed_section + +def process_dataset(short_name, coref_output_path, split_test, train_files, dev_files): + section_names = ('train', 'dev') + section_filenames = [train_files, dev_files] + sections = [] + + test_sections = [] + + for section, filenames in zip(section_names, section_filenames): + input_file = [] + for load in filenames: + lang = load.split("/")[-1].split("_")[0] + print("Ingesting %s from %s of lang %s" % (section, load, lang)) + docs = CoNLL.conll2multi_docs(load) + print(" Ingested %d documents" % len(docs)) + if split_test and section == 'train': + test_section = [] + train_section = [] + for i in docs: + # reseed for each doc so that we can attempt to keep things stable in the event + # of different file orderings or some change to the number of documents + split_random = Random(i.sentences[0].doc_id + i.sentences[0].text) + if split_random.random() < split_test: + test_section.append((i, i.sentences[0].doc_id, lang)) + else: + train_section.append((i, i.sentences[0].doc_id, lang)) + if len(test_section) == 0 and len(train_section) >= 2: + idx = split_random.randint(0, len(train_section) - 1) + test_section = [train_section[idx]] + train_section = train_section[:idx] + train_section[idx+1:] + print(" Splitting %d documents from %s for test" % (len(test_section), load)) + input_file.extend(train_section) + test_sections.append(test_section) + else: + for i in docs: + input_file.append((i, i.sentences[0].doc_id, lang)) + print("Ingested %d total documents" % len(input_file)) + sections.append(input_file) + + if split_test: + section_names = ('train', 'dev', 'test') + full_test_section = [] + for filename, test_section in zip(filenames, test_sections): + # TODO: could write dataset-specific test sections as well + full_test_section.extend(test_section) + sections.append(full_test_section) + + + for section_data, section_name in zip(sections, section_names): + converted_section = process_documents(section_data, augment=(section_name=="train")) + + os.makedirs(coref_output_path, exist_ok=True) + output_filename = os.path.join(coref_output_path, "%s.%s.json" % (short_name, section_name)) + with open(output_filename, "w", encoding="utf-8") as fout: + json.dump(converted_section, fout, indent=2) + +def get_dataset_by_language(coref_input_path, langs): + conll_path = os.path.join(coref_input_path, "CorefUD-1.2-public", "data") + train_filenames = [] + dev_filenames = [] + for lang in langs: + train_filenames.extend(glob.glob(os.path.join(conll_path, "*%s*" % lang, "*train.conllu"))) + dev_filenames.extend(glob.glob(os.path.join(conll_path, "*%s*" % lang, "*dev.conllu"))) + train_filenames = sorted(train_filenames) + dev_filenames = sorted(dev_filenames) + return train_filenames, dev_filenames + +def main(): + paths = get_default_paths() + parser = argparse.ArgumentParser( + prog='Convert UDCoref Data', + ) + parser.add_argument('--split_test', default=None, type=float, help='How much of the data to randomly split from train to make a test set') + + group = parser.add_mutually_exclusive_group(required=True) + group.add_argument('--directory', type=str, help="the name of the subfolder for data conversion") + group.add_argument('--project', type=str, help="Look for and use a set of datasets for data conversion - Slavic or Hungarian") + + args = parser.parse_args() + coref_input_path = paths['COREF_BASE'] + coref_output_path = paths['COREF_DATA_DIR'] + + if args.project: + if args.project == 'slavic': + project = "slavic_udcoref" + langs = ('Polish', 'Russian', 'Czech') + train_filenames, dev_filenames = get_dataset_by_language(coref_input_path, langs) + elif args.project == 'hungarian': + project = "hu_udcoref" + langs = ('Hungarian',) + train_filenames, dev_filenames = get_dataset_by_language(coref_input_path, langs) + elif args.project == 'gerrom': + project = "gerrom_udcoref" + langs = ('Catalan', 'English', 'French', 'German', 'Norwegian', 'Spanish') + train_filenames, dev_filenames = get_dataset_by_language(coref_input_path, langs) + elif args.project == 'germanic': + project = "germanic_udcoref" + langs = ('English', 'German', 'Norwegian') + train_filenames, dev_filenames = get_dataset_by_language(coref_input_path, langs) + elif args.project == 'norwegian': + project = "norwegian_udcoref" + langs = ('Norwegian',) + train_filenames, dev_filenames = get_dataset_by_language(coref_input_path, langs) + else: + project = args.directory + conll_path = os.path.join(coref_input_path, project) + if not os.path.exists(conll_path) and os.path.exists(project): + conll_path = args.directory + train_filenames = sorted(glob.glob(os.path.join(conll_path, f"*train.conllu"))) + dev_filenames = sorted(glob.glob(os.path.join(conll_path, f"*dev.conllu"))) + process_dataset(project, coref_output_path, args.split_test, train_filenames, dev_filenames) + +if __name__ == '__main__': + main() + diff --git a/stanza/stanza/utils/datasets/coref/utils.py b/stanza/stanza/utils/datasets/coref/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..011af24afca2494cb5220ae0c8a93563079cfa28 --- /dev/null +++ b/stanza/stanza/utils/datasets/coref/utils.py @@ -0,0 +1,148 @@ +from collections import defaultdict +from functools import lru_cache + +class DynamicDepth(): + """ + Implements a cache + dynamic programming to find the relative depth of every word in a subphrase given the head word for every word. + """ + def get_parse_depths(self, heads, start, end): + """Return the relative depth for every word + + Args: + heads (list): List where each entry is the index of that entry's head word in the dependency parse + start (int): starting index of the heads for the subphrase + end (int): ending index of the heads for the subphrase + + Returns: + list: Relative depth in the dependency parse for every word + """ + self.heads = heads[start:end] + self.relative_heads = [h - start if h else -100 for h in self.heads] # -100 to deal with 'none' headwords + + depths = [self._get_depth_recursive(h) for h in range(len(self.relative_heads))] + + return depths + + @lru_cache(maxsize=None) + def _get_depth_recursive(self, index): + """Recursively get the depths of every index using a cache and recursion + + Args: + index (int): Index of the word for which to calculate the relative depth + + Returns: + int: Relative depth of the word at the index + """ + # if the head for the current index is outside the scope, this index is a relative root + if self.relative_heads[index] >= len(self.relative_heads) or self.relative_heads[index] < 0: + return 0 + return self._get_depth_recursive(self.relative_heads[index]) + 1 + +def find_cconj_head(heads, upos, start, end): + """ + Finds how far each word is from the head of a span, then uses the closest CCONJ to the head as the new head + + If no CCONJ is present, returns None + """ + # use head information to extract parse depth + dynamicDepth = DynamicDepth() + depth = dynamicDepth.get_parse_depths(heads, start, end) + depth_limit = 2 + + # return first 'CCONJ' token above depth limit, if exists + # unlike the original paper, we expect the parses to use UPOS, hence CCONJ instead of CC + cc_indexes = [i for i in range(end - start) if upos[i+start] == 'CCONJ' and depth[i] < depth_limit] + if cc_indexes: + return cc_indexes[0] + start + return None + +def process_document(pipe, doc_id, part_id, sentences, coref_spans, sentence_speakers, use_cconj_heads=True): + """ + coref_spans: a list of lists + one list per sentence + each sentence has a list of spans, where each span is (span_index, span_start, span_end) + """ + sentence_lens = [len(x) for x in sentences] + if all(isinstance(x, list) for x in sentence_speakers): + speaker = [y for x in sentence_speakers for y in x] + else: + speaker = [y for x, sent_len in zip(sentence_speakers, sentence_lens) for y in [x] * sent_len] + + cased_words = [y for x in sentences for y in x] + sent_id = [y for idx, sent_len in enumerate(sentence_lens) for y in [idx] * sent_len] + + # use the trees to get the xpos tags + # alternatively, could translate the pos_tags field, + # but those have numbers, which is annoying + #tree_text = "\n".join(x['parse_tree'] for x in paragraph) + #trees = tree_reader.read_trees(tree_text) + #pos = [x.label for tree in trees for x in tree.yield_preterminals()] + # actually, the downstream code doesn't use pos at all. maybe we can skip? + + doc = pipe(sentences) + word_total = 0 + heads = [] + # TODO: does SD vs UD matter? + deprel = [] + for sentence in doc.sentences: + for word in sentence.words: + deprel.append(word.deprel) + if word.head == 0: + heads.append("null") + else: + heads.append(word.head - 1 + word_total) + word_total += len(sentence.words) + + span_clusters = defaultdict(list) + word_clusters = defaultdict(list) + head2span = [] + word_total = 0 + for parsed_sentence, ontonotes_coref, ontonotes_words in zip(doc.sentences, coref_spans, sentences): + sentence_upos = [x.upos for x in parsed_sentence.words] + sentence_heads = [x.head - 1 if x.head > 0 else None for x in parsed_sentence.words] + for span in ontonotes_coref: + # input is expected to be start word, end word + 1 + # counting from 0 + # whereas the OntoNotes coref_span is [start_word, end_word] inclusive + span_start = span[1] + word_total + span_end = span[2] + word_total + 1 + candidate_head = find_cconj_head(sentence_heads, sentence_upos, span[1], span[2]+1) if use_cconj_heads else None + if candidate_head is None: + for candidate_head in range(span[1], span[2] + 1): + # stanza uses 0 to mark the head, whereas OntoNotes is counting + # words from 0, so we have to subtract 1 from the stanza heads + #print(span, candidate_head, parsed_sentence.words[candidate_head].head - 1) + # treat the head of the phrase as the first word that has a head outside the phrase + if (parsed_sentence.words[candidate_head].head - 1 < span[1] or + parsed_sentence.words[candidate_head].head - 1 > span[2]): + break + else: + # if none have a head outside the phrase (circular??) + # then just take the first word + candidate_head = span[1] + #print("----> %d" % candidate_head) + candidate_head += word_total + span_clusters[span[0]].append((span_start, span_end)) + word_clusters[span[0]].append(candidate_head) + head2span.append((candidate_head, span_start, span_end)) + word_total += len(ontonotes_words) + span_clusters = sorted([sorted(values) for _, values in span_clusters.items()]) + word_clusters = sorted([sorted(values) for _, values in word_clusters.items()]) + head2span = sorted(head2span) + + processed = { + "document_id": doc_id, + "part_id": part_id, + "cased_words": cased_words, + "sent_id": sent_id, + "speaker": speaker, + #"pos": pos, + "deprel": deprel, + "head": heads, + "span_clusters": span_clusters, + "word_clusters": word_clusters, + "head2span": head2span, + } + if part_id is not None: + processed["part_id"] = part_id + return processed diff --git a/stanza/stanza/utils/datasets/corenlp_segmenter_dataset.py b/stanza/stanza/utils/datasets/corenlp_segmenter_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..66d4294c5eaab328d81dc102e7f9bbd08eb88d13 --- /dev/null +++ b/stanza/stanza/utils/datasets/corenlp_segmenter_dataset.py @@ -0,0 +1,78 @@ +""" +Output a treebank's sentences in a form that can be processed by the CoreNLP CRF Segmenter + +Run it as + python3 -m stanza.utils.datasets.corenlp_segmenter_dataset +such as + python3 -m stanza.utils.datasets.corenlp_segmenter_dataset UD_Chinese-GSDSimp --output_dir $CHINESE_SEGMENTER_HOME +""" + +import argparse +import os +import sys +import tempfile + +import stanza.utils.datasets.common as common +import stanza.utils.datasets.prepare_tokenizer_treebank as prepare_tokenizer_treebank +import stanza.utils.default_paths as default_paths + +from stanza.models.common.constant import treebank_to_short_name + +def build_argparse(): + parser = argparse.ArgumentParser() + parser.add_argument('treebanks', type=str, nargs='*', default=["UD_Chinese-GSDSimp"], help='Which treebanks to run on') + parser.add_argument('--output_dir', type=str, default='.', help='Where to put the results') + return parser + + +def write_segmenter_file(output_filename, dataset): + with open(output_filename, "w") as fout: + for sentence in dataset: + sentence = [x for x in sentence if not x.startswith("#")] + sentence = [x for x in [y.strip() for y in sentence] if x] + # eliminate MWE, although Chinese currently doesn't have any + sentence = [x for x in sentence if x.split("\t")[0].find("-") < 0] + + text = " ".join(x.split("\t")[1] for x in sentence) + fout.write(text) + fout.write("\n") + +def process_treebank(treebank, model_type, paths, output_dir): + with tempfile.TemporaryDirectory() as tokenizer_dir: + paths = dict(paths) + paths["TOKENIZE_DATA_DIR"] = tokenizer_dir + + short_name = treebank_to_short_name(treebank) + + # first we process the tokenization data + args = argparse.Namespace() + args.augment = False + args.prepare_labels = False + prepare_tokenizer_treebank.process_treebank(treebank, model_type, paths, args) + + # TODO: these names should be refactored + train_file = f"{tokenizer_dir}/{short_name}.train.gold.conllu" + dev_file = f"{tokenizer_dir}/{short_name}.dev.gold.conllu" + test_file = f"{tokenizer_dir}/{short_name}.test.gold.conllu" + + train_set = common.read_sentences_from_conllu(train_file) + dev_set = common.read_sentences_from_conllu(dev_file) + test_set = common.read_sentences_from_conllu(test_file) + + train_out = os.path.join(output_dir, f"{short_name}.train.seg.txt") + test_out = os.path.join(output_dir, f"{short_name}.test.seg.txt") + + write_segmenter_file(train_out, train_set + dev_set) + write_segmenter_file(test_out, test_set) + +def main(): + parser = build_argparse() + args = parser.parse_args() + + paths = default_paths.get_default_paths() + for treebank in args.treebanks: + process_treebank(treebank, common.ModelType.TOKENIZER, paths, args.output_dir) + +if __name__ == '__main__': + main() + diff --git a/stanza/stanza/utils/datasets/ner/convert_bsnlp.py b/stanza/stanza/utils/datasets/ner/convert_bsnlp.py new file mode 100644 index 0000000000000000000000000000000000000000..6112fbf51716069b464171d9d16f0911e30c63cb --- /dev/null +++ b/stanza/stanza/utils/datasets/ner/convert_bsnlp.py @@ -0,0 +1,333 @@ +import argparse +import glob +import os +import logging +import random +import re + +import stanza + +logger = logging.getLogger('stanza') + +AVAILABLE_LANGUAGES = ("bg", "cs", "pl", "ru") + +def normalize_bg_entity(text, entity, raw): + entity = entity.strip() + # sanity check that the token is in the original text + if text.find(entity) >= 0: + return entity + + # some entities have quotes, but the quotes are different from those in the data file + # for example: + # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_458.txt + # 'Съвета "Общи въпроси"' + # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_1002.txt + # 'Съвет "Общи въпроси"' + if sum(1 for x in entity if x == '"') == 2: + quote_entity = entity.replace('"', '“') + if text.find(quote_entity) >= 0: + logger.info("searching for '%s' instead of '%s' in %s" % (quote_entity, entity, raw)) + return quote_entity + + quote_entity = entity.replace('"', '„', 1).replace('"', '“') + if text.find(quote_entity) >= 0: + logger.info("searching for '%s' instead of '%s' in %s" % (quote_entity, entity, raw)) + return quote_entity + + if sum(1 for x in entity if x == '"') == 1: + quote_entity = entity.replace('"', '„', 1) + if text.find(quote_entity) >= 0: + logger.info("searching for '%s' instead of '%s' in %s" % (quote_entity, entity, raw)) + return quote_entity + + if entity.find("'") >= 0: + quote_entity = entity.replace("'", "’") + if text.find(quote_entity) >= 0: + logger.info("searching for '%s' instead of '%s' in %s" % (quote_entity, entity, raw)) + return quote_entity + + lower_idx = text.lower().find(entity.lower()) + if lower_idx >= 0: + fixed_entity = text[lower_idx:lower_idx+len(entity)] + logger.info("lowercase match found. Searching for '%s' instead of '%s' in %s" % (fixed_entity, entity, raw)) + return fixed_entity + + substitution_pairs = { + # this exact error happens in: + # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_67.txt + 'Съвет по общи въпроси': 'Съвета по общи въпроси', + # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_214.txt + 'Сумимото Мицуи файненшъл груп': 'Сумитомо Мицуи файненшъл груп', + # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_245.txt + 'С и Д': 'С&Д', + # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_348.txt + 'законопроекта за излизане на Великобритания за излизане от Европейския съюз': 'законопроекта за излизане на Великобритания от Европейския съюз', + # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_771.txt + 'Унивеситета в Есекс': 'Университета в Есекс', + # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_779.txt + 'Съвет за сигурност на ООН': 'Съвета за сигурност на ООН', + # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_787.txt + 'Федерика Могерини': 'Федереика Могерини', + # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_938.txt + 'Уайстейбъл': 'Уайтстейбъл', + 'Партията за независимост на Обединеното кралство': 'Партията на независимостта на Обединеното кралство', + # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_972.txt + 'Европейска банка за възстановяване и развитие': 'Европейската банка за възстановяване и развитие', + # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_1065.txt + 'Харолд Уилсон': 'Харолд Уилсън', + 'Манчестърски университет': 'Манчестърския университет', + # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_1096.txt + 'Обединеното кралство в променящата се Европа': 'Обединеното кралство в променяща се Европа', + # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_1175.txt + 'The Daily Express': 'Daily Express', + # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_1186.txt + 'демократичната юнионистка партия': 'демократична юнионистка партия', + # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_1192.txt + 'Европейската агенция за безопасността на полетите': 'Европейската агенция за сигурността на полетите', + # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_1219.txt + 'пресцентъра на Външно министертво': 'пресцентъра на Външно министерство', + # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_1281.txt + 'Европейска агенциа за безопасността на полетите': 'Европейската агенция за сигурността на полетите', + # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_1415.txt + 'Хонк Конг': 'Хонг Конг', + # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_1663.txt + 'Лейбъристка партия': 'Лейбъристката партия', + # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_1963.txt + 'Найджъл Фараж': 'Найджъл Фарадж', + 'Фараж': 'Фарадж', + + # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_1773.txt has an entity which is mixed Cyrillic and Ascii + 'Tescо': 'Tesco', + } + + if entity in substitution_pairs and text.find(substitution_pairs[entity]) >= 0: + fixed_entity = substitution_pairs[entity] + logger.info("searching for '%s' instead of '%s' in %s" % (fixed_entity, entity, raw)) + return fixed_entity + + # oops, can't find it anywhere + # want to raise ValueError but there are just too many in the train set for BG + logger.error("Could not find '%s' in %s" % (entity, raw)) + +def fix_bg_typos(text, raw_filename): + typo_pairs = { + # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_202.txt is not exactly a typo, but the word is mixed cyrillic and ascii characters + 'brexit_bg.txt_file_202.txt': ('Вlооmbеrg', 'Bloomberg'), + # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_261.txt has a typo: Telegaph instead of Telegraph + 'brexit_bg.txt_file_261.txt': ('Telegaph', 'Telegraph'), + # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_574.txt has a typo: politicalskrapbook instead of politicalscrapbook + 'brexit_bg.txt_file_574.txt': ('politicalskrapbook', 'politicalscrapbook'), + # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_861.txt has a mix of cyrillic and ascii + 'brexit_bg.txt_file_861.txt': ('Съвета „Общи въпроси“', 'Съветa "Общи въпроси"'), + # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_992.txt is not exactly a typo, but the word is mixed cyrillic and ascii characters + 'brexit_bg.txt_file_992.txt': ('The Guardiаn', 'The Guardian'), + # training_pl_cs_ru_bg_rc1/raw/bg/brexit_bg.txt_file_1856.txt has a typo: Southerb instead of Southern + 'brexit_bg.txt_file_1856.txt': ('Southerb', 'Southern'), + } + + filename = os.path.split(raw_filename)[1] + if filename in typo_pairs: + replacement = typo_pairs.get(filename) + text = text.replace(replacement[0], replacement[1]) + + return text + +def get_sentences(language, pipeline, annotated, raw): + if language == 'bg': + normalize_entity = normalize_bg_entity + fix_typos = fix_bg_typos + else: + raise AssertionError("Please build a normalize_%s_entity and fix_%s_typos first" % language) + + annotated_sentences = [] + with open(raw) as fin: + lines = fin.readlines() + if len(lines) < 5: + raise ValueError("Unexpected format in %s" % raw) + text = "\n".join(lines[4:]) + text = fix_typos(text, raw) + + entities = {} + with open(annotated) as fin: + # first line + header = fin.readline().strip() + if len(header.split("\t")) > 1: + raise ValueError("Unexpected missing header line in %s" % annotated) + for line in fin: + pieces = line.strip().split("\t") + if len(pieces) < 3 or len(pieces) > 4: + raise ValueError("Unexpected annotation format in %s" % annotated) + + entity = normalize_entity(text, pieces[0], raw) + if not entity: + continue + if entity in entities: + if entities[entity] != pieces[2]: + # would like to make this an error, but it actually happens and it's not clear how to fix + # annotated/nord_stream/bg/nord_stream_bg.txt_file_119.out + logger.warn("found multiple definitions for %s in %s" % (pieces[0], annotated)) + entities[entity] = pieces[2] + else: + entities[entity] = pieces[2] + + tokenized = pipeline(text) + # The benefit of doing these one at a time, instead of all at once, + # is that nested entities won't clobber previously labeled entities. + # For example, the file + # training_pl_cs_ru_bg_rc1/annotated/bg/brexit_bg.txt_file_994.out + # has each of: + # Северна Ирландия + # Република Ирландия + # Ирландия + # By doing the larger ones first, we can detect and skip the ones + # we already labeled when we reach the shorter one + regexes = [re.compile(re.escape(x)) for x in sorted(entities.keys(), key=len, reverse=True)] + + bad_sentences = set() + + for regex in regexes: + for match in regex.finditer(text): + start_char, end_char = match.span() + # this is inefficient, but for something only run once, it shouldn't matter + start_token = None + start_sloppy = False + end_token = None + end_sloppy = False + for token in tokenized.iter_tokens(): + if token.start_char <= start_char and token.end_char > start_char: + start_token = token + if token.start_char != start_char: + start_sloppy = True + if token.start_char <= end_char and token.end_char >= end_char: + end_token = token + if token.end_char != end_char: + end_sloppy = True + break + if start_token is None or end_token is None: + raise RuntimeError("Match %s did not align with any tokens in %s" % (match.group(0), raw)) + if not start_token.sent is end_token.sent: + bad_sentences.add(start_token.sent.id) + bad_sentences.add(end_token.sent.id) + logger.warn("match %s spanned sentences %d and %d in document %s" % (match.group(0), start_token.sent.id, end_token.sent.id, raw)) + continue + + # ids start at 1, not 0, so we have to subtract 1 + # then the end token is included, so we add back the 1 + # TODO: verify that this is correct if the language has MWE - cs, pl, for example + tokens = start_token.sent.tokens[start_token.id[0]-1:end_token.id[0]] + if all(token.ner for token in tokens): + # skip matches which have already been made + # this has the nice side effect of not complaining if + # a smaller match is found after a larger match + # earlier set the NER on those tokens + continue + + if start_sloppy and end_sloppy: + bad_sentences.add(start_token.sent.id) + logger.warn("match %s matched in the middle of a token in %s" % (match.group(0), raw)) + continue + if start_sloppy: + bad_sentences.add(end_token.sent.id) + logger.warn("match %s started matching in the middle of a token in %s" % (match.group(0), raw)) + #print(start_token) + #print(end_token) + #print(start_char, end_char) + continue + if end_sloppy: + bad_sentences.add(start_token.sent.id) + logger.warn("match %s ended matching in the middle of a token in %s" % (match.group(0), raw)) + #print(start_token) + #print(end_token) + #print(start_char, end_char) + continue + match_text = match.group(0) + if match_text not in entities: + raise RuntimeError("Matched %s, which is not in the entities from %s" % (match_text, annotated)) + ner_tag = entities[match_text] + tokens[0].ner = "B-" + ner_tag + for token in tokens[1:]: + token.ner = "I-" + ner_tag + + for sentence in tokenized.sentences: + if not sentence.id in bad_sentences: + annotated_sentences.append(sentence) + + return annotated_sentences + +def write_sentences(output_filename, annotated_sentences): + logger.info("Writing %d sentences to %s" % (len(annotated_sentences), output_filename)) + with open(output_filename, "w") as fout: + for sentence in annotated_sentences: + for token in sentence.tokens: + ner_tag = token.ner + if not ner_tag: + ner_tag = "O" + fout.write("%s\t%s\n" % (token.text, ner_tag)) + fout.write("\n") + + +def convert_bsnlp(language, base_input_path, output_filename, split_filename=None): + """ + Converts the BSNLP dataset for the given language. + + If only one output_filename is provided, all of the output goes to that file. + If split_filename is provided as well, 15% of the output chosen randomly + goes there instead. The dataset has no dev set, so this helps + divide the data into train/dev/test. + Note that the custom error fixes are only done for BG currently. + Please manually correct the data as appropriate before using this + for another language. + """ + if language not in AVAILABLE_LANGUAGES: + raise ValueError("The current BSNLP datasets only include the following languages: %s" % ",".join(AVAILABLE_LANGUAGES)) + if language != "bg": + raise ValueError("There were quite a few data fixes needed to get the data correct for BG. Please work on similar fixes before using the model for %s" % language.upper()) + pipeline = stanza.Pipeline(language, processors="tokenize") + random.seed(1234) + + annotated_path = os.path.join(base_input_path, "annotated", "*", language, "*") + annotated_files = sorted(glob.glob(annotated_path)) + raw_path = os.path.join(base_input_path, "raw", "*", language, "*") + raw_files = sorted(glob.glob(raw_path)) + + # if the instructions for downloading the data from the + # process_ner_dataset script are followed, there will be two test + # directories of data and a separate training directory of data. + if len(annotated_files) == 0 and len(raw_files) == 0: + logger.info("Could not find files in %s" % annotated_path) + annotated_path = os.path.join(base_input_path, "annotated", language, "*") + logger.info("Trying %s instead" % annotated_path) + annotated_files = sorted(glob.glob(annotated_path)) + raw_path = os.path.join(base_input_path, "raw", language, "*") + raw_files = sorted(glob.glob(raw_path)) + + if len(annotated_files) != len(raw_files): + raise ValueError("Unexpected differences in the file lists between %s and %s" % (annotated_files, raw_files)) + + for i, j in zip(annotated_files, raw_files): + if os.path.split(i)[1][:-4] != os.path.split(j)[1][:-4]: + raise ValueError("Unexpected differences in the file lists: found %s instead of %s" % (i, j)) + + annotated_sentences = [] + if split_filename: + split_sentences = [] + for annotated, raw in zip(annotated_files, raw_files): + new_sentences = get_sentences(language, pipeline, annotated, raw) + if not split_filename or random.random() < 0.85: + annotated_sentences.extend(new_sentences) + else: + split_sentences.extend(new_sentences) + + write_sentences(output_filename, annotated_sentences) + if split_filename: + write_sentences(split_filename, split_sentences) + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--language', type=str, default="bg", help="Language to process") + parser.add_argument('--input_path', type=str, default="/home/john/extern_data/ner/bsnlp2019", help="Where to find the files") + parser.add_argument('--output_path', type=str, default="/home/john/stanza/data/ner/bg_bsnlp.test.csv", help="Where to output the results") + parser.add_argument('--dev_path', type=str, default=None, help="A secondary output path - 15% of the data will go here") + args = parser.parse_args() + + convert_bsnlp(args.language, args.input_path, args.output_path, args.dev_path) diff --git a/stanza/stanza/utils/datasets/ner/convert_fire_2013.py b/stanza/stanza/utils/datasets/ner/convert_fire_2013.py new file mode 100644 index 0000000000000000000000000000000000000000..bb108507593102aef5d0b2a00ec8de99e0575575 --- /dev/null +++ b/stanza/stanza/utils/datasets/ner/convert_fire_2013.py @@ -0,0 +1,118 @@ +""" +Converts the FIRE 2013 dataset to TSV + +http://au-kbc.org/nlp/NER-FIRE2013/index.html + +The dataset is in six tab separated columns. The columns are + +word tag chunk ner1 ner2 ner3 + +This script keeps just the word and the ner1. It is quite possible that using the tag would help +""" + +import argparse +import glob +import os +import random + +def normalize(e1, e2, e3): + if e1 == 'o': + return "O" + + if e2 != 'o' and e1[:2] != e2[:2]: + raise ValueError("Found a token with conflicting position tags %s,%s" % (e1, e2)) + if e3 != 'o' and e2 == 'o': + raise ValueError("Found a token with tertiary label but no secondary label %s,%s,%s" % (e1, e2, e3)) + if e3 != 'o' and (e1[:2] != e2[:2] or e1[:2] != e3[:2]): + raise ValueError("Found a token with conflicting position tags %s,%s,%s" % (e1, e2, e3)) + + if e1[2:] in ('ORGANIZATION', 'FACILITIES'): + return e1 + if e1[2:] == 'ENTERTAINMENT' and e2[2:] != 'SPORTS' and e2[2:] != 'CINEMA': + return e1 + if e1[2:] == 'DISEASE' and e2 == 'o': + return e1 + if e1[2:] == 'PLANTS' and e2[2:] != 'PARTS': + return e1 + if e1[2:] == 'PERSON' and e2[2:] == 'INDIVIDUAL': + return e1 + if e1[2:] == 'LOCATION' and e2[2:] == 'PLACE': + return e1 + if e1[2:] in ('DATE', 'TIME', 'YEAR'): + string = e1[:2] + 'DATETIME' + return string + + return "O" + +def read_fileset(filenames): + # first, read the sentences from each data file + sentences = [] + for filename in filenames: + with open(filename) as fin: + next_sentence = [] + for line in fin: + line = line.strip() + if not line: + # lots of single line "sentences" in the dataset + if next_sentence: + if len(next_sentence) > 1: + sentences.append(next_sentence) + next_sentence = [] + else: + next_sentence.append(line) + if next_sentence and len(next_sentence) > 1: + sentences.append(next_sentence) + return sentences + +def write_fileset(output_csv_file, sentences): + with open(output_csv_file, "w") as fout: + for sentence in sentences: + for line in sentence: + pieces = line.split("\t") + if len(pieces) != 6: + raise ValueError("Found %d pieces instead of the expected 6" % len(pieces)) + if pieces[3] == 'o' and (pieces[4] != 'o' or pieces[5] != 'o'): + raise ValueError("Inner NER labeled but the top layer was O") + fout.write("%s\t%s\n" % (pieces[0], normalize(pieces[3], pieces[4], pieces[5]))) + fout.write("\n") + +def convert_fire_2013(input_path, train_csv_file, dev_csv_file, test_csv_file): + random.seed(1234) + + filenames = glob.glob(os.path.join(input_path, "*")) + + # won't be numerically sorted... shouldn't matter + filenames = sorted(filenames) + random.shuffle(filenames) + + sentences = read_fileset(filenames) + random.shuffle(sentences) + + train_cutoff = int(0.8 * len(sentences)) + dev_cutoff = int(0.9 * len(sentences)) + + train_sentences = sentences[:train_cutoff] + dev_sentences = sentences[train_cutoff:dev_cutoff] + test_sentences = sentences[dev_cutoff:] + + random.shuffle(train_sentences) + random.shuffle(dev_sentences) + random.shuffle(test_sentences) + + assert len(train_sentences) > 0 + assert len(dev_sentences) > 0 + assert len(test_sentences) > 0 + + write_fileset(train_csv_file, train_sentences) + write_fileset(dev_csv_file, dev_sentences) + write_fileset(test_csv_file, test_sentences) + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--input_path', type=str, default="/home/john/extern_data/ner/FIRE2013/hindi_train", help="Directory with raw files to read") + parser.add_argument('--train_file', type=str, default="/home/john/stanza/data/ner/hi_fire2013.train.csv", help="Where to put the train file") + parser.add_argument('--dev_file', type=str, default="/home/john/stanza/data/ner/hi_fire2013.dev.csv", help="Where to put the dev file") + parser.add_argument('--test_file', type=str, default="/home/john/stanza/data/ner/hi_fire2013.test.csv", help="Where to put the test file") + args = parser.parse_args() + + convert_fire_2013(args.input_path, args.train_file, args.dev_file, args.test_file) diff --git a/stanza/stanza/utils/datasets/ner/convert_hy_armtdp.py b/stanza/stanza/utils/datasets/ner/convert_hy_armtdp.py new file mode 100644 index 0000000000000000000000000000000000000000..9da5eba15bb7b73de8a7d97dbf7707317424c5d5 --- /dev/null +++ b/stanza/stanza/utils/datasets/ner/convert_hy_armtdp.py @@ -0,0 +1,145 @@ +""" +Convert a ArmTDP-NER dataset to BIO format + +The dataset is here: + +https://github.com/myavrum/ArmTDP-NER.git +""" + +import argparse +import os +import json +import re +import stanza +import random +from tqdm import tqdm + +from stanza import DownloadMethod, Pipeline +import stanza.utils.default_paths as default_paths + +def read_data(path: str) -> list: + """ + Reads the Armenian named entity recognition dataset + + Returns a list of dictionaries. + Each dictionary contains information + about a paragraph (text, labels, etc.) + """ + with open(path, 'r') as file: + paragraphs = [json.loads(line) for line in file] + return paragraphs + + +def filter_unicode_broken_characters(text: str) -> str: + """ + Removes all unicode characters in text + """ + return re.sub(r'\\u[A-Za-z0-9]{4}', '', text) + + +def get_label(tok_start_char: int, tok_end_char: int, labels: list) -> list: + """ + Returns the label that corresponds to the given token + """ + for label in labels: + if label[0] <= tok_start_char and label[1] >= tok_end_char: + return label + return [] + + +def format_sentences(paragraphs: list, nlp_hy: Pipeline) -> list: + """ + Takes a list of paragraphs and returns a list of sentences, + where each sentence is a list of tokens along with their respective entity tags. + """ + sentences = [] + for paragraph in tqdm(paragraphs): + doc = nlp_hy(filter_unicode_broken_characters(paragraph['text'])) + for sentence in doc.sentences: + sentence_ents = [] + entity = [] + for token in sentence.tokens: + label = get_label(token.start_char, token.end_char, paragraph['labels']) + if label: + entity.append(token.text) + if token.end_char == label[1]: + sentence_ents.append({'tokens': entity, + 'tag': label[2]}) + entity = [] + else: + sentence_ents.append({'tokens': [token.text], + 'tag': 'O'}) + sentences.append(sentence_ents) + return sentences + + +def convert_to_bioes(sentences: list) -> list: + """ + Returns a list of strings where each string represents a sentence in BIOES format + """ + beios_sents = [] + for sentence in tqdm(sentences): + sentence_toc = '' + for ent in sentence: + if ent['tag'] == 'O': + sentence_toc += ent['tokens'][0] + '\tO' + '\n' + else: + if len(ent['tokens']) == 1: + sentence_toc += ent['tokens'][0] + '\tS-' + ent['tag'] + '\n' + else: + sentence_toc += ent['tokens'][0] + '\tB-' + ent['tag'] + '\n' + for token in ent['tokens'][1:-1]: + sentence_toc += token + '\tI-' + ent['tag'] + '\n' + sentence_toc += ent['tokens'][-1] + '\tE-' + ent['tag'] + '\n' + beios_sents.append(sentence_toc) + return beios_sents + + +def write_sentences_to_file(sents, filename): + print(f"Writing {len(sents)} sentences to {filename}") + with open(filename, 'w') as outfile: + for sent in sents: + outfile.write(sent + '\n\n') + + +def train_test_dev_split(sents, base_output_path, short_name, train_fraction=0.7, dev_fraction=0.15): + """ + Splits a list of sentences into training, dev, and test sets, + and writes each set to a separate file with write_sentences_to_file + """ + num = len(sents) + train_num = int(num * train_fraction) + dev_num = int(num * dev_fraction) + if train_fraction + dev_fraction > 1.0: + raise ValueError( + "Train and dev fractions added up to more than 1: {} {} {}".format(train_fraction, dev_fraction)) + + random.shuffle(sents) + train_sents = sents[:train_num] + dev_sents = sents[train_num:train_num + dev_num] + test_sents = sents[train_num + dev_num:] + batches = [train_sents, dev_sents, test_sents] + filenames = [f'{short_name}.train.tsv', f'{short_name}.dev.tsv', f'{short_name}.test.tsv'] + for batch, filename in zip(batches, filenames): + write_sentences_to_file(batch, os.path.join(base_output_path, filename)) + + +def convert_dataset(base_input_path, base_output_path, short_name, download_method=DownloadMethod.DOWNLOAD_RESOURCES): + nlp_hy = stanza.Pipeline(lang='hy', processors='tokenize', download_method=download_method) + paragraphs = read_data(os.path.join(base_input_path, 'ArmNER-HY.json1')) + tagged_sentences = format_sentences(paragraphs, nlp_hy) + beios_sentences = convert_to_bioes(tagged_sentences) + train_test_dev_split(beios_sentences, base_output_path, short_name) + + +if __name__ == '__main__': + paths = default_paths.get_default_paths() + + parser = argparse.ArgumentParser() + parser.add_argument('--input_path', type=str, default=os.path.join(paths["NERBASE"], "armenian", "ArmTDP-NER"), help="Path to input file") + parser.add_argument('--output_path', type=str, default=paths["NER_DATA_DIR"], help="Path to the output directory") + parser.add_argument('--short_name', type=str, default="hy_armtdp", help="Name to identify the dataset and the model") + parser.add_argument('--download_method', type=str, default=DownloadMethod.DOWNLOAD_RESOURCES, help="Download method for initializing the Pipeline. Default downloads the Armenian pipeline, --download_method NONE does not. Options: %s" % DownloadMethod._member_names_) + args = parser.parse_args() + + convert_dataset(args.input_path, args.output_path, args.short_name, args.download_method) diff --git a/stanza/stanza/utils/datasets/ner/convert_kk_kazNERD.py b/stanza/stanza/utils/datasets/ner/convert_kk_kazNERD.py new file mode 100644 index 0000000000000000000000000000000000000000..8fbf3015315f7b3353266ee00670fae3bcae93ee --- /dev/null +++ b/stanza/stanza/utils/datasets/ner/convert_kk_kazNERD.py @@ -0,0 +1,35 @@ +""" +Convert a Kazakh NER dataset to our internal .json format +The dataset is here: + +https://github.com/IS2AI/KazNERD/tree/main/KazNERD +""" + +import argparse +import os +import shutil +# import random + +from stanza.utils.datasets.ner.utils import convert_bio_to_json, SHARDS + +def convert_dataset(in_directory, out_directory, short_name): + """ + Reads in train, validation, and test data and converts them to .json file + """ + filenames = ("IOB2_train.txt", "IOB2_valid.txt", "IOB2_test.txt") + for shard, filename in zip(SHARDS, filenames): + input_filename = os.path.join(in_directory, filename) + output_filename = os.path.join(out_directory, "%s.%s.bio" % (short_name, shard)) + shutil.copy(input_filename, output_filename) + convert_bio_to_json(out_directory, out_directory, short_name, "bio") + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--input_path', type=str, default="/nlp/scr/aaydin/kazNERD/NER", help="Where to find the files") + parser.add_argument('--output_path', type=str, default="/nlp/scr/aaydin/kazNERD/data/ner", help="Where to output the results") + args = parser.parse_args() + # in_path = '/nlp/scr/aaydin/kazNERD/NER' + # out_path = '/nlp/scr/aaydin/kazNERD/NER/output' + # convert_dataset(in_path, out_path) + convert_dataset(args.input_path, args.output_path, "kk_kazNERD") + diff --git a/stanza/stanza/utils/datasets/ner/convert_my_ucsy.py b/stanza/stanza/utils/datasets/ner/convert_my_ucsy.py new file mode 100644 index 0000000000000000000000000000000000000000..686792a4e8939fce654552cb04b8bca6f0ce14e1 --- /dev/null +++ b/stanza/stanza/utils/datasets/ner/convert_my_ucsy.py @@ -0,0 +1,102 @@ +""" +Processes the three pieces of the NER dataset we received from UCSY. + +Requires the Myanmar tokenizer to exist, since the text is not already tokenized. + +There are three files sent to us from UCSY, one each for train, dev, test +This script expects them to be in the ner directory with the names + $NERBASE/my_ucsy/Myanmar_NER_train.txt + $NERBASE/my_ucsy/Myanmar_NER_dev.txt + $NERBASE/my_ucsy/Myanmar_NER_test.txt + +The files are in the following format: + unsegmentedtext@LABEL|unsegmentedtext@LABEL|... +with one sentence per line + +Solution: + - break the text up into fragments by splitting on | + - extract the labels + - segment each block of text using the MY tokenizer + +We could take two approaches to breaking up the blocks. One would be +to combine all chunks, then segment an entire sentence at once. This +would require some logic to re-chunk the resulting pieces. Instead, +we resegment each individual chunk by itself. This loses the +information from the neighboring chunks, but guarantees there are no +screwups where segmentation crosses segment boundaries and is simpler +to code. + +Of course, experimenting with the alternate approach might be better. + +There is one stray label of SB in the training data, so we throw out +that entire sentence. +""" + + +import os + +from tqdm import tqdm +import stanza +from stanza.utils.datasets.ner.check_for_duplicates import check_for_duplicates + +SPLITS = ("train", "dev", "test") + +def convert_file(input_filename, output_filename, pipe): + with open(input_filename) as fin: + lines = fin.readlines() + + all_labels = set() + + with open(output_filename, "w") as fout: + for line in tqdm(lines): + pieces = line.split("|") + texts = [] + labels = [] + skip_sentence = False + for piece in pieces: + piece = piece.strip() + if not piece: + continue + text, label = piece.rsplit("@", maxsplit=1) + text = text.strip() + if not text: + continue + if label == 'SB': + skip_sentence = True + break + + texts.append(text) + labels.append(label) + + if skip_sentence: + continue + + text = "\n\n".join(texts) + doc = pipe(text) + assert len(doc.sentences) == len(texts) + for sentence, label in zip(doc.sentences, labels): + all_labels.add(label) + for word_idx, word in enumerate(sentence.words): + if label == "O": + output_label = "O" + elif word_idx == 0: + output_label = "B-" + label + else: + output_label = "I-" + label + + fout.write("%s\t%s\n" % (word.text, output_label)) + fout.write("\n\n") + + print("Finished processing {} Labels found: {}".format(input_filename, sorted(all_labels))) + +def convert_my_ucsy(base_input_path, base_output_path): + os.makedirs(base_output_path, exist_ok=True) + pipe = stanza.Pipeline("my", processors="tokenize", tokenize_no_ssplit=True) + output_filenames = [os.path.join(base_output_path, "my_ucsy.%s.bio" % split) for split in SPLITS] + + for split, output_filename in zip(SPLITS, output_filenames): + input_filename = os.path.join(base_input_path, "Myanmar_NER_%s.txt" % split) + if not os.path.exists(input_filename): + raise FileNotFoundError("Necessary file for my_ucsy does not exist: %s" % input_filename) + + convert_file(input_filename, output_filename, pipe) diff --git a/stanza/stanza/utils/datasets/ner/convert_sindhi_siner.py b/stanza/stanza/utils/datasets/ner/convert_sindhi_siner.py new file mode 100644 index 0000000000000000000000000000000000000000..01d50d092c105202c6aca7d356c48bfbee984bc1 --- /dev/null +++ b/stanza/stanza/utils/datasets/ner/convert_sindhi_siner.py @@ -0,0 +1,69 @@ +""" +Converts the raw data from SiNER to .json for the Stanza NER system + +https://aclanthology.org/2020.lrec-1.361.pdf +""" + +from stanza.utils.datasets.ner.utils import write_dataset + +def fix_sentence(sentence): + """ + Fix some of the mistags in the dataset + + This covers 11 sentences: 1 P-PERSON, 2 with line breaks in the middle of the tag, and 8 with no B- or I- + """ + new_sentence = [] + for word_idx, word in enumerate(sentence): + if word[1] == 'P-PERSON': + new_sentence.append((word[0], 'B-PERSON')) + elif word[1] == 'B-OT"': + new_sentence.append((word[0], 'B-OTHERS')) + elif word[1] == 'B-T"': + new_sentence.append((word[0], 'B-TITLE')) + elif word[1] in ('GPE', 'LOC', 'OTHERS'): + if len(new_sentence) > 0 and new_sentence[-1][1][:2] in ('B-', 'I-') and new_sentence[-1][1][2:] == word[1]: + # one example... no idea if it should be a break or + # not, but the last word translates to "Corporation", + # so probably not: ميٽرو پوليٽن ڪارپوريشن + new_sentence.append((word[0], 'I-' + word[1])) + else: + new_sentence.append((word[0], 'B-' + word[1])) + else: + new_sentence.append(word) + return new_sentence + +def convert_sindhi_siner(in_filename, out_directory, short_name, train_frac=0.8, dev_frac=0.1): + """ + Read lines from the dataset, crudely separate sentences based on . or !, and write the dataset + """ + with open(in_filename, encoding="utf-8") as fin: + lines = fin.readlines() + + lines = [x.strip().split("\t") for x in lines] + lines = [(x[0].strip(), x[1].strip()) for x in lines if len(x) == 2] + print("Read %d words from %s" % (len(lines), in_filename)) + sentences = [] + prev_idx = 0 + for sent_idx, line in enumerate(lines): + # maybe also handle line[0] == '،', "Arabic comma"? + if line[0] in ('.', '!'): + sentences.append(lines[prev_idx:sent_idx+1]) + prev_idx=sent_idx+1 + + # in case the file doesn't end with punctuation, grab the last few lines + if prev_idx < len(lines): + sentences.append(lines[prev_idx:]) + + print("Found %d sentences before splitting" % len(sentences)) + sentences = [fix_sentence(x) for x in sentences] + assert not any('"' in x[1] or x[1].startswith("P-") or x[1] in ("GPE", "LOC", "OTHERS") for sentence in sentences for x in sentence) + + train_len = int(len(sentences) * train_frac) + dev_len = int(len(sentences) * (train_frac+dev_frac)) + train_sentences = sentences[:train_len] + dev_sentences = sentences[train_len:dev_len] + test_sentences = sentences[dev_len:] + + datasets = (train_sentences, dev_sentences, test_sentences) + write_dataset(datasets, out_directory, short_name, suffix="bio") + diff --git a/stanza/stanza/utils/datasets/ner/convert_starlang_ner.py b/stanza/stanza/utils/datasets/ner/convert_starlang_ner.py new file mode 100644 index 0000000000000000000000000000000000000000..44100149427ab308ad8b03cb18f0506d5656a464 --- /dev/null +++ b/stanza/stanza/utils/datasets/ner/convert_starlang_ner.py @@ -0,0 +1,55 @@ +""" +Convert the starlang trees to a NER dataset + +Has to hide quite a few trees with missing NER labels +""" + +import re + +from stanza.models.constituency import tree_reader +import stanza.utils.datasets.constituency.convert_starlang as convert_starlang + +TURKISH_WORD_RE = re.compile(r"[{]turkish=([^}]+)[}]") +TURKISH_LABEL_RE = re.compile(r"[{]namedEntity=([^}]+)[}]") + + + +def read_tree(text): + """ + Reads in a tree, then extracts the word and the NER + + One problem is that it is unknown if there are cases of two separate items occurring consecutively + + Note that this is quite similar to the convert_starlang script for constituency. + """ + trees = tree_reader.read_trees(text) + if len(trees) > 1: + raise ValueError("Tree file had two trees!") + tree = trees[0] + words = [] + for label in tree.leaf_labels(): + match = TURKISH_WORD_RE.search(label) + if match is None: + raise ValueError("Could not find word in |{}|".format(label)) + word = match.group(1) + word = word.replace("-LCB-", "{").replace("-RCB-", "}") + + match = TURKISH_LABEL_RE.search(label) + if match is None: + raise ValueError("Could not find ner in |{}|".format(label)) + tag = match.group(1) + if tag == 'NONE' or tag == "null": + tag = 'O' + words.append((word, tag)) + + return words + +def read_starlang(paths): + return convert_starlang.read_starlang(paths, conversion=read_tree, log=False) + +def main(): + train, dev, test = convert_starlang.main(conversion=read_tree, log=False) + +if __name__ == '__main__': + main() + diff --git a/stanza/stanza/utils/datasets/ner/ontonotes_multitag.py b/stanza/stanza/utils/datasets/ner/ontonotes_multitag.py new file mode 100644 index 0000000000000000000000000000000000000000..98f57d58652a2d0754e1496cf96d19bc3ecd5c9b --- /dev/null +++ b/stanza/stanza/utils/datasets/ner/ontonotes_multitag.py @@ -0,0 +1,97 @@ +""" +Combines OntoNotes and WW into a single dataset with OntoNotes used for dev & test + +The resulting dataset has two layers saved in the multi_ner column. + +WW is kept as 9 classes, with the tag put in either the first or +second layer depending on the flags. + +OntoNotes is converted to one column for 18 and one column for 9 classes. +""" + +import argparse +import json +import os +import shutil + +from stanza.utils import default_paths +from stanza.utils.datasets.ner.utils import combine_files +from stanza.utils.datasets.ner.simplify_ontonotes_to_worldwide import simplify_ontonotes_to_worldwide + +def convert_ontonotes_file(filename, simplify, bigger_first): + assert "en_ontonotes" in filename + if not os.path.exists(filename): + raise FileNotFoundError("Cannot convert missing file %s" % filename) + new_filename = filename.replace("en_ontonotes", "en_ontonotes-multi") + + with open(filename) as fin: + doc = json.load(fin) + + for sentence in doc: + for word in sentence: + ner = word['ner'] + if simplify: + simplified = simplify_ontonotes_to_worldwide(ner) + else: + simplified = "-" + if bigger_first: + word['multi_ner'] = (ner, simplified) + else: + word['multi_ner'] = (simplified, ner) + + with open(new_filename, "w") as fout: + json.dump(doc, fout, indent=2) + +def convert_worldwide_file(filename, bigger_first): + assert "en_worldwide-9class" in filename + if not os.path.exists(filename): + raise FileNotFoundError("Cannot convert missing file %s" % filename) + + new_filename = filename.replace("en_worldwide-9class", "en_worldwide-9class-multi") + + with open(filename) as fin: + doc = json.load(fin) + + for sentence in doc: + for word in sentence: + ner = word['ner'] + if bigger_first: + word['multi_ner'] = ("-", ner) + else: + word['multi_ner'] = (ner, "-") + + with open(new_filename, "w") as fout: + json.dump(doc, fout, indent=2) + +def build_multitag_dataset(base_output_path, short_name, simplify, bigger_first): + convert_ontonotes_file(os.path.join(base_output_path, "en_ontonotes.train.json"), simplify, bigger_first) + convert_ontonotes_file(os.path.join(base_output_path, "en_ontonotes.dev.json"), simplify, bigger_first) + convert_ontonotes_file(os.path.join(base_output_path, "en_ontonotes.test.json"), simplify, bigger_first) + + convert_worldwide_file(os.path.join(base_output_path, "en_worldwide-9class.train.json"), bigger_first) + convert_worldwide_file(os.path.join(base_output_path, "en_worldwide-9class.dev.json"), bigger_first) + convert_worldwide_file(os.path.join(base_output_path, "en_worldwide-9class.test.json"), bigger_first) + + combine_files(os.path.join(base_output_path, "%s.train.json" % short_name), + os.path.join(base_output_path, "en_ontonotes-multi.train.json"), + os.path.join(base_output_path, "en_worldwide-9class-multi.train.json")) + shutil.copyfile(os.path.join(base_output_path, "en_ontonotes-multi.dev.json"), + os.path.join(base_output_path, "%s.dev.json" % short_name)) + shutil.copyfile(os.path.join(base_output_path, "en_ontonotes-multi.test.json"), + os.path.join(base_output_path, "%s.test.json" % short_name)) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--no_simplify', dest='simplify', action='store_false', help='By default, this script will simplify the OntoNotes 18 classes to the 8 WorldWide classes in a second column. Turning that off will leave that column blank. Initial experiments with that setting were very bad, though') + parser.add_argument('--no_bigger_first', dest='bigger_first', action='store_false', help='By default, this script will put the 18 class tags in the first column and the 8 in the second. This flips the order') + args = parser.parse_args() + + paths = default_paths.get_default_paths() + base_output_path = paths["NER_DATA_DIR"] + + build_multitag_dataset(base_output_path, "en_ontonotes-ww-multi", args.simplify, args.bigger_first) + +if __name__ == '__main__': + main() + diff --git a/stanza/stanza/utils/datasets/ner/prepare_ner_file.py b/stanza/stanza/utils/datasets/ner/prepare_ner_file.py new file mode 100644 index 0000000000000000000000000000000000000000..9bd3d5a8970dc925b4a9ea0b96c4d21d65999039 --- /dev/null +++ b/stanza/stanza/utils/datasets/ner/prepare_ner_file.py @@ -0,0 +1,78 @@ +""" +This script converts NER data from the CoNLL03 format to the latest CoNLL-U format. The script assumes that in the +input column format data, the token is always in the first column, while the NER tag is always in the last column. +""" + +import argparse +import json + +MIN_NUM_FIELD = 2 +MAX_NUM_FIELD = 5 + +DOC_START_TOKEN = '-DOCSTART-' + +def parse_args(): + parser = argparse.ArgumentParser(description="Convert the conll03 format data into conllu format.") + parser.add_argument('input', help='Input conll03 format data filename.') + parser.add_argument('output', help='Output json filename.') + args = parser.parse_args() + return args + +def main(): + args = parse_args() + process_dataset(args.input, args.output) + +def process_dataset(input_filename, output_filename): + sentences = load_conll03(input_filename) + print("{} examples loaded from {}".format(len(sentences), input_filename)) + + document = [] + for (words, tags) in sentences: + sent = [] + for w, t in zip(words, tags): + sent += [{'text': w, 'ner': t}] + document += [sent] + + with open(output_filename, 'w', encoding="utf-8") as outfile: + json.dump(document, outfile, indent=1) + print("Generated json file {}".format(output_filename)) + +# TODO: make skip_doc_start an argument +def load_conll03(filename, skip_doc_start=True): + cached_lines = [] + examples = [] + with open(filename, encoding="utf-8") as infile: + for line in infile: + line = line.strip() + if skip_doc_start and DOC_START_TOKEN in line: + continue + if len(line) > 0: + array = line.split("\t") + if len(array) < MIN_NUM_FIELD: + array = line.split() + if len(array) < MIN_NUM_FIELD: + continue + else: + cached_lines.append(line) + elif len(cached_lines) > 0: + example = process_cache(cached_lines) + examples.append(example) + cached_lines = [] + if len(cached_lines) > 0: + examples.append(process_cache(cached_lines)) + return examples + +def process_cache(cached_lines): + tokens = [] + ner_tags = [] + for line in cached_lines: + array = line.split("\t") + if len(array) < MIN_NUM_FIELD: + array = line.split() + assert len(array) >= MIN_NUM_FIELD and len(array) <= MAX_NUM_FIELD, "Got unexpected line length: {}".format(array) + tokens.append(array[0]) + ner_tags.append(array[-1]) + return (tokens, ner_tags) + +if __name__ == '__main__': + main() diff --git a/stanza/stanza/utils/datasets/ner/utils.py b/stanza/stanza/utils/datasets/ner/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..82e5f675fc6b24d4b933795cc54050536130c5a3 --- /dev/null +++ b/stanza/stanza/utils/datasets/ner/utils.py @@ -0,0 +1,417 @@ +""" +Utils for the processing of NER datasets + +These can be invoked from either the specific dataset scripts +or the entire prepare_ner_dataset.py script +""" + +from collections import defaultdict +import json +import os +import random + +from stanza.models.common.doc import Document +import stanza.utils.datasets.ner.prepare_ner_file as prepare_ner_file + +SHARDS = ('train', 'dev', 'test') + +def bioes_to_bio(tags): + new_tags = [] + in_entity = False + for tag in tags: + if tag == 'O': + new_tags.append(tag) + in_entity = False + elif in_entity and (tag.startswith("B-") or tag.startswith("S-")): + # TODO: does the tag have to match the previous tag? + # eg, does B-LOC B-PER in BIOES need a B-PER or is I-PER sufficient? + new_tags.append('B-' + tag[2:]) + else: + new_tags.append('I-' + tag[2:]) + in_entity = True + return new_tags + +def convert_bioes_to_bio(base_input_path, base_output_path, short_name): + """ + Convert BIOES files back to BIO (not BIO2) + + Useful for preparing datasets for CoreNLP, which doesn't do great with the more highly split classes + """ + for shard in SHARDS: + input_filename = os.path.join(base_input_path, '%s.%s.bioes' % (short_name, shard)) + output_filename = os.path.join(base_output_path, '%s.%s.bio' % (short_name, shard)) + + input_sentences = read_tsv(input_filename, text_column=0, annotation_column=1) + new_sentences = [] + for sentence in input_sentences: + tags = [x[1] for x in sentence] + tags = bioes_to_bio(tags) + sentence = [(x[0], y) for x, y in zip(sentence, tags)] + new_sentences.append(sentence) + write_sentences(output_filename, new_sentences) + + +def convert_bio_to_json(base_input_path, base_output_path, short_name, suffix="bio", shard_names=SHARDS, shards=SHARDS): + """ + Convert BIO files to json + + It can often be convenient to put the intermediate BIO files in + the same directory as the output files, in which case you can pass + in same path for both base_input_path and base_output_path. + + This also will rewrite a BIOES as json + """ + for input_shard, output_shard in zip(shard_names, shards): + input_filename = os.path.join(base_input_path, '%s.%s.%s' % (short_name, input_shard, suffix)) + if not os.path.exists(input_filename): + alt_filename = os.path.join(base_input_path, '%s.%s' % (input_shard, suffix)) + if os.path.exists(alt_filename): + input_filename = alt_filename + else: + raise FileNotFoundError('Cannot find %s component of %s in %s or %s' % (output_shard, short_name, input_filename, alt_filename)) + output_filename = os.path.join(base_output_path, '%s.%s.json' % (short_name, output_shard)) + print("Converting %s to %s" % (input_filename, output_filename)) + prepare_ner_file.process_dataset(input_filename, output_filename) + +def get_tags(datasets): + """ + return the set of tags used in these datasets + + datasets is expected to be train, dev, test but could be any list + """ + tags = set() + for dataset in datasets: + for sentence in dataset: + for word, tag in sentence: + tags.add(tag) + return tags + +def write_sentences(output_filename, dataset): + """ + Write exactly one output file worth of dataset + """ + os.makedirs(os.path.split(output_filename)[0], exist_ok=True) + with open(output_filename, "w", encoding="utf-8") as fout: + for sent_idx, sentence in enumerate(dataset): + for word_idx, word in enumerate(sentence): + if len(word) > 2: + word = word[:2] + try: + fout.write("%s\t%s\n" % word) + except TypeError: + raise TypeError("Unable to process sentence %d word %d of file %s" % (sent_idx, word_idx, output_filename)) + fout.write("\n") + +def write_dataset(datasets, output_dir, short_name, suffix="bio", shard_names=SHARDS, shards=SHARDS): + """ + write all three pieces of a dataset to output_dir + + datasets should be 3 lists: train, dev, test + each list should be a list of sentences + each sentence is a list of pairs: word, tag + + after writing to .bio files, the files will be converted to .json + """ + for shard, dataset in zip(shard_names, datasets): + output_filename = os.path.join(output_dir, "%s.%s.%s" % (short_name, shard, suffix)) + write_sentences(output_filename, dataset) + + convert_bio_to_json(output_dir, output_dir, short_name, suffix, shard_names=shard_names, shards=shards) + + +def write_multitag_json(output_filename, dataset): + json_dataset = [] + for sentence in dataset: + json_sentence = [] + for word in sentence: + word = {'text': word[0], + 'ner': word[1], + 'multi_ner': word[2]} + json_sentence.append(word) + json_dataset.append(json_sentence) + with open(output_filename, 'w', encoding='utf-8') as fout: + json.dump(json_dataset, fout, indent=2) + +def write_multitag_dataset(datasets, output_dir, short_name, suffix="bio", shard_names=SHARDS, shards=SHARDS): + for shard, dataset in zip(shard_names, datasets): + output_filename = os.path.join(output_dir, "%s.%s.%s" % (short_name, shard, suffix)) + write_sentences(output_filename, dataset) + + for shard, dataset in zip(shard_names, datasets): + output_filename = os.path.join(output_dir, "%s.%s.json" % (short_name, shard)) + write_multitag_json(output_filename, dataset) + +def read_tsv(filename, text_column, annotation_column, remap_fn=None, skip_comments=True, keep_broken_tags=False, keep_all_columns=False, separator="\t"): + """ + Read sentences from a TSV file + + Returns a list of list of (word, tag) + + If keep_broken_tags==True, then None is returned for a missing. Otherwise, an IndexError is thrown + """ + with open(filename, encoding="utf-8") as fin: + lines = fin.readlines() + + lines = [x.strip() for x in lines] + + sentences = [] + current_sentence = [] + for line_idx, line in enumerate(lines): + if not line: + if current_sentence: + sentences.append(current_sentence) + current_sentence = [] + continue + if skip_comments and line.startswith("#"): + continue + + pieces = line.split(separator) + try: + word = pieces[text_column] + except IndexError as e: + raise IndexError("Could not find word index %d at line %d |%s|" % (text_column, line_idx, line)) from e + if word == '\x96': + # this happens in GermEval2014 for some reason + continue + try: + tag = pieces[annotation_column] + except IndexError as e: + if keep_broken_tags: + tag = None + else: + raise IndexError("Could not find tag index %d at line %d |%s|" % (annotation_column, line_idx, line)) from e + if remap_fn: + tag = remap_fn(tag) + + if keep_all_columns: + pieces[annotation_column] = tag + current_sentence.append(pieces) + else: + current_sentence.append((word, tag)) + + if current_sentence: + sentences.append(current_sentence) + + return sentences + +def random_shuffle_directory(input_dir, output_dir, short_name): + input_files = os.listdir(input_dir) + input_files = sorted(input_files) + random_shuffle_files(input_dir, input_files, output_dir, short_name) + +def random_shuffle_files(input_dir, input_files, output_dir, short_name): + """ + Shuffle the files into different chunks based on their filename + + The first piece of the filename, split by ".", is used as a random seed. + + This will make it so that adding new files or using a different + annotation scheme (assuming that's encoding in pieces of the + filename) won't change the distibution of the files + """ + input_keys = {} + for f in input_files: + seed = f.split(".")[0] + if seed in input_keys: + raise ValueError("Multiple files with the same prefix: %s and %s" % (input_keys[seed], f)) + input_keys[seed] = f + assert len(input_keys) == len(input_files) + + train_files = [] + dev_files = [] + test_files = [] + + for filename in input_files: + seed = filename.split(".")[0] + # "salt" the filenames when using as a seed + # definitely not because of a dumb bug in the original implementation + seed = seed + ".txt.4class.tsv" + random.seed(seed, 2) + location = random.random() + if location < 0.7: + train_files.append(filename) + elif location < 0.8: + dev_files.append(filename) + else: + test_files.append(filename) + + print("Train files: %d Dev files: %d Test files: %d" % (len(train_files), len(dev_files), len(test_files))) + assert len(train_files) + len(dev_files) + len(test_files) == len(input_files) + + file_lists = [train_files, dev_files, test_files] + datasets = [] + for files in file_lists: + dataset = [] + for filename in files: + dataset.extend(read_tsv(os.path.join(input_dir, filename), 0, 1)) + datasets.append(dataset) + + write_dataset(datasets, output_dir, short_name) + return len(train_files), len(dev_files), len(test_files) + +def random_shuffle_by_prefixes(input_dir, output_dir, short_name, prefix_map): + input_files = os.listdir(input_dir) + input_files = sorted(input_files) + + file_divisions = defaultdict(list) + for filename in input_files: + for division in prefix_map.keys(): + for prefix in prefix_map[division]: + if filename.startswith(prefix): + break + else: # for/else is intentional + continue + break + else: # yes, stop asking + raise ValueError("Could not assign %s to any of the divisions in the prefix_map" % filename) + #print("Assigning %s to %s because of %s" % (filename, division, prefix)) + file_divisions[division].append(filename) + + num_train_files = 0 + num_dev_files = 0 + num_test_files = 0 + for division in file_divisions.keys(): + print() + print("Processing %d files from %s" % (len(file_divisions[division]), division)) + d_train, d_dev, d_test = random_shuffle_files(input_dir, file_divisions[division], output_dir, "%s-%s" % (short_name, division)) + num_train_files += d_train + num_dev_files += d_dev + num_test_files += d_test + + print() + print("After shuffling: Train files: %d Dev files: %d Test files: %d" % (num_train_files, num_dev_files, num_test_files)) + dataset_divisions = ["%s-%s" % (short_name, division) for division in file_divisions] + combine_dataset(output_dir, output_dir, dataset_divisions, short_name) + +def combine_dataset(input_dir, output_dir, input_datasets, output_dataset): + datasets = [] + for shard in SHARDS: + full_dataset = [] + for input_dataset in input_datasets: + input_filename = "%s.%s.json" % (input_dataset, shard) + input_path = os.path.join(input_dir, input_filename) + with open(input_path, encoding="utf-8") as fin: + dataset = json.load(fin) + converted = [[(word['text'], word['ner']) for word in sentence] for sentence in dataset] + full_dataset.extend(converted) + datasets.append(full_dataset) + write_dataset(datasets, output_dir, output_dataset) + +def read_prefix_file(destination_file): + """ + Read a prefix file such as the one for the Worldwide dataset + + the format should be + + africa: + af_ + ... + + asia: + cn_ + ... + """ + destination = None + known_prefixes = set() + prefixes = [] + + prefix_map = {} + with open(destination_file, encoding="utf-8") as fin: + for line in fin: + line = line.strip() + if line.startswith("#"): + continue + if not line: + continue + if line.endswith(":"): + if destination is not None: + prefix_map[destination] = prefixes + prefixes = [] + destination = line[:-1].strip().lower().replace(" ", "_") + else: + if not destination: + raise RuntimeError("Found a prefix before the first label was assigned when reading %s" % destination_file) + prefixes.append(line) + if line in known_prefixes: + raise RuntimeError("Found the same prefix twice! %s" % line) + known_prefixes.add(line) + + if destination and prefixes: + prefix_map[destination] = prefixes + + return prefix_map + +def read_json_entities(filename): + """ + Read entities from a file, return a list of (text, label) + + Should work on both BIOES and BIO + """ + with open(filename) as fin: + doc = Document(json.load(fin)) + + return list_doc_entities(doc) + +def list_doc_entities(doc): + """ + Return a list of (text, label) + + Should work on both BIOES and BIO + """ + entities = [] + for sentence in doc.sentences: + current_entity = [] + previous_label = None + for token in sentence.tokens: + if token.ner == 'O' or token.ner.startswith("E-"): + if token.ner.startswith("E-"): + current_entity.append(token.text) + if current_entity: + assert previous_label is not None + entities.append((current_entity, previous_label)) + current_entity = [] + previous_label = None + elif token.ner.startswith("I-"): + if previous_label is not None and previous_label != 'O' and previous_label != token.ner[2:]: + if current_entity: + assert previous_label is not None + entities.append((current_entity, previous_label)) + current_entity = [] + previous_label = token.ner[2:] + current_entity.append(token.text) + elif token.ner.startswith("B-") or token.ner.startswith("S-"): + if current_entity: + assert previous_label is not None + entities.append((current_entity, previous_label)) + current_entity = [] + previous_label = None + current_entity.append(token.text) + previous_label = token.ner[2:] + if token.ner.startswith("S-"): + assert previous_label is not None + entities.append(current_entity) + current_entity = [] + previous_label = None + else: + raise RuntimeError("Expected BIO(ES) format in the json file!") + previous_label = token.ner[2:] + if current_entity: + assert previous_label is not None + entities.append((current_entity, previous_label)) + entities = [(tuple(x[0]), x[1]) for x in entities] + return entities + +def combine_files(output_filename, *input_filenames): + """ + Combine multiple NER json files into one NER file + """ + doc = [] + + for filename in input_filenames: + with open(filename) as fin: + new_doc = json.load(fin) + doc.extend(new_doc) + + with open(output_filename, "w") as fout: + json.dump(doc, fout, indent=2) + diff --git a/stanza/stanza/utils/datasets/vietnamese/__init__.py b/stanza/stanza/utils/datasets/vietnamese/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/stanza/stanza/utils/pretrain/compare_pretrains.py b/stanza/stanza/utils/pretrain/compare_pretrains.py new file mode 100644 index 0000000000000000000000000000000000000000..4a498ea52437abdfad9cfa5d9d6f079f75c65a3b --- /dev/null +++ b/stanza/stanza/utils/pretrain/compare_pretrains.py @@ -0,0 +1,54 @@ +import sys +import numpy as np + +from stanza.models.common.pretrain import Pretrain + +pt1_filename = sys.argv[1] +pt2_filename = sys.argv[2] + +pt1 = Pretrain(pt1_filename) +pt2 = Pretrain(pt2_filename) + +vocab1 = pt1.vocab +vocab2 = pt2.vocab + +common_words = [x for x in vocab1 if x in vocab2] +print("%d shared words, out of %d in %s and %d in %s" % (len(common_words), len(vocab1), pt1_filename, len(vocab2), pt2_filename)) + +eps = 0.0001 +total_norm = 0.0 +total_close = 0 + +words_different = [] + +for word, idx in vocab1._unit2id.items(): + if word not in vocab2: + continue + v1 = pt1.emb[idx] + v2 = pt2.emb[pt2.vocab[word]] + norm = np.linalg.norm(v1 - v2) + + if norm < eps: + total_close += 1 + else: + total_norm += norm + if len(words_different) < 10: + words_different.append("|%s|" % word) + #print(word, idx, pt2.vocab[word]) + #print(v1) + #print(v2) + +if total_close < len(common_words): + avg_norm = total_norm / (len(common_words) - total_close) + print("%d vectors were close. Average difference of the others: %f" % (total_close, avg_norm)) + print("The first few different words were:\n %s" % "\n ".join(words_different)) +else: + print("All %d vectors were close!" % total_close) + + for word, idx in vocab1._unit2id.items(): + if word not in vocab2: + continue + if pt2.vocab[word] != idx: + break + else: + print("All indices are the same") diff --git a/stanza/stanza/utils/training/common.py b/stanza/stanza/utils/training/common.py new file mode 100644 index 0000000000000000000000000000000000000000..5d1b8114a20c4d1c00c8e06aac6fbc6af2a6fc5d --- /dev/null +++ b/stanza/stanza/utils/training/common.py @@ -0,0 +1,397 @@ +import argparse +import glob +import logging +import os +import pathlib +import sys +import tempfile + +from enum import Enum + +from stanza.resources.default_packages import default_charlms, lemma_charlms, pos_charlms, depparse_charlms, TRANSFORMERS, TRANSFORMER_LAYERS +from stanza.models.common.constant import treebank_to_short_name +from stanza.models.common.utils import ud_scores +from stanza.resources.common import download, DEFAULT_MODEL_DIR, UnknownLanguageError +from stanza.utils.datasets import common +import stanza.utils.default_paths as default_paths +from stanza.utils import conll18_ud_eval as ud_eval + +logger = logging.getLogger('stanza') + +class Mode(Enum): + TRAIN = 1 + SCORE_DEV = 2 + SCORE_TEST = 3 + SCORE_TRAIN = 4 + +class ArgumentParserWithExtraHelp(argparse.ArgumentParser): + def __init__(self, sub_argparse, *args, **kwargs): + super().__init__(*args, **kwargs) # forwards all unused arguments + + self.sub_argparse = sub_argparse + + def print_help(self, file=None): + super().print_help(file=file) + + def format_help(self): + help_text = super().format_help() + if self.sub_argparse is not None: + sub_text = self.sub_argparse.format_help().split("\n") + first_line = -1 + for line_idx, line in enumerate(sub_text): + if line.strip().startswith("usage:"): + first_line = line_idx + elif first_line >= 0 and not line.strip(): + first_line = line_idx + break + help_text = help_text + "\n\nmodel arguments:" + "\n".join(sub_text[first_line:]) + return help_text + + +def build_argparse(sub_argparse=None): + parser = ArgumentParserWithExtraHelp(sub_argparse=sub_argparse, formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument('--save_output', dest='temp_output', default=True, action='store_false', help="Save output - default is to use a temp directory.") + + parser.add_argument('treebanks', type=str, nargs='+', help='Which treebanks to run on. Use all_ud or ud_all for all UD treebanks') + + parser.add_argument('--train', dest='mode', default=Mode.TRAIN, action='store_const', const=Mode.TRAIN, help='Run in train mode') + parser.add_argument('--score_dev', dest='mode', action='store_const', const=Mode.SCORE_DEV, help='Score the dev set') + parser.add_argument('--score_test', dest='mode', action='store_const', const=Mode.SCORE_TEST, help='Score the test set') + parser.add_argument('--score_train', dest='mode', action='store_const', const=Mode.SCORE_TRAIN, help='Score the train set as a test set. Currently only implemented for some models') + + # These arguments need to be here so we can identify if the model already exists in the user-specified home + # TODO: when all of the model scripts handle their own names, can eliminate this argument + parser.add_argument('--save_dir', type=str, default=None, help="Root dir for saving models. If set, will override the model's default.") + parser.add_argument('--save_name', type=str, default=None, help="Base name for saving models. If set, will override the model's default.") + + parser.add_argument('--charlm_only', action='store_true', default=False, help='When asking for ud_all, filter the ones which have charlms') + parser.add_argument('--transformer_only', action='store_true', default=False, help='When asking for ud_all, filter the ones for languages where we have transformers') + + parser.add_argument('--force', dest='force', action='store_true', default=False, help='Retrain existing models') + return parser + +def add_charlm_args(parser): + parser.add_argument('--charlm', default="default", type=str, help='Which charlm to run on. Will use the default charlm for this language/model if not set. Set to None to turn off charlm for languages with a default charlm') + parser.add_argument('--no_charlm', dest='charlm', action="store_const", const=None, help="Don't use a charlm, even if one is used by default for this package") + +def main(run_treebank, model_dir, model_name, add_specific_args=None, sub_argparse=None, build_model_filename=None, choose_charlm_method=None, args=None): + """ + A main program for each of the run_xyz scripts + + It collects the arguments and runs the main method for each dataset provided. + It also tries to look for an existing model and not overwrite it unless --force is provided + + model_name can be a callable expecting the args + - the charlm, for example, needs this feature, since it makes + both forward and backward models + """ + if args is None: + logger.info("Training program called with:\n" + " ".join(sys.argv)) + args = sys.argv[1:] + else: + logger.info("Training program called with:\n" + " ".join(args)) + + paths = default_paths.get_default_paths() + + parser = build_argparse(sub_argparse) + if add_specific_args is not None: + add_specific_args(parser) + if '--extra_args' in sys.argv: + idx = sys.argv.index('--extra_args') + extra_args = sys.argv[idx+1:] + command_args = parser.parse_args(sys.argv[:idx]) + else: + command_args, extra_args = parser.parse_known_args(args=args) + + # Pass this through to the underlying model as well as use it here + # we don't put --save_name here for the awkward situation of + # --save_name being specified for an invocation with multiple treebanks + if command_args.save_dir: + extra_args.extend(["--save_dir", command_args.save_dir]) + + if callable(model_name): + model_name = model_name(command_args) + + mode = command_args.mode + treebanks = [] + + for treebank in command_args.treebanks: + # this is a really annoying typo to make if you copy/paste a + # UD directory name on the cluster and your job dies 30s after + # being queued for an hour + if treebank.endswith("/"): + treebank = treebank[:-1] + if treebank.lower() in ('ud_all', 'all_ud'): + ud_treebanks = common.get_ud_treebanks(paths["UDBASE"]) + if choose_charlm_method is not None and command_args.charlm_only: + logger.info("Filtering ud_all treebanks to only those which can use charlm for this model") + ud_treebanks = [x for x in ud_treebanks + if choose_charlm_method(*treebank_to_short_name(x).split("_", 1), 'default') is not None] + if command_args.transformer_only: + logger.info("Filtering ud_all treebanks to only those which can use a transformer for this model") + ud_treebanks = [x for x in ud_treebanks if treebank_to_short_name(x).split("_")[0] in TRANSFORMERS] + logger.info("Expanding %s to %s", treebank, " ".join(ud_treebanks)) + treebanks.extend(ud_treebanks) + else: + treebanks.append(treebank) + + for treebank_idx, treebank in enumerate(treebanks): + if treebank_idx > 0: + logger.info("=========================================") + + short_name = treebank_to_short_name(treebank) + logger.debug("%s: %s" % (treebank, short_name)) + + save_name_args = [] + if model_name != 'ete': + # ete is several models at once, so we don't set --save_name + # theoretically we could handle a parametrized save_name + if command_args.save_name: + save_name = command_args.save_name + # if there's more than 1 treebank, we can't save them all to this save_name + # we have to override that value for each treebank + if len(treebanks) > 1: + save_name_dir, save_name_filename = os.path.split(save_name) + save_name_filename = "%s_%s" % (short_name, save_name_filename) + save_name = os.path.join(save_name_dir, save_name_filename) + logger.info("Save file for %s model for %s: %s", short_name, treebank, save_name) + save_name_args = ['--save_name', save_name] + # some run scripts can build the model filename + # in order to check for models that are already created + elif build_model_filename is None: + save_name = "%s_%s.pt" % (short_name, model_name) + logger.info("Save file for %s model: %s", short_name, save_name) + save_name_args = ['--save_name', save_name] + else: + save_name_args = [] + + if mode == Mode.TRAIN and not command_args.force: + if build_model_filename is not None: + model_path = build_model_filename(paths, short_name, command_args, extra_args) + elif command_args.save_dir: + model_path = os.path.join(command_args.save_dir, save_name) + else: + save_dir = os.path.join("saved_models", model_dir) + save_name_args.extend(["--save_dir", save_dir]) + model_path = os.path.join(save_dir, save_name) + + if model_path is None: + # this can happen with the identity lemmatizer, for example + pass + elif os.path.exists(model_path): + logger.info("%s: %s exists, skipping!" % (treebank, model_path)) + continue + else: + logger.info("%s: %s does not exist, training new model" % (treebank, model_path)) + + if command_args.temp_output and model_name != 'ete': + with tempfile.NamedTemporaryFile() as temp_output_file: + run_treebank(mode, paths, treebank, short_name, + temp_output_file.name, command_args, extra_args + save_name_args) + else: + run_treebank(mode, paths, treebank, short_name, + None, command_args, extra_args + save_name_args) + +def run_eval_script(gold_conllu_file, system_conllu_file, evals=None): + """ Wrapper for lemma scorer. """ + evaluation = ud_scores(gold_conllu_file, system_conllu_file) + + if evals is None: + return ud_eval.build_evaluation_table(evaluation, verbose=True, counts=False, enhanced=False) + else: + results = [evaluation[key].f1 for key in evals] + max_len = max(5, max(len(e) for e in evals)) + evals_string = " ".join(("{:>%d}" % max_len).format(e) for e in evals) + results_string = " ".join(("{:%d.2f}" % max_len).format(100 * x) for x in results) + return evals_string + "\n" + results_string + +def run_eval_script_tokens(eval_gold, eval_pred): + return run_eval_script(eval_gold, eval_pred, evals=["Tokens", "Sentences", "Words"]) + +def run_eval_script_mwt(eval_gold, eval_pred): + return run_eval_script(eval_gold, eval_pred, evals=["Words"]) + +def run_eval_script_pos(eval_gold, eval_pred): + return run_eval_script(eval_gold, eval_pred, evals=["UPOS", "XPOS", "UFeats", "AllTags"]) + +def run_eval_script_depparse(eval_gold, eval_pred): + return run_eval_script(eval_gold, eval_pred, evals=["UAS", "LAS", "CLAS", "MLAS", "BLEX"]) + + +def find_wordvec_pretrain(language, default_pretrains, dataset_pretrains=None, dataset=None, model_dir=DEFAULT_MODEL_DIR): + # try to get the default pretrain for the language, + # but allow the package specific value to override it if that is set + default_pt = default_pretrains.get(language, None) + if dataset is not None and dataset_pretrains is not None: + default_pt = dataset_pretrains.get(language, {}).get(dataset, default_pt) + + if default_pt is not None: + default_pt_path = '{}/{}/pretrain/{}.pt'.format(model_dir, language, default_pt) + if not os.path.exists(default_pt_path): + logger.info("Default pretrain should be {} Attempting to download".format(default_pt_path)) + try: + download(lang=language, package=None, processors={"pretrain": default_pt}, model_dir=model_dir) + except UnknownLanguageError: + # if there's a pretrain in the directory, hiding this + # error will let us find that pretrain later + pass + if os.path.exists(default_pt_path): + if dataset is not None and dataset_pretrains is not None and language in dataset_pretrains and dataset in dataset_pretrains[language]: + logger.info(f"Using default pretrain for {language}:{dataset}, found in {default_pt_path} To use a different pretrain, specify --wordvec_pretrain_file") + else: + logger.info(f"Using default pretrain for language, found in {default_pt_path} To use a different pretrain, specify --wordvec_pretrain_file") + return default_pt_path + + pretrain_path = '{}/{}/pretrain/*.pt'.format(model_dir, language) + pretrains = glob.glob(pretrain_path) + if len(pretrains) == 0: + # we already tried to download the default pretrain once + # and it didn't work. maybe the default language package + # will have something? + logger.warning(f"Cannot figure out which pretrain to use for '{language}'. Will download the default package and hope for the best") + try: + download(lang=language, model_dir=model_dir) + except UnknownLanguageError as e: + # this is a very unusual situation + # basically, there was a language which we started to add + # to the resources, but then didn't release the models + # as part of resources.json + raise FileNotFoundError(f"Cannot find any pretrains in {pretrain_path} No pretrains in the system for this language. Please prepare an embedding as a .pt and use --wordvec_pretrain_file to specify a .pt file to use") from e + pretrains = glob.glob(pretrain_path) + if len(pretrains) == 0: + raise FileNotFoundError(f"Cannot find any pretrains in {pretrain_path} Try 'stanza.download(\"{language}\")' to get a default pretrain or use --wordvec_pretrain_file to specify a .pt file to use") + if len(pretrains) > 1: + raise FileNotFoundError(f"Too many pretrains to choose from in {pretrain_path} Must specify an exact path to a --wordvec_pretrain_file") + pt = pretrains[0] + logger.info(f"Using pretrain found in {pt} To use a different pretrain, specify --wordvec_pretrain_file") + return pt + +def find_charlm_file(direction, language, charlm, model_dir=DEFAULT_MODEL_DIR): + """ + Return the path to the forward or backward charlm if it exists for the given package + + If we can figure out the package, but can't find it anywhere, we try to download it + """ + saved_path = 'saved_models/charlm/{}_{}_{}_charlm.pt'.format(language, charlm, direction) + if os.path.exists(saved_path): + logger.info(f'Using model {saved_path} for {direction} charlm') + return saved_path + + resource_path = '{}/{}/{}_charlm/{}.pt'.format(model_dir, language, direction, charlm) + if os.path.exists(resource_path): + logger.info(f'Using model {resource_path} for {direction} charlm') + return resource_path + + try: + download(lang=language, package=None, processors={f"{direction}_charlm": charlm}, model_dir=model_dir) + if os.path.exists(resource_path): + logger.info(f'Downloaded model, using model {resource_path} for {direction} charlm') + return resource_path + except ValueError as e: + raise FileNotFoundError(f"Cannot find {direction} charlm in either {saved_path} or {resource_path} Attempted downloading {charlm} but that did not work") from e + + raise FileNotFoundError(f"Cannot find {direction} charlm in either {saved_path} or {resource_path} Attempted downloading {charlm} but that did not work") + +def build_charlm_args(language, charlm, base_args=True, model_dir=DEFAULT_MODEL_DIR): + """ + If specified, return forward and backward charlm args + """ + if charlm: + try: + forward = find_charlm_file('forward', language, charlm, model_dir=model_dir) + backward = find_charlm_file('backward', language, charlm, model_dir=model_dir) + except FileNotFoundError as e: + # if we couldn't find sd_isra when training an SD model, + # for example, but isra exists, we try to download the + # shorter model name + if charlm.startswith(language + "_"): + short_charlm = charlm[len(language)+1:] + try: + forward = find_charlm_file('forward', language, short_charlm, model_dir=model_dir) + backward = find_charlm_file('backward', language, short_charlm, model_dir=model_dir) + except FileNotFoundError as e2: + raise FileNotFoundError("Tried to find charlm %s, which doesn't exist. Also tried %s, but didn't find that either" % (charlm, short_charlm)) from e + logger.warning("Was asked to find charlm %s, which does not exist. Did find %s though", charlm, short_charlm) + else: + raise + + char_args = ['--charlm_forward_file', forward, + '--charlm_backward_file', backward] + if not base_args: + return char_args + return ['--charlm', + '--charlm_shorthand', f'{language}_{charlm}'] + char_args + + return [] + +def choose_charlm(language, dataset, charlm, language_charlms, dataset_charlms): + """ + charlm == "default" means the default charlm for this dataset or language + charlm == None is no charlm + """ + default_charlm = language_charlms.get(language, None) + specific_charlm = dataset_charlms.get(language, {}).get(dataset, None) + + if charlm is None: + return None + elif charlm != "default": + return charlm + elif dataset in dataset_charlms.get(language, {}): + # this way, a "" or None result gets honored + # thus treating "not in the map" as a way for dataset_charlms to signal to use the default + return specific_charlm + elif default_charlm: + return default_charlm + else: + return None + +def choose_pos_charlm(short_language, dataset, charlm): + """ + charlm == "default" means the default charlm for this dataset or language + charlm == None is no charlm + """ + return choose_charlm(short_language, dataset, charlm, default_charlms, pos_charlms) + +def choose_depparse_charlm(short_language, dataset, charlm): + """ + charlm == "default" means the default charlm for this dataset or language + charlm == None is no charlm + """ + return choose_charlm(short_language, dataset, charlm, default_charlms, depparse_charlms) + +def choose_lemma_charlm(short_language, dataset, charlm): + """ + charlm == "default" means the default charlm for this dataset or language + charlm == None is no charlm + """ + return choose_charlm(short_language, dataset, charlm, default_charlms, lemma_charlms) + +def choose_transformer(short_language, command_args, extra_args, warn=True, layers=False): + """ + Choose a transformer using the default options for this language + """ + bert_args = [] + if command_args is not None and command_args.use_bert and '--bert_model' not in extra_args: + if short_language in TRANSFORMERS: + bert_args = ['--bert_model', TRANSFORMERS.get(short_language)] + if layers and short_language in TRANSFORMER_LAYERS and '--bert_hidden_layers' not in extra_args: + bert_args.extend(['--bert_hidden_layers', str(TRANSFORMER_LAYERS.get(short_language))]) + elif warn: + logger.error("Transformer requested, but no default transformer for %s Specify one using --bert_model" % short_language) + + return bert_args + +def build_pos_charlm_args(short_language, dataset, charlm, base_args=True, model_dir=DEFAULT_MODEL_DIR): + charlm = choose_pos_charlm(short_language, dataset, charlm) + charlm_args = build_charlm_args(short_language, charlm, base_args, model_dir) + return charlm_args + +def build_lemma_charlm_args(short_language, dataset, charlm, base_args=True, model_dir=DEFAULT_MODEL_DIR): + charlm = choose_lemma_charlm(short_language, dataset, charlm) + charlm_args = build_charlm_args(short_language, charlm, base_args, model_dir) + return charlm_args + +def build_depparse_charlm_args(short_language, dataset, charlm, base_args=True, model_dir=DEFAULT_MODEL_DIR): + charlm = choose_depparse_charlm(short_language, dataset, charlm) + charlm_args = build_charlm_args(short_language, charlm, base_args, model_dir) + return charlm_args diff --git a/stanza/stanza/utils/training/compose_ete_results.py b/stanza/stanza/utils/training/compose_ete_results.py new file mode 100644 index 0000000000000000000000000000000000000000..d78117700bcf2b9675412a52de9449bc565200d2 --- /dev/null +++ b/stanza/stanza/utils/training/compose_ete_results.py @@ -0,0 +1,100 @@ +""" +Turn the ETE results into markdown + +Parses blocks like this from the model eval script + +2022-01-14 01:23:34 INFO: End to end results for af_afribooms models on af_afribooms test data: +Metric | Precision | Recall | F1 Score | AligndAcc +-----------+-----------+-----------+-----------+----------- +Tokens | 99.93 | 99.92 | 99.93 | +Sentences | 100.00 | 100.00 | 100.00 | +Words | 99.93 | 99.92 | 99.93 | +UPOS | 97.97 | 97.96 | 97.97 | 98.04 +XPOS | 93.98 | 93.97 | 93.97 | 94.04 +UFeats | 97.23 | 97.22 | 97.22 | 97.29 +AllTags | 93.89 | 93.88 | 93.88 | 93.95 +Lemmas | 97.40 | 97.39 | 97.39 | 97.46 +UAS | 87.39 | 87.38 | 87.38 | 87.45 +LAS | 83.57 | 83.56 | 83.57 | 83.63 +CLAS | 76.88 | 76.45 | 76.66 | 76.52 +MLAS | 72.28 | 71.87 | 72.07 | 71.94 +BLEX | 73.20 | 72.79 | 73.00 | 72.86 + + +Turns them into a markdown table. + +Included is an attempt to mark the default packages with a green check. + +""" + +import argparse + +from stanza.models.common.constant import pretty_langcode_to_lang +from stanza.models.common.short_name_to_treebank import short_name_to_treebank +from stanza.utils.training.run_ete import RESULTS_STRING +from stanza.resources.default_packages import default_treebanks + +EXPECTED_ORDER = ["Tokens", "Sentences", "Words", "UPOS", "XPOS", "UFeats", "AllTags", "Lemmas", "UAS", "LAS", "CLAS", "MLAS", "BLEX"] + +parser = argparse.ArgumentParser() +parser.add_argument("filenames", type=str, nargs="+", help="Which file(s) to read") +args = parser.parse_args() + +lines = [] +for filename in args.filenames: + with open(filename) as fin: + lines.extend(fin.readlines()) + +blocks = [] +index = 0 +while index < len(lines): + line = lines[index] + if line.find(RESULTS_STRING) < 0: + index = index + 1 + continue + + line = line[line.find(RESULTS_STRING) + len(RESULTS_STRING):].strip() + short_name = line.split()[0] + + # skip the header of the expected output + index = index + 1 + line = lines[index] + pieces = line.split("|") + assert pieces[0].strip() == 'Metric', "output format changed?" + assert pieces[3].strip() == 'F1 Score', "output format changed?" + + index = index + 1 + line = lines[index] + assert line.startswith("-----"), "output format changed?" + + index = index + 1 + + block = lines[index:index+13] + assert len(block) == 13 + index = index + 13 + + block = [x.split("|") for x in block] + assert all(x[0].strip() == y for x, y in zip(block, EXPECTED_ORDER)), "output format changed?" + lcode, short_dataset = short_name.split("_", 1) + language = pretty_langcode_to_lang(lcode) + treebank = short_name_to_treebank(short_name) + long_dataset = treebank.split("-")[-1] + + checkmark = "" + if default_treebanks[lcode] == short_dataset: + checkmark = '' + + block = [language, "[%s](%s)" % (long_dataset, "https://github.com/UniversalDependencies/%s" % treebank), lcode, checkmark] + [x[3].strip() for x in block] + blocks.append(block) + +PREFIX = ["​Macro Avg", "​", "​", ""] + +avg = [sum(float(x[i]) for x in blocks) / len(blocks) for i in range(len(PREFIX), len(EXPECTED_ORDER) + len(PREFIX))] +avg = PREFIX + ["%.2f" % x for x in avg] +blocks = sorted(blocks) +blocks = [avg] + blocks + +chart = ["|%s|" % " | ".join(x) for x in blocks] +for line in chart: + print(line) + diff --git a/stanza/stanza/utils/training/run_charlm.py b/stanza/stanza/utils/training/run_charlm.py new file mode 100644 index 0000000000000000000000000000000000000000..2aa251382f0caf6605a5861140ca88fee81483b1 --- /dev/null +++ b/stanza/stanza/utils/training/run_charlm.py @@ -0,0 +1,86 @@ +""" +Trains or scores a charlm model. +""" + +import logging +import os + +from stanza.models import charlm +from stanza.utils.training import common +from stanza.utils.training.common import Mode + +logger = logging.getLogger('stanza') + + +def add_charlm_args(parser): + """ + Extra args for the charlm: forward/backward + """ + parser.add_argument('--direction', default='forward', choices=['forward', 'backward'], help="Forward or backward language model") + parser.add_argument('--forward', action='store_const', dest='direction', const='forward', help="Train a forward language model") + parser.add_argument('--backward', action='store_const', dest='direction', const='backward', help="Train a backward language model") + + +def run_treebank(mode, paths, treebank, short_name, + temp_output_file, command_args, extra_args): + short_language, dataset_name = short_name.split("_", 1) + + train_dir = os.path.join(paths["CHARLM_DATA_DIR"], short_language, dataset_name, "train") + + dev_file = os.path.join(paths["CHARLM_DATA_DIR"], short_language, dataset_name, "dev.txt") + if not os.path.exists(dev_file) and os.path.exists(dev_file + ".xz"): + dev_file = dev_file + ".xz" + + test_file = os.path.join(paths["CHARLM_DATA_DIR"], short_language, dataset_name, "test.txt") + if not os.path.exists(test_file) and os.path.exists(test_file + ".xz"): + test_file = test_file + ".xz" + + # python -m stanza.models.charlm --train_dir $train_dir --eval_file $dev_file \ + # --direction $direction --shorthand $short --mode train $args + # python -m stanza.models.charlm --eval_file $dev_file \ + # --direction $direction --shorthand $short --mode predict $args + # python -m stanza.models.charlm --eval_file $test_file \ + # --direction $direction --shorthand $short --mode predict $args + + direction = command_args.direction + default_args = ['--%s' % direction, + '--shorthand', short_name] + if mode == Mode.TRAIN: + train_args = ['--mode', 'train'] + if '--train_dir' not in extra_args: + train_args += ['--train_dir', train_dir] + if '--eval_file' not in extra_args: + train_args += ['--eval_file', dev_file] + train_args = train_args + default_args + extra_args + logger.info("Running train step with args: %s", train_args) + charlm.main(train_args) + + if mode == Mode.SCORE_DEV: + dev_args = ['--mode', 'predict'] + if '--eval_file' not in extra_args: + dev_args += ['--eval_file', dev_file] + dev_args = dev_args + default_args + extra_args + logger.info("Running dev step with args: %s", dev_args) + charlm.main(dev_args) + + if mode == Mode.SCORE_TEST: + test_args = ['--mode', 'predict'] + if '--eval_file' not in extra_args: + test_args += ['--eval_file', test_file] + test_args = test_args + default_args + extra_args + logger.info("Running test step with args: %s", test_args) + charlm.main(test_args) + + +def get_model_name(args): + """ + The charlm saves forward and backward charlms to the same dir, but with different filenames + """ + return "%s_charlm" % args.direction + +def main(): + common.main(run_treebank, "charlm", get_model_name, add_charlm_args, charlm.build_argparse()) + +if __name__ == "__main__": + main() + diff --git a/stanza/stanza/utils/training/run_constituency.py b/stanza/stanza/utils/training/run_constituency.py new file mode 100644 index 0000000000000000000000000000000000000000..4e993f56ae1d87edc1c646e1d09de7b44b3f4320 --- /dev/null +++ b/stanza/stanza/utils/training/run_constituency.py @@ -0,0 +1,130 @@ +""" +Trains or scores a constituency model. + +Currently a suuuuper preliminary script. + +Example of how to run on multiple parsers at the same time on the Stanford workqueue: + +for i in `echo 1000 1001 1002 1003 1004`; do nlprun -d a6000 "python3 stanza/utils/training/run_constituency.py vi_vlsp23 --use_bert --stage1_bert_finetun --save_name vi_vlsp23_$i.pt --seed $i --epochs 200 --force" -o vi_vlsp23_$i.out; done + +""" + +import logging +import os + +from stanza.models import constituency_parser +from stanza.models.constituency.retagging import RETAG_METHOD +from stanza.utils.datasets.constituency import prepare_con_dataset +from stanza.utils.training import common +from stanza.utils.training.common import Mode, add_charlm_args, build_charlm_args, choose_charlm, find_wordvec_pretrain + +from stanza.resources.default_packages import default_charlms, default_pretrains + +logger = logging.getLogger('stanza') + +def add_constituency_args(parser): + add_charlm_args(parser) + + parser.add_argument('--use_bert', default=False, action="store_true", help='Use the default transformer for this language') + + parser.add_argument('--parse_text', dest='mode', action='store_const', const="parse_text", help='Parse a text file') + +def build_wordvec_args(short_language, dataset, extra_args): + if '--wordvec_pretrain_file' not in extra_args: + # will throw an error if the pretrain can't be found + wordvec_pretrain = find_wordvec_pretrain(short_language, default_pretrains) + wordvec_args = ['--wordvec_pretrain_file', wordvec_pretrain] + else: + wordvec_args = [] + + return wordvec_args + +def build_default_args(paths, short_language, dataset, command_args, extra_args): + if short_language in RETAG_METHOD: + retag_args = ["--retag_method", RETAG_METHOD[short_language]] + else: + retag_args = [] + + wordvec_args = build_wordvec_args(short_language, dataset, extra_args) + + charlm = choose_charlm(short_language, dataset, command_args.charlm, default_charlms, {}) + charlm_args = build_charlm_args(short_language, charlm, base_args=False) + + bert_args = common.choose_transformer(short_language, command_args, extra_args, warn=True, layers=True) + default_args = retag_args + wordvec_args + charlm_args + bert_args + + return default_args + +def build_model_filename(paths, short_name, command_args, extra_args): + short_language, dataset = short_name.split("_", 1) + + default_args = build_default_args(paths, short_language, dataset, command_args, extra_args) + + train_args = ["--shorthand", short_name, + "--mode", "train"] + train_args = train_args + default_args + if command_args.save_name is not None: + train_args.extend(["--save_name", command_args.save_name]) + if command_args.save_dir is not None: + train_args.extend(["--save_dir", command_args.save_dir]) + args = constituency_parser.parse_args(train_args) + save_name = constituency_parser.build_model_filename(args) + return save_name + + +def run_treebank(mode, paths, treebank, short_name, temp_output_file, command_args, extra_args): + constituency_dir = paths["CONSTITUENCY_DATA_DIR"] + short_language, dataset = short_name.split("_") + + train_file = os.path.join(constituency_dir, f"{short_name}_train.mrg") + dev_file = os.path.join(constituency_dir, f"{short_name}_dev.mrg") + test_file = os.path.join(constituency_dir, f"{short_name}_test.mrg") + + if not os.path.exists(train_file) or not os.path.exists(dev_file) or not os.path.exists(test_file): + logger.warning(f"The data for {short_name} is missing or incomplete. Attempting to rebuild...") + try: + prepare_con_dataset.main(short_name) + except: + logger.error(f"Unable to build the data. Please correctly build the files in {train_file}, {dev_file}, {test_file} and then try again.") + raise + + default_args = build_default_args(paths, short_language, dataset, command_args, extra_args) + + if mode == Mode.TRAIN: + train_args = ['--train_file', train_file, + '--eval_file', dev_file, + '--shorthand', short_name, + '--mode', 'train'] + train_args = train_args + default_args + extra_args + logger.info("Running train step with args: {}".format(train_args)) + constituency_parser.main(train_args) + + if mode == Mode.SCORE_DEV or mode == Mode.TRAIN: + dev_args = ['--eval_file', dev_file, + '--shorthand', short_name, + '--mode', 'predict'] + dev_args = dev_args + default_args + extra_args + logger.info("Running dev step with args: {}".format(dev_args)) + constituency_parser.main(dev_args) + + if mode == Mode.SCORE_TEST or mode == Mode.TRAIN: + test_args = ['--eval_file', test_file, + '--shorthand', short_name, + '--mode', 'predict'] + test_args = test_args + default_args + extra_args + logger.info("Running test step with args: {}".format(test_args)) + constituency_parser.main(test_args) + + if mode == "parse_text": + text_args = ['--shorthand', short_name, + '--mode', 'parse_text'] + text_args = text_args + default_args + extra_args + logger.info("Processing text with args: {}".format(text_args)) + constituency_parser.main(text_args) + +def main(): + common.main(run_treebank, "constituency", "constituency", add_constituency_args, sub_argparse=constituency_parser.build_argparse(), build_model_filename=build_model_filename) + +if __name__ == "__main__": + main() + diff --git a/stanza/stanza/utils/training/run_depparse.py b/stanza/stanza/utils/training/run_depparse.py new file mode 100644 index 0000000000000000000000000000000000000000..b9ea34a4529e2e120a480f5c37525ab15922d749 --- /dev/null +++ b/stanza/stanza/utils/training/run_depparse.py @@ -0,0 +1,133 @@ +import logging +import os + +from stanza.models import parser + +from stanza.utils.training import common +from stanza.utils.training.common import Mode, add_charlm_args, build_depparse_charlm_args, choose_depparse_charlm, choose_transformer +from stanza.utils.training.run_pos import wordvec_args + +from stanza.resources.default_packages import default_charlms, depparse_charlms + +logger = logging.getLogger('stanza') + +def add_depparse_args(parser): + add_charlm_args(parser) + + parser.add_argument('--use_bert', default=False, action="store_true", help='Use the default transformer for this language') + +# TODO: refactor with run_pos +def build_model_filename(paths, short_name, command_args, extra_args): + short_language, dataset = short_name.split("_", 1) + + # TODO: can avoid downloading the charlm at this point, since we + # might not even be training + charlm_args = build_depparse_charlm_args(short_language, dataset, command_args.charlm) + + bert_args = choose_transformer(short_language, command_args, extra_args, warn=False) + + train_args = ["--shorthand", short_name, + "--mode", "train"] + # TODO: also, this downloads the wordvec, which we might not want to do yet + train_args = train_args + wordvec_args(short_language, dataset, extra_args) + charlm_args + bert_args + extra_args + if command_args.save_name is not None: + train_args.extend(["--save_name", command_args.save_name]) + if command_args.save_dir is not None: + train_args.extend(["--save_dir", command_args.save_dir]) + args = parser.parse_args(train_args) + save_name = parser.model_file_name(args) + return save_name + + +def run_treebank(mode, paths, treebank, short_name, + temp_output_file, command_args, extra_args): + short_language, dataset = short_name.split("_") + + # TODO: refactor these blocks? + depparse_dir = paths["DEPPARSE_DATA_DIR"] + train_file = f"{depparse_dir}/{short_name}.train.in.conllu" + dev_in_file = f"{depparse_dir}/{short_name}.dev.in.conllu" + dev_pred_file = temp_output_file if temp_output_file else f"{depparse_dir}/{short_name}.dev.pred.conllu" + test_in_file = f"{depparse_dir}/{short_name}.test.in.conllu" + test_pred_file = temp_output_file if temp_output_file else f"{depparse_dir}/{short_name}.test.pred.conllu" + + eval_file = None + if '--eval_file' in extra_args: + eval_file = extra_args[extra_args.index('--eval_file') + 1] + + charlm_args = build_depparse_charlm_args(short_language, dataset, command_args.charlm) + + bert_args = choose_transformer(short_language, command_args, extra_args) + + if mode == Mode.TRAIN: + if not os.path.exists(train_file): + logger.error("TRAIN FILE NOT FOUND: %s ... skipping" % train_file) + return + + # some languages need reduced batch size + if short_name == 'de_hdt': + # 'UD_German-HDT' + batch_size = "1300" + elif short_name in ('hr_set', 'fi_tdt', 'ru_taiga', 'cs_cltt', 'gl_treegal', 'lv_lvtb', 'ro_simonero'): + # 'UD_Croatian-SET', 'UD_Finnish-TDT', 'UD_Russian-Taiga', + # 'UD_Czech-CLTT', 'UD_Galician-TreeGal', 'UD_Latvian-LVTB' 'Romanian-SiMoNERo' + batch_size = "3000" + else: + batch_size = "5000" + + train_args = ["--wordvec_dir", paths["WORDVEC_DIR"], + "--train_file", train_file, + "--eval_file", eval_file if eval_file else dev_in_file, + "--output_file", dev_pred_file, + "--batch_size", batch_size, + "--lang", short_language, + "--shorthand", short_name, + "--mode", "train"] + train_args = train_args + wordvec_args(short_language, dataset, extra_args) + charlm_args + bert_args + train_args = train_args + extra_args + logger.info("Running train depparse for {} with args {}".format(treebank, train_args)) + parser.main(train_args) + + if mode == Mode.SCORE_DEV or mode == Mode.TRAIN: + dev_args = ["--wordvec_dir", paths["WORDVEC_DIR"], + "--eval_file", eval_file if eval_file else dev_in_file, + "--output_file", dev_pred_file, + "--lang", short_language, + "--shorthand", short_name, + "--mode", "predict"] + dev_args = dev_args + wordvec_args(short_language, dataset, extra_args) + charlm_args + bert_args + dev_args = dev_args + extra_args + logger.info("Running dev depparse for {} with args {}".format(treebank, dev_args)) + parser.main(dev_args) + + if '--no_gold_labels' not in extra_args: + results = common.run_eval_script_depparse(eval_file if eval_file else dev_in_file, dev_pred_file) + logger.info("Finished running dev set on\n{}\n{}".format(treebank, results)) + if not temp_output_file: + logger.info("Output saved to %s", dev_pred_file) + + if mode == Mode.SCORE_TEST: + test_args = ["--wordvec_dir", paths["WORDVEC_DIR"], + "--eval_file", eval_file if eval_file else test_in_file, + "--output_file", test_pred_file, + "--lang", short_language, + "--shorthand", short_name, + "--mode", "predict"] + test_args = test_args + wordvec_args(short_language, dataset, extra_args) + charlm_args + bert_args + test_args = test_args + extra_args + logger.info("Running test depparse for {} with args {}".format(treebank, test_args)) + parser.main(test_args) + + if '--no_gold_labels' not in extra_args: + results = common.run_eval_script_depparse(eval_file if eval_file else test_in_file, test_pred_file) + logger.info("Finished running test set on\n{}\n{}".format(treebank, results)) + if not temp_output_file: + logger.info("Output saved to %s", test_pred_file) + + +def main(): + common.main(run_treebank, "depparse", "parser", add_depparse_args, sub_argparse=parser.build_argparse(), build_model_filename=build_model_filename, choose_charlm_method=choose_depparse_charlm) + +if __name__ == "__main__": + main() + diff --git a/stanza/stanza/utils/training/run_lemma.py b/stanza/stanza/utils/training/run_lemma.py new file mode 100644 index 0000000000000000000000000000000000000000..a362e64c4968620c7e5262f1f57606f6be0df5a8 --- /dev/null +++ b/stanza/stanza/utils/training/run_lemma.py @@ -0,0 +1,179 @@ +""" +This script allows for training or testing on dev / test of the UD lemmatizer. + +If run with a single treebank name, it will train or test that treebank. +If run with ud_all or all_ud, it will iterate over all UD treebanks it can find. + +Mode can be set to train&dev with --train, to dev set only +with --score_dev, and to test set only with --score_test. + +Treebanks are specified as a list. all_ud or ud_all means to look for +all UD treebanks. + +Extra arguments are passed to the lemmatizer. In case the run script +itself is shadowing arguments, you can specify --extra_args as a +parameter to mark where the lemmatizer arguments start. +""" + +import logging +import os + +from stanza.models import identity_lemmatizer +from stanza.models import lemmatizer +from stanza.models.lemma import attach_lemma_classifier + +from stanza.utils.training import common +from stanza.utils.training.common import Mode, add_charlm_args, build_lemma_charlm_args, choose_lemma_charlm +from stanza.utils.training import run_lemma_classifier + +from stanza.utils.datasets.prepare_lemma_treebank import check_lemmas +import stanza.utils.datasets.prepare_lemma_classifier as prepare_lemma_classifier + +logger = logging.getLogger('stanza') + +def add_lemma_args(parser): + add_charlm_args(parser) + + parser.add_argument('--lemma_classifier', dest='lemma_classifier', action='store_true', default=None, + help="Don't use the lemma classifier datasets. Default is to build lemma classifier as part of the original lemmatizer if the charlm is used") + parser.add_argument('--no_lemma_classifier', dest='lemma_classifier', action='store_false', + help="Don't use the lemma classifier datasets. Default is to build lemma classifier as part of the original lemmatizer if the charlm is used") + +def build_model_filename(paths, short_name, command_args, extra_args): + """ + Figure out what the model savename will be, taking into account the model settings. + + Useful for figuring out if the model already exists + + None will represent that there is no expected save_name + """ + short_language, dataset = short_name.split("_", 1) + + lemma_dir = paths["LEMMA_DATA_DIR"] + train_file = f"{lemma_dir}/{short_name}.train.in.conllu" + + if not os.path.exists(train_file): + logger.debug("Treebank %s is not prepared for training the lemmatizer. Could not find any training data at %s Cannot figure out the expected save_name without looking at the data, but a later step in the process will skip the training anyway" % (short_name, train_file)) + return None + + has_lemmas = check_lemmas(train_file) + if not has_lemmas: + return None + + # TODO: can avoid downloading the charlm at this point, since we + # might not even be training + charlm_args = build_lemma_charlm_args(short_language, dataset, command_args.charlm) + + train_args = ["--train_file", train_file, + "--shorthand", short_name, + "--mode", "train"] + train_args = train_args + charlm_args + extra_args + args = lemmatizer.parse_args(train_args) + save_name = lemmatizer.build_model_filename(args) + return save_name + +def run_treebank(mode, paths, treebank, short_name, + temp_output_file, command_args, extra_args): + short_language, dataset = short_name.split("_", 1) + + lemma_dir = paths["LEMMA_DATA_DIR"] + train_file = f"{lemma_dir}/{short_name}.train.in.conllu" + dev_in_file = f"{lemma_dir}/{short_name}.dev.in.conllu" + dev_gold_file = f"{lemma_dir}/{short_name}.dev.gold.conllu" + dev_pred_file = temp_output_file if temp_output_file else f"{lemma_dir}/{short_name}.dev.pred.conllu" + test_in_file = f"{lemma_dir}/{short_name}.test.in.conllu" + test_gold_file = f"{lemma_dir}/{short_name}.test.gold.conllu" + test_pred_file = temp_output_file if temp_output_file else f"{lemma_dir}/{short_name}.test.pred.conllu" + + charlm_args = build_lemma_charlm_args(short_language, dataset, command_args.charlm) + + if not os.path.exists(train_file): + logger.error("Treebank %s is not prepared for training the lemmatizer. Could not find any training data at %s Skipping..." % (treebank, train_file)) + return + + has_lemmas = check_lemmas(train_file) + if not has_lemmas: + logger.info("Treebank " + treebank + " (" + short_name + + ") has no lemmas. Using identity lemmatizer") + if mode == Mode.TRAIN or mode == Mode.SCORE_DEV: + train_args = ["--train_file", train_file, + "--eval_file", dev_in_file, + "--output_file", dev_pred_file, + "--gold_file", dev_gold_file, + "--shorthand", short_name] + logger.info("Running identity lemmatizer for {} with args {}".format(treebank, train_args)) + identity_lemmatizer.main(train_args) + elif mode == Mode.SCORE_TEST: + train_args = ["--train_file", train_file, + "--eval_file", test_in_file, + "--output_file", test_pred_file, + "--gold_file", test_gold_file, + "--shorthand", short_name] + logger.info("Running identity lemmatizer for {} with args {}".format(treebank, train_args)) + identity_lemmatizer.main(train_args) + else: + if mode == Mode.TRAIN: + # ('UD_Czech-PDT', 'UD_Russian-SynTagRus', 'UD_German-HDT') + if short_name in ('cs_pdt', 'ru_syntagrus', 'de_hdt'): + num_epochs = "30" + else: + num_epochs = "60" + + train_args = ["--train_file", train_file, + "--eval_file", dev_in_file, + "--output_file", dev_pred_file, + "--gold_file", dev_gold_file, + "--shorthand", short_name, + "--num_epoch", num_epochs, + "--mode", "train"] + train_args = train_args + charlm_args + extra_args + logger.info("Running train lemmatizer for {} with args {}".format(treebank, train_args)) + lemmatizer.main(train_args) + + if mode == Mode.SCORE_DEV or mode == Mode.TRAIN: + dev_args = ["--eval_file", dev_in_file, + "--output_file", dev_pred_file, + "--gold_file", dev_gold_file, + "--shorthand", short_name, + "--mode", "predict"] + dev_args = dev_args + charlm_args + extra_args + logger.info("Running dev lemmatizer for {} with args {}".format(treebank, dev_args)) + lemmatizer.main(dev_args) + + if mode == Mode.SCORE_TEST: + test_args = ["--eval_file", test_in_file, + "--output_file", test_pred_file, + "--gold_file", test_gold_file, + "--shorthand", short_name, + "--mode", "predict"] + test_args = test_args + charlm_args + extra_args + logger.info("Running test lemmatizer for {} with args {}".format(treebank, test_args)) + lemmatizer.main(test_args) + + use_lemma_classifier = command_args.lemma_classifier + if use_lemma_classifier is None: + use_lemma_classifier = command_args.charlm is not None + use_lemma_classifier = use_lemma_classifier and short_name in prepare_lemma_classifier.DATASET_MAPPING + if use_lemma_classifier and mode == Mode.TRAIN: + lc_charlm_args = ['--no_charlm'] if command_args.charlm is None else ['--charlm', command_args.charlm] + lemma_classifier_args = [treebank] + lc_charlm_args + if command_args.force: + lemma_classifier_args.append('--force') + run_lemma_classifier.main(lemma_classifier_args) + + save_name = build_model_filename(paths, short_name, command_args, extra_args) + # TODO: use a temp path for the lemma_classifier or keep it somewhere + attach_args = ['--input', save_name, + '--output', save_name, + '--classifier', 'saved_models/lemma_classifier/%s_lemma_classifier.pt' % short_name] + attach_lemma_classifier.main(attach_args) + + # now we rerun the dev set - the HI in particular demonstrates some good improvement + lemmatizer.main(dev_args) + +def main(): + common.main(run_treebank, "lemma", "lemmatizer", add_lemma_args, sub_argparse=lemmatizer.build_argparse(), build_model_filename=build_model_filename, choose_charlm_method=choose_lemma_charlm) + +if __name__ == "__main__": + main() + diff --git a/stanza/stanza/utils/training/run_lemma_classifier.py b/stanza/stanza/utils/training/run_lemma_classifier.py new file mode 100644 index 0000000000000000000000000000000000000000..73669e79b0964488269f305e31fa4e127e0b2c43 --- /dev/null +++ b/stanza/stanza/utils/training/run_lemma_classifier.py @@ -0,0 +1,87 @@ +import os + +from stanza.models.lemma_classifier import evaluate_models +from stanza.models.lemma_classifier import train_lstm_model +from stanza.models.lemma_classifier import train_transformer_model +from stanza.models.lemma_classifier.constants import ModelType + +from stanza.resources.default_packages import default_pretrains, TRANSFORMERS +from stanza.utils.training import common +from stanza.utils.training.common import Mode, add_charlm_args, build_lemma_charlm_args, choose_lemma_charlm, find_wordvec_pretrain + +def add_lemma_args(parser): + add_charlm_args(parser) + + parser.add_argument('--model_type', default=ModelType.LSTM, type=lambda x: ModelType[x.upper()], + help='Model type to use. {}'.format(", ".join(x.name for x in ModelType))) + +def build_model_filename(paths, short_name, command_args, extra_args): + return os.path.join("saved_models", "lemma_classifier", short_name + "_lemma_classifier.pt") + +def run_treebank(mode, paths, treebank, short_name, + temp_output_file, command_args, extra_args): + short_language, dataset = short_name.split("_", 1) + + base_args = [] + if '--save_name' not in extra_args: + base_args += ['--save_name', build_model_filename(paths, short_name, command_args, extra_args)] + + embedding_args = build_lemma_charlm_args(short_language, dataset, command_args.charlm) + if '--wordvec_pretrain_file' not in extra_args: + wordvec_pretrain = find_wordvec_pretrain(short_language, default_pretrains, {}, dataset) + embedding_args += ["--wordvec_pretrain_file", wordvec_pretrain] + + bert_args = [] + if command_args.model_type is ModelType.TRANSFORMER: + if '--bert_model' not in extra_args: + if short_language in TRANSFORMERS: + bert_args = ['--bert_model', TRANSFORMERS.get(short_language)] + else: + raise ValueError("--bert_model not specified, so cannot figure out which transformer to use for language %s" % short_language) + + extra_train_args = [] + if command_args.force: + extra_train_args.append('--force') + + if mode == Mode.TRAIN: + train_args = [] + if "--train_file" not in extra_args: + train_file = os.path.join("data", "lemma_classifier", "%s.train.lemma" % short_name) + train_args += ['--train_file', train_file] + if "--eval_file" not in extra_args: + eval_file = os.path.join("data", "lemma_classifier", "%s.dev.lemma" % short_name) + train_args += ['--eval_file', eval_file] + train_args = base_args + train_args + extra_args + extra_train_args + + if command_args.model_type == ModelType.LSTM: + train_args = embedding_args + train_args + train_lstm_model.main(train_args) + else: + model_type_args = ["--model_type", command_args.model_type.name.lower()] + train_args = bert_args + model_type_args + train_args + train_transformer_model.main(train_args) + + if mode == Mode.SCORE_DEV or mode == Mode.TRAIN: + eval_args = [] + if "--eval_file" not in extra_args: + eval_file = os.path.join("data", "lemma_classifier", "%s.dev.lemma" % short_name) + eval_args += ['--eval_file', eval_file] + model_type_args = ["--model_type", command_args.model_type.name.lower()] + eval_args = bert_args + model_type_args + base_args + eval_args + embedding_args + extra_args + evaluate_models.main(eval_args) + + if mode == Mode.SCORE_TEST or mode == Mode.TRAIN: + eval_args = [] + if "--eval_file" not in extra_args: + eval_file = os.path.join("data", "lemma_classifier", "%s.test.lemma" % short_name) + eval_args += ['--eval_file', eval_file] + model_type_args = ["--model_type", command_args.model_type.name.lower()] + eval_args = bert_args + model_type_args + base_args + eval_args + embedding_args + extra_args + evaluate_models.main(eval_args) + +def main(args=None): + common.main(run_treebank, "lemma_classifier", "lemma_classifier", add_lemma_args, sub_argparse=train_lstm_model.build_argparse(), build_model_filename=build_model_filename, choose_charlm_method=choose_lemma_charlm, args=args) + + +if __name__ == '__main__': + main() diff --git a/stanza/stanza/utils/training/run_mwt.py b/stanza/stanza/utils/training/run_mwt.py new file mode 100644 index 0000000000000000000000000000000000000000..af3171d7150f9694965c062d6e4c84c218af3e98 --- /dev/null +++ b/stanza/stanza/utils/training/run_mwt.py @@ -0,0 +1,122 @@ +""" +This script allows for training or testing on dev / test of the UD mwt tools. + +If run with a single treebank name, it will train or test that treebank. +If run with ud_all or all_ud, it will iterate over all UD treebanks it can find. + +Mode can be set to train&dev with --train, to dev set only +with --score_dev, and to test set only with --score_test. + +Treebanks are specified as a list. all_ud or ud_all means to look for +all UD treebanks. + +Extra arguments are passed to mwt. In case the run script +itself is shadowing arguments, you can specify --extra_args as a +parameter to mark where the mwt arguments start. +""" + + +import logging +import math + +from stanza.models import mwt_expander +from stanza.models.common.doc import Document +from stanza.utils.conll import CoNLL +from stanza.utils.training import common +from stanza.utils.training.common import Mode + +from stanza.utils.max_mwt_length import max_mwt_length + +logger = logging.getLogger('stanza') + +def check_mwt(filename): + """ + Checks whether or not there are MWTs in the given conll file + """ + doc = CoNLL.conll2doc(filename) + data = doc.get_mwt_expansions(False) + return len(data) > 0 + +def run_treebank(mode, paths, treebank, short_name, + temp_output_file, command_args, extra_args): + short_language = short_name.split("_")[0] + + mwt_dir = paths["MWT_DATA_DIR"] + + train_file = f"{mwt_dir}/{short_name}.train.in.conllu" + dev_in_file = f"{mwt_dir}/{short_name}.dev.in.conllu" + dev_gold_file = f"{mwt_dir}/{short_name}.dev.gold.conllu" + dev_output_file = temp_output_file if temp_output_file else f"{mwt_dir}/{short_name}.dev.pred.conllu" + test_in_file = f"{mwt_dir}/{short_name}.test.in.conllu" + test_gold_file = f"{mwt_dir}/{short_name}.test.gold.conllu" + test_output_file = temp_output_file if temp_output_file else f"{mwt_dir}/{short_name}.test.pred.conllu" + + train_json = f"{mwt_dir}/{short_name}-ud-train-mwt.json" + dev_json = f"{mwt_dir}/{short_name}-ud-dev-mwt.json" + test_json = f"{mwt_dir}/{short_name}-ud-test-mwt.json" + + eval_file = None + if '--eval_file' in extra_args: + eval_file = extra_args[extra_args.index('--eval_file') + 1] + + gold_file = None + if '--gold_file' in extra_args: + gold_file = extra_args[extra_args.index('--gold_file') + 1] + + if not check_mwt(train_file): + logger.info("No training MWTS found for %s. Skipping" % treebank) + return + + if not check_mwt(dev_in_file) and mode == Mode.TRAIN: + logger.info("No dev MWTS found for %s. Training only the deterministic MWT expander" % treebank) + extra_args.append('--dict_only') + + if mode == Mode.TRAIN: + max_mwt_len = math.ceil(max_mwt_length([train_json, dev_json]) * 1.1 + 1) + logger.info("Max len: %f" % max_mwt_len) + train_args = ['--train_file', train_file, + '--eval_file', eval_file if eval_file else dev_in_file, + '--output_file', dev_output_file, + '--gold_file', gold_file if gold_file else dev_gold_file, + '--lang', short_language, + '--shorthand', short_name, + '--mode', 'train', + '--max_dec_len', str(max_mwt_len)] + train_args = train_args + extra_args + logger.info("Running train step with args: {}".format(train_args)) + mwt_expander.main(train_args) + + if mode == Mode.SCORE_DEV or mode == Mode.TRAIN: + dev_args = ['--eval_file', eval_file if eval_file else dev_in_file, + '--output_file', dev_output_file, + '--gold_file', gold_file if gold_file else dev_gold_file, + '--lang', short_language, + '--shorthand', short_name, + '--mode', 'predict'] + dev_args = dev_args + extra_args + logger.info("Running dev step with args: {}".format(dev_args)) + mwt_expander.main(dev_args) + + results = common.run_eval_script_mwt(gold_file if gold_file else dev_gold_file, dev_output_file) + logger.info("Finished running dev set on\n{}\n{}".format(treebank, results)) + + if mode == Mode.SCORE_TEST: + test_args = ['--eval_file', eval_file if eval_file else test_in_file, + '--output_file', test_output_file, + '--gold_file', gold_file if gold_file else test_gold_file, + '--lang', short_language, + '--shorthand', short_name, + '--mode', 'predict'] + test_args = test_args + extra_args + logger.info("Running test step with args: {}".format(test_args)) + mwt_expander.main(test_args) + + results = common.run_eval_script_mwt(gold_file if gold_file else test_gold_file, test_output_file) + logger.info("Finished running test set on\n{}\n{}".format(treebank, results)) + +def main(): + common.main(run_treebank, "mwt", "mwt_expander", sub_argparse=mwt_expander.build_argparse()) + +if __name__ == "__main__": + main() + diff --git a/stanza/stanza/utils/training/run_ner.py b/stanza/stanza/utils/training/run_ner.py new file mode 100644 index 0000000000000000000000000000000000000000..dca1c4d8acbc23810f40851a53dc378ed56e7326 --- /dev/null +++ b/stanza/stanza/utils/training/run_ner.py @@ -0,0 +1,159 @@ +""" +Trains or scores an NER model. + +Will attempt to guess the appropriate word vector file if none is +specified, and will use the charlms specified in the resources +for a given dataset or language if possible. + +Example command line: + python3 -m stanza.utils.training.run_ner.py hu_combined + +This script expects the prepared data to be in + data/ner/{lang}_{dataset}.train.json, {lang}_{dataset}.dev.json, {lang}_{dataset}.test.json + +If those files don't exist, it will make an attempt to rebuild them +using the prepare_ner_dataset script. However, this will fail if the +data is not already downloaded. More information on where to find +most of the datasets online is in that script. Some of the datasets +have licenses which must be agreed to, so no attempt is made to +automatically download the data. +""" + +import logging +import os + +from stanza.models import ner_tagger +from stanza.resources.common import DEFAULT_MODEL_DIR +from stanza.utils.datasets.ner import prepare_ner_dataset +from stanza.utils.training import common +from stanza.utils.training.common import Mode, add_charlm_args, build_charlm_args, choose_charlm, find_wordvec_pretrain + +from stanza.resources.default_packages import default_charlms, default_pretrains, ner_charlms, ner_pretrains + +# extra arguments specific to a particular dataset +DATASET_EXTRA_ARGS = { + "da_ddt": [ "--dropout", "0.6" ], + "fa_arman": [ "--dropout", "0.6" ], + "vi_vlsp": [ "--dropout", "0.6", + "--word_dropout", "0.1", + "--locked_dropout", "0.1", + "--char_dropout", "0.1" ], +} + +logger = logging.getLogger('stanza') + +def add_ner_args(parser): + add_charlm_args(parser) + + parser.add_argument('--use_bert', default=False, action="store_true", help='Use the default transformer for this language') + + +def build_pretrain_args(language, dataset, charlm="default", command_args=None, extra_args=None, model_dir=DEFAULT_MODEL_DIR): + """ + Returns one list with the args for this language & dataset's charlm and pretrained embedding + """ + charlm = choose_charlm(language, dataset, charlm, default_charlms, ner_charlms) + charlm_args = build_charlm_args(language, charlm, model_dir=model_dir) + + wordvec_args = [] + if extra_args is None or '--wordvec_pretrain_file' not in extra_args: + # will throw an error if the pretrain can't be found + wordvec_pretrain = find_wordvec_pretrain(language, default_pretrains, ner_pretrains, dataset, model_dir=model_dir) + wordvec_args = ['--wordvec_pretrain_file', wordvec_pretrain] + + bert_args = common.choose_transformer(language, command_args, extra_args, warn=False) + + return charlm_args + wordvec_args + bert_args + + +# TODO: refactor? tagger and depparse should be pretty similar +def build_model_filename(paths, short_name, command_args, extra_args): + short_language, dataset = short_name.split("_", 1) + + # TODO: can avoid downloading the charlm at this point, since we + # might not even be training + pretrain_args = build_pretrain_args(short_language, dataset, command_args.charlm, command_args, extra_args) + + dataset_args = DATASET_EXTRA_ARGS.get(short_name, []) + + train_args = ["--shorthand", short_name, + "--mode", "train"] + train_args = train_args + pretrain_args + dataset_args + extra_args + if command_args.save_name is not None: + train_args.extend(["--save_name", command_args.save_name]) + if command_args.save_dir is not None: + train_args.extend(["--save_dir", command_args.save_dir]) + args = ner_tagger.parse_args(train_args) + save_name = ner_tagger.model_file_name(args) + return save_name + + +# Technically NER datasets are not necessarily treebanks +# (usually not, in fact) +# However, to keep the naming consistent, we leave the +# method which does the training as run_treebank +# TODO: rename treebank -> dataset everywhere +def run_treebank(mode, paths, treebank, short_name, + temp_output_file, command_args, extra_args): + ner_dir = paths["NER_DATA_DIR"] + language, dataset = short_name.split("_") + + train_file = os.path.join(ner_dir, f"{treebank}.train.json") + dev_file = os.path.join(ner_dir, f"{treebank}.dev.json") + test_file = os.path.join(ner_dir, f"{treebank}.test.json") + + # if any files are missing, try to rebuild the dataset + # if that still doesn't work, we have to throw an error + missing_file = [x for x in (train_file, dev_file, test_file) if not os.path.exists(x)] + if len(missing_file) > 0: + logger.warning(f"The data for {treebank} is missing or incomplete. Cannot find {missing_file} Attempting to rebuild...") + try: + prepare_ner_dataset.main(treebank) + except Exception as e: + raise FileNotFoundError(f"An exception occurred while trying to build the data for {treebank} At least one portion of the data was missing: {missing_file} Please correctly build these files and then try again.") from e + + pretrain_args = build_pretrain_args(language, dataset, command_args.charlm, command_args, extra_args) + + if mode == Mode.TRAIN: + # VI example arguments: + # --wordvec_pretrain_file ~/stanza_resources/vi/pretrain/vtb.pt + # --train_file data/ner/vi_vlsp.train.json + # --eval_file data/ner/vi_vlsp.dev.json + # --lang vi + # --shorthand vi_vlsp + # --mode train + # --charlm --charlm_shorthand vi_conll17 + # --dropout 0.6 --word_dropout 0.1 --locked_dropout 0.1 --char_dropout 0.1 + dataset_args = DATASET_EXTRA_ARGS.get(short_name, []) + + train_args = ['--train_file', train_file, + '--eval_file', dev_file, + '--shorthand', short_name, + '--mode', 'train'] + train_args = train_args + pretrain_args + dataset_args + extra_args + logger.info("Running train step with args: {}".format(train_args)) + ner_tagger.main(train_args) + + if mode == Mode.SCORE_DEV or mode == Mode.TRAIN: + dev_args = ['--eval_file', dev_file, + '--shorthand', short_name, + '--mode', 'predict'] + dev_args = dev_args + pretrain_args + extra_args + logger.info("Running dev step with args: {}".format(dev_args)) + ner_tagger.main(dev_args) + + if mode == Mode.SCORE_TEST or mode == Mode.TRAIN: + test_args = ['--eval_file', test_file, + '--shorthand', short_name, + '--mode', 'predict'] + test_args = test_args + pretrain_args + extra_args + logger.info("Running test step with args: {}".format(test_args)) + ner_tagger.main(test_args) + + +def main(): + common.main(run_treebank, "ner", "nertagger", add_ner_args, ner_tagger.build_argparse(), build_model_filename=build_model_filename) + +if __name__ == "__main__": + main() + diff --git a/stanza/stanza/utils/training/run_sentiment.py b/stanza/stanza/utils/training/run_sentiment.py new file mode 100644 index 0000000000000000000000000000000000000000..faeb9815c98319463380134d79749b7df3b8e843 --- /dev/null +++ b/stanza/stanza/utils/training/run_sentiment.py @@ -0,0 +1,118 @@ +""" +Trains or tests a sentiment model using the classifier package + +The prep script has separate entries for the root-only version of SST, +which is what people typically use to test. When training a model for +SST which uses all the data, the root-only version is used for +dev and test +""" + +import logging +import os + +from stanza.models import classifier +from stanza.utils.training import common +from stanza.utils.training.common import Mode, build_charlm_args, choose_charlm, find_wordvec_pretrain + +from stanza.resources.default_packages import default_charlms, default_pretrains + +logger = logging.getLogger('stanza') + +# TODO: refactor with ner & conparse +def add_sentiment_args(parser): + parser.add_argument('--charlm', default="default", type=str, help='Which charlm to run on. Will use the default charlm for this language/model if not set. Set to None to turn off charlm for languages with a default charlm') + parser.add_argument('--no_charlm', dest='charlm', action="store_const", const=None, help="Don't use a charlm, even if one is used by default for this package") + + parser.add_argument('--use_bert', default=False, action="store_true", help='Use the default transformer for this language') + +ALTERNATE_DATASET = { + "en_sst2": "en_sst2roots", + "en_sstplus": "en_sst3roots", +} + +def build_default_args(paths, short_language, dataset, command_args, extra_args): + if '--wordvec_pretrain_file' not in extra_args: + # will throw an error if the pretrain can't be found + wordvec_pretrain = find_wordvec_pretrain(short_language, default_pretrains) + wordvec_args = ['--wordvec_pretrain_file', wordvec_pretrain] + else: + wordvec_args = [] + + charlm = choose_charlm(short_language, dataset, command_args.charlm, default_charlms, {}) + charlm_args = build_charlm_args(short_language, charlm, base_args=False) + + bert_args = common.choose_transformer(short_language, command_args, extra_args) + default_args = wordvec_args + charlm_args + bert_args + + return default_args + +def build_model_filename(paths, short_name, command_args, extra_args): + short_language, dataset = short_name.split("_", 1) + + default_args = build_default_args(paths, short_language, dataset, command_args, extra_args) + + train_args = ["--shorthand", short_name] + train_args = train_args + default_args + if command_args.save_name is not None: + train_args.extend(["--save_name", command_args.save_name]) + if command_args.save_dir is not None: + train_args.extend(["--save_dir", command_args.save_dir]) + args = classifier.parse_args(train_args + extra_args) + save_name = classifier.build_model_filename(args) + return save_name + + +def run_dataset(mode, paths, treebank, short_name, + temp_output_file, command_args, extra_args): + sentiment_dir = paths["SENTIMENT_DATA_DIR"] + short_language, dataset = short_name.split("_", 1) + + train_file = os.path.join(sentiment_dir, f"{short_name}.train.json") + + other_name = ALTERNATE_DATASET.get(short_name, short_name) + dev_file = os.path.join(sentiment_dir, f"{other_name}.dev.json") + test_file = os.path.join(sentiment_dir, f"{other_name}.test.json") + + for filename in (train_file, dev_file, test_file): + if not os.path.exists(filename): + raise FileNotFoundError("Cannot find %s" % filename) + + default_args = build_default_args(paths, short_language, dataset, command_args, extra_args) + + if mode == Mode.TRAIN: + train_args = ['--train_file', train_file, + '--dev_file', dev_file, + '--test_file', test_file, + '--shorthand', short_name, + '--wordvec_type', 'word2vec', # TODO: chinese is fasttext + '--extra_wordvec_method', 'SUM'] + train_args = train_args + default_args + extra_args + logger.info("Running train step with args: {}".format(train_args)) + classifier.main(train_args) + + if mode == Mode.SCORE_DEV or mode == Mode.TRAIN: + dev_args = ['--no_train', + '--test_file', dev_file, + '--shorthand', short_name, + '--wordvec_type', 'word2vec'] # TODO: chinese is fasttext + dev_args = dev_args + default_args + extra_args + logger.info("Running dev step with args: {}".format(dev_args)) + classifier.main(dev_args) + + if mode == Mode.SCORE_TEST or mode == Mode.TRAIN: + test_args = ['--no_train', + '--test_file', test_file, + '--shorthand', short_name, + '--wordvec_type', 'word2vec'] # TODO: chinese is fasttext + test_args = test_args + default_args + extra_args + logger.info("Running test step with args: {}".format(test_args)) + classifier.main(test_args) + + + +def main(): + common.main(run_dataset, "classifier", "classifier", add_sentiment_args, classifier.build_argparse(), build_model_filename=build_model_filename) + +if __name__ == "__main__": + main() + diff --git a/stanza/stanza/utils/training/run_tokenizer.py b/stanza/stanza/utils/training/run_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..7c4245f61dff8a3565185131767b96c1680ed7fd --- /dev/null +++ b/stanza/stanza/utils/training/run_tokenizer.py @@ -0,0 +1,124 @@ +""" +This script allows for training or testing on dev / test of the UD tokenizer. + +If run with a single treebank name, it will train or test that treebank. +If run with ud_all or all_ud, it will iterate over all UD treebanks it can find. + +Mode can be set to train&dev with --train, to dev set only +with --score_dev, and to test set only with --score_test. + +Treebanks are specified as a list. all_ud or ud_all means to look for +all UD treebanks. + +Extra arguments are passed to tokenizer. In case the run script +itself is shadowing arguments, you can specify --extra_args as a +parameter to mark where the tokenizer arguments start. + +Default behavior is to discard the output and just print the results. +To keep the results instead, use --save_output +""" + +import logging +import math +import os + +from stanza.models import tokenizer +from stanza.utils.avg_sent_len import avg_sent_len +from stanza.utils.training import common +from stanza.utils.training.common import Mode + +logger = logging.getLogger('stanza') + +def uses_dictionary(short_language): + """ + Some of the languages (as shown here) have external dictionaries + + We found this helped the overall tokenizer performance + If these can't be found, they can be extracted from the previous iteration of models + """ + if short_language in ('ja', 'th', 'zh', 'zh-hans', 'zh-hant'): + return True + return False + +def run_treebank(mode, paths, treebank, short_name, + temp_output_file, command_args, extra_args): + tokenize_dir = paths["TOKENIZE_DATA_DIR"] + + short_language = short_name.split("_")[0] + label_type = "--label_file" + label_file = f"{tokenize_dir}/{short_name}-ud-train.toklabels" + dev_type = "--txt_file" + dev_file = f"{tokenize_dir}/{short_name}.dev.txt" + test_type = "--txt_file" + test_file = f"{tokenize_dir}/{short_name}.test.txt" + train_type = "--txt_file" + train_file = f"{tokenize_dir}/{short_name}.train.txt" + train_dev_args = ["--dev_txt_file", dev_file, "--dev_label_file", f"{tokenize_dir}/{short_name}-ud-dev.toklabels"] + + if short_language == "zh" or short_language.startswith("zh-"): + extra_args = ["--skip_newline"] + extra_args + + train_gold = f"{tokenize_dir}/{short_name}.train.gold.conllu" + dev_gold = f"{tokenize_dir}/{short_name}.dev.gold.conllu" + test_gold = f"{tokenize_dir}/{short_name}.test.gold.conllu" + + train_mwt = f"{tokenize_dir}/{short_name}-ud-train-mwt.json" + dev_mwt = f"{tokenize_dir}/{short_name}-ud-dev-mwt.json" + test_mwt = f"{tokenize_dir}/{short_name}-ud-test-mwt.json" + + train_pred = temp_output_file if temp_output_file else f"{tokenize_dir}/{short_name}.train.pred.conllu" + dev_pred = temp_output_file if temp_output_file else f"{tokenize_dir}/{short_name}.dev.pred.conllu" + test_pred = temp_output_file if temp_output_file else f"{tokenize_dir}/{short_name}.test.pred.conllu" + + if mode == Mode.TRAIN: + seqlen = str(math.ceil(avg_sent_len(label_file) * 3 / 100) * 100) + train_args = ([label_type, label_file, train_type, train_file, "--lang", short_language, + "--max_seqlen", seqlen, "--mwt_json_file", dev_mwt] + + train_dev_args + + ["--dev_conll_gold", dev_gold, "--conll_file", dev_pred, "--shorthand", short_name]) + if uses_dictionary(short_language): + train_args = train_args + ["--use_dictionary"] + train_args = train_args + extra_args + logger.info("Running train step with args: {}".format(train_args)) + tokenizer.main(train_args) + + if mode == Mode.SCORE_DEV or mode == Mode.TRAIN: + dev_args = ["--mode", "predict", dev_type, dev_file, "--lang", short_language, + "--conll_file", dev_pred, "--shorthand", short_name, "--mwt_json_file", dev_mwt] + dev_args = dev_args + extra_args + logger.info("Running dev step with args: {}".format(dev_args)) + tokenizer.main(dev_args) + + # TODO: log these results? The original script logged them to + # echo $results $args >> ${TOKENIZE_DATA_DIR}/${short}.results + + results = common.run_eval_script_tokens(dev_gold, dev_pred) + logger.info("Finished running dev set on\n{}\n{}".format(treebank, results)) + + if mode == Mode.SCORE_TEST or mode == Mode.TRAIN: + test_args = ["--mode", "predict", test_type, test_file, "--lang", short_language, + "--conll_file", test_pred, "--shorthand", short_name, "--mwt_json_file", test_mwt] + test_args = test_args + extra_args + logger.info("Running test step with args: {}".format(test_args)) + tokenizer.main(test_args) + + results = common.run_eval_script_tokens(test_gold, test_pred) + logger.info("Finished running test set on\n{}\n{}".format(treebank, results)) + + if mode == Mode.SCORE_TRAIN: + test_args = ["--mode", "predict", test_type, train_file, "--lang", short_language, + "--conll_file", train_pred, "--shorthand", short_name, "--mwt_json_file", train_mwt] + test_args = test_args + extra_args + logger.info("Running test step with args: {}".format(test_args)) + tokenizer.main(test_args) + + results = common.run_eval_script_tokens(train_gold, train_pred) + logger.info("Finished running train set as a test on\n{}\n{}".format(treebank, results)) + + + +def main(): + common.main(run_treebank, "tokenize", "tokenizer", sub_argparse=tokenizer.build_argparse()) + +if __name__ == "__main__": + main() diff --git a/stanza/stanza/utils/training/separate_ner_pretrain.py b/stanza/stanza/utils/training/separate_ner_pretrain.py new file mode 100644 index 0000000000000000000000000000000000000000..9f9aac5bc5906616843557f5fabea12c5cee04ac --- /dev/null +++ b/stanza/stanza/utils/training/separate_ner_pretrain.py @@ -0,0 +1,215 @@ +""" +Loads NER models & separates out the word vectors to base & delta + +The model will then be resaved without the base word vector, +greatly reducing the size of the model + +This may be useful for any external users of stanza who have an NER +model they wish to reuse without retraining + +If you know which pretrain was used to build an NER model, you can +provide that pretrain. Otherwise, you can give a directory of +pretrains and the script will test each one. In the latter case, +the name of the pretrain needs to look like lang_dataset_pretrain.pt +""" + +import argparse +from collections import defaultdict +import logging +import os + +import numpy as np +import torch +import torch.nn as nn + +from stanza import Pipeline +from stanza.models.common.constant import lang_to_langcode +from stanza.models.common.pretrain import Pretrain, PretrainedWordVocab +from stanza.models.common.vocab import PAD_ID, VOCAB_PREFIX +from stanza.models.ner.trainer import Trainer + +logger = logging.getLogger('stanza') +logger.setLevel(logging.ERROR) + +DEBUG = False +EPS = 0.0001 + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--input_path', type=str, default='saved_models/ner', help='Where to find NER models (dir or filename)') + parser.add_argument('--output_path', type=str, default='saved_models/shrunk', help='Where to write shrunk NER models (dir)') + parser.add_argument('--pretrain_path', type=str, default='saved_models/pretrain', help='Where to find pretrains (dir or filename)') + args = parser.parse_args() + + # get list of NER models to shrink + if os.path.isdir(args.input_path): + ner_model_dir = args.input_path + ners = os.listdir(ner_model_dir) + if len(ners) == 0: + raise FileNotFoundError("No ner models found in {}".format(args.input_path)) + else: + if not os.path.isfile(args.input_path): + raise FileNotFoundError("No ner model found at path {}".format(args.input_path)) + ner_model_dir, ners = os.path.split(args.input_path) + ners = [ners] + + # get map from language to candidate pretrains + if os.path.isdir(args.pretrain_path): + pt_model_dir = args.pretrain_path + pretrains = os.listdir(pt_model_dir) + lang_to_pretrain = defaultdict(list) + for pt in pretrains: + lang_to_pretrain[pt.split("_")[0]].append(pt) + else: + pt_model_dir, pretrains = os.path.split(pt_model_dir) + pretrains = [pretrains] + lang_to_pretrain = defaultdict(lambda: pretrains) + + # shrunk models will all go in this directory + new_dir = args.output_path + os.makedirs(new_dir, exist_ok=True) + + final_pretrains = [] + missing_pretrains = [] + no_finetune = [] + + # for each model, go through the various pretrains + # until we find one that works or none of them work + for ner_model in ners: + ner_path = os.path.join(ner_model_dir, ner_model) + + expected_ending = "_nertagger.pt" + if not ner_model.endswith(expected_ending): + raise ValueError("Unexpected name: {}".format(ner_model)) + short_name = ner_model[:-len(expected_ending)] + lang, package = short_name.split("_", maxsplit=1) + print("===============================================") + print("Processing lang %s package %s" % (lang, package)) + + # this may look funny - basically, the pipeline has machinery + # to make sure the model has everything it needs to load, + # including downloading other pieces if needed + pipe = Pipeline(lang, processors="tokenize,ner", tokenize_pretokenized=True, package={"ner": package}, ner_model_path=ner_path) + ner_processor = pipe.processors['ner'] + print("Loaded NER processor: {}".format(ner_processor)) + trainer = ner_processor.trainers[0] + vocab = trainer.model.vocab + word_vocab = vocab['word'] + num_vectors = trainer.model.word_emb.weight.shape[0] + + # sanity check, make sure the model loaded matches the + # language from the model's filename + lcode = lang_to_langcode(trainer.args['lang']) + if lang != lcode and not (lcode == 'zh' and lang == 'zh-hans'): + raise ValueError("lang not as expected: {} vs {} ({})".format(lang, trainer.args['lang'], lcode)) + + ner_pretrains = sorted(set(lang_to_pretrain[lang] + lang_to_pretrain[lcode])) + for pt_model in ner_pretrains: + pt_path = os.path.join(pt_model_dir, pt_model) + print("Attempting pretrain: {}".format(pt_path)) + pt = Pretrain(filename=pt_path) + print(" pretrain shape: {}".format(pt.emb.shape)) + print(" embedding in ner model shape: {}".format(trainer.model.word_emb.weight.shape)) + if pt.emb.shape[1] != trainer.model.word_emb.weight.shape[1]: + print(" DIMENSION DOES NOT MATCH. SKIPPING") + continue + N = min(pt.emb.shape[0], trainer.model.word_emb.weight.shape[0]) + if pt.emb.shape[0] != trainer.model.word_emb.weight.shape[0]: + # If the vocab was exactly the same, that's a good + # sign this pretrain was used, just with a different size + # In such a case, we can reuse the rest of the pretrain + # Minor issue: some vectors which were trained will be + # lost in the case of |pt| < |model.word_emb| + if all(word_vocab.id2unit(x) == word_vocab.id2unit(x) for x in range(N)): + print(" Attempting to use pt vectors to replace ner model's vectors") + else: + print(" NUM VECTORS DO NOT MATCH. WORDS DO NOT MATCH. SKIPPING") + continue + if pt.emb.shape[0] < trainer.model.word_emb.weight.shape[0]: + print(" WARNING: if any vectors beyond {} were fine tuned, that fine tuning will be lost".format(N)) + device = next(trainer.model.parameters()).device + delta = trainer.model.word_emb.weight[:N, :] - pt.emb.to(device)[:N, :] + delta = delta.detach() + delta_norms = torch.linalg.norm(delta, dim=1).cpu().numpy() + if np.sum(delta_norms < 0) > 0: + raise ValueError("This should not be - a norm was less than 0!") + num_matching = np.sum(delta_norms < EPS) + if num_matching > N / 2: + print(" Accepted! %d of %d vectors match for %s" % (num_matching, N, pt_path)) + if pt.emb.shape[0] != trainer.model.word_emb.weight.shape[0]: + print(" Setting model vocab to match the pretrain") + word_vocab = pt.vocab + vocab['word'] = word_vocab + trainer.args['word_emb_dim'] = pt.emb.shape[1] + break + else: + print(" %d of %d vectors matched for %s - SKIPPING" % (num_matching, N, pt_path)) + vocab_same = sum(x in pt.vocab for x in word_vocab) + print(" %d words were in both vocabs" % vocab_same) + # this is expensive, and in practice doesn't happen, + # but theoretically we might have missed a mostly matching pt + # if the vocab had been scrambled + if DEBUG: + rearranged_count = 0 + for x in word_vocab: + if x not in pt.vocab: + continue + x_id = word_vocab.unit2id(x) + x_vec = trainer.model.word_emb.weight[x_id, :] + pt_id = pt.vocab.unit2id(x) + pt_vec = pt.emb[pt_id, :] + if (x_vec.detach().cpu() - pt_vec).norm() < EPS: + rearranged_count += 1 + print(" %d vectors were close when ignoring id ordering" % rearranged_count) + else: + print("COULD NOT FIND A MATCHING PT: {}".format(ner_processor)) + missing_pretrains.append(ner_model) + continue + + # build a delta vector & embedding + assert 'delta' not in vocab.keys() + delta_vectors = [delta[i].cpu() for i in range(4)] + delta_vocab = [] + for i in range(4, len(delta_norms)): + if delta_norms[i] > 0.0: + delta_vocab.append(word_vocab.id2unit(i)) + delta_vectors.append(delta[i].cpu()) + + trainer.model.unsaved_modules.append("word_emb") + if len(delta_vocab) == 0: + print("No vectors were changed! Perhaps this model was trained without finetune.") + no_finetune.append(ner_model) + else: + print("%d delta vocab" % len(delta_vocab)) + print("%d vectors in the delta set" % len(delta_vectors)) + delta_vectors = np.stack(delta_vectors) + delta_vectors = torch.from_numpy(delta_vectors) + assert delta_vectors.shape[0] == len(delta_vocab) + len(VOCAB_PREFIX) + print(delta_vectors.shape) + + delta_vocab = PretrainedWordVocab(delta_vocab, lang=word_vocab.lang, lower=word_vocab.lower) + vocab['delta'] = delta_vocab + trainer.model.delta_emb = nn.Embedding(delta_vectors.shape[0], delta_vectors.shape[1], PAD_ID) + trainer.model.delta_emb.weight.data.copy_(delta_vectors) + + new_path = os.path.join(new_dir, ner_model) + trainer.save(new_path) + + final_pretrains.append((ner_model, pt_model)) + + print() + if len(final_pretrains) > 0: + print("Final pretrain mappings:") + for i in final_pretrains: + print(i) + if len(missing_pretrains) > 0: + print("MISSING EMBEDDINGS:") + for i in missing_pretrains: + print(i) + if len(no_finetune) > 0: + print("NOT FINE TUNED:") + for i in no_finetune: + print(i) + +if __name__ == '__main__': + main() diff --git a/stanza/stanza/utils/visualization/__init__.py b/stanza/stanza/utils/visualization/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/stanza/stanza/utils/visualization/conll_deprel_visualization.py b/stanza/stanza/utils/visualization/conll_deprel_visualization.py new file mode 100644 index 0000000000000000000000000000000000000000..e5f1e98070a2b872a6ce9eab1bf9d2d701dd6f85 --- /dev/null +++ b/stanza/stanza/utils/visualization/conll_deprel_visualization.py @@ -0,0 +1,83 @@ +from stanza.models.common.constant import is_right_to_left +import spacy +import argparse +from spacy import displacy +from spacy.tokens import Doc +from stanza.utils import conll +from stanza.utils.visualization import dependency_visualization as viz + + +def conll_to_visual(conll_file, pipeline, sent_count=10, display_all=False): + """ + Takes in a conll file and visualizes it by converting the conll file to a Stanza Document object + and visualizing it with the visualize_doc method. + + Input should be a proper conll file. + + The pipeline for the conll file to be processed in must be provided as well. + + Optionally, the sent_count argument can be tweaked to display a different amount of sentences. + + To display all of the sentences in a conll file, the display_all argument can optionally be set to True. + BEWARE: setting this argument for a large conll file may result in too many renderings, resulting in a crash. + """ + # convert conll file to doc + doc = conll.CoNLL.conll2doc(conll_file) + + if display_all: + viz.visualize_doc(conll.CoNLL.conll2doc(conll_file), pipeline) + else: # visualize a given number of sentences + visualization_options = {"compact": True, "bg": "#09a3d5", "color": "white", "distance": 100, + "font": "Source Sans Pro", "offset_x": 30, + "arrow_spacing": 20} # see spaCy visualization settings doc for more options + nlp = spacy.blank("en") + sentences_to_visualize, rtl, num_sentences = [], is_right_to_left(pipeline), len(doc.sentences) + + for i in range(sent_count): + if i >= num_sentences: # case where there are less sentences than amount requested + break + sentence = doc.sentences[i] + words, lemmas, heads, deps, tags = [], [], [], [], [] + sentence_words = sentence.words + if rtl: # rtl languages will be visually rendered from right to left as well + sentence_words = reversed(sentence.words) + sent_len = len(sentence.words) + for word in sentence_words: + words.append(word.text) + lemmas.append(word.lemma) + deps.append(word.deprel) + tags.append(word.upos) + if rtl and word.head == 0: # word heads are off-by-1 in spaCy doc inits compared to Stanza + heads.append(sent_len - word.id) + elif rtl and word.head != 0: + heads.append(sent_len - word.head) + elif not rtl and word.head == 0: + heads.append(word.id - 1) + elif not rtl and word.head != 0: + heads.append(word.head - 1) + + document_result = Doc(nlp.vocab, words=words, lemmas=lemmas, heads=heads, deps=deps, pos=tags) + sentences_to_visualize.append(document_result) + + print(sentences_to_visualize) + for line in sentences_to_visualize: # render all sentences through displaCy + displacy.render(line, style="dep", options=visualization_options) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--conll_file', type=str, + default="C:\\Users\\Alex\\stanza\\demo\\en_test.conllu.txt", + help="File path of the CoNLL file to visualize dependencies of") + parser.add_argument('--pipeline', type=str, default="en", + help="Language code of the language pipeline to use (ex: 'en' for English)") + parser.add_argument('--sent_count', type=int, default=10, help="Number of sentences to visualize from CoNLL file") + parser.add_argument('--display_all', type=bool, default=False, + help="Whether or not to visualize all of the sentences from the file. Overrides sent_count if set to True") + args = parser.parse_args() + conll_to_visual(args.conll_file, args.pipeline, args.sent_count, args.display_all) + return + + +if __name__ == "__main__": + main() diff --git a/stanza/stanza/utils/visualization/ner_visualization.py b/stanza/stanza/utils/visualization/ner_visualization.py new file mode 100644 index 0000000000000000000000000000000000000000..16719161542a02a82f08abc916dd376928b72868 --- /dev/null +++ b/stanza/stanza/utils/visualization/ner_visualization.py @@ -0,0 +1,168 @@ +""" +Visualize named entities from different texts and Stanza documents (+ CoNLL files) +""" + +from spacy import displacy +from spacy.tokens import Doc +from spacy.tokens import Span +from stanza.models.common.constant import is_right_to_left +import stanza +import spacy +import copy + + +def visualize_ner_doc(doc, language, select=None, colors=None): + """ + Takes a stanza doc object and language pipeline and visualizes the named entities within it. + + Stanza currently supports a limited amount of languages for NER, which you can view here: + https://stanfordnlp.github.io/stanza/ner_models.html + + To view only a specific type(s) of named entities, set the optional 'select' argument to + a list of the named entity types. Ex: select=["PER", "ORG", "GPE"] to only see entities tagged as Person(s), + Organizations, and Geo-political entities. A full list of the available types can be found here: + https://stanfordnlp.github.io/stanza/ner_models.html (ctrl + F "The following table"). + + The colors argument is formatted as a dictionary of NER tags with their corresponding colors, which can be + represented as a string (ex: "blue"), a color hex value (ex: #aa9cfc), or as a linear gradient of color + values (ex: "linear-gradient(90deg, #aa9cfc, #fc9ce7)"). + + Do not change the 'rtl_clr_adjusted' argument; it is used for ensuring that the visualize_strings function + works properly on rtl languages. + """ + model, documents, visualization_colors = spacy.blank('en'), [], copy.deepcopy(colors) # blank model, spacy is only used for visualization purposes + sentences, rtl, RTL_OVERRIDE = doc.sentences, is_right_to_left(language), "‮" + if rtl: # need to flip order of all the sentences in rendered display + sentences = reversed(doc.sentences) + # adjust colors to be in LTR flipped format due to the RLO unicode char flipping words + if colors: + for color in visualization_colors: + if RTL_OVERRIDE not in color: + clr_val = visualization_colors[color] + visualization_colors.pop(color) + visualization_colors[RTL_OVERRIDE + color[::-1]] = clr_val + for sentence in sentences: + words, display_ents, already_found = [], [], False + # initialize doc object with words first + for i, word in enumerate(sentence.words): + if rtl and word.text.isascii() and not already_found: + to_append = [word.text[::-1]] + next_word_index = i + 1 + # account for flipping non Arabic words back to original form and order. two flips -> original order + while next_word_index <= len(sentence.words) - 1 and sentence.words[next_word_index].text.isascii(): + to_append.append(sentence.words[next_word_index].text[::-1]) + next_word_index += 1 + to_append = reversed(to_append) + for token in to_append: + words.append(token) + already_found = True + elif rtl and word.text.isascii() and already_found: # skip over already collected words + continue + else: # arabic chars + words.append(word.text) + already_found = False + + document = Doc(model.vocab, words=words) + + # tag all NER tokens found + for ent in sentence.ents: + if select and ent.type not in select: + continue + found_indexes = [] + for token in ent.tokens: + found_indexes.append(token.id[0] - 1) + if not rtl: + to_add = Span(document, found_indexes[0], found_indexes[-1] + 1, ent.type) + else: # RTL languages need the override char to flip order + to_add = Span(document, found_indexes[0], found_indexes[-1] + 1, RTL_OVERRIDE + ent.type[::-1]) + display_ents.append(to_add) + document.set_ents(display_ents) + documents.append(document) + + # Visualize doc objects + visualization_options = {"ents": select} + if colors: + visualization_options["colors"] = visualization_colors + for document in documents: + displacy.render(document, style='ent', options=visualization_options) + + +def visualize_ner_str(text, pipe, select=None, colors=None): + """ + Takes in a text string and visualizes the named entities within the text. + + Required args also include a pipeline code, the two-letter code for a language defined by Universal Dependencies (ex: "en" for English). + + Lastly, the user must provide an NLP pipeline - we recommend Stanza (ex: pipe = stanza.Pipeline('en')). + + Optionally, the 'select' argument allows for specific NER tags to be highlighted; the 'color' argument allows + for specific NER tags to have certain color(s). + """ + doc = pipe(text) + visualize_ner_doc(doc, pipe.lang, select, colors) + + +def visualize_strings(texts, language_code, select=None, colors=None): + """ + Takes in a list of strings and a language code (Stanza defines these, ex: 'en' for English) to visualize all + of the strings' named entities. + + The strings are processed by the Stanza pipeline and the named entities are displayed. Each text is separated by a delimiting line. + + Optionally, the 'select' argument may be configured to only visualize given named entities (ex: select=['ORG', 'PERSON']). + + The optional colors argument is formatted as a dictionary of NER tags with their corresponding colors, which can be + represented as a string (ex: "blue"), a color hex value (ex: #aa9cfc), or as a linear gradient of color + values (ex: "linear-gradient(90deg, #aa9cfc, #fc9ce7)"). + """ + lang_pipe = stanza.Pipeline(language_code, processors="tokenize,ner") + + for text in texts: + visualize_ner_str(text, lang_pipe, select=select, colors=colors) + + +def visualize_docs(docs, language_code, select=None, colors=None): + """ + Takes in a list of doc and a language code (Stanza defines these, ex: 'en' for English) to visualize all + of the strings' named entities. + + Each text is separated by a delimiting line. + + Optionally, the 'select' argument may be configured to only visualize given named entities (ex: select=['ORG', 'PERSON']). + + The optional colors argument is formatted as a dictionary of NER tags with their corresponding colors, which can be + represented as a string (ex: "blue"), a color hex value (ex: #aa9cfc), or as a linear gradient of color + values (ex: "linear-gradient(90deg, #aa9cfc, #fc9ce7)"). + """ + for doc in docs: + visualize_ner_doc(doc, language_code, select=select, colors=colors) + + +def main(): + en_strings = ['''Samuel Jackson, a Christian man from Utah, went to the JFK Airport for a flight to New York. + He was thinking of attending the US Open, his favorite tennis tournament besides Wimbledon. + That would be a dream trip, certainly not possible since it is $5000 attendance and 5000 miles away. + On the way there, he watched the Super Bowl for 2 hours and read War and Piece by Tolstoy for 1 hour. + In New York, he crossed the Brooklyn Bridge and listened to the 5th symphony of Beethoven as well as + "All I want for Christmas is You" by Mariah Carey.''', + "Barack Obama was born in Hawaii. He was elected President of the United States in 2008"] + zh_strings = ['''来自犹他州的基督徒塞缪尔杰克逊前往肯尼迪机场搭乘航班飞往纽约。 + 他正在考虑参加美国公开赛,这是除了温布尔登之外他最喜欢的网球赛事。 + 那将是一次梦想之旅,当然不可能,因为它的出勤费为 5000 美元,距离 5000 英里。 + 在去的路上,他看了 2 个小时的超级碗比赛,看了 1 个小时的托尔斯泰的《战争与碎片》。 + 在纽约,他穿过布鲁克林大桥,聆听了贝多芬的第五交响曲以及 玛丽亚凯莉的“圣诞节我想要的就是你”。''', + "我觉得罗家费德勒住在加州, 在美国里面。"] + ar_strings = [ + ".أعيش في سان فرانسيسكو ، كاليفورنيا. اسمي أليكس وأنا ألتحق بجامعة ستانفورد. أنا أدرس علوم الكمبيوتر وأستاذي هو كريس مانينغ" + , "اسمي أليكس ، أنا من الولايات المتحدة.", + '''صامويل جاكسون ، رجل مسيحي من ولاية يوتا ، ذهب إلى مطار جون كنيدي في رحلة إلى نيويورك. كان يفكر في حضور بطولة الولايات المتحدة المفتوحة للتنس ، بطولة التنس المفضلة لديه إلى جانب بطولة ويمبلدون. ستكون هذه رحلة الأحلام ، وبالتأكيد ليست ممكنة لأنها تبلغ 5000 دولار للحضور و 5000 ميل. في الطريق إلى هناك ، شاهد Super Bowl لمدة ساعتين وقرأ War and Piece by Tolstoy لمدة ساعة واحدة. في نيويورك ، عبر جسر بروكلين واستمع إلى السيمفونية الخامسة لبيتهوفن وكذلك "كل ما أريده في عيد الميلاد هو أنت" لماريا كاري.'''] + + visualize_strings(en_strings, "en") + visualize_strings(zh_strings, "zh", colors={"PERSON": "yellow", "DATE": "red", "GPE": "blue"}) + visualize_strings(zh_strings, "zh", select=['PERSON', 'DATE']) + visualize_strings(ar_strings, "ar", + colors={"PER": "pink", "LOC": "linear-gradient(90deg, #aa9cfc, #fc9ce7)", "ORG": "yellow"}) + + +if __name__ == "__main__": + main() diff --git a/stanza/stanza/utils/visualization/semgrex_app.py b/stanza/stanza/utils/visualization/semgrex_app.py new file mode 100644 index 0000000000000000000000000000000000000000..5d861040f8f49aa9101c97aa08f7f4845d3ef4a3 --- /dev/null +++ b/stanza/stanza/utils/visualization/semgrex_app.py @@ -0,0 +1,393 @@ +import os +import sys +import streamlit as st +import streamlit.components.v1 as components +import stanza.utils.visualization.ssurgeon_visualizer as ssv +import logging + +from stanza.utils.visualization.semgrex_visualizer import visualize_search_str +from stanza.utils.visualization.semgrex_visualizer import edit_html_overflow +from stanza.utils.visualization.constants import * +from stanza.utils.conll import CoNLL +from stanza.server.ssurgeon import * +from stanza.pipeline.core import Pipeline + +from io import StringIO +import os +from typing import List, Tuple, Any +import argparse + + +def get_semgrex_text_and_query() -> Tuple[str, str]: + """ + Gets user input for the Semgrex text and queries to process. + + @return: A tuple containing the user's input text and their input queries + """ + input_txt = st.text_area( + "Text to analyze", + DEFAULT_SAMPLE_TEXT, + placeholder=DEFAULT_SAMPLE_TEXT, + ) + input_queries = st.text_area( + "Semgrex search queries (separate each query with a comma)", + DEFAULT_SEMGREX_QUERY, + placeholder=DEFAULT_SEMGREX_QUERY, + ) + return input_txt, input_queries + + +def get_file_input() -> List[str]: + """ + Allows user to submit files for analysis. + + @return: List of strings containing the file contents of each submitted file. The i-th element of res is the + string representing the i-th file uploaded. + """ + st.markdown("""**Alternatively, upload file(s) to analyze.**""") + uploaded_files = st.file_uploader( + "button_label", accept_multiple_files=True, label_visibility="collapsed" + ) + res = [] + for file in uploaded_files: + stringio = StringIO(file.getvalue().decode("utf-8")) + string_data = stringio.read() + res.append(string_data) + return res + + +def get_semgrex_window_input() -> Tuple[bool, int, int]: + """ + Allows user to specify a specific window of Semgrex hits to visualize. Works similar to Python splicing. + + @return: A tuple containing a bool representing whether or not the user wants to visualize a splice of + the visualizations, and two ints representing the start and end indices of the splice. + """ + show_window = st.checkbox( + "Visualize a specific window of Semgrex search hits?", + help="""If you want to visualize all search results, leave this unmarked.""", + ) + start_window, end_window = None, None + if show_window: + start_window = st.number_input( + "Which search hit should visualizations start from?", + help="""If you want to visualize the first 10 search results, set this to 0.""", + min_value=0, + ) + end_window = st.number_input( + "Which search hit should visualizations stop on?", + help="""If you want to visualize the first 10 search results, set this to 11. + The 11th result will NOT be displayed.""", + value=11, + min_value=start_window + 1, + ) + return show_window, start_window, end_window + + +def get_pos_input() -> bool: + """ + Prompts client for whether they want to see xpos tags instead of upos. + """ + use_xpos = st.checkbox("Would you like to visualize xpos tags?", + help="The default visualization options use upos tags for part-of-speech labeling. If xpos tags aren't available for the sentence, displays upos.") + return use_xpos + + +def get_input() -> Tuple[str, str, List[str], Tuple[bool, int, int, bool]]: + """ + Tie together all inputs to query user for all possible inputs. + """ + input_txt, input_queries = get_semgrex_text_and_query() + client_files = get_file_input() # this is already converted to string format + window_input = get_semgrex_window_input() + visualize_xpos = get_pos_input() + return input_txt, input_queries, client_files, window_input, visualize_xpos + + +def run_semgrex_process( + input_txt: str, + input_queries: str, + client_files: List[str], + show_window: bool, + clicked: bool, + pipe: Any, + start_window: int, + end_window: int, + visualize_xpos: bool, + show_success: bool = True +) -> None: + """ + Run Semgrex search on the input text/files with input query and serve the HTML on the app. + + @param input_txt: Text to analyze and draw sentences from. + @param input_queries: Semgrex queries to parse the input with. + @param client_files: Alternative to input text, we can parse the content of files for scaled analysis. + @param show_window: Whether or not the user wants a splice of the visualizations + @param clicked: Whether or not the button has been clicked to run Semgrex search + @param pipe: NLP pipeline to process input with + @param start_window: If displaying a splice of visualizations, this is the start idx + @param end_window: If displaying a splice of visualizations, this is the end idx + @param visualize_xpos: Set to true if using xpos tags for part of speech labels, otherwise use upos tags + + """ + + if clicked: + + # process inputs, reject bad ones + if not input_txt and not client_files: + st.error("Please provide a text input or upload files for analysis.") + elif input_txt and client_files: + st.error( + "Please only choose to visualize your input text or your uploaded files, not both." + ) + elif not input_queries: + st.error("Please provide a set of Semgrex queries.") + else: # no input errors + try: + with st.spinner("Processing..."): + queries = [ + query.strip() for query in input_queries.split(",") + ] # separate queries into individual parts + if client_files: + html_strings, begin_viz_idx, end_viz_idx = [], 0, float("inf") + if show_window: + begin_viz_idx, end_viz_idx = ( + start_window - 1, + end_window - 1, + ) + for client_file in client_files: + client_file_html_strings = visualize_search_str( + client_file, + queries, + "en", + start_match=begin_viz_idx, + end_match=end_viz_idx, + pipe=pipe, + visualize_xpos=visualize_xpos + ) + html_strings += client_file_html_strings + else: # just input text, no files + if show_window: + html_strings = visualize_search_str( + input_txt, + queries, + "en", + start_match=start_window - 1, + end_match=end_window - 1, + pipe=pipe, + visualize_xpos=visualize_xpos + ) + else: + html_strings = visualize_search_str( + input_txt, + queries, + "en", + end_match=float("inf"), + pipe=pipe, + visualize_xpos=visualize_xpos + ) + + + if len(html_strings) == 0: + st.write("No Semgrex match hits!") + + # Render successful Semgrex results + for s in html_strings: + s_no_overflow = edit_html_overflow(s) + components.html( + s_no_overflow, height=200, width=1000, scrolling=True + ) + if show_success: + if len(html_strings) == 1: + st.success( + f"Completed! Visualized {len(html_strings)} Semgrex search hit." + ) + else: + st.success( + f"Completed! Visualized {len(html_strings)} Semgrex search hits." + ) + except OSError: + st.error( + "Your text input or your provided Semgrex queries are incorrect. Please try again." + ) + + +def semgrex_state(): + """ + Contains the Semgrex portion of the webpage. + + This contains the markdown and calls to the processes which run when a query is made. + + When the `Load Semgrex search visualization` button is pressed, the function `run_semgrex_process` + is called inside this function and the rendered visual is placed onto the webpage. + """ + + # Title Markdown for page header + st.title("Displaying Semgrex Queries") + + html_string = ( + "

Enter a text below, along with your Semgrex query of choice.

" + ) + st.markdown(html_string, unsafe_allow_html=True) + input_txt, input_queries, client_files, window_input, visualize_xpos = get_input() + + show_window, start_window, end_window = window_input + + clicked = st.button( + "Load Semgrex search visualization", + help="""Semgrex search visualizations only display + sentences with a query match. Non-matching sentences are not shown.""", + ) # use the on_click param + + run_semgrex_process( + input_txt=input_txt, + input_queries=input_queries, + client_files=client_files, + show_window=show_window, + clicked=clicked, + pipe=st.session_state["pipeline"], + start_window=start_window, + end_window=end_window, + visualize_xpos=visualize_xpos + ) + + +def ssurgeon_state(): + """ + Contains the ssurgeon state for the webpage. + + This contains the markdown and calls the processes that run Ssurgeon operations. + + When the text boxes, buttons, or other interactable features are edited by the user, this function + runs with the updated page state and conducts operations (e.g. runs a Ssurgeon operation on a submitted file) + """ + + st.title("Displaying Ssurgeon Results") + + # Textbox for input to SSurgeon (text) + input_txt = st.text_area( + "Text to analyze", + SAMPLE_SSURGEON_DOC, + placeholder=SAMPLE_SSURGEON_DOC, + ) + + # Textbox for input queries to SSurgeon (commands + queries) + semgrex_input_queries = st.text_area( + "Semgrex search queries (separate each query with a comma)", + "{}=source >nsubj {} >csubj=bad {}", + placeholder="""{}=source >nsubj {} >csubj=bad {}""", + ) + ssurgeon_input_queries = st.text_area( + "Ssurgeon commands", + "relabelNamedEdge -edge bad -reln advcl", + placeholder="relabelNamedEdge -edge bad -reln advcl" + ) + + # File uploading box + st.markdown("""**Alternatively, upload file(s) to edit.**""") + uploaded_files = st.file_uploader( + "", accept_multiple_files=True, label_visibility="collapsed" + ) + res = [] + # Convert uploaded files to strings for processing + for file in uploaded_files: + stringio = StringIO(file.getvalue().decode("utf-8")) + string_data = stringio.read() + res.append(string_data) + + # Input button to trigger processing phase + clicked = st.button( + "Load Ssurgeon visualization", + ) + clicked_for_file_edit = st.button( + "Edit File" + ) + # Once the user requests the Ssurgeon operation, run this block: + if clicked: + try: + with st.spinner("Processing..."): + semgrex_queries = semgrex_input_queries # separate queries into individual parts + ssurgeon_queries = [ssurgeon_input_queries] + + # use SSurgeon to edit the deprel, get the HTML for new relations + html_strings = ssv.visualize_ssurgeon_deprel_adjusted_str_input(input_txt, semgrex_queries, ssurgeon_queries) + doc = CoNLL.conll2doc(input_str=input_txt) + string_txt = " ".join([word.text for sentence in doc.sentences for word in sentence.words]) + + # Render pre-edited input + html_string = ( + "

Previous deprel visualization:

" + ) + st.markdown(html_string, unsafe_allow_html=True) + components.html( + run_semgrex_process(input_txt=string_txt, input_queries=semgrex_queries, clicked=clicked, + show_window=False, client_files=[], pipe=st.session_state["pipeline"], + start_window=1, end_window=11, visualize_xpos=False, show_success=False) + ) + + if len(html_strings) == 0: + st.write("No Semgrex match hits!") + + # Render edited outputs + for s in html_strings: + html_string = ( + "

Edited deprel visualization:

" + ) + st.markdown(html_string, unsafe_allow_html=True) + s_no_overflow = edit_html_overflow(s) + components.html( + s_no_overflow, height=200, width=1000, scrolling=True + ) + except OSError: + st.error( + "Your text input or your provided Semgrex/Ssurgeon queries are incorrect. Please try again." + ) + # If the input is a file instead of raw text, process the file with Ssurgeon and give an output + # that can be downloaded by the client + if clicked_for_file_edit: + # files are in res + if len(res) == 0: + st.error("You must provide files for analysis.") + with st.spinner("Editing..."): + single_file = res[0] + doc = CoNLL.conll2doc(input_str=single_file) + ssurgeon_response = process_doc_one_operation(doc, semgrex_input_queries, [ssurgeon_input_queries]) + updated_doc = convert_response_to_doc(doc, ssurgeon_response) + output = CoNLL.doc2conll(updated_doc)[0] + output_str = "\n".join(output) + st.download_button("Download your edited file", data=output_str, file_name="SSurgeon.conll") + + +def main(): + + parser = argparse.ArgumentParser() + parser.add_argument( + "--CLASSPATH", + type=str, + default=os.environ.get("CLASSPATH", None), + help="Path to your CoreNLP directory.", + ) # for example, set $CLASSPATH to "C:\\stanford-corenlp-4.5.2\\stanford-corenlp-4.5.2\\*" + args = parser.parse_args() + + CLASSPATH = args.CLASSPATH + os.environ["CLASSPATH"] = CLASSPATH + + if os.environ.get("CLASSPATH") is None: + logging.error("Provide a valid $CLASSPATH value (path to your CoreNLP installation).") + raise ValueError("Provide a valid $CLASSPATH value (path to your CoreNLP installation).") + + # run pipeline once per user session + if "pipeline" not in st.session_state: + en_nlp_stanza = Pipeline( + "en", processors="tokenize, pos, lemma, depparse" + ) + st.session_state["pipeline"] = en_nlp_stanza + + #### Below is the webpage states that run. Streamlit operates by having the rendered HTML and when the user interacts with + # the page, these states are run once more with their internal states possibly altered (e.g. user clicks a button). + + semgrex_state() + ssurgeon_state() + +if __name__ == "__main__": + main() diff --git a/stanza/stanza/utils/visualization/semgrex_visualizer.py b/stanza/stanza/utils/visualization/semgrex_visualizer.py new file mode 100644 index 0000000000000000000000000000000000000000..6ddf6e56782bfc5fb0693946c550304d488d8683 --- /dev/null +++ b/stanza/stanza/utils/visualization/semgrex_visualizer.py @@ -0,0 +1,608 @@ +import os +import argparse +import sys + +root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) +sys.path.append(root_dir) + +from stanza.pipeline.core import Pipeline +from stanza.server.semgrex import Semgrex +from stanza.models.common.constant import is_right_to_left +import spacy +from spacy import displacy +from spacy.tokens import Doc +from IPython.display import display, HTML +import typing +from typing import List, Tuple, Any + +from stanza.utils.visualization.utils import find_nth, round_base + + +def get_sentences_html(doc: Any, language: str, visualize_xpos: bool = False) -> List[str]: + """ + Returns a list of HTML strings representing the dependency visualizations of a given stanza document. + One HTML string is generated per sentence of the document object. Converts the stanza document object + to a spaCy doc object and generates HTML with displaCy. + + @param doc: a stanza document object which can be generated with an NLP pipeline. + @param language: the two letter language code for the document e.g. "en" for English. + @param visualize_xpos: A toggled option to use xpos tags for part-of-speech labels instead of upos. + + @return: a list of HTML strings which visualize the dependencies of the doc object. + """ + USE_FINE_GRAINED = False if not visualize_xpos else True + html_strings, sentences_to_visualize = [], [] + nlp = spacy.blank( + "en" + ) # blank model - we don't use any of the model features, just the visualization + for sentence in doc.sentences: + words, lemmas, heads, deps, tags = [], [], [], [], [] + if is_right_to_left( + language + ): # order of words displayed is reversed, dependency arcs remain intact + sentence_len = len(sentence.words) + for word in reversed(sentence.words): + words.append(word.text) + lemmas.append(word.lemma) + deps.append(word.deprel) + if visualize_xpos and word.xpos: + tags.append(word.xpos) + else: + tags.append(word.upos) + if word.head == 0: # spaCy head indexes are one-off from Stanza's + heads.append(sentence_len - word.id) + else: + heads.append(sentence_len - word.head) + else: # left to right rendering + for word in sentence.words: + words.append(word.text) + lemmas.append(word.lemma) + deps.append(word.deprel) + if visualize_xpos and word.xpos: + tags.append(word.xpos) + else: + tags.append(word.upos) + if word.head == 0: + heads.append(word.id - 1) + else: + heads.append(word.head - 1) + if USE_FINE_GRAINED: + stanza_to_spacy_doc = Doc( + nlp.vocab, words=words, lemmas=lemmas, heads=heads, deps=deps, tags=tags + ) + else: + stanza_to_spacy_doc = Doc( + nlp.vocab, words=words, lemmas=lemmas, heads=heads, deps=deps, pos=tags + ) + sentences_to_visualize.append(stanza_to_spacy_doc) + + for line in sentences_to_visualize: # render all sentences through displaCy + html_strings.append( + displacy.render( + line, + style="dep", + options={ + "compact": True, + "word_spacing": 30, + "distance": 100, + "arrow_spacing": 20, + "fine_grained": USE_FINE_GRAINED + }, + jupyter=False, + ) + ) + return html_strings + + +def semgrexify_html(orig_html: str, semgrex_sentence) -> str: + """ + Modifies the HTML of a sentence's dependency visualization, highlighting words involved in the + semgrex_sentence search queries and adding the label of the word inside of the match. + + + @param orig_html: unedited HTML of a sentence's dependency visualization. + @param semgrex_sentence: a Semgrex result object containing the matches to a provided query. + @return: edited HTML containing the visual changes described above. + """ + tracker = {} # keep track of which words have multiple labels + DEFAULT_TSPAN_COUNT = ( + 2 # the original displacy html assigns two objects per object + ) + CLOSING_TSPAN_LEN = 8 # is 8 chars long + colors = [ + "#4477AA", + "#66CCEE", + "#228833", + "#CCBB44", + "#EE6677", + "#AA3377", + "#BBBBBB", + ] # colorblind-friendly scheme + css_bolded_class = "\n" + opening_svg_end_idx = orig_html.find("\n") + # insert the new style class + orig_html = ( + orig_html[: opening_svg_end_idx + 1] + + css_bolded_class + + orig_html[opening_svg_end_idx + 1 :] + ) + + # Color and bold words involved in each Semgrex match + for query in semgrex_sentence.result: + for i, match in enumerate(query.match): + color = colors[i] + paired_dy = 2 + for node in match.node: + name, match_index = node.name, node.matchIndex + # edit existing to change color and bold the text + start = find_nth( + orig_html, " of interest + if ( + match_index not in tracker + ): # if we've already bolded and colored, keep the first color + tspan_start = orig_html.find( + " inside of the + tspan_end = orig_html.find( + "", start + ) # finds start of the end of the above + tspan_substr = ( + orig_html[tspan_start : tspan_end + CLOSING_TSPAN_LEN + 1] + + "\n" + ) + # color and bold words in the search hit + edited_tspan = tspan_substr.replace( + 'class="displacy-word"', 'class="bolded"' + ).replace('fill="currentColor"', f'fill="{color}"') + # insert edited object into html string + + # TODO: DEBUG. This code has a bug in it that causes the svg to not end on an input like + # "The Wimbledon grass-court tennis tournament banned players, resulting in players hating others." + # to malfunction and add another copy to the tail-end of the first svg rendering. + # This bug has been patched in the end of this function, but need to find out what is going on. + orig_html = ( + orig_html[:tspan_start] + + edited_tspan + + orig_html[tspan_end + CLOSING_TSPAN_LEN + 2 :] + ) + + tracker[match_index] = DEFAULT_TSPAN_COUNT + + # next, we have to insert the new object for the label + # Copy old to copy formatting when creating new later + prev_tspan_start = ( + find_nth(orig_html[start:], " start index + prev_tspan_end = ( + find_nth(orig_html[start:], "", tracker[match_index] - 1) + + start + ) # find the prev start index + prev_tspan = orig_html[ + prev_tspan_start : prev_tspan_end + CLOSING_TSPAN_LEN + 1 + ] + + # Find spot to insert new tspan + closing_tspan_start = ( + find_nth(orig_html[start:], "", tracker[match_index]) + + start + ) + up_to_new_tspan = orig_html[ + : closing_tspan_start + CLOSING_TSPAN_LEN + 1 + ] + rest = orig_html[closing_tspan_start + CLOSING_TSPAN_LEN + 1 :] + + # Calculate proper x value in svg + x_value_start = prev_tspan.find('x="') + x_value_end = ( + prev_tspan[x_value_start + 3 :].find('"') + 3 + ) # 3 is the length of the 'x="' substring + x_value = prev_tspan[x_value_start + 3 : x_value_end + x_value_start] + + # Calculate proper y value in svg + DEFAULT_DY_VAL, dy = 2, 2 + if ( + paired_dy != DEFAULT_DY_VAL and node == match.node[1] + ): # we're on the second node and need to adjust height to match the paired node + dy = paired_dy + if node == match.node[0] and len(match.node) > 1: + paired_node_level = 2 + if ( + match.node[1].matchIndex in tracker + ): # check if we need to adjust heights of labels + paired_node_level = tracker[match.node[1].matchIndex] + dif = tracker[match_index] - paired_node_level + if dif > 0: # current node has more labels + paired_dy = DEFAULT_DY_VAL * dif + 1 + dy = DEFAULT_DY_VAL + else: # paired node has more labels, adjust this label down + dy = DEFAULT_DY_VAL * (abs(dif) + 1) + paired_dy = DEFAULT_DY_VAL + + # Insert new object + new_tspan = f' {name[: 3].title()}.\n' # abbreviate label names to 3 chars + orig_html = up_to_new_tspan + new_tspan + rest + tracker[match_index] += 1 + + # process out extra term if present -- TODO: Figure out why the semgrexify_html function lines 164-168 cause a duplication bug + end = find_nth(haystack=orig_html, needle=" has length 6 so add 1 to the end too + if len(orig_html) > end + LENGTH_OF_END_SVG: + orig_html = orig_html[: end + LENGTH_OF_END_SVG] + + return orig_html + + +def render_html_strings(edited_html_strings: List[str]) -> None: + """ + Renders the HTML of each HTML string. + """ + for html_string in edited_html_strings: + display(HTML(html_string)) + + +def visualize_search_doc( + doc: Any, + semgrex_queries: List[str], + lang_code: str, + start_match: int = 0, + end_match: int = 11, + render: bool = True, + visualize_xpos: bool = False +) -> List[str]: + """ + Visualizes the result of running Semgrex search on a document. The i-th element of + the returned list is the HTML representation of the i-th sentence's dependency + relationships. Only shows sentences that have a match on the Semgrex search. + + @param doc: A Stanza document object that contains dependency relationships . + @param semgrex_queries: A list of Semgrex queries to search for in the document. + @param lang_code: A two letter language abbreviation for the language that the Stanza document is written in. + @param start_match: Beginning of the splice for which to display elements with. + @param end_match: End of the splice for which to display elements with. + @param render: A toggled option to render the HTML strings within the returned list + @param visualize_xpos: A toggled option to use xpos tags in part-of-speech labels, defaulting to upos tags. + + @return: A list of HTML strings representing the dependency relations of the doc object. + """ + + matches_count = 0 # Limits number of visualizations + with Semgrex(classpath="$CLASSPATH") as sem: + edited_html_strings = [] + semgrex_results = sem.process(doc, *semgrex_queries) + # one html string for each sentence + unedited_html_strings = get_sentences_html(doc, lang_code, visualize_xpos=visualize_xpos) + + for i in range(len(unedited_html_strings)): + + if matches_count >= end_match: # we've collected enough matches + break + + # check if sentence has matches, if not then do not visualize + has_none = True + for query in semgrex_results.result[i].result: + for match in query.match: + if match: + has_none = False + + # Process HTML if queries have matches + if not has_none: + if start_match <= matches_count < end_match: + edited_string = semgrexify_html( + unedited_html_strings[i], semgrex_results.result[i] + ) + + edited_string = adjust_dep_arrows(edited_string) + edited_html_strings.append(edited_string) + matches_count += 1 + if render: + render_html_strings(edited_html_strings) + return edited_html_strings + + +def visualize_search_str( + text: str, + semgrex_queries: List[str], + lang_code: str, + start_match: int = 0, + end_match: int = 11, + pipe=None, + render: bool = True, + visualize_xpos: bool = False +): + """ + Visualizes the result of running Semgrex search on a string. The i-th element of + the returned list is the HTML representation of the i-th sentence's dependency + relationships. Only shows sentences that have a match on the Semgrex search. + + @param text: The string for which Semgrex search will be run on. + @param semgrex_queries: A list of Semgrex queries to search for in the document. + @param lang_code: A two letter language abbreviation for the language that the Stanza document is written in. + @param start_match: Beginning of the splice for which to display elements with. + @param end_match: End of the splice for which to display elements with. + @param pipe: An NLP pipeline through which the text will be processed. + @param render: A toggled option to render the HTML strings within the returned list. + @param visualize_xpos: A toggled option to use xpos tags for part-of-speech labeling, defaulting to upos tags + + @return: A list of HTML strings representing the dependency relations of the doc object. + """ + if pipe is None: + nlp = Pipeline(lang_code, processors="tokenize, pos, lemma, depparse") + else: + nlp = pipe + doc = nlp(text) + return visualize_search_doc( + doc, + semgrex_queries, + lang_code, + start_match=start_match, + end_match=end_match, + render=render, + visualize_xpos=visualize_xpos + ) + + +def adjust_dep_arrows(raw_html: str) -> str: + """ + Default spaCy dependency visualizations have misaligned arrows. Fix arrows by aligning arrow ends and bodies + to the word that they are directed to. + + @param raw_html: Dependency relation visualization generated HTML from displaCy + @return: Edited HTML string with fixed arrow placements + """ + + HTML_ARROW_BEGINNING = '' + HTML_ARROW_ENDING = "" + HTML_ARROW_ENDING_LEN = 6 # there are 2 newline chars after the arrow ending + arrows_start_idx = find_nth( + haystack=raw_html, needle='', n=1 + ) + words_html, arrows_html = ( + raw_html[:arrows_start_idx], + raw_html[arrows_start_idx:], + ) # separate html for words and arrows + final_html = ( + words_html # continually concatenate to this after processing each arrow + ) + arrow_number = 1 # which arrow we're currently editing (1-indexed) + start_idx, end_of_class_idx = ( + find_nth(haystack=arrows_html, needle=HTML_ARROW_BEGINNING, n=arrow_number), + find_nth(haystack=arrows_html, needle=HTML_ARROW_ENDING, n=arrow_number), + ) + while start_idx != -1: # edit every arrow + arrow_section = arrows_html[ + start_idx : end_of_class_idx + HTML_ARROW_ENDING_LEN + ] # slice a single svg arrow object + if ( + arrow_section[-1] == "<" + ): # this is the last arrow in the HTML, don't cut the splice early + arrow_section = arrows_html[start_idx:] + edited_arrow_section = edit_dep_arrow(arrow_section) + + final_html = ( + final_html + edited_arrow_section + ) # continually update html with new arrow html until done + + # Prepare for next iteration + arrow_number += 1 + start_idx = find_nth(arrows_html, '', arrow_number) + end_of_class_idx = find_nth(arrows_html, "", arrow_number) + return final_html + + +def edit_dep_arrow(arrow_html: str) -> str: + """ + The formatting of a single displacy arrow in svg is the following: + + + + csubj + + + + + We edit the 'd = ...' parts of the section to fix the arrow direction and length to round to + the nearest 50 units, centering on each word's center. This is because the words start at x=50 and have spacing + of 100, so each word is at an x-value that is a multiple of 50. + + @param arrow_html: Original SVG for a single displaCy arrow. + @return: Edited SVG for the displaCy arrow, adjusting its placement + """ + + WORD_SPACING = 50 # words start at x=50 and are separated by 100s so their x values are multiples of 50 + M_OFFSET = 4 # length of 'd="M' that we search for to extract the number from d="M70, for instance + ARROW_PIXEL_SIZE = 4 + first_d_idx, second_d_idx = ( + find_nth(arrow_html, 'd="M', 1), + find_nth(arrow_html, 'd="M', 2), + ) # find where d="M starts + first_d_cutoff, second_d_cutoff = ( + arrow_html.find(",", first_d_idx), + arrow_html.find(",", second_d_idx), + ) # isolate the number after 'M' e.g. 'M70' + # gives svg x values of arrow body starting position and arrowhead position + arrow_position, arrowhead_position = ( + float(arrow_html[first_d_idx + M_OFFSET : first_d_cutoff]), + float(arrow_html[second_d_idx + M_OFFSET : second_d_cutoff]), + ) + # gives starting index of where 'fill="none"' or 'fill="currentColor"' begin, reference points to end the d= section + first_fill_start_idx, second_fill_start_idx = ( + find_nth(arrow_html, "fill", n=1), + find_nth(arrow_html, "fill", n=3), + ) + + # isolate the d= ... section to edit + first_d, second_d = ( + arrow_html[first_d_idx:first_fill_start_idx], + arrow_html[second_d_idx:second_fill_start_idx], + ) + first_d_split, second_d_split = first_d.split(","), second_d.split(",") + + if ( + arrow_position == arrowhead_position + ): # This arrow is incoming onto the word, center the arrow/head to word center + corrected_arrow_pos = corrected_arrowhead_pos = round_base( + arrow_position, base=WORD_SPACING + ) + + # edit first_d -- arrow body + second_term = first_d_split[1].split(" ")[0] + " " + str(corrected_arrow_pos) + first_d = ( + 'd="M' + + str(corrected_arrow_pos) + + "," + + second_term + + "," + + ",".join(first_d_split[2:]) + ) + + # edit second_d -- arrowhead + second_term = ( + second_d_split[1].split(" ")[0] + + " L" + + str(corrected_arrowhead_pos - ARROW_PIXEL_SIZE) + ) + third_term = ( + second_d_split[2].split(" ")[0] + + " " + + str(corrected_arrowhead_pos + ARROW_PIXEL_SIZE) + ) + second_d = ( + 'd="M' + + str(corrected_arrowhead_pos) + + "," + + second_term + + "," + + third_term + + "," + + ",".join(second_d_split[3:]) + ) + else: # This arrow is outgoing to another word, center the arrow/head to that word's center + corrected_arrowhead_pos = round_base(arrowhead_position, base=WORD_SPACING) + + # edit first_d -- arrow body + third_term = first_d_split[2].split(" ")[0] + " " + str(corrected_arrowhead_pos) + fourth_term = ( + first_d_split[3].split(" ")[0] + " " + str(corrected_arrowhead_pos) + ) + terms = [ + first_d_split[0], + first_d_split[1], + third_term, + fourth_term, + ] + first_d_split[4:] + first_d = ",".join(terms) + + # edit second_d -- arrow head + first_term = f'd="M{corrected_arrowhead_pos}' + second_term = ( + second_d_split[1].split(" ")[0] + + " L" + + str(corrected_arrowhead_pos - ARROW_PIXEL_SIZE) + ) + third_term = ( + second_d_split[2].split(" ")[0] + + " " + + str(corrected_arrowhead_pos + ARROW_PIXEL_SIZE) + ) + terms = [first_term, second_term, third_term] + second_d_split[3:] + second_d = ",".join(terms) + # rebuild and return html from its individual sections + return ( + arrow_html[:first_d_idx] + + first_d + + " " + + arrow_html[first_fill_start_idx:second_d_idx] + + second_d + + " " + + arrow_html[second_fill_start_idx:] + ) + + +def edit_html_overflow(html_string: str) -> str: + """ + Adds to overflow and display settings to the SVG header to visualize overflowing HTML renderings in the + Semgrex streamlit app. Prevents Semgrex search tags from being cut off at the bottom of visualizations. + + The opening of each HTML string looks similar to this; we add to the end of the SVG header. + + + + + Banning + + VERB + Act. + + + @param html_string: HTML of the result of running Semgrex search on a text + @return: Edited HTML to visualize the dependencies even in the case of overflow. + """ + + BUFFER_LEN = 14 # length of 'direction: ltr"' + editing_start_idx = find_nth(html_string, "direction: ltr", n=1) + SVG_HEADER_ADDITION = "overflow: visible; display: block" + return ( + html_string[:editing_start_idx] + + "; " + + SVG_HEADER_ADDITION + + html_string[editing_start_idx + BUFFER_LEN :] + ) + + +def main(): + """ + IMPORTANT: For the code in this module to run, you must have corenlp and Java installed on your machine. Additionally, + set an environment variable CLASSPATH equal to the path of your corenlp directory. + + Example: CLASSPATH=C:\\Users\\Alex\\PycharmProjects\\pythonProject\\stanford-corenlp-4.5.0\\stanford-corenlp-4.5.0\\* + """ + nlp = Pipeline("en", processors="tokenize,pos,lemma,depparse") + doc = nlp( + "Banning opal removed artifact decks from the meta. Banning tennis resulted in players banning people." + ) + queries = [ + "{pos:NN}=object