convitom Claude Opus 4.7 commited on
Commit
8f6cf28
·
1 Parent(s): 2cd8b5a

feat(eval): add METEOR + optional LLM-as-judge for VQA scoring

Browse files

METEOR (WordNet synonym/stem-aware) supplements BLEU/ROUGE for radiology
phrasing variance. LLM-as-judge is opt-in via --llm_judge (defaults to
gpt-4o-mini, OpenAI-compatible base_url overridable for Gemini/Claude),
with --llm_judge_max_samples for cost control.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

Files changed (3) hide show
  1. evaluation/evaluate.py +17 -0
  2. evaluation/metrics.py +185 -1
  3. requirements.txt +1 -0
evaluation/evaluate.py CHANGED
@@ -89,6 +89,18 @@ def parse_args():
89
  "If unset, resolved from state file.")
90
  parser.add_argument("--no_hf_upload", action="store_true",
91
  help="Disable HuggingFace Hub upload of predictions/metrics.")
 
 
 
 
 
 
 
 
 
 
 
 
92
  return parser.parse_args()
93
 
94
 
@@ -276,6 +288,11 @@ def main():
276
  task = task,
277
  chexbert_path = args.chexbert_path,
278
  device = args.device,
 
 
 
 
 
279
  )
280
 
281
  print_results(metrics, task)
 
89
  "If unset, resolved from state file.")
90
  parser.add_argument("--no_hf_upload", action="store_true",
91
  help="Disable HuggingFace Hub upload of predictions/metrics.")
92
+ # ── LLM-as-judge (VQA only) ─────────────────────────────────────────────
93
+ parser.add_argument("--llm_judge", action="store_true",
94
+ help="Enable LLM-as-judge semantic scoring for VQA. "
95
+ "Requires OPENAI_API_KEY (or compatible).")
96
+ parser.add_argument("--llm_judge_model", type=str, default="gpt-4o-mini",
97
+ help="Judge model name. Default: gpt-4o-mini "
98
+ "(~$0.30 / 2k VQA samples).")
99
+ parser.add_argument("--llm_judge_base_url", type=str, default=None,
100
+ help="Override base URL for non-OpenAI providers "
101
+ "(e.g. Gemini OpenAI-compat endpoint).")
102
+ parser.add_argument("--llm_judge_max_samples", type=int, default=None,
103
+ help="Cap number of samples sent to the judge (cost control).")
104
  return parser.parse_args()
105
 
106
 
 
288
  task = task,
289
  chexbert_path = args.chexbert_path,
290
  device = args.device,
291
+ questions = predictions.get("questions"),
292
+ llm_judge = args.llm_judge and task == "vqa",
293
+ llm_judge_model = args.llm_judge_model,
294
+ llm_judge_base_url = args.llm_judge_base_url,
295
+ llm_judge_max_samples = args.llm_judge_max_samples,
296
  )
297
 
298
  print_results(metrics, task)
evaluation/metrics.py CHANGED
@@ -14,17 +14,39 @@ Evaluation metrics for 3 tasks:
14
  - Accuracy (exact match)
15
  - Token-level F1
16
  - BLEU-1 (for open-ended answers)
 
 
 
17
  """
18
 
 
19
  import re
 
 
20
  from typing import List, Dict, Optional, Tuple
21
 
22
  import torch
23
  import numpy as np
24
  from nltk.translate.bleu_score import corpus_bleu, sentence_bleu, SmoothingFunction
 
25
  from rouge_score import rouge_scorer
26
 
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  # ─── NLG Metrics ─────────────────────────────────────────────────────────────
29
 
30
  def compute_bleu(
@@ -82,6 +104,39 @@ def compute_rouge(
82
  }
83
 
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  def compute_bertscore(
86
  hypotheses: List[str],
87
  references: List[str],
@@ -228,6 +283,114 @@ def _token_f1(prediction: str, ground_truth: str) -> float:
228
  return f1
229
 
230
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
  # ─── Master Evaluation Function ──────────────────────────────────────────────
232
 
233
  def evaluate_all(
@@ -236,6 +399,11 @@ def evaluate_all(
236
  task: str,
237
  chexbert_path: Optional[str] = None,
238
  device: str = "cpu",
 
 
 
 
 
239
  ) -> Dict[str, float]:
240
  """
241
  Compute all relevant metrics for a given task.
@@ -243,8 +411,11 @@ def evaluate_all(
243
  Args:
244
  hypotheses: model-generated texts
245
  references: ground truth texts
246
- task: "findings" | "impression" | "vqa"
247
  chexbert_path: for clinical F1 (optional)
 
 
 
248
 
249
  Returns:
250
  Dict of metric_name → score
@@ -256,14 +427,27 @@ def evaluate_all(
256
  if task in ("findings", "impression", "report"):
257
  results.update(compute_bleu(hypotheses, references))
258
  results.update(compute_rouge(hypotheses, references))
 
259
  results.update(compute_bertscore(hypotheses, references, device=device))
260
  results.update(compute_clinical_f1(
261
  hypotheses, references, chexbert_path, device
262
  ))
263
 
264
  elif task == "vqa":
 
265
  results.update(compute_vqa_accuracy(hypotheses, references))
266
  results.update(compute_bleu(hypotheses, references))
 
 
 
 
 
 
 
 
 
 
 
267
 
268
  return results
269
 
 
14
  - Accuracy (exact match)
15
  - Token-level F1
16
  - BLEU-1 (for open-ended answers)
17
+ - METEOR (synonym + stem aware)
18
+ - BERTScore (semantic similarity)
19
+ - LLM-as-Judge (optional, GPT/Claude/Gemini for clinical semantic eval)
20
  """
21
 
22
+ import os
23
  import re
24
+ import json
25
+ import time
26
  from typing import List, Dict, Optional, Tuple
27
 
28
  import torch
29
  import numpy as np
30
  from nltk.translate.bleu_score import corpus_bleu, sentence_bleu, SmoothingFunction
31
+ from nltk.translate.meteor_score import meteor_score as nltk_meteor
32
  from rouge_score import rouge_scorer
33
 
34
 
35
+ # Ensure NLTK data required for METEOR is available (wordnet + punkt).
36
+ # Safe to call repeatedly — nltk.download() is a no-op if already present.
37
+ def _ensure_nltk_data():
38
+ import nltk
39
+ for pkg, path in [
40
+ ("wordnet", "corpora/wordnet"),
41
+ ("omw-1.4", "corpora/omw-1.4"),
42
+ ("punkt", "tokenizers/punkt"),
43
+ ]:
44
+ try:
45
+ nltk.data.find(path)
46
+ except LookupError:
47
+ nltk.download(pkg, quiet=True)
48
+
49
+
50
  # ─── NLG Metrics ─────────────────────────────────────────────────────────────
51
 
52
  def compute_bleu(
 
104
  }
105
 
106
 
107
+ def compute_meteor(
108
+ hypotheses: List[str],
109
+ references: List[str],
110
+ ) -> Dict[str, float]:
111
+ """
112
+ Compute corpus-level METEOR score.
113
+
114
+ METEOR improves over BLEU by:
115
+ - Matching synonyms via WordNet ("big" ↔ "large")
116
+ - Matching stems ("enlarged" ↔ "enlarging")
117
+ - Balancing precision + recall (weighted F-mean)
118
+ - Penalizing fragmented matches (chunk penalty)
119
+
120
+ Especially useful for radiology where paraphrasing is common.
121
+
122
+ Returns:
123
+ {"meteor": float}
124
+ """
125
+ _ensure_nltk_data()
126
+
127
+ scores = []
128
+ for hyp, ref in zip(hypotheses, references):
129
+ ref_tokens = ref.lower().split()
130
+ hyp_tokens = hyp.lower().split()
131
+ if not hyp_tokens or not ref_tokens:
132
+ scores.append(0.0)
133
+ continue
134
+ # nltk_meteor takes a list of references (here just one)
135
+ scores.append(nltk_meteor([ref_tokens], hyp_tokens))
136
+
137
+ return {"meteor": round(float(np.mean(scores)) if scores else 0.0, 4)}
138
+
139
+
140
  def compute_bertscore(
141
  hypotheses: List[str],
142
  references: List[str],
 
283
  return f1
284
 
285
 
286
+ # ─── LLM-as-Judge (semantic correctness via GPT/Claude/Gemini) ───────────────
287
+
288
+ _LLM_JUDGE_PROMPT = """You are a clinical evaluator for chest X-ray VQA.
289
+ Judge whether the predicted answer is semantically equivalent to the ground
290
+ truth in a medical context. Be tolerant of synonyms ("cardiomegaly" =
291
+ "enlarged heart"), paraphrases, and extra/missing function words. Penalize
292
+ contradictions (e.g. negating a positive finding) or clinically wrong
293
+ content.
294
+
295
+ Question: {question}
296
+ Ground truth: {reference}
297
+ Prediction: {hypothesis}
298
+
299
+ Reply with ONLY a JSON object of the form: {{"score": <0-5 integer>, "reason": "<one short sentence>"}}
300
+ Scoring rubric:
301
+ 5 = clinically equivalent
302
+ 4 = mostly correct, minor omission
303
+ 3 = partially correct
304
+ 2 = mostly incorrect
305
+ 1 = wrong but on topic
306
+ 0 = contradicts ground truth / unrelated"""
307
+
308
+
309
+ def compute_llm_judge(
310
+ hypotheses: List[str],
311
+ references: List[str],
312
+ questions: Optional[List[str]] = None,
313
+ model: str = "gpt-4o-mini",
314
+ api_key: Optional[str] = None,
315
+ base_url: Optional[str] = None,
316
+ max_samples: Optional[int] = None,
317
+ sleep_s: float = 0.0,
318
+ ) -> Dict[str, float]:
319
+ """
320
+ Score (hyp, ref) pairs with an LLM judge (OpenAI-compatible API).
321
+
322
+ Defaults to OpenAI's gpt-4o-mini (~$0.30 per 2k VQA samples).
323
+ For free alternatives, pass:
324
+ - Gemini : base_url="https://generativelanguage.googleapis.com/v1beta/openai/", model="gemini-1.5-flash"
325
+ - Local : base_url="http://localhost:11434/v1" (Ollama), model="llama3.1"
326
+ - Anthropic: needs separate SDK — not supported via this OpenAI-compatible path.
327
+
328
+ Args:
329
+ hypotheses, references, questions: parallel lists
330
+ model: judge model name
331
+ api_key: defaults to env var OPENAI_API_KEY
332
+ base_url: override for non-OpenAI providers
333
+ max_samples: cap evaluation cost (e.g. 200) — useful for sanity checks
334
+ sleep_s: delay between calls to dodge rate limits
335
+
336
+ Returns:
337
+ {"llm_judge_mean": float (0-5), "llm_judge_norm": float (0-1),
338
+ "llm_judge_n": int}
339
+ """
340
+ try:
341
+ from openai import OpenAI
342
+ except ImportError:
343
+ print("[WARNING] openai package not installed. Skipping LLM-judge.")
344
+ return {"llm_judge_mean": 0.0, "llm_judge_norm": 0.0, "llm_judge_n": 0}
345
+
346
+ api_key = api_key or os.environ.get("OPENAI_API_KEY")
347
+ if not api_key:
348
+ print("[WARNING] OPENAI_API_KEY not set. Skipping LLM-judge.")
349
+ return {"llm_judge_mean": 0.0, "llm_judge_norm": 0.0, "llm_judge_n": 0}
350
+
351
+ client = OpenAI(api_key=api_key, base_url=base_url) if base_url else OpenAI(api_key=api_key)
352
+
353
+ n = len(hypotheses)
354
+ if max_samples is not None:
355
+ n = min(n, max_samples)
356
+
357
+ questions = questions or [""] * n
358
+ scores = []
359
+ for i in range(n):
360
+ prompt = _LLM_JUDGE_PROMPT.format(
361
+ question = questions[i] or "(not provided)",
362
+ reference = references[i],
363
+ hypothesis = hypotheses[i],
364
+ )
365
+ try:
366
+ resp = client.chat.completions.create(
367
+ model = model,
368
+ messages = [{"role": "user", "content": prompt}],
369
+ temperature = 0.0,
370
+ max_tokens = 80,
371
+ response_format = {"type": "json_object"},
372
+ )
373
+ raw = resp.choices[0].message.content.strip()
374
+ data = json.loads(raw)
375
+ score = int(data.get("score", 0))
376
+ score = max(0, min(5, score))
377
+ scores.append(score)
378
+ except Exception as e:
379
+ print(f"[LLM-judge] sample {i} failed: {e}")
380
+ if sleep_s > 0:
381
+ time.sleep(sleep_s)
382
+
383
+ if not scores:
384
+ return {"llm_judge_mean": 0.0, "llm_judge_norm": 0.0, "llm_judge_n": 0}
385
+
386
+ mean = float(np.mean(scores))
387
+ return {
388
+ "llm_judge_mean": round(mean, 4),
389
+ "llm_judge_norm": round(mean / 5.0, 4), # 0..1 for easy comparison
390
+ "llm_judge_n": len(scores),
391
+ }
392
+
393
+
394
  # ─── Master Evaluation Function ──────────────────────────────────────────────
395
 
396
  def evaluate_all(
 
399
  task: str,
400
  chexbert_path: Optional[str] = None,
401
  device: str = "cpu",
402
+ questions: Optional[List[str]] = None,
403
+ llm_judge: bool = False,
404
+ llm_judge_model: str = "gpt-4o-mini",
405
+ llm_judge_base_url: Optional[str] = None,
406
+ llm_judge_max_samples: Optional[int] = None,
407
  ) -> Dict[str, float]:
408
  """
409
  Compute all relevant metrics for a given task.
 
411
  Args:
412
  hypotheses: model-generated texts
413
  references: ground truth texts
414
+ task: "findings" | "impression" | "report" | "vqa"
415
  chexbert_path: for clinical F1 (optional)
416
+ questions: VQA questions (passed to LLM judge for context)
417
+ llm_judge: if True, also run GPT/Claude/Gemini as a semantic judge
418
+ (requires OPENAI_API_KEY or compatible endpoint)
419
 
420
  Returns:
421
  Dict of metric_name → score
 
427
  if task in ("findings", "impression", "report"):
428
  results.update(compute_bleu(hypotheses, references))
429
  results.update(compute_rouge(hypotheses, references))
430
+ results.update(compute_meteor(hypotheses, references))
431
  results.update(compute_bertscore(hypotheses, references, device=device))
432
  results.update(compute_clinical_f1(
433
  hypotheses, references, chexbert_path, device
434
  ))
435
 
436
  elif task == "vqa":
437
+ # Lexical
438
  results.update(compute_vqa_accuracy(hypotheses, references))
439
  results.update(compute_bleu(hypotheses, references))
440
+ results.update(compute_meteor(hypotheses, references))
441
+ # Semantic
442
+ results.update(compute_bertscore(hypotheses, references, device=device))
443
+ if llm_judge:
444
+ results.update(compute_llm_judge(
445
+ hypotheses, references,
446
+ questions = questions,
447
+ model = llm_judge_model,
448
+ base_url = llm_judge_base_url,
449
+ max_samples = llm_judge_max_samples,
450
+ ))
451
 
452
  return results
453
 
requirements.txt CHANGED
@@ -12,6 +12,7 @@ wandb==0.16.0
12
  rouge-score==0.1.2
13
  nltk==3.8.1
14
  bert-score==0.3.13
 
15
  scikit-learn==1.3.2
16
  pandas==2.1.0
17
  numpy==1.24.0
 
12
  rouge-score==0.1.2
13
  nltk==3.8.1
14
  bert-score==0.3.13
15
+ openai>=1.30.0 # optional: LLM-as-judge for VQA (also works with Gemini/Ollama via base_url)
16
  scikit-learn==1.3.2
17
  pandas==2.1.0
18
  numpy==1.24.0