#!/usr/bin/env python3 import argparse import torch import numpy as np from util import CharacterTokenizer, BPETokenizer, Dataset from gpt import GPTLanguageModel from loss import estimate_loss from metrics import Metrics def train(data, model, tokenizer, steps, report_frequency, lr, save_path, checkpoint_freq, warmup_steps=0): device = next(model.parameters()).device optimizer = torch.optim.AdamW(model.parameters(), lr=lr) metrics = Metrics() if warmup_steps > 0: from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR warmup = LinearLR(optimizer, start_factor=0.01, end_factor=1.0, total_iters=warmup_steps) decay = CosineAnnealingLR(optimizer, T_max=max(1, steps - warmup_steps), eta_min=lr * 0.01) scheduler = SequentialLR(optimizer, schedulers=[warmup, decay], milestones=[warmup_steps]) else: scheduler = None for step in range(steps): xb, yb = data.get_batch('train', device) _, loss = model(xb, yb) optimizer.zero_grad(set_to_none=True) loss.backward() optimizer.step() if scheduler: scheduler.step() if step % report_frequency == 0 or step == steps - 1: losses = estimate_loss(data, model) print(f"Step {step}, train loss: {losses['train']:.4f} val loss: {losses['val']:.4f}") metrics_dict = metrics(data, model, tokenizer) print("Metrics:", metrics_dict) print() if checkpoint_freq and step > 0 and step % checkpoint_freq == 0: ckpt_path = save_path.replace('.pth', f'_step{step}.pth') torch.save(model.state_dict(), ckpt_path) print(f"Checkpoint saved: {ckpt_path}") parser = argparse.ArgumentParser() parser.add_argument("--input", type=str, default="input.txt") parser.add_argument("--seed", type=int, default=None) parser.add_argument("--tokenizer", type=str, default="char", choices=["char", "bpe"], help="Tokenizer type: char (default) or bpe (subword sentencepiece)") parser.add_argument("--bpe-vocab-size", type=int, default=2000, help="Vocab size when training a BPE tokenizer (ignored for char)") parser.add_argument("--bpe-model", type=str, default="bpe.model", help="Path to save/load the sentencepiece .model file") parser.add_argument("--context-size", type=int, default=256) parser.add_argument("--batch-size", type=int, default=32) parser.add_argument("--n-embd", type=int, default=384) parser.add_argument("--n-head", type=int, default=6) parser.add_argument("--n-layer", type=int, default=6) parser.add_argument("--dropout", type=float, default=0.2) subparsers = parser.add_subparsers(dest="command", required=True) train_parser = subparsers.add_parser("train") train_parser.add_argument("--save", type=str, default="model.pth") train_parser.add_argument("--steps", type=int, default=5000) train_parser.add_argument("--report", type=int, default=500) train_parser.add_argument("--lr", type=float, default=1e-3) train_parser.add_argument("--checkpoint-freq", type=int, default=0, help="Save a checkpoint every N steps (0 = disabled)") train_parser.add_argument("--warmup-steps", type=int, default=0, help="LR warmup steps before cosine decay (0 = constant LR)") eval_parser = subparsers.add_parser("eval") eval_parser.add_argument("--load", type=str, default="model.pth") eval_parser.add_argument("--prompt", type=str) eval_parser.add_argument("--token-count", type=int, default=300) args = parser.parse_args() if args.seed: torch.manual_seed(args.seed) batch_size = args.batch_size context_size = args.context_size n_embd = args.n_embd n_head = args.n_head n_layer = args.n_layer dropout = args.dropout # replace this with ur hw backend if needed device = 'cuda' if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu") if device == "cpu": print("WARNING: Running on cpu!") with open(args.input, "r") as f: content = f.read() if args.tokenizer == "bpe": tokenizer = BPETokenizer(args.input, model_path=args.bpe_model, vocab_size=args.bpe_vocab_size) else: tokenizer = CharacterTokenizer(content) data = torch.tensor(tokenizer.encode(content), dtype=torch.long) dataset = Dataset(data, context_size, batch_size) model = GPTLanguageModel(tokenizer.vocab_size, n_embd, context_size, n_head, n_layer) model = model.to(device) print(f"Total parameters: {sum(p.numel() for p in model.parameters()) / 1e6}") print(f"Using device: {device}") print() if args.command == "eval": print("=" * 20, "INFERENCE", "=" * 20) model.load_state_dict(torch.load(args.load)) model.eval() elif args.command == "train": print("=" * 20, "TRAINING", "=" * 20) train(dataset, model, tokenizer, args.steps, args.report, args.lr, args.save, args.checkpoint_freq, args.warmup_steps) torch.save(model.state_dict(), args.save) print("=" * 50) context = torch.zeros((1, 1), dtype=torch.long, device=device) max_tokens = 300 if args.command == "eval": if args.prompt is not None: context = torch.tensor([tokenizer.encode(args.prompt)], dtype=torch.long, device=device) max_tokens = args.token_count print( tokenizer.decode( model.generate(start_idx=context, number_of_tokens=max_tokens, use_cache=True)[0].tolist() ) )