Spaces:
Sleeping
Sleeping
File size: 4,620 Bytes
59d87be |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
"""
EWAAST: Monk Skin Tone (MST) Classifier
Uses MedGemma's VQA capabilities to classify patient skin tone
on the 10-point Monk Skin Tone Scale.
"""
from enum import Enum
from dataclasses import dataclass
from typing import Optional
from PIL import Image
class MSTCategory(Enum):
"""Monk Skin Tone categories for clinical guidance."""
LIGHT = "light" # MST 1-3
MEDIUM = "medium" # MST 4-7
DEEP = "deep" # MST 8-10
@dataclass
class MSTResult:
"""Result of Monk Skin Tone classification."""
value: int # 1-10
category: MSTCategory
confidence: float
visual_guidance: str
@property
def description(self) -> str:
"""Human-readable description of the MST value."""
category_labels = {
MSTCategory.LIGHT: "Light",
MSTCategory.MEDIUM: "Medium",
MSTCategory.DEEP: "Deep"
}
return f"{category_labels[self.category]} (MST {self.value})"
class MSTClassifier:
"""
Classifier for Monk Skin Tone (MST) Scale.
The MST Scale is a 10-point skin tone representation developed
by Dr. Ellis Monk to improve AI fairness across diverse skin tones.
Usage:
classifier = MSTClassifier()
result = classifier.classify(image)
print(f"Detected: {result.description}")
"""
# Clinical visual guidance based on MST category
VISUAL_GUIDANCE = {
MSTCategory.LIGHT: (
"Look for: Non-blanchable erythema (redness), warmth, "
"pallor for ischemia. Inflammatory signs typically present as red."
),
MSTCategory.MEDIUM: (
"Look for: Subtle color changes (slightly darker/redder than surrounding skin), "
"warmth, shiny or taut skin. Erythema may not be bright red."
),
MSTCategory.DEEP: (
"Look for: Purple, blue, or ashen discoloration (NOT redness), "
"induration (hardness), localized heat, edema. "
"Stage 1 pressure ulcers may appear as persistent violet/maroon areas."
),
}
def __init__(self, model_name: str = "google/medgemma-1.5-4b-it"):
"""
Initialize the MST Classifier.
Args:
model_name: HuggingFace model ID for MedGemma
"""
self.model_name = model_name
self.model = None # Lazy loading
self.processor = None
def _load_model(self) -> None:
"""Load MedGemma model for VQA tasks."""
# TODO: Implement actual model loading
# from transformers import AutoProcessor, AutoModelForVision2Seq
# self.processor = AutoProcessor.from_pretrained(self.model_name)
# self.model = AutoModelForVision2Seq.from_pretrained(self.model_name)
pass
def _get_category(self, mst_value: int) -> MSTCategory:
"""Determine MST category from numeric value."""
if mst_value <= 3:
return MSTCategory.LIGHT
elif mst_value <= 7:
return MSTCategory.MEDIUM
else:
return MSTCategory.DEEP
def classify(self, image: Image.Image) -> MSTResult:
"""
Classify the skin tone of the patient in the image.
Uses the healthy skin visible around the wound area
to determine the patient's Monk Skin Tone value.
Args:
image: PIL Image containing the wound and surrounding skin
Returns:
MSTResult with value (1-10), category, and visual guidance
"""
# TODO: Implement actual classification using MedGemma VQA
# Prompt: "What is the Monk Skin Tone (1-10) of the patient's skin
# visible in this image? Only return a number 1-10."
# Placeholder: Return middle value with medium confidence
mst_value = 5
confidence = 0.0 # Indicates placeholder
category = self._get_category(mst_value)
visual_guidance = self.VISUAL_GUIDANCE[category]
return MSTResult(
value=mst_value,
category=category,
confidence=confidence,
visual_guidance=visual_guidance
)
def get_guidance_for_mst(self, mst_value: int) -> str:
"""
Get clinical visual guidance for a given MST value.
Args:
mst_value: Monk Skin Tone value (1-10)
Returns:
String with visual examination guidance
"""
category = self._get_category(mst_value)
return self.VISUAL_GUIDANCE[category]
|