ewaast-demo / src /agent /classifier.py
Nursing Citizen Development
Deploy EWAAST app
59d87be
"""
EWAAST: Monk Skin Tone (MST) Classifier
Uses MedGemma's VQA capabilities to classify patient skin tone
on the 10-point Monk Skin Tone Scale.
"""
from enum import Enum
from dataclasses import dataclass
from typing import Optional
from PIL import Image
class MSTCategory(Enum):
"""Monk Skin Tone categories for clinical guidance."""
LIGHT = "light" # MST 1-3
MEDIUM = "medium" # MST 4-7
DEEP = "deep" # MST 8-10
@dataclass
class MSTResult:
"""Result of Monk Skin Tone classification."""
value: int # 1-10
category: MSTCategory
confidence: float
visual_guidance: str
@property
def description(self) -> str:
"""Human-readable description of the MST value."""
category_labels = {
MSTCategory.LIGHT: "Light",
MSTCategory.MEDIUM: "Medium",
MSTCategory.DEEP: "Deep"
}
return f"{category_labels[self.category]} (MST {self.value})"
class MSTClassifier:
"""
Classifier for Monk Skin Tone (MST) Scale.
The MST Scale is a 10-point skin tone representation developed
by Dr. Ellis Monk to improve AI fairness across diverse skin tones.
Usage:
classifier = MSTClassifier()
result = classifier.classify(image)
print(f"Detected: {result.description}")
"""
# Clinical visual guidance based on MST category
VISUAL_GUIDANCE = {
MSTCategory.LIGHT: (
"Look for: Non-blanchable erythema (redness), warmth, "
"pallor for ischemia. Inflammatory signs typically present as red."
),
MSTCategory.MEDIUM: (
"Look for: Subtle color changes (slightly darker/redder than surrounding skin), "
"warmth, shiny or taut skin. Erythema may not be bright red."
),
MSTCategory.DEEP: (
"Look for: Purple, blue, or ashen discoloration (NOT redness), "
"induration (hardness), localized heat, edema. "
"Stage 1 pressure ulcers may appear as persistent violet/maroon areas."
),
}
def __init__(self, model_name: str = "google/medgemma-1.5-4b-it"):
"""
Initialize the MST Classifier.
Args:
model_name: HuggingFace model ID for MedGemma
"""
self.model_name = model_name
self.model = None # Lazy loading
self.processor = None
def _load_model(self) -> None:
"""Load MedGemma model for VQA tasks."""
# TODO: Implement actual model loading
# from transformers import AutoProcessor, AutoModelForVision2Seq
# self.processor = AutoProcessor.from_pretrained(self.model_name)
# self.model = AutoModelForVision2Seq.from_pretrained(self.model_name)
pass
def _get_category(self, mst_value: int) -> MSTCategory:
"""Determine MST category from numeric value."""
if mst_value <= 3:
return MSTCategory.LIGHT
elif mst_value <= 7:
return MSTCategory.MEDIUM
else:
return MSTCategory.DEEP
def classify(self, image: Image.Image) -> MSTResult:
"""
Classify the skin tone of the patient in the image.
Uses the healthy skin visible around the wound area
to determine the patient's Monk Skin Tone value.
Args:
image: PIL Image containing the wound and surrounding skin
Returns:
MSTResult with value (1-10), category, and visual guidance
"""
# TODO: Implement actual classification using MedGemma VQA
# Prompt: "What is the Monk Skin Tone (1-10) of the patient's skin
# visible in this image? Only return a number 1-10."
# Placeholder: Return middle value with medium confidence
mst_value = 5
confidence = 0.0 # Indicates placeholder
category = self._get_category(mst_value)
visual_guidance = self.VISUAL_GUIDANCE[category]
return MSTResult(
value=mst_value,
category=category,
confidence=confidence,
visual_guidance=visual_guidance
)
def get_guidance_for_mst(self, mst_value: int) -> str:
"""
Get clinical visual guidance for a given MST value.
Args:
mst_value: Monk Skin Tone value (1-10)
Returns:
String with visual examination guidance
"""
category = self._get_category(mst_value)
return self.VISUAL_GUIDANCE[category]