File size: 7,826 Bytes
d92d8cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
"""
Integration tests for UGTC algorithm wrappers (PPO, TD3, SAC, DDPG).

These tests verify that each integration:
    - Initializes without error
    - Produces valid losses (finite, not NaN)
    - Returns expected metric keys
    - Updates parameters (gradient flows)

Note: Full training runs are not tested here β€” see examples/ for that.
"""

import pytest
import torch
import numpy as np

from ugtc import UGTCPPO, UGTCTD3, UGTCSAC, UGTCDDPG
from ugtc.td3 import ReplayBuffer


OBS_DIM = 8
ACT_DIM = 2
BATCH = 32
HIDDEN = 32


def make_ppo_rollout(T=64, device="cpu"):
    obs = torch.randn(T, OBS_DIM)
    actions = torch.randn(T, ACT_DIM)
    return {
        "obs": obs,
        "actions": actions,
        "rewards": torch.randn(T),
        "next_obs": torch.randn(T, OBS_DIM),
        "dones": torch.zeros(T),
        "log_probs": torch.randn(T),
    }


def make_replay_buffer():
    buf = ReplayBuffer(OBS_DIM, ACT_DIM, capacity=1000)
    for _ in range(BATCH * 2):
        obs = np.random.randn(OBS_DIM).astype(np.float32)
        action = np.random.randn(ACT_DIM).astype(np.float32)
        reward = float(np.random.randn())
        next_obs = np.random.randn(OBS_DIM).astype(np.float32)
        done = False
        buf.add(obs, action, reward, next_obs, done)
    return buf


# ── UGTC-PPO ──────────────────────────────────────────────────────────────────

class TestUGTCPPO:
    @pytest.fixture
    def agent(self):
        return UGTCPPO(OBS_DIM, ACT_DIM, hidden_dim=HIDDEN, epochs=2)

    def test_init(self, agent):
        assert agent.policy is not None
        assert agent.ugtc is not None

    def test_select_action_shape(self, agent):
        obs = torch.randn(1, OBS_DIM)
        action, log_prob = agent.select_action(obs)
        assert action.shape == (1, ACT_DIM)
        assert log_prob.shape == (1,)

    def test_update_returns_dict(self, agent):
        rollout = make_ppo_rollout()
        metrics = agent.update(rollout)
        assert isinstance(metrics, dict)

    def test_update_metrics_finite(self, agent):
        rollout = make_ppo_rollout()
        metrics = agent.update(rollout)
        for key, val in metrics.items():
            assert np.isfinite(val), f"Metric {key} = {val} is not finite"

    def test_update_metrics_keys(self, agent):
        rollout = make_ppo_rollout()
        metrics = agent.update(rollout)
        assert "policy_loss" in metrics
        assert "fast_value_loss" in metrics
        assert "gate_mean" in metrics

    def test_parameters_update(self, agent):
        """Verify that an update step actually modifies parameters."""
        initial_params = [p.clone() for p in agent.policy.parameters()]
        rollout = make_ppo_rollout()
        agent.update(rollout)
        updated_params = list(agent.policy.parameters())
        changed = any(
            not torch.allclose(i, u) for i, u in zip(initial_params, updated_params)
        )
        assert changed, "Policy parameters should change after update"

    def test_save_load(self, agent, tmp_path):
        path = str(tmp_path / "ppo.pt")
        agent.save(path)
        agent.load(path)


# ── UGTC-TD3 ──────────────────────────────────────────────────────────────────

class TestUGTCTD3:
    @pytest.fixture
    def agent(self):
        return UGTCTD3(OBS_DIM, ACT_DIM, hidden=HIDDEN, device="cpu")

    def test_init(self, agent):
        assert agent.actor is not None
        assert agent.ugtc is not None

    def test_select_action_shape(self, agent):
        obs = np.random.randn(OBS_DIM).astype(np.float32)
        action = agent.select_action(obs, noise=0.0)
        assert action.shape == (ACT_DIM,)

    def test_update_returns_dict(self, agent):
        buf = make_replay_buffer()
        metrics = agent.update(buf, batch_size=BATCH)
        assert isinstance(metrics, dict)

    def test_critic_loss_finite(self, agent):
        buf = make_replay_buffer()
        metrics = agent.update(buf, batch_size=BATCH)
        assert np.isfinite(metrics["critic_loss"])

    def test_actor_loss_after_delay(self, agent):
        """Actor loss should appear after policy_delay updates."""
        buf = make_replay_buffer()
        # Run enough updates to trigger delayed actor update
        for _ in range(agent.policy_delay + 1):
            metrics = agent.update(buf, batch_size=BATCH)
        assert "actor_loss" in metrics

    def test_action_clipped(self, agent):
        obs = np.random.randn(OBS_DIM).astype(np.float32) * 10
        action = agent.select_action(obs, noise=0.0)
        assert np.all(np.abs(action) <= agent.max_action + 1e-6)


# ── UGTC-SAC ──────────────────────────────────────────────────────────────────

class TestUGTCSAC:
    @pytest.fixture
    def agent(self):
        return UGTCSAC(OBS_DIM, ACT_DIM, hidden=HIDDEN, auto_alpha=True, device="cpu")

    def test_init(self, agent):
        assert agent.policy is not None
        assert agent.ugtc is not None

    def test_select_action_stochastic(self, agent):
        obs = np.random.randn(OBS_DIM).astype(np.float32)
        a1 = agent.select_action(obs, deterministic=False)
        a2 = agent.select_action(obs, deterministic=False)
        assert a1.shape == (ACT_DIM,)
        # Stochastic: two samples should (almost certainly) differ
        assert not np.allclose(a1, a2)

    def test_select_action_deterministic(self, agent):
        obs = np.random.randn(OBS_DIM).astype(np.float32)
        a1 = agent.select_action(obs, deterministic=True)
        a2 = agent.select_action(obs, deterministic=True)
        assert np.allclose(a1, a2)

    def test_update_returns_dict(self, agent):
        buf = make_replay_buffer()
        metrics = agent.update(buf, batch_size=BATCH)
        assert isinstance(metrics, dict)

    def test_all_metrics_finite(self, agent):
        buf = make_replay_buffer()
        metrics = agent.update(buf, batch_size=BATCH)
        for k, v in metrics.items():
            assert np.isfinite(v), f"Metric {k} = {v}"

    def test_auto_alpha_changes(self, agent):
        buf = make_replay_buffer()
        initial_alpha = agent.alpha
        for _ in range(5):
            agent.update(buf, batch_size=BATCH)
        # Alpha should move (unless already converged)
        assert isinstance(agent.alpha, float)


# ── UGTC-DDPG ─────────────────────────────────────────────────────────────────

class TestUGTCDDPG:
    @pytest.fixture
    def agent(self):
        return UGTCDDPG(OBS_DIM, ACT_DIM, hidden=HIDDEN, device="cpu")

    def test_init(self, agent):
        assert agent.actor is not None
        assert agent.ugtc is not None

    def test_select_action(self, agent):
        obs = np.random.randn(OBS_DIM).astype(np.float32)
        action = agent.select_action(obs, add_noise=False)
        assert action.shape == (ACT_DIM,)

    def test_update_finite(self, agent):
        buf = make_replay_buffer()
        metrics = agent.update(buf, batch_size=BATCH)
        for k, v in metrics.items():
            assert np.isfinite(v), f"Metric {k} = {v}"

    def test_update_keys(self, agent):
        buf = make_replay_buffer()
        metrics = agent.update(buf, batch_size=BATCH)
        assert "critic_loss" in metrics
        assert "actor_loss" in metrics
        assert "gate_mean" in metrics