cloud-sre-arbiter / server /environment.py
Escanor925's picture
fix: clamp scores to (0.001, 0.999) for Phase 2 strict range validation
d1583fb
"""
Cloud SRE Arbiter β€” Core Environment Engine
============================================
Implements a multi-step, RL-style environment where an AI agent must
simultaneously contain live incidents (Ops) and investigate root causes
(Sec/Data). The grader is fully deterministic and penalizes reckless
actions such as guessing the root cause without gathering evidence first.
"""
import json
from pathlib import Path
from pydantic import BaseModel, Field
from typing import Dict, Any, Tuple, Optional, Literal, List
# ---------------------------------------------------------------------------
# 1. PYDANTIC MODELS (strict schemas required by OpenEnv)
# ---------------------------------------------------------------------------
class Observation(BaseModel):
"""What the agent sees each turn."""
incident_id: str = Field(..., description="Unique incident identifier")
severity: str = Field(..., description="Incident severity level (P1/P2/P3)")
initial_observation: str = Field(..., description="Human-readable summary of what is happening")
active_alerts: List[str] = Field(..., description="List of active alert names")
system_metrics: Dict[str, str] = Field(..., description="Current system metric readings")
timeline: List[str] = Field(..., description="Recent event timeline")
investigation_results: Dict[str, str] = Field(
default_factory=dict,
description="Results from investigation queries run so far"
)
system_health: float = Field(..., ge=0.0, le=100.0, description="Current system health 0-100")
budget_spent: float = Field(..., ge=0.0, description="Total budget consumed so far ($)")
turn_number: int = Field(..., ge=0, description="Current turn in this episode")
turns_remaining: int = Field(..., ge=0, description="Turns left before forced resolution")
available_actions: Dict[str, List[str]] = Field(
..., description="Available action choices for each action type"
)
class Action(BaseModel):
"""The agent's two-pronged decision each turn."""
containment_action: Literal[
"scale_up_nodes",
"rate_limit_all",
"rollback_last_deploy",
"do_nothing"
] = Field(..., description="Immediate ops action to keep the system online")
investigation_query: Literal[
"analyze_ip_traffic",
"query_db_locks",
"check_commit_diffs",
"check_service_mesh",
"check_resource_utilization",
"none"
] = Field(..., description="Query to run for root-cause investigation")
declare_root_cause: Literal[
"ddos_attack",
"viral_traffic",
"bad_code",
"database_lock",
"unknown"
] = Field(..., description="Declare the root cause or 'unknown' to keep investigating")
justification: str = Field(
...,
min_length=1,
description="A short explanation for this decision, citing evidence gathered"
)
class Reward(BaseModel):
"""Deterministic grading result returned after each step."""
total_score: float = Field(..., gt=0.0, lt=1.0, description="Final score in (0, 1) exclusive")
breakdown: Dict[str, float] = Field(..., description="Score breakdown by category")
class State(BaseModel):
"""Metadata about the current episode."""
task_name: str = Field(..., description="Current task difficulty level")
incident_id: str = Field("", description="Current incident ID")
turn_number: int = Field(0, description="Current turn")
max_turns: int = Field(0, description="Maximum turns allowed")
system_health: float = Field(100.0, description="Current system health")
budget_spent: float = Field(0.0, description="Budget consumed")
is_done: bool = Field(False, description="Whether the episode has ended")
# ---------------------------------------------------------------------------
# 2. COST & REWARD CONSTANTS
# ---------------------------------------------------------------------------
# Containment costs (represent real infrastructure spend)
CONTAINMENT_COSTS = {
"scale_up_nodes": 500.0,
"rate_limit_all": 100.0,
"rollback_last_deploy": 200.0,
"do_nothing": 0.0,
}
# How each containment action affects system health (additive per turn)
CONTAINMENT_HEALTH_EFFECTS = {
"scale_up_nodes": +15.0,
"rate_limit_all": +10.0,
"rollback_last_deploy": +20.0,
"do_nothing": -15.0, # doing nothing while system is on fire is bad
}
# Investigation costs
INVESTIGATION_COST = 50.0 # each query costs $50
# Reward weights (must sum to 1.0)
W_ROOT_CAUSE = 0.40 # correctly identifying the root cause
W_CONTAINMENT = 0.25 # picking the ideal containment action
W_EVIDENCE = 0.15 # gathering required evidence before declaring
W_EFFICIENCY = 0.10 # budget efficiency
W_HEALTH = 0.10 # keeping system health above critical
# Penalties
PREMATURE_GUESS_PENALTY = 0.30 # deducted if you declare root cause without evidence
SYSTEM_CRASH_PENALTY = 0.50 # deducted if system health drops to 0
MAX_BUDGET = 5000.0 # budget ceiling for efficiency calculation
MAX_TURNS = 6 # maximum turns per incident
# ---------------------------------------------------------------------------
# 3. ENVIRONMENT ENGINE
# ---------------------------------------------------------------------------
class CloudSREEnv:
"""
Gymnasium-style environment for the Cloud SRE Arbiter.
The agent loops through reset() -> step() -> step() -> ... until done.
Each task (easy/medium/hard) contains one incident scenario.
"""
def __init__(self, data_path: str = "data.json"):
# Try loading from the same directory as this file first
p = Path(__file__).parent / data_path
if not p.exists():
p = Path(data_path)
with open(p, "r", encoding="utf-8") as f:
self.dataset: Dict[str, list] = json.load(f)
# Episode state
self._task_name: str = ""
self._case: Optional[dict] = None
self._turn: int = 0
self._budget: float = 0.0
self._health: float = 50.0 # start at 50 β€” system is already degraded
self._investigation_results: Dict[str, str] = {}
self._evidence_gathered: List[str] = []
self._containment_used: List[str] = []
self._done: bool = True
# ------------------------------------------------------------------
# PUBLIC API
# ------------------------------------------------------------------
def reset(self, task_name: str = "easy") -> Observation:
"""Start a new episode for the given task difficulty."""
if task_name not in self.dataset:
raise ValueError(
f"Task '{task_name}' not found. Available: {list(self.dataset.keys())}"
)
self._task_name = task_name
self._case = self.dataset[task_name][0] # one case per difficulty
self._turn = 0
self._budget = 0.0
self._health = 50.0 # system is already hurting
self._investigation_results = {}
self._evidence_gathered = []
self._containment_used = []
self._done = False
return self._build_observation()
def step(self, action: Action) -> Tuple[Optional[Observation], Reward, bool, Dict[str, Any]]:
"""
Process one agent turn.
Returns: (observation, reward, done, info)
"""
if self._done or self._case is None:
raise RuntimeError("Episode is over. Call reset() to start a new one.")
self._turn += 1
ground_truth = self._case["ground_truth"]
hidden_data = self._case["hidden_data"]
info: Dict[str, Any] = {"justification": action.justification, "turn": self._turn}
# --- A) PROCESS CONTAINMENT ---
cost = CONTAINMENT_COSTS.get(action.containment_action, 0.0)
self._budget += cost
health_delta = CONTAINMENT_HEALTH_EFFECTS.get(action.containment_action, 0.0)
self._health = max(0.0, min(100.0, self._health + health_delta))
if action.containment_action != "do_nothing":
self._containment_used.append(action.containment_action)
# --- B) PROCESS INVESTIGATION ---
if action.investigation_query != "none":
self._budget += INVESTIGATION_COST
query = action.investigation_query
if query in hidden_data:
self._investigation_results[query] = hidden_data[query]
else:
self._investigation_results[query] = "Query returned no anomalies."
if query not in self._evidence_gathered:
self._evidence_gathered.append(query)
# --- C) CHECK END CONDITIONS ---
declared = action.declare_root_cause != "unknown"
timed_out = self._turn >= MAX_TURNS
system_crashed = self._health <= 0.0
if declared or timed_out or system_crashed:
self._done = True
reward = self._grade(action, ground_truth, timed_out, system_crashed)
info["grading_detail"] = reward.breakdown
return None, reward, True, info
# --- D) CONTINUE INVESTIGATING ---
# Natural health decay each turn (the incident is ongoing)
self._health = max(0.0, self._health - 5.0)
reward = Reward(
total_score=0.001,
breakdown={
"status": 0.0,
"message_investigating": 0.0,
"budget_spent": self._budget,
"system_health": self._health,
},
)
return self._build_observation(), reward, False, info
def get_state(self) -> State:
"""Return metadata about the current episode."""
return State(
task_name=self._task_name or "none",
incident_id=self._case["incident_id"] if self._case else "",
turn_number=self._turn,
max_turns=MAX_TURNS,
system_health=self._health,
budget_spent=self._budget,
is_done=self._done,
)
# ------------------------------------------------------------------
# DETERMINISTIC GRADER
# ------------------------------------------------------------------
def _grade(
self,
action: Action,
ground_truth: dict,
timed_out: bool,
system_crashed: bool,
) -> Reward:
"""
Score the agent's performance. Returns a float in [0.0, 1.0].
Scoring breakdown:
- Root cause identification (40%)
- Containment quality (25%)
- Evidence gathering (15%)
- Budget efficiency (10%)
- System health maintenance (10%)
Penalties:
- Premature guess (no evidence) β†’ βˆ’0.30
- System crash (health β†’ 0) β†’ βˆ’0.50
"""
breakdown: Dict[str, float] = {}
# 1. Root cause (40%)
if action.declare_root_cause == ground_truth["root_cause"]:
breakdown["root_cause"] = W_ROOT_CAUSE
elif timed_out and action.declare_root_cause == "unknown":
breakdown["root_cause"] = 0.0 # never even guessed
else:
breakdown["root_cause"] = 0.0 # wrong guess
# 2. Containment (25%) β€” check if ideal action was used at any point
if ground_truth["ideal_containment"] in self._containment_used:
breakdown["containment"] = W_CONTAINMENT
elif action.containment_action == ground_truth["ideal_containment"]:
breakdown["containment"] = W_CONTAINMENT
else:
breakdown["containment"] = 0.0
# 3. Evidence (15%) β€” did the agent gather the required evidence?
required = set(ground_truth.get("required_evidence", []))
gathered = set(self._evidence_gathered)
if required and required.issubset(gathered):
breakdown["evidence"] = W_EVIDENCE
elif required:
# Partial credit for gathering some evidence
overlap = len(required & gathered) / len(required)
breakdown["evidence"] = round(W_EVIDENCE * overlap, 4)
else:
breakdown["evidence"] = W_EVIDENCE # no evidence required
# 4. Budget efficiency (10%)
if self._budget <= 0:
breakdown["efficiency"] = W_EFFICIENCY
else:
breakdown["efficiency"] = round(
max(0.0, W_EFFICIENCY * (1.0 - self._budget / MAX_BUDGET)), 4
)
# 5. System health (10%)
breakdown["health"] = round(W_HEALTH * (self._health / 100.0), 4)
# --- Penalties ---
penalty = 0.0
# Premature guess: agent declared root cause without gathering any
# of the required evidence
if (
action.declare_root_cause != "unknown"
and required
and not required.issubset(gathered)
):
penalty += PREMATURE_GUESS_PENALTY
breakdown["penalty_premature_guess"] = -PREMATURE_GUESS_PENALTY
# System crash penalty
if system_crashed:
penalty += SYSTEM_CRASH_PENALTY
breakdown["penalty_system_crash"] = -SYSTEM_CRASH_PENALTY
raw = sum(v for k, v in breakdown.items() if not k.startswith("penalty_"))
total = max(0.001, min(0.999, round(raw - penalty, 4)))
breakdown["budget_spent"] = self._budget
breakdown["final_health"] = self._health
breakdown["turns_used"] = float(self._turn)
return Reward(total_score=total, breakdown=breakdown)
# ------------------------------------------------------------------
# HELPERS
# ------------------------------------------------------------------
def _build_observation(self) -> Observation:
"""Build an Observation from the current case + internal state."""
case = self._case
if case is None:
raise RuntimeError("No active case β€” call reset() first.")
return Observation(
incident_id=case["incident_id"],
severity=case.get("severity", "P1"),
initial_observation=case.get("initial_observation", ""),
active_alerts=case["active_alerts"],
system_metrics=case["system_metrics"],
timeline=case.get("timeline", []),
investigation_results=dict(self._investigation_results),
system_health=round(self._health, 2),
budget_spent=round(self._budget, 2),
turn_number=self._turn,
turns_remaining=MAX_TURNS - self._turn,
available_actions={
"containment_action": [
"scale_up_nodes", "rate_limit_all",
"rollback_last_deploy", "do_nothing",
],
"investigation_query": [
"analyze_ip_traffic", "query_db_locks",
"check_commit_diffs", "check_service_mesh",
"check_resource_utilization", "none",
],
"declare_root_cause": [
"ddos_attack", "viral_traffic",
"bad_code", "database_lock", "unknown",
],
},
)