File size: 10,290 Bytes
19b8775
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
"""
Run the tagger for a couple iterations on some fake data

Uses a couple sentences of UD_English-EWT as training/dev data
"""

import os
import pytest
import zipfile

import torch

from stanza.models import parser
from stanza.models.common import pretrain
from stanza.models.depparse.trainer import Trainer
from stanza.tests import TEST_WORKING_DIR

pytestmark = [pytest.mark.pipeline, pytest.mark.travis]

TRAIN_DATA = """
# sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0003
# text = DPA: Iraqi authorities announced that they had busted up 3 terrorist cells operating in Baghdad.
1	DPA	DPA	PROPN	NNP	Number=Sing	0	root	0:root	SpaceAfter=No
2	:	:	PUNCT	:	_	1	punct	1:punct	_
3	Iraqi	Iraqi	ADJ	JJ	Degree=Pos	4	amod	4:amod	_
4	authorities	authority	NOUN	NNS	Number=Plur	5	nsubj	5:nsubj	_
5	announced	announce	VERB	VBD	Mood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin	1	parataxis	1:parataxis	_
6	that	that	SCONJ	IN	_	9	mark	9:mark	_
7	they	they	PRON	PRP	Case=Nom|Number=Plur|Person=3|PronType=Prs	9	nsubj	9:nsubj	_
8	had	have	AUX	VBD	Mood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin	9	aux	9:aux	_
9	busted	bust	VERB	VBN	Tense=Past|VerbForm=Part	5	ccomp	5:ccomp	_
10	up	up	ADP	RP	_	9	compound:prt	9:compound:prt	_
11	3	3	NUM	CD	NumForm=Digit|NumType=Card	13	nummod	13:nummod	_
12	terrorist	terrorist	ADJ	JJ	Degree=Pos	13	amod	13:amod	_
13	cells	cell	NOUN	NNS	Number=Plur	9	obj	9:obj	_
14	operating	operate	VERB	VBG	VerbForm=Ger	13	acl	13:acl	_
15	in	in	ADP	IN	_	16	case	16:case	_
16	Baghdad	Baghdad	PROPN	NNP	Number=Sing	14	obl	14:obl:in	SpaceAfter=No
17	.	.	PUNCT	.	_	1	punct	1:punct	_

# sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0004
# text = Two of them were being run by 2 officials of the Ministry of the Interior!
1	Two	two	NUM	CD	NumForm=Word|NumType=Card	6	nsubj:pass	6:nsubj:pass	_
2	of	of	ADP	IN	_	3	case	3:case	_
3	them	they	PRON	PRP	Case=Acc|Number=Plur|Person=3|PronType=Prs	1	nmod	1:nmod:of	_
4	were	be	AUX	VBD	Mood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin	6	aux	6:aux	_
5	being	be	AUX	VBG	VerbForm=Ger	6	aux:pass	6:aux:pass	_
6	run	run	VERB	VBN	Tense=Past|VerbForm=Part|Voice=Pass	0	root	0:root	_
7	by	by	ADP	IN	_	9	case	9:case	_
8	2	2	NUM	CD	NumForm=Digit|NumType=Card	9	nummod	9:nummod	_
9	officials	official	NOUN	NNS	Number=Plur	6	obl	6:obl:by	_
10	of	of	ADP	IN	_	12	case	12:case	_
11	the	the	DET	DT	Definite=Def|PronType=Art	12	det	12:det	_
12	Ministry	Ministry	PROPN	NNP	Number=Sing	9	nmod	9:nmod:of	_
13	of	of	ADP	IN	_	15	case	15:case	_
14	the	the	DET	DT	Definite=Def|PronType=Art	15	det	15:det	_
15	Interior	Interior	PROPN	NNP	Number=Sing	12	nmod	12:nmod:of	SpaceAfter=No
16	!	!	PUNCT	.	_	6	punct	6:punct	_

""".lstrip()


DEV_DATA = """
1	From	from	ADP	IN	_	3	case	3:case	_
2	the	the	DET	DT	Definite=Def|PronType=Art	3	det	3:det	_
3	AP	AP	PROPN	NNP	Number=Sing	4	obl	4:obl:from	_
4	comes	come	VERB	VBZ	Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin	0	root	0:root	_
5	this	this	DET	DT	Number=Sing|PronType=Dem	6	det	6:det	_
6	story	story	NOUN	NN	Number=Sing	4	nsubj	4:nsubj	_
7	:	:	PUNCT	:	_	4	punct	4:punct	_

""".lstrip()



class TestParser:
    @pytest.fixture(scope="class")
    def wordvec_pretrain_file(self):
        return f'{TEST_WORKING_DIR}/in/tiny_emb.pt'

    def run_training(self, tmp_path, wordvec_pretrain_file, train_text, dev_text, augment_nopunct=False, extra_args=None, zip_train_data=False):
        """
        Run the training for a few iterations, load & return the model
        """
        train_file = str(tmp_path / "train.zip") if zip_train_data else str(tmp_path / "train.conllu")
        dev_file = str(tmp_path / "dev.conllu")
        pred_file = str(tmp_path / "pred.conllu")

        save_name = "test_parser.pt"
        save_file = str(tmp_path / save_name)

        if zip_train_data:
            with zipfile.ZipFile(train_file, "w") as zout:
                with zout.open('train.conllu', 'w') as fout:
                    fout.write(train_text.encode())
        else:
            with open(train_file, "w", encoding="utf-8") as fout:
                fout.write(train_text)

        with open(dev_file, "w", encoding="utf-8") as fout:
            fout.write(dev_text)

        args = ["--wordvec_pretrain_file", wordvec_pretrain_file,
                "--train_file", train_file,
                "--eval_file", dev_file,
                "--output_file", pred_file,
                "--log_step", "10",
                "--eval_interval", "20",
                "--max_steps", "100",
                "--shorthand", "en_test",
                "--save_dir", str(tmp_path),
                "--save_name", save_name,
                # in case we are doing a bert test
                "--bert_start_finetuning", "10",
                "--bert_warmup_steps", "10",
                "--lang", "en"]
        if not augment_nopunct:
            args.extend(["--augment_nopunct", "0.0"])
        if extra_args is not None:
            args = args + extra_args
        trainer, _ = parser.main(args)

        assert os.path.exists(save_file)
        pt = pretrain.Pretrain(wordvec_pretrain_file)
        # test loading the saved model
        saved_model = Trainer(pretrain=pt, model_file=save_file)
        return trainer

    def test_train(self, tmp_path, wordvec_pretrain_file):
        """
        Simple test of a few 'epochs' of tagger training
        """
        self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA)

    def test_zipfile_train(self, tmp_path, wordvec_pretrain_file):
        """
        Simple test of a few 'epochs' of tagger training with a zipfile
        """
        self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, zip_train_data=True)

    def test_with_bert_nlayers(self, tmp_path, wordvec_pretrain_file):
        self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=['--bert_model', 'hf-internal-testing/tiny-bert', '--bert_hidden_layers', '2'])

    def test_with_bert_finetuning(self, tmp_path, wordvec_pretrain_file):
        trainer = self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=['--bert_model', 'hf-internal-testing/tiny-bert', '--bert_finetune', '--bert_hidden_layers', '2'])
        assert 'bert_optimizer' in trainer.optimizer.keys()
        assert 'bert_scheduler' in trainer.scheduler.keys()

    def test_with_bert_finetuning_resaved(self, tmp_path, wordvec_pretrain_file):
        """
        Check that if we save, then load, then save a model with a finetuned bert, that bert isn't lost
        """
        trainer = self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=['--bert_model', 'hf-internal-testing/tiny-bert', '--bert_finetune', '--bert_hidden_layers', '2'])
        assert 'bert_optimizer' in trainer.optimizer.keys()
        assert 'bert_scheduler' in trainer.scheduler.keys()

        save_name = trainer.args['save_name']
        filename = tmp_path / save_name
        assert os.path.exists(filename)
        checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True)
        assert any(x.startswith("bert_model") for x in checkpoint['model'].keys())

        # Test loading the saved model, saving it, and still having bert in it
        # even if we have set bert_finetune to False for this incarnation
        pt = pretrain.Pretrain(wordvec_pretrain_file)
        args = {"bert_finetune": False}
        saved_model = Trainer(pretrain=pt, model_file=filename, args=args)

        saved_model.save(filename)

        # This is the part that would fail if the force_bert_saved option did not exist
        checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True)
        assert any(x.startswith("bert_model") for x in checkpoint['model'].keys())

    def test_with_peft(self, tmp_path, wordvec_pretrain_file):
        trainer = self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=['--bert_model', 'hf-internal-testing/tiny-bert', '--bert_finetune', '--bert_hidden_layers', '2', '--use_peft'])
        assert 'bert_optimizer' in trainer.optimizer.keys()
        assert 'bert_scheduler' in trainer.scheduler.keys()

    def test_single_optimizer_checkpoint(self, tmp_path, wordvec_pretrain_file):
        trainer = self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=['--optim', 'adam'])

        save_dir = trainer.args['save_dir']
        save_name = trainer.args['save_name']
        checkpoint_name = trainer.args["checkpoint_save_name"]

        assert os.path.exists(os.path.join(save_dir, save_name))
        assert checkpoint_name is not None
        assert os.path.exists(checkpoint_name)

        assert len(trainer.optimizer) == 1
        for opt in trainer.optimizer.values():
            assert isinstance(opt, torch.optim.Adam)

        pt = pretrain.Pretrain(wordvec_pretrain_file)
        checkpoint = Trainer(args=trainer.args, pretrain=pt, model_file=checkpoint_name)
        assert checkpoint.optimizer is not None
        assert len(checkpoint.optimizer) == 1
        for opt in checkpoint.optimizer.values():
            assert isinstance(opt, torch.optim.Adam)

    def test_two_optimizers_checkpoint(self, tmp_path, wordvec_pretrain_file):
        trainer = self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=['--optim', 'adam', '--second_optim', 'sgd', '--second_optim_start_step', '40'])

        save_dir = trainer.args['save_dir']
        save_name = trainer.args['save_name']
        checkpoint_name = trainer.args["checkpoint_save_name"]

        assert os.path.exists(os.path.join(save_dir, save_name))
        assert checkpoint_name is not None
        assert os.path.exists(checkpoint_name)

        assert len(trainer.optimizer) == 1
        for opt in trainer.optimizer.values():
            assert isinstance(opt, torch.optim.SGD)

        pt = pretrain.Pretrain(wordvec_pretrain_file)
        checkpoint = Trainer(args=trainer.args, pretrain=pt, model_file=checkpoint_name)
        assert checkpoint.optimizer is not None
        assert len(checkpoint.optimizer) == 1
        for opt in trainer.optimizer.values():
            assert isinstance(opt, torch.optim.SGD)