Spaces:
Sleeping
Sleeping
| import os | |
| import time | |
| import base64 | |
| import json | |
| import uuid | |
| import asyncio | |
| import traceback | |
| import torch | |
| import spaces | |
| import gradio as gr | |
| import numpy as np | |
| from fastapi import FastAPI, HTTPException, Request | |
| from fastapi.responses import StreamingResponse, JSONResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from typing import Optional, List, Union, Any | |
| from gradio_client import Client | |
| # ─── 1. 在啟動時預先下載模型 ──────────────────────────── | |
| from huggingface_hub import snapshot_download | |
| print("==== 正在預先下載模型到快取 ====") | |
| try: | |
| snapshot_download("ACE-Step/ACE-Step-v1-3.5B") | |
| print("==== 模型下載完成! ====") | |
| except Exception as e: | |
| print(f"模型下載失敗: {e}") | |
| # ─── 環境變數 ─────────────────────────────────────────── | |
| API_KEY = os.environ.get("API_KEY", None) | |
| MODEL_ID = "acemusic/acestep-v15-turbo" | |
| pipeline = None | |
| def get_pipeline(): | |
| global pipeline | |
| if pipeline is None: | |
| print("初始化 ACE-Step Pipeline...") | |
| from acestep.pipeline_ace_step import ACEStepPipeline | |
| pipeline = ACEStepPipeline( | |
| checkpoint_dir=None, | |
| dtype="bfloat16", | |
| device="cuda", | |
| ) | |
| return pipeline | |
| def check_auth(request: Request): | |
| if API_KEY is None: return True | |
| auth = request.headers.get("Authorization", "") | |
| if not auth.startswith("Bearer "): raise HTTPException(status_code=401, detail="缺少或格式錯誤的 API 金鑰") | |
| if auth.split(" ", 1)[1] != API_KEY: raise HTTPException(status_code=401, detail="無效的 API 金鑰") | |
| return True | |
| class AudioConfig(BaseModel): | |
| duration: Optional[float] = None | |
| bpm: Optional[int] = None | |
| vocal_language: str = "en" | |
| instrumental: Optional[bool] = None | |
| format: str = "mp3" | |
| key_scale: Optional[str] = None | |
| time_signature: Optional[str] = None | |
| class Message(BaseModel): | |
| role: str | |
| content: Union[str, List[Any]] | |
| class ChatCompletionRequest(BaseModel): | |
| model: Optional[str] = "auto" | |
| messages: List[Message] | |
| stream: bool = False | |
| audio_config: Optional[AudioConfig] = None | |
| temperature: float = 0.85 | |
| top_p: float = 0.9 | |
| seed: Optional[Union[int, str]] = None | |
| lyrics: Optional[str] = "" | |
| sample_mode: bool = False | |
| thinking: bool = False | |
| use_format: bool = False | |
| use_cot_caption: bool = True | |
| use_cot_language: bool = True | |
| guidance_scale: float = 7.0 | |
| batch_size: int = 1 | |
| task_type: str = "text2music" | |
| repainting_start: float = 0.0 | |
| repainting_end: Optional[float] = None | |
| audio_cover_strength: float = 1.0 | |
| def parse_input(req: ChatCompletionRequest): | |
| last_user_msg = "" | |
| for msg in reversed(req.messages): | |
| if msg.role == "user": | |
| if isinstance(msg.content, str): last_user_msg = msg.content | |
| elif isinstance(msg.content, list): | |
| for item in msg.content: | |
| if isinstance(item, dict) and item.get("type") == "text": | |
| last_user_msg = item.get("text", ""); break | |
| break | |
| prompt = "" | |
| lyrics = req.lyrics or "" | |
| if "<prompt>" in last_user_msg: | |
| import re | |
| p_match = re.search(r"<prompt>(.*?)</prompt>", last_user_msg, re.DOTALL) | |
| l_match = re.search(r"<lyrics>(.*?)</lyrics>", last_user_msg, re.DOTALL) | |
| prompt = p_match.group(1).strip() if p_match else "" | |
| if not lyrics and l_match: lyrics = l_match.group(1).strip() | |
| elif lyrics or req.sample_mode: | |
| prompt = last_user_msg | |
| else: | |
| if any(tag in last_user_msg for tag in ["[Verse", "[verse", "[Chorus", "[chorus", "[Bridge"]): | |
| lyrics = last_user_msg | |
| else: | |
| prompt = last_user_msg | |
| return prompt, lyrics | |
| # ─── 核心生成函數(ZeroGPU)───────────────────────────── | |
| # | |
| # 🚨 修復 TypeError: argument of type 'bool' is not iterable | |
| # | |
| # 根本原因:當 gr.Number(value=None) 傳入 None 作為預設值時, | |
| # gradio 會產生含有 "additionalProperties": False(布林值)的 JSON Schema。 | |
| # 而 gradio_client 的 get_type() 函數嘗試對這個布林值執行 | |
| # "const" in schema,因此觸發 TypeError。 | |
| # | |
| # 修復策略: | |
| # 1. 所有可選數字欄位改用 gr.Textbox 代替 gr.Number, | |
| # 避免 gr.Number(value=None) 產生有問題的 JSON Schema。 | |
| # 2. 在函數內部手動解析字串。 | |
| # | |
| def gradio_generate( | |
| prompt: str, # gr.Textbox | |
| lyrics: str, # gr.Textbox | |
| duration: str, # gr.Textbox (取代 gr.Number),避免 additionalProperties: False 問題 | |
| bpm: str, # gr.Textbox (取代 gr.Number) | |
| vocal_language: str,# gr.Dropdown | |
| instrumental: str, # gr.Dropdown ("是"/"否"/"自動判定") | |
| guidance_scale: str,# gr.Textbox (取代 gr.Slider,確保相容性) | |
| seed: str # gr.Textbox (取代 gr.Number) | |
| ): | |
| try: | |
| seed_val = int(seed.strip()) if seed and seed.strip() else None | |
| bpm_val = int(bpm.strip()) if bpm and bpm.strip() else None | |
| dur_val = float(duration.strip()) if duration and duration.strip() else 30.0 | |
| cfg_val = float(guidance_scale.strip()) if guidance_scale and guidance_scale.strip() else 7.0 | |
| lang_map = {"英文 (en)": "en", "中文 (zh)": "zh", "日文 (ja)": "ja", "韓文 (ko)": "ko", "自動判定": "en"} | |
| lang = lang_map.get(vocal_language, "en") if vocal_language else "en" | |
| pipe = get_pipeline() | |
| # 處理純伴奏:不直接傳入 bool,改以 prompt 增強文字 | |
| final_prompt = prompt if prompt else "instrumental music" | |
| if instrumental == "是": | |
| final_prompt = final_prompt + ", purely instrumental, no vocals" | |
| gen_kwargs = dict( | |
| prompt=final_prompt, | |
| lyrics=lyrics or "", | |
| audio_duration=dur_val, | |
| guidance_scale=cfg_val, | |
| infer_steps=27, | |
| scheduler_type="euler", | |
| ) | |
| if bpm_val: gen_kwargs["bpm"] = bpm_val | |
| if lang: gen_kwargs["vocal_language"] = lang | |
| if seed_val is not None: gen_kwargs["seed"] = seed_val | |
| result = pipe(**gen_kwargs) | |
| if hasattr(result, "audio"): audio_data = result.audio | |
| elif isinstance(result, tuple): audio_data = result[0] | |
| else: audio_data = result | |
| sample_rate = getattr(result, "sample_rate", 44100) | |
| if isinstance(audio_data, torch.Tensor): audio_data = audio_data.cpu().numpy() | |
| if audio_data.ndim > 1: audio_data = audio_data.squeeze() | |
| return (sample_rate, audio_data) | |
| except Exception as e: | |
| traceback.print_exc() | |
| raise gr.Error(f"生成失敗: {str(e)}") | |
| # ─── FastAPI ───────────────────────────────────────────── | |
| fastapi_app = FastAPI(title="ACE-Step API") | |
| fastapi_app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]) | |
| async def health(): return {"status": "ok"} | |
| async def list_models(request: Request): | |
| check_auth(request) | |
| return {"data": [{"id": MODEL_ID, "name": "ACE-Step v1.5", "created": 1706688000, "pricing": {"prompt": "0", "completion": "0", "request": "0"}}]} | |
| async def chat_completions(req: ChatCompletionRequest, request: Request): | |
| check_auth(request) | |
| completion_id = f"chatcmpl-{uuid.uuid4().hex[:16]}" | |
| created_ts = int(time.time()) | |
| try: | |
| prompt, lyrics = parse_input(req) | |
| audio_cfg = req.audio_config or AudioConfig() | |
| seed_val = int(str(req.seed).split(",")[0].strip()) if req.seed is not None else None | |
| def _call_gradio(): | |
| client = Client("http://127.0.0.1:7860/") | |
| ui_lang = {"zh": "中文 (zh)", "ja": "日文 (ja)", "ko": "韓文 (ko)"}.get(audio_cfg.vocal_language, "英文 (en)") | |
| ui_instr = "是" if audio_cfg.instrumental is True else ("否" if audio_cfg.instrumental is False else "自動判定") | |
| # 所有參數轉換為字串,匹配 Gradio UI 的 Textbox 型別 | |
| return client.predict( | |
| prompt, | |
| lyrics, | |
| str(audio_cfg.duration or 30.0), | |
| str(audio_cfg.bpm) if audio_cfg.bpm else "", | |
| ui_lang, | |
| ui_instr, | |
| str(req.guidance_scale), | |
| str(seed_val) if seed_val is not None else "", | |
| api_name="/generate_music" | |
| ) | |
| result_audio_path = await asyncio.to_thread(_call_gradio) | |
| with open(result_audio_path, "rb") as f: | |
| audio_bytes = f.read() | |
| b64 = base64.b64encode(audio_bytes).decode("utf-8") | |
| audio_url = f"data:audio/wav;base64,{b64}" | |
| try: os.remove(result_audio_path) | |
| except: pass | |
| content_text = f"## 生成中繼資料\n**風格:** {prompt}\n**時長:** {audio_cfg.duration or 30}s\n" | |
| if lyrics: content_text += f"\n## 歌詞\n{lyrics}" | |
| response = { | |
| "id": completion_id, "object": "chat.completion", "created": created_ts, "model": MODEL_ID, | |
| "choices": [{"index": 0, "message": {"role": "assistant", "content": content_text, "audio": [{"type": "audio_url", "audio_url": {"url": audio_url}}]}, "finish_reason": "stop"}], | |
| "usage": {"prompt_tokens": len(prompt.split()), "completion_tokens": 100, "total_tokens": len(prompt.split()) + 100} | |
| } | |
| if req.stream: | |
| async def event_stream(): | |
| for chunk in [ | |
| {"delta": {"role": "assistant", "content": ""}}, | |
| {"delta": {"content": content_text}}, | |
| {"delta": {"audio": [{"type": "audio_url", "audio_url": {"url": audio_url}}]}}, | |
| {"delta": {}, "finish_reason": "stop"} | |
| ]: | |
| chunk_data = {"id": completion_id, "object": "chat.completion.chunk", "created": created_ts, "model": MODEL_ID, "choices": [{"index": 0, **chunk}]} | |
| yield f"data: {json.dumps(chunk_data)}\n\n" | |
| yield "data: [DONE]\n\n" | |
| return StreamingResponse(event_stream(), media_type="text/event-stream") | |
| return JSONResponse(response) | |
| except Exception as e: | |
| traceback.print_exc() | |
| raise HTTPException(status_code=500, detail=f"內部伺服器錯誤: {str(e)}") | |
| # ─── Gradio Web UI ─────────────────────────────────────── | |
| # 🚨 修復重點:所有可選數值欄位改用 gr.Textbox, | |
| # 不使用 gr.Number(value=None),防止產生問題 JSON Schema | |
| with gr.Blocks(title="🎵 ACE-Step v1.5 音樂生成器", theme=gr.themes.Soft()) as demo: | |
| gr.HTML("<h1 style='text-align: center;'>🎵 ACE-Step v1.5 音樂生成器</h1>") | |
| with gr.Tab("🎼 生成音樂"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| prompt_input = gr.Textbox(label="🏷️ 音樂風格描述 (Prompt)", placeholder="例如:節奏強烈的 EDM、重低音與合成器") | |
| lyrics_input = gr.Textbox(label="📜 歌詞 (Lyrics,可選填)", lines=4) | |
| with gr.Row(): | |
| # ✅ 改用 Textbox,完全避免 gr.Number(value=None) 問題 | |
| duration_input = gr.Textbox(label="⏱️ 時長 (秒)", value="30", placeholder="30") | |
| bpm_input = gr.Textbox(label="🥁 BPM (可選)", value="", placeholder="例:120") | |
| with gr.Row(): | |
| lang_input = gr.Dropdown(label="🌍 語言", choices=["英文 (en)", "中文 (zh)", "日文 (ja)", "韓文 (ko)", "自動判定"], value="英文 (en)") | |
| instr_input = gr.Dropdown(label="🎸 純伴奏", choices=["自動判定", "是", "否"], value="自動判定") | |
| with gr.Row(): | |
| cfg_input = gr.Textbox(label="🎚️ Guidance Scale", value="7.0", placeholder="7.0") | |
| seed_input = gr.Textbox(label="🎲 Seed (可選)", value="", placeholder="例:42") | |
| generate_btn = gr.Button("🚀 開始生成音樂", variant="primary") | |
| with gr.Column(scale=1): | |
| audio_output = gr.Audio(label="🎵 生成結果", type="numpy") | |
| generate_btn.click( | |
| fn=gradio_generate, | |
| inputs=[prompt_input, lyrics_input, duration_input, bpm_input, lang_input, instr_input, cfg_input, seed_input], | |
| outputs=audio_output | |
| ) | |
| # API 路由用隱藏觸發按鈕 | |
| api_btn = gr.Button("API", visible=False) | |
| api_btn.click( | |
| fn=gradio_generate, | |
| inputs=[prompt_input, lyrics_input, duration_input, bpm_input, lang_input, instr_input, cfg_input, seed_input], | |
| outputs=audio_output, | |
| api_name="generate_music" | |
| ) | |
| app = gr.mount_gradio_app(fastapi_app, demo, path="/") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |