Yaswanth-Bolla's picture
Initial commit
1175c0b
"""
Typed models for the SRE Incident Response environment.
Action space: hierarchical β€” select action_type first, then target + params.
Observation space: POMDP β€” agent never sees fault_type, only symptoms.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, List, Optional
# ---------------------------------------------------------------------------
# Action Space (Layer 4 β€” Hierarchical + Masked)
# ---------------------------------------------------------------------------
class ActionType(str, Enum):
"""Level-1 action categories β€” what kind of operation."""
VIEW_ALERTS = "view_alerts"
QUERY_LOGS = "query_logs"
CHECK_METRICS = "check_metrics"
CHECK_DEPENDENCIES = "check_dependencies"
CHECK_DEPLOY_HISTORY = "check_deploy_history"
RUN_HEALTH_CHECK = "run_health_check"
RESTART_SERVICE = "restart_service"
ROLLBACK_DEPLOY = "rollback_deploy"
SCALE_SERVICE = "scale_service"
DECLARE_ROOT_CAUSE = "declare_root_cause"
# Actions that require a target_service (Level 2 β€” where to apply)
TARGETED_ACTIONS = {
ActionType.QUERY_LOGS,
ActionType.CHECK_METRICS,
ActionType.CHECK_DEPENDENCIES,
ActionType.CHECK_DEPLOY_HISTORY,
ActionType.RUN_HEALTH_CHECK,
ActionType.RESTART_SERVICE,
ActionType.ROLLBACK_DEPLOY,
ActionType.SCALE_SERVICE,
}
# Actions that are diagnostic (information-gathering, no state mutation)
DIAGNOSTIC_ACTIONS = {
ActionType.VIEW_ALERTS,
ActionType.QUERY_LOGS,
ActionType.CHECK_METRICS,
ActionType.CHECK_DEPENDENCIES,
ActionType.CHECK_DEPLOY_HISTORY,
ActionType.RUN_HEALTH_CHECK,
}
# Actions that mutate infrastructure state
REMEDIATION_ACTIONS = {
ActionType.RESTART_SERVICE,
ActionType.ROLLBACK_DEPLOY,
ActionType.SCALE_SERVICE,
}
@dataclass
class IncidentAction:
"""
Agent action β€” hierarchical: action_type β†’ target_service β†’ parameters.
The LLM emits JSON with these three fields. The action mask in the
observation tells it which (action_type, target_service) pairs are legal.
"""
action_type: str # ActionType value
target_service: Optional[str] = None # Required for TARGETED_ACTIONS
parameters: Dict[str, Any] = field(default_factory=dict)
def parsed_type(self) -> ActionType:
return ActionType(self.action_type)
# ---------------------------------------------------------------------------
# Observation Space (Layer 2 β€” POMDP, partial views only)
# ---------------------------------------------------------------------------
@dataclass
class AlertInfo:
"""A single firing alert β€” what fired, not why."""
alert_id: str
severity: str # "critical" | "warning" | "info"
source_service: str
description: str
firing_since: str # ISO timestamp
@dataclass
class MetricSnapshot:
"""Time-series metrics for a single service β€” temporal pattern visible."""
service_name: str
timestamps: List[str]
cpu_percent: List[float]
memory_percent: List[float]
error_rate_percent: List[float]
latency_p50_ms: List[float]
latency_p95_ms: List[float]
latency_p99_ms: List[float]
requests_per_sec: List[float]
@dataclass
class LogEntry:
"""A single structured log entry β€” error semantics visible."""
timestamp: str
level: str # "DEBUG" | "INFO" | "WARN" | "ERROR" | "FATAL"
service: str
message: str
trace_id: Optional[str] = None
extra: Dict[str, Any] = field(default_factory=dict)
@dataclass
class DeployRecord:
"""A single deploy β€” evidence trail for rollback decisions."""
version: str
timestamp: str
author: str
commit_hash: str
description: str
@dataclass
class DependencyInfo:
"""Upstream/downstream dependency map for a service."""
service_name: str
depends_on: List[str] # services this one calls
depended_by: List[str] # services that call this one
@dataclass
class IncidentObservation:
"""
What the agent sees after each step.
This is a PARTIAL observation β€” the agent never sees fault_type,
fault_target, or the internal simulation state. It must infer
the root cause from the five observation modalities.
"""
# --- Incident context (always visible) ---
incident_summary: str
severity: str # "SEV1" | "SEV2" | "SEV3"
time_elapsed_minutes: int
time_budget_minutes: int
# --- Result of last action ---
action_result: Dict[str, Any] = field(default_factory=dict)
action_success: bool = True
action_message: str = ""
# --- Dashboard (always visible) ---
service_statuses: Dict[str, str] = field(default_factory=dict) # name β†’ "healthy"|"degraded"|"down"
active_alerts_count: int = 0
# --- Action mask (Layer 4 β€” prevents illegal actions) ---
valid_actions: List[str] = field(default_factory=list)
available_services: List[str] = field(default_factory=list)
# --- Episode progress ---
current_reward: float = 0.0
cumulative_reward: float = 0.0
steps_taken: int = 0
max_steps: int = 20
done: bool = False
# ---------------------------------------------------------------------------
# State (internal tracking β€” exposed via state() for debugging)
# ---------------------------------------------------------------------------
@dataclass
class IncidentState:
"""Episode metadata β€” returned by state() for monitoring/debugging."""
episode_id: str = ""
task_name: str = ""
step_count: int = 0
time_elapsed_minutes: int = 0
done: bool = False
cumulative_reward: float = 0.0
declared_root_cause: Optional[str] = None
# ---------------------------------------------------------------------------
# Step record β€” stored per step for trajectory-based grading
# ---------------------------------------------------------------------------
@dataclass
class StepRecord:
"""
Immutable record of a single step β€” used by the grader.
The grader receives List[StepRecord] and scores WITHOUT hidden state.
"""
step_number: int
action: IncidentAction
reward: float
observation_summary: Dict[str, Any] # key fields from observation
service_statuses_after: Dict[str, str] # service health after this step
timestamp_minutes: int # simulation time