|
|
import os |
|
|
import torch |
|
|
import torchaudio |
|
|
import gradio as gr |
|
|
import matplotlib.pyplot as plt |
|
|
from tqdm import tqdm |
|
|
from transformers import UMT5EncoderModel, AutoTokenizer |
|
|
from huggingface_hub import hf_hub_download, snapshot_download |
|
|
import json |
|
|
import numpy as np |
|
|
import tempfile |
|
|
from io import BytesIO |
|
|
import warnings |
|
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
|
|
|
from model.ae.music_dcae import MusicDCAE |
|
|
from model.ldm.editing_unet import EditingUNet |
|
|
from model.ldm.dpm_solver_pytorch import NoiseScheduleVP, model_wrapper, DPM_Solver |
|
|
|
|
|
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
DTYPE = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float32 |
|
|
|
|
|
|
|
|
MODEL_REPO = "NZUONG/mude" |
|
|
|
|
|
|
|
|
DDPM_NUM_TIMESTEPS = 1000 |
|
|
DDPM_BETA_START = 0.0001 |
|
|
DDPM_BETA_END = 0.02 |
|
|
|
|
|
class AttrDict(dict): |
|
|
def __init__(self, *args, **kwargs): |
|
|
super(AttrDict, self).__init__(*args, **kwargs) |
|
|
self.__dict__ = self |
|
|
|
|
|
def download_models(): |
|
|
"""Download models from Hugging Face Hub""" |
|
|
print("π Downloading models from Hugging Face Hub...") |
|
|
|
|
|
|
|
|
os.makedirs("checkpoints", exist_ok=True) |
|
|
|
|
|
try: |
|
|
|
|
|
local_dir = snapshot_download( |
|
|
repo_id=MODEL_REPO, |
|
|
cache_dir="./cache", |
|
|
local_dir="./checkpoints", |
|
|
repo_type="model" |
|
|
) |
|
|
print(f"β
Models downloaded to: {local_dir}") |
|
|
return True |
|
|
except Exception as e: |
|
|
print(f"β Error downloading models: {e}") |
|
|
return False |
|
|
|
|
|
class AudioEditor: |
|
|
def __init__(self): |
|
|
self.dcae = None |
|
|
self.tokenizer = None |
|
|
self.text_encoder = None |
|
|
self.model = None |
|
|
self.is_loaded = False |
|
|
|
|
|
def load_models(self): |
|
|
"""Load all models once at startup""" |
|
|
if self.is_loaded: |
|
|
return True |
|
|
|
|
|
|
|
|
if not os.path.exists("checkpoints/music_dcae_f8c8"): |
|
|
print("π₯ Models not found locally, downloading...") |
|
|
if not download_models(): |
|
|
return False |
|
|
|
|
|
print("π Loading models...") |
|
|
|
|
|
try: |
|
|
|
|
|
dcae_path = "checkpoints/music_dcae_f8c8" |
|
|
vocoder_path = "checkpoints/music_vocoder" |
|
|
t5_path = "checkpoints/umt5-base" |
|
|
unet_config_path = "model/ldm/exp_config.json" |
|
|
trained_model_path = "checkpoints/fm_checkpoint_epoch_9.pt" |
|
|
|
|
|
|
|
|
self.dcae = MusicDCAE( |
|
|
dcae_checkpoint_path=dcae_path, |
|
|
vocoder_checkpoint_path=vocoder_path |
|
|
).to(DEVICE).eval() |
|
|
|
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(t5_path) |
|
|
self.text_encoder = UMT5EncoderModel.from_pretrained(t5_path).to(DEVICE, dtype=DTYPE).eval() |
|
|
|
|
|
|
|
|
with open(unet_config_path, 'r') as f: |
|
|
unet_config = AttrDict(json.load(f)['model']['unet']) |
|
|
|
|
|
self.model = EditingUNet(unet_config, use_flow_matching=False).to("cpu", dtype=DTYPE).eval() |
|
|
|
|
|
|
|
|
checkpoint = torch.load(trained_model_path, map_location="cpu") |
|
|
model_state_dict = checkpoint.get('model_state_dict', checkpoint) |
|
|
if any(key.startswith('_orig_mod.') for key in model_state_dict.keys()): |
|
|
model_state_dict = {key.replace('_orig_mod.', ''): value for key, value in model_state_dict.items()} |
|
|
self.model.load_state_dict(model_state_dict, strict=False) |
|
|
|
|
|
self.is_loaded = True |
|
|
print("β
All models loaded successfully!") |
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β Error loading models: {e}") |
|
|
return False |
|
|
|
|
|
def dpm_solver_sampling(self, model, source_latent, instruction_embedding, uncond_embedding, |
|
|
strength=1.0, steps=25, guidance_scale=7.5, seed=42): |
|
|
"""DPM-Solver sampling function""" |
|
|
print(f"π Starting DPM-Solver++ sampling with {steps} steps...") |
|
|
|
|
|
|
|
|
betas = torch.linspace(DDPM_BETA_START, DDPM_BETA_END, DDPM_NUM_TIMESTEPS, dtype=torch.float32) |
|
|
alphas_cumprod = torch.cumprod(1.0 - betas, dim=0) |
|
|
noise_schedule = NoiseScheduleVP(schedule='discrete', alphas_cumprod=alphas_cumprod) |
|
|
|
|
|
|
|
|
model_fn = model_wrapper( |
|
|
model, |
|
|
noise_schedule, |
|
|
model_type="noise", |
|
|
model_kwargs={ |
|
|
"source_latent": source_latent, |
|
|
}, |
|
|
guidance_type="classifier-free", |
|
|
condition=instruction_embedding, |
|
|
unconditional_condition=uncond_embedding, |
|
|
guidance_scale=guidance_scale, |
|
|
) |
|
|
|
|
|
|
|
|
dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++") |
|
|
|
|
|
|
|
|
t_end = noise_schedule.T / noise_schedule.total_N |
|
|
t_start = t_end + strength * (noise_schedule.T - t_end) |
|
|
|
|
|
|
|
|
torch.manual_seed(seed) |
|
|
noise = torch.randn_like(source_latent) |
|
|
latents = dpm_solver.add_noise(source_latent, torch.tensor([t_start], device=DEVICE), noise) |
|
|
latents = latents.to(DTYPE) |
|
|
|
|
|
|
|
|
with torch.amp.autocast(device_type="cuda", dtype=DTYPE, enabled=(DTYPE != torch.float32)): |
|
|
with torch.no_grad(): |
|
|
final_latent, _ = dpm_solver.sample( |
|
|
latents, |
|
|
steps=steps, |
|
|
t_start=t_start, |
|
|
t_end=t_end, |
|
|
order=2, |
|
|
method="multistep", |
|
|
skip_type="time_uniform", |
|
|
lower_order_final=True, |
|
|
return_intermediate=True, |
|
|
) |
|
|
|
|
|
return final_latent |
|
|
|
|
|
def process_audio(self, audio_file, instruction, guidance_scale, steps, strength, seed): |
|
|
"""Main audio processing function""" |
|
|
try: |
|
|
if not self.load_models(): |
|
|
return None, None, "β Failed to load models. Please try again." |
|
|
|
|
|
|
|
|
print(f"π΅ Processing audio: {audio_file}") |
|
|
audio, sr = torchaudio.load(audio_file) |
|
|
TARGET_SR_DCAE = 44100 |
|
|
TARGET_LEN_DCAE = TARGET_SR_DCAE * 10 |
|
|
|
|
|
if sr != TARGET_SR_DCAE: |
|
|
audio = torchaudio.transforms.Resample(sr, TARGET_SR_DCAE)(audio) |
|
|
|
|
|
if audio.shape[1] > TARGET_LEN_DCAE: |
|
|
audio = audio[:, :TARGET_LEN_DCAE] |
|
|
elif audio.shape[1] < TARGET_LEN_DCAE: |
|
|
audio = torch.nn.functional.pad(audio, (0, TARGET_LEN_DCAE - audio.shape[1])) |
|
|
|
|
|
if audio.shape[0] == 1: |
|
|
audio = audio.repeat(2, 1) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
source_latent_scaled, _ = self.dcae.encode(audio.to(DEVICE).unsqueeze(0)) |
|
|
|
|
|
|
|
|
with torch.no_grad(), torch.amp.autocast(device_type="cuda", dtype=DTYPE, enabled=(DTYPE != torch.float32)): |
|
|
text_input = self.tokenizer([instruction], max_length=32, padding="max_length", |
|
|
truncation=True, return_tensors="pt") |
|
|
instruction_embedding = self.text_encoder(text_input.input_ids.to(DEVICE))[0] |
|
|
|
|
|
uncond_input = self.tokenizer([""], max_length=32, padding="max_length", |
|
|
truncation=True, return_tensors="pt") |
|
|
uncond_embedding = self.text_encoder(uncond_input.input_ids.to(DEVICE))[0] |
|
|
|
|
|
|
|
|
self.dcae = self.dcae.cpu() |
|
|
torch.cuda.empty_cache() |
|
|
self.model = self.model.to(DEVICE, dtype=DTYPE) |
|
|
|
|
|
|
|
|
print("π¨ Generating edited audio...") |
|
|
with torch.amp.autocast(device_type="cuda", dtype=DTYPE, enabled=(DTYPE != torch.float32)): |
|
|
with torch.no_grad(): |
|
|
final_latent = self.dpm_solver_sampling( |
|
|
model=self.model, |
|
|
source_latent=source_latent_scaled, |
|
|
instruction_embedding=instruction_embedding, |
|
|
uncond_embedding=uncond_embedding, |
|
|
strength=strength, |
|
|
steps=int(steps), |
|
|
guidance_scale=guidance_scale, |
|
|
seed=int(seed) |
|
|
) |
|
|
|
|
|
|
|
|
self.model = self.model.cpu() |
|
|
torch.cuda.empty_cache() |
|
|
self.dcae = self.dcae.to(DEVICE) |
|
|
|
|
|
final_latent_unscaled = (final_latent.float() / self.dcae.scale_factor) + self.dcae.shift_factor |
|
|
source_latent_raw = (source_latent_scaled / self.dcae.scale_factor) + self.dcae.shift_factor |
|
|
|
|
|
with torch.no_grad(): |
|
|
source_mel = self.dcae.decode_to_mel(source_latent_raw) |
|
|
edited_mel = self.dcae.decode_to_mel(final_latent_unscaled) |
|
|
_, pred_wavs = self.dcae.decode(latents=final_latent.float(), sr=44100) |
|
|
edited_audio = pred_wavs[0] |
|
|
|
|
|
|
|
|
comparison_plot = self.create_mel_comparison(source_mel, edited_mel, instruction) |
|
|
|
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file: |
|
|
torchaudio.save(tmp_file.name, edited_audio.cpu().float(), 44100) |
|
|
output_path = tmp_file.name |
|
|
|
|
|
|
|
|
self.dcae = self.dcae.cpu() |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
return output_path, comparison_plot, f"β
Audio editing completed! Instruction: '{instruction}'" |
|
|
|
|
|
except Exception as e: |
|
|
import traceback |
|
|
error_msg = f"β Error: {str(e)}\n{traceback.format_exc()}" |
|
|
print(error_msg) |
|
|
return None, None, error_msg |
|
|
|
|
|
def create_mel_comparison(self, source_mel, edited_mel, instruction): |
|
|
"""Create mel-spectrogram comparison plot""" |
|
|
try: |
|
|
source_mel_np = source_mel.squeeze(0)[0].cpu().float().numpy() |
|
|
edited_mel_np = edited_mel.squeeze(0)[0].cpu().float().numpy() |
|
|
|
|
|
fig, axs = plt.subplots(2, 1, figsize=(12, 8), sharex=True, sharey=True) |
|
|
fig.suptitle(f'Mel-Spectrogram Comparison', fontsize=14) |
|
|
|
|
|
|
|
|
im1 = axs[0].imshow(source_mel_np, aspect='auto', origin='lower', cmap='viridis') |
|
|
axs[0].set_title('Original Audio') |
|
|
axs[0].set_ylabel('Mel Bins') |
|
|
plt.colorbar(im1, ax=axs[0]) |
|
|
|
|
|
|
|
|
im2 = axs[1].imshow(edited_mel_np, aspect='auto', origin='lower', cmap='viridis') |
|
|
axs[1].set_title(f'Edited Audio: "{instruction}"') |
|
|
axs[1].set_ylabel('Mel Bins') |
|
|
axs[1].set_xlabel('Time Frames') |
|
|
plt.colorbar(im2, ax=axs[1]) |
|
|
|
|
|
plt.tight_layout() |
|
|
|
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_file: |
|
|
plt.savefig(tmp_file.name, dpi=100, bbox_inches='tight') |
|
|
plt.close() |
|
|
return tmp_file.name |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error creating plot: {e}") |
|
|
plt.close() |
|
|
return None |
|
|
|
|
|
|
|
|
audio_editor = AudioEditor() |
|
|
|
|
|
def gradio_interface(audio_file, instruction, guidance_scale, steps, strength, seed): |
|
|
"""Gradio interface function""" |
|
|
if audio_file is None: |
|
|
return None, None, "Please upload an audio file" |
|
|
|
|
|
if not instruction.strip(): |
|
|
return None, None, "Please provide an editing instruction" |
|
|
|
|
|
return audio_editor.process_audio(audio_file, instruction, guidance_scale, steps, strength, seed) |
|
|
|
|
|
|
|
|
with gr.Blocks(title="π΅ AI Audio Editor", theme=gr.themes.Soft()) as demo: |
|
|
gr.HTML(""" |
|
|
<div style="text-align: center; margin-bottom: 20px;"> |
|
|
<h1>π΅ AI Audio Editor</h1> |
|
|
<p>Upload an audio file and provide instructions to edit it using AI.<br/> |
|
|
The model uses DPM-Solver++ for fast, high-quality generation.</p> |
|
|
</div> |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
|
|
|
audio_input = gr.Audio( |
|
|
label="π Upload Audio File", |
|
|
type="filepath" |
|
|
) |
|
|
|
|
|
instruction_input = gr.Textbox( |
|
|
label="βοΈ Editing Instruction", |
|
|
placeholder="e.g., 'Add drums', 'Make it more energetic', 'Remove vocals'", |
|
|
lines=2 |
|
|
) |
|
|
|
|
|
with gr.Accordion("π§ Advanced Settings", open=False): |
|
|
guidance_scale = gr.Slider( |
|
|
minimum=1.0, maximum=20.0, value=7.5, step=0.5, |
|
|
label="Guidance Scale", |
|
|
info="Higher values follow the instruction more closely" |
|
|
) |
|
|
|
|
|
steps = gr.Slider( |
|
|
minimum=10, maximum=50, value=25, step=5, |
|
|
label="Sampling Steps", |
|
|
info="More steps = better quality, slower generation" |
|
|
) |
|
|
|
|
|
strength = gr.Slider( |
|
|
minimum=0.1, maximum=1.0, value=1.0, step=0.1, |
|
|
label="Denoising Strength", |
|
|
info="1.0 = full denoising, lower = more conservative editing" |
|
|
) |
|
|
|
|
|
seed = gr.Number( |
|
|
value=42, label="Seed", |
|
|
info="For reproducible results" |
|
|
) |
|
|
|
|
|
generate_btn = gr.Button("π¨ Generate Edited Audio", variant="primary", size="lg") |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
|
|
|
status_output = gr.Textbox(label="π Status", interactive=False) |
|
|
audio_output = gr.Audio(label="π΅ Generated Audio") |
|
|
plot_output = gr.Image(label="π Mel-Spectrogram Comparison") |
|
|
|
|
|
gr.HTML(""" |
|
|
<div style="margin-top: 20px; padding: 20px; background-color: #f0f0f0; border-radius: 10px;"> |
|
|
<h3>π Usage Tips:</h3> |
|
|
<ul> |
|
|
<li><b>Audio Length:</b> Files are automatically processed to 10 seconds</li> |
|
|
<li><b>Instructions:</b> Be specific (e.g., "Add heavy drums" vs "Add drums")</li> |
|
|
<li><b>Guidance Scale:</b> Start with 7.5, increase for stronger effects</li> |
|
|
<li><b>Steps:</b> 25 steps provide good quality/speed balance</li> |
|
|
</ul> |
|
|
</div> |
|
|
""") |
|
|
|
|
|
|
|
|
generate_btn.click( |
|
|
fn=gradio_interface, |
|
|
inputs=[audio_input, instruction_input, guidance_scale, steps, strength, seed], |
|
|
outputs=[audio_output, plot_output, status_output], |
|
|
show_progress=True |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch( |
|
|
server_name="0.0.0.0", |
|
|
server_port=7860, |
|
|
share=False, |
|
|
show_error=True |
|
|
) |