|
|
""" |
|
|
Latent Attention Implementation for nanoKimi |
|
|
|
|
|
This module implements the Latent Attention mechanism used in Kimi-K2, |
|
|
which compresses attention representations to reduce memory footprint |
|
|
while maintaining performance on long sequences. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import math |
|
|
|
|
|
|
|
|
class LatentAttention(nn.Module): |
|
|
""" |
|
|
Latent Attention mechanism that compresses attention representations |
|
|
|
|
|
The key idea is to project keys and values into a lower-dimensional |
|
|
latent space, reducing memory usage while preserving attention quality. |
|
|
|
|
|
Args: |
|
|
n_embd: embedding dimension |
|
|
n_head: number of attention heads |
|
|
latent_dim: dimension of the latent space |
|
|
dropout: dropout probability |
|
|
bias: whether to use bias in linear layers |
|
|
""" |
|
|
|
|
|
def __init__(self, n_embd, n_head, latent_dim=64, dropout=0.0, bias=True): |
|
|
super().__init__() |
|
|
assert n_embd % n_head == 0 |
|
|
|
|
|
self.n_embd = n_embd |
|
|
self.n_head = n_head |
|
|
self.latent_dim = latent_dim |
|
|
self.head_dim = n_embd // n_head |
|
|
|
|
|
|
|
|
self.q_proj = nn.Linear(n_embd, n_embd, bias=bias) |
|
|
|
|
|
|
|
|
self.k_proj = nn.Linear(n_embd, n_head * latent_dim, bias=bias) |
|
|
self.v_proj = nn.Linear(n_embd, n_head * latent_dim, bias=bias) |
|
|
|
|
|
|
|
|
self.o_proj = nn.Linear(n_head * latent_dim, n_embd, bias=bias) |
|
|
|
|
|
|
|
|
self.dropout = nn.Dropout(dropout) |
|
|
self.resid_dropout = nn.Dropout(dropout) |
|
|
|
|
|
|
|
|
self.scale = 1.0 / math.sqrt(latent_dim) |
|
|
|
|
|
def forward(self, x, mask=None): |
|
|
B, T, C = x.size() |
|
|
|
|
|
|
|
|
q = self.q_proj(x) |
|
|
k = self.k_proj(x) |
|
|
v = self.v_proj(x) |
|
|
|
|
|
|
|
|
q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2) |
|
|
k = k.view(B, T, self.n_head, self.latent_dim).transpose(1, 2) |
|
|
v = v.view(B, T, self.n_head, self.latent_dim).transpose(1, 2) |
|
|
|
|
|
|
|
|
|
|
|
if not hasattr(self, 'q_compress'): |
|
|
self.q_compress = nn.Linear(self.head_dim, self.latent_dim, bias=False).to(x.device) |
|
|
|
|
|
q_compressed = self.q_compress(q) |
|
|
|
|
|
|
|
|
att = torch.matmul(q_compressed, k.transpose(-2, -1)) * self.scale |
|
|
|
|
|
|
|
|
if mask is not None: |
|
|
att = att.masked_fill(mask == 0, float('-inf')) |
|
|
else: |
|
|
|
|
|
causal_mask = torch.tril(torch.ones(T, T, device=x.device)).view(1, 1, T, T) |
|
|
att = att.masked_fill(causal_mask == 0, float('-inf')) |
|
|
|
|
|
|
|
|
att = F.softmax(att, dim=-1) |
|
|
att = self.dropout(att) |
|
|
|
|
|
|
|
|
y = torch.matmul(att, v) |
|
|
|
|
|
|
|
|
y = y.transpose(1, 2).contiguous().view(B, T, self.n_head * self.latent_dim) |
|
|
y = self.o_proj(y) |
|
|
y = self.resid_dropout(y) |
|
|
|
|
|
return y |
|
|
|
|
|
|
|
|
class MultiHeadAttention(nn.Module): |
|
|
""" |
|
|
Standard multi-head attention for comparison |
|
|
""" |
|
|
|
|
|
def __init__(self, n_embd, n_head, dropout=0.0, bias=True): |
|
|
super().__init__() |
|
|
assert n_embd % n_head == 0 |
|
|
|
|
|
self.n_embd = n_embd |
|
|
self.n_head = n_head |
|
|
self.head_dim = n_embd // n_head |
|
|
|
|
|
|
|
|
self.qkv_proj = nn.Linear(n_embd, 3 * n_embd, bias=bias) |
|
|
|
|
|
|
|
|
self.o_proj = nn.Linear(n_embd, n_embd, bias=bias) |
|
|
|
|
|
|
|
|
self.dropout = nn.Dropout(dropout) |
|
|
self.resid_dropout = nn.Dropout(dropout) |
|
|
|
|
|
|
|
|
self.scale = 1.0 / math.sqrt(self.head_dim) |
|
|
|
|
|
def forward(self, x, mask=None): |
|
|
B, T, C = x.size() |
|
|
|
|
|
|
|
|
qkv = self.qkv_proj(x) |
|
|
q, k, v = qkv.chunk(3, dim=-1) |
|
|
|
|
|
|
|
|
q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2) |
|
|
k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2) |
|
|
v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2) |
|
|
|
|
|
|
|
|
att = torch.matmul(q, k.transpose(-2, -1)) * self.scale |
|
|
|
|
|
|
|
|
if mask is not None: |
|
|
att = att.masked_fill(mask == 0, float('-inf')) |
|
|
else: |
|
|
causal_mask = torch.tril(torch.ones(T, T, device=x.device)).view(1, 1, T, T) |
|
|
att = att.masked_fill(causal_mask == 0, float('-inf')) |
|
|
|
|
|
att = F.softmax(att, dim=-1) |
|
|
att = self.dropout(att) |
|
|
|
|
|
|
|
|
y = torch.matmul(att, v) |
|
|
y = y.transpose(1, 2).contiguous().view(B, T, C) |
|
|
y = self.o_proj(y) |
|
|
y = self.resid_dropout(y) |
|
|
|
|
|
return y |
|
|
|