Spaces:
Sleeping
Sleeping
Update main.py
Browse files
main.py
CHANGED
|
@@ -247,67 +247,86 @@ def generate_story_endpoint():
|
|
| 247 |
return jsonify({'error': 'Invalid or expired token'}), 401
|
| 248 |
|
| 249 |
# --- Read Request Data ---
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
|
|
|
|
|
|
| 253 |
story_type = data.get('story_type', 'free_form')
|
| 254 |
style = data.get('style', 'whimsical')
|
| 255 |
voice_model = data.get('voice_model', 'aura-asteria-en')
|
| 256 |
image_model = data.get('image_model', 'hf')
|
| 257 |
audio_model = data.get('audio_model', 'deepgram')
|
| 258 |
|
| 259 |
-
if
|
| 260 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 261 |
|
| 262 |
-
|
| 263 |
-
# --- Select the Appropriate Story Generation Function ---
|
| 264 |
story_gen_start = time.time()
|
| 265 |
full_story = None
|
| 266 |
|
| 267 |
if input_type == "text":
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
if input_type == "pdf":
|
| 271 |
-
from stories import generate_story_from_text
|
| 272 |
-
from stories import get_pdf_text
|
| 273 |
-
prompt = get_pdf_text(pdf)
|
| 274 |
full_story = generate_story_from_text(prompt, story_type)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 275 |
elif input_type == "wiki":
|
| 276 |
wiki_url = data.get("wiki_url")
|
| 277 |
if not wiki_url:
|
| 278 |
return jsonify({'error': 'wiki_url is required for input_type "wiki"'}), 400
|
| 279 |
from stories import generate_story_from_wiki
|
| 280 |
full_story = generate_story_from_wiki(wiki_url, story_type)
|
|
|
|
| 281 |
elif input_type == "bible":
|
| 282 |
bible_reference = data.get("bible_reference")
|
| 283 |
if not bible_reference:
|
| 284 |
return jsonify({'error': 'bible_reference is required for input_type "bible"'}), 400
|
| 285 |
from stories import generate_story_from_bible
|
| 286 |
full_story = generate_story_from_bible(bible_reference, story_type)
|
|
|
|
| 287 |
elif input_type == "youtube":
|
| 288 |
youtube_url = data.get("youtube_url")
|
| 289 |
if not youtube_url:
|
| 290 |
return jsonify({'error': 'youtube_url is required for input_type "youtube"'}), 400
|
| 291 |
from stories import generate_story_from_youtube
|
| 292 |
full_story = generate_story_from_youtube(youtube_url, story_type)
|
| 293 |
-
elif input_type == "dataframe":
|
| 294 |
-
# Expecting dataframe data as JSON (list of dicts)
|
| 295 |
-
df_data = data.get("data")
|
| 296 |
-
if not df_data:
|
| 297 |
-
return jsonify({'error': 'Data for dataframe input_type is required'}), 400
|
| 298 |
-
df = pd.DataFrame(df_data)
|
| 299 |
-
from stories import generate_story_from_dataframe
|
| 300 |
-
full_story = generate_story_from_dataframe(df, story_type)
|
| 301 |
-
else:
|
| 302 |
-
return jsonify({'error': 'Unsupported input_type'}), 400
|
| 303 |
|
|
|
|
| 304 |
story_gen_end = time.time()
|
| 305 |
story_generation_time = story_gen_end - story_gen_start
|
| 306 |
|
| 307 |
if not full_story:
|
| 308 |
return jsonify({'error': 'Story generation failed'}), 500
|
| 309 |
|
| 310 |
-
#
|
| 311 |
sections_raw = [s.strip() for s in full_story.split("[break]") if s.strip()]
|
| 312 |
if len(sections_raw) < 5:
|
| 313 |
sections_raw += ["(Placeholder section)"] * (5 - len(sections_raw))
|
|
@@ -318,48 +337,48 @@ def generate_story_endpoint():
|
|
| 318 |
image_generation_times = []
|
| 319 |
audio_generation_times = []
|
| 320 |
|
| 321 |
-
|
| 322 |
-
from
|
| 323 |
-
from audio_gen import generate_audio # audio generation function
|
| 324 |
|
| 325 |
-
# Process each section
|
| 326 |
-
for
|
| 327 |
-
#
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
img_prompt = img_prompt_match.group(1).strip() if img_prompt_match else section[:100]
|
| 331 |
|
|
|
|
| 332 |
image_start = time.time()
|
| 333 |
image_obj, _ = generate_image_with_retry(img_prompt, style, model=image_model)
|
| 334 |
image_end = time.time()
|
| 335 |
image_generation_times.append(image_end - image_start)
|
| 336 |
|
| 337 |
-
# Save image locally
|
| 338 |
image_filename = f"/tmp/{uuid.uuid4().hex}.jpg"
|
| 339 |
image_obj.save(image_filename, format="JPEG")
|
| 340 |
image_blob_name = f"stories/{uid}/{uuid.uuid4().hex}.jpg"
|
| 341 |
image_url = upload_to_storage(image_filename, image_blob_name)
|
|
|
|
| 342 |
|
| 343 |
-
#
|
|
|
|
|
|
|
| 344 |
audio_start = time.time()
|
| 345 |
-
audio_file_path = generate_audio(
|
| 346 |
audio_end = time.time()
|
| 347 |
audio_generation_times.append(audio_end - audio_start)
|
| 348 |
|
|
|
|
| 349 |
audio_blob_name = f"stories/{uid}/{uuid.uuid4().hex}.mp3"
|
| 350 |
audio_url = upload_to_storage(audio_file_path, audio_blob_name)
|
|
|
|
| 351 |
|
| 352 |
sections.append({
|
| 353 |
-
"section_text":
|
| 354 |
"image_url": image_url,
|
| 355 |
"audio_url": audio_url
|
| 356 |
})
|
| 357 |
|
| 358 |
-
|
| 359 |
-
os.remove(image_filename)
|
| 360 |
-
os.remove(audio_file_path)
|
| 361 |
-
|
| 362 |
-
# --- Store the Story Record in Firebase Realtime Database ---
|
| 363 |
story_id = str(uuid.uuid4())
|
| 364 |
story_ref = db.reference(f"stories/{story_id}")
|
| 365 |
story_record = {
|
|
@@ -376,7 +395,8 @@ def generate_story_endpoint():
|
|
| 376 |
"story_type": story_type
|
| 377 |
}
|
| 378 |
story_ref.set(story_record)
|
| 379 |
-
|
|
|
|
| 380 |
user_ref = db.reference(f"users/{uid}")
|
| 381 |
user_data = user_ref.get() or {}
|
| 382 |
current_credits = user_data.get("credits", 0)
|
|
|
|
| 247 |
return jsonify({'error': 'Invalid or expired token'}), 401
|
| 248 |
|
| 249 |
# --- Read Request Data ---
|
| 250 |
+
# If the user is uploading a file (PDF or CSV/Excel), we can read from request.files
|
| 251 |
+
# If the user is sending JSON only, we read request.get_json()
|
| 252 |
+
data = request.form.to_dict() # For multipart/form-data fields
|
| 253 |
+
input_type = data.get('input_type', 'text') # "text", "pdf", "wiki", "bible", "youtube", "dataframe"
|
| 254 |
+
prompt = data.get('prompt') # For "text" or fallback
|
| 255 |
story_type = data.get('story_type', 'free_form')
|
| 256 |
style = data.get('style', 'whimsical')
|
| 257 |
voice_model = data.get('voice_model', 'aura-asteria-en')
|
| 258 |
image_model = data.get('image_model', 'hf')
|
| 259 |
audio_model = data.get('audio_model', 'deepgram')
|
| 260 |
|
| 261 |
+
# Validate if needed
|
| 262 |
+
if input_type not in ["text", "pdf", "wiki", "bible", "youtube", "dataframe"]:
|
| 263 |
+
return jsonify({'error': 'Unsupported input_type'}), 400
|
| 264 |
+
|
| 265 |
+
# 1) Generate the full story text
|
| 266 |
+
from stories import generate_story_from_text
|
| 267 |
+
from stories import get_pdf_text
|
| 268 |
+
from stories import get_df
|
| 269 |
|
|
|
|
|
|
|
| 270 |
story_gen_start = time.time()
|
| 271 |
full_story = None
|
| 272 |
|
| 273 |
if input_type == "text":
|
| 274 |
+
if not prompt:
|
| 275 |
+
return jsonify({'error': 'Prompt is required for text input'}), 400
|
|
|
|
|
|
|
|
|
|
|
|
|
| 276 |
full_story = generate_story_from_text(prompt, story_type)
|
| 277 |
+
|
| 278 |
+
elif input_type == "pdf":
|
| 279 |
+
# Expecting a file in request.files["file"]
|
| 280 |
+
uploaded_file = request.files.get("file")
|
| 281 |
+
if not uploaded_file:
|
| 282 |
+
return jsonify({'error': 'No PDF file uploaded'}), 400
|
| 283 |
+
|
| 284 |
+
# Convert PDF to text
|
| 285 |
+
pdf_text = get_pdf_text(uploaded_file)
|
| 286 |
+
full_story = generate_story_from_text(pdf_text, story_type)
|
| 287 |
+
|
| 288 |
+
elif input_type == "dataframe":
|
| 289 |
+
# Expecting a file in request.files["file"] and an "ext" field (csv, xlsx, xls)
|
| 290 |
+
uploaded_file = request.files.get("file")
|
| 291 |
+
ext = data.get("ext") # e.g. "csv", "xlsx", "xls"
|
| 292 |
+
if not uploaded_file or not ext:
|
| 293 |
+
return jsonify({'error': 'File and ext are required for dataframe input'}), 400
|
| 294 |
+
|
| 295 |
+
df = get_df(uploaded_file, ext)
|
| 296 |
+
if df is None:
|
| 297 |
+
return jsonify({'error': f'Failed to read {ext} file'}), 400
|
| 298 |
+
from stories import generate_story_from_dataframe
|
| 299 |
+
full_story = generate_story_from_dataframe(df, story_type)
|
| 300 |
+
|
| 301 |
elif input_type == "wiki":
|
| 302 |
wiki_url = data.get("wiki_url")
|
| 303 |
if not wiki_url:
|
| 304 |
return jsonify({'error': 'wiki_url is required for input_type "wiki"'}), 400
|
| 305 |
from stories import generate_story_from_wiki
|
| 306 |
full_story = generate_story_from_wiki(wiki_url, story_type)
|
| 307 |
+
|
| 308 |
elif input_type == "bible":
|
| 309 |
bible_reference = data.get("bible_reference")
|
| 310 |
if not bible_reference:
|
| 311 |
return jsonify({'error': 'bible_reference is required for input_type "bible"'}), 400
|
| 312 |
from stories import generate_story_from_bible
|
| 313 |
full_story = generate_story_from_bible(bible_reference, story_type)
|
| 314 |
+
|
| 315 |
elif input_type == "youtube":
|
| 316 |
youtube_url = data.get("youtube_url")
|
| 317 |
if not youtube_url:
|
| 318 |
return jsonify({'error': 'youtube_url is required for input_type "youtube"'}), 400
|
| 319 |
from stories import generate_story_from_youtube
|
| 320 |
full_story = generate_story_from_youtube(youtube_url, story_type)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 321 |
|
| 322 |
+
# Measure generation time
|
| 323 |
story_gen_end = time.time()
|
| 324 |
story_generation_time = story_gen_end - story_gen_start
|
| 325 |
|
| 326 |
if not full_story:
|
| 327 |
return jsonify({'error': 'Story generation failed'}), 500
|
| 328 |
|
| 329 |
+
# 2) Split the story into 5 sections
|
| 330 |
sections_raw = [s.strip() for s in full_story.split("[break]") if s.strip()]
|
| 331 |
if len(sections_raw) < 5:
|
| 332 |
sections_raw += ["(Placeholder section)"] * (5 - len(sections_raw))
|
|
|
|
| 337 |
image_generation_times = []
|
| 338 |
audio_generation_times = []
|
| 339 |
|
| 340 |
+
from image_gen import generate_image_with_retry
|
| 341 |
+
from audio_gen import generate_audio
|
|
|
|
| 342 |
|
| 343 |
+
# 3) Process each section
|
| 344 |
+
for section_text in sections_raw:
|
| 345 |
+
# Extract an image prompt between angle brackets
|
| 346 |
+
img_prompt_match = re.search(r"<(.*?)>", section_text)
|
| 347 |
+
img_prompt = img_prompt_match.group(1).strip() if img_prompt_match else section_text[:100]
|
|
|
|
| 348 |
|
| 349 |
+
# Generate image
|
| 350 |
image_start = time.time()
|
| 351 |
image_obj, _ = generate_image_with_retry(img_prompt, style, model=image_model)
|
| 352 |
image_end = time.time()
|
| 353 |
image_generation_times.append(image_end - image_start)
|
| 354 |
|
| 355 |
+
# Save image locally -> upload -> get URL
|
| 356 |
image_filename = f"/tmp/{uuid.uuid4().hex}.jpg"
|
| 357 |
image_obj.save(image_filename, format="JPEG")
|
| 358 |
image_blob_name = f"stories/{uid}/{uuid.uuid4().hex}.jpg"
|
| 359 |
image_url = upload_to_storage(image_filename, image_blob_name)
|
| 360 |
+
os.remove(image_filename)
|
| 361 |
|
| 362 |
+
# Generate audio from section text WITHOUT <image> description
|
| 363 |
+
# e.g. remove <...> from text
|
| 364 |
+
audio_text = re.sub(r"<.*?>", "", section_text) # remove anything in angle brackets
|
| 365 |
audio_start = time.time()
|
| 366 |
+
audio_file_path = generate_audio(audio_text, voice_model, audio_model=audio_model)
|
| 367 |
audio_end = time.time()
|
| 368 |
audio_generation_times.append(audio_end - audio_start)
|
| 369 |
|
| 370 |
+
# Upload audio
|
| 371 |
audio_blob_name = f"stories/{uid}/{uuid.uuid4().hex}.mp3"
|
| 372 |
audio_url = upload_to_storage(audio_file_path, audio_blob_name)
|
| 373 |
+
os.remove(audio_file_path)
|
| 374 |
|
| 375 |
sections.append({
|
| 376 |
+
"section_text": section_text,
|
| 377 |
"image_url": image_url,
|
| 378 |
"audio_url": audio_url
|
| 379 |
})
|
| 380 |
|
| 381 |
+
# 4) Store the story record in Firebase Realtime Database
|
|
|
|
|
|
|
|
|
|
|
|
|
| 382 |
story_id = str(uuid.uuid4())
|
| 383 |
story_ref = db.reference(f"stories/{story_id}")
|
| 384 |
story_record = {
|
|
|
|
| 395 |
"story_type": story_type
|
| 396 |
}
|
| 397 |
story_ref.set(story_record)
|
| 398 |
+
|
| 399 |
+
# Subtract 5 Credits
|
| 400 |
user_ref = db.reference(f"users/{uid}")
|
| 401 |
user_data = user_ref.get() or {}
|
| 402 |
current_credits = user_data.get("credits", 0)
|