Spaces:
Running
Running
File size: 7,950 Bytes
c5cfc73 a363048 c5cfc73 a363048 c5cfc73 503bc84 c5cfc73 af7c75f c5cfc73 65dfc27 c5cfc73 a363048 503bc84 c5cfc73 a363048 c5cfc73 a363048 c5cfc73 a363048 c5cfc73 a363048 c5cfc73 a363048 c5cfc73 a363048 c5cfc73 a363048 c5cfc73 a363048 c5cfc73 a06a840 c5cfc73 a363048 c5cfc73 a363048 c5cfc73 a363048 c5cfc73 a363048 c5cfc73 a363048 6f7e1b7 a2ae67c a363048 a2ae67c c5cfc73 af7c75f 38cc60a af7c75f c5cfc73 a363048 c5cfc73 a363048 c5cfc73 a363048 c5cfc73 a363048 c5cfc73 a363048 c5cfc73 a363048 c5cfc73 503bc84 af7c75f 503bc84 c5cfc73 a363048 bc262f3 253ebc2 bc262f3 253ebc2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 | """
FastAPI application for the ESCTR Environment.
Exposes the Enterprise Supply Chain & Tax Reconciliation environment
over HTTP and WebSocket endpoints compatible with the OpenEnv spec.
"""
import json
import logging
from typing import Any, Dict, Optional
import gradio as gr
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse
from fastapi.responses import JSONResponse
from fastapi.responses import RedirectResponse
from pydantic import BaseModel
from .models import ESCTRAction, ESCTRObservation, ESCTRState
from .environment import ESCTREnvironment
from .gradio_ui import build_gradio_app
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Request / Response models
# ---------------------------------------------------------------------------
class ResetRequest(BaseModel):
seed: Optional[int] = None
episode_id: Optional[str] = None
task_name: str = "procurement_reconciliation"
class Config:
extra = "allow"
class StepRequest(BaseModel):
action: Dict[str, Any]
timeout_s: Optional[float] = None
class Config:
extra = "allow"
class HealthResponse(BaseModel):
status: str = "healthy"
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _obs_to_response(obs: ESCTRObservation) -> dict:
obs_dict = obs.model_dump()
reward = obs_dict.pop("reward", 0.0)
done = obs_dict.pop("done", False)
return {
"observation": obs_dict,
"reward": reward,
"done": done,
}
# ---------------------------------------------------------------------------
# Application factory
# ---------------------------------------------------------------------------
def create_app() -> FastAPI:
app = FastAPI(
title="ESCTR Environment",
description=(
"Enterprise Supply Chain & Tax Reconciliation β an OpenEnv environment "
"for training LLMs to investigate discrepancies, enforce SLA penalties, "
"and navigate adversarial vendor disputes."
),
version="1.0.0",
)
_env = ESCTREnvironment()
@app.get("/health")
def health():
return HealthResponse()
@app.post("/reset")
def reset(request: ResetRequest = ResetRequest()):
kwargs = request.model_dump(exclude_unset=True)
obs = _env.reset(**kwargs)
return _obs_to_response(obs)
@app.post("/step")
def step(request: StepRequest):
try:
action = ESCTRAction(**request.action)
except Exception as e:
return JSONResponse(
status_code=422,
content={"detail": f"Invalid action: {str(e)}"},
)
obs = _env.step(action, timeout_s=request.timeout_s)
return _obs_to_response(obs)
@app.get("/state")
def get_state():
return _env.state.model_dump()
@app.get("/schema")
def get_schema():
return {
"action": ESCTRAction.model_json_schema(),
"observation": ESCTRObservation.model_json_schema(),
"state": ESCTRState.model_json_schema(),
}
@app.get("/metadata")
def get_metadata():
return {
"name": "esctr_environment",
"description": (
"Enterprise Supply Chain & Tax Reconciliation: an environment where "
"an LLM agent operates as an autonomous financial controller, investigating "
"procurement discrepancies, enforcing SLA penalties from shipping delays, "
"and navigating adversarial vendor disputes. Features procedural generation "
"for infinite scenarios, RLVR composite rewards, and multi-tool agentic workflow."
),
"version": "1.0.0",
"themes": [
"World Modeling β Professional Tasks",
"Long-Horizon Planning & Instruction Following",
"Multi-Agent Interactions (adversarial vendor)",
],
"tasks": [
{"name": "procurement_reconciliation", "difficulty": "easy", "max_steps": 10,
"description": "Identify overcharged line items between PO and Invoice"},
{"name": "sla_enforcement", "difficulty": "medium", "max_steps": 15,
"description": "Calculate late delivery penalties from shipping logs and SLA contracts"},
{"name": "adversarial_auditing", "difficulty": "hard", "max_steps": 20,
"description": "Navigate vendor disputes, verify warehouse logs, reject settlement offers"},
],
"tools": [
"query_database", "read_document", "communicate_vendor", "submit_financial_decision",
],
}
@app.get("/trace")
def get_trace():
return {
"episode_id": _env.state.episode_id,
"task_name": _env.state.task_name,
"steps": _env.state.step_count,
"action_trace": _env.action_trace,
}
@app.get("/", response_class=HTMLResponse)
def root():
return RedirectResponse(url="/demo/", status_code=302)
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
ws_env = ESCTREnvironment()
logger.info("WebSocket session opened")
try:
while True:
raw = await websocket.receive_text()
try:
msg = json.loads(raw)
except json.JSONDecodeError:
await websocket.send_json({
"type": "error",
"data": {"message": "Invalid JSON", "code": "INVALID_JSON"},
})
continue
msg_type = msg.get("type", "")
msg_data = msg.get("data", {})
if msg_type == "reset":
obs = ws_env.reset(**msg_data)
await websocket.send_json({"type": "observation", "data": _obs_to_response(obs)})
elif msg_type == "step":
try:
action = ESCTRAction(**msg_data)
obs = ws_env.step(action)
await websocket.send_json({"type": "observation", "data": _obs_to_response(obs)})
except Exception as e:
await websocket.send_json({
"type": "error",
"data": {"message": str(e), "code": "EXECUTION_ERROR"},
})
elif msg_type == "state":
await websocket.send_json({"type": "state", "data": ws_env.state.model_dump()})
elif msg_type == "close":
break
else:
await websocket.send_json({
"type": "error",
"data": {"message": f"Unknown message type: {msg_type}", "code": "UNKNOWN_TYPE"},
})
except WebSocketDisconnect:
logger.info("WebSocket session disconnected")
except Exception as e:
logger.error(f"WebSocket error: {e}")
finally:
ws_env.close()
logger.info("WebSocket session closed")
# ββ Mount Gradio UI ββββββββββββββββββββββββββββββββββββββββββββββββββ
gradio_app = build_gradio_app()
app = gr.mount_gradio_app(app, gradio_app, path="/demo")
return app
app = create_app()
def main():
import uvicorn
uvicorn.run("server.app:app", host="0.0.0.0", port=7860)
if __name__ == "__main__":
main()
|