Spaces:
Running
on
A100
Running
on
A100
fix: user_metadata
Browse files- acestep/api_server.py +135 -19
- acestep/llm_inference.py +14 -0
acestep/api_server.py
CHANGED
|
@@ -51,9 +51,9 @@ class GenerateMusicRequest(BaseModel):
|
|
| 51 |
thinking: bool = False
|
| 52 |
|
| 53 |
bpm: Optional[int] = None
|
| 54 |
-
# Accept common client keys
|
| 55 |
-
key_scale: str =
|
| 56 |
-
time_signature: str =
|
| 57 |
vocal_language: str = "en"
|
| 58 |
inference_steps: int = 8
|
| 59 |
guidance_scale: float = 7.0
|
|
@@ -62,7 +62,7 @@ class GenerateMusicRequest(BaseModel):
|
|
| 62 |
|
| 63 |
reference_audio_path: Optional[str] = None
|
| 64 |
src_audio_path: Optional[str] = None
|
| 65 |
-
audio_duration: Optional[float] =
|
| 66 |
batch_size: Optional[int] = None
|
| 67 |
|
| 68 |
audio_code_string: str = ""
|
|
@@ -532,6 +532,12 @@ def create_app() -> FastAPI:
|
|
| 532 |
|
| 533 |
thinking = bool(getattr(req, "thinking", False))
|
| 534 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 535 |
# If LM-generated code hints are used, a too-strong cover strength can suppress lyric/vocal conditioning.
|
| 536 |
# We keep backward compatibility: only auto-adjust when user didn't override (still at default 1.0).
|
| 537 |
audio_cover_strength_val = float(req.audio_cover_strength)
|
|
@@ -556,6 +562,27 @@ def create_app() -> FastAPI:
|
|
| 556 |
or (audio_duration_val is None)
|
| 557 |
)
|
| 558 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 559 |
if need_lm_metas or need_lm_codes:
|
| 560 |
# Lazy init 5Hz LM once
|
| 561 |
with app.state._llm_init_lock:
|
|
@@ -602,6 +629,8 @@ def create_app() -> FastAPI:
|
|
| 602 |
top_k=_normalize_optional_int(req.lm_top_k),
|
| 603 |
top_p=_normalize_optional_float(req.lm_top_p),
|
| 604 |
repetition_penalty=float(req.lm_repetition_penalty),
|
|
|
|
|
|
|
| 605 |
)
|
| 606 |
|
| 607 |
meta, codes, status = _lm_call()
|
|
@@ -784,39 +813,122 @@ def create_app() -> FastAPI:
|
|
| 784 |
raise HTTPException(status_code=400, detail="Invalid request payload")
|
| 785 |
|
| 786 |
def _get_any(*keys: str, default: Any = None) -> Any:
|
|
|
|
| 787 |
for k in keys:
|
| 788 |
v = get(k, None)
|
| 789 |
if v is not None:
|
| 790 |
return v
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 791 |
return default
|
| 792 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 793 |
return GenerateMusicRequest(
|
| 794 |
caption=str(get("caption", "") or ""),
|
| 795 |
lyrics=str(get("lyrics", "") or ""),
|
| 796 |
thinking=_to_bool(get("thinking"), False),
|
| 797 |
-
bpm=
|
| 798 |
-
key_scale=
|
| 799 |
-
time_signature=
|
| 800 |
-
vocal_language=str(
|
| 801 |
-
inference_steps=_to_int(
|
| 802 |
-
guidance_scale=_to_float(
|
| 803 |
-
use_random_seed=_to_bool(
|
| 804 |
seed=_to_int(get("seed"), -1) or -1,
|
| 805 |
reference_audio_path=reference_audio_path,
|
| 806 |
src_audio_path=src_audio_path,
|
| 807 |
-
audio_duration=
|
| 808 |
batch_size=_to_int(get("batch_size"), None),
|
| 809 |
-
audio_code_string=str(
|
| 810 |
repainting_start=_to_float(get("repainting_start"), 0.0) or 0.0,
|
| 811 |
repainting_end=_to_float(get("repainting_end"), None),
|
| 812 |
instruction=str(get("instruction", _DEFAULT_DIT_INSTRUCTION) or ""),
|
| 813 |
-
audio_cover_strength=_to_float(
|
| 814 |
-
task_type=str(
|
| 815 |
use_adg=_to_bool(get("use_adg"), False),
|
| 816 |
cfg_interval_start=_to_float(get("cfg_interval_start"), 0.0) or 0.0,
|
| 817 |
cfg_interval_end=_to_float(get("cfg_interval_end"), 1.0) or 1.0,
|
| 818 |
audio_format=str(get("audio_format", "mp3") or "mp3"),
|
| 819 |
-
use_tiled_decode=_to_bool(
|
| 820 |
lm_model_path=str(get("lm_model_path") or "").strip() or None,
|
| 821 |
lm_backend=str(get("lm_backend", "vllm") or "vllm"),
|
| 822 |
lm_temperature=_to_float(get("lm_temperature"), _LM_DEFAULT_TEMPERATURE) or _LM_DEFAULT_TEMPERATURE,
|
|
@@ -834,11 +946,15 @@ def create_app() -> FastAPI:
|
|
| 834 |
|
| 835 |
if content_type.startswith("application/json"):
|
| 836 |
body = await request.json()
|
| 837 |
-
|
|
|
|
|
|
|
| 838 |
|
| 839 |
elif content_type.endswith("+json"):
|
| 840 |
body = await request.json()
|
| 841 |
-
|
|
|
|
|
|
|
| 842 |
|
| 843 |
elif content_type.startswith("multipart/form-data"):
|
| 844 |
form = await request.form()
|
|
@@ -877,7 +993,7 @@ def create_app() -> FastAPI:
|
|
| 877 |
try:
|
| 878 |
body = json.loads(raw.decode("utf-8"))
|
| 879 |
if isinstance(body, dict):
|
| 880 |
-
req =
|
| 881 |
else:
|
| 882 |
raise HTTPException(status_code=400, detail="JSON payload must be an object")
|
| 883 |
except HTTPException:
|
|
|
|
| 51 |
thinking: bool = False
|
| 52 |
|
| 53 |
bpm: Optional[int] = None
|
| 54 |
+
# Accept common client keys via manual parsing (see _build_req_from_mapping).
|
| 55 |
+
key_scale: str = ""
|
| 56 |
+
time_signature: str = ""
|
| 57 |
vocal_language: str = "en"
|
| 58 |
inference_steps: int = 8
|
| 59 |
guidance_scale: float = 7.0
|
|
|
|
| 62 |
|
| 63 |
reference_audio_path: Optional[str] = None
|
| 64 |
src_audio_path: Optional[str] = None
|
| 65 |
+
audio_duration: Optional[float] = None
|
| 66 |
batch_size: Optional[int] = None
|
| 67 |
|
| 68 |
audio_code_string: str = ""
|
|
|
|
| 532 |
|
| 533 |
thinking = bool(getattr(req, "thinking", False))
|
| 534 |
|
| 535 |
+
print(
|
| 536 |
+
"[api_server] parsed req: "
|
| 537 |
+
f"thinking={thinking}, caption_len={len((req.caption or '').strip())}, lyrics_len={len((req.lyrics or '').strip())}, "
|
| 538 |
+
f"bpm={req.bpm}, audio_duration={req.audio_duration}, key_scale={req.key_scale!r}, time_signature={req.time_signature!r}"
|
| 539 |
+
)
|
| 540 |
+
|
| 541 |
# If LM-generated code hints are used, a too-strong cover strength can suppress lyric/vocal conditioning.
|
| 542 |
# We keep backward compatibility: only auto-adjust when user didn't override (still at default 1.0).
|
| 543 |
audio_cover_strength_val = float(req.audio_cover_strength)
|
|
|
|
| 562 |
or (audio_duration_val is None)
|
| 563 |
)
|
| 564 |
|
| 565 |
+
# Feishu-compatible: if user explicitly provided some metadata fields,
|
| 566 |
+
# pass them into constrained decoding so LM injects them directly
|
| 567 |
+
# (i.e. does not re-infer / override those fields).
|
| 568 |
+
user_metadata: Dict[str, Optional[str]] = {}
|
| 569 |
+
if bpm_val is not None:
|
| 570 |
+
user_metadata["bpm"] = str(int(bpm_val))
|
| 571 |
+
if audio_duration_val is not None:
|
| 572 |
+
user_metadata["duration"] = str(float(audio_duration_val))
|
| 573 |
+
if (key_scale_val or "").strip():
|
| 574 |
+
user_metadata["keyscale"] = str(key_scale_val)
|
| 575 |
+
if (time_sig_val or "").strip():
|
| 576 |
+
user_metadata["timesignature"] = str(time_sig_val)
|
| 577 |
+
|
| 578 |
+
lm_target_duration: Optional[float] = None
|
| 579 |
+
if need_lm_codes:
|
| 580 |
+
# If user specified a duration, constrain codes generation length accordingly.
|
| 581 |
+
if audio_duration_val is not None and float(audio_duration_val) > 0:
|
| 582 |
+
lm_target_duration = float(audio_duration_val)
|
| 583 |
+
|
| 584 |
+
print(f"[api_server] LM调用参数: user_metadata={user_metadata}, target_duration={lm_target_duration}, need_lm_codes={need_lm_codes}, need_lm_metas={need_lm_metas}")
|
| 585 |
+
|
| 586 |
if need_lm_metas or need_lm_codes:
|
| 587 |
# Lazy init 5Hz LM once
|
| 588 |
with app.state._llm_init_lock:
|
|
|
|
| 629 |
top_k=_normalize_optional_int(req.lm_top_k),
|
| 630 |
top_p=_normalize_optional_float(req.lm_top_p),
|
| 631 |
repetition_penalty=float(req.lm_repetition_penalty),
|
| 632 |
+
target_duration=lm_target_duration,
|
| 633 |
+
user_metadata=(user_metadata or None),
|
| 634 |
)
|
| 635 |
|
| 636 |
meta, codes, status = _lm_call()
|
|
|
|
| 813 |
raise HTTPException(status_code=400, detail="Invalid request payload")
|
| 814 |
|
| 815 |
def _get_any(*keys: str, default: Any = None) -> Any:
|
| 816 |
+
# 1) Top-level keys
|
| 817 |
for k in keys:
|
| 818 |
v = get(k, None)
|
| 819 |
if v is not None:
|
| 820 |
return v
|
| 821 |
+
|
| 822 |
+
# 2) Nested metas/metadata/user_metadata (dict or JSON string)
|
| 823 |
+
nested = (
|
| 824 |
+
get("metas", None)
|
| 825 |
+
or get("meta", None)
|
| 826 |
+
or get("metadata", None)
|
| 827 |
+
or get("user_metadata", None)
|
| 828 |
+
or get("userMetadata", None)
|
| 829 |
+
)
|
| 830 |
+
|
| 831 |
+
if isinstance(nested, str):
|
| 832 |
+
s = nested.strip()
|
| 833 |
+
if s.startswith("{") and s.endswith("}"):
|
| 834 |
+
try:
|
| 835 |
+
nested = json.loads(s)
|
| 836 |
+
except Exception:
|
| 837 |
+
nested = None
|
| 838 |
+
|
| 839 |
+
if isinstance(nested, dict):
|
| 840 |
+
g2 = nested.get
|
| 841 |
+
for k in keys:
|
| 842 |
+
v = g2(k, None)
|
| 843 |
+
if v is not None:
|
| 844 |
+
return v
|
| 845 |
+
|
| 846 |
return default
|
| 847 |
|
| 848 |
+
# Debug: print what keys we actually received (helps explain empty parsed values)
|
| 849 |
+
try:
|
| 850 |
+
top_keys = list(getattr(mapping, "keys", lambda: [])())
|
| 851 |
+
except Exception:
|
| 852 |
+
top_keys = []
|
| 853 |
+
try:
|
| 854 |
+
nested_probe = (
|
| 855 |
+
get("metas", None)
|
| 856 |
+
or get("meta", None)
|
| 857 |
+
or get("metadata", None)
|
| 858 |
+
or get("user_metadata", None)
|
| 859 |
+
or get("userMetadata", None)
|
| 860 |
+
)
|
| 861 |
+
if isinstance(nested_probe, str):
|
| 862 |
+
sp = nested_probe.strip()
|
| 863 |
+
if sp.startswith("{") and sp.endswith("}"):
|
| 864 |
+
try:
|
| 865 |
+
nested_probe = json.loads(sp)
|
| 866 |
+
except Exception:
|
| 867 |
+
nested_probe = None
|
| 868 |
+
nested_keys = list(nested_probe.keys()) if isinstance(nested_probe, dict) else []
|
| 869 |
+
except Exception:
|
| 870 |
+
nested_keys = []
|
| 871 |
+
print(f"[api_server] request keys: top={sorted(top_keys)}, nested={sorted(nested_keys)}")
|
| 872 |
+
|
| 873 |
+
# Debug: print raw values/types for common meta fields (top-level + common aliases)
|
| 874 |
+
try:
|
| 875 |
+
probe_keys = [
|
| 876 |
+
"thinking",
|
| 877 |
+
"bpm",
|
| 878 |
+
"audio_duration",
|
| 879 |
+
"duration",
|
| 880 |
+
"audioDuration",
|
| 881 |
+
"key_scale",
|
| 882 |
+
"keyscale",
|
| 883 |
+
"keyScale",
|
| 884 |
+
"time_signature",
|
| 885 |
+
"timesignature",
|
| 886 |
+
"timeSignature",
|
| 887 |
+
]
|
| 888 |
+
raw = {k: get(k, None) for k in probe_keys}
|
| 889 |
+
raw_types = {k: (type(v).__name__ if v is not None else None) for k, v in raw.items()}
|
| 890 |
+
print(f"[api_server] request raw: {raw}")
|
| 891 |
+
print(f"[api_server] request raw types: {raw_types}")
|
| 892 |
+
except Exception:
|
| 893 |
+
pass
|
| 894 |
+
|
| 895 |
+
normalized_audio_duration = _to_float(_get_any("audio_duration", "duration", "audioDuration"), None)
|
| 896 |
+
normalized_bpm = _to_int(_get_any("bpm"), None)
|
| 897 |
+
normalized_keyscale = str(_get_any("key_scale", "keyscale", "keyScale", default="") or "")
|
| 898 |
+
normalized_timesig = str(_get_any("time_signature", "timesignature", "timeSignature", default="") or "")
|
| 899 |
+
print(
|
| 900 |
+
"[api_server] normalized: "
|
| 901 |
+
f"thinking={_to_bool(get('thinking'), False)}, bpm={normalized_bpm}, "
|
| 902 |
+
f"audio_duration={normalized_audio_duration}, key_scale={normalized_keyscale!r}, time_signature={normalized_timesig!r}"
|
| 903 |
+
)
|
| 904 |
+
|
| 905 |
return GenerateMusicRequest(
|
| 906 |
caption=str(get("caption", "") or ""),
|
| 907 |
lyrics=str(get("lyrics", "") or ""),
|
| 908 |
thinking=_to_bool(get("thinking"), False),
|
| 909 |
+
bpm=normalized_bpm,
|
| 910 |
+
key_scale=normalized_keyscale,
|
| 911 |
+
time_signature=normalized_timesig,
|
| 912 |
+
vocal_language=str(_get_any("vocal_language", "vocalLanguage", default="en") or "en"),
|
| 913 |
+
inference_steps=_to_int(_get_any("inference_steps", "inferenceSteps"), 8) or 8,
|
| 914 |
+
guidance_scale=_to_float(_get_any("guidance_scale", "guidanceScale"), 7.0) or 7.0,
|
| 915 |
+
use_random_seed=_to_bool(_get_any("use_random_seed", "useRandomSeed"), True),
|
| 916 |
seed=_to_int(get("seed"), -1) or -1,
|
| 917 |
reference_audio_path=reference_audio_path,
|
| 918 |
src_audio_path=src_audio_path,
|
| 919 |
+
audio_duration=normalized_audio_duration,
|
| 920 |
batch_size=_to_int(get("batch_size"), None),
|
| 921 |
+
audio_code_string=str(_get_any("audio_code_string", "audioCodeString", default="") or ""),
|
| 922 |
repainting_start=_to_float(get("repainting_start"), 0.0) or 0.0,
|
| 923 |
repainting_end=_to_float(get("repainting_end"), None),
|
| 924 |
instruction=str(get("instruction", _DEFAULT_DIT_INSTRUCTION) or ""),
|
| 925 |
+
audio_cover_strength=_to_float(_get_any("audio_cover_strength", "audioCoverStrength"), 1.0) or 1.0,
|
| 926 |
+
task_type=str(_get_any("task_type", "taskType", default="text2music") or "text2music"),
|
| 927 |
use_adg=_to_bool(get("use_adg"), False),
|
| 928 |
cfg_interval_start=_to_float(get("cfg_interval_start"), 0.0) or 0.0,
|
| 929 |
cfg_interval_end=_to_float(get("cfg_interval_end"), 1.0) or 1.0,
|
| 930 |
audio_format=str(get("audio_format", "mp3") or "mp3"),
|
| 931 |
+
use_tiled_decode=_to_bool(_get_any("use_tiled_decode", "useTiledDecode"), True),
|
| 932 |
lm_model_path=str(get("lm_model_path") or "").strip() or None,
|
| 933 |
lm_backend=str(get("lm_backend", "vllm") or "vllm"),
|
| 934 |
lm_temperature=_to_float(get("lm_temperature"), _LM_DEFAULT_TEMPERATURE) or _LM_DEFAULT_TEMPERATURE,
|
|
|
|
| 946 |
|
| 947 |
if content_type.startswith("application/json"):
|
| 948 |
body = await request.json()
|
| 949 |
+
if not isinstance(body, dict):
|
| 950 |
+
raise HTTPException(status_code=400, detail="JSON payload must be an object")
|
| 951 |
+
req = _build_req_from_mapping(body, reference_audio_path=None, src_audio_path=None)
|
| 952 |
|
| 953 |
elif content_type.endswith("+json"):
|
| 954 |
body = await request.json()
|
| 955 |
+
if not isinstance(body, dict):
|
| 956 |
+
raise HTTPException(status_code=400, detail="JSON payload must be an object")
|
| 957 |
+
req = _build_req_from_mapping(body, reference_audio_path=None, src_audio_path=None)
|
| 958 |
|
| 959 |
elif content_type.startswith("multipart/form-data"):
|
| 960 |
form = await request.form()
|
|
|
|
| 993 |
try:
|
| 994 |
body = json.loads(raw.decode("utf-8"))
|
| 995 |
if isinstance(body, dict):
|
| 996 |
+
req = _build_req_from_mapping(body, reference_audio_path=None, src_audio_path=None)
|
| 997 |
else:
|
| 998 |
raise HTTPException(status_code=400, detail="JSON payload must be an object")
|
| 999 |
except HTTPException:
|
acestep/llm_inference.py
CHANGED
|
@@ -245,6 +245,7 @@ class LLMHandler:
|
|
| 245 |
metadata_temperature: Optional[float] = 0.85,
|
| 246 |
codes_temperature: Optional[float] = None,
|
| 247 |
target_duration: Optional[float] = None,
|
|
|
|
| 248 |
) -> Tuple[Dict[str, Any], str, str]:
|
| 249 |
"""Generate metadata and audio codes using 5Hz LM with vllm backend
|
| 250 |
|
|
@@ -288,6 +289,8 @@ class LLMHandler:
|
|
| 288 |
self.constrained_processor.codes_temperature = codes_temperature if use_phase_temperatures else None
|
| 289 |
self.constrained_processor.update_caption(caption)
|
| 290 |
self.constrained_processor.set_target_duration(target_duration)
|
|
|
|
|
|
|
| 291 |
|
| 292 |
constrained_processor = self.constrained_processor
|
| 293 |
update_state_fn = constrained_processor.update_state
|
|
@@ -423,6 +426,7 @@ class LLMHandler:
|
|
| 423 |
metadata_temperature: Optional[float] = 0.85,
|
| 424 |
codes_temperature: Optional[float] = None,
|
| 425 |
target_duration: Optional[float] = None,
|
|
|
|
| 426 |
) -> Tuple[Dict[str, Any], str, str]:
|
| 427 |
"""Generate metadata and audio codes using 5Hz LM with PyTorch backend
|
| 428 |
|
|
@@ -495,6 +499,8 @@ class LLMHandler:
|
|
| 495 |
self.constrained_processor.codes_temperature = codes_temperature if use_phase_temperatures else None
|
| 496 |
self.constrained_processor.update_caption(caption)
|
| 497 |
self.constrained_processor.set_target_duration(target_duration)
|
|
|
|
|
|
|
| 498 |
|
| 499 |
constrained_processor = self.constrained_processor
|
| 500 |
|
|
@@ -769,6 +775,7 @@ class LLMHandler:
|
|
| 769 |
metadata_temperature: Optional[float] = 0.85,
|
| 770 |
codes_temperature: Optional[float] = None,
|
| 771 |
target_duration: Optional[float] = None,
|
|
|
|
| 772 |
) -> Tuple[Dict[str, Any], str, str]:
|
| 773 |
"""Generate metadata and audio codes using 5Hz LM
|
| 774 |
|
|
@@ -819,6 +826,7 @@ class LLMHandler:
|
|
| 819 |
metadata_temperature=metadata_temperature,
|
| 820 |
codes_temperature=codes_temperature,
|
| 821 |
target_duration=target_duration,
|
|
|
|
| 822 |
)
|
| 823 |
else:
|
| 824 |
return self.generate_with_5hz_lm_pt(
|
|
@@ -835,6 +843,7 @@ class LLMHandler:
|
|
| 835 |
metadata_temperature=metadata_temperature,
|
| 836 |
codes_temperature=codes_temperature,
|
| 837 |
target_duration=target_duration,
|
|
|
|
| 838 |
)
|
| 839 |
|
| 840 |
def generate_with_stop_condition(
|
|
@@ -853,6 +862,7 @@ class LLMHandler:
|
|
| 853 |
metadata_temperature: Optional[float] = 0.85,
|
| 854 |
codes_temperature: Optional[float] = None,
|
| 855 |
target_duration: Optional[float] = None,
|
|
|
|
| 856 |
) -> Tuple[Dict[str, Any], str, str]:
|
| 857 |
"""Feishu-compatible LM generation.
|
| 858 |
|
|
@@ -862,6 +872,8 @@ class LLMHandler:
|
|
| 862 |
Args:
|
| 863 |
target_duration: Target duration in seconds for codes generation constraint.
|
| 864 |
5 codes = 1 second. If specified, blocks EOS until target reached.
|
|
|
|
|
|
|
| 865 |
"""
|
| 866 |
infer_type = (infer_type or "").strip().lower()
|
| 867 |
if infer_type not in {"dit", "llm_dit"}:
|
|
@@ -882,6 +894,7 @@ class LLMHandler:
|
|
| 882 |
metadata_temperature=metadata_temperature,
|
| 883 |
codes_temperature=codes_temperature,
|
| 884 |
target_duration=target_duration,
|
|
|
|
| 885 |
)
|
| 886 |
|
| 887 |
# dit: generate and truncate at reasoning end tag
|
|
@@ -895,6 +908,7 @@ class LLMHandler:
|
|
| 895 |
"top_k": top_k,
|
| 896 |
"top_p": top_p,
|
| 897 |
"repetition_penalty": repetition_penalty,
|
|
|
|
| 898 |
},
|
| 899 |
use_constrained_decoding=use_constrained_decoding,
|
| 900 |
constrained_decoding_debug=constrained_decoding_debug,
|
|
|
|
| 245 |
metadata_temperature: Optional[float] = 0.85,
|
| 246 |
codes_temperature: Optional[float] = None,
|
| 247 |
target_duration: Optional[float] = None,
|
| 248 |
+
user_metadata: Optional[Dict[str, Optional[str]]] = None,
|
| 249 |
) -> Tuple[Dict[str, Any], str, str]:
|
| 250 |
"""Generate metadata and audio codes using 5Hz LM with vllm backend
|
| 251 |
|
|
|
|
| 289 |
self.constrained_processor.codes_temperature = codes_temperature if use_phase_temperatures else None
|
| 290 |
self.constrained_processor.update_caption(caption)
|
| 291 |
self.constrained_processor.set_target_duration(target_duration)
|
| 292 |
+
# Always call set_user_metadata to ensure previous settings are cleared if None
|
| 293 |
+
self.constrained_processor.set_user_metadata(user_metadata)
|
| 294 |
|
| 295 |
constrained_processor = self.constrained_processor
|
| 296 |
update_state_fn = constrained_processor.update_state
|
|
|
|
| 426 |
metadata_temperature: Optional[float] = 0.85,
|
| 427 |
codes_temperature: Optional[float] = None,
|
| 428 |
target_duration: Optional[float] = None,
|
| 429 |
+
user_metadata: Optional[Dict[str, Optional[str]]] = None,
|
| 430 |
) -> Tuple[Dict[str, Any], str, str]:
|
| 431 |
"""Generate metadata and audio codes using 5Hz LM with PyTorch backend
|
| 432 |
|
|
|
|
| 499 |
self.constrained_processor.codes_temperature = codes_temperature if use_phase_temperatures else None
|
| 500 |
self.constrained_processor.update_caption(caption)
|
| 501 |
self.constrained_processor.set_target_duration(target_duration)
|
| 502 |
+
# Always call set_user_metadata to ensure previous settings are cleared if None
|
| 503 |
+
self.constrained_processor.set_user_metadata(user_metadata)
|
| 504 |
|
| 505 |
constrained_processor = self.constrained_processor
|
| 506 |
|
|
|
|
| 775 |
metadata_temperature: Optional[float] = 0.85,
|
| 776 |
codes_temperature: Optional[float] = None,
|
| 777 |
target_duration: Optional[float] = None,
|
| 778 |
+
user_metadata: Optional[Dict[str, Optional[str]]] = None,
|
| 779 |
) -> Tuple[Dict[str, Any], str, str]:
|
| 780 |
"""Generate metadata and audio codes using 5Hz LM
|
| 781 |
|
|
|
|
| 826 |
metadata_temperature=metadata_temperature,
|
| 827 |
codes_temperature=codes_temperature,
|
| 828 |
target_duration=target_duration,
|
| 829 |
+
user_metadata=user_metadata,
|
| 830 |
)
|
| 831 |
else:
|
| 832 |
return self.generate_with_5hz_lm_pt(
|
|
|
|
| 843 |
metadata_temperature=metadata_temperature,
|
| 844 |
codes_temperature=codes_temperature,
|
| 845 |
target_duration=target_duration,
|
| 846 |
+
user_metadata=user_metadata,
|
| 847 |
)
|
| 848 |
|
| 849 |
def generate_with_stop_condition(
|
|
|
|
| 862 |
metadata_temperature: Optional[float] = 0.85,
|
| 863 |
codes_temperature: Optional[float] = None,
|
| 864 |
target_duration: Optional[float] = None,
|
| 865 |
+
user_metadata: Optional[Dict[str, Optional[str]]] = None,
|
| 866 |
) -> Tuple[Dict[str, Any], str, str]:
|
| 867 |
"""Feishu-compatible LM generation.
|
| 868 |
|
|
|
|
| 872 |
Args:
|
| 873 |
target_duration: Target duration in seconds for codes generation constraint.
|
| 874 |
5 codes = 1 second. If specified, blocks EOS until target reached.
|
| 875 |
+
user_metadata: User-provided metadata fields (e.g. bpm/duration/keyscale/timesignature).
|
| 876 |
+
If specified, constrained decoding will inject these values directly.
|
| 877 |
"""
|
| 878 |
infer_type = (infer_type or "").strip().lower()
|
| 879 |
if infer_type not in {"dit", "llm_dit"}:
|
|
|
|
| 894 |
metadata_temperature=metadata_temperature,
|
| 895 |
codes_temperature=codes_temperature,
|
| 896 |
target_duration=target_duration,
|
| 897 |
+
user_metadata=user_metadata,
|
| 898 |
)
|
| 899 |
|
| 900 |
# dit: generate and truncate at reasoning end tag
|
|
|
|
| 908 |
"top_k": top_k,
|
| 909 |
"top_p": top_p,
|
| 910 |
"repetition_penalty": repetition_penalty,
|
| 911 |
+
"user_metadata": user_metadata,
|
| 912 |
},
|
| 913 |
use_constrained_decoding=use_constrained_decoding,
|
| 914 |
constrained_decoding_debug=constrained_decoding_debug,
|