JackIsNotInTheBox commited on
Commit
d0e121d
·
1 Parent(s): 5ce02be

Fix all indentation errors - consistent 4-space indentation throughout

Browse files
Files changed (1) hide show
  1. app.py +89 -82
app.py CHANGED
@@ -3,14 +3,14 @@ import subprocess
3
  import sys
4
 
5
  try:
6
- import mmcv
7
- print("mmcv already installed")
8
  except ImportError:
9
- print("Installing mmcv with --no-build-isolation...")
10
- subprocess.check_call([sys.executable, "-m", "pip", "install", "--no-build-isolation", "mmcv>=2.0.0"])
11
- print("mmcv installed successfully")
12
 
13
- import torch
14
  import numpy as np
15
  import random
16
  import soundfile as sf
@@ -30,28 +30,31 @@ onset_ckpt_path = hf_hub_download(repo_id=REPO_ID, filename="onset_model.ckpt",
30
  taro_ckpt_path = hf_hub_download(repo_id=REPO_ID, filename="taro_ckpt.pt", cache_dir=CACHE_DIR)
31
  print("Checkpoints downloaded.")
32
 
 
33
  def set_global_seed(seed):
34
- np.random.seed(seed % (2**32))
35
- random.seed(seed)
36
- torch.manual_seed(seed)
37
- torch.cuda.manual_seed(seed)
38
- torch.backends.cudnn.deterministic = True
39
-
40
- def strip_audio_from_video(video_path, output_path):
41
- """Strip any existing audio from a video file, outputting a silent video."""""
42
- (
43
- ffmpeg
44
- .input(video_path)
45
- .output(output_path, vcodec="libx264", an=None)
46
- .run(overwrite_output=True, quiet=True)
47
- )
48
-
49
- @spaces.GPU(duration=300)
 
 
50
  def generate_audio(video_file, seed_val, cfg_scale, num_steps, mode):
51
- seed_val = int(seed_val)
52
- if seed_val < 0:
53
- seed_val = random.randint(0, 2**32 - 1)
54
- set_global_seed(seed_val)
55
  torch.set_grad_enabled(False)
56
  device = "cuda" if torch.cuda.is_available() else "cpu"
57
  weight_dtype = torch.bfloat16
@@ -62,18 +65,20 @@ def generate_audio(video_file, seed_val, cfg_scale, num_steps, mode):
62
  from samplers import euler_sampler, euler_maruyama_sampler
63
  from diffusers import AudioLDM2Pipeline
64
 
65
- extract_cavp = Extract_CAVP_Features(device=device, config_path="./cavp/cavp.yaml", ckpt_path=cavp_ckpt_path)
 
 
66
 
67
  state_dict = torch.load(onset_ckpt_path, map_location=device, weights_only=False)["state_dict"]
68
  new_state_dict = {}
69
  for key, value in state_dict.items():
70
- if "model.net.model" in key:
71
- new_key = key.replace("model.net.model", "net.model")
72
- elif "model.fc." in key:
73
- new_key = key.replace("model.fc", "fc")
74
- else:
75
- new_key = key
76
- new_state_dict[new_key] = value
77
  onset_model = VideoOnsetNet(False).to(device)
78
  onset_model.load_state_dict(new_state_dict)
79
  onset_model.eval()
@@ -110,37 +115,37 @@ def generate_audio(video_file, seed_val, cfg_scale, num_steps, mode):
110
  onset_feats_sliced = onset_feats[:truncate_onset]
111
  actual_onset_len = onset_feats_sliced.shape[0]
112
  if actual_onset_len < truncate_onset:
113
- pad_len = truncate_onset - actual_onset_len
114
- onset_feats_sliced = np.pad(
115
- onset_feats_sliced,
116
- ((0, pad_len),),
117
- mode="constant",
118
- constant_values=0,
119
- )
120
- onset_feats_t = torch.from_numpy(onset_feats_sliced).unsqueeze(0).to(device).to(weight_dtype)
121
 
122
  z = torch.randn(len(video_feats), model.in_channels, 204, 16, device=device).to(weight_dtype)
123
 
124
  sampling_kwargs = dict(
125
- model=model,
126
- latents=z,
127
- y=onset_feats_t,
128
- context=video_feats,
129
- num_steps=int(num_steps),
130
- heun=False,
131
- cfg_scale=float(cfg_scale),
132
- guidance_low=0.0,
133
- guidance_high=0.7,
134
- path_type="linear",
135
  )
136
 
137
  with torch.no_grad():
138
- if mode == "sde":
139
- samples = euler_maruyama_sampler(**sampling_kwargs)
140
- else:
141
- samples = euler_sampler(**sampling_kwargs)
142
 
143
- samples = vae.decode(samples / latents_scale).sample
144
 
145
  # Cast to float32 before vocoder (HiFi-GAN requires float32)
146
  wav_samples = vocoder(samples.squeeze().float()).detach().cpu().numpy()
@@ -153,39 +158,41 @@ def generate_audio(video_file, seed_val, cfg_scale, num_steps, mode):
153
  output_video = os.path.join(tmp_dir, "output.mp4")
154
 
155
  (
156
- ffmpeg
157
- .input(silent_video, ss=0, t=duration)
158
- .output(trimmed_video, vcodec="libx264", an=None)
159
- .run(overwrite_output=True, quiet=True)
160
  )
161
 
162
  input_v = ffmpeg.input(trimmed_video)
163
  input_a = ffmpeg.input(audio_path)
164
  (
165
- ffmpeg
166
- .output(input_v, input_a, output_video, vcodec="libx264", acodec="aac", strict="experimental")
167
- .run(overwrite_output=True, quiet=True)
168
  )
169
 
170
  return output_video, audio_path
171
 
 
172
  def get_random_seed():
173
- return random.randint(0, 2**32 - 1)
174
-
175
- demo = gr.Interface(
176
- fn=generate_audio,
177
- inputs=[
178
- gr.Video(label="Input Video"),
179
- gr.Number(label="Seed", value=get_random_seed, precision=0),
180
- gr.Slider(label="CFG Scale", minimum=1, maximum=15, value=8, step=0.5),
181
- gr.Slider(label="Sampling Steps", minimum=10, maximum=50, value=25, step=1),
182
- gr.Radio(label="Sampling Mode", choices=["sde", "ode"], value="sde"),
183
- ],
184
- outputs=[
185
- gr.Video(label="Output Video with Audio"),
186
- gr.Audio(label="Generated Audio"),
187
- ],
188
- title="TARO: Video-to-Audio Synthesis (ICCV 2025)",
189
- description="Upload a video and generate synchronized audio using TARO. Optimal duration is 8.2s.",
190
- )
 
191
  demo.queue().launch()
 
3
  import sys
4
 
5
  try:
6
+ import mmcv
7
+ print("mmcv already installed")
8
  except ImportError:
9
+ print("Installing mmcv with --no-build-isolation...")
10
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "--no-build-isolation", "mmcv>=2.0.0"])
11
+ print("mmcv installed successfully")
12
 
13
+ import torch
14
  import numpy as np
15
  import random
16
  import soundfile as sf
 
30
  taro_ckpt_path = hf_hub_download(repo_id=REPO_ID, filename="taro_ckpt.pt", cache_dir=CACHE_DIR)
31
  print("Checkpoints downloaded.")
32
 
33
+
34
  def set_global_seed(seed):
35
+ np.random.seed(seed % (2**32))
36
+ random.seed(seed)
37
+ torch.manual_seed(seed)
38
+ torch.cuda.manual_seed(seed)
39
+ torch.backends.cudnn.deterministic = True
40
+
41
+
42
+ def strip_audio_from_video(video_path, output_path):
43
+ """Strip any existing audio from a video file, outputting a silent video."""
44
+ (
45
+ ffmpeg
46
+ .input(video_path)
47
+ .output(output_path, vcodec="libx264", an=None)
48
+ .run(overwrite_output=True, quiet=True)
49
+ )
50
+
51
+
52
+ @spaces.GPU(duration=300)
53
  def generate_audio(video_file, seed_val, cfg_scale, num_steps, mode):
54
+ seed_val = int(seed_val)
55
+ if seed_val < 0:
56
+ seed_val = random.randint(0, 2**32 - 1)
57
+ set_global_seed(seed_val)
58
  torch.set_grad_enabled(False)
59
  device = "cuda" if torch.cuda.is_available() else "cpu"
60
  weight_dtype = torch.bfloat16
 
65
  from samplers import euler_sampler, euler_maruyama_sampler
66
  from diffusers import AudioLDM2Pipeline
67
 
68
+ extract_cavp = Extract_CAVP_Features(
69
+ device=device, config_path="./cavp/cavp.yaml", ckpt_path=cavp_ckpt_path
70
+ )
71
 
72
  state_dict = torch.load(onset_ckpt_path, map_location=device, weights_only=False)["state_dict"]
73
  new_state_dict = {}
74
  for key, value in state_dict.items():
75
+ if "model.net.model" in key:
76
+ new_key = key.replace("model.net.model", "net.model")
77
+ elif "model.fc." in key:
78
+ new_key = key.replace("model.fc", "fc")
79
+ else:
80
+ new_key = key
81
+ new_state_dict[new_key] = value
82
  onset_model = VideoOnsetNet(False).to(device)
83
  onset_model.load_state_dict(new_state_dict)
84
  onset_model.eval()
 
115
  onset_feats_sliced = onset_feats[:truncate_onset]
116
  actual_onset_len = onset_feats_sliced.shape[0]
117
  if actual_onset_len < truncate_onset:
118
+ pad_len = truncate_onset - actual_onset_len
119
+ onset_feats_sliced = np.pad(
120
+ onset_feats_sliced,
121
+ ((0, pad_len),),
122
+ mode="constant",
123
+ constant_values=0,
124
+ )
125
+ onset_feats_t = torch.from_numpy(onset_feats_sliced).unsqueeze(0).to(device).to(weight_dtype)
126
 
127
  z = torch.randn(len(video_feats), model.in_channels, 204, 16, device=device).to(weight_dtype)
128
 
129
  sampling_kwargs = dict(
130
+ model=model,
131
+ latents=z,
132
+ y=onset_feats_t,
133
+ context=video_feats,
134
+ num_steps=int(num_steps),
135
+ heun=False,
136
+ cfg_scale=float(cfg_scale),
137
+ guidance_low=0.0,
138
+ guidance_high=0.7,
139
+ path_type="linear",
140
  )
141
 
142
  with torch.no_grad():
143
+ if mode == "sde":
144
+ samples = euler_maruyama_sampler(**sampling_kwargs)
145
+ else:
146
+ samples = euler_sampler(**sampling_kwargs)
147
 
148
+ samples = vae.decode(samples / latents_scale).sample
149
 
150
  # Cast to float32 before vocoder (HiFi-GAN requires float32)
151
  wav_samples = vocoder(samples.squeeze().float()).detach().cpu().numpy()
 
158
  output_video = os.path.join(tmp_dir, "output.mp4")
159
 
160
  (
161
+ ffmpeg
162
+ .input(silent_video, ss=0, t=duration)
163
+ .output(trimmed_video, vcodec="libx264", an=None)
164
+ .run(overwrite_output=True, quiet=True)
165
  )
166
 
167
  input_v = ffmpeg.input(trimmed_video)
168
  input_a = ffmpeg.input(audio_path)
169
  (
170
+ ffmpeg
171
+ .output(input_v, input_a, output_video, vcodec="libx264", acodec="aac", strict="experimental")
172
+ .run(overwrite_output=True, quiet=True)
173
  )
174
 
175
  return output_video, audio_path
176
 
177
+
178
  def get_random_seed():
179
+ return random.randint(0, 2**32 - 1)
180
+
181
+
182
+ demo = gr.Interface(
183
+ fn=generate_audio,
184
+ inputs=[
185
+ gr.Video(label="Input Video"),
186
+ gr.Number(label="Seed", value=get_random_seed, precision=0),
187
+ gr.Slider(label="CFG Scale", minimum=1, maximum=15, value=8, step=0.5),
188
+ gr.Slider(label="Sampling Steps", minimum=10, maximum=50, value=25, step=1),
189
+ gr.Radio(label="Sampling Mode", choices=["sde", "ode"], value="sde"),
190
+ ],
191
+ outputs=[
192
+ gr.Video(label="Output Video with Audio"),
193
+ gr.Audio(label="Generated Audio"),
194
+ ],
195
+ title="TARO: Video-to-Audio Synthesis (ICCV 2025)",
196
+ description="Upload a video and generate synchronized audio using TARO. Optimal duration is 8.2s.",
197
+ )
198
  demo.queue().launch()