LiManshu's picture
Add files using upload-large-folder tool
bf6be45 verified
"""归一化层:RMSNorm"""
# 2026-01-22
import torch
import torch.nn as nn
class RMSNorm(nn.Module):
"""RMSNorm 归一化层"""
def __init__(self, dim, eps=1e-5):
"""
初始化 RMSNorm
参数:
dim: 特征维度
eps: 数值稳定项,防止除以零
"""
super().__init__()
# 确保 eps 是浮点数(防止 YAML 解析为字符串)
self.eps = float(eps)
# 创建可学习的缩放参数,初始化为全1
# nn.Parameter 表示这是模型参数,会被优化器更新
# weight 就是 γ 参数
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
"""
前向传播
参数:
x: 输入张量,形状为 (batch_size, seq_len, dim) 或其他形状
返回:
归一化后的张量
"""
# 计算均方根(Root Mean Square)
# x.pow(2) 计算每个元素的平方
# mean(-1, keepdim=True) 在最后一个维度上求均值,保持维度
rms = torch.sqrt(torch.mean(x.pow(2), dim=-1, keepdim=True) + self.eps)
# 归一化:除以 RMS,然后乘以可学习的权重
return x / rms * self.weight
if __name__ == "__main__":
print("=" * 60)
print("RMSNorm 测试")
print("=" * 60)
# 1. 创建 RMSNorm 层
dim = 32
norm = RMSNorm(dim=dim, eps=1e-5)
print("\n1. 创建 RMSNorm 层")
print(f" 维度: {dim}")
print(f" eps: {norm.eps}")
print(f" 权重形状: {norm.weight.shape}")
print(f" 权重初始值(前5个): {norm.weight[:5]}")
# 2. 创建测试输入
batch_size = 2
seq_len = 10
x = torch.randn(batch_size, seq_len, dim)
print("\n2. 创建测试输入")
print(f" 输入形状: {x.shape}")
print(" 输入统计:")
print(f" - 均值: {x.mean().item():.4f}")
print(f" - 标准差: {x.std().item():.4f}")
print(f" - 最小值: {x.min().item():.4f}")
print(f" - 最大值: {x.max().item():.4f}")
# 3. 前向传播
output = norm(x)
print("\n3. 前向传播结果")
print(f" 输出形状: {output.shape}")
print(" 输出统计:")
print(f" - 均值: {output.mean().item():.4f}")
print(f" - 标准差: {output.std().item():.4f}")
# 4. 验证归一化效果
print("\n4. 验证归一化效果")
# 计算每个样本的 RMS(应该接近1,因为权重初始化为1)
rms_per_sample = torch.sqrt(torch.mean(output.pow(2), dim=-1))
print(" 每个样本的 RMS(归一化后):")
print(f" - 样本1: {rms_per_sample[0].mean().item():.4f}")
print(f" - 样本2: {rms_per_sample[1].mean().item():.4f}")
print(f" - 平均 RMS: {rms_per_sample.mean().item():.4f}")
# 5. 验证参数是否可学习
print("\n5. 验证参数可学习性")
print(f" 权重是否为 Parameter: {isinstance(norm.weight, nn.Parameter)}")
print(f" 权重是否需要梯度: {norm.weight.requires_grad}")
# 6. 测试梯度计算
print("\n6. 测试梯度计算")
loss = output.sum() # 简单的损失函数
loss.backward()
print(f" 权重梯度是否存在: {norm.weight.grad is not None}")
if norm.weight.grad is not None:
print(f" 权重梯度形状: {norm.weight.grad.shape}")
print(f" 权重梯度统计:")
print(f" - 均值: {norm.weight.grad.mean().item():.4f}")
print(f" - 标准差: {norm.weight.grad.std().item():.4f}")
# 7. 测试不同输入形状
print("\n7. 测试不同输入形状")
test_cases = [
(1, 5, dim), # 单个样本
(4, 20, dim), # 多个样本
(1, 1, dim), # 单个 token
]
for i, shape in enumerate(test_cases, 1):
x_test = torch.randn(*shape)
output_test = norm(x_test)
print(f" 测试 {i}: 输入形状 {shape} -> 输出形状 {output_test.shape} ✓")
# 8. 验证数值稳定性
print("\n8. 验证数值稳定性")
# 测试非常小的输入
x_small = torch.randn(1, 1, dim) * 1e-6
output_small = norm(x_small)
print(
f" 极小输入测试: 输入范围 [{x_small.min().item():.2e}, {x_small.max().item():.2e}]"
)
print(f" 输出是否包含 NaN: {torch.isnan(output_small).any().item()}")
print(f" 输出是否包含 Inf: {torch.isinf(output_small).any().item()}")
print("\n" + "=" * 60)
print("所有测试完成!")
print("=" * 60)