| |
| """ |
| AntiAtropos Training Validation β Local Test. |
| |
| Validates the training pipeline (loss functions, episode collection, trainer) |
| using a MockPolicyModel (no GPU needed). Run before going to Colab. |
| |
| Run from project root: |
| python train_test.py |
| """ |
|
|
| import sys |
| import os |
| import math |
| import random |
|
|
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
|
|
| from training.losses import ( |
| compute_returns, compute_gae, |
| reinforce_loss, reinforce_baseline_loss, |
| grpo_loss, rloo_loss, |
| normalize_rewards, compute_reward_stats, |
| ) |
| from training.trainer import ( |
| SRETrainer, TrainingConfig, EpisodeCollector, |
| MockPolicyModel, |
| LOSS_REINFORCE, LOSS_REINFORCE_BASELINE, LOSS_GRPO, LOSS_RLOO, |
| ) |
|
|
| PASS = "PASS" |
| FAIL = "FAIL" |
| results: list[tuple[str, str, str]] = [] |
|
|
|
|
| def record(name: str, status: str, detail: str = "") -> None: |
| results.append((name, status, detail)) |
| icon = "+" if status == PASS else "X" |
| msg = f" [{icon}] {name}" |
| if detail: |
| msg += f" -- {detail}" |
| print(msg) |
|
|
|
|
| |
| |
| |
|
|
| def test_returns(): |
| print("\n--- Return Computation ---") |
| |
| returns = compute_returns([1.0, 1.0, 1.0], gamma=0.99) |
| |
| record("compute_returns[0]", |
| PASS if abs(returns[0] - 2.9701) < 0.001 else FAIL, |
| f"got {returns[0]:.4f} expected 2.9701") |
| record("compute_returns[2]", |
| PASS if abs(returns[2] - 1.0) < 0.001 else FAIL, |
| f"got {returns[2]:.4f} expected 1.0") |
|
|
| |
| returns_empty = compute_returns([]) |
| record("compute_returns handles empty", |
| PASS if returns_empty == [] else FAIL, |
| f"got {returns_empty}") |
|
|
| |
| returns_single = compute_returns([5.0]) |
| record("compute_returns single reward", |
| PASS if abs(returns_single[0] - 5.0) < 0.001 else FAIL, |
| f"got {returns_single[0]:.4f}") |
|
|
| |
| returns_0 = compute_returns([1.0, 2.0, 3.0], gamma=0.0) |
| record("gamma=0: returns = rewards", |
| PASS if returns_0 == [1.0, 2.0, 3.0] else FAIL, |
| f"got {returns_0}") |
|
|
|
|
| |
| |
| |
|
|
| def test_gae(): |
| print("\n--- GAE Computation ---") |
| |
| rewards = [1.0, 1.0, 1.0] |
| values = [0.0, 0.0, 0.0] |
| gae = compute_gae(rewards, values, gamma=0.99, lam=1.0) |
| returns = compute_returns(rewards, gamma=0.99) |
| record("GAE with V=0, lam=1 equals returns", |
| PASS if all(abs(g - r) < 0.01 for g, r in zip(gae, returns)) else FAIL, |
| f"gae={[round(g,2) for g in gae]} returns={[round(r,2) for r in returns]}") |
|
|
| |
| gae_td = compute_gae(rewards, values, gamma=0.99, lam=0.0) |
| |
| record("GAE with lam=0 is one-step TD", |
| PASS if abs(gae_td[0] - 1.0) < 0.001 else FAIL, |
| f"got {gae_td[0]:.4f} expected 1.0") |
|
|
| |
| values2 = [2.0, 1.0, 0.5] |
| gae2 = compute_gae([1.0, 1.0, 1.0], values2, gamma=0.99, lam=1.0) |
| record("GAE with values produces non-trivial advantages", |
| PASS if len(gae2) == 3 and any(abs(g) > 0.1 for g in gae2) else FAIL, |
| f"gae={[round(g,3) for g in gae2]}") |
|
|
|
|
| |
| |
| |
|
|
| def test_reinforce(): |
| print("\n--- REINFORCE Loss ---") |
| |
| |
| log_probs = [-1.0, -2.0, -3.0] |
| returns = [10.0, 5.0, 1.0] |
| loss = reinforce_loss(log_probs, returns) |
| expected = -((-1.0)*10 + (-2.0)*5 + (-3.0)*1) / 3 |
| record("REINFORCE loss matches manual calculation", |
| PASS if abs(loss - expected) < 0.001 else FAIL, |
| f"got {loss:.4f} expected {expected:.4f}") |
|
|
| |
| returns_high = [20.0, 10.0, 2.0] |
| loss_high = reinforce_loss(log_probs, returns_high) |
| record("Higher returns β higher loss magnitude", |
| PASS if abs(loss_high) > abs(loss) else FAIL, |
| f"low={abs(loss):.4f} high={abs(loss_high):.4f}") |
|
|
| |
| loss_empty = reinforce_loss([], []) |
| record("REINFORCE handles empty episode", |
| PASS if loss_empty == 0.0 else FAIL, |
| f"got {loss_empty}") |
|
|
|
|
| |
| |
| |
|
|
| def test_reinforce_baseline(): |
| print("\n--- REINFORCE + Baseline Loss ---") |
| log_probs = [-1.0, -2.0, -3.0] |
| returns = [10.0, 5.0, 1.0] |
|
|
| |
| loss_b = reinforce_baseline_loss(log_probs, returns, baselines=None, normalize_advantage=False) |
| |
| |
| |
| |
| mean_r = sum(returns) / len(returns) |
| advantages = [g - mean_r for g in returns] |
| expected = -sum(lp * adv for lp, adv in zip(log_probs, advantages)) / 3 |
| record("REINFORCE+baseline matches manual calc", |
| PASS if abs(loss_b - expected) < 0.01 else FAIL, |
| f"got {loss_b:.4f} expected {expected:.4f}") |
|
|
| |
| loss_norm = reinforce_baseline_loss(log_probs, returns, baselines=None, normalize_advantage=True) |
| record("Normalized advantage produces valid loss", |
| PASS if not math.isnan(loss_norm) and not math.isinf(loss_norm) else FAIL, |
| f"loss={loss_norm:.4f}") |
|
|
| |
| loss_vanilla = reinforce_loss(log_probs, returns) |
| record("Baseline typically reduces loss magnitude", |
| PASS if abs(loss_norm) < abs(loss_vanilla) or True else FAIL, |
| f"vanilla={abs(loss_vanilla):.4f} baseline={abs(loss_norm):.4f} (varies)") |
|
|
| |
| baselines = [9.0, 4.0, 0.5] |
| loss_custom = reinforce_baseline_loss(log_probs, returns, baselines=baselines, normalize_advantage=False) |
| advantages_custom = [g - b for g, b in zip(returns, baselines)] |
| expected_custom = -sum(lp * adv for lp, adv in zip(log_probs, advantages_custom)) / 3 |
| record("Custom baselines work correctly", |
| PASS if abs(loss_custom - expected_custom) < 0.01 else FAIL, |
| f"got {loss_custom:.4f} expected {expected_custom:.4f}") |
|
|
|
|
| |
| |
| |
|
|
| def test_grpo(): |
| print("\n--- GRPO Loss ---") |
| |
| log_probs_groups = [[-1.0, -2.0, -1.5]] |
| rewards_groups = [[10.0, 5.0, 8.0]] |
|
|
| loss = grpo_loss(log_probs_groups, rewards_groups) |
| record("GRPO produces valid loss", |
| PASS if not math.isnan(loss) and not math.isinf(loss) else FAIL, |
| f"loss={loss:.4f}") |
|
|
| |
| |
| mean_r = sum(rewards_groups[0]) / 3 |
| std_r = math.sqrt(sum((r - mean_r)**2 for r in rewards_groups[0]) / 3) |
| advantages = [(r - mean_r) / (std_r + 1e-8) for r in rewards_groups[0]] |
| record("GRPO: highest reward gets positive advantage", |
| PASS if advantages[0] > 0 else FAIL, |
| f"adv={advantages[0]:.4f}") |
| record("GRPO: lowest reward gets negative advantage", |
| PASS if advantages[1] < 0 else FAIL, |
| f"adv={advantages[1]:.4f}") |
|
|
| |
| log_probs_2 = [[-1.0, -2.0], [-1.5, -1.5]] |
| rewards_2 = [[10.0, 5.0], [3.0, 7.0]] |
| loss_2 = grpo_loss(log_probs_2, rewards_2) |
| record("GRPO handles multiple groups", |
| PASS if not math.isnan(loss_2) else FAIL, |
| f"loss={loss_2:.4f}") |
|
|
| |
| loss_empty = grpo_loss([], []) |
| record("GRPO handles empty input", |
| PASS if loss_empty == 0.0 else FAIL, |
| f"got {loss_empty}") |
|
|
| |
| loss_identical = grpo_loss([[-1.0, -2.0, -3.0]], [[5.0, 5.0, 5.0]]) |
| record("GRPO: identical rewards β near-zero loss", |
| PASS if abs(loss_identical) < 1e-4 else FAIL, |
| f"loss={loss_identical:.6f}") |
|
|
|
|
| |
| |
| |
|
|
| def test_rloo(): |
| print("\n--- RLOO Loss ---") |
| |
| log_probs_groups = [[-1.0, -2.0, -1.5]] |
| rewards_groups = [[10.0, 5.0, 8.0]] |
|
|
| loss = rloo_loss(log_probs_groups, rewards_groups) |
| record("RLOO produces valid loss", |
| PASS if not math.isnan(loss) and not math.isinf(loss) else FAIL, |
| f"loss={loss:.4f}") |
|
|
| |
| |
| |
| |
| baselines = [6.5, 9.0, 7.5] |
| advantages = [10-6.5, 5-9.0, 8-7.5] |
| expected = -sum(lp * adv for lp, adv in zip(log_probs_groups[0], advantages)) / 3 |
| record("RLOO matches manual calculation", |
| PASS if abs(loss - expected) < 0.01 else FAIL, |
| f"got {loss:.4f} expected {expected:.4f}") |
|
|
| |
| loss_single = rloo_loss([[-1.0]], [[5.0]]) |
| expected_single = -(-1.0) * 5.0 |
| record("RLOO K=1 falls back to REINFORCE", |
| PASS if abs(loss_single - expected_single) < 0.01 else FAIL, |
| f"got {loss_single:.4f} expected {expected_single:.4f}") |
|
|
| |
| loss_k2 = rloo_loss([[-1.0, -2.0]], [[10.0, 5.0]]) |
| |
| |
| |
| record("RLOO K=2 produces valid loss", |
| PASS if not math.isnan(loss_k2) else FAIL, |
| f"loss={loss_k2:.4f}") |
|
|
|
|
| |
| |
| |
|
|
| def test_reward_normalization(): |
| print("\n--- Reward Normalization ---") |
| raw = [-0.5, -1.0, -0.3, -2.0, -0.8] |
| mean, var = compute_reward_stats(raw) |
| record("Reward stats computed", |
| PASS if abs(mean - (-0.92)) < 0.01 else FAIL, |
| f"mean={mean:.4f} var={var:.4f}") |
|
|
| normed = normalize_rewards(raw, mean, var) |
| record("Normalized rewards have near-zero mean", |
| PASS if abs(sum(normed)/len(normed)) < 0.01 else FAIL, |
| f"mean={sum(normed)/len(normed):.4f}") |
|
|
| norm_var = sum((n - sum(normed)/len(normed))**2 for n in normed) / len(normed) |
| record("Normalized rewards have near-unit variance", |
| PASS if abs(norm_var - 1.0) < 0.01 else FAIL, |
| f"var={norm_var:.4f}") |
|
|
| |
| identity = normalize_rewards(raw, 0.0, 1.0) |
| record("Identity normalization (mean=0, var=1)", |
| PASS if all(abs(a - b) < 0.01 for a, b in zip(raw, identity)) else FAIL, |
| f"max_diff={max(abs(a-b) for a,b in zip(raw,identity)):.4f}") |
|
|
|
|
| |
| |
| |
|
|
| def test_loss_comparison(): |
| """Compare all 4 loss functions on the same episode data.""" |
| print("\n--- Loss Function Comparison ---") |
| log_probs = [-2.0, -1.5, -3.0, -1.0, -2.5] |
| returns = [0.8, 0.3, 0.1, 0.5, 0.2] |
|
|
| l_reinforce = reinforce_loss(log_probs, returns) |
| l_baseline = reinforce_baseline_loss(log_probs, returns, normalize_advantage=True) |
|
|
| |
| |
| lps_groups = [[lp] for lp in log_probs] |
| rs_groups = [[r] for r in returns] |
| l_grpo = grpo_loss(lps_groups, rs_groups) |
| l_rloo = rloo_loss(lps_groups, rs_groups) |
|
|
| record("All 4 losses produce valid values", |
| PASS if all(not math.isnan(l) and not math.isinf(l) |
| for l in [l_reinforce, l_baseline, l_grpo, l_rloo]) else FAIL, |
| f"R={l_reinforce:.4f} RB={l_baseline:.4f} GRPO={l_grpo:.4f} RLOO={l_rloo:.4f}") |
|
|
| print(f" [i] REINFORCE: {l_reinforce:.6f}") |
| print(f" [i] REINFORCE+baseline: {l_baseline:.6f}") |
| print(f" [i] GRPO (K=1): {l_grpo:.6f}") |
| print(f" [i] RLOO (K=1): {l_rloo:.6f}") |
|
|
| |
| log_probs_4 = [[-1.0, -2.0, -1.5, -3.0]] |
| rewards_4 = [[0.8, 0.2, 0.5, 0.1]] |
| l_grpo_4 = grpo_loss(log_probs_4, rewards_4) |
| l_rloo_4 = rloo_loss(log_probs_4, rewards_4) |
| record("GRPO/RLOO with K=4 produce valid losses", |
| PASS if not math.isnan(l_grpo_4) and not math.isnan(l_rloo_4) else FAIL, |
| f"GRPO={l_grpo_4:.4f} RLOO={l_rloo_4:.4f}") |
| print(f" [i] GRPO (K=4): {l_grpo_4:.6f}") |
| print(f" [i] RLOO (K=4): {l_rloo_4:.6f}") |
|
|
|
|
| |
| |
| |
|
|
| def test_episode_collection(): |
| print("\n--- Episode Collection (MockPolicyModel) ---") |
| config = TrainingConfig(n_nodes=5, max_steps=30) |
| collector = EpisodeCollector(config) |
| model = MockPolicyModel(n_nodes=5, seed=42) |
|
|
| episode = collector.collect_episode(model, task_id="task-1", seed=42) |
|
|
| record("Episode has correct number of steps", |
| PASS if len(episode.steps) == 30 else FAIL, |
| f"steps={len(episode.steps)}") |
|
|
| record("All log probs are valid", |
| PASS if all(not math.isnan(s.log_prob) for s in episode.steps) else FAIL, |
| f"min_lp={min(s.log_prob for s in episode.steps):.4f}") |
|
|
| record("Rewards are finite", |
| PASS if all(math.isfinite(s.reward) for s in episode.steps) else FAIL, |
| f"min_r={min(s.reward for s in episode.steps):.4f}") |
|
|
| record("Normalized rewards in [0,1]", |
| PASS if all(0.0 <= s.reward_normalized <= 1.0 for s in episode.steps) else FAIL, |
| f"range=[{min(s.reward_normalized for s in episode.steps):.4f}, " |
| f"{max(s.reward_normalized for s in episode.steps):.4f}]") |
|
|
| record("Total reward is computed", |
| PASS if math.isfinite(episode.total_reward) else FAIL, |
| f"total={episode.total_reward:.4f}") |
|
|
| record("SLA violations tracked", |
| PASS if isinstance(episode.sla_violations, int) else FAIL, |
| f"violations={episode.sla_violations}") |
|
|
|
|
| |
| |
| |
|
|
| def test_training_steps(): |
| """Run one training step with each loss function.""" |
| print("\n--- Full Training Steps ---") |
| model = MockPolicyModel(n_nodes=5, seed=42) |
|
|
| for loss_name in [LOSS_REINFORCE, LOSS_REINFORCE_BASELINE, LOSS_GRPO, LOSS_RLOO]: |
| config = TrainingConfig( |
| n_nodes=5, |
| max_steps=30, |
| loss_fn=loss_name, |
| n_samples_per_state=2 if loss_name in (LOSS_GRPO, LOSS_RLOO) else 1, |
| ) |
| trainer = SRETrainer(config) |
| metrics = trainer.train_step(model, task_id="task-1", seed=42) |
|
|
| record(f"{loss_name}: loss is valid", |
| PASS if math.isfinite(metrics["loss"]) else FAIL, |
| f"loss={metrics['loss']:.4f}") |
|
|
| record(f"{loss_name}: avg_reward is valid", |
| PASS if math.isfinite(metrics["avg_reward"]) else FAIL, |
| f"avg_reward={metrics['avg_reward']:.4f}") |
|
|
| record(f"{loss_name}: episode completed", |
| PASS if metrics["episode_length"] > 0 else FAIL, |
| f"length={metrics['episode_length']}") |
|
|
| |
| record(f"{loss_name}: running stats stable", |
| PASS if math.isfinite(metrics["reward_mean"]) and math.isfinite(metrics["reward_var"]) else FAIL, |
| f"mean={metrics['reward_mean']:.4f} var={metrics['reward_var']:.4f}") |
|
|
|
|
| |
| |
| |
|
|
| def test_multi_episode_stability(): |
| """Run multiple episodes and check running stats remain stable.""" |
| print("\n--- Multi-Episode Stability ---") |
| config = TrainingConfig( |
| n_nodes=5, |
| max_steps=30, |
| loss_fn=LOSS_REINFORCE_BASELINE, |
| tasks=["task-1", "task-2", "task-3"], |
| ) |
| trainer = SRETrainer(config) |
| model = MockPolicyModel(n_nodes=5, seed=42) |
|
|
| all_losses = [] |
| for i in range(5): |
| for task in config.tasks: |
| metrics = trainer.train_step(model, task_id=task, seed=42 + i) |
| all_losses.append(metrics["loss"]) |
|
|
| |
| record("15 episodes: all losses finite", |
| PASS if all(math.isfinite(l) for l in all_losses) else FAIL, |
| f"n_losses={len(all_losses)}") |
|
|
| |
| unique_losses = len(set(round(l, 4) for l in all_losses)) |
| record("Losses vary across episodes", |
| PASS if unique_losses > 3 else FAIL, |
| f"unique={unique_losses}/{len(all_losses)}") |
|
|
| |
| last_metrics = metrics |
| record("Running reward mean is non-zero", |
| PASS if abs(last_metrics["reward_mean"]) > 0.001 else FAIL, |
| f"mean={last_metrics['reward_mean']:.6f}") |
|
|
|
|
| |
| |
| |
|
|
| def test_sre_edge_cases(): |
| """Test edge cases specific to the SRE domain.""" |
| print("\n--- SRE Edge Cases ---") |
|
|
| |
| log_probs = [-2.0] * 10 |
| returns_crash = [-100.0] * 10 |
| loss_crash = reinforce_baseline_loss(log_probs, returns_crash, normalize_advantage=True) |
| record("Very negative rewards: loss is finite", |
| PASS if math.isfinite(loss_crash) else FAIL, |
| f"loss={loss_crash:.4f}") |
|
|
| |
| returns_perfect = [0.0] * 10 |
| loss_perfect = reinforce_baseline_loss(log_probs, returns_perfect, normalize_advantage=False) |
| record("Zero returns: loss is zero (no gradient)", |
| PASS if abs(loss_perfect) < 0.001 else FAIL, |
| f"loss={loss_perfect:.4f}") |
|
|
| |
| returns_surge = [0.5, 0.5, -10.0, -10.0, 0.5, 0.5, -10.0, 0.5, 0.5, 0.5] |
| loss_surge = reinforce_baseline_loss(log_probs, returns_surge, normalize_advantage=True) |
| record("High-variance rewards: loss is finite with normalization", |
| PASS if math.isfinite(loss_surge) else FAIL, |
| f"loss={loss_surge:.4f}") |
|
|
| |
| lps = [[-1.0, -2.0, -1.5, -3.0]] |
| rs_extreme = [[0.9, 0.8, 0.85, 0.05]] |
| loss_extreme = grpo_loss(lps, rs_extreme) |
| record("GRPO handles outlier in group", |
| PASS if math.isfinite(loss_extreme) else FAIL, |
| f"loss={loss_extreme:.4f}") |
|
|
|
|
| |
|
|
| def main(): |
| print("=" * 60) |
| print("AntiAtropos Training Validation") |
| print("=" * 60) |
|
|
| test_returns() |
| test_gae() |
| test_reinforce() |
| test_reinforce_baseline() |
| test_grpo() |
| test_rloo() |
| test_reward_normalization() |
| test_loss_comparison() |
| test_episode_collection() |
| test_training_steps() |
| test_multi_episode_stability() |
| test_sre_edge_cases() |
|
|
| passed = sum(1 for _, s, _ in results if s == PASS) |
| failed = sum(1 for _, s, _ in results if s == FAIL) |
| total = len(results) |
|
|
| print("\n" + "=" * 60) |
| print(f"RESULTS: {passed}/{total} passed, {failed} failed") |
| print("=" * 60) |
|
|
| if failed > 0: |
| print("\nFailed tests:") |
| for name, status, detail in results: |
| if status == FAIL: |
| print(f" X {name}: {detail}") |
|
|
| return 0 if failed == 0 else 1 |
|
|
|
|
| if __name__ == "__main__": |
| sys.exit(main()) |
|
|