File size: 7,587 Bytes
787565d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
import gradio as gr
import torch
import torch.nn.functional as F
from train import ShakespeareModel, TextDataset

# Global variables to store model and dataset
model = None
dataset = None

# Load the trained model and dataset once at startup
def initialize_model():
    global model, dataset
    
    # Load text and create dataset to get vocab size
    with open('input.txt', 'r', encoding='utf-8') as f:
        text = f.read()
    dataset = TextDataset(text, block_size=128)
    
    # Initialize model and load weights
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = ShakespeareModel(dataset.vocab_size).to(device)
    
    # Load the trained weights
    checkpoint = torch.load('shakespeare_model_best.pth', map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    
    print("Model loaded successfully!")
    return model, dataset

def generate_text(prompt, max_length=200, temperature=0.8):
    global model, dataset
    
    if model is None or dataset is None:
        model, dataset = initialize_model()
    
    device = next(model.parameters()).device
    
    try:
        # Convert prompt to tensor
        context = torch.tensor([dataset.stoi[c] for c in prompt], dtype=torch.long).unsqueeze(0).to(device)
    except KeyError:
        return "Error: Prompt contains characters not in the training dataset. Please use only standard characters."
    
    generated_text = prompt
    
    with torch.no_grad():
        for _ in range(max_length):
            # Get model predictions
            logits = model(context)
            logits = logits[:, -1, :] / temperature
            probs = F.softmax(logits, dim=-1)
            
            # Sample from the distribution
            next_token = torch.multinomial(probs, num_samples=1)
            
            # Convert to character and append to generated text
            next_char = dataset.itos[next_token.item()]
            generated_text += next_char
            
            # Update context for next prediction
            context = torch.cat([context, next_token], dim=1)
            if context.size(1) > 128:  # Keep context window fixed
                context = context[:, -128:]
                
            # Stop if we generate a lot of newlines (end of scene)
            if generated_text.count('\n\n') > 2:
                break
    
    return generated_text

def complete_sentence(prompt, num_words=5):
    global model, dataset
    
    if model is None or dataset is None:
        model, dataset = initialize_model()
    
    device = next(model.parameters()).device
    
    try:
        # Convert prompt to tensor
        context = torch.tensor([dataset.stoi[c] for c in prompt], dtype=torch.long).unsqueeze(0).to(device)
    except KeyError:
        return "Error: Prompt contains characters not in the training dataset. Please use only standard characters."
    
    generated_text = prompt
    word_count = 0
    
    with torch.no_grad():
        while word_count < num_words:
            # Get model predictions
            logits = model(context)
            logits = logits[:, -1, :] / 0.7  # Lower temperature for more focused completion
            probs = F.softmax(logits, dim=-1)
            
            # Sample from the distribution
            next_token = torch.multinomial(probs, num_samples=1)
            
            # Convert to character and append to generated text
            next_char = dataset.itos[next_token.item()]
            generated_text += next_char
            
            # Count words (roughly) by counting spaces
            if next_char == ' ':
                word_count += 1
            
            # Update context for next prediction
            context = torch.cat([context, next_token], dim=1)
            if context.size(1) > 128:
                context = context[:, -128:]
    
    return generated_text

# Create Gradio interface
def create_interface():
    with gr.Blocks() as demo:
        gr.Markdown("# Shakespeare Text Generator")
        gr.Markdown("Enter a prompt and the model will continue the text in Shakespeare's style.")
        
        with gr.Tab("Generate Text"):
            with gr.Row():
                with gr.Column():
                    input_text = gr.Textbox(
                        label="Enter your prompt",
                        placeholder="Enter a few words to start...",
                        lines=3
                    )
                    
                    temperature = gr.Slider(
                        minimum=0.1,
                        maximum=1.0,
                        value=0.8,
                        step=0.1,
                        label="Temperature (Higher = more creative, Lower = more focused)"
                    )
                    
                    max_length = gr.Slider(
                        minimum=50,
                        maximum=500,
                        value=200,
                        step=50,
                        label="Maximum length of generated text"
                    )
                    
                    generate_button = gr.Button("Generate")
                
                with gr.Column():
                    output_text = gr.Textbox(
                        label="Generated Text",
                        lines=10
                    )
            
            generate_button.click(
                fn=generate_text,
                inputs=[input_text, max_length, temperature],
                outputs=output_text
            )
        
        with gr.Tab("Complete Sentence"):
            with gr.Row():
                with gr.Column():
                    sentence_input = gr.Textbox(
                        label="Enter an incomplete sentence",
                        placeholder="Enter a sentence to complete...",
                        lines=2
                    )
                    
                    num_words = gr.Slider(
                        minimum=1,
                        maximum=20,
                        value=5,
                        step=1,
                        label="Number of words to generate"
                    )
                    
                    complete_button = gr.Button("Complete Sentence")
                
                with gr.Column():
                    completed_text = gr.Textbox(
                        label="Completed Sentence",
                        lines=5
                    )
            
            complete_button.click(
                fn=complete_sentence,
                inputs=[sentence_input, num_words],
                outputs=completed_text
            )
        
        gr.Markdown("""
        ## Tips for better results:
        1. Start with a character name and a colon (e.g., "HAMLET:")
        2. Use proper names and places from Shakespeare's plays
        3. Try different temperatures for varying creativity levels
        4. Keep initial prompts relatively short (1-2 lines)
        """)
        
        # Add some example prompts
        gr.Examples(
            examples=[
                ["HAMLET: To be, or not to be,"],
                ["MACBETH: Is this a dagger"],
                ["ROMEO: But, soft! what light through yonder"],
                ["PROSPERO: Our revels now are"],
            ],
            inputs=input_text
        )
    
    return demo

# Initialize model at startup
print("Initializing model...")
initialize_model()

# Launch the app
if __name__ == "__main__":
    demo = create_interface()
    demo.launch()