from __future__ import annotations from dataclasses import dataclass from pathlib import Path import re from sysadmin_env.models import RewardSignal from sysadmin_env.models import TaskScenarioDefinition from sysadmin_env.models import TaskScenarioState from sysadmin_env.tasks import get_task_module DEFAULT_STEP_PENALTY = -0.01 DEFAULT_CATASTROPHIC_PENALTY = -1.0 DEFAULT_DESTRUCTIVE_COMMAND_PATTERNS = ( r"(^|\s)rm\s+-rf\s+/($|\s)", r"(^|\s)rm\s+-rf\s+--no-preserve-root($|\s)", r"(^|\s)mkfs(\.|\s|$)", r"(^|\s)shutdown(\s|$)", r"(^|\s)reboot(\s|$)", r"(^|\s)halt(\s|$)", r"(^|\s)kill\s+(-9\s+)?1($|\s)", r"(^|\s)(dd|truncate)\b.*(of=|>)\s*/(etc|boot)(/|\s|$)", r":\s*\(\)\s*\{\s*:\s*\|\s*:\s*&\s*\}\s*;\s*:", ) @dataclass class EpisodeRewardState: task_id: str runtime_root: str known_fact_ids: set[str] last_health: float done: bool @dataclass class RewardComputation: signal: RewardSignal state: EpisodeRewardState task_state: TaskScenarioState catastrophic: bool class RewardEngine: def __init__( self, task_registry: dict[str, TaskScenarioDefinition], step_penalty: float = DEFAULT_STEP_PENALTY, catastrophic_penalty: float = DEFAULT_CATASTROPHIC_PENALTY, destructive_command_patterns: tuple[str, ...] = DEFAULT_DESTRUCTIVE_COMMAND_PATTERNS, ) -> None: self.task_registry = task_registry self.step_penalty = step_penalty self.catastrophic_penalty = catastrophic_penalty self.destructive_command_patterns = tuple(destructive_command_patterns) def start_episode(self, task_id: str, runtime_root: str | Path | None = None) -> EpisodeRewardState: definition = self.task_registry[task_id] effective_root = Path(runtime_root or definition.metadata.base_filesystem_path) task_state = self._grade_task(definition, effective_root) return EpisodeRewardState( task_id=task_id, runtime_root=str(effective_root), known_fact_ids=set(), last_health=task_state.health, done=task_state.done, ) def evaluate_action(self, state: EpisodeRewardState, command: str) -> RewardComputation: definition = self.task_registry[state.task_id] runtime_root = Path(state.runtime_root) if state.done: task_state = self._grade_task(definition, runtime_root) signal = RewardSignal( health_delta=0.0, knowledge_delta=0.0, action_penalty=0.0, total_reward=0.0, ) return RewardComputation( signal=signal, state=state, task_state=task_state, catastrophic=False, ) task_state = self._grade_task(definition, runtime_root) catastrophic = self.is_catastrophic_action(command) if catastrophic: state.done = True signal = RewardSignal( health_delta=0.0, knowledge_delta=0.0, action_penalty=self.catastrophic_penalty, total_reward=self.catastrophic_penalty, ) return RewardComputation( signal=signal, state=state, task_state=task_state, catastrophic=True, ) knowledge_delta = self._knowledge_delta(definition, state, command) health_delta = task_state.health - state.last_health total_reward = health_delta + knowledge_delta + self.step_penalty state.last_health = task_state.health state.done = task_state.done signal = RewardSignal( health_delta=health_delta, knowledge_delta=knowledge_delta, action_penalty=self.step_penalty, total_reward=total_reward, ) return RewardComputation( signal=signal, state=state, task_state=task_state, catastrophic=False, ) def is_catastrophic_action(self, command: str) -> bool: return any( re.search(pattern, command, flags=re.IGNORECASE) for pattern in self.destructive_command_patterns ) def _knowledge_delta( self, definition: TaskScenarioDefinition, state: EpisodeRewardState, command: str, ) -> float: task_module = get_task_module(state.task_id) reward = 0.0 for trigger in definition.diagnostic_triggers: if trigger.fact_id in state.known_fact_ids: continue if task_module.command_reveals_fact(command, trigger): state.known_fact_ids.add(trigger.fact_id) reward += trigger.reward return reward def _grade_task(self, definition: TaskScenarioDefinition, runtime_root: Path) -> TaskScenarioState: task_module = get_task_module(definition.metadata.task_id) return task_module.grade(runtime_root) def build_reward_engine( task_registry: dict[str, TaskScenarioDefinition], step_penalty: float = DEFAULT_STEP_PENALTY, catastrophic_penalty: float = DEFAULT_CATASTROPHIC_PENALTY, destructive_command_patterns: tuple[str, ...] = DEFAULT_DESTRUCTIVE_COMMAND_PATTERNS, ) -> RewardEngine: return RewardEngine( task_registry=task_registry, step_penalty=step_penalty, catastrophic_penalty=catastrophic_penalty, destructive_command_patterns=destructive_command_patterns, ) __all__ = [ "DEFAULT_CATASTROPHIC_PENALTY", "DEFAULT_DESTRUCTIVE_COMMAND_PATTERNS", "DEFAULT_STEP_PENALTY", "EpisodeRewardState", "RewardComputation", "RewardEngine", "build_reward_engine", ]