""" Session-safe OpenEnv environment with seeded partial observability. """ from __future__ import annotations import os import random import threading import uuid from dataclasses import dataclass, field from typing import Any, Dict, List, Optional, Set, Tuple from data.db_loader import build_task_log_pool, load_patterns, load_thresholds from env.models import Action, IncidentBriefing, Observation, Reward, RootCauseHypothesis from tasks.catalog import CONTAINMENT_DESCRIPTIONS, DEPENDENCY_GRAPH, TASK_SPECS from tasks.graders import build_dense_reward, containment_alignment, grade_report, hypothesis_match_score DEBUG_STATE_ENABLED = os.getenv("OPENENV_DEBUG_STATE", "false").lower() == "true" @dataclass class IncidentSession: session_id: str task_id: str seed: int max_steps: int logs: List[Dict[str, Any]] thresholds: Dict[str, Dict[str, float]] patterns: Dict[str, Dict[str, str]] step_number: int = 0 done: bool = False visible_log_ids: Set[int] = field(default_factory=set) visited_services: Set[str] = field(default_factory=set) containment_plan: List[str] = field(default_factory=list) last_hypothesis: Optional[RootCauseHypothesis] = None best_hypothesis_score: float = 0.0 query_fingerprints: Dict[str, int] = field(default_factory=dict) last_reward: Optional[Reward] = None episode_history: List[Dict[str, Any]] = field(default_factory=list) def visible_logs(self) -> List[Dict[str, Any]]: visible = [log for log in self.logs if log["log_id"] in self.visible_log_ids] return sorted(visible, key=lambda log: (log["timestamp"], log["log_id"])) def log_map(self) -> Dict[int, Dict[str, Any]]: return {log["log_id"]: log for log in self.logs} class SessionStore: def __init__(self) -> None: self._lock = threading.Lock() self._sessions: Dict[str, IncidentSession] = {} def reset(self, task_id: str = "easy", seed: Optional[int] = None) -> Observation: if task_id not in TASK_SPECS: raise ValueError(f"Unknown task_id '{task_id}'.") actual_seed = int(seed if seed is not None else 2025 + (list(TASK_SPECS).index(task_id) * 17)) session_id = uuid.uuid4().hex spec = TASK_SPECS[task_id] session = IncidentSession( session_id=session_id, task_id=task_id, seed=actual_seed, max_steps=int(spec["max_steps"]), logs=build_task_log_pool(task_id, actual_seed), thresholds=load_thresholds(), patterns=load_patterns(), ) with self._lock: self._sessions[session_id] = session return self._build_observation( session, feedback="Episode created. Query the incident window and inspect dependencies to build your case.", ) def step(self, action: Action) -> Tuple[Observation, Reward, bool, Dict[str, Any]]: session = self._resolve_session(action.session_id) if session.done: raise RuntimeError("Episode already finished. Call /reset to start a new session.") session.step_number += 1 repeated_action_count = self._register_action(session, action) if action.action_type == "submit_report": if action.report is None: raise ValueError("submit_report requires report") reward = grade_report( task_id=session.task_id, report=action.report, revealed_log_ids=set(session.visible_log_ids), revealed_log_map=session.log_map(), step_number=session.step_number, max_steps=session.max_steps, repeated_action_count=repeated_action_count, ) session.done = True feedback = "Final report graded." elif action.action_type == "no_anomalies": reward = build_dense_reward( signal_reward=0.0, hypothesis_reward=0.0, efficiency_reward=0.0, penalty=1.0, info={"message": "No-incident declaration is invalid for this benchmark."}, ) session.done = True feedback = "No-incident declaration rejected." else: reward, feedback = self._handle_non_terminal(session, action, repeated_action_count) if session.step_number >= session.max_steps: session.done = True feedback = f"{feedback} Step budget exhausted." session.last_reward = reward session.episode_history.append( { "step": session.step_number, "action_type": action.action_type, "reward": reward.value, "done": session.done, } ) observation = self._build_observation(session, feedback=feedback) return observation, reward, session.done, dict(reward.info) def public_state(self, session_id: Optional[str] = None) -> Dict[str, Any]: session = self._resolve_session(session_id) return { "session_id": session.session_id, "task_id": session.task_id, "step_number": session.step_number, "max_steps": session.max_steps, "done": session.done, "revealed_log_count": len(session.visible_log_ids), "visited_services": sorted(session.visited_services), "submitted_containment": list(session.containment_plan), "last_reward": session.last_reward.model_dump() if session.last_reward else None, } def debug_state(self, session_id: Optional[str] = None) -> Dict[str, Any]: if not DEBUG_STATE_ENABLED: raise PermissionError("Debug state is disabled.") session = self._resolve_session(session_id) return { "session_id": session.session_id, "task_id": session.task_id, "seed": session.seed, "visible_log_ids": sorted(session.visible_log_ids), "all_logs": session.logs, "history": session.episode_history, "best_hypothesis_score": session.best_hypothesis_score, } def _resolve_session(self, session_id: Optional[str]) -> IncidentSession: with self._lock: if session_id: session = self._sessions.get(session_id) if session is None: raise RuntimeError(f"Unknown session_id '{session_id}'.") return session if len(self._sessions) == 1: return next(iter(self._sessions.values())) raise RuntimeError("A valid session_id is required.") def _handle_non_terminal( self, session: IncidentSession, action: Action, repeated_action_count: int, ) -> Tuple[Reward, str]: signal_reward = 0.0 hypothesis_reward = 0.0 penalty = 0.0 info: Dict[str, Any] = {} if action.action_type == "query_logs": if action.query is None: raise ValueError("query_logs requires query") newly_revealed = self._query_logs(session, action.query.model_dump(exclude_none=True)) relevant = set(TASK_SPECS[session.task_id]["gold_evidence_ids"]) relevant_new = len(relevant & set(newly_revealed)) signal_reward = min(1.0, round((0.22 * len(newly_revealed)) + (0.28 * relevant_new), 4)) penalty = 0.15 if not newly_revealed else 0.0 feedback = f"Query revealed {len(newly_revealed)} new log(s)." info["revealed_log_ids"] = newly_revealed elif action.action_type == "inspect_dependencies": if action.target_service is None: raise ValueError("inspect_dependencies requires target_service") session.visited_services.add(action.target_service) neighbors = DEPENDENCY_GRAPH.get(action.target_service, []) revealed = self._inspect_dependencies(session, action.target_service, neighbors) relevant = set(TASK_SPECS[session.task_id]["gold_evidence_ids"]) signal_reward = min(1.0, round((0.15 * len(revealed)) + (0.35 * len(relevant & set(revealed))), 4)) penalty = 0.1 if not revealed else 0.0 feedback = f"Dependency inspection around {action.target_service} revealed {len(revealed)} new log(s)." info["neighbors"] = neighbors info["revealed_log_ids"] = revealed elif action.action_type == "update_hypothesis": if action.hypothesis is None: raise ValueError("update_hypothesis requires hypothesis") current_score = hypothesis_match_score(action.hypothesis, session.task_id) improvement = max(0.0, current_score - session.best_hypothesis_score) session.best_hypothesis_score = max(session.best_hypothesis_score, current_score) session.last_hypothesis = action.hypothesis hypothesis_reward = improvement penalty = 0.15 if improvement == 0.0 and current_score < session.best_hypothesis_score else 0.0 feedback = "Hypothesis recorded." info["hypothesis_score"] = current_score elif action.action_type == "execute_containment": plan = list(action.containment_plan or []) positive, negative = containment_alignment(plan, session.task_id) for item in plan: if item not in session.containment_plan: session.containment_plan.append(item) hypothesis_reward = positive penalty = min(1.0, negative + (0.05 if not plan else 0.0)) feedback = "Containment actions recorded." info["containment_positive"] = positive info["containment_negative"] = negative info["containment_descriptions"] = [CONTAINMENT_DESCRIPTIONS[item] for item in plan] elif action.action_type == "request_more": penalty = 0.1 feedback = "No additional passive data is provided. Use a concrete query." else: penalty = 0.2 feedback = "Unsupported action." if repeated_action_count > 0: penalty = min(1.0, penalty + min(0.2, repeated_action_count * 0.1)) efficiency_reward = max( 0.0, round(1.0 - ((session.step_number - 1) / max(1, session.max_steps - 1)), 4), ) reward = build_dense_reward( signal_reward=signal_reward, hypothesis_reward=hypothesis_reward, efficiency_reward=efficiency_reward, penalty=penalty, info=info, ) return reward, feedback def _query_logs(self, session: IncidentSession, query: Dict[str, Any]) -> List[int]: matched = [log for log in session.logs if self._match_query(log, query)] ranked = sorted(matched, key=lambda log: (self._severity_rank(log["log_level"]), -float(log["_seed_rank"]))) revealed: List[int] = [] for log in ranked: if log["log_id"] in session.visible_log_ids: continue session.visible_log_ids.add(log["log_id"]) revealed.append(log["log_id"]) session.visited_services.add(log["service_name"]) if len(revealed) >= int(query.get("limit", 6)): break return revealed def _inspect_dependencies(self, session: IncidentSession, target_service: str, neighbors: List[str]) -> List[int]: candidate_services = {target_service, *[neighbor for neighbor in neighbors if neighbor.endswith("-service")]} matched = [ log for log in session.logs if log["service_name"] in candidate_services and log["log_level"] in {"CRITICAL", "ERROR", "WARN"} ] ranked = sorted(matched, key=lambda log: (self._severity_rank(log["log_level"]), log["timestamp"], -float(log["_seed_rank"]))) revealed: List[int] = [] for log in ranked: if log["log_id"] in session.visible_log_ids: continue session.visible_log_ids.add(log["log_id"]) revealed.append(log["log_id"]) if len(revealed) >= 4: break return revealed @staticmethod def _match_query(log: Dict[str, Any], query: Dict[str, Any]) -> bool: if query.get("service_name") and log["service_name"] != query["service_name"]: return False if query.get("server_id") and log["server_id"] != query["server_id"]: return False if query.get("levels") and log["log_level"] not in set(query["levels"]): return False if query.get("start_time") and str(log["timestamp"]) < str(query["start_time"]): return False if query.get("end_time") and str(log["timestamp"]) > str(query["end_time"]): return False if query.get("text_contains") and query["text_contains"].lower() not in str(log["message"]).lower(): return False return True @staticmethod def _severity_rank(level: str) -> int: order = {"CRITICAL": 0, "ERROR": 1, "WARN": 2, "INFO": 3} return order.get(level, 4) @staticmethod def _register_action(session: IncidentSession, action: Action) -> int: fingerprint_source = [action.action_type] if action.query: fingerprint_source.append(str(action.query.model_dump(exclude_none=True))) if action.target_service: fingerprint_source.append(action.target_service) if action.hypothesis: fingerprint_source.append(str(action.hypothesis.model_dump())) if action.containment_plan: fingerprint_source.append(",".join(action.containment_plan)) if action.report: fingerprint_source.append(str(action.report.root_cause.model_dump())) fingerprint = "::".join(fingerprint_source) count = session.query_fingerprints.get(fingerprint, 0) session.query_fingerprints[fingerprint] = count + 1 return count def _build_observation(self, session: IncidentSession, feedback: Optional[str]) -> Observation: spec = TASK_SPECS[session.task_id] return Observation( session_id=session.session_id, task_id=session.task_id, task_title=str(spec["title"]), briefing=IncidentBriefing( incident_id=str(spec["incident_id"]), title=str(spec["title"]), objective=str(spec["objective"]), incident_window_start=str(spec["incident_window_start"]), incident_window_end=str(spec["incident_window_end"]), suspected_services=list(spec["suspected_services"]), customer_statement=str(spec["customer_statement"]), operational_constraints=list(spec["operational_constraints"]), ), dependency_graph=DEPENDENCY_GRAPH, visible_logs=session.visible_logs(), revealed_log_count=len(session.visible_log_ids), visited_services=sorted(session.visited_services), submitted_containment=list(session.containment_plan), last_hypothesis=session.last_hypothesis, step_number=session.step_number, max_steps=session.max_steps, feedback=feedback, done=session.done, ) store = SessionStore()