Spaces:
Sleeping
Sleeping
File size: 13,389 Bytes
ab56793 b8a90d2 ab56793 b8a90d2 2800a6c b8a90d2 ab56793 b8a90d2 ab56793 2800a6c ab56793 834749c ab56793 834749c ab56793 834749c ab56793 b8a90d2 ab56793 2800a6c ab56793 b8a90d2 ab56793 2800a6c ab56793 2800a6c b8a90d2 2800a6c b8a90d2 2800a6c 834749c 2800a6c 834749c b8a90d2 834749c 2800a6c b8a90d2 2800a6c b8a90d2 ab56793 b8a90d2 ab56793 b8a90d2 ab56793 2800a6c 834749c b8a90d2 ab56793 b8a90d2 ab56793 b8a90d2 ab56793 b8a90d2 2800a6c b8a90d2 834749c 2800a6c 834749c 2800a6c b8a90d2 ab56793 b8a90d2 ab56793 2800a6c b8a90d2 2800a6c b8a90d2 ab56793 b8a90d2 ab56793 2800a6c b8a90d2 ab56793 2800a6c b8a90d2 ab56793 2800a6c ab56793 2800a6c b8a90d2 ab56793 2800a6c b8a90d2 ab56793 2800a6c b8a90d2 2800a6c b8a90d2 ab56793 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 | import os
import time
import base64
import json
import uuid
import asyncio
import traceback
import torch
import spaces
import gradio as gr
import numpy as np
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import StreamingResponse, JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import Optional, List, Union, Any
from gradio_client import Client
# ─── 1. 在啟動時預先下載模型 ────────────────────────────
from huggingface_hub import snapshot_download
print("==== 正在預先下載模型到快取 ====")
try:
snapshot_download("ACE-Step/ACE-Step-v1-3.5B")
print("==== 模型下載完成! ====")
except Exception as e:
print(f"模型下載失敗: {e}")
# ─── 環境變數 ───────────────────────────────────────────
API_KEY = os.environ.get("API_KEY", None)
MODEL_ID = "acemusic/acestep-v15-turbo"
pipeline = None
def get_pipeline():
global pipeline
if pipeline is None:
print("初始化 ACE-Step Pipeline...")
from acestep.pipeline_ace_step import ACEStepPipeline
pipeline = ACEStepPipeline(
checkpoint_dir=None,
dtype="bfloat16",
device="cuda",
)
return pipeline
def check_auth(request: Request):
if API_KEY is None: return True
auth = request.headers.get("Authorization", "")
if not auth.startswith("Bearer "): raise HTTPException(status_code=401, detail="缺少或格式錯誤的 API 金鑰")
if auth.split(" ", 1)[1] != API_KEY: raise HTTPException(status_code=401, detail="無效的 API 金鑰")
return True
class AudioConfig(BaseModel):
duration: Optional[float] = None
bpm: Optional[int] = None
vocal_language: str = "en"
instrumental: Optional[bool] = None
format: str = "mp3"
key_scale: Optional[str] = None
time_signature: Optional[str] = None
class Message(BaseModel):
role: str
content: Union[str, List[Any]]
class ChatCompletionRequest(BaseModel):
model: Optional[str] = "auto"
messages: List[Message]
stream: bool = False
audio_config: Optional[AudioConfig] = None
temperature: float = 0.85
top_p: float = 0.9
seed: Optional[Union[int, str]] = None
lyrics: Optional[str] = ""
sample_mode: bool = False
thinking: bool = False
use_format: bool = False
use_cot_caption: bool = True
use_cot_language: bool = True
guidance_scale: float = 7.0
batch_size: int = 1
task_type: str = "text2music"
repainting_start: float = 0.0
repainting_end: Optional[float] = None
audio_cover_strength: float = 1.0
def parse_input(req: ChatCompletionRequest):
last_user_msg = ""
for msg in reversed(req.messages):
if msg.role == "user":
if isinstance(msg.content, str): last_user_msg = msg.content
elif isinstance(msg.content, list):
for item in msg.content:
if isinstance(item, dict) and item.get("type") == "text":
last_user_msg = item.get("text", ""); break
break
prompt = ""
lyrics = req.lyrics or ""
if "<prompt>" in last_user_msg:
import re
p_match = re.search(r"<prompt>(.*?)</prompt>", last_user_msg, re.DOTALL)
l_match = re.search(r"<lyrics>(.*?)</lyrics>", last_user_msg, re.DOTALL)
prompt = p_match.group(1).strip() if p_match else ""
if not lyrics and l_match: lyrics = l_match.group(1).strip()
elif lyrics or req.sample_mode:
prompt = last_user_msg
else:
if any(tag in last_user_msg for tag in ["[Verse", "[verse", "[Chorus", "[chorus", "[Bridge"]):
lyrics = last_user_msg
else:
prompt = last_user_msg
return prompt, lyrics
# ─── 核心生成函數(ZeroGPU)─────────────────────────────
#
# 🚨 修復 TypeError: argument of type 'bool' is not iterable
#
# 根本原因:當 gr.Number(value=None) 傳入 None 作為預設值時,
# gradio 會產生含有 "additionalProperties": False(布林值)的 JSON Schema。
# 而 gradio_client 的 get_type() 函數嘗試對這個布林值執行
# "const" in schema,因此觸發 TypeError。
#
# 修復策略:
# 1. 所有可選數字欄位改用 gr.Textbox 代替 gr.Number,
# 避免 gr.Number(value=None) 產生有問題的 JSON Schema。
# 2. 在函數內部手動解析字串。
#
@spaces.GPU(duration=120)
def gradio_generate(
prompt: str, # gr.Textbox
lyrics: str, # gr.Textbox
duration: str, # gr.Textbox (取代 gr.Number),避免 additionalProperties: False 問題
bpm: str, # gr.Textbox (取代 gr.Number)
vocal_language: str,# gr.Dropdown
instrumental: str, # gr.Dropdown ("是"/"否"/"自動判定")
guidance_scale: str,# gr.Textbox (取代 gr.Slider,確保相容性)
seed: str # gr.Textbox (取代 gr.Number)
):
try:
seed_val = int(seed.strip()) if seed and seed.strip() else None
bpm_val = int(bpm.strip()) if bpm and bpm.strip() else None
dur_val = float(duration.strip()) if duration and duration.strip() else 30.0
cfg_val = float(guidance_scale.strip()) if guidance_scale and guidance_scale.strip() else 7.0
lang_map = {"英文 (en)": "en", "中文 (zh)": "zh", "日文 (ja)": "ja", "韓文 (ko)": "ko", "自動判定": "en"}
lang = lang_map.get(vocal_language, "en") if vocal_language else "en"
pipe = get_pipeline()
# 處理純伴奏:不直接傳入 bool,改以 prompt 增強文字
final_prompt = prompt if prompt else "instrumental music"
if instrumental == "是":
final_prompt = final_prompt + ", purely instrumental, no vocals"
gen_kwargs = dict(
prompt=final_prompt,
lyrics=lyrics or "",
audio_duration=dur_val,
guidance_scale=cfg_val,
infer_steps=27,
scheduler_type="euler",
)
if bpm_val: gen_kwargs["bpm"] = bpm_val
if lang: gen_kwargs["vocal_language"] = lang
if seed_val is not None: gen_kwargs["seed"] = seed_val
result = pipe(**gen_kwargs)
if hasattr(result, "audio"): audio_data = result.audio
elif isinstance(result, tuple): audio_data = result[0]
else: audio_data = result
sample_rate = getattr(result, "sample_rate", 44100)
if isinstance(audio_data, torch.Tensor): audio_data = audio_data.cpu().numpy()
if audio_data.ndim > 1: audio_data = audio_data.squeeze()
return (sample_rate, audio_data)
except Exception as e:
traceback.print_exc()
raise gr.Error(f"生成失敗: {str(e)}")
# ─── FastAPI ─────────────────────────────────────────────
fastapi_app = FastAPI(title="ACE-Step API")
fastapi_app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
@fastapi_app.get("/health")
async def health(): return {"status": "ok"}
@fastapi_app.get("/v1/models")
async def list_models(request: Request):
check_auth(request)
return {"data": [{"id": MODEL_ID, "name": "ACE-Step v1.5", "created": 1706688000, "pricing": {"prompt": "0", "completion": "0", "request": "0"}}]}
@fastapi_app.post("/v1/chat/completions")
async def chat_completions(req: ChatCompletionRequest, request: Request):
check_auth(request)
completion_id = f"chatcmpl-{uuid.uuid4().hex[:16]}"
created_ts = int(time.time())
try:
prompt, lyrics = parse_input(req)
audio_cfg = req.audio_config or AudioConfig()
seed_val = int(str(req.seed).split(",")[0].strip()) if req.seed is not None else None
def _call_gradio():
client = Client("http://127.0.0.1:7860/")
ui_lang = {"zh": "中文 (zh)", "ja": "日文 (ja)", "ko": "韓文 (ko)"}.get(audio_cfg.vocal_language, "英文 (en)")
ui_instr = "是" if audio_cfg.instrumental is True else ("否" if audio_cfg.instrumental is False else "自動判定")
# 所有參數轉換為字串,匹配 Gradio UI 的 Textbox 型別
return client.predict(
prompt,
lyrics,
str(audio_cfg.duration or 30.0),
str(audio_cfg.bpm) if audio_cfg.bpm else "",
ui_lang,
ui_instr,
str(req.guidance_scale),
str(seed_val) if seed_val is not None else "",
api_name="/generate_music"
)
result_audio_path = await asyncio.to_thread(_call_gradio)
with open(result_audio_path, "rb") as f:
audio_bytes = f.read()
b64 = base64.b64encode(audio_bytes).decode("utf-8")
audio_url = f"data:audio/wav;base64,{b64}"
try: os.remove(result_audio_path)
except: pass
content_text = f"## 生成中繼資料\n**風格:** {prompt}\n**時長:** {audio_cfg.duration or 30}s\n"
if lyrics: content_text += f"\n## 歌詞\n{lyrics}"
response = {
"id": completion_id, "object": "chat.completion", "created": created_ts, "model": MODEL_ID,
"choices": [{"index": 0, "message": {"role": "assistant", "content": content_text, "audio": [{"type": "audio_url", "audio_url": {"url": audio_url}}]}, "finish_reason": "stop"}],
"usage": {"prompt_tokens": len(prompt.split()), "completion_tokens": 100, "total_tokens": len(prompt.split()) + 100}
}
if req.stream:
async def event_stream():
for chunk in [
{"delta": {"role": "assistant", "content": ""}},
{"delta": {"content": content_text}},
{"delta": {"audio": [{"type": "audio_url", "audio_url": {"url": audio_url}}]}},
{"delta": {}, "finish_reason": "stop"}
]:
chunk_data = {"id": completion_id, "object": "chat.completion.chunk", "created": created_ts, "model": MODEL_ID, "choices": [{"index": 0, **chunk}]}
yield f"data: {json.dumps(chunk_data)}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(event_stream(), media_type="text/event-stream")
return JSONResponse(response)
except Exception as e:
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"內部伺服器錯誤: {str(e)}")
# ─── Gradio Web UI ───────────────────────────────────────
# 🚨 修復重點:所有可選數值欄位改用 gr.Textbox,
# 不使用 gr.Number(value=None),防止產生問題 JSON Schema
with gr.Blocks(title="🎵 ACE-Step v1.5 音樂生成器", theme=gr.themes.Soft()) as demo:
gr.HTML("<h1 style='text-align: center;'>🎵 ACE-Step v1.5 音樂生成器</h1>")
with gr.Tab("🎼 生成音樂"):
with gr.Row():
with gr.Column(scale=1):
prompt_input = gr.Textbox(label="🏷️ 音樂風格描述 (Prompt)", placeholder="例如:節奏強烈的 EDM、重低音與合成器")
lyrics_input = gr.Textbox(label="📜 歌詞 (Lyrics,可選填)", lines=4)
with gr.Row():
# ✅ 改用 Textbox,完全避免 gr.Number(value=None) 問題
duration_input = gr.Textbox(label="⏱️ 時長 (秒)", value="30", placeholder="30")
bpm_input = gr.Textbox(label="🥁 BPM (可選)", value="", placeholder="例:120")
with gr.Row():
lang_input = gr.Dropdown(label="🌍 語言", choices=["英文 (en)", "中文 (zh)", "日文 (ja)", "韓文 (ko)", "自動判定"], value="英文 (en)")
instr_input = gr.Dropdown(label="🎸 純伴奏", choices=["自動判定", "是", "否"], value="自動判定")
with gr.Row():
cfg_input = gr.Textbox(label="🎚️ Guidance Scale", value="7.0", placeholder="7.0")
seed_input = gr.Textbox(label="🎲 Seed (可選)", value="", placeholder="例:42")
generate_btn = gr.Button("🚀 開始生成音樂", variant="primary")
with gr.Column(scale=1):
audio_output = gr.Audio(label="🎵 生成結果", type="numpy")
generate_btn.click(
fn=gradio_generate,
inputs=[prompt_input, lyrics_input, duration_input, bpm_input, lang_input, instr_input, cfg_input, seed_input],
outputs=audio_output
)
# API 路由用隱藏觸發按鈕
api_btn = gr.Button("API", visible=False)
api_btn.click(
fn=gradio_generate,
inputs=[prompt_input, lyrics_input, duration_input, bpm_input, lang_input, instr_input, cfg_input, seed_input],
outputs=audio_output,
api_name="generate_music"
)
app = gr.mount_gradio_app(fastapi_app, demo, path="/")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)
|