ColdSlim commited on
Commit
e4aafad
·
verified ·
1 Parent(s): 2157291

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -38
app.py CHANGED
@@ -1,9 +1,10 @@
1
  # app.py
2
  # Dermatology-AI-Assistant — HF Spaces (ZeroGPU, Qwen2.5-VL + LoRA adapters)
3
- # - Loads base model, then applies LoRA/PEFT adapters from MODEL_ID, merges, and runs multimodal inference
4
  # - Uses qwen-vl-utils + AutoProcessor (multimodal) with trust_remote_code, use_fast=False
5
  # - Deterministic decoding for stable eval
6
- # - ZeroGPU only during inference
 
7
 
8
  import os
9
  import logging
@@ -13,7 +14,7 @@ import gradio as gr
13
  import spaces
14
  import torch
15
  from PIL import Image
16
- from peft import PeftModel # <-- LoRA/PEFT
17
  from transformers import AutoProcessor, AutoModelForVision2Seq
18
  from qwen_vl_utils import process_vision_info
19
 
@@ -47,7 +48,7 @@ def _load_multimodal_processor() -> AutoProcessor:
47
  accepts_images = ("images" in str(sig)) if sig else hasattr(proc, "image_processor")
48
  if accepts_images and hasattr(proc, "image_processor"):
49
  logger.info(f"Loaded multimodal processor from: {mid} ({proc.__class__.__name__})")
50
- # optional: stabilize tiling
51
  try:
52
  proc.image_processor.max_pixels = int(os.environ.get("QWEN_MAX_PIXELS", "1500000"))
53
  proc.image_processor.min_pixels = int(os.environ.get("QWEN_MIN_PIXELS", "262144"))
@@ -85,6 +86,10 @@ def build_inputs(image: Image.Image, question: str):
85
  messages = _messages(image, question)
86
  text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
87
  image_inputs, video_inputs = process_vision_info(messages)
 
 
 
 
88
  return processor(text=[text], images=image_inputs, videos=video_inputs, return_tensors="pt")
89
 
90
  def _pad_token_id(model):
@@ -92,6 +97,7 @@ def _pad_token_id(model):
92
  return tid if tid is not None else (getattr(getattr(model, "config", None), "eos_token_id", 0) or 0)
93
 
94
  def _generate_text(model, inputs: dict) -> str:
 
95
  inputs = {k: v.to("cuda") if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
96
  with torch.no_grad():
97
  out_ids = model.generate(**inputs, **GEN_KW, pad_token_id=_pad_token_id(model))
@@ -109,6 +115,7 @@ def format_derm_disclaimer(ans: str) -> str:
109
 
110
  # ---------------------------
111
  # Model loading (LoRA first, then full weights fallback, then base)
 
112
  # ---------------------------
113
  def try_load_model():
114
  """
@@ -127,7 +134,13 @@ def try_load_model():
127
  )
128
  logger.info(f"Attaching LoRA adapters from: {FT_MODEL_ID}")
129
  model = PeftModel.from_pretrained(base, FT_MODEL_ID, is_trainable=False)
130
- # IMPORTANT: do not merge here; keep adapters active so we can toggle on/off for debugging
 
 
 
 
 
 
131
  logger.info("LoRA adapters attached and active (not merged).")
132
  model.eval()
133
  return model, None
@@ -150,7 +163,7 @@ def try_load_model():
150
  except Exception as e:
151
  logger.warning(f"Full FT load failed: {e}")
152
 
153
- # 3) Final fallback: base only
154
  try:
155
  logger.info("Falling back to BASE model only.")
156
  model = AutoModelForVision2Seq.from_pretrained(
@@ -167,19 +180,27 @@ def try_load_model():
167
 
168
  def compare_with_without_lora(model, inputs):
169
  """
170
- Returns (with_lora_text, without_lora_text). Requires PeftModel with adapters active.
 
171
  """
172
- # Generate WITH LoRA (normal path)
173
  with_lora = _generate_text(model, inputs)
174
 
175
- # If model supports toggling adapters, compare WITHOUT LoRA
176
  without_lora = "[Adapters could not be toggled on this model]"
177
- if hasattr(model, "disable_adapter") and hasattr(model, "enable_adapter"):
178
- try:
179
  model.disable_adapter()
180
  without_lora = _generate_text(model, inputs)
181
- finally:
182
  model.enable_adapter()
 
 
 
 
 
 
 
 
183
 
184
  return with_lora, without_lora
185
 
@@ -193,7 +214,7 @@ def analyze_skin_condition(image: Optional[Image.Image], question: str) -> str:
193
  model = None
194
  try:
195
  inputs = build_inputs(image, question)
196
- model, warn = try_load_model()
197
  if model is None:
198
  return "❌ Could not load any model (see logs)."
199
  if warn:
@@ -208,6 +229,34 @@ def analyze_skin_condition(image: Optional[Image.Image], question: str) -> str:
208
  del model
209
  torch.cuda.empty_cache()
210
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  # ---------------------------
212
  # UI
213
  # ---------------------------
@@ -232,34 +281,14 @@ def create_interface() -> gr.Blocks:
232
  submit_btn.click(fn=analyze_skin_condition, inputs=[image_input, question_input], outputs=output_box, queue=True)
233
  clear_btn.click(fn=lambda: (None, ""), inputs=None, outputs=[image_input, question_input])
234
 
235
- demo.queue()
236
- gr.Markdown("Tips: Ensure good lighting and focus. Avoid uploading personally identifying information.")
237
-
238
  with gr.Row():
239
  debug_btn = gr.Button("Debug: Compare LoRA ON vs OFF")
240
  debug_out = gr.Textbox(label="Debug Output", lines=14)
241
-
242
- def _debug_compare(image, question):
243
- if image is None:
244
- return "Please upload an image first."
245
- try:
246
- inputs = build_inputs(image, question)
247
- model, warn = try_load_model()
248
- if model is None:
249
- return f"Load error: {warn}"
250
- if warn:
251
- logger.warning(warn)
252
- on_text, off_text = compare_with_without_lora(model, inputs)
253
- return (
254
- "=== LoRA ON ===\n" + on_text +
255
- "\n\n=== LoRA OFF ===\n" + off_text
256
- )
257
- except Exception as e:
258
- logger.exception("Debug compare failed")
259
- return f"Debug error: {e}"
260
-
261
- debug_btn.click(_debug_compare, [image_input, question_input], debug_out, queue=True)
262
 
 
 
263
  return demo
264
 
265
  def main():
@@ -271,7 +300,7 @@ def main():
271
  show_error=True,
272
  inbrowser=False,
273
  quiet=False,
274
- ssr_mode=False,
275
  )
276
 
277
  if __name__ == "__main__":
 
1
  # app.py
2
  # Dermatology-AI-Assistant — HF Spaces (ZeroGPU, Qwen2.5-VL + LoRA adapters)
3
+ # - Loads base model, then applies LoRA/PEFT adapters from MODEL_ID (kept active; not merged)
4
  # - Uses qwen-vl-utils + AutoProcessor (multimodal) with trust_remote_code, use_fast=False
5
  # - Deterministic decoding for stable eval
6
+ # - ZeroGPU only during inference (ALL CUDA work happens inside @spaces.GPU functions)
7
+ # - Includes a ZeroGPU-safe debug tool: "LoRA ON vs OFF" comparison
8
 
9
  import os
10
  import logging
 
14
  import spaces
15
  import torch
16
  from PIL import Image
17
+ from peft import PeftModel # LoRA/PEFT
18
  from transformers import AutoProcessor, AutoModelForVision2Seq
19
  from qwen_vl_utils import process_vision_info
20
 
 
48
  accepts_images = ("images" in str(sig)) if sig else hasattr(proc, "image_processor")
49
  if accepts_images and hasattr(proc, "image_processor"):
50
  logger.info(f"Loaded multimodal processor from: {mid} ({proc.__class__.__name__})")
51
+ # Optional: stabilize tiling
52
  try:
53
  proc.image_processor.max_pixels = int(os.environ.get("QWEN_MAX_PIXELS", "1500000"))
54
  proc.image_processor.min_pixels = int(os.environ.get("QWEN_MIN_PIXELS", "262144"))
 
86
  messages = _messages(image, question)
87
  text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
88
  image_inputs, video_inputs = process_vision_info(messages)
89
+ logger.info(
90
+ f"vision: images={len(image_inputs) if image_inputs is not None else 0}, "
91
+ f"first_shape={getattr(image_inputs[0], 'shape', None) if image_inputs else None}"
92
+ )
93
  return processor(text=[text], images=image_inputs, videos=video_inputs, return_tensors="pt")
94
 
95
  def _pad_token_id(model):
 
97
  return tid if tid is not None else (getattr(getattr(model, "config", None), "eos_token_id", 0) or 0)
98
 
99
  def _generate_text(model, inputs: dict) -> str:
100
+ # IMPORTANT: This is called only inside GPU-decorated functions.
101
  inputs = {k: v.to("cuda") if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
102
  with torch.no_grad():
103
  out_ids = model.generate(**inputs, **GEN_KW, pad_token_id=_pad_token_id(model))
 
115
 
116
  # ---------------------------
117
  # Model loading (LoRA first, then full weights fallback, then base)
118
+ # NOTE: Do NOT call this outside a @spaces.GPU function, because it loads to CUDA.
119
  # ---------------------------
120
  def try_load_model():
121
  """
 
134
  )
135
  logger.info(f"Attaching LoRA adapters from: {FT_MODEL_ID}")
136
  model = PeftModel.from_pretrained(base, FT_MODEL_ID, is_trainable=False)
137
+ # Log adapter visibility
138
+ try:
139
+ if hasattr(model, "get_active_adapters"):
140
+ logger.info(f"Active adapters: {model.get_active_adapters()}")
141
+ logger.info(f"PEFT config present: {hasattr(model, 'peft_config')}")
142
+ except Exception:
143
+ pass
144
  logger.info("LoRA adapters attached and active (not merged).")
145
  model.eval()
146
  return model, None
 
163
  except Exception as e:
164
  logger.warning(f"Full FT load failed: {e}")
165
 
166
+ # 3) Final fallback: base only (keep app usable)
167
  try:
168
  logger.info("Falling back to BASE model only.")
169
  model = AutoModelForVision2Seq.from_pretrained(
 
180
 
181
  def compare_with_without_lora(model, inputs):
182
  """
183
+ Returns (with_lora_text, without_lora_text).
184
+ Requires adapters active. Tries disable/enable; falls back to set_adapter([]) if available.
185
  """
186
+ # WITH LoRA
187
  with_lora = _generate_text(model, inputs)
188
 
189
+ # WITHOUT LoRA
190
  without_lora = "[Adapters could not be toggled on this model]"
191
+ try:
192
+ if hasattr(model, "disable_adapter") and hasattr(model, "enable_adapter"):
193
  model.disable_adapter()
194
  without_lora = _generate_text(model, inputs)
 
195
  model.enable_adapter()
196
+ elif hasattr(model, "set_adapter"):
197
+ current = model.get_active_adapters() if hasattr(model, "get_active_adapters") else None
198
+ model.set_adapter([]) # deactivate all
199
+ without_lora = _generate_text(model, inputs)
200
+ if current:
201
+ model.set_adapter(current)
202
+ except Exception as e:
203
+ logger.warning(f"Adapter toggle failed: {e}")
204
 
205
  return with_lora, without_lora
206
 
 
214
  model = None
215
  try:
216
  inputs = build_inputs(image, question)
217
+ model, warn = try_load_model() # SAFE: inside GPU context
218
  if model is None:
219
  return "❌ Could not load any model (see logs)."
220
  if warn:
 
229
  del model
230
  torch.cuda.empty_cache()
231
 
232
+ # ---------------------------
233
+ # Debug (ZeroGPU-safe): LoRA ON vs OFF comparison
234
+ # ---------------------------
235
+ @spaces.GPU(duration=ZGPU_DURATION)
236
+ def debug_compare_lora(image: Optional[Image.Image], question: str) -> str:
237
+ if image is None:
238
+ return "Please upload an image first."
239
+ model = None
240
+ try:
241
+ inputs = build_inputs(image, question)
242
+ model, warn = try_load_model() # SAFE: inside GPU context
243
+ if model is None:
244
+ return f"Load error: {warn}"
245
+ if warn:
246
+ logger.warning(warn)
247
+ on_text, off_text = compare_with_without_lora(model, inputs)
248
+ return (
249
+ "=== LoRA ON ===\n" + on_text +
250
+ "\n\n=== LoRA OFF ===\n" + off_text
251
+ )
252
+ except Exception as e:
253
+ logger.exception("Debug compare failed")
254
+ return f"Debug error: {e}"
255
+ finally:
256
+ if model is not None:
257
+ del model
258
+ torch.cuda.empty_cache()
259
+
260
  # ---------------------------
261
  # UI
262
  # ---------------------------
 
281
  submit_btn.click(fn=analyze_skin_condition, inputs=[image_input, question_input], outputs=output_box, queue=True)
282
  clear_btn.click(fn=lambda: (None, ""), inputs=None, outputs=[image_input, question_input])
283
 
284
+ # Debug: LoRA ON vs OFF (GPU-decorated function)
 
 
285
  with gr.Row():
286
  debug_btn = gr.Button("Debug: Compare LoRA ON vs OFF")
287
  debug_out = gr.Textbox(label="Debug Output", lines=14)
288
+ debug_btn.click(fn=debug_compare_lora, inputs=[image_input, question_input], outputs=debug_out, queue=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289
 
290
+ demo.queue()
291
+ gr.Markdown("Tips: Ensure good lighting and focus. Avoid uploading personally identifying information.")
292
  return demo
293
 
294
  def main():
 
300
  show_error=True,
301
  inbrowser=False,
302
  quiet=False,
303
+ ssr_mode=False, # avoid Node requirement in container
304
  )
305
 
306
  if __name__ == "__main__":