rairo commited on
Commit
f3f35a4
·
verified ·
1 Parent(s): 4408186

Update sozo_gen.py

Browse files
Files changed (1) hide show
  1. sozo_gen.py +191 -286
sozo_gen.py CHANGED
@@ -1,5 +1,3 @@
1
- # sozo_gen.py
2
-
3
  import os
4
  import re
5
  import json
@@ -9,110 +7,128 @@ import io
9
  from pathlib import Path
10
  import pandas as pd
11
  import numpy as np
12
- import matplotlib
13
- matplotlib.use("Agg")
14
- import matplotlib.pyplot as plt
15
- from matplotlib.animation import FuncAnimation, FFMpegWriter
16
- from PIL import Image
17
- import cv2
18
- import inspect
19
- import tempfile
20
  import subprocess
21
- from typing import Dict, List
 
22
  from langchain_google_genai import ChatGoogleGenerativeAI
23
  from google import genai
24
  import requests
25
 
26
  # --- Configuration ---
27
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - [%(funcName)s] - %(message)s')
28
- FPS, WIDTH, HEIGHT = 24, 1280, 720
29
  MAX_CHARTS, VIDEO_SCENES = 5, 5
30
 
31
  # --- Gemini API Initialization ---
32
- # CORRECTED: Use the correct environment variable name and remove the deprecated configure call.
33
  API_KEY = os.getenv("GOOGLE_API_KEY")
34
  if not API_KEY:
35
  raise ValueError("GOOGLE_API_KEY environment variable not set.")
36
- # REMOVED: genai.configure(api_key=API_KEY) - This is deprecated. The library now uses the environment variable automatically.
37
 
38
  # --- Helper Functions ---
39
- def load_dataframe_safely(buf, name: str):
 
40
  ext = Path(name).suffix.lower()
41
- df = (pd.read_excel if ext in (".xlsx", ".xls") else pd.read_csv)(buf)
42
- df.columns = df.columns.astype(str).str.strip()
43
- df = df.dropna(how="all")
44
- if df.empty or len(df.columns) == 0: raise ValueError("No usable data found")
45
- return df
 
 
 
 
 
 
46
 
47
- def deepgram_tts(txt: str, voice_model: str):
 
48
  DG_KEY = os.getenv("DEEPGRAM_API_KEY")
49
- if not DG_KEY or not txt: return None
 
 
50
  txt = re.sub(r"[^\w\s.,!?;:-]", "", txt)[:1000]
51
  try:
52
- r = requests.post("https://api.deepgram.com/v1/speak", params={"model": voice_model}, headers={"Authorization": f"Token {DG_KEY}", "Content-Type": "application/json"}, json={"text": txt}, timeout=30)
 
 
 
 
 
 
53
  r.raise_for_status()
54
  return r.content
55
  except Exception as e:
56
  logging.error(f"Deepgram TTS failed: {e}")
57
  return None
58
 
59
- def generate_silence_mp3(duration: float, out: Path):
60
- subprocess.run([ "ffmpeg", "-y", "-f", "lavfi", "-i", "anullsrc=r=44100:cl=mono", "-t", f"{duration:.3f}", "-q:a", "9", str(out)], check=True, capture_output=True)
61
-
62
- def audio_duration(path: str) -> float:
63
- try:
64
- res = subprocess.run([ "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=nw=1:nk=1", path], text=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True)
65
- return float(res.stdout.strip())
66
- except Exception: return 5.0
67
-
68
- TAG_RE = re.compile( r'[<[]\s*generate_?chart\s*[:=]?\s*[\"\'“”]?(?P<d>[^>\"\'”\]]+?)[\"\'“”]?\s*[>\]]', re.I, )
69
- extract_chart_tags = lambda t: list( dict.fromkeys(m.group("d").strip() for m in TAG_RE.finditer(t or "")) )
70
 
71
  re_scene = re.compile(r"^\s*scene\s*\d+[:.\- ]*", re.I | re.M)
72
  def clean_narration(txt: str) -> str:
73
- txt = TAG_RE.sub("", txt); txt = re_scene.sub("", txt)
 
 
74
  phrases_to_remove = [r"as you can see in the chart", r"this chart shows", r"the chart illustrates", r"in this visual", r"this graph displays"]
75
- for phrase in phrases_to_remove: txt = re.sub(phrase, "", txt, flags=re.IGNORECASE)
76
- txt = re.sub(r"\s*\([^)]*\)", "", txt); txt = re.sub(r"[\*#_]", "", txt)
 
 
77
  return re.sub(r"\s{2,}", " ", txt).strip()
78
 
79
- def placeholder_img() -> Image.Image: return Image.new("RGB", (WIDTH, HEIGHT), (230, 230, 230))
80
-
81
- def generate_image_from_prompt(prompt: str) -> Image.Image:
82
- model_main = "gemini-2.0-flash-exp-image-generation"; model_fallback = "gemini-2.0-flash-preview-image-generation"
83
- full_prompt = "A clean business-presentation illustration: " + prompt
84
- def fetch(model_name):
85
- try:
86
- model = genai.GenerativeModel(model_name)
87
- res = model.generate_content(full_prompt)
88
- for part in res.candidates[0].content.parts:
89
- if getattr(part, "inline_data", None):
90
- return Image.open(io.BytesIO(part.inline_data.data)).convert("RGB")
91
- return None
92
- except Exception:
93
- return None
94
- try:
95
- img = fetch(model_main) or fetch(model_fallback)
96
- return img if img else placeholder_img()
97
- except Exception: return placeholder_img()
98
-
99
  # --- Chart Generation System ---
100
  class ChartSpecification:
101
- def __init__(self, chart_type: str, title: str, x_col: str, y_col: str, agg_method: str = None, filter_condition: str = None, top_n: int = None, color_scheme: str = "professional"):
102
- self.chart_type = chart_type; self.title = title; self.x_col = x_col; self.y_col = y_col
103
- self.agg_method = agg_method or "sum"; self.filter_condition = filter_condition; self.top_n = top_n; self.color_scheme = color_scheme
 
 
 
 
 
104
 
105
  def enhance_data_context(df: pd.DataFrame, ctx_dict: Dict) -> Dict:
106
- enhanced_ctx = ctx_dict.copy(); numeric_cols = df.select_dtypes(include=['number']).columns.tolist(); categorical_cols = df.select_dtypes(exclude=['number']).columns.tolist()
 
 
 
107
  enhanced_ctx.update({"numeric_columns": numeric_cols, "categorical_columns": categorical_cols})
108
  return enhanced_ctx
109
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  class ChartGenerator:
111
  def __init__(self, llm, df: pd.DataFrame):
112
- self.llm = llm; self.df = df
113
- self.enhanced_ctx = enhance_data_context(df, {"columns": list(df.columns), "shape": df.shape, "dtypes": {col: str(dtype) for col, dtype in df.dtypes.items()}})
 
114
 
115
- def generate_chart_spec(self, description: str) -> ChartSpecification:
 
116
  spec_prompt = f"""
117
  You are a data visualization expert. Based on the dataset and chart description, generate a precise chart specification.
118
  **Dataset Info:** {json.dumps(self.enhanced_ctx, indent=2)}
@@ -129,246 +145,135 @@ class ChartGenerator:
129
  if response.startswith("```json"): response = response[7:-3]
130
  elif response.startswith("```"): response = response[3:-3]
131
  spec_dict = json.loads(response)
132
- valid_keys = [p.name for p in inspect.signature(ChartSpecification).parameters.values() if p.name not in ['reasoning', 'filter_condition', 'color_scheme']]
133
- filtered_dict = {k: v for k, v in spec_dict.items() if k in valid_keys}
134
- return ChartSpecification(**filtered_dict)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  except Exception as e:
136
- logging.error(f"Spec generation failed: {e}. Using fallback.")
137
- return self._create_fallback_spec(description)
138
-
139
- def _create_fallback_spec(self, description: str) -> ChartSpecification:
140
- numeric_cols = self.enhanced_ctx['numeric_columns']; categorical_cols = self.enhanced_ctx['categorical_columns']
141
- ctype = "bar"
142
- for t in ["pie", "line", "scatter", "hist"]:
143
- if t in description.lower(): ctype = t
144
- x = categorical_cols[0] if categorical_cols else self.df.columns[0]
145
- y = numeric_cols[0] if numeric_cols and len(self.df.columns) > 1 else (self.df.columns[1] if len(self.df.columns) > 1 else None)
146
- return ChartSpecification(ctype, description, x, y)
147
-
148
- def execute_chart_spec(spec: ChartSpecification, df: pd.DataFrame, output_path: Path) -> bool:
149
- try:
150
- plot_data = prepare_plot_data(spec, df)
151
- fig, ax = plt.subplots(figsize=(12, 8)); plt.style.use('default')
152
- if spec.chart_type == "bar": ax.bar(plot_data.index.astype(str), plot_data.values, color='#2E86AB', alpha=0.8); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col); ax.tick_params(axis='x', rotation=45)
153
- elif spec.chart_type == "pie": ax.pie(plot_data.values, labels=plot_data.index, autopct='%1.1f%%', startangle=90); ax.axis('equal')
154
- elif spec.chart_type == "line": ax.plot(plot_data.index, plot_data.values, marker='o', linewidth=2, color='#A23B72'); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col); ax.grid(True, alpha=0.3)
155
- elif spec.chart_type == "scatter": ax.scatter(plot_data.iloc[:, 0], plot_data.iloc[:, 1], alpha=0.6, color='#F18F01'); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col); ax.grid(True, alpha=0.3)
156
- elif spec.chart_type == "hist": ax.hist(plot_data.values, bins=20, color='#C73E1D', alpha=0.7, edgecolor='black'); ax.set_xlabel(spec.x_col); ax.set_ylabel('Frequency'); ax.grid(True, alpha=0.3)
157
- ax.set_title(spec.title, fontsize=14, fontweight='bold', pad=20); plt.tight_layout()
158
- plt.savefig(output_path, dpi=300, bbox_inches='tight', facecolor='white'); plt.close()
159
- return True
160
- except Exception as e: logging.error(f"Static chart generation failed for '{spec.title}': {e}"); return False
161
-
162
- def prepare_plot_data(spec: ChartSpecification, df: pd.DataFrame) -> pd.Series:
163
- if spec.x_col not in df.columns or (spec.y_col and spec.y_col not in df.columns): raise ValueError(f"Invalid columns in chart spec: {spec.x_col}, {spec.y_col}")
164
- if spec.chart_type in ["bar", "pie"]:
165
- if not spec.y_col: return df[spec.x_col].value_counts().nlargest(spec.top_n or 10)
166
- grouped = df.groupby(spec.x_col)[spec.y_col].agg(spec.agg_method or 'sum')
167
- return grouped.nlargest(spec.top_n or 10)
168
- elif spec.chart_type == "line": return df.set_index(spec.x_col)[spec.y_col].sort_index()
169
- elif spec.chart_type == "scatter": return df[[spec.x_col, spec.y_col]].dropna()
170
- elif spec.chart_type == "hist": return df[spec.x_col].dropna()
171
- return df[spec.x_col]
172
-
173
- # --- Animation & Video Generation ---
174
- def animate_chart(spec: ChartSpecification, df: pd.DataFrame, dur: float, out: Path, fps: int = FPS) -> str:
175
- plot_data = prepare_plot_data(spec, df)
176
- frames = max(10, int(dur * fps))
177
- fig, ax = plt.subplots(figsize=(WIDTH / 100, HEIGHT / 100), dpi=100)
178
- plt.tight_layout(pad=3.0)
179
- ctype = spec.chart_type
180
- if ctype == "pie":
181
- wedges, _, _ = ax.pie(plot_data, labels=plot_data.index, startangle=90, autopct='%1.1f%%')
182
- ax.set_title(spec.title); ax.axis('equal')
183
- def init(): [w.set_alpha(0) for w in wedges]; return wedges
184
- def update(i): [w.set_alpha(i / (frames - 1)) for w in wedges]; return wedges
185
- elif ctype == "bar":
186
- bars = ax.bar(plot_data.index.astype(str), np.zeros_like(plot_data.values, dtype=float), color="#1f77b4")
187
- ax.set_ylim(0, plot_data.max() * 1.1 if not pd.isna(plot_data.max()) and plot_data.max() > 0 else 1)
188
- ax.set_title(spec.title); plt.xticks(rotation=45, ha="right")
189
- def init(): return bars
190
- def update(i):
191
- for b, h in zip(bars, plot_data.values): b.set_height(h * (i / (frames - 1)))
192
- return bars
193
- elif ctype == "scatter":
194
- scat = ax.scatter([], [], alpha=0.7)
195
- x_full, y_full = plot_data.iloc[:, 0], plot_data.iloc[:, 1]
196
- ax.set_xlim(x_full.min(), x_full.max()); ax.set_ylim(y_full.min(), y_full.max())
197
- ax.set_title(spec.title); ax.grid(alpha=.3); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col)
198
- def init(): scat.set_offsets(np.empty((0, 2))); return [scat]
199
- def update(i):
200
- k = max(1, int(len(x_full) * (i / (frames - 1))))
201
- scat.set_offsets(plot_data.iloc[:k].values); return [scat]
202
- elif ctype == "hist":
203
- _, _, patches = ax.hist(plot_data, bins=20, alpha=0)
204
- ax.set_title(spec.title); ax.set_xlabel(spec.x_col); ax.set_ylabel("Frequency")
205
- def init(): [p.set_alpha(0) for p in patches]; return patches
206
- def update(i): [p.set_alpha((i / (frames - 1)) * 0.7) for p in patches]; return patches
207
- else: # line
208
- line, = ax.plot([], [], lw=2)
209
- plot_data = plot_data.sort_index() if not plot_data.index.is_monotonic_increasing else plot_data
210
- x_full, y_full = plot_data.index, plot_data.values
211
- ax.set_xlim(x_full.min(), x_full.max()); ax.set_ylim(y_full.min() * 0.9, y_full.max() * 1.1)
212
- ax.set_title(spec.title); ax.grid(alpha=.3); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col)
213
- def init(): line.set_data([], []); return [line]
214
- def update(i):
215
- k = max(2, int(len(x_full) * (i / (frames - 1))))
216
- line.set_data(x_full[:k], y_full[:k]); return [line]
217
- anim = FuncAnimation(fig, update, init_func=init, frames=frames, blit=True, interval=1000 / fps)
218
- anim.save(str(out), writer=FFMpegWriter(fps=fps), dpi=144)
219
- plt.close(fig)
220
- return str(out)
221
-
222
- def animate_image_fade(img: np.ndarray, dur: float, out: Path, fps: int = 24) -> str:
223
- fourcc = cv2.VideoWriter_fourcc(*'mp4v'); video_writer = cv2.VideoWriter(str(out), fourcc, fps, (WIDTH, HEIGHT))
224
- total_frames = max(1, int(dur * fps))
225
- for i in range(total_frames):
226
- alpha = i / (total_frames - 1) if total_frames > 1 else 1.0
227
- frame = cv2.addWeighted(img, alpha, np.zeros_like(img), 1 - alpha, 0)
228
- video_writer.write(frame)
229
- video_writer.release()
230
- return str(out)
231
-
232
- def safe_chart(desc: str, df: pd.DataFrame, dur: float, out: Path) -> str:
233
- try:
234
- llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", google_api_key=API_KEY, temperature=0.1)
235
- chart_generator = ChartGenerator(llm, df)
236
- chart_spec = chart_generator.generate_chart_spec(desc)
237
- return animate_chart(chart_spec, df, dur, out)
238
- except Exception as e:
239
- logging.error(f"Chart animation failed for '{desc}': {e}. Falling back to static image.")
240
- with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_png_file:
241
- temp_png = Path(temp_png_file.name)
242
- llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", google_api_key=API_KEY, temperature=0.1)
243
- chart_generator = ChartGenerator(llm, df)
244
- chart_spec = chart_generator.generate_chart_spec(desc)
245
- if execute_chart_spec(chart_spec, df, temp_png):
246
- img = cv2.imread(str(temp_png)); os.unlink(temp_png)
247
- img_resized = cv2.resize(img, (WIDTH, HEIGHT))
248
- return animate_image_fade(img_resized, dur, out)
249
- else:
250
- img = generate_image_from_prompt(f"A professional business chart showing {desc}")
251
- img_cv = cv2.cvtColor(np.array(img.resize((WIDTH, HEIGHT))), cv2.COLOR_RGB2BGR)
252
- return animate_image_fade(img_cv, dur, out)
253
-
254
- def concat_media(file_paths: List[str], output_path: Path):
255
- valid_paths = [p for p in file_paths if Path(p).exists() and Path(p).stat().st_size > 100]
256
- if not valid_paths: raise ValueError("No valid media files to concatenate.")
257
- if len(valid_paths) == 1: import shutil; shutil.copy2(valid_paths[0], str(output_path)); return
258
- list_file = output_path.with_suffix(".txt")
259
- with open(list_file, 'w') as f:
260
- for path in valid_paths: f.write(f"file '{Path(path).resolve()}'\n")
261
- cmd = ["ffmpeg", "-y", "-f", "concat", "-safe", "0", "-i", str(list_file), "-c", "copy", str(output_path)]
262
- try:
263
- subprocess.run(cmd, check=True, capture_output=True, text=True)
264
- finally:
265
- list_file.unlink(missing_ok=True)
266
 
267
- # --- Main Business Logic Functions for Flask ---
268
 
269
- def generate_report_draft(buf, name: str, ctx: str, uid: str, project_id: str, bucket):
270
- logging.info(f"Generating report draft for project {project_id}")
271
- df = load_dataframe_safely(buf, name)
272
- llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", google_api_key=API_KEY, temperature=0.1)
273
- ctx_dict = {"shape": df.shape, "columns": list(df.columns), "user_ctx": ctx}
 
 
 
274
  enhanced_ctx = enhance_data_context(df, ctx_dict)
 
275
  report_prompt = f"""
276
- You are a senior data analyst and business intelligence expert. Analyze the provided dataset and write a comprehensive executive-level Markdown report.
277
  **Dataset Analysis Context:** {json.dumps(enhanced_ctx, indent=2)}
278
  **Instructions:**
279
  1. **Executive Summary**: Start with a high-level summary of key findings.
280
- 2. **Key Insights**: Provide 3-5 key insights, each with its own chart tag.
281
- 3. **Visual Support**: Insert chart tags like: `<generate_chart: "chart_type | specific description">`.
282
  Valid chart types: bar, pie, line, scatter, hist.
283
  Generate insights that would be valuable to C-level executives.
284
  """
285
  md = llm.invoke(report_prompt).content
286
  chart_descs = extract_chart_tags(md)[:MAX_CHARTS]
287
- chart_urls = {}
 
288
  chart_generator = ChartGenerator(llm, df)
289
  for desc in chart_descs:
290
- with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
291
- img_path = Path(temp_file.name)
292
- try:
293
- chart_spec = chart_generator.generate_chart_spec(desc)
294
- if execute_chart_spec(chart_spec, df, img_path):
295
- blob_name = f"sozo_projects/{uid}/{project_id}/charts/{uuid.uuid4().hex}.png"
296
- blob = bucket.blob(blob_name)
297
- blob.upload_from_filename(str(img_path))
298
- chart_urls[desc] = blob.public_url
299
- logging.info(f"Uploaded chart '{desc}' to {blob.public_url}")
300
- finally:
301
- os.unlink(img_path)
302
- return {"raw_md": md, "chartUrls": chart_urls}
303
-
304
- def generate_single_chart(df: pd.DataFrame, description: str, uid: str, project_id: str, bucket):
305
- logging.info(f"Generating single chart '{description}' for project {project_id}")
306
- llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", google_api_key=API_KEY, temperature=0.1)
307
- chart_generator = ChartGenerator(llm, df)
308
- with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
309
- img_path = Path(temp_file.name)
310
  try:
311
- chart_spec = chart_generator.generate_chart_spec(description)
312
- if execute_chart_spec(chart_spec, df, img_path):
313
- blob_name = f"sozo_projects/{uid}/{project_id}/charts/{uuid.uuid4().hex}.png"
314
- blob = bucket.blob(blob_name)
315
- blob.upload_from_filename(str(img_path))
316
- logging.info(f"Uploaded single chart to {blob.public_url}")
317
- return blob.public_url
318
- finally:
319
- os.unlink(img_path)
320
- return None
321
-
322
- def generate_video_from_project(df: pd.DataFrame, raw_md: str, uid: str, project_id: str, voice_model: str, bucket):
323
- logging.info(f"Generating video for project {project_id} with voice {voice_model}")
324
- llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", google_api_key=API_KEY, temperature=0.2)
325
- story_prompt = f"Based on the following report, create a script for a {VIDEO_SCENES}-scene video. Each scene must be separated by '[SCENE_BREAK]' and contain narration and one chart tag. Report: {raw_md}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
326
  script = llm.invoke(story_prompt).content
327
- scenes = [s.strip() for s in script.split("[SCENE_BREAK]") if s.strip()]
328
- video_parts, audio_parts, temps = [], [], []
329
- for sc in scenes:
330
- descs, narrative = extract_chart_tags(sc), clean_narration(sc)
 
 
 
 
 
 
 
 
 
 
 
 
 
331
  audio_bytes = deepgram_tts(narrative, voice_model)
332
- mp3 = Path(tempfile.gettempdir()) / f"{uuid.uuid4()}.mp3"
333
  if audio_bytes:
334
- mp3.write_bytes(audio_bytes); dur = audio_duration(str(mp3))
335
- if dur <= 0.1: dur = 5.0
336
  else:
337
- dur = 5.0; generate_silence_mp3(dur, mp3)
338
- audio_parts.append(str(mp3)); temps.append(mp3)
339
- mp4 = Path(tempfile.gettempdir()) / f"{uuid.uuid4()}.mp4"
340
- if descs: safe_chart(descs[0], df, dur, mp4)
341
- else:
342
- img = generate_image_from_prompt(narrative)
343
- img_cv = cv2.cvtColor(np.array(img.resize((WIDTH, HEIGHT))), cv2.COLOR_RGB2BGR)
344
- animate_image_fade(img_cv, dur, mp4)
345
- video_parts.append(str(mp4)); temps.append(mp4)
346
-
347
- with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_vid, \
348
- tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as temp_aud, \
349
- tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as final_vid:
350
-
351
- silent_vid_path = Path(temp_vid.name)
352
- audio_mix_path = Path(temp_aud.name)
353
- final_vid_path = Path(final_vid.name)
354
-
355
- concat_media(video_parts, silent_vid_path)
356
- concat_media(audio_parts, audio_mix_path)
357
-
358
- subprocess.run(
359
- ["ffmpeg", "-y", "-i", str(silent_vid_path), "-i", str(audio_mix_path),
360
- "-c:v", "libx264", "-pix_fmt", "yuv420p", "-c:a", "aac",
361
- "-map", "0:v:0", "-map", "1:a:0", "-shortest", str(final_vid_path)],
362
- check=True, capture_output=True,
363
- )
364
-
365
- blob_name = f"sozo_projects/{uid}/{project_id}/video.mp4"
366
- blob = bucket.blob(blob_name)
367
- blob.upload_from_filename(str(final_vid_path))
368
- logging.info(f"Uploaded video to {blob.public_url}")
369
 
370
- for p in temps + [silent_vid_path, audio_mix_path, final_vid_path]:
371
- if os.path.exists(p): os.unlink(p)
372
-
373
- return blob.public_url
374
- return None
 
 
 
1
  import os
2
  import re
3
  import json
 
7
  from pathlib import Path
8
  import pandas as pd
9
  import numpy as np
 
 
 
 
 
 
 
 
10
  import subprocess
11
+ from typing import Dict, List, Any
12
+
13
  from langchain_google_genai import ChatGoogleGenerativeAI
14
  from google import genai
15
  import requests
16
 
17
  # --- Configuration ---
18
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - [%(funcName)s] - %(message)s')
 
19
  MAX_CHARTS, VIDEO_SCENES = 5, 5
20
 
21
  # --- Gemini API Initialization ---
 
22
  API_KEY = os.getenv("GOOGLE_API_KEY")
23
  if not API_KEY:
24
  raise ValueError("GOOGLE_API_KEY environment variable not set.")
 
25
 
26
  # --- Helper Functions ---
27
+ def load_dataframe_safely(buf, name: str) -> pd.DataFrame:
28
+ """Loads a dataframe from a buffer, handling CSV or Excel files."""
29
  ext = Path(name).suffix.lower()
30
+ try:
31
+ df = (pd.read_excel if ext in (".xlsx", ".xls") else pd.read_csv)(buf)
32
+ df.columns = df.columns.astype(str).str.strip()
33
+ df = df.dropna(how="all")
34
+ if df.empty or len(df.columns) == 0:
35
+ raise ValueError("No usable data found in the file.")
36
+ # Convert entire dataframe to JSON-compatible types
37
+ return df.replace({np.nan: None})
38
+ except Exception as e:
39
+ logging.error(f"Failed to load dataframe: {e}")
40
+ raise ValueError(f"Could not parse the file: {e}")
41
 
42
+ def deepgram_tts(txt: str, voice_model: str) -> bytes | None:
43
+ """Generates speech from text using Deepgram and returns raw audio bytes."""
44
  DG_KEY = os.getenv("DEEPGRAM_API_KEY")
45
+ if not DG_KEY or not txt:
46
+ return None
47
+ # Clean and truncate text for the API
48
  txt = re.sub(r"[^\w\s.,!?;:-]", "", txt)[:1000]
49
  try:
50
+ r = requests.post(
51
+ "https://api.deepgram.com/v1/speak",
52
+ params={"model": voice_model},
53
+ headers={"Authorization": f"Token {DG_KEY}", "Content-Type": "application/json"},
54
+ json={"text": txt},
55
+ timeout=30
56
+ )
57
  r.raise_for_status()
58
  return r.content
59
  except Exception as e:
60
  logging.error(f"Deepgram TTS failed: {e}")
61
  return None
62
 
63
+ TAG_RE = re.compile(r'[<[]\s*generate_?chart\s*[:=]?\s*[\"\'“”]?(?P<d>[^>\"\'”\]]+?)[\"\'“”]?\s*[>\]]', re.I)
64
+ extract_chart_tags = lambda t: list(dict.fromkeys(m.group("d").strip() for m in TAG_RE.finditer(t or "")))
 
 
 
 
 
 
 
 
 
65
 
66
  re_scene = re.compile(r"^\s*scene\s*\d+[:.\- ]*", re.I | re.M)
67
  def clean_narration(txt: str) -> str:
68
+ """Cleans narration text by removing chart tags and other artifacts."""
69
+ txt = TAG_RE.sub("", txt)
70
+ txt = re_scene.sub("", txt)
71
  phrases_to_remove = [r"as you can see in the chart", r"this chart shows", r"the chart illustrates", r"in this visual", r"this graph displays"]
72
+ for phrase in phrases_to_remove:
73
+ txt = re.sub(phrase, "", txt, flags=re.IGNORECASE)
74
+ txt = re.sub(r"\s*\([^)]*\)", "", txt)
75
+ txt = re.sub(r"[\*#_]", "", txt)
76
  return re.sub(r"\s{2,}", " ", txt).strip()
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  # --- Chart Generation System ---
79
  class ChartSpecification:
80
+ """A data class to hold the specification for a chart."""
81
+ def __init__(self, chart_type: str, title: str, x_col: str, y_col: str, agg_method: str = None, top_n: int = None):
82
+ self.chart_type = chart_type
83
+ self.title = title
84
+ self.x_col = x_col
85
+ self.y_col = y_col
86
+ self.agg_method = agg_method or "sum"
87
+ self.top_n = top_n
88
 
89
  def enhance_data_context(df: pd.DataFrame, ctx_dict: Dict) -> Dict:
90
+ """Enhances the data context for the LLM with column types."""
91
+ enhanced_ctx = ctx_dict.copy()
92
+ numeric_cols = df.select_dtypes(include=['number']).columns.tolist()
93
+ categorical_cols = df.select_dtypes(exclude=['number']).columns.tolist()
94
  enhanced_ctx.update({"numeric_columns": numeric_cols, "categorical_columns": categorical_cols})
95
  return enhanced_ctx
96
 
97
+ def prepare_plot_data(spec: ChartSpecification, df: pd.DataFrame) -> List[Dict]:
98
+ """Prepares data for a chart based on its specification and returns it in a JSON-friendly format."""
99
+ if spec.x_col not in df.columns or (spec.y_col and spec.y_col not in df.columns):
100
+ raise ValueError(f"Invalid columns in chart spec: {spec.x_col}, {spec.y_col}")
101
+
102
+ if spec.chart_type in ["bar", "pie"]:
103
+ if not spec.y_col: # Case: count occurrences of a single categorical column
104
+ plot_data = df[spec.x_col].value_counts().nlargest(spec.top_n or 10)
105
+ else: # Case: aggregate a numeric column by a categorical column
106
+ grouped = df.groupby(spec.x_col)[spec.y_col].agg(spec.agg_method or 'sum')
107
+ plot_data = grouped.nlargest(spec.top_n or 10)
108
+ return plot_data.reset_index().rename(columns={'index': spec.x_col, plot_data.name: spec.y_col}).to_dict(orient='records')
109
+
110
+ elif spec.chart_type == "line":
111
+ plot_data = df.set_index(spec.x_col)[spec.y_col].sort_index()
112
+ return plot_data.reset_index().to_dict(orient='records')
113
+
114
+ elif spec.chart_type == "scatter":
115
+ return df[[spec.x_col, spec.y_col]].dropna().to_dict(orient='records')
116
+
117
+ elif spec.chart_type == "hist":
118
+ # For histograms, we just need the single column of data. The client will bin it.
119
+ return df[[spec.x_col]].dropna().to_dict(orient='records')
120
+
121
+ return []
122
+
123
+
124
  class ChartGenerator:
125
  def __init__(self, llm, df: pd.DataFrame):
126
+ self.llm = llm
127
+ self.df = df
128
+ self.enhanced_ctx = enhance_data_context(df, {"columns": list(df.columns), "shape": df.shape})
129
 
130
+ def generate_chart_spec(self, description: str) -> Dict[str, Any]:
131
+ """Generates a complete chart specification, including the data, as a dictionary."""
132
  spec_prompt = f"""
133
  You are a data visualization expert. Based on the dataset and chart description, generate a precise chart specification.
134
  **Dataset Info:** {json.dumps(self.enhanced_ctx, indent=2)}
 
145
  if response.startswith("```json"): response = response[7:-3]
146
  elif response.startswith("```"): response = response[3:-3]
147
  spec_dict = json.loads(response)
148
+
149
+ # Create a spec object to validate and use helper methods
150
+ spec_obj = ChartSpecification(
151
+ chart_type=spec_dict.get("chart_type"),
152
+ title=spec_dict.get("title"),
153
+ x_col=spec_dict.get("x_col"),
154
+ y_col=spec_dict.get("y_col"),
155
+ agg_method=spec_dict.get("agg_method"),
156
+ top_n=spec_dict.get("top_n")
157
+ )
158
+
159
+ # Prepare the data and add it to the final spec dictionary
160
+ chart_data = prepare_plot_data(spec_obj, self.df)
161
+ final_spec = {
162
+ "id": "chart_" + uuid.uuid4().hex,
163
+ "chart_type": spec_obj.chart_type,
164
+ "title": spec_obj.title,
165
+ "x_col": spec_obj.x_col,
166
+ "y_col": spec_obj.y_col,
167
+ "data": chart_data
168
+ }
169
+ return final_spec
170
  except Exception as e:
171
+ logging.error(f"Chart spec generation failed: {e}. Cannot proceed.")
172
+ # In the new architecture, we cannot fall back to a placeholder. We must fail.
173
+ raise ValueError(f"Failed to generate a valid chart specification for '{description}'.") from e
174
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
 
176
+ # --- Main Business Logic Functions ---
177
 
178
+ def generate_report_draft(df: pd.DataFrame, user_context: str) -> Dict[str, Any]:
179
+ """
180
+ Generates a markdown report draft and a list of chart specifications.
181
+ This function does NOT interact with storage.
182
+ """
183
+ logging.info("Generating report draft and chart specifications.")
184
+ llm = ChatGoogleGenerativeAI(model="gemini-1.5-flash", google_api_key=API_KEY, temperature=0.1)
185
+ ctx_dict = {"shape": df.shape, "columns": list(df.columns), "user_ctx": user_context}
186
  enhanced_ctx = enhance_data_context(df, ctx_dict)
187
+
188
  report_prompt = f"""
189
+ You are a senior data analyst. Analyze the provided dataset context and write a comprehensive executive-level Markdown report.
190
  **Dataset Analysis Context:** {json.dumps(enhanced_ctx, indent=2)}
191
  **Instructions:**
192
  1. **Executive Summary**: Start with a high-level summary of key findings.
193
+ 2. **Key Insights**: Provide 3-5 key insights. For each insight that can be visualized, insert a chart tag on its own line.
194
+ 3. **Visual Support**: Use chart tags like: <generate_chart: "A bar chart showing total sales by region">.
195
  Valid chart types: bar, pie, line, scatter, hist.
196
  Generate insights that would be valuable to C-level executives.
197
  """
198
  md = llm.invoke(report_prompt).content
199
  chart_descs = extract_chart_tags(md)[:MAX_CHARTS]
200
+ chart_specs = []
201
+
202
  chart_generator = ChartGenerator(llm, df)
203
  for desc in chart_descs:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
  try:
205
+ spec = chart_generator.generate_chart_spec(desc)
206
+ chart_specs.append(spec)
207
+ except Exception as e:
208
+ logging.warning(f"Could not generate spec for chart '{desc}': {e}")
209
+ # Continue without the failed chart
210
+
211
+ return {"raw_md": md, "chart_specs": chart_specs}
212
+
213
+
214
+ def generate_single_chart(df: pd.DataFrame, description: str) -> Dict[str, Any]:
215
+ """Generates a single chart specification dictionary."""
216
+ logging.info(f"Generating single chart spec for '{description}'")
217
+ llm = ChatGoogleGenerativeAI(model="gemini-1.5-flash", google_api_key=API_KEY, temperature=0.1)
218
+ chart_generator = ChartGenerator(llm, df)
219
+ return chart_generator.generate_chart_spec(description)
220
+
221
+
222
+ def generate_video_from_project(df: pd.DataFrame, raw_md: str, voice_model: str) -> Dict[str, Any]:
223
+ """
224
+ Generates a video script with narration text, chart specs, and raw audio bytes.
225
+ This function does NOT interact with storage.
226
+ """
227
+ logging.info(f"Generating video script with voice {voice_model}")
228
+ llm = ChatGoogleGenerativeAI(model="gemini-1.5-flash", google_api_key=API_KEY, temperature=0.2)
229
+ story_prompt = f"""
230
+ Based on the following report, create a script for a {VIDEO_SCENES}-scene video.
231
+ Each scene must be separated by '[SCENE_BREAK]'.
232
+ Each scene should contain narration text and, if relevant, exactly one chart tag from the report.
233
+ Example Scene:
234
+ Narration: We saw significant growth in the southern region.
235
+ <generate_chart: "A bar chart showing total sales by region">
236
+
237
+ [SCENE_BREAK]
238
+
239
+ Narration: This growth was driven primarily by our new product line.
240
+ <generate_chart: "A pie chart of product line performance">
241
+
242
+ Here is the report:
243
+ {raw_md}
244
+ """
245
  script = llm.invoke(story_prompt).content
246
+ scenes_text = [s.strip() for s in script.split("[SCENE_BREAK]") if s.strip()]
247
+
248
+ video_script = {"scenes": []}
249
+ chart_generator = ChartGenerator(llm, df)
250
+
251
+ for i, sc_text in enumerate(scenes_text):
252
+ descs = extract_chart_tags(sc_text)
253
+ narrative = clean_narration(sc_text)
254
+
255
+ scene_spec = {
256
+ "scene_id": f"scene_{i+1}",
257
+ "narration": narrative,
258
+ "audio_content": None, # Will hold raw bytes
259
+ "chart_spec": None
260
+ }
261
+
262
+ # Generate audio for the narration
263
  audio_bytes = deepgram_tts(narrative, voice_model)
 
264
  if audio_bytes:
265
+ scene_spec["audio_content"] = audio_bytes
 
266
  else:
267
+ logging.warning(f"Could not generate audio for scene {i+1}. It will be silent.")
268
+
269
+ # Generate chart spec if a tag exists
270
+ if descs:
271
+ try:
272
+ chart_spec = chart_generator.generate_chart_spec(descs[0])
273
+ scene_spec["chart_spec"] = chart_spec
274
+ except Exception as e:
275
+ logging.warning(f"Could not generate chart for scene {i+1}: {e}")
276
+
277
+ video_script["scenes"].append(scene_spec)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
 
279
+ return video_script