Spaces:
Sleeping
Sleeping
| """ | |
| EWAAST: Wound Assessment Reasoning Agent | |
| Uses MedGemma 1.5 4B with MST context injection for | |
| equitable wound assessment across all skin tones. | |
| """ | |
| from dataclasses import dataclass, field | |
| from typing import Optional, Literal | |
| from enum import Enum | |
| from PIL import Image | |
| from .classifier import MSTClassifier, MSTResult, MSTCategory | |
| class WoundStage(Enum): | |
| """EPUAP/NPUAP Pressure Ulcer Staging.""" | |
| STAGE_1 = "Stage 1" | |
| STAGE_2 = "Stage 2" | |
| STAGE_3 = "Stage 3" | |
| STAGE_4 = "Stage 4" | |
| UNSTAGEABLE = "Unstageable" | |
| DEEP_TISSUE_INJURY = "Deep Tissue Injury (DTI)" | |
| NOT_A_PRESSURE_ULCER = "Not a Pressure Ulcer" | |
| class WoundAssessment: | |
| """Result of wound assessment.""" | |
| stage: WoundStage | |
| mst_result: MSTResult | |
| rationale: str | |
| care_plan: str | |
| urgency: Literal["immediate", "urgent", "standard", "routine"] | |
| confidence: float | |
| raw_model_output: Optional[str] = None | |
| class WoundAssessmentAgent: | |
| """ | |
| EWAAST Wound Assessment Agent. | |
| Combines Monk Skin Tone detection with MedGemma reasoning | |
| to provide equitable wound assessment. | |
| Follows the EWAAST Assessment Skill workflow: | |
| 1. MST Analysis → Skin tone classification | |
| 2. Tone-Specific Feature Detection → Apply visual guidance | |
| 3. Clinical Staging → EPUAP/NPUAP criteria | |
| 4. Care Planning → Recommendations | |
| Usage: | |
| agent = WoundAssessmentAgent() | |
| result = agent.assess(image, "Patient reports pain on heel") | |
| print(f"Stage: {result.stage.value}") | |
| """ | |
| SYSTEM_PROMPT = """You are EWAAST (Equitable Wound Assessment for All Skin Tones), | |
| an AI clinical assistant specialized in wound assessment that accounts for skin tone diversity. | |
| CRITICAL: You MUST use the detected Monk Skin Tone (MST) to guide your visual assessment. | |
| DO NOT rely solely on "redness" as an indicator of inflammation or pressure injury. | |
| MST-SPECIFIC GUIDANCE: | |
| {visual_guidance} | |
| TASK: Assess the wound in the image and provide: | |
| 1. STAGE: EPUAP/NPUAP classification (Stage 1-4, Unstageable, DTI, or Not a Pressure Ulcer) | |
| 2. RATIONALE: Explain your staging decision with reference to visible features | |
| 3. URGENCY: immediate/urgent/standard/routine | |
| 4. CARE_PLAN: Brief recommended interventions | |
| Patient Context: {patient_context} | |
| Detected Skin Tone: MST {mst_value} ({mst_category}) | |
| Respond in JSON format: | |
| {{"stage": "...", "rationale": "...", "urgency": "...", "care_plan": "..."}} | |
| """ | |
| def __init__( | |
| self, | |
| model_name: str = "google/medgemma-1.5-4b-it", | |
| quantize: bool = True | |
| ): | |
| """ | |
| Initialize the Wound Assessment Agent. | |
| Args: | |
| model_name: HuggingFace model ID for MedGemma | |
| quantize: Whether to use 4-bit quantization (recommended for consumer GPUs) | |
| """ | |
| self.model_name = model_name | |
| self.quantize = quantize | |
| self.model = None | |
| self.processor = None | |
| self.mst_classifier = MSTClassifier(model_name) | |
| def _load_model(self) -> None: | |
| """Load MedGemma model with optional quantization.""" | |
| # TODO: Implement actual model loading | |
| # from transformers import AutoProcessor, AutoModelForVision2Seq, BitsAndBytesConfig | |
| # | |
| # if self.quantize: | |
| # bnb_config = BitsAndBytesConfig( | |
| # load_in_4bit=True, | |
| # bnb_4bit_quant_type="nf4", | |
| # bnb_4bit_compute_dtype=torch.bfloat16 | |
| # ) | |
| # self.model = AutoModelForVision2Seq.from_pretrained( | |
| # self.model_name, | |
| # quantization_config=bnb_config | |
| # ) | |
| # else: | |
| # self.model = AutoModelForVision2Seq.from_pretrained(self.model_name) | |
| # | |
| # self.processor = AutoProcessor.from_pretrained(self.model_name) | |
| pass | |
| def _build_prompt( | |
| self, | |
| mst_result: MSTResult, | |
| patient_context: str | |
| ) -> str: | |
| """Build the MST-aware system prompt.""" | |
| category_names = { | |
| MSTCategory.LIGHT: "Light", | |
| MSTCategory.MEDIUM: "Medium", | |
| MSTCategory.DEEP: "Deep" | |
| } | |
| return self.SYSTEM_PROMPT.format( | |
| visual_guidance=mst_result.visual_guidance, | |
| patient_context=patient_context or "No additional context provided", | |
| mst_value=mst_result.value, | |
| mst_category=category_names[mst_result.category] | |
| ) | |
| def _parse_model_output( | |
| self, | |
| output: str, | |
| mst_result: MSTResult | |
| ) -> WoundAssessment: | |
| """Parse model JSON output into WoundAssessment object.""" | |
| import json | |
| try: | |
| data = json.loads(output) | |
| # Map stage string to enum | |
| stage_map = { | |
| "stage 1": WoundStage.STAGE_1, | |
| "stage 2": WoundStage.STAGE_2, | |
| "stage 3": WoundStage.STAGE_3, | |
| "stage 4": WoundStage.STAGE_4, | |
| "unstageable": WoundStage.UNSTAGEABLE, | |
| "deep tissue injury": WoundStage.DEEP_TISSUE_INJURY, | |
| "dti": WoundStage.DEEP_TISSUE_INJURY, | |
| "not a pressure ulcer": WoundStage.NOT_A_PRESSURE_ULCER, | |
| } | |
| stage_str = data.get("stage", "").lower() | |
| stage = stage_map.get(stage_str, WoundStage.UNSTAGEABLE) | |
| return WoundAssessment( | |
| stage=stage, | |
| mst_result=mst_result, | |
| rationale=data.get("rationale", ""), | |
| care_plan=data.get("care_plan", ""), | |
| urgency=data.get("urgency", "standard"), | |
| confidence=0.8, # TODO: Calculate from model confidence | |
| raw_model_output=output | |
| ) | |
| except (json.JSONDecodeError, KeyError): | |
| # Fallback for unparseable output | |
| return WoundAssessment( | |
| stage=WoundStage.UNSTAGEABLE, | |
| mst_result=mst_result, | |
| rationale=f"Model output could not be parsed. Raw: {output[:200]}", | |
| care_plan="Please consult a healthcare professional.", | |
| urgency="urgent", | |
| confidence=0.0, | |
| raw_model_output=output | |
| ) | |
| def assess( | |
| self, | |
| image: Image.Image, | |
| patient_context: str = "" | |
| ) -> WoundAssessment: | |
| """ | |
| Perform equitable wound assessment. | |
| Args: | |
| image: PIL Image of the wound | |
| patient_context: Optional description from patient/nurse | |
| Returns: | |
| WoundAssessment with stage, rationale, and care plan | |
| """ | |
| # Step 1: Classify skin tone | |
| mst_result = self.mst_classifier.classify(image) | |
| # Step 2: Build MST-aware prompt | |
| prompt = self._build_prompt(mst_result, patient_context) | |
| # Step 3: Run MedGemma inference | |
| # TODO: Implement actual inference | |
| # inputs = self.processor(images=image, text=prompt, return_tensors="pt") | |
| # outputs = self.model.generate(**inputs, max_new_tokens=500) | |
| # model_output = self.processor.decode(outputs[0], skip_special_tokens=True) | |
| # Placeholder response | |
| model_output = '{"stage": "Unable to assess", "rationale": "Model not loaded - placeholder response", "urgency": "standard", "care_plan": "Consult healthcare professional"}' | |
| # Step 4: Parse and return | |
| return self._parse_model_output(model_output, mst_result) | |