cxr-vlm-code / evaluation /inference.py
convitom
fix: complete image_encoder→rad_dino rename (repair broken HEAD)
93ddbd3
"""
inference.py
------------
Run a single CXR image through the trained CXR-VLM and print the model's
output. Supports findings / impression / VQA tasks, and runs on either
GPU or CPU.
Examples
--------
# 1) Findings (default — no prompt needed)
python -m evaluation.inference \
--image /path/to/cxr.jpg \
--checkpoint checkpoints/IU-Xray_run_1/stage2_instruct/stage2_final.pt
# 2) Impression
python -m evaluation.inference \
--image /path/to/cxr.jpg \
--task impression \
--checkpoint .../stage2_final.pt
# 3) VQA — pass the clinical question via --prompt
python -m evaluation.inference \
--image /path/to/cxr.jpg \
--task vqa \
--prompt "Is there a pleural effusion?" \
--checkpoint .../stage2_final.pt
# 4) Free-form prompt (auto-treats as VQA-style)
python -m evaluation.inference \
--image /path/to/cxr.jpg \
--prompt "Describe abnormalities in the left lung." \
--checkpoint .../stage2_final.pt
# 5) Force CPU (slow — Vicuna-7B in fp32 on CPU is multi-second per token,
# needs ~28 GB RAM. 4-bit quantization is auto-disabled because bnb
# requires CUDA.)
python -m evaluation.inference --image cxr.jpg --device cpu --checkpoint ...
Device options
--------------
--device auto (default) cuda if a GPU is visible, else cpu
--device cuda explicit GPU
--device cpu CPU only — see CPU notes above
Checkpoint path
---------------
Pass the path to the projection .pt file. The loader reads
`<dir>/<name>_projection.pt`, `<dir>/<name>_lora/`, and (optional)
`<dir>/<name>_chexpert_classifier.pt` from the same folder. Both layouts
work:
.../stage2_instruct/stage2_final.pt (after train.py finishes)
.../stage2/best/checkpoint_projection.pt (pulled from HF Hub)
"""
import argparse
import sys
from pathlib import Path
# Silence HF per-shard download tqdm spam — must run before transformers import.
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
import utils._quiet # noqa: F401
import torch
from PIL import Image
from omegaconf import OmegaConf
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from model import CXRVisionLanguageModel
from model.rad_dino import BioViLTEncoder
from data.prompt_templates import (
build_findings_prompt,
build_impression_prompt,
build_vqa_prompt,
)
from utils.checkpoint import load_checkpoint
from utils.logger import setup_logger
def parse_args():
p = argparse.ArgumentParser(description="CXR-VLM single-image inference")
p.add_argument("--image", type=str, required=True,
help="Path to a CXR image (.jpg / .png).")
p.add_argument("--prompt", type=str, default=None,
help="Optional question or instruction. If omitted, the "
"default template for the chosen --task is used. "
"If supplied with --task findings/impression, the "
"prompt overrides the canned instruction.")
p.add_argument("--task", type=str, default="findings",
choices=["findings", "impression", "vqa"],
help="Prompt template family. Default: findings. "
"For 'vqa', --prompt is required.")
p.add_argument("--checkpoint", type=str, required=True,
help="Path to a stage-2 projection .pt file (the loader "
"also picks up the matching _lora/ folder beside it).")
p.add_argument("--model_config", type=str, default="configs/model_config.yaml")
p.add_argument("--device", type=str, default="auto",
choices=["auto", "cuda", "cpu"],
help="auto = cuda if available, else cpu.")
p.add_argument("--max_new_tokens", type=int, default=300)
p.add_argument("--temperature", type=float, default=0.0,
help="0.0 → greedy. >0 enables sampling.")
p.add_argument("--num_beams", type=int, default=1,
help="Beam search width. 1 = greedy/sampling.")
p.add_argument("--structured_findings", type=str, default=None,
help='Optional "Predicted Findings: ..." string to inject '
"as context (normally produced by the CheXpert "
"classifier; pass it manually if you have one).")
p.add_argument("--cpu_dtype", type=str, default="float32",
choices=["float32", "float16", "bfloat16"],
help="LLM dtype on CPU. float32 is safest (~28 GB RAM); "
"float16/bfloat16 halve RAM but are very slow on most "
"CPUs. Ignored on GPU.")
return p.parse_args()
def resolve_device(choice: str, logger) -> str:
if choice == "auto":
d = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"--device auto → resolved to {d}")
return d
if choice == "cuda" and not torch.cuda.is_available():
raise SystemExit("--device cuda requested but no CUDA GPU is visible. "
"Use --device cpu (slow) or --device auto.")
return choice
def patch_cfg_for_device(model_cfg, device: str, cpu_dtype: str, logger):
"""Mutate model_cfg in place so it can build cleanly on the chosen device.
On CPU, bitsandbytes 4-bit / 8-bit cannot be used (bnb is CUDA-only) —
silently disable them and fall back to a plain fp dtype.
"""
if device == "cpu":
if (getattr(model_cfg.llm, "load_in_4bit", False)
or getattr(model_cfg.llm, "load_in_8bit", False)):
logger.warning(
"CPU device: disabling 4-bit / 8-bit quantization "
"(bitsandbytes requires CUDA). Falling back to dtype="
f"{cpu_dtype}. Vicuna-7B in float32 needs ~28 GB RAM; "
"in float16/bfloat16 ~14 GB but very slow on CPU."
)
model_cfg.llm.load_in_4bit = False
model_cfg.llm.load_in_8bit = False
model_cfg.llm.torch_dtype = cpu_dtype
# On CPU we can't use `device_map='auto'` (accelerate would try to find
# a GPU). Set to None so HF loads everything on the default device (CPU).
model_cfg.llm.device_map = None
else:
# GPU — keep whatever the config already says (4-bit on T4, bf16 on A100, …).
if not torch.cuda.is_available():
raise RuntimeError("Resolved device=cuda but torch.cuda.is_available()==False")
def build_prompt(task: str, prompt: str | None, structured_findings: str | None) -> str:
"""Decide which prompt template to use given (task, prompt)."""
sf = structured_findings
if task == "vqa":
if not prompt:
raise SystemExit("--task vqa requires --prompt <question>")
return build_vqa_prompt(prompt, sf)
# findings / impression: if user passed a free-form prompt, treat it as
# an instruction (same template family as VQA — single instruction line).
if prompt:
return build_vqa_prompt(prompt, sf)
if task == "findings":
return build_findings_prompt(sf, randomize=False)
if task == "impression":
return build_impression_prompt(sf, randomize=False)
raise ValueError(f"Unknown task: {task}")
def main():
args = parse_args()
logger = setup_logger("cxr_vlm_infer")
device = resolve_device(args.device, logger)
# ── Load + patch model config ────────────────────────────────
model_cfg = OmegaConf.load(args.model_config)
patch_cfg_for_device(model_cfg, device, args.cpu_dtype, logger)
logger.info(f"image = {args.image}")
logger.info(f"checkpoint = {args.checkpoint}")
logger.info(f"task = {args.task}")
logger.info(f"prompt = {args.prompt!r}")
logger.info(f"device = {device}")
# ── Build image tensor ───────────────────────────────────────
img_path = Path(args.image)
if not img_path.is_file():
raise SystemExit(f"Image not found: {img_path}")
transform = BioViLTEncoder.get_transform("val")
img = Image.open(img_path).convert("RGB")
img_t = transform(img).unsqueeze(0) # (1, C, H, W)
# ── Build model + load trained weights ───────────────────────
logger.info("Building CXR-VLM …")
model = CXRVisionLanguageModel(model_cfg)
load_checkpoint(model, args.checkpoint)
# `device_map='auto'` with HF accelerate places submodules already; calling
# .to(device) on top of that is a no-op for the LLM. We still need to move
# the projection MLP, BioViL-T encoder, and (optional) CheXpert head.
if device == "cuda":
model = model.to("cuda")
img_t = img_t.to("cuda")
else:
model = model.to("cpu")
img_t = img_t.to("cpu")
model.eval()
# ── Build the formatted prompt ───────────────────────────────
full_prompt = build_prompt(args.task, args.prompt, args.structured_findings)
logger.info(f"full prompt = {full_prompt!r}")
# ── Generate ─────────────────────────────────────────────────
logger.info("Generating …")
do_sample = args.temperature > 0.0
output = model.generate(
images = img_t,
prompts = [full_prompt],
max_new_tokens = args.max_new_tokens,
temperature = max(args.temperature, 1e-5), # HF complains if exactly 0
do_sample = do_sample,
num_beams = args.num_beams,
)[0]
# ── Pretty-print result ──────────────────────────────────────
print()
print("=" * 78)
print(f"Image : {args.image}")
print(f"Task : {args.task}")
if args.prompt:
print(f"Prompt : {args.prompt}")
print("-" * 78)
print(output.strip() if output.strip() else "<empty response>")
print("=" * 78)
if __name__ == "__main__":
main()