Spaces:
Paused
Paused
ssalb
commited on
Commit
·
bfb4432
1
Parent(s):
3b5aca7
Update space with latest code and dependencies on Wed Jan 1 21:30:51 UTC 2025
Browse files- LICENSE +21 -0
- README.md +15 -4
- app.py +133 -4
- requirements.txt +79 -0
- story_beam_search/beam_search.py +103 -0
- story_beam_search/scoring.py +180 -0
- story_beam_search/stories_generator.py +191 -0
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2023 Salvador Salazar
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
CHANGED
|
@@ -1,13 +1,24 @@
|
|
| 1 |
---
|
| 2 |
title: Story Generator
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
|
|
|
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 5.9.1
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
license: mit
|
| 11 |
---
|
| 12 |
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
title: Story Generator
|
| 3 |
+
emoji: 📚
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
python_version: 3.10.13
|
| 6 |
+
colorTo: pink
|
| 7 |
sdk: gradio
|
| 8 |
sdk_version: 5.9.1
|
| 9 |
app_file: app.py
|
| 10 |
pinned: false
|
| 11 |
+
preload_from_hub:
|
| 12 |
+
# - meta-llama/Llama-3.2-1B-Instruct # Do not preload llama, as the token is not available at build time
|
| 13 |
+
- google-bert/bert-base-uncased
|
| 14 |
+
- facebook/bart-large-mnli
|
| 15 |
license: mit
|
| 16 |
---
|
| 17 |
|
| 18 |
+
## Project Overview
|
| 19 |
+
|
| 20 |
+
The Story Generator project leverages advanced natural language processing models to generate coherent and engaging stories. By utilizing models such as GPT-2, BERT, and BART, this project aims to provide users with a tool to create narratives based on given prompts. The application is built using Gradio for an interactive user interface, making it easy to input prompts and receive generated stories in real-time.
|
| 21 |
+
|
| 22 |
+
The main purpose of this project is to explore the idea of beam search for selecting stories with high coherence, fluency, and genre alignment scores. This ensures that the generated stories are not only creative but also maintain a logical flow and adhere to the specified genre.
|
| 23 |
+
|
| 24 |
+
Note that the final implementation is not strictly beam search and was modified to allow more diversity (creativity) inspired by the DVTS method in [this blog post](https://huggingface.co/spaces/HuggingFaceH4/blogpost-scaling-test-time-compute).
|
app.py
CHANGED
|
@@ -1,7 +1,136 @@
|
|
| 1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
-
|
| 4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
-
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
+
from typing import Literal
|
| 3 |
+
from pydantic import BaseModel, Field, constr
|
| 4 |
+
from story_beam_search.stories_generator import StoryGenerationSystem
|
| 5 |
+
from typing import Tuple, List
|
| 6 |
|
| 7 |
+
genre_choices = [
|
| 8 |
+
"children",
|
| 9 |
+
"mystery",
|
| 10 |
+
"adventure",
|
| 11 |
+
"sci-fi",
|
| 12 |
+
"fantasy",
|
| 13 |
+
"romance",
|
| 14 |
+
"comedy",
|
| 15 |
+
"drama",
|
| 16 |
+
"horror",
|
| 17 |
+
]
|
| 18 |
|
| 19 |
+
class InputModel(BaseModel):
|
| 20 |
+
prompt: str
|
| 21 |
+
genre: str
|
| 22 |
+
num_stories: int = Field(3, ge=2, le=7)
|
| 23 |
+
temperature: float = Field(2.5, ge=0.7, le=3.5)
|
| 24 |
+
max_length: int = Field(60, ge=30, le=200)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def create_story_generation_interface() -> gr.Interface:
|
| 28 |
+
# Initialize the story generation system
|
| 29 |
+
system = StoryGenerationSystem()
|
| 30 |
+
system.initialize()
|
| 31 |
+
|
| 32 |
+
def generate_stories(
|
| 33 |
+
prompt: str, genre: str, num_stories: int, temperature: float, max_length: int
|
| 34 |
+
) -> Tuple[str, List[str]]:
|
| 35 |
+
"""
|
| 36 |
+
Generate and evaluate stories based on user input.
|
| 37 |
+
Returns a tuple of (detailed_scores, story_texts).
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
# Validate inputs.Gradio seems to validate chioces but not the range of the values
|
| 41 |
+
input_values = InputModel(
|
| 42 |
+
prompt=prompt, genre=genre, num_stories=num_stories, temperature=temperature, max_length=max_length
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
# Update beam search config with user parameters
|
| 46 |
+
system.beam_search.config.temperature = input_values.temperature
|
| 47 |
+
system.beam_search.config.max_length = input_values.max_length
|
| 48 |
+
|
| 49 |
+
# Generate and evaluate stories
|
| 50 |
+
ranked_stories = system.generate_and_evaluate(
|
| 51 |
+
input_values.prompt, input_values.genre, num_stories=input_values.num_stories
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
# Format detailed scores
|
| 55 |
+
detailed_scores = ""
|
| 56 |
+
story_texts = []
|
| 57 |
+
|
| 58 |
+
for i, (story, scores) in enumerate(ranked_stories, 1):
|
| 59 |
+
detailed_scores += f"Story {i}:\n"
|
| 60 |
+
detailed_scores += f"Total Score: {scores.total:.3f}\n"
|
| 61 |
+
detailed_scores += f"Coherence: {scores.coherence:.3f}\n"
|
| 62 |
+
detailed_scores += f"Fluency: {scores.fluency:.3f}\n"
|
| 63 |
+
detailed_scores += f"Genre Alignment: {scores.genre_alignment:.3f}\n"
|
| 64 |
+
detailed_scores += "-" * 50 + "\n"
|
| 65 |
+
|
| 66 |
+
story_texts.append(f"Story {i}:\n{story}\n")
|
| 67 |
+
|
| 68 |
+
return detailed_scores, "\n".join(story_texts)
|
| 69 |
+
|
| 70 |
+
# Define interface components
|
| 71 |
+
prompt_input = gr.Textbox(
|
| 72 |
+
label="Story Prompt",
|
| 73 |
+
placeholder="Enter the beginning of your story...",
|
| 74 |
+
lines=3,
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
genre_input = gr.Dropdown(
|
| 78 |
+
choices=genre_choices,
|
| 79 |
+
label="Genre",
|
| 80 |
+
value="fantasy",
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
num_stories_input = gr.Slider(
|
| 84 |
+
minimum=2, maximum=7, value=3, step=1, label="Number of Stories to Generate"
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
temperature_input = gr.Slider(
|
| 88 |
+
minimum=0.7, maximum=3.5, value=2.5, step=0.1, label="Temperature (Creativity)"
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
max_length_input = gr.Slider(
|
| 92 |
+
minimum=30, maximum=200, value=60, step=30, label="Maximum Length"
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
# Output components
|
| 96 |
+
scores_output = gr.Textbox(label="Detailed Scores", lines=10, interactive=False)
|
| 97 |
+
|
| 98 |
+
stories_output = gr.Textbox(label="Generated Stories", lines=15, interactive=False)
|
| 99 |
+
|
| 100 |
+
# Create the interface
|
| 101 |
+
interface = gr.Interface(
|
| 102 |
+
fn=generate_stories,
|
| 103 |
+
inputs=[
|
| 104 |
+
prompt_input,
|
| 105 |
+
genre_input,
|
| 106 |
+
num_stories_input,
|
| 107 |
+
temperature_input,
|
| 108 |
+
max_length_input,
|
| 109 |
+
],
|
| 110 |
+
outputs=[scores_output, stories_output],
|
| 111 |
+
title="AI Story Generator",
|
| 112 |
+
description="""
|
| 113 |
+
Generate creative stories using AI! Enter a prompt and choose your preferences.
|
| 114 |
+
The system will generate multiple stories and evaluate them based on coherence,
|
| 115 |
+
fluency, and genre alignment.
|
| 116 |
+
""",
|
| 117 |
+
examples=[
|
| 118 |
+
["Once upon a time in a magical forest,", "fantasy", 3, 1.8, 150],
|
| 119 |
+
[
|
| 120 |
+
"The detective knelt beside the bloodstained carpet, her gaze sharp as she traced the faint outline of a shoeprint.",
|
| 121 |
+
"mystery",
|
| 122 |
+
3,
|
| 123 |
+
2.7,
|
| 124 |
+
200,
|
| 125 |
+
],
|
| 126 |
+
],
|
| 127 |
+
theme=gr.themes.Soft(),
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
return interface
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
if __name__ == "__main__":
|
| 134 |
+
# Create and launch the interface
|
| 135 |
+
interface = create_story_generation_interface()
|
| 136 |
+
interface.launch(show_error=True)
|
requirements.txt
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
aiofiles==23.2.1 ; python_full_version == "3.10.13"
|
| 2 |
+
annotated-types==0.7.0 ; python_full_version == "3.10.13"
|
| 3 |
+
anyio==4.7.0 ; python_full_version == "3.10.13"
|
| 4 |
+
certifi==2024.12.14 ; python_full_version == "3.10.13"
|
| 5 |
+
charset-normalizer==3.4.1 ; python_full_version == "3.10.13"
|
| 6 |
+
click==8.1.8 ; sys_platform != "emscripten" and python_full_version == "3.10.13"
|
| 7 |
+
colorama==0.4.6 ; python_full_version == "3.10.13" and platform_system == "Windows"
|
| 8 |
+
exceptiongroup==1.2.2 ; python_full_version == "3.10.13"
|
| 9 |
+
fastapi==0.115.6 ; python_full_version == "3.10.13"
|
| 10 |
+
ffmpy==0.5.0 ; python_full_version == "3.10.13"
|
| 11 |
+
filelock==3.16.1 ; python_full_version == "3.10.13"
|
| 12 |
+
fsspec==2024.12.0 ; python_full_version == "3.10.13"
|
| 13 |
+
gradio-client==1.5.2 ; python_full_version == "3.10.13"
|
| 14 |
+
gradio==5.9.1 ; python_full_version == "3.10.13"
|
| 15 |
+
h11==0.14.0 ; python_full_version == "3.10.13"
|
| 16 |
+
httpcore==1.0.7 ; python_full_version == "3.10.13"
|
| 17 |
+
httpx==0.28.1 ; python_full_version == "3.10.13"
|
| 18 |
+
huggingface-hub==0.27.0 ; python_full_version == "3.10.13"
|
| 19 |
+
idna==3.10 ; python_full_version == "3.10.13"
|
| 20 |
+
jinja2==3.1.5 ; python_full_version == "3.10.13"
|
| 21 |
+
joblib==1.4.2 ; python_full_version == "3.10.13"
|
| 22 |
+
markdown-it-py==3.0.0 ; sys_platform != "emscripten" and python_full_version == "3.10.13"
|
| 23 |
+
markupsafe==2.1.5 ; python_full_version == "3.10.13"
|
| 24 |
+
mdurl==0.1.2 ; sys_platform != "emscripten" and python_full_version == "3.10.13"
|
| 25 |
+
mpmath==1.3.0 ; python_full_version == "3.10.13"
|
| 26 |
+
networkx==3.4.2 ; python_full_version == "3.10.13"
|
| 27 |
+
numpy==2.2.1 ; python_full_version == "3.10.13"
|
| 28 |
+
nvidia-cublas-cu12==12.1.3.1 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version == "3.10.13"
|
| 29 |
+
nvidia-cuda-cupti-cu12==12.1.105 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version == "3.10.13"
|
| 30 |
+
nvidia-cuda-nvrtc-cu12==12.1.105 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version == "3.10.13"
|
| 31 |
+
nvidia-cuda-runtime-cu12==12.1.105 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version == "3.10.13"
|
| 32 |
+
nvidia-cudnn-cu12==9.1.0.70 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version == "3.10.13"
|
| 33 |
+
nvidia-cufft-cu12==11.0.2.54 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version == "3.10.13"
|
| 34 |
+
nvidia-curand-cu12==10.3.2.106 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version == "3.10.13"
|
| 35 |
+
nvidia-cusolver-cu12==11.4.5.107 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version == "3.10.13"
|
| 36 |
+
nvidia-cusparse-cu12==12.1.0.106 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version == "3.10.13"
|
| 37 |
+
nvidia-nccl-cu12==2.20.5 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version == "3.10.13"
|
| 38 |
+
nvidia-nvjitlink-cu12==12.6.85 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version == "3.10.13"
|
| 39 |
+
nvidia-nvtx-cu12==12.1.105 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version == "3.10.13"
|
| 40 |
+
orjson==3.10.13 ; python_full_version == "3.10.13"
|
| 41 |
+
packaging==24.2 ; python_full_version == "3.10.13"
|
| 42 |
+
pandas==2.2.3 ; python_full_version == "3.10.13"
|
| 43 |
+
pillow==11.0.0 ; python_full_version == "3.10.13"
|
| 44 |
+
protobuf==5.29.2 ; python_full_version == "3.10.13"
|
| 45 |
+
pydantic-core==2.27.2 ; python_full_version == "3.10.13"
|
| 46 |
+
pydantic==2.10.4 ; python_full_version == "3.10.13"
|
| 47 |
+
pydub==0.25.1 ; python_full_version == "3.10.13"
|
| 48 |
+
pygments==2.18.0 ; sys_platform != "emscripten" and python_full_version == "3.10.13"
|
| 49 |
+
python-dateutil==2.9.0.post0 ; python_full_version == "3.10.13"
|
| 50 |
+
python-multipart==0.0.20 ; python_full_version == "3.10.13"
|
| 51 |
+
pytz==2024.2 ; python_full_version == "3.10.13"
|
| 52 |
+
pyyaml==6.0.2 ; python_full_version == "3.10.13"
|
| 53 |
+
regex==2024.11.6 ; python_full_version == "3.10.13"
|
| 54 |
+
requests==2.32.3 ; python_full_version == "3.10.13"
|
| 55 |
+
rich==13.9.4 ; sys_platform != "emscripten" and python_full_version == "3.10.13"
|
| 56 |
+
ruff==0.8.4 ; sys_platform != "emscripten" and python_full_version == "3.10.13"
|
| 57 |
+
safehttpx==0.1.6 ; python_full_version == "3.10.13"
|
| 58 |
+
safetensors==0.4.5 ; python_full_version == "3.10.13"
|
| 59 |
+
scikit-learn==1.6.0 ; python_full_version == "3.10.13"
|
| 60 |
+
scipy==1.14.1 ; python_full_version == "3.10.13"
|
| 61 |
+
semantic-version==2.10.0 ; python_full_version == "3.10.13"
|
| 62 |
+
shellingham==1.5.4 ; sys_platform != "emscripten" and python_full_version == "3.10.13"
|
| 63 |
+
six==1.17.0 ; python_full_version == "3.10.13"
|
| 64 |
+
sniffio==1.3.1 ; python_full_version == "3.10.13"
|
| 65 |
+
starlette==0.41.3 ; python_full_version == "3.10.13"
|
| 66 |
+
sympy==1.13.3 ; python_full_version == "3.10.13"
|
| 67 |
+
threadpoolctl==3.5.0 ; python_full_version == "3.10.13"
|
| 68 |
+
tokenizers==0.21.0 ; python_full_version == "3.10.13"
|
| 69 |
+
tomlkit==0.13.2 ; python_full_version == "3.10.13"
|
| 70 |
+
torch==2.4.0 ; python_full_version == "3.10.13"
|
| 71 |
+
tqdm==4.67.1 ; python_full_version == "3.10.13"
|
| 72 |
+
transformers==4.47.1 ; python_full_version == "3.10.13"
|
| 73 |
+
triton==3.0.0 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version == "3.10.13"
|
| 74 |
+
typer==0.15.1 ; sys_platform != "emscripten" and python_full_version == "3.10.13"
|
| 75 |
+
typing-extensions==4.12.2 ; python_full_version == "3.10.13"
|
| 76 |
+
tzdata==2024.2 ; python_full_version == "3.10.13"
|
| 77 |
+
urllib3==2.3.0 ; python_full_version == "3.10.13"
|
| 78 |
+
uvicorn==0.34.0 ; sys_platform != "emscripten" and python_full_version == "3.10.13"
|
| 79 |
+
websockets==14.1 ; python_full_version == "3.10.13"
|
story_beam_search/beam_search.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic.dataclasses import dataclass
|
| 2 |
+
from typing import Optional
|
| 3 |
+
import torch
|
| 4 |
+
from transformers import PreTrainedModel, PreTrainedTokenizer
|
| 5 |
+
|
| 6 |
+
from story_beam_search.scoring import StoryScorer
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@dataclass
|
| 10 |
+
class BeamSearchConfig:
|
| 11 |
+
num_beams: int = 3
|
| 12 |
+
num_return_sequences: int = 3
|
| 13 |
+
max_length: int = 100
|
| 14 |
+
no_repeat_ngram_size: int = 2
|
| 15 |
+
temperature: float = 0.8
|
| 16 |
+
top_k: int = 8
|
| 17 |
+
top_p: float = 0.95
|
| 18 |
+
num_iterations: int = 3
|
| 19 |
+
continuation_length: int = 10
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class BeamSearchGenerator:
|
| 23 |
+
def __init__(
|
| 24 |
+
self,
|
| 25 |
+
model: PreTrainedModel,
|
| 26 |
+
tokenizer: PreTrainedTokenizer,
|
| 27 |
+
device: torch.device,
|
| 28 |
+
config: Optional[BeamSearchConfig] = None,
|
| 29 |
+
):
|
| 30 |
+
self.model = model
|
| 31 |
+
self.tokenizer = tokenizer
|
| 32 |
+
self.device = device
|
| 33 |
+
self.config = config or BeamSearchConfig()
|
| 34 |
+
|
| 35 |
+
def generate_iterations(
|
| 36 |
+
self, prompt: str, genre: str, evaluator: StoryScorer
|
| 37 |
+
) -> list[str]:
|
| 38 |
+
"""
|
| 39 |
+
Generate story continuations using multiple iterations of beam search.
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
# Adding some instructions to the prompt. These are removed in the end
|
| 43 |
+
instructions = (
|
| 44 |
+
f"Continue the following story in the {genre} genre, "
|
| 45 |
+
"ensuring coherence with the tone, characters, and narrative established so far:\n"
|
| 46 |
+
)
|
| 47 |
+
instructions_len = len(instructions)
|
| 48 |
+
|
| 49 |
+
stories = self._generate_single_iteration(instructions + prompt)
|
| 50 |
+
ranked_stories = evaluator.evaluate_multiple(
|
| 51 |
+
[story[instructions_len:] for story in stories]
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
stories = [story for story, _ in ranked_stories[:self.config.num_beams]]
|
| 55 |
+
|
| 56 |
+
if stories:
|
| 57 |
+
for _ in range(self.config.num_iterations):
|
| 58 |
+
all_stories = []
|
| 59 |
+
for story in stories:
|
| 60 |
+
continuations = self._generate_single_iteration(
|
| 61 |
+
instructions + story
|
| 62 |
+
)
|
| 63 |
+
all_stories.extend(continuations)
|
| 64 |
+
ranked_stories = evaluator.evaluate_multiple(
|
| 65 |
+
[story[instructions_len:] for story in all_stories]
|
| 66 |
+
)
|
| 67 |
+
stories = [story for story, _ in ranked_stories[:self.config.num_beams]]
|
| 68 |
+
|
| 69 |
+
return stories
|
| 70 |
+
|
| 71 |
+
def _generate_single_iteration(self, prompt: str) -> list[str]:
|
| 72 |
+
"""
|
| 73 |
+
Generate multiple continuations for a single iteration using beam search.
|
| 74 |
+
"""
|
| 75 |
+
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
|
| 76 |
+
input_ids = inputs["input_ids"]
|
| 77 |
+
attention_mask = inputs["attention_mask"]
|
| 78 |
+
|
| 79 |
+
self.config.continuation_length = (
|
| 80 |
+
len(input_ids[0]) + self.config.max_length // self.config.num_iterations
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
with torch.no_grad():
|
| 84 |
+
outputs = self.model.generate(
|
| 85 |
+
input_ids=input_ids,
|
| 86 |
+
attention_mask=attention_mask,
|
| 87 |
+
max_length=self.config.continuation_length,
|
| 88 |
+
num_beams=self.config.num_beams,
|
| 89 |
+
num_return_sequences=self.config.num_return_sequences,
|
| 90 |
+
early_stopping=True,
|
| 91 |
+
no_repeat_ngram_size=self.config.no_repeat_ngram_size,
|
| 92 |
+
temperature=self.config.temperature,
|
| 93 |
+
top_k=self.config.top_k,
|
| 94 |
+
top_p=self.config.top_p,
|
| 95 |
+
do_sample=True,
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
stories = []
|
| 99 |
+
for output in outputs:
|
| 100 |
+
text = self.tokenizer.decode(output, skip_special_tokens=True)
|
| 101 |
+
stories.append(text)
|
| 102 |
+
|
| 103 |
+
return stories
|
story_beam_search/scoring.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from typing import Protocol
|
| 4 |
+
import torch
|
| 5 |
+
from transformers import PreTrainedModel, PreTrainedTokenizer, Pipeline
|
| 6 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class StoryScorer(Protocol):
|
| 10 |
+
"""Protocol defining the interface for story scoring components."""
|
| 11 |
+
|
| 12 |
+
def score(self, story: str) -> float:
|
| 13 |
+
"""Return a score between 0 and 1."""
|
| 14 |
+
...
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@dataclass
|
| 18 |
+
class CombinedScore:
|
| 19 |
+
coherence: float = 0.0
|
| 20 |
+
fluency: float = 0.0
|
| 21 |
+
genre_alignment: float = 0.0
|
| 22 |
+
total: float = 0.0
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class CoherenceScorer(StoryScorer):
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
model: PreTrainedModel,
|
| 29 |
+
tokenizer: PreTrainedTokenizer,
|
| 30 |
+
device: torch.device,
|
| 31 |
+
max_pairs: int = 3,
|
| 32 |
+
):
|
| 33 |
+
self.model = model
|
| 34 |
+
self.tokenizer = tokenizer
|
| 35 |
+
self.device = device
|
| 36 |
+
self.max_pairs = max_pairs
|
| 37 |
+
|
| 38 |
+
def score(self, story: str) -> float:
|
| 39 |
+
"""Calculate coherence score based on sentences cosine similarity."""
|
| 40 |
+
|
| 41 |
+
sentences = [s.strip() for s in story.split(".") if s.strip()]
|
| 42 |
+
|
| 43 |
+
embeddings = []
|
| 44 |
+
|
| 45 |
+
# Generate embeddings for each sentence
|
| 46 |
+
for sentence in sentences:
|
| 47 |
+
inputs = self.tokenizer(sentence, return_tensors="pt").to(self.device)
|
| 48 |
+
with torch.no_grad():
|
| 49 |
+
emb = self.model.bert(**inputs).last_hidden_state[:, 0, :]
|
| 50 |
+
embeddings.append(emb.cpu().numpy())
|
| 51 |
+
|
| 52 |
+
# Calculate cosine similarity between adjacent embeddings
|
| 53 |
+
coherence_scores = []
|
| 54 |
+
for i in range(len(embeddings) - 1):
|
| 55 |
+
sim = cosine_similarity(embeddings[i], embeddings[i + 1])[0][0]
|
| 56 |
+
coherence_scores.append(sim)
|
| 57 |
+
|
| 58 |
+
# Average coherence score
|
| 59 |
+
avg_coherence = (
|
| 60 |
+
sum(coherence_scores) / len(coherence_scores) if coherence_scores else 0.0
|
| 61 |
+
)
|
| 62 |
+
return avg_coherence
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class FluencyScorer(StoryScorer):
|
| 66 |
+
def __init__(
|
| 67 |
+
self,
|
| 68 |
+
model: PreTrainedModel,
|
| 69 |
+
tokenizer: PreTrainedTokenizer,
|
| 70 |
+
device: torch.device,
|
| 71 |
+
):
|
| 72 |
+
self.model = model
|
| 73 |
+
self.tokenizer = tokenizer
|
| 74 |
+
self.device = device
|
| 75 |
+
|
| 76 |
+
def score(self, story: str) -> float:
|
| 77 |
+
# Mask each token in the story and calculate the probability of the original token
|
| 78 |
+
# Fluency is measured by the average probability of each token in the story
|
| 79 |
+
inputs = self.tokenizer(story, return_tensors="pt").to(self.device)
|
| 80 |
+
input_ids = inputs.input_ids
|
| 81 |
+
mask_token_id = self.tokenizer.mask_token_id
|
| 82 |
+
|
| 83 |
+
if mask_token_id is None:
|
| 84 |
+
self.tokenizer.mask_token = "[MASK]"
|
| 85 |
+
mask_token_id = self.tokenizer.encode(self.tokenizer.mask_token)[0]
|
| 86 |
+
|
| 87 |
+
fluency_scores = []
|
| 88 |
+
for i in range(1, input_ids.size(1) - 1):
|
| 89 |
+
masked_input_ids = input_ids.clone()
|
| 90 |
+
masked_input_ids[0, i] = mask_token_id
|
| 91 |
+
|
| 92 |
+
with torch.no_grad():
|
| 93 |
+
outputs = self.model(input_ids=masked_input_ids)
|
| 94 |
+
logits = outputs.logits
|
| 95 |
+
|
| 96 |
+
original_token_id = input_ids[0, i]
|
| 97 |
+
token_probability = logits[0, i].softmax(dim=-1)[original_token_id].item()
|
| 98 |
+
fluency_scores.append(token_probability)
|
| 99 |
+
|
| 100 |
+
avg_fluency = (
|
| 101 |
+
sum(fluency_scores) / len(fluency_scores) if fluency_scores else 0.0
|
| 102 |
+
)
|
| 103 |
+
return avg_fluency
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class GenreAlignmentScorer(StoryScorer):
|
| 107 |
+
def __init__(self, pipeline: Pipeline, genre: str):
|
| 108 |
+
self.pipeline = pipeline
|
| 109 |
+
self.genre = genre
|
| 110 |
+
|
| 111 |
+
def score(self, story: str) -> float:
|
| 112 |
+
if not self.genre:
|
| 113 |
+
return 0.5
|
| 114 |
+
|
| 115 |
+
# Evaluate by sentence to check whether the genre is maintained throughout
|
| 116 |
+
sentences = [s.strip() for s in story.split(".") if s.strip()]
|
| 117 |
+
results = []
|
| 118 |
+
for sentence in sentences:
|
| 119 |
+
result = self.pipeline(
|
| 120 |
+
sentence, candidate_labels=[self.genre], multi_label=True
|
| 121 |
+
)
|
| 122 |
+
results.append(result["scores"][0])
|
| 123 |
+
|
| 124 |
+
avg_core = sum(results) / len(results) if results else 0.0
|
| 125 |
+
return avg_core
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
class StoryEvaluator:
|
| 129 |
+
def __init__(
|
| 130 |
+
self,
|
| 131 |
+
coherence_scorer: CoherenceScorer,
|
| 132 |
+
fluency_scorer: FluencyScorer,
|
| 133 |
+
genre_scorer: GenreAlignmentScorer,
|
| 134 |
+
weights: tuple[float, float, float] = (0.4, 0.3, 0.3),
|
| 135 |
+
):
|
| 136 |
+
self.coherence_scorer = coherence_scorer
|
| 137 |
+
self.fluency_scorer = fluency_scorer
|
| 138 |
+
self.genre_scorer = genre_scorer
|
| 139 |
+
self.weights = weights
|
| 140 |
+
|
| 141 |
+
def evaluate(self, story: str, max_scores: list[float]) -> CombinedScore:
|
| 142 |
+
coherence = self.coherence_scorer.score(story)
|
| 143 |
+
fluency = self.fluency_scorer.score(story)
|
| 144 |
+
genre_alignment = self.genre_scorer.score(story)
|
| 145 |
+
|
| 146 |
+
max_scores[0] = np.max([max_scores[0], coherence])
|
| 147 |
+
max_scores[1] = np.max([max_scores[1], fluency])
|
| 148 |
+
max_scores[2] = np.max([max_scores[2], genre_alignment])
|
| 149 |
+
|
| 150 |
+
return CombinedScore(
|
| 151 |
+
coherence=coherence,
|
| 152 |
+
fluency=fluency,
|
| 153 |
+
genre_alignment=genre_alignment,
|
| 154 |
+
total=0,
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
def evaluate_multiple(self, stories: list[str]) -> list[tuple[str, CombinedScore]]:
|
| 158 |
+
"""Evaluate multiple stories and return them sorted by total score."""
|
| 159 |
+
|
| 160 |
+
# Scores are normalized by the max scores on every evaluation
|
| 161 |
+
# This is to ensure that the scores are comparable between each other, as they are originally on different scales
|
| 162 |
+
|
| 163 |
+
# Reset max scores
|
| 164 |
+
max_scores = [0.0, 0.0, 0.0]
|
| 165 |
+
|
| 166 |
+
scored_stories = [
|
| 167 |
+
(story, self.evaluate(story, max_scores)) for story in stories
|
| 168 |
+
]
|
| 169 |
+
|
| 170 |
+
# Normalize scores
|
| 171 |
+
for _, scores in scored_stories:
|
| 172 |
+
scores.coherence, scores.fluency, scores.genre_alignment = np.divide(
|
| 173 |
+
[scores.coherence, scores.fluency, scores.genre_alignment],
|
| 174 |
+
max_scores,
|
| 175 |
+
)
|
| 176 |
+
scores.total = np.dot(
|
| 177 |
+
[scores.coherence, scores.fluency, scores.genre_alignment], self.weights
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
return sorted(scored_stories, key=lambda x: x[1].total, reverse=True)
|
story_beam_search/stories_generator.py
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import re
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from story_beam_search.scoring import CombinedScore
|
| 5 |
+
from transformers import (
|
| 6 |
+
AutoModelForCausalLM,
|
| 7 |
+
AutoTokenizer,
|
| 8 |
+
AutoModelForMaskedLM,
|
| 9 |
+
pipeline,
|
| 10 |
+
Pipeline,
|
| 11 |
+
)
|
| 12 |
+
import torch
|
| 13 |
+
|
| 14 |
+
auth_token = os.getenv("HF_TOKEN", None)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@dataclass
|
| 18 |
+
class ModelConfig:
|
| 19 |
+
text_model_name: str = "meta-llama/Llama-3.2-1B-Instruct"
|
| 20 |
+
bert_name: str = "bert-base-uncased" # "answerdotai/ModernBERT-base"
|
| 21 |
+
zero_shot_name: str = "facebook/bart-large-mnli"
|
| 22 |
+
device: str = (
|
| 23 |
+
"mps"
|
| 24 |
+
if torch.backends.mps.is_available()
|
| 25 |
+
else "cuda" if torch.cuda.is_available() else "cpu"
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@dataclass
|
| 30 |
+
class Models:
|
| 31 |
+
"""Container for all loaded models and tokenizers."""
|
| 32 |
+
|
| 33 |
+
device: torch.device
|
| 34 |
+
text_model: AutoModelForCausalLM
|
| 35 |
+
text_tokenizer: AutoTokenizer
|
| 36 |
+
bert_model: AutoModelForMaskedLM
|
| 37 |
+
bert_tokenizer: AutoTokenizer
|
| 38 |
+
zero_shot_pipeline: Pipeline
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class ModelLoader:
|
| 42 |
+
"""Handles loading and initialization of all required models."""
|
| 43 |
+
|
| 44 |
+
def __init__(self, config: ModelConfig = ModelConfig()):
|
| 45 |
+
self.config = config
|
| 46 |
+
self.device = torch.device(config.device)
|
| 47 |
+
|
| 48 |
+
def load_models(self) -> Models:
|
| 49 |
+
"""Load all required models and return them in a Models container."""
|
| 50 |
+
|
| 51 |
+
# Load Text model for writting stories
|
| 52 |
+
print(f"Loading Text model ({self.config.text_model_name})...")
|
| 53 |
+
text_tokenizer = AutoTokenizer.from_pretrained(
|
| 54 |
+
self.config.text_model_name, token=auth_token
|
| 55 |
+
)
|
| 56 |
+
text_model = AutoModelForCausalLM.from_pretrained(
|
| 57 |
+
self.config.text_model_name
|
| 58 |
+
).to(self.device)
|
| 59 |
+
text_model.eval()
|
| 60 |
+
|
| 61 |
+
# Load BERT model for coherence and fluency scoring
|
| 62 |
+
print(f"Loading BERT model ({self.config.bert_name})...")
|
| 63 |
+
bert_tokenizer = AutoTokenizer.from_pretrained(self.config.bert_name)
|
| 64 |
+
bert_model = AutoModelForMaskedLM.from_pretrained(self.config.bert_name).to(
|
| 65 |
+
self.device
|
| 66 |
+
)
|
| 67 |
+
bert_model.eval()
|
| 68 |
+
|
| 69 |
+
# Load Zero-Shot classification pipeline for genre alignment scoring
|
| 70 |
+
print("Loading Zero-Shot Classification pipeline...")
|
| 71 |
+
zero_shot_pipeline = pipeline(
|
| 72 |
+
"zero-shot-classification",
|
| 73 |
+
model=self.config.zero_shot_name,
|
| 74 |
+
device=self.device,
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
return Models(
|
| 78 |
+
device=self.device,
|
| 79 |
+
text_model=text_model,
|
| 80 |
+
text_tokenizer=text_tokenizer,
|
| 81 |
+
bert_model=bert_model,
|
| 82 |
+
bert_tokenizer=bert_tokenizer,
|
| 83 |
+
zero_shot_pipeline=zero_shot_pipeline,
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class StoryGenerationSystem:
|
| 88 |
+
"""
|
| 89 |
+
High-level class that coordinates model loading and initialization of all components.
|
| 90 |
+
Acts as a facade for the entire story generation system.
|
| 91 |
+
"""
|
| 92 |
+
|
| 93 |
+
def __init__(self, model_config: ModelConfig = ModelConfig()):
|
| 94 |
+
self.model_loader = ModelLoader(model_config)
|
| 95 |
+
self.models = None
|
| 96 |
+
self.beam_search = None
|
| 97 |
+
self.evaluator = None
|
| 98 |
+
self.storyness = None
|
| 99 |
+
self.injection_guard = None
|
| 100 |
+
|
| 101 |
+
def initialize(self) -> None:
|
| 102 |
+
"""Initialize all components of the story generation system."""
|
| 103 |
+
from story_beam_search.beam_search import BeamSearchGenerator, BeamSearchConfig
|
| 104 |
+
from story_beam_search.scoring import (
|
| 105 |
+
CoherenceScorer,
|
| 106 |
+
FluencyScorer,
|
| 107 |
+
GenreAlignmentScorer,
|
| 108 |
+
StoryEvaluator,
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
# Load all models
|
| 112 |
+
self.models = self.model_loader.load_models()
|
| 113 |
+
|
| 114 |
+
# Initialize beam search
|
| 115 |
+
self.beam_search = BeamSearchGenerator(
|
| 116 |
+
model=self.models.text_model,
|
| 117 |
+
tokenizer=self.models.text_tokenizer,
|
| 118 |
+
device=self.models.device,
|
| 119 |
+
config=BeamSearchConfig(),
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
# Initialize scorers
|
| 123 |
+
coherence_scorer = CoherenceScorer(
|
| 124 |
+
model=self.models.bert_model,
|
| 125 |
+
tokenizer=self.models.bert_tokenizer,
|
| 126 |
+
device=self.models.device,
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
fluency_scorer = FluencyScorer(
|
| 130 |
+
model=self.models.text_model,
|
| 131 |
+
tokenizer=self.models.text_tokenizer,
|
| 132 |
+
device=self.models.device,
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
# Note: genre_scorer will be created per request as it depends on the user's genre choice
|
| 136 |
+
self.create_evaluator = lambda genre: StoryEvaluator(
|
| 137 |
+
coherence_scorer=coherence_scorer,
|
| 138 |
+
fluency_scorer=fluency_scorer,
|
| 139 |
+
genre_scorer=GenreAlignmentScorer(
|
| 140 |
+
pipeline=self.models.zero_shot_pipeline, genre=genre
|
| 141 |
+
),
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
# Last minute addition 'misusing' the GenreAlignmentScorer to check for prompt injections
|
| 145 |
+
self.storyness = GenreAlignmentScorer(
|
| 146 |
+
pipeline=self.models.zero_shot_pipeline, genre="story"
|
| 147 |
+
)
|
| 148 |
+
self.injection_guard = GenreAlignmentScorer(
|
| 149 |
+
pipeline=self.models.zero_shot_pipeline, genre="prompt injection"
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
def generate_and_evaluate(
|
| 153 |
+
self, prompt: str, genre: str, num_stories: int = 3
|
| 154 |
+
) -> list[tuple[str, CombinedScore]]:
|
| 155 |
+
"""Generate stories and evaluate them."""
|
| 156 |
+
if not self.models:
|
| 157 |
+
raise RuntimeError("System not initialized. Call initialize() first.")
|
| 158 |
+
|
| 159 |
+
# Low effort attempt to detect prompt injections using the zero-shot classifier
|
| 160 |
+
prompt_segments = re.split(r'[^a-zA-Z0-9 ]+', prompt)
|
| 161 |
+
prompt_segments = list(set(prompt_segments))
|
| 162 |
+
|
| 163 |
+
storyness_score = self.storyness.score(prompt)
|
| 164 |
+
for segment in prompt_segments:
|
| 165 |
+
if segment.strip():
|
| 166 |
+
injection_score = self.injection_guard.score(segment)
|
| 167 |
+
if storyness_score < 0.2 or injection_score > 0.2:
|
| 168 |
+
print("Potential prompt injection detected.")
|
| 169 |
+
print(f"storyness_score: {storyness_score}")
|
| 170 |
+
print(f"injection_score: {injection_score}")
|
| 171 |
+
print("Prompt:", segment)
|
| 172 |
+
raise ValueError("Prompt does not seem like a story. Please try again.")
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
# Create evaluator with specified genre
|
| 176 |
+
evaluator = self.create_evaluator(genre)
|
| 177 |
+
|
| 178 |
+
# Generate stories
|
| 179 |
+
# This is not strict beam search, inspired by https://huggingface.co/spaces/HuggingFaceH4/blogpost-scaling-test-time-compute
|
| 180 |
+
# to generate more diverse stories
|
| 181 |
+
all_stories = []
|
| 182 |
+
for _ in range(num_stories):
|
| 183 |
+
stories = self.beam_search.generate_iterations(prompt, genre, evaluator)
|
| 184 |
+
ranked_stories = evaluator.evaluate_multiple(stories)
|
| 185 |
+
# keep the top story of this beam search iteration
|
| 186 |
+
all_stories.append(ranked_stories[0][0])
|
| 187 |
+
|
| 188 |
+
# Evaluate stories once more
|
| 189 |
+
ranked_stories = evaluator.evaluate_multiple(all_stories)
|
| 190 |
+
# Return top k stories with their scores
|
| 191 |
+
return ranked_stories[:num_stories]
|