Update sozo_gen.py
Browse files- sozo_gen.py +111 -341
sozo_gen.py
CHANGED
|
@@ -13,8 +13,8 @@ import matplotlib
|
|
| 13 |
matplotlib.use("Agg")
|
| 14 |
import matplotlib.pyplot as plt
|
| 15 |
from matplotlib.animation import FuncAnimation, FFMpegWriter
|
| 16 |
-
import seaborn as sns
|
| 17 |
-
from scipy import stats
|
| 18 |
from PIL import Image
|
| 19 |
import cv2
|
| 20 |
import inspect
|
|
@@ -29,13 +29,13 @@ import requests
|
|
| 29 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - [%(funcName)s] - %(message)s')
|
| 30 |
FPS, WIDTH, HEIGHT = 24, 1280, 720
|
| 31 |
MAX_CHARTS, VIDEO_SCENES = 5, 5
|
|
|
|
| 32 |
|
| 33 |
# --- API Initialization ---
|
| 34 |
API_KEY = os.getenv("GOOGLE_API_KEY")
|
| 35 |
if not API_KEY:
|
| 36 |
raise ValueError("GOOGLE_API_KEY environment variable not set.")
|
| 37 |
|
| 38 |
-
# NEW: Pexels API Key
|
| 39 |
PEXELS_API_KEY = os.getenv("PEXELS_API_KEY")
|
| 40 |
|
| 41 |
# --- Helper Functions ---
|
|
@@ -68,13 +68,11 @@ def audio_duration(path: str) -> float:
|
|
| 68 |
return float(res.stdout.strip())
|
| 69 |
except Exception: return 5.0
|
| 70 |
|
| 71 |
-
# UPDATED: Regex for chart tags and NEW regex for stock video tags
|
| 72 |
TAG_RE = re.compile( r'[<[]\s*generate_?chart\s*[:=]?\s*[\"\'“”]?(?P<d>[^>\"\'”\]]+?)[\"\'“”]?\s*[>\]]', re.I, )
|
| 73 |
TAG_RE_PEXELS = re.compile( r'[<[]\s*generate_?stock_?video\s*[:=]?\s*[\"\'“”]?(?P<d>[^>\"\'”\]]+?)[\"\'“”]?\s*[>\]]', re.I, )
|
| 74 |
extract_chart_tags = lambda t: list( dict.fromkeys(m.group("d").strip() for m in TAG_RE.finditer(t or "")) )
|
| 75 |
extract_pexels_tags = lambda t: list( dict.fromkeys(m.group("d").strip() for m in TAG_RE_PEXELS.finditer(t or "")) )
|
| 76 |
|
| 77 |
-
|
| 78 |
re_scene = re.compile(r"^\s*scene\s*\d+[:.\- ]*", re.I | re.M)
|
| 79 |
def clean_narration(txt: str) -> str:
|
| 80 |
txt = TAG_RE.sub("", txt); txt = TAG_RE_PEXELS.sub("", txt); txt = re_scene.sub("", txt)
|
|
@@ -98,7 +96,6 @@ def generate_image_from_prompt(prompt: str) -> Image.Image:
|
|
| 98 |
except Exception:
|
| 99 |
return placeholder_img()
|
| 100 |
|
| 101 |
-
# NEW: Pexels video search and download function
|
| 102 |
def search_and_download_pexels_video(query: str, duration: float, out_path: Path) -> str:
|
| 103 |
if not PEXELS_API_KEY:
|
| 104 |
logging.warning("PEXELS_API_KEY not set. Cannot fetch stock video.")
|
|
@@ -113,7 +110,6 @@ def search_and_download_pexels_video(query: str, duration: float, out_path: Path
|
|
| 113 |
logging.warning(f"No Pexels videos found for query: '{query}'")
|
| 114 |
return None
|
| 115 |
|
| 116 |
-
# Find a suitable video file (prefer HD)
|
| 117 |
video_to_download = None
|
| 118 |
for video in videos:
|
| 119 |
for f in video.get('video_files', []):
|
|
@@ -127,7 +123,6 @@ def search_and_download_pexels_video(query: str, duration: float, out_path: Path
|
|
| 127 |
logging.warning(f"No suitable HD video file found for query: '{query}'")
|
| 128 |
return None
|
| 129 |
|
| 130 |
-
# Download to a temporary file
|
| 131 |
with requests.get(video_to_download, stream=True, timeout=60) as r:
|
| 132 |
r.raise_for_status()
|
| 133 |
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_dl_file:
|
|
@@ -135,7 +130,6 @@ def search_and_download_pexels_video(query: str, duration: float, out_path: Path
|
|
| 135 |
temp_dl_file.write(chunk)
|
| 136 |
temp_dl_path = Path(temp_dl_file.name)
|
| 137 |
|
| 138 |
-
# Use FFmpeg to resize, crop, and trim to exact duration
|
| 139 |
cmd = [
|
| 140 |
"ffmpeg", "-y", "-i", str(temp_dl_path),
|
| 141 |
"-vf", f"scale={WIDTH}:{HEIGHT}:force_original_aspect_ratio=decrease,pad={WIDTH}:{HEIGHT}:(ow-iw)/2:(oh-ih)/2,setsar=1",
|
|
@@ -154,28 +148,19 @@ def search_and_download_pexels_video(query: str, duration: float, out_path: Path
|
|
| 154 |
return None
|
| 155 |
|
| 156 |
# --- Chart Generation System ---
|
| 157 |
-
# UPDATED: ChartSpecification to include size_col for bubble charts
|
| 158 |
class ChartSpecification:
|
| 159 |
def __init__(self, chart_type: str, title: str, x_col: str, y_col: str = None, size_col: str = None, agg_method: str = None, filter_condition: str = None, top_n: int = None, color_scheme: str = "professional"):
|
| 160 |
self.chart_type = chart_type; self.title = title; self.x_col = x_col; self.y_col = y_col; self.size_col = size_col
|
| 161 |
self.agg_method = agg_method or "sum"; self.filter_condition = filter_condition; self.top_n = top_n; self.color_scheme = color_scheme
|
| 162 |
|
| 163 |
-
def enhance_data_context(df: pd.DataFrame, ctx_dict: Dict) -> Dict:
|
| 164 |
-
enhanced_ctx = ctx_dict.copy(); numeric_cols = df.select_dtypes(include=['number']).columns.tolist(); categorical_cols = df.select_dtypes(exclude=['number']).columns.tolist()
|
| 165 |
-
enhanced_ctx.update({"numeric_columns": numeric_cols, "categorical_columns": categorical_cols})
|
| 166 |
-
return enhanced_ctx
|
| 167 |
-
|
| 168 |
class ChartGenerator:
|
| 169 |
def __init__(self, llm, df: pd.DataFrame):
|
| 170 |
self.llm = llm; self.df = df
|
| 171 |
-
self.enhanced_ctx = enhance_data_context(df, {"columns": list(df.columns), "shape": df.shape, "dtypes": {col: str(dtype) for col, dtype in df.dtypes.items()}})
|
| 172 |
|
| 173 |
-
def generate_chart_spec(self, description: str) -> ChartSpecification:
|
| 174 |
-
safe_ctx = json_serializable(self.enhanced_ctx)
|
| 175 |
-
# UPDATED: Prompt to include new chart types
|
| 176 |
spec_prompt = f"""
|
| 177 |
-
You are a data visualization expert. Based on the dataset and chart description, generate a precise chart specification.
|
| 178 |
-
**Dataset
|
| 179 |
**Chart Request:** {description}
|
| 180 |
**Return a JSON specification with these exact fields:**
|
| 181 |
{{
|
|
@@ -187,7 +172,7 @@ class ChartGenerator:
|
|
| 187 |
"agg_method": "sum|mean|count|max|min|null",
|
| 188 |
"top_n": "number_for_top_n_filtering_or_null"
|
| 189 |
}}
|
| 190 |
-
Return only the JSON specification, no additional text.
|
| 191 |
"""
|
| 192 |
try:
|
| 193 |
response = self.llm.invoke(spec_prompt).content.strip()
|
|
@@ -199,18 +184,15 @@ class ChartGenerator:
|
|
| 199 |
return ChartSpecification(**filtered_dict)
|
| 200 |
except Exception as e:
|
| 201 |
logging.error(f"Spec generation failed: {e}. Using fallback.")
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
if
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
return ChartSpecification(ctype, description, x, y)
|
| 212 |
-
|
| 213 |
-
# UPDATED: execute_chart_spec to include new chart types
|
| 214 |
def execute_chart_spec(spec: ChartSpecification, df: pd.DataFrame, output_path: Path) -> bool:
|
| 215 |
try:
|
| 216 |
plot_data = prepare_plot_data(spec, df)
|
|
@@ -231,7 +213,6 @@ def execute_chart_spec(spec: ChartSpecification, df: pd.DataFrame, output_path:
|
|
| 231 |
return True
|
| 232 |
except Exception as e: logging.error(f"Static chart generation failed for '{spec.title}': {e}"); return False
|
| 233 |
|
| 234 |
-
# UPDATED: prepare_plot_data to handle new chart types
|
| 235 |
def prepare_plot_data(spec: ChartSpecification, df: pd.DataFrame):
|
| 236 |
if spec.chart_type not in ["heatmap"]:
|
| 237 |
if spec.x_col not in df.columns or (spec.y_col and spec.y_col not in df.columns): raise ValueError(f"Invalid columns in chart spec: {spec.x_col}, {spec.y_col}")
|
|
@@ -253,7 +234,6 @@ def prepare_plot_data(spec: ChartSpecification, df: pd.DataFrame):
|
|
| 253 |
return df[spec.x_col]
|
| 254 |
|
| 255 |
# --- Animation & Video Generation ---
|
| 256 |
-
# UPDATED: animate_chart with enhanced animations and new chart types
|
| 257 |
def animate_chart(spec: ChartSpecification, df: pd.DataFrame, dur: float, out: Path, fps: int = FPS) -> str:
|
| 258 |
plot_data = prepare_plot_data(spec, df)
|
| 259 |
frames = max(10, int(dur * fps))
|
|
@@ -276,30 +256,25 @@ def animate_chart(spec: ChartSpecification, df: pd.DataFrame, dur: float, out: P
|
|
| 276 |
return bars
|
| 277 |
elif ctype == "scatter":
|
| 278 |
x_full, y_full = plot_data.iloc[:, 0], plot_data.iloc[:, 1]
|
| 279 |
-
# Calculate regression line
|
| 280 |
slope, intercept, _, _, _ = stats.linregress(x_full, y_full)
|
| 281 |
reg_line_x = np.array([x_full.min(), x_full.max()])
|
| 282 |
reg_line_y = slope * reg_line_x + intercept
|
| 283 |
|
| 284 |
scat = ax.scatter([], [], alpha=0.7, color='#F18F01')
|
| 285 |
-
line, = ax.plot([], [], 'r--', lw=2)
|
| 286 |
ax.set_xlim(x_full.min(), x_full.max()); ax.set_ylim(y_full.min(), y_full.max())
|
| 287 |
ax.set_title(spec.title); ax.grid(alpha=.3); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col)
|
| 288 |
|
| 289 |
def init():
|
| 290 |
-
scat.set_offsets(np.empty((0, 2)))
|
| 291 |
-
line.set_data([], [])
|
| 292 |
return [scat, line]
|
| 293 |
def update(i):
|
| 294 |
-
# Animate points for the first 70% of frames
|
| 295 |
point_frames = int(frames * 0.7)
|
| 296 |
if i <= point_frames:
|
| 297 |
k = max(1, int(len(x_full) * (i / point_frames)))
|
| 298 |
scat.set_offsets(plot_data.iloc[:k].values)
|
| 299 |
-
# Animate regression line for the last 30%
|
| 300 |
else:
|
| 301 |
-
line_frame = i - point_frames
|
| 302 |
-
line_total_frames = frames - point_frames
|
| 303 |
current_x = reg_line_x[0] + (reg_line_x[1] - reg_line_x[0]) * (line_frame / line_total_frames)
|
| 304 |
line.set_data([reg_line_x[0], current_x], [reg_line_y[0], slope * current_x + intercept])
|
| 305 |
return [scat, line]
|
|
@@ -320,32 +295,19 @@ def animate_chart(spec: ChartSpecification, df: pd.DataFrame, dur: float, out: P
|
|
| 320 |
k = max(2, int(len(x_full) * (i / (frames - 1))))
|
| 321 |
fill = ax.fill_between(x_full[:k], y_full[:k], color="#4E79A7", alpha=0.4)
|
| 322 |
return [fill]
|
| 323 |
-
elif ctype == "heatmap":
|
| 324 |
-
sns.heatmap(plot_data, annot=True, cmap="viridis", ax=ax, alpha=0)
|
| 325 |
-
ax.set_title(spec.title)
|
| 326 |
-
def init(): ax.collections[0].set_alpha(0); return [ax.collections[0]]
|
| 327 |
-
def update(i): ax.collections[0].set_alpha(i / (frames - 1)); return [ax.collections[0]]
|
| 328 |
-
elif ctype == "bubble":
|
| 329 |
-
sizes = (plot_data[spec.size_col] - plot_data[spec.size_col].min() + 1) / (plot_data[spec.size_col].max() - plot_data[spec.size_col].min() + 1) * 2000 + 50
|
| 330 |
-
scat = ax.scatter(plot_data[spec.x_col], plot_data[spec.y_col], s=sizes, alpha=0, color='#59A14F')
|
| 331 |
-
ax.set_title(spec.title); ax.grid(alpha=.3); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col)
|
| 332 |
-
def init(): scat.set_alpha(0); return [scat]
|
| 333 |
-
def update(i): scat.set_alpha(i / (frames - 1) * 0.7); return [scat]
|
| 334 |
else: # line (Time Series)
|
| 335 |
line, = ax.plot([], [], lw=2, color='#A23B72')
|
| 336 |
-
markers, = ax.plot([], [], 'o', color='#A23B72', markersize=5)
|
| 337 |
plot_data = plot_data.sort_index() if not plot_data.index.is_monotonic_increasing else plot_data
|
| 338 |
x_full, y_full = plot_data.index, plot_data.values
|
| 339 |
ax.set_xlim(x_full.min(), x_full.max()); ax.set_ylim(y_full.min() * 0.9, y_full.max() * 1.1)
|
| 340 |
ax.set_title(spec.title); ax.grid(alpha=.3); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col)
|
| 341 |
def init():
|
| 342 |
-
line.set_data([], [])
|
| 343 |
-
markers.set_data([], [])
|
| 344 |
return [line, markers]
|
| 345 |
def update(i):
|
| 346 |
k = max(2, int(len(x_full) * (i / (frames - 1))))
|
| 347 |
-
line.set_data(x_full[:k], y_full[:k])
|
| 348 |
-
markers.set_data(x_full[:k], y_full[:k])
|
| 349 |
return [line, markers]
|
| 350 |
|
| 351 |
anim = FuncAnimation(fig, update, init_func=init, frames=frames, blit=True, interval=1000 / fps)
|
|
@@ -363,11 +325,11 @@ def animate_image_fade(img: np.ndarray, dur: float, out: Path, fps: int = 24) ->
|
|
| 363 |
video_writer.release()
|
| 364 |
return str(out)
|
| 365 |
|
| 366 |
-
def safe_chart(desc: str, df: pd.DataFrame, dur: float, out: Path) -> str:
|
| 367 |
try:
|
| 368 |
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", google_api_key=API_KEY, temperature=0.1)
|
| 369 |
chart_generator = ChartGenerator(llm, df)
|
| 370 |
-
chart_spec = chart_generator.generate_chart_spec(desc)
|
| 371 |
return animate_chart(chart_spec, df, dur, out)
|
| 372 |
except Exception as e:
|
| 373 |
logging.error(f"Chart animation failed for '{desc}': {e}. Falling back to static image.")
|
|
@@ -375,7 +337,7 @@ def safe_chart(desc: str, df: pd.DataFrame, dur: float, out: Path) -> str:
|
|
| 375 |
temp_png = Path(temp_png_file.name)
|
| 376 |
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", google_api_key=API_KEY, temperature=0.1)
|
| 377 |
chart_generator = ChartGenerator(llm, df)
|
| 378 |
-
chart_spec = chart_generator.generate_chart_spec(desc)
|
| 379 |
if execute_chart_spec(chart_spec, df, temp_png):
|
| 380 |
img = cv2.imread(str(temp_png)); os.unlink(temp_png)
|
| 381 |
img_resized = cv2.resize(img, (WIDTH, HEIGHT))
|
|
@@ -398,310 +360,118 @@ def concat_media(file_paths: List[str], output_path: Path):
|
|
| 398 |
finally:
|
| 399 |
list_file.unlink(missing_ok=True)
|
| 400 |
|
| 401 |
-
# --- Main Business Logic
|
| 402 |
-
# This section containing generate_report_draft and its helpers is left unchanged as requested.
|
| 403 |
-
# ... (all functions from sanitize_for_firebase_key to generate_single_chart) ...
|
| 404 |
-
# The following functions are preserved exactly as they were in the original code provided.
|
| 405 |
|
| 406 |
def sanitize_for_firebase_key(text: str) -> str:
|
| 407 |
-
"""Replaces Firebase-forbidden characters in a string with underscores."""
|
| 408 |
forbidden_chars = ['.', '$', '#', '[', ']', '/']
|
| 409 |
for char in forbidden_chars:
|
| 410 |
text = text.replace(char, '_')
|
| 411 |
return text
|
| 412 |
|
| 413 |
-
def
|
| 414 |
-
"""
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
'marketing': ['campaign', 'conversion', 'click', 'impression', 'engagement', 'customer', 'segment'],
|
| 425 |
-
'operational': ['performance', 'efficiency', 'throughput', 'capacity', 'utilization', 'process'],
|
| 426 |
-
'temporal': ['date', 'time', 'timestamp', 'period', 'month', 'year', 'day', 'hour']
|
| 427 |
-
}
|
| 428 |
-
|
| 429 |
-
# Analyze column patterns
|
| 430 |
-
columns_lower = [col.lower() for col in df.columns]
|
| 431 |
-
domain_scores = {}
|
| 432 |
-
|
| 433 |
-
for domain, keywords in domain_signals.items():
|
| 434 |
-
score = sum(1 for col in columns_lower if any(keyword in col for keyword in keywords))
|
| 435 |
-
domain_scores[domain] = score
|
| 436 |
-
|
| 437 |
-
# Determine primary domain
|
| 438 |
-
primary_domain = max(domain_scores, key=domain_scores.get) if max(domain_scores.values()) > 0 else 'general'
|
| 439 |
-
|
| 440 |
-
# Data Structure Analysis
|
| 441 |
-
numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
|
| 442 |
-
categorical_cols = df.select_dtypes(include=['object', 'category']).columns.tolist()
|
| 443 |
-
datetime_cols = df.select_dtypes(include=['datetime64']).columns.tolist()
|
| 444 |
-
|
| 445 |
-
# Detect time series
|
| 446 |
-
is_timeseries = len(datetime_cols) > 0 or any('date' in col.lower() or 'time' in col.lower() for col in columns_lower)
|
| 447 |
-
|
| 448 |
-
# Statistical Profile
|
| 449 |
-
statistical_summary = {}
|
| 450 |
-
if numeric_cols:
|
| 451 |
-
try:
|
| 452 |
-
correlations = df[numeric_cols].corr().abs().max()
|
| 453 |
-
correlations_dict = {k: float(v) if pd.notna(v) else 0.0 for k, v in correlations.to_dict().items()}
|
| 454 |
-
|
| 455 |
-
distributions = {}
|
| 456 |
-
for col in numeric_cols:
|
| 457 |
-
if len(df[col].dropna()) > 8:
|
| 458 |
-
try:
|
| 459 |
-
p_value = stats.normaltest(df[col].dropna())[1]
|
| 460 |
-
distributions[col] = 'normal' if p_value > 0.05 else 'non_normal'
|
| 461 |
-
except:
|
| 462 |
-
distributions[col] = 'unknown'
|
| 463 |
-
|
| 464 |
-
outliers = {}
|
| 465 |
-
for col in numeric_cols:
|
| 466 |
-
if len(df[col].dropna()) > 0:
|
| 467 |
-
try:
|
| 468 |
-
z_scores = np.abs(stats.zscore(df[col].dropna()))
|
| 469 |
-
outliers[col] = int(len(df[col][z_scores > 3]))
|
| 470 |
-
except:
|
| 471 |
-
outliers[col] = 0
|
| 472 |
-
|
| 473 |
-
statistical_summary = {
|
| 474 |
-
'correlations': correlations_dict,
|
| 475 |
-
'distributions': distributions,
|
| 476 |
-
'outliers': outliers
|
| 477 |
-
}
|
| 478 |
-
except Exception as e:
|
| 479 |
-
statistical_summary = {'error': 'Could not compute statistical summary'}
|
| 480 |
-
|
| 481 |
-
# Pattern Detection
|
| 482 |
-
patterns = {
|
| 483 |
-
'has_missing_data': df.isnull().sum().sum() > 0,
|
| 484 |
-
'has_duplicates': df.duplicated().sum() > 0,
|
| 485 |
-
'has_negative_values': any(df[col].min() < 0 for col in numeric_cols if len(df[col].dropna()) > 0),
|
| 486 |
-
'has_categorical_hierarchy': any(len(df[col].unique()) > 10 for col in categorical_cols),
|
| 487 |
-
'potential_segments': len(categorical_cols) > 0
|
| 488 |
-
}
|
| 489 |
-
|
| 490 |
-
# Insight Opportunities
|
| 491 |
-
insight_opportunities = []
|
| 492 |
-
|
| 493 |
-
if is_timeseries:
|
| 494 |
-
insight_opportunities.append("temporal_trends")
|
| 495 |
-
|
| 496 |
-
if len(numeric_cols) > 1:
|
| 497 |
-
insight_opportunities.append("correlations")
|
| 498 |
-
|
| 499 |
-
if len(categorical_cols) > 0 and len(numeric_cols) > 0:
|
| 500 |
-
insight_opportunities.append("segmentation")
|
| 501 |
-
|
| 502 |
-
if any(statistical_summary.get('outliers', {}).values()):
|
| 503 |
-
insight_opportunities.append("anomalies")
|
| 504 |
-
|
| 505 |
-
return {
|
| 506 |
-
'primary_domain': primary_domain,
|
| 507 |
-
'domain_confidence': domain_scores,
|
| 508 |
-
'data_structure': {
|
| 509 |
-
'is_timeseries': is_timeseries,
|
| 510 |
-
'numeric_cols': numeric_cols,
|
| 511 |
-
'categorical_cols': categorical_cols,
|
| 512 |
-
'datetime_cols': datetime_cols
|
| 513 |
},
|
| 514 |
-
|
| 515 |
-
'patterns': patterns,
|
| 516 |
-
'insight_opportunities': insight_opportunities,
|
| 517 |
-
'narrative_suggestions': get_narrative_suggestions(primary_domain, insight_opportunities, patterns)
|
| 518 |
}
|
| 519 |
-
|
| 520 |
-
def get_narrative_suggestions(domain: str, opportunities: List[str], patterns: Dict) -> Dict[str, str]:
|
| 521 |
-
"""Generate narrative direction based on domain and data characteristics"""
|
| 522 |
|
| 523 |
-
|
| 524 |
-
|
| 525 |
-
|
| 526 |
-
|
| 527 |
-
|
| 528 |
-
}
|
| 529 |
-
|
| 530 |
-
|
| 531 |
-
|
| 532 |
-
|
| 533 |
-
|
| 534 |
-
|
| 535 |
-
'hook': "The data reveals relationships that challenge conventional thinking",
|
| 536 |
-
'structure': "hypothesis → evidence → significance → implications",
|
| 537 |
-
'focus': "statistical significance, correlations, experimental validity"
|
| 538 |
-
},
|
| 539 |
-
'marketing': {
|
| 540 |
-
'hook': "Discover the customer journey patterns driving your growth",
|
| 541 |
-
'structure': "performance → segments → optimization → strategy",
|
| 542 |
-
'focus': "conversion funnels, customer segments, campaign effectiveness"
|
| 543 |
-
},
|
| 544 |
-
'operational': {
|
| 545 |
-
'hook': "Operational excellence lives in the details - here's where to look",
|
| 546 |
-
'structure': "efficiency → bottlenecks → optimization → impact",
|
| 547 |
-
'focus': "process efficiency, capacity utilization, improvement opportunities"
|
| 548 |
-
},
|
| 549 |
-
'general': {
|
| 550 |
-
'hook': "Every dataset tells a story - here's what yours is saying",
|
| 551 |
-
'structure': "overview → patterns → insights → implications",
|
| 552 |
-
'focus': "key patterns, significant relationships, actionable insights"
|
| 553 |
}
|
| 554 |
-
}
|
| 555 |
|
| 556 |
-
return
|
| 557 |
-
|
| 558 |
-
def json_serializable(obj):
|
| 559 |
-
"""Convert objects to JSON-serializable format"""
|
| 560 |
-
if isinstance(obj, (np.integer, np.floating)):
|
| 561 |
-
return float(obj)
|
| 562 |
-
elif isinstance(obj, np.ndarray):
|
| 563 |
-
return obj.tolist()
|
| 564 |
-
elif isinstance(obj, (np.bool_, bool)):
|
| 565 |
-
return bool(obj)
|
| 566 |
-
elif isinstance(obj, dict):
|
| 567 |
-
return {k: json_serializable(v) for k, v in obj.items()}
|
| 568 |
-
elif isinstance(obj, (list, tuple)):
|
| 569 |
-
return [json_serializable(item) for item in obj]
|
| 570 |
-
elif pd.isna(obj):
|
| 571 |
-
return None
|
| 572 |
-
else:
|
| 573 |
-
return obj
|
| 574 |
|
| 575 |
-
def
|
| 576 |
-
""
|
| 577 |
-
|
| 578 |
-
|
| 579 |
-
"""
|
| 580 |
-
|
| 581 |
-
domain = intelligence['primary_domain']
|
| 582 |
-
opportunities = intelligence['insight_opportunities']
|
| 583 |
-
narrative = intelligence['narrative_suggestions']
|
| 584 |
-
|
| 585 |
-
# Dynamic chart strategy based on data characteristics
|
| 586 |
-
chart_strategy = generate_chart_strategy(intelligence)
|
| 587 |
-
|
| 588 |
-
# Make context JSON serializable
|
| 589 |
-
serializable_ctx = json_serializable(enhanced_ctx)
|
| 590 |
|
| 591 |
-
|
| 592 |
-
|
| 593 |
-
|
| 594 |
-
|
| 595 |
-
|
| 596 |
-
|
| 597 |
-
|
| 598 |
-
|
| 599 |
-
|
| 600 |
-
|
| 601 |
-
|
| 602 |
-
|
| 603 |
-
{
|
| 604 |
-
|
| 605 |
-
|
| 606 |
-
|
| 607 |
-
2. **BUILD TENSION**: Present contrasts, surprises, or unexpected patterns
|
| 608 |
-
3. **REVEAL INSIGHTS**: Use data to resolve the tension with clear comprehensive explanations
|
| 609 |
-
4. **DRIVE ACTION**: End with specific, actionable recommendations
|
| 610 |
-
|
| 611 |
-
**VISUALIZATION STRATEGY:**
|
| 612 |
-
{chart_strategy}
|
| 613 |
-
|
| 614 |
-
**CRITICAL INSTRUCTIONS:**
|
| 615 |
-
- Write as if you're revealing a detective story, not filling a template
|
| 616 |
-
- Every insight must be explained and supported by data evidence
|
| 617 |
-
- Use compelling headers that create curiosity (not "Executive Summary")
|
| 618 |
-
- Weave charts naturally into the narrative flow
|
| 619 |
-
- Focus on business impact and actionable outcomes
|
| 620 |
-
- Let the data's personality shine through your writing style
|
| 621 |
-
|
| 622 |
-
**CHART INTEGRATION:**
|
| 623 |
-
Insert charts using: `<generate_chart: "chart_type | compelling description that advances the story">`
|
| 624 |
-
Available types: bar, pie, line, scatter, hist, heatmap, area, bubble
|
| 625 |
|
| 626 |
-
|
|
|
|
| 627 |
|
| 628 |
-
|
|
|
|
| 629 |
|
| 630 |
-
|
| 631 |
-
|
| 632 |
-
|
| 633 |
-
|
| 634 |
-
|
| 635 |
-
|
| 636 |
-
|
| 637 |
-
strategies = {
|
| 638 |
-
'financial': "Focus on trend lines showing performance over time, comparative bars for different categories, and scatter plots revealing correlations between financial metrics.",
|
| 639 |
-
'survey': "Emphasize distribution histograms for satisfaction scores, segmented bar charts for demographic breakdowns, and correlation matrices for response patterns.",
|
| 640 |
-
'scientific': "Prioritize scatter plots with regression lines, distribution comparisons, and statistical significance visualizations.",
|
| 641 |
-
'marketing': "Highlight conversion funnels, customer segment comparisons, and campaign performance trends.",
|
| 642 |
-
'operational': "Show efficiency trends, capacity utilization charts, and process performance comparisons."
|
| 643 |
-
}
|
| 644 |
-
|
| 645 |
-
base_strategy = strategies.get(domain, "Create visualizations that best tell your data's unique story.")
|
| 646 |
-
|
| 647 |
-
# Add specific guidance based on data characteristics
|
| 648 |
-
if structure['is_timeseries']:
|
| 649 |
-
base_strategy += " Leverage time-series visualizations like line or area charts to show trends and patterns over time."
|
| 650 |
-
|
| 651 |
-
if 'correlations' in opportunities:
|
| 652 |
-
base_strategy += " Include correlation visualizations like scatterplots or heatmaps to reveal hidden relationships."
|
| 653 |
-
|
| 654 |
-
if 'segmentation' in opportunities:
|
| 655 |
-
base_strategy += " Use segmented charts to highlight different groups or categories."
|
| 656 |
-
|
| 657 |
-
return base_strategy
|
| 658 |
|
| 659 |
-
|
| 660 |
-
|
| 661 |
-
logging.info(f"Generating autonomous report draft for project {project_id}")
|
| 662 |
-
|
| 663 |
-
df = load_dataframe_safely(buf, name)
|
| 664 |
-
llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash", google_api_key=API_KEY, temperature=0.1)
|
| 665 |
|
| 666 |
-
ctx_dict = {"shape": df.shape, "columns": list(df.columns), "user_ctx": ctx}
|
| 667 |
-
enhanced_ctx = enhance_data_context(df, ctx_dict)
|
| 668 |
-
intelligence = analyze_data_intelligence(df, ctx_dict)
|
| 669 |
-
report_prompt = create_autonomous_prompt(df, enhanced_ctx, intelligence)
|
| 670 |
md = llm.invoke(report_prompt).content
|
| 671 |
|
| 672 |
chart_descs = extract_chart_tags(md)[:MAX_CHARTS]
|
| 673 |
chart_urls = {}
|
| 674 |
chart_generator = ChartGenerator(llm, df)
|
| 675 |
-
|
| 676 |
for desc in chart_descs:
|
| 677 |
safe_desc = sanitize_for_firebase_key(desc)
|
| 678 |
md = md.replace(f'<generate_chart: "{desc}">', f'<generate_chart: "{safe_desc}">')
|
| 679 |
md = md.replace(f'<generate_chart: {desc}>', f'<generate_chart: "{safe_desc}">')
|
| 680 |
-
|
| 681 |
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
|
| 682 |
img_path = Path(temp_file.name)
|
| 683 |
try:
|
| 684 |
-
chart_spec = chart_generator.generate_chart_spec(desc)
|
| 685 |
if execute_chart_spec(chart_spec, df, img_path):
|
| 686 |
blob_name = f"sozo_projects/{uid}/{project_id}/charts/{uuid.uuid4().hex}.png"
|
| 687 |
blob = bucket.blob(blob_name)
|
| 688 |
blob.upload_from_filename(str(img_path))
|
| 689 |
chart_urls[safe_desc] = blob.public_url
|
| 690 |
-
logging.info(f"Uploaded chart '{desc}' to {blob.public_url} with safe key '{safe_desc}'")
|
| 691 |
finally:
|
| 692 |
if os.path.exists(img_path):
|
| 693 |
os.unlink(img_path)
|
| 694 |
-
|
| 695 |
-
return {"raw_md": md, "chartUrls": chart_urls}
|
| 696 |
|
| 697 |
def generate_single_chart(df: pd.DataFrame, description: str, uid: str, project_id: str, bucket):
|
| 698 |
logging.info(f"Generating single chart '{description}' for project {project_id}")
|
| 699 |
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", google_api_key=API_KEY, temperature=0.1)
|
| 700 |
chart_generator = ChartGenerator(llm, df)
|
|
|
|
| 701 |
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
|
| 702 |
img_path = Path(temp_file.name)
|
| 703 |
try:
|
| 704 |
-
chart_spec = chart_generator.generate_chart_spec(description)
|
| 705 |
if execute_chart_spec(chart_spec, df, img_path):
|
| 706 |
blob_name = f"sozo_projects/{uid}/{project_id}/charts/{uuid.uuid4().hex}.png"
|
| 707 |
blob = bucket.blob(blob_name)
|
|
@@ -713,26 +483,23 @@ def generate_single_chart(df: pd.DataFrame, description: str, uid: str, project_
|
|
| 713 |
os.unlink(img_path)
|
| 714 |
return None
|
| 715 |
|
| 716 |
-
|
| 717 |
-
def generate_video_from_project(df: pd.DataFrame, raw_md: str, uid: str, project_id: str, voice_model: str, bucket):
|
| 718 |
logging.info(f"Generating video for project {project_id} with voice {voice_model}")
|
| 719 |
llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash", google_api_key=API_KEY, temperature=0.2)
|
| 720 |
|
| 721 |
-
# UPDATED: Prompt to create Intro/Conclusion scenes with stock video tags
|
| 722 |
story_prompt = f"""
|
| 723 |
Based on the following report, create a script for a {VIDEO_SCENES}-scene video.
|
| 724 |
1. The first scene MUST be an "Introduction". It must contain narration and a stock video tag like: <generate_stock_video: "search query">.
|
| 725 |
2. The last scene MUST be a "Conclusion". It must also contain narration and a stock video tag.
|
| 726 |
3. The middle scenes should each contain narration and one chart tag from the report.
|
| 727 |
4. Separate each scene with '[SCENE_BREAK]'.
|
| 728 |
-
|
| 729 |
Report: {raw_md}
|
| 730 |
-
|
| 731 |
Only output the script, no extra text.
|
| 732 |
"""
|
| 733 |
script = llm.invoke(story_prompt).content
|
| 734 |
scenes = [s.strip() for s in script.split("[SCENE_BREAK]") if s.strip()]
|
| 735 |
video_parts, audio_parts, temps = [], [], []
|
|
|
|
| 736 |
|
| 737 |
for i, sc in enumerate(scenes):
|
| 738 |
chart_descs = extract_chart_tags(sc)
|
|
@@ -745,35 +512,36 @@ def generate_video_from_project(df: pd.DataFrame, raw_md: str, uid: str, project
|
|
| 745 |
|
| 746 |
audio_bytes = deepgram_tts(narrative, voice_model)
|
| 747 |
mp3 = Path(tempfile.gettempdir()) / f"{uuid.uuid4()}.mp3"
|
|
|
|
| 748 |
if audio_bytes:
|
| 749 |
-
mp3.write_bytes(audio_bytes)
|
| 750 |
-
|
|
|
|
| 751 |
else:
|
| 752 |
-
|
|
|
|
| 753 |
audio_parts.append(str(mp3)); temps.append(mp3)
|
|
|
|
| 754 |
|
|
|
|
| 755 |
mp4 = Path(tempfile.gettempdir()) / f"{uuid.uuid4()}.mp4"
|
| 756 |
video_generated = False
|
| 757 |
|
| 758 |
if pexels_descs:
|
| 759 |
-
|
| 760 |
-
video_path = search_and_download_pexels_video(pexels_descs[0], dur, mp4)
|
| 761 |
if video_path:
|
| 762 |
-
video_parts.append(video_path)
|
| 763 |
-
temps.append(Path(video_path))
|
| 764 |
video_generated = True
|
| 765 |
|
| 766 |
if not video_generated and chart_descs:
|
| 767 |
-
|
| 768 |
-
safe_chart(chart_descs[0], df, dur, mp4)
|
| 769 |
video_parts.append(str(mp4)); temps.append(mp4)
|
| 770 |
video_generated = True
|
| 771 |
|
| 772 |
if not video_generated:
|
| 773 |
-
logging.warning(f"Scene {i+1}: No valid chart or stock video tag found. Using fallback image.")
|
| 774 |
img = generate_image_from_prompt(narrative)
|
| 775 |
img_cv = cv2.cvtColor(np.array(img.resize((WIDTH, HEIGHT))), cv2.COLOR_RGB2BGR)
|
| 776 |
-
animate_image_fade(img_cv,
|
| 777 |
video_parts.append(str(mp4)); temps.append(mp4)
|
| 778 |
|
| 779 |
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_vid, \
|
|
@@ -787,12 +555,14 @@ def generate_video_from_project(df: pd.DataFrame, raw_md: str, uid: str, project
|
|
| 787 |
concat_media(video_parts, silent_vid_path)
|
| 788 |
concat_media(audio_parts, audio_mix_path)
|
| 789 |
|
| 790 |
-
|
| 791 |
-
|
| 792 |
"-c:v", "libx264", "-pix_fmt", "yuv420p", "-c:a", "aac",
|
| 793 |
-
"-map", "0:v:0", "-map", "1:a:0",
|
| 794 |
-
|
| 795 |
-
|
|
|
|
|
|
|
| 796 |
|
| 797 |
blob_name = f"sozo_projects/{uid}/{project_id}/video.mp4"
|
| 798 |
blob = bucket.blob(blob_name)
|
|
|
|
| 13 |
matplotlib.use("Agg")
|
| 14 |
import matplotlib.pyplot as plt
|
| 15 |
from matplotlib.animation import FuncAnimation, FFMpegWriter
|
| 16 |
+
import seaborn as sns
|
| 17 |
+
from scipy import stats
|
| 18 |
from PIL import Image
|
| 19 |
import cv2
|
| 20 |
import inspect
|
|
|
|
| 29 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - [%(funcName)s] - %(message)s')
|
| 30 |
FPS, WIDTH, HEIGHT = 24, 1280, 720
|
| 31 |
MAX_CHARTS, VIDEO_SCENES = 5, 5
|
| 32 |
+
MAX_CONTEXT_TOKENS = 250000 # Set max token limit for full dataset context
|
| 33 |
|
| 34 |
# --- API Initialization ---
|
| 35 |
API_KEY = os.getenv("GOOGLE_API_KEY")
|
| 36 |
if not API_KEY:
|
| 37 |
raise ValueError("GOOGLE_API_KEY environment variable not set.")
|
| 38 |
|
|
|
|
| 39 |
PEXELS_API_KEY = os.getenv("PEXELS_API_KEY")
|
| 40 |
|
| 41 |
# --- Helper Functions ---
|
|
|
|
| 68 |
return float(res.stdout.strip())
|
| 69 |
except Exception: return 5.0
|
| 70 |
|
|
|
|
| 71 |
TAG_RE = re.compile( r'[<[]\s*generate_?chart\s*[:=]?\s*[\"\'“”]?(?P<d>[^>\"\'”\]]+?)[\"\'“”]?\s*[>\]]', re.I, )
|
| 72 |
TAG_RE_PEXELS = re.compile( r'[<[]\s*generate_?stock_?video\s*[:=]?\s*[\"\'“”]?(?P<d>[^>\"\'”\]]+?)[\"\'“”]?\s*[>\]]', re.I, )
|
| 73 |
extract_chart_tags = lambda t: list( dict.fromkeys(m.group("d").strip() for m in TAG_RE.finditer(t or "")) )
|
| 74 |
extract_pexels_tags = lambda t: list( dict.fromkeys(m.group("d").strip() for m in TAG_RE_PEXELS.finditer(t or "")) )
|
| 75 |
|
|
|
|
| 76 |
re_scene = re.compile(r"^\s*scene\s*\d+[:.\- ]*", re.I | re.M)
|
| 77 |
def clean_narration(txt: str) -> str:
|
| 78 |
txt = TAG_RE.sub("", txt); txt = TAG_RE_PEXELS.sub("", txt); txt = re_scene.sub("", txt)
|
|
|
|
| 96 |
except Exception:
|
| 97 |
return placeholder_img()
|
| 98 |
|
|
|
|
| 99 |
def search_and_download_pexels_video(query: str, duration: float, out_path: Path) -> str:
|
| 100 |
if not PEXELS_API_KEY:
|
| 101 |
logging.warning("PEXELS_API_KEY not set. Cannot fetch stock video.")
|
|
|
|
| 110 |
logging.warning(f"No Pexels videos found for query: '{query}'")
|
| 111 |
return None
|
| 112 |
|
|
|
|
| 113 |
video_to_download = None
|
| 114 |
for video in videos:
|
| 115 |
for f in video.get('video_files', []):
|
|
|
|
| 123 |
logging.warning(f"No suitable HD video file found for query: '{query}'")
|
| 124 |
return None
|
| 125 |
|
|
|
|
| 126 |
with requests.get(video_to_download, stream=True, timeout=60) as r:
|
| 127 |
r.raise_for_status()
|
| 128 |
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_dl_file:
|
|
|
|
| 130 |
temp_dl_file.write(chunk)
|
| 131 |
temp_dl_path = Path(temp_dl_file.name)
|
| 132 |
|
|
|
|
| 133 |
cmd = [
|
| 134 |
"ffmpeg", "-y", "-i", str(temp_dl_path),
|
| 135 |
"-vf", f"scale={WIDTH}:{HEIGHT}:force_original_aspect_ratio=decrease,pad={WIDTH}:{HEIGHT}:(ow-iw)/2:(oh-ih)/2,setsar=1",
|
|
|
|
| 148 |
return None
|
| 149 |
|
| 150 |
# --- Chart Generation System ---
|
|
|
|
| 151 |
class ChartSpecification:
|
| 152 |
def __init__(self, chart_type: str, title: str, x_col: str, y_col: str = None, size_col: str = None, agg_method: str = None, filter_condition: str = None, top_n: int = None, color_scheme: str = "professional"):
|
| 153 |
self.chart_type = chart_type; self.title = title; self.x_col = x_col; self.y_col = y_col; self.size_col = size_col
|
| 154 |
self.agg_method = agg_method or "sum"; self.filter_condition = filter_condition; self.top_n = top_n; self.color_scheme = color_scheme
|
| 155 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
class ChartGenerator:
|
| 157 |
def __init__(self, llm, df: pd.DataFrame):
|
| 158 |
self.llm = llm; self.df = df
|
|
|
|
| 159 |
|
| 160 |
+
def generate_chart_spec(self, description: str, context: Dict) -> ChartSpecification:
|
|
|
|
|
|
|
| 161 |
spec_prompt = f"""
|
| 162 |
+
You are a data visualization expert. Based on the dataset context and chart description, generate a precise chart specification.
|
| 163 |
+
**Dataset Context:** {json.dumps(context, indent=2)}
|
| 164 |
**Chart Request:** {description}
|
| 165 |
**Return a JSON specification with these exact fields:**
|
| 166 |
{{
|
|
|
|
| 172 |
"agg_method": "sum|mean|count|max|min|null",
|
| 173 |
"top_n": "number_for_top_n_filtering_or_null"
|
| 174 |
}}
|
| 175 |
+
Return only the JSON specification, no additional text.
|
| 176 |
"""
|
| 177 |
try:
|
| 178 |
response = self.llm.invoke(spec_prompt).content.strip()
|
|
|
|
| 184 |
return ChartSpecification(**filtered_dict)
|
| 185 |
except Exception as e:
|
| 186 |
logging.error(f"Spec generation failed: {e}. Using fallback.")
|
| 187 |
+
numeric_cols = context.get('schema', {}).get('numeric_columns', list(self.df.select_dtypes(include=['number']).columns))
|
| 188 |
+
categorical_cols = context.get('schema', {}).get('categorical_columns', list(self.df.select_dtypes(exclude=['number']).columns))
|
| 189 |
+
ctype = "bar"
|
| 190 |
+
for t in ["pie", "line", "scatter", "hist", "heatmap", "area", "bubble"]:
|
| 191 |
+
if t in description.lower(): ctype = t
|
| 192 |
+
x = categorical_cols[0] if categorical_cols else self.df.columns[0]
|
| 193 |
+
y = numeric_cols[0] if numeric_cols and len(self.df.columns) > 1 else (self.df.columns[1] if len(self.df.columns) > 1 else None)
|
| 194 |
+
return ChartSpecification(ctype, description, x, y)
|
| 195 |
+
|
|
|
|
|
|
|
|
|
|
| 196 |
def execute_chart_spec(spec: ChartSpecification, df: pd.DataFrame, output_path: Path) -> bool:
|
| 197 |
try:
|
| 198 |
plot_data = prepare_plot_data(spec, df)
|
|
|
|
| 213 |
return True
|
| 214 |
except Exception as e: logging.error(f"Static chart generation failed for '{spec.title}': {e}"); return False
|
| 215 |
|
|
|
|
| 216 |
def prepare_plot_data(spec: ChartSpecification, df: pd.DataFrame):
|
| 217 |
if spec.chart_type not in ["heatmap"]:
|
| 218 |
if spec.x_col not in df.columns or (spec.y_col and spec.y_col not in df.columns): raise ValueError(f"Invalid columns in chart spec: {spec.x_col}, {spec.y_col}")
|
|
|
|
| 234 |
return df[spec.x_col]
|
| 235 |
|
| 236 |
# --- Animation & Video Generation ---
|
|
|
|
| 237 |
def animate_chart(spec: ChartSpecification, df: pd.DataFrame, dur: float, out: Path, fps: int = FPS) -> str:
|
| 238 |
plot_data = prepare_plot_data(spec, df)
|
| 239 |
frames = max(10, int(dur * fps))
|
|
|
|
| 256 |
return bars
|
| 257 |
elif ctype == "scatter":
|
| 258 |
x_full, y_full = plot_data.iloc[:, 0], plot_data.iloc[:, 1]
|
|
|
|
| 259 |
slope, intercept, _, _, _ = stats.linregress(x_full, y_full)
|
| 260 |
reg_line_x = np.array([x_full.min(), x_full.max()])
|
| 261 |
reg_line_y = slope * reg_line_x + intercept
|
| 262 |
|
| 263 |
scat = ax.scatter([], [], alpha=0.7, color='#F18F01')
|
| 264 |
+
line, = ax.plot([], [], 'r--', lw=2)
|
| 265 |
ax.set_xlim(x_full.min(), x_full.max()); ax.set_ylim(y_full.min(), y_full.max())
|
| 266 |
ax.set_title(spec.title); ax.grid(alpha=.3); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col)
|
| 267 |
|
| 268 |
def init():
|
| 269 |
+
scat.set_offsets(np.empty((0, 2))); line.set_data([], [])
|
|
|
|
| 270 |
return [scat, line]
|
| 271 |
def update(i):
|
|
|
|
| 272 |
point_frames = int(frames * 0.7)
|
| 273 |
if i <= point_frames:
|
| 274 |
k = max(1, int(len(x_full) * (i / point_frames)))
|
| 275 |
scat.set_offsets(plot_data.iloc[:k].values)
|
|
|
|
| 276 |
else:
|
| 277 |
+
line_frame = i - point_frames; line_total_frames = frames - point_frames
|
|
|
|
| 278 |
current_x = reg_line_x[0] + (reg_line_x[1] - reg_line_x[0]) * (line_frame / line_total_frames)
|
| 279 |
line.set_data([reg_line_x[0], current_x], [reg_line_y[0], slope * current_x + intercept])
|
| 280 |
return [scat, line]
|
|
|
|
| 295 |
k = max(2, int(len(x_full) * (i / (frames - 1))))
|
| 296 |
fill = ax.fill_between(x_full[:k], y_full[:k], color="#4E79A7", alpha=0.4)
|
| 297 |
return [fill]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 298 |
else: # line (Time Series)
|
| 299 |
line, = ax.plot([], [], lw=2, color='#A23B72')
|
| 300 |
+
markers, = ax.plot([], [], 'o', color='#A23B72', markersize=5)
|
| 301 |
plot_data = plot_data.sort_index() if not plot_data.index.is_monotonic_increasing else plot_data
|
| 302 |
x_full, y_full = plot_data.index, plot_data.values
|
| 303 |
ax.set_xlim(x_full.min(), x_full.max()); ax.set_ylim(y_full.min() * 0.9, y_full.max() * 1.1)
|
| 304 |
ax.set_title(spec.title); ax.grid(alpha=.3); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col)
|
| 305 |
def init():
|
| 306 |
+
line.set_data([], []); markers.set_data([], [])
|
|
|
|
| 307 |
return [line, markers]
|
| 308 |
def update(i):
|
| 309 |
k = max(2, int(len(x_full) * (i / (frames - 1))))
|
| 310 |
+
line.set_data(x_full[:k], y_full[:k]); markers.set_data(x_full[:k], y_full[:k])
|
|
|
|
| 311 |
return [line, markers]
|
| 312 |
|
| 313 |
anim = FuncAnimation(fig, update, init_func=init, frames=frames, blit=True, interval=1000 / fps)
|
|
|
|
| 325 |
video_writer.release()
|
| 326 |
return str(out)
|
| 327 |
|
| 328 |
+
def safe_chart(desc: str, df: pd.DataFrame, dur: float, out: Path, context: Dict) -> str:
|
| 329 |
try:
|
| 330 |
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", google_api_key=API_KEY, temperature=0.1)
|
| 331 |
chart_generator = ChartGenerator(llm, df)
|
| 332 |
+
chart_spec = chart_generator.generate_chart_spec(desc, context)
|
| 333 |
return animate_chart(chart_spec, df, dur, out)
|
| 334 |
except Exception as e:
|
| 335 |
logging.error(f"Chart animation failed for '{desc}': {e}. Falling back to static image.")
|
|
|
|
| 337 |
temp_png = Path(temp_png_file.name)
|
| 338 |
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", google_api_key=API_KEY, temperature=0.1)
|
| 339 |
chart_generator = ChartGenerator(llm, df)
|
| 340 |
+
chart_spec = chart_generator.generate_chart_spec(desc, context)
|
| 341 |
if execute_chart_spec(chart_spec, df, temp_png):
|
| 342 |
img = cv2.imread(str(temp_png)); os.unlink(temp_png)
|
| 343 |
img_resized = cv2.resize(img, (WIDTH, HEIGHT))
|
|
|
|
| 360 |
finally:
|
| 361 |
list_file.unlink(missing_ok=True)
|
| 362 |
|
| 363 |
+
# --- Main Business Logic ---
|
|
|
|
|
|
|
|
|
|
| 364 |
|
| 365 |
def sanitize_for_firebase_key(text: str) -> str:
|
|
|
|
| 366 |
forbidden_chars = ['.', '$', '#', '[', ']', '/']
|
| 367 |
for char in forbidden_chars:
|
| 368 |
text = text.replace(char, '_')
|
| 369 |
return text
|
| 370 |
|
| 371 |
+
def get_augmented_context(df: pd.DataFrame, user_ctx: str) -> Dict:
|
| 372 |
+
"""Creates a detailed summary of the dataframe for the AI."""
|
| 373 |
+
numeric_cols = df.select_dtypes(include=['number']).columns.tolist()
|
| 374 |
+
categorical_cols = df.select_dtypes(exclude=['number']).columns.tolist()
|
| 375 |
+
|
| 376 |
+
context = {
|
| 377 |
+
"user_context": user_ctx,
|
| 378 |
+
"dataset_shape": {"rows": df.shape[0], "columns": df.shape[1]},
|
| 379 |
+
"schema": {
|
| 380 |
+
"numeric_columns": numeric_cols,
|
| 381 |
+
"categorical_columns": categorical_cols
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 382 |
},
|
| 383 |
+
"data_previews": {}
|
|
|
|
|
|
|
|
|
|
| 384 |
}
|
|
|
|
|
|
|
|
|
|
| 385 |
|
| 386 |
+
for col in categorical_cols[:5]:
|
| 387 |
+
unique_vals = df[col].unique()
|
| 388 |
+
context["data_previews"][col] = {
|
| 389 |
+
"count": len(unique_vals),
|
| 390 |
+
"values": unique_vals[:5].tolist()
|
| 391 |
+
}
|
| 392 |
+
|
| 393 |
+
for col in numeric_cols[:5]:
|
| 394 |
+
context["data_previews"][col] = {
|
| 395 |
+
"mean": df[col].mean(),
|
| 396 |
+
"min": df[col].min(),
|
| 397 |
+
"max": df[col].max()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 398 |
}
|
|
|
|
| 399 |
|
| 400 |
+
return json.loads(json.dumps(context, default=str))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 401 |
|
| 402 |
+
def generate_report_draft(buf, name: str, ctx: str, uid: str, project_id: str, bucket):
|
| 403 |
+
logging.info(f"Generating report draft for project {project_id}")
|
| 404 |
+
df = load_dataframe_safely(buf, name)
|
| 405 |
+
llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash", google_api_key=API_KEY, temperature=0.1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 406 |
|
| 407 |
+
data_context_str = ""
|
| 408 |
+
context_for_charts = {}
|
| 409 |
+
try:
|
| 410 |
+
df_json = df.to_json(orient='records')
|
| 411 |
+
estimated_tokens = len(df_json) / 4
|
| 412 |
+
if estimated_tokens < MAX_CONTEXT_TOKENS:
|
| 413 |
+
logging.info(f"Dataset is small enough ({estimated_tokens:.0f} tokens). Using full JSON context.")
|
| 414 |
+
data_context_str = f"Here is the full dataset in JSON format:\n{df_json}"
|
| 415 |
+
context_for_charts = get_augmented_context(df, ctx)
|
| 416 |
+
else:
|
| 417 |
+
raise ValueError("Dataset too large for full context.")
|
| 418 |
+
except Exception as e:
|
| 419 |
+
logging.warning(f"Could not use full JSON context ({e}). Falling back to augmented summary.")
|
| 420 |
+
augmented_context = get_augmented_context(df, ctx)
|
| 421 |
+
data_context_str = f"The full dataset is too large to display. Here is a detailed summary:\n{json.dumps(augmented_context, indent=2)}"
|
| 422 |
+
context_for_charts = augmented_context
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 423 |
|
| 424 |
+
report_prompt = f"""
|
| 425 |
+
You are an expert data analyst and business intelligence storyteller. Your mission is to analyze the provided data context and write a comprehensive, executive-level report in Markdown format.
|
| 426 |
|
| 427 |
+
**Data Context:**
|
| 428 |
+
{data_context_str}
|
| 429 |
|
| 430 |
+
**Critical Instructions:**
|
| 431 |
+
1. **Data Grounding:** Your entire analysis and narrative **must strictly** use the column names and data provided in the 'Data Context' section. Do not invent, modify, or assume any column names that are not on this list. This is the most important rule.
|
| 432 |
+
2. **Report Goal:** Create a well-structured, professional report in Markdown that tells a compelling story from the data. The structure of the report is entirely up to you, but it should be logical and easy to follow.
|
| 433 |
+
3. **Visual Support:** Wherever a key finding, trend, or significant point is made in your narrative, you **must** support it with a chart tag using the format: `<generate_chart: "chart_type | a specific, compelling description">`.
|
| 434 |
+
4. **Chart Tag Grounding:** The column names used in your chart descriptions **must** also be an exact match from the provided data context.
|
| 435 |
+
5. **Available Chart Types:** `bar, pie, line, scatter, hist, heatmap, area, bubble`.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 436 |
|
| 437 |
+
Now, generate the complete Markdown report.
|
| 438 |
+
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 439 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 440 |
md = llm.invoke(report_prompt).content
|
| 441 |
|
| 442 |
chart_descs = extract_chart_tags(md)[:MAX_CHARTS]
|
| 443 |
chart_urls = {}
|
| 444 |
chart_generator = ChartGenerator(llm, df)
|
| 445 |
+
|
| 446 |
for desc in chart_descs:
|
| 447 |
safe_desc = sanitize_for_firebase_key(desc)
|
| 448 |
md = md.replace(f'<generate_chart: "{desc}">', f'<generate_chart: "{safe_desc}">')
|
| 449 |
md = md.replace(f'<generate_chart: {desc}>', f'<generate_chart: "{safe_desc}">')
|
| 450 |
+
|
| 451 |
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
|
| 452 |
img_path = Path(temp_file.name)
|
| 453 |
try:
|
| 454 |
+
chart_spec = chart_generator.generate_chart_spec(desc, context_for_charts)
|
| 455 |
if execute_chart_spec(chart_spec, df, img_path):
|
| 456 |
blob_name = f"sozo_projects/{uid}/{project_id}/charts/{uuid.uuid4().hex}.png"
|
| 457 |
blob = bucket.blob(blob_name)
|
| 458 |
blob.upload_from_filename(str(img_path))
|
| 459 |
chart_urls[safe_desc] = blob.public_url
|
|
|
|
| 460 |
finally:
|
| 461 |
if os.path.exists(img_path):
|
| 462 |
os.unlink(img_path)
|
| 463 |
+
|
| 464 |
+
return {"raw_md": md, "chartUrls": chart_urls, "data_context": context_for_charts}
|
| 465 |
|
| 466 |
def generate_single_chart(df: pd.DataFrame, description: str, uid: str, project_id: str, bucket):
|
| 467 |
logging.info(f"Generating single chart '{description}' for project {project_id}")
|
| 468 |
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", google_api_key=API_KEY, temperature=0.1)
|
| 469 |
chart_generator = ChartGenerator(llm, df)
|
| 470 |
+
context = get_augmented_context(df, "")
|
| 471 |
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
|
| 472 |
img_path = Path(temp_file.name)
|
| 473 |
try:
|
| 474 |
+
chart_spec = chart_generator.generate_chart_spec(description, context)
|
| 475 |
if execute_chart_spec(chart_spec, df, img_path):
|
| 476 |
blob_name = f"sozo_projects/{uid}/{project_id}/charts/{uuid.uuid4().hex}.png"
|
| 477 |
blob = bucket.blob(blob_name)
|
|
|
|
| 483 |
os.unlink(img_path)
|
| 484 |
return None
|
| 485 |
|
| 486 |
+
def generate_video_from_project(df: pd.DataFrame, raw_md: str, data_context: Dict, uid: str, project_id: str, voice_model: str, bucket):
|
|
|
|
| 487 |
logging.info(f"Generating video for project {project_id} with voice {voice_model}")
|
| 488 |
llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash", google_api_key=API_KEY, temperature=0.2)
|
| 489 |
|
|
|
|
| 490 |
story_prompt = f"""
|
| 491 |
Based on the following report, create a script for a {VIDEO_SCENES}-scene video.
|
| 492 |
1. The first scene MUST be an "Introduction". It must contain narration and a stock video tag like: <generate_stock_video: "search query">.
|
| 493 |
2. The last scene MUST be a "Conclusion". It must also contain narration and a stock video tag.
|
| 494 |
3. The middle scenes should each contain narration and one chart tag from the report.
|
| 495 |
4. Separate each scene with '[SCENE_BREAK]'.
|
|
|
|
| 496 |
Report: {raw_md}
|
|
|
|
| 497 |
Only output the script, no extra text.
|
| 498 |
"""
|
| 499 |
script = llm.invoke(story_prompt).content
|
| 500 |
scenes = [s.strip() for s in script.split("[SCENE_BREAK]") if s.strip()]
|
| 501 |
video_parts, audio_parts, temps = [], [], []
|
| 502 |
+
total_audio_duration = 0.0
|
| 503 |
|
| 504 |
for i, sc in enumerate(scenes):
|
| 505 |
chart_descs = extract_chart_tags(sc)
|
|
|
|
| 512 |
|
| 513 |
audio_bytes = deepgram_tts(narrative, voice_model)
|
| 514 |
mp3 = Path(tempfile.gettempdir()) / f"{uuid.uuid4()}.mp3"
|
| 515 |
+
audio_dur = 5.0
|
| 516 |
if audio_bytes:
|
| 517 |
+
mp3.write_bytes(audio_bytes)
|
| 518 |
+
audio_dur = audio_duration(str(mp3))
|
| 519 |
+
if audio_dur <= 0.1: audio_dur = 5.0
|
| 520 |
else:
|
| 521 |
+
generate_silence_mp3(audio_dur, mp3)
|
| 522 |
+
|
| 523 |
audio_parts.append(str(mp3)); temps.append(mp3)
|
| 524 |
+
total_audio_duration += audio_dur
|
| 525 |
|
| 526 |
+
video_dur = audio_dur + 0.5
|
| 527 |
mp4 = Path(tempfile.gettempdir()) / f"{uuid.uuid4()}.mp4"
|
| 528 |
video_generated = False
|
| 529 |
|
| 530 |
if pexels_descs:
|
| 531 |
+
video_path = search_and_download_pexels_video(pexels_descs[0], video_dur, mp4)
|
|
|
|
| 532 |
if video_path:
|
| 533 |
+
video_parts.append(video_path); temps.append(Path(video_path))
|
|
|
|
| 534 |
video_generated = True
|
| 535 |
|
| 536 |
if not video_generated and chart_descs:
|
| 537 |
+
safe_chart(chart_descs[0], df, video_dur, mp4, data_context)
|
|
|
|
| 538 |
video_parts.append(str(mp4)); temps.append(mp4)
|
| 539 |
video_generated = True
|
| 540 |
|
| 541 |
if not video_generated:
|
|
|
|
| 542 |
img = generate_image_from_prompt(narrative)
|
| 543 |
img_cv = cv2.cvtColor(np.array(img.resize((WIDTH, HEIGHT))), cv2.COLOR_RGB2BGR)
|
| 544 |
+
animate_image_fade(img_cv, video_dur, mp4)
|
| 545 |
video_parts.append(str(mp4)); temps.append(mp4)
|
| 546 |
|
| 547 |
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_vid, \
|
|
|
|
| 555 |
concat_media(video_parts, silent_vid_path)
|
| 556 |
concat_media(audio_parts, audio_mix_path)
|
| 557 |
|
| 558 |
+
cmd = [
|
| 559 |
+
"ffmpeg", "-y", "-i", str(silent_vid_path), "-i", str(audio_mix_path),
|
| 560 |
"-c:v", "libx264", "-pix_fmt", "yuv420p", "-c:a", "aac",
|
| 561 |
+
"-map", "0:v:0", "-map", "1:a:0",
|
| 562 |
+
"-t", f"{total_audio_duration:.3f}",
|
| 563 |
+
str(final_vid_path)
|
| 564 |
+
]
|
| 565 |
+
subprocess.run(cmd, check=True, capture_output=True)
|
| 566 |
|
| 567 |
blob_name = f"sozo_projects/{uid}/{project_id}/video.mp4"
|
| 568 |
blob = bucket.blob(blob_name)
|