File size: 12,214 Bytes
ed31cbb
9d18382
 
 
 
 
ed31cbb
 
9d18382
 
 
ed31cbb
9d18382
ed31cbb
421600f
ed31cbb
26d219f
 
9d18382
cfbf33f
 
 
9d18382
cfbf33f
 
 
 
 
ed31cbb
fc0a615
ed31cbb
 
ab5e55b
ed31cbb
 
 
cfbf33f
9d18382
cfbf33f
9d18382
1ca0b67
fc0a615
9d18382
ed31cbb
cbc528c
cfbf33f
a2e0d44
 
cfbf33f
ed31cbb
3e03f62
cfbf33f
 
 
 
 
 
 
 
ed31cbb
e79ec61
cfbf33f
e79ec61
a2e0d44
cfbf33f
 
 
 
 
 
 
 
 
 
 
 
 
9d18382
cfbf33f
cbc528c
cfbf33f
 
 
 
 
a2e0d44
 
a79b20b
9d18382
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ed31cbb
 
 
 
 
 
a2e0d44
 
9d18382
 
a2e0d44
ed31cbb
3e03f62
cfbf33f
 
 
3e03f62
ed31cbb
3e03f62
ed31cbb
a2e0d44
ed31cbb
9d18382
cfbf33f
 
 
ed31cbb
a2e0d44
cfbf33f
ed31cbb
 
 
 
8b9a9ad
a2e0d44
 
 
 
 
8b9a9ad
 
cfbf33f
 
 
ea3a5f0
cfbf33f
639fd4b
ea3a5f0
 
 
9d18382
 
ea3a5f0
 
 
ed31cbb
cfbf33f
ed31cbb
 
 
 
 
49e8446
26d219f
3e03f62
9d18382
cfbf33f
3e03f62
ed31cbb
26d219f
ed31cbb
 
 
 
 
 
fc0a615
9d18382
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ed31cbb
 
 
 
 
 
cfbf33f
ed31cbb
26d219f
cfbf33f
ed31cbb
 
 
 
 
 
fc0a615
cfbf33f
fc0a615
ed31cbb
 
cfbf33f
 
ed31cbb
9d18382
 
 
 
 
 
 
ed31cbb
 
9d18382
 
 
 
 
 
 
 
 
 
e4aafad
cfbf33f
 
 
26d219f
 
cf76e86
ed31cbb
 
 
 
 
 
 
 
cfbf33f
ed31cbb
fc0a615
cf76e86
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
# app.py
# Dermatology-AI-Assistant — HF Spaces (ZeroGPU, Qwen2.5-VL + LoRA adapters)
# - Normal UI for single-image analysis
# - Hidden API endpoint /analyze_batch for batched evaluation
# - Caches & sanitizes LoRA repo once at startup (CPU); attaches on GPU per request
# - No CUDA at import-time; ZeroGPU only inside @spaces.GPU functions

import os
import json
import tempfile
import shutil
import logging
from typing import Optional, List, Dict, Any

import gradio as gr
import spaces
import torch
from PIL import Image
from huggingface_hub import snapshot_download
from peft import PeftModel
from transformers import AutoProcessor

# Prefer the new class name if your transformers is recent; fall back to old alias.
try:
    from transformers import AutoModelForImageTextToText as VisionTextModelClass
except Exception:
    from transformers import AutoModelForVision2Seq as VisionTextModelClass  # deprecated alias

from qwen_vl_utils import process_vision_info

logging.basicConfig(level=logging.INFO, format="%(levelname)s:%(name)s:%(message)s")
logger = logging.getLogger(__name__)

# ---------------------------
# Config
# ---------------------------
BASE_MODEL_ID = os.environ.get("BASE_MODEL_ID", "Qwen/Qwen2.5-VL-3B-Instruct")
ADAPTER_ID = os.environ.get("ADAPTER_ID", "ColdSlim/Dermatology-Qwen2.5-VL-3B-LoRA")

# Give ourselves more time for first load in cold starts
ZGPU_DURATION = int(os.environ.get("ZGPU_DURATION", "15"))  # seconds

# Deterministic decoding for eval; tweak as needed
GEN_KW = dict(
    max_new_tokens=64,
    do_sample=False,
    temperature=0.0,
    top_p=1.0,
    repetition_penalty=1.02,
)

SYSTEM_PROMPT = (
    "You are a dermatology assistant. First, look carefully at the IMAGE.\n"
    "If the image is NOT a close-up of human skin or a dermatologic lesion, "
    "respond EXACTLY with: 'The image does not appear to show a skin condition; I cannot analyze it.' "
    "Do not invent findings.\n"
    "If it IS a skin/lesion photo, provide a concise description, likely differentials (3–5), "
    "and prudent next steps. Avoid definitive diagnoses and include red flags briefly."
)

# ---------------------------
# Processor (CPU only; safe at import time)
# ---------------------------
def _load_multimodal_processor() -> AutoProcessor:
    logger.info(f"Loading multimodal processor from base: {BASE_MODEL_ID}")
    proc = AutoProcessor.from_pretrained(
        BASE_MODEL_ID,
        trust_remote_code=True,
        use_fast=False,  # ensure multimodal __call__(images=...) works
    )
    # sanity check
    sig = getattr(proc.__call__, "__signature__", None)
    accepts_images = ("images" in str(sig)) if sig else hasattr(proc, "image_processor")
    if not accepts_images or not hasattr(proc, "image_processor"):
        raise RuntimeError(
            "Loaded processor is not multimodal. Ensure transformers>=4.44.2, qwen-vl-utils>=0.0.8, torch>=2.2."
        )
    # optional: stabilize pixel hints
    try:
        proc.image_processor.max_pixels = int(os.environ.get("QWEN_MAX_PIXELS", str(256 * 28 * 28)))  # ~0.2MP
        proc.image_processor.min_pixels = int(os.environ.get("QWEN_MIN_PIXELS", str(256 * 28 * 28)))
    except Exception:
        pass
    logger.info(f"Processor ready: {proc.__class__.__name__}")
    return proc

processor = _load_multimodal_processor()

# ---------------------------
# LoRA adapter cache & sanitize (CPU-only, startup)
# ---------------------------
def _sanitize_adapter_repo(src_dir: str) -> str:
    """Remove unknown keys from adapter_config.json so PEFT can parse."""
    cfg_path = os.path.join(src_dir, "adapter_config.json")
    if not os.path.isfile(cfg_path):
        return src_dir

    with open(cfg_path, "r") as f:
        cfg = json.load(f)

    allowed = {
        "peft_type", "task_type",
        "r", "lora_alpha", "lora_dropout",
        "target_modules", "bias",
        "inference_mode",
        "base_model_name_or_path",
        "fan_in_fan_out",
        "modules_to_save",
        "layers_to_transform",
        "layers_pattern",
        "use_rslora",
        "rank_dropout", "module_dropout",
        "init_lora_weights",
        "use_dora",
    }

    # If DoRA isn't actually used, remove its block
    if str(cfg.get("use_dora", "false")).lower() in ("false", "0", "no"):
        cfg.pop("dora_config", None)

    # Drop unknown top-level keys (e.g., 'corda_config', 'eva_config', etc.)
    for k in list(cfg.keys()):
        if k not in allowed:
            cfg.pop(k, None)

    cfg.setdefault("peft_type", "LORA")
    cfg.setdefault("task_type", "CAUSAL_LM")
    cfg.setdefault("bias", "none")
    cfg.setdefault("inference_mode", True)

    # Normalize booleans if strings
    for k in ("inference_mode", "use_rslora", "use_dora", "fan_in_fan_out"):
        if k in cfg and isinstance(cfg[k], str):
            cfg[k] = cfg[k].lower() in ("true", "1", "yes")

    with open(cfg_path, "w") as f:
        json.dump(cfg, f, indent=2)
    return src_dir

logger.info(f"Downloading/caching LoRA adapters: {ADAPTER_ID}")
_ADAPTER_LOCAL = snapshot_download(ADAPTER_ID, local_dir=None, local_dir_use_symlinks=False)
_ADAPTER_LOCAL = _sanitize_adapter_repo(_ADAPTER_LOCAL)
logger.info(f"Adapters ready at: {_ADAPTER_LOCAL}")

# ---------------------------
# Helpers
# ---------------------------
def _messages(image: Image.Image, question: str):
    if image.mode != "RGB":
        image = image.convert("RGB")
    return [
        {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
        {"role": "user", "content": [{"type": "image", "image": image},
                                     {"type": "text", "text": question}]},
    ]

def build_inputs(image: Image.Image, question: str):
    msgs = _messages(image, question)
    text = processor.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
    image_inputs, video_inputs = process_vision_info(msgs)
    return processor(text=[text], images=image_inputs, videos=video_inputs, return_tensors="pt")

def _pad_token_id(model):
    tid = getattr(getattr(processor, "tokenizer", None), "eos_token_id", None)
    return tid if tid is not None else (getattr(getattr(model, "config", None), "eos_token_id", 0) or 0)

def _generate_text(model, inputs: Dict[str, Any]) -> str:
    # move tensors to model device
    device = next(model.parameters()).device
    inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
    with torch.no_grad():
        out_ids = model.generate(**inputs, **GEN_KW, pad_token_id=_pad_token_id(model))
    # trim prompt
    trimmed = [o[len(i):] for i, o in zip(inputs["input_ids"], out_ids)]
    text = processor.batch_decode(trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
    return text

def format_derm_disclaimer(ans: str) -> str:
    return (
        ans
        + "\n\n---\n"
          "_Disclaimer: This AI is not a medical device. The output is informational and may be inaccurate. "
          "Consult a qualified dermatologist for diagnosis and treatment._"
    )

def _load_base_plus_lora(dtype: torch.dtype = torch.float16):
    logger.info(f"Loading BASE on GPU: {BASE_MODEL_ID}")
    base = VisionTextModelClass.from_pretrained(
        BASE_MODEL_ID,
        torch_dtype=dtype,
        device_map="cuda",
        trust_remote_code=True,
        low_cpu_mem_usage=True,
    )
    logger.info(f"Attaching LoRA adapters from: {_ADAPTER_LOCAL}")
    model = PeftModel.from_pretrained(base, _ADAPTER_LOCAL, is_trainable=False)
    model.eval()
    return model

# ---------------------------
# Inference (ZeroGPU-safe: only here we touch CUDA)
# ---------------------------
@spaces.GPU(duration=ZGPU_DURATION)
def analyze_skin_condition(image: Optional[Image.Image], question: str) -> str:
    if image is None:
        return "❌ Please upload an image first."
    model = None
    try:
        inputs = build_inputs(image, question)
        # pick fp16; bf16 also works on newer GPUs
        model = _load_base_plus_lora(dtype=torch.float16)
        text = _generate_text(model, inputs)
        return format_derm_disclaimer(text)
    except Exception as e:
        logger.exception("Error during inference")
        return f"❌ Error analyzing image: {e}"
    finally:
        if model is not None:
            del model
        torch.cuda.empty_cache()

# ---------------------------
# Batched inference API (hidden; call via /analyze_batch)
# ---------------------------
@spaces.GPU(duration=ZGPU_DURATION)
def analyze_batch(samples: List[Dict[str, Any]]) -> List[str]:
    """
    samples: list of dicts like: {"image": <PIL/Image or filepath>, "question": <str>}
    Returns a list of responses (same order).
    """
    outs: List[str] = []
    if not isinstance(samples, list):
        return ["❌ Invalid payload: expected a JSON list of {image, question} dicts."]
    model = None
    try:
        model = _load_base_plus_lora(dtype=torch.float16)
        for ex in samples:
            try:
                img = ex.get("image")
                q = ex.get("question") or "Describe this skin condition in detail and suggest possible next steps."
                # If the client sent a path (e.g., via gradio_client handle_file), load it:
                if isinstance(img, str) and os.path.isfile(img):
                    img = Image.open(img).convert("RGB")
                if not isinstance(img, Image.Image):
                    outs.append("❌ Missing/invalid image")
                    continue
                inputs = build_inputs(img, q)
                text = _generate_text(model, inputs)
                outs.append(format_derm_disclaimer(text))
            except Exception as ie:
                logger.exception("Error on one batch item")
                outs.append(f"❌ Error analyzing one item: {ie}")
        return outs
    except Exception as e:
        logger.exception("Batch inference failed")
        return [f"❌ Batch error: {e}"]
    finally:
        if model is not None:
            del model
        torch.cuda.empty_cache()

# ---------------------------
# UI
# ---------------------------
def create_interface() -> gr.Blocks:
    with gr.Blocks(title="Dermatology AI Assistant") as demo:
        gr.Markdown(
            "# 🩺 Dermatology AI Assistant\n"
            "Upload a skin photo and ask a question. The model will provide an informational response."
        )

        with gr.Row():
            image_input = gr.Image(type="pil", label="Upload Image (JPG/PNG)")
            question_input = gr.Textbox(
                label="Question / Prompt",
                value="Describe this skin condition in detail and suggest possible next steps.",
                lines=3,
            )

        with gr.Row():
            submit_btn = gr.Button("Analyze", variant="primary")
            clear_btn = gr.Button("Clear")

        output_box = gr.Textbox(label="Response", lines=16, show_copy_button=True)

        submit_btn.click(
            fn=analyze_skin_condition,
            inputs=[image_input, question_input],
            outputs=output_box,
            queue=True,
            api_name="analyze_skin_condition",  # public API for single requests
        )
        clear_btn.click(fn=lambda: (None, ""), inputs=None, outputs=[image_input, question_input])

        # Hidden minimal iface just to expose a batch API route
        gr.Interface(
            fn=analyze_batch,
            inputs=[gr.JSON(label="samples")],
            outputs=gr.JSON(label="responses"),
            allow_flagging="never",
            api_name="analyze_batch",  # call this from gradio_client
            visible=False,             # hide in UI; keep route alive
        )

        demo.queue()
        gr.Markdown(
            "_Tips: Ensure good lighting and focus. Avoid uploading personally identifying information._"
        )
    return demo

def main():
    demo = create_interface()
    demo.launch(
        server_name="0.0.0.0",
        server_port=7860,
        share=False,
        show_error=True,
        inbrowser=False,
        quiet=False,
        ssr_mode=False,  # no Node requirement
    )

if __name__ == "__main__":
    main()