File size: 10,185 Bytes
fe0625d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 | """tests/test_model.py —— DQNNetwork 单元测试
覆盖:
* 正向传播输出维度
* 不同 grid_size 的输出形状
* eval/train 模式切换
* 权重初始化(Conv → Kaiming,Linear → Xavier)
"""
from __future__ import annotations
import sys
from pathlib import Path
import pytest
import torch
import torch.nn as nn
# src/ 目录下的 model.py 不属于可安装包,注入 sys.path
_SRC = Path(__file__).resolve().parent.parent / "src"
if str(_SRC) not in sys.path:
sys.path.insert(0, str(_SRC))
from model import DQNNetwork # noqa: E402
from model import DuelingDQNNetwork # noqa: E402
# ---------------------------------------------------------------------------
# 夹具
# ---------------------------------------------------------------------------
@pytest.fixture
def net5() -> DQNNetwork:
"""5×5 迷宫网络(快速测试用)。"""
return DQNNetwork(grid_size=5)
@pytest.fixture
def net10() -> DQNNetwork:
"""10×10 迷宫网络(与训练默认配置一致)。"""
return DQNNetwork(grid_size=10)
# ---------------------------------------------------------------------------
# 正向传播:输出维度
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestForwardShape:
def test_output_shape_5x5(self, net5: DQNNetwork) -> None:
x = torch.randn(32, 4, 5, 5)
out = net5(x)
assert out.shape == (32, 4), f"期望 (32, 4),实际 {out.shape}"
def test_output_shape_10x10(self, net10: DQNNetwork) -> None:
x = torch.randn(16, 4, 10, 10)
out = net10(x)
assert out.shape == (16, 4), f"期望 (16, 4),实际 {out.shape}"
def test_batch_size_1(self, net5: DQNNetwork) -> None:
x = torch.randn(1, 4, 5, 5)
out = net5(x)
assert out.shape == (1, 4)
def test_custom_num_actions(self) -> None:
net = DQNNetwork(grid_size=6, num_actions=8)
x = torch.randn(4, 4, 6, 6)
out = net(x)
assert out.shape == (4, 8)
def test_output_dtype_float32(self, net5: DQNNetwork) -> None:
x = torch.randn(2, 4, 5, 5)
out = net5(x)
assert out.dtype == torch.float32
# ---------------------------------------------------------------------------
# 权重初始化
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestWeightInit:
def test_conv_bias_zeros(self, net5: DQNNetwork) -> None:
for m in net5.modules():
if isinstance(m, nn.Conv2d) and m.bias is not None:
assert torch.all(m.bias == 0.0), "Conv bias 应初始化为 0"
def test_linear_bias_zeros(self, net5: DQNNetwork) -> None:
for m in net5.modules():
if isinstance(m, nn.Linear) and m.bias is not None:
assert torch.all(m.bias == 0.0), "Linear bias 应初始化为 0"
def test_conv_weights_not_all_zeros(self, net5: DQNNetwork) -> None:
for m in net5.modules():
if isinstance(m, nn.Conv2d):
assert not torch.all(m.weight == 0.0), "Conv 权重不应全为 0"
def test_linear_weights_not_all_zeros(self, net5: DQNNetwork) -> None:
for m in net5.modules():
if isinstance(m, nn.Linear):
assert not torch.all(m.weight == 0.0), "Linear 权重不应全为 0"
# ---------------------------------------------------------------------------
# eval / train 模式
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestModeSwitch:
def test_eval_mode(self, net5: DQNNetwork) -> None:
net5.eval()
assert not net5.training
def test_train_mode(self, net5: DQNNetwork) -> None:
net5.eval()
net5.train()
assert net5.training
def test_inference_no_grad(self, net5: DQNNetwork) -> None:
net5.eval()
x = torch.randn(1, 4, 5, 5)
with torch.no_grad():
out = net5(x)
# 确保可正常取值
assert out.shape == (1, 4)
# ---------------------------------------------------------------------------
# 参数量回归(防止意外改动网络结构)
# ---------------------------------------------------------------------------
@pytest.mark.unit
def test_param_count_5x5() -> None:
net = DQNNetwork(grid_size=5, input_channels=4, num_actions=4)
total = sum(p.numel() for p in net.parameters())
# 不固定精确值,但要求在合理区间(数千 ~ 百万)
assert 10_000 < total < 10_000_000, f"参数量异常:{total}"
# ---------------------------------------------------------------------------
# DuelingDQNNetwork 夹具
# ---------------------------------------------------------------------------
@pytest.fixture
def dueling5() -> DuelingDQNNetwork:
"""5×5 Dueling 网络(快速测试用)。"""
return DuelingDQNNetwork(grid_size=5)
@pytest.fixture
def dueling10() -> DuelingDQNNetwork:
"""10×10 Dueling 网络(与训练默认配置一致)。"""
return DuelingDQNNetwork(grid_size=10)
# ---------------------------------------------------------------------------
# DuelingDQNNetwork:正向传播输出维度
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestDuelingForwardShape:
def test_output_shape_5x5(self, dueling5: DuelingDQNNetwork) -> None:
x = torch.randn(32, 4, 5, 5)
out = dueling5(x)
assert out.shape == (32, 4), f"期望 (32, 4),实际 {out.shape}"
def test_output_shape_10x10(self, dueling10: DuelingDQNNetwork) -> None:
x = torch.randn(16, 4, 10, 10)
out = dueling10(x)
assert out.shape == (16, 4), f"期望 (16, 4),实际 {out.shape}"
def test_batch_size_1(self, dueling5: DuelingDQNNetwork) -> None:
x = torch.randn(1, 4, 5, 5)
out = dueling5(x)
assert out.shape == (1, 4)
def test_output_dtype_float32(self, dueling5: DuelingDQNNetwork) -> None:
x = torch.randn(2, 4, 5, 5)
out = dueling5(x)
assert out.dtype == torch.float32
# ---------------------------------------------------------------------------
# DuelingDQNNetwork:V+A 分解核心语义
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestDuelingDecomposition:
def test_value_advantage_decomposition(self, dueling5: DuelingDQNNetwork) -> None:
"""直接访问 value_stream / advantage_stream,验证输出形状。"""
x = torch.randn(8, 4, 5, 5)
feat = dueling5.flatten(dueling5.conv(x)) # (8, flat_dim)
V = dueling5.value_stream(feat) # 期望 (8, 1)
A = dueling5.advantage_stream(feat) # 期望 (8, 4)
assert V.shape == (8, 1), f"value_stream 输出形状错误:{V.shape}"
assert A.shape == (8, 4), f"advantage_stream 输出形状错误:{A.shape}"
def test_mean_subtraction(self, dueling5: DuelingDQNNetwork) -> None:
"""验证 mean(A) 被减去:Q = V + A - mean(A),Q 的 mean 等于 V.squeeze。"""
x = torch.randn(4, 4, 5, 5)
feat = dueling5.flatten(dueling5.conv(x))
V = dueling5.value_stream(feat) # (4, 1)
A = dueling5.advantage_stream(feat) # (4, 4)
Q_expected = V + A - A.mean(dim=1, keepdim=True)
Q_actual = dueling5(x)
assert torch.allclose(Q_actual, Q_expected, atol=1e-5), \
"Q 值与 V+A-mean(A) 不匹配"
# 确认 mean(A) 确实被减去:Q 的行均值应等于 V.squeeze()
assert torch.allclose(Q_actual.mean(dim=1), V.squeeze(1), atol=1e-5), \
"Q 的行均值应等于 V(s)"
# ---------------------------------------------------------------------------
# DuelingDQNNetwork:权重初始化
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestDuelingWeightInit:
def test_conv_bias_zeros(self, dueling5: DuelingDQNNetwork) -> None:
for m in dueling5.modules():
if isinstance(m, nn.Conv2d) and m.bias is not None:
assert torch.all(m.bias == 0.0), "Conv bias 应初始化为 0"
def test_linear_bias_zeros(self, dueling5: DuelingDQNNetwork) -> None:
for m in dueling5.modules():
if isinstance(m, nn.Linear) and m.bias is not None:
assert torch.all(m.bias == 0.0), "Linear bias 应初始化为 0"
def test_conv_weights_not_all_zeros(self, dueling5: DuelingDQNNetwork) -> None:
for m in dueling5.modules():
if isinstance(m, nn.Conv2d):
assert not torch.all(m.weight == 0.0), "Conv 权重不应全为 0"
def test_linear_weights_not_all_zeros(self, dueling5: DuelingDQNNetwork) -> None:
for m in dueling5.modules():
if isinstance(m, nn.Linear):
assert not torch.all(m.weight == 0.0), "Linear 权重不应全为 0"
# ---------------------------------------------------------------------------
# DuelingDQNNetwork:参数量 & 结构对比
# ---------------------------------------------------------------------------
@pytest.mark.unit
def test_param_count_dueling_5x5() -> None:
net = DuelingDQNNetwork(grid_size=5, input_channels=4, num_actions=4)
total = sum(p.numel() for p in net.parameters())
assert 10_000 < total < 10_000_000, f"Dueling 参数量异常:{total}"
@pytest.mark.unit
def test_dueling_vs_dqn_same_conv() -> None:
"""两个网络 conv 部分参数量相同(相同 grid_size 下共享相同主干结构)。"""
dqn = DQNNetwork(grid_size=5)
dueling = DuelingDQNNetwork(grid_size=5)
dqn_conv_params = sum(p.numel() for m in dqn.modules()
if isinstance(m, nn.Conv2d) for p in m.parameters())
dueling_conv_params = sum(p.numel() for m in dueling.modules()
if isinstance(m, nn.Conv2d) for p in m.parameters())
assert dqn_conv_params == dueling_conv_params, (
f"Conv 参数量不一致:DQN={dqn_conv_params}, Dueling={dueling_conv_params}"
)
|