Spaces:
Sleeping
Sleeping
| # app.py | |
| # Dermatology-AI-Assistant — HF Spaces (ZeroGPU, Qwen2.5-VL + LoRA adapters) | |
| # - Normal UI for single-image analysis | |
| # - Hidden API endpoint /analyze_batch for batched evaluation | |
| # - Caches & sanitizes LoRA repo once at startup (CPU); attaches on GPU per request | |
| # - No CUDA at import-time; ZeroGPU only inside @spaces.GPU functions | |
| import os | |
| import json | |
| import tempfile | |
| import shutil | |
| import logging | |
| from typing import Optional, List, Dict, Any | |
| import gradio as gr | |
| import spaces | |
| import torch | |
| from PIL import Image | |
| from huggingface_hub import snapshot_download | |
| from peft import PeftModel | |
| from transformers import AutoProcessor | |
| # Prefer the new class name if your transformers is recent; fall back to old alias. | |
| try: | |
| from transformers import AutoModelForImageTextToText as VisionTextModelClass | |
| except Exception: | |
| from transformers import AutoModelForVision2Seq as VisionTextModelClass # deprecated alias | |
| from qwen_vl_utils import process_vision_info | |
| logging.basicConfig(level=logging.INFO, format="%(levelname)s:%(name)s:%(message)s") | |
| logger = logging.getLogger(__name__) | |
| # --------------------------- | |
| # Config | |
| # --------------------------- | |
| BASE_MODEL_ID = os.environ.get("BASE_MODEL_ID", "Qwen/Qwen2.5-VL-3B-Instruct") | |
| ADAPTER_ID = os.environ.get("ADAPTER_ID", "ColdSlim/Dermatology-Qwen2.5-VL-3B-LoRA") | |
| # Give ourselves more time for first load in cold starts | |
| ZGPU_DURATION = int(os.environ.get("ZGPU_DURATION", "15")) # seconds | |
| # Deterministic decoding for eval; tweak as needed | |
| GEN_KW = dict( | |
| max_new_tokens=64, | |
| do_sample=False, | |
| temperature=0.0, | |
| top_p=1.0, | |
| repetition_penalty=1.02, | |
| ) | |
| SYSTEM_PROMPT = ( | |
| "You are a dermatology assistant. First, look carefully at the IMAGE.\n" | |
| "If the image is NOT a close-up of human skin or a dermatologic lesion, " | |
| "respond EXACTLY with: 'The image does not appear to show a skin condition; I cannot analyze it.' " | |
| "Do not invent findings.\n" | |
| "If it IS a skin/lesion photo, provide a concise description, likely differentials (3–5), " | |
| "and prudent next steps. Avoid definitive diagnoses and include red flags briefly." | |
| ) | |
| # --------------------------- | |
| # Processor (CPU only; safe at import time) | |
| # --------------------------- | |
| def _load_multimodal_processor() -> AutoProcessor: | |
| logger.info(f"Loading multimodal processor from base: {BASE_MODEL_ID}") | |
| proc = AutoProcessor.from_pretrained( | |
| BASE_MODEL_ID, | |
| trust_remote_code=True, | |
| use_fast=False, # ensure multimodal __call__(images=...) works | |
| ) | |
| # sanity check | |
| sig = getattr(proc.__call__, "__signature__", None) | |
| accepts_images = ("images" in str(sig)) if sig else hasattr(proc, "image_processor") | |
| if not accepts_images or not hasattr(proc, "image_processor"): | |
| raise RuntimeError( | |
| "Loaded processor is not multimodal. Ensure transformers>=4.44.2, qwen-vl-utils>=0.0.8, torch>=2.2." | |
| ) | |
| # optional: stabilize pixel hints | |
| try: | |
| proc.image_processor.max_pixels = int(os.environ.get("QWEN_MAX_PIXELS", str(256 * 28 * 28))) # ~0.2MP | |
| proc.image_processor.min_pixels = int(os.environ.get("QWEN_MIN_PIXELS", str(256 * 28 * 28))) | |
| except Exception: | |
| pass | |
| logger.info(f"Processor ready: {proc.__class__.__name__}") | |
| return proc | |
| processor = _load_multimodal_processor() | |
| # --------------------------- | |
| # LoRA adapter cache & sanitize (CPU-only, startup) | |
| # --------------------------- | |
| def _sanitize_adapter_repo(src_dir: str) -> str: | |
| """Remove unknown keys from adapter_config.json so PEFT can parse.""" | |
| cfg_path = os.path.join(src_dir, "adapter_config.json") | |
| if not os.path.isfile(cfg_path): | |
| return src_dir | |
| with open(cfg_path, "r") as f: | |
| cfg = json.load(f) | |
| allowed = { | |
| "peft_type", "task_type", | |
| "r", "lora_alpha", "lora_dropout", | |
| "target_modules", "bias", | |
| "inference_mode", | |
| "base_model_name_or_path", | |
| "fan_in_fan_out", | |
| "modules_to_save", | |
| "layers_to_transform", | |
| "layers_pattern", | |
| "use_rslora", | |
| "rank_dropout", "module_dropout", | |
| "init_lora_weights", | |
| "use_dora", | |
| } | |
| # If DoRA isn't actually used, remove its block | |
| if str(cfg.get("use_dora", "false")).lower() in ("false", "0", "no"): | |
| cfg.pop("dora_config", None) | |
| # Drop unknown top-level keys (e.g., 'corda_config', 'eva_config', etc.) | |
| for k in list(cfg.keys()): | |
| if k not in allowed: | |
| cfg.pop(k, None) | |
| cfg.setdefault("peft_type", "LORA") | |
| cfg.setdefault("task_type", "CAUSAL_LM") | |
| cfg.setdefault("bias", "none") | |
| cfg.setdefault("inference_mode", True) | |
| # Normalize booleans if strings | |
| for k in ("inference_mode", "use_rslora", "use_dora", "fan_in_fan_out"): | |
| if k in cfg and isinstance(cfg[k], str): | |
| cfg[k] = cfg[k].lower() in ("true", "1", "yes") | |
| with open(cfg_path, "w") as f: | |
| json.dump(cfg, f, indent=2) | |
| return src_dir | |
| logger.info(f"Downloading/caching LoRA adapters: {ADAPTER_ID}") | |
| _ADAPTER_LOCAL = snapshot_download(ADAPTER_ID, local_dir=None, local_dir_use_symlinks=False) | |
| _ADAPTER_LOCAL = _sanitize_adapter_repo(_ADAPTER_LOCAL) | |
| logger.info(f"Adapters ready at: {_ADAPTER_LOCAL}") | |
| # --------------------------- | |
| # Helpers | |
| # --------------------------- | |
| def _messages(image: Image.Image, question: str): | |
| if image.mode != "RGB": | |
| image = image.convert("RGB") | |
| return [ | |
| {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]}, | |
| {"role": "user", "content": [{"type": "image", "image": image}, | |
| {"type": "text", "text": question}]}, | |
| ] | |
| def build_inputs(image: Image.Image, question: str): | |
| msgs = _messages(image, question) | |
| text = processor.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True) | |
| image_inputs, video_inputs = process_vision_info(msgs) | |
| return processor(text=[text], images=image_inputs, videos=video_inputs, return_tensors="pt") | |
| def _pad_token_id(model): | |
| tid = getattr(getattr(processor, "tokenizer", None), "eos_token_id", None) | |
| return tid if tid is not None else (getattr(getattr(model, "config", None), "eos_token_id", 0) or 0) | |
| def _generate_text(model, inputs: Dict[str, Any]) -> str: | |
| # move tensors to model device | |
| device = next(model.parameters()).device | |
| inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| out_ids = model.generate(**inputs, **GEN_KW, pad_token_id=_pad_token_id(model)) | |
| # trim prompt | |
| trimmed = [o[len(i):] for i, o in zip(inputs["input_ids"], out_ids)] | |
| text = processor.batch_decode(trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] | |
| return text | |
| def format_derm_disclaimer(ans: str) -> str: | |
| return ( | |
| ans | |
| + "\n\n---\n" | |
| "_Disclaimer: This AI is not a medical device. The output is informational and may be inaccurate. " | |
| "Consult a qualified dermatologist for diagnosis and treatment._" | |
| ) | |
| def _load_base_plus_lora(dtype: torch.dtype = torch.float16): | |
| logger.info(f"Loading BASE on GPU: {BASE_MODEL_ID}") | |
| base = VisionTextModelClass.from_pretrained( | |
| BASE_MODEL_ID, | |
| torch_dtype=dtype, | |
| device_map="cuda", | |
| trust_remote_code=True, | |
| low_cpu_mem_usage=True, | |
| ) | |
| logger.info(f"Attaching LoRA adapters from: {_ADAPTER_LOCAL}") | |
| model = PeftModel.from_pretrained(base, _ADAPTER_LOCAL, is_trainable=False) | |
| model.eval() | |
| return model | |
| # --------------------------- | |
| # Inference (ZeroGPU-safe: only here we touch CUDA) | |
| # --------------------------- | |
| def analyze_skin_condition(image: Optional[Image.Image], question: str) -> str: | |
| if image is None: | |
| return "❌ Please upload an image first." | |
| model = None | |
| try: | |
| inputs = build_inputs(image, question) | |
| # pick fp16; bf16 also works on newer GPUs | |
| model = _load_base_plus_lora(dtype=torch.float16) | |
| text = _generate_text(model, inputs) | |
| return format_derm_disclaimer(text) | |
| except Exception as e: | |
| logger.exception("Error during inference") | |
| return f"❌ Error analyzing image: {e}" | |
| finally: | |
| if model is not None: | |
| del model | |
| torch.cuda.empty_cache() | |
| # --------------------------- | |
| # Batched inference API (hidden; call via /analyze_batch) | |
| # --------------------------- | |
| def analyze_batch(samples: List[Dict[str, Any]]) -> List[str]: | |
| """ | |
| samples: list of dicts like: {"image": <PIL/Image or filepath>, "question": <str>} | |
| Returns a list of responses (same order). | |
| """ | |
| outs: List[str] = [] | |
| if not isinstance(samples, list): | |
| return ["❌ Invalid payload: expected a JSON list of {image, question} dicts."] | |
| model = None | |
| try: | |
| model = _load_base_plus_lora(dtype=torch.float16) | |
| for ex in samples: | |
| try: | |
| img = ex.get("image") | |
| q = ex.get("question") or "Describe this skin condition in detail and suggest possible next steps." | |
| # If the client sent a path (e.g., via gradio_client handle_file), load it: | |
| if isinstance(img, str) and os.path.isfile(img): | |
| img = Image.open(img).convert("RGB") | |
| if not isinstance(img, Image.Image): | |
| outs.append("❌ Missing/invalid image") | |
| continue | |
| inputs = build_inputs(img, q) | |
| text = _generate_text(model, inputs) | |
| outs.append(format_derm_disclaimer(text)) | |
| except Exception as ie: | |
| logger.exception("Error on one batch item") | |
| outs.append(f"❌ Error analyzing one item: {ie}") | |
| return outs | |
| except Exception as e: | |
| logger.exception("Batch inference failed") | |
| return [f"❌ Batch error: {e}"] | |
| finally: | |
| if model is not None: | |
| del model | |
| torch.cuda.empty_cache() | |
| # --------------------------- | |
| # UI | |
| # --------------------------- | |
| def create_interface() -> gr.Blocks: | |
| with gr.Blocks(title="Dermatology AI Assistant") as demo: | |
| gr.Markdown( | |
| "# 🩺 Dermatology AI Assistant\n" | |
| "Upload a skin photo and ask a question. The model will provide an informational response." | |
| ) | |
| with gr.Row(): | |
| image_input = gr.Image(type="pil", label="Upload Image (JPG/PNG)") | |
| question_input = gr.Textbox( | |
| label="Question / Prompt", | |
| value="Describe this skin condition in detail and suggest possible next steps.", | |
| lines=3, | |
| ) | |
| with gr.Row(): | |
| submit_btn = gr.Button("Analyze", variant="primary") | |
| clear_btn = gr.Button("Clear") | |
| output_box = gr.Textbox(label="Response", lines=16, show_copy_button=True) | |
| submit_btn.click( | |
| fn=analyze_skin_condition, | |
| inputs=[image_input, question_input], | |
| outputs=output_box, | |
| queue=True, | |
| api_name="analyze_skin_condition", # public API for single requests | |
| ) | |
| clear_btn.click(fn=lambda: (None, ""), inputs=None, outputs=[image_input, question_input]) | |
| # Hidden minimal iface just to expose a batch API route | |
| gr.Interface( | |
| fn=analyze_batch, | |
| inputs=[gr.JSON(label="samples")], | |
| outputs=gr.JSON(label="responses"), | |
| allow_flagging="never", | |
| api_name="analyze_batch", # call this from gradio_client | |
| visible=False, # hide in UI; keep route alive | |
| ) | |
| demo.queue() | |
| gr.Markdown( | |
| "_Tips: Ensure good lighting and focus. Avoid uploading personally identifying information._" | |
| ) | |
| return demo | |
| def main(): | |
| demo = create_interface() | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, | |
| show_error=True, | |
| inbrowser=False, | |
| quiet=False, | |
| ssr_mode=False, # no Node requirement | |
| ) | |
| if __name__ == "__main__": | |
| main() | |