Spaces:
Paused
Paused
| # Cloud Arena Environment — Mathematical Model RL | |
| # Extracted from cloud_arena_final.py | |
| # This is the MATHEMATICAL model env, NOT the LLM model. | |
| import sys, math, random, copy | |
| from collections import deque | |
| from typing import Dict, List, Optional, Tuple | |
| import numpy as np | |
| import gymnasium as gym | |
| from gymnasium import spaces | |
| # ── Seeds ───────────────────────────────────────────────────────────────────── | |
| GLOBAL_SEED = 42 | |
| np.random.seed(GLOBAL_SEED) | |
| random.seed(GLOBAL_SEED) | |
| # ── Observation layout (must sum to OBS_DIM) ────────────────────────────────── | |
| MAX_RES_IN_OBS = 8 # fixed obs slots (pad unused with zeros) | |
| N_FEAT_PER_RES = 10 # features per resource slot in obs | |
| N_BLOCK_B = 8 # global security block | |
| N_BLOCK_C = 7 # global cost block | |
| N_BLOCK_D = 6 # environment state block | |
| N_BLOCK_E = 24 # history: 8 actions + 8 rewards + 8 progress | |
| OBS_DIM = MAX_RES_IN_OBS * N_FEAT_PER_RES + N_BLOCK_B + N_BLOCK_C + N_BLOCK_D + N_BLOCK_E | |
| # = 80 + 8 + 7 + 6 + 24 = 125 | |
| assert OBS_DIM == 125, f"OBS_DIM mismatch: {OBS_DIM}" | |
| # ── Action space ────────────────────────────────────────────────────────────── | |
| N_ACTION_TYPES = 15 | |
| MAX_RESOURCES = 10 | |
| N_ACTIONS = N_ACTION_TYPES * MAX_RESOURCES # 150 | |
| A_NOOP=0; A_ANALYZE=1; A_VERIFY_DEPS=2; A_RESIZE_DOWN=3; A_RESIZE_UP=4 | |
| A_STOP=5; A_RESTART=6; A_DELETE=7; A_PATCH=8; A_ENCRYPT=9 | |
| A_RESTRICT=10; A_ROTATE_CREDS=11; A_ENABLE_LOG=12; A_ARCHIVE=13; A_OPT_NET=14 | |
| # Action cost penalties (small friction — makes actions non-free) | |
| ACTION_COSTS = { | |
| A_NOOP: 0.0, A_ANALYZE: -0.01, A_VERIFY_DEPS: -0.01, | |
| A_RESIZE_DOWN: -0.02, A_RESIZE_UP: -0.02, | |
| A_STOP: -0.03, A_RESTART: -0.02, A_DELETE: -0.05, | |
| A_PATCH: -0.02, A_ENCRYPT: -0.02, A_RESTRICT: -0.02, | |
| A_ROTATE_CREDS: -0.02, A_ENABLE_LOG: -0.01, | |
| A_ARCHIVE: -0.03, A_OPT_NET: -0.02, | |
| } | |
| # ── Curriculum ──────────────────────────────────────────────────────────────── | |
| # n_resources active per phase | |
| N_RESOURCES_PHASE = {0: 4, 1: 5, 2: 6, 3: 7, 4: 8, 5: 10} | |
| # Phase feature flags | |
| PHASE_FOG = {0: False, 1: True, 2: True, 3: True, 4: True, 5: True} | |
| PHASE_EVENTS = {0: False, 1: False, 2: True, 3: True, 4: True, 5: True} | |
| PHASE_CHAOS = {0: False, 1: False, 2: False, 3: True, 4: True, 5: True} | |
| CHAOS_INIT_PROB = {0: 0.0, 1: 0.0, 2: 0.0, 3: 0.20, 4: 0.30, 5: 0.35} | |
| # Win thresholds: cost must drop to this fraction of initial AND security >= sec_thr | |
| WIN_COST_THR = {0: 0.55, 1: 0.60, 2: 0.60, 3: 0.65, 4: 0.65, 5: 0.70} | |
| WIN_SEC_THR = {0: 0.00, 1: 0.60, 2: 0.70, 3: 0.70, 4: 0.75, 5: 0.80} | |
| MAX_STEPS = 150 | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| # RESOURCE OBJECT | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| class ResourceObject: | |
| CRIT = {"LOW": 0.3, "MED": 0.6, "HIGH": 1.0} | |
| def __init__(self, idx: int, criticality: str = "MED", | |
| category: str = "compute", rng: random.Random = None): | |
| rng = rng or random.Random(idx) | |
| self.idx = idx | |
| self.criticality = self.CRIT[criticality] | |
| self.category = category | |
| # ── Cost state ────────────────────────────────────────────────────── | |
| self.allocated = rng.uniform(0.70, 1.00) # initially overprovisioned | |
| self.usage = rng.uniform(0.15, 0.50) # true usage (always < allocated) | |
| self.usage = min(self.usage, self.allocated - 0.10) | |
| self.cost_rate = self.allocated # cost ∝ allocated | |
| self.activity_status = 1.0 # 1=active, 0=idle | |
| # ── State flags ────────────────────────────────────────────────────── | |
| self.health = 1 | |
| self.is_stopped = False | |
| self.is_deleted = False | |
| self.alert_flag = 0 | |
| # ── Security state (hidden under fog) ──────────────────────────────── | |
| self.risk_score = rng.uniform(0.05, 0.20) | |
| self.vulnerability = False | |
| self.encryption = True | |
| self.over_permission = False | |
| self.logging_enabled = True | |
| self.credential_age = rng.uniform(0.0, 0.3) | |
| self.exposure = rng.uniform(0.0, 0.15) | |
| self.sensitivity = rng.uniform(0.3, 0.8) | |
| # ── Fog of war ─────────────────────────────────────────────────────── | |
| self.fog_active = True # True = attributes hidden until ANALYZE | |
| self.cost_known = False | |
| self.deps_known = False | |
| self.steps_since_analyze = 0 | |
| self.staleness = 0.0 | |
| self.STALE_STEPS = 15 # after this many steps, fog re-activates | |
| # ── Dependency ─────────────────────────────────────────────────────── | |
| self.dependency_children: List[int] = [] # indices of resources that depend on this | |
| self.dependency_parent: Optional[int] = None | |
| # ── Diagnostics ────────────────────────────────────────────────────── | |
| self.steps_broken = 0 | |
| self.time_broken = 0.0 | |
| # ── Derived properties ──────────────────────────────────────────────────── | |
| def overprovision_ratio(self) -> float: | |
| return max(0.0, (self.allocated - self.usage) / max(self.allocated, 1e-6)) | |
| def get_cost(self) -> float: | |
| if self.is_deleted: | |
| return 0.0 | |
| if self.is_stopped: | |
| return self.cost_rate * 0.05 # minimal maintenance cost | |
| return self.cost_rate | |
| # ── Observation vector (10 dims) ────────────────────────────────────────── | |
| def to_obs(self, fog: bool = False) -> np.ndarray: | |
| if fog and self.fog_active: | |
| risk_obs = 0.0 | |
| cost_obs = 0.5 # agent sees estimated cost when under fog | |
| exp_obs = 0.0 | |
| else: | |
| risk_obs = self.risk_score | |
| cost_obs = self.cost_rate | |
| exp_obs = self.exposure | |
| return np.array([ | |
| float(self.health), # 0 | |
| risk_obs, # 1 (hidden under fog) | |
| self.criticality, # 2 | |
| cost_obs, # 3 (hidden under fog) | |
| self.activity_status, # 4 | |
| exp_obs, # 5 (hidden under fog) | |
| self.sensitivity, # 6 | |
| self.staleness, # 7 (always visible) | |
| float(self.alert_flag), # 8 (always visible for critical) | |
| self.time_broken, # 9 | |
| ], dtype=np.float32) | |
| # ── Per-step tick ───────────────────────────────────────────────────────── | |
| def tick(self, rng: random.Random, phase: int, event_prob: float = 0.0): | |
| if self.is_deleted: | |
| return | |
| # Staleness | |
| self.steps_since_analyze += 1 | |
| self.staleness = min(self.steps_since_analyze / self.STALE_STEPS, 1.0) | |
| if self.steps_since_analyze >= self.STALE_STEPS: | |
| self.fog_active = True # knowledge expires | |
| # Usage drift (only when running) | |
| if not self.is_stopped and self.health: | |
| self.usage = float(np.clip( | |
| self.usage + rng.uniform(-0.03, 0.03), 0.10, self.allocated)) | |
| # Credential aging | |
| self.credential_age = min(self.credential_age + 0.01, 1.0) | |
| # Broken resource tracking | |
| if not self.health: | |
| self.steps_broken += 1 | |
| self.time_broken = min(self.steps_broken / MAX_STEPS, 1.0) | |
| self.risk_score = min(self.risk_score + 0.015, 1.0) | |
| if self.criticality >= 1.0: | |
| self.alert_flag = 1 # high-criticality broken = visible alert | |
| # Random security events (Phase 2+) | |
| if phase >= 2 and rng.random() < event_prob and self.health: | |
| ev = rng.choice(["vuln", "expose", "iam", "log_off"]) | |
| if ev == "vuln": | |
| self.vulnerability = True | |
| self.risk_score = min(self.risk_score + 0.25, 1.0) | |
| elif ev == "expose": | |
| self.exposure = min(self.exposure + 0.35, 1.0) | |
| self.risk_score = min(self.risk_score + 0.20, 1.0) | |
| elif ev == "iam": | |
| self.over_permission = True | |
| self.risk_score = min(self.risk_score + 0.15, 1.0) | |
| elif ev == "log_off": | |
| self.logging_enabled = False | |
| self.risk_score = min(self.risk_score + 0.05, 1.0) | |
| # ── Actions ─────────────────────────────────────────────────────────────── | |
| def do_analyze(self): | |
| self.fog_active = False | |
| self.cost_known = True | |
| self.steps_since_analyze = 0 | |
| self.staleness = 0.0 | |
| def do_verify_deps(self): | |
| self.deps_known = True | |
| def do_resize_down(self) -> float: | |
| """Returns cost delta (positive = saving).""" | |
| new_alloc = max(self.usage + 0.10, 0.25) | |
| if new_alloc < self.allocated - 0.02: | |
| saved = (self.allocated - new_alloc) | |
| self.allocated = new_alloc | |
| self.cost_rate = new_alloc | |
| return saved | |
| return 0.0 | |
| def do_resize_up(self): | |
| self.allocated = min(self.allocated + 0.20, 1.0) | |
| self.cost_rate = self.allocated | |
| def do_stop(self) -> float: | |
| if not self.is_stopped: | |
| self.is_stopped = True | |
| self.activity_status = 0.0 | |
| return self.cost_rate * 0.95 # 95% cost eliminated | |
| return 0.0 | |
| def do_restart(self): | |
| self.is_stopped = False | |
| self.activity_status = 1.0 | |
| self.health = 1 | |
| def do_delete(self) -> float: | |
| saved = self.cost_rate | |
| self.is_deleted = True | |
| self.health = 0 | |
| return saved | |
| def do_patch(self): | |
| self.vulnerability = False | |
| self.risk_score = max(self.risk_score - 0.30, 0.0) | |
| def do_encrypt(self): | |
| self.encryption = True | |
| self.risk_score = max(self.risk_score - 0.15, 0.0) | |
| def do_restrict(self): | |
| self.exposure = max(self.exposure - 0.40, 0.0) | |
| self.risk_score = max(self.risk_score - 0.20, 0.0) | |
| def do_rotate_creds(self): | |
| self.credential_age = 0.0 | |
| self.over_permission = False | |
| self.risk_score = max(self.risk_score - 0.10, 0.0) | |
| def do_enable_logging(self): | |
| self.logging_enabled = True | |
| self.risk_score = max(self.risk_score - 0.05, 0.0) | |
| def do_archive(self) -> float: | |
| if not self.is_stopped: | |
| self.is_stopped = True | |
| self.activity_status = 0.0 | |
| return self.cost_rate * 0.70 | |
| return 0.0 | |
| def do_opt_network(self): | |
| self.exposure = max(self.exposure - 0.15, 0.0) | |
| self.risk_score = max(self.risk_score - 0.08, 0.0) | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| # ENVIRONMENT | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| class CloudArenaEnv(gym.Env): | |
| """ | |
| Cloud Arena: multi-objective cloud operations RL environment. | |
| Observation: 125-dim flat float32. | |
| Action space: Discrete(150) = 15 types × 10 resource slots. | |
| """ | |
| metadata = {"render_modes": []} | |
| def __init__(self, | |
| curriculum_ref: List[int] = None, | |
| global_step_ref: List[int] = None): | |
| super().__init__() | |
| self._curriculum_ref = curriculum_ref or [0] | |
| self._global_step_ref = global_step_ref or [0] | |
| self.observation_space = spaces.Box( | |
| low=-np.inf, high=np.inf, shape=(OBS_DIM,), dtype=np.float32) | |
| self.action_space = spaces.Discrete(N_ACTIONS) | |
| # Episode state (set in reset) | |
| self.resources: List[ResourceObject] = [] | |
| self.n_active = 0 | |
| self.step_count = 0 | |
| self.chaos_active = False | |
| self.chaos_steps = 0 | |
| self.veto_count = 0 | |
| self.cascade_count = 0 | |
| self.initial_total_cost = 1.0 | |
| self.prev_total_cost = 1.0 | |
| self.prev_risk_agg = 0.0 | |
| self._action_hist = deque([0.0] * 8, maxlen=8) | |
| self._reward_hist = deque([0.0] * 8, maxlen=8) | |
| self._progress_hist= deque([0.0] * 8, maxlen=8) | |
| # ── Properties ──────────────────────────────────────────────────────────── | |
| def curriculum_level(self) -> int: | |
| return self._curriculum_ref[0] | |
| # ── Reset ───────────────────────────────────────────────────────────────── | |
| def reset(self, seed=None, options=None): | |
| super().reset(seed=seed) | |
| rng = random.Random(seed if seed is not None else GLOBAL_SEED + self.step_count) | |
| self.step_count = 0 | |
| self.chaos_active = False | |
| self.chaos_steps = 0 | |
| self.veto_count = 0 | |
| self.cascade_count = 0 | |
| phase = self.curriculum_level | |
| scenario = options.get("scenario", 0) if options else 0 | |
| if scenario > 0: | |
| self._setup_boss_scenario(scenario, rng) | |
| else: | |
| self._setup_normal_episode(phase, rng) | |
| self.initial_total_cost = max(sum(r.get_cost() for r in self.resources), 1e-6) | |
| self.prev_total_cost = self.initial_total_cost | |
| self.prev_risk_agg = self._risk_aggregate() | |
| self._action_hist = deque([0.0] * 8, maxlen=8) | |
| self._reward_hist = deque([0.0] * 8, maxlen=8) | |
| self._progress_hist = deque([0.0] * 8, maxlen=8) | |
| return self._build_obs(), {} | |
| def _setup_normal_episode(self, phase: int, rng: random.Random): | |
| """Standard episode with phase-appropriate resources.""" | |
| self.n_active = N_RESOURCES_PHASE[phase] | |
| n = self.n_active | |
| # Criticality distribution: ~20% HIGH, ~40% MED, ~40% LOW | |
| crits = [] | |
| for i in range(n): | |
| if i == 0: | |
| crits.append("HIGH") | |
| elif i < n // 2: | |
| crits.append("MED") | |
| else: | |
| crits.append("LOW") | |
| cats = ["compute", "compute", "storage", "database", | |
| "compute", "storage", "compute", "database", | |
| "compute", "storage"][:n] | |
| self.resources = [] | |
| for i in range(n): | |
| r = ResourceObject(i, crits[i], cats[i], rng) | |
| # Phase 0: full observability — reveal everything upfront | |
| if not PHASE_FOG[phase]: | |
| r.fog_active = False | |
| r.cost_known = True | |
| r.deps_known = True | |
| # Phase 0: no security issues to start (clean state) | |
| if phase == 0: | |
| r.risk_score = rng.uniform(0.02, 0.08) | |
| r.vulnerability = False | |
| r.encryption = True | |
| r.over_permission = False | |
| r.logging_enabled = True | |
| r.exposure = rng.uniform(0.0, 0.05) | |
| else: | |
| # 💥 ANTI-CHEAT FIX: Force the agent to actually do SecOps in Phase 1+! | |
| r.vulnerability = rng.random() < 0.40 | |
| r.encryption = rng.random() > 0.30 # 30% unencrypted | |
| r.over_permission = rng.random() < 0.30 | |
| r.logging_enabled = rng.random() > 0.20 | |
| r.exposure = rng.uniform(0.10, 0.40) | |
| r.risk_score = rng.uniform(0.30, 0.60) | |
| self.resources.append(r) | |
| # Set up simple dependency: resource 0 (HIGH) has children [1] | |
| # This means deleting resource 0 would cascade to resource 1 | |
| # Agent can't delete resource 0 anyway (HIGH criticality), so it's safe | |
| if n >= 2: | |
| self.resources[0].dependency_children = [1] | |
| self.resources[1].dependency_parent = 0 | |
| # Chaos initialization for Phase 3+ | |
| if PHASE_CHAOS[phase] and rng.random() < CHAOS_INIT_PROB[phase]: | |
| self.chaos_active = True | |
| # Break 1-2 non-critical resources | |
| victims = [r for r in self.resources if r.criticality < 1.0][:2] | |
| for v in victims: | |
| v.health = 0 | |
| v.risk_score = min(v.risk_score + 0.40, 1.0) | |
| v.alert_flag = 0 # hidden unless HIGH criticality | |
| def _setup_boss_scenario(self, scenario: int, rng: random.Random): | |
| """Boss fight: predefined stressful starting conditions.""" | |
| phase = max(self.curriculum_level, 3) # boss fights at phase 3+ difficulty | |
| self._setup_normal_episode(phase, rng) | |
| if scenario == 1: # Cost Crisis | |
| for r in self.resources: | |
| r.allocated = min(r.allocated + rng.uniform(0.10, 0.25), 1.0) | |
| r.cost_rate = r.allocated | |
| r.usage = max(r.usage - 0.10, 0.10) | |
| elif scenario == 2: # Security Breach | |
| for r in self.resources: | |
| r.fog_active = True # force fog — agent must analyze | |
| r.cost_known = False | |
| r.vulnerability = (rng.random() < 0.60) | |
| r.encryption = (rng.random() < 0.30) | |
| r.over_permission = (rng.random() < 0.50) | |
| r.logging_enabled = (rng.random() < 0.40) | |
| r.exposure = rng.uniform(0.30, 0.80) | |
| r.risk_score = rng.uniform(0.40, 0.90) | |
| elif scenario == 3: # Infrastructure Failure (NOOP Test) | |
| self.chaos_active = True | |
| for r in self.resources[:3]: | |
| r.health = 0 | |
| r.risk_score = min(r.risk_score + 0.50, 1.0) | |
| elif scenario == 4: # Traffic Surge (underprovisioned) | |
| for r in self.resources: | |
| r.usage = min(r.allocated - 0.05, rng.uniform(0.75, 0.95)) | |
| r.risk_score = min(r.risk_score + 0.10, 0.50) | |
| elif scenario == 5: # Final Boss: everything | |
| self.chaos_active = True | |
| for i, r in enumerate(self.resources): | |
| r.allocated = min(r.allocated + 0.15, 1.0) | |
| r.cost_rate = r.allocated | |
| r.vulnerability = (rng.random() < 0.50) | |
| r.encryption = (rng.random() < 0.40) | |
| r.exposure = rng.uniform(0.20, 0.70) | |
| r.risk_score = rng.uniform(0.30, 0.80) | |
| if i < 2: | |
| r.health = 0 | |
| # ── Step ────────────────────────────────────────────────────────────────── | |
| def step(self, action: int): | |
| action = int(action) | |
| self.step_count += 1 | |
| self._global_step_ref[0] += 1 | |
| atype = action // MAX_RESOURCES | |
| ridx = action % MAX_RESOURCES | |
| phase = self.curriculum_level | |
| # ── Tick all resources ──────────────────────────────────────────────── | |
| event_prob = 0.04 if PHASE_EVENTS[phase] else 0.0 | |
| rng = random.Random(self._global_step_ref[0]) | |
| for r in self.resources: | |
| r.tick(rng, phase, event_prob) | |
| # ── Chaos events (Phase 3+) ─────────────────────────────────────────── | |
| if PHASE_CHAOS[phase] and rng.random() < 0.03: | |
| healthy = [r for r in self.resources if r.health and not r.is_deleted | |
| and r.criticality < 1.0] | |
| if healthy: | |
| victim = rng.choice(healthy) | |
| victim.health = 0 | |
| victim.risk_score = min(victim.risk_score + 0.40, 1.0) | |
| self.chaos_active = True | |
| if self.chaos_active: | |
| self.chaos_steps += 1 | |
| if self.chaos_steps > 20: | |
| self.chaos_active = False # chaos resolves after ~20 steps | |
| # ── Snapshot pre-action state ───────────────────────────────────────── | |
| cost_before = sum(r.get_cost() for r in self.resources) | |
| risk_before = self._risk_aggregate() | |
| # ── Apply action ────────────────────────────────────────────────────── | |
| cost_delta, sec_delta, veto = self._apply_action(atype, ridx) | |
| if veto: | |
| self.veto_count += 1 | |
| # ── Post-action state ───────────────────────────────────────────────── | |
| cost_now = sum(r.get_cost() for r in self.resources) | |
| risk_now = self._risk_aggregate() | |
| # ── Compute reward ──────────────────────────────────────────────────── | |
| reward = self._compute_reward( | |
| atype, ridx, veto, cost_before, cost_now, risk_before, risk_now) | |
| # ── Check win/done ──────────────────────────────────────────────────── | |
| win = self._check_win(cost_now, risk_now, phase) | |
| terminated = win | |
| truncated = (self.step_count >= MAX_STEPS) | |
| if terminated or truncated: | |
| reward += self._terminal_reward(win, cost_now, risk_now, phase) | |
| reward = float(np.clip(reward, -30.0, 60.0)) | |
| else: | |
| reward = float(np.clip(reward, -2.0, 5.0)) | |
| # ── Update history ──────────────────────────────────────────────────── | |
| self._action_hist.append(atype / N_ACTION_TYPES) | |
| self._reward_hist.append(np.clip(reward / 5.0, -1.0, 1.0)) | |
| self._progress_hist.append(max(0.0, (self.initial_total_cost - cost_now) | |
| / max(self.initial_total_cost, 1e-6))) | |
| self.prev_total_cost = cost_now | |
| self.prev_risk_agg = risk_now | |
| info = { | |
| "win": int(win), | |
| "cost_score": float(np.clip(1.0 - cost_now / max(self.initial_total_cost, 1e-6), 0, 1)), | |
| "security_score": float(np.clip(1.0 - risk_now, 0, 1)), | |
| "reliability_score": self._reliability_score(), | |
| "savings_pct": float(np.clip( | |
| (self.initial_total_cost - cost_now) | |
| / max(self.initial_total_cost, 1e-6) * 100, 0, 100)), | |
| "veto_rate": self.veto_count / max(self.step_count, 1), | |
| "cascade_count": self.cascade_count, | |
| "risk": risk_now, | |
| "chaos_active": self.chaos_active, | |
| } | |
| return self._build_obs(), reward, terminated, truncated, info | |
| # ── Action application ──────────────────────────────────────────────────── | |
| def _apply_action(self, atype: int, ridx: int) -> Tuple[float, float, bool]: | |
| """Returns (cost_delta, security_delta, was_vetoed).""" | |
| if atype == A_NOOP: | |
| return 0.0, 0.0, False # NOOP is never a veto | |
| # Validate resource index | |
| if ridx >= len(self.resources): | |
| return 0.0, 0.0, True | |
| r = self.resources[ridx] | |
| if r.is_deleted: | |
| return 0.0, 0.0, True | |
| cost_before = r.get_cost() | |
| risk_before = r.risk_score | |
| veto = False | |
| if atype == A_ANALYZE: | |
| r.do_analyze() | |
| elif atype == A_VERIFY_DEPS: | |
| r.do_verify_deps() | |
| elif atype == A_RESIZE_DOWN: | |
| if r.overprovision_ratio() > 0.08 and not r.is_stopped: | |
| r.do_resize_down() | |
| else: | |
| veto = True | |
| elif atype == A_RESIZE_UP: | |
| if r.usage > r.allocated - 0.12: | |
| r.do_resize_up() | |
| else: | |
| veto = True | |
| elif atype == A_STOP: | |
| can_stop = (not r.is_stopped and | |
| (r.activity_status < 0.35 or r.criticality <= 0.3) and | |
| r.criticality < 1.0) | |
| if can_stop: | |
| r.do_stop() | |
| else: | |
| veto = True | |
| elif atype == A_RESTART: | |
| if r.is_stopped: | |
| r.do_restart() | |
| else: | |
| veto = True | |
| elif atype == A_DELETE: | |
| can_delete = (r.deps_known and r.criticality < 1.0 and not r.is_stopped) | |
| if can_delete: | |
| has_crit_child = any( | |
| (ci < len(self.resources) and | |
| not self.resources[ci].is_deleted and | |
| self.resources[ci].criticality >= 0.6) | |
| for ci in r.dependency_children) | |
| if has_crit_child: | |
| veto = True | |
| else: | |
| r.do_delete() | |
| for ci in r.dependency_children: | |
| if ci < len(self.resources) and not self.resources[ci].is_deleted: | |
| child = self.resources[ci] | |
| child.health = 0 | |
| child.risk_score = min(child.risk_score + 0.3, 1.0) | |
| self.cascade_count += 1 | |
| else: | |
| veto = True | |
| elif atype == A_PATCH: | |
| if r.vulnerability: | |
| r.do_patch() | |
| else: | |
| veto = True | |
| elif atype == A_ENCRYPT: | |
| if not r.encryption: | |
| r.do_encrypt() | |
| else: | |
| veto = True | |
| elif atype == A_RESTRICT: | |
| if r.exposure > 0.15: | |
| r.do_restrict() | |
| else: | |
| veto = True | |
| elif atype == A_ROTATE_CREDS: | |
| if r.credential_age > 0.40: | |
| r.do_rotate_creds() | |
| else: | |
| veto = True | |
| elif atype == A_ENABLE_LOG: | |
| if not r.logging_enabled: | |
| r.do_enable_logging() | |
| else: | |
| veto = True | |
| elif atype == A_ARCHIVE: | |
| if r.category == "storage" and r.activity_status < 0.35: | |
| r.do_archive() | |
| else: | |
| veto = True | |
| elif atype == A_OPT_NET: | |
| if r.exposure > 0.08: | |
| r.do_opt_network() | |
| else: | |
| veto = True | |
| cost_after = r.get_cost() if not r.is_deleted else 0.0 | |
| risk_after = r.risk_score if not r.is_deleted else 0.0 | |
| return (cost_before - cost_after), (risk_before - risk_after), veto | |
| # ── Reward ──────────────────────────────────────────────────────────────── | |
| def _compute_reward(self, atype, ridx, veto, | |
| cost_before, cost_now, risk_before, risk_now) -> float: | |
| phase = self.curriculum_level | |
| w_cost = 0.25 | |
| w_sec = 0.35 if phase >= 1 else 0.0 | |
| w_stab = 0.25 | |
| # ── 1. Dense cost channel ───────────────────────────────────────────── | |
| r_cost = -w_cost * (cost_now / max(self.initial_total_cost, 1e-6)) | |
| # ── 2. Dense security channel ───────────────────────────────────────── | |
| r_sec = -w_sec * risk_now | |
| # ── 3. Stability/reliability ────────────────────────────────────────── | |
| n_broken = sum(1 for r in self.resources if not r.health and not r.is_deleted) | |
| r_stab = -w_stab * (n_broken / max(len(self.resources), 1)) | |
| # ── 4. Delta reward (THE MOST IMPORTANT SIGNAL) ─────────────────────── | |
| # Positive when agent caused improvement, zero otherwise | |
| cost_improvement = (cost_before - cost_now) / max(self.initial_total_cost, 1e-6) | |
| risk_improvement = risk_before - risk_now | |
| r_delta = 3.0 * cost_improvement # strong signal for cost savings | |
| r_delta += 4.0 * risk_improvement # strong signal for security improvements | |
| r_delta = float(np.clip(r_delta, -1.0, 2.0)) | |
| # ── 5. NOOP shaping ─────────────────────────────────────────────────── | |
| if atype == A_NOOP: | |
| if self.chaos_active: | |
| r_noop = +0.10 # correct — don't touch things during chaos | |
| elif risk_now < 0.10 and cost_now < self.initial_total_cost * 0.60: | |
| r_noop = +0.05 # correct — system is genuinely healthy | |
| elif risk_now < 0.25: | |
| r_noop = +0.01 # acceptable | |
| elif risk_now < 0.50: | |
| r_noop = -0.05 # negligence | |
| else: | |
| r_noop = -0.15 # gross negligence | |
| else: | |
| r_noop = 0.0 | |
| # ── 6. Action cost penalty ──────────────────────────────────────────── | |
| r_action = ACTION_COSTS.get(atype, -0.02) | |
| # ── 7. Veto penalty ─────────────────────────────────────────────────── | |
| r_veto = -0.10 if veto else 0.0 | |
| # ── 8. Temporal neglect ─────────────────────────────────────────────── | |
| # Phase 1+: growing penalty for ignoring known high-risk resources | |
| r_neglect = 0.0 | |
| if phase >= 1: | |
| for r in self.resources: | |
| if (not r.fog_active and not r.is_deleted and | |
| r.risk_score > 0.60): | |
| neglect_scale = min(r.steps_broken / MAX_STEPS, 1.0) | |
| r_neglect -= 0.02 * (1.0 + neglect_scale) * r.criticality | |
| r_neglect = max(r_neglect, -0.20) | |
| total = r_cost + r_sec + r_stab + r_delta + r_noop + r_action + r_veto + r_neglect | |
| return float(total) | |
| def _terminal_reward(self, win: bool, cost_now: float, | |
| risk_now: float, phase: int) -> float: | |
| r = 0.0 | |
| if win: | |
| speed_bonus = 10.0 * (1.0 - self.step_count / MAX_STEPS) | |
| r += 15.0 + speed_bonus | |
| else: | |
| # Partial credit | |
| cost_reduction = (self.initial_total_cost - cost_now) / max(self.initial_total_cost, 1e-6) | |
| r += 3.0 * max(cost_reduction, 0.0) | |
| r -= 5.0 # timeout penalty | |
| r -= 10.0 * risk_now # end-state security penalty | |
| if self.cascade_count > 0: | |
| r -= 5.0 * min(self.cascade_count, 3) | |
| return r | |
| # ── Win condition ───────────────────────────────────────────────────────── | |
| def _check_win(self, cost_now: float, risk_now: float, phase: int) -> bool: | |
| cost_ratio = cost_now / max(self.initial_total_cost, 1e-6) | |
| cost_win = cost_ratio < WIN_COST_THR[phase] | |
| sec_score = 1.0 - risk_now | |
| sec_win = sec_score >= WIN_SEC_THR[phase] | |
| # No critical resources broken | |
| no_crit_broken = not any( | |
| r.criticality >= 1.0 and not r.health and not r.is_deleted | |
| for r in self.resources) | |
| return cost_win and sec_win and no_crit_broken | |
| # ── Observation ─────────────────────────────────────────────────────────── | |
| def _build_obs(self) -> np.ndarray: | |
| phase = self.curriculum_level | |
| fog = PHASE_FOG[phase] | |
| # Block A: resource observations (padded to MAX_RES_IN_OBS) | |
| block_a = np.zeros(MAX_RES_IN_OBS * N_FEAT_PER_RES, dtype=np.float32) | |
| for i, r in enumerate(self.resources[:MAX_RES_IN_OBS]): | |
| block_a[i * N_FEAT_PER_RES: (i + 1) * N_FEAT_PER_RES] = r.to_obs(fog) | |
| # Block B: global security (8 dims) | |
| active = [r for r in self.resources if not r.is_deleted] | |
| n_a = max(len(active), 1) | |
| risk_agg = self._risk_aggregate() | |
| n_vuln = sum(1 for r in active if r.vulnerability) | |
| n_exposed = sum(1 for r in active if r.exposure > 0.3) | |
| n_unenc = sum(1 for r in active if not r.encryption) | |
| n_no_log = sum(1 for r in active if not r.logging_enabled) | |
| n_overperm = sum(1 for r in active if r.over_permission) | |
| block_b = np.array([ | |
| risk_agg, | |
| n_vuln / n_a, | |
| n_exposed / n_a, | |
| n_unenc / n_a, | |
| n_no_log / n_a, | |
| n_overperm / n_a, | |
| min(sum(r.credential_age for r in active) / n_a, 1.0), | |
| float(self.chaos_active), | |
| ], dtype=np.float32) | |
| # Block C: global cost (7 dims) | |
| total_cost = sum(r.get_cost() for r in self.resources) | |
| n_idle = sum(1 for r in active if r.activity_status < 0.3) | |
| n_overprov = sum(1 for r in active if r.overprovision_ratio() > 0.2) | |
| n_stopped = sum(1 for r in self.resources if r.is_stopped) | |
| n_deleted = sum(1 for r in self.resources if r.is_deleted) | |
| block_c = np.array([ | |
| total_cost / max(self.initial_total_cost, 1e-6), | |
| n_idle / n_a, | |
| n_overprov / n_a, | |
| n_stopped / max(len(self.resources), 1), | |
| n_deleted / max(len(self.resources), 1), | |
| (self.initial_total_cost - total_cost) / max(self.initial_total_cost, 1e-6), | |
| float(self._check_win(total_cost, risk_agg, self.curriculum_level)), | |
| ], dtype=np.float32) | |
| # Block D: environment state (6 dims) | |
| n_broken = sum(1 for r in active if not r.health) | |
| block_d = np.array([ | |
| self.step_count / MAX_STEPS, | |
| self.curriculum_level / 5.0, | |
| float(self.chaos_active), | |
| n_broken / n_a, | |
| self.veto_count / max(self.step_count, 1), | |
| self.cascade_count / max(n_a, 1), | |
| ], dtype=np.float32) | |
| # Block E: history (24 dims) | |
| block_e = np.array( | |
| list(self._action_hist) + | |
| list(self._reward_hist) + | |
| list(self._progress_hist), | |
| dtype=np.float32) | |
| obs = np.concatenate([block_a, block_b, block_c, block_d, block_e]) | |
| assert obs.shape == (OBS_DIM,), f"Obs shape {obs.shape} != {OBS_DIM}" | |
| return obs | |
| # ── Action masks ────────────────────────────────────────────────────────── | |
| def action_masks(self) -> np.ndarray: | |
| mask = np.zeros(N_ACTIONS, dtype=bool) | |
| # NOOP (action 0) — always valid | |
| mask[A_NOOP * MAX_RESOURCES] = True | |
| for ridx in range(MAX_RESOURCES): | |
| # Resources beyond active set are always invalid | |
| if ridx >= len(self.resources): | |
| # Only NOOP is already set; skip rest | |
| continue | |
| r = self.resources[ridx] | |
| if r.is_deleted: | |
| continue | |
| aid = lambda atype: atype * MAX_RESOURCES + ridx # noqa | |
| # ANALYZE — always valid (costs a small amount) | |
| mask[aid(A_ANALYZE)] = True | |
| # VERIFY_DEPS — always valid | |
| mask[aid(A_VERIFY_DEPS)] = True | |
| # 💥 ANTI-CHEAT FIX: If fog is active, the agent CANNOT execute these actions! | |
| if r.fog_active: | |
| continue # Skips evaluating the rest, keeping them False (Masked) | |
| # --- ONLY EVALUATED IF FOG IS LIFTED --- | |
| # RESIZE_DOWN — valid if overprovisioned and running | |
| mask[aid(A_RESIZE_DOWN)] = (r.overprovision_ratio() > 0.08 | |
| and not r.is_stopped) | |
| # RESIZE_UP — valid if near capacity | |
| mask[aid(A_RESIZE_UP)] = (r.usage > r.allocated - 0.12 | |
| and not r.is_stopped) | |
| # STOP — valid if idle or LOW criticality and currently running | |
| mask[aid(A_STOP)] = (not r.is_stopped | |
| and r.criticality < 1.0 | |
| and (r.activity_status < 0.35 or r.criticality <= 0.3)) | |
| # RESTART — valid if stopped | |
| mask[aid(A_RESTART)] = r.is_stopped | |
| # DELETE — valid if deps known, not critical, no critical children | |
| has_crit_child = any( | |
| (ci < len(self.resources) and | |
| not self.resources[ci].is_deleted and | |
| self.resources[ci].criticality >= 0.6) | |
| for ci in r.dependency_children) | |
| mask[aid(A_DELETE)] = (r.deps_known and r.criticality < 1.0 | |
| and not has_crit_child) | |
| # Security fixes (Phase 1+) | |
| mask[aid(A_PATCH)] = r.vulnerability | |
| mask[aid(A_ENCRYPT)] = not r.encryption | |
| mask[aid(A_RESTRICT)] = r.exposure > 0.15 | |
| mask[aid(A_ROTATE_CREDS)] = r.credential_age > 0.40 | |
| mask[aid(A_ENABLE_LOG)] = not r.logging_enabled | |
| mask[aid(A_ARCHIVE)] = (r.category == "storage" | |
| and r.activity_status < 0.35) | |
| mask[aid(A_OPT_NET)] = r.exposure > 0.08 | |
| # Collapse guard: always at least 3 valid actions | |
| if mask.sum() < 3: | |
| mask[A_NOOP * MAX_RESOURCES] = True | |
| if len(self.resources) > 0: | |
| mask[A_ANALYZE * MAX_RESOURCES] = True | |
| if len(self.resources) > 1: | |
| mask[A_ANALYZE * MAX_RESOURCES + 1] = True | |
| return mask | |
| # ── Helpers ─────────────────────────────────────────────────────────────── | |
| def _risk_aggregate(self) -> float: | |
| active = [r for r in self.resources if not r.is_deleted] | |
| if not active: | |
| return 0.0 | |
| weighted = sum(r.risk_score * r.criticality for r in active) | |
| total_w = sum(r.criticality for r in active) | |
| return weighted / max(total_w, 1e-6) | |
| def _reliability_score(self) -> float: | |
| active = [r for r in self.resources if not r.is_deleted] | |
| if not active: | |
| return 0.0 | |
| broken_w = sum(r.criticality for r in active if not r.health) | |
| total_w = sum(r.criticality for r in active) | |
| return max(0.0, 1.0 - broken_w / max(total_w, 1e-6)) | |
| def render(self): pass | |
| # ── Gymnasium wrapper ───────────────────────────────────────────────────────── | |
| from sb3_contrib.common.wrappers import ActionMasker | |
| def get_action_masks(env) -> np.ndarray: | |
| """Extract mask through ActionMasker wrapper.""" | |
| inner = env | |
| while hasattr(inner, "env"): | |
| inner = inner.env | |
| return inner.action_masks() | |