import torch import torch.nn as nn import torch.nn.functional as F from typing import Tuple, Optional, Union import math class YARNScaling: @staticmethod def compute_yarn_parameters( original_max_len: int, target_max_len: int=8192, dim: int=128, base: int = 10000, beta_fast: int = 32, beta_slow: int = 1, alpha: float = 1.0, device: Optional[torch.device] = None ) -> Tuple[torch.Tensor, float]: scale = float(target_max_len) / original_max_len mscale = YARNScaling.compute_mscale(scale, alpha) # 确保 dim 为 float 以进行除法运算 # RoPE 频率是成对的 (0, 2, ..., d-2) freqs_idx = torch.arange(0, dim, 2, dtype=torch.float32, device=device) # 基础频率 (Original RoPE) freq_extra = 1.0 / (base ** (freqs_idx / dim)) # 如果不需要缩放,直接返回基础频率 if scale <= 1.0: return freq_extra, 1.0 # 插值频率 (Interpolated for extension) freq_inter = 1.0 / (scale * base ** (freqs_idx / dim)) def get_limit(beta): return dim * math.log(original_max_len / (2 * math.pi * beta)) / (2 * math.log(base)) low = max(math.floor(get_limit(beta_fast)), 0) high = min(math.ceil(get_limit(beta_slow)), dim // 2 - 1) indices = torch.arange(0, dim // 2, dtype=torch.float32, device=device) inv_freq = freq_extra.clone() mask_low_freq = indices > high inv_freq[mask_low_freq] = freq_inter[mask_low_freq] mid_mask = (indices >= low) & (indices <= high) if mid_mask.any(): # 避免除以 0 denom = max(high - low, 1) t = (indices[mid_mask] - low) / denom inv_freq[mid_mask] = freq_extra[mid_mask] * (1 - t) + freq_inter[mid_mask] * t return inv_freq, float(mscale) @staticmethod def compute_mscale(scale: float, alpha: float = 1.0) -> float: """计算注意力缩放因子 (Temperature scaling)""" if scale <= 1.0: return 1.0 return 0.1 * math.log(scale) + 1.0 class YARNRotaryEmbedding(nn.Module): def __init__( self, dim: int = 64, max_seq_len: int = 8192, original_max_len: int = 4096, base: int = 10000, scaling_factor: float = 1.0, beta_fast: int = 32, beta_slow: int = 1, alpha: float = 1.0, rope_percentage: float = 1.0, device: Optional[torch.device] = None ): super().__init__() self.dim = dim self.max_seq_len = max_seq_len self.original_max_len = original_max_len self.base = base self.alpha = alpha # 计算实际应用 RoPE 的维度 self.rope_dim = int(dim * rope_percentage) # 确保是偶数 if self.rope_dim % 2 != 0: self.rope_dim -= 1 # 初始化频率 (Persistent state) self._init_yarn_frequencies(device) # 缓存 cos/sin self.register_buffer("cos_cached", None, persistent=False) self.register_buffer("sin_cached", None, persistent=False) def _init_yarn_frequencies(self, device: Optional[torch.device] = None): inv_freq, mscale = YARNScaling.compute_yarn_parameters( self.original_max_len, self.max_seq_len, self.rope_dim, self.base, beta_fast=32, beta_slow=1, alpha=self.alpha, device=device ) self.register_buffer("inv_freq", inv_freq, persistent=True) self.register_buffer("mscale", torch.tensor(mscale, dtype=torch.float32, device=device), persistent=True) def _compute_cos_sin_cache( self, needed_len: int, device: torch.device, dtype: torch.dtype ): alloc_len = max(needed_len, self.max_seq_len) if (self.cos_cached is not None and self.cos_cached.shape[2] >= alloc_len and self.cos_cached.device == device): return t = torch.arange(alloc_len, dtype=torch.float32, device=device) freqs = torch.outer(t, self.inv_freq.to(device)) emb = torch.cat((freqs, freqs), dim=-1) cos_cached = (emb.cos() * self.mscale).view(1, 1, alloc_len, self.rope_dim) sin_cached = (emb.sin() * self.mscale).view(1, 1, alloc_len, self.rope_dim) self.cos_cached = cos_cached.to(dtype) self.sin_cached = sin_cached.to(dtype) @staticmethod def rotate_half(x: torch.Tensor) -> torch.Tensor: x1, x2 = x.chunk(2, dim=-1) return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb( self, q: torch.Tensor, k: torch.Tensor, position_ids: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, torch.Tensor]: bsz, num_heads, seq_len, head_dim = q.shape if position_ids is not None: max_pos = position_ids.max().item() + 1 needed_len = max(max_pos, seq_len) else: needed_len = seq_len if (self.cos_cached is None or self.cos_cached.shape[2] < needed_len or self.cos_cached.device != q.device): self._compute_cos_sin_cache(needed_len, q.device, q.dtype) if position_ids is not None: cos = self.cos_cached[0, 0][position_ids].unsqueeze(1) sin = self.sin_cached[0, 0][position_ids].unsqueeze(1) else: cos = self.cos_cached[:, :, :seq_len, :] sin = self.sin_cached[:, :, :seq_len, :] if self.rope_dim < head_dim: q_rot = q[..., :self.rope_dim] q_pass = q[..., self.rope_dim:] k_rot = k[..., :self.rope_dim] k_pass = k[..., self.rope_dim:] else: q_rot = q k_rot = k q_pass = None k_pass = None q_rot_float = q_rot.float() k_rot_float = k_rot.float() cos_float = cos.float() sin_float = sin.float() q_embed = (q_rot_float * cos_float) + (self.rotate_half(q_rot_float) * sin_float) k_embed = (k_rot_float * cos_float) + (self.rotate_half(k_rot_float) * sin_float) q_embed = q_embed.type_as(q) k_embed = k_embed.type_as(k) if q_pass is not None: q_embed = torch.cat([q_embed, q_pass], dim=-1) k_embed = torch.cat([k_embed, k_pass], dim=-1) return q_embed, k_embed def forward( self, q: torch.Tensor, k: torch.Tensor, position_ids: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, torch.Tensor]: return self.apply_rotary_pos_emb(q, k, position_ids) def extra_repr(self) -> str: return (f"dim={self.dim}, rope_dim={self.rope_dim}, " f"max_seq_len={self.max_seq_len}, original_max_len={self.original_max_len}, " f"base={self.base}") class RMSNorm(nn.Module): def __init__( self, dim: int, eps: float = 1e-6, elementwise_affine: bool = True ): super().__init__() self.eps = eps self.elementwise_affine = elementwise_affine if self.elementwise_affine: self.weight = nn.Parameter(torch.ones(dim)) else: self.register_parameter('weight', None) def _norm(self, x: torch.Tensor) -> torch.Tensor: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) def forward(self, x: torch.Tensor) -> torch.Tensor: output = self._norm(x.float()) output = output.type_as(x) if self.elementwise_affine and self.weight is not None: output = output * self.weight return output class QKNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.query_norm = RMSNorm(dim, eps=eps) self.key_norm = RMSNorm(dim, eps=eps) def forward( self, q: torch.Tensor, k: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: q = self.query_norm(q) k = self.key_norm(k) return q, k class SwiGLU(nn.Module): def __init__( self, dim: int, hidden_dim: Optional[int] = None, multiple_of: int = 256, ffn_dim_multiplier: Optional[float] = None, dropout: float = 0.0, bias: bool = False ): super().__init__() if hidden_dim is None: if ffn_dim_multiplier is not None: hidden_dim = int(dim * ffn_dim_multiplier) else: # 默认: 2/3 * 4 * dim = 8/3 * dim hidden_dim = int(2 * dim * 4 / 3) # 确保 hidden_dim 是 multiple_of 的倍数 (通常为了 GPU 核心优化) hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) self.hidden_dim = hidden_dim # W1: Gate, W3: Up, W2: Down (Standard LLaMA naming conventions) self.w1 = nn.Linear(dim, hidden_dim, bias=bias) self.w2 = nn.Linear(hidden_dim, dim, bias=bias) self.w3 = nn.Linear(dim, hidden_dim, bias=bias) self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity() def forward(self, x: torch.Tensor) -> torch.Tensor: # SwiGLU(x) = (SiLU(W1·x) ⊙ W3·x) · W2 return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x))) class ParallelAttentionFFN(nn.Module): def __init__( self, dim: int, attn_module: nn.Module, ffn_module: nn.Module, norm_eps: float = 1e-6 ): super().__init__() self.attn_norm = RMSNorm(dim, eps=norm_eps) self.ffn_norm = RMSNorm(dim, eps=norm_eps) self.attn = attn_module self.ffn = ffn_module def forward( self, x: torch.Tensor, **attn_kwargs ) -> torch.Tensor: # 并行计算:从同一个 x (normalize 后) 分叉 attn_input = self.attn_norm(x) ffn_input = self.ffn_norm(x) # 计算注意力 attn_out = self.attn(attn_input, **attn_kwargs) # 计算 FFN (确保不传递 attn 特定的 kwargs) ffn_out = self.ffn(ffn_input) # 一次性残差连接 return x + attn_out + ffn_out