File size: 4,010 Bytes
67d9005
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
"""Minimal one-sample MedLayEval inference example.

Reproduces the call sequence used by `MedLayXPlain_official/MedLayEval/inference.py`
without the batched leaderboard scaffolding.

Run from this directory:

    python inference_example.py --image example.png

If --image is omitted, a dummy white square is used so that the script
verifies the full load-and-forward pipeline end-to-end without an image
on disk.
"""
import argparse
from pathlib import Path

import torch
from PIL import Image
from peft import PeftModel
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration

from model import VLMRegressor, ATTRS


def parse_args():
    p = argparse.ArgumentParser()
    p.add_argument("--ckpt_dir", default=str(Path(__file__).parent),
                   help="path containing adapter_*, regression_head.pt, processor files")
    p.add_argument("--base_model", default="Qwen/Qwen2.5-VL-3B-Instruct")
    p.add_argument("--image", default=None, help="optional local image path")
    p.add_argument("--expert", default=(
        "Axial chest CT showing a 1.2 cm spiculated nodule in the right "
        "upper lobe with adjacent pleural tethering, no mediastinal "
        "lymphadenopathy, and no pleural effusion."
    ))
    p.add_argument("--lay", default=(
        "The scan shows a small spot in the upper part of the right lung "
        "with some pulling on the nearby lining. The lymph nodes in the "
        "middle of the chest look normal and there is no fluid build-up."
    ))
    p.add_argument("--device", default="cuda:0" if torch.cuda.is_available() else "cpu")
    p.add_argument("--max_pixels", type=int, default=448 * 448)
    return p.parse_args()


def load_model(ckpt_dir: str, base_model: str, device: str):
    processor = AutoProcessor.from_pretrained(base_model, max_pixels=448 * 448)
    if processor.tokenizer.pad_token is None:
        processor.tokenizer.pad_token = processor.tokenizer.eos_token

    base = Qwen2_5_VLForConditionalGeneration.from_pretrained(
        base_model, torch_dtype=torch.bfloat16, attn_implementation="sdpa",
    )
    vlm = PeftModel.from_pretrained(base, ckpt_dir)
    try:
        vlm = vlm.merge_and_unload()
    except Exception:
        pass

    hidden = (
        vlm.config.hidden_size
        if hasattr(vlm.config, "hidden_size")
        else vlm.config.text_config.hidden_size
    )
    model = VLMRegressor(vlm, hidden).to(device, dtype=torch.bfloat16)
    head_state = torch.load(f"{ckpt_dir}/regression_head.pt", map_location=device)
    model.head.load_state_dict(head_state)
    model.head = model.head.to(device, dtype=torch.float32)
    model.eval()
    return model, processor


def score_triple(model, processor, image: Image.Image, expert: str, lay: str, device: str):
    user_text = f"<expert>{expert[:1500]}</expert>\n<lay>{lay[:1500]}</lay>"
    messages = [{"role": "user", "content": [
        {"type": "image"},
        {"type": "text", "text": user_text},
    ]}]
    text = processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=False,
    )
    inputs = processor(
        text=[text], images=[image],
        padding=True, truncation=True, max_length=2048,
        return_tensors="pt",
    ).to(device)
    with torch.no_grad():
        preds = model(**inputs).cpu().float().numpy()[0]
    scores = {a: float(preds[k]) for k, a in enumerate(ATTRS)}
    scores["overall"] = round(sum(scores.values()) / len(ATTRS), 4)
    return scores


def main():
    args = parse_args()
    model, processor = load_model(args.ckpt_dir, args.base_model, args.device)

    if args.image:
        image = Image.open(args.image).convert("RGB")
    else:
        image = Image.new("RGB", (224, 224), color=(255, 255, 255))
        print("[no --image; using a 224x224 white square]")

    scores = score_triple(model, processor, image, args.expert, args.lay, args.device)
    for k, v in scores.items():
        print(f"  {k:>12s}: {v:.4f}")


if __name__ == "__main__":
    main()