Spaces:
Sleeping
Sleeping
File size: 6,918 Bytes
11a8435 15f00b4 11a8435 15f00b4 11a8435 15f00b4 11a8435 34067c4 11a8435 bf6a463 11a8435 34067c4 11a8435 bf6a463 11a8435 bf6a463 11a8435 bf6a463 11a8435 34067c4 11a8435 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 | # Copyright (c) 2026 Meta Platforms, Inc. and affiliates.
# MedTriage Environment Implementation
import uuid
from typing import Any, Dict, Optional
from uuid import uuid4
# Imports (Adjust according to actual structure)
from openenv.core.env_server.mcp_environment import MCPEnvironment
from openenv.core.env_server.types import Action, Observation, State
from fastmcp import FastMCP
# Use local models
try:
from .models import TriageLevel, TriageAction, TriageObservation, TriageState
except ImportError:
from models import TriageLevel, TriageAction, TriageObservation, TriageState
# Task Scenarios (Easy -> Medium -> Hard)
TASKS = {
"TASK_EASY": {
"id": "TASK_EASY",
"name": "Seasonal Allergies",
"patient": {
"patient_id": "P-101", "age": 28, "gender": "Female",
"symptoms_text": "I've had a runny nose, sneezing, and itchy eyes for the past week. It's really annoying but I don't feel 'sick' otherwise.",
"vitals": {"temp": 98.6, "bp": "120/80", "hr": 72, "spo2": 99},
"history": ["No major conditions"]
},
"ground_truth": TriageLevel.SELF_CARE
},
"TASK_MEDIUM": {
"id": "TASK_MEDIUM",
"name": "Possible Appendicitis",
"patient": {
"patient_id": "P-102", "age": 19, "gender": "Male",
"symptoms_text": "I woke up with severe pain around my belly button that's moving down to my lower right side. I feel nauseous and have zero appetite.",
"vitals": {"temp": 100.8, "bp": "115/75", "hr": 95, "spo2": 98},
"history": ["No major conditions"]
},
"ground_truth": TriageLevel.URGENT_CARE
},
"TASK_HARD": {
"id": "TASK_HARD",
"name": "Atypical Myocardial Infarction",
"patient": {
"patient_id": "P-103", "age": 68, "gender": "Female",
"symptoms_text": "I just feel extremely weak and have this weird 'indigestion' sensation in my upper stomach. I'm also sweating a lot for no reason.",
"vitals": {"temp": 98.2, "bp": "165/100", "hr": 105, "spo2": 94},
"history": ["Type 2 Diabetes", "High Blood Pressure", "Smoking"]
},
"ground_truth": TriageLevel.EMERGENCY
}
}
class MedTriageEnvironment(MCPEnvironment):
"""
Real-world Triage Environment for Agent Training.
"""
def __init__(self):
mcp = FastMCP("med_triage_env")
@mcp.tool
def triage_patient(level: int, reasoning: str) -> str:
"""
Analyze patient data and assign a triage level (0-3).
Args:
level: 0 (Self-Care), 1 (Clinic), 2 (Urgent Care), 3 (Emergency)
reasoning: Medical explanation for your decision
"""
return f"Triage decision received: Level {level}. Reason: {reasoning}"
super().__init__(mcp)
self._state = TriageState(episode_id=str(uuid4()))
self._current_task = None
def reset(self, task_id: Optional[str] = "TASK_EASY", **kwargs: Any) -> TriageObservation:
"""Reset the environment with a specific task (EASY, MEDIUM, or HARD)."""
task_id = task_id or "TASK_EASY"
if task_id not in TASKS:
task_id = "TASK_EASY"
self._current_task = TASKS[task_id]
self._state = TriageState(
episode_id=str(uuid4()),
step_count=0,
current_task_id=task_id,
ground_truth_level=self._current_task["ground_truth"]
)
patient = self._current_task["patient"]
return TriageObservation(
patient_id=patient["patient_id"],
age=patient["age"],
gender=patient["gender"],
symptoms_text=patient["symptoms_text"],
vitals=patient["vitals"],
history=patient["history"],
message=f"New Patient Triage: {self._current_task['name']}"
)
def _calculate_reward(self, agent_level: TriageLevel, ground_truth: TriageLevel) -> float:
"""
Scoring Logic (Strictly 0.1 - 0.9 to pass Phase 2 Validation):
- Perfect Match: 0.9
- Over-triage (too safe): 0.5
- Minor Under-triage: 0.2
- Major Under-triage (dangerous): 0.1
"""
if agent_level == ground_truth:
return 0.9
# Dangerously Under-triaging an Emergency
if ground_truth == TriageLevel.EMERGENCY and agent_level < TriageLevel.URGENT_CARE:
return 0.1
# Over-triaging is better than under-triaging in medicine
if agent_level > ground_truth:
return 0.5
return 0.2
def _step_impl(self, action: Action, **kwargs: Any) -> TriageObservation:
"""
Process the agent's triage decision and return a score.
"""
print(f"DEBUG: Received action type: {type(action)}")
if hasattr(action, "tool_name"):
print(f"DEBUG: tool_name: {action.tool_name}")
self._state.step_count += 1
# If the action is an MCP CallToolAction
from openenv.core.env_server.mcp_types import CallToolAction
if isinstance(action, CallToolAction) and action.tool_name == "triage_patient":
agent_level = action.arguments.get("level")
reward = self._calculate_reward(TriageLevel(int(agent_level)), self._state.ground_truth_level)
self._last_reward = reward
patient = self._current_task["patient"]
# Ensure we return the model type expected by the app
return TriageObservation(
patient_id=patient["patient_id"],
age=patient["age"],
gender=patient["gender"],
symptoms_text=patient["symptoms_text"],
vitals=patient["vitals"],
history=patient["history"],
done=True,
reward=reward,
message=f"Episode complete. Agent Triage: {agent_level}. Ground Truth: {self._state.ground_truth_level.value}. Score: {reward}"
)
# Handle non-MCP fallback or invalid actions
# For this env, any non-triage_patient action is a no-op or error
if self._current_task:
patient = self._current_task["patient"]
return TriageObservation(
**patient,
message="Invalid action. Please use the triage_patient tool."
)
else:
return TriageObservation(
patient_id="unknown",
age=0,
gender="unknown",
symptoms_text="unknown",
vitals={},
history=[],
message="Invalid action and no task loaded."
)
@property
def state(self) -> State:
return self._state
|