File size: 5,244 Bytes
6e89446
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import torch
import gradio as gr
from PIL import Image
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLProcessor

# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
MODEL_ID = "Qwen/Qwen2.5-VL-3B-Instruct"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32

DEFAULT_PROMPT = (
    "Do you see any abnormality in the chest? Write briefly. "
    "If yes, also tell where the abnormality is in which part of the chest. "
    "The chest parts include lungs, heart and vessels, spine, diaphragm, "
    "soft tissues, Mediastinum and bones of chest shown in image. "
    "Respond only in English. Do NOT use any other language. "
    "**Do not use Chinese language.**"
)

# ---------------------------------------------------------------------------
# Load model & processor
# ---------------------------------------------------------------------------
print(f"Loading model: {MODEL_ID}")
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    MODEL_ID,
    torch_dtype=DTYPE,
    device_map="auto" if torch.cuda.is_available() else None,
)
if not torch.cuda.is_available():
    model = model.to(DEVICE)

# Skip video_processor attribute to avoid torchvision dependency
_orig_attrs = Qwen2_5_VLProcessor.attributes[:]
Qwen2_5_VLProcessor.attributes = [a for a in _orig_attrs if a != "video_processor"]
processor = Qwen2_5_VLProcessor.from_pretrained(MODEL_ID)
Qwen2_5_VLProcessor.attributes = _orig_attrs
print("Model loaded successfully.")


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def pad_to_square(image: Image.Image) -> Image.Image:
    width, height = image.size
    if width == height:
        return image
    max_dim = max(width, height)
    new_image = Image.new("RGB", (max_dim, max_dim), (0, 0, 0))
    new_image.paste(image, ((max_dim - width) // 2, (max_dim - height) // 2))
    return new_image


# ---------------------------------------------------------------------------
# Inference
# ---------------------------------------------------------------------------
def predict(image: Image.Image, prompt: str, max_new_tokens: int, temperature: float):
    if image is None:
        return "Please upload a chest X-ray image."

    if image.mode != "RGB":
        image = image.convert("RGB")

    image = pad_to_square(image)

    if not prompt.strip():
        prompt = DEFAULT_PROMPT

    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image", "image": image},
                {"type": "text", "text": prompt},
            ],
        }
    ]

    text = processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    inputs = processor(
        text=[text], images=[image], return_tensors="pt", padding=True
    ).to(model.device)

    with torch.no_grad():
        generated_ids = model.generate(
            **inputs,
            max_new_tokens=int(max_new_tokens),
            do_sample=temperature > 0,
            temperature=temperature if temperature > 0 else 1.0,
        )

    generated_ids_trimmed = [
        out_ids[len(in_ids):]
        for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]

    return processor.batch_decode(
        generated_ids_trimmed,
        skip_special_tokens=True,
        clean_up_tokenization_spaces=False,
    )[0]


# ---------------------------------------------------------------------------
# Gradio UI
# ---------------------------------------------------------------------------
with gr.Blocks(
    title="Chest X-Ray Analysis — Qwen2.5-VL-3B",
    theme=gr.themes.Soft(),
) as demo:
    gr.Markdown(
        "# Chest X-Ray Analysis\n"
        "Upload a chest X-ray and get an automated report "
        "powered by **Qwen2.5-VL-3B-Instruct**."
    )

    with gr.Row():
        with gr.Column(scale=1):
            image_input = gr.Image(type="pil", label="Upload Chest X-Ray")
            prompt_input = gr.Textbox(
                label="Prompt",
                value=DEFAULT_PROMPT,
                lines=4,
            )
            with gr.Row():
                max_tokens_slider = gr.Slider(
                    minimum=64, maximum=1024, value=512, step=64,
                    label="Max New Tokens",
                )
                temperature_slider = gr.Slider(
                    minimum=0.0, maximum=1.5, value=0.3, step=0.05,
                    label="Temperature (0 = greedy)",
                )
            submit_btn = gr.Button("Analyze", variant="primary")

        with gr.Column(scale=1):
            output_text = gr.Textbox(label="Model Report", lines=20)

    submit_btn.click(
        predict,
        inputs=[image_input, prompt_input, max_tokens_slider, temperature_slider],
        outputs=output_text,
    )

    gr.Markdown(
        "---\n"
        "*Research purposes only — not a substitute for professional medical diagnosis.*"
    )

if __name__ == "__main__":
    demo.launch()