| | """ |
| | ╔══════════════════════════════════════════════════════════════════════════╗ |
| | ║ PROJECT NORD — Крок 3: Чат з моделлю v3.1 ║ |
| | ║ ║ |
| | ║ Просто запусти: ║ |
| | ║ python chat.py ║ |
| | ║ ║ |
| | ║ Воно запитає де лежить модель і запустить інтерактивний чат. ║ |
| | ║ Підтримує STDP: модель вчиться новим словам прямо під час розмови! ║ |
| | ║ v3.1: Repetition Penalty — менше повторень у генерації ║ |
| | ╚══════════════════════════════════════════════════════════════════════════╝ |
| | |
| | Потрібно: |
| | pip install torch transformers |
| | """ |
| |
|
| | from __future__ import annotations |
| |
|
| | import os |
| | import sys |
| | import time |
| | from pathlib import Path |
| | from collections import Counter |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| |
|
| | from nord_core import NordConfig, NordModel |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def load_model(model_dir: str) -> tuple: |
| | """Завантажити модель і токенізатор.""" |
| | from transformers import AutoTokenizer |
| |
|
| | model_path = Path(model_dir) |
| |
|
| | |
| | candidates = ["nord_final.pt", "nord_latest.pt"] |
| | ckpt_path = None |
| | for name in candidates: |
| | p = model_path / name |
| | if p.exists(): |
| | ckpt_path = p |
| | break |
| |
|
| | if ckpt_path is None: |
| | steps = sorted(model_path.glob("nord_step_*.pt")) |
| | if steps: |
| | ckpt_path = steps[-1] |
| |
|
| | if ckpt_path is None: |
| | print(f" [✗] Не знайдено моделі в: {model_dir}") |
| | print(f" Спочатку натренуй: python train_nord.py") |
| | sys.exit(1) |
| |
|
| | print(f" [*] Завантажуємо: {ckpt_path.name}") |
| |
|
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | ckpt = torch.load(ckpt_path, map_location=device, weights_only=False) |
| |
|
| | saved_cfg = ckpt.get("config", {}) |
| | cfg = NordConfig( |
| | device=device, |
| | dtype=torch.float16 if device == "cuda" else torch.float32, |
| | d_model=saved_cfg.get("d_model", 512), |
| | n_heads=saved_cfg.get("n_heads", 8), |
| | n_layers=saved_cfg.get("n_layers", 6), |
| | d_ff=saved_cfg.get("d_ff", 1024), |
| | T=saved_cfg.get("T", 8), |
| | T_slow=saved_cfg.get("T_slow", 2), |
| | max_seq_len=saved_cfg.get("max_seq_len", 512), |
| | vocab_size=saved_cfg.get("vocab_size", 128_256), |
| | persistent_mem=False, |
| | ) |
| |
|
| | model = NordModel(cfg).to(device) |
| | model.load_state_dict(ckpt["model_state_dict"]) |
| | model.eval() |
| |
|
| | print(f" [*] Завантажуємо Llama-3.2 токенізатор...") |
| | tokenizer = AutoTokenizer.from_pretrained( |
| | cfg.tokenizer_id, trust_remote_code=True, |
| | ) |
| | if tokenizer.pad_token is None: |
| | tokenizer.pad_token = tokenizer.eos_token |
| | tokenizer.pad_token_id = tokenizer.eos_token_id |
| |
|
| | param_count = sum(p.numel() for p in model.parameters()) / 1e6 |
| | print(f" [✓] Модель завантажена! ({param_count:.1f}M параметрів)") |
| |
|
| | return model, tokenizer, cfg |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def apply_repetition_penalty( |
| | logits: torch.Tensor, |
| | generated_ids: torch.Tensor, |
| | penalty: float = 1.3, |
| | window: int = 50, |
| | ) -> torch.Tensor: |
| | """ |
| | Зменшує ймовірність токенів які вже з'явились в останніх `window` токенах. |
| | penalty > 1.0 = зменшує повторення (рекомендовано 1.2-1.5) |
| | Чим більше разів токен з'явився — тим сильніший penalty (до 5x). |
| | """ |
| | if penalty <= 1.0: |
| | return logits |
| |
|
| | recent_ids = generated_ids[0, -window:].tolist() |
| | token_counts = Counter(recent_ids) |
| |
|
| | for token_id, count in token_counts.items(): |
| | if token_id < logits.size(-1): |
| | |
| | effective_penalty = penalty ** min(count, 5) |
| | if logits[0, token_id] > 0: |
| | logits[0, token_id] = logits[0, token_id] / effective_penalty |
| | else: |
| | logits[0, token_id] = logits[0, token_id] * effective_penalty |
| |
|
| | return logits |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | @torch.no_grad() |
| | def generate( |
| | model: NordModel, |
| | tokenizer, |
| | cfg: NordConfig, |
| | prompt: str, |
| | max_new_tokens: int = 200, |
| | temperature: float = 0.8, |
| | top_k: int = 50, |
| | top_p: float = 0.9, |
| | enable_stdp: bool = True, |
| | repetition_penalty: float = 1.3, |
| | rep_window: int = 50, |
| | ) -> str: |
| | """ |
| | Авторегресивна генерація з SNN. |
| | v3.1: + repetition penalty для різноманітнішого тексту. |
| | """ |
| | device = cfg.device |
| |
|
| | model.reset_state() |
| |
|
| | max_prompt_len = max(32, cfg.max_seq_len - max_new_tokens) |
| | enc = tokenizer(prompt, return_tensors="pt", truncation=True, |
| | max_length=max_prompt_len) |
| | input_ids = enc.input_ids.to(device) |
| | generated_ids = input_ids.clone() |
| |
|
| | for _ in range(max_new_tokens): |
| | context = generated_ids[:, -cfg.max_seq_len:] |
| |
|
| | with torch.amp.autocast("cuda", enabled=(device == "cuda")): |
| | logits, stats = model(context, enable_stdp=enable_stdp) |
| |
|
| | next_logits = logits[:, -1, :].float() |
| |
|
| | |
| | next_logits = apply_repetition_penalty( |
| | next_logits, generated_ids, |
| | penalty=repetition_penalty, |
| | window=rep_window, |
| | ) |
| |
|
| | if temperature > 0: |
| | next_logits = next_logits / temperature |
| |
|
| | if top_k > 0: |
| | top_k_vals, _ = torch.topk(next_logits, min(top_k, next_logits.size(-1))) |
| | threshold = top_k_vals[:, -1].unsqueeze(-1) |
| | next_logits[next_logits < threshold] = float("-inf") |
| |
|
| | if top_p < 1.0: |
| | sorted_logits, sorted_idx = torch.sort(next_logits, descending=True) |
| | cumprobs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
| | remove_mask = cumprobs - F.softmax(sorted_logits, dim=-1) > top_p |
| | sorted_logits[remove_mask] = float("-inf") |
| | next_logits.scatter_(1, sorted_idx, sorted_logits) |
| |
|
| | probs = F.softmax(next_logits, dim=-1) |
| | next_token = torch.multinomial(probs, num_samples=1) |
| | generated_ids = torch.cat([generated_ids, next_token], dim=-1) |
| |
|
| | |
| | if enable_stdp: |
| | loss_proxy = -torch.log(probs.max() + 1e-8).item() |
| | model.stdp_update(current_loss=loss_proxy) |
| |
|
| | if next_token.item() == tokenizer.eos_token_id: |
| | break |
| |
|
| | new_ids = generated_ids[0, input_ids.shape[1]:] |
| | return tokenizer.decode(new_ids, skip_special_tokens=True) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def chat_loop(model: NordModel, tokenizer, cfg: NordConfig): |
| | """Головний цикл чату.""" |
| |
|
| | temperature = 0.8 |
| | max_tokens = 200 |
| | stdp_enabled = True |
| | rep_penalty = 1.3 |
| | rep_window = 50 |
| |
|
| | print(f"\n {'─' * 50}") |
| | print(f" Пиши повідомлення і натискай Enter.") |
| | print(f" Команди:") |
| | print(f" /quit — вийти") |
| | print(f" /temp 0.5 — змінити temperature") |
| | print(f" /tokens 300 — макс. токенів у відповіді") |
| | print(f" /stdp on|off — STDP навчання під час чату") |
| | print(f" /rep 1.5 — repetition penalty (1.0=вимк, 1.2-1.5=норм)") |
| | print(f" /stats — показати спайк-статистику") |
| | print(f" /reset — скинути STDP кеш") |
| | print(f" {'─' * 50}\n") |
| |
|
| | last_stats = {} |
| |
|
| | while True: |
| | try: |
| | user_input = input(" Ти: ").strip() |
| | except (KeyboardInterrupt, EOFError): |
| | print("\n Бувай! 👋") |
| | break |
| |
|
| | if not user_input: |
| | continue |
| |
|
| | |
| | if user_input.startswith("/"): |
| | parts = user_input.split() |
| | cmd = parts[0].lower() |
| |
|
| | if cmd == "/quit": |
| | print(" Бувай! 👋") |
| | break |
| |
|
| | elif cmd == "/temp" and len(parts) > 1: |
| | try: |
| | temperature = float(parts[1]) |
| | print(f" [⚙] Temperature = {temperature}") |
| | except ValueError: |
| | print(f" [!] Невірне значення") |
| |
|
| | elif cmd == "/tokens" and len(parts) > 1: |
| | try: |
| | max_tokens = int(parts[1]) |
| | print(f" [⚙] Max tokens = {max_tokens}") |
| | except ValueError: |
| | print(f" [!] Невірне значення") |
| |
|
| | elif cmd == "/stdp": |
| | if len(parts) > 1 and parts[1].lower() in ("off", "0", "ні"): |
| | stdp_enabled = False |
| | print(f" [⚙] STDP вимкнено") |
| | else: |
| | stdp_enabled = True |
| | print(f" [⚙] STDP увімкнено — модель вчиться під час чату!") |
| |
|
| | elif cmd == "/rep" and len(parts) > 1: |
| | try: |
| | rep_penalty = float(parts[1]) |
| | print(f" [⚙] Repetition penalty = {rep_penalty}") |
| | if rep_penalty > 2.0: |
| | print(f" [!] Увага: значення > 2.0 може зламати генерацію") |
| | except ValueError: |
| | print(f" [!] Невірне значення") |
| |
|
| | elif cmd == "/stats": |
| | if last_stats: |
| | print(f" [📊] Остання статистика:") |
| | for k, v in last_stats.items(): |
| | print(f" {k}: {v:.4f}") |
| | else: |
| | print(f" [!] Ще нема статистики — напиши щось спочатку") |
| |
|
| | elif cmd == "/reset": |
| | model._stdp_cache.clear() |
| | print(f" [⚙] STDP кеш скинуто") |
| |
|
| | else: |
| | print(f" [!] Невідома команда: {cmd}") |
| |
|
| | continue |
| |
|
| | |
| | t0 = time.time() |
| |
|
| | response = generate( |
| | model, tokenizer, cfg, |
| | prompt=user_input, |
| | max_new_tokens=max_tokens, |
| | temperature=temperature, |
| | enable_stdp=stdp_enabled, |
| | repetition_penalty=rep_penalty, |
| | rep_window=rep_window, |
| | ) |
| |
|
| | elapsed = time.time() - t0 |
| |
|
| | print(f"\n Nord: {response}") |
| |
|
| | resp_tokens = len(tokenizer.encode(response, add_special_tokens=False)) |
| | tps = resp_tokens / elapsed if elapsed > 0 else 0 |
| | stdp_tag = " [STDP ✓]" if stdp_enabled else "" |
| | rep_tag = f" [REP {rep_penalty}]" if rep_penalty > 1.0 else "" |
| | print(f" [{resp_tokens} tok, {elapsed:.1f}s, {tps:.1f} tok/s{stdp_tag}{rep_tag}]\n") |
| |
|
| | |
| | with torch.no_grad(), torch.amp.autocast("cuda", enabled=(cfg.device == "cuda")): |
| | ids = tokenizer(user_input, return_tensors="pt", |
| | truncation=True, max_length=cfg.max_seq_len).input_ids.to(cfg.device) |
| | _, last_stats = model(ids) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def main(): |
| | print() |
| | print("═" * 60) |
| | print(" ⚡ PROJECT NORD — Spiking Neural Network Chat v3.1") |
| | print("═" * 60) |
| |
|
| | default_model = os.path.join("D:", os.sep, "nord_model") |
| | print(f"\n Де лежить навчена модель?") |
| | print(f" (Enter = {default_model})") |
| | model_input = input(" Шлях: ").strip() |
| | model_dir = model_input if model_input else default_model |
| |
|
| | if not Path(model_dir).exists(): |
| | print(f"\n [✗] Папка не знайдена: {model_dir}") |
| | print(f" Спочатку натренуй: python train_nord.py") |
| | sys.exit(1) |
| |
|
| | model, tokenizer, cfg = load_model(model_dir) |
| | chat_loop(model, tokenizer, cfg) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |