Spaces:
Sleeping
Sleeping
Add Story GPT Python Space
Browse files- README.md +21 -5
- app.py +118 -0
- requirements.txt +2 -0
- story_gpt/__init__.py +4 -0
- story_gpt/config.py +24 -0
- story_gpt/data.py +81 -0
- story_gpt/model.py +111 -0
- story_gpt/service.py +139 -0
- story_gpt/tokenizer.py +69 -0
- story_gpt/trainer.py +52 -0
README.md
CHANGED
|
@@ -1,12 +1,28 @@
|
|
| 1 |
---
|
| 2 |
-
title: Story
|
| 3 |
-
|
| 4 |
-
colorFrom: pink
|
| 5 |
colorTo: red
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version: 6.10.0
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
|
|
|
| 10 |
---
|
| 11 |
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Story GPT Python
|
| 3 |
+
colorFrom: yellow
|
|
|
|
| 4 |
colorTo: red
|
| 5 |
sdk: gradio
|
|
|
|
| 6 |
app_file: app.py
|
| 7 |
pinned: false
|
| 8 |
+
license: mit
|
| 9 |
---
|
| 10 |
|
| 11 |
+
# Story GPT Python
|
| 12 |
+
|
| 13 |
+
This is a tiny story-writing GPT-style language model project written in Python from scratch.
|
| 14 |
+
|
| 15 |
+
## What it includes
|
| 16 |
+
|
| 17 |
+
- Word-level tokenizer
|
| 18 |
+
- Causal transformer decoder with self-attention
|
| 19 |
+
- Story-focused local training corpus
|
| 20 |
+
- Local CPU training loop
|
| 21 |
+
- Checkpoint save and load
|
| 22 |
+
- Gradio user interface
|
| 23 |
+
|
| 24 |
+
## Important
|
| 25 |
+
|
| 26 |
+
- No external pretrained LLM is used
|
| 27 |
+
- This is a small educational GPT-like story model
|
| 28 |
+
- The first generate or train call will initialize and train the model locally
|
app.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
|
| 3 |
+
from story_gpt.config import StoryGPTConfig
|
| 4 |
+
from story_gpt.service import StoryGPTService
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
config = StoryGPTConfig()
|
| 8 |
+
service = StoryGPTService(config=config)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def generate_story(title, genre, tone, idea, opening_line, max_new_tokens, temperature, top_k):
|
| 12 |
+
return service.generate_story(
|
| 13 |
+
title=title,
|
| 14 |
+
genre=genre,
|
| 15 |
+
tone=tone,
|
| 16 |
+
idea=idea,
|
| 17 |
+
opening_line=opening_line,
|
| 18 |
+
max_new_tokens=int(max_new_tokens),
|
| 19 |
+
temperature=float(temperature),
|
| 20 |
+
top_k=int(top_k),
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def train_story_model(extra_story_text, steps):
|
| 25 |
+
return service.train(extra_story_text=extra_story_text, steps=int(steps))
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def reset_story_model():
|
| 29 |
+
return service.reset()
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
with gr.Blocks(
|
| 33 |
+
title="Story GPT Python",
|
| 34 |
+
theme=gr.themes.Soft(primary_hue="amber", secondary_hue="orange"),
|
| 35 |
+
) as demo:
|
| 36 |
+
gr.Markdown(
|
| 37 |
+
"""
|
| 38 |
+
# Story GPT Python
|
| 39 |
+
A tiny story-writing GPT-style model written in Python from scratch.
|
| 40 |
+
|
| 41 |
+
- Causal transformer decoder
|
| 42 |
+
- Word-level tokenizer
|
| 43 |
+
- Story-focused local training corpus
|
| 44 |
+
- No external pretrained LLM
|
| 45 |
+
"""
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
with gr.Tab("Write Story"):
|
| 49 |
+
with gr.Row():
|
| 50 |
+
title_input = gr.Textbox(label="Title", value="The Lantern in the Rain")
|
| 51 |
+
genre_input = gr.Dropdown(
|
| 52 |
+
label="Genre",
|
| 53 |
+
choices=["Fantasy", "Adventure", "Mystery", "Sci-Fi", "Friendship", "Folktale"],
|
| 54 |
+
value="Fantasy",
|
| 55 |
+
)
|
| 56 |
+
tone_input = gr.Dropdown(
|
| 57 |
+
label="Tone",
|
| 58 |
+
choices=["Warm", "Wonder", "Suspense", "Playful", "Calm", "Heroic"],
|
| 59 |
+
value="Wonder",
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
idea_input = gr.Textbox(
|
| 63 |
+
label="Story Idea",
|
| 64 |
+
value="A child finds a glowing lantern that reveals hidden paths after a storm.",
|
| 65 |
+
lines=5,
|
| 66 |
+
)
|
| 67 |
+
opening_line_input = gr.Textbox(
|
| 68 |
+
label="Opening Line",
|
| 69 |
+
value="When the rain stopped, the alley behind Mira's house began to shine.",
|
| 70 |
+
lines=2,
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
with gr.Row():
|
| 74 |
+
max_tokens_input = gr.Slider(30, 220, value=110, step=5, label="Story Length")
|
| 75 |
+
temperature_input = gr.Slider(0.2, 1.4, value=0.85, step=0.05, label="Temperature")
|
| 76 |
+
top_k_input = gr.Slider(1, 24, value=10, step=1, label="Top-K")
|
| 77 |
+
|
| 78 |
+
generate_button = gr.Button("Generate Story", variant="primary")
|
| 79 |
+
output_text = gr.Textbox(label="Story Output", lines=14)
|
| 80 |
+
output_status = gr.Textbox(label="Status", lines=4)
|
| 81 |
+
|
| 82 |
+
with gr.Tab("Train"):
|
| 83 |
+
extra_story_text_input = gr.Textbox(
|
| 84 |
+
label="Extra Story Examples",
|
| 85 |
+
placeholder="Add more short stories, story prompts, or endings to continue training the model.",
|
| 86 |
+
lines=12,
|
| 87 |
+
)
|
| 88 |
+
steps_input = gr.Slider(10, 500, value=140, step=10, label="Training Steps")
|
| 89 |
+
train_button = gr.Button("Train Story Model", variant="primary")
|
| 90 |
+
reset_button = gr.Button("Reset Model")
|
| 91 |
+
train_status = gr.Textbox(label="Training Status", lines=6)
|
| 92 |
+
|
| 93 |
+
generate_button.click(
|
| 94 |
+
fn=generate_story,
|
| 95 |
+
inputs=[
|
| 96 |
+
title_input,
|
| 97 |
+
genre_input,
|
| 98 |
+
tone_input,
|
| 99 |
+
idea_input,
|
| 100 |
+
opening_line_input,
|
| 101 |
+
max_tokens_input,
|
| 102 |
+
temperature_input,
|
| 103 |
+
top_k_input,
|
| 104 |
+
],
|
| 105 |
+
outputs=[output_text, output_status],
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
train_button.click(
|
| 109 |
+
fn=train_story_model,
|
| 110 |
+
inputs=[extra_story_text_input, steps_input],
|
| 111 |
+
outputs=[train_status],
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
reset_button.click(fn=reset_story_model, outputs=[train_status])
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
if __name__ == "__main__":
|
| 118 |
+
demo.launch()
|
requirements.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio>=5.23.0
|
| 2 |
+
torch>=2.3.0
|
story_gpt/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .config import StoryGPTConfig
|
| 2 |
+
from .service import StoryGPTService
|
| 3 |
+
|
| 4 |
+
__all__ = ["StoryGPTConfig", "StoryGPTService"]
|
story_gpt/config.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
@dataclass
|
| 6 |
+
class StoryGPTConfig:
|
| 7 |
+
block_size: int = 72
|
| 8 |
+
batch_size: int = 16
|
| 9 |
+
d_model: int = 112
|
| 10 |
+
n_heads: int = 4
|
| 11 |
+
n_layers: int = 4
|
| 12 |
+
dropout: float = 0.1
|
| 13 |
+
learning_rate: float = 2.0e-3
|
| 14 |
+
bootstrap_steps: int = 120
|
| 15 |
+
cpu_threads: int = 4
|
| 16 |
+
seed: int = 42
|
| 17 |
+
|
| 18 |
+
@property
|
| 19 |
+
def root_dir(self) -> Path:
|
| 20 |
+
return Path(__file__).resolve().parents[1]
|
| 21 |
+
|
| 22 |
+
@property
|
| 23 |
+
def checkpoint_path(self) -> Path:
|
| 24 |
+
return self.root_dir / "artifacts" / "story_gpt_checkpoint.pt"
|
story_gpt/data.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
BASE_CORPUS = """
|
| 2 |
+
Instruction: Write a warm fantasy story.
|
| 3 |
+
Title: The Lantern in the Rain
|
| 4 |
+
Genre: Fantasy
|
| 5 |
+
Tone: Warm
|
| 6 |
+
Idea: A child finds a glowing lantern after a storm.
|
| 7 |
+
Opening: When the rain stopped, the alley behind Mira's house began to shine.
|
| 8 |
+
Story: When the rain stopped, the alley behind Mira's house began to shine. A brass lantern waited beside the last puddle, glowing with a quiet gold light. Mira lifted it, and the beam painted a path through walls the way moonlight slips through leaves. The hidden path led to a garden growing on rooftops, where dripping vines carried silver pears and tiny birds sang in voices soft as bells. An old gardener bowed and told Mira that the lantern appears only to people who still notice wonder after thunder. She picked one silver pear, carried the lantern home, and hung it by her window. Every storm after that left her less afraid, because she knew the dark was only the place where secret gardens begin.
|
| 9 |
+
|
| 10 |
+
Instruction: Write a playful friendship story.
|
| 11 |
+
Title: The Bicycle with Two Bells
|
| 12 |
+
Genre: Friendship
|
| 13 |
+
Tone: Playful
|
| 14 |
+
Idea: Two friends fix a broken bicycle to win a town parade.
|
| 15 |
+
Opening: Ravi found the bicycle under a blanket of marigold petals.
|
| 16 |
+
Story: Ravi found the bicycle under a blanket of marigold petals in his grandfather's shed. The front wheel leaned sideways, one pedal was missing, and two bright brass bells dangled from the handlebars. He called for Sana, who arrived with a box of screws, purple paint, and opinions louder than thunder. They spent the afternoon tightening bolts, polishing spokes, and arguing about whether a bicycle should sparkle. At sunset they rode into the parade with both bells ringing at once, startling pigeons and making children laugh. They did not win the biggest trophy, but the mayor gave them the prize for happiest invention. Ravi said the bicycle rode better because of the bells. Sana said it rode better because they fixed it together.
|
| 17 |
+
|
| 18 |
+
Instruction: Write a suspenseful mystery story.
|
| 19 |
+
Title: The Library Key
|
| 20 |
+
Genre: Mystery
|
| 21 |
+
Tone: Suspense
|
| 22 |
+
Idea: A student discovers a key hidden inside an old atlas.
|
| 23 |
+
Opening: The atlas fell open by itself on the coldest evening of winter.
|
| 24 |
+
Story: The atlas fell open by itself on the coldest evening of winter, dropping a brass key onto Noor's desk. She recognized the number tied to it with blue thread: nineteen, the same number painted on the locked door behind the school library staircase. After everyone left, Noor followed the silent corridor and slipped the key into the iron lock. Inside the room she found shelves of letters never mailed, each one written by students who had once been too afraid to speak. At the center stood a journal kept by the retired librarian, who had hidden the room so secrets could become courage when the time was right. Noor spent all night reading promises, apologies, and impossible dreams. In the morning she placed one fresh letter on the desk outside. It read, The door is open now.
|
| 25 |
+
|
| 26 |
+
Instruction: Write a gentle folktale.
|
| 27 |
+
Title: The River That Remembered
|
| 28 |
+
Genre: Folktale
|
| 29 |
+
Tone: Calm
|
| 30 |
+
Idea: A village forgets its songs until a quiet child listens to the river.
|
| 31 |
+
Opening: Every morning the river moved past the village like folded blue silk.
|
| 32 |
+
Story: Every morning the river moved past the village like folded blue silk, but no one noticed its music anymore. The bakers had forgotten the kneading song, the weavers had forgotten the shuttle rhyme, and even weddings had grown quiet. Only little Tara sat beside the bank and listened with her feet in the water. She heard a melody hidden in the current, a tune that sounded like grandmothers laughing inside clay walls. Tara carried the melody home and hummed it while sweeping the courtyard. Her mother remembered the next line. The baker remembered the drum beat. By evening the entire village was singing to one another across their rooftops. From then on, they said the river remembers for those who are too busy to hear.
|
| 33 |
+
|
| 34 |
+
Instruction: Write a bright science fiction story.
|
| 35 |
+
Title: The Pocket Star Map
|
| 36 |
+
Genre: Sci-Fi
|
| 37 |
+
Tone: Wonder
|
| 38 |
+
Idea: A mechanic builds a map that points toward lost satellites.
|
| 39 |
+
Opening: Arin's workshop smelled like copper, dust, and hot tea.
|
| 40 |
+
Story: Arin's workshop smelled like copper, dust, and hot tea when the tiny projector finally woke in his palm. It scattered blue light over the ceiling, shaping a star map that shifted each time he whispered a satellite name. One marker blinked above the desert, where an old weather station had vanished years ago. Arin packed his tools, followed the projector through the dunes, and found the missing machine half-buried beneath moonlit sand. When he repaired its cracked panel, the station answered with stored messages from children who had once used it to learn constellations. Arin transmitted the lessons back to the city that same night. By morning, every rooftop telescope was turned upward. The mayor called it a scientific recovery. Arin called it returning a lost piece of the sky.
|
| 41 |
+
|
| 42 |
+
Instruction: Write a heroic adventure story.
|
| 43 |
+
Title: The Bridge of Kites
|
| 44 |
+
Genre: Adventure
|
| 45 |
+
Tone: Heroic
|
| 46 |
+
Idea: A village must cross a broken bridge during a wind festival.
|
| 47 |
+
Opening: On the morning of the wind festival, the bridge ropes snapped.
|
| 48 |
+
Story: On the morning of the wind festival, the bridge ropes snapped, leaving the mountain village stranded above a roaring gorge. Mehul looked at the kites lined up for the celebration and saw not decorations, but sails. He tied bamboo poles to the strongest frames, stitched the kite cloth into long panels, and asked every household for spare string. Under a hard blue sky the villagers worked as fast as prayer. By noon a ribbon bridge of red, gold, and indigo stretched from one cliff to the other, trembling but strong. Children crossed first with medicine for the far side, then farmers, then drums, then dancers. When the evening wind rose, the bridge sang above the gorge like a hundred bright birds. The festival carried on, and every kite overhead looked like a flag for courage.
|
| 49 |
+
|
| 50 |
+
Instruction: Write a bedtime story.
|
| 51 |
+
Title: The Moon Rabbit's Blanket
|
| 52 |
+
Genre: Fantasy
|
| 53 |
+
Tone: Calm
|
| 54 |
+
Idea: A moon rabbit sews dreams into a blanket for a sleepless child.
|
| 55 |
+
Opening: Above the sleeping city, one window was still lit.
|
| 56 |
+
Story: Above the sleeping city, one window was still lit, and from the moon a rabbit noticed at once. The little rabbit gathered threads from cloud edges, silver dust from quiet stars, and a single feather from the night wind. It hopped down a beam of moonlight and sat on the child's windowsill, sewing without a sound. Into one square it stitched a forest with glowing fireflies. Into another it stitched a lake so still it could hold a whole sky. When the blanket touched the child, the room filled with the smell of rain and jasmine, and the tired eyes finally closed. By dawn the rabbit was back on the moon, but one silver thread remained on the pillow. The child kept it for years and slept as if the sky itself had tucked them in.
|
| 57 |
+
""".strip()
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def build_training_text(extra_text: str = "") -> str:
|
| 61 |
+
extra = (extra_text or "").strip()
|
| 62 |
+
if not extra:
|
| 63 |
+
return BASE_CORPUS
|
| 64 |
+
return BASE_CORPUS + "\n\n" + extra
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def build_story_prompt(title: str, genre: str, tone: str, idea: str, opening_line: str) -> str:
|
| 68 |
+
clean_title = (title or "A New Story").strip()
|
| 69 |
+
clean_genre = (genre or "Fantasy").strip()
|
| 70 |
+
clean_tone = (tone or "Wonder").strip()
|
| 71 |
+
clean_idea = (idea or "Write a short imaginative story.").strip()
|
| 72 |
+
clean_opening = (opening_line or "Begin with a vivid first line.").strip()
|
| 73 |
+
return (
|
| 74 |
+
f"Instruction: Write a {clean_tone.lower()} {clean_genre.lower()} story.\n"
|
| 75 |
+
f"Title: {clean_title}\n"
|
| 76 |
+
f"Genre: {clean_genre}\n"
|
| 77 |
+
f"Tone: {clean_tone}\n"
|
| 78 |
+
f"Idea: {clean_idea}\n"
|
| 79 |
+
f"Opening: {clean_opening}\n"
|
| 80 |
+
f"Story: {clean_opening}"
|
| 81 |
+
)
|
story_gpt/model.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class CausalSelfAttention(nn.Module):
|
| 8 |
+
def __init__(self, d_model, n_heads, block_size, dropout):
|
| 9 |
+
super().__init__()
|
| 10 |
+
self.n_heads = n_heads
|
| 11 |
+
self.head_dim = d_model // n_heads
|
| 12 |
+
self.qkv = nn.Linear(d_model, 3 * d_model)
|
| 13 |
+
self.out_proj = nn.Linear(d_model, d_model)
|
| 14 |
+
self.dropout = nn.Dropout(dropout)
|
| 15 |
+
mask = torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size)
|
| 16 |
+
self.register_buffer("mask", mask)
|
| 17 |
+
|
| 18 |
+
def forward(self, x):
|
| 19 |
+
batch, seq_len, channels = x.shape
|
| 20 |
+
qkv = self.qkv(x)
|
| 21 |
+
q, k, v = qkv.chunk(3, dim=-1)
|
| 22 |
+
|
| 23 |
+
q = q.view(batch, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
|
| 24 |
+
k = k.view(batch, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
|
| 25 |
+
v = v.view(batch, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
|
| 26 |
+
|
| 27 |
+
att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
| 28 |
+
att = att.masked_fill(self.mask[:, :, :seq_len, :seq_len] == 0, float("-inf"))
|
| 29 |
+
att = torch.softmax(att, dim=-1)
|
| 30 |
+
att = self.dropout(att)
|
| 31 |
+
|
| 32 |
+
out = att @ v
|
| 33 |
+
out = out.transpose(1, 2).contiguous().view(batch, seq_len, channels)
|
| 34 |
+
return self.out_proj(out)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class FeedForward(nn.Module):
|
| 38 |
+
def __init__(self, d_model, dropout):
|
| 39 |
+
super().__init__()
|
| 40 |
+
self.net = nn.Sequential(
|
| 41 |
+
nn.Linear(d_model, 4 * d_model),
|
| 42 |
+
nn.GELU(),
|
| 43 |
+
nn.Linear(4 * d_model, d_model),
|
| 44 |
+
nn.Dropout(dropout),
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
def forward(self, x):
|
| 48 |
+
return self.net(x)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class GPTBlock(nn.Module):
|
| 52 |
+
def __init__(self, d_model, n_heads, block_size, dropout):
|
| 53 |
+
super().__init__()
|
| 54 |
+
self.ln1 = nn.LayerNorm(d_model)
|
| 55 |
+
self.attn = CausalSelfAttention(d_model, n_heads, block_size, dropout)
|
| 56 |
+
self.ln2 = nn.LayerNorm(d_model)
|
| 57 |
+
self.ff = FeedForward(d_model, dropout)
|
| 58 |
+
|
| 59 |
+
def forward(self, x):
|
| 60 |
+
x = x + self.attn(self.ln1(x))
|
| 61 |
+
x = x + self.ff(self.ln2(x))
|
| 62 |
+
return x
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class StoryGPTModel(nn.Module):
|
| 66 |
+
def __init__(self, vocab_size, block_size, d_model, n_heads, n_layers, dropout):
|
| 67 |
+
super().__init__()
|
| 68 |
+
self.block_size = block_size
|
| 69 |
+
self.token_emb = nn.Embedding(vocab_size, d_model)
|
| 70 |
+
self.pos_emb = nn.Embedding(block_size, d_model)
|
| 71 |
+
self.dropout = nn.Dropout(dropout)
|
| 72 |
+
self.blocks = nn.Sequential(
|
| 73 |
+
*[GPTBlock(d_model, n_heads, block_size, dropout) for _ in range(n_layers)]
|
| 74 |
+
)
|
| 75 |
+
self.ln_f = nn.LayerNorm(d_model)
|
| 76 |
+
self.head = nn.Linear(d_model, vocab_size, bias=False)
|
| 77 |
+
self.head.weight = self.token_emb.weight
|
| 78 |
+
|
| 79 |
+
def forward(self, idx, targets=None):
|
| 80 |
+
batch, seq_len = idx.shape
|
| 81 |
+
positions = torch.arange(seq_len, device=idx.device)
|
| 82 |
+
x = self.token_emb(idx) + self.pos_emb(positions)[None, :, :]
|
| 83 |
+
x = self.dropout(x)
|
| 84 |
+
x = self.blocks(x)
|
| 85 |
+
x = self.ln_f(x)
|
| 86 |
+
logits = self.head(x)
|
| 87 |
+
|
| 88 |
+
loss = None
|
| 89 |
+
if targets is not None:
|
| 90 |
+
loss = nn.functional.cross_entropy(
|
| 91 |
+
logits.reshape(-1, logits.size(-1)),
|
| 92 |
+
targets.reshape(-1),
|
| 93 |
+
)
|
| 94 |
+
return logits, loss
|
| 95 |
+
|
| 96 |
+
def generate(self, idx, max_new_tokens, eos_id, temperature=1.0, top_k=8):
|
| 97 |
+
for _ in range(max_new_tokens):
|
| 98 |
+
idx_cond = idx[:, -self.block_size :]
|
| 99 |
+
logits, _ = self(idx_cond)
|
| 100 |
+
logits = logits[:, -1, :] / max(temperature, 1e-4)
|
| 101 |
+
|
| 102 |
+
if top_k is not None and top_k > 0:
|
| 103 |
+
values, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
| 104 |
+
logits[logits < values[:, [-1]]] = float("-inf")
|
| 105 |
+
|
| 106 |
+
probs = torch.softmax(logits, dim=-1)
|
| 107 |
+
next_id = torch.multinomial(probs, num_samples=1)
|
| 108 |
+
idx = torch.cat([idx, next_id], dim=1)
|
| 109 |
+
if int(next_id.item()) == eos_id:
|
| 110 |
+
break
|
| 111 |
+
return idx
|
story_gpt/service.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import shutil
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from .config import StoryGPTConfig
|
| 6 |
+
from .data import build_story_prompt
|
| 7 |
+
from .model import StoryGPTModel
|
| 8 |
+
from .tokenizer import WordTokenizer
|
| 9 |
+
from .trainer import create_model_and_tokenizer, set_seed, train_model
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class StoryGPTService:
|
| 13 |
+
def __init__(self, config: StoryGPTConfig):
|
| 14 |
+
self.config = config
|
| 15 |
+
torch.set_num_threads(max(1, self.config.cpu_threads))
|
| 16 |
+
self.model = None
|
| 17 |
+
self.tokenizer = None
|
| 18 |
+
|
| 19 |
+
def generate_story(
|
| 20 |
+
self,
|
| 21 |
+
title: str,
|
| 22 |
+
genre: str,
|
| 23 |
+
tone: str,
|
| 24 |
+
idea: str,
|
| 25 |
+
opening_line: str,
|
| 26 |
+
max_new_tokens: int,
|
| 27 |
+
temperature: float,
|
| 28 |
+
top_k: int,
|
| 29 |
+
):
|
| 30 |
+
clean_prompt = build_story_prompt(
|
| 31 |
+
title=title,
|
| 32 |
+
genre=genre,
|
| 33 |
+
tone=tone,
|
| 34 |
+
idea=idea,
|
| 35 |
+
opening_line=opening_line,
|
| 36 |
+
)
|
| 37 |
+
self._ensure_ready()
|
| 38 |
+
encoded = self.tokenizer.encode(clean_prompt, add_bos=True)
|
| 39 |
+
idx = torch.tensor(encoded, dtype=torch.long).unsqueeze(0)
|
| 40 |
+
self.model.eval()
|
| 41 |
+
|
| 42 |
+
with torch.inference_mode():
|
| 43 |
+
output = self.model.generate(
|
| 44 |
+
idx=idx,
|
| 45 |
+
max_new_tokens=max_new_tokens,
|
| 46 |
+
eos_id=self.tokenizer.eos_id,
|
| 47 |
+
temperature=temperature,
|
| 48 |
+
top_k=top_k,
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
text = self.tokenizer.decode(output[0].tolist())
|
| 52 |
+
if "Story:" in text:
|
| 53 |
+
text = text.split("Story:", 1)[1].strip()
|
| 54 |
+
status = (
|
| 55 |
+
f"Generated with Story GPT Python. "
|
| 56 |
+
f"Architecture=causal transformer, Vocab={self.tokenizer.vocab_size}, Layers={self.config.n_layers}."
|
| 57 |
+
)
|
| 58 |
+
return text, status
|
| 59 |
+
|
| 60 |
+
def train(self, extra_story_text: str, steps: int):
|
| 61 |
+
steps = max(1, steps)
|
| 62 |
+
checkpoint_exists = self.config.checkpoint_path.exists()
|
| 63 |
+
training_text = extra_story_text or ""
|
| 64 |
+
|
| 65 |
+
if checkpoint_exists:
|
| 66 |
+
self._load_or_initialize(extra_text="")
|
| 67 |
+
|
| 68 |
+
model, tokenizer, encoded = create_model_and_tokenizer(self.config, training_text)
|
| 69 |
+
if checkpoint_exists and self.model is not None and self.tokenizer is not None:
|
| 70 |
+
if tokenizer.stoi == self.tokenizer.stoi:
|
| 71 |
+
model.load_state_dict(self.model.state_dict())
|
| 72 |
+
|
| 73 |
+
losses = train_model(model, encoded, self.config, steps)
|
| 74 |
+
self.model = model
|
| 75 |
+
self.tokenizer = tokenizer
|
| 76 |
+
self._save_checkpoint(extra_text=training_text)
|
| 77 |
+
|
| 78 |
+
return (
|
| 79 |
+
f"Story GPT training finished.\n"
|
| 80 |
+
f"Steps: {steps}\n"
|
| 81 |
+
f"Start Loss: {losses[0]:.4f}\n"
|
| 82 |
+
f"End Loss: {losses[-1]:.4f}\n"
|
| 83 |
+
f"Checkpoint: {self.config.checkpoint_path}"
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
def reset(self):
|
| 87 |
+
checkpoint_dir = self.config.checkpoint_path.parent
|
| 88 |
+
if checkpoint_dir.exists():
|
| 89 |
+
shutil.rmtree(checkpoint_dir)
|
| 90 |
+
self.model = None
|
| 91 |
+
self.tokenizer = None
|
| 92 |
+
return "Story GPT reset complete. Next train or generate call will rebuild from scratch."
|
| 93 |
+
|
| 94 |
+
def _ensure_ready(self):
|
| 95 |
+
if self.model is not None and self.tokenizer is not None:
|
| 96 |
+
return
|
| 97 |
+
self._load_or_initialize(extra_text="")
|
| 98 |
+
|
| 99 |
+
def _load_or_initialize(self, extra_text: str):
|
| 100 |
+
checkpoint = self.config.checkpoint_path
|
| 101 |
+
if checkpoint.exists():
|
| 102 |
+
state = torch.load(checkpoint, map_location="cpu")
|
| 103 |
+
self.tokenizer = WordTokenizer.from_state_dict(state["tokenizer"])
|
| 104 |
+
self.model = StoryGPTModel(
|
| 105 |
+
vocab_size=state["config"]["vocab_size"],
|
| 106 |
+
block_size=state["config"]["block_size"],
|
| 107 |
+
d_model=state["config"]["d_model"],
|
| 108 |
+
n_heads=state["config"]["n_heads"],
|
| 109 |
+
n_layers=state["config"]["n_layers"],
|
| 110 |
+
dropout=state["config"]["dropout"],
|
| 111 |
+
)
|
| 112 |
+
self.model.load_state_dict(state["model"])
|
| 113 |
+
self.model.eval()
|
| 114 |
+
return
|
| 115 |
+
|
| 116 |
+
set_seed(self.config.seed)
|
| 117 |
+
self.model, self.tokenizer, encoded = create_model_and_tokenizer(self.config, extra_text)
|
| 118 |
+
train_model(self.model, encoded, self.config, self.config.bootstrap_steps)
|
| 119 |
+
self._save_checkpoint(extra_text=extra_text)
|
| 120 |
+
|
| 121 |
+
def _save_checkpoint(self, extra_text: str):
|
| 122 |
+
checkpoint = self.config.checkpoint_path
|
| 123 |
+
checkpoint.parent.mkdir(parents=True, exist_ok=True)
|
| 124 |
+
torch.save(
|
| 125 |
+
{
|
| 126 |
+
"model": self.model.state_dict(),
|
| 127 |
+
"tokenizer": self.tokenizer.state_dict(),
|
| 128 |
+
"config": {
|
| 129 |
+
"vocab_size": self.tokenizer.vocab_size,
|
| 130 |
+
"block_size": self.config.block_size,
|
| 131 |
+
"d_model": self.config.d_model,
|
| 132 |
+
"n_heads": self.config.n_heads,
|
| 133 |
+
"n_layers": self.config.n_layers,
|
| 134 |
+
"dropout": self.config.dropout,
|
| 135 |
+
"extra_text": extra_text,
|
| 136 |
+
},
|
| 137 |
+
},
|
| 138 |
+
checkpoint,
|
| 139 |
+
)
|
story_gpt/tokenizer.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
TOKEN_PATTERN = re.compile(r"\n|[A-Za-z0-9_']+|[^\w\s]")
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class WordTokenizer:
|
| 8 |
+
def __init__(self):
|
| 9 |
+
self.special_tokens = ["<pad>", "<unk>", "<bos>", "<eos>"]
|
| 10 |
+
self.stoi = {}
|
| 11 |
+
self.itos = {}
|
| 12 |
+
|
| 13 |
+
@property
|
| 14 |
+
def vocab_size(self):
|
| 15 |
+
return len(self.stoi)
|
| 16 |
+
|
| 17 |
+
@property
|
| 18 |
+
def bos_id(self):
|
| 19 |
+
return self.stoi["<bos>"]
|
| 20 |
+
|
| 21 |
+
@property
|
| 22 |
+
def eos_id(self):
|
| 23 |
+
return self.stoi["<eos>"]
|
| 24 |
+
|
| 25 |
+
def tokenize(self, text: str):
|
| 26 |
+
return TOKEN_PATTERN.findall(text)
|
| 27 |
+
|
| 28 |
+
def fit(self, text: str):
|
| 29 |
+
vocab = self.special_tokens + sorted(set(self.tokenize(text)))
|
| 30 |
+
self.stoi = {token: idx for idx, token in enumerate(vocab)}
|
| 31 |
+
self.itos = {idx: token for token, idx in self.stoi.items()}
|
| 32 |
+
return self
|
| 33 |
+
|
| 34 |
+
def encode(self, text: str, add_bos: bool = False, add_eos: bool = False):
|
| 35 |
+
tokens = self.tokenize(text)
|
| 36 |
+
ids = [self.stoi.get(token, self.stoi["<unk>"]) for token in tokens]
|
| 37 |
+
if add_bos:
|
| 38 |
+
ids = [self.bos_id] + ids
|
| 39 |
+
if add_eos:
|
| 40 |
+
ids = ids + [self.eos_id]
|
| 41 |
+
return ids
|
| 42 |
+
|
| 43 |
+
def decode(self, ids):
|
| 44 |
+
tokens = []
|
| 45 |
+
for idx in ids:
|
| 46 |
+
token = self.itos.get(int(idx), "<unk>")
|
| 47 |
+
if token in self.special_tokens:
|
| 48 |
+
continue
|
| 49 |
+
tokens.append(token)
|
| 50 |
+
|
| 51 |
+
text = ""
|
| 52 |
+
for token in tokens:
|
| 53 |
+
if token == "\n":
|
| 54 |
+
text = text.rstrip() + "\n"
|
| 55 |
+
elif token in {".", ",", "!", "?", ":", ";"}:
|
| 56 |
+
text = text.rstrip() + token + " "
|
| 57 |
+
else:
|
| 58 |
+
text += token + " "
|
| 59 |
+
return text.strip()
|
| 60 |
+
|
| 61 |
+
def state_dict(self):
|
| 62 |
+
return {"stoi": self.stoi}
|
| 63 |
+
|
| 64 |
+
@classmethod
|
| 65 |
+
def from_state_dict(cls, state):
|
| 66 |
+
tok = cls()
|
| 67 |
+
tok.stoi = dict(state["stoi"])
|
| 68 |
+
tok.itos = {idx: token for token, idx in tok.stoi.items()}
|
| 69 |
+
return tok
|
story_gpt/trainer.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from .data import build_training_text
|
| 6 |
+
from .model import StoryGPTModel
|
| 7 |
+
from .tokenizer import WordTokenizer
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def set_seed(seed: int):
|
| 11 |
+
random.seed(seed)
|
| 12 |
+
torch.manual_seed(seed)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def create_model_and_tokenizer(config, extra_text=""):
|
| 16 |
+
text = build_training_text(extra_text)
|
| 17 |
+
tokenizer = WordTokenizer().fit(text)
|
| 18 |
+
encoded = tokenizer.encode(text, add_bos=True, add_eos=True)
|
| 19 |
+
encoded = torch.tensor(encoded, dtype=torch.long)
|
| 20 |
+
model = StoryGPTModel(
|
| 21 |
+
vocab_size=tokenizer.vocab_size,
|
| 22 |
+
block_size=config.block_size,
|
| 23 |
+
d_model=config.d_model,
|
| 24 |
+
n_heads=config.n_heads,
|
| 25 |
+
n_layers=config.n_layers,
|
| 26 |
+
dropout=config.dropout,
|
| 27 |
+
)
|
| 28 |
+
return model, tokenizer, encoded
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def build_batch(encoded, block_size, batch_size):
|
| 32 |
+
max_start = max(1, len(encoded) - block_size - 1)
|
| 33 |
+
starts = torch.randint(0, max_start, (batch_size,))
|
| 34 |
+
x = torch.stack([encoded[start : start + block_size] for start in starts])
|
| 35 |
+
y = torch.stack([encoded[start + 1 : start + block_size + 1] for start in starts])
|
| 36 |
+
return x, y
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def train_model(model, encoded, config, steps):
|
| 40 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
|
| 41 |
+
model.train()
|
| 42 |
+
losses = []
|
| 43 |
+
|
| 44 |
+
for _ in range(steps):
|
| 45 |
+
xb, yb = build_batch(encoded, config.block_size, config.batch_size)
|
| 46 |
+
_, loss = model(xb, targets=yb)
|
| 47 |
+
optimizer.zero_grad()
|
| 48 |
+
loss.backward()
|
| 49 |
+
optimizer.step()
|
| 50 |
+
losses.append(float(loss.item()))
|
| 51 |
+
|
| 52 |
+
return losses
|