NZUONG commited on
Commit
18b3589
Β·
verified Β·
1 Parent(s): e3bedf7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +394 -394
app.py CHANGED
@@ -1,395 +1,395 @@
1
- import os
2
- import torch
3
- import torchaudio
4
- import gradio as gr
5
- import matplotlib.pyplot as plt
6
- from tqdm import tqdm
7
- from transformers import UMT5EncoderModel, AutoTokenizer
8
- from huggingface_hub import hf_hub_download, snapshot_download
9
- import json
10
- import numpy as np
11
- import tempfile
12
- from io import BytesIO
13
- import warnings
14
- warnings.filterwarnings("ignore")
15
-
16
- # Import model components
17
- from model.ae.music_dcae import MusicDCAE
18
- from model.ldm.editing_unet import EditingUNet
19
- from model.ldm.dpm_solver_pytorch import NoiseScheduleVP, model_wrapper, DPM_Solver
20
-
21
- # Configuration
22
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
23
- DTYPE = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float32
24
-
25
- # Model repository - UPDATE THIS TO YOUR MODEL REPO
26
- MODEL_REPO = "NZUONG/mude" # Your uploaded model repository
27
-
28
- # DDPM Parameters
29
- DDPM_NUM_TIMESTEPS = 1000
30
- DDPM_BETA_START = 0.0001
31
- DDPM_BETA_END = 0.02
32
-
33
- class AttrDict(dict):
34
- def __init__(self, *args, **kwargs):
35
- super(AttrDict, self).__init__(*args, **kwargs)
36
- self.__dict__ = self
37
-
38
- def download_models():
39
- """Download models from Hugging Face Hub"""
40
- print("πŸ”„ Downloading models from Hugging Face Hub...")
41
-
42
- # Create local directories
43
- os.makedirs("checkpoints", exist_ok=True)
44
-
45
- try:
46
- # Download the entire repository
47
- local_dir = snapshot_download(
48
- repo_id=MODEL_REPO,
49
- cache_dir="./cache",
50
- local_dir="./checkpoints",
51
- repo_type="model"
52
- )
53
- print(f"βœ… Models downloaded to: {local_dir}")
54
- return True
55
- except Exception as e:
56
- print(f"❌ Error downloading models: {e}")
57
- return False
58
-
59
- class AudioEditor:
60
- def __init__(self):
61
- self.dcae = None
62
- self.tokenizer = None
63
- self.text_encoder = None
64
- self.model = None
65
- self.is_loaded = False
66
-
67
- def load_models(self):
68
- """Load all models once at startup"""
69
- if self.is_loaded:
70
- return True
71
-
72
- # Download models if not present
73
- if not os.path.exists("checkpoints/music_dcae_f8c8"):
74
- print("πŸ“₯ Models not found locally, downloading...")
75
- if not download_models():
76
- return False
77
-
78
- print("πŸ”„ Loading models...")
79
-
80
- try:
81
- # Model paths
82
- dcae_path = "checkpoints/music_dcae_f8c8"
83
- vocoder_path = "checkpoints/music_vocoder"
84
- t5_path = "checkpoints/umt5-base"
85
- unet_config_path = "model/ldm/exp_config.json"
86
- trained_model_path = "checkpoints/fm_checkpoint_epoch_9.pt"
87
-
88
- # Load DCAE
89
- self.dcae = MusicDCAE(
90
- dcae_checkpoint_path=dcae_path,
91
- vocoder_checkpoint_path=vocoder_path
92
- ).to(DEVICE).eval()
93
-
94
- # Load text encoder
95
- self.tokenizer = AutoTokenizer.from_pretrained(t5_path)
96
- self.text_encoder = UMT5EncoderModel.from_pretrained(t5_path).to(DEVICE, dtype=DTYPE).eval()
97
-
98
- # Load UNet config
99
- with open(unet_config_path, 'r') as f:
100
- unet_config = AttrDict(json.load(f)['model']['unet'])
101
-
102
- self.model = EditingUNet(unet_config, use_flow_matching=False).to("cpu", dtype=DTYPE).eval()
103
-
104
- # Load checkpoint
105
- checkpoint = torch.load(trained_model_path, map_location="cpu")
106
- model_state_dict = checkpoint.get('model_state_dict', checkpoint)
107
- if any(key.startswith('_orig_mod.') for key in model_state_dict.keys()):
108
- model_state_dict = {key.replace('_orig_mod.', ''): value for key, value in model_state_dict.items()}
109
- self.model.load_state_dict(model_state_dict, strict=False)
110
-
111
- self.is_loaded = True
112
- print("βœ… All models loaded successfully!")
113
- return True
114
-
115
- except Exception as e:
116
- print(f"❌ Error loading models: {e}")
117
- return False
118
-
119
- def dpm_solver_sampling(self, model, source_latent, instruction_embedding, uncond_embedding,
120
- strength=1.0, steps=25, guidance_scale=7.5, seed=42):
121
- """DPM-Solver sampling function"""
122
- print(f"πŸš€ Starting DPM-Solver++ sampling with {steps} steps...")
123
-
124
- # Setup noise schedule
125
- betas = torch.linspace(DDPM_BETA_START, DDPM_BETA_END, DDMP_NUM_TIMESTEPS, dtype=torch.float32)
126
- alphas_cumprod = torch.cumprod(1.0 - betas, dim=0)
127
- noise_schedule = NoiseScheduleVP(schedule='discrete', alphas_cumprod=alphas_cumprod)
128
-
129
- # Setup model wrapper
130
- model_fn = model_wrapper(
131
- model,
132
- noise_schedule,
133
- model_type="noise", # DDPM objective only
134
- model_kwargs={
135
- "source_latent": source_latent,
136
- },
137
- guidance_type="classifier-free",
138
- condition=instruction_embedding,
139
- unconditional_condition=uncond_embedding,
140
- guidance_scale=guidance_scale,
141
- )
142
-
143
- # Initialize DPM-Solver++
144
- dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++")
145
-
146
- # Calculate time range
147
- t_end = noise_schedule.T / noise_schedule.total_N
148
- t_start = t_end + strength * (noise_schedule.T - t_end)
149
-
150
- # Add initial noise
151
- torch.manual_seed(seed)
152
- noise = torch.randn_like(source_latent)
153
- latents = dpm_solver.add_noise(source_latent, torch.tensor([t_start], device=DEVICE), noise)
154
- latents = latents.to(DTYPE)
155
-
156
- # Run DPM solver sampling
157
- with torch.amp.autocast(device_type="cuda", dtype=DTYPE, enabled=(DTYPE != torch.float32)):
158
- with torch.no_grad():
159
- final_latent, _ = dpm_solver.sample(
160
- latents,
161
- steps=steps,
162
- t_start=t_start,
163
- t_end=t_end,
164
- order=2,
165
- method="multistep",
166
- skip_type="time_uniform",
167
- lower_order_final=True,
168
- return_intermediate=True,
169
- )
170
-
171
- return final_latent
172
-
173
- def process_audio(self, audio_file, instruction, guidance_scale, steps, strength, seed):
174
- """Main audio processing function"""
175
- try:
176
- if not self.load_models():
177
- return None, None, "❌ Failed to load models. Please try again."
178
-
179
- # Load and preprocess audio
180
- print(f"🎡 Processing audio: {audio_file}")
181
- audio, sr = torchaudio.load(audio_file)
182
- TARGET_SR_DCAE = 44100
183
- TARGET_LEN_DCAE = TARGET_SR_DCAE * 10
184
-
185
- if sr != TARGET_SR_DCAE:
186
- audio = torchaudio.transforms.Resample(sr, TARGET_SR_DCAE)(audio)
187
-
188
- if audio.shape[1] > TARGET_LEN_DCAE:
189
- audio = audio[:, :TARGET_LEN_DCAE]
190
- elif audio.shape[1] < TARGET_LEN_DCAE:
191
- audio = torch.nn.functional.pad(audio, (0, TARGET_LEN_DCAE - audio.shape[1]))
192
-
193
- if audio.shape[0] == 1:
194
- audio = audio.repeat(2, 1)
195
-
196
- # Encode audio
197
- with torch.no_grad():
198
- source_latent_scaled, _ = self.dcae.encode(audio.to(DEVICE).unsqueeze(0))
199
-
200
- # Prepare text embeddings
201
- with torch.no_grad(), torch.amp.autocast(device_type="cuda", dtype=DTYPE, enabled=(DTYPE != torch.float32)):
202
- text_input = self.tokenizer([instruction], max_length=32, padding="max_length",
203
- truncation=True, return_tensors="pt")
204
- instruction_embedding = self.text_encoder(text_input.input_ids.to(DEVICE))[0]
205
-
206
- uncond_input = self.tokenizer([""], max_length=32, padding="max_length",
207
- truncation=True, return_tensors="pt")
208
- uncond_embedding = self.text_encoder(uncond_input.input_ids.to(DEVICE))[0]
209
-
210
- # Move models for inference
211
- self.dcae = self.dcae.cpu()
212
- torch.cuda.empty_cache()
213
- self.model = self.model.to(DEVICE, dtype=DTYPE)
214
-
215
- # Generate
216
- print("🎨 Generating edited audio...")
217
- with torch.amp.autocast(device_type="cuda", dtype=DTYPE, enabled=(DTYPE != torch.float32)):
218
- with torch.no_grad():
219
- final_latent = self.dpm_solver_sampling(
220
- model=self.model,
221
- source_latent=source_latent_scaled,
222
- instruction_embedding=instruction_embedding,
223
- uncond_embedding=uncond_embedding,
224
- strength=strength,
225
- steps=int(steps),
226
- guidance_scale=guidance_scale,
227
- seed=int(seed)
228
- )
229
-
230
- # Decode results
231
- self.model = self.model.cpu()
232
- torch.cuda.empty_cache()
233
- self.dcae = self.dcae.to(DEVICE)
234
-
235
- final_latent_unscaled = (final_latent.float() / self.dcae.scale_factor) + self.dcae.shift_factor
236
- source_latent_raw = (source_latent_scaled / self.dcae.scale_factor) + self.dcae.shift_factor
237
-
238
- with torch.no_grad():
239
- source_mel = self.dcae.decode_to_mel(source_latent_raw)
240
- edited_mel = self.dcae.decode_to_mel(final_latent_unscaled)
241
- _, pred_wavs = self.dcae.decode(latents=final_latent.float(), sr=44100)
242
- edited_audio = pred_wavs[0]
243
-
244
- # Create comparison plot
245
- comparison_plot = self.create_mel_comparison(source_mel, edited_mel, instruction)
246
-
247
- # Save output audio
248
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
249
- torchaudio.save(tmp_file.name, edited_audio.cpu().float(), 44100)
250
- output_path = tmp_file.name
251
-
252
- # Cleanup
253
- self.dcae = self.dcae.cpu()
254
- torch.cuda.empty_cache()
255
-
256
- return output_path, comparison_plot, f"βœ… Audio editing completed! Instruction: '{instruction}'"
257
-
258
- except Exception as e:
259
- import traceback
260
- error_msg = f"❌ Error: {str(e)}\n{traceback.format_exc()}"
261
- print(error_msg)
262
- return None, None, error_msg
263
-
264
- def create_mel_comparison(self, source_mel, edited_mel, instruction):
265
- """Create mel-spectrogram comparison plot"""
266
- try:
267
- source_mel_np = source_mel.squeeze(0)[0].cpu().float().numpy()
268
- edited_mel_np = edited_mel.squeeze(0)[0].cpu().float().numpy()
269
-
270
- fig, axs = plt.subplots(2, 1, figsize=(12, 8), sharex=True, sharey=True)
271
- fig.suptitle(f'Mel-Spectrogram Comparison', fontsize=14)
272
-
273
- # Plot source
274
- im1 = axs[0].imshow(source_mel_np, aspect='auto', origin='lower', cmap='viridis')
275
- axs[0].set_title('Original Audio')
276
- axs[0].set_ylabel('Mel Bins')
277
- plt.colorbar(im1, ax=axs[0])
278
-
279
- # Plot edited
280
- im2 = axs[1].imshow(edited_mel_np, aspect='auto', origin='lower', cmap='viridis')
281
- axs[1].set_title(f'Edited Audio: "{instruction}"')
282
- axs[1].set_ylabel('Mel Bins')
283
- axs[1].set_xlabel('Time Frames')
284
- plt.colorbar(im2, ax=axs[1])
285
-
286
- plt.tight_layout()
287
-
288
- # Save to temporary file for Gradio
289
- with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_file:
290
- plt.savefig(tmp_file.name, dpi=100, bbox_inches='tight')
291
- plt.close()
292
- return tmp_file.name
293
-
294
- except Exception as e:
295
- print(f"Error creating plot: {e}")
296
- plt.close()
297
- return None
298
-
299
- # Initialize the audio editor
300
- audio_editor = AudioEditor()
301
-
302
- def gradio_interface(audio_file, instruction, guidance_scale, steps, strength, seed):
303
- """Gradio interface function"""
304
- if audio_file is None:
305
- return None, None, "Please upload an audio file"
306
-
307
- if not instruction.strip():
308
- return None, None, "Please provide an editing instruction"
309
-
310
- return audio_editor.process_audio(audio_file, instruction, guidance_scale, steps, strength, seed)
311
-
312
- # Create Gradio interface
313
- with gr.Blocks(title="🎡 AI Audio Editor", theme=gr.themes.Soft()) as demo:
314
- gr.HTML("""
315
- <div style="text-align: center; margin-bottom: 20px;">
316
- <h1>🎡 AI Audio Editor</h1>
317
- <p>Upload an audio file and provide instructions to edit it using AI.<br/>
318
- The model uses DPM-Solver++ for fast, high-quality generation.</p>
319
- </div>
320
- """)
321
-
322
- with gr.Row():
323
- with gr.Column(scale=1):
324
- # Input components
325
- audio_input = gr.Audio(
326
- label="πŸ“ Upload Audio File",
327
- type="filepath"
328
- )
329
-
330
- instruction_input = gr.Textbox(
331
- label="✏️ Editing Instruction",
332
- placeholder="e.g., 'Add drums', 'Make it more energetic', 'Remove vocals'",
333
- lines=2
334
- )
335
-
336
- with gr.Accordion("πŸ”§ Advanced Settings", open=False):
337
- guidance_scale = gr.Slider(
338
- minimum=1.0, maximum=20.0, value=7.5, step=0.5,
339
- label="Guidance Scale",
340
- info="Higher values follow the instruction more closely"
341
- )
342
-
343
- steps = gr.Slider(
344
- minimum=10, maximum=50, value=25, step=5,
345
- label="Sampling Steps",
346
- info="More steps = better quality, slower generation"
347
- )
348
-
349
- strength = gr.Slider(
350
- minimum=0.1, maximum=1.0, value=1.0, step=0.1,
351
- label="Denoising Strength",
352
- info="1.0 = full denoising, lower = more conservative editing"
353
- )
354
-
355
- seed = gr.Number(
356
- value=42, label="Seed",
357
- info="For reproducible results"
358
- )
359
-
360
- generate_btn = gr.Button("🎨 Generate Edited Audio", variant="primary", size="lg")
361
-
362
- with gr.Column(scale=1):
363
- # Output components
364
- status_output = gr.Textbox(label="πŸ“Š Status", interactive=False)
365
- audio_output = gr.Audio(label="🎡 Generated Audio")
366
- plot_output = gr.Image(label="πŸ“ˆ Mel-Spectrogram Comparison")
367
-
368
- gr.HTML("""
369
- <div style="margin-top: 20px; padding: 20px; background-color: #f0f0f0; border-radius: 10px;">
370
- <h3>πŸ“ Usage Tips:</h3>
371
- <ul>
372
- <li><b>Audio Length:</b> Files are automatically processed to 10 seconds</li>
373
- <li><b>Instructions:</b> Be specific (e.g., "Add heavy drums" vs "Add drums")</li>
374
- <li><b>Guidance Scale:</b> Start with 7.5, increase for stronger effects</li>
375
- <li><b>Steps:</b> 25 steps provide good quality/speed balance</li>
376
- </ul>
377
- </div>
378
- """)
379
-
380
- # Connect the interface
381
- generate_btn.click(
382
- fn=gradio_interface,
383
- inputs=[audio_input, instruction_input, guidance_scale, steps, strength, seed],
384
- outputs=[audio_output, plot_output, status_output],
385
- show_progress=True
386
- )
387
-
388
- # Launch settings
389
- if __name__ == "__main__":
390
- demo.launch(
391
- server_name="0.0.0.0",
392
- server_port=7860,
393
- share=False,
394
- show_error=True
395
  )
 
1
+ import os
2
+ import torch
3
+ import torchaudio
4
+ import gradio as gr
5
+ import matplotlib.pyplot as plt
6
+ from tqdm import tqdm
7
+ from transformers import UMT5EncoderModel, AutoTokenizer
8
+ from huggingface_hub import hf_hub_download, snapshot_download
9
+ import json
10
+ import numpy as np
11
+ import tempfile
12
+ from io import BytesIO
13
+ import warnings
14
+ warnings.filterwarnings("ignore")
15
+
16
+ # Import model components
17
+ from model.ae.music_dcae import MusicDCAE
18
+ from model.ldm.editing_unet import EditingUNet
19
+ from model.ldm.dpm_solver_pytorch import NoiseScheduleVP, model_wrapper, DPM_Solver
20
+
21
+ # Configuration
22
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
23
+ DTYPE = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float32
24
+
25
+ # Model repository - UPDATE THIS TO YOUR MODEL REPO
26
+ MODEL_REPO = "NZUONG/mude" # Your uploaded model repository
27
+
28
+ # DDPM Parameters
29
+ DDPM_NUM_TIMESTEPS = 1000
30
+ DDPM_BETA_START = 0.0001
31
+ DDPM_BETA_END = 0.02
32
+
33
+ class AttrDict(dict):
34
+ def __init__(self, *args, **kwargs):
35
+ super(AttrDict, self).__init__(*args, **kwargs)
36
+ self.__dict__ = self
37
+
38
+ def download_models():
39
+ """Download models from Hugging Face Hub"""
40
+ print("πŸ”„ Downloading models from Hugging Face Hub...")
41
+
42
+ # Create local directories
43
+ os.makedirs("checkpoints", exist_ok=True)
44
+
45
+ try:
46
+ # Download the entire repository
47
+ local_dir = snapshot_download(
48
+ repo_id=MODEL_REPO,
49
+ cache_dir="./cache",
50
+ local_dir="./checkpoints",
51
+ repo_type="model"
52
+ )
53
+ print(f"βœ… Models downloaded to: {local_dir}")
54
+ return True
55
+ except Exception as e:
56
+ print(f"❌ Error downloading models: {e}")
57
+ return False
58
+
59
+ class AudioEditor:
60
+ def __init__(self):
61
+ self.dcae = None
62
+ self.tokenizer = None
63
+ self.text_encoder = None
64
+ self.model = None
65
+ self.is_loaded = False
66
+
67
+ def load_models(self):
68
+ """Load all models once at startup"""
69
+ if self.is_loaded:
70
+ return True
71
+
72
+ # Download models if not present
73
+ if not os.path.exists("checkpoints/music_dcae_f8c8"):
74
+ print("πŸ“₯ Models not found locally, downloading...")
75
+ if not download_models():
76
+ return False
77
+
78
+ print("πŸ”„ Loading models...")
79
+
80
+ try:
81
+ # Model paths
82
+ dcae_path = "checkpoints/music_dcae_f8c8"
83
+ vocoder_path = "checkpoints/music_vocoder"
84
+ t5_path = "checkpoints/umt5-base"
85
+ unet_config_path = "model/ldm/exp_config.json"
86
+ trained_model_path = "checkpoints/fm_checkpoint_epoch_9.pt"
87
+
88
+ # Load DCAE
89
+ self.dcae = MusicDCAE(
90
+ dcae_checkpoint_path=dcae_path,
91
+ vocoder_checkpoint_path=vocoder_path
92
+ ).to(DEVICE).eval()
93
+
94
+ # Load text encoder
95
+ self.tokenizer = AutoTokenizer.from_pretrained(t5_path)
96
+ self.text_encoder = UMT5EncoderModel.from_pretrained(t5_path).to(DEVICE, dtype=DTYPE).eval()
97
+
98
+ # Load UNet config
99
+ with open(unet_config_path, 'r') as f:
100
+ unet_config = AttrDict(json.load(f)['model']['unet'])
101
+
102
+ self.model = EditingUNet(unet_config, use_flow_matching=False).to("cpu", dtype=DTYPE).eval()
103
+
104
+ # Load checkpoint
105
+ checkpoint = torch.load(trained_model_path, map_location="cpu")
106
+ model_state_dict = checkpoint.get('model_state_dict', checkpoint)
107
+ if any(key.startswith('_orig_mod.') for key in model_state_dict.keys()):
108
+ model_state_dict = {key.replace('_orig_mod.', ''): value for key, value in model_state_dict.items()}
109
+ self.model.load_state_dict(model_state_dict, strict=False)
110
+
111
+ self.is_loaded = True
112
+ print("βœ… All models loaded successfully!")
113
+ return True
114
+
115
+ except Exception as e:
116
+ print(f"❌ Error loading models: {e}")
117
+ return False
118
+
119
+ def dpm_solver_sampling(self, model, source_latent, instruction_embedding, uncond_embedding,
120
+ strength=1.0, steps=25, guidance_scale=7.5, seed=42):
121
+ """DPM-Solver sampling function"""
122
+ print(f"πŸš€ Starting DPM-Solver++ sampling with {steps} steps...")
123
+
124
+ # Setup noise schedule - FIXED TYPO HERE
125
+ betas = torch.linspace(DDPM_BETA_START, DDPM_BETA_END, DDPM_NUM_TIMESTEPS, dtype=torch.float32)
126
+ alphas_cumprod = torch.cumprod(1.0 - betas, dim=0)
127
+ noise_schedule = NoiseScheduleVP(schedule='discrete', alphas_cumprod=alphas_cumprod)
128
+
129
+ # Setup model wrapper
130
+ model_fn = model_wrapper(
131
+ model,
132
+ noise_schedule,
133
+ model_type="noise", # DDPM objective only
134
+ model_kwargs={
135
+ "source_latent": source_latent,
136
+ },
137
+ guidance_type="classifier-free",
138
+ condition=instruction_embedding,
139
+ unconditional_condition=uncond_embedding,
140
+ guidance_scale=guidance_scale,
141
+ )
142
+
143
+ # Initialize DPM-Solver++
144
+ dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++")
145
+
146
+ # Calculate time range
147
+ t_end = noise_schedule.T / noise_schedule.total_N
148
+ t_start = t_end + strength * (noise_schedule.T - t_end)
149
+
150
+ # Add initial noise
151
+ torch.manual_seed(seed)
152
+ noise = torch.randn_like(source_latent)
153
+ latents = dpm_solver.add_noise(source_latent, torch.tensor([t_start], device=DEVICE), noise)
154
+ latents = latents.to(DTYPE)
155
+
156
+ # Run DPM solver sampling
157
+ with torch.amp.autocast(device_type="cuda", dtype=DTYPE, enabled=(DTYPE != torch.float32)):
158
+ with torch.no_grad():
159
+ final_latent, _ = dpm_solver.sample(
160
+ latents,
161
+ steps=steps,
162
+ t_start=t_start,
163
+ t_end=t_end,
164
+ order=2,
165
+ method="multistep",
166
+ skip_type="time_uniform",
167
+ lower_order_final=True,
168
+ return_intermediate=True,
169
+ )
170
+
171
+ return final_latent
172
+
173
+ def process_audio(self, audio_file, instruction, guidance_scale, steps, strength, seed):
174
+ """Main audio processing function"""
175
+ try:
176
+ if not self.load_models():
177
+ return None, None, "❌ Failed to load models. Please try again."
178
+
179
+ # Load and preprocess audio
180
+ print(f"🎡 Processing audio: {audio_file}")
181
+ audio, sr = torchaudio.load(audio_file)
182
+ TARGET_SR_DCAE = 44100
183
+ TARGET_LEN_DCAE = TARGET_SR_DCAE * 10
184
+
185
+ if sr != TARGET_SR_DCAE:
186
+ audio = torchaudio.transforms.Resample(sr, TARGET_SR_DCAE)(audio)
187
+
188
+ if audio.shape[1] > TARGET_LEN_DCAE:
189
+ audio = audio[:, :TARGET_LEN_DCAE]
190
+ elif audio.shape[1] < TARGET_LEN_DCAE:
191
+ audio = torch.nn.functional.pad(audio, (0, TARGET_LEN_DCAE - audio.shape[1]))
192
+
193
+ if audio.shape[0] == 1:
194
+ audio = audio.repeat(2, 1)
195
+
196
+ # Encode audio
197
+ with torch.no_grad():
198
+ source_latent_scaled, _ = self.dcae.encode(audio.to(DEVICE).unsqueeze(0))
199
+
200
+ # Prepare text embeddings
201
+ with torch.no_grad(), torch.amp.autocast(device_type="cuda", dtype=DTYPE, enabled=(DTYPE != torch.float32)):
202
+ text_input = self.tokenizer([instruction], max_length=32, padding="max_length",
203
+ truncation=True, return_tensors="pt")
204
+ instruction_embedding = self.text_encoder(text_input.input_ids.to(DEVICE))[0]
205
+
206
+ uncond_input = self.tokenizer([""], max_length=32, padding="max_length",
207
+ truncation=True, return_tensors="pt")
208
+ uncond_embedding = self.text_encoder(uncond_input.input_ids.to(DEVICE))[0]
209
+
210
+ # Move models for inference
211
+ self.dcae = self.dcae.cpu()
212
+ torch.cuda.empty_cache()
213
+ self.model = self.model.to(DEVICE, dtype=DTYPE)
214
+
215
+ # Generate
216
+ print("🎨 Generating edited audio...")
217
+ with torch.amp.autocast(device_type="cuda", dtype=DTYPE, enabled=(DTYPE != torch.float32)):
218
+ with torch.no_grad():
219
+ final_latent = self.dpm_solver_sampling(
220
+ model=self.model,
221
+ source_latent=source_latent_scaled,
222
+ instruction_embedding=instruction_embedding,
223
+ uncond_embedding=uncond_embedding,
224
+ strength=strength,
225
+ steps=int(steps),
226
+ guidance_scale=guidance_scale,
227
+ seed=int(seed)
228
+ )
229
+
230
+ # Decode results
231
+ self.model = self.model.cpu()
232
+ torch.cuda.empty_cache()
233
+ self.dcae = self.dcae.to(DEVICE)
234
+
235
+ final_latent_unscaled = (final_latent.float() / self.dcae.scale_factor) + self.dcae.shift_factor
236
+ source_latent_raw = (source_latent_scaled / self.dcae.scale_factor) + self.dcae.shift_factor
237
+
238
+ with torch.no_grad():
239
+ source_mel = self.dcae.decode_to_mel(source_latent_raw)
240
+ edited_mel = self.dcae.decode_to_mel(final_latent_unscaled)
241
+ _, pred_wavs = self.dcae.decode(latents=final_latent.float(), sr=44100)
242
+ edited_audio = pred_wavs[0]
243
+
244
+ # Create comparison plot
245
+ comparison_plot = self.create_mel_comparison(source_mel, edited_mel, instruction)
246
+
247
+ # Save output audio
248
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
249
+ torchaudio.save(tmp_file.name, edited_audio.cpu().float(), 44100)
250
+ output_path = tmp_file.name
251
+
252
+ # Cleanup
253
+ self.dcae = self.dcae.cpu()
254
+ torch.cuda.empty_cache()
255
+
256
+ return output_path, comparison_plot, f"βœ… Audio editing completed! Instruction: '{instruction}'"
257
+
258
+ except Exception as e:
259
+ import traceback
260
+ error_msg = f"❌ Error: {str(e)}\n{traceback.format_exc()}"
261
+ print(error_msg)
262
+ return None, None, error_msg
263
+
264
+ def create_mel_comparison(self, source_mel, edited_mel, instruction):
265
+ """Create mel-spectrogram comparison plot"""
266
+ try:
267
+ source_mel_np = source_mel.squeeze(0)[0].cpu().float().numpy()
268
+ edited_mel_np = edited_mel.squeeze(0)[0].cpu().float().numpy()
269
+
270
+ fig, axs = plt.subplots(2, 1, figsize=(12, 8), sharex=True, sharey=True)
271
+ fig.suptitle(f'Mel-Spectrogram Comparison', fontsize=14)
272
+
273
+ # Plot source
274
+ im1 = axs[0].imshow(source_mel_np, aspect='auto', origin='lower', cmap='viridis')
275
+ axs[0].set_title('Original Audio')
276
+ axs[0].set_ylabel('Mel Bins')
277
+ plt.colorbar(im1, ax=axs[0])
278
+
279
+ # Plot edited
280
+ im2 = axs[1].imshow(edited_mel_np, aspect='auto', origin='lower', cmap='viridis')
281
+ axs[1].set_title(f'Edited Audio: "{instruction}"')
282
+ axs[1].set_ylabel('Mel Bins')
283
+ axs[1].set_xlabel('Time Frames')
284
+ plt.colorbar(im2, ax=axs[1])
285
+
286
+ plt.tight_layout()
287
+
288
+ # Save to temporary file for Gradio
289
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_file:
290
+ plt.savefig(tmp_file.name, dpi=100, bbox_inches='tight')
291
+ plt.close()
292
+ return tmp_file.name
293
+
294
+ except Exception as e:
295
+ print(f"Error creating plot: {e}")
296
+ plt.close()
297
+ return None
298
+
299
+ # Initialize the audio editor
300
+ audio_editor = AudioEditor()
301
+
302
+ def gradio_interface(audio_file, instruction, guidance_scale, steps, strength, seed):
303
+ """Gradio interface function"""
304
+ if audio_file is None:
305
+ return None, None, "Please upload an audio file"
306
+
307
+ if not instruction.strip():
308
+ return None, None, "Please provide an editing instruction"
309
+
310
+ return audio_editor.process_audio(audio_file, instruction, guidance_scale, steps, strength, seed)
311
+
312
+ # Create Gradio interface
313
+ with gr.Blocks(title="🎡 AI Audio Editor", theme=gr.themes.Soft()) as demo:
314
+ gr.HTML("""
315
+ <div style="text-align: center; margin-bottom: 20px;">
316
+ <h1>🎡 AI Audio Editor</h1>
317
+ <p>Upload an audio file and provide instructions to edit it using AI.<br/>
318
+ The model uses DPM-Solver++ for fast, high-quality generation.</p>
319
+ </div>
320
+ """)
321
+
322
+ with gr.Row():
323
+ with gr.Column(scale=1):
324
+ # Input components
325
+ audio_input = gr.Audio(
326
+ label="πŸ“ Upload Audio File",
327
+ type="filepath"
328
+ )
329
+
330
+ instruction_input = gr.Textbox(
331
+ label="✏️ Editing Instruction",
332
+ placeholder="e.g., 'Add drums', 'Make it more energetic', 'Remove vocals'",
333
+ lines=2
334
+ )
335
+
336
+ with gr.Accordion("πŸ”§ Advanced Settings", open=False):
337
+ guidance_scale = gr.Slider(
338
+ minimum=1.0, maximum=20.0, value=7.5, step=0.5,
339
+ label="Guidance Scale",
340
+ info="Higher values follow the instruction more closely"
341
+ )
342
+
343
+ steps = gr.Slider(
344
+ minimum=10, maximum=50, value=25, step=5,
345
+ label="Sampling Steps",
346
+ info="More steps = better quality, slower generation"
347
+ )
348
+
349
+ strength = gr.Slider(
350
+ minimum=0.1, maximum=1.0, value=1.0, step=0.1,
351
+ label="Denoising Strength",
352
+ info="1.0 = full denoising, lower = more conservative editing"
353
+ )
354
+
355
+ seed = gr.Number(
356
+ value=42, label="Seed",
357
+ info="For reproducible results"
358
+ )
359
+
360
+ generate_btn = gr.Button("🎨 Generate Edited Audio", variant="primary", size="lg")
361
+
362
+ with gr.Column(scale=1):
363
+ # Output components
364
+ status_output = gr.Textbox(label="πŸ“Š Status", interactive=False)
365
+ audio_output = gr.Audio(label="🎡 Generated Audio")
366
+ plot_output = gr.Image(label="πŸ“ˆ Mel-Spectrogram Comparison")
367
+
368
+ gr.HTML("""
369
+ <div style="margin-top: 20px; padding: 20px; background-color: #f0f0f0; border-radius: 10px;">
370
+ <h3>πŸ“ Usage Tips:</h3>
371
+ <ul>
372
+ <li><b>Audio Length:</b> Files are automatically processed to 10 seconds</li>
373
+ <li><b>Instructions:</b> Be specific (e.g., "Add heavy drums" vs "Add drums")</li>
374
+ <li><b>Guidance Scale:</b> Start with 7.5, increase for stronger effects</li>
375
+ <li><b>Steps:</b> 25 steps provide good quality/speed balance</li>
376
+ </ul>
377
+ </div>
378
+ """)
379
+
380
+ # Connect the interface
381
+ generate_btn.click(
382
+ fn=gradio_interface,
383
+ inputs=[audio_input, instruction_input, guidance_scale, steps, strength, seed],
384
+ outputs=[audio_output, plot_output, status_output],
385
+ show_progress=True
386
+ )
387
+
388
+ # Launch settings
389
+ if __name__ == "__main__":
390
+ demo.launch(
391
+ server_name="0.0.0.0",
392
+ server_port=7860,
393
+ share=False,
394
+ show_error=True
395
  )