import os import torch import torch.nn as nn import torch.nn.functional as F import tiktoken # ----------------- CONFIG ----------------- MODEL_PATH = "chatgclm_base_2.9M.pt" VOCAB_PATH = "vocab_map.pt" TOKENIZER_NAME = "gpt2" # Defined in training script D_MODEL = 256 N_LAYERS = 4 MAX_SEQ_LEN = 1024 LOCAL_KERNEL_SIZE = 5 GLOBAL_KERNEL_SIZE = 256 USE_GLOBAL_EVERY_N_LAYERS = 2 FFT_SIZE = 1024 PAD_ID = 0 SEP_ID = 1 EOS_ID = 2 OFFSET = 3 # ------------------------------------------ # ----------------- MODEL DEF ----------------- class GlobalConv1D(nn.Module): def __init__(self, d_model, kernel_size, fft_size): super().__init__() self.kernel = nn.Parameter(torch.randn(d_model, kernel_size) * 0.01) self.kernel_size = kernel_size self.fft_size = fft_size def forward(self, x): B, C, T = x.shape K = min(self.kernel_size, T) overlap = K - 1 block = self.fft_size - overlap x = F.pad(x, (overlap, 0)) k = self.kernel[:, :K] k = F.pad(k, (0, self.fft_size - K)) k_f = torch.fft.rfft(k, n=self.fft_size) outs = [] pos = 0 while pos < T: seg = x[..., pos:pos+self.fft_size] if seg.shape[-1] < self.fft_size: seg = F.pad(seg, (0, self.fft_size - seg.shape[-1])) y = torch.fft.irfft( torch.fft.rfft(seg, n=self.fft_size) * k_f.unsqueeze(0), n=self.fft_size ) outs.append(y[..., overlap:overlap+block]) pos += block return torch.cat(outs, dim=-1)[..., :T] class LocalConv1D(nn.Module): def __init__(self, d_model, k): super().__init__() self.k = k self.dw = nn.Conv1d(d_model, d_model, k, groups=d_model) self.pw = nn.Conv1d(d_model, d_model, 1) def forward(self, x): x = F.pad(x, (self.k - 1, 0)) return self.pw(F.relu(self.dw(x))) class Block(nn.Module): def __init__(self, d_model, use_global): super().__init__() self.use_global = use_global self.ln1 = nn.LayerNorm(d_model) self.local = LocalConv1D(d_model, LOCAL_KERNEL_SIZE) if use_global: self.ln2 = nn.LayerNorm(d_model) self.global_conv = GlobalConv1D(d_model, GLOBAL_KERNEL_SIZE, FFT_SIZE) self.ln3 = nn.LayerNorm(d_model) self.ff = nn.Sequential( nn.Linear(d_model, d_model*4), nn.GELU(), nn.Linear(d_model*4, d_model) ) def forward(self, x): x = x + self.local(self.ln1(x).transpose(1,2)).transpose(1,2) if self.use_global: x = x + self.global_conv(self.ln2(x).transpose(1,2)).transpose(1,2) return x + self.ff(self.ln3(x)) class GCLM(nn.Module): def __init__(self, vocab): super().__init__() self.emb = nn.Embedding(vocab, D_MODEL) self.pos = nn.Embedding(MAX_SEQ_LEN, D_MODEL) self.layers = nn.ModuleList([ Block(D_MODEL, i % USE_GLOBAL_EVERY_N_LAYERS == 0) for i in range(N_LAYERS) ]) self.ln = nn.LayerNorm(D_MODEL) self.head = nn.Linear(D_MODEL, vocab) # Weight tying self.head.weight = self.emb.weight def forward(self, x): T = x.size(1) h = self.emb(x) + self.pos(torch.arange(T, device=x.device)) for layer in self.layers: h = layer(h) return self.head(self.ln(h)) # ----------------- UTILS ----------------- def load_model_and_vocab(device): if not os.path.exists(VOCAB_PATH): print(f"[ERROR] Vocab file not found: {VOCAB_PATH}") return None, None, None vocab_data = torch.load(VOCAB_PATH, map_location="cpu") used_tokens = vocab_data["used_tokens"] id2new = vocab_data["id2new"] vocab_size = len(used_tokens) + OFFSET print(f"[INFO] Vocab loaded. Size: {vocab_size}") model = GCLM(vocab_size).to(device) if os.path.exists(MODEL_PATH): print(f"[INFO] Loading model from {MODEL_PATH}...") state_dict = torch.load(MODEL_PATH, map_location=device) model.load_state_dict(state_dict) model.eval() else: print(f"[ERROR] Model file not found: {MODEL_PATH}") return None, None, None return model, used_tokens, id2new @torch.no_grad() def generate(model, prompt, tokenizer, id2new, used_tokens, device, max_new_tokens=200, temperature=0.8, top_k=50): model.eval() # Encode prompt raw_ids = tokenizer.encode(prompt) input_ids = [] # Map to model IDs for rid in raw_ids: if rid in id2new: input_ids.append(id2new[rid]) else: # Skip unknown tokens continue if not input_ids: print("[WARN] No known tokens in prompt.") input_ids = [PAD_ID] # Should not happen ideally x = torch.tensor([input_ids], dtype=torch.long, device=device) generated = [] for _ in range(max_new_tokens): # Crop to max seq len if x.size(1) > MAX_SEQ_LEN: ctx = x[:, -MAX_SEQ_LEN:] else: ctx = x logits = model(ctx) next_token_logits = logits[:, -1, :] / temperature # Optional: Top-k sampling if top_k is not None: v, _ = torch.topk(next_token_logits, min(top_k, next_token_logits.size(-1))) next_token_logits[next_token_logits < v[:, [-1]]] = -float('Inf') probs = F.softmax(next_token_logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) idx = next_token.item() if idx == EOS_ID: break x = torch.cat((x, next_token), dim=1) generated.append(idx) # Decode result decoded_text = decoder(generated, used_tokens, tokenizer) return decoded_text def decoder(ids, used_tokens, tokenizer): raw_ids = [] for i in ids: if i >= OFFSET: raw_ids.append(used_tokens[i - OFFSET]) return tokenizer.decode(raw_ids) # ----------------- MAIN ----------------- if __name__ == "__main__": if torch.cuda.is_available(): device = "cuda" elif torch.backends.mps.is_available(): device = "mps" else: device = "cpu" print(f"Using device: {device}") model, used_tokens, id2new = load_model_and_vocab(device) enc = tiktoken.get_encoding(TOKENIZER_NAME) if model: # Find a good starting token ID (e.g., newline or space) newline_id = id2new.get(enc.encode("\n")[0], OFFSET) while True: print(f"\n--- Generating Sample (Temp=0.8, TopK=50) ---") print("-" * 20) x = torch.tensor([[newline_id]], dtype=torch.long, device=device) generated = [] with torch.no_grad(): for _ in range(500): if x.size(1) > MAX_SEQ_LEN: ctx = x[:, -MAX_SEQ_LEN:] else: ctx = x logits = model(ctx) logits = logits[:, -1, :] / 0.8 # Temperature # Top-k v, _ = torch.topk(logits, min(50, logits.size(-1))) logits[logits < v[:, [-1]]] = -float('Inf') probs = F.softmax(logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) idx = next_token.item() x = torch.cat((x, next_token), dim=1) generated.append(idx) if idx == EOS_ID: print("[EOS]", end="", flush=True) break if idx >= OFFSET: raw_id = used_tokens[idx - OFFSET] token_text = enc.decode([raw_id]) print(token_text, end="", flush=True) elif idx == PAD_ID: print("[PAD]", end="", flush=True) elif idx == SEP_ID: print("[SEP]", end="", flush=True) print("\n" + "-"*20) cont = input("\nPress [Enter] to generate again, or type 'exit': ") if cont.lower() == 'exit': break