cxr-vlm-code / evaluation /evaluate.py
convitom
f
320063f
"""
evaluate.py
-----------
Evaluation entry point. Runs inference on the chosen split and computes
all metrics per task (findings, impression, VQA).
The dataset is selected via `train_cfg.data.dataset_name`:
- "MIMIC-CXR" → evaluates findings, impression, VQA
- "IU-Xray" → evaluates findings, impression only
Results are saved under:
{output_dir}/{dataset_name}_run_{N}/predictions_{task}.json
{output_dir}/{dataset_name}_run_{N}/metrics_summary.json
Usage (local checkpoint):
python -m evaluation.evaluate \
--model_config configs/model_config.yaml \
--train_config configs/train_config.yaml \
--checkpoint checkpoints/IU-Xray_run_1/stage2_instruct/stage2_final.pt \
--task all \
--output_dir results/
Usage (pull best/ from HF Hub first):
huggingface-cli download <user>/cxr-vlm-runs \
IU-Xray_run_1/stage2/best --local-dir ./hf_pulled
python -m evaluation.evaluate \
--checkpoint ./hf_pulled/IU-Xray_run_1/stage2/best/checkpoint_projection.pt \
--task all --output_dir results/
The `--checkpoint` arg may point at any `<dir>/<name>_projection.pt`; the loader
also picks up `<dir>/<name>_lora/` and `<dir>/<name>_chexpert_classifier.pt`
from the same folder.
"""
import os
import sys
from pathlib import Path
# Silence HF per-shard download tqdm spam — MUST be before transformers/peft/hf_hub imports
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
import utils._quiet # noqa: F401
import json
import argparse
from typing import List, Dict, Optional
import torch
from torch.utils.data import DataLoader
from omegaconf import OmegaConf
from tqdm.auto import tqdm
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from model import CXRVisionLanguageModel
from model.rad_dino import BioViLTEncoder
from data import CXRInstructDataset, CXRDataCollator
from data.prompt_templates import (
build_findings_prompt,
build_impression_prompt,
build_report_prompt,
build_vqa_prompt,
)
from data.dataset import parse_generated_report
from evaluation.metrics import evaluate_all, print_results
from utils.logger import setup_logger
from utils.checkpoint import load_checkpoint
from utils.hf_uploader import build_tracker_from_cfg
from utils.dataset_resolver import resolve_dataset_spec, resolve_run_id
def parse_args():
parser = argparse.ArgumentParser(description="Evaluate CXR VLM")
parser.add_argument("--model_config", type=str, default="configs/model_config.yaml")
parser.add_argument("--train_config", type=str, default="configs/train_config.yaml")
parser.add_argument("--checkpoint", type=str, required=True,
help="Path to model checkpoint")
parser.add_argument("--task", type=str, default="all",
choices=["findings", "impression", "report", "vqa", "all"])
parser.add_argument("--split", type=str, default="test")
parser.add_argument("--output_dir", type=str, default="results/",
help="Root dir; results land under {output_dir}/{run_id}/")
parser.add_argument("--chexbert_path", type=str, default=None,
help="Path to CheXbert weights for ClinicalF1")
parser.add_argument("--batch_size", type=int, default=8)
parser.add_argument("--max_new_tokens", type=int, default=300)
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--run_id", type=str, default=None,
help="Explicit run id (e.g. 'IU-Xray_run_3'). "
"If unset, resolved from state file.")
parser.add_argument("--no_hf_upload", action="store_true",
help="Disable HuggingFace Hub upload of predictions/metrics.")
# ── LLM-as-judge (VQA only) ─────────────────────────────────────────────
parser.add_argument("--llm_judge", action="store_true",
help="Enable LLM-as-judge semantic scoring for VQA. "
"Requires OPENAI_API_KEY (or compatible).")
parser.add_argument("--llm_judge_model", type=str, default="gpt-4o-mini",
help="Judge model name. Default: gpt-4o-mini "
"(~$0.30 / 2k VQA samples).")
parser.add_argument("--llm_judge_base_url", type=str, default=None,
help="Override base URL for non-OpenAI providers "
"(e.g. Gemini OpenAI-compat endpoint).")
parser.add_argument("--llm_judge_max_samples", type=int, default=None,
help="Cap number of samples sent to the judge (cost control).")
return parser.parse_args()
@torch.no_grad()
def run_inference(
model,
dataset: CXRInstructDataset,
task: str,
batch_size: int,
max_new_tokens: int,
device: str,
) -> Dict[str, List[str]]:
"""
Run inference on a dataset split for a specific task.
Returns:
{"hypotheses": [...], "references": [...], "questions": [...]}
"""
task_samples = [s for s in dataset.samples if s["task"] == task]
if not task_samples:
return {"hypotheses": [], "references": [], "questions": []}
hypotheses, references, questions = [], [], []
for i in tqdm(range(0, len(task_samples), batch_size),
desc=f"Evaluating {task}"):
batch_samples = task_samples[i:i + batch_size]
images, prompts = [], []
for s in batch_samples:
# Use the same code path as training: image_paths (list) → stacked,
# image_path (string) → single image. Keeps multi-image mode working.
if s.get("image_paths"):
img = dataset._load_image_stack(s["image_paths"]) # (N, C, H, W)
else:
img = dataset._load_image(s["image_path"]) # (C, H, W)
images.append(img)
sf = s.get("structured_findings")
if task == "findings":
prompt = build_findings_prompt(sf, randomize=False)
elif task == "impression":
prompt = build_impression_prompt(sf, randomize=False)
elif task == "report":
prompt = build_report_prompt(sf, randomize=False)
else: # vqa
prompt = build_vqa_prompt(s["question"], sf)
prompts.append(prompt)
images_tensor = torch.stack(images).to(device)
generated = model.generate(
images = images_tensor,
prompts = prompts,
max_new_tokens = max_new_tokens,
)
hypotheses.extend(generated)
references.extend([s["target"] for s in batch_samples])
if task == "vqa":
questions.extend([s.get("question", "") for s in batch_samples])
return {"hypotheses": hypotheses, "references": references, "questions": questions}
def save_predictions(predictions: Dict, task: str, output_dir: str):
"""Save predictions to JSON for later analysis."""
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
output_path = output_dir / f"predictions_{task}.json"
records = []
for i, (hyp, ref) in enumerate(
zip(predictions["hypotheses"], predictions["references"])
):
record = {"hypothesis": hyp, "reference": ref}
if predictions.get("questions") and i < len(predictions["questions"]):
record["question"] = predictions["questions"][i]
records.append(record)
with open(output_path, "w") as f:
json.dump(records, f, indent=2)
print(f"Predictions saved to {output_path}")
def main():
args = parse_args()
logger = setup_logger("cxr_vlm_eval")
model_cfg = OmegaConf.load(args.model_config)
train_cfg = OmegaConf.load(args.train_config)
# ── Resolve dataset + run_id ─────────────────────────────────────
spec = resolve_dataset_spec(train_cfg)
logger.info(f"Dataset: {spec.dataset_name}")
output_root = str(train_cfg.training.get("output_root", "checkpoints"))
state_file = str(train_cfg.hf_hub.run_state_file)
hf_token = os.environ.get(
train_cfg.hf_hub.token_env, os.environ.get("HF_TOKEN")
) if train_cfg.hf_hub.enabled else None
hf_repo_id = train_cfg.hf_hub.repo_id if train_cfg.hf_hub.enabled else None
# Evaluation always resumes an existing run
run_id = resolve_run_id(
dataset_name = spec.dataset_name,
output_root = output_root,
state_file = state_file,
resuming = True,
explicit = args.run_id,
hf_repo_id = hf_repo_id,
hf_token = hf_token,
)
logger.info(f"run_id = {run_id}")
# Results go under {output_dir}/{run_id}/
results_dir = Path(args.output_dir) / run_id
results_dir.mkdir(parents=True, exist_ok=True)
# HF Hub tracker
tracker = None
if not args.no_hf_upload:
tracker = build_tracker_from_cfg(
train_cfg,
resuming = True,
explicit_run_id = run_id,
)
# Build and load model
logger.info(f"Loading model from checkpoint: {args.checkpoint}")
model = CXRVisionLanguageModel(model_cfg)
load_checkpoint(model, args.checkpoint)
model = model.to(args.device)
model.eval()
# Load test dataset (for the chosen dataset)
dataset = CXRInstructDataset(
data_path = spec.instruct_json,
image_root = spec.image_root,
tokenizer = model.tokenizer,
transform = BioViLTEncoder.get_transform("val"),
task = "mixed",
split = args.split,
cutoff_len = train_cfg.training.cutoff_len,
task_weights = spec.task_weights,
max_images = spec.max_images,
feature_cache_dir = getattr(train_cfg.data, "feature_cache_dir", None) or None,
)
# Build task list, intersected with what's available for this dataset.
if args.task == "all":
tasks_to_eval = list(spec.tasks)
else:
if args.task not in spec.tasks:
logger.warning(
f"Task '{args.task}' not available for {spec.dataset_name} "
f"(has: {spec.tasks}). Skipping."
)
tasks_to_eval = []
else:
tasks_to_eval = [args.task]
all_results = {}
for task in tasks_to_eval:
logger.info(f"\nEvaluating task: {task.upper()}")
predictions = run_inference(
model = model,
dataset = dataset,
task = task,
batch_size = args.batch_size,
max_new_tokens = args.max_new_tokens,
device = args.device,
)
if not predictions["hypotheses"]:
logger.warning(f"No samples found for task: {task}")
continue
save_predictions(predictions, task, str(results_dir))
metrics = evaluate_all(
hypotheses = predictions["hypotheses"],
references = predictions["references"],
task = task,
chexbert_path = args.chexbert_path,
device = args.device,
questions = predictions.get("questions"),
llm_judge = args.llm_judge and task == "vqa",
llm_judge_model = args.llm_judge_model,
llm_judge_base_url = args.llm_judge_base_url,
llm_judge_max_samples = args.llm_judge_max_samples,
)
print_results(metrics, task)
all_results[task] = metrics
# ── If task is "report" (merged mode), also report per-section
# metrics by parsing the generated and reference reports back into
# findings / impression. This gives an apples-to-apples comparison
# against a previous split-mode run that reports those numbers.
if task == "report":
logger.info("\n[report] Computing per-section sub-metrics (parsed)…")
hyp_f, hyp_i, ref_f, ref_i = [], [], [], []
for h, r in zip(predictions["hypotheses"], predictions["references"]):
hp = parse_generated_report(h)
rp = parse_generated_report(r)
hyp_f.append(hp["findings"]); ref_f.append(rp["findings"])
hyp_i.append(hp["impression"]); ref_i.append(rp["impression"])
# Drop pairs where reference section is empty (cannot score them).
def _filter(hyps, refs):
pairs = [(h, r) for h, r in zip(hyps, refs) if r.strip()]
return [h for h, _ in pairs], [r for _, r in pairs]
f_h, f_r = _filter(hyp_f, ref_f)
i_h, i_r = _filter(hyp_i, ref_i)
if f_h:
m_f = evaluate_all(f_h, f_r, task="findings",
chexbert_path=args.chexbert_path, device=args.device)
print_results(m_f, "report→findings")
all_results["report__findings_only"] = m_f
if i_h:
m_i = evaluate_all(i_h, i_r, task="impression",
chexbert_path=args.chexbert_path, device=args.device)
print_results(m_i, "report→impression")
all_results["report__impression_only"] = m_i
# Save all metrics summary
summary_path = results_dir / "metrics_summary.json"
with open(summary_path, "w") as f:
json.dump(
{"dataset_name": spec.dataset_name, "run_id": run_id,
"split": args.split, "metrics": all_results},
f, indent=2,
)
logger.info(f"\nMetrics summary saved to {summary_path}")
# ── HF Hub upload: results folder ────────────────────────────────
if tracker is not None:
tracker.upload_folder(
str(results_dir),
"results",
allow_patterns = ["*.json"],
)
tracker.write_meta({
"dataset_name": spec.dataset_name,
"eval_done": True,
"eval_split": args.split,
"eval_tasks": tasks_to_eval,
"eval_checkpoint": args.checkpoint,
})
logger.info(f"Results uploaded to HF Hub → {tracker.repo_id} / {run_id}/results")
if __name__ == "__main__":
main()