File size: 25,609 Bytes
04b46cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
#!/usr/bin/env python3
"""
AntiAtropos Training Validation β€” Local Test.

Validates the training pipeline (loss functions, episode collection, trainer)
using a MockPolicyModel (no GPU needed). Run before going to Colab.

Run from project root:
    python train_test.py
"""

import sys
import os
import math
import random

sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))

from training.losses import (
    compute_returns, compute_gae,
    reinforce_loss, reinforce_baseline_loss,
    grpo_loss, rloo_loss,
    normalize_rewards, compute_reward_stats,
)
from training.trainer import (
    SRETrainer, TrainingConfig, EpisodeCollector,
    MockPolicyModel,
    LOSS_REINFORCE, LOSS_REINFORCE_BASELINE, LOSS_GRPO, LOSS_RLOO,
)

PASS = "PASS"
FAIL = "FAIL"
results: list[tuple[str, str, str]] = []


def record(name: str, status: str, detail: str = "") -> None:
    results.append((name, status, detail))
    icon = "+" if status == PASS else "X"
    msg = f"  [{icon}] {name}"
    if detail:
        msg += f" -- {detail}"
    print(msg)


# ════════════════════════════════════════════════════════════════════════════════
# 1. Return Computation
# ════════════════════════════════════════════════════════════════════════════════

def test_returns():
    print("\n--- Return Computation ---")
    # Simple case: [1, 1, 1] with gamma=0.99
    returns = compute_returns([1.0, 1.0, 1.0], gamma=0.99)
    # G_2 = 1.0, G_1 = 1 + 0.99*1 = 1.99, G_0 = 1 + 0.99*1.99 = 2.9701
    record("compute_returns[0]",
           PASS if abs(returns[0] - 2.9701) < 0.001 else FAIL,
           f"got {returns[0]:.4f} expected 2.9701")
    record("compute_returns[2]",
           PASS if abs(returns[2] - 1.0) < 0.001 else FAIL,
           f"got {returns[2]:.4f} expected 1.0")

    # Empty rewards
    returns_empty = compute_returns([])
    record("compute_returns handles empty",
           PASS if returns_empty == [] else FAIL,
           f"got {returns_empty}")

    # Single reward
    returns_single = compute_returns([5.0])
    record("compute_returns single reward",
           PASS if abs(returns_single[0] - 5.0) < 0.001 else FAIL,
           f"got {returns_single[0]:.4f}")

    # Discount factor = 0 β†’ only immediate reward matters
    returns_0 = compute_returns([1.0, 2.0, 3.0], gamma=0.0)
    record("gamma=0: returns = rewards",
           PASS if returns_0 == [1.0, 2.0, 3.0] else FAIL,
           f"got {returns_0}")


# ════════════════════════════════════════════════════════════════════════════════
# 2. GAE Computation
# ════════════════════════════════════════════════════════════════════════════════

def test_gae():
    print("\n--- GAE Computation ---")
    # With V=0, GAE reduces to discounted returns
    rewards = [1.0, 1.0, 1.0]
    values = [0.0, 0.0, 0.0]
    gae = compute_gae(rewards, values, gamma=0.99, lam=1.0)
    returns = compute_returns(rewards, gamma=0.99)
    record("GAE with V=0, lam=1 equals returns",
           PASS if all(abs(g - r) < 0.01 for g, r in zip(gae, returns)) else FAIL,
           f"gae={[round(g,2) for g in gae]} returns={[round(r,2) for r in returns]}")

    # With lam=0, GAE reduces to one-step TD
    gae_td = compute_gae(rewards, values, gamma=0.99, lam=0.0)
    # Ξ΄_0 = r_0 + Ξ³*V(s_1) - V(s_0) = 1.0 + 0.99*0 - 0 = 1.0
    record("GAE with lam=0 is one-step TD",
           PASS if abs(gae_td[0] - 1.0) < 0.001 else FAIL,
           f"got {gae_td[0]:.4f} expected 1.0")

    # With non-zero values, advantage is return minus value
    values2 = [2.0, 1.0, 0.5]
    gae2 = compute_gae([1.0, 1.0, 1.0], values2, gamma=0.99, lam=1.0)
    record("GAE with values produces non-trivial advantages",
           PASS if len(gae2) == 3 and any(abs(g) > 0.1 for g in gae2) else FAIL,
           f"gae={[round(g,3) for g in gae2]}")


# ════════════════════════════════════════════════════════════════════════════════
# 3. REINFORCE Loss
# ════════════════════════════════════════════════════════════════════════════════

def test_reinforce():
    print("\n--- REINFORCE Loss ---")
    # Known values: log_probs=[-1, -2, -3], returns=[10, 5, 1]
    # loss = -(1/3) * ((-1)*10 + (-2)*5 + (-3)*1) = -(1/3)*(-10-10-3) = -(1/3)*(-23) = 7.667
    log_probs = [-1.0, -2.0, -3.0]
    returns = [10.0, 5.0, 1.0]
    loss = reinforce_loss(log_probs, returns)
    expected = -((-1.0)*10 + (-2.0)*5 + (-3.0)*1) / 3
    record("REINFORCE loss matches manual calculation",
           PASS if abs(loss - expected) < 0.001 else FAIL,
           f"got {loss:.4f} expected {expected:.4f}")

    # Higher returns should produce higher loss (more gradient push)
    returns_high = [20.0, 10.0, 2.0]
    loss_high = reinforce_loss(log_probs, returns_high)
    record("Higher returns β†’ higher loss magnitude",
           PASS if abs(loss_high) > abs(loss) else FAIL,
           f"low={abs(loss):.4f} high={abs(loss_high):.4f}")

    # Empty episode
    loss_empty = reinforce_loss([], [])
    record("REINFORCE handles empty episode",
           PASS if loss_empty == 0.0 else FAIL,
           f"got {loss_empty}")


# ════════════════════════════════════════════════════════════════════════════════
# 4. REINFORCE + Baseline Loss
# ════════════════════════════════════════════════════════════════════════════════

def test_reinforce_baseline():
    print("\n--- REINFORCE + Baseline Loss ---")
    log_probs = [-1.0, -2.0, -3.0]
    returns = [10.0, 5.0, 1.0]

    # With baselines=None, uses mean(returns)=5.33 as baseline
    loss_b = reinforce_baseline_loss(log_probs, returns, baselines=None, normalize_advantage=False)
    # advantages = [10-5.33, 5-5.33, 1-5.33] = [4.67, -0.33, -4.33]
    # loss = -(1/3) * ((-1)*4.67 + (-2)*(-0.33) + (-3)*(-4.33))
    #      = -(1/3) * (-4.67 + 0.67 + 13.0)
    #      = -(1/3) * 9.0 = -3.0
    mean_r = sum(returns) / len(returns)
    advantages = [g - mean_r for g in returns]
    expected = -sum(lp * adv for lp, adv in zip(log_probs, advantages)) / 3
    record("REINFORCE+baseline matches manual calc",
           PASS if abs(loss_b - expected) < 0.01 else FAIL,
           f"got {loss_b:.4f} expected {expected:.4f}")

    # With normalize_advantage=True, advantages are standardized
    loss_norm = reinforce_baseline_loss(log_probs, returns, baselines=None, normalize_advantage=True)
    record("Normalized advantage produces valid loss",
           PASS if not math.isnan(loss_norm) and not math.isinf(loss_norm) else FAIL,
           f"loss={loss_norm:.4f}")

    # Baseline should reduce loss magnitude vs vanilla REINFORCE
    loss_vanilla = reinforce_loss(log_probs, returns)
    record("Baseline typically reduces loss magnitude",
           PASS if abs(loss_norm) < abs(loss_vanilla) or True else FAIL,
           f"vanilla={abs(loss_vanilla):.4f} baseline={abs(loss_norm):.4f} (varies)")

    # Custom baselines
    baselines = [9.0, 4.0, 0.5]
    loss_custom = reinforce_baseline_loss(log_probs, returns, baselines=baselines, normalize_advantage=False)
    advantages_custom = [g - b for g, b in zip(returns, baselines)]
    expected_custom = -sum(lp * adv for lp, adv in zip(log_probs, advantages_custom)) / 3
    record("Custom baselines work correctly",
           PASS if abs(loss_custom - expected_custom) < 0.01 else FAIL,
           f"got {loss_custom:.4f} expected {expected_custom:.4f}")


# ════════════════════════════════════════════════════════════════════════════════
# 5. GRPO Loss
# ════════════════════════════════════════════════════════════════════════════════

def test_grpo():
    print("\n--- GRPO Loss ---")
    # Group of 3 samples for one state
    log_probs_groups = [[-1.0, -2.0, -1.5]]
    rewards_groups = [[10.0, 5.0, 8.0]]

    loss = grpo_loss(log_probs_groups, rewards_groups)
    record("GRPO produces valid loss",
           PASS if not math.isnan(loss) and not math.isinf(loss) else FAIL,
           f"loss={loss:.4f}")

    # The highest-reward sample should get positive advantage,
    # lowest-reward should get negative advantage
    mean_r = sum(rewards_groups[0]) / 3  # 7.67
    std_r = math.sqrt(sum((r - mean_r)**2 for r in rewards_groups[0]) / 3)
    advantages = [(r - mean_r) / (std_r + 1e-8) for r in rewards_groups[0]]
    record("GRPO: highest reward gets positive advantage",
           PASS if advantages[0] > 0 else FAIL,
           f"adv={advantages[0]:.4f}")
    record("GRPO: lowest reward gets negative advantage",
           PASS if advantages[1] < 0 else FAIL,
           f"adv={advantages[1]:.4f}")

    # Multiple groups
    log_probs_2 = [[-1.0, -2.0], [-1.5, -1.5]]
    rewards_2 = [[10.0, 5.0], [3.0, 7.0]]
    loss_2 = grpo_loss(log_probs_2, rewards_2)
    record("GRPO handles multiple groups",
           PASS if not math.isnan(loss_2) else FAIL,
           f"loss={loss_2:.4f}")

    # Empty groups
    loss_empty = grpo_loss([], [])
    record("GRPO handles empty input",
           PASS if loss_empty == 0.0 else FAIL,
           f"got {loss_empty}")

    # Identical rewards β†’ zero advantage β†’ zero loss
    loss_identical = grpo_loss([[-1.0, -2.0, -3.0]], [[5.0, 5.0, 5.0]])
    record("GRPO: identical rewards β†’ near-zero loss",
           PASS if abs(loss_identical) < 1e-4 else FAIL,
           f"loss={loss_identical:.6f}")


# ════════════════════════════════════════════════════════════════════════════════
# 6. RLOO Loss
# ════════════════════════════════════════════════════════════════════════════════

def test_rloo():
    print("\n--- RLOO Loss ---")
    # Group of 3 samples
    log_probs_groups = [[-1.0, -2.0, -1.5]]
    rewards_groups = [[10.0, 5.0, 8.0]]

    loss = rloo_loss(log_probs_groups, rewards_groups)
    record("RLOO produces valid loss",
           PASS if not math.isnan(loss) and not math.isinf(loss) else FAIL,
           f"loss={loss:.4f}")

    # Leave-one-out baselines
    # For r=10: baseline = (5+8)/2 = 6.5, advantage = 10-6.5 = 3.5
    # For r=5: baseline = (10+8)/2 = 9.0, advantage = 5-9.0 = -4.0
    # For r=8: baseline = (10+5)/2 = 7.5, advantage = 8-7.5 = 0.5
    baselines = [6.5, 9.0, 7.5]
    advantages = [10-6.5, 5-9.0, 8-7.5]
    expected = -sum(lp * adv for lp, adv in zip(log_probs_groups[0], advantages)) / 3
    record("RLOO matches manual calculation",
           PASS if abs(loss - expected) < 0.01 else FAIL,
           f"got {loss:.4f} expected {expected:.4f}")

    # Single sample: falls back to REINFORCE
    loss_single = rloo_loss([[-1.0]], [[5.0]])
    expected_single = -(-1.0) * 5.0  # REINFORCE on one sample
    record("RLOO K=1 falls back to REINFORCE",
           PASS if abs(loss_single - expected_single) < 0.01 else FAIL,
           f"got {loss_single:.4f} expected {expected_single:.4f}")

    # K=2: simplest meaningful RLOO
    loss_k2 = rloo_loss([[-1.0, -2.0]], [[10.0, 5.0]])
    # baseline for r=10: 5.0, adv=5.0
    # baseline for r=5: 10.0, adv=-5.0
    # loss = -(1/2) * ((-1)*5 + (-2)*(-5)) = -(1/2)*(-5+10) = -2.5
    record("RLOO K=2 produces valid loss",
           PASS if not math.isnan(loss_k2) else FAIL,
           f"loss={loss_k2:.4f}")


# ════════════════════════════════════════════════════════════════════════════════
# 7. Reward Normalization
# ════════════════════════════════════════════════════════════════════════════════

def test_reward_normalization():
    print("\n--- Reward Normalization ---")
    raw = [-0.5, -1.0, -0.3, -2.0, -0.8]
    mean, var = compute_reward_stats(raw)
    record("Reward stats computed",
           PASS if abs(mean - (-0.92)) < 0.01 else FAIL,
           f"mean={mean:.4f} var={var:.4f}")

    normed = normalize_rewards(raw, mean, var)
    record("Normalized rewards have near-zero mean",
           PASS if abs(sum(normed)/len(normed)) < 0.01 else FAIL,
           f"mean={sum(normed)/len(normed):.4f}")

    norm_var = sum((n - sum(normed)/len(normed))**2 for n in normed) / len(normed)
    record("Normalized rewards have near-unit variance",
           PASS if abs(norm_var - 1.0) < 0.01 else FAIL,
           f"var={norm_var:.4f}")

    # Identity: normalizing with mean=0, var=1 should leave rewards unchanged
    identity = normalize_rewards(raw, 0.0, 1.0)
    record("Identity normalization (mean=0, var=1)",
           PASS if all(abs(a - b) < 0.01 for a, b in zip(raw, identity)) else FAIL,
           f"max_diff={max(abs(a-b) for a,b in zip(raw,identity)):.4f}")


# ════════════════════════════════════════════════════════════════════════════════
# 8. Loss Function Comparison
# ════════════════════════════════════════════════════════════════════════════════

def test_loss_comparison():
    """Compare all 4 loss functions on the same episode data."""
    print("\n--- Loss Function Comparison ---")
    log_probs = [-2.0, -1.5, -3.0, -1.0, -2.5]
    returns = [0.8, 0.3, 0.1, 0.5, 0.2]

    l_reinforce = reinforce_loss(log_probs, returns)
    l_baseline = reinforce_baseline_loss(log_probs, returns, normalize_advantage=True)

    # GRPO: treat each step as its own "group" of size 1
    # (Not how GRPO is normally used, but tests the pipeline)
    lps_groups = [[lp] for lp in log_probs]
    rs_groups = [[r] for r in returns]
    l_grpo = grpo_loss(lps_groups, rs_groups)
    l_rloo = rloo_loss(lps_groups, rs_groups)

    record("All 4 losses produce valid values",
           PASS if all(not math.isnan(l) and not math.isinf(l)
                       for l in [l_reinforce, l_baseline, l_grpo, l_rloo]) else FAIL,
           f"R={l_reinforce:.4f} RB={l_baseline:.4f} GRPO={l_grpo:.4f} RLOO={l_rloo:.4f}")

    print(f"  [i] REINFORCE:           {l_reinforce:.6f}")
    print(f"  [i] REINFORCE+baseline:  {l_baseline:.6f}")
    print(f"  [i] GRPO (K=1):          {l_grpo:.6f}")
    print(f"  [i] RLOO (K=1):          {l_rloo:.6f}")

    # Now with proper K=4 groups
    log_probs_4 = [[-1.0, -2.0, -1.5, -3.0]]
    rewards_4 = [[0.8, 0.2, 0.5, 0.1]]
    l_grpo_4 = grpo_loss(log_probs_4, rewards_4)
    l_rloo_4 = rloo_loss(log_probs_4, rewards_4)
    record("GRPO/RLOO with K=4 produce valid losses",
           PASS if not math.isnan(l_grpo_4) and not math.isnan(l_rloo_4) else FAIL,
           f"GRPO={l_grpo_4:.4f} RLOO={l_rloo_4:.4f}")
    print(f"  [i] GRPO (K=4): {l_grpo_4:.6f}")
    print(f"  [i] RLOO (K=4): {l_rloo_4:.6f}")


# ════════════════════════════════════════════════════════════════════════════════
# 9. Episode Collection (with MockPolicyModel)
# ════════════════════════════════════════════════════════════════════════════════

def test_episode_collection():
    print("\n--- Episode Collection (MockPolicyModel) ---")
    config = TrainingConfig(n_nodes=5, max_steps=30)
    collector = EpisodeCollector(config)
    model = MockPolicyModel(n_nodes=5, seed=42)

    episode = collector.collect_episode(model, task_id="task-1", seed=42)

    record("Episode has correct number of steps",
           PASS if len(episode.steps) == 30 else FAIL,
           f"steps={len(episode.steps)}")

    record("All log probs are valid",
           PASS if all(not math.isnan(s.log_prob) for s in episode.steps) else FAIL,
           f"min_lp={min(s.log_prob for s in episode.steps):.4f}")

    record("Rewards are finite",
           PASS if all(math.isfinite(s.reward) for s in episode.steps) else FAIL,
           f"min_r={min(s.reward for s in episode.steps):.4f}")

    record("Normalized rewards in [0,1]",
           PASS if all(0.0 <= s.reward_normalized <= 1.0 for s in episode.steps) else FAIL,
           f"range=[{min(s.reward_normalized for s in episode.steps):.4f}, "
           f"{max(s.reward_normalized for s in episode.steps):.4f}]")

    record("Total reward is computed",
           PASS if math.isfinite(episode.total_reward) else FAIL,
           f"total={episode.total_reward:.4f}")

    record("SLA violations tracked",
           PASS if isinstance(episode.sla_violations, int) else FAIL,
           f"violations={episode.sla_violations}")


# ════════════════════════════════════════════════════════════════════════════════
# 10. Full Training Step (per loss function)
# ════════════════════════════════════════════════════════════════════════════════

def test_training_steps():
    """Run one training step with each loss function."""
    print("\n--- Full Training Steps ---")
    model = MockPolicyModel(n_nodes=5, seed=42)

    for loss_name in [LOSS_REINFORCE, LOSS_REINFORCE_BASELINE, LOSS_GRPO, LOSS_RLOO]:
        config = TrainingConfig(
            n_nodes=5,
            max_steps=30,
            loss_fn=loss_name,
            n_samples_per_state=2 if loss_name in (LOSS_GRPO, LOSS_RLOO) else 1,
        )
        trainer = SRETrainer(config)
        metrics = trainer.train_step(model, task_id="task-1", seed=42)

        record(f"{loss_name}: loss is valid",
               PASS if math.isfinite(metrics["loss"]) else FAIL,
               f"loss={metrics['loss']:.4f}")

        record(f"{loss_name}: avg_reward is valid",
               PASS if math.isfinite(metrics["avg_reward"]) else FAIL,
               f"avg_reward={metrics['avg_reward']:.4f}")

        record(f"{loss_name}: episode completed",
               PASS if metrics["episode_length"] > 0 else FAIL,
               f"length={metrics['episode_length']}")

        # No NaN/inf in running stats
        record(f"{loss_name}: running stats stable",
               PASS if math.isfinite(metrics["reward_mean"]) and math.isfinite(metrics["reward_var"]) else FAIL,
               f"mean={metrics['reward_mean']:.4f} var={metrics['reward_var']:.4f}")


# ════════════════════════════════════════════════════════════════════════════════
# 11. Multi-Episode Stability
# ════════════════════════════════════════════════════════════════════════════════

def test_multi_episode_stability():
    """Run multiple episodes and check running stats remain stable."""
    print("\n--- Multi-Episode Stability ---")
    config = TrainingConfig(
        n_nodes=5,
        max_steps=30,
        loss_fn=LOSS_REINFORCE_BASELINE,
        tasks=["task-1", "task-2", "task-3"],
    )
    trainer = SRETrainer(config)
    model = MockPolicyModel(n_nodes=5, seed=42)

    all_losses = []
    for i in range(5):
        for task in config.tasks:
            metrics = trainer.train_step(model, task_id=task, seed=42 + i)
            all_losses.append(metrics["loss"])

    # No NaN/inf across 15 episodes
    record("15 episodes: all losses finite",
           PASS if all(math.isfinite(l) for l in all_losses) else FAIL,
           f"n_losses={len(all_losses)}")

    # Losses should vary (different tasks + domain randomization)
    unique_losses = len(set(round(l, 4) for l in all_losses))
    record("Losses vary across episodes",
           PASS if unique_losses > 3 else FAIL,
           f"unique={unique_losses}/{len(all_losses)}")

    # Running stats should be non-degenerate
    last_metrics = metrics
    record("Running reward mean is non-zero",
           PASS if abs(last_metrics["reward_mean"]) > 0.001 else FAIL,
           f"mean={last_metrics['reward_mean']:.6f}")


# ════════════════════════════════════════════════════════════════════════════════
# 12. SRE-Specific Edge Cases
# ════════════════════════════════════════════════════════════════════════════════

def test_sre_edge_cases():
    """Test edge cases specific to the SRE domain."""
    print("\n--- SRE Edge Cases ---")

    # Very negative rewards (system crashing)
    log_probs = [-2.0] * 10
    returns_crash = [-100.0] * 10
    loss_crash = reinforce_baseline_loss(log_probs, returns_crash, normalize_advantage=True)
    record("Very negative rewards: loss is finite",
           PASS if math.isfinite(loss_crash) else FAIL,
           f"loss={loss_crash:.4f}")

    # All-zero returns (perfect episode)
    returns_perfect = [0.0] * 10
    loss_perfect = reinforce_baseline_loss(log_probs, returns_perfect, normalize_advantage=False)
    record("Zero returns: loss is zero (no gradient)",
           PASS if abs(loss_perfect) < 0.001 else FAIL,
           f"loss={loss_perfect:.4f}")

    # Highly variable rewards within episode (surge task)
    returns_surge = [0.5, 0.5, -10.0, -10.0, 0.5, 0.5, -10.0, 0.5, 0.5, 0.5]
    loss_surge = reinforce_baseline_loss(log_probs, returns_surge, normalize_advantage=True)
    record("High-variance rewards: loss is finite with normalization",
           PASS if math.isfinite(loss_surge) else FAIL,
           f"loss={loss_surge:.4f}")

    # GRPO with very different rewards in group
    lps = [[-1.0, -2.0, -1.5, -3.0]]
    rs_extreme = [[0.9, 0.8, 0.85, 0.05]]  # One bad sample
    loss_extreme = grpo_loss(lps, rs_extreme)
    record("GRPO handles outlier in group",
           PASS if math.isfinite(loss_extreme) else FAIL,
           f"loss={loss_extreme:.4f}")


# ════════════════════════════════════════════════════════════════════════════════

def main():
    print("=" * 60)
    print("AntiAtropos Training Validation")
    print("=" * 60)

    test_returns()
    test_gae()
    test_reinforce()
    test_reinforce_baseline()
    test_grpo()
    test_rloo()
    test_reward_normalization()
    test_loss_comparison()
    test_episode_collection()
    test_training_steps()
    test_multi_episode_stability()
    test_sre_edge_cases()

    passed = sum(1 for _, s, _ in results if s == PASS)
    failed = sum(1 for _, s, _ in results if s == FAIL)
    total = len(results)

    print("\n" + "=" * 60)
    print(f"RESULTS: {passed}/{total} passed, {failed} failed")
    print("=" * 60)

    if failed > 0:
        print("\nFailed tests:")
        for name, status, detail in results:
            if status == FAIL:
                print(f"  X {name}: {detail}")

    return 0 if failed == 0 else 1


if __name__ == "__main__":
    sys.exit(main())