ugtc / tests /test_module.py
Ekrem-the-second's picture
Initial release: UGTC - Uncertainty-Gated Temporal Credit
d92d8cf verified
Raw
History Blame Contribute Delete
9.73 kB
"""
Unit tests for UGTCModule.
Tests cover:
- Gate computation correctness
- GAE computation
- Advantage shapes and dtype
- EMA normalization
- Parameter counting
- Value blending
"""
import pytest
import torch
import numpy as np
from ugtc.module import ValueNetwork, EnsembleValueNetwork, UGTCModule
OBS_DIM = 17
BATCH = 32
HIDDEN = 32
@pytest.fixture
def ugtc():
return UGTCModule(obs_dim=OBS_DIM, hidden_dim=HIDDEN, M=3)
@pytest.fixture
def obs():
return torch.randn(BATCH, OBS_DIM)
# ── ValueNetwork ─────────────────────────────────────────────────────────────
class TestValueNetwork:
def test_output_shape(self):
net = ValueNetwork(OBS_DIM, HIDDEN)
obs = torch.randn(BATCH, OBS_DIM)
out = net(obs)
assert out.shape == (BATCH,), f"Expected ({BATCH},), got {out.shape}"
def test_single_sample(self):
net = ValueNetwork(OBS_DIM, HIDDEN)
obs = torch.randn(1, OBS_DIM)
out = net(obs)
assert out.shape == (1,)
def test_grad_flows(self):
net = ValueNetwork(OBS_DIM, HIDDEN)
obs = torch.randn(BATCH, OBS_DIM, requires_grad=False)
loss = net(obs).mean()
loss.backward()
for p in net.parameters():
if p.requires_grad:
assert p.grad is not None
# ── EnsembleValueNetwork ──────────────────────────────────────────────────────
class TestEnsembleValueNetwork:
def test_output_shapes(self):
ens = EnsembleValueNetwork(OBS_DIM, HIDDEN, M=3)
obs = torch.randn(BATCH, OBS_DIM)
mean, sigma = ens(obs)
assert mean.shape == (BATCH,)
assert sigma.shape == (BATCH,)
def test_uncertainty_nonneg(self):
ens = EnsembleValueNetwork(OBS_DIM, HIDDEN, M=3)
obs = torch.randn(BATCH, OBS_DIM)
_, sigma = ens(obs)
assert (sigma >= 0).all(), "Uncertainty (std) must be non-negative"
def test_diversity(self):
"""Members should produce different outputs (different random init)."""
ens = EnsembleValueNetwork(OBS_DIM, HIDDEN, M=5)
obs = torch.randn(BATCH, OBS_DIM)
all_vals = ens.forward_all(obs) # (M, batch)
# At least one pair should differ
diffs = all_vals[0] - all_vals[1]
assert diffs.abs().mean() > 0, "Ensemble members should differ"
def test_forward_all_shape(self):
M = 4
ens = EnsembleValueNetwork(OBS_DIM, HIDDEN, M=M)
out = ens.forward_all(torch.randn(BATCH, OBS_DIM))
assert out.shape == (M, BATCH)
# ── UGTCModule ────────────────────────────────────────────────────────────────
class TestUGTCModule:
def test_gate_shape(self, ugtc, obs):
gate, v_fast, v_slow = ugtc.compute_gate(obs)
assert gate.shape == (BATCH,)
assert v_fast.shape == (BATCH,)
assert v_slow.shape == (BATCH,)
def test_gate_in_unit_interval(self, ugtc, obs):
gate, _, _ = ugtc.compute_gate(obs)
assert (gate >= 0.0).all(), "Gate must be >= 0"
assert (gate <= 1.0).all(), "Gate must be <= 1"
def test_advantage_shape(self, ugtc, obs):
next_obs = torch.randn(BATCH, OBS_DIM)
rewards = torch.randn(BATCH)
dones = torch.zeros(BATCH)
adv = ugtc.compute_advantages(obs, next_obs, rewards, dones, gamma=0.99)
assert adv.shape == (BATCH,), f"Expected ({BATCH},), got {adv.shape}"
def test_advantage_finite(self, ugtc, obs):
next_obs = torch.randn(BATCH, OBS_DIM)
rewards = torch.randn(BATCH)
dones = torch.zeros(BATCH)
adv = ugtc.compute_advantages(obs, next_obs, rewards, dones)
assert torch.isfinite(adv).all(), "All advantages should be finite"
def test_value_shape(self, ugtc, obs):
v = ugtc.get_value_ugtc(obs)
assert v.shape == (BATCH,)
def test_value_finite(self, ugtc, obs):
v = ugtc.get_value_ugtc(obs)
assert torch.isfinite(v).all()
def test_gae_zero_reward(self, ugtc):
"""Zero rewards with no termination should give near-zero advantages."""
T = 16
obs = torch.randn(T, OBS_DIM)
next_obs = torch.randn(T, OBS_DIM)
rewards = torch.zeros(T)
dones = torch.zeros(T)
adv = ugtc.compute_advantages(obs, next_obs, rewards, dones, gamma=0.99)
# GAE with zero rewards is not necessarily zero (depends on value differences)
assert adv.shape == (T,)
def test_done_masks(self, ugtc):
"""Done flags should prevent bootstrapping across episodes."""
T = 4
obs = torch.randn(T, OBS_DIM)
next_obs = torch.randn(T, OBS_DIM)
rewards = torch.ones(T)
dones_all = torch.ones(T) # every step terminates
dones_none = torch.zeros(T)
adv_all = ugtc.compute_advantages(obs, next_obs, rewards, dones_all)
adv_none = ugtc.compute_advantages(obs, next_obs, rewards, dones_none)
# These should differ because done=1 zeroes out future bootstrapping
assert not torch.allclose(adv_all, adv_none)
def test_parameter_count(self, ugtc):
counts = ugtc.parameter_count()
assert "fast_critic" in counts
assert "slow_ensemble" in counts
assert "total" in counts
assert counts["total"] == counts["fast_critic"] + counts["slow_ensemble"]
assert counts["total"] > 0
def test_gate_stats_keys(self, ugtc, obs):
stats = ugtc.get_gate_stats(obs)
for key in ("gate_mean", "gate_std", "gate_min", "gate_max", "sigma_ema"):
assert key in stats, f"Missing key: {key}"
def test_ema_updates_during_training(self, ugtc, obs):
ugtc.train()
initial_ema = ugtc.sigma_ema.item()
for _ in range(10):
ugtc.compute_gate(obs)
updated_ema = ugtc.sigma_ema.item()
# EMA should change (unless uncertainty is exactly 1.0 from the start)
assert isinstance(updated_ema, float)
def test_ema_frozen_in_eval(self, ugtc, obs):
ugtc.eval()
initial_ema = ugtc.sigma_ema.item()
for _ in range(10):
ugtc.compute_gate(obs)
assert ugtc.sigma_ema.item() == initial_ema, "EMA should not update in eval mode"
def test_different_lambda_values(self):
"""Verify different lambda values produce different advantages."""
ugtc_low = UGTCModule(OBS_DIM, HIDDEN, lambda_fast=0.1, lambda_slow=0.2)
ugtc_high = UGTCModule(OBS_DIM, HIDDEN, lambda_fast=0.8, lambda_slow=0.99)
obs = torch.randn(16, OBS_DIM)
next_obs = torch.randn(16, OBS_DIM)
rewards = torch.randn(16)
dones = torch.zeros(16)
adv_low = ugtc_low.compute_advantages(obs, next_obs, rewards, dones)
adv_high = ugtc_high.compute_advantages(obs, next_obs, rewards, dones)
assert not torch.allclose(adv_low, adv_high)
def test_beta_affects_gate_sharpness(self):
"""Higher beta should produce sharper gate transitions."""
ugtc_low_beta = UGTCModule(OBS_DIM, HIDDEN, beta=0.1)
ugtc_high_beta = UGTCModule(OBS_DIM, HIDDEN, beta=20.0)
obs = torch.randn(64, OBS_DIM)
gate_low, _, _ = ugtc_low_beta.compute_gate(obs)
gate_high, _, _ = ugtc_high_beta.compute_gate(obs)
# High beta should have more extreme values (closer to 0 or 1)
extremity_low = ((gate_low - 0.5).abs()).mean()
extremity_high = ((gate_high - 0.5).abs()).mean()
assert extremity_high >= extremity_low, "Higher beta should produce sharper gate"
@pytest.mark.parametrize("M", [1, 2, 5, 10])
def test_ensemble_sizes(self, M):
ugtc = UGTCModule(OBS_DIM, HIDDEN, M=M)
obs = torch.randn(BATCH, OBS_DIM)
gate, v_fast, v_slow = ugtc.compute_gate(obs)
assert gate.shape == (BATCH,)
def test_no_grad_in_advantage_computation(self, ugtc, obs):
"""compute_advantages should not retain gradients on the output."""
next_obs = torch.randn(BATCH, OBS_DIM)
rewards = torch.randn(BATCH)
dones = torch.zeros(BATCH)
adv = ugtc.compute_advantages(obs, next_obs, rewards, dones)
assert not adv.requires_grad
# ── GAE computation ────────────────────────────────────────────────────────────
class TestGAEComputation:
def test_single_step_gae(self):
"""Single step: advantage = Ξ΄ = r + Ξ³V(s') - V(s)."""
rewards = torch.tensor([1.0])
values = torch.tensor([0.5])
next_values = torch.tensor([0.5])
dones = torch.tensor([0.0])
adv = UGTCModule._compute_gae(rewards, values, next_values, dones, gamma=0.99, lam=0.95)
expected = 1.0 + 0.99 * 0.5 - 0.5
assert abs(adv[0].item() - expected) < 1e-5
def test_terminal_step(self):
"""Done=1 should zero out future value bootstrap."""
rewards = torch.tensor([1.0])
values = torch.tensor([0.0])
next_values = torch.tensor([100.0]) # large, should be masked
dones = torch.tensor([1.0])
adv = UGTCModule._compute_gae(rewards, values, next_values, dones, gamma=0.99, lam=0.95)
expected = 1.0 + 0.99 * 100.0 * 0.0 - 0.0 # next_values masked out
assert abs(adv[0].item() - expected) < 1e-5