5ivatej commited on
Commit
ce4a7da
·
1 Parent(s): e1ec6bc

Make HTTP sessions stateless for multi-node deployment

Browse files
Files changed (2) hide show
  1. server.py +70 -14
  2. src/env.py +72 -0
server.py CHANGED
@@ -14,7 +14,12 @@ instances.
14
  """
15
  from __future__ import annotations
16
 
17
- from uuid import uuid4
 
 
 
 
 
18
 
19
  from fastapi import FastAPI, HTTPException
20
  from fastapi import Request, Response
@@ -33,14 +38,52 @@ app = FastAPI(
33
  )
34
 
35
  SESSION_COOKIE = "esc_session_id"
36
- _envs: dict[str, ESCEnv] = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
 
39
  def _get_env_for_request(request: Request) -> ESCEnv:
40
- session_id = request.cookies.get(SESSION_COOKIE)
41
- if not session_id or session_id not in _envs:
42
  raise RuntimeError("env.step() called before reset()")
43
- return _envs[session_id]
44
 
45
 
46
  @app.get("/")
@@ -61,27 +104,40 @@ def list_tasks() -> dict:
61
  @app.post("/reset")
62
  def reset(request: Request, response: Response, req: ResetRequest | None = None) -> dict:
63
  req = req or ResetRequest()
64
- session_id = request.cookies.get(SESSION_COOKIE)
65
- if not session_id:
66
- session_id = uuid4().hex
67
- env = _envs.get(session_id)
68
- if env is None:
 
 
69
  env = ESCEnv()
70
- _envs[session_id] = env
71
  try:
72
  result = env.reset(task_id=req.task_id, seed=req.seed)
73
  except KeyError as e:
74
  raise HTTPException(status_code=400, detail=str(e))
75
- response.set_cookie(key=SESSION_COOKIE, value=session_id, httponly=True, samesite="lax")
 
 
 
 
 
76
  return result.model_dump()
77
 
78
 
79
  @app.post("/step")
80
- def step(req: StepRequest, request: Request) -> dict:
81
  try:
82
- result = _get_env_for_request(request).step(req.action)
 
83
  except RuntimeError as e:
84
  raise HTTPException(status_code=409, detail=str(e))
 
 
 
 
 
 
85
  return result.model_dump()
86
 
87
 
 
14
  """
15
  from __future__ import annotations
16
 
17
+ import base64
18
+ import hashlib
19
+ import hmac
20
+ import json
21
+ import os
22
+ import zlib
23
 
24
  from fastapi import FastAPI, HTTPException
25
  from fastapi import Request, Response
 
38
  )
39
 
40
  SESSION_COOKIE = "esc_session_id"
41
+ SESSION_SECRET = os.getenv("ESC_SESSION_SECRET", "esc-openenv-dev-secret").encode("utf-8")
42
+
43
+
44
+ def _urlsafe_b64encode(data: bytes) -> str:
45
+ return base64.urlsafe_b64encode(data).decode("ascii")
46
+
47
+
48
+ def _urlsafe_b64decode(data: str) -> bytes:
49
+ padding = "=" * (-len(data) % 4)
50
+ return base64.urlsafe_b64decode(data + padding)
51
+
52
+
53
+ def _sign(payload: str) -> str:
54
+ return hmac.new(SESSION_SECRET, payload.encode("utf-8"), hashlib.sha256).hexdigest()
55
+
56
+
57
+ def _encode_env(env: ESCEnv) -> str:
58
+ raw = json.dumps(env.export_state(), separators=(",", ":"), ensure_ascii=False).encode("utf-8")
59
+ payload = _urlsafe_b64encode(zlib.compress(raw, level=9))
60
+ return f"{payload}.{_sign(payload)}"
61
+
62
+
63
+ def _decode_env(token: str) -> ESCEnv:
64
+ try:
65
+ payload, signature = token.rsplit(".", 1)
66
+ except ValueError as exc:
67
+ raise RuntimeError("Invalid session token") from exc
68
+
69
+ expected = _sign(payload)
70
+ if not hmac.compare_digest(signature, expected):
71
+ raise RuntimeError("Invalid session signature")
72
+
73
+ try:
74
+ compressed = _urlsafe_b64decode(payload)
75
+ data = json.loads(zlib.decompress(compressed).decode("utf-8"))
76
+ except Exception as exc:
77
+ raise RuntimeError("Invalid session payload") from exc
78
+
79
+ return ESCEnv.from_state(data)
80
 
81
 
82
  def _get_env_for_request(request: Request) -> ESCEnv:
83
+ token = request.cookies.get(SESSION_COOKIE)
84
+ if not token:
85
  raise RuntimeError("env.step() called before reset()")
86
+ return _decode_env(token)
87
 
88
 
89
  @app.get("/")
 
104
  @app.post("/reset")
105
  def reset(request: Request, response: Response, req: ResetRequest | None = None) -> dict:
106
  req = req or ResetRequest()
107
+ token = request.cookies.get(SESSION_COOKIE)
108
+ if token:
109
+ try:
110
+ env = _decode_env(token)
111
+ except RuntimeError:
112
+ env = ESCEnv()
113
+ else:
114
  env = ESCEnv()
 
115
  try:
116
  result = env.reset(task_id=req.task_id, seed=req.seed)
117
  except KeyError as e:
118
  raise HTTPException(status_code=400, detail=str(e))
119
+ response.set_cookie(
120
+ key=SESSION_COOKIE,
121
+ value=_encode_env(env),
122
+ httponly=True,
123
+ samesite="lax",
124
+ )
125
  return result.model_dump()
126
 
127
 
128
  @app.post("/step")
129
+ def step(req: StepRequest, request: Request, response: Response) -> dict:
130
  try:
131
+ env = _get_env_for_request(request)
132
+ result = env.step(req.action)
133
  except RuntimeError as e:
134
  raise HTTPException(status_code=409, detail=str(e))
135
+ response.set_cookie(
136
+ key=SESSION_COOKIE,
137
+ value=_encode_env(env),
138
+ httponly=True,
139
+ samesite="lax",
140
+ )
141
  return result.model_dump()
142
 
143
 
src/env.py CHANGED
@@ -215,3 +215,75 @@ class ESCEnv:
215
  }
216
  for t in TASKS.values()
217
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
  }
216
  for t in TASKS.values()
217
  ]
218
+
219
+ # ------------------------------------------------------------- serialization
220
+
221
+ def export_state(self) -> Dict[str, Any]:
222
+ if self._task is None or self._seeker is None:
223
+ raise RuntimeError("env.export_state() called before reset()")
224
+
225
+ seeker_state = {
226
+ "distress": self._seeker.distress,
227
+ "trust": self._seeker.trust,
228
+ "openness": self._seeker.openness,
229
+ "revealed": self._seeker.revealed,
230
+ "stage": self._seeker.stage.value,
231
+ "last_line_idx_by_stage": {
232
+ stage.value: idx for stage, idx in self._seeker.last_line_idx_by_stage.items()
233
+ },
234
+ "turn": self._seeker.turn,
235
+ }
236
+
237
+ return {
238
+ "task_id": self._task.id,
239
+ "turn": self._turn,
240
+ "done": self._done,
241
+ "cumulative_reward": self._cumulative_reward,
242
+ "transcript": list(self._transcript),
243
+ "agent_messages": list(self._agent_messages),
244
+ "had_safety_reference": self._had_safety_reference,
245
+ "seeker": seeker_state,
246
+ }
247
+
248
+ @classmethod
249
+ def from_state(cls, data: Dict[str, Any]) -> "ESCEnv":
250
+ task = get_task(str(data["task_id"]))
251
+ seeker_data = data["seeker"]
252
+
253
+ env = cls()
254
+ env._task = task
255
+ env._turn = int(data["turn"])
256
+ env._done = bool(data["done"])
257
+ env._cumulative_reward = float(data["cumulative_reward"])
258
+ env._transcript = list(data.get("transcript", []))
259
+ env._agent_messages = list(data.get("agent_messages", []))
260
+ env._had_safety_reference = bool(data.get("had_safety_reference", False))
261
+ env._seeker = SeekerState(
262
+ persona=task.persona,
263
+ distress=float(seeker_data["distress"]),
264
+ trust=float(seeker_data["trust"]),
265
+ openness=float(seeker_data["openness"]),
266
+ revealed=bool(seeker_data["revealed"]),
267
+ stage=Stage(str(seeker_data["stage"])),
268
+ last_line_idx_by_stage={
269
+ Stage(stage_name): int(idx)
270
+ for stage_name, idx in seeker_data["last_line_idx_by_stage"].items()
271
+ },
272
+ turn=int(seeker_data["turn"]),
273
+ )
274
+
275
+ if env._transcript:
276
+ last_seeker_text = next(
277
+ (entry["text"] for entry in reversed(env._transcript) if entry.get("role") == "seeker"),
278
+ task.persona.surface_concern,
279
+ )
280
+ env._last_obs = Observation(
281
+ seeker_utterance=last_seeker_text,
282
+ turn=env._turn,
283
+ remaining_turns=max(0, task.max_turns - env._turn),
284
+ stage_hint=env._seeker.stage.value,
285
+ task_id=task.id,
286
+ scenario_brief=task.persona.scenario_brief,
287
+ )
288
+
289
+ return env