File size: 1,754 Bytes
cd688d7 aa4f7bc cd688d7 aa4f7bc cd688d7 aa4f7bc cd688d7 aa4f7bc cd688d7 aa4f7bc cd688d7 aa4f7bc cd688d7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 | from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from env.environment import SupportTicketEnv
from env.models import Action
from typing import Dict
from uuid import uuid4
app = FastAPI(title="OpenEnv Support Ticket API")
# Store sessions keyed by UUID to allow concurrent sessions
SESSIONS: Dict[str, SupportTicketEnv] = {}
class InitRequest(BaseModel):
task_id: str = "task_easy_1"
class StepRequest(BaseModel):
session_id: str
action: Action
@app.get("/")
def read_root():
return {"status": "ok", "message": "Support Ticket OpenEnv is live."}
@app.post("/reset")
def reset_env(req: InitRequest):
try:
env = SupportTicketEnv(task_id=req.task_id)
obs = env.reset()
session_id = str(uuid4())
SESSIONS[session_id] = env
return {"session_id": session_id, "observation": obs.model_dump()}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@app.post("/step")
def step_env(req: StepRequest):
env = SESSIONS.get(req.session_id)
if not env:
raise HTTPException(status_code=400, detail="Invalid or expired session_id. Call /reset to create a session.")
obs, reward, done, info = env.step(req.action)
return {
"observation": obs.model_dump(),
"reward": reward,
"done": done,
"info": info
}
@app.get("/state")
def state_env(session_id: str):
env = SESSIONS.get(session_id)
if not env:
raise HTTPException(status_code=400, detail="Invalid or expired session_id. Call /reset to create a session.")
return env.get_state().model_dump()
def main():
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)
if __name__ == "__main__":
main()
|