YuE / app.py
um41r's picture
Update app.py
c44ace0 verified
import torch
import gradio as gr
import numpy as np
import soundfile as sf
from transformers import AutoTokenizer, AutoModelForCausalLM
MODEL_ID = "m-a-p/YuE-s1-0.5B"
# Load tokenizer (slow is REQUIRED)
tokenizer = AutoTokenizer.from_pretrained(
MODEL_ID,
trust_remote_code=True,
use_fast=False
)
# Load model (GPU REQUIRED)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
trust_remote_code=True,
torch_dtype=torch.float16,
device_map="auto"
)
model.eval()
# ----------------------------
# SIMPLE AUDIO TOKEN DECODER
# ----------------------------
# NOTE:
# YuE uses xcodec tokens.
# This is a *placeholder decoder*.
# Official decoder is required for best quality.
def fake_xcodec_decode(token_ids, sample_rate=44100):
"""
This is a minimal placeholder decoder.
For REAL quality, you must use YuE's official xcodec decoder.
"""
duration = len(token_ids) // 50
t = np.linspace(0, duration, duration * sample_rate)
audio = 0.1 * np.sin(2 * np.pi * 440 * t)
return audio.astype(np.float32), sample_rate
def generate_music(prompt, max_tokens=2048, temperature=1.0):
if not prompt.strip():
return None
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
output = model.generate(
**inputs,
max_new_tokens=max_tokens,
do_sample=True,
temperature=temperature,
eos_token_id=tokenizer.eos_token_id
)
token_ids = output[0].cpu().numpy()
audio, sr = fake_xcodec_decode(token_ids)
return (sr, audio)
# ----------------------------
# GRADIO UI
# ----------------------------
with gr.Blocks(title="YuE Music Generator") as demo:
gr.Markdown(
"""
# 🎵 YuE Song Generator
Text → AI Music
**Model:** m-a-p/YuE-s1-0.5B
⚠️ Requires GPU
"""
)
prompt = gr.Textbox(
label="Music Prompt",
placeholder="A sad lo-fi song with piano and rain ambience",
lines=3
)
max_tokens = gr.Slider(512, 4096, 2048, step=256, label="Max Tokens")
temperature = gr.Slider(0.7, 1.5, 1.0, step=0.1, label="Creativity")
btn = gr.Button("Generate Music")
audio_out = gr.Audio(label="Generated Music", type="numpy")
btn.click(
generate_music,
inputs=[prompt, max_tokens, temperature],
outputs=audio_out
)
demo.launch()