File size: 8,041 Bytes
2ca7d54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
"""
Resonance 200M — Content + RRPRAM dual attention transformer.
Low-rank RRPRAM (Wr = Wr_a @ Wr_b), SwiGLU MLP, RMSNorm, RoPE.
Content attention uses FlashAttention via F.scaled_dot_product_attention.

Architecture matches resonance-bpe.c (with low-rank extension).
"""

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint


class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-5):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight


class ResonanceBlock(nn.Module):
    """
    Dual attention block: Content (QKV + RoPE + FlashAttn) + RRPRAM (low-rank Wr) + SwiGLU MLP.
    """

    def __init__(self, config):
        super().__init__()
        E = config['n_embd']
        H = config['n_head']
        D = config['head_dim']
        R = config['rrpram_rank']
        T = config['context_len']
        M = config['ffn_dim']

        self.n_head = H
        self.head_dim = D
        self.n_embd = E

        # Pre-attention norm
        self.norm1 = RMSNorm(E)

        # Content attention (MHA): Q, K, V
        self.wq = nn.Linear(E, H * D, bias=False)
        self.wk = nn.Linear(E, H * D, bias=False)
        self.wv = nn.Linear(E, H * D, bias=False)

        # RRPRAM attention (low-rank): Wr_a[H, E, R] @ Wr_b[H, R, T] = Wr[H, E, T]
        self.wr_a = nn.Parameter(torch.randn(H, E, R) * (2.0 / E) ** 0.5)
        self.wr_b = nn.Parameter(torch.randn(H, R, T) * (2.0 / R) ** 0.5)

        # Per-head gate: sigmoid(gate) blends content vs RRPRAM
        self.gate = nn.Parameter(torch.zeros(H))  # init 0 → sigmoid(0) = 0.5 = balanced

        # Output projection
        self.wo = nn.Linear(E, E, bias=False)

        # Pre-MLP norm
        self.norm2 = RMSNorm(E)

        # SwiGLU MLP
        self.mlp_gate = nn.Linear(E, M, bias=False)
        self.mlp_up = nn.Linear(E, M, bias=False)
        self.mlp_down = nn.Linear(M, E, bias=False)

        # Init output projections with smaller std (GPT-2 convention)
        n_layer = config['n_layer']
        nn.init.normal_(self.wo.weight, std=0.02 / math.sqrt(2 * n_layer))
        nn.init.normal_(self.mlp_down.weight, std=0.02 / math.sqrt(2 * n_layer))

    def forward(self, x, rope_cos, rope_sin, mask):
        B, T, E = x.shape
        H = self.n_head
        D = self.head_dim

        # Pre-norm
        xn = self.norm1(x)

        # === Content attention with RoPE + FlashAttention ===
        q = self.wq(xn).view(B, T, H, D).transpose(1, 2)  # [B, H, T, D]
        k = self.wk(xn).view(B, T, H, D).transpose(1, 2)
        v = self.wv(xn).view(B, T, H, D).transpose(1, 2)

        # Apply RoPE to Q and K
        q = _apply_rope(q, rope_cos, rope_sin)
        k = _apply_rope(k, rope_cos, rope_sin)

        # FlashAttention — O(T) memory instead of O(T²)
        c_out = F.scaled_dot_product_attention(q, k, v, is_causal=True)  # [B, H, T, D]

        # === RRPRAM attention (low-rank) ===
        # Wr = Wr_a @ Wr_b: [H, E, R] @ [H, R, T] = [H, E, T]
        # Score: xn @ Wr → [B, T, E] @ [H, E, T] → [B, H, T, T]
        xn_h = xn.unsqueeze(1).expand(-1, H, -1, -1)  # [B, H, T, E]
        # Low-rank: (xn @ Wr_a) @ Wr_b
        temp = torch.einsum('bhie,her->bhir', xn_h, self.wr_a)  # [B, H, T, R]
        r_attn = torch.einsum('bhir,hrj->bhij', temp, self.wr_b[:, :, :T])  # [B, H, T, T]
        r_attn = r_attn * (D ** -0.5)
        r_attn = r_attn.masked_fill(mask, float('-inf'))
        r_attn = F.softmax(r_attn, dim=-1)
        r_out = r_attn @ v  # [B, H, T, D] — shared V with content

        # === Gate: blend content and RRPRAM ===
        g = torch.sigmoid(self.gate).view(1, H, 1, 1)  # [1, H, 1, 1]
        attn_out = g * c_out + (1 - g) * r_out  # [B, H, T, D]

        # Output projection + residual
        attn_out = attn_out.transpose(1, 2).contiguous().view(B, T, E)
        x = x + self.wo(attn_out)

        # === SwiGLU MLP ===
        xn = self.norm2(x)
        gate = F.silu(self.mlp_gate(xn))
        up = self.mlp_up(xn)
        x = x + self.mlp_down(gate * up)

        return x


def _apply_rope(x, cos, sin):
    """Apply RoPE to tensor x: [B, H, T, D]."""
    x1 = x[..., ::2]   # even dims
    x2 = x[..., 1::2]  # odd dims
    out = torch.stack([
        x1 * cos - x2 * sin,
        x1 * sin + x2 * cos,
    ], dim=-1).flatten(-2)
    return out


class Resonance(nn.Module):
    """
    Resonance 200M: dual attention (Content + RRPRAM) transformer.
    """

    def __init__(self, config):
        super().__init__()
        self.config = config
        V = config['vocab_size']
        E = config['n_embd']
        T = config['context_len']
        D = config['head_dim']

        # Token embedding (no position — RoPE handles it)
        self.tok_emb = nn.Embedding(V, E)
        nn.init.normal_(self.tok_emb.weight, std=0.02)

        # Transformer blocks
        self.blocks = nn.ModuleList([
            ResonanceBlock(config) for _ in range(config['n_layer'])
        ])

        # Final norm + output head (untied from embedding)
        self.norm_f = RMSNorm(E)
        self.out_head = nn.Linear(E, V, bias=False)
        nn.init.normal_(self.out_head.weight, std=0.02)

        # Precompute RoPE
        freqs = 1.0 / (10000.0 ** (torch.arange(0, D, 2).float() / D))
        t = torch.arange(T).float()
        angles = torch.outer(t, freqs)
        self.register_buffer('rope_cos', angles.cos().unsqueeze(0).unsqueeze(0))  # [1,1,T,D//2]
        self.register_buffer('rope_sin', angles.sin().unsqueeze(0).unsqueeze(0))

        # Causal mask (for RRPRAM — content uses is_causal=True in SDPA)
        mask = torch.triu(torch.ones(T, T, dtype=torch.bool), diagonal=1)
        self.register_buffer('causal_mask', mask)

        n_params = sum(p.numel() for p in self.parameters())
        print(f"  [Resonance] {n_params:,} parameters")
        self._report_balance()

    def _report_balance(self):
        """Report parameter budget distribution."""
        cfg = self.config
        E, H, D = cfg['n_embd'], cfg['n_head'], cfg['head_dim']
        R, T, M = cfg['rrpram_rank'], cfg['context_len'], cfg['ffn_dim']
        V, L = cfg['vocab_size'], cfg['n_layer']

        emb = V * E * 2  # tok_emb + out_head (untied)
        qkv = L * (3 * E * H * D)
        rrpram = L * (H * E * R + H * R * T + H)  # wr_a + wr_b + gate
        wo = L * E * E
        mlp = L * (3 * E * M)
        norms = L * 2 * E + E  # per-block norms + final

        total = emb + qkv + rrpram + wo + mlp + norms
        print(f"  [Resonance] Budget: emb={emb/total*100:.1f}% qkv={qkv/total*100:.1f}% "
              f"rrpram={rrpram/total*100:.1f}% wo={wo/total*100:.1f}% "
              f"mlp={mlp/total*100:.1f}% norms={norms/total*100:.1f}%")

    def set_gradient_checkpointing(self, enable=True):
        self._grad_ckpt = enable

    def forward(self, idx, targets=None):
        B, T = idx.shape
        x = self.tok_emb(idx)

        cos = self.rope_cos[:, :, :T, :]
        sin = self.rope_sin[:, :, :T, :]
        mask = self.causal_mask[:T, :T]

        for block in self.blocks:
            if getattr(self, '_grad_ckpt', False) and self.training:
                x = torch.utils.checkpoint.checkpoint(
                    block, x, cos, sin, mask, use_reentrant=False)
            else:
                x = block(x, cos, sin, mask)

        logits = self.out_head(self.norm_f(x))

        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))

        return logits, loss


# === Default config: ~200M params ===
RESONANCE_200M = {
    'n_embd': 768,
    'n_head': 12,
    'head_dim': 64,        # n_embd // n_head
    'n_layer': 20,
    'rrpram_rank': 48,     # low-rank R
    'context_len': 2048,
    'ffn_dim': 2048,       # round(8*768/3, 256)
    'vocab_size': 16384,   # 256 + 16128 BPE merges
}