| """ |
| inference.py |
| ------------ |
| Run a single CXR image through the trained CXR-VLM and print the model's |
| output. Supports findings / impression / VQA tasks, and runs on either |
| GPU or CPU. |
| |
| Examples |
| -------- |
| # 1) Findings (default — no prompt needed) |
| python -m evaluation.inference \ |
| --image /path/to/cxr.jpg \ |
| --checkpoint checkpoints/IU-Xray_run_1/stage2_instruct/stage2_final.pt |
| |
| # 2) Impression |
| python -m evaluation.inference \ |
| --image /path/to/cxr.jpg \ |
| --task impression \ |
| --checkpoint .../stage2_final.pt |
| |
| # 3) VQA — pass the clinical question via --prompt |
| python -m evaluation.inference \ |
| --image /path/to/cxr.jpg \ |
| --task vqa \ |
| --prompt "Is there a pleural effusion?" \ |
| --checkpoint .../stage2_final.pt |
| |
| # 4) Free-form prompt (auto-treats as VQA-style) |
| python -m evaluation.inference \ |
| --image /path/to/cxr.jpg \ |
| --prompt "Describe abnormalities in the left lung." \ |
| --checkpoint .../stage2_final.pt |
| |
| # 5) Force CPU (slow — Vicuna-7B in fp32 on CPU is multi-second per token, |
| # needs ~28 GB RAM. 4-bit quantization is auto-disabled because bnb |
| # requires CUDA.) |
| python -m evaluation.inference --image cxr.jpg --device cpu --checkpoint ... |
| |
| Device options |
| -------------- |
| --device auto (default) cuda if a GPU is visible, else cpu |
| --device cuda explicit GPU |
| --device cpu CPU only — see CPU notes above |
| |
| Checkpoint path |
| --------------- |
| Pass the path to the projection .pt file. The loader reads |
| `<dir>/<name>_projection.pt`, `<dir>/<name>_lora/`, and (optional) |
| `<dir>/<name>_chexpert_classifier.pt` from the same folder. Both layouts |
| work: |
| .../stage2_instruct/stage2_final.pt (after train.py finishes) |
| .../stage2/best/checkpoint_projection.pt (pulled from HF Hub) |
| """ |
|
|
| import argparse |
| import sys |
| from pathlib import Path |
|
|
| |
| sys.path.insert(0, str(Path(__file__).resolve().parents[1])) |
| import utils._quiet |
|
|
| import torch |
| from PIL import Image |
| from omegaconf import OmegaConf |
|
|
| sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) |
|
|
| from model import CXRVisionLanguageModel |
| from model.rad_dino import BioViLTEncoder |
| from data.prompt_templates import ( |
| build_findings_prompt, |
| build_impression_prompt, |
| build_vqa_prompt, |
| ) |
| from utils.checkpoint import load_checkpoint |
| from utils.logger import setup_logger |
|
|
|
|
| def parse_args(): |
| p = argparse.ArgumentParser(description="CXR-VLM single-image inference") |
| p.add_argument("--image", type=str, required=True, |
| help="Path to a CXR image (.jpg / .png).") |
| p.add_argument("--prompt", type=str, default=None, |
| help="Optional question or instruction. If omitted, the " |
| "default template for the chosen --task is used. " |
| "If supplied with --task findings/impression, the " |
| "prompt overrides the canned instruction.") |
| p.add_argument("--task", type=str, default="findings", |
| choices=["findings", "impression", "vqa"], |
| help="Prompt template family. Default: findings. " |
| "For 'vqa', --prompt is required.") |
| p.add_argument("--checkpoint", type=str, required=True, |
| help="Path to a stage-2 projection .pt file (the loader " |
| "also picks up the matching _lora/ folder beside it).") |
| p.add_argument("--model_config", type=str, default="configs/model_config.yaml") |
| p.add_argument("--device", type=str, default="auto", |
| choices=["auto", "cuda", "cpu"], |
| help="auto = cuda if available, else cpu.") |
| p.add_argument("--max_new_tokens", type=int, default=300) |
| p.add_argument("--temperature", type=float, default=0.0, |
| help="0.0 → greedy. >0 enables sampling.") |
| p.add_argument("--num_beams", type=int, default=1, |
| help="Beam search width. 1 = greedy/sampling.") |
| p.add_argument("--structured_findings", type=str, default=None, |
| help='Optional "Predicted Findings: ..." string to inject ' |
| "as context (normally produced by the CheXpert " |
| "classifier; pass it manually if you have one).") |
| p.add_argument("--cpu_dtype", type=str, default="float32", |
| choices=["float32", "float16", "bfloat16"], |
| help="LLM dtype on CPU. float32 is safest (~28 GB RAM); " |
| "float16/bfloat16 halve RAM but are very slow on most " |
| "CPUs. Ignored on GPU.") |
| return p.parse_args() |
|
|
|
|
| def resolve_device(choice: str, logger) -> str: |
| if choice == "auto": |
| d = "cuda" if torch.cuda.is_available() else "cpu" |
| logger.info(f"--device auto → resolved to {d}") |
| return d |
| if choice == "cuda" and not torch.cuda.is_available(): |
| raise SystemExit("--device cuda requested but no CUDA GPU is visible. " |
| "Use --device cpu (slow) or --device auto.") |
| return choice |
|
|
|
|
| def patch_cfg_for_device(model_cfg, device: str, cpu_dtype: str, logger): |
| """Mutate model_cfg in place so it can build cleanly on the chosen device. |
| |
| On CPU, bitsandbytes 4-bit / 8-bit cannot be used (bnb is CUDA-only) — |
| silently disable them and fall back to a plain fp dtype. |
| """ |
| if device == "cpu": |
| if (getattr(model_cfg.llm, "load_in_4bit", False) |
| or getattr(model_cfg.llm, "load_in_8bit", False)): |
| logger.warning( |
| "CPU device: disabling 4-bit / 8-bit quantization " |
| "(bitsandbytes requires CUDA). Falling back to dtype=" |
| f"{cpu_dtype}. Vicuna-7B in float32 needs ~28 GB RAM; " |
| "in float16/bfloat16 ~14 GB but very slow on CPU." |
| ) |
| model_cfg.llm.load_in_4bit = False |
| model_cfg.llm.load_in_8bit = False |
| model_cfg.llm.torch_dtype = cpu_dtype |
| |
| |
| model_cfg.llm.device_map = None |
| else: |
| |
| if not torch.cuda.is_available(): |
| raise RuntimeError("Resolved device=cuda but torch.cuda.is_available()==False") |
|
|
|
|
| def build_prompt(task: str, prompt: str | None, structured_findings: str | None) -> str: |
| """Decide which prompt template to use given (task, prompt).""" |
| sf = structured_findings |
| if task == "vqa": |
| if not prompt: |
| raise SystemExit("--task vqa requires --prompt <question>") |
| return build_vqa_prompt(prompt, sf) |
| |
| |
| if prompt: |
| return build_vqa_prompt(prompt, sf) |
| if task == "findings": |
| return build_findings_prompt(sf, randomize=False) |
| if task == "impression": |
| return build_impression_prompt(sf, randomize=False) |
| raise ValueError(f"Unknown task: {task}") |
|
|
|
|
| def main(): |
| args = parse_args() |
| logger = setup_logger("cxr_vlm_infer") |
| device = resolve_device(args.device, logger) |
|
|
| |
| model_cfg = OmegaConf.load(args.model_config) |
| patch_cfg_for_device(model_cfg, device, args.cpu_dtype, logger) |
|
|
| logger.info(f"image = {args.image}") |
| logger.info(f"checkpoint = {args.checkpoint}") |
| logger.info(f"task = {args.task}") |
| logger.info(f"prompt = {args.prompt!r}") |
| logger.info(f"device = {device}") |
|
|
| |
| img_path = Path(args.image) |
| if not img_path.is_file(): |
| raise SystemExit(f"Image not found: {img_path}") |
| transform = BioViLTEncoder.get_transform("val") |
| img = Image.open(img_path).convert("RGB") |
| img_t = transform(img).unsqueeze(0) |
|
|
| |
| logger.info("Building CXR-VLM …") |
| model = CXRVisionLanguageModel(model_cfg) |
| load_checkpoint(model, args.checkpoint) |
|
|
| |
| |
| |
| if device == "cuda": |
| model = model.to("cuda") |
| img_t = img_t.to("cuda") |
| else: |
| model = model.to("cpu") |
| img_t = img_t.to("cpu") |
| model.eval() |
|
|
| |
| full_prompt = build_prompt(args.task, args.prompt, args.structured_findings) |
| logger.info(f"full prompt = {full_prompt!r}") |
|
|
| |
| logger.info("Generating …") |
| do_sample = args.temperature > 0.0 |
| output = model.generate( |
| images = img_t, |
| prompts = [full_prompt], |
| max_new_tokens = args.max_new_tokens, |
| temperature = max(args.temperature, 1e-5), |
| do_sample = do_sample, |
| num_beams = args.num_beams, |
| )[0] |
|
|
| |
| print() |
| print("=" * 78) |
| print(f"Image : {args.image}") |
| print(f"Task : {args.task}") |
| if args.prompt: |
| print(f"Prompt : {args.prompt}") |
| print("-" * 78) |
| print(output.strip() if output.strip() else "<empty response>") |
| print("=" * 78) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|