muse / app.py
Jacong's picture
Update app.py
2669951 verified
#!/usr/bin/env python3
# Minimal Muse-0.6b HF Space App (EXPERIMENTAL)
import os
import sys
import tempfile
import torch
import gradio as gr
import numpy as np
import torchaudio
from huggingface_hub import snapshot_download
from transformers import AutoTokenizer, AutoModelForCausalLM
# -----------------------------------------------------------------------------
# Basic config
# -----------------------------------------------------------------------------
MODEL_ID = "bolshyC/Muse-0.6b"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SAMPLE_RATE = 48000
# Force HF cache to writable dir
os.environ["HF_HOME"] = "/tmp/hf"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf/transformers"
# -----------------------------------------------------------------------------
# Load Muse language model
# -----------------------------------------------------------------------------
print("Loading Muse-0.6b...")
tokenizer = AutoTokenizer.from_pretrained(
MODEL_ID,
trust_remote_code=True
)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
trust_remote_code=True,
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
device_map="auto" if DEVICE == "cuda" else None,
)
model.eval()
print("Muse loaded.")
# -----------------------------------------------------------------------------
# Load MuCodec + AudioLDM (VERY HEAVY)
# -----------------------------------------------------------------------------
print("Loading MuCodec / AudioLDM...")
sys.path.insert(0, "./MuCodec")
from MuCodec.model import PromptCondAudioDiffusion
from MuCodec.tools.get_melvaehifigan48k import build_pretrained_models
# Download AudioLDM to /tmp
audioldm_dir = snapshot_download(
"haoheliu/audioldm_48k",
local_dir="/tmp/audioldm",
local_dir_use_symlinks=False
)
audioldm_path = os.path.join(audioldm_dir, "audioldm_48k.pth")
vae, stft = build_pretrained_models(audioldm_path)
vae = vae.to(DEVICE).eval()
mucodec = PromptCondAudioDiffusion(
num_channels=32,
unet_model_name=None,
unet_model_config_path="./MuCodec/configs/models/transformer2D.json",
snr_gamma=None,
)
ckpt = torch.load("./MuCodec/ckpt/mucodec.pt", map_location="cpu")
mucodec.load_state_dict(ckpt, strict=False)
mucodec = mucodec.to(DEVICE).eval()
print("MuCodec loaded.")
# -----------------------------------------------------------------------------
# Helpers
# -----------------------------------------------------------------------------
def extract_audio_tokens(text: str):
if "<|audio_0|>" not in text:
return None
start = text.find("<|audio_0|>") + len("<|audio_0|>")
end = text.find("<|audio_1|>")
tokens = [int(x) for x in text[start:end].split() if x.isdigit()]
if not tokens:
return None
return torch.tensor(tokens).unsqueeze(0).unsqueeze(0)
# -----------------------------------------------------------------------------
# Generation
# -----------------------------------------------------------------------------
def generate(prompt):
if not prompt.strip():
return None, "Empty prompt"
inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
with torch.no_grad():
out = model.generate(
**inputs,
max_new_tokens=1024,
do_sample=False
)
text = tokenizer.decode(out[0], skip_special_tokens=False)
codes = extract_audio_tokens(text)
if codes is None:
return None, "Failed to parse audio tokens"
codes = codes.to(DEVICE)
# Extremely reduced diffusion steps
latents = mucodec.inference_codes(
[codes[:, :, :1024]],
torch.zeros([1, 32, 1, 32], device=DEVICE),
torch.randn(1, 32, 512, 32, device=DEVICE),
latent_length=512,
first_latent_length=0,
additional_feats=[],
guidance_scale=1.0,
num_steps=10,
disable_progress=True,
scenario="other_seg"
)
mel = vae.decode_first_stage(latents.float())
wav = vae.decode_to_waveform(mel)
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
out_path = f.name
torchaudio.save(out_path, torch.from_numpy(wav), SAMPLE_RATE)
return out_path, "Done"
# -----------------------------------------------------------------------------
# UI
# -----------------------------------------------------------------------------
with gr.Blocks() as demo:
gr.Markdown("# Muse-0.6b (Experimental HF Space)")
prompt = gr.Textbox(label="Prompt")
btn = gr.Button("Generate")
audio = gr.Audio(type="filepath")
status = gr.Textbox()
btn.click(generate, prompt, [audio, status])
demo.launch()