Spaces:
Sleeping
Sleeping
Update sozo_gen.py
Browse files- sozo_gen.py +52 -55
sozo_gen.py
CHANGED
|
@@ -107,19 +107,19 @@ def detect_dataset_domain(df: pd.DataFrame) -> str:
|
|
| 107 |
# NEW: Keyword extraction for better Pexels searches
|
| 108 |
def extract_keywords_for_query(text: str, llm) -> str:
|
| 109 |
prompt = f"""
|
| 110 |
-
Extract
|
| 111 |
Focus on concrete actions and subjects.
|
| 112 |
-
Example: 'Our analysis shows a significant growth in quarterly revenue and strong partnerships.' -> 'data analysis growth
|
| 113 |
Output only the search query keywords, separated by spaces.
|
| 114 |
|
| 115 |
Text: "{text}"
|
| 116 |
"""
|
| 117 |
try:
|
| 118 |
response = llm.invoke(prompt).content.strip()
|
| 119 |
-
return response
|
| 120 |
except Exception as e:
|
| 121 |
logging.error(f"Keyword extraction failed: {e}. Using original text.")
|
| 122 |
-
return text
|
| 123 |
|
| 124 |
# UPDATED: Pexels search now loops short videos
|
| 125 |
def search_and_download_pexels_video(query: str, duration: float, out_path: Path) -> str:
|
|
@@ -265,7 +265,7 @@ def prepare_plot_data(spec: ChartSpecification, df: pd.DataFrame):
|
|
| 265 |
|
| 266 |
def animate_chart(spec: ChartSpecification, df: pd.DataFrame, dur: float, out: Path, fps: int = FPS) -> str:
|
| 267 |
plot_data = prepare_plot_data(spec, df)
|
| 268 |
-
frames = math.ceil(dur * fps)
|
| 269 |
fig, ax = plt.subplots(figsize=(WIDTH / 100, HEIGHT / 100), dpi=100)
|
| 270 |
plt.tight_layout(pad=3.0)
|
| 271 |
ctype = spec.chart_type
|
|
@@ -275,20 +275,33 @@ def animate_chart(spec: ChartSpecification, df: pd.DataFrame, dur: float, out: P
|
|
| 275 |
if ctype == "line":
|
| 276 |
plot_data = plot_data.sort_index()
|
| 277 |
x_full, y_full = plot_data.index, plot_data.values
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 278 |
def init():
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
return
|
| 282 |
def update(i):
|
| 283 |
-
ax.clear()
|
| 284 |
-
ax.set_xlim(x_full.min(), x_full.max())
|
| 285 |
-
ax.set_ylim(y_full.min() * 0.9, y_full.max() * 1.1)
|
| 286 |
-
ax.set_title(spec.title); ax.grid(alpha=.3); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col)
|
| 287 |
k = max(2, int(len(x_full) * (i / (frames - 1))))
|
| 288 |
-
|
| 289 |
-
|
|
|
|
| 290 |
init_func, update_func = init, update
|
| 291 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 292 |
bars = ax.bar(plot_data.index.astype(str), np.zeros_like(plot_data.values, dtype=float), color="#1f77b4")
|
| 293 |
ax.set_ylim(0, plot_data.max() * 1.1 if not pd.isna(plot_data.max()) and plot_data.max() > 0 else 1)
|
| 294 |
ax.set_title(spec.title); plt.xticks(rotation=45, ha="right")
|
|
@@ -308,7 +321,7 @@ def animate_chart(spec: ChartSpecification, df: pd.DataFrame, dur: float, out: P
|
|
| 308 |
ax.set_title(spec.title); ax.grid(alpha=.3); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col)
|
| 309 |
def init():
|
| 310 |
scat.set_offsets(np.empty((0, 2))); line.set_data([], [])
|
| 311 |
-
return [
|
| 312 |
def update(i):
|
| 313 |
point_frames = int(frames * 0.7)
|
| 314 |
if i <= point_frames:
|
|
@@ -318,25 +331,25 @@ def animate_chart(spec: ChartSpecification, df: pd.DataFrame, dur: float, out: P
|
|
| 318 |
line_frame = i - point_frames; line_total_frames = frames - point_frames
|
| 319 |
current_x = reg_line_x[0] + (reg_line_x[1] - reg_line_x[0]) * (line_frame / line_total_frames)
|
| 320 |
line.set_data([reg_line_x[0], current_x], [reg_line_y[0], slope * current_x + intercept])
|
| 321 |
-
return [
|
| 322 |
init_func, update_func = init, update
|
| 323 |
elif ctype == "pie":
|
| 324 |
wedges, _, _ = ax.pie(plot_data, labels=plot_data.index, startangle=90, autopct='%1.1f%%')
|
| 325 |
ax.set_title(spec.title); ax.axis('equal')
|
| 326 |
-
def init(): [w.set_alpha(0) for w in wedges]; return
|
| 327 |
-
def update(i): [w.set_alpha(i / (frames - 1)) for w in wedges]; return
|
| 328 |
init_func, update_func = init, update
|
| 329 |
elif ctype == "hist":
|
| 330 |
_, _, patches = ax.hist(plot_data, bins=20, alpha=0)
|
| 331 |
ax.set_title(spec.title); ax.set_xlabel(spec.x_col); ax.set_ylabel("Frequency")
|
| 332 |
-
def init(): [p.set_alpha(0) for p in patches]; return
|
| 333 |
-
def update(i): [p.set_alpha((i / (frames - 1)) * 0.7) for p in patches]; return
|
| 334 |
init_func, update_func = init, update
|
| 335 |
elif ctype == "heatmap":
|
| 336 |
sns.heatmap(plot_data, annot=True, cmap="viridis", ax=ax, alpha=0)
|
| 337 |
ax.set_title(spec.title)
|
| 338 |
-
def init(): ax.collections[0].set_alpha(0); return [
|
| 339 |
-
def update(i): ax.collections[0].set_alpha(i / (frames - 1)); return [
|
| 340 |
init_func, update_func = init, update
|
| 341 |
else:
|
| 342 |
ax.text(0.5, 0.5, f"'{ctype}' animation not implemented", ha='center', va='center')
|
|
@@ -348,16 +361,6 @@ def animate_chart(spec: ChartSpecification, df: pd.DataFrame, dur: float, out: P
|
|
| 348 |
anim.save(str(out), writer=FFMpegWriter(fps=fps), dpi=144)
|
| 349 |
plt.close(fig)
|
| 350 |
return str(out)
|
| 351 |
-
|
| 352 |
-
def animate_image_fade(img: np.ndarray, dur: float, out: Path, fps: int = 24) -> str:
|
| 353 |
-
fourcc = cv2.VideoWriter_fourcc(*'mp4v'); video_writer = cv2.VideoWriter(str(out), fourcc, fps, (WIDTH, HEIGHT))
|
| 354 |
-
total_frames = max(1, int(dur * fps))
|
| 355 |
-
for i in range(total_frames):
|
| 356 |
-
alpha = i / (total_frames - 1) if total_frames > 1 else 1.0
|
| 357 |
-
frame = cv2.addWeighted(img, alpha, np.zeros_like(img), 1 - alpha, 0)
|
| 358 |
-
video_writer.write(frame)
|
| 359 |
-
video_writer.release()
|
| 360 |
-
return str(out)
|
| 361 |
|
| 362 |
def safe_chart(desc: str, df: pd.DataFrame, dur: float, out: Path, context: Dict) -> str:
|
| 363 |
try:
|
|
@@ -608,38 +611,32 @@ def generate_video_from_project(df: pd.DataFrame, raw_md: str, data_context: Dic
|
|
| 608 |
video_dur = audio_dur + 1.5
|
| 609 |
|
| 610 |
try:
|
| 611 |
-
|
| 612 |
-
|
| 613 |
-
is_conclusion_scene = any(k in
|
| 614 |
-
|
| 615 |
-
if
|
| 616 |
-
|
| 617 |
-
|
| 618 |
-
|
| 619 |
-
|
| 620 |
-
|
| 621 |
-
logging.info(f"Scene {i+1}: Pre-emptive guard triggered. Query: '{primary_query}'")
|
| 622 |
-
video_path = search_and_download_pexels_video(primary_query, video_dur, mp4)
|
| 623 |
-
if not video_path: raise ValueError("Pexels search failed for guarded query.")
|
| 624 |
video_parts.append(video_path)
|
| 625 |
if is_conclusion_scene:
|
| 626 |
conclusion_video_path = video_path
|
|
|
|
|
|
|
|
|
|
|
|
|
| 627 |
else:
|
| 628 |
-
|
| 629 |
-
if chart_descs:
|
| 630 |
-
logging.info(f"Scene {i+1}: Primary attempt with animated chart.")
|
| 631 |
-
safe_chart(chart_descs[0], df, video_dur, mp4, data_context)
|
| 632 |
-
video_parts.append(str(mp4))
|
| 633 |
-
else:
|
| 634 |
-
raise ValueError("No chart tag found in a middle scene.")
|
| 635 |
except Exception as e:
|
| 636 |
logging.warning(f"Scene {i+1}: Primary visual failed ({e}). Marking for fallback.")
|
| 637 |
video_parts.append("FALLBACK_NEEDED")
|
| 638 |
|
| 639 |
temps.append(mp4)
|
| 640 |
|
| 641 |
-
|
| 642 |
-
if not conclusion_video_path: # Failsafe if conclusion scene itself failed
|
| 643 |
logging.warning("No conclusion video was generated; creating a generic one for fallbacks.")
|
| 644 |
fallback_mp4 = Path(tempfile.gettempdir()) / f"{uuid.uuid4()}.mp4"
|
| 645 |
conclusion_video_path = search_and_download_pexels_video(f"data visualization abstract {domain}", 5.0, fallback_mp4)
|
|
|
|
| 107 |
# NEW: Keyword extraction for better Pexels searches
|
| 108 |
def extract_keywords_for_query(text: str, llm) -> str:
|
| 109 |
prompt = f"""
|
| 110 |
+
Extract a maximum of 3 key nouns or verbs from the following text to use as a search query for a stock video.
|
| 111 |
Focus on concrete actions and subjects.
|
| 112 |
+
Example: 'Our analysis shows a significant growth in quarterly revenue and strong partnerships.' -> 'data analysis growth'
|
| 113 |
Output only the search query keywords, separated by spaces.
|
| 114 |
|
| 115 |
Text: "{text}"
|
| 116 |
"""
|
| 117 |
try:
|
| 118 |
response = llm.invoke(prompt).content.strip()
|
| 119 |
+
return response if response else text
|
| 120 |
except Exception as e:
|
| 121 |
logging.error(f"Keyword extraction failed: {e}. Using original text.")
|
| 122 |
+
return text
|
| 123 |
|
| 124 |
# UPDATED: Pexels search now loops short videos
|
| 125 |
def search_and_download_pexels_video(query: str, duration: float, out_path: Path) -> str:
|
|
|
|
| 265 |
|
| 266 |
def animate_chart(spec: ChartSpecification, df: pd.DataFrame, dur: float, out: Path, fps: int = FPS) -> str:
|
| 267 |
plot_data = prepare_plot_data(spec, df)
|
| 268 |
+
frames = math.ceil(dur * fps)
|
| 269 |
fig, ax = plt.subplots(figsize=(WIDTH / 100, HEIGHT / 100), dpi=100)
|
| 270 |
plt.tight_layout(pad=3.0)
|
| 271 |
ctype = spec.chart_type
|
|
|
|
| 275 |
if ctype == "line":
|
| 276 |
plot_data = plot_data.sort_index()
|
| 277 |
x_full, y_full = plot_data.index, plot_data.values
|
| 278 |
+
|
| 279 |
+
ax.set_xlim(x_full.min(), x_full.max())
|
| 280 |
+
ax.set_ylim(y_full.min() * 0.9, y_full.max() * 1.1)
|
| 281 |
+
ax.set_title(spec.title); ax.grid(alpha=.3); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col)
|
| 282 |
+
|
| 283 |
+
line, = ax.plot([], [], lw=2, color='#A23B72')
|
| 284 |
+
markers, = ax.plot([], [], 'o', color='#A23B72', markersize=5)
|
| 285 |
+
|
| 286 |
def init():
|
| 287 |
+
line.set_data([], [])
|
| 288 |
+
markers.set_data([], [])
|
| 289 |
+
return line, markers
|
| 290 |
def update(i):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 291 |
k = max(2, int(len(x_full) * (i / (frames - 1))))
|
| 292 |
+
line.set_data(x_full[:k], y_full[:k])
|
| 293 |
+
markers.set_data(x_full[:k], y_full[:k])
|
| 294 |
+
return line, markers
|
| 295 |
init_func, update_func = init, update
|
| 296 |
+
|
| 297 |
+
anim = FuncAnimation(fig, update, init_func=init, frames=frames, blit=True, interval=1000 / fps)
|
| 298 |
+
anim.save(str(out), writer=FFMpegWriter(fps=fps), dpi=144)
|
| 299 |
+
plt.close(fig)
|
| 300 |
+
return str(out)
|
| 301 |
+
|
| 302 |
+
# Fallback to the slightly slower but reliable blit=False for other types
|
| 303 |
+
# This ensures stability across all chart types while the line chart is optimized
|
| 304 |
+
if ctype == "bar":
|
| 305 |
bars = ax.bar(plot_data.index.astype(str), np.zeros_like(plot_data.values, dtype=float), color="#1f77b4")
|
| 306 |
ax.set_ylim(0, plot_data.max() * 1.1 if not pd.isna(plot_data.max()) and plot_data.max() > 0 else 1)
|
| 307 |
ax.set_title(spec.title); plt.xticks(rotation=45, ha="right")
|
|
|
|
| 321 |
ax.set_title(spec.title); ax.grid(alpha=.3); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col)
|
| 322 |
def init():
|
| 323 |
scat.set_offsets(np.empty((0, 2))); line.set_data([], [])
|
| 324 |
+
return []
|
| 325 |
def update(i):
|
| 326 |
point_frames = int(frames * 0.7)
|
| 327 |
if i <= point_frames:
|
|
|
|
| 331 |
line_frame = i - point_frames; line_total_frames = frames - point_frames
|
| 332 |
current_x = reg_line_x[0] + (reg_line_x[1] - reg_line_x[0]) * (line_frame / line_total_frames)
|
| 333 |
line.set_data([reg_line_x[0], current_x], [reg_line_y[0], slope * current_x + intercept])
|
| 334 |
+
return []
|
| 335 |
init_func, update_func = init, update
|
| 336 |
elif ctype == "pie":
|
| 337 |
wedges, _, _ = ax.pie(plot_data, labels=plot_data.index, startangle=90, autopct='%1.1f%%')
|
| 338 |
ax.set_title(spec.title); ax.axis('equal')
|
| 339 |
+
def init(): [w.set_alpha(0) for w in wedges]; return []
|
| 340 |
+
def update(i): [w.set_alpha(i / (frames - 1)) for w in wedges]; return []
|
| 341 |
init_func, update_func = init, update
|
| 342 |
elif ctype == "hist":
|
| 343 |
_, _, patches = ax.hist(plot_data, bins=20, alpha=0)
|
| 344 |
ax.set_title(spec.title); ax.set_xlabel(spec.x_col); ax.set_ylabel("Frequency")
|
| 345 |
+
def init(): [p.set_alpha(0) for p in patches]; return []
|
| 346 |
+
def update(i): [p.set_alpha((i / (frames - 1)) * 0.7) for p in patches]; return []
|
| 347 |
init_func, update_func = init, update
|
| 348 |
elif ctype == "heatmap":
|
| 349 |
sns.heatmap(plot_data, annot=True, cmap="viridis", ax=ax, alpha=0)
|
| 350 |
ax.set_title(spec.title)
|
| 351 |
+
def init(): ax.collections[0].set_alpha(0); return []
|
| 352 |
+
def update(i): ax.collections[0].set_alpha(i / (frames - 1)); return []
|
| 353 |
init_func, update_func = init, update
|
| 354 |
else:
|
| 355 |
ax.text(0.5, 0.5, f"'{ctype}' animation not implemented", ha='center', va='center')
|
|
|
|
| 361 |
anim.save(str(out), writer=FFMpegWriter(fps=fps), dpi=144)
|
| 362 |
plt.close(fig)
|
| 363 |
return str(out)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 364 |
|
| 365 |
def safe_chart(desc: str, df: pd.DataFrame, dur: float, out: Path, context: Dict) -> str:
|
| 366 |
try:
|
|
|
|
| 611 |
video_dur = audio_dur + 1.5
|
| 612 |
|
| 613 |
try:
|
| 614 |
+
chart_descs = extract_chart_tags(sc)
|
| 615 |
+
pexels_descs = extract_pexels_tags(sc)
|
| 616 |
+
is_conclusion_scene = any(k in narrative.lower() for k in ["conclusion", "summary", "in closing", "final thoughts"])
|
| 617 |
+
|
| 618 |
+
if pexels_descs:
|
| 619 |
+
logging.info(f"Scene {i+1}: Processing Pexels scene.")
|
| 620 |
+
base_keywords = extract_keywords_for_query(narrative, llm)
|
| 621 |
+
final_query = f"{base_keywords} {domain}"
|
| 622 |
+
video_path = search_and_download_pexels_video(final_query, video_dur, mp4)
|
| 623 |
+
if not video_path: raise ValueError("Pexels search returned no results for chained query.")
|
|
|
|
|
|
|
|
|
|
| 624 |
video_parts.append(video_path)
|
| 625 |
if is_conclusion_scene:
|
| 626 |
conclusion_video_path = video_path
|
| 627 |
+
elif chart_descs:
|
| 628 |
+
logging.info(f"Scene {i+1}: Primary attempt with animated chart.")
|
| 629 |
+
safe_chart(chart_descs[0], df, video_dur, mp4, data_context)
|
| 630 |
+
video_parts.append(str(mp4))
|
| 631 |
else:
|
| 632 |
+
raise ValueError("No visual tag found in scene.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 633 |
except Exception as e:
|
| 634 |
logging.warning(f"Scene {i+1}: Primary visual failed ({e}). Marking for fallback.")
|
| 635 |
video_parts.append("FALLBACK_NEEDED")
|
| 636 |
|
| 637 |
temps.append(mp4)
|
| 638 |
|
| 639 |
+
if not conclusion_video_path:
|
|
|
|
| 640 |
logging.warning("No conclusion video was generated; creating a generic one for fallbacks.")
|
| 641 |
fallback_mp4 = Path(tempfile.gettempdir()) / f"{uuid.uuid4()}.mp4"
|
| 642 |
conclusion_video_path = search_and_download_pexels_video(f"data visualization abstract {domain}", 5.0, fallback_mp4)
|