import os import torch import torchaudio import gradio as gr import matplotlib.pyplot as plt from tqdm import tqdm from transformers import UMT5EncoderModel, AutoTokenizer from huggingface_hub import hf_hub_download, snapshot_download import json import numpy as np import tempfile from io import BytesIO import warnings warnings.filterwarnings("ignore") # Import model components from model.ae.music_dcae import MusicDCAE from model.ldm.editing_unet import EditingUNet from model.ldm.dpm_solver_pytorch import NoiseScheduleVP, model_wrapper, DPM_Solver # Configuration DEVICE = "cuda" if torch.cuda.is_available() else "cpu" DTYPE = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float32 # Model repository - UPDATE THIS TO YOUR MODEL REPO MODEL_REPO = "NZUONG/mude" # Your uploaded model repository # DDPM Parameters DDPM_NUM_TIMESTEPS = 1000 DDPM_BETA_START = 0.0001 DDPM_BETA_END = 0.02 class AttrDict(dict): def __init__(self, *args, **kwargs): super(AttrDict, self).__init__(*args, **kwargs) self.__dict__ = self def download_models(): """Download models from Hugging Face Hub""" print("🔄 Downloading models from Hugging Face Hub...") # Create local directories os.makedirs("checkpoints", exist_ok=True) try: # Download the entire repository local_dir = snapshot_download( repo_id=MODEL_REPO, cache_dir="./cache", local_dir="./checkpoints", repo_type="model" ) print(f"✅ Models downloaded to: {local_dir}") return True except Exception as e: print(f"❌ Error downloading models: {e}") return False class AudioEditor: def __init__(self): self.dcae = None self.tokenizer = None self.text_encoder = None self.model = None self.is_loaded = False def load_models(self): """Load all models once at startup""" if self.is_loaded: return True # Download models if not present if not os.path.exists("checkpoints/music_dcae_f8c8"): print("📥 Models not found locally, downloading...") if not download_models(): return False print("🔄 Loading models...") try: # Model paths dcae_path = "checkpoints/music_dcae_f8c8" vocoder_path = "checkpoints/music_vocoder" t5_path = "checkpoints/umt5-base" unet_config_path = "model/ldm/exp_config.json" trained_model_path = "checkpoints/fm_checkpoint_epoch_9.pt" # Load DCAE self.dcae = MusicDCAE( dcae_checkpoint_path=dcae_path, vocoder_checkpoint_path=vocoder_path ).to(DEVICE).eval() # Load text encoder self.tokenizer = AutoTokenizer.from_pretrained(t5_path) self.text_encoder = UMT5EncoderModel.from_pretrained(t5_path).to(DEVICE, dtype=DTYPE).eval() # Load UNet config with open(unet_config_path, 'r') as f: unet_config = AttrDict(json.load(f)['model']['unet']) self.model = EditingUNet(unet_config, use_flow_matching=False).to("cpu", dtype=DTYPE).eval() # Load checkpoint checkpoint = torch.load(trained_model_path, map_location="cpu") model_state_dict = checkpoint.get('model_state_dict', checkpoint) if any(key.startswith('_orig_mod.') for key in model_state_dict.keys()): model_state_dict = {key.replace('_orig_mod.', ''): value for key, value in model_state_dict.items()} self.model.load_state_dict(model_state_dict, strict=False) self.is_loaded = True print("✅ All models loaded successfully!") return True except Exception as e: print(f"❌ Error loading models: {e}") return False def dpm_solver_sampling(self, model, source_latent, instruction_embedding, uncond_embedding, strength=1.0, steps=25, guidance_scale=7.5, seed=42): """DPM-Solver sampling function""" print(f"🚀 Starting DPM-Solver++ sampling with {steps} steps...") # Setup noise schedule - FIXED TYPO HERE betas = torch.linspace(DDPM_BETA_START, DDPM_BETA_END, DDPM_NUM_TIMESTEPS, dtype=torch.float32) alphas_cumprod = torch.cumprod(1.0 - betas, dim=0) noise_schedule = NoiseScheduleVP(schedule='discrete', alphas_cumprod=alphas_cumprod) # Setup model wrapper model_fn = model_wrapper( model, noise_schedule, model_type="noise", # DDPM objective only model_kwargs={ "source_latent": source_latent, }, guidance_type="classifier-free", condition=instruction_embedding, unconditional_condition=uncond_embedding, guidance_scale=guidance_scale, ) # Initialize DPM-Solver++ dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++") # Calculate time range t_end = noise_schedule.T / noise_schedule.total_N t_start = t_end + strength * (noise_schedule.T - t_end) # Add initial noise torch.manual_seed(seed) noise = torch.randn_like(source_latent) latents = dpm_solver.add_noise(source_latent, torch.tensor([t_start], device=DEVICE), noise) latents = latents.to(DTYPE) # Run DPM solver sampling with torch.amp.autocast(device_type="cuda", dtype=DTYPE, enabled=(DTYPE != torch.float32)): with torch.no_grad(): final_latent, _ = dpm_solver.sample( latents, steps=steps, t_start=t_start, t_end=t_end, order=2, method="multistep", skip_type="time_uniform", lower_order_final=True, return_intermediate=True, ) return final_latent def process_audio(self, audio_file, instruction, guidance_scale, steps, strength, seed): """Main audio processing function""" try: if not self.load_models(): return None, None, "❌ Failed to load models. Please try again." # Load and preprocess audio print(f"🎵 Processing audio: {audio_file}") audio, sr = torchaudio.load(audio_file) TARGET_SR_DCAE = 44100 TARGET_LEN_DCAE = TARGET_SR_DCAE * 10 if sr != TARGET_SR_DCAE: audio = torchaudio.transforms.Resample(sr, TARGET_SR_DCAE)(audio) if audio.shape[1] > TARGET_LEN_DCAE: audio = audio[:, :TARGET_LEN_DCAE] elif audio.shape[1] < TARGET_LEN_DCAE: audio = torch.nn.functional.pad(audio, (0, TARGET_LEN_DCAE - audio.shape[1])) if audio.shape[0] == 1: audio = audio.repeat(2, 1) # Encode audio with torch.no_grad(): source_latent_scaled, _ = self.dcae.encode(audio.to(DEVICE).unsqueeze(0)) # Prepare text embeddings with torch.no_grad(), torch.amp.autocast(device_type="cuda", dtype=DTYPE, enabled=(DTYPE != torch.float32)): text_input = self.tokenizer([instruction], max_length=32, padding="max_length", truncation=True, return_tensors="pt") instruction_embedding = self.text_encoder(text_input.input_ids.to(DEVICE))[0] uncond_input = self.tokenizer([""], max_length=32, padding="max_length", truncation=True, return_tensors="pt") uncond_embedding = self.text_encoder(uncond_input.input_ids.to(DEVICE))[0] # Move models for inference self.dcae = self.dcae.cpu() torch.cuda.empty_cache() self.model = self.model.to(DEVICE, dtype=DTYPE) # Generate print("🎨 Generating edited audio...") with torch.amp.autocast(device_type="cuda", dtype=DTYPE, enabled=(DTYPE != torch.float32)): with torch.no_grad(): final_latent = self.dpm_solver_sampling( model=self.model, source_latent=source_latent_scaled, instruction_embedding=instruction_embedding, uncond_embedding=uncond_embedding, strength=strength, steps=int(steps), guidance_scale=guidance_scale, seed=int(seed) ) # Decode results self.model = self.model.cpu() torch.cuda.empty_cache() self.dcae = self.dcae.to(DEVICE) final_latent_unscaled = (final_latent.float() / self.dcae.scale_factor) + self.dcae.shift_factor source_latent_raw = (source_latent_scaled / self.dcae.scale_factor) + self.dcae.shift_factor with torch.no_grad(): source_mel = self.dcae.decode_to_mel(source_latent_raw) edited_mel = self.dcae.decode_to_mel(final_latent_unscaled) _, pred_wavs = self.dcae.decode(latents=final_latent.float(), sr=44100) edited_audio = pred_wavs[0] # Create comparison plot comparison_plot = self.create_mel_comparison(source_mel, edited_mel, instruction) # Save output audio with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file: torchaudio.save(tmp_file.name, edited_audio.cpu().float(), 44100) output_path = tmp_file.name # Cleanup self.dcae = self.dcae.cpu() torch.cuda.empty_cache() return output_path, comparison_plot, f"✅ Audio editing completed! Instruction: '{instruction}'" except Exception as e: import traceback error_msg = f"❌ Error: {str(e)}\n{traceback.format_exc()}" print(error_msg) return None, None, error_msg def create_mel_comparison(self, source_mel, edited_mel, instruction): """Create mel-spectrogram comparison plot""" try: source_mel_np = source_mel.squeeze(0)[0].cpu().float().numpy() edited_mel_np = edited_mel.squeeze(0)[0].cpu().float().numpy() fig, axs = plt.subplots(2, 1, figsize=(12, 8), sharex=True, sharey=True) fig.suptitle(f'Mel-Spectrogram Comparison', fontsize=14) # Plot source im1 = axs[0].imshow(source_mel_np, aspect='auto', origin='lower', cmap='viridis') axs[0].set_title('Original Audio') axs[0].set_ylabel('Mel Bins') plt.colorbar(im1, ax=axs[0]) # Plot edited im2 = axs[1].imshow(edited_mel_np, aspect='auto', origin='lower', cmap='viridis') axs[1].set_title(f'Edited Audio: "{instruction}"') axs[1].set_ylabel('Mel Bins') axs[1].set_xlabel('Time Frames') plt.colorbar(im2, ax=axs[1]) plt.tight_layout() # Save to temporary file for Gradio with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_file: plt.savefig(tmp_file.name, dpi=100, bbox_inches='tight') plt.close() return tmp_file.name except Exception as e: print(f"Error creating plot: {e}") plt.close() return None # Initialize the audio editor audio_editor = AudioEditor() def gradio_interface(audio_file, instruction, guidance_scale, steps, strength, seed): """Gradio interface function""" if audio_file is None: return None, None, "Please upload an audio file" if not instruction.strip(): return None, None, "Please provide an editing instruction" return audio_editor.process_audio(audio_file, instruction, guidance_scale, steps, strength, seed) # Create Gradio interface with gr.Blocks(title="🎵 AI Audio Editor", theme=gr.themes.Soft()) as demo: gr.HTML("""

🎵 AI Audio Editor

Upload an audio file and provide instructions to edit it using AI.
The model uses DPM-Solver++ for fast, high-quality generation.

""") with gr.Row(): with gr.Column(scale=1): # Input components audio_input = gr.Audio( label="📁 Upload Audio File", type="filepath" ) instruction_input = gr.Textbox( label="✏️ Editing Instruction", placeholder="e.g., 'Add drums', 'Make it more energetic', 'Remove vocals'", lines=2 ) with gr.Accordion("🔧 Advanced Settings", open=False): guidance_scale = gr.Slider( minimum=1.0, maximum=20.0, value=7.5, step=0.5, label="Guidance Scale", info="Higher values follow the instruction more closely" ) steps = gr.Slider( minimum=10, maximum=50, value=25, step=5, label="Sampling Steps", info="More steps = better quality, slower generation" ) strength = gr.Slider( minimum=0.1, maximum=1.0, value=1.0, step=0.1, label="Denoising Strength", info="1.0 = full denoising, lower = more conservative editing" ) seed = gr.Number( value=42, label="Seed", info="For reproducible results" ) generate_btn = gr.Button("🎨 Generate Edited Audio", variant="primary", size="lg") with gr.Column(scale=1): # Output components status_output = gr.Textbox(label="📊 Status", interactive=False) audio_output = gr.Audio(label="🎵 Generated Audio") plot_output = gr.Image(label="📈 Mel-Spectrogram Comparison") gr.HTML("""

📝 Usage Tips:

""") # Connect the interface generate_btn.click( fn=gradio_interface, inputs=[audio_input, instruction_input, guidance_scale, steps, strength, seed], outputs=[audio_output, plot_output, status_output], show_progress=True ) # Launch settings if __name__ == "__main__": demo.launch( server_name="0.0.0.0", server_port=7860, share=False, show_error=True )