shwethd's picture
Upload app.py
227301c verified
raw
history blame
27.5 kB
"""
HuggingFace Spaces App for GPT-2 124M Shakespeare Model
"""
import torch
import torch.nn as nn
from torch.nn import functional as F
import tiktoken
import gradio as gr
import math
from dataclasses import dataclass
class CausalSelfAttention(nn.Module):
def __init__(self, config):
super().__init__()
assert config.n_embd % config.n_head == 0
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
self.c_proj = nn.Linear(config.n_embd, config.n_embd)
self.c_proj.NANOGPT_SCALE_INIT = 1
self.n_head = config.n_head
self.n_embd = config.n_embd
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)).view(1, 1, config.block_size, config.block_size))
def forward(self, x):
B, T, C = x.size()
qkv = self.c_attn(x)
q, k, v = qkv.split(self.n_embd, dim=2)
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
att = F.softmax(att, dim=-1)
y = att @ v
y = y.transpose(1, 2).contiguous().view(B, T, C)
y = self.c_proj(y)
return y
class MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
self.gelu = nn.GELU(approximate='tanh')
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
self.c_proj.NANOGPT_SCALE_INIT = 1
def forward(self, x):
x = self.c_fc(x)
x = self.gelu(x)
x = self.c_proj(x)
return x
class Block(nn.Module):
def __init__(self, config):
super().__init__()
self.ln_1 = nn.LayerNorm(config.n_embd)
self.attn = CausalSelfAttention(config)
self.ln_2 = nn.LayerNorm(config.n_embd)
self.mlp = MLP(config)
def forward(self, x):
x = x + self.attn(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
@dataclass
class GPTConfig:
block_size: int = 1024
vocab_size: int = 50257
n_layer: int = 12
n_head: int = 12
n_embd: int = 768
class GPT(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.transformer = nn.ModuleDict(dict(
wte=nn.Embedding(config.vocab_size, config.n_embd),
wpe=nn.Embedding(config.block_size, config.n_embd),
h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
ln_f=nn.LayerNorm(config.n_embd),
))
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.transformer.wte.weight = self.lm_head.weight
def forward(self, idx, targets=None):
B, T = idx.size()
assert T <= self.config.block_size, f"Cannot forward sequence of length {T}, block size is only {self.config.block_size}"
pos = torch.arange(0, T, dtype=torch.long, device=idx.device)
pos_emb = self.transformer.wpe(pos)
tok_emb = self.transformer.wte(idx)
x = tok_emb + pos_emb
for block in self.transformer.h:
x = block(x)
x = self.transformer.ln_f(x)
logits = self.lm_head(x)
loss = None
if targets is not None:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
return logits, loss
# Load model
print("Loading model...")
device = 'cuda' if torch.cuda.is_available() else 'cpu'
config = GPTConfig()
model = GPT(config)
model_loaded = False
# Try to load model from HuggingFace Model Hub first, then local file
try:
from huggingface_hub import hf_hub_download
import os
# Try to get model path from environment variable or use default
repo_id = os.getenv('HF_MODEL_REPO', 'shwethd/gpt2-shakespeare-124m')
try:
print(f"Attempting to load from HuggingFace Hub: {repo_id}")
# Try SafeTensors first (more secure, no pickle issues)
try:
from safetensors.torch import load_file
try:
model_path = hf_hub_download(
repo_id=repo_id,
filename="model.safetensors",
cache_dir=None
)
state_dict = load_file(model_path, device=device)
model.load_state_dict(state_dict)
# Restore weight sharing (broken during SafeTensors conversion)
# lm_head.weight and transformer.wte.weight should share memory
model.transformer.wte.weight = model.lm_head.weight
model_loaded = True
print(f"βœ… Model loaded successfully from SafeTensors: {repo_id}")
except Exception as e:
print(f"SafeTensors not found ({e}), trying .pt file...")
# Fallback to .pt file
model_path = hf_hub_download(
repo_id=repo_id,
filename="model_checkpoint_final.pt",
cache_dir=None
)
# PyTorch 2.6+ requires weights_only=False for custom classes
# This is safe since we trust our own trained model
checkpoint = torch.load(model_path, map_location=device, weights_only=False)
# Handle different checkpoint formats
if 'model_state_dict' in checkpoint:
model.load_state_dict(checkpoint['model_state_dict'])
elif 'state_dict' in checkpoint:
model.load_state_dict(checkpoint['state_dict'])
else:
# If checkpoint is the state dict itself
model.load_state_dict(checkpoint)
model_loaded = True
print(f"βœ… Model loaded successfully from HuggingFace Hub: {repo_id}")
except ImportError:
# safetensors not installed, use .pt file
model_path = hf_hub_download(
repo_id=repo_id,
filename="model_checkpoint_final.pt",
cache_dir=None
)
# PyTorch 2.6+ requires weights_only=False for custom classes
checkpoint = torch.load(model_path, map_location=device, weights_only=False)
# Handle different checkpoint formats
if 'model_state_dict' in checkpoint:
model.load_state_dict(checkpoint['model_state_dict'])
elif 'state_dict' in checkpoint:
model.load_state_dict(checkpoint['state_dict'])
else:
# If checkpoint is the state dict itself
model.load_state_dict(checkpoint)
model_loaded = True
print(f"βœ… Model loaded successfully from HuggingFace Hub: {repo_id}")
except Exception as e:
print(f"⚠️ Could not load from Hub ({e}), trying local file...")
try:
# Fallback to local file
# PyTorch 2.6+ requires weights_only=False for custom classes
checkpoint = torch.load('model_checkpoint_final.pt', map_location=device, weights_only=False)
if 'model_state_dict' in checkpoint:
model.load_state_dict(checkpoint['model_state_dict'])
elif 'state_dict' in checkpoint:
model.load_state_dict(checkpoint['state_dict'])
else:
model.load_state_dict(checkpoint)
model_loaded = True
print("βœ… Model loaded from local checkpoint")
except Exception as e2:
print(f"❌ Could not load from local file either: {e2}")
except FileNotFoundError:
print("❌ Warning: Model checkpoint not found. Using untrained model.")
except Exception as e:
print(f"❌ Error loading model: {e}")
print("⚠️ Using untrained model as fallback - output will be random!")
if not model_loaded:
print("⚠️ WARNING: Model is using random weights! Generation will be nonsensical.")
print("Please ensure model_checkpoint_final.pt is uploaded to HuggingFace Model Hub.")
model.to(device)
model.eval()
print(f"Model ready on {device}")
enc = tiktoken.get_encoding('gpt2')
def generate_text(prompt, max_new_tokens=100, temperature=0.7, top_k=50, top_p=0.9, repetition_penalty=1.1):
"""Generate text from prompt with improved sampling"""
try:
if not model_loaded:
return "❌ Error: Model not loaded correctly. Please check that model_checkpoint_final.pt is uploaded to HuggingFace Model Hub (shwethd/gpt2-shakespeare-124m)."
# Validate inputs
if not prompt or len(prompt.strip()) == 0:
return "Please enter a prompt."
temperature = max(0.1, min(2.0, temperature)) # Clamp temperature
top_k = max(1, min(100, int(top_k))) # Clamp top_k
top_p = max(0.1, min(1.0, float(top_p))) # Clamp top_p (nucleus sampling)
repetition_penalty = max(1.0, min(1.5, float(repetition_penalty))) # Clamp repetition penalty
max_new_tokens = max(1, min(200, int(max_new_tokens))) # Clamp max tokens
# Encode prompt
tokens = enc.encode(prompt)
if len(tokens) == 0:
return "Error: Could not encode prompt."
tokens = torch.tensor(tokens, dtype=torch.long, device=device).unsqueeze(0)
# Generate with improved sampling strategy
with torch.no_grad():
# Track recent tokens for repetition penalty
recent_tokens = set()
for i in range(max_new_tokens):
# Forward pass
logits, _ = model(tokens)
logits = logits[:, -1, :] / max(temperature, 0.1) # Apply temperature
# Apply repetition penalty to reduce loops
if repetition_penalty > 1.0 and len(recent_tokens) > 0:
for token_id in recent_tokens:
if logits[0, token_id] > 0:
logits[0, token_id] /= repetition_penalty
else:
logits[0, token_id] *= repetition_penalty
# Convert to probabilities
probs = F.softmax(logits, dim=-1)
# Apply top-p (nucleus) sampling first - often better than just top-k
if top_p < 1.0:
sorted_probs, sorted_indices = torch.sort(probs, descending=True)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
# Remove tokens with cumulative probability above threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Keep at least one token
sorted_indices_to_remove[..., 0] = False
# Create mask
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
probs[indices_to_remove] = 0
# Renormalize
probs = probs / probs.sum()
# Apply top-k filtering (after top-p for better quality)
if top_k < logits.size(-1):
topk_probs, topk_indices = torch.topk(probs, top_k, dim=-1)
# Create filtered probabilities
filtered_probs = torch.zeros_like(probs)
filtered_probs.scatter_(-1, topk_indices, topk_probs)
# Renormalize
filtered_probs = filtered_probs / filtered_probs.sum()
probs = filtered_probs
# Avoid NaN or zero probabilities
if torch.isnan(probs).any() or (probs.sum() == 0):
probs = torch.ones_like(probs) / probs.size(-1)
# Sample from distribution
next_token = torch.multinomial(probs, 1)
# Update recent tokens for repetition penalty (keep last 20 tokens)
token_id = next_token.item()
recent_tokens.add(token_id)
if len(recent_tokens) > 20:
# Remove oldest tokens (simple approach: keep last 20)
recent_tokens = set(list(recent_tokens)[-20:])
# Append to sequence
tokens = torch.cat([tokens, next_token], dim=1)
# Early stopping: stop if we generate end-of-text token (if present)
# For GPT-2 tokenizer, we can check for certain patterns
if tokens.size(1) >= config.block_size:
break
# Decode
generated_text = enc.decode(tokens[0].tolist())
# Post-process to fix spacing issues (common with BPE tokenizers)
import re
# Fix 0: Remove the prompt from the beginning if it appears as a speaker name
# This handles cases where user enters "Romeo and Juliet" and model treats it as speaker
prompt_lower = prompt.lower().strip()
generated_lower = generated_text.lower()
# If prompt appears at the very start and looks like it was treated as a speaker
if generated_lower.startswith(prompt_lower):
# Check if it's followed by a newline (speaker format) or dialogue
prompt_len = len(prompt)
if len(generated_text) > prompt_len:
next_chars = generated_text[prompt_len:prompt_len+5].strip()
# If prompt is followed by newline or colon-like pattern, it was treated as speaker
if not next_chars or ':' in next_chars or '\n' in generated_text[prompt_len:prompt_len+5]:
# Remove the prompt from output (it's the input, not part of generated story)
generated_text = generated_text[len(prompt):].strip()
# Remove leading newlines/colons
generated_text = re.sub(r'^[\s:]+', '', generated_text)
# Check if the first line after removal is orphaned dialogue (no speaker)
lines = generated_text.split('\n')
if lines and lines[0].strip():
first_line = lines[0].strip()
# If first line is not a speaker name and looks like dialogue, add a speaker
if not re.match(r'^([A-Z][A-Z\s]+?):\s*$', first_line):
# Check if it's dialogue-like (starts with capital, has punctuation)
if re.match(r'^[A-Z]', first_line) and ('.' in first_line or ',' in first_line or '!' in first_line or '?' in first_line):
# Add a generic speaker name based on the prompt context
# For story prompts like "Romeo and Juliet", use a character from the prompt
prompt_words = [w.capitalize() for w in prompt_lower.split() if len(w) > 2]
if len(prompt_words) >= 2:
# Use first significant word as speaker (e.g., "Romeo" from "Romeo and Juliet")
speaker_name = prompt_words[0].upper()
else:
# Generic speaker
speaker_name = "NARRATOR"
# Add speaker before the dialogue
generated_text = f"{speaker_name}:\n{first_line}\n" + '\n'.join(lines[1:]) if len(lines) > 1 else f"{speaker_name}:\n{first_line}"
# Fix 1: lowercase followed by uppercase (e.g., "perpetualWith" -> "perpetual With")
generated_text = re.sub(r'([a-z])([A-Z])', r'\1 \2', generated_text)
# Fix 2: Common word boundaries that got merged (e.g., "perpetualwith" -> "perpetual with")
# Add space before common words that might have been merged
common_words = ['with', 'the', 'and', 'that', 'this', 'have', 'from', 'not', 'but', 'for', 'are', 'was', 'were', 'been', 'will', 'shall', 'would', 'could', 'should', 'be', 'your', 'you', 'our', 'my', 'his', 'her', 'their', 'him', 'them']
for word in common_words:
# Only add space if it's not already separated and follows a lowercase letter
pattern = r'([a-z])(' + word + r'\b)'
generated_text = re.sub(pattern, r'\1 \2', generated_text, flags=re.IGNORECASE)
# Fix 2b: Fix contractions that got merged (e.g., "You'llbe" -> "You'll be")
# Add space after contractions before lowercase words
contractions = ["'ll", "'ve", "'re", "'d", "'t", "'s", "'m"]
for contraction in contractions:
# Pattern: contraction followed by lowercase letter (e.g., "You'llbe" -> "You'll be")
pattern = r"(" + re.escape(contraction) + r")([a-z])"
generated_text = re.sub(pattern, r'\1 \2', generated_text, flags=re.IGNORECASE)
# Fix 3: Add space before character names (all caps words)
generated_text = re.sub(r'([a-z])([A-Z]{2,})', r'\1 \2', generated_text)
# Fix 3b: Normalize speaker names (e.g., "Romeo and juliet" -> "ROMEO AND JULIET:")
# Handle mixed case speaker names that should be all caps
lines = generated_text.split('\n')
normalized_lines = []
for i, line in enumerate(lines):
line_stripped = line.strip()
# Check if line is a potential speaker name (title case or mixed case, 2+ words)
# Pattern: "Romeo and juliet", "Romeo And Juliet", etc.
speaker_pattern = r'^([A-Z][a-z]+(?:\s+[a-zA-Z]+)+)\s*:?\s*$'
match = re.match(speaker_pattern, line_stripped)
if match:
# Check if next line is dialogue (not another speaker)
is_speaker = False
if i + 1 < len(lines):
next_line = lines[i + 1].strip()
# If next line is not empty and not a speaker name, this is likely a speaker
if next_line and not re.match(r'^([A-Z][A-Z\s]+?):\s*$', next_line):
is_speaker = True
elif i == 0: # First line is likely a speaker if it matches pattern
is_speaker = True
if is_speaker:
# Convert to all caps and ensure colon
speaker_name = match.group(1).upper()
normalized_lines.append(speaker_name + ':')
continue
normalized_lines.append(line)
generated_text = '\n'.join(normalized_lines)
# Fix 4: Remove duplicate speaker names (e.g., "EDWARD IV:\n...\nEDWARD IV:" -> keep only first)
# More aggressive: remove same speaker if it appears within 3 lines (tighter window)
lines = generated_text.split('\n')
cleaned_lines = []
speaker_history = [] # Track recent speakers with their line numbers
for i, line in enumerate(lines):
line_stripped = line.strip()
# Check if this line is a speaker name
speaker_match = re.match(r'^([A-Z][A-Z\s]+?):\s*$', line_stripped)
if speaker_match:
speaker = speaker_match.group(1).strip()
# Check if this speaker appeared recently (within last 3 lines - more aggressive)
recent_speaker = False
for hist_speaker, hist_line_num in speaker_history[-3:]:
if speaker == hist_speaker:
recent_speaker = True
break
if recent_speaker:
# Skip this duplicate speaker
continue
# Add to history
speaker_history.append((speaker, i))
# Keep only last 10 speakers in history
if len(speaker_history) > 10:
speaker_history.pop(0)
cleaned_lines.append(line)
else:
cleaned_lines.append(line)
generated_text = '\n'.join(cleaned_lines)
# Fix 5: Remove speaker names with no dialogue (e.g., "KING:\nEDWARD IV:" -> "EDWARD IV:")
# A speaker name should be followed by actual dialogue, not immediately by another speaker
lines = generated_text.split('\n')
final_lines = []
for i, line in enumerate(lines):
line_stripped = line.strip()
speaker_match = re.match(r'^([A-Z][A-Z\s]+?):\s*$', line_stripped)
if speaker_match:
# Check if next non-empty line is another speaker (meaning this speaker has no dialogue)
has_dialogue = False
for j in range(i + 1, min(i + 3, len(lines))): # Check next 3 lines (more aggressive)
next_line = lines[j].strip()
if not next_line: # Skip empty lines
continue
# If next non-empty line is NOT a speaker, we have dialogue
if not re.match(r'^([A-Z][A-Z\s]+?):\s*$', next_line):
has_dialogue = True
break
# If next non-empty line IS a speaker, this speaker has no dialogue
else:
# This speaker has no dialogue - skip it
break
if not has_dialogue:
# This speaker has no dialogue, skip it
continue
final_lines.append(line)
generated_text = '\n'.join(final_lines)
# Fix 5b: Fix merged text issues (e.g., "You?A:" -> "You? A:")
# Add space after question/exclamation marks before capital letters
generated_text = re.sub(r'([?!])([A-Z])', r'\1 \2', generated_text)
# Fix 6: Remove multiple empty lines between speaker and dialogue
generated_text = re.sub(r'([A-Z][A-Z\s]+?):\s*\n\s*\n+', r'\1:\n', generated_text)
# Fix 7: Remove any remaining consecutive duplicate speakers (final cleanup)
generated_text = re.sub(
r'^([A-Z][A-Z\s]+?):\s*\n\s*\n*\1:\s*\n',
r'\1:\n',
generated_text,
flags=re.MULTILINE
)
return generated_text
except Exception as e:
import traceback
return f"❌ Error during generation: {str(e)}\n\nPlease check:\n1. Model is uploaded to HuggingFace Model Hub\n2. Repository name is correct: shwethd/gpt2-shakespeare-124m\n3. File name is exactly: model_checkpoint_final.pt"
# Create Gradio interface
with gr.Blocks(title="GPT-2 124M Shakespeare Model") as demo:
# Status indicator
status_color = "🟒" if model_loaded else "πŸ”΄"
status_text = "Model loaded successfully!" if model_loaded else "⚠️ Model not loaded - check HuggingFace Model Hub!"
gr.Markdown(f"""
# 🎭 GPT-2 124M Shakespeare Language Model
{status_color} **Status:** {status_text}
This is a 124M parameter decoder-only transformer model trained on Shakespeare's complete works.
**Training Results:**
- Final Loss: 0.095127 (Target: < 0.099999) βœ…
- Model Parameters: 124.44M
- Training Steps: 1,637
Enter a prompt below to generate Shakespeare-style text!
{"⚠️ **Note:** If you see garbled/random text, the model may not have loaded correctly. Check the logs and ensure the model is uploaded to HuggingFace Model Hub: `shwethd/gpt2-shakespeare-124m`" if not model_loaded else ""}
""")
with gr.Row():
with gr.Column():
prompt_input = gr.Textbox(
label="Prompt",
placeholder="Enter your prompt here (e.g., 'First Citizen:', 'ROMEO:', 'To be or not')",
value="First Citizen:",
lines=3
)
max_tokens = gr.Slider(
label="Max Tokens",
minimum=50,
maximum=200,
value=100,
step=10
)
temperature = gr.Slider(
label="Temperature",
minimum=0.1,
maximum=2.0,
value=0.7,
step=0.1,
info="Lower = more focused, Higher = more creative (0.7 recommended for better coherence)"
)
top_k = gr.Slider(
label="Top-K",
minimum=10,
maximum=100,
value=50,
step=10,
info="Number of top tokens to consider"
)
top_p = gr.Slider(
label="Top-P (Nucleus)",
minimum=0.1,
maximum=1.0,
value=0.9,
step=0.05,
info="Nucleus sampling - higher = more diverse, lower = more focused (0.9 recommended)"
)
repetition_penalty = gr.Slider(
label="Repetition Penalty",
minimum=1.0,
maximum=1.5,
value=1.1,
step=0.05,
info="Penalize repeated tokens - higher = less repetition (1.1 recommended)"
)
generate_btn = gr.Button("Generate", variant="primary")
with gr.Column():
output = gr.Textbox(
label="Generated Text",
lines=10,
interactive=True, # Make it interactive so users can select and copy
show_copy_button=True # Add copy button
)
# Example prompts
gr.Markdown("### Example Prompts (Click to try):")
examples = gr.Examples(
examples=[
["First Citizen:"],
["ROMEO:"],
["To be or not"],
["HAMLET:"],
["MACBETH:"],
["JULIET:"],
["KING:"],
["LADY MACBETH:"],
["OTHELLO:"],
["What light through yonder"],
["All the world's a stage"],
["Double, double toil and trouble"],
["Friends, Romans, countrymen"],
["A rose by any other name"],
],
inputs=prompt_input
)
generate_btn.click(
fn=generate_text,
inputs=[prompt_input, max_tokens, temperature, top_k, top_p, repetition_penalty],
outputs=output
)
gr.Markdown("""
---
**Note:** The model was trained on Shakespeare text and generates text in that style.
Generated text may not always be coherent but should follow Shakespearean patterns.
""")
if __name__ == "__main__":
# Don't use share=True on HuggingFace Spaces
demo.launch()