File size: 5,574 Bytes
7ec2f3b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
"""
Latent Attention Implementation for nanoKimi
This module implements the Latent Attention mechanism used in Kimi-K2,
which compresses attention representations to reduce memory footprint
while maintaining performance on long sequences.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class LatentAttention(nn.Module):
"""
Latent Attention mechanism that compresses attention representations
The key idea is to project keys and values into a lower-dimensional
latent space, reducing memory usage while preserving attention quality.
Args:
n_embd: embedding dimension
n_head: number of attention heads
latent_dim: dimension of the latent space
dropout: dropout probability
bias: whether to use bias in linear layers
"""
def __init__(self, n_embd, n_head, latent_dim=64, dropout=0.0, bias=True):
super().__init__()
assert n_embd % n_head == 0
self.n_embd = n_embd
self.n_head = n_head
self.latent_dim = latent_dim
self.head_dim = n_embd // n_head
# Query projection (full dimension)
self.q_proj = nn.Linear(n_embd, n_embd, bias=bias)
# Key and Value projections to latent space
self.k_proj = nn.Linear(n_embd, n_head * latent_dim, bias=bias)
self.v_proj = nn.Linear(n_embd, n_head * latent_dim, bias=bias)
# Output projection
self.o_proj = nn.Linear(n_head * latent_dim, n_embd, bias=bias)
# Dropout
self.dropout = nn.Dropout(dropout)
self.resid_dropout = nn.Dropout(dropout)
# Scale factor for attention
self.scale = 1.0 / math.sqrt(latent_dim)
def forward(self, x, mask=None):
B, T, C = x.size() # batch, sequence length, embedding dim
# Project to query, key, value
q = self.q_proj(x) # (B, T, n_embd)
k = self.k_proj(x) # (B, T, n_head * latent_dim)
v = self.v_proj(x) # (B, T, n_head * latent_dim)
# Reshape for multi-head attention
q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2) # (B, n_head, T, head_dim)
k = k.view(B, T, self.n_head, self.latent_dim).transpose(1, 2) # (B, n_head, T, latent_dim)
v = v.view(B, T, self.n_head, self.latent_dim).transpose(1, 2) # (B, n_head, T, latent_dim)
# Compress queries to latent dimension for attention computation
# We use a learnable compression matrix
if not hasattr(self, 'q_compress'):
self.q_compress = nn.Linear(self.head_dim, self.latent_dim, bias=False).to(x.device)
q_compressed = self.q_compress(q) # (B, n_head, T, latent_dim)
# Compute attention scores in latent space
att = torch.matmul(q_compressed, k.transpose(-2, -1)) * self.scale # (B, n_head, T, T)
# Apply causal mask
if mask is not None:
att = att.masked_fill(mask == 0, float('-inf'))
else:
# Create causal mask
causal_mask = torch.tril(torch.ones(T, T, device=x.device)).view(1, 1, T, T)
att = att.masked_fill(causal_mask == 0, float('-inf'))
# Apply softmax
att = F.softmax(att, dim=-1)
att = self.dropout(att)
# Apply attention to values
y = torch.matmul(att, v) # (B, n_head, T, latent_dim)
# Reshape and project back
y = y.transpose(1, 2).contiguous().view(B, T, self.n_head * self.latent_dim)
y = self.o_proj(y)
y = self.resid_dropout(y)
return y
class MultiHeadAttention(nn.Module):
"""
Standard multi-head attention for comparison
"""
def __init__(self, n_embd, n_head, dropout=0.0, bias=True):
super().__init__()
assert n_embd % n_head == 0
self.n_embd = n_embd
self.n_head = n_head
self.head_dim = n_embd // n_head
# QKV projection
self.qkv_proj = nn.Linear(n_embd, 3 * n_embd, bias=bias)
# Output projection
self.o_proj = nn.Linear(n_embd, n_embd, bias=bias)
# Dropout
self.dropout = nn.Dropout(dropout)
self.resid_dropout = nn.Dropout(dropout)
# Scale factor
self.scale = 1.0 / math.sqrt(self.head_dim)
def forward(self, x, mask=None):
B, T, C = x.size()
# Compute QKV
qkv = self.qkv_proj(x)
q, k, v = qkv.chunk(3, dim=-1)
# Reshape for multi-head attention
q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
# Compute attention
att = torch.matmul(q, k.transpose(-2, -1)) * self.scale
# Apply causal mask
if mask is not None:
att = att.masked_fill(mask == 0, float('-inf'))
else:
causal_mask = torch.tril(torch.ones(T, T, device=x.device)).view(1, 1, T, T)
att = att.masked_fill(causal_mask == 0, float('-inf'))
att = F.softmax(att, dim=-1)
att = self.dropout(att)
# Apply attention to values
y = torch.matmul(att, v)
y = y.transpose(1, 2).contiguous().view(B, T, C)
y = self.o_proj(y)
y = self.resid_dropout(y)
return y
|