File size: 3,270 Bytes
7952f32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
"""FastAPI application — the OpenEnv server.

Endpoints (PROPOSAL.md §6.1):

  POST /reset   { task_id?: str | None, seed?: int }
                -> { episode_id, observation }
  POST /step    { episode_id, action: Action }
                -> { observation, reward, done, info }
  GET  /state?episode_id=...
                -> { ... full snapshot ... }
  POST /close   { episode_id }
                -> { closed: bool }

The handlers are thin: routing, request validation, episode lookup. The
actual per-step orchestration lives in :mod:`graphforge.server.runner`.
"""

from __future__ import annotations

from typing import Any, Optional

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel

from graphforge.actions.schema import Action
from graphforge.server.episode import GLOBAL_STORE, Episode, EpisodeStore
from graphforge.server.runner import step as runner_step
from graphforge.tasks import default_task, get_task

app = FastAPI(
    title="GraphForge OpenEnv server",
    version="0.0.1",
    description="See graphforge.server for the wire shape.",
)


# ---- request / response models --------------------------------------


class ResetRequest(BaseModel):
    task_id: Optional[str] = None
    seed: Optional[int] = None  # reserved for variant generation, unused for tier-0


class StepRequest(BaseModel):
    episode_id: str
    # ``Action`` is itself an Annotated discriminated union; no need to
    # re-declare the discriminator on this field.
    action: Action


class CloseRequest(BaseModel):
    episode_id: str


# ---- store wiring (overridable for tests) ---------------------------


def _store() -> EpisodeStore:
    return GLOBAL_STORE


# ---- helpers --------------------------------------------------------


def _require_episode(episode_id: str) -> Episode:
    ep = _store().get(episode_id)
    if ep is None:
        raise HTTPException(status_code=404, detail=f"unknown episode_id: {episode_id!r}")
    return ep


def _initial_observation(ep: Episode) -> dict[str, Any]:
    return {
        "episode_id": ep.id,
        "task": ep.task.visible_payload(),
        "turns_total": 0,
        "tokens_used_total": 0,
        "budget": ep.task.budget,
        "episode_cap": ep.task.episode_cap,
    }


# ---- endpoints ------------------------------------------------------


@app.post("/reset")
def reset(req: ResetRequest) -> dict:
    if req.task_id is None:
        task = default_task()
    else:
        t = get_task(req.task_id)
        if t is None:
            raise HTTPException(status_code=404, detail=f"unknown task_id: {req.task_id!r}")
        task = t
    ep = Episode.new(task=task)
    _store().put(ep)
    return {
        "episode_id": ep.id,
        "observation": _initial_observation(ep),
    }


@app.post("/step")
def step(req: StepRequest) -> dict:
    ep = _require_episode(req.episode_id)
    return runner_step(ep, req.action)


@app.get("/state")
def state(episode_id: str) -> dict:
    ep = _require_episode(episode_id)
    return ep.state_snapshot()


@app.post("/close")
def close(req: CloseRequest) -> dict:
    closed = _store().drop(req.episode_id)
    return {"closed": closed}


@app.get("/healthz")
def healthz() -> dict:
    return {"status": "ok", "version": app.version}