File size: 10,354 Bytes
28b13fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93ddbd3
28b13fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
"""
inference.py
------------
Run a single CXR image through the trained CXR-VLM and print the model's
output. Supports findings / impression / VQA tasks, and runs on either
GPU or CPU.

Examples
--------
  # 1) Findings (default — no prompt needed)
  python -m evaluation.inference \
      --image  /path/to/cxr.jpg \
      --checkpoint checkpoints/IU-Xray_run_1/stage2_instruct/stage2_final.pt

  # 2) Impression
  python -m evaluation.inference \
      --image      /path/to/cxr.jpg \
      --task       impression \
      --checkpoint .../stage2_final.pt

  # 3) VQA — pass the clinical question via --prompt
  python -m evaluation.inference \
      --image      /path/to/cxr.jpg \
      --task       vqa \
      --prompt     "Is there a pleural effusion?" \
      --checkpoint .../stage2_final.pt

  # 4) Free-form prompt (auto-treats as VQA-style)
  python -m evaluation.inference \
      --image      /path/to/cxr.jpg \
      --prompt     "Describe abnormalities in the left lung." \
      --checkpoint .../stage2_final.pt

  # 5) Force CPU (slow — Vicuna-7B in fp32 on CPU is multi-second per token,
  #    needs ~28 GB RAM. 4-bit quantization is auto-disabled because bnb
  #    requires CUDA.)
  python -m evaluation.inference --image cxr.jpg --device cpu --checkpoint ...

Device options
--------------
  --device auto   (default) cuda if a GPU is visible, else cpu
  --device cuda   explicit GPU
  --device cpu    CPU only — see CPU notes above

Checkpoint path
---------------
  Pass the path to the projection .pt file. The loader reads
  `<dir>/<name>_projection.pt`, `<dir>/<name>_lora/`, and (optional)
  `<dir>/<name>_chexpert_classifier.pt` from the same folder. Both layouts
  work:
    .../stage2_instruct/stage2_final.pt          (after train.py finishes)
    .../stage2/best/checkpoint_projection.pt     (pulled from HF Hub)
"""

import argparse
import sys
from pathlib import Path

# Silence HF per-shard download tqdm spam — must run before transformers import.
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
import utils._quiet  # noqa: F401

import torch
from PIL import Image
from omegaconf import OmegaConf

sys.path.insert(0, str(Path(__file__).resolve().parent.parent))

from model import CXRVisionLanguageModel
from model.rad_dino import BioViLTEncoder
from data.prompt_templates import (
    build_findings_prompt,
    build_impression_prompt,
    build_vqa_prompt,
)
from utils.checkpoint import load_checkpoint
from utils.logger import setup_logger


def parse_args():
    p = argparse.ArgumentParser(description="CXR-VLM single-image inference")
    p.add_argument("--image", type=str, required=True,
                   help="Path to a CXR image (.jpg / .png).")
    p.add_argument("--prompt", type=str, default=None,
                   help="Optional question or instruction. If omitted, the "
                        "default template for the chosen --task is used. "
                        "If supplied with --task findings/impression, the "
                        "prompt overrides the canned instruction.")
    p.add_argument("--task", type=str, default="findings",
                   choices=["findings", "impression", "vqa"],
                   help="Prompt template family. Default: findings. "
                        "For 'vqa', --prompt is required.")
    p.add_argument("--checkpoint", type=str, required=True,
                   help="Path to a stage-2 projection .pt file (the loader "
                        "also picks up the matching _lora/ folder beside it).")
    p.add_argument("--model_config", type=str, default="configs/model_config.yaml")
    p.add_argument("--device", type=str, default="auto",
                   choices=["auto", "cuda", "cpu"],
                   help="auto = cuda if available, else cpu.")
    p.add_argument("--max_new_tokens", type=int, default=300)
    p.add_argument("--temperature", type=float, default=0.0,
                   help="0.0 → greedy. >0 enables sampling.")
    p.add_argument("--num_beams", type=int, default=1,
                   help="Beam search width. 1 = greedy/sampling.")
    p.add_argument("--structured_findings", type=str, default=None,
                   help='Optional "Predicted Findings: ..." string to inject '
                        "as context (normally produced by the CheXpert "
                        "classifier; pass it manually if you have one).")
    p.add_argument("--cpu_dtype", type=str, default="float32",
                   choices=["float32", "float16", "bfloat16"],
                   help="LLM dtype on CPU. float32 is safest (~28 GB RAM); "
                        "float16/bfloat16 halve RAM but are very slow on most "
                        "CPUs. Ignored on GPU.")
    return p.parse_args()


def resolve_device(choice: str, logger) -> str:
    if choice == "auto":
        d = "cuda" if torch.cuda.is_available() else "cpu"
        logger.info(f"--device auto → resolved to {d}")
        return d
    if choice == "cuda" and not torch.cuda.is_available():
        raise SystemExit("--device cuda requested but no CUDA GPU is visible. "
                         "Use --device cpu (slow) or --device auto.")
    return choice


def patch_cfg_for_device(model_cfg, device: str, cpu_dtype: str, logger):
    """Mutate model_cfg in place so it can build cleanly on the chosen device.

    On CPU, bitsandbytes 4-bit / 8-bit cannot be used (bnb is CUDA-only) —
    silently disable them and fall back to a plain fp dtype.
    """
    if device == "cpu":
        if (getattr(model_cfg.llm, "load_in_4bit", False)
                or getattr(model_cfg.llm, "load_in_8bit", False)):
            logger.warning(
                "CPU device: disabling 4-bit / 8-bit quantization "
                "(bitsandbytes requires CUDA). Falling back to dtype="
                f"{cpu_dtype}. Vicuna-7B in float32 needs ~28 GB RAM; "
                "in float16/bfloat16 ~14 GB but very slow on CPU."
            )
        model_cfg.llm.load_in_4bit = False
        model_cfg.llm.load_in_8bit = False
        model_cfg.llm.torch_dtype  = cpu_dtype
        # On CPU we can't use `device_map='auto'` (accelerate would try to find
        # a GPU). Set to None so HF loads everything on the default device (CPU).
        model_cfg.llm.device_map = None
    else:
        # GPU — keep whatever the config already says (4-bit on T4, bf16 on A100, …).
        if not torch.cuda.is_available():
            raise RuntimeError("Resolved device=cuda but torch.cuda.is_available()==False")


def build_prompt(task: str, prompt: str | None, structured_findings: str | None) -> str:
    """Decide which prompt template to use given (task, prompt)."""
    sf = structured_findings
    if task == "vqa":
        if not prompt:
            raise SystemExit("--task vqa requires --prompt <question>")
        return build_vqa_prompt(prompt, sf)
    # findings / impression: if user passed a free-form prompt, treat it as
    # an instruction (same template family as VQA — single instruction line).
    if prompt:
        return build_vqa_prompt(prompt, sf)
    if task == "findings":
        return build_findings_prompt(sf, randomize=False)
    if task == "impression":
        return build_impression_prompt(sf, randomize=False)
    raise ValueError(f"Unknown task: {task}")


def main():
    args   = parse_args()
    logger = setup_logger("cxr_vlm_infer")
    device = resolve_device(args.device, logger)

    # ── Load + patch model config ────────────────────────────────
    model_cfg = OmegaConf.load(args.model_config)
    patch_cfg_for_device(model_cfg, device, args.cpu_dtype, logger)

    logger.info(f"image          = {args.image}")
    logger.info(f"checkpoint     = {args.checkpoint}")
    logger.info(f"task           = {args.task}")
    logger.info(f"prompt         = {args.prompt!r}")
    logger.info(f"device         = {device}")

    # ── Build image tensor ───────────────────────────────────────
    img_path = Path(args.image)
    if not img_path.is_file():
        raise SystemExit(f"Image not found: {img_path}")
    transform = BioViLTEncoder.get_transform("val")
    img       = Image.open(img_path).convert("RGB")
    img_t     = transform(img).unsqueeze(0)   # (1, C, H, W)

    # ── Build model + load trained weights ───────────────────────
    logger.info("Building CXR-VLM …")
    model = CXRVisionLanguageModel(model_cfg)
    load_checkpoint(model, args.checkpoint)

    # `device_map='auto'` with HF accelerate places submodules already; calling
    # .to(device) on top of that is a no-op for the LLM. We still need to move
    # the projection MLP, BioViL-T encoder, and (optional) CheXpert head.
    if device == "cuda":
        model = model.to("cuda")
        img_t = img_t.to("cuda")
    else:
        model = model.to("cpu")
        img_t = img_t.to("cpu")
    model.eval()

    # ── Build the formatted prompt ───────────────────────────────
    full_prompt = build_prompt(args.task, args.prompt, args.structured_findings)
    logger.info(f"full prompt    = {full_prompt!r}")

    # ── Generate ─────────────────────────────────────────────────
    logger.info("Generating …")
    do_sample = args.temperature > 0.0
    output    = model.generate(
        images         = img_t,
        prompts        = [full_prompt],
        max_new_tokens = args.max_new_tokens,
        temperature    = max(args.temperature, 1e-5),  # HF complains if exactly 0
        do_sample      = do_sample,
        num_beams      = args.num_beams,
    )[0]

    # ── Pretty-print result ──────────────────────────────────────
    print()
    print("=" * 78)
    print(f"Image  : {args.image}")
    print(f"Task   : {args.task}")
    if args.prompt:
        print(f"Prompt : {args.prompt}")
    print("-" * 78)
    print(output.strip() if output.strip() else "<empty response>")
    print("=" * 78)


if __name__ == "__main__":
    main()