File size: 5,369 Bytes
7952f32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
"""Episode state — one per active OpenEnv session.

The server holds episodes in an in-memory dict keyed by ``episode_id``.
Episodes are entirely self-contained: they own a :class:`Graph`, a
:class:`Task`, and the running history. There is no leakage between
episodes (PROPOSAL.md §6.2 — "episode isolation").

Token accounting is a server-side concern. We use a coarse character-based
estimate (``len(json) // 4``) until a real tokenizer is wired in. The
estimate is consistent across baseline and trained runs because both go
through the same envelope.
"""

from __future__ import annotations

import json
import uuid
from dataclasses import dataclass, field
from typing import Any

from graphforge.actions.dispatcher import ActionResult
from graphforge.graph.schema import Graph
from graphforge.reward.engine import ActionOutcome, TurnReward
from graphforge.tasks.schema import Task


# ---- token estimation -----------------------------------------------


def estimate_tokens(payload: Any) -> int:
    """Coarse token estimate over a JSON-serializable payload.

    ~4 chars / token is the GPT-style rule of thumb. The exact tokenizer
    matters for training-time reward magnitudes; this estimate is a
    placeholder that's monotone in the size of the payload, which is
    enough to drive the 'prefer cheap queries over expensive ones' shaping
    while we wait on the real Qwen tokenizer.
    """
    try:
        s = json.dumps(payload, default=str)
    except Exception:
        s = str(payload)
    return max(0, len(s) // 4)


# ---- history records ------------------------------------------------


@dataclass
class TurnRecord:
    turn: int
    action_kind: str
    action_args: dict[str, Any]
    outcome: str       # ActionOutcome value
    ok: bool
    reward: float
    payload: dict[str, Any] = field(default_factory=dict)
    is_duplicate: bool = False
    tokens_returned: int = 0


# ---- episode --------------------------------------------------------


@dataclass
class Episode:
    id: str
    task: Task
    graph: Graph = field(default_factory=Graph.empty)
    history: list[TurnRecord] = field(default_factory=list)
    tokens_used: int = 0
    turns: int = 0
    terminated: bool = False
    terminal_reward: float | None = None
    terminal_payload: dict[str, Any] | None = None

    @classmethod
    def new(cls, task: Task) -> "Episode":
        return cls(id=str(uuid.uuid4()), task=task)

    # ----- duplicate detection ---------------------------------------

    def is_duplicate(self, kind: str, args: dict[str, Any]) -> bool:
        """True iff an identical (kind, args) action was tried this episode."""
        for r in self.history:
            if r.action_kind == kind and r.action_args == args:
                return True
        return False

    # ----- bookkeeping -----------------------------------------------

    def record_turn(
        self,
        kind: str,
        args: dict[str, Any],
        result: ActionResult,
        outcome: ActionOutcome,
        turn_reward: TurnReward,
        is_duplicate: bool,
        tokens_returned: int,
    ) -> TurnRecord:
        rec = TurnRecord(
            turn=self.turns,
            action_kind=kind,
            action_args=args,
            outcome=outcome.value,
            ok=result.ok,
            reward=turn_reward.total,
            payload=result.payload,
            is_duplicate=is_duplicate,
            tokens_returned=tokens_returned,
        )
        self.history.append(rec)
        self.turns += 1
        self.tokens_used += tokens_returned
        return rec

    # ----- snapshot --------------------------------------------------

    def state_snapshot(self) -> dict[str, Any]:
        return {
            "episode_id": self.id,
            "task": self.task.visible_payload(),
            "turns": self.turns,
            "tokens_used": self.tokens_used,
            "budget": self.task.budget,
            "episode_cap": self.task.episode_cap,
            "terminated": self.terminated,
            "graph": {
                "modules": [m.model_dump() for m in self.graph.modules],
                "nodes": [n.model_dump() for n in self.graph.nodes],
                "edges": [e.model_dump() for e in self.graph.edges],
            },
            "history": [
                {
                    "turn": r.turn,
                    "action_kind": r.action_kind,
                    "ok": r.ok,
                    "reward": r.reward,
                }
                for r in self.history
            ],
            "terminal_reward": self.terminal_reward,
        }


# ---- in-memory store ------------------------------------------------


class EpisodeStore:
    """Thin wrapper around a dict so we can swap in a TTL cache later."""

    def __init__(self) -> None:
        self._eps: dict[str, Episode] = {}

    def put(self, ep: Episode) -> None:
        self._eps[ep.id] = ep

    def get(self, episode_id: str) -> Episode | None:
        return self._eps.get(episode_id)

    def drop(self, episode_id: str) -> bool:
        return self._eps.pop(episode_id, None) is not None

    def __len__(self) -> int:
        return len(self._eps)


# Singleton store. The server module holds onto this for the lifetime of
# the process. Tests can construct their own EpisodeStore for isolation.
GLOBAL_STORE = EpisodeStore()