rairo commited on
Commit
5bfdda5
·
verified ·
1 Parent(s): 1e67e77

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +43 -18
main.py CHANGED
@@ -20,6 +20,8 @@ import zipfile
20
  from fpdf import FPDF
21
  import tempfile
22
  import urllib.parse
 
 
23
 
24
  # Initialize Flask app and CORS
25
  app = Flask(__name__)
@@ -247,7 +249,7 @@ def generate_story_endpoint():
247
  return jsonify({'error': 'Invalid or expired token'}), 401
248
 
249
  # --- Read Request Data ---
250
- data = request.form.to_dict() # For multipart/form-data fields
251
  input_type = data.get('input_type', 'text') # "text", "pdf", "wiki", "bible", "youtube", "dataframe"
252
  prompt = data.get('prompt') # For "text" only
253
  story_type = data.get('story_type', 'free_form')
@@ -256,20 +258,21 @@ def generate_story_endpoint():
256
  image_model = data.get('image_model', 'hf')
257
  audio_model = data.get('audio_model', 'deepgram')
258
 
259
- # Validate if needed
260
  if input_type not in ["text", "pdf", "wiki", "bible", "youtube", "dataframe"]:
261
  return jsonify({'error': 'Unsupported input_type'}), 400
262
 
263
- from stories import generate_story_from_text
264
- from stories import get_pdf_text
265
- from stories import get_df
266
-
 
 
 
267
  story_gen_start = time.time()
268
  full_story = None
269
 
 
270
  if input_type == "text":
271
- # <-- CHANGE HERE
272
- # We only check 'prompt' if input_type == "text"
273
  if not prompt:
274
  return jsonify({'error': 'Prompt is required for text input'}), 400
275
  full_story = generate_story_from_text(prompt, story_type)
@@ -283,14 +286,14 @@ def generate_story_endpoint():
283
 
284
  elif input_type == "dataframe":
285
  uploaded_file = request.files.get("file")
286
- ext = data.get("ext") # e.g. "csv", "xlsx", "xls"
287
  if not uploaded_file or not ext:
288
  return jsonify({'error': 'File and ext are required for dataframe input'}), 400
289
 
290
  df = get_df(uploaded_file, ext)
291
  if df is None:
292
  return jsonify({'error': f'Failed to read {ext} file'}), 400
293
- from stories import generate_story_from_dataframe
294
  full_story = generate_story_from_dataframe(df, story_type)
295
 
296
  elif input_type == "wiki":
@@ -321,7 +324,7 @@ def generate_story_endpoint():
321
  if not full_story:
322
  return jsonify({'error': 'Story generation failed'}), 500
323
 
324
- # 2) Split the story into 5 sections
325
  sections_raw = [s.strip() for s in full_story.split("[break]") if s.strip()]
326
  if len(sections_raw) < 5:
327
  sections_raw += ["(Placeholder section)"] * (5 - len(sections_raw))
@@ -335,33 +338,55 @@ def generate_story_endpoint():
335
  from image_gen import generate_image_with_retry
336
  from audio_gen import generate_audio
337
 
 
 
 
 
 
 
 
338
  # 3) Process each section
339
  for section_text in sections_raw:
340
  # Extract an image prompt between angle brackets
341
  img_prompt_match = re.search(r"<(.*?)>", section_text)
342
  img_prompt = img_prompt_match.group(1).strip() if img_prompt_match else section_text[:100]
343
 
344
- # Generate image
345
  image_start = time.time()
346
- image_obj, _ = generate_image_with_retry(img_prompt, style, model=image_model)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
347
  image_end = time.time()
348
  image_generation_times.append(image_end - image_start)
349
 
350
- # Save image locally -> upload -> get URL
351
  image_filename = f"/tmp/{uuid.uuid4().hex}.jpg"
352
  image_obj.save(image_filename, format="JPEG")
353
  image_blob_name = f"stories/{uid}/{uuid.uuid4().hex}.jpg"
354
  image_url = upload_to_storage(image_filename, image_blob_name)
355
  os.remove(image_filename)
356
 
357
- # Generate audio from section text WITHOUT <image> description
358
- audio_text = re.sub(r"<.*?>", "", section_text) # remove anything in angle brackets
359
  audio_start = time.time()
360
  audio_file_path = generate_audio(audio_text, voice_model, audio_model=audio_model)
361
  audio_end = time.time()
362
  audio_generation_times.append(audio_end - audio_start)
363
 
364
- # Upload audio
365
  audio_blob_name = f"stories/{uid}/{uuid.uuid4().hex}.mp3"
366
  audio_url = upload_to_storage(audio_file_path, audio_blob_name)
367
  os.remove(audio_file_path)
@@ -372,7 +397,7 @@ def generate_story_endpoint():
372
  "audio_url": audio_url
373
  })
374
 
375
- # 4) Store the story record in Firebase Realtime Database
376
  story_id = str(uuid.uuid4())
377
  story_ref = db.reference(f"stories/{story_id}")
378
  story_record = {
 
20
  from fpdf import FPDF
21
  import tempfile
22
  import urllib.parse
23
+ from stories import generateResponse
24
+
25
 
26
  # Initialize Flask app and CORS
27
  app = Flask(__name__)
 
249
  return jsonify({'error': 'Invalid or expired token'}), 401
250
 
251
  # --- Read Request Data ---
252
+ data = request.form.to_dict() # For multipart/form-data
253
  input_type = data.get('input_type', 'text') # "text", "pdf", "wiki", "bible", "youtube", "dataframe"
254
  prompt = data.get('prompt') # For "text" only
255
  story_type = data.get('story_type', 'free_form')
 
258
  image_model = data.get('image_model', 'hf')
259
  audio_model = data.get('audio_model', 'deepgram')
260
 
 
261
  if input_type not in ["text", "pdf", "wiki", "bible", "youtube", "dataframe"]:
262
  return jsonify({'error': 'Unsupported input_type'}), 400
263
 
264
+ from stories import (
265
+ generate_story_from_text,
266
+ get_pdf_text,
267
+ get_df,
268
+ generate_story_from_dataframe,
269
+ generateResponse # <-- for chart images
270
+ )
271
  story_gen_start = time.time()
272
  full_story = None
273
 
274
+ # 1) Generate the full story text
275
  if input_type == "text":
 
 
276
  if not prompt:
277
  return jsonify({'error': 'Prompt is required for text input'}), 400
278
  full_story = generate_story_from_text(prompt, story_type)
 
286
 
287
  elif input_type == "dataframe":
288
  uploaded_file = request.files.get("file")
289
+ ext = data.get("ext") # "csv", "xlsx", "xls"
290
  if not uploaded_file or not ext:
291
  return jsonify({'error': 'File and ext are required for dataframe input'}), 400
292
 
293
  df = get_df(uploaded_file, ext)
294
  if df is None:
295
  return jsonify({'error': f'Failed to read {ext} file'}), 400
296
+
297
  full_story = generate_story_from_dataframe(df, story_type)
298
 
299
  elif input_type == "wiki":
 
324
  if not full_story:
325
  return jsonify({'error': 'Story generation failed'}), 500
326
 
327
+ # 2) Split into 5 sections
328
  sections_raw = [s.strip() for s in full_story.split("[break]") if s.strip()]
329
  if len(sections_raw) < 5:
330
  sections_raw += ["(Placeholder section)"] * (5 - len(sections_raw))
 
338
  from image_gen import generate_image_with_retry
339
  from audio_gen import generate_audio
340
 
341
+ # If input_type is "dataframe", we have a df for chart generation
342
+ df = None
343
+ if input_type == "dataframe":
344
+ uploaded_file = request.files.get("file")
345
+ ext = data.get("ext")
346
+ df = get_df(uploaded_file, ext) # re-use the same df
347
+
348
  # 3) Process each section
349
  for section_text in sections_raw:
350
  # Extract an image prompt between angle brackets
351
  img_prompt_match = re.search(r"<(.*?)>", section_text)
352
  img_prompt = img_prompt_match.group(1).strip() if img_prompt_match else section_text[:100]
353
 
 
354
  image_start = time.time()
355
+ image_obj = None
356
+
357
+ # If we are dealing with "dataframe", attempt chart generation first
358
+ if input_type == "dataframe" and df is not None:
359
+ try:
360
+ chart_str = generateResponse(img_prompt, df) # returns a Python string or None
361
+ if chart_str and chart_str.startswith("data:image/png;base64,"):
362
+ # decode base64 -> PIL Image
363
+ base64_data = chart_str.split(",", 1)[1]
364
+ chart_bytes = base64.b64decode(base64_data)
365
+ image_obj = Image.open(io.BytesIO(chart_bytes))
366
+ except Exception as e:
367
+ print("DataFrame chart generation error:", e)
368
+
369
+ # Fallback to generate_image_with_retry
370
+ if not image_obj:
371
+ image_obj, _ = generate_image_with_retry(img_prompt, style, model=image_model)
372
+
373
  image_end = time.time()
374
  image_generation_times.append(image_end - image_start)
375
 
376
+ # Save & upload
377
  image_filename = f"/tmp/{uuid.uuid4().hex}.jpg"
378
  image_obj.save(image_filename, format="JPEG")
379
  image_blob_name = f"stories/{uid}/{uuid.uuid4().hex}.jpg"
380
  image_url = upload_to_storage(image_filename, image_blob_name)
381
  os.remove(image_filename)
382
 
383
+ # Generate audio without <image> description
384
+ audio_text = re.sub(r"<.*?>", "", section_text)
385
  audio_start = time.time()
386
  audio_file_path = generate_audio(audio_text, voice_model, audio_model=audio_model)
387
  audio_end = time.time()
388
  audio_generation_times.append(audio_end - audio_start)
389
 
 
390
  audio_blob_name = f"stories/{uid}/{uuid.uuid4().hex}.mp3"
391
  audio_url = upload_to_storage(audio_file_path, audio_blob_name)
392
  os.remove(audio_file_path)
 
397
  "audio_url": audio_url
398
  })
399
 
400
+ # 4) Store the story
401
  story_id = str(uuid.uuid4())
402
  story_ref = db.reference(f"stories/{story_id}")
403
  story_record = {