Spaces:
Running
Running
Upload folder using huggingface_hub
Browse files- Dockerfile +29 -0
- README.md +98 -5
- __init__.py +0 -0
- client.py +17 -0
- inference.py +46 -0
- models.py +38 -0
- openenv.yaml +6 -0
- pyproject.toml +28 -0
- server/__init__.py +0 -0
- server/app.py +83 -0
- server/requirements.txt +6 -0
- server/triage_environment.py +153 -0
- uv.lock +0 -0
Dockerfile
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# Dockerfile for MedTriage Environment
|
| 3 |
+
|
| 4 |
+
FROM python:3.11-slim
|
| 5 |
+
|
| 6 |
+
WORKDIR /app
|
| 7 |
+
|
| 8 |
+
# Install system dependencies
|
| 9 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 10 |
+
curl \
|
| 11 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 12 |
+
|
| 13 |
+
# Install Python dependencies
|
| 14 |
+
COPY server/requirements.txt .
|
| 15 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 16 |
+
|
| 17 |
+
# Copy environment code
|
| 18 |
+
COPY . /app/
|
| 19 |
+
|
| 20 |
+
# Set PYTHONPATH to include current directory for imports
|
| 21 |
+
ENV PYTHONPATH="/app:$PYTHONPATH"
|
| 22 |
+
|
| 23 |
+
# Health check
|
| 24 |
+
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
|
| 25 |
+
CMD curl -f http://localhost:8002/health || exit 1
|
| 26 |
+
|
| 27 |
+
# Run the FastAPI server
|
| 28 |
+
ENV ENABLE_WEB_INTERFACE=true
|
| 29 |
+
CMD ["uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "8002"]
|
README.md
CHANGED
|
@@ -1,10 +1,103 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom: green
|
| 5 |
-
colorTo: green
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
---
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: MedTriage OpenEnv
|
| 3 |
+
emoji: 🏥
|
|
|
|
|
|
|
| 4 |
sdk: docker
|
| 5 |
pinned: false
|
| 6 |
+
app_port: 8002
|
| 7 |
+
tags:
|
| 8 |
+
- openenv
|
| 9 |
+
- healthcare
|
| 10 |
+
- ai-agents
|
| 11 |
+
base_path: /web
|
| 12 |
---
|
| 13 |
|
| 14 |
+
# MedTriage OpenEnv
|
| 15 |
+
|
| 16 |
+
A real-world medical triage simulation environment built for the Meta PyTorch OpenEnv Hackathon. This environment allows AI agents to learn how to categorize patient symptoms into appropriate clinical triage levels using the standard OpenEnv API.
|
| 17 |
+
|
| 18 |
+
## 📋 Environment Overview
|
| 19 |
+
|
| 20 |
+
**MedTriage** simulates the decision-making process of a clinical triage officer. The agent receives patient demographics, vitals, and unstructured symptom text, and must decide on the safest and most efficient path for care.
|
| 21 |
+
|
| 22 |
+
### 🎯 Real-World Utility
|
| 23 |
+
In real healthcare settings, accurate triage is critical for:
|
| 24 |
+
1. **Patient Safety**: Ensuring life-threatening conditions (like heart attacks) are seen immediately.
|
| 25 |
+
2. **Resource Optimization**: Preventing hospital ERs from being overwhelmed by minor cases that can be treated at home.
|
| 26 |
+
|
| 27 |
+
---
|
| 28 |
+
|
| 29 |
+
## 🎮 Action Space
|
| 30 |
+
|
| 31 |
+
The agent interacts via the `triage_patient` tool:
|
| 32 |
+
|
| 33 |
+
- **level**: (IntEnum)
|
| 34 |
+
- `0`: **Self-Care** (Over-the-counter/rest)
|
| 35 |
+
- `1`: **Clinic** (Primary Care appointment in 24-48h)
|
| 36 |
+
- `2`: **Urgent Care** (Same-day care)
|
| 37 |
+
- `3`: **Emergency** (Immediate ER/Ambulance)
|
| 38 |
+
- **reasoning**: (String) A medical justification for the triage level.
|
| 39 |
+
|
| 40 |
+
---
|
| 41 |
+
|
| 42 |
+
## 📥 Observation Space
|
| 43 |
+
|
| 44 |
+
Each observation provides:
|
| 45 |
+
- **patient_id**: Unique identifier.
|
| 46 |
+
- **age / gender**: Basic demographics.
|
| 47 |
+
- **symptoms_text**: Unstructured description of the patient's complaint.
|
| 48 |
+
- **vitals**: Dictionary containing `temp`, `bp` (Blood Pressure), `hr` (Heart Rate), and `spo2` (Oxygen).
|
| 49 |
+
- **history**: List of prior medical conditions or medications.
|
| 50 |
+
|
| 51 |
+
---
|
| 52 |
+
|
| 53 |
+
## 🚀 Tasks & Difficulty
|
| 54 |
+
|
| 55 |
+
The environment includes 3 built-in tasks with automated graders:
|
| 56 |
+
|
| 57 |
+
| Task ID | Name | Difficulty | Ground Truth |
|
| 58 |
+
|---------|------|------------|--------------|
|
| 59 |
+
| `TASK_EASY` | Seasonal Allergies | Easy | Self-Care (0) |
|
| 60 |
+
| `TASK_MEDIUM` | Possible Appendicitis | Medium | Urgent Care (2) |
|
| 61 |
+
| `TASK_HARD` | Atypical MI | Hard | Emergency (3) |
|
| 62 |
+
|
| 63 |
+
---
|
| 64 |
+
|
| 65 |
+
## 📈 Reward Function (Grader)
|
| 66 |
+
|
| 67 |
+
Scores range from **0.0 to 1.0**:
|
| 68 |
+
- **1.0**: Perfect match with ground truth.
|
| 69 |
+
- **0.5**: Over-triage (Safe but resource-intensive).
|
| 70 |
+
- **0.2**: Minor under-triage.
|
| 71 |
+
- **0.0**: Dangerous under-triage (e.g., sending a heart attack to self-care).
|
| 72 |
+
|
| 73 |
+
---
|
| 74 |
+
|
| 75 |
+
## 🛠️ Setup & Usage
|
| 76 |
+
|
| 77 |
+
### Local Development
|
| 78 |
+
1. **Install Dependencies**:
|
| 79 |
+
```bash
|
| 80 |
+
pip install -e .
|
| 81 |
+
```
|
| 82 |
+
2. **Start the Server**:
|
| 83 |
+
```bash
|
| 84 |
+
python server/app.py
|
| 85 |
+
```
|
| 86 |
+
3. **Run Baseline**:
|
| 87 |
+
```bash
|
| 88 |
+
python inference.py
|
| 89 |
+
```
|
| 90 |
+
|
| 91 |
+
### Docker
|
| 92 |
+
```bash
|
| 93 |
+
docker build -t med-triage-env:latest .
|
| 94 |
+
docker run -p 8002:8002 med-triage-env:latest
|
| 95 |
+
```
|
| 96 |
+
|
| 97 |
+
---
|
| 98 |
+
|
| 99 |
+
## 🌐 API Endpoints
|
| 100 |
+
- `/tasks`: List all available tasks.
|
| 101 |
+
- `/baseline`: Run the baseline inference.
|
| 102 |
+
- `/grader`: Get the score of the last episode.
|
| 103 |
+
- `/health`: Environment health check.
|
__init__.py
ADDED
|
File without changes
|
client.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# MedTriage Environment Client
|
| 3 |
+
|
| 4 |
+
from openenv.core.mcp_client import MCPToolClient
|
| 5 |
+
|
| 6 |
+
class MedTriageEnv(MCPToolClient):
|
| 7 |
+
"""
|
| 8 |
+
Client for the MedTriage Environment.
|
| 9 |
+
|
| 10 |
+
Example:
|
| 11 |
+
>>> with MedTriageEnv(base_url="http://localhost:8000") as env:
|
| 12 |
+
... obs = env.reset(task_id="TASK_HARD")
|
| 13 |
+
... print(obs.symptoms_text)
|
| 14 |
+
... result = env.call_tool("triage_patient", level=3, reasoning="High BP and atypical symptoms in elderly patient.")
|
| 15 |
+
... print(f"Reward: {result.reward}")
|
| 16 |
+
"""
|
| 17 |
+
pass
|
inference.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# MedTriage - Baseline Inference Script
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
import time
|
| 6 |
+
import subprocess
|
| 7 |
+
import requests
|
| 8 |
+
from client import MedTriageEnv
|
| 9 |
+
|
| 10 |
+
def run_baseline(base_url: str = "http://localhost:8002"):
|
| 11 |
+
"""Run baseline agent against all 3 tasks and return results."""
|
| 12 |
+
print(f"🚀 Starting MedTriage Baseline Inference on {base_url}...")
|
| 13 |
+
|
| 14 |
+
tasks = ["TASK_EASY", "TASK_MEDIUM", "TASK_HARD"]
|
| 15 |
+
scores = {}
|
| 16 |
+
|
| 17 |
+
# Simple heuristic-based baseline (no LLM required for this local test)
|
| 18 |
+
try:
|
| 19 |
+
from client import MedTriageEnv
|
| 20 |
+
except ImportError:
|
| 21 |
+
from .client import MedTriageEnv
|
| 22 |
+
|
| 23 |
+
with MedTriageEnv(base_url=base_url).sync() as env:
|
| 24 |
+
for task_id in tasks:
|
| 25 |
+
print(f"📋 Running {task_id}...", end=" ", flush=True)
|
| 26 |
+
obs = env.reset(task_id=task_id)
|
| 27 |
+
|
| 28 |
+
# Simple heuristic logic
|
| 29 |
+
bp_sys = int(obs.vitals.get("bp", "120/80").split("/")[0])
|
| 30 |
+
|
| 31 |
+
if bp_sys > 150 or obs.age > 65:
|
| 32 |
+
level = 3 # EMERGENCY
|
| 33 |
+
elif "severe pain" in obs.symptoms_text.lower():
|
| 34 |
+
level = 2 # URGENT_CARE
|
| 35 |
+
else:
|
| 36 |
+
level = 0 # SELF_CARE
|
| 37 |
+
|
| 38 |
+
result = env.step({"tool_name": "triage_patient", "arguments": {"level": level, "reasoning": "Heuristic baseline."}})
|
| 39 |
+
scores[task_id] = result.reward
|
| 40 |
+
print(f"Score: {result.reward}")
|
| 41 |
+
|
| 42 |
+
return scores
|
| 43 |
+
|
| 44 |
+
if __name__ == "__main__":
|
| 45 |
+
results = run_baseline()
|
| 46 |
+
print("\n📊 FINAL BASELINE SCORES:", results)
|
models.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# MedTriage Environment - Type-Safe Models
|
| 3 |
+
|
| 4 |
+
from typing import Dict, List, Optional, Any
|
| 5 |
+
from enum import IntEnum
|
| 6 |
+
from pydantic import BaseModel, Field
|
| 7 |
+
|
| 8 |
+
# Core Triage Levels
|
| 9 |
+
class TriageLevel(IntEnum):
|
| 10 |
+
SELF_CARE = 0 # Over-the-counter/rest
|
| 11 |
+
CLINIC = 1 # Primary Care appointment (next 24-48h)
|
| 12 |
+
URGENT_CARE = 2 # Same-day clinic (e.g., potential fracture)
|
| 13 |
+
EMERGENCY = 3 # Immediate ER/Ambulance (life-threatening)
|
| 14 |
+
|
| 15 |
+
# 1. Action Model
|
| 16 |
+
class TriageAction(BaseModel):
|
| 17 |
+
level: TriageLevel = Field(..., description="Recommended triage level (0-3)")
|
| 18 |
+
reasoning: str = Field(..., description="Medical justification for the triage level")
|
| 19 |
+
follow_up_questions: List[str] = Field(default_factory=list, description="Questions to ask the patient if more info is needed")
|
| 20 |
+
|
| 21 |
+
# 2. Observation Model
|
| 22 |
+
class TriageObservation(BaseModel):
|
| 23 |
+
patient_id: str
|
| 24 |
+
age: int
|
| 25 |
+
gender: str
|
| 26 |
+
symptoms_text: str = Field(..., description="Unstructured description of symptoms from patient")
|
| 27 |
+
vitals: Dict[str, Any] = Field(default_factory=dict, description="Vitals like temp, bp, hr, spo2")
|
| 28 |
+
history: List[str] = Field(default_factory=list, description="Relevant past conditions or medications")
|
| 29 |
+
done: bool = False
|
| 30 |
+
reward: float = 0.0
|
| 31 |
+
message: str = ""
|
| 32 |
+
|
| 33 |
+
# 3. State Model (Metadata)
|
| 34 |
+
class TriageState(BaseModel):
|
| 35 |
+
episode_id: str
|
| 36 |
+
step_count: int = 0
|
| 37 |
+
current_task_id: str = ""
|
| 38 |
+
ground_truth_level: TriageLevel = TriageLevel.SELF_CARE
|
openenv.yaml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
spec_version: 1
|
| 2 |
+
name: med_triage_env
|
| 3 |
+
type: space
|
| 4 |
+
runtime: fastapi
|
| 5 |
+
app: server.app:app
|
| 6 |
+
port: 8002
|
pyproject.toml
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "med_triage_env"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
description = "Real-world Medical Triage environment for OpenEnv training."
|
| 5 |
+
authors = [
|
| 6 |
+
{ name = "Gemini AI", email = "gemini@example.com" }
|
| 7 |
+
]
|
| 8 |
+
dependencies = [
|
| 9 |
+
"fastapi",
|
| 10 |
+
"uvicorn",
|
| 11 |
+
"pydantic",
|
| 12 |
+
"fastmcp",
|
| 13 |
+
"requests",
|
| 14 |
+
"openenv-core>=0.2.0",
|
| 15 |
+
]
|
| 16 |
+
|
| 17 |
+
[project.scripts]
|
| 18 |
+
server = "server.app:main"
|
| 19 |
+
|
| 20 |
+
[project.optional-dependencies]
|
| 21 |
+
dev = [
|
| 22 |
+
"pytest",
|
| 23 |
+
"httpx",
|
| 24 |
+
]
|
| 25 |
+
|
| 26 |
+
[build-system]
|
| 27 |
+
requires = ["setuptools>=61.0"]
|
| 28 |
+
build-backend = "setuptools.build_meta"
|
server/__init__.py
ADDED
|
File without changes
|
server/app.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# FastAPI application for MedTriage Environment
|
| 3 |
+
|
| 4 |
+
from fastapi import FastAPI, Request
|
| 5 |
+
from openenv.core.env_server.http_server import create_app
|
| 6 |
+
from openenv.core.env_server.mcp_types import CallToolAction, CallToolObservation
|
| 7 |
+
from pydantic import BaseModel
|
| 8 |
+
from typing import List, Dict, Any
|
| 9 |
+
|
| 10 |
+
try:
|
| 11 |
+
from .triage_environment import MedTriageEnvironment, TASKS
|
| 12 |
+
from .models import TriageAction
|
| 13 |
+
except ImportError:
|
| 14 |
+
from triage_environment import MedTriageEnvironment, TASKS
|
| 15 |
+
from models import TriageAction
|
| 16 |
+
|
| 17 |
+
# Initialize the environment instance to be used by the app
|
| 18 |
+
env_instance = MedTriageEnvironment()
|
| 19 |
+
|
| 20 |
+
# Create the base OpenEnv app
|
| 21 |
+
app = create_app(
|
| 22 |
+
MedTriageEnvironment,
|
| 23 |
+
CallToolAction,
|
| 24 |
+
CallToolObservation,
|
| 25 |
+
env_name="med_triage_env"
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
# --- Additional Hackathon Endpoints ---
|
| 29 |
+
|
| 30 |
+
@app.get("/tasks")
|
| 31 |
+
async def get_tasks():
|
| 32 |
+
"""Returns list of tasks and the action schema."""
|
| 33 |
+
task_list = []
|
| 34 |
+
for tid, tdata in TASKS.items():
|
| 35 |
+
task_list.append({
|
| 36 |
+
"id": tid,
|
| 37 |
+
"name": tdata["name"],
|
| 38 |
+
"difficulty": tid.split("_")[1].lower()
|
| 39 |
+
})
|
| 40 |
+
|
| 41 |
+
return {
|
| 42 |
+
"tasks": task_list,
|
| 43 |
+
"action_schema": TriageAction.model_json_schema()
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
@app.get("/grader")
|
| 47 |
+
async def get_grader():
|
| 48 |
+
"""Returns the most recent grader score."""
|
| 49 |
+
state = env_instance.state
|
| 50 |
+
# In a real multi-session env, we'd lookup by session_id
|
| 51 |
+
# For a simple demo, we return the last calculated reward if available
|
| 52 |
+
return {"score": getattr(env_instance, "_last_reward", 0.0)}
|
| 53 |
+
|
| 54 |
+
@app.get("/baseline")
|
| 55 |
+
async def trigger_baseline():
|
| 56 |
+
"""
|
| 57 |
+
Trigger baseline inference script and return scores.
|
| 58 |
+
"""
|
| 59 |
+
try:
|
| 60 |
+
from ..inference import run_baseline
|
| 61 |
+
except ImportError:
|
| 62 |
+
import sys
|
| 63 |
+
import os
|
| 64 |
+
# Add parent dir to sys.path if not there
|
| 65 |
+
parent_dir = os.path.dirname(os.path.dirname(__file__))
|
| 66 |
+
if parent_dir not in sys.path:
|
| 67 |
+
sys.path.append(parent_dir)
|
| 68 |
+
from inference import run_baseline
|
| 69 |
+
|
| 70 |
+
# Execute actual baseline
|
| 71 |
+
scores = run_baseline(base_url="http://localhost:8002")
|
| 72 |
+
|
| 73 |
+
return {
|
| 74 |
+
"status": "baseline_completed",
|
| 75 |
+
"baseline_scores": scores
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
def main():
|
| 79 |
+
import uvicorn
|
| 80 |
+
uvicorn.run(app, host="0.0.0.0", port=8002)
|
| 81 |
+
|
| 82 |
+
if __name__ == "__main__":
|
| 83 |
+
main()
|
server/requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi
|
| 2 |
+
uvicorn
|
| 3 |
+
pydantic
|
| 4 |
+
fastmcp
|
| 5 |
+
requests
|
| 6 |
+
openenv
|
server/triage_environment.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# MedTriage Environment Implementation
|
| 3 |
+
|
| 4 |
+
import uuid
|
| 5 |
+
from typing import Any, Dict, Optional
|
| 6 |
+
from uuid import uuid4
|
| 7 |
+
|
| 8 |
+
# Imports (Adjust according to actual structure)
|
| 9 |
+
from openenv.core.env_server.mcp_environment import MCPEnvironment
|
| 10 |
+
from openenv.core.env_server.types import Action, Observation, State
|
| 11 |
+
from fastmcp import FastMCP
|
| 12 |
+
|
| 13 |
+
# Use local models
|
| 14 |
+
try:
|
| 15 |
+
from .models import TriageLevel, TriageAction, TriageObservation, TriageState
|
| 16 |
+
except ImportError:
|
| 17 |
+
from models import TriageLevel, TriageAction, TriageObservation, TriageState
|
| 18 |
+
|
| 19 |
+
# Task Scenarios (Easy -> Medium -> Hard)
|
| 20 |
+
TASKS = {
|
| 21 |
+
"TASK_EASY": {
|
| 22 |
+
"id": "TASK_EASY",
|
| 23 |
+
"name": "Seasonal Allergies",
|
| 24 |
+
"patient": {
|
| 25 |
+
"patient_id": "P-101", "age": 28, "gender": "Female",
|
| 26 |
+
"symptoms_text": "I've had a runny nose, sneezing, and itchy eyes for the past week. It's really annoying but I don't feel 'sick' otherwise.",
|
| 27 |
+
"vitals": {"temp": 98.6, "bp": "120/80", "hr": 72, "spo2": 99},
|
| 28 |
+
"history": ["No major conditions"]
|
| 29 |
+
},
|
| 30 |
+
"ground_truth": TriageLevel.SELF_CARE
|
| 31 |
+
},
|
| 32 |
+
"TASK_MEDIUM": {
|
| 33 |
+
"id": "TASK_MEDIUM",
|
| 34 |
+
"name": "Possible Appendicitis",
|
| 35 |
+
"patient": {
|
| 36 |
+
"patient_id": "P-102", "age": 19, "gender": "Male",
|
| 37 |
+
"symptoms_text": "I woke up with severe pain around my belly button that's moving down to my lower right side. I feel nauseous and have zero appetite.",
|
| 38 |
+
"vitals": {"temp": 100.8, "bp": "115/75", "hr": 95, "spo2": 98},
|
| 39 |
+
"history": ["No major conditions"]
|
| 40 |
+
},
|
| 41 |
+
"ground_truth": TriageLevel.URGENT_CARE
|
| 42 |
+
},
|
| 43 |
+
"TASK_HARD": {
|
| 44 |
+
"id": "TASK_HARD",
|
| 45 |
+
"name": "Atypical Myocardial Infarction",
|
| 46 |
+
"patient": {
|
| 47 |
+
"patient_id": "P-103", "age": 68, "gender": "Female",
|
| 48 |
+
"symptoms_text": "I just feel extremely weak and have this weird 'indigestion' sensation in my upper stomach. I'm also sweating a lot for no reason.",
|
| 49 |
+
"vitals": {"temp": 98.2, "bp": "165/100", "hr": 105, "spo2": 94},
|
| 50 |
+
"history": ["Type 2 Diabetes", "High Blood Pressure", "Smoking"]
|
| 51 |
+
},
|
| 52 |
+
"ground_truth": TriageLevel.EMERGENCY
|
| 53 |
+
}
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
class MedTriageEnvironment(MCPEnvironment):
|
| 57 |
+
"""
|
| 58 |
+
Real-world Triage Environment for Agent Training.
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
def __init__(self):
|
| 62 |
+
mcp = FastMCP("med_triage_env")
|
| 63 |
+
|
| 64 |
+
@mcp.tool
|
| 65 |
+
def triage_patient(level: int, reasoning: str) -> str:
|
| 66 |
+
"""
|
| 67 |
+
Analyze patient data and assign a triage level (0-3).
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
level: 0 (Self-Care), 1 (Clinic), 2 (Urgent Care), 3 (Emergency)
|
| 71 |
+
reasoning: Medical explanation for your decision
|
| 72 |
+
"""
|
| 73 |
+
return f"Triage decision received: Level {level}. Reason: {reasoning}"
|
| 74 |
+
|
| 75 |
+
super().__init__(mcp)
|
| 76 |
+
self._state = TriageState(episode_id=str(uuid4()))
|
| 77 |
+
self._current_task = None
|
| 78 |
+
|
| 79 |
+
def reset(self, task_id: Optional[str] = "TASK_EASY", **kwargs: Any) -> TriageObservation:
|
| 80 |
+
"""Reset the environment with a specific task (EASY, MEDIUM, or HARD)."""
|
| 81 |
+
task_id = task_id or "TASK_EASY"
|
| 82 |
+
if task_id not in TASKS:
|
| 83 |
+
task_id = "TASK_EASY"
|
| 84 |
+
|
| 85 |
+
self._current_task = TASKS[task_id]
|
| 86 |
+
self._state = TriageState(
|
| 87 |
+
episode_id=str(uuid4()),
|
| 88 |
+
step_count=0,
|
| 89 |
+
current_task_id=task_id,
|
| 90 |
+
ground_truth_level=self._current_task["ground_truth"]
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
patient = self._current_task["patient"]
|
| 94 |
+
return TriageObservation(
|
| 95 |
+
patient_id=patient["patient_id"],
|
| 96 |
+
age=patient["age"],
|
| 97 |
+
gender=patient["gender"],
|
| 98 |
+
symptoms_text=patient["symptoms_text"],
|
| 99 |
+
vitals=patient["vitals"],
|
| 100 |
+
history=patient["history"],
|
| 101 |
+
message=f"New Patient Triage: {self._current_task['name']}"
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
def _calculate_reward(self, agent_level: TriageLevel, ground_truth: TriageLevel) -> float:
|
| 105 |
+
"""
|
| 106 |
+
Scoring Logic (0.0 - 1.0):
|
| 107 |
+
- Perfect Match: 1.0
|
| 108 |
+
- Over-triage (too safe): 0.5 (safe but resource heavy)
|
| 109 |
+
- Minor Under-triage: 0.2 (delay in care)
|
| 110 |
+
- Major Under-triage (dangerous): 0.0 (unsafe)
|
| 111 |
+
"""
|
| 112 |
+
if agent_level == ground_truth:
|
| 113 |
+
return 1.0
|
| 114 |
+
|
| 115 |
+
# Dangerously Under-triaging an Emergency
|
| 116 |
+
if ground_truth == TriageLevel.EMERGENCY and agent_level < TriageLevel.URGENT_CARE:
|
| 117 |
+
return 0.0
|
| 118 |
+
|
| 119 |
+
# Over-triaging is better than under-triaging in medicine
|
| 120 |
+
if agent_level > ground_truth:
|
| 121 |
+
return 0.5
|
| 122 |
+
|
| 123 |
+
return 0.2
|
| 124 |
+
|
| 125 |
+
def step(self, action: Action, **kwargs: Any) -> TriageObservation:
|
| 126 |
+
"""
|
| 127 |
+
Process the agent's triage decision and return a score.
|
| 128 |
+
"""
|
| 129 |
+
self._state.step_count += 1
|
| 130 |
+
|
| 131 |
+
# If the action is an MCP CallToolAction (from step())
|
| 132 |
+
from openenv.core.env_server.mcp_types import CallToolAction
|
| 133 |
+
|
| 134 |
+
if isinstance(action, CallToolAction) and action.tool_name == "triage_patient":
|
| 135 |
+
agent_level = action.arguments.get("level")
|
| 136 |
+
reward = self._calculate_reward(TriageLevel(agent_level), self._state.ground_truth_level)
|
| 137 |
+
self._last_reward = reward
|
| 138 |
+
|
| 139 |
+
patient = self._current_task["patient"]
|
| 140 |
+
return TriageObservation(
|
| 141 |
+
**patient,
|
| 142 |
+
done=True,
|
| 143 |
+
reward=reward,
|
| 144 |
+
message=f"Episode complete. Agent Triage: {agent_level}. Ground Truth: {self._state.ground_truth_level.value}. Score: {reward}"
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
# Handle non-MCP fallback or invalid actions
|
| 148 |
+
obs = super().step(action, **kwargs)
|
| 149 |
+
return obs
|
| 150 |
+
|
| 151 |
+
@property
|
| 152 |
+
def state(self) -> State:
|
| 153 |
+
return self._state
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|