JackIsNotInTheBox commited on
Commit
5ce02be
·
1 Parent(s): 39355fb

Fix RuntimeError: pad onset_feats to truncate_onset=120 when video is shorter than 8.2s

Browse files
Files changed (1) hide show
  1. app.py +88 -82
app.py CHANGED
@@ -3,14 +3,14 @@ import subprocess
3
  import sys
4
 
5
  try:
6
- import mmcv
7
- print("mmcv already installed")
8
  except ImportError:
9
- print("Installing mmcv with --no-build-isolation...")
10
- subprocess.check_call([sys.executable, "-m", "pip", "install", "--no-build-isolation", "mmcv>=2.0.0"])
11
- print("mmcv installed successfully")
12
 
13
- import torch
14
  import numpy as np
15
  import random
16
  import soundfile as sf
@@ -30,31 +30,28 @@ onset_ckpt_path = hf_hub_download(repo_id=REPO_ID, filename="onset_model.ckpt",
30
  taro_ckpt_path = hf_hub_download(repo_id=REPO_ID, filename="taro_ckpt.pt", cache_dir=CACHE_DIR)
31
  print("Checkpoints downloaded.")
32
 
33
-
34
  def set_global_seed(seed):
35
- np.random.seed(seed % (2**32))
36
- random.seed(seed)
37
- torch.manual_seed(seed)
38
- torch.cuda.manual_seed(seed)
39
- torch.backends.cudnn.deterministic = True
40
-
41
-
42
- def strip_audio_from_video(video_path, output_path):
43
- """Strip any existing audio from a video file, outputting a silent video."""
44
- (
45
- ffmpeg
46
- .input(video_path)
47
- .output(output_path, vcodec="libx264", an=None)
48
- .run(overwrite_output=True, quiet=True)
49
- )
50
-
51
-
52
- @spaces.GPU(duration=300)
53
  def generate_audio(video_file, seed_val, cfg_scale, num_steps, mode):
54
- seed_val = int(seed_val)
55
- if seed_val < 0:
56
- seed_val = random.randint(0, 2**32 - 1)
57
- set_global_seed(seed_val)
58
  torch.set_grad_enabled(False)
59
  device = "cuda" if torch.cuda.is_available() else "cpu"
60
  weight_dtype = torch.bfloat16
@@ -70,14 +67,13 @@ def generate_audio(video_file, seed_val, cfg_scale, num_steps, mode):
70
  state_dict = torch.load(onset_ckpt_path, map_location=device, weights_only=False)["state_dict"]
71
  new_state_dict = {}
72
  for key, value in state_dict.items():
73
- if "model.net.model" in key:
74
- new_key = key.replace("model.net.model", "net.model")
75
- elif "model.fc." in key:
76
- new_key = key.replace("model.fc", "fc")
77
- else:
78
- new_key = key
79
- new_state_dict[new_key] = value
80
-
81
  onset_model = VideoOnsetNet(False).to(device)
82
  onset_model.load_state_dict(new_state_dict)
83
  onset_model.eval()
@@ -94,7 +90,6 @@ def generate_audio(video_file, seed_val, cfg_scale, num_steps, mode):
94
  vocoder = model_audioldm.vocoder.to(device)
95
 
96
  tmp_dir = tempfile.mkdtemp()
97
-
98
  silent_video = os.path.join(tmp_dir, "silent_input.mp4")
99
  strip_audio_from_video(video_file, silent_video)
100
 
@@ -110,30 +105,43 @@ def generate_audio(video_file, seed_val, cfg_scale, num_steps, mode):
110
  latents_scale = torch.tensor([0.18215] * 8).view(1, 8, 1, 1).to(device)
111
 
112
  video_feats = torch.from_numpy(cavp_feats[:truncate_frame]).unsqueeze(0).to(device).to(weight_dtype)
113
- onset_feats_t = torch.from_numpy(onset_feats[:truncate_onset]).unsqueeze(0).to(device).to(weight_dtype)
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
  z = torch.randn(len(video_feats), model.in_channels, 204, 16, device=device).to(weight_dtype)
116
 
117
  sampling_kwargs = dict(
118
- model=model,
119
- latents=z,
120
- y=onset_feats_t,
121
- context=video_feats,
122
- num_steps=int(num_steps),
123
- heun=False,
124
- cfg_scale=float(cfg_scale),
125
- guidance_low=0.0,
126
- guidance_high=0.7,
127
- path_type="linear",
128
  )
129
 
130
  with torch.no_grad():
131
- if mode == "sde":
132
- samples = euler_maruyama_sampler(**sampling_kwargs)
133
- else:
134
- samples = euler_sampler(**sampling_kwargs)
 
 
135
 
136
- samples = vae.decode(samples / latents_scale).sample
137
  # Cast to float32 before vocoder (HiFi-GAN requires float32)
138
  wav_samples = vocoder(samples.squeeze().float()).detach().cpu().numpy()
139
 
@@ -141,45 +149,43 @@ def generate_audio(video_file, seed_val, cfg_scale, num_steps, mode):
141
  sf.write(audio_path, wav_samples, sr)
142
 
143
  duration = truncate / sr
144
-
145
  trimmed_video = os.path.join(tmp_dir, "trimmed.mp4")
146
  output_video = os.path.join(tmp_dir, "output.mp4")
 
147
  (
148
- ffmpeg
149
- .input(silent_video, ss=0, t=duration)
150
- .output(trimmed_video, vcodec="libx264", an=None)
151
- .run(overwrite_output=True, quiet=True)
152
  )
153
 
154
  input_v = ffmpeg.input(trimmed_video)
155
  input_a = ffmpeg.input(audio_path)
156
  (
157
- ffmpeg
158
- .output(input_v, input_a, output_video, vcodec="libx264", acodec="aac", strict="experimental")
159
- .run(overwrite_output=True, quiet=True)
160
  )
161
 
162
  return output_video, audio_path
163
 
164
-
165
  def get_random_seed():
166
- return random.randint(0, 2**32 - 1)
167
-
168
-
169
- demo = gr.Interface(
170
- fn=generate_audio,
171
- inputs=[
172
- gr.Video(label="Input Video"),
173
- gr.Number(label="Seed", value=get_random_seed, precision=0),
174
- gr.Slider(label="CFG Scale", minimum=1, maximum=15, value=8, step=0.5),
175
- gr.Slider(label="Sampling Steps", minimum=10, maximum=50, value=25, step=1),
176
- gr.Radio(label="Sampling Mode", choices=["sde", "ode"], value="sde"),
177
- ],
178
- outputs=[
179
- gr.Video(label="Output Video with Audio"),
180
- gr.Audio(label="Generated Audio"),
181
- ],
182
- title="TARO: Video-to-Audio Synthesis (ICCV 2025)",
183
- description="Upload a video and generate synchronized audio using TARO. Optimal duration is 8.2s.",
184
- )
185
  demo.queue().launch()
 
3
  import sys
4
 
5
  try:
6
+ import mmcv
7
+ print("mmcv already installed")
8
  except ImportError:
9
+ print("Installing mmcv with --no-build-isolation...")
10
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "--no-build-isolation", "mmcv>=2.0.0"])
11
+ print("mmcv installed successfully")
12
 
13
+ import torch
14
  import numpy as np
15
  import random
16
  import soundfile as sf
 
30
  taro_ckpt_path = hf_hub_download(repo_id=REPO_ID, filename="taro_ckpt.pt", cache_dir=CACHE_DIR)
31
  print("Checkpoints downloaded.")
32
 
 
33
  def set_global_seed(seed):
34
+ np.random.seed(seed % (2**32))
35
+ random.seed(seed)
36
+ torch.manual_seed(seed)
37
+ torch.cuda.manual_seed(seed)
38
+ torch.backends.cudnn.deterministic = True
39
+
40
+ def strip_audio_from_video(video_path, output_path):
41
+ """Strip any existing audio from a video file, outputting a silent video."""""
42
+ (
43
+ ffmpeg
44
+ .input(video_path)
45
+ .output(output_path, vcodec="libx264", an=None)
46
+ .run(overwrite_output=True, quiet=True)
47
+ )
48
+
49
+ @spaces.GPU(duration=300)
 
 
50
  def generate_audio(video_file, seed_val, cfg_scale, num_steps, mode):
51
+ seed_val = int(seed_val)
52
+ if seed_val < 0:
53
+ seed_val = random.randint(0, 2**32 - 1)
54
+ set_global_seed(seed_val)
55
  torch.set_grad_enabled(False)
56
  device = "cuda" if torch.cuda.is_available() else "cpu"
57
  weight_dtype = torch.bfloat16
 
67
  state_dict = torch.load(onset_ckpt_path, map_location=device, weights_only=False)["state_dict"]
68
  new_state_dict = {}
69
  for key, value in state_dict.items():
70
+ if "model.net.model" in key:
71
+ new_key = key.replace("model.net.model", "net.model")
72
+ elif "model.fc." in key:
73
+ new_key = key.replace("model.fc", "fc")
74
+ else:
75
+ new_key = key
76
+ new_state_dict[new_key] = value
 
77
  onset_model = VideoOnsetNet(False).to(device)
78
  onset_model.load_state_dict(new_state_dict)
79
  onset_model.eval()
 
90
  vocoder = model_audioldm.vocoder.to(device)
91
 
92
  tmp_dir = tempfile.mkdtemp()
 
93
  silent_video = os.path.join(tmp_dir, "silent_input.mp4")
94
  strip_audio_from_video(video_file, silent_video)
95
 
 
105
  latents_scale = torch.tensor([0.18215] * 8).view(1, 8, 1, 1).to(device)
106
 
107
  video_feats = torch.from_numpy(cavp_feats[:truncate_frame]).unsqueeze(0).to(device).to(weight_dtype)
108
+
109
+ # Slice onset features and pad to truncate_onset if the video is shorter than expected
110
+ onset_feats_sliced = onset_feats[:truncate_onset]
111
+ actual_onset_len = onset_feats_sliced.shape[0]
112
+ if actual_onset_len < truncate_onset:
113
+ pad_len = truncate_onset - actual_onset_len
114
+ onset_feats_sliced = np.pad(
115
+ onset_feats_sliced,
116
+ ((0, pad_len),),
117
+ mode="constant",
118
+ constant_values=0,
119
+ )
120
+ onset_feats_t = torch.from_numpy(onset_feats_sliced).unsqueeze(0).to(device).to(weight_dtype)
121
 
122
  z = torch.randn(len(video_feats), model.in_channels, 204, 16, device=device).to(weight_dtype)
123
 
124
  sampling_kwargs = dict(
125
+ model=model,
126
+ latents=z,
127
+ y=onset_feats_t,
128
+ context=video_feats,
129
+ num_steps=int(num_steps),
130
+ heun=False,
131
+ cfg_scale=float(cfg_scale),
132
+ guidance_low=0.0,
133
+ guidance_high=0.7,
134
+ path_type="linear",
135
  )
136
 
137
  with torch.no_grad():
138
+ if mode == "sde":
139
+ samples = euler_maruyama_sampler(**sampling_kwargs)
140
+ else:
141
+ samples = euler_sampler(**sampling_kwargs)
142
+
143
+ samples = vae.decode(samples / latents_scale).sample
144
 
 
145
  # Cast to float32 before vocoder (HiFi-GAN requires float32)
146
  wav_samples = vocoder(samples.squeeze().float()).detach().cpu().numpy()
147
 
 
149
  sf.write(audio_path, wav_samples, sr)
150
 
151
  duration = truncate / sr
 
152
  trimmed_video = os.path.join(tmp_dir, "trimmed.mp4")
153
  output_video = os.path.join(tmp_dir, "output.mp4")
154
+
155
  (
156
+ ffmpeg
157
+ .input(silent_video, ss=0, t=duration)
158
+ .output(trimmed_video, vcodec="libx264", an=None)
159
+ .run(overwrite_output=True, quiet=True)
160
  )
161
 
162
  input_v = ffmpeg.input(trimmed_video)
163
  input_a = ffmpeg.input(audio_path)
164
  (
165
+ ffmpeg
166
+ .output(input_v, input_a, output_video, vcodec="libx264", acodec="aac", strict="experimental")
167
+ .run(overwrite_output=True, quiet=True)
168
  )
169
 
170
  return output_video, audio_path
171
 
 
172
  def get_random_seed():
173
+ return random.randint(0, 2**32 - 1)
174
+
175
+ demo = gr.Interface(
176
+ fn=generate_audio,
177
+ inputs=[
178
+ gr.Video(label="Input Video"),
179
+ gr.Number(label="Seed", value=get_random_seed, precision=0),
180
+ gr.Slider(label="CFG Scale", minimum=1, maximum=15, value=8, step=0.5),
181
+ gr.Slider(label="Sampling Steps", minimum=10, maximum=50, value=25, step=1),
182
+ gr.Radio(label="Sampling Mode", choices=["sde", "ode"], value="sde"),
183
+ ],
184
+ outputs=[
185
+ gr.Video(label="Output Video with Audio"),
186
+ gr.Audio(label="Generated Audio"),
187
+ ],
188
+ title="TARO: Video-to-Audio Synthesis (ICCV 2025)",
189
+ description="Upload a video and generate synchronized audio using TARO. Optimal duration is 8.2s.",
190
+ )
 
191
  demo.queue().launch()