Manik Sheokand commited on
Commit
421600f
·
1 Parent(s): bdbb866

Fix: Replace Qwen2VLForConditionalGeneration with AutoModelForCausalLM and update transformers to 4.44.0

Browse files
Files changed (3) hide show
  1. app.py +233 -196
  2. requirements.txt +3 -3
  3. runtime.txt +1 -0
app.py CHANGED
@@ -1,220 +1,257 @@
1
- # app.py
2
- # Dermatology-AI-Assistant — Hugging Face Space (ZeroGPU-ready)
3
- # - First tries your fine-tuned model
4
- # - If Qwen raises token/feature mismatch, falls back to official base model
5
- # - Acquires ZeroGPU only during inference
6
- # - Uses qwen-vl-utils.process_vision_info
7
-
8
- import os
9
- import logging
10
- from typing import Optional
11
-
12
- import gradio as gr
13
  import spaces
 
14
  import torch
 
15
  from PIL import Image
16
- from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
17
- from qwen_vl_utils import process_vision_info
18
-
19
- logging.basicConfig(level=logging.INFO, format="%(levelname)s:%(name)s:%(message)s")
20
- logger = logging.getLogger(__name__)
21
-
22
- # ---------------------------
23
- # Config
24
- # ---------------------------
25
- FT_MODEL_ID = os.environ.get("MODEL_ID", "ColdSlim/Dermatology-Qwen2.5-VL-3B")
26
- BASE_MODEL_ID = os.environ.get("FALLBACK_BASE_MODEL_ID", "Qwen/Qwen2.5-VL-3B-Instruct")
27
-
28
- GEN_KW = dict(
29
- max_new_tokens=512,
30
- do_sample=True,
31
- temperature=0.7,
32
- top_p=0.9,
33
- )
34
-
35
- ZGPU_DURATION = int(os.environ.get("ZGPU_DURATION", "180"))
36
-
37
- # Preload only the fine-tuned processor on CPU; we may swap to base processor in the fallback
38
- logger.info(f"Loading processor from: {FT_MODEL_ID}")
39
- ft_processor = AutoProcessor.from_pretrained(FT_MODEL_ID, trust_remote_code=True)
40
- logger.info("Processor loaded.")
41
-
42
- # Optional: stabilize tiling by constraining pixel range (helps placeholder consistency)
43
- def _tune_image_processor(proc):
44
- if hasattr(proc, "image_processor"):
45
- try:
46
- proc.image_processor.max_pixels = int(os.environ.get("QWEN_MAX_PIXELS", "1500000")) # ~1.5MP
47
- proc.image_processor.min_pixels = int(os.environ.get("QWEN_MIN_PIXELS", "262144")) # 512x512
48
- except Exception:
49
- pass
50
-
51
- _tune_image_processor(ft_processor)
52
-
53
- # ---------------------------
54
- # Helpers
55
- # ---------------------------
56
- def _messages(image: Image.Image, question: str):
57
- # ensure RGB to avoid mode surprises
58
- if image.mode != "RGB":
59
- image = image.convert("RGB")
60
- return [
61
- {
62
- "role": "user",
63
- "content": [
64
- {"type": "image", "image": image},
65
- {"type": "text", "text": question},
66
- ],
67
- }
68
- ]
69
-
70
- def build_inputs(processor: AutoProcessor, image: Image.Image, question: str):
71
- """
72
- Build Qwen-style multimodal inputs (no padding, batch size 1).
73
- """
74
- messages = _messages(image, question)
75
- text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
76
- image_inputs, video_inputs = process_vision_info(messages)
77
- inputs = processor(
78
- text=[text],
79
- images=image_inputs,
80
- videos=video_inputs,
81
- return_tensors="pt", # no padding for single sample
82
- )
83
- return inputs
84
 
85
- def _pad_token_id(processor, model):
86
- # Prefer tokenizer.eos if present; else model config; else 0
87
- tid = getattr(getattr(processor, "tokenizer", None), "eos_token_id", None)
88
- if tid is not None:
89
- return tid
90
- return getattr(getattr(model, "config", None), "eos_token_id", 0)
 
 
 
 
 
91
 
92
- def format_derm_disclaimer(ans: str) -> str:
93
- tail = (
94
- "\n\n---\n"
95
- "_Disclaimer: This AI is not a medical device. The output is informational and may be inaccurate. "
96
- "Consult a qualified dermatologist for diagnosis and treatment._"
97
- )
98
- return ans + tail
99
 
100
- def _generate_text(model, processor, inputs: dict) -> str:
101
- # move to CUDA
102
- inputs = {k: v.to("cuda") if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
103
- with torch.no_grad():
104
- out_ids = model.generate(
105
- **inputs,
106
- **GEN_KW,
107
- pad_token_id=_pad_token_id(processor, model),
108
- )
109
- trimmed = [o[len(i):] for i, o in zip(inputs["input_ids"], out_ids)]
110
- text = processor.batch_decode(trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
111
- return text
112
 
113
- # ---------------------------
114
- # Inference (ZeroGPU)
115
- # ---------------------------
116
- @spaces.GPU(duration=ZGPU_DURATION)
117
- def analyze_skin_condition(image: Optional[Image.Image], question: str) -> str:
118
- """
119
- Try fine-tuned model first; on token/feature mismatch, fall back to base model+processor.
120
- """
121
- if image is None:
122
- return "❌ Please upload an image first."
123
 
124
- model = None
 
 
 
125
  try:
126
- # ------- Attempt 1: Fine-tuned model -------
127
- logger.info(f"Loading fine-tuned model on GPU: {FT_MODEL_ID}")
128
- model = Qwen2VLForConditionalGeneration.from_pretrained(
129
- FT_MODEL_ID,
130
- torch_dtype=torch.float16,
131
- device_map="cuda",
 
 
 
132
  trust_remote_code=True,
133
  low_cpu_mem_usage=True,
134
- ignore_mismatched_sizes=True, # your FT ckpt logs suggest some vision head diffs
135
  )
136
- logger.info("Fine-tuned model loaded.")
137
- inputs = build_inputs(ft_processor, image, question)
138
- try:
139
- text = _generate_text(model, ft_processor, inputs)
140
- return format_derm_disclaimer(text)
141
- except ValueError as ve:
142
- msg = str(ve)
143
- if "Image features and image tokens do not match" in msg:
144
- logger.warning("Token/feature mismatch on fine-tuned model — falling back to base model.")
145
- else:
146
- raise
147
-
148
- # ------- Attempt 2: Base model & its processor -------
149
- # Free FT model first
150
- del model
151
- torch.cuda.empty_cache()
152
-
153
- logger.info(f"Loading BASE model on GPU: {BASE_MODEL_ID}")
154
- base_processor = AutoProcessor.from_pretrained(BASE_MODEL_ID, trust_remote_code=True)
155
- _tune_image_processor(base_processor)
156
- model = Qwen2VLForConditionalGeneration.from_pretrained(
157
- BASE_MODEL_ID,
158
- torch_dtype=torch.float16,
159
- device_map="cuda",
160
- trust_remote_code=True,
161
- low_cpu_mem_usage=True,
162
- )
163
- logger.info("Base model loaded.")
164
- base_inputs = build_inputs(base_processor, image, question)
165
- text = _generate_text(model, base_processor, base_inputs)
166
- return format_derm_disclaimer(text)
167
-
168
  except Exception as e:
169
- logger.exception("Error during inference")
170
- return f"❌ Error analyzing image: {e}"
171
- finally:
172
- if model is not None:
173
- del model
174
- torch.cuda.empty_cache()
175
 
176
- # ---------------------------
177
- # UI
178
- # ---------------------------
179
- def create_interface() -> gr.Blocks:
180
- with gr.Blocks(title="Dermatology AI Assistant") as demo:
181
- gr.Markdown(
182
- "# Dermatology AI Assistant\n"
183
- "Upload a skin photo and ask a question. The model will provide an informational response."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  )
185
-
186
- with gr.Row():
187
- image_input = gr.Image(type="pil", label="Upload Image (JPG/PNG)")
188
- question_input = gr.Textbox(
189
- label="Question / Prompt",
190
- value="Describe this skin condition in detail and suggest possible next steps.",
191
- lines=3,
 
 
 
 
 
 
192
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  with gr.Row():
195
- submit_btn = gr.Button("Analyze", variant="primary")
196
- clear_btn = gr.Button("Clear")
197
-
198
- output_box = gr.Textbox(label="Response", lines=16)
199
-
200
- submit_btn.click(fn=analyze_skin_condition, inputs=[image_input, question_input], outputs=output_box, queue=True)
201
- clear_btn.click(fn=lambda: (None, ""), inputs=None, outputs=[image_input, question_input])
202
-
203
- demo.queue()
204
- gr.Markdown("Tips: Ensure good lighting and focus. Avoid uploading personally identifying information.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  return demo
206
 
 
207
  def main():
208
- demo = create_interface()
209
- demo.launch(
210
- server_name="0.0.0.0",
211
- server_port=7860,
212
- share=False,
213
- show_error=True,
214
- inbrowser=False,
215
- quiet=False,
216
- ssr_mode=False,
217
- )
 
 
 
 
 
218
 
219
  if __name__ == "__main__":
220
  main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import spaces
2
+ import gradio as gr
3
  import torch
4
+ from transformers import AutoProcessor, AutoModelForCausalLM
5
  from PIL import Image
6
+ import logging
7
+ import subprocess
8
+ import sys
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
+ # Force Gradio update if needed
11
+ def ensure_gradio_version():
12
+ try:
13
+ import pkg_resources
14
+ current_version = pkg_resources.get_distribution("gradio").version
15
+ if current_version.startswith("4.0"):
16
+ logger.warning(f"Detected old Gradio version {current_version}, attempting to upgrade...")
17
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "--upgrade", "gradio==4.44.1"])
18
+ logger.info("Gradio upgrade completed")
19
+ except Exception as e:
20
+ logger.warning(f"Could not check/upgrade Gradio: {e}")
21
 
22
+ # Check and upgrade Gradio if needed
23
+ ensure_gradio_version()
 
 
 
 
 
24
 
25
+ # Configure logging
26
+ logging.basicConfig(level=logging.INFO)
27
+ logger = logging.getLogger(__name__)
 
 
 
 
 
 
 
 
 
28
 
29
+ # Global variables for model and processor
30
+ model = None
31
+ processor = None
 
 
 
 
 
 
 
32
 
33
+ def load_model():
34
+ """Load the fine-tuned dermatology model"""
35
+ global model, processor
36
+
37
  try:
38
+ # Load the merged model (replace with your actual model path)
39
+ model_name = "ColdSlim/Dermatology-Qwen2.5-VL-3B" # Update with your actual model name
40
+
41
+ logger.info(f"Loading model: {model_name}")
42
+ processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
43
+ model = AutoModelForCausalLM.from_pretrained(
44
+ model_name,
45
+ dtype=torch.bfloat16,
46
+ device_map="auto",
47
  trust_remote_code=True,
48
  low_cpu_mem_usage=True,
49
+ ignore_mismatched_sizes=True
50
  )
51
+
52
+ logger.info("Model loaded successfully!")
53
+ return True
54
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  except Exception as e:
56
+ logger.error(f"Error loading model: {e}")
57
+ return False
 
 
 
 
58
 
59
+ def analyze_skin_condition(image, question="Describe this skin condition in detail."):
60
+ """Analyze skin condition from uploaded image"""
61
+ global model, processor
62
+
63
+ if model is None or processor is None:
64
+ return "❌ Model not loaded. Please wait for the model to load or contact the administrator."
65
+
66
+ if image is None:
67
+ return "❌ Please upload an image first."
68
+
69
+ try:
70
+ # Prepare the conversation
71
+ messages = [
72
+ {
73
+ "role": "user",
74
+ "content": [
75
+ {"type": "image", "image": image},
76
+ {"type": "text", "text": question}
77
+ ]
78
+ }
79
+ ]
80
+
81
+ # Process the input
82
+ text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
83
+ image_inputs, video_inputs = processor.process_vision_info(messages)
84
+
85
+ inputs = processor(
86
+ text=[text],
87
+ images=image_inputs,
88
+ videos=video_inputs,
89
+ padding=True,
90
+ return_tensors="pt"
91
  )
92
+
93
+ # Move inputs to the same device as model
94
+ inputs = {k: v.to(model.device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
95
+
96
+ # Generate response
97
+ with torch.no_grad():
98
+ generated_ids = model.generate(
99
+ **inputs,
100
+ max_new_tokens=512,
101
+ do_sample=True,
102
+ temperature=0.7,
103
+ top_p=0.9,
104
+ pad_token_id=processor.tokenizer.eos_token_id
105
  )
106
+
107
+ # Decode the response
108
+ generated_ids_trimmed = [
109
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
110
+ ]
111
+ output_text = processor.batch_decode(
112
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
113
+ )[0]
114
+
115
+ return output_text
116
+
117
+ except Exception as e:
118
+ logger.error(f"Error during inference: {e}")
119
+ return f"❌ Error analyzing image: {str(e)}"
120
 
121
+ def create_interface():
122
+ """Create the Gradio interface"""
123
+
124
+ # Load model on startup
125
+ model_loaded = load_model()
126
+
127
+ with gr.Blocks(
128
+ title="Dermatology AI Assistant",
129
+ theme=gr.themes.Soft(),
130
+ css="""
131
+ .gradio-container {
132
+ max-width: 1200px !important;
133
+ margin: auto !important;
134
+ }
135
+ .main-header {
136
+ text-align: center;
137
+ margin-bottom: 2rem;
138
+ }
139
+ .warning-box {
140
+ background-color: #fff3cd;
141
+ border: 1px solid #ffeaa7;
142
+ border-radius: 8px;
143
+ padding: 1rem;
144
+ margin: 1rem 0;
145
+ }
146
+ """
147
+ ) as demo:
148
+
149
+ gr.HTML("""
150
+ <div class="main-header">
151
+ <h1>🩺 Dermatology AI Assistant</h1>
152
+ <p>Powered by Qwen2.5-VL-3B fine-tuned for dermatology analysis</p>
153
+ </div>
154
+ """)
155
+
156
+ # Warning message
157
+ gr.HTML("""
158
+ <div class="warning-box">
159
+ <h3>⚠️ Medical Disclaimer</h3>
160
+ <p>This AI assistant is for educational and research purposes only.
161
+ It should not be used as a substitute for professional medical advice,
162
+ diagnosis, or treatment. Always consult with a qualified healthcare
163
+ provider for medical concerns.</p>
164
+ </div>
165
+ """)
166
+
167
  with gr.Row():
168
+ with gr.Column(scale=1):
169
+ # Image upload
170
+ image_input = gr.Image(
171
+ label="Upload Skin Image",
172
+ type="pil",
173
+ height=400
174
+ )
175
+
176
+ # Question input
177
+ question_input = gr.Textbox(
178
+ label="Question (Optional)",
179
+ placeholder="Describe this skin condition in detail.",
180
+ value="Describe this skin condition in detail.",
181
+ lines=3
182
+ )
183
+
184
+ # Analyze button
185
+ analyze_btn = gr.Button(
186
+ "🔍 Analyze Skin Condition",
187
+ variant="primary",
188
+ size="lg"
189
+ )
190
+
191
+ # Example questions
192
+ gr.HTML("""
193
+ <h4>💡 Example Questions:</h4>
194
+ <ul>
195
+ <li>What type of skin condition is this?</li>
196
+ <li>Describe the characteristics of this lesion.</li>
197
+ <li>What are the potential causes of this skin issue?</li>
198
+ <li>What should I know about this skin condition?</li>
199
+ </ul>
200
+ """)
201
+
202
+ with gr.Column(scale=1):
203
+ # Output
204
+ output_text = gr.Textbox(
205
+ label="AI Analysis",
206
+ lines=15,
207
+ max_lines=20,
208
+ show_copy_button=True
209
+ )
210
+
211
+ # Examples
212
+ gr.Examples(
213
+ examples=[
214
+ ["What type of skin condition is this?", "Describe this skin condition in detail."],
215
+ ["What are the characteristics of this lesion?", "Describe this skin condition in detail."],
216
+ ["What should I know about this skin issue?", "Describe this skin condition in detail."],
217
+ ],
218
+ inputs=[question_input, question_input],
219
+ label="💡 Example Questions"
220
+ )
221
+
222
+ # Event handlers
223
+ analyze_btn.click(
224
+ fn=analyze_skin_condition,
225
+ inputs=[image_input, question_input],
226
+ outputs=output_text
227
+ )
228
+
229
+ # Model status
230
+ if model_loaded:
231
+ gr.HTML("<div style='text-align: center; color: green;'>✅ Model loaded successfully!</div>")
232
+ else:
233
+ gr.HTML("<div style='text-align: center; color: red;'>❌ Model loading failed. Please check the logs.</div>")
234
+
235
  return demo
236
 
237
+ @spaces.GPU
238
  def main():
239
+ """Main function with GPU decorator for Hugging Face Spaces"""
240
+ try:
241
+ # Create and launch the interface
242
+ demo = create_interface()
243
+ demo.launch(
244
+ server_name="0.0.0.0",
245
+ server_port=7860,
246
+ share=False,
247
+ show_error=True,
248
+ inbrowser=False,
249
+ quiet=False
250
+ )
251
+ except Exception as e:
252
+ logger.error(f"Error launching app: {e}")
253
+ raise
254
 
255
  if __name__ == "__main__":
256
  main()
257
+
requirements.txt CHANGED
@@ -3,8 +3,8 @@
3
  # Core dependencies
4
  torch>=2.0.0
5
  torchvision>=0.15.0
6
- transformers==4.44.2
7
- accelerate>=0.34.2
8
  gradio==4.44.1
9
  huggingface_hub>=0.20.0
10
  spaces
@@ -14,7 +14,7 @@ Pillow>=9.0.0
14
  opencv-python>=4.5.0
15
 
16
  # Qwen2-VL specific
17
- qwen-vl-utils>=0.0.8
18
 
19
  # Optional: For better performance
20
  flash-attn>=2.0.0
 
3
  # Core dependencies
4
  torch>=2.0.0
5
  torchvision>=0.15.0
6
+ transformers>=4.44.0
7
+ accelerate>=0.20.0
8
  gradio==4.44.1
9
  huggingface_hub>=0.20.0
10
  spaces
 
14
  opencv-python>=4.5.0
15
 
16
  # Qwen2-VL specific
17
+ qwen-vl-utils>=0.0.1
18
 
19
  # Optional: For better performance
20
  flash-attn>=2.0.0
runtime.txt CHANGED
@@ -1 +1,2 @@
1
  python-3.10
 
 
1
  python-3.10
2
+