ColdSlim commited on
Commit
ed31cbb
·
verified ·
1 Parent(s): 421600f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +196 -233
app.py CHANGED
@@ -1,257 +1,220 @@
1
- import spaces
 
 
 
 
 
 
 
 
 
 
2
  import gradio as gr
 
3
  import torch
4
- from transformers import AutoProcessor, AutoModelForCausalLM
5
  from PIL import Image
6
- import logging
7
- import subprocess
8
- import sys
9
 
10
- # Force Gradio update if needed
11
- def ensure_gradio_version():
12
- try:
13
- import pkg_resources
14
- current_version = pkg_resources.get_distribution("gradio").version
15
- if current_version.startswith("4.0"):
16
- logger.warning(f"Detected old Gradio version {current_version}, attempting to upgrade...")
17
- subprocess.check_call([sys.executable, "-m", "pip", "install", "--upgrade", "gradio==4.44.1"])
18
- logger.info("Gradio upgrade completed")
19
- except Exception as e:
20
- logger.warning(f"Could not check/upgrade Gradio: {e}")
21
 
22
- # Check and upgrade Gradio if needed
23
- ensure_gradio_version()
 
 
 
24
 
25
- # Configure logging
26
- logging.basicConfig(level=logging.INFO)
27
- logger = logging.getLogger(__name__)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
- # Global variables for model and processor
30
- model = None
31
- processor = None
32
 
33
- def load_model():
34
- """Load the fine-tuned dermatology model"""
35
- global model, processor
36
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  try:
38
- # Load the merged model (replace with your actual model path)
39
- model_name = "ColdSlim/Dermatology-Qwen2.5-VL-3B" # Update with your actual model name
40
-
41
- logger.info(f"Loading model: {model_name}")
42
- processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
43
- model = AutoModelForCausalLM.from_pretrained(
44
- model_name,
45
- dtype=torch.bfloat16,
46
- device_map="auto",
47
  trust_remote_code=True,
48
  low_cpu_mem_usage=True,
49
- ignore_mismatched_sizes=True
50
  )
51
-
52
- logger.info("Model loaded successfully!")
53
- return True
54
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  except Exception as e:
56
- logger.error(f"Error loading model: {e}")
57
- return False
 
 
 
 
58
 
59
- def analyze_skin_condition(image, question="Describe this skin condition in detail."):
60
- """Analyze skin condition from uploaded image"""
61
- global model, processor
62
-
63
- if model is None or processor is None:
64
- return "❌ Model not loaded. Please wait for the model to load or contact the administrator."
65
-
66
- if image is None:
67
- return "❌ Please upload an image first."
68
-
69
- try:
70
- # Prepare the conversation
71
- messages = [
72
- {
73
- "role": "user",
74
- "content": [
75
- {"type": "image", "image": image},
76
- {"type": "text", "text": question}
77
- ]
78
- }
79
- ]
80
-
81
- # Process the input
82
- text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
83
- image_inputs, video_inputs = processor.process_vision_info(messages)
84
-
85
- inputs = processor(
86
- text=[text],
87
- images=image_inputs,
88
- videos=video_inputs,
89
- padding=True,
90
- return_tensors="pt"
91
  )
92
-
93
- # Move inputs to the same device as model
94
- inputs = {k: v.to(model.device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
95
-
96
- # Generate response
97
- with torch.no_grad():
98
- generated_ids = model.generate(
99
- **inputs,
100
- max_new_tokens=512,
101
- do_sample=True,
102
- temperature=0.7,
103
- top_p=0.9,
104
- pad_token_id=processor.tokenizer.eos_token_id
105
  )
106
-
107
- # Decode the response
108
- generated_ids_trimmed = [
109
- out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
110
- ]
111
- output_text = processor.batch_decode(
112
- generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
113
- )[0]
114
-
115
- return output_text
116
-
117
- except Exception as e:
118
- logger.error(f"Error during inference: {e}")
119
- return f"❌ Error analyzing image: {str(e)}"
120
 
121
- def create_interface():
122
- """Create the Gradio interface"""
123
-
124
- # Load model on startup
125
- model_loaded = load_model()
126
-
127
- with gr.Blocks(
128
- title="Dermatology AI Assistant",
129
- theme=gr.themes.Soft(),
130
- css="""
131
- .gradio-container {
132
- max-width: 1200px !important;
133
- margin: auto !important;
134
- }
135
- .main-header {
136
- text-align: center;
137
- margin-bottom: 2rem;
138
- }
139
- .warning-box {
140
- background-color: #fff3cd;
141
- border: 1px solid #ffeaa7;
142
- border-radius: 8px;
143
- padding: 1rem;
144
- margin: 1rem 0;
145
- }
146
- """
147
- ) as demo:
148
-
149
- gr.HTML("""
150
- <div class="main-header">
151
- <h1>🩺 Dermatology AI Assistant</h1>
152
- <p>Powered by Qwen2.5-VL-3B fine-tuned for dermatology analysis</p>
153
- </div>
154
- """)
155
-
156
- # Warning message
157
- gr.HTML("""
158
- <div class="warning-box">
159
- <h3>⚠️ Medical Disclaimer</h3>
160
- <p>This AI assistant is for educational and research purposes only.
161
- It should not be used as a substitute for professional medical advice,
162
- diagnosis, or treatment. Always consult with a qualified healthcare
163
- provider for medical concerns.</p>
164
- </div>
165
- """)
166
-
167
  with gr.Row():
168
- with gr.Column(scale=1):
169
- # Image upload
170
- image_input = gr.Image(
171
- label="Upload Skin Image",
172
- type="pil",
173
- height=400
174
- )
175
-
176
- # Question input
177
- question_input = gr.Textbox(
178
- label="Question (Optional)",
179
- placeholder="Describe this skin condition in detail.",
180
- value="Describe this skin condition in detail.",
181
- lines=3
182
- )
183
-
184
- # Analyze button
185
- analyze_btn = gr.Button(
186
- "🔍 Analyze Skin Condition",
187
- variant="primary",
188
- size="lg"
189
- )
190
-
191
- # Example questions
192
- gr.HTML("""
193
- <h4>💡 Example Questions:</h4>
194
- <ul>
195
- <li>What type of skin condition is this?</li>
196
- <li>Describe the characteristics of this lesion.</li>
197
- <li>What are the potential causes of this skin issue?</li>
198
- <li>What should I know about this skin condition?</li>
199
- </ul>
200
- """)
201
-
202
- with gr.Column(scale=1):
203
- # Output
204
- output_text = gr.Textbox(
205
- label="AI Analysis",
206
- lines=15,
207
- max_lines=20,
208
- show_copy_button=True
209
- )
210
-
211
- # Examples
212
- gr.Examples(
213
- examples=[
214
- ["What type of skin condition is this?", "Describe this skin condition in detail."],
215
- ["What are the characteristics of this lesion?", "Describe this skin condition in detail."],
216
- ["What should I know about this skin issue?", "Describe this skin condition in detail."],
217
- ],
218
- inputs=[question_input, question_input],
219
- label="💡 Example Questions"
220
- )
221
-
222
- # Event handlers
223
- analyze_btn.click(
224
- fn=analyze_skin_condition,
225
- inputs=[image_input, question_input],
226
- outputs=output_text
227
- )
228
-
229
- # Model status
230
- if model_loaded:
231
- gr.HTML("<div style='text-align: center; color: green;'>✅ Model loaded successfully!</div>")
232
- else:
233
- gr.HTML("<div style='text-align: center; color: red;'>❌ Model loading failed. Please check the logs.</div>")
234
-
235
  return demo
236
 
237
- @spaces.GPU
238
  def main():
239
- """Main function with GPU decorator for Hugging Face Spaces"""
240
- try:
241
- # Create and launch the interface
242
- demo = create_interface()
243
- demo.launch(
244
- server_name="0.0.0.0",
245
- server_port=7860,
246
- share=False,
247
- show_error=True,
248
- inbrowser=False,
249
- quiet=False
250
- )
251
- except Exception as e:
252
- logger.error(f"Error launching app: {e}")
253
- raise
254
 
255
  if __name__ == "__main__":
256
  main()
257
-
 
1
+ # app.py
2
+ # Dermatology-AI-Assistant — Hugging Face Space (ZeroGPU-ready)
3
+ # - First tries your fine-tuned model
4
+ # - If Qwen raises token/feature mismatch, falls back to official base model
5
+ # - Acquires ZeroGPU only during inference
6
+ # - Uses qwen-vl-utils.process_vision_info
7
+
8
+ import os
9
+ import logging
10
+ from typing import Optional
11
+
12
  import gradio as gr
13
+ import spaces
14
  import torch
 
15
  from PIL import Image
16
+ from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
17
+ from qwen_vl_utils import process_vision_info
 
18
 
19
+ logging.basicConfig(level=logging.INFO, format="%(levelname)s:%(name)s:%(message)s")
20
+ logger = logging.getLogger(__name__)
 
 
 
 
 
 
 
 
 
21
 
22
+ # ---------------------------
23
+ # Config
24
+ # ---------------------------
25
+ FT_MODEL_ID = os.environ.get("MODEL_ID", "ColdSlim/Dermatology-Qwen2.5-VL-3B")
26
+ BASE_MODEL_ID = os.environ.get("FALLBACK_BASE_MODEL_ID", "Qwen/Qwen2.5-VL-3B-Instruct")
27
 
28
+ GEN_KW = dict(
29
+ max_new_tokens=512,
30
+ do_sample=True,
31
+ temperature=0.7,
32
+ top_p=0.9,
33
+ )
34
+
35
+ ZGPU_DURATION = int(os.environ.get("ZGPU_DURATION", "180"))
36
+
37
+ # Preload only the fine-tuned processor on CPU; we may swap to base processor in the fallback
38
+ logger.info(f"Loading processor from: {FT_MODEL_ID}")
39
+ ft_processor = AutoProcessor.from_pretrained(FT_MODEL_ID, trust_remote_code=True)
40
+ logger.info("Processor loaded.")
41
+
42
+ # Optional: stabilize tiling by constraining pixel range (helps placeholder consistency)
43
+ def _tune_image_processor(proc):
44
+ if hasattr(proc, "image_processor"):
45
+ try:
46
+ proc.image_processor.max_pixels = int(os.environ.get("QWEN_MAX_PIXELS", "1500000")) # ~1.5MP
47
+ proc.image_processor.min_pixels = int(os.environ.get("QWEN_MIN_PIXELS", "262144")) # 512x512
48
+ except Exception:
49
+ pass
50
 
51
+ _tune_image_processor(ft_processor)
 
 
52
 
53
+ # ---------------------------
54
+ # Helpers
55
+ # ---------------------------
56
+ def _messages(image: Image.Image, question: str):
57
+ # ensure RGB to avoid mode surprises
58
+ if image.mode != "RGB":
59
+ image = image.convert("RGB")
60
+ return [
61
+ {
62
+ "role": "user",
63
+ "content": [
64
+ {"type": "image", "image": image},
65
+ {"type": "text", "text": question},
66
+ ],
67
+ }
68
+ ]
69
+
70
+ def build_inputs(processor: AutoProcessor, image: Image.Image, question: str):
71
+ """
72
+ Build Qwen-style multimodal inputs (no padding, batch size 1).
73
+ """
74
+ messages = _messages(image, question)
75
+ text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
76
+ image_inputs, video_inputs = process_vision_info(messages)
77
+ inputs = processor(
78
+ text=[text],
79
+ images=image_inputs,
80
+ videos=video_inputs,
81
+ return_tensors="pt", # no padding for single sample
82
+ )
83
+ return inputs
84
+
85
+ def _pad_token_id(processor, model):
86
+ # Prefer tokenizer.eos if present; else model config; else 0
87
+ tid = getattr(getattr(processor, "tokenizer", None), "eos_token_id", None)
88
+ if tid is not None:
89
+ return tid
90
+ return getattr(getattr(model, "config", None), "eos_token_id", 0)
91
+
92
+ def format_derm_disclaimer(ans: str) -> str:
93
+ tail = (
94
+ "\n\n---\n"
95
+ "_Disclaimer: This AI is not a medical device. The output is informational and may be inaccurate. "
96
+ "Consult a qualified dermatologist for diagnosis and treatment._"
97
+ )
98
+ return ans + tail
99
+
100
+ def _generate_text(model, processor, inputs: dict) -> str:
101
+ # move to CUDA
102
+ inputs = {k: v.to("cuda") if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
103
+ with torch.no_grad():
104
+ out_ids = model.generate(
105
+ **inputs,
106
+ **GEN_KW,
107
+ pad_token_id=_pad_token_id(processor, model),
108
+ )
109
+ trimmed = [o[len(i):] for i, o in zip(inputs["input_ids"], out_ids)]
110
+ text = processor.batch_decode(trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
111
+ return text
112
+
113
+ # ---------------------------
114
+ # Inference (ZeroGPU)
115
+ # ---------------------------
116
+ @spaces.GPU(duration=ZGPU_DURATION)
117
+ def analyze_skin_condition(image: Optional[Image.Image], question: str) -> str:
118
+ """
119
+ Try fine-tuned model first; on token/feature mismatch, fall back to base model+processor.
120
+ """
121
+ if image is None:
122
+ return "❌ Please upload an image first."
123
+
124
+ model = None
125
  try:
126
+ # ------- Attempt 1: Fine-tuned model -------
127
+ logger.info(f"Loading fine-tuned model on GPU: {FT_MODEL_ID}")
128
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
129
+ FT_MODEL_ID,
130
+ torch_dtype=torch.float16,
131
+ device_map="cuda",
 
 
 
132
  trust_remote_code=True,
133
  low_cpu_mem_usage=True,
134
+ ignore_mismatched_sizes=True, # your FT ckpt logs suggest some vision head diffs
135
  )
136
+ logger.info("Fine-tuned model loaded.")
137
+ inputs = build_inputs(ft_processor, image, question)
138
+ try:
139
+ text = _generate_text(model, ft_processor, inputs)
140
+ return format_derm_disclaimer(text)
141
+ except ValueError as ve:
142
+ msg = str(ve)
143
+ if "Image features and image tokens do not match" in msg:
144
+ logger.warning("Token/feature mismatch on fine-tuned model — falling back to base model.")
145
+ else:
146
+ raise
147
+
148
+ # ------- Attempt 2: Base model & its processor -------
149
+ # Free FT model first
150
+ del model
151
+ torch.cuda.empty_cache()
152
+
153
+ logger.info(f"Loading BASE model on GPU: {BASE_MODEL_ID}")
154
+ base_processor = AutoProcessor.from_pretrained(BASE_MODEL_ID, trust_remote_code=True)
155
+ _tune_image_processor(base_processor)
156
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
157
+ BASE_MODEL_ID,
158
+ torch_dtype=torch.float16,
159
+ device_map="cuda",
160
+ trust_remote_code=True,
161
+ low_cpu_mem_usage=True,
162
+ )
163
+ logger.info("Base model loaded.")
164
+ base_inputs = build_inputs(base_processor, image, question)
165
+ text = _generate_text(model, base_processor, base_inputs)
166
+ return format_derm_disclaimer(text)
167
+
168
  except Exception as e:
169
+ logger.exception("Error during inference")
170
+ return f"❌ Error analyzing image: {e}"
171
+ finally:
172
+ if model is not None:
173
+ del model
174
+ torch.cuda.empty_cache()
175
 
176
+ # ---------------------------
177
+ # UI
178
+ # ---------------------------
179
+ def create_interface() -> gr.Blocks:
180
+ with gr.Blocks(title="Dermatology AI Assistant") as demo:
181
+ gr.Markdown(
182
+ "# Dermatology AI Assistant\n"
183
+ "Upload a skin photo and ask a question. The model will provide an informational response."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  )
185
+
186
+ with gr.Row():
187
+ image_input = gr.Image(type="pil", label="Upload Image (JPG/PNG)")
188
+ question_input = gr.Textbox(
189
+ label="Question / Prompt",
190
+ value="Describe this skin condition in detail and suggest possible next steps.",
191
+ lines=3,
 
 
 
 
 
 
192
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  with gr.Row():
195
+ submit_btn = gr.Button("Analyze", variant="primary")
196
+ clear_btn = gr.Button("Clear")
197
+
198
+ output_box = gr.Textbox(label="Response", lines=16)
199
+
200
+ submit_btn.click(fn=analyze_skin_condition, inputs=[image_input, question_input], outputs=output_box, queue=True)
201
+ clear_btn.click(fn=lambda: (None, ""), inputs=None, outputs=[image_input, question_input])
202
+
203
+ demo.queue()
204
+ gr.Markdown("Tips: Ensure good lighting and focus. Avoid uploading personally identifying information.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  return demo
206
 
 
207
  def main():
208
+ demo = create_interface()
209
+ demo.launch(
210
+ server_name="0.0.0.0",
211
+ server_port=7860,
212
+ share=False,
213
+ show_error=True,
214
+ inbrowser=False,
215
+ quiet=False,
216
+ ssr_mode=False,
217
+ )
 
 
 
 
 
218
 
219
  if __name__ == "__main__":
220
  main()