Arijit-07's picture
Sync: all 7 tasks, openenv-core migration, training notebook, cleanup
849b14a
from __future__ import annotations
import uuid
from typing import Dict, Any, List
from models import Action, ActionType
from tasks.base import BaseTask, InternalState, StepOutput, semantic_match
INCIDENT_TIME = "2026-03-30T10:14:47Z"
SCENARIOS = [
{
"failing_service": "payment-service",
"root_cause": "memory_leak_payment_service",
"fix": "restart payment-service",
"alert_msg": "payment-service pod restarting (OOMKilled)",
"language": "java",
"diagnosis_keywords": ["memory", "oom", "heap", "leak", "outofmemory", "kill"],
},
{
"failing_service": "order-service",
"root_cause": "memory_leak_order_service",
"fix": "restart order-service",
"alert_msg": "order-service pod restarting (OOMKilled)",
"language": "python",
"diagnosis_keywords": ["memory", "oom", "heap", "leak", "segfault", "kill", "allocat"],
},
{
"failing_service": "user-service",
"root_cause": "memory_leak_user_service",
"fix": "restart user-service",
"alert_msg": "user-service pod restarting (OOMKilled)",
"language": "node",
"diagnosis_keywords": ["memory", "heap", "oom", "leak", "javascript", "kill"],
},
]
ALL_SERVICES = ["payment-service", "order-service", "user-service", "api-gateway"]
VERSIONS = {
"payment-service": "v4.2.1", "order-service": "v1.8.2",
"user-service": "v3.0.5", "api-gateway": "v2.1.0",
}
DEPENDENCIES = [
{"service": "api-gateway", "calls": ["payment-service", "order-service", "user-service"], "called_by": []},
{"service": "payment-service", "calls": [], "called_by": ["api-gateway"]},
{"service": "order-service", "calls": [], "called_by": ["api-gateway"]},
{"service": "user-service", "calls": [], "called_by": ["api-gateway"]},
]
def _make_logs(scenario, heap1, heap2, restart_count):
svc = scenario["failing_service"]
lang = scenario["language"]
if lang == "java":
failing = [
"[10:13:55] INFO Request processed 200 38ms",
f"[10:14:35] WARN Heap usage at {heap1}% - approaching threshold",
f"[10:14:41] WARN Heap usage at {heap2}%",
"[10:14:45] WARN GC overhead limit exceeded - major GC running",
"[10:14:47] ERROR java.lang.OutOfMemoryError: Java heap space",
"[10:14:47] ERROR at com.payments.ChargeProcessor.process(ChargeProcessor.java:142)",
f"[10:14:48] FATAL Service entering crash loop - pod restart #{restart_count}",
]
elif lang == "python":
failing = [
"[10:13:55] INFO POST /orders 200 55ms",
f"[10:14:35] WARN RSS memory {heap1}% of pod limit",
f"[10:14:41] WARN RSS memory {heap2}% of pod limit - approaching OOM",
"[10:14:46] ERROR Memory allocator: no more pages available",
"[10:14:47] ERROR Fatal Python error: Segmentation fault (memory allocator exhausted)",
f"[10:14:48] FATAL Pod killed by OOM killer - restart #{restart_count}",
]
else:
failing = [
"[10:13:55] INFO GET /users/profile 200 9ms",
f"[10:14:35] WARN Heap used: {heap1}% ({heap1 * 2}MB / 200MB)",
f"[10:14:41] WARN Heap used: {heap2}% - GC pressure increasing",
"[10:14:47] ERROR FATAL ERROR: Reached heap limit - JavaScript heap out of memory",
f"[10:14:48] FATAL Container OOMKilled - restart #{restart_count}",
]
logs = {svc: failing}
for name in ALL_SERVICES:
if name == svc: continue
if name == "api-gateway":
logs[name] = [
"[10:14:30] INFO GET /api/v1/health 200 3ms",
f"[10:14:48] WARN Upstream {svc} returned 503",
f"[10:14:49] WARN Circuit breaker OPEN for {svc}",
]
else:
logs[name] = ["[10:14:30] INFO Service healthy - 0 errors"]
return logs
class EasyTask(BaseTask):
def initialize(self) -> InternalState:
scenario = SCENARIOS[self.rng.randint(0, len(SCENARIOS) - 1)]
failing = scenario["failing_service"]
heap1 = self.rng.randint(74, 83)
heap2 = heap1 + self.rng.randint(5, 10)
restart_count = self.rng.randint(2, 6)
services: Dict[str, dict] = {}
for name in ALL_SERVICES:
if name == failing:
services[name] = {
"name": name, "status": "down",
"cpu_percent": round(self.rng.uniform(5, 20), 1),
"memory_percent": round(self.rng.uniform(93, 99), 1),
"error_rate": round(self.rng.uniform(8.0, 15.0), 2),
"latency_p99_ms": round(self.rng.uniform(5000, 9000), 0),
"replicas_running": 0, "replicas_desired": 3,
"current_version": VERSIONS[name],
"last_deployed": "2026-03-28T14:00:00Z",
"minutes_degraded": 0, "sla_breach": False,
}
elif name == "api-gateway":
services[name] = {
"name": name, "status": "degraded",
"cpu_percent": round(self.rng.uniform(35, 55), 1),
"memory_percent": round(self.rng.uniform(40, 55), 1),
"error_rate": round(self.rng.uniform(2.0, 5.0), 2),
"latency_p99_ms": round(self.rng.uniform(800, 1500), 0),
"replicas_running": 2, "replicas_desired": 2,
"current_version": VERSIONS[name],
"last_deployed": "2026-03-25T09:00:00Z",
"minutes_degraded": 0, "sla_breach": False,
}
else:
services[name] = {
"name": name, "status": "healthy",
"cpu_percent": round(self.rng.uniform(20, 40), 1),
"memory_percent": round(self.rng.uniform(30, 48), 1),
"error_rate": 0.0,
"latency_p99_ms": round(self.rng.uniform(8, 30), 0),
"replicas_running": 2, "replicas_desired": 2,
"current_version": VERSIONS[name],
"last_deployed": "2026-03-20T11:00:00Z",
"minutes_degraded": 0, "sla_breach": False,
}
alerts = [
{
"id": "A001", "severity": "critical", "service": failing,
"message": f"{scenario['alert_msg']} - {restart_count} times in 5 minutes",
"timestamp": "2026-03-30T10:14:48Z", "acknowledged": False,
},
{
"id": "A002", "severity": "warning", "service": "api-gateway",
"message": f"Upstream {failing} returning 503 - circuit breaker open",
"timestamp": "2026-03-30T10:14:52Z", "acknowledged": False,
},
]
state = InternalState(
episode_id=str(uuid.uuid4()), task_id="easy", step=0, max_steps=15,
services=services, alerts=alerts,
logs=_make_logs(scenario, heap1, heap2, restart_count),
action_history=[], total_reward=0.0, incident_resolved=False,
ground_truth_root_cause=scenario["root_cause"],
ground_truth_fix=scenario["fix"],
incident_start_time=INCIDENT_TIME,
healthy_services=[s for s in ALL_SERVICES if s != failing],
service_dependencies=DEPENDENCIES,
)
state._scenario = scenario
return state
def step(self, state: InternalState, action: Action) -> StepOutput:
state.step += 1
state._apply_sla_degradation()
at = action.action_type
svc = action.service or ""
scenario = state._scenario
failing = scenario["failing_service"]
keywords = scenario["diagnosis_keywords"]
reward = 0.0
done = False
info: Dict[str, Any] = {}
result_text, error_text = self._apply_action_to_logs(state, action)
if at in (ActionType.READ_LOGS, ActionType.SEARCH_LOGS) and svc == failing:
if "logs_investigated" not in state.rewards_given:
reward += 0.15
state.rewards_given.add("logs_investigated")
if at == ActionType.READ_METRICS and svc == failing:
if "read_metrics" not in state.rewards_given:
reward += 0.10
state.rewards_given.add("read_metrics")
if at == ActionType.READ_RUNBOOK:
if "runbook" not in state.rewards_given:
reward += 0.05
state.rewards_given.add("runbook")
if at == ActionType.DIAGNOSE:
rc = action.root_cause or ""
correct_type = semantic_match(rc, keywords, threshold=1)
correct_svc = semantic_match(rc, [failing, failing.split("-")[0]])
result_text = f"Diagnosis recorded: {rc}"
if correct_type and correct_svc:
if "diagnose_correct" not in state.rewards_given:
# Give full reward, remove partial if already given
bonus = 0.30 if "diagnose_partial" not in state.rewards_given else 0.15
reward += bonus
state.rewards_given.add("diagnose_correct")
elif correct_type:
if "diagnose_partial" not in state.rewards_given and "diagnose_correct" not in state.rewards_given:
reward += 0.15
state.rewards_given.add("diagnose_partial")
if at == ActionType.RESTART_SERVICE:
blind_penalty = self._penalty_blind_remediation(state, action, "restarted")
reward += blind_penalty
if svc == failing:
reward += 0.40
state.services[svc]["status"] = "healthy"
state.services[svc]["memory_percent"] = round(self.rng.uniform(38, 48), 1)
state.services[svc]["error_rate"] = 0.0
state.services[svc]["latency_p99_ms"] = round(self.rng.uniform(20, 60), 0)
state.services[svc]["replicas_running"] = state.services[svc]["replicas_desired"]
state.alerts = [a for a in state.alerts if a["id"] != "A001"]
state.incident_resolved = True
result_text = f"{svc} restarted. Memory cleared. All pods healthy."
done = True
info["resolution"] = "incident_resolved"
elif svc in state.healthy_services:
reward -= 0.10
error_text = f"Collateral damage: {svc} was healthy. Unnecessary restart."
if at == ActionType.NOOP and state.step > 3:
reward -= 0.04
if at in (ActionType.BLOCK_IP_RANGE, ActionType.CREATE_INDEX, ActionType.FAILOVER) or str(at) in ("block_ip_range", "create_index", "failover"):
reward -= 0.10
error_text = f"Action {at} is not applicable to this incident."
state.total_reward = self._clamp(state.total_reward + reward)
if state.step >= state.max_steps and not done:
done = True
info["reason"] = "max_steps_reached"
obs = state._build_observation(last_action_result=result_text, last_action_error=error_text)
state.action_history.append({"step": state.step, "action": action.model_dump(), "reward": round(reward, 4)})
return StepOutput(next_state=state, reward=round(reward, 4), done=done, info=info)