Spaces:
Running
Running
File size: 2,558 Bytes
33fa849 01b8ff5 33fa849 01b8ff5 33fa849 38c97c8 33fa849 38c97c8 33fa849 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 | # Copyright (c) 2026 Meta Platforms, Inc. and affiliates.
# FastAPI application for MedTriage Environment
from fastapi import FastAPI, Request
from openenv.core.env_server.http_server import create_app
from openenv.core.env_server.mcp_types import CallToolAction, CallToolObservation
from pydantic import BaseModel
from typing import List, Dict, Any
try:
from server.triage_environment import MedTriageEnvironment, TASKS
from models import TriageAction
except ImportError:
try:
from .triage_environment import MedTriageEnvironment, TASKS
from .models import TriageAction
except ImportError:
from triage_environment import MedTriageEnvironment, TASKS
from models import TriageAction
# Initialize the environment instance to be used by the app
env_instance = MedTriageEnvironment()
# Create the base OpenEnv app
app = create_app(
MedTriageEnvironment,
CallToolAction,
CallToolObservation,
env_name="med_triage_env"
)
# --- Additional Hackathon Endpoints ---
@app.get("/tasks")
async def get_tasks():
"""Returns list of tasks and the action schema."""
task_list = []
for tid, tdata in TASKS.items():
task_list.append({
"id": tid,
"name": tdata["name"],
"difficulty": tid.split("_")[1].lower()
})
return {
"tasks": task_list,
"action_schema": TriageAction.model_json_schema()
}
@app.get("/grader")
async def get_grader():
"""Returns the most recent grader score."""
state = env_instance.state
# In a real multi-session env, we'd lookup by session_id
# For a simple demo, we return the last calculated reward if available
return {"score": getattr(env_instance, "_last_reward", 0.0)}
@app.get("/baseline")
async def trigger_baseline():
"""
Trigger baseline inference script and return scores.
"""
try:
from ..inference import run_baseline
except ImportError:
import sys
import os
# Add parent dir to sys.path if not there
parent_dir = os.path.dirname(os.path.dirname(__file__))
if parent_dir not in sys.path:
sys.path.append(parent_dir)
from inference import run_baseline
# Execute actual baseline
scores = run_baseline(base_url="http://localhost:7860")
return {
"status": "baseline_completed",
"baseline_scores": scores
}
def main():
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)
if __name__ == "__main__":
main()
|