kgrabko commited on
Commit
e9bd681
·
verified ·
1 Parent(s): 6dd0c2d

Update JiRackPyTorch_GPT5_class_1b.py

Browse files
Files changed (1) hide show
  1. JiRackPyTorch_GPT5_class_1b.py +475 -481
JiRackPyTorch_GPT5_class_1b.py CHANGED
@@ -1,482 +1,476 @@
1
- # Copyright (c) 2025 CMS Manhattan
2
- # All rights reserved.
3
- # Author: Konstantin Vladimirovich Grabko
4
- # Email: grabko@cmsmanhattan.com
5
- # Phone: +1(516)777-0945
6
- #
7
- # MIT License
8
- #
9
- # Copyright (c) 2025 Konstantin Grabko
10
- #
11
- # Permission is hereby granted, free of charge, to any person obtaining a copy
12
- # of this software and associated documentation files (the "Software"), to deal
13
- # in the Software without restriction, including without limitation the rights
14
- # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
15
- # copies of the Software, and to permit persons to whom the Software is
16
- # furnished to do so, subject to the following conditions:
17
- #
18
- # The above copyright notice and this permission notice shall be included in all
19
- # copies or substantial portions of the Software.
20
- #
21
- # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
22
- # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
23
- # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
24
- # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
25
- # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
26
- # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
27
- # SOFTWARE.
28
-
29
- """
30
- JiRackPyTorch 1B Model Definition
31
- Complete and final version with SWA, RoPE Scaling, and full generative sampling.
32
- FIXED: Test harness unpacking bug resolved.
33
- """
34
-
35
- import os
36
- import torch
37
- import torch.nn as nn
38
- import torch.nn.functional as F
39
- from typing import Optional, List, Tuple
40
- from pathlib import Path
41
- import math
42
- import torch.utils.checkpoint
43
-
44
- # ========================================
45
- # Model Configuration (Llama-Style 1B)
46
- # ~0.94 B params
47
- # ========================================
48
- VOCAB_SIZE = 50257
49
- MODEL_DIM = 2048
50
- NUM_HEADS = 32
51
- NUM_LAYERS = 16
52
- MAX_SEQ_LEN = 2048 # Training length
53
- FFN_HIDDEN_DIM = MODEL_DIM * 4
54
- HEAD_DIM = MODEL_DIM // NUM_HEADS
55
- EPSILON = 1e-6
56
- DROPOUT_RATE = 0.1
57
-
58
- # --- Sliding Window Attention Parameter ---
59
- WINDOW_SIZE = 512 # The size of the local attention window and maximum KV cache size
60
- # ---------------------------------------------
61
-
62
- # --- 1. RMSNorm ---
63
- class RMSNorm(nn.Module):
64
- def __init__(self, dim: int, eps: float = EPSILON):
65
- super().__init__()
66
- self.eps = eps
67
- self.weight = nn.Parameter(torch.ones(dim))
68
-
69
- def _norm(self, x):
70
- return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
71
-
72
- def forward(self, x):
73
- return self._norm(x) * self.weight
74
-
75
- # --- 2. Rotary Positional Embedding (RoPE) with Context Scaling ---
76
- def precompute_freqs_cis(dim: int, seq_len: int, theta: float = 10000.0, max_seq_len: int = MAX_SEQ_LEN):
77
- freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
78
-
79
- if seq_len > max_seq_len:
80
- scale_factor = seq_len / max_seq_len
81
- t = torch.arange(seq_len, dtype=torch.float32) / scale_factor
82
- else:
83
- t = torch.arange(seq_len, dtype=torch.float32)
84
-
85
- freqs = torch.outer(t, freqs)
86
- freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
87
- return freqs_cis
88
-
89
- def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
90
- return freqs_cis[None, None, :, None, :]
91
-
92
- def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor):
93
- dtype = xq.dtype
94
-
95
- xq_f = xq.float().reshape(*xq.shape[:-1], -1, 2)
96
- xk_f = xk.float().reshape(*xk.shape[:-1], -1, 2)
97
- xq_ = torch.view_as_complex(xq_f)
98
- xk_ = torch.view_as_complex(xk_f)
99
-
100
- freqs_cis_broadcast = reshape_for_broadcast(freqs_cis, xq_)
101
- xq_rot = xq_ * freqs_cis_broadcast.squeeze(3)
102
- xk_rot = xk_ * freqs_cis_broadcast.squeeze(3)
103
-
104
- xq_out = torch.view_as_real(xq_rot).flatten(3)
105
- xk_out = torch.view_as_real(xk_rot).flatten(3)
106
-
107
- return xq_out.type(dtype), xk_out.type(dtype)
108
-
109
- # --- 3. MultiHeadAttention (SWA/SAPA Enabled, Cache Truncation Fixed) ---
110
- class MultiHeadAttention(nn.Module):
111
- def __init__(self, window_size: int = WINDOW_SIZE):
112
- super().__init__()
113
- self.q_proj = nn.Linear(MODEL_DIM, MODEL_DIM, bias=False)
114
- self.k_proj = nn.Linear(MODEL_DIM, MODEL_DIM, bias=False)
115
- self.v_proj = nn.Linear(MODEL_DIM, MODEL_DIM, bias=False)
116
- self.out_proj = nn.Linear(MODEL_DIM, MODEL_DIM, bias=False)
117
- self.scale = HEAD_DIM ** -0.5
118
- self.window_size = window_size
119
-
120
- self._build_rope_buffers(MAX_SEQ_LEN)
121
-
122
- def _build_rope_buffers(self, max_context_len: int):
123
- freqs_cis = precompute_freqs_cis(HEAD_DIM, max_context_len)
124
- self.register_buffer("freqs_cis", freqs_cis, persistent=False)
125
-
126
- def forward(self, x: torch.Tensor, pos_offset: int, past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None):
127
- device = x.device
128
- B, T, D = x.shape
129
- dtype = x.dtype
130
-
131
- # --- Context Scaling (RoPE) Check and Update ---
132
- total_len = T + pos_offset
133
- if total_len > self.freqs_cis.size(0):
134
- new_freqs_cis = precompute_freqs_cis(HEAD_DIM, total_len).to(device)
135
- self.freqs_cis = new_freqs_cis
136
-
137
- q = self.q_proj(x).view(B, T, NUM_HEADS, HEAD_DIM).transpose(1, 2)
138
- current_k = self.k_proj(x).view(B, T, NUM_HEADS, HEAD_DIM).transpose(1, 2)
139
- current_v = self.v_proj(x).view(B, T, NUM_HEADS, HEAD_DIM).transpose(1, 2)
140
-
141
- # Apply RoPE
142
- cur_freqs_cis = self.freqs_cis[pos_offset : pos_offset + T].to(device)
143
- q, k = apply_rotary_emb(q, current_k, cur_freqs_cis)
144
- v = current_v
145
-
146
- new_kv = None
147
-
148
- # --- Handle initialization and enforce WINDOW_SIZE truncation ---
149
- if past_kv is None or past_kv[0] is None:
150
- if T > self.window_size:
151
- new_kv = (k[:, :, -self.window_size:], v[:, :, -self.window_size:])
152
- k = k[:, :, -self.window_size:]
153
- v = v[:, :, -self.window_size:]
154
- else:
155
- new_kv = (k, v)
156
-
157
- elif past_kv[0] is not None:
158
- past_k, past_v = past_kv
159
- cache_len = past_k.size(2)
160
-
161
- sapa_start_idx = max(0, cache_len - (self.window_size - T))
162
-
163
- k_windowed = past_k[:, :, sapa_start_idx:, :]
164
- v_windowed = past_v[:, :, sapa_start_idx:, :]
165
-
166
- k = torch.cat([k_windowed, k], dim=2)
167
- v = torch.cat([v_windowed, v], dim=2)
168
-
169
- full_new_k = torch.cat([past_k, current_k], dim=2)
170
- full_new_v = torch.cat([past_v, current_v], dim=2)
171
-
172
- new_kv = (full_new_k[:, :, -self.window_size:], full_new_v[:, :, -self.window_size:])
173
-
174
- seqlen_k = k.size(2)
175
-
176
- # Attention in FP32 for stability
177
- q_stab = q.float()
178
- k_stab = k.float()
179
- v_stab = v.float()
180
-
181
- attn_weights = torch.matmul(q_stab, k_stab.transpose(-2, -1)) * self.scale
182
-
183
- # Causal Mask
184
- past_len_visible = seqlen_k - T
185
- mask = torch.full((T, seqlen_k), float('-inf'), device=device, dtype=torch.float32)
186
- mask = torch.triu(mask, diagonal=past_len_visible + 1).unsqueeze(0).unsqueeze(0)
187
-
188
- attn_weights = attn_weights + mask
189
- attn_weights = F.softmax(attn_weights, dim=-1)
190
-
191
- out_raw = torch.matmul(attn_weights, v_stab)
192
- out = out_raw.transpose(1, 2).contiguous().view(B, T, D)
193
- out = self.out_proj(out)
194
-
195
- return out.type(dtype), new_kv
196
-
197
- # --- 4. SwiGLU Feed-Forward ---
198
- class SwiGLUFeedForward(nn.Module):
199
- def __init__(self):
200
- super().__init__()
201
- self.w1 = nn.Linear(MODEL_DIM, FFN_HIDDEN_DIM, bias=False)
202
- self.w3 = nn.Linear(MODEL_DIM, FFN_HIDDEN_DIM, bias=False)
203
- self.w2 = nn.Linear(FFN_HIDDEN_DIM, MODEL_DIM, bias=False)
204
- self.dropout = nn.Dropout(DROPOUT_RATE)
205
-
206
- def forward(self, x):
207
- up_output = self.w3(x)
208
- gate_output = self.w1(x)
209
- swiglu_output = F.silu(gate_output) * up_output
210
- out = self.w2(swiglu_output)
211
- out = self.dropout(out)
212
- return out
213
-
214
- # --- 5. Transformer Block ---
215
- class TransformerBlock(nn.Module):
216
- def __init__(self):
217
- super().__init__()
218
- self.attn = MultiHeadAttention(window_size=WINDOW_SIZE)
219
- self.ffn = SwiGLUFeedForward()
220
- self.norm1 = RMSNorm(MODEL_DIM)
221
- self.norm2 = RMSNorm(MODEL_DIM)
222
- self.attn_dropout = nn.Dropout(DROPOUT_RATE)
223
-
224
- def forward(self, x, pos_offset: int, past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None):
225
-
226
- if self.training and getattr(self, 'model', None) and self.model.gradient_checkpointing:
227
- if past_kv is None:
228
- def create_forward_function(attn, ffn, norm1, norm2, attn_dropout, pos_offset):
229
- def forward_fn(x):
230
- norm_x1 = norm1(x)
231
- attn_out, _ = attn(norm_x1, pos_offset, None)
232
- x = x + attn_dropout(attn_out)
233
-
234
- norm_x2 = norm2(x)
235
- x = x + ffn(norm_x2)
236
- return x
237
- return forward_fn
238
-
239
- x = torch.utils.checkpoint.checkpoint(
240
- create_forward_function(self.attn, self.ffn, self.norm1, self.norm2, self.attn_dropout, pos_offset),
241
- x, use_reentrant=False, preserve_rng_state=True
242
- )
243
- new_kv = None
244
- else:
245
- norm_x = self.norm1(x)
246
- attn_out, new_kv = self.attn(norm_x, pos_offset, past_kv)
247
- x = x + self.attn_dropout(attn_out)
248
-
249
- norm_x = self.norm2(x)
250
- x = x + self.ffn(norm_x)
251
-
252
- else:
253
- norm_x = self.norm1(x)
254
- attn_out, new_kv = self.attn(norm_x, pos_offset, past_kv)
255
- x = x + self.attn_dropout(attn_out)
256
-
257
- norm_x = self.norm2(x)
258
- x = x + self.ffn(norm_x)
259
-
260
- return x, new_kv
261
-
262
- # --- 6. Main Model (JiRackPyTorch) - FINAL VERSION ---
263
- class JiRackPyTorch(nn.Module):
264
- def __init__(self):
265
- super().__init__()
266
- self.token_emb = nn.Embedding(VOCAB_SIZE, MODEL_DIM)
267
- self.blocks = nn.ModuleList([TransformerBlock() for _ in range(NUM_LAYERS)])
268
- self.ln_f = RMSNorm(MODEL_DIM)
269
- self.lm_head = nn.Linear(MODEL_DIM, VOCAB_SIZE, bias=False)
270
- self.emb_dropout = nn.Dropout(DROPOUT_RATE)
271
-
272
- self.apply(self._init_weights)
273
- self.lm_head.weight = self.token_emb.weight
274
-
275
- self.gradient_checkpointing = False
276
-
277
- signature = "Konstantin V Grabko . original author 2025"
278
- self.register_buffer("proof_of_authorship_cmsmanhattan", torch.tensor([ord(c) for c in signature], dtype=torch.uint8), persistent=False)
279
- self.register_buffer("birth_date", torch.tensor([20251127], dtype=torch.int64), persistent=False)
280
-
281
- for block in self.blocks:
282
- object.__setattr__(block, 'model', self)
283
-
284
- def _init_weights(self, module):
285
- if isinstance(module, nn.Linear):
286
- torch.nn.init.normal_(module.weight, mean=0.0, std=0.02 / math.sqrt(2 * NUM_LAYERS))
287
- if module.bias is not None:
288
- torch.nn.init.zeros_(module.bias)
289
- elif isinstance(module, nn.Embedding):
290
- torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
291
- elif isinstance(module, RMSNorm):
292
- nn.init.ones_(module.weight)
293
-
294
- if isinstance(module, nn.Linear) and hasattr(self, 'lm_head') and module is self.lm_head:
295
- nn.init.normal_(module.weight, mean=0.0, std=0.01)
296
-
297
- def gradient_checkpointing_enable(self):
298
- self.gradient_checkpointing = True
299
-
300
- def gradient_checkpointing_disable(self):
301
- self.gradient_checkpointing = False
302
-
303
- def forward(self, input_ids: torch.Tensor, past_kv: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None):
304
- x = self.token_emb(input_ids)
305
- x = self.emb_dropout(x)
306
-
307
- pos_offset = 0
308
- if past_kv is not None and past_kv[0] is not None and past_kv[0][0] is not None:
309
- pos_offset = past_kv[0][0].size(2)
310
-
311
- new_kv_cache = [] if past_kv is not None else None
312
- current_past = past_kv
313
-
314
- for i, block in enumerate(self.blocks):
315
- layer_past = current_past[i] if current_past and i < len(current_past) else None
316
-
317
- x, layer_kv = block(x, pos_offset, layer_past)
318
-
319
- if new_kv_cache is not None and layer_kv is not None:
320
- new_kv_cache.append(layer_kv)
321
-
322
- x = self.ln_f(x)
323
- logits = self.lm_head(x)
324
-
325
- return logits if past_kv is None else (logits, new_kv_cache)
326
-
327
- @torch.no_grad()
328
- def generate(self, input_ids: torch.Tensor, max_new_tokens: int = 100, temperature: float = 0.8, top_p: float = 0.95, repetition_penalty: float = 1.0, do_sample: bool = True, eos_token_id: int = 50256) -> torch.Tensor:
329
- B, T = input_ids.shape
330
- device = input_ids.device
331
-
332
- # Prefill Step
333
- past_kv = [None] * NUM_LAYERS
334
- forward_output = self(input_ids, past_kv=past_kv)
335
-
336
- if isinstance(forward_output, tuple):
337
- if len(forward_output) != 2:
338
- raise ValueError(f"CRITICAL ERROR: forward returned {len(forward_output)} outputs in prefill.")
339
- logits, past_kv = forward_output
340
- else:
341
- logits = forward_output
342
- past_kv = [None] * NUM_LAYERS
343
-
344
- last_logits = logits[:, -1, :]
345
- output_ids = input_ids.clone()
346
-
347
- for _ in range(max_new_tokens):
348
- if repetition_penalty != 1.0:
349
- unique_tokens = output_ids.unique()
350
- for token_id in unique_tokens:
351
- tid = token_id.item()
352
- if output_ids.tolist().count(tid) > 0:
353
- log_prob = last_logits[:, tid]
354
- last_logits[:, tid] = torch.where(log_prob > 0, log_prob / repetition_penalty, log_prob * repetition_penalty)
355
-
356
- if temperature == 0.0 or not do_sample:
357
- next_token = torch.argmax(last_logits, dim=-1, keepdim=True)
358
- else:
359
- logits_temp = last_logits.float() / temperature
360
- probs = F.softmax(logits_temp, dim=-1)
361
-
362
- if top_p < 1.0:
363
- sorted_probs, sorted_indices = torch.sort(probs, dim=-1, descending=True)
364
- cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
365
- mask = cumulative_probs > top_p
366
- mask[..., 1:] = mask[..., :-1].clone()
367
- mask[..., 0] = False
368
- sorted_probs[mask] = 0.0
369
- sorted_probs = sorted_probs / (sorted_probs.sum(dim=-1, keepdim=True) + 1e-9)
370
- next_token_index = torch.multinomial(sorted_probs, num_samples=1)
371
- next_token = torch.gather(sorted_indices, -1, next_token_index)
372
- else:
373
- next_token = torch.multinomial(probs, num_samples=1)
374
-
375
- if next_token.item() == eos_token_id:
376
- break
377
-
378
- output_ids = torch.cat([output_ids, next_token], dim=-1)
379
- next_input = next_token
380
-
381
- forward_output = self(next_input, past_kv=past_kv)
382
-
383
- if isinstance(forward_output, tuple):
384
- if len(forward_output) != 2:
385
- raise ValueError(f"CRITICAL ERROR: forward returned {len(forward_output)} outputs in decode loop.")
386
- logits_out, past_kv = forward_output
387
- else:
388
- logits_out = forward_output
389
-
390
- last_logits = logits_out[:, -1, :]
391
-
392
- return output_ids.squeeze(0)
393
-
394
-
395
- # === EXPORT SCRIPT (Testing SWA and Generation functionality) ===
396
- if __name__ == "__main__":
397
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
398
- print(f"Creating 0.94B-parameter Llama-style model with SWA on {device}...")
399
-
400
- model = JiRackPyTorch().to(device)
401
- model.eval()
402
-
403
- # Ensure RoPE freqs are on the target device
404
- for name, module in model.named_modules():
405
- if isinstance(module, MultiHeadAttention):
406
- if module.freqs_cis.device != device:
407
- module._build_rope_buffers(MAX_SEQ_LEN)
408
- module.freqs_cis = module.freqs_cis.to(device)
409
-
410
- total_params = sum(p.numel() for p in model.parameters()) / 1e9
411
- print(f"Model ready. Parameters: {total_params:.2f}B. SWA Window Size: {WINDOW_SIZE}")
412
-
413
- # --- SWA TEST ---
414
- print("\n--- Testing SWA/KV Cache Truncation (Inference) ---")
415
- large_input = torch.randint(0, VOCAB_SIZE, (1, WINDOW_SIZE * 2), device=device)
416
-
417
- with torch.no_grad():
418
- output = model(large_input, past_kv=[None] * NUM_LAYERS)
419
- if isinstance(output, tuple):
420
- logits_out, kv_cache = output
421
- else:
422
- logits_out = output
423
- kv_cache = [None] * NUM_LAYERS
424
-
425
- first_layer_cache_size = kv_cache[0][0].size(2) if kv_cache and kv_cache[0] is not None else 0
426
-
427
- print(f"Initial Prefill Length: {large_input.size(1)}. Cache Size after Prefill: {first_layer_cache_size}")
428
- if first_layer_cache_size == WINDOW_SIZE:
429
- print("✅ Cache Truncation (SWA) successful.")
430
- else:
431
- print(f"❌ Cache Truncation (SWA) failed. Expected {WINDOW_SIZE}, got {first_layer_cache_size}")
432
-
433
- single_token = torch.randint(0, VOCAB_SIZE, (1, 1), device=device)
434
- with torch.no_grad():
435
- output = model(single_token, past_kv=kv_cache)
436
- if isinstance(output, tuple):
437
- logits_out, final_kv_cache = output
438
- else:
439
- logits_out = output
440
- final_kv_cache = kv_cache
441
-
442
- final_cache_size = final_kv_cache[0][0].size(2) if final_kv_cache and final_kv_cache[0] is not None else 0
443
- print(f"Cache Size after 1 token generation: {final_cache_size}")
444
- if final_cache_size == WINDOW_SIZE:
445
- print("✅ SWA cache size remains fixed during generation.")
446
- else:
447
- print(f"❌ SWA cache size changed. Expected {WINDOW_SIZE}, got {final_cache_size}")
448
-
449
- # --- GENERATE TEST ---
450
- print("\n--- Testing Generation Loop ---")
451
-
452
- prompt = torch.randint(0, VOCAB_SIZE, (1, 10), device=device)
453
- max_tokens_to_generate = 20
454
-
455
- try:
456
- generated_ids = model.generate(
457
- prompt,
458
- max_new_tokens=max_tokens_to_generate,
459
- temperature=0.7,
460
- top_p=0.9,
461
- do_sample=True,
462
- eos_token_id=-1
463
- )
464
-
465
- generated_new_tokens = generated_ids.size(0) - prompt.size(1)
466
- print(f"Prompt length: {prompt.size(1)}")
467
- print(f"Generated new tokens: {generated_new_tokens}")
468
-
469
- if generated_new_tokens == max_tokens_to_generate:
470
- print("✅ Generation output length is correct.")
471
- else:
472
- print(f"⚠️ Generation stopped early (should not happen with eos_token_id=-1), got {generated_new_tokens} new tokens.")
473
-
474
- print("✅ Generation Test Succeeded (no errors)!")
475
-
476
- except Exception as e:
477
- print(f"❌ Generation Test Failed: {e}")
478
-
479
- state_dict_path = Path("models/jirack_swa_1b_class.state_dict.pt")
480
- state_dict_path.parent.mkdir(parents=True, exist_ok=True)
481
- torch.save(model.state_dict(), state_dict_path)
482
  print(f"\nFinal state_dict saved to → {state_dict_path}")
 
1
+ # Copyright (c) 2025 CMS Manhattan
2
+ # All rights reserved.
3
+ # Author: Konstantin Vladimirovich Grabko
4
+ # Email: grabko@cmsmanhattan.com
5
+ # Phone: +1(516)777-0945
6
+ #
7
+ # This program is free software: you can redistribute it and/or modify
8
+ # it under the terms of the GNU General Public License as published by
9
+ # the Free Software Foundation, version 3 of the License.
10
+ #
11
+ # This program is distributed in the hope that it will be useful,
12
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
13
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
+ # GNU General Public License for more details.
15
+ #
16
+ # You should have received a copy of the GNU General Public License
17
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
18
+ #
19
+ # Additional terms:
20
+ # Any commercial use or distribution of this software or derivative works
21
+ # requires explicit written permission from the copyright holder.
22
+
23
+ """
24
+ JiRackPyTorch 1B Model Definition
25
+ Complete and final version with SWA, RoPE Scaling, and full generative sampling.
26
+ FIXED: Test harness unpacking bug resolved.
27
+ """
28
+
29
+ import os
30
+ import torch
31
+ import torch.nn as nn
32
+ import torch.nn.functional as F
33
+ from typing import Optional, List, Tuple
34
+ from pathlib import Path
35
+ import math
36
+ import torch.utils.checkpoint
37
+
38
+ # ========================================
39
+ # Model Configuration (Llama-Style 1B)
40
+ # ~0.94 B params
41
+ # ========================================
42
+ VOCAB_SIZE = 50257
43
+ MODEL_DIM = 2048
44
+ NUM_HEADS = 32
45
+ NUM_LAYERS = 16
46
+ MAX_SEQ_LEN = 2048 # Training length
47
+ FFN_HIDDEN_DIM = MODEL_DIM * 4
48
+ HEAD_DIM = MODEL_DIM // NUM_HEADS
49
+ EPSILON = 1e-6
50
+ DROPOUT_RATE = 0.1
51
+
52
+ # --- Sliding Window Attention Parameter ---
53
+ WINDOW_SIZE = 512 # The size of the local attention window and maximum KV cache size
54
+ # ---------------------------------------------
55
+
56
+ # --- 1. RMSNorm ---
57
+ class RMSNorm(nn.Module):
58
+ def __init__(self, dim: int, eps: float = EPSILON):
59
+ super().__init__()
60
+ self.eps = eps
61
+ self.weight = nn.Parameter(torch.ones(dim))
62
+
63
+ def _norm(self, x):
64
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
65
+
66
+ def forward(self, x):
67
+ return self._norm(x) * self.weight
68
+
69
+ # --- 2. Rotary Positional Embedding (RoPE) with Context Scaling ---
70
+ def precompute_freqs_cis(dim: int, seq_len: int, theta: float = 10000.0, max_seq_len: int = MAX_SEQ_LEN):
71
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
72
+
73
+ if seq_len > max_seq_len:
74
+ scale_factor = seq_len / max_seq_len
75
+ t = torch.arange(seq_len, dtype=torch.float32) / scale_factor
76
+ else:
77
+ t = torch.arange(seq_len, dtype=torch.float32)
78
+
79
+ freqs = torch.outer(t, freqs)
80
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
81
+ return freqs_cis
82
+
83
+ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
84
+ return freqs_cis[None, None, :, None, :]
85
+
86
+ def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor):
87
+ dtype = xq.dtype
88
+
89
+ xq_f = xq.float().reshape(*xq.shape[:-1], -1, 2)
90
+ xk_f = xk.float().reshape(*xk.shape[:-1], -1, 2)
91
+ xq_ = torch.view_as_complex(xq_f)
92
+ xk_ = torch.view_as_complex(xk_f)
93
+
94
+ freqs_cis_broadcast = reshape_for_broadcast(freqs_cis, xq_)
95
+ xq_rot = xq_ * freqs_cis_broadcast.squeeze(3)
96
+ xk_rot = xk_ * freqs_cis_broadcast.squeeze(3)
97
+
98
+ xq_out = torch.view_as_real(xq_rot).flatten(3)
99
+ xk_out = torch.view_as_real(xk_rot).flatten(3)
100
+
101
+ return xq_out.type(dtype), xk_out.type(dtype)
102
+
103
+ # --- 3. MultiHeadAttention (SWA/SAPA Enabled, Cache Truncation Fixed) ---
104
+ class MultiHeadAttention(nn.Module):
105
+ def __init__(self, window_size: int = WINDOW_SIZE):
106
+ super().__init__()
107
+ self.q_proj = nn.Linear(MODEL_DIM, MODEL_DIM, bias=False)
108
+ self.k_proj = nn.Linear(MODEL_DIM, MODEL_DIM, bias=False)
109
+ self.v_proj = nn.Linear(MODEL_DIM, MODEL_DIM, bias=False)
110
+ self.out_proj = nn.Linear(MODEL_DIM, MODEL_DIM, bias=False)
111
+ self.scale = HEAD_DIM ** -0.5
112
+ self.window_size = window_size
113
+
114
+ self._build_rope_buffers(MAX_SEQ_LEN)
115
+
116
+ def _build_rope_buffers(self, max_context_len: int):
117
+ freqs_cis = precompute_freqs_cis(HEAD_DIM, max_context_len)
118
+ self.register_buffer("freqs_cis", freqs_cis, persistent=False)
119
+
120
+ def forward(self, x: torch.Tensor, pos_offset: int, past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None):
121
+ device = x.device
122
+ B, T, D = x.shape
123
+ dtype = x.dtype
124
+
125
+ # --- Context Scaling (RoPE) Check and Update ---
126
+ total_len = T + pos_offset
127
+ if total_len > self.freqs_cis.size(0):
128
+ new_freqs_cis = precompute_freqs_cis(HEAD_DIM, total_len).to(device)
129
+ self.freqs_cis = new_freqs_cis
130
+
131
+ q = self.q_proj(x).view(B, T, NUM_HEADS, HEAD_DIM).transpose(1, 2)
132
+ current_k = self.k_proj(x).view(B, T, NUM_HEADS, HEAD_DIM).transpose(1, 2)
133
+ current_v = self.v_proj(x).view(B, T, NUM_HEADS, HEAD_DIM).transpose(1, 2)
134
+
135
+ # Apply RoPE
136
+ cur_freqs_cis = self.freqs_cis[pos_offset : pos_offset + T].to(device)
137
+ q, k = apply_rotary_emb(q, current_k, cur_freqs_cis)
138
+ v = current_v
139
+
140
+ new_kv = None
141
+
142
+ # --- Handle initialization and enforce WINDOW_SIZE truncation ---
143
+ if past_kv is None or past_kv[0] is None:
144
+ if T > self.window_size:
145
+ new_kv = (k[:, :, -self.window_size:], v[:, :, -self.window_size:])
146
+ k = k[:, :, -self.window_size:]
147
+ v = v[:, :, -self.window_size:]
148
+ else:
149
+ new_kv = (k, v)
150
+
151
+ elif past_kv[0] is not None:
152
+ past_k, past_v = past_kv
153
+ cache_len = past_k.size(2)
154
+
155
+ sapa_start_idx = max(0, cache_len - (self.window_size - T))
156
+
157
+ k_windowed = past_k[:, :, sapa_start_idx:, :]
158
+ v_windowed = past_v[:, :, sapa_start_idx:, :]
159
+
160
+ k = torch.cat([k_windowed, k], dim=2)
161
+ v = torch.cat([v_windowed, v], dim=2)
162
+
163
+ full_new_k = torch.cat([past_k, current_k], dim=2)
164
+ full_new_v = torch.cat([past_v, current_v], dim=2)
165
+
166
+ new_kv = (full_new_k[:, :, -self.window_size:], full_new_v[:, :, -self.window_size:])
167
+
168
+ seqlen_k = k.size(2)
169
+
170
+ # Attention in FP32 for stability
171
+ q_stab = q.float()
172
+ k_stab = k.float()
173
+ v_stab = v.float()
174
+
175
+ attn_weights = torch.matmul(q_stab, k_stab.transpose(-2, -1)) * self.scale
176
+
177
+ # Causal Mask
178
+ past_len_visible = seqlen_k - T
179
+ mask = torch.full((T, seqlen_k), float('-inf'), device=device, dtype=torch.float32)
180
+ mask = torch.triu(mask, diagonal=past_len_visible + 1).unsqueeze(0).unsqueeze(0)
181
+
182
+ attn_weights = attn_weights + mask
183
+ attn_weights = F.softmax(attn_weights, dim=-1)
184
+
185
+ out_raw = torch.matmul(attn_weights, v_stab)
186
+ out = out_raw.transpose(1, 2).contiguous().view(B, T, D)
187
+ out = self.out_proj(out)
188
+
189
+ return out.type(dtype), new_kv
190
+
191
+ # --- 4. SwiGLU Feed-Forward ---
192
+ class SwiGLUFeedForward(nn.Module):
193
+ def __init__(self):
194
+ super().__init__()
195
+ self.w1 = nn.Linear(MODEL_DIM, FFN_HIDDEN_DIM, bias=False)
196
+ self.w3 = nn.Linear(MODEL_DIM, FFN_HIDDEN_DIM, bias=False)
197
+ self.w2 = nn.Linear(FFN_HIDDEN_DIM, MODEL_DIM, bias=False)
198
+ self.dropout = nn.Dropout(DROPOUT_RATE)
199
+
200
+ def forward(self, x):
201
+ up_output = self.w3(x)
202
+ gate_output = self.w1(x)
203
+ swiglu_output = F.silu(gate_output) * up_output
204
+ out = self.w2(swiglu_output)
205
+ out = self.dropout(out)
206
+ return out
207
+
208
+ # --- 5. Transformer Block ---
209
+ class TransformerBlock(nn.Module):
210
+ def __init__(self):
211
+ super().__init__()
212
+ self.attn = MultiHeadAttention(window_size=WINDOW_SIZE)
213
+ self.ffn = SwiGLUFeedForward()
214
+ self.norm1 = RMSNorm(MODEL_DIM)
215
+ self.norm2 = RMSNorm(MODEL_DIM)
216
+ self.attn_dropout = nn.Dropout(DROPOUT_RATE)
217
+
218
+ def forward(self, x, pos_offset: int, past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None):
219
+
220
+ if self.training and getattr(self, 'model', None) and self.model.gradient_checkpointing:
221
+ if past_kv is None:
222
+ def create_forward_function(attn, ffn, norm1, norm2, attn_dropout, pos_offset):
223
+ def forward_fn(x):
224
+ norm_x1 = norm1(x)
225
+ attn_out, _ = attn(norm_x1, pos_offset, None)
226
+ x = x + attn_dropout(attn_out)
227
+
228
+ norm_x2 = norm2(x)
229
+ x = x + ffn(norm_x2)
230
+ return x
231
+ return forward_fn
232
+
233
+ x = torch.utils.checkpoint.checkpoint(
234
+ create_forward_function(self.attn, self.ffn, self.norm1, self.norm2, self.attn_dropout, pos_offset),
235
+ x, use_reentrant=False, preserve_rng_state=True
236
+ )
237
+ new_kv = None
238
+ else:
239
+ norm_x = self.norm1(x)
240
+ attn_out, new_kv = self.attn(norm_x, pos_offset, past_kv)
241
+ x = x + self.attn_dropout(attn_out)
242
+
243
+ norm_x = self.norm2(x)
244
+ x = x + self.ffn(norm_x)
245
+
246
+ else:
247
+ norm_x = self.norm1(x)
248
+ attn_out, new_kv = self.attn(norm_x, pos_offset, past_kv)
249
+ x = x + self.attn_dropout(attn_out)
250
+
251
+ norm_x = self.norm2(x)
252
+ x = x + self.ffn(norm_x)
253
+
254
+ return x, new_kv
255
+
256
+ # --- 6. Main Model (JiRackPyTorch) - FINAL VERSION ---
257
+ class JiRackPyTorch(nn.Module):
258
+ def __init__(self):
259
+ super().__init__()
260
+ self.token_emb = nn.Embedding(VOCAB_SIZE, MODEL_DIM)
261
+ self.blocks = nn.ModuleList([TransformerBlock() for _ in range(NUM_LAYERS)])
262
+ self.ln_f = RMSNorm(MODEL_DIM)
263
+ self.lm_head = nn.Linear(MODEL_DIM, VOCAB_SIZE, bias=False)
264
+ self.emb_dropout = nn.Dropout(DROPOUT_RATE)
265
+
266
+ self.apply(self._init_weights)
267
+ self.lm_head.weight = self.token_emb.weight
268
+
269
+ self.gradient_checkpointing = False
270
+
271
+ signature = "Konstantin V Grabko . original author 2025"
272
+ self.register_buffer("proof_of_authorship_cmsmanhattan", torch.tensor([ord(c) for c in signature], dtype=torch.uint8), persistent=False)
273
+ self.register_buffer("birth_date", torch.tensor([20251127], dtype=torch.int64), persistent=False)
274
+
275
+ for block in self.blocks:
276
+ object.__setattr__(block, 'model', self)
277
+
278
+ def _init_weights(self, module):
279
+ if isinstance(module, nn.Linear):
280
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02 / math.sqrt(2 * NUM_LAYERS))
281
+ if module.bias is not None:
282
+ torch.nn.init.zeros_(module.bias)
283
+ elif isinstance(module, nn.Embedding):
284
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
285
+ elif isinstance(module, RMSNorm):
286
+ nn.init.ones_(module.weight)
287
+
288
+ if isinstance(module, nn.Linear) and hasattr(self, 'lm_head') and module is self.lm_head:
289
+ nn.init.normal_(module.weight, mean=0.0, std=0.01)
290
+
291
+ def gradient_checkpointing_enable(self):
292
+ self.gradient_checkpointing = True
293
+
294
+ def gradient_checkpointing_disable(self):
295
+ self.gradient_checkpointing = False
296
+
297
+ def forward(self, input_ids: torch.Tensor, past_kv: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None):
298
+ x = self.token_emb(input_ids)
299
+ x = self.emb_dropout(x)
300
+
301
+ pos_offset = 0
302
+ if past_kv is not None and past_kv[0] is not None and past_kv[0][0] is not None:
303
+ pos_offset = past_kv[0][0].size(2)
304
+
305
+ new_kv_cache = [] if past_kv is not None else None
306
+ current_past = past_kv
307
+
308
+ for i, block in enumerate(self.blocks):
309
+ layer_past = current_past[i] if current_past and i < len(current_past) else None
310
+
311
+ x, layer_kv = block(x, pos_offset, layer_past)
312
+
313
+ if new_kv_cache is not None and layer_kv is not None:
314
+ new_kv_cache.append(layer_kv)
315
+
316
+ x = self.ln_f(x)
317
+ logits = self.lm_head(x)
318
+
319
+ return logits if past_kv is None else (logits, new_kv_cache)
320
+
321
+ @torch.no_grad()
322
+ def generate(self, input_ids: torch.Tensor, max_new_tokens: int = 100, temperature: float = 0.8, top_p: float = 0.95, repetition_penalty: float = 1.0, do_sample: bool = True, eos_token_id: int = 50256) -> torch.Tensor:
323
+ B, T = input_ids.shape
324
+ device = input_ids.device
325
+
326
+ # Prefill Step
327
+ past_kv = [None] * NUM_LAYERS
328
+ forward_output = self(input_ids, past_kv=past_kv)
329
+
330
+ if isinstance(forward_output, tuple):
331
+ if len(forward_output) != 2:
332
+ raise ValueError(f"CRITICAL ERROR: forward returned {len(forward_output)} outputs in prefill.")
333
+ logits, past_kv = forward_output
334
+ else:
335
+ logits = forward_output
336
+ past_kv = [None] * NUM_LAYERS
337
+
338
+ last_logits = logits[:, -1, :]
339
+ output_ids = input_ids.clone()
340
+
341
+ for _ in range(max_new_tokens):
342
+ if repetition_penalty != 1.0:
343
+ unique_tokens = output_ids.unique()
344
+ for token_id in unique_tokens:
345
+ tid = token_id.item()
346
+ if output_ids.tolist().count(tid) > 0:
347
+ log_prob = last_logits[:, tid]
348
+ last_logits[:, tid] = torch.where(log_prob > 0, log_prob / repetition_penalty, log_prob * repetition_penalty)
349
+
350
+ if temperature == 0.0 or not do_sample:
351
+ next_token = torch.argmax(last_logits, dim=-1, keepdim=True)
352
+ else:
353
+ logits_temp = last_logits.float() / temperature
354
+ probs = F.softmax(logits_temp, dim=-1)
355
+
356
+ if top_p < 1.0:
357
+ sorted_probs, sorted_indices = torch.sort(probs, dim=-1, descending=True)
358
+ cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
359
+ mask = cumulative_probs > top_p
360
+ mask[..., 1:] = mask[..., :-1].clone()
361
+ mask[..., 0] = False
362
+ sorted_probs[mask] = 0.0
363
+ sorted_probs = sorted_probs / (sorted_probs.sum(dim=-1, keepdim=True) + 1e-9)
364
+ next_token_index = torch.multinomial(sorted_probs, num_samples=1)
365
+ next_token = torch.gather(sorted_indices, -1, next_token_index)
366
+ else:
367
+ next_token = torch.multinomial(probs, num_samples=1)
368
+
369
+ if next_token.item() == eos_token_id:
370
+ break
371
+
372
+ output_ids = torch.cat([output_ids, next_token], dim=-1)
373
+ next_input = next_token
374
+
375
+ forward_output = self(next_input, past_kv=past_kv)
376
+
377
+ if isinstance(forward_output, tuple):
378
+ if len(forward_output) != 2:
379
+ raise ValueError(f"CRITICAL ERROR: forward returned {len(forward_output)} outputs in decode loop.")
380
+ logits_out, past_kv = forward_output
381
+ else:
382
+ logits_out = forward_output
383
+
384
+ last_logits = logits_out[:, -1, :]
385
+
386
+ return output_ids.squeeze(0)
387
+
388
+
389
+ # === EXPORT SCRIPT (Testing SWA and Generation functionality) ===
390
+ if __name__ == "__main__":
391
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
392
+ print(f"Creating 0.94B-parameter Llama-style model with SWA on {device}...")
393
+
394
+ model = JiRackPyTorch().to(device)
395
+ model.eval()
396
+
397
+ # Ensure RoPE freqs are on the target device
398
+ for name, module in model.named_modules():
399
+ if isinstance(module, MultiHeadAttention):
400
+ if module.freqs_cis.device != device:
401
+ module._build_rope_buffers(MAX_SEQ_LEN)
402
+ module.freqs_cis = module.freqs_cis.to(device)
403
+
404
+ total_params = sum(p.numel() for p in model.parameters()) / 1e9
405
+ print(f"Model ready. Parameters: {total_params:.2f}B. SWA Window Size: {WINDOW_SIZE}")
406
+
407
+ # --- SWA TEST ---
408
+ print("\n--- Testing SWA/KV Cache Truncation (Inference) ---")
409
+ large_input = torch.randint(0, VOCAB_SIZE, (1, WINDOW_SIZE * 2), device=device)
410
+
411
+ with torch.no_grad():
412
+ output = model(large_input, past_kv=[None] * NUM_LAYERS)
413
+ if isinstance(output, tuple):
414
+ logits_out, kv_cache = output
415
+ else:
416
+ logits_out = output
417
+ kv_cache = [None] * NUM_LAYERS
418
+
419
+ first_layer_cache_size = kv_cache[0][0].size(2) if kv_cache and kv_cache[0] is not None else 0
420
+
421
+ print(f"Initial Prefill Length: {large_input.size(1)}. Cache Size after Prefill: {first_layer_cache_size}")
422
+ if first_layer_cache_size == WINDOW_SIZE:
423
+ print("✅ Cache Truncation (SWA) successful.")
424
+ else:
425
+ print(f"❌ Cache Truncation (SWA) failed. Expected {WINDOW_SIZE}, got {first_layer_cache_size}")
426
+
427
+ single_token = torch.randint(0, VOCAB_SIZE, (1, 1), device=device)
428
+ with torch.no_grad():
429
+ output = model(single_token, past_kv=kv_cache)
430
+ if isinstance(output, tuple):
431
+ logits_out, final_kv_cache = output
432
+ else:
433
+ logits_out = output
434
+ final_kv_cache = kv_cache
435
+
436
+ final_cache_size = final_kv_cache[0][0].size(2) if final_kv_cache and final_kv_cache[0] is not None else 0
437
+ print(f"Cache Size after 1 token generation: {final_cache_size}")
438
+ if final_cache_size == WINDOW_SIZE:
439
+ print("✅ SWA cache size remains fixed during generation.")
440
+ else:
441
+ print(f"❌ SWA cache size changed. Expected {WINDOW_SIZE}, got {final_cache_size}")
442
+
443
+ # --- GENERATE TEST ---
444
+ print("\n--- Testing Generation Loop ---")
445
+
446
+ prompt = torch.randint(0, VOCAB_SIZE, (1, 10), device=device)
447
+ max_tokens_to_generate = 20
448
+
449
+ try:
450
+ generated_ids = model.generate(
451
+ prompt,
452
+ max_new_tokens=max_tokens_to_generate,
453
+ temperature=0.7,
454
+ top_p=0.9,
455
+ do_sample=True,
456
+ eos_token_id=-1
457
+ )
458
+
459
+ generated_new_tokens = generated_ids.size(0) - prompt.size(1)
460
+ print(f"Prompt length: {prompt.size(1)}")
461
+ print(f"Generated new tokens: {generated_new_tokens}")
462
+
463
+ if generated_new_tokens == max_tokens_to_generate:
464
+ print("✅ Generation output length is correct.")
465
+ else:
466
+ print(f"⚠️ Generation stopped early (should not happen with eos_token_id=-1), got {generated_new_tokens} new tokens.")
467
+
468
+ print("✅ Generation Test Succeeded (no errors)!")
469
+
470
+ except Exception as e:
471
+ print(f"❌ Generation Test Failed: {e}")
472
+
473
+ state_dict_path = Path("models/jirack_swa_1b_class.state_dict.pt")
474
+ state_dict_path.parent.mkdir(parents=True, exist_ok=True)
475
+ torch.save(model.state_dict(), state_dict_path)
 
 
 
 
 
 
476
  print(f"\nFinal state_dict saved to → {state_dict_path}")