rairo commited on
Commit
84f5a2b
·
verified ·
1 Parent(s): 78b5cf2

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +694 -37
src/streamlit_app.py CHANGED
@@ -1,40 +1,697 @@
1
- import altair as alt
 
 
 
 
 
 
 
 
 
 
 
2
  import numpy as np
 
 
 
 
 
 
3
  import pandas as pd
4
- import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
1
+ import streamlit as st
2
+ import os
3
+ import re
4
+ import time
5
+ import tempfile
6
+ import requests
7
+ import json
8
+ from google import genai
9
+ from google.genai import types
10
+ import google.generativeai as genai
11
+ import io
12
+ import base64
13
  import numpy as np
14
+ import cv2
15
+ import logging
16
+ import uuid
17
+ import subprocess
18
+ from pathlib import Path
19
+ import urllib.parse
20
  import pandas as pd
21
+ import plotly.graph_objects as go
22
+ import matplotlib.pyplot as plt
23
+ from langchain_google_genai import ChatGoogleGenerativeAI
24
+ # For PandasAI using a single dataframe
25
+ from pandasai import SmartDataframe
26
+ from pandasai.responses.response_parser import ResponseParser
27
+ from pandasai.exceptions import InvalidOutputValueMismatch
28
+ import base64
29
+ import os
30
+ import uuid
31
+ import matplotlib
32
+ import matplotlib.pyplot as plt
33
+ from io import BytesIO
34
+ import dataframe_image as dfi
35
+ import uuid
36
+ from supadata import Supadata, SupadataError
37
+ from PIL import ImageFont, ImageDraw, Image
38
+ import seaborn as sns
39
+
40
+ #Streamlit response parse
41
+ class StreamLitResponse(ResponseParser):
42
+ def __init__(self, context):
43
+ super().__init__(context)
44
+ # Ensure the export directory exists
45
+ os.makedirs("./exports/charts", exist_ok=True)
46
+
47
+ def format_dataframe(self, result):
48
+ """
49
+ Convert a DataFrame to an image using dataframe_image,
50
+ and return a dict with type 'plot' to match the expected output.
51
+ """
52
+ try:
53
+ df = result['value']
54
+ # Apply styling if desired
55
+ styled_df = df.style
56
+ img_path = f"./exports/charts/{uuid.uuid4().hex}.png"
57
+ dfi.export(styled_df, img_path)
58
+ except Exception as e:
59
+ print("Error in format_dataframe:", e)
60
+ # Fallback to a string representation if needed
61
+ img_path = str(result['value'])
62
+ print("response_class_path (dataframe):", img_path)
63
+ # Return as a dict with type 'plot'
64
+ return {'type': 'plot', 'value': img_path}
65
+
66
+ def format_plot(self, result):
67
+ img_value = result["value"]
68
+ # Case 1: If it's a matplotlib figure
69
+ if hasattr(img_value, "savefig"):
70
+ try:
71
+ img_path = f"./exports/charts/{uuid.uuid4().hex}.png"
72
+ img_value.savefig(img_path, format="png")
73
+ return {'type': 'plot', 'value': img_path}
74
+ except Exception as e:
75
+ print("Error saving matplotlib figure:", e)
76
+ return {'type': 'plot', 'value': str(img_value)}
77
+
78
+ # Case 2: If it's a file path (e.g., a .png file)
79
+ if isinstance(img_value, str) and os.path.isfile(img_value):
80
+ return {'type': 'plot', 'value': str(img_value)}
81
+
82
+ # Case 3: If it's a BytesIO object
83
+ if isinstance(img_value, io.BytesIO):
84
+ try:
85
+ img_path = f"./exports/charts/{uuid.uuid4().hex}.png"
86
+ with open(img_path, "wb") as f:
87
+ f.write(img_value.getvalue())
88
+ return {'type': 'plot', 'value': img_path}
89
+ except Exception as e:
90
+ print("Error writing BytesIO to file:", e)
91
+ return {'type': 'plot', 'value': str(img_value)}
92
+
93
+ # Case 4: If it's a base64 string
94
+ if isinstance(img_value, str) and (img_value.startswith("iVBOR") or img_value.startswith("data:image")):
95
+ try:
96
+ # Extract raw base64 if it's a data URI
97
+ if "base64," in img_value:
98
+ img_value = img_value.split("base64,")[1]
99
+ # Decode and save to file
100
+ img_path = f"./exports/charts/{uuid.uuid4().hex}.png"
101
+ with open(img_path, "wb") as f:
102
+ f.write(base64.b64decode(img_value))
103
+ return {'type': 'plot', 'value': img_path}
104
+ except Exception as e:
105
+ print("Error decoding base64 image:", e)
106
+ return {'type': 'plot', 'value': str(img_value)}
107
+
108
+ # Fallback: Return as a string
109
+ return {'type': 'plot', 'value': str(img_value)}
110
+
111
+ def format_other(self, result):
112
+ # For non-image responses, simply return the value as a string.
113
+ return {'type': 'text', 'value': str(result['value'])}
114
+
115
+
116
+ guid = uuid.uuid4()
117
+ new_filename = f"{guid}"
118
+ user_defined_path = os.path.join("./exports/charts/", new_filename)
119
+
120
+ img_ID = "344744a88ad1098"
121
+ img_secret = "3c542a40c215327045d7155bddfd8b8bc84aebbf"
122
+
123
+ imgur_url = "https://api.imgur.com/3/image"
124
+ imgur_headers = {"Authorization": f"Client-ID {img_ID}"}
125
+
126
+
127
+ # -----------------------
128
+ # Configuration and Logging
129
+ # -----------------------
130
+ logging.basicConfig(level=logging.INFO)
131
+ logger = logging.getLogger(__name__)
132
+
133
+ GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
134
+ if not GOOGLE_API_KEY:
135
+ st.error("Google API Key is missing. Please set it in environment variables or secrets.toml.")
136
+ else:
137
+ genai.configure(api_key=GOOGLE_API_KEY)
138
+
139
+ token = os.getenv('HF_API')
140
+ headers = {"Authorization": f"Bearer {token}"}
141
+
142
+ # Pandasai gemini
143
+ llm1 = ChatGoogleGenerativeAI(
144
+ model="gemini-2.0-flash-thinking-exp", # MODEL REVERTED
145
+ temperature=0,
146
+ max_tokens=None,
147
+ timeout=1000,
148
+ max_retries=2
149
+ )
150
+
151
+ # -----------------------
152
+ # Utility Constants
153
+ # -----------------------
154
+ MAX_CHARACTERS = 200000
155
+
156
+ def configure_gemini(api_key):
157
+ try:
158
+ genai.configure(api_key=api_key)
159
+ return genai.GenerativeModel('gemini-2.0-flash-thinking-exp') # MODEL REVERTED
160
+ except Exception as e:
161
+ logger.error(f"Error configuring Gemini: {str(e)}")
162
+ raise
163
+
164
+ # Initialize Gemini model for story generation
165
+ model = configure_gemini(GOOGLE_API_KEY)
166
+ os.environ["GEMINI_API_KEY"] = GOOGLE_API_KEY
167
+
168
+ # -----------------------
169
+ # PandasAI Response for DataFrame
170
+ # -----------------------
171
+ def generateResponse(prompt, df):
172
+ """Generate response using PandasAI with SmartDataframe."""
173
+ pandas_agent = SmartDataframe(df, config={"llm": llm1, "custom_whitelisted_dependencies": [
174
+ "os",
175
+ "io",
176
+ "sys",
177
+ "chr",
178
+ "glob",
179
+ "b64decoder",
180
+ "collections",
181
+ "geopy",
182
+ "geopandas",
183
+ "wordcloud",
184
+ "builtins"
185
+ ], "response_parser": StreamLitResponse,"security":"none", "enable_cache": False, "save_charts":False, "save_charts_path":user_defined_path})
186
+ return pandas_agent.chat(prompt)
187
+
188
+ # -----------------------
189
+ # DataFrame-Based Story Generation (for CSV/Excel files)
190
+ # -----------------------
191
+ def generate_story_from_dataframe(df, story_type):
192
+ """
193
+ Generate a data-based story from a CSV/Excel file.
194
+ """
195
+ df_json = json.dumps(df.to_dict())
196
+ prompts = {
197
+ "free_form": "You are a professional storyteller. Using the following dataset in JSON format: " + df_json +
198
+ ", create an engaging and concise story. ",
199
+ "children": "You are a professional storyteller writing stories for children. Using the following dataset in JSON format: " + df_json +
200
+ ", create a fun, factual, and concise story appropriate for children. ",
201
+ "education": "You are a professional storyteller writing educational content. Using the following dataset in JSON format: " + df_json +
202
+ ", create an informative, engaging, and concise educational story. Include interesting facts while keeping it engaging. ",
203
+ "business": "You are a professional storyteller specializing in business narratives. Using the following dataset in JSON format: " + df_json +
204
+ ", create a professional, concise business story with practical insights. ",
205
+ "entertainment": "You are a professional storyteller writing creative entertaining stories. Using the following dataset in JSON format: " + df_json +
206
+ ", create an engaging and concise entertaining story. Include interesting facts while keeping it engaging. "
207
+ }
208
+ story_prompt = prompts.get(story_type, prompts["free_form"])
209
+ full_prompt = (
210
+ story_prompt +
211
+ "Write a story for a narrator meaning no labels of pages or sections the story should just flow. Divide your story into exactly 5 short and concise sections separated by [break]. " +
212
+ "For each section, provide a brief narrative analysis and include, within angle brackets <>, a clear and plain-text description of a chart visualization that would represent the data. " +
213
+ "Limit the descriptions by specifying only charts. " +
214
+ "Ensure that your response contains only natural language descriptions examples: 'bar chart of', 'pie chart of' , 'histogram of', 'scatterplot of', 'boxplot of' etc and nothing else."
215
+ )
216
+
217
+ try:
218
+ response = model.generate_content(full_prompt)
219
+ if not response or not response.text:
220
+ return None
221
+
222
+ sections = response.text.split("[break]")
223
+ sections = [s.strip() for s in sections if s.strip()]
224
+
225
+ if len(sections) < 5:
226
+ sections += ["(Placeholder section)"] * (5 - len(sections))
227
+ elif len(sections) > 5:
228
+ sections = sections[:5]
229
+
230
+ return "[break]".join(sections)
231
+
232
+ except Exception as e:
233
+ st.error(f"Error generating story from dataframe: {e}")
234
+ return None
235
+
236
+ # -----------------------
237
+ # Extract Image Prompts and Story Sections
238
+ # -----------------------
239
+ def extract_image_prompts_and_story(story_text):
240
+ pages = []
241
+ image_prompts = []
242
+ parts = re.split(r"\[break\]", story_text)
243
+ for part in parts:
244
+ if not part.strip():
245
+ continue
246
+ img_match = re.search(r"<(.*?)>", part)
247
+ if img_match:
248
+ image_prompts.append(img_match.group(1).strip())
249
+ pages.append(re.sub(r"<(.*?)>", "", part).strip())
250
+ else:
251
+ snippet = part.strip()[:100]
252
+ pages.append(snippet)
253
+ image_prompts.append(f"A concise illustration of {snippet}")
254
+ return pages, image_prompts
255
+
256
+ def is_valid_png(file_path):
257
+ try:
258
+ with open(file_path, "rb") as f:
259
+ header = f.read(8)
260
+ if header != b'\x89PNG\r\n\x1a\n':
261
+ return False
262
+ with Image.open(file_path) as img:
263
+ img.verify()
264
+ return True
265
+ except Exception as e:
266
+ print(f"Invalid PNG file at {file_path}: {e}")
267
+ return False
268
+
269
+ def standardize_and_validate_image(file_path):
270
+ try:
271
+ with Image.open(file_path) as img:
272
+ img.verify()
273
+ with Image.open(file_path) as img:
274
+ img = img.convert("RGB")
275
+ buffer = io.BytesIO()
276
+ img.save(buffer, format="PNG")
277
+ buffer.seek(0)
278
+ with open(file_path, "wb") as f:
279
+ f.write(buffer.getvalue())
280
+ return True
281
+ except Exception as e:
282
+ print(f"Failed to standardize/validate {file_path}: {e}")
283
+ return False
284
+
285
+ def generate_image(prompt_text, style, model="hf"):
286
+ try:
287
+ if model == "pollinations_turbo":
288
+ prompt_encoded = urllib.parse.quote(prompt_text)
289
+ api_url = f"https://image.pollinations.ai/prompt/{prompt_encoded}?model=turbo"
290
+ response = requests.get(api_url)
291
+ if response.status_code != 200:
292
+ logger.error(f"Pollinations API error: {response.status_code}, {response.text}")
293
+ return None, None
294
+ image_bytes = response.content
295
+
296
+ elif model == "gemini":
297
+ try:
298
+ g_api_key = os.getenv("GEMINI")
299
+ if not g_api_key:
300
+ st.error("Google Gemini API key is missing.")
301
+ return None, None
302
+ client = genai.Client(api_key=g_api_key)
303
+ enhanced_prompt = f"image of {prompt_text} in {style} style, high quality, detailed illustration"
304
+ response = client.models.generate_content(
305
+ model="models/gemini-2.0-flash-exp", # MODEL REVERTED
306
+ contents=enhanced_prompt,
307
+ config=types.GenerateContentConfig(response_modalities=['Text', 'Image'])
308
+ )
309
+ for part in response.candidates[0].content.parts:
310
+ if part.inline_data is not None:
311
+ image = Image.open(BytesIO(part.inline_data.data))
312
+ buffered = io.BytesIO()
313
+ image.save(buffered, format="JPEG")
314
+ img_str = base64.b64encode(buffered.getvalue()).decode()
315
+ return image, img_str
316
+ logger.error("No image was found in the Gemini API response")
317
+ return None, None
318
+ except Exception as e:
319
+ logger.error(f"Gemini API error: {str(e)}")
320
+ return None, None
321
+
322
+ else:
323
+ enhanced_prompt = f"{prompt_text} in {style} style, high quality, detailed illustration"
324
+ model_id = "black-forest-labs/FLUX.1-dev"
325
+ api_url = f"https://api-inference.huggingface.co/models/{model_id}"
326
+ payload = {"inputs": enhanced_prompt}
327
+ response = requests.post(api_url, headers=headers, json=payload)
328
+ if response.status_code != 200:
329
+ logger.error(f"Hugging Face API error: {response.status_code}, {response.text}")
330
+ return None, None
331
+ image_bytes = response.content
332
+
333
+ if model != "gemini":
334
+ image = Image.open(io.BytesIO(image_bytes))
335
+ buffered = io.BytesIO()
336
+ image.save(buffered, format="JPEG")
337
+ img_str = base64.b64encode(buffered.getvalue()).decode()
338
+ return image, img_str
339
+
340
+ except Exception as e:
341
+ logger.error(f"Image generation error: {str(e)}")
342
+
343
+ return Image.new('RGB', (1024, 1024), color=(200,200,200)), None
344
+
345
+ def generate_image_with_retry(prompt_text, style, model="hf", max_retries=3):
346
+ for attempt in range(max_retries):
347
+ try:
348
+ if attempt > 0:
349
+ time.sleep(2 ** attempt)
350
+ return generate_image(prompt_text, style, model=model)
351
+ except Exception as e:
352
+ logger.error(f"Attempt {attempt+1} failed: {e}")
353
+ if attempt == max_retries - 1:
354
+ raise
355
+ return None, None
356
+
357
+ # -----------------------
358
+ # Video Creation Functions
359
+ # -----------------------
360
+ def create_silent_video(images, durations, output_path, logo_path="sozo_logo2.png", font_path="lazy_dog.ttf"):
361
+ try:
362
+ height, width = 720, 1280
363
+ fps = 24
364
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
365
+ video = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
366
+
367
+ if not video.isOpened():
368
+ st.error("Failed to create video file.")
369
+ return None
370
+
371
+ font_size = 45
372
+ font = ImageFont.truetype(font_path, font_size)
373
+
374
+ logo = None
375
+ if logo_path:
376
+ logo = cv2.imread(logo_path)
377
+ if logo is not None:
378
+ logo = cv2.resize(logo, (width, height))
379
+ else:
380
+ st.warning(f"Failed to load logo from {logo_path}.")
381
+
382
+ for img, duration in zip(images, durations):
383
+ try:
384
+ img = img.convert("RGB")
385
+ img_resized = img.resize((width, height))
386
+ frame = np.array(img_resized)
387
+ except Exception as e:
388
+ print(f"Invalid image detected, replacing with logo: {e}")
389
+ frame = logo if logo is not None else np.zeros((height, width, 3), dtype=np.uint8)
390
+
391
+ pil_img = Image.fromarray(frame)
392
+ draw = ImageDraw.Draw(pil_img)
393
+
394
+ text1 = "Made With"
395
+ text2 = "Sozo Business Studio" # TEXT UPDATED
396
+
397
+ bbox = draw.textbbox((0, 0), text1, font=font)
398
+ text1_height = bbox[3] - bbox[1]
399
+
400
+ text_position1 = (width - 270, height - 120)
401
+ text_position2 = (width - 430, height - 120 + text1_height + 5) # Position adjusted for longer text
402
+
403
+ draw.text(text_position1, text1, font=font, fill=(81, 34, 97, 255))
404
+ draw.text(text_position2, text2, font=font, fill=(81, 34, 97, 255))
405
+
406
+ frame = np.array(pil_img)
407
+ frame_cv = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
408
+
409
+ for _ in range(int(duration * fps)):
410
+ video.write(frame_cv)
411
+
412
+ if logo is not None:
413
+ for _ in range(int(3 * fps)):
414
+ video.write(logo)
415
+
416
+ video.release()
417
+ return output_path
418
+
419
+ except Exception as e:
420
+ st.error(f"Error creating silent video: {e}")
421
+ return None
422
+
423
+ def combine_video_audio(video_path, audio_files, output_path=None):
424
+ try:
425
+ if output_path is None:
426
+ output_path = f"final_video_{uuid.uuid4()}.mp4"
427
+ temp_audio_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3")
428
+ temp_audio_file.close()
429
+ if len(audio_files) > 1:
430
+ concat_list_path = tempfile.NamedTemporaryFile(delete=False, suffix=".txt")
431
+ with open(concat_list_path.name, 'w') as f:
432
+ for af in audio_files:
433
+ f.write(f"file '{af}'\n")
434
+ concat_cmd = [
435
+ 'ffmpeg', '-y', '-f', 'concat', '-safe', '0',
436
+ '-i', concat_list_path.name, '-c', 'copy', temp_audio_file.name
437
+ ]
438
+ subprocess.run(concat_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
439
+ os.unlink(concat_list_path.name)
440
+ combined_audio = temp_audio_file.name
441
+ else:
442
+ combined_audio = audio_files[0] if audio_files else None
443
+ if not combined_audio:
444
+ return video_path
445
+ combine_cmd = [
446
+ 'ffmpeg', '-y', '-i', video_path, '-i', combined_audio,
447
+ '-map', '0:v', '-map', '1:a', '-c:v', 'libx264',
448
+ '-crf', '23', '-c:a', 'aac', '-shortest', output_path
449
+ ]
450
+ subprocess.run(combine_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
451
+ os.unlink(temp_audio_file.name)
452
+ return output_path
453
+ except Exception:
454
+ return video_path
455
+
456
+ def create_video(images, audio_files, output_path=None):
457
+ try:
458
+ subprocess.run(['ffmpeg', '-version'], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
459
+ except FileNotFoundError:
460
+ st.error("ffmpeg not installed.")
461
+ return None
462
+ if output_path is None:
463
+ output_path = f"output_video_{uuid.uuid4()}.mp4"
464
+ silent_video_path = f"silent_{uuid.uuid4()}.mp4"
465
+ durations = [get_audio_duration(af) if af else 5.0 for af in audio_files]
466
+ if len(durations) < len(images):
467
+ durations.extend([5.0]*(len(images)-len(durations)))
468
+ silent_video = create_silent_video(images, durations, silent_video_path)
469
+ if not silent_video:
470
+ return None
471
+ final_video = combine_video_audio(silent_video, audio_files, output_path)
472
+ try:
473
+ os.unlink(silent_video_path)
474
+ except Exception:
475
+ pass
476
+ return final_video
477
+
478
+ # -----------------------
479
+ # Audio Generation Function
480
+ # -----------------------
481
+ def generate_audio(text, voice_model, audio_model="deepgram"):
482
+ if audio_model == "deepgram":
483
+ deepgram_api_key = os.getenv("DeepGram")
484
+ if not deepgram_api_key:
485
+ st.error("Deepgram API Key is missing.")
486
+ return None
487
+ headers_tts = {
488
+ "Authorization": f"Token {deepgram_api_key}",
489
+ "Content-Type": "text/plain"
490
+ }
491
+ url = f"https://api.deepgram.com/v1/speak?model={voice_model}"
492
+ response = requests.post(url, headers=headers_tts, data=text)
493
+ if response.status_code == 200:
494
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3")
495
+ temp_file.write(response.content)
496
+ temp_file.close()
497
+ return temp_file.name
498
+ else:
499
+ st.error(f"DeepGram TTS error: {response.status_code}")
500
+ return None
501
+ elif audio_model == "openai-audio":
502
+ encoded_text = urllib.parse.quote(text)
503
+ url = f"https://text.pollinations.ai/{encoded_text}?model=openai-audio&voice={voice_model}"
504
+ response = requests.get(url)
505
+ if response.status_code == 200:
506
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3")
507
+ temp_file.write(response.content)
508
+ temp_file.close()
509
+ return temp_file.name
510
+ else:
511
+ st.error(f"OpenAI Audio TTS error: {response.status_code}")
512
+ return None
513
+ else:
514
+ st.error("Unsupported audio model selected.")
515
+ return None
516
+
517
+ def get_audio_duration(audio_file):
518
+ try:
519
+ cmd = ['ffprobe', '-v', 'error', '-show_entries', 'format=duration',
520
+ '-of', 'default=noprint_wrappers=1:nokey=1', audio_file]
521
+ result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
522
+ return float(result.stdout.strip()) if result.returncode == 0 else 5.0
523
+ except Exception:
524
+ return 5.0
525
+
526
+ # -----------------------
527
+ # Unified Process-Story Function
528
+ # -----------------------
529
+ def process_generated_story(style, voice_model):
530
+ pages, image_prompts = extract_image_prompts_and_story(st.session_state.full_story)
531
+ st.session_state.story_pages = pages
532
+ st.session_state.image_descriptions = image_prompts
533
+ st.session_state.generated_images = []
534
+ st.session_state.story_audio = []
535
+ progress_bar = st.progress(0)
536
+ total_steps = len(pages) * 2 # 1 for image, 1 for audio
537
+ current_step = 0
538
+
539
+ for i, (page, img_prompt) in enumerate(zip(pages, image_prompts)):
540
+ with st.spinner(f"Generating image {i+1}/{len(pages)}..."):
541
+ img = None
542
+ try:
543
+ chart_response = generateResponse("Generate this visualization: " + img_prompt, st.session_state.dataframe)
544
+ if isinstance(chart_response, dict) and chart_response.get("type") == "plot":
545
+ img_path = chart_response["value"]
546
+ if isinstance(img_path, str) and os.path.isfile(img_path) and is_valid_png(img_path) and standardize_and_validate_image(img_path):
547
+ img = Image.open(img_path)
548
+ else:
549
+ img, _ = generate_image_with_retry(img_prompt, style)
550
+ else:
551
+ img, _ = generate_image_with_retry(img_prompt, style)
552
+ except Exception as e:
553
+ st.warning(f"Chart generation failed for section {i+1}: {e}. Using default image.")
554
+ img, _ = generate_image_with_retry(img_prompt, style)
555
+
556
+ img = img if img else Image.new('RGB', (1024, 1024), color=(200, 200, 200))
557
+ st.session_state.generated_images.append(img.convert('RGB'))
558
+ current_step += 1
559
+ progress_bar.progress(current_step / total_steps)
560
+
561
+ for i, page in enumerate(pages):
562
+ with st.spinner(f"Generating audio {i+1}/{len(pages)}..."):
563
+ audio = generate_audio(page, voice_model, audio_model=audio_model_param)
564
+ st.session_state.story_audio.append(audio)
565
+ current_step += 1
566
+ progress_bar.progress(current_step / total_steps)
567
+
568
+ if st.session_state.generated_images:
569
+ with st.spinner("Assembling video..."):
570
+ audio_paths = [af for af in st.session_state.story_audio if af]
571
+ if audio_paths:
572
+ st.session_state.final_video_path = create_video(st.session_state.generated_images, audio_paths)
573
+ else:
574
+ silent_path = f"silent_video_{uuid.uuid4()}.mp4"
575
+ durations = [5.0] * len(st.session_state.generated_images)
576
+ st.session_state.final_video_path = create_silent_video(st.session_state.generated_images, durations, silent_path)
577
+ progress_bar.empty()
578
+ # -----------------------
579
+ # Display Generated Content
580
+ # -----------------------
581
+ def display_generated_content():
582
+ st.subheader("Generated Narrative Video")
583
+ tab1, tab2, tab3 = st.tabs(["Video Output", "Story Pages", "Full Script"])
584
+
585
+ with tab1:
586
+ if st.session_state.final_video_path and os.path.exists(st.session_state.final_video_path):
587
+ with open(st.session_state.final_video_path, "rb") as f:
588
+ video_bytes = f.read()
589
+ st.video(video_bytes)
590
+ st.download_button("Download Video", data=video_bytes, file_name="sozo_business_narrative.mp4", mime="video/mp4")
591
+ share_message = "Check out this AI-generated business narrative video!"
592
+ whatsapp_link = f"https://api.whatsapp.com/send?text={urllib.parse.quote(share_message)}"
593
+ st.markdown(f"[Share on WhatsApp]({whatsapp_link})", unsafe_allow_html=True)
594
+ else:
595
+ st.error("Video file not found or not readable.")
596
+
597
+ with tab2:
598
+ for i, (page, img) in enumerate(zip(st.session_state.story_pages, st.session_state.generated_images)):
599
+ st.image(img, caption=f"Scene {i+1}")
600
+ st.markdown(f"**Narration {i+1}**: {page}")
601
+ if i < len(st.session_state.story_audio) and st.session_state.story_audio[i]:
602
+ st.audio(st.session_state.story_audio[i])
603
+
604
+ with tab3:
605
+ st.text_area("Complete Narrative Script", st.session_state.full_story, height=400)
606
+
607
+
608
+ # -----------------------
609
+ # Streamlit App Configuration and Sidebar
610
+ # -----------------------
611
+ st.set_page_config(page_title="Sozo Business Studio", page_icon="💼", layout="wide", initial_sidebar_state="expanded")
612
+
613
+ for key in ["story_pages", "image_descriptions", "generated_images", "story_audio", "full_story", "final_video_path", "dataframe"]:
614
+ if key not in st.session_state:
615
+ st.session_state[key] = [] if key.startswith("story") or key.startswith("generated") else None
616
+
617
+ with st.sidebar:
618
+ st.sidebar.image("sozo_logo1.jpeg", use_container_width=True)
619
+ story_types = {
620
+ "business": "Business Narrative",
621
+ "education": "Educational",
622
+ "entertainment": "Entertaining",
623
+ "free_form": "Free Form (AI's choice)",
624
+ "children": "Children's Story",
625
+ }
626
+ selected_story_type = st.selectbox(
627
+ "Narrative Style",
628
+ options=list(story_types.keys()),
629
+ format_func=lambda x: story_types[x],
630
+ key="story_type_select"
631
+ )
632
+
633
+ model_options = ["HuggingFace Flux", "Pollinations Turbo", "Google Gemini"]
634
+ selected_model_name = st.selectbox("Select Image Generation Model", model_options, index=0, key="image_model_select")
635
+
636
+ style_options = ["photorealistic", "cinematic", "cartoon", "concept art", "oil painting", "fantasy illustration", "whimsical"]
637
+ selected_style = st.selectbox("Image Style", style_options, key="style_select")
638
+
639
+ model_param = {"HuggingFace Flux": "hf", "Pollinations Turbo": "pollinations_turbo", "Google Gemini": "gemini"}[selected_model_name]
640
+
641
+ audio_model_options = ["DeepGram", "Pollinations OpenAI-Audio"]
642
+ selected_audio_model = st.selectbox("Select Audio Generation Model", audio_model_options, key="audio_model_select")
643
+
644
+ if selected_audio_model == "DeepGram":
645
+ voice_options = {"aura-asteria-en": "Female", "aura-helios-en": "Male"}
646
+ selected_voice = st.selectbox("Voice Model", options=list(voice_options.keys()), format_func=voice_options.get, key="voice_select_deepgram")
647
+ audio_model_param = "deepgram"
648
+ else:
649
+ voice_options = {"sage": "Female", "echo": "Male"}
650
+ selected_voice = st.selectbox("Voice Model", options=list(voice_options.keys()), format_func=voice_options.get, key="voice_select_pollinations")
651
+ audio_model_param = "openai-audio"
652
+
653
+ st.markdown("### Tips for Best Results")
654
+ st.markdown("- Ensure your data has clear column headers.\n- Use the 'Business Narrative' style for professional reports.\n- Try different image styles and voices to match your brand.")
655
+ if st.button("Check System Requirements"):
656
+ try:
657
+ result = subprocess.run(['ffmpeg', '-version'], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
658
+ st.success("✅ ffmpeg is installed.")
659
+ except FileNotFoundError:
660
+ st.error("❌ ffmpeg not found. It must be installed to create videos.")
661
+
662
+ # --- MAIN PAGE ---
663
+ st.subheader("Sozo Business Studio")
664
+ st.markdown("#### Turn business data into compelling narratives.")
665
+ st.markdown("---")
666
+
667
+ st.markdown("### 1. Upload Your Business Data")
668
+ uploaded_file = st.file_uploader(
669
+ "Upload a CSV or Excel file to begin.",
670
+ type=['csv', 'xlsx', 'xls'],
671
+ label_visibility="collapsed"
672
+ )
673
+
674
+ if uploaded_file:
675
+ try:
676
+ df = pd.read_excel(uploaded_file) if uploaded_file.name.endswith(('xlsx', 'xls')) else pd.read_csv(uploaded_file)
677
+ st.session_state.dataframe = df
678
+ st.success(f"✅ Loaded `{uploaded_file.name}`. Data preview:")
679
+ st.dataframe(df.head())
680
+ except Exception as e:
681
+ st.error(f"Error processing {uploaded_file.name}: {e}")
682
+ st.session_state.dataframe = None
683
+
684
+ st.markdown("### 2. Generate Your Video")
685
+ if st.button("Generate Video Narrative", disabled=st.session_state.dataframe is None):
686
+ with st.spinner("Analyzing data and generating narrative script..."):
687
+ st.session_state.full_story = generate_story_from_dataframe(st.session_state.dataframe, selected_story_type)
688
+
689
+ if st.session_state.full_story:
690
+ st.success("Script generated! Now creating video assets...")
691
+ process_generated_story(selected_style, selected_voice)
692
+ else:
693
+ st.error("Failed to generate narrative script. The data might be formatted incorrectly or the AI model could be temporarily unavailable.")
694
 
695
+ if st.session_state.story_pages:
696
+ st.markdown("---")
697
+ display_generated_content()