mude / app.py
NZUONG's picture
Update app.py
18b3589 verified
import os
import torch
import torchaudio
import gradio as gr
import matplotlib.pyplot as plt
from tqdm import tqdm
from transformers import UMT5EncoderModel, AutoTokenizer
from huggingface_hub import hf_hub_download, snapshot_download
import json
import numpy as np
import tempfile
from io import BytesIO
import warnings
warnings.filterwarnings("ignore")
# Import model components
from model.ae.music_dcae import MusicDCAE
from model.ldm.editing_unet import EditingUNet
from model.ldm.dpm_solver_pytorch import NoiseScheduleVP, model_wrapper, DPM_Solver
# Configuration
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float32
# Model repository - UPDATE THIS TO YOUR MODEL REPO
MODEL_REPO = "NZUONG/mude" # Your uploaded model repository
# DDPM Parameters
DDPM_NUM_TIMESTEPS = 1000
DDPM_BETA_START = 0.0001
DDPM_BETA_END = 0.02
class AttrDict(dict):
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
self.__dict__ = self
def download_models():
"""Download models from Hugging Face Hub"""
print("πŸ”„ Downloading models from Hugging Face Hub...")
# Create local directories
os.makedirs("checkpoints", exist_ok=True)
try:
# Download the entire repository
local_dir = snapshot_download(
repo_id=MODEL_REPO,
cache_dir="./cache",
local_dir="./checkpoints",
repo_type="model"
)
print(f"βœ… Models downloaded to: {local_dir}")
return True
except Exception as e:
print(f"❌ Error downloading models: {e}")
return False
class AudioEditor:
def __init__(self):
self.dcae = None
self.tokenizer = None
self.text_encoder = None
self.model = None
self.is_loaded = False
def load_models(self):
"""Load all models once at startup"""
if self.is_loaded:
return True
# Download models if not present
if not os.path.exists("checkpoints/music_dcae_f8c8"):
print("πŸ“₯ Models not found locally, downloading...")
if not download_models():
return False
print("πŸ”„ Loading models...")
try:
# Model paths
dcae_path = "checkpoints/music_dcae_f8c8"
vocoder_path = "checkpoints/music_vocoder"
t5_path = "checkpoints/umt5-base"
unet_config_path = "model/ldm/exp_config.json"
trained_model_path = "checkpoints/fm_checkpoint_epoch_9.pt"
# Load DCAE
self.dcae = MusicDCAE(
dcae_checkpoint_path=dcae_path,
vocoder_checkpoint_path=vocoder_path
).to(DEVICE).eval()
# Load text encoder
self.tokenizer = AutoTokenizer.from_pretrained(t5_path)
self.text_encoder = UMT5EncoderModel.from_pretrained(t5_path).to(DEVICE, dtype=DTYPE).eval()
# Load UNet config
with open(unet_config_path, 'r') as f:
unet_config = AttrDict(json.load(f)['model']['unet'])
self.model = EditingUNet(unet_config, use_flow_matching=False).to("cpu", dtype=DTYPE).eval()
# Load checkpoint
checkpoint = torch.load(trained_model_path, map_location="cpu")
model_state_dict = checkpoint.get('model_state_dict', checkpoint)
if any(key.startswith('_orig_mod.') for key in model_state_dict.keys()):
model_state_dict = {key.replace('_orig_mod.', ''): value for key, value in model_state_dict.items()}
self.model.load_state_dict(model_state_dict, strict=False)
self.is_loaded = True
print("βœ… All models loaded successfully!")
return True
except Exception as e:
print(f"❌ Error loading models: {e}")
return False
def dpm_solver_sampling(self, model, source_latent, instruction_embedding, uncond_embedding,
strength=1.0, steps=25, guidance_scale=7.5, seed=42):
"""DPM-Solver sampling function"""
print(f"πŸš€ Starting DPM-Solver++ sampling with {steps} steps...")
# Setup noise schedule - FIXED TYPO HERE
betas = torch.linspace(DDPM_BETA_START, DDPM_BETA_END, DDPM_NUM_TIMESTEPS, dtype=torch.float32)
alphas_cumprod = torch.cumprod(1.0 - betas, dim=0)
noise_schedule = NoiseScheduleVP(schedule='discrete', alphas_cumprod=alphas_cumprod)
# Setup model wrapper
model_fn = model_wrapper(
model,
noise_schedule,
model_type="noise", # DDPM objective only
model_kwargs={
"source_latent": source_latent,
},
guidance_type="classifier-free",
condition=instruction_embedding,
unconditional_condition=uncond_embedding,
guidance_scale=guidance_scale,
)
# Initialize DPM-Solver++
dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++")
# Calculate time range
t_end = noise_schedule.T / noise_schedule.total_N
t_start = t_end + strength * (noise_schedule.T - t_end)
# Add initial noise
torch.manual_seed(seed)
noise = torch.randn_like(source_latent)
latents = dpm_solver.add_noise(source_latent, torch.tensor([t_start], device=DEVICE), noise)
latents = latents.to(DTYPE)
# Run DPM solver sampling
with torch.amp.autocast(device_type="cuda", dtype=DTYPE, enabled=(DTYPE != torch.float32)):
with torch.no_grad():
final_latent, _ = dpm_solver.sample(
latents,
steps=steps,
t_start=t_start,
t_end=t_end,
order=2,
method="multistep",
skip_type="time_uniform",
lower_order_final=True,
return_intermediate=True,
)
return final_latent
def process_audio(self, audio_file, instruction, guidance_scale, steps, strength, seed):
"""Main audio processing function"""
try:
if not self.load_models():
return None, None, "❌ Failed to load models. Please try again."
# Load and preprocess audio
print(f"🎡 Processing audio: {audio_file}")
audio, sr = torchaudio.load(audio_file)
TARGET_SR_DCAE = 44100
TARGET_LEN_DCAE = TARGET_SR_DCAE * 10
if sr != TARGET_SR_DCAE:
audio = torchaudio.transforms.Resample(sr, TARGET_SR_DCAE)(audio)
if audio.shape[1] > TARGET_LEN_DCAE:
audio = audio[:, :TARGET_LEN_DCAE]
elif audio.shape[1] < TARGET_LEN_DCAE:
audio = torch.nn.functional.pad(audio, (0, TARGET_LEN_DCAE - audio.shape[1]))
if audio.shape[0] == 1:
audio = audio.repeat(2, 1)
# Encode audio
with torch.no_grad():
source_latent_scaled, _ = self.dcae.encode(audio.to(DEVICE).unsqueeze(0))
# Prepare text embeddings
with torch.no_grad(), torch.amp.autocast(device_type="cuda", dtype=DTYPE, enabled=(DTYPE != torch.float32)):
text_input = self.tokenizer([instruction], max_length=32, padding="max_length",
truncation=True, return_tensors="pt")
instruction_embedding = self.text_encoder(text_input.input_ids.to(DEVICE))[0]
uncond_input = self.tokenizer([""], max_length=32, padding="max_length",
truncation=True, return_tensors="pt")
uncond_embedding = self.text_encoder(uncond_input.input_ids.to(DEVICE))[0]
# Move models for inference
self.dcae = self.dcae.cpu()
torch.cuda.empty_cache()
self.model = self.model.to(DEVICE, dtype=DTYPE)
# Generate
print("🎨 Generating edited audio...")
with torch.amp.autocast(device_type="cuda", dtype=DTYPE, enabled=(DTYPE != torch.float32)):
with torch.no_grad():
final_latent = self.dpm_solver_sampling(
model=self.model,
source_latent=source_latent_scaled,
instruction_embedding=instruction_embedding,
uncond_embedding=uncond_embedding,
strength=strength,
steps=int(steps),
guidance_scale=guidance_scale,
seed=int(seed)
)
# Decode results
self.model = self.model.cpu()
torch.cuda.empty_cache()
self.dcae = self.dcae.to(DEVICE)
final_latent_unscaled = (final_latent.float() / self.dcae.scale_factor) + self.dcae.shift_factor
source_latent_raw = (source_latent_scaled / self.dcae.scale_factor) + self.dcae.shift_factor
with torch.no_grad():
source_mel = self.dcae.decode_to_mel(source_latent_raw)
edited_mel = self.dcae.decode_to_mel(final_latent_unscaled)
_, pred_wavs = self.dcae.decode(latents=final_latent.float(), sr=44100)
edited_audio = pred_wavs[0]
# Create comparison plot
comparison_plot = self.create_mel_comparison(source_mel, edited_mel, instruction)
# Save output audio
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
torchaudio.save(tmp_file.name, edited_audio.cpu().float(), 44100)
output_path = tmp_file.name
# Cleanup
self.dcae = self.dcae.cpu()
torch.cuda.empty_cache()
return output_path, comparison_plot, f"βœ… Audio editing completed! Instruction: '{instruction}'"
except Exception as e:
import traceback
error_msg = f"❌ Error: {str(e)}\n{traceback.format_exc()}"
print(error_msg)
return None, None, error_msg
def create_mel_comparison(self, source_mel, edited_mel, instruction):
"""Create mel-spectrogram comparison plot"""
try:
source_mel_np = source_mel.squeeze(0)[0].cpu().float().numpy()
edited_mel_np = edited_mel.squeeze(0)[0].cpu().float().numpy()
fig, axs = plt.subplots(2, 1, figsize=(12, 8), sharex=True, sharey=True)
fig.suptitle(f'Mel-Spectrogram Comparison', fontsize=14)
# Plot source
im1 = axs[0].imshow(source_mel_np, aspect='auto', origin='lower', cmap='viridis')
axs[0].set_title('Original Audio')
axs[0].set_ylabel('Mel Bins')
plt.colorbar(im1, ax=axs[0])
# Plot edited
im2 = axs[1].imshow(edited_mel_np, aspect='auto', origin='lower', cmap='viridis')
axs[1].set_title(f'Edited Audio: "{instruction}"')
axs[1].set_ylabel('Mel Bins')
axs[1].set_xlabel('Time Frames')
plt.colorbar(im2, ax=axs[1])
plt.tight_layout()
# Save to temporary file for Gradio
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_file:
plt.savefig(tmp_file.name, dpi=100, bbox_inches='tight')
plt.close()
return tmp_file.name
except Exception as e:
print(f"Error creating plot: {e}")
plt.close()
return None
# Initialize the audio editor
audio_editor = AudioEditor()
def gradio_interface(audio_file, instruction, guidance_scale, steps, strength, seed):
"""Gradio interface function"""
if audio_file is None:
return None, None, "Please upload an audio file"
if not instruction.strip():
return None, None, "Please provide an editing instruction"
return audio_editor.process_audio(audio_file, instruction, guidance_scale, steps, strength, seed)
# Create Gradio interface
with gr.Blocks(title="🎡 AI Audio Editor", theme=gr.themes.Soft()) as demo:
gr.HTML("""
<div style="text-align: center; margin-bottom: 20px;">
<h1>🎡 AI Audio Editor</h1>
<p>Upload an audio file and provide instructions to edit it using AI.<br/>
The model uses DPM-Solver++ for fast, high-quality generation.</p>
</div>
""")
with gr.Row():
with gr.Column(scale=1):
# Input components
audio_input = gr.Audio(
label="πŸ“ Upload Audio File",
type="filepath"
)
instruction_input = gr.Textbox(
label="✏️ Editing Instruction",
placeholder="e.g., 'Add drums', 'Make it more energetic', 'Remove vocals'",
lines=2
)
with gr.Accordion("πŸ”§ Advanced Settings", open=False):
guidance_scale = gr.Slider(
minimum=1.0, maximum=20.0, value=7.5, step=0.5,
label="Guidance Scale",
info="Higher values follow the instruction more closely"
)
steps = gr.Slider(
minimum=10, maximum=50, value=25, step=5,
label="Sampling Steps",
info="More steps = better quality, slower generation"
)
strength = gr.Slider(
minimum=0.1, maximum=1.0, value=1.0, step=0.1,
label="Denoising Strength",
info="1.0 = full denoising, lower = more conservative editing"
)
seed = gr.Number(
value=42, label="Seed",
info="For reproducible results"
)
generate_btn = gr.Button("🎨 Generate Edited Audio", variant="primary", size="lg")
with gr.Column(scale=1):
# Output components
status_output = gr.Textbox(label="πŸ“Š Status", interactive=False)
audio_output = gr.Audio(label="🎡 Generated Audio")
plot_output = gr.Image(label="πŸ“ˆ Mel-Spectrogram Comparison")
gr.HTML("""
<div style="margin-top: 20px; padding: 20px; background-color: #f0f0f0; border-radius: 10px;">
<h3>πŸ“ Usage Tips:</h3>
<ul>
<li><b>Audio Length:</b> Files are automatically processed to 10 seconds</li>
<li><b>Instructions:</b> Be specific (e.g., "Add heavy drums" vs "Add drums")</li>
<li><b>Guidance Scale:</b> Start with 7.5, increase for stronger effects</li>
<li><b>Steps:</b> 25 steps provide good quality/speed balance</li>
</ul>
</div>
""")
# Connect the interface
generate_btn.click(
fn=gradio_interface,
inputs=[audio_input, instruction_input, guidance_scale, steps, strength, seed],
outputs=[audio_output, plot_output, status_output],
show_progress=True
)
# Launch settings
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
show_error=True
)