abersbail's picture
Improve story quality with structured composer and educational story mode
f82ff10 verified
import gradio as gr
from story_gpt.config import StoryGPTConfig
from story_gpt.service import StoryGPTService
config = StoryGPTConfig()
service = StoryGPTService(config=config)
def generate_story(title, genre, tone, idea, opening_line, max_new_tokens, temperature, top_k):
return service.generate_story(
title=title,
genre=genre,
tone=tone,
idea=idea,
opening_line=opening_line,
max_new_tokens=int(max_new_tokens),
temperature=float(temperature),
top_k=int(top_k),
)
def train_story_model(extra_story_text, steps):
return service.train(extra_story_text=extra_story_text, steps=int(steps))
def reset_story_model():
return service.reset()
with gr.Blocks(
title="Story GPT Python",
theme=gr.themes.Soft(primary_hue="amber", secondary_hue="orange"),
) as demo:
gr.Markdown(
"""
# Story GPT Python
A tiny story-writing GPT-style model written in Python from scratch.
- Causal transformer decoder
- Word-level tokenizer
- Story-focused local training corpus
- Structured local story composer for clean long-form output
- No external pretrained LLM
"""
)
with gr.Tab("Write Story"):
with gr.Row():
title_input = gr.Textbox(label="Title", value="The Intelligent Project")
genre_input = gr.Dropdown(
label="Genre",
choices=[
"Fantasy",
"Adventure",
"Mystery",
"Sci-Fi",
"Friendship",
"Folktale",
"Educational",
],
value="Educational",
)
tone_input = gr.Dropdown(
label="Tone",
choices=["Warm", "Wonder", "Suspense", "Playful", "Calm", "Heroic", "Inspiring"],
value="Inspiring",
)
idea_input = gr.Textbox(
label="Story Idea",
value=(
"A student builds an intelligent AI project step by step using Python, data analysis, "
"machine learning, deep learning, and language models."
),
lines=5,
)
opening_line_input = gr.Textbox(
label="Opening Line",
value="Arman was a student who loved technology.",
lines=2,
)
with gr.Row():
max_tokens_input = gr.Slider(30, 220, value=110, step=5, label="Story Length")
temperature_input = gr.Slider(0.2, 1.4, value=0.85, step=0.05, label="Temperature")
top_k_input = gr.Slider(1, 24, value=10, step=1, label="Top-K")
generate_button = gr.Button("Generate Story", variant="primary")
output_text = gr.Textbox(label="Story Output", lines=14)
output_status = gr.Textbox(label="Status", lines=4)
with gr.Tab("Train"):
extra_story_text_input = gr.Textbox(
label="Extra Story Examples",
placeholder="Add more short stories, story prompts, or endings to continue training the model.",
lines=12,
)
steps_input = gr.Slider(10, 500, value=140, step=10, label="Training Steps")
train_button = gr.Button("Train Story Model", variant="primary")
reset_button = gr.Button("Reset Model")
train_status = gr.Textbox(label="Training Status", lines=6)
generate_button.click(
fn=generate_story,
inputs=[
title_input,
genre_input,
tone_input,
idea_input,
opening_line_input,
max_tokens_input,
temperature_input,
top_k_input,
],
outputs=[output_text, output_status],
)
train_button.click(
fn=train_story_model,
inputs=[extra_story_text_input, steps_input],
outputs=[train_status],
)
reset_button.click(fn=reset_story_model, outputs=[train_status])
if __name__ == "__main__":
demo.launch()