| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| from typing import Tuple, Optional, Union
|
| import math
|
|
|
| class YARNScaling:
|
| """
|
| YARN (Yet Another RoPE extensioN) 缩放策略
|
| 实现参考: https://arxiv.org/abs/2309.00071
|
| """
|
| @staticmethod
|
| def compute_yarn_parameters(
|
| original_max_len: int,
|
| target_max_len: int=8192,
|
| dim: int=128,
|
| base: int = 10000,
|
| beta_fast: int = 32,
|
| beta_slow: int = 1,
|
| alpha: float = 1.0,
|
| device: Optional[torch.device] = None
|
| ) -> Tuple[torch.Tensor, float]:
|
| scale = float(target_max_len) / original_max_len
|
| mscale = YARNScaling.compute_mscale(scale, alpha)
|
|
|
|
|
|
|
| freqs_idx = torch.arange(0, dim, 2, dtype=torch.float32, device=device)
|
|
|
|
|
| freq_extra = 1.0 / (base ** (freqs_idx / dim))
|
|
|
|
|
| if scale <= 1.0:
|
| return freq_extra, 1.0
|
|
|
|
|
| freq_inter = 1.0 / (scale * base ** (freqs_idx / dim))
|
|
|
|
|
|
|
|
|
| def get_limit(beta):
|
| return dim * math.log(original_max_len / (2 * math.pi * beta)) / (2 * math.log(base))
|
|
|
| low = max(math.floor(get_limit(beta_fast)), 0)
|
| high = min(math.ceil(get_limit(beta_slow)), dim // 2 - 1)
|
|
|
|
|
| indices = torch.arange(0, dim // 2, dtype=torch.float32, device=device)
|
|
|
| inv_freq = freq_extra.clone()
|
|
|
|
|
|
|
| mask_low_freq = indices > high
|
| inv_freq[mask_low_freq] = freq_inter[mask_low_freq]
|
|
|
|
|
|
|
|
|
|
|
| mid_mask = (indices >= low) & (indices <= high)
|
| if mid_mask.any():
|
|
|
| denom = max(high - low, 1)
|
| t = (indices[mid_mask] - low) / denom
|
| inv_freq[mid_mask] = freq_extra[mid_mask] * (1 - t) + freq_inter[mid_mask] * t
|
|
|
| return inv_freq, float(mscale)
|
|
|
| @staticmethod
|
| def compute_mscale(scale: float, alpha: float = 1.0) -> float:
|
| """计算注意力缩放因子 (Temperature scaling)"""
|
| if scale <= 1.0:
|
| return 1.0
|
|
|
| return 0.1 * math.log(scale) + 1.0
|
|
|
| class YARNRotaryEmbedding(nn.Module):
|
| """
|
| 集成 YARN 的旋转位置编码
|
| 修复了精度问题、缓存管理以及 position_ids 越界问题
|
| """
|
| def __init__(
|
| self,
|
| dim: int = 64,
|
| max_seq_len: int = 8192,
|
| original_max_len: int = 4096,
|
| base: int = 10000,
|
| scaling_factor: float = 1.0,
|
| beta_fast: int = 32,
|
| beta_slow: int = 1,
|
| alpha: float = 1.0,
|
| rope_percentage: float = 1.0,
|
| device: Optional[torch.device] = None
|
| ):
|
| super().__init__()
|
| self.dim = dim
|
| self.max_seq_len = max_seq_len
|
| self.original_max_len = original_max_len
|
| self.base = base
|
| self.alpha = alpha
|
|
|
|
|
| self.rope_dim = int(dim * rope_percentage)
|
|
|
| if self.rope_dim % 2 != 0:
|
| self.rope_dim -= 1
|
|
|
|
|
| self._init_yarn_frequencies(device)
|
|
|
|
|
|
|
| self.register_buffer("cos_cached", None, persistent=False)
|
| self.register_buffer("sin_cached", None, persistent=False)
|
|
|
| def _init_yarn_frequencies(self, device: Optional[torch.device] = None):
|
| """初始化 YARN 频率"""
|
| inv_freq, mscale = YARNScaling.compute_yarn_parameters(
|
| self.original_max_len,
|
| self.max_seq_len,
|
| self.rope_dim,
|
| self.base,
|
| beta_fast=32,
|
| beta_slow=1,
|
| alpha=self.alpha,
|
| device=device
|
| )
|
|
|
| self.register_buffer("inv_freq", inv_freq, persistent=True)
|
| self.register_buffer("mscale", torch.tensor(mscale, dtype=torch.float32, device=device), persistent=True)
|
|
|
| def _compute_cos_sin_cache(
|
| self,
|
| needed_len: int,
|
| device: torch.device,
|
| dtype: torch.dtype
|
| ):
|
| """预计算 cos 和 sin 缓存,始终使用 float32 计算以保证精度"""
|
|
|
| alloc_len = max(needed_len, self.max_seq_len)
|
|
|
|
|
| if (self.cos_cached is not None and
|
| self.cos_cached.shape[2] >= alloc_len and
|
| self.cos_cached.device == device):
|
| return
|
|
|
| t = torch.arange(alloc_len, dtype=torch.float32, device=device)
|
|
|
|
|
|
|
| freqs = torch.outer(t, self.inv_freq.to(device))
|
|
|
|
|
| emb = torch.cat((freqs, freqs), dim=-1)
|
|
|
|
|
|
|
| cos_cached = (emb.cos() * self.mscale).view(1, 1, alloc_len, self.rope_dim)
|
| sin_cached = (emb.sin() * self.mscale).view(1, 1, alloc_len, self.rope_dim)
|
|
|
| self.cos_cached = cos_cached.to(dtype)
|
| self.sin_cached = sin_cached.to(dtype)
|
|
|
| @staticmethod
|
| def rotate_half(x: torch.Tensor) -> torch.Tensor:
|
| """
|
| 旋转输入的后半部分
|
| Input: [..., d] -> Split into x1, x2 -> Output [-x2, x1]
|
| """
|
| x1, x2 = x.chunk(2, dim=-1)
|
| return torch.cat((-x2, x1), dim=-1)
|
|
|
| def apply_rotary_pos_emb(
|
| self,
|
| q: torch.Tensor,
|
| k: torch.Tensor,
|
| position_ids: Optional[torch.Tensor] = None
|
| ) -> Tuple[torch.Tensor, torch.Tensor]:
|
| """应用 RoPE,包含精度修正和边界检查"""
|
| bsz, num_heads, seq_len, head_dim = q.shape
|
|
|
|
|
| if position_ids is not None:
|
|
|
| max_pos = position_ids.max().item() + 1
|
| needed_len = max(max_pos, seq_len)
|
| else:
|
| needed_len = seq_len
|
|
|
|
|
| if (self.cos_cached is None or
|
| self.cos_cached.shape[2] < needed_len or
|
| self.cos_cached.device != q.device):
|
| self._compute_cos_sin_cache(needed_len, q.device, q.dtype)
|
|
|
|
|
|
|
| if position_ids is not None:
|
|
|
|
|
|
|
| cos = self.cos_cached[0, 0][position_ids].unsqueeze(1)
|
| sin = self.sin_cached[0, 0][position_ids].unsqueeze(1)
|
| else:
|
|
|
| cos = self.cos_cached[:, :, :seq_len, :]
|
| sin = self.sin_cached[:, :, :seq_len, :]
|
|
|
|
|
| if self.rope_dim < head_dim:
|
| q_rot = q[..., :self.rope_dim]
|
| q_pass = q[..., self.rope_dim:]
|
| k_rot = k[..., :self.rope_dim]
|
| k_pass = k[..., self.rope_dim:]
|
| else:
|
| q_rot = q
|
| k_rot = k
|
| q_pass = None
|
| k_pass = None
|
|
|
|
|
| q_rot_float = q_rot.float()
|
| k_rot_float = k_rot.float()
|
| cos_float = cos.float()
|
| sin_float = sin.float()
|
|
|
| q_embed = (q_rot_float * cos_float) + (self.rotate_half(q_rot_float) * sin_float)
|
| k_embed = (k_rot_float * cos_float) + (self.rotate_half(k_rot_float) * sin_float)
|
|
|
|
|
| q_embed = q_embed.type_as(q)
|
| k_embed = k_embed.type_as(k)
|
|
|
| if q_pass is not None:
|
| q_embed = torch.cat([q_embed, q_pass], dim=-1)
|
| k_embed = torch.cat([k_embed, k_pass], dim=-1)
|
|
|
| return q_embed, k_embed
|
|
|
| def forward(
|
| self,
|
| q: torch.Tensor,
|
| k: torch.Tensor,
|
| position_ids: Optional[torch.Tensor] = None
|
| ) -> Tuple[torch.Tensor, torch.Tensor]:
|
| return self.apply_rotary_pos_emb(q, k, position_ids)
|
|
|
| def extra_repr(self) -> str:
|
| return (f"dim={self.dim}, rope_dim={self.rope_dim}, "
|
| f"max_seq_len={self.max_seq_len}, original_max_len={self.original_max_len}, "
|
| f"base={self.base}")
|
|
|
| class RMSNorm(nn.Module):
|
| """
|
| Root Mean Square Layer Normalization
|
| 包含 float32 强制转换以确保数值稳定性
|
| """
|
| def __init__(
|
| self,
|
| dim: int,
|
| eps: float = 1e-6,
|
| elementwise_affine: bool = True
|
| ):
|
| super().__init__()
|
| self.eps = eps
|
| self.elementwise_affine = elementwise_affine
|
|
|
| if self.elementwise_affine:
|
| self.weight = nn.Parameter(torch.ones(dim))
|
| else:
|
| self.register_parameter('weight', None)
|
|
|
| def _norm(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
| return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
| output = self._norm(x.float())
|
|
|
|
|
| output = output.type_as(x)
|
|
|
|
|
| if self.elementwise_affine and self.weight is not None:
|
| output = output * self.weight
|
|
|
| return output
|
|
|
| class QKNorm(nn.Module):
|
| """
|
| Query-Key Normalization (ViT-22B / Scaling Transformer)
|
| 用于稳定注意力矩阵的 logits
|
| """
|
| def __init__(self, dim: int, eps: float = 1e-6):
|
| super().__init__()
|
| self.query_norm = RMSNorm(dim, eps=eps)
|
| self.key_norm = RMSNorm(dim, eps=eps)
|
|
|
| def forward(
|
| self,
|
| q: torch.Tensor,
|
| k: torch.Tensor
|
| ) -> Tuple[torch.Tensor, torch.Tensor]:
|
| q = self.query_norm(q)
|
| k = self.key_norm(k)
|
| return q, k
|
|
|
| class SwiGLU(nn.Module):
|
| """
|
| SwiGLU 激活前馈网络
|
| 结构: Down(SiLU(Gate) * Up)
|
| """
|
| def __init__(
|
| self,
|
| dim: int,
|
| hidden_dim: Optional[int] = None,
|
| multiple_of: int = 256,
|
| ffn_dim_multiplier: Optional[float] = None,
|
| dropout: float = 0.0,
|
| bias: bool = False
|
| ):
|
| super().__init__()
|
|
|
| if hidden_dim is None:
|
| if ffn_dim_multiplier is not None:
|
| hidden_dim = int(dim * ffn_dim_multiplier)
|
| else:
|
|
|
| hidden_dim = int(2 * dim * 4 / 3)
|
|
|
|
|
| hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
|
|
| self.hidden_dim = hidden_dim
|
|
|
|
|
| self.w1 = nn.Linear(dim, hidden_dim, bias=bias)
|
| self.w2 = nn.Linear(hidden_dim, dim, bias=bias)
|
| self.w3 = nn.Linear(dim, hidden_dim, bias=bias)
|
| self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
|
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
| return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
|
|
|
| class ParallelAttentionFFN(nn.Module):
|
| """
|
| 并行注意力与前馈网络 (PaLM / GPT-J 风格)
|
| y = x + Attention(LN(x)) + MLP(LN(x))
|
| """
|
| def __init__(
|
| self,
|
| dim: int,
|
| attn_module: nn.Module,
|
| ffn_module: nn.Module,
|
| norm_eps: float = 1e-6
|
| ):
|
| super().__init__()
|
|
|
|
|
| self.attn_norm = RMSNorm(dim, eps=norm_eps)
|
| self.ffn_norm = RMSNorm(dim, eps=norm_eps)
|
| self.attn = attn_module
|
| self.ffn = ffn_module
|
|
|
| def forward(
|
| self,
|
| x: torch.Tensor,
|
| **attn_kwargs
|
| ) -> torch.Tensor:
|
|
|
| attn_input = self.attn_norm(x)
|
| ffn_input = self.ffn_norm(x)
|
|
|
|
|
| attn_out = self.attn(attn_input, **attn_kwargs)
|
|
|
|
|
| ffn_out = self.ffn(ffn_input)
|
|
|
|
|
| return x + attn_out + ffn_out |