| | |
| | """ |
| | Evaluate GPT-5.2 translation quality on MultiClinSum files. |
| | |
| | What this script does: |
| | 1) Loads EN/ES/FR/PT json files (expects fields like id/fulltext/summary) |
| | 2) Aligns EN with each non-EN language by shared numeric case id |
| | 3) Samples N aligned instances per language pair |
| | 4) Runs bidirectional translation with GPT-5.2: |
| | - EN -> X |
| | - X -> EN |
| | 5) Reports common MT metrics used in top venues: |
| | - BLEU (sacreBLEU) |
| | - chrF++ (sacreBLEU chrF) |
| | - COMET (if installed) |
| | - BERTScore F1 (if installed) |
| | """ |
| |
|
| | from __future__ import annotations |
| |
|
| | import argparse |
| | import csv |
| | import json |
| | import os |
| | import random |
| | import re |
| | import sys |
| | import time |
| | from dataclasses import dataclass |
| | from datetime import datetime |
| | from pathlib import Path |
| | from typing import Dict, List, Optional, Tuple |
| |
|
| | from openai import OpenAI |
| | import sacrebleu |
| |
|
| |
|
| | ID_NUM_RE = re.compile(r"_(\d+)\.txt$") |
| |
|
| |
|
| | @dataclass |
| | class Example: |
| | case_id: str |
| | text: str |
| | raw_id: str |
| |
|
| |
|
| | def parse_args() -> argparse.Namespace: |
| | parser = argparse.ArgumentParser(description="GPT-5.2 translation evaluation") |
| | parser.add_argument( |
| | "--en-file", |
| | default="data/testing_data_gs/multiclinsum_gs_train_en.json", |
| | help="Path to English json file", |
| | ) |
| | parser.add_argument( |
| | "--es-file", |
| | default="data/testing_data_gs/multiclinsum_gs_train_es.json", |
| | help="Path to Spanish json file", |
| | ) |
| | parser.add_argument( |
| | "--fr-file", |
| | default="data/testing_data_gs/multiclinsum_gs_train_fr.json", |
| | help="Path to French json file", |
| | ) |
| | parser.add_argument( |
| | "--pt-file", |
| | default="data/testing_data_gs/multiclinsum_gs_train_pt.json", |
| | help="Path to Portuguese json file", |
| | ) |
| | parser.add_argument( |
| | "--num-samples", |
| | type=int, |
| | default=20, |
| | help="Samples per language pair", |
| | ) |
| | parser.add_argument("--seed", type=int, default=42, help="Random seed") |
| | parser.add_argument( |
| | "--model", |
| | default="gpt-5.2", |
| | help="OpenAI model name", |
| | ) |
| | parser.add_argument( |
| | "--max-chars", |
| | type=int, |
| | default=2500, |
| | help="Character cap per sample to control cost/latency", |
| | ) |
| | parser.add_argument( |
| | "--openai-api-key-env", |
| | default="OPENAI_API_KEY", |
| | help="Environment variable that stores OpenAI API key", |
| | ) |
| | parser.add_argument( |
| | "--api-file", |
| | default="/home/mshahidul/api_new.json", |
| | help="JSON file containing API keys (expects key 'openai')", |
| | ) |
| | parser.add_argument( |
| | "--output-dir", |
| | default="/home/mshahidul/readctrl/code/translation_quality_check", |
| | help="Directory to save outputs", |
| | ) |
| | parser.add_argument( |
| | "--skip-comet", |
| | action="store_true", |
| | help="Skip COMET even if installed", |
| | ) |
| | parser.add_argument( |
| | "--skip-bertscore", |
| | action="store_true", |
| | help="Skip BERTScore even if installed", |
| | ) |
| | parser.add_argument( |
| | "--temperature", |
| | type=float, |
| | default=0.0, |
| | help="Decoding temperature", |
| | ) |
| | parser.add_argument( |
| | "--save-every", |
| | type=int, |
| | default=10, |
| | help="Checkpoint save interval (in translated instances)", |
| | ) |
| | return parser.parse_args() |
| |
|
| |
|
| | def load_json(path: str) -> List[dict]: |
| | with open(path, "r", encoding="utf-8") as f: |
| | return json.load(f) |
| |
|
| |
|
| | def normalize_case_id(raw_id: str) -> str: |
| | m = ID_NUM_RE.search(raw_id) |
| | if m: |
| | return m.group(1) |
| | return raw_id |
| |
|
| |
|
| | def dataset_to_examples(rows: List[dict], field: str) -> Dict[str, Example]: |
| | out: Dict[str, Example] = {} |
| | for row in rows: |
| | raw_id = str(row.get("id", "")) |
| | case_id = normalize_case_id(raw_id) |
| | text = row.get(field) |
| | if text is None: |
| | |
| | text = row.get("summary") or row.get("fulltext") or "" |
| | text = str(text).strip() |
| | if not text: |
| | continue |
| | out[case_id] = Example(case_id=case_id, text=text, raw_id=raw_id) |
| | return out |
| |
|
| |
|
| | def truncate_text(text: str, max_chars: int) -> str: |
| | if max_chars <= 0: |
| | return text |
| | if len(text) <= max_chars: |
| | return text |
| | return text[:max_chars].rstrip() + " ..." |
| |
|
| |
|
| | def translate_one( |
| | client: OpenAI, |
| | model: str, |
| | text: str, |
| | src_lang_name: str, |
| | tgt_lang_name: str, |
| | temperature: float, |
| | ) -> str: |
| | system = ( |
| | "You are a professional medical translator. " |
| | "Translate faithfully and naturally. Preserve meaning, numbers, units, " |
| | "named entities, and clinical terminology. Return only the translation." |
| | ) |
| | user = ( |
| | f"Translate the following text from {src_lang_name} to {tgt_lang_name}.\n\n" |
| | f"{text}" |
| | ) |
| | response = client.responses.create( |
| | model=model, |
| | temperature=temperature, |
| | input=[ |
| | {"role": "system", "content": system}, |
| | {"role": "user", "content": user}, |
| | ], |
| | ) |
| | return response.output_text.strip() |
| |
|
| |
|
| | def compute_bleu_chrf(hypotheses: List[str], references: List[str]) -> Dict[str, float]: |
| | bleu = sacrebleu.corpus_bleu(hypotheses, [references]).score |
| | chrf = sacrebleu.corpus_chrf(hypotheses, [references]).score |
| | return {"bleu": round(bleu, 4), "chrf++": round(chrf, 4)} |
| |
|
| |
|
| | def maybe_compute_bertscore( |
| | hypotheses: List[str], |
| | references: List[str], |
| | target_lang: str, |
| | ) -> Optional[float]: |
| | try: |
| | from bert_score import score as bert_score_fn |
| | except Exception: |
| | return None |
| | _, _, f1 = bert_score_fn(hypotheses, references, lang=target_lang, verbose=False) |
| | return round(float(f1.mean().item()), 6) |
| |
|
| |
|
| | def maybe_compute_comet( |
| | sources: List[str], |
| | hypotheses: List[str], |
| | references: List[str], |
| | ) -> Optional[float]: |
| | try: |
| | from comet import download_model, load_from_checkpoint |
| | except Exception: |
| | return None |
| | model_path = download_model("Unbabel/wmt22-comet-da") |
| | comet_model = load_from_checkpoint(model_path) |
| | data = [{"src": s, "mt": h, "ref": r} for s, h, r in zip(sources, hypotheses, references)] |
| | result = comet_model.predict(data, batch_size=8, gpus=1 if os.environ.get("CUDA_VISIBLE_DEVICES") else 0) |
| | return round(float(result.system_score), 6) |
| |
|
| |
|
| | def ensure_dir(path: str) -> None: |
| | Path(path).mkdir(parents=True, exist_ok=True) |
| |
|
| |
|
| | def persist_outputs( |
| | json_path: Path, |
| | details_path: Path, |
| | csv_path: Path, |
| | all_results: dict, |
| | detailed_rows: List[dict], |
| | summary_rows: List[dict], |
| | ) -> None: |
| | with open(json_path, "w", encoding="utf-8") as f: |
| | json.dump(all_results, f, ensure_ascii=False, indent=2) |
| |
|
| | with open(details_path, "w", encoding="utf-8") as f: |
| | for row in detailed_rows: |
| | f.write(json.dumps(row, ensure_ascii=False) + "\n") |
| |
|
| | cols = [ |
| | "language_file", |
| | "direction", |
| | "n_samples", |
| | "bleu", |
| | "chrf++", |
| | "bertscore_f1", |
| | "comet", |
| | "elapsed_sec", |
| | ] |
| | with open(csv_path, "w", encoding="utf-8", newline="") as f: |
| | writer = csv.DictWriter(f, fieldnames=cols) |
| | writer.writeheader() |
| | if summary_rows: |
| | writer.writerows(summary_rows) |
| |
|
| |
|
| | def resolve_openai_api_key(api_file: str, env_var: str) -> str: |
| | |
| | if Path(api_file).exists(): |
| | with open(api_file, "r", encoding="utf-8") as f: |
| | api_keys = json.load(f) |
| | key = api_keys.get("openai") |
| | if key: |
| | return str(key) |
| | key = os.getenv(env_var) |
| | if key: |
| | return key |
| | raise RuntimeError( |
| | f"Missing OpenAI API key. Expected '{api_file}' with key 'openai' " |
| | f"or environment variable '{env_var}'." |
| | ) |
| |
|
| |
|
| | def main() -> None: |
| | args = parse_args() |
| | api_key = resolve_openai_api_key(args.api_file, args.openai_api_key_env) |
| |
|
| | rng = random.Random(args.seed) |
| | client = OpenAI(api_key=api_key) |
| |
|
| | en_rows = load_json(args.en_file) |
| | lang_files = {"es": args.es_file, "fr": args.fr_file, "pt": args.pt_file} |
| |
|
| | |
| | field = "fulltext" |
| | en_map = dataset_to_examples(en_rows, field) |
| | lang_maps = { |
| | lang: dataset_to_examples(load_json(path), field) |
| | for lang, path in lang_files.items() |
| | } |
| |
|
| | lang_name = {"en": "English", "es": "Spanish", "fr": "French", "pt": "Portuguese"} |
| | bert_lang = {"en": "en", "es": "es", "fr": "fr", "pt": "pt"} |
| |
|
| | timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S") |
| | run_dir = Path(args.output_dir) / f"run_{timestamp}" |
| | ensure_dir(str(run_dir)) |
| |
|
| | all_results = { |
| | "run_time_utc": datetime.utcnow().isoformat(), |
| | "settings": { |
| | "model": args.model, |
| | "field": field, |
| | "num_samples": args.num_samples, |
| | "max_chars": args.max_chars, |
| | "seed": args.seed, |
| | "files": { |
| | "en": args.en_file, |
| | "es": args.es_file, |
| | "fr": args.fr_file, |
| | "pt": args.pt_file, |
| | }, |
| | }, |
| | "scores": {}, |
| | } |
| |
|
| | detailed_rows: List[dict] = [] |
| | summary_rows: List[dict] = [] |
| | all_results["partial_scores"] = {} |
| |
|
| | json_path = run_dir / "scores.json" |
| | details_path = run_dir / "translations.jsonl" |
| | csv_path = run_dir / "summary.csv" |
| |
|
| | for tgt_lang, tgt_map in lang_maps.items(): |
| | common_ids = sorted(set(en_map.keys()) & set(tgt_map.keys())) |
| | if not common_ids: |
| | print(f"[WARN] No aligned IDs between en and {tgt_lang}. Skipping.") |
| | continue |
| | k = min(args.num_samples, len(common_ids)) |
| | sampled_ids = rng.sample(common_ids, k=k) |
| |
|
| | pair_results = {} |
| | print(f"[INFO] Evaluating EN <-> {tgt_lang.upper()} with {k} samples") |
| |
|
| | directions = [("en", tgt_lang), (tgt_lang, "en")] |
| | for src_lang, out_lang in directions: |
| | sources: List[str] = [] |
| | refs: List[str] = [] |
| | hyps: List[str] = [] |
| |
|
| | start = time.time() |
| | for idx, case_id in enumerate(sampled_ids, start=1): |
| | src_ex = en_map[case_id] if src_lang == "en" else tgt_map[case_id] |
| | ref_ex = tgt_map[case_id] if out_lang == tgt_lang else en_map[case_id] |
| |
|
| | src_text = truncate_text(src_ex.text, args.max_chars) |
| | ref_text = truncate_text(ref_ex.text, args.max_chars) |
| |
|
| | hyp = translate_one( |
| | client=client, |
| | model=args.model, |
| | text=src_text, |
| | src_lang_name=lang_name[src_lang], |
| | tgt_lang_name=lang_name[out_lang], |
| | temperature=args.temperature, |
| | ) |
| |
|
| | sources.append(src_text) |
| | refs.append(ref_text) |
| | hyps.append(hyp) |
| |
|
| | detailed_rows.append( |
| | { |
| | "target_language_file": tgt_lang, |
| | "direction": f"{src_lang}_to_{out_lang}", |
| | "case_id": case_id, |
| | "src_raw_id": src_ex.raw_id, |
| | "ref_raw_id": ref_ex.raw_id, |
| | "source_text": src_text, |
| | "reference_text": ref_text, |
| | "hypothesis_text": hyp, |
| | } |
| | ) |
| | print( |
| | f" [{src_lang}->{out_lang}] {idx}/{k} done " |
| | f"(case_id={case_id})" |
| | ) |
| |
|
| | if args.save_every > 0 and (idx % args.save_every == 0): |
| | partial_key = f"{tgt_lang}:{src_lang}_to_{out_lang}" |
| | all_results["partial_scores"][partial_key] = { |
| | "completed": idx, |
| | "total": k, |
| | **compute_bleu_chrf(hyps, refs), |
| | } |
| | persist_outputs( |
| | json_path=json_path, |
| | details_path=details_path, |
| | csv_path=csv_path, |
| | all_results=all_results, |
| | detailed_rows=detailed_rows, |
| | summary_rows=summary_rows, |
| | ) |
| | print( |
| | f" [checkpoint] saved at {idx}/{k} " |
| | f"for {src_lang}->{out_lang}" |
| | ) |
| |
|
| | metric_dict = compute_bleu_chrf(hyps, refs) |
| | if not args.skip_bertscore: |
| | bs = maybe_compute_bertscore(hyps, refs, bert_lang[out_lang]) |
| | if bs is not None: |
| | metric_dict["bertscore_f1"] = bs |
| | else: |
| | metric_dict["bertscore_f1"] = None |
| | if not args.skip_comet: |
| | comet = maybe_compute_comet(sources, hyps, refs) |
| | if comet is not None: |
| | metric_dict["comet"] = comet |
| | else: |
| | metric_dict["comet"] = None |
| |
|
| | metric_dict["n_samples"] = k |
| | metric_dict["elapsed_sec"] = round(time.time() - start, 2) |
| | key = f"{src_lang}_to_{out_lang}" |
| | pair_results[key] = metric_dict |
| |
|
| | summary_rows.append( |
| | { |
| | "language_file": tgt_lang, |
| | "direction": key, |
| | **metric_dict, |
| | } |
| | ) |
| |
|
| | all_results["scores"][tgt_lang] = pair_results |
| |
|
| | persist_outputs( |
| | json_path=json_path, |
| | details_path=details_path, |
| | csv_path=csv_path, |
| | all_results=all_results, |
| | detailed_rows=detailed_rows, |
| | summary_rows=summary_rows, |
| | ) |
| |
|
| | print("\n=== Translation Evaluation Complete ===") |
| | print(f"Run directory: {run_dir}") |
| | print(f"Scores JSON: {json_path}") |
| | print(f"Summary CSV: {csv_path}") |
| | print(f"Details JSONL: {details_path}") |
| | if not args.skip_comet: |
| | print("Note: COMET requires the `unbabel-comet` package and model download.") |
| | if not args.skip_bertscore: |
| | print("Note: BERTScore requires the `bert-score` package.") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | try: |
| | main() |
| | except KeyboardInterrupt: |
| | print("\nInterrupted by user.") |
| | sys.exit(130) |
| |
|