File size: 4,653 Bytes
2669951
 
 
c372f9b
2669951
 
 
 
 
 
 
 
c372f9b
2669951
 
 
c372f9b
2669951
 
 
c372f9b
2669951
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c372f9b
2669951
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c372f9b
0f2c357
2669951
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c372f9b
2669951
 
0b7a658
2669951
c672f0c
2669951
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
#!/usr/bin/env python3
# Minimal Muse-0.6b HF Space App (EXPERIMENTAL)

import os
import sys
import tempfile
import torch
import gradio as gr
import numpy as np
import torchaudio
from huggingface_hub import snapshot_download
from transformers import AutoTokenizer, AutoModelForCausalLM

# -----------------------------------------------------------------------------
# Basic config
# -----------------------------------------------------------------------------

MODEL_ID = "bolshyC/Muse-0.6b"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SAMPLE_RATE = 48000

# Force HF cache to writable dir
os.environ["HF_HOME"] = "/tmp/hf"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf/transformers"

# -----------------------------------------------------------------------------
# Load Muse language model
# -----------------------------------------------------------------------------

print("Loading Muse-0.6b...")
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_ID,
    trust_remote_code=True
)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    trust_remote_code=True,
    torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
    device_map="auto" if DEVICE == "cuda" else None,
)
model.eval()
print("Muse loaded.")

# -----------------------------------------------------------------------------
# Load MuCodec + AudioLDM (VERY HEAVY)
# -----------------------------------------------------------------------------

print("Loading MuCodec / AudioLDM...")

sys.path.insert(0, "./MuCodec")
from MuCodec.model import PromptCondAudioDiffusion
from MuCodec.tools.get_melvaehifigan48k import build_pretrained_models

# Download AudioLDM to /tmp
audioldm_dir = snapshot_download(
    "haoheliu/audioldm_48k",
    local_dir="/tmp/audioldm",
    local_dir_use_symlinks=False
)
audioldm_path = os.path.join(audioldm_dir, "audioldm_48k.pth")

vae, stft = build_pretrained_models(audioldm_path)
vae = vae.to(DEVICE).eval()

mucodec = PromptCondAudioDiffusion(
    num_channels=32,
    unet_model_name=None,
    unet_model_config_path="./MuCodec/configs/models/transformer2D.json",
    snr_gamma=None,
)

ckpt = torch.load("./MuCodec/ckpt/mucodec.pt", map_location="cpu")
mucodec.load_state_dict(ckpt, strict=False)
mucodec = mucodec.to(DEVICE).eval()

print("MuCodec loaded.")

# -----------------------------------------------------------------------------
# Helpers
# -----------------------------------------------------------------------------

def extract_audio_tokens(text: str):
    if "<|audio_0|>" not in text:
        return None
    start = text.find("<|audio_0|>") + len("<|audio_0|>")
    end = text.find("<|audio_1|>")
    tokens = [int(x) for x in text[start:end].split() if x.isdigit()]
    if not tokens:
        return None
    return torch.tensor(tokens).unsqueeze(0).unsqueeze(0)

# -----------------------------------------------------------------------------
# Generation
# -----------------------------------------------------------------------------

def generate(prompt):
    if not prompt.strip():
        return None, "Empty prompt"

    inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)

    with torch.no_grad():
        out = model.generate(
            **inputs,
            max_new_tokens=1024,
            do_sample=False
        )

    text = tokenizer.decode(out[0], skip_special_tokens=False)
    codes = extract_audio_tokens(text)
    if codes is None:
        return None, "Failed to parse audio tokens"

    codes = codes.to(DEVICE)

    # Extremely reduced diffusion steps
    latents = mucodec.inference_codes(
        [codes[:, :, :1024]],
        torch.zeros([1, 32, 1, 32], device=DEVICE),
        torch.randn(1, 32, 512, 32, device=DEVICE),
        latent_length=512,
        first_latent_length=0,
        additional_feats=[],
        guidance_scale=1.0,
        num_steps=10,
        disable_progress=True,
        scenario="other_seg"
    )

    mel = vae.decode_first_stage(latents.float())
    wav = vae.decode_to_waveform(mel)

    with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
        out_path = f.name

    torchaudio.save(out_path, torch.from_numpy(wav), SAMPLE_RATE)
    return out_path, "Done"

# -----------------------------------------------------------------------------
# UI
# -----------------------------------------------------------------------------

with gr.Blocks() as demo:
    gr.Markdown("# Muse-0.6b (Experimental HF Space)")
    prompt = gr.Textbox(label="Prompt")
    btn = gr.Button("Generate")
    audio = gr.Audio(type="filepath")
    status = gr.Textbox()

    btn.click(generate, prompt, [audio, status])

demo.launch()