File size: 5,574 Bytes
7ec2f3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
"""
Latent Attention Implementation for nanoKimi

This module implements the Latent Attention mechanism used in Kimi-K2,
which compresses attention representations to reduce memory footprint
while maintaining performance on long sequences.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import math


class LatentAttention(nn.Module):
    """
    Latent Attention mechanism that compresses attention representations
    
    The key idea is to project keys and values into a lower-dimensional
    latent space, reducing memory usage while preserving attention quality.
    
    Args:
        n_embd: embedding dimension
        n_head: number of attention heads
        latent_dim: dimension of the latent space
        dropout: dropout probability
        bias: whether to use bias in linear layers
    """
    
    def __init__(self, n_embd, n_head, latent_dim=64, dropout=0.0, bias=True):
        super().__init__()
        assert n_embd % n_head == 0
        
        self.n_embd = n_embd
        self.n_head = n_head
        self.latent_dim = latent_dim
        self.head_dim = n_embd // n_head
        
        # Query projection (full dimension)
        self.q_proj = nn.Linear(n_embd, n_embd, bias=bias)
        
        # Key and Value projections to latent space
        self.k_proj = nn.Linear(n_embd, n_head * latent_dim, bias=bias)
        self.v_proj = nn.Linear(n_embd, n_head * latent_dim, bias=bias)
        
        # Output projection
        self.o_proj = nn.Linear(n_head * latent_dim, n_embd, bias=bias)
        
        # Dropout
        self.dropout = nn.Dropout(dropout)
        self.resid_dropout = nn.Dropout(dropout)
        
        # Scale factor for attention
        self.scale = 1.0 / math.sqrt(latent_dim)
        
    def forward(self, x, mask=None):
        B, T, C = x.size()  # batch, sequence length, embedding dim
        
        # Project to query, key, value
        q = self.q_proj(x)  # (B, T, n_embd)
        k = self.k_proj(x)  # (B, T, n_head * latent_dim)
        v = self.v_proj(x)  # (B, T, n_head * latent_dim)
        
        # Reshape for multi-head attention
        q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)  # (B, n_head, T, head_dim)
        k = k.view(B, T, self.n_head, self.latent_dim).transpose(1, 2)  # (B, n_head, T, latent_dim)
        v = v.view(B, T, self.n_head, self.latent_dim).transpose(1, 2)  # (B, n_head, T, latent_dim)
        
        # Compress queries to latent dimension for attention computation
        # We use a learnable compression matrix
        if not hasattr(self, 'q_compress'):
            self.q_compress = nn.Linear(self.head_dim, self.latent_dim, bias=False).to(x.device)
        
        q_compressed = self.q_compress(q)  # (B, n_head, T, latent_dim)
        
        # Compute attention scores in latent space
        att = torch.matmul(q_compressed, k.transpose(-2, -1)) * self.scale  # (B, n_head, T, T)
        
        # Apply causal mask
        if mask is not None:
            att = att.masked_fill(mask == 0, float('-inf'))
        else:
            # Create causal mask
            causal_mask = torch.tril(torch.ones(T, T, device=x.device)).view(1, 1, T, T)
            att = att.masked_fill(causal_mask == 0, float('-inf'))
        
        # Apply softmax
        att = F.softmax(att, dim=-1)
        att = self.dropout(att)
        
        # Apply attention to values
        y = torch.matmul(att, v)  # (B, n_head, T, latent_dim)
        
        # Reshape and project back
        y = y.transpose(1, 2).contiguous().view(B, T, self.n_head * self.latent_dim)
        y = self.o_proj(y)
        y = self.resid_dropout(y)
        
        return y


class MultiHeadAttention(nn.Module):
    """
    Standard multi-head attention for comparison
    """
    
    def __init__(self, n_embd, n_head, dropout=0.0, bias=True):
        super().__init__()
        assert n_embd % n_head == 0
        
        self.n_embd = n_embd
        self.n_head = n_head
        self.head_dim = n_embd // n_head
        
        # QKV projection
        self.qkv_proj = nn.Linear(n_embd, 3 * n_embd, bias=bias)
        
        # Output projection
        self.o_proj = nn.Linear(n_embd, n_embd, bias=bias)
        
        # Dropout
        self.dropout = nn.Dropout(dropout)
        self.resid_dropout = nn.Dropout(dropout)
        
        # Scale factor
        self.scale = 1.0 / math.sqrt(self.head_dim)
        
    def forward(self, x, mask=None):
        B, T, C = x.size()
        
        # Compute QKV
        qkv = self.qkv_proj(x)
        q, k, v = qkv.chunk(3, dim=-1)
        
        # Reshape for multi-head attention
        q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
        k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
        v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
        
        # Compute attention
        att = torch.matmul(q, k.transpose(-2, -1)) * self.scale
        
        # Apply causal mask
        if mask is not None:
            att = att.masked_fill(mask == 0, float('-inf'))
        else:
            causal_mask = torch.tril(torch.ones(T, T, device=x.device)).view(1, 1, T, T)
            att = att.masked_fill(causal_mask == 0, float('-inf'))
        
        att = F.softmax(att, dim=-1)
        att = self.dropout(att)
        
        # Apply attention to values
        y = torch.matmul(att, v)
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.o_proj(y)
        y = self.resid_dropout(y)
        
        return y