#!/usr/bin/env python3 # Minimal Muse-0.6b HF Space App (EXPERIMENTAL) import os import sys import tempfile import torch import gradio as gr import numpy as np import torchaudio from huggingface_hub import snapshot_download from transformers import AutoTokenizer, AutoModelForCausalLM # ----------------------------------------------------------------------------- # Basic config # ----------------------------------------------------------------------------- MODEL_ID = "bolshyC/Muse-0.6b" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" SAMPLE_RATE = 48000 # Force HF cache to writable dir os.environ["HF_HOME"] = "/tmp/hf" os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf/transformers" # ----------------------------------------------------------------------------- # Load Muse language model # ----------------------------------------------------------------------------- print("Loading Muse-0.6b...") tokenizer = AutoTokenizer.from_pretrained( MODEL_ID, trust_remote_code=True ) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, trust_remote_code=True, torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32, device_map="auto" if DEVICE == "cuda" else None, ) model.eval() print("Muse loaded.") # ----------------------------------------------------------------------------- # Load MuCodec + AudioLDM (VERY HEAVY) # ----------------------------------------------------------------------------- print("Loading MuCodec / AudioLDM...") sys.path.insert(0, "./MuCodec") from MuCodec.model import PromptCondAudioDiffusion from MuCodec.tools.get_melvaehifigan48k import build_pretrained_models # Download AudioLDM to /tmp audioldm_dir = snapshot_download( "haoheliu/audioldm_48k", local_dir="/tmp/audioldm", local_dir_use_symlinks=False ) audioldm_path = os.path.join(audioldm_dir, "audioldm_48k.pth") vae, stft = build_pretrained_models(audioldm_path) vae = vae.to(DEVICE).eval() mucodec = PromptCondAudioDiffusion( num_channels=32, unet_model_name=None, unet_model_config_path="./MuCodec/configs/models/transformer2D.json", snr_gamma=None, ) ckpt = torch.load("./MuCodec/ckpt/mucodec.pt", map_location="cpu") mucodec.load_state_dict(ckpt, strict=False) mucodec = mucodec.to(DEVICE).eval() print("MuCodec loaded.") # ----------------------------------------------------------------------------- # Helpers # ----------------------------------------------------------------------------- def extract_audio_tokens(text: str): if "<|audio_0|>" not in text: return None start = text.find("<|audio_0|>") + len("<|audio_0|>") end = text.find("<|audio_1|>") tokens = [int(x) for x in text[start:end].split() if x.isdigit()] if not tokens: return None return torch.tensor(tokens).unsqueeze(0).unsqueeze(0) # ----------------------------------------------------------------------------- # Generation # ----------------------------------------------------------------------------- def generate(prompt): if not prompt.strip(): return None, "Empty prompt" inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE) with torch.no_grad(): out = model.generate( **inputs, max_new_tokens=1024, do_sample=False ) text = tokenizer.decode(out[0], skip_special_tokens=False) codes = extract_audio_tokens(text) if codes is None: return None, "Failed to parse audio tokens" codes = codes.to(DEVICE) # Extremely reduced diffusion steps latents = mucodec.inference_codes( [codes[:, :, :1024]], torch.zeros([1, 32, 1, 32], device=DEVICE), torch.randn(1, 32, 512, 32, device=DEVICE), latent_length=512, first_latent_length=0, additional_feats=[], guidance_scale=1.0, num_steps=10, disable_progress=True, scenario="other_seg" ) mel = vae.decode_first_stage(latents.float()) wav = vae.decode_to_waveform(mel) with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f: out_path = f.name torchaudio.save(out_path, torch.from_numpy(wav), SAMPLE_RATE) return out_path, "Done" # ----------------------------------------------------------------------------- # UI # ----------------------------------------------------------------------------- with gr.Blocks() as demo: gr.Markdown("# Muse-0.6b (Experimental HF Space)") prompt = gr.Textbox(label="Prompt") btn = gr.Button("Generate") audio = gr.Audio(type="filepath") status = gr.Textbox() btn.click(generate, prompt, [audio, status]) demo.launch()