import torch import torch.nn as nn from torch.nn import functional as F from tqdm import tqdm import os import math import sentencepiece as spm from data import training_data model_to_train = "SLM" MODELS = { "SLM": { "block_size": 256, "n_embd": 384, "n_head": 6, "n_layer": 8, "batch_size": 64, "max_iters": 12000, "learning_rate": 2e-3, "weight_decay": 1e-2 }, } if model_to_train not in MODELS: raise ValueError(f"Model '{model_to_train}' not found. Available models: {list(MODELS.keys())}") config = MODELS[model_to_train] block_size = config["block_size"] n_embd = config["n_embd"] n_head = config["n_head"] n_layer = config["n_layer"] batch_size = config["batch_size"] max_iters = config["max_iters"] learning_rate = config["learning_rate"] weight_decay = config.get("weight_decay", 0.0) device = 'cuda' if torch.cuda.is_available() else 'cpu' eval_interval = 250 eval_iters = 100 dropout = 0.3 MODEL_SAVE_PATH = f"{model_to_train.replace(' ', '_')}.pth" TOKENIZER_MODEL = f"{model_to_train.replace(' ', '_')}_tokenizer.model" print(f"--- Using Model: {model_to_train} ---") print(f"Using device: {device}") for key, value in config.items(): print(f"{key.replace('_', ' ').title()}: {value}") torch.manual_seed(1337) vocab_size = 1000 if not os.path.exists(TOKENIZER_MODEL): print("Tokenizer model not found, training a new one...") with open("temp_training_data.txt", "w", encoding="utf-8") as f: f.write(training_data) spm.SentencePieceTrainer.train( f'--input=temp_training_data.txt --model_prefix={model_to_train.replace(" ", "_")}_tokenizer ' f'--vocab_size={vocab_size} --model_type=bpe ' f'--pad_id=0 --unk_id=1 --bos_id=-1 --eos_id=-1' ) os.remove("temp_training_data.txt") print("Tokenizer training complete.") sp = spm.SentencePieceProcessor() sp.load(TOKENIZER_MODEL) vocab_size = sp.get_piece_size() pad_token_id = sp.pad_id() def encode(s): return sp.encode_as_ids(s) def decode(l): return sp.decode([token for token in l if token != pad_token_id]) data = torch.tensor(encode(training_data), dtype=torch.long) n = int(0.9 * len(data)) train_data = data[:n] val_data = data[n:] def get_batch(split): data = train_data if split == 'train' else val_data ix = torch.randint(len(data) - block_size, (batch_size,)) x = torch.stack([data[i:i+block_size] for i in ix]) y = torch.stack([data[i+1:i+block_size+1] for i in ix]) x, y = x.to(device), y.to(device) return x, y class Head(nn.Module): def __init__(self, head_size): super().__init__() self.key = nn.Linear(n_embd, head_size, bias=False) self.query = nn.Linear(n_embd, head_size, bias=False) self.value = nn.Linear(n_embd, head_size, bias=False) self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size))) self.dropout = nn.Dropout(dropout) def forward(self, x): B, T, C = x.shape k, q, v = self.key(x), self.query(x), self.value(x) wei = q @ k.transpose(-2, -1) * k.shape[-1]**-0.5 wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) wei = F.softmax(wei, dim=-1) wei = self.dropout(wei) return wei @ v class MultiHeadAttention(nn.Module): def __init__(self, num_heads, head_size): super().__init__() self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)]) self.proj = nn.Linear(num_heads * head_size, n_embd) self.dropout = nn.Dropout(dropout) def forward(self, x): out = torch.cat([h(x) for h in self.heads], dim=-1) return self.dropout(self.proj(out)) class FeedForward(nn.Module): def __init__(self, n_embd): super().__init__() self.net = nn.Sequential( nn.Linear(n_embd, 4 * n_embd), nn.GELU(), nn.Linear(4 * n_embd, n_embd), nn.Dropout(dropout) ) def forward(self, x): return self.net(x) class Block(nn.Module): def __init__(self, n_embd, n_head): super().__init__() head_size = n_embd // n_head self.sa = MultiHeadAttention(n_head, head_size) self.ffwd = FeedForward(n_embd) self.ln1, self.ln2 = nn.LayerNorm(n_embd), nn.LayerNorm(n_embd) def forward(self, x): x = x + self.sa(self.ln1(x)) x = x + self.ffwd(self.ln2(x)) return x class GPTLanguageModel(nn.Module): def __init__(self): super().__init__() self.token_embedding_table = nn.Embedding(vocab_size, n_embd) self.position_embedding_table = nn.Embedding(block_size, n_embd) self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)]) self.ln_f = nn.LayerNorm(n_embd) self.lm_head = nn.Linear(n_embd, vocab_size) self.token_embedding_table.weight = self.lm_head.weight self.apply(self._init_weights) def _init_weights(self, module): if isinstance(module, nn.Linear): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.bias is not None: torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) def forward(self, idx, targets=None): B, T = idx.shape tok_emb = self.token_embedding_table(idx) pos_emb = self.position_embedding_table(torch.arange(T, device=device)) x = tok_emb + pos_emb x = self.blocks(x) x = self.ln_f(x) logits = self.lm_head(x) loss = None if targets is not None: B, T, C = logits.shape logits = logits.view(B * T, C) targets = targets.view(B * T) loss = F.cross_entropy(logits, targets, ignore_index=pad_token_id) return logits, loss def generate(self, idx, max_new_tokens, top_k=50): self.eval() for _ in range(max_new_tokens): idx_cond = idx[:, -block_size:] logits, loss = self(idx_cond) logits = logits[:, -1, :] v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = -float('Inf') probs = F.softmax(logits, dim=-1) idx_next = torch.multinomial(probs, num_samples=1) idx = torch.cat((idx, idx_next), dim=1) self.train() return idx def get_lr(it): if it < 100: return learning_rate * it / 100 if it > max_iters: return learning_rate / 10 decay_ratio = (it - 100) / (max_iters - 100) assert 0 <= decay_ratio <= 1 coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) return (learning_rate / 10) + coeff * (learning_rate - (learning_rate/10)) @torch.no_grad() def estimate_loss(): out = {} model.eval() for split in ['train', 'val']: losses = torch.zeros(eval_iters) for k in range(eval_iters): X, Y = get_batch(split) _, loss = model(X, Y) losses[k] = loss.item() out[split] = losses.mean() model.train() return out model = GPTLanguageModel().to(device) print(f"Model Parameters: {sum(p.numel() for p in model.parameters())/1e6:.2f}M") if os.path.exists(MODEL_SAVE_PATH): print(f"Loading pre-trained model from {MODEL_SAVE_PATH}...") model.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location=device)) print("Model loaded successfully.") else: print(f"No pre-trained model found at '{MODEL_SAVE_PATH}'. Starting training...") optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay) for i in tqdm(range(max_iters), desc="Training"): lr = get_lr(i) for param_group in optimizer.param_groups: param_group['lr'] = lr if i % eval_interval == 0 or i == max_iters - 1: losses = estimate_loss() tqdm.write(f"step {i}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}, lr: {lr:.6f}") xb, yb = get_batch('train') _, loss = model(xb, yb) optimizer.zero_grad(set_to_none=True) loss.backward() optimizer.step() print(f"\nTraining finished. Saving model to {MODEL_SAVE_PATH}...") torch.save(model.state_dict(), MODEL_SAVE_PATH) print("Model saved.") print(f"\n--- Generation from {model_to_train} ---") prompt = "in the heart of a whispering forest" context = torch.tensor(encode(prompt), dtype=torch.long, device=device).unsqueeze(0) generated_tokens = model.generate(context, max_new_tokens=100)[0].tolist() print(decode(generated_tokens)) print("\n--- Generation from 'hello' prompt ---") prompt2 = "hello" context2 = torch.tensor(encode(prompt2), dtype=torch.long, device=device).unsqueeze(0) generated_tokens2 = model.generate(context2, max_new_tokens=100)[0].tolist() print(decode(generated_tokens2))