File size: 15,205 Bytes
d831a32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
"""
╔══════════════════════════════════════════════════════════════════════════╗
β•‘         PROJECT NORD β€” ΠšΡ€ΠΎΠΊ 3: Π§Π°Ρ‚ Π· модСллю  v3.1                     β•‘
β•‘                                                                        β•‘
β•‘  ΠŸΡ€ΠΎΡΡ‚ΠΎ запусти:                                                       β•‘
β•‘      python chat.py                                                    β•‘
β•‘                                                                        β•‘
β•‘  Π’ΠΎΠ½ΠΎ Π·Π°ΠΏΠΈΡ‚Π°Ρ” Π΄Π΅ Π»Π΅ΠΆΠΈΡ‚ΡŒ модСль Ρ– Π·Π°ΠΏΡƒΡΡ‚ΠΈΡ‚ΡŒ Ρ–Π½Ρ‚Π΅Ρ€Π°ΠΊΡ‚ΠΈΠ²Π½ΠΈΠΉ Ρ‡Π°Ρ‚.           β•‘
β•‘  ΠŸΡ–Π΄Ρ‚Ρ€ΠΈΠΌΡƒΡ” STDP: модСль Π²Ρ‡ΠΈΡ‚ΡŒΡΡ Π½ΠΎΠ²ΠΈΠΌ словам прямо ΠΏΡ–Π΄ час Ρ€ΠΎΠ·ΠΌΠΎΠ²ΠΈ!    β•‘
β•‘  v3.1: Repetition Penalty β€” мСншС ΠΏΠΎΠ²Ρ‚ΠΎΡ€Π΅Π½ΡŒ Ρƒ Π³Π΅Π½Π΅Ρ€Π°Ρ†Ρ–Ρ—                 β•‘
β•šβ•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•

ΠŸΠΎΡ‚Ρ€Ρ–Π±Π½ΠΎ:
    pip install torch transformers
"""

from __future__ import annotations

import os
import sys
import time
from pathlib import Path
from collections import Counter

import torch
import torch.nn.functional as F

from nord_core import NordConfig, NordModel


# ─────────────────────────────────────────────────────────────────────────────
# ЗАВАНВАЖЕННЯ ΠœΠžΠ”Π•Π›Π†
# ─────────────────────────────────────────────────────────────────────────────

def load_model(model_dir: str) -> tuple:
    """Π—Π°Π²Π°Π½Ρ‚Π°ΠΆΠΈΡ‚ΠΈ модСль Ρ– Ρ‚ΠΎΠΊΠ΅Π½Ρ–Π·Π°Ρ‚ΠΎΡ€."""
    from transformers import AutoTokenizer

    model_path = Path(model_dir)

    # Π—Π½Π°ΠΉΡ‚ΠΈ Ρ„Π°ΠΉΠ» ΠΌΠΎΠ΄Π΅Π»Ρ–
    candidates = ["nord_final.pt", "nord_latest.pt"]
    ckpt_path = None
    for name in candidates:
        p = model_path / name
        if p.exists():
            ckpt_path = p
            break

    if ckpt_path is None:
        steps = sorted(model_path.glob("nord_step_*.pt"))
        if steps:
            ckpt_path = steps[-1]

    if ckpt_path is None:
        print(f"  [βœ—] НС Π·Π½Π°ΠΉΠ΄Π΅Π½ΠΎ ΠΌΠΎΠ΄Π΅Π»Ρ– Π²: {model_dir}")
        print(f"  Π‘ΠΏΠΎΡ‡Π°Ρ‚ΠΊΡƒ Π½Π°Ρ‚Ρ€Π΅Π½ΡƒΠΉ:  python train_nord.py")
        sys.exit(1)

    print(f"  [*] Π—Π°Π²Π°Π½Ρ‚Π°ΠΆΡƒΡ”ΠΌΠΎ: {ckpt_path.name}")

    device = "cuda" if torch.cuda.is_available() else "cpu"
    ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)

    saved_cfg = ckpt.get("config", {})
    cfg = NordConfig(
        device=device,
        dtype=torch.float16 if device == "cuda" else torch.float32,
        d_model=saved_cfg.get("d_model", 512),
        n_heads=saved_cfg.get("n_heads", 8),
        n_layers=saved_cfg.get("n_layers", 6),
        d_ff=saved_cfg.get("d_ff", 1024),
        T=saved_cfg.get("T", 8),
        T_slow=saved_cfg.get("T_slow", 2),
        max_seq_len=saved_cfg.get("max_seq_len", 512),
        vocab_size=saved_cfg.get("vocab_size", 128_256),
        persistent_mem=False,
    )

    model = NordModel(cfg).to(device)
    model.load_state_dict(ckpt["model_state_dict"])
    model.eval()

    print(f"  [*] Π—Π°Π²Π°Π½Ρ‚Π°ΠΆΡƒΡ”ΠΌΠΎ Llama-3.2 Ρ‚ΠΎΠΊΠ΅Π½Ρ–Π·Π°Ρ‚ΠΎΡ€...")
    tokenizer = AutoTokenizer.from_pretrained(
        cfg.tokenizer_id, trust_remote_code=True,
    )
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.pad_token_id = tokenizer.eos_token_id

    param_count = sum(p.numel() for p in model.parameters()) / 1e6
    print(f"  [βœ“] МодСль Π·Π°Π²Π°Π½Ρ‚Π°ΠΆΠ΅Π½Π°! ({param_count:.1f}M ΠΏΠ°Ρ€Π°ΠΌΠ΅Ρ‚Ρ€Ρ–Π²)")

    return model, tokenizer, cfg


# ─────────────────────────────────────────────────────────────────────────────
# REPETITION PENALTY
# ─────────────────────────────────────────────────────────────────────────────

def apply_repetition_penalty(
    logits: torch.Tensor,
    generated_ids: torch.Tensor,
    penalty: float = 1.3,
    window: int = 50,
) -> torch.Tensor:
    """
    Π—ΠΌΠ΅Π½ΡˆΡƒΡ” ΠΉΠΌΠΎΠ²Ρ–Ρ€Π½Ρ–ΡΡ‚ΡŒ Ρ‚ΠΎΠΊΠ΅Π½Ρ–Π² які Π²ΠΆΠ΅ Π·'явились Π² останніх `window` Ρ‚ΠΎΠΊΠ΅Π½Π°Ρ….
    penalty > 1.0 = Π·ΠΌΠ΅Π½ΡˆΡƒΡ” повторСння (Ρ€Π΅ΠΊΠΎΠΌΠ΅Π½Π΄ΠΎΠ²Π°Π½ΠΎ 1.2-1.5)
    Π§ΠΈΠΌ Π±Ρ–Π»ΡŒΡˆΠ΅ Ρ€Π°Π·Ρ–Π² Ρ‚ΠΎΠΊΠ΅Π½ Π·'явився β€” Ρ‚ΠΈΠΌ ΡΠΈΠ»ΡŒΠ½Ρ–ΡˆΠΈΠΉ penalty (Π΄ΠΎ 5x).
    """
    if penalty <= 1.0:
        return logits

    recent_ids = generated_ids[0, -window:].tolist()
    token_counts = Counter(recent_ids)

    for token_id, count in token_counts.items():
        if token_id < logits.size(-1):
            # ЕкспонСнційний penalty: penalty^min(count, 5)
            effective_penalty = penalty ** min(count, 5)
            if logits[0, token_id] > 0:
                logits[0, token_id] = logits[0, token_id] / effective_penalty
            else:
                logits[0, token_id] = logits[0, token_id] * effective_penalty

    return logits


# ─────────────────────────────────────────────────────────────────────────────
# ГЕНЕРАЦІЯ Π’Π•ΠšΠ‘Π’Π£
# ─────────────────────────────────────────────────────────────────────────────

@torch.no_grad()
def generate(
    model: NordModel,
    tokenizer,
    cfg: NordConfig,
    prompt: str,
    max_new_tokens: int = 200,
    temperature: float = 0.8,
    top_k: int = 50,
    top_p: float = 0.9,
    enable_stdp: bool = True,
    repetition_penalty: float = 1.3,
    rep_window: int = 50,
) -> str:
    """
    АвторСгрСсивна гСнСрація Π· SNN.
    v3.1: + repetition penalty для Ρ€Ρ–Π·Π½ΠΎΠΌΠ°Π½Ρ–Ρ‚Π½Ρ–ΡˆΠΎΠ³ΠΎ тСксту.
    """
    device = cfg.device

    model.reset_state()

    max_prompt_len = max(32, cfg.max_seq_len - max_new_tokens)
    enc = tokenizer(prompt, return_tensors="pt", truncation=True,
                    max_length=max_prompt_len)
    input_ids = enc.input_ids.to(device)
    generated_ids = input_ids.clone()

    for _ in range(max_new_tokens):
        context = generated_ids[:, -cfg.max_seq_len:]

        with torch.amp.autocast("cuda", enabled=(device == "cuda")):
            logits, stats = model(context, enable_stdp=enable_stdp)

        next_logits = logits[:, -1, :].float()

        # ── Repetition Penalty (Π΄ΠΎ temperature!) ──
        next_logits = apply_repetition_penalty(
            next_logits, generated_ids,
            penalty=repetition_penalty,
            window=rep_window,
        )

        if temperature > 0:
            next_logits = next_logits / temperature

        if top_k > 0:
            top_k_vals, _ = torch.topk(next_logits, min(top_k, next_logits.size(-1)))
            threshold = top_k_vals[:, -1].unsqueeze(-1)
            next_logits[next_logits < threshold] = float("-inf")

        if top_p < 1.0:
            sorted_logits, sorted_idx = torch.sort(next_logits, descending=True)
            cumprobs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
            remove_mask = cumprobs - F.softmax(sorted_logits, dim=-1) > top_p
            sorted_logits[remove_mask] = float("-inf")
            next_logits.scatter_(1, sorted_idx, sorted_logits)

        probs = F.softmax(next_logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)
        generated_ids = torch.cat([generated_ids, next_token], dim=-1)

        # v3: Reward-modulated STDP
        if enable_stdp:
            loss_proxy = -torch.log(probs.max() + 1e-8).item()
            model.stdp_update(current_loss=loss_proxy)

        if next_token.item() == tokenizer.eos_token_id:
            break

    new_ids = generated_ids[0, input_ids.shape[1]:]
    return tokenizer.decode(new_ids, skip_special_tokens=True)


# ─────────────────────────────────────────────────────────────────────────────
# Π†ΠΠ’Π•Π ΠΠšΠ’Π˜Π’ΠΠ˜Π™ ЧАВ
# ─────────────────────────────────────────────────────────────────────────────

def chat_loop(model: NordModel, tokenizer, cfg: NordConfig):
    """Π“ΠΎΠ»ΠΎΠ²Π½ΠΈΠΉ Ρ†ΠΈΠΊΠ» Ρ‡Π°Ρ‚Ρƒ."""

    temperature = 0.8
    max_tokens = 200
    stdp_enabled = True
    rep_penalty = 1.3
    rep_window = 50

    print(f"\n  {'─' * 50}")
    print(f"  Пиши повідомлСння Ρ– натискай Enter.")
    print(f"  Команди:")
    print(f"    /quit          β€” Π²ΠΈΠΉΡ‚ΠΈ")
    print(f"    /temp 0.5      β€” Π·ΠΌΡ–Π½ΠΈΡ‚ΠΈ temperature")
    print(f"    /tokens 300    β€” макс. Ρ‚ΠΎΠΊΠ΅Π½Ρ–Π² Ρƒ Π²Ρ–Π΄ΠΏΠΎΠ²Ρ–Π΄Ρ–")
    print(f"    /stdp on|off   β€” STDP навчання ΠΏΡ–Π΄ час Ρ‡Π°Ρ‚Ρƒ")
    print(f"    /rep 1.5       β€” repetition penalty (1.0=Π²ΠΈΠΌΠΊ, 1.2-1.5=Π½ΠΎΡ€ΠΌ)")
    print(f"    /stats         β€” ΠΏΠΎΠΊΠ°Π·Π°Ρ‚ΠΈ спайк-статистику")
    print(f"    /reset         β€” скинути STDP кСш")
    print(f"  {'─' * 50}\n")

    last_stats = {}

    while True:
        try:
            user_input = input("  Π’ΠΈ: ").strip()
        except (KeyboardInterrupt, EOFError):
            print("\n  Π‘ΡƒΠ²Π°ΠΉ! πŸ‘‹")
            break

        if not user_input:
            continue

        # ── Команди ──
        if user_input.startswith("/"):
            parts = user_input.split()
            cmd = parts[0].lower()

            if cmd == "/quit":
                print("  Π‘ΡƒΠ²Π°ΠΉ! πŸ‘‹")
                break

            elif cmd == "/temp" and len(parts) > 1:
                try:
                    temperature = float(parts[1])
                    print(f"  [βš™] Temperature = {temperature}")
                except ValueError:
                    print(f"  [!] НСвірнС значСння")

            elif cmd == "/tokens" and len(parts) > 1:
                try:
                    max_tokens = int(parts[1])
                    print(f"  [βš™] Max tokens = {max_tokens}")
                except ValueError:
                    print(f"  [!] НСвірнС значСння")

            elif cmd == "/stdp":
                if len(parts) > 1 and parts[1].lower() in ("off", "0", "Π½Ρ–"):
                    stdp_enabled = False
                    print(f"  [βš™] STDP Π²ΠΈΠΌΠΊΠ½Π΅Π½ΠΎ")
                else:
                    stdp_enabled = True
                    print(f"  [βš™] STDP ΡƒΠ²Ρ–ΠΌΠΊΠ½Π΅Π½ΠΎ β€” модСль Π²Ρ‡ΠΈΡ‚ΡŒΡΡ ΠΏΡ–Π΄ час Ρ‡Π°Ρ‚Ρƒ!")

            elif cmd == "/rep" and len(parts) > 1:
                try:
                    rep_penalty = float(parts[1])
                    print(f"  [βš™] Repetition penalty = {rep_penalty}")
                    if rep_penalty > 2.0:
                        print(f"  [!] Π£Π²Π°Π³Π°: значСння > 2.0 ΠΌΠΎΠΆΠ΅ Π·Π»Π°ΠΌΠ°Ρ‚ΠΈ Π³Π΅Π½Π΅Ρ€Π°Ρ†Ρ–ΡŽ")
                except ValueError:
                    print(f"  [!] НСвірнС значСння")

            elif cmd == "/stats":
                if last_stats:
                    print(f"  [πŸ“Š] ΠžΡΡ‚Π°Π½Π½Ρ статистика:")
                    for k, v in last_stats.items():
                        print(f"       {k}: {v:.4f}")
                else:
                    print(f"  [!] Π©Π΅ Π½Π΅ΠΌΠ° статистики β€” напиши Ρ‰ΠΎΡΡŒ спочатку")

            elif cmd == "/reset":
                model._stdp_cache.clear()
                print(f"  [βš™] STDP кСш скинуто")

            else:
                print(f"  [!] НСвідома ΠΊΠΎΠΌΠ°Π½Π΄Π°: {cmd}")

            continue

        # ── ГСнСрація ──
        t0 = time.time()

        response = generate(
            model, tokenizer, cfg,
            prompt=user_input,
            max_new_tokens=max_tokens,
            temperature=temperature,
            enable_stdp=stdp_enabled,
            repetition_penalty=rep_penalty,
            rep_window=rep_window,
        )

        elapsed = time.time() - t0

        print(f"\n  Nord: {response}")

        resp_tokens = len(tokenizer.encode(response, add_special_tokens=False))
        tps = resp_tokens / elapsed if elapsed > 0 else 0
        stdp_tag = " [STDP βœ“]" if stdp_enabled else ""
        rep_tag = f" [REP {rep_penalty}]" if rep_penalty > 1.0 else ""
        print(f"  [{resp_tokens} tok, {elapsed:.1f}s, {tps:.1f} tok/s{stdp_tag}{rep_tag}]\n")

        # Π—Π±Π΅Ρ€Π΅Π³Ρ‚ΠΈ статистику
        with torch.no_grad(), torch.amp.autocast("cuda", enabled=(cfg.device == "cuda")):
            ids = tokenizer(user_input, return_tensors="pt",
                          truncation=True, max_length=cfg.max_seq_len).input_ids.to(cfg.device)
            _, last_stats = model(ids)


# ─────────────────────────────────────────────────────────────────────────────
# ENTRY POINT
# ─────────────────────────────────────────────────────────────────────────────

def main():
    print()
    print("═" * 60)
    print("  ⚑ PROJECT NORD β€” Spiking Neural Network Chat v3.1")
    print("═" * 60)

    default_model = os.path.join("D:", os.sep, "nord_model")
    print(f"\n  Π”Π΅ Π»Π΅ΠΆΠΈΡ‚ΡŒ Π½Π°Π²Ρ‡Π΅Π½Π° модСль?")
    print(f"  (Enter = {default_model})")
    model_input = input("  Шлях: ").strip()
    model_dir = model_input if model_input else default_model

    if not Path(model_dir).exists():
        print(f"\n  [βœ—] Папка Π½Π΅ Π·Π½Π°ΠΉΠ΄Π΅Π½Π°: {model_dir}")
        print(f"  Π‘ΠΏΠΎΡ‡Π°Ρ‚ΠΊΡƒ Π½Π°Ρ‚Ρ€Π΅Π½ΡƒΠΉ:  python train_nord.py")
        sys.exit(1)

    model, tokenizer, cfg = load_model(model_dir)
    chat_loop(model, tokenizer, cfg)


if __name__ == "__main__":
    main()