aqibmumtazbits commited on
Commit
6e89446
·
verified ·
1 Parent(s): ed8812d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +154 -0
app.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from PIL import Image
4
+ from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLProcessor
5
+
6
+ # ---------------------------------------------------------------------------
7
+ # Configuration
8
+ # ---------------------------------------------------------------------------
9
+ MODEL_ID = "Qwen/Qwen2.5-VL-3B-Instruct"
10
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
11
+ DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
12
+
13
+ DEFAULT_PROMPT = (
14
+ "Do you see any abnormality in the chest? Write briefly. "
15
+ "If yes, also tell where the abnormality is in which part of the chest. "
16
+ "The chest parts include lungs, heart and vessels, spine, diaphragm, "
17
+ "soft tissues, Mediastinum and bones of chest shown in image. "
18
+ "Respond only in English. Do NOT use any other language. "
19
+ "**Do not use Chinese language.**"
20
+ )
21
+
22
+ # ---------------------------------------------------------------------------
23
+ # Load model & processor
24
+ # ---------------------------------------------------------------------------
25
+ print(f"Loading model: {MODEL_ID}")
26
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
27
+ MODEL_ID,
28
+ torch_dtype=DTYPE,
29
+ device_map="auto" if torch.cuda.is_available() else None,
30
+ )
31
+ if not torch.cuda.is_available():
32
+ model = model.to(DEVICE)
33
+
34
+ # Skip video_processor attribute to avoid torchvision dependency
35
+ _orig_attrs = Qwen2_5_VLProcessor.attributes[:]
36
+ Qwen2_5_VLProcessor.attributes = [a for a in _orig_attrs if a != "video_processor"]
37
+ processor = Qwen2_5_VLProcessor.from_pretrained(MODEL_ID)
38
+ Qwen2_5_VLProcessor.attributes = _orig_attrs
39
+ print("Model loaded successfully.")
40
+
41
+
42
+ # ---------------------------------------------------------------------------
43
+ # Helpers
44
+ # ---------------------------------------------------------------------------
45
+ def pad_to_square(image: Image.Image) -> Image.Image:
46
+ width, height = image.size
47
+ if width == height:
48
+ return image
49
+ max_dim = max(width, height)
50
+ new_image = Image.new("RGB", (max_dim, max_dim), (0, 0, 0))
51
+ new_image.paste(image, ((max_dim - width) // 2, (max_dim - height) // 2))
52
+ return new_image
53
+
54
+
55
+ # ---------------------------------------------------------------------------
56
+ # Inference
57
+ # ---------------------------------------------------------------------------
58
+ def predict(image: Image.Image, prompt: str, max_new_tokens: int, temperature: float):
59
+ if image is None:
60
+ return "Please upload a chest X-ray image."
61
+
62
+ if image.mode != "RGB":
63
+ image = image.convert("RGB")
64
+
65
+ image = pad_to_square(image)
66
+
67
+ if not prompt.strip():
68
+ prompt = DEFAULT_PROMPT
69
+
70
+ messages = [
71
+ {
72
+ "role": "user",
73
+ "content": [
74
+ {"type": "image", "image": image},
75
+ {"type": "text", "text": prompt},
76
+ ],
77
+ }
78
+ ]
79
+
80
+ text = processor.apply_chat_template(
81
+ messages, tokenize=False, add_generation_prompt=True
82
+ )
83
+ inputs = processor(
84
+ text=[text], images=[image], return_tensors="pt", padding=True
85
+ ).to(model.device)
86
+
87
+ with torch.no_grad():
88
+ generated_ids = model.generate(
89
+ **inputs,
90
+ max_new_tokens=int(max_new_tokens),
91
+ do_sample=temperature > 0,
92
+ temperature=temperature if temperature > 0 else 1.0,
93
+ )
94
+
95
+ generated_ids_trimmed = [
96
+ out_ids[len(in_ids):]
97
+ for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
98
+ ]
99
+
100
+ return processor.batch_decode(
101
+ generated_ids_trimmed,
102
+ skip_special_tokens=True,
103
+ clean_up_tokenization_spaces=False,
104
+ )[0]
105
+
106
+
107
+ # ---------------------------------------------------------------------------
108
+ # Gradio UI
109
+ # ---------------------------------------------------------------------------
110
+ with gr.Blocks(
111
+ title="Chest X-Ray Analysis — Qwen2.5-VL-3B",
112
+ theme=gr.themes.Soft(),
113
+ ) as demo:
114
+ gr.Markdown(
115
+ "# Chest X-Ray Analysis\n"
116
+ "Upload a chest X-ray and get an automated report "
117
+ "powered by **Qwen2.5-VL-3B-Instruct**."
118
+ )
119
+
120
+ with gr.Row():
121
+ with gr.Column(scale=1):
122
+ image_input = gr.Image(type="pil", label="Upload Chest X-Ray")
123
+ prompt_input = gr.Textbox(
124
+ label="Prompt",
125
+ value=DEFAULT_PROMPT,
126
+ lines=4,
127
+ )
128
+ with gr.Row():
129
+ max_tokens_slider = gr.Slider(
130
+ minimum=64, maximum=1024, value=512, step=64,
131
+ label="Max New Tokens",
132
+ )
133
+ temperature_slider = gr.Slider(
134
+ minimum=0.0, maximum=1.5, value=0.3, step=0.05,
135
+ label="Temperature (0 = greedy)",
136
+ )
137
+ submit_btn = gr.Button("Analyze", variant="primary")
138
+
139
+ with gr.Column(scale=1):
140
+ output_text = gr.Textbox(label="Model Report", lines=20)
141
+
142
+ submit_btn.click(
143
+ predict,
144
+ inputs=[image_input, prompt_input, max_tokens_slider, temperature_slider],
145
+ outputs=output_text,
146
+ )
147
+
148
+ gr.Markdown(
149
+ "---\n"
150
+ "*Research purposes only — not a substitute for professional medical diagnosis.*"
151
+ )
152
+
153
+ if __name__ == "__main__":
154
+ demo.launch()