ChuxiJ commited on
Commit
51dc2aa
·
1 Parent(s): b3f1425

fix: user_metadata

Browse files
Files changed (2) hide show
  1. acestep/api_server.py +135 -19
  2. acestep/llm_inference.py +14 -0
acestep/api_server.py CHANGED
@@ -51,9 +51,9 @@ class GenerateMusicRequest(BaseModel):
51
  thinking: bool = False
52
 
53
  bpm: Optional[int] = None
54
- # Accept common client keys while keeping internal field names stable.
55
- key_scale: str = Field(default="", alias="keyscale")
56
- time_signature: str = Field(default="", alias="timesignature")
57
  vocal_language: str = "en"
58
  inference_steps: int = 8
59
  guidance_scale: float = 7.0
@@ -62,7 +62,7 @@ class GenerateMusicRequest(BaseModel):
62
 
63
  reference_audio_path: Optional[str] = None
64
  src_audio_path: Optional[str] = None
65
- audio_duration: Optional[float] = Field(default=None, alias="duration")
66
  batch_size: Optional[int] = None
67
 
68
  audio_code_string: str = ""
@@ -532,6 +532,12 @@ def create_app() -> FastAPI:
532
 
533
  thinking = bool(getattr(req, "thinking", False))
534
 
 
 
 
 
 
 
535
  # If LM-generated code hints are used, a too-strong cover strength can suppress lyric/vocal conditioning.
536
  # We keep backward compatibility: only auto-adjust when user didn't override (still at default 1.0).
537
  audio_cover_strength_val = float(req.audio_cover_strength)
@@ -556,6 +562,27 @@ def create_app() -> FastAPI:
556
  or (audio_duration_val is None)
557
  )
558
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
559
  if need_lm_metas or need_lm_codes:
560
  # Lazy init 5Hz LM once
561
  with app.state._llm_init_lock:
@@ -602,6 +629,8 @@ def create_app() -> FastAPI:
602
  top_k=_normalize_optional_int(req.lm_top_k),
603
  top_p=_normalize_optional_float(req.lm_top_p),
604
  repetition_penalty=float(req.lm_repetition_penalty),
 
 
605
  )
606
 
607
  meta, codes, status = _lm_call()
@@ -784,39 +813,122 @@ def create_app() -> FastAPI:
784
  raise HTTPException(status_code=400, detail="Invalid request payload")
785
 
786
  def _get_any(*keys: str, default: Any = None) -> Any:
 
787
  for k in keys:
788
  v = get(k, None)
789
  if v is not None:
790
  return v
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
791
  return default
792
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
793
  return GenerateMusicRequest(
794
  caption=str(get("caption", "") or ""),
795
  lyrics=str(get("lyrics", "") or ""),
796
  thinking=_to_bool(get("thinking"), False),
797
- bpm=_to_int(get("bpm"), None),
798
- key_scale=str(_get_any("key_scale", "keyscale", default="") or ""),
799
- time_signature=str(_get_any("time_signature", "timesignature", default="") or ""),
800
- vocal_language=str(get("vocal_language", "en") or "en"),
801
- inference_steps=_to_int(get("inference_steps"), 8) or 8,
802
- guidance_scale=_to_float(get("guidance_scale"), 7.0) or 7.0,
803
- use_random_seed=_to_bool(get("use_random_seed"), True),
804
  seed=_to_int(get("seed"), -1) or -1,
805
  reference_audio_path=reference_audio_path,
806
  src_audio_path=src_audio_path,
807
- audio_duration=_to_float(_get_any("audio_duration", "duration"), None),
808
  batch_size=_to_int(get("batch_size"), None),
809
- audio_code_string=str(get("audio_code_string", "") or ""),
810
  repainting_start=_to_float(get("repainting_start"), 0.0) or 0.0,
811
  repainting_end=_to_float(get("repainting_end"), None),
812
  instruction=str(get("instruction", _DEFAULT_DIT_INSTRUCTION) or ""),
813
- audio_cover_strength=_to_float(get("audio_cover_strength"), 1.0) or 1.0,
814
- task_type=str(get("task_type", "text2music") or "text2music"),
815
  use_adg=_to_bool(get("use_adg"), False),
816
  cfg_interval_start=_to_float(get("cfg_interval_start"), 0.0) or 0.0,
817
  cfg_interval_end=_to_float(get("cfg_interval_end"), 1.0) or 1.0,
818
  audio_format=str(get("audio_format", "mp3") or "mp3"),
819
- use_tiled_decode=_to_bool(get("use_tiled_decode"), True),
820
  lm_model_path=str(get("lm_model_path") or "").strip() or None,
821
  lm_backend=str(get("lm_backend", "vllm") or "vllm"),
822
  lm_temperature=_to_float(get("lm_temperature"), _LM_DEFAULT_TEMPERATURE) or _LM_DEFAULT_TEMPERATURE,
@@ -834,11 +946,15 @@ def create_app() -> FastAPI:
834
 
835
  if content_type.startswith("application/json"):
836
  body = await request.json()
837
- req = GenerateMusicRequest(**body)
 
 
838
 
839
  elif content_type.endswith("+json"):
840
  body = await request.json()
841
- req = GenerateMusicRequest(**body)
 
 
842
 
843
  elif content_type.startswith("multipart/form-data"):
844
  form = await request.form()
@@ -877,7 +993,7 @@ def create_app() -> FastAPI:
877
  try:
878
  body = json.loads(raw.decode("utf-8"))
879
  if isinstance(body, dict):
880
- req = GenerateMusicRequest(**body)
881
  else:
882
  raise HTTPException(status_code=400, detail="JSON payload must be an object")
883
  except HTTPException:
 
51
  thinking: bool = False
52
 
53
  bpm: Optional[int] = None
54
+ # Accept common client keys via manual parsing (see _build_req_from_mapping).
55
+ key_scale: str = ""
56
+ time_signature: str = ""
57
  vocal_language: str = "en"
58
  inference_steps: int = 8
59
  guidance_scale: float = 7.0
 
62
 
63
  reference_audio_path: Optional[str] = None
64
  src_audio_path: Optional[str] = None
65
+ audio_duration: Optional[float] = None
66
  batch_size: Optional[int] = None
67
 
68
  audio_code_string: str = ""
 
532
 
533
  thinking = bool(getattr(req, "thinking", False))
534
 
535
+ print(
536
+ "[api_server] parsed req: "
537
+ f"thinking={thinking}, caption_len={len((req.caption or '').strip())}, lyrics_len={len((req.lyrics or '').strip())}, "
538
+ f"bpm={req.bpm}, audio_duration={req.audio_duration}, key_scale={req.key_scale!r}, time_signature={req.time_signature!r}"
539
+ )
540
+
541
  # If LM-generated code hints are used, a too-strong cover strength can suppress lyric/vocal conditioning.
542
  # We keep backward compatibility: only auto-adjust when user didn't override (still at default 1.0).
543
  audio_cover_strength_val = float(req.audio_cover_strength)
 
562
  or (audio_duration_val is None)
563
  )
564
 
565
+ # Feishu-compatible: if user explicitly provided some metadata fields,
566
+ # pass them into constrained decoding so LM injects them directly
567
+ # (i.e. does not re-infer / override those fields).
568
+ user_metadata: Dict[str, Optional[str]] = {}
569
+ if bpm_val is not None:
570
+ user_metadata["bpm"] = str(int(bpm_val))
571
+ if audio_duration_val is not None:
572
+ user_metadata["duration"] = str(float(audio_duration_val))
573
+ if (key_scale_val or "").strip():
574
+ user_metadata["keyscale"] = str(key_scale_val)
575
+ if (time_sig_val or "").strip():
576
+ user_metadata["timesignature"] = str(time_sig_val)
577
+
578
+ lm_target_duration: Optional[float] = None
579
+ if need_lm_codes:
580
+ # If user specified a duration, constrain codes generation length accordingly.
581
+ if audio_duration_val is not None and float(audio_duration_val) > 0:
582
+ lm_target_duration = float(audio_duration_val)
583
+
584
+ print(f"[api_server] LM调用参数: user_metadata={user_metadata}, target_duration={lm_target_duration}, need_lm_codes={need_lm_codes}, need_lm_metas={need_lm_metas}")
585
+
586
  if need_lm_metas or need_lm_codes:
587
  # Lazy init 5Hz LM once
588
  with app.state._llm_init_lock:
 
629
  top_k=_normalize_optional_int(req.lm_top_k),
630
  top_p=_normalize_optional_float(req.lm_top_p),
631
  repetition_penalty=float(req.lm_repetition_penalty),
632
+ target_duration=lm_target_duration,
633
+ user_metadata=(user_metadata or None),
634
  )
635
 
636
  meta, codes, status = _lm_call()
 
813
  raise HTTPException(status_code=400, detail="Invalid request payload")
814
 
815
  def _get_any(*keys: str, default: Any = None) -> Any:
816
+ # 1) Top-level keys
817
  for k in keys:
818
  v = get(k, None)
819
  if v is not None:
820
  return v
821
+
822
+ # 2) Nested metas/metadata/user_metadata (dict or JSON string)
823
+ nested = (
824
+ get("metas", None)
825
+ or get("meta", None)
826
+ or get("metadata", None)
827
+ or get("user_metadata", None)
828
+ or get("userMetadata", None)
829
+ )
830
+
831
+ if isinstance(nested, str):
832
+ s = nested.strip()
833
+ if s.startswith("{") and s.endswith("}"):
834
+ try:
835
+ nested = json.loads(s)
836
+ except Exception:
837
+ nested = None
838
+
839
+ if isinstance(nested, dict):
840
+ g2 = nested.get
841
+ for k in keys:
842
+ v = g2(k, None)
843
+ if v is not None:
844
+ return v
845
+
846
  return default
847
 
848
+ # Debug: print what keys we actually received (helps explain empty parsed values)
849
+ try:
850
+ top_keys = list(getattr(mapping, "keys", lambda: [])())
851
+ except Exception:
852
+ top_keys = []
853
+ try:
854
+ nested_probe = (
855
+ get("metas", None)
856
+ or get("meta", None)
857
+ or get("metadata", None)
858
+ or get("user_metadata", None)
859
+ or get("userMetadata", None)
860
+ )
861
+ if isinstance(nested_probe, str):
862
+ sp = nested_probe.strip()
863
+ if sp.startswith("{") and sp.endswith("}"):
864
+ try:
865
+ nested_probe = json.loads(sp)
866
+ except Exception:
867
+ nested_probe = None
868
+ nested_keys = list(nested_probe.keys()) if isinstance(nested_probe, dict) else []
869
+ except Exception:
870
+ nested_keys = []
871
+ print(f"[api_server] request keys: top={sorted(top_keys)}, nested={sorted(nested_keys)}")
872
+
873
+ # Debug: print raw values/types for common meta fields (top-level + common aliases)
874
+ try:
875
+ probe_keys = [
876
+ "thinking",
877
+ "bpm",
878
+ "audio_duration",
879
+ "duration",
880
+ "audioDuration",
881
+ "key_scale",
882
+ "keyscale",
883
+ "keyScale",
884
+ "time_signature",
885
+ "timesignature",
886
+ "timeSignature",
887
+ ]
888
+ raw = {k: get(k, None) for k in probe_keys}
889
+ raw_types = {k: (type(v).__name__ if v is not None else None) for k, v in raw.items()}
890
+ print(f"[api_server] request raw: {raw}")
891
+ print(f"[api_server] request raw types: {raw_types}")
892
+ except Exception:
893
+ pass
894
+
895
+ normalized_audio_duration = _to_float(_get_any("audio_duration", "duration", "audioDuration"), None)
896
+ normalized_bpm = _to_int(_get_any("bpm"), None)
897
+ normalized_keyscale = str(_get_any("key_scale", "keyscale", "keyScale", default="") or "")
898
+ normalized_timesig = str(_get_any("time_signature", "timesignature", "timeSignature", default="") or "")
899
+ print(
900
+ "[api_server] normalized: "
901
+ f"thinking={_to_bool(get('thinking'), False)}, bpm={normalized_bpm}, "
902
+ f"audio_duration={normalized_audio_duration}, key_scale={normalized_keyscale!r}, time_signature={normalized_timesig!r}"
903
+ )
904
+
905
  return GenerateMusicRequest(
906
  caption=str(get("caption", "") or ""),
907
  lyrics=str(get("lyrics", "") or ""),
908
  thinking=_to_bool(get("thinking"), False),
909
+ bpm=normalized_bpm,
910
+ key_scale=normalized_keyscale,
911
+ time_signature=normalized_timesig,
912
+ vocal_language=str(_get_any("vocal_language", "vocalLanguage", default="en") or "en"),
913
+ inference_steps=_to_int(_get_any("inference_steps", "inferenceSteps"), 8) or 8,
914
+ guidance_scale=_to_float(_get_any("guidance_scale", "guidanceScale"), 7.0) or 7.0,
915
+ use_random_seed=_to_bool(_get_any("use_random_seed", "useRandomSeed"), True),
916
  seed=_to_int(get("seed"), -1) or -1,
917
  reference_audio_path=reference_audio_path,
918
  src_audio_path=src_audio_path,
919
+ audio_duration=normalized_audio_duration,
920
  batch_size=_to_int(get("batch_size"), None),
921
+ audio_code_string=str(_get_any("audio_code_string", "audioCodeString", default="") or ""),
922
  repainting_start=_to_float(get("repainting_start"), 0.0) or 0.0,
923
  repainting_end=_to_float(get("repainting_end"), None),
924
  instruction=str(get("instruction", _DEFAULT_DIT_INSTRUCTION) or ""),
925
+ audio_cover_strength=_to_float(_get_any("audio_cover_strength", "audioCoverStrength"), 1.0) or 1.0,
926
+ task_type=str(_get_any("task_type", "taskType", default="text2music") or "text2music"),
927
  use_adg=_to_bool(get("use_adg"), False),
928
  cfg_interval_start=_to_float(get("cfg_interval_start"), 0.0) or 0.0,
929
  cfg_interval_end=_to_float(get("cfg_interval_end"), 1.0) or 1.0,
930
  audio_format=str(get("audio_format", "mp3") or "mp3"),
931
+ use_tiled_decode=_to_bool(_get_any("use_tiled_decode", "useTiledDecode"), True),
932
  lm_model_path=str(get("lm_model_path") or "").strip() or None,
933
  lm_backend=str(get("lm_backend", "vllm") or "vllm"),
934
  lm_temperature=_to_float(get("lm_temperature"), _LM_DEFAULT_TEMPERATURE) or _LM_DEFAULT_TEMPERATURE,
 
946
 
947
  if content_type.startswith("application/json"):
948
  body = await request.json()
949
+ if not isinstance(body, dict):
950
+ raise HTTPException(status_code=400, detail="JSON payload must be an object")
951
+ req = _build_req_from_mapping(body, reference_audio_path=None, src_audio_path=None)
952
 
953
  elif content_type.endswith("+json"):
954
  body = await request.json()
955
+ if not isinstance(body, dict):
956
+ raise HTTPException(status_code=400, detail="JSON payload must be an object")
957
+ req = _build_req_from_mapping(body, reference_audio_path=None, src_audio_path=None)
958
 
959
  elif content_type.startswith("multipart/form-data"):
960
  form = await request.form()
 
993
  try:
994
  body = json.loads(raw.decode("utf-8"))
995
  if isinstance(body, dict):
996
+ req = _build_req_from_mapping(body, reference_audio_path=None, src_audio_path=None)
997
  else:
998
  raise HTTPException(status_code=400, detail="JSON payload must be an object")
999
  except HTTPException:
acestep/llm_inference.py CHANGED
@@ -245,6 +245,7 @@ class LLMHandler:
245
  metadata_temperature: Optional[float] = 0.85,
246
  codes_temperature: Optional[float] = None,
247
  target_duration: Optional[float] = None,
 
248
  ) -> Tuple[Dict[str, Any], str, str]:
249
  """Generate metadata and audio codes using 5Hz LM with vllm backend
250
 
@@ -288,6 +289,8 @@ class LLMHandler:
288
  self.constrained_processor.codes_temperature = codes_temperature if use_phase_temperatures else None
289
  self.constrained_processor.update_caption(caption)
290
  self.constrained_processor.set_target_duration(target_duration)
 
 
291
 
292
  constrained_processor = self.constrained_processor
293
  update_state_fn = constrained_processor.update_state
@@ -423,6 +426,7 @@ class LLMHandler:
423
  metadata_temperature: Optional[float] = 0.85,
424
  codes_temperature: Optional[float] = None,
425
  target_duration: Optional[float] = None,
 
426
  ) -> Tuple[Dict[str, Any], str, str]:
427
  """Generate metadata and audio codes using 5Hz LM with PyTorch backend
428
 
@@ -495,6 +499,8 @@ class LLMHandler:
495
  self.constrained_processor.codes_temperature = codes_temperature if use_phase_temperatures else None
496
  self.constrained_processor.update_caption(caption)
497
  self.constrained_processor.set_target_duration(target_duration)
 
 
498
 
499
  constrained_processor = self.constrained_processor
500
 
@@ -769,6 +775,7 @@ class LLMHandler:
769
  metadata_temperature: Optional[float] = 0.85,
770
  codes_temperature: Optional[float] = None,
771
  target_duration: Optional[float] = None,
 
772
  ) -> Tuple[Dict[str, Any], str, str]:
773
  """Generate metadata and audio codes using 5Hz LM
774
 
@@ -819,6 +826,7 @@ class LLMHandler:
819
  metadata_temperature=metadata_temperature,
820
  codes_temperature=codes_temperature,
821
  target_duration=target_duration,
 
822
  )
823
  else:
824
  return self.generate_with_5hz_lm_pt(
@@ -835,6 +843,7 @@ class LLMHandler:
835
  metadata_temperature=metadata_temperature,
836
  codes_temperature=codes_temperature,
837
  target_duration=target_duration,
 
838
  )
839
 
840
  def generate_with_stop_condition(
@@ -853,6 +862,7 @@ class LLMHandler:
853
  metadata_temperature: Optional[float] = 0.85,
854
  codes_temperature: Optional[float] = None,
855
  target_duration: Optional[float] = None,
 
856
  ) -> Tuple[Dict[str, Any], str, str]:
857
  """Feishu-compatible LM generation.
858
 
@@ -862,6 +872,8 @@ class LLMHandler:
862
  Args:
863
  target_duration: Target duration in seconds for codes generation constraint.
864
  5 codes = 1 second. If specified, blocks EOS until target reached.
 
 
865
  """
866
  infer_type = (infer_type or "").strip().lower()
867
  if infer_type not in {"dit", "llm_dit"}:
@@ -882,6 +894,7 @@ class LLMHandler:
882
  metadata_temperature=metadata_temperature,
883
  codes_temperature=codes_temperature,
884
  target_duration=target_duration,
 
885
  )
886
 
887
  # dit: generate and truncate at reasoning end tag
@@ -895,6 +908,7 @@ class LLMHandler:
895
  "top_k": top_k,
896
  "top_p": top_p,
897
  "repetition_penalty": repetition_penalty,
 
898
  },
899
  use_constrained_decoding=use_constrained_decoding,
900
  constrained_decoding_debug=constrained_decoding_debug,
 
245
  metadata_temperature: Optional[float] = 0.85,
246
  codes_temperature: Optional[float] = None,
247
  target_duration: Optional[float] = None,
248
+ user_metadata: Optional[Dict[str, Optional[str]]] = None,
249
  ) -> Tuple[Dict[str, Any], str, str]:
250
  """Generate metadata and audio codes using 5Hz LM with vllm backend
251
 
 
289
  self.constrained_processor.codes_temperature = codes_temperature if use_phase_temperatures else None
290
  self.constrained_processor.update_caption(caption)
291
  self.constrained_processor.set_target_duration(target_duration)
292
+ # Always call set_user_metadata to ensure previous settings are cleared if None
293
+ self.constrained_processor.set_user_metadata(user_metadata)
294
 
295
  constrained_processor = self.constrained_processor
296
  update_state_fn = constrained_processor.update_state
 
426
  metadata_temperature: Optional[float] = 0.85,
427
  codes_temperature: Optional[float] = None,
428
  target_duration: Optional[float] = None,
429
+ user_metadata: Optional[Dict[str, Optional[str]]] = None,
430
  ) -> Tuple[Dict[str, Any], str, str]:
431
  """Generate metadata and audio codes using 5Hz LM with PyTorch backend
432
 
 
499
  self.constrained_processor.codes_temperature = codes_temperature if use_phase_temperatures else None
500
  self.constrained_processor.update_caption(caption)
501
  self.constrained_processor.set_target_duration(target_duration)
502
+ # Always call set_user_metadata to ensure previous settings are cleared if None
503
+ self.constrained_processor.set_user_metadata(user_metadata)
504
 
505
  constrained_processor = self.constrained_processor
506
 
 
775
  metadata_temperature: Optional[float] = 0.85,
776
  codes_temperature: Optional[float] = None,
777
  target_duration: Optional[float] = None,
778
+ user_metadata: Optional[Dict[str, Optional[str]]] = None,
779
  ) -> Tuple[Dict[str, Any], str, str]:
780
  """Generate metadata and audio codes using 5Hz LM
781
 
 
826
  metadata_temperature=metadata_temperature,
827
  codes_temperature=codes_temperature,
828
  target_duration=target_duration,
829
+ user_metadata=user_metadata,
830
  )
831
  else:
832
  return self.generate_with_5hz_lm_pt(
 
843
  metadata_temperature=metadata_temperature,
844
  codes_temperature=codes_temperature,
845
  target_duration=target_duration,
846
+ user_metadata=user_metadata,
847
  )
848
 
849
  def generate_with_stop_condition(
 
862
  metadata_temperature: Optional[float] = 0.85,
863
  codes_temperature: Optional[float] = None,
864
  target_duration: Optional[float] = None,
865
+ user_metadata: Optional[Dict[str, Optional[str]]] = None,
866
  ) -> Tuple[Dict[str, Any], str, str]:
867
  """Feishu-compatible LM generation.
868
 
 
872
  Args:
873
  target_duration: Target duration in seconds for codes generation constraint.
874
  5 codes = 1 second. If specified, blocks EOS until target reached.
875
+ user_metadata: User-provided metadata fields (e.g. bpm/duration/keyscale/timesignature).
876
+ If specified, constrained decoding will inject these values directly.
877
  """
878
  infer_type = (infer_type or "").strip().lower()
879
  if infer_type not in {"dit", "llm_dit"}:
 
894
  metadata_temperature=metadata_temperature,
895
  codes_temperature=codes_temperature,
896
  target_duration=target_duration,
897
+ user_metadata=user_metadata,
898
  )
899
 
900
  # dit: generate and truncate at reasoning end tag
 
908
  "top_k": top_k,
909
  "top_p": top_p,
910
  "repetition_penalty": repetition_penalty,
911
+ "user_metadata": user_metadata,
912
  },
913
  use_constrained_decoding=use_constrained_decoding,
914
  constrained_decoding_debug=constrained_decoding_debug,