Spaces:
Sleeping
Sleeping
| # sozo_gen.py | |
| import os | |
| import re | |
| import json | |
| import logging | |
| import uuid | |
| import io | |
| from pathlib import Path | |
| import pandas as pd | |
| import numpy as np | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| from matplotlib.animation import FuncAnimation, FFMpegWriter | |
| import seaborn as sns | |
| from scipy import stats | |
| from PIL import Image, ImageDraw, ImageFont | |
| import cv2 | |
| import inspect | |
| import tempfile | |
| import subprocess | |
| from typing import Dict, List, Tuple, Any | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| from google import genai | |
| import requests | |
| # In sozo_gen.py, near the other google imports | |
| from google.genai import types as genai_types | |
| # --- Configuration --- | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - [%(funcName)s] - %(message)s') | |
| FPS, WIDTH, HEIGHT = 24, 1280, 720 | |
| MAX_CHARTS, VIDEO_SCENES = 5, 5 | |
| MAX_CONTEXT_TOKENS = 750000 | |
| # --- API Initialization --- | |
| API_KEY = os.getenv("GOOGLE_API_KEY") | |
| if not API_KEY: | |
| raise ValueError("GOOGLE_API_KEY environment variable not set.") | |
| PEXELS_API_KEY = os.getenv("PEXELS_API_KEY") | |
| # --- Helper Functions --- | |
| def load_dataframe_safely(buf, name: str): | |
| ext = Path(name).suffix.lower() | |
| df = (pd.read_excel if ext in (".xlsx", ".xls") else pd.read_csv)(buf) | |
| df.columns = df.columns.astype(str).str.strip() | |
| df = df.dropna(how="all") | |
| if df.empty or len(df.columns) == 0: raise ValueError("No usable data found") | |
| return df | |
| def deepgram_tts(txt: str, voice_model: str): | |
| DG_KEY = os.getenv("DEEPGRAM_API_KEY") | |
| if not DG_KEY or not txt: return None | |
| txt = re.sub(r"[^\w\s.,!?;:-]", "", txt) | |
| try: | |
| r = requests.post("https://api.deepgram.com/v1/speak", params={"model": voice_model}, headers={"Authorization": f"Token {DG_KEY}", "Content-Type": "application/json"}, json={"text": txt}, timeout=30) | |
| r.raise_for_status() | |
| return r.content | |
| except Exception as e: | |
| logging.error(f"Deepgram TTS failed: {e}") | |
| return None | |
| def generate_silence_mp3(duration: float, out: Path): | |
| subprocess.run([ "ffmpeg", "-y", "-f", "lavfi", "-i", "anullsrc=r=44100:cl=mono", "-t", f"{duration:.3f}", "-q:a", "9", str(out)], check=True, capture_output=True) | |
| def audio_duration(path: str) -> float: | |
| try: | |
| res = subprocess.run([ "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=nw=1:nk=1", path], text=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True) | |
| return float(res.stdout.strip()) | |
| except Exception: return 5.0 | |
| TAG_RE = re.compile( r'[<[]\s*generate_?chart\s*[:=]?\s*[\"\'“”]?(?P<d>[^>\"\'”\]]+?)[\"\'“”]?\s*[>\]]', re.I, ) | |
| TAG_RE_PEXELS = re.compile( r'[<[]\s*generate_?stock_?video\s*[:=]?\s*[\"\'“”]?(?P<d>[^>\"\'”\]]+?)[\"\'“”]?\s*[>\]]', re.I, ) | |
| extract_chart_tags = lambda t: list( dict.fromkeys(m.group("d").strip() for m in TAG_RE.finditer(t or "")) ) | |
| extract_pexels_tags = lambda t: list( dict.fromkeys(m.group("d").strip() for m in TAG_RE_PEXELS.finditer(t or "")) ) | |
| re_scene = re.compile(r"^\s*scene\s*\d+[:.\- ]*", re.I | re.M) | |
| def clean_narration(txt: str) -> str: | |
| txt = TAG_RE.sub("", txt); txt = TAG_RE_PEXELS.sub("", txt); txt = re_scene.sub("", txt) | |
| phrases_to_remove = [r"chart tag", r"chart_tag", r"narration", r"stock video tag"] | |
| for phrase in phrases_to_remove: txt = re.sub(phrase, "", txt, flags=re.IGNORECASE) | |
| txt = re.sub(r"\s*\([^)]*\)", "", txt); txt = re.sub(r"[\*#_]", "", txt) | |
| return re.sub(r"\s{2,}", " ", txt).strip() | |
| def placeholder_img() -> Image.Image: return Image.new("RGB", (WIDTH, HEIGHT), (230, 230, 230)) | |
| # In sozo_gen.py, add these new functions at the end of the file | |
| def generate_image_with_gemini(prompt: str) -> Image.Image: | |
| """Generates an image using the specified Gemini model and client configuration.""" | |
| logging.info(f"Generating Gemini image with prompt: '{prompt}'") | |
| try: | |
| # Use the genai.Client as per the correct implementation | |
| client = genai.Client(api_key=API_KEY) | |
| full_prompt = f"A professional, 3d digital art style illustration for a business presentation: {prompt}" | |
| response = client.models.generate_content( | |
| model="gemini-2.0-flash-exp", | |
| contents=full_prompt, | |
| config=genai_types.GenerateContentConfig( | |
| response_modalities=["Text", "Image"] | |
| ), | |
| ) | |
| # Find the image part in the response | |
| img_part = next((part for part in response.candidates[0].content.parts if part.content_type == "Image"), None) | |
| if img_part: | |
| # The content is already bytes, so we can open it directly | |
| return Image.open(io.BytesIO(img_part.content)).convert("RGB") | |
| else: | |
| logging.error("Gemini response did not contain an image.") | |
| return None | |
| except Exception as e: | |
| logging.error(f"Gemini image generation failed: {e}") | |
| return None | |
| def generate_slides_from_report(raw_md: str, chart_urls: dict, uid: str, project_id: str, bucket, llm): | |
| """ | |
| Uses an AI planner to convert a report into a 10-slide presentation deck. | |
| """ | |
| logging.info(f"Generating slides for project {project_id}") | |
| planner_prompt = f""" | |
| You are an expert presentation designer. Your task is to convert the following data analysis report into a concise and visually engaging 10-slide deck. | |
| **Full Report Content:** | |
| --- | |
| {raw_md} | |
| --- | |
| **Instructions:** | |
| 1. Read the entire report to understand the core narrative and key findings. | |
| 2. Create a plan for exactly 10 slides. | |
| 3. For each slide, define a `title` and short `content` (2-3 bullet points or a brief paragraph). | |
| 4. For the visual on each slide, you must decide between two types: | |
| - If a report section is supported by an existing chart (indicated by a `<generate_chart:...>` tag), you **must** use it. Set `visual_type: "existing_chart"` and `visual_ref: "the exact chart description from the tag"`. | |
| - For key points without a chart (like introductions, conclusions, or text-only insights), you **must** request a new image. Set `visual_type: "new_image"` and `visual_ref: "a concise, descriptive prompt for an AI to generate a 3D digital art style illustration"`. | |
| 5. You must request exactly 3-4 new images to balance the presentation. | |
| **Output Format:** | |
| Return ONLY a valid JSON array of 10 slide objects. Do not include any other text or markdown formatting. | |
| Example: | |
| [ | |
| {{ "slide_number": 1, "title": "Introduction", "content": "...", "visual_type": "new_image", "visual_ref": "A 3D illustration of a rising stock chart" }}, | |
| {{ "slide_number": 2, "title": "Sales by Region", "content": "...", "visual_type": "existing_chart", "visual_ref": "bar | Sales by Region" }}, | |
| ... | |
| ] | |
| """ | |
| try: | |
| plan_response = llm.invoke(planner_prompt).content.strip() | |
| if plan_response.startswith("```json"): | |
| plan_response = plan_response[7:-3] | |
| slide_plan = json.loads(plan_response) | |
| except Exception as e: | |
| logging.error(f"Failed to generate or parse slide plan: {e}") | |
| return None | |
| final_slides = [] | |
| for slide in slide_plan: | |
| try: | |
| image_url = None | |
| visual_type = slide.get("visual_type") | |
| visual_ref = slide.get("visual_ref") | |
| if visual_type == "existing_chart": | |
| sanitized_ref = sanitize_for_firebase_key(visual_ref) | |
| image_url = chart_urls.get(sanitized_ref) | |
| if not image_url: | |
| logging.warning(f"Could not find existing chart for ref: '{visual_ref}' (sanitized: '{sanitized_ref}')") | |
| elif visual_type == "new_image": | |
| img = generate_image_with_gemini(visual_ref) | |
| if img: | |
| with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file: | |
| img_path = Path(temp_file.name) | |
| img.save(img_path, format="PNG") | |
| blob_name = f"sozo_projects/{uid}/{project_id}/slides/slide_{uuid.uuid4().hex}.png" | |
| blob = bucket.blob(blob_name) | |
| blob.upload_from_filename(str(img_path)) | |
| image_url = blob.public_url | |
| logging.info(f"Uploaded new slide image to {image_url}") | |
| os.unlink(img_path) | |
| if not image_url: | |
| logging.warning(f"Visual generation failed for slide {slide.get('slide_number')}. Skipping visual for this slide.") | |
| final_slides.append({ | |
| "slide_number": slide.get("slide_number"), | |
| "title": slide.get("title"), | |
| "content": slide.get("content"), | |
| "image_url": image_url or "" | |
| }) | |
| except Exception as slide_e: | |
| logging.error(f"Failed to process slide {slide.get('slide_number')}: {slide_e}") | |
| continue | |
| return final_slides | |
| # NEW: Keyword extraction for better Pexels searches | |
| def extract_keywords_for_query(text: str, llm) -> str: | |
| prompt = f""" | |
| Extract 2-4 key nouns and verbs from the following text to use as a search query for a stock video. | |
| Focus on concrete actions and subjects. | |
| Example: 'Our analysis shows a significant growth in quarterly revenue and strong partnerships.' -> 'data analysis growth chart business' | |
| Output only the search query keywords, separated by spaces. | |
| Text: "{text}" | |
| """ | |
| try: | |
| response = llm.invoke(prompt).content.strip() | |
| return response | |
| except Exception as e: | |
| logging.error(f"Keyword extraction failed: {e}. Using original text.") | |
| return text # Fallback to the original text if LLM fails | |
| # UPDATED: Pexels search now loops short videos | |
| def search_and_download_pexels_video(query: str, duration: float, out_path: Path) -> str: | |
| if not PEXELS_API_KEY: | |
| logging.warning("PEXELS_API_KEY not set. Cannot fetch stock video.") | |
| return None | |
| try: | |
| headers = {"Authorization": PEXELS_API_KEY} | |
| params = {"query": query, "per_page": 10, "orientation": "landscape"} | |
| response = requests.get("https://api.pexels.com/videos/search", headers=headers, params=params, timeout=20) | |
| response.raise_for_status() | |
| videos = response.json().get('videos', []) | |
| if not videos: | |
| logging.warning(f"No Pexels videos found for query: '{query}'") | |
| return None | |
| video_to_download = None | |
| for video in videos: | |
| for f in video.get('video_files', []): | |
| if f.get('quality') == 'hd' and f.get('width') >= 1280: | |
| video_to_download = f['link'] | |
| break | |
| if video_to_download: | |
| break | |
| if not video_to_download: | |
| logging.warning(f"No suitable HD video file found for query: '{query}'") | |
| return None | |
| with requests.get(video_to_download, stream=True, timeout=60) as r: | |
| r.raise_for_status() | |
| with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_dl_file: | |
| for chunk in r.iter_content(chunk_size=8192): | |
| temp_dl_file.write(chunk) | |
| temp_dl_path = Path(temp_dl_file.name) | |
| # UPDATED: Added -stream_loop -1 to loop short videos | |
| cmd = [ | |
| "ffmpeg", "-y", | |
| "-stream_loop", "-1", # Loop the input video | |
| "-i", str(temp_dl_path), | |
| "-vf", f"scale={WIDTH}:{HEIGHT}:force_original_aspect_ratio=decrease,pad={WIDTH}:{HEIGHT}:(ow-iw)/2:(oh-ih)/2,setsar=1", | |
| "-t", f"{duration:.3f}", # Cut the looped video to the exact duration | |
| "-c:v", "libx264", "-pix_fmt", "yuv420p", "-an", | |
| str(out_path) | |
| ] | |
| subprocess.run(cmd, check=True, capture_output=True) | |
| temp_dl_path.unlink() | |
| return str(out_path) | |
| except Exception as e: | |
| logging.error(f"Pexels video processing failed for query '{query}': {e}") | |
| if 'temp_dl_path' in locals() and temp_dl_path.exists(): | |
| temp_dl_path.unlink() | |
| return None | |
| class ChartSpecification: | |
| def __init__(self, chart_type: str, title: str, x_col: str, y_col: str = None, size_col: str = None, agg_method: str = None, filter_condition: str = None, top_n: int = None, color_scheme: str = "professional"): | |
| self.chart_type = chart_type; self.title = title; self.x_col = x_col; self.y_col = y_col; self.size_col = size_col | |
| self.agg_method = agg_method or "sum"; self.filter_condition = filter_condition; self.top_n = top_n; self.color_scheme = color_scheme | |
| class ChartGenerator: | |
| def __init__(self, llm, df: pd.DataFrame): | |
| self.llm = llm; self.df = df | |
| def generate_chart_spec(self, description: str, context: Dict) -> ChartSpecification: | |
| spec_prompt = f""" | |
| You are a data visualization expert. Based on the dataset context and chart description, generate a precise chart specification. | |
| **Dataset Context:** {json.dumps(context, indent=2)} | |
| **Chart Request:** {description} | |
| **Return a JSON specification with these exact fields:** | |
| {{ | |
| "chart_type": "bar|pie|line|scatter|hist|heatmap|area|bubble", | |
| "title": "Professional chart title", | |
| "x_col": "column_name_for_x_axis_or_null_for_heatmap", | |
| "y_col": "column_name_for_y_axis_or_null", | |
| "size_col": "column_name_for_bubble_size_or_null", | |
| "agg_method": "sum|mean|count|max|min|null", | |
| "top_n": "number_for_top_n_filtering_or_null" | |
| }} | |
| Return only the JSON specification, no additional text. | |
| """ | |
| try: | |
| response = self.llm.invoke(spec_prompt).content.strip() | |
| if response.startswith("```json"): response = response[7:-3] | |
| elif response.startswith("```"): response = response[3:-3] | |
| spec_dict = json.loads(response) | |
| valid_keys = [p.name for p in inspect.signature(ChartSpecification).parameters.values() if p.name not in ['reasoning', 'filter_condition', 'color_scheme']] | |
| filtered_dict = {k: v for k, v in spec_dict.items() if k in valid_keys} | |
| return ChartSpecification(**filtered_dict) | |
| except Exception as e: | |
| logging.error(f"Spec generation failed: {e}. Using fallback.") | |
| numeric_cols = context.get('schema', {}).get('numeric_columns', list(self.df.select_dtypes(include=['number']).columns)) | |
| categorical_cols = context.get('schema', {}).get('categorical_columns', list(self.df.select_dtypes(exclude=['number']).columns)) | |
| ctype = "bar" | |
| for t in ["pie", "line", "scatter", "hist", "heatmap", "area", "bubble"]: | |
| if t in description.lower(): ctype = t | |
| x = categorical_cols[0] if categorical_cols else self.df.columns[0] | |
| y = numeric_cols[0] if numeric_cols and len(self.df.columns) > 1 else (self.df.columns[1] if len(self.df.columns) > 1 else None) | |
| return ChartSpecification(ctype, description, x, y) | |
| def execute_chart_spec(spec: ChartSpecification, df: pd.DataFrame, output_path: Path) -> bool: | |
| try: | |
| plot_data = prepare_plot_data(spec, df) | |
| fig, ax = plt.subplots(figsize=(12, 8)); plt.style.use('default') | |
| if spec.chart_type == "bar": ax.bar(plot_data.index.astype(str), plot_data.values, color='#2E86AB', alpha=0.8); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col); ax.tick_params(axis='x', rotation=45) | |
| elif spec.chart_type == "pie": ax.pie(plot_data.values, labels=plot_data.index, autopct='%1.1f%%', startangle=90); ax.axis('equal') | |
| elif spec.chart_type == "line": ax.plot(plot_data.index, plot_data.values, marker='o', linewidth=2, color='#A23B72'); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col); ax.grid(True, alpha=0.3) | |
| elif spec.chart_type == "scatter": ax.scatter(plot_data.iloc[:, 0], plot_data.iloc[:, 1], alpha=0.6, color='#F18F01'); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col); ax.grid(True, alpha=0.3) | |
| elif spec.chart_type == "hist": ax.hist(plot_data.values, bins=20, color='#C73E1D', alpha=0.7, edgecolor='black'); ax.set_xlabel(spec.x_col); ax.set_ylabel('Frequency'); ax.grid(True, alpha=0.3) | |
| elif spec.chart_type == "area": ax.fill_between(plot_data.index, plot_data.values, color="#4E79A7", alpha=0.4); ax.plot(plot_data.index, plot_data.values, color="#4E79A7", alpha=0.8); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col); ax.grid(True, alpha=0.3) | |
| elif spec.chart_type == "heatmap": sns.heatmap(plot_data, annot=True, cmap="viridis", ax=ax); plt.xticks(rotation=45, ha="right"); plt.yticks(rotation=0) | |
| elif spec.chart_type == "bubble": | |
| sizes = (plot_data[spec.size_col] - plot_data[spec.size_col].min() + 1) / (plot_data[spec.size_col].max() - plot_data[spec.size_col].min() + 1) * 2000 + 50 | |
| ax.scatter(plot_data[spec.x_col], plot_data[spec.y_col], s=sizes, alpha=0.6, color='#59A14F'); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col); ax.grid(True, alpha=0.3) | |
| ax.set_title(spec.title, fontsize=14, fontweight='bold', pad=20); plt.tight_layout() | |
| plt.savefig(output_path, dpi=300, bbox_inches='tight', facecolor='white'); plt.close() | |
| return True | |
| except Exception as e: logging.error(f"Static chart generation failed for '{spec.title}': {e}"); return False | |
| def prepare_plot_data(spec: ChartSpecification, df: pd.DataFrame): | |
| if spec.chart_type not in ["heatmap"]: | |
| if spec.x_col not in df.columns or (spec.y_col and spec.y_col not in df.columns): raise ValueError(f"Invalid columns in chart spec: {spec.x_col}, {spec.y_col}") | |
| if spec.chart_type in ["bar", "pie"]: | |
| if not spec.y_col: return df[spec.x_col].value_counts().nlargest(spec.top_n or 10) | |
| grouped = df.groupby(spec.x_col)[spec.y_col].agg(spec.agg_method or 'sum') | |
| return grouped.nlargest(spec.top_n or 10) | |
| elif spec.chart_type in ["line", "area"]: return df.set_index(spec.x_col)[spec.y_col].sort_index() | |
| elif spec.chart_type == "scatter": return df[[spec.x_col, spec.y_col]].dropna() | |
| elif spec.chart_type == "bubble": | |
| if not spec.size_col or spec.size_col not in df.columns: raise ValueError("Bubble chart requires a valid size_col.") | |
| return df[[spec.x_col, spec.y_col, spec.size_col]].dropna() | |
| elif spec.chart_type == "hist": return df[spec.x_col].dropna() | |
| elif spec.chart_type == "heatmap": | |
| numeric_cols = df.select_dtypes(include=np.number).columns | |
| if not numeric_cols.any(): raise ValueError("Heatmap requires numeric columns.") | |
| return df[numeric_cols].corr() | |
| return df[spec.x_col] | |
| # UPDATED: animate_chart now uses blit=False for accurate timing | |
| def animate_chart(spec: ChartSpecification, df: pd.DataFrame, dur: float, out: Path, fps: int = FPS) -> str: | |
| plot_data = prepare_plot_data(spec, df) | |
| frames = max(10, int(dur * fps)) | |
| fig, ax = plt.subplots(figsize=(WIDTH / 100, HEIGHT / 100), dpi=100) | |
| plt.tight_layout(pad=3.0) | |
| ctype = spec.chart_type | |
| # Animation logic remains the same, only the final call to FuncAnimation changes | |
| if ctype == "pie": | |
| wedges, _, _ = ax.pie(plot_data, labels=plot_data.index, startangle=90, autopct='%1.1f%%') | |
| ax.set_title(spec.title); ax.axis('equal') | |
| def init(): [w.set_alpha(0) for w in wedges]; return wedges | |
| def update(i): [w.set_alpha(i / (frames - 1)) for w in wedges]; return wedges | |
| elif ctype == "bar": | |
| bars = ax.bar(plot_data.index.astype(str), np.zeros_like(plot_data.values, dtype=float), color="#1f77b4") | |
| ax.set_ylim(0, plot_data.max() * 1.1 if not pd.isna(plot_data.max()) and plot_data.max() > 0 else 1) | |
| ax.set_title(spec.title); plt.xticks(rotation=45, ha="right") | |
| def init(): return bars | |
| def update(i): | |
| for b, h in zip(bars, plot_data.values): b.set_height(h * (i / (frames - 1))) | |
| return bars | |
| elif ctype == "scatter": | |
| x_full, y_full = plot_data.iloc[:, 0], plot_data.iloc[:, 1] | |
| slope, intercept, _, _, _ = stats.linregress(x_full, y_full) | |
| reg_line_x = np.array([x_full.min(), x_full.max()]) | |
| reg_line_y = slope * reg_line_x + intercept | |
| scat = ax.scatter([], [], alpha=0.7, color='#F18F01') | |
| line, = ax.plot([], [], 'r--', lw=2) | |
| ax.set_xlim(x_full.min(), x_full.max()); ax.set_ylim(y_full.min(), y_full.max()) | |
| ax.set_title(spec.title); ax.grid(alpha=.3); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col) | |
| def init(): | |
| scat.set_offsets(np.empty((0, 2))); line.set_data([], []) | |
| return [scat, line] | |
| def update(i): | |
| point_frames = int(frames * 0.7) | |
| if i <= point_frames: | |
| k = max(1, int(len(x_full) * (i / point_frames))) | |
| scat.set_offsets(plot_data.iloc[:k].values) | |
| else: | |
| line_frame = i - point_frames; line_total_frames = frames - point_frames | |
| current_x = reg_line_x[0] + (reg_line_x[1] - reg_line_x[0]) * (line_frame / line_total_frames) | |
| line.set_data([reg_line_x[0], current_x], [reg_line_y[0], slope * current_x + intercept]) | |
| return [scat, line] | |
| else: # line, area, hist, etc. | |
| # This is a simplified representation; the full logic from previous steps is assumed here | |
| # For brevity, we'll just show the line chart example | |
| line, = ax.plot([], [], lw=2, color='#A23B72') | |
| plot_data = plot_data.sort_index() | |
| x_full, y_full = plot_data.index, plot_data.values | |
| ax.set_xlim(x_full.min(), x_full.max()); ax.set_ylim(y_full.min() * 0.9, y_full.max() * 1.1) | |
| ax.set_title(spec.title); ax.grid(alpha=.3); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col) | |
| def init(): line.set_data([], []); return [line] | |
| def update(i): | |
| k = max(2, int(len(x_full) * (i / (frames - 1)))) | |
| line.set_data(x_full[:k], y_full[:k]); return [line] | |
| # The key change: blit=False | |
| anim = FuncAnimation(fig, update, init_func=init, frames=frames, blit=False, interval=1000 / fps) | |
| anim.save(str(out), writer=FFMpegWriter(fps=fps), dpi=144) | |
| plt.close(fig) | |
| return str(out) | |
| def animate_image_fade(img: np.ndarray, dur: float, out: Path, fps: int = 24) -> str: | |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v'); video_writer = cv2.VideoWriter(str(out), fourcc, fps, (WIDTH, HEIGHT)) | |
| total_frames = max(1, int(dur * fps)) | |
| for i in range(total_frames): | |
| alpha = i / (total_frames - 1) if total_frames > 1 else 1.0 | |
| frame = cv2.addWeighted(img, alpha, np.zeros_like(img), 1 - alpha, 0) | |
| video_writer.write(frame) | |
| video_writer.release() | |
| return str(out) | |
| def safe_chart(desc: str, df: pd.DataFrame, dur: float, out: Path, context: Dict) -> str: | |
| try: | |
| llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", google_api_key=API_KEY, temperature=0.1) | |
| chart_generator = ChartGenerator(llm, df) | |
| chart_spec = chart_generator.generate_chart_spec(desc, context) | |
| return animate_chart(chart_spec, df, dur, out) | |
| except Exception as e: | |
| logging.error(f"Chart animation failed for '{desc}': {e}. Raising exception to trigger fallback.") | |
| raise e # Raise exception to be caught by the video generator's fallback logic | |
| def concat_media(file_paths: List[str], output_path: Path): | |
| valid_paths = [p for p in file_paths if Path(p).exists() and Path(p).stat().st_size > 100] | |
| if not valid_paths: raise ValueError("No valid media files to concatenate.") | |
| if len(valid_paths) == 1: import shutil; shutil.copy2(valid_paths[0], str(output_path)); return | |
| list_file = output_path.with_suffix(".txt") | |
| with open(list_file, 'w') as f: | |
| for path in valid_paths: f.write(f"file '{Path(path).resolve()}'\n") | |
| cmd = ["ffmpeg", "-y", "-f", "concat", "-safe", "0", "-i", str(list_file), "-c", "copy", str(output_path)] | |
| try: | |
| subprocess.run(cmd, check=True, capture_output=True, text=True) | |
| finally: | |
| list_file.unlink(missing_ok=True) | |
| def sanitize_for_firebase_key(text: str) -> str: | |
| forbidden_chars = ['.', '$', '#', '[', ']', '/'] | |
| for char in forbidden_chars: | |
| text = text.replace(char, '_') | |
| return text | |
| def analyze_data_intelligence(df: pd.DataFrame) -> Dict: | |
| numeric_cols = df.select_dtypes(include=['number']).columns.tolist() | |
| categorical_cols = df.select_dtypes(exclude=['number']).columns.tolist() | |
| is_timeseries = any('date' in col.lower() or 'time' in col.lower() for col in df.columns) | |
| opportunities = [] | |
| if is_timeseries: opportunities.append("temporal trends") | |
| if len(numeric_cols) > 1: opportunities.append("correlations between metrics") | |
| if len(categorical_cols) > 0 and len(numeric_cols) > 0: opportunities.append("segmentation by category") | |
| if df.isnull().sum().sum() > 0: opportunities.append("impact of missing data") | |
| return { | |
| "insight_opportunities": opportunities, | |
| "is_timeseries": is_timeseries, | |
| "has_correlations": len(numeric_cols) > 1, | |
| "has_segments": len(categorical_cols) > 0 and len(numeric_cols) > 0 | |
| } | |
| def generate_visualization_strategy(intelligence: Dict) -> str: | |
| strategy = "Vary your visualizations to keep the report engaging. " | |
| if intelligence["is_timeseries"]: strategy += "Use 'line' or 'area' charts to explore temporal trends. " | |
| if intelligence["has_correlations"]: strategy += "Use 'scatter' or 'heatmap' charts to reveal correlations. " | |
| if intelligence["has_segments"]: strategy += "Use 'bar' or 'pie' charts to compare segments. " | |
| return strategy | |
| def get_augmented_context(df: pd.DataFrame, user_ctx: str) -> Dict: | |
| numeric_cols = df.select_dtypes(include=['number']).columns.tolist() | |
| categorical_cols = df.select_dtypes(exclude=['number']).columns.tolist() | |
| context = { | |
| "user_context": user_ctx, | |
| "dataset_shape": {"rows": df.shape[0], "columns": df.shape[1]}, | |
| "schema": {"numeric_columns": numeric_cols, "categorical_columns": categorical_cols}, | |
| "data_previews": {} | |
| } | |
| for col in categorical_cols[:5]: | |
| unique_vals = df[col].unique() | |
| context["data_previews"][col] = {"count": len(unique_vals), "values": unique_vals[:5].tolist()} | |
| for col in numeric_cols[:5]: | |
| context["data_previews"][col] = {"mean": df[col].mean(), "min": df[col].min(), "max": df[col].max()} | |
| return json.loads(json.dumps(context, default=str)) | |
| def generate_report_draft(buf, name: str, ctx: str, uid: str, project_id: str, bucket): | |
| logging.info(f"Generating persona-driven report draft for project {project_id}") | |
| df = load_dataframe_safely(buf, name) | |
| llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash", google_api_key=API_KEY, temperature=0.2) | |
| data_context_str, context_for_charts = "", {} | |
| try: | |
| df_json = df.to_json(orient='records') | |
| estimated_tokens = len(df_json) / 4 | |
| if estimated_tokens < MAX_CONTEXT_TOKENS: | |
| logging.info(f"Using full JSON context.") | |
| data_context_str = f"Here is the full dataset in JSON format:\n{df_json}" | |
| context_for_charts = get_augmented_context(df, ctx) | |
| else: | |
| raise ValueError("Dataset too large.") | |
| except Exception as e: | |
| logging.warning(f"Falling back to augmented summary context: {e}") | |
| augmented_context = get_augmented_context(df, ctx) | |
| data_context_str = f"The full dataset is too large to display. Here is a detailed summary:\n{json.dumps(augmented_context, indent=2)}" | |
| context_for_charts = augmented_context | |
| intelligence = analyze_data_intelligence(df) | |
| viz_strategy = generate_visualization_strategy(intelligence) | |
| report_prompt = f""" | |
| You are an elite data storyteller and business intelligence expert. Your mission is to uncover the compelling, hidden narrative in this dataset and present it as a captivating story in Markdown format that drives action. | |
| **Data Context:** | |
| {data_context_str} | |
| **Intelligence Analysis:** | |
| - The most interesting parts of this story may lie in the following areas: {', '.join(intelligence['insight_opportunities'])}. | |
| - Weave these threads into your core narrative. | |
| **Visualization Strategy:** | |
| - {viz_strategy} | |
| - Available Chart Types: `bar, pie, line, scatter, hist, heatmap, area, bubble`. | |
| **Your Grounding Rules (Most Important):** | |
| 1. **Strict Accuracy:** Your entire analysis and narrative **must strictly** use the column names provided in the 'Data Context' section. Do not invent, modify, or assume any column names that are not on this list. | |
| 2. **Chart Support:** Wherever a key finding is made, you **must** support it with a chart tag: `<generate_chart: "chart_type | a specific, compelling description">`. | |
| 3. **Chart Accuracy:** The column names used in your chart descriptions **must** also be an exact match from the provided data context. | |
| Now, begin your report. Let the data's story unfold naturally. | |
| """ | |
| md = llm.invoke(report_prompt).content | |
| chart_descs = extract_chart_tags(md)[:MAX_CHARTS] | |
| chart_urls = {} | |
| chart_generator = ChartGenerator(llm, df) | |
| for desc in chart_descs: | |
| safe_desc = sanitize_for_firebase_key(desc) | |
| md = md.replace(f'<generate_chart: "{desc}">', f'<generate_chart: "{safe_desc}">') | |
| md = md.replace(f'<generate_chart: {desc}>', f'<generate_chart: "{safe_desc}">') | |
| with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file: | |
| img_path = Path(temp_file.name) | |
| try: | |
| chart_spec = chart_generator.generate_chart_spec(desc, context_for_charts) | |
| if execute_chart_spec(chart_spec, df, img_path): | |
| blob_name = f"sozo_projects/{uid}/{project_id}/charts/{uuid.uuid4().hex}.png" | |
| blob = bucket.blob(blob_name) | |
| blob.upload_from_filename(str(img_path)) | |
| chart_urls[safe_desc] = blob.public_url | |
| finally: | |
| if os.path.exists(img_path): | |
| os.unlink(img_path) | |
| return {"raw_md": md, "chartUrls": chart_urls, "data_context": context_for_charts} | |
| def generate_single_chart(df: pd.DataFrame, description: str, uid: str, project_id: str, bucket): | |
| logging.info(f"Generating single chart '{description}' for project {project_id}") | |
| llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", google_api_key=API_KEY, temperature=0.1) | |
| chart_generator = ChartGenerator(llm, df) | |
| context = get_augmented_context(df, "") | |
| with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file: | |
| img_path = Path(temp_file.name) | |
| try: | |
| chart_spec = chart_generator.generate_chart_spec(description, context) | |
| if execute_chart_spec(chart_spec, df, img_path): | |
| blob_name = f"sozo_projects/{uid}/{project_id}/charts/{uuid.uuid4().hex}.png" | |
| blob = bucket.blob(blob_name) | |
| blob.upload_from_filename(str(img_path)) | |
| logging.info(f"Uploaded single chart to {blob.public_url}") | |
| return blob.public_url | |
| finally: | |
| if os.path.exists(img_path): | |
| os.unlink(img_path) | |
| return None | |
| def generate_video_from_project(df: pd.DataFrame, raw_md: str, data_context: Dict, uid: str, project_id: str, voice_model: str, bucket): | |
| logging.info(f"Generating video for project {project_id} with voice {voice_model}") | |
| llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash", google_api_key=API_KEY, temperature=0.2) | |
| story_prompt = f""" | |
| Based on the following report, create a script for a {VIDEO_SCENES}-scene video. | |
| 1. The first scene MUST be an "Introduction". It must contain narration and a stock video tag like: <generate_stock_video: "search query">. | |
| 2. The last scene MUST be a "Conclusion". It must also contain narration and a stock video tag. | |
| 3. The middle scenes should each contain narration and one chart tag from the report. | |
| 4. Separate each scene with '[SCENE_BREAK]'. | |
| Report: {raw_md} | |
| Only output the script, no extra text. | |
| """ | |
| script = llm.invoke(story_prompt).content | |
| scenes = [s.strip() for s in script.split("[SCENE_BREAK]") if s.strip()] | |
| video_parts, audio_parts, temps = [], [], [] | |
| total_audio_duration = 0.0 | |
| for i, sc in enumerate(scenes): | |
| mp4 = Path(tempfile.gettempdir()) / f"{uuid.uuid4()}.mp4" | |
| narrative = clean_narration(sc) | |
| if not narrative: | |
| logging.warning(f"Scene {i+1} has no narration, skipping.") | |
| continue | |
| audio_bytes = deepgram_tts(narrative, voice_model) | |
| mp3 = Path(tempfile.gettempdir()) / f"{uuid.uuid4()}.mp3" | |
| audio_dur = 5.0 | |
| if audio_bytes: | |
| mp3.write_bytes(audio_bytes) | |
| audio_dur = audio_duration(str(mp3)) | |
| if audio_dur <= 0.1: audio_dur = 5.0 | |
| else: | |
| generate_silence_mp3(audio_dur, mp3) | |
| audio_parts.append(str(mp3)); temps.append(mp3) | |
| total_audio_duration += audio_dur | |
| video_dur = audio_dur + 0.5 | |
| try: | |
| # --- Primary Visual Generation --- | |
| chart_descs = extract_chart_tags(sc) | |
| pexels_descs = extract_pexels_tags(sc) | |
| if pexels_descs: | |
| logging.info(f"Scene {i+1}: Primary attempt with Pexels.") | |
| query = extract_keywords_for_query(narrative, llm) | |
| video_path = search_and_download_pexels_video(query, video_dur, mp4) | |
| if not video_path: raise ValueError("Pexels search returned no results.") | |
| video_parts.append(video_path) | |
| elif chart_descs: | |
| logging.info(f"Scene {i+1}: Primary attempt with animated chart.") | |
| safe_chart(chart_descs[0], df, video_dur, mp4, data_context) | |
| video_parts.append(str(mp4)) | |
| else: | |
| raise ValueError("No visual tag found in scene.") | |
| except Exception as e: | |
| # --- Fallback Visual Generation --- | |
| logging.warning(f"Scene {i+1}: Primary visual failed ({e}). Triggering fallback.") | |
| try: | |
| fallback_query = "abstract technology background" | |
| video_path = search_and_download_pexels_video(fallback_query, video_dur, mp4) | |
| if not video_path: raise ValueError("Fallback Pexels search failed.") | |
| video_parts.append(video_path) | |
| except Exception as fallback_e: | |
| # --- Final Failsafe --- | |
| logging.error(f"Scene {i+1}: Fallback visual also failed ({fallback_e}). Using placeholder.") | |
| placeholder = placeholder_img() | |
| placeholder.save(str(mp4).replace(".mp4", ".png")) | |
| animate_image_fade(cv2.imread(str(mp4).replace(".mp4", ".png")), video_dur, mp4) | |
| video_parts.append(str(mp4)) | |
| temps.append(mp4) | |
| with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_vid, \ | |
| tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as temp_aud, \ | |
| tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as final_vid: | |
| silent_vid_path = Path(temp_vid.name) | |
| audio_mix_path = Path(temp_aud.name) | |
| final_vid_path = Path(final_vid.name) | |
| concat_media(video_parts, silent_vid_path) | |
| concat_media(audio_parts, audio_mix_path) | |
| cmd = [ | |
| "ffmpeg", "-y", "-i", str(silent_vid_path), "-i", str(audio_mix_path), | |
| "-c:v", "libx264", "-pix_fmt", "yuv420p", "-c:a", "aac", | |
| "-map", "0:v:0", "-map", "1:a:0", | |
| "-t", f"{total_audio_duration:.3f}", | |
| str(final_vid_path) | |
| ] | |
| subprocess.run(cmd, check=True, capture_output=True) | |
| blob_name = f"sozo_projects/{uid}/{project_id}/video.mp4" | |
| blob = bucket.blob(blob_name) | |
| blob.upload_from_filename(str(final_vid_path)) | |
| logging.info(f"Uploaded video to {blob.public_url}") | |
| for p in temps + [silent_vid_path, audio_mix_path, final_vid_path]: | |
| if os.path.exists(p): os.unlink(p) | |
| return blob.public_url | |
| return None |