""" EWAAST: Wound Assessment Reasoning Agent Uses MedGemma 1.5 4B with MST context injection for equitable wound assessment across all skin tones. """ from dataclasses import dataclass, field from typing import Optional, Literal from enum import Enum from PIL import Image from .classifier import MSTClassifier, MSTResult, MSTCategory class WoundStage(Enum): """EPUAP/NPUAP Pressure Ulcer Staging.""" STAGE_1 = "Stage 1" STAGE_2 = "Stage 2" STAGE_3 = "Stage 3" STAGE_4 = "Stage 4" UNSTAGEABLE = "Unstageable" DEEP_TISSUE_INJURY = "Deep Tissue Injury (DTI)" NOT_A_PRESSURE_ULCER = "Not a Pressure Ulcer" @dataclass class WoundAssessment: """Result of wound assessment.""" stage: WoundStage mst_result: MSTResult rationale: str care_plan: str urgency: Literal["immediate", "urgent", "standard", "routine"] confidence: float raw_model_output: Optional[str] = None class WoundAssessmentAgent: """ EWAAST Wound Assessment Agent. Combines Monk Skin Tone detection with MedGemma reasoning to provide equitable wound assessment. Follows the EWAAST Assessment Skill workflow: 1. MST Analysis → Skin tone classification 2. Tone-Specific Feature Detection → Apply visual guidance 3. Clinical Staging → EPUAP/NPUAP criteria 4. Care Planning → Recommendations Usage: agent = WoundAssessmentAgent() result = agent.assess(image, "Patient reports pain on heel") print(f"Stage: {result.stage.value}") """ SYSTEM_PROMPT = """You are EWAAST (Equitable Wound Assessment for All Skin Tones), an AI clinical assistant specialized in wound assessment that accounts for skin tone diversity. CRITICAL: You MUST use the detected Monk Skin Tone (MST) to guide your visual assessment. DO NOT rely solely on "redness" as an indicator of inflammation or pressure injury. MST-SPECIFIC GUIDANCE: {visual_guidance} TASK: Assess the wound in the image and provide: 1. STAGE: EPUAP/NPUAP classification (Stage 1-4, Unstageable, DTI, or Not a Pressure Ulcer) 2. RATIONALE: Explain your staging decision with reference to visible features 3. URGENCY: immediate/urgent/standard/routine 4. CARE_PLAN: Brief recommended interventions Patient Context: {patient_context} Detected Skin Tone: MST {mst_value} ({mst_category}) Respond in JSON format: {{"stage": "...", "rationale": "...", "urgency": "...", "care_plan": "..."}} """ def __init__( self, model_name: str = "google/medgemma-1.5-4b-it", quantize: bool = True ): """ Initialize the Wound Assessment Agent. Args: model_name: HuggingFace model ID for MedGemma quantize: Whether to use 4-bit quantization (recommended for consumer GPUs) """ self.model_name = model_name self.quantize = quantize self.model = None self.processor = None self.mst_classifier = MSTClassifier(model_name) def _load_model(self) -> None: """Load MedGemma model with optional quantization.""" # TODO: Implement actual model loading # from transformers import AutoProcessor, AutoModelForVision2Seq, BitsAndBytesConfig # # if self.quantize: # bnb_config = BitsAndBytesConfig( # load_in_4bit=True, # bnb_4bit_quant_type="nf4", # bnb_4bit_compute_dtype=torch.bfloat16 # ) # self.model = AutoModelForVision2Seq.from_pretrained( # self.model_name, # quantization_config=bnb_config # ) # else: # self.model = AutoModelForVision2Seq.from_pretrained(self.model_name) # # self.processor = AutoProcessor.from_pretrained(self.model_name) pass def _build_prompt( self, mst_result: MSTResult, patient_context: str ) -> str: """Build the MST-aware system prompt.""" category_names = { MSTCategory.LIGHT: "Light", MSTCategory.MEDIUM: "Medium", MSTCategory.DEEP: "Deep" } return self.SYSTEM_PROMPT.format( visual_guidance=mst_result.visual_guidance, patient_context=patient_context or "No additional context provided", mst_value=mst_result.value, mst_category=category_names[mst_result.category] ) def _parse_model_output( self, output: str, mst_result: MSTResult ) -> WoundAssessment: """Parse model JSON output into WoundAssessment object.""" import json try: data = json.loads(output) # Map stage string to enum stage_map = { "stage 1": WoundStage.STAGE_1, "stage 2": WoundStage.STAGE_2, "stage 3": WoundStage.STAGE_3, "stage 4": WoundStage.STAGE_4, "unstageable": WoundStage.UNSTAGEABLE, "deep tissue injury": WoundStage.DEEP_TISSUE_INJURY, "dti": WoundStage.DEEP_TISSUE_INJURY, "not a pressure ulcer": WoundStage.NOT_A_PRESSURE_ULCER, } stage_str = data.get("stage", "").lower() stage = stage_map.get(stage_str, WoundStage.UNSTAGEABLE) return WoundAssessment( stage=stage, mst_result=mst_result, rationale=data.get("rationale", ""), care_plan=data.get("care_plan", ""), urgency=data.get("urgency", "standard"), confidence=0.8, # TODO: Calculate from model confidence raw_model_output=output ) except (json.JSONDecodeError, KeyError): # Fallback for unparseable output return WoundAssessment( stage=WoundStage.UNSTAGEABLE, mst_result=mst_result, rationale=f"Model output could not be parsed. Raw: {output[:200]}", care_plan="Please consult a healthcare professional.", urgency="urgent", confidence=0.0, raw_model_output=output ) def assess( self, image: Image.Image, patient_context: str = "" ) -> WoundAssessment: """ Perform equitable wound assessment. Args: image: PIL Image of the wound patient_context: Optional description from patient/nurse Returns: WoundAssessment with stage, rationale, and care plan """ # Step 1: Classify skin tone mst_result = self.mst_classifier.classify(image) # Step 2: Build MST-aware prompt prompt = self._build_prompt(mst_result, patient_context) # Step 3: Run MedGemma inference # TODO: Implement actual inference # inputs = self.processor(images=image, text=prompt, return_tensors="pt") # outputs = self.model.generate(**inputs, max_new_tokens=500) # model_output = self.processor.decode(outputs[0], skip_special_tokens=True) # Placeholder response model_output = '{"stage": "Unable to assess", "rationale": "Model not loaded - placeholder response", "urgency": "standard", "care_plan": "Consult healthcare professional"}' # Step 4: Parse and return return self._parse_model_output(model_output, mst_result)