rairo commited on
Commit
f635a57
·
verified ·
1 Parent(s): 19ee271

Update sozo_gen.py

Browse files
Files changed (1) hide show
  1. sozo_gen.py +52 -55
sozo_gen.py CHANGED
@@ -107,19 +107,19 @@ def detect_dataset_domain(df: pd.DataFrame) -> str:
107
  # NEW: Keyword extraction for better Pexels searches
108
  def extract_keywords_for_query(text: str, llm) -> str:
109
  prompt = f"""
110
- Extract 2-4 key nouns and verbs from the following text to use as a search query for a stock video.
111
  Focus on concrete actions and subjects.
112
- Example: 'Our analysis shows a significant growth in quarterly revenue and strong partnerships.' -> 'data analysis growth chart business'
113
  Output only the search query keywords, separated by spaces.
114
 
115
  Text: "{text}"
116
  """
117
  try:
118
  response = llm.invoke(prompt).content.strip()
119
- return response
120
  except Exception as e:
121
  logging.error(f"Keyword extraction failed: {e}. Using original text.")
122
- return text # Fallback to the original text if LLM fails
123
 
124
  # UPDATED: Pexels search now loops short videos
125
  def search_and_download_pexels_video(query: str, duration: float, out_path: Path) -> str:
@@ -265,7 +265,7 @@ def prepare_plot_data(spec: ChartSpecification, df: pd.DataFrame):
265
 
266
  def animate_chart(spec: ChartSpecification, df: pd.DataFrame, dur: float, out: Path, fps: int = FPS) -> str:
267
  plot_data = prepare_plot_data(spec, df)
268
- frames = math.ceil(dur * fps) # Use math.ceil to always round up frames
269
  fig, ax = plt.subplots(figsize=(WIDTH / 100, HEIGHT / 100), dpi=100)
270
  plt.tight_layout(pad=3.0)
271
  ctype = spec.chart_type
@@ -275,20 +275,33 @@ def animate_chart(spec: ChartSpecification, df: pd.DataFrame, dur: float, out: P
275
  if ctype == "line":
276
  plot_data = plot_data.sort_index()
277
  x_full, y_full = plot_data.index, plot_data.values
 
 
 
 
 
 
 
 
278
  def init():
279
- ax.set_xlim(x_full.min(), x_full.max())
280
- ax.set_ylim(y_full.min() * 0.9, y_full.max() * 1.1)
281
- return []
282
  def update(i):
283
- ax.clear()
284
- ax.set_xlim(x_full.min(), x_full.max())
285
- ax.set_ylim(y_full.min() * 0.9, y_full.max() * 1.1)
286
- ax.set_title(spec.title); ax.grid(alpha=.3); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col)
287
  k = max(2, int(len(x_full) * (i / (frames - 1))))
288
- ax.plot(x_full[:k], y_full[:k], lw=2, color='#A23B72', marker='o', markersize=5)
289
- return []
 
290
  init_func, update_func = init, update
291
- elif ctype == "bar":
 
 
 
 
 
 
 
 
292
  bars = ax.bar(plot_data.index.astype(str), np.zeros_like(plot_data.values, dtype=float), color="#1f77b4")
293
  ax.set_ylim(0, plot_data.max() * 1.1 if not pd.isna(plot_data.max()) and plot_data.max() > 0 else 1)
294
  ax.set_title(spec.title); plt.xticks(rotation=45, ha="right")
@@ -308,7 +321,7 @@ def animate_chart(spec: ChartSpecification, df: pd.DataFrame, dur: float, out: P
308
  ax.set_title(spec.title); ax.grid(alpha=.3); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col)
309
  def init():
310
  scat.set_offsets(np.empty((0, 2))); line.set_data([], [])
311
- return [scat, line]
312
  def update(i):
313
  point_frames = int(frames * 0.7)
314
  if i <= point_frames:
@@ -318,25 +331,25 @@ def animate_chart(spec: ChartSpecification, df: pd.DataFrame, dur: float, out: P
318
  line_frame = i - point_frames; line_total_frames = frames - point_frames
319
  current_x = reg_line_x[0] + (reg_line_x[1] - reg_line_x[0]) * (line_frame / line_total_frames)
320
  line.set_data([reg_line_x[0], current_x], [reg_line_y[0], slope * current_x + intercept])
321
- return [scat, line]
322
  init_func, update_func = init, update
323
  elif ctype == "pie":
324
  wedges, _, _ = ax.pie(plot_data, labels=plot_data.index, startangle=90, autopct='%1.1f%%')
325
  ax.set_title(spec.title); ax.axis('equal')
326
- def init(): [w.set_alpha(0) for w in wedges]; return wedges
327
- def update(i): [w.set_alpha(i / (frames - 1)) for w in wedges]; return wedges
328
  init_func, update_func = init, update
329
  elif ctype == "hist":
330
  _, _, patches = ax.hist(plot_data, bins=20, alpha=0)
331
  ax.set_title(spec.title); ax.set_xlabel(spec.x_col); ax.set_ylabel("Frequency")
332
- def init(): [p.set_alpha(0) for p in patches]; return patches
333
- def update(i): [p.set_alpha((i / (frames - 1)) * 0.7) for p in patches]; return patches
334
  init_func, update_func = init, update
335
  elif ctype == "heatmap":
336
  sns.heatmap(plot_data, annot=True, cmap="viridis", ax=ax, alpha=0)
337
  ax.set_title(spec.title)
338
- def init(): ax.collections[0].set_alpha(0); return [ax.collections[0]]
339
- def update(i): ax.collections[0].set_alpha(i / (frames - 1)); return [ax.collections[0]]
340
  init_func, update_func = init, update
341
  else:
342
  ax.text(0.5, 0.5, f"'{ctype}' animation not implemented", ha='center', va='center')
@@ -348,16 +361,6 @@ def animate_chart(spec: ChartSpecification, df: pd.DataFrame, dur: float, out: P
348
  anim.save(str(out), writer=FFMpegWriter(fps=fps), dpi=144)
349
  plt.close(fig)
350
  return str(out)
351
-
352
- def animate_image_fade(img: np.ndarray, dur: float, out: Path, fps: int = 24) -> str:
353
- fourcc = cv2.VideoWriter_fourcc(*'mp4v'); video_writer = cv2.VideoWriter(str(out), fourcc, fps, (WIDTH, HEIGHT))
354
- total_frames = max(1, int(dur * fps))
355
- for i in range(total_frames):
356
- alpha = i / (total_frames - 1) if total_frames > 1 else 1.0
357
- frame = cv2.addWeighted(img, alpha, np.zeros_like(img), 1 - alpha, 0)
358
- video_writer.write(frame)
359
- video_writer.release()
360
- return str(out)
361
 
362
  def safe_chart(desc: str, df: pd.DataFrame, dur: float, out: Path, context: Dict) -> str:
363
  try:
@@ -608,38 +611,32 @@ def generate_video_from_project(df: pd.DataFrame, raw_md: str, data_context: Dic
608
  video_dur = audio_dur + 1.5
609
 
610
  try:
611
- primary_query = None
612
- narration_lower = narrative.lower()
613
- is_conclusion_scene = any(k in narration_lower for k in ["conclusion", "summary", "in closing", "final thoughts"])
614
-
615
- if any(k in narration_lower for k in ["introduction", "welcome", "let's begin"]):
616
- primary_query = f"abstract technology background {domain}"
617
- elif is_conclusion_scene:
618
- primary_query = f"future strategy business meeting {domain}"
619
-
620
- if primary_query:
621
- logging.info(f"Scene {i+1}: Pre-emptive guard triggered. Query: '{primary_query}'")
622
- video_path = search_and_download_pexels_video(primary_query, video_dur, mp4)
623
- if not video_path: raise ValueError("Pexels search failed for guarded query.")
624
  video_parts.append(video_path)
625
  if is_conclusion_scene:
626
  conclusion_video_path = video_path
 
 
 
 
627
  else:
628
- chart_descs = extract_chart_tags(sc)
629
- if chart_descs:
630
- logging.info(f"Scene {i+1}: Primary attempt with animated chart.")
631
- safe_chart(chart_descs[0], df, video_dur, mp4, data_context)
632
- video_parts.append(str(mp4))
633
- else:
634
- raise ValueError("No chart tag found in a middle scene.")
635
  except Exception as e:
636
  logging.warning(f"Scene {i+1}: Primary visual failed ({e}). Marking for fallback.")
637
  video_parts.append("FALLBACK_NEEDED")
638
 
639
  temps.append(mp4)
640
 
641
- # Post-processing loop to apply the conclusion video as a fallback
642
- if not conclusion_video_path: # Failsafe if conclusion scene itself failed
643
  logging.warning("No conclusion video was generated; creating a generic one for fallbacks.")
644
  fallback_mp4 = Path(tempfile.gettempdir()) / f"{uuid.uuid4()}.mp4"
645
  conclusion_video_path = search_and_download_pexels_video(f"data visualization abstract {domain}", 5.0, fallback_mp4)
 
107
  # NEW: Keyword extraction for better Pexels searches
108
  def extract_keywords_for_query(text: str, llm) -> str:
109
  prompt = f"""
110
+ Extract a maximum of 3 key nouns or verbs from the following text to use as a search query for a stock video.
111
  Focus on concrete actions and subjects.
112
+ Example: 'Our analysis shows a significant growth in quarterly revenue and strong partnerships.' -> 'data analysis growth'
113
  Output only the search query keywords, separated by spaces.
114
 
115
  Text: "{text}"
116
  """
117
  try:
118
  response = llm.invoke(prompt).content.strip()
119
+ return response if response else text
120
  except Exception as e:
121
  logging.error(f"Keyword extraction failed: {e}. Using original text.")
122
+ return text
123
 
124
  # UPDATED: Pexels search now loops short videos
125
  def search_and_download_pexels_video(query: str, duration: float, out_path: Path) -> str:
 
265
 
266
  def animate_chart(spec: ChartSpecification, df: pd.DataFrame, dur: float, out: Path, fps: int = FPS) -> str:
267
  plot_data = prepare_plot_data(spec, df)
268
+ frames = math.ceil(dur * fps)
269
  fig, ax = plt.subplots(figsize=(WIDTH / 100, HEIGHT / 100), dpi=100)
270
  plt.tight_layout(pad=3.0)
271
  ctype = spec.chart_type
 
275
  if ctype == "line":
276
  plot_data = plot_data.sort_index()
277
  x_full, y_full = plot_data.index, plot_data.values
278
+
279
+ ax.set_xlim(x_full.min(), x_full.max())
280
+ ax.set_ylim(y_full.min() * 0.9, y_full.max() * 1.1)
281
+ ax.set_title(spec.title); ax.grid(alpha=.3); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col)
282
+
283
+ line, = ax.plot([], [], lw=2, color='#A23B72')
284
+ markers, = ax.plot([], [], 'o', color='#A23B72', markersize=5)
285
+
286
  def init():
287
+ line.set_data([], [])
288
+ markers.set_data([], [])
289
+ return line, markers
290
  def update(i):
 
 
 
 
291
  k = max(2, int(len(x_full) * (i / (frames - 1))))
292
+ line.set_data(x_full[:k], y_full[:k])
293
+ markers.set_data(x_full[:k], y_full[:k])
294
+ return line, markers
295
  init_func, update_func = init, update
296
+
297
+ anim = FuncAnimation(fig, update, init_func=init, frames=frames, blit=True, interval=1000 / fps)
298
+ anim.save(str(out), writer=FFMpegWriter(fps=fps), dpi=144)
299
+ plt.close(fig)
300
+ return str(out)
301
+
302
+ # Fallback to the slightly slower but reliable blit=False for other types
303
+ # This ensures stability across all chart types while the line chart is optimized
304
+ if ctype == "bar":
305
  bars = ax.bar(plot_data.index.astype(str), np.zeros_like(plot_data.values, dtype=float), color="#1f77b4")
306
  ax.set_ylim(0, plot_data.max() * 1.1 if not pd.isna(plot_data.max()) and plot_data.max() > 0 else 1)
307
  ax.set_title(spec.title); plt.xticks(rotation=45, ha="right")
 
321
  ax.set_title(spec.title); ax.grid(alpha=.3); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col)
322
  def init():
323
  scat.set_offsets(np.empty((0, 2))); line.set_data([], [])
324
+ return []
325
  def update(i):
326
  point_frames = int(frames * 0.7)
327
  if i <= point_frames:
 
331
  line_frame = i - point_frames; line_total_frames = frames - point_frames
332
  current_x = reg_line_x[0] + (reg_line_x[1] - reg_line_x[0]) * (line_frame / line_total_frames)
333
  line.set_data([reg_line_x[0], current_x], [reg_line_y[0], slope * current_x + intercept])
334
+ return []
335
  init_func, update_func = init, update
336
  elif ctype == "pie":
337
  wedges, _, _ = ax.pie(plot_data, labels=plot_data.index, startangle=90, autopct='%1.1f%%')
338
  ax.set_title(spec.title); ax.axis('equal')
339
+ def init(): [w.set_alpha(0) for w in wedges]; return []
340
+ def update(i): [w.set_alpha(i / (frames - 1)) for w in wedges]; return []
341
  init_func, update_func = init, update
342
  elif ctype == "hist":
343
  _, _, patches = ax.hist(plot_data, bins=20, alpha=0)
344
  ax.set_title(spec.title); ax.set_xlabel(spec.x_col); ax.set_ylabel("Frequency")
345
+ def init(): [p.set_alpha(0) for p in patches]; return []
346
+ def update(i): [p.set_alpha((i / (frames - 1)) * 0.7) for p in patches]; return []
347
  init_func, update_func = init, update
348
  elif ctype == "heatmap":
349
  sns.heatmap(plot_data, annot=True, cmap="viridis", ax=ax, alpha=0)
350
  ax.set_title(spec.title)
351
+ def init(): ax.collections[0].set_alpha(0); return []
352
+ def update(i): ax.collections[0].set_alpha(i / (frames - 1)); return []
353
  init_func, update_func = init, update
354
  else:
355
  ax.text(0.5, 0.5, f"'{ctype}' animation not implemented", ha='center', va='center')
 
361
  anim.save(str(out), writer=FFMpegWriter(fps=fps), dpi=144)
362
  plt.close(fig)
363
  return str(out)
 
 
 
 
 
 
 
 
 
 
364
 
365
  def safe_chart(desc: str, df: pd.DataFrame, dur: float, out: Path, context: Dict) -> str:
366
  try:
 
611
  video_dur = audio_dur + 1.5
612
 
613
  try:
614
+ chart_descs = extract_chart_tags(sc)
615
+ pexels_descs = extract_pexels_tags(sc)
616
+ is_conclusion_scene = any(k in narrative.lower() for k in ["conclusion", "summary", "in closing", "final thoughts"])
617
+
618
+ if pexels_descs:
619
+ logging.info(f"Scene {i+1}: Processing Pexels scene.")
620
+ base_keywords = extract_keywords_for_query(narrative, llm)
621
+ final_query = f"{base_keywords} {domain}"
622
+ video_path = search_and_download_pexels_video(final_query, video_dur, mp4)
623
+ if not video_path: raise ValueError("Pexels search returned no results for chained query.")
 
 
 
624
  video_parts.append(video_path)
625
  if is_conclusion_scene:
626
  conclusion_video_path = video_path
627
+ elif chart_descs:
628
+ logging.info(f"Scene {i+1}: Primary attempt with animated chart.")
629
+ safe_chart(chart_descs[0], df, video_dur, mp4, data_context)
630
+ video_parts.append(str(mp4))
631
  else:
632
+ raise ValueError("No visual tag found in scene.")
 
 
 
 
 
 
633
  except Exception as e:
634
  logging.warning(f"Scene {i+1}: Primary visual failed ({e}). Marking for fallback.")
635
  video_parts.append("FALLBACK_NEEDED")
636
 
637
  temps.append(mp4)
638
 
639
+ if not conclusion_video_path:
 
640
  logging.warning("No conclusion video was generated; creating a generic one for fallbacks.")
641
  fallback_mp4 = Path(tempfile.gettempdir()) / f"{uuid.uuid4()}.mp4"
642
  conclusion_video_path = search_and_download_pexels_video(f"data visualization abstract {domain}", 5.0, fallback_mp4)