import gradio as gr import torch import torch.nn.functional as F from train import ShakespeareModel, TextDataset # Global variables to store model and dataset model = None dataset = None # Load the trained model and dataset once at startup def initialize_model(): global model, dataset # Load text and create dataset to get vocab size with open('input.txt', 'r', encoding='utf-8') as f: text = f.read() dataset = TextDataset(text, block_size=128) # Initialize model and load weights device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = ShakespeareModel(dataset.vocab_size).to(device) # Load the trained weights checkpoint = torch.load('shakespeare_model_best.pth', map_location=device) model.load_state_dict(checkpoint['model_state_dict']) model.eval() print("Model loaded successfully!") return model, dataset def generate_text(prompt, max_length=200, temperature=0.8): global model, dataset if model is None or dataset is None: model, dataset = initialize_model() device = next(model.parameters()).device try: # Convert prompt to tensor context = torch.tensor([dataset.stoi[c] for c in prompt], dtype=torch.long).unsqueeze(0).to(device) except KeyError: return "Error: Prompt contains characters not in the training dataset. Please use only standard characters." generated_text = prompt with torch.no_grad(): for _ in range(max_length): # Get model predictions logits = model(context) logits = logits[:, -1, :] / temperature probs = F.softmax(logits, dim=-1) # Sample from the distribution next_token = torch.multinomial(probs, num_samples=1) # Convert to character and append to generated text next_char = dataset.itos[next_token.item()] generated_text += next_char # Update context for next prediction context = torch.cat([context, next_token], dim=1) if context.size(1) > 128: # Keep context window fixed context = context[:, -128:] # Stop if we generate a lot of newlines (end of scene) if generated_text.count('\n\n') > 2: break return generated_text def complete_sentence(prompt, num_words=5): global model, dataset if model is None or dataset is None: model, dataset = initialize_model() device = next(model.parameters()).device try: # Convert prompt to tensor context = torch.tensor([dataset.stoi[c] for c in prompt], dtype=torch.long).unsqueeze(0).to(device) except KeyError: return "Error: Prompt contains characters not in the training dataset. Please use only standard characters." generated_text = prompt word_count = 0 with torch.no_grad(): while word_count < num_words: # Get model predictions logits = model(context) logits = logits[:, -1, :] / 0.7 # Lower temperature for more focused completion probs = F.softmax(logits, dim=-1) # Sample from the distribution next_token = torch.multinomial(probs, num_samples=1) # Convert to character and append to generated text next_char = dataset.itos[next_token.item()] generated_text += next_char # Count words (roughly) by counting spaces if next_char == ' ': word_count += 1 # Update context for next prediction context = torch.cat([context, next_token], dim=1) if context.size(1) > 128: context = context[:, -128:] return generated_text # Create Gradio interface def create_interface(): with gr.Blocks() as demo: gr.Markdown("# Shakespeare Text Generator") gr.Markdown("Enter a prompt and the model will continue the text in Shakespeare's style.") with gr.Tab("Generate Text"): with gr.Row(): with gr.Column(): input_text = gr.Textbox( label="Enter your prompt", placeholder="Enter a few words to start...", lines=3 ) temperature = gr.Slider( minimum=0.1, maximum=1.0, value=0.8, step=0.1, label="Temperature (Higher = more creative, Lower = more focused)" ) max_length = gr.Slider( minimum=50, maximum=500, value=200, step=50, label="Maximum length of generated text" ) generate_button = gr.Button("Generate") with gr.Column(): output_text = gr.Textbox( label="Generated Text", lines=10 ) generate_button.click( fn=generate_text, inputs=[input_text, max_length, temperature], outputs=output_text ) with gr.Tab("Complete Sentence"): with gr.Row(): with gr.Column(): sentence_input = gr.Textbox( label="Enter an incomplete sentence", placeholder="Enter a sentence to complete...", lines=2 ) num_words = gr.Slider( minimum=1, maximum=20, value=5, step=1, label="Number of words to generate" ) complete_button = gr.Button("Complete Sentence") with gr.Column(): completed_text = gr.Textbox( label="Completed Sentence", lines=5 ) complete_button.click( fn=complete_sentence, inputs=[sentence_input, num_words], outputs=completed_text ) gr.Markdown(""" ## Tips for better results: 1. Start with a character name and a colon (e.g., "HAMLET:") 2. Use proper names and places from Shakespeare's plays 3. Try different temperatures for varying creativity levels 4. Keep initial prompts relatively short (1-2 lines) """) # Add some example prompts gr.Examples( examples=[ ["HAMLET: To be, or not to be,"], ["MACBETH: Is this a dagger"], ["ROMEO: But, soft! what light through yonder"], ["PROSPERO: Our revels now are"], ], inputs=input_text ) return demo # Initialize model at startup print("Initializing model...") initialize_model() # Launch the app if __name__ == "__main__": demo = create_interface() demo.launch()