diff --git "a/acestep/api_server.py" "b/acestep/api_server.py" new file mode 100644--- /dev/null +++ "b/acestep/api_server.py" @@ -0,0 +1,2768 @@ +"""FastAPI server for ACE-Step V1.5. + +Endpoints: +- POST /release_task Create music generation task +- POST /query_result Batch query task results +- POST /create_random_sample Generate random music parameters via LLM +- POST /format_input Format and enhance lyrics/caption via LLM +- GET /v1/models List available models +- GET /v1/audio Download audio file +- GET /health Health check + +NOTE: +- In-memory queue and job store -> run uvicorn with workers=1. +""" + +from __future__ import annotations + +import asyncio +import glob +import json +import os +import random +import sys +import time +import traceback +import tempfile +import urllib.parse +from collections import deque +from concurrent.futures import ThreadPoolExecutor +from contextlib import asynccontextmanager +from dataclasses import dataclass +from pathlib import Path +from threading import Lock +from typing import Any, Dict, List, Literal, Optional, Union +from uuid import uuid4 +from loguru import logger + +try: + from dotenv import load_dotenv +except ImportError: # Optional dependency + load_dotenv = None # type: ignore + +from fastapi import FastAPI, HTTPException, Request, Depends, Header +from pydantic import BaseModel, Field +from starlette.datastructures import UploadFile as StarletteUploadFile + +from acestep.handler import AceStepHandler +from acestep.llm_inference import LLMHandler +from acestep.constants import ( + DEFAULT_DIT_INSTRUCTION, + DEFAULT_LM_INSTRUCTION, + TASK_INSTRUCTIONS, +) +from acestep.inference import ( + GenerationParams, + GenerationConfig, + generate_music, + create_sample, + format_sample, +) +from acestep.gradio_ui.events.results_handlers import _build_generation_info +from acestep.gpu_config import ( + get_gpu_config, + get_gpu_memory_gb, + print_gpu_config_info, + set_global_gpu_config, + get_recommended_lm_model, + is_lm_model_supported, + GPUConfig, + VRAM_16GB_MIN_GB, +) + + +# ============================================================================= +# Model Auto-Download Support +# ============================================================================= + +# Model name to repository mapping +MODEL_REPO_MAPPING = { + # Main unified repository (contains: acestep-v15-turbo, acestep-5Hz-lm-1.7B, Qwen3-Embedding-0.6B, vae) + "acestep-v15-turbo": "ACE-Step/Ace-Step1.5", + "acestep-5Hz-lm-1.7B": "ACE-Step/Ace-Step1.5", + "vae": "ACE-Step/Ace-Step1.5", + "Qwen3-Embedding-0.6B": "ACE-Step/Ace-Step1.5", + # Separate model repositories + "acestep-5Hz-lm-0.6B": "ACE-Step/acestep-5Hz-lm-0.6B", + "acestep-5Hz-lm-4B": "ACE-Step/acestep-5Hz-lm-4B", + "acestep-v15-base": "ACE-Step/acestep-v15-base", + "acestep-v15-sft": "ACE-Step/acestep-v15-sft", + "acestep-v15-turbo-shift3": "ACE-Step/acestep-v15-turbo-shift3", +} + +DEFAULT_REPO_ID = "ACE-Step/Ace-Step1.5" + + +def _can_access_google(timeout: float = 3.0) -> bool: + """Check if Google is accessible (to determine HuggingFace vs ModelScope).""" + import socket + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + try: + sock.settimeout(timeout) + sock.connect(("www.google.com", 443)) + return True + except (socket.timeout, socket.error, OSError): + return False + finally: + sock.close() + + +def _download_from_huggingface(repo_id: str, local_dir: str, model_name: str) -> str: + """Download model from HuggingFace Hub.""" + from huggingface_hub import snapshot_download + + is_unified_repo = repo_id == DEFAULT_REPO_ID or repo_id == "ACE-Step/Ace-Step1.5" + + if is_unified_repo: + download_dir = local_dir + print(f"[Model Download] Downloading unified repo {repo_id} to {download_dir}...") + else: + download_dir = os.path.join(local_dir, model_name) + os.makedirs(download_dir, exist_ok=True) + print(f"[Model Download] Downloading {model_name} from {repo_id} to {download_dir}...") + + snapshot_download( + repo_id=repo_id, + local_dir=download_dir, + local_dir_use_symlinks=False, + ) + + return os.path.join(local_dir, model_name) + + +def _download_from_modelscope(repo_id: str, local_dir: str, model_name: str) -> str: + """Download model from ModelScope.""" + from modelscope import snapshot_download + + is_unified_repo = repo_id == DEFAULT_REPO_ID or repo_id == "ACE-Step/Ace-Step1.5" + + if is_unified_repo: + download_dir = local_dir + print(f"[Model Download] Downloading unified repo {repo_id} from ModelScope to {download_dir}...") + else: + download_dir = os.path.join(local_dir, model_name) + os.makedirs(download_dir, exist_ok=True) + print(f"[Model Download] Downloading {model_name} from ModelScope {repo_id} to {download_dir}...") + + # ModelScope snapshot_download returns the cache path + # Use cache_dir parameter for better compatibility across versions + try: + # Try with local_dir first (newer versions) + result_path = snapshot_download( + model_id=repo_id, + local_dir=download_dir, + ) + print(f"[Model Download] ModelScope download completed: {result_path}") + except TypeError: + # Fallback to cache_dir for older versions + print("[Model Download] Retrying with cache_dir parameter...") + result_path = snapshot_download( + model_id=repo_id, + cache_dir=download_dir, + ) + print(f"[Model Download] ModelScope download completed: {result_path}") + + return os.path.join(local_dir, model_name) + + +def _ensure_model_downloaded(model_name: str, checkpoint_dir: str) -> str: + """ + Ensure model is downloaded. Auto-detect source based on network. + + Args: + model_name: Model directory name (e.g., "acestep-v15-turbo") + checkpoint_dir: Target checkpoint directory + + Returns: + Path to the model directory + """ + model_path = os.path.join(checkpoint_dir, model_name) + + # Check if model already exists + if os.path.exists(model_path) and os.listdir(model_path): + print(f"[Model Download] Model {model_name} already exists at {model_path}") + return model_path + + # Get repository ID + repo_id = MODEL_REPO_MAPPING.get(model_name, DEFAULT_REPO_ID) + + print(f"[Model Download] Model {model_name} not found, checking network...") + + # Check for user preference + prefer_source = os.environ.get("ACESTEP_DOWNLOAD_SOURCE", "").lower() + + # Determine download source + if prefer_source == "huggingface": + use_huggingface = True + print("[Model Download] User preference: HuggingFace Hub") + elif prefer_source == "modelscope": + use_huggingface = False + print("[Model Download] User preference: ModelScope") + else: + use_huggingface = _can_access_google() + print(f"[Model Download] Auto-detected: {'HuggingFace Hub' if use_huggingface else 'ModelScope'}") + + if use_huggingface: + print("[Model Download] Using HuggingFace Hub...") + try: + return _download_from_huggingface(repo_id, checkpoint_dir, model_name) + except Exception as e: + print(f"[Model Download] HuggingFace download failed: {e}") + print("[Model Download] Falling back to ModelScope...") + return _download_from_modelscope(repo_id, checkpoint_dir, model_name) + else: + print("[Model Download] Using ModelScope...") + try: + return _download_from_modelscope(repo_id, checkpoint_dir, model_name) + except Exception as e: + print(f"[Model Download] ModelScope download failed: {e}") + print("[Model Download] Trying HuggingFace as fallback...") + return _download_from_huggingface(repo_id, checkpoint_dir, model_name) + + +def _get_project_root() -> str: + current_file = os.path.abspath(__file__) + return os.path.dirname(os.path.dirname(current_file)) + + +# ============================================================================= +# Constants +# ============================================================================= + +RESULT_KEY_PREFIX = "ace_step_v1.5_" +RESULT_EXPIRE_SECONDS = 7 * 24 * 60 * 60 # 7 days +TASK_TIMEOUT_SECONDS = 3600 # 1 hour +JOB_STORE_CLEANUP_INTERVAL = 300 # 5 minutes - interval for cleaning up old jobs +JOB_STORE_MAX_AGE_SECONDS = 86400 # 24 hours - completed jobs older than this will be cleaned +STATUS_MAP = {"queued": 0, "running": 0, "succeeded": 1, "failed": 2} + +LM_DEFAULT_TEMPERATURE = 0.85 +LM_DEFAULT_CFG_SCALE = 2.5 +LM_DEFAULT_TOP_P = 0.9 + + +def _wrap_response(data: Any, code: int = 200, error: Optional[str] = None) -> Dict[str, Any]: + """Wrap response data in standard format.""" + return { + "data": data, + "code": code, + "error": error, + "timestamp": int(time.time() * 1000), + "extra": None, + } + + +# ============================================================================= +# Example Data for Random Sample +# ============================================================================= + +SIMPLE_MODE_EXAMPLES_DIR = os.path.join(_get_project_root(), "examples", "simple_mode") +CUSTOM_MODE_EXAMPLES_DIR = os.path.join(_get_project_root(), "examples", "text2music") + + +def _load_all_examples(sample_mode: str = "simple_mode") -> List[Dict[str, Any]]: + """Load all example data files from the examples directory.""" + examples = [] + examples_dir = SIMPLE_MODE_EXAMPLES_DIR if sample_mode == "simple_mode" else CUSTOM_MODE_EXAMPLES_DIR + pattern = os.path.join(examples_dir, "example_*.json") + + for filepath in glob.glob(pattern): + try: + with open(filepath, 'r', encoding='utf-8') as f: + data = json.load(f) + examples.append(data) + except Exception as e: + print(f"[API Server] Failed to load example file {filepath}: {e}") + + return examples + + +# Pre-load example data at module load time +SIMPLE_EXAMPLE_DATA: List[Dict[str, Any]] = _load_all_examples(sample_mode="simple_mode") +CUSTOM_EXAMPLE_DATA: List[Dict[str, Any]] = _load_all_examples(sample_mode="custom_mode") + +# ============================================================================= +# API Key Authentication +# ============================================================================= + +_api_key: Optional[str] = None + + +def set_api_key(key: Optional[str]): + """Set the API key for authentication""" + global _api_key + _api_key = key + + +def verify_token_from_request(body: dict, authorization: Optional[str] = None) -> Optional[str]: + """ + Verify API key from request body (ai_token) or Authorization header. + Returns the token if valid, None if no auth required. + """ + if _api_key is None: + return None # No auth required + + # Try ai_token from body first + ai_token = body.get("ai_token") if body else None + if ai_token: + if ai_token == _api_key: + return ai_token + raise HTTPException(status_code=401, detail="Invalid ai_token") + + # Fallback to Authorization header + if authorization: + if authorization.startswith("Bearer "): + token = authorization[7:] + else: + token = authorization + if token == _api_key: + return token + raise HTTPException(status_code=401, detail="Invalid API key") + + # No token provided but auth is required + raise HTTPException(status_code=401, detail="Missing ai_token or Authorization header") + + +async def verify_api_key(authorization: Optional[str] = Header(None)): + """Verify API key from Authorization header (legacy, for non-body endpoints)""" + if _api_key is None: + return # No auth required + + if not authorization: + raise HTTPException(status_code=401, detail="Missing Authorization header") + + # Support "Bearer " format + if authorization.startswith("Bearer "): + token = authorization[7:] + else: + token = authorization + + if token != _api_key: + raise HTTPException(status_code=401, detail="Invalid API key") + +# Parameter aliases for request parsing +PARAM_ALIASES = { + "prompt": ["prompt", "caption"], + "lyrics": ["lyrics"], + "thinking": ["thinking"], + "analysis_only": ["analysis_only", "analysisOnly"], + "full_analysis_only": ["full_analysis_only", "fullAnalysisOnly"], + "sample_mode": ["sample_mode", "sampleMode"], + "sample_query": ["sample_query", "sampleQuery", "description", "desc"], + "use_format": ["use_format", "useFormat", "format"], + "model": ["model", "model_name", "modelName", "dit_model", "ditModel"], + "key_scale": ["key_scale", "keyscale", "keyScale", "key"], + "time_signature": ["time_signature", "timesignature", "timeSignature"], + "audio_duration": ["audio_duration", "duration", "audioDuration", "target_duration", "targetDuration"], + "vocal_language": ["vocal_language", "vocalLanguage", "language"], + "bpm": ["bpm"], + "inference_steps": ["inference_steps", "inferenceSteps"], + "guidance_scale": ["guidance_scale", "guidanceScale"], + "use_random_seed": ["use_random_seed", "useRandomSeed"], + "seed": ["seed"], + + "audio_cover_strength": ["audio_cover_strength", "audioCoverStrength"], + "reference_audio_path": ["reference_audio_path", "ref_audio_path", "referenceAudioPath", "refAudioPath"], + "src_audio_path": ["src_audio_path", "ctx_audio_path", "sourceAudioPath", "srcAudioPath", "ctxAudioPath"], + "task_type": ["task_type", "taskType"], + "infer_method": ["infer_method", "inferMethod"], + "use_tiled_decode": ["use_tiled_decode", "useTiledDecode"], + "constrained_decoding": ["constrained_decoding", "constrainedDecoding", "constrained"], + "constrained_decoding_debug": ["constrained_decoding_debug", "constrainedDecodingDebug"], + "use_cot_caption": ["use_cot_caption", "cot_caption", "cot-caption"], + "use_cot_language": ["use_cot_language", "cot_language", "cot-language"], + "is_format_caption": ["is_format_caption", "isFormatCaption"], + "allow_lm_batch": ["allow_lm_batch", "allowLmBatch", "parallel_thinking"], +} + + +def _parse_description_hints(description: str) -> tuple[Optional[str], bool]: + """ + Parse a description string to extract language code and instrumental flag. + + This function analyzes user descriptions like "Pop rock. English" or "piano solo" + to detect: + - Language: Maps language names to ISO codes (e.g., "English" -> "en") + - Instrumental: Detects patterns indicating instrumental/no-vocal music + + Args: + description: User's natural language music description + + Returns: + (language_code, is_instrumental) tuple: + - language_code: ISO language code (e.g., "en", "zh") or None if not detected + - is_instrumental: True if description indicates instrumental music + """ + import re + + if not description: + return None, False + + description_lower = description.lower().strip() + + # Language mapping: input patterns -> ISO code + language_mapping = { + 'english': 'en', 'en': 'en', + 'chinese': 'zh', '中文': 'zh', 'zh': 'zh', 'mandarin': 'zh', + 'japanese': 'ja', '日本語': 'ja', 'ja': 'ja', + 'korean': 'ko', '한국어': 'ko', 'ko': 'ko', + 'spanish': 'es', 'español': 'es', 'es': 'es', + 'french': 'fr', 'français': 'fr', 'fr': 'fr', + 'german': 'de', 'deutsch': 'de', 'de': 'de', + 'italian': 'it', 'italiano': 'it', 'it': 'it', + 'portuguese': 'pt', 'português': 'pt', 'pt': 'pt', + 'russian': 'ru', 'русский': 'ru', 'ru': 'ru', + 'bengali': 'bn', 'bn': 'bn', + 'hindi': 'hi', 'hi': 'hi', + 'arabic': 'ar', 'ar': 'ar', + 'thai': 'th', 'th': 'th', + 'vietnamese': 'vi', 'vi': 'vi', + 'indonesian': 'id', 'id': 'id', + 'turkish': 'tr', 'tr': 'tr', + 'dutch': 'nl', 'nl': 'nl', + 'polish': 'pl', 'pl': 'pl', + } + + # Detect language + detected_language = None + for lang_name, lang_code in language_mapping.items(): + if len(lang_name) <= 2: + pattern = r'(?:^|\s|[.,;:!?])' + re.escape(lang_name) + r'(?:$|\s|[.,;:!?])' + else: + pattern = r'\b' + re.escape(lang_name) + r'\b' + + if re.search(pattern, description_lower): + detected_language = lang_code + break + + # Detect instrumental + is_instrumental = False + if 'instrumental' in description_lower: + is_instrumental = True + elif 'pure music' in description_lower or 'pure instrument' in description_lower: + is_instrumental = True + elif description_lower.endswith(' solo') or description_lower == 'solo': + is_instrumental = True + + return detected_language, is_instrumental + + +JobStatus = Literal["queued", "running", "succeeded", "failed"] + + +class GenerateMusicRequest(BaseModel): + prompt: str = Field(default="", description="Text prompt describing the music") + lyrics: str = Field(default="", description="Lyric text") + + # New API semantics: + # - thinking=True: use 5Hz LM to generate audio codes (lm-dit behavior) + # - thinking=False: do not use LM to generate codes (dit behavior) + # Regardless of thinking, if some metas are missing, server may use LM to fill them. + thinking: bool = False + # Sample-mode requests auto-generate caption/lyrics/metas via LM (no user prompt). + sample_mode: bool = False + # Description for sample mode: auto-generate caption/lyrics from description query + sample_query: str = Field(default="", description="Query/description for sample mode (use create_sample)") + # Whether to use format_sample() to enhance input caption/lyrics + use_format: bool = Field(default=False, description="Use format_sample() to enhance input (default: False)") + # Model name for multi-model support (select which DiT model to use) + model: Optional[str] = Field(default=None, description="Model name to use (e.g., 'acestep-v15-turbo')") + + bpm: Optional[int] = None + # Accept common client keys via manual parsing (see RequestParser). + key_scale: str = "" + time_signature: str = "" + vocal_language: str = "en" + inference_steps: int = 8 + guidance_scale: float = 7.0 + use_random_seed: bool = True + seed: Union[int, str] = -1 + + reference_audio_path: Optional[str] = None + src_audio_path: Optional[str] = None + audio_duration: Optional[float] = None + batch_size: Optional[int] = None + + repainting_start: float = 0.0 + repainting_end: Optional[float] = None + + instruction: str = DEFAULT_DIT_INSTRUCTION + audio_cover_strength: float = 1.0 + task_type: str = "text2music" + analysis_only: bool = False + full_analysis_only: bool = False + + use_adg: bool = False + cfg_interval_start: float = 0.0 + cfg_interval_end: float = 1.0 + infer_method: str = "ode" # "ode" or "sde" - diffusion inference method + shift: float = Field( + default=3.0, + description="Timestep shift factor (range 1.0~5.0, default 3.0). Only effective for base models, not turbo models." + ) + timesteps: Optional[str] = Field( + default=None, + description="Custom timesteps (comma-separated, e.g., '0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0'). Overrides inference_steps and shift." + ) + + audio_format: str = "mp3" + use_tiled_decode: bool = True + + # 5Hz LM (server-side): used for metadata completion and (when thinking=True) codes generation. + lm_model_path: Optional[str] = None # e.g. "acestep-5Hz-lm-0.6B" + lm_backend: Literal["vllm", "pt", "mlx"] = "vllm" + + constrained_decoding: bool = True + constrained_decoding_debug: bool = False + use_cot_caption: bool = True + use_cot_language: bool = True + is_format_caption: bool = False + allow_lm_batch: bool = True + + lm_temperature: float = 0.85 + lm_cfg_scale: float = 2.5 + lm_top_k: Optional[int] = None + lm_top_p: Optional[float] = 0.9 + lm_repetition_penalty: float = 1.0 + lm_negative_prompt: str = "NO USER INPUT" + + class Config: + allow_population_by_field_name = True + allow_population_by_alias = True + + +class CreateJobResponse(BaseModel): + task_id: str + status: JobStatus + queue_position: int = 0 # 1-based best-effort position when queued + progress_text: Optional[str] = "" + + +class JobResult(BaseModel): + first_audio_path: Optional[str] = None + second_audio_path: Optional[str] = None + audio_paths: list[str] = Field(default_factory=list) + + generation_info: str = "" + status_message: str = "" + seed_value: str = "" + + metas: Dict[str, Any] = Field(default_factory=dict) + bpm: Optional[int] = None + duration: Optional[float] = None + genres: Optional[str] = None + keyscale: Optional[str] = None + timesignature: Optional[str] = None + + # Model information + lm_model: Optional[str] = None + dit_model: Optional[str] = None + + +class JobResponse(BaseModel): + job_id: str + status: JobStatus + created_at: float + started_at: Optional[float] = None + finished_at: Optional[float] = None + + # queue observability + queue_position: int = 0 + eta_seconds: Optional[float] = None + avg_job_seconds: Optional[float] = None + + result: Optional[JobResult] = None + error: Optional[str] = None + + +@dataclass +class _JobRecord: + job_id: str + status: JobStatus + created_at: float + started_at: Optional[float] = None + finished_at: Optional[float] = None + result: Optional[Dict[str, Any]] = None + error: Optional[str] = None + progress_text: str = "" + status_text: str = "" + env: str = "development" + progress: float = 0.0 # 0.0 - 1.0 + stage: str = "queued" + updated_at: Optional[float] = None + # OpenRouter integration: synchronous wait / streaming support + done_event: Optional[asyncio.Event] = None + progress_queue: Optional[asyncio.Queue] = None + + +class _JobStore: + def __init__(self, max_age_seconds: int = JOB_STORE_MAX_AGE_SECONDS) -> None: + self._lock = Lock() + self._jobs: Dict[str, _JobRecord] = {} + self._max_age = max_age_seconds + + def create(self) -> _JobRecord: + job_id = str(uuid4()) + now = time.time() + rec = _JobRecord(job_id=job_id, status="queued", created_at=now, progress=0.0, stage="queued", updated_at=now) + with self._lock: + self._jobs[job_id] = rec + return rec + + def create_with_id(self, job_id: str, env: str = "development") -> _JobRecord: + """Create job record with specified ID""" + now = time.time() + rec = _JobRecord( + job_id=job_id, + status="queued", + created_at=now, + env=env, + progress=0.0, + stage="queued", + updated_at=now, + ) + with self._lock: + self._jobs[job_id] = rec + return rec + + def get(self, job_id: str) -> Optional[_JobRecord]: + with self._lock: + return self._jobs.get(job_id) + + def mark_running(self, job_id: str) -> None: + with self._lock: + rec = self._jobs[job_id] + rec.status = "running" + rec.started_at = time.time() + rec.progress = max(rec.progress, 0.01) + rec.stage = "running" + rec.updated_at = time.time() + + def mark_succeeded(self, job_id: str, result: Dict[str, Any]) -> None: + with self._lock: + rec = self._jobs[job_id] + rec.status = "succeeded" + rec.finished_at = time.time() + rec.result = result + rec.error = None + rec.progress = 1.0 + rec.stage = "succeeded" + rec.updated_at = time.time() + + def mark_failed(self, job_id: str, error: str) -> None: + with self._lock: + rec = self._jobs[job_id] + rec.status = "failed" + rec.finished_at = time.time() + rec.result = None + rec.error = error + rec.progress = rec.progress if rec.progress > 0 else 0.0 + rec.stage = "failed" + rec.updated_at = time.time() + + def update_progress(self, job_id: str, progress: float, stage: Optional[str] = None) -> None: + with self._lock: + rec = self._jobs.get(job_id) + if not rec: + return + rec.progress = max(0.0, min(1.0, float(progress))) + if stage: + rec.stage = stage + rec.updated_at = time.time() + + def cleanup_old_jobs(self, max_age_seconds: Optional[int] = None) -> int: + """ + Clean up completed jobs older than max_age_seconds. + + Only removes jobs with status 'succeeded' or 'failed'. + Jobs that are 'queued' or 'running' are never removed. + + Returns the number of jobs removed. + """ + max_age = max_age_seconds if max_age_seconds is not None else self._max_age + now = time.time() + removed = 0 + + with self._lock: + to_remove = [] + for job_id, rec in self._jobs.items(): + if rec.status in ("succeeded", "failed"): + finish_time = rec.finished_at or rec.created_at + age = now - finish_time + if age > max_age: + to_remove.append(job_id) + + for job_id in to_remove: + del self._jobs[job_id] + removed += 1 + + return removed + + def get_stats(self) -> Dict[str, int]: + """Get statistics about jobs in the store.""" + with self._lock: + stats = { + "total": len(self._jobs), + "queued": 0, + "running": 0, + "succeeded": 0, + "failed": 0, + } + for rec in self._jobs.values(): + if rec.status in stats: + stats[rec.status] += 1 + return stats + + def update_status_text(self, job_id: str, text: str) -> None: + with self._lock: + if job_id in self._jobs: + self._jobs[job_id].status_text = text + + def update_progress_text(self, job_id: str, text: str) -> None: + with self._lock: + if job_id in self._jobs: + self._jobs[job_id].progress_text = text + +def _env_bool(name: str, default: bool) -> bool: + v = os.getenv(name) + if v is None: + return default + return v.strip().lower() in {"1", "true", "yes", "y", "on"} + + + + +def _get_model_name(config_path: str) -> str: + """ + Extract model name from config_path. + + Args: + config_path: Path like "acestep-v15-turbo" or "/path/to/acestep-v15-turbo" + + Returns: + Model name (last directory name from config_path) + """ + if not config_path: + return "" + normalized = config_path.rstrip("/\\") + return os.path.basename(normalized) + + +_project_env_loaded = False + + +def _load_project_env() -> None: + """Load .env at most once per process to avoid epoch-boundary stalls (e.g. Windows LoRA training).""" + global _project_env_loaded + if _project_env_loaded or load_dotenv is None: + return + try: + project_root = _get_project_root() + env_path = os.path.join(project_root, ".env") + if os.path.exists(env_path): + load_dotenv(env_path, override=False) + _project_env_loaded = True + except Exception: + # Optional best-effort: continue even if .env loading fails. + pass + + +_load_project_env() + + +def _to_int(v: Any, default: Optional[int] = None) -> Optional[int]: + if v is None: + return default + if isinstance(v, int): + return v + s = str(v).strip() + if s == "": + return default + try: + return int(s) + except Exception: + return default + + +def _to_float(v: Any, default: Optional[float] = None) -> Optional[float]: + if v is None: + return default + if isinstance(v, float): + return v + s = str(v).strip() + if s == "": + return default + try: + return float(s) + except Exception: + return default + + +def _to_bool(v: Any, default: bool = False) -> bool: + if v is None: + return default + if isinstance(v, bool): + return v + s = str(v).strip().lower() + if s == "": + return default + return s in {"1", "true", "yes", "y", "on"} + + +def _map_status(status: str) -> int: + """Map job status string to integer code.""" + return STATUS_MAP.get(status, 2) + + +def _parse_timesteps(s: Optional[str]) -> Optional[List[float]]: + """Parse comma-separated timesteps string to list of floats.""" + if not s or not s.strip(): + return None + try: + return [float(t.strip()) for t in s.split(",") if t.strip()] + except (ValueError, Exception): + return None + + +def _is_instrumental(lyrics: str) -> bool: + """ + Determine if the music should be instrumental based on lyrics. + + Returns True if: + - lyrics is empty or whitespace only + - lyrics (lowercased and trimmed) is "[inst]" or "[instrumental]" + """ + if not lyrics: + return True + lyrics_clean = lyrics.strip().lower() + if not lyrics_clean: + return True + return lyrics_clean in ("[inst]", "[instrumental]") + + +class RequestParser: + """Parse request parameters from multiple sources with alias support.""" + + def __init__(self, raw: dict): + self._raw = dict(raw) if raw else {} + self._param_obj = self._parse_json(self._raw.get("param_obj")) + self._metas = self._find_metas() + + def _parse_json(self, v) -> dict: + if isinstance(v, dict): + return v + if isinstance(v, str) and v.strip(): + try: + return json.loads(v) + except Exception: + pass + return {} + + def _find_metas(self) -> dict: + for key in ("metas", "meta", "metadata", "user_metadata", "userMetadata"): + v = self._raw.get(key) + if v: + return self._parse_json(v) + return {} + + def get(self, name: str, default=None): + """Get parameter by canonical name from all sources.""" + aliases = PARAM_ALIASES.get(name, [name]) + for source in (self._raw, self._param_obj, self._metas): + for alias in aliases: + v = source.get(alias) + if v is not None: + return v + return default + + def str(self, name: str, default: str = "") -> str: + v = self.get(name) + return str(v) if v is not None else default + + def int(self, name: str, default: Optional[int] = None) -> Optional[int]: + return _to_int(self.get(name), default) + + def float(self, name: str, default: Optional[float] = None) -> Optional[float]: + return _to_float(self.get(name), default) + + def bool(self, name: str, default: bool = False) -> bool: + return _to_bool(self.get(name), default) + + +async def _save_upload_to_temp(upload: StarletteUploadFile, *, prefix: str) -> str: + suffix = Path(upload.filename or "").suffix + fd, path = tempfile.mkstemp(prefix=f"{prefix}_", suffix=suffix) + os.close(fd) + try: + with open(path, "wb") as f: + while True: + chunk = await upload.read(1024 * 1024) + if not chunk: + break + f.write(chunk) + except Exception: + try: + os.remove(path) + except Exception: + pass + raise + finally: + try: + await upload.close() + except Exception: + pass + return path + +class LogBuffer: + def __init__(self): + self.last_message = "Waiting" + + def write(self, message): + msg = message.strip() + if msg: + self.last_message = msg + + def flush(self): + pass + +log_buffer = LogBuffer() +logger.add(lambda msg: log_buffer.write(msg), format="{time:HH:mm:ss} | {level} | {message}") + +class StderrLogger: + def __init__(self, original_stderr, buffer): + self.original_stderr = original_stderr + self.buffer = buffer + + def write(self, message): + self.original_stderr.write(message) # Print to terminal + self.buffer.write(message) # Send to API buffer + + def flush(self): + self.original_stderr.flush() + +sys.stderr = StderrLogger(sys.stderr, log_buffer) + + +def create_app() -> FastAPI: + store = _JobStore() + + # API Key authentication (from environment variable) + api_key = os.getenv("ACESTEP_API_KEY", None) + set_api_key(api_key) + + QUEUE_MAXSIZE = int(os.getenv("ACESTEP_QUEUE_MAXSIZE", "200")) + WORKER_COUNT = int(os.getenv("ACESTEP_QUEUE_WORKERS", "1")) # Single GPU recommended + + INITIAL_AVG_JOB_SECONDS = float(os.getenv("ACESTEP_AVG_JOB_SECONDS", "5.0")) + AVG_WINDOW = int(os.getenv("ACESTEP_AVG_WINDOW", "50")) + + def _path_to_audio_url(path: str) -> str: + """Convert local file path to downloadable relative URL""" + if not path: + return path + if path.startswith("http://") or path.startswith("https://"): + return path + encoded_path = urllib.parse.quote(path, safe="") + return f"/v1/audio?path={encoded_path}" + + @asynccontextmanager + async def lifespan(app: FastAPI): + # Clear proxy env that may affect downstream libs + for proxy_var in ["http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY"]: + os.environ.pop(proxy_var, None) + + # Ensure compilation/temp caches do not fill up small default /tmp. + # Triton/Inductor (and the system compiler) can create large temporary files. + project_root = _get_project_root() + cache_root = os.path.join(project_root, ".cache", "acestep") + tmp_root = (os.getenv("ACESTEP_TMPDIR") or os.path.join(cache_root, "tmp")).strip() + triton_cache_root = (os.getenv("TRITON_CACHE_DIR") or os.path.join(cache_root, "triton")).strip() + inductor_cache_root = (os.getenv("TORCHINDUCTOR_CACHE_DIR") or os.path.join(cache_root, "torchinductor")).strip() + + for p in [cache_root, tmp_root, triton_cache_root, inductor_cache_root]: + try: + os.makedirs(p, exist_ok=True) + except Exception: + # Best-effort: do not block startup if directory creation fails. + pass + + # Respect explicit user overrides; if ACESTEP_TMPDIR is set, it should win. + if os.getenv("ACESTEP_TMPDIR"): + os.environ["TMPDIR"] = tmp_root + os.environ["TEMP"] = tmp_root + os.environ["TMP"] = tmp_root + else: + os.environ.setdefault("TMPDIR", tmp_root) + os.environ.setdefault("TEMP", tmp_root) + os.environ.setdefault("TMP", tmp_root) + + os.environ.setdefault("TRITON_CACHE_DIR", triton_cache_root) + os.environ.setdefault("TORCHINDUCTOR_CACHE_DIR", inductor_cache_root) + + handler = AceStepHandler() + llm_handler = LLMHandler() + init_lock = asyncio.Lock() + app.state._initialized = False + app.state._init_error = None + app.state._init_lock = init_lock + + app.state.llm_handler = llm_handler + app.state._llm_initialized = False + app.state._llm_init_error = None + app.state._llm_init_lock = Lock() + app.state._llm_lazy_load_disabled = False # Will be set to True if LLM skipped due to GPU config + + # Multi-model support: secondary DiT handlers + handler2 = None + handler3 = None + config_path2 = os.getenv("ACESTEP_CONFIG_PATH2", "").strip() + config_path3 = os.getenv("ACESTEP_CONFIG_PATH3", "").strip() + + if config_path2: + handler2 = AceStepHandler() + if config_path3: + handler3 = AceStepHandler() + + app.state.handler2 = handler2 + app.state.handler3 = handler3 + app.state._initialized2 = False + app.state._initialized3 = False + app.state._config_path = os.getenv("ACESTEP_CONFIG_PATH", "acestep-v15-turbo") + app.state._config_path2 = config_path2 + app.state._config_path3 = config_path3 + + max_workers = int(os.getenv("ACESTEP_API_WORKERS", "1")) + executor = ThreadPoolExecutor(max_workers=max_workers) + + # Queue & observability + app.state.job_queue = asyncio.Queue(maxsize=QUEUE_MAXSIZE) # (job_id, req) + app.state.pending_ids = deque() # queued job_ids + app.state.pending_lock = asyncio.Lock() + + # temp files per job (from multipart uploads) + app.state.job_temp_files = {} # job_id -> list[path] + app.state.job_temp_files_lock = asyncio.Lock() + + # stats + app.state.stats_lock = asyncio.Lock() + app.state.recent_durations = deque(maxlen=AVG_WINDOW) + app.state.avg_job_seconds = INITIAL_AVG_JOB_SECONDS + + app.state.handler = handler + app.state.executor = executor + app.state.job_store = store + app.state._python_executable = sys.executable + + # Temporary directory for saving generated audio files + app.state.temp_audio_dir = os.path.join(tmp_root, "api_audio") + os.makedirs(app.state.temp_audio_dir, exist_ok=True) + + # Initialize local cache + try: + from acestep.local_cache import get_local_cache + local_cache_dir = os.path.join(cache_root, "local_redis") + app.state.local_cache = get_local_cache(local_cache_dir) + except ImportError: + app.state.local_cache = None + + async def _ensure_initialized() -> None: + """Check if models are initialized (they should be loaded at startup).""" + if getattr(app.state, "_init_error", None): + raise RuntimeError(app.state._init_error) + if not getattr(app.state, "_initialized", False): + raise RuntimeError("Model not initialized") + + async def _cleanup_job_temp_files(job_id: str) -> None: + async with app.state.job_temp_files_lock: + paths = app.state.job_temp_files.pop(job_id, []) + for p in paths: + try: + os.remove(p) + except Exception: + pass + + def _update_local_cache(job_id: str, result: Optional[Dict], status: str) -> None: + """Update local cache with job result""" + local_cache = getattr(app.state, 'local_cache', None) + if not local_cache: + return + + rec = store.get(job_id) + env = getattr(rec, 'env', 'development') if rec else 'development' + create_time = rec.created_at if rec else time.time() + + status_int = _map_status(status) + + if status == "succeeded" and result: + # Check if it's a "Full Analysis" result + if result.get("status_message") == "Full Hardware Analysis Success": + result_data = [result] + else: + audio_paths = result.get("audio_paths", []) + # Final prompt/lyrics (may be modified by thinking/format) + final_prompt = result.get("prompt", "") + final_lyrics = result.get("lyrics", "") + # Original user input from metas + metas_raw = result.get("metas", {}) or {} + original_prompt = metas_raw.get("prompt", "") + original_lyrics = metas_raw.get("lyrics", "") + # metas contains original input + other metadata + metas = { + "bpm": metas_raw.get("bpm"), + "duration": metas_raw.get("duration"), + "genres": metas_raw.get("genres", ""), + "keyscale": metas_raw.get("keyscale", ""), + "timesignature": metas_raw.get("timesignature", ""), + "prompt": original_prompt, + "lyrics": original_lyrics, + } + # Extra fields for Discord bot + generation_info = result.get("generation_info", "") + seed_value = result.get("seed_value", "") + lm_model = result.get("lm_model", "") + dit_model = result.get("dit_model", "") + + if audio_paths: + result_data = [ + { + "file": p, + "wave": "", + "status": status_int, + "create_time": int(create_time), + "env": env, + "prompt": final_prompt, + "lyrics": final_lyrics, + "metas": metas, + "generation_info": generation_info, + "seed_value": seed_value, + "lm_model": lm_model, + "dit_model": dit_model, + "progress": 1.0, + "stage": "succeeded", + } + for p in audio_paths + ] + else: + result_data = [{ + "file": "", + "wave": "", + "status": status_int, + "create_time": int(create_time), + "env": env, + "prompt": final_prompt, + "lyrics": final_lyrics, + "metas": metas, + "generation_info": generation_info, + "seed_value": seed_value, + "lm_model": lm_model, + "dit_model": dit_model, + "progress": 1.0, + "stage": "succeeded", + }] + else: + result_data = [{ + "file": "", + "wave": "", + "status": status_int, + "create_time": int(create_time), + "env": env, + "progress": 0.0, + "stage": "failed" if status == "failed" else status, + }] + + result_key = f"{RESULT_KEY_PREFIX}{job_id}" + local_cache.set(result_key, result_data, ex=RESULT_EXPIRE_SECONDS) + + def _update_local_cache_progress(job_id: str, progress: float, stage: str) -> None: + """Update local cache with job progress for queued/running states.""" + local_cache = getattr(app.state, 'local_cache', None) + if not local_cache: + return + + rec = store.get(job_id) + env = getattr(rec, 'env', 'development') if rec else 'development' + create_time = rec.created_at if rec else time.time() + status_int = _map_status("running") + + result_data = [{ + "file": "", + "wave": "", + "status": status_int, + "create_time": int(create_time), + "env": env, + "progress": float(progress), + "stage": stage, + }] + + result_key = f"{RESULT_KEY_PREFIX}{job_id}" + local_cache.set(result_key, result_data, ex=RESULT_EXPIRE_SECONDS) + + async def _run_one_job(job_id: str, req: GenerateMusicRequest) -> None: + job_store: _JobStore = app.state.job_store + llm: LLMHandler = app.state.llm_handler + executor: ThreadPoolExecutor = app.state.executor + + await _ensure_initialized() + job_store.mark_running(job_id) + _update_local_cache_progress(job_id, 0.01, "running") + + # Select DiT handler based on user's model choice + # Default: use primary handler + selected_handler: AceStepHandler = app.state.handler + selected_model_name = _get_model_name(app.state._config_path) + + if req.model: + model_matched = False + + # Check if it matches the second model + if app.state.handler2 and getattr(app.state, "_initialized2", False): + model2_name = _get_model_name(app.state._config_path2) + if req.model == model2_name: + selected_handler = app.state.handler2 + selected_model_name = model2_name + model_matched = True + print(f"[API Server] Job {job_id}: Using second model: {model2_name}") + + # Check if it matches the third model + if not model_matched and app.state.handler3 and getattr(app.state, "_initialized3", False): + model3_name = _get_model_name(app.state._config_path3) + if req.model == model3_name: + selected_handler = app.state.handler3 + selected_model_name = model3_name + model_matched = True + print(f"[API Server] Job {job_id}: Using third model: {model3_name}") + + if not model_matched: + available_models = [_get_model_name(app.state._config_path)] + if app.state.handler2 and getattr(app.state, "_initialized2", False): + available_models.append(_get_model_name(app.state._config_path2)) + if app.state.handler3 and getattr(app.state, "_initialized3", False): + available_models.append(_get_model_name(app.state._config_path3)) + print(f"[API Server] Job {job_id}: Model '{req.model}' not found in {available_models}, using primary: {selected_model_name}") + + # Use selected handler for generation + h: AceStepHandler = selected_handler + + def _blocking_generate() -> Dict[str, Any]: + """Generate music using unified inference logic from acestep.inference""" + + def _ensure_llm_ready() -> None: + """Ensure LLM handler is initialized when needed""" + with app.state._llm_init_lock: + initialized = getattr(app.state, "_llm_initialized", False) + had_error = getattr(app.state, "_llm_init_error", None) + if initialized or had_error is not None: + return + print("[API Server] reloading.") + + # Check if lazy loading is disabled (GPU memory insufficient) + if getattr(app.state, "_llm_lazy_load_disabled", False): + app.state._llm_init_error = ( + "LLM not initialized at startup. To enable LLM, set ACESTEP_INIT_LLM=true " + "in .env or environment variables. For this request, optional LLM features " + "(use_cot_caption, use_cot_language) will be auto-disabled." + ) + print(f"[API Server] LLM lazy load blocked: LLM was not initialized at startup") + return + + project_root = _get_project_root() + checkpoint_dir = os.path.join(project_root, "checkpoints") + lm_model_path = (req.lm_model_path or os.getenv("ACESTEP_LM_MODEL_PATH") or "acestep-5Hz-lm-0.6B").strip() + backend = (req.lm_backend or os.getenv("ACESTEP_LM_BACKEND") or "vllm").strip().lower() + if backend not in {"vllm", "pt", "mlx"}: + backend = "vllm" + + # Auto-download LM model if not present + lm_model_name = _get_model_name(lm_model_path) + if lm_model_name: + try: + _ensure_model_downloaded(lm_model_name, checkpoint_dir) + except Exception as e: + print(f"[API Server] Warning: Failed to download LM model {lm_model_name}: {e}") + + lm_device = os.getenv("ACESTEP_LM_DEVICE", os.getenv("ACESTEP_DEVICE", "auto")) + lm_offload = _env_bool("ACESTEP_LM_OFFLOAD_TO_CPU", False) + + status, ok = llm.initialize( + checkpoint_dir=checkpoint_dir, + lm_model_path=lm_model_path, + backend=backend, + device=lm_device, + offload_to_cpu=lm_offload, + dtype=None, + ) + if not ok: + app.state._llm_init_error = status + else: + app.state._llm_initialized = True + + def _normalize_metas(meta: Dict[str, Any]) -> Dict[str, Any]: + """Ensure a stable `metas` dict (keys always present).""" + meta = meta or {} + out: Dict[str, Any] = dict(meta) + + # Normalize key aliases + if "keyscale" not in out and "key_scale" in out: + out["keyscale"] = out.get("key_scale") + if "timesignature" not in out and "time_signature" in out: + out["timesignature"] = out.get("time_signature") + + # Ensure required keys exist + for k in ["bpm", "duration", "genres", "keyscale", "timesignature"]: + if out.get(k) in (None, ""): + out[k] = "N/A" + return out + + # Normalize LM sampling parameters + lm_top_k = req.lm_top_k if req.lm_top_k and req.lm_top_k > 0 else 0 + lm_top_p = req.lm_top_p if req.lm_top_p and req.lm_top_p < 1.0 else 0.9 + + # Determine if LLM is needed + thinking = bool(req.thinking) + sample_mode = bool(req.sample_mode) + has_sample_query = bool(req.sample_query and req.sample_query.strip()) + use_format = bool(req.use_format) + use_cot_caption = bool(req.use_cot_caption) + use_cot_language = bool(req.use_cot_language) + + full_analysis_only = bool(req.full_analysis_only) + + # Unload LM for cover tasks on MPS to reduce memory; reload lazily when needed. + if req.task_type == "cover" and h.device == "mps": + if getattr(app.state, "_llm_initialized", False) and getattr(llm, "llm_initialized", False): + try: + print("[API Server] unloading.") + llm.unload() + app.state._llm_initialized = False + app.state._llm_init_error = None + except Exception as e: + print(f"[API Server] Failed to unload LM: {e}") + + # LLM is REQUIRED for these features (fail if unavailable): + # - thinking mode (LM generates audio codes) + # - sample_mode (LM generates random caption/lyrics/metas) + # - sample_query/description (LM generates from description) + # - use_format (LM enhances caption/lyrics) + # - full_analysis_only (LM understands audio codes) + require_llm = thinking or sample_mode or has_sample_query or use_format or full_analysis_only + + # LLM is OPTIONAL for these features (auto-disable if unavailable): + # - use_cot_caption or use_cot_language (LM enhances metadata) + want_llm = use_cot_caption or use_cot_language + + # Check if LLM is available + llm_available = True + if require_llm or want_llm: + _ensure_llm_ready() + if getattr(app.state, "_llm_init_error", None): + llm_available = False + + # Fail if LLM is required but unavailable + if require_llm and not llm_available: + raise RuntimeError(f"5Hz LM init failed: {app.state._llm_init_error}") + + # Auto-disable optional LLM features if unavailable + if want_llm and not llm_available: + if use_cot_caption or use_cot_language: + print(f"[API Server] LLM unavailable, auto-disabling: use_cot_caption={use_cot_caption}->False, use_cot_language={use_cot_language}->False") + use_cot_caption = False + use_cot_language = False + + # Handle sample mode or description: generate caption/lyrics/metas via LM + caption = req.prompt + lyrics = req.lyrics + bpm = req.bpm + key_scale = req.key_scale + time_signature = req.time_signature + audio_duration = req.audio_duration + + # Save original user input for metas + original_prompt = req.prompt or "" + original_lyrics = req.lyrics or "" + + if sample_mode or has_sample_query: + # Parse description hints from sample_query (if provided) + sample_query = req.sample_query if has_sample_query else "NO USER INPUT" + parsed_language, parsed_instrumental = _parse_description_hints(sample_query) + + # Determine vocal_language with priority: + # 1. User-specified vocal_language (if not default "en") + # 2. Language parsed from description + # 3. None (no constraint) + if req.vocal_language and req.vocal_language not in ("en", "unknown", ""): + sample_language = req.vocal_language + else: + sample_language = parsed_language + + sample_result = create_sample( + llm_handler=llm, + query=sample_query, + instrumental=parsed_instrumental, + vocal_language=sample_language, + temperature=req.lm_temperature, + top_k=lm_top_k if lm_top_k > 0 else None, + top_p=lm_top_p if lm_top_p < 1.0 else None, + use_constrained_decoding=True, + ) + + if not sample_result.success: + raise RuntimeError(f"create_sample failed: {sample_result.error or sample_result.status_message}") + + # Use generated sample data + caption = sample_result.caption + lyrics = sample_result.lyrics + bpm = sample_result.bpm + key_scale = sample_result.keyscale + time_signature = sample_result.timesignature + audio_duration = sample_result.duration + + # Apply format_sample() if use_format is True and caption/lyrics are provided + format_has_duration = False + + if req.use_format and (caption or lyrics): + _ensure_llm_ready() + if getattr(app.state, "_llm_init_error", None): + raise RuntimeError(f"5Hz LM init failed (needed for format): {app.state._llm_init_error}") + + # Build user_metadata from request params (matching bot.py behavior) + user_metadata_for_format = {} + if bpm is not None: + user_metadata_for_format['bpm'] = bpm + if audio_duration is not None and float(audio_duration) > 0: + user_metadata_for_format['duration'] = float(audio_duration) + if key_scale: + user_metadata_for_format['keyscale'] = key_scale + if time_signature: + user_metadata_for_format['timesignature'] = time_signature + if req.vocal_language and req.vocal_language != "unknown": + user_metadata_for_format['language'] = req.vocal_language + + format_result = format_sample( + llm_handler=llm, + caption=caption, + lyrics=lyrics, + user_metadata=user_metadata_for_format if user_metadata_for_format else None, + temperature=req.lm_temperature, + top_k=lm_top_k if lm_top_k > 0 else None, + top_p=lm_top_p if lm_top_p < 1.0 else None, + use_constrained_decoding=True, + ) + + if format_result.success: + # Extract all formatted data (matching bot.py behavior) + caption = format_result.caption or caption + lyrics = format_result.lyrics or lyrics + if format_result.duration: + audio_duration = format_result.duration + format_has_duration = True + if format_result.bpm: + bpm = format_result.bpm + if format_result.keyscale: + key_scale = format_result.keyscale + if format_result.timesignature: + time_signature = format_result.timesignature + + # Parse timesteps string to list of floats if provided + parsed_timesteps = _parse_timesteps(req.timesteps) + + # Determine actual inference steps (timesteps override inference_steps) + actual_inference_steps = len(parsed_timesteps) if parsed_timesteps else req.inference_steps + + # Auto-select instruction based on task_type if user didn't provide custom instruction + # This matches gradio behavior which uses TASK_INSTRUCTIONS for each task type + instruction_to_use = req.instruction + if instruction_to_use == DEFAULT_DIT_INSTRUCTION and req.task_type in TASK_INSTRUCTIONS: + instruction_to_use = TASK_INSTRUCTIONS[req.task_type] + + # Build GenerationParams using unified interface + # Note: thinking controls LM code generation, sample_mode only affects CoT metas + params = GenerationParams( + task_type=req.task_type, + instruction=instruction_to_use, + reference_audio=req.reference_audio_path, + src_audio=req.src_audio_path, + audio_codes="", + caption=caption, + lyrics=lyrics, + instrumental=_is_instrumental(lyrics), + vocal_language=req.vocal_language, + bpm=bpm, + keyscale=key_scale, + timesignature=time_signature, + duration=audio_duration if audio_duration else -1.0, + inference_steps=req.inference_steps, + seed=req.seed, + guidance_scale=req.guidance_scale, + use_adg=req.use_adg, + cfg_interval_start=req.cfg_interval_start, + cfg_interval_end=req.cfg_interval_end, + shift=req.shift, + infer_method=req.infer_method, + timesteps=parsed_timesteps, + repainting_start=req.repainting_start, + repainting_end=req.repainting_end if req.repainting_end else -1, + audio_cover_strength=req.audio_cover_strength, + # LM parameters + thinking=thinking, # Use LM for code generation when thinking=True + lm_temperature=req.lm_temperature, + lm_cfg_scale=req.lm_cfg_scale, + lm_top_k=lm_top_k, + lm_top_p=lm_top_p, + lm_negative_prompt=req.lm_negative_prompt, + # use_cot_metas logic: + # - sample_mode: metas already generated, skip Phase 1 + # - format with duration: metas already generated, skip Phase 1 + # - format without duration: need Phase 1 to generate duration + # - no format: need Phase 1 to generate all metas + use_cot_metas=not sample_mode and not format_has_duration, + use_cot_caption=use_cot_caption, # Use local var (may be auto-disabled) + use_cot_language=use_cot_language, # Use local var (may be auto-disabled) + use_constrained_decoding=True, + ) + + # Build GenerationConfig - default to 2 audios like gradio_ui + batch_size = req.batch_size if req.batch_size is not None else 2 + config = GenerationConfig( + batch_size=batch_size, + allow_lm_batch=req.allow_lm_batch, + use_random_seed=req.use_random_seed, + seeds=None, # Let unified logic handle seed generation + audio_format=req.audio_format, + constrained_decoding_debug=req.constrained_decoding_debug, + ) + + # Check LLM initialization status + llm_is_initialized = getattr(app.state, "_llm_initialized", False) + llm_to_pass = llm if llm_is_initialized else None + + # Progress callback for API polling + last_progress = {"value": -1.0, "time": 0.0, "stage": ""} + + def _progress_cb(value: float, desc: str = "") -> None: + now = time.time() + try: + value_f = max(0.0, min(1.0, float(value))) + except Exception: + value_f = 0.0 + stage = desc or last_progress["stage"] or "running" + # Throttle updates to avoid excessive cache writes + if ( + value_f - last_progress["value"] >= 0.01 + or stage != last_progress["stage"] + or (now - last_progress["time"]) >= 0.5 + ): + last_progress["value"] = value_f + last_progress["time"] = now + last_progress["stage"] = stage + job_store.update_progress(job_id, value_f, stage=stage) + _update_local_cache_progress(job_id, value_f, stage) + + if req.full_analysis_only: + store.update_progress_text(job_id, "Starting Deep Analysis...") + # Step A: Convert source audio to semantic codes + # We use params.src_audio which is the server-side path + audio_codes = h.convert_src_audio_to_codes(params.src_audio) + + if not audio_codes or audio_codes.startswith("❌"): + raise RuntimeError(f"Audio encoding failed: {audio_codes}") + + # Step B: LLM Understanding of those specific codes + # This yields the deep metadata and lyrics transcription + metadata_dict, status_string = llm_to_pass.understand_audio_from_codes( + audio_codes=audio_codes, + temperature=0.3, + use_constrained_decoding=True, + constrained_decoding_debug=config.constrained_decoding_debug + ) + + if not metadata_dict: + raise RuntimeError(f"LLM Understanding failed: {status_string}") + + return { + "status_message": "Full Hardware Analysis Success", + "bpm": metadata_dict.get("bpm"), + "keyscale": metadata_dict.get("keyscale"), + "timesignature": metadata_dict.get("timesignature"), + "duration": metadata_dict.get("duration"), + "genre": metadata_dict.get("genres") or metadata_dict.get("genre"), + "prompt": metadata_dict.get("caption", ""), + "lyrics": metadata_dict.get("lyrics", ""), + "language": metadata_dict.get("language", "unknown"), + "metas": metadata_dict, + "audio_paths": [] + } + + if req.analysis_only: + lm_res = llm_to_pass.generate_with_stop_condition( + caption=params.caption, + lyrics=params.lyrics, + infer_type="dit", + temperature=req.lm_temperature, + top_p=req.lm_top_p, + use_cot_metas=True, + use_cot_caption=req.use_cot_caption, + use_cot_language=req.use_cot_language, + use_constrained_decoding=True + ) + + if not lm_res.get("success"): + raise RuntimeError(f"Analysis Failed: {lm_res.get('error')}") + + metas_found = lm_res.get("metadata", {}) + return { + "first_audio_path": None, + "audio_paths": [], + "raw_audio_paths": [], + "generation_info": "Analysis Only Mode Complete", + "status_message": "Success", + "metas": metas_found, + "bpm": metas_found.get("bpm"), + "keyscale": metas_found.get("keyscale"), + "duration": metas_found.get("duration"), + "prompt": metas_found.get("caption", params.caption), + "lyrics": params.lyrics, + "lm_model": os.getenv("ACESTEP_LM_MODEL_PATH", ""), + "dit_model": "None (Analysis Only)" + } + + # Generate music using unified interface + sequential_runs = 1 + if req.task_type == "cover" and h.device == "mps": + # If user asked for multiple outputs, run sequentially on MPS to avoid OOM. + if config.batch_size is not None and config.batch_size > 1: + sequential_runs = int(config.batch_size) + config.batch_size = 1 + print(f"[API Server] Job {job_id}: MPS cover sequential mode enabled (runs={sequential_runs})") + + def _progress_for_slice(start: float, end: float): + base = {"seen": False, "value": 0.0} + def _cb(value: float, desc: str = "") -> None: + try: + value_f = max(0.0, min(1.0, float(value))) + except Exception: + value_f = 0.0 + if not base["seen"]: + base["seen"] = True + base["value"] = value_f + # Normalize progress to avoid initial jump (e.g., 0.51 -> 0.0) + if value_f <= base["value"]: + norm = 0.0 + else: + denom = max(1e-6, 1.0 - base["value"]) + norm = min(1.0, (value_f - base["value"]) / denom) + mapped = start + (end - start) * norm + _progress_cb(mapped, desc=desc) + return _cb + + aggregated_result = None + all_audios: List[Dict[str, Any]] = [] + for run_idx in range(sequential_runs): + if sequential_runs > 1: + print(f"[API Server] Job {job_id}: Sequential cover run {run_idx + 1}/{sequential_runs}") + if sequential_runs > 1: + start = run_idx / sequential_runs + end = (run_idx + 1) / sequential_runs + progress_cb = _progress_for_slice(start, end) + else: + progress_cb = _progress_cb + + result = generate_music( + dit_handler=h, + llm_handler=llm_to_pass, + params=params, + config=config, + save_dir=app.state.temp_audio_dir, + progress=progress_cb, + ) + if not result.success: + raise RuntimeError(f"Music generation failed: {result.error or result.status_message}") + + if aggregated_result is None: + aggregated_result = result + all_audios.extend(result.audios) + + # Use aggregated result with combined audios + if aggregated_result is None: + raise RuntimeError("Music generation failed: no results") + aggregated_result.audios = all_audios + result = aggregated_result + + if not result.success: + raise RuntimeError(f"Music generation failed: {result.error or result.status_message}") + + # Extract results + audio_paths = [audio["path"] for audio in result.audios if audio.get("path")] + first_audio = audio_paths[0] if len(audio_paths) > 0 else None + second_audio = audio_paths[1] if len(audio_paths) > 1 else None + + # Get metadata from LM or CoT results + lm_metadata = result.extra_outputs.get("lm_metadata", {}) + metas_out = _normalize_metas(lm_metadata) + + # Update metas with actual values used + if params.cot_bpm: + metas_out["bpm"] = params.cot_bpm + elif bpm: + metas_out["bpm"] = bpm + + if params.cot_duration: + metas_out["duration"] = params.cot_duration + elif audio_duration: + metas_out["duration"] = audio_duration + + if params.cot_keyscale: + metas_out["keyscale"] = params.cot_keyscale + elif key_scale: + metas_out["keyscale"] = key_scale + + if params.cot_timesignature: + metas_out["timesignature"] = params.cot_timesignature + elif time_signature: + metas_out["timesignature"] = time_signature + + # Store original user input in metas (not the final/modified values) + metas_out["prompt"] = original_prompt + metas_out["lyrics"] = original_lyrics + + # Extract seed values for response (comma-separated for multiple audios) + seed_values = [] + for audio in result.audios: + audio_params = audio.get("params", {}) + seed = audio_params.get("seed") + if seed is not None: + seed_values.append(str(seed)) + seed_value = ",".join(seed_values) if seed_values else "" + + # Build generation_info using the helper function (like gradio_ui) + time_costs = result.extra_outputs.get("time_costs", {}) + generation_info = _build_generation_info( + lm_metadata=lm_metadata, + time_costs=time_costs, + seed_value=seed_value, + inference_steps=req.inference_steps, + num_audios=len(result.audios), + ) + + def _none_if_na_str(v: Any) -> Optional[str]: + if v is None: + return None + s = str(v).strip() + if s in {"", "N/A"}: + return None + return s + + # Get model information + lm_model_name = os.getenv("ACESTEP_LM_MODEL_PATH", "acestep-5Hz-lm-0.6B") + # Use selected_model_name (set at the beginning of _run_one_job) + dit_model_name = selected_model_name + + return { + "first_audio_path": _path_to_audio_url(first_audio) if first_audio else None, + "second_audio_path": _path_to_audio_url(second_audio) if second_audio else None, + "audio_paths": [_path_to_audio_url(p) for p in audio_paths], + "raw_audio_paths": list(audio_paths), + "generation_info": generation_info, + "status_message": result.status_message, + "seed_value": seed_value, + # Final prompt/lyrics (may be modified by thinking/format) + "prompt": caption or "", + "lyrics": lyrics or "", + # metas contains original user input + other metadata + "metas": metas_out, + "bpm": metas_out.get("bpm") if isinstance(metas_out.get("bpm"), int) else None, + "duration": metas_out.get("duration") if isinstance(metas_out.get("duration"), (int, float)) else None, + "genres": _none_if_na_str(metas_out.get("genres")), + "keyscale": _none_if_na_str(metas_out.get("keyscale")), + "timesignature": _none_if_na_str(metas_out.get("timesignature")), + "lm_model": lm_model_name, + "dit_model": dit_model_name, + } + + t0 = time.time() + try: + loop = asyncio.get_running_loop() + result = await loop.run_in_executor(executor, _blocking_generate) + job_store.mark_succeeded(job_id, result) + + # Update local cache + _update_local_cache(job_id, result, "succeeded") + except Exception as e: + error_traceback = traceback.format_exc() + print(f"[API Server] Job {job_id} FAILED: {e}") + print(f"[API Server] Traceback:\n{error_traceback}") + job_store.mark_failed(job_id, error_traceback) + + # Update local cache + _update_local_cache(job_id, None, "failed") + finally: + # Best-effort cache cleanup to reduce MPS memory fragmentation between jobs + try: + if hasattr(h, "_empty_cache"): + h._empty_cache() + else: + import torch + if hasattr(torch, "mps") and hasattr(torch.mps, "empty_cache"): + torch.mps.empty_cache() + except Exception: + pass + dt = max(0.0, time.time() - t0) + async with app.state.stats_lock: + app.state.recent_durations.append(dt) + if app.state.recent_durations: + app.state.avg_job_seconds = sum(app.state.recent_durations) / len(app.state.recent_durations) + + async def _queue_worker(worker_idx: int) -> None: + while True: + job_id, req = await app.state.job_queue.get() + rec = store.get(job_id) + try: + async with app.state.pending_lock: + try: + app.state.pending_ids.remove(job_id) + except ValueError: + pass + + await _run_one_job(job_id, req) + + # Notify OpenRouter waiters after job completion + if rec and rec.progress_queue: + if rec.status == "succeeded" and rec.result: + await rec.progress_queue.put({"type": "result", "result": rec.result}) + elif rec.status == "failed": + await rec.progress_queue.put({"type": "error", "content": rec.error or "Generation failed"}) + await rec.progress_queue.put({"type": "done"}) + if rec and rec.done_event: + rec.done_event.set() + + except Exception as exc: + # _run_one_job raised (e.g. _ensure_initialized failed) + if rec and rec.status not in ("succeeded", "failed"): + store.mark_failed(job_id, str(exc)) + if rec and rec.progress_queue: + await rec.progress_queue.put({"type": "error", "content": str(exc)}) + await rec.progress_queue.put({"type": "done"}) + if rec and rec.done_event: + rec.done_event.set() + finally: + await _cleanup_job_temp_files(job_id) + app.state.job_queue.task_done() + + async def _job_store_cleanup_worker() -> None: + """Background task to periodically clean up old completed jobs.""" + while True: + try: + await asyncio.sleep(JOB_STORE_CLEANUP_INTERVAL) + removed = store.cleanup_old_jobs() + if removed > 0: + stats = store.get_stats() + print(f"[API Server] Cleaned up {removed} old jobs. Current stats: {stats}") + except asyncio.CancelledError: + break + except Exception as e: + print(f"[API Server] Job cleanup error: {e}") + + worker_count = max(1, WORKER_COUNT) + workers = [asyncio.create_task(_queue_worker(i)) for i in range(worker_count)] + cleanup_task = asyncio.create_task(_job_store_cleanup_worker()) + app.state.worker_tasks = workers + app.state.cleanup_task = cleanup_task + + # ================================================================= + # Initialize models at startup (not lazily on first request) + # ================================================================= + print("[API Server] Initializing models at startup...") + + # Detect GPU memory and get configuration + gpu_config = get_gpu_config() + set_global_gpu_config(gpu_config) + app.state.gpu_config = gpu_config + + gpu_memory_gb = gpu_config.gpu_memory_gb + auto_offload = gpu_memory_gb > 0 and gpu_memory_gb < VRAM_16GB_MIN_GB + + # Print GPU configuration info + print(f"\n{'='*60}") + print("[API Server] GPU Configuration Detected:") + print(f"{'='*60}") + print(f" GPU Memory: {gpu_memory_gb:.2f} GB") + print(f" Configuration Tier: {gpu_config.tier}") + print(f" Max Duration (with LM): {gpu_config.max_duration_with_lm}s") + print(f" Max Duration (without LM): {gpu_config.max_duration_without_lm}s") + print(f" Max Batch Size (with LM): {gpu_config.max_batch_size_with_lm}") + print(f" Max Batch Size (without LM): {gpu_config.max_batch_size_without_lm}") + print(f" Default LM Init: {gpu_config.init_lm_default}") + print(f" Available LM Models: {gpu_config.available_lm_models or 'None'}") + print(f"{'='*60}\n") + + if auto_offload: + print(f"[API Server] Auto-enabling CPU offload (GPU < 16GB)") + elif gpu_memory_gb > 0: + print(f"[API Server] CPU offload disabled by default (GPU >= 16GB)") + else: + print("[API Server] No GPU detected, running on CPU") + + project_root = _get_project_root() + config_path = os.getenv("ACESTEP_CONFIG_PATH", "acestep-v15-turbo") + device = os.getenv("ACESTEP_DEVICE", "auto") + use_flash_attention = _env_bool("ACESTEP_USE_FLASH_ATTENTION", True) + + # Auto-determine offload settings based on GPU config if not explicitly set + offload_to_cpu_env = os.getenv("ACESTEP_OFFLOAD_TO_CPU") + if offload_to_cpu_env is not None: + offload_to_cpu = _env_bool("ACESTEP_OFFLOAD_TO_CPU", False) + else: + offload_to_cpu = auto_offload + if auto_offload: + print(f"[API Server] Auto-setting offload_to_cpu=True based on GPU memory") + + offload_dit_to_cpu = _env_bool("ACESTEP_OFFLOAD_DIT_TO_CPU", False) + + # Checkpoint directory + checkpoint_dir = os.path.join(project_root, "checkpoints") + os.makedirs(checkpoint_dir, exist_ok=True) + + # Download and initialize primary DiT model + dit_model_name = _get_model_name(config_path) + if dit_model_name: + try: + _ensure_model_downloaded(dit_model_name, checkpoint_dir) + except Exception as e: + print(f"[API Server] Warning: Failed to download DiT model: {e}") + + # Download VAE model + try: + _ensure_model_downloaded("vae", checkpoint_dir) + except Exception as e: + print(f"[API Server] Warning: Failed to download VAE model: {e}") + + print(f"[API Server] Loading primary DiT model: {config_path}") + status_msg, ok = handler.initialize_service( + project_root=project_root, + config_path=config_path, + device=device, + use_flash_attention=use_flash_attention, + compile_model=False, + offload_to_cpu=offload_to_cpu, + offload_dit_to_cpu=offload_dit_to_cpu, + ) + if not ok: + app.state._init_error = status_msg + print(f"[API Server] ERROR: Primary model failed to load: {status_msg}") + raise RuntimeError(status_msg) + app.state._initialized = True + print(f"[API Server] Primary model loaded: {_get_model_name(config_path)}") + + # Initialize secondary model if configured + if handler2 and config_path2: + model2_name = _get_model_name(config_path2) + if model2_name: + try: + _ensure_model_downloaded(model2_name, checkpoint_dir) + except Exception as e: + print(f"[API Server] Warning: Failed to download secondary model: {e}") + + print(f"[API Server] Loading secondary DiT model: {config_path2}") + try: + status_msg2, ok2 = handler2.initialize_service( + project_root=project_root, + config_path=config_path2, + device=device, + use_flash_attention=use_flash_attention, + compile_model=False, + offload_to_cpu=offload_to_cpu, + offload_dit_to_cpu=offload_dit_to_cpu, + ) + app.state._initialized2 = ok2 + if ok2: + print(f"[API Server] Secondary model loaded: {model2_name}") + else: + print(f"[API Server] Warning: Secondary model failed: {status_msg2}") + except Exception as e: + print(f"[API Server] Warning: Failed to initialize secondary model: {e}") + app.state._initialized2 = False + + # Initialize third model if configured + if handler3 and config_path3: + model3_name = _get_model_name(config_path3) + if model3_name: + try: + _ensure_model_downloaded(model3_name, checkpoint_dir) + except Exception as e: + print(f"[API Server] Warning: Failed to download third model: {e}") + + print(f"[API Server] Loading third DiT model: {config_path3}") + try: + status_msg3, ok3 = handler3.initialize_service( + project_root=project_root, + config_path=config_path3, + device=device, + use_flash_attention=use_flash_attention, + compile_model=False, + offload_to_cpu=offload_to_cpu, + offload_dit_to_cpu=offload_dit_to_cpu, + ) + app.state._initialized3 = ok3 + if ok3: + print(f"[API Server] Third model loaded: {model3_name}") + else: + print(f"[API Server] Warning: Third model failed: {status_msg3}") + except Exception as e: + print(f"[API Server] Warning: Failed to initialize third model: {e}") + app.state._initialized3 = False + + # Initialize LLM model based on GPU configuration + # ACESTEP_INIT_LLM controls LLM initialization: + # - "auto" / empty / not set: Use GPU config default (auto-detect) + # - "true"/"1"/"yes": Force enable LLM after GPU config is applied + # - "false"/"0"/"no": Force disable LLM + # + # Flow: GPU detection → model validation → ACESTEP_INIT_LLM override + # This ensures GPU optimizations (offload, quantization, etc.) are always applied. + init_llm_env = os.getenv("ACESTEP_INIT_LLM", "").strip().lower() + + # Step 1: Start with GPU auto-detection result + init_llm = gpu_config.init_lm_default + print(f"[API Server] GPU auto-detection: init_llm={init_llm} (VRAM: {gpu_config.gpu_memory_gb:.1f}GB, tier: {gpu_config.tier})") + + # Step 2: Apply user override if set + if not init_llm_env or init_llm_env == "auto": + print(f"[API Server] ACESTEP_INIT_LLM=auto, using GPU auto-detection result") + elif init_llm_env in {"1", "true", "yes", "y", "on"}: + if init_llm: + print(f"[API Server] ACESTEP_INIT_LLM=true (GPU already supports LLM, no override needed)") + else: + init_llm = True + print(f"[API Server] ACESTEP_INIT_LLM=true, overriding GPU auto-detection (force enable)") + else: + if not init_llm: + print(f"[API Server] ACESTEP_INIT_LLM=false (GPU already disabled LLM, no override needed)") + else: + init_llm = False + print(f"[API Server] ACESTEP_INIT_LLM=false, overriding GPU auto-detection (force disable)") + + if init_llm: + print("[API Server] Loading LLM model...") + + # Auto-select LM model based on GPU config if not explicitly set + lm_model_path_env = os.getenv("ACESTEP_LM_MODEL_PATH", "").strip() + if lm_model_path_env: + lm_model_path = lm_model_path_env + print(f"[API Server] Using user-specified LM model: {lm_model_path}") + else: + # Get recommended LM model for this GPU tier + recommended_lm = get_recommended_lm_model(gpu_config) + if recommended_lm: + lm_model_path = recommended_lm + print(f"[API Server] Auto-selected LM model: {lm_model_path} based on GPU tier") + else: + # No recommended model (GPU tier too low), default to smallest + lm_model_path = "acestep-5Hz-lm-0.6B" + print(f"[API Server] No recommended model for this GPU tier, using smallest: {lm_model_path}") + + # Validate LM model support (warning only, does not block) + is_supported, warning_msg = is_lm_model_supported(lm_model_path, gpu_config) + if not is_supported: + print(f"[API Server] Warning: {warning_msg}") + # Try to fall back to a supported model + recommended_lm = get_recommended_lm_model(gpu_config) + if recommended_lm: + lm_model_path = recommended_lm + print(f"[API Server] Falling back to supported LM model: {lm_model_path}") + else: + # No supported model, but user may have forced init + print(f"[API Server] No GPU-validated LM model available, attempting {lm_model_path} anyway (may cause OOM)") + + if init_llm: + lm_backend = os.getenv("ACESTEP_LM_BACKEND", "vllm").strip().lower() + if lm_backend not in {"vllm", "pt", "mlx"}: + lm_backend = "vllm" + lm_device = os.getenv("ACESTEP_LM_DEVICE", device) + + # Auto-determine LM offload based on GPU config + lm_offload_env = os.getenv("ACESTEP_LM_OFFLOAD_TO_CPU") + if lm_offload_env is not None: + lm_offload = _env_bool("ACESTEP_LM_OFFLOAD_TO_CPU", False) + else: + lm_offload = offload_to_cpu + + try: + _ensure_model_downloaded(lm_model_path, checkpoint_dir) + except Exception as e: + print(f"[API Server] Warning: Failed to download LLM model: {e}") + + llm_status, llm_ok = llm_handler.initialize( + checkpoint_dir=checkpoint_dir, + lm_model_path=lm_model_path, + backend=lm_backend, + device=lm_device, + offload_to_cpu=lm_offload, + dtype=None, + ) + if llm_ok: + app.state._llm_initialized = True + print(f"[API Server] LLM model loaded: {lm_model_path}") + else: + app.state._llm_init_error = llm_status + print(f"[API Server] Warning: LLM model failed to load: {llm_status}") + else: + print("[API Server] Skipping LLM initialization (disabled or not supported for this GPU)") + app.state._llm_initialized = False + # Disable lazy loading of LLM - don't try to load it later during requests + app.state._llm_lazy_load_disabled = True + print("[API Server] LLM lazy loading disabled. To enable LLM:") + print("[API Server] - Set ACESTEP_INIT_LLM=true in .env or environment") + print("[API Server] - Or use --init-llm command line flag") + + print("[API Server] All models initialized successfully!") + + try: + yield + finally: + cleanup_task.cancel() + for t in workers: + t.cancel() + executor.shutdown(wait=False, cancel_futures=True) + + app = FastAPI(title="ACE-Step API", version="1.0", lifespan=lifespan) + + # Mount OpenRouter-compatible endpoints (/v1/chat/completions, /v1/models) + from acestep.openrouter_adapter import create_openrouter_router + openrouter_router = create_openrouter_router(lambda: app.state) + app.include_router(openrouter_router) + + async def _queue_position(job_id: str) -> int: + async with app.state.pending_lock: + try: + return list(app.state.pending_ids).index(job_id) + 1 + except ValueError: + return 0 + + async def _eta_seconds_for_position(pos: int) -> Optional[float]: + if pos <= 0: + return None + async with app.state.stats_lock: + avg = float(getattr(app.state, "avg_job_seconds", INITIAL_AVG_JOB_SECONDS)) + return pos * avg + + @app.post("/release_task") + async def create_music_generate_job(request: Request, authorization: Optional[str] = Header(None)): + content_type = (request.headers.get("content-type") or "").lower() + temp_files: list[str] = [] + + def _build_request(p: RequestParser, **kwargs) -> GenerateMusicRequest: + """Build GenerateMusicRequest from parsed parameters.""" + return GenerateMusicRequest( + prompt=p.str("prompt"), + lyrics=p.str("lyrics"), + thinking=p.bool("thinking"), + analysis_only=p.bool("analysis_only"), + full_analysis_only=p.bool("full_analysis_only"), + sample_mode=p.bool("sample_mode"), + sample_query=p.str("sample_query"), + use_format=p.bool("use_format"), + model=p.str("model") or None, + bpm=p.int("bpm"), + key_scale=p.str("key_scale"), + time_signature=p.str("time_signature"), + audio_duration=p.float("audio_duration"), + vocal_language=p.str("vocal_language", "en"), + inference_steps=p.int("inference_steps", 8), + guidance_scale=p.float("guidance_scale", 7.0), + use_random_seed=p.bool("use_random_seed", True), + seed=p.int("seed", -1), + batch_size=p.int("batch_size"), + repainting_start=p.float("repainting_start", 0.0), + repainting_end=p.float("repainting_end"), + instruction=p.str("instruction", DEFAULT_DIT_INSTRUCTION), + audio_cover_strength=p.float("audio_cover_strength", 1.0), + reference_audio_path=p.str("reference_audio_path") or None, + src_audio_path=p.str("src_audio_path") or None, + task_type=p.str("task_type", "text2music"), + use_adg=p.bool("use_adg"), + cfg_interval_start=p.float("cfg_interval_start", 0.0), + cfg_interval_end=p.float("cfg_interval_end", 1.0), + infer_method=p.str("infer_method", "ode"), + shift=p.float("shift", 3.0), + audio_format=p.str("audio_format", "mp3"), + use_tiled_decode=p.bool("use_tiled_decode", True), + lm_model_path=p.str("lm_model_path") or None, + lm_backend=p.str("lm_backend", "vllm"), + lm_temperature=p.float("lm_temperature", LM_DEFAULT_TEMPERATURE), + lm_cfg_scale=p.float("lm_cfg_scale", LM_DEFAULT_CFG_SCALE), + lm_top_k=p.int("lm_top_k"), + lm_top_p=p.float("lm_top_p", LM_DEFAULT_TOP_P), + lm_repetition_penalty=p.float("lm_repetition_penalty", 1.0), + lm_negative_prompt=p.str("lm_negative_prompt", "NO USER INPUT"), + constrained_decoding=p.bool("constrained_decoding", True), + constrained_decoding_debug=p.bool("constrained_decoding_debug"), + use_cot_caption=p.bool("use_cot_caption", True), + use_cot_language=p.bool("use_cot_language", True), + is_format_caption=p.bool("is_format_caption"), + allow_lm_batch=p.bool("allow_lm_batch", True), + **kwargs, + ) + + if content_type.startswith("application/json"): + body = await request.json() + if not isinstance(body, dict): + raise HTTPException(status_code=400, detail="JSON payload must be an object") + verify_token_from_request(body, authorization) + req = _build_request(RequestParser(body)) + + elif content_type.endswith("+json"): + body = await request.json() + if not isinstance(body, dict): + raise HTTPException(status_code=400, detail="JSON payload must be an object") + verify_token_from_request(body, authorization) + req = _build_request(RequestParser(body)) + + elif content_type.startswith("multipart/form-data"): + form = await request.form() + form_dict = {k: v for k, v in form.items() if not hasattr(v, 'read')} + verify_token_from_request(form_dict, authorization) + + # Support both naming conventions: ref_audio/reference_audio, ctx_audio/src_audio + ref_up = form.get("ref_audio") or form.get("reference_audio") + ctx_up = form.get("ctx_audio") or form.get("src_audio") + + reference_audio_path = None + src_audio_path = None + + if isinstance(ref_up, StarletteUploadFile): + reference_audio_path = await _save_upload_to_temp(ref_up, prefix="ref_audio") + temp_files.append(reference_audio_path) + else: + reference_audio_path = str(form.get("ref_audio_path") or form.get("reference_audio_path") or "").strip() or None + + if isinstance(ctx_up, StarletteUploadFile): + src_audio_path = await _save_upload_to_temp(ctx_up, prefix="ctx_audio") + temp_files.append(src_audio_path) + else: + src_audio_path = str(form.get("ctx_audio_path") or form.get("src_audio_path") or "").strip() or None + + req = _build_request( + RequestParser(dict(form)), + reference_audio_path=reference_audio_path, + src_audio_path=src_audio_path, + ) + + elif content_type.startswith("application/x-www-form-urlencoded"): + form = await request.form() + form_dict = dict(form) + verify_token_from_request(form_dict, authorization) + reference_audio_path = str(form.get("ref_audio_path") or form.get("reference_audio_path") or "").strip() or None + src_audio_path = str(form.get("ctx_audio_path") or form.get("src_audio_path") or "").strip() or None + req = _build_request( + RequestParser(form_dict), + reference_audio_path=reference_audio_path, + src_audio_path=src_audio_path, + ) + + else: + raw = await request.body() + raw_stripped = raw.lstrip() + # Best-effort: accept missing/incorrect Content-Type if payload is valid JSON. + if raw_stripped.startswith(b"{") or raw_stripped.startswith(b"["): + try: + body = json.loads(raw.decode("utf-8")) + if isinstance(body, dict): + verify_token_from_request(body, authorization) + req = _build_request(RequestParser(body)) + else: + raise HTTPException(status_code=400, detail="JSON payload must be an object") + except HTTPException: + raise + except Exception: + raise HTTPException( + status_code=400, + detail="Invalid JSON body (hint: set 'Content-Type: application/json')", + ) + # Best-effort: parse key=value bodies even if Content-Type is missing. + elif raw_stripped and b"=" in raw: + parsed = urllib.parse.parse_qs(raw.decode("utf-8"), keep_blank_values=True) + flat = {k: (v[0] if isinstance(v, list) and v else v) for k, v in parsed.items()} + verify_token_from_request(flat, authorization) + reference_audio_path = str(flat.get("ref_audio_path") or flat.get("reference_audio_path") or "").strip() or None + src_audio_path = str(flat.get("ctx_audio_path") or flat.get("src_audio_path") or "").strip() or None + req = _build_request( + RequestParser(flat), + reference_audio_path=reference_audio_path, + src_audio_path=src_audio_path, + ) + else: + raise HTTPException( + status_code=415, + detail=( + f"Unsupported Content-Type: {content_type or '(missing)'}; " + "use application/json, application/x-www-form-urlencoded, or multipart/form-data" + ), + ) + + rec = store.create() + + q: asyncio.Queue = app.state.job_queue + if q.full(): + for p in temp_files: + try: + os.remove(p) + except Exception: + pass + raise HTTPException(status_code=429, detail="Server busy: queue is full") + + if temp_files: + async with app.state.job_temp_files_lock: + app.state.job_temp_files[rec.job_id] = temp_files + + async with app.state.pending_lock: + app.state.pending_ids.append(rec.job_id) + position = len(app.state.pending_ids) + + await q.put((rec.job_id, req)) + return _wrap_response({"task_id": rec.job_id, "status": "queued", "queue_position": position}) + + @app.post("/query_result") + async def query_result(request: Request, authorization: Optional[str] = Header(None)): + """Batch query job results""" + content_type = (request.headers.get("content-type") or "").lower() + + if "json" in content_type: + body = await request.json() + else: + form = await request.form() + body = {k: v for k, v in form.items()} + + verify_token_from_request(body, authorization) + task_id_list_str = body.get("task_id_list", "[]") + + # Parse task ID list + if isinstance(task_id_list_str, list): + task_id_list = task_id_list_str + else: + try: + task_id_list = json.loads(task_id_list_str) + except Exception: + task_id_list = [] + + local_cache = getattr(app.state, 'local_cache', None) + data_list = [] + current_time = time.time() + + for task_id in task_id_list: + result_key = f"{RESULT_KEY_PREFIX}{task_id}" + + # Read from local cache first + if local_cache: + data = local_cache.get(result_key) + if data: + try: + data_json = json.loads(data) + except Exception: + data_json = [] + + if len(data_json) <= 0: + data_list.append({"task_id": task_id, "result": data, "status": 2}) + else: + status = data_json[0].get("status") + create_time = data_json[0].get("create_time", 0) + if status == 0 and (current_time - create_time) > TASK_TIMEOUT_SECONDS: + data_list.append({"task_id": task_id, "result": data, "status": 2}) + else: + data_list.append({ + "task_id": task_id, + "result": data, + "status": int(status) if status is not None else 1, + "progress_text": log_buffer.last_message + }) + continue + + # Fallback to job_store query + rec = store.get(task_id) + if rec: + env = getattr(rec, 'env', 'development') + create_time = rec.created_at + status_int = _map_status(rec.status) + + if rec.result and rec.status == "succeeded": + # Check if it's a "Full Analysis" result + if rec.result.get("status_message") == "Full Hardware Analysis Success": + result_data = [rec.result] + else: + audio_paths = rec.result.get("audio_paths", []) + metas = rec.result.get("metas", {}) or {} + result_data = [ + { + "file": p, "wave": "", "status": status_int, + "create_time": int(create_time), "env": env, + "prompt": metas.get("caption", ""), + "lyrics": metas.get("lyrics", ""), + "metas": { + "bpm": metas.get("bpm"), + "duration": metas.get("duration"), + "genres": metas.get("genres", ""), + "keyscale": metas.get("keyscale", ""), + "timesignature": metas.get("timesignature", ""), + } + } + for p in audio_paths + ] if audio_paths else [{ + "file": "", "wave": "", "status": status_int, + "create_time": int(create_time), "env": env, + "prompt": metas.get("caption", ""), + "lyrics": metas.get("lyrics", ""), + "metas": { + "bpm": metas.get("bpm"), + "duration": metas.get("duration"), + "genres": metas.get("genres", ""), + "keyscale": metas.get("keyscale", ""), + "timesignature": metas.get("timesignature", ""), + } + }] + else: + result_data = [{ + "file": "", "wave": "", "status": status_int, + "create_time": int(create_time), "env": env, + "prompt": "", "lyrics": "", + "metas": {}, + "progress": float(rec.progress) if rec else 0.0, + "stage": rec.stage if rec else "queued", + "error": rec.error if rec.error else None, + }] + + current_log = log_buffer.last_message if status_int == 0 else rec.progress_text + data_list.append({ + "task_id": task_id, + "result": json.dumps(result_data, ensure_ascii=False), + "status": status_int, + "progress_text": current_log + }) + else: + data_list.append({"task_id": task_id, "result": "[]", "status": 0}) + + return _wrap_response(data_list) + + @app.get("/health") + async def health_check(): + """Health check endpoint for service status.""" + return _wrap_response({ + "status": "ok", + "service": "ACE-Step API", + "version": "1.0", + }) + + @app.get("/v1/stats") + async def get_stats(_: None = Depends(verify_api_key)): + """Get server statistics including job store stats.""" + job_stats = store.get_stats() + async with app.state.stats_lock: + avg_job_seconds = getattr(app.state, "avg_job_seconds", INITIAL_AVG_JOB_SECONDS) + return _wrap_response({ + "jobs": job_stats, + "queue_size": app.state.job_queue.qsize(), + "queue_maxsize": QUEUE_MAXSIZE, + "avg_job_seconds": avg_job_seconds, + }) + + @app.get("/v1/models") + async def list_models(_: None = Depends(verify_api_key)): + """List available DiT models.""" + models = [] + + # Primary model (always available if initialized) + if getattr(app.state, "_initialized", False): + primary_model = _get_model_name(app.state._config_path) + if primary_model: + models.append({ + "name": primary_model, + "is_default": True, + }) + + # Secondary model + if getattr(app.state, "_initialized2", False) and app.state._config_path2: + secondary_model = _get_model_name(app.state._config_path2) + if secondary_model: + models.append({ + "name": secondary_model, + "is_default": False, + }) + + # Third model + if getattr(app.state, "_initialized3", False) and app.state._config_path3: + third_model = _get_model_name(app.state._config_path3) + if third_model: + models.append({ + "name": third_model, + "is_default": False, + }) + + return _wrap_response({ + "models": models, + "default_model": models[0]["name"] if models else None, + }) + + @app.post("/create_random_sample") + async def create_random_sample_endpoint(request: Request, authorization: Optional[str] = Header(None)): + """ + Get random sample parameters from pre-loaded example data. + + Returns a random example from the examples directory for form filling. + """ + content_type = (request.headers.get("content-type") or "").lower() + + if "json" in content_type: + body = await request.json() + else: + form = await request.form() + body = {k: v for k, v in form.items()} + + verify_token_from_request(body, authorization) + sample_type = body.get("sample_type", "simple_mode") or "simple_mode" + + if sample_type == "simple_mode": + example_data = SIMPLE_EXAMPLE_DATA + else: + example_data = CUSTOM_EXAMPLE_DATA + + if not example_data: + return _wrap_response(None, code=500, error="No example data available") + + random_example = random.choice(example_data) + return _wrap_response(random_example) + + @app.post("/format_input") + async def format_input_endpoint(request: Request, authorization: Optional[str] = Header(None)): + """ + Format and enhance lyrics/caption via LLM. + + Takes user-provided caption and lyrics, and uses the LLM to enhance them + with proper structure and metadata. + """ + content_type = (request.headers.get("content-type") or "").lower() + + if "json" in content_type: + body = await request.json() + else: + form = await request.form() + body = {k: v for k, v in form.items()} + + verify_token_from_request(body, authorization) + llm: LLMHandler = app.state.llm_handler + + # Initialize LLM if needed + with app.state._llm_init_lock: + if not getattr(app.state, "_llm_initialized", False): + if getattr(app.state, "_llm_init_error", None): + raise HTTPException(status_code=500, detail=f"LLM init failed: {app.state._llm_init_error}") + + # Check if lazy loading is disabled + if getattr(app.state, "_llm_lazy_load_disabled", False): + raise HTTPException( + status_code=503, + detail="LLM not initialized. Set ACESTEP_INIT_LLM=true in .env to enable." + ) + + project_root = _get_project_root() + checkpoint_dir = os.path.join(project_root, "checkpoints") + lm_model_path = os.getenv("ACESTEP_LM_MODEL_PATH", "acestep-5Hz-lm-0.6B").strip() + backend = os.getenv("ACESTEP_LM_BACKEND", "vllm").strip().lower() + if backend not in {"vllm", "pt", "mlx"}: + backend = "vllm" + + # Auto-download LM model if not present + lm_model_name = _get_model_name(lm_model_path) + if lm_model_name: + try: + _ensure_model_downloaded(lm_model_name, checkpoint_dir) + except Exception as e: + print(f"[API Server] Warning: Failed to download LM model {lm_model_name}: {e}") + + lm_device = os.getenv("ACESTEP_LM_DEVICE", os.getenv("ACESTEP_DEVICE", "auto")) + lm_offload = _env_bool("ACESTEP_LM_OFFLOAD_TO_CPU", False) + + h: AceStepHandler = app.state.handler + status, ok = llm.initialize( + checkpoint_dir=checkpoint_dir, + lm_model_path=lm_model_path, + backend=backend, + device=lm_device, + offload_to_cpu=lm_offload, + dtype=None, + ) + if not ok: + app.state._llm_init_error = status + raise HTTPException(status_code=500, detail=f"LLM init failed: {status}") + app.state._llm_initialized = True + + # Parse parameters + prompt = body.get("prompt", "") or "" + lyrics = body.get("lyrics", "") or "" + temperature = _to_float(body.get("temperature"), 0.85) + + # Parse param_obj if provided + param_obj_str = body.get("param_obj", "{}") + if isinstance(param_obj_str, dict): + param_obj = param_obj_str + else: + try: + param_obj = json.loads(param_obj_str) if param_obj_str else {} + except json.JSONDecodeError: + param_obj = {} + + # Extract metadata from param_obj + duration = _to_float(param_obj.get("duration")) + bpm = _to_int(param_obj.get("bpm")) + key_scale = param_obj.get("key", "") or param_obj.get("key_scale", "") or "" + time_signature = param_obj.get("time_signature", "") or body.get("time_signature", "") or "" + language = param_obj.get("language", "") or "" + + # Build user_metadata for format_sample + user_metadata_for_format = {} + if bpm is not None: + user_metadata_for_format['bpm'] = bpm + if duration is not None and duration > 0: + user_metadata_for_format['duration'] = int(duration) + if key_scale: + user_metadata_for_format['keyscale'] = key_scale + if time_signature: + user_metadata_for_format['timesignature'] = time_signature + if language and language != "unknown": + user_metadata_for_format['language'] = language + + # Call format_sample + try: + format_result = format_sample( + llm_handler=llm, + caption=prompt, + lyrics=lyrics, + user_metadata=user_metadata_for_format if user_metadata_for_format else None, + temperature=temperature, + use_constrained_decoding=True, + ) + + if not format_result.success: + error_msg = format_result.error or format_result.status_message + return _wrap_response(None, code=500, error=f"format_sample failed: {error_msg}") + + # Use formatted results or fallback to original + result_caption = format_result.caption or prompt + result_lyrics = format_result.lyrics or lyrics + result_duration = format_result.duration or duration + result_bpm = format_result.bpm or bpm + result_key_scale = format_result.keyscale or key_scale + result_time_signature = format_result.timesignature or time_signature + + return _wrap_response({ + "caption": result_caption, + "lyrics": result_lyrics, + "bpm": result_bpm, + "key_scale": result_key_scale, + "time_signature": result_time_signature, + "duration": result_duration, + "vocal_language": format_result.language or language or "unknown", + }) + except Exception as e: + return _wrap_response(None, code=500, error=f"format_sample error: {str(e)}") + + @app.get("/v1/audio") + async def get_audio(path: str, request: Request, _: None = Depends(verify_api_key)): + """Serve audio file by path.""" + from fastapi.responses import FileResponse + + # Security: Validate path is within allowed directory to prevent path traversal + resolved_path = os.path.realpath(path) + allowed_dir = os.path.realpath(request.app.state.temp_audio_dir) + if not resolved_path.startswith(allowed_dir + os.sep) and resolved_path != allowed_dir: + raise HTTPException(status_code=403, detail="Access denied: path outside allowed directory") + if not os.path.exists(resolved_path): + raise HTTPException(status_code=404, detail="Audio file not found") + + ext = os.path.splitext(resolved_path)[1].lower() + media_types = { + ".mp3": "audio/mpeg", + ".wav": "audio/wav", + ".flac": "audio/flac", + ".ogg": "audio/ogg", + } + media_type = media_types.get(ext, "audio/mpeg") + + return FileResponse(resolved_path, media_type=media_type) + + return app + + +app = create_app() + + +def main() -> None: + import argparse + import uvicorn + + parser = argparse.ArgumentParser(description="ACE-Step API server") + parser.add_argument( + "--host", + default=os.getenv("ACESTEP_API_HOST", "127.0.0.1"), + help="Bind host (default from ACESTEP_API_HOST or 127.0.0.1)", + ) + parser.add_argument( + "--port", + type=int, + default=int(os.getenv("ACESTEP_API_PORT", "8001")), + help="Bind port (default from ACESTEP_API_PORT or 8001)", + ) + parser.add_argument( + "--api-key", + type=str, + default=os.getenv("ACESTEP_API_KEY", None), + help="API key for authentication (default from ACESTEP_API_KEY)", + ) + parser.add_argument( + "--download-source", + type=str, + choices=["huggingface", "modelscope", "auto"], + default=os.getenv("ACESTEP_DOWNLOAD_SOURCE", "auto"), + help="Preferred model download source: auto (default), huggingface, or modelscope", + ) + parser.add_argument( + "--init-llm", + action="store_true", + default=_env_bool("ACESTEP_INIT_LLM", False), + help="Initialize LLM even if GPU memory is insufficient (may cause OOM). " + "Can also be set via ACESTEP_INIT_LLM=true environment variable.", + ) + parser.add_argument( + "--lm-model-path", + type=str, + default=os.getenv("ACESTEP_LM_MODEL_PATH", ""), + help="LM model to load (e.g., 'acestep-5Hz-lm-0.6B'). Default from ACESTEP_LM_MODEL_PATH.", + ) + args = parser.parse_args() + + # Set API key from command line argument + if args.api_key: + os.environ["ACESTEP_API_KEY"] = args.api_key + + # Set download source preference + if args.download_source and args.download_source != "auto": + os.environ["ACESTEP_DOWNLOAD_SOURCE"] = args.download_source + print(f"Using preferred download source: {args.download_source}") + + # Set init LLM flag + if args.init_llm: + os.environ["ACESTEP_INIT_LLM"] = "true" + print("[API Server] LLM initialization enabled via --init-llm") + + # Set LM model path + if args.lm_model_path: + os.environ["ACESTEP_LM_MODEL_PATH"] = args.lm_model_path + print(f"[API Server] Using LM model: {args.lm_model_path}") + + # IMPORTANT: in-memory queue/store -> workers MUST be 1 + uvicorn.run( + "acestep.api_server:app", + host=str(args.host), + port=int(args.port), + reload=False, + workers=1, + ) + +if __name__ == "__main__": + main()