""" EWAAST: Monk Skin Tone (MST) Classifier Uses MedGemma's VQA capabilities to classify patient skin tone on the 10-point Monk Skin Tone Scale. """ from enum import Enum from dataclasses import dataclass from typing import Optional from PIL import Image class MSTCategory(Enum): """Monk Skin Tone categories for clinical guidance.""" LIGHT = "light" # MST 1-3 MEDIUM = "medium" # MST 4-7 DEEP = "deep" # MST 8-10 @dataclass class MSTResult: """Result of Monk Skin Tone classification.""" value: int # 1-10 category: MSTCategory confidence: float visual_guidance: str @property def description(self) -> str: """Human-readable description of the MST value.""" category_labels = { MSTCategory.LIGHT: "Light", MSTCategory.MEDIUM: "Medium", MSTCategory.DEEP: "Deep" } return f"{category_labels[self.category]} (MST {self.value})" class MSTClassifier: """ Classifier for Monk Skin Tone (MST) Scale. The MST Scale is a 10-point skin tone representation developed by Dr. Ellis Monk to improve AI fairness across diverse skin tones. Usage: classifier = MSTClassifier() result = classifier.classify(image) print(f"Detected: {result.description}") """ # Clinical visual guidance based on MST category VISUAL_GUIDANCE = { MSTCategory.LIGHT: ( "Look for: Non-blanchable erythema (redness), warmth, " "pallor for ischemia. Inflammatory signs typically present as red." ), MSTCategory.MEDIUM: ( "Look for: Subtle color changes (slightly darker/redder than surrounding skin), " "warmth, shiny or taut skin. Erythema may not be bright red." ), MSTCategory.DEEP: ( "Look for: Purple, blue, or ashen discoloration (NOT redness), " "induration (hardness), localized heat, edema. " "Stage 1 pressure ulcers may appear as persistent violet/maroon areas." ), } def __init__(self, model_name: str = "google/medgemma-1.5-4b-it"): """ Initialize the MST Classifier. Args: model_name: HuggingFace model ID for MedGemma """ self.model_name = model_name self.model = None # Lazy loading self.processor = None def _load_model(self) -> None: """Load MedGemma model for VQA tasks.""" # TODO: Implement actual model loading # from transformers import AutoProcessor, AutoModelForVision2Seq # self.processor = AutoProcessor.from_pretrained(self.model_name) # self.model = AutoModelForVision2Seq.from_pretrained(self.model_name) pass def _get_category(self, mst_value: int) -> MSTCategory: """Determine MST category from numeric value.""" if mst_value <= 3: return MSTCategory.LIGHT elif mst_value <= 7: return MSTCategory.MEDIUM else: return MSTCategory.DEEP def classify(self, image: Image.Image) -> MSTResult: """ Classify the skin tone of the patient in the image. Uses the healthy skin visible around the wound area to determine the patient's Monk Skin Tone value. Args: image: PIL Image containing the wound and surrounding skin Returns: MSTResult with value (1-10), category, and visual guidance """ # TODO: Implement actual classification using MedGemma VQA # Prompt: "What is the Monk Skin Tone (1-10) of the patient's skin # visible in this image? Only return a number 1-10." # Placeholder: Return middle value with medium confidence mst_value = 5 confidence = 0.0 # Indicates placeholder category = self._get_category(mst_value) visual_guidance = self.VISUAL_GUIDANCE[category] return MSTResult( value=mst_value, category=category, confidence=confidence, visual_guidance=visual_guidance ) def get_guidance_for_mst(self, mst_value: int) -> str: """ Get clinical visual guidance for a given MST value. Args: mst_value: Monk Skin Tone value (1-10) Returns: String with visual examination guidance """ category = self._get_category(mst_value) return self.VISUAL_GUIDANCE[category]