cxr-vlm-code / evaluation /metrics.py
convitom
feat(eval): add METEOR + optional LLM-as-judge for VQA scoring
8f6cf28
"""
metrics.py
----------
Evaluation metrics for 3 tasks:
Findings / Impression Generation:
- BLEU-1, BLEU-4
- ROUGE-L
- METEOR
- BERTScore (F1)
- ClinicalF1 (via CheXbert — clinical correctness metric)
VQA:
- Accuracy (exact match)
- Token-level F1
- BLEU-1 (for open-ended answers)
- METEOR (synonym + stem aware)
- BERTScore (semantic similarity)
- LLM-as-Judge (optional, GPT/Claude/Gemini for clinical semantic eval)
"""
import os
import re
import json
import time
from typing import List, Dict, Optional, Tuple
import torch
import numpy as np
from nltk.translate.bleu_score import corpus_bleu, sentence_bleu, SmoothingFunction
from nltk.translate.meteor_score import meteor_score as nltk_meteor
from rouge_score import rouge_scorer
# Ensure NLTK data required for METEOR is available (wordnet + punkt).
# Safe to call repeatedly — nltk.download() is a no-op if already present.
def _ensure_nltk_data():
import nltk
for pkg, path in [
("wordnet", "corpora/wordnet"),
("omw-1.4", "corpora/omw-1.4"),
("punkt", "tokenizers/punkt"),
]:
try:
nltk.data.find(path)
except LookupError:
nltk.download(pkg, quiet=True)
# ─── NLG Metrics ─────────────────────────────────────────────────────────────
def compute_bleu(
hypotheses: List[str],
references: List[str],
) -> Dict[str, float]:
"""
Compute corpus-level BLEU-1 and BLEU-4.
Args:
hypotheses: list of generated texts
references: list of ground truth texts
Returns:
{"bleu1": float, "bleu4": float}
"""
smooth = SmoothingFunction().method1
refs_tokenized = [[ref.lower().split()] for ref in references]
hyps_tokenized = [hyp.lower().split() for hyp in hypotheses]
bleu1 = corpus_bleu(refs_tokenized, hyps_tokenized,
weights=(1, 0, 0, 0), smoothing_function=smooth)
bleu4 = corpus_bleu(refs_tokenized, hyps_tokenized,
weights=(0.25, 0.25, 0.25, 0.25), smoothing_function=smooth)
return {"bleu1": round(bleu1, 4), "bleu4": round(bleu4, 4)}
def compute_rouge(
hypotheses: List[str],
references: List[str],
) -> Dict[str, float]:
"""
Compute ROUGE-1, ROUGE-2, ROUGE-L.
Returns:
{"rouge1": float, "rouge2": float, "rougeL": float}
"""
scorer = rouge_scorer.RougeScorer(
["rouge1", "rouge2", "rougeL"], use_stemmer=True
)
r1_list, r2_list, rl_list = [], [], []
for hyp, ref in zip(hypotheses, references):
scores = scorer.score(ref, hyp)
r1_list.append(scores["rouge1"].fmeasure)
r2_list.append(scores["rouge2"].fmeasure)
rl_list.append(scores["rougeL"].fmeasure)
return {
"rouge1": round(np.mean(r1_list), 4),
"rouge2": round(np.mean(r2_list), 4),
"rougeL": round(np.mean(rl_list), 4),
}
def compute_meteor(
hypotheses: List[str],
references: List[str],
) -> Dict[str, float]:
"""
Compute corpus-level METEOR score.
METEOR improves over BLEU by:
- Matching synonyms via WordNet ("big" ↔ "large")
- Matching stems ("enlarged" ↔ "enlarging")
- Balancing precision + recall (weighted F-mean)
- Penalizing fragmented matches (chunk penalty)
Especially useful for radiology where paraphrasing is common.
Returns:
{"meteor": float}
"""
_ensure_nltk_data()
scores = []
for hyp, ref in zip(hypotheses, references):
ref_tokens = ref.lower().split()
hyp_tokens = hyp.lower().split()
if not hyp_tokens or not ref_tokens:
scores.append(0.0)
continue
# nltk_meteor takes a list of references (here just one)
scores.append(nltk_meteor([ref_tokens], hyp_tokens))
return {"meteor": round(float(np.mean(scores)) if scores else 0.0, 4)}
def compute_bertscore(
hypotheses: List[str],
references: List[str],
model_type: str = "distilbert-base-uncased",
device: str = "cpu",
) -> Dict[str, float]:
"""
Compute BERTScore F1 (semantic similarity).
Uses distilbert for speed; use 'microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract'
for higher clinical relevance.
Returns:
{"bertscore_f1": float}
"""
try:
from bert_score import score as bert_score
P, R, F1 = bert_score(
hypotheses,
references,
model_type = model_type,
device = device,
verbose = False,
)
return {"bertscore_f1": round(F1.mean().item(), 4)}
except ImportError:
print("[WARNING] bert-score not installed. Skipping BERTScore.")
return {"bertscore_f1": 0.0}
# ─── Clinical F1 (CheXbert-based) ────────────────────────────────────────────
def compute_clinical_f1(
hypotheses: List[str],
references: List[str],
chexbert_path: Optional[str] = None,
device: str = "cpu",
) -> Dict[str, float]:
"""
Compute Clinical F1 using CheXbert NLP labeler.
Labels both hypothesis and reference with 14 CheXpert pathologies,
then computes macro-averaged F1.
This is the primary clinical correctness metric used by RaDialog,
CheXagent, and most CXR report generation papers.
Args:
hypotheses: generated report texts
references: ground truth report texts
chexbert_path: path to CheXbert model weights
(download from: stanfordmlgroup.github.io/projects/chexbert)
Returns:
{"clinical_f1": float, "clinical_precision": float, "clinical_recall": float}
"""
if chexbert_path is None:
print("[WARNING] chexbert_path not provided. Skipping ClinicalF1.")
return {"clinical_f1": 0.0, "clinical_precision": 0.0, "clinical_recall": 0.0}
try:
# CheXbert labeler — labels text into 14 pathology classes
# Install from: github.com/stanfordmlgroup/CheXbert
from chexbert.label import label as chexbert_label
hyp_labels = chexbert_label(chexbert_path, hypotheses, device=device)
ref_labels = chexbert_label(chexbert_path, references, device=device)
# Convert to binary (positive=1, negative/uncertain=0)
hyp_binary = (np.array(hyp_labels) == 1).astype(int)
ref_binary = (np.array(ref_labels) == 1).astype(int)
from sklearn.metrics import f1_score, precision_score, recall_score
f1 = f1_score(ref_binary, hyp_binary, average="macro", zero_division=0)
precision = precision_score(ref_binary, hyp_binary, average="macro", zero_division=0)
recall = recall_score(ref_binary, hyp_binary, average="macro", zero_division=0)
return {
"clinical_f1": round(f1, 4),
"clinical_precision": round(precision, 4),
"clinical_recall": round(recall, 4),
}
except Exception as e:
print(f"[WARNING] ClinicalF1 computation failed: {e}")
return {"clinical_f1": 0.0, "clinical_precision": 0.0, "clinical_recall": 0.0}
# ─── VQA Metrics ─────────────────────────────────────────────────────────────
def compute_vqa_accuracy(
hypotheses: List[str],
references: List[str],
) -> Dict[str, float]:
"""
VQA accuracy metrics:
- Exact match accuracy (case-insensitive, stripped)
- Token F1 (overlap between predicted and reference tokens)
Returns:
{"vqa_exact_match": float, "vqa_token_f1": float}
"""
exact_matches = []
token_f1s = []
for hyp, ref in zip(hypotheses, references):
hyp_norm = _normalize_answer(hyp)
ref_norm = _normalize_answer(ref)
# Exact match
exact_matches.append(int(hyp_norm == ref_norm))
# Token F1
token_f1s.append(_token_f1(hyp_norm, ref_norm))
return {
"vqa_exact_match": round(np.mean(exact_matches), 4),
"vqa_token_f1": round(np.mean(token_f1s), 4),
}
def _normalize_answer(text: str) -> str:
"""Lowercase, remove punctuation, strip whitespace."""
text = text.lower().strip()
text = re.sub(r"[^\w\s]", "", text)
text = re.sub(r"\s+", " ", text).strip()
return text
def _token_f1(prediction: str, ground_truth: str) -> float:
"""Token-level F1 between two strings."""
pred_tokens = prediction.split()
gt_tokens = ground_truth.split()
if not pred_tokens or not gt_tokens:
return 0.0
common = set(pred_tokens) & set(gt_tokens)
if not common:
return 0.0
precision = len(common) / len(pred_tokens)
recall = len(common) / len(gt_tokens)
f1 = 2 * precision * recall / (precision + recall)
return f1
# ─── LLM-as-Judge (semantic correctness via GPT/Claude/Gemini) ───────────────
_LLM_JUDGE_PROMPT = """You are a clinical evaluator for chest X-ray VQA.
Judge whether the predicted answer is semantically equivalent to the ground
truth in a medical context. Be tolerant of synonyms ("cardiomegaly" =
"enlarged heart"), paraphrases, and extra/missing function words. Penalize
contradictions (e.g. negating a positive finding) or clinically wrong
content.
Question: {question}
Ground truth: {reference}
Prediction: {hypothesis}
Reply with ONLY a JSON object of the form: {{"score": <0-5 integer>, "reason": "<one short sentence>"}}
Scoring rubric:
5 = clinically equivalent
4 = mostly correct, minor omission
3 = partially correct
2 = mostly incorrect
1 = wrong but on topic
0 = contradicts ground truth / unrelated"""
def compute_llm_judge(
hypotheses: List[str],
references: List[str],
questions: Optional[List[str]] = None,
model: str = "gpt-4o-mini",
api_key: Optional[str] = None,
base_url: Optional[str] = None,
max_samples: Optional[int] = None,
sleep_s: float = 0.0,
) -> Dict[str, float]:
"""
Score (hyp, ref) pairs with an LLM judge (OpenAI-compatible API).
Defaults to OpenAI's gpt-4o-mini (~$0.30 per 2k VQA samples).
For free alternatives, pass:
- Gemini : base_url="https://generativelanguage.googleapis.com/v1beta/openai/", model="gemini-1.5-flash"
- Local : base_url="http://localhost:11434/v1" (Ollama), model="llama3.1"
- Anthropic: needs separate SDK — not supported via this OpenAI-compatible path.
Args:
hypotheses, references, questions: parallel lists
model: judge model name
api_key: defaults to env var OPENAI_API_KEY
base_url: override for non-OpenAI providers
max_samples: cap evaluation cost (e.g. 200) — useful for sanity checks
sleep_s: delay between calls to dodge rate limits
Returns:
{"llm_judge_mean": float (0-5), "llm_judge_norm": float (0-1),
"llm_judge_n": int}
"""
try:
from openai import OpenAI
except ImportError:
print("[WARNING] openai package not installed. Skipping LLM-judge.")
return {"llm_judge_mean": 0.0, "llm_judge_norm": 0.0, "llm_judge_n": 0}
api_key = api_key or os.environ.get("OPENAI_API_KEY")
if not api_key:
print("[WARNING] OPENAI_API_KEY not set. Skipping LLM-judge.")
return {"llm_judge_mean": 0.0, "llm_judge_norm": 0.0, "llm_judge_n": 0}
client = OpenAI(api_key=api_key, base_url=base_url) if base_url else OpenAI(api_key=api_key)
n = len(hypotheses)
if max_samples is not None:
n = min(n, max_samples)
questions = questions or [""] * n
scores = []
for i in range(n):
prompt = _LLM_JUDGE_PROMPT.format(
question = questions[i] or "(not provided)",
reference = references[i],
hypothesis = hypotheses[i],
)
try:
resp = client.chat.completions.create(
model = model,
messages = [{"role": "user", "content": prompt}],
temperature = 0.0,
max_tokens = 80,
response_format = {"type": "json_object"},
)
raw = resp.choices[0].message.content.strip()
data = json.loads(raw)
score = int(data.get("score", 0))
score = max(0, min(5, score))
scores.append(score)
except Exception as e:
print(f"[LLM-judge] sample {i} failed: {e}")
if sleep_s > 0:
time.sleep(sleep_s)
if not scores:
return {"llm_judge_mean": 0.0, "llm_judge_norm": 0.0, "llm_judge_n": 0}
mean = float(np.mean(scores))
return {
"llm_judge_mean": round(mean, 4),
"llm_judge_norm": round(mean / 5.0, 4), # 0..1 for easy comparison
"llm_judge_n": len(scores),
}
# ─── Master Evaluation Function ──────────────────────────────────────────────
def evaluate_all(
hypotheses: List[str],
references: List[str],
task: str,
chexbert_path: Optional[str] = None,
device: str = "cpu",
questions: Optional[List[str]] = None,
llm_judge: bool = False,
llm_judge_model: str = "gpt-4o-mini",
llm_judge_base_url: Optional[str] = None,
llm_judge_max_samples: Optional[int] = None,
) -> Dict[str, float]:
"""
Compute all relevant metrics for a given task.
Args:
hypotheses: model-generated texts
references: ground truth texts
task: "findings" | "impression" | "report" | "vqa"
chexbert_path: for clinical F1 (optional)
questions: VQA questions (passed to LLM judge for context)
llm_judge: if True, also run GPT/Claude/Gemini as a semantic judge
(requires OPENAI_API_KEY or compatible endpoint)
Returns:
Dict of metric_name → score
"""
results = {}
# "report" is the merged-mode task (full Findings + Impression in one
# target). Same NLG/clinical metrics apply as for findings/impression.
if task in ("findings", "impression", "report"):
results.update(compute_bleu(hypotheses, references))
results.update(compute_rouge(hypotheses, references))
results.update(compute_meteor(hypotheses, references))
results.update(compute_bertscore(hypotheses, references, device=device))
results.update(compute_clinical_f1(
hypotheses, references, chexbert_path, device
))
elif task == "vqa":
# Lexical
results.update(compute_vqa_accuracy(hypotheses, references))
results.update(compute_bleu(hypotheses, references))
results.update(compute_meteor(hypotheses, references))
# Semantic
results.update(compute_bertscore(hypotheses, references, device=device))
if llm_judge:
results.update(compute_llm_judge(
hypotheses, references,
questions = questions,
model = llm_judge_model,
base_url = llm_judge_base_url,
max_samples = llm_judge_max_samples,
))
return results
def print_results(results: Dict[str, float], task: str):
"""Pretty-print evaluation results."""
print(f"\n{'='*50}")
print(f"Evaluation Results — Task: {task.upper()}")
print(f"{'='*50}")
for metric, value in results.items():
print(f" {metric:<25} {value:.4f}")
print(f"{'='*50}\n")