File size: 11,032 Bytes
ac024f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84842ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
755e948
 
 
 
 
 
 
 
 
84842ba
 
 
755e948
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84842ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
755e948
84842ba
 
 
 
 
ac024f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
"""
MedGemma-powered clinical explanation generator.

Uses MedGemma (google/medgemma-4b-it) to produce natural-language
clinical assessments for skin lesion images, enhanced with
classification context from MedSigLIP.

Requires a HuggingFace token with access to MedGemma.
Set the ``HF_TOKEN`` environment variable or pass it to
``huggingface_hub.login()`` before calling ``load_model()``.
"""

import os
import torch
from PIL import Image
from pathlib import Path

CLASS_INFO = {
    "akiec": {
        "full_name": "Actinic Keratosis / Intraepithelial Carcinoma",
        "risk_level": "MODERATE",
        "description": "precancerous scaly lesion caused by sun damage",
        "action": "Dermatology referral within 2-4 weeks for evaluation and possible treatment",
    },
    "bcc": {
        "full_name": "Basal Cell Carcinoma",
        "risk_level": "HIGH",
        "description": "most common form of skin cancer, typically slow-growing",
        "action": "Dermatology referral within 2 weeks for biopsy and treatment planning",
    },
    "bkl": {
        "full_name": "Benign Keratosis",
        "risk_level": "LOW",
        "description": "non-cancerous growth including seborrheic keratosis",
        "action": "Routine monitoring; removal only if symptomatic or cosmetically desired",
    },
    "df": {
        "full_name": "Dermatofibroma",
        "risk_level": "LOW",
        "description": "benign fibrous skin nodule",
        "action": "No treatment required; reassure patient",
    },
    "mel": {
        "full_name": "Melanoma",
        "risk_level": "URGENT",
        "description": "potentially deadly form of skin cancer requiring immediate attention",
        "action": "URGENT dermatology referral within 48 hours; do not delay",
    },
    "nv": {
        "full_name": "Melanocytic Nevus",
        "risk_level": "LOW",
        "description": "common benign mole",
        "action": "Routine monitoring; educate patient on ABCDE warning signs",
    },
    "vasc": {
        "full_name": "Vascular Lesion",
        "risk_level": "LOW",
        "description": "benign blood vessel abnormality such as angioma",
        "action": "No treatment required unless symptomatic",
    },
}


class MedGemmaExplainer:
    """Lazy-loaded MedGemma explainer for clinical skin lesion analysis."""

    def __init__(self):
        self.model = None
        self.processor = None

    def load_model(self):
        """Load MedGemma (idempotent)."""
        if self.model is not None:
            return

        from transformers import AutoProcessor, AutoModelForImageTextToText

        self.model = AutoModelForImageTextToText.from_pretrained(
            "google/medgemma-4b-it",
            torch_dtype=torch.bfloat16,
            device_map="auto",
        )
        self.processor = AutoProcessor.from_pretrained("google/medgemma-4b-it")

    def generate_explanation(self, image, predicted_class, confidence, uncertainty=None):
        """Generate a clinical explanation for a classified skin lesion.

        Args:
            image: PIL Image of the lesion.
            predicted_class: One of the CLASS_INFO keys (e.g. ``"mel"``).
            confidence: Model confidence in [0, 1].
            uncertainty: Optional uncertainty score.

        Returns:
            dict with classification details, AI explanation, and recommendation.
        """
        self.load_model()

        info = CLASS_INFO.get(predicted_class, CLASS_INFO["nv"])

        prompt = (
            "You are a dermatology AI assistant helping primary care physicians triage skin lesions.\n\n"
            "Analyze this dermoscopic image and provide a clinical assessment.\n\n"
            f"The AI classification system has identified this lesion as: {info['full_name']}\n"
            f"Classification confidence: {confidence * 100:.1f}%\n"
        )
        if uncertainty is not None:
            level = "HIGH - consider expert review" if uncertainty > 0.3 else "LOW"
            prompt += f"Uncertainty level: {level}\n"
        prompt += (
            "\nPlease provide:\n"
            "1. A brief description of the visible dermoscopic features (2-3 sentences)\n"
            "2. Whether the AI classification appears consistent with the visual features\n"
            "3. Any additional observations relevant to clinical decision-making\n\n"
            "Keep your response concise and clinically focused."
        )

        messages = [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": prompt}]}]

        inputs = self.processor.apply_chat_template(
            messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt"
        ).to(self.model.device, dtype=torch.bfloat16)

        input_len = inputs["input_ids"].shape[-1]
        with torch.inference_mode():
            generation = self.model.generate(**inputs, max_new_tokens=300, do_sample=False)
            generation = generation[0][input_len:]

        explanation = self.processor.decode(generation, skip_special_tokens=True)

        report = {
            "classification": info["full_name"],
            "class_code": predicted_class,
            "confidence": confidence,
            "risk_level": info["risk_level"],
            "ai_explanation": explanation,
            "recommended_action": info["action"],
            "clinical_description": info["description"],
        }

        if uncertainty is not None:
            report["uncertainty"] = uncertainty
            report["uncertainty_note"] = (
                "HIGH uncertainty - recommend expert review regardless of classification"
                if uncertainty > 0.3
                else "Uncertainty within acceptable range"
            )
        return report

    def generate_triage_explanation(self, image, prob_malignant, zone_label,
                                     top_class=None, top_class_prob=None):
        """Generate a clinical explanation for a binary triage result.

        Designed for the two-step Gradio UI: the triage zone card renders
        instantly, then this method is called to fill in the clinical
        reasoning (~10-15s).

        Args:
            image: PIL Image of the lesion.
            prob_malignant: Blended malignancy probability in [0, 1].
            zone_label: One of "REFER", "UNCERTAIN", "LOW RISK".
            top_class: Optional top-1 predicted class code (e.g. "mel").
            top_class_prob: Optional confidence for top_class in [0, 1].

        Returns:
            str — plain-text clinical explanation.
        """
        self.load_model()

        # Map zone label to a decision-oriented opening phrase.
        zone_opening = {
            "REFER": "This lesion warrants dermatology referral",
            "UNCERTAIN": "This lesion warrants caution and clinical correlation",
            "LOW RISK": "This lesion appears low-risk based on visual features",
        }.get(zone_label, "This lesion was assessed as " + zone_label)

        # Optional specific-diagnosis hint (do not force the model to use it).
        diagnosis_hint = ""
        if top_class is not None and top_class_prob is not None:
            info = CLASS_INFO.get(top_class, {})
            full_name = info.get("full_name", top_class)
            if top_class_prob >= 0.30:
                diagnosis_hint = (
                    f"\nThe specific-class head's most likely diagnosis is {full_name}. "
                    "Mention this only if the visual features clearly support it."
                )

        prompt = (
            "You are a clinical decision-support assistant for primary care physicians "
            "reviewing a skin lesion image.\n\n"
            f"Triage call: {zone_label}.\n"
            f"Open your response with: \"{zone_opening} because...\" and complete the "
            "sentence with the 2 or 3 specific visual features that support this call.\n"
            f"{diagnosis_hint}\n\n"
            "Rules:\n"
            "- Maximum 3 short sentences, ~60 words total.\n"
            "- Plain clinical voice. No headers, no numbered lists, no bullet points.\n"
            "- Do not repeat the malignancy probability number.\n"
            "- Do not speculate about anatomical location if not visible.\n"
            "- Do not list textbook features; describe what you actually see."
        )

        messages = [
            {
                "role": "user",
                "content": [
                    {"type": "image", "image": image},
                    {"type": "text", "text": prompt},
                ],
            }
        ]

        inputs = self.processor.apply_chat_template(
            messages, add_generation_prompt=True, tokenize=True,
            return_dict=True, return_tensors="pt",
        ).to(self.model.device, dtype=torch.bfloat16)

        input_len = inputs["input_ids"].shape[-1]
        with torch.inference_mode():
            generation = self.model.generate(
                **inputs, max_new_tokens=160, do_sample=False,
            )
            generation = generation[0][input_len:]

        return self.processor.decode(generation, skip_special_tokens=True)

    def format_report(self, report):
        """Format a report dict as a readable clinical summary string."""
        lines = [
            "=" * 60,
            "DERMTRIAGE CLINICAL DECISION SUPPORT REPORT",
            "=" * 60,
            "",
            f"CLASSIFICATION: {report['classification']}",
            f"RISK LEVEL: {report['risk_level']}",
            f"CONFIDENCE: {report['confidence'] * 100:.1f}%",
        ]
        if "uncertainty" in report:
            lines.append(f"UNCERTAINTY: {report['uncertainty']:.2f} - {report['uncertainty_note']}")
        lines += [
            "",
            "-" * 60,
            "AI ANALYSIS:",
            "-" * 60,
            report["ai_explanation"],
            "",
            "-" * 60,
            "RECOMMENDED ACTION:",
            "-" * 60,
            report["recommended_action"],
            "",
            "=" * 60,
            "This report is for clinical decision support only.",
            "Final diagnosis requires expert dermatologic evaluation.",
            "=" * 60,
        ]
        return "\n".join(lines)


def generate_referral_packet(image_path, classification_result):
    """Generate a complete referral packet from an image path and classification result.

    Args:
        image_path: Path to skin lesion image.
        classification_result: dict with ``class``, ``confidence``, and optionally ``uncertainty``.

    Returns:
        Formatted clinical report string.
    """
    image = Image.open(image_path).convert("RGB")
    explainer = MedGemmaExplainer()
    report = explainer.generate_explanation(
        image=image,
        predicted_class=classification_result["class"],
        confidence=classification_result["confidence"],
        uncertainty=classification_result.get("uncertainty"),
    )
    return explainer.format_report(report)