|
|
|
|
|
""" |
|
|
Interactive chat interface for Neuro-Synergy Spiking GPT model. |
|
|
Loads the fine-tuned checkpoint and provides a conversational interface with real-time stats. |
|
|
""" |
|
|
import os |
|
|
import time |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from transformers import AutoTokenizer |
|
|
from torch.utils.cpp_extension import load_inline |
|
|
|
|
|
|
|
|
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" |
|
|
|
|
|
|
|
|
try: |
|
|
import spikingjelly |
|
|
except ImportError: |
|
|
print("Installing spikingjelly...") |
|
|
import subprocess |
|
|
import sys |
|
|
subprocess.check_call([sys.executable, "-m", "pip", "install", "spikingjelly"]) |
|
|
import spikingjelly |
|
|
|
|
|
from spikingjelly.activation_based import neuron, surrogate, functional |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
CONFIG = { |
|
|
"device": "cuda" if torch.cuda.is_available() else "cpu", |
|
|
"d_model": 768, |
|
|
"n_layers": 18, |
|
|
"n_heads": 12, |
|
|
"vocab_size": 50304, |
|
|
"seq_len": 1024, |
|
|
"checkpoint_path": "neuro_synergy_chat.pt", |
|
|
"max_new_tokens": 200, |
|
|
"temperature": 0.7, |
|
|
"top_p": 0.9, |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cuda_source = """ |
|
|
#include <stdio.h> |
|
|
#include <assert.h> |
|
|
#define MIN_VALUE (-1e38) |
|
|
#ifndef Tmax |
|
|
#define Tmax 1024 |
|
|
#endif |
|
|
template <typename F> |
|
|
__global__ void kernel_forward(const int B, const int T, const int C, |
|
|
const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, |
|
|
F *__restrict__ const _y) { |
|
|
const int idx = blockIdx.x * blockDim.x + threadIdx.x; |
|
|
const int _b = idx / C; |
|
|
const int _c = idx % C; |
|
|
const int _offset = _b * T * C + _c; |
|
|
F u = _u[_c]; |
|
|
F w = _w[_c]; |
|
|
const F *__restrict__ const k = _k + _offset; |
|
|
const F *__restrict__ const v = _v + _offset; |
|
|
F *__restrict__ const y = _y + _offset; |
|
|
F p = 0, q = 0, o = MIN_VALUE; |
|
|
for (int i = 0; i < T; i++) { |
|
|
const int ii = i * C; |
|
|
F no = max(o, u + k[ii]); |
|
|
F A = exp(o - no); |
|
|
F B = exp(u + k[ii] - no); |
|
|
y[ii] = (A * p + B * v[ii]) / (A * q + B); |
|
|
no = max(w + o, k[ii]); |
|
|
A = exp(w + o - no); |
|
|
B = exp(k[ii] - no); |
|
|
p = A * p + B * v[ii]; |
|
|
q = A * q + B; |
|
|
o = no; |
|
|
} |
|
|
} |
|
|
template <typename F> |
|
|
__global__ void kernel_backward(const int B, const int T, const int C, |
|
|
const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _gy, |
|
|
F *__restrict__ const _gw, F *__restrict__ const _gu, F *__restrict__ const _gk, F *__restrict__ const _gv) { |
|
|
const int idx = blockIdx.x * blockDim.x + threadIdx.x; |
|
|
const int _b = idx / C; |
|
|
const int _c = idx % C; |
|
|
const int _offset = _b * T * C + _c; |
|
|
F u = _u[_c]; |
|
|
F w = _w[_c]; |
|
|
const F *__restrict__ const k = _k + _offset; |
|
|
const F *__restrict__ const v = _v + _offset; |
|
|
const F *__restrict__ const gy = _gy + _offset; |
|
|
F *__restrict__ const gk = _gk + _offset; |
|
|
F *__restrict__ const gv = _gv + _offset; |
|
|
F y[Tmax], z[Tmax], zexp[Tmax]; |
|
|
F gw = 0, gu = 0; |
|
|
F p = 0, q = 0; |
|
|
F dpdw = 0, dqdw = 0; |
|
|
F o = MIN_VALUE; |
|
|
for (int i = 0; i < T; i++) { |
|
|
const int ii = i * C; |
|
|
F no = max(o, k[ii] + u); |
|
|
F A = exp(o - no); |
|
|
F B = exp(k[ii] + u - no); |
|
|
F num = A * p + B * v[ii]; |
|
|
F iden = 1 / (A * q + B); |
|
|
y[i] = num * iden; |
|
|
z[i] = iden; |
|
|
zexp[i] = k[ii] + u - no; |
|
|
gw += gy[ii] * (dpdw - dqdw * y[i]) * iden * A; |
|
|
gu += gy[ii] * (v[ii] - y[i]) * B * iden; |
|
|
no = max(w + o, k[ii]); |
|
|
A = exp(w + o - no); |
|
|
B = exp(k[ii] - no); |
|
|
dpdw = A * (p + dpdw); |
|
|
dqdw = A * (q + dqdw); |
|
|
p = A * p + B * v[ii]; |
|
|
q = A * q + B; |
|
|
o = no; |
|
|
} |
|
|
F gp = 0, gq = 0; |
|
|
o = MIN_VALUE; |
|
|
for (int i = T - 1; i >= 0; i--) { |
|
|
const int ii = i * C; |
|
|
F A = gy[ii] * z[i] * exp(zexp[i]); |
|
|
F B = exp(k[ii] + o); |
|
|
gk[ii] = A * (v[ii] - y[i]) + B * (gp * v[ii] + gq); |
|
|
gv[ii] = A + B * gp; |
|
|
F no = max(w + o, zexp[i] - k[ii] - u); |
|
|
A = exp(w + o - no); |
|
|
B = gy[ii] * z[i] * exp(zexp[i] - k[ii] - u - no); |
|
|
gp = A * gp + B; |
|
|
gq = A * gq - B * y[i]; |
|
|
o = no; |
|
|
} |
|
|
const int _offsetBC = _b * C + _c; |
|
|
_gw[_offsetBC] += gw * _w[_c]; |
|
|
_gu[_offsetBC] += gu; |
|
|
} |
|
|
void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y) { |
|
|
dim3 threadsPerBlock( min(C, 32) ); |
|
|
assert(B * C % threadsPerBlock.x == 0); |
|
|
dim3 numBlocks(B * C / threadsPerBlock.x); |
|
|
kernel_forward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y); |
|
|
} |
|
|
void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *gy, float *gw, float *gu, float *gk, float *gv) { |
|
|
dim3 threadsPerBlock( min(C, 32) ); |
|
|
assert(B * C % threadsPerBlock.x == 0); |
|
|
dim3 numBlocks(B * C / threadsPerBlock.x); |
|
|
kernel_backward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, gy, gw, gu, gk, gv); |
|
|
} |
|
|
""" |
|
|
cpp_source = """ |
|
|
#include <torch/extension.h> |
|
|
void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y); |
|
|
void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *gy, float *gw, float *gu, float *gk, float *gv); |
|
|
void forward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) { |
|
|
cuda_forward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>()); |
|
|
} |
|
|
void backward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) { |
|
|
cuda_backward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), gy.data_ptr<float>(), gw.data_ptr<float>(), gu.data_ptr<float>(), gk.data_ptr<float>(), gv.data_ptr<float>()); |
|
|
} |
|
|
""" |
|
|
|
|
|
|
|
|
try: |
|
|
import ninja |
|
|
except ImportError: |
|
|
import subprocess, sys |
|
|
subprocess.check_call([sys.executable, "-m", "pip", "install", "ninja"]) |
|
|
import ninja |
|
|
|
|
|
wkv_cuda = None |
|
|
if torch.cuda.is_available(): |
|
|
try: |
|
|
print("π§ Compiling CUDA kernels...") |
|
|
wkv_cuda = load_inline( |
|
|
name='wkv_cuda_chat', |
|
|
cpp_sources=cpp_source, |
|
|
cuda_sources=cuda_source, |
|
|
functions=['forward', 'backward'], |
|
|
verbose=False, |
|
|
extra_cuda_cflags=['-res-usage', '--maxrregcount 60', '--use_fast_math', '-O3', '-Xptxas -O3', f'-DTmax={CONFIG["seq_len"]}'] |
|
|
) |
|
|
print("β
CUDA kernels ready") |
|
|
except: |
|
|
wkv_cuda = None |
|
|
print("β οΈ CUDA compilation failed, using PyTorch fallback") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class WKV_CUDA_Function(torch.autograd.Function): |
|
|
@staticmethod |
|
|
def forward(ctx, w, u, k, v): |
|
|
B, T, C = k.size() |
|
|
ctx.save_for_backward(w, u, k, v) |
|
|
y = torch.zeros(B, T, C, device=k.device) |
|
|
wkv_cuda.forward(B, T, C, w, u, k, v, y) |
|
|
return y |
|
|
@staticmethod |
|
|
def backward(ctx, gy): |
|
|
w, u, k, v = ctx.saved_tensors |
|
|
B, T, C = k.size() |
|
|
gw = torch.zeros(B, C, device=k.device) |
|
|
gu = torch.zeros(B, C, device=k.device) |
|
|
gk = torch.zeros(B, T, C, device=k.device) |
|
|
gv = torch.zeros(B, T, C, device=k.device) |
|
|
wkv_cuda.backward(B, T, C, w, u, k, v, gy.contiguous(), gw, gu, gk, gv) |
|
|
return gw.sum(0), gu.sum(0), gk, gv |
|
|
|
|
|
class WKV_PureTorch(nn.Module): |
|
|
def __init__(self, d_model): |
|
|
super().__init__() |
|
|
def forward(self, w, u, k, v): |
|
|
B, T, C = k.size() |
|
|
aa = torch.zeros(B, C, device=k.device) |
|
|
bb = torch.zeros(B, C, device=k.device) |
|
|
pp = torch.ones(B, C, device=k.device) * -1e38 |
|
|
y = torch.zeros(B, T, C, device=k.device) |
|
|
for t in range(T): |
|
|
kt = k[:, t, :] |
|
|
vt = v[:, t, :] |
|
|
ww = u + kt |
|
|
p = torch.maximum(pp, ww) |
|
|
e1 = torch.exp(pp - p) |
|
|
e2 = torch.exp(ww - p) |
|
|
y[:, t, :] = (e1 * aa + e2 * vt) / (e1 * bb + e2) |
|
|
ww = pp + w |
|
|
p = torch.maximum(ww, kt) |
|
|
e1 = torch.exp(ww - p) |
|
|
e2 = torch.exp(kt - p) |
|
|
aa = e1 * aa + e2 * vt |
|
|
bb = e1 * bb + e2 |
|
|
pp = p |
|
|
return y |
|
|
|
|
|
class SpikingRWKV(nn.Module): |
|
|
def __init__(self, d_model): |
|
|
super().__init__() |
|
|
self.time_decay = nn.Parameter(torch.ones(d_model) * -2.0) |
|
|
self.time_first = nn.Parameter(torch.ones(d_model) * 0.5) |
|
|
self.time_mix_k = nn.Parameter(torch.ones(d_model) * 0.5) |
|
|
self.time_mix_v = nn.Parameter(torch.ones(d_model) * 0.5) |
|
|
self.time_mix_r = nn.Parameter(torch.ones(d_model) * 0.5) |
|
|
self.key = nn.Linear(d_model, d_model, bias=False) |
|
|
self.value = nn.Linear(d_model, d_model, bias=False) |
|
|
self.receptance = nn.Linear(d_model, d_model, bias=False) |
|
|
self.output = nn.Linear(d_model, d_model, bias=False) |
|
|
self.wkv_torch = WKV_PureTorch(d_model) |
|
|
def forward(self, x): |
|
|
B, T, C = x.size() |
|
|
x_prev = torch.cat([torch.zeros_like(x[:, :1, :]), x[:, :-1, :]], dim=1) |
|
|
xk = x * self.time_mix_k + x_prev * (1 - self.time_mix_k) |
|
|
xv = x * self.time_mix_v + x_prev * (1 - self.time_mix_v) |
|
|
xr = x * self.time_mix_r + x_prev * (1 - self.time_mix_r) |
|
|
k = self.key(xk) |
|
|
v = self.value(xv) |
|
|
r = self.receptance(xr) |
|
|
if wkv_cuda is not None: |
|
|
rwkv = WKV_CUDA_Function.apply(self.time_decay.float(), self.time_first.float(), k.float(), v.float()) |
|
|
rwkv = rwkv.type_as(x) |
|
|
else: |
|
|
rwkv = self.wkv_torch(self.time_decay, self.time_first, k, v) |
|
|
sr = torch.sigmoid(r) |
|
|
return self.output(sr * rwkv) |
|
|
|
|
|
class SpikingMLP(nn.Module): |
|
|
def __init__(self, d_model): |
|
|
super().__init__() |
|
|
self.time_mix_k = nn.Parameter(torch.ones(d_model) * 0.5) |
|
|
self.time_mix_r = nn.Parameter(torch.ones(d_model) * 0.5) |
|
|
self.key = nn.Linear(d_model, 4 * d_model, bias=False) |
|
|
self.receptance = nn.Linear(d_model, d_model, bias=False) |
|
|
self.value = nn.Linear(4 * d_model, d_model, bias=False) |
|
|
def forward(self, x): |
|
|
x_prev = torch.cat([torch.zeros_like(x[:, :1, :]), x[:, :-1, :]], dim=1) |
|
|
xk = x * self.time_mix_k + x_prev * (1 - self.time_mix_k) |
|
|
xr = x * self.time_mix_r + x_prev * (1 - self.time_mix_r) |
|
|
k = self.key(xk) |
|
|
k = torch.square(torch.relu(k)) |
|
|
kv = self.value(k) |
|
|
rkv = torch.sigmoid(self.receptance(xr)) * kv |
|
|
return rkv |
|
|
|
|
|
class NeuroSynergyBlock(nn.Module): |
|
|
def __init__(self, d_model): |
|
|
super().__init__() |
|
|
self.ln1 = nn.LayerNorm(d_model) |
|
|
self.ln2 = nn.LayerNorm(d_model) |
|
|
self.att = SpikingRWKV(d_model) |
|
|
self.ffn = SpikingMLP(d_model) |
|
|
self.bn_att = nn.BatchNorm1d(d_model, momentum=0.1) |
|
|
self.lif_att = neuron.LIFNode(surrogate_function=surrogate.ATan(alpha=4.0), detach_reset=True, v_threshold=1.0) |
|
|
self.bn_ffn = nn.BatchNorm1d(d_model, momentum=0.1) |
|
|
self.lif_ffn = neuron.LIFNode(surrogate_function=surrogate.ATan(alpha=4.0), detach_reset=True, v_threshold=1.0) |
|
|
self.dropout = nn.Dropout(0.05) |
|
|
def forward(self, x): |
|
|
residual = x |
|
|
x = self.ln1(x) |
|
|
x = self.att(x) |
|
|
x = x.transpose(1, 2) |
|
|
x = self.bn_att(x) |
|
|
x = x.transpose(1, 2) |
|
|
att_spikes = self.lif_att(x) |
|
|
x = self.dropout(att_spikes) |
|
|
x = residual + x |
|
|
residual = x |
|
|
x = self.ln2(x) |
|
|
x = self.ffn(x) |
|
|
x = x.transpose(1, 2) |
|
|
x = self.bn_ffn(x) |
|
|
x = x.transpose(1, 2) |
|
|
ffn_spikes = self.lif_ffn(x) |
|
|
x = self.dropout(ffn_spikes) |
|
|
x = residual + x |
|
|
return x, att_spikes, ffn_spikes |
|
|
|
|
|
class NeuroSynergyGPT(nn.Module): |
|
|
def __init__(self, vocab_size): |
|
|
super().__init__() |
|
|
self.d_model = CONFIG["d_model"] |
|
|
self.emb = nn.Embedding(vocab_size, self.d_model) |
|
|
self.bn_in = nn.BatchNorm1d(self.d_model, momentum=0.1) |
|
|
self.input_lif = neuron.LIFNode(surrogate_function=surrogate.ATan(alpha=4.0), detach_reset=True, v_threshold=1.0) |
|
|
self.blocks = nn.ModuleList([NeuroSynergyBlock(self.d_model) for _ in range(CONFIG["n_layers"])]) |
|
|
self.ln_out = nn.LayerNorm(self.d_model) |
|
|
self.head = nn.Linear(self.d_model, vocab_size, bias=False) |
|
|
def forward(self, idx): |
|
|
functional.reset_net(self) |
|
|
x = self.emb(idx) |
|
|
x = x.transpose(1, 2) |
|
|
x = self.bn_in(x) |
|
|
x = x.transpose(1, 2) |
|
|
in_spikes = self.input_lif(x) |
|
|
x = in_spikes |
|
|
spike_layers = [in_spikes] |
|
|
for block in self.blocks: |
|
|
x, s_att, s_ffn = block(x) |
|
|
spike_layers.extend([s_att, s_ffn]) |
|
|
x = self.ln_out(x) |
|
|
logits = self.head(x) |
|
|
return logits, spike_layers |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_with_stats(model, tokenizer, prompt, max_new_tokens=200, temperature=0.7, top_p=0.9): |
|
|
""" |
|
|
Generate text with real-time statistics tracking. |
|
|
Returns: generated_text, stats_dict |
|
|
""" |
|
|
model.eval() |
|
|
|
|
|
|
|
|
tokens = tokenizer.encode(prompt) |
|
|
tokens = torch.tensor([tokens], dtype=torch.long, device=CONFIG["device"]) |
|
|
|
|
|
generated_tokens = [] |
|
|
all_spike_rates = [] |
|
|
generation_times = [] |
|
|
|
|
|
start_time = time.time() |
|
|
|
|
|
with torch.no_grad(): |
|
|
for step in range(max_new_tokens): |
|
|
step_start = time.time() |
|
|
|
|
|
|
|
|
logits, spike_layers = model(tokens) |
|
|
|
|
|
|
|
|
active_spikes = spike_layers[-8:] if len(spike_layers) >= 8 else spike_layers |
|
|
rates = [s.mean().item() for s in active_spikes] |
|
|
current_spike_rate = sum(rates) / len(rates) if rates else 0.0 |
|
|
all_spike_rates.append(current_spike_rate) |
|
|
|
|
|
|
|
|
next_token_logits = logits[0, -1, :] / temperature |
|
|
|
|
|
|
|
|
if top_p < 1.0: |
|
|
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True) |
|
|
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
|
|
sorted_indices_to_remove = cumulative_probs > top_p |
|
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
|
|
sorted_indices_to_remove[..., 0] = 0 |
|
|
indices_to_remove = sorted_indices_to_remove.scatter(0, sorted_indices, sorted_indices_to_remove) |
|
|
next_token_logits[indices_to_remove] = float('-inf') |
|
|
|
|
|
probs = F.softmax(next_token_logits, dim=-1) |
|
|
next_token = torch.multinomial(probs, num_samples=1).unsqueeze(0) |
|
|
|
|
|
tokens = torch.cat([tokens, next_token], dim=1) |
|
|
generated_tokens.append(next_token.item()) |
|
|
|
|
|
step_time = time.time() - step_start |
|
|
generation_times.append(step_time) |
|
|
|
|
|
|
|
|
if next_token.item() == tokenizer.eos_token_id: |
|
|
break |
|
|
|
|
|
total_time = time.time() - start_time |
|
|
generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True) |
|
|
|
|
|
|
|
|
stats = { |
|
|
"total_tokens": len(generated_tokens), |
|
|
"total_time": total_time, |
|
|
"tokens_per_second": len(generated_tokens) / total_time if total_time > 0 else 0, |
|
|
"avg_time_per_token": total_time / len(generated_tokens) if generated_tokens else 0, |
|
|
"avg_spike_rate": sum(all_spike_rates) / len(all_spike_rates) if all_spike_rates else 0, |
|
|
"min_spike_rate": min(all_spike_rates) if all_spike_rates else 0, |
|
|
"max_spike_rate": max(all_spike_rates) if all_spike_rates else 0, |
|
|
} |
|
|
|
|
|
return generated_text, stats |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def print_header(): |
|
|
"""Print welcome header""" |
|
|
print("\n" + "="*70) |
|
|
print("π§ Neuro-Synergy Chat Interface".center(70)) |
|
|
print("="*70) |
|
|
print("π‘ Type your message and press Enter") |
|
|
print("π Stats will be shown after each response") |
|
|
print("β Type 'quit', 'exit', or 'q' to end the conversation") |
|
|
print("="*70 + "\n") |
|
|
|
|
|
def print_stats(stats): |
|
|
"""Print generation statistics with emojis""" |
|
|
print("\n" + "β"*70) |
|
|
print("π Generation Statistics:") |
|
|
print(f" β‘ Tokens Generated: {stats['total_tokens']}") |
|
|
print(f" β±οΈ Total Time: {stats['total_time']:.2f}s") |
|
|
print(f" π Speed: {stats['tokens_per_second']:.2f} tokens/sec") |
|
|
print(f" β³ Avg Time/Token: {stats['avg_time_per_token']*1000:.2f}ms") |
|
|
print(f" π₯ Avg Spike Rate: {stats['avg_spike_rate']*100:.1f}%") |
|
|
print(f" π Min Spike Rate: {stats['min_spike_rate']*100:.1f}%") |
|
|
print(f" π Max Spike Rate: {stats['max_spike_rate']*100:.1f}%") |
|
|
print("β"*70 + "\n") |
|
|
|
|
|
def main(): |
|
|
print("π Initializing Neuro-Synergy Chat Interface...") |
|
|
|
|
|
|
|
|
print("π Loading tokenizer...") |
|
|
tokenizer = AutoTokenizer.from_pretrained("gpt2") |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
tokenizer.model_max_length = CONFIG["seq_len"] |
|
|
print("β
Tokenizer loaded") |
|
|
|
|
|
|
|
|
print(f"π€ Loading model from {CONFIG['checkpoint_path']}...") |
|
|
model = NeuroSynergyGPT(CONFIG["vocab_size"]).to(CONFIG["device"]) |
|
|
|
|
|
if os.path.exists(CONFIG["checkpoint_path"]): |
|
|
try: |
|
|
checkpoint = torch.load(CONFIG["checkpoint_path"], map_location=CONFIG["device"]) |
|
|
|
|
|
|
|
|
|
|
|
print("π Applying weight normalization to match checkpoint...") |
|
|
for block in model.blocks[-4:]: |
|
|
if not hasattr(block.att.output, 'weight_g'): |
|
|
block.att.output = torch.nn.utils.weight_norm(block.att.output) |
|
|
|
|
|
model.load_state_dict(checkpoint) |
|
|
print("β
Model loaded successfully") |
|
|
except Exception as e: |
|
|
print(f"β Error loading checkpoint: {e}") |
|
|
print("π‘ Make sure you've run the fine-tuning script first!") |
|
|
return |
|
|
else: |
|
|
print(f"β Checkpoint not found: {CONFIG['checkpoint_path']}") |
|
|
print("π‘ Please run finetune-meuro-synergy.py first to create the checkpoint") |
|
|
return |
|
|
|
|
|
model.eval() |
|
|
print(f"π― Model ready on {CONFIG['device']}") |
|
|
|
|
|
|
|
|
print_header() |
|
|
|
|
|
conversation_history = [] |
|
|
|
|
|
while True: |
|
|
try: |
|
|
|
|
|
user_input = input("π€ You: ").strip() |
|
|
|
|
|
if not user_input: |
|
|
continue |
|
|
|
|
|
|
|
|
if user_input.lower() in ['quit', 'exit', 'q']: |
|
|
print("\nπ Goodbye! Thanks for chatting with Neuro-Synergy!") |
|
|
break |
|
|
|
|
|
|
|
|
if conversation_history: |
|
|
|
|
|
prompt = "\n\n".join(conversation_history) + f"\n\nUser: {user_input}\n\nAssistant:" |
|
|
else: |
|
|
|
|
|
prompt = f"User: {user_input}\n\nAssistant:" |
|
|
|
|
|
|
|
|
print("\nπ€ Thinking...") |
|
|
response, stats = generate_with_stats( |
|
|
model, tokenizer, prompt, |
|
|
max_new_tokens=CONFIG["max_new_tokens"], |
|
|
temperature=CONFIG["temperature"], |
|
|
top_p=CONFIG["top_p"] |
|
|
) |
|
|
|
|
|
|
|
|
print(f"\nπ€ Assistant: {response}") |
|
|
|
|
|
|
|
|
print_stats(stats) |
|
|
|
|
|
|
|
|
conversation_history.append(f"User: {user_input}") |
|
|
conversation_history.append(f"Assistant: {response}") |
|
|
if len(conversation_history) > 6: |
|
|
conversation_history = conversation_history[-6:] |
|
|
|
|
|
except KeyboardInterrupt: |
|
|
print("\n\nπ Interrupted. Goodbye!") |
|
|
break |
|
|
except Exception as e: |
|
|
print(f"\nβ Error: {e}") |
|
|
print("π‘ Please try again or type 'quit' to exit") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|