| """ |
| Cloud SRE Arbiter β Core Environment Engine |
| ============================================ |
| Implements a multi-step, RL-style environment where an AI agent must |
| simultaneously contain live incidents (Ops) and investigate root causes |
| (Sec/Data). The grader is fully deterministic and penalizes reckless |
| actions such as guessing the root cause without gathering evidence first. |
| """ |
|
|
| import json |
| from pathlib import Path |
| from pydantic import BaseModel, Field |
| from typing import Dict, Any, Tuple, Optional, Literal, List |
|
|
|
|
| |
| |
| |
|
|
| class Observation(BaseModel): |
| """What the agent sees each turn.""" |
| incident_id: str = Field(..., description="Unique incident identifier") |
| severity: str = Field(..., description="Incident severity level (P1/P2/P3)") |
| initial_observation: str = Field(..., description="Human-readable summary of what is happening") |
| active_alerts: List[str] = Field(..., description="List of active alert names") |
| system_metrics: Dict[str, str] = Field(..., description="Current system metric readings") |
| timeline: List[str] = Field(..., description="Recent event timeline") |
| investigation_results: Dict[str, str] = Field( |
| default_factory=dict, |
| description="Results from investigation queries run so far" |
| ) |
| system_health: float = Field(..., ge=0.0, le=100.0, description="Current system health 0-100") |
| budget_spent: float = Field(..., ge=0.0, description="Total budget consumed so far ($)") |
| turn_number: int = Field(..., ge=0, description="Current turn in this episode") |
| turns_remaining: int = Field(..., ge=0, description="Turns left before forced resolution") |
| available_actions: Dict[str, List[str]] = Field( |
| ..., description="Available action choices for each action type" |
| ) |
|
|
|
|
| class Action(BaseModel): |
| """The agent's two-pronged decision each turn.""" |
| containment_action: Literal[ |
| "scale_up_nodes", |
| "rate_limit_all", |
| "rollback_last_deploy", |
| "do_nothing" |
| ] = Field(..., description="Immediate ops action to keep the system online") |
|
|
| investigation_query: Literal[ |
| "analyze_ip_traffic", |
| "query_db_locks", |
| "check_commit_diffs", |
| "check_service_mesh", |
| "check_resource_utilization", |
| "none" |
| ] = Field(..., description="Query to run for root-cause investigation") |
|
|
| declare_root_cause: Literal[ |
| "ddos_attack", |
| "viral_traffic", |
| "bad_code", |
| "database_lock", |
| "unknown" |
| ] = Field(..., description="Declare the root cause or 'unknown' to keep investigating") |
|
|
| justification: str = Field( |
| ..., |
| min_length=1, |
| description="A short explanation for this decision, citing evidence gathered" |
| ) |
|
|
|
|
| class Reward(BaseModel): |
| """Deterministic grading result returned after each step.""" |
| total_score: float = Field(..., gt=0.0, lt=1.0, description="Final score in (0, 1) exclusive") |
| breakdown: Dict[str, float] = Field(..., description="Score breakdown by category") |
|
|
|
|
| class State(BaseModel): |
| """Metadata about the current episode.""" |
| task_name: str = Field(..., description="Current task difficulty level") |
| incident_id: str = Field("", description="Current incident ID") |
| turn_number: int = Field(0, description="Current turn") |
| max_turns: int = Field(0, description="Maximum turns allowed") |
| system_health: float = Field(100.0, description="Current system health") |
| budget_spent: float = Field(0.0, description="Budget consumed") |
| is_done: bool = Field(False, description="Whether the episode has ended") |
|
|
|
|
| |
| |
| |
|
|
| |
| CONTAINMENT_COSTS = { |
| "scale_up_nodes": 500.0, |
| "rate_limit_all": 100.0, |
| "rollback_last_deploy": 200.0, |
| "do_nothing": 0.0, |
| } |
|
|
| |
| CONTAINMENT_HEALTH_EFFECTS = { |
| "scale_up_nodes": +15.0, |
| "rate_limit_all": +10.0, |
| "rollback_last_deploy": +20.0, |
| "do_nothing": -15.0, |
| } |
|
|
| |
| INVESTIGATION_COST = 50.0 |
|
|
| |
| W_ROOT_CAUSE = 0.40 |
| W_CONTAINMENT = 0.25 |
| W_EVIDENCE = 0.15 |
| W_EFFICIENCY = 0.10 |
| W_HEALTH = 0.10 |
|
|
| |
| PREMATURE_GUESS_PENALTY = 0.30 |
| SYSTEM_CRASH_PENALTY = 0.50 |
| MAX_BUDGET = 5000.0 |
| MAX_TURNS = 6 |
|
|
|
|
| |
| |
| |
|
|
| class CloudSREEnv: |
| """ |
| Gymnasium-style environment for the Cloud SRE Arbiter. |
| |
| The agent loops through reset() -> step() -> step() -> ... until done. |
| Each task (easy/medium/hard) contains one incident scenario. |
| """ |
|
|
| def __init__(self, data_path: str = "data.json"): |
| |
| p = Path(__file__).parent / data_path |
| if not p.exists(): |
| p = Path(data_path) |
| with open(p, "r", encoding="utf-8") as f: |
| self.dataset: Dict[str, list] = json.load(f) |
|
|
| |
| self._task_name: str = "" |
| self._case: Optional[dict] = None |
| self._turn: int = 0 |
| self._budget: float = 0.0 |
| self._health: float = 50.0 |
| self._investigation_results: Dict[str, str] = {} |
| self._evidence_gathered: List[str] = [] |
| self._containment_used: List[str] = [] |
| self._done: bool = True |
|
|
| |
| |
| |
|
|
| def reset(self, task_name: str = "easy") -> Observation: |
| """Start a new episode for the given task difficulty.""" |
| if task_name not in self.dataset: |
| raise ValueError( |
| f"Task '{task_name}' not found. Available: {list(self.dataset.keys())}" |
| ) |
|
|
| self._task_name = task_name |
| self._case = self.dataset[task_name][0] |
| self._turn = 0 |
| self._budget = 0.0 |
| self._health = 50.0 |
| self._investigation_results = {} |
| self._evidence_gathered = [] |
| self._containment_used = [] |
| self._done = False |
|
|
| return self._build_observation() |
|
|
| def step(self, action: Action) -> Tuple[Optional[Observation], Reward, bool, Dict[str, Any]]: |
| """ |
| Process one agent turn. |
| |
| Returns: (observation, reward, done, info) |
| """ |
| if self._done or self._case is None: |
| raise RuntimeError("Episode is over. Call reset() to start a new one.") |
|
|
| self._turn += 1 |
| ground_truth = self._case["ground_truth"] |
| hidden_data = self._case["hidden_data"] |
| info: Dict[str, Any] = {"justification": action.justification, "turn": self._turn} |
|
|
| |
| cost = CONTAINMENT_COSTS.get(action.containment_action, 0.0) |
| self._budget += cost |
| health_delta = CONTAINMENT_HEALTH_EFFECTS.get(action.containment_action, 0.0) |
| self._health = max(0.0, min(100.0, self._health + health_delta)) |
| if action.containment_action != "do_nothing": |
| self._containment_used.append(action.containment_action) |
|
|
| |
| if action.investigation_query != "none": |
| self._budget += INVESTIGATION_COST |
| query = action.investigation_query |
| if query in hidden_data: |
| self._investigation_results[query] = hidden_data[query] |
| else: |
| self._investigation_results[query] = "Query returned no anomalies." |
| if query not in self._evidence_gathered: |
| self._evidence_gathered.append(query) |
|
|
| |
| declared = action.declare_root_cause != "unknown" |
| timed_out = self._turn >= MAX_TURNS |
| system_crashed = self._health <= 0.0 |
|
|
| if declared or timed_out or system_crashed: |
| self._done = True |
| reward = self._grade(action, ground_truth, timed_out, system_crashed) |
| info["grading_detail"] = reward.breakdown |
| return None, reward, True, info |
|
|
| |
| |
| self._health = max(0.0, self._health - 5.0) |
|
|
| reward = Reward( |
| total_score=0.001, |
| breakdown={ |
| "status": 0.0, |
| "message_investigating": 0.0, |
| "budget_spent": self._budget, |
| "system_health": self._health, |
| }, |
| ) |
| return self._build_observation(), reward, False, info |
|
|
| def get_state(self) -> State: |
| """Return metadata about the current episode.""" |
| return State( |
| task_name=self._task_name or "none", |
| incident_id=self._case["incident_id"] if self._case else "", |
| turn_number=self._turn, |
| max_turns=MAX_TURNS, |
| system_health=self._health, |
| budget_spent=self._budget, |
| is_done=self._done, |
| ) |
|
|
| |
| |
| |
|
|
| def _grade( |
| self, |
| action: Action, |
| ground_truth: dict, |
| timed_out: bool, |
| system_crashed: bool, |
| ) -> Reward: |
| """ |
| Score the agent's performance. Returns a float in [0.0, 1.0]. |
| |
| Scoring breakdown: |
| - Root cause identification (40%) |
| - Containment quality (25%) |
| - Evidence gathering (15%) |
| - Budget efficiency (10%) |
| - System health maintenance (10%) |
| |
| Penalties: |
| - Premature guess (no evidence) β β0.30 |
| - System crash (health β 0) β β0.50 |
| """ |
| breakdown: Dict[str, float] = {} |
|
|
| |
| if action.declare_root_cause == ground_truth["root_cause"]: |
| breakdown["root_cause"] = W_ROOT_CAUSE |
| elif timed_out and action.declare_root_cause == "unknown": |
| breakdown["root_cause"] = 0.0 |
| else: |
| breakdown["root_cause"] = 0.0 |
|
|
| |
| if ground_truth["ideal_containment"] in self._containment_used: |
| breakdown["containment"] = W_CONTAINMENT |
| elif action.containment_action == ground_truth["ideal_containment"]: |
| breakdown["containment"] = W_CONTAINMENT |
| else: |
| breakdown["containment"] = 0.0 |
|
|
| |
| required = set(ground_truth.get("required_evidence", [])) |
| gathered = set(self._evidence_gathered) |
| if required and required.issubset(gathered): |
| breakdown["evidence"] = W_EVIDENCE |
| elif required: |
| |
| overlap = len(required & gathered) / len(required) |
| breakdown["evidence"] = round(W_EVIDENCE * overlap, 4) |
| else: |
| breakdown["evidence"] = W_EVIDENCE |
|
|
| |
| if self._budget <= 0: |
| breakdown["efficiency"] = W_EFFICIENCY |
| else: |
| breakdown["efficiency"] = round( |
| max(0.0, W_EFFICIENCY * (1.0 - self._budget / MAX_BUDGET)), 4 |
| ) |
|
|
| |
| breakdown["health"] = round(W_HEALTH * (self._health / 100.0), 4) |
|
|
| |
| penalty = 0.0 |
|
|
| |
| |
| if ( |
| action.declare_root_cause != "unknown" |
| and required |
| and not required.issubset(gathered) |
| ): |
| penalty += PREMATURE_GUESS_PENALTY |
| breakdown["penalty_premature_guess"] = -PREMATURE_GUESS_PENALTY |
|
|
| |
| if system_crashed: |
| penalty += SYSTEM_CRASH_PENALTY |
| breakdown["penalty_system_crash"] = -SYSTEM_CRASH_PENALTY |
|
|
| raw = sum(v for k, v in breakdown.items() if not k.startswith("penalty_")) |
| total = max(0.001, min(0.999, round(raw - penalty, 4))) |
|
|
| breakdown["budget_spent"] = self._budget |
| breakdown["final_health"] = self._health |
| breakdown["turns_used"] = float(self._turn) |
|
|
| return Reward(total_score=total, breakdown=breakdown) |
|
|
| |
| |
| |
|
|
| def _build_observation(self) -> Observation: |
| """Build an Observation from the current case + internal state.""" |
| case = self._case |
| if case is None: |
| raise RuntimeError("No active case β call reset() first.") |
|
|
| return Observation( |
| incident_id=case["incident_id"], |
| severity=case.get("severity", "P1"), |
| initial_observation=case.get("initial_observation", ""), |
| active_alerts=case["active_alerts"], |
| system_metrics=case["system_metrics"], |
| timeline=case.get("timeline", []), |
| investigation_results=dict(self._investigation_results), |
| system_health=round(self._health, 2), |
| budget_spent=round(self._budget, 2), |
| turn_number=self._turn, |
| turns_remaining=MAX_TURNS - self._turn, |
| available_actions={ |
| "containment_action": [ |
| "scale_up_nodes", "rate_limit_all", |
| "rollback_last_deploy", "do_nothing", |
| ], |
| "investigation_query": [ |
| "analyze_ip_traffic", "query_db_locks", |
| "check_commit_diffs", "check_service_mesh", |
| "check_resource_utilization", "none", |
| ], |
| "declare_root_cause": [ |
| "ddos_attack", "viral_traffic", |
| "bad_code", "database_lock", "unknown", |
| ], |
| }, |
| ) |
|
|