abidlabs HF Staff commited on
Commit
6d9caf3
·
verified ·
1 Parent(s): 57b1197

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +112 -186
app.py CHANGED
@@ -38,18 +38,32 @@ setup_eval_logging()
38
 
39
 
40
  def get_model() -> tuple[MMAudio, FeaturesUtils, SequenceConfig]:
 
 
 
 
 
 
 
 
 
 
 
 
41
  seq_cfg = model.seq_cfg
42
 
43
  net: MMAudio = get_my_mmaudio(model.model_name).to(device, dtype).eval()
44
  net.load_weights(torch.load(model.model_path, map_location=device, weights_only=True))
45
  log.info(f'Loaded weights from {model.model_path}')
46
 
47
- feature_utils = FeaturesUtils(tod_vae_ckpt=model.vae_path,
48
- synchformer_ckpt=model.synchformer_ckpt,
49
- enable_conditions=True,
50
- mode=model.mode,
51
- bigvgan_vocoder_ckpt=model.bigvgan_16k_path,
52
- need_vae_encoder=False)
 
 
53
  feature_utils = feature_utils.to(device, dtype).eval()
54
 
55
  return net, feature_utils, seq_cfg
@@ -60,222 +74,134 @@ net, feature_utils, seq_cfg = get_model()
60
 
61
  @spaces.GPU(duration=120)
62
  @torch.inference_mode()
63
- def video_to_audio(video: gr.Video, prompt: str, negative_prompt: str, seed: int, num_steps: int,
64
- cfg_strength: float, duration: float):
65
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  rng = torch.Generator(device=device)
67
  if seed >= 0:
68
  rng.manual_seed(seed)
69
  else:
70
  rng.seed()
 
71
  fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
72
 
73
  video_info = load_video(video, duration)
74
  clip_frames = video_info.clip_frames
75
  sync_frames = video_info.sync_frames
76
  duration = video_info.duration_sec
 
77
  clip_frames = clip_frames.unsqueeze(0)
78
  sync_frames = sync_frames.unsqueeze(0)
 
79
  seq_cfg.duration = duration
80
- net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
81
-
82
- audios = generate(clip_frames,
83
- sync_frames, [prompt],
84
- negative_text=[negative_prompt],
85
- feature_utils=feature_utils,
86
- net=net,
87
- fm=fm,
88
- rng=rng,
89
- cfg_strength=cfg_strength)
 
 
 
 
 
 
 
90
  audio = audios.float().cpu()[0]
91
 
92
- # current_time_string = datetime.now().strftime('%Y%m%d_%H%M%S')
93
  video_save_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name
94
- # output_dir.mkdir(exist_ok=True, parents=True)
95
- # video_save_path = output_dir / f'{current_time_string}.mp4'
96
  make_video(video_info, video_save_path, audio, sampling_rate=seq_cfg.sampling_rate)
97
  log.info(f'Saved video to {video_save_path}')
 
98
  return video_save_path
99
 
100
 
101
  @spaces.GPU(duration=120)
102
  @torch.inference_mode()
103
- def text_to_audio(prompt: str, negative_prompt: str, seed: int, num_steps: int, cfg_strength: float,
104
- duration: float):
105
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  rng = torch.Generator(device=device)
107
  if seed >= 0:
108
  rng.manual_seed(seed)
109
  else:
110
  rng.seed()
 
111
  fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
112
 
113
  clip_frames = sync_frames = None
114
  seq_cfg.duration = duration
115
- net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
116
-
117
- audios = generate(clip_frames,
118
- sync_frames, [prompt],
119
- negative_text=[negative_prompt],
120
- feature_utils=feature_utils,
121
- net=net,
122
- fm=fm,
123
- rng=rng,
124
- cfg_strength=cfg_strength)
 
 
 
 
 
 
 
125
  audio = audios.float().cpu()[0]
126
 
127
  audio_save_path = tempfile.NamedTemporaryFile(delete=False, suffix='.flac').name
128
  torchaudio.save(audio_save_path, audio, seq_cfg.sampling_rate)
129
  log.info(f'Saved audio to {audio_save_path}')
130
- return audio_save_path
131
 
132
-
133
- video_to_audio_tab = gr.Interface(
134
- fn=video_to_audio,
135
- description="""
136
- Project page: <a href="https://hkchengrex.com/MMAudio/">https://hkchengrex.com/MMAudio/</a><br>
137
- Code: <a href="https://github.com/hkchengrex/MMAudio">https://github.com/hkchengrex/MMAudio</a><br>
138
-
139
- Ho Kei Cheng, Masato Ishii, Akio Hayakawa, Takashi Shibuya, Alexander Schwing, Yuki Mitsufuji
140
-
141
- University of Illinois Urbana-Champaign, Sony AI, and Sony Group Corporation
142
-
143
- CVPR 2025
144
-
145
- NOTE: It takes longer to process high-resolution videos (>384 px on the shorter side).
146
- Doing so does not improve results.
147
-
148
- The model has been trained on 8-second videos. Using much longer or shorter videos will degrade performance. Around 5s~12s should be fine.
149
- """,
150
- inputs=[
151
- gr.Video(),
152
- gr.Text(label='Prompt'),
153
- gr.Text(label='Negative prompt', value='music'),
154
- gr.Number(label='Seed (-1: random)', value=-1, precision=0, minimum=-1),
155
- gr.Number(label='Num steps', value=25, precision=0, minimum=1),
156
- gr.Number(label='Guidance Strength', value=4.5, minimum=1),
157
- gr.Number(label='Duration (sec)', value=8, minimum=1),
158
- ],
159
- outputs='playable_video',
160
- cache_examples=False,
161
- title='MMAudio — Video-to-Audio Synthesis',
162
- examples=[
163
- [
164
- 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_beach.mp4',
165
- 'waves, seagulls',
166
- '',
167
- 0,
168
- 25,
169
- 4.5,
170
- 10,
171
- ],
172
- [
173
- 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_serpent.mp4',
174
- '',
175
- 'music',
176
- 0,
177
- 25,
178
- 4.5,
179
- 10,
180
- ],
181
- [
182
- 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_seahorse.mp4',
183
- 'bubbles',
184
- '',
185
- 0,
186
- 25,
187
- 4.5,
188
- 10,
189
- ],
190
- [
191
- 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_india.mp4',
192
- 'Indian holy music',
193
- '',
194
- 0,
195
- 25,
196
- 4.5,
197
- 10,
198
- ],
199
- [
200
- 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_galloping.mp4',
201
- 'galloping',
202
- '',
203
- 0,
204
- 25,
205
- 4.5,
206
- 10,
207
- ],
208
- [
209
- 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_kraken.mp4',
210
- 'waves, storm',
211
- '',
212
- 0,
213
- 25,
214
- 4.5,
215
- 10,
216
- ],
217
- [
218
- 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_nyc.mp4',
219
- '',
220
- '',
221
- 0,
222
- 25,
223
- 4.5,
224
- 10,
225
- ],
226
- [
227
- 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/mochi_storm.mp4',
228
- 'storm',
229
- '',
230
- 0,
231
- 25,
232
- 4.5,
233
- 10,
234
- ],
235
- [
236
- 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/hunyuan_spring.mp4',
237
- '',
238
- '',
239
- 0,
240
- 25,
241
- 4.5,
242
- 10,
243
- ],
244
- [
245
- 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/hunyuan_typing.mp4',
246
- 'typing',
247
- '',
248
- 0,
249
- 25,
250
- 4.5,
251
- 10,
252
- ],
253
- [
254
- 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/hunyuan_wake_up.mp4',
255
- '',
256
- '',
257
- 0,
258
- 25,
259
- 4.5,
260
- 10,
261
- ],
262
- ])
263
-
264
- text_to_audio_tab = gr.Interface(
265
- fn=text_to_audio,
266
- inputs=[
267
- gr.Text(label='Prompt'),
268
- gr.Text(label='Negative prompt'),
269
- gr.Number(label='Seed (-1: random)', value=-1, precision=0, minimum=-1),
270
- gr.Number(label='Num steps', value=25, precision=0, minimum=1),
271
- gr.Number(label='Guidance Strength', value=4.5, minimum=1),
272
- gr.Number(label='Duration (sec)', value=8, minimum=1),
273
- ],
274
- outputs='audio',
275
- cache_examples=False,
276
- title='MMAudio — Text-to-Audio Synthesis',
277
- )
278
-
279
- if __name__ == "__main__":
280
- gr.TabbedInterface([video_to_audio_tab, text_to_audio_tab],
281
- ['Video-to-Audio', 'Text-to-Audio']).launch(allowed_paths=[output_dir])
 
38
 
39
 
40
  def get_model() -> tuple[MMAudio, FeaturesUtils, SequenceConfig]:
41
+ """
42
+ Load and initialize the MMAudio model and its associated utilities.
43
+ This function constructs the MMAudio neural network, loads pretrained
44
+ weights, initializes feature extraction utilities, and prepares the
45
+ sequence configuration needed for inference.
46
+
47
+ Returns:
48
+ tuple:
49
+ - net (MMAudio): The loaded MMAudio neural network in evaluation mode.
50
+ - feature_utils (FeaturesUtils): Utility object for audio and video feature extraction.
51
+ - seq_cfg (SequenceConfig): Configuration object defining sequence lengths and duration.
52
+ """
53
  seq_cfg = model.seq_cfg
54
 
55
  net: MMAudio = get_my_mmaudio(model.model_name).to(device, dtype).eval()
56
  net.load_weights(torch.load(model.model_path, map_location=device, weights_only=True))
57
  log.info(f'Loaded weights from {model.model_path}')
58
 
59
+ feature_utils = FeaturesUtils(
60
+ tod_vae_ckpt=model.vae_path,
61
+ synchformer_ckpt=model.synchformer_ckpt,
62
+ enable_conditions=True,
63
+ mode=model.mode,
64
+ bigvgan_vocoder_ckpt=model.bigvgan_16k_path,
65
+ need_vae_encoder=False
66
+ )
67
  feature_utils = feature_utils.to(device, dtype).eval()
68
 
69
  return net, feature_utils, seq_cfg
 
74
 
75
  @spaces.GPU(duration=120)
76
  @torch.inference_mode()
77
+ def video_to_audio(
78
+ video: gr.Video,
79
+ prompt: str,
80
+ negative_prompt: str,
81
+ seed: int,
82
+ num_steps: int,
83
+ cfg_strength: float,
84
+ duration: float,
85
+ ):
86
+ """
87
+ Generate audio conditioned on a video and text prompt.
88
+ This function extracts visual features from a video, combines them
89
+ with text conditioning, and synthesizes synchronized audio using
90
+ the MMAudio model. The output is a video file with generated audio.
91
+
92
+ Args:
93
+ video (gr.Video): Input video used for visual and temporal conditioning.
94
+ prompt (str): Text prompt describing the desired audio content.
95
+ negative_prompt (str): Text describing audio characteristics to avoid.
96
+ seed (int): Random seed for reproducibility (-1 for random).
97
+ num_steps (int): Number of diffusion inference steps.
98
+ cfg_strength (float): Classifier-free guidance strength.
99
+ duration (float): Duration of the generated audio in seconds.
100
+
101
+ Returns:
102
+ str: File path to the generated video containing synthesized audio.
103
+ """
104
  rng = torch.Generator(device=device)
105
  if seed >= 0:
106
  rng.manual_seed(seed)
107
  else:
108
  rng.seed()
109
+
110
  fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
111
 
112
  video_info = load_video(video, duration)
113
  clip_frames = video_info.clip_frames
114
  sync_frames = video_info.sync_frames
115
  duration = video_info.duration_sec
116
+
117
  clip_frames = clip_frames.unsqueeze(0)
118
  sync_frames = sync_frames.unsqueeze(0)
119
+
120
  seq_cfg.duration = duration
121
+ net.update_seq_lengths(
122
+ seq_cfg.latent_seq_len,
123
+ seq_cfg.clip_seq_len,
124
+ seq_cfg.sync_seq_len
125
+ )
126
+
127
+ audios = generate(
128
+ clip_frames,
129
+ sync_frames,
130
+ [prompt],
131
+ negative_text=[negative_prompt],
132
+ feature_utils=feature_utils,
133
+ net=net,
134
+ fm=fm,
135
+ rng=rng,
136
+ cfg_strength=cfg_strength,
137
+ )
138
  audio = audios.float().cpu()[0]
139
 
 
140
  video_save_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name
 
 
141
  make_video(video_info, video_save_path, audio, sampling_rate=seq_cfg.sampling_rate)
142
  log.info(f'Saved video to {video_save_path}')
143
+
144
  return video_save_path
145
 
146
 
147
  @spaces.GPU(duration=120)
148
  @torch.inference_mode()
149
+ def text_to_audio(
150
+ prompt: str,
151
+ negative_prompt: str,
152
+ seed: int,
153
+ num_steps: int,
154
+ cfg_strength: float,
155
+ duration: float,
156
+ ):
157
+ """
158
+ Generate audio purely from text prompts.
159
+ This function synthesizes standalone audio using the MMAudio model
160
+ without any video conditioning, relying solely on textual prompts
161
+ and diffusion-based generation.
162
+
163
+ Args:
164
+ prompt (str): Text prompt describing the desired audio content.
165
+ negative_prompt (str): Text describing audio characteristics to avoid.
166
+ seed (int): Random seed for reproducibility (-1 for random).
167
+ num_steps (int): Number of diffusion inference steps.
168
+ cfg_strength (float): Classifier-free guidance strength.
169
+ duration (float): Duration of the generated audio in seconds.
170
+
171
+ Returns:
172
+ str: File path to the generated audio file.
173
+ """
174
  rng = torch.Generator(device=device)
175
  if seed >= 0:
176
  rng.manual_seed(seed)
177
  else:
178
  rng.seed()
179
+
180
  fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
181
 
182
  clip_frames = sync_frames = None
183
  seq_cfg.duration = duration
184
+ net.update_seq_lengths(
185
+ seq_cfg.latent_seq_len,
186
+ seq_cfg.clip_seq_len,
187
+ seq_cfg.sync_seq_len
188
+ )
189
+
190
+ audios = generate(
191
+ clip_frames,
192
+ sync_frames,
193
+ [prompt],
194
+ negative_text=[negative_prompt],
195
+ feature_utils=feature_utils,
196
+ net=net,
197
+ fm=fm,
198
+ rng=rng,
199
+ cfg_strength=cfg_strength,
200
+ )
201
  audio = audios.float().cpu()[0]
202
 
203
  audio_save_path = tempfile.NamedTemporaryFile(delete=False, suffix='.flac').name
204
  torchaudio.save(audio_save_path, audio, seq_cfg.sampling_rate)
205
  log.info(f'Saved audio to {audio_save_path}')
 
206
 
207
+ return audio_save_path