rairo commited on
Commit
78debf6
·
verified ·
1 Parent(s): 034567d

Update sozo_gen.py

Browse files
Files changed (1) hide show
  1. sozo_gen.py +111 -341
sozo_gen.py CHANGED
@@ -13,8 +13,8 @@ import matplotlib
13
  matplotlib.use("Agg")
14
  import matplotlib.pyplot as plt
15
  from matplotlib.animation import FuncAnimation, FFMpegWriter
16
- import seaborn as sns # Added for heatmaps
17
- from scipy import stats # Added for scatterplot regression
18
  from PIL import Image
19
  import cv2
20
  import inspect
@@ -29,13 +29,13 @@ import requests
29
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - [%(funcName)s] - %(message)s')
30
  FPS, WIDTH, HEIGHT = 24, 1280, 720
31
  MAX_CHARTS, VIDEO_SCENES = 5, 5
 
32
 
33
  # --- API Initialization ---
34
  API_KEY = os.getenv("GOOGLE_API_KEY")
35
  if not API_KEY:
36
  raise ValueError("GOOGLE_API_KEY environment variable not set.")
37
 
38
- # NEW: Pexels API Key
39
  PEXELS_API_KEY = os.getenv("PEXELS_API_KEY")
40
 
41
  # --- Helper Functions ---
@@ -68,13 +68,11 @@ def audio_duration(path: str) -> float:
68
  return float(res.stdout.strip())
69
  except Exception: return 5.0
70
 
71
- # UPDATED: Regex for chart tags and NEW regex for stock video tags
72
  TAG_RE = re.compile( r'[<[]\s*generate_?chart\s*[:=]?\s*[\"\'“”]?(?P<d>[^>\"\'”\]]+?)[\"\'“”]?\s*[>\]]', re.I, )
73
  TAG_RE_PEXELS = re.compile( r'[<[]\s*generate_?stock_?video\s*[:=]?\s*[\"\'“”]?(?P<d>[^>\"\'”\]]+?)[\"\'“”]?\s*[>\]]', re.I, )
74
  extract_chart_tags = lambda t: list( dict.fromkeys(m.group("d").strip() for m in TAG_RE.finditer(t or "")) )
75
  extract_pexels_tags = lambda t: list( dict.fromkeys(m.group("d").strip() for m in TAG_RE_PEXELS.finditer(t or "")) )
76
 
77
-
78
  re_scene = re.compile(r"^\s*scene\s*\d+[:.\- ]*", re.I | re.M)
79
  def clean_narration(txt: str) -> str:
80
  txt = TAG_RE.sub("", txt); txt = TAG_RE_PEXELS.sub("", txt); txt = re_scene.sub("", txt)
@@ -98,7 +96,6 @@ def generate_image_from_prompt(prompt: str) -> Image.Image:
98
  except Exception:
99
  return placeholder_img()
100
 
101
- # NEW: Pexels video search and download function
102
  def search_and_download_pexels_video(query: str, duration: float, out_path: Path) -> str:
103
  if not PEXELS_API_KEY:
104
  logging.warning("PEXELS_API_KEY not set. Cannot fetch stock video.")
@@ -113,7 +110,6 @@ def search_and_download_pexels_video(query: str, duration: float, out_path: Path
113
  logging.warning(f"No Pexels videos found for query: '{query}'")
114
  return None
115
 
116
- # Find a suitable video file (prefer HD)
117
  video_to_download = None
118
  for video in videos:
119
  for f in video.get('video_files', []):
@@ -127,7 +123,6 @@ def search_and_download_pexels_video(query: str, duration: float, out_path: Path
127
  logging.warning(f"No suitable HD video file found for query: '{query}'")
128
  return None
129
 
130
- # Download to a temporary file
131
  with requests.get(video_to_download, stream=True, timeout=60) as r:
132
  r.raise_for_status()
133
  with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_dl_file:
@@ -135,7 +130,6 @@ def search_and_download_pexels_video(query: str, duration: float, out_path: Path
135
  temp_dl_file.write(chunk)
136
  temp_dl_path = Path(temp_dl_file.name)
137
 
138
- # Use FFmpeg to resize, crop, and trim to exact duration
139
  cmd = [
140
  "ffmpeg", "-y", "-i", str(temp_dl_path),
141
  "-vf", f"scale={WIDTH}:{HEIGHT}:force_original_aspect_ratio=decrease,pad={WIDTH}:{HEIGHT}:(ow-iw)/2:(oh-ih)/2,setsar=1",
@@ -154,28 +148,19 @@ def search_and_download_pexels_video(query: str, duration: float, out_path: Path
154
  return None
155
 
156
  # --- Chart Generation System ---
157
- # UPDATED: ChartSpecification to include size_col for bubble charts
158
  class ChartSpecification:
159
  def __init__(self, chart_type: str, title: str, x_col: str, y_col: str = None, size_col: str = None, agg_method: str = None, filter_condition: str = None, top_n: int = None, color_scheme: str = "professional"):
160
  self.chart_type = chart_type; self.title = title; self.x_col = x_col; self.y_col = y_col; self.size_col = size_col
161
  self.agg_method = agg_method or "sum"; self.filter_condition = filter_condition; self.top_n = top_n; self.color_scheme = color_scheme
162
 
163
- def enhance_data_context(df: pd.DataFrame, ctx_dict: Dict) -> Dict:
164
- enhanced_ctx = ctx_dict.copy(); numeric_cols = df.select_dtypes(include=['number']).columns.tolist(); categorical_cols = df.select_dtypes(exclude=['number']).columns.tolist()
165
- enhanced_ctx.update({"numeric_columns": numeric_cols, "categorical_columns": categorical_cols})
166
- return enhanced_ctx
167
-
168
  class ChartGenerator:
169
  def __init__(self, llm, df: pd.DataFrame):
170
  self.llm = llm; self.df = df
171
- self.enhanced_ctx = enhance_data_context(df, {"columns": list(df.columns), "shape": df.shape, "dtypes": {col: str(dtype) for col, dtype in df.dtypes.items()}})
172
 
173
- def generate_chart_spec(self, description: str) -> ChartSpecification:
174
- safe_ctx = json_serializable(self.enhanced_ctx)
175
- # UPDATED: Prompt to include new chart types
176
  spec_prompt = f"""
177
- You are a data visualization expert. Based on the dataset and chart description, generate a precise chart specification.
178
- **Dataset Info:** {json.dumps(safe_ctx, indent=2)}
179
  **Chart Request:** {description}
180
  **Return a JSON specification with these exact fields:**
181
  {{
@@ -187,7 +172,7 @@ class ChartGenerator:
187
  "agg_method": "sum|mean|count|max|min|null",
188
  "top_n": "number_for_top_n_filtering_or_null"
189
  }}
190
- Return only the JSON specification, no additional text. For heatmaps, x_col and y_col can be null if it's a correlation matrix of all numeric columns.
191
  """
192
  try:
193
  response = self.llm.invoke(spec_prompt).content.strip()
@@ -199,18 +184,15 @@ class ChartGenerator:
199
  return ChartSpecification(**filtered_dict)
200
  except Exception as e:
201
  logging.error(f"Spec generation failed: {e}. Using fallback.")
202
- return self._create_fallback_spec(description)
203
-
204
- def _create_fallback_spec(self, description: str) -> ChartSpecification:
205
- numeric_cols = self.enhanced_ctx['numeric_columns']; categorical_cols = self.enhanced_ctx['categorical_columns']
206
- ctype = "bar"
207
- for t in ["pie", "line", "scatter", "hist", "heatmap", "area", "bubble"]:
208
- if t in description.lower(): ctype = t
209
- x = categorical_cols[0] if categorical_cols else self.df.columns[0]
210
- y = numeric_cols[0] if numeric_cols and len(self.df.columns) > 1 else (self.df.columns[1] if len(self.df.columns) > 1 else None)
211
- return ChartSpecification(ctype, description, x, y)
212
-
213
- # UPDATED: execute_chart_spec to include new chart types
214
  def execute_chart_spec(spec: ChartSpecification, df: pd.DataFrame, output_path: Path) -> bool:
215
  try:
216
  plot_data = prepare_plot_data(spec, df)
@@ -231,7 +213,6 @@ def execute_chart_spec(spec: ChartSpecification, df: pd.DataFrame, output_path:
231
  return True
232
  except Exception as e: logging.error(f"Static chart generation failed for '{spec.title}': {e}"); return False
233
 
234
- # UPDATED: prepare_plot_data to handle new chart types
235
  def prepare_plot_data(spec: ChartSpecification, df: pd.DataFrame):
236
  if spec.chart_type not in ["heatmap"]:
237
  if spec.x_col not in df.columns or (spec.y_col and spec.y_col not in df.columns): raise ValueError(f"Invalid columns in chart spec: {spec.x_col}, {spec.y_col}")
@@ -253,7 +234,6 @@ def prepare_plot_data(spec: ChartSpecification, df: pd.DataFrame):
253
  return df[spec.x_col]
254
 
255
  # --- Animation & Video Generation ---
256
- # UPDATED: animate_chart with enhanced animations and new chart types
257
  def animate_chart(spec: ChartSpecification, df: pd.DataFrame, dur: float, out: Path, fps: int = FPS) -> str:
258
  plot_data = prepare_plot_data(spec, df)
259
  frames = max(10, int(dur * fps))
@@ -276,30 +256,25 @@ def animate_chart(spec: ChartSpecification, df: pd.DataFrame, dur: float, out: P
276
  return bars
277
  elif ctype == "scatter":
278
  x_full, y_full = plot_data.iloc[:, 0], plot_data.iloc[:, 1]
279
- # Calculate regression line
280
  slope, intercept, _, _, _ = stats.linregress(x_full, y_full)
281
  reg_line_x = np.array([x_full.min(), x_full.max()])
282
  reg_line_y = slope * reg_line_x + intercept
283
 
284
  scat = ax.scatter([], [], alpha=0.7, color='#F18F01')
285
- line, = ax.plot([], [], 'r--', lw=2) # Regression line
286
  ax.set_xlim(x_full.min(), x_full.max()); ax.set_ylim(y_full.min(), y_full.max())
287
  ax.set_title(spec.title); ax.grid(alpha=.3); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col)
288
 
289
  def init():
290
- scat.set_offsets(np.empty((0, 2)))
291
- line.set_data([], [])
292
  return [scat, line]
293
  def update(i):
294
- # Animate points for the first 70% of frames
295
  point_frames = int(frames * 0.7)
296
  if i <= point_frames:
297
  k = max(1, int(len(x_full) * (i / point_frames)))
298
  scat.set_offsets(plot_data.iloc[:k].values)
299
- # Animate regression line for the last 30%
300
  else:
301
- line_frame = i - point_frames
302
- line_total_frames = frames - point_frames
303
  current_x = reg_line_x[0] + (reg_line_x[1] - reg_line_x[0]) * (line_frame / line_total_frames)
304
  line.set_data([reg_line_x[0], current_x], [reg_line_y[0], slope * current_x + intercept])
305
  return [scat, line]
@@ -320,32 +295,19 @@ def animate_chart(spec: ChartSpecification, df: pd.DataFrame, dur: float, out: P
320
  k = max(2, int(len(x_full) * (i / (frames - 1))))
321
  fill = ax.fill_between(x_full[:k], y_full[:k], color="#4E79A7", alpha=0.4)
322
  return [fill]
323
- elif ctype == "heatmap":
324
- sns.heatmap(plot_data, annot=True, cmap="viridis", ax=ax, alpha=0)
325
- ax.set_title(spec.title)
326
- def init(): ax.collections[0].set_alpha(0); return [ax.collections[0]]
327
- def update(i): ax.collections[0].set_alpha(i / (frames - 1)); return [ax.collections[0]]
328
- elif ctype == "bubble":
329
- sizes = (plot_data[spec.size_col] - plot_data[spec.size_col].min() + 1) / (plot_data[spec.size_col].max() - plot_data[spec.size_col].min() + 1) * 2000 + 50
330
- scat = ax.scatter(plot_data[spec.x_col], plot_data[spec.y_col], s=sizes, alpha=0, color='#59A14F')
331
- ax.set_title(spec.title); ax.grid(alpha=.3); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col)
332
- def init(): scat.set_alpha(0); return [scat]
333
- def update(i): scat.set_alpha(i / (frames - 1) * 0.7); return [scat]
334
  else: # line (Time Series)
335
  line, = ax.plot([], [], lw=2, color='#A23B72')
336
- markers, = ax.plot([], [], 'o', color='#A23B72', markersize=5) # Animated markers
337
  plot_data = plot_data.sort_index() if not plot_data.index.is_monotonic_increasing else plot_data
338
  x_full, y_full = plot_data.index, plot_data.values
339
  ax.set_xlim(x_full.min(), x_full.max()); ax.set_ylim(y_full.min() * 0.9, y_full.max() * 1.1)
340
  ax.set_title(spec.title); ax.grid(alpha=.3); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col)
341
  def init():
342
- line.set_data([], [])
343
- markers.set_data([], [])
344
  return [line, markers]
345
  def update(i):
346
  k = max(2, int(len(x_full) * (i / (frames - 1))))
347
- line.set_data(x_full[:k], y_full[:k])
348
- markers.set_data(x_full[:k], y_full[:k])
349
  return [line, markers]
350
 
351
  anim = FuncAnimation(fig, update, init_func=init, frames=frames, blit=True, interval=1000 / fps)
@@ -363,11 +325,11 @@ def animate_image_fade(img: np.ndarray, dur: float, out: Path, fps: int = 24) ->
363
  video_writer.release()
364
  return str(out)
365
 
366
- def safe_chart(desc: str, df: pd.DataFrame, dur: float, out: Path) -> str:
367
  try:
368
  llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", google_api_key=API_KEY, temperature=0.1)
369
  chart_generator = ChartGenerator(llm, df)
370
- chart_spec = chart_generator.generate_chart_spec(desc)
371
  return animate_chart(chart_spec, df, dur, out)
372
  except Exception as e:
373
  logging.error(f"Chart animation failed for '{desc}': {e}. Falling back to static image.")
@@ -375,7 +337,7 @@ def safe_chart(desc: str, df: pd.DataFrame, dur: float, out: Path) -> str:
375
  temp_png = Path(temp_png_file.name)
376
  llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", google_api_key=API_KEY, temperature=0.1)
377
  chart_generator = ChartGenerator(llm, df)
378
- chart_spec = chart_generator.generate_chart_spec(desc)
379
  if execute_chart_spec(chart_spec, df, temp_png):
380
  img = cv2.imread(str(temp_png)); os.unlink(temp_png)
381
  img_resized = cv2.resize(img, (WIDTH, HEIGHT))
@@ -398,310 +360,118 @@ def concat_media(file_paths: List[str], output_path: Path):
398
  finally:
399
  list_file.unlink(missing_ok=True)
400
 
401
- # --- Main Business Logic Functions ---
402
- # This section containing generate_report_draft and its helpers is left unchanged as requested.
403
- # ... (all functions from sanitize_for_firebase_key to generate_single_chart) ...
404
- # The following functions are preserved exactly as they were in the original code provided.
405
 
406
  def sanitize_for_firebase_key(text: str) -> str:
407
- """Replaces Firebase-forbidden characters in a string with underscores."""
408
  forbidden_chars = ['.', '$', '#', '[', ']', '/']
409
  for char in forbidden_chars:
410
  text = text.replace(char, '_')
411
  return text
412
 
413
- def analyze_data_intelligence(df: pd.DataFrame, ctx_dict: Dict) -> Dict[str, Any]:
414
- """
415
- Autonomous data intelligence system that classifies domain,
416
- detects patterns, and determines optimal analytical approach.
417
- """
418
-
419
- # Domain Classification Engine
420
- domain_signals = {
421
- 'financial': ['amount', 'price', 'cost', 'revenue', 'profit', 'balance', 'transaction', 'payment'],
422
- 'survey': ['rating', 'satisfaction', 'score', 'response', 'feedback', 'opinion', 'agree', 'likert'],
423
- 'scientific': ['measurement', 'experiment', 'trial', 'test', 'control', 'variable', 'hypothesis'],
424
- 'marketing': ['campaign', 'conversion', 'click', 'impression', 'engagement', 'customer', 'segment'],
425
- 'operational': ['performance', 'efficiency', 'throughput', 'capacity', 'utilization', 'process'],
426
- 'temporal': ['date', 'time', 'timestamp', 'period', 'month', 'year', 'day', 'hour']
427
- }
428
-
429
- # Analyze column patterns
430
- columns_lower = [col.lower() for col in df.columns]
431
- domain_scores = {}
432
-
433
- for domain, keywords in domain_signals.items():
434
- score = sum(1 for col in columns_lower if any(keyword in col for keyword in keywords))
435
- domain_scores[domain] = score
436
-
437
- # Determine primary domain
438
- primary_domain = max(domain_scores, key=domain_scores.get) if max(domain_scores.values()) > 0 else 'general'
439
-
440
- # Data Structure Analysis
441
- numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
442
- categorical_cols = df.select_dtypes(include=['object', 'category']).columns.tolist()
443
- datetime_cols = df.select_dtypes(include=['datetime64']).columns.tolist()
444
-
445
- # Detect time series
446
- is_timeseries = len(datetime_cols) > 0 or any('date' in col.lower() or 'time' in col.lower() for col in columns_lower)
447
-
448
- # Statistical Profile
449
- statistical_summary = {}
450
- if numeric_cols:
451
- try:
452
- correlations = df[numeric_cols].corr().abs().max()
453
- correlations_dict = {k: float(v) if pd.notna(v) else 0.0 for k, v in correlations.to_dict().items()}
454
-
455
- distributions = {}
456
- for col in numeric_cols:
457
- if len(df[col].dropna()) > 8:
458
- try:
459
- p_value = stats.normaltest(df[col].dropna())[1]
460
- distributions[col] = 'normal' if p_value > 0.05 else 'non_normal'
461
- except:
462
- distributions[col] = 'unknown'
463
-
464
- outliers = {}
465
- for col in numeric_cols:
466
- if len(df[col].dropna()) > 0:
467
- try:
468
- z_scores = np.abs(stats.zscore(df[col].dropna()))
469
- outliers[col] = int(len(df[col][z_scores > 3]))
470
- except:
471
- outliers[col] = 0
472
-
473
- statistical_summary = {
474
- 'correlations': correlations_dict,
475
- 'distributions': distributions,
476
- 'outliers': outliers
477
- }
478
- except Exception as e:
479
- statistical_summary = {'error': 'Could not compute statistical summary'}
480
-
481
- # Pattern Detection
482
- patterns = {
483
- 'has_missing_data': df.isnull().sum().sum() > 0,
484
- 'has_duplicates': df.duplicated().sum() > 0,
485
- 'has_negative_values': any(df[col].min() < 0 for col in numeric_cols if len(df[col].dropna()) > 0),
486
- 'has_categorical_hierarchy': any(len(df[col].unique()) > 10 for col in categorical_cols),
487
- 'potential_segments': len(categorical_cols) > 0
488
- }
489
-
490
- # Insight Opportunities
491
- insight_opportunities = []
492
-
493
- if is_timeseries:
494
- insight_opportunities.append("temporal_trends")
495
-
496
- if len(numeric_cols) > 1:
497
- insight_opportunities.append("correlations")
498
-
499
- if len(categorical_cols) > 0 and len(numeric_cols) > 0:
500
- insight_opportunities.append("segmentation")
501
-
502
- if any(statistical_summary.get('outliers', {}).values()):
503
- insight_opportunities.append("anomalies")
504
-
505
- return {
506
- 'primary_domain': primary_domain,
507
- 'domain_confidence': domain_scores,
508
- 'data_structure': {
509
- 'is_timeseries': is_timeseries,
510
- 'numeric_cols': numeric_cols,
511
- 'categorical_cols': categorical_cols,
512
- 'datetime_cols': datetime_cols
513
  },
514
- 'statistical_profile': statistical_summary,
515
- 'patterns': patterns,
516
- 'insight_opportunities': insight_opportunities,
517
- 'narrative_suggestions': get_narrative_suggestions(primary_domain, insight_opportunities, patterns)
518
  }
519
-
520
- def get_narrative_suggestions(domain: str, opportunities: List[str], patterns: Dict) -> Dict[str, str]:
521
- """Generate narrative direction based on domain and data characteristics"""
522
 
523
- narrative_frameworks = {
524
- 'financial': {
525
- 'hook': "Follow the money trail that reveals your business's hidden opportunities",
526
- 'structure': "performance → trends → risks → opportunities",
527
- 'focus': "profitability, efficiency, growth patterns, risk indicators"
528
- },
529
- 'survey': {
530
- 'hook': "Your customers are speaking - here's what they're really saying",
531
- 'structure': "sentiment segments → drivers → actions",
532
- 'focus': "satisfaction drivers, demographic patterns, improvement areas"
533
- },
534
- 'scientific': {
535
- 'hook': "The data reveals relationships that challenge conventional thinking",
536
- 'structure': "hypothesis → evidence → significance → implications",
537
- 'focus': "statistical significance, correlations, experimental validity"
538
- },
539
- 'marketing': {
540
- 'hook': "Discover the customer journey patterns driving your growth",
541
- 'structure': "performance → segments → optimization → strategy",
542
- 'focus': "conversion funnels, customer segments, campaign effectiveness"
543
- },
544
- 'operational': {
545
- 'hook': "Operational excellence lives in the details - here's where to look",
546
- 'structure': "efficiency → bottlenecks → optimization → impact",
547
- 'focus': "process efficiency, capacity utilization, improvement opportunities"
548
- },
549
- 'general': {
550
- 'hook': "Every dataset tells a story - here's what yours is saying",
551
- 'structure': "overview → patterns → insights → implications",
552
- 'focus': "key patterns, significant relationships, actionable insights"
553
  }
554
- }
555
 
556
- return narrative_frameworks.get(domain, narrative_frameworks['general'])
557
-
558
- def json_serializable(obj):
559
- """Convert objects to JSON-serializable format"""
560
- if isinstance(obj, (np.integer, np.floating)):
561
- return float(obj)
562
- elif isinstance(obj, np.ndarray):
563
- return obj.tolist()
564
- elif isinstance(obj, (np.bool_, bool)):
565
- return bool(obj)
566
- elif isinstance(obj, dict):
567
- return {k: json_serializable(v) for k, v in obj.items()}
568
- elif isinstance(obj, (list, tuple)):
569
- return [json_serializable(item) for item in obj]
570
- elif pd.isna(obj):
571
- return None
572
- else:
573
- return obj
574
 
575
- def create_autonomous_prompt(df: pd.DataFrame, enhanced_ctx: Dict, intelligence: Dict) -> str:
576
- """
577
- Generate a dynamic, intelligence-driven prompt that creates compelling narratives
578
- rather than following templates.
579
- """
580
-
581
- domain = intelligence['primary_domain']
582
- opportunities = intelligence['insight_opportunities']
583
- narrative = intelligence['narrative_suggestions']
584
-
585
- # Dynamic chart strategy based on data characteristics
586
- chart_strategy = generate_chart_strategy(intelligence)
587
-
588
- # Make context JSON serializable
589
- serializable_ctx = json_serializable(enhanced_ctx)
590
 
591
- prompt = f"""You are an elite data storyteller with deep expertise in {domain} analytics. Your mission is to uncover the compelling narrative hidden in this dataset and present it as a captivating story that drives action.
592
-
593
- **THE DATA'S STORY CONTEXT:**
594
- {json.dumps(serializable_ctx, indent=2)}
595
-
596
- **INTELLIGENCE ANALYSIS:**
597
- - Primary Domain: {domain}
598
- - Key Opportunities: {', '.join(opportunities)}
599
- - Data Characteristics: {json_serializable(intelligence['data_structure'])}
600
- - Narrative Framework: {narrative['structure']}
601
-
602
- **YOUR STORYTELLING MISSION:**
603
- {narrative['hook']}
604
-
605
- **NARRATIVE CONSTRUCTION GUIDELINES:**
606
- 1. **LEAD WITH INTRIGUE**: Start with the most compelling finding that hooks the reader
607
- 2. **BUILD TENSION**: Present contrasts, surprises, or unexpected patterns
608
- 3. **REVEAL INSIGHTS**: Use data to resolve the tension with clear comprehensive explanations
609
- 4. **DRIVE ACTION**: End with specific, actionable recommendations
610
-
611
- **VISUALIZATION STRATEGY:**
612
- {chart_strategy}
613
-
614
- **CRITICAL INSTRUCTIONS:**
615
- - Write as if you're revealing a detective story, not filling a template
616
- - Every insight must be explained and supported by data evidence
617
- - Use compelling headers that create curiosity (not "Executive Summary")
618
- - Weave charts naturally into the narrative flow
619
- - Focus on business impact and actionable outcomes
620
- - Let the data's personality shine through your writing style
621
-
622
- **CHART INTEGRATION:**
623
- Insert charts using: `<generate_chart: "chart_type | compelling description that advances the story">`
624
- Available types: bar, pie, line, scatter, hist, heatmap, area, bubble
625
 
626
- Transform this data into a story that decision-makers can't stop reading."""
 
627
 
628
- return prompt
 
629
 
630
- def generate_chart_strategy(intelligence: Dict) -> str:
631
- """Generate visualization strategy based on data intelligence"""
632
-
633
- domain = intelligence['primary_domain']
634
- opportunities = intelligence['insight_opportunities']
635
- structure = intelligence['data_structure']
636
-
637
- strategies = {
638
- 'financial': "Focus on trend lines showing performance over time, comparative bars for different categories, and scatter plots revealing correlations between financial metrics.",
639
- 'survey': "Emphasize distribution histograms for satisfaction scores, segmented bar charts for demographic breakdowns, and correlation matrices for response patterns.",
640
- 'scientific': "Prioritize scatter plots with regression lines, distribution comparisons, and statistical significance visualizations.",
641
- 'marketing': "Highlight conversion funnels, customer segment comparisons, and campaign performance trends.",
642
- 'operational': "Show efficiency trends, capacity utilization charts, and process performance comparisons."
643
- }
644
-
645
- base_strategy = strategies.get(domain, "Create visualizations that best tell your data's unique story.")
646
-
647
- # Add specific guidance based on data characteristics
648
- if structure['is_timeseries']:
649
- base_strategy += " Leverage time-series visualizations like line or area charts to show trends and patterns over time."
650
-
651
- if 'correlations' in opportunities:
652
- base_strategy += " Include correlation visualizations like scatterplots or heatmaps to reveal hidden relationships."
653
-
654
- if 'segmentation' in opportunities:
655
- base_strategy += " Use segmented charts to highlight different groups or categories."
656
-
657
- return base_strategy
658
 
659
- def generate_report_draft(buf, name: str, ctx: str, uid: str, project_id: str, bucket):
660
- # This function remains unchanged as per the instructions.
661
- logging.info(f"Generating autonomous report draft for project {project_id}")
662
-
663
- df = load_dataframe_safely(buf, name)
664
- llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash", google_api_key=API_KEY, temperature=0.1)
665
 
666
- ctx_dict = {"shape": df.shape, "columns": list(df.columns), "user_ctx": ctx}
667
- enhanced_ctx = enhance_data_context(df, ctx_dict)
668
- intelligence = analyze_data_intelligence(df, ctx_dict)
669
- report_prompt = create_autonomous_prompt(df, enhanced_ctx, intelligence)
670
  md = llm.invoke(report_prompt).content
671
 
672
  chart_descs = extract_chart_tags(md)[:MAX_CHARTS]
673
  chart_urls = {}
674
  chart_generator = ChartGenerator(llm, df)
675
-
676
  for desc in chart_descs:
677
  safe_desc = sanitize_for_firebase_key(desc)
678
  md = md.replace(f'<generate_chart: "{desc}">', f'<generate_chart: "{safe_desc}">')
679
  md = md.replace(f'<generate_chart: {desc}>', f'<generate_chart: "{safe_desc}">')
680
-
681
  with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
682
  img_path = Path(temp_file.name)
683
  try:
684
- chart_spec = chart_generator.generate_chart_spec(desc)
685
  if execute_chart_spec(chart_spec, df, img_path):
686
  blob_name = f"sozo_projects/{uid}/{project_id}/charts/{uuid.uuid4().hex}.png"
687
  blob = bucket.blob(blob_name)
688
  blob.upload_from_filename(str(img_path))
689
  chart_urls[safe_desc] = blob.public_url
690
- logging.info(f"Uploaded chart '{desc}' to {blob.public_url} with safe key '{safe_desc}'")
691
  finally:
692
  if os.path.exists(img_path):
693
  os.unlink(img_path)
694
-
695
- return {"raw_md": md, "chartUrls": chart_urls}
696
 
697
  def generate_single_chart(df: pd.DataFrame, description: str, uid: str, project_id: str, bucket):
698
  logging.info(f"Generating single chart '{description}' for project {project_id}")
699
  llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", google_api_key=API_KEY, temperature=0.1)
700
  chart_generator = ChartGenerator(llm, df)
 
701
  with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
702
  img_path = Path(temp_file.name)
703
  try:
704
- chart_spec = chart_generator.generate_chart_spec(description)
705
  if execute_chart_spec(chart_spec, df, img_path):
706
  blob_name = f"sozo_projects/{uid}/{project_id}/charts/{uuid.uuid4().hex}.png"
707
  blob = bucket.blob(blob_name)
@@ -713,26 +483,23 @@ def generate_single_chart(df: pd.DataFrame, description: str, uid: str, project_
713
  os.unlink(img_path)
714
  return None
715
 
716
- # UPDATED: generate_video_from_project to handle Pexels integration
717
- def generate_video_from_project(df: pd.DataFrame, raw_md: str, uid: str, project_id: str, voice_model: str, bucket):
718
  logging.info(f"Generating video for project {project_id} with voice {voice_model}")
719
  llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash", google_api_key=API_KEY, temperature=0.2)
720
 
721
- # UPDATED: Prompt to create Intro/Conclusion scenes with stock video tags
722
  story_prompt = f"""
723
  Based on the following report, create a script for a {VIDEO_SCENES}-scene video.
724
  1. The first scene MUST be an "Introduction". It must contain narration and a stock video tag like: <generate_stock_video: "search query">.
725
  2. The last scene MUST be a "Conclusion". It must also contain narration and a stock video tag.
726
  3. The middle scenes should each contain narration and one chart tag from the report.
727
  4. Separate each scene with '[SCENE_BREAK]'.
728
-
729
  Report: {raw_md}
730
-
731
  Only output the script, no extra text.
732
  """
733
  script = llm.invoke(story_prompt).content
734
  scenes = [s.strip() for s in script.split("[SCENE_BREAK]") if s.strip()]
735
  video_parts, audio_parts, temps = [], [], []
 
736
 
737
  for i, sc in enumerate(scenes):
738
  chart_descs = extract_chart_tags(sc)
@@ -745,35 +512,36 @@ def generate_video_from_project(df: pd.DataFrame, raw_md: str, uid: str, project
745
 
746
  audio_bytes = deepgram_tts(narrative, voice_model)
747
  mp3 = Path(tempfile.gettempdir()) / f"{uuid.uuid4()}.mp3"
 
748
  if audio_bytes:
749
- mp3.write_bytes(audio_bytes); dur = audio_duration(str(mp3))
750
- if dur <= 0.1: dur = 5.0
 
751
  else:
752
- dur = 5.0; generate_silence_mp3(dur, mp3)
 
753
  audio_parts.append(str(mp3)); temps.append(mp3)
 
754
 
 
755
  mp4 = Path(tempfile.gettempdir()) / f"{uuid.uuid4()}.mp4"
756
  video_generated = False
757
 
758
  if pexels_descs:
759
- logging.info(f"Scene {i+1}: Found Pexels tag '{pexels_descs[0]}'. Searching for video.")
760
- video_path = search_and_download_pexels_video(pexels_descs[0], dur, mp4)
761
  if video_path:
762
- video_parts.append(video_path)
763
- temps.append(Path(video_path))
764
  video_generated = True
765
 
766
  if not video_generated and chart_descs:
767
- logging.info(f"Scene {i+1}: Found chart tag '{chart_descs[0]}'. Generating chart animation.")
768
- safe_chart(chart_descs[0], df, dur, mp4)
769
  video_parts.append(str(mp4)); temps.append(mp4)
770
  video_generated = True
771
 
772
  if not video_generated:
773
- logging.warning(f"Scene {i+1}: No valid chart or stock video tag found. Using fallback image.")
774
  img = generate_image_from_prompt(narrative)
775
  img_cv = cv2.cvtColor(np.array(img.resize((WIDTH, HEIGHT))), cv2.COLOR_RGB2BGR)
776
- animate_image_fade(img_cv, dur, mp4)
777
  video_parts.append(str(mp4)); temps.append(mp4)
778
 
779
  with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_vid, \
@@ -787,12 +555,14 @@ def generate_video_from_project(df: pd.DataFrame, raw_md: str, uid: str, project
787
  concat_media(video_parts, silent_vid_path)
788
  concat_media(audio_parts, audio_mix_path)
789
 
790
- subprocess.run(
791
- ["ffmpeg", "-y", "-i", str(silent_vid_path), "-i", str(audio_mix_path),
792
  "-c:v", "libx264", "-pix_fmt", "yuv420p", "-c:a", "aac",
793
- "-map", "0:v:0", "-map", "1:a:0", "-shortest", str(final_vid_path)],
794
- check=True, capture_output=True,
795
- )
 
 
796
 
797
  blob_name = f"sozo_projects/{uid}/{project_id}/video.mp4"
798
  blob = bucket.blob(blob_name)
 
13
  matplotlib.use("Agg")
14
  import matplotlib.pyplot as plt
15
  from matplotlib.animation import FuncAnimation, FFMpegWriter
16
+ import seaborn as sns
17
+ from scipy import stats
18
  from PIL import Image
19
  import cv2
20
  import inspect
 
29
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - [%(funcName)s] - %(message)s')
30
  FPS, WIDTH, HEIGHT = 24, 1280, 720
31
  MAX_CHARTS, VIDEO_SCENES = 5, 5
32
+ MAX_CONTEXT_TOKENS = 250000 # Set max token limit for full dataset context
33
 
34
  # --- API Initialization ---
35
  API_KEY = os.getenv("GOOGLE_API_KEY")
36
  if not API_KEY:
37
  raise ValueError("GOOGLE_API_KEY environment variable not set.")
38
 
 
39
  PEXELS_API_KEY = os.getenv("PEXELS_API_KEY")
40
 
41
  # --- Helper Functions ---
 
68
  return float(res.stdout.strip())
69
  except Exception: return 5.0
70
 
 
71
  TAG_RE = re.compile( r'[<[]\s*generate_?chart\s*[:=]?\s*[\"\'“”]?(?P<d>[^>\"\'”\]]+?)[\"\'“”]?\s*[>\]]', re.I, )
72
  TAG_RE_PEXELS = re.compile( r'[<[]\s*generate_?stock_?video\s*[:=]?\s*[\"\'“”]?(?P<d>[^>\"\'”\]]+?)[\"\'“”]?\s*[>\]]', re.I, )
73
  extract_chart_tags = lambda t: list( dict.fromkeys(m.group("d").strip() for m in TAG_RE.finditer(t or "")) )
74
  extract_pexels_tags = lambda t: list( dict.fromkeys(m.group("d").strip() for m in TAG_RE_PEXELS.finditer(t or "")) )
75
 
 
76
  re_scene = re.compile(r"^\s*scene\s*\d+[:.\- ]*", re.I | re.M)
77
  def clean_narration(txt: str) -> str:
78
  txt = TAG_RE.sub("", txt); txt = TAG_RE_PEXELS.sub("", txt); txt = re_scene.sub("", txt)
 
96
  except Exception:
97
  return placeholder_img()
98
 
 
99
  def search_and_download_pexels_video(query: str, duration: float, out_path: Path) -> str:
100
  if not PEXELS_API_KEY:
101
  logging.warning("PEXELS_API_KEY not set. Cannot fetch stock video.")
 
110
  logging.warning(f"No Pexels videos found for query: '{query}'")
111
  return None
112
 
 
113
  video_to_download = None
114
  for video in videos:
115
  for f in video.get('video_files', []):
 
123
  logging.warning(f"No suitable HD video file found for query: '{query}'")
124
  return None
125
 
 
126
  with requests.get(video_to_download, stream=True, timeout=60) as r:
127
  r.raise_for_status()
128
  with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_dl_file:
 
130
  temp_dl_file.write(chunk)
131
  temp_dl_path = Path(temp_dl_file.name)
132
 
 
133
  cmd = [
134
  "ffmpeg", "-y", "-i", str(temp_dl_path),
135
  "-vf", f"scale={WIDTH}:{HEIGHT}:force_original_aspect_ratio=decrease,pad={WIDTH}:{HEIGHT}:(ow-iw)/2:(oh-ih)/2,setsar=1",
 
148
  return None
149
 
150
  # --- Chart Generation System ---
 
151
  class ChartSpecification:
152
  def __init__(self, chart_type: str, title: str, x_col: str, y_col: str = None, size_col: str = None, agg_method: str = None, filter_condition: str = None, top_n: int = None, color_scheme: str = "professional"):
153
  self.chart_type = chart_type; self.title = title; self.x_col = x_col; self.y_col = y_col; self.size_col = size_col
154
  self.agg_method = agg_method or "sum"; self.filter_condition = filter_condition; self.top_n = top_n; self.color_scheme = color_scheme
155
 
 
 
 
 
 
156
  class ChartGenerator:
157
  def __init__(self, llm, df: pd.DataFrame):
158
  self.llm = llm; self.df = df
 
159
 
160
+ def generate_chart_spec(self, description: str, context: Dict) -> ChartSpecification:
 
 
161
  spec_prompt = f"""
162
+ You are a data visualization expert. Based on the dataset context and chart description, generate a precise chart specification.
163
+ **Dataset Context:** {json.dumps(context, indent=2)}
164
  **Chart Request:** {description}
165
  **Return a JSON specification with these exact fields:**
166
  {{
 
172
  "agg_method": "sum|mean|count|max|min|null",
173
  "top_n": "number_for_top_n_filtering_or_null"
174
  }}
175
+ Return only the JSON specification, no additional text.
176
  """
177
  try:
178
  response = self.llm.invoke(spec_prompt).content.strip()
 
184
  return ChartSpecification(**filtered_dict)
185
  except Exception as e:
186
  logging.error(f"Spec generation failed: {e}. Using fallback.")
187
+ numeric_cols = context.get('schema', {}).get('numeric_columns', list(self.df.select_dtypes(include=['number']).columns))
188
+ categorical_cols = context.get('schema', {}).get('categorical_columns', list(self.df.select_dtypes(exclude=['number']).columns))
189
+ ctype = "bar"
190
+ for t in ["pie", "line", "scatter", "hist", "heatmap", "area", "bubble"]:
191
+ if t in description.lower(): ctype = t
192
+ x = categorical_cols[0] if categorical_cols else self.df.columns[0]
193
+ y = numeric_cols[0] if numeric_cols and len(self.df.columns) > 1 else (self.df.columns[1] if len(self.df.columns) > 1 else None)
194
+ return ChartSpecification(ctype, description, x, y)
195
+
 
 
 
196
  def execute_chart_spec(spec: ChartSpecification, df: pd.DataFrame, output_path: Path) -> bool:
197
  try:
198
  plot_data = prepare_plot_data(spec, df)
 
213
  return True
214
  except Exception as e: logging.error(f"Static chart generation failed for '{spec.title}': {e}"); return False
215
 
 
216
  def prepare_plot_data(spec: ChartSpecification, df: pd.DataFrame):
217
  if spec.chart_type not in ["heatmap"]:
218
  if spec.x_col not in df.columns or (spec.y_col and spec.y_col not in df.columns): raise ValueError(f"Invalid columns in chart spec: {spec.x_col}, {spec.y_col}")
 
234
  return df[spec.x_col]
235
 
236
  # --- Animation & Video Generation ---
 
237
  def animate_chart(spec: ChartSpecification, df: pd.DataFrame, dur: float, out: Path, fps: int = FPS) -> str:
238
  plot_data = prepare_plot_data(spec, df)
239
  frames = max(10, int(dur * fps))
 
256
  return bars
257
  elif ctype == "scatter":
258
  x_full, y_full = plot_data.iloc[:, 0], plot_data.iloc[:, 1]
 
259
  slope, intercept, _, _, _ = stats.linregress(x_full, y_full)
260
  reg_line_x = np.array([x_full.min(), x_full.max()])
261
  reg_line_y = slope * reg_line_x + intercept
262
 
263
  scat = ax.scatter([], [], alpha=0.7, color='#F18F01')
264
+ line, = ax.plot([], [], 'r--', lw=2)
265
  ax.set_xlim(x_full.min(), x_full.max()); ax.set_ylim(y_full.min(), y_full.max())
266
  ax.set_title(spec.title); ax.grid(alpha=.3); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col)
267
 
268
  def init():
269
+ scat.set_offsets(np.empty((0, 2))); line.set_data([], [])
 
270
  return [scat, line]
271
  def update(i):
 
272
  point_frames = int(frames * 0.7)
273
  if i <= point_frames:
274
  k = max(1, int(len(x_full) * (i / point_frames)))
275
  scat.set_offsets(plot_data.iloc[:k].values)
 
276
  else:
277
+ line_frame = i - point_frames; line_total_frames = frames - point_frames
 
278
  current_x = reg_line_x[0] + (reg_line_x[1] - reg_line_x[0]) * (line_frame / line_total_frames)
279
  line.set_data([reg_line_x[0], current_x], [reg_line_y[0], slope * current_x + intercept])
280
  return [scat, line]
 
295
  k = max(2, int(len(x_full) * (i / (frames - 1))))
296
  fill = ax.fill_between(x_full[:k], y_full[:k], color="#4E79A7", alpha=0.4)
297
  return [fill]
 
 
 
 
 
 
 
 
 
 
 
298
  else: # line (Time Series)
299
  line, = ax.plot([], [], lw=2, color='#A23B72')
300
+ markers, = ax.plot([], [], 'o', color='#A23B72', markersize=5)
301
  plot_data = plot_data.sort_index() if not plot_data.index.is_monotonic_increasing else plot_data
302
  x_full, y_full = plot_data.index, plot_data.values
303
  ax.set_xlim(x_full.min(), x_full.max()); ax.set_ylim(y_full.min() * 0.9, y_full.max() * 1.1)
304
  ax.set_title(spec.title); ax.grid(alpha=.3); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col)
305
  def init():
306
+ line.set_data([], []); markers.set_data([], [])
 
307
  return [line, markers]
308
  def update(i):
309
  k = max(2, int(len(x_full) * (i / (frames - 1))))
310
+ line.set_data(x_full[:k], y_full[:k]); markers.set_data(x_full[:k], y_full[:k])
 
311
  return [line, markers]
312
 
313
  anim = FuncAnimation(fig, update, init_func=init, frames=frames, blit=True, interval=1000 / fps)
 
325
  video_writer.release()
326
  return str(out)
327
 
328
+ def safe_chart(desc: str, df: pd.DataFrame, dur: float, out: Path, context: Dict) -> str:
329
  try:
330
  llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", google_api_key=API_KEY, temperature=0.1)
331
  chart_generator = ChartGenerator(llm, df)
332
+ chart_spec = chart_generator.generate_chart_spec(desc, context)
333
  return animate_chart(chart_spec, df, dur, out)
334
  except Exception as e:
335
  logging.error(f"Chart animation failed for '{desc}': {e}. Falling back to static image.")
 
337
  temp_png = Path(temp_png_file.name)
338
  llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", google_api_key=API_KEY, temperature=0.1)
339
  chart_generator = ChartGenerator(llm, df)
340
+ chart_spec = chart_generator.generate_chart_spec(desc, context)
341
  if execute_chart_spec(chart_spec, df, temp_png):
342
  img = cv2.imread(str(temp_png)); os.unlink(temp_png)
343
  img_resized = cv2.resize(img, (WIDTH, HEIGHT))
 
360
  finally:
361
  list_file.unlink(missing_ok=True)
362
 
363
+ # --- Main Business Logic ---
 
 
 
364
 
365
  def sanitize_for_firebase_key(text: str) -> str:
 
366
  forbidden_chars = ['.', '$', '#', '[', ']', '/']
367
  for char in forbidden_chars:
368
  text = text.replace(char, '_')
369
  return text
370
 
371
+ def get_augmented_context(df: pd.DataFrame, user_ctx: str) -> Dict:
372
+ """Creates a detailed summary of the dataframe for the AI."""
373
+ numeric_cols = df.select_dtypes(include=['number']).columns.tolist()
374
+ categorical_cols = df.select_dtypes(exclude=['number']).columns.tolist()
375
+
376
+ context = {
377
+ "user_context": user_ctx,
378
+ "dataset_shape": {"rows": df.shape[0], "columns": df.shape[1]},
379
+ "schema": {
380
+ "numeric_columns": numeric_cols,
381
+ "categorical_columns": categorical_cols
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
382
  },
383
+ "data_previews": {}
 
 
 
384
  }
 
 
 
385
 
386
+ for col in categorical_cols[:5]:
387
+ unique_vals = df[col].unique()
388
+ context["data_previews"][col] = {
389
+ "count": len(unique_vals),
390
+ "values": unique_vals[:5].tolist()
391
+ }
392
+
393
+ for col in numeric_cols[:5]:
394
+ context["data_previews"][col] = {
395
+ "mean": df[col].mean(),
396
+ "min": df[col].min(),
397
+ "max": df[col].max()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
398
  }
 
399
 
400
+ return json.loads(json.dumps(context, default=str))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
401
 
402
+ def generate_report_draft(buf, name: str, ctx: str, uid: str, project_id: str, bucket):
403
+ logging.info(f"Generating report draft for project {project_id}")
404
+ df = load_dataframe_safely(buf, name)
405
+ llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash", google_api_key=API_KEY, temperature=0.1)
 
 
 
 
 
 
 
 
 
 
 
406
 
407
+ data_context_str = ""
408
+ context_for_charts = {}
409
+ try:
410
+ df_json = df.to_json(orient='records')
411
+ estimated_tokens = len(df_json) / 4
412
+ if estimated_tokens < MAX_CONTEXT_TOKENS:
413
+ logging.info(f"Dataset is small enough ({estimated_tokens:.0f} tokens). Using full JSON context.")
414
+ data_context_str = f"Here is the full dataset in JSON format:\n{df_json}"
415
+ context_for_charts = get_augmented_context(df, ctx)
416
+ else:
417
+ raise ValueError("Dataset too large for full context.")
418
+ except Exception as e:
419
+ logging.warning(f"Could not use full JSON context ({e}). Falling back to augmented summary.")
420
+ augmented_context = get_augmented_context(df, ctx)
421
+ data_context_str = f"The full dataset is too large to display. Here is a detailed summary:\n{json.dumps(augmented_context, indent=2)}"
422
+ context_for_charts = augmented_context
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
423
 
424
+ report_prompt = f"""
425
+ You are an expert data analyst and business intelligence storyteller. Your mission is to analyze the provided data context and write a comprehensive, executive-level report in Markdown format.
426
 
427
+ **Data Context:**
428
+ {data_context_str}
429
 
430
+ **Critical Instructions:**
431
+ 1. **Data Grounding:** Your entire analysis and narrative **must strictly** use the column names and data provided in the 'Data Context' section. Do not invent, modify, or assume any column names that are not on this list. This is the most important rule.
432
+ 2. **Report Goal:** Create a well-structured, professional report in Markdown that tells a compelling story from the data. The structure of the report is entirely up to you, but it should be logical and easy to follow.
433
+ 3. **Visual Support:** Wherever a key finding, trend, or significant point is made in your narrative, you **must** support it with a chart tag using the format: `<generate_chart: "chart_type | a specific, compelling description">`.
434
+ 4. **Chart Tag Grounding:** The column names used in your chart descriptions **must** also be an exact match from the provided data context.
435
+ 5. **Available Chart Types:** `bar, pie, line, scatter, hist, heatmap, area, bubble`.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
436
 
437
+ Now, generate the complete Markdown report.
438
+ """
 
 
 
 
439
 
 
 
 
 
440
  md = llm.invoke(report_prompt).content
441
 
442
  chart_descs = extract_chart_tags(md)[:MAX_CHARTS]
443
  chart_urls = {}
444
  chart_generator = ChartGenerator(llm, df)
445
+
446
  for desc in chart_descs:
447
  safe_desc = sanitize_for_firebase_key(desc)
448
  md = md.replace(f'<generate_chart: "{desc}">', f'<generate_chart: "{safe_desc}">')
449
  md = md.replace(f'<generate_chart: {desc}>', f'<generate_chart: "{safe_desc}">')
450
+
451
  with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
452
  img_path = Path(temp_file.name)
453
  try:
454
+ chart_spec = chart_generator.generate_chart_spec(desc, context_for_charts)
455
  if execute_chart_spec(chart_spec, df, img_path):
456
  blob_name = f"sozo_projects/{uid}/{project_id}/charts/{uuid.uuid4().hex}.png"
457
  blob = bucket.blob(blob_name)
458
  blob.upload_from_filename(str(img_path))
459
  chart_urls[safe_desc] = blob.public_url
 
460
  finally:
461
  if os.path.exists(img_path):
462
  os.unlink(img_path)
463
+
464
+ return {"raw_md": md, "chartUrls": chart_urls, "data_context": context_for_charts}
465
 
466
  def generate_single_chart(df: pd.DataFrame, description: str, uid: str, project_id: str, bucket):
467
  logging.info(f"Generating single chart '{description}' for project {project_id}")
468
  llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", google_api_key=API_KEY, temperature=0.1)
469
  chart_generator = ChartGenerator(llm, df)
470
+ context = get_augmented_context(df, "")
471
  with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
472
  img_path = Path(temp_file.name)
473
  try:
474
+ chart_spec = chart_generator.generate_chart_spec(description, context)
475
  if execute_chart_spec(chart_spec, df, img_path):
476
  blob_name = f"sozo_projects/{uid}/{project_id}/charts/{uuid.uuid4().hex}.png"
477
  blob = bucket.blob(blob_name)
 
483
  os.unlink(img_path)
484
  return None
485
 
486
+ def generate_video_from_project(df: pd.DataFrame, raw_md: str, data_context: Dict, uid: str, project_id: str, voice_model: str, bucket):
 
487
  logging.info(f"Generating video for project {project_id} with voice {voice_model}")
488
  llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash", google_api_key=API_KEY, temperature=0.2)
489
 
 
490
  story_prompt = f"""
491
  Based on the following report, create a script for a {VIDEO_SCENES}-scene video.
492
  1. The first scene MUST be an "Introduction". It must contain narration and a stock video tag like: <generate_stock_video: "search query">.
493
  2. The last scene MUST be a "Conclusion". It must also contain narration and a stock video tag.
494
  3. The middle scenes should each contain narration and one chart tag from the report.
495
  4. Separate each scene with '[SCENE_BREAK]'.
 
496
  Report: {raw_md}
 
497
  Only output the script, no extra text.
498
  """
499
  script = llm.invoke(story_prompt).content
500
  scenes = [s.strip() for s in script.split("[SCENE_BREAK]") if s.strip()]
501
  video_parts, audio_parts, temps = [], [], []
502
+ total_audio_duration = 0.0
503
 
504
  for i, sc in enumerate(scenes):
505
  chart_descs = extract_chart_tags(sc)
 
512
 
513
  audio_bytes = deepgram_tts(narrative, voice_model)
514
  mp3 = Path(tempfile.gettempdir()) / f"{uuid.uuid4()}.mp3"
515
+ audio_dur = 5.0
516
  if audio_bytes:
517
+ mp3.write_bytes(audio_bytes)
518
+ audio_dur = audio_duration(str(mp3))
519
+ if audio_dur <= 0.1: audio_dur = 5.0
520
  else:
521
+ generate_silence_mp3(audio_dur, mp3)
522
+
523
  audio_parts.append(str(mp3)); temps.append(mp3)
524
+ total_audio_duration += audio_dur
525
 
526
+ video_dur = audio_dur + 0.5
527
  mp4 = Path(tempfile.gettempdir()) / f"{uuid.uuid4()}.mp4"
528
  video_generated = False
529
 
530
  if pexels_descs:
531
+ video_path = search_and_download_pexels_video(pexels_descs[0], video_dur, mp4)
 
532
  if video_path:
533
+ video_parts.append(video_path); temps.append(Path(video_path))
 
534
  video_generated = True
535
 
536
  if not video_generated and chart_descs:
537
+ safe_chart(chart_descs[0], df, video_dur, mp4, data_context)
 
538
  video_parts.append(str(mp4)); temps.append(mp4)
539
  video_generated = True
540
 
541
  if not video_generated:
 
542
  img = generate_image_from_prompt(narrative)
543
  img_cv = cv2.cvtColor(np.array(img.resize((WIDTH, HEIGHT))), cv2.COLOR_RGB2BGR)
544
+ animate_image_fade(img_cv, video_dur, mp4)
545
  video_parts.append(str(mp4)); temps.append(mp4)
546
 
547
  with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_vid, \
 
555
  concat_media(video_parts, silent_vid_path)
556
  concat_media(audio_parts, audio_mix_path)
557
 
558
+ cmd = [
559
+ "ffmpeg", "-y", "-i", str(silent_vid_path), "-i", str(audio_mix_path),
560
  "-c:v", "libx264", "-pix_fmt", "yuv420p", "-c:a", "aac",
561
+ "-map", "0:v:0", "-map", "1:a:0",
562
+ "-t", f"{total_audio_duration:.3f}",
563
+ str(final_vid_path)
564
+ ]
565
+ subprocess.run(cmd, check=True, capture_output=True)
566
 
567
  blob_name = f"sozo_projects/{uid}/{project_id}/video.mp4"
568
  blob = bucket.blob(blob_name)