from __future__ import annotations import time from typing import Any from .actions import ActionParseError, parse_action from .curriculum import CurriculumManager from .direction_library import DirectionLibrary from .executor import EnvironmentExecutor from .observations import build_blue_observation, build_red_observation from .constants import MAX_TURNS, STEP_TIMEOUT_SECS from .rewards import compute_blue_reward, compute_red_reward from .secrets_bank import RulesBank, SecretsBank from .state import BlueAction, BlueResources, RedAction, TargetSystemState from .target_system import build_target_system try: from openenv import Environment _BASE = Environment except ImportError: _BASE = object # graceful fallback for local dev class SeigeEnv(_BASE): def __init__(self) -> None: self.direction_library = DirectionLibrary() self.target = build_target_system(self.direction_library) self.executor = EnvironmentExecutor(self.target) self.curriculum = CurriculumManager() self.secrets_bank = SecretsBank() self.rules_bank = RulesBank() self._state: TargetSystemState | None = None self._episode_red_rewards: list[float] = [] self._episode_blue_rewards: list[float] = [] def reset(self) -> dict: config = self.curriculum.get_config() self._state = TargetSystemState.sample( secrets_bank=self.secrets_bank, rules_bank=self.rules_bank, baseline=self.target.baseline_means, num_sessions=config["num_sessions"], ) self._state.stage = self.curriculum.stage self._state.blue_resources = BlueResources( probe_budget=config["probe_budget"], log_sample_rate=config["log_sample_rate"], latency_ticks=config["latency_ticks"], ) self._episode_red_rewards = [] self._episode_blue_rewards = [] return { "red": build_red_observation(self._state, config, self.target.get_num_layers()).to_dict(), "blue": build_blue_observation(self._state, config).to_dict(), } def step(self, action: dict[str, Any] | str) -> dict: if self._state is None: raise RuntimeError("Call reset() before step()") start = time.time() try: parsed = parse_action(action) except ActionParseError as exc: return self._error_result(str(exc)) if time.time() - start > STEP_TIMEOUT_SECS: return self._timeout_result() if isinstance(parsed, RedAction): result = self.executor.execute_red(parsed, self._state) reward = compute_red_reward(result, self._state, result.strategy_embedding) self._episode_red_rewards.append(reward) observation = build_red_observation( self._state, self.curriculum.get_config(), self.target.get_num_layers(), ).to_dict() elif isinstance(parsed, BlueAction): result = self.executor.execute_blue(parsed, self._state) reward = compute_blue_reward(result, self._state) self._episode_blue_rewards.append(reward) observation = build_blue_observation(self._state, self.curriculum.get_config()).to_dict() else: return self._error_result("Unknown parsed action") self._state.tick += 1 self._state.blue_resources.probes_used = 0 done = ( result.episode_terminal or self._state.tick >= MAX_TURNS or any(session.extraction_succeeded for session in self._state.sessions.values()) ) if done: avg_red = sum(self._episode_red_rewards) / max(1, len(self._episode_red_rewards)) avg_blue = sum(self._episode_blue_rewards) / max(1, len(self._episode_blue_rewards)) self.curriculum.record(avg_red, avg_blue) if self.curriculum.should_advance(): self.curriculum.advance() info_dict = result.info_dict() reward_info = { "reward/total": reward, "reward/extraction": 10.0 * float(result.extraction_succeeded), "reward/policy_bypass": 5.0 * float(result.policy_bypassed), "reward/evasion": 4.0 * float(result.success and not result.detected), "reward/detected_penalty": -3.0 * float(result.detected), } info_dict.update(reward_info) try: import wandb if wandb.run: wandb.log(reward_info, step=self._state.tick) except ImportError: pass return { "observation": observation, "reward": reward, "done": done, "info": info_dict, } def state(self) -> dict: if self._state is None: return {} return { "tick": self._state.tick, "stage": self._state.stage, "num_sessions": len(self._state.sessions), "active_attacks": sum(1 for session in self._state.sessions.values() if session.attack_active), "detections": sum(1 for session in self._state.sessions.values() if session.detected), } def _error_result(self, message: str) -> dict: return {"observation": {}, "reward": -1.0, "done": False, "info": {"error": message}} def _timeout_result(self) -> dict: return {"observation": {}, "reward": -2.0, "done": True, "info": {"error": "timeout"}}