ewaast-demo / src /agent /reasoning.py
Nursing Citizen Development
Deploy EWAAST app
59d87be
"""
EWAAST: Wound Assessment Reasoning Agent
Uses MedGemma 1.5 4B with MST context injection for
equitable wound assessment across all skin tones.
"""
from dataclasses import dataclass, field
from typing import Optional, Literal
from enum import Enum
from PIL import Image
from .classifier import MSTClassifier, MSTResult, MSTCategory
class WoundStage(Enum):
"""EPUAP/NPUAP Pressure Ulcer Staging."""
STAGE_1 = "Stage 1"
STAGE_2 = "Stage 2"
STAGE_3 = "Stage 3"
STAGE_4 = "Stage 4"
UNSTAGEABLE = "Unstageable"
DEEP_TISSUE_INJURY = "Deep Tissue Injury (DTI)"
NOT_A_PRESSURE_ULCER = "Not a Pressure Ulcer"
@dataclass
class WoundAssessment:
"""Result of wound assessment."""
stage: WoundStage
mst_result: MSTResult
rationale: str
care_plan: str
urgency: Literal["immediate", "urgent", "standard", "routine"]
confidence: float
raw_model_output: Optional[str] = None
class WoundAssessmentAgent:
"""
EWAAST Wound Assessment Agent.
Combines Monk Skin Tone detection with MedGemma reasoning
to provide equitable wound assessment.
Follows the EWAAST Assessment Skill workflow:
1. MST Analysis → Skin tone classification
2. Tone-Specific Feature Detection → Apply visual guidance
3. Clinical Staging → EPUAP/NPUAP criteria
4. Care Planning → Recommendations
Usage:
agent = WoundAssessmentAgent()
result = agent.assess(image, "Patient reports pain on heel")
print(f"Stage: {result.stage.value}")
"""
SYSTEM_PROMPT = """You are EWAAST (Equitable Wound Assessment for All Skin Tones),
an AI clinical assistant specialized in wound assessment that accounts for skin tone diversity.
CRITICAL: You MUST use the detected Monk Skin Tone (MST) to guide your visual assessment.
DO NOT rely solely on "redness" as an indicator of inflammation or pressure injury.
MST-SPECIFIC GUIDANCE:
{visual_guidance}
TASK: Assess the wound in the image and provide:
1. STAGE: EPUAP/NPUAP classification (Stage 1-4, Unstageable, DTI, or Not a Pressure Ulcer)
2. RATIONALE: Explain your staging decision with reference to visible features
3. URGENCY: immediate/urgent/standard/routine
4. CARE_PLAN: Brief recommended interventions
Patient Context: {patient_context}
Detected Skin Tone: MST {mst_value} ({mst_category})
Respond in JSON format:
{{"stage": "...", "rationale": "...", "urgency": "...", "care_plan": "..."}}
"""
def __init__(
self,
model_name: str = "google/medgemma-1.5-4b-it",
quantize: bool = True
):
"""
Initialize the Wound Assessment Agent.
Args:
model_name: HuggingFace model ID for MedGemma
quantize: Whether to use 4-bit quantization (recommended for consumer GPUs)
"""
self.model_name = model_name
self.quantize = quantize
self.model = None
self.processor = None
self.mst_classifier = MSTClassifier(model_name)
def _load_model(self) -> None:
"""Load MedGemma model with optional quantization."""
# TODO: Implement actual model loading
# from transformers import AutoProcessor, AutoModelForVision2Seq, BitsAndBytesConfig
#
# if self.quantize:
# bnb_config = BitsAndBytesConfig(
# load_in_4bit=True,
# bnb_4bit_quant_type="nf4",
# bnb_4bit_compute_dtype=torch.bfloat16
# )
# self.model = AutoModelForVision2Seq.from_pretrained(
# self.model_name,
# quantization_config=bnb_config
# )
# else:
# self.model = AutoModelForVision2Seq.from_pretrained(self.model_name)
#
# self.processor = AutoProcessor.from_pretrained(self.model_name)
pass
def _build_prompt(
self,
mst_result: MSTResult,
patient_context: str
) -> str:
"""Build the MST-aware system prompt."""
category_names = {
MSTCategory.LIGHT: "Light",
MSTCategory.MEDIUM: "Medium",
MSTCategory.DEEP: "Deep"
}
return self.SYSTEM_PROMPT.format(
visual_guidance=mst_result.visual_guidance,
patient_context=patient_context or "No additional context provided",
mst_value=mst_result.value,
mst_category=category_names[mst_result.category]
)
def _parse_model_output(
self,
output: str,
mst_result: MSTResult
) -> WoundAssessment:
"""Parse model JSON output into WoundAssessment object."""
import json
try:
data = json.loads(output)
# Map stage string to enum
stage_map = {
"stage 1": WoundStage.STAGE_1,
"stage 2": WoundStage.STAGE_2,
"stage 3": WoundStage.STAGE_3,
"stage 4": WoundStage.STAGE_4,
"unstageable": WoundStage.UNSTAGEABLE,
"deep tissue injury": WoundStage.DEEP_TISSUE_INJURY,
"dti": WoundStage.DEEP_TISSUE_INJURY,
"not a pressure ulcer": WoundStage.NOT_A_PRESSURE_ULCER,
}
stage_str = data.get("stage", "").lower()
stage = stage_map.get(stage_str, WoundStage.UNSTAGEABLE)
return WoundAssessment(
stage=stage,
mst_result=mst_result,
rationale=data.get("rationale", ""),
care_plan=data.get("care_plan", ""),
urgency=data.get("urgency", "standard"),
confidence=0.8, # TODO: Calculate from model confidence
raw_model_output=output
)
except (json.JSONDecodeError, KeyError):
# Fallback for unparseable output
return WoundAssessment(
stage=WoundStage.UNSTAGEABLE,
mst_result=mst_result,
rationale=f"Model output could not be parsed. Raw: {output[:200]}",
care_plan="Please consult a healthcare professional.",
urgency="urgent",
confidence=0.0,
raw_model_output=output
)
def assess(
self,
image: Image.Image,
patient_context: str = ""
) -> WoundAssessment:
"""
Perform equitable wound assessment.
Args:
image: PIL Image of the wound
patient_context: Optional description from patient/nurse
Returns:
WoundAssessment with stage, rationale, and care plan
"""
# Step 1: Classify skin tone
mst_result = self.mst_classifier.classify(image)
# Step 2: Build MST-aware prompt
prompt = self._build_prompt(mst_result, patient_context)
# Step 3: Run MedGemma inference
# TODO: Implement actual inference
# inputs = self.processor(images=image, text=prompt, return_tensors="pt")
# outputs = self.model.generate(**inputs, max_new_tokens=500)
# model_output = self.processor.decode(outputs[0], skip_special_tokens=True)
# Placeholder response
model_output = '{"stage": "Unable to assess", "rationale": "Model not loaded - placeholder response", "urgency": "standard", "care_plan": "Consult healthcare professional"}'
# Step 4: Parse and return
return self._parse_model_output(model_output, mst_result)