ChatGCLM-330M / train_chatgclm.py
AGofficial's picture
Upload 6 files
238d08f verified
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import tiktoken
import contextlib
from model import ChatGCLM, MAX_SEQ_LEN
if os.name != "nt":
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
if torch.cuda.is_available():
torch.set_float32_matmul_precision("high")
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
DATA_DIR = "data"
DATA_PCT = 0.002
TOKENIZER_NAME = "gpt2"
VOCAB_SAVE_PATH = "vocab_map.pt"
EPOCHS = 50
MICRO_BATCH_SIZE = 1
GRAD_ACCUM_STEPS = 8
LEARNING_RATE = 5e-4
MIN_LR = 1e-5
SAVE_N_EPOCHS = 1
PAD_ID = 0
SEP_ID = 1
EOS_ID = 2
OFFSET = 3
def build_dataset_vocab(data_dir, tokenizer, save_path):
vocab_size = tokenizer.n_vocab
torch.save({
"vocab_size": vocab_size,
"PAD_ID": PAD_ID,
"SEP_ID": SEP_ID,
"EOS_ID": EOS_ID,
}, save_path)
return vocab_size
class RemappedTextDataset(Dataset):
def __init__(self, ids, max_len):
self.ids = ids
self.max_len = max_len
def __len__(self):
return max(0, (len(self.ids) - 1) // self.max_len)
def __getitem__(self, i):
start = i * self.max_len
x = self.ids[start : start + self.max_len]
y = self.ids[start + 1 : start + self.max_len + 1]
if len(x) < self.max_len:
x = x + [0] * (self.max_len - len(x))
if len(y) < self.max_len:
y = y + [0] * (self.max_len - len(y))
return torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.long)
def format_params(num):
if num >= 1_000_000_000:
return f"{num/1_000_000_000:.1f}B"
elif num >= 1_000_000:
return f"{num/1_000_000:.1f}M"
else:
return f"{num/1_000:.1f}K"
@torch.no_grad()
def estimate_loss(model, dl, device, ctx):
model.eval()
losses = []
limit = 50
for i, (x, y) in enumerate(dl):
if i >= limit: break
x, y = x.to(device), y.to(device)
with ctx:
logits = model(x)
loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), y.reshape(-1), ignore_index=PAD_ID)
losses.append(loss.item())
model.train()
return sum(losses) / len(losses) if losses else 0.0
def train():
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
tok = tiktoken.get_encoding(TOKENIZER_NAME)
vocab = build_dataset_vocab(DATA_DIR, tok, VOCAB_SAVE_PATH)
full_text = ""
for f in os.listdir(DATA_DIR):
if not f.endswith(".txt"): continue
fpath = os.path.join(DATA_DIR, f)
content = open(fpath, "r", encoding="utf-8").read()
full_text += content + "\n"
ids = tok.encode(full_text) + [EOS_ID]
n = len(ids)
split_idx = int(n * 0.9)
train_ids = ids[:split_idx]
val_ids = ids[split_idx:]
train_ds = RemappedTextDataset(train_ids, MAX_SEQ_LEN)
val_ds = RemappedTextDataset(val_ids, MAX_SEQ_LEN)
train_dl = DataLoader(train_ds, batch_size=MICRO_BATCH_SIZE, shuffle=True)
val_dl = DataLoader(val_ds, batch_size=MICRO_BATCH_SIZE, shuffle=False)
model = ChatGCLM(vocab).to(device)
num_params = sum(p.numel() for p in model.parameters())
param_str = format_params(num_params)
save_path = f"ChatGCLM_{param_str}.pt"
print("-" * 30)
print(f"ChatGCLM TRAINING START")
print(f"Model ID: {save_path}")
print(f"Parameters: {num_params:,}")
print(f"Device: {device}")
print(f"Vocab Size: {vocab}")
print(f"Learning Rate: {LEARNING_RATE}")
print(f"Epochs: {EPOCHS}")
print("-" * 30)
if os.path.exists(save_path) and os.path.getsize(save_path) > 0:
print(f"⏳ Found checkpoint at {save_path}, loading...")
model.load_state_dict(torch.load(save_path, map_location=device))
print("✓ Model weights loaded successfully! Resuming training.")
else:
print("ℹ No checkpoint found. Starting training from scratch.")
opt = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=EPOCHS, eta_min=MIN_LR)
loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_ID)
ctx = torch.amp.autocast(device) if device == "cuda" else contextlib.nullcontext()
scaler = torch.amp.GradScaler(device) if device == "cuda" else None
for ep in range(EPOCHS):
opt.zero_grad(set_to_none=True)
pbar = tqdm(train_dl, desc=f"Epoch {ep+1}/{EPOCHS}")
running_loss = 0.0
for i, (x, y) in enumerate(pbar):
x, y = x.to(device), y.to(device)
with ctx:
logits = model(x)
loss = loss_fn(logits.reshape(-1, vocab), y.reshape(-1))
loss_val = loss.item()
loss = loss / GRAD_ACCUM_STEPS
if scaler:
scaler.scale(loss).backward()
else:
loss.backward()
if (i+1) % GRAD_ACCUM_STEPS == 0:
if scaler:
scaler.step(opt)
scaler.update()
else:
opt.step()
opt.zero_grad(set_to_none=True)
running_loss = 0.9 * running_loss + 0.1 * loss_val if running_loss > 0 else loss_val
pbar.set_postfix(loss=f"{running_loss:.4f}")
val_loss = estimate_loss(model, val_dl, device, ctx)
current_lr = scheduler.get_last_lr()[0]
print(f"Epoch {ep+1} | Train Loss: {running_loss:.4f} | Val Loss: {val_loss:.4f} | LR: {current_lr:.6f}")
torch.save(model.state_dict(), save_path)
print(f"✓ Model saved successfully after epoch {ep+1} to {save_path}")
scheduler.step()
if __name__ == "__main__":
train()