Neuro-Synergy-Spiking-GPT / chat_interface.py
Puddings22's picture
Upload 3 files
45c4c34 verified
#@title Neuro-Synergy Chat Interface
"""
Interactive chat interface for Neuro-Synergy Spiking GPT model.
Loads the fine-tuned checkpoint and provides a conversational interface with real-time stats.
"""
import os
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer
from torch.utils.cpp_extension import load_inline
# Force expansive segments for CUDA
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
# Try importing spikingjelly
try:
import spikingjelly
except ImportError:
print("Installing spikingjelly...")
import subprocess
import sys
subprocess.check_call([sys.executable, "-m", "pip", "install", "spikingjelly"])
import spikingjelly
from spikingjelly.activation_based import neuron, surrogate, functional
# ==========================================
# CONFIGURATION
# ==========================================
CONFIG = {
"device": "cuda" if torch.cuda.is_available() else "cpu",
"d_model": 768,
"n_layers": 18,
"n_heads": 12,
"vocab_size": 50304,
"seq_len": 1024,
"checkpoint_path": "neuro_synergy_chat.pt", # Fine-tuned checkpoint
"max_new_tokens": 200,
"temperature": 0.7,
"top_p": 0.9,
}
# ==========================================
# CUDA KERNELS (Same as training)
# ==========================================
cuda_source = """
#include <stdio.h>
#include <assert.h>
#define MIN_VALUE (-1e38)
#ifndef Tmax
#define Tmax 1024
#endif
template <typename F>
__global__ void kernel_forward(const int B, const int T, const int C,
const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v,
F *__restrict__ const _y) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
const int _b = idx / C;
const int _c = idx % C;
const int _offset = _b * T * C + _c;
F u = _u[_c];
F w = _w[_c];
const F *__restrict__ const k = _k + _offset;
const F *__restrict__ const v = _v + _offset;
F *__restrict__ const y = _y + _offset;
F p = 0, q = 0, o = MIN_VALUE;
for (int i = 0; i < T; i++) {
const int ii = i * C;
F no = max(o, u + k[ii]);
F A = exp(o - no);
F B = exp(u + k[ii] - no);
y[ii] = (A * p + B * v[ii]) / (A * q + B);
no = max(w + o, k[ii]);
A = exp(w + o - no);
B = exp(k[ii] - no);
p = A * p + B * v[ii];
q = A * q + B;
o = no;
}
}
template <typename F>
__global__ void kernel_backward(const int B, const int T, const int C,
const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _gy,
F *__restrict__ const _gw, F *__restrict__ const _gu, F *__restrict__ const _gk, F *__restrict__ const _gv) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
const int _b = idx / C;
const int _c = idx % C;
const int _offset = _b * T * C + _c;
F u = _u[_c];
F w = _w[_c];
const F *__restrict__ const k = _k + _offset;
const F *__restrict__ const v = _v + _offset;
const F *__restrict__ const gy = _gy + _offset;
F *__restrict__ const gk = _gk + _offset;
F *__restrict__ const gv = _gv + _offset;
F y[Tmax], z[Tmax], zexp[Tmax];
F gw = 0, gu = 0;
F p = 0, q = 0;
F dpdw = 0, dqdw = 0;
F o = MIN_VALUE;
for (int i = 0; i < T; i++) {
const int ii = i * C;
F no = max(o, k[ii] + u);
F A = exp(o - no);
F B = exp(k[ii] + u - no);
F num = A * p + B * v[ii];
F iden = 1 / (A * q + B);
y[i] = num * iden;
z[i] = iden;
zexp[i] = k[ii] + u - no;
gw += gy[ii] * (dpdw - dqdw * y[i]) * iden * A;
gu += gy[ii] * (v[ii] - y[i]) * B * iden;
no = max(w + o, k[ii]);
A = exp(w + o - no);
B = exp(k[ii] - no);
dpdw = A * (p + dpdw);
dqdw = A * (q + dqdw);
p = A * p + B * v[ii];
q = A * q + B;
o = no;
}
F gp = 0, gq = 0;
o = MIN_VALUE;
for (int i = T - 1; i >= 0; i--) {
const int ii = i * C;
F A = gy[ii] * z[i] * exp(zexp[i]);
F B = exp(k[ii] + o);
gk[ii] = A * (v[ii] - y[i]) + B * (gp * v[ii] + gq);
gv[ii] = A + B * gp;
F no = max(w + o, zexp[i] - k[ii] - u);
A = exp(w + o - no);
B = gy[ii] * z[i] * exp(zexp[i] - k[ii] - u - no);
gp = A * gp + B;
gq = A * gq - B * y[i];
o = no;
}
const int _offsetBC = _b * C + _c;
_gw[_offsetBC] += gw * _w[_c];
_gu[_offsetBC] += gu;
}
void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y) {
dim3 threadsPerBlock( min(C, 32) );
assert(B * C % threadsPerBlock.x == 0);
dim3 numBlocks(B * C / threadsPerBlock.x);
kernel_forward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y);
}
void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *gy, float *gw, float *gu, float *gk, float *gv) {
dim3 threadsPerBlock( min(C, 32) );
assert(B * C % threadsPerBlock.x == 0);
dim3 numBlocks(B * C / threadsPerBlock.x);
kernel_backward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, gy, gw, gu, gk, gv);
}
"""
cpp_source = """
#include <torch/extension.h>
void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y);
void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *gy, float *gw, float *gu, float *gk, float *gv);
void forward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) {
cuda_forward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>());
}
void backward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) {
cuda_backward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), gy.data_ptr<float>(), gw.data_ptr<float>(), gu.data_ptr<float>(), gk.data_ptr<float>(), gv.data_ptr<float>());
}
"""
# Compile CUDA kernels
try:
import ninja
except ImportError:
import subprocess, sys
subprocess.check_call([sys.executable, "-m", "pip", "install", "ninja"])
import ninja
wkv_cuda = None
if torch.cuda.is_available():
try:
print("πŸ”§ Compiling CUDA kernels...")
wkv_cuda = load_inline(
name='wkv_cuda_chat',
cpp_sources=cpp_source,
cuda_sources=cuda_source,
functions=['forward', 'backward'],
verbose=False,
extra_cuda_cflags=['-res-usage', '--maxrregcount 60', '--use_fast_math', '-O3', '-Xptxas -O3', f'-DTmax={CONFIG["seq_len"]}']
)
print("βœ… CUDA kernels ready")
except:
wkv_cuda = None
print("⚠️ CUDA compilation failed, using PyTorch fallback")
# ==========================================
# MODEL CLASSES (Same as training)
# ==========================================
class WKV_CUDA_Function(torch.autograd.Function):
@staticmethod
def forward(ctx, w, u, k, v):
B, T, C = k.size()
ctx.save_for_backward(w, u, k, v)
y = torch.zeros(B, T, C, device=k.device)
wkv_cuda.forward(B, T, C, w, u, k, v, y)
return y
@staticmethod
def backward(ctx, gy):
w, u, k, v = ctx.saved_tensors
B, T, C = k.size()
gw = torch.zeros(B, C, device=k.device)
gu = torch.zeros(B, C, device=k.device)
gk = torch.zeros(B, T, C, device=k.device)
gv = torch.zeros(B, T, C, device=k.device)
wkv_cuda.backward(B, T, C, w, u, k, v, gy.contiguous(), gw, gu, gk, gv)
return gw.sum(0), gu.sum(0), gk, gv
class WKV_PureTorch(nn.Module):
def __init__(self, d_model):
super().__init__()
def forward(self, w, u, k, v):
B, T, C = k.size()
aa = torch.zeros(B, C, device=k.device)
bb = torch.zeros(B, C, device=k.device)
pp = torch.ones(B, C, device=k.device) * -1e38
y = torch.zeros(B, T, C, device=k.device)
for t in range(T):
kt = k[:, t, :]
vt = v[:, t, :]
ww = u + kt
p = torch.maximum(pp, ww)
e1 = torch.exp(pp - p)
e2 = torch.exp(ww - p)
y[:, t, :] = (e1 * aa + e2 * vt) / (e1 * bb + e2)
ww = pp + w
p = torch.maximum(ww, kt)
e1 = torch.exp(ww - p)
e2 = torch.exp(kt - p)
aa = e1 * aa + e2 * vt
bb = e1 * bb + e2
pp = p
return y
class SpikingRWKV(nn.Module):
def __init__(self, d_model):
super().__init__()
self.time_decay = nn.Parameter(torch.ones(d_model) * -2.0)
self.time_first = nn.Parameter(torch.ones(d_model) * 0.5)
self.time_mix_k = nn.Parameter(torch.ones(d_model) * 0.5)
self.time_mix_v = nn.Parameter(torch.ones(d_model) * 0.5)
self.time_mix_r = nn.Parameter(torch.ones(d_model) * 0.5)
self.key = nn.Linear(d_model, d_model, bias=False)
self.value = nn.Linear(d_model, d_model, bias=False)
self.receptance = nn.Linear(d_model, d_model, bias=False)
self.output = nn.Linear(d_model, d_model, bias=False)
self.wkv_torch = WKV_PureTorch(d_model)
def forward(self, x):
B, T, C = x.size()
x_prev = torch.cat([torch.zeros_like(x[:, :1, :]), x[:, :-1, :]], dim=1)
xk = x * self.time_mix_k + x_prev * (1 - self.time_mix_k)
xv = x * self.time_mix_v + x_prev * (1 - self.time_mix_v)
xr = x * self.time_mix_r + x_prev * (1 - self.time_mix_r)
k = self.key(xk)
v = self.value(xv)
r = self.receptance(xr)
if wkv_cuda is not None:
rwkv = WKV_CUDA_Function.apply(self.time_decay.float(), self.time_first.float(), k.float(), v.float())
rwkv = rwkv.type_as(x)
else:
rwkv = self.wkv_torch(self.time_decay, self.time_first, k, v)
sr = torch.sigmoid(r)
return self.output(sr * rwkv)
class SpikingMLP(nn.Module):
def __init__(self, d_model):
super().__init__()
self.time_mix_k = nn.Parameter(torch.ones(d_model) * 0.5)
self.time_mix_r = nn.Parameter(torch.ones(d_model) * 0.5)
self.key = nn.Linear(d_model, 4 * d_model, bias=False)
self.receptance = nn.Linear(d_model, d_model, bias=False)
self.value = nn.Linear(4 * d_model, d_model, bias=False)
def forward(self, x):
x_prev = torch.cat([torch.zeros_like(x[:, :1, :]), x[:, :-1, :]], dim=1)
xk = x * self.time_mix_k + x_prev * (1 - self.time_mix_k)
xr = x * self.time_mix_r + x_prev * (1 - self.time_mix_r)
k = self.key(xk)
k = torch.square(torch.relu(k))
kv = self.value(k)
rkv = torch.sigmoid(self.receptance(xr)) * kv
return rkv
class NeuroSynergyBlock(nn.Module):
def __init__(self, d_model):
super().__init__()
self.ln1 = nn.LayerNorm(d_model)
self.ln2 = nn.LayerNorm(d_model)
self.att = SpikingRWKV(d_model)
self.ffn = SpikingMLP(d_model)
self.bn_att = nn.BatchNorm1d(d_model, momentum=0.1)
self.lif_att = neuron.LIFNode(surrogate_function=surrogate.ATan(alpha=4.0), detach_reset=True, v_threshold=1.0)
self.bn_ffn = nn.BatchNorm1d(d_model, momentum=0.1)
self.lif_ffn = neuron.LIFNode(surrogate_function=surrogate.ATan(alpha=4.0), detach_reset=True, v_threshold=1.0)
self.dropout = nn.Dropout(0.05)
def forward(self, x):
residual = x
x = self.ln1(x)
x = self.att(x)
x = x.transpose(1, 2)
x = self.bn_att(x)
x = x.transpose(1, 2)
att_spikes = self.lif_att(x)
x = self.dropout(att_spikes)
x = residual + x
residual = x
x = self.ln2(x)
x = self.ffn(x)
x = x.transpose(1, 2)
x = self.bn_ffn(x)
x = x.transpose(1, 2)
ffn_spikes = self.lif_ffn(x)
x = self.dropout(ffn_spikes)
x = residual + x
return x, att_spikes, ffn_spikes
class NeuroSynergyGPT(nn.Module):
def __init__(self, vocab_size):
super().__init__()
self.d_model = CONFIG["d_model"]
self.emb = nn.Embedding(vocab_size, self.d_model)
self.bn_in = nn.BatchNorm1d(self.d_model, momentum=0.1)
self.input_lif = neuron.LIFNode(surrogate_function=surrogate.ATan(alpha=4.0), detach_reset=True, v_threshold=1.0)
self.blocks = nn.ModuleList([NeuroSynergyBlock(self.d_model) for _ in range(CONFIG["n_layers"])])
self.ln_out = nn.LayerNorm(self.d_model)
self.head = nn.Linear(self.d_model, vocab_size, bias=False)
def forward(self, idx):
functional.reset_net(self)
x = self.emb(idx)
x = x.transpose(1, 2)
x = self.bn_in(x)
x = x.transpose(1, 2)
in_spikes = self.input_lif(x)
x = in_spikes
spike_layers = [in_spikes]
for block in self.blocks:
x, s_att, s_ffn = block(x)
spike_layers.extend([s_att, s_ffn])
x = self.ln_out(x)
logits = self.head(x)
return logits, spike_layers
# ==========================================
# GENERATION FUNCTION WITH STATS
# ==========================================
def generate_with_stats(model, tokenizer, prompt, max_new_tokens=200, temperature=0.7, top_p=0.9):
"""
Generate text with real-time statistics tracking.
Returns: generated_text, stats_dict
"""
model.eval()
# Tokenize prompt
tokens = tokenizer.encode(prompt)
tokens = torch.tensor([tokens], dtype=torch.long, device=CONFIG["device"])
generated_tokens = []
all_spike_rates = []
generation_times = []
start_time = time.time()
with torch.no_grad():
for step in range(max_new_tokens):
step_start = time.time()
# Forward pass
logits, spike_layers = model(tokens)
# Calculate spike rate (last 4 blocks = 8 spike tensors)
active_spikes = spike_layers[-8:] if len(spike_layers) >= 8 else spike_layers
rates = [s.mean().item() for s in active_spikes]
current_spike_rate = sum(rates) / len(rates) if rates else 0.0
all_spike_rates.append(current_spike_rate)
# Sample next token
next_token_logits = logits[0, -1, :] / temperature
# Top-p (nucleus) sampling
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(0, sorted_indices, sorted_indices_to_remove)
next_token_logits[indices_to_remove] = float('-inf')
probs = F.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1).unsqueeze(0)
tokens = torch.cat([tokens, next_token], dim=1)
generated_tokens.append(next_token.item())
step_time = time.time() - step_start
generation_times.append(step_time)
# Stop on EOS token
if next_token.item() == tokenizer.eos_token_id:
break
total_time = time.time() - start_time
generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
# Calculate stats
stats = {
"total_tokens": len(generated_tokens),
"total_time": total_time,
"tokens_per_second": len(generated_tokens) / total_time if total_time > 0 else 0,
"avg_time_per_token": total_time / len(generated_tokens) if generated_tokens else 0,
"avg_spike_rate": sum(all_spike_rates) / len(all_spike_rates) if all_spike_rates else 0,
"min_spike_rate": min(all_spike_rates) if all_spike_rates else 0,
"max_spike_rate": max(all_spike_rates) if all_spike_rates else 0,
}
return generated_text, stats
# ==========================================
# CHAT INTERFACE
# ==========================================
def print_header():
"""Print welcome header"""
print("\n" + "="*70)
print("🧠 Neuro-Synergy Chat Interface".center(70))
print("="*70)
print("πŸ’‘ Type your message and press Enter")
print("πŸ“Š Stats will be shown after each response")
print("❌ Type 'quit', 'exit', or 'q' to end the conversation")
print("="*70 + "\n")
def print_stats(stats):
"""Print generation statistics with emojis"""
print("\n" + "─"*70)
print("πŸ“Š Generation Statistics:")
print(f" ⚑ Tokens Generated: {stats['total_tokens']}")
print(f" ⏱️ Total Time: {stats['total_time']:.2f}s")
print(f" πŸš€ Speed: {stats['tokens_per_second']:.2f} tokens/sec")
print(f" ⏳ Avg Time/Token: {stats['avg_time_per_token']*1000:.2f}ms")
print(f" πŸ”₯ Avg Spike Rate: {stats['avg_spike_rate']*100:.1f}%")
print(f" πŸ“‰ Min Spike Rate: {stats['min_spike_rate']*100:.1f}%")
print(f" πŸ“ˆ Max Spike Rate: {stats['max_spike_rate']*100:.1f}%")
print("─"*70 + "\n")
def main():
print("πŸš€ Initializing Neuro-Synergy Chat Interface...")
# Load tokenizer
print("πŸ“š Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.model_max_length = CONFIG["seq_len"]
print("βœ… Tokenizer loaded")
# Load model
print(f"πŸ€– Loading model from {CONFIG['checkpoint_path']}...")
model = NeuroSynergyGPT(CONFIG["vocab_size"]).to(CONFIG["device"])
if os.path.exists(CONFIG["checkpoint_path"]):
try:
checkpoint = torch.load(CONFIG["checkpoint_path"], map_location=CONFIG["device"])
# Apply weight normalization to last 4 layers (matching training script)
# This must be done BEFORE loading to match checkpoint structure
print("πŸ”’ Applying weight normalization to match checkpoint...")
for block in model.blocks[-4:]:
if not hasattr(block.att.output, 'weight_g'):
block.att.output = torch.nn.utils.weight_norm(block.att.output)
model.load_state_dict(checkpoint)
print("βœ… Model loaded successfully")
except Exception as e:
print(f"❌ Error loading checkpoint: {e}")
print("πŸ’‘ Make sure you've run the fine-tuning script first!")
return
else:
print(f"❌ Checkpoint not found: {CONFIG['checkpoint_path']}")
print("πŸ’‘ Please run finetune-meuro-synergy.py first to create the checkpoint")
return
model.eval()
print(f"🎯 Model ready on {CONFIG['device']}")
# Chat loop
print_header()
conversation_history = []
while True:
try:
# Get user input
user_input = input("πŸ‘€ You: ").strip()
if not user_input:
continue
# Check for exit commands
if user_input.lower() in ['quit', 'exit', 'q']:
print("\nπŸ‘‹ Goodbye! Thanks for chatting with Neuro-Synergy!")
break
# Format prompt
if conversation_history:
# Multi-turn conversation
prompt = "\n\n".join(conversation_history) + f"\n\nUser: {user_input}\n\nAssistant:"
else:
# First turn
prompt = f"User: {user_input}\n\nAssistant:"
# Generate response
print("\nπŸ€” Thinking...")
response, stats = generate_with_stats(
model, tokenizer, prompt,
max_new_tokens=CONFIG["max_new_tokens"],
temperature=CONFIG["temperature"],
top_p=CONFIG["top_p"]
)
# Print response
print(f"\nπŸ€– Assistant: {response}")
# Print stats
print_stats(stats)
# Update conversation history (keep last 3 exchanges)
conversation_history.append(f"User: {user_input}")
conversation_history.append(f"Assistant: {response}")
if len(conversation_history) > 6: # Keep last 3 exchanges
conversation_history = conversation_history[-6:]
except KeyboardInterrupt:
print("\n\nπŸ‘‹ Interrupted. Goodbye!")
break
except Exception as e:
print(f"\n❌ Error: {e}")
print("πŸ’‘ Please try again or type 'quit' to exit")
if __name__ == "__main__":
main()