Manik Sheokand commited on
Commit
cbc528c
·
1 Parent(s): d5cfeb4

new changes

Browse files
Files changed (3) hide show
  1. .env +4 -0
  2. app.py +3 -3
  3. app.py.backup +324 -0
.env ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Optimized settings for free tier
2
+ ZGPU_DURATION=120
3
+ QWEN_MAX_PIXELS=250880
4
+ QWEN_MIN_PIXELS=200704
app.py CHANGED
@@ -38,11 +38,11 @@ BASE_MODEL_ID = os.environ.get("BASE_MODEL_ID", "Qwen/Qwen2.5-VL-3B-Instruct")
38
  ADAPTER_ID = os.environ.get("ADAPTER_ID", "ColdSlim/Dermatology-Qwen2.5-VL-3B-LoRA")
39
 
40
  # Give ourselves more time for first load in cold starts
41
- ZGPU_DURATION = int(os.environ.get("ZGPU_DURATION", "120")) # seconds
42
 
43
  # Deterministic decoding for eval; tweak as needed
44
  GEN_KW = dict(
45
- max_new_tokens=128,
46
  do_sample=False,
47
  temperature=0.0,
48
  top_p=1.0,
@@ -77,7 +77,7 @@ def _load_multimodal_processor() -> AutoProcessor:
77
  )
78
  # optional: stabilize pixel hints
79
  try:
80
- proc.image_processor.max_pixels = int(os.environ.get("QWEN_MAX_PIXELS", str(320 * 28 * 28))) # ~0.5MP
81
  proc.image_processor.min_pixels = int(os.environ.get("QWEN_MIN_PIXELS", str(256 * 28 * 28)))
82
  except Exception:
83
  pass
 
38
  ADAPTER_ID = os.environ.get("ADAPTER_ID", "ColdSlim/Dermatology-Qwen2.5-VL-3B-LoRA")
39
 
40
  # Give ourselves more time for first load in cold starts
41
+ ZGPU_DURATION = int(os.environ.get("ZGPU_DURATION", "60")) # seconds
42
 
43
  # Deterministic decoding for eval; tweak as needed
44
  GEN_KW = dict(
45
+ max_new_tokens=64,
46
  do_sample=False,
47
  temperature=0.0,
48
  top_p=1.0,
 
77
  )
78
  # optional: stabilize pixel hints
79
  try:
80
+ proc.image_processor.max_pixels = int(os.environ.get("QWEN_MAX_PIXELS", str(256 * 28 * 28))) # ~0.2MP
81
  proc.image_processor.min_pixels = int(os.environ.get("QWEN_MIN_PIXELS", str(256 * 28 * 28)))
82
  except Exception:
83
  pass
app.py.backup ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ # Dermatology-AI-Assistant — HF Spaces (ZeroGPU, Qwen2.5-VL + LoRA adapters)
3
+ # - Normal UI for single-image analysis
4
+ # - Hidden API endpoint /analyze_batch for batched evaluation
5
+ # - Caches & sanitizes LoRA repo once at startup (CPU); attaches on GPU per request
6
+ # - No CUDA at import-time; ZeroGPU only inside @spaces.GPU functions
7
+
8
+ import os
9
+ import json
10
+ import tempfile
11
+ import shutil
12
+ import logging
13
+ from typing import Optional, List, Dict, Any
14
+
15
+ import gradio as gr
16
+ import spaces
17
+ import torch
18
+ from PIL import Image
19
+ from huggingface_hub import snapshot_download
20
+ from peft import PeftModel
21
+ from transformers import AutoProcessor
22
+
23
+ # Prefer the new class name if your transformers is recent; fall back to old alias.
24
+ try:
25
+ from transformers import AutoModelForImageTextToText as VisionTextModelClass
26
+ except Exception:
27
+ from transformers import AutoModelForVision2Seq as VisionTextModelClass # deprecated alias
28
+
29
+ from qwen_vl_utils import process_vision_info
30
+
31
+ logging.basicConfig(level=logging.INFO, format="%(levelname)s:%(name)s:%(message)s")
32
+ logger = logging.getLogger(__name__)
33
+
34
+ # ---------------------------
35
+ # Config
36
+ # ---------------------------
37
+ BASE_MODEL_ID = os.environ.get("BASE_MODEL_ID", "Qwen/Qwen2.5-VL-3B-Instruct")
38
+ ADAPTER_ID = os.environ.get("ADAPTER_ID", "ColdSlim/Dermatology-Qwen2.5-VL-3B-LoRA")
39
+
40
+ # Give ourselves more time for first load in cold starts
41
+ ZGPU_DURATION = int(os.environ.get("ZGPU_DURATION", "600")) # seconds
42
+
43
+ # Deterministic decoding for eval; tweak as needed
44
+ GEN_KW = dict(
45
+ max_new_tokens=256,
46
+ do_sample=False,
47
+ temperature=0.0,
48
+ top_p=1.0,
49
+ repetition_penalty=1.02,
50
+ )
51
+
52
+ SYSTEM_PROMPT = (
53
+ "You are a dermatology assistant. First, look carefully at the IMAGE.\n"
54
+ "If the image is NOT a close-up of human skin or a dermatologic lesion, "
55
+ "respond EXACTLY with: 'The image does not appear to show a skin condition; I cannot analyze it.' "
56
+ "Do not invent findings.\n"
57
+ "If it IS a skin/lesion photo, provide a concise description, likely differentials (3–5), "
58
+ "and prudent next steps. Avoid definitive diagnoses and include red flags briefly."
59
+ )
60
+
61
+ # ---------------------------
62
+ # Processor (CPU only; safe at import time)
63
+ # ---------------------------
64
+ def _load_multimodal_processor() -> AutoProcessor:
65
+ logger.info(f"Loading multimodal processor from base: {BASE_MODEL_ID}")
66
+ proc = AutoProcessor.from_pretrained(
67
+ BASE_MODEL_ID,
68
+ trust_remote_code=True,
69
+ use_fast=False, # ensure multimodal __call__(images=...) works
70
+ )
71
+ # sanity check
72
+ sig = getattr(proc.__call__, "__signature__", None)
73
+ accepts_images = ("images" in str(sig)) if sig else hasattr(proc, "image_processor")
74
+ if not accepts_images or not hasattr(proc, "image_processor"):
75
+ raise RuntimeError(
76
+ "Loaded processor is not multimodal. Ensure transformers>=4.44.2, qwen-vl-utils>=0.0.8, torch>=2.2."
77
+ )
78
+ # optional: stabilize pixel hints
79
+ try:
80
+ proc.image_processor.max_pixels = int(os.environ.get("QWEN_MAX_PIXELS", str(640 * 28 * 28))) # ~0.5MP
81
+ proc.image_processor.min_pixels = int(os.environ.get("QWEN_MIN_PIXELS", str(256 * 28 * 28)))
82
+ except Exception:
83
+ pass
84
+ logger.info(f"Processor ready: {proc.__class__.__name__}")
85
+ return proc
86
+
87
+ processor = _load_multimodal_processor()
88
+
89
+ # ---------------------------
90
+ # LoRA adapter cache & sanitize (CPU-only, startup)
91
+ # ---------------------------
92
+ def _sanitize_adapter_repo(src_dir: str) -> str:
93
+ """Remove unknown keys from adapter_config.json so PEFT can parse."""
94
+ cfg_path = os.path.join(src_dir, "adapter_config.json")
95
+ if not os.path.isfile(cfg_path):
96
+ return src_dir
97
+
98
+ with open(cfg_path, "r") as f:
99
+ cfg = json.load(f)
100
+
101
+ allowed = {
102
+ "peft_type", "task_type",
103
+ "r", "lora_alpha", "lora_dropout",
104
+ "target_modules", "bias",
105
+ "inference_mode",
106
+ "base_model_name_or_path",
107
+ "fan_in_fan_out",
108
+ "modules_to_save",
109
+ "layers_to_transform",
110
+ "layers_pattern",
111
+ "use_rslora",
112
+ "rank_dropout", "module_dropout",
113
+ "init_lora_weights",
114
+ "use_dora",
115
+ }
116
+
117
+ # If DoRA isn't actually used, remove its block
118
+ if str(cfg.get("use_dora", "false")).lower() in ("false", "0", "no"):
119
+ cfg.pop("dora_config", None)
120
+
121
+ # Drop unknown top-level keys (e.g., 'corda_config', 'eva_config', etc.)
122
+ for k in list(cfg.keys()):
123
+ if k not in allowed:
124
+ cfg.pop(k, None)
125
+
126
+ cfg.setdefault("peft_type", "LORA")
127
+ cfg.setdefault("task_type", "CAUSAL_LM")
128
+ cfg.setdefault("bias", "none")
129
+ cfg.setdefault("inference_mode", True)
130
+
131
+ # Normalize booleans if strings
132
+ for k in ("inference_mode", "use_rslora", "use_dora", "fan_in_fan_out"):
133
+ if k in cfg and isinstance(cfg[k], str):
134
+ cfg[k] = cfg[k].lower() in ("true", "1", "yes")
135
+
136
+ with open(cfg_path, "w") as f:
137
+ json.dump(cfg, f, indent=2)
138
+ return src_dir
139
+
140
+ logger.info(f"Downloading/caching LoRA adapters: {ADAPTER_ID}")
141
+ _ADAPTER_LOCAL = snapshot_download(ADAPTER_ID, local_dir=None, local_dir_use_symlinks=False)
142
+ _ADAPTER_LOCAL = _sanitize_adapter_repo(_ADAPTER_LOCAL)
143
+ logger.info(f"Adapters ready at: {_ADAPTER_LOCAL}")
144
+
145
+ # ---------------------------
146
+ # Helpers
147
+ # ---------------------------
148
+ def _messages(image: Image.Image, question: str):
149
+ if image.mode != "RGB":
150
+ image = image.convert("RGB")
151
+ return [
152
+ {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
153
+ {"role": "user", "content": [{"type": "image", "image": image},
154
+ {"type": "text", "text": question}]},
155
+ ]
156
+
157
+ def build_inputs(image: Image.Image, question: str):
158
+ msgs = _messages(image, question)
159
+ text = processor.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
160
+ image_inputs, video_inputs = process_vision_info(msgs)
161
+ return processor(text=[text], images=image_inputs, videos=video_inputs, return_tensors="pt")
162
+
163
+ def _pad_token_id(model):
164
+ tid = getattr(getattr(processor, "tokenizer", None), "eos_token_id", None)
165
+ return tid if tid is not None else (getattr(getattr(model, "config", None), "eos_token_id", 0) or 0)
166
+
167
+ def _generate_text(model, inputs: Dict[str, Any]) -> str:
168
+ # move tensors to model device
169
+ device = next(model.parameters()).device
170
+ inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
171
+ with torch.no_grad():
172
+ out_ids = model.generate(**inputs, **GEN_KW, pad_token_id=_pad_token_id(model))
173
+ # trim prompt
174
+ trimmed = [o[len(i):] for i, o in zip(inputs["input_ids"], out_ids)]
175
+ text = processor.batch_decode(trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
176
+ return text
177
+
178
+ def format_derm_disclaimer(ans: str) -> str:
179
+ return (
180
+ ans
181
+ + "\n\n---\n"
182
+ "_Disclaimer: This AI is not a medical device. The output is informational and may be inaccurate. "
183
+ "Consult a qualified dermatologist for diagnosis and treatment._"
184
+ )
185
+
186
+ def _load_base_plus_lora(dtype: torch.dtype = torch.float16):
187
+ logger.info(f"Loading BASE on GPU: {BASE_MODEL_ID}")
188
+ base = VisionTextModelClass.from_pretrained(
189
+ BASE_MODEL_ID,
190
+ torch_dtype=dtype,
191
+ device_map="cuda",
192
+ trust_remote_code=True,
193
+ low_cpu_mem_usage=True,
194
+ )
195
+ logger.info(f"Attaching LoRA adapters from: {_ADAPTER_LOCAL}")
196
+ model = PeftModel.from_pretrained(base, _ADAPTER_LOCAL, is_trainable=False)
197
+ model.eval()
198
+ return model
199
+
200
+ # ---------------------------
201
+ # Inference (ZeroGPU-safe: only here we touch CUDA)
202
+ # ---------------------------
203
+ @spaces.GPU(duration=ZGPU_DURATION)
204
+ def analyze_skin_condition(image: Optional[Image.Image], question: str) -> str:
205
+ if image is None:
206
+ return "❌ Please upload an image first."
207
+ model = None
208
+ try:
209
+ inputs = build_inputs(image, question)
210
+ # pick fp16; bf16 also works on newer GPUs
211
+ model = _load_base_plus_lora(dtype=torch.float16)
212
+ text = _generate_text(model, inputs)
213
+ return format_derm_disclaimer(text)
214
+ except Exception as e:
215
+ logger.exception("Error during inference")
216
+ return f"❌ Error analyzing image: {e}"
217
+ finally:
218
+ if model is not None:
219
+ del model
220
+ torch.cuda.empty_cache()
221
+
222
+ # ---------------------------
223
+ # Batched inference API (hidden; call via /analyze_batch)
224
+ # ---------------------------
225
+ @spaces.GPU(duration=ZGPU_DURATION)
226
+ def analyze_batch(samples: List[Dict[str, Any]]) -> List[str]:
227
+ """
228
+ samples: list of dicts like: {"image": <PIL/Image or filepath>, "question": <str>}
229
+ Returns a list of responses (same order).
230
+ """
231
+ outs: List[str] = []
232
+ if not isinstance(samples, list):
233
+ return ["❌ Invalid payload: expected a JSON list of {image, question} dicts."]
234
+ model = None
235
+ try:
236
+ model = _load_base_plus_lora(dtype=torch.float16)
237
+ for ex in samples:
238
+ try:
239
+ img = ex.get("image")
240
+ q = ex.get("question") or "Describe this skin condition in detail and suggest possible next steps."
241
+ # If the client sent a path (e.g., via gradio_client handle_file), load it:
242
+ if isinstance(img, str) and os.path.isfile(img):
243
+ img = Image.open(img).convert("RGB")
244
+ if not isinstance(img, Image.Image):
245
+ outs.append("❌ Missing/invalid image")
246
+ continue
247
+ inputs = build_inputs(img, q)
248
+ text = _generate_text(model, inputs)
249
+ outs.append(format_derm_disclaimer(text))
250
+ except Exception as ie:
251
+ logger.exception("Error on one batch item")
252
+ outs.append(f"❌ Error analyzing one item: {ie}")
253
+ return outs
254
+ except Exception as e:
255
+ logger.exception("Batch inference failed")
256
+ return [f"❌ Batch error: {e}"]
257
+ finally:
258
+ if model is not None:
259
+ del model
260
+ torch.cuda.empty_cache()
261
+
262
+ # ---------------------------
263
+ # UI
264
+ # ---------------------------
265
+ def create_interface() -> gr.Blocks:
266
+ with gr.Blocks(title="Dermatology AI Assistant") as demo:
267
+ gr.Markdown(
268
+ "# 🩺 Dermatology AI Assistant\n"
269
+ "Upload a skin photo and ask a question. The model will provide an informational response."
270
+ )
271
+
272
+ with gr.Row():
273
+ image_input = gr.Image(type="pil", label="Upload Image (JPG/PNG)")
274
+ question_input = gr.Textbox(
275
+ label="Question / Prompt",
276
+ value="Describe this skin condition in detail and suggest possible next steps.",
277
+ lines=3,
278
+ )
279
+
280
+ with gr.Row():
281
+ submit_btn = gr.Button("Analyze", variant="primary")
282
+ clear_btn = gr.Button("Clear")
283
+
284
+ output_box = gr.Textbox(label="Response", lines=16, show_copy_button=True)
285
+
286
+ submit_btn.click(
287
+ fn=analyze_skin_condition,
288
+ inputs=[image_input, question_input],
289
+ outputs=output_box,
290
+ queue=True,
291
+ api_name="analyze_skin_condition", # public API for single requests
292
+ )
293
+ clear_btn.click(fn=lambda: (None, ""), inputs=None, outputs=[image_input, question_input])
294
+
295
+ # Hidden minimal iface just to expose a batch API route
296
+ gr.Interface(
297
+ fn=analyze_batch,
298
+ inputs=[gr.JSON(label="samples")],
299
+ outputs=gr.JSON(label="responses"),
300
+ allow_flagging="never",
301
+ api_name="analyze_batch", # call this from gradio_client
302
+ visible=False, # hide in UI; keep route alive
303
+ )
304
+
305
+ demo.queue()
306
+ gr.Markdown(
307
+ "_Tips: Ensure good lighting and focus. Avoid uploading personally identifying information._"
308
+ )
309
+ return demo
310
+
311
+ def main():
312
+ demo = create_interface()
313
+ demo.launch(
314
+ server_name="0.0.0.0",
315
+ server_port=7860,
316
+ share=False,
317
+ show_error=True,
318
+ inbrowser=False,
319
+ quiet=False,
320
+ ssr_mode=False, # no Node requirement
321
+ )
322
+
323
+ if __name__ == "__main__":
324
+ main()