"""Episode runner — the per-step orchestration the server endpoints use. Pulls together dispatcher, reward engine, constraint checker, and episode state. Kept separate from the FastAPI app so it can be unit-tested without spinning up an HTTP server. """ from __future__ import annotations from typing import Any from graphforge.actions import dispatch from graphforge.actions.schema import Action, Submit from graphforge.constraints import evaluate_all from graphforge.materializer import materialize from graphforge.reward.engine import ( ActionOutcome, TurnReward, score_terminal, score_turn, ) from graphforge.server.episode import ( Episode, TurnRecord, estimate_tokens, ) from graphforge.validator import full_check def _classify_outcome(action: Action, ok: bool) -> ActionOutcome: # Schema rejection happens before this function (caught by FastAPI's # pydantic validation). What we see here is a successfully-parsed # action that either succeeded or failed at handler-time. return ActionOutcome.SUCCESS if ok else ActionOutcome.FAILURE def _render_observation(ep: Episode, turn_record: TurnRecord) -> dict[str, Any]: return { "turn": turn_record.turn, "ok": turn_record.ok, "outcome": turn_record.outcome, "payload": turn_record.payload, "reward": turn_record.reward, "is_duplicate": turn_record.is_duplicate, "tokens_returned": turn_record.tokens_returned, "tokens_used_total": ep.tokens_used, "turns_total": ep.turns, "budget_remaining": max(0, ep.task.budget - ep.tokens_used), "episode_cap_remaining": max(0, ep.task.episode_cap - ep.turns), } def step(ep: Episode, action: Action) -> dict[str, Any]: """Apply ``action`` to ``ep``. Auto-terminates on submit or cap. Returns a dict in the OpenEnv ``/step`` response shape: ``{observation, reward, done, info}``. """ if ep.terminated: return { "observation": {}, "reward": 0.0, "done": True, "info": {"error": "episode_already_terminated"}, } args = action.model_dump(exclude={"kind"}) kind = action.kind # type: ignore[attr-defined] is_duplicate = ep.is_duplicate(kind, args) result = dispatch(ep.graph, action) tokens_returned = estimate_tokens(result.payload) outcome = _classify_outcome(action, result.ok) turn_reward = score_turn( outcome=outcome, is_duplicate=is_duplicate, tokens_returned=tokens_returned, ) rec = ep.record_turn( kind=kind, args=args, result=result, outcome=outcome, turn_reward=turn_reward, is_duplicate=is_duplicate, tokens_returned=tokens_returned, ) done = False info: dict[str, Any] = {} # Terminate on Submit. if isinstance(action, Submit): done = True terminal = _score_terminal(ep) ep.terminated = True ep.terminal_reward = terminal["total"] ep.terminal_payload = terminal info["terminal"] = terminal # Terminate on episode cap. if not done and ep.turns >= ep.task.episode_cap: done = True terminal = _score_terminal(ep) ep.terminated = True ep.terminal_reward = terminal["total"] ep.terminal_payload = terminal info["terminal"] = terminal info["reason"] = "episode_cap_reached" return { "observation": _render_observation(ep, rec), "reward": rec.reward + (info.get("terminal", {}).get("total", 0.0) if done else 0.0), "done": done, "info": info, } def _score_terminal(ep: Episode) -> dict[str, Any]: """Compute terminal reward + return a serialized payload.""" sat = evaluate_all(ep.graph, ep.task.all_constraints) structural, behavioral = sat.split_by_family() # materialization gate: try to materialize + parse-check. materialization_ok = False try: files = materialize(ep.graph) materialization_ok = full_check(files).ok except Exception: materialization_ok = False reward = score_terminal( n_structural_satisfied=len(structural.satisfied), n_structural_total=structural.total, n_behavioral_passing=len(behavioral.satisfied), n_behavioral_total=behavioral.total, materialization_ok=materialization_ok, type_checks_ok=None, # mypy not wired yet tokens_used=ep.tokens_used, budget=ep.task.budget, ) out = reward.to_dict() out["satisfaction"] = sat.to_dict() return out