Spaces:
Sleeping
Sleeping
| import torch | |
| import gradio as gr | |
| from model import CharacterLevelTokenizer, PotterGPT, Config | |
| class GradioApp(): | |
| def __init__(self): | |
| # Set up configuration and data | |
| self.model_path = 'potterGPT/potterGPT.pth' | |
| with open('data/harry_potter_data', 'r', encoding='utf-8') as f: | |
| data = f.read() | |
| self.tokenizer = CharacterLevelTokenizer(data) | |
| self.lm = PotterGPT(Config) | |
| state_dict = torch.load(self.model_path, map_location='cpu') | |
| self.lm.load_state_dict(state_dict) | |
| def launch(self): | |
| # Define Gradio interface without a clear button | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# potterGPT v0") | |
| gr.Markdown("Click the button to generate a text prompt using the potterGPT model.") | |
| generate_button = gr.Button("Generate") | |
| output_text = gr.Textbox(label="Generated Text") | |
| generate_button.click(self.generate_text, inputs=None, outputs=output_text) | |
| demo.launch() | |
| def generate_text(self, input=None): | |
| """Generate text using the trained model.""" | |
| generated_texts = [] | |
| for length in [1000]: | |
| generated = self.lm.generate( | |
| torch.zeros((1,1),dtype=torch.long,device='cpu') + 61, # initial context 0, 61 is \n | |
| total=length | |
| ) | |
| generated = self.tokenizer.decode(generated[0].cpu().numpy()) | |
| text = f'generated ({length} tokens)\n{"="*50}\n{generated}\n{"="*50}\n\n' | |
| generated_texts.append(text) | |
| return generated_texts[0] | |
| if __name__ == '__main__': | |
| app = GradioApp() | |
| app.launch() |