import torch import gradio as gr import numpy as np import soundfile as sf from transformers import AutoTokenizer, AutoModelForCausalLM MODEL_ID = "m-a-p/YuE-s1-0.5B" # Load tokenizer (slow is REQUIRED) tokenizer = AutoTokenizer.from_pretrained( MODEL_ID, trust_remote_code=True, use_fast=False ) # Load model (GPU REQUIRED) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, trust_remote_code=True, torch_dtype=torch.float16, device_map="auto" ) model.eval() # ---------------------------- # SIMPLE AUDIO TOKEN DECODER # ---------------------------- # NOTE: # YuE uses xcodec tokens. # This is a *placeholder decoder*. # Official decoder is required for best quality. def fake_xcodec_decode(token_ids, sample_rate=44100): """ This is a minimal placeholder decoder. For REAL quality, you must use YuE's official xcodec decoder. """ duration = len(token_ids) // 50 t = np.linspace(0, duration, duration * sample_rate) audio = 0.1 * np.sin(2 * np.pi * 440 * t) return audio.astype(np.float32), sample_rate def generate_music(prompt, max_tokens=2048, temperature=1.0): if not prompt.strip(): return None inputs = tokenizer(prompt, return_tensors="pt").to(model.device) with torch.no_grad(): output = model.generate( **inputs, max_new_tokens=max_tokens, do_sample=True, temperature=temperature, eos_token_id=tokenizer.eos_token_id ) token_ids = output[0].cpu().numpy() audio, sr = fake_xcodec_decode(token_ids) return (sr, audio) # ---------------------------- # GRADIO UI # ---------------------------- with gr.Blocks(title="YuE Music Generator") as demo: gr.Markdown( """ # 🎵 YuE Song Generator Text → AI Music **Model:** m-a-p/YuE-s1-0.5B ⚠️ Requires GPU """ ) prompt = gr.Textbox( label="Music Prompt", placeholder="A sad lo-fi song with piano and rain ambience", lines=3 ) max_tokens = gr.Slider(512, 4096, 2048, step=256, label="Max Tokens") temperature = gr.Slider(0.7, 1.5, 1.0, step=0.1, label="Creativity") btn = gr.Button("Generate Music") audio_out = gr.Audio(label="Generated Music", type="numpy") btn.click( generate_music, inputs=[prompt, max_tokens, temperature], outputs=audio_out ) demo.launch()