File size: 4,620 Bytes
59d87be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
"""
EWAAST: Monk Skin Tone (MST) Classifier

Uses MedGemma's VQA capabilities to classify patient skin tone
on the 10-point Monk Skin Tone Scale.
"""

from enum import Enum
from dataclasses import dataclass
from typing import Optional
from PIL import Image


class MSTCategory(Enum):
    """Monk Skin Tone categories for clinical guidance."""
    LIGHT = "light"      # MST 1-3
    MEDIUM = "medium"    # MST 4-7
    DEEP = "deep"        # MST 8-10


@dataclass
class MSTResult:
    """Result of Monk Skin Tone classification."""
    value: int  # 1-10
    category: MSTCategory
    confidence: float
    visual_guidance: str
    
    @property
    def description(self) -> str:
        """Human-readable description of the MST value."""
        category_labels = {
            MSTCategory.LIGHT: "Light",
            MSTCategory.MEDIUM: "Medium", 
            MSTCategory.DEEP: "Deep"
        }
        return f"{category_labels[self.category]} (MST {self.value})"


class MSTClassifier:
    """
    Classifier for Monk Skin Tone (MST) Scale.
    
    The MST Scale is a 10-point skin tone representation developed
    by Dr. Ellis Monk to improve AI fairness across diverse skin tones.
    
    Usage:
        classifier = MSTClassifier()
        result = classifier.classify(image)
        print(f"Detected: {result.description}")
    """
    
    # Clinical visual guidance based on MST category
    VISUAL_GUIDANCE = {
        MSTCategory.LIGHT: (
            "Look for: Non-blanchable erythema (redness), warmth, "
            "pallor for ischemia. Inflammatory signs typically present as red."
        ),
        MSTCategory.MEDIUM: (
            "Look for: Subtle color changes (slightly darker/redder than surrounding skin), "
            "warmth, shiny or taut skin. Erythema may not be bright red."
        ),
        MSTCategory.DEEP: (
            "Look for: Purple, blue, or ashen discoloration (NOT redness), "
            "induration (hardness), localized heat, edema. "
            "Stage 1 pressure ulcers may appear as persistent violet/maroon areas."
        ),
    }
    
    def __init__(self, model_name: str = "google/medgemma-1.5-4b-it"):
        """
        Initialize the MST Classifier.
        
        Args:
            model_name: HuggingFace model ID for MedGemma
        """
        self.model_name = model_name
        self.model = None  # Lazy loading
        self.processor = None
    
    def _load_model(self) -> None:
        """Load MedGemma model for VQA tasks."""
        # TODO: Implement actual model loading
        # from transformers import AutoProcessor, AutoModelForVision2Seq
        # self.processor = AutoProcessor.from_pretrained(self.model_name)
        # self.model = AutoModelForVision2Seq.from_pretrained(self.model_name)
        pass
    
    def _get_category(self, mst_value: int) -> MSTCategory:
        """Determine MST category from numeric value."""
        if mst_value <= 3:
            return MSTCategory.LIGHT
        elif mst_value <= 7:
            return MSTCategory.MEDIUM
        else:
            return MSTCategory.DEEP
    
    def classify(self, image: Image.Image) -> MSTResult:
        """
        Classify the skin tone of the patient in the image.
        
        Uses the healthy skin visible around the wound area
        to determine the patient's Monk Skin Tone value.
        
        Args:
            image: PIL Image containing the wound and surrounding skin
            
        Returns:
            MSTResult with value (1-10), category, and visual guidance
        """
        # TODO: Implement actual classification using MedGemma VQA
        # Prompt: "What is the Monk Skin Tone (1-10) of the patient's skin 
        #          visible in this image? Only return a number 1-10."
        
        # Placeholder: Return middle value with medium confidence
        mst_value = 5
        confidence = 0.0  # Indicates placeholder
        
        category = self._get_category(mst_value)
        visual_guidance = self.VISUAL_GUIDANCE[category]
        
        return MSTResult(
            value=mst_value,
            category=category,
            confidence=confidence,
            visual_guidance=visual_guidance
        )
    
    def get_guidance_for_mst(self, mst_value: int) -> str:
        """
        Get clinical visual guidance for a given MST value.
        
        Args:
            mst_value: Monk Skin Tone value (1-10)
            
        Returns:
            String with visual examination guidance
        """
        category = self._get_category(mst_value)
        return self.VISUAL_GUIDANCE[category]