| """ |
| Unit tests for UGTCModule. |
| |
| Tests cover: |
| - Gate computation correctness |
| - GAE computation |
| - Advantage shapes and dtype |
| - EMA normalization |
| - Parameter counting |
| - Value blending |
| """ |
|
|
| import pytest |
| import torch |
| import numpy as np |
|
|
| from ugtc.module import ValueNetwork, EnsembleValueNetwork, UGTCModule |
|
|
|
|
| OBS_DIM = 17 |
| BATCH = 32 |
| HIDDEN = 32 |
|
|
|
|
| @pytest.fixture |
| def ugtc(): |
| return UGTCModule(obs_dim=OBS_DIM, hidden_dim=HIDDEN, M=3) |
|
|
|
|
| @pytest.fixture |
| def obs(): |
| return torch.randn(BATCH, OBS_DIM) |
|
|
|
|
| |
|
|
| class TestValueNetwork: |
| def test_output_shape(self): |
| net = ValueNetwork(OBS_DIM, HIDDEN) |
| obs = torch.randn(BATCH, OBS_DIM) |
| out = net(obs) |
| assert out.shape == (BATCH,), f"Expected ({BATCH},), got {out.shape}" |
|
|
| def test_single_sample(self): |
| net = ValueNetwork(OBS_DIM, HIDDEN) |
| obs = torch.randn(1, OBS_DIM) |
| out = net(obs) |
| assert out.shape == (1,) |
|
|
| def test_grad_flows(self): |
| net = ValueNetwork(OBS_DIM, HIDDEN) |
| obs = torch.randn(BATCH, OBS_DIM, requires_grad=False) |
| loss = net(obs).mean() |
| loss.backward() |
| for p in net.parameters(): |
| if p.requires_grad: |
| assert p.grad is not None |
|
|
|
|
| |
|
|
| class TestEnsembleValueNetwork: |
| def test_output_shapes(self): |
| ens = EnsembleValueNetwork(OBS_DIM, HIDDEN, M=3) |
| obs = torch.randn(BATCH, OBS_DIM) |
| mean, sigma = ens(obs) |
| assert mean.shape == (BATCH,) |
| assert sigma.shape == (BATCH,) |
|
|
| def test_uncertainty_nonneg(self): |
| ens = EnsembleValueNetwork(OBS_DIM, HIDDEN, M=3) |
| obs = torch.randn(BATCH, OBS_DIM) |
| _, sigma = ens(obs) |
| assert (sigma >= 0).all(), "Uncertainty (std) must be non-negative" |
|
|
| def test_diversity(self): |
| """Members should produce different outputs (different random init).""" |
| ens = EnsembleValueNetwork(OBS_DIM, HIDDEN, M=5) |
| obs = torch.randn(BATCH, OBS_DIM) |
| all_vals = ens.forward_all(obs) |
| |
| diffs = all_vals[0] - all_vals[1] |
| assert diffs.abs().mean() > 0, "Ensemble members should differ" |
|
|
| def test_forward_all_shape(self): |
| M = 4 |
| ens = EnsembleValueNetwork(OBS_DIM, HIDDEN, M=M) |
| out = ens.forward_all(torch.randn(BATCH, OBS_DIM)) |
| assert out.shape == (M, BATCH) |
|
|
|
|
| |
|
|
| class TestUGTCModule: |
| def test_gate_shape(self, ugtc, obs): |
| gate, v_fast, v_slow = ugtc.compute_gate(obs) |
| assert gate.shape == (BATCH,) |
| assert v_fast.shape == (BATCH,) |
| assert v_slow.shape == (BATCH,) |
|
|
| def test_gate_in_unit_interval(self, ugtc, obs): |
| gate, _, _ = ugtc.compute_gate(obs) |
| assert (gate >= 0.0).all(), "Gate must be >= 0" |
| assert (gate <= 1.0).all(), "Gate must be <= 1" |
|
|
| def test_advantage_shape(self, ugtc, obs): |
| next_obs = torch.randn(BATCH, OBS_DIM) |
| rewards = torch.randn(BATCH) |
| dones = torch.zeros(BATCH) |
| adv = ugtc.compute_advantages(obs, next_obs, rewards, dones, gamma=0.99) |
| assert adv.shape == (BATCH,), f"Expected ({BATCH},), got {adv.shape}" |
|
|
| def test_advantage_finite(self, ugtc, obs): |
| next_obs = torch.randn(BATCH, OBS_DIM) |
| rewards = torch.randn(BATCH) |
| dones = torch.zeros(BATCH) |
| adv = ugtc.compute_advantages(obs, next_obs, rewards, dones) |
| assert torch.isfinite(adv).all(), "All advantages should be finite" |
|
|
| def test_value_shape(self, ugtc, obs): |
| v = ugtc.get_value_ugtc(obs) |
| assert v.shape == (BATCH,) |
|
|
| def test_value_finite(self, ugtc, obs): |
| v = ugtc.get_value_ugtc(obs) |
| assert torch.isfinite(v).all() |
|
|
| def test_gae_zero_reward(self, ugtc): |
| """Zero rewards with no termination should give near-zero advantages.""" |
| T = 16 |
| obs = torch.randn(T, OBS_DIM) |
| next_obs = torch.randn(T, OBS_DIM) |
| rewards = torch.zeros(T) |
| dones = torch.zeros(T) |
| adv = ugtc.compute_advantages(obs, next_obs, rewards, dones, gamma=0.99) |
| |
| assert adv.shape == (T,) |
|
|
| def test_done_masks(self, ugtc): |
| """Done flags should prevent bootstrapping across episodes.""" |
| T = 4 |
| obs = torch.randn(T, OBS_DIM) |
| next_obs = torch.randn(T, OBS_DIM) |
| rewards = torch.ones(T) |
| dones_all = torch.ones(T) |
| dones_none = torch.zeros(T) |
|
|
| adv_all = ugtc.compute_advantages(obs, next_obs, rewards, dones_all) |
| adv_none = ugtc.compute_advantages(obs, next_obs, rewards, dones_none) |
| |
| assert not torch.allclose(adv_all, adv_none) |
|
|
| def test_parameter_count(self, ugtc): |
| counts = ugtc.parameter_count() |
| assert "fast_critic" in counts |
| assert "slow_ensemble" in counts |
| assert "total" in counts |
| assert counts["total"] == counts["fast_critic"] + counts["slow_ensemble"] |
| assert counts["total"] > 0 |
|
|
| def test_gate_stats_keys(self, ugtc, obs): |
| stats = ugtc.get_gate_stats(obs) |
| for key in ("gate_mean", "gate_std", "gate_min", "gate_max", "sigma_ema"): |
| assert key in stats, f"Missing key: {key}" |
|
|
| def test_ema_updates_during_training(self, ugtc, obs): |
| ugtc.train() |
| initial_ema = ugtc.sigma_ema.item() |
| for _ in range(10): |
| ugtc.compute_gate(obs) |
| updated_ema = ugtc.sigma_ema.item() |
| |
| assert isinstance(updated_ema, float) |
|
|
| def test_ema_frozen_in_eval(self, ugtc, obs): |
| ugtc.eval() |
| initial_ema = ugtc.sigma_ema.item() |
| for _ in range(10): |
| ugtc.compute_gate(obs) |
| assert ugtc.sigma_ema.item() == initial_ema, "EMA should not update in eval mode" |
|
|
| def test_different_lambda_values(self): |
| """Verify different lambda values produce different advantages.""" |
| ugtc_low = UGTCModule(OBS_DIM, HIDDEN, lambda_fast=0.1, lambda_slow=0.2) |
| ugtc_high = UGTCModule(OBS_DIM, HIDDEN, lambda_fast=0.8, lambda_slow=0.99) |
| obs = torch.randn(16, OBS_DIM) |
| next_obs = torch.randn(16, OBS_DIM) |
| rewards = torch.randn(16) |
| dones = torch.zeros(16) |
| adv_low = ugtc_low.compute_advantages(obs, next_obs, rewards, dones) |
| adv_high = ugtc_high.compute_advantages(obs, next_obs, rewards, dones) |
| assert not torch.allclose(adv_low, adv_high) |
|
|
| def test_beta_affects_gate_sharpness(self): |
| """Higher beta should produce sharper gate transitions.""" |
| ugtc_low_beta = UGTCModule(OBS_DIM, HIDDEN, beta=0.1) |
| ugtc_high_beta = UGTCModule(OBS_DIM, HIDDEN, beta=20.0) |
| obs = torch.randn(64, OBS_DIM) |
| gate_low, _, _ = ugtc_low_beta.compute_gate(obs) |
| gate_high, _, _ = ugtc_high_beta.compute_gate(obs) |
| |
| extremity_low = ((gate_low - 0.5).abs()).mean() |
| extremity_high = ((gate_high - 0.5).abs()).mean() |
| assert extremity_high >= extremity_low, "Higher beta should produce sharper gate" |
|
|
| @pytest.mark.parametrize("M", [1, 2, 5, 10]) |
| def test_ensemble_sizes(self, M): |
| ugtc = UGTCModule(OBS_DIM, HIDDEN, M=M) |
| obs = torch.randn(BATCH, OBS_DIM) |
| gate, v_fast, v_slow = ugtc.compute_gate(obs) |
| assert gate.shape == (BATCH,) |
|
|
| def test_no_grad_in_advantage_computation(self, ugtc, obs): |
| """compute_advantages should not retain gradients on the output.""" |
| next_obs = torch.randn(BATCH, OBS_DIM) |
| rewards = torch.randn(BATCH) |
| dones = torch.zeros(BATCH) |
| adv = ugtc.compute_advantages(obs, next_obs, rewards, dones) |
| assert not adv.requires_grad |
|
|
|
|
| |
|
|
| class TestGAEComputation: |
| def test_single_step_gae(self): |
| """Single step: advantage = Ξ΄ = r + Ξ³V(s') - V(s).""" |
| rewards = torch.tensor([1.0]) |
| values = torch.tensor([0.5]) |
| next_values = torch.tensor([0.5]) |
| dones = torch.tensor([0.0]) |
| adv = UGTCModule._compute_gae(rewards, values, next_values, dones, gamma=0.99, lam=0.95) |
| expected = 1.0 + 0.99 * 0.5 - 0.5 |
| assert abs(adv[0].item() - expected) < 1e-5 |
|
|
| def test_terminal_step(self): |
| """Done=1 should zero out future value bootstrap.""" |
| rewards = torch.tensor([1.0]) |
| values = torch.tensor([0.0]) |
| next_values = torch.tensor([100.0]) |
| dones = torch.tensor([1.0]) |
| adv = UGTCModule._compute_gae(rewards, values, next_values, dones, gamma=0.99, lam=0.95) |
| expected = 1.0 + 0.99 * 100.0 * 0.0 - 0.0 |
| assert abs(adv[0].item() - expected) < 1e-5 |
|
|