crimson / chat_interface.py
AGofficial's picture
Upload 15 files
63a9f45 verified
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import customtkinter as ctk
import tiktoken
import threading
from typing import List
# Hyperparameters (must match train_gclm_base.py and finetune_gclm_base.py)
D_MODEL = 256
N_LAYERS = 4
MAX_SEQ_LEN = 1024
LOCAL_KERNEL_SIZE = 5
GLOBAL_KERNEL_SIZE = 256
USE_GLOBAL_EVERY_N_LAYERS = 2
FFT_SIZE = 1024
TOKENIZER_NAME = "gpt2"
# Paths
VOCAB_MAP_PATH = "vocab_map.pt"
MODEL_PATH = "crimson_instruct_8.9M.pt"
# Generation settings
TEMPERATURE = 0.8
TOP_K = 50
TOP_P = 0.9
MAX_GEN_LEN = 256
# --- Model Components (Duplicated for standalone use) ---
class GlobalConv1D(nn.Module):
def __init__(self, d_model, kernel_size, fft_size):
super().__init__()
self.kernel = nn.Parameter(torch.randn(d_model, kernel_size) * 0.01)
self.kernel_size = kernel_size
self.fft_size = fft_size
def forward(self, x):
B, C, T = x.shape
K = min(self.kernel_size, T)
overlap = K - 1
block = self.fft_size - overlap
x = F.pad(x, (overlap, 0))
k = self.kernel[:, :K]
k = F.pad(k, (0, self.fft_size - K))
k_f = torch.fft.rfft(k, n=self.fft_size)
outs = []
pos = 0
while pos < T:
seg = x[..., pos:pos+self.fft_size]
if seg.shape[-1] < self.fft_size:
seg = F.pad(seg, (0, self.fft_size - seg.shape[-1]))
y = torch.fft.irfft(torch.fft.rfft(seg, n=self.fft_size) * k_f.unsqueeze(0), n=self.fft_size)
outs.append(y[..., overlap:overlap+block])
pos += block
return torch.cat(outs, dim=-1)[..., :T]
class LocalConv1D(nn.Module):
def __init__(self, d_model, k):
super().__init__()
self.k = k
self.dw = nn.Conv1d(d_model, d_model, k, groups=d_model)
self.pw = nn.Conv1d(d_model, d_model, 1)
def forward(self, x):
x = F.pad(x, (self.k - 1, 0))
return self.pw(F.relu(self.dw(x)))
class Block(nn.Module):
def __init__(self, d_model, use_global):
super().__init__()
self.use_global = use_global
self.ln1 = nn.LayerNorm(d_model)
self.local = LocalConv1D(d_model, LOCAL_KERNEL_SIZE)
if use_global:
self.ln2 = nn.LayerNorm(d_model)
self.global_conv = GlobalConv1D(d_model, GLOBAL_KERNEL_SIZE, FFT_SIZE)
self.ln3 = nn.LayerNorm(d_model)
self.ff = nn.Sequential(
nn.Linear(d_model, d_model*4),
nn.GELU(),
nn.Linear(d_model*4, d_model)
)
def forward(self, x):
x = x + self.local(self.ln1(x).transpose(1,2)).transpose(1,2)
if self.use_global:
x = x + self.global_conv(self.ln2(x).transpose(1,2)).transpose(1,2)
return x + self.ff(self.ln3(x))
class CrimsonBase(nn.Module):
def __init__(self, vocab):
super().__init__()
self.emb = nn.Embedding(vocab, D_MODEL)
self.pos = nn.Embedding(MAX_SEQ_LEN, D_MODEL)
self.layers = nn.ModuleList([
Block(D_MODEL, i % USE_GLOBAL_EVERY_N_LAYERS == 0)
for i in range(N_LAYERS)
])
self.ln = nn.LayerNorm(D_MODEL)
self.head = nn.Linear(D_MODEL, vocab)
self.head.weight = self.emb.weight
def forward(self, x):
T = x.size(1)
if T > MAX_SEQ_LEN:
x = x[:, -MAX_SEQ_LEN:]
T = MAX_SEQ_LEN
h = self.emb(x) + self.pos(torch.arange(T, device=x.device))
for layer in self.layers:
h = layer(h)
return self.head(self.ln(h))
# --- Chat Engine ---
class ChatEngine:
def __init__(self):
self.device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"[INFO] Initializing engine on {self.device}...")
# Load vocab
self.vocab_data = torch.load(VOCAB_MAP_PATH, map_location="cpu")
self.id2new = self.vocab_data["id2new"]
self.new2id = {v: k for k, v in self.id2new.items()}
self.PAD_ID = self.vocab_data["PAD_ID"]
self.EOS_ID = self.vocab_data["EOS_ID"]
self.vocab_size = len(self.vocab_data["used_tokens"]) + 3
self.tok = tiktoken.get_encoding(TOKENIZER_NAME)
# Build model
self.model = CrimsonBase(self.vocab_size).to(self.device).eval()
if os.path.exists(MODEL_PATH):
self.model.load_state_dict(torch.load(MODEL_PATH, map_location=self.device))
print(f"[INFO] Loaded model from {MODEL_PATH}")
else:
print(f"[ERROR] {MODEL_PATH} not found. UI will be non-functional.")
@torch.no_grad()
def generate(self, prompt, max_new_tokens=MAX_GEN_LEN):
# Format prompt
full_prompt = f"<user> {prompt} <ai> "
raw_ids = self.tok.encode(full_prompt)
input_ids = [self.id2new.get(i, self.PAD_ID) for i in raw_ids]
x = torch.tensor([input_ids], dtype=torch.long, device=self.device)
generated = []
for _ in range(max_new_tokens):
logits = self.model(x)
logits = logits[:, -1, :] / TEMPERATURE
# Top-K
if TOP_K > 0:
v, _ = torch.topk(logits, min(TOP_K, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
# Top-P
if TOP_P < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > TOP_P
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices[sorted_indices_to_remove]
logits[0, indices_to_remove] = -float('Inf')
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
if next_token.item() == self.EOS_ID:
break
generated.append(next_token.item())
x = torch.cat([x, next_token], dim=1)
# Map back to original IDs and decode
current_ids = [self.new2id.get(i, 0) for i in generated]
yield self.tok.decode(current_ids)
# --- UI ---
class ChatApp(ctk.CTk):
def __init__(self, engine):
super().__init__()
self.engine = engine
self.title("Crimson Instruct Chat")
self.geometry("800x600")
ctk.set_appearance_mode("dark")
ctk.set_default_color_theme("blue")
# Layout
self.grid_rowconfigure(0, weight=1)
self.grid_columnconfigure(0, weight=1)
# Chat display
self.chat_display = ctk.CTkTextbox(self, state="disabled", font=("Inter", 14))
self.chat_display.grid(row=0, column=0, padx=20, pady=20, sticky="nsew")
# Input area
self.input_frame = ctk.CTkFrame(self)
self.input_frame.grid(row=1, column=0, padx=20, pady=(0, 20), sticky="ew")
self.input_frame.grid_columnconfigure(0, weight=1)
self.user_input = ctk.CTkEntry(self.input_frame, placeholder_text="Type your message here...", font=("Inter", 14))
self.user_input.grid(row=0, column=0, padx=(10, 5), pady=10, sticky="ew")
self.user_input.bind("<Return>", lambda e: self.send_message())
self.send_button = ctk.CTkButton(self.input_frame, text="Send", command=self.send_message, width=100)
self.send_button.grid(row=0, column=1, padx=(5, 10), pady=10)
def append_chat(self, sender, message):
self.chat_display.configure(state="normal")
tag = "<user>" if sender == "You" else "<ai>"
self.chat_display.insert("end", f"{tag} ", "bold")
self.chat_display.insert("end", f"{message}\n\n")
self.chat_display.configure(state="disabled")
self.chat_display.see("end")
def send_message(self):
msg = self.user_input.get().strip()
if not msg: return
self.user_input.delete(0, "end")
self.append_chat("You", msg)
# Start generation in thread
self.send_button.configure(state="disabled")
threading.Thread(target=self.generate_response, args=(msg,), daemon=True).start()
def generate_response(self, prompt):
self.chat_display.configure(state="normal")
self.chat_display.insert("end", "<ai> ", "bold")
current_text = ""
last_text = ""
for text in self.engine.generate(prompt):
current_text = text
new_part = current_text[len(last_text):]
self.chat_display.insert("end", new_part)
self.chat_display.see("end")
last_text = current_text
self.chat_display.insert("end", "\n\n")
self.chat_display.configure(state="disabled")
self.send_button.configure(state="normal")
if __name__ == "__main__":
eng = ChatEngine()
app = ChatApp(eng)
app.mainloop()