File size: 3,358 Bytes
332538a 703dfdc 332538a b5af4d3 332538a b5af4d3 332538a a60341c 332538a 38611c2 5e78317 38611c2 a60341c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 | import os
import sys
# Ensure root directory is in path so we can import 'env'
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from fastapi import FastAPI, HTTPException
from fastapi.responses import HTMLResponse
from env.environment import CustomerSupportEnv
from env.models import Action
from env.tasks import TASKS
from fastapi import Request
app = FastAPI(title="Customer Support OpenEnv", version="1.0.0")
# one env per session
sessions = {}
def get_env(session_id="default"):
if session_id not in sessions:
sessions[session_id] = CustomerSupportEnv()
return sessions[session_id]
@app.get("/", response_class=HTMLResponse)
def home():
return """
<html><body style="font-family:sans-serif;background:#0f1117;color:#e0e0e0;max-width:700px;margin:50px auto;padding:0 24px">
<h1 style="color:#7ee787">Customer Support OpenEnv</h1>
<p>An OpenEnv RL environment for customer support automation.</p>
<h2 style="color:#58a6ff">Endpoints</h2>
<ul>
<li><a href="/docs" style="color:#58a6ff">/docs</a> — Swagger UI</li>
<li><code>GET /reset?task_id=easy|medium|hard</code></li>
<li><code>POST /step</code> — send an Action</li>
<li><code>GET /state</code></li>
<li><a href="/tasks" style="color:#58a6ff">GET /tasks</a></li>
</ul>
</body></html>
"""
@app.get("/health")
def health():
return {"status": "ok"}
@app.api_route("/reset", methods=["GET", "POST"])
async def reset(request: Request, task_id: str = None, session_id: str = "default"):
if request.method == "POST":
try:
body = await request.json()
task_id = body.get("task_id", task_id)
session_id = body.get("session_id", session_id)
except Exception:
pass
env = get_env(session_id)
try:
obs = env.reset(task_id=task_id)
except ValueError as e:
raise HTTPException(400, str(e))
return {
"observation": obs.model_dump(),
"task": {
"id": env.current_task["id"],
"description": env.current_task["description"],
"max_steps": env.current_task["max_steps"],
},
}
@app.post("/step")
def step(action: Action, session_id: str = "default"):
env = get_env(session_id)
if not env.current_task:
raise HTTPException(400, "Call /reset first.")
try:
obs, reward, done, info = env.step(action)
except RuntimeError as e:
raise HTTPException(400, str(e))
return {
"observation": obs.model_dump(),
"reward": reward.model_dump(),
"done": done,
"info": info,
}
@app.get("/state")
def state(session_id: str = "default"):
env = get_env(session_id)
if not env.current_task:
raise HTTPException(400, "Call /reset first.")
return env.state()
@app.get("/tasks")
def list_tasks():
return [
{
"id": t["id"],
"description": t["description"],
"max_steps": t["max_steps"],
"requires_escalation": t["expected"]["requires_escalation"],
}
for t in TASKS.values()
]
def main():
import uvicorn
port = int(os.getenv("PORT", 7860))
uvicorn.run("server.app:app", host="0.0.0.0", port=port)
if __name__ == "__main__":
main() |