fourth-gpt / model.py
ajaxdavis's picture
Upload folder using huggingface_hub
3217baa verified
"""Fourth GPT model definition and inference using PyTorch (CPU)."""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import json
import os
import re
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.eps = eps
def forward(self, x):
norm = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
return x * norm * self.weight
class TransformerBlock(nn.Module):
def __init__(self, n_embd, n_head):
super().__init__()
self.n_head = n_head
self.head_dim = n_embd // n_head
self.norm1 = RMSNorm(n_embd)
self.wq = nn.Linear(n_embd, n_embd, bias=False)
self.wk = nn.Linear(n_embd, n_embd, bias=False)
self.wv = nn.Linear(n_embd, n_embd, bias=False)
self.wo = nn.Linear(n_embd, n_embd, bias=False)
self.norm2 = RMSNorm(n_embd)
self.mlp_fc1 = nn.Linear(n_embd, 4 * n_embd, bias=False)
self.mlp_fc2 = nn.Linear(4 * n_embd, n_embd, bias=False)
def forward(self, x, mask):
B, T, _ = x.shape
xn = self.norm1(x)
q = self.wq(xn).reshape(B, T, self.n_head, self.head_dim).transpose(1, 2)
k = self.wk(xn).reshape(B, T, self.n_head, self.head_dim).transpose(1, 2)
v = self.wv(xn).reshape(B, T, self.n_head, self.head_dim).transpose(1, 2)
att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
att = att + mask
att = F.softmax(att, dim=-1)
out = (att @ v).transpose(1, 2).reshape(B, T, -1)
x = x + self.wo(out)
xn2 = self.norm2(x)
h = F.relu(self.mlp_fc1(xn2))
x = x + self.mlp_fc2(h)
return x
class GPT(nn.Module):
def __init__(self, vocab_size, n_layer, n_embd, block_size, n_head):
super().__init__()
self.block_size = block_size
self.wte = nn.Embedding(vocab_size, n_embd)
self.wpe = nn.Embedding(block_size, n_embd)
self.ln_pre = RMSNorm(n_embd)
self.layers = nn.ModuleList([TransformerBlock(n_embd, n_head) for _ in range(n_layer)])
self.ln_post = RMSNorm(n_embd)
self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)
def forward(self, tokens):
B, T = tokens.shape
x = self.wte(tokens) + self.wpe(torch.arange(T, device=tokens.device))
x = self.ln_pre(x)
mask = torch.triu(torch.full((T, T), -1e9, device=tokens.device), diagonal=1)
for layer in self.layers:
x = layer(x, mask)
x = self.ln_post(x)
return self.lm_head(x)
class FourthModel:
"""Wraps the GPT model with tokenizer and generation logic."""
def __init__(self, checkpoint_dir=None):
if checkpoint_dir is None:
checkpoint_dir = os.path.join(os.path.dirname(__file__) or ".", "model_weights")
self.checkpoint_dir = checkpoint_dir
self.model = None
self.stoi = None
self.itos = None
self.bos = None
self.config = None
def load(self):
config_path = os.path.join(self.checkpoint_dir, "config.json")
with open(config_path) as f:
self.config = json.load(f)
self.stoi = self.config["stoi"]
self.bos = self.config["bos"]
self.itos = {int(i): c for c, i in self.stoi.items()}
self.itos[self.bos] = ""
self.model = GPT(
vocab_size=self.config["vocab_size"],
n_layer=self.config["n_layer"],
n_embd=self.config["n_embd"],
block_size=self.config["block_size"],
n_head=self.config["n_head"],
)
# Load weights — try PyTorch format first, fall back to npz
pt_path = os.path.join(self.checkpoint_dir, "weights.pt")
npz_path = os.path.join(self.checkpoint_dir, "weights.npz")
if os.path.exists(pt_path):
state_dict = torch.load(pt_path, map_location="cpu", weights_only=True)
else:
import numpy as np
npz = np.load(npz_path)
state_dict = {k: torch.tensor(npz[k]) for k in npz.files}
self.model.load_state_dict(state_dict)
self.model.eval()
nparams = sum(p.numel() for p in self.model.parameters())
print(f"Loaded model: {nparams} params, vocab={self.config['vocab_size']}")
@torch.no_grad()
def generate(self, prompt: str, max_tokens: int = 128, temperature: float = 0.7) -> str:
"""Generate a response to a prompt."""
clean = re.sub(r'[^a-z |]', '', prompt.lower().strip())
clean = re.sub(r' +', ' ', clean).strip()
if not clean.endswith("|"):
clean += "|"
block_size = self.config["block_size"]
tokens = [self.bos] + [self.stoi.get(c, self.bos) for c in clean]
for _ in range(min(max_tokens, block_size - len(tokens))):
x = torch.tensor([tokens[-block_size:]], dtype=torch.long)
logits = self.model(x)
logits = logits[0, -1] / temperature
probs = F.softmax(logits, dim=-1)
tok = torch.multinomial(probs, 1).item()
if tok == self.bos:
break
tokens.append(tok)
full = "".join(self.itos.get(t, "?") for t in tokens[1:])
parts = full.split("|", 1)
return parts[1] if len(parts) > 1 else full