zerdovzad commited on
Commit
d831a32
·
verified ·
1 Parent(s): 7bd2553

Upload 4 files

Browse files
Files changed (4) hide show
  1. chat.py +359 -0
  2. download_data.py +191 -0
  3. nord_core.py +778 -0
  4. train_nord.py +456 -0
chat.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ╔══════════════════════════════════════════════════════════════════════════╗
3
+ ║ PROJECT NORD — Крок 3: Чат з моделлю v3.1 ║
4
+ ║ ║
5
+ ║ Просто запусти: ║
6
+ ║ python chat.py ║
7
+ ║ ║
8
+ ║ Воно запитає де лежить модель і запустить інтерактивний чат. ║
9
+ ║ Підтримує STDP: модель вчиться новим словам прямо під час розмови! ║
10
+ ║ v3.1: Repetition Penalty — менше повторень у генерації ║
11
+ ╚══════════════════════════════════════════════════════════════════════════╝
12
+
13
+ Потрібно:
14
+ pip install torch transformers
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ import os
20
+ import sys
21
+ import time
22
+ from pathlib import Path
23
+ from collections import Counter
24
+
25
+ import torch
26
+ import torch.nn.functional as F
27
+
28
+ from nord_core import NordConfig, NordModel
29
+
30
+
31
+ # ─────────────────────────────────────────────────────────────────────────────
32
+ # ЗАВАНТАЖЕННЯ МОДЕЛІ
33
+ # ─────────────────────────────────────────────────────────────────────────────
34
+
35
+ def load_model(model_dir: str) -> tuple:
36
+ """Завантажити модель і токенізатор."""
37
+ from transformers import AutoTokenizer
38
+
39
+ model_path = Path(model_dir)
40
+
41
+ # Знайти файл моделі
42
+ candidates = ["nord_final.pt", "nord_latest.pt"]
43
+ ckpt_path = None
44
+ for name in candidates:
45
+ p = model_path / name
46
+ if p.exists():
47
+ ckpt_path = p
48
+ break
49
+
50
+ if ckpt_path is None:
51
+ steps = sorted(model_path.glob("nord_step_*.pt"))
52
+ if steps:
53
+ ckpt_path = steps[-1]
54
+
55
+ if ckpt_path is None:
56
+ print(f" [✗] Не знайдено моделі в: {model_dir}")
57
+ print(f" Спочатку натренуй: python train_nord.py")
58
+ sys.exit(1)
59
+
60
+ print(f" [*] Завантажуємо: {ckpt_path.name}")
61
+
62
+ device = "cuda" if torch.cuda.is_available() else "cpu"
63
+ ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
64
+
65
+ saved_cfg = ckpt.get("config", {})
66
+ cfg = NordConfig(
67
+ device=device,
68
+ dtype=torch.float16 if device == "cuda" else torch.float32,
69
+ d_model=saved_cfg.get("d_model", 512),
70
+ n_heads=saved_cfg.get("n_heads", 8),
71
+ n_layers=saved_cfg.get("n_layers", 6),
72
+ d_ff=saved_cfg.get("d_ff", 1024),
73
+ T=saved_cfg.get("T", 8),
74
+ T_slow=saved_cfg.get("T_slow", 2),
75
+ max_seq_len=saved_cfg.get("max_seq_len", 512),
76
+ vocab_size=saved_cfg.get("vocab_size", 128_256),
77
+ persistent_mem=False,
78
+ )
79
+
80
+ model = NordModel(cfg).to(device)
81
+ model.load_state_dict(ckpt["model_state_dict"])
82
+ model.eval()
83
+
84
+ print(f" [*] Завантажуємо Llama-3.2 токенізатор...")
85
+ tokenizer = AutoTokenizer.from_pretrained(
86
+ cfg.tokenizer_id, trust_remote_code=True,
87
+ )
88
+ if tokenizer.pad_token is None:
89
+ tokenizer.pad_token = tokenizer.eos_token
90
+ tokenizer.pad_token_id = tokenizer.eos_token_id
91
+
92
+ param_count = sum(p.numel() for p in model.parameters()) / 1e6
93
+ print(f" [✓] Модель завантажена! ({param_count:.1f}M параметрів)")
94
+
95
+ return model, tokenizer, cfg
96
+
97
+
98
+ # ─────────────────────────────────────────────────────────────────────────────
99
+ # REPETITION PENALTY
100
+ # ─────────────────────────────────────────────────────────────────────────────
101
+
102
+ def apply_repetition_penalty(
103
+ logits: torch.Tensor,
104
+ generated_ids: torch.Tensor,
105
+ penalty: float = 1.3,
106
+ window: int = 50,
107
+ ) -> torch.Tensor:
108
+ """
109
+ Зменшує ймовірність токенів які вже з'явились в останніх `window` токена��.
110
+ penalty > 1.0 = зменшує повторення (рекомендовано 1.2-1.5)
111
+ Чим більше разів токен з'явився — тим сильніший penalty (до 5x).
112
+ """
113
+ if penalty <= 1.0:
114
+ return logits
115
+
116
+ recent_ids = generated_ids[0, -window:].tolist()
117
+ token_counts = Counter(recent_ids)
118
+
119
+ for token_id, count in token_counts.items():
120
+ if token_id < logits.size(-1):
121
+ # Експоненційний penalty: penalty^min(count, 5)
122
+ effective_penalty = penalty ** min(count, 5)
123
+ if logits[0, token_id] > 0:
124
+ logits[0, token_id] = logits[0, token_id] / effective_penalty
125
+ else:
126
+ logits[0, token_id] = logits[0, token_id] * effective_penalty
127
+
128
+ return logits
129
+
130
+
131
+ # ─────────────────────────────────────────────────────────────────────────────
132
+ # ГЕНЕРАЦІЯ ТЕКСТУ
133
+ # ─────────────────────────────────────────────────────────────────────────────
134
+
135
+ @torch.no_grad()
136
+ def generate(
137
+ model: NordModel,
138
+ tokenizer,
139
+ cfg: NordConfig,
140
+ prompt: str,
141
+ max_new_tokens: int = 200,
142
+ temperature: float = 0.8,
143
+ top_k: int = 50,
144
+ top_p: float = 0.9,
145
+ enable_stdp: bool = True,
146
+ repetition_penalty: float = 1.3,
147
+ rep_window: int = 50,
148
+ ) -> str:
149
+ """
150
+ Авторегресивна генерація з SNN.
151
+ v3.1: + repetition penalty для різноманітнішого тексту.
152
+ """
153
+ device = cfg.device
154
+
155
+ model.reset_state()
156
+
157
+ max_prompt_len = max(32, cfg.max_seq_len - max_new_tokens)
158
+ enc = tokenizer(prompt, return_tensors="pt", truncation=True,
159
+ max_length=max_prompt_len)
160
+ input_ids = enc.input_ids.to(device)
161
+ generated_ids = input_ids.clone()
162
+
163
+ for _ in range(max_new_tokens):
164
+ context = generated_ids[:, -cfg.max_seq_len:]
165
+
166
+ with torch.amp.autocast("cuda", enabled=(device == "cuda")):
167
+ logits, stats = model(context, enable_stdp=enable_stdp)
168
+
169
+ next_logits = logits[:, -1, :].float()
170
+
171
+ # ── Repetition Penalty (до temperature!) ──
172
+ next_logits = apply_repetition_penalty(
173
+ next_logits, generated_ids,
174
+ penalty=repetition_penalty,
175
+ window=rep_window,
176
+ )
177
+
178
+ if temperature > 0:
179
+ next_logits = next_logits / temperature
180
+
181
+ if top_k > 0:
182
+ top_k_vals, _ = torch.topk(next_logits, min(top_k, next_logits.size(-1)))
183
+ threshold = top_k_vals[:, -1].unsqueeze(-1)
184
+ next_logits[next_logits < threshold] = float("-inf")
185
+
186
+ if top_p < 1.0:
187
+ sorted_logits, sorted_idx = torch.sort(next_logits, descending=True)
188
+ cumprobs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
189
+ remove_mask = cumprobs - F.softmax(sorted_logits, dim=-1) > top_p
190
+ sorted_logits[remove_mask] = float("-inf")
191
+ next_logits.scatter_(1, sorted_idx, sorted_logits)
192
+
193
+ probs = F.softmax(next_logits, dim=-1)
194
+ next_token = torch.multinomial(probs, num_samples=1)
195
+ generated_ids = torch.cat([generated_ids, next_token], dim=-1)
196
+
197
+ # v3: Reward-modulated STDP
198
+ if enable_stdp:
199
+ loss_proxy = -torch.log(probs.max() + 1e-8).item()
200
+ model.stdp_update(current_loss=loss_proxy)
201
+
202
+ if next_token.item() == tokenizer.eos_token_id:
203
+ break
204
+
205
+ new_ids = generated_ids[0, input_ids.shape[1]:]
206
+ return tokenizer.decode(new_ids, skip_special_tokens=True)
207
+
208
+
209
+ # ─────────────────────────────────────────────────────────────────────────────
210
+ # ІНТЕРАКТИВНИЙ ЧАТ
211
+ # ─────────────────────────────────────────────────────────────────────────────
212
+
213
+ def chat_loop(model: NordModel, tokenizer, cfg: NordConfig):
214
+ """Головний цикл чату."""
215
+
216
+ temperature = 0.8
217
+ max_tokens = 200
218
+ stdp_enabled = True
219
+ rep_penalty = 1.3
220
+ rep_window = 50
221
+
222
+ print(f"\n {'─' * 50}")
223
+ print(f" Пиши повідомлення і натискай Enter.")
224
+ print(f" Команди:")
225
+ print(f" /quit — вийти")
226
+ print(f" /temp 0.5 — змінити temperature")
227
+ print(f" /tokens 300 — макс. токенів у відповіді")
228
+ print(f" /stdp on|off — STDP навчання під час чату")
229
+ print(f" /rep 1.5 — repetition penalty (1.0=вимк, 1.2-1.5=норм)")
230
+ print(f" /stats — показати спайк-статистику")
231
+ print(f" /reset — скинути STDP кеш")
232
+ print(f" {'─' * 50}\n")
233
+
234
+ last_stats = {}
235
+
236
+ while True:
237
+ try:
238
+ user_input = input(" Ти: ").strip()
239
+ except (KeyboardInterrupt, EOFError):
240
+ print("\n Бувай! 👋")
241
+ break
242
+
243
+ if not user_input:
244
+ continue
245
+
246
+ # ── Команди ──
247
+ if user_input.startswith("/"):
248
+ parts = user_input.split()
249
+ cmd = parts[0].lower()
250
+
251
+ if cmd == "/quit":
252
+ print(" Бувай! 👋")
253
+ break
254
+
255
+ elif cmd == "/temp" and len(parts) > 1:
256
+ try:
257
+ temperature = float(parts[1])
258
+ print(f" [⚙] Temperature = {temperature}")
259
+ except ValueError:
260
+ print(f" [!] Невірне значення")
261
+
262
+ elif cmd == "/tokens" and len(parts) > 1:
263
+ try:
264
+ max_tokens = int(parts[1])
265
+ print(f" [⚙] Max tokens = {max_tokens}")
266
+ except ValueError:
267
+ print(f" [!] Невірне значення")
268
+
269
+ elif cmd == "/stdp":
270
+ if len(parts) > 1 and parts[1].lower() in ("off", "0", "ні"):
271
+ stdp_enabled = False
272
+ print(f" [⚙] STDP вимкнено")
273
+ else:
274
+ stdp_enabled = True
275
+ print(f" [⚙] STDP увімкнено — модель вчиться під час чату!")
276
+
277
+ elif cmd == "/rep" and len(parts) > 1:
278
+ try:
279
+ rep_penalty = float(parts[1])
280
+ print(f" [⚙] Repetition penalty = {rep_penalty}")
281
+ if rep_penalty > 2.0:
282
+ print(f" [!] Увага: значення > 2.0 може зламати генерацію")
283
+ except ValueError:
284
+ print(f" [!] Невірне значення")
285
+
286
+ elif cmd == "/stats":
287
+ if last_stats:
288
+ print(f" [📊] Остання статистика:")
289
+ for k, v in last_stats.items():
290
+ print(f" {k}: {v:.4f}")
291
+ else:
292
+ print(f" [!] Ще нема статистики — напиши щось спочатку")
293
+
294
+ elif cmd == "/reset":
295
+ model._stdp_cache.clear()
296
+ print(f" [⚙] STDP кеш скинуто")
297
+
298
+ else:
299
+ print(f" [!] Невідома команда: {cmd}")
300
+
301
+ continue
302
+
303
+ # ── Генерація ──
304
+ t0 = time.time()
305
+
306
+ response = generate(
307
+ model, tokenizer, cfg,
308
+ prompt=user_input,
309
+ max_new_tokens=max_tokens,
310
+ temperature=temperature,
311
+ enable_stdp=stdp_enabled,
312
+ repetition_penalty=rep_penalty,
313
+ rep_window=rep_window,
314
+ )
315
+
316
+ elapsed = time.time() - t0
317
+
318
+ print(f"\n Nord: {response}")
319
+
320
+ resp_tokens = len(tokenizer.encode(response, add_special_tokens=False))
321
+ tps = resp_tokens / elapsed if elapsed > 0 else 0
322
+ stdp_tag = " [STDP ✓]" if stdp_enabled else ""
323
+ rep_tag = f" [REP {rep_penalty}]" if rep_penalty > 1.0 else ""
324
+ print(f" [{resp_tokens} tok, {elapsed:.1f}s, {tps:.1f} tok/s{stdp_tag}{rep_tag}]\n")
325
+
326
+ # Зберегти статистику
327
+ with torch.no_grad(), torch.amp.autocast("cuda", enabled=(cfg.device == "cuda")):
328
+ ids = tokenizer(user_input, return_tensors="pt",
329
+ truncation=True, max_length=cfg.max_seq_len).input_ids.to(cfg.device)
330
+ _, last_stats = model(ids)
331
+
332
+
333
+ # ─────────────────────────────────────────────────────────────────────────────
334
+ # ENTRY POINT
335
+ # ─────────────────────────────────────────────────────────────────────────────
336
+
337
+ def main():
338
+ print()
339
+ print("═" * 60)
340
+ print(" ⚡ PROJECT NORD — Spiking Neural Network Chat v3.1")
341
+ print("═" * 60)
342
+
343
+ default_model = os.path.join("D:", os.sep, "nord_model")
344
+ print(f"\n Де лежить навчена модель?")
345
+ print(f" (Enter = {default_model})")
346
+ model_input = input(" Шлях: ").strip()
347
+ model_dir = model_input if model_input else default_model
348
+
349
+ if not Path(model_dir).exists():
350
+ print(f"\n [✗] Папка не знайдена: {model_dir}")
351
+ print(f" Спочатку натренуй: python train_nord.py")
352
+ sys.exit(1)
353
+
354
+ model, tokenizer, cfg = load_model(model_dir)
355
+ chat_loop(model, tokenizer, cfg)
356
+
357
+
358
+ if __name__ == "__main__":
359
+ main()
download_data.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ╔══════════════════════════════════════════════════════════════════════════╗
3
+ ║ PROJECT NORD — Крок 1: Завантаження датасету ║
4
+ ║ ║
5
+ ║ Просто запусти: ║
6
+ ║ python download_data.py ║
7
+ ║ ║
8
+ ║ Воно запитає куди зберегти і почне качати. ║
9
+ ║ Датасет: FineWeb-Edu (високоякісні освітні тексти англійською) ║
10
+ ║ Розмір: ~40 GB тексту (JSONL формат) ║
11
+ ╚══════════════════════════════════════════════════════════════════════════╝
12
+
13
+ Потрібно встановити один раз:
14
+ pip install datasets tqdm
15
+ """
16
+
17
+ import json
18
+ import os
19
+ import sys
20
+ import time
21
+
22
+
23
+ def format_size(bytes_val: int) -> str:
24
+ """Форматувати байти в людський вигляд."""
25
+ for unit in ["B", "KB", "MB", "GB", "TB"]:
26
+ if bytes_val < 1024:
27
+ return f"{bytes_val:.1f} {unit}"
28
+ bytes_val /= 1024
29
+ return f"{bytes_val:.1f} PB"
30
+
31
+
32
+ def download():
33
+ print("=" * 60)
34
+ print(" PROJECT NORD — Завантаження датасету")
35
+ print("=" * 60)
36
+ print()
37
+
38
+ # ── Запитати куди зберегти ──
39
+ default_path = os.path.join("D:", os.sep, "nord_dataset", "train_data.jsonl")
40
+ print(f" Куди зберегти датасет?")
41
+ print(f" (Enter = {default_path})")
42
+ user_path = input(" Шлях: ").strip()
43
+ save_path = user_path if user_path else default_path
44
+
45
+ # ── Запитати розмір ──
46
+ print()
47
+ print(" Скільки гігабайт завантажити?")
48
+ print(" Рекомендовано: 10 GB — швидкий тест")
49
+ print(" 40 GB — повне навчання")
50
+ print(f" (Enter = 40)")
51
+ size_input = input(" Розмір (GB): ").strip()
52
+ target_gb = float(size_input) if size_input else 40.0
53
+ target_bytes = int(target_gb * (1024 ** 3))
54
+
55
+ # Створити папку
56
+ os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
57
+
58
+ print()
59
+ print(f" 📁 Зберігаємо в: {save_path}")
60
+ print(f" 📦 Цільовий розмір: {target_gb:.0f} GB")
61
+ print()
62
+
63
+ # ── Перевірити чи вже є частина файлу (для продовження) ──
64
+ bytes_written = 0
65
+ samples_written = 0
66
+ mode = "w"
67
+
68
+ if os.path.exists(save_path):
69
+ existing_size = os.path.getsize(save_path)
70
+ if existing_size > 0:
71
+ print(f" [!] Файл вже існує ({format_size(existing_size)})")
72
+ print(f" Продовжити дозавантаження? (y/n, Enter = y)")
73
+ choice = input(" > ").strip().lower()
74
+ if choice in ("", "y", "yes", "так", "д"):
75
+ bytes_written = existing_size
76
+ # Count existing lines
77
+ print(" Підраховуємо існуючі рядки...")
78
+ with open(save_path, "r", encoding="utf-8") as f:
79
+ samples_written = sum(1 for _ in f)
80
+ mode = "a"
81
+ print(f" Продовжуємо з {samples_written:,} зразків ({format_size(bytes_written)})")
82
+ else:
83
+ print(" Починаємо з нуля...")
84
+
85
+ if bytes_written >= target_bytes:
86
+ print(f"\n [✓] Датасет вже повний! ({format_size(bytes_written)})")
87
+ print(f" Тепер запускай: python train_nord.py")
88
+ return save_path
89
+
90
+ # ── Завантаження ──
91
+ print()
92
+ print(" [*] Підключаємося до HuggingFace...")
93
+ print(" [*] Датасет: HuggingFaceFW/fineweb-edu (sample-10BT)")
94
+ print(" Це високоякісні освітні тексти — найкраще для навчання LLM")
95
+ print()
96
+
97
+ try:
98
+ from datasets import load_dataset
99
+ except ImportError:
100
+ print(" [✗] Бібліотека 'datasets' не встановлена!")
101
+ print(" Виконай: pip install datasets")
102
+ sys.exit(1)
103
+
104
+ # Stream dataset — НІКОЛИ не вантажить все в RAM
105
+ dataset = load_dataset(
106
+ "HuggingFaceFW/fineweb-edu",
107
+ name="sample-10BT",
108
+ split="train",
109
+ streaming=True,
110
+ )
111
+
112
+ # Якщо продовжуємо — пропустити вже завантажені зразки
113
+ data_iter = iter(dataset)
114
+ if samples_written > 0:
115
+ print(f" [*] Пропускаємо {samples_written:,} вже завантажених зразків...")
116
+ for _ in range(samples_written):
117
+ try:
118
+ next(data_iter)
119
+ except StopIteration:
120
+ break
121
+
122
+ print(f" [*] Починаємо запис... (Ctrl+C щоб зупинити, можна продовжити пізніше)")
123
+ print()
124
+
125
+ t_start = time.time()
126
+ last_print = t_start
127
+
128
+ try:
129
+ with open(save_path, mode, encoding="utf-8") as f:
130
+ for sample in data_iter:
131
+ text = sample.get("text", "")
132
+ if not text or len(text) < 50:
133
+ continue
134
+
135
+ line = json.dumps({"text": text}, ensure_ascii=False) + "\n"
136
+ line_bytes = len(line.encode("utf-8"))
137
+ f.write(line)
138
+
139
+ bytes_written += line_bytes
140
+ samples_written += 1
141
+
142
+ # Прогрес кожні 2 секунди
143
+ now = time.time()
144
+ if now - last_print >= 2.0:
145
+ elapsed = now - t_start
146
+ speed = (bytes_written - (0 if mode == "w" else bytes_written)) / elapsed if elapsed > 0 else 0
147
+ pct = bytes_written / target_bytes * 100
148
+ bar_len = 30
149
+ filled = int(bar_len * min(pct, 100) / 100)
150
+ bar = "█" * filled + "░" * (bar_len - filled)
151
+
152
+ print(
153
+ f"\r [{bar}] {pct:.1f}% "
154
+ f"{format_size(bytes_written)}/{format_size(target_bytes)} "
155
+ f"{samples_written:,} зразків "
156
+ f"{format_size(int(speed))}/s ",
157
+ end="", flush=True,
158
+ )
159
+ last_print = now
160
+
161
+ # Flush periodically
162
+ if samples_written % 10000 == 0:
163
+ f.flush()
164
+
165
+ # Досягли цільового розміру
166
+ if bytes_written >= target_bytes:
167
+ break
168
+
169
+ except KeyboardInterrupt:
170
+ print(f"\n\n [⏸] Зупинено! Збережено {format_size(bytes_written)} ({samples_written:,} зразків)")
171
+ print(f" Щоб продовжити пізніше — просто запусти цей скрипт знову.")
172
+ return save_path
173
+
174
+ elapsed = time.time() - t_start
175
+ print(f"\n\n {'═' * 50}")
176
+ print(f" [✓] ГОТОВО!")
177
+ print(f" 📁 Файл: {save_path}")
178
+ print(f" 📦 Розмір: {format_size(bytes_written)}")
179
+ print(f" 📝 Зразків: {samples_written:,}")
180
+ print(f" ⏱ Час: {elapsed/60:.0f} хвилин")
181
+ print(f" {'═' * 50}")
182
+ print()
183
+ print(f" Наступний крок:")
184
+ print(f" python train_nord.py")
185
+ print()
186
+
187
+ return save_path
188
+
189
+
190
+ if __name__ == "__main__":
191
+ download()
nord_core.py ADDED
@@ -0,0 +1,778 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ╔══════════════════════════════════════════════════════════════════════════════╗
3
+ ║ PROJECT NORD — Core Engine v3 ║
4
+ ║ Spiking Neural Network LLM with Associative Memory Manifold ║
5
+ ║ ║
6
+ ║ v3 — All 7 bottleneck fixes: ║
7
+ ║ 1. Multi-Scale Temporal: T_fast + T_slow + persistent membrane state ║
8
+ ║ 2. LeakyClamp: keeps small negatives (parametric floor, not hard ReLU) ║
9
+ ║ 3. Adaptive Cascade: learnable per-cluster gain + soft neighbor weights ║
10
+ ║ 4. Reward-Modulated STDP: LM loss guides plasticity direction ║
11
+ ║ 5. Sparse Resonance: top-K co-firing instead of full O(S²) ║
12
+ ║ 6. Temporal Smoothing Readout: EMA on membrane for long dependencies ║
13
+ ║ 7. Fused ops: no per-block GPU sync, sparse spike buffers ║
14
+ ║ ║
15
+ ║ Target HW: NVIDIA RTX 5070 (8 GB VRAM) ║
16
+ ╚══════════════════════════════════════════════════════════════════════════════╝
17
+ """
18
+
19
+ from __future__ import annotations
20
+
21
+ import math
22
+ import torch
23
+ import torch.nn as nn
24
+ import torch.nn.functional as F
25
+ from torch import Tensor
26
+ from dataclasses import dataclass
27
+ from typing import Dict, Tuple, Optional
28
+
29
+ # ─────────────────────────────────────────────────────────────────────────────
30
+ # §0 CONFIGURATION
31
+ # ─────────────────────────────────────────────────────────────────────────────
32
+
33
+ @dataclass
34
+ class NordConfig:
35
+ # Tokenizer
36
+ tokenizer_id: str = "meta-llama/Llama-3.2-1B"
37
+
38
+ # Dimensions
39
+ vocab_size: int = 128_256
40
+ d_model: int = 512
41
+ n_heads: int = 8
42
+ n_layers: int = 6
43
+ d_ff: int = 1024
44
+ max_seq_len: int = 1024
45
+
46
+ # ═══ FIX #1: Multi-Scale Temporal ═══
47
+ T: int = 8 # fast timesteps (local spike dynamics)
48
+ T_slow: int = 2 # slow timesteps (decimated, longer memory)
49
+ persistent_mem: bool = True # carry membrane state between batches
50
+
51
+ # LIF Neuron Dynamics
52
+ tau_mem: float = 0.9
53
+ tau_syn: float = 0.50
54
+ v_threshold: float = 0.25
55
+ v_reset: float = -0.1
56
+ refractory_t: int = 2
57
+ threshold_lr: float = 0.01
58
+
59
+ # ═══ FIX #3: Adaptive Cascade ═══
60
+ n_clusters: int = 64
61
+ cascade_radius: int = 3
62
+ cascade_gain: float = 0.8 # initial gain (now learnable per-cluster)
63
+
64
+ # ═══ FIX #4: Reward-Modulated STDP ═══
65
+ stdp_a_plus: float = 0.005
66
+ stdp_a_minus: float = 0.005
67
+ stdp_tau_plus: float = 20.0
68
+ stdp_tau_minus: float = 20.0
69
+ stdp_w_max: float = 1.0
70
+ stdp_w_min: float = -0.3
71
+ stdp_reward_scale: float = 1.0 # how much loss modulates STDP
72
+
73
+ # ═══ FIX #5: Sparse Resonance ═══
74
+ resonance_top_k: int = 64 # attend to top-K co-firing positions only
75
+
76
+ # ═══ FIX #2: LeakyClamp ═══
77
+ clamp_floor: float = -0.1 # initial floor (learnable per-channel)
78
+
79
+ # Surrogate Gradient
80
+ surrogate_alpha: float = 4.0
81
+
82
+ # Training
83
+ batch_size: int = 4
84
+ grad_accum: int = 8
85
+ lr: float = 5e-4
86
+ min_lr: float = 1e-5
87
+ weight_decay: float = 0.01
88
+ warmup_steps: int = 500
89
+ max_steps: int = 100_000
90
+ save_every: int = 1000
91
+ log_every: int = 10
92
+ max_grad_norm: float = 1.0
93
+
94
+ # Hardware
95
+ dtype: torch.dtype = torch.float16
96
+ device: str = "cuda"
97
+
98
+ @property
99
+ def T_total(self) -> int:
100
+ """Total effective timesteps (fast + slow)."""
101
+ return self.T + self.T_slow
102
+
103
+
104
+ # ─────────────────────────────────────────────────────────────────────────────
105
+ # §1 SURROGATE GRADIENT — ATan
106
+ # ─────────────────────────────────────────────────────────────────────────────
107
+
108
+ class ATanSurrogate(torch.autograd.Function):
109
+ alpha = 2.0
110
+
111
+ @staticmethod
112
+ def forward(ctx, membrane: Tensor, threshold: Tensor) -> Tensor:
113
+ ctx.save_for_backward(membrane, threshold)
114
+ return (membrane >= threshold).to(membrane.dtype)
115
+
116
+ @staticmethod
117
+ def backward(ctx, grad_output: Tensor) -> Tuple[Tensor, Tensor]:
118
+ membrane, threshold = ctx.saved_tensors
119
+ orig_dtype = membrane.dtype
120
+ x = (membrane.float() - threshold.float())
121
+ grad = ATanSurrogate.alpha / (
122
+ 2.0 * math.pi * (1.0 + (ATanSurrogate.alpha * x) ** 2))
123
+ grad_v = (grad_output.float() * grad).to(orig_dtype)
124
+ return grad_v, -grad_v
125
+
126
+
127
+ def spike_fn(v: Tensor, th: Tensor, alpha: float = 2.0) -> Tensor:
128
+ ATanSurrogate.alpha = alpha
129
+ return ATanSurrogate.apply(v, th)
130
+
131
+
132
+ # ─────────────────────────────────────────────────────────────────────────────
133
+ # §2 ASSOCIATIVE LIF NEURON (v3 — Adaptive Cascade + Persistent State)
134
+ # ─────────────────────────────────────────────────────────────────────────────
135
+
136
+ class AssociativeLIF(nn.Module):
137
+ """
138
+ v3 improvements:
139
+ • FIX #3: Learnable per-cluster cascade gain + soft neighbor weights
140
+ • FIX #1: Optional persistent membrane state between calls
141
+ """
142
+
143
+ def __init__(self, d: int, cfg: NordConfig, persistent: bool = False):
144
+ super().__init__()
145
+ self.cfg = cfg
146
+ self.d = d
147
+ self.persistent = persistent
148
+
149
+ self.threshold = nn.Parameter(torch.full((d,), cfg.v_threshold))
150
+ self.beta_mem_raw = nn.Parameter(torch.tensor(
151
+ math.log(cfg.tau_mem / (1 - cfg.tau_mem + 1e-6))))
152
+ self.beta_syn_raw = nn.Parameter(torch.tensor(
153
+ math.log(cfg.tau_syn / (1 - cfg.tau_syn + 1e-6))))
154
+
155
+ # Cluster topology
156
+ nc = cfg.n_clusters
157
+ cluster_ids = torch.arange(d) % nc
158
+ self.register_buffer("cluster_ids", cluster_ids)
159
+
160
+ # ═══ FIX #3: Adaptive Cascade ═══
161
+ # Instead of fixed boolean neighbor_mask + fixed gain:
162
+ # - Learnable soft neighbor weights (nc × nc), initialized from topology
163
+ # - Learnable per-cluster gain
164
+ r = cfg.cascade_radius
165
+ idx = torch.arange(nc)
166
+ init_weights = torch.zeros(nc, nc)
167
+ for offset in range(-r, r + 1):
168
+ if offset != 0:
169
+ # Closer neighbors get higher initial weight
170
+ dist_weight = 1.0 - abs(offset) / (r + 1)
171
+ init_weights[idx, (idx + offset) % nc] = dist_weight
172
+ # Learnable: network can strengthen/weaken/extend neighbor connections
173
+ self.neighbor_weights = nn.Parameter(init_weights)
174
+ # Per-cluster gain (not global scalar anymore)
175
+ self.cluster_gain = nn.Parameter(torch.full((nc,), cfg.cascade_gain))
176
+
177
+ # ═══ FIX #1: Persistent membrane state ═══
178
+ if persistent:
179
+ self.register_buffer("_v_mem_state", torch.zeros(1, d))
180
+ self.register_buffer("_i_syn_state", torch.zeros(1, d))
181
+
182
+ @property
183
+ def beta_mem(self) -> Tensor:
184
+ return torch.sigmoid(self.beta_mem_raw)
185
+
186
+ @property
187
+ def beta_syn(self) -> Tensor:
188
+ return torch.sigmoid(self.beta_syn_raw)
189
+
190
+ def _cascade_amplify(self, spikes: Tensor) -> Tensor:
191
+ """v3: Soft learnable neighbor weights + per-cluster gain."""
192
+ B, D = spikes.shape
193
+ nc = self.cfg.n_clusters
194
+ cid = self.cluster_ids.unsqueeze(0).expand(B, -1)
195
+
196
+ cluster_fire = torch.zeros(B, nc, device=spikes.device, dtype=spikes.dtype)
197
+ cluster_fire.scatter_add_(1, cid, spikes)
198
+ cluster_fire = cluster_fire / max(D // nc, 1)
199
+
200
+ # Soft neighbor weights (sigmoid → [0,1] so they can't go negative)
201
+ W = torch.sigmoid(self.neighbor_weights) # (nc, nc)
202
+ neighbor_signal = (W.to(cluster_fire.dtype) @ cluster_fire.T).T # (B, nc)
203
+
204
+ # Per-cluster gain
205
+ gain = self.cluster_gain.to(cluster_fire.dtype) # (nc,)
206
+ neighbor_signal = neighbor_signal * gain.unsqueeze(0)
207
+
208
+ return neighbor_signal.gather(1, cid)
209
+
210
+ def reset_state(self):
211
+ """Reset persistent membrane state (call at start of new sequence)."""
212
+ if self.persistent:
213
+ self._v_mem_state.zero_()
214
+ self._i_syn_state.zero_()
215
+
216
+ def forward(self, current_in: Tensor) -> Tuple[Tensor, Tensor]:
217
+ T, B, D = current_in.shape
218
+ device = current_in.device
219
+ dtype = current_in.dtype
220
+ beta_m = self.beta_mem
221
+ beta_s = self.beta_syn
222
+
223
+ # ═══ FIX #1: Persistent membrane — carry state from previous batch ═══
224
+ if self.persistent and self._v_mem_state.shape[0] == B:
225
+ v_mem = self._v_mem_state.clone()
226
+ i_syn = self._i_syn_state.clone()
227
+ else:
228
+ v_mem = torch.zeros(B, D, device=device, dtype=dtype)
229
+ i_syn = torch.zeros(B, D, device=device, dtype=dtype)
230
+ if self.persistent:
231
+ # Resize state buffers for new batch size
232
+ self._v_mem_state = torch.zeros(B, D, device=device, dtype=dtype)
233
+ self._i_syn_state = torch.zeros(B, D, device=device, dtype=dtype)
234
+
235
+ refrac_counter = torch.zeros(B, D, device=device, dtype=torch.int32)
236
+
237
+ spikes_out = []
238
+ v_trace = []
239
+
240
+ for t in range(T):
241
+ i_syn = beta_s * i_syn + current_in[t]
242
+
243
+ refractory_mask = (refrac_counter > 0)
244
+ v_mem = torch.where(
245
+ refractory_mask,
246
+ torch.full_like(v_mem, self.cfg.v_reset),
247
+ beta_m * v_mem + (1.0 - beta_m) * i_syn,
248
+ )
249
+
250
+ s = spike_fn(v_mem, self.threshold, self.cfg.surrogate_alpha)
251
+
252
+ if s.sum() > 0:
253
+ cascade = self._cascade_amplify(s)
254
+ i_syn = i_syn + cascade
255
+
256
+ v_mem = v_mem - s * self.threshold.detach()
257
+ refrac_counter = torch.where(
258
+ s.bool(),
259
+ torch.full_like(refrac_counter, self.cfg.refractory_t),
260
+ (refrac_counter - 1).clamp(min=0),
261
+ )
262
+
263
+ spikes_out.append(s)
264
+ v_trace.append(v_mem)
265
+
266
+ # Save state for next batch
267
+ if self.persistent:
268
+ self._v_mem_state = v_mem.detach()
269
+ self._i_syn_state = i_syn.detach()
270
+
271
+ return torch.stack(spikes_out), torch.stack(v_trace)
272
+
273
+
274
+ # ─────────────────────────────────────────────────────────────────────────────
275
+ # §3 TEMPORAL ENCODER (v3 — Multi-Scale)
276
+ # ─────────────────────────────────────────────────────────────────────────────
277
+
278
+ class TemporalSpikeEncoder(nn.Module):
279
+ """
280
+ v3 — Multi-Scale Temporal Coding:
281
+ Fast path (T timesteps): standard temporal basis modulation
282
+ Slow path (T_slow timesteps): decimated, larger time constants
283
+ → concatenated along time axis → (T+T_slow, B*S, D)
284
+
285
+ The slow path captures longer-range dependencies that T=8 misses.
286
+ """
287
+
288
+ def __init__(self, cfg: NordConfig):
289
+ super().__init__()
290
+ self.cfg = cfg
291
+ D = cfg.d_model
292
+ T = cfg.T
293
+ T_slow = cfg.T_slow
294
+
295
+ self.embed = nn.Embedding(cfg.vocab_size, D)
296
+ nn.init.kaiming_uniform_(self.embed.weight, a=math.sqrt(5))
297
+
298
+ self.temporal_proj = nn.Linear(D, D, bias=False)
299
+ self.drive_scale = nn.Parameter(torch.tensor(15.0))
300
+
301
+ # Fast temporal basis (T gates)
302
+ self.fast_basis = nn.Parameter(torch.randn(T, D) * 0.02)
303
+
304
+ # ═══ FIX #1: Slow temporal basis (T_slow gates, wider receptive field) ═══
305
+ self.slow_basis = nn.Parameter(torch.randn(T_slow, D) * 0.02)
306
+ # Slow drive is weaker — it's a "summary" signal
307
+ self.slow_scale = nn.Parameter(torch.tensor(5.0))
308
+
309
+ def forward(self, token_ids: Tensor) -> Tensor:
310
+ """Returns: (T + T_slow, B*S, D) current."""
311
+ B, S = token_ids.shape
312
+ D = self.cfg.d_model
313
+
314
+ x = self.temporal_proj(self.embed(token_ids))
315
+ x = x.reshape(B * S, D)
316
+
317
+ # Fast path
318
+ fast_gates = torch.sigmoid(self.fast_basis) # (T, D)
319
+ fast = fast_gates.unsqueeze(1) * x.unsqueeze(0) * self.drive_scale
320
+
321
+ # Slow path — fewer timesteps, gentler drive
322
+ slow_gates = torch.sigmoid(self.slow_basis) # (T_slow, D)
323
+ slow = slow_gates.unsqueeze(1) * x.unsqueeze(0) * self.slow_scale
324
+
325
+ # Concatenate: fast then slow timesteps
326
+ return torch.cat([fast, slow], dim=0) # (T+T_slow, B*S, D)
327
+
328
+
329
+ # ─────────────────────────────────────────────────────────────────────────────
330
+ # §4 SPIKING SYNAPTIC RESONANCE (v3 — Sparse Top-K)
331
+ # ─────────────────────────────────────────────────────────────────────────────
332
+
333
+ class SpikingSynapticResonance(nn.Module):
334
+ """
335
+ v3 — FIX #5: Sparse Resonance
336
+
337
+ Instead of full O(S²) attention matrix:
338
+ 1. Compute full co-fire resonance (still needed for causality)
339
+ 2. Keep only top-K values per query position
340
+ 3. Zero out the rest → sparse attention → less memory, faster
341
+
342
+ For S=512, top_k=64 → 87.5% sparsity in attention matrix.
343
+ """
344
+
345
+ def __init__(self, cfg: NordConfig):
346
+ super().__init__()
347
+ self.cfg = cfg
348
+ self.n_heads = cfg.n_heads
349
+ self.d_head = cfg.d_model // cfg.n_heads
350
+ self.top_k = cfg.resonance_top_k
351
+ D = cfg.d_model
352
+
353
+ self.W_q = nn.Linear(D, D, bias=False)
354
+ self.W_k = nn.Linear(D, D, bias=False)
355
+ self.W_v = nn.Linear(D, D, bias=False)
356
+ self.W_o = nn.Linear(D, D, bias=False)
357
+
358
+ self.lif_q = AssociativeLIF(D, cfg)
359
+ self.lif_k = AssociativeLIF(D, cfg)
360
+
361
+ self.resonance_temp = nn.Parameter(
362
+ torch.tensor(1.0 / math.sqrt(self.d_head)))
363
+
364
+ def forward(self, x_spikes: Tensor) -> Tensor:
365
+ T_total, B, S, D = x_spikes.shape
366
+ H, Dh = self.n_heads, self.d_head
367
+
368
+ x_flat = x_spikes.reshape(T_total * B * S, D)
369
+ q_current = self.W_q(x_flat).reshape(T_total, B * S, D)
370
+ k_current = self.W_k(x_flat).reshape(T_total, B * S, D)
371
+ v_raw = self.W_v(x_flat).reshape(T_total, B, S, D)
372
+
373
+ q_spikes, _ = self.lif_q(q_current)
374
+ k_spikes, _ = self.lif_k(k_current)
375
+
376
+ q_spikes = q_spikes.reshape(T_total, B, S, H, Dh)
377
+ k_spikes = k_spikes.reshape(T_total, B, S, H, Dh)
378
+
379
+ q_flat = q_spikes.permute(1, 3, 2, 0, 4).reshape(B, H, S, T_total * Dh)
380
+ k_flat = k_spikes.permute(1, 3, 2, 0, 4).reshape(B, H, S, T_total * Dh)
381
+
382
+ resonance = torch.matmul(q_flat, k_flat.transpose(-2, -1))
383
+ resonance = resonance * self.resonance_temp
384
+
385
+ # Causal mask
386
+ causal_mask = torch.triu(
387
+ torch.ones(S, S, device=x_spikes.device, dtype=torch.bool), diagonal=1
388
+ )
389
+ resonance.masked_fill_(causal_mask.unsqueeze(0).unsqueeze(0), float("-inf"))
390
+
391
+ # ═══ FIX #5: Top-K Sparse Attention ═══
392
+ # Keep only top-K resonance scores per query position, zero out the rest.
393
+ # This makes attention sparse → less memory for long sequences.
394
+ K = min(self.top_k, S)
395
+ if K < S:
396
+ # Find top-K per query row
397
+ top_vals, top_idx = torch.topk(resonance, K, dim=-1) # (B,H,S,K)
398
+ # Create sparse mask: -inf everywhere, then scatter top-K back
399
+ sparse_res = torch.full_like(resonance, float("-inf"))
400
+ sparse_res.scatter_(-1, top_idx, top_vals)
401
+ resonance = sparse_res
402
+
403
+ attn = F.softmax(resonance.float(), dim=-1).to(resonance.dtype)
404
+
405
+ v_mean = v_raw.mean(dim=0)
406
+ v_heads = v_mean.reshape(B, S, H, Dh).permute(0, 2, 1, 3)
407
+ context = torch.matmul(attn, v_heads)
408
+ context = context.permute(0, 2, 1, 3).reshape(B, S, D)
409
+ out = self.W_o(context)
410
+
411
+ return out.unsqueeze(0).expand(T_total, -1, -1, -1)
412
+
413
+
414
+ # ─────────────────────────────────────────────────────────────────────────────
415
+ # §5 NORD BLOCK (v3 — LeakyClamp + LayerScale)
416
+ # ─────────────────────────────────────────────────────────────────────────────
417
+
418
+ class SpikingFeedForward(nn.Module):
419
+ def __init__(self, cfg: NordConfig):
420
+ super().__init__()
421
+ self.up = nn.Linear(cfg.d_model, cfg.d_ff, bias=False)
422
+ self.down = nn.Linear(cfg.d_ff, cfg.d_model, bias=False)
423
+ self.lif1 = AssociativeLIF(cfg.d_ff, cfg)
424
+ self.lif2 = AssociativeLIF(cfg.d_model, cfg)
425
+
426
+ def forward(self, x: Tensor) -> Tensor:
427
+ T, B, S, D = x.shape
428
+ h = self.up(x.reshape(T * B * S, D)).reshape(T, B * S, -1)
429
+ h, _ = self.lif1(h)
430
+ h = h.reshape(T, B, S, -1)
431
+ h = self.down(h.reshape(T * B * S, -1)).reshape(T, B * S, D)
432
+ h, _ = self.lif2(h)
433
+ return h.reshape(T, B, S, D)
434
+
435
+
436
+ class LeakyClamp(nn.Module):
437
+ """
438
+ ═══ FIX #2: LeakyClamp ═══
439
+
440
+ Instead of hard ReLU (kills all negatives):
441
+ output = x if x >= 0
442
+ output = floor + leak * x if x < 0
443
+
444
+ Where `floor` and `leak` are learnable per-channel.
445
+ This preserves sub-threshold membrane information that ReLU discards.
446
+ Initialized so floor ≈ -0.1, leak ≈ 0.1 (gentle pass-through of negatives).
447
+ """
448
+
449
+ def __init__(self, d: int, floor_init: float = -0.1, leak_init: float = 0.1):
450
+ super().__init__()
451
+ # Learnable floor (per-channel): how far below zero we allow
452
+ self.floor = nn.Parameter(torch.full((d,), floor_init))
453
+ # Learnable leak slope (per-channel): how much negative signal passes
454
+ self.leak_raw = nn.Parameter(torch.full((d,), math.log(leak_init / (1 - leak_init + 1e-6))))
455
+
456
+ @property
457
+ def leak(self) -> Tensor:
458
+ return torch.sigmoid(self.leak_raw) # always in (0, 1)
459
+
460
+ def forward(self, x: Tensor) -> Tensor:
461
+ # Positive: pass through unchanged
462
+ # Negative: leak * x, clamped above floor
463
+ neg_part = (self.leak * x).clamp(min=self.floor)
464
+ return torch.where(x >= 0, x, neg_part)
465
+
466
+
467
+ class NordBlock(nn.Module):
468
+ """
469
+ v3: LayerScale + LeakyClamp (not ReLU).
470
+ """
471
+
472
+ def __init__(self, cfg: NordConfig, layer_idx: int = 0):
473
+ super().__init__()
474
+ D = cfg.d_model
475
+ self.norm1 = nn.LayerNorm(D)
476
+ self.norm2 = nn.LayerNorm(D)
477
+ self.resonance = SpikingSynapticResonance(cfg)
478
+ self.ffn = SpikingFeedForward(cfg)
479
+
480
+ init_scale = 0.1 / max(cfg.n_layers, 1)
481
+ self.gamma_attn = nn.Parameter(torch.full((D,), init_scale))
482
+ self.gamma_ffn = nn.Parameter(torch.full((D,), init_scale))
483
+
484
+ # ═══ FIX #2: LeakyClamp instead of ReLU ═══
485
+ self.clamp = LeakyClamp(D, floor_init=cfg.clamp_floor)
486
+
487
+ @staticmethod
488
+ def _safe_norm(norm_layer: nn.LayerNorm, x: Tensor) -> Tensor:
489
+ orig_dtype = x.dtype
490
+ return F.layer_norm(
491
+ x.float(),
492
+ norm_layer.normalized_shape,
493
+ norm_layer.weight.float() if norm_layer.weight is not None else None,
494
+ norm_layer.bias.float() if norm_layer.bias is not None else None,
495
+ norm_layer.eps,
496
+ ).to(orig_dtype)
497
+
498
+ def forward(self, x: Tensor) -> Tensor:
499
+ x_norm = self._safe_norm(self.norm1, x)
500
+ x = x + self.gamma_attn * self.resonance(x_norm)
501
+
502
+ x_norm = self._safe_norm(self.norm2, x)
503
+ x = x + self.gamma_ffn * self.ffn(x_norm)
504
+
505
+ # FIX #2: LeakyClamp preserves sub-threshold info
506
+ x = self.clamp(x)
507
+ return x
508
+
509
+
510
+ # ─────────────────────────────────────────────────────────────────────────────
511
+ # §6 STDP ENGINE (v3 — Reward-Modulated)
512
+ # ─────────────────────────────────────────────────────────────────────────────
513
+
514
+ class STDPEngine:
515
+ """
516
+ ═══ FIX #4: Reward-Modulated STDP ═══
517
+
518
+ Classic STDP is blind — it strengthens any co-firing, even if it hurts
519
+ the LM loss. Reward modulation fixes this:
520
+
521
+ dW_final = dW_stdp × reward_signal
522
+
523
+ Where reward_signal = sigmoid(baseline_loss - current_loss)
524
+ - If current loss < baseline → reward > 0.5 → strengthen
525
+ - If current loss > baseline → reward < 0.5 → weaken/suppress
526
+ - baseline_loss is an exponential moving average
527
+
528
+ This aligns local Hebbian plasticity with the global training objective.
529
+ """
530
+
531
+ def __init__(self, cfg: NordConfig):
532
+ self.cfg = cfg
533
+ self.a_plus = cfg.stdp_a_plus
534
+ self.a_minus = cfg.stdp_a_minus
535
+ self.tau_plus = cfg.stdp_tau_plus
536
+ self.tau_minus = cfg.stdp_tau_minus
537
+ self.w_max = cfg.stdp_w_max
538
+ self.w_min = cfg.stdp_w_min
539
+ self.reward_scale = cfg.stdp_reward_scale
540
+
541
+ # Running baseline loss (EMA)
542
+ self._loss_ema: float = 10.0 # initialize high
543
+ self._ema_decay: float = 0.99
544
+
545
+ def update_reward(self, current_loss: float):
546
+ """Call after each forward pass with current loss."""
547
+ self._loss_ema = self._ema_decay * self._loss_ema + (1 - self._ema_decay) * current_loss
548
+
549
+ def _compute_reward(self, current_loss: float) -> float:
550
+ """Reward signal: how much better than baseline?"""
551
+ delta = self._loss_ema - current_loss # positive = improving
552
+ return float(torch.sigmoid(torch.tensor(delta * self.reward_scale)).item())
553
+
554
+ @torch.no_grad()
555
+ def compute_stdp_update(self, pre_spikes: Tensor, post_spikes: Tensor) -> Tensor:
556
+ T = pre_spikes.shape[0]
557
+ device = pre_spikes.device
558
+ trace_pre = torch.zeros_like(pre_spikes[0])
559
+ trace_post = torch.zeros_like(post_spikes[0])
560
+ decay_plus = math.exp(-1.0 / self.tau_plus)
561
+ decay_minus = math.exp(-1.0 / self.tau_minus)
562
+
563
+ dW = torch.zeros(
564
+ post_spikes.shape[1], pre_spikes.shape[1],
565
+ device=device, dtype=pre_spikes.dtype)
566
+
567
+ for t in range(T):
568
+ trace_pre = trace_pre * decay_plus + pre_spikes[t]
569
+ trace_post = trace_post * decay_minus + post_spikes[t]
570
+ if post_spikes[t].any():
571
+ dW += self.a_plus * torch.outer(post_spikes[t], trace_pre)
572
+ if pre_spikes[t].any():
573
+ dW -= self.a_minus * torch.outer(trace_post, pre_spikes[t])
574
+ return dW
575
+
576
+ @torch.no_grad()
577
+ def apply_to_layer(self, layer: nn.Linear, pre_spikes: Tensor,
578
+ post_spikes: Tensor, current_loss: Optional[float] = None):
579
+ if pre_spikes.dim() == 3:
580
+ pre_spikes = pre_spikes.mean(dim=1)
581
+ if post_spikes.dim() == 3:
582
+ post_spikes = post_spikes.mean(dim=1)
583
+
584
+ dW = self.compute_stdp_update(pre_spikes, post_spikes)
585
+
586
+ # ═══ Reward modulation ═══
587
+ if current_loss is not None:
588
+ reward = self._compute_reward(current_loss)
589
+ # reward ∈ (0, 1): >0.5 means improving → full STDP
590
+ # <0.5 means worsening → suppress/reverse STDP
591
+ dW = dW * (2.0 * reward - 1.0) # map (0,1) → (-1,1)
592
+ self.update_reward(current_loss)
593
+
594
+ out_dim, in_dim = layer.weight.shape
595
+ dW = dW[:out_dim, :in_dim]
596
+ layer.weight.data = (layer.weight.data + dW).clamp(self.w_min, self.w_max)
597
+
598
+
599
+ # ─────────────────────────────────────────────────────────────────────────────
600
+ # §7 NORD MODEL (v3 — Multi-Scale + Temporal Smoothing Readout)
601
+ # ─────────────────────────────────────────────────────────────────────────────
602
+
603
+ class NordModel(nn.Module):
604
+ """
605
+ v3 — Full architecture:
606
+
607
+ Pipeline:
608
+ tokens → MultiScale TemporalEncoder → input_LIF(persistent)
609
+ → [NordBlock(LeakyClamp, SparseResonance) × N]
610
+ → readout_LIF → EMA-smoothed membrane → LM_head
611
+
612
+ FIX #6 — Temporal Smoothing Readout:
613
+ Instead of simple mean over timesteps, apply exponential moving average
614
+ on membrane potential → later timesteps get more weight → captures
615
+ the "final state" while retaining history. Learnable smoothing factor.
616
+ """
617
+
618
+ def __init__(self, cfg: NordConfig):
619
+ super().__init__()
620
+ self.cfg = cfg
621
+
622
+ self.encoder = TemporalSpikeEncoder(cfg)
623
+
624
+ # Input LIF with persistent membrane state
625
+ self.input_lif = AssociativeLIF(
626
+ cfg.d_model, cfg, persistent=cfg.persistent_mem)
627
+
628
+ self.blocks = nn.ModuleList([
629
+ NordBlock(cfg, layer_idx=i) for i in range(cfg.n_layers)
630
+ ])
631
+
632
+ # Readout LIF (persistent → accumulates cross-batch info)
633
+ self.readout_lif = AssociativeLIF(
634
+ cfg.d_model, cfg, persistent=cfg.persistent_mem)
635
+
636
+ # ═══ FIX #6: Temporal Smoothing ═══
637
+ # Learnable EMA decay for readout: how much to weight recent vs old timesteps
638
+ # Higher = more weight on recent (initialized 0.8)
639
+ self.readout_ema_raw = nn.Parameter(torch.tensor(1.4)) # sigmoid(1.4) ≈ 0.8
640
+
641
+ self.readout_norm = nn.LayerNorm(cfg.d_model)
642
+ self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)
643
+
644
+ self.stdp = STDPEngine(cfg)
645
+ self._stdp_cache: Dict[str, Tensor] = {}
646
+ self._last_loss: Optional[float] = None
647
+
648
+ @property
649
+ def readout_ema_decay(self) -> Tensor:
650
+ return torch.sigmoid(self.readout_ema_raw)
651
+
652
+ def reset_state(self):
653
+ """Reset all persistent membrane states (call between unrelated sequences)."""
654
+ self.input_lif.reset_state()
655
+ self.readout_lif.reset_state()
656
+
657
+ def forward(
658
+ self,
659
+ token_ids: Tensor,
660
+ enable_stdp: bool = False,
661
+ ) -> Tuple[Tensor, Dict[str, Tensor]]:
662
+ B, S = token_ids.shape
663
+ T_total = self.cfg.T_total
664
+ D = self.cfg.d_model
665
+
666
+ # ── Encode (Multi-Scale) → Spike ──
667
+ current = self.encoder(token_ids) # (T+T_slow, B*S, D)
668
+ spikes, _ = self.input_lif(current) # (T_total, B*S, D)
669
+ spikes = spikes.reshape(T_total, B, S, D)
670
+
671
+ _rates = [spikes.detach().mean()]
672
+
673
+ if enable_stdp:
674
+ self._stdp_cache["input"] = spikes.detach()
675
+
676
+ # ── Nord Blocks ──
677
+ x = spikes
678
+ for i, block in enumerate(self.blocks):
679
+ prev = x.detach() if enable_stdp else None
680
+ x = block(x)
681
+ _rates.append(x.detach().mean())
682
+
683
+ if enable_stdp and prev is not None:
684
+ self._stdp_cache[f"block_{i}_pre"] = prev
685
+ self._stdp_cache[f"block_{i}_post"] = x.detach()
686
+
687
+ # ── Readout: EMA-smoothed membrane potential ──
688
+ x_flat = x.reshape(T_total, B * S, D)
689
+ readout_spikes, v_membrane = self.readout_lif(x_flat)
690
+
691
+ # ═══ FIX #6: EMA temporal smoothing ═══
692
+ # Instead of simple mean, exponentially weight later timesteps more
693
+ alpha = self.readout_ema_decay # scalar in (0, 1)
694
+ ema = torch.zeros(B * S, D, device=x.device, dtype=v_membrane.dtype)
695
+ for t in range(T_total):
696
+ ema = alpha * ema + (1 - alpha) * v_membrane[t]
697
+ # ema now holds the smoothed membrane potential
698
+ v_smooth = ema.reshape(B, S, D)
699
+
700
+ # Hybrid: smoothed membrane + spike rate
701
+ s_mean = readout_spikes.mean(dim=0).reshape(B, S, D)
702
+ readout = v_smooth + s_mean
703
+
704
+ x_norm = F.layer_norm(
705
+ readout.float(),
706
+ self.readout_norm.normalized_shape,
707
+ self.readout_norm.weight.float() if self.readout_norm.weight is not None else None,
708
+ self.readout_norm.bias.float() if self.readout_norm.bias is not None else None,
709
+ self.readout_norm.eps,
710
+ ).to(readout.dtype)
711
+ logits = self.lm_head(x_norm)
712
+
713
+ # Stats (single GPU sync point)
714
+ stats = {}
715
+ stats["encoder_spike_rate"] = _rates[0].item()
716
+ for i in range(self.cfg.n_layers):
717
+ stats[f"block_{i}_spike_rate"] = _rates[i + 1].item()
718
+ out_rate = readout_spikes.detach().mean().item()
719
+ stats["output_spike_rate"] = out_rate
720
+ stats["sparsity"] = 1.0 - out_rate
721
+
722
+ return logits, stats
723
+
724
+ @torch.no_grad()
725
+ def stdp_update(self, current_loss: Optional[float] = None):
726
+ """
727
+ v3: Pass current_loss for reward modulation.
728
+ If None, falls back to unmodulated STDP.
729
+ """
730
+ loss_val = current_loss or self._last_loss
731
+ for i, block in enumerate(self.blocks):
732
+ pre_key = f"block_{i}_pre"
733
+ post_key = f"block_{i}_post"
734
+ if pre_key in self._stdp_cache and post_key in self._stdp_cache:
735
+ pre = self._stdp_cache[pre_key]
736
+ post = self._stdp_cache[post_key]
737
+ T_dim = pre.shape[0]
738
+ pre_flat = pre.reshape(T_dim, -1, self.cfg.d_model).mean(dim=1)
739
+ post_flat = post.reshape(T_dim, -1, self.cfg.d_model).mean(dim=1)
740
+ self.stdp.apply_to_layer(
741
+ block.resonance.W_v, pre_flat, post_flat,
742
+ current_loss=loss_val,
743
+ )
744
+ self._stdp_cache.clear()
745
+
746
+ def set_last_loss(self, loss: float):
747
+ """Store loss for reward-modulated STDP during inference."""
748
+ self._last_loss = loss
749
+
750
+ def count_params(self) -> str:
751
+ total = sum(p.numel() for p in self.parameters())
752
+ train = sum(p.numel() for p in self.parameters() if p.requires_grad)
753
+ return f"Total: {total/1e6:.1f}M | Trainable: {train/1e6:.1f}M"
754
+
755
+
756
+ # ─────────────────────────────────────────────────────────────────────────────
757
+ # §8 UTILITY
758
+ # ─────────────────────────────────────────────────────────────────────────────
759
+
760
+ def estimate_vram(cfg: NordConfig) -> str:
761
+ param_bytes = (
762
+ cfg.vocab_size * cfg.d_model
763
+ + cfg.n_layers * (
764
+ 4 * cfg.d_model * cfg.d_model
765
+ + 2 * cfg.d_model * cfg.d_ff
766
+ + 6 * cfg.d_model
767
+ + cfg.n_clusters * cfg.n_clusters # neighbor_weights
768
+ )
769
+ + cfg.vocab_size * cfg.d_model
770
+ ) * (2 if cfg.dtype == torch.float16 else 4)
771
+
772
+ act_bytes = cfg.T_total * 1 * cfg.max_seq_len * cfg.d_model * cfg.n_layers * 2 * 2
773
+ total_gb = (param_bytes + act_bytes) / (1024 ** 3)
774
+ return (
775
+ f"Parameters: ~{param_bytes / 1e6:.0f} MB\n"
776
+ f"Activations: ~{act_bytes / 1e6:.0f} MB (B=1, S={cfg.max_seq_len})\n"
777
+ f"Total Est: ~{total_gb:.2f} GB (target: 8 GB RTX 5070)"
778
+ )
train_nord.py ADDED
@@ -0,0 +1,456 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ╔══════════════════════════════════════════════════════════════════════════╗
3
+ ║ PROJECT NORD — Крок 2: Навчання SNN моделі ║
4
+ ║ ║
5
+ ║ Просто запусти: ║
6
+ ║ python train_nord.py ║
7
+ ║ ║
8
+ ║ Воно запитає: ║
9
+ ║ 1. Де лежить датасет (JSONL файл) ║
10
+ ║ 2. Куди зберігати модель ║
11
+ ║ І все — далі тренує автоматично. ║
12
+ ║ ║
13
+ ║ Можна зупинити Ctrl+C і продовжити пізніше — модель збережеться. ║
14
+ ╚══════════════════════════════════════════════════════════════════════════╝
15
+
16
+ Потрібно встановити один раз:
17
+ pip install torch transformers lmdb tqdm
18
+ """
19
+
20
+ from __future__ import annotations
21
+
22
+ import json
23
+ import math
24
+ import os
25
+ import shutil
26
+ import struct
27
+ import sys
28
+ import time
29
+ from pathlib import Path
30
+ from typing import Optional
31
+
32
+ import torch
33
+ import torch.nn.functional as F
34
+ from torch.amp import autocast
35
+ from torch.utils.data import Dataset, DataLoader
36
+
37
+ from nord_core import NordConfig, NordModel
38
+
39
+
40
+ # ─────────────────────────────────────────────────────────────────────────────
41
+ # ТОКЕНІЗАТОР
42
+ # ─────────────────────────────────────────────────────────────────────────────
43
+
44
+ class NordTokenizer:
45
+ """Обгортка Llama-3.2 токенізатора для Project Nord."""
46
+
47
+ def __init__(self, cfg: NordConfig):
48
+ from transformers import AutoTokenizer
49
+
50
+ print(f" [*] Завантажуємо Llama-3.2 токенізатор...")
51
+ self.tokenizer = AutoTokenizer.from_pretrained(
52
+ cfg.tokenizer_id, trust_remote_code=True,
53
+ )
54
+ if self.tokenizer.pad_token is None:
55
+ self.tokenizer.pad_token = self.tokenizer.eos_token
56
+ self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
57
+
58
+ self.max_len = cfg.max_seq_len
59
+ self.vocab_size = self.tokenizer.vocab_size
60
+ if cfg.vocab_size < self.vocab_size:
61
+ cfg.vocab_size = self.vocab_size
62
+
63
+ print(f" [✓] Токенізатор готовий (vocab={self.vocab_size:,})")
64
+
65
+ def encode(self, text: str) -> torch.Tensor:
66
+ enc = self.tokenizer(
67
+ text, return_tensors="pt",
68
+ max_length=self.max_len, truncation=True, padding="max_length",
69
+ )
70
+ return enc.input_ids
71
+
72
+ def decode(self, ids) -> str:
73
+ return self.tokenizer.decode(ids, skip_special_tokens=True)
74
+
75
+ @property
76
+ def pad_id(self) -> int:
77
+ return self.tokenizer.pad_token_id
78
+
79
+
80
+ # ─────────────────────────────────────────────────────────────────────────────
81
+ # LMDB ДАТАСЕТ (on-disk, zero RAM)
82
+ # ─────────────────────────────────────────────────────────────────────────────
83
+
84
+ class LMDBDataset(Dataset):
85
+ def __init__(self, db_path: str, max_seq_len: int):
86
+ import lmdb
87
+ self.db_path = db_path
88
+ self.max_seq_len = max_seq_len
89
+ self._env = None # opened lazily — can't pickle lmdb.Environment on Windows
90
+
91
+ # Read length once, then close
92
+ env = lmdb.open(db_path, readonly=True, lock=False, readahead=False, meminit=False)
93
+ with env.begin(write=False) as txn:
94
+ raw = txn.get(b"__len__")
95
+ self.length = struct.unpack("<Q", raw)[0]
96
+ env.close()
97
+ print(f" [✓] LMDB: {self.length:,} зразків")
98
+
99
+ def _get_env(self):
100
+ """Lazy-open LMDB per worker process (safe for multiprocessing)."""
101
+ if self._env is None:
102
+ import lmdb
103
+ self._env = lmdb.open(
104
+ self.db_path, readonly=True, lock=False,
105
+ readahead=True, meminit=False, max_readers=64,
106
+ )
107
+ return self._env
108
+
109
+ def __len__(self): return self.length
110
+
111
+ def __getitem__(self, idx):
112
+ env = self._get_env()
113
+ with env.begin(write=False) as txn:
114
+ raw = txn.get(f"sample_{idx:010d}".encode())
115
+ ids = torch.frombuffer(bytearray(raw), dtype=torch.int32).long()
116
+ S = self.max_seq_len
117
+ return ids[:S] if ids.shape[0] >= S else F.pad(ids, (0, S - ids.shape[0]))
118
+
119
+
120
+ def build_lmdb(jsonl_path: str, db_path: str, tokenizer: NordTokenizer,
121
+ max_seq_len: int, map_size_gb: float = 50.0):
122
+ """Конвертує JSONL → LMDB базу (один раз)."""
123
+ import lmdb
124
+ from tqdm import tqdm
125
+
126
+ print(f"\n [*] Будуємо LMDB базу даних...")
127
+ print(f" Це робиться ОДИН раз — потім тренуєшся з бази нескінченно.")
128
+ print(f" Джерело: {jsonl_path}")
129
+ print(f" Ціль: {db_path}")
130
+
131
+ # Підрахувати рядки
132
+ print(f" [*] Рахуємо рядки...")
133
+ with open(jsonl_path, "r", encoding="utf-8") as f:
134
+ n_lines = sum(1 for _ in f)
135
+ print(f" Знайдено: {n_lines:,} рядків")
136
+
137
+ env = lmdb.open(db_path, map_size=int(map_size_gb * (1024 ** 3)))
138
+ count = 0
139
+ total_tokens = 0
140
+
141
+ txn = env.begin(write=True)
142
+ try:
143
+ with open(jsonl_path, "r", encoding="utf-8") as f:
144
+ for line in tqdm(f, total=n_lines, desc=" Токенізація", unit=" doc"):
145
+ line = line.strip()
146
+ if not line:
147
+ continue
148
+ try:
149
+ obj = json.loads(line)
150
+ except json.JSONDecodeError:
151
+ continue
152
+
153
+ text = obj.get("text") or obj.get("content") or obj.get("passage", "")
154
+ if len(text) < 30:
155
+ continue
156
+
157
+ ids = tokenizer.encode(text).squeeze(0)
158
+ non_pad = (ids != tokenizer.pad_id).sum().item()
159
+ if non_pad < 10:
160
+ continue
161
+
162
+ txn.put(f"sample_{count:010d}".encode(),
163
+ ids.to(torch.int32).numpy().tobytes())
164
+ count += 1
165
+ total_tokens += non_pad
166
+
167
+ if count % 50_000 == 0:
168
+ txn.commit()
169
+ txn = env.begin(write=True)
170
+ print(f" ... {count:,} зразків, {total_tokens/1e6:.1f}M токенів")
171
+
172
+ txn.put(b"__len__", struct.pack("<Q", count))
173
+ txn.put(b"__total_tokens__", struct.pack("<Q", total_tokens))
174
+ txn.commit()
175
+ except BaseException:
176
+ txn.abort()
177
+ raise
178
+
179
+ env.close()
180
+
181
+ db_size = sum(f.stat().st_size for f in Path(db_path).rglob("*") if f.is_file())
182
+ print(f"\n [✓] LMDB готова!")
183
+ print(f" Зразків: {count:,}")
184
+ print(f" Токенів: {total_tokens:,} ({total_tokens/1e6:.1f}M)")
185
+ print(f" На диску: {db_size / (1024**3):.2f} GB")
186
+
187
+
188
+ # ─────────────────────────────────────────────────────────────────────────────
189
+ # LR SCHEDULE
190
+ # ─────────────────────────────────────────────────────────────────────────────
191
+
192
+ def get_lr(step: int, cfg: NordConfig) -> float:
193
+ if step < cfg.warmup_steps:
194
+ return cfg.lr * (step + 1) / cfg.warmup_steps
195
+ progress = min((step - cfg.warmup_steps) / max(1, cfg.max_steps - cfg.warmup_steps), 1.0)
196
+ return cfg.min_lr + 0.5 * (1.0 + math.cos(math.pi * progress)) * (cfg.lr - cfg.min_lr)
197
+
198
+
199
+ # ─────────────────────────────────────────────────────────────────────────────
200
+ # ЧЕКПОІНТ МЕНЕДЖЕР
201
+ # ─────────────────────────────────────────────────────────────────────────────
202
+
203
+ class CheckpointManager:
204
+ def __init__(self, save_dir: str, keep_last: int = 5):
205
+ self.save_dir = Path(save_dir)
206
+ self.save_dir.mkdir(parents=True, exist_ok=True)
207
+ self.keep_last = keep_last
208
+
209
+ def save(self, model, optimizer, scaler, step, loss, cfg):
210
+ path = self.save_dir / f"nord_step_{step:07d}.pt"
211
+ torch.save({
212
+ "step": step, "loss": loss,
213
+ "model_state_dict": model.state_dict(),
214
+ "optimizer_state_dict": optimizer.state_dict(),
215
+ "scaler_state_dict": scaler.state_dict(),
216
+ "config": {k: v for k, v in cfg.__dict__.items()
217
+ if not k.startswith("_") and k != "dtype"},
218
+ }, path)
219
+
220
+ latest = self.save_dir / "nord_latest.pt"
221
+ if latest.exists():
222
+ latest.unlink()
223
+ shutil.copy2(path, latest)
224
+
225
+ # Cleanup old
226
+ ckpts = sorted(self.save_dir.glob("nord_step_*.pt"), key=lambda p: p.stat().st_mtime)
227
+ for old in ckpts[:max(0, len(ckpts) - self.keep_last)]:
228
+ old.unlink()
229
+
230
+ print(f" [💾] Збережено: {path.name} (loss={loss:.4f})")
231
+
232
+ def load(self, model, optimizer, scaler, device) -> int:
233
+ latest = self.save_dir / "nord_latest.pt"
234
+ if not latest.exists():
235
+ ckpts = sorted(self.save_dir.glob("nord_step_*.pt"))
236
+ latest = ckpts[-1] if ckpts else None
237
+ if latest is None:
238
+ return 0
239
+
240
+ print(f" [*] Відновлюємо з: {latest.name}")
241
+ ckpt = torch.load(latest, map_location=device, weights_only=False)
242
+ model.load_state_dict(ckpt["model_state_dict"])
243
+ optimizer.load_state_dict(ckpt["optimizer_state_dict"])
244
+ scaler.load_state_dict(ckpt["scaler_state_dict"])
245
+ step = ckpt["step"]
246
+ print(f" [✓] Відновлено на кроці {step:,} (loss={ckpt.get('loss', '?')})")
247
+ return step
248
+
249
+ def save_final(self, model, cfg):
250
+ """Зберегти тільки модель для inference (менший файл)."""
251
+ path = self.save_dir / "nord_final.pt"
252
+ torch.save({
253
+ "model_state_dict": model.state_dict(),
254
+ "config": {k: v for k, v in cfg.__dict__.items()
255
+ if not k.startswith("_") and k != "dtype"},
256
+ }, path)
257
+ print(f" [⭐] Фінальна модель: {path}")
258
+ return path
259
+
260
+
261
+ # ─────────────────────────────────────────────────────────────────────────────
262
+ # ГОЛОВНА ФУНКЦІЯ НАВЧАННЯ
263
+ # ─────────────────────────────────────────────────────────────────────────────
264
+
265
+ def train(dataset_path: str, model_dir: str):
266
+ # ── Конфіг ──
267
+ cfg = NordConfig(
268
+ device="cuda" if torch.cuda.is_available() else "cpu",
269
+ dtype=torch.float16,
270
+ d_model=512,
271
+ n_heads=8,
272
+ n_layers=6,
273
+ d_ff=1024,
274
+ T=8,
275
+ T_slow=2,
276
+ persistent_mem=False, # shuffled batches → no persistent state during training
277
+ max_seq_len=512,
278
+ batch_size=4,
279
+ grad_accum=8,
280
+ lr=5e-4,
281
+ max_steps=100_000,
282
+ save_every=1000,
283
+ log_every=10,
284
+ )
285
+
286
+ print()
287
+ print("═" * 60)
288
+ print(" PROJECT NORD v3 — Навчання SNN моделі")
289
+ print("═" * 60)
290
+ print(f" GPU: {torch.cuda.get_device_name()}" if torch.cuda.is_available() else " CPU mode")
291
+ print(f" Модель: d={cfg.d_model}, layers={cfg.n_layers}, T={cfg.T}+{cfg.T_slow}={cfg.T_total}")
292
+ print(f" Ефективний батч: {cfg.batch_size} × {cfg.grad_accum} = {cfg.batch_size * cfg.grad_accum}")
293
+ print(f" Кроків: {cfg.max_steps:,}")
294
+ print(f" Датасет: {dataset_path}")
295
+ print(f" Модель → {model_dir}")
296
+ print()
297
+
298
+ # ── Токенізатор ──
299
+ tokenizer = NordTokenizer(cfg)
300
+
301
+ # ── LMDB база (будується автоматично якщо не існує) ──
302
+ db_path = str(Path(dataset_path).with_suffix("")) + "_lmdb"
303
+ if not Path(db_path).exists():
304
+ build_lmdb(dataset_path, db_path, tokenizer, cfg.max_seq_len)
305
+
306
+ dataset = LMDBDataset(db_path, cfg.max_seq_len)
307
+ dataloader = DataLoader(
308
+ dataset, batch_size=cfg.batch_size, shuffle=True,
309
+ num_workers=2, pin_memory=True, drop_last=True, persistent_workers=True,
310
+ )
311
+
312
+ # ── Модель ──
313
+ # НЕ робимо .half() — autocast сам конвертує forward pass у fp16,
314
+ # а параметри залишаються fp32 для коректної роботи GradScaler
315
+ print(f"\n [*] Будуємо модель...")
316
+ model = NordModel(cfg).to(cfg.device)
317
+ print(f" [✓] {model.count_params()}")
318
+
319
+ # ── Optimizer ──
320
+ optimizer = torch.optim.AdamW(
321
+ model.parameters(), lr=cfg.lr,
322
+ weight_decay=cfg.weight_decay, betas=(0.9, 0.95),
323
+ )
324
+ scaler = torch.amp.GradScaler("cuda", enabled=(cfg.dtype == torch.float16))
325
+
326
+ # ── Чекпоінти (auto-resume) ──
327
+ ckpt_mgr = CheckpointManager(model_dir)
328
+ start_step = ckpt_mgr.load(model, optimizer, scaler, cfg.device)
329
+
330
+ # ── ТРЕНУВАННЯ ──
331
+ model.train()
332
+ data_iter = iter(dataloader)
333
+ running_loss = 0.0
334
+ tokens_seen = 0
335
+ t_start = time.time()
336
+
337
+ print(f"\n {'─' * 50}")
338
+ print(f" Старт з кроку {start_step:,} | {len(dataset):,} зразків в базі")
339
+ print(f" Ctrl+C = зупинити (модель збережеться!)")
340
+ print(f" {'─' * 50}\n")
341
+
342
+ try:
343
+ for step in range(start_step, cfg.max_steps):
344
+ accum_loss = 0.0
345
+ stats = {}
346
+
347
+ for _ in range(cfg.grad_accum):
348
+ try:
349
+ input_ids = next(data_iter)
350
+ except StopIteration:
351
+ data_iter = iter(dataloader)
352
+ input_ids = next(data_iter)
353
+
354
+ input_ids = input_ids.to(cfg.device, non_blocking=True)
355
+
356
+ with autocast(device_type="cuda", dtype=torch.float16,
357
+ enabled=(cfg.dtype == torch.float16)):
358
+ logits, stats = model(input_ids)
359
+
360
+ shift_logits = logits[:, :-1, :].contiguous()
361
+ shift_labels = input_ids[:, 1:].contiguous()
362
+
363
+ loss = F.cross_entropy(
364
+ shift_logits.reshape(-1, cfg.vocab_size),
365
+ shift_labels.reshape(-1),
366
+ ignore_index=tokenizer.pad_id,
367
+ ) / cfg.grad_accum
368
+
369
+ scaler.scale(loss).backward()
370
+ accum_loss += loss.item()
371
+ tokens_seen += input_ids.numel()
372
+
373
+ # Optimizer step
374
+ scaler.unscale_(optimizer)
375
+ grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.max_grad_norm)
376
+ scaler.step(optimizer)
377
+ scaler.update()
378
+ optimizer.zero_grad(set_to_none=True)
379
+
380
+ # LR schedule
381
+ lr = get_lr(step, cfg)
382
+ for pg in optimizer.param_groups:
383
+ pg["lr"] = lr
384
+
385
+ running_loss += accum_loss
386
+
387
+ # Лог
388
+ if step % cfg.log_every == 0 and step > start_step:
389
+ avg = running_loss / cfg.log_every
390
+ elapsed = time.time() - t_start
391
+ tps = tokens_seen / elapsed / 1000 if elapsed > 0 else 0
392
+ sp = stats.get("sparsity", 0)
393
+
394
+ print(
395
+ f" крок {step:>7,} │ "
396
+ f"loss {avg:.4f} │ "
397
+ f"lr {lr:.1e} │ "
398
+ f"grad {grad_norm:.1f} │ "
399
+ f"sparsity {sp:.0%} │ "
400
+ f"{tps:.1f}k tok/s"
401
+ )
402
+ running_loss = 0.0
403
+
404
+ # Зберегти
405
+ if step > 0 and step % cfg.save_every == 0:
406
+ ckpt_mgr.save(model, optimizer, scaler, step, accum_loss, cfg)
407
+
408
+ except KeyboardInterrupt:
409
+ print(f"\n\n [⏸] Зупинено на кроці {step:,}")
410
+ ckpt_mgr.save(model, optimizer, scaler, step, accum_loss, cfg)
411
+ print(f" Щоб продовжити — просто запусти скрипт знову.")
412
+
413
+ # Зберегти фінальну модель для чату
414
+ ckpt_mgr.save_final(model, cfg)
415
+
416
+ print(f"\n {'═' * 50}")
417
+ print(f" Навчання завершено!")
418
+ print(f" Модель збережена в: {model_dir}")
419
+ print(f" Тепер запускай: python chat.py")
420
+ print(f" {'═' * 50}")
421
+
422
+
423
+ # ─────────────────────────────────────────────────────────────────────────────
424
+ # ENTRY POINT
425
+ # ─────────────────────────────────────────────────────────────────────────────
426
+
427
+ def main():
428
+ print("=" * 60)
429
+ print(" PROJECT NORD — Тренування SNN")
430
+ print("=" * 60)
431
+
432
+ # ── Запитати шлях до датасету ──
433
+ default_data = os.path.join("D:", os.sep, "nord_dataset", "train_data.jsonl")
434
+ print(f"\n Де лежить датасет? (JSONL файл)")
435
+ print(f" (Enter = {default_data})")
436
+ data_input = input(" Шлях до датасету: ").strip()
437
+ dataset_path = data_input if data_input else default_data
438
+
439
+ if not Path(dataset_path).exists():
440
+ print(f"\n [✗] Файл не знайдено: {dataset_path}")
441
+ print(f" Спочатку запусти: python download_data.py")
442
+ sys.exit(1)
443
+
444
+ # ── Запитати куди зберігати модель ──
445
+ default_model = os.path.join("D:", os.sep, "nord_model")
446
+ print(f"\n Куди зберігати модель?")
447
+ print(f" (Enter = {default_model})")
448
+ model_input = input(" Шлях для моделі: ").strip()
449
+ model_dir = model_input if model_input else default_model
450
+
451
+ # ── Поїхали ──
452
+ train(dataset_path, model_dir)
453
+
454
+
455
+ if __name__ == "__main__":
456
+ main()