Manik Sheokand
new timeout
1ca0b67
# app.py
# Dermatology-AI-Assistant — HF Spaces (ZeroGPU, Qwen2.5-VL + LoRA adapters)
# - Normal UI for single-image analysis
# - Hidden API endpoint /analyze_batch for batched evaluation
# - Caches & sanitizes LoRA repo once at startup (CPU); attaches on GPU per request
# - No CUDA at import-time; ZeroGPU only inside @spaces.GPU functions
import os
import json
import tempfile
import shutil
import logging
from typing import Optional, List, Dict, Any
import gradio as gr
import spaces
import torch
from PIL import Image
from huggingface_hub import snapshot_download
from peft import PeftModel
from transformers import AutoProcessor
# Prefer the new class name if your transformers is recent; fall back to old alias.
try:
from transformers import AutoModelForImageTextToText as VisionTextModelClass
except Exception:
from transformers import AutoModelForVision2Seq as VisionTextModelClass # deprecated alias
from qwen_vl_utils import process_vision_info
logging.basicConfig(level=logging.INFO, format="%(levelname)s:%(name)s:%(message)s")
logger = logging.getLogger(__name__)
# ---------------------------
# Config
# ---------------------------
BASE_MODEL_ID = os.environ.get("BASE_MODEL_ID", "Qwen/Qwen2.5-VL-3B-Instruct")
ADAPTER_ID = os.environ.get("ADAPTER_ID", "ColdSlim/Dermatology-Qwen2.5-VL-3B-LoRA")
# Give ourselves more time for first load in cold starts
ZGPU_DURATION = int(os.environ.get("ZGPU_DURATION", "15")) # seconds
# Deterministic decoding for eval; tweak as needed
GEN_KW = dict(
max_new_tokens=64,
do_sample=False,
temperature=0.0,
top_p=1.0,
repetition_penalty=1.02,
)
SYSTEM_PROMPT = (
"You are a dermatology assistant. First, look carefully at the IMAGE.\n"
"If the image is NOT a close-up of human skin or a dermatologic lesion, "
"respond EXACTLY with: 'The image does not appear to show a skin condition; I cannot analyze it.' "
"Do not invent findings.\n"
"If it IS a skin/lesion photo, provide a concise description, likely differentials (3–5), "
"and prudent next steps. Avoid definitive diagnoses and include red flags briefly."
)
# ---------------------------
# Processor (CPU only; safe at import time)
# ---------------------------
def _load_multimodal_processor() -> AutoProcessor:
logger.info(f"Loading multimodal processor from base: {BASE_MODEL_ID}")
proc = AutoProcessor.from_pretrained(
BASE_MODEL_ID,
trust_remote_code=True,
use_fast=False, # ensure multimodal __call__(images=...) works
)
# sanity check
sig = getattr(proc.__call__, "__signature__", None)
accepts_images = ("images" in str(sig)) if sig else hasattr(proc, "image_processor")
if not accepts_images or not hasattr(proc, "image_processor"):
raise RuntimeError(
"Loaded processor is not multimodal. Ensure transformers>=4.44.2, qwen-vl-utils>=0.0.8, torch>=2.2."
)
# optional: stabilize pixel hints
try:
proc.image_processor.max_pixels = int(os.environ.get("QWEN_MAX_PIXELS", str(256 * 28 * 28))) # ~0.2MP
proc.image_processor.min_pixels = int(os.environ.get("QWEN_MIN_PIXELS", str(256 * 28 * 28)))
except Exception:
pass
logger.info(f"Processor ready: {proc.__class__.__name__}")
return proc
processor = _load_multimodal_processor()
# ---------------------------
# LoRA adapter cache & sanitize (CPU-only, startup)
# ---------------------------
def _sanitize_adapter_repo(src_dir: str) -> str:
"""Remove unknown keys from adapter_config.json so PEFT can parse."""
cfg_path = os.path.join(src_dir, "adapter_config.json")
if not os.path.isfile(cfg_path):
return src_dir
with open(cfg_path, "r") as f:
cfg = json.load(f)
allowed = {
"peft_type", "task_type",
"r", "lora_alpha", "lora_dropout",
"target_modules", "bias",
"inference_mode",
"base_model_name_or_path",
"fan_in_fan_out",
"modules_to_save",
"layers_to_transform",
"layers_pattern",
"use_rslora",
"rank_dropout", "module_dropout",
"init_lora_weights",
"use_dora",
}
# If DoRA isn't actually used, remove its block
if str(cfg.get("use_dora", "false")).lower() in ("false", "0", "no"):
cfg.pop("dora_config", None)
# Drop unknown top-level keys (e.g., 'corda_config', 'eva_config', etc.)
for k in list(cfg.keys()):
if k not in allowed:
cfg.pop(k, None)
cfg.setdefault("peft_type", "LORA")
cfg.setdefault("task_type", "CAUSAL_LM")
cfg.setdefault("bias", "none")
cfg.setdefault("inference_mode", True)
# Normalize booleans if strings
for k in ("inference_mode", "use_rslora", "use_dora", "fan_in_fan_out"):
if k in cfg and isinstance(cfg[k], str):
cfg[k] = cfg[k].lower() in ("true", "1", "yes")
with open(cfg_path, "w") as f:
json.dump(cfg, f, indent=2)
return src_dir
logger.info(f"Downloading/caching LoRA adapters: {ADAPTER_ID}")
_ADAPTER_LOCAL = snapshot_download(ADAPTER_ID, local_dir=None, local_dir_use_symlinks=False)
_ADAPTER_LOCAL = _sanitize_adapter_repo(_ADAPTER_LOCAL)
logger.info(f"Adapters ready at: {_ADAPTER_LOCAL}")
# ---------------------------
# Helpers
# ---------------------------
def _messages(image: Image.Image, question: str):
if image.mode != "RGB":
image = image.convert("RGB")
return [
{"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
{"role": "user", "content": [{"type": "image", "image": image},
{"type": "text", "text": question}]},
]
def build_inputs(image: Image.Image, question: str):
msgs = _messages(image, question)
text = processor.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = process_vision_info(msgs)
return processor(text=[text], images=image_inputs, videos=video_inputs, return_tensors="pt")
def _pad_token_id(model):
tid = getattr(getattr(processor, "tokenizer", None), "eos_token_id", None)
return tid if tid is not None else (getattr(getattr(model, "config", None), "eos_token_id", 0) or 0)
def _generate_text(model, inputs: Dict[str, Any]) -> str:
# move tensors to model device
device = next(model.parameters()).device
inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
with torch.no_grad():
out_ids = model.generate(**inputs, **GEN_KW, pad_token_id=_pad_token_id(model))
# trim prompt
trimmed = [o[len(i):] for i, o in zip(inputs["input_ids"], out_ids)]
text = processor.batch_decode(trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
return text
def format_derm_disclaimer(ans: str) -> str:
return (
ans
+ "\n\n---\n"
"_Disclaimer: This AI is not a medical device. The output is informational and may be inaccurate. "
"Consult a qualified dermatologist for diagnosis and treatment._"
)
def _load_base_plus_lora(dtype: torch.dtype = torch.float16):
logger.info(f"Loading BASE on GPU: {BASE_MODEL_ID}")
base = VisionTextModelClass.from_pretrained(
BASE_MODEL_ID,
torch_dtype=dtype,
device_map="cuda",
trust_remote_code=True,
low_cpu_mem_usage=True,
)
logger.info(f"Attaching LoRA adapters from: {_ADAPTER_LOCAL}")
model = PeftModel.from_pretrained(base, _ADAPTER_LOCAL, is_trainable=False)
model.eval()
return model
# ---------------------------
# Inference (ZeroGPU-safe: only here we touch CUDA)
# ---------------------------
@spaces.GPU(duration=ZGPU_DURATION)
def analyze_skin_condition(image: Optional[Image.Image], question: str) -> str:
if image is None:
return "❌ Please upload an image first."
model = None
try:
inputs = build_inputs(image, question)
# pick fp16; bf16 also works on newer GPUs
model = _load_base_plus_lora(dtype=torch.float16)
text = _generate_text(model, inputs)
return format_derm_disclaimer(text)
except Exception as e:
logger.exception("Error during inference")
return f"❌ Error analyzing image: {e}"
finally:
if model is not None:
del model
torch.cuda.empty_cache()
# ---------------------------
# Batched inference API (hidden; call via /analyze_batch)
# ---------------------------
@spaces.GPU(duration=ZGPU_DURATION)
def analyze_batch(samples: List[Dict[str, Any]]) -> List[str]:
"""
samples: list of dicts like: {"image": <PIL/Image or filepath>, "question": <str>}
Returns a list of responses (same order).
"""
outs: List[str] = []
if not isinstance(samples, list):
return ["❌ Invalid payload: expected a JSON list of {image, question} dicts."]
model = None
try:
model = _load_base_plus_lora(dtype=torch.float16)
for ex in samples:
try:
img = ex.get("image")
q = ex.get("question") or "Describe this skin condition in detail and suggest possible next steps."
# If the client sent a path (e.g., via gradio_client handle_file), load it:
if isinstance(img, str) and os.path.isfile(img):
img = Image.open(img).convert("RGB")
if not isinstance(img, Image.Image):
outs.append("❌ Missing/invalid image")
continue
inputs = build_inputs(img, q)
text = _generate_text(model, inputs)
outs.append(format_derm_disclaimer(text))
except Exception as ie:
logger.exception("Error on one batch item")
outs.append(f"❌ Error analyzing one item: {ie}")
return outs
except Exception as e:
logger.exception("Batch inference failed")
return [f"❌ Batch error: {e}"]
finally:
if model is not None:
del model
torch.cuda.empty_cache()
# ---------------------------
# UI
# ---------------------------
def create_interface() -> gr.Blocks:
with gr.Blocks(title="Dermatology AI Assistant") as demo:
gr.Markdown(
"# 🩺 Dermatology AI Assistant\n"
"Upload a skin photo and ask a question. The model will provide an informational response."
)
with gr.Row():
image_input = gr.Image(type="pil", label="Upload Image (JPG/PNG)")
question_input = gr.Textbox(
label="Question / Prompt",
value="Describe this skin condition in detail and suggest possible next steps.",
lines=3,
)
with gr.Row():
submit_btn = gr.Button("Analyze", variant="primary")
clear_btn = gr.Button("Clear")
output_box = gr.Textbox(label="Response", lines=16, show_copy_button=True)
submit_btn.click(
fn=analyze_skin_condition,
inputs=[image_input, question_input],
outputs=output_box,
queue=True,
api_name="analyze_skin_condition", # public API for single requests
)
clear_btn.click(fn=lambda: (None, ""), inputs=None, outputs=[image_input, question_input])
# Hidden minimal iface just to expose a batch API route
gr.Interface(
fn=analyze_batch,
inputs=[gr.JSON(label="samples")],
outputs=gr.JSON(label="responses"),
allow_flagging="never",
api_name="analyze_batch", # call this from gradio_client
visible=False, # hide in UI; keep route alive
)
demo.queue()
gr.Markdown(
"_Tips: Ensure good lighting and focus. Avoid uploading personally identifying information._"
)
return demo
def main():
demo = create_interface()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
show_error=True,
inbrowser=False,
quiet=False,
ssr_mode=False, # no Node requirement
)
if __name__ == "__main__":
main()