CausalOps-Env / env /environment.py
omm7's picture
Upload folder using huggingface_hub
bc2ead7 verified
"""
Session-safe OpenEnv environment with seeded partial observability.
"""
from __future__ import annotations
import os
import random
import threading
import uuid
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Set, Tuple
from data.db_loader import build_task_log_pool, load_patterns, load_thresholds
from env.models import Action, IncidentBriefing, Observation, Reward, RootCauseHypothesis
from tasks.catalog import CONTAINMENT_DESCRIPTIONS, DEPENDENCY_GRAPH, TASK_SPECS
from tasks.graders import build_dense_reward, containment_alignment, grade_report, hypothesis_match_score
DEBUG_STATE_ENABLED = os.getenv("OPENENV_DEBUG_STATE", "false").lower() == "true"
@dataclass
class IncidentSession:
session_id: str
task_id: str
seed: int
max_steps: int
logs: List[Dict[str, Any]]
thresholds: Dict[str, Dict[str, float]]
patterns: Dict[str, Dict[str, str]]
step_number: int = 0
done: bool = False
visible_log_ids: Set[int] = field(default_factory=set)
visited_services: Set[str] = field(default_factory=set)
containment_plan: List[str] = field(default_factory=list)
last_hypothesis: Optional[RootCauseHypothesis] = None
best_hypothesis_score: float = 0.0
query_fingerprints: Dict[str, int] = field(default_factory=dict)
last_reward: Optional[Reward] = None
episode_history: List[Dict[str, Any]] = field(default_factory=list)
def visible_logs(self) -> List[Dict[str, Any]]:
visible = [log for log in self.logs if log["log_id"] in self.visible_log_ids]
return sorted(visible, key=lambda log: (log["timestamp"], log["log_id"]))
def log_map(self) -> Dict[int, Dict[str, Any]]:
return {log["log_id"]: log for log in self.logs}
class SessionStore:
def __init__(self) -> None:
self._lock = threading.Lock()
self._sessions: Dict[str, IncidentSession] = {}
def reset(self, task_id: str = "easy", seed: Optional[int] = None) -> Observation:
if task_id not in TASK_SPECS:
raise ValueError(f"Unknown task_id '{task_id}'.")
actual_seed = int(seed if seed is not None else 2025 + (list(TASK_SPECS).index(task_id) * 17))
session_id = uuid.uuid4().hex
spec = TASK_SPECS[task_id]
session = IncidentSession(
session_id=session_id,
task_id=task_id,
seed=actual_seed,
max_steps=int(spec["max_steps"]),
logs=build_task_log_pool(task_id, actual_seed),
thresholds=load_thresholds(),
patterns=load_patterns(),
)
with self._lock:
self._sessions[session_id] = session
return self._build_observation(
session,
feedback="Episode created. Query the incident window and inspect dependencies to build your case.",
)
def step(self, action: Action) -> Tuple[Observation, Reward, bool, Dict[str, Any]]:
session = self._resolve_session(action.session_id)
if session.done:
raise RuntimeError("Episode already finished. Call /reset to start a new session.")
session.step_number += 1
repeated_action_count = self._register_action(session, action)
if action.action_type == "submit_report":
if action.report is None:
raise ValueError("submit_report requires report")
reward = grade_report(
task_id=session.task_id,
report=action.report,
revealed_log_ids=set(session.visible_log_ids),
revealed_log_map=session.log_map(),
step_number=session.step_number,
max_steps=session.max_steps,
repeated_action_count=repeated_action_count,
)
session.done = True
feedback = "Final report graded."
elif action.action_type == "no_anomalies":
reward = build_dense_reward(
signal_reward=0.0,
hypothesis_reward=0.0,
efficiency_reward=0.0,
penalty=1.0,
info={"message": "No-incident declaration is invalid for this benchmark."},
)
session.done = True
feedback = "No-incident declaration rejected."
else:
reward, feedback = self._handle_non_terminal(session, action, repeated_action_count)
if session.step_number >= session.max_steps:
session.done = True
feedback = f"{feedback} Step budget exhausted."
session.last_reward = reward
session.episode_history.append(
{
"step": session.step_number,
"action_type": action.action_type,
"reward": reward.value,
"done": session.done,
}
)
observation = self._build_observation(session, feedback=feedback)
return observation, reward, session.done, dict(reward.info)
def public_state(self, session_id: Optional[str] = None) -> Dict[str, Any]:
session = self._resolve_session(session_id)
return {
"session_id": session.session_id,
"task_id": session.task_id,
"step_number": session.step_number,
"max_steps": session.max_steps,
"done": session.done,
"revealed_log_count": len(session.visible_log_ids),
"visited_services": sorted(session.visited_services),
"submitted_containment": list(session.containment_plan),
"last_reward": session.last_reward.model_dump() if session.last_reward else None,
}
def debug_state(self, session_id: Optional[str] = None) -> Dict[str, Any]:
if not DEBUG_STATE_ENABLED:
raise PermissionError("Debug state is disabled.")
session = self._resolve_session(session_id)
return {
"session_id": session.session_id,
"task_id": session.task_id,
"seed": session.seed,
"visible_log_ids": sorted(session.visible_log_ids),
"all_logs": session.logs,
"history": session.episode_history,
"best_hypothesis_score": session.best_hypothesis_score,
}
def _resolve_session(self, session_id: Optional[str]) -> IncidentSession:
with self._lock:
if session_id:
session = self._sessions.get(session_id)
if session is None:
raise RuntimeError(f"Unknown session_id '{session_id}'.")
return session
if len(self._sessions) == 1:
return next(iter(self._sessions.values()))
raise RuntimeError("A valid session_id is required.")
def _handle_non_terminal(
self,
session: IncidentSession,
action: Action,
repeated_action_count: int,
) -> Tuple[Reward, str]:
signal_reward = 0.0
hypothesis_reward = 0.0
penalty = 0.0
info: Dict[str, Any] = {}
if action.action_type == "query_logs":
if action.query is None:
raise ValueError("query_logs requires query")
newly_revealed = self._query_logs(session, action.query.model_dump(exclude_none=True))
relevant = set(TASK_SPECS[session.task_id]["gold_evidence_ids"])
relevant_new = len(relevant & set(newly_revealed))
signal_reward = min(1.0, round((0.22 * len(newly_revealed)) + (0.28 * relevant_new), 4))
penalty = 0.15 if not newly_revealed else 0.0
feedback = f"Query revealed {len(newly_revealed)} new log(s)."
info["revealed_log_ids"] = newly_revealed
elif action.action_type == "inspect_dependencies":
if action.target_service is None:
raise ValueError("inspect_dependencies requires target_service")
session.visited_services.add(action.target_service)
neighbors = DEPENDENCY_GRAPH.get(action.target_service, [])
revealed = self._inspect_dependencies(session, action.target_service, neighbors)
relevant = set(TASK_SPECS[session.task_id]["gold_evidence_ids"])
signal_reward = min(1.0, round((0.15 * len(revealed)) + (0.35 * len(relevant & set(revealed))), 4))
penalty = 0.1 if not revealed else 0.0
feedback = f"Dependency inspection around {action.target_service} revealed {len(revealed)} new log(s)."
info["neighbors"] = neighbors
info["revealed_log_ids"] = revealed
elif action.action_type == "update_hypothesis":
if action.hypothesis is None:
raise ValueError("update_hypothesis requires hypothesis")
current_score = hypothesis_match_score(action.hypothesis, session.task_id)
improvement = max(0.0, current_score - session.best_hypothesis_score)
session.best_hypothesis_score = max(session.best_hypothesis_score, current_score)
session.last_hypothesis = action.hypothesis
hypothesis_reward = improvement
penalty = 0.15 if improvement == 0.0 and current_score < session.best_hypothesis_score else 0.0
feedback = "Hypothesis recorded."
info["hypothesis_score"] = current_score
elif action.action_type == "execute_containment":
plan = list(action.containment_plan or [])
positive, negative = containment_alignment(plan, session.task_id)
for item in plan:
if item not in session.containment_plan:
session.containment_plan.append(item)
hypothesis_reward = positive
penalty = min(1.0, negative + (0.05 if not plan else 0.0))
feedback = "Containment actions recorded."
info["containment_positive"] = positive
info["containment_negative"] = negative
info["containment_descriptions"] = [CONTAINMENT_DESCRIPTIONS[item] for item in plan]
elif action.action_type == "request_more":
penalty = 0.1
feedback = "No additional passive data is provided. Use a concrete query."
else:
penalty = 0.2
feedback = "Unsupported action."
if repeated_action_count > 0:
penalty = min(1.0, penalty + min(0.2, repeated_action_count * 0.1))
efficiency_reward = max(
0.0,
round(1.0 - ((session.step_number - 1) / max(1, session.max_steps - 1)), 4),
)
reward = build_dense_reward(
signal_reward=signal_reward,
hypothesis_reward=hypothesis_reward,
efficiency_reward=efficiency_reward,
penalty=penalty,
info=info,
)
return reward, feedback
def _query_logs(self, session: IncidentSession, query: Dict[str, Any]) -> List[int]:
matched = [log for log in session.logs if self._match_query(log, query)]
ranked = sorted(matched, key=lambda log: (self._severity_rank(log["log_level"]), -float(log["_seed_rank"])))
revealed: List[int] = []
for log in ranked:
if log["log_id"] in session.visible_log_ids:
continue
session.visible_log_ids.add(log["log_id"])
revealed.append(log["log_id"])
session.visited_services.add(log["service_name"])
if len(revealed) >= int(query.get("limit", 6)):
break
return revealed
def _inspect_dependencies(self, session: IncidentSession, target_service: str, neighbors: List[str]) -> List[int]:
candidate_services = {target_service, *[neighbor for neighbor in neighbors if neighbor.endswith("-service")]}
matched = [
log
for log in session.logs
if log["service_name"] in candidate_services and log["log_level"] in {"CRITICAL", "ERROR", "WARN"}
]
ranked = sorted(matched, key=lambda log: (self._severity_rank(log["log_level"]), log["timestamp"], -float(log["_seed_rank"])))
revealed: List[int] = []
for log in ranked:
if log["log_id"] in session.visible_log_ids:
continue
session.visible_log_ids.add(log["log_id"])
revealed.append(log["log_id"])
if len(revealed) >= 4:
break
return revealed
@staticmethod
def _match_query(log: Dict[str, Any], query: Dict[str, Any]) -> bool:
if query.get("service_name") and log["service_name"] != query["service_name"]:
return False
if query.get("server_id") and log["server_id"] != query["server_id"]:
return False
if query.get("levels") and log["log_level"] not in set(query["levels"]):
return False
if query.get("start_time") and str(log["timestamp"]) < str(query["start_time"]):
return False
if query.get("end_time") and str(log["timestamp"]) > str(query["end_time"]):
return False
if query.get("text_contains") and query["text_contains"].lower() not in str(log["message"]).lower():
return False
return True
@staticmethod
def _severity_rank(level: str) -> int:
order = {"CRITICAL": 0, "ERROR": 1, "WARN": 2, "INFO": 3}
return order.get(level, 4)
@staticmethod
def _register_action(session: IncidentSession, action: Action) -> int:
fingerprint_source = [action.action_type]
if action.query:
fingerprint_source.append(str(action.query.model_dump(exclude_none=True)))
if action.target_service:
fingerprint_source.append(action.target_service)
if action.hypothesis:
fingerprint_source.append(str(action.hypothesis.model_dump()))
if action.containment_plan:
fingerprint_source.append(",".join(action.containment_plan))
if action.report:
fingerprint_source.append(str(action.report.root_cause.model_dump()))
fingerprint = "::".join(fingerprint_source)
count = session.query_fingerprints.get(fingerprint, 0)
session.query_fingerprints[fingerprint] = count + 1
return count
def _build_observation(self, session: IncidentSession, feedback: Optional[str]) -> Observation:
spec = TASK_SPECS[session.task_id]
return Observation(
session_id=session.session_id,
task_id=session.task_id,
task_title=str(spec["title"]),
briefing=IncidentBriefing(
incident_id=str(spec["incident_id"]),
title=str(spec["title"]),
objective=str(spec["objective"]),
incident_window_start=str(spec["incident_window_start"]),
incident_window_end=str(spec["incident_window_end"]),
suspected_services=list(spec["suspected_services"]),
customer_statement=str(spec["customer_statement"]),
operational_constraints=list(spec["operational_constraints"]),
),
dependency_graph=DEPENDENCY_GRAPH,
visible_logs=session.visible_logs(),
revealed_log_count=len(session.visible_log_ids),
visited_services=sorted(session.visited_services),
submitted_containment=list(session.containment_plan),
last_hypothesis=session.last_hypothesis,
step_number=session.step_number,
max_steps=session.max_steps,
feedback=feedback,
done=session.done,
)
store = SessionStore()