| """ |
| Integration tests for UGTC algorithm wrappers (PPO, TD3, SAC, DDPG). |
| |
| These tests verify that each integration: |
| - Initializes without error |
| - Produces valid losses (finite, not NaN) |
| - Returns expected metric keys |
| - Updates parameters (gradient flows) |
| |
| Note: Full training runs are not tested here β see examples/ for that. |
| """ |
|
|
| import pytest |
| import torch |
| import numpy as np |
|
|
| from ugtc import UGTCPPO, UGTCTD3, UGTCSAC, UGTCDDPG |
| from ugtc.td3 import ReplayBuffer |
|
|
|
|
| OBS_DIM = 8 |
| ACT_DIM = 2 |
| BATCH = 32 |
| HIDDEN = 32 |
|
|
|
|
| def make_ppo_rollout(T=64, device="cpu"): |
| obs = torch.randn(T, OBS_DIM) |
| actions = torch.randn(T, ACT_DIM) |
| return { |
| "obs": obs, |
| "actions": actions, |
| "rewards": torch.randn(T), |
| "next_obs": torch.randn(T, OBS_DIM), |
| "dones": torch.zeros(T), |
| "log_probs": torch.randn(T), |
| } |
|
|
|
|
| def make_replay_buffer(): |
| buf = ReplayBuffer(OBS_DIM, ACT_DIM, capacity=1000) |
| for _ in range(BATCH * 2): |
| obs = np.random.randn(OBS_DIM).astype(np.float32) |
| action = np.random.randn(ACT_DIM).astype(np.float32) |
| reward = float(np.random.randn()) |
| next_obs = np.random.randn(OBS_DIM).astype(np.float32) |
| done = False |
| buf.add(obs, action, reward, next_obs, done) |
| return buf |
|
|
|
|
| |
|
|
| class TestUGTCPPO: |
| @pytest.fixture |
| def agent(self): |
| return UGTCPPO(OBS_DIM, ACT_DIM, hidden_dim=HIDDEN, epochs=2) |
|
|
| def test_init(self, agent): |
| assert agent.policy is not None |
| assert agent.ugtc is not None |
|
|
| def test_select_action_shape(self, agent): |
| obs = torch.randn(1, OBS_DIM) |
| action, log_prob = agent.select_action(obs) |
| assert action.shape == (1, ACT_DIM) |
| assert log_prob.shape == (1,) |
|
|
| def test_update_returns_dict(self, agent): |
| rollout = make_ppo_rollout() |
| metrics = agent.update(rollout) |
| assert isinstance(metrics, dict) |
|
|
| def test_update_metrics_finite(self, agent): |
| rollout = make_ppo_rollout() |
| metrics = agent.update(rollout) |
| for key, val in metrics.items(): |
| assert np.isfinite(val), f"Metric {key} = {val} is not finite" |
|
|
| def test_update_metrics_keys(self, agent): |
| rollout = make_ppo_rollout() |
| metrics = agent.update(rollout) |
| assert "policy_loss" in metrics |
| assert "fast_value_loss" in metrics |
| assert "gate_mean" in metrics |
|
|
| def test_parameters_update(self, agent): |
| """Verify that an update step actually modifies parameters.""" |
| initial_params = [p.clone() for p in agent.policy.parameters()] |
| rollout = make_ppo_rollout() |
| agent.update(rollout) |
| updated_params = list(agent.policy.parameters()) |
| changed = any( |
| not torch.allclose(i, u) for i, u in zip(initial_params, updated_params) |
| ) |
| assert changed, "Policy parameters should change after update" |
|
|
| def test_save_load(self, agent, tmp_path): |
| path = str(tmp_path / "ppo.pt") |
| agent.save(path) |
| agent.load(path) |
|
|
|
|
| |
|
|
| class TestUGTCTD3: |
| @pytest.fixture |
| def agent(self): |
| return UGTCTD3(OBS_DIM, ACT_DIM, hidden=HIDDEN, device="cpu") |
|
|
| def test_init(self, agent): |
| assert agent.actor is not None |
| assert agent.ugtc is not None |
|
|
| def test_select_action_shape(self, agent): |
| obs = np.random.randn(OBS_DIM).astype(np.float32) |
| action = agent.select_action(obs, noise=0.0) |
| assert action.shape == (ACT_DIM,) |
|
|
| def test_update_returns_dict(self, agent): |
| buf = make_replay_buffer() |
| metrics = agent.update(buf, batch_size=BATCH) |
| assert isinstance(metrics, dict) |
|
|
| def test_critic_loss_finite(self, agent): |
| buf = make_replay_buffer() |
| metrics = agent.update(buf, batch_size=BATCH) |
| assert np.isfinite(metrics["critic_loss"]) |
|
|
| def test_actor_loss_after_delay(self, agent): |
| """Actor loss should appear after policy_delay updates.""" |
| buf = make_replay_buffer() |
| |
| for _ in range(agent.policy_delay + 1): |
| metrics = agent.update(buf, batch_size=BATCH) |
| assert "actor_loss" in metrics |
|
|
| def test_action_clipped(self, agent): |
| obs = np.random.randn(OBS_DIM).astype(np.float32) * 10 |
| action = agent.select_action(obs, noise=0.0) |
| assert np.all(np.abs(action) <= agent.max_action + 1e-6) |
|
|
|
|
| |
|
|
| class TestUGTCSAC: |
| @pytest.fixture |
| def agent(self): |
| return UGTCSAC(OBS_DIM, ACT_DIM, hidden=HIDDEN, auto_alpha=True, device="cpu") |
|
|
| def test_init(self, agent): |
| assert agent.policy is not None |
| assert agent.ugtc is not None |
|
|
| def test_select_action_stochastic(self, agent): |
| obs = np.random.randn(OBS_DIM).astype(np.float32) |
| a1 = agent.select_action(obs, deterministic=False) |
| a2 = agent.select_action(obs, deterministic=False) |
| assert a1.shape == (ACT_DIM,) |
| |
| assert not np.allclose(a1, a2) |
|
|
| def test_select_action_deterministic(self, agent): |
| obs = np.random.randn(OBS_DIM).astype(np.float32) |
| a1 = agent.select_action(obs, deterministic=True) |
| a2 = agent.select_action(obs, deterministic=True) |
| assert np.allclose(a1, a2) |
|
|
| def test_update_returns_dict(self, agent): |
| buf = make_replay_buffer() |
| metrics = agent.update(buf, batch_size=BATCH) |
| assert isinstance(metrics, dict) |
|
|
| def test_all_metrics_finite(self, agent): |
| buf = make_replay_buffer() |
| metrics = agent.update(buf, batch_size=BATCH) |
| for k, v in metrics.items(): |
| assert np.isfinite(v), f"Metric {k} = {v}" |
|
|
| def test_auto_alpha_changes(self, agent): |
| buf = make_replay_buffer() |
| initial_alpha = agent.alpha |
| for _ in range(5): |
| agent.update(buf, batch_size=BATCH) |
| |
| assert isinstance(agent.alpha, float) |
|
|
|
|
| |
|
|
| class TestUGTCDDPG: |
| @pytest.fixture |
| def agent(self): |
| return UGTCDDPG(OBS_DIM, ACT_DIM, hidden=HIDDEN, device="cpu") |
|
|
| def test_init(self, agent): |
| assert agent.actor is not None |
| assert agent.ugtc is not None |
|
|
| def test_select_action(self, agent): |
| obs = np.random.randn(OBS_DIM).astype(np.float32) |
| action = agent.select_action(obs, add_noise=False) |
| assert action.shape == (ACT_DIM,) |
|
|
| def test_update_finite(self, agent): |
| buf = make_replay_buffer() |
| metrics = agent.update(buf, batch_size=BATCH) |
| for k, v in metrics.items(): |
| assert np.isfinite(v), f"Metric {k} = {v}" |
|
|
| def test_update_keys(self, agent): |
| buf = make_replay_buffer() |
| metrics = agent.update(buf, batch_size=BATCH) |
| assert "critic_loss" in metrics |
| assert "actor_loss" in metrics |
| assert "gate_mean" in metrics |
|
|