ace-step15-endpoint / handler.py
Andrew
Patch Tensor.sort/argsort bool CUDA compatibility
56eb0fe
# handler.py
import base64
import io
import os
import traceback
from typing import Any, Dict, Optional, Tuple
import numpy as np
import soundfile as sf
try:
import torch
except Exception:
torch = None
class EndpointHandler:
"""
Hugging Face Inference Endpoints custom handler for ACE-Step 1.5.
Supported request shapes:
{
"inputs": {
"prompt": "upbeat pop rap, emotional guitar",
"lyrics": "[Verse] ...",
"duration_sec": 12,
"sample_rate": 44100,
"seed": 42,
"guidance_scale": 7.0,
"steps": 8,
"use_lm": true,
"simple_prompt": false,
"instrumental": false,
"allow_fallback": false
}
}
Or simple mode:
{
"inputs": "upbeat pop rap with emotional guitar"
}
Notes:
- This handler uses ACE-Step's official Python API internally.
- Fallback sine generation is disabled by default so model failures are explicit.
"""
def __init__(self, path: str = ""):
self.path = path
self.project_root = os.path.dirname(os.path.abspath(__file__))
self.model_repo = os.getenv("ACE_MODEL_REPO", "ACE-Step/Ace-Step1.5")
self.config_path = os.getenv("ACE_CONFIG_PATH", "acestep-v15-turbo")
self.lm_model_path = os.getenv("ACE_LM_MODEL_PATH", "acestep-5Hz-lm-1.7B")
self.lm_backend = os.getenv("ACE_LM_BACKEND", "pt")
self.download_source = os.getenv("ACE_DOWNLOAD_SOURCE", "huggingface")
self.default_sr = int(os.getenv("DEFAULT_SAMPLE_RATE", "44100"))
self.enable_fallback = self._to_bool(os.getenv("ACE_ENABLE_FALLBACK"), False)
self.init_lm_on_start = self._to_bool(os.getenv("ACE_INIT_LLM"), False)
self.skip_init = self._to_bool(os.getenv("ACE_SKIP_INIT"), False)
self.device = "cuda" if (torch is not None and torch.cuda.is_available()) else "cpu"
self.dtype = "float16" if self.device == "cuda" else "float32"
self.model_loaded = False
self.model_error: Optional[str] = None
self.init_details: Dict[str, Any] = {}
self.dit_handler = None
self.llm_handler = None
self.llm_initialized = False
self.llm_error: Optional[str] = None
self._GenerationParams = None
self._GenerationConfig = None
self._generate_music = None
self._create_sample = None
if self.skip_init:
self.model_error = "Initialization skipped because ACE_SKIP_INIT=true"
else:
self._init_model()
# --------------------------
# Initialization
# --------------------------
def _init_model(self) -> None:
err_msgs = []
# ACE-Step dynamic config imports layer_type_validation from transformers.
# Some endpoint base images ship a transformers build without this helper.
self._patch_transformers_layer_validation()
# Some CUDA/torch combinations used by managed endpoint images don't support
# sorting bool tensors on CUDA. ACE-Step/Transformers paths can hit this.
self._patch_torch_sort_bool_cuda()
try:
from acestep.handler import AceStepHandler
from acestep.inference import GenerationConfig, GenerationParams, create_sample, generate_music
from acestep.llm_inference import LLMHandler
except Exception as e:
self.model_error = f"ACE-Step import failed: {type(e).__name__}: {e}"
return
self._GenerationParams = GenerationParams
self._GenerationConfig = GenerationConfig
self._generate_music = generate_music
self._create_sample = create_sample
try:
self.dit_handler = AceStepHandler()
prefer_source = self.download_source if self.download_source in {"huggingface", "modelscope"} else None
init_status, ok = self.dit_handler.initialize_service(
project_root=self.project_root,
config_path=self.config_path,
device=self.device,
use_flash_attention=False,
compile_model=False,
offload_to_cpu=False,
offload_dit_to_cpu=False,
prefer_source=prefer_source,
)
self.init_details["dit_status"] = init_status
if not ok:
raise RuntimeError(init_status)
except Exception as e:
err_msgs.append(f"DiT init failed: {type(e).__name__}: {e}")
try:
self.llm_handler = LLMHandler()
if self.init_lm_on_start:
self._ensure_llm_initialized()
except Exception as e:
err_msgs.append(f"LLM bootstrap failed: {type(e).__name__}: {e}")
if err_msgs:
self.model_loaded = False
self.model_error = " | ".join(err_msgs)
return
self.model_loaded = True
self.model_error = None
@staticmethod
def _patch_transformers_layer_validation() -> None:
try:
from transformers import configuration_utils as cu
except Exception:
return
if hasattr(cu, "layer_type_validation"):
return
def _fallback_layer_type_validation(layer_types, num_hidden_layers=None):
if layer_types is None:
return
if not isinstance(layer_types, (list, tuple)):
raise TypeError("`layer_types` must be a list/tuple")
if num_hidden_layers is not None and len(layer_types) != int(num_hidden_layers):
raise ValueError("`layer_types` length must match `num_hidden_layers`")
cu.layer_type_validation = _fallback_layer_type_validation
@staticmethod
def _patch_torch_sort_bool_cuda() -> None:
if torch is None or not hasattr(torch, "sort"):
return
if getattr(torch.sort, "__name__", "") == "_sort_bool_cuda_compat":
return
_orig_sort = torch.sort
_orig_tensor_sort = getattr(torch.Tensor, "sort", None)
_orig_argsort = getattr(torch, "argsort", None)
_orig_tensor_argsort = getattr(torch.Tensor, "argsort", None)
def _sort_bool_cuda_compat(input_tensor, *args, **kwargs):
if (
isinstance(input_tensor, torch.Tensor)
and input_tensor.is_cuda
and input_tensor.dtype == torch.bool
):
out = _orig_sort(input_tensor.to(torch.uint8), *args, **kwargs)
values = out.values.to(torch.bool) if hasattr(out, "values") else out[0].to(torch.bool)
indices = out.indices if hasattr(out, "indices") else out[1]
return values, indices
return _orig_sort(input_tensor, *args, **kwargs)
_sort_bool_cuda_compat.__name__ = "_sort_bool_cuda_compat"
torch.sort = _sort_bool_cuda_compat
if callable(_orig_tensor_sort):
def _tensor_sort_bool_cuda_compat(self, *args, **kwargs):
if self.is_cuda and self.dtype == torch.bool:
out = _orig_tensor_sort(self.to(torch.uint8), *args, **kwargs)
values = out.values.to(torch.bool) if hasattr(out, "values") else out[0].to(torch.bool)
indices = out.indices if hasattr(out, "indices") else out[1]
return values, indices
return _orig_tensor_sort(self, *args, **kwargs)
_tensor_sort_bool_cuda_compat.__name__ = "_tensor_sort_bool_cuda_compat"
torch.Tensor.sort = _tensor_sort_bool_cuda_compat
if callable(_orig_argsort):
def _argsort_bool_cuda_compat(input_tensor, *args, **kwargs):
if (
isinstance(input_tensor, torch.Tensor)
and input_tensor.is_cuda
and input_tensor.dtype == torch.bool
):
return _orig_argsort(input_tensor.to(torch.uint8), *args, **kwargs)
return _orig_argsort(input_tensor, *args, **kwargs)
_argsort_bool_cuda_compat.__name__ = "_argsort_bool_cuda_compat"
torch.argsort = _argsort_bool_cuda_compat
if callable(_orig_tensor_argsort):
def _tensor_argsort_bool_cuda_compat(self, *args, **kwargs):
if self.is_cuda and self.dtype == torch.bool:
return _orig_tensor_argsort(self.to(torch.uint8), *args, **kwargs)
return _orig_tensor_argsort(self, *args, **kwargs)
_tensor_argsort_bool_cuda_compat.__name__ = "_tensor_argsort_bool_cuda_compat"
torch.Tensor.argsort = _tensor_argsort_bool_cuda_compat
def _ensure_llm_initialized(self) -> bool:
if self.llm_handler is None:
self.llm_error = "LLM handler is not available"
return False
if self.llm_initialized:
return True
try:
checkpoint_dir = os.path.join(self.project_root, "checkpoints")
status, ok = self.llm_handler.initialize(
checkpoint_dir=checkpoint_dir,
lm_model_path=self.lm_model_path,
backend=self.lm_backend,
device=self.device,
offload_to_cpu=False,
)
self.init_details["llm_status"] = status
if not ok:
self.llm_error = status
self.llm_initialized = False
return False
self.llm_error = None
self.llm_initialized = True
return True
except Exception as e:
self.llm_error = f"LLM init exception: {type(e).__name__}: {e}"
self.llm_initialized = False
return False
# --------------------------
# Audio helpers
# --------------------------
@staticmethod
def _as_float32(audio: Any) -> np.ndarray:
if isinstance(audio, np.ndarray):
arr = audio
elif torch is not None and isinstance(audio, torch.Tensor):
arr = audio.detach().cpu().numpy()
else:
arr = np.asarray(audio)
if arr.ndim == 2 and arr.shape[0] in (1, 2) and arr.shape[1] > arr.shape[0]:
arr = arr.T
if arr.dtype != np.float32:
arr = arr.astype(np.float32)
return np.clip(arr, -1.0, 1.0)
@staticmethod
def _wav_b64(audio: np.ndarray, sr: int) -> str:
bio = io.BytesIO()
sf.write(bio, audio, sr, format="WAV")
return base64.b64encode(bio.getvalue()).decode("utf-8")
@staticmethod
def _fallback_sine(duration_sec: int, sr: int, seed: int) -> np.ndarray:
rng = np.random.default_rng(seed)
t = np.linspace(0, duration_sec, int(sr * duration_sec), endpoint=False)
y = (0.07 * np.sin(2 * np.pi * 440 * t) + 0.01 * rng.standard_normal(len(t))).astype(np.float32)
return np.clip(y, -1.0, 1.0)
# --------------------------
# Request normalization
# --------------------------
@staticmethod
def _to_bool(value: Any, default: bool = False) -> bool:
if value is None:
return default
if isinstance(value, bool):
return value
if isinstance(value, (int, float)):
return bool(value)
if isinstance(value, str):
return value.strip().lower() in {"1", "true", "t", "yes", "y", "on"}
return default
@staticmethod
def _to_int(value: Any, default: int) -> int:
try:
return int(value)
except Exception:
return default
@staticmethod
def _to_float(value: Any, default: float) -> float:
try:
return float(value)
except Exception:
return default
@staticmethod
def _pick_text(inputs: Dict[str, Any], *keys: str) -> str:
for key in keys:
v = inputs.get(key)
if v is None:
continue
s = str(v).strip()
if s:
return s
return ""
def _normalize_request(self, data: Dict[str, Any]) -> Dict[str, Any]:
raw_inputs = data.get("inputs", data)
if isinstance(raw_inputs, str):
raw_inputs = {"prompt": raw_inputs, "simple_prompt": True}
if not isinstance(raw_inputs, dict):
raise ValueError("`inputs` must be an object or string")
prompt = self._pick_text(raw_inputs, "prompt", "query", "caption", "text", "description")
lyrics = self._pick_text(raw_inputs, "lyrics")
simple_prompt = self._to_bool(raw_inputs.get("simple_prompt"), False) or self._to_bool(
raw_inputs.get("simple"), False
)
instrumental = self._to_bool(raw_inputs.get("instrumental"), False)
if not lyrics and (instrumental or simple_prompt):
lyrics = "[Instrumental]"
duration_sec = self._to_int(raw_inputs.get("duration_sec", raw_inputs.get("duration", 12)), 12)
duration_sec = max(10, min(duration_sec, 600))
sample_rate = self._to_int(raw_inputs.get("sample_rate", self.default_sr), self.default_sr)
sample_rate = max(8000, min(sample_rate, 48000))
seed = self._to_int(raw_inputs.get("seed", 42), 42)
guidance_scale = self._to_float(raw_inputs.get("guidance_scale", 7.0), 7.0)
steps = self._to_int(raw_inputs.get("steps", raw_inputs.get("inference_steps", 8)), 8)
steps = max(1, min(steps, 200))
use_lm = self._to_bool(raw_inputs.get("use_lm", raw_inputs.get("thinking", True)), True)
allow_fallback = self._to_bool(raw_inputs.get("allow_fallback"), self.enable_fallback)
return {
"prompt": prompt,
"lyrics": lyrics,
"duration_sec": duration_sec,
"sample_rate": sample_rate,
"seed": seed,
"guidance_scale": guidance_scale,
"steps": steps,
"use_lm": use_lm,
"instrumental": instrumental,
"simple_prompt": simple_prompt,
"allow_fallback": allow_fallback,
}
# --------------------------
# ACE-Step invocation
# --------------------------
def _build_generation_inputs(self, req: Dict[str, Any], llm_ready: bool) -> Tuple[Dict[str, Any], Dict[str, Any]]:
caption = req["prompt"]
lyrics = req["lyrics"]
extras: Dict[str, Any] = {
"simple_expansion_used": False,
"simple_expansion_error": None,
}
bpm = None
keyscale = ""
timesignature = ""
vocal_language = "unknown"
duration = float(req["duration_sec"])
if req["simple_prompt"] and req["use_lm"] and llm_ready and caption:
try:
sample = self._create_sample(
llm_handler=self.llm_handler,
query=caption,
instrumental=req["instrumental"],
)
if getattr(sample, "success", False):
caption = getattr(sample, "caption", "") or caption
lyrics = getattr(sample, "lyrics", "") or lyrics
bpm = getattr(sample, "bpm", None)
keyscale = getattr(sample, "keyscale", "") or ""
timesignature = getattr(sample, "timesignature", "") or ""
vocal_language = getattr(sample, "language", "") or "unknown"
sample_duration = getattr(sample, "duration", None)
if sample_duration:
duration = float(sample_duration)
extras["simple_expansion_used"] = True
else:
extras["simple_expansion_error"] = getattr(sample, "error", "create_sample failed")
except Exception as e:
extras["simple_expansion_error"] = f"{type(e).__name__}: {e}"
params = self._GenerationParams(
task_type="text2music",
caption=caption,
lyrics=lyrics,
instrumental=req["instrumental"],
duration=duration,
inference_steps=req["steps"],
guidance_scale=req["guidance_scale"],
seed=req["seed"],
bpm=bpm,
keyscale=keyscale,
timesignature=timesignature,
vocal_language=vocal_language,
thinking=bool(req["use_lm"] and llm_ready),
use_cot_metas=bool(req["use_lm"] and llm_ready),
use_cot_caption=bool(req["use_lm"] and llm_ready and not req["simple_prompt"]),
use_cot_language=bool(req["use_lm"] and llm_ready),
)
config = self._GenerationConfig(
batch_size=1,
allow_lm_batch=False,
use_random_seed=False,
seeds=[req["seed"]],
audio_format="wav",
)
extras["resolved_prompt"] = caption
extras["resolved_lyrics"] = lyrics
extras["resolved_duration"] = duration
return {"params": params, "config": config}, extras
def _call_model(self, req: Dict[str, Any]) -> Tuple[np.ndarray, int, Dict[str, Any]]:
if not self.model_loaded or self.dit_handler is None:
raise RuntimeError(self.model_error or "Model is not loaded")
llm_ready = False
if req["use_lm"]:
llm_ready = self._ensure_llm_initialized()
generation_inputs, extras = self._build_generation_inputs(req, llm_ready)
result = self._generate_music(
self.dit_handler,
self.llm_handler if llm_ready else None,
generation_inputs["params"],
generation_inputs["config"],
save_dir=None,
progress=None,
)
if not getattr(result, "success", False):
raise RuntimeError(getattr(result, "error", "generation failed"))
audios = getattr(result, "audios", None) or []
if not audios:
raise RuntimeError("generation succeeded but no audio was returned")
first = audios[0]
audio_tensor = first.get("tensor") if isinstance(first, dict) else None
if audio_tensor is None:
raise RuntimeError("generated audio tensor is missing")
sample_rate = int(first.get("sample_rate", req["sample_rate"]))
status_message = getattr(result, "status_message", "")
meta = {
"llm_requested": req["use_lm"],
"llm_initialized": llm_ready,
"llm_error": self.llm_error,
"status_message": status_message,
}
meta.update(extras)
return self._as_float32(audio_tensor), sample_rate, meta
# --------------------------
# Endpoint entry
# --------------------------
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
try:
req = self._normalize_request(data)
used_fallback = False
runtime_meta: Dict[str, Any] = {}
try:
audio, out_sr, runtime_meta = self._call_model(req)
except Exception as model_exc:
self.model_error = f"Inference failed: {type(model_exc).__name__}: {model_exc}"
if not req["allow_fallback"]:
raise RuntimeError(self.model_error)
used_fallback = True
audio = self._fallback_sine(req["duration_sec"], req["sample_rate"], req["seed"])
out_sr = req["sample_rate"]
return {
"audio_base64_wav": self._wav_b64(audio, out_sr),
"sample_rate": int(out_sr),
"duration_sec": int(req["duration_sec"]),
"used_fallback": used_fallback,
"model_loaded": self.model_loaded,
"model_repo": self.model_repo,
"model_error": self.model_error,
"meta": {
"device": self.device,
"dtype": self.dtype,
"prompt_len": len(req["prompt"]),
"lyrics_len": len(req["lyrics"]),
"seed": req["seed"],
"guidance_scale": req["guidance_scale"],
"steps": req["steps"],
"use_lm": req["use_lm"],
"simple_prompt": req["simple_prompt"],
"instrumental": req["instrumental"],
"allow_fallback": req["allow_fallback"],
"resolved_prompt": runtime_meta.get("resolved_prompt", req["prompt"]),
"resolved_lyrics": runtime_meta.get("resolved_lyrics", req["lyrics"]),
"simple_expansion_used": runtime_meta.get("simple_expansion_used", False),
"simple_expansion_error": runtime_meta.get("simple_expansion_error"),
"llm_requested": runtime_meta.get("llm_requested", False),
"llm_initialized": runtime_meta.get("llm_initialized", False),
"llm_error": runtime_meta.get("llm_error"),
"status_message": runtime_meta.get("status_message", ""),
"init_details": self.init_details,
},
}
except Exception as e:
return {
"error": f"{type(e).__name__}: {e}",
"traceback": traceback.format_exc(limit=4),
"audio_base64_wav": None,
"sample_rate": None,
"duration_sec": None,
"used_fallback": False,
"model_loaded": self.model_loaded,
"model_repo": self.model_repo,
"model_error": self.model_error,
"meta": {
"device": self.device,
"dtype": self.dtype,
"init_details": self.init_details,
"llm_error": self.llm_error,
},
}