LTTEAM commited on
Commit
8164907
·
verified ·
1 Parent(s): 19b3d93

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +184 -279
app.py CHANGED
@@ -1,279 +1,184 @@
1
- import logging
2
- from datetime import datetime
3
- from pathlib import Path
4
-
5
- import gradio as gr
6
- import torch
7
- import torchaudio
8
- import os
9
-
10
- try:
11
- import mmaudio
12
- except ImportError:
13
- os.system("pip install -e .")
14
- import mmaudio
15
-
16
- from mmaudio.eval_utils import (ModelConfig, all_model_cfg, generate, load_video, make_video,
17
- setup_eval_logging)
18
- from mmaudio.model.flow_matching import FlowMatching
19
- from mmaudio.model.networks import MMAudio, get_my_mmaudio
20
- from mmaudio.model.sequence_config import SequenceConfig
21
- from mmaudio.model.utils.features_utils import FeaturesUtils
22
- import tempfile
23
-
24
- torch.backends.cuda.matmul.allow_tf32 = True
25
- torch.backends.cudnn.allow_tf32 = True
26
-
27
- log = logging.getLogger()
28
-
29
- device = 'cuda'
30
- dtype = torch.bfloat16
31
-
32
- model: ModelConfig = all_model_cfg['large_44k_v2']
33
- model.download_if_needed()
34
- output_dir = Path('./output/gradio')
35
-
36
- setup_eval_logging()
37
-
38
-
39
- def get_model() -> tuple[MMAudio, FeaturesUtils, SequenceConfig]:
40
- seq_cfg = model.seq_cfg
41
-
42
- net: MMAudio = get_my_mmaudio(model.model_name).to(device, dtype).eval()
43
- net.load_weights(torch.load(model.model_path, map_location=device, weights_only=True))
44
- log.info(f'Loaded weights from {model.model_path}')
45
-
46
- feature_utils = FeaturesUtils(tod_vae_ckpt=model.vae_path,
47
- synchformer_ckpt=model.synchformer_ckpt,
48
- enable_conditions=True,
49
- mode=model.mode,
50
- bigvgan_vocoder_ckpt=model.bigvgan_16k_path,
51
- need_vae_encoder=False)
52
- feature_utils = feature_utils.to(device, dtype).eval()
53
-
54
- return net, feature_utils, seq_cfg
55
-
56
-
57
- net, feature_utils, seq_cfg = get_model()
58
-
59
-
60
- @torch.inference_mode()
61
- def video_to_audio(video: gr.Video, prompt: str, negative_prompt: str, seed: int, num_steps: int,
62
- cfg_strength: float, duration: float):
63
-
64
- rng = torch.Generator(device=device)
65
- if seed >= 0:
66
- rng.manual_seed(seed)
67
- else:
68
- rng.seed()
69
- fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
70
-
71
- video_info = load_video(video, duration)
72
- clip_frames = video_info.clip_frames
73
- sync_frames = video_info.sync_frames
74
- duration = video_info.duration_sec
75
- clip_frames = clip_frames.unsqueeze(0)
76
- sync_frames = sync_frames.unsqueeze(0)
77
- seq_cfg.duration = duration
78
- net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
79
-
80
- audios = generate(clip_frames,
81
- sync_frames, [prompt],
82
- negative_text=[negative_prompt],
83
- feature_utils=feature_utils,
84
- net=net,
85
- fm=fm,
86
- rng=rng,
87
- cfg_strength=cfg_strength)
88
- audio = audios.float().cpu()[0]
89
-
90
- # current_time_string = datetime.now().strftime('%Y%m%d_%H%M%S')
91
- video_save_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name
92
- # output_dir.mkdir(exist_ok=True, parents=True)
93
- # video_save_path = output_dir / f'{current_time_string}.mp4'
94
- make_video(video_info, video_save_path, audio, sampling_rate=seq_cfg.sampling_rate)
95
- log.info(f'Saved video to {video_save_path}')
96
- return video_save_path
97
-
98
-
99
- @torch.inference_mode()
100
- def text_to_audio(prompt: str, negative_prompt: str, seed: int, num_steps: int, cfg_strength: float,
101
- duration: float):
102
-
103
- rng = torch.Generator(device=device)
104
- if seed >= 0:
105
- rng.manual_seed(seed)
106
- else:
107
- rng.seed()
108
- fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
109
-
110
- clip_frames = sync_frames = None
111
- seq_cfg.duration = duration
112
- net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
113
-
114
- audios = generate(clip_frames,
115
- sync_frames, [prompt],
116
- negative_text=[negative_prompt],
117
- feature_utils=feature_utils,
118
- net=net,
119
- fm=fm,
120
- rng=rng,
121
- cfg_strength=cfg_strength)
122
- audio = audios.float().cpu()[0]
123
-
124
- audio_save_path = tempfile.NamedTemporaryFile(delete=False, suffix='.flac').name
125
- torchaudio.save(audio_save_path, audio, seq_cfg.sampling_rate)
126
- log.info(f'Saved audio to {audio_save_path}')
127
- return audio_save_path
128
-
129
-
130
- video_to_audio_tab = gr.Interface(
131
- fn=video_to_audio,
132
- description="""
133
- Project page: <a href="https://hkchengrex.com/MMAudio/">https://hkchengrex.com/MMAudio/</a><br>
134
- Code: <a href="https://github.com/hkchengrex/MMAudio">https://github.com/hkchengrex/MMAudio</a><br>
135
-
136
- Ho Kei Cheng, Masato Ishii, Akio Hayakawa, Takashi Shibuya, Alexander Schwing, Yuki Mitsufuji
137
-
138
- University of Illinois Urbana-Champaign, Sony AI, and Sony Group Corporation
139
-
140
- CVPR 2025
141
-
142
- NOTE: It takes longer to process high-resolution videos (>384 px on the shorter side).
143
- Doing so does not improve results.
144
-
145
- The model has been trained on 8-second videos. Using much longer or shorter videos will degrade performance. Around 5s~12s should be fine.
146
- """,
147
- inputs=[
148
- gr.Video(),
149
- gr.Text(label='Prompt'),
150
- gr.Text(label='Negative prompt', value='music'),
151
- gr.Number(label='Seed (-1: random)', value=-1, precision=0, minimum=-1),
152
- gr.Number(label='Num steps', value=25, precision=0, minimum=1),
153
- gr.Number(label='Guidance Strength', value=4.5, minimum=1),
154
- gr.Number(label='Duration (sec)', value=8, minimum=1),
155
- ],
156
- outputs='playable_video',
157
- cache_examples=False,
158
- title='MMAudio — Video-to-Audio Synthesis',
159
- examples=[
160
- [
161
- 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_beach.mp4',
162
- 'waves, seagulls',
163
- '',
164
- 0,
165
- 25,
166
- 4.5,
167
- 10,
168
- ],
169
- [
170
- 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_serpent.mp4',
171
- '',
172
- 'music',
173
- 0,
174
- 25,
175
- 4.5,
176
- 10,
177
- ],
178
- [
179
- 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_seahorse.mp4',
180
- 'bubbles',
181
- '',
182
- 0,
183
- 25,
184
- 4.5,
185
- 10,
186
- ],
187
- [
188
- 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_india.mp4',
189
- 'Indian holy music',
190
- '',
191
- 0,
192
- 25,
193
- 4.5,
194
- 10,
195
- ],
196
- [
197
- 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_galloping.mp4',
198
- 'galloping',
199
- '',
200
- 0,
201
- 25,
202
- 4.5,
203
- 10,
204
- ],
205
- [
206
- 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_kraken.mp4',
207
- 'waves, storm',
208
- '',
209
- 0,
210
- 25,
211
- 4.5,
212
- 10,
213
- ],
214
- [
215
- 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_nyc.mp4',
216
- '',
217
- '',
218
- 0,
219
- 25,
220
- 4.5,
221
- 10,
222
- ],
223
- [
224
- 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/mochi_storm.mp4',
225
- 'storm',
226
- '',
227
- 0,
228
- 25,
229
- 4.5,
230
- 10,
231
- ],
232
- [
233
- 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/hunyuan_spring.mp4',
234
- '',
235
- '',
236
- 0,
237
- 25,
238
- 4.5,
239
- 10,
240
- ],
241
- [
242
- 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/hunyuan_typing.mp4',
243
- 'typing',
244
- '',
245
- 0,
246
- 25,
247
- 4.5,
248
- 10,
249
- ],
250
- [
251
- 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/hunyuan_wake_up.mp4',
252
- '',
253
- '',
254
- 0,
255
- 25,
256
- 4.5,
257
- 10,
258
- ],
259
- ])
260
-
261
- text_to_audio_tab = gr.Interface(
262
- fn=text_to_audio,
263
- inputs=[
264
- gr.Text(label='Prompt'),
265
- gr.Text(label='Negative prompt'),
266
- gr.Number(label='Seed (-1: random)', value=-1, precision=0, minimum=-1),
267
- gr.Number(label='Num steps', value=25, precision=0, minimum=1),
268
- gr.Number(label='Guidance Strength', value=4.5, minimum=1),
269
- gr.Number(label='Duration (sec)', value=8, minimum=1),
270
- ],
271
- outputs='audio',
272
- cache_examples=False,
273
- title='MMAudio — Text-to-Audio Synthesis',
274
- )
275
-
276
- if __name__ == "__main__":
277
- gr.TabbedInterface([video_to_audio_tab, text_to_audio_tab],
278
- ['Video-to-Audio', 'Text-to-Audio']).launch(allowed_paths=[output_dir])
279
-
 
1
+ import logging
2
+ from datetime import datetime
3
+ from pathlib import Path
4
+ import sys
5
+
6
+ import gradio as gr
7
+ import torch
8
+ import torchaudio
9
+ import os
10
+
11
+ # Phát hiện Colab
12
+ IN_COLAB = "google.colab" in sys.modules
13
+
14
+ # Tự động chọn device
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+ # GPU thì bfloat16, CPU thì float32
17
+ dtype = torch.bfloat16 if device.type == "cuda" else torch.float32
18
+
19
+ try:
20
+ import mmaudio
21
+ except ImportError:
22
+ os.system("pip install -e .")
23
+ import mmaudio
24
+
25
+ from mmaudio.eval_utils import (
26
+ ModelConfig, all_model_cfg, generate, load_video, make_video,
27
+ setup_eval_logging
28
+ )
29
+ from mmaudio.model.flow_matching import FlowMatching
30
+ from mmaudio.model.networks import MMAudio, get_my_mmaudio
31
+ from mmaudio.model.sequence_config import SequenceConfig
32
+ from mmaudio.model.utils.features_utils import FeaturesUtils
33
+ import tempfile
34
+
35
+ # Tắt warning về TF32 nếu cần
36
+ torch.backends.cuda.matmul.allow_tf32 = True
37
+ torch.backends.cudnn.allow_tf32 = True
38
+
39
+ log = logging.getLogger()
40
+
41
+ # Cấu hình model
42
+ model: ModelConfig = all_model_cfg['large_44k_v2']
43
+ model.download_if_needed()
44
+ output_dir = Path('./output/gradio')
45
+ setup_eval_logging()
46
+
47
+ def get_model() -> tuple[MMAudio, FeaturesUtils, SequenceConfig]:
48
+ seq_cfg = model.seq_cfg
49
+
50
+ # Đưa mạng lên device và dtype
51
+ net: MMAudio = get_my_mmaudio(model.model_name).to(device, dtype).eval()
52
+ net.load_weights(
53
+ torch.load(model.model_path, map_location=device, weights_only=True)
54
+ )
55
+ log.info(f'Loaded weights from {model.model_path}')
56
+
57
+ feature_utils = FeaturesUtils(
58
+ tod_vae_ckpt=model.vae_path,
59
+ synchformer_ckpt=model.synchformer_ckpt,
60
+ enable_conditions=True,
61
+ mode=model.mode,
62
+ bigvgan_vocoder_ckpt=model.bigvgan_16k_path,
63
+ need_vae_encoder=False
64
+ ).to(device, dtype).eval()
65
+
66
+ return net, feature_utils, seq_cfg
67
+
68
+ net, feature_utils, seq_cfg = get_model()
69
+
70
+ @torch.inference_mode()
71
+ def video_to_audio(
72
+ video: gr.Video, prompt: str, negative_prompt: str, seed: int,
73
+ num_steps: int, cfg_strength: float, duration: float
74
+ ):
75
+ rng = torch.Generator(device=device)
76
+ if seed >= 0:
77
+ rng.manual_seed(seed)
78
+ else:
79
+ rng.seed()
80
+
81
+ fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
82
+ video_info = load_video(video, duration)
83
+ clip_frames = video_info.clip_frames.unsqueeze(0)
84
+ sync_frames = video_info.sync_frames.unsqueeze(0)
85
+ seq_cfg.duration = video_info.duration_sec
86
+ net.update_seq_lengths(
87
+ seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len
88
+ )
89
+
90
+ audios = generate(
91
+ clip_frames, sync_frames, [prompt],
92
+ negative_text=[negative_prompt],
93
+ feature_utils=feature_utils,
94
+ net=net, fm=fm, rng=rng, cfg_strength=cfg_strength
95
+ )
96
+ audio = audios.float().cpu()[0]
97
+
98
+ video_save_path = tempfile.NamedTemporaryFile(
99
+ delete=False, suffix='.mp4'
100
+ ).name
101
+ make_video(video_info, video_save_path, audio,
102
+ sampling_rate=seq_cfg.sampling_rate)
103
+ log.info(f'Saved video to {video_save_path}')
104
+ return video_save_path
105
+
106
+ @torch.inference_mode()
107
+ def text_to_audio(
108
+ prompt: str, negative_prompt: str, seed: int,
109
+ num_steps: int, cfg_strength: float, duration: float
110
+ ):
111
+ rng = torch.Generator(device=device)
112
+ if seed >= 0:
113
+ rng.manual_seed(seed)
114
+ else:
115
+ rng.seed()
116
+
117
+ fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
118
+ seq_cfg.duration = duration
119
+ net.update_seq_lengths(
120
+ seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len
121
+ )
122
+
123
+ audios = generate(
124
+ None, None, [prompt],
125
+ negative_text=[negative_prompt],
126
+ feature_utils=feature_utils,
127
+ net=net, fm=fm, rng=rng, cfg_strength=cfg_strength
128
+ )
129
+ audio = audios.float().cpu()[0]
130
+
131
+ audio_save_path = tempfile.NamedTemporaryFile(
132
+ delete=False, suffix='.flac'
133
+ ).name
134
+ torchaudio.save(audio_save_path, audio, seq_cfg.sampling_rate)
135
+ log.info(f'Saved audio to {audio_save_path}')
136
+ return audio_save_path
137
+
138
+ # Cấu hình giao diện Gradio
139
+ video_to_audio_tab = gr.Interface(
140
+ fn=video_to_audio,
141
+ description="""
142
+ Project page: <a href="https://hkchengrex.com/MMAudio/">MMAudio</a><br>
143
+ Code: <a href="https://github.com/hkchengrex/MMAudio">GitHub/MMAudio</a><br>
144
+ CVPR 2025 — HK Cheng, M. Ishii, A. Hayakawa, T. Shibuya, A. Schwing, Y. Mitsufuji.
145
+ """,
146
+ inputs=[
147
+ gr.Video(),
148
+ gr.Text(label='Prompt'),
149
+ gr.Text(label='Negative prompt', value='music'),
150
+ gr.Number(label='Seed (-1: random)', value=-1, precision=0, minimum=-1),
151
+ gr.Number(label='Num steps', value=25, precision=0, minimum=1),
152
+ gr.Number(label='Guidance Strength', value=4.5, minimum=1),
153
+ gr.Number(label='Duration (sec)', value=8, minimum=1),
154
+ ],
155
+ outputs='playable_video',
156
+ cache_examples=False,
157
+ title='MMAudio — Video-to-Audio Synthesis',
158
+ )
159
+
160
+ text_to_audio_tab = gr.Interface(
161
+ fn=text_to_audio,
162
+ inputs=[
163
+ gr.Text(label='Prompt'),
164
+ gr.Text(label='Negative prompt'),
165
+ gr.Number(label='Seed (-1: random)', value=-1, precision=0, minimum=-1),
166
+ gr.Number(label='Num steps', value=25, precision=0, minimum=1),
167
+ gr.Number(label='Guidance Strength', value=4.5, minimum=1),
168
+ gr.Number(label='Duration (sec)', value=8, minimum=1),
169
+ ],
170
+ outputs='audio',
171
+ cache_examples=False,
172
+ title='MMAudio — Text-to-Audio Synthesis',
173
+ )
174
+
175
+ if __name__ == "__main__":
176
+ gr.TabbedInterface(
177
+ [video_to_audio_tab, text_to_audio_tab],
178
+ ['Video-to-Audio', 'Text-to-Audio']
179
+ ).launch(
180
+ server_name="0.0.0.0",
181
+ server_port=7860,
182
+ share=IN_COLAB, # Nếu chạy trên Colab thì share=True
183
+ inbrowser=not IN_COLAB # nếu không phải Colab thì mở trình duyệt local
184
+ )