Spaces:
Sleeping
Sleeping
MedGemma prompt: decision-oriented opening, 3-sentence cap, no numbered lists, no filler
755e948 verified | """ | |
| MedGemma-powered clinical explanation generator. | |
| Uses MedGemma (google/medgemma-4b-it) to produce natural-language | |
| clinical assessments for skin lesion images, enhanced with | |
| classification context from MedSigLIP. | |
| Requires a HuggingFace token with access to MedGemma. | |
| Set the ``HF_TOKEN`` environment variable or pass it to | |
| ``huggingface_hub.login()`` before calling ``load_model()``. | |
| """ | |
| import os | |
| import torch | |
| from PIL import Image | |
| from pathlib import Path | |
| CLASS_INFO = { | |
| "akiec": { | |
| "full_name": "Actinic Keratosis / Intraepithelial Carcinoma", | |
| "risk_level": "MODERATE", | |
| "description": "precancerous scaly lesion caused by sun damage", | |
| "action": "Dermatology referral within 2-4 weeks for evaluation and possible treatment", | |
| }, | |
| "bcc": { | |
| "full_name": "Basal Cell Carcinoma", | |
| "risk_level": "HIGH", | |
| "description": "most common form of skin cancer, typically slow-growing", | |
| "action": "Dermatology referral within 2 weeks for biopsy and treatment planning", | |
| }, | |
| "bkl": { | |
| "full_name": "Benign Keratosis", | |
| "risk_level": "LOW", | |
| "description": "non-cancerous growth including seborrheic keratosis", | |
| "action": "Routine monitoring; removal only if symptomatic or cosmetically desired", | |
| }, | |
| "df": { | |
| "full_name": "Dermatofibroma", | |
| "risk_level": "LOW", | |
| "description": "benign fibrous skin nodule", | |
| "action": "No treatment required; reassure patient", | |
| }, | |
| "mel": { | |
| "full_name": "Melanoma", | |
| "risk_level": "URGENT", | |
| "description": "potentially deadly form of skin cancer requiring immediate attention", | |
| "action": "URGENT dermatology referral within 48 hours; do not delay", | |
| }, | |
| "nv": { | |
| "full_name": "Melanocytic Nevus", | |
| "risk_level": "LOW", | |
| "description": "common benign mole", | |
| "action": "Routine monitoring; educate patient on ABCDE warning signs", | |
| }, | |
| "vasc": { | |
| "full_name": "Vascular Lesion", | |
| "risk_level": "LOW", | |
| "description": "benign blood vessel abnormality such as angioma", | |
| "action": "No treatment required unless symptomatic", | |
| }, | |
| } | |
| class MedGemmaExplainer: | |
| """Lazy-loaded MedGemma explainer for clinical skin lesion analysis.""" | |
| def __init__(self): | |
| self.model = None | |
| self.processor = None | |
| def load_model(self): | |
| """Load MedGemma (idempotent).""" | |
| if self.model is not None: | |
| return | |
| from transformers import AutoProcessor, AutoModelForImageTextToText | |
| self.model = AutoModelForImageTextToText.from_pretrained( | |
| "google/medgemma-4b-it", | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto", | |
| ) | |
| self.processor = AutoProcessor.from_pretrained("google/medgemma-4b-it") | |
| def generate_explanation(self, image, predicted_class, confidence, uncertainty=None): | |
| """Generate a clinical explanation for a classified skin lesion. | |
| Args: | |
| image: PIL Image of the lesion. | |
| predicted_class: One of the CLASS_INFO keys (e.g. ``"mel"``). | |
| confidence: Model confidence in [0, 1]. | |
| uncertainty: Optional uncertainty score. | |
| Returns: | |
| dict with classification details, AI explanation, and recommendation. | |
| """ | |
| self.load_model() | |
| info = CLASS_INFO.get(predicted_class, CLASS_INFO["nv"]) | |
| prompt = ( | |
| "You are a dermatology AI assistant helping primary care physicians triage skin lesions.\n\n" | |
| "Analyze this dermoscopic image and provide a clinical assessment.\n\n" | |
| f"The AI classification system has identified this lesion as: {info['full_name']}\n" | |
| f"Classification confidence: {confidence * 100:.1f}%\n" | |
| ) | |
| if uncertainty is not None: | |
| level = "HIGH - consider expert review" if uncertainty > 0.3 else "LOW" | |
| prompt += f"Uncertainty level: {level}\n" | |
| prompt += ( | |
| "\nPlease provide:\n" | |
| "1. A brief description of the visible dermoscopic features (2-3 sentences)\n" | |
| "2. Whether the AI classification appears consistent with the visual features\n" | |
| "3. Any additional observations relevant to clinical decision-making\n\n" | |
| "Keep your response concise and clinically focused." | |
| ) | |
| messages = [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": prompt}]}] | |
| inputs = self.processor.apply_chat_template( | |
| messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" | |
| ).to(self.model.device, dtype=torch.bfloat16) | |
| input_len = inputs["input_ids"].shape[-1] | |
| with torch.inference_mode(): | |
| generation = self.model.generate(**inputs, max_new_tokens=300, do_sample=False) | |
| generation = generation[0][input_len:] | |
| explanation = self.processor.decode(generation, skip_special_tokens=True) | |
| report = { | |
| "classification": info["full_name"], | |
| "class_code": predicted_class, | |
| "confidence": confidence, | |
| "risk_level": info["risk_level"], | |
| "ai_explanation": explanation, | |
| "recommended_action": info["action"], | |
| "clinical_description": info["description"], | |
| } | |
| if uncertainty is not None: | |
| report["uncertainty"] = uncertainty | |
| report["uncertainty_note"] = ( | |
| "HIGH uncertainty - recommend expert review regardless of classification" | |
| if uncertainty > 0.3 | |
| else "Uncertainty within acceptable range" | |
| ) | |
| return report | |
| def generate_triage_explanation(self, image, prob_malignant, zone_label, | |
| top_class=None, top_class_prob=None): | |
| """Generate a clinical explanation for a binary triage result. | |
| Designed for the two-step Gradio UI: the triage zone card renders | |
| instantly, then this method is called to fill in the clinical | |
| reasoning (~10-15s). | |
| Args: | |
| image: PIL Image of the lesion. | |
| prob_malignant: Blended malignancy probability in [0, 1]. | |
| zone_label: One of "REFER", "UNCERTAIN", "LOW RISK". | |
| top_class: Optional top-1 predicted class code (e.g. "mel"). | |
| top_class_prob: Optional confidence for top_class in [0, 1]. | |
| Returns: | |
| str — plain-text clinical explanation. | |
| """ | |
| self.load_model() | |
| # Map zone label to a decision-oriented opening phrase. | |
| zone_opening = { | |
| "REFER": "This lesion warrants dermatology referral", | |
| "UNCERTAIN": "This lesion warrants caution and clinical correlation", | |
| "LOW RISK": "This lesion appears low-risk based on visual features", | |
| }.get(zone_label, "This lesion was assessed as " + zone_label) | |
| # Optional specific-diagnosis hint (do not force the model to use it). | |
| diagnosis_hint = "" | |
| if top_class is not None and top_class_prob is not None: | |
| info = CLASS_INFO.get(top_class, {}) | |
| full_name = info.get("full_name", top_class) | |
| if top_class_prob >= 0.30: | |
| diagnosis_hint = ( | |
| f"\nThe specific-class head's most likely diagnosis is {full_name}. " | |
| "Mention this only if the visual features clearly support it." | |
| ) | |
| prompt = ( | |
| "You are a clinical decision-support assistant for primary care physicians " | |
| "reviewing a skin lesion image.\n\n" | |
| f"Triage call: {zone_label}.\n" | |
| f"Open your response with: \"{zone_opening} because...\" and complete the " | |
| "sentence with the 2 or 3 specific visual features that support this call.\n" | |
| f"{diagnosis_hint}\n\n" | |
| "Rules:\n" | |
| "- Maximum 3 short sentences, ~60 words total.\n" | |
| "- Plain clinical voice. No headers, no numbered lists, no bullet points.\n" | |
| "- Do not repeat the malignancy probability number.\n" | |
| "- Do not speculate about anatomical location if not visible.\n" | |
| "- Do not list textbook features; describe what you actually see." | |
| ) | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image", "image": image}, | |
| {"type": "text", "text": prompt}, | |
| ], | |
| } | |
| ] | |
| inputs = self.processor.apply_chat_template( | |
| messages, add_generation_prompt=True, tokenize=True, | |
| return_dict=True, return_tensors="pt", | |
| ).to(self.model.device, dtype=torch.bfloat16) | |
| input_len = inputs["input_ids"].shape[-1] | |
| with torch.inference_mode(): | |
| generation = self.model.generate( | |
| **inputs, max_new_tokens=160, do_sample=False, | |
| ) | |
| generation = generation[0][input_len:] | |
| return self.processor.decode(generation, skip_special_tokens=True) | |
| def format_report(self, report): | |
| """Format a report dict as a readable clinical summary string.""" | |
| lines = [ | |
| "=" * 60, | |
| "DERMTRIAGE CLINICAL DECISION SUPPORT REPORT", | |
| "=" * 60, | |
| "", | |
| f"CLASSIFICATION: {report['classification']}", | |
| f"RISK LEVEL: {report['risk_level']}", | |
| f"CONFIDENCE: {report['confidence'] * 100:.1f}%", | |
| ] | |
| if "uncertainty" in report: | |
| lines.append(f"UNCERTAINTY: {report['uncertainty']:.2f} - {report['uncertainty_note']}") | |
| lines += [ | |
| "", | |
| "-" * 60, | |
| "AI ANALYSIS:", | |
| "-" * 60, | |
| report["ai_explanation"], | |
| "", | |
| "-" * 60, | |
| "RECOMMENDED ACTION:", | |
| "-" * 60, | |
| report["recommended_action"], | |
| "", | |
| "=" * 60, | |
| "This report is for clinical decision support only.", | |
| "Final diagnosis requires expert dermatologic evaluation.", | |
| "=" * 60, | |
| ] | |
| return "\n".join(lines) | |
| def generate_referral_packet(image_path, classification_result): | |
| """Generate a complete referral packet from an image path and classification result. | |
| Args: | |
| image_path: Path to skin lesion image. | |
| classification_result: dict with ``class``, ``confidence``, and optionally ``uncertainty``. | |
| Returns: | |
| Formatted clinical report string. | |
| """ | |
| image = Image.open(image_path).convert("RGB") | |
| explainer = MedGemmaExplainer() | |
| report = explainer.generate_explanation( | |
| image=image, | |
| predicted_class=classification_result["class"], | |
| confidence=classification_result["confidence"], | |
| uncertainty=classification_result.get("uncertainty"), | |
| ) | |
| return explainer.format_report(report) | |