TaoNet-mini-T2 / code /TaoTrain /src /taoTrain /models /mla_components.py
StarMist0012's picture
Add files using upload-large-folder tool
3270dae verified
"""
DeepSeek-style Multi-head Latent Attention (MLA) with RoPE.
Key innovations:
1. KV compression to latent space (reduce KV memory)
2. Q stays in full dimension for expressive query space
3. RoPE positional embeddings on Q and K
4. Grouped Query Attention (GQA) for efficiency
5. Learnable head combination weights
6. Numerical stability via pre-norm and scaling
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
def _residual_rms_norm(x, enabled=False, target=1.0, eps=1e-6, cap=None):
if not enabled and cap is None:
return x
rms = x.float().square().mean(dim=-1, keepdim=True).add(eps).sqrt()
if enabled:
scale = target / rms
else:
cap_tensor = torch.tensor(float(cap), dtype=rms.dtype, device=rms.device)
scale = torch.minimum(torch.ones_like(rms), cap_tensor / rms)
return x * scale.to(dtype=x.dtype)
class RotaryEmbedding(nn.Module):
"""Rotary position embeddings used in RoPE with optional YaRN extension.
YaRN (Yet another RoPE eXtension) allows context length interpolation via
frequency scaling. When yarn_alpha != 1.0 or seq_len > max_seq_length,
frequencies are dynamically scaled to support longer sequences.
Parameters:
dim: Embedding dimension (must be even)
rope_scale: Base RoPE scale factor (default: 40)
max_seq_length: Original trained sequence length (default: 1024)
yarn_alpha: YaRN interpolation factor (default: 1.0, no interpolation)
- values < 1.0: aggressive interpolation (faster context expansion)
- values > 1.0: conservative interpolation (safer)
"""
def __init__(self, dim, rope_scale=40.0, max_seq_length=1024, yarn_alpha=1.0):
super().__init__()
assert dim % 2 == 0, "Dimension must be even for rotary embeddings"
self.dim = dim
self.rope_scale = rope_scale
self.max_seq_length = max_seq_length
self.yarn_alpha = yarn_alpha
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
def _apply_yarn_scaling(self, freqs, seq_len):
"""Apply YaRN frequency scaling for context extension.
Args:
freqs: [seq_len, dim] frequency tensor
seq_len: Current sequence length
Returns:
Scaled freqs if yarn is enabled and seq_len > max_seq_length, else original freqs
"""
# Only apply scaling if sequence exceeds training length or yarn_alpha != 1.0
if self.yarn_alpha == 1.0 and seq_len <= self.max_seq_length:
return freqs
# YaRN scaling factor: interpolate frequency reduction
# scale_factor = (seq_len / max_seq_length) ** (1 / yarn_alpha)
# Scales down frequencies to fit longer context while maintaining position distinctions
scale_factor = (seq_len / self.max_seq_length) ** (1.0 / self.yarn_alpha)
freqs = freqs / scale_factor
return freqs
def forward(self, seq_len, device):
"""Generate rotary embeddings for sequence with optional YaRN scaling.
Args:
seq_len: Current sequence length
device: Device to create embeddings on
Returns:
[seq_len, 2*dim] rotary embeddings (duplicated freqs)
"""
t = torch.arange(seq_len, device=device).type_as(self.inv_freq) / self.rope_scale
freqs = torch.einsum("i,j->ij", t, self.inv_freq) # [seq_len, dim//2]
# Apply YaRN frequency scaling if enabled
freqs = self._apply_yarn_scaling(freqs, seq_len)
return torch.cat((freqs, freqs), dim=-1) # [seq_len, dim]
def rotate_half(x):
"""Rotate half the hidden dims of the input."""
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary(x, cos, sin):
"""Apply rotary embeddings to input tensor.
Args:
x: [B, n_heads, seq_len, head_dim] or similar
cos: [seq_len, head_dim] or [1, 1, seq_len, head_dim]
sin: [seq_len, head_dim] or [1, 1, seq_len, head_dim]
"""
# Ensure cos/sin have the right dimensions for broadcasting
if cos.dim() == 2:
cos = cos.unsqueeze(0).unsqueeze(0)
sin = sin.unsqueeze(0).unsqueeze(0)
# Handle case where cos/sin may be shorter than x
cos = cos[..., :x.shape[-1]]
sin = sin[..., :x.shape[-1]]
# Split x based on cos dimensions
x_rot = x[..., :cos.shape[-1]]
x_base = x[..., cos.shape[-1]:]
# Apply rotation
x_rot = (x_rot * cos) + (rotate_half(x_rot) * sin)
# Concatenate rotated and base parts
return torch.cat([x_rot, x_base], dim=-1) if x_base.shape[-1] > 0 else x_rot
class DeepSeekMLA(nn.Module):
"""
DeepSeek-style Multi-head Latent Attention (MLA).
Architecture:
1. Project input to Query: [B, seq_len, d_model] -> [B, seq_len, d_model]
2. Compress to KV latent: [B, seq_len, d_model] -> [B, seq_len, d_latent_kv]
3. Split into heads for attention
4. Apply RoPE to Q and K
5. Compute attention scores: (Q @ K^T) / sqrt(d_head)
6. Apply softmax and combine with values
7. Concatenate heads and project back to d_model
Parameters:
d_model: Model dimension
d_latent_kv: Latent dimension for KV compression
n_heads: Number of attention heads
d_rope: Dimension for RoPE (usually == d_head_dim)
dropout: Dropout probability
gqa_groups: Grouped Query Attention groups (1 = standard MLA, >1 = GQA)
"""
def __init__(self, d_model, d_latent_kv, n_heads, d_rope, dropout=0.1, gqa_groups=1,
rope_scale=40.0, max_seq_length=1024, yarn_alpha=1.0):
super().__init__()
self.d_model = d_model
self.d_latent_kv = d_latent_kv
self.n_heads = n_heads
self.d_rope = d_rope
self.gqa_groups = gqa_groups
assert d_model % n_heads == 0, f"d_model ({d_model}) must be divisible by n_heads ({n_heads})"
assert d_latent_kv % n_heads == 0, f"d_latent_kv ({d_latent_kv}) must be divisible by n_heads ({n_heads})"
self.d_head_full = d_model // n_heads # Full head dimension for Q
self.d_head_latent = d_latent_kv // n_heads # Latent head dimension for K/V
# Scaling factor for attention scores
self.scale = 1.0 / math.sqrt(self.d_head_latent)
# Layer norm before attention for stability
self.norm = nn.LayerNorm(d_model)
# Q projection: d_model -> d_model (full dimension)
self.q_proj = nn.Linear(d_model, d_model, bias=False)
# K/V projections: d_model -> d_latent_kv (compressed)
self.k_proj = nn.Linear(d_model, d_latent_kv, bias=False)
self.v_proj = nn.Linear(d_model, d_latent_kv, bias=False)
# RoPE for position encoding with YaRN support
self.rotary = RotaryEmbedding(
d_rope,
rope_scale=rope_scale,
max_seq_length=max_seq_length,
yarn_alpha=yarn_alpha
)
# Output projection: d_latent_kv -> d_model
self.out_proj = nn.Linear(d_latent_kv, d_model, bias=False)
# Head combination weights (learnable scaling per head)
self.head_weights = nn.Parameter(torch.ones(n_heads))
# Dropout
self.attn_dropout = nn.Dropout(dropout)
self.proj_dropout = nn.Dropout(dropout)
def forward(self, x, attention_mask=None):
"""
Args:
x: [B, seq_len, d_model]
attention_mask: [B, seq_len] (1 = keep, 0 = mask) or
[B, 1, seq_len, seq_len] (causal mask)
Returns:
out: [B, seq_len, d_model]
"""
B, seq_len, _ = x.shape
device = x.device
# Pre-norm
x_norm = self.norm(x)
# Project to Q, K, V spaces
q = self.q_proj(x_norm) # [B, seq_len, d_model]
k = self.k_proj(x_norm) # [B, seq_len, d_latent_kv]
v = self.v_proj(x_norm) # [B, seq_len, d_latent_kv]
# ────────────────────────────────────────────────────────────────────────
# Reshape into multi-head format
# ────────────────────────────────────────────────────────────────────────
# Q: [B, seq_len, d_model] -> [B, seq_len, n_heads, d_head_full] -> [B, n_heads, seq_len, d_head_full]
q = q.view(B, seq_len, self.n_heads, self.d_head_full).transpose(1, 2)
# K: [B, seq_len, d_latent_kv] -> [B, seq_len, n_heads, d_head_latent] -> [B, n_heads, seq_len, d_head_latent]
k = k.view(B, seq_len, self.n_heads, self.d_head_latent).transpose(1, 2)
# V: [B, seq_len, d_latent_kv] -> [B, seq_len, n_heads, d_head_latent] -> [B, n_heads, seq_len, d_head_latent]
v = v.view(B, seq_len, self.n_heads, self.d_head_latent).transpose(1, 2)
# ────────────────────────────────────────────────────────────────────────
# Apply RoPE to Q and K
# ────────────────────────────────────────────────────────────────────────
if self.d_rope > 0:
# Generate RoPE embeddings: [seq_len, d_rope]
rotary_emb = self.rotary(seq_len, device) # [seq_len, d_rope]
cos = torch.cos(rotary_emb).unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, d_rope]
sin = torch.sin(rotary_emb).unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, d_rope]
# Apply RoPE to Q (only on first d_rope dimensions)
q_rope = apply_rotary(q[..., :self.d_rope], cos, sin) # [B, n_heads, seq_len, d_rope]
q = torch.cat([q_rope, q[..., self.d_rope:]], dim=-1) # Combine with remaining dims
# Apply RoPE to K (only on first d_rope dimensions)
k_rope = apply_rotary(k[..., :self.d_rope], cos, sin) # [B, n_heads, seq_len, d_rope]
k = torch.cat([k_rope, k[..., self.d_rope:]], dim=-1) # Combine with remaining dims
# ────────────────────────────────────────────────────────────────────────
# Compute attention using PyTorch 2.0+ fused scaled_dot_product_attention
# ────────────────────────────────────────────────────────────────────────
# Only use first d_head_latent dimensions of Q for attention
# K and V are already d_head_latent dimension
q_for_attn = q[..., :self.d_head_latent] # [B, n_heads, seq_len, d_head_latent]
# Convert attention mask to boolean format for scaled_dot_product_attention
# Input mask: 0 = mask (don't attend), 1 = keep (attend)
# Boolean mask: False = mask, True = attend
attn_mask_bool = None
if attention_mask is not None:
if attention_mask.dim() == 2:
# [B, seq_len] with {0, 1} -> [B, 1, 1, seq_len] with {False, True}
attn_mask_bool = attention_mask.bool().unsqueeze(1).unsqueeze(1)
else:
# Already 4D [B, 1, seq_len, seq_len], just convert to bool
attn_mask_bool = attention_mask.bool()
# Get dropout probability (0.0 when not training)
dropout_p = self.attn_dropout.p if self.training else 0.0
if hasattr(F, "scaled_dot_product_attention"):
# Apply fused attention operation when available.
out_heads = F.scaled_dot_product_attention(
q_for_attn, k, v,
attn_mask=attn_mask_bool,
dropout_p=dropout_p,
scale=None
) # [B, n_heads, seq_len, d_head_latent]
else:
scores = torch.matmul(q_for_attn, k.transpose(-2, -1)) * self.scale
if attn_mask_bool is not None:
scores = scores.masked_fill(~attn_mask_bool, torch.finfo(scores.dtype).min)
attn_weights = F.softmax(scores, dim=-1)
if dropout_p > 0.0:
attn_weights = F.dropout(attn_weights, p=dropout_p, training=True)
out_heads = torch.matmul(attn_weights, v)
# ────────────────────────────────────────────────────────────────────────
# Concatenate heads
# ────────────────────────────────────────────────────────────────────────
# [B, seq_len, n_heads, d_head_latent] -> [B, seq_len, d_latent_kv]
out_concat = out_heads.transpose(1, 2).reshape(B, seq_len, self.d_latent_kv)
# Project back to d_model
out = self.out_proj(out_concat) # [B, seq_len, d_model]
out = self.proj_dropout(out)
return out
class AttentionBlock(nn.Module):
"""
Attention block with pre-norm residual connection and feed-forward network.
Structure:
Input
β”œβ”€> Norm ─┬─> MLA ──┬─> Residual Add
β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”˜
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€> Norm ─┬─> SwiGLU FFN ──┬─> Residual Add
β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚
└────────────────────────────────────────────────────────────> Output
"""
def __init__(self, d_model, d_latent_kv, n_heads, d_rope, d_ff, dropout=0.1, gqa_groups=1,
rope_scale=40.0, max_seq_length=1024, yarn_alpha=1.0,
residual_rms_norm=False, residual_rms_target=1.0, residual_rms_cap=None,
residual_rms_eps=1e-6):
super().__init__()
self.residual_rms_norm = residual_rms_norm
self.residual_rms_target = residual_rms_target
self.residual_rms_cap = residual_rms_cap
self.residual_rms_eps = residual_rms_eps
self.mla = DeepSeekMLA(d_model, d_latent_kv, n_heads, d_rope, dropout, gqa_groups,
rope_scale=rope_scale, max_seq_length=max_seq_length,
yarn_alpha=yarn_alpha)
# SwiGLU feed-forward network
self.ff_norm = nn.LayerNorm(d_model)
self.ff_gate = nn.Linear(d_model, d_ff, bias=False)
self.ff_value = nn.Linear(d_model, d_ff, bias=False)
self.ff_out = nn.Linear(d_ff, d_model, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x, attention_mask=None):
"""
Args:
x: [B, seq_len, d_model]
attention_mask: [B, seq_len] or [B, 1, seq_len, seq_len]
Returns:
out: [B, seq_len, d_model]
"""
# Attention with residual
attn_out = self.mla(x, attention_mask)
x = x + self.dropout(attn_out)
x = _residual_rms_norm(
x,
self.residual_rms_norm,
self.residual_rms_target,
self.residual_rms_eps,
self.residual_rms_cap,
)
# FFN with residual
ff_norm = self.ff_norm(x)
ff_gate = self.ff_gate(ff_norm)
ff_value = self.ff_value(ff_norm)
ff_out = ff_value * F.silu(ff_gate) # SwiGLU activation
ff_out = self.ff_out(ff_out)
x = x + self.dropout(ff_out)
x = _residual_rms_norm(
x,
self.residual_rms_norm,
self.residual_rms_target,
self.residual_rms_eps,
self.residual_rms_cap,
)
return x