|
|
import os |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import customtkinter as ctk |
|
|
import tiktoken |
|
|
import threading |
|
|
from typing import List |
|
|
|
|
|
|
|
|
D_MODEL = 256 |
|
|
N_LAYERS = 4 |
|
|
MAX_SEQ_LEN = 1024 |
|
|
LOCAL_KERNEL_SIZE = 5 |
|
|
GLOBAL_KERNEL_SIZE = 256 |
|
|
USE_GLOBAL_EVERY_N_LAYERS = 2 |
|
|
FFT_SIZE = 1024 |
|
|
TOKENIZER_NAME = "gpt2" |
|
|
|
|
|
|
|
|
VOCAB_MAP_PATH = "vocab_map.pt" |
|
|
MODEL_PATH = "crimson_instruct_8.9M.pt" |
|
|
|
|
|
|
|
|
TEMPERATURE = 0.8 |
|
|
TOP_K = 50 |
|
|
TOP_P = 0.9 |
|
|
MAX_GEN_LEN = 256 |
|
|
|
|
|
|
|
|
|
|
|
class GlobalConv1D(nn.Module): |
|
|
def __init__(self, d_model, kernel_size, fft_size): |
|
|
super().__init__() |
|
|
self.kernel = nn.Parameter(torch.randn(d_model, kernel_size) * 0.01) |
|
|
self.kernel_size = kernel_size |
|
|
self.fft_size = fft_size |
|
|
|
|
|
def forward(self, x): |
|
|
B, C, T = x.shape |
|
|
K = min(self.kernel_size, T) |
|
|
overlap = K - 1 |
|
|
block = self.fft_size - overlap |
|
|
x = F.pad(x, (overlap, 0)) |
|
|
k = self.kernel[:, :K] |
|
|
k = F.pad(k, (0, self.fft_size - K)) |
|
|
k_f = torch.fft.rfft(k, n=self.fft_size) |
|
|
outs = [] |
|
|
pos = 0 |
|
|
while pos < T: |
|
|
seg = x[..., pos:pos+self.fft_size] |
|
|
if seg.shape[-1] < self.fft_size: |
|
|
seg = F.pad(seg, (0, self.fft_size - seg.shape[-1])) |
|
|
y = torch.fft.irfft(torch.fft.rfft(seg, n=self.fft_size) * k_f.unsqueeze(0), n=self.fft_size) |
|
|
outs.append(y[..., overlap:overlap+block]) |
|
|
pos += block |
|
|
return torch.cat(outs, dim=-1)[..., :T] |
|
|
|
|
|
class LocalConv1D(nn.Module): |
|
|
def __init__(self, d_model, k): |
|
|
super().__init__() |
|
|
self.k = k |
|
|
self.dw = nn.Conv1d(d_model, d_model, k, groups=d_model) |
|
|
self.pw = nn.Conv1d(d_model, d_model, 1) |
|
|
|
|
|
def forward(self, x): |
|
|
x = F.pad(x, (self.k - 1, 0)) |
|
|
return self.pw(F.relu(self.dw(x))) |
|
|
|
|
|
class Block(nn.Module): |
|
|
def __init__(self, d_model, use_global): |
|
|
super().__init__() |
|
|
self.use_global = use_global |
|
|
self.ln1 = nn.LayerNorm(d_model) |
|
|
self.local = LocalConv1D(d_model, LOCAL_KERNEL_SIZE) |
|
|
if use_global: |
|
|
self.ln2 = nn.LayerNorm(d_model) |
|
|
self.global_conv = GlobalConv1D(d_model, GLOBAL_KERNEL_SIZE, FFT_SIZE) |
|
|
self.ln3 = nn.LayerNorm(d_model) |
|
|
self.ff = nn.Sequential( |
|
|
nn.Linear(d_model, d_model*4), |
|
|
nn.GELU(), |
|
|
nn.Linear(d_model*4, d_model) |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
x = x + self.local(self.ln1(x).transpose(1,2)).transpose(1,2) |
|
|
if self.use_global: |
|
|
x = x + self.global_conv(self.ln2(x).transpose(1,2)).transpose(1,2) |
|
|
return x + self.ff(self.ln3(x)) |
|
|
|
|
|
class CrimsonBase(nn.Module): |
|
|
def __init__(self, vocab): |
|
|
super().__init__() |
|
|
self.emb = nn.Embedding(vocab, D_MODEL) |
|
|
self.pos = nn.Embedding(MAX_SEQ_LEN, D_MODEL) |
|
|
self.layers = nn.ModuleList([ |
|
|
Block(D_MODEL, i % USE_GLOBAL_EVERY_N_LAYERS == 0) |
|
|
for i in range(N_LAYERS) |
|
|
]) |
|
|
self.ln = nn.LayerNorm(D_MODEL) |
|
|
self.head = nn.Linear(D_MODEL, vocab) |
|
|
self.head.weight = self.emb.weight |
|
|
|
|
|
def forward(self, x): |
|
|
T = x.size(1) |
|
|
if T > MAX_SEQ_LEN: |
|
|
x = x[:, -MAX_SEQ_LEN:] |
|
|
T = MAX_SEQ_LEN |
|
|
|
|
|
h = self.emb(x) + self.pos(torch.arange(T, device=x.device)) |
|
|
for layer in self.layers: |
|
|
h = layer(h) |
|
|
return self.head(self.ln(h)) |
|
|
|
|
|
|
|
|
|
|
|
class ChatEngine: |
|
|
def __init__(self): |
|
|
self.device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" |
|
|
print(f"[INFO] Initializing engine on {self.device}...") |
|
|
|
|
|
|
|
|
self.vocab_data = torch.load(VOCAB_MAP_PATH, map_location="cpu") |
|
|
self.id2new = self.vocab_data["id2new"] |
|
|
self.new2id = {v: k for k, v in self.id2new.items()} |
|
|
self.PAD_ID = self.vocab_data["PAD_ID"] |
|
|
self.EOS_ID = self.vocab_data["EOS_ID"] |
|
|
self.vocab_size = len(self.vocab_data["used_tokens"]) + 3 |
|
|
|
|
|
self.tok = tiktoken.get_encoding(TOKENIZER_NAME) |
|
|
|
|
|
|
|
|
self.model = CrimsonBase(self.vocab_size).to(self.device).eval() |
|
|
if os.path.exists(MODEL_PATH): |
|
|
self.model.load_state_dict(torch.load(MODEL_PATH, map_location=self.device)) |
|
|
print(f"[INFO] Loaded model from {MODEL_PATH}") |
|
|
else: |
|
|
print(f"[ERROR] {MODEL_PATH} not found. UI will be non-functional.") |
|
|
|
|
|
@torch.no_grad() |
|
|
def generate(self, prompt, max_new_tokens=MAX_GEN_LEN): |
|
|
|
|
|
full_prompt = f"<user> {prompt} <ai> " |
|
|
raw_ids = self.tok.encode(full_prompt) |
|
|
input_ids = [self.id2new.get(i, self.PAD_ID) for i in raw_ids] |
|
|
x = torch.tensor([input_ids], dtype=torch.long, device=self.device) |
|
|
|
|
|
generated = [] |
|
|
for _ in range(max_new_tokens): |
|
|
logits = self.model(x) |
|
|
logits = logits[:, -1, :] / TEMPERATURE |
|
|
|
|
|
|
|
|
if TOP_K > 0: |
|
|
v, _ = torch.topk(logits, min(TOP_K, logits.size(-1))) |
|
|
logits[logits < v[:, [-1]]] = -float('Inf') |
|
|
|
|
|
|
|
|
if TOP_P < 1.0: |
|
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True) |
|
|
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
|
|
sorted_indices_to_remove = cumulative_probs > TOP_P |
|
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
|
|
sorted_indices_to_remove[..., 0] = 0 |
|
|
indices_to_remove = sorted_indices[sorted_indices_to_remove] |
|
|
logits[0, indices_to_remove] = -float('Inf') |
|
|
|
|
|
probs = F.softmax(logits, dim=-1) |
|
|
next_token = torch.multinomial(probs, num_samples=1) |
|
|
|
|
|
if next_token.item() == self.EOS_ID: |
|
|
break |
|
|
|
|
|
generated.append(next_token.item()) |
|
|
x = torch.cat([x, next_token], dim=1) |
|
|
|
|
|
|
|
|
current_ids = [self.new2id.get(i, 0) for i in generated] |
|
|
yield self.tok.decode(current_ids) |
|
|
|
|
|
|
|
|
|
|
|
class ChatApp(ctk.CTk): |
|
|
def __init__(self, engine): |
|
|
super().__init__() |
|
|
self.engine = engine |
|
|
self.title("Crimson Instruct Chat") |
|
|
self.geometry("800x600") |
|
|
|
|
|
ctk.set_appearance_mode("dark") |
|
|
ctk.set_default_color_theme("blue") |
|
|
|
|
|
|
|
|
self.grid_rowconfigure(0, weight=1) |
|
|
self.grid_columnconfigure(0, weight=1) |
|
|
|
|
|
|
|
|
self.chat_display = ctk.CTkTextbox(self, state="disabled", font=("Inter", 14)) |
|
|
self.chat_display.grid(row=0, column=0, padx=20, pady=20, sticky="nsew") |
|
|
|
|
|
|
|
|
self.input_frame = ctk.CTkFrame(self) |
|
|
self.input_frame.grid(row=1, column=0, padx=20, pady=(0, 20), sticky="ew") |
|
|
self.input_frame.grid_columnconfigure(0, weight=1) |
|
|
|
|
|
self.user_input = ctk.CTkEntry(self.input_frame, placeholder_text="Type your message here...", font=("Inter", 14)) |
|
|
self.user_input.grid(row=0, column=0, padx=(10, 5), pady=10, sticky="ew") |
|
|
self.user_input.bind("<Return>", lambda e: self.send_message()) |
|
|
|
|
|
self.send_button = ctk.CTkButton(self.input_frame, text="Send", command=self.send_message, width=100) |
|
|
self.send_button.grid(row=0, column=1, padx=(5, 10), pady=10) |
|
|
|
|
|
def append_chat(self, sender, message): |
|
|
self.chat_display.configure(state="normal") |
|
|
tag = "<user>" if sender == "You" else "<ai>" |
|
|
self.chat_display.insert("end", f"{tag} ", "bold") |
|
|
self.chat_display.insert("end", f"{message}\n\n") |
|
|
self.chat_display.configure(state="disabled") |
|
|
self.chat_display.see("end") |
|
|
|
|
|
def send_message(self): |
|
|
msg = self.user_input.get().strip() |
|
|
if not msg: return |
|
|
|
|
|
self.user_input.delete(0, "end") |
|
|
self.append_chat("You", msg) |
|
|
|
|
|
|
|
|
self.send_button.configure(state="disabled") |
|
|
threading.Thread(target=self.generate_response, args=(msg,), daemon=True).start() |
|
|
|
|
|
def generate_response(self, prompt): |
|
|
self.chat_display.configure(state="normal") |
|
|
self.chat_display.insert("end", "<ai> ", "bold") |
|
|
|
|
|
current_text = "" |
|
|
last_text = "" |
|
|
|
|
|
for text in self.engine.generate(prompt): |
|
|
current_text = text |
|
|
new_part = current_text[len(last_text):] |
|
|
self.chat_display.insert("end", new_part) |
|
|
self.chat_display.see("end") |
|
|
last_text = current_text |
|
|
|
|
|
self.chat_display.insert("end", "\n\n") |
|
|
self.chat_display.configure(state="disabled") |
|
|
self.send_button.configure(state="normal") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
eng = ChatEngine() |
|
|
app = ChatApp(eng) |
|
|
app.mainloop() |
|
|
|