| |
| """ |
| Evaluate GPT-5.2 translation quality on MultiClinSum files. |
| |
| What this script does: |
| 1) Loads EN/ES/FR/PT json files (expects fields like id/fulltext/summary) |
| 2) Aligns EN with each non-EN language by shared numeric case id |
| 3) Samples N aligned instances per language pair |
| 4) Runs bidirectional translation with GPT-5.2: |
| - EN -> X |
| - X -> EN |
| 5) Reports common MT metrics used in top venues: |
| - BLEU (sacreBLEU) |
| - chrF++ (sacreBLEU chrF) |
| - COMET (if installed) |
| - BERTScore F1 (if installed) |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import csv |
| import json |
| import os |
| import random |
| import re |
| import sys |
| import time |
| from dataclasses import dataclass |
| from datetime import datetime |
| from pathlib import Path |
| from typing import Dict, List, Optional, Tuple |
|
|
| from openai import OpenAI |
| import sacrebleu |
|
|
|
|
| ID_NUM_RE = re.compile(r"_(\d+)\.txt$") |
|
|
|
|
| @dataclass |
| class Example: |
| case_id: str |
| text: str |
| raw_id: str |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser(description="GPT-5.2 translation evaluation") |
| parser.add_argument( |
| "--en-file", |
| default="data/testing_data_gs/multiclinsum_gs_train_en.json", |
| help="Path to English json file", |
| ) |
| parser.add_argument( |
| "--es-file", |
| default="data/testing_data_gs/multiclinsum_gs_train_es.json", |
| help="Path to Spanish json file", |
| ) |
| parser.add_argument( |
| "--fr-file", |
| default="data/testing_data_gs/multiclinsum_gs_train_fr.json", |
| help="Path to French json file", |
| ) |
| parser.add_argument( |
| "--pt-file", |
| default="data/testing_data_gs/multiclinsum_gs_train_pt.json", |
| help="Path to Portuguese json file", |
| ) |
| parser.add_argument( |
| "--num-samples", |
| type=int, |
| default=20, |
| help="Samples per language pair", |
| ) |
| parser.add_argument("--seed", type=int, default=42, help="Random seed") |
| parser.add_argument( |
| "--model", |
| default="gpt-5.2", |
| help="OpenAI model name", |
| ) |
| parser.add_argument( |
| "--max-chars", |
| type=int, |
| default=2500, |
| help="Character cap per sample to control cost/latency", |
| ) |
| parser.add_argument( |
| "--openai-api-key-env", |
| default="OPENAI_API_KEY", |
| help="Environment variable that stores OpenAI API key", |
| ) |
| parser.add_argument( |
| "--api-file", |
| default="/home/mshahidul/api_new.json", |
| help="JSON file containing API keys (expects key 'openai')", |
| ) |
| parser.add_argument( |
| "--output-dir", |
| default="/home/mshahidul/readctrl/code/translation_quality_check", |
| help="Directory to save outputs", |
| ) |
| parser.add_argument( |
| "--skip-comet", |
| action="store_true", |
| help="Skip COMET even if installed", |
| ) |
| parser.add_argument( |
| "--skip-bertscore", |
| action="store_true", |
| help="Skip BERTScore even if installed", |
| ) |
| parser.add_argument( |
| "--temperature", |
| type=float, |
| default=0.0, |
| help="Decoding temperature", |
| ) |
| parser.add_argument( |
| "--save-every", |
| type=int, |
| default=10, |
| help="Checkpoint save interval (in translated instances)", |
| ) |
| return parser.parse_args() |
|
|
|
|
| def load_json(path: str) -> List[dict]: |
| with open(path, "r", encoding="utf-8") as f: |
| return json.load(f) |
|
|
|
|
| def normalize_case_id(raw_id: str) -> str: |
| m = ID_NUM_RE.search(raw_id) |
| if m: |
| return m.group(1) |
| return raw_id |
|
|
|
|
| def dataset_to_examples(rows: List[dict], field: str) -> Dict[str, Example]: |
| out: Dict[str, Example] = {} |
| for row in rows: |
| raw_id = str(row.get("id", "")) |
| case_id = normalize_case_id(raw_id) |
| text = row.get(field) |
| if text is None: |
| |
| text = row.get("summary") or row.get("fulltext") or "" |
| text = str(text).strip() |
| if not text: |
| continue |
| out[case_id] = Example(case_id=case_id, text=text, raw_id=raw_id) |
| return out |
|
|
|
|
| def truncate_text(text: str, max_chars: int) -> str: |
| if max_chars <= 0: |
| return text |
| if len(text) <= max_chars: |
| return text |
| return text[:max_chars].rstrip() + " ..." |
|
|
|
|
| def translate_one( |
| client: OpenAI, |
| model: str, |
| text: str, |
| src_lang_name: str, |
| tgt_lang_name: str, |
| temperature: float, |
| ) -> str: |
| system = ( |
| "You are a professional medical translator. " |
| "Translate faithfully and naturally. Preserve meaning, numbers, units, " |
| "named entities, and clinical terminology. Return only the translation." |
| ) |
| user = ( |
| f"Translate the following text from {src_lang_name} to {tgt_lang_name}.\n\n" |
| f"{text}" |
| ) |
| response = client.responses.create( |
| model=model, |
| temperature=temperature, |
| input=[ |
| {"role": "system", "content": system}, |
| {"role": "user", "content": user}, |
| ], |
| ) |
| return response.output_text.strip() |
|
|
|
|
| def compute_bleu_chrf(hypotheses: List[str], references: List[str]) -> Dict[str, float]: |
| bleu = sacrebleu.corpus_bleu(hypotheses, [references]).score |
| chrf = sacrebleu.corpus_chrf(hypotheses, [references]).score |
| return {"bleu": round(bleu, 4), "chrf++": round(chrf, 4)} |
|
|
|
|
| def maybe_compute_bertscore( |
| hypotheses: List[str], |
| references: List[str], |
| target_lang: str, |
| ) -> Optional[float]: |
| try: |
| from bert_score import score as bert_score_fn |
| except Exception: |
| return None |
| _, _, f1 = bert_score_fn(hypotheses, references, lang=target_lang, verbose=False) |
| return round(float(f1.mean().item()), 6) |
|
|
|
|
| def maybe_compute_comet( |
| sources: List[str], |
| hypotheses: List[str], |
| references: List[str], |
| ) -> Optional[float]: |
| try: |
| from comet import download_model, load_from_checkpoint |
| except Exception: |
| return None |
| model_path = download_model("Unbabel/wmt22-comet-da") |
| comet_model = load_from_checkpoint(model_path) |
| data = [{"src": s, "mt": h, "ref": r} for s, h, r in zip(sources, hypotheses, references)] |
| result = comet_model.predict(data, batch_size=8, gpus=1 if os.environ.get("CUDA_VISIBLE_DEVICES") else 0) |
| return round(float(result.system_score), 6) |
|
|
|
|
| def ensure_dir(path: str) -> None: |
| Path(path).mkdir(parents=True, exist_ok=True) |
|
|
|
|
| def persist_outputs( |
| json_path: Path, |
| details_path: Path, |
| csv_path: Path, |
| all_results: dict, |
| detailed_rows: List[dict], |
| summary_rows: List[dict], |
| ) -> None: |
| with open(json_path, "w", encoding="utf-8") as f: |
| json.dump(all_results, f, ensure_ascii=False, indent=2) |
|
|
| with open(details_path, "w", encoding="utf-8") as f: |
| for row in detailed_rows: |
| f.write(json.dumps(row, ensure_ascii=False) + "\n") |
|
|
| cols = [ |
| "language_file", |
| "direction", |
| "n_samples", |
| "bleu", |
| "chrf++", |
| "bertscore_f1", |
| "comet", |
| "elapsed_sec", |
| ] |
| with open(csv_path, "w", encoding="utf-8", newline="") as f: |
| writer = csv.DictWriter(f, fieldnames=cols) |
| writer.writeheader() |
| if summary_rows: |
| writer.writerows(summary_rows) |
|
|
|
|
| def resolve_openai_api_key(api_file: str, env_var: str) -> str: |
| |
| if Path(api_file).exists(): |
| with open(api_file, "r", encoding="utf-8") as f: |
| api_keys = json.load(f) |
| key = api_keys.get("openai") |
| if key: |
| return str(key) |
| key = os.getenv(env_var) |
| if key: |
| return key |
| raise RuntimeError( |
| f"Missing OpenAI API key. Expected '{api_file}' with key 'openai' " |
| f"or environment variable '{env_var}'." |
| ) |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| api_key = resolve_openai_api_key(args.api_file, args.openai_api_key_env) |
|
|
| rng = random.Random(args.seed) |
| client = OpenAI(api_key=api_key) |
|
|
| en_rows = load_json(args.en_file) |
| lang_files = {"es": args.es_file, "fr": args.fr_file, "pt": args.pt_file} |
|
|
| |
| field = "fulltext" |
| en_map = dataset_to_examples(en_rows, field) |
| lang_maps = { |
| lang: dataset_to_examples(load_json(path), field) |
| for lang, path in lang_files.items() |
| } |
|
|
| lang_name = {"en": "English", "es": "Spanish", "fr": "French", "pt": "Portuguese"} |
| bert_lang = {"en": "en", "es": "es", "fr": "fr", "pt": "pt"} |
|
|
| timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S") |
| run_dir = Path(args.output_dir) / f"run_{timestamp}" |
| ensure_dir(str(run_dir)) |
|
|
| all_results = { |
| "run_time_utc": datetime.utcnow().isoformat(), |
| "settings": { |
| "model": args.model, |
| "field": field, |
| "num_samples": args.num_samples, |
| "max_chars": args.max_chars, |
| "seed": args.seed, |
| "files": { |
| "en": args.en_file, |
| "es": args.es_file, |
| "fr": args.fr_file, |
| "pt": args.pt_file, |
| }, |
| }, |
| "scores": {}, |
| } |
|
|
| detailed_rows: List[dict] = [] |
| summary_rows: List[dict] = [] |
| all_results["partial_scores"] = {} |
|
|
| json_path = run_dir / "scores.json" |
| details_path = run_dir / "translations.jsonl" |
| csv_path = run_dir / "summary.csv" |
|
|
| for tgt_lang, tgt_map in lang_maps.items(): |
| common_ids = sorted(set(en_map.keys()) & set(tgt_map.keys())) |
| if not common_ids: |
| print(f"[WARN] No aligned IDs between en and {tgt_lang}. Skipping.") |
| continue |
| k = min(args.num_samples, len(common_ids)) |
| sampled_ids = rng.sample(common_ids, k=k) |
|
|
| pair_results = {} |
| print(f"[INFO] Evaluating EN <-> {tgt_lang.upper()} with {k} samples") |
|
|
| directions = [("en", tgt_lang), (tgt_lang, "en")] |
| for src_lang, out_lang in directions: |
| sources: List[str] = [] |
| refs: List[str] = [] |
| hyps: List[str] = [] |
|
|
| start = time.time() |
| for idx, case_id in enumerate(sampled_ids, start=1): |
| src_ex = en_map[case_id] if src_lang == "en" else tgt_map[case_id] |
| ref_ex = tgt_map[case_id] if out_lang == tgt_lang else en_map[case_id] |
|
|
| src_text = truncate_text(src_ex.text, args.max_chars) |
| ref_text = truncate_text(ref_ex.text, args.max_chars) |
|
|
| hyp = translate_one( |
| client=client, |
| model=args.model, |
| text=src_text, |
| src_lang_name=lang_name[src_lang], |
| tgt_lang_name=lang_name[out_lang], |
| temperature=args.temperature, |
| ) |
|
|
| sources.append(src_text) |
| refs.append(ref_text) |
| hyps.append(hyp) |
|
|
| detailed_rows.append( |
| { |
| "target_language_file": tgt_lang, |
| "direction": f"{src_lang}_to_{out_lang}", |
| "case_id": case_id, |
| "src_raw_id": src_ex.raw_id, |
| "ref_raw_id": ref_ex.raw_id, |
| "source_text": src_text, |
| "reference_text": ref_text, |
| "hypothesis_text": hyp, |
| } |
| ) |
| print( |
| f" [{src_lang}->{out_lang}] {idx}/{k} done " |
| f"(case_id={case_id})" |
| ) |
|
|
| if args.save_every > 0 and (idx % args.save_every == 0): |
| partial_key = f"{tgt_lang}:{src_lang}_to_{out_lang}" |
| all_results["partial_scores"][partial_key] = { |
| "completed": idx, |
| "total": k, |
| **compute_bleu_chrf(hyps, refs), |
| } |
| persist_outputs( |
| json_path=json_path, |
| details_path=details_path, |
| csv_path=csv_path, |
| all_results=all_results, |
| detailed_rows=detailed_rows, |
| summary_rows=summary_rows, |
| ) |
| print( |
| f" [checkpoint] saved at {idx}/{k} " |
| f"for {src_lang}->{out_lang}" |
| ) |
|
|
| metric_dict = compute_bleu_chrf(hyps, refs) |
| if not args.skip_bertscore: |
| bs = maybe_compute_bertscore(hyps, refs, bert_lang[out_lang]) |
| if bs is not None: |
| metric_dict["bertscore_f1"] = bs |
| else: |
| metric_dict["bertscore_f1"] = None |
| if not args.skip_comet: |
| comet = maybe_compute_comet(sources, hyps, refs) |
| if comet is not None: |
| metric_dict["comet"] = comet |
| else: |
| metric_dict["comet"] = None |
|
|
| metric_dict["n_samples"] = k |
| metric_dict["elapsed_sec"] = round(time.time() - start, 2) |
| key = f"{src_lang}_to_{out_lang}" |
| pair_results[key] = metric_dict |
|
|
| summary_rows.append( |
| { |
| "language_file": tgt_lang, |
| "direction": key, |
| **metric_dict, |
| } |
| ) |
|
|
| all_results["scores"][tgt_lang] = pair_results |
|
|
| persist_outputs( |
| json_path=json_path, |
| details_path=details_path, |
| csv_path=csv_path, |
| all_results=all_results, |
| detailed_rows=detailed_rows, |
| summary_rows=summary_rows, |
| ) |
|
|
| print("\n=== Translation Evaluation Complete ===") |
| print(f"Run directory: {run_dir}") |
| print(f"Scores JSON: {json_path}") |
| print(f"Summary CSV: {csv_path}") |
| print(f"Details JSONL: {details_path}") |
| if not args.skip_comet: |
| print("Note: COMET requires the `unbabel-comet` package and model download.") |
| if not args.skip_bertscore: |
| print("Note: BERTScore requires the `bert-score` package.") |
|
|
|
|
| if __name__ == "__main__": |
| try: |
| main() |
| except KeyboardInterrupt: |
| print("\nInterrupted by user.") |
| sys.exit(130) |
|
|