Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from story_gpt.config import StoryGPTConfig | |
| from story_gpt.service import StoryGPTService | |
| config = StoryGPTConfig() | |
| service = StoryGPTService(config=config) | |
| def generate_story(title, genre, tone, idea, opening_line, max_new_tokens, temperature, top_k): | |
| return service.generate_story( | |
| title=title, | |
| genre=genre, | |
| tone=tone, | |
| idea=idea, | |
| opening_line=opening_line, | |
| max_new_tokens=int(max_new_tokens), | |
| temperature=float(temperature), | |
| top_k=int(top_k), | |
| ) | |
| def train_story_model(extra_story_text, steps): | |
| return service.train(extra_story_text=extra_story_text, steps=int(steps)) | |
| def reset_story_model(): | |
| return service.reset() | |
| with gr.Blocks( | |
| title="Story GPT Python", | |
| theme=gr.themes.Soft(primary_hue="amber", secondary_hue="orange"), | |
| ) as demo: | |
| gr.Markdown( | |
| """ | |
| # Story GPT Python | |
| A tiny story-writing GPT-style model written in Python from scratch. | |
| - Causal transformer decoder | |
| - Word-level tokenizer | |
| - Story-focused local training corpus | |
| - No external pretrained LLM | |
| """ | |
| ) | |
| with gr.Tab("Write Story"): | |
| with gr.Row(): | |
| title_input = gr.Textbox(label="Title", value="The Lantern in the Rain") | |
| genre_input = gr.Dropdown( | |
| label="Genre", | |
| choices=["Fantasy", "Adventure", "Mystery", "Sci-Fi", "Friendship", "Folktale"], | |
| value="Fantasy", | |
| ) | |
| tone_input = gr.Dropdown( | |
| label="Tone", | |
| choices=["Warm", "Wonder", "Suspense", "Playful", "Calm", "Heroic"], | |
| value="Wonder", | |
| ) | |
| idea_input = gr.Textbox( | |
| label="Story Idea", | |
| value="A child finds a glowing lantern that reveals hidden paths after a storm.", | |
| lines=5, | |
| ) | |
| opening_line_input = gr.Textbox( | |
| label="Opening Line", | |
| value="When the rain stopped, the alley behind Mira's house began to shine.", | |
| lines=2, | |
| ) | |
| with gr.Row(): | |
| max_tokens_input = gr.Slider(30, 220, value=110, step=5, label="Story Length") | |
| temperature_input = gr.Slider(0.2, 1.4, value=0.85, step=0.05, label="Temperature") | |
| top_k_input = gr.Slider(1, 24, value=10, step=1, label="Top-K") | |
| generate_button = gr.Button("Generate Story", variant="primary") | |
| output_text = gr.Textbox(label="Story Output", lines=14) | |
| output_status = gr.Textbox(label="Status", lines=4) | |
| with gr.Tab("Train"): | |
| extra_story_text_input = gr.Textbox( | |
| label="Extra Story Examples", | |
| placeholder="Add more short stories, story prompts, or endings to continue training the model.", | |
| lines=12, | |
| ) | |
| steps_input = gr.Slider(10, 500, value=140, step=10, label="Training Steps") | |
| train_button = gr.Button("Train Story Model", variant="primary") | |
| reset_button = gr.Button("Reset Model") | |
| train_status = gr.Textbox(label="Training Status", lines=6) | |
| generate_button.click( | |
| fn=generate_story, | |
| inputs=[ | |
| title_input, | |
| genre_input, | |
| tone_input, | |
| idea_input, | |
| opening_line_input, | |
| max_tokens_input, | |
| temperature_input, | |
| top_k_input, | |
| ], | |
| outputs=[output_text, output_status], | |
| ) | |
| train_button.click( | |
| fn=train_story_model, | |
| inputs=[extra_story_text_input, steps_input], | |
| outputs=[train_status], | |
| ) | |
| reset_button.click(fn=reset_story_model, outputs=[train_status]) | |
| if __name__ == "__main__": | |
| demo.launch() | |