| """ |
| evaluate.py |
| ----------- |
| Evaluation entry point. Runs inference on the chosen split and computes |
| all metrics per task (findings, impression, VQA). |
| |
| The dataset is selected via `train_cfg.data.dataset_name`: |
| - "MIMIC-CXR" → evaluates findings, impression, VQA |
| - "IU-Xray" → evaluates findings, impression only |
| |
| Results are saved under: |
| {output_dir}/{dataset_name}_run_{N}/predictions_{task}.json |
| {output_dir}/{dataset_name}_run_{N}/metrics_summary.json |
| |
| Usage (local checkpoint): |
| python -m evaluation.evaluate \ |
| --model_config configs/model_config.yaml \ |
| --train_config configs/train_config.yaml \ |
| --checkpoint checkpoints/IU-Xray_run_1/stage2_instruct/stage2_final.pt \ |
| --task all \ |
| --output_dir results/ |
| |
| Usage (pull best/ from HF Hub first): |
| huggingface-cli download <user>/cxr-vlm-runs \ |
| IU-Xray_run_1/stage2/best --local-dir ./hf_pulled |
| python -m evaluation.evaluate \ |
| --checkpoint ./hf_pulled/IU-Xray_run_1/stage2/best/checkpoint_projection.pt \ |
| --task all --output_dir results/ |
| |
| The `--checkpoint` arg may point at any `<dir>/<name>_projection.pt`; the loader |
| also picks up `<dir>/<name>_lora/` and `<dir>/<name>_chexpert_classifier.pt` |
| from the same folder. |
| """ |
|
|
| import os |
| import sys |
| from pathlib import Path |
|
|
| |
| sys.path.insert(0, str(Path(__file__).resolve().parents[1])) |
| import utils._quiet |
|
|
| import json |
| import argparse |
| from typing import List, Dict, Optional |
|
|
| import torch |
| from torch.utils.data import DataLoader |
| from omegaconf import OmegaConf |
| from tqdm.auto import tqdm |
|
|
| sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) |
|
|
| from model import CXRVisionLanguageModel |
| from model.rad_dino import BioViLTEncoder |
| from data import CXRInstructDataset, CXRDataCollator |
| from data.prompt_templates import ( |
| build_findings_prompt, |
| build_impression_prompt, |
| build_report_prompt, |
| build_vqa_prompt, |
| ) |
| from data.dataset import parse_generated_report |
| from evaluation.metrics import evaluate_all, print_results |
| from utils.logger import setup_logger |
| from utils.checkpoint import load_checkpoint |
| from utils.hf_uploader import build_tracker_from_cfg |
| from utils.dataset_resolver import resolve_dataset_spec, resolve_run_id |
|
|
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser(description="Evaluate CXR VLM") |
| parser.add_argument("--model_config", type=str, default="configs/model_config.yaml") |
| parser.add_argument("--train_config", type=str, default="configs/train_config.yaml") |
| parser.add_argument("--checkpoint", type=str, required=True, |
| help="Path to model checkpoint") |
| parser.add_argument("--task", type=str, default="all", |
| choices=["findings", "impression", "report", "vqa", "all"]) |
| parser.add_argument("--split", type=str, default="test") |
| parser.add_argument("--output_dir", type=str, default="results/", |
| help="Root dir; results land under {output_dir}/{run_id}/") |
| parser.add_argument("--chexbert_path", type=str, default=None, |
| help="Path to CheXbert weights for ClinicalF1") |
| parser.add_argument("--batch_size", type=int, default=8) |
| parser.add_argument("--max_new_tokens", type=int, default=300) |
| parser.add_argument("--device", type=str, default="cuda") |
| parser.add_argument("--run_id", type=str, default=None, |
| help="Explicit run id (e.g. 'IU-Xray_run_3'). " |
| "If unset, resolved from state file.") |
| parser.add_argument("--no_hf_upload", action="store_true", |
| help="Disable HuggingFace Hub upload of predictions/metrics.") |
| |
| parser.add_argument("--llm_judge", action="store_true", |
| help="Enable LLM-as-judge semantic scoring for VQA. " |
| "Requires OPENAI_API_KEY (or compatible).") |
| parser.add_argument("--llm_judge_model", type=str, default="gpt-4o-mini", |
| help="Judge model name. Default: gpt-4o-mini " |
| "(~$0.30 / 2k VQA samples).") |
| parser.add_argument("--llm_judge_base_url", type=str, default=None, |
| help="Override base URL for non-OpenAI providers " |
| "(e.g. Gemini OpenAI-compat endpoint).") |
| parser.add_argument("--llm_judge_max_samples", type=int, default=None, |
| help="Cap number of samples sent to the judge (cost control).") |
| return parser.parse_args() |
|
|
|
|
| @torch.no_grad() |
| def run_inference( |
| model, |
| dataset: CXRInstructDataset, |
| task: str, |
| batch_size: int, |
| max_new_tokens: int, |
| device: str, |
| ) -> Dict[str, List[str]]: |
| """ |
| Run inference on a dataset split for a specific task. |
| |
| Returns: |
| {"hypotheses": [...], "references": [...], "questions": [...]} |
| """ |
| task_samples = [s for s in dataset.samples if s["task"] == task] |
| if not task_samples: |
| return {"hypotheses": [], "references": [], "questions": []} |
|
|
| hypotheses, references, questions = [], [], [] |
|
|
| for i in tqdm(range(0, len(task_samples), batch_size), |
| desc=f"Evaluating {task}"): |
| batch_samples = task_samples[i:i + batch_size] |
|
|
| images, prompts = [], [] |
| for s in batch_samples: |
| |
| |
| if s.get("image_paths"): |
| img = dataset._load_image_stack(s["image_paths"]) |
| else: |
| img = dataset._load_image(s["image_path"]) |
| images.append(img) |
|
|
| sf = s.get("structured_findings") |
| if task == "findings": |
| prompt = build_findings_prompt(sf, randomize=False) |
| elif task == "impression": |
| prompt = build_impression_prompt(sf, randomize=False) |
| elif task == "report": |
| prompt = build_report_prompt(sf, randomize=False) |
| else: |
| prompt = build_vqa_prompt(s["question"], sf) |
| prompts.append(prompt) |
|
|
| images_tensor = torch.stack(images).to(device) |
|
|
| generated = model.generate( |
| images = images_tensor, |
| prompts = prompts, |
| max_new_tokens = max_new_tokens, |
| ) |
|
|
| hypotheses.extend(generated) |
| references.extend([s["target"] for s in batch_samples]) |
| if task == "vqa": |
| questions.extend([s.get("question", "") for s in batch_samples]) |
|
|
| return {"hypotheses": hypotheses, "references": references, "questions": questions} |
|
|
|
|
| def save_predictions(predictions: Dict, task: str, output_dir: str): |
| """Save predictions to JSON for later analysis.""" |
| output_dir = Path(output_dir) |
| output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| output_path = output_dir / f"predictions_{task}.json" |
| records = [] |
| for i, (hyp, ref) in enumerate( |
| zip(predictions["hypotheses"], predictions["references"]) |
| ): |
| record = {"hypothesis": hyp, "reference": ref} |
| if predictions.get("questions") and i < len(predictions["questions"]): |
| record["question"] = predictions["questions"][i] |
| records.append(record) |
|
|
| with open(output_path, "w") as f: |
| json.dump(records, f, indent=2) |
|
|
| print(f"Predictions saved to {output_path}") |
|
|
|
|
| def main(): |
| args = parse_args() |
| logger = setup_logger("cxr_vlm_eval") |
|
|
| model_cfg = OmegaConf.load(args.model_config) |
| train_cfg = OmegaConf.load(args.train_config) |
|
|
| |
| spec = resolve_dataset_spec(train_cfg) |
| logger.info(f"Dataset: {spec.dataset_name}") |
|
|
| output_root = str(train_cfg.training.get("output_root", "checkpoints")) |
| state_file = str(train_cfg.hf_hub.run_state_file) |
| hf_token = os.environ.get( |
| train_cfg.hf_hub.token_env, os.environ.get("HF_TOKEN") |
| ) if train_cfg.hf_hub.enabled else None |
| hf_repo_id = train_cfg.hf_hub.repo_id if train_cfg.hf_hub.enabled else None |
| |
| run_id = resolve_run_id( |
| dataset_name = spec.dataset_name, |
| output_root = output_root, |
| state_file = state_file, |
| resuming = True, |
| explicit = args.run_id, |
| hf_repo_id = hf_repo_id, |
| hf_token = hf_token, |
| ) |
| logger.info(f"run_id = {run_id}") |
|
|
| |
| results_dir = Path(args.output_dir) / run_id |
| results_dir.mkdir(parents=True, exist_ok=True) |
|
|
| |
| tracker = None |
| if not args.no_hf_upload: |
| tracker = build_tracker_from_cfg( |
| train_cfg, |
| resuming = True, |
| explicit_run_id = run_id, |
| ) |
|
|
| |
| logger.info(f"Loading model from checkpoint: {args.checkpoint}") |
| model = CXRVisionLanguageModel(model_cfg) |
| load_checkpoint(model, args.checkpoint) |
| model = model.to(args.device) |
| model.eval() |
|
|
| |
| dataset = CXRInstructDataset( |
| data_path = spec.instruct_json, |
| image_root = spec.image_root, |
| tokenizer = model.tokenizer, |
| transform = BioViLTEncoder.get_transform("val"), |
| task = "mixed", |
| split = args.split, |
| cutoff_len = train_cfg.training.cutoff_len, |
| task_weights = spec.task_weights, |
| max_images = spec.max_images, |
| feature_cache_dir = getattr(train_cfg.data, "feature_cache_dir", None) or None, |
| ) |
|
|
| |
| if args.task == "all": |
| tasks_to_eval = list(spec.tasks) |
| else: |
| if args.task not in spec.tasks: |
| logger.warning( |
| f"Task '{args.task}' not available for {spec.dataset_name} " |
| f"(has: {spec.tasks}). Skipping." |
| ) |
| tasks_to_eval = [] |
| else: |
| tasks_to_eval = [args.task] |
|
|
| all_results = {} |
|
|
| for task in tasks_to_eval: |
| logger.info(f"\nEvaluating task: {task.upper()}") |
|
|
| predictions = run_inference( |
| model = model, |
| dataset = dataset, |
| task = task, |
| batch_size = args.batch_size, |
| max_new_tokens = args.max_new_tokens, |
| device = args.device, |
| ) |
|
|
| if not predictions["hypotheses"]: |
| logger.warning(f"No samples found for task: {task}") |
| continue |
|
|
| save_predictions(predictions, task, str(results_dir)) |
|
|
| metrics = evaluate_all( |
| hypotheses = predictions["hypotheses"], |
| references = predictions["references"], |
| task = task, |
| chexbert_path = args.chexbert_path, |
| device = args.device, |
| questions = predictions.get("questions"), |
| llm_judge = args.llm_judge and task == "vqa", |
| llm_judge_model = args.llm_judge_model, |
| llm_judge_base_url = args.llm_judge_base_url, |
| llm_judge_max_samples = args.llm_judge_max_samples, |
| ) |
|
|
| print_results(metrics, task) |
| all_results[task] = metrics |
|
|
| |
| |
| |
| |
| if task == "report": |
| logger.info("\n[report] Computing per-section sub-metrics (parsed)…") |
| hyp_f, hyp_i, ref_f, ref_i = [], [], [], [] |
| for h, r in zip(predictions["hypotheses"], predictions["references"]): |
| hp = parse_generated_report(h) |
| rp = parse_generated_report(r) |
| hyp_f.append(hp["findings"]); ref_f.append(rp["findings"]) |
| hyp_i.append(hp["impression"]); ref_i.append(rp["impression"]) |
|
|
| |
| def _filter(hyps, refs): |
| pairs = [(h, r) for h, r in zip(hyps, refs) if r.strip()] |
| return [h for h, _ in pairs], [r for _, r in pairs] |
|
|
| f_h, f_r = _filter(hyp_f, ref_f) |
| i_h, i_r = _filter(hyp_i, ref_i) |
|
|
| if f_h: |
| m_f = evaluate_all(f_h, f_r, task="findings", |
| chexbert_path=args.chexbert_path, device=args.device) |
| print_results(m_f, "report→findings") |
| all_results["report__findings_only"] = m_f |
| if i_h: |
| m_i = evaluate_all(i_h, i_r, task="impression", |
| chexbert_path=args.chexbert_path, device=args.device) |
| print_results(m_i, "report→impression") |
| all_results["report__impression_only"] = m_i |
|
|
| |
| summary_path = results_dir / "metrics_summary.json" |
| with open(summary_path, "w") as f: |
| json.dump( |
| {"dataset_name": spec.dataset_name, "run_id": run_id, |
| "split": args.split, "metrics": all_results}, |
| f, indent=2, |
| ) |
| logger.info(f"\nMetrics summary saved to {summary_path}") |
|
|
| |
| if tracker is not None: |
| tracker.upload_folder( |
| str(results_dir), |
| "results", |
| allow_patterns = ["*.json"], |
| ) |
| tracker.write_meta({ |
| "dataset_name": spec.dataset_name, |
| "eval_done": True, |
| "eval_split": args.split, |
| "eval_tasks": tasks_to_eval, |
| "eval_checkpoint": args.checkpoint, |
| }) |
| logger.info(f"Results uploaded to HF Hub → {tracker.repo_id} / {run_id}/results") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|