Update sozo_gen.py
Browse files- sozo_gen.py +191 -286
sozo_gen.py
CHANGED
|
@@ -1,5 +1,3 @@
|
|
| 1 |
-
# sozo_gen.py
|
| 2 |
-
|
| 3 |
import os
|
| 4 |
import re
|
| 5 |
import json
|
|
@@ -9,110 +7,128 @@ import io
|
|
| 9 |
from pathlib import Path
|
| 10 |
import pandas as pd
|
| 11 |
import numpy as np
|
| 12 |
-
import matplotlib
|
| 13 |
-
matplotlib.use("Agg")
|
| 14 |
-
import matplotlib.pyplot as plt
|
| 15 |
-
from matplotlib.animation import FuncAnimation, FFMpegWriter
|
| 16 |
-
from PIL import Image
|
| 17 |
-
import cv2
|
| 18 |
-
import inspect
|
| 19 |
-
import tempfile
|
| 20 |
import subprocess
|
| 21 |
-
from typing import Dict, List
|
|
|
|
| 22 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 23 |
from google import genai
|
| 24 |
import requests
|
| 25 |
|
| 26 |
# --- Configuration ---
|
| 27 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - [%(funcName)s] - %(message)s')
|
| 28 |
-
FPS, WIDTH, HEIGHT = 24, 1280, 720
|
| 29 |
MAX_CHARTS, VIDEO_SCENES = 5, 5
|
| 30 |
|
| 31 |
# --- Gemini API Initialization ---
|
| 32 |
-
# CORRECTED: Use the correct environment variable name and remove the deprecated configure call.
|
| 33 |
API_KEY = os.getenv("GOOGLE_API_KEY")
|
| 34 |
if not API_KEY:
|
| 35 |
raise ValueError("GOOGLE_API_KEY environment variable not set.")
|
| 36 |
-
# REMOVED: genai.configure(api_key=API_KEY) - This is deprecated. The library now uses the environment variable automatically.
|
| 37 |
|
| 38 |
# --- Helper Functions ---
|
| 39 |
-
def load_dataframe_safely(buf, name: str):
|
|
|
|
| 40 |
ext = Path(name).suffix.lower()
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
-
def deepgram_tts(txt: str, voice_model: str):
|
|
|
|
| 48 |
DG_KEY = os.getenv("DEEPGRAM_API_KEY")
|
| 49 |
-
if not DG_KEY or not txt:
|
|
|
|
|
|
|
| 50 |
txt = re.sub(r"[^\w\s.,!?;:-]", "", txt)[:1000]
|
| 51 |
try:
|
| 52 |
-
r = requests.post(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
r.raise_for_status()
|
| 54 |
return r.content
|
| 55 |
except Exception as e:
|
| 56 |
logging.error(f"Deepgram TTS failed: {e}")
|
| 57 |
return None
|
| 58 |
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
def audio_duration(path: str) -> float:
|
| 63 |
-
try:
|
| 64 |
-
res = subprocess.run([ "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=nw=1:nk=1", path], text=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True)
|
| 65 |
-
return float(res.stdout.strip())
|
| 66 |
-
except Exception: return 5.0
|
| 67 |
-
|
| 68 |
-
TAG_RE = re.compile( r'[<[]\s*generate_?chart\s*[:=]?\s*[\"\'“”]?(?P<d>[^>\"\'”\]]+?)[\"\'“”]?\s*[>\]]', re.I, )
|
| 69 |
-
extract_chart_tags = lambda t: list( dict.fromkeys(m.group("d").strip() for m in TAG_RE.finditer(t or "")) )
|
| 70 |
|
| 71 |
re_scene = re.compile(r"^\s*scene\s*\d+[:.\- ]*", re.I | re.M)
|
| 72 |
def clean_narration(txt: str) -> str:
|
| 73 |
-
|
|
|
|
|
|
|
| 74 |
phrases_to_remove = [r"as you can see in the chart", r"this chart shows", r"the chart illustrates", r"in this visual", r"this graph displays"]
|
| 75 |
-
for phrase in phrases_to_remove:
|
| 76 |
-
|
|
|
|
|
|
|
| 77 |
return re.sub(r"\s{2,}", " ", txt).strip()
|
| 78 |
|
| 79 |
-
def placeholder_img() -> Image.Image: return Image.new("RGB", (WIDTH, HEIGHT), (230, 230, 230))
|
| 80 |
-
|
| 81 |
-
def generate_image_from_prompt(prompt: str) -> Image.Image:
|
| 82 |
-
model_main = "gemini-2.0-flash-exp-image-generation"; model_fallback = "gemini-2.0-flash-preview-image-generation"
|
| 83 |
-
full_prompt = "A clean business-presentation illustration: " + prompt
|
| 84 |
-
def fetch(model_name):
|
| 85 |
-
try:
|
| 86 |
-
model = genai.GenerativeModel(model_name)
|
| 87 |
-
res = model.generate_content(full_prompt)
|
| 88 |
-
for part in res.candidates[0].content.parts:
|
| 89 |
-
if getattr(part, "inline_data", None):
|
| 90 |
-
return Image.open(io.BytesIO(part.inline_data.data)).convert("RGB")
|
| 91 |
-
return None
|
| 92 |
-
except Exception:
|
| 93 |
-
return None
|
| 94 |
-
try:
|
| 95 |
-
img = fetch(model_main) or fetch(model_fallback)
|
| 96 |
-
return img if img else placeholder_img()
|
| 97 |
-
except Exception: return placeholder_img()
|
| 98 |
-
|
| 99 |
# --- Chart Generation System ---
|
| 100 |
class ChartSpecification:
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
|
| 105 |
def enhance_data_context(df: pd.DataFrame, ctx_dict: Dict) -> Dict:
|
| 106 |
-
|
|
|
|
|
|
|
|
|
|
| 107 |
enhanced_ctx.update({"numeric_columns": numeric_cols, "categorical_columns": categorical_cols})
|
| 108 |
return enhanced_ctx
|
| 109 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
class ChartGenerator:
|
| 111 |
def __init__(self, llm, df: pd.DataFrame):
|
| 112 |
-
self.llm = llm
|
| 113 |
-
self.
|
|
|
|
| 114 |
|
| 115 |
-
def generate_chart_spec(self, description: str) ->
|
|
|
|
| 116 |
spec_prompt = f"""
|
| 117 |
You are a data visualization expert. Based on the dataset and chart description, generate a precise chart specification.
|
| 118 |
**Dataset Info:** {json.dumps(self.enhanced_ctx, indent=2)}
|
|
@@ -129,246 +145,135 @@ class ChartGenerator:
|
|
| 129 |
if response.startswith("```json"): response = response[7:-3]
|
| 130 |
elif response.startswith("```"): response = response[3:-3]
|
| 131 |
spec_dict = json.loads(response)
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
except Exception as e:
|
| 136 |
-
logging.error(f"
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
numeric_cols = self.enhanced_ctx['numeric_columns']; categorical_cols = self.enhanced_ctx['categorical_columns']
|
| 141 |
-
ctype = "bar"
|
| 142 |
-
for t in ["pie", "line", "scatter", "hist"]:
|
| 143 |
-
if t in description.lower(): ctype = t
|
| 144 |
-
x = categorical_cols[0] if categorical_cols else self.df.columns[0]
|
| 145 |
-
y = numeric_cols[0] if numeric_cols and len(self.df.columns) > 1 else (self.df.columns[1] if len(self.df.columns) > 1 else None)
|
| 146 |
-
return ChartSpecification(ctype, description, x, y)
|
| 147 |
-
|
| 148 |
-
def execute_chart_spec(spec: ChartSpecification, df: pd.DataFrame, output_path: Path) -> bool:
|
| 149 |
-
try:
|
| 150 |
-
plot_data = prepare_plot_data(spec, df)
|
| 151 |
-
fig, ax = plt.subplots(figsize=(12, 8)); plt.style.use('default')
|
| 152 |
-
if spec.chart_type == "bar": ax.bar(plot_data.index.astype(str), plot_data.values, color='#2E86AB', alpha=0.8); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col); ax.tick_params(axis='x', rotation=45)
|
| 153 |
-
elif spec.chart_type == "pie": ax.pie(plot_data.values, labels=plot_data.index, autopct='%1.1f%%', startangle=90); ax.axis('equal')
|
| 154 |
-
elif spec.chart_type == "line": ax.plot(plot_data.index, plot_data.values, marker='o', linewidth=2, color='#A23B72'); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col); ax.grid(True, alpha=0.3)
|
| 155 |
-
elif spec.chart_type == "scatter": ax.scatter(plot_data.iloc[:, 0], plot_data.iloc[:, 1], alpha=0.6, color='#F18F01'); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col); ax.grid(True, alpha=0.3)
|
| 156 |
-
elif spec.chart_type == "hist": ax.hist(plot_data.values, bins=20, color='#C73E1D', alpha=0.7, edgecolor='black'); ax.set_xlabel(spec.x_col); ax.set_ylabel('Frequency'); ax.grid(True, alpha=0.3)
|
| 157 |
-
ax.set_title(spec.title, fontsize=14, fontweight='bold', pad=20); plt.tight_layout()
|
| 158 |
-
plt.savefig(output_path, dpi=300, bbox_inches='tight', facecolor='white'); plt.close()
|
| 159 |
-
return True
|
| 160 |
-
except Exception as e: logging.error(f"Static chart generation failed for '{spec.title}': {e}"); return False
|
| 161 |
-
|
| 162 |
-
def prepare_plot_data(spec: ChartSpecification, df: pd.DataFrame) -> pd.Series:
|
| 163 |
-
if spec.x_col not in df.columns or (spec.y_col and spec.y_col not in df.columns): raise ValueError(f"Invalid columns in chart spec: {spec.x_col}, {spec.y_col}")
|
| 164 |
-
if spec.chart_type in ["bar", "pie"]:
|
| 165 |
-
if not spec.y_col: return df[spec.x_col].value_counts().nlargest(spec.top_n or 10)
|
| 166 |
-
grouped = df.groupby(spec.x_col)[spec.y_col].agg(spec.agg_method or 'sum')
|
| 167 |
-
return grouped.nlargest(spec.top_n or 10)
|
| 168 |
-
elif spec.chart_type == "line": return df.set_index(spec.x_col)[spec.y_col].sort_index()
|
| 169 |
-
elif spec.chart_type == "scatter": return df[[spec.x_col, spec.y_col]].dropna()
|
| 170 |
-
elif spec.chart_type == "hist": return df[spec.x_col].dropna()
|
| 171 |
-
return df[spec.x_col]
|
| 172 |
-
|
| 173 |
-
# --- Animation & Video Generation ---
|
| 174 |
-
def animate_chart(spec: ChartSpecification, df: pd.DataFrame, dur: float, out: Path, fps: int = FPS) -> str:
|
| 175 |
-
plot_data = prepare_plot_data(spec, df)
|
| 176 |
-
frames = max(10, int(dur * fps))
|
| 177 |
-
fig, ax = plt.subplots(figsize=(WIDTH / 100, HEIGHT / 100), dpi=100)
|
| 178 |
-
plt.tight_layout(pad=3.0)
|
| 179 |
-
ctype = spec.chart_type
|
| 180 |
-
if ctype == "pie":
|
| 181 |
-
wedges, _, _ = ax.pie(plot_data, labels=plot_data.index, startangle=90, autopct='%1.1f%%')
|
| 182 |
-
ax.set_title(spec.title); ax.axis('equal')
|
| 183 |
-
def init(): [w.set_alpha(0) for w in wedges]; return wedges
|
| 184 |
-
def update(i): [w.set_alpha(i / (frames - 1)) for w in wedges]; return wedges
|
| 185 |
-
elif ctype == "bar":
|
| 186 |
-
bars = ax.bar(plot_data.index.astype(str), np.zeros_like(plot_data.values, dtype=float), color="#1f77b4")
|
| 187 |
-
ax.set_ylim(0, plot_data.max() * 1.1 if not pd.isna(plot_data.max()) and plot_data.max() > 0 else 1)
|
| 188 |
-
ax.set_title(spec.title); plt.xticks(rotation=45, ha="right")
|
| 189 |
-
def init(): return bars
|
| 190 |
-
def update(i):
|
| 191 |
-
for b, h in zip(bars, plot_data.values): b.set_height(h * (i / (frames - 1)))
|
| 192 |
-
return bars
|
| 193 |
-
elif ctype == "scatter":
|
| 194 |
-
scat = ax.scatter([], [], alpha=0.7)
|
| 195 |
-
x_full, y_full = plot_data.iloc[:, 0], plot_data.iloc[:, 1]
|
| 196 |
-
ax.set_xlim(x_full.min(), x_full.max()); ax.set_ylim(y_full.min(), y_full.max())
|
| 197 |
-
ax.set_title(spec.title); ax.grid(alpha=.3); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col)
|
| 198 |
-
def init(): scat.set_offsets(np.empty((0, 2))); return [scat]
|
| 199 |
-
def update(i):
|
| 200 |
-
k = max(1, int(len(x_full) * (i / (frames - 1))))
|
| 201 |
-
scat.set_offsets(plot_data.iloc[:k].values); return [scat]
|
| 202 |
-
elif ctype == "hist":
|
| 203 |
-
_, _, patches = ax.hist(plot_data, bins=20, alpha=0)
|
| 204 |
-
ax.set_title(spec.title); ax.set_xlabel(spec.x_col); ax.set_ylabel("Frequency")
|
| 205 |
-
def init(): [p.set_alpha(0) for p in patches]; return patches
|
| 206 |
-
def update(i): [p.set_alpha((i / (frames - 1)) * 0.7) for p in patches]; return patches
|
| 207 |
-
else: # line
|
| 208 |
-
line, = ax.plot([], [], lw=2)
|
| 209 |
-
plot_data = plot_data.sort_index() if not plot_data.index.is_monotonic_increasing else plot_data
|
| 210 |
-
x_full, y_full = plot_data.index, plot_data.values
|
| 211 |
-
ax.set_xlim(x_full.min(), x_full.max()); ax.set_ylim(y_full.min() * 0.9, y_full.max() * 1.1)
|
| 212 |
-
ax.set_title(spec.title); ax.grid(alpha=.3); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col)
|
| 213 |
-
def init(): line.set_data([], []); return [line]
|
| 214 |
-
def update(i):
|
| 215 |
-
k = max(2, int(len(x_full) * (i / (frames - 1))))
|
| 216 |
-
line.set_data(x_full[:k], y_full[:k]); return [line]
|
| 217 |
-
anim = FuncAnimation(fig, update, init_func=init, frames=frames, blit=True, interval=1000 / fps)
|
| 218 |
-
anim.save(str(out), writer=FFMpegWriter(fps=fps), dpi=144)
|
| 219 |
-
plt.close(fig)
|
| 220 |
-
return str(out)
|
| 221 |
-
|
| 222 |
-
def animate_image_fade(img: np.ndarray, dur: float, out: Path, fps: int = 24) -> str:
|
| 223 |
-
fourcc = cv2.VideoWriter_fourcc(*'mp4v'); video_writer = cv2.VideoWriter(str(out), fourcc, fps, (WIDTH, HEIGHT))
|
| 224 |
-
total_frames = max(1, int(dur * fps))
|
| 225 |
-
for i in range(total_frames):
|
| 226 |
-
alpha = i / (total_frames - 1) if total_frames > 1 else 1.0
|
| 227 |
-
frame = cv2.addWeighted(img, alpha, np.zeros_like(img), 1 - alpha, 0)
|
| 228 |
-
video_writer.write(frame)
|
| 229 |
-
video_writer.release()
|
| 230 |
-
return str(out)
|
| 231 |
-
|
| 232 |
-
def safe_chart(desc: str, df: pd.DataFrame, dur: float, out: Path) -> str:
|
| 233 |
-
try:
|
| 234 |
-
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", google_api_key=API_KEY, temperature=0.1)
|
| 235 |
-
chart_generator = ChartGenerator(llm, df)
|
| 236 |
-
chart_spec = chart_generator.generate_chart_spec(desc)
|
| 237 |
-
return animate_chart(chart_spec, df, dur, out)
|
| 238 |
-
except Exception as e:
|
| 239 |
-
logging.error(f"Chart animation failed for '{desc}': {e}. Falling back to static image.")
|
| 240 |
-
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_png_file:
|
| 241 |
-
temp_png = Path(temp_png_file.name)
|
| 242 |
-
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", google_api_key=API_KEY, temperature=0.1)
|
| 243 |
-
chart_generator = ChartGenerator(llm, df)
|
| 244 |
-
chart_spec = chart_generator.generate_chart_spec(desc)
|
| 245 |
-
if execute_chart_spec(chart_spec, df, temp_png):
|
| 246 |
-
img = cv2.imread(str(temp_png)); os.unlink(temp_png)
|
| 247 |
-
img_resized = cv2.resize(img, (WIDTH, HEIGHT))
|
| 248 |
-
return animate_image_fade(img_resized, dur, out)
|
| 249 |
-
else:
|
| 250 |
-
img = generate_image_from_prompt(f"A professional business chart showing {desc}")
|
| 251 |
-
img_cv = cv2.cvtColor(np.array(img.resize((WIDTH, HEIGHT))), cv2.COLOR_RGB2BGR)
|
| 252 |
-
return animate_image_fade(img_cv, dur, out)
|
| 253 |
-
|
| 254 |
-
def concat_media(file_paths: List[str], output_path: Path):
|
| 255 |
-
valid_paths = [p for p in file_paths if Path(p).exists() and Path(p).stat().st_size > 100]
|
| 256 |
-
if not valid_paths: raise ValueError("No valid media files to concatenate.")
|
| 257 |
-
if len(valid_paths) == 1: import shutil; shutil.copy2(valid_paths[0], str(output_path)); return
|
| 258 |
-
list_file = output_path.with_suffix(".txt")
|
| 259 |
-
with open(list_file, 'w') as f:
|
| 260 |
-
for path in valid_paths: f.write(f"file '{Path(path).resolve()}'\n")
|
| 261 |
-
cmd = ["ffmpeg", "-y", "-f", "concat", "-safe", "0", "-i", str(list_file), "-c", "copy", str(output_path)]
|
| 262 |
-
try:
|
| 263 |
-
subprocess.run(cmd, check=True, capture_output=True, text=True)
|
| 264 |
-
finally:
|
| 265 |
-
list_file.unlink(missing_ok=True)
|
| 266 |
|
| 267 |
-
# --- Main Business Logic Functions
|
| 268 |
|
| 269 |
-
def generate_report_draft(
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
|
|
|
|
|
|
|
|
|
| 274 |
enhanced_ctx = enhance_data_context(df, ctx_dict)
|
|
|
|
| 275 |
report_prompt = f"""
|
| 276 |
-
You are a senior data analyst
|
| 277 |
**Dataset Analysis Context:** {json.dumps(enhanced_ctx, indent=2)}
|
| 278 |
**Instructions:**
|
| 279 |
1. **Executive Summary**: Start with a high-level summary of key findings.
|
| 280 |
-
2. **Key Insights**: Provide 3-5 key insights
|
| 281 |
-
3. **Visual Support**:
|
| 282 |
Valid chart types: bar, pie, line, scatter, hist.
|
| 283 |
Generate insights that would be valuable to C-level executives.
|
| 284 |
"""
|
| 285 |
md = llm.invoke(report_prompt).content
|
| 286 |
chart_descs = extract_chart_tags(md)[:MAX_CHARTS]
|
| 287 |
-
|
|
|
|
| 288 |
chart_generator = ChartGenerator(llm, df)
|
| 289 |
for desc in chart_descs:
|
| 290 |
-
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
|
| 291 |
-
img_path = Path(temp_file.name)
|
| 292 |
-
try:
|
| 293 |
-
chart_spec = chart_generator.generate_chart_spec(desc)
|
| 294 |
-
if execute_chart_spec(chart_spec, df, img_path):
|
| 295 |
-
blob_name = f"sozo_projects/{uid}/{project_id}/charts/{uuid.uuid4().hex}.png"
|
| 296 |
-
blob = bucket.blob(blob_name)
|
| 297 |
-
blob.upload_from_filename(str(img_path))
|
| 298 |
-
chart_urls[desc] = blob.public_url
|
| 299 |
-
logging.info(f"Uploaded chart '{desc}' to {blob.public_url}")
|
| 300 |
-
finally:
|
| 301 |
-
os.unlink(img_path)
|
| 302 |
-
return {"raw_md": md, "chartUrls": chart_urls}
|
| 303 |
-
|
| 304 |
-
def generate_single_chart(df: pd.DataFrame, description: str, uid: str, project_id: str, bucket):
|
| 305 |
-
logging.info(f"Generating single chart '{description}' for project {project_id}")
|
| 306 |
-
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", google_api_key=API_KEY, temperature=0.1)
|
| 307 |
-
chart_generator = ChartGenerator(llm, df)
|
| 308 |
-
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
|
| 309 |
-
img_path = Path(temp_file.name)
|
| 310 |
try:
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 326 |
script = llm.invoke(story_prompt).content
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 331 |
audio_bytes = deepgram_tts(narrative, voice_model)
|
| 332 |
-
mp3 = Path(tempfile.gettempdir()) / f"{uuid.uuid4()}.mp3"
|
| 333 |
if audio_bytes:
|
| 334 |
-
|
| 335 |
-
if dur <= 0.1: dur = 5.0
|
| 336 |
else:
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
if descs:
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as temp_aud, \
|
| 349 |
-
tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as final_vid:
|
| 350 |
-
|
| 351 |
-
silent_vid_path = Path(temp_vid.name)
|
| 352 |
-
audio_mix_path = Path(temp_aud.name)
|
| 353 |
-
final_vid_path = Path(final_vid.name)
|
| 354 |
-
|
| 355 |
-
concat_media(video_parts, silent_vid_path)
|
| 356 |
-
concat_media(audio_parts, audio_mix_path)
|
| 357 |
-
|
| 358 |
-
subprocess.run(
|
| 359 |
-
["ffmpeg", "-y", "-i", str(silent_vid_path), "-i", str(audio_mix_path),
|
| 360 |
-
"-c:v", "libx264", "-pix_fmt", "yuv420p", "-c:a", "aac",
|
| 361 |
-
"-map", "0:v:0", "-map", "1:a:0", "-shortest", str(final_vid_path)],
|
| 362 |
-
check=True, capture_output=True,
|
| 363 |
-
)
|
| 364 |
-
|
| 365 |
-
blob_name = f"sozo_projects/{uid}/{project_id}/video.mp4"
|
| 366 |
-
blob = bucket.blob(blob_name)
|
| 367 |
-
blob.upload_from_filename(str(final_vid_path))
|
| 368 |
-
logging.info(f"Uploaded video to {blob.public_url}")
|
| 369 |
|
| 370 |
-
|
| 371 |
-
if os.path.exists(p): os.unlink(p)
|
| 372 |
-
|
| 373 |
-
return blob.public_url
|
| 374 |
-
return None
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
import re
|
| 3 |
import json
|
|
|
|
| 7 |
from pathlib import Path
|
| 8 |
import pandas as pd
|
| 9 |
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
import subprocess
|
| 11 |
+
from typing import Dict, List, Any
|
| 12 |
+
|
| 13 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 14 |
from google import genai
|
| 15 |
import requests
|
| 16 |
|
| 17 |
# --- Configuration ---
|
| 18 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - [%(funcName)s] - %(message)s')
|
|
|
|
| 19 |
MAX_CHARTS, VIDEO_SCENES = 5, 5
|
| 20 |
|
| 21 |
# --- Gemini API Initialization ---
|
|
|
|
| 22 |
API_KEY = os.getenv("GOOGLE_API_KEY")
|
| 23 |
if not API_KEY:
|
| 24 |
raise ValueError("GOOGLE_API_KEY environment variable not set.")
|
|
|
|
| 25 |
|
| 26 |
# --- Helper Functions ---
|
| 27 |
+
def load_dataframe_safely(buf, name: str) -> pd.DataFrame:
|
| 28 |
+
"""Loads a dataframe from a buffer, handling CSV or Excel files."""
|
| 29 |
ext = Path(name).suffix.lower()
|
| 30 |
+
try:
|
| 31 |
+
df = (pd.read_excel if ext in (".xlsx", ".xls") else pd.read_csv)(buf)
|
| 32 |
+
df.columns = df.columns.astype(str).str.strip()
|
| 33 |
+
df = df.dropna(how="all")
|
| 34 |
+
if df.empty or len(df.columns) == 0:
|
| 35 |
+
raise ValueError("No usable data found in the file.")
|
| 36 |
+
# Convert entire dataframe to JSON-compatible types
|
| 37 |
+
return df.replace({np.nan: None})
|
| 38 |
+
except Exception as e:
|
| 39 |
+
logging.error(f"Failed to load dataframe: {e}")
|
| 40 |
+
raise ValueError(f"Could not parse the file: {e}")
|
| 41 |
|
| 42 |
+
def deepgram_tts(txt: str, voice_model: str) -> bytes | None:
|
| 43 |
+
"""Generates speech from text using Deepgram and returns raw audio bytes."""
|
| 44 |
DG_KEY = os.getenv("DEEPGRAM_API_KEY")
|
| 45 |
+
if not DG_KEY or not txt:
|
| 46 |
+
return None
|
| 47 |
+
# Clean and truncate text for the API
|
| 48 |
txt = re.sub(r"[^\w\s.,!?;:-]", "", txt)[:1000]
|
| 49 |
try:
|
| 50 |
+
r = requests.post(
|
| 51 |
+
"https://api.deepgram.com/v1/speak",
|
| 52 |
+
params={"model": voice_model},
|
| 53 |
+
headers={"Authorization": f"Token {DG_KEY}", "Content-Type": "application/json"},
|
| 54 |
+
json={"text": txt},
|
| 55 |
+
timeout=30
|
| 56 |
+
)
|
| 57 |
r.raise_for_status()
|
| 58 |
return r.content
|
| 59 |
except Exception as e:
|
| 60 |
logging.error(f"Deepgram TTS failed: {e}")
|
| 61 |
return None
|
| 62 |
|
| 63 |
+
TAG_RE = re.compile(r'[<[]\s*generate_?chart\s*[:=]?\s*[\"\'“”]?(?P<d>[^>\"\'”\]]+?)[\"\'“”]?\s*[>\]]', re.I)
|
| 64 |
+
extract_chart_tags = lambda t: list(dict.fromkeys(m.group("d").strip() for m in TAG_RE.finditer(t or "")))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
re_scene = re.compile(r"^\s*scene\s*\d+[:.\- ]*", re.I | re.M)
|
| 67 |
def clean_narration(txt: str) -> str:
|
| 68 |
+
"""Cleans narration text by removing chart tags and other artifacts."""
|
| 69 |
+
txt = TAG_RE.sub("", txt)
|
| 70 |
+
txt = re_scene.sub("", txt)
|
| 71 |
phrases_to_remove = [r"as you can see in the chart", r"this chart shows", r"the chart illustrates", r"in this visual", r"this graph displays"]
|
| 72 |
+
for phrase in phrases_to_remove:
|
| 73 |
+
txt = re.sub(phrase, "", txt, flags=re.IGNORECASE)
|
| 74 |
+
txt = re.sub(r"\s*\([^)]*\)", "", txt)
|
| 75 |
+
txt = re.sub(r"[\*#_]", "", txt)
|
| 76 |
return re.sub(r"\s{2,}", " ", txt).strip()
|
| 77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
# --- Chart Generation System ---
|
| 79 |
class ChartSpecification:
|
| 80 |
+
"""A data class to hold the specification for a chart."""
|
| 81 |
+
def __init__(self, chart_type: str, title: str, x_col: str, y_col: str, agg_method: str = None, top_n: int = None):
|
| 82 |
+
self.chart_type = chart_type
|
| 83 |
+
self.title = title
|
| 84 |
+
self.x_col = x_col
|
| 85 |
+
self.y_col = y_col
|
| 86 |
+
self.agg_method = agg_method or "sum"
|
| 87 |
+
self.top_n = top_n
|
| 88 |
|
| 89 |
def enhance_data_context(df: pd.DataFrame, ctx_dict: Dict) -> Dict:
|
| 90 |
+
"""Enhances the data context for the LLM with column types."""
|
| 91 |
+
enhanced_ctx = ctx_dict.copy()
|
| 92 |
+
numeric_cols = df.select_dtypes(include=['number']).columns.tolist()
|
| 93 |
+
categorical_cols = df.select_dtypes(exclude=['number']).columns.tolist()
|
| 94 |
enhanced_ctx.update({"numeric_columns": numeric_cols, "categorical_columns": categorical_cols})
|
| 95 |
return enhanced_ctx
|
| 96 |
|
| 97 |
+
def prepare_plot_data(spec: ChartSpecification, df: pd.DataFrame) -> List[Dict]:
|
| 98 |
+
"""Prepares data for a chart based on its specification and returns it in a JSON-friendly format."""
|
| 99 |
+
if spec.x_col not in df.columns or (spec.y_col and spec.y_col not in df.columns):
|
| 100 |
+
raise ValueError(f"Invalid columns in chart spec: {spec.x_col}, {spec.y_col}")
|
| 101 |
+
|
| 102 |
+
if spec.chart_type in ["bar", "pie"]:
|
| 103 |
+
if not spec.y_col: # Case: count occurrences of a single categorical column
|
| 104 |
+
plot_data = df[spec.x_col].value_counts().nlargest(spec.top_n or 10)
|
| 105 |
+
else: # Case: aggregate a numeric column by a categorical column
|
| 106 |
+
grouped = df.groupby(spec.x_col)[spec.y_col].agg(spec.agg_method or 'sum')
|
| 107 |
+
plot_data = grouped.nlargest(spec.top_n or 10)
|
| 108 |
+
return plot_data.reset_index().rename(columns={'index': spec.x_col, plot_data.name: spec.y_col}).to_dict(orient='records')
|
| 109 |
+
|
| 110 |
+
elif spec.chart_type == "line":
|
| 111 |
+
plot_data = df.set_index(spec.x_col)[spec.y_col].sort_index()
|
| 112 |
+
return plot_data.reset_index().to_dict(orient='records')
|
| 113 |
+
|
| 114 |
+
elif spec.chart_type == "scatter":
|
| 115 |
+
return df[[spec.x_col, spec.y_col]].dropna().to_dict(orient='records')
|
| 116 |
+
|
| 117 |
+
elif spec.chart_type == "hist":
|
| 118 |
+
# For histograms, we just need the single column of data. The client will bin it.
|
| 119 |
+
return df[[spec.x_col]].dropna().to_dict(orient='records')
|
| 120 |
+
|
| 121 |
+
return []
|
| 122 |
+
|
| 123 |
+
|
| 124 |
class ChartGenerator:
|
| 125 |
def __init__(self, llm, df: pd.DataFrame):
|
| 126 |
+
self.llm = llm
|
| 127 |
+
self.df = df
|
| 128 |
+
self.enhanced_ctx = enhance_data_context(df, {"columns": list(df.columns), "shape": df.shape})
|
| 129 |
|
| 130 |
+
def generate_chart_spec(self, description: str) -> Dict[str, Any]:
|
| 131 |
+
"""Generates a complete chart specification, including the data, as a dictionary."""
|
| 132 |
spec_prompt = f"""
|
| 133 |
You are a data visualization expert. Based on the dataset and chart description, generate a precise chart specification.
|
| 134 |
**Dataset Info:** {json.dumps(self.enhanced_ctx, indent=2)}
|
|
|
|
| 145 |
if response.startswith("```json"): response = response[7:-3]
|
| 146 |
elif response.startswith("```"): response = response[3:-3]
|
| 147 |
spec_dict = json.loads(response)
|
| 148 |
+
|
| 149 |
+
# Create a spec object to validate and use helper methods
|
| 150 |
+
spec_obj = ChartSpecification(
|
| 151 |
+
chart_type=spec_dict.get("chart_type"),
|
| 152 |
+
title=spec_dict.get("title"),
|
| 153 |
+
x_col=spec_dict.get("x_col"),
|
| 154 |
+
y_col=spec_dict.get("y_col"),
|
| 155 |
+
agg_method=spec_dict.get("agg_method"),
|
| 156 |
+
top_n=spec_dict.get("top_n")
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
# Prepare the data and add it to the final spec dictionary
|
| 160 |
+
chart_data = prepare_plot_data(spec_obj, self.df)
|
| 161 |
+
final_spec = {
|
| 162 |
+
"id": "chart_" + uuid.uuid4().hex,
|
| 163 |
+
"chart_type": spec_obj.chart_type,
|
| 164 |
+
"title": spec_obj.title,
|
| 165 |
+
"x_col": spec_obj.x_col,
|
| 166 |
+
"y_col": spec_obj.y_col,
|
| 167 |
+
"data": chart_data
|
| 168 |
+
}
|
| 169 |
+
return final_spec
|
| 170 |
except Exception as e:
|
| 171 |
+
logging.error(f"Chart spec generation failed: {e}. Cannot proceed.")
|
| 172 |
+
# In the new architecture, we cannot fall back to a placeholder. We must fail.
|
| 173 |
+
raise ValueError(f"Failed to generate a valid chart specification for '{description}'.") from e
|
| 174 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
|
| 176 |
+
# --- Main Business Logic Functions ---
|
| 177 |
|
| 178 |
+
def generate_report_draft(df: pd.DataFrame, user_context: str) -> Dict[str, Any]:
|
| 179 |
+
"""
|
| 180 |
+
Generates a markdown report draft and a list of chart specifications.
|
| 181 |
+
This function does NOT interact with storage.
|
| 182 |
+
"""
|
| 183 |
+
logging.info("Generating report draft and chart specifications.")
|
| 184 |
+
llm = ChatGoogleGenerativeAI(model="gemini-1.5-flash", google_api_key=API_KEY, temperature=0.1)
|
| 185 |
+
ctx_dict = {"shape": df.shape, "columns": list(df.columns), "user_ctx": user_context}
|
| 186 |
enhanced_ctx = enhance_data_context(df, ctx_dict)
|
| 187 |
+
|
| 188 |
report_prompt = f"""
|
| 189 |
+
You are a senior data analyst. Analyze the provided dataset context and write a comprehensive executive-level Markdown report.
|
| 190 |
**Dataset Analysis Context:** {json.dumps(enhanced_ctx, indent=2)}
|
| 191 |
**Instructions:**
|
| 192 |
1. **Executive Summary**: Start with a high-level summary of key findings.
|
| 193 |
+
2. **Key Insights**: Provide 3-5 key insights. For each insight that can be visualized, insert a chart tag on its own line.
|
| 194 |
+
3. **Visual Support**: Use chart tags like: <generate_chart: "A bar chart showing total sales by region">.
|
| 195 |
Valid chart types: bar, pie, line, scatter, hist.
|
| 196 |
Generate insights that would be valuable to C-level executives.
|
| 197 |
"""
|
| 198 |
md = llm.invoke(report_prompt).content
|
| 199 |
chart_descs = extract_chart_tags(md)[:MAX_CHARTS]
|
| 200 |
+
chart_specs = []
|
| 201 |
+
|
| 202 |
chart_generator = ChartGenerator(llm, df)
|
| 203 |
for desc in chart_descs:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
try:
|
| 205 |
+
spec = chart_generator.generate_chart_spec(desc)
|
| 206 |
+
chart_specs.append(spec)
|
| 207 |
+
except Exception as e:
|
| 208 |
+
logging.warning(f"Could not generate spec for chart '{desc}': {e}")
|
| 209 |
+
# Continue without the failed chart
|
| 210 |
+
|
| 211 |
+
return {"raw_md": md, "chart_specs": chart_specs}
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def generate_single_chart(df: pd.DataFrame, description: str) -> Dict[str, Any]:
|
| 215 |
+
"""Generates a single chart specification dictionary."""
|
| 216 |
+
logging.info(f"Generating single chart spec for '{description}'")
|
| 217 |
+
llm = ChatGoogleGenerativeAI(model="gemini-1.5-flash", google_api_key=API_KEY, temperature=0.1)
|
| 218 |
+
chart_generator = ChartGenerator(llm, df)
|
| 219 |
+
return chart_generator.generate_chart_spec(description)
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def generate_video_from_project(df: pd.DataFrame, raw_md: str, voice_model: str) -> Dict[str, Any]:
|
| 223 |
+
"""
|
| 224 |
+
Generates a video script with narration text, chart specs, and raw audio bytes.
|
| 225 |
+
This function does NOT interact with storage.
|
| 226 |
+
"""
|
| 227 |
+
logging.info(f"Generating video script with voice {voice_model}")
|
| 228 |
+
llm = ChatGoogleGenerativeAI(model="gemini-1.5-flash", google_api_key=API_KEY, temperature=0.2)
|
| 229 |
+
story_prompt = f"""
|
| 230 |
+
Based on the following report, create a script for a {VIDEO_SCENES}-scene video.
|
| 231 |
+
Each scene must be separated by '[SCENE_BREAK]'.
|
| 232 |
+
Each scene should contain narration text and, if relevant, exactly one chart tag from the report.
|
| 233 |
+
Example Scene:
|
| 234 |
+
Narration: We saw significant growth in the southern region.
|
| 235 |
+
<generate_chart: "A bar chart showing total sales by region">
|
| 236 |
+
|
| 237 |
+
[SCENE_BREAK]
|
| 238 |
+
|
| 239 |
+
Narration: This growth was driven primarily by our new product line.
|
| 240 |
+
<generate_chart: "A pie chart of product line performance">
|
| 241 |
+
|
| 242 |
+
Here is the report:
|
| 243 |
+
{raw_md}
|
| 244 |
+
"""
|
| 245 |
script = llm.invoke(story_prompt).content
|
| 246 |
+
scenes_text = [s.strip() for s in script.split("[SCENE_BREAK]") if s.strip()]
|
| 247 |
+
|
| 248 |
+
video_script = {"scenes": []}
|
| 249 |
+
chart_generator = ChartGenerator(llm, df)
|
| 250 |
+
|
| 251 |
+
for i, sc_text in enumerate(scenes_text):
|
| 252 |
+
descs = extract_chart_tags(sc_text)
|
| 253 |
+
narrative = clean_narration(sc_text)
|
| 254 |
+
|
| 255 |
+
scene_spec = {
|
| 256 |
+
"scene_id": f"scene_{i+1}",
|
| 257 |
+
"narration": narrative,
|
| 258 |
+
"audio_content": None, # Will hold raw bytes
|
| 259 |
+
"chart_spec": None
|
| 260 |
+
}
|
| 261 |
+
|
| 262 |
+
# Generate audio for the narration
|
| 263 |
audio_bytes = deepgram_tts(narrative, voice_model)
|
|
|
|
| 264 |
if audio_bytes:
|
| 265 |
+
scene_spec["audio_content"] = audio_bytes
|
|
|
|
| 266 |
else:
|
| 267 |
+
logging.warning(f"Could not generate audio for scene {i+1}. It will be silent.")
|
| 268 |
+
|
| 269 |
+
# Generate chart spec if a tag exists
|
| 270 |
+
if descs:
|
| 271 |
+
try:
|
| 272 |
+
chart_spec = chart_generator.generate_chart_spec(descs[0])
|
| 273 |
+
scene_spec["chart_spec"] = chart_spec
|
| 274 |
+
except Exception as e:
|
| 275 |
+
logging.warning(f"Could not generate chart for scene {i+1}: {e}")
|
| 276 |
+
|
| 277 |
+
video_script["scenes"].append(scene_spec)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 278 |
|
| 279 |
+
return video_script
|
|
|
|
|
|
|
|
|
|
|
|