import os import torch import torch.nn as nn import torch.nn.functional as F import customtkinter as ctk import tiktoken import threading from typing import List # Hyperparameters (must match train_gclm_base.py and finetune_gclm_base.py) D_MODEL = 256 N_LAYERS = 4 MAX_SEQ_LEN = 1024 LOCAL_KERNEL_SIZE = 5 GLOBAL_KERNEL_SIZE = 256 USE_GLOBAL_EVERY_N_LAYERS = 2 FFT_SIZE = 1024 TOKENIZER_NAME = "gpt2" # Paths VOCAB_MAP_PATH = "vocab_map.pt" MODEL_PATH = "crimson_instruct_8.9M.pt" # Generation settings TEMPERATURE = 0.8 TOP_K = 50 TOP_P = 0.9 MAX_GEN_LEN = 256 # --- Model Components (Duplicated for standalone use) --- class GlobalConv1D(nn.Module): def __init__(self, d_model, kernel_size, fft_size): super().__init__() self.kernel = nn.Parameter(torch.randn(d_model, kernel_size) * 0.01) self.kernel_size = kernel_size self.fft_size = fft_size def forward(self, x): B, C, T = x.shape K = min(self.kernel_size, T) overlap = K - 1 block = self.fft_size - overlap x = F.pad(x, (overlap, 0)) k = self.kernel[:, :K] k = F.pad(k, (0, self.fft_size - K)) k_f = torch.fft.rfft(k, n=self.fft_size) outs = [] pos = 0 while pos < T: seg = x[..., pos:pos+self.fft_size] if seg.shape[-1] < self.fft_size: seg = F.pad(seg, (0, self.fft_size - seg.shape[-1])) y = torch.fft.irfft(torch.fft.rfft(seg, n=self.fft_size) * k_f.unsqueeze(0), n=self.fft_size) outs.append(y[..., overlap:overlap+block]) pos += block return torch.cat(outs, dim=-1)[..., :T] class LocalConv1D(nn.Module): def __init__(self, d_model, k): super().__init__() self.k = k self.dw = nn.Conv1d(d_model, d_model, k, groups=d_model) self.pw = nn.Conv1d(d_model, d_model, 1) def forward(self, x): x = F.pad(x, (self.k - 1, 0)) return self.pw(F.relu(self.dw(x))) class Block(nn.Module): def __init__(self, d_model, use_global): super().__init__() self.use_global = use_global self.ln1 = nn.LayerNorm(d_model) self.local = LocalConv1D(d_model, LOCAL_KERNEL_SIZE) if use_global: self.ln2 = nn.LayerNorm(d_model) self.global_conv = GlobalConv1D(d_model, GLOBAL_KERNEL_SIZE, FFT_SIZE) self.ln3 = nn.LayerNorm(d_model) self.ff = nn.Sequential( nn.Linear(d_model, d_model*4), nn.GELU(), nn.Linear(d_model*4, d_model) ) def forward(self, x): x = x + self.local(self.ln1(x).transpose(1,2)).transpose(1,2) if self.use_global: x = x + self.global_conv(self.ln2(x).transpose(1,2)).transpose(1,2) return x + self.ff(self.ln3(x)) class CrimsonBase(nn.Module): def __init__(self, vocab): super().__init__() self.emb = nn.Embedding(vocab, D_MODEL) self.pos = nn.Embedding(MAX_SEQ_LEN, D_MODEL) self.layers = nn.ModuleList([ Block(D_MODEL, i % USE_GLOBAL_EVERY_N_LAYERS == 0) for i in range(N_LAYERS) ]) self.ln = nn.LayerNorm(D_MODEL) self.head = nn.Linear(D_MODEL, vocab) self.head.weight = self.emb.weight def forward(self, x): T = x.size(1) if T > MAX_SEQ_LEN: x = x[:, -MAX_SEQ_LEN:] T = MAX_SEQ_LEN h = self.emb(x) + self.pos(torch.arange(T, device=x.device)) for layer in self.layers: h = layer(h) return self.head(self.ln(h)) # --- Chat Engine --- class ChatEngine: def __init__(self): self.device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" print(f"[INFO] Initializing engine on {self.device}...") # Load vocab self.vocab_data = torch.load(VOCAB_MAP_PATH, map_location="cpu") self.id2new = self.vocab_data["id2new"] self.new2id = {v: k for k, v in self.id2new.items()} self.PAD_ID = self.vocab_data["PAD_ID"] self.EOS_ID = self.vocab_data["EOS_ID"] self.vocab_size = len(self.vocab_data["used_tokens"]) + 3 self.tok = tiktoken.get_encoding(TOKENIZER_NAME) # Build model self.model = CrimsonBase(self.vocab_size).to(self.device).eval() if os.path.exists(MODEL_PATH): self.model.load_state_dict(torch.load(MODEL_PATH, map_location=self.device)) print(f"[INFO] Loaded model from {MODEL_PATH}") else: print(f"[ERROR] {MODEL_PATH} not found. UI will be non-functional.") @torch.no_grad() def generate(self, prompt, max_new_tokens=MAX_GEN_LEN): # Format prompt full_prompt = f" {prompt} " raw_ids = self.tok.encode(full_prompt) input_ids = [self.id2new.get(i, self.PAD_ID) for i in raw_ids] x = torch.tensor([input_ids], dtype=torch.long, device=self.device) generated = [] for _ in range(max_new_tokens): logits = self.model(x) logits = logits[:, -1, :] / TEMPERATURE # Top-K if TOP_K > 0: v, _ = torch.topk(logits, min(TOP_K, logits.size(-1))) logits[logits < v[:, [-1]]] = -float('Inf') # Top-P if TOP_P < 1.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) sorted_indices_to_remove = cumulative_probs > TOP_P sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 indices_to_remove = sorted_indices[sorted_indices_to_remove] logits[0, indices_to_remove] = -float('Inf') probs = F.softmax(logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) if next_token.item() == self.EOS_ID: break generated.append(next_token.item()) x = torch.cat([x, next_token], dim=1) # Map back to original IDs and decode current_ids = [self.new2id.get(i, 0) for i in generated] yield self.tok.decode(current_ids) # --- UI --- class ChatApp(ctk.CTk): def __init__(self, engine): super().__init__() self.engine = engine self.title("Crimson Instruct Chat") self.geometry("800x600") ctk.set_appearance_mode("dark") ctk.set_default_color_theme("blue") # Layout self.grid_rowconfigure(0, weight=1) self.grid_columnconfigure(0, weight=1) # Chat display self.chat_display = ctk.CTkTextbox(self, state="disabled", font=("Inter", 14)) self.chat_display.grid(row=0, column=0, padx=20, pady=20, sticky="nsew") # Input area self.input_frame = ctk.CTkFrame(self) self.input_frame.grid(row=1, column=0, padx=20, pady=(0, 20), sticky="ew") self.input_frame.grid_columnconfigure(0, weight=1) self.user_input = ctk.CTkEntry(self.input_frame, placeholder_text="Type your message here...", font=("Inter", 14)) self.user_input.grid(row=0, column=0, padx=(10, 5), pady=10, sticky="ew") self.user_input.bind("", lambda e: self.send_message()) self.send_button = ctk.CTkButton(self.input_frame, text="Send", command=self.send_message, width=100) self.send_button.grid(row=0, column=1, padx=(5, 10), pady=10) def append_chat(self, sender, message): self.chat_display.configure(state="normal") tag = "" if sender == "You" else "" self.chat_display.insert("end", f"{tag} ", "bold") self.chat_display.insert("end", f"{message}\n\n") self.chat_display.configure(state="disabled") self.chat_display.see("end") def send_message(self): msg = self.user_input.get().strip() if not msg: return self.user_input.delete(0, "end") self.append_chat("You", msg) # Start generation in thread self.send_button.configure(state="disabled") threading.Thread(target=self.generate_response, args=(msg,), daemon=True).start() def generate_response(self, prompt): self.chat_display.configure(state="normal") self.chat_display.insert("end", " ", "bold") current_text = "" last_text = "" for text in self.engine.generate(prompt): current_text = text new_part = current_text[len(last_text):] self.chat_display.insert("end", new_part) self.chat_display.see("end") last_text = current_text self.chat_display.insert("end", "\n\n") self.chat_display.configure(state="disabled") self.send_button.configure(state="normal") if __name__ == "__main__": eng = ChatEngine() app = ChatApp(eng) app.mainloop()