ChuxiJ commited on
Commit
bbb4f62
·
1 Parent(s): e161e9a

feat: update api-server cot-caption

Browse files
Files changed (2) hide show
  1. .env.example +4 -0
  2. acestep/api_server.py +82 -23
.env.example ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ ACESTEP_CONFIG_PATH=acestep-v15-turbo-rl
2
+ ACESTEP_LM_MODEL_PATH=acestep-5Hz-lm-0.6B-v3
3
+ ACESTEP_DEVICE=auto
4
+ ACESTEP_LM_BACKEND=vllm
acestep/api_server.py CHANGED
@@ -29,6 +29,11 @@ from threading import Lock
29
  from typing import Any, Dict, Literal, Optional
30
  from uuid import uuid4
31
 
 
 
 
 
 
32
  from fastapi import FastAPI, HTTPException, Request
33
  from pydantic import BaseModel, Field
34
  from starlette.datastructures import UploadFile as StarletteUploadFile
@@ -89,8 +94,12 @@ class GenerateMusicRequest(BaseModel):
89
  lm_model_path: Optional[str] = None # e.g. "acestep-5Hz-lm-0.6B"
90
  lm_backend: Literal["vllm", "pt"] = "vllm"
91
 
92
- # Align defaults with `acestep/gradio_ui.py` and `feishu_bot/config.py`
93
- # to improve lyric adherence in lm-dit mode.
 
 
 
 
94
  lm_temperature: float = 0.85
95
  lm_cfg_scale: float = 2.0
96
  lm_top_k: Optional[int] = None
@@ -125,8 +134,6 @@ class JobResult(BaseModel):
125
  status_message: str = ""
126
  seed_value: str = ""
127
 
128
- # 5Hz LM metadata (present when server invoked LM)
129
- # Keep a raw-ish dict for clients that expect a `metas` object.
130
  metas: Dict[str, Any] = Field(default_factory=dict)
131
  bpm: Optional[int] = None
132
  duration: Optional[float] = None
@@ -213,6 +220,22 @@ def _get_project_root() -> str:
213
  return os.path.dirname(os.path.dirname(current_file))
214
 
215
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  def _to_int(v: Any, default: Optional[int] = None) -> Optional[int]:
217
  if v is None:
218
  return default
@@ -372,7 +395,7 @@ def create_app() -> FastAPI:
372
  raise RuntimeError(app.state._init_error)
373
 
374
  project_root = _get_project_root()
375
- config_path = os.getenv("ACESTEP_CONFIG_PATH", "acestep-v15-turbo")
376
  device = os.getenv("ACESTEP_DEVICE", "auto")
377
 
378
  use_flash_attention = _env_bool("ACESTEP_USE_FLASH_ATTENTION", True)
@@ -568,25 +591,40 @@ def create_app() -> FastAPI:
568
 
569
  has_codes = bool(audio_code_string and str(audio_code_string).strip())
570
  need_lm_codes = bool(thinking) and (not has_codes)
571
- need_lm_metas = (
572
- (bpm_val is None)
573
- or (not (key_scale_val or "").strip())
574
- or (not (time_sig_val or "").strip())
575
- or (audio_duration_val is None)
576
- )
577
 
578
- # Feishu-compatible: if user explicitly provided some metadata fields,
 
 
 
 
 
579
  # pass them into constrained decoding so LM injects them directly
580
  # (i.e. does not re-infer / override those fields).
581
  user_metadata: Dict[str, Optional[str]] = {}
582
- if bpm_val is not None:
583
- user_metadata["bpm"] = str(int(bpm_val))
584
- if audio_duration_val is not None:
585
- user_metadata["duration"] = str(float(audio_duration_val))
586
- if (key_scale_val or "").strip():
587
- user_metadata["keyscale"] = str(key_scale_val)
588
- if (time_sig_val or "").strip():
589
- user_metadata["timesignature"] = str(time_sig_val)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
590
 
591
  lm_target_duration: Optional[float] = None
592
  if need_lm_codes:
@@ -594,7 +632,13 @@ def create_app() -> FastAPI:
594
  if audio_duration_val is not None and float(audio_duration_val) > 0:
595
  lm_target_duration = float(audio_duration_val)
596
 
597
- print(f"[api_server] LM调用参数: user_metadata={user_metadata}, target_duration={lm_target_duration}, need_lm_codes={need_lm_codes}, need_lm_metas={need_lm_metas}")
 
 
 
 
 
 
598
 
599
  if need_lm_metas or need_lm_codes:
600
  # Lazy init 5Hz LM once
@@ -602,7 +646,7 @@ def create_app() -> FastAPI:
602
  if getattr(app.state, "_llm_initialized", False) is False and getattr(app.state, "_llm_init_error", None) is None:
603
  project_root = _get_project_root()
604
  checkpoint_dir = os.path.join(project_root, "checkpoints")
605
- lm_model_path = (req.lm_model_path or os.getenv("ACESTEP_LM_MODEL_PATH") or "acestep-5Hz-lm-0.6B").strip()
606
  backend = (req.lm_backend or os.getenv("ACESTEP_LM_BACKEND") or "vllm").strip().lower()
607
  if backend not in {"vllm", "pt"}:
608
  backend = "vllm"
@@ -644,6 +688,11 @@ def create_app() -> FastAPI:
644
  repetition_penalty=float(req.lm_repetition_penalty),
645
  target_duration=lm_target_duration,
646
  user_metadata=(user_metadata or None),
 
 
 
 
 
647
  )
648
 
649
  meta, codes, status = _lm_call()
@@ -715,7 +764,6 @@ def create_app() -> FastAPI:
715
  if s in {"", "N/A"}:
716
  return None
717
  return s
718
-
719
  first, second, paths, gen_info, status_msg, seed_value, *_ = h.generate_music(
720
  captions=req.caption,
721
  lyrics=req.lyrics,
@@ -909,6 +957,11 @@ def create_app() -> FastAPI:
909
  normalized_bpm = _to_int(_get_any("bpm"), None)
910
  normalized_keyscale = str(_get_any("key_scale", "keyscale", "keyScale", default="") or "")
911
  normalized_timesig = str(_get_any("time_signature", "timesignature", "timeSignature", default="") or "")
 
 
 
 
 
912
  print(
913
  "[api_server] normalized: "
914
  f"thinking={_to_bool(get('thinking'), False)}, bpm={normalized_bpm}, "
@@ -950,6 +1003,12 @@ def create_app() -> FastAPI:
950
  lm_top_p=_to_float(get("lm_top_p"), _LM_DEFAULT_TOP_P),
951
  lm_repetition_penalty=_to_float(get("lm_repetition_penalty"), 1.0) or 1.0,
952
  lm_negative_prompt=str(get("lm_negative_prompt", "NO USER INPUT") or "NO USER INPUT"),
 
 
 
 
 
 
953
  )
954
 
955
  def _first_value(v: Any) -> Any:
 
29
  from typing import Any, Dict, Literal, Optional
30
  from uuid import uuid4
31
 
32
+ try:
33
+ from dotenv import load_dotenv
34
+ except ImportError: # Optional dependency
35
+ load_dotenv = None # type: ignore
36
+
37
  from fastapi import FastAPI, HTTPException, Request
38
  from pydantic import BaseModel, Field
39
  from starlette.datastructures import UploadFile as StarletteUploadFile
 
94
  lm_model_path: Optional[str] = None # e.g. "acestep-5Hz-lm-0.6B"
95
  lm_backend: Literal["vllm", "pt"] = "vllm"
96
 
97
+ constrained_decoding: bool = True
98
+ constrained_decoding_debug: bool = False
99
+ use_cot_caption: bool = True
100
+ use_cot_language: bool = True
101
+ is_format_caption: bool = False
102
+
103
  lm_temperature: float = 0.85
104
  lm_cfg_scale: float = 2.0
105
  lm_top_k: Optional[int] = None
 
134
  status_message: str = ""
135
  seed_value: str = ""
136
 
 
 
137
  metas: Dict[str, Any] = Field(default_factory=dict)
138
  bpm: Optional[int] = None
139
  duration: Optional[float] = None
 
220
  return os.path.dirname(os.path.dirname(current_file))
221
 
222
 
223
+ def _load_project_env() -> None:
224
+ if load_dotenv is None:
225
+ return
226
+ try:
227
+ project_root = _get_project_root()
228
+ env_path = os.path.join(project_root, ".env")
229
+ if os.path.exists(env_path):
230
+ load_dotenv(env_path, override=False)
231
+ except Exception:
232
+ # Optional best-effort: continue even if .env loading fails.
233
+ pass
234
+
235
+
236
+ _load_project_env()
237
+
238
+
239
  def _to_int(v: Any, default: Optional[int] = None) -> Optional[int]:
240
  if v is None:
241
  return default
 
395
  raise RuntimeError(app.state._init_error)
396
 
397
  project_root = _get_project_root()
398
+ config_path = os.getenv("ACESTEP_CONFIG_PATH", "acestep-v15-turbo-rl")
399
  device = os.getenv("ACESTEP_DEVICE", "auto")
400
 
401
  use_flash_attention = _env_bool("ACESTEP_USE_FLASH_ATTENTION", True)
 
591
 
592
  has_codes = bool(audio_code_string and str(audio_code_string).strip())
593
  need_lm_codes = bool(thinking) and (not has_codes)
 
 
 
 
 
 
594
 
595
+ use_constrained_decoding = bool(getattr(req, "constrained_decoding", True))
596
+ constrained_decoding_debug = bool(getattr(req, "constrained_decoding_debug", False))
597
+ use_cot_caption = bool(getattr(req, "use_cot_caption", True))
598
+ use_cot_language = bool(getattr(req, "use_cot_language", True))
599
+ is_format_caption = bool(getattr(req, "is_format_caption", False))
600
+
601
  # pass them into constrained decoding so LM injects them directly
602
  # (i.e. does not re-infer / override those fields).
603
  user_metadata: Dict[str, Optional[str]] = {}
604
+
605
+ def _set_user_meta(field: str, value: Optional[Any]) -> None:
606
+ if value is None:
607
+ return
608
+ s = str(value).strip()
609
+ if not s or s.upper() == "N/A":
610
+ return
611
+ user_metadata[field] = s
612
+
613
+ _set_user_meta("bpm", int(bpm_val) if bpm_val is not None else None)
614
+ _set_user_meta("duration", float(audio_duration_val) if audio_duration_val is not None else None)
615
+ _set_user_meta("keyscale", key_scale_val if (key_scale_val or "").strip() else None)
616
+ _set_user_meta("timesignature", time_sig_val if (time_sig_val or "").strip() else None)
617
+
618
+ def _has_meta(field: str) -> bool:
619
+ v = user_metadata.get(field)
620
+ return bool((v or "").strip())
621
+
622
+ need_lm_metas = not (
623
+ _has_meta("bpm")
624
+ and _has_meta("duration")
625
+ and _has_meta("keyscale")
626
+ and _has_meta("timesignature")
627
+ )
628
 
629
  lm_target_duration: Optional[float] = None
630
  if need_lm_codes:
 
632
  if audio_duration_val is not None and float(audio_duration_val) > 0:
633
  lm_target_duration = float(audio_duration_val)
634
 
635
+ print(
636
+ "[api_server] LM调用参数: "
637
+ f"user_metadata_keys={sorted(user_metadata.keys())}, target_duration={lm_target_duration}, "
638
+ f"need_lm_codes={need_lm_codes}, need_lm_metas={need_lm_metas}, "
639
+ f"use_constrained_decoding={use_constrained_decoding}, use_cot_caption={use_cot_caption}, "
640
+ f"use_cot_language={use_cot_language}, is_format_caption={is_format_caption}"
641
+ )
642
 
643
  if need_lm_metas or need_lm_codes:
644
  # Lazy init 5Hz LM once
 
646
  if getattr(app.state, "_llm_initialized", False) is False and getattr(app.state, "_llm_init_error", None) is None:
647
  project_root = _get_project_root()
648
  checkpoint_dir = os.path.join(project_root, "checkpoints")
649
+ lm_model_path = (req.lm_model_path or os.getenv("ACESTEP_LM_MODEL_PATH") or "acestep-5Hz-lm-0.6B-v3").strip()
650
  backend = (req.lm_backend or os.getenv("ACESTEP_LM_BACKEND") or "vllm").strip().lower()
651
  if backend not in {"vllm", "pt"}:
652
  backend = "vllm"
 
688
  repetition_penalty=float(req.lm_repetition_penalty),
689
  target_duration=lm_target_duration,
690
  user_metadata=(user_metadata or None),
691
+ use_constrained_decoding=use_constrained_decoding,
692
+ constrained_decoding_debug=constrained_decoding_debug,
693
+ use_cot_caption=use_cot_caption,
694
+ use_cot_language=use_cot_language,
695
+ is_format_caption=is_format_caption,
696
  )
697
 
698
  meta, codes, status = _lm_call()
 
764
  if s in {"", "N/A"}:
765
  return None
766
  return s
 
767
  first, second, paths, gen_info, status_msg, seed_value, *_ = h.generate_music(
768
  captions=req.caption,
769
  lyrics=req.lyrics,
 
957
  normalized_bpm = _to_int(_get_any("bpm"), None)
958
  normalized_keyscale = str(_get_any("key_scale", "keyscale", "keyScale", default="") or "")
959
  normalized_timesig = str(_get_any("time_signature", "timesignature", "timeSignature", default="") or "")
960
+
961
+ # Accept it as an alias to avoid clients needing to special-case server.
962
+ if normalized_audio_duration is None:
963
+ normalized_audio_duration = _to_float(_get_any("target_duration", "targetDuration"), None)
964
+
965
  print(
966
  "[api_server] normalized: "
967
  f"thinking={_to_bool(get('thinking'), False)}, bpm={normalized_bpm}, "
 
1003
  lm_top_p=_to_float(get("lm_top_p"), _LM_DEFAULT_TOP_P),
1004
  lm_repetition_penalty=_to_float(get("lm_repetition_penalty"), 1.0) or 1.0,
1005
  lm_negative_prompt=str(get("lm_negative_prompt", "NO USER INPUT") or "NO USER INPUT"),
1006
+ constrained_decoding=_to_bool(_get_any("constrained_decoding", "constrainedDecoding", "constrained"), True),
1007
+ constrained_decoding_debug=_to_bool(_get_any("constrained_decoding_debug", "constrainedDecodingDebug"), False),
1008
+ # Accept common aliases, including hyphenated keys from some clients.
1009
+ use_cot_caption=_to_bool(_get_any("use_cot_caption", "cot_caption", "cot-caption"), True),
1010
+ use_cot_language=_to_bool(_get_any("use_cot_language", "cot_language", "cot-language"), True),
1011
+ is_format_caption=_to_bool(_get_any("is_format_caption", "isFormatCaption"), False),
1012
  )
1013
 
1014
  def _first_value(v: Any) -> Any: