""" Unit tests for UGTCModule. Tests cover: - Gate computation correctness - GAE computation - Advantage shapes and dtype - EMA normalization - Parameter counting - Value blending """ import pytest import torch import numpy as np from ugtc.module import ValueNetwork, EnsembleValueNetwork, UGTCModule OBS_DIM = 17 BATCH = 32 HIDDEN = 32 @pytest.fixture def ugtc(): return UGTCModule(obs_dim=OBS_DIM, hidden_dim=HIDDEN, M=3) @pytest.fixture def obs(): return torch.randn(BATCH, OBS_DIM) # ── ValueNetwork ───────────────────────────────────────────────────────────── class TestValueNetwork: def test_output_shape(self): net = ValueNetwork(OBS_DIM, HIDDEN) obs = torch.randn(BATCH, OBS_DIM) out = net(obs) assert out.shape == (BATCH,), f"Expected ({BATCH},), got {out.shape}" def test_single_sample(self): net = ValueNetwork(OBS_DIM, HIDDEN) obs = torch.randn(1, OBS_DIM) out = net(obs) assert out.shape == (1,) def test_grad_flows(self): net = ValueNetwork(OBS_DIM, HIDDEN) obs = torch.randn(BATCH, OBS_DIM, requires_grad=False) loss = net(obs).mean() loss.backward() for p in net.parameters(): if p.requires_grad: assert p.grad is not None # ── EnsembleValueNetwork ────────────────────────────────────────────────────── class TestEnsembleValueNetwork: def test_output_shapes(self): ens = EnsembleValueNetwork(OBS_DIM, HIDDEN, M=3) obs = torch.randn(BATCH, OBS_DIM) mean, sigma = ens(obs) assert mean.shape == (BATCH,) assert sigma.shape == (BATCH,) def test_uncertainty_nonneg(self): ens = EnsembleValueNetwork(OBS_DIM, HIDDEN, M=3) obs = torch.randn(BATCH, OBS_DIM) _, sigma = ens(obs) assert (sigma >= 0).all(), "Uncertainty (std) must be non-negative" def test_diversity(self): """Members should produce different outputs (different random init).""" ens = EnsembleValueNetwork(OBS_DIM, HIDDEN, M=5) obs = torch.randn(BATCH, OBS_DIM) all_vals = ens.forward_all(obs) # (M, batch) # At least one pair should differ diffs = all_vals[0] - all_vals[1] assert diffs.abs().mean() > 0, "Ensemble members should differ" def test_forward_all_shape(self): M = 4 ens = EnsembleValueNetwork(OBS_DIM, HIDDEN, M=M) out = ens.forward_all(torch.randn(BATCH, OBS_DIM)) assert out.shape == (M, BATCH) # ── UGTCModule ──────────────────────────────────────────────────────────────── class TestUGTCModule: def test_gate_shape(self, ugtc, obs): gate, v_fast, v_slow = ugtc.compute_gate(obs) assert gate.shape == (BATCH,) assert v_fast.shape == (BATCH,) assert v_slow.shape == (BATCH,) def test_gate_in_unit_interval(self, ugtc, obs): gate, _, _ = ugtc.compute_gate(obs) assert (gate >= 0.0).all(), "Gate must be >= 0" assert (gate <= 1.0).all(), "Gate must be <= 1" def test_advantage_shape(self, ugtc, obs): next_obs = torch.randn(BATCH, OBS_DIM) rewards = torch.randn(BATCH) dones = torch.zeros(BATCH) adv = ugtc.compute_advantages(obs, next_obs, rewards, dones, gamma=0.99) assert adv.shape == (BATCH,), f"Expected ({BATCH},), got {adv.shape}" def test_advantage_finite(self, ugtc, obs): next_obs = torch.randn(BATCH, OBS_DIM) rewards = torch.randn(BATCH) dones = torch.zeros(BATCH) adv = ugtc.compute_advantages(obs, next_obs, rewards, dones) assert torch.isfinite(adv).all(), "All advantages should be finite" def test_value_shape(self, ugtc, obs): v = ugtc.get_value_ugtc(obs) assert v.shape == (BATCH,) def test_value_finite(self, ugtc, obs): v = ugtc.get_value_ugtc(obs) assert torch.isfinite(v).all() def test_gae_zero_reward(self, ugtc): """Zero rewards with no termination should give near-zero advantages.""" T = 16 obs = torch.randn(T, OBS_DIM) next_obs = torch.randn(T, OBS_DIM) rewards = torch.zeros(T) dones = torch.zeros(T) adv = ugtc.compute_advantages(obs, next_obs, rewards, dones, gamma=0.99) # GAE with zero rewards is not necessarily zero (depends on value differences) assert adv.shape == (T,) def test_done_masks(self, ugtc): """Done flags should prevent bootstrapping across episodes.""" T = 4 obs = torch.randn(T, OBS_DIM) next_obs = torch.randn(T, OBS_DIM) rewards = torch.ones(T) dones_all = torch.ones(T) # every step terminates dones_none = torch.zeros(T) adv_all = ugtc.compute_advantages(obs, next_obs, rewards, dones_all) adv_none = ugtc.compute_advantages(obs, next_obs, rewards, dones_none) # These should differ because done=1 zeroes out future bootstrapping assert not torch.allclose(adv_all, adv_none) def test_parameter_count(self, ugtc): counts = ugtc.parameter_count() assert "fast_critic" in counts assert "slow_ensemble" in counts assert "total" in counts assert counts["total"] == counts["fast_critic"] + counts["slow_ensemble"] assert counts["total"] > 0 def test_gate_stats_keys(self, ugtc, obs): stats = ugtc.get_gate_stats(obs) for key in ("gate_mean", "gate_std", "gate_min", "gate_max", "sigma_ema"): assert key in stats, f"Missing key: {key}" def test_ema_updates_during_training(self, ugtc, obs): ugtc.train() initial_ema = ugtc.sigma_ema.item() for _ in range(10): ugtc.compute_gate(obs) updated_ema = ugtc.sigma_ema.item() # EMA should change (unless uncertainty is exactly 1.0 from the start) assert isinstance(updated_ema, float) def test_ema_frozen_in_eval(self, ugtc, obs): ugtc.eval() initial_ema = ugtc.sigma_ema.item() for _ in range(10): ugtc.compute_gate(obs) assert ugtc.sigma_ema.item() == initial_ema, "EMA should not update in eval mode" def test_different_lambda_values(self): """Verify different lambda values produce different advantages.""" ugtc_low = UGTCModule(OBS_DIM, HIDDEN, lambda_fast=0.1, lambda_slow=0.2) ugtc_high = UGTCModule(OBS_DIM, HIDDEN, lambda_fast=0.8, lambda_slow=0.99) obs = torch.randn(16, OBS_DIM) next_obs = torch.randn(16, OBS_DIM) rewards = torch.randn(16) dones = torch.zeros(16) adv_low = ugtc_low.compute_advantages(obs, next_obs, rewards, dones) adv_high = ugtc_high.compute_advantages(obs, next_obs, rewards, dones) assert not torch.allclose(adv_low, adv_high) def test_beta_affects_gate_sharpness(self): """Higher beta should produce sharper gate transitions.""" ugtc_low_beta = UGTCModule(OBS_DIM, HIDDEN, beta=0.1) ugtc_high_beta = UGTCModule(OBS_DIM, HIDDEN, beta=20.0) obs = torch.randn(64, OBS_DIM) gate_low, _, _ = ugtc_low_beta.compute_gate(obs) gate_high, _, _ = ugtc_high_beta.compute_gate(obs) # High beta should have more extreme values (closer to 0 or 1) extremity_low = ((gate_low - 0.5).abs()).mean() extremity_high = ((gate_high - 0.5).abs()).mean() assert extremity_high >= extremity_low, "Higher beta should produce sharper gate" @pytest.mark.parametrize("M", [1, 2, 5, 10]) def test_ensemble_sizes(self, M): ugtc = UGTCModule(OBS_DIM, HIDDEN, M=M) obs = torch.randn(BATCH, OBS_DIM) gate, v_fast, v_slow = ugtc.compute_gate(obs) assert gate.shape == (BATCH,) def test_no_grad_in_advantage_computation(self, ugtc, obs): """compute_advantages should not retain gradients on the output.""" next_obs = torch.randn(BATCH, OBS_DIM) rewards = torch.randn(BATCH) dones = torch.zeros(BATCH) adv = ugtc.compute_advantages(obs, next_obs, rewards, dones) assert not adv.requires_grad # ── GAE computation ──────────────────────────────────────────────────────────── class TestGAEComputation: def test_single_step_gae(self): """Single step: advantage = δ = r + γV(s') - V(s).""" rewards = torch.tensor([1.0]) values = torch.tensor([0.5]) next_values = torch.tensor([0.5]) dones = torch.tensor([0.0]) adv = UGTCModule._compute_gae(rewards, values, next_values, dones, gamma=0.99, lam=0.95) expected = 1.0 + 0.99 * 0.5 - 0.5 assert abs(adv[0].item() - expected) < 1e-5 def test_terminal_step(self): """Done=1 should zero out future value bootstrap.""" rewards = torch.tensor([1.0]) values = torch.tensor([0.0]) next_values = torch.tensor([100.0]) # large, should be masked dones = torch.tensor([1.0]) adv = UGTCModule._compute_gae(rewards, values, next_values, dones, gamma=0.99, lam=0.95) expected = 1.0 + 0.99 * 100.0 * 0.0 - 0.0 # next_values masked out assert abs(adv[0].item() - expected) < 1e-5