|
|
|
|
|
|
|
|
import os |
|
|
import re |
|
|
import json |
|
|
import logging |
|
|
import uuid |
|
|
import io |
|
|
from pathlib import Path |
|
|
import pandas as pd |
|
|
import numpy as np |
|
|
import matplotlib |
|
|
matplotlib.use("Agg") |
|
|
import matplotlib.pyplot as plt |
|
|
from matplotlib.animation import FuncAnimation, FFMpegWriter |
|
|
import seaborn as sns |
|
|
from scipy import stats |
|
|
from PIL import Image, ImageDraw, ImageFont |
|
|
import cv2 |
|
|
import inspect |
|
|
import tempfile |
|
|
import subprocess |
|
|
from typing import Dict, List, Tuple, Any |
|
|
from langchain_google_genai import ChatGoogleGenerativeAI |
|
|
from google import genai |
|
|
import requests |
|
|
|
|
|
from google.genai import types as genai_types |
|
|
import math |
|
|
import shutil |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - [%(funcName)s] - %(message)s') |
|
|
FPS, WIDTH, HEIGHT = 24, 1280, 720 |
|
|
MAX_CHARTS, VIDEO_SCENES = 5, 5 |
|
|
MAX_CONTEXT_TOKENS = 750000 |
|
|
|
|
|
|
|
|
API_KEY = os.getenv("GOOGLE_API_KEY") |
|
|
if not API_KEY: |
|
|
raise ValueError("GOOGLE_API_KEY environment variable not set.") |
|
|
|
|
|
PEXELS_API_KEY = os.getenv("PEXELS_API_KEY") |
|
|
|
|
|
|
|
|
def load_dataframe_safely(buf, name: str): |
|
|
ext = Path(name).suffix.lower() |
|
|
df = (pd.read_excel if ext in (".xlsx", ".xls") else pd.read_csv)(buf) |
|
|
df.columns = df.columns.astype(str).str.strip() |
|
|
df = df.dropna(how="all") |
|
|
if df.empty or len(df.columns) == 0: raise ValueError("No usable data found") |
|
|
return df |
|
|
|
|
|
def deepgram_tts(txt: str, voice_model: str): |
|
|
DG_KEY = os.getenv("DEEPGRAM_API_KEY") |
|
|
if not DG_KEY or not txt: return None |
|
|
txt = re.sub(r"[^\w\s.,!?;:-]", "", txt) |
|
|
try: |
|
|
r = requests.post("https://api.deepgram.com/v1/speak", params={"model": voice_model}, headers={"Authorization": f"Token {DG_KEY}", "Content-Type": "application/json"}, json={"text": txt}, timeout=30) |
|
|
r.raise_for_status() |
|
|
return r.content |
|
|
except Exception as e: |
|
|
logging.error(f"Deepgram TTS failed: {e}") |
|
|
return None |
|
|
|
|
|
def generate_silence_mp3(duration: float, out: Path): |
|
|
subprocess.run([ "ffmpeg", "-y", "-f", "lavfi", "-i", "anullsrc=r=44100:cl=mono", "-t", f"{duration:.3f}", "-q:a", "9", str(out)], check=True, capture_output=True) |
|
|
|
|
|
def audio_duration(path: str) -> float: |
|
|
try: |
|
|
res = subprocess.run([ "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=nw=1:nk=1", path], text=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True) |
|
|
return float(res.stdout.strip()) |
|
|
except Exception: return 5.0 |
|
|
|
|
|
TAG_RE = re.compile( r'[<[]\s*generate_?chart\s*[:=]?\s*[\"\'“”]?(?P<d>[^>\"\'”\]]+?)[\"\'“”]?\s*[>\]]', re.I, ) |
|
|
TAG_RE_PEXELS = re.compile( r'[<[]\s*generate_?stock_?video\s*[:=]?\s*[\"\'“”]?(?P<d>[^>\"\'”\]]+?)[\"\'“”]?\s*[>\]]', re.I, ) |
|
|
extract_chart_tags = lambda t: list( dict.fromkeys(m.group("d").strip() for m in TAG_RE.finditer(t or "")) ) |
|
|
extract_pexels_tags = lambda t: list( dict.fromkeys(m.group("d").strip() for m in TAG_RE_PEXELS.finditer(t or "")) ) |
|
|
|
|
|
re_scene = re.compile(r"^\s*scene\s*\d+[:.\- ]*", re.I | re.M) |
|
|
def clean_narration(txt: str) -> str: |
|
|
txt = TAG_RE.sub("", txt); txt = TAG_RE_PEXELS.sub("", txt); txt = re_scene.sub("", txt) |
|
|
phrases_to_remove = [r"chart tag", r"chart_tag", r"narration", r"stock video tag"] |
|
|
for phrase in phrases_to_remove: txt = re.sub(phrase, "", txt, flags=re.IGNORECASE) |
|
|
txt = re.sub(r"\s*\([^)]*\)", "", txt); txt = re.sub(r"[\*#_]", "", txt) |
|
|
return re.sub(r"\s{2,}", " ", txt).strip() |
|
|
|
|
|
def placeholder_img() -> Image.Image: return Image.new("RGB", (WIDTH, HEIGHT), (230, 230, 230)) |
|
|
|
|
|
def _sanitize_for_json(data): |
|
|
"""Recursively sanitizes a dict/list for JSON compliance.""" |
|
|
if isinstance(data, dict): |
|
|
return {k: _sanitize_for_json(v) for k, v in data.items()} |
|
|
if isinstance(data, list): |
|
|
return [_sanitize_for_json(i) for i in data] |
|
|
if isinstance(data, float) and (math.isnan(data) or math.isinf(data)): |
|
|
return None |
|
|
return data |
|
|
|
|
|
def detect_dataset_domain(df: pd.DataFrame) -> str: |
|
|
"""Analyzes column names to detect the dataset's primary domain.""" |
|
|
domain_keywords = { |
|
|
"health insurance": ["charges", "bmi", "smoker", "beneficiary"], |
|
|
"finance": ["revenue", "profit", "cost", "budget", "expense", "stock"], |
|
|
"marketing": ["campaign", "conversion", "click", "customer", "segment"], |
|
|
"survey": ["satisfaction", "rating", "feedback", "opinion", "score"], |
|
|
"food": ["nutrition", "calories", "ingredients", "restaurant"] |
|
|
} |
|
|
columns_lower = [col.lower() for col in df.columns] |
|
|
for domain, keywords in domain_keywords.items(): |
|
|
if any(keyword in col for col in columns_lower for keyword in keywords): |
|
|
logging.info(f"Dataset domain detected: {domain}") |
|
|
return domain |
|
|
logging.info("No specific dataset domain detected, using generic terms.") |
|
|
return "data" |
|
|
|
|
|
|
|
|
def extract_keywords_for_query(text: str, llm) -> str: |
|
|
prompt = f""" |
|
|
Extract a maximum of 3 key nouns or verbs from the following text to use as a search query for a stock video. |
|
|
Focus on concrete actions and subjects. |
|
|
Example: 'Our analysis shows a significant growth in quarterly revenue and strong partnerships.' -> 'data analysis growth' |
|
|
Output only the search query keywords, separated by spaces. |
|
|
|
|
|
Text: "{text}" |
|
|
""" |
|
|
try: |
|
|
response = llm.invoke(prompt).content.strip() |
|
|
return response if response else text |
|
|
except Exception as e: |
|
|
logging.error(f"Keyword extraction failed: {e}. Using original text.") |
|
|
return text |
|
|
|
|
|
|
|
|
def search_and_download_pexels_video(query: str, duration: float, out_path: Path) -> str: |
|
|
if not PEXELS_API_KEY: |
|
|
logging.warning("PEXELS_API_KEY not set. Cannot fetch stock video.") |
|
|
return None |
|
|
try: |
|
|
headers = {"Authorization": PEXELS_API_KEY} |
|
|
params = {"query": query, "per_page": 10, "orientation": "landscape"} |
|
|
response = requests.get("https://api.pexels.com/videos/search", headers=headers, params=params, timeout=20) |
|
|
response.raise_for_status() |
|
|
videos = response.json().get('videos', []) |
|
|
if not videos: |
|
|
logging.warning(f"No Pexels videos found for query: '{query}'") |
|
|
return None |
|
|
|
|
|
video_to_download = None |
|
|
for video in videos: |
|
|
for f in video.get('video_files', []): |
|
|
if f.get('quality') == 'hd' and f.get('width') >= 1280: |
|
|
video_to_download = f['link'] |
|
|
break |
|
|
if video_to_download: |
|
|
break |
|
|
|
|
|
if not video_to_download: |
|
|
logging.warning(f"No suitable HD video file found for query: '{query}'") |
|
|
return None |
|
|
|
|
|
with requests.get(video_to_download, stream=True, timeout=60) as r: |
|
|
r.raise_for_status() |
|
|
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_dl_file: |
|
|
for chunk in r.iter_content(chunk_size=8192): |
|
|
temp_dl_file.write(chunk) |
|
|
temp_dl_path = Path(temp_dl_file.name) |
|
|
|
|
|
|
|
|
cmd = [ |
|
|
"ffmpeg", "-y", |
|
|
"-stream_loop", "-1", |
|
|
"-i", str(temp_dl_path), |
|
|
"-vf", f"scale={WIDTH}:{HEIGHT}:force_original_aspect_ratio=decrease,pad={WIDTH}:{HEIGHT}:(ow-iw)/2:(oh-ih)/2,setsar=1", |
|
|
"-t", f"{duration:.3f}", |
|
|
"-c:v", "libx264", "-pix_fmt", "yuv420p", "-an", |
|
|
str(out_path) |
|
|
] |
|
|
subprocess.run(cmd, check=True, capture_output=True) |
|
|
temp_dl_path.unlink() |
|
|
return str(out_path) |
|
|
|
|
|
except Exception as e: |
|
|
logging.error(f"Pexels video processing failed for query '{query}': {e}") |
|
|
if 'temp_dl_path' in locals() and temp_dl_path.exists(): |
|
|
temp_dl_path.unlink() |
|
|
return None |
|
|
|
|
|
class ChartSpecification: |
|
|
def __init__(self, chart_type: str, title: str, x_col: str, y_col: str = None, size_col: str = None, agg_method: str = None, filter_condition: str = None, top_n: int = None, color_scheme: str = "professional"): |
|
|
self.chart_type = chart_type; self.title = title; self.x_col = x_col; self.y_col = y_col; self.size_col = size_col |
|
|
self.agg_method = agg_method or "sum"; self.filter_condition = filter_condition; self.top_n = top_n; self.color_scheme = color_scheme |
|
|
|
|
|
class ChartGenerator: |
|
|
def __init__(self, llm, df: pd.DataFrame): |
|
|
self.llm = llm; self.df = df |
|
|
|
|
|
def generate_chart_spec(self, description: str, context: Dict) -> ChartSpecification: |
|
|
spec_prompt = f""" |
|
|
You are a data visualization expert. Based on the dataset context and chart description, generate a precise chart specification. |
|
|
**Dataset Context:** {json.dumps(context, indent=2)} |
|
|
**Chart Request:** {description} |
|
|
**Return a JSON specification with these exact fields:** |
|
|
{{ |
|
|
"chart_type": "bar|pie|line|scatter|hist|heatmap|area|bubble", |
|
|
"title": "Professional chart title", |
|
|
"x_col": "column_name_for_x_axis_or_null_for_heatmap", |
|
|
"y_col": "column_name_for_y_axis_or_null", |
|
|
"size_col": "column_name_for_bubble_size_or_null", |
|
|
"agg_method": "sum|mean|count|max|min|null", |
|
|
"top_n": "number_for_top_n_filtering_or_null" |
|
|
}} |
|
|
Return only the JSON specification, no additional text. |
|
|
""" |
|
|
try: |
|
|
response = self.llm.invoke(spec_prompt).content.strip() |
|
|
if response.startswith("```json"): response = response[7:-3] |
|
|
elif response.startswith("```"): response = response[3:-3] |
|
|
spec_dict = json.loads(response) |
|
|
valid_keys = [p.name for p in inspect.signature(ChartSpecification).parameters.values() if p.name not in ['reasoning', 'filter_condition', 'color_scheme']] |
|
|
filtered_dict = {k: v for k, v in spec_dict.items() if k in valid_keys} |
|
|
return ChartSpecification(**filtered_dict) |
|
|
except Exception as e: |
|
|
logging.error(f"Spec generation failed: {e}. Using fallback.") |
|
|
numeric_cols = context.get('schema', {}).get('numeric_columns', list(self.df.select_dtypes(include=['number']).columns)) |
|
|
categorical_cols = context.get('schema', {}).get('categorical_columns', list(self.df.select_dtypes(exclude=['number']).columns)) |
|
|
ctype = "bar" |
|
|
for t in ["pie", "line", "scatter", "hist", "heatmap", "area", "bubble"]: |
|
|
if t in description.lower(): ctype = t |
|
|
x = categorical_cols[0] if categorical_cols else self.df.columns[0] |
|
|
y = numeric_cols[0] if numeric_cols and len(self.df.columns) > 1 else (self.df.columns[1] if len(self.df.columns) > 1 else None) |
|
|
return ChartSpecification(ctype, description, x, y) |
|
|
|
|
|
def execute_chart_spec(spec: ChartSpecification, df: pd.DataFrame, output_path: Path) -> bool: |
|
|
try: |
|
|
plot_data = prepare_plot_data(spec, df) |
|
|
fig, ax = plt.subplots(figsize=(12, 8)); plt.style.use('default') |
|
|
if spec.chart_type == "bar": ax.bar(plot_data.index.astype(str), plot_data.values, color='#2E86AB', alpha=0.8); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col); ax.tick_params(axis='x', rotation=45) |
|
|
elif spec.chart_type == "pie": ax.pie(plot_data.values, labels=plot_data.index, autopct='%1.1f%%', startangle=90); ax.axis('equal') |
|
|
elif spec.chart_type == "line": ax.plot(plot_data.index, plot_data.values, marker='o', linewidth=2, color='#A23B72'); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col); ax.grid(True, alpha=0.3) |
|
|
elif spec.chart_type == "scatter": ax.scatter(plot_data.iloc[:, 0], plot_data.iloc[:, 1], alpha=0.6, color='#F18F01'); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col); ax.grid(True, alpha=0.3) |
|
|
elif spec.chart_type == "hist": ax.hist(plot_data.values, bins=20, color='#C73E1D', alpha=0.7, edgecolor='black'); ax.set_xlabel(spec.x_col); ax.set_ylabel('Frequency'); ax.grid(True, alpha=0.3) |
|
|
elif spec.chart_type == "area": ax.fill_between(plot_data.index, plot_data.values, color="#4E79A7", alpha=0.4); ax.plot(plot_data.index, plot_data.values, color="#4E79A7", alpha=0.8); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col); ax.grid(True, alpha=0.3) |
|
|
elif spec.chart_type == "heatmap": sns.heatmap(plot_data, annot=True, cmap="viridis", ax=ax); plt.xticks(rotation=45, ha="right"); plt.yticks(rotation=0) |
|
|
elif spec.chart_type == "bubble": |
|
|
sizes = (plot_data[spec.size_col] - plot_data[spec.size_col].min() + 1) / (plot_data[spec.size_col].max() - plot_data[spec.size_col].min() + 1) * 2000 + 50 |
|
|
ax.scatter(plot_data[spec.x_col], plot_data[spec.y_col], s=sizes, alpha=0.6, color='#59A14F'); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col); ax.grid(True, alpha=0.3) |
|
|
|
|
|
ax.set_title(spec.title, fontsize=14, fontweight='bold', pad=20); plt.tight_layout() |
|
|
plt.savefig(output_path, dpi=300, bbox_inches='tight', facecolor='white'); plt.close() |
|
|
return True |
|
|
except Exception as e: logging.error(f"Static chart generation failed for '{spec.title}': {e}"); return False |
|
|
|
|
|
def prepare_plot_data(spec: ChartSpecification, df: pd.DataFrame): |
|
|
if spec.chart_type not in ["heatmap"]: |
|
|
if spec.x_col not in df.columns or (spec.y_col and spec.y_col not in df.columns): raise ValueError(f"Invalid columns in chart spec: {spec.x_col}, {spec.y_col}") |
|
|
|
|
|
if spec.chart_type in ["bar", "pie"]: |
|
|
if not spec.y_col: return df[spec.x_col].value_counts().nlargest(spec.top_n or 10) |
|
|
grouped = df.groupby(spec.x_col)[spec.y_col].agg(spec.agg_method or 'sum') |
|
|
return grouped.nlargest(spec.top_n or 10) |
|
|
elif spec.chart_type in ["line", "area"]: return df.set_index(spec.x_col)[spec.y_col].sort_index() |
|
|
elif spec.chart_type == "scatter": return df[[spec.x_col, spec.y_col]].dropna() |
|
|
elif spec.chart_type == "bubble": |
|
|
if not spec.size_col or spec.size_col not in df.columns: raise ValueError("Bubble chart requires a valid size_col.") |
|
|
return df[[spec.x_col, spec.y_col, spec.size_col]].dropna() |
|
|
elif spec.chart_type == "hist": return df[spec.x_col].dropna() |
|
|
elif spec.chart_type == "heatmap": |
|
|
numeric_cols = df.select_dtypes(include=np.number).columns |
|
|
if not numeric_cols.any(): raise ValueError("Heatmap requires numeric columns.") |
|
|
return df[numeric_cols].corr() |
|
|
return df[spec.x_col] |
|
|
|
|
|
|
|
|
|
|
|
def animate_chart(spec: ChartSpecification, df: pd.DataFrame, dur: float, out: Path, fps: int = FPS) -> str: |
|
|
plot_data = prepare_plot_data(spec, df) |
|
|
frames = math.ceil(dur * fps) |
|
|
fig, ax = plt.subplots(figsize=(WIDTH / 100, HEIGHT / 100), dpi=100) |
|
|
plt.tight_layout(pad=3.0) |
|
|
ctype = spec.chart_type |
|
|
|
|
|
init_func, update_func = None, None |
|
|
|
|
|
if ctype == "line": |
|
|
plot_data = plot_data.sort_index() |
|
|
x_full, y_full = plot_data.index, plot_data.values |
|
|
|
|
|
ax.set_xlim(x_full.min(), x_full.max()) |
|
|
ax.set_ylim(y_full.min() * 0.9, y_full.max() * 1.1) |
|
|
ax.set_title(spec.title); ax.grid(alpha=.3); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col) |
|
|
|
|
|
line, = ax.plot([], [], lw=2, color='#A23B72') |
|
|
markers, = ax.plot([], [], 'o', color='#A23B72', markersize=5) |
|
|
|
|
|
def init(): |
|
|
line.set_data([], []) |
|
|
markers.set_data([], []) |
|
|
return line, markers |
|
|
def update(i): |
|
|
k = max(2, int(len(x_full) * (i / (frames - 1)))) |
|
|
line.set_data(x_full[:k], y_full[:k]) |
|
|
markers.set_data(x_full[:k], y_full[:k]) |
|
|
return line, markers |
|
|
init_func, update_func = init, update |
|
|
|
|
|
anim = FuncAnimation(fig, update, init_func=init, frames=frames, blit=True, interval=1000 / fps) |
|
|
anim.save(str(out), writer=FFMpegWriter(fps=fps), dpi=144) |
|
|
plt.close(fig) |
|
|
return str(out) |
|
|
|
|
|
|
|
|
|
|
|
if ctype == "bar": |
|
|
bars = ax.bar(plot_data.index.astype(str), np.zeros_like(plot_data.values, dtype=float), color="#1f77b4") |
|
|
ax.set_ylim(0, plot_data.max() * 1.1 if not pd.isna(plot_data.max()) and plot_data.max() > 0 else 1) |
|
|
ax.set_title(spec.title); plt.xticks(rotation=45, ha="right") |
|
|
def init(): return bars |
|
|
def update(i): |
|
|
for b, h in zip(bars, plot_data.values): b.set_height(h * (i / (frames - 1))) |
|
|
return bars |
|
|
init_func, update_func = init, update |
|
|
elif ctype == "scatter": |
|
|
x_full, y_full = plot_data.iloc[:, 0], plot_data.iloc[:, 1] |
|
|
slope, intercept, _, _, _ = stats.linregress(x_full, y_full) |
|
|
reg_line_x = np.array([x_full.min(), x_full.max()]) |
|
|
reg_line_y = slope * reg_line_x + intercept |
|
|
scat = ax.scatter([], [], alpha=0.7, color='#F18F01') |
|
|
line, = ax.plot([], [], 'r--', lw=2) |
|
|
ax.set_xlim(x_full.min(), x_full.max()); ax.set_ylim(y_full.min(), y_full.max()) |
|
|
ax.set_title(spec.title); ax.grid(alpha=.3); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col) |
|
|
def init(): |
|
|
scat.set_offsets(np.empty((0, 2))); line.set_data([], []) |
|
|
return [] |
|
|
def update(i): |
|
|
point_frames = int(frames * 0.7) |
|
|
if i <= point_frames: |
|
|
k = max(1, int(len(x_full) * (i / point_frames))) |
|
|
scat.set_offsets(plot_data.iloc[:k].values) |
|
|
else: |
|
|
line_frame = i - point_frames; line_total_frames = frames - point_frames |
|
|
current_x = reg_line_x[0] + (reg_line_x[1] - reg_line_x[0]) * (line_frame / line_total_frames) |
|
|
line.set_data([reg_line_x[0], current_x], [reg_line_y[0], slope * current_x + intercept]) |
|
|
return [] |
|
|
init_func, update_func = init, update |
|
|
elif ctype == "pie": |
|
|
wedges, _, _ = ax.pie(plot_data, labels=plot_data.index, startangle=90, autopct='%1.1f%%') |
|
|
ax.set_title(spec.title); ax.axis('equal') |
|
|
def init(): [w.set_alpha(0) for w in wedges]; return [] |
|
|
def update(i): [w.set_alpha(i / (frames - 1)) for w in wedges]; return [] |
|
|
init_func, update_func = init, update |
|
|
elif ctype == "hist": |
|
|
_, _, patches = ax.hist(plot_data, bins=20, alpha=0) |
|
|
ax.set_title(spec.title); ax.set_xlabel(spec.x_col); ax.set_ylabel("Frequency") |
|
|
def init(): [p.set_alpha(0) for p in patches]; return [] |
|
|
def update(i): [p.set_alpha((i / (frames - 1)) * 0.7) for p in patches]; return [] |
|
|
init_func, update_func = init, update |
|
|
elif ctype == "heatmap": |
|
|
sns.heatmap(plot_data, annot=True, cmap="viridis", ax=ax, alpha=0) |
|
|
ax.set_title(spec.title) |
|
|
def init(): ax.collections[0].set_alpha(0); return [] |
|
|
def update(i): ax.collections[0].set_alpha(i / (frames - 1)); return [] |
|
|
init_func, update_func = init, update |
|
|
else: |
|
|
ax.text(0.5, 0.5, f"'{ctype}' animation not implemented", ha='center', va='center') |
|
|
def init(): return [] |
|
|
def update(i): return [] |
|
|
init_func, update_func = init, update |
|
|
|
|
|
anim = FuncAnimation(fig, update_func, init_func=init_func, frames=frames, blit=False, interval=1000 / fps) |
|
|
anim.save(str(out), writer=FFMpegWriter(fps=fps), dpi=144) |
|
|
plt.close(fig) |
|
|
return str(out) |
|
|
|
|
|
def safe_chart(desc: str, df: pd.DataFrame, dur: float, out: Path, context: Dict) -> str: |
|
|
try: |
|
|
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", google_api_key=API_KEY, temperature=0.1) |
|
|
chart_generator = ChartGenerator(llm, df) |
|
|
chart_spec = chart_generator.generate_chart_spec(desc, context) |
|
|
return animate_chart(chart_spec, df, dur, out) |
|
|
except Exception as e: |
|
|
logging.error(f"Chart animation failed for '{desc}': {e}. Raising exception to trigger fallback.") |
|
|
raise e |
|
|
|
|
|
def concat_media(file_paths: List[str], output_path: Path): |
|
|
valid_paths = [p for p in file_paths if Path(p).exists() and Path(p).stat().st_size > 100] |
|
|
if not valid_paths: raise ValueError("No valid media files to concatenate.") |
|
|
if len(valid_paths) == 1: import shutil; shutil.copy2(valid_paths[0], str(output_path)); return |
|
|
list_file = output_path.with_suffix(".txt") |
|
|
with open(list_file, 'w') as f: |
|
|
for path in valid_paths: f.write(f"file '{Path(path).resolve()}'\n") |
|
|
cmd = ["ffmpeg", "-y", "-f", "concat", "-safe", "0", "-i", str(list_file), "-c", "copy", str(output_path)] |
|
|
try: |
|
|
subprocess.run(cmd, check=True, capture_output=True, text=True) |
|
|
finally: |
|
|
list_file.unlink(missing_ok=True) |
|
|
|
|
|
def sanitize_for_firebase_key(text: str) -> str: |
|
|
forbidden_chars = ['.', '$', '#', '[', ']', '/'] |
|
|
for char in forbidden_chars: |
|
|
text = text.replace(char, '_') |
|
|
return text |
|
|
|
|
|
def analyze_data_intelligence(df: pd.DataFrame) -> Dict: |
|
|
numeric_cols = df.select_dtypes(include=['number']).columns.tolist() |
|
|
categorical_cols = df.select_dtypes(exclude=['number']).columns.tolist() |
|
|
is_timeseries = any('date' in col.lower() or 'time' in col.lower() for col in df.columns) |
|
|
opportunities = [] |
|
|
if is_timeseries: opportunities.append("temporal trends") |
|
|
if len(numeric_cols) > 1: opportunities.append("correlations between metrics") |
|
|
if len(categorical_cols) > 0 and len(numeric_cols) > 0: opportunities.append("segmentation by category") |
|
|
if df.isnull().sum().sum() > 0: opportunities.append("impact of missing data") |
|
|
return { |
|
|
"insight_opportunities": opportunities, |
|
|
"is_timeseries": is_timeseries, |
|
|
"has_correlations": len(numeric_cols) > 1, |
|
|
"has_segments": len(categorical_cols) > 0 and len(numeric_cols) > 0 |
|
|
} |
|
|
|
|
|
def generate_visualization_strategy(intelligence: Dict) -> str: |
|
|
strategy = "Vary your visualizations to keep the report engaging. " |
|
|
if intelligence["is_timeseries"]: strategy += "Use 'line' or 'area' charts to explore temporal trends. " |
|
|
if intelligence["has_correlations"]: strategy += "Use 'scatter' or 'heatmap' charts to reveal correlations. " |
|
|
if intelligence["has_segments"]: strategy += "Use 'bar' or 'pie' charts to compare segments. " |
|
|
return strategy |
|
|
|
|
|
def get_augmented_context(df: pd.DataFrame, user_ctx: str) -> Dict: |
|
|
"""Creates a detailed, JSON-safe summary of the dataframe for the AI.""" |
|
|
numeric_cols = df.select_dtypes(include=['number']).columns.tolist() |
|
|
categorical_cols = df.select_dtypes(exclude=['number']).columns.tolist() |
|
|
|
|
|
context = { |
|
|
"user_context": user_ctx, |
|
|
"dataset_shape": {"rows": df.shape[0], "columns": df.shape[1]}, |
|
|
"schema": {"numeric_columns": numeric_cols, "categorical_columns": categorical_cols}, |
|
|
"data_previews": {} |
|
|
} |
|
|
|
|
|
for col in categorical_cols[:5]: |
|
|
unique_vals = df[col].unique() |
|
|
context["data_previews"][col] = { |
|
|
"count": len(unique_vals), |
|
|
"values": unique_vals[:5].tolist() |
|
|
} |
|
|
|
|
|
for col in numeric_cols[:5]: |
|
|
context["data_previews"][col] = { |
|
|
"mean": df[col].mean(), |
|
|
"min": df[col].min(), |
|
|
"max": df[col].max() |
|
|
} |
|
|
|
|
|
|
|
|
return _sanitize_for_json(json.loads(json.dumps(context, default=str))) |
|
|
|
|
|
def generate_report_draft(buf, name: str, ctx: str, uid: str, project_id: str, bucket): |
|
|
logging.info(f"Generating guided storyteller report draft for project {project_id}") |
|
|
df = load_dataframe_safely(buf, name) |
|
|
llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash", google_api_key=API_KEY, temperature=0.3) |
|
|
|
|
|
data_context_str, context_for_charts = "", {} |
|
|
try: |
|
|
df_json = df.to_json(orient='records') |
|
|
estimated_tokens = len(df_json) / 4 |
|
|
if estimated_tokens < MAX_CONTEXT_TOKENS: |
|
|
logging.info(f"Using full JSON context for report generation.") |
|
|
data_context_str = f"Here is the full dataset in JSON format:\n{df_json}" |
|
|
context_for_charts = get_augmented_context(df, ctx) |
|
|
else: |
|
|
raise ValueError("Dataset too large for full context.") |
|
|
except Exception as e: |
|
|
logging.warning(f"Falling back to augmented summary context for report generation: {e}") |
|
|
augmented_context = get_augmented_context(df, ctx) |
|
|
data_context_str = f"The full dataset is too large to display. Here is a detailed summary:\n{json.dumps(augmented_context, indent=2)}" |
|
|
context_for_charts = augmented_context |
|
|
|
|
|
md = "" |
|
|
try: |
|
|
|
|
|
strategist_prompt = f""" |
|
|
You are a data visualization expert. Your task is to create a diverse palette of unique and impactful charts for a data storyteller. |
|
|
Based on the provided data context, identify the 4-5 most distinct and insightful stories that can be visualized. |
|
|
|
|
|
**Data Context:** |
|
|
{data_context_str} |
|
|
|
|
|
**Your Goal:** |
|
|
Your primary goal is to select a **diverse palette of chart types**. A high-quality response will use a mix of different charts from the available list to create a visually engaging and comprehensive report. **Do not use the same chart type more than twice.** |
|
|
|
|
|
**Strategic Hints:** |
|
|
- Consider a `histogram` to show the distribution of a key variable (like age or bmi). |
|
|
- Consider a `pie chart` for a clear part-to-whole relationship (e.g., smoker vs. non-smoker proportions). |
|
|
- Consider a `heatmap` if the dataset has multiple numeric columns and you believe the overall pattern of their correlations is a key insight in itself. |
|
|
|
|
|
**Output Format:** |
|
|
Return ONLY a valid JSON array of strings. Each string must be a unique chart description tag. |
|
|
|
|
|
Example: |
|
|
["bar | Average Charges by Smoker Status", "scatter | Charges vs. BMI", "hist | Distribution of Beneficiary Ages", "pie | Regional Proportions"] |
|
|
""" |
|
|
logging.info("Executing Visualization Strategist Pass...") |
|
|
strategist_response = llm.invoke(strategist_prompt).content.strip() |
|
|
if strategist_response.startswith("```json"): |
|
|
strategist_response = strategist_response[7:-3] |
|
|
chart_palette = json.loads(strategist_response) |
|
|
logging.info(f"Strategist Pass successful. Palette has {len(chart_palette)} unique charts.") |
|
|
|
|
|
|
|
|
storyteller_prompt = f""" |
|
|
You are an elite data storyteller and business intelligence expert. Your mission is to write a comprehensive, flowing narrative that analyzes the entire dataset provided. Your goal is to create a captivating story that **drives action**. |
|
|
|
|
|
**Data Context:** |
|
|
{data_context_str} |
|
|
|
|
|
**Narrative Construction Guidelines:** |
|
|
1. **Use Compelling Headers:** Structure your report with multiple sections using Markdown headings (`##` or `###`). Do not write one long block of text. Create curiosity with your headers (e.g., 'The Smoking Premium: A Costly Habit', 'Geographic Hotspots: Where Charges Are Highest'). |
|
|
2. **Weave a Story:** Don't just describe the charts one by one. Connect the findings together. For example, how does 'age' relate to 'smoker status' and how do they both impact 'charges'? |
|
|
3. **Drive to Action:** Conclude your report with a dedicated section titled `## Actionable Recommendations`. Based on your analysis, provide specific, data-driven suggestions that a business leader could implement. |
|
|
|
|
|
**Your Toolbox (Most Important):** |
|
|
To support your story with visuals, you have been provided with a pre-approved 'palette' of unique charts. As you write your narrative, you **must** integrate each of these chart tags, one time, at the most logical point in the story. |
|
|
- You **must** use every chart tag from the provided palette exactly once. |
|
|
- Do **not** repeat chart tags. |
|
|
- Do **not** invent new chart tags. |
|
|
- Insert the tags in the format `<generate_chart: "the_description">`. |
|
|
|
|
|
**Chart Palette:** |
|
|
{json.dumps(chart_palette, indent=2)} |
|
|
|
|
|
Now, write the complete, comprehensive Markdown report. |
|
|
""" |
|
|
logging.info("Executing Master Storyteller Pass...") |
|
|
md = llm.invoke(storyteller_prompt).content.strip() |
|
|
logging.info("Master Storyteller Pass successful.") |
|
|
|
|
|
except Exception as e: |
|
|
logging.error(f"Guided Storyteller system failed: {e}. Reverting to single-pass fallback.") |
|
|
fallback_prompt = f""" |
|
|
You are an elite data storyteller and business intelligence expert. Your mission is to uncover the compelling, hidden narrative in this dataset and present it as a captivating story in Markdown format that drives action. |
|
|
**Data Context:** {data_context_str} |
|
|
**Your Grounding Rules (Most Important):** |
|
|
1. **Strict Accuracy:** Your entire analysis and narrative **must strictly** use the column names provided in the 'Data Context' section. |
|
|
2. **Chart Support:** Wherever a key finding is made, you **must** support it with a chart tag: `<generate_chart: "chart_type | a specific, compelling description">`. |
|
|
3. **Chart Accuracy:** The column names used in your chart descriptions **must** also be an exact match from the provided data context. |
|
|
Now, begin your report. Let the data's story unfold naturally. |
|
|
""" |
|
|
md = llm.invoke(fallback_prompt).content.strip() |
|
|
|
|
|
chart_descs = extract_chart_tags(md)[:MAX_CHARTS] |
|
|
chart_urls = {} |
|
|
chart_generator = ChartGenerator(llm, df) |
|
|
|
|
|
for desc in chart_descs: |
|
|
safe_desc = sanitize_for_firebase_key(desc) |
|
|
md = md.replace(f'<generate_chart: "{desc}">', f'<generate_chart: "{safe_desc}">') |
|
|
md = md.replace(f'<generate_chart: {desc}>', f'<generate_chart: "{safe_desc}">') |
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file: |
|
|
img_path = Path(temp_file.name) |
|
|
try: |
|
|
chart_spec = chart_generator.generate_chart_spec(desc, context_for_charts) |
|
|
if execute_chart_spec(chart_spec, df, img_path): |
|
|
blob_name = f"sozo_projects/{uid}/{project_id}/charts/{uuid.uuid4().hex}.png" |
|
|
blob = bucket.blob(blob_name) |
|
|
blob.upload_from_filename(str(img_path)) |
|
|
chart_urls[safe_desc] = blob.public_url |
|
|
finally: |
|
|
if os.path.exists(img_path): |
|
|
os.unlink(img_path) |
|
|
|
|
|
return {"raw_md": md, "chartUrls": chart_urls, "data_context": context_for_charts} |
|
|
|
|
|
def generate_single_chart(df: pd.DataFrame, description: str, uid: str, project_id: str, bucket): |
|
|
logging.info(f"Generating single chart '{description}' for project {project_id}") |
|
|
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", google_api_key=API_KEY, temperature=0.1) |
|
|
chart_generator = ChartGenerator(llm, df) |
|
|
context = get_augmented_context(df, "") |
|
|
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file: |
|
|
img_path = Path(temp_file.name) |
|
|
try: |
|
|
chart_spec = chart_generator.generate_chart_spec(description, context) |
|
|
if execute_chart_spec(chart_spec, df, img_path): |
|
|
blob_name = f"sozo_projects/{uid}/{project_id}/charts/{uuid.uuid4().hex}.png" |
|
|
blob = bucket.blob(blob_name) |
|
|
blob.upload_from_filename(str(img_path)) |
|
|
logging.info(f"Uploaded single chart to {blob.public_url}") |
|
|
return blob.public_url |
|
|
finally: |
|
|
if os.path.exists(img_path): |
|
|
os.unlink(img_path) |
|
|
return None |
|
|
|
|
|
|
|
|
def generate_video_from_project(df: pd.DataFrame, raw_md: str, data_context: Dict, uid: str, project_id: str, voice_model: str, bucket): |
|
|
logging.info(f"Generating video for project {project_id} with voice {voice_model}") |
|
|
llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash", google_api_key=API_KEY, temperature=0.2) |
|
|
|
|
|
domain = detect_dataset_domain(df) |
|
|
|
|
|
story_prompt = f""" |
|
|
Based on the following report, create a script for a {VIDEO_SCENES}-scene video. |
|
|
1. The first scene MUST be an "Introduction". It must contain narration and a stock video tag like: <generate_stock_video: "search query">. |
|
|
2. The last scene MUST be a "Conclusion". It must also contain narration and a stock video tag. |
|
|
3. The middle scenes should each contain narration and one chart tag from the report. |
|
|
4. Separate each scene with '[SCENE_BREAK]'. |
|
|
Report: {raw_md} |
|
|
Only output the script, no extra text. |
|
|
""" |
|
|
script = llm.invoke(story_prompt).content.strip() |
|
|
scenes = [s.strip() for s in script.split("[SCENE_BREAK]") if s.strip()] |
|
|
video_parts, audio_parts, temps = [], [], [] |
|
|
total_audio_duration = 0.0 |
|
|
conclusion_video_path = None |
|
|
|
|
|
for i, sc in enumerate(scenes): |
|
|
mp4 = Path(tempfile.gettempdir()) / f"{uuid.uuid4()}.mp4" |
|
|
narrative = clean_narration(sc) |
|
|
if not narrative: |
|
|
logging.warning(f"Scene {i+1} has no narration, skipping.") |
|
|
continue |
|
|
|
|
|
audio_bytes = deepgram_tts(narrative, voice_model) |
|
|
mp3 = Path(tempfile.gettempdir()) / f"{uuid.uuid4()}.mp3" |
|
|
audio_dur = 5.0 |
|
|
if audio_bytes: |
|
|
mp3.write_bytes(audio_bytes) |
|
|
audio_dur = audio_duration(str(mp3)) |
|
|
if audio_dur <= 0.1: audio_dur = 5.0 |
|
|
else: |
|
|
generate_silence_mp3(audio_dur, mp3) |
|
|
|
|
|
audio_parts.append(str(mp3)); temps.append(mp3) |
|
|
total_audio_duration += audio_dur |
|
|
|
|
|
video_dur = audio_dur + 1.5 |
|
|
|
|
|
try: |
|
|
|
|
|
chart_descs = extract_chart_tags(sc) |
|
|
pexels_descs = extract_pexels_tags(sc) |
|
|
is_conclusion_scene = any(k in narrative.lower() for k in ["conclusion", "summary", "in closing", "final thoughts"]) |
|
|
|
|
|
if pexels_descs: |
|
|
logging.info(f"Scene {i+1}: Processing Pexels scene.") |
|
|
base_keywords = extract_keywords_for_query(narrative, llm) |
|
|
final_query = f"{base_keywords} {domain}" |
|
|
video_path = search_and_download_pexels_video(final_query, video_dur, mp4) |
|
|
if not video_path: raise ValueError("Pexels search returned no results for chained query.") |
|
|
video_parts.append(video_path) |
|
|
if is_conclusion_scene: |
|
|
conclusion_video_path = video_path |
|
|
elif chart_descs: |
|
|
logging.info(f"Scene {i+1}: Primary attempt with animated chart.") |
|
|
if not chart_descs: raise ValueError("AI script failed to provide a chart tag for this scene.") |
|
|
safe_chart(chart_descs[0], df, video_dur, mp4, data_context) |
|
|
video_parts.append(str(mp4)) |
|
|
else: |
|
|
raise ValueError("No visual tag found in scene script.") |
|
|
except Exception as e: |
|
|
logging.warning(f"Scene {i+1}: Primary visual failed ({e}). Triggering Fallback Tier 1.") |
|
|
|
|
|
try: |
|
|
fallback_keywords = extract_keywords_for_query(narrative, llm) |
|
|
final_fallback_query = f"{fallback_keywords} {domain}" |
|
|
logging.info(f"Fallback Tier 1: Searching Pexels with query: '{final_fallback_query}'") |
|
|
|
|
|
video_path = search_and_download_pexels_video(final_fallback_query, video_dur, mp4) |
|
|
if not video_path: raise ValueError("Fallback Pexels search returned no results.") |
|
|
|
|
|
video_parts.append(video_path) |
|
|
logging.info(f"Scene {i+1}: Successfully recovered with a relevant Pexels video.") |
|
|
except Exception as fallback_e: |
|
|
|
|
|
logging.error(f"Scene {i+1}: Fallback Tier 1 also failed ({fallback_e}). Marking for final failsafe.") |
|
|
video_parts.append("FALLBACK_NEEDED") |
|
|
|
|
|
temps.append(mp4) |
|
|
|
|
|
if not conclusion_video_path: |
|
|
logging.warning("No conclusion video was generated; creating a generic one for fallbacks.") |
|
|
fallback_mp4 = Path(tempfile.gettempdir()) / f"{uuid.uuid4()}.mp4" |
|
|
conclusion_video_path = search_and_download_pexels_video(f"data visualization abstract {domain}", 5.0, fallback_mp4) |
|
|
if conclusion_video_path: temps.append(fallback_mp4) |
|
|
|
|
|
final_video_parts = [] |
|
|
for part in video_parts: |
|
|
if part == "FALLBACK_NEEDED": |
|
|
if conclusion_video_path: |
|
|
fallback_copy_path = Path(tempfile.gettempdir()) / f"fallback_{uuid.uuid4().hex}.mp4" |
|
|
shutil.copy(conclusion_video_path, fallback_copy_path) |
|
|
temps.append(fallback_copy_path) |
|
|
final_video_parts.append(str(fallback_copy_path)) |
|
|
logging.info(f"Applying unique copy of conclusion video as fallback for a failed scene.") |
|
|
else: |
|
|
logging.error("Cannot apply fallback; no conclusion video available. A scene will be missing.") |
|
|
else: |
|
|
final_video_parts.append(part) |
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_vid, \ |
|
|
tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as temp_aud, \ |
|
|
tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as final_vid, \ |
|
|
tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as branded_vid: |
|
|
|
|
|
silent_vid_path = Path(temp_vid.name) |
|
|
audio_mix_path = Path(temp_aud.name) |
|
|
final_vid_path = Path(final_vid.name) |
|
|
branded_vid_path = Path(branded_vid.name) |
|
|
|
|
|
concat_media(final_video_parts, silent_vid_path) |
|
|
concat_media(audio_parts, audio_mix_path) |
|
|
|
|
|
cmd = [ |
|
|
"ffmpeg", "-y", "-i", str(silent_vid_path), "-i", str(audio_mix_path), |
|
|
"-c:v", "libx264", "-pix_fmt", "yuv420p", "-c:a", "aac", |
|
|
"-map", "0:v:0", "-map", "1:a:0", |
|
|
"-t", f"{total_audio_duration:.3f}", |
|
|
str(final_vid_path) |
|
|
] |
|
|
subprocess.run(cmd, check=True, capture_output=True) |
|
|
|
|
|
upload_path = final_vid_path |
|
|
logo_path = Path("sozob.png") |
|
|
|
|
|
if logo_path.exists(): |
|
|
logging.info("Logo 'sozob.png' found. Adding full-screen end-card.") |
|
|
duration_for_filter = total_audio_duration |
|
|
|
|
|
filter_complex = f"[1:v]scale={WIDTH}:{HEIGHT}[logo];[0:v][logo]overlay=0:0:enable='gte(t,{duration_for_filter - 2})'" |
|
|
|
|
|
logo_cmd = [ |
|
|
"ffmpeg", "-y", |
|
|
"-i", str(final_vid_path), |
|
|
"-i", str(logo_path), |
|
|
"-filter_complex", filter_complex, |
|
|
"-map", "0:a", |
|
|
"-c:a", "copy", |
|
|
"-c:v", "libx264", "-pix_fmt", "yuv420p", |
|
|
str(branded_vid_path) |
|
|
] |
|
|
try: |
|
|
subprocess.run(logo_cmd, check=True, capture_output=True) |
|
|
upload_path = branded_vid_path |
|
|
except subprocess.CalledProcessError as e: |
|
|
logging.error(f"Failed to add logo end-card. Uploading unbranded video. Error: {e.stderr.decode()}") |
|
|
else: |
|
|
logging.warning("Logo 'sozob.png' not found in root directory. Skipping end-card.") |
|
|
|
|
|
blob_name = f"sozo_projects/{uid}/{project_id}/video.mp4" |
|
|
blob = bucket.blob(blob_name) |
|
|
blob.upload_from_filename(str(upload_path)) |
|
|
logging.info(f"Uploaded video to {blob.public_url}") |
|
|
|
|
|
for p in temps + [silent_vid_path, audio_mix_path, final_vid_path, branded_vid_path]: |
|
|
if os.path.exists(p): os.unlink(p) |
|
|
|
|
|
return blob.public_url |
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
def generate_image_with_gemini(prompt: str) -> Image.Image: |
|
|
"""Generates an image using the specified Gemini model and client configuration.""" |
|
|
logging.info(f"Generating Gemini image with prompt: '{prompt}'") |
|
|
try: |
|
|
|
|
|
client = genai.Client(api_key=API_KEY) |
|
|
full_prompt = f"A professional, 3d digital art style illustration for a business presentation: {prompt}" |
|
|
|
|
|
response = client.models.generate_content( |
|
|
model="gemini-2.0-flash-exp", |
|
|
contents=full_prompt, |
|
|
config=genai_types.GenerateContentConfig( |
|
|
response_modalities=["Text", "Image"] |
|
|
), |
|
|
) |
|
|
|
|
|
|
|
|
img_part = next((part for part in response.candidates[0].content.parts if part.content_type == "Image"), None) |
|
|
|
|
|
if img_part: |
|
|
|
|
|
return Image.open(io.BytesIO(img_part.content)).convert("RGB") |
|
|
else: |
|
|
logging.error("Gemini response did not contain an image.") |
|
|
return None |
|
|
except Exception as e: |
|
|
logging.error(f"Gemini image generation failed: {e}") |
|
|
return None |
|
|
|
|
|
def generate_slides_from_report(raw_md: str, chart_urls: dict, uid: str, project_id: str, bucket, llm): |
|
|
""" |
|
|
Uses an AI planner to convert a report into a 10-slide presentation deck. |
|
|
""" |
|
|
logging.info(f"Generating slides for project {project_id}") |
|
|
|
|
|
planner_prompt = f""" |
|
|
You are an expert presentation designer. Your task is to convert the following data analysis report into a concise and visually engaging 10-slide deck. |
|
|
|
|
|
**Full Report Content:** |
|
|
--- |
|
|
{raw_md} |
|
|
--- |
|
|
|
|
|
**Instructions:** |
|
|
1. Read the entire report to understand the core narrative and key findings. |
|
|
2. Create a plan for exactly 10 slides. |
|
|
3. For each slide, define a `title` and short `content` (2-3 bullet points or a brief paragraph). |
|
|
4. For the visual on each slide, you must decide between two types: |
|
|
- If a report section is supported by an existing chart (indicated by a `<generate_chart:...>` tag), you **must** use it. Set `visual_type: "existing_chart"` and `visual_ref: "the exact chart description from the tag"`. |
|
|
- For key points without a chart (like introductions, conclusions, or text-only insights), you **must** request a new image. Set `visual_type: "new_image"` and `visual_ref: "a concise, descriptive prompt for an AI to generate a 3D digital art style illustration"`. |
|
|
5. You must request exactly 3-4 new images to balance the presentation. |
|
|
|
|
|
**Output Format:** |
|
|
Return ONLY a valid JSON array of 10 slide objects. Do not include any other text or markdown formatting. |
|
|
|
|
|
Example: |
|
|
[ |
|
|
{{ "slide_number": 1, "title": "Introduction", "content": "...", "visual_type": "new_image", "visual_ref": "A 3D illustration of a rising stock chart" }}, |
|
|
{{ "slide_number": 2, "title": "Sales by Region", "content": "...", "visual_type": "existing_chart", "visual_ref": "bar | Sales by Region" }}, |
|
|
... |
|
|
] |
|
|
""" |
|
|
|
|
|
try: |
|
|
plan_response = llm.invoke(planner_prompt).content.strip() |
|
|
if plan_response.startswith("```json"): |
|
|
plan_response = plan_response[7:-3] |
|
|
slide_plan = json.loads(plan_response) |
|
|
except Exception as e: |
|
|
logging.error(f"Failed to generate or parse slide plan: {e}") |
|
|
return None |
|
|
|
|
|
final_slides = [] |
|
|
for slide in slide_plan: |
|
|
try: |
|
|
image_url = None |
|
|
visual_type = slide.get("visual_type") |
|
|
visual_ref = slide.get("visual_ref") |
|
|
|
|
|
if visual_type == "existing_chart": |
|
|
sanitized_ref = sanitize_for_firebase_key(visual_ref) |
|
|
image_url = chart_urls.get(sanitized_ref) |
|
|
if not image_url: |
|
|
logging.warning(f"Could not find existing chart for ref: '{visual_ref}' (sanitized: '{sanitized_ref}')") |
|
|
|
|
|
elif visual_type == "new_image": |
|
|
img = generate_image_with_gemini(visual_ref) |
|
|
if img: |
|
|
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file: |
|
|
img_path = Path(temp_file.name) |
|
|
img.save(img_path, format="PNG") |
|
|
|
|
|
blob_name = f"sozo_projects/{uid}/{project_id}/slides/slide_{uuid.uuid4().hex}.png" |
|
|
blob = bucket.blob(blob_name) |
|
|
blob.upload_from_filename(str(img_path)) |
|
|
image_url = blob.public_url |
|
|
logging.info(f"Uploaded new slide image to {image_url}") |
|
|
os.unlink(img_path) |
|
|
|
|
|
if not image_url: |
|
|
logging.warning(f"Visual generation failed for slide {slide.get('slide_number')}. Skipping visual for this slide.") |
|
|
|
|
|
final_slides.append({ |
|
|
"slide_number": slide.get("slide_number"), |
|
|
"title": slide.get("title"), |
|
|
"content": slide.get("content"), |
|
|
"image_url": image_url or "" |
|
|
}) |
|
|
except Exception as slide_e: |
|
|
logging.error(f"Failed to process slide {slide.get('slide_number')}: {slide_e}") |
|
|
continue |
|
|
|
|
|
return final_slides |