kines9661 commited on
Commit
2800a6c
·
verified ·
1 Parent(s): 834749c

Upload 3 files

Browse files
Files changed (1) hide show
  1. app.py +67 -45
app.py CHANGED
@@ -18,7 +18,7 @@ from gradio_client import Client
18
 
19
  # ─── 1. 在啟動時預先下載模型 ────────────────────────────
20
  from huggingface_hub import snapshot_download
21
- print("==== 正在預先下載模型到快取 (避免第一次請求超時) ====")
22
  try:
23
  snapshot_download("ACE-Step/ACE-Step-v1-3.5B")
24
  print("==== 模型下載完成! ====")
@@ -36,7 +36,7 @@ def get_pipeline():
36
  print("初始化 ACE-Step Pipeline...")
37
  from acestep.pipeline_ace_step import ACEStepPipeline
38
  pipeline = ACEStepPipeline(
39
- checkpoint_dir=None,
40
  dtype="bfloat16",
41
  device="cuda",
42
  )
@@ -91,13 +91,10 @@ def parse_input(req: ChatCompletionRequest):
91
  elif isinstance(msg.content, list):
92
  for item in msg.content:
93
  if isinstance(item, dict) and item.get("type") == "text":
94
- last_user_msg = item.get("text", "")
95
- break
96
  break
97
-
98
  prompt = ""
99
  lyrics = req.lyrics or ""
100
-
101
  if "<prompt>" in last_user_msg:
102
  import re
103
  p_match = re.search(r"<prompt>(.*?)</prompt>", last_user_msg, re.DOTALL)
@@ -113,31 +110,52 @@ def parse_input(req: ChatCompletionRequest):
113
  prompt = last_user_msg
114
  return prompt, lyrics
115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  @spaces.GPU(duration=120)
117
- def gradio_generate(prompt, lyrics, duration, bpm, vocal_language, instrumental, guidance_scale, seed):
 
 
 
 
 
 
 
 
 
118
  try:
119
- seed_val = int(seed) if seed is not None and str(seed).strip() != "" else None
120
- bpm_val = int(bpm) if bpm is not None and str(bpm).strip() != "" else None
121
- dur_val = float(duration) if duration is not None else 30.0
 
122
 
123
  lang_map = {"英文 (en)": "en", "中文 (zh)": "zh", "日文 (ja)": "ja", "韓文 (ko)": "ko", "自動判定": "en"}
124
  lang = lang_map.get(vocal_language, "en") if vocal_language else "en"
125
- instr = True if instrumental in ["是", "Yes", True] else (False if instrumental in ["否", "No", False] else None)
126
 
127
  pipe = get_pipeline()
128
 
 
129
  final_prompt = prompt if prompt else "instrumental music"
130
-
131
- # 🚨 重大修復:避免將 boolean 傳入可能預期 Iterable 的 ACE-Step 參數中
132
- # 轉換 boolean 為提示詞增強,100% 確保相容性
133
- if instr is True:
134
  final_prompt = final_prompt + ", purely instrumental, no vocals"
135
 
136
  gen_kwargs = dict(
137
  prompt=final_prompt,
138
- lyrics=lyrics,
139
  audio_duration=dur_val,
140
- guidance_scale=float(guidance_scale) if guidance_scale else 7.0,
141
  infer_steps=27,
142
  scheduler_type="euler",
143
  )
@@ -160,6 +178,7 @@ def gradio_generate(prompt, lyrics, duration, bpm, vocal_language, instrumental,
160
  traceback.print_exc()
161
  raise gr.Error(f"生成失敗: {str(e)}")
162
 
 
163
  fastapi_app = FastAPI(title="ACE-Step API")
164
  fastapi_app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
165
 
@@ -176,7 +195,6 @@ async def chat_completions(req: ChatCompletionRequest, request: Request):
176
  check_auth(request)
177
  completion_id = f"chatcmpl-{uuid.uuid4().hex[:16]}"
178
  created_ts = int(time.time())
179
-
180
  try:
181
  prompt, lyrics = parse_input(req)
182
  audio_cfg = req.audio_config or AudioConfig()
@@ -184,27 +202,18 @@ async def chat_completions(req: ChatCompletionRequest, request: Request):
184
 
185
  def _call_gradio():
186
  client = Client("http://127.0.0.1:7860/")
187
-
188
- # 對齊 UI 元件所需的格式,防止型別對應錯誤
189
- ui_lang = "英文 (en)"
190
- if audio_cfg.vocal_language == "zh": ui_lang = "中文 (zh)"
191
- elif audio_cfg.vocal_language == "ja": ui_lang = "日文 (ja)"
192
- elif audio_cfg.vocal_language == "ko": ui_lang = "韓文 (ko)"
193
-
194
- ui_instr = "自動判定"
195
- if audio_cfg.instrumental is True: ui_instr = "是"
196
- elif audio_cfg.instrumental is False: ui_instr = "否"
197
-
198
- # 使用嚴格位置傳參
199
  return client.predict(
200
  prompt,
201
  lyrics,
202
- audio_cfg.duration or 30.0,
203
- audio_cfg.bpm,
204
  ui_lang,
205
  ui_instr,
206
- req.guidance_scale,
207
- seed_val,
208
  api_name="/generate_music"
209
  )
210
 
@@ -228,13 +237,12 @@ async def chat_completions(req: ChatCompletionRequest, request: Request):
228
 
229
  if req.stream:
230
  async def event_stream():
231
- chunks = [
232
  {"delta": {"role": "assistant", "content": ""}},
233
  {"delta": {"content": content_text}},
234
  {"delta": {"audio": [{"type": "audio_url", "audio_url": {"url": audio_url}}]}},
235
  {"delta": {}, "finish_reason": "stop"}
236
- ]
237
- for chunk in chunks:
238
  chunk_data = {"id": completion_id, "object": "chat.completion.chunk", "created": created_ts, "model": MODEL_ID, "choices": [{"index": 0, **chunk}]}
239
  yield f"data: {json.dumps(chunk_data)}\n\n"
240
  yield "data: [DONE]\n\n"
@@ -246,29 +254,43 @@ async def chat_completions(req: ChatCompletionRequest, request: Request):
246
  traceback.print_exc()
247
  raise HTTPException(status_code=500, detail=f"內部伺服器錯誤: {str(e)}")
248
 
 
 
 
249
  with gr.Blocks(title="🎵 ACE-Step v1.5 音樂生成器", theme=gr.themes.Soft()) as demo:
250
  gr.HTML("<h1 style='text-align: center;'>🎵 ACE-Step v1.5 音樂生成器</h1>")
251
  with gr.Tab("🎼 生成音樂"):
252
  with gr.Row():
253
  with gr.Column(scale=1):
254
- prompt_input = gr.Textbox(label="🏷️ 音樂風格描述 (Prompt)")
255
  lyrics_input = gr.Textbox(label="📜 歌詞 (Lyrics,可選填)", lines=4)
256
  with gr.Row():
257
- duration_input = gr.Number(label="⏱️ 時長(秒)", value=30)
258
- bpm_input = gr.Number(label="🥁 BPM", value=None)
 
259
  with gr.Row():
260
- lang_input = gr.Dropdown(label="🌍 語言", choices=["英文 (en)", "中文 (zh)", "日文 (ja)", "韓文 (ko)"], value="英文 (en)")
261
  instr_input = gr.Dropdown(label="🎸 純伴奏", choices=["自動判定", "是", "否"], value="自動判定")
262
  with gr.Row():
263
- cfg_input = gr.Slider(label="🎚️ Guidance Scale", minimum=1, maximum=15, value=7.0)
264
- seed_input = gr.Number(label="🎲 Seed", value=None)
265
  generate_btn = gr.Button("🚀 開始生成音樂", variant="primary")
266
  with gr.Column(scale=1):
267
  audio_output = gr.Audio(label="🎵 生成結果", type="numpy")
268
 
269
- generate_btn.click(fn=gradio_generate, inputs=[prompt_input, lyrics_input, duration_input, bpm_input, lang_input, instr_input, cfg_input, seed_input], outputs=audio_output)
 
 
 
 
 
270
  api_btn = gr.Button("API", visible=False)
271
- api_btn.click(fn=gradio_generate, inputs=[prompt_input, lyrics_input, duration_input, bpm_input, lang_input, instr_input, cfg_input, seed_input], outputs=audio_output, api_name="generate_music")
 
 
 
 
 
272
 
273
  app = gr.mount_gradio_app(fastapi_app, demo, path="/")
274
  if __name__ == "__main__":
 
18
 
19
  # ─── 1. 在啟動時預先下載模型 ────────────────────────────
20
  from huggingface_hub import snapshot_download
21
+ print("==== 正在預先下載模型到快取 ====")
22
  try:
23
  snapshot_download("ACE-Step/ACE-Step-v1-3.5B")
24
  print("==== 模型下載完成! ====")
 
36
  print("初始化 ACE-Step Pipeline...")
37
  from acestep.pipeline_ace_step import ACEStepPipeline
38
  pipeline = ACEStepPipeline(
39
+ checkpoint_dir=None,
40
  dtype="bfloat16",
41
  device="cuda",
42
  )
 
91
  elif isinstance(msg.content, list):
92
  for item in msg.content:
93
  if isinstance(item, dict) and item.get("type") == "text":
94
+ last_user_msg = item.get("text", ""); break
 
95
  break
 
96
  prompt = ""
97
  lyrics = req.lyrics or ""
 
98
  if "<prompt>" in last_user_msg:
99
  import re
100
  p_match = re.search(r"<prompt>(.*?)</prompt>", last_user_msg, re.DOTALL)
 
110
  prompt = last_user_msg
111
  return prompt, lyrics
112
 
113
+ # ─── 核心生成函數(ZeroGPU)─────────────────────────────
114
+ #
115
+ # 🚨 修復 TypeError: argument of type 'bool' is not iterable
116
+ #
117
+ # 根本原因:當 gr.Number(value=None) 傳入 None 作為預設值時,
118
+ # gradio 會產生含有 "additionalProperties": False(布林值)的 JSON Schema。
119
+ # 而 gradio_client 的 get_type() 函數嘗試對這個布林值執行
120
+ # "const" in schema,因此觸發 TypeError。
121
+ #
122
+ # 修復策略:
123
+ # 1. 所有可選數字欄位改用 gr.Textbox 代替 gr.Number,
124
+ # 避免 gr.Number(value=None) 產生有問題的 JSON Schema。
125
+ # 2. 在函數內部手動解析字串。
126
+ #
127
  @spaces.GPU(duration=120)
128
+ def gradio_generate(
129
+ prompt: str, # gr.Textbox
130
+ lyrics: str, # gr.Textbox
131
+ duration: str, # gr.Textbox (取代 gr.Number),避免 additionalProperties: False 問題
132
+ bpm: str, # gr.Textbox (取代 gr.Number)
133
+ vocal_language: str,# gr.Dropdown
134
+ instrumental: str, # gr.Dropdown ("是"/"否"/"自動判定")
135
+ guidance_scale: str,# gr.Textbox (取代 gr.Slider,確保相容性)
136
+ seed: str # gr.Textbox (取代 gr.Number)
137
+ ):
138
  try:
139
+ seed_val = int(seed.strip()) if seed and seed.strip() else None
140
+ bpm_val = int(bpm.strip()) if bpm and bpm.strip() else None
141
+ dur_val = float(duration.strip()) if duration and duration.strip() else 30.0
142
+ cfg_val = float(guidance_scale.strip()) if guidance_scale and guidance_scale.strip() else 7.0
143
 
144
  lang_map = {"英文 (en)": "en", "中文 (zh)": "zh", "日文 (ja)": "ja", "韓文 (ko)": "ko", "自動判定": "en"}
145
  lang = lang_map.get(vocal_language, "en") if vocal_language else "en"
 
146
 
147
  pipe = get_pipeline()
148
 
149
+ # 處理純伴奏:不直接傳入 bool,改以 prompt 增強文字
150
  final_prompt = prompt if prompt else "instrumental music"
151
+ if instrumental == "是":
 
 
 
152
  final_prompt = final_prompt + ", purely instrumental, no vocals"
153
 
154
  gen_kwargs = dict(
155
  prompt=final_prompt,
156
+ lyrics=lyrics or "",
157
  audio_duration=dur_val,
158
+ guidance_scale=cfg_val,
159
  infer_steps=27,
160
  scheduler_type="euler",
161
  )
 
178
  traceback.print_exc()
179
  raise gr.Error(f"生成失敗: {str(e)}")
180
 
181
+ # ─── FastAPI ─────────────────────────────────────────────
182
  fastapi_app = FastAPI(title="ACE-Step API")
183
  fastapi_app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
184
 
 
195
  check_auth(request)
196
  completion_id = f"chatcmpl-{uuid.uuid4().hex[:16]}"
197
  created_ts = int(time.time())
 
198
  try:
199
  prompt, lyrics = parse_input(req)
200
  audio_cfg = req.audio_config or AudioConfig()
 
202
 
203
  def _call_gradio():
204
  client = Client("http://127.0.0.1:7860/")
205
+ ui_lang = {"zh": "中文 (zh)", "ja": "日文 (ja)", "ko": "韓文 (ko)"}.get(audio_cfg.vocal_language, "英文 (en)")
206
+ ui_instr = "是" if audio_cfg.instrumental is True else ("否" if audio_cfg.instrumental is False else "自動判定")
207
+ # 所有參數轉換為字串,匹配 Gradio UI 的 Textbox 型別
 
 
 
 
 
 
 
 
 
208
  return client.predict(
209
  prompt,
210
  lyrics,
211
+ str(audio_cfg.duration or 30.0),
212
+ str(audio_cfg.bpm) if audio_cfg.bpm else "",
213
  ui_lang,
214
  ui_instr,
215
+ str(req.guidance_scale),
216
+ str(seed_val) if seed_val is not None else "",
217
  api_name="/generate_music"
218
  )
219
 
 
237
 
238
  if req.stream:
239
  async def event_stream():
240
+ for chunk in [
241
  {"delta": {"role": "assistant", "content": ""}},
242
  {"delta": {"content": content_text}},
243
  {"delta": {"audio": [{"type": "audio_url", "audio_url": {"url": audio_url}}]}},
244
  {"delta": {}, "finish_reason": "stop"}
245
+ ]:
 
246
  chunk_data = {"id": completion_id, "object": "chat.completion.chunk", "created": created_ts, "model": MODEL_ID, "choices": [{"index": 0, **chunk}]}
247
  yield f"data: {json.dumps(chunk_data)}\n\n"
248
  yield "data: [DONE]\n\n"
 
254
  traceback.print_exc()
255
  raise HTTPException(status_code=500, detail=f"內部伺服器錯誤: {str(e)}")
256
 
257
+ # ─── Gradio Web UI ───────────────────────────────────────
258
+ # 🚨 修復重點:所有可選數值欄位改用 gr.Textbox,
259
+ # 不使用 gr.Number(value=None),防止產生問題 JSON Schema
260
  with gr.Blocks(title="🎵 ACE-Step v1.5 音樂生成器", theme=gr.themes.Soft()) as demo:
261
  gr.HTML("<h1 style='text-align: center;'>🎵 ACE-Step v1.5 音樂生成器</h1>")
262
  with gr.Tab("🎼 生成音樂"):
263
  with gr.Row():
264
  with gr.Column(scale=1):
265
+ prompt_input = gr.Textbox(label="🏷️ 音樂風格描述 (Prompt)", placeholder="例如:節奏強烈的 EDM、重低音與合成器")
266
  lyrics_input = gr.Textbox(label="📜 歌詞 (Lyrics,可選填)", lines=4)
267
  with gr.Row():
268
+ # 改用 Textbox,完全避免 gr.Number(value=None) 問題
269
+ duration_input = gr.Textbox(label="⏱️ 時長 (秒)", value="30", placeholder="30")
270
+ bpm_input = gr.Textbox(label="🥁 BPM (可選)", value="", placeholder="例:120")
271
  with gr.Row():
272
+ lang_input = gr.Dropdown(label="🌍 語言", choices=["英文 (en)", "中文 (zh)", "日文 (ja)", "韓文 (ko)", "自動判定"], value="英文 (en)")
273
  instr_input = gr.Dropdown(label="🎸 純伴奏", choices=["自動判定", "是", "否"], value="自動判定")
274
  with gr.Row():
275
+ cfg_input = gr.Textbox(label="🎚️ Guidance Scale", value="7.0", placeholder="7.0")
276
+ seed_input = gr.Textbox(label="🎲 Seed (可選)", value="", placeholder="例:42")
277
  generate_btn = gr.Button("🚀 開始生成音樂", variant="primary")
278
  with gr.Column(scale=1):
279
  audio_output = gr.Audio(label="🎵 生成結果", type="numpy")
280
 
281
+ generate_btn.click(
282
+ fn=gradio_generate,
283
+ inputs=[prompt_input, lyrics_input, duration_input, bpm_input, lang_input, instr_input, cfg_input, seed_input],
284
+ outputs=audio_output
285
+ )
286
+ # API 路由用隱藏觸發按鈕
287
  api_btn = gr.Button("API", visible=False)
288
+ api_btn.click(
289
+ fn=gradio_generate,
290
+ inputs=[prompt_input, lyrics_input, duration_input, bpm_input, lang_input, instr_input, cfg_input, seed_input],
291
+ outputs=audio_output,
292
+ api_name="generate_music"
293
+ )
294
 
295
  app = gr.mount_gradio_app(fastapi_app, demo, path="/")
296
  if __name__ == "__main__":