Spaces:
Sleeping
Sleeping
File size: 11,196 Bytes
8be69b1 849b14a 8be69b1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 | from __future__ import annotations
import uuid
from typing import Dict, Any, List
from models import Action, ActionType
from tasks.base import BaseTask, InternalState, StepOutput, semantic_match
INCIDENT_TIME = "2026-04-12T11:37:00Z"
DEPENDENCIES = [
{"service": "api-gateway", "calls": ["auth-service", "user-service"], "called_by": []},
{"service": "auth-service", "calls": ["backend-db", "rate-limiter"], "called_by": ["api-gateway"]},
{"service": "rate-limiter", "calls": [], "called_by": ["auth-service"]},
{"service": "user-service", "calls": ["backend-db"], "called_by": ["api-gateway"]},
{"service": "backend-db", "calls": [], "called_by": ["auth-service", "user-service"]},
]
API_LOGS = [
"[11:37:00] INFO Traffic normal: 820 req/s",
"[11:37:30] WARN Traffic spike: 2400 req/s - monitoring",
"[11:38:00] WARN Traffic spike: 5800 req/s - alert fired",
"[11:38:30] ERROR Traffic: 12000 req/s - rate limiter overwhelmed",
"[11:38:45] ERROR 94.2% of requests from 185.x.x.x IP range",
"[11:38:46] ERROR 99.8% of high-volume requests targeting POST /api/v1/login",
"[11:38:47] WARN Dropping 78% of requests - circuit breaker opening",
"[11:39:00] ERROR Connection pool to auth-service exhausted: 500/500 connections active",
"[11:45:01] INFO GET /api/v1/products 200 12ms 203.0.113.42",
"[11:45:01] WARN POST /api/v1/login 429 8ms 185.220.101.45 [rate-limited]",
"[11:45:01] WARN POST /api/v1/login 429 8ms 185.220.101.46 [rate-limited]",
"[11:45:02] INFO GET /api/v1/health 200 3ms 10.0.0.1",
]
AUTH_LOGS = [
"[11:37:45] INFO Login attempt: user_id=NULL ip=185.220.101.45 (failed - no such user)",
"[11:37:45] INFO Login attempt: user_id=NULL ip=185.220.101.46 (failed - no such user)",
"[11:38:00] WARN 98% of login attempts are credential stuffing pattern (NULL user_ids)",
"[11:38:30] ERROR Thread pool saturation: 498/500 threads active",
"[11:38:45] ERROR Response time degraded: avg 4200ms (normal: 45ms)",
"[11:39:00] CRIT Auth service overwhelmed - dropping 60% of legitimate login attempts",
]
RATE_LIMITER_LOGS = [
"[11:38:00] INFO Rate limit config: 100 req/min per IP (no subnet blocking configured)",
"[11:38:30] WARN 185.220.101.x subnet generating 8400 req/min across 84 IPs",
"[11:38:45] WARN Per-IP rate limiting ineffective against distributed botnet",
"[11:38:46] INFO Subnet 185.220.101.0/24: 84 active IPs, avg 100 req/min each = bypassing limit",
]
class SecurityTask(BaseTask):
def initialize(self) -> InternalState:
logs = {
"api-gateway": API_LOGS[:],
"auth-service": AUTH_LOGS[:],
"rate-limiter": RATE_LIMITER_LOGS[:],
"user-service": ["[11:37:00] INFO Service normal"],
"backend-db": ["[11:38:30] WARN High connection count detected from auth-service"],
}
services = {
"api-gateway": {
"name": "api-gateway", "status": "degraded",
"cpu_percent": 95.0, "memory_percent": 45.0,
"error_rate": 78.0, "latency_p99_ms": 3500.0,
"replicas_running": 5, "replicas_desired": 5,
"current_version": "v3.1.0", "last_deployed": "2026-03-20T08:00:00Z",
"minutes_degraded": 8, "sla_breach": False,
},
"auth-service": {
"name": "auth-service", "status": "degraded",
"cpu_percent": 99.0, "memory_percent": 80.0,
"error_rate": 60.0, "latency_p99_ms": 4200.0,
"replicas_running": 3, "replicas_desired": 3,
"current_version": "v1.5.0", "last_deployed": "2026-04-10T11:00:00Z",
"minutes_degraded": 8, "sla_breach": False,
},
"backend-db": {
"name": "backend-db", "status": "degraded",
"cpu_percent": 82.0, "memory_percent": 65.0,
"error_rate": 0.0, "latency_p99_ms": 150.0,
"replicas_running": 1, "replicas_desired": 1,
"current_version": "v14.1", "last_deployed": "2025-01-01T00:00:00Z",
"minutes_degraded": 5, "sla_breach": False,
},
"rate-limiter": {
"name": "rate-limiter", "status": "healthy",
"cpu_percent": 40.0, "memory_percent": 25.0,
"error_rate": 0.0, "latency_p99_ms": 5.0,
"replicas_running": 2, "replicas_desired": 2,
"current_version": "v2.0.0", "last_deployed": "2026-01-15T00:00:00Z",
"minutes_degraded": 0, "sla_breach": False,
},
"user-service": {
"name": "user-service", "status": "healthy",
"cpu_percent": 15.0, "memory_percent": 30.0,
"error_rate": 0.0, "latency_p99_ms": 25.0,
"replicas_running": 2, "replicas_desired": 2,
"current_version": "v1.1.2", "last_deployed": "2026-03-01T00:00:00Z",
"minutes_degraded": 0, "sla_breach": False,
},
}
alerts = [
{
"id": "A001", "severity": "critical", "service": "api-gateway",
"message": "Error rate 78% - requests being dropped (traffic: 12000 req/s)",
"timestamp": "2026-04-12T11:38:47Z", "acknowledged": False,
},
{
"id": "A002", "severity": "critical", "service": "auth-service",
"message": "Response time 4200ms (threshold: 500ms) - connection pool exhausted",
"timestamp": "2026-04-12T11:38:45Z", "acknowledged": False,
},
{
"id": "A003", "severity": "warning", "service": "backend-db",
"message": "Connection pool 89% utilized - auth query storm",
"timestamp": "2026-04-12T11:38:50Z", "acknowledged": False,
},
{
"id": "A004", "severity": "info", "service": "rate-limiter",
"message": "Per-IP rate limits being bypassed by distributed source",
"timestamp": "2026-04-12T11:38:46Z", "acknowledged": False,
},
]
state = InternalState(
episode_id=str(uuid.uuid4()), task_id="security", step=0, max_steps=20,
services=services, alerts=alerts, logs=logs,
action_history=[], total_reward=0.0, incident_resolved=False,
ground_truth_root_cause="ddos_attack_185.x.x.x_botnet_targeting_login_endpoint",
ground_truth_fix="block_ip_range_185.x.x.x AND alert_oncall security team",
incident_start_time=INCIDENT_TIME,
healthy_services=["rate-limiter", "user-service"],
service_dependencies=DEPENDENCIES,
)
return state
def step(self, state: InternalState, action: Action) -> StepOutput:
state.step += 1
state._apply_sla_degradation()
at = action.action_type
svc = action.service or ""
reward = 0.0
done = False
info: Dict[str, Any] = {}
result_text, error_text = self._apply_action_to_logs(state, action)
gather_map = {
("read_logs", "api-gateway"): ("rl_api", 0.10),
("search_logs", "api-gateway"): ("rl_api", 0.10),
("read_logs", "auth-service"): ("rl_auth", 0.10),
("search_logs", "auth-service"): ("rl_auth", 0.10),
("read_logs", "rate-limiter"): ("rl_rate", 0.05),
("search_logs", "rate-limiter"): ("rl_rate", 0.05),
}
k = (at.value, svc)
if k in gather_map:
tag, r = gather_map[k]
if tag not in state.rewards_given:
reward += r; state.rewards_given.add(tag)
if at == ActionType.READ_RUNBOOK:
rb = action.runbook or ""
if rb.endswith("security_incident.md"):
if "runbook_security" not in state.rewards_given:
reward += 0.05; state.rewards_given.add("runbook_security")
if at == ActionType.DIAGNOSE:
rc = action.root_cause or ""
if semantic_match(rc, ["ddos", "botnet", "185", "attack", "credential stuffing"]):
if "diagnose_correct" not in state.rewards_given:
reward += 0.20; state.rewards_given.add("diagnose_correct")
result_text = f"Diagnosis recorded: {rc}"
if at == ActionType.BLOCK_IP_RANGE:
ip_range = action.ip_range or ""
if "185" in ip_range:
if "fix_block" not in state.rewards_given:
reward += 0.30; state.rewards_given.add("fix_block")
if ip_range == "185.0.0.0/8" or ip_range == "185.220.0.0/16":
if "bonus_cidr" not in state.rewards_given:
reward += 0.10; state.rewards_given.add("bonus_cidr")
result_text = f"Successfully applied firewall block rule for IP range {ip_range}."
if "fix_alert" in state.rewards_given:
state.incident_resolved = True; done = True; info["resolution"] = "incident_resolved"
else:
reward -= 0.10
result_text = f"Blocked IP range {ip_range}, but it did not stop the attack."
if at == ActionType.ALERT_ONCALL:
reason = (action.reason or "").lower()
if semantic_match(reason, ["security", "ddos", "attack"]):
if "fix_alert" not in state.rewards_given:
reward += 0.20; state.rewards_given.add("fix_alert")
result_text = "Security team paged. They are actively monitoring the situation."
if "fix_block" in state.rewards_given:
state.incident_resolved = True; done = True; info["resolution"] = "incident_resolved"
else:
result_text = "On-call paged, but without security context they cannot escalate."
if at == ActionType.RESTART_SERVICE:
reward -= 0.15
result_text = f"Restarted {svc}. Connection pool dropped but immediately overwhelmed again by DDoS."
if at == ActionType.ROLLBACK:
reward -= 0.10
result_text = f"Rolled back {svc}, but this is an external attack, not a bad deployment."
if at == ActionType.NOOP and state.step > 5:
reward -= 0.03
if at in (ActionType.CREATE_INDEX, ActionType.FAILOVER):
reward -= 0.10
error_text = f"Action {at.value} is not applicable to this incident."
state.total_reward = self._clamp(state.total_reward + reward)
if state.step >= state.max_steps and not done:
done = True; info["reason"] = "max_steps_reached"
obs = state._build_observation(last_action_result=result_text, last_action_error=error_text)
state.action_history.append({"step": state.step, "action": action.model_dump(), "reward": round(reward, 4)})
return StepOutput(next_state=state, reward=round(reward, 4), done=done, info=info)
|