JackIsNotInTheBox commited on
Commit
8443669
·
1 Parent(s): d9d27f2

Fix vocoder float32 cast error; default seed is now random

Browse files
Files changed (1) hide show
  1. app.py +80 -71
app.py CHANGED
@@ -3,14 +3,14 @@ import subprocess
3
  import sys
4
 
5
  try:
6
- import mmcv
7
- print("mmcv already installed")
8
  except ImportError:
9
- print("Installing mmcv with --no-build-isolation...")
10
- subprocess.check_call([sys.executable, "-m", "pip", "install", "--no-build-isolation", "mmcv>=2.0.0"])
11
- print("mmcv installed successfully")
12
 
13
- import torch
14
  import numpy as np
15
  import random
16
  import soundfile as sf
@@ -30,27 +30,32 @@ onset_ckpt_path = hf_hub_download(repo_id=REPO_ID, filename="onset_model.ckpt",
30
  taro_ckpt_path = hf_hub_download(repo_id=REPO_ID, filename="taro_ckpt.pt", cache_dir=CACHE_DIR)
31
  print("Checkpoints downloaded.")
32
 
 
33
  def set_global_seed(seed):
34
- np.random.seed(seed % (2**32))
35
- random.seed(seed)
36
- torch.manual_seed(seed)
37
- torch.cuda.manual_seed(seed)
38
- torch.backends.cudnn.deterministic = True
39
-
40
- def strip_audio_from_video(video_path, output_path):
41
- """Strip any existing audio from a video file, outputting a silent video."""
42
- (
43
- ffmpeg
44
- .input(video_path)
45
- .output(output_path, vcodec="libx264", an=None)
46
- .run(overwrite_output=True, quiet=True)
47
- )
 
48
 
49
- @spaces.GPU(duration=300)
 
50
  def generate_audio(video_file, seed_val, cfg_scale, num_steps, mode):
51
- set_global_seed(int(seed_val))
 
 
 
52
  torch.set_grad_enabled(False)
53
-
54
  device = "cuda" if torch.cuda.is_available() else "cpu"
55
  weight_dtype = torch.bfloat16
56
 
@@ -65,13 +70,13 @@ def generate_audio(video_file, seed_val, cfg_scale, num_steps, mode):
65
  state_dict = torch.load(onset_ckpt_path, map_location=device, weights_only=False)["state_dict"]
66
  new_state_dict = {}
67
  for key, value in state_dict.items():
68
- if "model.net.model" in key:
69
- new_key = key.replace("model.net.model", "net.model")
70
- elif "model.fc." in key:
71
- new_key = key.replace("model.fc", "fc")
72
- else:
73
- new_key = key
74
- new_state_dict[new_key] = value
75
 
76
  onset_model = VideoOnsetNet(False).to(device)
77
  onset_model.load_state_dict(new_state_dict)
@@ -111,26 +116,27 @@ def generate_audio(video_file, seed_val, cfg_scale, num_steps, mode):
111
  z = torch.randn(len(video_feats), model.in_channels, 204, 16, device=device).to(weight_dtype)
112
 
113
  sampling_kwargs = dict(
114
- model=model,
115
- latents=z,
116
- y=onset_feats_t,
117
- context=video_feats,
118
- num_steps=int(num_steps),
119
- heun=False,
120
- cfg_scale=float(cfg_scale),
121
- guidance_low=0.0,
122
- guidance_high=0.7,
123
- path_type="linear"
124
  )
125
 
126
  with torch.no_grad():
127
- if mode == "sde":
128
- samples = euler_maruyama_sampler(**sampling_kwargs)
129
- else:
130
- samples = euler_sampler(**sampling_kwargs)
131
 
132
- samples = vae.decode(samples / latents_scale).sample
133
- wav_samples = vocoder(samples.squeeze()).detach().cpu().numpy()
 
134
 
135
  audio_path = os.path.join(tmp_dir, "output.wav")
136
  sf.write(audio_path, wav_samples, sr)
@@ -140,40 +146,43 @@ def generate_audio(video_file, seed_val, cfg_scale, num_steps, mode):
140
  # Trim the silent input video to the target duration (no audio)
141
  trimmed_video = os.path.join(tmp_dir, "trimmed.mp4")
142
  output_video = os.path.join(tmp_dir, "output.mp4")
143
-
144
  (
145
- ffmpeg
146
- .input(silent_video, ss=0, t=duration)
147
- .output(trimmed_video, vcodec="libx264", an=None)
148
- .run(overwrite_output=True, quiet=True)
149
  )
150
 
151
  # Combine the trimmed silent video with the generated audio
152
  input_v = ffmpeg.input(trimmed_video)
153
  input_a = ffmpeg.input(audio_path)
154
  (
155
- ffmpeg
156
- .output(input_v, input_a, output_video, vcodec="libx264", acodec="aac", strict="experimental")
157
- .run(overwrite_output=True, quiet=True)
158
  )
159
 
160
  return output_video, audio_path
161
 
162
 
163
- demo = gr.Interface(
164
- fn=generate_audio,
165
- inputs=[
166
- gr.Video(label="Input Video"),
167
- gr.Number(label="Seed", value=0, precision=0),
168
- gr.Slider(label="CFG Scale", minimum=1, maximum=15, value=8, step=0.5),
169
- gr.Slider(label="Sampling Steps", minimum=10, maximum=50, value=25, step=1),
170
- gr.Radio(label="Sampling Mode", choices=["sde", "ode"], value="sde")
171
- ],
172
- outputs=[
173
- gr.Video(label="Output Video with Audio"),
174
- gr.Audio(label="Generated Audio")
175
- ],
176
- title="TARO: Video-to-Audio Synthesis (ICCV 2025)",
177
- description="Upload a video and generate synchronized audio using TARO. Optimal duration is 8.2s."
178
- )
179
- demo.queue().launch()
 
 
 
 
 
3
  import sys
4
 
5
  try:
6
+ import mmcv
7
+ print("mmcv already installed")
8
  except ImportError:
9
+ print("Installing mmcv with --no-build-isolation...")
10
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "--no-build-isolation", "mmcv>=2.0.0"])
11
+ print("mmcv installed successfully")
12
 
13
+ import torch
14
  import numpy as np
15
  import random
16
  import soundfile as sf
 
30
  taro_ckpt_path = hf_hub_download(repo_id=REPO_ID, filename="taro_ckpt.pt", cache_dir=CACHE_DIR)
31
  print("Checkpoints downloaded.")
32
 
33
+
34
  def set_global_seed(seed):
35
+ np.random.seed(seed % (2**32))
36
+ random.seed(seed)
37
+ torch.manual_seed(seed)
38
+ torch.cuda.manual_seed(seed)
39
+ torch.backends.cudnn.deterministic = True
40
+
41
+
42
+ def strip_audio_from_video(video_path, output_path):
43
+ """Strip any existing audio from a video file, outputting a silent video."""""
44
+ (
45
+ ffmpeg
46
+ .input(video_path)
47
+ .output(output_path, vcodec="libx264", an=None)
48
+ .run(overwrite_output=True, quiet=True)
49
+ )
50
 
51
+
52
+ @spaces.GPU(duration=300)
53
  def generate_audio(video_file, seed_val, cfg_scale, num_steps, mode):
54
+ seed_val = int(seed_val)
55
+ if seed_val < 0:
56
+ seed_val = random.randint(0, 2**32 - 1)
57
+ set_global_seed(seed_val)
58
  torch.set_grad_enabled(False)
 
59
  device = "cuda" if torch.cuda.is_available() else "cpu"
60
  weight_dtype = torch.bfloat16
61
 
 
70
  state_dict = torch.load(onset_ckpt_path, map_location=device, weights_only=False)["state_dict"]
71
  new_state_dict = {}
72
  for key, value in state_dict.items():
73
+ if "model.net.model" in key:
74
+ new_key = key.replace("model.net.model", "net.model")
75
+ elif "model.fc." in key:
76
+ new_key = key.replace("model.fc", "fc")
77
+ else:
78
+ new_key = key
79
+ new_state_dict[new_key] = value
80
 
81
  onset_model = VideoOnsetNet(False).to(device)
82
  onset_model.load_state_dict(new_state_dict)
 
116
  z = torch.randn(len(video_feats), model.in_channels, 204, 16, device=device).to(weight_dtype)
117
 
118
  sampling_kwargs = dict(
119
+ model=model,
120
+ latents=z,
121
+ y=onset_feats_t,
122
+ context=video_feats,
123
+ num_steps=int(num_steps),
124
+ heun=False,
125
+ cfg_scale=float(cfg_scale),
126
+ guidance_low=0.0,
127
+ guidance_high=0.7,
128
+ path_type="linear"
129
  )
130
 
131
  with torch.no_grad():
132
+ if mode == "sde":
133
+ samples = euler_maruyama_sampler(**sampling_kwargs)
134
+ else:
135
+ samples = euler_sampler(**sampling_kwargs)
136
 
137
+ samples = vae.decode(samples / latents_scale).sample
138
+ # Cast to float32 before vocoder (HiFi-GAN requires float32)
139
+ wav_samples = vocoder(samples.squeeze().float()).detach().cpu().numpy()
140
 
141
  audio_path = os.path.join(tmp_dir, "output.wav")
142
  sf.write(audio_path, wav_samples, sr)
 
146
  # Trim the silent input video to the target duration (no audio)
147
  trimmed_video = os.path.join(tmp_dir, "trimmed.mp4")
148
  output_video = os.path.join(tmp_dir, "output.mp4")
 
149
  (
150
+ ffmpeg
151
+ .input(silent_video, ss=0, t=duration)
152
+ .output(trimmed_video, vcodec="libx264", an=None)
153
+ .run(overwrite_output=True, quiet=True)
154
  )
155
 
156
  # Combine the trimmed silent video with the generated audio
157
  input_v = ffmpeg.input(trimmed_video)
158
  input_a = ffmpeg.input(audio_path)
159
  (
160
+ ffmpeg
161
+ .output(input_v, input_a, output_video, vcodec="libx264", acodec="aac", strict="experimental")
162
+ .run(overwrite_output=True, quiet=True)
163
  )
164
 
165
  return output_video, audio_path
166
 
167
 
168
+ def get_random_seed():
169
+ return random.randint(0, 2**32 - 1)
170
+
171
+
172
+ demo = gr.Interface(
173
+ fn=generate_audio,
174
+ inputs=[
175
+ gr.Video(label="Input Video"),
176
+ gr.Number(label="Seed", value=get_random_seed, precision=0),
177
+ gr.Slider(label="CFG Scale", minimum=1, maximum=15, value=8, step=0.5),
178
+ gr.Slider(label="Sampling Steps", minimum=10, maximum=50, value=25, step=1),
179
+ gr.Radio(label="Sampling Mode", choices=["sde", "ode"], value="sde")
180
+ ],
181
+ outputs=[
182
+ gr.Video(label="Output Video with Audio"),
183
+ gr.Audio(label="Generated Audio")
184
+ ],
185
+ title="TARO: Video-to-Audio Synthesis (ICCV 2025)",
186
+ description="Upload a video and generate synchronized audio using TARO. Optimal duration is 8.2s."
187
+ )
188
+ demo.queue().launch()