Spaces:
Sleeping
Sleeping
File size: 5,369 Bytes
7952f32 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 | """Episode state — one per active OpenEnv session.
The server holds episodes in an in-memory dict keyed by ``episode_id``.
Episodes are entirely self-contained: they own a :class:`Graph`, a
:class:`Task`, and the running history. There is no leakage between
episodes (PROPOSAL.md §6.2 — "episode isolation").
Token accounting is a server-side concern. We use a coarse character-based
estimate (``len(json) // 4``) until a real tokenizer is wired in. The
estimate is consistent across baseline and trained runs because both go
through the same envelope.
"""
from __future__ import annotations
import json
import uuid
from dataclasses import dataclass, field
from typing import Any
from graphforge.actions.dispatcher import ActionResult
from graphforge.graph.schema import Graph
from graphforge.reward.engine import ActionOutcome, TurnReward
from graphforge.tasks.schema import Task
# ---- token estimation -----------------------------------------------
def estimate_tokens(payload: Any) -> int:
"""Coarse token estimate over a JSON-serializable payload.
~4 chars / token is the GPT-style rule of thumb. The exact tokenizer
matters for training-time reward magnitudes; this estimate is a
placeholder that's monotone in the size of the payload, which is
enough to drive the 'prefer cheap queries over expensive ones' shaping
while we wait on the real Qwen tokenizer.
"""
try:
s = json.dumps(payload, default=str)
except Exception:
s = str(payload)
return max(0, len(s) // 4)
# ---- history records ------------------------------------------------
@dataclass
class TurnRecord:
turn: int
action_kind: str
action_args: dict[str, Any]
outcome: str # ActionOutcome value
ok: bool
reward: float
payload: dict[str, Any] = field(default_factory=dict)
is_duplicate: bool = False
tokens_returned: int = 0
# ---- episode --------------------------------------------------------
@dataclass
class Episode:
id: str
task: Task
graph: Graph = field(default_factory=Graph.empty)
history: list[TurnRecord] = field(default_factory=list)
tokens_used: int = 0
turns: int = 0
terminated: bool = False
terminal_reward: float | None = None
terminal_payload: dict[str, Any] | None = None
@classmethod
def new(cls, task: Task) -> "Episode":
return cls(id=str(uuid.uuid4()), task=task)
# ----- duplicate detection ---------------------------------------
def is_duplicate(self, kind: str, args: dict[str, Any]) -> bool:
"""True iff an identical (kind, args) action was tried this episode."""
for r in self.history:
if r.action_kind == kind and r.action_args == args:
return True
return False
# ----- bookkeeping -----------------------------------------------
def record_turn(
self,
kind: str,
args: dict[str, Any],
result: ActionResult,
outcome: ActionOutcome,
turn_reward: TurnReward,
is_duplicate: bool,
tokens_returned: int,
) -> TurnRecord:
rec = TurnRecord(
turn=self.turns,
action_kind=kind,
action_args=args,
outcome=outcome.value,
ok=result.ok,
reward=turn_reward.total,
payload=result.payload,
is_duplicate=is_duplicate,
tokens_returned=tokens_returned,
)
self.history.append(rec)
self.turns += 1
self.tokens_used += tokens_returned
return rec
# ----- snapshot --------------------------------------------------
def state_snapshot(self) -> dict[str, Any]:
return {
"episode_id": self.id,
"task": self.task.visible_payload(),
"turns": self.turns,
"tokens_used": self.tokens_used,
"budget": self.task.budget,
"episode_cap": self.task.episode_cap,
"terminated": self.terminated,
"graph": {
"modules": [m.model_dump() for m in self.graph.modules],
"nodes": [n.model_dump() for n in self.graph.nodes],
"edges": [e.model_dump() for e in self.graph.edges],
},
"history": [
{
"turn": r.turn,
"action_kind": r.action_kind,
"ok": r.ok,
"reward": r.reward,
}
for r in self.history
],
"terminal_reward": self.terminal_reward,
}
# ---- in-memory store ------------------------------------------------
class EpisodeStore:
"""Thin wrapper around a dict so we can swap in a TTL cache later."""
def __init__(self) -> None:
self._eps: dict[str, Episode] = {}
def put(self, ep: Episode) -> None:
self._eps[ep.id] = ep
def get(self, episode_id: str) -> Episode | None:
return self._eps.get(episode_id)
def drop(self, episode_id: str) -> bool:
return self._eps.pop(episode_id, None) is not None
def __len__(self) -> int:
return len(self._eps)
# Singleton store. The server module holds onto this for the lifetime of
# the process. Tests can construct their own EpisodeStore for isolation.
GLOBAL_STORE = EpisodeStore()
|