ashdev commited on
Commit
33fa849
·
verified ·
1 Parent(s): b084668

Upload folder using huggingface_hub

Browse files
Files changed (13) hide show
  1. Dockerfile +29 -0
  2. README.md +98 -5
  3. __init__.py +0 -0
  4. client.py +17 -0
  5. inference.py +46 -0
  6. models.py +38 -0
  7. openenv.yaml +6 -0
  8. pyproject.toml +28 -0
  9. server/__init__.py +0 -0
  10. server/app.py +83 -0
  11. server/requirements.txt +6 -0
  12. server/triage_environment.py +153 -0
  13. uv.lock +0 -0
Dockerfile ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 Meta Platforms, Inc. and affiliates.
2
+ # Dockerfile for MedTriage Environment
3
+
4
+ FROM python:3.11-slim
5
+
6
+ WORKDIR /app
7
+
8
+ # Install system dependencies
9
+ RUN apt-get update && apt-get install -y --no-install-recommends \
10
+ curl \
11
+ && rm -rf /var/lib/apt/lists/*
12
+
13
+ # Install Python dependencies
14
+ COPY server/requirements.txt .
15
+ RUN pip install --no-cache-dir -r requirements.txt
16
+
17
+ # Copy environment code
18
+ COPY . /app/
19
+
20
+ # Set PYTHONPATH to include current directory for imports
21
+ ENV PYTHONPATH="/app:$PYTHONPATH"
22
+
23
+ # Health check
24
+ HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
25
+ CMD curl -f http://localhost:8002/health || exit 1
26
+
27
+ # Run the FastAPI server
28
+ ENV ENABLE_WEB_INTERFACE=true
29
+ CMD ["uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "8002"]
README.md CHANGED
@@ -1,10 +1,103 @@
1
  ---
2
- title: Med Triage Openenv
3
- emoji: 🚀
4
- colorFrom: green
5
- colorTo: green
6
  sdk: docker
7
  pinned: false
 
 
 
 
 
 
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: MedTriage OpenEnv
3
+ emoji: 🏥
 
 
4
  sdk: docker
5
  pinned: false
6
+ app_port: 8002
7
+ tags:
8
+ - openenv
9
+ - healthcare
10
+ - ai-agents
11
+ base_path: /web
12
  ---
13
 
14
+ # MedTriage OpenEnv
15
+
16
+ A real-world medical triage simulation environment built for the Meta PyTorch OpenEnv Hackathon. This environment allows AI agents to learn how to categorize patient symptoms into appropriate clinical triage levels using the standard OpenEnv API.
17
+
18
+ ## 📋 Environment Overview
19
+
20
+ **MedTriage** simulates the decision-making process of a clinical triage officer. The agent receives patient demographics, vitals, and unstructured symptom text, and must decide on the safest and most efficient path for care.
21
+
22
+ ### 🎯 Real-World Utility
23
+ In real healthcare settings, accurate triage is critical for:
24
+ 1. **Patient Safety**: Ensuring life-threatening conditions (like heart attacks) are seen immediately.
25
+ 2. **Resource Optimization**: Preventing hospital ERs from being overwhelmed by minor cases that can be treated at home.
26
+
27
+ ---
28
+
29
+ ## 🎮 Action Space
30
+
31
+ The agent interacts via the `triage_patient` tool:
32
+
33
+ - **level**: (IntEnum)
34
+ - `0`: **Self-Care** (Over-the-counter/rest)
35
+ - `1`: **Clinic** (Primary Care appointment in 24-48h)
36
+ - `2`: **Urgent Care** (Same-day care)
37
+ - `3`: **Emergency** (Immediate ER/Ambulance)
38
+ - **reasoning**: (String) A medical justification for the triage level.
39
+
40
+ ---
41
+
42
+ ## 📥 Observation Space
43
+
44
+ Each observation provides:
45
+ - **patient_id**: Unique identifier.
46
+ - **age / gender**: Basic demographics.
47
+ - **symptoms_text**: Unstructured description of the patient's complaint.
48
+ - **vitals**: Dictionary containing `temp`, `bp` (Blood Pressure), `hr` (Heart Rate), and `spo2` (Oxygen).
49
+ - **history**: List of prior medical conditions or medications.
50
+
51
+ ---
52
+
53
+ ## 🚀 Tasks & Difficulty
54
+
55
+ The environment includes 3 built-in tasks with automated graders:
56
+
57
+ | Task ID | Name | Difficulty | Ground Truth |
58
+ |---------|------|------------|--------------|
59
+ | `TASK_EASY` | Seasonal Allergies | Easy | Self-Care (0) |
60
+ | `TASK_MEDIUM` | Possible Appendicitis | Medium | Urgent Care (2) |
61
+ | `TASK_HARD` | Atypical MI | Hard | Emergency (3) |
62
+
63
+ ---
64
+
65
+ ## 📈 Reward Function (Grader)
66
+
67
+ Scores range from **0.0 to 1.0**:
68
+ - **1.0**: Perfect match with ground truth.
69
+ - **0.5**: Over-triage (Safe but resource-intensive).
70
+ - **0.2**: Minor under-triage.
71
+ - **0.0**: Dangerous under-triage (e.g., sending a heart attack to self-care).
72
+
73
+ ---
74
+
75
+ ## 🛠️ Setup & Usage
76
+
77
+ ### Local Development
78
+ 1. **Install Dependencies**:
79
+ ```bash
80
+ pip install -e .
81
+ ```
82
+ 2. **Start the Server**:
83
+ ```bash
84
+ python server/app.py
85
+ ```
86
+ 3. **Run Baseline**:
87
+ ```bash
88
+ python inference.py
89
+ ```
90
+
91
+ ### Docker
92
+ ```bash
93
+ docker build -t med-triage-env:latest .
94
+ docker run -p 8002:8002 med-triage-env:latest
95
+ ```
96
+
97
+ ---
98
+
99
+ ## 🌐 API Endpoints
100
+ - `/tasks`: List all available tasks.
101
+ - `/baseline`: Run the baseline inference.
102
+ - `/grader`: Get the score of the last episode.
103
+ - `/health`: Environment health check.
__init__.py ADDED
File without changes
client.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 Meta Platforms, Inc. and affiliates.
2
+ # MedTriage Environment Client
3
+
4
+ from openenv.core.mcp_client import MCPToolClient
5
+
6
+ class MedTriageEnv(MCPToolClient):
7
+ """
8
+ Client for the MedTriage Environment.
9
+
10
+ Example:
11
+ >>> with MedTriageEnv(base_url="http://localhost:8000") as env:
12
+ ... obs = env.reset(task_id="TASK_HARD")
13
+ ... print(obs.symptoms_text)
14
+ ... result = env.call_tool("triage_patient", level=3, reasoning="High BP and atypical symptoms in elderly patient.")
15
+ ... print(f"Reward: {result.reward}")
16
+ """
17
+ pass
inference.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 Meta Platforms, Inc. and affiliates.
2
+ # MedTriage - Baseline Inference Script
3
+
4
+ import os
5
+ import time
6
+ import subprocess
7
+ import requests
8
+ from client import MedTriageEnv
9
+
10
+ def run_baseline(base_url: str = "http://localhost:8002"):
11
+ """Run baseline agent against all 3 tasks and return results."""
12
+ print(f"🚀 Starting MedTriage Baseline Inference on {base_url}...")
13
+
14
+ tasks = ["TASK_EASY", "TASK_MEDIUM", "TASK_HARD"]
15
+ scores = {}
16
+
17
+ # Simple heuristic-based baseline (no LLM required for this local test)
18
+ try:
19
+ from client import MedTriageEnv
20
+ except ImportError:
21
+ from .client import MedTriageEnv
22
+
23
+ with MedTriageEnv(base_url=base_url).sync() as env:
24
+ for task_id in tasks:
25
+ print(f"📋 Running {task_id}...", end=" ", flush=True)
26
+ obs = env.reset(task_id=task_id)
27
+
28
+ # Simple heuristic logic
29
+ bp_sys = int(obs.vitals.get("bp", "120/80").split("/")[0])
30
+
31
+ if bp_sys > 150 or obs.age > 65:
32
+ level = 3 # EMERGENCY
33
+ elif "severe pain" in obs.symptoms_text.lower():
34
+ level = 2 # URGENT_CARE
35
+ else:
36
+ level = 0 # SELF_CARE
37
+
38
+ result = env.step({"tool_name": "triage_patient", "arguments": {"level": level, "reasoning": "Heuristic baseline."}})
39
+ scores[task_id] = result.reward
40
+ print(f"Score: {result.reward}")
41
+
42
+ return scores
43
+
44
+ if __name__ == "__main__":
45
+ results = run_baseline()
46
+ print("\n📊 FINAL BASELINE SCORES:", results)
models.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 Meta Platforms, Inc. and affiliates.
2
+ # MedTriage Environment - Type-Safe Models
3
+
4
+ from typing import Dict, List, Optional, Any
5
+ from enum import IntEnum
6
+ from pydantic import BaseModel, Field
7
+
8
+ # Core Triage Levels
9
+ class TriageLevel(IntEnum):
10
+ SELF_CARE = 0 # Over-the-counter/rest
11
+ CLINIC = 1 # Primary Care appointment (next 24-48h)
12
+ URGENT_CARE = 2 # Same-day clinic (e.g., potential fracture)
13
+ EMERGENCY = 3 # Immediate ER/Ambulance (life-threatening)
14
+
15
+ # 1. Action Model
16
+ class TriageAction(BaseModel):
17
+ level: TriageLevel = Field(..., description="Recommended triage level (0-3)")
18
+ reasoning: str = Field(..., description="Medical justification for the triage level")
19
+ follow_up_questions: List[str] = Field(default_factory=list, description="Questions to ask the patient if more info is needed")
20
+
21
+ # 2. Observation Model
22
+ class TriageObservation(BaseModel):
23
+ patient_id: str
24
+ age: int
25
+ gender: str
26
+ symptoms_text: str = Field(..., description="Unstructured description of symptoms from patient")
27
+ vitals: Dict[str, Any] = Field(default_factory=dict, description="Vitals like temp, bp, hr, spo2")
28
+ history: List[str] = Field(default_factory=list, description="Relevant past conditions or medications")
29
+ done: bool = False
30
+ reward: float = 0.0
31
+ message: str = ""
32
+
33
+ # 3. State Model (Metadata)
34
+ class TriageState(BaseModel):
35
+ episode_id: str
36
+ step_count: int = 0
37
+ current_task_id: str = ""
38
+ ground_truth_level: TriageLevel = TriageLevel.SELF_CARE
openenv.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ spec_version: 1
2
+ name: med_triage_env
3
+ type: space
4
+ runtime: fastapi
5
+ app: server.app:app
6
+ port: 8002
pyproject.toml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "med_triage_env"
3
+ version = "0.1.0"
4
+ description = "Real-world Medical Triage environment for OpenEnv training."
5
+ authors = [
6
+ { name = "Gemini AI", email = "gemini@example.com" }
7
+ ]
8
+ dependencies = [
9
+ "fastapi",
10
+ "uvicorn",
11
+ "pydantic",
12
+ "fastmcp",
13
+ "requests",
14
+ "openenv-core>=0.2.0",
15
+ ]
16
+
17
+ [project.scripts]
18
+ server = "server.app:main"
19
+
20
+ [project.optional-dependencies]
21
+ dev = [
22
+ "pytest",
23
+ "httpx",
24
+ ]
25
+
26
+ [build-system]
27
+ requires = ["setuptools>=61.0"]
28
+ build-backend = "setuptools.build_meta"
server/__init__.py ADDED
File without changes
server/app.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 Meta Platforms, Inc. and affiliates.
2
+ # FastAPI application for MedTriage Environment
3
+
4
+ from fastapi import FastAPI, Request
5
+ from openenv.core.env_server.http_server import create_app
6
+ from openenv.core.env_server.mcp_types import CallToolAction, CallToolObservation
7
+ from pydantic import BaseModel
8
+ from typing import List, Dict, Any
9
+
10
+ try:
11
+ from .triage_environment import MedTriageEnvironment, TASKS
12
+ from .models import TriageAction
13
+ except ImportError:
14
+ from triage_environment import MedTriageEnvironment, TASKS
15
+ from models import TriageAction
16
+
17
+ # Initialize the environment instance to be used by the app
18
+ env_instance = MedTriageEnvironment()
19
+
20
+ # Create the base OpenEnv app
21
+ app = create_app(
22
+ MedTriageEnvironment,
23
+ CallToolAction,
24
+ CallToolObservation,
25
+ env_name="med_triage_env"
26
+ )
27
+
28
+ # --- Additional Hackathon Endpoints ---
29
+
30
+ @app.get("/tasks")
31
+ async def get_tasks():
32
+ """Returns list of tasks and the action schema."""
33
+ task_list = []
34
+ for tid, tdata in TASKS.items():
35
+ task_list.append({
36
+ "id": tid,
37
+ "name": tdata["name"],
38
+ "difficulty": tid.split("_")[1].lower()
39
+ })
40
+
41
+ return {
42
+ "tasks": task_list,
43
+ "action_schema": TriageAction.model_json_schema()
44
+ }
45
+
46
+ @app.get("/grader")
47
+ async def get_grader():
48
+ """Returns the most recent grader score."""
49
+ state = env_instance.state
50
+ # In a real multi-session env, we'd lookup by session_id
51
+ # For a simple demo, we return the last calculated reward if available
52
+ return {"score": getattr(env_instance, "_last_reward", 0.0)}
53
+
54
+ @app.get("/baseline")
55
+ async def trigger_baseline():
56
+ """
57
+ Trigger baseline inference script and return scores.
58
+ """
59
+ try:
60
+ from ..inference import run_baseline
61
+ except ImportError:
62
+ import sys
63
+ import os
64
+ # Add parent dir to sys.path if not there
65
+ parent_dir = os.path.dirname(os.path.dirname(__file__))
66
+ if parent_dir not in sys.path:
67
+ sys.path.append(parent_dir)
68
+ from inference import run_baseline
69
+
70
+ # Execute actual baseline
71
+ scores = run_baseline(base_url="http://localhost:8002")
72
+
73
+ return {
74
+ "status": "baseline_completed",
75
+ "baseline_scores": scores
76
+ }
77
+
78
+ def main():
79
+ import uvicorn
80
+ uvicorn.run(app, host="0.0.0.0", port=8002)
81
+
82
+ if __name__ == "__main__":
83
+ main()
server/requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ pydantic
4
+ fastmcp
5
+ requests
6
+ openenv
server/triage_environment.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 Meta Platforms, Inc. and affiliates.
2
+ # MedTriage Environment Implementation
3
+
4
+ import uuid
5
+ from typing import Any, Dict, Optional
6
+ from uuid import uuid4
7
+
8
+ # Imports (Adjust according to actual structure)
9
+ from openenv.core.env_server.mcp_environment import MCPEnvironment
10
+ from openenv.core.env_server.types import Action, Observation, State
11
+ from fastmcp import FastMCP
12
+
13
+ # Use local models
14
+ try:
15
+ from .models import TriageLevel, TriageAction, TriageObservation, TriageState
16
+ except ImportError:
17
+ from models import TriageLevel, TriageAction, TriageObservation, TriageState
18
+
19
+ # Task Scenarios (Easy -> Medium -> Hard)
20
+ TASKS = {
21
+ "TASK_EASY": {
22
+ "id": "TASK_EASY",
23
+ "name": "Seasonal Allergies",
24
+ "patient": {
25
+ "patient_id": "P-101", "age": 28, "gender": "Female",
26
+ "symptoms_text": "I've had a runny nose, sneezing, and itchy eyes for the past week. It's really annoying but I don't feel 'sick' otherwise.",
27
+ "vitals": {"temp": 98.6, "bp": "120/80", "hr": 72, "spo2": 99},
28
+ "history": ["No major conditions"]
29
+ },
30
+ "ground_truth": TriageLevel.SELF_CARE
31
+ },
32
+ "TASK_MEDIUM": {
33
+ "id": "TASK_MEDIUM",
34
+ "name": "Possible Appendicitis",
35
+ "patient": {
36
+ "patient_id": "P-102", "age": 19, "gender": "Male",
37
+ "symptoms_text": "I woke up with severe pain around my belly button that's moving down to my lower right side. I feel nauseous and have zero appetite.",
38
+ "vitals": {"temp": 100.8, "bp": "115/75", "hr": 95, "spo2": 98},
39
+ "history": ["No major conditions"]
40
+ },
41
+ "ground_truth": TriageLevel.URGENT_CARE
42
+ },
43
+ "TASK_HARD": {
44
+ "id": "TASK_HARD",
45
+ "name": "Atypical Myocardial Infarction",
46
+ "patient": {
47
+ "patient_id": "P-103", "age": 68, "gender": "Female",
48
+ "symptoms_text": "I just feel extremely weak and have this weird 'indigestion' sensation in my upper stomach. I'm also sweating a lot for no reason.",
49
+ "vitals": {"temp": 98.2, "bp": "165/100", "hr": 105, "spo2": 94},
50
+ "history": ["Type 2 Diabetes", "High Blood Pressure", "Smoking"]
51
+ },
52
+ "ground_truth": TriageLevel.EMERGENCY
53
+ }
54
+ }
55
+
56
+ class MedTriageEnvironment(MCPEnvironment):
57
+ """
58
+ Real-world Triage Environment for Agent Training.
59
+ """
60
+
61
+ def __init__(self):
62
+ mcp = FastMCP("med_triage_env")
63
+
64
+ @mcp.tool
65
+ def triage_patient(level: int, reasoning: str) -> str:
66
+ """
67
+ Analyze patient data and assign a triage level (0-3).
68
+
69
+ Args:
70
+ level: 0 (Self-Care), 1 (Clinic), 2 (Urgent Care), 3 (Emergency)
71
+ reasoning: Medical explanation for your decision
72
+ """
73
+ return f"Triage decision received: Level {level}. Reason: {reasoning}"
74
+
75
+ super().__init__(mcp)
76
+ self._state = TriageState(episode_id=str(uuid4()))
77
+ self._current_task = None
78
+
79
+ def reset(self, task_id: Optional[str] = "TASK_EASY", **kwargs: Any) -> TriageObservation:
80
+ """Reset the environment with a specific task (EASY, MEDIUM, or HARD)."""
81
+ task_id = task_id or "TASK_EASY"
82
+ if task_id not in TASKS:
83
+ task_id = "TASK_EASY"
84
+
85
+ self._current_task = TASKS[task_id]
86
+ self._state = TriageState(
87
+ episode_id=str(uuid4()),
88
+ step_count=0,
89
+ current_task_id=task_id,
90
+ ground_truth_level=self._current_task["ground_truth"]
91
+ )
92
+
93
+ patient = self._current_task["patient"]
94
+ return TriageObservation(
95
+ patient_id=patient["patient_id"],
96
+ age=patient["age"],
97
+ gender=patient["gender"],
98
+ symptoms_text=patient["symptoms_text"],
99
+ vitals=patient["vitals"],
100
+ history=patient["history"],
101
+ message=f"New Patient Triage: {self._current_task['name']}"
102
+ )
103
+
104
+ def _calculate_reward(self, agent_level: TriageLevel, ground_truth: TriageLevel) -> float:
105
+ """
106
+ Scoring Logic (0.0 - 1.0):
107
+ - Perfect Match: 1.0
108
+ - Over-triage (too safe): 0.5 (safe but resource heavy)
109
+ - Minor Under-triage: 0.2 (delay in care)
110
+ - Major Under-triage (dangerous): 0.0 (unsafe)
111
+ """
112
+ if agent_level == ground_truth:
113
+ return 1.0
114
+
115
+ # Dangerously Under-triaging an Emergency
116
+ if ground_truth == TriageLevel.EMERGENCY and agent_level < TriageLevel.URGENT_CARE:
117
+ return 0.0
118
+
119
+ # Over-triaging is better than under-triaging in medicine
120
+ if agent_level > ground_truth:
121
+ return 0.5
122
+
123
+ return 0.2
124
+
125
+ def step(self, action: Action, **kwargs: Any) -> TriageObservation:
126
+ """
127
+ Process the agent's triage decision and return a score.
128
+ """
129
+ self._state.step_count += 1
130
+
131
+ # If the action is an MCP CallToolAction (from step())
132
+ from openenv.core.env_server.mcp_types import CallToolAction
133
+
134
+ if isinstance(action, CallToolAction) and action.tool_name == "triage_patient":
135
+ agent_level = action.arguments.get("level")
136
+ reward = self._calculate_reward(TriageLevel(agent_level), self._state.ground_truth_level)
137
+ self._last_reward = reward
138
+
139
+ patient = self._current_task["patient"]
140
+ return TriageObservation(
141
+ **patient,
142
+ done=True,
143
+ reward=reward,
144
+ message=f"Episode complete. Agent Triage: {agent_level}. Ground Truth: {self._state.ground_truth_level.value}. Score: {reward}"
145
+ )
146
+
147
+ # Handle non-MCP fallback or invalid actions
148
+ obs = super().step(action, **kwargs)
149
+ return obs
150
+
151
+ @property
152
+ def state(self) -> State:
153
+ return self._state
uv.lock ADDED
The diff for this file is too large to render. See raw diff