|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from typing import Tuple, Optional, Union |
|
|
import math |
|
|
|
|
|
class YARNScaling: |
|
|
@staticmethod |
|
|
def compute_yarn_parameters( |
|
|
original_max_len: int, |
|
|
target_max_len: int=8192, |
|
|
dim: int=128, |
|
|
base: int = 10000, |
|
|
beta_fast: int = 32, |
|
|
beta_slow: int = 1, |
|
|
alpha: float = 1.0, |
|
|
device: Optional[torch.device] = None |
|
|
) -> Tuple[torch.Tensor, float]: |
|
|
scale = float(target_max_len) / original_max_len |
|
|
mscale = YARNScaling.compute_mscale(scale, alpha) |
|
|
|
|
|
|
|
|
|
|
|
freqs_idx = torch.arange(0, dim, 2, dtype=torch.float32, device=device) |
|
|
|
|
|
|
|
|
freq_extra = 1.0 / (base ** (freqs_idx / dim)) |
|
|
|
|
|
|
|
|
if scale <= 1.0: |
|
|
return freq_extra, 1.0 |
|
|
|
|
|
|
|
|
freq_inter = 1.0 / (scale * base ** (freqs_idx / dim)) |
|
|
|
|
|
def get_limit(beta): |
|
|
return dim * math.log(original_max_len / (2 * math.pi * beta)) / (2 * math.log(base)) |
|
|
|
|
|
low = max(math.floor(get_limit(beta_fast)), 0) |
|
|
high = min(math.ceil(get_limit(beta_slow)), dim // 2 - 1) |
|
|
|
|
|
indices = torch.arange(0, dim // 2, dtype=torch.float32, device=device) |
|
|
|
|
|
inv_freq = freq_extra.clone() |
|
|
|
|
|
mask_low_freq = indices > high |
|
|
inv_freq[mask_low_freq] = freq_inter[mask_low_freq] |
|
|
|
|
|
mid_mask = (indices >= low) & (indices <= high) |
|
|
if mid_mask.any(): |
|
|
|
|
|
denom = max(high - low, 1) |
|
|
t = (indices[mid_mask] - low) / denom |
|
|
inv_freq[mid_mask] = freq_extra[mid_mask] * (1 - t) + freq_inter[mid_mask] * t |
|
|
|
|
|
return inv_freq, float(mscale) |
|
|
|
|
|
@staticmethod |
|
|
def compute_mscale(scale: float, alpha: float = 1.0) -> float: |
|
|
"""计算注意力缩放因子 (Temperature scaling)""" |
|
|
if scale <= 1.0: |
|
|
return 1.0 |
|
|
return 0.1 * math.log(scale) + 1.0 |
|
|
|
|
|
class YARNRotaryEmbedding(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
dim: int = 64, |
|
|
max_seq_len: int = 8192, |
|
|
original_max_len: int = 4096, |
|
|
base: int = 10000, |
|
|
scaling_factor: float = 1.0, |
|
|
beta_fast: int = 32, |
|
|
beta_slow: int = 1, |
|
|
alpha: float = 1.0, |
|
|
rope_percentage: float = 1.0, |
|
|
device: Optional[torch.device] = None |
|
|
): |
|
|
super().__init__() |
|
|
self.dim = dim |
|
|
self.max_seq_len = max_seq_len |
|
|
self.original_max_len = original_max_len |
|
|
self.base = base |
|
|
self.alpha = alpha |
|
|
|
|
|
|
|
|
self.rope_dim = int(dim * rope_percentage) |
|
|
|
|
|
if self.rope_dim % 2 != 0: |
|
|
self.rope_dim -= 1 |
|
|
|
|
|
|
|
|
self._init_yarn_frequencies(device) |
|
|
|
|
|
|
|
|
self.register_buffer("cos_cached", None, persistent=False) |
|
|
self.register_buffer("sin_cached", None, persistent=False) |
|
|
|
|
|
def _init_yarn_frequencies(self, device: Optional[torch.device] = None): |
|
|
inv_freq, mscale = YARNScaling.compute_yarn_parameters( |
|
|
self.original_max_len, |
|
|
self.max_seq_len, |
|
|
self.rope_dim, |
|
|
self.base, |
|
|
beta_fast=32, |
|
|
beta_slow=1, |
|
|
alpha=self.alpha, |
|
|
device=device |
|
|
) |
|
|
self.register_buffer("inv_freq", inv_freq, persistent=True) |
|
|
self.register_buffer("mscale", torch.tensor(mscale, dtype=torch.float32, device=device), persistent=True) |
|
|
|
|
|
def _compute_cos_sin_cache( |
|
|
self, |
|
|
needed_len: int, |
|
|
device: torch.device, |
|
|
dtype: torch.dtype |
|
|
): |
|
|
alloc_len = max(needed_len, self.max_seq_len) |
|
|
|
|
|
if (self.cos_cached is not None and |
|
|
self.cos_cached.shape[2] >= alloc_len and |
|
|
self.cos_cached.device == device): |
|
|
return |
|
|
|
|
|
t = torch.arange(alloc_len, dtype=torch.float32, device=device) |
|
|
freqs = torch.outer(t, self.inv_freq.to(device)) |
|
|
emb = torch.cat((freqs, freqs), dim=-1) |
|
|
|
|
|
cos_cached = (emb.cos() * self.mscale).view(1, 1, alloc_len, self.rope_dim) |
|
|
sin_cached = (emb.sin() * self.mscale).view(1, 1, alloc_len, self.rope_dim) |
|
|
|
|
|
self.cos_cached = cos_cached.to(dtype) |
|
|
self.sin_cached = sin_cached.to(dtype) |
|
|
|
|
|
@staticmethod |
|
|
def rotate_half(x: torch.Tensor) -> torch.Tensor: |
|
|
x1, x2 = x.chunk(2, dim=-1) |
|
|
return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
|
def apply_rotary_pos_emb( |
|
|
self, |
|
|
q: torch.Tensor, |
|
|
k: torch.Tensor, |
|
|
position_ids: Optional[torch.Tensor] = None |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
bsz, num_heads, seq_len, head_dim = q.shape |
|
|
|
|
|
if position_ids is not None: |
|
|
max_pos = position_ids.max().item() + 1 |
|
|
needed_len = max(max_pos, seq_len) |
|
|
else: |
|
|
needed_len = seq_len |
|
|
|
|
|
if (self.cos_cached is None or |
|
|
self.cos_cached.shape[2] < needed_len or |
|
|
self.cos_cached.device != q.device): |
|
|
self._compute_cos_sin_cache(needed_len, q.device, q.dtype) |
|
|
|
|
|
if position_ids is not None: |
|
|
cos = self.cos_cached[0, 0][position_ids].unsqueeze(1) |
|
|
sin = self.sin_cached[0, 0][position_ids].unsqueeze(1) |
|
|
else: |
|
|
cos = self.cos_cached[:, :, :seq_len, :] |
|
|
sin = self.sin_cached[:, :, :seq_len, :] |
|
|
|
|
|
if self.rope_dim < head_dim: |
|
|
q_rot = q[..., :self.rope_dim] |
|
|
q_pass = q[..., self.rope_dim:] |
|
|
k_rot = k[..., :self.rope_dim] |
|
|
k_pass = k[..., self.rope_dim:] |
|
|
else: |
|
|
q_rot = q |
|
|
k_rot = k |
|
|
q_pass = None |
|
|
k_pass = None |
|
|
|
|
|
q_rot_float = q_rot.float() |
|
|
k_rot_float = k_rot.float() |
|
|
cos_float = cos.float() |
|
|
sin_float = sin.float() |
|
|
|
|
|
q_embed = (q_rot_float * cos_float) + (self.rotate_half(q_rot_float) * sin_float) |
|
|
k_embed = (k_rot_float * cos_float) + (self.rotate_half(k_rot_float) * sin_float) |
|
|
|
|
|
q_embed = q_embed.type_as(q) |
|
|
k_embed = k_embed.type_as(k) |
|
|
|
|
|
if q_pass is not None: |
|
|
q_embed = torch.cat([q_embed, q_pass], dim=-1) |
|
|
k_embed = torch.cat([k_embed, k_pass], dim=-1) |
|
|
|
|
|
return q_embed, k_embed |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
q: torch.Tensor, |
|
|
k: torch.Tensor, |
|
|
position_ids: Optional[torch.Tensor] = None |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
return self.apply_rotary_pos_emb(q, k, position_ids) |
|
|
|
|
|
def extra_repr(self) -> str: |
|
|
return (f"dim={self.dim}, rope_dim={self.rope_dim}, " |
|
|
f"max_seq_len={self.max_seq_len}, original_max_len={self.original_max_len}, " |
|
|
f"base={self.base}") |
|
|
|
|
|
class RMSNorm(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
dim: int, |
|
|
eps: float = 1e-6, |
|
|
elementwise_affine: bool = True |
|
|
): |
|
|
super().__init__() |
|
|
self.eps = eps |
|
|
self.elementwise_affine = elementwise_affine |
|
|
|
|
|
if self.elementwise_affine: |
|
|
self.weight = nn.Parameter(torch.ones(dim)) |
|
|
else: |
|
|
self.register_parameter('weight', None) |
|
|
|
|
|
def _norm(self, x: torch.Tensor) -> torch.Tensor: |
|
|
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
output = self._norm(x.float()) |
|
|
output = output.type_as(x) |
|
|
|
|
|
if self.elementwise_affine and self.weight is not None: |
|
|
output = output * self.weight |
|
|
|
|
|
return output |
|
|
|
|
|
class QKNorm(nn.Module): |
|
|
def __init__(self, dim: int, eps: float = 1e-6): |
|
|
super().__init__() |
|
|
self.query_norm = RMSNorm(dim, eps=eps) |
|
|
self.key_norm = RMSNorm(dim, eps=eps) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
q: torch.Tensor, |
|
|
k: torch.Tensor |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
q = self.query_norm(q) |
|
|
k = self.key_norm(k) |
|
|
return q, k |
|
|
|
|
|
class SwiGLU(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
dim: int, |
|
|
hidden_dim: Optional[int] = None, |
|
|
multiple_of: int = 256, |
|
|
ffn_dim_multiplier: Optional[float] = None, |
|
|
dropout: float = 0.0, |
|
|
bias: bool = False |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
if hidden_dim is None: |
|
|
if ffn_dim_multiplier is not None: |
|
|
hidden_dim = int(dim * ffn_dim_multiplier) |
|
|
else: |
|
|
|
|
|
hidden_dim = int(2 * dim * 4 / 3) |
|
|
|
|
|
|
|
|
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) |
|
|
|
|
|
self.hidden_dim = hidden_dim |
|
|
|
|
|
|
|
|
self.w1 = nn.Linear(dim, hidden_dim, bias=bias) |
|
|
self.w2 = nn.Linear(hidden_dim, dim, bias=bias) |
|
|
self.w3 = nn.Linear(dim, hidden_dim, bias=bias) |
|
|
self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity() |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
|
|
|
return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x))) |
|
|
|
|
|
class ParallelAttentionFFN(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
dim: int, |
|
|
attn_module: nn.Module, |
|
|
ffn_module: nn.Module, |
|
|
norm_eps: float = 1e-6 |
|
|
): |
|
|
super().__init__() |
|
|
self.attn_norm = RMSNorm(dim, eps=norm_eps) |
|
|
self.ffn_norm = RMSNorm(dim, eps=norm_eps) |
|
|
self.attn = attn_module |
|
|
self.ffn = ffn_module |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
x: torch.Tensor, |
|
|
**attn_kwargs |
|
|
) -> torch.Tensor: |
|
|
|
|
|
attn_input = self.attn_norm(x) |
|
|
ffn_input = self.ffn_norm(x) |
|
|
|
|
|
|
|
|
attn_out = self.attn(attn_input, **attn_kwargs) |
|
|
|
|
|
|
|
|
ffn_out = self.ffn(ffn_input) |
|
|
|
|
|
|
|
|
return x + attn_out + ffn_out |