edgemindroboticslabs's picture
Upload train.py with huggingface_hub
e28afaf verified
"""Train the GPT model from scratch on one or more datasets."""
import os
import time
import argparse
import torch
import numpy as np
from model import GPT, GPTConfig
from tokenizer import BPETokenizer, CharTokenizer
from data_loader import build_combined_text, tokenize_and_split
def get_device():
if torch.backends.mps.is_available():
return torch.device("mps")
if torch.cuda.is_available():
return torch.device("cuda")
return torch.device("cpu")
def get_batch(data, block_size, batch_size, device):
ix = torch.randint(len(data) - block_size, (batch_size,))
x = torch.stack([data[i : i + block_size] for i in ix])
y = torch.stack([data[i + 1 : i + block_size + 1] for i in ix])
return x.to(device), y.to(device)
@torch.no_grad()
def estimate_loss(model, train_data, val_data, block_size, batch_size, device, eval_iters=50):
model.eval()
losses = {}
for split, data in [("train", train_data), ("val", val_data)]:
batch_losses = []
for _ in range(eval_iters):
x, y = get_batch(data, block_size, batch_size, device)
_, loss = model(x, y)
batch_losses.append(loss.item())
losses[split] = np.mean(batch_losses)
model.train()
return losses
def train(args):
device = get_device()
print(f"Using device: {device}")
# ── Tokenizer ─────────────────────────────────────────────────────────────
if args.tokenizer == "bpe":
tokenizer = BPETokenizer()
print(f"Tokenizer: BPE (GPT-2), vocab size {tokenizer.vocab_size:,}")
else:
# char tokenizer needs a full text pass first β€” load all data, build vocab
print("Tokenizer: char-level (building vocab from data...)")
raw = build_combined_text(
args.datasets.split(","),
data_dir=args.data_dir,
custom_file=args.custom_file,
weights=[float(w) for w in args.weights.split(",")] if args.weights else None,
)
tokenizer = CharTokenizer(text=raw)
print(f"Tokenizer: char-level, vocab size {tokenizer.vocab_size}")
tokenizer.save("tokenizer.json")
# ── Data ──────────────────────────────────────────────────────────────────
dataset_names = args.datasets.split(",")
weights = [float(w) for w in args.weights.split(",")] if args.weights else None
if args.tokenizer == "bpe" or not hasattr(tokenizer, "_raw_text"):
raw = build_combined_text(
dataset_names,
data_dir=args.data_dir,
custom_file=args.custom_file,
weights=weights,
)
train_data, val_data = tokenize_and_split(raw, tokenizer, split_ratio=0.9)
print(f"Train tokens: {len(train_data):,} | Val tokens: {len(val_data):,}")
# ── Model ─────────────────────────────────────────────────────────────────
start_step = 0
best_val_loss = float("inf")
if args.resume and os.path.exists("checkpoints/best_model.pt"):
print("Resuming from checkpoints/best_model.pt ...")
ckpt = torch.load("checkpoints/best_model.pt", map_location=device, weights_only=False)
config = GPTConfig(**ckpt["config"])
model = GPT(config).to(device)
model.load_state_dict(ckpt["model_state"])
best_val_loss = ckpt.get("val_loss", float("inf"))
start_step = ckpt.get("step", 0) + 1
print(f"Resumed at step {start_step}, val loss {best_val_loss:.4f}")
else:
config = GPTConfig(
vocab_size=tokenizer.vocab_size,
block_size=args.block_size,
n_layer=args.n_layer,
n_head=args.n_head,
n_embd=args.n_embd,
dropout=args.dropout,
)
model = GPT(config).to(device)
print(f"Model parameters: {model.num_params():,}")
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=0.1)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.max_iters)
os.makedirs("checkpoints", exist_ok=True)
t0 = time.time()
for step in range(start_step, args.max_iters):
x, y = get_batch(train_data, args.block_size, args.batch_size, device)
_, loss = model(x, y)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
scheduler.step()
if step % args.eval_interval == 0 or step == args.max_iters - 1:
losses = estimate_loss(model, train_data, val_data, args.block_size, args.batch_size, device)
elapsed = time.time() - t0
lr_now = scheduler.get_last_lr()[0]
print(
f"step {step:5d} | train {losses['train']:.4f} | val {losses['val']:.4f}"
f" | lr {lr_now:.2e} | {elapsed:.1f}s"
)
if losses["val"] < best_val_loss:
best_val_loss = losses["val"]
torch.save(
{
"model_state": model.state_dict(),
"config": config.__dict__,
"val_loss": best_val_loss,
"step": step,
"datasets": args.datasets,
"tokenizer": args.tokenizer,
},
"checkpoints/best_model.pt",
)
print(f" -> Saved best model (val loss {best_val_loss:.4f})")
print(f"\nTraining complete. Best val loss: {best_val_loss:.4f}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Train a GPT model from scratch")
# Dataset args
parser.add_argument(
"--datasets",
default="shakespeare",
help="Comma-separated dataset names: shakespeare,alpaca,openwebtext,custom",
)
parser.add_argument(
"--weights",
default=None,
help="Comma-separated sampling weights matching --datasets, e.g. '1.0,0.5'",
)
parser.add_argument("--data_dir", default="data")
parser.add_argument("--custom_file", default=None, help="Path to a custom .txt file")
# Tokenizer
parser.add_argument("--tokenizer", default="bpe", choices=["bpe", "char"],
help="bpe (GPT-2, 50257 tokens) or char (small vocab, fast)")
# Model
parser.add_argument("--block_size", type=int, default=256)
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--n_layer", type=int, default=6)
parser.add_argument("--n_head", type=int, default=6)
parser.add_argument("--n_embd", type=int, default=384)
parser.add_argument("--dropout", type=float, default=0.2)
# Training
parser.add_argument("--lr", type=float, default=3e-4)
parser.add_argument("--max_iters", type=int, default=5000)
parser.add_argument("--eval_interval", type=int, default=500)
parser.add_argument("--resume", action="store_true", help="Resume from checkpoints/best_model.pt")
args = parser.parse_args()
train(args)