Spaces:
Sleeping
Sleeping
File size: 7,096 Bytes
4887b5f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 | import random
import uuid
from datetime import datetime
from tasks.base import BaseTask, InternalState, StepOutput
from models import Action, StepResult, ServiceStatus, Alert, ActionType
class GeneratedTask(BaseTask):
def __init__(self, incident_dict: dict):
self.incident = incident_dict
self.task_id = "generated"
self.max_steps = 20
# For compatibility with BaseTask which expects rng in __init__
super().__init__(random.Random(incident_dict.get("seed", 42)))
def initialize(self) -> InternalState:
affected = self.incident["affected_service"]
failure_mode = self.incident["failure_mode"]
services_dict = {}
SERVICES = ["payment-service", "order-service", "user-service",
"inventory-service", "api-gateway", "notification-service",
"data-pipeline", "ml-inference-service"]
for svc in SERVICES:
if svc == affected:
services_dict[svc] = {
"name": svc,
"status": "degraded",
"cpu_percent": 85.0,
"memory_percent": 92.0,
"error_rate": 0.35,
"latency_p99_ms": 2800.0,
"replicas_running": 2,
"replicas_desired": 3,
"current_version": "v1.2.4",
"last_deployed": (datetime.utcnow()).isoformat(),
"minutes_degraded": 0,
"sla_breach": False
}
else:
services_dict[svc] = {
"name": svc,
"status": "healthy",
"cpu_percent": 25.0,
"memory_percent": 40.0,
"error_rate": 0.01,
"latency_p99_ms": 120.0,
"replicas_running": 3,
"replicas_desired": 3,
"current_version": "v1.2.3",
"last_deployed": (datetime.utcnow()).isoformat(),
"minutes_degraded": 0,
"sla_breach": False
}
active_alerts = []
# One CRITICAL alert
active_alerts.append({
"id": str(uuid.uuid4())[:8],
"service": affected,
"severity": "critical",
"message": self.incident["description"],
"timestamp": datetime.utcnow().isoformat(),
"acknowledged": False
})
# Noise alerts
for noise in self.incident["noise_alerts"]:
active_alerts.append({
"id": str(uuid.uuid4())[:8],
"service": "notification-service",
"severity": "warning",
"message": noise,
"timestamp": datetime.utcnow().isoformat(),
"acknowledged": False
})
log_lines = {
"oom": ["ERROR OutOfMemoryError: Java heap space",
"WARN Memory usage at 98%, GC overhead limit exceeded"],
"cascade": ["ERROR Connection pool exhausted: timeout after 30s",
"ERROR Failed to acquire connection from pool"],
"corruption": ["WARN Price mismatch detected: expected 29.99 got 299.9",
"WARN Data validation failed for 847 records"],
"security": ["WARN 1847 failed login attempts in 60s",
"WARN Rate limit exceeded from 185.220.101.x"],
"database": ["WARN Slow query: seq_scan on orders (847ms)",
"WARN Query planner chose sequential scan, missing index"],
"network_partition": ["ERROR Connection timeout to us-east-1",
"ERROR Health check failed: unreachable"]
}
logs = {}
for svc in SERVICES:
if svc == affected:
logs[svc] = log_lines.get(failure_mode, ["INFO Service running normally"])
else:
logs[svc] = ["INFO Service running normally", "INFO Health check passed"]
state = InternalState(
episode_id=str(uuid.uuid4()),
task_id="generated",
step=0,
max_steps=self.max_steps,
services=services_dict,
alerts=active_alerts,
logs=logs,
action_history=[],
total_reward=0.0,
incident_resolved=False,
ground_truth_root_cause=self.incident["ground_truth_root_cause"],
ground_truth_fix=self.incident["ground_truth_fix"],
incident_start_time=datetime.utcnow().isoformat(),
rewards_given=set()
)
state._scenario = self.incident
return state
def step(self, state: InternalState, action: Action) -> StepOutput:
reward = 0.0
result_text, error_text = self._apply_action_to_logs(state, action)
# ActionType can be enum or string
at = action.action_type
at_val = at.value if hasattr(at, "value") else str(at)
if at_val == "read_logs":
if action.service == self.incident["affected_service"]:
if "read_logs" not in state.rewards_given:
reward += 0.10
state.rewards_given.add("read_logs")
if at_val == "diagnose":
diagnosis = action.diagnosis or action.root_cause or ""
if state.ground_truth_root_cause.lower() in diagnosis.lower():
if "diagnose" not in state.rewards_given:
reward += 0.30
state.rewards_given.add("diagnose")
if at_val == self.incident["ground_truth_fix"]:
if action.service == self.incident["affected_service"] and "fix" not in state.rewards_given:
reward += 0.45
state.rewards_given.add("fix")
state.incident_resolved = True
state.services[self.incident["affected_service"]]["status"] = "healthy"
state.services[self.incident["affected_service"]]["cpu_percent"] = 25.0
state.services[self.incident["affected_service"]]["memory_percent"] = 40.0
state.services[self.incident["affected_service"]]["error_rate"] = 0.01
state.services[self.incident["affected_service"]]["latency_p99_ms"] = 120.0
state.step += 1
state.total_reward = self._clamp(state.total_reward + reward)
done = state.incident_resolved or state.step >= self.max_steps
info = {}
if state.incident_resolved: info["resolution"] = "incident_resolved"
if state.step >= self.max_steps: info["reason"] = "max_steps_reached"
state.action_history.append({
"step": state.step,
"action": action.model_dump(),
"reward": round(reward, 4)
})
return StepOutput(next_state=state, reward=round(reward, 4), done=done, info=info)
|