"""Minimal one-sample MedLayEval inference example. Reproduces the call sequence used by `MedLayXPlain_official/MedLayEval/inference.py` without the batched leaderboard scaffolding. Run from this directory: python inference_example.py --image example.png If --image is omitted, a dummy white square is used so that the script verifies the full load-and-forward pipeline end-to-end without an image on disk. """ import argparse from pathlib import Path import torch from PIL import Image from peft import PeftModel from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration from model import VLMRegressor, ATTRS def parse_args(): p = argparse.ArgumentParser() p.add_argument("--ckpt_dir", default=str(Path(__file__).parent), help="path containing adapter_*, regression_head.pt, processor files") p.add_argument("--base_model", default="Qwen/Qwen2.5-VL-3B-Instruct") p.add_argument("--image", default=None, help="optional local image path") p.add_argument("--expert", default=( "Axial chest CT showing a 1.2 cm spiculated nodule in the right " "upper lobe with adjacent pleural tethering, no mediastinal " "lymphadenopathy, and no pleural effusion." )) p.add_argument("--lay", default=( "The scan shows a small spot in the upper part of the right lung " "with some pulling on the nearby lining. The lymph nodes in the " "middle of the chest look normal and there is no fluid build-up." )) p.add_argument("--device", default="cuda:0" if torch.cuda.is_available() else "cpu") p.add_argument("--max_pixels", type=int, default=448 * 448) return p.parse_args() def load_model(ckpt_dir: str, base_model: str, device: str): processor = AutoProcessor.from_pretrained(base_model, max_pixels=448 * 448) if processor.tokenizer.pad_token is None: processor.tokenizer.pad_token = processor.tokenizer.eos_token base = Qwen2_5_VLForConditionalGeneration.from_pretrained( base_model, torch_dtype=torch.bfloat16, attn_implementation="sdpa", ) vlm = PeftModel.from_pretrained(base, ckpt_dir) try: vlm = vlm.merge_and_unload() except Exception: pass hidden = ( vlm.config.hidden_size if hasattr(vlm.config, "hidden_size") else vlm.config.text_config.hidden_size ) model = VLMRegressor(vlm, hidden).to(device, dtype=torch.bfloat16) head_state = torch.load(f"{ckpt_dir}/regression_head.pt", map_location=device) model.head.load_state_dict(head_state) model.head = model.head.to(device, dtype=torch.float32) model.eval() return model, processor def score_triple(model, processor, image: Image.Image, expert: str, lay: str, device: str): user_text = f"{expert[:1500]}\n{lay[:1500]}" messages = [{"role": "user", "content": [ {"type": "image"}, {"type": "text", "text": user_text}, ]}] text = processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=False, ) inputs = processor( text=[text], images=[image], padding=True, truncation=True, max_length=2048, return_tensors="pt", ).to(device) with torch.no_grad(): preds = model(**inputs).cpu().float().numpy()[0] scores = {a: float(preds[k]) for k, a in enumerate(ATTRS)} scores["overall"] = round(sum(scores.values()) / len(ATTRS), 4) return scores def main(): args = parse_args() model, processor = load_model(args.ckpt_dir, args.base_model, args.device) if args.image: image = Image.open(args.image).convert("RGB") else: image = Image.new("RGB", (224, 224), color=(255, 255, 255)) print("[no --image; using a 224x224 white square]") scores = score_triple(model, processor, image, args.expert, args.lay, args.device) for k, v in scores.items(): print(f" {k:>12s}: {v:.4f}") if __name__ == "__main__": main()