Spaces:
Sleeping
Sleeping
File size: 12,214 Bytes
ed31cbb 9d18382 ed31cbb 9d18382 ed31cbb 9d18382 ed31cbb 421600f ed31cbb 26d219f 9d18382 cfbf33f 9d18382 cfbf33f ed31cbb fc0a615 ed31cbb ab5e55b ed31cbb cfbf33f 9d18382 cfbf33f 9d18382 1ca0b67 fc0a615 9d18382 ed31cbb cbc528c cfbf33f a2e0d44 cfbf33f ed31cbb 3e03f62 cfbf33f ed31cbb e79ec61 cfbf33f e79ec61 a2e0d44 cfbf33f 9d18382 cfbf33f cbc528c cfbf33f a2e0d44 a79b20b 9d18382 ed31cbb a2e0d44 9d18382 a2e0d44 ed31cbb 3e03f62 cfbf33f 3e03f62 ed31cbb 3e03f62 ed31cbb a2e0d44 ed31cbb 9d18382 cfbf33f ed31cbb a2e0d44 cfbf33f ed31cbb 8b9a9ad a2e0d44 8b9a9ad cfbf33f ea3a5f0 cfbf33f 639fd4b ea3a5f0 9d18382 ea3a5f0 ed31cbb cfbf33f ed31cbb 49e8446 26d219f 3e03f62 9d18382 cfbf33f 3e03f62 ed31cbb 26d219f ed31cbb fc0a615 9d18382 ed31cbb cfbf33f ed31cbb 26d219f cfbf33f ed31cbb fc0a615 cfbf33f fc0a615 ed31cbb cfbf33f ed31cbb 9d18382 ed31cbb 9d18382 e4aafad cfbf33f 26d219f cf76e86 ed31cbb cfbf33f ed31cbb fc0a615 cf76e86 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 |
# app.py
# Dermatology-AI-Assistant — HF Spaces (ZeroGPU, Qwen2.5-VL + LoRA adapters)
# - Normal UI for single-image analysis
# - Hidden API endpoint /analyze_batch for batched evaluation
# - Caches & sanitizes LoRA repo once at startup (CPU); attaches on GPU per request
# - No CUDA at import-time; ZeroGPU only inside @spaces.GPU functions
import os
import json
import tempfile
import shutil
import logging
from typing import Optional, List, Dict, Any
import gradio as gr
import spaces
import torch
from PIL import Image
from huggingface_hub import snapshot_download
from peft import PeftModel
from transformers import AutoProcessor
# Prefer the new class name if your transformers is recent; fall back to old alias.
try:
from transformers import AutoModelForImageTextToText as VisionTextModelClass
except Exception:
from transformers import AutoModelForVision2Seq as VisionTextModelClass # deprecated alias
from qwen_vl_utils import process_vision_info
logging.basicConfig(level=logging.INFO, format="%(levelname)s:%(name)s:%(message)s")
logger = logging.getLogger(__name__)
# ---------------------------
# Config
# ---------------------------
BASE_MODEL_ID = os.environ.get("BASE_MODEL_ID", "Qwen/Qwen2.5-VL-3B-Instruct")
ADAPTER_ID = os.environ.get("ADAPTER_ID", "ColdSlim/Dermatology-Qwen2.5-VL-3B-LoRA")
# Give ourselves more time for first load in cold starts
ZGPU_DURATION = int(os.environ.get("ZGPU_DURATION", "15")) # seconds
# Deterministic decoding for eval; tweak as needed
GEN_KW = dict(
max_new_tokens=64,
do_sample=False,
temperature=0.0,
top_p=1.0,
repetition_penalty=1.02,
)
SYSTEM_PROMPT = (
"You are a dermatology assistant. First, look carefully at the IMAGE.\n"
"If the image is NOT a close-up of human skin or a dermatologic lesion, "
"respond EXACTLY with: 'The image does not appear to show a skin condition; I cannot analyze it.' "
"Do not invent findings.\n"
"If it IS a skin/lesion photo, provide a concise description, likely differentials (3–5), "
"and prudent next steps. Avoid definitive diagnoses and include red flags briefly."
)
# ---------------------------
# Processor (CPU only; safe at import time)
# ---------------------------
def _load_multimodal_processor() -> AutoProcessor:
logger.info(f"Loading multimodal processor from base: {BASE_MODEL_ID}")
proc = AutoProcessor.from_pretrained(
BASE_MODEL_ID,
trust_remote_code=True,
use_fast=False, # ensure multimodal __call__(images=...) works
)
# sanity check
sig = getattr(proc.__call__, "__signature__", None)
accepts_images = ("images" in str(sig)) if sig else hasattr(proc, "image_processor")
if not accepts_images or not hasattr(proc, "image_processor"):
raise RuntimeError(
"Loaded processor is not multimodal. Ensure transformers>=4.44.2, qwen-vl-utils>=0.0.8, torch>=2.2."
)
# optional: stabilize pixel hints
try:
proc.image_processor.max_pixels = int(os.environ.get("QWEN_MAX_PIXELS", str(256 * 28 * 28))) # ~0.2MP
proc.image_processor.min_pixels = int(os.environ.get("QWEN_MIN_PIXELS", str(256 * 28 * 28)))
except Exception:
pass
logger.info(f"Processor ready: {proc.__class__.__name__}")
return proc
processor = _load_multimodal_processor()
# ---------------------------
# LoRA adapter cache & sanitize (CPU-only, startup)
# ---------------------------
def _sanitize_adapter_repo(src_dir: str) -> str:
"""Remove unknown keys from adapter_config.json so PEFT can parse."""
cfg_path = os.path.join(src_dir, "adapter_config.json")
if not os.path.isfile(cfg_path):
return src_dir
with open(cfg_path, "r") as f:
cfg = json.load(f)
allowed = {
"peft_type", "task_type",
"r", "lora_alpha", "lora_dropout",
"target_modules", "bias",
"inference_mode",
"base_model_name_or_path",
"fan_in_fan_out",
"modules_to_save",
"layers_to_transform",
"layers_pattern",
"use_rslora",
"rank_dropout", "module_dropout",
"init_lora_weights",
"use_dora",
}
# If DoRA isn't actually used, remove its block
if str(cfg.get("use_dora", "false")).lower() in ("false", "0", "no"):
cfg.pop("dora_config", None)
# Drop unknown top-level keys (e.g., 'corda_config', 'eva_config', etc.)
for k in list(cfg.keys()):
if k not in allowed:
cfg.pop(k, None)
cfg.setdefault("peft_type", "LORA")
cfg.setdefault("task_type", "CAUSAL_LM")
cfg.setdefault("bias", "none")
cfg.setdefault("inference_mode", True)
# Normalize booleans if strings
for k in ("inference_mode", "use_rslora", "use_dora", "fan_in_fan_out"):
if k in cfg and isinstance(cfg[k], str):
cfg[k] = cfg[k].lower() in ("true", "1", "yes")
with open(cfg_path, "w") as f:
json.dump(cfg, f, indent=2)
return src_dir
logger.info(f"Downloading/caching LoRA adapters: {ADAPTER_ID}")
_ADAPTER_LOCAL = snapshot_download(ADAPTER_ID, local_dir=None, local_dir_use_symlinks=False)
_ADAPTER_LOCAL = _sanitize_adapter_repo(_ADAPTER_LOCAL)
logger.info(f"Adapters ready at: {_ADAPTER_LOCAL}")
# ---------------------------
# Helpers
# ---------------------------
def _messages(image: Image.Image, question: str):
if image.mode != "RGB":
image = image.convert("RGB")
return [
{"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
{"role": "user", "content": [{"type": "image", "image": image},
{"type": "text", "text": question}]},
]
def build_inputs(image: Image.Image, question: str):
msgs = _messages(image, question)
text = processor.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = process_vision_info(msgs)
return processor(text=[text], images=image_inputs, videos=video_inputs, return_tensors="pt")
def _pad_token_id(model):
tid = getattr(getattr(processor, "tokenizer", None), "eos_token_id", None)
return tid if tid is not None else (getattr(getattr(model, "config", None), "eos_token_id", 0) or 0)
def _generate_text(model, inputs: Dict[str, Any]) -> str:
# move tensors to model device
device = next(model.parameters()).device
inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
with torch.no_grad():
out_ids = model.generate(**inputs, **GEN_KW, pad_token_id=_pad_token_id(model))
# trim prompt
trimmed = [o[len(i):] for i, o in zip(inputs["input_ids"], out_ids)]
text = processor.batch_decode(trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
return text
def format_derm_disclaimer(ans: str) -> str:
return (
ans
+ "\n\n---\n"
"_Disclaimer: This AI is not a medical device. The output is informational and may be inaccurate. "
"Consult a qualified dermatologist for diagnosis and treatment._"
)
def _load_base_plus_lora(dtype: torch.dtype = torch.float16):
logger.info(f"Loading BASE on GPU: {BASE_MODEL_ID}")
base = VisionTextModelClass.from_pretrained(
BASE_MODEL_ID,
torch_dtype=dtype,
device_map="cuda",
trust_remote_code=True,
low_cpu_mem_usage=True,
)
logger.info(f"Attaching LoRA adapters from: {_ADAPTER_LOCAL}")
model = PeftModel.from_pretrained(base, _ADAPTER_LOCAL, is_trainable=False)
model.eval()
return model
# ---------------------------
# Inference (ZeroGPU-safe: only here we touch CUDA)
# ---------------------------
@spaces.GPU(duration=ZGPU_DURATION)
def analyze_skin_condition(image: Optional[Image.Image], question: str) -> str:
if image is None:
return "❌ Please upload an image first."
model = None
try:
inputs = build_inputs(image, question)
# pick fp16; bf16 also works on newer GPUs
model = _load_base_plus_lora(dtype=torch.float16)
text = _generate_text(model, inputs)
return format_derm_disclaimer(text)
except Exception as e:
logger.exception("Error during inference")
return f"❌ Error analyzing image: {e}"
finally:
if model is not None:
del model
torch.cuda.empty_cache()
# ---------------------------
# Batched inference API (hidden; call via /analyze_batch)
# ---------------------------
@spaces.GPU(duration=ZGPU_DURATION)
def analyze_batch(samples: List[Dict[str, Any]]) -> List[str]:
"""
samples: list of dicts like: {"image": <PIL/Image or filepath>, "question": <str>}
Returns a list of responses (same order).
"""
outs: List[str] = []
if not isinstance(samples, list):
return ["❌ Invalid payload: expected a JSON list of {image, question} dicts."]
model = None
try:
model = _load_base_plus_lora(dtype=torch.float16)
for ex in samples:
try:
img = ex.get("image")
q = ex.get("question") or "Describe this skin condition in detail and suggest possible next steps."
# If the client sent a path (e.g., via gradio_client handle_file), load it:
if isinstance(img, str) and os.path.isfile(img):
img = Image.open(img).convert("RGB")
if not isinstance(img, Image.Image):
outs.append("❌ Missing/invalid image")
continue
inputs = build_inputs(img, q)
text = _generate_text(model, inputs)
outs.append(format_derm_disclaimer(text))
except Exception as ie:
logger.exception("Error on one batch item")
outs.append(f"❌ Error analyzing one item: {ie}")
return outs
except Exception as e:
logger.exception("Batch inference failed")
return [f"❌ Batch error: {e}"]
finally:
if model is not None:
del model
torch.cuda.empty_cache()
# ---------------------------
# UI
# ---------------------------
def create_interface() -> gr.Blocks:
with gr.Blocks(title="Dermatology AI Assistant") as demo:
gr.Markdown(
"# 🩺 Dermatology AI Assistant\n"
"Upload a skin photo and ask a question. The model will provide an informational response."
)
with gr.Row():
image_input = gr.Image(type="pil", label="Upload Image (JPG/PNG)")
question_input = gr.Textbox(
label="Question / Prompt",
value="Describe this skin condition in detail and suggest possible next steps.",
lines=3,
)
with gr.Row():
submit_btn = gr.Button("Analyze", variant="primary")
clear_btn = gr.Button("Clear")
output_box = gr.Textbox(label="Response", lines=16, show_copy_button=True)
submit_btn.click(
fn=analyze_skin_condition,
inputs=[image_input, question_input],
outputs=output_box,
queue=True,
api_name="analyze_skin_condition", # public API for single requests
)
clear_btn.click(fn=lambda: (None, ""), inputs=None, outputs=[image_input, question_input])
# Hidden minimal iface just to expose a batch API route
gr.Interface(
fn=analyze_batch,
inputs=[gr.JSON(label="samples")],
outputs=gr.JSON(label="responses"),
allow_flagging="never",
api_name="analyze_batch", # call this from gradio_client
visible=False, # hide in UI; keep route alive
)
demo.queue()
gr.Markdown(
"_Tips: Ensure good lighting and focus. Avoid uploading personally identifying information._"
)
return demo
def main():
demo = create_interface()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
show_error=True,
inbrowser=False,
quiet=False,
ssr_mode=False, # no Node requirement
)
if __name__ == "__main__":
main()
|