ugtc / tests /test_integrations.py
Ekrem-the-second's picture
Initial release: UGTC - Uncertainty-Gated Temporal Credit
d92d8cf verified
Raw
History Blame Contribute Delete
7.83 kB
"""
Integration tests for UGTC algorithm wrappers (PPO, TD3, SAC, DDPG).
These tests verify that each integration:
- Initializes without error
- Produces valid losses (finite, not NaN)
- Returns expected metric keys
- Updates parameters (gradient flows)
Note: Full training runs are not tested here β€” see examples/ for that.
"""
import pytest
import torch
import numpy as np
from ugtc import UGTCPPO, UGTCTD3, UGTCSAC, UGTCDDPG
from ugtc.td3 import ReplayBuffer
OBS_DIM = 8
ACT_DIM = 2
BATCH = 32
HIDDEN = 32
def make_ppo_rollout(T=64, device="cpu"):
obs = torch.randn(T, OBS_DIM)
actions = torch.randn(T, ACT_DIM)
return {
"obs": obs,
"actions": actions,
"rewards": torch.randn(T),
"next_obs": torch.randn(T, OBS_DIM),
"dones": torch.zeros(T),
"log_probs": torch.randn(T),
}
def make_replay_buffer():
buf = ReplayBuffer(OBS_DIM, ACT_DIM, capacity=1000)
for _ in range(BATCH * 2):
obs = np.random.randn(OBS_DIM).astype(np.float32)
action = np.random.randn(ACT_DIM).astype(np.float32)
reward = float(np.random.randn())
next_obs = np.random.randn(OBS_DIM).astype(np.float32)
done = False
buf.add(obs, action, reward, next_obs, done)
return buf
# ── UGTC-PPO ──────────────────────────────────────────────────────────────────
class TestUGTCPPO:
@pytest.fixture
def agent(self):
return UGTCPPO(OBS_DIM, ACT_DIM, hidden_dim=HIDDEN, epochs=2)
def test_init(self, agent):
assert agent.policy is not None
assert agent.ugtc is not None
def test_select_action_shape(self, agent):
obs = torch.randn(1, OBS_DIM)
action, log_prob = agent.select_action(obs)
assert action.shape == (1, ACT_DIM)
assert log_prob.shape == (1,)
def test_update_returns_dict(self, agent):
rollout = make_ppo_rollout()
metrics = agent.update(rollout)
assert isinstance(metrics, dict)
def test_update_metrics_finite(self, agent):
rollout = make_ppo_rollout()
metrics = agent.update(rollout)
for key, val in metrics.items():
assert np.isfinite(val), f"Metric {key} = {val} is not finite"
def test_update_metrics_keys(self, agent):
rollout = make_ppo_rollout()
metrics = agent.update(rollout)
assert "policy_loss" in metrics
assert "fast_value_loss" in metrics
assert "gate_mean" in metrics
def test_parameters_update(self, agent):
"""Verify that an update step actually modifies parameters."""
initial_params = [p.clone() for p in agent.policy.parameters()]
rollout = make_ppo_rollout()
agent.update(rollout)
updated_params = list(agent.policy.parameters())
changed = any(
not torch.allclose(i, u) for i, u in zip(initial_params, updated_params)
)
assert changed, "Policy parameters should change after update"
def test_save_load(self, agent, tmp_path):
path = str(tmp_path / "ppo.pt")
agent.save(path)
agent.load(path)
# ── UGTC-TD3 ──────────────────────────────────────────────────────────────────
class TestUGTCTD3:
@pytest.fixture
def agent(self):
return UGTCTD3(OBS_DIM, ACT_DIM, hidden=HIDDEN, device="cpu")
def test_init(self, agent):
assert agent.actor is not None
assert agent.ugtc is not None
def test_select_action_shape(self, agent):
obs = np.random.randn(OBS_DIM).astype(np.float32)
action = agent.select_action(obs, noise=0.0)
assert action.shape == (ACT_DIM,)
def test_update_returns_dict(self, agent):
buf = make_replay_buffer()
metrics = agent.update(buf, batch_size=BATCH)
assert isinstance(metrics, dict)
def test_critic_loss_finite(self, agent):
buf = make_replay_buffer()
metrics = agent.update(buf, batch_size=BATCH)
assert np.isfinite(metrics["critic_loss"])
def test_actor_loss_after_delay(self, agent):
"""Actor loss should appear after policy_delay updates."""
buf = make_replay_buffer()
# Run enough updates to trigger delayed actor update
for _ in range(agent.policy_delay + 1):
metrics = agent.update(buf, batch_size=BATCH)
assert "actor_loss" in metrics
def test_action_clipped(self, agent):
obs = np.random.randn(OBS_DIM).astype(np.float32) * 10
action = agent.select_action(obs, noise=0.0)
assert np.all(np.abs(action) <= agent.max_action + 1e-6)
# ── UGTC-SAC ──────────────────────────────────────────────────────────────────
class TestUGTCSAC:
@pytest.fixture
def agent(self):
return UGTCSAC(OBS_DIM, ACT_DIM, hidden=HIDDEN, auto_alpha=True, device="cpu")
def test_init(self, agent):
assert agent.policy is not None
assert agent.ugtc is not None
def test_select_action_stochastic(self, agent):
obs = np.random.randn(OBS_DIM).astype(np.float32)
a1 = agent.select_action(obs, deterministic=False)
a2 = agent.select_action(obs, deterministic=False)
assert a1.shape == (ACT_DIM,)
# Stochastic: two samples should (almost certainly) differ
assert not np.allclose(a1, a2)
def test_select_action_deterministic(self, agent):
obs = np.random.randn(OBS_DIM).astype(np.float32)
a1 = agent.select_action(obs, deterministic=True)
a2 = agent.select_action(obs, deterministic=True)
assert np.allclose(a1, a2)
def test_update_returns_dict(self, agent):
buf = make_replay_buffer()
metrics = agent.update(buf, batch_size=BATCH)
assert isinstance(metrics, dict)
def test_all_metrics_finite(self, agent):
buf = make_replay_buffer()
metrics = agent.update(buf, batch_size=BATCH)
for k, v in metrics.items():
assert np.isfinite(v), f"Metric {k} = {v}"
def test_auto_alpha_changes(self, agent):
buf = make_replay_buffer()
initial_alpha = agent.alpha
for _ in range(5):
agent.update(buf, batch_size=BATCH)
# Alpha should move (unless already converged)
assert isinstance(agent.alpha, float)
# ── UGTC-DDPG ─────────────────────────────────────────────────────────────────
class TestUGTCDDPG:
@pytest.fixture
def agent(self):
return UGTCDDPG(OBS_DIM, ACT_DIM, hidden=HIDDEN, device="cpu")
def test_init(self, agent):
assert agent.actor is not None
assert agent.ugtc is not None
def test_select_action(self, agent):
obs = np.random.randn(OBS_DIM).astype(np.float32)
action = agent.select_action(obs, add_noise=False)
assert action.shape == (ACT_DIM,)
def test_update_finite(self, agent):
buf = make_replay_buffer()
metrics = agent.update(buf, batch_size=BATCH)
for k, v in metrics.items():
assert np.isfinite(v), f"Metric {k} = {v}"
def test_update_keys(self, agent):
buf = make_replay_buffer()
metrics = agent.update(buf, batch_size=BATCH)
assert "critic_loss" in metrics
assert "actor_loss" in metrics
assert "gate_mean" in metrics