|
|
|
|
|
""" |
|
|
Example Usage: Shakespeare Transformer |
|
|
|
|
|
This script shows how to download and use the Shakespeare model from Hugging Face. |
|
|
|
|
|
Usage: |
|
|
python example_usage.py |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
print("="*70) |
|
|
print("๐ญ Shakespeare Transformer - Example Usage") |
|
|
print("="*70) |
|
|
print() |
|
|
|
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
print(f"Device: {device}") |
|
|
print() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("๐ฅ Method 1: Download from Hugging Face") |
|
|
print("-"*70) |
|
|
print() |
|
|
print("To download the model:") |
|
|
print() |
|
|
print("from huggingface_hub import hf_hub_download") |
|
|
print() |
|
|
print("# Download model file") |
|
|
print("model_path = hf_hub_download(") |
|
|
print(" repo_id='YOUR-USERNAME/shakespeare-transformer-learning',") |
|
|
print(" filename='best_model.pth'") |
|
|
print(")") |
|
|
print() |
|
|
print("# Load the model") |
|
|
print("checkpoint = torch.load(model_path, map_location=device)") |
|
|
print() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print() |
|
|
print("๐ Method 2: Use Local File") |
|
|
print("-"*70) |
|
|
print() |
|
|
|
|
|
|
|
|
class CharTokenizer: |
|
|
def __init__(self, text=None): |
|
|
if text is not None: |
|
|
self.chars = sorted(list(set(text))) |
|
|
self.vocab_size = len(self.chars) |
|
|
self.char_to_idx = {ch: i for i, ch in enumerate(self.chars)} |
|
|
self.idx_to_char = {i: ch for i, ch in enumerate(self.chars)} |
|
|
else: |
|
|
self.chars = [] |
|
|
self.vocab_size = 0 |
|
|
self.char_to_idx = {} |
|
|
self.idx_to_char = {} |
|
|
|
|
|
def encode(self, text): |
|
|
return [self.char_to_idx[ch] for ch in text if ch in self.char_to_idx] |
|
|
|
|
|
def decode(self, indices): |
|
|
return ''.join([self.idx_to_char.get(i, '') for i in indices]) |
|
|
|
|
|
|
|
|
class TransformerLanguageModel(nn.Module): |
|
|
def __init__(self, vocab_size, d_model=256, nhead=8, num_layers=6, dropout=0.2, seq_length=128): |
|
|
super().__init__() |
|
|
self.d_model = d_model |
|
|
self.seq_length = seq_length |
|
|
self.embedding = nn.Embedding(vocab_size, d_model) |
|
|
self.pos_encoding = nn.Embedding(seq_length, d_model) |
|
|
|
|
|
encoder_layer = nn.TransformerEncoderLayer( |
|
|
d_model=d_model, |
|
|
nhead=nhead, |
|
|
dim_feedforward=d_model * 4, |
|
|
dropout=dropout, |
|
|
batch_first=True |
|
|
) |
|
|
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) |
|
|
|
|
|
self.dropout = nn.Dropout(dropout) |
|
|
self.fc_out = nn.Linear(d_model, vocab_size) |
|
|
|
|
|
def forward(self, x): |
|
|
batch_size, seq_len = x.shape |
|
|
token_emb = self.embedding(x) |
|
|
positions = torch.arange(seq_len, device=x.device).unsqueeze(0).expand(batch_size, -1) |
|
|
pos_emb = self.pos_encoding(positions) |
|
|
x = self.dropout(token_emb + pos_emb) |
|
|
mask = nn.Transformer.generate_square_subsequent_mask(seq_len).to(x.device) |
|
|
x = self.transformer(x, mask=mask, is_causal=True) |
|
|
logits = self.fc_out(x) |
|
|
return logits |
|
|
|
|
|
|
|
|
try: |
|
|
print("Loading model...") |
|
|
checkpoint = torch.load('best_model.pth', map_location=device, weights_only=False) |
|
|
tokenizer = checkpoint['tokenizer'] |
|
|
|
|
|
model = TransformerLanguageModel( |
|
|
vocab_size=tokenizer.vocab_size, |
|
|
d_model=256, |
|
|
nhead=8, |
|
|
num_layers=6, |
|
|
dropout=0.2, |
|
|
seq_length=128 |
|
|
).to(device) |
|
|
|
|
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
|
model.eval() |
|
|
|
|
|
print("โ Model loaded successfully!") |
|
|
print() |
|
|
|
|
|
except FileNotFoundError: |
|
|
print("โ ๏ธ best_model.pth not found in current directory") |
|
|
print("Please download it from Hugging Face first.") |
|
|
exit() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_text(prompt, max_length=300, temperature=0.8): |
|
|
""" |
|
|
Generate text from a prompt |
|
|
|
|
|
Args: |
|
|
prompt: Starting text (e.g., "ROMEO:" or "To be or not to be") |
|
|
max_length: Maximum number of characters to generate |
|
|
temperature: Sampling temperature (higher = more random) |
|
|
|
|
|
Returns: |
|
|
Generated text as string |
|
|
""" |
|
|
model.eval() |
|
|
indices = tokenizer.encode(prompt) if prompt else [0] |
|
|
|
|
|
with torch.no_grad(): |
|
|
for _ in range(max_length): |
|
|
|
|
|
x = torch.tensor(indices[-128:], dtype=torch.long).unsqueeze(0).to(device) |
|
|
|
|
|
|
|
|
if x.shape[1] < 128: |
|
|
padding = torch.zeros(1, 128 - x.shape[1], dtype=torch.long).to(device) |
|
|
x = torch.cat([padding, x], dim=1) |
|
|
|
|
|
|
|
|
logits = model(x) |
|
|
logits = logits[0, -1, :] / temperature |
|
|
probs = torch.softmax(logits, dim=-1) |
|
|
next_idx = torch.multinomial(probs, num_samples=1).item() |
|
|
indices.append(next_idx) |
|
|
|
|
|
return tokenizer.decode(indices) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("๐ฌ Example Generations") |
|
|
print("="*70) |
|
|
print() |
|
|
|
|
|
examples = [ |
|
|
("ROMEO:", "Character dialogue"), |
|
|
("To be or not to be", "Famous quote continuation"), |
|
|
("Once upon a time", "Story beginning"), |
|
|
("", "Random generation"), |
|
|
] |
|
|
|
|
|
for prompt, description in examples: |
|
|
print(f"๐ {description}") |
|
|
print(f"Prompt: '{prompt}'") |
|
|
print("-"*70) |
|
|
|
|
|
generated = generate_text(prompt, max_length=200, temperature=0.8) |
|
|
|
|
|
|
|
|
display_text = generated[:300] |
|
|
if len(generated) > 300: |
|
|
display_text += "..." |
|
|
|
|
|
print(display_text) |
|
|
print() |
|
|
print("="*70) |
|
|
print() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("๐ฎ Interactive Mode") |
|
|
print("="*70) |
|
|
print("Enter prompts to generate text. Type 'quit' to exit.") |
|
|
print() |
|
|
|
|
|
while True: |
|
|
try: |
|
|
prompt = input("\nEnter prompt (or 'quit'): ") |
|
|
|
|
|
if prompt.lower() in ['quit', 'exit', 'q']: |
|
|
print("Goodbye! ๐") |
|
|
break |
|
|
|
|
|
print("\nGenerating...") |
|
|
print("-"*70) |
|
|
|
|
|
generated = generate_text(prompt, max_length=300, temperature=0.8) |
|
|
print(generated[:400]) |
|
|
|
|
|
print("-"*70) |
|
|
|
|
|
except KeyboardInterrupt: |
|
|
print("\n\nGoodbye! ๐") |
|
|
break |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print() |
|
|
print("๐ก Tips for Best Results:") |
|
|
print("="*70) |
|
|
print() |
|
|
print("1. Use character names as prompts: 'ROMEO:', 'JULIET:', etc.") |
|
|
print("2. Start with famous quotes: 'To be or not to be'") |
|
|
print("3. Try lower temperature (0.5) for more consistent text") |
|
|
print("4. Try higher temperature (1.2) for more creative/random text") |
|
|
print("5. This is a small educational model - expect imperfections!") |
|
|
print() |
|
|
print("๐ญ Enjoy exploring Shakespeare-style text generation!") |
|
|
|