Spaces:
Running on Zero
Running on Zero
feat: update api-server cot-caption
Browse files- .env.example +4 -0
- acestep/api_server.py +82 -23
.env.example
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
ACESTEP_CONFIG_PATH=acestep-v15-turbo-rl
|
| 2 |
+
ACESTEP_LM_MODEL_PATH=acestep-5Hz-lm-0.6B-v3
|
| 3 |
+
ACESTEP_DEVICE=auto
|
| 4 |
+
ACESTEP_LM_BACKEND=vllm
|
acestep/api_server.py
CHANGED
|
@@ -29,6 +29,11 @@ from threading import Lock
|
|
| 29 |
from typing import Any, Dict, Literal, Optional
|
| 30 |
from uuid import uuid4
|
| 31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
from fastapi import FastAPI, HTTPException, Request
|
| 33 |
from pydantic import BaseModel, Field
|
| 34 |
from starlette.datastructures import UploadFile as StarletteUploadFile
|
|
@@ -89,8 +94,12 @@ class GenerateMusicRequest(BaseModel):
|
|
| 89 |
lm_model_path: Optional[str] = None # e.g. "acestep-5Hz-lm-0.6B"
|
| 90 |
lm_backend: Literal["vllm", "pt"] = "vllm"
|
| 91 |
|
| 92 |
-
|
| 93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
lm_temperature: float = 0.85
|
| 95 |
lm_cfg_scale: float = 2.0
|
| 96 |
lm_top_k: Optional[int] = None
|
|
@@ -125,8 +134,6 @@ class JobResult(BaseModel):
|
|
| 125 |
status_message: str = ""
|
| 126 |
seed_value: str = ""
|
| 127 |
|
| 128 |
-
# 5Hz LM metadata (present when server invoked LM)
|
| 129 |
-
# Keep a raw-ish dict for clients that expect a `metas` object.
|
| 130 |
metas: Dict[str, Any] = Field(default_factory=dict)
|
| 131 |
bpm: Optional[int] = None
|
| 132 |
duration: Optional[float] = None
|
|
@@ -213,6 +220,22 @@ def _get_project_root() -> str:
|
|
| 213 |
return os.path.dirname(os.path.dirname(current_file))
|
| 214 |
|
| 215 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
def _to_int(v: Any, default: Optional[int] = None) -> Optional[int]:
|
| 217 |
if v is None:
|
| 218 |
return default
|
|
@@ -372,7 +395,7 @@ def create_app() -> FastAPI:
|
|
| 372 |
raise RuntimeError(app.state._init_error)
|
| 373 |
|
| 374 |
project_root = _get_project_root()
|
| 375 |
-
config_path = os.getenv("ACESTEP_CONFIG_PATH", "acestep-v15-turbo")
|
| 376 |
device = os.getenv("ACESTEP_DEVICE", "auto")
|
| 377 |
|
| 378 |
use_flash_attention = _env_bool("ACESTEP_USE_FLASH_ATTENTION", True)
|
|
@@ -568,25 +591,40 @@ def create_app() -> FastAPI:
|
|
| 568 |
|
| 569 |
has_codes = bool(audio_code_string and str(audio_code_string).strip())
|
| 570 |
need_lm_codes = bool(thinking) and (not has_codes)
|
| 571 |
-
need_lm_metas = (
|
| 572 |
-
(bpm_val is None)
|
| 573 |
-
or (not (key_scale_val or "").strip())
|
| 574 |
-
or (not (time_sig_val or "").strip())
|
| 575 |
-
or (audio_duration_val is None)
|
| 576 |
-
)
|
| 577 |
|
| 578 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 579 |
# pass them into constrained decoding so LM injects them directly
|
| 580 |
# (i.e. does not re-infer / override those fields).
|
| 581 |
user_metadata: Dict[str, Optional[str]] = {}
|
| 582 |
-
|
| 583 |
-
|
| 584 |
-
|
| 585 |
-
|
| 586 |
-
|
| 587 |
-
|
| 588 |
-
|
| 589 |
-
user_metadata[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 590 |
|
| 591 |
lm_target_duration: Optional[float] = None
|
| 592 |
if need_lm_codes:
|
|
@@ -594,7 +632,13 @@ def create_app() -> FastAPI:
|
|
| 594 |
if audio_duration_val is not None and float(audio_duration_val) > 0:
|
| 595 |
lm_target_duration = float(audio_duration_val)
|
| 596 |
|
| 597 |
-
print(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 598 |
|
| 599 |
if need_lm_metas or need_lm_codes:
|
| 600 |
# Lazy init 5Hz LM once
|
|
@@ -602,7 +646,7 @@ def create_app() -> FastAPI:
|
|
| 602 |
if getattr(app.state, "_llm_initialized", False) is False and getattr(app.state, "_llm_init_error", None) is None:
|
| 603 |
project_root = _get_project_root()
|
| 604 |
checkpoint_dir = os.path.join(project_root, "checkpoints")
|
| 605 |
-
lm_model_path = (req.lm_model_path or os.getenv("ACESTEP_LM_MODEL_PATH") or "acestep-5Hz-lm-0.6B").strip()
|
| 606 |
backend = (req.lm_backend or os.getenv("ACESTEP_LM_BACKEND") or "vllm").strip().lower()
|
| 607 |
if backend not in {"vllm", "pt"}:
|
| 608 |
backend = "vllm"
|
|
@@ -644,6 +688,11 @@ def create_app() -> FastAPI:
|
|
| 644 |
repetition_penalty=float(req.lm_repetition_penalty),
|
| 645 |
target_duration=lm_target_duration,
|
| 646 |
user_metadata=(user_metadata or None),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 647 |
)
|
| 648 |
|
| 649 |
meta, codes, status = _lm_call()
|
|
@@ -715,7 +764,6 @@ def create_app() -> FastAPI:
|
|
| 715 |
if s in {"", "N/A"}:
|
| 716 |
return None
|
| 717 |
return s
|
| 718 |
-
|
| 719 |
first, second, paths, gen_info, status_msg, seed_value, *_ = h.generate_music(
|
| 720 |
captions=req.caption,
|
| 721 |
lyrics=req.lyrics,
|
|
@@ -909,6 +957,11 @@ def create_app() -> FastAPI:
|
|
| 909 |
normalized_bpm = _to_int(_get_any("bpm"), None)
|
| 910 |
normalized_keyscale = str(_get_any("key_scale", "keyscale", "keyScale", default="") or "")
|
| 911 |
normalized_timesig = str(_get_any("time_signature", "timesignature", "timeSignature", default="") or "")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 912 |
print(
|
| 913 |
"[api_server] normalized: "
|
| 914 |
f"thinking={_to_bool(get('thinking'), False)}, bpm={normalized_bpm}, "
|
|
@@ -950,6 +1003,12 @@ def create_app() -> FastAPI:
|
|
| 950 |
lm_top_p=_to_float(get("lm_top_p"), _LM_DEFAULT_TOP_P),
|
| 951 |
lm_repetition_penalty=_to_float(get("lm_repetition_penalty"), 1.0) or 1.0,
|
| 952 |
lm_negative_prompt=str(get("lm_negative_prompt", "NO USER INPUT") or "NO USER INPUT"),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 953 |
)
|
| 954 |
|
| 955 |
def _first_value(v: Any) -> Any:
|
|
|
|
| 29 |
from typing import Any, Dict, Literal, Optional
|
| 30 |
from uuid import uuid4
|
| 31 |
|
| 32 |
+
try:
|
| 33 |
+
from dotenv import load_dotenv
|
| 34 |
+
except ImportError: # Optional dependency
|
| 35 |
+
load_dotenv = None # type: ignore
|
| 36 |
+
|
| 37 |
from fastapi import FastAPI, HTTPException, Request
|
| 38 |
from pydantic import BaseModel, Field
|
| 39 |
from starlette.datastructures import UploadFile as StarletteUploadFile
|
|
|
|
| 94 |
lm_model_path: Optional[str] = None # e.g. "acestep-5Hz-lm-0.6B"
|
| 95 |
lm_backend: Literal["vllm", "pt"] = "vllm"
|
| 96 |
|
| 97 |
+
constrained_decoding: bool = True
|
| 98 |
+
constrained_decoding_debug: bool = False
|
| 99 |
+
use_cot_caption: bool = True
|
| 100 |
+
use_cot_language: bool = True
|
| 101 |
+
is_format_caption: bool = False
|
| 102 |
+
|
| 103 |
lm_temperature: float = 0.85
|
| 104 |
lm_cfg_scale: float = 2.0
|
| 105 |
lm_top_k: Optional[int] = None
|
|
|
|
| 134 |
status_message: str = ""
|
| 135 |
seed_value: str = ""
|
| 136 |
|
|
|
|
|
|
|
| 137 |
metas: Dict[str, Any] = Field(default_factory=dict)
|
| 138 |
bpm: Optional[int] = None
|
| 139 |
duration: Optional[float] = None
|
|
|
|
| 220 |
return os.path.dirname(os.path.dirname(current_file))
|
| 221 |
|
| 222 |
|
| 223 |
+
def _load_project_env() -> None:
|
| 224 |
+
if load_dotenv is None:
|
| 225 |
+
return
|
| 226 |
+
try:
|
| 227 |
+
project_root = _get_project_root()
|
| 228 |
+
env_path = os.path.join(project_root, ".env")
|
| 229 |
+
if os.path.exists(env_path):
|
| 230 |
+
load_dotenv(env_path, override=False)
|
| 231 |
+
except Exception:
|
| 232 |
+
# Optional best-effort: continue even if .env loading fails.
|
| 233 |
+
pass
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
_load_project_env()
|
| 237 |
+
|
| 238 |
+
|
| 239 |
def _to_int(v: Any, default: Optional[int] = None) -> Optional[int]:
|
| 240 |
if v is None:
|
| 241 |
return default
|
|
|
|
| 395 |
raise RuntimeError(app.state._init_error)
|
| 396 |
|
| 397 |
project_root = _get_project_root()
|
| 398 |
+
config_path = os.getenv("ACESTEP_CONFIG_PATH", "acestep-v15-turbo-rl")
|
| 399 |
device = os.getenv("ACESTEP_DEVICE", "auto")
|
| 400 |
|
| 401 |
use_flash_attention = _env_bool("ACESTEP_USE_FLASH_ATTENTION", True)
|
|
|
|
| 591 |
|
| 592 |
has_codes = bool(audio_code_string and str(audio_code_string).strip())
|
| 593 |
need_lm_codes = bool(thinking) and (not has_codes)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 594 |
|
| 595 |
+
use_constrained_decoding = bool(getattr(req, "constrained_decoding", True))
|
| 596 |
+
constrained_decoding_debug = bool(getattr(req, "constrained_decoding_debug", False))
|
| 597 |
+
use_cot_caption = bool(getattr(req, "use_cot_caption", True))
|
| 598 |
+
use_cot_language = bool(getattr(req, "use_cot_language", True))
|
| 599 |
+
is_format_caption = bool(getattr(req, "is_format_caption", False))
|
| 600 |
+
|
| 601 |
# pass them into constrained decoding so LM injects them directly
|
| 602 |
# (i.e. does not re-infer / override those fields).
|
| 603 |
user_metadata: Dict[str, Optional[str]] = {}
|
| 604 |
+
|
| 605 |
+
def _set_user_meta(field: str, value: Optional[Any]) -> None:
|
| 606 |
+
if value is None:
|
| 607 |
+
return
|
| 608 |
+
s = str(value).strip()
|
| 609 |
+
if not s or s.upper() == "N/A":
|
| 610 |
+
return
|
| 611 |
+
user_metadata[field] = s
|
| 612 |
+
|
| 613 |
+
_set_user_meta("bpm", int(bpm_val) if bpm_val is not None else None)
|
| 614 |
+
_set_user_meta("duration", float(audio_duration_val) if audio_duration_val is not None else None)
|
| 615 |
+
_set_user_meta("keyscale", key_scale_val if (key_scale_val or "").strip() else None)
|
| 616 |
+
_set_user_meta("timesignature", time_sig_val if (time_sig_val or "").strip() else None)
|
| 617 |
+
|
| 618 |
+
def _has_meta(field: str) -> bool:
|
| 619 |
+
v = user_metadata.get(field)
|
| 620 |
+
return bool((v or "").strip())
|
| 621 |
+
|
| 622 |
+
need_lm_metas = not (
|
| 623 |
+
_has_meta("bpm")
|
| 624 |
+
and _has_meta("duration")
|
| 625 |
+
and _has_meta("keyscale")
|
| 626 |
+
and _has_meta("timesignature")
|
| 627 |
+
)
|
| 628 |
|
| 629 |
lm_target_duration: Optional[float] = None
|
| 630 |
if need_lm_codes:
|
|
|
|
| 632 |
if audio_duration_val is not None and float(audio_duration_val) > 0:
|
| 633 |
lm_target_duration = float(audio_duration_val)
|
| 634 |
|
| 635 |
+
print(
|
| 636 |
+
"[api_server] LM调用参数: "
|
| 637 |
+
f"user_metadata_keys={sorted(user_metadata.keys())}, target_duration={lm_target_duration}, "
|
| 638 |
+
f"need_lm_codes={need_lm_codes}, need_lm_metas={need_lm_metas}, "
|
| 639 |
+
f"use_constrained_decoding={use_constrained_decoding}, use_cot_caption={use_cot_caption}, "
|
| 640 |
+
f"use_cot_language={use_cot_language}, is_format_caption={is_format_caption}"
|
| 641 |
+
)
|
| 642 |
|
| 643 |
if need_lm_metas or need_lm_codes:
|
| 644 |
# Lazy init 5Hz LM once
|
|
|
|
| 646 |
if getattr(app.state, "_llm_initialized", False) is False and getattr(app.state, "_llm_init_error", None) is None:
|
| 647 |
project_root = _get_project_root()
|
| 648 |
checkpoint_dir = os.path.join(project_root, "checkpoints")
|
| 649 |
+
lm_model_path = (req.lm_model_path or os.getenv("ACESTEP_LM_MODEL_PATH") or "acestep-5Hz-lm-0.6B-v3").strip()
|
| 650 |
backend = (req.lm_backend or os.getenv("ACESTEP_LM_BACKEND") or "vllm").strip().lower()
|
| 651 |
if backend not in {"vllm", "pt"}:
|
| 652 |
backend = "vllm"
|
|
|
|
| 688 |
repetition_penalty=float(req.lm_repetition_penalty),
|
| 689 |
target_duration=lm_target_duration,
|
| 690 |
user_metadata=(user_metadata or None),
|
| 691 |
+
use_constrained_decoding=use_constrained_decoding,
|
| 692 |
+
constrained_decoding_debug=constrained_decoding_debug,
|
| 693 |
+
use_cot_caption=use_cot_caption,
|
| 694 |
+
use_cot_language=use_cot_language,
|
| 695 |
+
is_format_caption=is_format_caption,
|
| 696 |
)
|
| 697 |
|
| 698 |
meta, codes, status = _lm_call()
|
|
|
|
| 764 |
if s in {"", "N/A"}:
|
| 765 |
return None
|
| 766 |
return s
|
|
|
|
| 767 |
first, second, paths, gen_info, status_msg, seed_value, *_ = h.generate_music(
|
| 768 |
captions=req.caption,
|
| 769 |
lyrics=req.lyrics,
|
|
|
|
| 957 |
normalized_bpm = _to_int(_get_any("bpm"), None)
|
| 958 |
normalized_keyscale = str(_get_any("key_scale", "keyscale", "keyScale", default="") or "")
|
| 959 |
normalized_timesig = str(_get_any("time_signature", "timesignature", "timeSignature", default="") or "")
|
| 960 |
+
|
| 961 |
+
# Accept it as an alias to avoid clients needing to special-case server.
|
| 962 |
+
if normalized_audio_duration is None:
|
| 963 |
+
normalized_audio_duration = _to_float(_get_any("target_duration", "targetDuration"), None)
|
| 964 |
+
|
| 965 |
print(
|
| 966 |
"[api_server] normalized: "
|
| 967 |
f"thinking={_to_bool(get('thinking'), False)}, bpm={normalized_bpm}, "
|
|
|
|
| 1003 |
lm_top_p=_to_float(get("lm_top_p"), _LM_DEFAULT_TOP_P),
|
| 1004 |
lm_repetition_penalty=_to_float(get("lm_repetition_penalty"), 1.0) or 1.0,
|
| 1005 |
lm_negative_prompt=str(get("lm_negative_prompt", "NO USER INPUT") or "NO USER INPUT"),
|
| 1006 |
+
constrained_decoding=_to_bool(_get_any("constrained_decoding", "constrainedDecoding", "constrained"), True),
|
| 1007 |
+
constrained_decoding_debug=_to_bool(_get_any("constrained_decoding_debug", "constrainedDecodingDebug"), False),
|
| 1008 |
+
# Accept common aliases, including hyphenated keys from some clients.
|
| 1009 |
+
use_cot_caption=_to_bool(_get_any("use_cot_caption", "cot_caption", "cot-caption"), True),
|
| 1010 |
+
use_cot_language=_to_bool(_get_any("use_cot_language", "cot_language", "cot-language"), True),
|
| 1011 |
+
is_format_caption=_to_bool(_get_any("is_format_caption", "isFormatCaption"), False),
|
| 1012 |
)
|
| 1013 |
|
| 1014 |
def _first_value(v: Any) -> Any:
|