Spaces:
Sleeping
Sleeping
File size: 4,634 Bytes
7952f32 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 | """Episode runner — the per-step orchestration the server endpoints use.
Pulls together dispatcher, reward engine, constraint checker, and episode
state. Kept separate from the FastAPI app so it can be unit-tested without
spinning up an HTTP server.
"""
from __future__ import annotations
from typing import Any
from graphforge.actions import dispatch
from graphforge.actions.schema import Action, Submit
from graphforge.constraints import evaluate_all
from graphforge.materializer import materialize
from graphforge.reward.engine import (
ActionOutcome,
TurnReward,
score_terminal,
score_turn,
)
from graphforge.server.episode import (
Episode,
TurnRecord,
estimate_tokens,
)
from graphforge.validator import full_check
def _classify_outcome(action: Action, ok: bool) -> ActionOutcome:
# Schema rejection happens before this function (caught by FastAPI's
# pydantic validation). What we see here is a successfully-parsed
# action that either succeeded or failed at handler-time.
return ActionOutcome.SUCCESS if ok else ActionOutcome.FAILURE
def _render_observation(ep: Episode, turn_record: TurnRecord) -> dict[str, Any]:
return {
"turn": turn_record.turn,
"ok": turn_record.ok,
"outcome": turn_record.outcome,
"payload": turn_record.payload,
"reward": turn_record.reward,
"is_duplicate": turn_record.is_duplicate,
"tokens_returned": turn_record.tokens_returned,
"tokens_used_total": ep.tokens_used,
"turns_total": ep.turns,
"budget_remaining": max(0, ep.task.budget - ep.tokens_used),
"episode_cap_remaining": max(0, ep.task.episode_cap - ep.turns),
}
def step(ep: Episode, action: Action) -> dict[str, Any]:
"""Apply ``action`` to ``ep``. Auto-terminates on submit or cap.
Returns a dict in the OpenEnv ``/step`` response shape:
``{observation, reward, done, info}``.
"""
if ep.terminated:
return {
"observation": {},
"reward": 0.0,
"done": True,
"info": {"error": "episode_already_terminated"},
}
args = action.model_dump(exclude={"kind"})
kind = action.kind # type: ignore[attr-defined]
is_duplicate = ep.is_duplicate(kind, args)
result = dispatch(ep.graph, action)
tokens_returned = estimate_tokens(result.payload)
outcome = _classify_outcome(action, result.ok)
turn_reward = score_turn(
outcome=outcome,
is_duplicate=is_duplicate,
tokens_returned=tokens_returned,
)
rec = ep.record_turn(
kind=kind,
args=args,
result=result,
outcome=outcome,
turn_reward=turn_reward,
is_duplicate=is_duplicate,
tokens_returned=tokens_returned,
)
done = False
info: dict[str, Any] = {}
# Terminate on Submit.
if isinstance(action, Submit):
done = True
terminal = _score_terminal(ep)
ep.terminated = True
ep.terminal_reward = terminal["total"]
ep.terminal_payload = terminal
info["terminal"] = terminal
# Terminate on episode cap.
if not done and ep.turns >= ep.task.episode_cap:
done = True
terminal = _score_terminal(ep)
ep.terminated = True
ep.terminal_reward = terminal["total"]
ep.terminal_payload = terminal
info["terminal"] = terminal
info["reason"] = "episode_cap_reached"
return {
"observation": _render_observation(ep, rec),
"reward": rec.reward + (info.get("terminal", {}).get("total", 0.0) if done else 0.0),
"done": done,
"info": info,
}
def _score_terminal(ep: Episode) -> dict[str, Any]:
"""Compute terminal reward + return a serialized payload."""
sat = evaluate_all(ep.graph, ep.task.all_constraints)
structural, behavioral = sat.split_by_family()
# materialization gate: try to materialize + parse-check.
materialization_ok = False
try:
files = materialize(ep.graph)
materialization_ok = full_check(files).ok
except Exception:
materialization_ok = False
reward = score_terminal(
n_structural_satisfied=len(structural.satisfied),
n_structural_total=structural.total,
n_behavioral_passing=len(behavioral.satisfied),
n_behavioral_total=behavioral.total,
materialization_ok=materialization_ok,
type_checks_ok=None, # mypy not wired yet
tokens_used=ep.tokens_used,
budget=ep.task.budget,
)
out = reward.to_dict()
out["satisfaction"] = sat.to_dict()
return out
|