SwapnilPatil28's picture
Add LLM policy, SFT saving & LLM evaluation
58af620 verified
"""Baseline inference for the Incident Command Center environment.
Two policies are provided:
- `HeuristicCoordinator` — a deterministic state machine that exercises the
full action space, picks role-appropriate actors, and consults the
observation's `investigation_targets` and `playbook_hints` so the heuristic
adapts to whatever the server is currently serving.
- `random_action` — a pure random baseline for comparison.
Running this script hits a deployed environment (local or Hugging Face Space)
and prints a structured trace the hackathon judges can follow.
"""
from __future__ import annotations
import asyncio
import json
import os
import random
from typing import Dict, List, Optional
from client import IncidentCommandEnvClient
from models import IncidentAction, IncidentObservation
ENV_URL = os.getenv("ENV_URL", "http://127.0.0.1:8000")
BENCHMARK = "incident_command_center_env"
RANDOM_BASELINE = os.getenv("RANDOM_BASELINE", "false").lower() == "true"
# When set, run an LLM-backed policy (base or fine-tuned checkpoint) instead
# of the heuristic / random ones. Point this at a HF hub id or a local dir.
POLICY_MODEL = os.getenv("POLICY_MODEL", "").strip()
# ---------------------------------------------------------------------------
# Logging helpers (structured line format, easy to grep)
# ---------------------------------------------------------------------------
def log_start(task: str, env: str, policy: str) -> None:
print(f"[START] task={task} env={env} policy={policy}", flush=True)
def log_step(
step: int,
action: str,
reward: float,
done: bool,
error: Optional[str] = None,
components: Optional[Dict[str, float]] = None,
) -> None:
error_val = error if error else "null"
done_val = str(done).lower()
comp_val = "-" if not components else ",".join(f"{k}={v:+.2f}" for k, v in components.items())
print(
f"[STEP] step={step} action={action} reward={reward:+.2f} "
f"done={done_val} error={error_val} components={comp_val}",
flush=True,
)
def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
rewards_str = ",".join(f"{r:+.2f}" for r in rewards)
print(
f"[END] success={str(success).lower()} steps={steps} score={score:+.3f} rewards={rewards_str}",
flush=True,
)
# ---------------------------------------------------------------------------
# Heuristic coordinator
# ---------------------------------------------------------------------------
class HeuristicCoordinator:
"""Deterministic multi-agent playbook agent.
The state machine runs per incident and picks the correct specialist for
each action so it never eats the wrong-actor penalty:
1. Triage inspects logs + metrics using observation-provided targets.
2. Investigator consults a KB article for the playbook.
3. Ops Manager negotiates handoff to the owner the incident expects.
4. Investigator applies a fix matched to inferred root cause.
5. Ops Manager submits a postmortem when the incident marks it required.
6. Ops Manager closes the incident with the inferred root cause.
"""
def __init__(self) -> None:
self._phase_by_incident: Dict[str, int] = {}
self._root_cause_by_incident: Dict[str, str] = {}
def select_action(self, observation: IncidentObservation) -> IncidentAction:
incident_id = observation.incident_id
phase = self._phase_by_incident.get(incident_id, 0)
targets = observation.investigation_targets or {}
log_targets = targets.get("logs", []) or []
metric_targets = targets.get("metrics", []) or []
kb_targets = targets.get("kb", []) or observation.playbook_hints
# Haystack of all visible text we can mine for clues.
haystack = " ".join(
[
observation.incident_title or "",
observation.incident_description or "",
observation.terminal_output or "",
" ".join(observation.visible_signals or []),
]
).lower()
if phase == 0 and log_targets:
self._phase_by_incident[incident_id] = 1
return IncidentAction(
actor="triage_agent",
action_type="inspect_logs",
target=self._best_target(haystack, log_targets),
reason="Initial triage: scan top logs for failure signature.",
)
if phase <= 1 and metric_targets:
self._phase_by_incident[incident_id] = 2
return IncidentAction(
actor="triage_agent",
action_type="inspect_metrics",
target=self._best_target(haystack, metric_targets),
reason="Correlate logs with dashboards.",
)
if phase <= 2 and kb_targets:
self._phase_by_incident[incident_id] = 3
return IncidentAction(
actor="investigator_agent",
action_type="consult_kb",
target=self._best_target(haystack, list(kb_targets)),
reason="Review runbook for candidate fix.",
)
if phase <= 3:
self._phase_by_incident[incident_id] = 4
owner = self._infer_owner(haystack, observation.customer_tier)
return IncidentAction(
actor="ops_manager_agent",
action_type="negotiate_handoff",
target=owner,
reason="Route to accountable specialist.",
)
if phase <= 4:
self._phase_by_incident[incident_id] = 5
guess = self._infer_root_cause(haystack)
self._root_cause_by_incident[incident_id] = guess
return IncidentAction(
actor="investigator_agent",
action_type="apply_fix",
resolution_summary=self._generate_fix_plan(guess),
reason=f"Attempt mitigation for {guess}",
)
if phase <= 5 and observation.postmortem_required and not observation.postmortem_submitted:
self._phase_by_incident[incident_id] = 6
guess = self._root_cause_by_incident.get(
incident_id, self._infer_root_cause(haystack)
)
return IncidentAction(
actor="ops_manager_agent",
action_type="submit_postmortem",
postmortem_note=(
f"Incident {incident_id}: identified root cause {guess}. "
"Mitigation applied. Follow-up actions queued for "
"reliability review."
),
reason="High-impact incident — postmortem required.",
)
guess = self._root_cause_by_incident.get(
incident_id, self._infer_root_cause(haystack)
)
return IncidentAction(
actor="ops_manager_agent",
action_type="close_incident",
root_cause=guess,
resolution_summary=f"Closed with hypothesis {guess}.",
confidence=0.75,
reason="Enough evidence gathered to close incident.",
)
# -- helpers ------------------------------------------------------------
def _best_target(self, haystack: str, candidates: List[str]) -> str:
"""Pick the candidate target whose tokens most overlap with the haystack."""
best = candidates[0]
best_score = -1
for candidate in candidates:
score = sum(1 for token in candidate.lower().split("-") if token in haystack)
if score > best_score:
best = candidate
best_score = score
return best
def _infer_owner(self, haystack: str, tier: str) -> str:
if tier == "enterprise":
return "ops_manager_agent"
if any(
token in haystack
for token in ["deploy", "rate", "sla", "rotation", "cert", "mtls"]
):
return "ops_manager_agent"
if any(
token in haystack
for token in ["schema", "export", "cache", "inventory", "search", "ranking"]
):
return "investigator_agent"
return "triage_agent"
def _infer_root_cause(self, haystack: str) -> str:
table = [
(("redis", "pool"), "redis_connection_pool_exhausted"),
(("jwt",), "jwt_clock_skew_mismatch"),
(("token", "clock"), "jwt_clock_skew_mismatch"),
(("spf",), "spf_record_misconfiguration"),
(("cache", "invalidation"), "cache_invalidation_topic_lag"),
(("timezone",), "timezone_normalization_bug"),
(("offset",), "timezone_normalization_bug"),
(("idempotency",), "idempotency_key_regression"),
(("duplicate", "invoice"), "idempotency_key_regression"),
(("mtls",), "mtls_cert_chain_mismatch"),
(("certificate", "chain"), "mtls_cert_chain_mismatch"),
(("feature", "flag"), "feature_flag_scope_misconfigured"),
(("429",), "rate_limit_misconfigured_for_promo_segment"),
(("promo",), "rate_limit_misconfigured_for_promo_segment"),
(("schema", "drift"), "schema_version_drift"),
(("schema", "mismatch"), "schema_version_drift"),
(("dedupe",), "dedupe_rule_disabled"),
(("alert", "storm"), "dedupe_rule_disabled"),
(("out-of-order",), "event_ordering_race_condition"),
(("oversell",), "event_ordering_race_condition"),
(("deadlock",), "lock_escalation_on_reporting_view"),
(("reporting", "lock"), "lock_escalation_on_reporting_view"),
]
for tokens, guess in table:
if all(tok in haystack for tok in tokens):
return guess
return "unknown"
def _generate_fix_plan(self, root_cause: str) -> str:
fixes = {
"redis_connection_pool_exhausted": "increase redis pool and recycle stale connections",
"jwt_clock_skew_mismatch": "sync clock tolerance and increase jwt leeway",
"spf_record_misconfiguration": "fix spf record and align sending domain",
"cache_invalidation_topic_lag": "scale invalidation consumer and replay partition 3",
"timezone_normalization_bug": "patch timezone parser and use iana timezone map",
"idempotency_key_regression": "restore idempotency guard and persist retry token first",
"mtls_cert_chain_mismatch": "reissue certificate chain with full intermediate chain",
"feature_flag_scope_misconfigured": "rollback feature flag and restrict experiment segment",
"rate_limit_misconfigured_for_promo_segment": (
"hotfix promo segment rate limits and enable exponential backoff"
),
"schema_version_drift": "enforce schema negotiation and pin serializer to v11",
"dedupe_rule_disabled": "restore dedupe rule and replay critical fingerprints",
"event_ordering_race_condition": "enable sequence guards and quarantine out-of-order events",
"lock_escalation_on_reporting_view": (
"offload reporting to replica and schedule reporting off-peak"
),
}
return fixes.get(root_cause, "collect additional diagnostics and rollback last change")
# ---------------------------------------------------------------------------
# Random baseline
# ---------------------------------------------------------------------------
def random_action(observation: IncidentObservation) -> IncidentAction:
action_type = random.choice(observation.available_actions or ["inspect_logs"])
teams = observation.available_teams or [
"triage_agent",
"investigator_agent",
"ops_manager_agent",
]
actor = random.choice(teams)
targets_pool: List[str] = []
for _tool, values in (observation.investigation_targets or {}).items():
targets_pool.extend(values)
targets_pool.extend(
["payments-api", "auth-service", "dash-auth", "dash-redis", "kb-rate-limits"]
)
random_target = random.choice(targets_pool)
return IncidentAction(
actor=actor, # type: ignore[arg-type]
action_type=action_type, # type: ignore[arg-type]
target=random_target,
root_cause="unknown",
resolution_summary="random baseline action",
)
# ---------------------------------------------------------------------------
# Episode driver
# ---------------------------------------------------------------------------
async def run_task(task_name: str, llm_policy=None) -> None:
env = IncidentCommandEnvClient(base_url=ENV_URL).sync()
if llm_policy is not None:
policy_name = f"llm:{getattr(llm_policy, 'label', POLICY_MODEL)}"
elif RANDOM_BASELINE:
policy_name = "random_baseline"
else:
policy_name = "heuristic_coordinator"
coordinator = HeuristicCoordinator()
log_start(task=task_name, env=BENCHMARK, policy=policy_name)
rewards: List[float] = []
steps_taken = 0
success = False
score = 0.0
try:
res = env.reset(task_name=task_name)
while not res.done:
steps_taken += 1
if llm_policy is not None:
action = llm_policy.select_action(res.observation)
elif RANDOM_BASELINE:
action = random_action(res.observation)
else:
action = coordinator.select_action(res.observation)
res = env.step(action)
reward = float(res.reward or 0.0)
rewards.append(reward)
log_step(
step=steps_taken,
action=f"{action.actor}:{action.action_type}:{action.target or '-'}",
reward=reward,
done=res.done,
components=getattr(res.observation, "reward_components", None),
)
score = sum(rewards) / len(rewards) if rewards else 0.0
success = score > 0.1
finally:
try:
env.close()
except Exception:
pass
log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
def main() -> None:
llm_policy = None
if POLICY_MODEL:
from llm_policy import LLMPolicy
llm_policy = LLMPolicy(POLICY_MODEL, label=POLICY_MODEL)
for task in ["easy", "medium", "hard"]:
asyncio.run(run_task(task, llm_policy=llm_policy))
if llm_policy is not None:
policy_label = f"llm:{POLICY_MODEL}"
elif RANDOM_BASELINE:
policy_label = "random_baseline"
else:
policy_label = "heuristic_coordinator"
print(
json.dumps(
{
"benchmark": BENCHMARK,
"policy": policy_label,
"env_url": ENV_URL,
},
indent=2,
)
)
if llm_policy is not None:
try:
llm_policy.release()
except Exception:
pass
if __name__ == "__main__":
main()