File size: 4,975 Bytes
5028b60 a2ac209 9cba1b2 5028b60 9cba1b2 5028b60 9cba1b2 5028b60 a2ac209 5028b60 f49f4d9 5028b60 7b988a7 5028b60 4471dff 9cba1b2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 | """
server.py β E.L.A.R.A. FastAPI server
Endpoints: /health /reset /step /state /tasks /grader /ws
"""
import json
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent / "app"))
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect
from pydantic import BaseModel
from environment import ElaraEnv
from models import Action, EpisodeState
from grader import grade
app = FastAPI(
title="E.L.A.R.A.",
description="Extensive LLM Applied Reasoning Agent β sandboxed sales ops environment",
version="0.1.0",
)
env = ElaraEnv()
# βββββββββββββββββββββββββββββββββββββββββββββ
# Request/response wrappers
# βββββββββββββββββββββββββββββββββββββββββββββ
class ResetRequest(BaseModel):
task_id: str = "easy"
seed: int | None = None
class StepRequest(BaseModel):
action: Action
# βββββββββββββββββββββββββββββββββββββββββββββ
# Endpoints
# βββββββββββββββββββββββββββββββββββββββββββββ
@app.get("/health")
def health():
return {"status": "ok", "project": "E.L.A.R.A.", "version": "0.1.0"}
@app.post("/reset")
def reset(req: ResetRequest = ResetRequest()):
try:
obs = env.reset(task_id=req.task_id, seed=req.seed)
return {"observation": obs.model_dump()}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@app.post("/step")
def step(req: StepRequest):
try:
obs, reward, done, info = env.step(req.action)
return {
"observation": obs.model_dump(),
"reward": reward,
"done": done,
"info": info,
}
except RuntimeError as e:
raise HTTPException(status_code=400, detail=str(e))
@app.get("/state")
def state():
try:
return env.state()
except RuntimeError as e:
raise HTTPException(status_code=400, detail=str(e))
@app.get("/tasks")
def tasks():
import json
from pathlib import Path
task_dir = Path(__file__).parent / "tasks"
result = []
for tid in ["easy", "medium", "hard", "consent"]:
p = task_dir / f"{tid}.json"
if p.exists():
result.append(json.loads(p.read_text()))
return {"tasks": result}
@app.post("/grader")
def grader():
try:
raw = env.state()
s = EpisodeState(**raw)
result = grade(s)
return result
except RuntimeError as e:
raise HTTPException(status_code=400, detail=str(e))
# βββββββββββββββββββββββββββββββββββββββββββββ
# WebSocket endpoint (OpenEnv SDK)
# βββββββββββββββββββββββββββββββββββββββββββββ
@app.websocket("/ws")
async def websocket_endpoint(ws: WebSocket):
await ws.accept()
try:
while True:
raw = await ws.receive_text()
msg = json.loads(raw)
msg_type = msg.get("type", "")
data = msg.get("data", {})
if msg_type == "reset":
task_id = data.get("task_id", "easy")
seed = data.get("seed")
obs = env.reset(task_id=task_id, seed=seed)
await ws.send_text(json.dumps({
"type": "reset",
"data": {
"observation": obs.model_dump(),
"reward": 0.0,
"done": False,
},
}))
elif msg_type == "step":
action_data = data.get("action", data)
action = Action(**action_data)
obs, reward, done, info = env.step(action)
await ws.send_text(json.dumps({
"type": "step",
"data": {
"observation": obs.model_dump(),
"reward": reward,
"done": done,
"info": info,
},
}))
elif msg_type == "state":
state_data = env.state()
await ws.send_text(json.dumps({
"type": "state",
"data": state_data,
}))
elif msg_type == "close":
await ws.close()
break
else:
await ws.send_text(json.dumps({
"type": "error",
"data": {"message": f"Unknown message type: {msg_type}"},
}))
except WebSocketDisconnect:
pass
|