nanokimi-mini / src /attention.py
sohv's picture
Upload src/attention.py
7ec2f3b verified
"""
Latent Attention Implementation for nanoKimi
This module implements the Latent Attention mechanism used in Kimi-K2,
which compresses attention representations to reduce memory footprint
while maintaining performance on long sequences.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class LatentAttention(nn.Module):
"""
Latent Attention mechanism that compresses attention representations
The key idea is to project keys and values into a lower-dimensional
latent space, reducing memory usage while preserving attention quality.
Args:
n_embd: embedding dimension
n_head: number of attention heads
latent_dim: dimension of the latent space
dropout: dropout probability
bias: whether to use bias in linear layers
"""
def __init__(self, n_embd, n_head, latent_dim=64, dropout=0.0, bias=True):
super().__init__()
assert n_embd % n_head == 0
self.n_embd = n_embd
self.n_head = n_head
self.latent_dim = latent_dim
self.head_dim = n_embd // n_head
# Query projection (full dimension)
self.q_proj = nn.Linear(n_embd, n_embd, bias=bias)
# Key and Value projections to latent space
self.k_proj = nn.Linear(n_embd, n_head * latent_dim, bias=bias)
self.v_proj = nn.Linear(n_embd, n_head * latent_dim, bias=bias)
# Output projection
self.o_proj = nn.Linear(n_head * latent_dim, n_embd, bias=bias)
# Dropout
self.dropout = nn.Dropout(dropout)
self.resid_dropout = nn.Dropout(dropout)
# Scale factor for attention
self.scale = 1.0 / math.sqrt(latent_dim)
def forward(self, x, mask=None):
B, T, C = x.size() # batch, sequence length, embedding dim
# Project to query, key, value
q = self.q_proj(x) # (B, T, n_embd)
k = self.k_proj(x) # (B, T, n_head * latent_dim)
v = self.v_proj(x) # (B, T, n_head * latent_dim)
# Reshape for multi-head attention
q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2) # (B, n_head, T, head_dim)
k = k.view(B, T, self.n_head, self.latent_dim).transpose(1, 2) # (B, n_head, T, latent_dim)
v = v.view(B, T, self.n_head, self.latent_dim).transpose(1, 2) # (B, n_head, T, latent_dim)
# Compress queries to latent dimension for attention computation
# We use a learnable compression matrix
if not hasattr(self, 'q_compress'):
self.q_compress = nn.Linear(self.head_dim, self.latent_dim, bias=False).to(x.device)
q_compressed = self.q_compress(q) # (B, n_head, T, latent_dim)
# Compute attention scores in latent space
att = torch.matmul(q_compressed, k.transpose(-2, -1)) * self.scale # (B, n_head, T, T)
# Apply causal mask
if mask is not None:
att = att.masked_fill(mask == 0, float('-inf'))
else:
# Create causal mask
causal_mask = torch.tril(torch.ones(T, T, device=x.device)).view(1, 1, T, T)
att = att.masked_fill(causal_mask == 0, float('-inf'))
# Apply softmax
att = F.softmax(att, dim=-1)
att = self.dropout(att)
# Apply attention to values
y = torch.matmul(att, v) # (B, n_head, T, latent_dim)
# Reshape and project back
y = y.transpose(1, 2).contiguous().view(B, T, self.n_head * self.latent_dim)
y = self.o_proj(y)
y = self.resid_dropout(y)
return y
class MultiHeadAttention(nn.Module):
"""
Standard multi-head attention for comparison
"""
def __init__(self, n_embd, n_head, dropout=0.0, bias=True):
super().__init__()
assert n_embd % n_head == 0
self.n_embd = n_embd
self.n_head = n_head
self.head_dim = n_embd // n_head
# QKV projection
self.qkv_proj = nn.Linear(n_embd, 3 * n_embd, bias=bias)
# Output projection
self.o_proj = nn.Linear(n_embd, n_embd, bias=bias)
# Dropout
self.dropout = nn.Dropout(dropout)
self.resid_dropout = nn.Dropout(dropout)
# Scale factor
self.scale = 1.0 / math.sqrt(self.head_dim)
def forward(self, x, mask=None):
B, T, C = x.size()
# Compute QKV
qkv = self.qkv_proj(x)
q, k, v = qkv.chunk(3, dim=-1)
# Reshape for multi-head attention
q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
# Compute attention
att = torch.matmul(q, k.transpose(-2, -1)) * self.scale
# Apply causal mask
if mask is not None:
att = att.masked_fill(mask == 0, float('-inf'))
else:
causal_mask = torch.tril(torch.ones(T, T, device=x.device)).view(1, 1, T, T)
att = att.masked_fill(causal_mask == 0, float('-inf'))
att = F.softmax(att, dim=-1)
att = self.dropout(att)
# Apply attention to values
y = torch.matmul(att, v)
y = y.transpose(1, 2).contiguous().view(B, T, C)
y = self.o_proj(y)
y = self.resid_dropout(y)
return y