# Copyright (c) 2026 Meta Platforms, Inc. and affiliates. # MedTriage Environment Implementation import uuid from typing import Any, Dict, Optional from uuid import uuid4 # Imports (Adjust according to actual structure) from openenv.core.env_server.mcp_environment import MCPEnvironment from openenv.core.env_server.types import Action, Observation, State from fastmcp import FastMCP # Use local models try: from .models import TriageLevel, TriageAction, TriageObservation, TriageState except ImportError: from models import TriageLevel, TriageAction, TriageObservation, TriageState # Task Scenarios (Easy -> Medium -> Hard) TASKS = { "TASK_EASY": { "id": "TASK_EASY", "name": "Seasonal Allergies", "patient": { "patient_id": "P-101", "age": 28, "gender": "Female", "symptoms_text": "I've had a runny nose, sneezing, and itchy eyes for the past week. It's really annoying but I don't feel 'sick' otherwise.", "vitals": {"temp": 98.6, "bp": "120/80", "hr": 72, "spo2": 99}, "history": ["No major conditions"] }, "ground_truth": TriageLevel.SELF_CARE }, "TASK_MEDIUM": { "id": "TASK_MEDIUM", "name": "Possible Appendicitis", "patient": { "patient_id": "P-102", "age": 19, "gender": "Male", "symptoms_text": "I woke up with severe pain around my belly button that's moving down to my lower right side. I feel nauseous and have zero appetite.", "vitals": {"temp": 100.8, "bp": "115/75", "hr": 95, "spo2": 98}, "history": ["No major conditions"] }, "ground_truth": TriageLevel.URGENT_CARE }, "TASK_HARD": { "id": "TASK_HARD", "name": "Atypical Myocardial Infarction", "patient": { "patient_id": "P-103", "age": 68, "gender": "Female", "symptoms_text": "I just feel extremely weak and have this weird 'indigestion' sensation in my upper stomach. I'm also sweating a lot for no reason.", "vitals": {"temp": 98.2, "bp": "165/100", "hr": 105, "spo2": 94}, "history": ["Type 2 Diabetes", "High Blood Pressure", "Smoking"] }, "ground_truth": TriageLevel.EMERGENCY } } class MedTriageEnvironment(MCPEnvironment): """ Real-world Triage Environment for Agent Training. """ def __init__(self): mcp = FastMCP("med_triage_env") @mcp.tool def triage_patient(level: int, reasoning: str) -> str: """ Analyze patient data and assign a triage level (0-3). Args: level: 0 (Self-Care), 1 (Clinic), 2 (Urgent Care), 3 (Emergency) reasoning: Medical explanation for your decision """ return f"Triage decision received: Level {level}. Reason: {reasoning}" super().__init__(mcp) self._state = TriageState(episode_id=str(uuid4())) self._current_task = None def reset(self, task_id: Optional[str] = "TASK_EASY", **kwargs: Any) -> TriageObservation: """Reset the environment with a specific task (EASY, MEDIUM, or HARD).""" task_id = task_id or "TASK_EASY" if task_id not in TASKS: task_id = "TASK_EASY" self._current_task = TASKS[task_id] self._state = TriageState( episode_id=str(uuid4()), step_count=0, current_task_id=task_id, ground_truth_level=self._current_task["ground_truth"] ) patient = self._current_task["patient"] return TriageObservation( patient_id=patient["patient_id"], age=patient["age"], gender=patient["gender"], symptoms_text=patient["symptoms_text"], vitals=patient["vitals"], history=patient["history"], message=f"New Patient Triage: {self._current_task['name']}" ) def _calculate_reward(self, agent_level: TriageLevel, ground_truth: TriageLevel) -> float: """ Scoring Logic (0.0 - 1.0): - Perfect Match: 1.0 - Over-triage (too safe): 0.5 (safe but resource heavy) - Minor Under-triage: 0.2 (delay in care) - Major Under-triage (dangerous): 0.0 (unsafe) """ if agent_level == ground_truth: return 1.0 # Dangerously Under-triaging an Emergency if ground_truth == TriageLevel.EMERGENCY and agent_level < TriageLevel.URGENT_CARE: return 0.0 # Over-triaging is better than under-triaging in medicine if agent_level > ground_truth: return 0.5 return 0.2 def _step_impl(self, action: Action, **kwargs: Any) -> TriageObservation: """ Process the agent's triage decision and return a score. """ self._state.step_count += 1 # If the action is an MCP CallToolAction from openenv.core.env_server.mcp_types import CallToolAction if isinstance(action, CallToolAction) and action.tool_name == "triage_patient": agent_level = action.arguments.get("level") reward = self._calculate_reward(TriageLevel(agent_level), self._state.ground_truth_level) self._last_reward = reward patient = self._current_task["patient"] return TriageObservation( **patient, done=True, reward=reward, message=f"Episode complete. Agent Triage: {agent_level}. Ground Truth: {self._state.ground_truth_level.value}. Score: {reward}" ) # Handle non-MCP fallback or invalid actions # For this env, any non-triage_patient action is a no-op or error if self._current_task: patient = self._current_task["patient"] return TriageObservation( **patient, message="Invalid action. Please use the triage_patient tool." ) else: return TriageObservation( patient_id="unknown", age=0, gender="unknown", symptoms_text="unknown", vitals={}, history=[], message="Invalid action and no task loaded." ) @property def state(self) -> State: return self._state