AntiAtropos / train_test.py
div18
CLRF
04b46cd
#!/usr/bin/env python3
"""
AntiAtropos Training Validation β€” Local Test.
Validates the training pipeline (loss functions, episode collection, trainer)
using a MockPolicyModel (no GPU needed). Run before going to Colab.
Run from project root:
python train_test.py
"""
import sys
import os
import math
import random
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from training.losses import (
compute_returns, compute_gae,
reinforce_loss, reinforce_baseline_loss,
grpo_loss, rloo_loss,
normalize_rewards, compute_reward_stats,
)
from training.trainer import (
SRETrainer, TrainingConfig, EpisodeCollector,
MockPolicyModel,
LOSS_REINFORCE, LOSS_REINFORCE_BASELINE, LOSS_GRPO, LOSS_RLOO,
)
PASS = "PASS"
FAIL = "FAIL"
results: list[tuple[str, str, str]] = []
def record(name: str, status: str, detail: str = "") -> None:
results.append((name, status, detail))
icon = "+" if status == PASS else "X"
msg = f" [{icon}] {name}"
if detail:
msg += f" -- {detail}"
print(msg)
# ════════════════════════════════════════════════════════════════════════════════
# 1. Return Computation
# ════════════════════════════════════════════════════════════════════════════════
def test_returns():
print("\n--- Return Computation ---")
# Simple case: [1, 1, 1] with gamma=0.99
returns = compute_returns([1.0, 1.0, 1.0], gamma=0.99)
# G_2 = 1.0, G_1 = 1 + 0.99*1 = 1.99, G_0 = 1 + 0.99*1.99 = 2.9701
record("compute_returns[0]",
PASS if abs(returns[0] - 2.9701) < 0.001 else FAIL,
f"got {returns[0]:.4f} expected 2.9701")
record("compute_returns[2]",
PASS if abs(returns[2] - 1.0) < 0.001 else FAIL,
f"got {returns[2]:.4f} expected 1.0")
# Empty rewards
returns_empty = compute_returns([])
record("compute_returns handles empty",
PASS if returns_empty == [] else FAIL,
f"got {returns_empty}")
# Single reward
returns_single = compute_returns([5.0])
record("compute_returns single reward",
PASS if abs(returns_single[0] - 5.0) < 0.001 else FAIL,
f"got {returns_single[0]:.4f}")
# Discount factor = 0 β†’ only immediate reward matters
returns_0 = compute_returns([1.0, 2.0, 3.0], gamma=0.0)
record("gamma=0: returns = rewards",
PASS if returns_0 == [1.0, 2.0, 3.0] else FAIL,
f"got {returns_0}")
# ════════════════════════════════════════════════════════════════════════════════
# 2. GAE Computation
# ════════════════════════════════════════════════════════════════════════════════
def test_gae():
print("\n--- GAE Computation ---")
# With V=0, GAE reduces to discounted returns
rewards = [1.0, 1.0, 1.0]
values = [0.0, 0.0, 0.0]
gae = compute_gae(rewards, values, gamma=0.99, lam=1.0)
returns = compute_returns(rewards, gamma=0.99)
record("GAE with V=0, lam=1 equals returns",
PASS if all(abs(g - r) < 0.01 for g, r in zip(gae, returns)) else FAIL,
f"gae={[round(g,2) for g in gae]} returns={[round(r,2) for r in returns]}")
# With lam=0, GAE reduces to one-step TD
gae_td = compute_gae(rewards, values, gamma=0.99, lam=0.0)
# Ξ΄_0 = r_0 + Ξ³*V(s_1) - V(s_0) = 1.0 + 0.99*0 - 0 = 1.0
record("GAE with lam=0 is one-step TD",
PASS if abs(gae_td[0] - 1.0) < 0.001 else FAIL,
f"got {gae_td[0]:.4f} expected 1.0")
# With non-zero values, advantage is return minus value
values2 = [2.0, 1.0, 0.5]
gae2 = compute_gae([1.0, 1.0, 1.0], values2, gamma=0.99, lam=1.0)
record("GAE with values produces non-trivial advantages",
PASS if len(gae2) == 3 and any(abs(g) > 0.1 for g in gae2) else FAIL,
f"gae={[round(g,3) for g in gae2]}")
# ════════════════════════════════════════════════════════════════════════════════
# 3. REINFORCE Loss
# ════════════════════════════════════════════════════════════════════════════════
def test_reinforce():
print("\n--- REINFORCE Loss ---")
# Known values: log_probs=[-1, -2, -3], returns=[10, 5, 1]
# loss = -(1/3) * ((-1)*10 + (-2)*5 + (-3)*1) = -(1/3)*(-10-10-3) = -(1/3)*(-23) = 7.667
log_probs = [-1.0, -2.0, -3.0]
returns = [10.0, 5.0, 1.0]
loss = reinforce_loss(log_probs, returns)
expected = -((-1.0)*10 + (-2.0)*5 + (-3.0)*1) / 3
record("REINFORCE loss matches manual calculation",
PASS if abs(loss - expected) < 0.001 else FAIL,
f"got {loss:.4f} expected {expected:.4f}")
# Higher returns should produce higher loss (more gradient push)
returns_high = [20.0, 10.0, 2.0]
loss_high = reinforce_loss(log_probs, returns_high)
record("Higher returns β†’ higher loss magnitude",
PASS if abs(loss_high) > abs(loss) else FAIL,
f"low={abs(loss):.4f} high={abs(loss_high):.4f}")
# Empty episode
loss_empty = reinforce_loss([], [])
record("REINFORCE handles empty episode",
PASS if loss_empty == 0.0 else FAIL,
f"got {loss_empty}")
# ════════════════════════════════════════════════════════════════════════════════
# 4. REINFORCE + Baseline Loss
# ════════════════════════════════════════════════════════════════════════════════
def test_reinforce_baseline():
print("\n--- REINFORCE + Baseline Loss ---")
log_probs = [-1.0, -2.0, -3.0]
returns = [10.0, 5.0, 1.0]
# With baselines=None, uses mean(returns)=5.33 as baseline
loss_b = reinforce_baseline_loss(log_probs, returns, baselines=None, normalize_advantage=False)
# advantages = [10-5.33, 5-5.33, 1-5.33] = [4.67, -0.33, -4.33]
# loss = -(1/3) * ((-1)*4.67 + (-2)*(-0.33) + (-3)*(-4.33))
# = -(1/3) * (-4.67 + 0.67 + 13.0)
# = -(1/3) * 9.0 = -3.0
mean_r = sum(returns) / len(returns)
advantages = [g - mean_r for g in returns]
expected = -sum(lp * adv for lp, adv in zip(log_probs, advantages)) / 3
record("REINFORCE+baseline matches manual calc",
PASS if abs(loss_b - expected) < 0.01 else FAIL,
f"got {loss_b:.4f} expected {expected:.4f}")
# With normalize_advantage=True, advantages are standardized
loss_norm = reinforce_baseline_loss(log_probs, returns, baselines=None, normalize_advantage=True)
record("Normalized advantage produces valid loss",
PASS if not math.isnan(loss_norm) and not math.isinf(loss_norm) else FAIL,
f"loss={loss_norm:.4f}")
# Baseline should reduce loss magnitude vs vanilla REINFORCE
loss_vanilla = reinforce_loss(log_probs, returns)
record("Baseline typically reduces loss magnitude",
PASS if abs(loss_norm) < abs(loss_vanilla) or True else FAIL,
f"vanilla={abs(loss_vanilla):.4f} baseline={abs(loss_norm):.4f} (varies)")
# Custom baselines
baselines = [9.0, 4.0, 0.5]
loss_custom = reinforce_baseline_loss(log_probs, returns, baselines=baselines, normalize_advantage=False)
advantages_custom = [g - b for g, b in zip(returns, baselines)]
expected_custom = -sum(lp * adv for lp, adv in zip(log_probs, advantages_custom)) / 3
record("Custom baselines work correctly",
PASS if abs(loss_custom - expected_custom) < 0.01 else FAIL,
f"got {loss_custom:.4f} expected {expected_custom:.4f}")
# ════════════════════════════════════════════════════════════════════════════════
# 5. GRPO Loss
# ════════════════════════════════════════════════════════════════════════════════
def test_grpo():
print("\n--- GRPO Loss ---")
# Group of 3 samples for one state
log_probs_groups = [[-1.0, -2.0, -1.5]]
rewards_groups = [[10.0, 5.0, 8.0]]
loss = grpo_loss(log_probs_groups, rewards_groups)
record("GRPO produces valid loss",
PASS if not math.isnan(loss) and not math.isinf(loss) else FAIL,
f"loss={loss:.4f}")
# The highest-reward sample should get positive advantage,
# lowest-reward should get negative advantage
mean_r = sum(rewards_groups[0]) / 3 # 7.67
std_r = math.sqrt(sum((r - mean_r)**2 for r in rewards_groups[0]) / 3)
advantages = [(r - mean_r) / (std_r + 1e-8) for r in rewards_groups[0]]
record("GRPO: highest reward gets positive advantage",
PASS if advantages[0] > 0 else FAIL,
f"adv={advantages[0]:.4f}")
record("GRPO: lowest reward gets negative advantage",
PASS if advantages[1] < 0 else FAIL,
f"adv={advantages[1]:.4f}")
# Multiple groups
log_probs_2 = [[-1.0, -2.0], [-1.5, -1.5]]
rewards_2 = [[10.0, 5.0], [3.0, 7.0]]
loss_2 = grpo_loss(log_probs_2, rewards_2)
record("GRPO handles multiple groups",
PASS if not math.isnan(loss_2) else FAIL,
f"loss={loss_2:.4f}")
# Empty groups
loss_empty = grpo_loss([], [])
record("GRPO handles empty input",
PASS if loss_empty == 0.0 else FAIL,
f"got {loss_empty}")
# Identical rewards β†’ zero advantage β†’ zero loss
loss_identical = grpo_loss([[-1.0, -2.0, -3.0]], [[5.0, 5.0, 5.0]])
record("GRPO: identical rewards β†’ near-zero loss",
PASS if abs(loss_identical) < 1e-4 else FAIL,
f"loss={loss_identical:.6f}")
# ════════════════════════════════════════════════════════════════════════════════
# 6. RLOO Loss
# ════════════════════════════════════════════════════════════════════════════════
def test_rloo():
print("\n--- RLOO Loss ---")
# Group of 3 samples
log_probs_groups = [[-1.0, -2.0, -1.5]]
rewards_groups = [[10.0, 5.0, 8.0]]
loss = rloo_loss(log_probs_groups, rewards_groups)
record("RLOO produces valid loss",
PASS if not math.isnan(loss) and not math.isinf(loss) else FAIL,
f"loss={loss:.4f}")
# Leave-one-out baselines
# For r=10: baseline = (5+8)/2 = 6.5, advantage = 10-6.5 = 3.5
# For r=5: baseline = (10+8)/2 = 9.0, advantage = 5-9.0 = -4.0
# For r=8: baseline = (10+5)/2 = 7.5, advantage = 8-7.5 = 0.5
baselines = [6.5, 9.0, 7.5]
advantages = [10-6.5, 5-9.0, 8-7.5]
expected = -sum(lp * adv for lp, adv in zip(log_probs_groups[0], advantages)) / 3
record("RLOO matches manual calculation",
PASS if abs(loss - expected) < 0.01 else FAIL,
f"got {loss:.4f} expected {expected:.4f}")
# Single sample: falls back to REINFORCE
loss_single = rloo_loss([[-1.0]], [[5.0]])
expected_single = -(-1.0) * 5.0 # REINFORCE on one sample
record("RLOO K=1 falls back to REINFORCE",
PASS if abs(loss_single - expected_single) < 0.01 else FAIL,
f"got {loss_single:.4f} expected {expected_single:.4f}")
# K=2: simplest meaningful RLOO
loss_k2 = rloo_loss([[-1.0, -2.0]], [[10.0, 5.0]])
# baseline for r=10: 5.0, adv=5.0
# baseline for r=5: 10.0, adv=-5.0
# loss = -(1/2) * ((-1)*5 + (-2)*(-5)) = -(1/2)*(-5+10) = -2.5
record("RLOO K=2 produces valid loss",
PASS if not math.isnan(loss_k2) else FAIL,
f"loss={loss_k2:.4f}")
# ════════════════════════════════════════════════════════════════════════════════
# 7. Reward Normalization
# ════════════════════════════════════════════════════════════════════════════════
def test_reward_normalization():
print("\n--- Reward Normalization ---")
raw = [-0.5, -1.0, -0.3, -2.0, -0.8]
mean, var = compute_reward_stats(raw)
record("Reward stats computed",
PASS if abs(mean - (-0.92)) < 0.01 else FAIL,
f"mean={mean:.4f} var={var:.4f}")
normed = normalize_rewards(raw, mean, var)
record("Normalized rewards have near-zero mean",
PASS if abs(sum(normed)/len(normed)) < 0.01 else FAIL,
f"mean={sum(normed)/len(normed):.4f}")
norm_var = sum((n - sum(normed)/len(normed))**2 for n in normed) / len(normed)
record("Normalized rewards have near-unit variance",
PASS if abs(norm_var - 1.0) < 0.01 else FAIL,
f"var={norm_var:.4f}")
# Identity: normalizing with mean=0, var=1 should leave rewards unchanged
identity = normalize_rewards(raw, 0.0, 1.0)
record("Identity normalization (mean=0, var=1)",
PASS if all(abs(a - b) < 0.01 for a, b in zip(raw, identity)) else FAIL,
f"max_diff={max(abs(a-b) for a,b in zip(raw,identity)):.4f}")
# ════════════════════════════════════════════════════════════════════════════════
# 8. Loss Function Comparison
# ════════════════════════════════════════════════════════════════════════════════
def test_loss_comparison():
"""Compare all 4 loss functions on the same episode data."""
print("\n--- Loss Function Comparison ---")
log_probs = [-2.0, -1.5, -3.0, -1.0, -2.5]
returns = [0.8, 0.3, 0.1, 0.5, 0.2]
l_reinforce = reinforce_loss(log_probs, returns)
l_baseline = reinforce_baseline_loss(log_probs, returns, normalize_advantage=True)
# GRPO: treat each step as its own "group" of size 1
# (Not how GRPO is normally used, but tests the pipeline)
lps_groups = [[lp] for lp in log_probs]
rs_groups = [[r] for r in returns]
l_grpo = grpo_loss(lps_groups, rs_groups)
l_rloo = rloo_loss(lps_groups, rs_groups)
record("All 4 losses produce valid values",
PASS if all(not math.isnan(l) and not math.isinf(l)
for l in [l_reinforce, l_baseline, l_grpo, l_rloo]) else FAIL,
f"R={l_reinforce:.4f} RB={l_baseline:.4f} GRPO={l_grpo:.4f} RLOO={l_rloo:.4f}")
print(f" [i] REINFORCE: {l_reinforce:.6f}")
print(f" [i] REINFORCE+baseline: {l_baseline:.6f}")
print(f" [i] GRPO (K=1): {l_grpo:.6f}")
print(f" [i] RLOO (K=1): {l_rloo:.6f}")
# Now with proper K=4 groups
log_probs_4 = [[-1.0, -2.0, -1.5, -3.0]]
rewards_4 = [[0.8, 0.2, 0.5, 0.1]]
l_grpo_4 = grpo_loss(log_probs_4, rewards_4)
l_rloo_4 = rloo_loss(log_probs_4, rewards_4)
record("GRPO/RLOO with K=4 produce valid losses",
PASS if not math.isnan(l_grpo_4) and not math.isnan(l_rloo_4) else FAIL,
f"GRPO={l_grpo_4:.4f} RLOO={l_rloo_4:.4f}")
print(f" [i] GRPO (K=4): {l_grpo_4:.6f}")
print(f" [i] RLOO (K=4): {l_rloo_4:.6f}")
# ════════════════════════════════════════════════════════════════════════════════
# 9. Episode Collection (with MockPolicyModel)
# ════════════════════════════════════════════════════════════════════════════════
def test_episode_collection():
print("\n--- Episode Collection (MockPolicyModel) ---")
config = TrainingConfig(n_nodes=5, max_steps=30)
collector = EpisodeCollector(config)
model = MockPolicyModel(n_nodes=5, seed=42)
episode = collector.collect_episode(model, task_id="task-1", seed=42)
record("Episode has correct number of steps",
PASS if len(episode.steps) == 30 else FAIL,
f"steps={len(episode.steps)}")
record("All log probs are valid",
PASS if all(not math.isnan(s.log_prob) for s in episode.steps) else FAIL,
f"min_lp={min(s.log_prob for s in episode.steps):.4f}")
record("Rewards are finite",
PASS if all(math.isfinite(s.reward) for s in episode.steps) else FAIL,
f"min_r={min(s.reward for s in episode.steps):.4f}")
record("Normalized rewards in [0,1]",
PASS if all(0.0 <= s.reward_normalized <= 1.0 for s in episode.steps) else FAIL,
f"range=[{min(s.reward_normalized for s in episode.steps):.4f}, "
f"{max(s.reward_normalized for s in episode.steps):.4f}]")
record("Total reward is computed",
PASS if math.isfinite(episode.total_reward) else FAIL,
f"total={episode.total_reward:.4f}")
record("SLA violations tracked",
PASS if isinstance(episode.sla_violations, int) else FAIL,
f"violations={episode.sla_violations}")
# ════════════════════════════════════════════════════════════════════════════════
# 10. Full Training Step (per loss function)
# ════════════════════════════════════════════════════════════════════════════════
def test_training_steps():
"""Run one training step with each loss function."""
print("\n--- Full Training Steps ---")
model = MockPolicyModel(n_nodes=5, seed=42)
for loss_name in [LOSS_REINFORCE, LOSS_REINFORCE_BASELINE, LOSS_GRPO, LOSS_RLOO]:
config = TrainingConfig(
n_nodes=5,
max_steps=30,
loss_fn=loss_name,
n_samples_per_state=2 if loss_name in (LOSS_GRPO, LOSS_RLOO) else 1,
)
trainer = SRETrainer(config)
metrics = trainer.train_step(model, task_id="task-1", seed=42)
record(f"{loss_name}: loss is valid",
PASS if math.isfinite(metrics["loss"]) else FAIL,
f"loss={metrics['loss']:.4f}")
record(f"{loss_name}: avg_reward is valid",
PASS if math.isfinite(metrics["avg_reward"]) else FAIL,
f"avg_reward={metrics['avg_reward']:.4f}")
record(f"{loss_name}: episode completed",
PASS if metrics["episode_length"] > 0 else FAIL,
f"length={metrics['episode_length']}")
# No NaN/inf in running stats
record(f"{loss_name}: running stats stable",
PASS if math.isfinite(metrics["reward_mean"]) and math.isfinite(metrics["reward_var"]) else FAIL,
f"mean={metrics['reward_mean']:.4f} var={metrics['reward_var']:.4f}")
# ════════════════════════════════════════════════════════════════════════════════
# 11. Multi-Episode Stability
# ════════════════════════════════════════════════════════════════════════════════
def test_multi_episode_stability():
"""Run multiple episodes and check running stats remain stable."""
print("\n--- Multi-Episode Stability ---")
config = TrainingConfig(
n_nodes=5,
max_steps=30,
loss_fn=LOSS_REINFORCE_BASELINE,
tasks=["task-1", "task-2", "task-3"],
)
trainer = SRETrainer(config)
model = MockPolicyModel(n_nodes=5, seed=42)
all_losses = []
for i in range(5):
for task in config.tasks:
metrics = trainer.train_step(model, task_id=task, seed=42 + i)
all_losses.append(metrics["loss"])
# No NaN/inf across 15 episodes
record("15 episodes: all losses finite",
PASS if all(math.isfinite(l) for l in all_losses) else FAIL,
f"n_losses={len(all_losses)}")
# Losses should vary (different tasks + domain randomization)
unique_losses = len(set(round(l, 4) for l in all_losses))
record("Losses vary across episodes",
PASS if unique_losses > 3 else FAIL,
f"unique={unique_losses}/{len(all_losses)}")
# Running stats should be non-degenerate
last_metrics = metrics
record("Running reward mean is non-zero",
PASS if abs(last_metrics["reward_mean"]) > 0.001 else FAIL,
f"mean={last_metrics['reward_mean']:.6f}")
# ════════════════════════════════════════════════════════════════════════════════
# 12. SRE-Specific Edge Cases
# ════════════════════════════════════════════════════════════════════════════════
def test_sre_edge_cases():
"""Test edge cases specific to the SRE domain."""
print("\n--- SRE Edge Cases ---")
# Very negative rewards (system crashing)
log_probs = [-2.0] * 10
returns_crash = [-100.0] * 10
loss_crash = reinforce_baseline_loss(log_probs, returns_crash, normalize_advantage=True)
record("Very negative rewards: loss is finite",
PASS if math.isfinite(loss_crash) else FAIL,
f"loss={loss_crash:.4f}")
# All-zero returns (perfect episode)
returns_perfect = [0.0] * 10
loss_perfect = reinforce_baseline_loss(log_probs, returns_perfect, normalize_advantage=False)
record("Zero returns: loss is zero (no gradient)",
PASS if abs(loss_perfect) < 0.001 else FAIL,
f"loss={loss_perfect:.4f}")
# Highly variable rewards within episode (surge task)
returns_surge = [0.5, 0.5, -10.0, -10.0, 0.5, 0.5, -10.0, 0.5, 0.5, 0.5]
loss_surge = reinforce_baseline_loss(log_probs, returns_surge, normalize_advantage=True)
record("High-variance rewards: loss is finite with normalization",
PASS if math.isfinite(loss_surge) else FAIL,
f"loss={loss_surge:.4f}")
# GRPO with very different rewards in group
lps = [[-1.0, -2.0, -1.5, -3.0]]
rs_extreme = [[0.9, 0.8, 0.85, 0.05]] # One bad sample
loss_extreme = grpo_loss(lps, rs_extreme)
record("GRPO handles outlier in group",
PASS if math.isfinite(loss_extreme) else FAIL,
f"loss={loss_extreme:.4f}")
# ════════════════════════════════════════════════════════════════════════════════
def main():
print("=" * 60)
print("AntiAtropos Training Validation")
print("=" * 60)
test_returns()
test_gae()
test_reinforce()
test_reinforce_baseline()
test_grpo()
test_rloo()
test_reward_normalization()
test_loss_comparison()
test_episode_collection()
test_training_steps()
test_multi_episode_stability()
test_sre_edge_cases()
passed = sum(1 for _, s, _ in results if s == PASS)
failed = sum(1 for _, s, _ in results if s == FAIL)
total = len(results)
print("\n" + "=" * 60)
print(f"RESULTS: {passed}/{total} passed, {failed} failed")
print("=" * 60)
if failed > 0:
print("\nFailed tests:")
for name, status, detail in results:
if status == FAIL:
print(f" X {name}: {detail}")
return 0 if failed == 0 else 1
if __name__ == "__main__":
sys.exit(main())