JackIsNotInTheBox commited on
Commit
211f7b6
·
1 Parent(s): 47a6cd2

Strip original audio from uploaded videos before processing and final output

Browse files
Files changed (1) hide show
  1. app.py +117 -44
app.py CHANGED
@@ -3,14 +3,14 @@ import subprocess
3
  import sys
4
 
5
  try:
6
- import mmcv
7
- print("mmcv already installed")
8
  except ImportError:
9
- print("Installing mmcv with --no-build-isolation...")
10
- subprocess.check_call([sys.executable, "-m", "pip", "install", "--no-build-isolation", "mmcv>=2.0.0"])
11
- print("mmcv installed successfully")
12
 
13
- import torch
14
  import numpy as np
15
  import random
16
  import soundfile as sf
@@ -31,76 +31,149 @@ taro_ckpt_path = hf_hub_download(repo_id=REPO_ID, filename="taro_ckpt.pt", cache
31
  print("Checkpoints downloaded.")
32
 
33
  def set_global_seed(seed):
34
- np.random.seed(seed % (2**32))
35
- random.seed(seed)
36
- torch.manual_seed(seed)
37
- torch.cuda.manual_seed(seed)
38
- torch.backends.cudnn.deterministic = True
39
 
40
- @spaces.GPU(duration=300)
 
 
 
 
 
 
 
 
 
41
  def generate_audio(video_file, seed_val, cfg_scale, num_steps, mode):
42
- set_global_seed(int(seed_val))
43
- torch.set_grad_enabled(False)
44
- device = "cuda" if torch.cuda.is_available() else "cpu"
45
- weight_dtype = torch.bfloat16
46
- from cavp_util import Extract_CAVP_Features
47
- from onset_util import VideoOnsetNet, extract_onset
48
- from models import MMDiT
49
- from samplers import euler_sampler, euler_maruyama_sampler
50
- from diffusers import AudioLDM2Pipeline
51
- extract_cavp = Extract_CAVP_Features(device=device, config_path="./cavp/cavp.yaml", ckpt_path=cavp_ckpt_path)
52
- state_dict = torch.load(onset_ckpt_path, map_location=device, weights_only=False)["state_dict"]
53
- new_state_dict = {}
54
- for key, value in state_dict.items():
55
- if "model.net.model" in key:
56
- new_key = key.replace("model.net.model", "net.model")
57
- elif "model.fc." in key:
58
- new_key = key.replace("model.fc", "fc")
59
- else:
60
- new_key = key
61
- new_state_dict[new_key] = value
62
- onset_model = VideoOnsetNet(False).to(device)
 
 
 
 
 
63
  onset_model.load_state_dict(new_state_dict)
64
  onset_model.eval()
 
65
  model = MMDiT(adm_in_channels=120, z_dims=[768], encoder_depth=4).to(device)
66
  ckpt = torch.load(taro_ckpt_path, map_location=device, weights_only=False)["ema"]
67
  model.load_state_dict(ckpt)
68
  model.eval()
69
  model.to(weight_dtype)
 
70
  model_audioldm = AudioLDM2Pipeline.from_pretrained("cvssp/audioldm2")
71
  vae = model_audioldm.vae.to(device)
72
  vae.eval()
73
  vocoder = model_audioldm.vocoder.to(device)
 
74
  tmp_dir = tempfile.mkdtemp()
75
- cavp_feats = extract_cavp(video_file, tmp_path=tmp_dir)
76
- onset_feats = extract_onset(video_file, onset_model, tmp_path=tmp_dir, device=device)
 
 
 
 
 
 
77
  sr = 16000
78
  truncate = 131072
79
  fps = 4
80
  truncate_frame = int(fps * truncate / sr)
81
  truncate_onset = 120
 
82
  latents_scale = torch.tensor([0.18215]*8).view(1, 8, 1, 1).to(device)
 
83
  video_feats = torch.from_numpy(cavp_feats[:truncate_frame]).unsqueeze(0).to(device).to(weight_dtype)
84
  onset_feats_t = torch.from_numpy(onset_feats[:truncate_onset]).unsqueeze(0).to(device).to(weight_dtype)
 
85
  z = torch.randn(len(video_feats), model.in_channels, 204, 16, device=device).to(weight_dtype)
86
- sampling_kwargs = dict(model=model, latents=z, y=onset_feats_t, context=video_feats, num_steps=int(num_steps), heun=False, cfg_scale=float(cfg_scale), guidance_low=0.0, guidance_high=0.7, path_type="linear")
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  with torch.no_grad():
88
- if mode == "sde":
89
- samples = euler_maruyama_sampler(**sampling_kwargs)
90
- else:
91
- samples = euler_sampler(**sampling_kwargs)
92
- samples = vae.decode(samples / latents_scale).sample
 
93
  wav_samples = vocoder(samples.squeeze()).detach().cpu().numpy()
 
94
  audio_path = os.path.join(tmp_dir, "output.wav")
95
  sf.write(audio_path, wav_samples, sr)
 
96
  duration = truncate / sr
 
 
97
  trimmed_video = os.path.join(tmp_dir, "trimmed.mp4")
98
  output_video = os.path.join(tmp_dir, "output.mp4")
99
- ffmpeg.input(video_file, ss=0, t=duration).output(trimmed_video, vcodec="libx264", an=None).run(overwrite_output=True, quiet=True)
 
 
 
 
 
 
 
 
100
  input_v = ffmpeg.input(trimmed_video)
101
  input_a = ffmpeg.input(audio_path)
102
- ffmpeg.output(input_v, input_a, output_video, vcodec="libx264", acodec="aac", strict="experimental").run(overwrite_output=True, quiet=True)
 
 
 
 
 
103
  return output_video, audio_path
104
 
105
- demo = gr.Interface(fn=generate_audio, inputs=[gr.Video(label="Input Video"), gr.Number(label="Seed", value=0, precision=0), gr.Slider(label="CFG Scale", minimum=1, maximum=15, value=8, step=0.5), gr.Slider(label="Sampling Steps", minimum=10, maximum=50, value=25, step=1), gr.Radio(label="Sampling Mode", choices=["sde", "ode"], value="sde")], outputs=[gr.Video(label="Output Video with Audio"), gr.Audio(label="Generated Audio")], title="TARO: Video-to-Audio Synthesis (ICCV 2025)", description="Upload a video and generate synchronized audio using TARO. Optimal duration is as close to 8 seconds as possible.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  demo.queue().launch()
 
3
  import sys
4
 
5
  try:
6
+ import mmcv
7
+ print("mmcv already installed")
8
  except ImportError:
9
+ print("Installing mmcv with --no-build-isolation...")
10
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "--no-build-isolation", "mmcv>=2.0.0"])
11
+ print("mmcv installed successfully")
12
 
13
+ import torch
14
  import numpy as np
15
  import random
16
  import soundfile as sf
 
31
  print("Checkpoints downloaded.")
32
 
33
  def set_global_seed(seed):
34
+ np.random.seed(seed % (2**32))
35
+ random.seed(seed)
36
+ torch.manual_seed(seed)
37
+ torch.cuda.manual_seed(seed)
38
+ torch.backends.cudnn.deterministic = True
39
 
40
+ def strip_audio_from_video(video_path, output_path):
41
+ """Strip any existing audio from a video file, outputting a silent video."""""
42
+ (
43
+ ffmpeg
44
+ .input(video_path)
45
+ .output(output_path, vcodec="libx264", an=None)
46
+ .run(overwrite_output=True, quiet=True)
47
+ )
48
+
49
+ @spaces.GPU(duration=300)
50
  def generate_audio(video_file, seed_val, cfg_scale, num_steps, mode):
51
+ set_global_seed(int(seed_val))
52
+ torch.set_grad_enabled(False)
53
+
54
+ device = "cuda" if torch.cuda.is_available() else "cpu"
55
+ weight_dtype = torch.bfloat16
56
+
57
+ from cavp_util import Extract_CAVP_Features
58
+ from onset_util import VideoOnsetNet, extract_onset
59
+ from models import MMDiT
60
+ from samplers import euler_sampler, euler_maruyama_sampler
61
+ from diffusers import AudioLDM2Pipeline
62
+
63
+ extract_cavp = Extract_CAVP_Features(device=device, config_path="./cavp/cavp.yaml", ckpt_path=cavp_ckpt_path)
64
+
65
+ state_dict = torch.load(onset_ckpt_path, map_location=device, weights_only=False)["state_dict"]
66
+ new_state_dict = {}
67
+ for key, value in state_dict.items():
68
+ if "model.net.model" in key:
69
+ new_key = key.replace("model.net.model", "net.model")
70
+ elif "model.fc." in key:
71
+ new_key = key.replace("model.fc", "fc")
72
+ else:
73
+ new_key = key
74
+ new_state_dict[new_key] = value
75
+
76
+ onset_model = VideoOnsetNet(False).to(device)
77
  onset_model.load_state_dict(new_state_dict)
78
  onset_model.eval()
79
+
80
  model = MMDiT(adm_in_channels=120, z_dims=[768], encoder_depth=4).to(device)
81
  ckpt = torch.load(taro_ckpt_path, map_location=device, weights_only=False)["ema"]
82
  model.load_state_dict(ckpt)
83
  model.eval()
84
  model.to(weight_dtype)
85
+
86
  model_audioldm = AudioLDM2Pipeline.from_pretrained("cvssp/audioldm2")
87
  vae = model_audioldm.vae.to(device)
88
  vae.eval()
89
  vocoder = model_audioldm.vocoder.to(device)
90
+
91
  tmp_dir = tempfile.mkdtemp()
92
+
93
+ # Strip any existing audio from the input video before feature extraction
94
+ silent_video = os.path.join(tmp_dir, "silent_input.mp4")
95
+ strip_audio_from_video(video_file, silent_video)
96
+
97
+ cavp_feats = extract_cavp(silent_video, tmp_path=tmp_dir)
98
+ onset_feats = extract_onset(silent_video, onset_model, tmp_path=tmp_dir, device=device)
99
+
100
  sr = 16000
101
  truncate = 131072
102
  fps = 4
103
  truncate_frame = int(fps * truncate / sr)
104
  truncate_onset = 120
105
+
106
  latents_scale = torch.tensor([0.18215]*8).view(1, 8, 1, 1).to(device)
107
+
108
  video_feats = torch.from_numpy(cavp_feats[:truncate_frame]).unsqueeze(0).to(device).to(weight_dtype)
109
  onset_feats_t = torch.from_numpy(onset_feats[:truncate_onset]).unsqueeze(0).to(device).to(weight_dtype)
110
+
111
  z = torch.randn(len(video_feats), model.in_channels, 204, 16, device=device).to(weight_dtype)
112
+
113
+ sampling_kwargs = dict(
114
+ model=model,
115
+ latents=z,
116
+ y=onset_feats_t,
117
+ context=video_feats,
118
+ num_steps=int(num_steps),
119
+ heun=False,
120
+ cfg_scale=float(cfg_scale),
121
+ guidance_low=0.0,
122
+ guidance_high=0.7,
123
+ path_type="linear"
124
+ )
125
+
126
  with torch.no_grad():
127
+ if mode == "sde":
128
+ samples = euler_maruyama_sampler(**sampling_kwargs)
129
+ else:
130
+ samples = euler_sampler(**sampling_kwargs)
131
+
132
+ samples = vae.decode(samples / latents_scale).sample
133
  wav_samples = vocoder(samples.squeeze()).detach().cpu().numpy()
134
+
135
  audio_path = os.path.join(tmp_dir, "output.wav")
136
  sf.write(audio_path, wav_samples, sr)
137
+
138
  duration = truncate / sr
139
+
140
+ # Trim the silent input video to the target duration (no audio)
141
  trimmed_video = os.path.join(tmp_dir, "trimmed.mp4")
142
  output_video = os.path.join(tmp_dir, "output.mp4")
143
+
144
+ (
145
+ ffmpeg
146
+ .input(silent_video, ss=0, t=duration)
147
+ .output(trimmed_video, vcodec="libx264", an=None)
148
+ .run(overwrite_output=True, quiet=True)
149
+ )
150
+
151
+ # Combine the trimmed silent video with the generated audio
152
  input_v = ffmpeg.input(trimmed_video)
153
  input_a = ffmpeg.input(audio_path)
154
+ (
155
+ ffmpeg
156
+ .output(input_v, input_a, output_video, vcodec="libx264", acodec="aac", strict="experimental")
157
+ .run(overwrite_output=True, quiet=True)
158
+ )
159
+
160
  return output_video, audio_path
161
 
162
+
163
+ demo = gr.Interface(
164
+ fn=generate_audio,
165
+ inputs=[
166
+ gr.Video(label="Input Video"),
167
+ gr.Number(label="Seed", value=0, precision=0),
168
+ gr.Slider(label="CFG Scale", minimum=1, maximum=15, value=8, step=0.5),
169
+ gr.Slider(label="Sampling Steps", minimum=10, maximum=50, value=25, step=1),
170
+ gr.Radio(label="Sampling Mode", choices=["sde", "ode"], value="sde")
171
+ ],
172
+ outputs=[
173
+ gr.Video(label="Output Video with Audio"),
174
+ gr.Audio(label="Generated Audio")
175
+ ],
176
+ title="TARO: Video-to-Audio Synthesis (ICCV 2025)",
177
+ description="Upload a video and generate synchronized audio using TARO. Optimal duration is as close to 8 seconds as possible."
178
+ )
179
  demo.queue().launch()