ColdSlim commited on
Commit
49e8446
·
verified ·
1 Parent(s): 8b9a9ad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -42
app.py CHANGED
@@ -1,13 +1,13 @@
1
  # app.py
2
- # Dermatology-AI-Assistant—HF Spaces (ZeroGPU)
3
- # - Tries fine-tuned model first; on load/mismatch errors, falls back to base
4
  # - Uses qwen-vl-utils for vision preprocessing
5
- # - Acquires ZeroGPU only during inference
6
  # - No runtime pip; pin versions in requirements.txt
7
 
8
  import os
9
  import logging
10
- from typing import Optional
11
 
12
  import gradio as gr
13
  import spaces
@@ -34,7 +34,7 @@ GEN_KW = dict(
34
 
35
  ZGPU_DURATION = int(os.environ.get("ZGPU_DURATION", "180"))
36
 
37
- # Preload only FT processor on CPU; we may swap to base processor in fallback
38
  logger.info(f"Loading processor from: {FT_MODEL_ID}")
39
  ft_processor = AutoProcessor.from_pretrained(FT_MODEL_ID, trust_remote_code=True)
40
  logger.info("Processor loaded.")
@@ -67,7 +67,7 @@ def build_inputs(processor: AutoProcessor, image: Image.Image, question: str):
67
  messages = _messages(image, question)
68
  text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
69
  image_inputs, video_inputs = process_vision_info(messages)
70
- # no padding for single sample to avoid mask quirks
71
  inputs = processor(text=[text], images=image_inputs, videos=video_inputs, return_tensors="pt")
72
  return inputs
73
 
@@ -98,65 +98,73 @@ def format_derm_disclaimer(ans: str) -> str:
98
  )
99
  return ans + tail
100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  # ---------------------------
102
  # Inference (ZeroGPU)
103
  # ---------------------------
104
  @spaces.GPU(duration=ZGPU_DURATION)
105
  def analyze_skin_condition(image: Optional[Image.Image], question: str) -> str:
106
  """
107
- Try fine-tuned model first; if load or token/feature mismatch occurs, fall back to base model+processor.
108
  """
109
  if image is None:
110
  return "❌ Please upload an image first."
111
 
112
- model = None # ensure defined for finally block
113
  try:
114
- # ------- Attempt 1: Fine-tuned model -------
115
- try:
116
- logger.info(f"Loading fine-tuned model on GPU: {FT_MODEL_ID}")
117
- model = Qwen2VLForConditionalGeneration.from_pretrained(
118
- FT_MODEL_ID,
119
- torch_dtype=torch.float16,
120
- device_map="cuda",
121
- trust_remote_code=True,
122
- low_cpu_mem_usage=True,
123
- ignore_mismatched_sizes=True, # allow partial head diffs
124
- # offload_state_dict can help with odd shards during load
125
- offload_state_dict=True,
126
- )
127
- logger.info("Fine-tuned model loaded.")
128
- inputs = build_inputs(ft_processor, image, question)
129
  try:
 
130
  text = _generate_text(model, ft_processor, inputs)
131
  return format_derm_disclaimer(text)
132
  except ValueError as ve:
133
- # Qwen2-VL edge case: placeholder token vs feature mismatch
134
  if "Image features and image tokens do not match" in str(ve):
135
- logger.warning("Token/feature mismatch on FT model — switching to base model.")
136
  else:
137
- raise
138
- except Exception as e:
139
- # Any FT load error (e.g., Linear size mismatch) triggers fallback
140
- logger.warning(f"Fine-tuned model load failed: {e}. Falling back to base model.")
141
-
142
- # ------- Attempt 2: Base model & its processor -------
143
- # Free FT before loading base
 
144
  if model is not None:
145
  del model
146
  model = None
147
  torch.cuda.empty_cache()
148
 
149
- logger.info(f"Loading BASE model on GPU: {BASE_MODEL_ID}")
150
  base_processor = AutoProcessor.from_pretrained(BASE_MODEL_ID, trust_remote_code=True)
151
  _tune_image_processor(base_processor)
152
- model = Qwen2VLForConditionalGeneration.from_pretrained(
153
- BASE_MODEL_ID,
154
- torch_dtype=torch.float16,
155
- device_map="cuda",
156
- trust_remote_code=True,
157
- low_cpu_mem_usage=True,
158
- )
159
- logger.info("Base model loaded.")
160
  base_inputs = build_inputs(base_processor, image, question)
161
  text = _generate_text(model, base_processor, base_inputs)
162
  return format_derm_disclaimer(text)
@@ -196,7 +204,7 @@ def create_interface() -> gr.Blocks:
196
  submit_btn.click(fn=analyze_skin_condition, inputs=[image_input, question_input], outputs=output_box, queue=True)
197
  clear_btn.click(fn=lambda: (None, ""), inputs=None, outputs=[image_input, question_input])
198
 
199
- # Gradio 4.44.1: simple queue call, no kwargs
200
  demo.queue()
201
 
202
  gr.Markdown("Tips: Ensure good lighting and focus. Avoid uploading personally identifying information.")
 
1
  # app.py
2
+ # Dermatology-AI-Assistant HF Spaces (ZeroGPU)
3
+ # - Robust FT->Base fallback on ANY model load error (incl. Linear size mismatch)
4
  # - Uses qwen-vl-utils for vision preprocessing
5
+ # - ZeroGPU only during inference
6
  # - No runtime pip; pin versions in requirements.txt
7
 
8
  import os
9
  import logging
10
+ from typing import Optional, Tuple
11
 
12
  import gradio as gr
13
  import spaces
 
34
 
35
  ZGPU_DURATION = int(os.environ.get("ZGPU_DURATION", "180"))
36
 
37
+ # Preload only the FT processor on CPU (we may swap to base processor if we fall back)
38
  logger.info(f"Loading processor from: {FT_MODEL_ID}")
39
  ft_processor = AutoProcessor.from_pretrained(FT_MODEL_ID, trust_remote_code=True)
40
  logger.info("Processor loaded.")
 
67
  messages = _messages(image, question)
68
  text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
69
  image_inputs, video_inputs = process_vision_info(messages)
70
+ # single-sample: no padding to avoid mask quirks
71
  inputs = processor(text=[text], images=image_inputs, videos=video_inputs, return_tensors="pt")
72
  return inputs
73
 
 
98
  )
99
  return ans + tail
100
 
101
+ def try_load_model(model_id: str, *, allow_mismatch: bool = True) -> Tuple[Optional[Qwen2VLForConditionalGeneration], Optional[str]]:
102
+ """
103
+ Attempt to load a Qwen2-VL model. Return (model_or_None, error_message_or_None).
104
+ Any exception is captured and returned instead of bubbling up.
105
+ """
106
+ try:
107
+ logger.info(f"Loading model on GPU: {model_id}")
108
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
109
+ model_id,
110
+ torch_dtype=torch.float16,
111
+ device_map="cuda",
112
+ trust_remote_code=True,
113
+ low_cpu_mem_usage=True,
114
+ ignore_mismatched_sizes=allow_mismatch, # let FT load even if some heads differ
115
+ offload_state_dict=True, # helps load large shards reliably
116
+ )
117
+ logger.info(f"Model loaded: {model_id}")
118
+ return model, None
119
+ except Exception as e:
120
+ logger.warning(f"Model load failed for {model_id}: {e}")
121
+ return None, str(e)
122
+
123
  # ---------------------------
124
  # Inference (ZeroGPU)
125
  # ---------------------------
126
  @spaces.GPU(duration=ZGPU_DURATION)
127
  def analyze_skin_condition(image: Optional[Image.Image], question: str) -> str:
128
  """
129
+ Try FT model first; on ANY load error (e.g., Linear size mismatch), fall back to base model+processor.
130
  """
131
  if image is None:
132
  return "❌ Please upload an image first."
133
 
134
+ model = None
135
  try:
136
+ # Attempt 1: fine-tuned model
137
+ model, ft_err = try_load_model(FT_MODEL_ID, allow_mismatch=True)
138
+ if model is not None:
 
 
 
 
 
 
 
 
 
 
 
 
139
  try:
140
+ inputs = build_inputs(ft_processor, image, question)
141
  text = _generate_text(model, ft_processor, inputs)
142
  return format_derm_disclaimer(text)
143
  except ValueError as ve:
 
144
  if "Image features and image tokens do not match" in str(ve):
145
+ logger.warning("Token/feature mismatch on FT model — falling back to base.")
146
  else:
147
+ # Unexpected generation error on FT; fall back anyway
148
+ logger.warning(f"FT generation error: {ve}. Falling back to base.")
149
+ except Exception as gen_e:
150
+ logger.warning(f"FT generation failed: {gen_e}. Falling back to base.")
151
+ else:
152
+ logger.warning(f"FT model unavailable, error: {ft_err}. Falling back to base.")
153
+
154
+ # Free FT model (if any) before loading base
155
  if model is not None:
156
  del model
157
  model = None
158
  torch.cuda.empty_cache()
159
 
160
+ # Attempt 2: base model + its processor
161
  base_processor = AutoProcessor.from_pretrained(BASE_MODEL_ID, trust_remote_code=True)
162
  _tune_image_processor(base_processor)
163
+ model, base_err = try_load_model(BASE_MODEL_ID, allow_mismatch=False)
164
+ if model is None:
165
+ # Both loads failed — report combined error
166
+ return f"❌ Error loading models.\n- FT: {ft_err}\n- BASE: {base_err}"
167
+
 
 
 
168
  base_inputs = build_inputs(base_processor, image, question)
169
  text = _generate_text(model, base_processor, base_inputs)
170
  return format_derm_disclaimer(text)
 
204
  submit_btn.click(fn=analyze_skin_condition, inputs=[image_input, question_input], outputs=output_box, queue=True)
205
  clear_btn.click(fn=lambda: (None, ""), inputs=None, outputs=[image_input, question_input])
206
 
207
+ # Gradio 4.44.1: simple queue() call, no kwargs
208
  demo.queue()
209
 
210
  gr.Markdown("Tips: Ensure good lighting and focus. Avoid uploading personally identifying information.")