"""Fourth GPT model definition and inference using PyTorch (CPU).""" import torch import torch.nn as nn import torch.nn.functional as F import math import json import os import re class RMSNorm(nn.Module): def __init__(self, dim, eps=1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(dim)) self.eps = eps def forward(self, x): norm = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) return x * norm * self.weight class TransformerBlock(nn.Module): def __init__(self, n_embd, n_head): super().__init__() self.n_head = n_head self.head_dim = n_embd // n_head self.norm1 = RMSNorm(n_embd) self.wq = nn.Linear(n_embd, n_embd, bias=False) self.wk = nn.Linear(n_embd, n_embd, bias=False) self.wv = nn.Linear(n_embd, n_embd, bias=False) self.wo = nn.Linear(n_embd, n_embd, bias=False) self.norm2 = RMSNorm(n_embd) self.mlp_fc1 = nn.Linear(n_embd, 4 * n_embd, bias=False) self.mlp_fc2 = nn.Linear(4 * n_embd, n_embd, bias=False) def forward(self, x, mask): B, T, _ = x.shape xn = self.norm1(x) q = self.wq(xn).reshape(B, T, self.n_head, self.head_dim).transpose(1, 2) k = self.wk(xn).reshape(B, T, self.n_head, self.head_dim).transpose(1, 2) v = self.wv(xn).reshape(B, T, self.n_head, self.head_dim).transpose(1, 2) att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim) att = att + mask att = F.softmax(att, dim=-1) out = (att @ v).transpose(1, 2).reshape(B, T, -1) x = x + self.wo(out) xn2 = self.norm2(x) h = F.relu(self.mlp_fc1(xn2)) x = x + self.mlp_fc2(h) return x class GPT(nn.Module): def __init__(self, vocab_size, n_layer, n_embd, block_size, n_head): super().__init__() self.block_size = block_size self.wte = nn.Embedding(vocab_size, n_embd) self.wpe = nn.Embedding(block_size, n_embd) self.ln_pre = RMSNorm(n_embd) self.layers = nn.ModuleList([TransformerBlock(n_embd, n_head) for _ in range(n_layer)]) self.ln_post = RMSNorm(n_embd) self.lm_head = nn.Linear(n_embd, vocab_size, bias=False) def forward(self, tokens): B, T = tokens.shape x = self.wte(tokens) + self.wpe(torch.arange(T, device=tokens.device)) x = self.ln_pre(x) mask = torch.triu(torch.full((T, T), -1e9, device=tokens.device), diagonal=1) for layer in self.layers: x = layer(x, mask) x = self.ln_post(x) return self.lm_head(x) class FourthModel: """Wraps the GPT model with tokenizer and generation logic.""" def __init__(self, checkpoint_dir=None): if checkpoint_dir is None: checkpoint_dir = os.path.join(os.path.dirname(__file__) or ".", "model_weights") self.checkpoint_dir = checkpoint_dir self.model = None self.stoi = None self.itos = None self.bos = None self.config = None def load(self): config_path = os.path.join(self.checkpoint_dir, "config.json") with open(config_path) as f: self.config = json.load(f) self.stoi = self.config["stoi"] self.bos = self.config["bos"] self.itos = {int(i): c for c, i in self.stoi.items()} self.itos[self.bos] = "" self.model = GPT( vocab_size=self.config["vocab_size"], n_layer=self.config["n_layer"], n_embd=self.config["n_embd"], block_size=self.config["block_size"], n_head=self.config["n_head"], ) # Load weights — try PyTorch format first, fall back to npz pt_path = os.path.join(self.checkpoint_dir, "weights.pt") npz_path = os.path.join(self.checkpoint_dir, "weights.npz") if os.path.exists(pt_path): state_dict = torch.load(pt_path, map_location="cpu", weights_only=True) else: import numpy as np npz = np.load(npz_path) state_dict = {k: torch.tensor(npz[k]) for k in npz.files} self.model.load_state_dict(state_dict) self.model.eval() nparams = sum(p.numel() for p in self.model.parameters()) print(f"Loaded model: {nparams} params, vocab={self.config['vocab_size']}") @torch.no_grad() def generate(self, prompt: str, max_tokens: int = 128, temperature: float = 0.7) -> str: """Generate a response to a prompt.""" clean = re.sub(r'[^a-z |]', '', prompt.lower().strip()) clean = re.sub(r' +', ' ', clean).strip() if not clean.endswith("|"): clean += "|" block_size = self.config["block_size"] tokens = [self.bos] + [self.stoi.get(c, self.bos) for c in clean] for _ in range(min(max_tokens, block_size - len(tokens))): x = torch.tensor([tokens[-block_size:]], dtype=torch.long) logits = self.model(x) logits = logits[0, -1] / temperature probs = F.softmax(logits, dim=-1) tok = torch.multinomial(probs, 1).item() if tok == self.bos: break tokens.append(tok) full = "".join(self.itos.get(t, "?") for t in tokens[1:]) parts = full.split("|", 1) return parts[1] if len(parts) > 1 else full