|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
import sys |
|
|
import tempfile |
|
|
import torch |
|
|
import gradio as gr |
|
|
import numpy as np |
|
|
import torchaudio |
|
|
from huggingface_hub import snapshot_download |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODEL_ID = "bolshyC/Muse-0.6b" |
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
SAMPLE_RATE = 48000 |
|
|
|
|
|
|
|
|
os.environ["HF_HOME"] = "/tmp/hf" |
|
|
os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf/transformers" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("Loading Muse-0.6b...") |
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
|
MODEL_ID, |
|
|
trust_remote_code=True |
|
|
) |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
MODEL_ID, |
|
|
trust_remote_code=True, |
|
|
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32, |
|
|
device_map="auto" if DEVICE == "cuda" else None, |
|
|
) |
|
|
model.eval() |
|
|
print("Muse loaded.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("Loading MuCodec / AudioLDM...") |
|
|
|
|
|
sys.path.insert(0, "./MuCodec") |
|
|
from MuCodec.model import PromptCondAudioDiffusion |
|
|
from MuCodec.tools.get_melvaehifigan48k import build_pretrained_models |
|
|
|
|
|
|
|
|
audioldm_dir = snapshot_download( |
|
|
"haoheliu/audioldm_48k", |
|
|
local_dir="/tmp/audioldm", |
|
|
local_dir_use_symlinks=False |
|
|
) |
|
|
audioldm_path = os.path.join(audioldm_dir, "audioldm_48k.pth") |
|
|
|
|
|
vae, stft = build_pretrained_models(audioldm_path) |
|
|
vae = vae.to(DEVICE).eval() |
|
|
|
|
|
mucodec = PromptCondAudioDiffusion( |
|
|
num_channels=32, |
|
|
unet_model_name=None, |
|
|
unet_model_config_path="./MuCodec/configs/models/transformer2D.json", |
|
|
snr_gamma=None, |
|
|
) |
|
|
|
|
|
ckpt = torch.load("./MuCodec/ckpt/mucodec.pt", map_location="cpu") |
|
|
mucodec.load_state_dict(ckpt, strict=False) |
|
|
mucodec = mucodec.to(DEVICE).eval() |
|
|
|
|
|
print("MuCodec loaded.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def extract_audio_tokens(text: str): |
|
|
if "<|audio_0|>" not in text: |
|
|
return None |
|
|
start = text.find("<|audio_0|>") + len("<|audio_0|>") |
|
|
end = text.find("<|audio_1|>") |
|
|
tokens = [int(x) for x in text[start:end].split() if x.isdigit()] |
|
|
if not tokens: |
|
|
return None |
|
|
return torch.tensor(tokens).unsqueeze(0).unsqueeze(0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate(prompt): |
|
|
if not prompt.strip(): |
|
|
return None, "Empty prompt" |
|
|
|
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE) |
|
|
|
|
|
with torch.no_grad(): |
|
|
out = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=1024, |
|
|
do_sample=False |
|
|
) |
|
|
|
|
|
text = tokenizer.decode(out[0], skip_special_tokens=False) |
|
|
codes = extract_audio_tokens(text) |
|
|
if codes is None: |
|
|
return None, "Failed to parse audio tokens" |
|
|
|
|
|
codes = codes.to(DEVICE) |
|
|
|
|
|
|
|
|
latents = mucodec.inference_codes( |
|
|
[codes[:, :, :1024]], |
|
|
torch.zeros([1, 32, 1, 32], device=DEVICE), |
|
|
torch.randn(1, 32, 512, 32, device=DEVICE), |
|
|
latent_length=512, |
|
|
first_latent_length=0, |
|
|
additional_feats=[], |
|
|
guidance_scale=1.0, |
|
|
num_steps=10, |
|
|
disable_progress=True, |
|
|
scenario="other_seg" |
|
|
) |
|
|
|
|
|
mel = vae.decode_first_stage(latents.float()) |
|
|
wav = vae.decode_to_waveform(mel) |
|
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f: |
|
|
out_path = f.name |
|
|
|
|
|
torchaudio.save(out_path, torch.from_numpy(wav), SAMPLE_RATE) |
|
|
return out_path, "Done" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("# Muse-0.6b (Experimental HF Space)") |
|
|
prompt = gr.Textbox(label="Prompt") |
|
|
btn = gr.Button("Generate") |
|
|
audio = gr.Audio(type="filepath") |
|
|
status = gr.Textbox() |
|
|
|
|
|
btn.click(generate, prompt, [audio, status]) |
|
|
|
|
|
demo.launch() |
|
|
|