| """归一化层:RMSNorm"""
|
|
|
|
|
| import torch
|
| import torch.nn as nn
|
|
|
|
|
| class RMSNorm(nn.Module):
|
| """RMSNorm 归一化层"""
|
|
|
| def __init__(self, dim, eps=1e-5):
|
| """
|
| 初始化 RMSNorm
|
|
|
| 参数:
|
| dim: 特征维度
|
| eps: 数值稳定项,防止除以零
|
| """
|
| super().__init__()
|
|
|
| self.eps = float(eps)
|
|
|
|
|
|
|
|
|
| self.weight = nn.Parameter(torch.ones(dim))
|
|
|
| def forward(self, x):
|
| """
|
| 前向传播
|
|
|
| 参数:
|
| x: 输入张量,形状为 (batch_size, seq_len, dim) 或其他形状
|
|
|
| 返回:
|
| 归一化后的张量
|
| """
|
|
|
|
|
|
|
| rms = torch.sqrt(torch.mean(x.pow(2), dim=-1, keepdim=True) + self.eps)
|
|
|
|
|
| return x / rms * self.weight
|
|
|
|
|
| if __name__ == "__main__":
|
| print("=" * 60)
|
| print("RMSNorm 测试")
|
| print("=" * 60)
|
|
|
|
|
| dim = 32
|
| norm = RMSNorm(dim=dim, eps=1e-5)
|
| print("\n1. 创建 RMSNorm 层")
|
| print(f" 维度: {dim}")
|
| print(f" eps: {norm.eps}")
|
| print(f" 权重形状: {norm.weight.shape}")
|
| print(f" 权重初始值(前5个): {norm.weight[:5]}")
|
|
|
|
|
| batch_size = 2
|
| seq_len = 10
|
| x = torch.randn(batch_size, seq_len, dim)
|
| print("\n2. 创建测试输入")
|
| print(f" 输入形状: {x.shape}")
|
| print(" 输入统计:")
|
| print(f" - 均值: {x.mean().item():.4f}")
|
| print(f" - 标准差: {x.std().item():.4f}")
|
| print(f" - 最小值: {x.min().item():.4f}")
|
| print(f" - 最大值: {x.max().item():.4f}")
|
|
|
|
|
| output = norm(x)
|
| print("\n3. 前向传播结果")
|
| print(f" 输出形状: {output.shape}")
|
| print(" 输出统计:")
|
| print(f" - 均值: {output.mean().item():.4f}")
|
| print(f" - 标准差: {output.std().item():.4f}")
|
|
|
|
|
| print("\n4. 验证归一化效果")
|
|
|
| rms_per_sample = torch.sqrt(torch.mean(output.pow(2), dim=-1))
|
| print(" 每个样本的 RMS(归一化后):")
|
| print(f" - 样本1: {rms_per_sample[0].mean().item():.4f}")
|
| print(f" - 样本2: {rms_per_sample[1].mean().item():.4f}")
|
| print(f" - 平均 RMS: {rms_per_sample.mean().item():.4f}")
|
|
|
|
|
| print("\n5. 验证参数可学习性")
|
| print(f" 权重是否为 Parameter: {isinstance(norm.weight, nn.Parameter)}")
|
| print(f" 权重是否需要梯度: {norm.weight.requires_grad}")
|
|
|
|
|
| print("\n6. 测试梯度计算")
|
| loss = output.sum()
|
| loss.backward()
|
| print(f" 权重梯度是否存在: {norm.weight.grad is not None}")
|
| if norm.weight.grad is not None:
|
| print(f" 权重梯度形状: {norm.weight.grad.shape}")
|
| print(f" 权重梯度统计:")
|
| print(f" - 均值: {norm.weight.grad.mean().item():.4f}")
|
| print(f" - 标准差: {norm.weight.grad.std().item():.4f}")
|
|
|
|
|
| print("\n7. 测试不同输入形状")
|
| test_cases = [
|
| (1, 5, dim),
|
| (4, 20, dim),
|
| (1, 1, dim),
|
| ]
|
|
|
| for i, shape in enumerate(test_cases, 1):
|
| x_test = torch.randn(*shape)
|
| output_test = norm(x_test)
|
| print(f" 测试 {i}: 输入形状 {shape} -> 输出形状 {output_test.shape} ✓")
|
|
|
|
|
| print("\n8. 验证数值稳定性")
|
|
|
| x_small = torch.randn(1, 1, dim) * 1e-6
|
| output_small = norm(x_small)
|
| print(
|
| f" 极小输入测试: 输入范围 [{x_small.min().item():.2e}, {x_small.max().item():.2e}]"
|
| )
|
| print(f" 输出是否包含 NaN: {torch.isnan(output_small).any().item()}")
|
| print(f" 输出是否包含 Inf: {torch.isinf(output_small).any().item()}")
|
|
|
| print("\n" + "=" * 60)
|
| print("所有测试完成!")
|
| print("=" * 60)
|
|
|