File size: 10,503 Bytes
543a85f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63c7e0c
543a85f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4fa79b5
b4cadb7
4fa79b5
543a85f
 
4fa79b5
543a85f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c1f8942
b4cadb7
 
543a85f
 
c1f8942
543a85f
 
 
 
 
 
 
 
 
 
 
63c7e0c
543a85f
 
 
 
63c7e0c
 
 
543a85f
63c7e0c
 
543a85f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63c7e0c
543a85f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c1f8942
b4cadb7
 
543a85f
 
 
 
c1f8942
543a85f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
"""
FastAPI application for the PrivilegeDesk Environment.

Endpoints (OpenEnv standard):
    POST /reset     - start a new episode
    POST /step      β€” execute a tool call
    GET  /state     β€” current episode state
    GET  /schema    β€” action/observation JSON schemas
    WS   /ws        β€” WebSocket for persistent sessions

Additional endpoints (hackathon required):
    GET  /tasks     β€” list the 3 tasks with action schemas
    POST /grader    β€” return episode score breakdown (0.0–1.0)
    POST /baseline  β€” run baseline agent and return scores

NOTE on HTTP statefulness
--------------------------
OpenEnv's create_app() HTTP /reset and /step handlers are intentionally
stateless β€” each request creates a new throwaway env instance. State is
only preserved over WebSocket.

We remove those stateless routes and replace them with our own singleton-
backed versions so that curl-style HTTP testing works correctly. WebSocket
(/ws) is untouched and still creates proper per-session environments.
"""
from typing import Any, Dict, Optional, List

from fastapi import HTTPException
from pydantic import BaseModel
from starlette.routing import Route

try:
    from openenv.core.env_server.http_server import create_app
except Exception as e:
    raise ImportError("openenv is required. Install with: uv sync") from e

try:
    from ..models import PrivilegeDeskAction, PrivilegeDeskObservation
    from .privilege_desk_environment import PrivilegeDeskEnvironment
except (ImportError, ModuleNotFoundError):
    from models import PrivilegeDeskAction, PrivilegeDeskObservation
    from server.privilege_desk_environment import PrivilegeDeskEnvironment

from env.world_state import WorldState
from reward.grader import grade
from pipeline.task_templates import TASK_TEMPLATES


# ── Module-level singleton β€” shared across /reset, /step, /grader ─────────────
# Mimics what OpenEnv's WebSocket session manager does per-connection, but
# over plain HTTP for easy curl / script testing.

_world: WorldState = WorldState()


# ── Build base app from OpenEnv (registers /ws, /schema, /health, /state) ────

app = create_app(
    PrivilegeDeskEnvironment,
    PrivilegeDeskAction,
    PrivilegeDeskObservation,
    env_name="privilege_desk",
    max_concurrent_envs=1,
)

# Remove the stateless /reset and /step routes that create_app registered.
# FastAPI uses first-match routing, so if we don't remove them our overrides
# (registered below) would never be reached.
_OVERRIDE_PATHS = {"/reset", "/step"}
app.routes[:] = [
    r for r in app.routes
    if not (
        isinstance(r, Route)
        and r.path in _OVERRIDE_PATHS
        and "POST" in (r.methods or set())
    )
]


# ── Pydantic request models ───────────────────────────────────────────────────

class ResetRequest(BaseModel):
    task_id: str = "access_decision"
    seed: Optional[int] = None
    difficulty_level: int = 1


class StepRequest(BaseModel):
    action: Dict[str, Any]  # {"tool_name": "...", "arguments": {...}}


# ── /reset ────────────────────────────────────────────────────────────────────

@app.post("/reset")
def reset_episode(body: ResetRequest = None) -> Dict[str, Any]:
    """Reset to a new episode and return the initial observation.

    The world state is stored in a module-level singleton so that subsequent
    /step and /grader calls operate on the same episode.
    """
    global _world
    req = body or ResetRequest()

    _world = WorldState()
    obs = _world.reset(
        seed=req.seed,
        task_id=req.task_id,
        difficulty_level=req.difficulty_level,
    )

    return {
        "observation": obs,
        "reward": 0.0,
        "done": False,
        "info": {
            "task_id": req.task_id,
            "seed": req.seed,
            "episode": "started",
        },
    }


# ── /step ─────────────────────────────────────────────────────────────────────

@app.post("/step")
def step_episode(body: StepRequest) -> Dict[str, Any]:
    """Execute one tool call on the current episode.

    body.action must be a dict with:
        tool_name (str)   β€” e.g. "request.view", "access.decide"
        arguments (dict)  β€” tool-specific kwargs (can be empty {})
    """
    global _world

    if _world._router is None:
        raise HTTPException(
            status_code=409,
            detail="No active episode. Call POST /reset first.",
        )

    obs, reward, terminated, truncated, info = _world.step(body.action)
    done = terminated or truncated

    # Clamp step reward strictly to (0.01, 0.99) β€” Phase 2 requirement
    clamped_reward = min(max(round(reward, 4), 0.10), 0.90)

    return {
        "observation": obs,
        "reward": clamped_reward,
        "done": done,
        "info": info,
    }


# ── /grader ───────────────────────────────────────────────────────────────────

@app.post("/grader")
def grade_episode(body: Dict[str, Any] = None) -> Dict[str, Any]:
    """Return the grading breakdown for the current live episode.

    Reads the world state already mutated by /reset + /step calls and
    scores it β€” no re-execution, no fake episode.
    """
    global _world

    if _world._router is None or not _world._raw:
        raise HTTPException(
            status_code=409,
            detail="No active episode. Call POST /reset first, then POST /step.",
        )

    score = grade(_world._raw)
    # Ensure score is strictly in (0, 1) β€” never exactly 0.0 or 1.0
    raw_score = score.get("score", 0.10)
    final_score = max(0.10, min(0.90, raw_score))
    return {
        "task_id": _world._raw.get("task_id", "unknown"),
        "score": round(final_score, 4),
        "breakdown": score.get("breakdown", {}),
        "weights": score.get("weights", {}),
        "details": score.get("details", {}),
        "steps_taken": _world.step_count,
        "episode_done": _world.done,
    }


# ── /tasks ────────────────────────────────────────────────────────────────────

@app.get("/tasks")
def list_tasks() -> List[Dict[str, Any]]:
    """List all available tasks with their schemas."""
    tasks = []
    for task_id, template in TASK_TEMPLATES.items():
        tasks.append({
            "id": task_id,
            "name": task_id.replace("_", " ").title(),
            "description": template["task_goal"],
            "difficulty": template["difficulty"],
            "grader": f"graders.{task_id}_grader",
            "time_limit_seconds": 600,
            "max_steps": template["max_steps"],
            "available_tools": template["available_tools"],
            "grading_weights": template["grading_weights"],
            "action_schema": {
                "type": "object",
                "properties": {
                    "tool_name": {
                        "type": "string",
                        "enum": template["available_tools"],
                        "description": "Tool to call",
                    },
                    "arguments": {
                        "type": "object",
                        "description": "Tool-specific arguments",
                    },
                },
                "required": ["tool_name"],
            },
        })
    return tasks


# ── /baseline ─────────────────────────────────────────────────────────────────

@app.post("/baseline")
def run_baseline() -> Dict[str, Any]:
    """Run a naive baseline agent on all 3 tasks and return scores."""
    results = []

    for task_id in TASK_TEMPLATES:
        ws = WorldState()
        ws.reset(seed=42, task_id=task_id)
        total_reward = 0.0
        steps = 0

        if task_id == "access_decision":
            _, r, _, _, _ = ws.step({
                "tool_name": "access.decide",
                "arguments": {"decision": "approve", "role": "viewer",
                              "ttl_hours": 4, "justification_category": "operational"},
            })
            total_reward += r
            steps = 1
        elif task_id == "access_review":
            _, r, _, _, _ = ws.step({
                "tool_name": "review.submit",
                "arguments": {"summary": "baseline review"},
            })
            total_reward += r
            steps = 1
        elif task_id == "jit_escalation":
            req_id = next(iter(ws._raw.get("pending_requests", {})), None)
            if req_id:
                _, r, _, _, _ = ws.step({
                    "tool_name": "access.grant",
                    "arguments": {"request_id": req_id},
                })
                total_reward += r
                steps = 1

        score_info = grade(ws._raw)
        # Ensure score is strictly in (0, 1) β€” never exactly 0.0 or 1.0
        raw_score = score_info.get("score", 0.10)
        episode_score = max(0.10, min(0.90, raw_score))
        results.append({
            "task_id": task_id,
            "steps": steps,
            "total_step_reward": round(total_reward, 4),
            "episode_score": round(episode_score, 4),
        })

    avg_score = sum(r["episode_score"] for r in results) / len(results) if results else 0.0

    return {
        "baseline_agent": "naive_terminal_tool",
        "results": results,
        "average_episode_score": round(avg_score, 4),
        "note": "Naive baseline β€” hits the terminal tool immediately with default args.",
    }


# ── Server entry point ────────────────────────────────────────────────────────

def main(host: str = "0.0.0.0", port: int = 8000):
    import uvicorn
    uvicorn.run(app, host=host, port=port)


if __name__ == '__main__':
    main()