print("Starting...") ############################################### # CONFIGURATION — CUSTOMIZE EVERYTHING HERE ############################################### # ---- data / vocab ---- TXT_PATH = "data.txt" DATA_PCT = 0.001 # this is small for testing purposes TOKENIZER_NAME = "gpt2" REDUCE_VOCAB = True VOCAB_SAVE_PATH = "vocab_map.pt" # ---- training ---- EPOCHS = 25 MICRO_BATCH_SIZE = 1 GRAD_ACCUM_STEPS = 8 LEARNING_RATE = 3e-4 # ---- model ---- D_MODEL = 256 N_LAYERS = 4 MAX_SEQ_LEN = 1024 LOCAL_KERNEL_SIZE = 5 GLOBAL_KERNEL_SIZE = 256 USE_GLOBAL_EVERY_N_LAYERS = 2 # ---- FFT conv ---- FFT_SIZE = 1024 # must be power of 2 and > GLOBAL_KERNEL_SIZE # ---- checkpointing ---- SAVE_PATH = "model.pt" SAVE_N_EPOCHS = 1 # ---- device ---- USE_DEVICE = "cuda" USE_AMP = True USE_ACTIVATION_CHECKPOINTING = False # ---- torch.compile ---- COMPILE = False COMPILE_MODE = "reduce-overhead" COMPILE_BACKEND = "eager" ############################################### # END CONFIG ############################################### import os # Windows cannot use expandable_segments — only enable on Linux. if os.name != "nt": os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader from tqdm import tqdm import tiktoken # performance settings if torch.cuda.is_available(): torch.set_float32_matmul_precision("high") torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True ############################################################### # SPECIAL TOKENS ############################################################### PAD_ID = 0 SEP_ID = 1 EOS_ID = 2 OFFSET = 3 ############################################################### # VOCAB ############################################################### def build_dataset_vocab(txt_path, tokenizer, save_path): text = open(txt_path, "r", encoding="utf-8").read() if DATA_PCT < 1.0: text = text[:int(len(text) * DATA_PCT)] token_ids = tokenizer.encode(text) used = sorted(set(token_ids)) id2new = {tok: i + OFFSET for i, tok in enumerate(used)} torch.save({ "used_tokens": used, "id2new": id2new, "PAD_ID": PAD_ID, "SEP_ID": SEP_ID, "EOS_ID": EOS_ID, }, save_path) print(f"[OK] Vocab size: {len(used) + OFFSET}") return used, id2new ############################################################### # DATASET ############################################################### class RemappedTextDataset(Dataset): def __init__(self, ids, max_len): self.ids = ids self.max_len = max_len def __len__(self): return max(0, len(self.ids) - self.max_len - 1) def __getitem__(self, i): x = self.ids[i:i+self.max_len] y = self.ids[i+1:i+self.max_len+1] return torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.long) ############################################################### # GLOBAL + LOCAL CONVOLUTION ############################################################### class GlobalConv1D(nn.Module): def __init__(self, d_model, kernel_size, fft_size): super().__init__() self.kernel = nn.Parameter(torch.randn(d_model, kernel_size) * 0.01) self.kernel_size = kernel_size self.fft_size = fft_size def forward(self, x): B, C, T = x.shape K = min(self.kernel_size, T) overlap = K - 1 block = self.fft_size - overlap x = F.pad(x, (overlap, 0)) k = self.kernel[:, :K] k = F.pad(k, (0, self.fft_size - K)) k_f = torch.fft.rfft(k, n=self.fft_size) outs = [] pos = 0 while pos < T: seg = x[..., pos:pos+self.fft_size] if seg.shape[-1] < self.fft_size: seg = F.pad(seg, (0, self.fft_size - seg.shape[-1])) y = torch.fft.irfft( torch.fft.rfft(seg, n=self.fft_size) * k_f.unsqueeze(0), n=self.fft_size ) outs.append(y[..., overlap:overlap+block]) pos += block return torch.cat(outs, dim=-1)[..., :T] class LocalConv1D(nn.Module): def __init__(self, d_model, k): super().__init__() self.k = k self.dw = nn.Conv1d(d_model, d_model, k, groups=d_model) self.pw = nn.Conv1d(d_model, d_model, 1) def forward(self, x): x = F.pad(x, (self.k - 1, 0)) return self.pw(F.relu(self.dw(x))) class Block(nn.Module): def __init__(self, d_model, use_global): super().__init__() self.use_global = use_global self.ln1 = nn.LayerNorm(d_model) self.local = LocalConv1D(d_model, LOCAL_KERNEL_SIZE) if use_global: self.ln2 = nn.LayerNorm(d_model) self.global_conv = GlobalConv1D(d_model, GLOBAL_KERNEL_SIZE, FFT_SIZE) self.ln3 = nn.LayerNorm(d_model) self.ff = nn.Sequential( nn.Linear(d_model, d_model*4), nn.GELU(), nn.Linear(d_model*4, d_model) ) def forward(self, x): x = x + self.local(self.ln1(x).transpose(1,2)).transpose(1,2) if self.use_global: x = x + self.global_conv(self.ln2(x).transpose(1,2)).transpose(1,2) return x + self.ff(self.ln3(x)) class GCLM(nn.Module): def __init__(self, vocab): super().__init__() self.emb = nn.Embedding(vocab, D_MODEL) self.pos = nn.Embedding(MAX_SEQ_LEN, D_MODEL) self.layers = nn.ModuleList([ Block(D_MODEL, i % USE_GLOBAL_EVERY_N_LAYERS == 0) for i in range(N_LAYERS) ]) self.ln = nn.LayerNorm(D_MODEL) self.head = nn.Linear(D_MODEL, vocab) # Weight tying: SIGNIFICANTLY reduces parameter count self.head.weight = self.emb.weight def forward(self, x): T = x.size(1) h = self.emb(x) + self.pos(torch.arange(T, device=x.device)) for layer in self.layers: h = layer(h) return self.head(self.ln(h)) ############################################################### # TRAINING LOOP ############################################################### def format_params(num): if num >= 1_000_000_000: return f"{num/1_000_000_000:.1f}B" elif num >= 1_000_000: return f"{num/1_000_000:.1f}M" else: return f"{num/1_000:.1f}K" @torch.no_grad() def estimate_loss(model, dl, device, ctx): model.eval() losses = [] # Check up to 50 batches for validation to save time limit = 50 for i, (x, y) in enumerate(dl): if i >= limit: break x, y = x.to(device), y.to(device) with ctx: logits = model(x) loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), y.reshape(-1), ignore_index=PAD_ID) losses.append(loss.item()) model.train() return sum(losses) / len(losses) if losses else 0.0 def train(): if torch.cuda.is_available(): device = "cuda" elif torch.backends.mps.is_available(): device = "mps" else: device = "cpu" print("[INFO] Device:", device) # 1. Prepare Vocab & Data tok = tiktoken.get_encoding(TOKENIZER_NAME) # We call this to generate/load the vocab map used, id2new = build_dataset_vocab(TXT_PATH, tok, VOCAB_SAVE_PATH) vocab = len(used) + OFFSET # Load and process full text print("[INFO] Loading and tokenizing text...") text = open(TXT_PATH, "r", encoding="utf-8").read() if DATA_PCT < 1.0: text = text[:int(len(text) * DATA_PCT)] raw_ids = tok.encode(text) # Map to new IDs ids = [id2new.get(i, PAD_ID) for i in raw_ids] + [EOS_ID] # Split Train/Val (90/10) n = len(ids) split_idx = int(n * 0.9) train_ids = ids[:split_idx] val_ids = ids[split_idx:] print(f"[INFO] Tokens: {n} | Train: {len(train_ids)} | Val: {len(val_ids)}") train_ds = RemappedTextDataset(train_ids, MAX_SEQ_LEN) val_ds = RemappedTextDataset(val_ids, MAX_SEQ_LEN) train_dl = DataLoader(train_ds, batch_size=MICRO_BATCH_SIZE, shuffle=True) val_dl = DataLoader(val_ds, batch_size=MICRO_BATCH_SIZE, shuffle=False) model = GCLM(vocab).to(device) # Calculate params num_params = sum(p.numel() for p in model.parameters()) param_str = format_params(num_params) save_path = f"chatgclm_base_{param_str}.pt" print(f"[INFO] Model parameters: {num_params:,} ({param_str})") print(f"[INFO] Save path: {save_path}") # 🔁 RESUME IF CHECKPOINT EXISTS if os.path.exists(save_path): model.load_state_dict(torch.load(save_path, map_location=device)) print(f"[RESUME] Loaded existing checkpoint from {save_path}") if device == "cuda" and COMPILE: print("[INFO] Compiling model with torch.compile...") model = torch.compile( model, mode=COMPILE_MODE, fullgraph=False, backend=COMPILE_BACKEND ) opt = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE) loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_ID) # AMP Context if device == "cuda" and USE_AMP: ctx = torch.amp.autocast(device) scaler = torch.amp.GradScaler(device) else: # Dummy context for cpu/mps import contextlib ctx = contextlib.nullcontext() scaler = None for ep in range(EPOCHS): print(f"\nEpoch {ep+1}/{EPOCHS}") opt.zero_grad(set_to_none=True) pbar = tqdm(train_dl, desc="Training") running_loss = 0.0 for i, (x, y) in enumerate(pbar): x, y = x.to(device), y.to(device) with ctx: logits = model(x) loss = loss_fn(logits.reshape(-1, vocab), y.reshape(-1)) loss_val = loss.item() loss = loss / GRAD_ACCUM_STEPS if scaler: scaler.scale(loss).backward() else: loss.backward() if (i+1) % GRAD_ACCUM_STEPS == 0: if scaler: scaler.step(opt) scaler.update() else: opt.step() opt.zero_grad(set_to_none=True) # Update progress bar running_loss = 0.9 * running_loss + 0.1 * loss_val if running_loss > 0 else loss_val pbar.set_postfix(loss=f"{running_loss:.4f}") # Validate at end of epoch val_loss = estimate_loss(model, val_dl, device, ctx) print(f"Epoch {ep+1} finished. Train Loss: {running_loss:.4f} | Val Loss: {val_loss:.4f}") if SAVE_N_EPOCHS and (ep+1) % SAVE_N_EPOCHS == 0: torch.save(model.state_dict(), save_path) print(f"[OK] Saved checkpoint to {save_path}") torch.save(model.state_dict(), save_path) print("[DONE] Training complete.") ############################################################### # ENTRY POINT ############################################################### if __name__ == "__main__": train()