rairo commited on
Commit
ed4e2b3
·
verified ·
1 Parent(s): 9cda0fc

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +63 -43
main.py CHANGED
@@ -247,67 +247,86 @@ def generate_story_endpoint():
247
  return jsonify({'error': 'Invalid or expired token'}), 401
248
 
249
  # --- Read Request Data ---
250
- data = request.get_json()
251
- input_type = data.get('input_type', 'text') # "text", "wiki", "bible", "youtube", or "dataframe"
252
- prompt = data.get('prompt')
 
 
253
  story_type = data.get('story_type', 'free_form')
254
  style = data.get('style', 'whimsical')
255
  voice_model = data.get('voice_model', 'aura-asteria-en')
256
  image_model = data.get('image_model', 'hf')
257
  audio_model = data.get('audio_model', 'deepgram')
258
 
259
- if not prompt:
260
- return jsonify({'error': 'Prompt is required'}), 400
 
 
 
 
 
 
261
 
262
-
263
- # --- Select the Appropriate Story Generation Function ---
264
  story_gen_start = time.time()
265
  full_story = None
266
 
267
  if input_type == "text":
268
- from stories import generate_story_from_text
269
- full_story = generate_story_from_text(prompt, story_type)
270
- if input_type == "pdf":
271
- from stories import generate_story_from_text
272
- from stories import get_pdf_text
273
- prompt = get_pdf_text(pdf)
274
  full_story = generate_story_from_text(prompt, story_type)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275
  elif input_type == "wiki":
276
  wiki_url = data.get("wiki_url")
277
  if not wiki_url:
278
  return jsonify({'error': 'wiki_url is required for input_type "wiki"'}), 400
279
  from stories import generate_story_from_wiki
280
  full_story = generate_story_from_wiki(wiki_url, story_type)
 
281
  elif input_type == "bible":
282
  bible_reference = data.get("bible_reference")
283
  if not bible_reference:
284
  return jsonify({'error': 'bible_reference is required for input_type "bible"'}), 400
285
  from stories import generate_story_from_bible
286
  full_story = generate_story_from_bible(bible_reference, story_type)
 
287
  elif input_type == "youtube":
288
  youtube_url = data.get("youtube_url")
289
  if not youtube_url:
290
  return jsonify({'error': 'youtube_url is required for input_type "youtube"'}), 400
291
  from stories import generate_story_from_youtube
292
  full_story = generate_story_from_youtube(youtube_url, story_type)
293
- elif input_type == "dataframe":
294
- # Expecting dataframe data as JSON (list of dicts)
295
- df_data = data.get("data")
296
- if not df_data:
297
- return jsonify({'error': 'Data for dataframe input_type is required'}), 400
298
- df = pd.DataFrame(df_data)
299
- from stories import generate_story_from_dataframe
300
- full_story = generate_story_from_dataframe(df, story_type)
301
- else:
302
- return jsonify({'error': 'Unsupported input_type'}), 400
303
 
 
304
  story_gen_end = time.time()
305
  story_generation_time = story_gen_end - story_gen_start
306
 
307
  if not full_story:
308
  return jsonify({'error': 'Story generation failed'}), 500
309
 
310
- # --- Split the Story into 5 Sections ---
311
  sections_raw = [s.strip() for s in full_story.split("[break]") if s.strip()]
312
  if len(sections_raw) < 5:
313
  sections_raw += ["(Placeholder section)"] * (5 - len(sections_raw))
@@ -318,48 +337,48 @@ def generate_story_endpoint():
318
  image_generation_times = []
319
  audio_generation_times = []
320
 
321
- # Import generation functions from your modules
322
- from image_gen import generate_image_with_retry # image generation function
323
- from audio_gen import generate_audio # audio generation function
324
 
325
- # Process each section
326
- for section in sections_raw:
327
- # --- Image Generation ---
328
- # Extract an image prompt between angle brackets; otherwise, fallback to the first 100 characters.
329
- img_prompt_match = re.search(r"<(.*?)>", section)
330
- img_prompt = img_prompt_match.group(1).strip() if img_prompt_match else section[:100]
331
 
 
332
  image_start = time.time()
333
  image_obj, _ = generate_image_with_retry(img_prompt, style, model=image_model)
334
  image_end = time.time()
335
  image_generation_times.append(image_end - image_start)
336
 
337
- # Save image locally and upload it.
338
  image_filename = f"/tmp/{uuid.uuid4().hex}.jpg"
339
  image_obj.save(image_filename, format="JPEG")
340
  image_blob_name = f"stories/{uid}/{uuid.uuid4().hex}.jpg"
341
  image_url = upload_to_storage(image_filename, image_blob_name)
 
342
 
343
- # --- Audio Generation ---
 
 
344
  audio_start = time.time()
345
- audio_file_path = generate_audio(section, voice_model, audio_model=audio_model)
346
  audio_end = time.time()
347
  audio_generation_times.append(audio_end - audio_start)
348
 
 
349
  audio_blob_name = f"stories/{uid}/{uuid.uuid4().hex}.mp3"
350
  audio_url = upload_to_storage(audio_file_path, audio_blob_name)
 
351
 
352
  sections.append({
353
- "section_text": section,
354
  "image_url": image_url,
355
  "audio_url": audio_url
356
  })
357
 
358
- # Clean up temporary files
359
- os.remove(image_filename)
360
- os.remove(audio_file_path)
361
-
362
- # --- Store the Story Record in Firebase Realtime Database ---
363
  story_id = str(uuid.uuid4())
364
  story_ref = db.reference(f"stories/{story_id}")
365
  story_record = {
@@ -376,7 +395,8 @@ def generate_story_endpoint():
376
  "story_type": story_type
377
  }
378
  story_ref.set(story_record)
379
- # --- Subtract 5 Credits from the User ---
 
380
  user_ref = db.reference(f"users/{uid}")
381
  user_data = user_ref.get() or {}
382
  current_credits = user_data.get("credits", 0)
 
247
  return jsonify({'error': 'Invalid or expired token'}), 401
248
 
249
  # --- Read Request Data ---
250
+ # If the user is uploading a file (PDF or CSV/Excel), we can read from request.files
251
+ # If the user is sending JSON only, we read request.get_json()
252
+ data = request.form.to_dict() # For multipart/form-data fields
253
+ input_type = data.get('input_type', 'text') # "text", "pdf", "wiki", "bible", "youtube", "dataframe"
254
+ prompt = data.get('prompt') # For "text" or fallback
255
  story_type = data.get('story_type', 'free_form')
256
  style = data.get('style', 'whimsical')
257
  voice_model = data.get('voice_model', 'aura-asteria-en')
258
  image_model = data.get('image_model', 'hf')
259
  audio_model = data.get('audio_model', 'deepgram')
260
 
261
+ # Validate if needed
262
+ if input_type not in ["text", "pdf", "wiki", "bible", "youtube", "dataframe"]:
263
+ return jsonify({'error': 'Unsupported input_type'}), 400
264
+
265
+ # 1) Generate the full story text
266
+ from stories import generate_story_from_text
267
+ from stories import get_pdf_text
268
+ from stories import get_df
269
 
 
 
270
  story_gen_start = time.time()
271
  full_story = None
272
 
273
  if input_type == "text":
274
+ if not prompt:
275
+ return jsonify({'error': 'Prompt is required for text input'}), 400
 
 
 
 
276
  full_story = generate_story_from_text(prompt, story_type)
277
+
278
+ elif input_type == "pdf":
279
+ # Expecting a file in request.files["file"]
280
+ uploaded_file = request.files.get("file")
281
+ if not uploaded_file:
282
+ return jsonify({'error': 'No PDF file uploaded'}), 400
283
+
284
+ # Convert PDF to text
285
+ pdf_text = get_pdf_text(uploaded_file)
286
+ full_story = generate_story_from_text(pdf_text, story_type)
287
+
288
+ elif input_type == "dataframe":
289
+ # Expecting a file in request.files["file"] and an "ext" field (csv, xlsx, xls)
290
+ uploaded_file = request.files.get("file")
291
+ ext = data.get("ext") # e.g. "csv", "xlsx", "xls"
292
+ if not uploaded_file or not ext:
293
+ return jsonify({'error': 'File and ext are required for dataframe input'}), 400
294
+
295
+ df = get_df(uploaded_file, ext)
296
+ if df is None:
297
+ return jsonify({'error': f'Failed to read {ext} file'}), 400
298
+ from stories import generate_story_from_dataframe
299
+ full_story = generate_story_from_dataframe(df, story_type)
300
+
301
  elif input_type == "wiki":
302
  wiki_url = data.get("wiki_url")
303
  if not wiki_url:
304
  return jsonify({'error': 'wiki_url is required for input_type "wiki"'}), 400
305
  from stories import generate_story_from_wiki
306
  full_story = generate_story_from_wiki(wiki_url, story_type)
307
+
308
  elif input_type == "bible":
309
  bible_reference = data.get("bible_reference")
310
  if not bible_reference:
311
  return jsonify({'error': 'bible_reference is required for input_type "bible"'}), 400
312
  from stories import generate_story_from_bible
313
  full_story = generate_story_from_bible(bible_reference, story_type)
314
+
315
  elif input_type == "youtube":
316
  youtube_url = data.get("youtube_url")
317
  if not youtube_url:
318
  return jsonify({'error': 'youtube_url is required for input_type "youtube"'}), 400
319
  from stories import generate_story_from_youtube
320
  full_story = generate_story_from_youtube(youtube_url, story_type)
 
 
 
 
 
 
 
 
 
 
321
 
322
+ # Measure generation time
323
  story_gen_end = time.time()
324
  story_generation_time = story_gen_end - story_gen_start
325
 
326
  if not full_story:
327
  return jsonify({'error': 'Story generation failed'}), 500
328
 
329
+ # 2) Split the story into 5 sections
330
  sections_raw = [s.strip() for s in full_story.split("[break]") if s.strip()]
331
  if len(sections_raw) < 5:
332
  sections_raw += ["(Placeholder section)"] * (5 - len(sections_raw))
 
337
  image_generation_times = []
338
  audio_generation_times = []
339
 
340
+ from image_gen import generate_image_with_retry
341
+ from audio_gen import generate_audio
 
342
 
343
+ # 3) Process each section
344
+ for section_text in sections_raw:
345
+ # Extract an image prompt between angle brackets
346
+ img_prompt_match = re.search(r"<(.*?)>", section_text)
347
+ img_prompt = img_prompt_match.group(1).strip() if img_prompt_match else section_text[:100]
 
348
 
349
+ # Generate image
350
  image_start = time.time()
351
  image_obj, _ = generate_image_with_retry(img_prompt, style, model=image_model)
352
  image_end = time.time()
353
  image_generation_times.append(image_end - image_start)
354
 
355
+ # Save image locally -> upload -> get URL
356
  image_filename = f"/tmp/{uuid.uuid4().hex}.jpg"
357
  image_obj.save(image_filename, format="JPEG")
358
  image_blob_name = f"stories/{uid}/{uuid.uuid4().hex}.jpg"
359
  image_url = upload_to_storage(image_filename, image_blob_name)
360
+ os.remove(image_filename)
361
 
362
+ # Generate audio from section text WITHOUT <image> description
363
+ # e.g. remove <...> from text
364
+ audio_text = re.sub(r"<.*?>", "", section_text) # remove anything in angle brackets
365
  audio_start = time.time()
366
+ audio_file_path = generate_audio(audio_text, voice_model, audio_model=audio_model)
367
  audio_end = time.time()
368
  audio_generation_times.append(audio_end - audio_start)
369
 
370
+ # Upload audio
371
  audio_blob_name = f"stories/{uid}/{uuid.uuid4().hex}.mp3"
372
  audio_url = upload_to_storage(audio_file_path, audio_blob_name)
373
+ os.remove(audio_file_path)
374
 
375
  sections.append({
376
+ "section_text": section_text,
377
  "image_url": image_url,
378
  "audio_url": audio_url
379
  })
380
 
381
+ # 4) Store the story record in Firebase Realtime Database
 
 
 
 
382
  story_id = str(uuid.uuid4())
383
  story_ref = db.reference(f"stories/{story_id}")
384
  story_record = {
 
395
  "story_type": story_type
396
  }
397
  story_ref.set(story_record)
398
+
399
+ # Subtract 5 Credits
400
  user_ref = db.reference(f"users/{uid}")
401
  user_data = user_ref.get() or {}
402
  current_credits = user_data.get("credits", 0)