#!/usr/bin/env python3 """ Example Usage: Shakespeare Transformer This script shows how to download and use the Shakespeare model from Hugging Face. Usage: python example_usage.py """ import torch import torch.nn as nn print("="*70) print("🎭 Shakespeare Transformer - Example Usage") print("="*70) print() # Check device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Device: {device}") print() # ============================================ # METHOD 1: Download from Hugging Face # ============================================ print("📥 Method 1: Download from Hugging Face") print("-"*70) print() print("To download the model:") print() print("from huggingface_hub import hf_hub_download") print() print("# Download model file") print("model_path = hf_hub_download(") print(" repo_id='YOUR-USERNAME/shakespeare-transformer-learning',") print(" filename='best_model.pth'") print(")") print() print("# Load the model") print("checkpoint = torch.load(model_path, map_location=device)") print() # ============================================ # METHOD 2: Use Local File # ============================================ print() print("📂 Method 2: Use Local File") print("-"*70) print() # Define the CharTokenizer class (needed for loading) class CharTokenizer: def __init__(self, text=None): if text is not None: self.chars = sorted(list(set(text))) self.vocab_size = len(self.chars) self.char_to_idx = {ch: i for i, ch in enumerate(self.chars)} self.idx_to_char = {i: ch for i, ch in enumerate(self.chars)} else: self.chars = [] self.vocab_size = 0 self.char_to_idx = {} self.idx_to_char = {} def encode(self, text): return [self.char_to_idx[ch] for ch in text if ch in self.char_to_idx] def decode(self, indices): return ''.join([self.idx_to_char.get(i, '') for i in indices]) # Define the model architecture class TransformerLanguageModel(nn.Module): def __init__(self, vocab_size, d_model=256, nhead=8, num_layers=6, dropout=0.2, seq_length=128): super().__init__() self.d_model = d_model self.seq_length = seq_length self.embedding = nn.Embedding(vocab_size, d_model) self.pos_encoding = nn.Embedding(seq_length, d_model) encoder_layer = nn.TransformerEncoderLayer( d_model=d_model, nhead=nhead, dim_feedforward=d_model * 4, dropout=dropout, batch_first=True ) self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) self.dropout = nn.Dropout(dropout) self.fc_out = nn.Linear(d_model, vocab_size) def forward(self, x): batch_size, seq_len = x.shape token_emb = self.embedding(x) positions = torch.arange(seq_len, device=x.device).unsqueeze(0).expand(batch_size, -1) pos_emb = self.pos_encoding(positions) x = self.dropout(token_emb + pos_emb) mask = nn.Transformer.generate_square_subsequent_mask(seq_len).to(x.device) x = self.transformer(x, mask=mask, is_causal=True) logits = self.fc_out(x) return logits # Load model try: print("Loading model...") checkpoint = torch.load('best_model.pth', map_location=device, weights_only=False) tokenizer = checkpoint['tokenizer'] model = TransformerLanguageModel( vocab_size=tokenizer.vocab_size, d_model=256, nhead=8, num_layers=6, dropout=0.2, seq_length=128 ).to(device) model.load_state_dict(checkpoint['model_state_dict']) model.eval() print("✓ Model loaded successfully!") print() except FileNotFoundError: print("⚠️ best_model.pth not found in current directory") print("Please download it from Hugging Face first.") exit() # ============================================ # GENERATION FUNCTION # ============================================ def generate_text(prompt, max_length=300, temperature=0.8): """ Generate text from a prompt Args: prompt: Starting text (e.g., "ROMEO:" or "To be or not to be") max_length: Maximum number of characters to generate temperature: Sampling temperature (higher = more random) Returns: Generated text as string """ model.eval() indices = tokenizer.encode(prompt) if prompt else [0] with torch.no_grad(): for _ in range(max_length): # Get last seq_length characters x = torch.tensor(indices[-128:], dtype=torch.long).unsqueeze(0).to(device) # Pad if needed if x.shape[1] < 128: padding = torch.zeros(1, 128 - x.shape[1], dtype=torch.long).to(device) x = torch.cat([padding, x], dim=1) # Generate next character logits = model(x) logits = logits[0, -1, :] / temperature probs = torch.softmax(logits, dim=-1) next_idx = torch.multinomial(probs, num_samples=1).item() indices.append(next_idx) return tokenizer.decode(indices) # ============================================ # EXAMPLE GENERATIONS # ============================================ print("🎬 Example Generations") print("="*70) print() examples = [ ("ROMEO:", "Character dialogue"), ("To be or not to be", "Famous quote continuation"), ("Once upon a time", "Story beginning"), ("", "Random generation"), ] for prompt, description in examples: print(f"📝 {description}") print(f"Prompt: '{prompt}'") print("-"*70) generated = generate_text(prompt, max_length=200, temperature=0.8) # Show first 300 characters display_text = generated[:300] if len(generated) > 300: display_text += "..." print(display_text) print() print("="*70) print() # ============================================ # INTERACTIVE MODE # ============================================ print("🎮 Interactive Mode") print("="*70) print("Enter prompts to generate text. Type 'quit' to exit.") print() while True: try: prompt = input("\nEnter prompt (or 'quit'): ") if prompt.lower() in ['quit', 'exit', 'q']: print("Goodbye! 👋") break print("\nGenerating...") print("-"*70) generated = generate_text(prompt, max_length=300, temperature=0.8) print(generated[:400]) # Show first 400 characters print("-"*70) except KeyboardInterrupt: print("\n\nGoodbye! 👋") break # ============================================ # TIPS # ============================================ print() print("💡 Tips for Best Results:") print("="*70) print() print("1. Use character names as prompts: 'ROMEO:', 'JULIET:', etc.") print("2. Start with famous quotes: 'To be or not to be'") print("3. Try lower temperature (0.5) for more consistent text") print("4. Try higher temperature (1.2) for more creative/random text") print("5. This is a small educational model - expect imperfections!") print() print("🎭 Enjoy exploring Shakespeare-style text generation!")