Spaces:
Paused
Paused
File size: 2,891 Bytes
8125804 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 | import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class MultiHeadAttention(nn.Module):
def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1):
super().__init__()
assert d_model % n_heads == 0
self.d_model = d_model
self.n_heads = n_heads
self.d_head = d_model // n_heads
self.w_q = nn.Linear(d_model, d_model, bias=False)
self.w_k = nn.Linear(d_model, d_model, bias=False)
self.w_v = nn.Linear(d_model, d_model, bias=False)
self.w_o = nn.Linear(d_model, d_model, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
causal: bool = False,
) -> torch.Tensor:
B, T, _ = q.shape
q = self.w_q(q).view(B, T, self.n_heads, self.d_head).transpose(1, 2)
k = self.w_k(k).view(B, -1, self.n_heads, self.d_head).transpose(1, 2)
v = self.w_v(v).view(B, -1, self.n_heads, self.d_head).transpose(1, 2)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_head)
if causal:
T_k = k.size(2)
mask = torch.triu(
torch.ones(T, T_k, device=q.device, dtype=torch.bool), diagonal=1
)
scores.masked_fill_(mask.unsqueeze(0).unsqueeze(0), float("-inf"))
attn = F.softmax(scores, dim=-1)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = out.transpose(1, 2).contiguous().view(B, T, self.d_model)
return self.w_o(out)
class AttentionResidual(nn.Module):
"""Replace standard residual with attention over previous layer outputs.
Each layer learns input-dependent weights for aggregating previous representations.
Solves PreNorm dilution (Kimi Team, 2026, arXiv:2603.15031).
"""
def __init__(self, d_model: int, layer_idx: int):
super().__init__()
self.layer_idx = layer_idx
self.query_proj = nn.Linear(d_model, d_model, bias=False)
self.key_proj = nn.Linear(d_model, d_model, bias=False)
self.layer_norm = nn.LayerNorm(d_model)
def forward(
self, current: torch.Tensor, layer_outputs: list[torch.Tensor]
) -> torch.Tensor:
n_prev = len(layer_outputs)
if n_prev == 0:
return self.layer_norm(current)
# Stack previous outputs: [B, T, N, D]
stacked = torch.stack(layer_outputs, dim=2)
q = self.query_proj(current).unsqueeze(2) # [B, T, 1, D]
k = self.key_proj(stacked) # [B, T, N, D]
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(current.size(-1))
weights = F.softmax(scores, dim=-1) # [B, T, 1, N]
aggregated = torch.matmul(weights, stacked).squeeze(2) # [B, T, D]
return self.layer_norm(current + aggregated)
|