File size: 2,398 Bytes
96a5caf
 
 
 
 
c3935be
96a5caf
5107b13
 
96a5caf
da62f9f
 
 
 
 
96a5caf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5107b13
 
 
 
 
 
 
96a5caf
5107b13
96a5caf
 
 
5107b13
96a5caf
 
 
5107b13
 
 
 
 
96a5caf
 
 
 
 
 
 
 
c3935be
5107b13
c3935be
96a5caf
c3935be
96a5caf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5107b13
96a5caf
 
 
 
 
5107b13
96a5caf
5107b13
96a5caf
 
a4d97e5
4e320a3
a4d97e5
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
"""
FastAPI server exposing the Email Triage environment via HTTP.
Endpoints mirror the OpenEnv spec.
"""

from fastapi import FastAPI, HTTPException, Body
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from typing import Optional, Union
import uvicorn
import os
import sys

# Ensure the root directory is in sys.path so environment.py can be imported
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))

from environment import EmailTriageEnv, Action

app = FastAPI(title="Email Triage Environment", version="1.0.0")

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)

# One env per task (task is set at reset time)
_envs: dict[int, EmailTriageEnv] = {}


def _parse_task(task: Union[int, str]) -> int:
    if isinstance(task, str):
        if task.startswith("task"):
            return int(task[4:])
        return int(task)
    return task

class ResetRequest(BaseModel):
    task: Union[int, str] = 1


class StepRequest(BaseModel):
    task: Union[int, str] = 1
    action: Action


def _get_env(task: Union[int, str]) -> EmailTriageEnv:
    task_int = _parse_task(task)
    if task_int not in _envs:
        raise HTTPException(status_code=400, detail=f"Task {task_int} not initialised. Call /reset first.")
    return _envs[task_int]


@app.get("/health")
def health():
    return {"status": "ok"}


@app.post("/reset")
def reset(req: Optional[ResetRequest] = Body(default=None)):
    task = _parse_task(req.task if req else 1)
    env = EmailTriageEnv(task=task)
    obs = env.reset()
    _envs[task] = env
    return {"observation": obs.model_dump(), "state": env.state()}


@app.post("/step")
def step(req: StepRequest):
    env = _get_env(req.task)
    result = env.step(req.action)
    return {
        "observation": result.observation.model_dump(),
        "reward": result.reward,
        "done": result.done,
        "info": result.info,
        "score": env.score(),
    }


@app.get("/state")
def state(task: Union[int, str] = 1):
    env = _get_env(task)
    return {"state": env.state(), "score": env.score()}


@app.get("/score")
def score(task: Union[int, str] = 1):
    env = _get_env(task)
    return {"score": env.score(), "task": _parse_task(task)}


def main():
    uvicorn.run(app, host="0.0.0.0", port=7860)


if __name__ == "__main__":
    main()