postmortem_env / server /postmortem_env_environment.py
yashppawar's picture
Upload folder using huggingface_hub
b29893e verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
"""
PostMortem Environment — incident triage as an OpenEnv env.
Agent plays an on-call SRE. It interacts via typed actions (query_logs,
query_metrics, query_traces, ack, scope, hypothesize, mitigate, write_status)
against one of three fixed scenarios that rotate on reset(). The reward is a
5-stage process-reward ladder in [0, 1]:
ack +0.10
scope +0.20 (Jaccard overlap vs. gold service set)
hypothesize +0.20 (fraction of gold keywords mentioned)
mitigate +0.20 (fraction of gold keywords mentioned)
write_status +0.30 (fraction of gold keywords mentioned)
Each sub-goal can only be claimed once. Episodes terminate on `write_status`
or after MAX_STEPS (12).
"""
from typing import Any, Dict, List
from uuid import uuid4
from openenv.core.env_server.interfaces import Environment
from openenv.core.env_server.types import State
try:
from ..models import PostmortemAction, PostmortemObservation
from .scenarios import SCENARIOS, num_scenarios
except (ImportError, ModuleNotFoundError): # Docker / direct-run fallback
import os, sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from models import PostmortemAction, PostmortemObservation # type: ignore
from scenarios import SCENARIOS, num_scenarios # type: ignore
MAX_STEPS = 12
# ---------- Reward helpers ----------
def _jaccard(a: List[str], b: List[str]) -> float:
if not a and not b:
return 1.0
sa, sb = {x.strip().lower() for x in a}, {x.strip().lower() for x in b}
if not sa or not sb:
return 0.0
return len(sa & sb) / len(sa | sb)
def _keyword_fraction(text: str, keywords: List[str]) -> float:
if not keywords:
return 0.0
t = text.lower()
hits = sum(1 for k in keywords if k.lower() in t)
return hits / len(keywords)
# ---------- Environment ----------
class PostmortemEnvironment(Environment):
"""Incident triage environment."""
SUPPORTS_CONCURRENT_SESSIONS: bool = True
def __init__(self) -> None:
self._state = State(episode_id=str(uuid4()), step_count=0)
self._scenario_idx = 0
self._scenario: Dict[str, Any] = SCENARIOS[0]
self._subgoals: Dict[str, bool] = {
"acked": False,
"scoped": False,
"hypothesized": False,
"mitigated": False,
"written": False,
}
self._reward_so_far = 0.0
self._done = False
self._last_error = ""
# ---- env API ----
def reset(self) -> PostmortemObservation:
# Rotate to next scenario on each reset so a run of 3 resets
# covers all three difficulty tiers in order.
self._scenario = SCENARIOS[self._scenario_idx % num_scenarios()]
self._scenario_idx += 1
self._state = State(episode_id=str(uuid4()), step_count=0)
self._subgoals = {k: False for k in self._subgoals}
self._reward_so_far = 0.0
self._done = False
self._last_error = ""
return PostmortemObservation(
task_id=self._scenario["task_id"],
task_description=self._scenario["description"],
available_services=list(self._scenario["services"]),
available_trace_ids=list(self._scenario.get("traces", {}).keys()),
tool_result="Incident opened. Begin investigation.",
subgoals=dict(self._subgoals),
reward_so_far=0.0,
steps_remaining=MAX_STEPS,
last_error="",
done=False,
reward=0.0,
metadata={"difficulty": self._scenario.get("difficulty", "")},
)
def step(self, action: PostmortemAction) -> PostmortemObservation: # type: ignore[override]
self._state.step_count += 1
tool = (action.tool or "").strip().lower()
args = action.args or {}
tool_result = ""
step_reward = 0.0
self._last_error = ""
try:
if tool == "ack":
if not self._subgoals["acked"]:
self._subgoals["acked"] = True
step_reward = 0.10
tool_result = "Acknowledged. You now own this incident."
else:
tool_result = "Already acknowledged."
elif tool == "query_logs":
service = str(args.get("service", "")).strip()
logs = self._scenario.get("logs", {}).get(service)
if logs is None:
self._last_error = f"unknown service '{service}'"
tool_result = f"ERROR: {self._last_error}"
else:
tool_result = "\n".join(logs)
elif tool == "query_metrics":
service = str(args.get("service", "")).strip()
metrics = self._scenario.get("metrics", {}).get(service)
if metrics is None:
self._last_error = f"unknown service '{service}'"
tool_result = f"ERROR: {self._last_error}"
else:
tool_result = ", ".join(f"{k}={v}" for k, v in metrics.items())
elif tool == "query_traces":
trace_id = str(args.get("trace_id", "")).strip()
trace = self._scenario.get("traces", {}).get(trace_id)
if trace is None:
self._last_error = f"unknown trace_id '{trace_id}'"
tool_result = f"ERROR: {self._last_error}"
else:
tool_result = " | ".join(
f"{s['service']}:{s['op']} {s['duration_ms']}ms err={s.get('error', False)}"
for s in trace
)
elif tool == "scope":
services = args.get("services", [])
if not isinstance(services, list):
self._last_error = "scope.services must be a list"
tool_result = f"ERROR: {self._last_error}"
elif not self._subgoals["scoped"]:
jac = _jaccard(services, self._scenario["gold"]["scope"])
gained = 0.20 * jac
step_reward = gained
self._subgoals["scoped"] = True
tool_result = f"Scope recorded. Match vs gold = {jac:.2f}"
else:
tool_result = "Scope already set."
elif tool == "hypothesize":
cause = str(args.get("root_cause", ""))
if not self._subgoals["hypothesized"]:
frac = _keyword_fraction(cause, self._scenario["gold"]["hypothesis_keywords"])
gained = 0.20 * frac
step_reward = gained
self._subgoals["hypothesized"] = True
tool_result = f"Hypothesis recorded. Keyword match = {frac:.2f}"
else:
tool_result = "Hypothesis already set."
elif tool == "mitigate":
mit = str(args.get("action", ""))
if not self._subgoals["mitigated"]:
frac = _keyword_fraction(mit, self._scenario["gold"]["mitigation_keywords"])
gained = 0.20 * frac
step_reward = gained
self._subgoals["mitigated"] = True
tool_result = f"Mitigation applied. Keyword match = {frac:.2f}"
else:
tool_result = "Mitigation already applied."
elif tool == "write_status":
text = str(args.get("text", ""))
if not self._subgoals["written"]:
frac = _keyword_fraction(text, self._scenario["gold"]["writeup_keywords"])
gained = 0.30 * frac
step_reward = gained
self._subgoals["written"] = True
tool_result = f"Status update published. Keyword match = {frac:.2f}"
self._done = True # writeup ends the episode
else:
tool_result = "Status update already published."
else:
self._last_error = f"unknown tool '{tool}'"
tool_result = (
f"ERROR: {self._last_error}. Valid: ack, query_logs, query_metrics, "
"query_traces, scope, hypothesize, mitigate, write_status."
)
except Exception as exc: # defensive — never crash the server
self._last_error = f"internal: {exc}"
tool_result = f"ERROR: {self._last_error}"
self._reward_so_far = min(1.0, max(0.0, self._reward_so_far + step_reward))
if self._state.step_count >= MAX_STEPS:
self._done = True
return PostmortemObservation(
task_id=self._scenario["task_id"],
task_description=self._scenario["description"],
available_services=list(self._scenario["services"]),
available_trace_ids=list(self._scenario.get("traces", {}).keys()),
tool_result=tool_result,
subgoals=dict(self._subgoals),
reward_so_far=self._reward_so_far,
steps_remaining=max(0, MAX_STEPS - self._state.step_count),
last_error=self._last_error,
done=self._done,
reward=step_reward,
metadata={"difficulty": self._scenario.get("difficulty", "")},
)
@property
def state(self) -> State:
return self._state