NMT / scripts /post_lora_export_eval.py
marconolimits's picture
deploy: clean orphan branch for HF Spaces - CPU threading optimisation
c7b4419
Raw
History Blame Contribute Delete
3.75 kB
"""
After train-lora completes: export merged LoRA to CTranslate2 and run fast Flores eval both directions.
Requires final_metrics.json under the LoRA output directory.
"""
from __future__ import annotations
import argparse
import subprocess
import sys
import time
from pathlib import Path
def main() -> None:
p = argparse.ArgumentParser(description="Export LoRA to CT2 and evaluate vs Flores (en-it both ways).")
p.add_argument("--lora-output-dir", required=True, help="train-lora --output-dir (must contain adapter/ and final_metrics.json)")
p.add_argument("--ct2-output-dir", required=True, help="New directory for export-lora CTranslate2 model tree")
p.add_argument("--base-model", default="facebook/nllb-200-distilled-600M")
p.add_argument("--quantization", default="int8")
p.add_argument(
"--reports-dir",
default=None,
help="evaluate_nmt_fast --reports-dir (default: reports/post_tune_<run_id>)",
)
p.add_argument(
"--wait-for-metrics",
action="store_true",
help="Poll until final_metrics.json exists, then export and evaluate.",
)
p.add_argument("--poll-seconds", type=int, default=90)
p.add_argument("--timeout-hours", type=int, default=12)
args = p.parse_args()
repo_root = Path(__file__).resolve().parent.parent
lora_out = Path(args.lora_output_dir)
if not lora_out.is_absolute():
lora_out = (repo_root / lora_out).resolve()
else:
lora_out = lora_out.resolve()
metrics = lora_out / "final_metrics.json"
adapter = lora_out / "adapter"
if args.wait_for_metrics:
deadline = time.time() + args.timeout_hours * 3600
print(f"Waiting for {metrics} ...")
while time.time() < deadline:
if metrics.is_file():
break
time.sleep(args.poll_seconds)
else:
print(f"Timeout waiting for {metrics}", file=sys.stderr)
sys.exit(1)
if not metrics.is_file():
print(f"Missing {metrics}; train-lora must finish first.", file=sys.stderr)
sys.exit(1)
if not adapter.is_dir():
print(f"Missing {adapter}.", file=sys.stderr)
sys.exit(1)
run_id = lora_out.name
reports = Path(args.reports_dir) if args.reports_dir else repo_root / "reports" / f"post_tune_{run_id}"
reports.mkdir(parents=True, exist_ok=True)
ct2 = Path(args.ct2_output_dir)
if not ct2.is_absolute():
ct2 = (repo_root / ct2).resolve()
else:
ct2 = ct2.resolve()
ct2.parent.mkdir(parents=True, exist_ok=True)
convert_cmd = [
sys.executable,
str(repo_root / "scripts" / "convert_model.py"),
"export-lora",
"--base-model",
args.base_model,
"--adapter-dir",
str(adapter),
"--output-dir",
str(ct2),
"--quantization",
args.quantization,
]
print("Running:", " ".join(convert_cmd))
subprocess.check_call(convert_cmd, cwd=str(repo_root))
model_dir = ct2 / "model"
eval_py = repo_root / "scripts" / "evaluate_nmt_fast.py"
for src, tgt in (("eng_Latn", "ita_Latn"), ("ita_Latn", "eng_Latn")):
ev = [
sys.executable,
str(eval_py),
"--model-dir",
str(model_dir),
"--spm-model",
str(model_dir / "sentencepiece.bpe.model"),
"--source-lang",
src,
"--target-lang",
tgt,
"--reports-dir",
str(reports),
]
print("Running:", " ".join(ev))
subprocess.check_call(ev, cwd=str(repo_root))
print(f"Done. CT2 model: {model_dir}")
print(f"Reports: {reports}")
if __name__ == "__main__":
main()