Spaces:
Sleeping
Sleeping
| """SentinelEnvironment β core OpenEnv Environment for SRE incident triage.""" | |
| import uuid | |
| from typing import Any, Optional | |
| from openenv.core.env_server.interfaces import Environment | |
| from openenv.core.env_server.types import Observation | |
| from grading.grader import grade | |
| from grading.rewards import compute_step_reward, _call_signature, _is_relevant | |
| from models import SentinelAction, SentinelObservation, SentinelState | |
| from scenarios.task1_smoking_gun import SmokingGunScenario | |
| from scenarios.task2_upstream_culprit import UpstreamCulpritScenario | |
| from scenarios.task3_cascading_failure import CascadingFailureScenario | |
| from tools.registry import AVAILABLE_TOOLS, dispatch, make_relevance_key | |
| MAX_STEPS = 20 | |
| MAX_CONSECUTIVE_INVALID = 5 | |
| SCENARIOS = { | |
| 1: SmokingGunScenario, | |
| 2: UpstreamCulpritScenario, | |
| 3: CascadingFailureScenario, | |
| } | |
| TASK_NAMES = { | |
| 1: "The Smoking Gun", | |
| 2: "The Upstream Culprit", | |
| 3: "The Cascading Failure", | |
| } | |
| class SentinelEnvironment(Environment): | |
| """OpenEnv Environment for SRE incident triage.""" | |
| SUPPORTS_CONCURRENT_SESSIONS = True | |
| def __init__(self) -> None: | |
| super().__init__() | |
| self._scenario = None | |
| self._state = SentinelState() | |
| self._cumulative_reward: float = 0.0 | |
| self._previous_calls: list[str] = [] | |
| self._consecutive_invalid: int = 0 | |
| # ββ helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _make_obs( | |
| self, | |
| *, | |
| tool_output: str = "", | |
| error: str = "", | |
| done: bool = False, | |
| reward: Optional[float] = None, | |
| tool_descriptions: Optional[dict] = None, | |
| ) -> SentinelObservation: | |
| return SentinelObservation( | |
| incident_summary=self._scenario.get_incident_summary() if self._scenario else "", | |
| tool_output=tool_output, | |
| available_tools=AVAILABLE_TOOLS, | |
| step_number=self._state.step_count, | |
| max_steps=MAX_STEPS, | |
| cumulative_reward=self._cumulative_reward, | |
| last_action_error=error, | |
| done=done, | |
| reward=reward, | |
| tool_descriptions=tool_descriptions or {}, | |
| ) | |
| def _handle_submit(self, params: dict, step_num: int) -> SentinelObservation: | |
| required = ("root_cause", "affected_service", "recommendation") | |
| missing = [f for f in required if not params.get(f)] | |
| if missing: | |
| return self._make_obs( | |
| error=f"submit_resolution requires: {', '.join(missing)}", | |
| done=False, | |
| reward=0.0, | |
| ) | |
| self._consecutive_invalid = 0 | |
| result = grade(self._scenario, params, step_num) | |
| self._state.resolution_submitted = True | |
| self._state.root_cause_correct = result["root_cause_correct"] | |
| self._state.recommendation_correct = result["recommendation_correct"] | |
| self._state.final_score = result["score"] | |
| self._state.tools_called.append("submit_resolution") | |
| return self._make_obs( | |
| tool_output=f"Resolution graded. Score: {result['score']:.2f}", | |
| done=True, | |
| reward=result["score"], | |
| ) | |
| # ββ Environment interface ββββββββββββββββββββββββββββββββββββββββ | |
| def reset( | |
| self, | |
| seed: Optional[int] = None, | |
| episode_id: Optional[str] = None, | |
| **kwargs: Any, | |
| ) -> SentinelObservation: | |
| task_id = kwargs.get("task_id", 1) | |
| if task_id not in SCENARIOS: | |
| task_id = 1 | |
| self._scenario = SCENARIOS[task_id]() | |
| self._cumulative_reward = 0.0 | |
| self._previous_calls = [] | |
| self._consecutive_invalid = 0 | |
| self._state = SentinelState( | |
| episode_id=episode_id or str(uuid.uuid4()), | |
| step_count=0, | |
| task_id=task_id, | |
| task_name=TASK_NAMES[task_id], | |
| ) | |
| return self._make_obs( | |
| done=False, | |
| reward=None, | |
| tool_descriptions=self._scenario.get_tool_descriptions(), | |
| ) | |
| def step( | |
| self, | |
| action: SentinelAction, | |
| timeout_s: Optional[float] = None, | |
| **kwargs: Any, | |
| ) -> SentinelObservation: | |
| if self._scenario is None: | |
| return SentinelObservation( | |
| incident_summary="", | |
| last_action_error="Environment not reset. Call reset() first.", | |
| done=True, | |
| reward=0.0, | |
| ) | |
| self._state.step_count += 1 | |
| step_num = self._state.step_count | |
| tool_name = action.tool_name | |
| params = action.param_dict() | |
| # ββ submit_resolution βββββββββββββββββββββββββββββββββββββββ | |
| if tool_name == "submit_resolution": | |
| return self._handle_submit(params, step_num) | |
| # ββ regular tool dispatch βββββββββββββββββββββββββββββββββββ | |
| output, is_valid = dispatch(tool_name, params, self._scenario) | |
| if not is_valid: | |
| self._consecutive_invalid += 1 | |
| else: | |
| self._consecutive_invalid = 0 | |
| # compute reward | |
| reward = compute_step_reward( | |
| tool_name, | |
| params, | |
| is_valid, | |
| self._scenario.get_relevant_tools(), | |
| self._previous_calls, | |
| ) | |
| self._cumulative_reward += reward | |
| # track call | |
| sig = _call_signature(tool_name, params) | |
| self._previous_calls.append(sig) | |
| self._state.tools_called.append(f"{tool_name}({params})") | |
| # track relevant | |
| if is_valid and _is_relevant(tool_name, params, self._scenario.get_relevant_tools()): | |
| self._state.relevant_tools_called.append(f"{tool_name}({params})") | |
| # ββ check termination βββββββββββββββββββββββββββββββββββββββ | |
| done = False | |
| error_msg = "" if is_valid else output | |
| if self._consecutive_invalid >= MAX_CONSECUTIVE_INVALID: | |
| done = True | |
| output = f"{output}\n\nEpisode terminated: {MAX_CONSECUTIVE_INVALID} consecutive invalid actions." | |
| self._state.final_score = 0.0 | |
| if step_num >= MAX_STEPS: | |
| done = True | |
| output = f"{output}\n\nEpisode terminated: maximum steps ({MAX_STEPS}) reached." | |
| self._state.final_score = 0.0 | |
| return self._make_obs( | |
| tool_output=output, | |
| error=error_msg, | |
| done=done, | |
| reward=reward, | |
| ) | |
| def state(self) -> SentinelState: | |
| return self._state | |