File size: 2,558 Bytes
33fa849
 
 
 
 
 
 
 
 
 
01b8ff5
33fa849
01b8ff5
 
 
 
 
 
 
33fa849
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38c97c8
33fa849
 
 
 
 
 
 
 
38c97c8
33fa849
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
# Copyright (c) 2026 Meta Platforms, Inc. and affiliates.
# FastAPI application for MedTriage Environment

from fastapi import FastAPI, Request
from openenv.core.env_server.http_server import create_app
from openenv.core.env_server.mcp_types import CallToolAction, CallToolObservation
from pydantic import BaseModel
from typing import List, Dict, Any

try:
    from server.triage_environment import MedTriageEnvironment, TASKS
    from models import TriageAction
except ImportError:
    try:
        from .triage_environment import MedTriageEnvironment, TASKS
        from .models import TriageAction
    except ImportError:
        from triage_environment import MedTriageEnvironment, TASKS
        from models import TriageAction

# Initialize the environment instance to be used by the app
env_instance = MedTriageEnvironment()

# Create the base OpenEnv app
app = create_app(
    MedTriageEnvironment, 
    CallToolAction, 
    CallToolObservation, 
    env_name="med_triage_env"
)

# --- Additional Hackathon Endpoints ---

@app.get("/tasks")
async def get_tasks():
    """Returns list of tasks and the action schema."""
    task_list = []
    for tid, tdata in TASKS.items():
        task_list.append({
            "id": tid,
            "name": tdata["name"],
            "difficulty": tid.split("_")[1].lower()
        })
    
    return {
        "tasks": task_list,
        "action_schema": TriageAction.model_json_schema()
    }

@app.get("/grader")
async def get_grader():
    """Returns the most recent grader score."""
    state = env_instance.state
    # In a real multi-session env, we'd lookup by session_id
    # For a simple demo, we return the last calculated reward if available
    return {"score": getattr(env_instance, "_last_reward", 0.0)}

@app.get("/baseline")
async def trigger_baseline():
    """
    Trigger baseline inference script and return scores.
    """
    try:
        from ..inference import run_baseline
    except ImportError:
        import sys
        import os
        # Add parent dir to sys.path if not there
        parent_dir = os.path.dirname(os.path.dirname(__file__))
        if parent_dir not in sys.path:
            sys.path.append(parent_dir)
        from inference import run_baseline
        
    # Execute actual baseline
    scores = run_baseline(base_url="http://localhost:7860")
    
    return {
        "status": "baseline_completed",
        "baseline_scores": scores
    }

def main():
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)

if __name__ == "__main__":
    main()