Arijit-07's picture
Improvements: partial log obs, search_logs action, CoT inference, dashboard UI, README motivation
77eea12
from __future__ import annotations
import random
import uuid
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import List, Dict, Any, Optional, Set
from models import (
Action, ActionType, Observation, State, StepResult,
ServiceStatus, Alert, ServiceDependency, EvidenceEntry,
)
AVAILABLE_RUNBOOKS = [
"high_cpu.md",
"memory_leak.md",
"db_connection.md",
"deployment_rollback.md",
"cascade_failure.md",
"data_corruption.md",
]
TASK_DESCRIPTIONS = {
"easy": (
"PRODUCTION INCIDENT — One service is crash-looping. "
"Read its logs and metrics to find the root cause, diagnose precisely, "
"then apply the correct single-service fix. "
"Avoid restarting healthy services — collateral damage is penalised."
),
"medium": (
"PRODUCTION INCIDENT — Multiple services are degraded. "
"Use the service dependency map to trace the failure to its origin. "
"A recent deployment is likely involved. One alert is a red herring. "
"Fix the root service only — downstream victims will self-heal."
),
"hard": (
"PRODUCTION INCIDENT — All services show green health. No error-rate alerts. "
"Look for anomalies in business-logic metrics and WARN-level logs. "
"Correlate signals across services to find silent data corruption. "
"Two actions are required for full credit: rollback AND alert_oncall."
),
"bonus": (
"PRODUCTION INCIDENT — Two independent failures are active simultaneously. "
"They are unrelated — fixing one will NOT fix the other. "
"Identify both root causes and remediate each independently. "
"Full credit requires resolving both."
),
}
@dataclass
class InternalState:
episode_id: str
task_id: str
step: int
max_steps: int
services: Dict[str, dict]
alerts: list
logs: Dict[str, List[str]]
action_history: List[Dict[str, Any]]
total_reward: float
incident_resolved: bool
ground_truth_root_cause: str
ground_truth_fix: str
incident_start_time: str
last_action_result: Optional[str] = field(default=None)
last_action_error: Optional[str] = field(default=None)
rewards_given: Set[str] = field(default_factory=set)
healthy_services: List[str] = field(default_factory=list)
evidence_log: List[dict] = field(default_factory=list)
service_dependencies: List[dict] = field(default_factory=list)
_scenario: Any = field(default=None, repr=False)
_ml_version: Any = field(default=None, repr=False)
def to_state_snapshot(self) -> State:
obs = self._build_observation()
return State(
episode_id=self.episode_id,
task_id=self.task_id,
step=self.step,
current_observation=obs,
action_history=self.action_history,
total_reward=round(self.total_reward, 4),
incident_resolved=self.incident_resolved,
ground_truth_root_cause=self.ground_truth_root_cause,
ground_truth_fix=self.ground_truth_fix,
info={
"rewards_unlocked": sorted(self.rewards_given),
"evidence_gathered": len(self.evidence_log),
},
)
def _build_sla_status(self) -> Dict[str, str]:
status = {}
for name, svc in self.services.items():
if svc["status"] == "down":
mins = self.step * 2
if mins >= 10:
status[name] = "breached"
elif mins >= 5:
status[name] = "warning"
else:
status[name] = "ok"
elif svc["status"] == "degraded":
mins = self.step * 2
if mins >= 20:
status[name] = "breached"
elif mins >= 10:
status[name] = "warning"
else:
status[name] = "ok"
else:
status[name] = "ok"
return status
def _apply_sla_degradation(self) -> None:
"""Services get progressively worse if not fixed — adds urgency."""
if self.incident_resolved:
return
for name, svc in self.services.items():
if svc["status"] == "down":
svc["minutes_degraded"] = svc.get("minutes_degraded", 0) + 2
# Error rate creeps up
svc["error_rate"] = min(svc["error_rate"] * 1.05, 50.0)
elif svc["status"] == "degraded":
svc["minutes_degraded"] = svc.get("minutes_degraded", 0) + 2
# Latency grows
svc["latency_p99_ms"] = min(svc["latency_p99_ms"] * 1.03, 60000.0)
if svc["latency_p99_ms"] > 30000 and svc["error_rate"] < 1.0:
svc["error_rate"] = round(svc["error_rate"] + 0.5, 2)
def _build_observation(
self,
last_action_result: Optional[str] = None,
last_action_error: Optional[str] = None,
) -> Observation:
if last_action_result is not None: self.last_action_result = last_action_result
if last_action_error is not None: self.last_action_error = last_action_error
services = []
for name, s in self.services.items():
services.append(ServiceStatus(
name=s["name"],
status=s["status"],
cpu_percent=s["cpu_percent"],
memory_percent=s["memory_percent"],
error_rate=round(s["error_rate"], 3),
latency_p99_ms=round(s["latency_p99_ms"], 0),
replicas_running=s["replicas_running"],
replicas_desired=s["replicas_desired"],
current_version=s["current_version"],
last_deployed=s["last_deployed"],
sla_breach=s.get("sla_breach", False),
minutes_degraded=s.get("minutes_degraded", 0),
))
alerts = [Alert(**a) for a in self.alerts]
deps = [ServiceDependency(**d) for d in self.service_dependencies]
evidence = [EvidenceEntry(**e) for e in self.evidence_log]
sla = self._build_sla_status()
return Observation(
step=self.step,
max_steps=self.max_steps,
task_id=self.task_id,
task_description=TASK_DESCRIPTIONS.get(self.task_id, ""),
services=services,
active_alerts=alerts,
recent_logs={
svc: lines[-2:] + ([f"[... {len(lines)-2} more lines — use read_logs to see full history]"] if len(lines) > 2 else [])
for svc, lines in self.logs.items()
},
available_runbooks=AVAILABLE_RUNBOOKS,
service_dependencies=deps,
evidence_log=evidence,
sla_status=sla,
last_action_result=self.last_action_result,
last_action_error=self.last_action_error,
incident_start_time=self.incident_start_time,
elapsed_minutes=self.step * 2,
)
@dataclass
class StepOutput:
next_state: InternalState
reward: float
done: bool
info: Dict[str, Any]
def semantic_match(candidate: str, keywords: List[str], threshold: int = 1) -> bool:
"""
Returns True if candidate contains at least `threshold` keywords.
Case-insensitive, handles hyphens/underscores.
"""
if not candidate:
return False
c = candidate.lower().replace("-", " ").replace("_", " ")
hits = sum(1 for kw in keywords if kw.lower().replace("-", " ") in c)
return hits >= threshold
class BaseTask(ABC):
def __init__(self, rng: random.Random):
self.rng = rng
@abstractmethod
def initialize(self) -> InternalState:
pass
@abstractmethod
def step(self, state: InternalState, action: Action) -> StepOutput:
pass
def _apply_action_to_logs(
self, state: InternalState, action: Action
) -> tuple[Optional[str], Optional[str]]:
at = action.action_type.value
if at == "read_logs":
svc = action.service
if svc and svc in state.logs:
lines = state.logs[svc]
result = "\n".join(lines)
# Add to evidence log
state.evidence_log.append({
"step": state.step,
"source": f"logs:{svc}",
"summary": f"Read {len(lines)} log lines from {svc}",
"raw": result,
})
return result, None
return None, f"No logs found for service '{svc}'"
if at == "search_logs":
svc = action.service
query = (action.query or "").lower()
if not svc or svc not in state.logs:
return None, f"Unknown service '{svc}'"
if not query:
return None, "search_logs requires a query parameter"
lines = state.logs[svc]
matches = [l for l in lines if query in l.lower()]
if not matches:
result = f"No lines matching '{query}' in {svc} logs."
else:
result = f"Found {len(matches)} lines matching '{query}':\n" + "\n".join(matches)
state.evidence_log.append({
"step": state.step,
"source": f"search:{svc}",
"summary": f"Searched {svc} for '{query}': {len(matches)} matches",
"raw": result,
})
return result, None
if at == "read_metrics":
svc = action.service
if svc and svc in state.services:
s = state.services[svc]
result = (
f"=== Metrics: {svc} ===\n"
f"Status: {s['status'].upper()}\n"
f"CPU: {s['cpu_percent']:.1f}%\n"
f"Memory: {s['memory_percent']:.1f}%\n"
f"Error rate: {s['error_rate']:.3f}/s\n"
f"P99 latency: {s['latency_p99_ms']:.0f}ms\n"
f"Replicas: {s['replicas_running']}/{s['replicas_desired']}\n"
f"Version: {s['current_version']}\n"
f"Last deploy: {s['last_deployed']}\n"
f"Degraded for: {s.get('minutes_degraded', 0)} minutes"
)
state.evidence_log.append({
"step": state.step,
"source": f"metrics:{svc}",
"summary": (
f"{svc}: {s['status']}, cpu={s['cpu_percent']:.0f}%, "
f"mem={s['memory_percent']:.0f}%, err={s['error_rate']:.2f}/s, "
f"ver={s['current_version']}"
),
"raw": result,
})
return result, None
return None, f"Unknown service '{svc}'"
if at == "read_runbook":
rb = action.runbook
if rb in AVAILABLE_RUNBOOKS:
content = self._load_runbook(rb)
state.evidence_log.append({
"step": state.step,
"source": f"runbook:{rb}",
"summary": f"Read runbook: {rb}",
"raw": content[:200],
})
return content, None
return None, f"Runbook '{rb}' not found. Available: {AVAILABLE_RUNBOOKS}"
if at == "acknowledge":
alert_id = action.service
for a in state.alerts:
if a["id"] == alert_id:
a["acknowledged"] = True
return f"Alert {alert_id} acknowledged.", None
return None, f"Alert '{alert_id}' not found."
if at == "noop":
return "No action taken.", None
return None, None
def _load_runbook(self, name: str) -> str:
import os
path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "data", "runbooks", name)
try:
with open(path) as f:
return f.read()
except FileNotFoundError:
return f"[Runbook '{name}' not found]"
def _clamp(self, value: float) -> float:
return max(0.0, min(1.0, value))
def _penalty_blind_remediation(
self, state: InternalState, action: Action, fix_key: str
) -> float:
"""
Small penalty if agent remediates without any prior diagnosis.
Encourages evidence-gathering before action.
"""
if fix_key in state.rewards_given:
return 0.0
if "diagnose_correct" not in state.rewards_given and \
"diagnose_partial" not in state.rewards_given:
return -0.05
return 0.0