""" Core Environment implementation. Per-step execution order: validate → mutate → tick → observe → reward. Two-phase architecture: Phase 1 — ops/SRE diagnostic loop (existing behavior). Phase 2 — code attribution loop, sandboxed under a CodeWorkspace. Mode selection is automatic per scenario: - Scenario with `code_context = None` → legacy P1-only episode (declare_root_cause terminates) - Scenario with `code_context != None` → unified P1 → P2 episode (declare_root_cause is silent; transition_to_phase2 switches phase; propose_patch / declare_no_change terminate the episode) The environment uses oracle-shaped per-step rewards for training. The oracle-INDEPENDENT graders live on `BaseScenario` and `scenarios.grader_p2`. """ from __future__ import annotations import uuid import random from dataclasses import asdict from typing import Any, Dict, List, Optional, Tuple from ..models import ( ActionType, IncidentAction, IncidentState, StepRecord, BeliefState, DIAGNOSTIC_ACTIONS, REMEDIATION_ACTIONS, TARGETED_ACTIONS, PHASE1_ACTIONS, PHASE2_ACTIONS, PHASE2_DIAGNOSTIC_ACTIONS, PHASE2_TERMINAL_ACTIONS, ) from ..simulation.infrastructure import Infrastructure, SERVICE_NAMES from ..tasks import get_scenario, TASK_NAMES from ..scenarios.base import BaseScenario from ..pools import POOLS, get_pool, sample_task, oracle_belief from .code_workspace import CodeWorkspace, CodeWorkspaceError # Per-step reward constants ------------------------------------------------ _STEP_PENALTY = -0.02 _REPEAT_PENALTY = -0.05 _INVALID_PENALTY = -0.05 # Phase 2 shaping (small — terminal patch quality is graded post-hoc) _P2_DIAG_REWARD = +0.05 _P2_TERMINAL_BONUS = +0.10 class IncidentEnvironment: """ SRE Incident Response Environment. Implements the three OpenEnv methods: - reset(task_name) → initial observation + info - step(action) → dict with observation, reward, done, info - state() → IncidentState for monitoring Plus two extras used by the unified evaluator: - get_trajectory() → P1 + P2 step records - score_unified(...) → component scores for unified grader """ def __init__(self) -> None: self._infra: Optional[Infrastructure] = None self._scenario: Optional[BaseScenario] = None self._state = IncidentState() # ---- Per-episode mutable state ---- self._phase: int = 1 self._workspace: Optional[CodeWorkspace] = None self._belief_at_transition: Optional[BeliefState] = None self._p1_trajectory: List[StepRecord] = [] self._p2_trajectory: List[StepRecord] = [] self._declared_patch: Optional[str] = None self._declared_no_change: bool = False self._declared_root_cause: Optional[str] = None self._cumulative_reward: float = 0.0 self._done: bool = False # ---- Pool / mode (set by reset, drives episode semantics) ---- # mode in {"joint" (default), "p1_only" (Pool A), "p2_only" (Pool B)} self._pool: Optional[str] = None self._mode: str = "joint" self._inject_oracle_belief: bool = False # P2-only tracking (for repeat detection inside P2) self._p2_actions_taken: List[Tuple[str, str]] = [] # (atype, primary_param) # ================================================================== # reset() # ================================================================== def reset( self, task_name: Optional[str] = None, seed: Optional[int] = None, pool: Optional[str] = None, mode: Optional[str] = None, **kwargs: Any, ) -> Dict[str, Any]: """ Initialize a new incident episode. `pool` selects training pool A/B/C/D (overrides default mode). `mode` forces episode semantics ("p1_only"|"p2_only"|"joint"). Explicit `mode` always wins over pool defaults. """ if seed is not None: random.seed(seed) # ---- Pool / task selection ---- pool_obj = None if pool: pool_obj = get_pool(pool) if task_name is None: task_name = sample_task(pool, rng=random) self._pool = pool_obj.name self._mode = pool_obj.mode self._inject_oracle_belief = pool_obj.inject_oracle_belief else: self._pool = None self._mode = "joint" self._inject_oracle_belief = False if mode: self._mode = mode if mode == "p2_only": self._inject_oracle_belief = True if task_name is None: task_name = random.choice(TASK_NAMES) self._infra = Infrastructure() self._scenario = get_scenario(task_name) self._infra.time_budget_minutes = self._scenario.time_budget_minutes self._scenario.inject(self._infra) # Let cascades propagate a few minutes for _ in range(3): self._infra.tick() self._state = IncidentState( episode_id = str(uuid.uuid4()), task_name = task_name, step_count = 0, time_elapsed_minutes = self._infra.current_minute, done = False, cumulative_reward = 0.0, ) self._phase = 1 self._workspace = None self._belief_at_transition = None self._p1_trajectory = [] self._p2_trajectory = [] self._declared_patch = None self._declared_no_change = False self._declared_root_cause = None self._cumulative_reward = 0.0 self._done = False self._p2_actions_taken = [] # ---- Pool B (p2_only) auto-handoff with oracle belief -------- # The agent never sees Phase 1; we synthesise a perfect handoff and # immediately switch the env into Phase 2. if self._mode == "p2_only" and self._scenario.code_context is not None: belief = oracle_belief(self._scenario) self._handle_transition(IncidentAction( action_type = ActionType.TRANSITION_TO_PHASE2.value, target_service = None, parameters = {"belief": asdict(belief)}, )) # _handle_transition already returned; we just consume its # observation as the reset observation so caller sees Phase 2. obs = self._build_observation( action_result = { "message": "[Pool B] Auto-handoff with oracle Phase-1 belief.", "issue": self._scenario.build_p2_issue(belief), "file_tree": (self._workspace.file_tree(max_depth=4) if self._workspace else []), "bad_commit_sha": self._scenario.code_context.bad_commit_sha, "bad_commit": (self._workspace.bad_commit_metadata() if self._workspace else None), }, action_success = True, action_message = "Episode started in Pool B (P2-only) mode", reward = 0.0, ) return { "observation": obs, "reward": 0.01, "done": False, "info": {"task_name": task_name, "pool": self._pool, "mode": self._mode, "has_phase2": True, "phase": 2}, } obs = self._build_observation( action_result = {"message": "Incident triggered. Begin investigation."}, action_success = True, action_message = "Episode started", reward = 0.0, ) return { "observation": obs, "reward": 0.01, "done": False, "info": {"task_name": task_name, "pool": self._pool, "mode": self._mode, "has_phase2": self._scenario.code_context is not None}, } # ================================================================== # step() # ================================================================== def step(self, action_data: Dict[str, Any]) -> Dict[str, Any]: """Execute one agent action — phase-aware dispatch.""" if self._done: return self._final_step_response() if self._infra is None or self._scenario is None: return self._not_initialized_response() action = IncidentAction( action_type = action_data.get("action_type", ""), target_service = action_data.get("target_service"), parameters = action_data.get("parameters", {}) or {}, ) # ---- Type validation ---------------------------------------- try: atype = ActionType(action.action_type) except ValueError: return self._invalid_action_response( f"Unknown action type: {action.action_type!r}", action, ) # ---- Phase-aware dispatch ----------------------------------- if atype == ActionType.TRANSITION_TO_PHASE2: return self._handle_transition(action) if self._phase == 1: if atype not in PHASE1_ACTIONS: return self._invalid_action_response( f"Action {atype.value!r} not allowed in Phase 1", action, ) return self._step_phase1(action, atype) # Phase 2 if atype not in PHASE2_ACTIONS: return self._invalid_action_response( f"Action {atype.value!r} not allowed in Phase 2", action, ) return self._step_phase2(action, atype) # ------------------------------------------------------------------ # Phase 1 step # ------------------------------------------------------------------ def _step_phase1( self, action: IncidentAction, atype: ActionType, ) -> Dict[str, Any]: # Validate target / preconditions via Infrastructure is_valid, err = self._infra.validate_action( action.action_type, action.target_service) if not is_valid: return self._invalid_action_response(err, action) # Mutate action_result, action_msg = self._execute_p1_action(action, atype) # Tick simulation self._infra.tick() self._state.step_count += 1 self._state.time_elapsed_minutes = self._infra.current_minute # Reward (compute BEFORE recording so repeat-detection sees prior actions) reward = self._compute_p1_reward(action, atype) self._infra.record_action(action.action_type, action.target_service) self._cumulative_reward += reward self._state.cumulative_reward = self._cumulative_reward # Done check done = self._check_done_p1(atype) self._done = done self._state.done = done obs = self._build_observation( action_result = action_result, action_success = True, action_message = action_msg, reward = reward, ) record = StepRecord( step_number = self._state.step_count, action = action, reward = reward, observation_summary = { "action_message": obs.get("action_message", ""), "active_alerts_count": obs.get("active_alerts_count", 0), }, service_statuses_after = dict(obs.get("service_statuses", {})), timestamp_minutes = self._infra.current_minute, phase = 1, ) self._p1_trajectory.append(record) info: Dict[str, Any] = {} if done: info["score"] = self._scenario.grade(self._p1_trajectory) info["task_name"] = self._scenario.task_name info["steps_taken"] = self._state.step_count info["trajectory_length"] = len(self._p1_trajectory) return {"observation": obs, "reward": reward, "done": done, "info": info} # ------------------------------------------------------------------ # Phase 2 step # ------------------------------------------------------------------ def _step_phase2( self, action: IncidentAction, atype: ActionType, ) -> Dict[str, Any]: if self._workspace is None: return self._invalid_action_response( "Phase 2 not initialised — must transition_to_phase2 first.", action, ) params = action.parameters or {} # ---- Execute action ---- try: if atype == ActionType.LIST_DIR: result = self._workspace.list_dir(params.get("path", ".")) msg = f"Listed {result.get('count', 0)} entries in {result.get('path', '.')}" elif atype == ActionType.READ_FILE: result = self._workspace.read_file(params.get("path", "")) msg = f"Read {result.get('path')} ({result.get('size', 0)} bytes)" elif atype == ActionType.SEARCH_CODE: result = self._workspace.search_code( query = params.get("query", ""), file_pattern = params.get("file_pattern", "*.py"), max_hits = params.get("max_hits"), ) msg = f"Found {result.get('count', 0)} hit(s) for {params.get('query')!r}" elif atype == ActionType.GET_GIT_LOG: result = self._workspace.get_git_log( path = params.get("path", ""), n_commits = int(params.get("n_commits", 10)), ) msg = f"Returned {result.get('count', 0)} commit(s)" elif atype == ActionType.GET_FILE_DIFF: result = self._workspace.get_file_diff( commit_sha = params.get("commit_sha", ""), path = params.get("path", ""), ) msg = f"Diff for {result.get('commit_sha')[:8]} ({len(result.get('diff', ''))} bytes)" elif atype == ActionType.PROPOSE_PATCH: diff = params.get("diff", "") self._declared_patch = diff result = {"accepted": True, "patch_bytes": len(diff)} msg = "Patch proposal accepted — episode terminating." elif atype == ActionType.DECLARE_NO_CHANGE: self._declared_no_change = True reason = params.get("reason", "") result = {"accepted": True, "reason": reason} msg = "no-change declaration accepted — episode terminating." else: return self._invalid_action_response( f"Unhandled P2 action type: {atype.value!r}", action, ) success = True except CodeWorkspaceError as e: result = {"error": str(e)} msg = f"Workspace error: {e}" success = False # ---- Tick (simulation time still advances during P2) ---- self._infra.tick() self._state.step_count += 1 self._state.time_elapsed_minutes = self._infra.current_minute # ---- Reward ---- reward = self._compute_p2_reward(action, atype, success) self._cumulative_reward += reward self._state.cumulative_reward = self._cumulative_reward # ---- Done ---- done = (atype in PHASE2_TERMINAL_ACTIONS) or self._exceeded_step_budget() self._done = done self._state.done = done obs = self._build_observation( action_result = result, action_success = success, action_message = msg, reward = reward, ) # Record record = StepRecord( step_number = self._state.step_count, action = action, reward = reward, observation_summary = { "action_message": obs.get("action_message", ""), "p2_action": atype.value, }, service_statuses_after = dict(obs.get("service_statuses", {})), timestamp_minutes = self._infra.current_minute, phase = 2, ) self._p2_trajectory.append(record) # Track repeats inside P2 prim_param = self._p2_primary_param(atype, params) self._p2_actions_taken.append((atype.value, prim_param)) info: Dict[str, Any] = {} if done: info["score"] = self._compute_unified_final_score() info["task_name"] = self._scenario.task_name info["steps_taken"] = self._state.step_count info["trajectory_length"] = len(self._p1_trajectory) + len(self._p2_trajectory) return {"observation": obs, "reward": reward, "done": done, "info": info} # ------------------------------------------------------------------ # transition_to_phase2 handler # ------------------------------------------------------------------ def _handle_transition(self, action: IncidentAction) -> Dict[str, Any]: if self._phase != 1: return self._invalid_action_response( "Already in Phase 2 — cannot transition again.", action, ) if self._scenario is None or self._scenario.code_context is None: return self._invalid_action_response( "Scenario has no code_context — Phase 2 unavailable.", action, ) ctx = self._scenario.code_context # Construct workspace try: self._workspace = CodeWorkspace( snapshot_root = ctx.repo_snapshot_path, bad_commit_sha = ctx.bad_commit_sha, ) except CodeWorkspaceError as e: return self._invalid_action_response( f"Cannot open snapshot: {e}", action, ) # Capture handoff belief belief_dict = (action.parameters or {}).get("belief") or {} self._belief_at_transition = self._coerce_belief(belief_dict) # Switch phase self._phase = 2 self._state.step_count += 1 self._infra.tick() self._state.time_elapsed_minutes = self._infra.current_minute # Initial P2 obs issue_text = self._scenario.build_p2_issue(self._belief_at_transition) file_tree = self._workspace.file_tree(max_depth=4) commit_meta = self._workspace.bad_commit_metadata() action_result = { "phase": 2, "issue": issue_text, "file_tree": file_tree, "bad_commit_sha": ctx.bad_commit_sha, "bad_commit": commit_meta, "snapshot_root": str(self._workspace.tree_root), } # Reward: small handoff bonus only when belief is non-trivial reward = 0.0 if self._belief_at_transition.suspected_service: reward += 0.05 self._cumulative_reward += reward self._state.cumulative_reward = self._cumulative_reward obs = self._build_observation( action_result = action_result, action_success = True, action_message = "Transitioned to Phase 2 (code attribution).", reward = reward, ) record = StepRecord( step_number = self._state.step_count, action = action, reward = reward, observation_summary = { "action_message": "transition_to_phase2", "transition": True, }, service_statuses_after = dict(obs.get("service_statuses", {})), timestamp_minutes = self._infra.current_minute, phase = 2, belief_state_snapshot = asdict(self._belief_at_transition), ) self._p2_trajectory.append(record) return {"observation": obs, "reward": reward, "done": False, "info": {}} @staticmethod def _coerce_belief(d: Dict[str, Any]) -> BeliefState: """Best-effort: turn an inference-side dict into the canonical BeliefState.""" gaps = d.get("evidence_gaps", []) if isinstance(gaps, str): gaps = [g.strip() for g in gaps.split(",") if g.strip() and g.strip() != "none"] return BeliefState( suspected_service = d.get("suspected_service") or None, suspected_fault_class = d.get("suspected_fault_class") or None, service_confidence = float(d.get("service_confidence") or 0.0), fault_confidence = float(d.get("fault_confidence") or 0.0), evidence_gaps = list(gaps), estimated_p2_cost = d.get("estimated_p2_cost") or "unknown", decision = d.get("decision") or "transition", reasoning = d.get("reasoning") or "", ) # ================================================================== # state() # ================================================================== @property def state(self) -> IncidentState: return self._state def get_state(self) -> Dict[str, Any]: return { "episode_id": self._state.episode_id, "task_name": self._state.task_name, "step_count": self._state.step_count, "time_elapsed_minutes": self._state.time_elapsed_minutes, "done": self._state.done, "cumulative_reward": round(self._state.cumulative_reward, 3), "declared_root_cause": self._declared_root_cause, "declared_patch": self._declared_patch, "declared_no_change": self._declared_no_change, "phase": self._phase, "phase_transition_at": next( (r.step_number for r in self._p2_trajectory if r.action.action_type == ActionType.TRANSITION_TO_PHASE2.value), None, ), } # ================================================================== # Phase 1 action execution # ================================================================== def _execute_p1_action( self, action: IncidentAction, atype: ActionType, ) -> Tuple[Dict[str, Any], str]: target = action.target_service params = action.parameters or {} if atype == ActionType.VIEW_ALERTS: alerts = self._infra.get_alerts() return {"alerts": alerts, "count": len(alerts)}, \ f"Viewing {len(alerts)} active alerts" if atype == ActionType.QUERY_LOGS: level = params.get("level") keyword = params.get("keyword") limit = params.get("limit", 15) logs = self._infra.get_logs_for_service(target, level, keyword, limit) return {"logs": logs, "count": len(logs), "service": target}, \ f"Queried {len(logs)} logs from {target}" if atype == ActionType.CHECK_METRICS: metrics = self._infra.get_metrics_for_service(target) return {"metrics": metrics, "service": target, "data_points": len(metrics)}, \ f"Retrieved {len(metrics)} metric points for {target}" if atype == ActionType.CHECK_DEPENDENCIES: deps = self._infra.get_dependencies_for_service(target) return {"dependencies": deps, "service": target}, \ f"Retrieved dependency map for {target}" if atype == ActionType.CHECK_DEPLOY_HISTORY: deploys = self._infra.get_deploy_history_for_service(target) return {"deploys": deploys, "service": target, "count": len(deploys)}, \ f"Retrieved {len(deploys)} deploys for {target}" if atype == ActionType.RUN_HEALTH_CHECK: h = self._infra.run_health_check(target) return {"health_check": h, "service": target}, \ f"Health check for {target}: {h['status']}" if atype == ActionType.RESTART_SERVICE: svc = self._infra.get_service(target) msg = svc.restart(self._infra.current_minute) if svc else "Service not found" return {"result": msg, "service": target}, msg if atype == ActionType.ROLLBACK_DEPLOY: svc = self._infra.get_service(target) msg = svc.rollback_deploy(self._infra.current_minute) \ if svc else "Service not found" return {"result": msg, "service": target}, msg if atype == ActionType.SCALE_SERVICE: svc = self._infra.get_service(target) new_replicas = params.get("replicas", 5) msg = svc.scale(new_replicas, self._infra.current_minute) \ if svc else "Service not found" return {"result": msg, "service": target}, msg if atype == ActionType.DECLARE_ROOT_CAUSE: rc = params.get("root_cause", "") self._declared_root_cause = rc self._state.declared_root_cause = rc return { "declared": rc, "message": ("Root cause declared. " + ("Episode continues — Phase 2 awaits." if self._scenario.code_context else "Episode will end after this step.")), }, f"Root cause declared: {rc[:120]}" return {"error": f"Unhandled action type: {atype.value}"}, "Unknown action" # ================================================================== # Reward computation # ================================================================== def _compute_p1_reward( self, action: IncidentAction, atype: ActionType, ) -> float: scenario = self._scenario target = action.target_service reward = _STEP_PENALTY if self._infra.was_action_taken(action.action_type, target): return round(reward + _REPEAT_PENALTY, 3) if atype in DIAGNOSTIC_ACTIONS: if target and target in scenario.involved_services: reward += 0.15 elif target and target not in scenario.involved_services: reward += 0.05 elif atype == ActionType.VIEW_ALERTS: reward += 0.15 elif atype in REMEDIATION_ACTIONS: if target == scenario.root_cause_service: reward += 0.30 elif target and target in scenario.involved_services: reward += 0.10 else: reward -= 0.15 elif atype == ActionType.DECLARE_ROOT_CAUSE: declared = (action.parameters or {}).get("root_cause", "").lower() kws = scenario.root_cause_keywords if kws: ratio = sum(1 for k in kws if k in declared) / len(kws) if ratio >= 0.6: reward += 0.40 elif ratio >= 0.3: reward += 0.15 else: reward -= 0.20 else: reward -= 0.20 # Completion bonus when episode terminates if self._declared_root_cause and not scenario.code_context: if self._infra.all_services_healthy(): reward += 0.20 if self._infra.current_minute > self._infra.time_budget_minutes: reward -= 0.10 return round(reward, 3) def _compute_p2_reward( self, action: IncidentAction, atype: ActionType, success: bool, ) -> float: params = action.parameters or {} prim = self._p2_primary_param(atype, params) reward = _STEP_PENALTY if not success: return round(reward + _INVALID_PENALTY, 3) if (atype.value, prim) in self._p2_actions_taken: return round(reward + _REPEAT_PENALTY, 3) if atype in PHASE2_DIAGNOSTIC_ACTIONS: reward += _P2_DIAG_REWARD elif atype in PHASE2_TERMINAL_ACTIONS: reward += _P2_TERMINAL_BONUS return round(reward, 3) @staticmethod def _p2_primary_param(atype: ActionType, params: Dict[str, Any]) -> str: if atype == ActionType.LIST_DIR: return params.get("path", ".") if atype == ActionType.READ_FILE: return params.get("path", "") if atype == ActionType.SEARCH_CODE: return params.get("query", "") if atype == ActionType.GET_GIT_LOG: return params.get("path", "") if atype == ActionType.GET_FILE_DIFF: return f'{params.get("commit_sha", "")}:{params.get("path", "")}' return "" # ================================================================== # Done logic # ================================================================== def _check_done_p1(self, atype: ActionType) -> bool: # Pool A / explicit p1_only mode: declare_root_cause always terminates, # regardless of whether the scenario could otherwise transition to P2. if atype == ActionType.DECLARE_ROOT_CAUSE: if self._mode == "p1_only" or self._scenario.code_context is None: return True if self._exceeded_step_budget(): return True return False def _exceeded_step_budget(self) -> bool: budget = self._scenario.max_steps if self._scenario else 20 # When code_context exists, allow a bit more headroom for P2 exploration if self._scenario and self._scenario.code_context is not None: budget = budget + 15 return self._state.step_count >= budget # ================================================================== # Observation builder # ================================================================== def _build_observation( self, action_result: Dict[str, Any], action_success: bool, action_message: str, reward: float, ) -> Dict[str, Any]: statuses = self._infra.get_all_statuses() if self._infra else {} alerts = self._infra.get_alerts() if self._infra else [] valid_actions = self._valid_actions_for_phase() return { "incident_summary": self._scenario.incident_summary if self._scenario else "", "severity": self._scenario.severity if self._scenario else "SEV3", "time_elapsed_minutes": self._infra.current_minute if self._infra else 0, "time_budget_minutes": self._infra.time_budget_minutes if self._infra else 30, "action_result": action_result, "action_success": action_success, "action_message": action_message, "service_statuses": statuses, "active_alerts_count": len(alerts), "valid_actions": valid_actions, "available_services": list(SERVICE_NAMES), "current_phase": self._phase, "current_reward": reward, "cumulative_reward": round(self._cumulative_reward, 3), "steps_taken": self._state.step_count, "max_steps": self._scenario.max_steps if self._scenario else 20, "done": self._done, # Convenience field surfaced after transition (so the inference loop # can grab it without re-issuing a step) — only meaningful after # transition_to_phase2 has been called. "bad_commit_sha": (self._scenario.code_context.bad_commit_sha if self._scenario and self._scenario.code_context else None), } def _valid_actions_for_phase(self) -> List[str]: if self._phase == 1: base = self._infra.get_valid_actions() if self._infra else [] # Filter to only P1 + (optionally) transition_to_phase2 valid = [a for a in base if a.split(":", 1)[0] in {at.value for at in PHASE1_ACTIONS}] if self._scenario and self._scenario.code_context is not None: valid.append(ActionType.TRANSITION_TO_PHASE2.value) return valid # Phase 2 return [at.value for at in PHASE2_ACTIONS] # ================================================================== # Trajectory access (used by /score endpoint and Pool runners) # ================================================================== def get_trajectory(self) -> List[StepRecord]: return list(self._p1_trajectory) + list(self._p2_trajectory) def get_p1_trajectory(self) -> List[StepRecord]: return list(self._p1_trajectory) def get_p2_trajectory(self) -> List[StepRecord]: return list(self._p2_trajectory) def get_belief_at_transition(self) -> Optional[BeliefState]: return self._belief_at_transition # ================================================================== # Final unified scoring # ================================================================== def _compute_unified_final_score(self) -> float: """Quick wrapper for the in-step `info.score` field.""" from ..tasks import grade_trajectory_unified if self._scenario is None: return 0.01 breakdown = grade_trajectory_unified( task_name = self._scenario.task_name, p1_trajectory = self._p1_trajectory, p2_trajectory = self._p2_trajectory, declared_patch = self._declared_patch, declared_no_change = self._declared_no_change, p1_belief_history = [], ) return float(breakdown.get("final", 0.01)) def score_unified( self, belief_history: Optional[List[Dict[str, Any]]] = None, ) -> Dict[str, float]: """Public wrapper exposed by the /score endpoint.""" from ..tasks import grade_trajectory_unified if self._scenario is None: return {"final": 0.01} return grade_trajectory_unified( task_name = self._scenario.task_name, p1_trajectory = self._p1_trajectory, p2_trajectory = self._p2_trajectory, declared_patch = self._declared_patch, declared_no_change = self._declared_no_change, p1_belief_history = belief_history or [], ) # ================================================================== # Error / fallback responses # ================================================================== def _invalid_action_response( self, msg: str, action: IncidentAction, ) -> Dict[str, Any]: reward = _INVALID_PENALTY self._cumulative_reward += reward self._state.step_count += 1 obs = self._build_observation( action_result = {"error": msg}, action_success = False, action_message = f"Invalid action: {msg}", reward = reward, ) # Still record the failed attempt so trajectory analysis sees it record = StepRecord( step_number = self._state.step_count, action = action, reward = reward, observation_summary = {"action_message": f"invalid: {msg}"}, service_statuses_after = dict(obs.get("service_statuses", {})), timestamp_minutes = self._infra.current_minute if self._infra else 0, phase = self._phase, ) if self._phase == 1: self._p1_trajectory.append(record) else: self._p2_trajectory.append(record) return {"observation": obs, "reward": reward, "done": False, "info": {"error": msg}} def _final_step_response(self) -> Dict[str, Any]: obs = self._build_observation( action_result = {"error": "Episode is already done."}, action_success = False, action_message = "Episode already finished", reward = 0.0, ) score = (self._compute_unified_final_score() if self._scenario and self._scenario.code_context else (self._scenario.grade(self._p1_trajectory) if self._scenario else 0.01)) return {"observation": obs, "reward": 0.01, "done": True, "info": {"score": score}} def _not_initialized_response(self) -> Dict[str, Any]: obs = self._build_observation( action_result = {"error": "Environment not initialized. Call reset() first."}, action_success = False, action_message = "Not initialized", reward = 0.0, ) return {"observation": obs, "reward": 0.01, "done": False, "info": {}}