mrfakename commited on
Commit
4387a7f
·
verified ·
1 Parent(s): 27c6cbc

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +361 -0
app.py CHANGED
@@ -0,0 +1,361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ HuggingFace Space app for Muse-8b music generation
4
+ Text input -> Audio output
5
+ """
6
+
7
+ import spaces
8
+ import gradio as gr
9
+ import os
10
+ import sys
11
+ import tempfile
12
+ from typing import Optional, Tuple
13
+ import torch
14
+ import numpy as np
15
+ import torchaudio
16
+
17
+ # Add MuCodec to path
18
+ sys.path.insert(0, "./MuCodec")
19
+
20
+ from transformers import AutoModelForCausalLM, AutoTokenizer
21
+ from MuCodec.model import PromptCondAudioDiffusion
22
+ from MuCodec.tools.get_melvaehifigan48k import build_pretrained_models
23
+ import MuCodec.tools.torch_tools as torch_tools
24
+
25
+ # Constants
26
+ MODEL_NAME = "bolshyC/Muse-8b"
27
+ SAMPLE_RATE = 48000
28
+
29
+ # ============================================================================
30
+ # Model Loading at Module Level
31
+ # ============================================================================
32
+
33
+ print("Loading Muse language model...")
34
+ device = "cuda" if torch.cuda.is_available() else "cpu"
35
+
36
+ # Load language model
37
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
38
+ language_model = AutoModelForCausalLM.from_pretrained(
39
+ MODEL_NAME,
40
+ trust_remote_code=True,
41
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
42
+ device_map="auto" if device == "cuda" else None,
43
+ )
44
+ if device == "cpu":
45
+ language_model = language_model.to(device)
46
+ language_model.eval()
47
+ print("Language model loaded!")
48
+
49
+ # Load MuCodec decoder
50
+ print("Loading MuCodec decoder...")
51
+ mucodec_dir = "./MuCodec"
52
+ ckpt_path = os.path.join(mucodec_dir, "ckpt/mucodec.pt")
53
+ audioldm_path = os.path.join(mucodec_dir, "tools/audioldm_48k.pth")
54
+ config_path = os.path.join(mucodec_dir, "configs/models/transformer2D.json")
55
+
56
+ # Load VAE and STFT
57
+ vae, stft = build_pretrained_models(audioldm_path)
58
+ vae = vae.eval().to(device)
59
+ stft = stft.eval().to(device)
60
+
61
+ # Load diffusion model
62
+ main_config = {
63
+ "num_channels": 32,
64
+ "unet_model_name": None,
65
+ "unet_model_config_path": config_path,
66
+ "snr_gamma": None,
67
+ }
68
+ mucodec_model = PromptCondAudioDiffusion(**main_config)
69
+ main_weights = torch.load(ckpt_path, map_location='cpu')
70
+ mucodec_model.load_state_dict(main_weights, strict=False)
71
+ mucodec_model = mucodec_model.to(device).eval()
72
+ mucodec_model.init_device_dtype(torch.device(device), torch.float32)
73
+ print("MuCodec decoder loaded!")
74
+
75
+ # ============================================================================
76
+ # Helper Functions
77
+ # ============================================================================
78
+
79
+ def parse_tokens_from_text(text: str) -> Optional[torch.Tensor]:
80
+ """Extract audio tokens from generated text"""
81
+ try:
82
+ if "<|audio_0|>" in text and "<|audio_1|>" in text:
83
+ start = text.find("<|audio_0|>") + len("<|audio_0|>")
84
+ end = text.find("<|audio_1|>")
85
+ token_str = text[start:end].strip()
86
+ else:
87
+ token_str = text.strip()
88
+
89
+ tokens = [int(t) for t in token_str.split() if t.isdigit()]
90
+
91
+ if len(tokens) == 0:
92
+ return None
93
+
94
+ return torch.tensor(tokens, dtype=torch.long).unsqueeze(0).unsqueeze(0)
95
+
96
+ except Exception as e:
97
+ print(f"Error parsing tokens: {e}")
98
+ return None
99
+
100
+
101
+ def codes_to_audio(
102
+ codes: torch.Tensor,
103
+ num_steps: int = 20
104
+ ) -> torch.Tensor:
105
+ """Convert audio codes to waveform using MuCodec"""
106
+
107
+ codes = codes.to(device)
108
+
109
+ # Initialize latent
110
+ first_latent = torch.randn(codes.shape[0], 32, 512, 32).to(device)
111
+ first_latent_length = 0
112
+ first_latent_codes_length = 0
113
+
114
+ # Sliding window parameters
115
+ min_samples = 1024
116
+ hop_samples = min_samples // 4 * 3
117
+ ovlp_samples = min_samples - hop_samples
118
+
119
+ codes_len = codes.shape[-1]
120
+ target_len = int(codes_len / 100 * 4 * SAMPLE_RATE)
121
+
122
+ # Pad codes if too short
123
+ if codes_len < min_samples:
124
+ while codes.shape[-1] < min_samples:
125
+ codes = torch.cat([codes, codes], -1)
126
+ codes = codes[:, :, :min_samples]
127
+ codes_len = codes.shape[-1]
128
+
129
+ # Adjust codes length for sliding window
130
+ if (codes_len - ovlp_samples) % hop_samples > 0:
131
+ len_codes = int(np.ceil((codes_len - ovlp_samples) / hop_samples) * hop_samples + ovlp_samples)
132
+ while codes.shape[-1] < len_codes:
133
+ codes = torch.cat([codes, codes], -1)
134
+ codes = codes[:, :, :len_codes]
135
+
136
+ # Generate latents with sliding window
137
+ latent_length = 512
138
+ latent_list = []
139
+ spk_embeds = torch.zeros([1, 32, 1, 32], device=codes.device)
140
+
141
+ with torch.autocast(device_type="cuda" if torch.cuda.is_available() else "cpu", dtype=torch.float16):
142
+ for sinx in range(0, codes.shape[-1] - hop_samples, hop_samples):
143
+ codes_input = [codes[:, :, sinx:sinx + min_samples]]
144
+
145
+ if sinx == 0:
146
+ latents = mucodec_model.inference_codes(
147
+ codes_input, spk_embeds, first_latent,
148
+ latent_length, first_latent_length,
149
+ additional_feats=[], guidance_scale=1.5,
150
+ num_steps=num_steps, disable_progress=True,
151
+ scenario='other_seg'
152
+ )
153
+ else:
154
+ true_latent = latent_list[-1][:, :, -ovlp_samples // 2:, :]
155
+ len_add = 512 - true_latent.shape[-2]
156
+ incontext_length = true_latent.shape[-2]
157
+ true_latent = torch.cat([
158
+ true_latent,
159
+ torch.randn(true_latent.shape[0], true_latent.shape[1],
160
+ len_add, true_latent.shape[-1]).to(device)
161
+ ], -2)
162
+
163
+ latents = mucodec_model.inference_codes(
164
+ codes_input, spk_embeds, true_latent,
165
+ latent_length, incontext_length,
166
+ additional_feats=[], guidance_scale=1.5,
167
+ num_steps=num_steps, disable_progress=True,
168
+ scenario='other_seg'
169
+ )
170
+
171
+ latent_list.append(latents)
172
+
173
+ # Decode latents to audio
174
+ latent_list = [l.float() for l in latent_list]
175
+ duration = 40.96
176
+ min_samples_audio = int(duration * SAMPLE_RATE)
177
+ hop_samples_audio = min_samples_audio // 4 * 3
178
+ ovlp_samples_audio = min_samples_audio - hop_samples_audio
179
+
180
+ output = None
181
+ for i, latent in enumerate(latent_list):
182
+ bsz, ch, t, f = latent.shape
183
+ latent = latent.reshape(bsz * 2, ch // 2, t, f)
184
+ mel = vae.decode_first_stage(latent)
185
+ cur_output = vae.decode_to_waveform(mel)
186
+ cur_output = torch.from_numpy(cur_output)[:, :min_samples_audio]
187
+
188
+ if output is None:
189
+ output = cur_output
190
+ else:
191
+ # Overlap-add smoothing
192
+ ov_win = torch.from_numpy(np.linspace(0, 1, ovlp_samples_audio)[None, :])
193
+ ov_win = torch.cat([ov_win, 1 - ov_win], -1)
194
+ output[:, -ovlp_samples_audio:] = (
195
+ output[:, -ovlp_samples_audio:] * ov_win[:, -ovlp_samples_audio:] +
196
+ cur_output[:, :ovlp_samples_audio] * ov_win[:, :ovlp_samples_audio]
197
+ )
198
+ output = torch.cat([output, cur_output[:, ovlp_samples_audio:]], -1)
199
+
200
+ # Trim to target length
201
+ output = output[:, :target_len]
202
+ return output
203
+
204
+
205
+ # ============================================================================
206
+ # Main Generation Function with @spaces.GPU
207
+ # ============================================================================
208
+
209
+ @spaces.GPU
210
+ def generate_music(
211
+ prompt: str,
212
+ max_tokens: int = 3000,
213
+ temperature: float = 0.0,
214
+ top_p: float = 0.9,
215
+ repetition_penalty: float = 1.1,
216
+ num_diffusion_steps: int = 20,
217
+ ) -> Tuple[Optional[str], str]:
218
+ """Generate music from text prompt"""
219
+
220
+ if not prompt.strip():
221
+ return None, "Please enter a prompt"
222
+
223
+ try:
224
+ # Generate tokens
225
+ messages = [{"role": "user", "content": prompt}]
226
+ prompt_text = tokenizer.apply_chat_template(
227
+ messages, tokenize=False, add_generation_prompt=True
228
+ )
229
+
230
+ inputs = tokenizer(prompt_text, return_tensors="pt")
231
+ inputs = {k: v.to(device) for k, v in inputs.items()}
232
+
233
+ generation_config = {
234
+ "max_new_tokens": max_tokens,
235
+ "temperature": temperature if temperature > 0 else 1.0,
236
+ "top_p": top_p,
237
+ "repetition_penalty": repetition_penalty,
238
+ "do_sample": temperature > 0,
239
+ "pad_token_id": tokenizer.pad_token_id or tokenizer.eos_token_id,
240
+ "eos_token_id": tokenizer.eos_token_id,
241
+ }
242
+
243
+ with torch.no_grad():
244
+ outputs = language_model.generate(**inputs, **generation_config)
245
+
246
+ input_length = inputs["input_ids"].shape[1]
247
+ generated_tokens = outputs[0][input_length:]
248
+ response = tokenizer.decode(generated_tokens, skip_special_tokens=False)
249
+
250
+ # Parse tokens
251
+ audio_codes = parse_tokens_from_text(response)
252
+ if audio_codes is None:
253
+ return None, "❌ Could not parse audio tokens from model output"
254
+
255
+ print(f"Parsed {audio_codes.shape[-1]} audio tokens")
256
+
257
+ # Decode to audio
258
+ waveform = codes_to_audio(audio_codes, num_steps=num_diffusion_steps)
259
+
260
+ # Save audio file
261
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
262
+ output_path = f.name
263
+
264
+ torchaudio.save(output_path, waveform.cpu(), SAMPLE_RATE)
265
+
266
+ duration = waveform.shape[-1] / SAMPLE_RATE
267
+ return output_path, f"✓ Generated {duration:.1f}s audio ({audio_codes.shape[-1]} tokens)"
268
+
269
+ except Exception as e:
270
+ import traceback
271
+ error_msg = f"❌ Error: {str(e)}\n{traceback.format_exc()}"
272
+ print(error_msg)
273
+ return None, error_msg
274
+
275
+
276
+ # ============================================================================
277
+ # Gradio Interface
278
+ # ============================================================================
279
+
280
+ with gr.Blocks(title="Muse-8b Music Generator") as demo:
281
+ gr.Markdown(
282
+ """
283
+ # 🎵 Muse-8b Music Generator
284
+
285
+ Generate music directly from text prompts using Muse-8b + MuCodec.
286
+ """
287
+ )
288
+
289
+ with gr.Row():
290
+ with gr.Column(scale=2):
291
+ prompt_input = gr.Textbox(
292
+ label="Music Prompt",
293
+ placeholder="Describe the music you want to generate...\n\nExample: Please generate a song in style: Pop, Ballad, C-pop. Create an emotional love song with piano accompaniment.",
294
+ lines=5
295
+ )
296
+
297
+ generate_btn = gr.Button("🎵 Generate Music", variant="primary", size="lg")
298
+
299
+ status_output = gr.Textbox(label="Status", lines=2)
300
+ audio_output = gr.Audio(label="Generated Music", type="filepath")
301
+
302
+ with gr.Column(scale=1):
303
+ gr.Markdown("### Generation Settings")
304
+
305
+ max_tokens_slider = gr.Slider(
306
+ minimum=500, maximum=5000, value=3000, step=100,
307
+ label="Max Tokens"
308
+ )
309
+ temperature_slider = gr.Slider(
310
+ minimum=0.0, maximum=1.0, value=0.0, step=0.1,
311
+ label="Temperature (0 = deterministic)"
312
+ )
313
+ top_p_slider = gr.Slider(
314
+ minimum=0.0, maximum=1.0, value=0.9, step=0.05,
315
+ label="Top P"
316
+ )
317
+ rep_penalty_slider = gr.Slider(
318
+ minimum=1.0, maximum=2.0, value=1.1, step=0.05,
319
+ label="Repetition Penalty"
320
+ )
321
+ diffusion_steps_slider = gr.Slider(
322
+ minimum=10, maximum=50, value=20, step=5,
323
+ label="Diffusion Steps (quality vs speed)"
324
+ )
325
+
326
+ gr.Examples(
327
+ examples=[
328
+ ["Please generate a song in style: Pop, Ballad, C-pop. Create an emotional love song with piano accompaniment."],
329
+ ["Generate an upbeat electronic dance music track with strong bass and synth leads."],
330
+ ["Create a classical orchestral piece with strings and woodwinds, peaceful and serene."],
331
+ ["Make a jazz fusion track with saxophone and electric guitar solos."],
332
+ ],
333
+ inputs=prompt_input
334
+ )
335
+
336
+ generate_btn.click(
337
+ fn=generate_music,
338
+ inputs=[
339
+ prompt_input,
340
+ max_tokens_slider,
341
+ temperature_slider,
342
+ top_p_slider,
343
+ rep_penalty_slider,
344
+ diffusion_steps_slider
345
+ ],
346
+ outputs=[audio_output, status_output]
347
+ )
348
+
349
+ gr.Markdown(
350
+ """
351
+ ---
352
+ ### About
353
+
354
+ **Model**: [bolshyC/Muse-8b](https://huggingface.co/bolshyC/Muse-8b)
355
+ **Decoder**: MuCodec (Ultra Low-Bitrate Music Codec)
356
+
357
+ First generation may take ~1-2 minutes. Subsequent generations are faster.
358
+ """
359
+ )
360
+
361
+ demo.queue().launch()