""" Integration tests for UGTC algorithm wrappers (PPO, TD3, SAC, DDPG). These tests verify that each integration: - Initializes without error - Produces valid losses (finite, not NaN) - Returns expected metric keys - Updates parameters (gradient flows) Note: Full training runs are not tested here — see examples/ for that. """ import pytest import torch import numpy as np from ugtc import UGTCPPO, UGTCTD3, UGTCSAC, UGTCDDPG from ugtc.td3 import ReplayBuffer OBS_DIM = 8 ACT_DIM = 2 BATCH = 32 HIDDEN = 32 def make_ppo_rollout(T=64, device="cpu"): obs = torch.randn(T, OBS_DIM) actions = torch.randn(T, ACT_DIM) return { "obs": obs, "actions": actions, "rewards": torch.randn(T), "next_obs": torch.randn(T, OBS_DIM), "dones": torch.zeros(T), "log_probs": torch.randn(T), } def make_replay_buffer(): buf = ReplayBuffer(OBS_DIM, ACT_DIM, capacity=1000) for _ in range(BATCH * 2): obs = np.random.randn(OBS_DIM).astype(np.float32) action = np.random.randn(ACT_DIM).astype(np.float32) reward = float(np.random.randn()) next_obs = np.random.randn(OBS_DIM).astype(np.float32) done = False buf.add(obs, action, reward, next_obs, done) return buf # ── UGTC-PPO ────────────────────────────────────────────────────────────────── class TestUGTCPPO: @pytest.fixture def agent(self): return UGTCPPO(OBS_DIM, ACT_DIM, hidden_dim=HIDDEN, epochs=2) def test_init(self, agent): assert agent.policy is not None assert agent.ugtc is not None def test_select_action_shape(self, agent): obs = torch.randn(1, OBS_DIM) action, log_prob = agent.select_action(obs) assert action.shape == (1, ACT_DIM) assert log_prob.shape == (1,) def test_update_returns_dict(self, agent): rollout = make_ppo_rollout() metrics = agent.update(rollout) assert isinstance(metrics, dict) def test_update_metrics_finite(self, agent): rollout = make_ppo_rollout() metrics = agent.update(rollout) for key, val in metrics.items(): assert np.isfinite(val), f"Metric {key} = {val} is not finite" def test_update_metrics_keys(self, agent): rollout = make_ppo_rollout() metrics = agent.update(rollout) assert "policy_loss" in metrics assert "fast_value_loss" in metrics assert "gate_mean" in metrics def test_parameters_update(self, agent): """Verify that an update step actually modifies parameters.""" initial_params = [p.clone() for p in agent.policy.parameters()] rollout = make_ppo_rollout() agent.update(rollout) updated_params = list(agent.policy.parameters()) changed = any( not torch.allclose(i, u) for i, u in zip(initial_params, updated_params) ) assert changed, "Policy parameters should change after update" def test_save_load(self, agent, tmp_path): path = str(tmp_path / "ppo.pt") agent.save(path) agent.load(path) # ── UGTC-TD3 ────────────────────────────────────────────────────────────────── class TestUGTCTD3: @pytest.fixture def agent(self): return UGTCTD3(OBS_DIM, ACT_DIM, hidden=HIDDEN, device="cpu") def test_init(self, agent): assert agent.actor is not None assert agent.ugtc is not None def test_select_action_shape(self, agent): obs = np.random.randn(OBS_DIM).astype(np.float32) action = agent.select_action(obs, noise=0.0) assert action.shape == (ACT_DIM,) def test_update_returns_dict(self, agent): buf = make_replay_buffer() metrics = agent.update(buf, batch_size=BATCH) assert isinstance(metrics, dict) def test_critic_loss_finite(self, agent): buf = make_replay_buffer() metrics = agent.update(buf, batch_size=BATCH) assert np.isfinite(metrics["critic_loss"]) def test_actor_loss_after_delay(self, agent): """Actor loss should appear after policy_delay updates.""" buf = make_replay_buffer() # Run enough updates to trigger delayed actor update for _ in range(agent.policy_delay + 1): metrics = agent.update(buf, batch_size=BATCH) assert "actor_loss" in metrics def test_action_clipped(self, agent): obs = np.random.randn(OBS_DIM).astype(np.float32) * 10 action = agent.select_action(obs, noise=0.0) assert np.all(np.abs(action) <= agent.max_action + 1e-6) # ── UGTC-SAC ────────────────────────────────────────────────────────────────── class TestUGTCSAC: @pytest.fixture def agent(self): return UGTCSAC(OBS_DIM, ACT_DIM, hidden=HIDDEN, auto_alpha=True, device="cpu") def test_init(self, agent): assert agent.policy is not None assert agent.ugtc is not None def test_select_action_stochastic(self, agent): obs = np.random.randn(OBS_DIM).astype(np.float32) a1 = agent.select_action(obs, deterministic=False) a2 = agent.select_action(obs, deterministic=False) assert a1.shape == (ACT_DIM,) # Stochastic: two samples should (almost certainly) differ assert not np.allclose(a1, a2) def test_select_action_deterministic(self, agent): obs = np.random.randn(OBS_DIM).astype(np.float32) a1 = agent.select_action(obs, deterministic=True) a2 = agent.select_action(obs, deterministic=True) assert np.allclose(a1, a2) def test_update_returns_dict(self, agent): buf = make_replay_buffer() metrics = agent.update(buf, batch_size=BATCH) assert isinstance(metrics, dict) def test_all_metrics_finite(self, agent): buf = make_replay_buffer() metrics = agent.update(buf, batch_size=BATCH) for k, v in metrics.items(): assert np.isfinite(v), f"Metric {k} = {v}" def test_auto_alpha_changes(self, agent): buf = make_replay_buffer() initial_alpha = agent.alpha for _ in range(5): agent.update(buf, batch_size=BATCH) # Alpha should move (unless already converged) assert isinstance(agent.alpha, float) # ── UGTC-DDPG ───────────────────────────────────────────────────────────────── class TestUGTCDDPG: @pytest.fixture def agent(self): return UGTCDDPG(OBS_DIM, ACT_DIM, hidden=HIDDEN, device="cpu") def test_init(self, agent): assert agent.actor is not None assert agent.ugtc is not None def test_select_action(self, agent): obs = np.random.randn(OBS_DIM).astype(np.float32) action = agent.select_action(obs, add_noise=False) assert action.shape == (ACT_DIM,) def test_update_finite(self, agent): buf = make_replay_buffer() metrics = agent.update(buf, batch_size=BATCH) for k, v in metrics.items(): assert np.isfinite(v), f"Metric {k} = {v}" def test_update_keys(self, agent): buf = make_replay_buffer() metrics = agent.update(buf, batch_size=BATCH) assert "critic_loss" in metrics assert "actor_loss" in metrics assert "gate_mean" in metrics