| """ |
| metrics.py |
| ---------- |
| Evaluation metrics for 3 tasks: |
| |
| Findings / Impression Generation: |
| - BLEU-1, BLEU-4 |
| - ROUGE-L |
| - METEOR |
| - BERTScore (F1) |
| - ClinicalF1 (via CheXbert — clinical correctness metric) |
| |
| VQA: |
| - Accuracy (exact match) |
| - Token-level F1 |
| - BLEU-1 (for open-ended answers) |
| - METEOR (synonym + stem aware) |
| - BERTScore (semantic similarity) |
| - LLM-as-Judge (optional, GPT/Claude/Gemini for clinical semantic eval) |
| """ |
|
|
| import os |
| import re |
| import json |
| import time |
| from typing import List, Dict, Optional, Tuple |
|
|
| import torch |
| import numpy as np |
| from nltk.translate.bleu_score import corpus_bleu, sentence_bleu, SmoothingFunction |
| from nltk.translate.meteor_score import meteor_score as nltk_meteor |
| from rouge_score import rouge_scorer |
|
|
|
|
| |
| |
| def _ensure_nltk_data(): |
| import nltk |
| for pkg, path in [ |
| ("wordnet", "corpora/wordnet"), |
| ("omw-1.4", "corpora/omw-1.4"), |
| ("punkt", "tokenizers/punkt"), |
| ]: |
| try: |
| nltk.data.find(path) |
| except LookupError: |
| nltk.download(pkg, quiet=True) |
|
|
|
|
| |
|
|
| def compute_bleu( |
| hypotheses: List[str], |
| references: List[str], |
| ) -> Dict[str, float]: |
| """ |
| Compute corpus-level BLEU-1 and BLEU-4. |
| |
| Args: |
| hypotheses: list of generated texts |
| references: list of ground truth texts |
| |
| Returns: |
| {"bleu1": float, "bleu4": float} |
| """ |
| smooth = SmoothingFunction().method1 |
|
|
| refs_tokenized = [[ref.lower().split()] for ref in references] |
| hyps_tokenized = [hyp.lower().split() for hyp in hypotheses] |
|
|
| bleu1 = corpus_bleu(refs_tokenized, hyps_tokenized, |
| weights=(1, 0, 0, 0), smoothing_function=smooth) |
| bleu4 = corpus_bleu(refs_tokenized, hyps_tokenized, |
| weights=(0.25, 0.25, 0.25, 0.25), smoothing_function=smooth) |
|
|
| return {"bleu1": round(bleu1, 4), "bleu4": round(bleu4, 4)} |
|
|
|
|
| def compute_rouge( |
| hypotheses: List[str], |
| references: List[str], |
| ) -> Dict[str, float]: |
| """ |
| Compute ROUGE-1, ROUGE-2, ROUGE-L. |
| |
| Returns: |
| {"rouge1": float, "rouge2": float, "rougeL": float} |
| """ |
| scorer = rouge_scorer.RougeScorer( |
| ["rouge1", "rouge2", "rougeL"], use_stemmer=True |
| ) |
|
|
| r1_list, r2_list, rl_list = [], [], [] |
| for hyp, ref in zip(hypotheses, references): |
| scores = scorer.score(ref, hyp) |
| r1_list.append(scores["rouge1"].fmeasure) |
| r2_list.append(scores["rouge2"].fmeasure) |
| rl_list.append(scores["rougeL"].fmeasure) |
|
|
| return { |
| "rouge1": round(np.mean(r1_list), 4), |
| "rouge2": round(np.mean(r2_list), 4), |
| "rougeL": round(np.mean(rl_list), 4), |
| } |
|
|
|
|
| def compute_meteor( |
| hypotheses: List[str], |
| references: List[str], |
| ) -> Dict[str, float]: |
| """ |
| Compute corpus-level METEOR score. |
| |
| METEOR improves over BLEU by: |
| - Matching synonyms via WordNet ("big" ↔ "large") |
| - Matching stems ("enlarged" ↔ "enlarging") |
| - Balancing precision + recall (weighted F-mean) |
| - Penalizing fragmented matches (chunk penalty) |
| |
| Especially useful for radiology where paraphrasing is common. |
| |
| Returns: |
| {"meteor": float} |
| """ |
| _ensure_nltk_data() |
|
|
| scores = [] |
| for hyp, ref in zip(hypotheses, references): |
| ref_tokens = ref.lower().split() |
| hyp_tokens = hyp.lower().split() |
| if not hyp_tokens or not ref_tokens: |
| scores.append(0.0) |
| continue |
| |
| scores.append(nltk_meteor([ref_tokens], hyp_tokens)) |
|
|
| return {"meteor": round(float(np.mean(scores)) if scores else 0.0, 4)} |
|
|
|
|
| def compute_bertscore( |
| hypotheses: List[str], |
| references: List[str], |
| model_type: str = "distilbert-base-uncased", |
| device: str = "cpu", |
| ) -> Dict[str, float]: |
| """ |
| Compute BERTScore F1 (semantic similarity). |
| Uses distilbert for speed; use 'microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract' |
| for higher clinical relevance. |
| |
| Returns: |
| {"bertscore_f1": float} |
| """ |
| try: |
| from bert_score import score as bert_score |
| P, R, F1 = bert_score( |
| hypotheses, |
| references, |
| model_type = model_type, |
| device = device, |
| verbose = False, |
| ) |
| return {"bertscore_f1": round(F1.mean().item(), 4)} |
| except ImportError: |
| print("[WARNING] bert-score not installed. Skipping BERTScore.") |
| return {"bertscore_f1": 0.0} |
|
|
|
|
| |
|
|
| def compute_clinical_f1( |
| hypotheses: List[str], |
| references: List[str], |
| chexbert_path: Optional[str] = None, |
| device: str = "cpu", |
| ) -> Dict[str, float]: |
| """ |
| Compute Clinical F1 using CheXbert NLP labeler. |
| Labels both hypothesis and reference with 14 CheXpert pathologies, |
| then computes macro-averaged F1. |
| |
| This is the primary clinical correctness metric used by RaDialog, |
| CheXagent, and most CXR report generation papers. |
| |
| Args: |
| hypotheses: generated report texts |
| references: ground truth report texts |
| chexbert_path: path to CheXbert model weights |
| (download from: stanfordmlgroup.github.io/projects/chexbert) |
| |
| Returns: |
| {"clinical_f1": float, "clinical_precision": float, "clinical_recall": float} |
| """ |
| if chexbert_path is None: |
| print("[WARNING] chexbert_path not provided. Skipping ClinicalF1.") |
| return {"clinical_f1": 0.0, "clinical_precision": 0.0, "clinical_recall": 0.0} |
|
|
| try: |
| |
| |
| from chexbert.label import label as chexbert_label |
|
|
| hyp_labels = chexbert_label(chexbert_path, hypotheses, device=device) |
| ref_labels = chexbert_label(chexbert_path, references, device=device) |
|
|
| |
| hyp_binary = (np.array(hyp_labels) == 1).astype(int) |
| ref_binary = (np.array(ref_labels) == 1).astype(int) |
|
|
| from sklearn.metrics import f1_score, precision_score, recall_score |
|
|
| f1 = f1_score(ref_binary, hyp_binary, average="macro", zero_division=0) |
| precision = precision_score(ref_binary, hyp_binary, average="macro", zero_division=0) |
| recall = recall_score(ref_binary, hyp_binary, average="macro", zero_division=0) |
|
|
| return { |
| "clinical_f1": round(f1, 4), |
| "clinical_precision": round(precision, 4), |
| "clinical_recall": round(recall, 4), |
| } |
|
|
| except Exception as e: |
| print(f"[WARNING] ClinicalF1 computation failed: {e}") |
| return {"clinical_f1": 0.0, "clinical_precision": 0.0, "clinical_recall": 0.0} |
|
|
|
|
| |
|
|
| def compute_vqa_accuracy( |
| hypotheses: List[str], |
| references: List[str], |
| ) -> Dict[str, float]: |
| """ |
| VQA accuracy metrics: |
| - Exact match accuracy (case-insensitive, stripped) |
| - Token F1 (overlap between predicted and reference tokens) |
| |
| Returns: |
| {"vqa_exact_match": float, "vqa_token_f1": float} |
| """ |
| exact_matches = [] |
| token_f1s = [] |
|
|
| for hyp, ref in zip(hypotheses, references): |
| hyp_norm = _normalize_answer(hyp) |
| ref_norm = _normalize_answer(ref) |
|
|
| |
| exact_matches.append(int(hyp_norm == ref_norm)) |
|
|
| |
| token_f1s.append(_token_f1(hyp_norm, ref_norm)) |
|
|
| return { |
| "vqa_exact_match": round(np.mean(exact_matches), 4), |
| "vqa_token_f1": round(np.mean(token_f1s), 4), |
| } |
|
|
|
|
| def _normalize_answer(text: str) -> str: |
| """Lowercase, remove punctuation, strip whitespace.""" |
| text = text.lower().strip() |
| text = re.sub(r"[^\w\s]", "", text) |
| text = re.sub(r"\s+", " ", text).strip() |
| return text |
|
|
|
|
| def _token_f1(prediction: str, ground_truth: str) -> float: |
| """Token-level F1 between two strings.""" |
| pred_tokens = prediction.split() |
| gt_tokens = ground_truth.split() |
|
|
| if not pred_tokens or not gt_tokens: |
| return 0.0 |
|
|
| common = set(pred_tokens) & set(gt_tokens) |
| if not common: |
| return 0.0 |
|
|
| precision = len(common) / len(pred_tokens) |
| recall = len(common) / len(gt_tokens) |
| f1 = 2 * precision * recall / (precision + recall) |
| return f1 |
|
|
|
|
| |
|
|
| _LLM_JUDGE_PROMPT = """You are a clinical evaluator for chest X-ray VQA. |
| Judge whether the predicted answer is semantically equivalent to the ground |
| truth in a medical context. Be tolerant of synonyms ("cardiomegaly" = |
| "enlarged heart"), paraphrases, and extra/missing function words. Penalize |
| contradictions (e.g. negating a positive finding) or clinically wrong |
| content. |
| |
| Question: {question} |
| Ground truth: {reference} |
| Prediction: {hypothesis} |
| |
| Reply with ONLY a JSON object of the form: {{"score": <0-5 integer>, "reason": "<one short sentence>"}} |
| Scoring rubric: |
| 5 = clinically equivalent |
| 4 = mostly correct, minor omission |
| 3 = partially correct |
| 2 = mostly incorrect |
| 1 = wrong but on topic |
| 0 = contradicts ground truth / unrelated""" |
|
|
|
|
| def compute_llm_judge( |
| hypotheses: List[str], |
| references: List[str], |
| questions: Optional[List[str]] = None, |
| model: str = "gpt-4o-mini", |
| api_key: Optional[str] = None, |
| base_url: Optional[str] = None, |
| max_samples: Optional[int] = None, |
| sleep_s: float = 0.0, |
| ) -> Dict[str, float]: |
| """ |
| Score (hyp, ref) pairs with an LLM judge (OpenAI-compatible API). |
| |
| Defaults to OpenAI's gpt-4o-mini (~$0.30 per 2k VQA samples). |
| For free alternatives, pass: |
| - Gemini : base_url="https://generativelanguage.googleapis.com/v1beta/openai/", model="gemini-1.5-flash" |
| - Local : base_url="http://localhost:11434/v1" (Ollama), model="llama3.1" |
| - Anthropic: needs separate SDK — not supported via this OpenAI-compatible path. |
| |
| Args: |
| hypotheses, references, questions: parallel lists |
| model: judge model name |
| api_key: defaults to env var OPENAI_API_KEY |
| base_url: override for non-OpenAI providers |
| max_samples: cap evaluation cost (e.g. 200) — useful for sanity checks |
| sleep_s: delay between calls to dodge rate limits |
| |
| Returns: |
| {"llm_judge_mean": float (0-5), "llm_judge_norm": float (0-1), |
| "llm_judge_n": int} |
| """ |
| try: |
| from openai import OpenAI |
| except ImportError: |
| print("[WARNING] openai package not installed. Skipping LLM-judge.") |
| return {"llm_judge_mean": 0.0, "llm_judge_norm": 0.0, "llm_judge_n": 0} |
|
|
| api_key = api_key or os.environ.get("OPENAI_API_KEY") |
| if not api_key: |
| print("[WARNING] OPENAI_API_KEY not set. Skipping LLM-judge.") |
| return {"llm_judge_mean": 0.0, "llm_judge_norm": 0.0, "llm_judge_n": 0} |
|
|
| client = OpenAI(api_key=api_key, base_url=base_url) if base_url else OpenAI(api_key=api_key) |
|
|
| n = len(hypotheses) |
| if max_samples is not None: |
| n = min(n, max_samples) |
|
|
| questions = questions or [""] * n |
| scores = [] |
| for i in range(n): |
| prompt = _LLM_JUDGE_PROMPT.format( |
| question = questions[i] or "(not provided)", |
| reference = references[i], |
| hypothesis = hypotheses[i], |
| ) |
| try: |
| resp = client.chat.completions.create( |
| model = model, |
| messages = [{"role": "user", "content": prompt}], |
| temperature = 0.0, |
| max_tokens = 80, |
| response_format = {"type": "json_object"}, |
| ) |
| raw = resp.choices[0].message.content.strip() |
| data = json.loads(raw) |
| score = int(data.get("score", 0)) |
| score = max(0, min(5, score)) |
| scores.append(score) |
| except Exception as e: |
| print(f"[LLM-judge] sample {i} failed: {e}") |
| if sleep_s > 0: |
| time.sleep(sleep_s) |
|
|
| if not scores: |
| return {"llm_judge_mean": 0.0, "llm_judge_norm": 0.0, "llm_judge_n": 0} |
|
|
| mean = float(np.mean(scores)) |
| return { |
| "llm_judge_mean": round(mean, 4), |
| "llm_judge_norm": round(mean / 5.0, 4), |
| "llm_judge_n": len(scores), |
| } |
|
|
|
|
| |
|
|
| def evaluate_all( |
| hypotheses: List[str], |
| references: List[str], |
| task: str, |
| chexbert_path: Optional[str] = None, |
| device: str = "cpu", |
| questions: Optional[List[str]] = None, |
| llm_judge: bool = False, |
| llm_judge_model: str = "gpt-4o-mini", |
| llm_judge_base_url: Optional[str] = None, |
| llm_judge_max_samples: Optional[int] = None, |
| ) -> Dict[str, float]: |
| """ |
| Compute all relevant metrics for a given task. |
| |
| Args: |
| hypotheses: model-generated texts |
| references: ground truth texts |
| task: "findings" | "impression" | "report" | "vqa" |
| chexbert_path: for clinical F1 (optional) |
| questions: VQA questions (passed to LLM judge for context) |
| llm_judge: if True, also run GPT/Claude/Gemini as a semantic judge |
| (requires OPENAI_API_KEY or compatible endpoint) |
| |
| Returns: |
| Dict of metric_name → score |
| """ |
| results = {} |
|
|
| |
| |
| if task in ("findings", "impression", "report"): |
| results.update(compute_bleu(hypotheses, references)) |
| results.update(compute_rouge(hypotheses, references)) |
| results.update(compute_meteor(hypotheses, references)) |
| results.update(compute_bertscore(hypotheses, references, device=device)) |
| results.update(compute_clinical_f1( |
| hypotheses, references, chexbert_path, device |
| )) |
|
|
| elif task == "vqa": |
| |
| results.update(compute_vqa_accuracy(hypotheses, references)) |
| results.update(compute_bleu(hypotheses, references)) |
| results.update(compute_meteor(hypotheses, references)) |
| |
| results.update(compute_bertscore(hypotheses, references, device=device)) |
| if llm_judge: |
| results.update(compute_llm_judge( |
| hypotheses, references, |
| questions = questions, |
| model = llm_judge_model, |
| base_url = llm_judge_base_url, |
| max_samples = llm_judge_max_samples, |
| )) |
|
|
| return results |
|
|
|
|
| def print_results(results: Dict[str, float], task: str): |
| """Pretty-print evaluation results.""" |
| print(f"\n{'='*50}") |
| print(f"Evaluation Results — Task: {task.upper()}") |
| print(f"{'='*50}") |
| for metric, value in results.items(): |
| print(f" {metric:<25} {value:.4f}") |
| print(f"{'='*50}\n") |
|
|