#@title Neuro-Synergy Chat Interface """ Interactive chat interface for Neuro-Synergy Spiking GPT model. Loads the fine-tuned checkpoint and provides a conversational interface with real-time stats. """ import os import time import torch import torch.nn as nn import torch.nn.functional as F from transformers import AutoTokenizer from torch.utils.cpp_extension import load_inline # Force expansive segments for CUDA os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" # Try importing spikingjelly try: import spikingjelly except ImportError: print("Installing spikingjelly...") import subprocess import sys subprocess.check_call([sys.executable, "-m", "pip", "install", "spikingjelly"]) import spikingjelly from spikingjelly.activation_based import neuron, surrogate, functional # ========================================== # CONFIGURATION # ========================================== CONFIG = { "device": "cuda" if torch.cuda.is_available() else "cpu", "d_model": 768, "n_layers": 18, "n_heads": 12, "vocab_size": 50304, "seq_len": 1024, "checkpoint_path": "neuro_synergy_chat.pt", # Fine-tuned checkpoint "max_new_tokens": 200, "temperature": 0.7, "top_p": 0.9, } # ========================================== # CUDA KERNELS (Same as training) # ========================================== cuda_source = """ #include #include #define MIN_VALUE (-1e38) #ifndef Tmax #define Tmax 1024 #endif template __global__ void kernel_forward(const int B, const int T, const int C, const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, F *__restrict__ const _y) { const int idx = blockIdx.x * blockDim.x + threadIdx.x; const int _b = idx / C; const int _c = idx % C; const int _offset = _b * T * C + _c; F u = _u[_c]; F w = _w[_c]; const F *__restrict__ const k = _k + _offset; const F *__restrict__ const v = _v + _offset; F *__restrict__ const y = _y + _offset; F p = 0, q = 0, o = MIN_VALUE; for (int i = 0; i < T; i++) { const int ii = i * C; F no = max(o, u + k[ii]); F A = exp(o - no); F B = exp(u + k[ii] - no); y[ii] = (A * p + B * v[ii]) / (A * q + B); no = max(w + o, k[ii]); A = exp(w + o - no); B = exp(k[ii] - no); p = A * p + B * v[ii]; q = A * q + B; o = no; } } template __global__ void kernel_backward(const int B, const int T, const int C, const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _gy, F *__restrict__ const _gw, F *__restrict__ const _gu, F *__restrict__ const _gk, F *__restrict__ const _gv) { const int idx = blockIdx.x * blockDim.x + threadIdx.x; const int _b = idx / C; const int _c = idx % C; const int _offset = _b * T * C + _c; F u = _u[_c]; F w = _w[_c]; const F *__restrict__ const k = _k + _offset; const F *__restrict__ const v = _v + _offset; const F *__restrict__ const gy = _gy + _offset; F *__restrict__ const gk = _gk + _offset; F *__restrict__ const gv = _gv + _offset; F y[Tmax], z[Tmax], zexp[Tmax]; F gw = 0, gu = 0; F p = 0, q = 0; F dpdw = 0, dqdw = 0; F o = MIN_VALUE; for (int i = 0; i < T; i++) { const int ii = i * C; F no = max(o, k[ii] + u); F A = exp(o - no); F B = exp(k[ii] + u - no); F num = A * p + B * v[ii]; F iden = 1 / (A * q + B); y[i] = num * iden; z[i] = iden; zexp[i] = k[ii] + u - no; gw += gy[ii] * (dpdw - dqdw * y[i]) * iden * A; gu += gy[ii] * (v[ii] - y[i]) * B * iden; no = max(w + o, k[ii]); A = exp(w + o - no); B = exp(k[ii] - no); dpdw = A * (p + dpdw); dqdw = A * (q + dqdw); p = A * p + B * v[ii]; q = A * q + B; o = no; } F gp = 0, gq = 0; o = MIN_VALUE; for (int i = T - 1; i >= 0; i--) { const int ii = i * C; F A = gy[ii] * z[i] * exp(zexp[i]); F B = exp(k[ii] + o); gk[ii] = A * (v[ii] - y[i]) + B * (gp * v[ii] + gq); gv[ii] = A + B * gp; F no = max(w + o, zexp[i] - k[ii] - u); A = exp(w + o - no); B = gy[ii] * z[i] * exp(zexp[i] - k[ii] - u - no); gp = A * gp + B; gq = A * gq - B * y[i]; o = no; } const int _offsetBC = _b * C + _c; _gw[_offsetBC] += gw * _w[_c]; _gu[_offsetBC] += gu; } void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y) { dim3 threadsPerBlock( min(C, 32) ); assert(B * C % threadsPerBlock.x == 0); dim3 numBlocks(B * C / threadsPerBlock.x); kernel_forward<<>>(B, T, C, w, u, k, v, y); } void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *gy, float *gw, float *gu, float *gk, float *gv) { dim3 threadsPerBlock( min(C, 32) ); assert(B * C % threadsPerBlock.x == 0); dim3 numBlocks(B * C / threadsPerBlock.x); kernel_backward<<>>(B, T, C, w, u, k, v, gy, gw, gu, gk, gv); } """ cpp_source = """ #include void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y); void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *gy, float *gw, float *gu, float *gk, float *gv); void forward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) { cuda_forward(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), y.data_ptr()); } void backward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) { cuda_backward(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), gy.data_ptr(), gw.data_ptr(), gu.data_ptr(), gk.data_ptr(), gv.data_ptr()); } """ # Compile CUDA kernels try: import ninja except ImportError: import subprocess, sys subprocess.check_call([sys.executable, "-m", "pip", "install", "ninja"]) import ninja wkv_cuda = None if torch.cuda.is_available(): try: print("šŸ”§ Compiling CUDA kernels...") wkv_cuda = load_inline( name='wkv_cuda_chat', cpp_sources=cpp_source, cuda_sources=cuda_source, functions=['forward', 'backward'], verbose=False, extra_cuda_cflags=['-res-usage', '--maxrregcount 60', '--use_fast_math', '-O3', '-Xptxas -O3', f'-DTmax={CONFIG["seq_len"]}'] ) print("āœ… CUDA kernels ready") except: wkv_cuda = None print("āš ļø CUDA compilation failed, using PyTorch fallback") # ========================================== # MODEL CLASSES (Same as training) # ========================================== class WKV_CUDA_Function(torch.autograd.Function): @staticmethod def forward(ctx, w, u, k, v): B, T, C = k.size() ctx.save_for_backward(w, u, k, v) y = torch.zeros(B, T, C, device=k.device) wkv_cuda.forward(B, T, C, w, u, k, v, y) return y @staticmethod def backward(ctx, gy): w, u, k, v = ctx.saved_tensors B, T, C = k.size() gw = torch.zeros(B, C, device=k.device) gu = torch.zeros(B, C, device=k.device) gk = torch.zeros(B, T, C, device=k.device) gv = torch.zeros(B, T, C, device=k.device) wkv_cuda.backward(B, T, C, w, u, k, v, gy.contiguous(), gw, gu, gk, gv) return gw.sum(0), gu.sum(0), gk, gv class WKV_PureTorch(nn.Module): def __init__(self, d_model): super().__init__() def forward(self, w, u, k, v): B, T, C = k.size() aa = torch.zeros(B, C, device=k.device) bb = torch.zeros(B, C, device=k.device) pp = torch.ones(B, C, device=k.device) * -1e38 y = torch.zeros(B, T, C, device=k.device) for t in range(T): kt = k[:, t, :] vt = v[:, t, :] ww = u + kt p = torch.maximum(pp, ww) e1 = torch.exp(pp - p) e2 = torch.exp(ww - p) y[:, t, :] = (e1 * aa + e2 * vt) / (e1 * bb + e2) ww = pp + w p = torch.maximum(ww, kt) e1 = torch.exp(ww - p) e2 = torch.exp(kt - p) aa = e1 * aa + e2 * vt bb = e1 * bb + e2 pp = p return y class SpikingRWKV(nn.Module): def __init__(self, d_model): super().__init__() self.time_decay = nn.Parameter(torch.ones(d_model) * -2.0) self.time_first = nn.Parameter(torch.ones(d_model) * 0.5) self.time_mix_k = nn.Parameter(torch.ones(d_model) * 0.5) self.time_mix_v = nn.Parameter(torch.ones(d_model) * 0.5) self.time_mix_r = nn.Parameter(torch.ones(d_model) * 0.5) self.key = nn.Linear(d_model, d_model, bias=False) self.value = nn.Linear(d_model, d_model, bias=False) self.receptance = nn.Linear(d_model, d_model, bias=False) self.output = nn.Linear(d_model, d_model, bias=False) self.wkv_torch = WKV_PureTorch(d_model) def forward(self, x): B, T, C = x.size() x_prev = torch.cat([torch.zeros_like(x[:, :1, :]), x[:, :-1, :]], dim=1) xk = x * self.time_mix_k + x_prev * (1 - self.time_mix_k) xv = x * self.time_mix_v + x_prev * (1 - self.time_mix_v) xr = x * self.time_mix_r + x_prev * (1 - self.time_mix_r) k = self.key(xk) v = self.value(xv) r = self.receptance(xr) if wkv_cuda is not None: rwkv = WKV_CUDA_Function.apply(self.time_decay.float(), self.time_first.float(), k.float(), v.float()) rwkv = rwkv.type_as(x) else: rwkv = self.wkv_torch(self.time_decay, self.time_first, k, v) sr = torch.sigmoid(r) return self.output(sr * rwkv) class SpikingMLP(nn.Module): def __init__(self, d_model): super().__init__() self.time_mix_k = nn.Parameter(torch.ones(d_model) * 0.5) self.time_mix_r = nn.Parameter(torch.ones(d_model) * 0.5) self.key = nn.Linear(d_model, 4 * d_model, bias=False) self.receptance = nn.Linear(d_model, d_model, bias=False) self.value = nn.Linear(4 * d_model, d_model, bias=False) def forward(self, x): x_prev = torch.cat([torch.zeros_like(x[:, :1, :]), x[:, :-1, :]], dim=1) xk = x * self.time_mix_k + x_prev * (1 - self.time_mix_k) xr = x * self.time_mix_r + x_prev * (1 - self.time_mix_r) k = self.key(xk) k = torch.square(torch.relu(k)) kv = self.value(k) rkv = torch.sigmoid(self.receptance(xr)) * kv return rkv class NeuroSynergyBlock(nn.Module): def __init__(self, d_model): super().__init__() self.ln1 = nn.LayerNorm(d_model) self.ln2 = nn.LayerNorm(d_model) self.att = SpikingRWKV(d_model) self.ffn = SpikingMLP(d_model) self.bn_att = nn.BatchNorm1d(d_model, momentum=0.1) self.lif_att = neuron.LIFNode(surrogate_function=surrogate.ATan(alpha=4.0), detach_reset=True, v_threshold=1.0) self.bn_ffn = nn.BatchNorm1d(d_model, momentum=0.1) self.lif_ffn = neuron.LIFNode(surrogate_function=surrogate.ATan(alpha=4.0), detach_reset=True, v_threshold=1.0) self.dropout = nn.Dropout(0.05) def forward(self, x): residual = x x = self.ln1(x) x = self.att(x) x = x.transpose(1, 2) x = self.bn_att(x) x = x.transpose(1, 2) att_spikes = self.lif_att(x) x = self.dropout(att_spikes) x = residual + x residual = x x = self.ln2(x) x = self.ffn(x) x = x.transpose(1, 2) x = self.bn_ffn(x) x = x.transpose(1, 2) ffn_spikes = self.lif_ffn(x) x = self.dropout(ffn_spikes) x = residual + x return x, att_spikes, ffn_spikes class NeuroSynergyGPT(nn.Module): def __init__(self, vocab_size): super().__init__() self.d_model = CONFIG["d_model"] self.emb = nn.Embedding(vocab_size, self.d_model) self.bn_in = nn.BatchNorm1d(self.d_model, momentum=0.1) self.input_lif = neuron.LIFNode(surrogate_function=surrogate.ATan(alpha=4.0), detach_reset=True, v_threshold=1.0) self.blocks = nn.ModuleList([NeuroSynergyBlock(self.d_model) for _ in range(CONFIG["n_layers"])]) self.ln_out = nn.LayerNorm(self.d_model) self.head = nn.Linear(self.d_model, vocab_size, bias=False) def forward(self, idx): functional.reset_net(self) x = self.emb(idx) x = x.transpose(1, 2) x = self.bn_in(x) x = x.transpose(1, 2) in_spikes = self.input_lif(x) x = in_spikes spike_layers = [in_spikes] for block in self.blocks: x, s_att, s_ffn = block(x) spike_layers.extend([s_att, s_ffn]) x = self.ln_out(x) logits = self.head(x) return logits, spike_layers # ========================================== # GENERATION FUNCTION WITH STATS # ========================================== def generate_with_stats(model, tokenizer, prompt, max_new_tokens=200, temperature=0.7, top_p=0.9): """ Generate text with real-time statistics tracking. Returns: generated_text, stats_dict """ model.eval() # Tokenize prompt tokens = tokenizer.encode(prompt) tokens = torch.tensor([tokens], dtype=torch.long, device=CONFIG["device"]) generated_tokens = [] all_spike_rates = [] generation_times = [] start_time = time.time() with torch.no_grad(): for step in range(max_new_tokens): step_start = time.time() # Forward pass logits, spike_layers = model(tokens) # Calculate spike rate (last 4 blocks = 8 spike tensors) active_spikes = spike_layers[-8:] if len(spike_layers) >= 8 else spike_layers rates = [s.mean().item() for s in active_spikes] current_spike_rate = sum(rates) / len(rates) if rates else 0.0 all_spike_rates.append(current_spike_rate) # Sample next token next_token_logits = logits[0, -1, :] / temperature # Top-p (nucleus) sampling if top_p < 1.0: sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 indices_to_remove = sorted_indices_to_remove.scatter(0, sorted_indices, sorted_indices_to_remove) next_token_logits[indices_to_remove] = float('-inf') probs = F.softmax(next_token_logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1).unsqueeze(0) tokens = torch.cat([tokens, next_token], dim=1) generated_tokens.append(next_token.item()) step_time = time.time() - step_start generation_times.append(step_time) # Stop on EOS token if next_token.item() == tokenizer.eos_token_id: break total_time = time.time() - start_time generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True) # Calculate stats stats = { "total_tokens": len(generated_tokens), "total_time": total_time, "tokens_per_second": len(generated_tokens) / total_time if total_time > 0 else 0, "avg_time_per_token": total_time / len(generated_tokens) if generated_tokens else 0, "avg_spike_rate": sum(all_spike_rates) / len(all_spike_rates) if all_spike_rates else 0, "min_spike_rate": min(all_spike_rates) if all_spike_rates else 0, "max_spike_rate": max(all_spike_rates) if all_spike_rates else 0, } return generated_text, stats # ========================================== # CHAT INTERFACE # ========================================== def print_header(): """Print welcome header""" print("\n" + "="*70) print("🧠 Neuro-Synergy Chat Interface".center(70)) print("="*70) print("šŸ’” Type your message and press Enter") print("šŸ“Š Stats will be shown after each response") print("āŒ Type 'quit', 'exit', or 'q' to end the conversation") print("="*70 + "\n") def print_stats(stats): """Print generation statistics with emojis""" print("\n" + "─"*70) print("šŸ“Š Generation Statistics:") print(f" ⚔ Tokens Generated: {stats['total_tokens']}") print(f" ā±ļø Total Time: {stats['total_time']:.2f}s") print(f" šŸš€ Speed: {stats['tokens_per_second']:.2f} tokens/sec") print(f" ā³ Avg Time/Token: {stats['avg_time_per_token']*1000:.2f}ms") print(f" šŸ”„ Avg Spike Rate: {stats['avg_spike_rate']*100:.1f}%") print(f" šŸ“‰ Min Spike Rate: {stats['min_spike_rate']*100:.1f}%") print(f" šŸ“ˆ Max Spike Rate: {stats['max_spike_rate']*100:.1f}%") print("─"*70 + "\n") def main(): print("šŸš€ Initializing Neuro-Synergy Chat Interface...") # Load tokenizer print("šŸ“š Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained("gpt2") tokenizer.pad_token = tokenizer.eos_token tokenizer.model_max_length = CONFIG["seq_len"] print("āœ… Tokenizer loaded") # Load model print(f"šŸ¤– Loading model from {CONFIG['checkpoint_path']}...") model = NeuroSynergyGPT(CONFIG["vocab_size"]).to(CONFIG["device"]) if os.path.exists(CONFIG["checkpoint_path"]): try: checkpoint = torch.load(CONFIG["checkpoint_path"], map_location=CONFIG["device"]) # Apply weight normalization to last 4 layers (matching training script) # This must be done BEFORE loading to match checkpoint structure print("šŸ”’ Applying weight normalization to match checkpoint...") for block in model.blocks[-4:]: if not hasattr(block.att.output, 'weight_g'): block.att.output = torch.nn.utils.weight_norm(block.att.output) model.load_state_dict(checkpoint) print("āœ… Model loaded successfully") except Exception as e: print(f"āŒ Error loading checkpoint: {e}") print("šŸ’” Make sure you've run the fine-tuning script first!") return else: print(f"āŒ Checkpoint not found: {CONFIG['checkpoint_path']}") print("šŸ’” Please run finetune-meuro-synergy.py first to create the checkpoint") return model.eval() print(f"šŸŽÆ Model ready on {CONFIG['device']}") # Chat loop print_header() conversation_history = [] while True: try: # Get user input user_input = input("šŸ‘¤ You: ").strip() if not user_input: continue # Check for exit commands if user_input.lower() in ['quit', 'exit', 'q']: print("\nšŸ‘‹ Goodbye! Thanks for chatting with Neuro-Synergy!") break # Format prompt if conversation_history: # Multi-turn conversation prompt = "\n\n".join(conversation_history) + f"\n\nUser: {user_input}\n\nAssistant:" else: # First turn prompt = f"User: {user_input}\n\nAssistant:" # Generate response print("\nšŸ¤” Thinking...") response, stats = generate_with_stats( model, tokenizer, prompt, max_new_tokens=CONFIG["max_new_tokens"], temperature=CONFIG["temperature"], top_p=CONFIG["top_p"] ) # Print response print(f"\nšŸ¤– Assistant: {response}") # Print stats print_stats(stats) # Update conversation history (keep last 3 exchanges) conversation_history.append(f"User: {user_input}") conversation_history.append(f"Assistant: {response}") if len(conversation_history) > 6: # Keep last 3 exchanges conversation_history = conversation_history[-6:] except KeyboardInterrupt: print("\n\nšŸ‘‹ Interrupted. Goodbye!") break except Exception as e: print(f"\nāŒ Error: {e}") print("šŸ’” Please try again or type 'quit' to exit") if __name__ == "__main__": main()