IncidentCommander-Gym / server /environment.py
Chirag0123's picture
fix: Reduce workers to 1 and add safety check for engine initialization
27ae240
from openenv_core import Environment
from server.simulation.state_engine import StateEngine
from server.simulation.scenarios.task1_oom_crash import Task1OOMCrashScenario
from server.simulation.scenarios.task2_cascade import Task2CascadeFailureScenario
from server.simulation.scenarios.task3_multi_root import Task3MultiRootCauseScenario
from server.models.action import IncidentAction
from server.models.reward import IncidentReward
from server.rewards.reward_engine import RewardEngine
from server.graders.task1_grader import Task1Grader
from server.graders.task2_grader import Task2Grader
from server.graders.task3_grader import Task3Grader
from typing import Optional
import uuid
class IncidentCommanderEnvironment(Environment):
def __init__(self):
self.engine = None
self.current_task = None
self.episode_history = []
self.reward_engine = RewardEngine()
self.graders = {
"task1_oom_crash": Task1Grader(),
"task2_cascade_failure": Task2Grader(),
"task3_multi_root_cause": Task3Grader(),
}
self.episode_id = None
self.prev_state = None
async def reset(self, task_id: Optional[str] = None, seed: Optional[int] = None):
# Select scenario
if task_id is None:
task_id = "task1_oom_crash"
if task_id == "task1_oom_crash":
scenario = Task1OOMCrashScenario()
elif task_id == "task2_cascade_failure":
scenario = Task2CascadeFailureScenario()
elif task_id == "task3_multi_root_cause":
scenario = Task3MultiRootCauseScenario()
else:
raise ValueError(f"Unknown task_id: {task_id}")
self.current_task = task_id
self.episode_id = str(uuid.uuid4())
self.episode_history = []
self.engine = StateEngine(scenario, seed or 42)
observation = self.engine.tick()
observation.episode_id = self.episode_id
observation.task_id = task_id
self.prev_state = self.engine.current_state.copy()
return {"observation": observation.model_dump()}
async def step(self, action_dict: dict):
if self.engine is None:
return {
"observation": {},
"reward": {"total_reward": 0.0},
"done": True,
"info": {
"error": "Environment not initialized. Call /reset first.",
"grader_score": 0.0,
"action_result": "Error: Reset required"
}
}
try:
action = IncidentAction(**action_dict)
except Exception as e:
# Invalid action - return penalty
return {
"observation": {},
"reward": {"total_reward": -1.0},
"done": True,
"info": {
"error": str(e),
"grader_score": 0.0,
"action_result": "Invalid action"
}
}
# Get state before action
prev_state = self.engine.current_state.copy() if self.engine.current_state else {}
# Execute tick with action
observation = self.engine.tick(action)
observation.episode_id = self.episode_id
observation.task_id = self.current_task
# Get state after action
new_state = self.engine.current_state
# Calculate reward
action_allowed = observation.safety_violations_this_episode == prev_state.get("safety_violations", 0)
root_causes = self.engine.scenario.get_root_causes()
is_terminal = action.action_type in ["declare_incident_resolved", "request_human_escalation"]
reward = self.reward_engine.compute(
action=action,
prev_state=prev_state,
new_state=new_state,
action_result={"success": True}, # Simplified
action_allowed=action_allowed,
root_cause_services=root_causes,
is_terminal=is_terminal
)
# Track action
self.episode_history.append({
"step": observation.step,
"action_type": action.action_type,
"target_service": action.target_service,
"reasoning": action.reasoning,
"reward": reward.total_reward
})
# Check termination
done = (
action.action_type == "declare_incident_resolved" or
action.action_type == "request_human_escalation" or
observation.step >= observation.max_steps or
observation.blast_radius < 0.05
)
# Score episode if terminal
grader_score = 0.0
if done:
grader = self.graders.get(self.current_task)
if grader:
grader_score = grader.score(self.episode_history, new_state)
reward.episode_final_score = grader_score
info = {
"episode_id": self.episode_id,
"task_id": self.current_task,
"step": observation.step,
"blast_radius": observation.blast_radius,
"grader_score": grader_score,
"safety_violations": observation.safety_violations_this_episode,
"actions_taken": observation.actions_taken,
"action_result": f"Action executed: {action.action_type}",
"root_causes_identified": list(set(
a["target_service"] for a in self.episode_history
if a["action_type"] in ["inspect_logs", "pull_metrics"]
and a["target_service"] in root_causes
)),
"audit_log": self.episode_history
}
return {
"observation": observation.model_dump(),
"reward": reward.model_dump(),
"done": done,
"info": info
}
async def state(self):
return {
"current_state": self.engine.current_state if self.engine else {},
"episode_id": self.episode_id,
"current_task": self.current_task
}