JackIsNotInTheBox commited on
Commit
39355fb
·
1 Parent(s): 9c25971

Fix all indentation errors - consistent 4-space indentation throughout

Browse files
Files changed (1) hide show
  1. app.py +72 -75
app.py CHANGED
@@ -3,12 +3,12 @@ import subprocess
3
  import sys
4
 
5
  try:
6
- import mmcv
7
- print("mmcv already installed")
8
  except ImportError:
9
- print("Installing mmcv with --no-build-isolation...")
10
- subprocess.check_call([sys.executable, "-m", "pip", "install", "--no-build-isolation", "mmcv>=2.0.0"])
11
- print("mmcv installed successfully")
12
 
13
  import torch
14
  import numpy as np
@@ -32,29 +32,29 @@ print("Checkpoints downloaded.")
32
 
33
 
34
  def set_global_seed(seed):
35
- np.random.seed(seed % (2**32))
36
- random.seed(seed)
37
- torch.manual_seed(seed)
38
- torch.cuda.manual_seed(seed)
39
- torch.backends.cudnn.deterministic = True
40
 
41
 
42
- def strip_audio_from_video(video_path, output_path):
43
- """Strip any existing audio from a video file, outputting a silent video."""""
44
- (
45
- ffmpeg
46
- .input(video_path)
47
- .output(output_path, vcodec="libx264", an=None)
48
- .run(overwrite_output=True, quiet=True)
49
- )
50
 
51
 
52
- @spaces.GPU(duration=300)
53
  def generate_audio(video_file, seed_val, cfg_scale, num_steps, mode):
54
- seed_val = int(seed_val)
55
- if seed_val < 0:
56
- seed_val = random.randint(0, 2**32 - 1)
57
- set_global_seed(seed_val)
58
  torch.set_grad_enabled(False)
59
  device = "cuda" if torch.cuda.is_available() else "cpu"
60
  weight_dtype = torch.bfloat16
@@ -70,13 +70,13 @@ def generate_audio(video_file, seed_val, cfg_scale, num_steps, mode):
70
  state_dict = torch.load(onset_ckpt_path, map_location=device, weights_only=False)["state_dict"]
71
  new_state_dict = {}
72
  for key, value in state_dict.items():
73
- if "model.net.model" in key:
74
- new_key = key.replace("model.net.model", "net.model")
75
- elif "model.fc." in key:
76
- new_key = key.replace("model.fc", "fc")
77
- else:
78
- new_key = key
79
- new_state_dict[new_key] = value
80
 
81
  onset_model = VideoOnsetNet(False).to(device)
82
  onset_model.load_state_dict(new_state_dict)
@@ -95,7 +95,6 @@ def generate_audio(video_file, seed_val, cfg_scale, num_steps, mode):
95
 
96
  tmp_dir = tempfile.mkdtemp()
97
 
98
- # Strip any existing audio from the input video before feature extraction
99
  silent_video = os.path.join(tmp_dir, "silent_input.mp4")
100
  strip_audio_from_video(video_file, silent_video)
101
 
@@ -108,7 +107,7 @@ def generate_audio(video_file, seed_val, cfg_scale, num_steps, mode):
108
  truncate_frame = int(fps * truncate / sr)
109
  truncate_onset = 120
110
 
111
- latents_scale = torch.tensor([0.18215]*8).view(1, 8, 1, 1).to(device)
112
 
113
  video_feats = torch.from_numpy(cavp_feats[:truncate_frame]).unsqueeze(0).to(device).to(weight_dtype)
114
  onset_feats_t = torch.from_numpy(onset_feats[:truncate_onset]).unsqueeze(0).to(device).to(weight_dtype)
@@ -116,25 +115,25 @@ def generate_audio(video_file, seed_val, cfg_scale, num_steps, mode):
116
  z = torch.randn(len(video_feats), model.in_channels, 204, 16, device=device).to(weight_dtype)
117
 
118
  sampling_kwargs = dict(
119
- model=model,
120
- latents=z,
121
- y=onset_feats_t,
122
- context=video_feats,
123
- num_steps=int(num_steps),
124
- heun=False,
125
- cfg_scale=float(cfg_scale),
126
- guidance_low=0.0,
127
- guidance_high=0.7,
128
- path_type="linear"
129
  )
130
 
131
  with torch.no_grad():
132
- if mode == "sde":
133
- samples = euler_maruyama_sampler(**sampling_kwargs)
134
- else:
135
- samples = euler_sampler(**sampling_kwargs)
136
 
137
- samples = vae.decode(samples / latents_scale).sample
138
  # Cast to float32 before vocoder (HiFi-GAN requires float32)
139
  wav_samples = vocoder(samples.squeeze().float()).detach().cpu().numpy()
140
 
@@ -143,46 +142,44 @@ def generate_audio(video_file, seed_val, cfg_scale, num_steps, mode):
143
 
144
  duration = truncate / sr
145
 
146
- # Trim the silent input video to the target duration (no audio)
147
  trimmed_video = os.path.join(tmp_dir, "trimmed.mp4")
148
  output_video = os.path.join(tmp_dir, "output.mp4")
149
  (
150
- ffmpeg
151
- .input(silent_video, ss=0, t=duration)
152
- .output(trimmed_video, vcodec="libx264", an=None)
153
- .run(overwrite_output=True, quiet=True)
154
  )
155
 
156
- # Combine the trimmed silent video with the generated audio
157
  input_v = ffmpeg.input(trimmed_video)
158
  input_a = ffmpeg.input(audio_path)
159
  (
160
- ffmpeg
161
- .output(input_v, input_a, output_video, vcodec="libx264", acodec="aac", strict="experimental")
162
- .run(overwrite_output=True, quiet=True)
163
  )
164
 
165
  return output_video, audio_path
166
 
167
 
168
  def get_random_seed():
169
- return random.randint(0, 2**32 - 1)
170
-
171
-
172
- demo = gr.Interface(
173
- fn=generate_audio,
174
- inputs=[
175
- gr.Video(label="Input Video"),
176
- gr.Number(label="Seed", value=get_random_seed, precision=0),
177
- gr.Slider(label="CFG Scale", minimum=1, maximum=15, value=8, step=0.5),
178
- gr.Slider(label="Sampling Steps", minimum=10, maximum=50, value=25, step=1),
179
- gr.Radio(label="Sampling Mode", choices=["sde", "ode"], value="sde")
180
- ],
181
- outputs=[
182
- gr.Video(label="Output Video with Audio"),
183
- gr.Audio(label="Generated Audio")
184
- ],
185
- title="TARO: Video-to-Audio Synthesis (ICCV 2025)",
186
- description="Upload a video and generate synchronized audio using TARO. Optimal duration is 8.2s."
187
- )
188
  demo.queue().launch()
 
3
  import sys
4
 
5
  try:
6
+ import mmcv
7
+ print("mmcv already installed")
8
  except ImportError:
9
+ print("Installing mmcv with --no-build-isolation...")
10
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "--no-build-isolation", "mmcv>=2.0.0"])
11
+ print("mmcv installed successfully")
12
 
13
  import torch
14
  import numpy as np
 
32
 
33
 
34
  def set_global_seed(seed):
35
+ np.random.seed(seed % (2**32))
36
+ random.seed(seed)
37
+ torch.manual_seed(seed)
38
+ torch.cuda.manual_seed(seed)
39
+ torch.backends.cudnn.deterministic = True
40
 
41
 
42
+ def strip_audio_from_video(video_path, output_path):
43
+ """Strip any existing audio from a video file, outputting a silent video."""
44
+ (
45
+ ffmpeg
46
+ .input(video_path)
47
+ .output(output_path, vcodec="libx264", an=None)
48
+ .run(overwrite_output=True, quiet=True)
49
+ )
50
 
51
 
52
+ @spaces.GPU(duration=300)
53
  def generate_audio(video_file, seed_val, cfg_scale, num_steps, mode):
54
+ seed_val = int(seed_val)
55
+ if seed_val < 0:
56
+ seed_val = random.randint(0, 2**32 - 1)
57
+ set_global_seed(seed_val)
58
  torch.set_grad_enabled(False)
59
  device = "cuda" if torch.cuda.is_available() else "cpu"
60
  weight_dtype = torch.bfloat16
 
70
  state_dict = torch.load(onset_ckpt_path, map_location=device, weights_only=False)["state_dict"]
71
  new_state_dict = {}
72
  for key, value in state_dict.items():
73
+ if "model.net.model" in key:
74
+ new_key = key.replace("model.net.model", "net.model")
75
+ elif "model.fc." in key:
76
+ new_key = key.replace("model.fc", "fc")
77
+ else:
78
+ new_key = key
79
+ new_state_dict[new_key] = value
80
 
81
  onset_model = VideoOnsetNet(False).to(device)
82
  onset_model.load_state_dict(new_state_dict)
 
95
 
96
  tmp_dir = tempfile.mkdtemp()
97
 
 
98
  silent_video = os.path.join(tmp_dir, "silent_input.mp4")
99
  strip_audio_from_video(video_file, silent_video)
100
 
 
107
  truncate_frame = int(fps * truncate / sr)
108
  truncate_onset = 120
109
 
110
+ latents_scale = torch.tensor([0.18215] * 8).view(1, 8, 1, 1).to(device)
111
 
112
  video_feats = torch.from_numpy(cavp_feats[:truncate_frame]).unsqueeze(0).to(device).to(weight_dtype)
113
  onset_feats_t = torch.from_numpy(onset_feats[:truncate_onset]).unsqueeze(0).to(device).to(weight_dtype)
 
115
  z = torch.randn(len(video_feats), model.in_channels, 204, 16, device=device).to(weight_dtype)
116
 
117
  sampling_kwargs = dict(
118
+ model=model,
119
+ latents=z,
120
+ y=onset_feats_t,
121
+ context=video_feats,
122
+ num_steps=int(num_steps),
123
+ heun=False,
124
+ cfg_scale=float(cfg_scale),
125
+ guidance_low=0.0,
126
+ guidance_high=0.7,
127
+ path_type="linear",
128
  )
129
 
130
  with torch.no_grad():
131
+ if mode == "sde":
132
+ samples = euler_maruyama_sampler(**sampling_kwargs)
133
+ else:
134
+ samples = euler_sampler(**sampling_kwargs)
135
 
136
+ samples = vae.decode(samples / latents_scale).sample
137
  # Cast to float32 before vocoder (HiFi-GAN requires float32)
138
  wav_samples = vocoder(samples.squeeze().float()).detach().cpu().numpy()
139
 
 
142
 
143
  duration = truncate / sr
144
 
 
145
  trimmed_video = os.path.join(tmp_dir, "trimmed.mp4")
146
  output_video = os.path.join(tmp_dir, "output.mp4")
147
  (
148
+ ffmpeg
149
+ .input(silent_video, ss=0, t=duration)
150
+ .output(trimmed_video, vcodec="libx264", an=None)
151
+ .run(overwrite_output=True, quiet=True)
152
  )
153
 
 
154
  input_v = ffmpeg.input(trimmed_video)
155
  input_a = ffmpeg.input(audio_path)
156
  (
157
+ ffmpeg
158
+ .output(input_v, input_a, output_video, vcodec="libx264", acodec="aac", strict="experimental")
159
+ .run(overwrite_output=True, quiet=True)
160
  )
161
 
162
  return output_video, audio_path
163
 
164
 
165
  def get_random_seed():
166
+ return random.randint(0, 2**32 - 1)
167
+
168
+
169
+ demo = gr.Interface(
170
+ fn=generate_audio,
171
+ inputs=[
172
+ gr.Video(label="Input Video"),
173
+ gr.Number(label="Seed", value=get_random_seed, precision=0),
174
+ gr.Slider(label="CFG Scale", minimum=1, maximum=15, value=8, step=0.5),
175
+ gr.Slider(label="Sampling Steps", minimum=10, maximum=50, value=25, step=1),
176
+ gr.Radio(label="Sampling Mode", choices=["sde", "ode"], value="sde"),
177
+ ],
178
+ outputs=[
179
+ gr.Video(label="Output Video with Audio"),
180
+ gr.Audio(label="Generated Audio"),
181
+ ],
182
+ title="TARO: Video-to-Audio Synthesis (ICCV 2025)",
183
+ description="Upload a video and generate synchronized audio using TARO. Optimal duration is 8.2s.",
184
+ )
185
  demo.queue().launch()