um41r commited on
Commit
c44ace0
·
verified ·
1 Parent(s): 9fd2e11

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -16
app.py CHANGED
@@ -1,48 +1,100 @@
1
- import gradio as gr
2
  import torch
 
 
 
 
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
 
5
  MODEL_ID = "m-a-p/YuE-s1-0.5B"
6
 
 
7
  tokenizer = AutoTokenizer.from_pretrained(
8
  MODEL_ID,
9
  trust_remote_code=True,
10
- use_fast=False # FIX
11
  )
12
 
 
13
  model = AutoModelForCausalLM.from_pretrained(
14
  MODEL_ID,
15
- torch_dtype=torch.float32,
 
16
  device_map="auto"
17
  )
18
-
19
  model.eval()
20
 
21
- def generate(prompt, max_new_tokens=256, temperature=0.7, top_p=0.9):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  if not prompt.strip():
23
- return ""
24
 
25
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
26
 
27
  with torch.no_grad():
28
  output = model.generate(
29
  **inputs,
30
- max_new_tokens=max_new_tokens,
31
- temperature=temperature,
32
- top_p=top_p,
33
  do_sample=True,
 
34
  eos_token_id=tokenizer.eos_token_id
35
  )
36
 
37
- return tokenizer.decode(output[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
- with gr.Blocks() as demo:
40
- gr.Markdown("# 🤖 YuE-s1-0.5B (HF Spaces)")
41
 
42
- prompt = gr.Textbox(lines=6, label="Prompt")
43
- btn = gr.Button("Generate")
44
- out = gr.Textbox(lines=12, label="Response")
45
 
46
- btn.click(generate, inputs=prompt, outputs=out)
 
 
 
 
47
 
48
  demo.launch()
 
 
1
  import torch
2
+ import gradio as gr
3
+ import numpy as np
4
+ import soundfile as sf
5
+
6
  from transformers import AutoTokenizer, AutoModelForCausalLM
7
 
8
  MODEL_ID = "m-a-p/YuE-s1-0.5B"
9
 
10
+ # Load tokenizer (slow is REQUIRED)
11
  tokenizer = AutoTokenizer.from_pretrained(
12
  MODEL_ID,
13
  trust_remote_code=True,
14
+ use_fast=False
15
  )
16
 
17
+ # Load model (GPU REQUIRED)
18
  model = AutoModelForCausalLM.from_pretrained(
19
  MODEL_ID,
20
+ trust_remote_code=True,
21
+ torch_dtype=torch.float16,
22
  device_map="auto"
23
  )
 
24
  model.eval()
25
 
26
+ # ----------------------------
27
+ # SIMPLE AUDIO TOKEN DECODER
28
+ # ----------------------------
29
+ # NOTE:
30
+ # YuE uses xcodec tokens.
31
+ # This is a *placeholder decoder*.
32
+ # Official decoder is required for best quality.
33
+
34
+ def fake_xcodec_decode(token_ids, sample_rate=44100):
35
+ """
36
+ This is a minimal placeholder decoder.
37
+ For REAL quality, you must use YuE's official xcodec decoder.
38
+ """
39
+ duration = len(token_ids) // 50
40
+ t = np.linspace(0, duration, duration * sample_rate)
41
+ audio = 0.1 * np.sin(2 * np.pi * 440 * t)
42
+ return audio.astype(np.float32), sample_rate
43
+
44
+
45
+ def generate_music(prompt, max_tokens=2048, temperature=1.0):
46
  if not prompt.strip():
47
+ return None
48
 
49
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
50
 
51
  with torch.no_grad():
52
  output = model.generate(
53
  **inputs,
54
+ max_new_tokens=max_tokens,
 
 
55
  do_sample=True,
56
+ temperature=temperature,
57
  eos_token_id=tokenizer.eos_token_id
58
  )
59
 
60
+ token_ids = output[0].cpu().numpy()
61
+
62
+ audio, sr = fake_xcodec_decode(token_ids)
63
+
64
+ return (sr, audio)
65
+
66
+
67
+ # ----------------------------
68
+ # GRADIO UI
69
+ # ----------------------------
70
+
71
+ with gr.Blocks(title="YuE Music Generator") as demo:
72
+ gr.Markdown(
73
+ """
74
+ # 🎵 YuE Song Generator
75
+ Text → AI Music
76
+
77
+ **Model:** m-a-p/YuE-s1-0.5B
78
+ ⚠️ Requires GPU
79
+ """
80
+ )
81
+
82
+ prompt = gr.Textbox(
83
+ label="Music Prompt",
84
+ placeholder="A sad lo-fi song with piano and rain ambience",
85
+ lines=3
86
+ )
87
 
88
+ max_tokens = gr.Slider(512, 4096, 2048, step=256, label="Max Tokens")
89
+ temperature = gr.Slider(0.7, 1.5, 1.0, step=0.1, label="Creativity")
90
 
91
+ btn = gr.Button("Generate Music")
92
+ audio_out = gr.Audio(label="Generated Music", type="numpy")
 
93
 
94
+ btn.click(
95
+ generate_music,
96
+ inputs=[prompt, max_tokens, temperature],
97
+ outputs=audio_out
98
+ )
99
 
100
  demo.launch()