ColdSlim commited on
Commit
a2e0d44
·
verified ·
1 Parent(s): 8ac03b2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +106 -108
app.py CHANGED
@@ -1,10 +1,9 @@
1
  # app.py
2
- # Dermatology-AI-Assistant — HF Spaces (ZeroGPU, Qwen2.5-VL multimodal)
3
- # - GUARANTEES multimodal: loads processor from base with trust_remote_code + use_fast=False
4
- # - Asserts processor supports images at startup (clear error if deps are wrong)
5
- # - Tries FT model first; falls back to base model on load/generation issues
6
- # - Uses qwen-vl-utils for vision inputs
7
- # - ZeroGPU only during inference; no runtime pip installs
8
 
9
  import os
10
  import logging
@@ -14,6 +13,7 @@ import gradio as gr
14
  import spaces
15
  import torch
16
  from PIL import Image
 
17
  from transformers import AutoProcessor, AutoModelForVision2Seq
18
  from qwen_vl_utils import process_vision_info
19
 
@@ -23,65 +23,65 @@ logger = logging.getLogger(__name__)
23
  # ---------------------------
24
  # Config
25
  # ---------------------------
26
- FT_MODEL_ID = os.environ.get("MODEL_ID", "ColdSlim/Dermatology-Qwen2.5-VL-3B")
27
  BASE_MODEL_ID = os.environ.get("FALLBACK_BASE_MODEL_ID", "Qwen/Qwen2.5-VL-3B-Instruct")
28
 
29
  GEN_KW = dict(
30
- max_new_tokens=512,
31
- do_sample=True,
32
- temperature=0.7,
33
- top_p=0.9,
34
  )
35
 
36
  ZGPU_DURATION = int(os.environ.get("ZGPU_DURATION", "180"))
37
 
38
  # ---------------------------
39
- # Load MULTIMODAL processor from BASE (NOT FT) and validate it
40
  # ---------------------------
41
- logger.info(f"Loading processor from base model (multimodal expected): {BASE_MODEL_ID}")
42
- processor = AutoProcessor.from_pretrained(
43
- BASE_MODEL_ID,
44
- trust_remote_code=True,
45
- use_fast=False, # critical: ensure multimodal __call__ supports images/videos
46
- )
47
- logger.info(f"Processor class: {processor.__class__.__name__}")
48
-
49
- # Validate that processor can handle images
50
- proc_sig = getattr(processor.__call__, "__signature__", None)
51
- accepts_images = ("images" in str(proc_sig)) if proc_sig else hasattr(processor, "image_processor")
52
- if not accepts_images or not hasattr(processor, "image_processor"):
53
- raise RuntimeError(
54
- "Loaded processor is not multimodal. Ensure requirements include: "
55
- "transformers>=4.56.1, qwen-vl-utils>=0.0.10, torch>=2.2.0, and do a Factory reboot."
56
- )
57
-
58
- # Optional: stabilize tiling/token placeholders
59
- if hasattr(processor, "image_processor"):
60
- try:
61
- processor.image_processor.max_pixels = int(os.environ.get("QWEN_MAX_PIXELS", "1500000")) # ~1.5MP
62
- processor.image_processor.min_pixels = int(os.environ.get("QWEN_MIN_PIXELS", "262144")) # 512x512
63
- except Exception:
64
- pass
65
 
66
  # ---------------------------
67
  # Helpers
68
  # ---------------------------
 
 
 
 
 
 
 
 
 
69
  def _messages(image: Image.Image, question: str):
70
  if image.mode != "RGB":
71
  image = image.convert("RGB")
72
- return [{
73
- "role": "user",
74
- "content": [
75
- {"type": "image", "image": image},
76
- {"type": "text", "text": question},
77
- ],
78
- }]
79
 
80
  def build_inputs(image: Image.Image, question: str):
81
- """
82
- Build Qwen2.5-VL multimodal inputs using processor + qwen-vl-utils.
83
- Single-sample, no padding (reduces placeholder mask edge cases).
84
- """
85
  messages = _messages(image, question)
86
  text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
87
  image_inputs, video_inputs = process_vision_info(messages)
@@ -89,98 +89,99 @@ def build_inputs(image: Image.Image, question: str):
89
 
90
  def _pad_token_id(model):
91
  tid = getattr(getattr(processor, "tokenizer", None), "eos_token_id", None)
92
- if tid is not None:
93
- return tid
94
- return getattr(getattr(model, "config", None), "eos_token_id", 0) or 0
95
 
96
  def _generate_text(model, inputs: dict) -> str:
97
- # move tensors to CUDA
98
  inputs = {k: v.to("cuda") if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
99
  with torch.no_grad():
100
- out_ids = model.generate(
101
- **inputs,
102
- **GEN_KW,
103
- pad_token_id=_pad_token_id(model),
104
- )
105
  trimmed = [o[len(i):] for i, o in zip(inputs["input_ids"], out_ids)]
106
  text = processor.batch_decode(trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
107
  return text
108
 
109
  def format_derm_disclaimer(ans: str) -> str:
110
- tail = (
111
- "\n\n---\n"
112
- "_Disclaimer: This AI is not a medical device. The output is informational and may be inaccurate. "
113
- "Consult a qualified dermatologist for diagnosis and treatment._"
 
114
  )
115
- return ans + tail
116
 
117
- def try_load_model(model_id: str, *, allow_mismatch: bool):
 
 
 
118
  """
119
- Load Qwen2.5-VL via AutoModelForVision2Seq with trust_remote_code (multimodal weights).
 
120
  """
 
121
  try:
122
- logger.info(f"Loading model on GPU: {model_id}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  model = AutoModelForVision2Seq.from_pretrained(
124
- model_id,
125
  torch_dtype=torch.float16,
126
  device_map="cuda",
127
  trust_remote_code=True,
128
  low_cpu_mem_usage=True,
129
- ignore_mismatched_sizes=False,
130
- offload_state_dict=False,
131
  )
132
- logger.info(f"Model loaded: {model_id} ({model.__class__.__name__})")
133
  return model, None
134
  except Exception as e:
135
- logger.warning(f"Model load failed for {model_id}: {e}")
136
- return None, str(e)
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
  # ---------------------------
139
  # Inference (ZeroGPU)
140
  # ---------------------------
141
  @spaces.GPU(duration=ZGPU_DURATION)
142
  def analyze_skin_condition(image: Optional[Image.Image], question: str) -> str:
143
- """
144
- STRICT multimodal: requires processor with images support (asserted at startup).
145
- Try FT model first; on ANY load/generation error, fall back to base model.
146
- """
147
  if image is None:
148
  return "❌ Please upload an image first."
149
-
150
  model = None
151
  try:
152
  inputs = build_inputs(image, question)
153
-
154
- # Attempt 1: fine-tuned model
155
- model, ft_err = try_load_model(FT_MODEL_ID, allow_mismatch=True)
156
- if model is not None:
157
- try:
158
- text = _generate_text(model, inputs)
159
- return format_derm_disclaimer(text)
160
- except ValueError as ve:
161
- if "Image features and image tokens do not match" in str(ve):
162
- logger.warning("Token/feature mismatch on FT model — falling back to base.")
163
- else:
164
- logger.warning(f"FT generation error: {ve}. Falling back to base.")
165
- except Exception as gen_e:
166
- logger.warning(f"FT generation failed: {gen_e}. Falling back to base.")
167
- else:
168
- logger.warning(f"FT model unavailable, error: {ft_err}. Falling back to base.")
169
-
170
- # Free FT before base
171
- if model is not None:
172
- del model
173
- model = None
174
- torch.cuda.empty_cache()
175
-
176
- # Attempt 2: base model
177
- model, base_err = try_load_model(BASE_MODEL_ID, allow_mismatch=False)
178
  if model is None:
179
- return f"❌ Error loading models.\n- FT: {ft_err}\n- BASE: {base_err}"
180
-
 
181
  text = _generate_text(model, inputs)
182
  return format_derm_disclaimer(text)
183
-
184
  except Exception as e:
185
  logger.exception("Error during inference")
186
  return f"❌ Error analyzing image: {e}"
@@ -198,7 +199,6 @@ def create_interface() -> gr.Blocks:
198
  "# Dermatology AI Assistant\n"
199
  "Upload a skin photo and ask a question. The model will provide an informational response."
200
  )
201
-
202
  with gr.Row():
203
  image_input = gr.Image(type="pil", label="Upload Image (JPG/PNG)")
204
  question_input = gr.Textbox(
@@ -206,17 +206,15 @@ def create_interface() -> gr.Blocks:
206
  value="Describe this skin condition in detail and suggest possible next steps.",
207
  lines=3,
208
  )
209
-
210
  with gr.Row():
211
  submit_btn = gr.Button("Analyze", variant="primary")
212
  clear_btn = gr.Button("Clear")
213
-
214
  output_box = gr.Textbox(label="Response", lines=16)
215
 
216
  submit_btn.click(fn=analyze_skin_condition, inputs=[image_input, question_input], outputs=output_box, queue=True)
217
  clear_btn.click(fn=lambda: (None, ""), inputs=None, outputs=[image_input, question_input])
218
 
219
- demo.queue() # Gradio 4.44.1: no kwargs
220
  gr.Markdown("Tips: Ensure good lighting and focus. Avoid uploading personally identifying information.")
221
  return demo
222
 
@@ -229,7 +227,7 @@ def main():
229
  show_error=True,
230
  inbrowser=False,
231
  quiet=False,
232
- ssr_mode=False, # avoid Node requirement in container
233
  )
234
 
235
  if __name__ == "__main__":
 
1
  # app.py
2
+ # Dermatology-AI-Assistant — HF Spaces (ZeroGPU, Qwen2.5-VL + LoRA adapters)
3
+ # - Loads base model, then applies LoRA/PEFT adapters from MODEL_ID, merges, and runs multimodal inference
4
+ # - Uses qwen-vl-utils + AutoProcessor (multimodal) with trust_remote_code, use_fast=False
5
+ # - Deterministic decoding for stable eval
6
+ # - ZeroGPU only during inference
 
7
 
8
  import os
9
  import logging
 
13
  import spaces
14
  import torch
15
  from PIL import Image
16
+ from peft import PeftModel # <-- LoRA/PEFT
17
  from transformers import AutoProcessor, AutoModelForVision2Seq
18
  from qwen_vl_utils import process_vision_info
19
 
 
23
  # ---------------------------
24
  # Config
25
  # ---------------------------
26
+ FT_MODEL_ID = os.environ.get("MODEL_ID", "ColdSlim/Dermatology-Qwen2.5-VL-3B") # LoRA adapters repo
27
  BASE_MODEL_ID = os.environ.get("FALLBACK_BASE_MODEL_ID", "Qwen/Qwen2.5-VL-3B-Instruct")
28
 
29
  GEN_KW = dict(
30
+ max_new_tokens=256,
31
+ do_sample=False, # deterministic for evaluation
32
+ temperature=0.0,
33
+ top_p=1.0,
34
  )
35
 
36
  ZGPU_DURATION = int(os.environ.get("ZGPU_DURATION", "180"))
37
 
38
  # ---------------------------
39
+ # Processor (try FT first; fall back to base). Must be multimodal.
40
  # ---------------------------
41
+ def _load_multimodal_processor() -> AutoProcessor:
42
+ tried = []
43
+ for mid in (FT_MODEL_ID, BASE_MODEL_ID):
44
+ try:
45
+ proc = AutoProcessor.from_pretrained(mid, trust_remote_code=True, use_fast=False)
46
+ sig = getattr(proc.__call__, "__signature__", None)
47
+ accepts_images = ("images" in str(sig)) if sig else hasattr(proc, "image_processor")
48
+ if accepts_images and hasattr(proc, "image_processor"):
49
+ logger.info(f"Loaded multimodal processor from: {mid} ({proc.__class__.__name__})")
50
+ # optional: stabilize tiling
51
+ try:
52
+ proc.image_processor.max_pixels = int(os.environ.get("QWEN_MAX_PIXELS", "1500000"))
53
+ proc.image_processor.min_pixels = int(os.environ.get("QWEN_MIN_PIXELS", "262144"))
54
+ except Exception:
55
+ pass
56
+ return proc
57
+ tried.append(f"{mid} => {proc.__class__.__name__} (no images support)")
58
+ except Exception as e:
59
+ tried.append(f"{mid} => ERROR: {e}")
60
+ raise RuntimeError("Failed to load a multimodal processor. Tried:\n" + "\n".join(tried))
61
+
62
+ processor = _load_multimodal_processor()
 
 
63
 
64
  # ---------------------------
65
  # Helpers
66
  # ---------------------------
67
+ SYSTEM_PROMPT = (
68
+ "You are a dermatology assistant. First, look carefully at the IMAGE.\n"
69
+ "If the image is NOT a close-up of human skin or a dermatologic lesion, "
70
+ "respond EXACTLY with: 'The image does not appear to show a skin condition; I cannot analyze it.' "
71
+ "Do not invent findings.\n"
72
+ "If it IS a skin/lesion photo, provide a concise description, 3–5 likely differentials, "
73
+ "and prudent next steps (including red flags). Avoid definitive diagnoses."
74
+ )
75
+
76
  def _messages(image: Image.Image, question: str):
77
  if image.mode != "RGB":
78
  image = image.convert("RGB")
79
+ return [
80
+ {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
81
+ {"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": question}]},
82
+ ]
 
 
 
83
 
84
  def build_inputs(image: Image.Image, question: str):
 
 
 
 
85
  messages = _messages(image, question)
86
  text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
87
  image_inputs, video_inputs = process_vision_info(messages)
 
89
 
90
  def _pad_token_id(model):
91
  tid = getattr(getattr(processor, "tokenizer", None), "eos_token_id", None)
92
+ return tid if tid is not None else (getattr(getattr(model, "config", None), "eos_token_id", 0) or 0)
 
 
93
 
94
  def _generate_text(model, inputs: dict) -> str:
 
95
  inputs = {k: v.to("cuda") if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
96
  with torch.no_grad():
97
+ out_ids = model.generate(**inputs, **GEN_KW, pad_token_id=_pad_token_id(model))
 
 
 
 
98
  trimmed = [o[len(i):] for i, o in zip(inputs["input_ids"], out_ids)]
99
  text = processor.batch_decode(trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
100
  return text
101
 
102
  def format_derm_disclaimer(ans: str) -> str:
103
+ return (
104
+ ans
105
+ + "\n\n---\n"
106
+ "_Disclaimer: This AI is not a medical device. The output is informational and may be inaccurate. "
107
+ "Consult a qualified dermatologist for diagnosis and treatment._"
108
  )
 
109
 
110
+ # ---------------------------
111
+ # Model loading (LoRA first, then full weights fallback, then base)
112
+ # ---------------------------
113
+ def try_load_model() -> Tuple[Optional[AutoModelForVision2Seq], Optional[str]]:
114
  """
115
+ Preferred path: load BASE, then apply LoRA adapters from FT repo, merge, unload.
116
+ Fallbacks: full FT weights -> pure base.
117
  """
118
+ # 1) BASE + LoRA adapters (PEFT)
119
  try:
120
+ logger.info(f"Loading BASE model: {BASE_MODEL_ID}")
121
+ base = AutoModelForVision2Seq.from_pretrained(
122
+ BASE_MODEL_ID,
123
+ torch_dtype=torch.float16,
124
+ device_map="cuda",
125
+ trust_remote_code=True,
126
+ low_cpu_mem_usage=True,
127
+ )
128
+ logger.info(f"Attaching LoRA adapters from: {FT_MODEL_ID}")
129
+ model = PeftModel.from_pretrained(base, FT_MODEL_ID, is_trainable=False)
130
+ try:
131
+ model = model.merge_and_unload()
132
+ logger.info("Merged LoRA adapters into base (inference-optimized).")
133
+ except Exception as e:
134
+ logger.info(f"Adapters active without merge (PEFT runtime). Reason: {e}")
135
+ return model, None
136
+ except Exception as peft_e:
137
+ logger.warning(f"PEFT adapters load failed: {peft_e}")
138
+
139
+ # 2) Try full FT weights (in case you exported merged weights)
140
+ try:
141
+ logger.info(f"Loading full FT weights from: {FT_MODEL_ID}")
142
  model = AutoModelForVision2Seq.from_pretrained(
143
+ FT_MODEL_ID,
144
  torch_dtype=torch.float16,
145
  device_map="cuda",
146
  trust_remote_code=True,
147
  low_cpu_mem_usage=True,
148
+ ignore_mismatched_sizes=False, # strict: do not silently re-init layers
 
149
  )
 
150
  return model, None
151
  except Exception as e:
152
+ logger.warning(f"Full FT load failed: {e}")
153
+
154
+ # 3) Final fallback: base only (so app still works)
155
+ try:
156
+ logger.info("Falling back to BASE model only.")
157
+ model = AutoModelForVision2Seq.from_pretrained(
158
+ BASE_MODEL_ID,
159
+ torch_dtype=torch.float16,
160
+ device_map="cuda",
161
+ trust_remote_code=True,
162
+ low_cpu_mem_usage=True,
163
+ )
164
+ return model, "Using base model only (FT not applied)."
165
+ except Exception as e:
166
+ return None, f"Base load failed too: {e}"
167
 
168
  # ---------------------------
169
  # Inference (ZeroGPU)
170
  # ---------------------------
171
  @spaces.GPU(duration=ZGPU_DURATION)
172
  def analyze_skin_condition(image: Optional[Image.Image], question: str) -> str:
 
 
 
 
173
  if image is None:
174
  return "❌ Please upload an image first."
 
175
  model = None
176
  try:
177
  inputs = build_inputs(image, question)
178
+ model, warn = try_load_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  if model is None:
180
+ return "❌ Could not load any model (see logs)."
181
+ if warn:
182
+ logger.warning(warn)
183
  text = _generate_text(model, inputs)
184
  return format_derm_disclaimer(text)
 
185
  except Exception as e:
186
  logger.exception("Error during inference")
187
  return f"❌ Error analyzing image: {e}"
 
199
  "# Dermatology AI Assistant\n"
200
  "Upload a skin photo and ask a question. The model will provide an informational response."
201
  )
 
202
  with gr.Row():
203
  image_input = gr.Image(type="pil", label="Upload Image (JPG/PNG)")
204
  question_input = gr.Textbox(
 
206
  value="Describe this skin condition in detail and suggest possible next steps.",
207
  lines=3,
208
  )
 
209
  with gr.Row():
210
  submit_btn = gr.Button("Analyze", variant="primary")
211
  clear_btn = gr.Button("Clear")
 
212
  output_box = gr.Textbox(label="Response", lines=16)
213
 
214
  submit_btn.click(fn=analyze_skin_condition, inputs=[image_input, question_input], outputs=output_box, queue=True)
215
  clear_btn.click(fn=lambda: (None, ""), inputs=None, outputs=[image_input, question_input])
216
 
217
+ demo.queue()
218
  gr.Markdown("Tips: Ensure good lighting and focus. Avoid uploading personally identifying information.")
219
  return demo
220
 
 
227
  show_error=True,
228
  inbrowser=False,
229
  quiet=False,
230
+ ssr_mode=False,
231
  )
232
 
233
  if __name__ == "__main__":