Sre-Validation / env /environment.py
abdur0001's picture
feat: initial push with env and 3 tasks
5fe9036
"""
Core environment engine — implements reset/step/state for the SRE Incident Response env.
"""
import uuid
from typing import Any, Dict, Optional, Set, Tuple
from models import (
Action,
ActionType,
GraderResult,
INVESTIGATION_ACTIONS,
Observation,
REMEDIATION_ACTIONS,
RootCauseCategory,
ServiceState,
ServiceStatus,
State,
)
from env.scenario import IncidentScenario, RequiredFix
from tasks import SCENARIOS
from env.services import (
format_config_diff,
format_deploy_history,
format_dependencies,
format_logs,
format_metrics,
format_runbook,
format_traces,
generate_alerts,
ping_service,
recompute_health,
)
class Session:
"""Tracks the state of a single episode."""
def __init__(self, scenario: IncidentScenario, session_id: str):
self.session_id = session_id
self.scenario = scenario
self.step_count = 0
self.done = False
self.cumulative_reward = 0.0
# Mutable service state: {name: {status, version, replicas}}
self.services: Dict[str, Dict[str, Any]] = {}
for name, cfg in scenario.services.items():
self.services[name] = {
"status": cfg.status,
"version": cfg.version,
"replicas": cfg.replicas,
}
# Track which root-cause services have been fixed
self.fixed_services: Set[str] = set()
# Build root-cause map: service_name -> fault_type
self.root_cause_map: Dict[str, str] = {}
for name, cfg in scenario.services.items():
if cfg.is_root_cause and cfg.fault_type:
self.root_cause_map[name] = cfg.fault_type
# Action history for grading
self.actions: list[Action] = []
self.services_investigated: Set[str] = set()
self.remediations_applied: list[Dict[str, Any]] = []
self.diagnosis: Optional[Action] = None
class IncidentResponseEnv:
"""The SRE Incident Response OpenEnv environment."""
def __init__(self):
self.sessions: Dict[str, Session] = {}
def get_task_ids(self) -> list[str]:
return list(SCENARIOS.keys())
def reset(self, task_id: str, seed: int = 0) -> Tuple[Observation, str]:
"""Start a new episode for the given task."""
if task_id not in SCENARIOS:
raise ValueError(f"Unknown task_id: {task_id}. Available: {list(SCENARIOS.keys())}")
scenario = SCENARIOS[task_id]
session_id = str(uuid.uuid4())[:8]
session = Session(scenario, session_id)
self.sessions[session_id] = session
# Build initial observation
obs = self._build_observation(session, action_result=None)
return obs, session_id
def step(self, session_id: str, action: Action) -> Tuple[Observation, float, bool, Dict]:
"""Execute an action and return (observation, reward, done, info)."""
session = self._get_session(session_id)
if session.done:
obs = self._build_observation(session, action_result="Episode already finished.")
return obs, 0.0, True, {"error": "Episode already finished."}
session.step_count += 1
session.actions.append(action)
reward = 0.0
action_result = ""
info: Dict[str, Any] = {}
service_name = action.service
scenario = session.scenario
# Validate service name for actions that require it
if action.action_type != ActionType.SUBMIT_DIAGNOSIS:
if service_name and service_name not in scenario.services:
action_result = f"Unknown service: '{service_name}'. Available: {list(scenario.services.keys())}"
obs = self._build_observation(session, action_result=action_result)
return obs, 0.0, False, {"error": action_result}
if not service_name and action.action_type != ActionType.SUBMIT_DIAGNOSIS:
action_result = "Action requires a 'service' parameter."
obs = self._build_observation(session, action_result=action_result)
return obs, 0.0, False, {"error": action_result}
# ── Investigation actions ──
if action.action_type in INVESTIGATION_ACTIONS:
session.services_investigated.add(service_name)
action_result = self._handle_investigation(session, action)
# Small reward for investigating root cause services
if service_name in scenario.root_cause_services:
reward = 0.01
else:
reward = 0.0
# ── Remediation actions ──
elif action.action_type in REMEDIATION_ACTIONS:
action_result, reward = self._handle_remediation(session, action)
session.remediations_applied.append({
"action": action.action_type.value,
"service": service_name,
"target_version": action.target_version,
"replicas": action.replicas,
})
# ── Submit diagnosis ──
elif action.action_type == ActionType.SUBMIT_DIAGNOSIS:
session.diagnosis = action
session.done = True
grader_result = self._grade(session)
reward = grader_result.score
action_result = f"Diagnosis submitted. Score: {grader_result.score:.2f}"
info["grader_result"] = grader_result.model_dump()
session.cumulative_reward += reward
# Check max steps
if session.step_count >= scenario.max_steps and not session.done:
session.done = True
if session.diagnosis is None:
# Auto-grade with whatever we have
grader_result = self._grade(session)
reward = grader_result.score
info["grader_result"] = grader_result.model_dump()
action_result += f"\n[MAX STEPS REACHED] Episode ended. Score: {grader_result.score:.2f}"
obs = self._build_observation(session, action_result=action_result, reward=reward)
obs.done = session.done
if "grader_result" in info:
obs.score = info["grader_result"]["score"]
return obs, reward, session.done, info
def state(self, session_id: str) -> State:
"""Return current episode state."""
session = self._get_session(session_id)
return State(
session_id=session.session_id,
task_id=session.scenario.task_id,
step_count=session.step_count,
max_steps=session.scenario.max_steps,
done=session.done,
actions_taken=[a.action_type.value for a in session.actions],
services_investigated=list(session.services_investigated),
remediations_applied=[f"{r['action']}({r['service']})" for r in session.remediations_applied],
cumulative_reward=round(session.cumulative_reward, 4),
)
# ── Internal helpers ───────────────────────────────────────────────
def _get_session(self, session_id: str) -> Session:
if session_id not in self.sessions:
raise ValueError(f"Unknown session: {session_id}")
return self.sessions[session_id]
def _build_observation(
self, session: Session, action_result: Optional[str], reward: float = 0.0,
) -> Observation:
scenario = session.scenario
svc_states = {}
for name, data in session.services.items():
svc_states[name] = ServiceState(
status=data["status"],
version=data["version"],
replicas=data["replicas"],
)
alerts = generate_alerts(
session.services, scenario.initial_alerts, session.fixed_services,
)
return Observation(
step_number=session.step_count,
timestamp=f"2026-04-06T04:{session.step_count:02d}:00Z",
services=svc_states,
active_alerts=alerts,
incident_summary=scenario.incident_summary if session.step_count == 0 else "",
action_result=action_result,
reward=round(reward, 4),
done=session.done,
)
def _handle_investigation(self, session: Session, action: Action) -> str:
scenario = session.scenario
svc = action.service
if action.action_type == ActionType.READ_LOGS:
logs = scenario.logs.get(svc, [])
return format_logs(logs)
elif action.action_type == ActionType.CHECK_METRICS:
metrics = scenario.metrics.get(svc, [])
return format_metrics(metrics)
elif action.action_type == ActionType.PING_SERVICE:
status = session.services[svc]["status"]
return ping_service(status, svc)
elif action.action_type == ActionType.CHECK_DEPENDENCIES:
deps = scenario.dependencies.get(svc, [])
dep_info = format_dependencies(deps)
# Also show current health of dependencies
dep_health = []
for d in deps:
if d in session.services:
dep_health.append(f" {d}: {session.services[d]['status'].value}")
if dep_health:
dep_info += "\n\nDependency health:\n" + "\n".join(dep_health)
return dep_info
elif action.action_type == ActionType.INSPECT_DEPLOY:
deploys = scenario.deploy_history.get(svc, [])
return format_deploy_history(deploys)
elif action.action_type == ActionType.QUERY_TRACES:
traces = scenario.traces.get(svc, [])
return format_traces(traces)
elif action.action_type == ActionType.CHECK_RUNBOOK:
runbook = scenario.runbooks.get(svc, "")
return format_runbook(runbook)
elif action.action_type == ActionType.DIFF_CONFIG:
configs = scenario.configs.get(svc, {})
return format_config_diff(configs)
return f"No data available for {action.action_type.value} on {svc}."
def _handle_remediation(self, session: Session, action: Action) -> Tuple[str, float]:
scenario = session.scenario
svc = action.service
reward = 0.0
result = ""
# Check if this remediation matches any required fix
fix_matched = False
for req_fix in scenario.required_fixes:
if self._fix_matches(action, req_fix):
fix_matched = True
session.fixed_services.add(svc)
reward = 0.05
break
if action.action_type == ActionType.RESTART_SERVICE:
if fix_matched:
session.services[svc]["status"] = ServiceStatus.HEALTHY
result = f"Service '{svc}' restarted successfully. Status: HEALTHY"
else:
# Restarting a non-root-cause service: no effect on the underlying issue
current = session.services[svc]["status"]
if current == ServiceStatus.DOWN and svc in session.root_cause_map:
result = f"Service '{svc}' restarted but crashed again — underlying issue persists."
elif current == ServiceStatus.HEALTHY:
result = f"Service '{svc}' restarted. It was already healthy — no change."
else:
result = f"Service '{svc}' restarted. Status unchanged — issue is caused by an upstream dependency."
reward = -0.05
elif action.action_type == ActionType.ROLLBACK_DEPLOY:
if fix_matched:
session.services[svc]["version"] = action.target_version or ""
session.services[svc]["status"] = ServiceStatus.HEALTHY
result = (
f"Service '{svc}' rolled back to {action.target_version}. "
f"Pods restarting with previous version... Status: HEALTHY"
)
else:
current_version = session.services[svc]["version"]
result = (
f"Rolled back '{svc}' to {action.target_version}, but this didn't resolve the issue. "
f"Previous version was {current_version}."
)
reward = -0.05
elif action.action_type == ActionType.SCALE_UP:
replicas = action.replicas or 3
if fix_matched or (svc in scenario.root_cause_services):
session.services[svc]["replicas"] = replicas
session.fixed_services.add(svc)
session.services[svc]["status"] = ServiceStatus.HEALTHY
result = f"Service '{svc}' scaled to {replicas} replicas. Memory pressure alleviated. Status: HEALTHY"
reward = 0.05
else:
session.services[svc]["replicas"] = replicas
result = f"Service '{svc}' scaled to {replicas} replicas. No effect on the underlying issue."
reward = -0.05
elif action.action_type == ActionType.DRAIN_TRAFFIC:
result = f"Traffic drained from '{svc}'. Service is no longer receiving requests."
if svc not in scenario.root_cause_services:
reward = -0.05
# Recompute health after remediation
session.services = recompute_health(
session.services,
scenario.dependencies,
session.fixed_services,
session.root_cause_map,
)
# Add post-remediation status summary
still_broken = [
name for name, data in session.services.items()
if data["status"] != ServiceStatus.HEALTHY
]
if still_broken:
result += f"\n\n[POST-REMEDIATION CHECK] Services still unhealthy: {', '.join(still_broken)}"
else:
result += "\n\n[POST-REMEDIATION CHECK] All services are now HEALTHY."
return result, reward
def _fix_matches(self, action: Action, req_fix: RequiredFix) -> bool:
"""Check if an action matches a required fix."""
if action.action_type.value != req_fix.action:
return False
if action.service != req_fix.service:
return False
if req_fix.target_version and action.target_version != req_fix.target_version:
return False
return True
def _grade(self, session: Session) -> GraderResult:
"""Deterministic grading of the episode."""
from graders.grader import grade_episode
return grade_episode(session)