| """ |
| AntiAtropos Training Loop. |
| |
| Orchestrates episode collection, reward computation, and loss calculation |
| for training LLM-based SRE agents. Works with the local simulator in |
| pure-Python mode (no AWS/GPU needed for validation). |
| |
| On Colab: Replace EpisodeCollector's "model" with a real QLoRA-backed |
| transformers model. The rest of the pipeline stays the same. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import random |
| import math |
| from dataclasses import dataclass, field |
| from typing import List, Optional, Protocol, Callable |
|
|
| import sys |
| import os |
| sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")) |
|
|
| from simulator import ClusterSimulator, NodeStatus, COST_PER_CAPACITY_UNIT_PER_HOUR |
| from stability import ( |
| compute_lyapunov, compute_reward, compute_barrier, |
| normalize_reward, smooth_sla_penalty, |
| ) |
| from .losses import ( |
| compute_returns, compute_gae, |
| reinforce_loss, reinforce_baseline_loss, |
| grpo_loss, rloo_loss, |
| normalize_rewards, compute_reward_stats, |
| ) |
|
|
|
|
| |
| |
| |
|
|
| LOSS_REINFORCE = "reinforce" |
| LOSS_REINFORCE_BASELINE = "reinforce_baseline" |
| LOSS_GRPO = "grpo" |
| LOSS_RLOO = "rloo" |
|
|
| VALID_LOSSES = {LOSS_REINFORCE, LOSS_REINFORCE_BASELINE, LOSS_GRPO, LOSS_RLOO} |
|
|
|
|
| @dataclass |
| class TrainingConfig: |
| """Configuration for the SRE training loop.""" |
|
|
| |
| n_nodes: int = 5 |
| max_steps: int = 100 |
| tasks: List[str] = field(default_factory=lambda: ["task-1", "task-2", "task-3"]) |
|
|
| |
| loss_fn: str = LOSS_REINFORCE_BASELINE |
| gamma: float = 0.99 |
| gae_lambda: float = 0.95 |
|
|
| |
| n_samples_per_state: int = 4 |
|
|
| |
| normalize_rewards: bool = True |
| reward_ema_alpha: float = 0.01 |
|
|
| |
| normalize_advantages: bool = True |
|
|
| |
| log_every: int = 10 |
|
|
|
|
| |
| |
| |
|
|
| class PolicyModel(Protocol): |
| """Interface that both real LLMs and mock models must implement.""" |
|
|
| def get_log_prob(self, prompt: str, action_text: str) -> float: |
| """Return log π(action_text | prompt) under the current policy.""" |
| ... |
|
|
| def generate(self, prompt: str) -> str: |
| """Sample an action from the current policy.""" |
| ... |
|
|
|
|
| class MockPolicyModel: |
| """ |
| Random policy for local testing. Generates random valid actions |
| and returns stochastic log probabilities. |
| |
| This is NOT for training — it's for validating the training pipeline |
| (episode collection, reward computation, loss calculation) end-to-end |
| before connecting a real model. |
| |
| Unlike a truly uniform policy (which would produce zero gradient with |
| advantage normalization), this model returns varying log probs per |
| action, simulating a real LLM that prefers some actions over others. |
| This ensures the training pipeline produces non-trivial losses. |
| """ |
|
|
| def __init__(self, n_nodes: int = 5, seed: int = 42): |
| self._rng = random.Random(seed) |
| self._n_nodes = n_nodes |
| |
| |
| |
| self._n_choices = 5 * n_nodes |
| self._base_log_prob = math.log(1.0 / self._n_choices) |
|
|
| def get_log_prob(self, prompt: str, action_text: str) -> float: |
| """Return stochastic log probability (varies per action).""" |
| |
| |
| |
| noise = self._rng.gauss(0, 0.5) |
| return self._base_log_prob + noise |
|
|
| def generate(self, prompt: str) -> str: |
| """Generate a random valid action as JSON string.""" |
| import json |
| action_types = ["SCALE_UP", "SCALE_DOWN", "REROUTE_TRAFFIC", "SHED_LOAD", "NO_OP"] |
| node_id = f"node-{self._rng.randint(0, self._n_nodes - 1)}" |
| action_type = self._rng.choice(action_types) |
| parameter = round(self._rng.random(), 2) |
| return json.dumps({ |
| "action_type": action_type, |
| "target_node_id": node_id, |
| "parameter": parameter, |
| }) |
|
|
|
|
| |
| |
| |
|
|
| MAX_QUEUE_NORM = 200.0 |
| MAX_LATENCY_NORM = 1000.0 |
| MAX_REQUEST_RATE_NORM = 100.0 |
| ALPHA, BETA, GAMMA, DELTA = 0.002, 0.3, 6.0, 0.1 |
|
|
|
|
| def format_observation(nodes: List[dict], task_id: str, step: int, max_steps: int) -> str: |
| """ |
| Format simulator state as a text prompt for the model. |
| |
| This mirrors inference.py's build_user_prompt and observation_for_model. |
| """ |
| import json |
| node_data = [] |
| for n in nodes: |
| node_data.append({ |
| "node_id": n["node_id"], |
| "status": n["status"] if isinstance(n["status"], str) else n["status"].value, |
| "is_vip": n.get("is_vip", False), |
| "queue_depth": min(1.0, max(0.0, n["queue_depth"] / MAX_QUEUE_NORM)), |
| "latency_ms": min(1.0, max(0.0, n["latency_ms"] / MAX_LATENCY_NORM)), |
| "cpu_utilization": min(1.0, max(0.0, n.get("cpu_utilization", 0.0))), |
| "incoming_request_rate": min(1.0, max(0.0, n["incoming_request_rate"] / MAX_REQUEST_RATE_NORM)), |
| }) |
| obs = {"task_id": task_id, "step": step, "max_steps": max_steps, "nodes": node_data} |
| return json.dumps(obs, separators=(",", ":")) |
|
|
|
|
| def parse_action(action_text: str) -> dict: |
| """Parse model output into an action dict.""" |
| import json |
| try: |
| data = json.loads(action_text) |
| return { |
| "action_type": str(data.get("action_type", "NO_OP")).upper(), |
| "target_node_id": str(data.get("target_node_id", "node-0")), |
| "parameter": float(data.get("parameter", 0.0)), |
| } |
| except (json.JSONDecodeError, ValueError): |
| return {"action_type": "NO_OP", "target_node_id": "node-0", "parameter": 0.0} |
|
|
|
|
| |
| |
| |
|
|
| @dataclass |
| class StepRecord: |
| """A single step in an episode.""" |
| prompt: str |
| action_text: str |
| log_prob: float |
| reward: float |
| reward_normalized: float |
|
|
|
|
| @dataclass |
| class EpisodeRecord: |
| """A complete episode trajectory.""" |
| task_id: str |
| steps: List[StepRecord] = field(default_factory=list) |
| total_reward: float = 0.0 |
| total_reward_normalized: float = 0.0 |
| avg_reward: float = 0.0 |
| sla_violations: int = 0 |
| final_lyapunov: float = 0.0 |
|
|
|
|
| class EpisodeCollector: |
| """ |
| Collects episodes by running the simulator with a policy model. |
| |
| This is the bridge between the simulator (physics) and the training |
| pipeline (loss computation). It produces EpisodeRecords that feed |
| directly into the loss functions. |
| """ |
|
|
| def __init__(self, config: TrainingConfig): |
| self._config = config |
| self._sim = ClusterSimulator(n_nodes=config.n_nodes) |
|
|
| def collect_episode( |
| self, |
| model: PolicyModel, |
| task_id: str, |
| seed: Optional[int] = None, |
| ) -> EpisodeRecord: |
| """Run one episode and collect step-level data.""" |
| cfg = self._config |
| self._sim.reset(task_id=task_id, seed=seed) |
|
|
| record = EpisodeRecord(task_id=task_id) |
| prev_lyapunov = 0.0 |
|
|
| for step in range(1, cfg.max_steps + 1): |
| |
| nodes_true = self._sim.state(for_agent=False) |
| nodes_obs = self._sim.state(for_agent=True) |
| prompt = format_observation(nodes_obs, task_id, step, cfg.max_steps) |
|
|
| |
| action_text = model.generate(prompt) |
| log_prob = model.get_log_prob(prompt, action_text) |
|
|
| |
| action = parse_action(action_text) |
| class _A: |
| pass |
| a = _A() |
| a.action_type = action["action_type"] |
| a.target_node_id = action["target_node_id"] |
| a.parameter = action["parameter"] |
| self._sim.apply_action(a) |
|
|
| |
| self._sim.tick() |
|
|
| |
| nodes_true = self._sim.state(for_agent=False) |
| current_lyapunov = compute_lyapunov(nodes_true) |
|
|
| |
| w_lat = 0.0 |
| w_sum = 0.0 |
| for n in nodes_true: |
| w = n.get("importance_weight", 1.0) |
| lat = MAX_LATENCY_NORM if n["status"] == NodeStatus.FAILED else n["latency_ms"] |
| w_lat += w * lat |
| w_sum += w |
| avg_lat_norm = min(1.0, max(0.0, (w_lat / w_sum / MAX_LATENCY_NORM) if w_sum > 0 else 1.0)) |
|
|
| |
| total_in = sum(n.get("incoming_request_rate", 0) * n.get("importance_weight", 1.0) for n in nodes_true) |
| total_drop = sum(n.get("dropped_requests", 0) * n.get("importance_weight", 1.0) for n in nodes_true) |
| error_rate = min(1.0, total_drop / total_in) if total_in > 0 else 0.0 |
|
|
| sla_step = smooth_sla_penalty(avg_lat_norm, error_rate) |
| if avg_lat_norm > 0.20 or error_rate > 0.05: |
| record.sla_violations += 1 |
|
|
| |
| total_cap = 0 |
| for n in nodes_true: |
| if n["status"] != NodeStatus.FAILED: |
| total_cap += int(n.get("capacity_units", 0)) + int(n.get("pending_capacity_units", 0)) |
| cost = total_cap * COST_PER_CAPACITY_UNIT_PER_HOUR |
|
|
| barrier = compute_barrier(nodes_true) |
| raw_reward = compute_reward( |
| prev_lyapunov, current_lyapunov, cost, sla_step, |
| ALPHA, BETA, GAMMA, barrier, DELTA, |
| ) |
| norm_reward = normalize_reward(raw_reward) |
|
|
| record.steps.append(StepRecord( |
| prompt=prompt, |
| action_text=action_text, |
| log_prob=log_prob, |
| reward=raw_reward, |
| reward_normalized=norm_reward, |
| )) |
| record.total_reward += raw_reward |
| record.total_reward_normalized += norm_reward |
| prev_lyapunov = current_lyapunov |
|
|
| record.avg_reward = record.total_reward / max(1, len(record.steps)) |
| record.final_lyapunov = prev_lyapunov |
| return record |
|
|
| def collect_group( |
| self, |
| model: PolicyModel, |
| task_id: str, |
| k: int, |
| seed: Optional[int] = None, |
| ) -> List[EpisodeRecord]: |
| """ |
| Collect K episodes from the same initial state (for GRPO/RLOO). |
| |
| Uses the same seed for all K episodes so they start from the same |
| domain randomization, but different model samples produce different |
| trajectories. |
| """ |
| return [self.collect_episode(model, task_id, seed=seed) for _ in range(k)] |
|
|
|
|
| |
| |
| |
|
|
| class SRETrainer: |
| """ |
| Main training orchestrator for AntiAtropos SRE agents. |
| |
| Usage (local validation with MockPolicyModel): |
| config = TrainingConfig(loss_fn="reinforce_baseline") |
| trainer = SRETrainer(config) |
| model = MockPolicyModel() |
| metrics = trainer.train_step(model, task_id="task-1", seed=42) |
| |
| Usage (Colab with real model): |
| config = TrainingConfig(loss_fn="grpo", n_samples_per_state=4) |
| trainer = SRETrainer(config) |
| model = QLoRAModel(...) # Your transformers model |
| for epoch in range(num_epochs): |
| for task in config.tasks: |
| metrics = trainer.train_step(model, task_id=task) |
| model.update(metrics["loss"]) # Backprop |
| """ |
|
|
| def __init__(self, config: TrainingConfig): |
| assert config.loss_fn in VALID_LOSSES, f"Unknown loss: {config.loss_fn}" |
| self._config = config |
| self._collector = EpisodeCollector(config) |
| self._running_reward_mean = 0.0 |
| self._running_reward_var = 1.0 |
| self._episode_count = 0 |
|
|
| def train_step( |
| self, |
| model: PolicyModel, |
| task_id: str, |
| seed: Optional[int] = None, |
| ) -> dict: |
| """ |
| Execute one training step: collect episode(s) → compute loss. |
| |
| Returns a metrics dict with: |
| - loss: The computed loss value |
| - avg_reward: Average raw reward across the episode |
| - avg_norm_reward: Average normalized reward |
| - episode_length: Number of steps |
| - sla_violations: Number of SLA violations |
| - final_lyapunov: Lyapunov energy at episode end |
| - reward_mean/var: Running reward statistics |
| """ |
| cfg = self._config |
|
|
| if cfg.loss_fn in (LOSS_GRPO, LOSS_RLOO): |
| return self._train_step_grouped(model, task_id, seed) |
| else: |
| return self._train_step_single(model, task_id, seed) |
|
|
| def _train_step_single( |
| self, |
| model: PolicyModel, |
| task_id: str, |
| seed: Optional[int] = None, |
| ) -> dict: |
| """Train step for REINFORCE / REINFORCE+baseline.""" |
| cfg = self._config |
|
|
| |
| episode = self._collector.collect_episode(model, task_id, seed=seed) |
|
|
| |
| rewards = [s.reward for s in episode.steps] |
| log_probs = [s.log_prob for s in episode.steps] |
|
|
| |
| ep_mean, ep_var = compute_reward_stats(rewards) |
| self._running_reward_mean = ( |
| (1 - cfg.reward_ema_alpha) * self._running_reward_mean |
| + cfg.reward_ema_alpha * ep_mean |
| ) |
| self._running_reward_var = ( |
| (1 - cfg.reward_ema_alpha) * self._running_reward_var |
| + cfg.reward_ema_alpha * ep_var |
| ) |
|
|
| |
| if cfg.normalize_rewards: |
| rewards = normalize_rewards( |
| rewards, self._running_reward_mean, self._running_reward_var |
| ) |
|
|
| |
| returns = compute_returns(rewards, gamma=cfg.gamma) |
|
|
| |
| if cfg.loss_fn == LOSS_REINFORCE: |
| loss = reinforce_loss(log_probs, returns) |
| elif cfg.loss_fn == LOSS_REINFORCE_BASELINE: |
| |
| baselines = [self._running_reward_mean] * len(returns) |
| loss = reinforce_baseline_loss( |
| log_probs, returns, baselines, |
| normalize_advantage=cfg.normalize_advantages, |
| ) |
| else: |
| raise ValueError(f"Unexpected loss_fn: {cfg.loss_fn}") |
|
|
| self._episode_count += 1 |
|
|
| return { |
| "loss": loss, |
| "avg_reward": episode.avg_reward, |
| "avg_norm_reward": episode.total_reward_normalized / max(1, len(episode.steps)), |
| "episode_length": len(episode.steps), |
| "sla_violations": episode.sla_violations, |
| "final_lyapunov": episode.final_lyapunov, |
| "reward_mean": self._running_reward_mean, |
| "reward_var": self._running_reward_var, |
| "task_id": task_id, |
| "episode": episode, |
| } |
|
|
| def _train_step_grouped( |
| self, |
| model: PolicyModel, |
| task_id: str, |
| seed: Optional[int] = None, |
| ) -> dict: |
| """Train step for GRPO / RLOO.""" |
| cfg = self._config |
| k = cfg.n_samples_per_state |
|
|
| |
| episodes = self._collector.collect_group(model, task_id, k=k, seed=seed) |
|
|
| |
| |
| min_len = min(len(ep.steps) for ep in episodes) |
|
|
| log_probs_groups = [] |
| rewards_groups = [] |
|
|
| for t in range(min_len): |
| step_lps = [] |
| step_rs = [] |
| for ep in episodes: |
| step_lps.append(ep.steps[t].log_prob) |
| step_rs.append(ep.steps[t].reward) |
| log_probs_groups.append(step_lps) |
| rewards_groups.append(step_rs) |
|
|
| |
| all_rewards = [s.reward for ep in episodes for s in ep.steps] |
| ep_mean, ep_var = compute_reward_stats(all_rewards) |
| self._running_reward_mean = ( |
| (1 - cfg.reward_ema_alpha) * self._running_reward_mean |
| + cfg.reward_ema_alpha * ep_mean |
| ) |
| self._running_reward_var = ( |
| (1 - cfg.reward_ema_alpha) * self._running_reward_var |
| + cfg.reward_ema_alpha * ep_var |
| ) |
|
|
| |
| if cfg.normalize_rewards: |
| rewards_groups = [ |
| normalize_rewards(rs, self._running_reward_mean, self._running_reward_var) |
| for rs in rewards_groups |
| ] |
|
|
| |
| if cfg.loss_fn == LOSS_GRPO: |
| loss = grpo_loss(log_probs_groups, rewards_groups) |
| elif cfg.loss_fn == LOSS_RLOO: |
| loss = rloo_loss(log_probs_groups, rewards_groups) |
| else: |
| raise ValueError(f"Unexpected grouped loss_fn: {cfg.loss_fn}") |
|
|
| |
| avg_reward = sum(ep.avg_reward for ep in episodes) / len(episodes) |
| avg_norm = sum( |
| ep.total_reward_normalized / max(1, len(ep.steps)) for ep in episodes |
| ) / len(episodes) |
| total_sla = sum(ep.sla_violations for ep in episodes) |
| avg_lyapunov = sum(ep.final_lyapunov for ep in episodes) / len(episodes) |
|
|
| self._episode_count += k |
|
|
| return { |
| "loss": loss, |
| "avg_reward": avg_reward, |
| "avg_norm_reward": avg_norm, |
| "episode_length": min_len, |
| "sla_violations": total_sla, |
| "final_lyapunov": avg_lyapunov, |
| "reward_mean": self._running_reward_mean, |
| "reward_var": self._running_reward_var, |
| "task_id": task_id, |
| "episodes": episodes, |
| } |
|
|
| def train_epoch( |
| self, |
| model: PolicyModel, |
| seed: Optional[int] = None, |
| ) -> List[dict]: |
| """ |
| Run one training step per task in the curriculum. |
| |
| Returns a list of metrics dicts (one per task). |
| """ |
| results = [] |
| for task_id in self._config.tasks: |
| step_seed = seed + hash(task_id) % 1000 if seed is not None else None |
| metrics = self.train_step(model, task_id, seed=step_seed) |
| results.append(metrics) |
| if self._episode_count % self._config.log_every == 0: |
| self._log_metrics(metrics) |
| return results |
|
|
| def _log_metrics(self, metrics: dict) -> None: |
| """Print training metrics.""" |
| print( |
| f"[Episode {self._episode_count}] " |
| f"task={metrics['task_id']} " |
| f"loss={metrics['loss']:.4f} " |
| f"avg_reward={metrics['avg_reward']:.4f} " |
| f"avg_norm_reward={metrics['avg_norm_reward']:.4f} " |
| f"sla_violations={metrics['sla_violations']} " |
| f"lyapunov={metrics['final_lyapunov']:.1f} " |
| f"reward_mean={metrics['reward_mean']:.4f} " |
| f"reward_var={metrics['reward_var']:.4f}" |
| ) |
|
|