Spaces:
Sleeping
Sleeping
| from openenv_core import Environment | |
| from server.simulation.state_engine import StateEngine | |
| from server.simulation.scenarios.task1_oom_crash import Task1OOMCrashScenario | |
| from server.simulation.scenarios.task2_cascade import Task2CascadeFailureScenario | |
| from server.simulation.scenarios.task3_multi_root import Task3MultiRootCauseScenario | |
| from server.models.action import IncidentAction | |
| from server.models.reward import IncidentReward | |
| from server.rewards.reward_engine import RewardEngine | |
| from server.graders.task1_grader import Task1Grader | |
| from server.graders.task2_grader import Task2Grader | |
| from server.graders.task3_grader import Task3Grader | |
| from typing import Optional | |
| import uuid | |
| class IncidentCommanderEnvironment(Environment): | |
| def __init__(self): | |
| self.engine = None | |
| self.current_task = None | |
| self.episode_history = [] | |
| self.reward_engine = RewardEngine() | |
| self.graders = { | |
| "task1_oom_crash": Task1Grader(), | |
| "task2_cascade_failure": Task2Grader(), | |
| "task3_multi_root_cause": Task3Grader(), | |
| } | |
| self.episode_id = None | |
| self.prev_state = None | |
| async def reset(self, task_id: Optional[str] = None, seed: Optional[int] = None): | |
| # Select scenario | |
| if task_id is None: | |
| task_id = "task1_oom_crash" | |
| if task_id == "task1_oom_crash": | |
| scenario = Task1OOMCrashScenario() | |
| elif task_id == "task2_cascade_failure": | |
| scenario = Task2CascadeFailureScenario() | |
| elif task_id == "task3_multi_root_cause": | |
| scenario = Task3MultiRootCauseScenario() | |
| else: | |
| raise ValueError(f"Unknown task_id: {task_id}") | |
| self.current_task = task_id | |
| self.episode_id = str(uuid.uuid4()) | |
| self.episode_history = [] | |
| self.engine = StateEngine(scenario, seed or 42) | |
| observation = self.engine.tick() | |
| observation.episode_id = self.episode_id | |
| observation.task_id = task_id | |
| self.prev_state = self.engine.current_state.copy() | |
| return {"observation": observation.model_dump()} | |
| async def step(self, action_dict: dict): | |
| if self.engine is None: | |
| return { | |
| "observation": {}, | |
| "reward": {"total_reward": 0.0}, | |
| "done": True, | |
| "info": { | |
| "error": "Environment not initialized. Call /reset first.", | |
| "grader_score": 0.0, | |
| "action_result": "Error: Reset required" | |
| } | |
| } | |
| try: | |
| action = IncidentAction(**action_dict) | |
| except Exception as e: | |
| # Invalid action - return penalty | |
| return { | |
| "observation": {}, | |
| "reward": {"total_reward": -1.0}, | |
| "done": True, | |
| "info": { | |
| "error": str(e), | |
| "grader_score": 0.0, | |
| "action_result": "Invalid action" | |
| } | |
| } | |
| # Get state before action | |
| prev_state = self.engine.current_state.copy() if self.engine.current_state else {} | |
| # Execute tick with action | |
| observation = self.engine.tick(action) | |
| observation.episode_id = self.episode_id | |
| observation.task_id = self.current_task | |
| # Get state after action | |
| new_state = self.engine.current_state | |
| # Calculate reward | |
| action_allowed = observation.safety_violations_this_episode == prev_state.get("safety_violations", 0) | |
| root_causes = self.engine.scenario.get_root_causes() | |
| is_terminal = action.action_type in ["declare_incident_resolved", "request_human_escalation"] | |
| reward = self.reward_engine.compute( | |
| action=action, | |
| prev_state=prev_state, | |
| new_state=new_state, | |
| action_result={"success": True}, # Simplified | |
| action_allowed=action_allowed, | |
| root_cause_services=root_causes, | |
| is_terminal=is_terminal | |
| ) | |
| # Track action | |
| self.episode_history.append({ | |
| "step": observation.step, | |
| "action_type": action.action_type, | |
| "target_service": action.target_service, | |
| "reasoning": action.reasoning, | |
| "reward": reward.total_reward | |
| }) | |
| # Check termination | |
| done = ( | |
| action.action_type == "declare_incident_resolved" or | |
| action.action_type == "request_human_escalation" or | |
| observation.step >= observation.max_steps or | |
| observation.blast_radius < 0.05 | |
| ) | |
| # Score episode if terminal | |
| grader_score = 0.0 | |
| if done: | |
| grader = self.graders.get(self.current_task) | |
| if grader: | |
| grader_score = grader.score(self.episode_history, new_state) | |
| reward.episode_final_score = grader_score | |
| info = { | |
| "episode_id": self.episode_id, | |
| "task_id": self.current_task, | |
| "step": observation.step, | |
| "blast_radius": observation.blast_radius, | |
| "grader_score": grader_score, | |
| "safety_violations": observation.safety_violations_this_episode, | |
| "actions_taken": observation.actions_taken, | |
| "action_result": f"Action executed: {action.action_type}", | |
| "root_causes_identified": list(set( | |
| a["target_service"] for a in self.episode_history | |
| if a["action_type"] in ["inspect_logs", "pull_metrics"] | |
| and a["target_service"] in root_causes | |
| )), | |
| "audit_log": self.episode_history | |
| } | |
| return { | |
| "observation": observation.model_dump(), | |
| "reward": reward.model_dump(), | |
| "done": done, | |
| "info": info | |
| } | |
| async def state(self): | |
| return { | |
| "current_state": self.engine.current_state if self.engine else {}, | |
| "episode_id": self.episode_id, | |
| "current_task": self.current_task | |
| } |