ChatGCLM-Open / train_gclm_base.py
umm-dev's picture
Upload 4 files (#2)
fd13eda verified
print("Starting...")
###############################################
# CONFIGURATION — CUSTOMIZE EVERYTHING HERE
###############################################
# ---- data / vocab ----
TXT_PATH = "data.txt"
DATA_PCT = 0.001 # this is small for testing purposes
TOKENIZER_NAME = "gpt2"
REDUCE_VOCAB = True
VOCAB_SAVE_PATH = "vocab_map.pt"
# ---- training ----
EPOCHS = 25
MICRO_BATCH_SIZE = 1
GRAD_ACCUM_STEPS = 8
LEARNING_RATE = 3e-4
# ---- model ----
D_MODEL = 256
N_LAYERS = 4
MAX_SEQ_LEN = 1024
LOCAL_KERNEL_SIZE = 5
GLOBAL_KERNEL_SIZE = 256
USE_GLOBAL_EVERY_N_LAYERS = 2
# ---- FFT conv ----
FFT_SIZE = 1024 # must be power of 2 and > GLOBAL_KERNEL_SIZE
# ---- checkpointing ----
SAVE_PATH = "model.pt"
SAVE_N_EPOCHS = 1
# ---- device ----
USE_DEVICE = "cuda"
USE_AMP = True
USE_ACTIVATION_CHECKPOINTING = False
# ---- torch.compile ----
COMPILE = False
COMPILE_MODE = "reduce-overhead"
COMPILE_BACKEND = "eager"
###############################################
# END CONFIG
###############################################
import os
# Windows cannot use expandable_segments — only enable on Linux.
if os.name != "nt":
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import tiktoken
# performance settings
if torch.cuda.is_available():
torch.set_float32_matmul_precision("high")
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
###############################################################
# SPECIAL TOKENS
###############################################################
PAD_ID = 0
SEP_ID = 1
EOS_ID = 2
OFFSET = 3
###############################################################
# VOCAB
###############################################################
def build_dataset_vocab(txt_path, tokenizer, save_path):
text = open(txt_path, "r", encoding="utf-8").read()
if DATA_PCT < 1.0:
text = text[:int(len(text) * DATA_PCT)]
token_ids = tokenizer.encode(text)
used = sorted(set(token_ids))
id2new = {tok: i + OFFSET for i, tok in enumerate(used)}
torch.save({
"used_tokens": used,
"id2new": id2new,
"PAD_ID": PAD_ID,
"SEP_ID": SEP_ID,
"EOS_ID": EOS_ID,
}, save_path)
print(f"[OK] Vocab size: {len(used) + OFFSET}")
return used, id2new
###############################################################
# DATASET
###############################################################
class RemappedTextDataset(Dataset):
def __init__(self, ids, max_len):
self.ids = ids
self.max_len = max_len
def __len__(self):
return max(0, len(self.ids) - self.max_len - 1)
def __getitem__(self, i):
x = self.ids[i:i+self.max_len]
y = self.ids[i+1:i+self.max_len+1]
return torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.long)
###############################################################
# GLOBAL + LOCAL CONVOLUTION
###############################################################
class GlobalConv1D(nn.Module):
def __init__(self, d_model, kernel_size, fft_size):
super().__init__()
self.kernel = nn.Parameter(torch.randn(d_model, kernel_size) * 0.01)
self.kernel_size = kernel_size
self.fft_size = fft_size
def forward(self, x):
B, C, T = x.shape
K = min(self.kernel_size, T)
overlap = K - 1
block = self.fft_size - overlap
x = F.pad(x, (overlap, 0))
k = self.kernel[:, :K]
k = F.pad(k, (0, self.fft_size - K))
k_f = torch.fft.rfft(k, n=self.fft_size)
outs = []
pos = 0
while pos < T:
seg = x[..., pos:pos+self.fft_size]
if seg.shape[-1] < self.fft_size:
seg = F.pad(seg, (0, self.fft_size - seg.shape[-1]))
y = torch.fft.irfft(
torch.fft.rfft(seg, n=self.fft_size) * k_f.unsqueeze(0),
n=self.fft_size
)
outs.append(y[..., overlap:overlap+block])
pos += block
return torch.cat(outs, dim=-1)[..., :T]
class LocalConv1D(nn.Module):
def __init__(self, d_model, k):
super().__init__()
self.k = k
self.dw = nn.Conv1d(d_model, d_model, k, groups=d_model)
self.pw = nn.Conv1d(d_model, d_model, 1)
def forward(self, x):
x = F.pad(x, (self.k - 1, 0))
return self.pw(F.relu(self.dw(x)))
class Block(nn.Module):
def __init__(self, d_model, use_global):
super().__init__()
self.use_global = use_global
self.ln1 = nn.LayerNorm(d_model)
self.local = LocalConv1D(d_model, LOCAL_KERNEL_SIZE)
if use_global:
self.ln2 = nn.LayerNorm(d_model)
self.global_conv = GlobalConv1D(d_model, GLOBAL_KERNEL_SIZE, FFT_SIZE)
self.ln3 = nn.LayerNorm(d_model)
self.ff = nn.Sequential(
nn.Linear(d_model, d_model*4),
nn.GELU(),
nn.Linear(d_model*4, d_model)
)
def forward(self, x):
x = x + self.local(self.ln1(x).transpose(1,2)).transpose(1,2)
if self.use_global:
x = x + self.global_conv(self.ln2(x).transpose(1,2)).transpose(1,2)
return x + self.ff(self.ln3(x))
class GCLM(nn.Module):
def __init__(self, vocab):
super().__init__()
self.emb = nn.Embedding(vocab, D_MODEL)
self.pos = nn.Embedding(MAX_SEQ_LEN, D_MODEL)
self.layers = nn.ModuleList([
Block(D_MODEL, i % USE_GLOBAL_EVERY_N_LAYERS == 0)
for i in range(N_LAYERS)
])
self.ln = nn.LayerNorm(D_MODEL)
self.head = nn.Linear(D_MODEL, vocab)
# Weight tying: SIGNIFICANTLY reduces parameter count
self.head.weight = self.emb.weight
def forward(self, x):
T = x.size(1)
h = self.emb(x) + self.pos(torch.arange(T, device=x.device))
for layer in self.layers:
h = layer(h)
return self.head(self.ln(h))
###############################################################
# TRAINING LOOP
###############################################################
def format_params(num):
if num >= 1_000_000_000:
return f"{num/1_000_000_000:.1f}B"
elif num >= 1_000_000:
return f"{num/1_000_000:.1f}M"
else:
return f"{num/1_000:.1f}K"
@torch.no_grad()
def estimate_loss(model, dl, device, ctx):
model.eval()
losses = []
# Check up to 50 batches for validation to save time
limit = 50
for i, (x, y) in enumerate(dl):
if i >= limit: break
x, y = x.to(device), y.to(device)
with ctx:
logits = model(x)
loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), y.reshape(-1), ignore_index=PAD_ID)
losses.append(loss.item())
model.train()
return sum(losses) / len(losses) if losses else 0.0
def train():
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
print("[INFO] Device:", device)
# 1. Prepare Vocab & Data
tok = tiktoken.get_encoding(TOKENIZER_NAME)
# We call this to generate/load the vocab map
used, id2new = build_dataset_vocab(TXT_PATH, tok, VOCAB_SAVE_PATH)
vocab = len(used) + OFFSET
# Load and process full text
print("[INFO] Loading and tokenizing text...")
text = open(TXT_PATH, "r", encoding="utf-8").read()
if DATA_PCT < 1.0:
text = text[:int(len(text) * DATA_PCT)]
raw_ids = tok.encode(text)
# Map to new IDs
ids = [id2new.get(i, PAD_ID) for i in raw_ids] + [EOS_ID]
# Split Train/Val (90/10)
n = len(ids)
split_idx = int(n * 0.9)
train_ids = ids[:split_idx]
val_ids = ids[split_idx:]
print(f"[INFO] Tokens: {n} | Train: {len(train_ids)} | Val: {len(val_ids)}")
train_ds = RemappedTextDataset(train_ids, MAX_SEQ_LEN)
val_ds = RemappedTextDataset(val_ids, MAX_SEQ_LEN)
train_dl = DataLoader(train_ds, batch_size=MICRO_BATCH_SIZE, shuffle=True)
val_dl = DataLoader(val_ds, batch_size=MICRO_BATCH_SIZE, shuffle=False)
model = GCLM(vocab).to(device)
# Calculate params
num_params = sum(p.numel() for p in model.parameters())
param_str = format_params(num_params)
save_path = f"chatgclm_base_{param_str}.pt"
print(f"[INFO] Model parameters: {num_params:,} ({param_str})")
print(f"[INFO] Save path: {save_path}")
# 🔁 RESUME IF CHECKPOINT EXISTS
if os.path.exists(save_path):
model.load_state_dict(torch.load(save_path, map_location=device))
print(f"[RESUME] Loaded existing checkpoint from {save_path}")
if device == "cuda" and COMPILE:
print("[INFO] Compiling model with torch.compile...")
model = torch.compile(
model,
mode=COMPILE_MODE,
fullgraph=False,
backend=COMPILE_BACKEND
)
opt = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_ID)
# AMP Context
if device == "cuda" and USE_AMP:
ctx = torch.amp.autocast(device)
scaler = torch.amp.GradScaler(device)
else:
# Dummy context for cpu/mps
import contextlib
ctx = contextlib.nullcontext()
scaler = None
for ep in range(EPOCHS):
print(f"\nEpoch {ep+1}/{EPOCHS}")
opt.zero_grad(set_to_none=True)
pbar = tqdm(train_dl, desc="Training")
running_loss = 0.0
for i, (x, y) in enumerate(pbar):
x, y = x.to(device), y.to(device)
with ctx:
logits = model(x)
loss = loss_fn(logits.reshape(-1, vocab), y.reshape(-1))
loss_val = loss.item()
loss = loss / GRAD_ACCUM_STEPS
if scaler:
scaler.scale(loss).backward()
else:
loss.backward()
if (i+1) % GRAD_ACCUM_STEPS == 0:
if scaler:
scaler.step(opt)
scaler.update()
else:
opt.step()
opt.zero_grad(set_to_none=True)
# Update progress bar
running_loss = 0.9 * running_loss + 0.1 * loss_val if running_loss > 0 else loss_val
pbar.set_postfix(loss=f"{running_loss:.4f}")
# Validate at end of epoch
val_loss = estimate_loss(model, val_dl, device, ctx)
print(f"Epoch {ep+1} finished. Train Loss: {running_loss:.4f} | Val Loss: {val_loss:.4f}")
if SAVE_N_EPOCHS and (ep+1) % SAVE_N_EPOCHS == 0:
torch.save(model.state_dict(), save_path)
print(f"[OK] Saved checkpoint to {save_path}")
torch.save(model.state_dict(), save_path)
print("[DONE] Training complete.")
###############################################################
# ENTRY POINT
###############################################################
if __name__ == "__main__":
train()