File size: 6,918 Bytes
11a8435
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15f00b4
 
 
 
 
11a8435
 
15f00b4
11a8435
 
 
15f00b4
11a8435
 
 
 
 
 
 
34067c4
11a8435
 
 
bf6a463
 
 
11a8435
 
34067c4
11a8435
 
 
 
bf6a463
11a8435
 
 
bf6a463
11a8435
bf6a463
 
 
 
 
 
11a8435
 
 
 
 
 
34067c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11a8435
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
# Copyright (c) 2026 Meta Platforms, Inc. and affiliates.
# MedTriage Environment Implementation

import uuid
from typing import Any, Dict, Optional
from uuid import uuid4

# Imports (Adjust according to actual structure)
from openenv.core.env_server.mcp_environment import MCPEnvironment
from openenv.core.env_server.types import Action, Observation, State
from fastmcp import FastMCP

# Use local models
try:
    from .models import TriageLevel, TriageAction, TriageObservation, TriageState
except ImportError:
    from models import TriageLevel, TriageAction, TriageObservation, TriageState

# Task Scenarios (Easy -> Medium -> Hard)
TASKS = {
    "TASK_EASY": {
        "id": "TASK_EASY",
        "name": "Seasonal Allergies",
        "patient": {
            "patient_id": "P-101", "age": 28, "gender": "Female",
            "symptoms_text": "I've had a runny nose, sneezing, and itchy eyes for the past week. It's really annoying but I don't feel 'sick' otherwise.",
            "vitals": {"temp": 98.6, "bp": "120/80", "hr": 72, "spo2": 99},
            "history": ["No major conditions"]
        },
        "ground_truth": TriageLevel.SELF_CARE
    },
    "TASK_MEDIUM": {
        "id": "TASK_MEDIUM",
        "name": "Possible Appendicitis",
        "patient": {
            "patient_id": "P-102", "age": 19, "gender": "Male",
            "symptoms_text": "I woke up with severe pain around my belly button that's moving down to my lower right side. I feel nauseous and have zero appetite.",
            "vitals": {"temp": 100.8, "bp": "115/75", "hr": 95, "spo2": 98},
            "history": ["No major conditions"]
        },
        "ground_truth": TriageLevel.URGENT_CARE
    },
    "TASK_HARD": {
        "id": "TASK_HARD",
        "name": "Atypical Myocardial Infarction",
        "patient": {
            "patient_id": "P-103", "age": 68, "gender": "Female",
            "symptoms_text": "I just feel extremely weak and have this weird 'indigestion' sensation in my upper stomach. I'm also sweating a lot for no reason.",
            "vitals": {"temp": 98.2, "bp": "165/100", "hr": 105, "spo2": 94},
            "history": ["Type 2 Diabetes", "High Blood Pressure", "Smoking"]
        },
        "ground_truth": TriageLevel.EMERGENCY
    }
}

class MedTriageEnvironment(MCPEnvironment):
    """
    Real-world Triage Environment for Agent Training.
    """

    def __init__(self):
        mcp = FastMCP("med_triage_env")

        @mcp.tool
        def triage_patient(level: int, reasoning: str) -> str:
            """
            Analyze patient data and assign a triage level (0-3).
            
            Args:
                level: 0 (Self-Care), 1 (Clinic), 2 (Urgent Care), 3 (Emergency)
                reasoning: Medical explanation for your decision
            """
            return f"Triage decision received: Level {level}. Reason: {reasoning}"

        super().__init__(mcp)
        self._state = TriageState(episode_id=str(uuid4()))
        self._current_task = None

    def reset(self, task_id: Optional[str] = "TASK_EASY", **kwargs: Any) -> TriageObservation:
        """Reset the environment with a specific task (EASY, MEDIUM, or HARD)."""
        task_id = task_id or "TASK_EASY"
        if task_id not in TASKS:
            task_id = "TASK_EASY"
            
        self._current_task = TASKS[task_id]
        self._state = TriageState(
            episode_id=str(uuid4()),
            step_count=0,
            current_task_id=task_id,
            ground_truth_level=self._current_task["ground_truth"]
        )

        patient = self._current_task["patient"]
        return TriageObservation(
            patient_id=patient["patient_id"],
            age=patient["age"],
            gender=patient["gender"],
            symptoms_text=patient["symptoms_text"],
            vitals=patient["vitals"],
            history=patient["history"],
            message=f"New Patient Triage: {self._current_task['name']}"
        )

    def _calculate_reward(self, agent_level: TriageLevel, ground_truth: TriageLevel) -> float:
        """
        Scoring Logic (Strictly 0.1 - 0.9 to pass Phase 2 Validation):
        - Perfect Match: 0.9
        - Over-triage (too safe): 0.5
        - Minor Under-triage: 0.2
        - Major Under-triage (dangerous): 0.1
        """
        if agent_level == ground_truth:
            return 0.9
        
        # Dangerously Under-triaging an Emergency
        if ground_truth == TriageLevel.EMERGENCY and agent_level < TriageLevel.URGENT_CARE:
            return 0.1
            
        # Over-triaging is better than under-triaging in medicine
        if agent_level > ground_truth:
            return 0.5
            
        return 0.2

    def _step_impl(self, action: Action, **kwargs: Any) -> TriageObservation:
        """
        Process the agent's triage decision and return a score.
        """
        print(f"DEBUG: Received action type: {type(action)}")
        if hasattr(action, "tool_name"):
            print(f"DEBUG: tool_name: {action.tool_name}")
        self._state.step_count += 1
        
        # If the action is an MCP CallToolAction
        from openenv.core.env_server.mcp_types import CallToolAction
        
        if isinstance(action, CallToolAction) and action.tool_name == "triage_patient":
            agent_level = action.arguments.get("level")
            reward = self._calculate_reward(TriageLevel(int(agent_level)), self._state.ground_truth_level)
            self._last_reward = reward
            
            patient = self._current_task["patient"]
            # Ensure we return the model type expected by the app
            return TriageObservation(
                patient_id=patient["patient_id"],
                age=patient["age"],
                gender=patient["gender"],
                symptoms_text=patient["symptoms_text"],
                vitals=patient["vitals"],
                history=patient["history"],
                done=True,
                reward=reward,
                message=f"Episode complete. Agent Triage: {agent_level}. Ground Truth: {self._state.ground_truth_level.value}. Score: {reward}"
            )

        # Handle non-MCP fallback or invalid actions
        # For this env, any non-triage_patient action is a no-op or error
        if self._current_task:
            patient = self._current_task["patient"]
            return TriageObservation(
                **patient,
                message="Invalid action. Please use the triage_patient tool."
            )
        else:
            return TriageObservation(
                patient_id="unknown",
                age=0,
                gender="unknown",
                symptoms_text="unknown",
                vitals={},
                history=[],
                message="Invalid action and no task loaded."
            )

    @property
    def state(self) -> State:
        return self._state