""" inference.py ------------ Run a single CXR image through the trained CXR-VLM and print the model's output. Supports findings / impression / VQA tasks, and runs on either GPU or CPU. Examples -------- # 1) Findings (default — no prompt needed) python -m evaluation.inference \ --image /path/to/cxr.jpg \ --checkpoint checkpoints/IU-Xray_run_1/stage2_instruct/stage2_final.pt # 2) Impression python -m evaluation.inference \ --image /path/to/cxr.jpg \ --task impression \ --checkpoint .../stage2_final.pt # 3) VQA — pass the clinical question via --prompt python -m evaluation.inference \ --image /path/to/cxr.jpg \ --task vqa \ --prompt "Is there a pleural effusion?" \ --checkpoint .../stage2_final.pt # 4) Free-form prompt (auto-treats as VQA-style) python -m evaluation.inference \ --image /path/to/cxr.jpg \ --prompt "Describe abnormalities in the left lung." \ --checkpoint .../stage2_final.pt # 5) Force CPU (slow — Vicuna-7B in fp32 on CPU is multi-second per token, # needs ~28 GB RAM. 4-bit quantization is auto-disabled because bnb # requires CUDA.) python -m evaluation.inference --image cxr.jpg --device cpu --checkpoint ... Device options -------------- --device auto (default) cuda if a GPU is visible, else cpu --device cuda explicit GPU --device cpu CPU only — see CPU notes above Checkpoint path --------------- Pass the path to the projection .pt file. The loader reads `/_projection.pt`, `/_lora/`, and (optional) `/_chexpert_classifier.pt` from the same folder. Both layouts work: .../stage2_instruct/stage2_final.pt (after train.py finishes) .../stage2/best/checkpoint_projection.pt (pulled from HF Hub) """ import argparse import sys from pathlib import Path # Silence HF per-shard download tqdm spam — must run before transformers import. sys.path.insert(0, str(Path(__file__).resolve().parents[1])) import utils._quiet # noqa: F401 import torch from PIL import Image from omegaconf import OmegaConf sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) from model import CXRVisionLanguageModel from model.rad_dino import BioViLTEncoder from data.prompt_templates import ( build_findings_prompt, build_impression_prompt, build_vqa_prompt, ) from utils.checkpoint import load_checkpoint from utils.logger import setup_logger def parse_args(): p = argparse.ArgumentParser(description="CXR-VLM single-image inference") p.add_argument("--image", type=str, required=True, help="Path to a CXR image (.jpg / .png).") p.add_argument("--prompt", type=str, default=None, help="Optional question or instruction. If omitted, the " "default template for the chosen --task is used. " "If supplied with --task findings/impression, the " "prompt overrides the canned instruction.") p.add_argument("--task", type=str, default="findings", choices=["findings", "impression", "vqa"], help="Prompt template family. Default: findings. " "For 'vqa', --prompt is required.") p.add_argument("--checkpoint", type=str, required=True, help="Path to a stage-2 projection .pt file (the loader " "also picks up the matching _lora/ folder beside it).") p.add_argument("--model_config", type=str, default="configs/model_config.yaml") p.add_argument("--device", type=str, default="auto", choices=["auto", "cuda", "cpu"], help="auto = cuda if available, else cpu.") p.add_argument("--max_new_tokens", type=int, default=300) p.add_argument("--temperature", type=float, default=0.0, help="0.0 → greedy. >0 enables sampling.") p.add_argument("--num_beams", type=int, default=1, help="Beam search width. 1 = greedy/sampling.") p.add_argument("--structured_findings", type=str, default=None, help='Optional "Predicted Findings: ..." string to inject ' "as context (normally produced by the CheXpert " "classifier; pass it manually if you have one).") p.add_argument("--cpu_dtype", type=str, default="float32", choices=["float32", "float16", "bfloat16"], help="LLM dtype on CPU. float32 is safest (~28 GB RAM); " "float16/bfloat16 halve RAM but are very slow on most " "CPUs. Ignored on GPU.") return p.parse_args() def resolve_device(choice: str, logger) -> str: if choice == "auto": d = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"--device auto → resolved to {d}") return d if choice == "cuda" and not torch.cuda.is_available(): raise SystemExit("--device cuda requested but no CUDA GPU is visible. " "Use --device cpu (slow) or --device auto.") return choice def patch_cfg_for_device(model_cfg, device: str, cpu_dtype: str, logger): """Mutate model_cfg in place so it can build cleanly on the chosen device. On CPU, bitsandbytes 4-bit / 8-bit cannot be used (bnb is CUDA-only) — silently disable them and fall back to a plain fp dtype. """ if device == "cpu": if (getattr(model_cfg.llm, "load_in_4bit", False) or getattr(model_cfg.llm, "load_in_8bit", False)): logger.warning( "CPU device: disabling 4-bit / 8-bit quantization " "(bitsandbytes requires CUDA). Falling back to dtype=" f"{cpu_dtype}. Vicuna-7B in float32 needs ~28 GB RAM; " "in float16/bfloat16 ~14 GB but very slow on CPU." ) model_cfg.llm.load_in_4bit = False model_cfg.llm.load_in_8bit = False model_cfg.llm.torch_dtype = cpu_dtype # On CPU we can't use `device_map='auto'` (accelerate would try to find # a GPU). Set to None so HF loads everything on the default device (CPU). model_cfg.llm.device_map = None else: # GPU — keep whatever the config already says (4-bit on T4, bf16 on A100, …). if not torch.cuda.is_available(): raise RuntimeError("Resolved device=cuda but torch.cuda.is_available()==False") def build_prompt(task: str, prompt: str | None, structured_findings: str | None) -> str: """Decide which prompt template to use given (task, prompt).""" sf = structured_findings if task == "vqa": if not prompt: raise SystemExit("--task vqa requires --prompt ") return build_vqa_prompt(prompt, sf) # findings / impression: if user passed a free-form prompt, treat it as # an instruction (same template family as VQA — single instruction line). if prompt: return build_vqa_prompt(prompt, sf) if task == "findings": return build_findings_prompt(sf, randomize=False) if task == "impression": return build_impression_prompt(sf, randomize=False) raise ValueError(f"Unknown task: {task}") def main(): args = parse_args() logger = setup_logger("cxr_vlm_infer") device = resolve_device(args.device, logger) # ── Load + patch model config ──────────────────────────────── model_cfg = OmegaConf.load(args.model_config) patch_cfg_for_device(model_cfg, device, args.cpu_dtype, logger) logger.info(f"image = {args.image}") logger.info(f"checkpoint = {args.checkpoint}") logger.info(f"task = {args.task}") logger.info(f"prompt = {args.prompt!r}") logger.info(f"device = {device}") # ── Build image tensor ─────────────────────────────────────── img_path = Path(args.image) if not img_path.is_file(): raise SystemExit(f"Image not found: {img_path}") transform = BioViLTEncoder.get_transform("val") img = Image.open(img_path).convert("RGB") img_t = transform(img).unsqueeze(0) # (1, C, H, W) # ── Build model + load trained weights ─────────────────────── logger.info("Building CXR-VLM …") model = CXRVisionLanguageModel(model_cfg) load_checkpoint(model, args.checkpoint) # `device_map='auto'` with HF accelerate places submodules already; calling # .to(device) on top of that is a no-op for the LLM. We still need to move # the projection MLP, BioViL-T encoder, and (optional) CheXpert head. if device == "cuda": model = model.to("cuda") img_t = img_t.to("cuda") else: model = model.to("cpu") img_t = img_t.to("cpu") model.eval() # ── Build the formatted prompt ─────────────────────────────── full_prompt = build_prompt(args.task, args.prompt, args.structured_findings) logger.info(f"full prompt = {full_prompt!r}") # ── Generate ───────────────────────────────────────────────── logger.info("Generating …") do_sample = args.temperature > 0.0 output = model.generate( images = img_t, prompts = [full_prompt], max_new_tokens = args.max_new_tokens, temperature = max(args.temperature, 1e-5), # HF complains if exactly 0 do_sample = do_sample, num_beams = args.num_beams, )[0] # ── Pretty-print result ────────────────────────────────────── print() print("=" * 78) print(f"Image : {args.image}") print(f"Task : {args.task}") if args.prompt: print(f"Prompt : {args.prompt}") print("-" * 78) print(output.strip() if output.strip() else "") print("=" * 78) if __name__ == "__main__": main()