update
Browse files
server/api/models.py
CHANGED
|
@@ -9,11 +9,10 @@ class Choice(BaseModel):
|
|
| 9 |
class StorySegmentResponse(BaseModel):
|
| 10 |
story_text: str = Field(description="The story text. No more than 30 words.")
|
| 11 |
|
| 12 |
-
@validator('story_text')
|
| 13 |
def validate_story_text_length(cls, v):
|
| 14 |
words = v.split()
|
| 15 |
-
if len(words) >
|
| 16 |
-
raise ValueError('Story text must not exceed
|
| 17 |
return v
|
| 18 |
|
| 19 |
class StoryPromptsResponse(BaseModel):
|
|
|
|
| 9 |
class StorySegmentResponse(BaseModel):
|
| 10 |
story_text: str = Field(description="The story text. No more than 30 words.")
|
| 11 |
|
|
|
|
| 12 |
def validate_story_text_length(cls, v):
|
| 13 |
words = v.split()
|
| 14 |
+
if len(words) > 40:
|
| 15 |
+
raise ValueError('Story text must not exceed 30 words')
|
| 16 |
return v
|
| 17 |
|
| 18 |
class StoryPromptsResponse(BaseModel):
|
server/core/generators/story_segment_generator.py
CHANGED
|
@@ -17,7 +17,6 @@ class StorySegmentGenerator(BaseGenerator):
|
|
| 17 |
self.universe_epoch = universe_epoch
|
| 18 |
self.universe_story = universe_story
|
| 19 |
self.universe_macguffin = universe_macguffin
|
| 20 |
-
self.max_retries = 5
|
| 21 |
# Then call parent constructor which will create the prompt
|
| 22 |
super().__init__(mistral_client, hero_name=hero_name, hero_desc=hero_desc)
|
| 23 |
|
|
@@ -90,7 +89,6 @@ Your task is to generate the next segment of the story, following these rules:
|
|
| 90 |
|
| 91 |
Hero Description: {self.hero_desc}
|
| 92 |
|
| 93 |
-
- MANDATORY: Each segment must be close to 15 words, no exceptions.
|
| 94 |
"""
|
| 95 |
|
| 96 |
human_template = """
|
|
@@ -112,15 +110,12 @@ Story history:
|
|
| 112 |
|
| 113 |
{what_to_represent}
|
| 114 |
|
| 115 |
-
|
| 116 |
-
Be short. Never describes game variables.
|
| 117 |
|
| 118 |
IT MUST BE THE DIRECT CONTINUATION OF THE CURRENT STORY.
|
| 119 |
You MUST mention the previous situation and what is happening now with the new choice.
|
| 120 |
Never propose choices or options. Never describe the game variables.
|
| 121 |
-
|
| 122 |
-
MANDATORY: Each segment must be close to 15 words, keep it concise.
|
| 123 |
-
Be short. Never describes game variables.
|
| 124 |
"""
|
| 125 |
return ChatPromptTemplate(
|
| 126 |
messages=[
|
|
@@ -191,10 +186,7 @@ Be short. Never describes game variables.
|
|
| 191 |
return 0 <= word_count <= 30
|
| 192 |
|
| 193 |
async def generate(self, story_beat: int, current_time: str, current_location: str, previous_choice: str, story_history: str = "", turn_before_end: int = 0, is_winning_story: bool = False) -> StorySegmentResponse:
|
| 194 |
-
"""Generate the next story segment
|
| 195 |
-
retry_count = 0
|
| 196 |
-
last_attempt = None
|
| 197 |
-
|
| 198 |
is_end = True if story_beat == turn_before_end else False
|
| 199 |
is_death = True if is_end and is_winning_story else False
|
| 200 |
is_victory = True if is_end and not is_winning_story else False
|
|
@@ -211,10 +203,9 @@ Write a story segment that:
|
|
| 211 |
2. Maintains consistency with the universe and story
|
| 212 |
3. Respects all previous rules about length and style
|
| 213 |
4. Naturally integrates the custom elements while staying true to the plot
|
| 214 |
-
Close to 15 words.
|
| 215 |
"""
|
| 216 |
|
| 217 |
-
# Créer les messages
|
| 218 |
messages = self.prompt.format_messages(
|
| 219 |
hero_description=self.hero_desc,
|
| 220 |
FORMATTING_RULES=FORMATTING_RULES,
|
|
@@ -230,32 +221,6 @@ Close to 15 words.
|
|
| 230 |
universe_macguffin=self.universe_macguffin
|
| 231 |
)
|
| 232 |
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
try:
|
| 237 |
-
story_text = await self.mistral_client.generate_text(current_messages)
|
| 238 |
-
word_count = len(story_text.split())
|
| 239 |
-
|
| 240 |
-
if self._is_valid_length(story_text):
|
| 241 |
-
return StorySegmentResponse(story_text=story_text)
|
| 242 |
-
|
| 243 |
-
retry_count += 1
|
| 244 |
-
if retry_count < self.max_retries:
|
| 245 |
-
# Créer un nouveau message avec le feedback sur la longueur
|
| 246 |
-
if word_count > 15:
|
| 247 |
-
feedback = f"The previous response was too long ({word_count} words). Here was your last attempt:\n\n{story_text}\n\nPlease generate a MUCH SHORTER story segment close to 15 words that continues from: {story_history}"
|
| 248 |
-
|
| 249 |
-
# Réinitialiser les messages avec les messages de base
|
| 250 |
-
current_messages = messages.copy()
|
| 251 |
-
# Ajouter le feedback
|
| 252 |
-
current_messages.append(HumanMessage(content=feedback))
|
| 253 |
-
last_attempt = story_text
|
| 254 |
-
continue
|
| 255 |
-
|
| 256 |
-
raise ValueError(f"Failed to generate text of valid length after {self.max_retries} attempts. Last attempt had {word_count} words.")
|
| 257 |
-
|
| 258 |
-
except Exception as e:
|
| 259 |
-
retry_count += 1
|
| 260 |
-
if retry_count >= self.max_retries:
|
| 261 |
-
raise e
|
|
|
|
| 17 |
self.universe_epoch = universe_epoch
|
| 18 |
self.universe_story = universe_story
|
| 19 |
self.universe_macguffin = universe_macguffin
|
|
|
|
| 20 |
# Then call parent constructor which will create the prompt
|
| 21 |
super().__init__(mistral_client, hero_name=hero_name, hero_desc=hero_desc)
|
| 22 |
|
|
|
|
| 89 |
|
| 90 |
Hero Description: {self.hero_desc}
|
| 91 |
|
|
|
|
| 92 |
"""
|
| 93 |
|
| 94 |
human_template = """
|
|
|
|
| 110 |
|
| 111 |
{what_to_represent}
|
| 112 |
|
| 113 |
+
Never describes game variables.
|
|
|
|
| 114 |
|
| 115 |
IT MUST BE THE DIRECT CONTINUATION OF THE CURRENT STORY.
|
| 116 |
You MUST mention the previous situation and what is happening now with the new choice.
|
| 117 |
Never propose choices or options. Never describe the game variables.
|
| 118 |
+
LIMIT: 15 words.
|
|
|
|
|
|
|
| 119 |
"""
|
| 120 |
return ChatPromptTemplate(
|
| 121 |
messages=[
|
|
|
|
| 186 |
return 0 <= word_count <= 30
|
| 187 |
|
| 188 |
async def generate(self, story_beat: int, current_time: str, current_location: str, previous_choice: str, story_history: str = "", turn_before_end: int = 0, is_winning_story: bool = False) -> StorySegmentResponse:
|
| 189 |
+
"""Generate the next story segment."""
|
|
|
|
|
|
|
|
|
|
| 190 |
is_end = True if story_beat == turn_before_end else False
|
| 191 |
is_death = True if is_end and is_winning_story else False
|
| 192 |
is_victory = True if is_end and not is_winning_story else False
|
|
|
|
| 203 |
2. Maintains consistency with the universe and story
|
| 204 |
3. Respects all previous rules about length and style
|
| 205 |
4. Naturally integrates the custom elements while staying true to the plot
|
|
|
|
| 206 |
"""
|
| 207 |
|
| 208 |
+
# Créer les messages
|
| 209 |
messages = self.prompt.format_messages(
|
| 210 |
hero_description=self.hero_desc,
|
| 211 |
FORMATTING_RULES=FORMATTING_RULES,
|
|
|
|
| 221 |
universe_macguffin=self.universe_macguffin
|
| 222 |
)
|
| 223 |
|
| 224 |
+
# Générer le texte
|
| 225 |
+
story_text = await self.mistral_client.generate_text(messages)
|
| 226 |
+
return StorySegmentResponse(story_text=story_text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
server/core/story_generator.py
CHANGED
|
@@ -26,7 +26,13 @@ class StoryGenerator:
|
|
| 26 |
self.model_name = model_name
|
| 27 |
self.turn_before_end = random.randint(GameConfig.MIN_SEGMENTS_BEFORE_END, GameConfig.MAX_SEGMENTS_BEFORE_END)
|
| 28 |
self.is_winning_story = random.random() < GameConfig.WINNING_STORY_CHANCE
|
|
|
|
|
|
|
| 29 |
self.mistral_client = MistralClient(api_key=api_key, model_name=model_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
self.image_prompt_generator = None # Will be initialized with the first universe style
|
| 31 |
self.metadata_generator = None # Will be initialized with hero description
|
| 32 |
self.segment_generators: Dict[str, StorySegmentGenerator] = {}
|
|
@@ -65,7 +71,7 @@ class StoryGenerator:
|
|
| 65 |
|
| 66 |
# Create a new StorySegmentGenerator with all universe parameters
|
| 67 |
self.segment_generators[session_id] = StorySegmentGenerator(
|
| 68 |
-
self.
|
| 69 |
universe_style=style["name"],
|
| 70 |
universe_genre=genre,
|
| 71 |
universe_epoch=epoch,
|
|
|
|
| 26 |
self.model_name = model_name
|
| 27 |
self.turn_before_end = random.randint(GameConfig.MIN_SEGMENTS_BEFORE_END, GameConfig.MAX_SEGMENTS_BEFORE_END)
|
| 28 |
self.is_winning_story = random.random() < GameConfig.WINNING_STORY_CHANCE
|
| 29 |
+
|
| 30 |
+
# Client principal avec limite standard
|
| 31 |
self.mistral_client = MistralClient(api_key=api_key, model_name=model_name)
|
| 32 |
+
|
| 33 |
+
# Client spécifique pour les segments d'histoire avec limite plus basse
|
| 34 |
+
self.story_segment_client = MistralClient(api_key=api_key, model_name=model_name, max_tokens=50)
|
| 35 |
+
|
| 36 |
self.image_prompt_generator = None # Will be initialized with the first universe style
|
| 37 |
self.metadata_generator = None # Will be initialized with hero description
|
| 38 |
self.segment_generators: Dict[str, StorySegmentGenerator] = {}
|
|
|
|
| 71 |
|
| 72 |
# Create a new StorySegmentGenerator with all universe parameters
|
| 73 |
self.segment_generators[session_id] = StorySegmentGenerator(
|
| 74 |
+
self.story_segment_client,
|
| 75 |
universe_style=style["name"],
|
| 76 |
universe_genre=genre,
|
| 77 |
universe_epoch=epoch,
|
server/services/mistral_client.py
CHANGED
|
@@ -31,17 +31,17 @@ logger = logging.getLogger(__name__)
|
|
| 31 |
# Pricing: https://docs.mistral.ai/platform/pricing/
|
| 32 |
|
| 33 |
class MistralClient:
|
| 34 |
-
def __init__(self, api_key: str, model_name: str = "mistral-large-latest"):
|
| 35 |
-
logger.info(f"Initializing MistralClient with model: {model_name}")
|
| 36 |
self.model = ChatMistralAI(
|
| 37 |
mistral_api_key=api_key,
|
| 38 |
model=model_name,
|
| 39 |
-
max_tokens=
|
| 40 |
)
|
| 41 |
self.fixing_model = ChatMistralAI(
|
| 42 |
mistral_api_key=api_key,
|
| 43 |
model=model_name,
|
| 44 |
-
max_tokens=
|
| 45 |
)
|
| 46 |
|
| 47 |
# Pour gérer le rate limit
|
|
|
|
| 31 |
# Pricing: https://docs.mistral.ai/platform/pricing/
|
| 32 |
|
| 33 |
class MistralClient:
|
| 34 |
+
def __init__(self, api_key: str, model_name: str = "mistral-large-latest", max_tokens: int = 1000):
|
| 35 |
+
logger.info(f"Initializing MistralClient with model: {model_name}, max_tokens: {max_tokens}")
|
| 36 |
self.model = ChatMistralAI(
|
| 37 |
mistral_api_key=api_key,
|
| 38 |
model=model_name,
|
| 39 |
+
max_tokens=max_tokens
|
| 40 |
)
|
| 41 |
self.fixing_model = ChatMistralAI(
|
| 42 |
mistral_api_key=api_key,
|
| 43 |
model=model_name,
|
| 44 |
+
max_tokens=max_tokens
|
| 45 |
)
|
| 46 |
|
| 47 |
# Pour gérer le rate limit
|