import os import time import base64 import json import uuid import asyncio import traceback import torch import spaces import gradio as gr import numpy as np from fastapi import FastAPI, HTTPException, Request from fastapi.responses import StreamingResponse, JSONResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from typing import Optional, List, Union, Any from gradio_client import Client # ─── 1. 在啟動時預先下載模型 ──────────────────────────── from huggingface_hub import snapshot_download print("==== 正在預先下載模型到快取 ====") try: snapshot_download("ACE-Step/ACE-Step-v1-3.5B") print("==== 模型下載完成! ====") except Exception as e: print(f"模型下載失敗: {e}") # ─── 環境變數 ─────────────────────────────────────────── API_KEY = os.environ.get("API_KEY", None) MODEL_ID = "acemusic/acestep-v15-turbo" pipeline = None def get_pipeline(): global pipeline if pipeline is None: print("初始化 ACE-Step Pipeline...") from acestep.pipeline_ace_step import ACEStepPipeline pipeline = ACEStepPipeline( checkpoint_dir=None, dtype="bfloat16", device="cuda", ) return pipeline def check_auth(request: Request): if API_KEY is None: return True auth = request.headers.get("Authorization", "") if not auth.startswith("Bearer "): raise HTTPException(status_code=401, detail="缺少或格式錯誤的 API 金鑰") if auth.split(" ", 1)[1] != API_KEY: raise HTTPException(status_code=401, detail="無效的 API 金鑰") return True class AudioConfig(BaseModel): duration: Optional[float] = None bpm: Optional[int] = None vocal_language: str = "en" instrumental: Optional[bool] = None format: str = "mp3" key_scale: Optional[str] = None time_signature: Optional[str] = None class Message(BaseModel): role: str content: Union[str, List[Any]] class ChatCompletionRequest(BaseModel): model: Optional[str] = "auto" messages: List[Message] stream: bool = False audio_config: Optional[AudioConfig] = None temperature: float = 0.85 top_p: float = 0.9 seed: Optional[Union[int, str]] = None lyrics: Optional[str] = "" sample_mode: bool = False thinking: bool = False use_format: bool = False use_cot_caption: bool = True use_cot_language: bool = True guidance_scale: float = 7.0 batch_size: int = 1 task_type: str = "text2music" repainting_start: float = 0.0 repainting_end: Optional[float] = None audio_cover_strength: float = 1.0 def parse_input(req: ChatCompletionRequest): last_user_msg = "" for msg in reversed(req.messages): if msg.role == "user": if isinstance(msg.content, str): last_user_msg = msg.content elif isinstance(msg.content, list): for item in msg.content: if isinstance(item, dict) and item.get("type") == "text": last_user_msg = item.get("text", ""); break break prompt = "" lyrics = req.lyrics or "" if "" in last_user_msg: import re p_match = re.search(r"(.*?)", last_user_msg, re.DOTALL) l_match = re.search(r"(.*?)", last_user_msg, re.DOTALL) prompt = p_match.group(1).strip() if p_match else "" if not lyrics and l_match: lyrics = l_match.group(1).strip() elif lyrics or req.sample_mode: prompt = last_user_msg else: if any(tag in last_user_msg for tag in ["[Verse", "[verse", "[Chorus", "[chorus", "[Bridge"]): lyrics = last_user_msg else: prompt = last_user_msg return prompt, lyrics # ─── 核心生成函數(ZeroGPU)───────────────────────────── # # 🚨 修復 TypeError: argument of type 'bool' is not iterable # # 根本原因:當 gr.Number(value=None) 傳入 None 作為預設值時, # gradio 會產生含有 "additionalProperties": False(布林值)的 JSON Schema。 # 而 gradio_client 的 get_type() 函數嘗試對這個布林值執行 # "const" in schema,因此觸發 TypeError。 # # 修復策略: # 1. 所有可選數字欄位改用 gr.Textbox 代替 gr.Number, # 避免 gr.Number(value=None) 產生有問題的 JSON Schema。 # 2. 在函數內部手動解析字串。 # @spaces.GPU(duration=120) def gradio_generate( prompt: str, # gr.Textbox lyrics: str, # gr.Textbox duration: str, # gr.Textbox (取代 gr.Number),避免 additionalProperties: False 問題 bpm: str, # gr.Textbox (取代 gr.Number) vocal_language: str,# gr.Dropdown instrumental: str, # gr.Dropdown ("是"/"否"/"自動判定") guidance_scale: str,# gr.Textbox (取代 gr.Slider,確保相容性) seed: str # gr.Textbox (取代 gr.Number) ): try: seed_val = int(seed.strip()) if seed and seed.strip() else None bpm_val = int(bpm.strip()) if bpm and bpm.strip() else None dur_val = float(duration.strip()) if duration and duration.strip() else 30.0 cfg_val = float(guidance_scale.strip()) if guidance_scale and guidance_scale.strip() else 7.0 lang_map = {"英文 (en)": "en", "中文 (zh)": "zh", "日文 (ja)": "ja", "韓文 (ko)": "ko", "自動判定": "en"} lang = lang_map.get(vocal_language, "en") if vocal_language else "en" pipe = get_pipeline() # 處理純伴奏:不直接傳入 bool,改以 prompt 增強文字 final_prompt = prompt if prompt else "instrumental music" if instrumental == "是": final_prompt = final_prompt + ", purely instrumental, no vocals" gen_kwargs = dict( prompt=final_prompt, lyrics=lyrics or "", audio_duration=dur_val, guidance_scale=cfg_val, infer_steps=27, scheduler_type="euler", ) if bpm_val: gen_kwargs["bpm"] = bpm_val if lang: gen_kwargs["vocal_language"] = lang if seed_val is not None: gen_kwargs["seed"] = seed_val result = pipe(**gen_kwargs) if hasattr(result, "audio"): audio_data = result.audio elif isinstance(result, tuple): audio_data = result[0] else: audio_data = result sample_rate = getattr(result, "sample_rate", 44100) if isinstance(audio_data, torch.Tensor): audio_data = audio_data.cpu().numpy() if audio_data.ndim > 1: audio_data = audio_data.squeeze() return (sample_rate, audio_data) except Exception as e: traceback.print_exc() raise gr.Error(f"生成失敗: {str(e)}") # ─── FastAPI ───────────────────────────────────────────── fastapi_app = FastAPI(title="ACE-Step API") fastapi_app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]) @fastapi_app.get("/health") async def health(): return {"status": "ok"} @fastapi_app.get("/v1/models") async def list_models(request: Request): check_auth(request) return {"data": [{"id": MODEL_ID, "name": "ACE-Step v1.5", "created": 1706688000, "pricing": {"prompt": "0", "completion": "0", "request": "0"}}]} @fastapi_app.post("/v1/chat/completions") async def chat_completions(req: ChatCompletionRequest, request: Request): check_auth(request) completion_id = f"chatcmpl-{uuid.uuid4().hex[:16]}" created_ts = int(time.time()) try: prompt, lyrics = parse_input(req) audio_cfg = req.audio_config or AudioConfig() seed_val = int(str(req.seed).split(",")[0].strip()) if req.seed is not None else None def _call_gradio(): client = Client("http://127.0.0.1:7860/") ui_lang = {"zh": "中文 (zh)", "ja": "日文 (ja)", "ko": "韓文 (ko)"}.get(audio_cfg.vocal_language, "英文 (en)") ui_instr = "是" if audio_cfg.instrumental is True else ("否" if audio_cfg.instrumental is False else "自動判定") # 所有參數轉換為字串,匹配 Gradio UI 的 Textbox 型別 return client.predict( prompt, lyrics, str(audio_cfg.duration or 30.0), str(audio_cfg.bpm) if audio_cfg.bpm else "", ui_lang, ui_instr, str(req.guidance_scale), str(seed_val) if seed_val is not None else "", api_name="/generate_music" ) result_audio_path = await asyncio.to_thread(_call_gradio) with open(result_audio_path, "rb") as f: audio_bytes = f.read() b64 = base64.b64encode(audio_bytes).decode("utf-8") audio_url = f"data:audio/wav;base64,{b64}" try: os.remove(result_audio_path) except: pass content_text = f"## 生成中繼資料\n**風格:** {prompt}\n**時長:** {audio_cfg.duration or 30}s\n" if lyrics: content_text += f"\n## 歌詞\n{lyrics}" response = { "id": completion_id, "object": "chat.completion", "created": created_ts, "model": MODEL_ID, "choices": [{"index": 0, "message": {"role": "assistant", "content": content_text, "audio": [{"type": "audio_url", "audio_url": {"url": audio_url}}]}, "finish_reason": "stop"}], "usage": {"prompt_tokens": len(prompt.split()), "completion_tokens": 100, "total_tokens": len(prompt.split()) + 100} } if req.stream: async def event_stream(): for chunk in [ {"delta": {"role": "assistant", "content": ""}}, {"delta": {"content": content_text}}, {"delta": {"audio": [{"type": "audio_url", "audio_url": {"url": audio_url}}]}}, {"delta": {}, "finish_reason": "stop"} ]: chunk_data = {"id": completion_id, "object": "chat.completion.chunk", "created": created_ts, "model": MODEL_ID, "choices": [{"index": 0, **chunk}]} yield f"data: {json.dumps(chunk_data)}\n\n" yield "data: [DONE]\n\n" return StreamingResponse(event_stream(), media_type="text/event-stream") return JSONResponse(response) except Exception as e: traceback.print_exc() raise HTTPException(status_code=500, detail=f"內部伺服器錯誤: {str(e)}") # ─── Gradio Web UI ─────────────────────────────────────── # 🚨 修復重點:所有可選數值欄位改用 gr.Textbox, # 不使用 gr.Number(value=None),防止產生問題 JSON Schema with gr.Blocks(title="🎵 ACE-Step v1.5 音樂生成器", theme=gr.themes.Soft()) as demo: gr.HTML("

🎵 ACE-Step v1.5 音樂生成器

") with gr.Tab("🎼 生成音樂"): with gr.Row(): with gr.Column(scale=1): prompt_input = gr.Textbox(label="🏷️ 音樂風格描述 (Prompt)", placeholder="例如:節奏強烈的 EDM、重低音與合成器") lyrics_input = gr.Textbox(label="📜 歌詞 (Lyrics,可選填)", lines=4) with gr.Row(): # ✅ 改用 Textbox,完全避免 gr.Number(value=None) 問題 duration_input = gr.Textbox(label="⏱️ 時長 (秒)", value="30", placeholder="30") bpm_input = gr.Textbox(label="🥁 BPM (可選)", value="", placeholder="例:120") with gr.Row(): lang_input = gr.Dropdown(label="🌍 語言", choices=["英文 (en)", "中文 (zh)", "日文 (ja)", "韓文 (ko)", "自動判定"], value="英文 (en)") instr_input = gr.Dropdown(label="🎸 純伴奏", choices=["自動判定", "是", "否"], value="自動判定") with gr.Row(): cfg_input = gr.Textbox(label="🎚️ Guidance Scale", value="7.0", placeholder="7.0") seed_input = gr.Textbox(label="🎲 Seed (可選)", value="", placeholder="例:42") generate_btn = gr.Button("🚀 開始生成音樂", variant="primary") with gr.Column(scale=1): audio_output = gr.Audio(label="🎵 生成結果", type="numpy") generate_btn.click( fn=gradio_generate, inputs=[prompt_input, lyrics_input, duration_input, bpm_input, lang_input, instr_input, cfg_input, seed_input], outputs=audio_output ) # API 路由用隱藏觸發按鈕 api_btn = gr.Button("API", visible=False) api_btn.click( fn=gradio_generate, inputs=[prompt_input, lyrics_input, duration_input, bpm_input, lang_input, instr_input, cfg_input, seed_input], outputs=audio_output, api_name="generate_music" ) app = gr.mount_gradio_app(fastapi_app, demo, path="/") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)