BaseGPT / app.py
skolvankar's picture
Upload 3 files
9a1472c verified
"""
Hugging Face Spaces app for the from-scratch Shakespeare GPT.
This app:
1. Defines the GPT model architecture.
2. Loads the saved checkpoint (model, config, and tokenizer).
3. Serves a Gradio UI for text generation.
"""
import gradio as gr
import torch
import torch.nn as nn
from torch.nn import functional as F
import math
from dataclasses import dataclass
# -----------------------------------------------------------------------------
# 1. Model Definition (Pasted from train_complete.py)
# -----------------------------------------------------------------------------
# (This code is identical to train_complete.py, so it's folded here for brevity)
@dataclass
class GPTConfig:
block_size: int = 1024
vocab_size: int = 50257
n_layer: int = 12
n_head: int = 12
n_embd: int = 768
dropout: float = 0.1
class CausalSelfAttention(nn.Module):
def __init__(self, config: GPTConfig):
super().__init__()
assert config.n_embd % config.n_head == 0, "Embedding dim must be divisible by num heads"
self.n_head = config.n_head
self.n_embd = config.n_embd
self.dropout = config.dropout
self.c_attn = nn.Linear(self.n_embd, 3 * self.n_embd, bias=False)
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
self.attn_dropout = nn.Dropout(self.dropout)
self.resid_dropout = nn.Dropout(self.dropout)
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
.view(1, 1, config.block_size, config.block_size))
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, T, C = x.size()
qkv = self.c_attn(x)
q, k, v = qkv.split(self.n_embd, dim=2)
head_size = C // self.n_head
q = q.view(B, T, self.n_head, head_size).transpose(1, 2)
k = k.view(B, T, self.n_head, head_size).transpose(1, 2)
v = v.view(B, T, self.n_head, head_size).transpose(1, 2)
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(head_size))
att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
att = F.softmax(att, dim=-1)
att = self.attn_dropout(att)
y = att @ v
y = y.transpose(1, 2).contiguous().view(B, T, C)
y = self.resid_dropout(self.c_proj(y))
return y
class MLP(nn.Module):
def __init__(self, config: GPTConfig):
super().__init__()
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
self.gelu = nn.GELU()
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.c_fc(x)
x = self.gelu(x)
x = self.c_proj(x)
x = self.dropout(x)
return x
class Block(nn.Module):
def __init__(self, config: GPTConfig):
super().__init__()
self.ln_1 = nn.LayerNorm(config.n_embd)
self.attn = CausalSelfAttention(config)
self.ln_2 = nn.LayerNorm(config.n_embd)
self.mlp = MLP(config)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + self.attn(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
class GPT(nn.Module):
def __init__(self, config: GPTConfig):
super().__init__()
assert config.vocab_size is not None
assert config.block_size is not None
self.config = config
self.transformer = nn.ModuleDict(dict(
wte = nn.Embedding(config.vocab_size, config.n_embd),
wpe = nn.Embedding(config.block_size, config.n_embd),
drop = nn.Dropout(config.dropout),
h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
ln_f = nn.LayerNorm(config.n_embd),
))
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.transformer.wte.weight = self.lm_head.weight
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(self, idx: torch.Tensor, targets: torch.Tensor = None):
B, T = idx.size()
assert T <= self.config.block_size, f"Seq len {T} exceeds block size {self.config.block_size}"
pos = torch.arange(0, T, dtype=torch.long, device=idx.device).unsqueeze(0)
tok_emb = self.transformer.wte(idx)
pos_emb = self.transformer.wpe(pos)
x = self.transformer.drop(tok_emb + pos_emb)
for block in self.transformer.h:
x = block(x)
x = self.transformer.ln_f(x)
logits = self.lm_head(x)
loss = None
if targets is not None:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
return logits, loss
@torch.no_grad()
def generate(self, idx, max_new_tokens):
self.eval()
for _ in range(max_new_tokens):
idx_cond = idx[:, -self.config.block_size:]
logits, _ = self(idx_cond)
logits = logits[:, -1, :]
probs = F.softmax(logits, dim=-1)
idx_next = torch.argmax(probs, dim=-1, keepdim=True)
idx = torch.cat((idx, idx_next), dim=1)
return idx
# -----------------------------------------------------------------------------
# 2. Load Model and Tokenizer (MODIFIED FOR 15-MINUTE DEADLINE)
# -----------------------------------------------------------------------------
# --- Configuration ---
# IMPORTANT: Change this to match the checkpoint file you uploaded!
# e.g., 'model_baby.pth' or 'model_gpt2-124m.pth'
CHECKPOINT_FILE = 'models/model_gpt2-124m.pth'
DEVICE = 'cpu' # HF Spaces run best on CPU for light inference
# ---------------------
# --- Hard-coded Configuration ---
# We must manually define the config and tokenizer because
# the old .pth file only contains the model weights.
# 1. Define the characters (must match training vocabulary)
# This vocabulary was created from the training data
# The characters are sorted to match the training script: chars = sorted(list(set(text)))
chars = sorted(['\n', ' ', '!', '$', '&', "'", ',', '-', '.', '3', ':', ';', '?',
'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K',
'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V',
'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g',
'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r',
's', 't', 'u', 'v', 'w', 'x', 'y', 'z'])
vocab_size = len(chars) # Calculate from actual list length
print(f"Vocab size: {vocab_size} characters")
print(f"Characters: {chars}")
assert len(chars) == vocab_size, f"Tokenizer character list is incorrect: expected {vocab_size}, got {len(chars)}"
# 2. Define the model config
# This MUST match the exact settings used for training!
# Training config: block_size=1024, n_embd=936, n_layer=12, n_head=12
config = GPTConfig(
block_size = 1024, # Updated to match training: was 512
vocab_size = vocab_size, # CRITICAL: 65, not 50257
n_layer = 12,
n_head = 12,
n_embd = 936, # Updated to match training: was 768
dropout = 0.1
)
# --- End Hard-coded Configuration ---
print(f"Loading model from {CHECKPOINT_FILE}...")
print(f"Using hard-coded config: {config}")
print(f"Using hard-coded vocab size: {vocab_size}")
# 1. Create the model "scaffolding"
model = GPT(config)
# 2. Load the weights
# Your old .pth file *is* the state_dict, not a checkpoint dictionary
state_dict = torch.load(CHECKPOINT_FILE, map_location=DEVICE)
model.load_state_dict(state_dict)
model.to(DEVICE)
model.eval()
print("Model loaded successfully.")
# Re-create the character-level tokenizer
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s if c in stoi] # Ignore chars not in vocab
decode = lambda l: ''.join([itos[i] for i in l])
# -----------------------------------------------------------------------------
# 3. Gradio Inference Function and UI
# -----------------------------------------------------------------------------
def predict(prompt_text, max_new_tokens=300):
"""
The main inference function for Gradio.
"""
if not prompt_text:
# Start with a newline character if no prompt
start_context = torch.zeros((1, 1), dtype=torch.long, device=DEVICE)
else:
# Encode the prompt
encoded_prompt = encode(prompt_text)
start_context = torch.tensor(encoded_prompt, dtype=torch.long, device=DEVICE).unsqueeze(0)
# Generate tokens
generated_tokens = model.generate(start_context, max_new_tokens=max_new_tokens)
# Decode and return the full text
generated_text = decode(generated_tokens[0].tolist())
return generated_text
# Launch the Gradio Interface
iface = gr.Interface(
fn=predict,
inputs=[
gr.Textbox(
lines=5,
label="Prompt",
placeholder="Enter your prompt... (e.g., 'JULIET:')"
),
gr.Slider(
minimum=50,
maximum=1000,
value=300,
step=50,
label="Max New Tokens"
)
],
outputs=gr.Textbox(label="Generated Text", lines=10),
title="Shakespeare GPT",
description="A character-level GPT (85M param config) trained from scratch on Shakespeare. "
"This app loads a raw state_dict file."
)
if __name__ == "__main__":
iface.launch()