File size: 9,734 Bytes
d92d8cf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 | """
Unit tests for UGTCModule.
Tests cover:
- Gate computation correctness
- GAE computation
- Advantage shapes and dtype
- EMA normalization
- Parameter counting
- Value blending
"""
import pytest
import torch
import numpy as np
from ugtc.module import ValueNetwork, EnsembleValueNetwork, UGTCModule
OBS_DIM = 17
BATCH = 32
HIDDEN = 32
@pytest.fixture
def ugtc():
return UGTCModule(obs_dim=OBS_DIM, hidden_dim=HIDDEN, M=3)
@pytest.fixture
def obs():
return torch.randn(BATCH, OBS_DIM)
# ββ ValueNetwork βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class TestValueNetwork:
def test_output_shape(self):
net = ValueNetwork(OBS_DIM, HIDDEN)
obs = torch.randn(BATCH, OBS_DIM)
out = net(obs)
assert out.shape == (BATCH,), f"Expected ({BATCH},), got {out.shape}"
def test_single_sample(self):
net = ValueNetwork(OBS_DIM, HIDDEN)
obs = torch.randn(1, OBS_DIM)
out = net(obs)
assert out.shape == (1,)
def test_grad_flows(self):
net = ValueNetwork(OBS_DIM, HIDDEN)
obs = torch.randn(BATCH, OBS_DIM, requires_grad=False)
loss = net(obs).mean()
loss.backward()
for p in net.parameters():
if p.requires_grad:
assert p.grad is not None
# ββ EnsembleValueNetwork ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class TestEnsembleValueNetwork:
def test_output_shapes(self):
ens = EnsembleValueNetwork(OBS_DIM, HIDDEN, M=3)
obs = torch.randn(BATCH, OBS_DIM)
mean, sigma = ens(obs)
assert mean.shape == (BATCH,)
assert sigma.shape == (BATCH,)
def test_uncertainty_nonneg(self):
ens = EnsembleValueNetwork(OBS_DIM, HIDDEN, M=3)
obs = torch.randn(BATCH, OBS_DIM)
_, sigma = ens(obs)
assert (sigma >= 0).all(), "Uncertainty (std) must be non-negative"
def test_diversity(self):
"""Members should produce different outputs (different random init)."""
ens = EnsembleValueNetwork(OBS_DIM, HIDDEN, M=5)
obs = torch.randn(BATCH, OBS_DIM)
all_vals = ens.forward_all(obs) # (M, batch)
# At least one pair should differ
diffs = all_vals[0] - all_vals[1]
assert diffs.abs().mean() > 0, "Ensemble members should differ"
def test_forward_all_shape(self):
M = 4
ens = EnsembleValueNetwork(OBS_DIM, HIDDEN, M=M)
out = ens.forward_all(torch.randn(BATCH, OBS_DIM))
assert out.shape == (M, BATCH)
# ββ UGTCModule ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class TestUGTCModule:
def test_gate_shape(self, ugtc, obs):
gate, v_fast, v_slow = ugtc.compute_gate(obs)
assert gate.shape == (BATCH,)
assert v_fast.shape == (BATCH,)
assert v_slow.shape == (BATCH,)
def test_gate_in_unit_interval(self, ugtc, obs):
gate, _, _ = ugtc.compute_gate(obs)
assert (gate >= 0.0).all(), "Gate must be >= 0"
assert (gate <= 1.0).all(), "Gate must be <= 1"
def test_advantage_shape(self, ugtc, obs):
next_obs = torch.randn(BATCH, OBS_DIM)
rewards = torch.randn(BATCH)
dones = torch.zeros(BATCH)
adv = ugtc.compute_advantages(obs, next_obs, rewards, dones, gamma=0.99)
assert adv.shape == (BATCH,), f"Expected ({BATCH},), got {adv.shape}"
def test_advantage_finite(self, ugtc, obs):
next_obs = torch.randn(BATCH, OBS_DIM)
rewards = torch.randn(BATCH)
dones = torch.zeros(BATCH)
adv = ugtc.compute_advantages(obs, next_obs, rewards, dones)
assert torch.isfinite(adv).all(), "All advantages should be finite"
def test_value_shape(self, ugtc, obs):
v = ugtc.get_value_ugtc(obs)
assert v.shape == (BATCH,)
def test_value_finite(self, ugtc, obs):
v = ugtc.get_value_ugtc(obs)
assert torch.isfinite(v).all()
def test_gae_zero_reward(self, ugtc):
"""Zero rewards with no termination should give near-zero advantages."""
T = 16
obs = torch.randn(T, OBS_DIM)
next_obs = torch.randn(T, OBS_DIM)
rewards = torch.zeros(T)
dones = torch.zeros(T)
adv = ugtc.compute_advantages(obs, next_obs, rewards, dones, gamma=0.99)
# GAE with zero rewards is not necessarily zero (depends on value differences)
assert adv.shape == (T,)
def test_done_masks(self, ugtc):
"""Done flags should prevent bootstrapping across episodes."""
T = 4
obs = torch.randn(T, OBS_DIM)
next_obs = torch.randn(T, OBS_DIM)
rewards = torch.ones(T)
dones_all = torch.ones(T) # every step terminates
dones_none = torch.zeros(T)
adv_all = ugtc.compute_advantages(obs, next_obs, rewards, dones_all)
adv_none = ugtc.compute_advantages(obs, next_obs, rewards, dones_none)
# These should differ because done=1 zeroes out future bootstrapping
assert not torch.allclose(adv_all, adv_none)
def test_parameter_count(self, ugtc):
counts = ugtc.parameter_count()
assert "fast_critic" in counts
assert "slow_ensemble" in counts
assert "total" in counts
assert counts["total"] == counts["fast_critic"] + counts["slow_ensemble"]
assert counts["total"] > 0
def test_gate_stats_keys(self, ugtc, obs):
stats = ugtc.get_gate_stats(obs)
for key in ("gate_mean", "gate_std", "gate_min", "gate_max", "sigma_ema"):
assert key in stats, f"Missing key: {key}"
def test_ema_updates_during_training(self, ugtc, obs):
ugtc.train()
initial_ema = ugtc.sigma_ema.item()
for _ in range(10):
ugtc.compute_gate(obs)
updated_ema = ugtc.sigma_ema.item()
# EMA should change (unless uncertainty is exactly 1.0 from the start)
assert isinstance(updated_ema, float)
def test_ema_frozen_in_eval(self, ugtc, obs):
ugtc.eval()
initial_ema = ugtc.sigma_ema.item()
for _ in range(10):
ugtc.compute_gate(obs)
assert ugtc.sigma_ema.item() == initial_ema, "EMA should not update in eval mode"
def test_different_lambda_values(self):
"""Verify different lambda values produce different advantages."""
ugtc_low = UGTCModule(OBS_DIM, HIDDEN, lambda_fast=0.1, lambda_slow=0.2)
ugtc_high = UGTCModule(OBS_DIM, HIDDEN, lambda_fast=0.8, lambda_slow=0.99)
obs = torch.randn(16, OBS_DIM)
next_obs = torch.randn(16, OBS_DIM)
rewards = torch.randn(16)
dones = torch.zeros(16)
adv_low = ugtc_low.compute_advantages(obs, next_obs, rewards, dones)
adv_high = ugtc_high.compute_advantages(obs, next_obs, rewards, dones)
assert not torch.allclose(adv_low, adv_high)
def test_beta_affects_gate_sharpness(self):
"""Higher beta should produce sharper gate transitions."""
ugtc_low_beta = UGTCModule(OBS_DIM, HIDDEN, beta=0.1)
ugtc_high_beta = UGTCModule(OBS_DIM, HIDDEN, beta=20.0)
obs = torch.randn(64, OBS_DIM)
gate_low, _, _ = ugtc_low_beta.compute_gate(obs)
gate_high, _, _ = ugtc_high_beta.compute_gate(obs)
# High beta should have more extreme values (closer to 0 or 1)
extremity_low = ((gate_low - 0.5).abs()).mean()
extremity_high = ((gate_high - 0.5).abs()).mean()
assert extremity_high >= extremity_low, "Higher beta should produce sharper gate"
@pytest.mark.parametrize("M", [1, 2, 5, 10])
def test_ensemble_sizes(self, M):
ugtc = UGTCModule(OBS_DIM, HIDDEN, M=M)
obs = torch.randn(BATCH, OBS_DIM)
gate, v_fast, v_slow = ugtc.compute_gate(obs)
assert gate.shape == (BATCH,)
def test_no_grad_in_advantage_computation(self, ugtc, obs):
"""compute_advantages should not retain gradients on the output."""
next_obs = torch.randn(BATCH, OBS_DIM)
rewards = torch.randn(BATCH)
dones = torch.zeros(BATCH)
adv = ugtc.compute_advantages(obs, next_obs, rewards, dones)
assert not adv.requires_grad
# ββ GAE computation ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class TestGAEComputation:
def test_single_step_gae(self):
"""Single step: advantage = Ξ΄ = r + Ξ³V(s') - V(s)."""
rewards = torch.tensor([1.0])
values = torch.tensor([0.5])
next_values = torch.tensor([0.5])
dones = torch.tensor([0.0])
adv = UGTCModule._compute_gae(rewards, values, next_values, dones, gamma=0.99, lam=0.95)
expected = 1.0 + 0.99 * 0.5 - 0.5
assert abs(adv[0].item() - expected) < 1e-5
def test_terminal_step(self):
"""Done=1 should zero out future value bootstrap."""
rewards = torch.tensor([1.0])
values = torch.tensor([0.0])
next_values = torch.tensor([100.0]) # large, should be masked
dones = torch.tensor([1.0])
adv = UGTCModule._compute_gae(rewards, values, next_values, dones, gamma=0.99, lam=0.95)
expected = 1.0 + 0.99 * 100.0 * 0.0 - 0.0 # next_values masked out
assert abs(adv[0].item() - expected) < 1e-5
|