| """ |
| server/environment.py — Core OpenEnv environment for Cloud Incident Response. |
| |
| Difficulty comes from SCENARIO DESIGN, not mechanics: |
| EASY: 3 services, clear metrics, obvious severity |
| MEDIUM: 8 services, root cause NOT in alert, must follow log breadcrumbs |
| HARD: 8 services + 5-7 remediation steps + quality summary + penalties |
| """ |
|
|
| from __future__ import annotations |
|
|
| import os |
| import sys |
| import threading |
| import uuid |
|
|
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
| from graders import _svc_match, grade |
| from server.models import Action, ActionParameters, EpisodeState, Observation, Reward |
| from tasks import get_scenario, get_task |
|
|
| _DIAGNOSTIC = frozenset({ |
| "query_logs", "check_metrics", "check_dependencies", |
| "check_recent_deploys", "check_service_status", |
| }) |
|
|
| _REMEDIATION = frozenset({ |
| "restart_service", "rollback_deploy", "scale_service", |
| "disable_feature_flag", "clear_cache", "execute_runbook_step", |
| }) |
|
|
| _SUBMIT = frozenset({ |
| "submit_severity", "submit_root_cause", "submit_resolution", |
| }) |
|
|
| _TASK_SUBMIT = { |
| "alert_classification": "submit_severity", |
| "root_cause_analysis": "submit_root_cause", |
| "remediation_planning": "submit_resolution", |
| } |
|
|
| _REWARD_TABLE = { |
| "easy": { |
| "query_new_svc": +0.04, "query_new_action": +0.02, |
| "query_repeat": -0.03, "query_unknown_svc": -0.06, |
| "query_no_service": -0.04, "rem_good": +0.00, |
| "rem_wrong": -0.08, "rem_no_target": -0.05, |
| "submit_correct": +0.02, "submit_wrong": -0.08, |
| "past_half": -0.04, "timeout": -0.15, |
| "bad_action": -0.05, "exact_repeat": -0.04, |
| }, |
| "medium": { |
| "query_new_svc": +0.04, "query_new_action": +0.02, |
| "query_repeat": -0.04, "query_unknown_svc": -0.06, |
| "query_no_service": -0.04, "rem_good": +0.06, |
| "rem_wrong": -0.10, "rem_no_target": -0.06, |
| "submit_correct": +0.02, "submit_wrong": -0.10, |
| "past_half": -0.02, "timeout": -0.15, |
| "bad_action": -0.05, "exact_repeat": -0.05, |
| }, |
| "hard": { |
| "query_new_svc": +0.03, "query_new_action": +0.01, |
| "query_repeat": -0.03, "query_unknown_svc": -0.05, |
| "query_no_service": -0.03, "rem_good": +0.06, |
| "rem_wrong": -0.15, "rem_no_target": -0.05, |
| "submit_correct": +0.02, "submit_wrong": -0.12, |
| "past_half": -0.02, "timeout": -0.20, |
| "bad_action": -0.05, "exact_repeat": -0.04, |
| }, |
| } |
|
|
| _TASK_DIFFICULTY = { |
| "alert_classification": "easy", |
| "root_cause_analysis": "medium", |
| "remediation_planning": "hard", |
| } |
|
|
|
|
| class IncidentEnvironment: |
| def __init__(self) -> None: |
| self._lock = threading.Lock() |
| self._s: dict = {} |
| self._scenario: dict = {} |
| self._task_def: dict = {} |
| self._ready = False |
|
|
| def reset(self, task_id: str = "alert_classification", |
| scenario_index: int = 0) -> Observation: |
| with self._lock: |
| task_def = get_task(task_id) |
| scenario = get_scenario(task_id, scenario_index) |
| self._task_def = task_def |
| self._scenario = scenario |
| self._s = { |
| "episode_id": str(uuid.uuid4()), |
| "task_id": task_id, |
| "scenario_id": scenario["scenario_id"], |
| "step_count": 0, |
| "max_steps": task_def["max_steps"], |
| "action_history": [], |
| "queried_data": {}, |
| "queried_keys": set(), |
| "services_queried": set(), |
| "exact_hashes": set(), |
| "submitted": False, |
| "resolved": False, |
| "done": False, |
| "cumulative_reward": 0.0, |
| "feedback": f"Episode started. {scenario['description']}", |
| "last_action_error": None, |
| } |
| self._ready = True |
| return self._build_obs() |
|
|
| def step(self, action: Action) -> tuple[Observation, Reward, bool, dict]: |
| with self._lock: |
| if not self._ready: |
| raise RuntimeError("Call reset() before step().") |
| s = self._s |
| s["last_action_error"] = None |
|
|
| if s["done"]: |
| return (self._build_obs(), |
| Reward(score=0.0, reason="episode already done", |
| cumulative=s["cumulative_reward"]), |
| True, {}) |
|
|
| s["step_count"] += 1 |
| step_num = s["step_count"] |
| at = action.action_type |
| params = action.parameters |
| task_id = s["task_id"] |
| diff = _TASK_DIFFICULTY.get(task_id, "medium") |
| rt = _REWARD_TABLE[diff] |
|
|
| s["action_history"].append({ |
| "action_type": at, |
| "parameters": params.model_dump(exclude_none=True), |
| "step": step_num, |
| }) |
|
|
| r = 0.0 |
| fb: list[str] = [] |
|
|
| h = f"{at}|{params.model_dump_json(exclude_none=True)}" |
| if h in s["exact_hashes"]: |
| r += rt["exact_repeat"] |
| fb.append(f"exact repeat ({rt['exact_repeat']:+.2f})") |
| s["exact_hashes"].add(h) |
|
|
| half = max(1, s["max_steps"] // 2) |
| if step_num > half and at not in _SUBMIT: |
| r += rt["past_half"] |
| fb.append(f"past halfway ({rt['past_half']:+.3f})") |
|
|
| if at in _DIAGNOSTIC: |
| r, fb = self._handle_diagnostic(at, params, r, fb, rt) |
| elif at in _REMEDIATION: |
| r, fb = self._handle_remediation(at, params, r, fb, rt, task_id) |
| elif at in _SUBMIT: |
| r, fb, terminal = self._handle_submit(at, params, r, fb, rt, task_id) |
| if terminal: |
| s["done"] = True |
| else: |
| r += rt["bad_action"] |
| fb.append(f"unknown action '{at}' ({rt['bad_action']:+.2f})") |
| s["last_action_error"] = f"Unknown action type: {at}" |
|
|
| if step_num >= s["max_steps"] and not s["done"]: |
| r += rt["timeout"] |
| fb.append(f"timeout ({rt['timeout']:+.2f})") |
| s["done"] = True |
|
|
| if s["done"]: |
| result = grade(s["task_id"], s, self._scenario) |
| grader_score = result["total"] |
| s["cumulative_reward"] = round( |
| s["cumulative_reward"] + r + grader_score, 4) |
| fb.append(f"grader={grader_score:.3f} ({result['feedback']})") |
| else: |
| s["cumulative_reward"] = round(s["cumulative_reward"] + r, 4) |
|
|
| s["feedback"] = " | ".join(fb) if fb else "ok" |
| return (self._build_obs(), |
| Reward(score=round(r, 4), reason=s["feedback"], |
| cumulative=s["cumulative_reward"]), |
| s["done"], |
| {"step": step_num, "feedback": s["feedback"]}) |
|
|
| def state(self) -> EpisodeState: |
| with self._lock: |
| if not self._ready: |
| raise RuntimeError("No active episode — call reset() first.") |
| s = self._s |
| return EpisodeState( |
| episode_id=s["episode_id"], task_id=s["task_id"], |
| scenario_id=s["scenario_id"], step_count=s["step_count"], |
| max_steps=s["max_steps"], |
| action_history=list(s["action_history"]), |
| queried_data=dict(s["queried_data"]), |
| submitted=s["submitted"], resolved=s["resolved"], |
| done=s["done"], cumulative_reward=s["cumulative_reward"], |
| feedback=s["feedback"]) |
|
|
| def _handle_diagnostic(self, at, params, r, fb, rt): |
| s = self._s |
| svc = (params.service or "").lower().strip() |
| known = {v.lower() for v in self._scenario.get("known_services", set())} |
| tool = self._scenario.get("tool_responses", {}).get(at, {}) |
| key = (at, svc) |
|
|
| if not svc: |
| r += rt["query_no_service"] |
| fb.append(f"{at}: no service ({rt['query_no_service']:+.2f})") |
| s["last_action_error"] = f"{at} requires a service parameter" |
| return r, fb |
|
|
| if svc not in known: |
| r += rt["query_unknown_svc"] |
| fb.append(f"unknown service '{svc}' ({rt['query_unknown_svc']:+.2f})") |
| s["last_action_error"] = f"Unknown service: {svc}" |
| return r, fb |
|
|
| if key in s["queried_keys"]: |
| r += rt["query_repeat"] |
| fb.append(f"repeat [{at}][{svc}] ({rt['query_repeat']:+.2f})") |
| elif svc in s["services_queried"]: |
| r += rt["query_new_action"] |
| fb.append(f"new action on {svc} ({rt['query_new_action']:+.2f})") |
| s["queried_keys"].add(key) |
| else: |
| r += rt["query_new_svc"] |
| fb.append(f"new service {svc} ({rt['query_new_svc']:+.2f})") |
| s["queried_keys"].add(key) |
| s["services_queried"].add(svc) |
|
|
| result = tool.get(svc, f"No data available for '{svc}'.") |
| s["queried_data"].setdefault(at, {})[svc] = result |
| return r, fb |
|
|
| def _handle_remediation(self, at, params, r, fb, rt, task_id): |
| s = self._s |
| if task_id == "alert_classification": |
| r += rt["rem_wrong"] |
| fb.append(f"remediation in easy task ({rt['rem_wrong']:+.2f})") |
| s["last_action_error"] = "Remediation not available in alert_classification" |
| return r, fb |
|
|
| svc = (params.service or "").lower().strip() |
| flag = (params.flag or "").lower().strip() |
| runbook = (params.runbook_action or "").lower().strip() |
| target = (params.target or params.target_version or "").lower().strip() |
|
|
| if not (svc or flag or runbook or target): |
| r += rt["rem_no_target"] |
| fb.append(f"{at}: no target ({rt['rem_no_target']:+.2f})") |
| s["last_action_error"] = f"{at} requires a target" |
| return r, fb |
|
|
| keys = {at} |
| if svc: keys.add(f"{at}:{svc}") |
| if flag: keys.add(f"{at}:{flag}") |
| if runbook: keys.add(f"execute_runbook_step:{runbook}") |
| if target: keys.add(f"execute_runbook_step:{target}") |
|
|
| wrong_map = self._scenario.get("wrong_actions", {}) |
| rem_data = self._scenario.get("remediation_data", {}) |
|
|
| is_wrong = any(k in wrong_map for k in keys) |
| if not is_wrong and svc: |
| for wk in wrong_map: |
| if ":" in wk: |
| w_at, w_svc = wk.split(":", 1) |
| if w_at == at and _svc_match(svc, w_svc): |
| is_wrong = True |
| break |
|
|
| if is_wrong: |
| r += rt["rem_wrong"] |
| reason = next((wrong_map[k] for k in keys if k in wrong_map), "wrong") |
| fb.append(f"wrong: {at} — {str(reason)[:60]} ({rt['rem_wrong']:+.2f})") |
| else: |
| r += rt["rem_good"] |
| tgt = svc or flag or runbook or target |
| fb.append(f"executed {at}:{tgt} ({rt['rem_good']:+.2f})") |
| at_data = rem_data.get(at, {}) |
| result = (at_data.get(svc) or at_data.get(flag) or at_data.get(runbook) |
| or at_data.get(target) or "action executed successfully") |
| s["queried_data"].setdefault(at, {})[tgt] = result |
| return r, fb |
|
|
| def _handle_submit(self, at, params, r, fb, rt, task_id): |
| s = self._s |
| correct = _TASK_SUBMIT.get(task_id, "") |
| if at != correct: |
| r += rt["submit_wrong"] |
| fb.append(f"wrong submit '{at}' (need '{correct}') ({rt['submit_wrong']:+.2f})") |
| s["last_action_error"] = f"Wrong submission type: use {correct}" |
| return r, fb, False |
|
|
| s["submitted"] = True |
| r += rt["submit_correct"] |
| fb.append(f"submitted ({rt['submit_correct']:+.2f})") |
|
|
| if at == "submit_severity": |
| fb.append(f"severity={(params.severity or '').upper().strip()}") |
| elif at == "submit_root_cause": |
| fb.append(f"svc={params.service or ''}, mode={params.failure_mode or ''}") |
| elif at == "submit_resolution": |
| summary = params.summary or "" |
| inv = sum(1 for a in s["action_history"] |
| if a.get("action_type") in _DIAGNOSTIC | _REMEDIATION) |
| if summary.strip() and inv >= 1: |
| s["resolved"] = True |
| fb.append("resolved") |
| else: |
| fb.append("insufficient investigation") |
| return r, fb, True |
|
|
| def _build_obs(self): |
| s = self._s |
| sc = self._scenario |
| td = self._task_def |
| return Observation( |
| episode_id=s["episode_id"], task_id=s["task_id"], |
| scenario_id=s["scenario_id"], step_count=s["step_count"], |
| max_steps=s["max_steps"], |
| incident_summary=sc.get("incident_summary", sc.get("description", "")), |
| alert=sc.get("alert", {}), |
| available_actions=td.get("available_actions", []), |
| queried_data=dict(s["queried_data"]), |
| cumulative_reward=s["cumulative_reward"], |
| done=s["done"], feedback=s["feedback"], |
| known_services=sorted(sc.get("known_services", set())), |
| last_action_error=s.get("last_action_error")) |