NoLev commited on
Commit
930332a
·
verified ·
1 Parent(s): 1422914

Update app/main.py

Browse files
Files changed (1) hide show
  1. app/main.py +24 -12
app/main.py CHANGED
@@ -274,7 +274,7 @@ def generate_manuscript_excerpts(manuscript: str, prompt: str):
274
  print(f"Excerpt generation error: {e}")
275
  return ""
276
 
277
- # Generate a prompt based on manuscript, outline, characters, and prompt history
278
  def generate_prompt(manuscript: str, outline: str, characters: str, project_id: str):
279
  # Initialize prompt components
280
  prompt_parts = []
@@ -297,10 +297,15 @@ def generate_prompt(manuscript: str, outline: str, characters: str, project_id:
297
  except Exception as e:
298
  print(f"Prompt history context extraction error: {e}")
299
 
300
- # Extract key sentences from manuscript using semantic similarity
 
 
301
  if manuscript:
302
  try:
303
- sentences = nltk.sent_tokenize(manuscript)[:50] # Limit to 50 sentences for efficiency
 
 
 
304
  if sentences:
305
  sentence_embeddings = excerpt_model.encode(sentences)
306
  avg_embedding = np.mean(sentence_embeddings, axis=0)
@@ -309,23 +314,30 @@ def generate_prompt(manuscript: str, outline: str, characters: str, project_id:
309
  )
310
  top_indices = similarities.argsort()[-3:][::-1] # Top 3 sentences
311
  key_sentences = [sentences[i] for i in top_indices]
312
- context += "Recent Manuscript Context: " + " ".join(key_sentences) + "\n"
313
  except Exception as e:
314
  print(f"Manuscript context extraction error: {e}")
315
 
316
- # Extract key plot points from outline
317
  if outline:
318
  try:
319
  sentences = nltk.sent_tokenize(outline)[:20] # Limit to 20 sentences
320
  if sentences:
321
  sentence_embeddings = excerpt_model.encode(sentences)
322
- avg_embedding = np.mean(sentence_embeddings, axis=0)
323
- similarities = np.dot(sentence_embeddings, avg_embedding) / (
324
- np.linalg.norm(sentence_embeddings, axis=1) * np.linalg.norm(avg_embedding)
325
- )
 
 
 
 
 
 
 
326
  top_indices = similarities.argsort()[-2:][::-1] # Top 2 sentences
327
  key_points = [sentences[i] for i in top_indices]
328
- context += "Key Plot Points: " + " ".join(key_points) + "\n"
329
  except Exception as e:
330
  print(f"Outline context extraction error: {e}")
331
 
@@ -346,10 +358,10 @@ def generate_prompt(manuscript: str, outline: str, characters: str, project_id:
346
  print(f"Character context extraction error: {e}")
347
 
348
  # Construct the prompt
349
- prompt_parts.append("Write a detailed prose passage for a novel based on the following context.")
350
  if context:
351
  prompt_parts.append(context)
352
- prompt_parts.append("Focus on advancing the story with vivid descriptions, character development, and alignment with the provided plot, character details, and previous prompts. Create a scene that feels immersive, consistent with the tone and style of the existing manuscript, and builds on prior narrative directions.")
353
 
354
  generated_prompt = "\n".join(prompt_parts)
355
  if len(generated_prompt) > MAX_MEDIUMTEXT_CHARS:
 
274
  print(f"Excerpt generation error: {e}")
275
  return ""
276
 
277
+ # Generate a prompt based on the last 1000 words of the manuscript, outline, characters, and prompt history
278
  def generate_prompt(manuscript: str, outline: str, characters: str, project_id: str):
279
  # Initialize prompt components
280
  prompt_parts = []
 
297
  except Exception as e:
298
  print(f"Prompt history context extraction error: {e}")
299
 
300
+ # Use the last 1000 words of the manuscript from the database
301
+ inputs = get_latest_inputs(project_id)
302
+ manuscript = inputs.get("manuscript") or manuscript
303
  if manuscript:
304
  try:
305
+ words = manuscript.split()
306
+ # Take the last 1000 words, or all if fewer than 1000
307
+ last_1000_words = " ".join(words[-1000:]) if len(words) > 1000 else manuscript
308
+ sentences = nltk.sent_tokenize(last_1000_words)[:15] # Limit to 15 sentences for efficiency
309
  if sentences:
310
  sentence_embeddings = excerpt_model.encode(sentences)
311
  avg_embedding = np.mean(sentence_embeddings, axis=0)
 
314
  )
315
  top_indices = similarities.argsort()[-3:][::-1] # Top 3 sentences
316
  key_sentences = [sentences[i] for i in top_indices]
317
+ context += "Recent Manuscript Context (Last 1000 Words): " + " ".join(key_sentences) + "\n"
318
  except Exception as e:
319
  print(f"Manuscript context extraction error: {e}")
320
 
321
+ # Extract key plot points from outline, focusing on current context
322
  if outline:
323
  try:
324
  sentences = nltk.sent_tokenize(outline)[:20] # Limit to 20 sentences
325
  if sentences:
326
  sentence_embeddings = excerpt_model.encode(sentences)
327
+ # If manuscript context exists, align outline with it
328
+ if manuscript:
329
+ manuscript_embedding = excerpt_model.encode(last_1000_words)
330
+ similarities = np.dot(sentence_embeddings, manuscript_embedding) / (
331
+ np.linalg.norm(sentence_embeddings, axis=1) * np.linalg.norm(manuscript_embedding)
332
+ )
333
+ else:
334
+ avg_embedding = np.mean(sentence_embeddings, axis=0)
335
+ similarities = np.dot(sentence_embeddings, avg_embedding) / (
336
+ np.linalg.norm(sentence_embeddings, axis=1) * np.linalg.norm(avg_embedding)
337
+ )
338
  top_indices = similarities.argsort()[-2:][::-1] # Top 2 sentences
339
  key_points = [sentences[i] for i in top_indices]
340
+ context += "Relevant Plot Points: " + " ".join(key_points) + "\n"
341
  except Exception as e:
342
  print(f"Outline context extraction error: {e}")
343
 
 
358
  print(f"Character context extraction error: {e}")
359
 
360
  # Construct the prompt
361
+ prompt_parts.append("Write a detailed prose passage for a novel based on the following context, focusing on the current scene.")
362
  if context:
363
  prompt_parts.append(context)
364
+ prompt_parts.append("Advance the story with vivid descriptions and character development, aligning with the provided plot points, character details, and previous prompts. Ensure the scene is immersive, consistent with the tone and style of the manuscript's latest section, and continues the immediate narrative arc.")
365
 
366
  generated_prompt = "\n".join(prompt_parts)
367
  if len(generated_prompt) > MAX_MEDIUMTEXT_CHARS: