Spaces:
Sleeping
Sleeping
Upload 3 files
Browse files
app.py
CHANGED
|
@@ -18,7 +18,7 @@ from gradio_client import Client
|
|
| 18 |
|
| 19 |
# ─── 1. 在啟動時預先下載模型 ────────────────────────────
|
| 20 |
from huggingface_hub import snapshot_download
|
| 21 |
-
print("==== 正在預先下載模型到快取
|
| 22 |
try:
|
| 23 |
snapshot_download("ACE-Step/ACE-Step-v1-3.5B")
|
| 24 |
print("==== 模型下載完成! ====")
|
|
@@ -36,7 +36,7 @@ def get_pipeline():
|
|
| 36 |
print("初始化 ACE-Step Pipeline...")
|
| 37 |
from acestep.pipeline_ace_step import ACEStepPipeline
|
| 38 |
pipeline = ACEStepPipeline(
|
| 39 |
-
checkpoint_dir=None,
|
| 40 |
dtype="bfloat16",
|
| 41 |
device="cuda",
|
| 42 |
)
|
|
@@ -91,13 +91,10 @@ def parse_input(req: ChatCompletionRequest):
|
|
| 91 |
elif isinstance(msg.content, list):
|
| 92 |
for item in msg.content:
|
| 93 |
if isinstance(item, dict) and item.get("type") == "text":
|
| 94 |
-
last_user_msg = item.get("text", "")
|
| 95 |
-
break
|
| 96 |
break
|
| 97 |
-
|
| 98 |
prompt = ""
|
| 99 |
lyrics = req.lyrics or ""
|
| 100 |
-
|
| 101 |
if "<prompt>" in last_user_msg:
|
| 102 |
import re
|
| 103 |
p_match = re.search(r"<prompt>(.*?)</prompt>", last_user_msg, re.DOTALL)
|
|
@@ -113,31 +110,52 @@ def parse_input(req: ChatCompletionRequest):
|
|
| 113 |
prompt = last_user_msg
|
| 114 |
return prompt, lyrics
|
| 115 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
@spaces.GPU(duration=120)
|
| 117 |
-
def gradio_generate(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
try:
|
| 119 |
-
seed_val = int(seed) if seed
|
| 120 |
-
bpm_val
|
| 121 |
-
dur_val
|
|
|
|
| 122 |
|
| 123 |
lang_map = {"英文 (en)": "en", "中文 (zh)": "zh", "日文 (ja)": "ja", "韓文 (ko)": "ko", "自動判定": "en"}
|
| 124 |
lang = lang_map.get(vocal_language, "en") if vocal_language else "en"
|
| 125 |
-
instr = True if instrumental in ["是", "Yes", True] else (False if instrumental in ["否", "No", False] else None)
|
| 126 |
|
| 127 |
pipe = get_pipeline()
|
| 128 |
|
|
|
|
| 129 |
final_prompt = prompt if prompt else "instrumental music"
|
| 130 |
-
|
| 131 |
-
# 🚨 重大修復:避免將 boolean 傳入可能預期 Iterable 的 ACE-Step 參數中
|
| 132 |
-
# 轉換 boolean 為提示詞增強,100% 確保相容性
|
| 133 |
-
if instr is True:
|
| 134 |
final_prompt = final_prompt + ", purely instrumental, no vocals"
|
| 135 |
|
| 136 |
gen_kwargs = dict(
|
| 137 |
prompt=final_prompt,
|
| 138 |
-
lyrics=lyrics,
|
| 139 |
audio_duration=dur_val,
|
| 140 |
-
guidance_scale=
|
| 141 |
infer_steps=27,
|
| 142 |
scheduler_type="euler",
|
| 143 |
)
|
|
@@ -160,6 +178,7 @@ def gradio_generate(prompt, lyrics, duration, bpm, vocal_language, instrumental,
|
|
| 160 |
traceback.print_exc()
|
| 161 |
raise gr.Error(f"生成失敗: {str(e)}")
|
| 162 |
|
|
|
|
| 163 |
fastapi_app = FastAPI(title="ACE-Step API")
|
| 164 |
fastapi_app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
|
| 165 |
|
|
@@ -176,7 +195,6 @@ async def chat_completions(req: ChatCompletionRequest, request: Request):
|
|
| 176 |
check_auth(request)
|
| 177 |
completion_id = f"chatcmpl-{uuid.uuid4().hex[:16]}"
|
| 178 |
created_ts = int(time.time())
|
| 179 |
-
|
| 180 |
try:
|
| 181 |
prompt, lyrics = parse_input(req)
|
| 182 |
audio_cfg = req.audio_config or AudioConfig()
|
|
@@ -184,27 +202,18 @@ async def chat_completions(req: ChatCompletionRequest, request: Request):
|
|
| 184 |
|
| 185 |
def _call_gradio():
|
| 186 |
client = Client("http://127.0.0.1:7860/")
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
if audio_cfg.vocal_language == "zh": ui_lang = "中文 (zh)"
|
| 191 |
-
elif audio_cfg.vocal_language == "ja": ui_lang = "日文 (ja)"
|
| 192 |
-
elif audio_cfg.vocal_language == "ko": ui_lang = "韓文 (ko)"
|
| 193 |
-
|
| 194 |
-
ui_instr = "自動判定"
|
| 195 |
-
if audio_cfg.instrumental is True: ui_instr = "是"
|
| 196 |
-
elif audio_cfg.instrumental is False: ui_instr = "否"
|
| 197 |
-
|
| 198 |
-
# 使用嚴格位置傳參
|
| 199 |
return client.predict(
|
| 200 |
prompt,
|
| 201 |
lyrics,
|
| 202 |
-
audio_cfg.duration or 30.0,
|
| 203 |
-
audio_cfg.bpm,
|
| 204 |
ui_lang,
|
| 205 |
ui_instr,
|
| 206 |
-
req.guidance_scale,
|
| 207 |
-
seed_val,
|
| 208 |
api_name="/generate_music"
|
| 209 |
)
|
| 210 |
|
|
@@ -228,13 +237,12 @@ async def chat_completions(req: ChatCompletionRequest, request: Request):
|
|
| 228 |
|
| 229 |
if req.stream:
|
| 230 |
async def event_stream():
|
| 231 |
-
|
| 232 |
{"delta": {"role": "assistant", "content": ""}},
|
| 233 |
{"delta": {"content": content_text}},
|
| 234 |
{"delta": {"audio": [{"type": "audio_url", "audio_url": {"url": audio_url}}]}},
|
| 235 |
{"delta": {}, "finish_reason": "stop"}
|
| 236 |
-
]
|
| 237 |
-
for chunk in chunks:
|
| 238 |
chunk_data = {"id": completion_id, "object": "chat.completion.chunk", "created": created_ts, "model": MODEL_ID, "choices": [{"index": 0, **chunk}]}
|
| 239 |
yield f"data: {json.dumps(chunk_data)}\n\n"
|
| 240 |
yield "data: [DONE]\n\n"
|
|
@@ -246,29 +254,43 @@ async def chat_completions(req: ChatCompletionRequest, request: Request):
|
|
| 246 |
traceback.print_exc()
|
| 247 |
raise HTTPException(status_code=500, detail=f"內部伺服器錯誤: {str(e)}")
|
| 248 |
|
|
|
|
|
|
|
|
|
|
| 249 |
with gr.Blocks(title="🎵 ACE-Step v1.5 音樂生成器", theme=gr.themes.Soft()) as demo:
|
| 250 |
gr.HTML("<h1 style='text-align: center;'>🎵 ACE-Step v1.5 音樂生成器</h1>")
|
| 251 |
with gr.Tab("🎼 生成音樂"):
|
| 252 |
with gr.Row():
|
| 253 |
with gr.Column(scale=1):
|
| 254 |
-
prompt_input = gr.Textbox(label="🏷️ 音樂風格描述 (Prompt)")
|
| 255 |
lyrics_input = gr.Textbox(label="📜 歌詞 (Lyrics,可選填)", lines=4)
|
| 256 |
with gr.Row():
|
| 257 |
-
|
| 258 |
-
|
|
|
|
| 259 |
with gr.Row():
|
| 260 |
-
lang_input
|
| 261 |
instr_input = gr.Dropdown(label="🎸 純伴奏", choices=["自動判定", "是", "否"], value="自動判定")
|
| 262 |
with gr.Row():
|
| 263 |
-
cfg_input
|
| 264 |
-
seed_input = gr.
|
| 265 |
generate_btn = gr.Button("🚀 開始生成音樂", variant="primary")
|
| 266 |
with gr.Column(scale=1):
|
| 267 |
audio_output = gr.Audio(label="🎵 生成結果", type="numpy")
|
| 268 |
|
| 269 |
-
generate_btn.click(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 270 |
api_btn = gr.Button("API", visible=False)
|
| 271 |
-
api_btn.click(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 272 |
|
| 273 |
app = gr.mount_gradio_app(fastapi_app, demo, path="/")
|
| 274 |
if __name__ == "__main__":
|
|
|
|
| 18 |
|
| 19 |
# ─── 1. 在啟動時預先下載模型 ────────────────────────────
|
| 20 |
from huggingface_hub import snapshot_download
|
| 21 |
+
print("==== 正在預先下載模型到快取 ====")
|
| 22 |
try:
|
| 23 |
snapshot_download("ACE-Step/ACE-Step-v1-3.5B")
|
| 24 |
print("==== 模型下載完成! ====")
|
|
|
|
| 36 |
print("初始化 ACE-Step Pipeline...")
|
| 37 |
from acestep.pipeline_ace_step import ACEStepPipeline
|
| 38 |
pipeline = ACEStepPipeline(
|
| 39 |
+
checkpoint_dir=None,
|
| 40 |
dtype="bfloat16",
|
| 41 |
device="cuda",
|
| 42 |
)
|
|
|
|
| 91 |
elif isinstance(msg.content, list):
|
| 92 |
for item in msg.content:
|
| 93 |
if isinstance(item, dict) and item.get("type") == "text":
|
| 94 |
+
last_user_msg = item.get("text", ""); break
|
|
|
|
| 95 |
break
|
|
|
|
| 96 |
prompt = ""
|
| 97 |
lyrics = req.lyrics or ""
|
|
|
|
| 98 |
if "<prompt>" in last_user_msg:
|
| 99 |
import re
|
| 100 |
p_match = re.search(r"<prompt>(.*?)</prompt>", last_user_msg, re.DOTALL)
|
|
|
|
| 110 |
prompt = last_user_msg
|
| 111 |
return prompt, lyrics
|
| 112 |
|
| 113 |
+
# ─── 核心生成函數(ZeroGPU)─────────────────────────────
|
| 114 |
+
#
|
| 115 |
+
# 🚨 修復 TypeError: argument of type 'bool' is not iterable
|
| 116 |
+
#
|
| 117 |
+
# 根本原因:當 gr.Number(value=None) 傳入 None 作為預設值時,
|
| 118 |
+
# gradio 會產生含有 "additionalProperties": False(布林值)的 JSON Schema。
|
| 119 |
+
# 而 gradio_client 的 get_type() 函數嘗試對這個布林值執行
|
| 120 |
+
# "const" in schema,因此觸發 TypeError。
|
| 121 |
+
#
|
| 122 |
+
# 修復策略:
|
| 123 |
+
# 1. 所有可選數字欄位改用 gr.Textbox 代替 gr.Number,
|
| 124 |
+
# 避免 gr.Number(value=None) 產生有問題的 JSON Schema。
|
| 125 |
+
# 2. 在函數內部手動解析字串。
|
| 126 |
+
#
|
| 127 |
@spaces.GPU(duration=120)
|
| 128 |
+
def gradio_generate(
|
| 129 |
+
prompt: str, # gr.Textbox
|
| 130 |
+
lyrics: str, # gr.Textbox
|
| 131 |
+
duration: str, # gr.Textbox (取代 gr.Number),避免 additionalProperties: False 問題
|
| 132 |
+
bpm: str, # gr.Textbox (取代 gr.Number)
|
| 133 |
+
vocal_language: str,# gr.Dropdown
|
| 134 |
+
instrumental: str, # gr.Dropdown ("是"/"否"/"自動判定")
|
| 135 |
+
guidance_scale: str,# gr.Textbox (取代 gr.Slider,確保相容性)
|
| 136 |
+
seed: str # gr.Textbox (取代 gr.Number)
|
| 137 |
+
):
|
| 138 |
try:
|
| 139 |
+
seed_val = int(seed.strip()) if seed and seed.strip() else None
|
| 140 |
+
bpm_val = int(bpm.strip()) if bpm and bpm.strip() else None
|
| 141 |
+
dur_val = float(duration.strip()) if duration and duration.strip() else 30.0
|
| 142 |
+
cfg_val = float(guidance_scale.strip()) if guidance_scale and guidance_scale.strip() else 7.0
|
| 143 |
|
| 144 |
lang_map = {"英文 (en)": "en", "中文 (zh)": "zh", "日文 (ja)": "ja", "韓文 (ko)": "ko", "自動判定": "en"}
|
| 145 |
lang = lang_map.get(vocal_language, "en") if vocal_language else "en"
|
|
|
|
| 146 |
|
| 147 |
pipe = get_pipeline()
|
| 148 |
|
| 149 |
+
# 處理純伴奏:不直接傳入 bool,改以 prompt 增強文字
|
| 150 |
final_prompt = prompt if prompt else "instrumental music"
|
| 151 |
+
if instrumental == "是":
|
|
|
|
|
|
|
|
|
|
| 152 |
final_prompt = final_prompt + ", purely instrumental, no vocals"
|
| 153 |
|
| 154 |
gen_kwargs = dict(
|
| 155 |
prompt=final_prompt,
|
| 156 |
+
lyrics=lyrics or "",
|
| 157 |
audio_duration=dur_val,
|
| 158 |
+
guidance_scale=cfg_val,
|
| 159 |
infer_steps=27,
|
| 160 |
scheduler_type="euler",
|
| 161 |
)
|
|
|
|
| 178 |
traceback.print_exc()
|
| 179 |
raise gr.Error(f"生成失敗: {str(e)}")
|
| 180 |
|
| 181 |
+
# ─── FastAPI ─────────────────────────────────────────────
|
| 182 |
fastapi_app = FastAPI(title="ACE-Step API")
|
| 183 |
fastapi_app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
|
| 184 |
|
|
|
|
| 195 |
check_auth(request)
|
| 196 |
completion_id = f"chatcmpl-{uuid.uuid4().hex[:16]}"
|
| 197 |
created_ts = int(time.time())
|
|
|
|
| 198 |
try:
|
| 199 |
prompt, lyrics = parse_input(req)
|
| 200 |
audio_cfg = req.audio_config or AudioConfig()
|
|
|
|
| 202 |
|
| 203 |
def _call_gradio():
|
| 204 |
client = Client("http://127.0.0.1:7860/")
|
| 205 |
+
ui_lang = {"zh": "中文 (zh)", "ja": "日文 (ja)", "ko": "韓文 (ko)"}.get(audio_cfg.vocal_language, "英文 (en)")
|
| 206 |
+
ui_instr = "是" if audio_cfg.instrumental is True else ("否" if audio_cfg.instrumental is False else "自動判定")
|
| 207 |
+
# 所有參數轉換為字串,匹配 Gradio UI 的 Textbox 型別
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
return client.predict(
|
| 209 |
prompt,
|
| 210 |
lyrics,
|
| 211 |
+
str(audio_cfg.duration or 30.0),
|
| 212 |
+
str(audio_cfg.bpm) if audio_cfg.bpm else "",
|
| 213 |
ui_lang,
|
| 214 |
ui_instr,
|
| 215 |
+
str(req.guidance_scale),
|
| 216 |
+
str(seed_val) if seed_val is not None else "",
|
| 217 |
api_name="/generate_music"
|
| 218 |
)
|
| 219 |
|
|
|
|
| 237 |
|
| 238 |
if req.stream:
|
| 239 |
async def event_stream():
|
| 240 |
+
for chunk in [
|
| 241 |
{"delta": {"role": "assistant", "content": ""}},
|
| 242 |
{"delta": {"content": content_text}},
|
| 243 |
{"delta": {"audio": [{"type": "audio_url", "audio_url": {"url": audio_url}}]}},
|
| 244 |
{"delta": {}, "finish_reason": "stop"}
|
| 245 |
+
]:
|
|
|
|
| 246 |
chunk_data = {"id": completion_id, "object": "chat.completion.chunk", "created": created_ts, "model": MODEL_ID, "choices": [{"index": 0, **chunk}]}
|
| 247 |
yield f"data: {json.dumps(chunk_data)}\n\n"
|
| 248 |
yield "data: [DONE]\n\n"
|
|
|
|
| 254 |
traceback.print_exc()
|
| 255 |
raise HTTPException(status_code=500, detail=f"內部伺服器錯誤: {str(e)}")
|
| 256 |
|
| 257 |
+
# ─── Gradio Web UI ───────────────────────────────────────
|
| 258 |
+
# 🚨 修復重點:所有可選數值欄位改用 gr.Textbox,
|
| 259 |
+
# 不使用 gr.Number(value=None),防止產生問題 JSON Schema
|
| 260 |
with gr.Blocks(title="🎵 ACE-Step v1.5 音樂生成器", theme=gr.themes.Soft()) as demo:
|
| 261 |
gr.HTML("<h1 style='text-align: center;'>🎵 ACE-Step v1.5 音樂生成器</h1>")
|
| 262 |
with gr.Tab("🎼 生成音樂"):
|
| 263 |
with gr.Row():
|
| 264 |
with gr.Column(scale=1):
|
| 265 |
+
prompt_input = gr.Textbox(label="🏷️ 音樂風格描述 (Prompt)", placeholder="例如:節奏強烈的 EDM、重低音與合成器")
|
| 266 |
lyrics_input = gr.Textbox(label="📜 歌詞 (Lyrics,可選填)", lines=4)
|
| 267 |
with gr.Row():
|
| 268 |
+
# ✅ 改用 Textbox,完全避免 gr.Number(value=None) 問題
|
| 269 |
+
duration_input = gr.Textbox(label="⏱️ 時長 (秒)", value="30", placeholder="30")
|
| 270 |
+
bpm_input = gr.Textbox(label="🥁 BPM (可選)", value="", placeholder="例:120")
|
| 271 |
with gr.Row():
|
| 272 |
+
lang_input = gr.Dropdown(label="🌍 語言", choices=["英文 (en)", "中文 (zh)", "日文 (ja)", "韓文 (ko)", "自動判定"], value="英文 (en)")
|
| 273 |
instr_input = gr.Dropdown(label="🎸 純伴奏", choices=["自動判定", "是", "否"], value="自動判定")
|
| 274 |
with gr.Row():
|
| 275 |
+
cfg_input = gr.Textbox(label="🎚️ Guidance Scale", value="7.0", placeholder="7.0")
|
| 276 |
+
seed_input = gr.Textbox(label="🎲 Seed (可選)", value="", placeholder="例:42")
|
| 277 |
generate_btn = gr.Button("🚀 開始生成音樂", variant="primary")
|
| 278 |
with gr.Column(scale=1):
|
| 279 |
audio_output = gr.Audio(label="🎵 生成結果", type="numpy")
|
| 280 |
|
| 281 |
+
generate_btn.click(
|
| 282 |
+
fn=gradio_generate,
|
| 283 |
+
inputs=[prompt_input, lyrics_input, duration_input, bpm_input, lang_input, instr_input, cfg_input, seed_input],
|
| 284 |
+
outputs=audio_output
|
| 285 |
+
)
|
| 286 |
+
# API 路由用隱藏觸發按鈕
|
| 287 |
api_btn = gr.Button("API", visible=False)
|
| 288 |
+
api_btn.click(
|
| 289 |
+
fn=gradio_generate,
|
| 290 |
+
inputs=[prompt_input, lyrics_input, duration_input, bpm_input, lang_input, instr_input, cfg_input, seed_input],
|
| 291 |
+
outputs=audio_output,
|
| 292 |
+
api_name="generate_music"
|
| 293 |
+
)
|
| 294 |
|
| 295 |
app = gr.mount_gradio_app(fastapi_app, demo, path="/")
|
| 296 |
if __name__ == "__main__":
|