File size: 4,634 Bytes
7952f32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
"""Episode runner — the per-step orchestration the server endpoints use.

Pulls together dispatcher, reward engine, constraint checker, and episode
state. Kept separate from the FastAPI app so it can be unit-tested without
spinning up an HTTP server.
"""

from __future__ import annotations

from typing import Any

from graphforge.actions import dispatch
from graphforge.actions.schema import Action, Submit
from graphforge.constraints import evaluate_all
from graphforge.materializer import materialize
from graphforge.reward.engine import (
    ActionOutcome,
    TurnReward,
    score_terminal,
    score_turn,
)
from graphforge.server.episode import (
    Episode,
    TurnRecord,
    estimate_tokens,
)
from graphforge.validator import full_check


def _classify_outcome(action: Action, ok: bool) -> ActionOutcome:
    # Schema rejection happens before this function (caught by FastAPI's
    # pydantic validation). What we see here is a successfully-parsed
    # action that either succeeded or failed at handler-time.
    return ActionOutcome.SUCCESS if ok else ActionOutcome.FAILURE


def _render_observation(ep: Episode, turn_record: TurnRecord) -> dict[str, Any]:
    return {
        "turn": turn_record.turn,
        "ok": turn_record.ok,
        "outcome": turn_record.outcome,
        "payload": turn_record.payload,
        "reward": turn_record.reward,
        "is_duplicate": turn_record.is_duplicate,
        "tokens_returned": turn_record.tokens_returned,
        "tokens_used_total": ep.tokens_used,
        "turns_total": ep.turns,
        "budget_remaining": max(0, ep.task.budget - ep.tokens_used),
        "episode_cap_remaining": max(0, ep.task.episode_cap - ep.turns),
    }


def step(ep: Episode, action: Action) -> dict[str, Any]:
    """Apply ``action`` to ``ep``. Auto-terminates on submit or cap.

    Returns a dict in the OpenEnv ``/step`` response shape:
    ``{observation, reward, done, info}``.
    """
    if ep.terminated:
        return {
            "observation": {},
            "reward": 0.0,
            "done": True,
            "info": {"error": "episode_already_terminated"},
        }

    args = action.model_dump(exclude={"kind"})
    kind = action.kind  # type: ignore[attr-defined]
    is_duplicate = ep.is_duplicate(kind, args)

    result = dispatch(ep.graph, action)
    tokens_returned = estimate_tokens(result.payload)
    outcome = _classify_outcome(action, result.ok)
    turn_reward = score_turn(
        outcome=outcome,
        is_duplicate=is_duplicate,
        tokens_returned=tokens_returned,
    )
    rec = ep.record_turn(
        kind=kind,
        args=args,
        result=result,
        outcome=outcome,
        turn_reward=turn_reward,
        is_duplicate=is_duplicate,
        tokens_returned=tokens_returned,
    )

    done = False
    info: dict[str, Any] = {}

    # Terminate on Submit.
    if isinstance(action, Submit):
        done = True
        terminal = _score_terminal(ep)
        ep.terminated = True
        ep.terminal_reward = terminal["total"]
        ep.terminal_payload = terminal
        info["terminal"] = terminal

    # Terminate on episode cap.
    if not done and ep.turns >= ep.task.episode_cap:
        done = True
        terminal = _score_terminal(ep)
        ep.terminated = True
        ep.terminal_reward = terminal["total"]
        ep.terminal_payload = terminal
        info["terminal"] = terminal
        info["reason"] = "episode_cap_reached"

    return {
        "observation": _render_observation(ep, rec),
        "reward": rec.reward + (info.get("terminal", {}).get("total", 0.0) if done else 0.0),
        "done": done,
        "info": info,
    }


def _score_terminal(ep: Episode) -> dict[str, Any]:
    """Compute terminal reward + return a serialized payload."""
    sat = evaluate_all(ep.graph, ep.task.all_constraints)
    structural, behavioral = sat.split_by_family()

    # materialization gate: try to materialize + parse-check.
    materialization_ok = False
    try:
        files = materialize(ep.graph)
        materialization_ok = full_check(files).ok
    except Exception:
        materialization_ok = False

    reward = score_terminal(
        n_structural_satisfied=len(structural.satisfied),
        n_structural_total=structural.total,
        n_behavioral_passing=len(behavioral.satisfied),
        n_behavioral_total=behavioral.total,
        materialization_ok=materialization_ok,
        type_checks_ok=None,  # mypy not wired yet
        tokens_used=ep.tokens_used,
        budget=ep.task.budget,
    )
    out = reward.to_dict()
    out["satisfaction"] = sat.to_dict()
    return out