devops-incident-response / tasks /task_security.py
Arijit-07's picture
Sync: all 7 tasks, openenv-core migration, training notebook, cleanup
849b14a
from __future__ import annotations
import uuid
from typing import Dict, Any, List
from models import Action, ActionType
from tasks.base import BaseTask, InternalState, StepOutput, semantic_match
INCIDENT_TIME = "2026-04-12T11:37:00Z"
DEPENDENCIES = [
{"service": "api-gateway", "calls": ["auth-service", "user-service"], "called_by": []},
{"service": "auth-service", "calls": ["backend-db", "rate-limiter"], "called_by": ["api-gateway"]},
{"service": "rate-limiter", "calls": [], "called_by": ["auth-service"]},
{"service": "user-service", "calls": ["backend-db"], "called_by": ["api-gateway"]},
{"service": "backend-db", "calls": [], "called_by": ["auth-service", "user-service"]},
]
API_LOGS = [
"[11:37:00] INFO Traffic normal: 820 req/s",
"[11:37:30] WARN Traffic spike: 2400 req/s - monitoring",
"[11:38:00] WARN Traffic spike: 5800 req/s - alert fired",
"[11:38:30] ERROR Traffic: 12000 req/s - rate limiter overwhelmed",
"[11:38:45] ERROR 94.2% of requests from 185.x.x.x IP range",
"[11:38:46] ERROR 99.8% of high-volume requests targeting POST /api/v1/login",
"[11:38:47] WARN Dropping 78% of requests - circuit breaker opening",
"[11:39:00] ERROR Connection pool to auth-service exhausted: 500/500 connections active",
"[11:45:01] INFO GET /api/v1/products 200 12ms 203.0.113.42",
"[11:45:01] WARN POST /api/v1/login 429 8ms 185.220.101.45 [rate-limited]",
"[11:45:01] WARN POST /api/v1/login 429 8ms 185.220.101.46 [rate-limited]",
"[11:45:02] INFO GET /api/v1/health 200 3ms 10.0.0.1",
]
AUTH_LOGS = [
"[11:37:45] INFO Login attempt: user_id=NULL ip=185.220.101.45 (failed - no such user)",
"[11:37:45] INFO Login attempt: user_id=NULL ip=185.220.101.46 (failed - no such user)",
"[11:38:00] WARN 98% of login attempts are credential stuffing pattern (NULL user_ids)",
"[11:38:30] ERROR Thread pool saturation: 498/500 threads active",
"[11:38:45] ERROR Response time degraded: avg 4200ms (normal: 45ms)",
"[11:39:00] CRIT Auth service overwhelmed - dropping 60% of legitimate login attempts",
]
RATE_LIMITER_LOGS = [
"[11:38:00] INFO Rate limit config: 100 req/min per IP (no subnet blocking configured)",
"[11:38:30] WARN 185.220.101.x subnet generating 8400 req/min across 84 IPs",
"[11:38:45] WARN Per-IP rate limiting ineffective against distributed botnet",
"[11:38:46] INFO Subnet 185.220.101.0/24: 84 active IPs, avg 100 req/min each = bypassing limit",
]
class SecurityTask(BaseTask):
def initialize(self) -> InternalState:
logs = {
"api-gateway": API_LOGS[:],
"auth-service": AUTH_LOGS[:],
"rate-limiter": RATE_LIMITER_LOGS[:],
"user-service": ["[11:37:00] INFO Service normal"],
"backend-db": ["[11:38:30] WARN High connection count detected from auth-service"],
}
services = {
"api-gateway": {
"name": "api-gateway", "status": "degraded",
"cpu_percent": 95.0, "memory_percent": 45.0,
"error_rate": 78.0, "latency_p99_ms": 3500.0,
"replicas_running": 5, "replicas_desired": 5,
"current_version": "v3.1.0", "last_deployed": "2026-03-20T08:00:00Z",
"minutes_degraded": 8, "sla_breach": False,
},
"auth-service": {
"name": "auth-service", "status": "degraded",
"cpu_percent": 99.0, "memory_percent": 80.0,
"error_rate": 60.0, "latency_p99_ms": 4200.0,
"replicas_running": 3, "replicas_desired": 3,
"current_version": "v1.5.0", "last_deployed": "2026-04-10T11:00:00Z",
"minutes_degraded": 8, "sla_breach": False,
},
"backend-db": {
"name": "backend-db", "status": "degraded",
"cpu_percent": 82.0, "memory_percent": 65.0,
"error_rate": 0.0, "latency_p99_ms": 150.0,
"replicas_running": 1, "replicas_desired": 1,
"current_version": "v14.1", "last_deployed": "2025-01-01T00:00:00Z",
"minutes_degraded": 5, "sla_breach": False,
},
"rate-limiter": {
"name": "rate-limiter", "status": "healthy",
"cpu_percent": 40.0, "memory_percent": 25.0,
"error_rate": 0.0, "latency_p99_ms": 5.0,
"replicas_running": 2, "replicas_desired": 2,
"current_version": "v2.0.0", "last_deployed": "2026-01-15T00:00:00Z",
"minutes_degraded": 0, "sla_breach": False,
},
"user-service": {
"name": "user-service", "status": "healthy",
"cpu_percent": 15.0, "memory_percent": 30.0,
"error_rate": 0.0, "latency_p99_ms": 25.0,
"replicas_running": 2, "replicas_desired": 2,
"current_version": "v1.1.2", "last_deployed": "2026-03-01T00:00:00Z",
"minutes_degraded": 0, "sla_breach": False,
},
}
alerts = [
{
"id": "A001", "severity": "critical", "service": "api-gateway",
"message": "Error rate 78% - requests being dropped (traffic: 12000 req/s)",
"timestamp": "2026-04-12T11:38:47Z", "acknowledged": False,
},
{
"id": "A002", "severity": "critical", "service": "auth-service",
"message": "Response time 4200ms (threshold: 500ms) - connection pool exhausted",
"timestamp": "2026-04-12T11:38:45Z", "acknowledged": False,
},
{
"id": "A003", "severity": "warning", "service": "backend-db",
"message": "Connection pool 89% utilized - auth query storm",
"timestamp": "2026-04-12T11:38:50Z", "acknowledged": False,
},
{
"id": "A004", "severity": "info", "service": "rate-limiter",
"message": "Per-IP rate limits being bypassed by distributed source",
"timestamp": "2026-04-12T11:38:46Z", "acknowledged": False,
},
]
state = InternalState(
episode_id=str(uuid.uuid4()), task_id="security", step=0, max_steps=20,
services=services, alerts=alerts, logs=logs,
action_history=[], total_reward=0.0, incident_resolved=False,
ground_truth_root_cause="ddos_attack_185.x.x.x_botnet_targeting_login_endpoint",
ground_truth_fix="block_ip_range_185.x.x.x AND alert_oncall security team",
incident_start_time=INCIDENT_TIME,
healthy_services=["rate-limiter", "user-service"],
service_dependencies=DEPENDENCIES,
)
return state
def step(self, state: InternalState, action: Action) -> StepOutput:
state.step += 1
state._apply_sla_degradation()
at = action.action_type
svc = action.service or ""
reward = 0.0
done = False
info: Dict[str, Any] = {}
result_text, error_text = self._apply_action_to_logs(state, action)
gather_map = {
("read_logs", "api-gateway"): ("rl_api", 0.10),
("search_logs", "api-gateway"): ("rl_api", 0.10),
("read_logs", "auth-service"): ("rl_auth", 0.10),
("search_logs", "auth-service"): ("rl_auth", 0.10),
("read_logs", "rate-limiter"): ("rl_rate", 0.05),
("search_logs", "rate-limiter"): ("rl_rate", 0.05),
}
k = (at.value, svc)
if k in gather_map:
tag, r = gather_map[k]
if tag not in state.rewards_given:
reward += r; state.rewards_given.add(tag)
if at == ActionType.READ_RUNBOOK:
rb = action.runbook or ""
if rb.endswith("security_incident.md"):
if "runbook_security" not in state.rewards_given:
reward += 0.05; state.rewards_given.add("runbook_security")
if at == ActionType.DIAGNOSE:
rc = action.root_cause or ""
if semantic_match(rc, ["ddos", "botnet", "185", "attack", "credential stuffing"]):
if "diagnose_correct" not in state.rewards_given:
reward += 0.20; state.rewards_given.add("diagnose_correct")
result_text = f"Diagnosis recorded: {rc}"
if at == ActionType.BLOCK_IP_RANGE:
ip_range = action.ip_range or ""
if "185" in ip_range:
if "fix_block" not in state.rewards_given:
reward += 0.30; state.rewards_given.add("fix_block")
if ip_range == "185.0.0.0/8" or ip_range == "185.220.0.0/16":
if "bonus_cidr" not in state.rewards_given:
reward += 0.10; state.rewards_given.add("bonus_cidr")
result_text = f"Successfully applied firewall block rule for IP range {ip_range}."
if "fix_alert" in state.rewards_given:
state.incident_resolved = True; done = True; info["resolution"] = "incident_resolved"
else:
reward -= 0.10
result_text = f"Blocked IP range {ip_range}, but it did not stop the attack."
if at == ActionType.ALERT_ONCALL:
reason = (action.reason or "").lower()
if semantic_match(reason, ["security", "ddos", "attack"]):
if "fix_alert" not in state.rewards_given:
reward += 0.20; state.rewards_given.add("fix_alert")
result_text = "Security team paged. They are actively monitoring the situation."
if "fix_block" in state.rewards_given:
state.incident_resolved = True; done = True; info["resolution"] = "incident_resolved"
else:
result_text = "On-call paged, but without security context they cannot escalate."
if at == ActionType.RESTART_SERVICE:
reward -= 0.15
result_text = f"Restarted {svc}. Connection pool dropped but immediately overwhelmed again by DDoS."
if at == ActionType.ROLLBACK:
reward -= 0.10
result_text = f"Rolled back {svc}, but this is an external attack, not a bad deployment."
if at == ActionType.NOOP and state.step > 5:
reward -= 0.03
if at in (ActionType.CREATE_INDEX, ActionType.FAILOVER):
reward -= 0.10
error_text = f"Action {at.value} is not applicable to this incident."
state.total_reward = self._clamp(state.total_reward + reward)
if state.step >= state.max_steps and not done:
done = True; info["reason"] = "max_steps_reached"
obs = state._build_observation(last_action_result=result_text, last_action_error=error_text)
state.action_history.append({"step": state.step, "action": action.model_dump(), "reward": round(reward, 4)})
return StepOutput(next_state=state, reward=round(reward, 4), done=done, info=info)