Spaces:
Running on Zero
Running on Zero
Commit ·
0429f8a
1
Parent(s): d0e121d
Support videos longer than 8.2s via overlapping inference + 2s +3dB crossfade stitching
Browse files
app.py
CHANGED
|
@@ -49,6 +49,105 @@ def strip_audio_from_video(video_path, output_path):
|
|
| 49 |
)
|
| 50 |
|
| 51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
@spaces.GPU(duration=300)
|
| 53 |
def generate_audio(video_file, seed_val, cfg_scale, num_steps, mode):
|
| 54 |
seed_val = int(seed_val)
|
|
@@ -101,74 +200,80 @@ def generate_audio(video_file, seed_val, cfg_scale, num_steps, mode):
|
|
| 101 |
cavp_feats = extract_cavp(silent_video, tmp_path=tmp_dir)
|
| 102 |
onset_feats = extract_onset(silent_video, onset_model, tmp_path=tmp_dir, device=device)
|
| 103 |
|
| 104 |
-
sr
|
| 105 |
-
truncate
|
| 106 |
-
fps
|
| 107 |
-
truncate_frame = int(fps * truncate / sr)
|
| 108 |
-
truncate_onset = 120
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
|
| 110 |
latents_scale = torch.tensor([0.18215] * 8).view(1, 8, 1, 1).to(device)
|
| 111 |
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
)
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
with torch.no_grad():
|
| 143 |
-
if mode == "sde":
|
| 144 |
-
samples = euler_maruyama_sampler(**sampling_kwargs)
|
| 145 |
-
else:
|
| 146 |
-
samples = euler_sampler(**sampling_kwargs)
|
| 147 |
-
|
| 148 |
-
samples = vae.decode(samples / latents_scale).sample
|
| 149 |
-
|
| 150 |
-
# Cast to float32 before vocoder (HiFi-GAN requires float32)
|
| 151 |
-
wav_samples = vocoder(samples.squeeze().float()).detach().cpu().numpy()
|
| 152 |
|
| 153 |
audio_path = os.path.join(tmp_dir, "output.wav")
|
| 154 |
-
sf.write(audio_path,
|
| 155 |
|
| 156 |
-
|
| 157 |
-
trimmed_video = os.path.join(tmp_dir, "trimmed.mp4")
|
| 158 |
output_video = os.path.join(tmp_dir, "output.mp4")
|
| 159 |
-
|
| 160 |
-
(
|
| 161 |
-
ffmpeg
|
| 162 |
-
.input(silent_video, ss=0, t=duration)
|
| 163 |
-
.output(trimmed_video, vcodec="libx264", an=None)
|
| 164 |
-
.run(overwrite_output=True, quiet=True)
|
| 165 |
-
)
|
| 166 |
-
|
| 167 |
-
input_v = ffmpeg.input(trimmed_video)
|
| 168 |
input_a = ffmpeg.input(audio_path)
|
| 169 |
(
|
| 170 |
ffmpeg
|
| 171 |
-
.output(input_v, input_a, output_video,
|
|
|
|
| 172 |
.run(overwrite_output=True, quiet=True)
|
| 173 |
)
|
| 174 |
|
|
|
|
| 49 |
)
|
| 50 |
|
| 51 |
|
| 52 |
+
def infer_segment(model, vae, vocoder, cavp_feats_full, onset_feats_full,
|
| 53 |
+
seg_start_s, seg_end_s,
|
| 54 |
+
sr, fps, truncate_frame, truncate_onset, model_dur,
|
| 55 |
+
latents_scale, device, weight_dtype,
|
| 56 |
+
cfg_scale, num_steps, mode,
|
| 57 |
+
euler_sampler, euler_maruyama_sampler):
|
| 58 |
+
"""
|
| 59 |
+
Run one model inference pass for the video window [seg_start_s, seg_start_s + model_dur].
|
| 60 |
+
Returns a numpy float32 wav array of exactly round(model_dur * sr) samples,
|
| 61 |
+
trimmed to the actual segment length (seg_end_s - seg_start_s) when shorter.
|
| 62 |
+
"""
|
| 63 |
+
# -- CAVP features: 4 fps --
|
| 64 |
+
cavp_start = int(round(seg_start_s * fps))
|
| 65 |
+
cavp_end = cavp_start + truncate_frame
|
| 66 |
+
cavp_slice = cavp_feats_full[cavp_start:cavp_end]
|
| 67 |
+
# pad if near end of video
|
| 68 |
+
if cavp_slice.shape[0] < truncate_frame:
|
| 69 |
+
pad = np.zeros((truncate_frame - cavp_slice.shape[0],) + cavp_slice.shape[1:], dtype=cavp_slice.dtype)
|
| 70 |
+
cavp_slice = np.concatenate([cavp_slice, pad], axis=0)
|
| 71 |
+
video_feats = torch.from_numpy(cavp_slice).unsqueeze(0).to(device).to(weight_dtype)
|
| 72 |
+
|
| 73 |
+
# -- Onset features: truncate_onset frames per model_dur --
|
| 74 |
+
onset_fps = truncate_onset / model_dur # frames per second of onset feats
|
| 75 |
+
onset_start = int(round(seg_start_s * onset_fps))
|
| 76 |
+
onset_slice = onset_feats_full[onset_start : onset_start + truncate_onset]
|
| 77 |
+
if onset_slice.shape[0] < truncate_onset:
|
| 78 |
+
pad_len = truncate_onset - onset_slice.shape[0]
|
| 79 |
+
onset_slice = np.pad(onset_slice, ((0, pad_len),), mode="constant", constant_values=0)
|
| 80 |
+
onset_feats_t = torch.from_numpy(onset_slice).unsqueeze(0).to(device).to(weight_dtype)
|
| 81 |
+
|
| 82 |
+
# -- Diffusion --
|
| 83 |
+
z = torch.randn(1, model.in_channels, 204, 16, device=device).to(weight_dtype)
|
| 84 |
+
sampling_kwargs = dict(
|
| 85 |
+
model=model,
|
| 86 |
+
latents=z,
|
| 87 |
+
y=onset_feats_t,
|
| 88 |
+
context=video_feats,
|
| 89 |
+
num_steps=int(num_steps),
|
| 90 |
+
heun=False,
|
| 91 |
+
cfg_scale=float(cfg_scale),
|
| 92 |
+
guidance_low=0.0,
|
| 93 |
+
guidance_high=0.7,
|
| 94 |
+
path_type="linear",
|
| 95 |
+
)
|
| 96 |
+
with torch.no_grad():
|
| 97 |
+
if mode == "sde":
|
| 98 |
+
samples = euler_maruyama_sampler(**sampling_kwargs)
|
| 99 |
+
else:
|
| 100 |
+
samples = euler_sampler(**sampling_kwargs)
|
| 101 |
+
|
| 102 |
+
samples = vae.decode(samples / latents_scale).sample
|
| 103 |
+
wav = vocoder(samples.squeeze().float()).detach().cpu().numpy()
|
| 104 |
+
|
| 105 |
+
# Trim to actual segment length
|
| 106 |
+
seg_samples = int(round((seg_end_s - seg_start_s) * sr))
|
| 107 |
+
return wav[:seg_samples]
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def crossfade_join(wav_a, wav_b, crossfade_s, sr):
|
| 111 |
+
"""
|
| 112 |
+
Join wav_a and wav_b with a 2-second equal-power (+3 dB) crossfade.
|
| 113 |
+
|
| 114 |
+
wav_a contains 1 s of 'extra' audio at its tail (the overlap region starts
|
| 115 |
+
1 s before its end). wav_b contains 1 s of 'extra' audio at its head.
|
| 116 |
+
The crossfade window is crossfade_s wide; the midpoint sits at (crossfade_s/2)
|
| 117 |
+
into the window, where each gain = sqrt(0.5) ≈ -3 dB ... wait, we want +3 dB
|
| 118 |
+
at midpoint meaning both signals are at *full* amplitude there.
|
| 119 |
+
|
| 120 |
+
Equal-power (sqrt) ramps: at the midpoint t=0.5 the fade-out = sqrt(0.5) and
|
| 121 |
+
fade-in = sqrt(0.5), so combined power = 0.5+0.5 = 1.0 (+0 dB).
|
| 122 |
+
For a +3 dB bump at midpoint we use *linear* ramps instead:
|
| 123 |
+
fade_out = 1 - t, fade_in = t (t: 0->1 across window)
|
| 124 |
+
At t=0.5: both = 0.5, sum = 1.0 amplitude = +6 dB power... that is not right.
|
| 125 |
+
|
| 126 |
+
DaVinci Resolve "+3 dB" crossfade means the combined level at the midpoint
|
| 127 |
+
is +3 dB above either source, which equals the behaviour where each signal
|
| 128 |
+
is kept at full gain (1.0) across the entire overlap and the two are simply
|
| 129 |
+
summed — then the overlap region has 6 dB of headroom risk, but the *perceived*
|
| 130 |
+
loudness boost at the centre is +3 dB (sqrt(2) in amplitude).
|
| 131 |
+
|
| 132 |
+
Implementation: keep both signals at unity gain in the crossfade window and
|
| 133 |
+
sum them. Outside the window use the respective signal only.
|
| 134 |
+
"""
|
| 135 |
+
cf_samples = int(round(crossfade_s * sr))
|
| 136 |
+
|
| 137 |
+
# The crossfade sits at the junction: last cf_samples of wav_a overlap with
|
| 138 |
+
# first cf_samples of wav_b.
|
| 139 |
+
tail_a = wav_a[-cf_samples:] # 1s before end of a
|
| 140 |
+
head_b = wav_b[:cf_samples] # 1s after start of b
|
| 141 |
+
overlap = tail_a + head_b # +3 dB sum at centre (unity + unity)
|
| 142 |
+
|
| 143 |
+
result = np.concatenate([
|
| 144 |
+
wav_a[:-cf_samples], # body of a (before crossfade)
|
| 145 |
+
overlap, # crossfade region
|
| 146 |
+
wav_b[cf_samples:], # body of b (after crossfade)
|
| 147 |
+
])
|
| 148 |
+
return result
|
| 149 |
+
|
| 150 |
+
|
| 151 |
@spaces.GPU(duration=300)
|
| 152 |
def generate_audio(video_file, seed_val, cfg_scale, num_steps, mode):
|
| 153 |
seed_val = int(seed_val)
|
|
|
|
| 200 |
cavp_feats = extract_cavp(silent_video, tmp_path=tmp_dir)
|
| 201 |
onset_feats = extract_onset(silent_video, onset_model, tmp_path=tmp_dir, device=device)
|
| 202 |
|
| 203 |
+
sr = 16000
|
| 204 |
+
truncate = 131072
|
| 205 |
+
fps = 4
|
| 206 |
+
truncate_frame = int(fps * truncate / sr) # 32 cavp frames per segment
|
| 207 |
+
truncate_onset = 120 # onset frames per segment
|
| 208 |
+
model_dur = truncate / sr # 8.192 s
|
| 209 |
+
crossfade_s = 2.0 # 2-second crossfade window
|
| 210 |
+
# Each segment starts (model_dur - crossfade_s) later than the previous,
|
| 211 |
+
# so the tails overlap by crossfade_s giving 1 s of extra audio on each side.
|
| 212 |
+
step_s = model_dur - crossfade_s # 6.192 s
|
| 213 |
|
| 214 |
latents_scale = torch.tensor([0.18215] * 8).view(1, 8, 1, 1).to(device)
|
| 215 |
|
| 216 |
+
# Total video duration from cavp features
|
| 217 |
+
total_frames = cavp_feats.shape[0]
|
| 218 |
+
total_dur_s = total_frames / fps
|
| 219 |
+
|
| 220 |
+
# ------------------------------------------------------------------ #
|
| 221 |
+
# Build segment list: each entry is (seg_start_s, seg_end_s) #
|
| 222 |
+
# seg_end_s is the actual content end (clipped to video length), #
|
| 223 |
+
# but we always run the model for a full model_dur window. #
|
| 224 |
+
# ------------------------------------------------------------------ #
|
| 225 |
+
segments = []
|
| 226 |
+
seg_start = 0.0
|
| 227 |
+
while True:
|
| 228 |
+
seg_end = min(seg_start + model_dur, total_dur_s)
|
| 229 |
+
segments.append((seg_start, seg_end))
|
| 230 |
+
if seg_end >= total_dur_s:
|
| 231 |
+
break
|
| 232 |
+
seg_start += step_s
|
| 233 |
+
|
| 234 |
+
# ------------------------------------------------------------------ #
|
| 235 |
+
# Run inference for every segment #
|
| 236 |
+
# ------------------------------------------------------------------ #
|
| 237 |
+
wavs = []
|
| 238 |
+
for seg_start_s, seg_end_s in segments:
|
| 239 |
+
print(f"Inferring segment {seg_start_s:.2f}s – {seg_end_s:.2f}s ...")
|
| 240 |
+
wav = infer_segment(
|
| 241 |
+
model, vae, vocoder,
|
| 242 |
+
cavp_feats, onset_feats,
|
| 243 |
+
seg_start_s, seg_end_s,
|
| 244 |
+
sr, fps, truncate_frame, truncate_onset, model_dur,
|
| 245 |
+
latents_scale, device, weight_dtype,
|
| 246 |
+
cfg_scale, num_steps, mode,
|
| 247 |
+
euler_sampler, euler_maruyama_sampler,
|
| 248 |
)
|
| 249 |
+
wavs.append(wav)
|
| 250 |
+
|
| 251 |
+
# ------------------------------------------------------------------ #
|
| 252 |
+
# Stitch with crossfades #
|
| 253 |
+
# Single segment: no crossfade needed #
|
| 254 |
+
# ------------------------------------------------------------------ #
|
| 255 |
+
if len(wavs) == 1:
|
| 256 |
+
final_wav = wavs[0]
|
| 257 |
+
else:
|
| 258 |
+
final_wav = wavs[0]
|
| 259 |
+
for next_wav in wavs[1:]:
|
| 260 |
+
final_wav = crossfade_join(final_wav, next_wav, crossfade_s, sr)
|
| 261 |
+
|
| 262 |
+
# Clip to exact video duration
|
| 263 |
+
target_samples = int(round(total_dur_s * sr))
|
| 264 |
+
final_wav = final_wav[:target_samples]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 265 |
|
| 266 |
audio_path = os.path.join(tmp_dir, "output.wav")
|
| 267 |
+
sf.write(audio_path, final_wav, sr)
|
| 268 |
|
| 269 |
+
# Mux original silent video (full length) with generated audio
|
|
|
|
| 270 |
output_video = os.path.join(tmp_dir, "output.mp4")
|
| 271 |
+
input_v = ffmpeg.input(silent_video)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 272 |
input_a = ffmpeg.input(audio_path)
|
| 273 |
(
|
| 274 |
ffmpeg
|
| 275 |
+
.output(input_v, input_a, output_video,
|
| 276 |
+
vcodec="libx264", acodec="aac", strict="experimental")
|
| 277 |
.run(overwrite_output=True, quiet=True)
|
| 278 |
)
|
| 279 |
|