from __future__ import annotations import uuid from typing import Dict, Any, List from models import Action, ActionType from tasks.base import BaseTask, InternalState, StepOutput, semantic_match INCIDENT_TIME = "2026-03-30T11:02:00Z" DEPENDENCIES = [ {"service": "api-gateway", "calls": ["order-service", "product-catalog-service"], "called_by": []}, {"service": "order-service", "calls": ["product-catalog-service"], "called_by": ["api-gateway"]}, {"service": "data-pipeline-service", "calls": ["product-catalog-service"], "called_by": []}, {"service": "product-catalog-service", "calls": [], "called_by": ["api-gateway", "order-service", "data-pipeline-service"]}, {"service": "price-validation-service","calls": ["product-catalog-service"], "called_by": []}, {"service": "analytics-service", "calls": ["order-service"], "called_by": []}, ] PIPELINE_LOGS = [ "[11:01:55] INFO Deployment data-pipeline-service:{version} complete", "[11:01:58] INFO Health check passed. Starting pipeline workers.", "[11:02:00] INFO Pipeline worker started. Consuming from topic: product-updates", "[11:02:01] INFO Processed batch: 142 records written to product-catalog", "[11:02:03] INFO Processed batch: 138 records written to product-catalog", "[11:02:07] INFO Processed batch: 147 records written to product-catalog", "[11:02:09] INFO All writes succeeded (HTTP 200) - no errors detected", ] PRICE_VALIDATION_LOGS = [ "[11:02:08] INFO Validation batch started: 312 products", "[11:02:10] WARN PRICE_MISMATCH: product_id=1042 catalog=149.99 expected=14.99 (10x multiplier?)", "[11:02:11] WARN PRICE_MISMATCH: product_id=2891 catalog=899.00 expected=89.00", "[11:02:13] WARN PRICE_MISMATCH: product_id=0391 catalog=24.90 expected=2.49", "[11:02:14] WARN PRICE_MISMATCH: product_id=5521 catalog=1299.90 expected=129.99", "[11:02:17] WARN PRICE_MISMATCH: product_id=7823 catalog=49.90 expected=4.99", "[11:02:21] WARN PRICE_MISMATCH: product_id=3314 catalog=799.00 expected=79.90", "[11:02:24] INFO Validation batch complete: 265 ok, 47 mismatches (15.1% rate, baseline: 0.2%)", "[11:02:24] WARN Mismatch rate 15.1% exceeds SLA threshold 1.0% - notifying data team", ] ANALYTICS_LOGS = [ "[11:01:50] INFO Hourly report: avg_order_value=$89.42 orders=138 (normal)", "[11:02:00] INFO Hourly report: avg_order_value=$91.18 orders=141", "[11:02:10] INFO ANOMALY: avg_order_value=$312.44 (3.5x baseline) in last 2 min", "[11:02:20] WARN avg_order_value=$847.23 - possible pricing issue", "[11:02:21] INFO orders_per_minute=142 (normal: 120-160) - volume is normal", "[11:02:21] INFO Spike NOT correlated with marketing campaign or known event", ] CATALOG_LOGS = [ "[11:02:01] INFO PUT /catalog/product/1042 200 8ms price=149.99", "[11:02:02] INFO PUT /catalog/product/2891 200 7ms price=899.00", "[11:02:03] INFO PUT /catalog/product/0391 200 6ms price=24.90", "[11:02:04] INFO PUT /catalog/product/5521 200 8ms price=1299.90", "[11:02:05] INFO All writes returning 200 OK - no DB errors", ] GATEWAY_LOGS = [ "[11:02:00] INFO GET /api/v1/products 200 12ms", "[11:02:05] INFO POST /api/v1/orders 200 88ms", "[11:02:15] INFO POST /api/v1/orders 200 91ms", "[11:02:20] INFO POST /api/v1/orders 200 87ms", ] ORDER_LOGS = [ "[11:02:05] INFO Order ORD-9901: total=$149.99 (product_id=1042)", "[11:02:08] INFO Order ORD-9902: total=$899.00 (product_id=2891)", "[11:02:12] INFO Order ORD-9903: total=$1299.90 (product_id=5521)", ] # Extra noise alerts that don't point to the real issue NOISE_ALERTS = [ { "id": "A030", "severity": "info", "service": "api-gateway", "message": "TLS certificate renewing in 14 days - scheduled maintenance upcoming", "timestamp": "2026-03-30T11:00:00Z", "acknowledged": False, }, { "id": "A031", "severity": "info", "service": "analytics-service", "message": "Nightly aggregation job starting 5 minutes early due to backlog", "timestamp": "2026-03-30T11:01:45Z", "acknowledged": False, }, { "id": "A032", "severity": "info", "service": "product-catalog-service", "message": "Read replica lag 280ms (threshold: 500ms) - within normal range", "timestamp": "2026-03-30T11:02:00Z", "acknowledged": False, }, ] class HardTask(BaseTask): def initialize(self) -> InternalState: bad_ver = f"v3.1.{self.rng.randint(0, 4)}" logs = { "data-pipeline-service": [l.replace("{version}", bad_ver) for l in PIPELINE_LOGS], "price-validation-service": PRICE_VALIDATION_LOGS[:], "analytics-service": ANALYTICS_LOGS[:], "product-catalog-service": CATALOG_LOGS[:], "api-gateway": GATEWAY_LOGS[:], "order-service": ORDER_LOGS[:], } def healthy_svc(name, ver, deployed): return { "name": name, "status": "healthy", "cpu_percent": round(self.rng.uniform(22, 48), 1), "memory_percent": round(self.rng.uniform(35, 55), 1), "error_rate": 0.0, "latency_p99_ms": round(self.rng.uniform(8, 130), 0), "replicas_running": self.rng.choice([2, 3]), "replicas_desired": self.rng.choice([2, 3]), "current_version": ver, "last_deployed": deployed, "minutes_degraded": 0, "sla_breach": False, } services = { "api-gateway": {**healthy_svc("api-gateway", "v3.1.0", "2026-03-20T08:00:00Z"), "replicas_running": 2, "replicas_desired": 2}, "data-pipeline-service": {**healthy_svc("data-pipeline-service", bad_ver, "2026-03-30T11:01:55Z"), "replicas_running": 3, "replicas_desired": 3}, "product-catalog-service": {**healthy_svc("product-catalog-service", "v2.0.1", "2026-03-10T12:00:00Z"), "replicas_running": 2, "replicas_desired": 2}, "price-validation-service":{**healthy_svc("price-validation-service","v1.4.0", "2026-03-12T14:00:00Z"), "replicas_running": 2, "replicas_desired": 2}, "analytics-service": {**healthy_svc("analytics-service", "v2.3.1", "2026-03-14T10:00:00Z"), "replicas_running": 2, "replicas_desired": 2}, "order-service": {**healthy_svc("order-service", "v1.8.2", "2026-03-22T10:00:00Z"), "replicas_running": 3, "replicas_desired": 3}, } # Real signal alerts + noise alerts = NOISE_ALERTS[:] + [ { "id": "A020", "severity": "info", "service": "price-validation-service", "message": "Price mismatch rate 15.1% — above SLA threshold of 1.0%. Data team notified.", "timestamp": "2026-03-30T11:02:24Z", "acknowledged": False, }, { "id": "A021", "severity": "warning", "service": "analytics-service", "message": "avg_order_value anomaly: $847.23 vs baseline $89.42 — not correlated with campaigns", "timestamp": "2026-03-30T11:02:21Z", "acknowledged": False, }, ] state = InternalState( episode_id=str(uuid.uuid4()), task_id="hard", step=0, max_steps=25, services=services, alerts=alerts, logs=logs, action_history=[], total_reward=0.0, incident_resolved=False, ground_truth_root_cause=f"data_corruption_data_pipeline_{bad_ver}_incorrect_price_writes", ground_truth_fix="rollback data-pipeline-service then alert_oncall for data audit", incident_start_time=INCIDENT_TIME, healthy_services=list(services.keys()), service_dependencies=DEPENDENCIES, ) state._bad_ver = bad_ver return state def step(self, state: InternalState, action: Action) -> StepOutput: state.step += 1 # No SLA degradation on hard task — all services stay green at = action.action_type svc = action.service or "" reward = 0.0 done = False info: Dict[str, Any] = {} result_text, error_text = self._apply_action_to_logs(state, action) gather_map = { ("read_logs", "price-validation-service"): ("rl_price", 0.05), ("search_logs", "price-validation-service"): ("rl_price", 0.05), ("read_logs", "analytics-service"): ("rl_analytics", 0.05), ("search_logs", "analytics-service"): ("rl_analytics", 0.05), ("read_logs", "data-pipeline-service"): ("rl_pipeline", 0.05), ("search_logs", "data-pipeline-service"): ("rl_pipeline", 0.05), ("read_metrics", "analytics-service"): ("rm_analytics", 0.10), ("read_metrics", "data-pipeline-service"): ("rm_pipeline", 0.10), } k = (at.value, svc) if k in gather_map: tag, r = gather_map[k] if tag not in state.rewards_given: reward += r; state.rewards_given.add(tag) if at == ActionType.READ_RUNBOOK: if "runbook" not in state.rewards_given: reward += 0.05; state.rewards_given.add("runbook") # Restarts/scale-ups are always wrong here if at in (ActionType.RESTART_SERVICE, ActionType.SCALE_UP): reward -= 0.15 error_text = ( f"Restarting/scaling {svc} will not fix corrupt data already written. " "You need to rollback the pipeline and audit the data." ) if at == ActionType.DIAGNOSE: rc = action.root_cause or "" has_pipeline = semantic_match(rc, ["pipeline", "data-pipeline"]) has_corruption = semantic_match(rc, ["corrupt", "data", "price", "wrong", "incorrect", "mismatch"]) result_text = f"Diagnosis recorded: {rc}" if has_pipeline and has_corruption: if "diagnose_correct" not in state.rewards_given: reward += 0.20; state.rewards_given.add("diagnose_correct") elif has_pipeline or has_corruption: if "diagnose_partial" not in state.rewards_given and "diagnose_correct" not in state.rewards_given: reward += 0.08; state.rewards_given.add("diagnose_partial") if at == ActionType.ROLLBACK and svc == "data-pipeline-service": reward += self._penalty_blind_remediation(state, action, "rollback_done") if "rollback_done" not in state.rewards_given: reward += 0.25; state.rewards_given.add("rollback_done") state.services["data-pipeline-service"]["current_version"] = "v3.0.9" result_text = ( "data-pipeline-service rolled back to v3.0.9. Future writes corrected. " "WARNING: corrupted prices already written must be audited." ) if "alert_oncall_done" in state.rewards_given: state.incident_resolved = True; done = True; info["resolution"] = "incident_resolved" if at == ActionType.ALERT_ONCALL: if "alert_oncall_done" not in state.rewards_given: reward += 0.15; state.rewards_given.add("alert_oncall_done") result_text = "On-call data team paged for price audit and correction job." if "rollback_done" in state.rewards_given: state.incident_resolved = True; done = True; info["resolution"] = "incident_resolved" if at in (ActionType.BLOCK_IP_RANGE, ActionType.CREATE_INDEX, ActionType.FAILOVER): reward -= 0.10 error_text = f"Action {at.value} is not applicable to this incident." state.total_reward = self._clamp(state.total_reward + reward) if state.step >= state.max_steps and not done: done = True; info["reason"] = "max_steps_reached" obs = state._build_observation(last_action_result=result_text, last_action_error=error_text) state.action_history.append({"step": state.step, "action": action.model_dump(), "reward": round(reward, 4)}) return StepOutput(next_state=state, reward=round(reward, 4), done=done, info=info)