Spaces:
Sleeping
Sleeping
| """ | |
| Episode rollout utility — drives a Policy through `IncidentEnvironment` | |
| without going through the HTTP server. Used by every other training | |
| module (baseline runner, ablations, dataset builder). | |
| Usage: | |
| from incident_env.training.rollouts import run_episode | |
| from incident_env.training.policies import RandomPhase2Policy | |
| from incident_env.server.incident_environment import IncidentEnvironment | |
| env = IncidentEnvironment() | |
| result = run_episode(env, RandomPhase2Policy(), task_name="memory_leak", | |
| pool="B", max_steps=30) | |
| print(result["score_breakdown"]) | |
| """ | |
| from __future__ import annotations | |
| from dataclasses import asdict | |
| from typing import Any, Dict, List, Optional | |
| from ..models import StepRecord | |
| from ..tasks import compute_r_cross | |
| from ..server.incident_environment import IncidentEnvironment | |
| from .policies import Policy | |
| def _trajectory_to_dicts(traj: List[StepRecord]) -> List[Dict[str, Any]]: | |
| """Serialize a List[StepRecord] to plain dicts (for JSON dumps).""" | |
| out: List[Dict[str, Any]] = [] | |
| for r in traj: | |
| d = asdict(r) | |
| # IncidentAction inside dataclass — already nested-asdict'd | |
| out.append(d) | |
| return out | |
| def run_episode( | |
| env: IncidentEnvironment, | |
| policy: Policy, | |
| task_name: Optional[str] = None, | |
| pool: Optional[str] = None, | |
| mode: Optional[str] = None, | |
| seed: Optional[int] = None, | |
| max_steps: int = 40, | |
| ) -> Dict[str, Any]: | |
| """ | |
| Drive `policy` through one episode against `env`. | |
| Returns a dict with: | |
| task_name, pool, mode, steps_taken, | |
| p1_trajectory, p2_trajectory, declared_patch, declared_no_change, | |
| score_breakdown (the /score response), r_cross, | |
| per_step_rewards. | |
| """ | |
| info_reset = env.reset(task_name=task_name, pool=pool, mode=mode, seed=seed) | |
| obs = info_reset["observation"] | |
| initial_phase = info_reset.get("info", {}).get("phase", obs.get("current_phase", 1)) | |
| actual_task = info_reset.get("info", {}).get("task_name", task_name) | |
| actual_pool = info_reset.get("info", {}).get("pool", pool) | |
| actual_mode = info_reset.get("info", {}).get("mode", mode or "joint") | |
| # Optional reset hook on policy | |
| if hasattr(policy, "reset"): | |
| try: | |
| policy.reset() | |
| except TypeError: | |
| policy.reset(actual_task) | |
| rewards: List[float] = [] | |
| for _ in range(max_steps): | |
| phase = obs.get("current_phase", initial_phase) | |
| action = policy(obs, phase, actual_task) | |
| step_out = env.step(action) | |
| obs = step_out["observation"] | |
| rewards.append(float(step_out.get("reward", 0.0))) | |
| if step_out.get("done"): | |
| break | |
| state = env.get_state() | |
| breakdown = env.score_unified() | |
| r_cross = 0.0 | |
| try: | |
| r_cross = compute_r_cross( | |
| task_name = actual_task, | |
| declared_patch = state.get("declared_patch"), | |
| declared_no_change = bool(state.get("declared_no_change")), | |
| p2_trajectory = env.get_p2_trajectory(), | |
| ) | |
| except Exception: | |
| pass | |
| return { | |
| "task_name": actual_task, | |
| "pool": actual_pool, | |
| "mode": actual_mode, | |
| "steps_taken": state.get("step_count", 0), | |
| "p1_trajectory": _trajectory_to_dicts(env.get_p1_trajectory()), | |
| "p2_trajectory": _trajectory_to_dicts(env.get_p2_trajectory()), | |
| "declared_patch": state.get("declared_patch"), | |
| "declared_no_change": bool(state.get("declared_no_change")), | |
| "score_breakdown": breakdown, | |
| "r_cross": float(r_cross), | |
| "per_step_rewards": rewards, | |
| "phase_transition_at": state.get("phase_transition_at"), | |
| } | |