File size: 7,725 Bytes
7952f32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
"""Multi-turn rollout — the bridge between the env and a policy.

For each turn:

  1. The policy is sampled, given the conversation so far. It returns a
     single text completion.
  2. The completion is parsed to extract the tool call. If parsing fails,
     a synthetic ``schema_rejection`` step is recorded with the reward
     engine's MALFORMED magnitude and the loop continues.
  3. The tool call is forwarded to the env via ``EnvClient.step``. The env
     returns ``{observation, reward, done, info}``.
  4. The observation is appended to the conversation as a user turn.
  5. We stop on ``done`` or when ``episode_cap`` is reached.

After the loop we compute discounted returns from each turn and produce a
list of ``TurnSample(prompt_messages, completion_text, reward, return_)``
tuples — exactly the shape ``trl.GRPOTrainer`` consumes when wrapped with
a custom reward function.

The rollout is environment-agnostic via :class:`EnvClient` and
policy-agnostic via :class:`Policy`. Both come from sibling modules; the
rollout function never imports torch or httpx directly.
"""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any

from graphforge.reward.engine import (
    DUPLICATE_ACTION,
    PER_TURN_COST,
    SCHEMA_REJECTION,
)
from graphforge.training.client import EnvClient
from graphforge.training.policy import Policy
from graphforge.training.prompt import (
    Message,
    append_completion,
    append_observation,
    initial_messages,
)
from graphforge.training.protocol import (
    ParseFailure,
    ParseSuccess,
    parse_completion,
)


# ---- per-turn record -------------------------------------------------


@dataclass
class TurnSample:
    """Single (prompt, completion, reward, return) tuple for the trainer.

    ``prompt_messages`` is the conversation up to (but not including) the
    assistant's completion at this turn.
    """

    turn: int
    prompt_messages: list[Message]
    completion_text: str
    reward: float
    return_: float = 0.0

    # Diagnostics; not consumed by the trainer.
    parse_ok: bool = True
    parse_failure_code: str | None = None
    env_response: dict[str, Any] = field(default_factory=dict)
    done: bool = False


@dataclass
class Trajectory:
    episode_id: str
    task_id: str
    samples: list[TurnSample] = field(default_factory=list)
    terminated_naturally: bool = False
    terminal_total: float | None = None

    @property
    def total_reward(self) -> float:
        return sum(s.reward for s in self.samples)

    def __len__(self) -> int:
        return len(self.samples)


# ---- rollout ---------------------------------------------------------


def rollout(
    *,
    policy: Policy,
    env: EnvClient,
    task_id: str | None = None,
    seed: int | None = None,
    gamma: float = 0.97,
    max_turns: int | None = None,
    auto_close: bool = True,
) -> Trajectory:
    """Run one episode end-to-end. Returns a :class:`Trajectory`.

    ``max_turns`` overrides the task's ``episode_cap`` if specified
    (useful for unit tests). Otherwise the env's own cap fires first.
    ``auto_close`` calls ``env.close`` when the episode ends.
    """
    reset_resp = env.reset(task_id=task_id, seed=seed)
    episode_id = reset_resp["episode_id"]
    task_visible = reset_resp["observation"]["task"]
    cap = max_turns or task_visible["episode_cap"]

    messages = initial_messages(task_visible)
    samples: list[TurnSample] = []
    done = False
    terminal_total: float | None = None

    for turn_idx in range(cap):
        # 1. Sample the policy.
        completion = policy.sample(messages)
        prompt_at_turn = list(messages)  # snapshot before appending the assistant turn

        # 2. Parse the tool call.
        parsed = parse_completion(completion)

        if isinstance(parsed, ParseFailure):
            # Synthetic step — env never sees the action. Reward mirrors
            # the MALFORMED branch of score_turn (no token cost because
            # nothing came back from the env).
            reward = SCHEMA_REJECTION + PER_TURN_COST
            sample = TurnSample(
                turn=turn_idx,
                prompt_messages=prompt_at_turn,
                completion_text=completion,
                reward=reward,
                parse_ok=False,
                parse_failure_code=parsed.code,
            )
            samples.append(sample)
            messages = append_completion(messages, completion)
            messages = append_observation(
                messages,
                {
                    "ok": False,
                    "outcome": "malformed",
                    "is_duplicate": False,
                    "reward": reward,
                    "payload": {"error": parsed.code, "message": parsed.message},
                    "turns_total": turn_idx + 1,
                    "tokens_used_total": 0,
                    "budget_remaining": task_visible["budget"],
                    "episode_cap_remaining": cap - (turn_idx + 1),
                },
            )
            continue

        # 3. Forward to env.
        assert isinstance(parsed, ParseSuccess)
        env_resp = env.step(episode_id, parsed.action)

        info = env_resp.get("info", {})
        # The env client returns a synthetic response on FastAPI 422 — that's
        # a schema_rejection (e.g. unknown kind, missing required field).
        # Score it the same as a parse-side malformed completion.
        is_schema_rejection = info.get("error") == "schema_rejection"
        if is_schema_rejection:
            reward = SCHEMA_REJECTION + PER_TURN_COST
        else:
            reward = float(env_resp.get("reward", 0.0))
        done = bool(env_resp.get("done", False))

        # The embedded observation carries duplicate flags etc.
        obs = env_resp.get("observation", {})

        sample = TurnSample(
            turn=turn_idx,
            prompt_messages=prompt_at_turn,
            completion_text=completion,
            reward=reward,
            env_response=env_resp,
            done=done,
            parse_ok=not is_schema_rejection,
            parse_failure_code="env_schema_rejection" if is_schema_rejection else None,
        )
        samples.append(sample)

        messages = append_completion(messages, completion)
        messages = append_observation(messages, obs)

        if done:
            terminal_total = info.get("terminal", {}).get("total")
            break

    if auto_close:
        try:
            env.close(episode_id)
        except Exception:
            pass

    _fill_returns(samples, gamma=gamma)

    return Trajectory(
        episode_id=episode_id,
        task_id=task_visible.get("id", ""),
        samples=samples,
        terminated_naturally=done,
        terminal_total=terminal_total,
    )


# ---- discounted returns ---------------------------------------------


def _fill_returns(samples: list[TurnSample], *, gamma: float) -> None:
    """In-place fill of ``return_`` on each sample.

    return_t = r_t + gamma * return_{t+1}, with return_{T+1} = 0.
    """
    running = 0.0
    for s in reversed(samples):
        running = s.reward + gamma * running
        s.return_ = running


# ---- helper for stub-policy demo ------------------------------------


def trajectory_summary(traj: Trajectory) -> dict[str, Any]:
    return {
        "episode_id": traj.episode_id,
        "task_id": traj.task_id,
        "n_turns": len(traj),
        "total_reward": traj.total_reward,
        "terminated_naturally": traj.terminated_naturally,
        "terminal_total": traj.terminal_total,
        "parse_failures": sum(1 for s in traj.samples if not s.parse_ok),
    }