File size: 26,323 Bytes
b43d8da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
"""DriftCall env Space — FastAPI + OpenEnv-compliant REST surface.

Implements ``docs/modules/deploy_env_space.md`` and DESIGN.md §3.3 / §11.1.

Endpoints:
    GET  /healthz   → 200 text/plain "ok" (unauthenticated)
    POST /reset     → 200 application/json (create / recycle session)
    POST /step      → 200 application/json (advance one turn)
    GET  /state     → 200 application/json (read DriftCallState)
    POST /close     → 200 application/json (evict session)

Headers (mutating endpoints): ``Authorization: Bearer <DRIFTCALL_ENV_TOKEN>``
and ``X-Session-Id: <[A-Za-z0-9_-]{1,64}>``.

Error modes (deploy_env_space.md §5):
    M1 401 unauthorized          M7  400 bad_json
    M2 400 missing_session_id    M8  400 invalid_action
    M3 404 session_not_found     M9  500 internal_error
    M4 404 session_expired       M10 500 io_error
    M5 429 max_sessions          M11 413 payload_too_large
    M6 503 model_not_ready       M12 409 reset_in_progress

All error bodies: ``{"error": {"code": <slug>, "message": <str>,
"request_id": <asgi-id>}}``; ``Cache-Control: no-store``; only M5 carries
``Retry-After: 30``. No stack traces ever leak across the wire.
"""

from __future__ import annotations

import asyncio
import contextlib
import dataclasses
import json
import logging
import os
import re
import time
from contextlib import asynccontextmanager
from dataclasses import dataclass, replace
from typing import TYPE_CHECKING, Any

from fastapi import FastAPI, Request, Response
from fastapi.responses import JSONResponse, PlainTextResponse
from starlette.middleware.base import BaseHTTPMiddleware

from cells.step_04_models import ActionType, DriftCallAction
from cells.step_10_env import (
    DriftCallEnv,
    EnvClosedError,
    EnvNotReadyError,
    EpisodeAlreadyTerminalError,
    InvalidActionError,
    InvalidConfigError,
    UnknownDomainError,
    UnknownToolError,
)

if TYPE_CHECKING:
    from collections.abc import AsyncIterator, Awaitable, Callable

    from starlette.types import ASGIApp

logger = logging.getLogger(__name__)

# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------

_MAX_SESSIONS: int = 10
_TTL_S: float = 3600.0
_SWEEP_INTERVAL_S: float = 60.0
_MAX_SESSION_ID_LEN: int = 64
_SESSION_ID_RE: re.Pattern[str] = re.compile(r"^[A-Za-z0-9_-]{1,64}$")
_MAX_BODY_BYTES: int = 1 * 1024 * 1024  # 1 MiB
_RETRY_AFTER_S: str = "30"
_TOKEN_ENV_VAR: str = "DRIFTCALL_ENV_TOKEN"


# ---------------------------------------------------------------------------
# Time source (test-overridable)
# ---------------------------------------------------------------------------


def _monotonic() -> float:
    """Indirection for tests to monkeypatch."""

    return time.monotonic()


# ---------------------------------------------------------------------------
# Errors / envelope
# ---------------------------------------------------------------------------


@dataclass(frozen=True)
class _ApiError(Exception):
    """Internal exception → uniform error envelope (deploy_env_space.md §5)."""

    code: str
    message: str
    http_status: int
    retry_after: bool = False


_NO_STORE: dict[str, str] = {"Cache-Control": "no-store"}


def _error_response(err: _ApiError, request_id: str) -> JSONResponse:
    body = {
        "error": {
            "code": err.code,
            "message": err.message,
            "request_id": request_id,
        }
    }
    headers = dict(_NO_STORE)
    if err.retry_after:
        headers["Retry-After"] = _RETRY_AFTER_S
    return JSONResponse(status_code=err.http_status, content=body, headers=headers)


# ---------------------------------------------------------------------------
# Session cache
# ---------------------------------------------------------------------------


@dataclass(frozen=True)
class SessionEntry:
    """Frozen per project rule — every touch produces a new entry."""

    env: DriftCallEnv
    created_at: float
    last_touched: float
    reset_count: int
    lock: asyncio.Lock


class SessionCache:
    """In-memory session registry with LRU + TTL eviction."""

    def __init__(self, *, max_sessions: int = _MAX_SESSIONS, ttl_s: float = _TTL_S) -> None:
        self._max = max_sessions
        self._ttl = ttl_s
        self._store: dict[str, SessionEntry] = {}
        self._guard = asyncio.Lock()

    @property
    def size(self) -> int:
        return len(self._store)

    def get(self, sid: str) -> SessionEntry | None:
        return self._store.get(sid)

    async def acquire_lock(self, sid: str) -> asyncio.Lock:
        """Return (or lazily create) the per-session lock."""
        async with self._guard:
            entry = self._store.get(sid)
            if entry is not None:
                return entry.lock
            return asyncio.Lock()

    async def insert_or_replace(self, sid: str, env_factory: Callable[[], DriftCallEnv]) -> SessionEntry:
        """Insert a new env or replace an existing one (in-place reset)."""
        async with self._guard:
            now = _monotonic()
            existing = self._store.get(sid)
            if existing is not None:
                # In-place reset (§7.1 case after winner completed).
                try:
                    existing.env.close()
                except Exception:
                    logger.exception("env.close() raised on in-place reset for sid=%s", sid)
                env = env_factory()
                entry = SessionEntry(
                    env=env,
                    created_at=now,
                    last_touched=now,
                    reset_count=existing.reset_count + 1,
                    lock=existing.lock,
                )
                self._store[sid] = entry
                return entry
            # New session — enforce cap.
            if len(self._store) >= self._max:
                # Try LRU evict only if any entry is older than the others by TTL/2.
                victim_sid = min(self._store, key=lambda k: self._store[k].last_touched)
                victim = self._store[victim_sid]
                age = now - victim.last_touched
                if age <= 0.0:
                    raise _ApiError(
                        code="max_sessions",
                        message=f"max concurrent sessions reached ({self._max})",
                        http_status=429,
                        retry_after=True,
                    )
                try:
                    victim.env.close()
                except Exception:
                    logger.exception("env.close() raised on LRU eviction for sid=%s", victim_sid)
                self._store.pop(victim_sid, None)
            env = env_factory()
            entry = SessionEntry(
                env=env,
                created_at=now,
                last_touched=now,
                reset_count=0,
                lock=asyncio.Lock(),
            )
            self._store[sid] = entry
            return entry

    def touch(self, sid: str) -> tuple[SessionEntry | None, bool]:
        """Update last_touched. Returns ``(entry, was_expired)``.

        - ``(entry, False)`` on hit
        - ``(None, True)`` if the entry was present but evicted by this call
          due to TTL expiry
        - ``(None, False)`` if there was never an entry under this sid
        """
        entry = self._store.get(sid)
        if entry is None:
            return None, False
        now = _monotonic()
        if now - entry.last_touched > self._ttl:
            try:
                entry.env.close()
            except Exception:
                logger.exception("env.close() raised on expired touch for sid=%s", sid)
            self._store.pop(sid, None)
            return None, True
        new = replace(entry, last_touched=now)
        self._store[sid] = new
        return new, False

    def evict(self, sid: str) -> SessionEntry | None:
        """Pop a session out of the cache. Returns the removed entry or None."""
        return self._store.pop(sid, None)

    def sweep(self) -> int:
        """Synchronous TTL sweep — evict every entry past TTL."""
        now = _monotonic()
        expired = [sid for sid, e in self._store.items() if now - e.last_touched > self._ttl]
        for sid in expired:
            entry = self._store.pop(sid)
            try:
                entry.env.close()
            except Exception:
                logger.exception("env.close() raised on sweep for sid=%s", sid)
        if expired:
            logger.info(
                json.dumps(
                    {
                        "event": "session_sweep",
                        "expired_count": len(expired),
                        "cache_size": len(self._store),
                    }
                )
            )
        return len(expired)


# ---------------------------------------------------------------------------
# App state container
# ---------------------------------------------------------------------------


@dataclass
class _AppState:
    """Mutable (intentional) — owned by lifespan; readers go through getters."""

    cache: SessionCache
    models_ready: bool = False
    sweep_task: asyncio.Task[None] | None = None
    bearer_token: str = ""


def _get_state(app: FastAPI) -> _AppState:
    state: _AppState = app.state.driftcall
    return state


# ---------------------------------------------------------------------------
# Lifespan — eager-load Kokoro + Whisper before serving (M6 guard)
# ---------------------------------------------------------------------------


def _eager_load_models() -> None:
    """Force-load TTS + ASR singletons. Test patches this to avoid network."""
    from cells.step_09_audio import get_asr_engine, get_tts_engine

    get_tts_engine()
    get_asr_engine()


@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncIterator[None]:
    cache = SessionCache()
    token = os.environ.get(_TOKEN_ENV_VAR, "")
    if not token:
        # Fail-fast per deploy_env_space.md §3.5.
        raise RuntimeError(
            f"{_TOKEN_ENV_VAR} environment variable not set; refusing to start"
        )
    state = _AppState(cache=cache, bearer_token=token)
    app.state.driftcall = state

    # Eager model load (M6 guard — must complete before serving).
    try:
        await asyncio.to_thread(_eager_load_models)
    except Exception:
        logger.exception("eager model load failed")
        raise
    state.models_ready = True

    # Background TTL sweep.
    async def _sweep_loop() -> None:
        try:
            while True:
                await asyncio.sleep(_SWEEP_INTERVAL_S)
                cache.sweep()
        except asyncio.CancelledError:
            raise

    state.sweep_task = asyncio.create_task(_sweep_loop())
    try:
        yield
    finally:
        if state.sweep_task is not None:
            state.sweep_task.cancel()
            with contextlib.suppress(asyncio.CancelledError, Exception):
                await state.sweep_task


# ---------------------------------------------------------------------------
# Body-size middleware (M11)
# ---------------------------------------------------------------------------


class _BodySizeMiddleware(BaseHTTPMiddleware):
    def __init__(self, app: ASGIApp, *, max_bytes: int = _MAX_BODY_BYTES) -> None:
        super().__init__(app)
        self._max_bytes = max_bytes

    async def dispatch(
        self, request: Request, call_next: Callable[[Request], Awaitable[Response]]
    ) -> Response:
        cl = request.headers.get("content-length")
        if cl is not None:
            try:
                cl_int = int(cl)
            except ValueError:
                cl_int = -1
            if cl_int > self._max_bytes:
                err = _ApiError(
                    code="payload_too_large",
                    message="request body exceeds 1 MiB",
                    http_status=413,
                )
                return _error_response(err, _request_id(request))
        return await call_next(request)


# ---------------------------------------------------------------------------
# Helpers — auth, headers, body parsing
# ---------------------------------------------------------------------------


def _request_id(request: Request) -> str:
    return str(id(request))


def _check_bearer(request: Request, state: _AppState) -> None:
    auth = request.headers.get("authorization", "")
    if not auth.startswith("Bearer "):
        raise _ApiError(
            code="unauthorized",
            message="missing or non-Bearer Authorization header",
            http_status=401,
        )
    token = auth[len("Bearer ") :].strip()
    if token != state.bearer_token or not token:
        raise _ApiError(
            code="unauthorized",
            message="invalid bearer token",
            http_status=401,
        )


def _check_session_header(request: Request) -> str:
    sid = request.headers.get("x-session-id", "")
    if not sid or not _SESSION_ID_RE.match(sid):
        raise _ApiError(
            code="missing_session_id",
            message="X-Session-Id header missing or malformed",
            http_status=400,
        )
    return sid


def _check_models_ready(state: _AppState) -> None:
    if not state.models_ready:
        raise _ApiError(
            code="model_not_ready",
            message="audio models still loading; retry shortly",
            http_status=503,
        )


async def _parse_json_body(request: Request) -> dict[str, Any]:
    raw = await request.body()
    if len(raw) > _MAX_BODY_BYTES:
        raise _ApiError(
            code="payload_too_large",
            message="request body exceeds 1 MiB",
            http_status=413,
        )
    if not raw:
        return {}
    try:
        parsed = json.loads(raw)
    except (json.JSONDecodeError, UnicodeDecodeError) as exc:
        raise _ApiError(
            code="bad_json",
            message=f"malformed JSON: {exc.__class__.__name__}",
            http_status=400,
        ) from exc
    if not isinstance(parsed, dict):
        raise _ApiError(
            code="bad_json",
            message="request body must be a JSON object",
            http_status=400,
        )
    return parsed


# ---------------------------------------------------------------------------
# Action / config validation (envelope-level — env owns deep validation)
# ---------------------------------------------------------------------------


def _build_action(raw: Any) -> DriftCallAction:
    if not isinstance(raw, dict):
        raise _ApiError(
            code="invalid_action",
            message="action must be a JSON object",
            http_status=400,
        )
    atype_raw = raw.get("action_type")
    if not isinstance(atype_raw, str):
        raise _ApiError(
            code="invalid_action",
            message="action.action_type must be a string",
            http_status=400,
        )
    try:
        atype = ActionType(atype_raw)
    except ValueError as exc:
        raise _ApiError(
            code="invalid_action",
            message=f"unknown action_type {atype_raw!r}",
            http_status=400,
        ) from exc

    tool_name = raw.get("tool_name")
    tool_args = raw.get("tool_args")
    message = raw.get("message")
    confidence = raw.get("confidence")
    rationale = raw.get("rationale")

    # Action-type contract checks (deep checks happen inside env._validate_action).
    if atype == ActionType.TOOL_CALL and (
        tool_name is None or not isinstance(tool_name, str) or tool_args is None
    ):
        raise _ApiError(
            code="invalid_action",
            message="TOOL_CALL requires tool_name (str) and tool_args (object)",
            http_status=400,
        )
    return DriftCallAction(
        action_type=atype,
        tool_name=tool_name if isinstance(tool_name, str) else None,
        tool_args=tool_args if isinstance(tool_args, dict) else None,
        message=message if isinstance(message, str) else None,
        confidence=float(confidence) if isinstance(confidence, (int, float)) and not isinstance(confidence, bool) else None,
        rationale=rationale if isinstance(rationale, str) else None,
    )


def _build_env_config(reset_body: dict[str, Any]) -> dict[str, Any]:
    raw_cfg = reset_body.get("config")
    if raw_cfg is None:
        raw_cfg = {}
    if not isinstance(raw_cfg, dict):
        raise _ApiError(
            code="invalid_action",
            message="config must be a JSON object",
            http_status=400,
        )
    return raw_cfg


# ---------------------------------------------------------------------------
# Serialization helpers
# ---------------------------------------------------------------------------


def _to_jsonable(obj: Any) -> Any:
    """Recursively convert frozen dataclasses / tuples / enums to JSON-safe form."""
    if dataclasses.is_dataclass(obj) and not isinstance(obj, type):
        return {k: _to_jsonable(v) for k, v in dataclasses.asdict(obj).items()}
    if isinstance(obj, ActionType):
        return obj.value
    if isinstance(obj, dict):
        return {k: _to_jsonable(v) for k, v in obj.items()}
    if isinstance(obj, (list, tuple)):
        return [_to_jsonable(v) for v in obj]
    return obj


# ---------------------------------------------------------------------------
# Endpoint handlers (one function per route)
# ---------------------------------------------------------------------------


async def _handle_reset(request: Request, state: _AppState) -> Response:
    _check_bearer(request, state)
    _check_models_ready(state)
    sid = _check_session_header(request)
    body = await _parse_json_body(request)
    cfg = _build_env_config(body)
    seed_raw = body.get("seed")
    if seed_raw is not None and (not isinstance(seed_raw, int) or isinstance(seed_raw, bool)):
        raise _ApiError(
            code="invalid_action",
            message="seed must be an int or null",
            http_status=400,
        )
    seed: int | None = seed_raw if isinstance(seed_raw, int) and not isinstance(seed_raw, bool) else None

    cache = state.cache
    # Per-session reset lock (§7.1).
    existing = cache.get(sid)
    if existing is not None and existing.lock.locked():
        raise _ApiError(
            code="reset_in_progress",
            message="concurrent /reset on same session id",
            http_status=409,
        )

    # Acquire lock (creates one if not present).
    lock = await cache.acquire_lock(sid)
    if lock.locked():
        raise _ApiError(
            code="reset_in_progress",
            message="concurrent /reset on same session id",
            http_status=409,
        )

    async with lock:
        def _factory() -> DriftCallEnv:
            try:
                return DriftCallEnv(cfg)
            except InvalidConfigError as exc:
                raise _ApiError(
                    code="invalid_action",
                    message=f"invalid config: {exc}",
                    http_status=400,
                ) from exc

        try:
            entry = await cache.insert_or_replace(sid, _factory)
        except _ApiError:
            raise
        except Exception as exc:
            logger.exception("env construction failed for sid=%s", sid)
            raise _ApiError(
                code="internal_error",
                message="env construction failed",
                http_status=500,
            ) from exc

        try:
            obs = await asyncio.to_thread(entry.env.reset, seed)
        except InvalidConfigError as exc:
            cache.evict(sid)
            raise _ApiError(
                code="invalid_action",
                message=f"invalid config at reset: {exc}",
                http_status=400,
            ) from exc
        except OSError as exc:
            cache.evict(sid)
            raise _ApiError(
                code="io_error",
                message=f"I/O error during reset: {exc.__class__.__name__}",
                http_status=500,
            ) from exc
        except Exception as exc:
            cache.evict(sid)
            logger.exception("env.reset raised for sid=%s", sid)
            raise _ApiError(
                code="internal_error",
                message="env.reset raised",
                http_status=500,
            ) from exc

    body_out = {
        "observation": _to_jsonable(obs),
        "episode_id": entry.env.state().episode_id,
        "max_turns": entry.env.state().max_turns,
    }
    return JSONResponse(status_code=200, content=body_out)


async def _handle_step(request: Request, state: _AppState) -> Response:
    _check_bearer(request, state)
    _check_models_ready(state)
    sid = _check_session_header(request)
    body = await _parse_json_body(request)
    raw_action = body.get("action")
    action = _build_action(raw_action)

    entry, was_expired = state.cache.touch(sid)
    if entry is None:
        if was_expired:
            raise _ApiError(
                code="session_expired",
                message="session TTL expired; call /reset",
                http_status=404,
            )
        raise _ApiError(
            code="session_not_found",
            message="X-Session-Id has no live session; call /reset",
            http_status=404,
        )

    try:
        obs = await asyncio.to_thread(entry.env.step, action)
    except (InvalidActionError, UnknownToolError, UnknownDomainError) as exc:
        raise _ApiError(
            code="invalid_action",
            message=str(exc),
            http_status=400,
        ) from exc
    except (EnvNotReadyError, EnvClosedError, EpisodeAlreadyTerminalError) as exc:
        raise _ApiError(
            code="invalid_action",
            message=str(exc),
            http_status=400,
        ) from exc
    except OSError as exc:
        raise _ApiError(
            code="io_error",
            message=f"I/O error during step: {exc.__class__.__name__}",
            http_status=500,
        ) from exc
    except Exception as exc:
        logger.exception("env.step raised for sid=%s", sid)
        raise _ApiError(
            code="internal_error",
            message="env.step raised",
            http_status=500,
        ) from exc

    reward: float | None = None
    info: dict[str, Any] = {}
    if entry.env.done():
        try:
            rewards = entry.env.rewards()
            reward = float(getattr(rewards, "reward", 0.0))
            info["terminated_by"] = entry.env.episode().terminated_by
        except Exception:
            reward = None

    body_out = {
        "observation": _to_jsonable(obs),
        "reward": reward,
        "done": bool(entry.env.done()),
        "info": info,
    }
    return JSONResponse(status_code=200, content=body_out)


async def _handle_state(request: Request, state: _AppState) -> Response:
    _check_bearer(request, state)
    _check_models_ready(state)
    sid = _check_session_header(request)
    entry, was_expired = state.cache.touch(sid)
    if entry is None:
        if was_expired:
            raise _ApiError(
                code="session_expired",
                message="session TTL expired; call /reset",
                http_status=404,
            )
        raise _ApiError(
            code="session_not_found",
            message="X-Session-Id has no live session; call /reset",
            http_status=404,
        )
    try:
        st = entry.env.state()
    except EnvNotReadyError as exc:
        raise _ApiError(
            code="invalid_action",
            message=str(exc),
            http_status=400,
        ) from exc
    body_out = {"state": _to_jsonable(st), "turn": st.turn}
    return JSONResponse(status_code=200, content=body_out)


async def _handle_close(request: Request, state: _AppState) -> Response:
    _check_bearer(request, state)
    _check_models_ready(state)
    sid = _check_session_header(request)
    entry = state.cache.evict(sid)
    if entry is None:
        return JSONResponse(status_code=200, content={"closed": True, "final_state": None})
    final_state: Any = None
    try:
        final_state = _to_jsonable(entry.env.state())
    except EnvNotReadyError:
        final_state = None
    try:
        entry.env.close()
    except Exception:
        logger.exception("env.close raised on /close for sid=%s", sid)
    return JSONResponse(status_code=200, content={"closed": True, "final_state": final_state})


# ---------------------------------------------------------------------------
# App factory + route wiring
# ---------------------------------------------------------------------------


def create_app() -> FastAPI:
    """Construct a fresh FastAPI app. Used by tests to get an isolated instance."""
    app = FastAPI(lifespan=lifespan, title="DriftCall Env", version="0.1.0")
    app.add_middleware(_BodySizeMiddleware, max_bytes=_MAX_BODY_BYTES)

    @app.get("/healthz", response_class=PlainTextResponse)
    async def healthz() -> PlainTextResponse:
        return PlainTextResponse(content="ok", status_code=200)

    @app.post("/reset")
    async def reset_route(request: Request) -> Response:
        try:
            return await _handle_reset(request, _get_state(app))
        except _ApiError as err:
            return _error_response(err, _request_id(request))

    @app.post("/step")
    async def step_route(request: Request) -> Response:
        try:
            return await _handle_step(request, _get_state(app))
        except _ApiError as err:
            return _error_response(err, _request_id(request))

    @app.get("/state")
    async def state_route(request: Request) -> Response:
        try:
            return await _handle_state(request, _get_state(app))
        except _ApiError as err:
            return _error_response(err, _request_id(request))

    @app.post("/close")
    async def close_route(request: Request) -> Response:
        try:
            return await _handle_close(request, _get_state(app))
        except _ApiError as err:
            return _error_response(err, _request_id(request))

    return app


app = create_app()


__all__ = [
    "SessionCache",
    "SessionEntry",
    "app",
    "create_app",
    "lifespan",
]