rairo commited on
Commit
848688d
·
verified ·
1 Parent(s): 3df5109

Update sozo_gen.py

Browse files
Files changed (1) hide show
  1. sozo_gen.py +234 -170
sozo_gen.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import os
2
  import re
3
  import json
@@ -7,15 +9,23 @@ import io
7
  from pathlib import Path
8
  import pandas as pd
9
  import numpy as np
10
- import subprocess
11
- from typing import Dict, List, Any
 
 
 
 
 
12
  import tempfile
 
 
13
  from langchain_google_genai import ChatGoogleGenerativeAI
14
  from google import genai
15
  import requests
16
 
17
  # --- Configuration ---
18
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - [%(funcName)s] - %(message)s')
 
19
  MAX_CHARTS, VIDEO_SCENES = 5, 5
20
 
21
  # --- Gemini API Initialization ---
@@ -24,111 +34,78 @@ if not API_KEY:
24
  raise ValueError("GOOGLE_API_KEY environment variable not set.")
25
 
26
  # --- Helper Functions ---
27
- def load_dataframe_safely(buf, name: str) -> pd.DataFrame:
28
- """Loads a dataframe from a buffer, handling CSV or Excel files."""
29
  ext = Path(name).suffix.lower()
30
- try:
31
- df = (pd.read_excel if ext in (".xlsx", ".xls") else pd.read_csv)(buf)
32
- df.columns = df.columns.astype(str).str.strip()
33
- df = df.dropna(how="all")
34
- if df.empty or len(df.columns) == 0:
35
- raise ValueError("No usable data found in the file.")
36
- # Convert entire dataframe to JSON-compatible types
37
- return df.replace({np.nan: None})
38
- except Exception as e:
39
- logging.error(f"Failed to load dataframe: {e}")
40
- raise ValueError(f"Could not parse the file: {e}")
41
 
42
- def deepgram_tts(txt: str, voice_model: str) -> bytes | None:
43
- """Generates speech from text using Deepgram and returns raw audio bytes."""
44
  DG_KEY = os.getenv("DEEPGRAM_API_KEY")
45
- if not DG_KEY or not txt:
46
- return None
47
- # Clean and truncate text for the API
48
  txt = re.sub(r"[^\w\s.,!?;:-]", "", txt)[:1000]
49
  try:
50
- r = requests.post(
51
- "https://api.deepgram.com/v1/speak",
52
- params={"model": voice_model},
53
- headers={"Authorization": f"Token {DG_KEY}", "Content-Type": "application/json"},
54
- json={"text": txt},
55
- timeout=30
56
- )
57
  r.raise_for_status()
58
  return r.content
59
  except Exception as e:
60
  logging.error(f"Deepgram TTS failed: {e}")
61
  return None
62
 
63
- TAG_RE = re.compile(r'[<[]\s*generate_?chart\s*[:=]?\s*[\"\'“”]?(?P<d>[^>\"\'”\]]+?)[\"\'“”]?\s*[>\]]', re.I)
64
- extract_chart_tags = lambda t: list(dict.fromkeys(m.group("d").strip() for m in TAG_RE.finditer(t or "")))
 
 
 
 
 
 
 
 
 
65
 
66
  re_scene = re.compile(r"^\s*scene\s*\d+[:.\- ]*", re.I | re.M)
67
  def clean_narration(txt: str) -> str:
68
- """Cleans narration text by removing chart tags and other artifacts."""
69
- txt = TAG_RE.sub("", txt)
70
- txt = re_scene.sub("", txt)
71
  phrases_to_remove = [r"as you can see in the chart", r"this chart shows", r"the chart illustrates", r"in this visual", r"this graph displays"]
72
- for phrase in phrases_to_remove:
73
- txt = re.sub(phrase, "", txt, flags=re.IGNORECASE)
74
- txt = re.sub(r"\s*\([^)]*\)", "", txt)
75
- txt = re.sub(r"[\*#_]", "", txt)
76
  return re.sub(r"\s{2,}", " ", txt).strip()
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  # --- Chart Generation System ---
79
  class ChartSpecification:
80
- """A data class to hold the specification for a chart."""
81
- def __init__(self, chart_type: str, title: str, x_col: str, y_col: str, agg_method: str = None, top_n: int = None):
82
- self.chart_type = chart_type
83
- self.title = title
84
- self.x_col = x_col
85
- self.y_col = y_col
86
- self.agg_method = agg_method or "sum"
87
- self.top_n = top_n
88
 
89
  def enhance_data_context(df: pd.DataFrame, ctx_dict: Dict) -> Dict:
90
- """Enhances the data context for the LLM with column types."""
91
- enhanced_ctx = ctx_dict.copy()
92
- numeric_cols = df.select_dtypes(include=['number']).columns.tolist()
93
- categorical_cols = df.select_dtypes(exclude=['number']).columns.tolist()
94
  enhanced_ctx.update({"numeric_columns": numeric_cols, "categorical_columns": categorical_cols})
95
  return enhanced_ctx
96
 
97
- def prepare_plot_data(spec: ChartSpecification, df: pd.DataFrame) -> List[Dict]:
98
- """Prepares data for a chart based on its specification and returns it in a JSON-friendly format."""
99
- if spec.x_col not in df.columns or (spec.y_col and spec.y_col not in df.columns):
100
- raise ValueError(f"Invalid columns in chart spec: {spec.x_col}, {spec.y_col}")
101
-
102
- if spec.chart_type in ["bar", "pie"]:
103
- if not spec.y_col: # Case: count occurrences of a single categorical column
104
- plot_data = df[spec.x_col].value_counts().nlargest(spec.top_n or 10)
105
- else: # Case: aggregate a numeric column by a categorical column
106
- grouped = df.groupby(spec.x_col)[spec.y_col].agg(spec.agg_method or 'sum')
107
- plot_data = grouped.nlargest(spec.top_n or 10)
108
- return plot_data.reset_index().rename(columns={'index': spec.x_col, plot_data.name: spec.y_col}).to_dict(orient='records')
109
-
110
- elif spec.chart_type == "line":
111
- plot_data = df.set_index(spec.x_col)[spec.y_col].sort_index()
112
- return plot_data.reset_index().to_dict(orient='records')
113
-
114
- elif spec.chart_type == "scatter":
115
- return df[[spec.x_col, spec.y_col]].dropna().to_dict(orient='records')
116
-
117
- elif spec.chart_type == "hist":
118
- # For histograms, we just need the single column of data. The client will bin it.
119
- return df[[spec.x_col]].dropna().to_dict(orient='records')
120
-
121
- return []
122
-
123
-
124
  class ChartGenerator:
125
  def __init__(self, llm, df: pd.DataFrame):
126
- self.llm = llm
127
- self.df = df
128
- self.enhanced_ctx = enhance_data_context(df, {"columns": list(df.columns), "shape": df.shape})
129
 
130
- def generate_chart_spec(self, description: str) -> Dict[str, Any]:
131
- """Generates a complete chart specification, including the data, as a dictionary."""
132
  spec_prompt = f"""
133
  You are a data visualization expert. Based on the dataset and chart description, generate a precise chart specification.
134
  **Dataset Info:** {json.dumps(self.enhanced_ctx, indent=2)}
@@ -145,34 +122,21 @@ class ChartGenerator:
145
  if response.startswith("```json"): response = response[7:-3]
146
  elif response.startswith("```"): response = response[3:-3]
147
  spec_dict = json.loads(response)
148
-
149
- # Create a spec object to validate and use helper methods
150
- spec_obj = ChartSpecification(
151
- chart_type=spec_dict.get("chart_type"),
152
- title=spec_dict.get("title"),
153
- x_col=spec_dict.get("x_col"),
154
- y_col=spec_dict.get("y_col"),
155
- agg_method=spec_dict.get("agg_method"),
156
- top_n=spec_dict.get("top_n")
157
- )
158
-
159
- # Prepare the data and add it to the final spec dictionary
160
- chart_data = prepare_plot_data(spec_obj, self.df)
161
- final_spec = {
162
- "id": "chart_" + uuid.uuid4().hex,
163
- "chart_type": spec_obj.chart_type,
164
- "title": spec_obj.title,
165
- "x_col": spec_obj.x_col,
166
- "y_col": spec_obj.y_col,
167
- "data": chart_data
168
- }
169
- return final_spec
170
  except Exception as e:
171
- logging.error(f"Chart spec generation failed: {e}. Cannot proceed.")
172
- # In the new architecture, we cannot fall back to a placeholder. We must fail.
173
- raise ValueError(f"Failed to generate a valid chart specification for '{description}'.") from e
174
-
175
- # PASTE THIS CODE INTO YOUR sozo_gen.py FILE
 
 
 
 
 
 
176
 
177
  def execute_chart_spec(spec: ChartSpecification, df: pd.DataFrame, output_path: Path) -> bool:
178
  try:
@@ -198,14 +162,107 @@ def prepare_plot_data(spec: ChartSpecification, df: pd.DataFrame) -> pd.Series:
198
  elif spec.chart_type == "scatter": return df[[spec.x_col, spec.y_col]].dropna()
199
  elif spec.chart_type == "hist": return df[spec.x_col].dropna()
200
  return df[spec.x_col]
201
- # --- Main Business Logic Functions ---
202
 
203
- # In sozo_gen.py
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
 
205
  def generate_report_draft(buf, name: str, ctx: str, uid: str, project_id: str, bucket):
206
  logging.info(f"Generating report draft for project {project_id}")
207
  df = load_dataframe_safely(buf, name)
208
- llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", google_api_key=API_KEY, temperature=0.1)
209
  ctx_dict = {"shape": df.shape, "columns": list(df.columns), "user_ctx": ctx}
210
  enhanced_ctx = enhance_data_context(df, ctx_dict)
211
  report_prompt = f"""
@@ -234,72 +291,79 @@ def generate_report_draft(buf, name: str, ctx: str, uid: str, project_id: str, b
234
  chart_urls[desc] = blob.public_url
235
  logging.info(f"Uploaded chart '{desc}' to {blob.public_url}")
236
  finally:
237
- os.unlink(img_path)
 
238
  return {"raw_md": md, "chartUrls": chart_urls}
239
 
240
- def generate_single_chart(df: pd.DataFrame, description: str) -> Dict[str, Any]:
241
- """Generates a single chart specification dictionary."""
242
- logging.info(f"Generating single chart spec for '{description}'")
243
- llm = ChatGoogleGenerativeAI(model="gemini-1.5-flash", google_api_key=API_KEY, temperature=0.1)
244
  chart_generator = ChartGenerator(llm, df)
245
- return chart_generator.generate_chart_spec(description)
246
-
247
-
248
- def generate_video_from_project(df: pd.DataFrame, raw_md: str, voice_model: str) -> Dict[str, Any]:
249
- """
250
- Generates a video script with narration text, chart specs, and raw audio bytes.
251
- This function does NOT interact with storage.
252
- """
253
- logging.info(f"Generating video script with voice {voice_model}")
254
- llm = ChatGoogleGenerativeAI(model="gemini-1.5-flash", google_api_key=API_KEY, temperature=0.2)
255
- story_prompt = f"""
256
- Based on the following report, create a script for a {VIDEO_SCENES}-scene video.
257
- Each scene must be separated by '[SCENE_BREAK]'.
258
- Each scene should contain narration text and, if relevant, exactly one chart tag from the report.
259
- Example Scene:
260
- Narration: We saw significant growth in the southern region.
261
- <generate_chart: "A bar chart showing total sales by region">
262
-
263
- [SCENE_BREAK]
264
-
265
- Narration: This growth was driven primarily by our new product line.
266
- <generate_chart: "A pie chart of product line performance">
267
 
268
- Here is the report:
269
- {raw_md}
270
- """
 
271
  script = llm.invoke(story_prompt).content
272
- scenes_text = [s.strip() for s in script.split("[SCENE_BREAK]") if s.strip()]
273
-
274
- video_script = {"scenes": []}
275
- chart_generator = ChartGenerator(llm, df)
276
-
277
- for i, sc_text in enumerate(scenes_text):
278
- descs = extract_chart_tags(sc_text)
279
- narrative = clean_narration(sc_text)
280
-
281
- scene_spec = {
282
- "scene_id": f"scene_{i+1}",
283
- "narration": narrative,
284
- "audio_content": None, # Will hold raw bytes
285
- "chart_spec": None
286
- }
287
-
288
- # Generate audio for the narration
289
  audio_bytes = deepgram_tts(narrative, voice_model)
 
290
  if audio_bytes:
291
- scene_spec["audio_content"] = audio_bytes
 
292
  else:
293
- logging.warning(f"Could not generate audio for scene {i+1}. It will be silent.")
294
-
295
- # Generate chart spec if a tag exists
296
- if descs:
297
- try:
298
- chart_spec = chart_generator.generate_chart_spec(descs[0])
299
- scene_spec["chart_spec"] = chart_spec
300
- except Exception as e:
301
- logging.warning(f"Could not generate chart for scene {i+1}: {e}")
302
-
303
- video_script["scenes"].append(scene_spec)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
304
 
305
- return video_script
 
 
 
 
 
1
+ # sozo_gen.py
2
+
3
  import os
4
  import re
5
  import json
 
9
  from pathlib import Path
10
  import pandas as pd
11
  import numpy as np
12
+ import matplotlib
13
+ matplotlib.use("Agg")
14
+ import matplotlib.pyplot as plt
15
+ from matplotlib.animation import FuncAnimation, FFMpegWriter
16
+ from PIL import Image
17
+ import cv2
18
+ import inspect
19
  import tempfile
20
+ import subprocess
21
+ from typing import Dict, List
22
  from langchain_google_genai import ChatGoogleGenerativeAI
23
  from google import genai
24
  import requests
25
 
26
  # --- Configuration ---
27
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - [%(funcName)s] - %(message)s')
28
+ FPS, WIDTH, HEIGHT = 24, 1280, 720
29
  MAX_CHARTS, VIDEO_SCENES = 5, 5
30
 
31
  # --- Gemini API Initialization ---
 
34
  raise ValueError("GOOGLE_API_KEY environment variable not set.")
35
 
36
  # --- Helper Functions ---
37
+ def load_dataframe_safely(buf, name: str):
 
38
  ext = Path(name).suffix.lower()
39
+ df = (pd.read_excel if ext in (".xlsx", ".xls") else pd.read_csv)(buf)
40
+ df.columns = df.columns.astype(str).str.strip()
41
+ df = df.dropna(how="all")
42
+ if df.empty or len(df.columns) == 0: raise ValueError("No usable data found")
43
+ return df
 
 
 
 
 
 
44
 
45
+ def deepgram_tts(txt: str, voice_model: str):
 
46
  DG_KEY = os.getenv("DEEPGRAM_API_KEY")
47
+ if not DG_KEY or not txt: return None
 
 
48
  txt = re.sub(r"[^\w\s.,!?;:-]", "", txt)[:1000]
49
  try:
50
+ r = requests.post("https://api.deepgram.com/v1/speak", params={"model": voice_model}, headers={"Authorization": f"Token {DG_KEY}", "Content-Type": "application/json"}, json={"text": txt}, timeout=30)
 
 
 
 
 
 
51
  r.raise_for_status()
52
  return r.content
53
  except Exception as e:
54
  logging.error(f"Deepgram TTS failed: {e}")
55
  return None
56
 
57
+ def generate_silence_mp3(duration: float, out: Path):
58
+ subprocess.run([ "ffmpeg", "-y", "-f", "lavfi", "-i", "anullsrc=r=44100:cl=mono", "-t", f"{duration:.3f}", "-q:a", "9", str(out)], check=True, capture_output=True)
59
+
60
+ def audio_duration(path: str) -> float:
61
+ try:
62
+ res = subprocess.run([ "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=nw=1:nk=1", path], text=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True)
63
+ return float(res.stdout.strip())
64
+ except Exception: return 5.0
65
+
66
+ TAG_RE = re.compile( r'[<[]\s*generate_?chart\s*[:=]?\s*[\"\'“”]?(?P<d>[^>\"\'”\]]+?)[\"\'“”]?\s*[>\]]', re.I, )
67
+ extract_chart_tags = lambda t: list( dict.fromkeys(m.group("d").strip() for m in TAG_RE.finditer(t or "")) )
68
 
69
  re_scene = re.compile(r"^\s*scene\s*\d+[:.\- ]*", re.I | re.M)
70
  def clean_narration(txt: str) -> str:
71
+ txt = TAG_RE.sub("", txt); txt = re_scene.sub("", txt)
 
 
72
  phrases_to_remove = [r"as you can see in the chart", r"this chart shows", r"the chart illustrates", r"in this visual", r"this graph displays"]
73
+ for phrase in phrases_to_remove: txt = re.sub(phrase, "", txt, flags=re.IGNORECASE)
74
+ txt = re.sub(r"\s*\([^)]*\)", "", txt); txt = re.sub(r"[\*#_]", "", txt)
 
 
75
  return re.sub(r"\s{2,}", " ", txt).strip()
76
 
77
+ def placeholder_img() -> Image.Image: return Image.new("RGB", (WIDTH, HEIGHT), (230, 230, 230))
78
+
79
+ def generate_image_from_prompt(prompt: str) -> Image.Image:
80
+ model_main = "gemini-1.5-flash-latest";
81
+ full_prompt = "A clean business-presentation illustration: " + prompt
82
+ try:
83
+ model = genai.GenerativeModel(model_main)
84
+ res = model.generate_content(full_prompt)
85
+ img_part = next((part for part in res.candidates[0].content.parts if getattr(part, "inline_data", None)), None)
86
+ if img_part:
87
+ return Image.open(io.BytesIO(img_part.inline_data.data)).convert("RGB")
88
+ return placeholder_img()
89
+ except Exception:
90
+ return placeholder_img()
91
+
92
  # --- Chart Generation System ---
93
  class ChartSpecification:
94
+ def __init__(self, chart_type: str, title: str, x_col: str, y_col: str = None, agg_method: str = None, filter_condition: str = None, top_n: int = None, color_scheme: str = "professional"):
95
+ self.chart_type = chart_type; self.title = title; self.x_col = x_col; self.y_col = y_col
96
+ self.agg_method = agg_method or "sum"; self.filter_condition = filter_condition; self.top_n = top_n; self.color_scheme = color_scheme
 
 
 
 
 
97
 
98
  def enhance_data_context(df: pd.DataFrame, ctx_dict: Dict) -> Dict:
99
+ enhanced_ctx = ctx_dict.copy(); numeric_cols = df.select_dtypes(include=['number']).columns.tolist(); categorical_cols = df.select_dtypes(exclude=['number']).columns.tolist()
 
 
 
100
  enhanced_ctx.update({"numeric_columns": numeric_cols, "categorical_columns": categorical_cols})
101
  return enhanced_ctx
102
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  class ChartGenerator:
104
  def __init__(self, llm, df: pd.DataFrame):
105
+ self.llm = llm; self.df = df
106
+ self.enhanced_ctx = enhance_data_context(df, {"columns": list(df.columns), "shape": df.shape, "dtypes": {col: str(dtype) for col, dtype in df.dtypes.items()}})
 
107
 
108
+ def generate_chart_spec(self, description: str) -> ChartSpecification:
 
109
  spec_prompt = f"""
110
  You are a data visualization expert. Based on the dataset and chart description, generate a precise chart specification.
111
  **Dataset Info:** {json.dumps(self.enhanced_ctx, indent=2)}
 
122
  if response.startswith("```json"): response = response[7:-3]
123
  elif response.startswith("```"): response = response[3:-3]
124
  spec_dict = json.loads(response)
125
+ valid_keys = [p.name for p in inspect.signature(ChartSpecification).parameters.values() if p.name not in ['reasoning', 'filter_condition', 'color_scheme']]
126
+ filtered_dict = {k: v for k, v in spec_dict.items() if k in valid_keys}
127
+ return ChartSpecification(**filtered_dict)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  except Exception as e:
129
+ logging.error(f"Spec generation failed: {e}. Using fallback.")
130
+ return self._create_fallback_spec(description)
131
+
132
+ def _create_fallback_spec(self, description: str) -> ChartSpecification:
133
+ numeric_cols = self.enhanced_ctx['numeric_columns']; categorical_cols = self.enhanced_ctx['categorical_columns']
134
+ ctype = "bar"
135
+ for t in ["pie", "line", "scatter", "hist"]:
136
+ if t in description.lower(): ctype = t
137
+ x = categorical_cols[0] if categorical_cols else self.df.columns[0]
138
+ y = numeric_cols[0] if numeric_cols and len(self.df.columns) > 1 else (self.df.columns[1] if len(self.df.columns) > 1 else None)
139
+ return ChartSpecification(ctype, description, x, y)
140
 
141
  def execute_chart_spec(spec: ChartSpecification, df: pd.DataFrame, output_path: Path) -> bool:
142
  try:
 
162
  elif spec.chart_type == "scatter": return df[[spec.x_col, spec.y_col]].dropna()
163
  elif spec.chart_type == "hist": return df[spec.x_col].dropna()
164
  return df[spec.x_col]
 
165
 
166
+ # --- Animation & Video Generation ---
167
+ def animate_chart(spec: ChartSpecification, df: pd.DataFrame, dur: float, out: Path, fps: int = FPS) -> str:
168
+ plot_data = prepare_plot_data(spec, df)
169
+ frames = max(10, int(dur * fps))
170
+ fig, ax = plt.subplots(figsize=(WIDTH / 100, HEIGHT / 100), dpi=100)
171
+ plt.tight_layout(pad=3.0)
172
+ ctype = spec.chart_type
173
+ if ctype == "pie":
174
+ wedges, _, _ = ax.pie(plot_data, labels=plot_data.index, startangle=90, autopct='%1.1f%%')
175
+ ax.set_title(spec.title); ax.axis('equal')
176
+ def init(): [w.set_alpha(0) for w in wedges]; return wedges
177
+ def update(i): [w.set_alpha(i / (frames - 1)) for w in wedges]; return wedges
178
+ elif ctype == "bar":
179
+ bars = ax.bar(plot_data.index.astype(str), np.zeros_like(plot_data.values, dtype=float), color="#1f77b4")
180
+ ax.set_ylim(0, plot_data.max() * 1.1 if not pd.isna(plot_data.max()) and plot_data.max() > 0 else 1)
181
+ ax.set_title(spec.title); plt.xticks(rotation=45, ha="right")
182
+ def init(): return bars
183
+ def update(i):
184
+ for b, h in zip(bars, plot_data.values): b.set_height(h * (i / (frames - 1)))
185
+ return bars
186
+ elif ctype == "scatter":
187
+ scat = ax.scatter([], [], alpha=0.7)
188
+ x_full, y_full = plot_data.iloc[:, 0], plot_data.iloc[:, 1]
189
+ ax.set_xlim(x_full.min(), x_full.max()); ax.set_ylim(y_full.min(), y_full.max())
190
+ ax.set_title(spec.title); ax.grid(alpha=.3); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col)
191
+ def init(): scat.set_offsets(np.empty((0, 2))); return [scat]
192
+ def update(i):
193
+ k = max(1, int(len(x_full) * (i / (frames - 1))))
194
+ scat.set_offsets(plot_data.iloc[:k].values); return [scat]
195
+ elif ctype == "hist":
196
+ _, _, patches = ax.hist(plot_data, bins=20, alpha=0)
197
+ ax.set_title(spec.title); ax.set_xlabel(spec.x_col); ax.set_ylabel("Frequency")
198
+ def init(): [p.set_alpha(0) for p in patches]; return patches
199
+ def update(i): [p.set_alpha((i / (frames - 1)) * 0.7) for p in patches]; return patches
200
+ else: # line
201
+ line, = ax.plot([], [], lw=2)
202
+ plot_data = plot_data.sort_index() if not plot_data.index.is_monotonic_increasing else plot_data
203
+ x_full, y_full = plot_data.index, plot_data.values
204
+ ax.set_xlim(x_full.min(), x_full.max()); ax.set_ylim(y_full.min() * 0.9, y_full.max() * 1.1)
205
+ ax.set_title(spec.title); ax.grid(alpha=.3); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col)
206
+ def init(): line.set_data([], []); return [line]
207
+ def update(i):
208
+ k = max(2, int(len(x_full) * (i / (frames - 1))))
209
+ line.set_data(x_full[:k], y_full[:k]); return [line]
210
+ anim = FuncAnimation(fig, update, init_func=init, frames=frames, blit=True, interval=1000 / fps)
211
+ anim.save(str(out), writer=FFMpegWriter(fps=fps), dpi=144)
212
+ plt.close(fig)
213
+ return str(out)
214
+
215
+ def animate_image_fade(img: np.ndarray, dur: float, out: Path, fps: int = 24) -> str:
216
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v'); video_writer = cv2.VideoWriter(str(out), fourcc, fps, (WIDTH, HEIGHT))
217
+ total_frames = max(1, int(dur * fps))
218
+ for i in range(total_frames):
219
+ alpha = i / (total_frames - 1) if total_frames > 1 else 1.0
220
+ frame = cv2.addWeighted(img, alpha, np.zeros_like(img), 1 - alpha, 0)
221
+ video_writer.write(frame)
222
+ video_writer.release()
223
+ return str(out)
224
+
225
+ def safe_chart(desc: str, df: pd.DataFrame, dur: float, out: Path) -> str:
226
+ try:
227
+ llm = ChatGoogleGenerativeAI(model="gemini-1.5-flash-latest", google_api_key=API_KEY, temperature=0.1)
228
+ chart_generator = ChartGenerator(llm, df)
229
+ chart_spec = chart_generator.generate_chart_spec(desc)
230
+ return animate_chart(chart_spec, df, dur, out)
231
+ except Exception as e:
232
+ logging.error(f"Chart animation failed for '{desc}': {e}. Falling back to static image.")
233
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_png_file:
234
+ temp_png = Path(temp_png_file.name)
235
+ llm = ChatGoogleGenerativeAI(model="gemini-1.5-flash-latest", google_api_key=API_KEY, temperature=0.1)
236
+ chart_generator = ChartGenerator(llm, df)
237
+ chart_spec = chart_generator.generate_chart_spec(desc)
238
+ if execute_chart_spec(chart_spec, df, temp_png):
239
+ img = cv2.imread(str(temp_png)); os.unlink(temp_png)
240
+ img_resized = cv2.resize(img, (WIDTH, HEIGHT))
241
+ return animate_image_fade(img_resized, dur, out)
242
+ else:
243
+ img = generate_image_from_prompt(f"A professional business chart showing {desc}")
244
+ img_cv = cv2.cvtColor(np.array(img.resize((WIDTH, HEIGHT))), cv2.COLOR_RGB2BGR)
245
+ return animate_image_fade(img_cv, dur, out)
246
+
247
+ def concat_media(file_paths: List[str], output_path: Path):
248
+ valid_paths = [p for p in file_paths if Path(p).exists() and Path(p).stat().st_size > 100]
249
+ if not valid_paths: raise ValueError("No valid media files to concatenate.")
250
+ if len(valid_paths) == 1: import shutil; shutil.copy2(valid_paths[0], str(output_path)); return
251
+ list_file = output_path.with_suffix(".txt")
252
+ with open(list_file, 'w') as f:
253
+ for path in valid_paths: f.write(f"file '{Path(path).resolve()}'\n")
254
+ cmd = ["ffmpeg", "-y", "-f", "concat", "-safe", "0", "-i", str(list_file), "-c", "copy", str(output_path)]
255
+ try:
256
+ subprocess.run(cmd, check=True, capture_output=True, text=True)
257
+ finally:
258
+ list_file.unlink(missing_ok=True)
259
+
260
+ # --- Main Business Logic Functions for Flask ---
261
 
262
  def generate_report_draft(buf, name: str, ctx: str, uid: str, project_id: str, bucket):
263
  logging.info(f"Generating report draft for project {project_id}")
264
  df = load_dataframe_safely(buf, name)
265
+ llm = ChatGoogleGenerativeAI(model="gemini-1.5-flash-latest", google_api_key=API_KEY, temperature=0.1)
266
  ctx_dict = {"shape": df.shape, "columns": list(df.columns), "user_ctx": ctx}
267
  enhanced_ctx = enhance_data_context(df, ctx_dict)
268
  report_prompt = f"""
 
291
  chart_urls[desc] = blob.public_url
292
  logging.info(f"Uploaded chart '{desc}' to {blob.public_url}")
293
  finally:
294
+ if os.path.exists(img_path):
295
+ os.unlink(img_path)
296
  return {"raw_md": md, "chartUrls": chart_urls}
297
 
298
+ def generate_single_chart(df: pd.DataFrame, description: str, uid: str, project_id: str, bucket):
299
+ logging.info(f"Generating single chart '{description}' for project {project_id}")
300
+ llm = ChatGoogleGenerativeAI(model="gemini-1.5-flash-latest", google_api_key=API_KEY, temperature=0.1)
 
301
  chart_generator = ChartGenerator(llm, df)
302
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
303
+ img_path = Path(temp_file.name)
304
+ try:
305
+ chart_spec = chart_generator.generate_chart_spec(description)
306
+ if execute_chart_spec(chart_spec, df, img_path):
307
+ blob_name = f"sozo_projects/{uid}/{project_id}/charts/{uuid.uuid4().hex}.png"
308
+ blob = bucket.blob(blob_name)
309
+ blob.upload_from_filename(str(img_path))
310
+ logging.info(f"Uploaded single chart to {blob.public_url}")
311
+ return blob.public_url
312
+ finally:
313
+ if os.path.exists(img_path):
314
+ os.unlink(img_path)
315
+ return None
 
 
 
 
 
 
 
 
316
 
317
+ def generate_video_from_project(df: pd.DataFrame, raw_md: str, uid: str, project_id: str, voice_model: str, bucket):
318
+ logging.info(f"Generating video for project {project_id} with voice {voice_model}")
319
+ llm = ChatGoogleGenerativeAI(model="gemini-1.5-flash-latest", google_api_key=API_KEY, temperature=0.2)
320
+ story_prompt = f"Based on the following report, create a script for a {VIDEO_SCENES}-scene video. Each scene must be separated by '[SCENE_BREAK]' and contain narration and one chart tag. Report: {raw_md}"
321
  script = llm.invoke(story_prompt).content
322
+ scenes = [s.strip() for s in script.split("[SCENE_BREAK]") if s.strip()]
323
+ video_parts, audio_parts, temps = [], [], []
324
+ for sc in scenes:
325
+ descs, narrative = extract_chart_tags(sc), clean_narration(sc)
 
 
 
 
 
 
 
 
 
 
 
 
 
326
  audio_bytes = deepgram_tts(narrative, voice_model)
327
+ mp3 = Path(tempfile.gettempdir()) / f"{uuid.uuid4()}.mp3"
328
  if audio_bytes:
329
+ mp3.write_bytes(audio_bytes); dur = audio_duration(str(mp3))
330
+ if dur <= 0.1: dur = 5.0
331
  else:
332
+ dur = 5.0; generate_silence_mp3(dur, mp3)
333
+ audio_parts.append(str(mp3)); temps.append(mp3)
334
+ mp4 = Path(tempfile.gettempdir()) / f"{uuid.uuid4()}.mp4"
335
+ if descs: safe_chart(descs[0], df, dur, mp4)
336
+ else:
337
+ img = generate_image_from_prompt(narrative)
338
+ img_cv = cv2.cvtColor(np.array(img.resize((WIDTH, HEIGHT))), cv2.COLOR_RGB2BGR)
339
+ animate_image_fade(img_cv, dur, mp4)
340
+ video_parts.append(str(mp4)); temps.append(mp4)
341
+
342
+ with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_vid, \
343
+ tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as temp_aud, \
344
+ tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as final_vid:
345
+
346
+ silent_vid_path = Path(temp_vid.name)
347
+ audio_mix_path = Path(temp_aud.name)
348
+ final_vid_path = Path(final_vid.name)
349
+
350
+ concat_media(video_parts, silent_vid_path)
351
+ concat_media(audio_parts, audio_mix_path)
352
+
353
+ subprocess.run(
354
+ ["ffmpeg", "-y", "-i", str(silent_vid_path), "-i", str(audio_mix_path),
355
+ "-c:v", "libx264", "-pix_fmt", "yuv420p", "-c:a", "aac",
356
+ "-map", "0:v:0", "-map", "1:a:0", "-shortest", str(final_vid_path)],
357
+ check=True, capture_output=True,
358
+ )
359
+
360
+ blob_name = f"sozo_projects/{uid}/{project_id}/video.mp4"
361
+ blob = bucket.blob(blob_name)
362
+ blob.upload_from_filename(str(final_vid_path))
363
+ logging.info(f"Uploaded video to {blob.public_url}")
364
 
365
+ for p in temps + [silent_vid_path, audio_mix_path, final_vid_path]:
366
+ if os.path.exists(p): os.unlink(p)
367
+
368
+ return blob.public_url
369
+ return None