NagaNithin-V
Deploy GraphForge OpenEnv — AST-parsed KG code-editing environment
7952f32
"""Multi-turn rollout — the bridge between the env and a policy.
For each turn:
1. The policy is sampled, given the conversation so far. It returns a
single text completion.
2. The completion is parsed to extract the tool call. If parsing fails,
a synthetic ``schema_rejection`` step is recorded with the reward
engine's MALFORMED magnitude and the loop continues.
3. The tool call is forwarded to the env via ``EnvClient.step``. The env
returns ``{observation, reward, done, info}``.
4. The observation is appended to the conversation as a user turn.
5. We stop on ``done`` or when ``episode_cap`` is reached.
After the loop we compute discounted returns from each turn and produce a
list of ``TurnSample(prompt_messages, completion_text, reward, return_)``
tuples — exactly the shape ``trl.GRPOTrainer`` consumes when wrapped with
a custom reward function.
The rollout is environment-agnostic via :class:`EnvClient` and
policy-agnostic via :class:`Policy`. Both come from sibling modules; the
rollout function never imports torch or httpx directly.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any
from graphforge.reward.engine import (
DUPLICATE_ACTION,
PER_TURN_COST,
SCHEMA_REJECTION,
)
from graphforge.training.client import EnvClient
from graphforge.training.policy import Policy
from graphforge.training.prompt import (
Message,
append_completion,
append_observation,
initial_messages,
)
from graphforge.training.protocol import (
ParseFailure,
ParseSuccess,
parse_completion,
)
# ---- per-turn record -------------------------------------------------
@dataclass
class TurnSample:
"""Single (prompt, completion, reward, return) tuple for the trainer.
``prompt_messages`` is the conversation up to (but not including) the
assistant's completion at this turn.
"""
turn: int
prompt_messages: list[Message]
completion_text: str
reward: float
return_: float = 0.0
# Diagnostics; not consumed by the trainer.
parse_ok: bool = True
parse_failure_code: str | None = None
env_response: dict[str, Any] = field(default_factory=dict)
done: bool = False
@dataclass
class Trajectory:
episode_id: str
task_id: str
samples: list[TurnSample] = field(default_factory=list)
terminated_naturally: bool = False
terminal_total: float | None = None
@property
def total_reward(self) -> float:
return sum(s.reward for s in self.samples)
def __len__(self) -> int:
return len(self.samples)
# ---- rollout ---------------------------------------------------------
def rollout(
*,
policy: Policy,
env: EnvClient,
task_id: str | None = None,
seed: int | None = None,
gamma: float = 0.97,
max_turns: int | None = None,
auto_close: bool = True,
) -> Trajectory:
"""Run one episode end-to-end. Returns a :class:`Trajectory`.
``max_turns`` overrides the task's ``episode_cap`` if specified
(useful for unit tests). Otherwise the env's own cap fires first.
``auto_close`` calls ``env.close`` when the episode ends.
"""
reset_resp = env.reset(task_id=task_id, seed=seed)
episode_id = reset_resp["episode_id"]
task_visible = reset_resp["observation"]["task"]
cap = max_turns or task_visible["episode_cap"]
messages = initial_messages(task_visible)
samples: list[TurnSample] = []
done = False
terminal_total: float | None = None
for turn_idx in range(cap):
# 1. Sample the policy.
completion = policy.sample(messages)
prompt_at_turn = list(messages) # snapshot before appending the assistant turn
# 2. Parse the tool call.
parsed = parse_completion(completion)
if isinstance(parsed, ParseFailure):
# Synthetic step — env never sees the action. Reward mirrors
# the MALFORMED branch of score_turn (no token cost because
# nothing came back from the env).
reward = SCHEMA_REJECTION + PER_TURN_COST
sample = TurnSample(
turn=turn_idx,
prompt_messages=prompt_at_turn,
completion_text=completion,
reward=reward,
parse_ok=False,
parse_failure_code=parsed.code,
)
samples.append(sample)
messages = append_completion(messages, completion)
messages = append_observation(
messages,
{
"ok": False,
"outcome": "malformed",
"is_duplicate": False,
"reward": reward,
"payload": {"error": parsed.code, "message": parsed.message},
"turns_total": turn_idx + 1,
"tokens_used_total": 0,
"budget_remaining": task_visible["budget"],
"episode_cap_remaining": cap - (turn_idx + 1),
},
)
continue
# 3. Forward to env.
assert isinstance(parsed, ParseSuccess)
env_resp = env.step(episode_id, parsed.action)
info = env_resp.get("info", {})
# The env client returns a synthetic response on FastAPI 422 — that's
# a schema_rejection (e.g. unknown kind, missing required field).
# Score it the same as a parse-side malformed completion.
is_schema_rejection = info.get("error") == "schema_rejection"
if is_schema_rejection:
reward = SCHEMA_REJECTION + PER_TURN_COST
else:
reward = float(env_resp.get("reward", 0.0))
done = bool(env_resp.get("done", False))
# The embedded observation carries duplicate flags etc.
obs = env_resp.get("observation", {})
sample = TurnSample(
turn=turn_idx,
prompt_messages=prompt_at_turn,
completion_text=completion,
reward=reward,
env_response=env_resp,
done=done,
parse_ok=not is_schema_rejection,
parse_failure_code="env_schema_rejection" if is_schema_rejection else None,
)
samples.append(sample)
messages = append_completion(messages, completion)
messages = append_observation(messages, obs)
if done:
terminal_total = info.get("terminal", {}).get("total")
break
if auto_close:
try:
env.close(episode_id)
except Exception:
pass
_fill_returns(samples, gamma=gamma)
return Trajectory(
episode_id=episode_id,
task_id=task_visible.get("id", ""),
samples=samples,
terminated_naturally=done,
terminal_total=terminal_total,
)
# ---- discounted returns ---------------------------------------------
def _fill_returns(samples: list[TurnSample], *, gamma: float) -> None:
"""In-place fill of ``return_`` on each sample.
return_t = r_t + gamma * return_{t+1}, with return_{T+1} = 0.
"""
running = 0.0
for s in reversed(samples):
running = s.reward + gamma * running
s.return_ = running
# ---- helper for stub-policy demo ------------------------------------
def trajectory_summary(traj: Trajectory) -> dict[str, Any]:
return {
"episode_id": traj.episode_id,
"task_id": traj.task_id,
"n_turns": len(traj),
"total_reward": traj.total_reward,
"terminated_naturally": traj.terminated_naturally,
"terminal_total": traj.terminal_total,
"parse_failures": sum(1 for s in traj.samples if not s.parse_ok),
}