"""Minimal one-sample MedLayEval inference example.
Reproduces the call sequence used by `MedLayXPlain_official/MedLayEval/inference.py`
without the batched leaderboard scaffolding.
Run from this directory:
python inference_example.py --image example.png
If --image is omitted, a dummy white square is used so that the script
verifies the full load-and-forward pipeline end-to-end without an image
on disk.
"""
import argparse
from pathlib import Path
import torch
from PIL import Image
from peft import PeftModel
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
from model import VLMRegressor, ATTRS
def parse_args():
p = argparse.ArgumentParser()
p.add_argument("--ckpt_dir", default=str(Path(__file__).parent),
help="path containing adapter_*, regression_head.pt, processor files")
p.add_argument("--base_model", default="Qwen/Qwen2.5-VL-3B-Instruct")
p.add_argument("--image", default=None, help="optional local image path")
p.add_argument("--expert", default=(
"Axial chest CT showing a 1.2 cm spiculated nodule in the right "
"upper lobe with adjacent pleural tethering, no mediastinal "
"lymphadenopathy, and no pleural effusion."
))
p.add_argument("--lay", default=(
"The scan shows a small spot in the upper part of the right lung "
"with some pulling on the nearby lining. The lymph nodes in the "
"middle of the chest look normal and there is no fluid build-up."
))
p.add_argument("--device", default="cuda:0" if torch.cuda.is_available() else "cpu")
p.add_argument("--max_pixels", type=int, default=448 * 448)
return p.parse_args()
def load_model(ckpt_dir: str, base_model: str, device: str):
processor = AutoProcessor.from_pretrained(base_model, max_pixels=448 * 448)
if processor.tokenizer.pad_token is None:
processor.tokenizer.pad_token = processor.tokenizer.eos_token
base = Qwen2_5_VLForConditionalGeneration.from_pretrained(
base_model, torch_dtype=torch.bfloat16, attn_implementation="sdpa",
)
vlm = PeftModel.from_pretrained(base, ckpt_dir)
try:
vlm = vlm.merge_and_unload()
except Exception:
pass
hidden = (
vlm.config.hidden_size
if hasattr(vlm.config, "hidden_size")
else vlm.config.text_config.hidden_size
)
model = VLMRegressor(vlm, hidden).to(device, dtype=torch.bfloat16)
head_state = torch.load(f"{ckpt_dir}/regression_head.pt", map_location=device)
model.head.load_state_dict(head_state)
model.head = model.head.to(device, dtype=torch.float32)
model.eval()
return model, processor
def score_triple(model, processor, image: Image.Image, expert: str, lay: str, device: str):
user_text = f"{expert[:1500]}\n{lay[:1500]}"
messages = [{"role": "user", "content": [
{"type": "image"},
{"type": "text", "text": user_text},
]}]
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=False,
)
inputs = processor(
text=[text], images=[image],
padding=True, truncation=True, max_length=2048,
return_tensors="pt",
).to(device)
with torch.no_grad():
preds = model(**inputs).cpu().float().numpy()[0]
scores = {a: float(preds[k]) for k, a in enumerate(ATTRS)}
scores["overall"] = round(sum(scores.values()) / len(ATTRS), 4)
return scores
def main():
args = parse_args()
model, processor = load_model(args.ckpt_dir, args.base_model, args.device)
if args.image:
image = Image.open(args.image).convert("RGB")
else:
image = Image.new("RGB", (224, 224), color=(255, 255, 255))
print("[no --image; using a 224x224 white square]")
scores = score_triple(model, processor, image, args.expert, args.lay, args.device)
for k, v in scores.items():
print(f" {k:>12s}: {v:.4f}")
if __name__ == "__main__":
main()