Spaces:
Sleeping
Sleeping
| """ | |
| Session-safe OpenEnv environment with seeded partial observability. | |
| """ | |
| from __future__ import annotations | |
| import os | |
| import random | |
| import threading | |
| import uuid | |
| from dataclasses import dataclass, field | |
| from typing import Any, Dict, List, Optional, Set, Tuple | |
| from data.db_loader import build_task_log_pool, load_patterns, load_thresholds | |
| from env.models import Action, IncidentBriefing, Observation, Reward, RootCauseHypothesis | |
| from tasks.catalog import CONTAINMENT_DESCRIPTIONS, DEPENDENCY_GRAPH, TASK_SPECS | |
| from tasks.graders import build_dense_reward, containment_alignment, grade_report, hypothesis_match_score | |
| DEBUG_STATE_ENABLED = os.getenv("OPENENV_DEBUG_STATE", "false").lower() == "true" | |
| class IncidentSession: | |
| session_id: str | |
| task_id: str | |
| seed: int | |
| max_steps: int | |
| logs: List[Dict[str, Any]] | |
| thresholds: Dict[str, Dict[str, float]] | |
| patterns: Dict[str, Dict[str, str]] | |
| step_number: int = 0 | |
| done: bool = False | |
| visible_log_ids: Set[int] = field(default_factory=set) | |
| visited_services: Set[str] = field(default_factory=set) | |
| containment_plan: List[str] = field(default_factory=list) | |
| last_hypothesis: Optional[RootCauseHypothesis] = None | |
| best_hypothesis_score: float = 0.0 | |
| query_fingerprints: Dict[str, int] = field(default_factory=dict) | |
| last_reward: Optional[Reward] = None | |
| episode_history: List[Dict[str, Any]] = field(default_factory=list) | |
| def visible_logs(self) -> List[Dict[str, Any]]: | |
| visible = [log for log in self.logs if log["log_id"] in self.visible_log_ids] | |
| return sorted(visible, key=lambda log: (log["timestamp"], log["log_id"])) | |
| def log_map(self) -> Dict[int, Dict[str, Any]]: | |
| return {log["log_id"]: log for log in self.logs} | |
| class SessionStore: | |
| def __init__(self) -> None: | |
| self._lock = threading.Lock() | |
| self._sessions: Dict[str, IncidentSession] = {} | |
| def reset(self, task_id: str = "easy", seed: Optional[int] = None) -> Observation: | |
| if task_id not in TASK_SPECS: | |
| raise ValueError(f"Unknown task_id '{task_id}'.") | |
| actual_seed = int(seed if seed is not None else 2025 + (list(TASK_SPECS).index(task_id) * 17)) | |
| session_id = uuid.uuid4().hex | |
| spec = TASK_SPECS[task_id] | |
| session = IncidentSession( | |
| session_id=session_id, | |
| task_id=task_id, | |
| seed=actual_seed, | |
| max_steps=int(spec["max_steps"]), | |
| logs=build_task_log_pool(task_id, actual_seed), | |
| thresholds=load_thresholds(), | |
| patterns=load_patterns(), | |
| ) | |
| with self._lock: | |
| self._sessions[session_id] = session | |
| return self._build_observation( | |
| session, | |
| feedback="Episode created. Query the incident window and inspect dependencies to build your case.", | |
| ) | |
| def step(self, action: Action) -> Tuple[Observation, Reward, bool, Dict[str, Any]]: | |
| session = self._resolve_session(action.session_id) | |
| if session.done: | |
| raise RuntimeError("Episode already finished. Call /reset to start a new session.") | |
| session.step_number += 1 | |
| repeated_action_count = self._register_action(session, action) | |
| if action.action_type == "submit_report": | |
| if action.report is None: | |
| raise ValueError("submit_report requires report") | |
| reward = grade_report( | |
| task_id=session.task_id, | |
| report=action.report, | |
| revealed_log_ids=set(session.visible_log_ids), | |
| revealed_log_map=session.log_map(), | |
| step_number=session.step_number, | |
| max_steps=session.max_steps, | |
| repeated_action_count=repeated_action_count, | |
| ) | |
| session.done = True | |
| feedback = "Final report graded." | |
| elif action.action_type == "no_anomalies": | |
| reward = build_dense_reward( | |
| signal_reward=0.0, | |
| hypothesis_reward=0.0, | |
| efficiency_reward=0.0, | |
| penalty=1.0, | |
| info={"message": "No-incident declaration is invalid for this benchmark."}, | |
| ) | |
| session.done = True | |
| feedback = "No-incident declaration rejected." | |
| else: | |
| reward, feedback = self._handle_non_terminal(session, action, repeated_action_count) | |
| if session.step_number >= session.max_steps: | |
| session.done = True | |
| feedback = f"{feedback} Step budget exhausted." | |
| session.last_reward = reward | |
| session.episode_history.append( | |
| { | |
| "step": session.step_number, | |
| "action_type": action.action_type, | |
| "reward": reward.value, | |
| "done": session.done, | |
| } | |
| ) | |
| observation = self._build_observation(session, feedback=feedback) | |
| return observation, reward, session.done, dict(reward.info) | |
| def public_state(self, session_id: Optional[str] = None) -> Dict[str, Any]: | |
| session = self._resolve_session(session_id) | |
| return { | |
| "session_id": session.session_id, | |
| "task_id": session.task_id, | |
| "step_number": session.step_number, | |
| "max_steps": session.max_steps, | |
| "done": session.done, | |
| "revealed_log_count": len(session.visible_log_ids), | |
| "visited_services": sorted(session.visited_services), | |
| "submitted_containment": list(session.containment_plan), | |
| "last_reward": session.last_reward.model_dump() if session.last_reward else None, | |
| } | |
| def debug_state(self, session_id: Optional[str] = None) -> Dict[str, Any]: | |
| if not DEBUG_STATE_ENABLED: | |
| raise PermissionError("Debug state is disabled.") | |
| session = self._resolve_session(session_id) | |
| return { | |
| "session_id": session.session_id, | |
| "task_id": session.task_id, | |
| "seed": session.seed, | |
| "visible_log_ids": sorted(session.visible_log_ids), | |
| "all_logs": session.logs, | |
| "history": session.episode_history, | |
| "best_hypothesis_score": session.best_hypothesis_score, | |
| } | |
| def _resolve_session(self, session_id: Optional[str]) -> IncidentSession: | |
| with self._lock: | |
| if session_id: | |
| session = self._sessions.get(session_id) | |
| if session is None: | |
| raise RuntimeError(f"Unknown session_id '{session_id}'.") | |
| return session | |
| if len(self._sessions) == 1: | |
| return next(iter(self._sessions.values())) | |
| raise RuntimeError("A valid session_id is required.") | |
| def _handle_non_terminal( | |
| self, | |
| session: IncidentSession, | |
| action: Action, | |
| repeated_action_count: int, | |
| ) -> Tuple[Reward, str]: | |
| signal_reward = 0.0 | |
| hypothesis_reward = 0.0 | |
| penalty = 0.0 | |
| info: Dict[str, Any] = {} | |
| if action.action_type == "query_logs": | |
| if action.query is None: | |
| raise ValueError("query_logs requires query") | |
| newly_revealed = self._query_logs(session, action.query.model_dump(exclude_none=True)) | |
| relevant = set(TASK_SPECS[session.task_id]["gold_evidence_ids"]) | |
| relevant_new = len(relevant & set(newly_revealed)) | |
| signal_reward = min(1.0, round((0.22 * len(newly_revealed)) + (0.28 * relevant_new), 4)) | |
| penalty = 0.15 if not newly_revealed else 0.0 | |
| feedback = f"Query revealed {len(newly_revealed)} new log(s)." | |
| info["revealed_log_ids"] = newly_revealed | |
| elif action.action_type == "inspect_dependencies": | |
| if action.target_service is None: | |
| raise ValueError("inspect_dependencies requires target_service") | |
| session.visited_services.add(action.target_service) | |
| neighbors = DEPENDENCY_GRAPH.get(action.target_service, []) | |
| revealed = self._inspect_dependencies(session, action.target_service, neighbors) | |
| relevant = set(TASK_SPECS[session.task_id]["gold_evidence_ids"]) | |
| signal_reward = min(1.0, round((0.15 * len(revealed)) + (0.35 * len(relevant & set(revealed))), 4)) | |
| penalty = 0.1 if not revealed else 0.0 | |
| feedback = f"Dependency inspection around {action.target_service} revealed {len(revealed)} new log(s)." | |
| info["neighbors"] = neighbors | |
| info["revealed_log_ids"] = revealed | |
| elif action.action_type == "update_hypothesis": | |
| if action.hypothesis is None: | |
| raise ValueError("update_hypothesis requires hypothesis") | |
| current_score = hypothesis_match_score(action.hypothesis, session.task_id) | |
| improvement = max(0.0, current_score - session.best_hypothesis_score) | |
| session.best_hypothesis_score = max(session.best_hypothesis_score, current_score) | |
| session.last_hypothesis = action.hypothesis | |
| hypothesis_reward = improvement | |
| penalty = 0.15 if improvement == 0.0 and current_score < session.best_hypothesis_score else 0.0 | |
| feedback = "Hypothesis recorded." | |
| info["hypothesis_score"] = current_score | |
| elif action.action_type == "execute_containment": | |
| plan = list(action.containment_plan or []) | |
| positive, negative = containment_alignment(plan, session.task_id) | |
| for item in plan: | |
| if item not in session.containment_plan: | |
| session.containment_plan.append(item) | |
| hypothesis_reward = positive | |
| penalty = min(1.0, negative + (0.05 if not plan else 0.0)) | |
| feedback = "Containment actions recorded." | |
| info["containment_positive"] = positive | |
| info["containment_negative"] = negative | |
| info["containment_descriptions"] = [CONTAINMENT_DESCRIPTIONS[item] for item in plan] | |
| elif action.action_type == "request_more": | |
| penalty = 0.1 | |
| feedback = "No additional passive data is provided. Use a concrete query." | |
| else: | |
| penalty = 0.2 | |
| feedback = "Unsupported action." | |
| if repeated_action_count > 0: | |
| penalty = min(1.0, penalty + min(0.2, repeated_action_count * 0.1)) | |
| efficiency_reward = max( | |
| 0.0, | |
| round(1.0 - ((session.step_number - 1) / max(1, session.max_steps - 1)), 4), | |
| ) | |
| reward = build_dense_reward( | |
| signal_reward=signal_reward, | |
| hypothesis_reward=hypothesis_reward, | |
| efficiency_reward=efficiency_reward, | |
| penalty=penalty, | |
| info=info, | |
| ) | |
| return reward, feedback | |
| def _query_logs(self, session: IncidentSession, query: Dict[str, Any]) -> List[int]: | |
| matched = [log for log in session.logs if self._match_query(log, query)] | |
| ranked = sorted(matched, key=lambda log: (self._severity_rank(log["log_level"]), -float(log["_seed_rank"]))) | |
| revealed: List[int] = [] | |
| for log in ranked: | |
| if log["log_id"] in session.visible_log_ids: | |
| continue | |
| session.visible_log_ids.add(log["log_id"]) | |
| revealed.append(log["log_id"]) | |
| session.visited_services.add(log["service_name"]) | |
| if len(revealed) >= int(query.get("limit", 6)): | |
| break | |
| return revealed | |
| def _inspect_dependencies(self, session: IncidentSession, target_service: str, neighbors: List[str]) -> List[int]: | |
| candidate_services = {target_service, *[neighbor for neighbor in neighbors if neighbor.endswith("-service")]} | |
| matched = [ | |
| log | |
| for log in session.logs | |
| if log["service_name"] in candidate_services and log["log_level"] in {"CRITICAL", "ERROR", "WARN"} | |
| ] | |
| ranked = sorted(matched, key=lambda log: (self._severity_rank(log["log_level"]), log["timestamp"], -float(log["_seed_rank"]))) | |
| revealed: List[int] = [] | |
| for log in ranked: | |
| if log["log_id"] in session.visible_log_ids: | |
| continue | |
| session.visible_log_ids.add(log["log_id"]) | |
| revealed.append(log["log_id"]) | |
| if len(revealed) >= 4: | |
| break | |
| return revealed | |
| def _match_query(log: Dict[str, Any], query: Dict[str, Any]) -> bool: | |
| if query.get("service_name") and log["service_name"] != query["service_name"]: | |
| return False | |
| if query.get("server_id") and log["server_id"] != query["server_id"]: | |
| return False | |
| if query.get("levels") and log["log_level"] not in set(query["levels"]): | |
| return False | |
| if query.get("start_time") and str(log["timestamp"]) < str(query["start_time"]): | |
| return False | |
| if query.get("end_time") and str(log["timestamp"]) > str(query["end_time"]): | |
| return False | |
| if query.get("text_contains") and query["text_contains"].lower() not in str(log["message"]).lower(): | |
| return False | |
| return True | |
| def _severity_rank(level: str) -> int: | |
| order = {"CRITICAL": 0, "ERROR": 1, "WARN": 2, "INFO": 3} | |
| return order.get(level, 4) | |
| def _register_action(session: IncidentSession, action: Action) -> int: | |
| fingerprint_source = [action.action_type] | |
| if action.query: | |
| fingerprint_source.append(str(action.query.model_dump(exclude_none=True))) | |
| if action.target_service: | |
| fingerprint_source.append(action.target_service) | |
| if action.hypothesis: | |
| fingerprint_source.append(str(action.hypothesis.model_dump())) | |
| if action.containment_plan: | |
| fingerprint_source.append(",".join(action.containment_plan)) | |
| if action.report: | |
| fingerprint_source.append(str(action.report.root_cause.model_dump())) | |
| fingerprint = "::".join(fingerprint_source) | |
| count = session.query_fingerprints.get(fingerprint, 0) | |
| session.query_fingerprints[fingerprint] = count + 1 | |
| return count | |
| def _build_observation(self, session: IncidentSession, feedback: Optional[str]) -> Observation: | |
| spec = TASK_SPECS[session.task_id] | |
| return Observation( | |
| session_id=session.session_id, | |
| task_id=session.task_id, | |
| task_title=str(spec["title"]), | |
| briefing=IncidentBriefing( | |
| incident_id=str(spec["incident_id"]), | |
| title=str(spec["title"]), | |
| objective=str(spec["objective"]), | |
| incident_window_start=str(spec["incident_window_start"]), | |
| incident_window_end=str(spec["incident_window_end"]), | |
| suspected_services=list(spec["suspected_services"]), | |
| customer_statement=str(spec["customer_statement"]), | |
| operational_constraints=list(spec["operational_constraints"]), | |
| ), | |
| dependency_graph=DEPENDENCY_GRAPH, | |
| visible_logs=session.visible_logs(), | |
| revealed_log_count=len(session.visible_log_ids), | |
| visited_services=sorted(session.visited_services), | |
| submitted_containment=list(session.containment_plan), | |
| last_hypothesis=session.last_hypothesis, | |
| step_number=session.step_number, | |
| max_steps=session.max_steps, | |
| feedback=feedback, | |
| done=session.done, | |
| ) | |
| store = SessionStore() | |