Jacong commited on
Commit
2669951
ยท
verified ยท
1 Parent(s): dccd291

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +143 -45
app.py CHANGED
@@ -1,55 +1,153 @@
1
- import gradio as gr
2
- import subprocess
 
3
  import os
4
- import uuid
5
-
6
- OUTPUT_DIR = "outputs"
7
- os.makedirs(OUTPUT_DIR, exist_ok=True)
8
-
9
- def generate_music(prompt):
10
- uid = str(uuid.uuid4())[:8]
11
- out_wav = os.path.join(OUTPUT_DIR, f"{uid}.wav")
12
-
13
- cmd = [
14
- "python",
15
- "scripts/infer.py",
16
- "--model", "bolshyC/Muse-0.6b",
17
- "--prompt", prompt,
18
- "--out", out_wav
19
- ]
20
-
21
- try:
22
- subprocess.run(
23
- cmd,
24
- check=True,
25
- timeout=600
26
- )
27
- except Exception as e:
28
- return None, f"็”Ÿๆˆๅคฑ่ดฅ๏ผš{e}"
29
 
30
- if not os.path.exists(out_wav):
31
- return None, "็”Ÿๆˆๅคฑ่ดฅ๏ผšๆœชๆ‰พๅˆฐ้Ÿณ้ข‘ๆ–‡ไปถ"
 
32
 
33
- return out_wav, "็”ŸๆˆๅฎŒๆˆ"
 
 
34
 
35
- with gr.Blocks() as demo:
36
- gr.Markdown("# ๐ŸŽต Muse Music Generator (HF Space)")
37
- gr.Markdown("ไฝฟ็”จ Muse-0.6b ๅœจ HF Space ไธŠ็”Ÿๆˆ้Ÿณไน๏ผˆCLI ๆŽฅๅ…ฅ๏ผ‰")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
- prompt = gr.Textbox(
40
- label="Prompt",
41
- placeholder="A calm piano melody, slow tempo"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  )
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  btn = gr.Button("Generate")
45
- audio = gr.Audio(label="Output", type="filepath")
46
- status = gr.Textbox(label="Status")
47
 
48
- btn.click(
49
- fn=generate_music,
50
- inputs=prompt,
51
- outputs=[audio, status]
52
- )
53
 
54
- if __name__ == "__main__":
55
- demo.launch()
 
1
+ #!/usr/bin/env python3
2
+ # Minimal Muse-0.6b HF Space App (EXPERIMENTAL)
3
+
4
  import os
5
+ import sys
6
+ import tempfile
7
+ import torch
8
+ import gradio as gr
9
+ import numpy as np
10
+ import torchaudio
11
+ from huggingface_hub import snapshot_download
12
+ from transformers import AutoTokenizer, AutoModelForCausalLM
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
+ # -----------------------------------------------------------------------------
15
+ # Basic config
16
+ # -----------------------------------------------------------------------------
17
 
18
+ MODEL_ID = "bolshyC/Muse-0.6b"
19
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
20
+ SAMPLE_RATE = 48000
21
 
22
+ # Force HF cache to writable dir
23
+ os.environ["HF_HOME"] = "/tmp/hf"
24
+ os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf/transformers"
25
+
26
+ # -----------------------------------------------------------------------------
27
+ # Load Muse language model
28
+ # -----------------------------------------------------------------------------
29
+
30
+ print("Loading Muse-0.6b...")
31
+ tokenizer = AutoTokenizer.from_pretrained(
32
+ MODEL_ID,
33
+ trust_remote_code=True
34
+ )
35
+
36
+ model = AutoModelForCausalLM.from_pretrained(
37
+ MODEL_ID,
38
+ trust_remote_code=True,
39
+ torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
40
+ device_map="auto" if DEVICE == "cuda" else None,
41
+ )
42
+ model.eval()
43
+ print("Muse loaded.")
44
+
45
+ # -----------------------------------------------------------------------------
46
+ # Load MuCodec + AudioLDM (VERY HEAVY)
47
+ # -----------------------------------------------------------------------------
48
+
49
+ print("Loading MuCodec / AudioLDM...")
50
+
51
+ sys.path.insert(0, "./MuCodec")
52
+ from MuCodec.model import PromptCondAudioDiffusion
53
+ from MuCodec.tools.get_melvaehifigan48k import build_pretrained_models
54
+
55
+ # Download AudioLDM to /tmp
56
+ audioldm_dir = snapshot_download(
57
+ "haoheliu/audioldm_48k",
58
+ local_dir="/tmp/audioldm",
59
+ local_dir_use_symlinks=False
60
+ )
61
+ audioldm_path = os.path.join(audioldm_dir, "audioldm_48k.pth")
62
+
63
+ vae, stft = build_pretrained_models(audioldm_path)
64
+ vae = vae.to(DEVICE).eval()
65
+
66
+ mucodec = PromptCondAudioDiffusion(
67
+ num_channels=32,
68
+ unet_model_name=None,
69
+ unet_model_config_path="./MuCodec/configs/models/transformer2D.json",
70
+ snr_gamma=None,
71
+ )
72
+
73
+ ckpt = torch.load("./MuCodec/ckpt/mucodec.pt", map_location="cpu")
74
+ mucodec.load_state_dict(ckpt, strict=False)
75
+ mucodec = mucodec.to(DEVICE).eval()
76
+
77
+ print("MuCodec loaded.")
78
 
79
+ # -----------------------------------------------------------------------------
80
+ # Helpers
81
+ # -----------------------------------------------------------------------------
82
+
83
+ def extract_audio_tokens(text: str):
84
+ if "<|audio_0|>" not in text:
85
+ return None
86
+ start = text.find("<|audio_0|>") + len("<|audio_0|>")
87
+ end = text.find("<|audio_1|>")
88
+ tokens = [int(x) for x in text[start:end].split() if x.isdigit()]
89
+ if not tokens:
90
+ return None
91
+ return torch.tensor(tokens).unsqueeze(0).unsqueeze(0)
92
+
93
+ # -----------------------------------------------------------------------------
94
+ # Generation
95
+ # -----------------------------------------------------------------------------
96
+
97
+ def generate(prompt):
98
+ if not prompt.strip():
99
+ return None, "Empty prompt"
100
+
101
+ inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
102
+
103
+ with torch.no_grad():
104
+ out = model.generate(
105
+ **inputs,
106
+ max_new_tokens=1024,
107
+ do_sample=False
108
+ )
109
+
110
+ text = tokenizer.decode(out[0], skip_special_tokens=False)
111
+ codes = extract_audio_tokens(text)
112
+ if codes is None:
113
+ return None, "Failed to parse audio tokens"
114
+
115
+ codes = codes.to(DEVICE)
116
+
117
+ # Extremely reduced diffusion steps
118
+ latents = mucodec.inference_codes(
119
+ [codes[:, :, :1024]],
120
+ torch.zeros([1, 32, 1, 32], device=DEVICE),
121
+ torch.randn(1, 32, 512, 32, device=DEVICE),
122
+ latent_length=512,
123
+ first_latent_length=0,
124
+ additional_feats=[],
125
+ guidance_scale=1.0,
126
+ num_steps=10,
127
+ disable_progress=True,
128
+ scenario="other_seg"
129
  )
130
 
131
+ mel = vae.decode_first_stage(latents.float())
132
+ wav = vae.decode_to_waveform(mel)
133
+
134
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
135
+ out_path = f.name
136
+
137
+ torchaudio.save(out_path, torch.from_numpy(wav), SAMPLE_RATE)
138
+ return out_path, "Done"
139
+
140
+ # -----------------------------------------------------------------------------
141
+ # UI
142
+ # -----------------------------------------------------------------------------
143
+
144
+ with gr.Blocks() as demo:
145
+ gr.Markdown("# Muse-0.6b (Experimental HF Space)")
146
+ prompt = gr.Textbox(label="Prompt")
147
  btn = gr.Button("Generate")
148
+ audio = gr.Audio(type="filepath")
149
+ status = gr.Textbox()
150
 
151
+ btn.click(generate, prompt, [audio, status])
 
 
 
 
152
 
153
+ demo.launch()