Spaces:
Sleeping
Sleeping
| """ | |
| evaluate_nmt_fast.py - Fast direct-Python evaluation (no TCP overhead). | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import logging | |
| import os | |
| import tarfile | |
| import urllib.request | |
| from datetime import datetime | |
| from pathlib import Path | |
| import ctranslate2 | |
| import sacrebleu | |
| import sentencepiece as spm | |
| from sacrebleu.metrics import BLEU, CHRF | |
| from tqdm import tqdm | |
| try: | |
| from huggingface_hub import snapshot_download | |
| except ImportError: | |
| snapshot_download = None # type: ignore[misc, assignment] | |
| try: | |
| from scripts.nmt_tcp_server import DEFAULT_SRC_LANG, DEFAULT_TGT_LANG, validate_lang_pair | |
| except ModuleNotFoundError: | |
| from nmt_tcp_server import DEFAULT_SRC_LANG, DEFAULT_TGT_LANG, validate_lang_pair | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") | |
| FLORES_URL = "https://dl.fbaipublicfiles.com/nllb/flores200_dataset.tar.gz" | |
| def load_model(model_dir: str, inter_threads: int, device: str | None = None, device_index: int | None = None) -> ctranslate2.Translator: | |
| dev = (device or os.environ.get("NMT_DEVICE", "cpu")).strip().lower() or "cpu" | |
| if device_index is None: | |
| try: | |
| idx = int(os.environ.get("NMT_DEVICE_INDEX", "0")) | |
| except ValueError: | |
| idx = 0 | |
| else: | |
| idx = device_index | |
| logging.info("Loading CTranslate2 model from '%s' (device=%s index=%s) ...", model_dir, dev, idx) | |
| if dev == "cpu": | |
| translator = ctranslate2.Translator( | |
| model_dir, | |
| device=dev, | |
| inter_threads=inter_threads, | |
| intra_threads=0, | |
| ) | |
| else: | |
| translator = ctranslate2.Translator( | |
| model_dir, | |
| device=dev, | |
| device_index=idx, | |
| inter_threads=inter_threads, | |
| intra_threads=0, | |
| ) | |
| logging.info("Model loaded.") | |
| return translator | |
| def load_spm(spm_path: str) -> spm.SentencePieceProcessor: | |
| sp = spm.SentencePieceProcessor() | |
| sp.Load(spm_path) | |
| return sp | |
| def tokenize_batch(sp: spm.SentencePieceProcessor, texts: list[str], src_lang: str) -> list[list[str]]: | |
| """Tokenize a list of sentences and append NLLB language tags.""" | |
| batch = [] | |
| for text in texts: | |
| tokens = sp.EncodeAsPieces(text) | |
| tokens.append("</s>") | |
| tokens.append(src_lang) | |
| batch.append(tokens) | |
| return batch | |
| def detokenize_and_clean(sp: spm.SentencePieceProcessor, token_sequences: list[list[str]], tgt_lang: str) -> list[str]: | |
| """Decode token sequences and strip the leading language tag.""" | |
| results = [] | |
| for tokens in token_sequences: | |
| text = sp.Decode(tokens) | |
| if text.startswith(tgt_lang): | |
| text = text[len(tgt_lang):].lstrip() | |
| results.append(text) | |
| return results | |
| def translate_all( | |
| translator: ctranslate2.Translator, | |
| sp: spm.SentencePieceProcessor, | |
| sentences: list[str], | |
| src_lang: str, | |
| tgt_lang: str, | |
| batch_size: int, | |
| beam_size: int, | |
| ) -> list[str]: | |
| predictions = [] | |
| target_prefix = [[tgt_lang]] # reused for every sentence in a batch | |
| for start in tqdm(range(0, len(sentences), batch_size), desc="Translating batches"): | |
| chunk = sentences[start : start + batch_size] | |
| tokenized = tokenize_batch(sp, chunk, src_lang) | |
| tgt_prefixes = [target_prefix[0]] * len(chunk) | |
| results = translator.translate_batch( | |
| tokenized, | |
| target_prefix=[tgt_prefixes[i] for i in range(len(chunk))], | |
| beam_size=beam_size, | |
| max_decoding_length=256, | |
| ) | |
| raw_tokens = [r.hypotheses[0] for r in results] | |
| decoded = detokenize_and_clean(sp, raw_tokens, tgt_lang) | |
| predictions.extend(decoded) | |
| return predictions | |
| def load_flores(lang: str, max_sentences: int, flores_cache: str) -> list[str]: | |
| path = os.path.join(flores_cache, "devtest", f"{lang}.devtest") | |
| with open(path, encoding="utf-8") as f: | |
| return [line.strip() for line in f if line.strip()][:max_sentences] | |
| def ensure_flores_cache() -> str: | |
| flores_cache = os.path.join(os.environ.get("TEMP", "/tmp"), "flores200_dataset") | |
| flores_archive = os.path.join(os.environ.get("TEMP", "/tmp"), "flores200.tar.gz") | |
| devtest_ok = os.path.isfile(os.path.join(flores_cache, "devtest", "eng_Latn.devtest")) | |
| if not os.path.isdir(flores_cache) or not devtest_ok: | |
| if os.path.isdir(flores_cache) and not devtest_ok: | |
| logging.warning("FLORES cache at %s is incomplete; re-downloading.", flores_cache) | |
| logging.info("Downloading flores200 from %s ...", FLORES_URL) | |
| urllib.request.urlretrieve(FLORES_URL, flores_archive) | |
| logging.info("Extracting...") | |
| with tarfile.open(flores_archive, "r:gz") as tar: | |
| tar.extractall(os.path.dirname(flores_cache)) | |
| logging.info("Done.") | |
| else: | |
| logging.info("Using cached flores200 at %s", flores_cache) | |
| return flores_cache | |
| def resolve_flores_root(flores_cache: str) -> str: | |
| """Return the directory whose child is `devtest/` containing FLORES language files.""" | |
| if os.path.isfile(os.path.join(flores_cache, "devtest", "eng_Latn.devtest")): | |
| return flores_cache | |
| for root, dirs, _files in os.walk(flores_cache): | |
| if "devtest" not in dirs: | |
| continue | |
| dt = os.path.join(root, "devtest") | |
| if os.path.isfile(os.path.join(dt, "eng_Latn.devtest")): | |
| logging.info("Resolved FLORES devtest under %s", root) | |
| return root | |
| return flores_cache | |
| def parse_args() -> argparse.Namespace: | |
| p = argparse.ArgumentParser(description="Fast BLEU/chrF evaluator for En<->It.") | |
| p.add_argument( | |
| "--model-dir", | |
| default="artifacts/ct2/en_it_v4_casual_weighted/model", | |
| help="Path to converted CT2 model", | |
| ) | |
| p.add_argument( | |
| "--device", | |
| default=None, | |
| help="CTranslate2 device (default: env NMT_DEVICE or cpu). Use cuda for GPU inference.", | |
| ) | |
| p.add_argument("--device-index", type=int, default=None, help="GPU index when device is cuda (default: env NMT_DEVICE_INDEX or 0).") | |
| p.add_argument("--spm-model", default=None, help="Default: <model-dir>/sentencepiece.bpe.model") | |
| p.add_argument("--source-lang", default=DEFAULT_SRC_LANG) | |
| p.add_argument("--target-lang", default=DEFAULT_TGT_LANG) | |
| p.add_argument("--max-sentences", type=int, default=1000) | |
| p.add_argument("--batch-size", type=int, default=32) | |
| p.add_argument("--inter-threads", type=int, default=8) | |
| p.add_argument("--beam-size", type=int, default=1) | |
| p.add_argument("--reports-dir", default="reports/baseline") | |
| p.add_argument( | |
| "--sentence-metrics-out", | |
| default=None, | |
| help="Optional path to write JSONL with per-sentence chrF++ and smoothed sentence BLEU for statistics.", | |
| ) | |
| p.add_argument( | |
| "--hf-repo", | |
| default=None, | |
| metavar="REPO_ID", | |
| help="If set, download/sync this Hugging Face CTranslate2 repo and use it as --model-dir (overrides --model-dir).", | |
| ) | |
| p.add_argument( | |
| "--hf-cache-dir", | |
| default=None, | |
| help="Cache directory for --hf-repo snapshot_download (default: env HF_HOME or ./.hf_cache).", | |
| ) | |
| return p.parse_args() | |
| def resolve_hf_repo_model_dir(repo_id: str, cache_dir: str | None) -> str: | |
| if snapshot_download is None: | |
| raise RuntimeError("huggingface_hub is required for --hf-repo (pip install huggingface_hub).") | |
| cd = cache_dir or os.environ.get("HF_HOME") or os.path.join(os.getcwd(), ".hf_cache") | |
| root = snapshot_download(repo_id=repo_id, cache_dir=cd) | |
| return root | |
| def main() -> None: | |
| args = parse_args() | |
| validate_lang_pair(args.source_lang, args.target_lang) | |
| model_dir = resolve_hf_repo_model_dir(args.hf_repo, args.hf_cache_dir) if args.hf_repo else args.model_dir | |
| spm_model = args.spm_model or os.path.join(model_dir, "sentencepiece.bpe.model") | |
| if not os.path.isdir(model_dir): | |
| raise FileNotFoundError(f"Model directory not found: {model_dir}") | |
| if not os.path.isfile(spm_model): | |
| raise FileNotFoundError(f"SentencePiece model not found: {spm_model}") | |
| translator = load_model(model_dir, args.inter_threads, device=args.device, device_index=args.device_index) | |
| sp = load_spm(spm_model) | |
| flores_cache = resolve_flores_root(ensure_flores_cache()) | |
| source_sentences = load_flores(args.source_lang, args.max_sentences, flores_cache) | |
| target_references = load_flores(args.target_lang, args.max_sentences, flores_cache) | |
| total = min(len(source_sentences), len(target_references)) | |
| source_sentences = source_sentences[:total] | |
| target_references = target_references[:total] | |
| logging.info("Loaded %d flores200 devtest sentence pairs.", total) | |
| predictions = translate_all( | |
| translator, | |
| sp, | |
| source_sentences, | |
| args.source_lang, | |
| args.target_lang, | |
| args.batch_size, | |
| args.beam_size, | |
| ) | |
| logging.info("Calculating BLEU and chrF++ scores...") | |
| refs = [target_references] | |
| bleu = sacrebleu.corpus_bleu(predictions, refs) | |
| chrf = sacrebleu.corpus_chrf(predictions, refs) | |
| logging.info("BLEU Score : %.2f", bleu.score) | |
| logging.info("chrF++ Score: %.2f", chrf.score) | |
| if args.sentence_metrics_out: | |
| chrf_metric = CHRF(word_order=2, beta=2) | |
| bleu_metric = BLEU(effective_order=True) | |
| out_path = Path(args.sentence_metrics_out) | |
| out_path.parent.mkdir(parents=True, exist_ok=True) | |
| with out_path.open("w", encoding="utf-8") as jf: | |
| for i, (hyp, ref) in enumerate(zip(predictions, target_references)): | |
| s_chrf = float(chrf_metric.sentence_score(hyp, [ref]).score) | |
| s_bleu = float(bleu_metric.sentence_score(hyp, [ref]).score) | |
| rec = { | |
| "i": i, | |
| "source_lang": args.source_lang, | |
| "target_lang": args.target_lang, | |
| "chrf": s_chrf, | |
| "bleu": s_bleu, | |
| } | |
| jf.write(json.dumps(rec) + "\n") | |
| logging.info("Wrote sentence-level metrics for %d lines to %s", total, out_path) | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| pair_name = f"{args.source_lang}_to_{args.target_lang}" | |
| reports_dir = Path(args.reports_dir) | |
| reports_dir.mkdir(parents=True, exist_ok=True) | |
| txt_path = reports_dir / f"{timestamp}_{pair_name}.txt" | |
| json_path = reports_dir / f"{timestamp}_{pair_name}.json" | |
| with txt_path.open("w", encoding="utf-8") as f: | |
| f.write("=== NMT-MenKan Accuracy Evaluation ===\n") | |
| f.write("Script: evaluate_nmt_fast.py (direct CTranslate2, batched)\n") | |
| f.write(f"Dataset: flores200 devtest ({args.source_lang} -> {args.target_lang})\n") | |
| f.write(f"Sentences: {total}\n") | |
| f.write( | |
| f"Batch size: {args.batch_size} Beam size: {args.beam_size} " | |
| f"Threads: {args.inter_threads}\n\n" | |
| ) | |
| f.write("--- Metrics ---\n") | |
| f.write(f"BLEU Score : {bleu.score:.2f}\n") | |
| f.write(f"{bleu.format()}\n\n") | |
| f.write(f"chrF++ Score: {chrf.score:.2f}\n") | |
| f.write(f"{chrf.format()}\n\n") | |
| f.write("--- Sample Outputs (First 5) ---\n") | |
| for i in range(min(5, total)): | |
| f.write(f"SRC: {source_sentences[i]}\n") | |
| f.write(f"PRED: {predictions[i]}\n") | |
| f.write(f"REF: {target_references[i]}\n") | |
| f.write("-" * 40 + "\n") | |
| payload = { | |
| "source_lang": args.source_lang, | |
| "target_lang": args.target_lang, | |
| "dataset": "flores200/devtest", | |
| "sentences": total, | |
| "batch_size": args.batch_size, | |
| "beam_size": args.beam_size, | |
| "inter_threads": args.inter_threads, | |
| "bleu": bleu.score, | |
| "chrf": chrf.score, | |
| "text_report": str(txt_path), | |
| } | |
| json_path.write_text(json.dumps(payload, indent=2), encoding="utf-8") | |
| logging.info("Reports saved to %s and %s", txt_path, json_path) | |
| if __name__ == "__main__": | |
| main() | |