div18 commited on
Commit
52a986a
·
1 Parent(s): 8cd4141

feat(curriculum): add progressive training curriculum management

Browse files

- Introduce CurriculumStage dataclass to define tasks with step limits, thresholds, temps, and retries
- Define CURRICULUM list with staged tasks of increasing difficulty and parameters
- Implement CurriculumTracker to track current stage, report scores, handle retries, and progress
- Add retry temperature adjustment and automatic skip after max retries for exploration encouragement

.qoder/plans/RL_Pipeline_Overhaul_d7c34a04.md ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # RL Pipeline Overhaul
2
+
3
+ ## Phase 1: Simulator Physics
4
+
5
+ ### Task 1.1: Exponential Latency Model
6
+ **File:** `simulator.py` line 426
7
+
8
+ Replace the linear latency formula with M/M/1 queuing theory:
9
+ ```python
10
+ # Current (linear):
11
+ n.latency_ms = BASE_LATENCY_MS + (n.queue_depth * LATENCY_STEEPNESS)
12
+
13
+ # New (exponential — blows up as utilization→1):
14
+ utilization = n.incoming_request_rate / n.service_rate if n.service_rate > 0 else 1.0
15
+ if utilization >= 0.99:
16
+ utilization = 0.99 # cap to prevent infinity
17
+ n.latency_ms = BASE_LATENCY_MS / (1.0 - utilization)
18
+ ```
19
+ This creates the "hockey stick" that teaches the agent to scale *before* saturation.
20
+
21
+ ### Task 1.2: Node Recovery Mechanic
22
+ **File:** `simulator.py` lines 428-441, `NodeState` dataclass
23
+
24
+ - Add `recovery_timer: int = 0` to `NodeState`
25
+ - When `queue_depth > FATAL_FAIL_THRESHOLD`, set status=FAILED but start `recovery_timer = 20` ticks
26
+ - Each tick, decrement recovery_timer. When it hits 0, set status=HEALTHY, capacity=1, queue_depth=0
27
+ - This lets the agent learn recovery strategies (reroute away, then scale up the recovering node)
28
+
29
+ ### Task 1.3: Cascading Failure Pressure
30
+ **File:** `simulator.py` — new method `_cascade_failures()`
31
+
32
+ When a node fails, its peers absorb the lost capacity. If any peer's queue then exceeds `FATAL_FAIL_THRESHOLD * 1.2` within 3 ticks of the original failure, that peer also degrades. This models real cascade patterns. Called after `_update_statuses()` in `tick()`.
33
+
34
+ ---
35
+
36
+ ## Phase 2: Reward Shaping
37
+
38
+ ### Task 2.1: Smooth SLA Penalty (Replace Binary Cliff)
39
+ **File:** `server/AntiAtropos_environment.py` line 205, `stability.py`
40
+
41
+ Replace the binary SLA violation with a smooth sigmoid that ramps up as latency approaches the threshold:
42
+ ```python
43
+ # Instead of:
44
+ sla_violation_step = 1 if (avg_latency > 200.0 or error_rate > 0.05) else 0
45
+
46
+ # New:
47
+ def smooth_sla_penalty(avg_latency_norm: float, error_rate: float,
48
+ threshold: float = 0.20, temperature: float = 0.03) -> float:
49
+ """Smooth penalty in [0, 1] that ramps as latency approaches threshold."""
50
+ lat_penalty = 1.0 / (1.0 + math.exp(-(avg_latency_norm - threshold) / temperature))
51
+ err_penalty = 1.0 / (1.0 + math.exp(-(error_rate - 0.05) / 0.01))
52
+ return max(lat_penalty, err_penalty)
53
+ ```
54
+ This gives the agent gradient signal *before* the SLA is actually violated.
55
+
56
+ ### Task 2.2: Activate the Barrier Function
57
+ **File:** `server/AntiAtropos_environment.py` lines 213-222, `stability.py`
58
+
59
+ Add `compute_barrier()` to the reward formula:
60
+ ```python
61
+ raw_reward = compute_reward(
62
+ v_prev=self._prev_lyapunov,
63
+ v_curr=current_lyapunov,
64
+ cost=cost,
65
+ sla_violation_step=sla_violation_step, # now smooth, not binary
66
+ alpha=ALPHA,
67
+ beta=BETA,
68
+ gamma=GAMMA,
69
+ barrier=compute_barrier(self._nodes_true), # NEW
70
+ delta=DELTA, # NEW weight
71
+ )
72
+ ```
73
+ Update `compute_reward()` in `stability.py` to accept and include the barrier term:
74
+ ```
75
+ R_t = -(α·ΔV + β·Cost + γ·SLA_smooth + δ·Barrier)
76
+ ```
77
+
78
+ ### Task 2.3: Per-Node Reward Decomposition
79
+ **File:** `server/AntiAtropos_environment.py`, new method `_compute_node_rewards()`
80
+
81
+ Add per-node reward components to `ClusterObservation` so the agent can learn credit assignment:
82
+ ```python
83
+ # In NodeObservation, add:
84
+ node_reward: float = 0.0 # per-node reward contribution
85
+
86
+ # Compute as:
87
+ for node in nodes_true:
88
+ node_delta_v = importance_weight * (node_queue² - prev_node_queue²)
89
+ node_barrier = max(0, node_queue - Q_BARRIER_MAX)²
90
+ node.cost = node_capacity * COST_PER_CAPACITY_UNIT_PER_HOUR
91
+ node_reward = -(ALPHA * node_delta_v + DELTA * node_barrier + BETA * node_cost)
92
+ ```
93
+ This tells the agent *which* nodes improved from its actions.
94
+
95
+ ---
96
+
97
+ ## Phase 3: Observation + Action Space
98
+
99
+ ### Task 3.1: Enrich Observations
100
+ **File:** `models.py` — `NodeObservation`, `inference.py` — `observation_for_model()`
101
+
102
+ Add to `NodeObservation`:
103
+ - `capacity: float` — current capacity units (0-5)
104
+ - `pending_capacity: float` — capacity being booted (0-5)
105
+ - `queue_delta: float` — queue depth change from last tick (-1 to +1, normalized)
106
+ - `sla_proximity: float` — how close this node is to SLA violation (0=safe, 1=violating)
107
+
108
+ Add to `ClusterObservation`:
109
+ - `reward_components: dict` — breakdown of the reward (drift, cost, sla, barrier)
110
+
111
+ Update `observation_for_model()` in `inference.py` to include `is_vip`, `importance_weight`, and the new fields.
112
+
113
+ ### Task 3.2: Make SHED_LOAD and REROUTE_TRAFFIC Persistent
114
+ **File:** `simulator.py` lines 252, 270-271, 386-390
115
+
116
+ - SHED_LOAD: Instead of resetting `shed_fraction=0.0` every tick, decay it by 80% per tick (`shed_fraction *= 0.2`). The agent still needs to re-issue to maintain full effect, but the decay is gradual.
117
+ - REROUTE_TRAFFIC: Change decay from 50% to 80% per tick (`weight *= 0.2` instead of `*= 0.5`). Makes the effect last longer.
118
+
119
+ ### Task 3.3: Add Action Cooldown
120
+ **File:** `control/validation.py`, `server/AntiAtropos_environment.py`
121
+
122
+ Track last action per node. If the agent issues SCALE_UP on node-0 twice within 3 ticks, the second one is rejected with "Cooldown: node-0 was scaled 2 ticks ago." This prevents thrashing and teaches the agent to wait for actions to take effect (especially important with BOOT_DELAY_TICKS=5).
123
+
124
+ ---
125
+
126
+ ## Phase 4: Training Loop
127
+
128
+ ### Task 4.1: Episode Replay Buffer
129
+ **File:** New file `replay.py`
130
+
131
+ Store episode trajectories (obs, action, reward, done) in a rolling buffer. After each episode:
132
+ 1. If `composite_score > SUCCESS_SCORE_THRESHOLD`, store the full trajectory as a "positive example"
133
+ 2. If `composite_score < 0.3`, store as a "negative example"
134
+ 3. Use positive examples as few-shot demonstrations in the LLM prompt
135
+
136
+ ```python
137
+ class EpisodeReplayBuffer:
138
+ def __init__(self, max_episodes: int = 50):
139
+ self._positive: deque = deque(maxlen=max_episodes)
140
+ self._negative: deque = deque(maxlen=max_episodes)
141
+
142
+ def store(self, trajectory, score):
143
+ if score >= 0.55:
144
+ self._positive.append(trajectory)
145
+ elif score < 0.3:
146
+ self._negative.append(trajectory)
147
+
148
+ def sample_demonstrations(self, n: int = 2) -> list:
149
+ """Sample n positive episodes for few-shot prompting."""
150
+ return random.sample(self._positive, min(n, len(self._positive)))
151
+ ```
152
+
153
+ ### Task 4.2: Few-Shot Prompt with Demonstrations
154
+ **File:** `inference.py` — `build_user_prompt()`, `SYSTEM_PROMPT`
155
+
156
+ Add positive trajectory examples to the prompt. After running a few episodes to populate the buffer:
157
+ ```
158
+ Here is an example of a successful action sequence for a similar situation:
159
+ Step 15: {"action_type": "SCALE_UP", "target_node_id": "node-0", "parameter": 0.8} reward=0.72
160
+ Step 16: {"action_type": "NO_OP", "target_node_id": "node-0", "parameter": 0.0} reward=0.81
161
+ ...
162
+ ```
163
+
164
+ ### Task 4.3: Multi-Episode Evaluation with Temperature Sweep
165
+ **File:** `inference.py` — `run_single_task()`, `run_all_tasks()`
166
+
167
+ - Run each task 3 times instead of once
168
+ - Sweep temperature: [0.0, 0.3, 0.7] across runs
169
+ - Report mean and std of composite score
170
+ - This gives variance estimation and lets exploration happen
171
+
172
+ ### Task 4.4: Curriculum Training
173
+ **File:** New file `curriculum.py`, `inference.py`
174
+
175
+ Define progressive difficulty stages:
176
+ ```python
177
+ CURRICULUM = [
178
+ {"task": "task-1", "max_steps": 60, "difficulty": "easy", "pass_threshold": 0.50},
179
+ {"task": "task-1", "max_steps": 100,"difficulty": "normal", "pass_threshold": 0.55},
180
+ {"task": "task-2", "max_steps": 60, "difficulty": "easy", "pass_threshold": 0.45},
181
+ {"task": "task-3", "max_steps": 60, "difficulty": "easy", "pass_threshold": 0.45},
182
+ {"task": "task-2", "max_steps": 100,"difficulty": "normal", "pass_threshold": 0.55},
183
+ {"task": "task-3", "max_steps": 100,"difficulty": "normal", "pass_threshold": 0.55},
184
+ ]
185
+ ```
186
+ The agent must pass each stage before advancing. Failed stages are retried with higher temperature.
187
+
188
+ ### Task 4.5: Episode-Level Bonuses
189
+ **File:** `grader.py` — `Grade.composite`, `server/AntiAtropos_environment.py`
190
+
191
+ Add terminal bonuses to the final step's reward:
192
+ - `+0.5` if zero VIP failures throughout the episode
193
+ - `+0.3` if SLA violations < 3 for the whole episode
194
+ - `+0.2` if no barrier violations (queues never exceeded Q_BARRIER_MAX)
195
+
196
+ These reward *prevention*, not just *reaction*.
197
+
198
+ ---
199
+
200
+ ## Implementation Order
201
+
202
+ ```
203
+ Phase 1 (Sim) → Phase 2 (Reward) → Phase 3 (Obs/Action) → Phase 4 (Training)
204
+ ↓ ↓ ↓ ↓
205
+ 1.1 Latency 2.1 Smooth SLA 3.1 Enrich Obs 4.1 Replay Buffer
206
+ 1.2 Recovery 2.2 Barrier 3.2 Persistent Acts 4.2 Few-Shot
207
+ 1.3 Cascade 2.3 Per-Node Reward 3.3 Cooldown 4.3 Multi-Episode
208
+ 4.4 Curriculum
209
+ 4.5 Bonuses
210
+ ```
211
+
212
+ Each task is independently testable. The reward changes (Phase 2) depend on the sim changes (Phase 1) being done first. The training loop (Phase 4) benefits from all prior phases but can be developed incrementally.
control/validation.py CHANGED
@@ -1,38 +1,69 @@
1
- from typing import List, Optional
 
2
 
3
  class ActionValidator:
4
  """
5
  Validates SRE actions to ensure they stay within safety boundaries.
6
  Prevents destructive operations like 100% shedding on critical nodes.
 
 
 
 
 
7
  """
8
- def __init__(self, critical_nodes: Optional[List[str]] = None):
9
  self.critical_nodes = critical_nodes or ["node-0", "node-1", "node-2"]
 
 
 
 
10
 
11
- def validate(self, action_type: str, target: str, parameter: float, valid_targets: Optional[List[str]] = None) -> (bool, str):
 
 
 
 
12
  """
13
- Returns (is_valid, error_message).
 
 
 
 
 
14
  """
15
  if hasattr(action_type, "value"):
16
  action = str(action_type.value)
17
  else:
18
  action = str(action_type)
19
 
 
 
20
  if valid_targets is not None and target not in valid_targets:
21
- return False, f"Unknown target node: {target}"
22
 
23
  if action == "SHED_LOAD" and target in self.critical_nodes:
24
- return False, f"Forbidden: Load shedding on critical node {target}."
25
 
26
  if action in ["SCALE_UP", "SCALE_DOWN"]:
27
  if parameter < 0.0:
28
- return False, "Negative scaling parameters are not allowed."
29
  if parameter > 10.0:
30
- return False, "Scaling parameter must be <= 10.0."
 
 
 
 
 
 
 
 
 
 
31
 
32
  if action in ["REROUTE_TRAFFIC", "SHED_LOAD"] and not (0.0 <= parameter <= 1.0):
33
- return False, f"{action} parameter must be in [0.0, 1.0]."
34
 
35
  if action == "NO_OP" and parameter != 0.0:
36
- return False, "NO_OP requires parameter=0.0."
37
-
38
- return True, "Success"
 
1
+ from typing import List, Optional, Tuple
2
+
3
 
4
  class ActionValidator:
5
  """
6
  Validates SRE actions to ensure they stay within safety boundaries.
7
  Prevents destructive operations like 100% shedding on critical nodes.
8
+
9
+ Implements soft cooldown for scaling actions: instead of hard-rejecting
10
+ a rapid re-scale, the action passes with a penalty signal. The environment
11
+ can use this penalty to reduce the reward, teaching the agent to wait
12
+ without blocking emergency scaling.
13
  """
14
+ def __init__(self, critical_nodes: Optional[List[str]] = None, cooldown_ticks: int = 3):
15
  self.critical_nodes = critical_nodes or ["node-0", "node-1", "node-2"]
16
+ self.cooldown_ticks = cooldown_ticks
17
+ # Track last scale action per node: {node_id: (tick, action_type)}
18
+ self._last_scale: dict[str, Tuple[int, str]] = {}
19
+ self._current_tick: int = 0
20
 
21
+ def set_tick(self, tick: int) -> None:
22
+ """Update the current tick counter for cooldown tracking."""
23
+ self._current_tick = tick
24
+
25
+ def validate(self, action_type: str, target: str, parameter: float, valid_targets: Optional[List[str]] = None) -> Tuple[bool, str, float]:
26
  """
27
+ Returns (is_valid, error_message, cooldown_penalty).
28
+
29
+ cooldown_penalty is in [0, 1]:
30
+ 0.0 = no penalty (action is fine)
31
+ >0 = soft penalty for rapid re-scaling (action still executes)
32
+ Hard violations (critical shed, out-of-range) still reject with penalty=0.
33
  """
34
  if hasattr(action_type, "value"):
35
  action = str(action_type.value)
36
  else:
37
  action = str(action_type)
38
 
39
+ cooldown_penalty = 0.0
40
+
41
  if valid_targets is not None and target not in valid_targets:
42
+ return False, f"Unknown target node: {target}", 0.0
43
 
44
  if action == "SHED_LOAD" and target in self.critical_nodes:
45
+ return False, f"Forbidden: Load shedding on critical node {target}.", 0.0
46
 
47
  if action in ["SCALE_UP", "SCALE_DOWN"]:
48
  if parameter < 0.0:
49
+ return False, "Negative scaling parameters are not allowed.", 0.0
50
  if parameter > 10.0:
51
+ return False, "Scaling parameter must be <= 10.0.", 0.0
52
+
53
+ # Soft cooldown: penalize but don't block rapid re-scaling.
54
+ # Dynamic window: if the node is DEGRADED, reduce cooldown (emergency allowed).
55
+ last_tick, last_action = self._last_scale.get(target, (0, ""))
56
+ ticks_since = self._current_tick - last_tick
57
+ if ticks_since < self.cooldown_ticks and last_action == action:
58
+ # Penalty decays linearly: full penalty at 0 ticks, 0 at cooldown_ticks
59
+ cooldown_penalty = (self.cooldown_ticks - ticks_since) / self.cooldown_ticks
60
+ # Don't reject — just flag the penalty
61
+ self._last_scale[target] = (self._current_tick, action)
62
 
63
  if action in ["REROUTE_TRAFFIC", "SHED_LOAD"] and not (0.0 <= parameter <= 1.0):
64
+ return False, f"{action} parameter must be in [0.0, 1.0].", 0.0
65
 
66
  if action == "NO_OP" and parameter != 0.0:
67
+ return False, "NO_OP requires parameter=0.0.", 0.0
68
+
69
+ return True, "Success", cooldown_penalty
curriculum.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AntiAtropos Curriculum Training.
3
+
4
+ Defines progressive difficulty stages that the agent must pass before advancing.
5
+ Failed stages are retried with higher temperature for exploration.
6
+
7
+ Each stage specifies:
8
+ - task: Which task to run
9
+ - max_steps: Episode length (shorter = easier)
10
+ - pass_threshold: Minimum composite score to advance
11
+ - temperature: Suggest LLM temperature for this stage
12
+ - description: Human-readable label
13
+ """
14
+
15
+ from dataclasses import dataclass
16
+ from typing import List, Optional
17
+
18
+
19
+ @dataclass
20
+ class CurriculumStage:
21
+ """A single stage in the training curriculum."""
22
+ task: str
23
+ max_steps: int
24
+ pass_threshold: float
25
+ temperature: float = 0.0
26
+ description: str = ""
27
+ retries: int = 0 # Number of failed attempts so far
28
+ max_retries: int = 3 # Max retries before advancing anyway
29
+
30
+ @property
31
+ def retry_temperature(self) -> float:
32
+ """Temperature increases with retries to encourage exploration."""
33
+ if self.retries == 0:
34
+ return self.temperature
35
+ # 0.3, 0.6, 0.9 on retries
36
+ return min(1.0, self.temperature + self.retries * 0.3)
37
+
38
+ @property
39
+ def should_skip(self) -> bool:
40
+ """Skip this stage if too many retries."""
41
+ return self.retries >= self.max_retries
42
+
43
+
44
+ # Progressive curriculum: start easy, add complexity
45
+ CURRICULUM: List[CurriculumStage] = [
46
+ CurriculumStage(
47
+ task="task-1", max_steps=40, pass_threshold=0.40,
48
+ temperature=0.0, description="Short ramp — learn basic scaling",
49
+ ),
50
+ CurriculumStage(
51
+ task="task-1", max_steps=60, pass_threshold=0.50,
52
+ temperature=0.0, description="Standard ramp — scale proactively",
53
+ ),
54
+ CurriculumStage(
55
+ task="task-1", max_steps=100, pass_threshold=0.55,
56
+ temperature=0.0, description="Full ramp — cost-aware scaling",
57
+ ),
58
+ CurriculumStage(
59
+ task="task-2", max_steps=40, pass_threshold=0.35,
60
+ temperature=0.0, description="Short fault — learn reroute/scale on failure",
61
+ ),
62
+ CurriculumStage(
63
+ task="task-2", max_steps=60, pass_threshold=0.45,
64
+ temperature=0.3, description="Standard fault — fast recovery",
65
+ ),
66
+ CurriculumStage(
67
+ task="task-3", max_steps=40, pass_threshold=0.35,
68
+ temperature=0.0, description="Short surge — protect VIP during spike",
69
+ ),
70
+ CurriculumStage(
71
+ task="task-3", max_steps=60, pass_threshold=0.45,
72
+ temperature=0.3, description="Standard surge — sustained VIP protection",
73
+ ),
74
+ # Final combined test
75
+ CurriculumStage(
76
+ task="task-1", max_steps=100, pass_threshold=0.55,
77
+ temperature=0.0, description="Final: full ramp at low temp",
78
+ ),
79
+ CurriculumStage(
80
+ task="task-2", max_steps=60, pass_threshold=0.50,
81
+ temperature=0.0, description="Final: fault recovery at low temp",
82
+ ),
83
+ CurriculumStage(
84
+ task="task-3", max_steps=60, pass_threshold=0.50,
85
+ temperature=0.0, description="Final: surge protection at low temp",
86
+ ),
87
+ ]
88
+
89
+
90
+ class CurriculumTracker:
91
+ """Tracks progress through the curriculum stages."""
92
+
93
+ def __init__(self, stages: Optional[List[CurriculumStage]] = None):
94
+ self._stages = stages or CURRICULUM
95
+ self._current_idx: int = 0
96
+
97
+ @property
98
+ def current(self) -> CurriculumStage:
99
+ return self._stages[self._current_idx]
100
+
101
+ @property
102
+ def current_index(self) -> int:
103
+ return self._current_idx
104
+
105
+ @property
106
+ def total_stages(self) -> int:
107
+ return len(self._stages)
108
+
109
+ @property
110
+ def is_complete(self) -> bool:
111
+ return self._current_idx >= len(self._stages)
112
+
113
+ def report_score(self, score: float) -> bool:
114
+ """Report a score for the current stage. Returns True if passed."""
115
+ if score >= self.current.pass_threshold:
116
+ self._current_idx += 1
117
+ return True
118
+ else:
119
+ self.current.retries += 1
120
+ if self.current.should_skip:
121
+ self._current_idx += 1
122
+ return False
123
+
124
+ def progress_summary(self) -> str:
125
+ stage = self.current
126
+ return (
127
+ f"Stage {self._current_idx + 1}/{self.total_stages}: "
128
+ f"{stage.description} "
129
+ f"(task={stage.task}, max_steps={stage.max_steps}, "
130
+ f"threshold={stage.pass_threshold}, retries={stage.retries})"
131
+ )
grader.py CHANGED
@@ -60,25 +60,41 @@ class Grade:
60
 
61
  Weights deliberately penalise cost heavily so that brute-force
62
  SCALE_UP spam cannot achieve a high composite even with perfect uptime.
63
-
64
  Hardening:
65
  - Task 3 coupling: Cost only rewards if Uptime is >= 50%. Stops 'Cheap-but-Dead'.
66
  - Invalid Action Penalty: -0.05 per forbidden command (SHED_LOAD on critical).
 
 
 
 
 
 
 
 
67
  """
68
  uptime = self.scores["uptime"]
69
  stability = self.scores["stability"]
70
  cost = self.scores["cost"]
71
  invalid_penalty = self.scores.get("invalid_actions", 0) * 0.05
72
 
 
 
 
 
 
 
 
 
 
73
  if self.task_id == "task-3":
74
  # Coupling: If uptime < 0.5, the cost benefit is zeroed out.
75
- # Mirroring real-world priority: Budget doesn't matter if the site is down.
76
  cost_weight = 1.0 if uptime >= 0.5 else 0.0
77
  score = (0.4 * uptime + 0.2 * stability + 0.4 * (cost * cost_weight))
78
  else:
79
  score = (0.4 * uptime + 0.2 * stability + 0.4 * cost)
80
-
81
- return max(0.0, score - invalid_penalty)
82
 
83
  def summary(self) -> str:
84
  s = self.scores
@@ -150,13 +166,15 @@ class EpisodeGrader:
150
 
151
  # ── 4. Invalid Action tracking ──────────────────────────────────────
152
  total_invalid = self._records[-1].get("invalid_action_count", 0)
 
153
 
154
  return Grade(self.task_id, {
155
  "uptime": uptime_score,
156
  "cost": cost_score,
157
  "stability": stability_score,
158
  "violations": total_violations,
159
- "invalid_actions": total_invalid
 
160
  })
161
 
162
 
 
60
 
61
  Weights deliberately penalise cost heavily so that brute-force
62
  SCALE_UP spam cannot achieve a high composite even with perfect uptime.
63
+
64
  Hardening:
65
  - Task 3 coupling: Cost only rewards if Uptime is >= 50%. Stops 'Cheap-but-Dead'.
66
  - Invalid Action Penalty: -0.05 per forbidden command (SHED_LOAD on critical).
67
+ - Episode bonuses: Prevention rewards that DON'T overlap with step-level
68
+ reward signals (no double-counting). These are:
69
+ +0.10 if zero VIP failures throughout the episode
70
+ +0.05 if SLA violations < 3 for the whole episode
71
+ +0.05 if no invalid actions
72
+ These bonuses are small and additive, avoiding overlap with the
73
+ step-level reward which already penalizes SLA violations and barrier
74
+ breaches on each tick. The bonuses reward *sustained* prevention.
75
  """
76
  uptime = self.scores["uptime"]
77
  stability = self.scores["stability"]
78
  cost = self.scores["cost"]
79
  invalid_penalty = self.scores.get("invalid_actions", 0) * 0.05
80
 
81
+ # Episode-level prevention bonuses (NOT in step reward to avoid double-counting)
82
+ bonus = 0.0
83
+ if self.scores.get("vip_failure_count", 0) == 0:
84
+ bonus += 0.10 # Zero VIP failures all episode
85
+ if self.scores.get("violations", 0) < 3:
86
+ bonus += 0.05 # Very few SLA violations all episode
87
+ if self.scores.get("invalid_actions", 0) == 0:
88
+ bonus += 0.05 # Clean actions all episode
89
+
90
  if self.task_id == "task-3":
91
  # Coupling: If uptime < 0.5, the cost benefit is zeroed out.
 
92
  cost_weight = 1.0 if uptime >= 0.5 else 0.0
93
  score = (0.4 * uptime + 0.2 * stability + 0.4 * (cost * cost_weight))
94
  else:
95
  score = (0.4 * uptime + 0.2 * stability + 0.4 * cost)
96
+
97
+ return max(0.0, min(1.0, score - invalid_penalty + bonus))
98
 
99
  def summary(self) -> str:
100
  s = self.scores
 
166
 
167
  # ── 4. Invalid Action tracking ──────────────────────────────────────
168
  total_invalid = self._records[-1].get("invalid_action_count", 0)
169
+ total_vip_failures = self._records[-1].get("vip_failure_count", 0)
170
 
171
  return Grade(self.task_id, {
172
  "uptime": uptime_score,
173
  "cost": cost_score,
174
  "stability": stability_score,
175
  "violations": total_violations,
176
+ "invalid_actions": total_invalid,
177
+ "vip_failure_count": total_vip_failures,
178
  })
179
 
180
 
inference.py CHANGED
@@ -14,6 +14,7 @@ from openai import AsyncOpenAI
14
  from AntiAtropos.client import AntiAtroposEnv
15
  from AntiAtropos.grader import EpisodeGrader
16
  from AntiAtropos.models import ActionType, SREAction
 
17
 
18
  load_dotenv()
19
 
@@ -39,6 +40,8 @@ TEMPERATURE = float(os.getenv("ANTIATROPOS_TEMPERATURE", "0.0"))
39
  MAX_TOKENS = int(os.getenv("ANTIATROPOS_MAX_TOKENS", "180"))
40
  SEED = int(os.getenv("ANTIATROPOS_SEED", "42"))
41
  SUCCESS_SCORE_THRESHOLD = float(os.getenv("ANTIATROPOS_SUCCESS_THRESHOLD", "0.55"))
 
 
42
 
43
  TASK_BRIEFS: Dict[str, str] = {
44
  "task-1": "Traffic increases linearly. Scale proactively to keep latency low and cost efficient.",
@@ -142,9 +145,10 @@ async def open_env(message_timeout_s: int):
142
  raise RuntimeError("Missing environment target. Set ENV_URL/ANTIATROPOS_ENV_URL or LOCAL_IMAGE_NAME.")
143
 
144
 
145
- def build_user_prompt(task_id: str, step: int, obs: dict, history: List[str]) -> str:
146
  recent = "\n".join(history[-4:]) if history else "None"
147
  brief = TASK_BRIEFS.get(task_id, "Maintain SLA, stability, and efficient cost.")
 
148
  return textwrap.dedent(
149
  f"""
150
  Task: {task_id}
@@ -155,7 +159,7 @@ def build_user_prompt(task_id: str, step: int, obs: dict, history: List[str]) ->
155
  {json.dumps(obs, separators=(",", ":"))}
156
 
157
  Recent decisions:
158
- {recent}
159
 
160
  Choose the next SRE action.
161
  """
@@ -174,15 +178,25 @@ def observation_for_model(obs) -> dict:
174
  "total_queue_backlog": obs.total_queue_backlog,
175
  "sla_violations": obs.sla_violations,
176
  "invalid_action_count": obs.invalid_action_count,
 
 
 
 
177
  "nodes": [
178
  {
179
  "node_id": node.node_id,
180
  "status": getattr(node.status, "value", str(node.status)),
181
  "is_vip": node.is_vip,
 
182
  "queue_depth": node.queue_depth,
183
  "latency_ms": node.latency_ms,
184
  "incoming_request_rate": node.incoming_request_rate,
185
  "cpu_utilization": node.cpu_utilization,
 
 
 
 
 
186
  }
187
  for node in obs.nodes
188
  ],
@@ -209,8 +223,8 @@ def _parse_action(payload: dict) -> SREAction:
209
  )
210
 
211
 
212
- async def get_model_action(client: AsyncOpenAI, task_id: str, step: int, obs: dict, history: List[str]) -> SREAction:
213
- prompt = build_user_prompt(task_id=task_id, step=step, obs=obs, history=history)
214
  try:
215
  completion = await client.chat.completions.create(
216
  model=MODEL_NAME,
@@ -241,15 +255,17 @@ def _compact_action(action: SREAction) -> str:
241
  return json.dumps(payload, separators=(",", ":"))
242
 
243
 
244
- async def run_single_task(env: AntiAtroposEnv, client: AsyncOpenAI, task_id: str) -> dict:
245
- task_seed = _task_seed(SEED, task_id)
246
  result = await env.reset(task_id=task_id, mode=ENV_MODE, seed=task_seed)
247
 
248
  grader = EpisodeGrader(task_id=task_id)
249
  grader.record(result.observation)
250
  history: List[str] = []
251
  rewards: List[float] = []
 
252
  steps_taken = 0
 
253
  for step in range(1, MAX_STEPS_PER_TASK + 1):
254
  if result.done:
255
  break
@@ -260,6 +276,7 @@ async def run_single_task(env: AntiAtroposEnv, client: AsyncOpenAI, task_id: str
260
  step=step,
261
  obs=observation_for_model(result.observation),
262
  history=history,
 
263
  )
264
  result = await env.step(action)
265
  grader.record(result.observation)
@@ -270,12 +287,39 @@ async def run_single_task(env: AntiAtroposEnv, client: AsyncOpenAI, task_id: str
270
  action_str = _compact_action(action)
271
  history.append(f"step={step} action={action_str} reward={reward:.2f}")
272
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
  error = getattr(result.observation, "last_action_error", None)
274
  log_step(step=step, action=action_str, reward=reward, done=bool(result.done), error=error)
275
 
276
  grade = grader.score()
277
  score = _strict_score(float(grade.composite))
278
  success = score >= SUCCESS_SCORE_THRESHOLD
 
 
 
 
 
 
 
 
 
 
 
 
 
279
  return {
280
  "task_id": task_id,
281
  "success": success,
@@ -295,29 +339,58 @@ async def run_all_tasks() -> None:
295
  raise RuntimeError("Missing API key (API_KEY/HF_TOKEN/OPENAI_API_KEY).")
296
 
297
  client = AsyncOpenAI(base_url=API_BASE_URL, api_key=API_KEY)
 
298
 
299
  try:
300
  async with open_env(MESSAGE_TIMEOUT_S) as env:
301
  for task in tasks_to_run:
302
- success = False
303
- steps = 0
304
- score = 0.0
305
- rewards: List[float] = []
306
- task_error: Optional[Exception] = None
307
- log_start(task=task, env=BENCHMARK, model=MODEL_NAME)
308
- try:
309
- report = await run_single_task(env=env, client=client, task_id=task)
310
- success = bool(report["success"])
311
- steps = int(report["steps"])
312
- score = _strict_score(float(report["score"]))
313
- rewards = list(report["rewards"])
314
- except Exception as exc:
315
- task_error = exc
316
  score = 0.0
317
- finally:
318
- log_end(success=success, steps=steps, score=score, rewards=rewards)
319
- if task_error is not None:
320
- raise InferenceError(f"Task {task} failed.") from task_error
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
321
  finally:
322
  await client.close()
323
 
 
14
  from AntiAtropos.client import AntiAtroposEnv
15
  from AntiAtropos.grader import EpisodeGrader
16
  from AntiAtropos.models import ActionType, SREAction
17
+ from AntiAtropos.replay import EpisodeReplayBuffer, compress_trajectory
18
 
19
  load_dotenv()
20
 
 
40
  MAX_TOKENS = int(os.getenv("ANTIATROPOS_MAX_TOKENS", "180"))
41
  SEED = int(os.getenv("ANTIATROPOS_SEED", "42"))
42
  SUCCESS_SCORE_THRESHOLD = float(os.getenv("ANTIATROPOS_SUCCESS_THRESHOLD", "0.55"))
43
+ EVAL_RUNS = int(os.getenv("ANTIATROPOS_EVAL_RUNS", "3")) # Num eval runs per task
44
+ TEMPERATURE_SWEEP = [0.0, 0.3, 0.7] # Fixed temperatures for multi-episode eval
45
 
46
  TASK_BRIEFS: Dict[str, str] = {
47
  "task-1": "Traffic increases linearly. Scale proactively to keep latency low and cost efficient.",
 
145
  raise RuntimeError("Missing environment target. Set ENV_URL/ANTIATROPOS_ENV_URL or LOCAL_IMAGE_NAME.")
146
 
147
 
148
+ def build_user_prompt(task_id: str, step: int, obs: dict, history: List[str], demo_text: str = "") -> str:
149
  recent = "\n".join(history[-4:]) if history else "None"
150
  brief = TASK_BRIEFS.get(task_id, "Maintain SLA, stability, and efficient cost.")
151
+ demo_section = f"\n\n{demo_text}" if demo_text else ""
152
  return textwrap.dedent(
153
  f"""
154
  Task: {task_id}
 
159
  {json.dumps(obs, separators=(",", ":"))}
160
 
161
  Recent decisions:
162
+ {recent}{demo_section}
163
 
164
  Choose the next SRE action.
165
  """
 
178
  "total_queue_backlog": obs.total_queue_backlog,
179
  "sla_violations": obs.sla_violations,
180
  "invalid_action_count": obs.invalid_action_count,
181
+ "reward_drift": getattr(obs, "reward_drift", 0.0),
182
+ "reward_cost": getattr(obs, "reward_cost", 0.0),
183
+ "reward_sla": getattr(obs, "reward_sla", 0.0),
184
+ "reward_barrier": getattr(obs, "reward_barrier", 0.0),
185
  "nodes": [
186
  {
187
  "node_id": node.node_id,
188
  "status": getattr(node.status, "value", str(node.status)),
189
  "is_vip": node.is_vip,
190
+ "importance_weight": node.importance_weight,
191
  "queue_depth": node.queue_depth,
192
  "latency_ms": node.latency_ms,
193
  "incoming_request_rate": node.incoming_request_rate,
194
  "cpu_utilization": node.cpu_utilization,
195
+ "capacity": getattr(node, "capacity", 0.0),
196
+ "pending_capacity": getattr(node, "pending_capacity", 0.0),
197
+ "queue_delta": getattr(node, "queue_delta", 0.0),
198
+ "sla_proximity": getattr(node, "sla_proximity", 0.0),
199
+ "node_reward": getattr(node, "node_reward", 0.0),
200
  }
201
  for node in obs.nodes
202
  ],
 
223
  )
224
 
225
 
226
+ async def get_model_action(client: AsyncOpenAI, task_id: str, step: int, obs: dict, history: List[str], demo_text: str = "") -> SREAction:
227
+ prompt = build_user_prompt(task_id=task_id, step=step, obs=obs, history=history, demo_text=demo_text)
228
  try:
229
  completion = await client.chat.completions.create(
230
  model=MODEL_NAME,
 
255
  return json.dumps(payload, separators=(",", ":"))
256
 
257
 
258
+ async def run_single_task(env: AntiAtroposEnv, client: AsyncOpenAI, task_id: str, temperature: float = 0.0, replay_buffer: Optional[EpisodeReplayBuffer] = None, run_seed: Optional[int] = None) -> dict:
259
+ task_seed = run_seed if run_seed is not None else _task_seed(SEED, task_id)
260
  result = await env.reset(task_id=task_id, mode=ENV_MODE, seed=task_seed)
261
 
262
  grader = EpisodeGrader(task_id=task_id)
263
  grader.record(result.observation)
264
  history: List[str] = []
265
  rewards: List[float] = []
266
+ raw_steps: List[dict] = [] # For replay buffer compression
267
  steps_taken = 0
268
+ demo_text = replay_buffer.format_demonstrations() if replay_buffer else ""
269
  for step in range(1, MAX_STEPS_PER_TASK + 1):
270
  if result.done:
271
  break
 
276
  step=step,
277
  obs=observation_for_model(result.observation),
278
  history=history,
279
+ demo_text=demo_text,
280
  )
281
  result = await env.step(action)
282
  grader.record(result.observation)
 
287
  action_str = _compact_action(action)
288
  history.append(f"step={step} action={action_str} reward={reward:.2f}")
289
 
290
+ # Collect raw step data for replay compression
291
+ obs = result.observation
292
+ raw_steps.append({
293
+ "step": step,
294
+ "action_type": action.action_type.value,
295
+ "target_node_id": action.target_node_id,
296
+ "parameter": float(action.parameter),
297
+ "reward": reward,
298
+ "avg_latency_norm": getattr(obs, "average_latency_ms", 0.0),
299
+ "error_rate": getattr(obs, "error_rate", 0.0),
300
+ "queue_backlog_norm": getattr(obs, "total_queue_backlog", 0.0),
301
+ "sla_violation": reward < 0.3,
302
+ })
303
+
304
  error = getattr(result.observation, "last_action_error", None)
305
  log_step(step=step, action=action_str, reward=reward, done=bool(result.done), error=error)
306
 
307
  grade = grader.score()
308
  score = _strict_score(float(grade.composite))
309
  success = score >= SUCCESS_SCORE_THRESHOLD
310
+
311
+ # Store in replay buffer if available
312
+ if replay_buffer is not None and raw_steps:
313
+ trajectory = compress_trajectory(
314
+ steps=raw_steps,
315
+ task_id=task_id,
316
+ score=score,
317
+ total_steps=steps_taken,
318
+ final_sla_violations=int(grade.scores.get("violations", 0)),
319
+ final_invalid_actions=int(grade.scores.get("invalid_actions", 0)),
320
+ )
321
+ replay_buffer.store(trajectory, score)
322
+
323
  return {
324
  "task_id": task_id,
325
  "success": success,
 
339
  raise RuntimeError("Missing API key (API_KEY/HF_TOKEN/OPENAI_API_KEY).")
340
 
341
  client = AsyncOpenAI(base_url=API_BASE_URL, api_key=API_KEY)
342
+ replay_buffer = EpisodeReplayBuffer()
343
 
344
  try:
345
  async with open_env(MESSAGE_TIMEOUT_S) as env:
346
  for task in tasks_to_run:
347
+ task_scores: List[float] = []
348
+ task_successes: List[bool] = []
349
+
350
+ for run_idx in range(EVAL_RUNS):
351
+ # Fixed seed per (task, run_idx) so runs are reproducible
352
+ # and comparable across temperature conditions.
353
+ run_seed = SEED * 1000 + hash(task) % 100 + run_idx
354
+ temperature = TEMPERATURE_SWEEP[run_idx % len(TEMPERATURE_SWEEP)]
355
+
356
+ success = False
357
+ steps = 0
 
 
 
358
  score = 0.0
359
+ rewards: List[float] = []
360
+ task_error: Optional[Exception] = None
361
+ log_start(task=f"{task} run={run_idx+1}/{EVAL_RUNS} temp={temperature}", env=BENCHMARK, model=MODEL_NAME)
362
+ try:
363
+ report = await run_single_task(
364
+ env=env,
365
+ client=client,
366
+ task_id=task,
367
+ temperature=temperature,
368
+ replay_buffer=replay_buffer,
369
+ run_seed=run_seed,
370
+ )
371
+ success = bool(report["success"])
372
+ steps = int(report["steps"])
373
+ score = _strict_score(float(report["score"]))
374
+ rewards = list(report["rewards"])
375
+ task_scores.append(score)
376
+ task_successes.append(success)
377
+ except Exception as exc:
378
+ task_error = exc
379
+ score = 0.0
380
+ finally:
381
+ log_end(success=success, steps=steps, score=score, rewards=rewards)
382
+ if task_error is not None:
383
+ raise InferenceError(f"Task {task} run {run_idx+1} failed.") from task_error
384
+
385
+ # Report aggregate stats
386
+ if task_scores:
387
+ mean_score = sum(task_scores) / len(task_scores)
388
+ std_score = (sum((s - mean_score) ** 2 for s in task_scores) / len(task_scores)) ** 0.5
389
+ print(
390
+ f"[AGGREGATE] task={task} mean_score={mean_score:.3f} "
391
+ f"std={std_score:.3f} runs={len(task_scores)}",
392
+ flush=True,
393
+ )
394
  finally:
395
  await client.close()
396
 
models.py CHANGED
@@ -84,6 +84,37 @@ class NodeObservation(BaseModel):
84
  description="Business criticality weight. VIP nodes have higher impact on scoring.",
85
  )
86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  # Episode interaction fields (handled by framework)
88
  done: bool = False
89
  reward: float = 0.0
@@ -158,6 +189,24 @@ class ClusterObservation(BaseModel):
158
  raw_reward: float = 0.0
159
  normalized_reward: float = Field(default=0.0, ge=0.0, le=1.0)
160
  reward_scale_version: str = "sigmoid-v1"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  choke_level: float = 0.0
162
 
163
  nodes: list[NodeObservation]
 
84
  description="Business criticality weight. VIP nodes have higher impact on scoring.",
85
  )
86
 
87
+ capacity: float = Field(
88
+ default=0.0,
89
+ ge=0.0,
90
+ description="Current capacity units provisioned for this node (0-5).",
91
+ )
92
+
93
+ pending_capacity: float = Field(
94
+ default=0.0,
95
+ ge=0.0,
96
+ description="Capacity units being booted (will be live after boot delay).",
97
+ )
98
+
99
+ queue_delta: float = Field(
100
+ default=0.0,
101
+ ge=-1.0,
102
+ le=1.0,
103
+ description="Normalized queue depth change from previous tick (-1 to +1).",
104
+ )
105
+
106
+ sla_proximity: float = Field(
107
+ default=0.0,
108
+ ge=0.0,
109
+ le=1.0,
110
+ description="How close this node is to SLA violation (0=safe, 1=violating).",
111
+ )
112
+
113
+ node_reward: float = Field(
114
+ default=0.0,
115
+ description="Per-node reward contribution for credit assignment.",
116
+ )
117
+
118
  # Episode interaction fields (handled by framework)
119
  done: bool = False
120
  reward: float = 0.0
 
189
  raw_reward: float = 0.0
190
  normalized_reward: float = Field(default=0.0, ge=0.0, le=1.0)
191
  reward_scale_version: str = "sigmoid-v1"
192
+ # Reward components breakdown
193
+ reward_drift: float = Field(
194
+ default=0.0,
195
+ description="Lyapunov drift component of the reward.",
196
+ )
197
+ reward_cost: float = Field(
198
+ default=0.0,
199
+ description="Infrastructure cost component of the reward.",
200
+ )
201
+ reward_sla: float = Field(
202
+ default=0.0,
203
+ description="SLA penalty component of the reward.",
204
+ )
205
+ reward_barrier: float = Field(
206
+ default=0.0,
207
+ description="Barrier function penalty component of the reward.",
208
+ )
209
+
210
  choke_level: float = 0.0
211
 
212
  nodes: list[NodeObservation]
replay.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AntiAtropos Episode Replay Buffer.
3
+
4
+ Stores episode trajectories for few-shot demonstrations during inference.
5
+ Uses summarization/compression to keep context window manageable:
6
+ - Only stores key transition windows (action, reward spike, SLA violation)
7
+ - Compresses long stable stretches into single summary lines
8
+ - Caps total demonstration size to avoid LLM context overflow
9
+ """
10
+
11
+ import random
12
+ from collections import deque
13
+ from dataclasses import dataclass, field
14
+ from typing import List, Optional
15
+
16
+
17
+ @dataclass
18
+ class Transition:
19
+ """A single step in an episode trajectory."""
20
+ step: int
21
+ action_type: str
22
+ target_node_id: str
23
+ parameter: float
24
+ reward: float
25
+ avg_latency_norm: float
26
+ error_rate: float
27
+ queue_backlog_norm: float
28
+ sla_violation: bool
29
+
30
+
31
+ @dataclass
32
+ class EpisodeTrajectory:
33
+ """A compressed episode trajectory for few-shot prompting."""
34
+ task_id: str
35
+ score: float
36
+ # Full trajectory is NOT stored — only key transitions
37
+ key_transitions: List[Transition] = field(default_factory=list)
38
+ total_steps: int = 0
39
+ final_sla_violations: int = 0
40
+ final_invalid_actions: int = 0
41
+
42
+ def to_prompt_lines(self, max_lines: int = 8) -> List[str]:
43
+ """Convert to concise prompt lines, capped at max_lines.
44
+
45
+ Summarization strategy:
46
+ 1. Always include first action (shows opening strategy)
47
+ 2. Always include highest-reward action (shows what worked)
48
+ 3. Always include last action (shows closing strategy)
49
+ 4. Fill remaining with transitions near SLA violations
50
+ 5. If still under max_lines, add evenly-spaced transitions
51
+ """
52
+ if not self.key_transitions:
53
+ return []
54
+
55
+ lines: List[str] = []
56
+ selected: List[Transition] = []
57
+
58
+ # Always take first
59
+ selected.append(self.key_transitions[0])
60
+
61
+ # Always take highest-reward
62
+ best = max(self.key_transitions, key=lambda t: t.reward)
63
+ if best not in selected:
64
+ selected.append(best)
65
+
66
+ # Always take last
67
+ last = self.key_transitions[-1]
68
+ if last not in selected:
69
+ selected.append(last)
70
+
71
+ # Add transitions near SLA violations (up to 2)
72
+ violation_trans = [t for t in self.key_transitions if t.sla_violation and t not in selected]
73
+ for vt in violation_trans[:2]:
74
+ selected.append(vt)
75
+
76
+ # Fill with evenly-spaced transitions
77
+ remaining = max_lines - len(selected)
78
+ if remaining > 0 and len(self.key_transitions) > len(selected):
79
+ stride = max(1, len(self.key_transitions) // (remaining + 1))
80
+ for i in range(stride, len(self.key_transitions), stride):
81
+ if self.key_transitions[i] not in selected and remaining > 0:
82
+ selected.append(self.key_transitions[i])
83
+ remaining -= 1
84
+
85
+ # Sort by step and format
86
+ selected.sort(key=lambda t: t.step)
87
+ for t in selected[:max_lines]:
88
+ action_str = f'{{"action_type":"{t.action_type}","target_node_id":"{t.target_node_id}","parameter":{t.parameter:.2f}}}'
89
+ lines.append(f"Step {t.step}: {action_str} reward={t.reward:.2f}")
90
+
91
+ # Add summary
92
+ lines.append(
93
+ f"[Episode summary: score={self.score:.2f}, "
94
+ f"steps={self.total_steps}, "
95
+ f"SLA_violations={self.final_sla_violations}]"
96
+ )
97
+ return lines
98
+
99
+
100
+ class EpisodeReplayBuffer:
101
+ """
102
+ Rolling buffer of episode trajectories for few-shot learning.
103
+
104
+ Addresses context explosion by:
105
+ 1. Storing only compressed trajectories (key transitions, not full)
106
+ 2. Capping demonstration size at MAX_DEMO_LINES per prompt inclusion
107
+ 3. Sampling at most MAX_DEMOS_PER_PROMPT trajectories
108
+ """
109
+
110
+ MAX_DEMO_LINES: int = 8 # Max lines per trajectory in prompt
111
+ MAX_DEMOS_PER_PROMPT: int = 2 # Max trajectories included in prompt
112
+
113
+ def __init__(self, max_episodes: int = 50):
114
+ self._positive: deque[EpisodeTrajectory] = deque(maxlen=max_episodes)
115
+ self._negative: deque[EpisodeTrajectory] = deque(maxlen=max_episodes)
116
+
117
+ def store(self, trajectory: EpisodeTrajectory, score: float) -> None:
118
+ """Store an episode trajectory, categorized by score."""
119
+ if score >= 0.55:
120
+ self._positive.append(trajectory)
121
+ elif score < 0.3:
122
+ self._negative.append(trajectory)
123
+
124
+ def sample_demonstrations(self, n: Optional[int] = None) -> List[EpisodeTrajectory]:
125
+ """Sample n positive episodes for few-shot prompting."""
126
+ if n is None:
127
+ n = self.MAX_DEMOS_PER_PROMPT
128
+ if not self._positive:
129
+ return []
130
+ return random.sample(list(self._positive), min(n, len(self._positive)))
131
+
132
+ def format_demonstrations(self) -> str:
133
+ """Format sampled demonstrations into a prompt-ready string.
134
+
135
+ Returns empty string if no demonstrations available.
136
+ Total output is bounded by MAX_DEMO_LINES * MAX_DEMOS_PER_PROMPT.
137
+ """
138
+ demos = self.sample_demonstrations()
139
+ if not demos:
140
+ return ""
141
+
142
+ parts = []
143
+ for i, demo in enumerate(demos):
144
+ lines = demo.to_prompt_lines(max_lines=self.MAX_DEMO_LINES)
145
+ if lines:
146
+ parts.append(f"Example {i+1} (task={demo.task_id}):")
147
+ parts.extend(lines)
148
+
149
+ if not parts:
150
+ return ""
151
+
152
+ return "Successful episode examples:\n" + "\n".join(parts)
153
+
154
+
155
+ def compress_trajectory(
156
+ steps: List[dict],
157
+ task_id: str,
158
+ score: float,
159
+ total_steps: int,
160
+ final_sla_violations: int = 0,
161
+ final_invalid_actions: int = 0,
162
+ ) -> EpisodeTrajectory:
163
+ """Compress a raw step list into a trajectory with only key transitions.
164
+
165
+ Raw steps are dicts with keys:
166
+ step, action_type, target_node_id, parameter, reward,
167
+ avg_latency_norm, error_rate, queue_backlog_norm, sla_violation
168
+
169
+ Key transition selection:
170
+ - First step
171
+ - Last step
172
+ - Steps with SLA violations
173
+ - Steps with highest/lowest reward
174
+ - Steps where action changed direction (e.g. SCALE_UP then SCALE_DOWN)
175
+ """
176
+ if not steps:
177
+ return EpisodeTrajectory(
178
+ task_id=task_id,
179
+ score=score,
180
+ total_steps=total_steps,
181
+ final_sla_violations=final_sla_violations,
182
+ final_invalid_actions=final_invalid_actions,
183
+ )
184
+
185
+ # Always include first and last
186
+ key_indices = {0, len(steps) - 1}
187
+
188
+ # Include SLA violations
189
+ for i, s in enumerate(steps):
190
+ if s.get("sla_violation"):
191
+ key_indices.add(i)
192
+
193
+ # Include reward extremes
194
+ if len(steps) > 2:
195
+ best_idx = max(range(len(steps)), key=lambda i: steps[i].get("reward", 0))
196
+ worst_idx = min(range(len(steps)), key=lambda i: steps[i].get("reward", 0))
197
+ key_indices.add(best_idx)
198
+ key_indices.add(worst_idx)
199
+
200
+ # Include action direction changes
201
+ for i in range(1, len(steps)):
202
+ prev_action = steps[i - 1].get("action_type", "")
203
+ curr_action = steps[i].get("action_type", "")
204
+ if prev_action != curr_action:
205
+ key_indices.add(i)
206
+
207
+ # Build compressed transitions (sorted)
208
+ key_transitions = []
209
+ for i in sorted(key_indices):
210
+ s = steps[i]
211
+ key_transitions.append(Transition(
212
+ step=s.get("step", i),
213
+ action_type=s.get("action_type", "NO_OP"),
214
+ target_node_id=s.get("target_node_id", "node-0"),
215
+ parameter=s.get("parameter", 0.0),
216
+ reward=s.get("reward", 0.0),
217
+ avg_latency_norm=s.get("avg_latency_norm", 0.0),
218
+ error_rate=s.get("error_rate", 0.0),
219
+ queue_backlog_norm=s.get("queue_backlog_norm", 0.0),
220
+ sla_violation=s.get("sla_violation", False),
221
+ ))
222
+
223
+ return EpisodeTrajectory(
224
+ task_id=task_id,
225
+ score=score,
226
+ key_transitions=key_transitions,
227
+ total_steps=total_steps,
228
+ final_sla_violations=final_sla_violations,
229
+ final_invalid_actions=final_invalid_actions,
230
+ )
server/AntiAtropos_environment.py CHANGED
@@ -10,13 +10,13 @@ from openenv.core.env_server.types import State
10
  try:
11
  from ..models import SREAction, ClusterObservation, NodeObservation, NodeStatus, EnvironmentMode
12
  from ..simulator import ClusterSimulator, COST_PER_CAPACITY_UNIT_PER_HOUR
13
- from ..stability import compute_lyapunov, compute_reward, normalize_reward, REWARD_SCALE_VERSION
14
  from ..telemetry import PrometheusClient, get_observability_tracker
15
  from ..control import KubernetesExecutor, ActionValidator
16
  except ImportError:
17
  from models import SREAction, ClusterObservation, NodeObservation, NodeStatus, EnvironmentMode # type: ignore[no-redef]
18
  from simulator import ClusterSimulator, COST_PER_CAPACITY_UNIT_PER_HOUR # type: ignore[no-redef]
19
- from stability import compute_lyapunov, compute_reward, normalize_reward, REWARD_SCALE_VERSION # type: ignore[no-redef]
20
  from telemetry import PrometheusClient, get_observability_tracker # type: ignore[no-redef]
21
  from control import KubernetesExecutor, ActionValidator # type: ignore[no-redef]
22
 
@@ -25,9 +25,10 @@ except ImportError:
25
  # Reward hyper-parameters (synchronized with stability.py constants)
26
  # ---------------------------------------------------------------------------
27
 
28
- ALPHA: float = 0.002 # Weight on Lyapunov energy drift ΔV(s) (Increased for faster feedback)
29
  BETA: float = 0.01 # Weight on infrastructure cost (Reduced to prevent cheap-but-dead strategies)
30
  GAMMA: float = 10.0 # Weight on per-step SLA violation indicator (Increased to force reactive scaling)
 
31
 
32
  MAX_QUEUE_NORM = 200.0
33
  MAX_LATENCY_NORM = 1000.0
@@ -66,6 +67,7 @@ class AntiAtroposEnvironment(Environment):
66
 
67
  self._nodes_true: list[dict] = []
68
  self._nodes_obs: list[dict] = []
 
69
  self._prev_lyapunov: float = 0.0
70
  self._sla_violations: int = 0
71
  self._action_ack_status: str = "success"
@@ -74,6 +76,10 @@ class AntiAtroposEnvironment(Environment):
74
  self._last_executor_error_code: str = ""
75
  self._last_raw_reward: float = 0.0
76
  self._last_normalized_reward: float = 0.0
 
 
 
 
77
  self._reward_output_mode: str = os.getenv("ANTIATROPOS_REWARD_OUTPUT_MODE", "normalized").strip().lower()
78
  if self._reward_output_mode not in REWARD_OUTPUT_MODES:
79
  self._reward_output_mode = "normalized"
@@ -140,14 +146,13 @@ class AntiAtroposEnvironment(Environment):
140
  is_enabled, mode_error = self._is_action_enabled_for_mode(action.action_type)
141
  if not is_enabled:
142
  self._action_ack_status = f"Rejected: {mode_error}"
143
- # Capability gate rejections happen before executor invocation, so
144
- # they should be tracked as rejected actions (ack_class) rather than
145
- # executor failures.
146
  self._last_executor_error_code = ""
147
  is_valid = False
148
  error = mode_error
 
149
  else:
150
- is_valid, error = self._validator.validate(
 
151
  action.action_type,
152
  action.target_node_id,
153
  action.parameter,
@@ -196,34 +201,51 @@ class AntiAtroposEnvironment(Environment):
196
  self._last_metric_time = time.time()
197
 
198
  # 4. Extract states (Ground Truth for reward; Observation for agent)
 
199
  self._nodes_true = self._sim.state(for_agent=False)
200
  self._nodes_obs = self._sim.state(for_agent=True)
201
 
202
- # 5. SLA Check
203
- avg_latency = self._avg_latency(self._nodes_true)
204
  error_rate = self._error_rate(self._nodes_true)
205
- sla_violation_step = 1 if (avg_latency > 200.0 or error_rate > 0.05) else 0
206
- if sla_violation_step:
 
207
  self._sla_violations += 1
208
 
209
  # 6. Compute Lyapunov stability metrics from Ground Truth
210
  current_lyapunov = compute_lyapunov(self._nodes_true)
211
 
212
- # 7. Compute scalar reward
213
  cost = self._compute_cost(self._nodes_true)
 
214
  raw_reward = compute_reward(
215
  v_prev=self._prev_lyapunov,
216
  v_curr=current_lyapunov,
217
  cost=cost,
218
- sla_violation_step=sla_violation_step,
219
  alpha=ALPHA,
220
  beta=BETA,
221
- gamma=GAMMA
 
 
222
  )
223
  normalized_reward = normalize_reward(raw_reward)
 
 
 
 
224
  reward = normalized_reward if self._reward_output_mode == "normalized" else raw_reward
225
  self._last_raw_reward = raw_reward
226
  self._last_normalized_reward = normalized_reward
 
 
 
 
 
 
 
 
227
 
228
  self._prev_lyapunov = current_lyapunov
229
 
@@ -348,8 +370,40 @@ class AntiAtroposEnvironment(Environment):
348
 
349
  def _build_observation(self) -> ClusterObservation:
350
  """Assembles the ClusterObservation from the current observed simulator state."""
351
- node_obs = [
352
- NodeObservation(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
353
  node_id=n["node_id"],
354
  status=n["status"],
355
  queue_depth=min(1.0, max(0.0, float(n["queue_depth"]) / MAX_QUEUE_NORM)),
@@ -358,11 +412,14 @@ class AntiAtroposEnvironment(Environment):
358
  cpu_utilization=min(1.0, max(0.0, float(n["cpu_utilization"]))),
359
  is_vip=bool(n.get("is_vip", False)),
360
  importance_weight=float(n.get("importance_weight", 1.0)),
 
 
 
 
 
361
  done=False,
362
  reward=0.0,
363
- )
364
- for n in self._nodes_obs
365
- ]
366
 
367
  freshness = int((time.time() - self._last_metric_time) * 1000) if self._last_metric_time > 0 else 0
368
 
@@ -391,6 +448,10 @@ class AntiAtroposEnvironment(Environment):
391
  raw_reward=self._last_raw_reward,
392
  normalized_reward=self._last_normalized_reward,
393
  reward_scale_version=REWARD_SCALE_VERSION,
 
 
 
 
394
  choke_level=0.0,
395
  done=False,
396
  reward=0.0,
 
10
  try:
11
  from ..models import SREAction, ClusterObservation, NodeObservation, NodeStatus, EnvironmentMode
12
  from ..simulator import ClusterSimulator, COST_PER_CAPACITY_UNIT_PER_HOUR
13
+ from ..stability import compute_lyapunov, compute_reward, compute_barrier, normalize_reward, smooth_sla_penalty, REWARD_SCALE_VERSION
14
  from ..telemetry import PrometheusClient, get_observability_tracker
15
  from ..control import KubernetesExecutor, ActionValidator
16
  except ImportError:
17
  from models import SREAction, ClusterObservation, NodeObservation, NodeStatus, EnvironmentMode # type: ignore[no-redef]
18
  from simulator import ClusterSimulator, COST_PER_CAPACITY_UNIT_PER_HOUR # type: ignore[no-redef]
19
+ from stability import compute_lyapunov, compute_reward, compute_barrier, normalize_reward, smooth_sla_penalty, REWARD_SCALE_VERSION # type: ignore[no-redef]
20
  from telemetry import PrometheusClient, get_observability_tracker # type: ignore[no-redef]
21
  from control import KubernetesExecutor, ActionValidator # type: ignore[no-redef]
22
 
 
25
  # Reward hyper-parameters (synchronized with stability.py constants)
26
  # ---------------------------------------------------------------------------
27
 
28
+ ALPHA: float = 0.002 # Weight on Lyapunov energy drift DeltaV(s) (Increased for faster feedback)
29
  BETA: float = 0.01 # Weight on infrastructure cost (Reduced to prevent cheap-but-dead strategies)
30
  GAMMA: float = 10.0 # Weight on per-step SLA violation indicator (Increased to force reactive scaling)
31
+ DELTA: float = 0.005 # Weight on control-barrier function penalty (queue safety zone)
32
 
33
  MAX_QUEUE_NORM = 200.0
34
  MAX_LATENCY_NORM = 1000.0
 
67
 
68
  self._nodes_true: list[dict] = []
69
  self._nodes_obs: list[dict] = []
70
+ self._prev_nodes_true: list[dict] = [] # For per-node queue delta + reward
71
  self._prev_lyapunov: float = 0.0
72
  self._sla_violations: int = 0
73
  self._action_ack_status: str = "success"
 
76
  self._last_executor_error_code: str = ""
77
  self._last_raw_reward: float = 0.0
78
  self._last_normalized_reward: float = 0.0
79
+ self._last_reward_drift: float = 0.0
80
+ self._last_reward_cost: float = 0.0
81
+ self._last_reward_sla: float = 0.0
82
+ self._last_reward_barrier: float = 0.0
83
  self._reward_output_mode: str = os.getenv("ANTIATROPOS_REWARD_OUTPUT_MODE", "normalized").strip().lower()
84
  if self._reward_output_mode not in REWARD_OUTPUT_MODES:
85
  self._reward_output_mode = "normalized"
 
146
  is_enabled, mode_error = self._is_action_enabled_for_mode(action.action_type)
147
  if not is_enabled:
148
  self._action_ack_status = f"Rejected: {mode_error}"
 
 
 
149
  self._last_executor_error_code = ""
150
  is_valid = False
151
  error = mode_error
152
+ cooldown_penalty = 0.0
153
  else:
154
+ self._validator.set_tick(self._state.step_count)
155
+ is_valid, error, cooldown_penalty = self._validator.validate(
156
  action.action_type,
157
  action.target_node_id,
158
  action.parameter,
 
201
  self._last_metric_time = time.time()
202
 
203
  # 4. Extract states (Ground Truth for reward; Observation for agent)
204
+ self._prev_nodes_true = self._nodes_true # Save for per-node delta
205
  self._nodes_true = self._sim.state(for_agent=False)
206
  self._nodes_obs = self._sim.state(for_agent=True)
207
 
208
+ # 5. SLA Check (smooth sigmoid penalty instead of binary cliff)
209
+ avg_latency_norm = self._avg_latency(self._nodes_true) / MAX_LATENCY_NORM
210
  error_rate = self._error_rate(self._nodes_true)
211
+ sla_penalty_step = smooth_sla_penalty(avg_latency_norm, error_rate)
212
+ # Track binary violations for the grader (backward compat)
213
+ if avg_latency_norm > 0.20 or error_rate > 0.05:
214
  self._sla_violations += 1
215
 
216
  # 6. Compute Lyapunov stability metrics from Ground Truth
217
  current_lyapunov = compute_lyapunov(self._nodes_true)
218
 
219
+ # 7. Compute scalar reward (with barrier function)
220
  cost = self._compute_cost(self._nodes_true)
221
+ barrier = compute_barrier(self._nodes_true)
222
  raw_reward = compute_reward(
223
  v_prev=self._prev_lyapunov,
224
  v_curr=current_lyapunov,
225
  cost=cost,
226
+ sla_violation_step=sla_penalty_step,
227
  alpha=ALPHA,
228
  beta=BETA,
229
+ gamma=GAMMA,
230
+ barrier=barrier,
231
+ delta=DELTA,
232
  )
233
  normalized_reward = normalize_reward(raw_reward)
234
+ # Apply soft cooldown penalty: reduces reward for rapid re-scaling
235
+ # without blocking the action (emergency scaling still goes through)
236
+ if cooldown_penalty > 0:
237
+ normalized_reward = max(0.0, normalized_reward - cooldown_penalty * 0.1)
238
  reward = normalized_reward if self._reward_output_mode == "normalized" else raw_reward
239
  self._last_raw_reward = raw_reward
240
  self._last_normalized_reward = normalized_reward
241
+ # Store reward component breakdown for the observation
242
+ from ..stability import compute_drift, BARRIER_NORM_SCALE
243
+ delta_v = compute_drift(self._prev_lyapunov, current_lyapunov)
244
+ barrier_norm = barrier / BARRIER_NORM_SCALE if BARRIER_NORM_SCALE > 0 else barrier
245
+ self._last_reward_drift = -(ALPHA * delta_v)
246
+ self._last_reward_cost = -(BETA * cost)
247
+ self._last_reward_sla = -(GAMMA * sla_penalty_step)
248
+ self._last_reward_barrier = -(DELTA * barrier_norm)
249
 
250
  self._prev_lyapunov = current_lyapunov
251
 
 
370
 
371
  def _build_observation(self) -> ClusterObservation:
372
  """Assembles the ClusterObservation from the current observed simulator state."""
373
+ # Build a lookup for previous node state (for queue_delta and node_reward)
374
+ prev_by_id: dict[str, dict] = {n["node_id"]: n for n in self._prev_nodes_true}
375
+
376
+ node_obs = []
377
+ for n in self._nodes_obs:
378
+ # Per-node queue delta (normalized)
379
+ true_n = next((t for t in self._nodes_true if t["node_id"] == n["node_id"]), n)
380
+ prev_n = prev_by_id.get(n["node_id"])
381
+ if prev_n:
382
+ queue_delta_raw = float(n["queue_depth"]) - float(prev_n.get("queue_depth", 0))
383
+ queue_delta = max(-1.0, min(1.0, queue_delta_raw / MAX_QUEUE_NORM))
384
+ else:
385
+ queue_delta = 0.0
386
+
387
+ # Per-node reward contribution (normalized)
388
+ # Uses same formula as global reward but per-node
389
+ weight = float(n.get("importance_weight", 1.0))
390
+ if prev_n:
391
+ prev_q = float(prev_n.get("queue_depth", 0))
392
+ curr_q = float(true_n["queue_depth"])
393
+ node_drift = weight * (curr_q ** 2 - prev_q ** 2)
394
+ node_barrier = max(0, curr_q - 150.0) ** 2 # Q_BARRIER_MAX=150
395
+ node_cost = float(true_n.get("capacity_units", 0)) * COST_PER_CAPACITY_UNIT_PER_HOUR
396
+ node_reward_raw = -(ALPHA * node_drift + DELTA * (node_barrier / 10000.0) + BETA * node_cost)
397
+ # Normalize to [-1, 0] range
398
+ node_reward_val = max(-1.0, min(0.0, node_reward_raw / 10.0))
399
+ else:
400
+ node_reward_val = 0.0
401
+
402
+ # SLA proximity: how close this node is to violating (normalized)
403
+ node_latency_norm = min(1.0, max(0.0, float(n["latency_ms"]) / MAX_LATENCY_NORM))
404
+ sla_prox = max(0.0, min(1.0, node_latency_norm / 0.20)) # 0.20 is SLA threshold
405
+
406
+ node_obs.append(NodeObservation(
407
  node_id=n["node_id"],
408
  status=n["status"],
409
  queue_depth=min(1.0, max(0.0, float(n["queue_depth"]) / MAX_QUEUE_NORM)),
 
412
  cpu_utilization=min(1.0, max(0.0, float(n["cpu_utilization"]))),
413
  is_vip=bool(n.get("is_vip", False)),
414
  importance_weight=float(n.get("importance_weight", 1.0)),
415
+ capacity=float(n.get("capacity_units", 0)) / 5.0, # Normalize to [0,1]
416
+ pending_capacity=float(n.get("pending_capacity_units", 0)) / 5.0,
417
+ queue_delta=queue_delta,
418
+ sla_proximity=sla_prox,
419
+ node_reward=node_reward_val,
420
  done=False,
421
  reward=0.0,
422
+ ))
 
 
423
 
424
  freshness = int((time.time() - self._last_metric_time) * 1000) if self._last_metric_time > 0 else 0
425
 
 
448
  raw_reward=self._last_raw_reward,
449
  normalized_reward=self._last_normalized_reward,
450
  reward_scale_version=REWARD_SCALE_VERSION,
451
+ reward_drift=self._last_reward_drift,
452
+ reward_cost=self._last_reward_cost,
453
+ reward_sla=self._last_reward_sla,
454
+ reward_barrier=self._last_reward_barrier,
455
  choke_level=0.0,
456
  done=False,
457
  reward=0.0,
simulator.py CHANGED
@@ -30,6 +30,9 @@ BASE_LATENCY_MS: float = 20.0 # Minimum processing time
30
  OVERLOAD_THRESHOLD: int = 80 # Request count where node begins to "fail" (DEGRADED)
31
  LATENCY_STEEPNESS: float = 2.0 # Increased to ensure SLA violations before death
32
  FATAL_FAIL_THRESHOLD: int = 200 # Hard cap on queue depth (catastrophic failure boundary)
 
 
 
33
 
34
  SENSOR_DROPOUT_PROB: float = 0.05 # P(node.queue, latency reports 0 or -1.0)
35
  NODE_FAILURE_PROB: float = 0.00 # P(node fails naturally) — largely driven by task profile
@@ -90,6 +93,8 @@ class NodeState:
90
  dropped_requests: float = 0.0
91
  shed_fraction: float = 0.0 # Fraction of incoming traffic to drop this tick
92
  pending_capacity_queue: list[int] = field(default_factory=list)
 
 
93
 
94
  # Derived (recomputed whenever capacity or status changes)
95
  @property
@@ -114,6 +119,8 @@ class NodeState:
114
  "shed_fraction": round(self.shed_fraction, 4),
115
  "capacity_units": int(self.capacity),
116
  "pending_capacity_units": int(len(self.pending_capacity_queue)),
 
 
117
  }
118
 
119
 
@@ -146,6 +153,8 @@ class ClusterSimulator:
146
  self._t3_surge_end: int = T3_SURGE_BASE_END
147
  # Per-node reroute weights for REROUTE_TRAFFIC (node_id → fraction)
148
  self._reroute_weights: dict[str, float] = {}
 
 
149
  self._nodes: list[NodeState] = []
150
  self.invalid_action_count: int = 0
151
  self._randomize_domain()
@@ -176,6 +185,7 @@ class ClusterSimulator:
176
  node_id=f"node-{i}",
177
  is_vip=f"node-{i}" in VIP_NODE_WEIGHTS,
178
  importance_weight=VIP_NODE_WEIGHTS.get(f"node-{i}", 1.0),
 
179
  )
180
  for i in range(self._n_nodes)
181
  ]
@@ -190,6 +200,8 @@ class ClusterSimulator:
190
  self._tick_count = 0
191
  self._failed_node_id = None
192
  self._reroute_weights = {}
 
 
193
  self.invalid_action_count = 0
194
  self._randomize_domain()
195
  self._reset_nodes()
@@ -266,9 +278,16 @@ class ClusterSimulator:
266
  self._update_queues()
267
  self._update_derived_metrics()
268
  self._update_statuses()
269
- # decay/reset shed fractions for next tick
 
 
 
 
 
270
  for node in self._nodes:
271
- node.shed_fraction = 0.0
 
 
272
 
273
  def _update_capacity(self) -> None:
274
  """Process pending capacity from SCALE_UP actions"""
@@ -296,6 +315,10 @@ class ClusterSimulator:
296
  self._failed_node_id = self._rng.choice(
297
  [n.node_id for n in self._nodes if n.node_id != "node-0"]
298
  )
 
 
 
 
299
 
300
  # Physics change: In Task 2, we do NOT redistribute dead node traffic
301
  # automatically. The infrastructure keeps sending λ/N to the failed node
@@ -384,8 +407,10 @@ class ClusterSimulator:
384
  n.incoming_request_rate += share
385
 
386
  # Decay weights — agent must keep re-issuing to maintain effect
 
 
387
  for nid in list(self._reroute_weights.keys()):
388
- self._reroute_weights[nid] *= 0.5
389
  if self._reroute_weights[nid] < 0.01:
390
  del self._reroute_weights[nid]
391
 
@@ -421,24 +446,93 @@ class ClusterSimulator:
421
  # Utilization = Ratio of λ to μ
422
  service_rate = n.service_rate
423
  n.cpu_utilization = n.incoming_request_rate / service_rate if service_rate > 0 else 1.0
424
-
425
- # Latency (simplified M/M/1 wait-time model)
426
- n.latency_ms = BASE_LATENCY_MS + (n.queue_depth * LATENCY_STEEPNESS)
 
 
 
 
 
 
427
 
428
  def _update_statuses(self) -> None:
429
- """Transition node health based on queue boundaries."""
 
 
 
 
 
 
 
 
430
  for n in self._nodes:
431
- if n.node_id == self._failed_node_id:
 
432
  n.status = NodeStatus.FAILED
 
433
  continue
434
 
435
  if n.queue_depth > FATAL_FAIL_THRESHOLD:
436
- n.status = NodeStatus.FAILED
 
 
 
437
  elif n.queue_depth > OVERLOAD_THRESHOLD:
438
  n.status = NodeStatus.DEGRADED
439
  elif n.status == NodeStatus.DEGRADED and n.queue_depth < (OVERLOAD_THRESHOLD / 2):
440
  n.status = NodeStatus.HEALTHY
441
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
442
  def reconcile_state(self, telemetry_map: dict) -> None:
443
  """
444
  Reconcile internal simulator state with external telemetry signals.
 
30
  OVERLOAD_THRESHOLD: int = 80 # Request count where node begins to "fail" (DEGRADED)
31
  LATENCY_STEEPNESS: float = 2.0 # Increased to ensure SLA violations before death
32
  FATAL_FAIL_THRESHOLD: int = 200 # Hard cap on queue depth (catastrophic failure boundary)
33
+ CASCADE_WINDOW_TICKS: int = 3 # Ticks after a failure to check for cascade effects
34
+ CASCADE_QUEUE_MULTIPLIER: float = 1.2 # Queue must exceed FATAL_FAIL_THRESHOLD * this to cascade
35
+ NODE_RECOVERY_TICKS: int = 20 # Ticks before a FAILED node auto-recovers
36
 
37
  SENSOR_DROPOUT_PROB: float = 0.05 # P(node.queue, latency reports 0 or -1.0)
38
  NODE_FAILURE_PROB: float = 0.00 # P(node fails naturally) — largely driven by task profile
 
93
  dropped_requests: float = 0.0
94
  shed_fraction: float = 0.0 # Fraction of incoming traffic to drop this tick
95
  pending_capacity_queue: list[int] = field(default_factory=list)
96
+ recovery_timer: int = 0 # Countdown to auto-recovery from FAILED status
97
+ is_scripted_failure: bool = False # True if failed due to task scripting (no auto-recovery)
98
 
99
  # Derived (recomputed whenever capacity or status changes)
100
  @property
 
119
  "shed_fraction": round(self.shed_fraction, 4),
120
  "capacity_units": int(self.capacity),
121
  "pending_capacity_units": int(len(self.pending_capacity_queue)),
122
+ "recovery_timer": self.recovery_timer,
123
+ "is_scripted_failure": self.is_scripted_failure,
124
  }
125
 
126
 
 
153
  self._t3_surge_end: int = T3_SURGE_BASE_END
154
  # Per-node reroute weights for REROUTE_TRAFFIC (node_id → fraction)
155
  self._reroute_weights: dict[str, float] = {}
156
+ self._cascade_tick: int = 0 # Tick counter for cascade detection window
157
+ self._cascade_triggered: bool = False # Set True when a NEW overload failure occurs
158
  self._nodes: list[NodeState] = []
159
  self.invalid_action_count: int = 0
160
  self._randomize_domain()
 
185
  node_id=f"node-{i}",
186
  is_vip=f"node-{i}" in VIP_NODE_WEIGHTS,
187
  importance_weight=VIP_NODE_WEIGHTS.get(f"node-{i}", 1.0),
188
+ is_scripted_failure=False,
189
  )
190
  for i in range(self._n_nodes)
191
  ]
 
200
  self._tick_count = 0
201
  self._failed_node_id = None
202
  self._reroute_weights = {}
203
+ self._cascade_tick = 0
204
+ self._cascade_triggered = False
205
  self.invalid_action_count = 0
206
  self._randomize_domain()
207
  self._reset_nodes()
 
278
  self._update_queues()
279
  self._update_derived_metrics()
280
  self._update_statuses()
281
+ self._cascade_failures()
282
+ self._process_recovery()
283
+ # Decay shed fractions gradually (retain 80% per tick = slow decay)
284
+ # The agent must still re-issue to maintain full effect, but the
285
+ # effect doesn't vanish instantly. *= 0.8 means after 3 ticks
286
+ # the shed is still at 51% (0.8^3), vs old 0.0 after 1 tick.
287
  for node in self._nodes:
288
+ node.shed_fraction *= 0.8
289
+ if node.shed_fraction < 0.01:
290
+ node.shed_fraction = 0.0
291
 
292
  def _update_capacity(self) -> None:
293
  """Process pending capacity from SCALE_UP actions"""
 
315
  self._failed_node_id = self._rng.choice(
316
  [n.node_id for n in self._nodes if n.node_id != "node-0"]
317
  )
318
+ # Mark the chosen node as a scripted (permanent) failure
319
+ target = next((n for n in self._nodes if n.node_id == self._failed_node_id), None)
320
+ if target:
321
+ target.is_scripted_failure = True
322
 
323
  # Physics change: In Task 2, we do NOT redistribute dead node traffic
324
  # automatically. The infrastructure keeps sending λ/N to the failed node
 
407
  n.incoming_request_rate += share
408
 
409
  # Decay weights — agent must keep re-issuing to maintain effect
410
+ # *= 0.8 retains 80% per tick (slow decay, persistent effect).
411
+ # After 5 ticks without re-issue, effect is at 33% (0.8^5).
412
  for nid in list(self._reroute_weights.keys()):
413
+ self._reroute_weights[nid] *= 0.8
414
  if self._reroute_weights[nid] < 0.01:
415
  del self._reroute_weights[nid]
416
 
 
446
  # Utilization = Ratio of λ to μ
447
  service_rate = n.service_rate
448
  n.cpu_utilization = n.incoming_request_rate / service_rate if service_rate > 0 else 1.0
449
+
450
+ # Latency: Hybrid M/M/1 + backlog term
451
+ # M/M/1 gives exponential blow-up as utilization->1 (the "hockey stick")
452
+ # Backlog term ensures queue_depth still contributes signal even when
453
+ # utilization is capped at 0.99, preventing the flattening problem.
454
+ utilization = min(0.99, n.cpu_utilization) # cap to prevent infinity
455
+ mm1_latency = BASE_LATENCY_MS / (1.0 - utilization)
456
+ backlog_latency = n.queue_depth * LATENCY_STEEPNESS
457
+ n.latency_ms = mm1_latency + backlog_latency
458
 
459
  def _update_statuses(self) -> None:
460
+ """Transition node health based on queue boundaries.
461
+
462
+ Recovery rules:
463
+ - Scripted failures (Task 2 forced node kill): permanent, never auto-recover.
464
+ Marked by is_scripted_failure=True, recovery_timer=0.
465
+ - Overload failures (queue > FATAL_FAIL_THRESHOLD): auto-recover after
466
+ NODE_RECOVERY_TICKS. The agent can learn to reroute away and let the
467
+ node heal.
468
+ """
469
  for n in self._nodes:
470
+ # Scripted (task-forced) failures are permanent
471
+ if n.is_scripted_failure:
472
  n.status = NodeStatus.FAILED
473
+ n.recovery_timer = 0
474
  continue
475
 
476
  if n.queue_depth > FATAL_FAIL_THRESHOLD:
477
+ if n.status != NodeStatus.FAILED:
478
+ n.status = NodeStatus.FAILED
479
+ n.recovery_timer = NODE_RECOVERY_TICKS
480
+ self._cascade_triggered = True # Signal cascade detection
481
  elif n.queue_depth > OVERLOAD_THRESHOLD:
482
  n.status = NodeStatus.DEGRADED
483
  elif n.status == NodeStatus.DEGRADED and n.queue_depth < (OVERLOAD_THRESHOLD / 2):
484
  n.status = NodeStatus.HEALTHY
485
 
486
+ def _cascade_failures(self) -> None:
487
+ """Detect cascading failure: if a peer node's queue exceeds a heightened
488
+ threshold within CASCADE_WINDOW_TICKS of a *new* failure, degrade it.
489
+
490
+ Guardrails:
491
+ - Only triggers when a NEW failure occurred this tick (not any failed node).
492
+ - Max one cascade step per failure event (no cascade chains).
493
+ - Scripted failures (Task 2) do not trigger cascades.
494
+ """
495
+ if not self._cascade_triggered:
496
+ self._cascade_tick = 0
497
+ return
498
+
499
+ self._cascade_tick += 1
500
+ if self._cascade_tick > CASCADE_WINDOW_TICKS:
501
+ self._cascade_triggered = False
502
+ self._cascade_tick = 0
503
+ return
504
+
505
+ cascade_threshold = FATAL_FAIL_THRESHOLD * CASCADE_QUEUE_MULTIPLIER
506
+ cascaded_this_tick = 0
507
+ for n in self._nodes:
508
+ if cascaded_this_tick >= 1:
509
+ break # Max one cascade per window to prevent chain reactions
510
+ if n.status == NodeStatus.FAILED:
511
+ continue
512
+ if n.is_scripted_failure:
513
+ continue
514
+ if n.queue_depth > cascade_threshold:
515
+ n.status = NodeStatus.DEGRADED
516
+ cascaded_this_tick += 1
517
+
518
+ def _process_recovery(self) -> None:
519
+ """Count down recovery timers and bring FAILED nodes back online.
520
+
521
+ Only overload-failed nodes (recovery_timer > 0) can recover.
522
+ Scripted failures (is_scripted_failure=True) are excluded.
523
+ """
524
+ for n in self._nodes:
525
+ if n.is_scripted_failure:
526
+ continue
527
+ if n.status == NodeStatus.FAILED and n.recovery_timer > 0:
528
+ n.recovery_timer -= 1
529
+ if n.recovery_timer <= 0:
530
+ n.status = NodeStatus.HEALTHY
531
+ n.capacity = 1.0 # Recover at minimum capacity
532
+ n.queue_depth = 0.0
533
+ n.latency_ms = BASE_LATENCY_MS
534
+ n.cpu_utilization = 0.0
535
+
536
  def reconcile_state(self, telemetry_map: dict) -> None:
537
  """
538
  Reconcile internal simulator state with external telemetry signals.
stability.py CHANGED
@@ -53,6 +53,13 @@ Q_BARRIER_MAX: float = 150.0
53
  Set higher than OVERLOAD_THRESHOLD (80) to allow the agent time to react
54
  before the barrier penalty kicks in."""
55
 
 
 
 
 
 
 
 
56
  STABILITY_WINDOW: int = 10
57
  """Number of ticks to look back when judging whether the system is
58
  trend-stable (V is on a decreasing trajectory)."""
@@ -65,7 +72,7 @@ trend-stable (V is on a decreasing trajectory)."""
65
  REWARD_NORM_MIDPOINT: float = float(os.getenv("ANTIATROPOS_REWARD_MIDPOINT", "0.0"))
66
  REWARD_NORM_TEMPERATURE: float = float(os.getenv("ANTIATROPOS_REWARD_TEMPERATURE", "5.0"))
67
  REWARD_NORM_EPS: float = float(os.getenv("ANTIATROPOS_REWARD_EPS", "1e-8"))
68
- REWARD_SCALE_VERSION: str = "sigmoid-v1"
69
 
70
 
71
  # ---------------------------------------------------------------------------
@@ -245,36 +252,93 @@ def drift_plus_penalty(
245
  # Convenience: full reward computation (matches environment.py formula)
246
  # ---------------------------------------------------------------------------
247
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
  def compute_reward(
249
  v_prev: float,
250
  v_curr: float,
251
  cost: float,
252
- sla_violation_step: int,
253
  alpha: float = 1.0,
254
  beta: float = 0.05,
255
  gamma: float = 2.0,
 
 
256
  ) -> float:
257
  """
258
- R_t = (α·ΔV(s) + β·Cost + γ·SLA_violation_step)
259
 
260
  Convenience wrapper that mirrors the reward formula in environment.py.
261
- Can be used by the baseline agent to simulate rewards without calling
262
- the server, or by the grader to reconstruct reward trajectories.
263
 
264
  Args:
265
  v_prev: Lyapunov energy at previous tick.
266
  v_curr: Lyapunov energy at current tick.
267
  cost: Infrastructure cost this tick (USD/hr).
268
- sla_violation_step: 1 if this step violated SLA, else 0.
269
  alpha: Weight on Lyapunov drift.
270
  beta: Weight on cost.
271
  gamma: Weight on SLA violations.
 
 
272
 
273
  Returns:
274
- Scalar reward (higher is better, always 0 in a stable episode).
275
  """
276
  delta_v = compute_drift(v_prev, v_curr)
277
- return -(alpha * delta_v + beta * cost + gamma * sla_violation_step)
 
 
 
 
278
 
279
 
280
  def normalize_reward(
 
53
  Set higher than OVERLOAD_THRESHOLD (80) to allow the agent time to react
54
  before the barrier penalty kicks in."""
55
 
56
+ BARRIER_NORM_SCALE: float = 10000.0
57
+ """Normalization divisor for the barrier term.
58
+ The raw barrier H(s) = sum(max(0, Q_i - Q_max)^2) can produce very large values
59
+ (e.g. 5 nodes at Q=200, Q_max=150 gives 5*2500=12500). Without normalization,
60
+ this dominates the reward. Dividing by this scale keeps barrier in the same
61
+ order of magnitude as the other terms when delta=0.005."""
62
+
63
  STABILITY_WINDOW: int = 10
64
  """Number of ticks to look back when judging whether the system is
65
  trend-stable (V is on a decreasing trajectory)."""
 
72
  REWARD_NORM_MIDPOINT: float = float(os.getenv("ANTIATROPOS_REWARD_MIDPOINT", "0.0"))
73
  REWARD_NORM_TEMPERATURE: float = float(os.getenv("ANTIATROPOS_REWARD_TEMPERATURE", "5.0"))
74
  REWARD_NORM_EPS: float = float(os.getenv("ANTIATROPOS_REWARD_EPS", "1e-8"))
75
+ REWARD_SCALE_VERSION: str = "sigmoid-v2" # v2: smooth SLA + barrier active
76
 
77
 
78
  # ---------------------------------------------------------------------------
 
252
  # Convenience: full reward computation (matches environment.py formula)
253
  # ---------------------------------------------------------------------------
254
 
255
+ def smooth_sla_penalty(
256
+ avg_latency_norm: float,
257
+ error_rate: float,
258
+ latency_threshold: float = 0.20,
259
+ error_threshold: float = 0.05,
260
+ latency_temperature: float = 0.03,
261
+ error_temperature: float = 0.01,
262
+ ) -> float:
263
+ """
264
+ Smooth SLA penalty in [0, 1] that ramps up as metrics approach thresholds.
265
+
266
+ Unlike the binary cliff (0 or 1), this gives the agent gradient signal
267
+ BEFORE the SLA is actually violated, enabling preventive learning.
268
+
269
+ Uses two sigmoids (one for latency, one for errors) and takes the max
270
+ so whichever dimension is worse dominates.
271
+
272
+ Args:
273
+ avg_latency_norm: Normalized average latency [0, 1].
274
+ error_rate: Cluster-wide error rate [0, 1].
275
+ latency_threshold: Normalized latency SLA boundary.
276
+ error_threshold: Error rate SLA boundary.
277
+ latency_temperature: Sigmoid temperature for latency (lower = sharper).
278
+ error_temperature: Sigmoid temperature for errors (lower = sharper).
279
+
280
+ Returns:
281
+ Smooth penalty in [0, 1]. Near 0 when safe, near 1 when violating.
282
+
283
+ Raises:
284
+ ValueError: If inputs are outside [0, 1], indicating raw (non-normalized)
285
+ values were passed by mistake. This is a common bug: passing latency
286
+ in raw ms (e.g. 200.0) instead of normalized [0,1] (e.g. 0.20).
287
+ """
288
+ if avg_latency_norm < -0.01 or avg_latency_norm > 1.5:
289
+ raise ValueError(
290
+ f"smooth_sla_penalty: avg_latency_norm={avg_latency_norm:.4f} is outside "
291
+ f"expected [0, 1] range. Did you pass raw ms instead of normalized? "
292
+ f"Divide by MAX_LATENCY_NORM before calling."
293
+ )
294
+ if error_rate < -0.01 or error_rate > 1.5:
295
+ raise ValueError(
296
+ f"smooth_sla_penalty: error_rate={error_rate:.4f} is outside "
297
+ f"expected [0, 1] range."
298
+ )
299
+ lat_z = (avg_latency_norm - latency_threshold) / max(1e-8, latency_temperature)
300
+ err_z = (error_rate - error_threshold) / max(1e-8, error_temperature)
301
+ lat_penalty = 1.0 / (1.0 + math.exp(-lat_z))
302
+ err_penalty = 1.0 / (1.0 + math.exp(-err_z))
303
+ return max(lat_penalty, err_penalty)
304
+
305
+
306
  def compute_reward(
307
  v_prev: float,
308
  v_curr: float,
309
  cost: float,
310
+ sla_violation_step: float = 0.0,
311
  alpha: float = 1.0,
312
  beta: float = 0.05,
313
  gamma: float = 2.0,
314
+ barrier: float = 0.0,
315
+ delta: float = 0.005,
316
  ) -> float:
317
  """
318
+ R_t = -(alpha * DeltaV(s) + beta * Cost + gamma * SLA_smooth + delta * Barrier)
319
 
320
  Convenience wrapper that mirrors the reward formula in environment.py.
 
 
321
 
322
  Args:
323
  v_prev: Lyapunov energy at previous tick.
324
  v_curr: Lyapunov energy at current tick.
325
  cost: Infrastructure cost this tick (USD/hr).
326
+ sla_violation_step: Smooth SLA penalty in [0, 1] (was binary 0/1).
327
  alpha: Weight on Lyapunov drift.
328
  beta: Weight on cost.
329
  gamma: Weight on SLA violations.
330
+ barrier: Control-barrier function violation energy.
331
+ delta: Weight on barrier penalty.
332
 
333
  Returns:
334
+ Scalar reward (higher is better, always <= 0 in a stable episode).
335
  """
336
  delta_v = compute_drift(v_prev, v_curr)
337
+ # Normalize barrier to prevent reward domination: raw barrier can be ~12500,
338
+ # after dividing by BARRIER_NORM_SCALE it's ~1.25, then scaled by delta=0.005
339
+ # gives ~0.006 which is comparable to other terms.
340
+ barrier_normalized = barrier / BARRIER_NORM_SCALE if BARRIER_NORM_SCALE > 0 else barrier
341
+ return -(alpha * delta_v + beta * cost + gamma * sla_violation_step + delta * barrier_normalized)
342
 
343
 
344
  def normalize_reward(