|
|
|
|
|
|
|
|
import argparse |
|
|
import os |
|
|
import sys |
|
|
from unittest.mock import patch |
|
|
|
|
|
import pytorch_lightning as pl |
|
|
import timeout_decorator |
|
|
import torch |
|
|
from distillation import SummarizationDistiller, distill_main |
|
|
from finetune import SummarizationModule, main |
|
|
|
|
|
from transformers import MarianMTModel |
|
|
from transformers.file_utils import cached_path |
|
|
from transformers.testing_utils import TestCasePlus, require_torch_gpu, slow |
|
|
from utils import load_json |
|
|
|
|
|
|
|
|
MARIAN_MODEL = "sshleifer/mar_enro_6_3_student" |
|
|
|
|
|
|
|
|
class TestMbartCc25Enro(TestCasePlus): |
|
|
def setUp(self): |
|
|
super().setUp() |
|
|
|
|
|
data_cached = cached_path( |
|
|
"https://cdn-datasets.huggingface.co/translation/wmt_en_ro-tr40k-va0.5k-te0.5k.tar.gz", |
|
|
extract_compressed_file=True, |
|
|
) |
|
|
self.data_dir = f"{data_cached}/wmt_en_ro-tr40k-va0.5k-te0.5k" |
|
|
|
|
|
@slow |
|
|
@require_torch_gpu |
|
|
def test_model_download(self): |
|
|
"""This warms up the cache so that we can time the next test without including download time, which varies between machines.""" |
|
|
MarianMTModel.from_pretrained(MARIAN_MODEL) |
|
|
|
|
|
|
|
|
@slow |
|
|
@require_torch_gpu |
|
|
def test_train_mbart_cc25_enro_script(self): |
|
|
env_vars_to_replace = { |
|
|
"$MAX_LEN": 64, |
|
|
"$BS": 64, |
|
|
"$GAS": 1, |
|
|
"$ENRO_DIR": self.data_dir, |
|
|
"facebook/mbart-large-cc25": MARIAN_MODEL, |
|
|
|
|
|
"--learning_rate=3e-5": "--learning_rate 3e-4", |
|
|
"--num_train_epochs 6": "--num_train_epochs 1", |
|
|
} |
|
|
|
|
|
|
|
|
bash_script = (self.test_file_dir / "train_mbart_cc25_enro.sh").open().read().split("finetune.py")[1].strip() |
|
|
bash_script = bash_script.replace("\\\n", "").strip().replace('"$@"', "") |
|
|
for k, v in env_vars_to_replace.items(): |
|
|
bash_script = bash_script.replace(k, str(v)) |
|
|
output_dir = self.get_auto_remove_tmp_dir() |
|
|
|
|
|
|
|
|
args = f""" |
|
|
--output_dir {output_dir} |
|
|
--tokenizer_name Helsinki-NLP/opus-mt-en-ro |
|
|
--sortish_sampler |
|
|
--do_predict |
|
|
--gpus 1 |
|
|
--freeze_encoder |
|
|
--n_train 40000 |
|
|
--n_val 500 |
|
|
--n_test 500 |
|
|
--fp16_opt_level O1 |
|
|
--num_sanity_val_steps 0 |
|
|
--eval_beams 2 |
|
|
""".split() |
|
|
|
|
|
|
|
|
testargs = ["finetune.py"] + bash_script.split() + args |
|
|
with patch.object(sys, "argv", testargs): |
|
|
parser = argparse.ArgumentParser() |
|
|
parser = pl.Trainer.add_argparse_args(parser) |
|
|
parser = SummarizationModule.add_model_specific_args(parser, os.getcwd()) |
|
|
args = parser.parse_args() |
|
|
model = main(args) |
|
|
|
|
|
|
|
|
metrics = load_json(model.metrics_save_path) |
|
|
first_step_stats = metrics["val"][0] |
|
|
last_step_stats = metrics["val"][-1] |
|
|
self.assertEqual(len(metrics["val"]), (args.max_epochs / args.val_check_interval)) |
|
|
assert isinstance(last_step_stats[f"val_avg_{model.val_metric}"], float) |
|
|
|
|
|
self.assertGreater(last_step_stats["val_avg_gen_time"], 0.01) |
|
|
|
|
|
self.assertLessEqual(last_step_stats["val_avg_gen_time"], 1.0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.assertGreater(last_step_stats["val_avg_bleu"] - first_step_stats["val_avg_bleu"], 2) |
|
|
|
|
|
|
|
|
self.assertGreater(last_step_stats["val_avg_bleu"], 17) |
|
|
|
|
|
|
|
|
self.assertLess(abs(metrics["val"][-1]["val_avg_bleu"] - metrics["test"][-1]["test_avg_bleu"]), 1.1) |
|
|
|
|
|
|
|
|
contents = os.listdir(output_dir) |
|
|
ckpt_path = [x for x in contents if x.endswith(".ckpt")][0] |
|
|
full_path = os.path.join(args.output_dir, ckpt_path) |
|
|
ckpt = torch.load(full_path, map_location="cpu") |
|
|
expected_key = "model.model.decoder.layers.0.encoder_attn_layer_norm.weight" |
|
|
assert expected_key in ckpt["state_dict"] |
|
|
assert ckpt["state_dict"]["model.model.decoder.layers.0.encoder_attn_layer_norm.weight"].dtype == torch.float32 |
|
|
|
|
|
|
|
|
if args.do_predict: |
|
|
contents = {os.path.basename(p) for p in contents} |
|
|
assert "test_generations.txt" in contents |
|
|
assert "test_results.txt" in contents |
|
|
|
|
|
assert len(metrics["test"]) == 1 |
|
|
|
|
|
|
|
|
class TestDistilMarianNoTeacher(TestCasePlus): |
|
|
@timeout_decorator.timeout(600) |
|
|
@slow |
|
|
@require_torch_gpu |
|
|
def test_opus_mt_distill_script(self): |
|
|
data_dir = f"{self.test_file_dir_str}/test_data/wmt_en_ro" |
|
|
env_vars_to_replace = { |
|
|
"--fp16_opt_level=O1": "", |
|
|
"$MAX_LEN": 128, |
|
|
"$BS": 16, |
|
|
"$GAS": 1, |
|
|
"$ENRO_DIR": data_dir, |
|
|
"$m": "sshleifer/student_marian_en_ro_6_1", |
|
|
"val_check_interval=0.25": "val_check_interval=1.0", |
|
|
} |
|
|
|
|
|
|
|
|
bash_script = ( |
|
|
(self.test_file_dir / "distil_marian_no_teacher.sh").open().read().split("distillation.py")[1].strip() |
|
|
) |
|
|
bash_script = bash_script.replace("\\\n", "").strip().replace('"$@"', "") |
|
|
bash_script = bash_script.replace("--fp16 ", " ") |
|
|
|
|
|
for k, v in env_vars_to_replace.items(): |
|
|
bash_script = bash_script.replace(k, str(v)) |
|
|
output_dir = self.get_auto_remove_tmp_dir() |
|
|
bash_script = bash_script.replace("--fp16", "") |
|
|
epochs = 6 |
|
|
testargs = ( |
|
|
["distillation.py"] |
|
|
+ bash_script.split() |
|
|
+ [ |
|
|
f"--output_dir={output_dir}", |
|
|
"--gpus=1", |
|
|
"--learning_rate=1e-3", |
|
|
f"--num_train_epochs={epochs}", |
|
|
"--warmup_steps=10", |
|
|
"--val_check_interval=1.0", |
|
|
"--do_predict", |
|
|
] |
|
|
) |
|
|
with patch.object(sys, "argv", testargs): |
|
|
parser = argparse.ArgumentParser() |
|
|
parser = pl.Trainer.add_argparse_args(parser) |
|
|
parser = SummarizationDistiller.add_model_specific_args(parser, os.getcwd()) |
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
model = distill_main(args) |
|
|
|
|
|
|
|
|
metrics = load_json(model.metrics_save_path) |
|
|
first_step_stats = metrics["val"][0] |
|
|
last_step_stats = metrics["val"][-1] |
|
|
assert len(metrics["val"]) >= (args.max_epochs / args.val_check_interval) |
|
|
|
|
|
assert last_step_stats["val_avg_gen_time"] >= 0.01 |
|
|
|
|
|
assert first_step_stats["val_avg_bleu"] < last_step_stats["val_avg_bleu"] |
|
|
assert 1.0 >= last_step_stats["val_avg_gen_time"] |
|
|
assert isinstance(last_step_stats[f"val_avg_{model.val_metric}"], float) |
|
|
|
|
|
|
|
|
contents = os.listdir(output_dir) |
|
|
ckpt_path = [x for x in contents if x.endswith(".ckpt")][0] |
|
|
full_path = os.path.join(args.output_dir, ckpt_path) |
|
|
ckpt = torch.load(full_path, map_location="cpu") |
|
|
expected_key = "model.model.decoder.layers.0.encoder_attn_layer_norm.weight" |
|
|
assert expected_key in ckpt["state_dict"] |
|
|
assert ckpt["state_dict"]["model.model.decoder.layers.0.encoder_attn_layer_norm.weight"].dtype == torch.float32 |
|
|
|
|
|
|
|
|
if args.do_predict: |
|
|
contents = {os.path.basename(p) for p in contents} |
|
|
assert "test_generations.txt" in contents |
|
|
assert "test_results.txt" in contents |
|
|
|
|
|
assert len(metrics["test"]) == 1 |
|
|
|