shakespeareGPT / app.py
pradeep6kumar2024's picture
Add application file
787565d
import gradio as gr
import torch
import torch.nn.functional as F
from train import ShakespeareModel, TextDataset
# Global variables to store model and dataset
model = None
dataset = None
# Load the trained model and dataset once at startup
def initialize_model():
global model, dataset
# Load text and create dataset to get vocab size
with open('input.txt', 'r', encoding='utf-8') as f:
text = f.read()
dataset = TextDataset(text, block_size=128)
# Initialize model and load weights
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ShakespeareModel(dataset.vocab_size).to(device)
# Load the trained weights
checkpoint = torch.load('shakespeare_model_best.pth', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
print("Model loaded successfully!")
return model, dataset
def generate_text(prompt, max_length=200, temperature=0.8):
global model, dataset
if model is None or dataset is None:
model, dataset = initialize_model()
device = next(model.parameters()).device
try:
# Convert prompt to tensor
context = torch.tensor([dataset.stoi[c] for c in prompt], dtype=torch.long).unsqueeze(0).to(device)
except KeyError:
return "Error: Prompt contains characters not in the training dataset. Please use only standard characters."
generated_text = prompt
with torch.no_grad():
for _ in range(max_length):
# Get model predictions
logits = model(context)
logits = logits[:, -1, :] / temperature
probs = F.softmax(logits, dim=-1)
# Sample from the distribution
next_token = torch.multinomial(probs, num_samples=1)
# Convert to character and append to generated text
next_char = dataset.itos[next_token.item()]
generated_text += next_char
# Update context for next prediction
context = torch.cat([context, next_token], dim=1)
if context.size(1) > 128: # Keep context window fixed
context = context[:, -128:]
# Stop if we generate a lot of newlines (end of scene)
if generated_text.count('\n\n') > 2:
break
return generated_text
def complete_sentence(prompt, num_words=5):
global model, dataset
if model is None or dataset is None:
model, dataset = initialize_model()
device = next(model.parameters()).device
try:
# Convert prompt to tensor
context = torch.tensor([dataset.stoi[c] for c in prompt], dtype=torch.long).unsqueeze(0).to(device)
except KeyError:
return "Error: Prompt contains characters not in the training dataset. Please use only standard characters."
generated_text = prompt
word_count = 0
with torch.no_grad():
while word_count < num_words:
# Get model predictions
logits = model(context)
logits = logits[:, -1, :] / 0.7 # Lower temperature for more focused completion
probs = F.softmax(logits, dim=-1)
# Sample from the distribution
next_token = torch.multinomial(probs, num_samples=1)
# Convert to character and append to generated text
next_char = dataset.itos[next_token.item()]
generated_text += next_char
# Count words (roughly) by counting spaces
if next_char == ' ':
word_count += 1
# Update context for next prediction
context = torch.cat([context, next_token], dim=1)
if context.size(1) > 128:
context = context[:, -128:]
return generated_text
# Create Gradio interface
def create_interface():
with gr.Blocks() as demo:
gr.Markdown("# Shakespeare Text Generator")
gr.Markdown("Enter a prompt and the model will continue the text in Shakespeare's style.")
with gr.Tab("Generate Text"):
with gr.Row():
with gr.Column():
input_text = gr.Textbox(
label="Enter your prompt",
placeholder="Enter a few words to start...",
lines=3
)
temperature = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.8,
step=0.1,
label="Temperature (Higher = more creative, Lower = more focused)"
)
max_length = gr.Slider(
minimum=50,
maximum=500,
value=200,
step=50,
label="Maximum length of generated text"
)
generate_button = gr.Button("Generate")
with gr.Column():
output_text = gr.Textbox(
label="Generated Text",
lines=10
)
generate_button.click(
fn=generate_text,
inputs=[input_text, max_length, temperature],
outputs=output_text
)
with gr.Tab("Complete Sentence"):
with gr.Row():
with gr.Column():
sentence_input = gr.Textbox(
label="Enter an incomplete sentence",
placeholder="Enter a sentence to complete...",
lines=2
)
num_words = gr.Slider(
minimum=1,
maximum=20,
value=5,
step=1,
label="Number of words to generate"
)
complete_button = gr.Button("Complete Sentence")
with gr.Column():
completed_text = gr.Textbox(
label="Completed Sentence",
lines=5
)
complete_button.click(
fn=complete_sentence,
inputs=[sentence_input, num_words],
outputs=completed_text
)
gr.Markdown("""
## Tips for better results:
1. Start with a character name and a colon (e.g., "HAMLET:")
2. Use proper names and places from Shakespeare's plays
3. Try different temperatures for varying creativity levels
4. Keep initial prompts relatively short (1-2 lines)
""")
# Add some example prompts
gr.Examples(
examples=[
["HAMLET: To be, or not to be,"],
["MACBETH: Is this a dagger"],
["ROMEO: But, soft! what light through yonder"],
["PROSPERO: Our revels now are"],
],
inputs=input_text
)
return demo
# Initialize model at startup
print("Initializing model...")
initialize_model()
# Launch the app
if __name__ == "__main__":
demo = create_interface()
demo.launch()