| import torch |
| import gradio as gr |
| import tiktoken |
| import os |
| from torch.nn import functional as F |
| from model import GPT, GPTConfig |
|
|
| |
| model = None |
|
|
| def initialize_model(): |
| global model |
| if model is None: |
| model_path = 'model/model_state_dict.pth' |
| if not os.path.exists(model_path): |
| raise FileNotFoundError(f"Model file not found at {model_path}") |
| |
| try: |
| model = GPT(GPTConfig()) |
| model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) |
| model.eval() |
| except Exception as e: |
| raise Exception(f"Error loading model: {str(e)}") |
| return model |
|
|
| def generate_shakespeare(prompt, max_length=100, temperature=0.8): |
| """Generate Shakespeare-style text from a prompt""" |
| try: |
| |
| model = initialize_model() |
| |
| |
| enc = tiktoken.get_encoding('gpt2') |
| prompt_tokens = enc.encode(prompt) |
| |
| |
| if len(prompt_tokens) > model.config.block_size: |
| return f"Prompt too long. Please limit to {model.config.block_size} tokens." |
| |
| x = torch.tensor(prompt_tokens).unsqueeze(0) |
| |
| with torch.no_grad(): |
| while x.size(1) < max_length: |
| |
| logits, _ = model(x) |
| logits = logits[:, -1, :] / temperature |
| |
| |
| probs = F.softmax(logits, dim=-1) |
| next_token = torch.multinomial(probs, num_samples=1) |
| |
| |
| x = torch.cat((x, next_token), dim=1) |
| |
| |
| if next_token.item() == enc.encode('\n')[0]: |
| break |
| |
| |
| generated_tokens = x[0].tolist() |
| generated_text = enc.decode(generated_tokens) |
| |
| return generated_text |
| |
| except Exception as e: |
| return f"Error generating text: {str(e)}" |
|
|
| |
| demo = gr.Interface( |
| fn=generate_shakespeare, |
| inputs=[ |
| gr.Textbox( |
| label="Enter your prompt", |
| placeholder="Enter some Shakespeare-style text...", |
| lines=2 |
| ), |
| gr.Slider( |
| minimum=10, |
| maximum=200, |
| value=100, |
| step=1, |
| label="Max Length", |
| info="Maximum length of generated text" |
| ), |
| gr.Slider( |
| minimum=0.1, |
| maximum=2.0, |
| value=0.8, |
| step=0.1, |
| label="Temperature", |
| info="Higher values make the output more random, lower values make it more focused" |
| ) |
| ], |
| outputs=gr.Textbox(label="Generated Text", lines=5), |
| title="Shakespeare Text Generator", |
| description="""Generate Shakespeare-style text based on your prompt using a fine-tuned GPT model. |
| Enter a prompt and adjust the parameters to control the generation.""", |
| examples=[ |
| ["To be, or not to be,", 100, 0.8], |
| ["All the world's a stage,", 100, 0.8], |
| ["Romeo, Romeo,", 100, 0.8] |
| ], |
| cache_examples=True |
| ) |
|
|
| |
| if __name__ == "__main__": |
| try: |
| |
| initialize_model() |
| demo.launch() |
| except Exception as e: |
| print(f"Error starting the application: {str(e)}") |