updated-policy / training /rollouts.py
srinjoyd's picture
init
19f7f7b
"""
Episode rollout utility — drives a Policy through `IncidentEnvironment`
without going through the HTTP server. Used by every other training
module (baseline runner, ablations, dataset builder).
Usage:
from incident_env.training.rollouts import run_episode
from incident_env.training.policies import RandomPhase2Policy
from incident_env.server.incident_environment import IncidentEnvironment
env = IncidentEnvironment()
result = run_episode(env, RandomPhase2Policy(), task_name="memory_leak",
pool="B", max_steps=30)
print(result["score_breakdown"])
"""
from __future__ import annotations
from dataclasses import asdict
from typing import Any, Dict, List, Optional
from ..models import StepRecord
from ..tasks import compute_r_cross
from ..server.incident_environment import IncidentEnvironment
from .policies import Policy
def _trajectory_to_dicts(traj: List[StepRecord]) -> List[Dict[str, Any]]:
"""Serialize a List[StepRecord] to plain dicts (for JSON dumps)."""
out: List[Dict[str, Any]] = []
for r in traj:
d = asdict(r)
# IncidentAction inside dataclass — already nested-asdict'd
out.append(d)
return out
def run_episode(
env: IncidentEnvironment,
policy: Policy,
task_name: Optional[str] = None,
pool: Optional[str] = None,
mode: Optional[str] = None,
seed: Optional[int] = None,
max_steps: int = 40,
) -> Dict[str, Any]:
"""
Drive `policy` through one episode against `env`.
Returns a dict with:
task_name, pool, mode, steps_taken,
p1_trajectory, p2_trajectory, declared_patch, declared_no_change,
score_breakdown (the /score response), r_cross,
per_step_rewards.
"""
info_reset = env.reset(task_name=task_name, pool=pool, mode=mode, seed=seed)
obs = info_reset["observation"]
initial_phase = info_reset.get("info", {}).get("phase", obs.get("current_phase", 1))
actual_task = info_reset.get("info", {}).get("task_name", task_name)
actual_pool = info_reset.get("info", {}).get("pool", pool)
actual_mode = info_reset.get("info", {}).get("mode", mode or "joint")
# Optional reset hook on policy
if hasattr(policy, "reset"):
try:
policy.reset()
except TypeError:
policy.reset(actual_task)
rewards: List[float] = []
for _ in range(max_steps):
phase = obs.get("current_phase", initial_phase)
action = policy(obs, phase, actual_task)
step_out = env.step(action)
obs = step_out["observation"]
rewards.append(float(step_out.get("reward", 0.0)))
if step_out.get("done"):
break
state = env.get_state()
breakdown = env.score_unified()
r_cross = 0.0
try:
r_cross = compute_r_cross(
task_name = actual_task,
declared_patch = state.get("declared_patch"),
declared_no_change = bool(state.get("declared_no_change")),
p2_trajectory = env.get_p2_trajectory(),
)
except Exception:
pass
return {
"task_name": actual_task,
"pool": actual_pool,
"mode": actual_mode,
"steps_taken": state.get("step_count", 0),
"p1_trajectory": _trajectory_to_dicts(env.get_p1_trajectory()),
"p2_trajectory": _trajectory_to_dicts(env.get_p2_trajectory()),
"declared_patch": state.get("declared_patch"),
"declared_no_change": bool(state.get("declared_no_change")),
"score_breakdown": breakdown,
"r_cross": float(r_cross),
"per_step_rewards": rewards,
"phase_transition_at": state.get("phase_transition_at"),
}