ssalb commited on
Commit
bfb4432
·
1 Parent(s): 3b5aca7

Update space with latest code and dependencies on Wed Jan 1 21:30:51 UTC 2025

Browse files
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Salvador Salazar
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,13 +1,24 @@
1
  ---
2
  title: Story Generator
3
- emoji: 🐨
4
- colorFrom: gray
5
- colorTo: red
 
6
  sdk: gradio
7
  sdk_version: 5.9.1
8
  app_file: app.py
9
  pinned: false
 
 
 
 
10
  license: mit
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
1
  ---
2
  title: Story Generator
3
+ emoji: 📚
4
+ colorFrom: blue
5
+ python_version: 3.10.13
6
+ colorTo: pink
7
  sdk: gradio
8
  sdk_version: 5.9.1
9
  app_file: app.py
10
  pinned: false
11
+ preload_from_hub:
12
+ # - meta-llama/Llama-3.2-1B-Instruct # Do not preload llama, as the token is not available at build time
13
+ - google-bert/bert-base-uncased
14
+ - facebook/bart-large-mnli
15
  license: mit
16
  ---
17
 
18
+ ## Project Overview
19
+
20
+ The Story Generator project leverages advanced natural language processing models to generate coherent and engaging stories. By utilizing models such as GPT-2, BERT, and BART, this project aims to provide users with a tool to create narratives based on given prompts. The application is built using Gradio for an interactive user interface, making it easy to input prompts and receive generated stories in real-time.
21
+
22
+ The main purpose of this project is to explore the idea of beam search for selecting stories with high coherence, fluency, and genre alignment scores. This ensures that the generated stories are not only creative but also maintain a logical flow and adhere to the specified genre.
23
+
24
+ Note that the final implementation is not strictly beam search and was modified to allow more diversity (creativity) inspired by the DVTS method in [this blog post](https://huggingface.co/spaces/HuggingFaceH4/blogpost-scaling-test-time-compute).
app.py CHANGED
@@ -1,7 +1,136 @@
1
  import gradio as gr
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
 
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from typing import Literal
3
+ from pydantic import BaseModel, Field, constr
4
+ from story_beam_search.stories_generator import StoryGenerationSystem
5
+ from typing import Tuple, List
6
 
7
+ genre_choices = [
8
+ "children",
9
+ "mystery",
10
+ "adventure",
11
+ "sci-fi",
12
+ "fantasy",
13
+ "romance",
14
+ "comedy",
15
+ "drama",
16
+ "horror",
17
+ ]
18
 
19
+ class InputModel(BaseModel):
20
+ prompt: str
21
+ genre: str
22
+ num_stories: int = Field(3, ge=2, le=7)
23
+ temperature: float = Field(2.5, ge=0.7, le=3.5)
24
+ max_length: int = Field(60, ge=30, le=200)
25
+
26
+
27
+ def create_story_generation_interface() -> gr.Interface:
28
+ # Initialize the story generation system
29
+ system = StoryGenerationSystem()
30
+ system.initialize()
31
+
32
+ def generate_stories(
33
+ prompt: str, genre: str, num_stories: int, temperature: float, max_length: int
34
+ ) -> Tuple[str, List[str]]:
35
+ """
36
+ Generate and evaluate stories based on user input.
37
+ Returns a tuple of (detailed_scores, story_texts).
38
+ """
39
+
40
+ # Validate inputs.Gradio seems to validate chioces but not the range of the values
41
+ input_values = InputModel(
42
+ prompt=prompt, genre=genre, num_stories=num_stories, temperature=temperature, max_length=max_length
43
+ )
44
+
45
+ # Update beam search config with user parameters
46
+ system.beam_search.config.temperature = input_values.temperature
47
+ system.beam_search.config.max_length = input_values.max_length
48
+
49
+ # Generate and evaluate stories
50
+ ranked_stories = system.generate_and_evaluate(
51
+ input_values.prompt, input_values.genre, num_stories=input_values.num_stories
52
+ )
53
+
54
+ # Format detailed scores
55
+ detailed_scores = ""
56
+ story_texts = []
57
+
58
+ for i, (story, scores) in enumerate(ranked_stories, 1):
59
+ detailed_scores += f"Story {i}:\n"
60
+ detailed_scores += f"Total Score: {scores.total:.3f}\n"
61
+ detailed_scores += f"Coherence: {scores.coherence:.3f}\n"
62
+ detailed_scores += f"Fluency: {scores.fluency:.3f}\n"
63
+ detailed_scores += f"Genre Alignment: {scores.genre_alignment:.3f}\n"
64
+ detailed_scores += "-" * 50 + "\n"
65
+
66
+ story_texts.append(f"Story {i}:\n{story}\n")
67
+
68
+ return detailed_scores, "\n".join(story_texts)
69
+
70
+ # Define interface components
71
+ prompt_input = gr.Textbox(
72
+ label="Story Prompt",
73
+ placeholder="Enter the beginning of your story...",
74
+ lines=3,
75
+ )
76
+
77
+ genre_input = gr.Dropdown(
78
+ choices=genre_choices,
79
+ label="Genre",
80
+ value="fantasy",
81
+ )
82
+
83
+ num_stories_input = gr.Slider(
84
+ minimum=2, maximum=7, value=3, step=1, label="Number of Stories to Generate"
85
+ )
86
+
87
+ temperature_input = gr.Slider(
88
+ minimum=0.7, maximum=3.5, value=2.5, step=0.1, label="Temperature (Creativity)"
89
+ )
90
+
91
+ max_length_input = gr.Slider(
92
+ minimum=30, maximum=200, value=60, step=30, label="Maximum Length"
93
+ )
94
+
95
+ # Output components
96
+ scores_output = gr.Textbox(label="Detailed Scores", lines=10, interactive=False)
97
+
98
+ stories_output = gr.Textbox(label="Generated Stories", lines=15, interactive=False)
99
+
100
+ # Create the interface
101
+ interface = gr.Interface(
102
+ fn=generate_stories,
103
+ inputs=[
104
+ prompt_input,
105
+ genre_input,
106
+ num_stories_input,
107
+ temperature_input,
108
+ max_length_input,
109
+ ],
110
+ outputs=[scores_output, stories_output],
111
+ title="AI Story Generator",
112
+ description="""
113
+ Generate creative stories using AI! Enter a prompt and choose your preferences.
114
+ The system will generate multiple stories and evaluate them based on coherence,
115
+ fluency, and genre alignment.
116
+ """,
117
+ examples=[
118
+ ["Once upon a time in a magical forest,", "fantasy", 3, 1.8, 150],
119
+ [
120
+ "The detective knelt beside the bloodstained carpet, her gaze sharp as she traced the faint outline of a shoeprint.",
121
+ "mystery",
122
+ 3,
123
+ 2.7,
124
+ 200,
125
+ ],
126
+ ],
127
+ theme=gr.themes.Soft(),
128
+ )
129
+
130
+ return interface
131
+
132
+
133
+ if __name__ == "__main__":
134
+ # Create and launch the interface
135
+ interface = create_story_generation_interface()
136
+ interface.launch(show_error=True)
requirements.txt ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1 ; python_full_version == "3.10.13"
2
+ annotated-types==0.7.0 ; python_full_version == "3.10.13"
3
+ anyio==4.7.0 ; python_full_version == "3.10.13"
4
+ certifi==2024.12.14 ; python_full_version == "3.10.13"
5
+ charset-normalizer==3.4.1 ; python_full_version == "3.10.13"
6
+ click==8.1.8 ; sys_platform != "emscripten" and python_full_version == "3.10.13"
7
+ colorama==0.4.6 ; python_full_version == "3.10.13" and platform_system == "Windows"
8
+ exceptiongroup==1.2.2 ; python_full_version == "3.10.13"
9
+ fastapi==0.115.6 ; python_full_version == "3.10.13"
10
+ ffmpy==0.5.0 ; python_full_version == "3.10.13"
11
+ filelock==3.16.1 ; python_full_version == "3.10.13"
12
+ fsspec==2024.12.0 ; python_full_version == "3.10.13"
13
+ gradio-client==1.5.2 ; python_full_version == "3.10.13"
14
+ gradio==5.9.1 ; python_full_version == "3.10.13"
15
+ h11==0.14.0 ; python_full_version == "3.10.13"
16
+ httpcore==1.0.7 ; python_full_version == "3.10.13"
17
+ httpx==0.28.1 ; python_full_version == "3.10.13"
18
+ huggingface-hub==0.27.0 ; python_full_version == "3.10.13"
19
+ idna==3.10 ; python_full_version == "3.10.13"
20
+ jinja2==3.1.5 ; python_full_version == "3.10.13"
21
+ joblib==1.4.2 ; python_full_version == "3.10.13"
22
+ markdown-it-py==3.0.0 ; sys_platform != "emscripten" and python_full_version == "3.10.13"
23
+ markupsafe==2.1.5 ; python_full_version == "3.10.13"
24
+ mdurl==0.1.2 ; sys_platform != "emscripten" and python_full_version == "3.10.13"
25
+ mpmath==1.3.0 ; python_full_version == "3.10.13"
26
+ networkx==3.4.2 ; python_full_version == "3.10.13"
27
+ numpy==2.2.1 ; python_full_version == "3.10.13"
28
+ nvidia-cublas-cu12==12.1.3.1 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version == "3.10.13"
29
+ nvidia-cuda-cupti-cu12==12.1.105 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version == "3.10.13"
30
+ nvidia-cuda-nvrtc-cu12==12.1.105 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version == "3.10.13"
31
+ nvidia-cuda-runtime-cu12==12.1.105 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version == "3.10.13"
32
+ nvidia-cudnn-cu12==9.1.0.70 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version == "3.10.13"
33
+ nvidia-cufft-cu12==11.0.2.54 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version == "3.10.13"
34
+ nvidia-curand-cu12==10.3.2.106 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version == "3.10.13"
35
+ nvidia-cusolver-cu12==11.4.5.107 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version == "3.10.13"
36
+ nvidia-cusparse-cu12==12.1.0.106 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version == "3.10.13"
37
+ nvidia-nccl-cu12==2.20.5 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version == "3.10.13"
38
+ nvidia-nvjitlink-cu12==12.6.85 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version == "3.10.13"
39
+ nvidia-nvtx-cu12==12.1.105 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version == "3.10.13"
40
+ orjson==3.10.13 ; python_full_version == "3.10.13"
41
+ packaging==24.2 ; python_full_version == "3.10.13"
42
+ pandas==2.2.3 ; python_full_version == "3.10.13"
43
+ pillow==11.0.0 ; python_full_version == "3.10.13"
44
+ protobuf==5.29.2 ; python_full_version == "3.10.13"
45
+ pydantic-core==2.27.2 ; python_full_version == "3.10.13"
46
+ pydantic==2.10.4 ; python_full_version == "3.10.13"
47
+ pydub==0.25.1 ; python_full_version == "3.10.13"
48
+ pygments==2.18.0 ; sys_platform != "emscripten" and python_full_version == "3.10.13"
49
+ python-dateutil==2.9.0.post0 ; python_full_version == "3.10.13"
50
+ python-multipart==0.0.20 ; python_full_version == "3.10.13"
51
+ pytz==2024.2 ; python_full_version == "3.10.13"
52
+ pyyaml==6.0.2 ; python_full_version == "3.10.13"
53
+ regex==2024.11.6 ; python_full_version == "3.10.13"
54
+ requests==2.32.3 ; python_full_version == "3.10.13"
55
+ rich==13.9.4 ; sys_platform != "emscripten" and python_full_version == "3.10.13"
56
+ ruff==0.8.4 ; sys_platform != "emscripten" and python_full_version == "3.10.13"
57
+ safehttpx==0.1.6 ; python_full_version == "3.10.13"
58
+ safetensors==0.4.5 ; python_full_version == "3.10.13"
59
+ scikit-learn==1.6.0 ; python_full_version == "3.10.13"
60
+ scipy==1.14.1 ; python_full_version == "3.10.13"
61
+ semantic-version==2.10.0 ; python_full_version == "3.10.13"
62
+ shellingham==1.5.4 ; sys_platform != "emscripten" and python_full_version == "3.10.13"
63
+ six==1.17.0 ; python_full_version == "3.10.13"
64
+ sniffio==1.3.1 ; python_full_version == "3.10.13"
65
+ starlette==0.41.3 ; python_full_version == "3.10.13"
66
+ sympy==1.13.3 ; python_full_version == "3.10.13"
67
+ threadpoolctl==3.5.0 ; python_full_version == "3.10.13"
68
+ tokenizers==0.21.0 ; python_full_version == "3.10.13"
69
+ tomlkit==0.13.2 ; python_full_version == "3.10.13"
70
+ torch==2.4.0 ; python_full_version == "3.10.13"
71
+ tqdm==4.67.1 ; python_full_version == "3.10.13"
72
+ transformers==4.47.1 ; python_full_version == "3.10.13"
73
+ triton==3.0.0 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version == "3.10.13"
74
+ typer==0.15.1 ; sys_platform != "emscripten" and python_full_version == "3.10.13"
75
+ typing-extensions==4.12.2 ; python_full_version == "3.10.13"
76
+ tzdata==2024.2 ; python_full_version == "3.10.13"
77
+ urllib3==2.3.0 ; python_full_version == "3.10.13"
78
+ uvicorn==0.34.0 ; sys_platform != "emscripten" and python_full_version == "3.10.13"
79
+ websockets==14.1 ; python_full_version == "3.10.13"
story_beam_search/beam_search.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic.dataclasses import dataclass
2
+ from typing import Optional
3
+ import torch
4
+ from transformers import PreTrainedModel, PreTrainedTokenizer
5
+
6
+ from story_beam_search.scoring import StoryScorer
7
+
8
+
9
+ @dataclass
10
+ class BeamSearchConfig:
11
+ num_beams: int = 3
12
+ num_return_sequences: int = 3
13
+ max_length: int = 100
14
+ no_repeat_ngram_size: int = 2
15
+ temperature: float = 0.8
16
+ top_k: int = 8
17
+ top_p: float = 0.95
18
+ num_iterations: int = 3
19
+ continuation_length: int = 10
20
+
21
+
22
+ class BeamSearchGenerator:
23
+ def __init__(
24
+ self,
25
+ model: PreTrainedModel,
26
+ tokenizer: PreTrainedTokenizer,
27
+ device: torch.device,
28
+ config: Optional[BeamSearchConfig] = None,
29
+ ):
30
+ self.model = model
31
+ self.tokenizer = tokenizer
32
+ self.device = device
33
+ self.config = config or BeamSearchConfig()
34
+
35
+ def generate_iterations(
36
+ self, prompt: str, genre: str, evaluator: StoryScorer
37
+ ) -> list[str]:
38
+ """
39
+ Generate story continuations using multiple iterations of beam search.
40
+ """
41
+
42
+ # Adding some instructions to the prompt. These are removed in the end
43
+ instructions = (
44
+ f"Continue the following story in the {genre} genre, "
45
+ "ensuring coherence with the tone, characters, and narrative established so far:\n"
46
+ )
47
+ instructions_len = len(instructions)
48
+
49
+ stories = self._generate_single_iteration(instructions + prompt)
50
+ ranked_stories = evaluator.evaluate_multiple(
51
+ [story[instructions_len:] for story in stories]
52
+ )
53
+
54
+ stories = [story for story, _ in ranked_stories[:self.config.num_beams]]
55
+
56
+ if stories:
57
+ for _ in range(self.config.num_iterations):
58
+ all_stories = []
59
+ for story in stories:
60
+ continuations = self._generate_single_iteration(
61
+ instructions + story
62
+ )
63
+ all_stories.extend(continuations)
64
+ ranked_stories = evaluator.evaluate_multiple(
65
+ [story[instructions_len:] for story in all_stories]
66
+ )
67
+ stories = [story for story, _ in ranked_stories[:self.config.num_beams]]
68
+
69
+ return stories
70
+
71
+ def _generate_single_iteration(self, prompt: str) -> list[str]:
72
+ """
73
+ Generate multiple continuations for a single iteration using beam search.
74
+ """
75
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
76
+ input_ids = inputs["input_ids"]
77
+ attention_mask = inputs["attention_mask"]
78
+
79
+ self.config.continuation_length = (
80
+ len(input_ids[0]) + self.config.max_length // self.config.num_iterations
81
+ )
82
+
83
+ with torch.no_grad():
84
+ outputs = self.model.generate(
85
+ input_ids=input_ids,
86
+ attention_mask=attention_mask,
87
+ max_length=self.config.continuation_length,
88
+ num_beams=self.config.num_beams,
89
+ num_return_sequences=self.config.num_return_sequences,
90
+ early_stopping=True,
91
+ no_repeat_ngram_size=self.config.no_repeat_ngram_size,
92
+ temperature=self.config.temperature,
93
+ top_k=self.config.top_k,
94
+ top_p=self.config.top_p,
95
+ do_sample=True,
96
+ )
97
+
98
+ stories = []
99
+ for output in outputs:
100
+ text = self.tokenizer.decode(output, skip_special_tokens=True)
101
+ stories.append(text)
102
+
103
+ return stories
story_beam_search/scoring.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from dataclasses import dataclass
3
+ from typing import Protocol
4
+ import torch
5
+ from transformers import PreTrainedModel, PreTrainedTokenizer, Pipeline
6
+ from sklearn.metrics.pairwise import cosine_similarity
7
+
8
+
9
+ class StoryScorer(Protocol):
10
+ """Protocol defining the interface for story scoring components."""
11
+
12
+ def score(self, story: str) -> float:
13
+ """Return a score between 0 and 1."""
14
+ ...
15
+
16
+
17
+ @dataclass
18
+ class CombinedScore:
19
+ coherence: float = 0.0
20
+ fluency: float = 0.0
21
+ genre_alignment: float = 0.0
22
+ total: float = 0.0
23
+
24
+
25
+ class CoherenceScorer(StoryScorer):
26
+ def __init__(
27
+ self,
28
+ model: PreTrainedModel,
29
+ tokenizer: PreTrainedTokenizer,
30
+ device: torch.device,
31
+ max_pairs: int = 3,
32
+ ):
33
+ self.model = model
34
+ self.tokenizer = tokenizer
35
+ self.device = device
36
+ self.max_pairs = max_pairs
37
+
38
+ def score(self, story: str) -> float:
39
+ """Calculate coherence score based on sentences cosine similarity."""
40
+
41
+ sentences = [s.strip() for s in story.split(".") if s.strip()]
42
+
43
+ embeddings = []
44
+
45
+ # Generate embeddings for each sentence
46
+ for sentence in sentences:
47
+ inputs = self.tokenizer(sentence, return_tensors="pt").to(self.device)
48
+ with torch.no_grad():
49
+ emb = self.model.bert(**inputs).last_hidden_state[:, 0, :]
50
+ embeddings.append(emb.cpu().numpy())
51
+
52
+ # Calculate cosine similarity between adjacent embeddings
53
+ coherence_scores = []
54
+ for i in range(len(embeddings) - 1):
55
+ sim = cosine_similarity(embeddings[i], embeddings[i + 1])[0][0]
56
+ coherence_scores.append(sim)
57
+
58
+ # Average coherence score
59
+ avg_coherence = (
60
+ sum(coherence_scores) / len(coherence_scores) if coherence_scores else 0.0
61
+ )
62
+ return avg_coherence
63
+
64
+
65
+ class FluencyScorer(StoryScorer):
66
+ def __init__(
67
+ self,
68
+ model: PreTrainedModel,
69
+ tokenizer: PreTrainedTokenizer,
70
+ device: torch.device,
71
+ ):
72
+ self.model = model
73
+ self.tokenizer = tokenizer
74
+ self.device = device
75
+
76
+ def score(self, story: str) -> float:
77
+ # Mask each token in the story and calculate the probability of the original token
78
+ # Fluency is measured by the average probability of each token in the story
79
+ inputs = self.tokenizer(story, return_tensors="pt").to(self.device)
80
+ input_ids = inputs.input_ids
81
+ mask_token_id = self.tokenizer.mask_token_id
82
+
83
+ if mask_token_id is None:
84
+ self.tokenizer.mask_token = "[MASK]"
85
+ mask_token_id = self.tokenizer.encode(self.tokenizer.mask_token)[0]
86
+
87
+ fluency_scores = []
88
+ for i in range(1, input_ids.size(1) - 1):
89
+ masked_input_ids = input_ids.clone()
90
+ masked_input_ids[0, i] = mask_token_id
91
+
92
+ with torch.no_grad():
93
+ outputs = self.model(input_ids=masked_input_ids)
94
+ logits = outputs.logits
95
+
96
+ original_token_id = input_ids[0, i]
97
+ token_probability = logits[0, i].softmax(dim=-1)[original_token_id].item()
98
+ fluency_scores.append(token_probability)
99
+
100
+ avg_fluency = (
101
+ sum(fluency_scores) / len(fluency_scores) if fluency_scores else 0.0
102
+ )
103
+ return avg_fluency
104
+
105
+
106
+ class GenreAlignmentScorer(StoryScorer):
107
+ def __init__(self, pipeline: Pipeline, genre: str):
108
+ self.pipeline = pipeline
109
+ self.genre = genre
110
+
111
+ def score(self, story: str) -> float:
112
+ if not self.genre:
113
+ return 0.5
114
+
115
+ # Evaluate by sentence to check whether the genre is maintained throughout
116
+ sentences = [s.strip() for s in story.split(".") if s.strip()]
117
+ results = []
118
+ for sentence in sentences:
119
+ result = self.pipeline(
120
+ sentence, candidate_labels=[self.genre], multi_label=True
121
+ )
122
+ results.append(result["scores"][0])
123
+
124
+ avg_core = sum(results) / len(results) if results else 0.0
125
+ return avg_core
126
+
127
+
128
+ class StoryEvaluator:
129
+ def __init__(
130
+ self,
131
+ coherence_scorer: CoherenceScorer,
132
+ fluency_scorer: FluencyScorer,
133
+ genre_scorer: GenreAlignmentScorer,
134
+ weights: tuple[float, float, float] = (0.4, 0.3, 0.3),
135
+ ):
136
+ self.coherence_scorer = coherence_scorer
137
+ self.fluency_scorer = fluency_scorer
138
+ self.genre_scorer = genre_scorer
139
+ self.weights = weights
140
+
141
+ def evaluate(self, story: str, max_scores: list[float]) -> CombinedScore:
142
+ coherence = self.coherence_scorer.score(story)
143
+ fluency = self.fluency_scorer.score(story)
144
+ genre_alignment = self.genre_scorer.score(story)
145
+
146
+ max_scores[0] = np.max([max_scores[0], coherence])
147
+ max_scores[1] = np.max([max_scores[1], fluency])
148
+ max_scores[2] = np.max([max_scores[2], genre_alignment])
149
+
150
+ return CombinedScore(
151
+ coherence=coherence,
152
+ fluency=fluency,
153
+ genre_alignment=genre_alignment,
154
+ total=0,
155
+ )
156
+
157
+ def evaluate_multiple(self, stories: list[str]) -> list[tuple[str, CombinedScore]]:
158
+ """Evaluate multiple stories and return them sorted by total score."""
159
+
160
+ # Scores are normalized by the max scores on every evaluation
161
+ # This is to ensure that the scores are comparable between each other, as they are originally on different scales
162
+
163
+ # Reset max scores
164
+ max_scores = [0.0, 0.0, 0.0]
165
+
166
+ scored_stories = [
167
+ (story, self.evaluate(story, max_scores)) for story in stories
168
+ ]
169
+
170
+ # Normalize scores
171
+ for _, scores in scored_stories:
172
+ scores.coherence, scores.fluency, scores.genre_alignment = np.divide(
173
+ [scores.coherence, scores.fluency, scores.genre_alignment],
174
+ max_scores,
175
+ )
176
+ scores.total = np.dot(
177
+ [scores.coherence, scores.fluency, scores.genre_alignment], self.weights
178
+ )
179
+
180
+ return sorted(scored_stories, key=lambda x: x[1].total, reverse=True)
story_beam_search/stories_generator.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ from dataclasses import dataclass
4
+ from story_beam_search.scoring import CombinedScore
5
+ from transformers import (
6
+ AutoModelForCausalLM,
7
+ AutoTokenizer,
8
+ AutoModelForMaskedLM,
9
+ pipeline,
10
+ Pipeline,
11
+ )
12
+ import torch
13
+
14
+ auth_token = os.getenv("HF_TOKEN", None)
15
+
16
+
17
+ @dataclass
18
+ class ModelConfig:
19
+ text_model_name: str = "meta-llama/Llama-3.2-1B-Instruct"
20
+ bert_name: str = "bert-base-uncased" # "answerdotai/ModernBERT-base"
21
+ zero_shot_name: str = "facebook/bart-large-mnli"
22
+ device: str = (
23
+ "mps"
24
+ if torch.backends.mps.is_available()
25
+ else "cuda" if torch.cuda.is_available() else "cpu"
26
+ )
27
+
28
+
29
+ @dataclass
30
+ class Models:
31
+ """Container for all loaded models and tokenizers."""
32
+
33
+ device: torch.device
34
+ text_model: AutoModelForCausalLM
35
+ text_tokenizer: AutoTokenizer
36
+ bert_model: AutoModelForMaskedLM
37
+ bert_tokenizer: AutoTokenizer
38
+ zero_shot_pipeline: Pipeline
39
+
40
+
41
+ class ModelLoader:
42
+ """Handles loading and initialization of all required models."""
43
+
44
+ def __init__(self, config: ModelConfig = ModelConfig()):
45
+ self.config = config
46
+ self.device = torch.device(config.device)
47
+
48
+ def load_models(self) -> Models:
49
+ """Load all required models and return them in a Models container."""
50
+
51
+ # Load Text model for writting stories
52
+ print(f"Loading Text model ({self.config.text_model_name})...")
53
+ text_tokenizer = AutoTokenizer.from_pretrained(
54
+ self.config.text_model_name, token=auth_token
55
+ )
56
+ text_model = AutoModelForCausalLM.from_pretrained(
57
+ self.config.text_model_name
58
+ ).to(self.device)
59
+ text_model.eval()
60
+
61
+ # Load BERT model for coherence and fluency scoring
62
+ print(f"Loading BERT model ({self.config.bert_name})...")
63
+ bert_tokenizer = AutoTokenizer.from_pretrained(self.config.bert_name)
64
+ bert_model = AutoModelForMaskedLM.from_pretrained(self.config.bert_name).to(
65
+ self.device
66
+ )
67
+ bert_model.eval()
68
+
69
+ # Load Zero-Shot classification pipeline for genre alignment scoring
70
+ print("Loading Zero-Shot Classification pipeline...")
71
+ zero_shot_pipeline = pipeline(
72
+ "zero-shot-classification",
73
+ model=self.config.zero_shot_name,
74
+ device=self.device,
75
+ )
76
+
77
+ return Models(
78
+ device=self.device,
79
+ text_model=text_model,
80
+ text_tokenizer=text_tokenizer,
81
+ bert_model=bert_model,
82
+ bert_tokenizer=bert_tokenizer,
83
+ zero_shot_pipeline=zero_shot_pipeline,
84
+ )
85
+
86
+
87
+ class StoryGenerationSystem:
88
+ """
89
+ High-level class that coordinates model loading and initialization of all components.
90
+ Acts as a facade for the entire story generation system.
91
+ """
92
+
93
+ def __init__(self, model_config: ModelConfig = ModelConfig()):
94
+ self.model_loader = ModelLoader(model_config)
95
+ self.models = None
96
+ self.beam_search = None
97
+ self.evaluator = None
98
+ self.storyness = None
99
+ self.injection_guard = None
100
+
101
+ def initialize(self) -> None:
102
+ """Initialize all components of the story generation system."""
103
+ from story_beam_search.beam_search import BeamSearchGenerator, BeamSearchConfig
104
+ from story_beam_search.scoring import (
105
+ CoherenceScorer,
106
+ FluencyScorer,
107
+ GenreAlignmentScorer,
108
+ StoryEvaluator,
109
+ )
110
+
111
+ # Load all models
112
+ self.models = self.model_loader.load_models()
113
+
114
+ # Initialize beam search
115
+ self.beam_search = BeamSearchGenerator(
116
+ model=self.models.text_model,
117
+ tokenizer=self.models.text_tokenizer,
118
+ device=self.models.device,
119
+ config=BeamSearchConfig(),
120
+ )
121
+
122
+ # Initialize scorers
123
+ coherence_scorer = CoherenceScorer(
124
+ model=self.models.bert_model,
125
+ tokenizer=self.models.bert_tokenizer,
126
+ device=self.models.device,
127
+ )
128
+
129
+ fluency_scorer = FluencyScorer(
130
+ model=self.models.text_model,
131
+ tokenizer=self.models.text_tokenizer,
132
+ device=self.models.device,
133
+ )
134
+
135
+ # Note: genre_scorer will be created per request as it depends on the user's genre choice
136
+ self.create_evaluator = lambda genre: StoryEvaluator(
137
+ coherence_scorer=coherence_scorer,
138
+ fluency_scorer=fluency_scorer,
139
+ genre_scorer=GenreAlignmentScorer(
140
+ pipeline=self.models.zero_shot_pipeline, genre=genre
141
+ ),
142
+ )
143
+
144
+ # Last minute addition 'misusing' the GenreAlignmentScorer to check for prompt injections
145
+ self.storyness = GenreAlignmentScorer(
146
+ pipeline=self.models.zero_shot_pipeline, genre="story"
147
+ )
148
+ self.injection_guard = GenreAlignmentScorer(
149
+ pipeline=self.models.zero_shot_pipeline, genre="prompt injection"
150
+ )
151
+
152
+ def generate_and_evaluate(
153
+ self, prompt: str, genre: str, num_stories: int = 3
154
+ ) -> list[tuple[str, CombinedScore]]:
155
+ """Generate stories and evaluate them."""
156
+ if not self.models:
157
+ raise RuntimeError("System not initialized. Call initialize() first.")
158
+
159
+ # Low effort attempt to detect prompt injections using the zero-shot classifier
160
+ prompt_segments = re.split(r'[^a-zA-Z0-9 ]+', prompt)
161
+ prompt_segments = list(set(prompt_segments))
162
+
163
+ storyness_score = self.storyness.score(prompt)
164
+ for segment in prompt_segments:
165
+ if segment.strip():
166
+ injection_score = self.injection_guard.score(segment)
167
+ if storyness_score < 0.2 or injection_score > 0.2:
168
+ print("Potential prompt injection detected.")
169
+ print(f"storyness_score: {storyness_score}")
170
+ print(f"injection_score: {injection_score}")
171
+ print("Prompt:", segment)
172
+ raise ValueError("Prompt does not seem like a story. Please try again.")
173
+
174
+
175
+ # Create evaluator with specified genre
176
+ evaluator = self.create_evaluator(genre)
177
+
178
+ # Generate stories
179
+ # This is not strict beam search, inspired by https://huggingface.co/spaces/HuggingFaceH4/blogpost-scaling-test-time-compute
180
+ # to generate more diverse stories
181
+ all_stories = []
182
+ for _ in range(num_stories):
183
+ stories = self.beam_search.generate_iterations(prompt, genre, evaluator)
184
+ ranked_stories = evaluator.evaluate_multiple(stories)
185
+ # keep the top story of this beam search iteration
186
+ all_stories.append(ranked_stories[0][0])
187
+
188
+ # Evaluate stories once more
189
+ ranked_stories = evaluator.evaluate_multiple(all_stories)
190
+ # Return top k stories with their scores
191
+ return ranked_stories[:num_stories]