ColdSlim commited on
Commit
ab5e55b
·
verified ·
1 Parent(s): a79b20b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -85
app.py CHANGED
@@ -1,8 +1,9 @@
1
  # app.py
2
  # Dermatology-AI-Assistant — Hugging Face Space (ZeroGPU-ready)
3
- # - Uses qwen-vl-utils for vision inputs
 
4
  # - Acquires ZeroGPU only during inference
5
- # - Handles Qwen2-VL token/feature mismatch with a safe fallback retry
6
 
7
  import os
8
  import logging
@@ -15,42 +16,47 @@ from PIL import Image
15
  from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
16
  from qwen_vl_utils import process_vision_info
17
 
18
- # ---------------------------
19
- # Logging
20
- # ---------------------------
21
  logging.basicConfig(level=logging.INFO, format="%(levelname)s:%(name)s:%(message)s")
22
  logger = logging.getLogger(__name__)
23
 
24
  # ---------------------------
25
  # Config
26
  # ---------------------------
27
- MODEL_ID = os.environ.get("MODEL_ID", "ColdSlim/Dermatology-Qwen2.5-VL-3B")
 
 
28
  GEN_KW = dict(
29
  max_new_tokens=512,
30
  do_sample=True,
31
  temperature=0.7,
32
  top_p=0.9,
33
  )
 
34
  ZGPU_DURATION = int(os.environ.get("ZGPU_DURATION", "180"))
35
 
36
- logger.info(f"Loading processor from: {MODEL_ID}")
37
- processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
 
 
38
 
39
- # (Optional) Tame resolution to reduce tiling variance; adjust if you like.
40
- if hasattr(processor, "image_processor"):
41
- try:
42
- # Keep images within a predictable pixel band so placeholder count is stable.
43
- processor.image_processor.max_pixels = int(os.environ.get("QWEN_MAX_PIXELS", "1500000")) # ~1.5MP
44
- processor.image_processor.min_pixels = int(os.environ.get("QWEN_MIN_PIXELS", "262144")) # 512x512
45
- except Exception:
46
- pass
47
 
48
- logger.info("Processor loaded.")
49
 
50
  # ---------------------------
51
  # Helpers
52
  # ---------------------------
53
  def _messages(image: Image.Image, question: str):
 
 
 
54
  return [
55
  {
56
  "role": "user",
@@ -61,47 +67,28 @@ def _messages(image: Image.Image, question: str):
61
  }
62
  ]
63
 
64
- def build_inputs(image: Image.Image, question: str, *, disable_splitting: bool = False):
65
  """
66
- Build Qwen-style multimodal chat inputs.
67
- When disable_splitting is True, we hint the image processor to avoid tiling,
68
- which can fix token/feature mismatches for some edge cases.
69
  """
70
  messages = _messages(image, question)
71
-
72
- # Apply chat template (inserts <image> placeholders automatically)
73
  text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
74
-
75
- # Prepare vision inputs
76
  image_inputs, video_inputs = process_vision_info(messages)
77
-
78
- # Optionally force-disable splitting (fallback path)
79
- if disable_splitting and hasattr(processor, "image_processor"):
80
- ip = processor.image_processor
81
- # Cache old setting to not mutate global defaults permanently
82
- prev = getattr(ip, "do_image_splitting", None)
83
- try:
84
- if hasattr(ip, "do_image_splitting"):
85
- ip.do_image_splitting = False
86
- inputs = processor(
87
- text=[text],
88
- images=image_inputs,
89
- videos=video_inputs,
90
- return_tensors="pt", # <- no padding for single-sample path
91
- )
92
- finally:
93
- if prev is not None:
94
- ip.do_image_splitting = prev
95
- else:
96
- inputs = processor(
97
- text=[text],
98
- images=image_inputs,
99
- videos=video_inputs,
100
- return_tensors="pt", # <- no padding to avoid mask quirks
101
- )
102
-
103
  return inputs
104
 
 
 
 
 
 
 
 
105
  def format_derm_disclaimer(ans: str) -> str:
106
  tail = (
107
  "\n\n---\n"
@@ -110,69 +97,81 @@ def format_derm_disclaimer(ans: str) -> str:
110
  )
111
  return ans + tail
112
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  # ---------------------------
114
  # Inference (ZeroGPU)
115
  # ---------------------------
116
  @spaces.GPU(duration=ZGPU_DURATION)
117
  def analyze_skin_condition(image: Optional[Image.Image], question: str) -> str:
118
  """
119
- Runs inside a ZeroGPU reservation window.
120
- Loads model on GPU, generates, frees VRAM.
121
- Includes a fallback retry if Qwen raises a token/feature mismatch.
122
  """
123
  if image is None:
124
  return "❌ Please upload an image first."
125
 
 
126
  try:
127
- logger.info(f"Loading model on GPU: {MODEL_ID}")
 
128
  model = Qwen2VLForConditionalGeneration.from_pretrained(
129
- MODEL_ID,
130
  torch_dtype=torch.float16,
131
  device_map="cuda",
132
  trust_remote_code=True,
133
  low_cpu_mem_usage=True,
134
- ignore_mismatched_sizes=True, # keep until weights align perfectly
135
  )
136
- logger.info("Model loaded successfully!")
137
-
138
- def _run_infer(disable_splitting: bool = False) -> str:
139
- inputs = build_inputs(image, question, disable_splitting=disable_splitting)
140
- # Move tensors to CUDA
141
- inputs = {k: v.to("cuda") if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
142
- with torch.no_grad():
143
- out_ids = model.generate(
144
- **inputs,
145
- **GEN_KW,
146
- pad_token_id=processor.tokenizer.eos_token_id,
147
- )
148
- # Strip prompt tokens before decoding
149
- trimmed = [o[len(i):] for i, o in zip(inputs["input_ids"], out_ids)]
150
- text = processor.batch_decode(
151
- trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
152
- )[0]
153
- return text
154
-
155
- # First attempt: normal path
156
  try:
157
- text = _run_infer(disable_splitting=False)
 
158
  except ValueError as ve:
159
  msg = str(ve)
160
- # Known Qwen2-VL edge case: token/feature mismatch — retry with splitting disabled
161
  if "Image features and image tokens do not match" in msg:
162
- logger.warning("Token/feature mismatch detectedretrying with image splitting disabled.")
163
- text = _run_infer(disable_splitting=True)
164
  else:
165
  raise
166
 
167
- # Free VRAM
 
168
  del model
169
  torch.cuda.empty_cache()
170
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  return format_derm_disclaimer(text)
172
 
173
  except Exception as e:
174
  logger.exception("Error during inference")
175
  return f"❌ Error analyzing image: {e}"
 
 
 
 
176
 
177
  # ---------------------------
178
  # UI
@@ -201,9 +200,7 @@ def create_interface() -> gr.Blocks:
201
  submit_btn.click(fn=analyze_skin_condition, inputs=[image_input, question_input], outputs=output_box, queue=True)
202
  clear_btn.click(fn=lambda: (None, ""), inputs=None, outputs=[image_input, question_input])
203
 
204
- # Gradio 4.44.1: simple queue() call (no kwargs)
205
  demo.queue()
206
-
207
  gr.Markdown("Tips: Ensure good lighting and focus. Avoid uploading personally identifying information.")
208
  return demo
209
 
@@ -216,7 +213,7 @@ def main():
216
  show_error=True,
217
  inbrowser=False,
218
  quiet=False,
219
- ssr_mode=False, # avoid Node 20 requirement in container
220
  )
221
 
222
  if __name__ == "__main__":
 
1
  # app.py
2
  # Dermatology-AI-Assistant — Hugging Face Space (ZeroGPU-ready)
3
+ # - First tries your fine-tuned model
4
+ # - If Qwen raises token/feature mismatch, falls back to official base model
5
  # - Acquires ZeroGPU only during inference
6
+ # - Uses qwen-vl-utils.process_vision_info
7
 
8
  import os
9
  import logging
 
16
  from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
17
  from qwen_vl_utils import process_vision_info
18
 
 
 
 
19
  logging.basicConfig(level=logging.INFO, format="%(levelname)s:%(name)s:%(message)s")
20
  logger = logging.getLogger(__name__)
21
 
22
  # ---------------------------
23
  # Config
24
  # ---------------------------
25
+ FT_MODEL_ID = os.environ.get("MODEL_ID", "ColdSlim/Dermatology-Qwen2.5-VL-3B")
26
+ BASE_MODEL_ID = os.environ.get("FALLBACK_BASE_MODEL_ID", "Qwen/Qwen2.5-VL-3B-Instruct")
27
+
28
  GEN_KW = dict(
29
  max_new_tokens=512,
30
  do_sample=True,
31
  temperature=0.7,
32
  top_p=0.9,
33
  )
34
+
35
  ZGPU_DURATION = int(os.environ.get("ZGPU_DURATION", "180"))
36
 
37
+ # Preload only the fine-tuned processor on CPU; we may swap to base processor in the fallback
38
+ logger.info(f"Loading processor from: {FT_MODEL_ID}")
39
+ ft_processor = AutoProcessor.from_pretrained(FT_MODEL_ID, trust_remote_code=True)
40
+ logger.info("Processor loaded.")
41
 
42
+ # Optional: stabilize tiling by constraining pixel range (helps placeholder consistency)
43
+ def _tune_image_processor(proc):
44
+ if hasattr(proc, "image_processor"):
45
+ try:
46
+ proc.image_processor.max_pixels = int(os.environ.get("QWEN_MAX_PIXELS", "1500000")) # ~1.5MP
47
+ proc.image_processor.min_pixels = int(os.environ.get("QWEN_MIN_PIXELS", "262144")) # 512x512
48
+ except Exception:
49
+ pass
50
 
51
+ _tune_image_processor(ft_processor)
52
 
53
  # ---------------------------
54
  # Helpers
55
  # ---------------------------
56
  def _messages(image: Image.Image, question: str):
57
+ # ensure RGB to avoid mode surprises
58
+ if image.mode != "RGB":
59
+ image = image.convert("RGB")
60
  return [
61
  {
62
  "role": "user",
 
67
  }
68
  ]
69
 
70
+ def build_inputs(processor: AutoProcessor, image: Image.Image, question: str):
71
  """
72
+ Build Qwen-style multimodal inputs (no padding, batch size 1).
 
 
73
  """
74
  messages = _messages(image, question)
 
 
75
  text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
 
 
76
  image_inputs, video_inputs = process_vision_info(messages)
77
+ inputs = processor(
78
+ text=[text],
79
+ images=image_inputs,
80
+ videos=video_inputs,
81
+ return_tensors="pt", # no padding for single sample
82
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  return inputs
84
 
85
+ def _pad_token_id(processor, model):
86
+ # Prefer tokenizer.eos if present; else model config; else 0
87
+ tid = getattr(getattr(processor, "tokenizer", None), "eos_token_id", None)
88
+ if tid is not None:
89
+ return tid
90
+ return getattr(getattr(model, "config", None), "eos_token_id", 0)
91
+
92
  def format_derm_disclaimer(ans: str) -> str:
93
  tail = (
94
  "\n\n---\n"
 
97
  )
98
  return ans + tail
99
 
100
+ def _generate_text(model, processor, inputs: dict) -> str:
101
+ # move to CUDA
102
+ inputs = {k: v.to("cuda") if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
103
+ with torch.no_grad():
104
+ out_ids = model.generate(
105
+ **inputs,
106
+ **GEN_KW,
107
+ pad_token_id=_pad_token_id(processor, model),
108
+ )
109
+ trimmed = [o[len(i):] for i, o in zip(inputs["input_ids"], out_ids)]
110
+ text = processor.batch_decode(trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
111
+ return text
112
+
113
  # ---------------------------
114
  # Inference (ZeroGPU)
115
  # ---------------------------
116
  @spaces.GPU(duration=ZGPU_DURATION)
117
  def analyze_skin_condition(image: Optional[Image.Image], question: str) -> str:
118
  """
119
+ Try fine-tuned model first; on token/feature mismatch, fall back to base model+processor.
 
 
120
  """
121
  if image is None:
122
  return "❌ Please upload an image first."
123
 
124
+ model = None
125
  try:
126
+ # ------- Attempt 1: Fine-tuned model -------
127
+ logger.info(f"Loading fine-tuned model on GPU: {FT_MODEL_ID}")
128
  model = Qwen2VLForConditionalGeneration.from_pretrained(
129
+ FT_MODEL_ID,
130
  torch_dtype=torch.float16,
131
  device_map="cuda",
132
  trust_remote_code=True,
133
  low_cpu_mem_usage=True,
134
+ ignore_mismatched_sizes=True, # your FT ckpt logs suggest some vision head diffs
135
  )
136
+ logger.info("Fine-tuned model loaded.")
137
+ inputs = build_inputs(ft_processor, image, question)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  try:
139
+ text = _generate_text(model, ft_processor, inputs)
140
+ return format_derm_disclaimer(text)
141
  except ValueError as ve:
142
  msg = str(ve)
 
143
  if "Image features and image tokens do not match" in msg:
144
+ logger.warning("Token/feature mismatch on fine-tuned model falling back to base model.")
 
145
  else:
146
  raise
147
 
148
+ # ------- Attempt 2: Base model & its processor -------
149
+ # Free FT model first
150
  del model
151
  torch.cuda.empty_cache()
152
 
153
+ logger.info(f"Loading BASE model on GPU: {BASE_MODEL_ID}")
154
+ base_processor = AutoProcessor.from_pretrained(BASE_MODEL_ID, trust_remote_code=True)
155
+ _tune_image_processor(base_processor)
156
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
157
+ BASE_MODEL_ID,
158
+ torch_dtype=torch.float16,
159
+ device_map="cuda",
160
+ trust_remote_code=True,
161
+ low_cpu_mem_usage=True,
162
+ )
163
+ logger.info("Base model loaded.")
164
+ base_inputs = build_inputs(base_processor, image, question)
165
+ text = _generate_text(model, base_processor, base_inputs)
166
  return format_derm_disclaimer(text)
167
 
168
  except Exception as e:
169
  logger.exception("Error during inference")
170
  return f"❌ Error analyzing image: {e}"
171
+ finally:
172
+ if model is not None:
173
+ del model
174
+ torch.cuda.empty_cache()
175
 
176
  # ---------------------------
177
  # UI
 
200
  submit_btn.click(fn=analyze_skin_condition, inputs=[image_input, question_input], outputs=output_box, queue=True)
201
  clear_btn.click(fn=lambda: (None, ""), inputs=None, outputs=[image_input, question_input])
202
 
 
203
  demo.queue()
 
204
  gr.Markdown("Tips: Ensure good lighting and focus. Avoid uploading personally identifying information.")
205
  return demo
206
 
 
213
  show_error=True,
214
  inbrowser=False,
215
  quiet=False,
216
+ ssr_mode=False,
217
  )
218
 
219
  if __name__ == "__main__":