File size: 2,891 Bytes
8125804
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import torch
import torch.nn as nn
import torch.nn.functional as F
import math


class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1):
        super().__init__()
        assert d_model % n_heads == 0
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_head = d_model // n_heads

        self.w_q = nn.Linear(d_model, d_model, bias=False)
        self.w_k = nn.Linear(d_model, d_model, bias=False)
        self.w_v = nn.Linear(d_model, d_model, bias=False)
        self.w_o = nn.Linear(d_model, d_model, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(
        self,
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        causal: bool = False,
    ) -> torch.Tensor:
        B, T, _ = q.shape
        q = self.w_q(q).view(B, T, self.n_heads, self.d_head).transpose(1, 2)
        k = self.w_k(k).view(B, -1, self.n_heads, self.d_head).transpose(1, 2)
        v = self.w_v(v).view(B, -1, self.n_heads, self.d_head).transpose(1, 2)

        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_head)

        if causal:
            T_k = k.size(2)
            mask = torch.triu(
                torch.ones(T, T_k, device=q.device, dtype=torch.bool), diagonal=1
            )
            scores.masked_fill_(mask.unsqueeze(0).unsqueeze(0), float("-inf"))

        attn = F.softmax(scores, dim=-1)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v)
        out = out.transpose(1, 2).contiguous().view(B, T, self.d_model)
        return self.w_o(out)


class AttentionResidual(nn.Module):
    """Replace standard residual with attention over previous layer outputs.
    Each layer learns input-dependent weights for aggregating previous representations.
    Solves PreNorm dilution (Kimi Team, 2026, arXiv:2603.15031).
    """

    def __init__(self, d_model: int, layer_idx: int):
        super().__init__()
        self.layer_idx = layer_idx
        self.query_proj = nn.Linear(d_model, d_model, bias=False)
        self.key_proj = nn.Linear(d_model, d_model, bias=False)
        self.layer_norm = nn.LayerNorm(d_model)

    def forward(
        self, current: torch.Tensor, layer_outputs: list[torch.Tensor]
    ) -> torch.Tensor:
        n_prev = len(layer_outputs)
        if n_prev == 0:
            return self.layer_norm(current)

        # Stack previous outputs: [B, T, N, D]
        stacked = torch.stack(layer_outputs, dim=2)

        q = self.query_proj(current).unsqueeze(2)  # [B, T, 1, D]
        k = self.key_proj(stacked)  # [B, T, N, D]

        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(current.size(-1))
        weights = F.softmax(scores, dim=-1)  # [B, T, 1, N]

        aggregated = torch.matmul(weights, stacked).squeeze(2)  # [B, T, D]

        return self.layer_norm(current + aggregated)