File size: 7,950 Bytes
c5cfc73
a363048
c5cfc73
a363048
 
c5cfc73
 
 
 
 
 
503bc84
c5cfc73
af7c75f
c5cfc73
65dfc27
c5cfc73
 
a363048
 
503bc84
c5cfc73
 
 
 
 
a363048
c5cfc73
 
 
 
 
a363048
c5cfc73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a363048
c5cfc73
a363048
c5cfc73
 
 
a363048
c5cfc73
 
 
 
 
 
 
 
a363048
c5cfc73
a363048
 
 
 
 
 
 
c5cfc73
 
a363048
c5cfc73
 
 
 
 
a06a840
c5cfc73
 
 
 
 
 
 
 
 
 
a363048
c5cfc73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a363048
 
 
c5cfc73
 
 
 
 
a363048
c5cfc73
a363048
 
 
 
 
c5cfc73
a363048
 
 
 
 
6f7e1b7
a2ae67c
a363048
 
 
 
 
 
 
 
 
a2ae67c
c5cfc73
 
af7c75f
 
 
 
 
 
 
 
 
 
 
38cc60a
af7c75f
c5cfc73
 
 
a363048
c5cfc73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a363048
c5cfc73
 
 
a363048
c5cfc73
a363048
c5cfc73
 
 
 
 
 
 
a363048
c5cfc73
 
 
 
 
 
 
a363048
c5cfc73
 
 
 
 
 
 
 
 
 
503bc84
 
af7c75f
503bc84
c5cfc73
 
 
a363048
bc262f3
 
253ebc2
bc262f3
 
253ebc2
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
"""
FastAPI application for the ESCTR Environment.

Exposes the Enterprise Supply Chain & Tax Reconciliation environment
over HTTP and WebSocket endpoints compatible with the OpenEnv spec.
"""

import json
import logging
from typing import Any, Dict, Optional

import gradio as gr
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse
from fastapi.responses import JSONResponse
from fastapi.responses import RedirectResponse
from pydantic import BaseModel

from .models import ESCTRAction, ESCTRObservation, ESCTRState
from .environment import ESCTREnvironment
from .gradio_ui import build_gradio_app

logger = logging.getLogger(__name__)


# ---------------------------------------------------------------------------
# Request / Response models
# ---------------------------------------------------------------------------

class ResetRequest(BaseModel):
    seed: Optional[int] = None
    episode_id: Optional[str] = None
    task_name: str = "procurement_reconciliation"

    class Config:
        extra = "allow"


class StepRequest(BaseModel):
    action: Dict[str, Any]
    timeout_s: Optional[float] = None

    class Config:
        extra = "allow"


class HealthResponse(BaseModel):
    status: str = "healthy"


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------

def _obs_to_response(obs: ESCTRObservation) -> dict:
    obs_dict = obs.model_dump()
    reward = obs_dict.pop("reward", 0.0)
    done = obs_dict.pop("done", False)
    return {
        "observation": obs_dict,
        "reward": reward,
        "done": done,
    }


# ---------------------------------------------------------------------------
# Application factory
# ---------------------------------------------------------------------------

def create_app() -> FastAPI:
    app = FastAPI(
        title="ESCTR Environment",
        description=(
            "Enterprise Supply Chain & Tax Reconciliation β€” an OpenEnv environment "
            "for training LLMs to investigate discrepancies, enforce SLA penalties, "
            "and navigate adversarial vendor disputes."
        ),
        version="1.0.0",
    )

    _env = ESCTREnvironment()

    @app.get("/health")
    def health():
        return HealthResponse()



    @app.post("/reset")
    def reset(request: ResetRequest = ResetRequest()):
        kwargs = request.model_dump(exclude_unset=True)
        obs = _env.reset(**kwargs)
        return _obs_to_response(obs)

    @app.post("/step")
    def step(request: StepRequest):
        try:
            action = ESCTRAction(**request.action)
        except Exception as e:
            return JSONResponse(
                status_code=422,
                content={"detail": f"Invalid action: {str(e)}"},
            )
        obs = _env.step(action, timeout_s=request.timeout_s)
        return _obs_to_response(obs)

    @app.get("/state")
    def get_state():
        return _env.state.model_dump()

    @app.get("/schema")
    def get_schema():
        return {
            "action": ESCTRAction.model_json_schema(),
            "observation": ESCTRObservation.model_json_schema(),
            "state": ESCTRState.model_json_schema(),
        }

    @app.get("/metadata")
    def get_metadata():
        return {
            "name": "esctr_environment",
            "description": (
                "Enterprise Supply Chain & Tax Reconciliation: an environment where "
                "an LLM agent operates as an autonomous financial controller, investigating "
                "procurement discrepancies, enforcing SLA penalties from shipping delays, "
                "and navigating adversarial vendor disputes. Features procedural generation "
                "for infinite scenarios, RLVR composite rewards, and multi-tool agentic workflow."
            ),
            "version": "1.0.0",
            "themes": [
                "World Modeling β€” Professional Tasks",
                "Long-Horizon Planning & Instruction Following",
                "Multi-Agent Interactions (adversarial vendor)",
            ],
            "tasks": [
                {"name": "procurement_reconciliation", "difficulty": "easy", "max_steps": 10,
                 "description": "Identify overcharged line items between PO and Invoice"},
                {"name": "sla_enforcement", "difficulty": "medium", "max_steps": 15,
                 "description": "Calculate late delivery penalties from shipping logs and SLA contracts"},
                {"name": "adversarial_auditing", "difficulty": "hard", "max_steps": 20,
                 "description": "Navigate vendor disputes, verify warehouse logs, reject settlement offers"},
            ],
            "tools": [
                "query_database", "read_document", "communicate_vendor", "submit_financial_decision",
            ],
        }

    @app.get("/trace")
    def get_trace():
        return {
            "episode_id": _env.state.episode_id,
            "task_name": _env.state.task_name,
            "steps": _env.state.step_count,
            "action_trace": _env.action_trace,
        }

    @app.get("/", response_class=HTMLResponse)
    def root():
        return RedirectResponse(url="/demo/", status_code=302)

    @app.websocket("/ws")
    async def websocket_endpoint(websocket: WebSocket):
        await websocket.accept()
        ws_env = ESCTREnvironment()
        logger.info("WebSocket session opened")

        try:
            while True:
                raw = await websocket.receive_text()
                try:
                    msg = json.loads(raw)
                except json.JSONDecodeError:
                    await websocket.send_json({
                        "type": "error",
                        "data": {"message": "Invalid JSON", "code": "INVALID_JSON"},
                    })
                    continue

                msg_type = msg.get("type", "")
                msg_data = msg.get("data", {})

                if msg_type == "reset":
                    obs = ws_env.reset(**msg_data)
                    await websocket.send_json({"type": "observation", "data": _obs_to_response(obs)})

                elif msg_type == "step":
                    try:
                        action = ESCTRAction(**msg_data)
                        obs = ws_env.step(action)
                        await websocket.send_json({"type": "observation", "data": _obs_to_response(obs)})
                    except Exception as e:
                        await websocket.send_json({
                            "type": "error",
                            "data": {"message": str(e), "code": "EXECUTION_ERROR"},
                        })

                elif msg_type == "state":
                    await websocket.send_json({"type": "state", "data": ws_env.state.model_dump()})

                elif msg_type == "close":
                    break

                else:
                    await websocket.send_json({
                        "type": "error",
                        "data": {"message": f"Unknown message type: {msg_type}", "code": "UNKNOWN_TYPE"},
                    })

        except WebSocketDisconnect:
            logger.info("WebSocket session disconnected")
        except Exception as e:
            logger.error(f"WebSocket error: {e}")
        finally:
            ws_env.close()
            logger.info("WebSocket session closed")

    # ── Mount Gradio UI ──────────────────────────────────────────────────
    gradio_app = build_gradio_app()
    app = gr.mount_gradio_app(app, gradio_app, path="/demo")

    return app


app = create_app()


def main():
    import uvicorn
    uvicorn.run("server.app:app", host="0.0.0.0", port=7860)


if __name__ == "__main__":
    main()