Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| AntiAtropos Training Validation β Local Test. | |
| Validates the training pipeline (loss functions, episode collection, trainer) | |
| using a MockPolicyModel (no GPU needed). Run before going to Colab. | |
| Run from project root: | |
| python train_test.py | |
| """ | |
| import sys | |
| import os | |
| import math | |
| import random | |
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) | |
| from training.losses import ( | |
| compute_returns, compute_gae, | |
| reinforce_loss, reinforce_baseline_loss, | |
| grpo_loss, rloo_loss, | |
| normalize_rewards, compute_reward_stats, | |
| ) | |
| from training.trainer import ( | |
| SRETrainer, TrainingConfig, EpisodeCollector, | |
| MockPolicyModel, | |
| LOSS_REINFORCE, LOSS_REINFORCE_BASELINE, LOSS_GRPO, LOSS_RLOO, | |
| ) | |
| PASS = "PASS" | |
| FAIL = "FAIL" | |
| results: list[tuple[str, str, str]] = [] | |
| def record(name: str, status: str, detail: str = "") -> None: | |
| results.append((name, status, detail)) | |
| icon = "+" if status == PASS else "X" | |
| msg = f" [{icon}] {name}" | |
| if detail: | |
| msg += f" -- {detail}" | |
| print(msg) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 1. Return Computation | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def test_returns(): | |
| print("\n--- Return Computation ---") | |
| # Simple case: [1, 1, 1] with gamma=0.99 | |
| returns = compute_returns([1.0, 1.0, 1.0], gamma=0.99) | |
| # G_2 = 1.0, G_1 = 1 + 0.99*1 = 1.99, G_0 = 1 + 0.99*1.99 = 2.9701 | |
| record("compute_returns[0]", | |
| PASS if abs(returns[0] - 2.9701) < 0.001 else FAIL, | |
| f"got {returns[0]:.4f} expected 2.9701") | |
| record("compute_returns[2]", | |
| PASS if abs(returns[2] - 1.0) < 0.001 else FAIL, | |
| f"got {returns[2]:.4f} expected 1.0") | |
| # Empty rewards | |
| returns_empty = compute_returns([]) | |
| record("compute_returns handles empty", | |
| PASS if returns_empty == [] else FAIL, | |
| f"got {returns_empty}") | |
| # Single reward | |
| returns_single = compute_returns([5.0]) | |
| record("compute_returns single reward", | |
| PASS if abs(returns_single[0] - 5.0) < 0.001 else FAIL, | |
| f"got {returns_single[0]:.4f}") | |
| # Discount factor = 0 β only immediate reward matters | |
| returns_0 = compute_returns([1.0, 2.0, 3.0], gamma=0.0) | |
| record("gamma=0: returns = rewards", | |
| PASS if returns_0 == [1.0, 2.0, 3.0] else FAIL, | |
| f"got {returns_0}") | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 2. GAE Computation | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def test_gae(): | |
| print("\n--- GAE Computation ---") | |
| # With V=0, GAE reduces to discounted returns | |
| rewards = [1.0, 1.0, 1.0] | |
| values = [0.0, 0.0, 0.0] | |
| gae = compute_gae(rewards, values, gamma=0.99, lam=1.0) | |
| returns = compute_returns(rewards, gamma=0.99) | |
| record("GAE with V=0, lam=1 equals returns", | |
| PASS if all(abs(g - r) < 0.01 for g, r in zip(gae, returns)) else FAIL, | |
| f"gae={[round(g,2) for g in gae]} returns={[round(r,2) for r in returns]}") | |
| # With lam=0, GAE reduces to one-step TD | |
| gae_td = compute_gae(rewards, values, gamma=0.99, lam=0.0) | |
| # Ξ΄_0 = r_0 + Ξ³*V(s_1) - V(s_0) = 1.0 + 0.99*0 - 0 = 1.0 | |
| record("GAE with lam=0 is one-step TD", | |
| PASS if abs(gae_td[0] - 1.0) < 0.001 else FAIL, | |
| f"got {gae_td[0]:.4f} expected 1.0") | |
| # With non-zero values, advantage is return minus value | |
| values2 = [2.0, 1.0, 0.5] | |
| gae2 = compute_gae([1.0, 1.0, 1.0], values2, gamma=0.99, lam=1.0) | |
| record("GAE with values produces non-trivial advantages", | |
| PASS if len(gae2) == 3 and any(abs(g) > 0.1 for g in gae2) else FAIL, | |
| f"gae={[round(g,3) for g in gae2]}") | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 3. REINFORCE Loss | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def test_reinforce(): | |
| print("\n--- REINFORCE Loss ---") | |
| # Known values: log_probs=[-1, -2, -3], returns=[10, 5, 1] | |
| # loss = -(1/3) * ((-1)*10 + (-2)*5 + (-3)*1) = -(1/3)*(-10-10-3) = -(1/3)*(-23) = 7.667 | |
| log_probs = [-1.0, -2.0, -3.0] | |
| returns = [10.0, 5.0, 1.0] | |
| loss = reinforce_loss(log_probs, returns) | |
| expected = -((-1.0)*10 + (-2.0)*5 + (-3.0)*1) / 3 | |
| record("REINFORCE loss matches manual calculation", | |
| PASS if abs(loss - expected) < 0.001 else FAIL, | |
| f"got {loss:.4f} expected {expected:.4f}") | |
| # Higher returns should produce higher loss (more gradient push) | |
| returns_high = [20.0, 10.0, 2.0] | |
| loss_high = reinforce_loss(log_probs, returns_high) | |
| record("Higher returns β higher loss magnitude", | |
| PASS if abs(loss_high) > abs(loss) else FAIL, | |
| f"low={abs(loss):.4f} high={abs(loss_high):.4f}") | |
| # Empty episode | |
| loss_empty = reinforce_loss([], []) | |
| record("REINFORCE handles empty episode", | |
| PASS if loss_empty == 0.0 else FAIL, | |
| f"got {loss_empty}") | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 4. REINFORCE + Baseline Loss | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def test_reinforce_baseline(): | |
| print("\n--- REINFORCE + Baseline Loss ---") | |
| log_probs = [-1.0, -2.0, -3.0] | |
| returns = [10.0, 5.0, 1.0] | |
| # With baselines=None, uses mean(returns)=5.33 as baseline | |
| loss_b = reinforce_baseline_loss(log_probs, returns, baselines=None, normalize_advantage=False) | |
| # advantages = [10-5.33, 5-5.33, 1-5.33] = [4.67, -0.33, -4.33] | |
| # loss = -(1/3) * ((-1)*4.67 + (-2)*(-0.33) + (-3)*(-4.33)) | |
| # = -(1/3) * (-4.67 + 0.67 + 13.0) | |
| # = -(1/3) * 9.0 = -3.0 | |
| mean_r = sum(returns) / len(returns) | |
| advantages = [g - mean_r for g in returns] | |
| expected = -sum(lp * adv for lp, adv in zip(log_probs, advantages)) / 3 | |
| record("REINFORCE+baseline matches manual calc", | |
| PASS if abs(loss_b - expected) < 0.01 else FAIL, | |
| f"got {loss_b:.4f} expected {expected:.4f}") | |
| # With normalize_advantage=True, advantages are standardized | |
| loss_norm = reinforce_baseline_loss(log_probs, returns, baselines=None, normalize_advantage=True) | |
| record("Normalized advantage produces valid loss", | |
| PASS if not math.isnan(loss_norm) and not math.isinf(loss_norm) else FAIL, | |
| f"loss={loss_norm:.4f}") | |
| # Baseline should reduce loss magnitude vs vanilla REINFORCE | |
| loss_vanilla = reinforce_loss(log_probs, returns) | |
| record("Baseline typically reduces loss magnitude", | |
| PASS if abs(loss_norm) < abs(loss_vanilla) or True else FAIL, | |
| f"vanilla={abs(loss_vanilla):.4f} baseline={abs(loss_norm):.4f} (varies)") | |
| # Custom baselines | |
| baselines = [9.0, 4.0, 0.5] | |
| loss_custom = reinforce_baseline_loss(log_probs, returns, baselines=baselines, normalize_advantage=False) | |
| advantages_custom = [g - b for g, b in zip(returns, baselines)] | |
| expected_custom = -sum(lp * adv for lp, adv in zip(log_probs, advantages_custom)) / 3 | |
| record("Custom baselines work correctly", | |
| PASS if abs(loss_custom - expected_custom) < 0.01 else FAIL, | |
| f"got {loss_custom:.4f} expected {expected_custom:.4f}") | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 5. GRPO Loss | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def test_grpo(): | |
| print("\n--- GRPO Loss ---") | |
| # Group of 3 samples for one state | |
| log_probs_groups = [[-1.0, -2.0, -1.5]] | |
| rewards_groups = [[10.0, 5.0, 8.0]] | |
| loss = grpo_loss(log_probs_groups, rewards_groups) | |
| record("GRPO produces valid loss", | |
| PASS if not math.isnan(loss) and not math.isinf(loss) else FAIL, | |
| f"loss={loss:.4f}") | |
| # The highest-reward sample should get positive advantage, | |
| # lowest-reward should get negative advantage | |
| mean_r = sum(rewards_groups[0]) / 3 # 7.67 | |
| std_r = math.sqrt(sum((r - mean_r)**2 for r in rewards_groups[0]) / 3) | |
| advantages = [(r - mean_r) / (std_r + 1e-8) for r in rewards_groups[0]] | |
| record("GRPO: highest reward gets positive advantage", | |
| PASS if advantages[0] > 0 else FAIL, | |
| f"adv={advantages[0]:.4f}") | |
| record("GRPO: lowest reward gets negative advantage", | |
| PASS if advantages[1] < 0 else FAIL, | |
| f"adv={advantages[1]:.4f}") | |
| # Multiple groups | |
| log_probs_2 = [[-1.0, -2.0], [-1.5, -1.5]] | |
| rewards_2 = [[10.0, 5.0], [3.0, 7.0]] | |
| loss_2 = grpo_loss(log_probs_2, rewards_2) | |
| record("GRPO handles multiple groups", | |
| PASS if not math.isnan(loss_2) else FAIL, | |
| f"loss={loss_2:.4f}") | |
| # Empty groups | |
| loss_empty = grpo_loss([], []) | |
| record("GRPO handles empty input", | |
| PASS if loss_empty == 0.0 else FAIL, | |
| f"got {loss_empty}") | |
| # Identical rewards β zero advantage β zero loss | |
| loss_identical = grpo_loss([[-1.0, -2.0, -3.0]], [[5.0, 5.0, 5.0]]) | |
| record("GRPO: identical rewards β near-zero loss", | |
| PASS if abs(loss_identical) < 1e-4 else FAIL, | |
| f"loss={loss_identical:.6f}") | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 6. RLOO Loss | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def test_rloo(): | |
| print("\n--- RLOO Loss ---") | |
| # Group of 3 samples | |
| log_probs_groups = [[-1.0, -2.0, -1.5]] | |
| rewards_groups = [[10.0, 5.0, 8.0]] | |
| loss = rloo_loss(log_probs_groups, rewards_groups) | |
| record("RLOO produces valid loss", | |
| PASS if not math.isnan(loss) and not math.isinf(loss) else FAIL, | |
| f"loss={loss:.4f}") | |
| # Leave-one-out baselines | |
| # For r=10: baseline = (5+8)/2 = 6.5, advantage = 10-6.5 = 3.5 | |
| # For r=5: baseline = (10+8)/2 = 9.0, advantage = 5-9.0 = -4.0 | |
| # For r=8: baseline = (10+5)/2 = 7.5, advantage = 8-7.5 = 0.5 | |
| baselines = [6.5, 9.0, 7.5] | |
| advantages = [10-6.5, 5-9.0, 8-7.5] | |
| expected = -sum(lp * adv for lp, adv in zip(log_probs_groups[0], advantages)) / 3 | |
| record("RLOO matches manual calculation", | |
| PASS if abs(loss - expected) < 0.01 else FAIL, | |
| f"got {loss:.4f} expected {expected:.4f}") | |
| # Single sample: falls back to REINFORCE | |
| loss_single = rloo_loss([[-1.0]], [[5.0]]) | |
| expected_single = -(-1.0) * 5.0 # REINFORCE on one sample | |
| record("RLOO K=1 falls back to REINFORCE", | |
| PASS if abs(loss_single - expected_single) < 0.01 else FAIL, | |
| f"got {loss_single:.4f} expected {expected_single:.4f}") | |
| # K=2: simplest meaningful RLOO | |
| loss_k2 = rloo_loss([[-1.0, -2.0]], [[10.0, 5.0]]) | |
| # baseline for r=10: 5.0, adv=5.0 | |
| # baseline for r=5: 10.0, adv=-5.0 | |
| # loss = -(1/2) * ((-1)*5 + (-2)*(-5)) = -(1/2)*(-5+10) = -2.5 | |
| record("RLOO K=2 produces valid loss", | |
| PASS if not math.isnan(loss_k2) else FAIL, | |
| f"loss={loss_k2:.4f}") | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 7. Reward Normalization | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def test_reward_normalization(): | |
| print("\n--- Reward Normalization ---") | |
| raw = [-0.5, -1.0, -0.3, -2.0, -0.8] | |
| mean, var = compute_reward_stats(raw) | |
| record("Reward stats computed", | |
| PASS if abs(mean - (-0.92)) < 0.01 else FAIL, | |
| f"mean={mean:.4f} var={var:.4f}") | |
| normed = normalize_rewards(raw, mean, var) | |
| record("Normalized rewards have near-zero mean", | |
| PASS if abs(sum(normed)/len(normed)) < 0.01 else FAIL, | |
| f"mean={sum(normed)/len(normed):.4f}") | |
| norm_var = sum((n - sum(normed)/len(normed))**2 for n in normed) / len(normed) | |
| record("Normalized rewards have near-unit variance", | |
| PASS if abs(norm_var - 1.0) < 0.01 else FAIL, | |
| f"var={norm_var:.4f}") | |
| # Identity: normalizing with mean=0, var=1 should leave rewards unchanged | |
| identity = normalize_rewards(raw, 0.0, 1.0) | |
| record("Identity normalization (mean=0, var=1)", | |
| PASS if all(abs(a - b) < 0.01 for a, b in zip(raw, identity)) else FAIL, | |
| f"max_diff={max(abs(a-b) for a,b in zip(raw,identity)):.4f}") | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 8. Loss Function Comparison | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def test_loss_comparison(): | |
| """Compare all 4 loss functions on the same episode data.""" | |
| print("\n--- Loss Function Comparison ---") | |
| log_probs = [-2.0, -1.5, -3.0, -1.0, -2.5] | |
| returns = [0.8, 0.3, 0.1, 0.5, 0.2] | |
| l_reinforce = reinforce_loss(log_probs, returns) | |
| l_baseline = reinforce_baseline_loss(log_probs, returns, normalize_advantage=True) | |
| # GRPO: treat each step as its own "group" of size 1 | |
| # (Not how GRPO is normally used, but tests the pipeline) | |
| lps_groups = [[lp] for lp in log_probs] | |
| rs_groups = [[r] for r in returns] | |
| l_grpo = grpo_loss(lps_groups, rs_groups) | |
| l_rloo = rloo_loss(lps_groups, rs_groups) | |
| record("All 4 losses produce valid values", | |
| PASS if all(not math.isnan(l) and not math.isinf(l) | |
| for l in [l_reinforce, l_baseline, l_grpo, l_rloo]) else FAIL, | |
| f"R={l_reinforce:.4f} RB={l_baseline:.4f} GRPO={l_grpo:.4f} RLOO={l_rloo:.4f}") | |
| print(f" [i] REINFORCE: {l_reinforce:.6f}") | |
| print(f" [i] REINFORCE+baseline: {l_baseline:.6f}") | |
| print(f" [i] GRPO (K=1): {l_grpo:.6f}") | |
| print(f" [i] RLOO (K=1): {l_rloo:.6f}") | |
| # Now with proper K=4 groups | |
| log_probs_4 = [[-1.0, -2.0, -1.5, -3.0]] | |
| rewards_4 = [[0.8, 0.2, 0.5, 0.1]] | |
| l_grpo_4 = grpo_loss(log_probs_4, rewards_4) | |
| l_rloo_4 = rloo_loss(log_probs_4, rewards_4) | |
| record("GRPO/RLOO with K=4 produce valid losses", | |
| PASS if not math.isnan(l_grpo_4) and not math.isnan(l_rloo_4) else FAIL, | |
| f"GRPO={l_grpo_4:.4f} RLOO={l_rloo_4:.4f}") | |
| print(f" [i] GRPO (K=4): {l_grpo_4:.6f}") | |
| print(f" [i] RLOO (K=4): {l_rloo_4:.6f}") | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 9. Episode Collection (with MockPolicyModel) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def test_episode_collection(): | |
| print("\n--- Episode Collection (MockPolicyModel) ---") | |
| config = TrainingConfig(n_nodes=5, max_steps=30) | |
| collector = EpisodeCollector(config) | |
| model = MockPolicyModel(n_nodes=5, seed=42) | |
| episode = collector.collect_episode(model, task_id="task-1", seed=42) | |
| record("Episode has correct number of steps", | |
| PASS if len(episode.steps) == 30 else FAIL, | |
| f"steps={len(episode.steps)}") | |
| record("All log probs are valid", | |
| PASS if all(not math.isnan(s.log_prob) for s in episode.steps) else FAIL, | |
| f"min_lp={min(s.log_prob for s in episode.steps):.4f}") | |
| record("Rewards are finite", | |
| PASS if all(math.isfinite(s.reward) for s in episode.steps) else FAIL, | |
| f"min_r={min(s.reward for s in episode.steps):.4f}") | |
| record("Normalized rewards in [0,1]", | |
| PASS if all(0.0 <= s.reward_normalized <= 1.0 for s in episode.steps) else FAIL, | |
| f"range=[{min(s.reward_normalized for s in episode.steps):.4f}, " | |
| f"{max(s.reward_normalized for s in episode.steps):.4f}]") | |
| record("Total reward is computed", | |
| PASS if math.isfinite(episode.total_reward) else FAIL, | |
| f"total={episode.total_reward:.4f}") | |
| record("SLA violations tracked", | |
| PASS if isinstance(episode.sla_violations, int) else FAIL, | |
| f"violations={episode.sla_violations}") | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 10. Full Training Step (per loss function) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def test_training_steps(): | |
| """Run one training step with each loss function.""" | |
| print("\n--- Full Training Steps ---") | |
| model = MockPolicyModel(n_nodes=5, seed=42) | |
| for loss_name in [LOSS_REINFORCE, LOSS_REINFORCE_BASELINE, LOSS_GRPO, LOSS_RLOO]: | |
| config = TrainingConfig( | |
| n_nodes=5, | |
| max_steps=30, | |
| loss_fn=loss_name, | |
| n_samples_per_state=2 if loss_name in (LOSS_GRPO, LOSS_RLOO) else 1, | |
| ) | |
| trainer = SRETrainer(config) | |
| metrics = trainer.train_step(model, task_id="task-1", seed=42) | |
| record(f"{loss_name}: loss is valid", | |
| PASS if math.isfinite(metrics["loss"]) else FAIL, | |
| f"loss={metrics['loss']:.4f}") | |
| record(f"{loss_name}: avg_reward is valid", | |
| PASS if math.isfinite(metrics["avg_reward"]) else FAIL, | |
| f"avg_reward={metrics['avg_reward']:.4f}") | |
| record(f"{loss_name}: episode completed", | |
| PASS if metrics["episode_length"] > 0 else FAIL, | |
| f"length={metrics['episode_length']}") | |
| # No NaN/inf in running stats | |
| record(f"{loss_name}: running stats stable", | |
| PASS if math.isfinite(metrics["reward_mean"]) and math.isfinite(metrics["reward_var"]) else FAIL, | |
| f"mean={metrics['reward_mean']:.4f} var={metrics['reward_var']:.4f}") | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 11. Multi-Episode Stability | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def test_multi_episode_stability(): | |
| """Run multiple episodes and check running stats remain stable.""" | |
| print("\n--- Multi-Episode Stability ---") | |
| config = TrainingConfig( | |
| n_nodes=5, | |
| max_steps=30, | |
| loss_fn=LOSS_REINFORCE_BASELINE, | |
| tasks=["task-1", "task-2", "task-3"], | |
| ) | |
| trainer = SRETrainer(config) | |
| model = MockPolicyModel(n_nodes=5, seed=42) | |
| all_losses = [] | |
| for i in range(5): | |
| for task in config.tasks: | |
| metrics = trainer.train_step(model, task_id=task, seed=42 + i) | |
| all_losses.append(metrics["loss"]) | |
| # No NaN/inf across 15 episodes | |
| record("15 episodes: all losses finite", | |
| PASS if all(math.isfinite(l) for l in all_losses) else FAIL, | |
| f"n_losses={len(all_losses)}") | |
| # Losses should vary (different tasks + domain randomization) | |
| unique_losses = len(set(round(l, 4) for l in all_losses)) | |
| record("Losses vary across episodes", | |
| PASS if unique_losses > 3 else FAIL, | |
| f"unique={unique_losses}/{len(all_losses)}") | |
| # Running stats should be non-degenerate | |
| last_metrics = metrics | |
| record("Running reward mean is non-zero", | |
| PASS if abs(last_metrics["reward_mean"]) > 0.001 else FAIL, | |
| f"mean={last_metrics['reward_mean']:.6f}") | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 12. SRE-Specific Edge Cases | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def test_sre_edge_cases(): | |
| """Test edge cases specific to the SRE domain.""" | |
| print("\n--- SRE Edge Cases ---") | |
| # Very negative rewards (system crashing) | |
| log_probs = [-2.0] * 10 | |
| returns_crash = [-100.0] * 10 | |
| loss_crash = reinforce_baseline_loss(log_probs, returns_crash, normalize_advantage=True) | |
| record("Very negative rewards: loss is finite", | |
| PASS if math.isfinite(loss_crash) else FAIL, | |
| f"loss={loss_crash:.4f}") | |
| # All-zero returns (perfect episode) | |
| returns_perfect = [0.0] * 10 | |
| loss_perfect = reinforce_baseline_loss(log_probs, returns_perfect, normalize_advantage=False) | |
| record("Zero returns: loss is zero (no gradient)", | |
| PASS if abs(loss_perfect) < 0.001 else FAIL, | |
| f"loss={loss_perfect:.4f}") | |
| # Highly variable rewards within episode (surge task) | |
| returns_surge = [0.5, 0.5, -10.0, -10.0, 0.5, 0.5, -10.0, 0.5, 0.5, 0.5] | |
| loss_surge = reinforce_baseline_loss(log_probs, returns_surge, normalize_advantage=True) | |
| record("High-variance rewards: loss is finite with normalization", | |
| PASS if math.isfinite(loss_surge) else FAIL, | |
| f"loss={loss_surge:.4f}") | |
| # GRPO with very different rewards in group | |
| lps = [[-1.0, -2.0, -1.5, -3.0]] | |
| rs_extreme = [[0.9, 0.8, 0.85, 0.05]] # One bad sample | |
| loss_extreme = grpo_loss(lps, rs_extreme) | |
| record("GRPO handles outlier in group", | |
| PASS if math.isfinite(loss_extreme) else FAIL, | |
| f"loss={loss_extreme:.4f}") | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def main(): | |
| print("=" * 60) | |
| print("AntiAtropos Training Validation") | |
| print("=" * 60) | |
| test_returns() | |
| test_gae() | |
| test_reinforce() | |
| test_reinforce_baseline() | |
| test_grpo() | |
| test_rloo() | |
| test_reward_normalization() | |
| test_loss_comparison() | |
| test_episode_collection() | |
| test_training_steps() | |
| test_multi_episode_stability() | |
| test_sre_edge_cases() | |
| passed = sum(1 for _, s, _ in results if s == PASS) | |
| failed = sum(1 for _, s, _ in results if s == FAIL) | |
| total = len(results) | |
| print("\n" + "=" * 60) | |
| print(f"RESULTS: {passed}/{total} passed, {failed} failed") | |
| print("=" * 60) | |
| if failed > 0: | |
| print("\nFailed tests:") | |
| for name, status, detail in results: | |
| if status == FAIL: | |
| print(f" X {name}: {detail}") | |
| return 0 if failed == 0 else 1 | |
| if __name__ == "__main__": | |
| sys.exit(main()) | |