| import os |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import Dataset, DataLoader |
| from tqdm import tqdm |
| import tiktoken |
| import contextlib |
| from model import ChatGCLM, MAX_SEQ_LEN |
|
|
| if os.name != "nt": |
| os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") |
|
|
| if torch.cuda.is_available(): |
| torch.set_float32_matmul_precision("high") |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
|
|
| DATA_DIR = "data" |
| DATA_PCT = 0.002 |
| TOKENIZER_NAME = "gpt2" |
| VOCAB_SAVE_PATH = "vocab_map.pt" |
|
|
| EPOCHS = 50 |
| MICRO_BATCH_SIZE = 1 |
| GRAD_ACCUM_STEPS = 8 |
| LEARNING_RATE = 5e-4 |
| MIN_LR = 1e-5 |
|
|
| SAVE_N_EPOCHS = 1 |
|
|
| PAD_ID = 0 |
| SEP_ID = 1 |
| EOS_ID = 2 |
| OFFSET = 3 |
|
|
| def build_dataset_vocab(data_dir, tokenizer, save_path): |
| vocab_size = tokenizer.n_vocab |
| torch.save({ |
| "vocab_size": vocab_size, |
| "PAD_ID": PAD_ID, |
| "SEP_ID": SEP_ID, |
| "EOS_ID": EOS_ID, |
| }, save_path) |
| return vocab_size |
|
|
| class RemappedTextDataset(Dataset): |
| def __init__(self, ids, max_len): |
| self.ids = ids |
| self.max_len = max_len |
|
|
| def __len__(self): |
| return max(0, (len(self.ids) - 1) // self.max_len) |
|
|
| def __getitem__(self, i): |
| start = i * self.max_len |
| x = self.ids[start : start + self.max_len] |
| y = self.ids[start + 1 : start + self.max_len + 1] |
| |
| if len(x) < self.max_len: |
| x = x + [0] * (self.max_len - len(x)) |
| if len(y) < self.max_len: |
| y = y + [0] * (self.max_len - len(y)) |
| |
| return torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.long) |
|
|
| def format_params(num): |
| if num >= 1_000_000_000: |
| return f"{num/1_000_000_000:.1f}B" |
| elif num >= 1_000_000: |
| return f"{num/1_000_000:.1f}M" |
| else: |
| return f"{num/1_000:.1f}K" |
|
|
| @torch.no_grad() |
| def estimate_loss(model, dl, device, ctx): |
| model.eval() |
| losses = [] |
| limit = 50 |
| for i, (x, y) in enumerate(dl): |
| if i >= limit: break |
| x, y = x.to(device), y.to(device) |
| with ctx: |
| logits = model(x) |
| loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), y.reshape(-1), ignore_index=PAD_ID) |
| losses.append(loss.item()) |
| model.train() |
| return sum(losses) / len(losses) if losses else 0.0 |
|
|
| def train(): |
| device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" |
| tok = tiktoken.get_encoding(TOKENIZER_NAME) |
| vocab = build_dataset_vocab(DATA_DIR, tok, VOCAB_SAVE_PATH) |
|
|
| full_text = "" |
| for f in os.listdir(DATA_DIR): |
| if not f.endswith(".txt"): continue |
| fpath = os.path.join(DATA_DIR, f) |
| content = open(fpath, "r", encoding="utf-8").read() |
| full_text += content + "\n" |
|
|
| ids = tok.encode(full_text) + [EOS_ID] |
| |
| n = len(ids) |
| split_idx = int(n * 0.9) |
| train_ids = ids[:split_idx] |
| val_ids = ids[split_idx:] |
| |
| train_ds = RemappedTextDataset(train_ids, MAX_SEQ_LEN) |
| val_ds = RemappedTextDataset(val_ids, MAX_SEQ_LEN) |
| train_dl = DataLoader(train_ds, batch_size=MICRO_BATCH_SIZE, shuffle=True) |
| val_dl = DataLoader(val_ds, batch_size=MICRO_BATCH_SIZE, shuffle=False) |
|
|
| model = ChatGCLM(vocab).to(device) |
| num_params = sum(p.numel() for p in model.parameters()) |
| param_str = format_params(num_params) |
| save_path = f"ChatGCLM_{param_str}.pt" |
| |
| print("-" * 30) |
| print(f"ChatGCLM TRAINING START") |
| print(f"Model ID: {save_path}") |
| print(f"Parameters: {num_params:,}") |
| print(f"Device: {device}") |
| print(f"Vocab Size: {vocab}") |
| print(f"Learning Rate: {LEARNING_RATE}") |
| print(f"Epochs: {EPOCHS}") |
| print("-" * 30) |
|
|
| if os.path.exists(save_path) and os.path.getsize(save_path) > 0: |
| print(f"⏳ Found checkpoint at {save_path}, loading...") |
| model.load_state_dict(torch.load(save_path, map_location=device)) |
| print("✓ Model weights loaded successfully! Resuming training.") |
| else: |
| print("ℹ No checkpoint found. Starting training from scratch.") |
|
|
| opt = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE) |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=EPOCHS, eta_min=MIN_LR) |
| loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_ID) |
| ctx = torch.amp.autocast(device) if device == "cuda" else contextlib.nullcontext() |
| scaler = torch.amp.GradScaler(device) if device == "cuda" else None |
|
|
| for ep in range(EPOCHS): |
| opt.zero_grad(set_to_none=True) |
| pbar = tqdm(train_dl, desc=f"Epoch {ep+1}/{EPOCHS}") |
| running_loss = 0.0 |
| for i, (x, y) in enumerate(pbar): |
| x, y = x.to(device), y.to(device) |
| with ctx: |
| logits = model(x) |
| loss = loss_fn(logits.reshape(-1, vocab), y.reshape(-1)) |
| loss_val = loss.item() |
| loss = loss / GRAD_ACCUM_STEPS |
| if scaler: |
| scaler.scale(loss).backward() |
| else: |
| loss.backward() |
| if (i+1) % GRAD_ACCUM_STEPS == 0: |
| if scaler: |
| scaler.step(opt) |
| scaler.update() |
| else: |
| opt.step() |
| opt.zero_grad(set_to_none=True) |
| running_loss = 0.9 * running_loss + 0.1 * loss_val if running_loss > 0 else loss_val |
| pbar.set_postfix(loss=f"{running_loss:.4f}") |
| val_loss = estimate_loss(model, val_dl, device, ctx) |
| current_lr = scheduler.get_last_lr()[0] |
| print(f"Epoch {ep+1} | Train Loss: {running_loss:.4f} | Val Loss: {val_loss:.4f} | LR: {current_lr:.6f}") |
| torch.save(model.state_dict(), save_path) |
| print(f"✓ Model saved successfully after epoch {ep+1} to {save_path}") |
| scheduler.step() |
|
|
| if __name__ == "__main__": |
| train() |
|
|