Dar3devil's picture
Allow empty POST body on reset
1f76703 verified
Raw
History Blame Contribute Delete
6.06 kB
from __future__ import annotations
from threading import Lock
from uuid import uuid4
from typing import Any
from fastapi import Body, FastAPI, HTTPException, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse, RedirectResponse
from pydantic import BaseModel, ConfigDict
from support_ticket_env import SupportTicketEnvironment, list_task_ids
class ResetRequest(BaseModel):
model_config = ConfigDict(extra="forbid")
task_id: str | None = None
session_id: str | None = None
class StepRequest(BaseModel):
model_config = ConfigDict(extra="forbid")
session_id: str
action: dict[str, Any]
class SessionManager:
def __init__(self) -> None:
self._sessions: dict[str, SupportTicketEnvironment] = {}
self._lock = Lock()
def create_or_reuse(self, session_id: str | None = None, task_id: str | None = None) -> tuple[str, SupportTicketEnvironment]:
with self._lock:
if session_id and session_id in self._sessions:
return session_id, self._sessions[session_id]
new_session_id = session_id or str(uuid4())
env = SupportTicketEnvironment(task_id=task_id)
self._sessions[new_session_id] = env
return new_session_id, env
def get(self, session_id: str) -> SupportTicketEnvironment:
with self._lock:
if session_id not in self._sessions:
raise KeyError(session_id)
return self._sessions[session_id]
def delete(self, session_id: str) -> None:
with self._lock:
self._sessions.pop(session_id, None)
manager = SessionManager()
app = FastAPI(
title="AcmeCloud Customer Support Ticket Handler",
version="0.1.0",
description="Deterministic OpenEnv-style customer support benchmark for B2B SaaS ticket handling.",
)
def _step_payload(result, session_id: str) -> dict[str, Any]:
payload = result.model_dump(mode="json")
payload.setdefault("info", {})["session_id"] = session_id
return payload
@app.get("/health")
def health() -> dict[str, Any]:
return {"status": "healthy", "tasks": list_task_ids()}
@app.get("/", include_in_schema=False)
def root() -> RedirectResponse:
return RedirectResponse(url="/web", status_code=307)
@app.post("/reset")
def reset(request: ResetRequest | None = Body(default=None)) -> dict[str, Any]:
request = request or ResetRequest()
session_id, env = manager.create_or_reuse(request.session_id, request.task_id)
result = env.reset(request.task_id)
return _step_payload(result, session_id)
@app.post("/step")
def step(request: StepRequest) -> dict[str, Any]:
try:
env = manager.get(request.session_id)
except KeyError as exc:
raise HTTPException(status_code=404, detail=f"Unknown session_id: {request.session_id}") from exc
result = env.step(request.action)
return _step_payload(result, request.session_id)
@app.get("/state")
def state(session_id: str) -> dict[str, Any]:
try:
env = manager.get(session_id)
except KeyError as exc:
raise HTTPException(status_code=404, detail=f"Unknown session_id: {session_id}") from exc
return {"session_id": session_id, **env.state()}
@app.delete("/session/{session_id}")
def close_session(session_id: str) -> dict[str, str]:
manager.delete(session_id)
return {"status": "deleted", "session_id": session_id}
@app.get("/web")
def web_ui() -> HTMLResponse:
task_items = "".join(f"<li><code>{task_id}</code></li>" for task_id in list_task_ids())
html = f"""
<html>
<head>
<title>AcmeCloud Customer Support Ticket Handler</title>
<style>
body {{ font-family: Segoe UI, sans-serif; margin: 2rem auto; max-width: 900px; line-height: 1.5; }}
code {{ background: #f4f4f4; padding: 0.15rem 0.35rem; border-radius: 0.25rem; }}
pre {{ background: #111827; color: #f9fafb; padding: 1rem; border-radius: 0.5rem; overflow-x: auto; }}
</style>
</head>
<body>
<h1>AcmeCloud Customer Support Ticket Handler</h1>
<p>One episode equals one support ticket. Available fixed tasks:</p>
<ul>{task_items}</ul>
<p>Example local reset:</p>
<pre>curl -X POST http://localhost:8000/reset -H "Content-Type: application/json" -d '{{"task_id":"password_reset_guidance"}}'</pre>
</body>
</html>
"""
return HTMLResponse(html)
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket) -> None:
await websocket.accept()
session_id = str(uuid4())
env = SupportTicketEnvironment()
try:
while True:
payload = await websocket.receive_json()
message_type = payload.get("type")
if message_type == "reset":
result = env.reset(payload.get("task_id"))
await websocket.send_json(_step_payload(result, session_id))
elif message_type == "step":
result = env.step(payload.get("action", {}))
await websocket.send_json(_step_payload(result, session_id))
elif message_type == "state":
await websocket.send_json({"session_id": session_id, **env.state()})
elif message_type == "close":
await websocket.send_json({"status": "closed", "session_id": session_id})
break
else:
await websocket.send_json(
{
"error": "unsupported_message_type",
"message": "Use reset, step, state, or close.",
"session_id": session_id,
}
)
except WebSocketDisconnect:
return
def main() -> None:
import uvicorn
uvicorn.run("server.app:app", host="0.0.0.0", port=8000, reload=False)
if __name__ == "__main__":
main()