ACEStepv1.5AI / app.py
kines9661's picture
Upload 3 files
2800a6c verified
import os
import time
import base64
import json
import uuid
import asyncio
import traceback
import torch
import spaces
import gradio as gr
import numpy as np
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import StreamingResponse, JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import Optional, List, Union, Any
from gradio_client import Client
# ─── 1. 在啟動時預先下載模型 ────────────────────────────
from huggingface_hub import snapshot_download
print("==== 正在預先下載模型到快取 ====")
try:
snapshot_download("ACE-Step/ACE-Step-v1-3.5B")
print("==== 模型下載完成! ====")
except Exception as e:
print(f"模型下載失敗: {e}")
# ─── 環境變數 ───────────────────────────────────────────
API_KEY = os.environ.get("API_KEY", None)
MODEL_ID = "acemusic/acestep-v15-turbo"
pipeline = None
def get_pipeline():
global pipeline
if pipeline is None:
print("初始化 ACE-Step Pipeline...")
from acestep.pipeline_ace_step import ACEStepPipeline
pipeline = ACEStepPipeline(
checkpoint_dir=None,
dtype="bfloat16",
device="cuda",
)
return pipeline
def check_auth(request: Request):
if API_KEY is None: return True
auth = request.headers.get("Authorization", "")
if not auth.startswith("Bearer "): raise HTTPException(status_code=401, detail="缺少或格式錯誤的 API 金鑰")
if auth.split(" ", 1)[1] != API_KEY: raise HTTPException(status_code=401, detail="無效的 API 金鑰")
return True
class AudioConfig(BaseModel):
duration: Optional[float] = None
bpm: Optional[int] = None
vocal_language: str = "en"
instrumental: Optional[bool] = None
format: str = "mp3"
key_scale: Optional[str] = None
time_signature: Optional[str] = None
class Message(BaseModel):
role: str
content: Union[str, List[Any]]
class ChatCompletionRequest(BaseModel):
model: Optional[str] = "auto"
messages: List[Message]
stream: bool = False
audio_config: Optional[AudioConfig] = None
temperature: float = 0.85
top_p: float = 0.9
seed: Optional[Union[int, str]] = None
lyrics: Optional[str] = ""
sample_mode: bool = False
thinking: bool = False
use_format: bool = False
use_cot_caption: bool = True
use_cot_language: bool = True
guidance_scale: float = 7.0
batch_size: int = 1
task_type: str = "text2music"
repainting_start: float = 0.0
repainting_end: Optional[float] = None
audio_cover_strength: float = 1.0
def parse_input(req: ChatCompletionRequest):
last_user_msg = ""
for msg in reversed(req.messages):
if msg.role == "user":
if isinstance(msg.content, str): last_user_msg = msg.content
elif isinstance(msg.content, list):
for item in msg.content:
if isinstance(item, dict) and item.get("type") == "text":
last_user_msg = item.get("text", ""); break
break
prompt = ""
lyrics = req.lyrics or ""
if "<prompt>" in last_user_msg:
import re
p_match = re.search(r"<prompt>(.*?)</prompt>", last_user_msg, re.DOTALL)
l_match = re.search(r"<lyrics>(.*?)</lyrics>", last_user_msg, re.DOTALL)
prompt = p_match.group(1).strip() if p_match else ""
if not lyrics and l_match: lyrics = l_match.group(1).strip()
elif lyrics or req.sample_mode:
prompt = last_user_msg
else:
if any(tag in last_user_msg for tag in ["[Verse", "[verse", "[Chorus", "[chorus", "[Bridge"]):
lyrics = last_user_msg
else:
prompt = last_user_msg
return prompt, lyrics
# ─── 核心生成函數(ZeroGPU)─────────────────────────────
#
# 🚨 修復 TypeError: argument of type 'bool' is not iterable
#
# 根本原因:當 gr.Number(value=None) 傳入 None 作為預設值時,
# gradio 會產生含有 "additionalProperties": False(布林值)的 JSON Schema。
# 而 gradio_client 的 get_type() 函數嘗試對這個布林值執行
# "const" in schema,因此觸發 TypeError。
#
# 修復策略:
# 1. 所有可選數字欄位改用 gr.Textbox 代替 gr.Number,
# 避免 gr.Number(value=None) 產生有問題的 JSON Schema。
# 2. 在函數內部手動解析字串。
#
@spaces.GPU(duration=120)
def gradio_generate(
prompt: str, # gr.Textbox
lyrics: str, # gr.Textbox
duration: str, # gr.Textbox (取代 gr.Number),避免 additionalProperties: False 問題
bpm: str, # gr.Textbox (取代 gr.Number)
vocal_language: str,# gr.Dropdown
instrumental: str, # gr.Dropdown ("是"/"否"/"自動判定")
guidance_scale: str,# gr.Textbox (取代 gr.Slider,確保相容性)
seed: str # gr.Textbox (取代 gr.Number)
):
try:
seed_val = int(seed.strip()) if seed and seed.strip() else None
bpm_val = int(bpm.strip()) if bpm and bpm.strip() else None
dur_val = float(duration.strip()) if duration and duration.strip() else 30.0
cfg_val = float(guidance_scale.strip()) if guidance_scale and guidance_scale.strip() else 7.0
lang_map = {"英文 (en)": "en", "中文 (zh)": "zh", "日文 (ja)": "ja", "韓文 (ko)": "ko", "自動判定": "en"}
lang = lang_map.get(vocal_language, "en") if vocal_language else "en"
pipe = get_pipeline()
# 處理純伴奏:不直接傳入 bool,改以 prompt 增強文字
final_prompt = prompt if prompt else "instrumental music"
if instrumental == "是":
final_prompt = final_prompt + ", purely instrumental, no vocals"
gen_kwargs = dict(
prompt=final_prompt,
lyrics=lyrics or "",
audio_duration=dur_val,
guidance_scale=cfg_val,
infer_steps=27,
scheduler_type="euler",
)
if bpm_val: gen_kwargs["bpm"] = bpm_val
if lang: gen_kwargs["vocal_language"] = lang
if seed_val is not None: gen_kwargs["seed"] = seed_val
result = pipe(**gen_kwargs)
if hasattr(result, "audio"): audio_data = result.audio
elif isinstance(result, tuple): audio_data = result[0]
else: audio_data = result
sample_rate = getattr(result, "sample_rate", 44100)
if isinstance(audio_data, torch.Tensor): audio_data = audio_data.cpu().numpy()
if audio_data.ndim > 1: audio_data = audio_data.squeeze()
return (sample_rate, audio_data)
except Exception as e:
traceback.print_exc()
raise gr.Error(f"生成失敗: {str(e)}")
# ─── FastAPI ─────────────────────────────────────────────
fastapi_app = FastAPI(title="ACE-Step API")
fastapi_app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
@fastapi_app.get("/health")
async def health(): return {"status": "ok"}
@fastapi_app.get("/v1/models")
async def list_models(request: Request):
check_auth(request)
return {"data": [{"id": MODEL_ID, "name": "ACE-Step v1.5", "created": 1706688000, "pricing": {"prompt": "0", "completion": "0", "request": "0"}}]}
@fastapi_app.post("/v1/chat/completions")
async def chat_completions(req: ChatCompletionRequest, request: Request):
check_auth(request)
completion_id = f"chatcmpl-{uuid.uuid4().hex[:16]}"
created_ts = int(time.time())
try:
prompt, lyrics = parse_input(req)
audio_cfg = req.audio_config or AudioConfig()
seed_val = int(str(req.seed).split(",")[0].strip()) if req.seed is not None else None
def _call_gradio():
client = Client("http://127.0.0.1:7860/")
ui_lang = {"zh": "中文 (zh)", "ja": "日文 (ja)", "ko": "韓文 (ko)"}.get(audio_cfg.vocal_language, "英文 (en)")
ui_instr = "是" if audio_cfg.instrumental is True else ("否" if audio_cfg.instrumental is False else "自動判定")
# 所有參數轉換為字串,匹配 Gradio UI 的 Textbox 型別
return client.predict(
prompt,
lyrics,
str(audio_cfg.duration or 30.0),
str(audio_cfg.bpm) if audio_cfg.bpm else "",
ui_lang,
ui_instr,
str(req.guidance_scale),
str(seed_val) if seed_val is not None else "",
api_name="/generate_music"
)
result_audio_path = await asyncio.to_thread(_call_gradio)
with open(result_audio_path, "rb") as f:
audio_bytes = f.read()
b64 = base64.b64encode(audio_bytes).decode("utf-8")
audio_url = f"data:audio/wav;base64,{b64}"
try: os.remove(result_audio_path)
except: pass
content_text = f"## 生成中繼資料\n**風格:** {prompt}\n**時長:** {audio_cfg.duration or 30}s\n"
if lyrics: content_text += f"\n## 歌詞\n{lyrics}"
response = {
"id": completion_id, "object": "chat.completion", "created": created_ts, "model": MODEL_ID,
"choices": [{"index": 0, "message": {"role": "assistant", "content": content_text, "audio": [{"type": "audio_url", "audio_url": {"url": audio_url}}]}, "finish_reason": "stop"}],
"usage": {"prompt_tokens": len(prompt.split()), "completion_tokens": 100, "total_tokens": len(prompt.split()) + 100}
}
if req.stream:
async def event_stream():
for chunk in [
{"delta": {"role": "assistant", "content": ""}},
{"delta": {"content": content_text}},
{"delta": {"audio": [{"type": "audio_url", "audio_url": {"url": audio_url}}]}},
{"delta": {}, "finish_reason": "stop"}
]:
chunk_data = {"id": completion_id, "object": "chat.completion.chunk", "created": created_ts, "model": MODEL_ID, "choices": [{"index": 0, **chunk}]}
yield f"data: {json.dumps(chunk_data)}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(event_stream(), media_type="text/event-stream")
return JSONResponse(response)
except Exception as e:
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"內部伺服器錯誤: {str(e)}")
# ─── Gradio Web UI ───────────────────────────────────────
# 🚨 修復重點:所有可選數值欄位改用 gr.Textbox,
# 不使用 gr.Number(value=None),防止產生問題 JSON Schema
with gr.Blocks(title="🎵 ACE-Step v1.5 音樂生成器", theme=gr.themes.Soft()) as demo:
gr.HTML("<h1 style='text-align: center;'>🎵 ACE-Step v1.5 音樂生成器</h1>")
with gr.Tab("🎼 生成音樂"):
with gr.Row():
with gr.Column(scale=1):
prompt_input = gr.Textbox(label="🏷️ 音樂風格描述 (Prompt)", placeholder="例如:節奏強烈的 EDM、重低音與合成器")
lyrics_input = gr.Textbox(label="📜 歌詞 (Lyrics,可選填)", lines=4)
with gr.Row():
# ✅ 改用 Textbox,完全避免 gr.Number(value=None) 問題
duration_input = gr.Textbox(label="⏱️ 時長 (秒)", value="30", placeholder="30")
bpm_input = gr.Textbox(label="🥁 BPM (可選)", value="", placeholder="例:120")
with gr.Row():
lang_input = gr.Dropdown(label="🌍 語言", choices=["英文 (en)", "中文 (zh)", "日文 (ja)", "韓文 (ko)", "自動判定"], value="英文 (en)")
instr_input = gr.Dropdown(label="🎸 純伴奏", choices=["自動判定", "是", "否"], value="自動判定")
with gr.Row():
cfg_input = gr.Textbox(label="🎚️ Guidance Scale", value="7.0", placeholder="7.0")
seed_input = gr.Textbox(label="🎲 Seed (可選)", value="", placeholder="例:42")
generate_btn = gr.Button("🚀 開始生成音樂", variant="primary")
with gr.Column(scale=1):
audio_output = gr.Audio(label="🎵 生成結果", type="numpy")
generate_btn.click(
fn=gradio_generate,
inputs=[prompt_input, lyrics_input, duration_input, bpm_input, lang_input, instr_input, cfg_input, seed_input],
outputs=audio_output
)
# API 路由用隱藏觸發按鈕
api_btn = gr.Button("API", visible=False)
api_btn.click(
fn=gradio_generate,
inputs=[prompt_input, lyrics_input, duration_input, bpm_input, lang_input, instr_input, cfg_input, seed_input],
outputs=audio_output,
api_name="generate_music"
)
app = gr.mount_gradio_app(fastapi_app, demo, path="/")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)