MultiModal / components.py
szxllm's picture
Upload 20 files
cd66851 verified
raw
history blame
14.5 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Optional, Union
import math
class YARNScaling:
"""
YARN (Yet Another RoPE extensioN) 缩放策略
实现参考: https://arxiv.org/abs/2309.00071
"""
@staticmethod
def compute_yarn_parameters(
original_max_len: int,
target_max_len: int=8192,
dim: int=128,
base: int = 10000,
beta_fast: int = 32,
beta_slow: int = 1,
alpha: float = 1.0,
device: Optional[torch.device] = None
) -> Tuple[torch.Tensor, float]:
scale = float(target_max_len) / original_max_len
mscale = YARNScaling.compute_mscale(scale, alpha)
# 确保 dim 为 float 以进行除法运算
# RoPE 频率是成对的 (0, 2, ..., d-2)
freqs_idx = torch.arange(0, dim, 2, dtype=torch.float32, device=device)
# 基础频率 (Original RoPE)
freq_extra = 1.0 / (base ** (freqs_idx / dim))
# 如果不需要缩放,直接返回基础频率
if scale <= 1.0:
return freq_extra, 1.0
# 插值频率 (Interpolated for extension)
freq_inter = 1.0 / (scale * base ** (freqs_idx / dim))
# 计算 YARN 阈值 (基于波长/索引)
# 对应 paper 中的 band constraints
# 这里的公式将频率索引 i 映射到阈值
def get_limit(beta):
return dim * math.log(original_max_len / (2 * math.pi * beta)) / (2 * math.log(base))
low = max(math.floor(get_limit(beta_fast)), 0)
high = min(math.ceil(get_limit(beta_slow)), dim // 2 - 1)
# indices: 0, 1, ..., dim/2 - 1
indices = torch.arange(0, dim // 2, dtype=torch.float32, device=device)
inv_freq = freq_extra.clone()
# 1. 低频部分 (Long wavelengths, Indices > high): 使用插值频率
# 这些频率对应的波长已经超过了原始上下文长度,需要拉伸
mask_low_freq = indices > high
inv_freq[mask_low_freq] = freq_inter[mask_low_freq]
# 2. 高频部分 (Short wavelengths, Indices < low): 保持原频率 (freq_extra)
# 这些部分受旋转不变性保护,不需要插值
# 3. 中间部分: 线性平滑混合 (Ramp function)
mid_mask = (indices >= low) & (indices <= high)
if mid_mask.any():
# 避免除以 0
denom = max(high - low, 1)
t = (indices[mid_mask] - low) / denom
inv_freq[mid_mask] = freq_extra[mid_mask] * (1 - t) + freq_inter[mid_mask] * t
return inv_freq, float(mscale)
@staticmethod
def compute_mscale(scale: float, alpha: float = 1.0) -> float:
"""计算注意力缩放因子 (Temperature scaling)"""
if scale <= 1.0:
return 1.0
# 0.1 * ln(scale) + 1.0 是经验公式,用于修正熵值
return 0.1 * math.log(scale) + 1.0
class YARNRotaryEmbedding(nn.Module):
"""
集成 YARN 的旋转位置编码
修复了精度问题、缓存管理以及 position_ids 越界问题
"""
def __init__(
self,
dim: int = 64,
max_seq_len: int = 8192,
original_max_len: int = 4096,
base: int = 10000,
scaling_factor: float = 1.0, # 预留接口,暂未使用,由 yarn 逻辑控制
beta_fast: int = 32,
beta_slow: int = 1,
alpha: float = 1.0,
rope_percentage: float = 1.0,
device: Optional[torch.device] = None
):
super().__init__()
self.dim = dim
self.max_seq_len = max_seq_len
self.original_max_len = original_max_len
self.base = base
self.alpha = alpha
# 计算实际应用 RoPE 的维度
self.rope_dim = int(dim * rope_percentage)
# 确保是偶数
if self.rope_dim % 2 != 0:
self.rope_dim -= 1
# 初始化频率 (Persistent state)
self._init_yarn_frequencies(device)
# 缓存 cos/sin (Transient state)
# persistent=False 意味着不会保存到 state_dict,减少 checkpoint 大小
self.register_buffer("cos_cached", None, persistent=False)
self.register_buffer("sin_cached", None, persistent=False)
def _init_yarn_frequencies(self, device: Optional[torch.device] = None):
"""初始化 YARN 频率"""
inv_freq, mscale = YARNScaling.compute_yarn_parameters(
self.original_max_len,
self.max_seq_len,
self.rope_dim,
self.base,
beta_fast=32, # 这里通常使用默认值或传入参数,此处修正为使用硬编码默认值保持一致,或应改为 self.beta_fast
beta_slow=1,
alpha=self.alpha,
device=device
)
# 注册 buffer
self.register_buffer("inv_freq", inv_freq, persistent=True)
self.register_buffer("mscale", torch.tensor(mscale, dtype=torch.float32, device=device), persistent=True)
def _compute_cos_sin_cache(
self,
needed_len: int,
device: torch.device,
dtype: torch.dtype
):
"""预计算 cos 和 sin 缓存,始终使用 float32 计算以保证精度"""
# 至少分配 max_seq_len,如果 needed_len 更大则扩展
alloc_len = max(needed_len, self.max_seq_len)
# 如果已有缓存且足够大且设备匹配,则不重新计算 (可选优化,这里选择简单逻辑:不够就重算)
if (self.cos_cached is not None and
self.cos_cached.shape[2] >= alloc_len and
self.cos_cached.device == device):
return
t = torch.arange(alloc_len, dtype=torch.float32, device=device)
# freqs: [alloc_len, dim // 2]
# outer product: t[i] * inv_freq[j]
freqs = torch.outer(t, self.inv_freq.to(device))
# 拼接以匹配 rotate_half 的逻辑: [theta_0, theta_1, ..., theta_0, theta_1, ...]
emb = torch.cat((freqs, freqs), dim=-1)
# 应用 mscale 并计算 cos/sin
# [alloc_len, rope_dim] -> [1, 1, alloc_len, rope_dim] 用于广播
cos_cached = (emb.cos() * self.mscale).view(1, 1, alloc_len, self.rope_dim)
sin_cached = (emb.sin() * self.mscale).view(1, 1, alloc_len, self.rope_dim)
self.cos_cached = cos_cached.to(dtype) # 缓存可以存为半精度以省显存,但计算时建议 float32
self.sin_cached = sin_cached.to(dtype)
@staticmethod
def rotate_half(x: torch.Tensor) -> torch.Tensor:
"""
旋转输入的后半部分
Input: [..., d] -> Split into x1, x2 -> Output [-x2, x1]
"""
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(
self,
q: torch.Tensor,
k: torch.Tensor,
position_ids: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""应用 RoPE,包含精度修正和边界检查"""
bsz, num_heads, seq_len, head_dim = q.shape
# 1. 确定需要的缓存长度
if position_ids is not None:
# 必须覆盖 position_ids 中的最大索引
max_pos = position_ids.max().item() + 1
needed_len = max(max_pos, seq_len)
else:
needed_len = seq_len
# 2. 检查并更新缓存
if (self.cos_cached is None or
self.cos_cached.shape[2] < needed_len or
self.cos_cached.device != q.device):
self._compute_cos_sin_cache(needed_len, q.device, q.dtype)
# 3. 获取对应的 cos/sin
# cos_cached: [1, 1, alloc_len, dim]
if position_ids is not None:
# position_ids: [bs, seq_len]
# 选取对应的 pos embedding -> [bs, 1, seq_len, dim]
# 注意: cos_cached[0, 0] 形状为 [alloc_len, dim]
cos = self.cos_cached[0, 0][position_ids].unsqueeze(1)
sin = self.sin_cached[0, 0][position_ids].unsqueeze(1)
else:
# 默认假设从 0 开始
cos = self.cos_cached[:, :, :seq_len, :]
sin = self.sin_cached[:, :, :seq_len, :]
# 4. 处理部分 RoPE (如果 rope_dim < head_dim)
if self.rope_dim < head_dim:
q_rot = q[..., :self.rope_dim]
q_pass = q[..., self.rope_dim:]
k_rot = k[..., :self.rope_dim]
k_pass = k[..., self.rope_dim:]
else:
q_rot = q
k_rot = k
q_pass = None
k_pass = None
# 5. 执行旋转 (强制 float32 计算以避免精度溢出)
q_rot_float = q_rot.float()
k_rot_float = k_rot.float()
cos_float = cos.float()
sin_float = sin.float()
q_embed = (q_rot_float * cos_float) + (self.rotate_half(q_rot_float) * sin_float)
k_embed = (k_rot_float * cos_float) + (self.rotate_half(k_rot_float) * sin_float)
# 6. 转回原始类型
q_embed = q_embed.type_as(q)
k_embed = k_embed.type_as(k)
if q_pass is not None:
q_embed = torch.cat([q_embed, q_pass], dim=-1)
k_embed = torch.cat([k_embed, k_pass], dim=-1)
return q_embed, k_embed
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
position_ids: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
return self.apply_rotary_pos_emb(q, k, position_ids)
def extra_repr(self) -> str:
return (f"dim={self.dim}, rope_dim={self.rope_dim}, "
f"max_seq_len={self.max_seq_len}, original_max_len={self.original_max_len}, "
f"base={self.base}")
class RMSNorm(nn.Module):
"""
Root Mean Square Layer Normalization
包含 float32 强制转换以确保数值稳定性
"""
def __init__(
self,
dim: int,
eps: float = 1e-6,
elementwise_affine: bool = True
):
super().__init__()
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
self.weight = nn.Parameter(torch.ones(dim))
else:
self.register_parameter('weight', None)
def _norm(self, x: torch.Tensor) -> torch.Tensor:
# 始终在 float32 下计算 RMS,防止 FP16 下溢或溢出
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# 1. 转换为 float32 进行统计量计算
output = self._norm(x.float())
# 2. 转回原始类型
output = output.type_as(x)
# 3. 应用权重 (如果存在)
if self.elementwise_affine and self.weight is not None:
output = output * self.weight
return output
class QKNorm(nn.Module):
"""
Query-Key Normalization (ViT-22B / Scaling Transformer)
用于稳定注意力矩阵的 logits
"""
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.query_norm = RMSNorm(dim, eps=eps)
self.key_norm = RMSNorm(dim, eps=eps)
def forward(
self,
q: torch.Tensor,
k: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
q = self.query_norm(q)
k = self.key_norm(k)
return q, k
class SwiGLU(nn.Module):
"""
SwiGLU 激活前馈网络
结构: Down(SiLU(Gate) * Up)
"""
def __init__(
self,
dim: int,
hidden_dim: Optional[int] = None,
multiple_of: int = 256,
ffn_dim_multiplier: Optional[float] = None,
dropout: float = 0.0,
bias: bool = False
):
super().__init__()
if hidden_dim is None:
if ffn_dim_multiplier is not None:
hidden_dim = int(dim * ffn_dim_multiplier)
else:
# 默认: 2/3 * 4 * dim = 8/3 * dim (LLaMA standard)
hidden_dim = int(2 * dim * 4 / 3)
# 确保 hidden_dim 是 multiple_of 的倍数 (通常为了 GPU 核心优化)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.hidden_dim = hidden_dim
# W1: Gate, W3: Up, W2: Down (Standard LLaMA naming conventions)
self.w1 = nn.Linear(dim, hidden_dim, bias=bias)
self.w2 = nn.Linear(hidden_dim, dim, bias=bias)
self.w3 = nn.Linear(dim, hidden_dim, bias=bias)
self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor:
# SwiGLU(x) = (SiLU(W1·x) ⊙ W3·x) · W2
return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
class ParallelAttentionFFN(nn.Module):
"""
并行注意力与前馈网络 (PaLM / GPT-J 风格)
y = x + Attention(LN(x)) + MLP(LN(x))
"""
def __init__(
self,
dim: int,
attn_module: nn.Module,
ffn_module: nn.Module,
norm_eps: float = 1e-6
):
super().__init__()
# 注意: 某些架构(如 PaLM)可能共用一个 LayerNorm,
# 但这里为了灵活性保留两个独立的 Norm (如 CodeLlama 某些变体)
self.attn_norm = RMSNorm(dim, eps=norm_eps)
self.ffn_norm = RMSNorm(dim, eps=norm_eps)
self.attn = attn_module
self.ffn = ffn_module
def forward(
self,
x: torch.Tensor,
**attn_kwargs
) -> torch.Tensor:
# 并行计算:从同一个 x (normalize 后) 分叉
attn_input = self.attn_norm(x)
ffn_input = self.ffn_norm(x)
# 计算注意力
attn_out = self.attn(attn_input, **attn_kwargs)
# 计算 FFN (确保不传递 attn 特定的 kwargs)
ffn_out = self.ffn(ffn_input)
# 一次性残差连接
return x + attn_out + ffn_out