NoLev commited on
Commit
e3748bb
·
verified ·
1 Parent(s): 6665991

Update app/main.py

Browse files
Files changed (1) hide show
  1. app/main.py +102 -97
app/main.py CHANGED
@@ -11,6 +11,7 @@ from transformers import pipeline
11
  import requests
12
  from urllib.parse import urlparse
13
  import logging # For verbose logging
 
14
 
15
  # Suppress warnings for cleaner logs
16
  import warnings
@@ -44,7 +45,8 @@ OPENROUTER_API_URL = "https://openrouter.ai/api/v1/chat/completions"
44
  # Model setup (CPU only)
45
  SUMMARIZER_MODELS = ["facebook/bart-large-cnn", "distilbart-cnn-6-6"]
46
  EXCERPT_MODELS = ["sentence-transformers/all-MiniLM-L6-v2", "sentence-transformers/all-distilroberta-v1"]
47
- DEFAULT_PROCESSING_MODEL = "facebook/bart-large-cnn"
 
48
 
49
  # Cache for models to avoid reloading
50
  processing_model_cache = {}
@@ -52,12 +54,15 @@ processing_model_cache = {}
52
  def get_processing_model(model_name: str, is_summarizer: bool):
53
  if model_name not in processing_model_cache:
54
  try:
55
- if model_name in SUMMARIZER_MODELS and is_summarizer:
56
  processing_model_cache[model_name] = pipeline("summarization", model=model_name, device=-1) # CPU
57
- elif model_name in EXCERPT_MODELS and not is_summarizer:
58
  processing_model_cache[model_name] = SentenceTransformer(model_name, device='cpu')
59
  else:
60
- raise ValueError(f"Unsupported model: {model_name}")
 
 
 
61
  except Exception as e:
62
  print(f"Error loading model {model_name}: {e}")
63
  raise HTTPException(status_code=500, detail=f"Failed to load model: {str(e)}")
@@ -102,7 +107,7 @@ class PromptRequest(BaseModel):
102
  objects: str = "" # New field for objects
103
  prompt: str = "" # Manual prompt field
104
  model: str # OpenRouter model
105
- processing_model: str = DEFAULT_PROCESSING_MODEL
106
  summary_length: int = 1000 # Target ~1000 words for summarizers
107
 
108
  class PasswordRequest(BaseModel):
@@ -188,10 +193,19 @@ def generate_manuscript_summary(manuscript: str, processing_model: str, target_w
188
 
189
  # Split into smaller chunks for CPU efficiency
190
  try:
191
- sentences = nltk.sent_tokenize(last_10000_words)
 
 
 
 
 
 
 
192
  except Exception as e:
193
  logger.error(f"Sentence tokenization error: {e}")
194
- return ""
 
 
195
 
196
  chunks = []
197
  current_chunk = ""
@@ -238,6 +252,38 @@ def generate_manuscript_summary(manuscript: str, processing_model: str, target_w
238
  return combined_summary[:target_max_length]
239
  return combined_summary
240
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
  # Generate prompt based on last 10,000 words, outline, characters, locations, objects, manual prompt, and prompt history
242
  def generate_prompt(manuscript: str, outline: str, characters: str, locations: str, objects: str, manual_prompt: str, project_id: str, model: str, processing_model: str, summary_length: int):
243
  logger.info(f"Generating prompt for {project_id}")
@@ -249,7 +295,8 @@ def generate_prompt(manuscript: str, outline: str, characters: str, locations: s
249
  try:
250
  previous_prompts = [h['prompt'] for h in history if h['prompt']]
251
  if previous_prompts:
252
- processing_model_instance = get_processing_model(processing_model, is_summarizer=False)
 
253
  prompt_embeddings = processing_model_instance.encode(previous_prompts, batch_size=4)
254
  avg_embedding = np.mean(prompt_embeddings, axis=0)
255
  similarities = np.dot(prompt_embeddings, avg_embedding) / (
@@ -276,10 +323,14 @@ def generate_prompt(manuscript: str, outline: str, characters: str, locations: s
276
  summary = generate_manuscript_summary(last_10000_words, processing_model, summary_length)
277
  if summary:
278
  context += "Manuscript Context (Last 10,000 Words Summary):\n" + summary + "\n\n"
279
- # Extract key sentences
280
- sentences = nltk.sent_tokenize(last_10000_words)[:15]
 
 
 
 
281
  if sentences:
282
- processing_model_instance = get_processing_model(processing_model, is_summarizer=False)
283
  sentence_embeddings = processing_model_instance.encode(sentences, batch_size=4)
284
  avg_embedding = np.mean(sentence_embeddings, axis=0)
285
  similarities = np.dot(sentence_embeddings, avg_embedding) / (
@@ -291,73 +342,50 @@ def generate_prompt(manuscript: str, outline: str, characters: str, locations: s
291
  except Exception as e:
292
  logger.warning(f"Manuscript context extraction error: {e}")
293
 
294
- # Extract relevant outline points
295
  if outline:
296
  try:
297
- sentences = nltk.sent_tokenize(outline)[:20]
298
- if sentences and manuscript:
299
- processing_model_instance = get_processing_model(processing_model, is_summarizer=False)
300
- manuscript_embedding = processing_model_instance.encode(last_10000_words, batch_size=4)
301
- sentence_embeddings = processing_model_instance.encode(sentences, batch_size=4)
 
 
 
 
 
302
  similarities = np.dot(sentence_embeddings, manuscript_embedding) / (
303
  np.linalg.norm(sentence_embeddings, axis=1) * np.linalg.norm(manuscript_embedding)
304
  )
305
  top_indices = similarities.argsort()[-3:][::-1]
306
- key_points = [sentences[i] for i in top_indices]
307
  context += "Relevant Plot Points:\n" + "\n".join(key_points) + "\n\n"
308
  except Exception as e:
309
  logger.warning(f"Outline context extraction error: {e}")
310
 
311
- # Extract character details
312
- if characters:
313
- try:
314
- char_sentences = characters.split("\n")[:10]
315
- if char_sentences and manuscript:
316
- processing_model_instance = get_processing_model(processing_model, is_summarizer=False)
317
- manuscript_embedding = processing_model_instance.encode(last_10000_words, batch_size=4)
318
- char_embeddings = processing_model_instance.encode(char_sentences, batch_size=4)
319
- similarities = np.dot(char_embeddings, manuscript_embedding) / (
320
- np.linalg.norm(char_embeddings, axis=1) * np.linalg.norm(manuscript_embedding)
321
- )
322
- top_indices = similarities.argsort()[-3:][::-1]
323
- key_chars = [char_sentences[i] for i in top_indices]
324
- context += "Relevant Character Details:\n" + "\n".join(key_chars) + "\n\n"
325
- except Exception as e:
326
- logger.warning(f"Character context extraction error: {e}")
327
-
328
- # Extract location details
329
- if locations:
330
- try:
331
- loc_sentences = locations.split("\n")[:10]
332
- if loc_sentences and manuscript:
333
- processing_model_instance = get_processing_model(processing_model, is_summarizer=False)
334
- manuscript_embedding = processing_model_instance.encode(last_10000_words, batch_size=4)
335
- loc_embeddings = processing_model_instance.encode(loc_sentences, batch_size=4)
336
- similarities = np.dot(loc_embeddings, manuscript_embedding) / (
337
- np.linalg.norm(loc_embeddings, axis=1) * np.linalg.norm(manuscript_embedding)
338
- )
339
- top_indices = similarities.argsort()[-3:][::-1]
340
- key_locs = [loc_sentences[i] for i in top_indices]
341
- context += "Relevant Location Details:\n" + "\n".join(key_locs) + "\n\n"
342
- except Exception as e:
343
- logger.warning(f"Location context extraction error: {e}")
344
-
345
- # Extract object details
346
- if objects:
347
- try:
348
- obj_sentences = objects.split("\n")[:10]
349
- if obj_sentences and manuscript:
350
- processing_model_instance = get_processing_model(processing_model, is_summarizer=False)
351
- manuscript_embedding = processing_model_instance.encode(last_10000_words, batch_size=4)
352
- obj_embeddings = processing_model_instance.encode(obj_sentences, batch_size=4)
353
- similarities = np.dot(obj_embeddings, manuscript_embedding) / (
354
- np.linalg.norm(obj_embeddings, axis=1) * np.linalg.norm(manuscript_embedding)
355
- )
356
- top_indices = similarities.argsort()[-3:][::-1]
357
- key_objs = [obj_sentences[i] for i in top_indices]
358
- context += "Relevant Object Details:\n" + "\n".join(key_objs) + "\n\n"
359
- except Exception as e:
360
- logger.warning(f"Object context extraction error: {e}")
361
 
362
  # If manual prompt provided, use it as base
363
  if manual_prompt:
@@ -373,39 +401,16 @@ def generate_prompt(manuscript: str, outline: str, characters: str, locations: s
373
  "Craft a prompt that continues the narrative arc from the last 10,000 words of the manuscript, aligns with the provided outline, incorporates relevant character, location, and object details, and builds on the last generated prompt (if available). Ensure the prompt is specific, vivid, and sets up an immersive scene that maintains the tone, style, and direction of the story."
374
  ])
375
 
376
- headers = {
377
- "Authorization": f"Bearer {OPENROUTER_API_KEY}",
378
- "Content-Type": "application/json",
379
- "HTTP-Referer": "https://huggingface.co/spaces/NoLev/NovelCrafter",
380
- "X-Title": "Novel Prompt Generator"
381
- }
382
-
383
- payload = {
384
- "model": model,
385
- "messages": [
386
- {"role": "system", "content": system_prompt},
387
- {"role": "user", "content": user_prompt}
388
- ],
389
- "temperature": 0.7,
390
- "max_tokens": 600
391
- }
392
 
393
- try:
394
- response = requests.post(OPENROUTER_API_URL, headers=headers, json=payload)
395
- if response.status_code != 200:
396
- raise HTTPException(status_code=response.status_code, detail="Error from OpenRouter API")
397
- response_data = response.json()
398
- generated_prompt = response_data.get("choices", [{}])[0].get("message", {}).get("content", "")
399
- if len(generated_prompt) > MAX_MEDIUMTEXT_CHARS:
400
- generated_prompt = generated_prompt[:MAX_MEDIUMTEXT_CHARS]
401
- logger.warning(f"Generated prompt truncated to {MAX_MEDIUMTEXT_CHARS} characters")
402
- return generated_prompt
403
- except Exception as e:
404
- logger.error(f"OpenRouter API request failed: {e}")
405
- raise HTTPException(status_code=500, detail=f"Failed to generate prompt: {str(e)}")
406
 
407
  # Save inputs to database
408
  def save_inputs(project_id: str, manuscript: str, outline: str, characters: str, locations: str = "", objects: str = "", last_prompt: str = None):
 
409
  if len(manuscript) > MAX_MEDIUMTEXT_CHARS:
410
  manuscript = manuscript[:MAX_MEDIUMTEXT_CHARS]
411
  logger.warning(f"Manuscript truncated to {MAX_MEDIUMTEXT_CHARS} characters")
 
11
  import requests
12
  from urllib.parse import urlparse
13
  import logging # For verbose logging
14
+ import time # For retries
15
 
16
  # Suppress warnings for cleaner logs
17
  import warnings
 
45
  # Model setup (CPU only)
46
  SUMMARIZER_MODELS = ["facebook/bart-large-cnn", "distilbart-cnn-6-6"]
47
  EXCERPT_MODELS = ["sentence-transformers/all-MiniLM-L6-v2", "sentence-transformers/all-distilroberta-v1"]
48
+ DEFAULT_SUMMARIZER = "facebook/bart-large-cnn"
49
+ DEFAULT_EXCERPT = "sentence-transformers/all-MiniLM-L6-v2"
50
 
51
  # Cache for models to avoid reloading
52
  processing_model_cache = {}
 
54
  def get_processing_model(model_name: str, is_summarizer: bool):
55
  if model_name not in processing_model_cache:
56
  try:
57
+ if is_summarizer and model_name in SUMMARIZER_MODELS:
58
  processing_model_cache[model_name] = pipeline("summarization", model=model_name, device=-1) # CPU
59
+ elif not is_summarizer and model_name in EXCERPT_MODELS:
60
  processing_model_cache[model_name] = SentenceTransformer(model_name, device='cpu')
61
  else:
62
+ # Fallback: Use default for type
63
+ fallback = DEFAULT_SUMMARIZER if is_summarizer else DEFAULT_EXCERPT
64
+ print(f"Using fallback model {fallback} for {model_name} ({'summarizer' if is_summarizer else 'excerpt'})")
65
+ return get_processing_model(fallback, is_summarizer)
66
  except Exception as e:
67
  print(f"Error loading model {model_name}: {e}")
68
  raise HTTPException(status_code=500, detail=f"Failed to load model: {str(e)}")
 
107
  objects: str = "" # New field for objects
108
  prompt: str = "" # Manual prompt field
109
  model: str # OpenRouter model
110
+ processing_model: str = DEFAULT_SUMMARIZER
111
  summary_length: int = 1000 # Target ~1000 words for summarizers
112
 
113
  class PasswordRequest(BaseModel):
 
193
 
194
  # Split into smaller chunks for CPU efficiency
195
  try:
196
+ # Try new punkt_tab first
197
+ try:
198
+ nltk.data.find('tokenizers/punkt_tab')
199
+ tokenizer = nltk.data.load('tokenizers/punkt_tab/english.pickle')
200
+ sentences = tokenizer.tokenize(last_10000_words)
201
+ except LookupError:
202
+ # Fallback to old punkt
203
+ sentences = nltk.sent_tokenize(last_10000_words)
204
  except Exception as e:
205
  logger.error(f"Sentence tokenization error: {e}")
206
+ # Ultimate fallback: split on periods
207
+ sentences = [s.strip() for s in last_10000_words.replace('\n', ' ').split('.') if s.strip()]
208
+ sentences = sentences[:50] # Limit
209
 
210
  chunks = []
211
  current_chunk = ""
 
252
  return combined_summary[:target_max_length]
253
  return combined_summary
254
 
255
+ # OpenRouter call with retry for rate limits
256
+ def call_openrouter_with_retry(messages: list, model: str, max_tokens: int = 600, temperature: float = 0.7, retries: int = 3) -> str:
257
+ for attempt in range(retries):
258
+ try:
259
+ headers = {
260
+ "Authorization": f"Bearer {OPENROUTER_API_KEY}",
261
+ "Content-Type": "application/json",
262
+ "HTTP-Referer": "https://huggingface.co/spaces/NoLev/NovelCrafter",
263
+ "X-Title": "Novel Prompt Generator"
264
+ }
265
+ payload = {
266
+ "model": model,
267
+ "messages": messages,
268
+ "temperature": temperature,
269
+ "max_tokens": max_tokens
270
+ }
271
+ response = requests.post(OPENROUTER_API_URL, headers=headers, json=payload, timeout=30)
272
+ if response.status_code == 429:
273
+ logger.warning(f"Rate limit hit (attempt {attempt+1}/{retries}), waiting 5s...")
274
+ time.sleep(5)
275
+ continue
276
+ if response.status_code != 200:
277
+ raise HTTPException(status_code=response.status_code, detail="Error from OpenRouter API")
278
+ response_data = response.json()
279
+ return response_data.get("choices", [{}])[0].get("message", {}).get("content", "")
280
+ except Exception as e:
281
+ if attempt == retries - 1:
282
+ raise
283
+ logger.warning(f"API attempt {attempt+1} failed: {e}, retrying...")
284
+ time.sleep(2 ** attempt) # Exponential backoff
285
+ raise HTTPException(status_code=500, detail="Max retries exceeded for OpenRouter")
286
+
287
  # Generate prompt based on last 10,000 words, outline, characters, locations, objects, manual prompt, and prompt history
288
  def generate_prompt(manuscript: str, outline: str, characters: str, locations: str, objects: str, manual_prompt: str, project_id: str, model: str, processing_model: str, summary_length: int):
289
  logger.info(f"Generating prompt for {project_id}")
 
295
  try:
296
  previous_prompts = [h['prompt'] for h in history if h['prompt']]
297
  if previous_prompts:
298
+ excerpt_model = DEFAULT_EXCERPT # Use default for embeddings
299
+ processing_model_instance = get_processing_model(excerpt_model, is_summarizer=False)
300
  prompt_embeddings = processing_model_instance.encode(previous_prompts, batch_size=4)
301
  avg_embedding = np.mean(prompt_embeddings, axis=0)
302
  similarities = np.dot(prompt_embeddings, avg_embedding) / (
 
323
  summary = generate_manuscript_summary(last_10000_words, processing_model, summary_length)
324
  if summary:
325
  context += "Manuscript Context (Last 10,000 Words Summary):\n" + summary + "\n\n"
326
+ # Extract key sentences (use default excerpt for embeddings)
327
+ excerpt_model = DEFAULT_EXCERPT
328
+ try:
329
+ sentences = nltk.sent_tokenize(last_10000_words)[:15]
330
+ except:
331
+ sentences = [s.strip() for s in last_10000_words.replace('\n', ' ').split('.') if s.strip()][:15]
332
  if sentences:
333
+ processing_model_instance = get_processing_model(excerpt_model, is_summarizer=False)
334
  sentence_embeddings = processing_model_instance.encode(sentences, batch_size=4)
335
  avg_embedding = np.mean(sentence_embeddings, axis=0)
336
  similarities = np.dot(sentence_embeddings, avg_embedding) / (
 
342
  except Exception as e:
343
  logger.warning(f"Manuscript context extraction error: {e}")
344
 
345
+ # Extract relevant outline points (use default excerpt)
346
  if outline:
347
  try:
348
+ try:
349
+ outline_sentences = nltk.sent_tokenize(outline)[:20]
350
+ except:
351
+ outline_sentences = [s.strip() for s in outline.replace('\n', ' ').split('.') if s.strip()][:20]
352
+ if outline_sentences and manuscript:
353
+ excerpt_model = DEFAULT_EXCERPT
354
+ processing_model_instance = get_processing_model(excerpt_model, is_summarizer=False)
355
+ last_10000_words = " ".join(manuscript.split()[-10000:]) if len(manuscript.split()) > 10000 else manuscript
356
+ manuscript_embedding = processing_model_instance.encode([last_10000_words], batch_size=1)[0]
357
+ sentence_embeddings = processing_model_instance.encode(outline_sentences, batch_size=4)
358
  similarities = np.dot(sentence_embeddings, manuscript_embedding) / (
359
  np.linalg.norm(sentence_embeddings, axis=1) * np.linalg.norm(manuscript_embedding)
360
  )
361
  top_indices = similarities.argsort()[-3:][::-1]
362
+ key_points = [outline_sentences[i] for i in top_indices]
363
  context += "Relevant Plot Points:\n" + "\n".join(key_points) + "\n\n"
364
  except Exception as e:
365
  logger.warning(f"Outline context extraction error: {e}")
366
 
367
+ # Similar for characters, locations, objects (using default excerpt)
368
+ for detail_type, detail_text in [("characters", characters), ("locations", locations), ("objects", objects)]:
369
+ if detail_text:
370
+ try:
371
+ try:
372
+ detail_sentences = nltk.sent_tokenize(detail_text)[:10]
373
+ except:
374
+ detail_sentences = [s.strip() for s in detail_text.replace('\n', ' ').split('.') if s.strip()][:10]
375
+ if detail_sentences and manuscript:
376
+ excerpt_model = DEFAULT_EXCERPT
377
+ processing_model_instance = get_processing_model(excerpt_model, is_summarizer=False)
378
+ last_10000_words = " ".join(manuscript.split()[-10000:]) if len(manuscript.split()) > 10000 else manuscript
379
+ manuscript_embedding = processing_model_instance.encode([last_10000_words], batch_size=1)[0]
380
+ detail_embeddings = processing_model_instance.encode(detail_sentences, batch_size=4)
381
+ similarities = np.dot(detail_embeddings, manuscript_embedding) / (
382
+ np.linalg.norm(detail_embeddings, axis=1) * np.linalg.norm(manuscript_embedding)
383
+ )
384
+ top_indices = similarities.argsort()[-3:][::-1]
385
+ key_details = [detail_sentences[i] for i in top_indices]
386
+ context += f"Relevant {detail_type.title()} Details:\n" + "\n".join(key_details) + "\n\n"
387
+ except Exception as e:
388
+ logger.warning(f"{detail_type} context extraction error: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
389
 
390
  # If manual prompt provided, use it as base
391
  if manual_prompt:
 
401
  "Craft a prompt that continues the narrative arc from the last 10,000 words of the manuscript, aligns with the provided outline, incorporates relevant character, location, and object details, and builds on the last generated prompt (if available). Ensure the prompt is specific, vivid, and sets up an immersive scene that maintains the tone, style, and direction of the story."
402
  ])
403
 
404
+ messages = [
405
+ {"role": "system", "content": system_prompt},
406
+ {"role": "user", "content": user_prompt}
407
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
408
 
409
+ return call_openrouter_with_retry(messages, model, max_tokens=600, temperature=0.7)
 
 
 
 
 
 
 
 
 
 
 
 
410
 
411
  # Save inputs to database
412
  def save_inputs(project_id: str, manuscript: str, outline: str, characters: str, locations: str = "", objects: str = "", last_prompt: str = None):
413
+ logger.info(f"Saving inputs for {project_id}")
414
  if len(manuscript) > MAX_MEDIUMTEXT_CHARS:
415
  manuscript = manuscript[:MAX_MEDIUMTEXT_CHARS]
416
  logger.warning(f"Manuscript truncated to {MAX_MEDIUMTEXT_CHARS} characters")