# app.py # Dermatology-AI-Assistant — HF Spaces (ZeroGPU, Qwen2.5-VL + LoRA adapters) # - Normal UI for single-image analysis # - Hidden API endpoint /analyze_batch for batched evaluation # - Caches & sanitizes LoRA repo once at startup (CPU); attaches on GPU per request # - No CUDA at import-time; ZeroGPU only inside @spaces.GPU functions import os import json import tempfile import shutil import logging from typing import Optional, List, Dict, Any import gradio as gr import spaces import torch from PIL import Image from huggingface_hub import snapshot_download from peft import PeftModel from transformers import AutoProcessor # Prefer the new class name if your transformers is recent; fall back to old alias. try: from transformers import AutoModelForImageTextToText as VisionTextModelClass except Exception: from transformers import AutoModelForVision2Seq as VisionTextModelClass # deprecated alias from qwen_vl_utils import process_vision_info logging.basicConfig(level=logging.INFO, format="%(levelname)s:%(name)s:%(message)s") logger = logging.getLogger(__name__) # --------------------------- # Config # --------------------------- BASE_MODEL_ID = os.environ.get("BASE_MODEL_ID", "Qwen/Qwen2.5-VL-3B-Instruct") ADAPTER_ID = os.environ.get("ADAPTER_ID", "ColdSlim/Dermatology-Qwen2.5-VL-3B-LoRA") # Give ourselves more time for first load in cold starts ZGPU_DURATION = int(os.environ.get("ZGPU_DURATION", "15")) # seconds # Deterministic decoding for eval; tweak as needed GEN_KW = dict( max_new_tokens=64, do_sample=False, temperature=0.0, top_p=1.0, repetition_penalty=1.02, ) SYSTEM_PROMPT = ( "You are a dermatology assistant. First, look carefully at the IMAGE.\n" "If the image is NOT a close-up of human skin or a dermatologic lesion, " "respond EXACTLY with: 'The image does not appear to show a skin condition; I cannot analyze it.' " "Do not invent findings.\n" "If it IS a skin/lesion photo, provide a concise description, likely differentials (3–5), " "and prudent next steps. Avoid definitive diagnoses and include red flags briefly." ) # --------------------------- # Processor (CPU only; safe at import time) # --------------------------- def _load_multimodal_processor() -> AutoProcessor: logger.info(f"Loading multimodal processor from base: {BASE_MODEL_ID}") proc = AutoProcessor.from_pretrained( BASE_MODEL_ID, trust_remote_code=True, use_fast=False, # ensure multimodal __call__(images=...) works ) # sanity check sig = getattr(proc.__call__, "__signature__", None) accepts_images = ("images" in str(sig)) if sig else hasattr(proc, "image_processor") if not accepts_images or not hasattr(proc, "image_processor"): raise RuntimeError( "Loaded processor is not multimodal. Ensure transformers>=4.44.2, qwen-vl-utils>=0.0.8, torch>=2.2." ) # optional: stabilize pixel hints try: proc.image_processor.max_pixels = int(os.environ.get("QWEN_MAX_PIXELS", str(256 * 28 * 28))) # ~0.2MP proc.image_processor.min_pixels = int(os.environ.get("QWEN_MIN_PIXELS", str(256 * 28 * 28))) except Exception: pass logger.info(f"Processor ready: {proc.__class__.__name__}") return proc processor = _load_multimodal_processor() # --------------------------- # LoRA adapter cache & sanitize (CPU-only, startup) # --------------------------- def _sanitize_adapter_repo(src_dir: str) -> str: """Remove unknown keys from adapter_config.json so PEFT can parse.""" cfg_path = os.path.join(src_dir, "adapter_config.json") if not os.path.isfile(cfg_path): return src_dir with open(cfg_path, "r") as f: cfg = json.load(f) allowed = { "peft_type", "task_type", "r", "lora_alpha", "lora_dropout", "target_modules", "bias", "inference_mode", "base_model_name_or_path", "fan_in_fan_out", "modules_to_save", "layers_to_transform", "layers_pattern", "use_rslora", "rank_dropout", "module_dropout", "init_lora_weights", "use_dora", } # If DoRA isn't actually used, remove its block if str(cfg.get("use_dora", "false")).lower() in ("false", "0", "no"): cfg.pop("dora_config", None) # Drop unknown top-level keys (e.g., 'corda_config', 'eva_config', etc.) for k in list(cfg.keys()): if k not in allowed: cfg.pop(k, None) cfg.setdefault("peft_type", "LORA") cfg.setdefault("task_type", "CAUSAL_LM") cfg.setdefault("bias", "none") cfg.setdefault("inference_mode", True) # Normalize booleans if strings for k in ("inference_mode", "use_rslora", "use_dora", "fan_in_fan_out"): if k in cfg and isinstance(cfg[k], str): cfg[k] = cfg[k].lower() in ("true", "1", "yes") with open(cfg_path, "w") as f: json.dump(cfg, f, indent=2) return src_dir logger.info(f"Downloading/caching LoRA adapters: {ADAPTER_ID}") _ADAPTER_LOCAL = snapshot_download(ADAPTER_ID, local_dir=None, local_dir_use_symlinks=False) _ADAPTER_LOCAL = _sanitize_adapter_repo(_ADAPTER_LOCAL) logger.info(f"Adapters ready at: {_ADAPTER_LOCAL}") # --------------------------- # Helpers # --------------------------- def _messages(image: Image.Image, question: str): if image.mode != "RGB": image = image.convert("RGB") return [ {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]}, {"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": question}]}, ] def build_inputs(image: Image.Image, question: str): msgs = _messages(image, question) text = processor.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True) image_inputs, video_inputs = process_vision_info(msgs) return processor(text=[text], images=image_inputs, videos=video_inputs, return_tensors="pt") def _pad_token_id(model): tid = getattr(getattr(processor, "tokenizer", None), "eos_token_id", None) return tid if tid is not None else (getattr(getattr(model, "config", None), "eos_token_id", 0) or 0) def _generate_text(model, inputs: Dict[str, Any]) -> str: # move tensors to model device device = next(model.parameters()).device inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()} with torch.no_grad(): out_ids = model.generate(**inputs, **GEN_KW, pad_token_id=_pad_token_id(model)) # trim prompt trimmed = [o[len(i):] for i, o in zip(inputs["input_ids"], out_ids)] text = processor.batch_decode(trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] return text def format_derm_disclaimer(ans: str) -> str: return ( ans + "\n\n---\n" "_Disclaimer: This AI is not a medical device. The output is informational and may be inaccurate. " "Consult a qualified dermatologist for diagnosis and treatment._" ) def _load_base_plus_lora(dtype: torch.dtype = torch.float16): logger.info(f"Loading BASE on GPU: {BASE_MODEL_ID}") base = VisionTextModelClass.from_pretrained( BASE_MODEL_ID, torch_dtype=dtype, device_map="cuda", trust_remote_code=True, low_cpu_mem_usage=True, ) logger.info(f"Attaching LoRA adapters from: {_ADAPTER_LOCAL}") model = PeftModel.from_pretrained(base, _ADAPTER_LOCAL, is_trainable=False) model.eval() return model # --------------------------- # Inference (ZeroGPU-safe: only here we touch CUDA) # --------------------------- @spaces.GPU(duration=ZGPU_DURATION) def analyze_skin_condition(image: Optional[Image.Image], question: str) -> str: if image is None: return "❌ Please upload an image first." model = None try: inputs = build_inputs(image, question) # pick fp16; bf16 also works on newer GPUs model = _load_base_plus_lora(dtype=torch.float16) text = _generate_text(model, inputs) return format_derm_disclaimer(text) except Exception as e: logger.exception("Error during inference") return f"❌ Error analyzing image: {e}" finally: if model is not None: del model torch.cuda.empty_cache() # --------------------------- # Batched inference API (hidden; call via /analyze_batch) # --------------------------- @spaces.GPU(duration=ZGPU_DURATION) def analyze_batch(samples: List[Dict[str, Any]]) -> List[str]: """ samples: list of dicts like: {"image": , "question": } Returns a list of responses (same order). """ outs: List[str] = [] if not isinstance(samples, list): return ["❌ Invalid payload: expected a JSON list of {image, question} dicts."] model = None try: model = _load_base_plus_lora(dtype=torch.float16) for ex in samples: try: img = ex.get("image") q = ex.get("question") or "Describe this skin condition in detail and suggest possible next steps." # If the client sent a path (e.g., via gradio_client handle_file), load it: if isinstance(img, str) and os.path.isfile(img): img = Image.open(img).convert("RGB") if not isinstance(img, Image.Image): outs.append("❌ Missing/invalid image") continue inputs = build_inputs(img, q) text = _generate_text(model, inputs) outs.append(format_derm_disclaimer(text)) except Exception as ie: logger.exception("Error on one batch item") outs.append(f"❌ Error analyzing one item: {ie}") return outs except Exception as e: logger.exception("Batch inference failed") return [f"❌ Batch error: {e}"] finally: if model is not None: del model torch.cuda.empty_cache() # --------------------------- # UI # --------------------------- def create_interface() -> gr.Blocks: with gr.Blocks(title="Dermatology AI Assistant") as demo: gr.Markdown( "# 🩺 Dermatology AI Assistant\n" "Upload a skin photo and ask a question. The model will provide an informational response." ) with gr.Row(): image_input = gr.Image(type="pil", label="Upload Image (JPG/PNG)") question_input = gr.Textbox( label="Question / Prompt", value="Describe this skin condition in detail and suggest possible next steps.", lines=3, ) with gr.Row(): submit_btn = gr.Button("Analyze", variant="primary") clear_btn = gr.Button("Clear") output_box = gr.Textbox(label="Response", lines=16, show_copy_button=True) submit_btn.click( fn=analyze_skin_condition, inputs=[image_input, question_input], outputs=output_box, queue=True, api_name="analyze_skin_condition", # public API for single requests ) clear_btn.click(fn=lambda: (None, ""), inputs=None, outputs=[image_input, question_input]) # Hidden minimal iface just to expose a batch API route gr.Interface( fn=analyze_batch, inputs=[gr.JSON(label="samples")], outputs=gr.JSON(label="responses"), allow_flagging="never", api_name="analyze_batch", # call this from gradio_client visible=False, # hide in UI; keep route alive ) demo.queue() gr.Markdown( "_Tips: Ensure good lighting and focus. Avoid uploading personally identifying information._" ) return demo def main(): demo = create_interface() demo.launch( server_name="0.0.0.0", server_port=7860, share=False, show_error=True, inbrowser=False, quiet=False, ssr_mode=False, # no Node requirement ) if __name__ == "__main__": main()