gopi87's picture
Upload 6 files
e8bf402 verified
#!/usr/bin/env python3
"""
Example Usage: Shakespeare Transformer
This script shows how to download and use the Shakespeare model from Hugging Face.
Usage:
python example_usage.py
"""
import torch
import torch.nn as nn
print("="*70)
print("๐ŸŽญ Shakespeare Transformer - Example Usage")
print("="*70)
print()
# Check device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
print()
# ============================================
# METHOD 1: Download from Hugging Face
# ============================================
print("๐Ÿ“ฅ Method 1: Download from Hugging Face")
print("-"*70)
print()
print("To download the model:")
print()
print("from huggingface_hub import hf_hub_download")
print()
print("# Download model file")
print("model_path = hf_hub_download(")
print(" repo_id='YOUR-USERNAME/shakespeare-transformer-learning',")
print(" filename='best_model.pth'")
print(")")
print()
print("# Load the model")
print("checkpoint = torch.load(model_path, map_location=device)")
print()
# ============================================
# METHOD 2: Use Local File
# ============================================
print()
print("๐Ÿ“‚ Method 2: Use Local File")
print("-"*70)
print()
# Define the CharTokenizer class (needed for loading)
class CharTokenizer:
def __init__(self, text=None):
if text is not None:
self.chars = sorted(list(set(text)))
self.vocab_size = len(self.chars)
self.char_to_idx = {ch: i for i, ch in enumerate(self.chars)}
self.idx_to_char = {i: ch for i, ch in enumerate(self.chars)}
else:
self.chars = []
self.vocab_size = 0
self.char_to_idx = {}
self.idx_to_char = {}
def encode(self, text):
return [self.char_to_idx[ch] for ch in text if ch in self.char_to_idx]
def decode(self, indices):
return ''.join([self.idx_to_char.get(i, '') for i in indices])
# Define the model architecture
class TransformerLanguageModel(nn.Module):
def __init__(self, vocab_size, d_model=256, nhead=8, num_layers=6, dropout=0.2, seq_length=128):
super().__init__()
self.d_model = d_model
self.seq_length = seq_length
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_encoding = nn.Embedding(seq_length, d_model)
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model,
nhead=nhead,
dim_feedforward=d_model * 4,
dropout=dropout,
batch_first=True
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
self.dropout = nn.Dropout(dropout)
self.fc_out = nn.Linear(d_model, vocab_size)
def forward(self, x):
batch_size, seq_len = x.shape
token_emb = self.embedding(x)
positions = torch.arange(seq_len, device=x.device).unsqueeze(0).expand(batch_size, -1)
pos_emb = self.pos_encoding(positions)
x = self.dropout(token_emb + pos_emb)
mask = nn.Transformer.generate_square_subsequent_mask(seq_len).to(x.device)
x = self.transformer(x, mask=mask, is_causal=True)
logits = self.fc_out(x)
return logits
# Load model
try:
print("Loading model...")
checkpoint = torch.load('best_model.pth', map_location=device, weights_only=False)
tokenizer = checkpoint['tokenizer']
model = TransformerLanguageModel(
vocab_size=tokenizer.vocab_size,
d_model=256,
nhead=8,
num_layers=6,
dropout=0.2,
seq_length=128
).to(device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
print("โœ“ Model loaded successfully!")
print()
except FileNotFoundError:
print("โš ๏ธ best_model.pth not found in current directory")
print("Please download it from Hugging Face first.")
exit()
# ============================================
# GENERATION FUNCTION
# ============================================
def generate_text(prompt, max_length=300, temperature=0.8):
"""
Generate text from a prompt
Args:
prompt: Starting text (e.g., "ROMEO:" or "To be or not to be")
max_length: Maximum number of characters to generate
temperature: Sampling temperature (higher = more random)
Returns:
Generated text as string
"""
model.eval()
indices = tokenizer.encode(prompt) if prompt else [0]
with torch.no_grad():
for _ in range(max_length):
# Get last seq_length characters
x = torch.tensor(indices[-128:], dtype=torch.long).unsqueeze(0).to(device)
# Pad if needed
if x.shape[1] < 128:
padding = torch.zeros(1, 128 - x.shape[1], dtype=torch.long).to(device)
x = torch.cat([padding, x], dim=1)
# Generate next character
logits = model(x)
logits = logits[0, -1, :] / temperature
probs = torch.softmax(logits, dim=-1)
next_idx = torch.multinomial(probs, num_samples=1).item()
indices.append(next_idx)
return tokenizer.decode(indices)
# ============================================
# EXAMPLE GENERATIONS
# ============================================
print("๐ŸŽฌ Example Generations")
print("="*70)
print()
examples = [
("ROMEO:", "Character dialogue"),
("To be or not to be", "Famous quote continuation"),
("Once upon a time", "Story beginning"),
("", "Random generation"),
]
for prompt, description in examples:
print(f"๐Ÿ“ {description}")
print(f"Prompt: '{prompt}'")
print("-"*70)
generated = generate_text(prompt, max_length=200, temperature=0.8)
# Show first 300 characters
display_text = generated[:300]
if len(generated) > 300:
display_text += "..."
print(display_text)
print()
print("="*70)
print()
# ============================================
# INTERACTIVE MODE
# ============================================
print("๐ŸŽฎ Interactive Mode")
print("="*70)
print("Enter prompts to generate text. Type 'quit' to exit.")
print()
while True:
try:
prompt = input("\nEnter prompt (or 'quit'): ")
if prompt.lower() in ['quit', 'exit', 'q']:
print("Goodbye! ๐Ÿ‘‹")
break
print("\nGenerating...")
print("-"*70)
generated = generate_text(prompt, max_length=300, temperature=0.8)
print(generated[:400]) # Show first 400 characters
print("-"*70)
except KeyboardInterrupt:
print("\n\nGoodbye! ๐Ÿ‘‹")
break
# ============================================
# TIPS
# ============================================
print()
print("๐Ÿ’ก Tips for Best Results:")
print("="*70)
print()
print("1. Use character names as prompts: 'ROMEO:', 'JULIET:', etc.")
print("2. Start with famous quotes: 'To be or not to be'")
print("3. Try lower temperature (0.5) for more consistent text")
print("4. Try higher temperature (1.2) for more creative/random text")
print("5. This is a small educational model - expect imperfections!")
print()
print("๐ŸŽญ Enjoy exploring Shakespeare-style text generation!")