TreeLeek commited on
Commit
f6e3ff4
·
verified ·
1 Parent(s): 207e44b

Upload leeknet_500m.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. leeknet_500m.py +264 -0
leeknet_500m.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ leeknet_500m.py — Scaled TCF-1 architecture for ~500M params.
4
+
5
+ Same hybrid attention + Mamba SSM design as the 36M character-level model.
6
+ Differences:
7
+ - BPE tokenizer (vocab 32k) instead of character-level
8
+ - Wider: n_embed 1024 (vs 512)
9
+ - Deeper: 12 hybrid pairs (vs 4)
10
+ - Longer context: block_size 2048 (vs 512)
11
+ - Persistent SSM state still threads through all pairs and across tokens
12
+
13
+ Architecture (per hybrid pair):
14
+ Attention (reasons over context)
15
+ + Mamba SSM (holds and updates persistent state)
16
+ + FeedForward (transforms)
17
+
18
+ Usage:
19
+ python3 leeknet_500m.py info # show parameter count
20
+ python3 leeknet_500m.py train_a # Stage A pretraining
21
+ python3 leeknet_500m.py train_b # Stage B SFT
22
+ python3 leeknet_500m.py train_c # Stage C voice imprint
23
+ python3 leeknet_500m.py chat # interactive
24
+ """
25
+
26
+ import math
27
+ import json
28
+ import sys
29
+ import time
30
+ from pathlib import Path
31
+
32
+ import mlx.core as mx
33
+ import mlx.nn as nn
34
+ import mlx.optimizers as optim
35
+ import mlx.utils as mlx_utils
36
+ import numpy as np
37
+ import sentencepiece as spm
38
+
39
+ # ---------------------------------------------------------------------------
40
+ # Paths
41
+ # ---------------------------------------------------------------------------
42
+ ROOT = Path(__file__).parent
43
+ TOKENIZER_DIR = ROOT / 'tokenizer'
44
+ DATA_A = ROOT / 'data' / 'A_knowledge'
45
+ DATA_B = ROOT / 'data' / 'B_instruction'
46
+ VOICE_DIR = ROOT / 'memory' / 'corpus'
47
+ CKPT_DIR = ROOT / 'checkpoints_500m'
48
+ CKPT_DIR.mkdir(exist_ok=True)
49
+
50
+ TOKENIZER_MODEL = TOKENIZER_DIR / 'leek_bpe_32k.model'
51
+
52
+ # ---------------------------------------------------------------------------
53
+ # Config — scales from the 36M version
54
+ # ---------------------------------------------------------------------------
55
+ N_VOCAB = 32000 # from BPE tokenizer
56
+ N_EMBED = 1024 # was 512
57
+ N_HEAD = 16 # was 8
58
+ N_PAIRS = 12 # was 4
59
+ SSM_D_STATE = 16
60
+ SSM_D_CONV = 4
61
+ SSM_EXPAND = 2
62
+ DROPOUT = 0.0 # disabled — relying on data diversity
63
+ BLOCK_SIZE = 2048 # was 512
64
+
65
+ # Tools (still emitted as text — harness handles execution)
66
+ TOOLS = ['<none>', 'query_soul', 'bash', 'read_file', 'write_file', 'query_memory']
67
+
68
+ # Training defaults — adjust per stage
69
+ BATCH_SIZE = 8
70
+ LEARN_RATE = 3e-4
71
+ WARMUP_STEPS = 500
72
+ WEIGHT_DECAY = 0.1
73
+
74
+ # ---------------------------------------------------------------------------
75
+ # SSM block — Mamba-style selective state
76
+ # ---------------------------------------------------------------------------
77
+ class MambaBlock(nn.Module):
78
+ def __init__(self, d_model, d_state=16, d_conv=4, expand=2):
79
+ super().__init__()
80
+ self.d_model = d_model
81
+ self.d_state = d_state
82
+ self.d_inner = int(expand * d_model)
83
+
84
+ self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False)
85
+ self.conv1d = nn.Conv1d(
86
+ in_channels=self.d_inner,
87
+ out_channels=self.d_inner,
88
+ kernel_size=d_conv,
89
+ padding=d_conv - 1,
90
+ bias=True,
91
+ )
92
+ self.x_proj = nn.Linear(self.d_inner, d_state * 2 + 1, bias=False)
93
+ self.dt_proj = nn.Linear(1, self.d_inner, bias=True)
94
+ self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)
95
+ self.norm = nn.LayerNorm(d_model)
96
+
97
+ A = np.arange(1, d_state + 1, dtype=np.float32)
98
+ self.A_log = mx.array(np.log(A))
99
+ self.D = mx.ones(self.d_inner)
100
+
101
+ def __call__(self, x, h_prev=None):
102
+ B, T, D = x.shape
103
+ x_in = self.norm(x)
104
+ xz = self.in_proj(x_in)
105
+ x_, z = xz[..., :self.d_inner], xz[..., self.d_inner:]
106
+
107
+ x_conv = self.conv1d(x_)[:, :T, :]
108
+ x_act = mx.maximum(x_conv, 0) * mx.sigmoid(x_conv) # silu-ish
109
+
110
+ xproj = self.x_proj(x_act)
111
+ dt = xproj[..., :1]
112
+ B_ = xproj[..., 1:1+self.d_state]
113
+ C = xproj[..., 1+self.d_state:]
114
+
115
+ delta = nn.softplus(self.dt_proj(dt))
116
+ A = -mx.exp(self.A_log)
117
+
118
+ # serial scan with persistent state
119
+ h = h_prev if h_prev is not None else mx.zeros((B, self.d_inner, self.d_state))
120
+ ys = []
121
+ for t in range(T):
122
+ dt_t = delta[:, t, :] # (B, d_inner)
123
+ x_t = x_act[:, t, :] # (B, d_inner)
124
+ B_t = B_[:, t, :] # (B, d_state)
125
+ C_t = C[:, t, :] # (B, d_state)
126
+
127
+ # discretize A and B per timestep
128
+ dA = mx.exp(dt_t[:, :, None] * A[None, None, :]) # (B, d_inner, d_state)
129
+ dB = dt_t[:, :, None] * B_t[:, None, :] # (B, d_inner, d_state)
130
+
131
+ # state update: h_t = dA * h_{t-1} + dB * x_t
132
+ h = dA * h + dB * x_t[:, :, None] # (B, d_inner, d_state)
133
+
134
+ # output projection: y_t = sum_state(h_t * C_t)
135
+ y = (h * C_t[:, None, :]).sum(axis=-1) # (B, d_inner)
136
+ ys.append(y[:, None, :])
137
+
138
+ y_out = mx.concatenate(ys, axis=1)
139
+ y_out = y_out + self.D * x_act
140
+ y_out = y_out * mx.sigmoid(z)
141
+ return x + self.out_proj(y_out), h
142
+
143
+ # ---------------------------------------------------------------------------
144
+ # Attention block
145
+ # ---------------------------------------------------------------------------
146
+ class AttentionBlock(nn.Module):
147
+ def __init__(self, n_embed, n_head, dropout):
148
+ super().__init__()
149
+ assert n_embed % n_head == 0
150
+ self.n_head = n_head
151
+ self.head_dim = n_embed // n_head
152
+ self.qkv = nn.Linear(n_embed, 3 * n_embed, bias=False)
153
+ self.proj = nn.Linear(n_embed, n_embed, bias=False)
154
+ self.norm = nn.LayerNorm(n_embed)
155
+ self.drop = nn.Dropout(dropout)
156
+
157
+ def __call__(self, x):
158
+ B, T, D = x.shape
159
+ x_in = self.norm(x)
160
+ qkv = self.qkv(x_in)
161
+ qkv = qkv.reshape(B, T, 3, self.n_head, self.head_dim).transpose(2, 0, 3, 1, 4)
162
+ q, k, v = qkv[0], qkv[1], qkv[2]
163
+ scores = (q @ k.transpose(0, 1, 3, 2)) / math.sqrt(self.head_dim)
164
+ mask = mx.tril(mx.ones((T, T))) == 0
165
+ scores = mx.where(mask, -1e9, scores)
166
+ attn = mx.softmax(scores, axis=-1)
167
+ out = (attn @ v).transpose(0, 2, 1, 3).reshape(B, T, D)
168
+ return x + self.drop(self.proj(out))
169
+
170
+ # ---------------------------------------------------------------------------
171
+ # FeedForward
172
+ # ---------------------------------------------------------------------------
173
+ class FeedForward(nn.Module):
174
+ def __init__(self, n_embed, dropout):
175
+ super().__init__()
176
+ self.net = nn.Sequential(
177
+ nn.Linear(n_embed, 4 * n_embed, bias=False),
178
+ nn.GELU(),
179
+ nn.Linear(4 * n_embed, n_embed, bias=False),
180
+ nn.Dropout(dropout),
181
+ )
182
+ self.norm = nn.LayerNorm(n_embed)
183
+
184
+ def __call__(self, x):
185
+ return x + self.net(self.norm(x))
186
+
187
+ # ---------------------------------------------------------------------------
188
+ # Hybrid pair: Attention + SSM + FFN
189
+ # ---------------------------------------------------------------------------
190
+ class HybridPair(nn.Module):
191
+ def __init__(self, n_embed, n_head, dropout):
192
+ super().__init__()
193
+ self.attn = AttentionBlock(n_embed, n_head, dropout)
194
+ self.ssm = MambaBlock(n_embed, SSM_D_STATE, SSM_D_CONV, SSM_EXPAND)
195
+ self.ff = FeedForward(n_embed, dropout)
196
+
197
+ def __call__(self, x, h=None):
198
+ x = self.attn(x)
199
+ x, h = self.ssm(x, h)
200
+ x = self.ff(x)
201
+ return x, h
202
+
203
+ # ---------------------------------------------------------------------------
204
+ # LeekNet 500M
205
+ # ---------------------------------------------------------------------------
206
+ class LeekNet500M(nn.Module):
207
+ def __init__(self, vocab_size=N_VOCAB, n_embed=N_EMBED, n_head=N_HEAD,
208
+ n_pairs=N_PAIRS, block_size=BLOCK_SIZE, dropout=DROPOUT):
209
+ super().__init__()
210
+ self.block_size = block_size
211
+ self.tok_embed = nn.Embedding(vocab_size, n_embed)
212
+ self.pos_embed = nn.Embedding(block_size, n_embed)
213
+ self.drop = nn.Dropout(dropout)
214
+ self.pairs = [HybridPair(n_embed, n_head, dropout) for _ in range(n_pairs)]
215
+ self.ln_final = nn.LayerNorm(n_embed)
216
+ self.lm_head = nn.Linear(n_embed, vocab_size, bias=False)
217
+
218
+ def forward(self, idx, states=None):
219
+ B, T = idx.shape
220
+ pos = mx.arange(T)
221
+ x = self.drop(self.tok_embed(idx) + self.pos_embed(pos))
222
+ if states is None:
223
+ states = [None] * len(self.pairs)
224
+ new_states = []
225
+ for pair, h in zip(self.pairs, states):
226
+ x, h = pair(x, h)
227
+ new_states.append(h)
228
+ x = self.ln_final(x)
229
+ return x, new_states
230
+
231
+ def __call__(self, idx, n_think=1):
232
+ states = None
233
+ for _ in range(n_think):
234
+ x, states = self.forward(idx, states)
235
+ return self.lm_head(x)
236
+
237
+ # ---------------------------------------------------------------------------
238
+ # Quick sanity / param count
239
+ # ---------------------------------------------------------------------------
240
+ def info():
241
+ model = LeekNet500M()
242
+ n_params = sum(v.size for _, v in mlx_utils.tree_flatten(model.parameters()))
243
+ print(f'\nLeekNet 500M:')
244
+ print(f' vocab: {N_VOCAB:,}')
245
+ print(f' n_embed: {N_EMBED}')
246
+ print(f' n_pairs: {N_PAIRS}')
247
+ print(f' n_head: {N_HEAD}')
248
+ print(f' block_size: {BLOCK_SIZE}')
249
+ print(f' parameters: {n_params/1e6:.1f}M')
250
+
251
+ tok = spm.SentencePieceProcessor(model_file=str(TOKENIZER_MODEL))
252
+ print(f' tokenizer: {TOKENIZER_MODEL.name}')
253
+ print(f' vocab_size: {tok.vocab_size()}')
254
+
255
+ # ---------------------------------------------------------------------------
256
+ # Entry
257
+ # ---------------------------------------------------------------------------
258
+ if __name__ == '__main__':
259
+ cmd = sys.argv[1] if len(sys.argv) > 1 else 'info'
260
+ if cmd == 'info':
261
+ info()
262
+ else:
263
+ print(f'training entry points (train_a/b/c) will be wired in next.')
264
+ print(f'usage: python3 leeknet_500m.py info')