Spaces:
Running on Zero
Running on Zero
Commit ·
5aeadc9
1
Parent(s): 39d7b17
Add per-model inference timing to calibrate GPU duration constants
Browse filesWraps the segment inference loop in each generate_* function with
time.perf_counter() and prints actual wall-clock time, steps/segs,
and measured s/step alongside the current constant after every run.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
app.py
CHANGED
|
@@ -14,6 +14,7 @@ import random
|
|
| 14 |
from math import floor
|
| 15 |
from pathlib import Path
|
| 16 |
|
|
|
|
| 17 |
import torch
|
| 18 |
import numpy as np
|
| 19 |
import torchaudio
|
|
@@ -354,6 +355,7 @@ def generate_taro(video_file, seed_val, cfg_scale, num_steps, mode,
|
|
| 354 |
set_global_seed(sample_seed)
|
| 355 |
onset_feats = extract_onset(silent_video, onset_model, tmp_path=tmp_dir, device=device)
|
| 356 |
wavs = []
|
|
|
|
| 357 |
for seg_start_s, seg_end_s in segments:
|
| 358 |
print(f"[TARO] Sample {sample_idx+1} | {seg_start_s:.2f}s – {seg_end_s:.2f}s")
|
| 359 |
wav = _taro_infer_segment(
|
|
@@ -366,6 +368,12 @@ def generate_taro(video_file, seed_val, cfg_scale, num_steps, mode,
|
|
| 366 |
euler_sampler, euler_maruyama_sampler,
|
| 367 |
)
|
| 368 |
wavs.append(wav)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 369 |
_TARO_INFERENCE_CACHE[cache_key] = {"wavs": wavs}
|
| 370 |
|
| 371 |
final_wav = _stitch_wavs(wavs, crossfade_s, crossfade_db, total_dur_s, TARO_SR)
|
|
@@ -474,6 +482,7 @@ def generate_mmaudio(video_file, prompt, negative_prompt, seed_val,
|
|
| 474 |
rng.seed()
|
| 475 |
|
| 476 |
seg_audios = [] # list of (channels, samples) numpy arrays
|
|
|
|
| 477 |
|
| 478 |
for seg_i, (seg_start, seg_end) in enumerate(segments):
|
| 479 |
seg_dur = seg_end - seg_start
|
|
@@ -512,6 +521,13 @@ def generate_mmaudio(video_file, prompt, negative_prompt, seed_val,
|
|
| 512 |
wav = wav[:, :seg_samples]
|
| 513 |
seg_audios.append(wav)
|
| 514 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 515 |
# Crossfade-stitch all segments using shared equal-power helper
|
| 516 |
full_wav = seg_audios[0]
|
| 517 |
for nw in seg_audios[1:]:
|
|
@@ -627,6 +643,7 @@ def generate_hunyuan(video_file, prompt, negative_prompt, seed_val,
|
|
| 627 |
for sample_idx in range(num_samples):
|
| 628 |
seg_wavs = []
|
| 629 |
sr = 48000 # HunyuanFoley always outputs 48 kHz
|
|
|
|
| 630 |
for seg_i, (seg_start, seg_end) in enumerate(segments):
|
| 631 |
seg_dur = seg_end - seg_start
|
| 632 |
seg_path = os.path.join(tmp_dir, f"seg_{sample_idx}_{seg_i}.mp4")
|
|
@@ -661,6 +678,13 @@ def generate_hunyuan(video_file, prompt, negative_prompt, seed_val,
|
|
| 661 |
wav = wav[:, :seg_samples]
|
| 662 |
seg_wavs.append(wav)
|
| 663 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 664 |
# Crossfade-stitch all segments using shared equal-power helper
|
| 665 |
full_wav = seg_wavs[0]
|
| 666 |
for nw in seg_wavs[1:]:
|
|
|
|
| 14 |
from math import floor
|
| 15 |
from pathlib import Path
|
| 16 |
|
| 17 |
+
import time
|
| 18 |
import torch
|
| 19 |
import numpy as np
|
| 20 |
import torchaudio
|
|
|
|
| 355 |
set_global_seed(sample_seed)
|
| 356 |
onset_feats = extract_onset(silent_video, onset_model, tmp_path=tmp_dir, device=device)
|
| 357 |
wavs = []
|
| 358 |
+
_t_infer_start = time.perf_counter()
|
| 359 |
for seg_start_s, seg_end_s in segments:
|
| 360 |
print(f"[TARO] Sample {sample_idx+1} | {seg_start_s:.2f}s – {seg_end_s:.2f}s")
|
| 361 |
wav = _taro_infer_segment(
|
|
|
|
| 368 |
euler_sampler, euler_maruyama_sampler,
|
| 369 |
)
|
| 370 |
wavs.append(wav)
|
| 371 |
+
_t_infer_elapsed = time.perf_counter() - _t_infer_start
|
| 372 |
+
_n_segs = len(segments)
|
| 373 |
+
_secs_per_step = _t_infer_elapsed / (_n_segs * int(num_steps)) if _n_segs * int(num_steps) > 0 else 0
|
| 374 |
+
print(f"[TARO] Inference done: {_n_segs} seg(s) × {int(num_steps)} steps in "
|
| 375 |
+
f"{_t_infer_elapsed:.1f}s wall → {_secs_per_step:.3f}s/step "
|
| 376 |
+
f"(current constant={TARO_SECS_PER_STEP})")
|
| 377 |
_TARO_INFERENCE_CACHE[cache_key] = {"wavs": wavs}
|
| 378 |
|
| 379 |
final_wav = _stitch_wavs(wavs, crossfade_s, crossfade_db, total_dur_s, TARO_SR)
|
|
|
|
| 482 |
rng.seed()
|
| 483 |
|
| 484 |
seg_audios = [] # list of (channels, samples) numpy arrays
|
| 485 |
+
_t_mma_start = time.perf_counter()
|
| 486 |
|
| 487 |
for seg_i, (seg_start, seg_end) in enumerate(segments):
|
| 488 |
seg_dur = seg_end - seg_start
|
|
|
|
| 521 |
wav = wav[:, :seg_samples]
|
| 522 |
seg_audios.append(wav)
|
| 523 |
|
| 524 |
+
_t_mma_elapsed = time.perf_counter() - _t_mma_start
|
| 525 |
+
_n_segs_mma = len(segments)
|
| 526 |
+
_secs_per_step_mma = _t_mma_elapsed / (_n_segs_mma * int(num_steps)) if _n_segs_mma * int(num_steps) > 0 else 0
|
| 527 |
+
print(f"[MMAudio] Inference done: {_n_segs_mma} seg(s) × {int(num_steps)} steps in "
|
| 528 |
+
f"{_t_mma_elapsed:.1f}s wall → {_secs_per_step_mma:.3f}s/step "
|
| 529 |
+
f"(current constant={MMAUDIO_SECS_PER_STEP})")
|
| 530 |
+
|
| 531 |
# Crossfade-stitch all segments using shared equal-power helper
|
| 532 |
full_wav = seg_audios[0]
|
| 533 |
for nw in seg_audios[1:]:
|
|
|
|
| 643 |
for sample_idx in range(num_samples):
|
| 644 |
seg_wavs = []
|
| 645 |
sr = 48000 # HunyuanFoley always outputs 48 kHz
|
| 646 |
+
_t_hny_start = time.perf_counter()
|
| 647 |
for seg_i, (seg_start, seg_end) in enumerate(segments):
|
| 648 |
seg_dur = seg_end - seg_start
|
| 649 |
seg_path = os.path.join(tmp_dir, f"seg_{sample_idx}_{seg_i}.mp4")
|
|
|
|
| 678 |
wav = wav[:, :seg_samples]
|
| 679 |
seg_wavs.append(wav)
|
| 680 |
|
| 681 |
+
_t_hny_elapsed = time.perf_counter() - _t_hny_start
|
| 682 |
+
_n_segs_hny = len(segments)
|
| 683 |
+
_secs_per_step_hny = _t_hny_elapsed / (_n_segs_hny * int(num_steps)) if _n_segs_hny * int(num_steps) > 0 else 0
|
| 684 |
+
print(f"[HunyuanFoley] Inference done: {_n_segs_hny} seg(s) × {int(num_steps)} steps in "
|
| 685 |
+
f"{_t_hny_elapsed:.1f}s wall → {_secs_per_step_hny:.3f}s/step "
|
| 686 |
+
f"(current constant={HUNYUAN_SECS_PER_STEP})")
|
| 687 |
+
|
| 688 |
# Crossfade-stitch all segments using shared equal-power helper
|
| 689 |
full_wav = seg_wavs[0]
|
| 690 |
for nw in seg_wavs[1:]:
|