ChatGCLM-Open / sample.py
umm-dev's picture
Upload 4 files (#2)
fd13eda verified
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import tiktoken
# ----------------- CONFIG -----------------
MODEL_PATH = "chatgclm_base_2.9M.pt"
VOCAB_PATH = "vocab_map.pt"
TOKENIZER_NAME = "gpt2"
# Defined in training script
D_MODEL = 256
N_LAYERS = 4
MAX_SEQ_LEN = 1024
LOCAL_KERNEL_SIZE = 5
GLOBAL_KERNEL_SIZE = 256
USE_GLOBAL_EVERY_N_LAYERS = 2
FFT_SIZE = 1024
PAD_ID = 0
SEP_ID = 1
EOS_ID = 2
OFFSET = 3
# ------------------------------------------
# ----------------- MODEL DEF -----------------
class GlobalConv1D(nn.Module):
def __init__(self, d_model, kernel_size, fft_size):
super().__init__()
self.kernel = nn.Parameter(torch.randn(d_model, kernel_size) * 0.01)
self.kernel_size = kernel_size
self.fft_size = fft_size
def forward(self, x):
B, C, T = x.shape
K = min(self.kernel_size, T)
overlap = K - 1
block = self.fft_size - overlap
x = F.pad(x, (overlap, 0))
k = self.kernel[:, :K]
k = F.pad(k, (0, self.fft_size - K))
k_f = torch.fft.rfft(k, n=self.fft_size)
outs = []
pos = 0
while pos < T:
seg = x[..., pos:pos+self.fft_size]
if seg.shape[-1] < self.fft_size:
seg = F.pad(seg, (0, self.fft_size - seg.shape[-1]))
y = torch.fft.irfft(
torch.fft.rfft(seg, n=self.fft_size) * k_f.unsqueeze(0),
n=self.fft_size
)
outs.append(y[..., overlap:overlap+block])
pos += block
return torch.cat(outs, dim=-1)[..., :T]
class LocalConv1D(nn.Module):
def __init__(self, d_model, k):
super().__init__()
self.k = k
self.dw = nn.Conv1d(d_model, d_model, k, groups=d_model)
self.pw = nn.Conv1d(d_model, d_model, 1)
def forward(self, x):
x = F.pad(x, (self.k - 1, 0))
return self.pw(F.relu(self.dw(x)))
class Block(nn.Module):
def __init__(self, d_model, use_global):
super().__init__()
self.use_global = use_global
self.ln1 = nn.LayerNorm(d_model)
self.local = LocalConv1D(d_model, LOCAL_KERNEL_SIZE)
if use_global:
self.ln2 = nn.LayerNorm(d_model)
self.global_conv = GlobalConv1D(d_model, GLOBAL_KERNEL_SIZE, FFT_SIZE)
self.ln3 = nn.LayerNorm(d_model)
self.ff = nn.Sequential(
nn.Linear(d_model, d_model*4),
nn.GELU(),
nn.Linear(d_model*4, d_model)
)
def forward(self, x):
x = x + self.local(self.ln1(x).transpose(1,2)).transpose(1,2)
if self.use_global:
x = x + self.global_conv(self.ln2(x).transpose(1,2)).transpose(1,2)
return x + self.ff(self.ln3(x))
class GCLM(nn.Module):
def __init__(self, vocab):
super().__init__()
self.emb = nn.Embedding(vocab, D_MODEL)
self.pos = nn.Embedding(MAX_SEQ_LEN, D_MODEL)
self.layers = nn.ModuleList([
Block(D_MODEL, i % USE_GLOBAL_EVERY_N_LAYERS == 0)
for i in range(N_LAYERS)
])
self.ln = nn.LayerNorm(D_MODEL)
self.head = nn.Linear(D_MODEL, vocab)
# Weight tying
self.head.weight = self.emb.weight
def forward(self, x):
T = x.size(1)
h = self.emb(x) + self.pos(torch.arange(T, device=x.device))
for layer in self.layers:
h = layer(h)
return self.head(self.ln(h))
# ----------------- UTILS -----------------
def load_model_and_vocab(device):
if not os.path.exists(VOCAB_PATH):
print(f"[ERROR] Vocab file not found: {VOCAB_PATH}")
return None, None, None
vocab_data = torch.load(VOCAB_PATH, map_location="cpu")
used_tokens = vocab_data["used_tokens"]
id2new = vocab_data["id2new"]
vocab_size = len(used_tokens) + OFFSET
print(f"[INFO] Vocab loaded. Size: {vocab_size}")
model = GCLM(vocab_size).to(device)
if os.path.exists(MODEL_PATH):
print(f"[INFO] Loading model from {MODEL_PATH}...")
state_dict = torch.load(MODEL_PATH, map_location=device)
model.load_state_dict(state_dict)
model.eval()
else:
print(f"[ERROR] Model file not found: {MODEL_PATH}")
return None, None, None
return model, used_tokens, id2new
@torch.no_grad()
def generate(model, prompt, tokenizer, id2new, used_tokens, device, max_new_tokens=200, temperature=0.8, top_k=50):
model.eval()
# Encode prompt
raw_ids = tokenizer.encode(prompt)
input_ids = []
# Map to model IDs
for rid in raw_ids:
if rid in id2new:
input_ids.append(id2new[rid])
else:
# Skip unknown tokens
continue
if not input_ids:
print("[WARN] No known tokens in prompt.")
input_ids = [PAD_ID] # Should not happen ideally
x = torch.tensor([input_ids], dtype=torch.long, device=device)
generated = []
for _ in range(max_new_tokens):
# Crop to max seq len
if x.size(1) > MAX_SEQ_LEN:
ctx = x[:, -MAX_SEQ_LEN:]
else:
ctx = x
logits = model(ctx)
next_token_logits = logits[:, -1, :] / temperature
# Optional: Top-k sampling
if top_k is not None:
v, _ = torch.topk(next_token_logits, min(top_k, next_token_logits.size(-1)))
next_token_logits[next_token_logits < v[:, [-1]]] = -float('Inf')
probs = F.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
idx = next_token.item()
if idx == EOS_ID:
break
x = torch.cat((x, next_token), dim=1)
generated.append(idx)
# Decode result
decoded_text = decoder(generated, used_tokens, tokenizer)
return decoded_text
def decoder(ids, used_tokens, tokenizer):
raw_ids = []
for i in ids:
if i >= OFFSET:
raw_ids.append(used_tokens[i - OFFSET])
return tokenizer.decode(raw_ids)
# ----------------- MAIN -----------------
if __name__ == "__main__":
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
print(f"Using device: {device}")
model, used_tokens, id2new = load_model_and_vocab(device)
enc = tiktoken.get_encoding(TOKENIZER_NAME)
if model:
# Find a good starting token ID (e.g., newline or space)
newline_id = id2new.get(enc.encode("\n")[0], OFFSET)
while True:
print(f"\n--- Generating Sample (Temp=0.8, TopK=50) ---")
print("-" * 20)
x = torch.tensor([[newline_id]], dtype=torch.long, device=device)
generated = []
with torch.no_grad():
for _ in range(500):
if x.size(1) > MAX_SEQ_LEN:
ctx = x[:, -MAX_SEQ_LEN:]
else:
ctx = x
logits = model(ctx)
logits = logits[:, -1, :] / 0.8 # Temperature
# Top-k
v, _ = torch.topk(logits, min(50, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
idx = next_token.item()
x = torch.cat((x, next_token), dim=1)
generated.append(idx)
if idx == EOS_ID:
print("[EOS]", end="", flush=True)
break
if idx >= OFFSET:
raw_id = used_tokens[idx - OFFSET]
token_text = enc.decode([raw_id])
print(token_text, end="", flush=True)
elif idx == PAD_ID:
print("[PAD]", end="", flush=True)
elif idx == SEP_ID:
print("[SEP]", end="", flush=True)
print("\n" + "-"*20)
cont = input("\nPress [Enter] to generate again, or type 'exit': ")
if cont.lower() == 'exit':
break