from openenv_core import Environment from server.simulation.state_engine import StateEngine from server.simulation.scenarios.task1_oom_crash import Task1OOMCrashScenario from server.simulation.scenarios.task2_cascade import Task2CascadeFailureScenario from server.simulation.scenarios.task3_multi_root import Task3MultiRootCauseScenario from server.models.action import IncidentAction from server.models.reward import IncidentReward from server.rewards.reward_engine import RewardEngine from server.graders.task1_grader import Task1Grader from server.graders.task2_grader import Task2Grader from server.graders.task3_grader import Task3Grader from typing import Optional import uuid class IncidentCommanderEnvironment(Environment): def __init__(self): self.engine = None self.current_task = None self.episode_history = [] self.reward_engine = RewardEngine() self.graders = { "task1_oom_crash": Task1Grader(), "task2_cascade_failure": Task2Grader(), "task3_multi_root_cause": Task3Grader(), } self.episode_id = None self.prev_state = None async def reset(self, task_id: Optional[str] = None, seed: Optional[int] = None): # Select scenario if task_id is None: task_id = "task1_oom_crash" if task_id == "task1_oom_crash": scenario = Task1OOMCrashScenario() elif task_id == "task2_cascade_failure": scenario = Task2CascadeFailureScenario() elif task_id == "task3_multi_root_cause": scenario = Task3MultiRootCauseScenario() else: raise ValueError(f"Unknown task_id: {task_id}") self.current_task = task_id self.episode_id = str(uuid.uuid4()) self.episode_history = [] self.engine = StateEngine(scenario, seed or 42) observation = self.engine.tick() observation.episode_id = self.episode_id observation.task_id = task_id self.prev_state = self.engine.current_state.copy() return {"observation": observation.model_dump()} async def step(self, action_dict: dict): if self.engine is None: return { "observation": {}, "reward": {"total_reward": 0.0}, "done": True, "info": { "error": "Environment not initialized. Call /reset first.", "grader_score": 0.0, "action_result": "Error: Reset required" } } try: action = IncidentAction(**action_dict) except Exception as e: # Invalid action - return penalty return { "observation": {}, "reward": {"total_reward": -1.0}, "done": True, "info": { "error": str(e), "grader_score": 0.0, "action_result": "Invalid action" } } # Get state before action prev_state = self.engine.current_state.copy() if self.engine.current_state else {} # Execute tick with action observation = self.engine.tick(action) observation.episode_id = self.episode_id observation.task_id = self.current_task # Get state after action new_state = self.engine.current_state # Calculate reward action_allowed = observation.safety_violations_this_episode == prev_state.get("safety_violations", 0) root_causes = self.engine.scenario.get_root_causes() is_terminal = action.action_type in ["declare_incident_resolved", "request_human_escalation"] reward = self.reward_engine.compute( action=action, prev_state=prev_state, new_state=new_state, action_result={"success": True}, # Simplified action_allowed=action_allowed, root_cause_services=root_causes, is_terminal=is_terminal ) # Track action self.episode_history.append({ "step": observation.step, "action_type": action.action_type, "target_service": action.target_service, "reasoning": action.reasoning, "reward": reward.total_reward }) # Check termination done = ( action.action_type == "declare_incident_resolved" or action.action_type == "request_human_escalation" or observation.step >= observation.max_steps or observation.blast_radius < 0.05 ) # Score episode if terminal grader_score = 0.0 if done: grader = self.graders.get(self.current_task) if grader: grader_score = grader.score(self.episode_history, new_state) reward.episode_final_score = grader_score info = { "episode_id": self.episode_id, "task_id": self.current_task, "step": observation.step, "blast_radius": observation.blast_radius, "grader_score": grader_score, "safety_violations": observation.safety_violations_this_episode, "actions_taken": observation.actions_taken, "action_result": f"Action executed: {action.action_type}", "root_causes_identified": list(set( a["target_service"] for a in self.episode_history if a["action_type"] in ["inspect_logs", "pull_metrics"] and a["target_service"] in root_causes )), "audit_log": self.episode_history } return { "observation": observation.model_dump(), "reward": reward.model_dump(), "done": done, "info": info } async def state(self): return { "current_state": self.engine.current_state if self.engine else {}, "episode_id": self.episode_id, "current_task": self.current_task }