rairo commited on
Commit
aa7e0f2
·
verified ·
1 Parent(s): c13819e

Update sozo_gen.py

Browse files
Files changed (1) hide show
  1. sozo_gen.py +47 -23
sozo_gen.py CHANGED
@@ -366,12 +366,24 @@ def animate_chart(spec: ChartSpecification, df: pd.DataFrame, dur: float, out: P
366
  plt.tight_layout(pad=3.0)
367
  ctype = spec.chart_type
368
 
369
- # Animation logic remains the same, only the final call to FuncAnimation changes
370
- if ctype == "pie":
371
- wedges, _, _ = ax.pie(plot_data, labels=plot_data.index, startangle=90, autopct='%1.1f%%')
372
- ax.set_title(spec.title); ax.axis('equal')
373
- def init(): [w.set_alpha(0) for w in wedges]; return wedges
374
- def update(i): [w.set_alpha(i / (frames - 1)) for w in wedges]; return wedges
 
 
 
 
 
 
 
 
 
 
 
 
375
  elif ctype == "bar":
376
  bars = ax.bar(plot_data.index.astype(str), np.zeros_like(plot_data.values, dtype=float), color="#1f77b4")
377
  ax.set_ylim(0, plot_data.max() * 1.1 if not pd.isna(plot_data.max()) and plot_data.max() > 0 else 1)
@@ -380,6 +392,7 @@ def animate_chart(spec: ChartSpecification, df: pd.DataFrame, dur: float, out: P
380
  def update(i):
381
  for b, h in zip(bars, plot_data.values): b.set_height(h * (i / (frames - 1)))
382
  return bars
 
383
  elif ctype == "scatter":
384
  x_full, y_full = plot_data.iloc[:, 0], plot_data.iloc[:, 1]
385
  slope, intercept, _, _, _ = stats.linregress(x_full, y_full)
@@ -402,25 +415,36 @@ def animate_chart(spec: ChartSpecification, df: pd.DataFrame, dur: float, out: P
402
  current_x = reg_line_x[0] + (reg_line_x[1] - reg_line_x[0]) * (line_frame / line_total_frames)
403
  line.set_data([reg_line_x[0], current_x], [reg_line_y[0], slope * current_x + intercept])
404
  return [scat, line]
405
- else: # line, area, hist, etc.
406
- # This is a simplified representation; the full logic from previous steps is assumed here
407
- # For brevity, we'll just show the line chart example
408
- line, = ax.plot([], [], lw=2, color='#A23B72')
409
- plot_data = plot_data.sort_index()
410
- x_full, y_full = plot_data.index, plot_data.values
411
- ax.set_xlim(x_full.min(), x_full.max()); ax.set_ylim(y_full.min() * 0.9, y_full.max() * 1.1)
412
- ax.set_title(spec.title); ax.grid(alpha=.3); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col)
413
- def init(): line.set_data([], []); return [line]
414
- def update(i):
415
- k = max(2, int(len(x_full) * (i / (frames - 1))))
416
- line.set_data(x_full[:k], y_full[:k]); return [line]
417
-
418
- # The key change: blit=False
419
- anim = FuncAnimation(fig, update, init_func=init, frames=frames, blit=False, interval=1000 / fps)
 
 
 
 
 
 
 
 
 
 
 
420
  anim.save(str(out), writer=FFMpegWriter(fps=fps), dpi=144)
421
  plt.close(fig)
422
  return str(out)
423
-
424
  def animate_image_fade(img: np.ndarray, dur: float, out: Path, fps: int = 24) -> str:
425
  fourcc = cv2.VideoWriter_fourcc(*'mp4v'); video_writer = cv2.VideoWriter(str(out), fourcc, fps, (WIDTH, HEIGHT))
426
  total_frames = max(1, int(dur * fps))
@@ -674,7 +698,7 @@ def generate_video_from_project(df: pd.DataFrame, raw_md: str, data_context: Dic
674
  audio_parts.append(str(mp3)); temps.append(mp3)
675
  total_audio_duration += audio_dur
676
 
677
- video_dur = audio_dur + 0.5
678
 
679
  try:
680
  # --- Primary Visual Generation ---
 
366
  plt.tight_layout(pad=3.0)
367
  ctype = spec.chart_type
368
 
369
+ init_func, update_func = None, None
370
+
371
+ if ctype == "line":
372
+ plot_data = plot_data.sort_index()
373
+ x_full, y_full = plot_data.index, plot_data.values
374
+ def init():
375
+ ax.set_xlim(x_full.min(), x_full.max())
376
+ ax.set_ylim(y_full.min() * 0.9, y_full.max() * 1.1)
377
+ return []
378
+ def update(i):
379
+ ax.clear()
380
+ ax.set_xlim(x_full.min(), x_full.max())
381
+ ax.set_ylim(y_full.min() * 0.9, y_full.max() * 1.1)
382
+ ax.set_title(spec.title); ax.grid(alpha=.3); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col)
383
+ k = max(2, int(len(x_full) * (i / (frames - 1))))
384
+ ax.plot(x_full[:k], y_full[:k], lw=2, color='#A23B72', marker='o', markersize=5)
385
+ return []
386
+ init_func, update_func = init, update
387
  elif ctype == "bar":
388
  bars = ax.bar(plot_data.index.astype(str), np.zeros_like(plot_data.values, dtype=float), color="#1f77b4")
389
  ax.set_ylim(0, plot_data.max() * 1.1 if not pd.isna(plot_data.max()) and plot_data.max() > 0 else 1)
 
392
  def update(i):
393
  for b, h in zip(bars, plot_data.values): b.set_height(h * (i / (frames - 1)))
394
  return bars
395
+ init_func, update_func = init, update
396
  elif ctype == "scatter":
397
  x_full, y_full = plot_data.iloc[:, 0], plot_data.iloc[:, 1]
398
  slope, intercept, _, _, _ = stats.linregress(x_full, y_full)
 
415
  current_x = reg_line_x[0] + (reg_line_x[1] - reg_line_x[0]) * (line_frame / line_total_frames)
416
  line.set_data([reg_line_x[0], current_x], [reg_line_y[0], slope * current_x + intercept])
417
  return [scat, line]
418
+ init_func, update_func = init, update
419
+ elif ctype == "pie":
420
+ wedges, _, _ = ax.pie(plot_data, labels=plot_data.index, startangle=90, autopct='%1.1f%%')
421
+ ax.set_title(spec.title); ax.axis('equal')
422
+ def init(): [w.set_alpha(0) for w in wedges]; return wedges
423
+ def update(i): [w.set_alpha(i / (frames - 1)) for w in wedges]; return wedges
424
+ init_func, update_func = init, update
425
+ elif ctype == "hist":
426
+ _, _, patches = ax.hist(plot_data, bins=20, alpha=0)
427
+ ax.set_title(spec.title); ax.set_xlabel(spec.x_col); ax.set_ylabel("Frequency")
428
+ def init(): [p.set_alpha(0) for p in patches]; return patches
429
+ def update(i): [p.set_alpha((i / (frames - 1)) * 0.7) for p in patches]; return patches
430
+ init_func, update_func = init, update
431
+ elif ctype == "heatmap":
432
+ sns.heatmap(plot_data, annot=True, cmap="viridis", ax=ax, alpha=0)
433
+ ax.set_title(spec.title)
434
+ def init(): ax.collections[0].set_alpha(0); return [ax.collections[0]]
435
+ def update(i): ax.collections[0].set_alpha(i / (frames - 1)); return [ax.collections[0]]
436
+ init_func, update_func = init, update
437
+ else:
438
+ ax.text(0.5, 0.5, f"'{ctype}' animation not implemented", ha='center', va='center')
439
+ def init(): return []
440
+ def update(i): return []
441
+ init_func, update_func = init, update
442
+
443
+ anim = FuncAnimation(fig, update_func, init_func=init_func, frames=frames, blit=False, interval=1000 / fps)
444
  anim.save(str(out), writer=FFMpegWriter(fps=fps), dpi=144)
445
  plt.close(fig)
446
  return str(out)
447
+
448
  def animate_image_fade(img: np.ndarray, dur: float, out: Path, fps: int = 24) -> str:
449
  fourcc = cv2.VideoWriter_fourcc(*'mp4v'); video_writer = cv2.VideoWriter(str(out), fourcc, fps, (WIDTH, HEIGHT))
450
  total_frames = max(1, int(dur * fps))
 
698
  audio_parts.append(str(mp3)); temps.append(mp3)
699
  total_audio_duration += audio_dur
700
 
701
+ video_dur = audio_dur + 1.5
702
 
703
  try:
704
  # --- Primary Visual Generation ---