File size: 4,653 Bytes
2669951 c372f9b 2669951 c372f9b 2669951 c372f9b 2669951 c372f9b 2669951 c372f9b 2669951 c372f9b 0f2c357 2669951 c372f9b 2669951 0b7a658 2669951 c672f0c 2669951 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
#!/usr/bin/env python3
# Minimal Muse-0.6b HF Space App (EXPERIMENTAL)
import os
import sys
import tempfile
import torch
import gradio as gr
import numpy as np
import torchaudio
from huggingface_hub import snapshot_download
from transformers import AutoTokenizer, AutoModelForCausalLM
# -----------------------------------------------------------------------------
# Basic config
# -----------------------------------------------------------------------------
MODEL_ID = "bolshyC/Muse-0.6b"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SAMPLE_RATE = 48000
# Force HF cache to writable dir
os.environ["HF_HOME"] = "/tmp/hf"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf/transformers"
# -----------------------------------------------------------------------------
# Load Muse language model
# -----------------------------------------------------------------------------
print("Loading Muse-0.6b...")
tokenizer = AutoTokenizer.from_pretrained(
MODEL_ID,
trust_remote_code=True
)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
trust_remote_code=True,
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
device_map="auto" if DEVICE == "cuda" else None,
)
model.eval()
print("Muse loaded.")
# -----------------------------------------------------------------------------
# Load MuCodec + AudioLDM (VERY HEAVY)
# -----------------------------------------------------------------------------
print("Loading MuCodec / AudioLDM...")
sys.path.insert(0, "./MuCodec")
from MuCodec.model import PromptCondAudioDiffusion
from MuCodec.tools.get_melvaehifigan48k import build_pretrained_models
# Download AudioLDM to /tmp
audioldm_dir = snapshot_download(
"haoheliu/audioldm_48k",
local_dir="/tmp/audioldm",
local_dir_use_symlinks=False
)
audioldm_path = os.path.join(audioldm_dir, "audioldm_48k.pth")
vae, stft = build_pretrained_models(audioldm_path)
vae = vae.to(DEVICE).eval()
mucodec = PromptCondAudioDiffusion(
num_channels=32,
unet_model_name=None,
unet_model_config_path="./MuCodec/configs/models/transformer2D.json",
snr_gamma=None,
)
ckpt = torch.load("./MuCodec/ckpt/mucodec.pt", map_location="cpu")
mucodec.load_state_dict(ckpt, strict=False)
mucodec = mucodec.to(DEVICE).eval()
print("MuCodec loaded.")
# -----------------------------------------------------------------------------
# Helpers
# -----------------------------------------------------------------------------
def extract_audio_tokens(text: str):
if "<|audio_0|>" not in text:
return None
start = text.find("<|audio_0|>") + len("<|audio_0|>")
end = text.find("<|audio_1|>")
tokens = [int(x) for x in text[start:end].split() if x.isdigit()]
if not tokens:
return None
return torch.tensor(tokens).unsqueeze(0).unsqueeze(0)
# -----------------------------------------------------------------------------
# Generation
# -----------------------------------------------------------------------------
def generate(prompt):
if not prompt.strip():
return None, "Empty prompt"
inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
with torch.no_grad():
out = model.generate(
**inputs,
max_new_tokens=1024,
do_sample=False
)
text = tokenizer.decode(out[0], skip_special_tokens=False)
codes = extract_audio_tokens(text)
if codes is None:
return None, "Failed to parse audio tokens"
codes = codes.to(DEVICE)
# Extremely reduced diffusion steps
latents = mucodec.inference_codes(
[codes[:, :, :1024]],
torch.zeros([1, 32, 1, 32], device=DEVICE),
torch.randn(1, 32, 512, 32, device=DEVICE),
latent_length=512,
first_latent_length=0,
additional_feats=[],
guidance_scale=1.0,
num_steps=10,
disable_progress=True,
scenario="other_seg"
)
mel = vae.decode_first_stage(latents.float())
wav = vae.decode_to_waveform(mel)
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
out_path = f.name
torchaudio.save(out_path, torch.from_numpy(wav), SAMPLE_RATE)
return out_path, "Done"
# -----------------------------------------------------------------------------
# UI
# -----------------------------------------------------------------------------
with gr.Blocks() as demo:
gr.Markdown("# Muse-0.6b (Experimental HF Space)")
prompt = gr.Textbox(label="Prompt")
btn = gr.Button("Generate")
audio = gr.Audio(type="filepath")
status = gr.Textbox()
btn.click(generate, prompt, [audio, status])
demo.launch()
|