| """ |
| Multi-Agent Inference Script — Incident Post-Mortem Writer (OpenEnv) |
| ===================================================================== |
| Demonstrates the Phase 1 multi-agent extension end-to-end. |
| |
| FLOW PER EPISODE: |
| 1. QUERY_LOGS (find evidence) |
| 2. WRITE_SECTION summary (primary agent) |
| 3. WRITE_SECTION root_cause (primary agent) |
| 4. REQUEST_REVIEW (skeptic critiques the draft) |
| 5. REVISE_SECTION (primary revises root_cause based on critique) |
| 6. WRITE_SECTION timeline |
| 7. WRITE_SECTION impact |
| 8. WRITE_SECTION action_items |
| 9. ASSIGN_ACTION_ITEM |
| 10. SUBMIT (grader includes collaboration_score bonus) |
| |
| OUTPUT EXAMPLE: |
| [START] task=hard env=... mode=multi-agent |
| [STEP] step=4 action=REQUEST_REVIEW reward=0.04 |
| >> Skeptic: Your root cause blames CDN but alerts show CDN healthy... |
| [STEP] step=5 action=REVISE_SECTION reward=0.06 |
| [STEP] step=10 action=SUBMIT (includes collaboration bonus) |
| [END] |
| |
| REQUIRED env vars (same as inference.py): |
| API_BASE_URL, MODEL_NAME, HF_TOKEN, ENV_BASE_URL |
| Plus (for the skeptic running SERVER-side, optional, has fallback): |
| SKEPTIC_API_KEY, SKEPTIC_API_BASE_URL, SKEPTIC_MODEL_NAME |
| |
| USAGE: |
| python inference_multiagent.py |
| """ |
|
|
| from __future__ import annotations |
|
|
| import json |
| import os |
| import re |
| import sys |
| import time |
| from typing import Any, Dict, List, Optional |
|
|
| from openai import OpenAI |
| import requests |
|
|
| |
| API_BASE_URL = os.environ.get("API_BASE_URL", "https://api.openai.com/v1") |
| MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-4o-mini") |
| HF_TOKEN = os.environ.get("HF_TOKEN", "") |
| ENV_BASE_URL = os.environ.get("ENV_BASE_URL", "http://localhost:7860") |
|
|
| TEMPERATURE = 0.0 |
| MAX_TOKENS = 1500 |
| DIFFICULTIES = ["easy", "medium", "hard", "expert"] |
|
|
| client = OpenAI(api_key=HF_TOKEN or "dummy", base_url=API_BASE_URL) |
|
|
| BENCHMARK = "incident-postmortem-writer" |
|
|
|
|
| def log_start(task: str, env: str, model: str) -> None: |
| print(f"[START] task={task} env={env} model={model} mode=multi-agent", flush=True) |
|
|
| def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None: |
| error_val = error if error else "null" |
| done_val = str(done).lower() |
| print(f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}", flush=True) |
|
|
| def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None: |
| rewards_str = ",".join(f"{r:.2f}" for r in rewards) |
| print(f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", flush=True) |
|
|
|
|
| class PostMortemEnv: |
| def __init__(self, base_url: str): |
| self.base_url = base_url.rstrip("/") |
| self._session = requests.Session() |
|
|
| def reset(self, difficulty: str = "easy") -> Dict[str, Any]: |
| r = self._session.post(f"{self.base_url}/reset", json={"difficulty": difficulty}, timeout=30) |
| r.raise_for_status() |
| return r.json() |
|
|
| def step(self, action: Dict[str, Any]) -> Dict[str, Any]: |
| r = self._session.post(f"{self.base_url}/step", json=action, timeout=30) |
| r.raise_for_status() |
| return r.json() |
|
|
| def health(self) -> bool: |
| try: |
| r = self._session.get(f"{self.base_url}/health", timeout=5) |
| return r.status_code == 200 |
| except Exception: |
| return False |
|
|
|
|
| def call_llm(system: str, user: str) -> str: |
| try: |
| completion = client.chat.completions.create( |
| model=MODEL_NAME, |
| messages=[ |
| {"role": "system", "content": system}, |
| {"role": "user", "content": user}, |
| ], |
| temperature=TEMPERATURE, |
| max_tokens=MAX_TOKENS, |
| ) |
| return completion.choices[0].message.content or "" |
| except Exception as exc: |
| print(f" [LLM error] {exc}") |
| return "" |
|
|
|
|
| def extract_json(text: str) -> Optional[Dict]: |
| try: |
| return json.loads(text.strip()) |
| except Exception: |
| pass |
| m = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", text, re.DOTALL) |
| if m: |
| try: |
| return json.loads(m.group(1)) |
| except Exception: |
| pass |
| m = re.search(r"\{.*?\}", text, re.DOTALL) |
| if m: |
| try: |
| return json.loads(m.group(0)) |
| except Exception: |
| pass |
| return None |
|
|
|
|
| QUERY_SYSTEM = """You are an expert SRE. Given incident alerts and Slack messages, |
| identify the best service and time window to query for root cause evidence. |
| Respond with ONLY valid JSON: {"service": "<service_name>", "from": "<HH:MM>", "to": "<HH:MM>"} |
| |
| STRATEGY - follow in this exact order: |
| 1. Look for DEPLOYMENT or CONFIG CHANGE in Slack (keywords: deploy, TTL, migration, release, config, schema). |
| If found, query THAT service at THAT deployment time. Deployments are almost always root cause. |
| 2. If no deployment, identify which service changed behavior FIRST and trace upstream dependencies. |
| 3. Pick a 5-8 minute window AROUND the deployment or first change time. |
| 4. NEVER query the most-alerted service - it is usually a victim not the cause.""" |
|
|
| WRITE_SYSTEM = """You are an expert SRE writing one section of an incident post-mortem. |
| Write ONLY the section content - no JSON, no section labels, just plain text. |
| Be specific and factual. Use exact service names and timestamps from the evidence.""" |
|
|
| SECTION_PROMPTS = { |
| "summary": "Write 2-3 sentences summarizing the incident. MUST explicitly name the affected service.", |
| "timeline": "Write a chronological timeline with 5+ timestamped events in format 'HH:MM - what happened'.", |
| "root_cause": "Write root cause analysis. MUST name: (1) which service failed, (2) type of failure (deployment bug / config error / connection leak / etc), (3) specific technical details.", |
| "impact": "Write impact assessment of at least 30 words. Include: affected services, outage duration, users affected, business/revenue impact.", |
| "action_items": "Write 3 numbered action items. Example: '1. Fix X - Owner: payments-team - Due: 2024-08-01'. Owner must be a team or person from Slack.", |
| } |
|
|
| REVISE_SYSTEM = """You are an expert SRE revising a post-mortem section based on a senior reviewer's critique. |
| The reviewer has identified a specific problem with your previous draft. |
| Your job is to address that critique by rewriting the section. |
| |
| RULES: |
| - Write ONLY the revised section content - no labels, no JSON. |
| - Must be SUBSTANTIALLY different from the original (at least 30 characters changed). |
| - Must directly address the critique, not avoid it. |
| - Keep factual claims that were correct; fix the ones that were challenged.""" |
|
|
|
|
| def _fallback_section(section: str, observation: Dict, logs_found: List) -> str: |
| alerts = observation.get("alerts", []) |
| services = list({a["service"] for a in alerts}) |
| main_svc = services[0] if services else "payments" |
| t_start = alerts[0]["timestamp"][:5] if alerts else "00:00" |
| t_end = alerts[-1]["timestamp"][:5] if alerts else "01:00" |
| return { |
| "summary": f"The {main_svc} service experienced a significant incident affecting production traffic. The on-call team investigated and resolved the issue.", |
| "timeline": f"{t_start} - First alert for {main_svc}\n{alerts[len(alerts)//2]['timestamp'][:5] if alerts else '00:15'} - On-call engaged\n{t_end} - Service recovery confirmed", |
| "root_cause": f"Root cause: The {main_svc} service experienced a deployment bug or configuration error that caused service degradation affecting production traffic.", |
| "impact": f"The {main_svc} service was unavailable or degraded for approximately 30 minutes. Production users experienced errors and timeouts. Business impact included user-facing failures and potential revenue loss.", |
| "action_items": "1. Add monitoring for the affected service - Owner: sre - Due: next sprint\n2. Review deploy process - Owner: platform - Due: 2024-08-01\n3. Post-mortem review with team - Owner: sre - Due: next sprint", |
| }.get(section, f"Analysis of {main_svc} service incident.") |
|
|
|
|
| def do_query(env, observation): |
| alerts_text = "\n".join( |
| f"[{a['timestamp']}] [{a['severity']}] {a['service']}: {a['message']}" |
| for a in observation.get("alerts", []) |
| ) |
| slack_text = "\n".join( |
| f"[{m['timestamp']}] {m['author']}: {m['text']}" |
| for m in observation.get("slack_thread", []) |
| ) |
| services = list({a['service'] for a in observation.get("alerts", [])}) |
|
|
| user_prompt = f"""INCIDENT: {observation.get('incident_title', '')} |
| ALERTS:\n{alerts_text} |
| SLACK:\n{slack_text} |
| Available services: {services} |
| Which service and time window to query for root cause?""" |
|
|
| response = call_llm(QUERY_SYSTEM, user_prompt) |
| query = extract_json(response) |
|
|
| if query and "service" in query: |
| action = { |
| "action_type": "QUERY_LOGS", |
| "query_service": query.get("service", services[0] if services else "payments"), |
| "query_from": query.get("from", "00:00"), |
| "query_to": query.get("to", "23:59"), |
| } |
| return env.step(action) |
|
|
| return env.step({ |
| "action_type": "QUERY_LOGS", |
| "query_service": services[0] if services else "payments", |
| "query_from": "00:00", |
| "query_to": "23:59", |
| }) |
|
|
|
|
| def write_section(env, observation, section: str, logs_found: List) -> Dict: |
| alerts_text = "\n".join( |
| f"[{a['timestamp']}] [{a['severity']}] {a['service']}: {a['message']}" |
| for a in observation.get("alerts", []) |
| ) |
| slack_text = "\n".join( |
| f"[{m['timestamp']}] {m['author']}: {m['text']}" |
| for m in observation.get("slack_thread", []) |
| ) |
| logs_text = "\n".join( |
| f"[{l['timestamp']}] [{l['severity']}] {l['service']}: {l['message']}" |
| for l in logs_found |
| ) if logs_found else "(no logs retrieved)" |
|
|
| base_context = ( |
| f"INCIDENT: {observation.get('incident_title', '')}\n" |
| f"ALERTS:\n{alerts_text}\n" |
| f"SLACK:\n{slack_text}\n" |
| f"RETRIEVED LOGS:\n{logs_text}" |
| ) |
|
|
| instruction = SECTION_PROMPTS[section] |
| user_prompt = f"{base_context}\n\nWRITE THE '{section.upper()}' SECTION:\n{instruction}\n\nSection content:" |
| response = call_llm(WRITE_SYSTEM, user_prompt) |
| content = response.strip() |
| if content.startswith("{") or not content or len(content) < 20: |
| content = _fallback_section(section, observation, logs_found) |
|
|
| return env.step({ |
| "action_type": "WRITE_SECTION", |
| "section_name": section, |
| "section_content": content, |
| }) |
|
|
|
|
| def request_review(env): |
| return env.step({"action_type": "REQUEST_REVIEW"}) |
|
|
|
|
| def revise_section_via_llm(env, observation, critique: str, section_name: str, critique_index: int) -> Dict: |
| current_sections = observation.get("sections", []) |
| current_content = "" |
| for s in current_sections: |
| if s.get("name") == section_name: |
| current_content = s.get("content", "") or "" |
| break |
|
|
| alerts_text = "\n".join( |
| f"[{a['timestamp']}] [{a['severity']}] {a['service']}: {a['message']}" |
| for a in observation.get("alerts", []) |
| ) |
| slack_text = "\n".join( |
| f"[{m['timestamp']}] {m['author']}: {m['text']}" |
| for m in observation.get("slack_thread", []) |
| ) |
|
|
| user_prompt = ( |
| f"INCIDENT: {observation.get('incident_title', '')}\n" |
| f"ALERTS:\n{alerts_text}\n" |
| f"SLACK:\n{slack_text}\n\n" |
| f"ORIGINAL {section_name.upper()} SECTION:\n{current_content}\n\n" |
| f"REVIEWER CRITIQUE:\n{critique}\n\n" |
| f"Write the revised {section_name.upper()} section that addresses this critique:" |
| ) |
| response = call_llm(REVISE_SYSTEM, user_prompt).strip() |
|
|
| if not response or len(response) < 40: |
| response = f"REVISED: {current_content} Additionally, based on reviewer feedback: {critique[:150]}" |
|
|
| return env.step({ |
| "action_type": "REVISE_SECTION", |
| "section_name": section_name, |
| "section_content": response, |
| "critique_addressed_index": critique_index, |
| }) |
|
|
|
|
| def run_multiagent_episode(env: PostMortemEnv, difficulty: str) -> float: |
| print(f"\n{'='*60}") |
| print(f" Task: {difficulty.upper()} (multi-agent)") |
| print(f"{'='*60}") |
|
|
| log_start(task=difficulty, env=BENCHMARK, model=MODEL_NAME) |
|
|
| step_rewards: List[float] = [] |
| step_count = 0 |
| final_score = 0.0 |
| success = False |
|
|
| try: |
| result = env.reset(difficulty=difficulty) |
| observation = result["observation"] |
| print(f" Incident: {observation.get('incident_title','')}") |
| print(f" Alerts: {len(observation.get('alerts',[]))} | Slack: {len(observation.get('slack_thread',[]))}") |
|
|
| |
| print("\n -- Step 1: QUERY_LOGS --") |
| result = do_query(env, observation) |
| observation = result["observation"] |
| step_count += 1 |
| r = float(result.get("reward", {}).get("total", 0.0) or 0.0) |
| step_rewards.append(r) |
| log_step(step=step_count, action="QUERY_LOGS", reward=r, done=False, error=None) |
| logs_found = observation.get("retrieved_logs") or [] |
| print(f" reward={r:+.3f} | retrieved {len(logs_found)} logs") |
|
|
| |
| for section in ["summary", "root_cause"]: |
| print(f"\n -- Step {step_count+1}: WRITE_SECTION {section} (primary agent) --") |
| result = write_section(env, observation, section, logs_found) |
| observation = result["observation"] |
| step_count += 1 |
| r = float(result.get("reward", {}).get("total", 0.0) or 0.0) |
| step_rewards.append(r) |
| log_step(step=step_count, action=f"WRITE_SECTION_{section}", reward=r, done=False, error=None) |
| print(f" reward={r:+.3f}") |
|
|
| |
| print(f"\n -- Step {step_count+1}: REQUEST_REVIEW (skeptic critiques) --") |
| result = request_review(env) |
| observation = result["observation"] |
| step_count += 1 |
| r = float(result.get("reward", {}).get("total", 0.0) or 0.0) |
| step_rewards.append(r) |
| log_step(step=step_count, action="REQUEST_REVIEW", reward=r, done=False, error=None) |
|
|
| critique = None |
| if observation.get("skeptic_critiques"): |
| critique = observation["skeptic_critiques"][-1] |
| print(f" reward={r:+.3f}") |
| print(f" >> Skeptic: {critique[:180]}") |
| else: |
| print(f" reward={r:+.3f} | No critique returned") |
|
|
| |
| if critique: |
| print(f"\n -- Step {step_count+1}: REVISE_SECTION root_cause (primary revises) --") |
| result = revise_section_via_llm( |
| env, observation, critique, |
| section_name="root_cause", |
| critique_index=len(observation["skeptic_critiques"]) - 1, |
| ) |
| observation = result["observation"] |
| step_count += 1 |
| r = float(result.get("reward", {}).get("total", 0.0) or 0.0) |
| step_rewards.append(r) |
| log_step(step=step_count, action="REVISE_SECTION_root_cause", reward=r, done=False, error=None) |
| print(f" reward={r:+.3f} | critiques_addressed={observation.get('critiques_addressed', 0)}") |
|
|
| |
| for section in ["timeline", "impact", "action_items"]: |
| print(f"\n -- Step {step_count+1}: WRITE_SECTION {section} --") |
| result = write_section(env, observation, section, logs_found) |
| observation = result["observation"] |
| step_count += 1 |
| r = float(result.get("reward", {}).get("total", 0.0) or 0.0) |
| step_rewards.append(r) |
| log_step(step=step_count, action=f"WRITE_SECTION_{section}", reward=r, done=False, error=None) |
| print(f" reward={r:+.3f}") |
|
|
| |
| print(f"\n -- Step {step_count+1}: ASSIGN_ACTION_ITEM --") |
| result = env.step({ |
| "action_type": "ASSIGN_ACTION_ITEM", |
| "action_item_description": "Prevent recurrence of incident - implement fixes and monitoring", |
| "action_item_owner": "sre", |
| "action_item_due_date": "next sprint", |
| }) |
| observation = result["observation"] |
| step_count += 1 |
| r = float(result.get("reward", {}).get("total", 0.0) or 0.0) |
| step_rewards.append(r) |
| log_step(step=step_count, action="ASSIGN_ACTION_ITEM", reward=r, done=False, error=None) |
|
|
| print(f"\n -- Step {step_count+1}: SUBMIT --") |
| result = env.step({"action_type": "SUBMIT"}) |
| step_count += 1 |
| r = float(result.get("reward", {}).get("total", 0.0) or 0.0) |
| step_rewards.append(r) |
| done = bool(result.get("done", False)) |
| log_step(step=step_count, action="SUBMIT", reward=r, done=done, error=None) |
|
|
| grade = result.get("info", {}).get("grade") |
| if grade: |
| final_score = grade.get("total_score", 0.0) |
| success = final_score > 0.3 |
| print(f"\n FINAL GRADE: {final_score:.3f}") |
| print(f" root_cause={grade.get('root_cause_score',0):.2f} " |
| f"timeline={grade.get('timeline_score',0):.2f} " |
| f"action_items={grade.get('action_items_score',0):.2f} " |
| f"impact={grade.get('impact_score',0):.2f} " |
| f"completeness={grade.get('completeness_score',0):.2f}") |
| if grade.get('critiques_received', 0) > 0: |
| print(f" collaboration={grade.get('collaboration_score',0):.2f} " |
| f"({grade.get('critiques_addressed',0)}/{grade.get('critiques_received',0)} critiques addressed)") |
| print(f" {grade.get('explanation','')}") |
|
|
| except Exception as exc: |
| print(f" [ERROR] Episode failed: {exc}") |
| log_step(step=step_count + 1, action="ERROR", reward=0.0, done=True, error=str(exc)) |
|
|
| log_end(success=success, steps=step_count, score=final_score, rewards=step_rewards) |
| return final_score |
|
|
|
|
| def main(): |
| print("=" * 60) |
| print(" Multi-Agent Inference - Incident Post-Mortem Writer") |
| print("=" * 60) |
| print(f" Model: {MODEL_NAME}") |
| print(f" API: {API_BASE_URL}") |
| print(f" Env URL: {ENV_BASE_URL}") |
| print(f" Mode: multi-agent (primary + skeptic)") |
|
|
| env = PostMortemEnv(ENV_BASE_URL) |
| if not env.health(): |
| print(f"ERROR: Environment not reachable at {ENV_BASE_URL}") |
| print(f"Start it: uvicorn server.app:app --host 0.0.0.0 --port 7860") |
| sys.exit(1) |
|
|
| scores: Dict[str, float] = {} |
| t_start = time.time() |
| for difficulty in DIFFICULTIES: |
| score = run_multiagent_episode(env, difficulty) |
| scores[difficulty] = round(score, 4) |
|
|
| elapsed = time.time() - t_start |
|
|
| print("\n" + "=" * 60) |
| print(" MULTI-AGENT BENCHMARK RESULTS") |
| print("=" * 60) |
| for task, score in scores.items(): |
| print(f" {task:8s}: {score:.3f}") |
| avg = sum(scores.values()) / len(scores) |
| print(f" {'average':8s}: {avg:.3f}") |
| print(f"\n runtime: {elapsed:.1f}s") |
|
|
| print(f"\nJSON_SCORES: {json.dumps(scores)}", flush=True) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|