Spaces:
Sleeping
Sleeping
Update main.py
Browse files
main.py
CHANGED
|
@@ -20,6 +20,8 @@ import zipfile
|
|
| 20 |
from fpdf import FPDF
|
| 21 |
import tempfile
|
| 22 |
import urllib.parse
|
|
|
|
|
|
|
| 23 |
|
| 24 |
# Initialize Flask app and CORS
|
| 25 |
app = Flask(__name__)
|
|
@@ -247,7 +249,7 @@ def generate_story_endpoint():
|
|
| 247 |
return jsonify({'error': 'Invalid or expired token'}), 401
|
| 248 |
|
| 249 |
# --- Read Request Data ---
|
| 250 |
-
data = request.form.to_dict() # For multipart/form-data
|
| 251 |
input_type = data.get('input_type', 'text') # "text", "pdf", "wiki", "bible", "youtube", "dataframe"
|
| 252 |
prompt = data.get('prompt') # For "text" only
|
| 253 |
story_type = data.get('story_type', 'free_form')
|
|
@@ -256,20 +258,21 @@ def generate_story_endpoint():
|
|
| 256 |
image_model = data.get('image_model', 'hf')
|
| 257 |
audio_model = data.get('audio_model', 'deepgram')
|
| 258 |
|
| 259 |
-
# Validate if needed
|
| 260 |
if input_type not in ["text", "pdf", "wiki", "bible", "youtube", "dataframe"]:
|
| 261 |
return jsonify({'error': 'Unsupported input_type'}), 400
|
| 262 |
|
| 263 |
-
from stories import
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
|
|
|
|
|
|
|
|
|
| 267 |
story_gen_start = time.time()
|
| 268 |
full_story = None
|
| 269 |
|
|
|
|
| 270 |
if input_type == "text":
|
| 271 |
-
# <-- CHANGE HERE
|
| 272 |
-
# We only check 'prompt' if input_type == "text"
|
| 273 |
if not prompt:
|
| 274 |
return jsonify({'error': 'Prompt is required for text input'}), 400
|
| 275 |
full_story = generate_story_from_text(prompt, story_type)
|
|
@@ -283,14 +286,14 @@ def generate_story_endpoint():
|
|
| 283 |
|
| 284 |
elif input_type == "dataframe":
|
| 285 |
uploaded_file = request.files.get("file")
|
| 286 |
-
ext = data.get("ext") #
|
| 287 |
if not uploaded_file or not ext:
|
| 288 |
return jsonify({'error': 'File and ext are required for dataframe input'}), 400
|
| 289 |
|
| 290 |
df = get_df(uploaded_file, ext)
|
| 291 |
if df is None:
|
| 292 |
return jsonify({'error': f'Failed to read {ext} file'}), 400
|
| 293 |
-
|
| 294 |
full_story = generate_story_from_dataframe(df, story_type)
|
| 295 |
|
| 296 |
elif input_type == "wiki":
|
|
@@ -321,7 +324,7 @@ def generate_story_endpoint():
|
|
| 321 |
if not full_story:
|
| 322 |
return jsonify({'error': 'Story generation failed'}), 500
|
| 323 |
|
| 324 |
-
# 2) Split
|
| 325 |
sections_raw = [s.strip() for s in full_story.split("[break]") if s.strip()]
|
| 326 |
if len(sections_raw) < 5:
|
| 327 |
sections_raw += ["(Placeholder section)"] * (5 - len(sections_raw))
|
|
@@ -335,33 +338,55 @@ def generate_story_endpoint():
|
|
| 335 |
from image_gen import generate_image_with_retry
|
| 336 |
from audio_gen import generate_audio
|
| 337 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 338 |
# 3) Process each section
|
| 339 |
for section_text in sections_raw:
|
| 340 |
# Extract an image prompt between angle brackets
|
| 341 |
img_prompt_match = re.search(r"<(.*?)>", section_text)
|
| 342 |
img_prompt = img_prompt_match.group(1).strip() if img_prompt_match else section_text[:100]
|
| 343 |
|
| 344 |
-
# Generate image
|
| 345 |
image_start = time.time()
|
| 346 |
-
image_obj
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 347 |
image_end = time.time()
|
| 348 |
image_generation_times.append(image_end - image_start)
|
| 349 |
|
| 350 |
-
# Save
|
| 351 |
image_filename = f"/tmp/{uuid.uuid4().hex}.jpg"
|
| 352 |
image_obj.save(image_filename, format="JPEG")
|
| 353 |
image_blob_name = f"stories/{uid}/{uuid.uuid4().hex}.jpg"
|
| 354 |
image_url = upload_to_storage(image_filename, image_blob_name)
|
| 355 |
os.remove(image_filename)
|
| 356 |
|
| 357 |
-
# Generate audio
|
| 358 |
-
audio_text = re.sub(r"<.*?>", "", section_text)
|
| 359 |
audio_start = time.time()
|
| 360 |
audio_file_path = generate_audio(audio_text, voice_model, audio_model=audio_model)
|
| 361 |
audio_end = time.time()
|
| 362 |
audio_generation_times.append(audio_end - audio_start)
|
| 363 |
|
| 364 |
-
# Upload audio
|
| 365 |
audio_blob_name = f"stories/{uid}/{uuid.uuid4().hex}.mp3"
|
| 366 |
audio_url = upload_to_storage(audio_file_path, audio_blob_name)
|
| 367 |
os.remove(audio_file_path)
|
|
@@ -372,7 +397,7 @@ def generate_story_endpoint():
|
|
| 372 |
"audio_url": audio_url
|
| 373 |
})
|
| 374 |
|
| 375 |
-
# 4) Store the story
|
| 376 |
story_id = str(uuid.uuid4())
|
| 377 |
story_ref = db.reference(f"stories/{story_id}")
|
| 378 |
story_record = {
|
|
|
|
| 20 |
from fpdf import FPDF
|
| 21 |
import tempfile
|
| 22 |
import urllib.parse
|
| 23 |
+
from stories import generateResponse
|
| 24 |
+
|
| 25 |
|
| 26 |
# Initialize Flask app and CORS
|
| 27 |
app = Flask(__name__)
|
|
|
|
| 249 |
return jsonify({'error': 'Invalid or expired token'}), 401
|
| 250 |
|
| 251 |
# --- Read Request Data ---
|
| 252 |
+
data = request.form.to_dict() # For multipart/form-data
|
| 253 |
input_type = data.get('input_type', 'text') # "text", "pdf", "wiki", "bible", "youtube", "dataframe"
|
| 254 |
prompt = data.get('prompt') # For "text" only
|
| 255 |
story_type = data.get('story_type', 'free_form')
|
|
|
|
| 258 |
image_model = data.get('image_model', 'hf')
|
| 259 |
audio_model = data.get('audio_model', 'deepgram')
|
| 260 |
|
|
|
|
| 261 |
if input_type not in ["text", "pdf", "wiki", "bible", "youtube", "dataframe"]:
|
| 262 |
return jsonify({'error': 'Unsupported input_type'}), 400
|
| 263 |
|
| 264 |
+
from stories import (
|
| 265 |
+
generate_story_from_text,
|
| 266 |
+
get_pdf_text,
|
| 267 |
+
get_df,
|
| 268 |
+
generate_story_from_dataframe,
|
| 269 |
+
generateResponse # <-- for chart images
|
| 270 |
+
)
|
| 271 |
story_gen_start = time.time()
|
| 272 |
full_story = None
|
| 273 |
|
| 274 |
+
# 1) Generate the full story text
|
| 275 |
if input_type == "text":
|
|
|
|
|
|
|
| 276 |
if not prompt:
|
| 277 |
return jsonify({'error': 'Prompt is required for text input'}), 400
|
| 278 |
full_story = generate_story_from_text(prompt, story_type)
|
|
|
|
| 286 |
|
| 287 |
elif input_type == "dataframe":
|
| 288 |
uploaded_file = request.files.get("file")
|
| 289 |
+
ext = data.get("ext") # "csv", "xlsx", "xls"
|
| 290 |
if not uploaded_file or not ext:
|
| 291 |
return jsonify({'error': 'File and ext are required for dataframe input'}), 400
|
| 292 |
|
| 293 |
df = get_df(uploaded_file, ext)
|
| 294 |
if df is None:
|
| 295 |
return jsonify({'error': f'Failed to read {ext} file'}), 400
|
| 296 |
+
|
| 297 |
full_story = generate_story_from_dataframe(df, story_type)
|
| 298 |
|
| 299 |
elif input_type == "wiki":
|
|
|
|
| 324 |
if not full_story:
|
| 325 |
return jsonify({'error': 'Story generation failed'}), 500
|
| 326 |
|
| 327 |
+
# 2) Split into 5 sections
|
| 328 |
sections_raw = [s.strip() for s in full_story.split("[break]") if s.strip()]
|
| 329 |
if len(sections_raw) < 5:
|
| 330 |
sections_raw += ["(Placeholder section)"] * (5 - len(sections_raw))
|
|
|
|
| 338 |
from image_gen import generate_image_with_retry
|
| 339 |
from audio_gen import generate_audio
|
| 340 |
|
| 341 |
+
# If input_type is "dataframe", we have a df for chart generation
|
| 342 |
+
df = None
|
| 343 |
+
if input_type == "dataframe":
|
| 344 |
+
uploaded_file = request.files.get("file")
|
| 345 |
+
ext = data.get("ext")
|
| 346 |
+
df = get_df(uploaded_file, ext) # re-use the same df
|
| 347 |
+
|
| 348 |
# 3) Process each section
|
| 349 |
for section_text in sections_raw:
|
| 350 |
# Extract an image prompt between angle brackets
|
| 351 |
img_prompt_match = re.search(r"<(.*?)>", section_text)
|
| 352 |
img_prompt = img_prompt_match.group(1).strip() if img_prompt_match else section_text[:100]
|
| 353 |
|
|
|
|
| 354 |
image_start = time.time()
|
| 355 |
+
image_obj = None
|
| 356 |
+
|
| 357 |
+
# If we are dealing with "dataframe", attempt chart generation first
|
| 358 |
+
if input_type == "dataframe" and df is not None:
|
| 359 |
+
try:
|
| 360 |
+
chart_str = generateResponse(img_prompt, df) # returns a Python string or None
|
| 361 |
+
if chart_str and chart_str.startswith("data:image/png;base64,"):
|
| 362 |
+
# decode base64 -> PIL Image
|
| 363 |
+
base64_data = chart_str.split(",", 1)[1]
|
| 364 |
+
chart_bytes = base64.b64decode(base64_data)
|
| 365 |
+
image_obj = Image.open(io.BytesIO(chart_bytes))
|
| 366 |
+
except Exception as e:
|
| 367 |
+
print("DataFrame chart generation error:", e)
|
| 368 |
+
|
| 369 |
+
# Fallback to generate_image_with_retry
|
| 370 |
+
if not image_obj:
|
| 371 |
+
image_obj, _ = generate_image_with_retry(img_prompt, style, model=image_model)
|
| 372 |
+
|
| 373 |
image_end = time.time()
|
| 374 |
image_generation_times.append(image_end - image_start)
|
| 375 |
|
| 376 |
+
# Save & upload
|
| 377 |
image_filename = f"/tmp/{uuid.uuid4().hex}.jpg"
|
| 378 |
image_obj.save(image_filename, format="JPEG")
|
| 379 |
image_blob_name = f"stories/{uid}/{uuid.uuid4().hex}.jpg"
|
| 380 |
image_url = upload_to_storage(image_filename, image_blob_name)
|
| 381 |
os.remove(image_filename)
|
| 382 |
|
| 383 |
+
# Generate audio without <image> description
|
| 384 |
+
audio_text = re.sub(r"<.*?>", "", section_text)
|
| 385 |
audio_start = time.time()
|
| 386 |
audio_file_path = generate_audio(audio_text, voice_model, audio_model=audio_model)
|
| 387 |
audio_end = time.time()
|
| 388 |
audio_generation_times.append(audio_end - audio_start)
|
| 389 |
|
|
|
|
| 390 |
audio_blob_name = f"stories/{uid}/{uuid.uuid4().hex}.mp3"
|
| 391 |
audio_url = upload_to_storage(audio_file_path, audio_blob_name)
|
| 392 |
os.remove(audio_file_path)
|
|
|
|
| 397 |
"audio_url": audio_url
|
| 398 |
})
|
| 399 |
|
| 400 |
+
# 4) Store the story
|
| 401 |
story_id = str(uuid.uuid4())
|
| 402 |
story_ref = db.reference(f"stories/{story_id}")
|
| 403 |
story_record = {
|