Spaces:
Running
on
A100
Running
on
A100
fix: bot return metas
Browse files- acestep/acestep_v15_pipeline.py +15 -4
- acestep/api_server.py +45 -70
- acestep/handler.py +59 -13
acestep/acestep_v15_pipeline.py
CHANGED
|
@@ -9,10 +9,21 @@ import sys
|
|
| 9 |
for proxy_var in ['http_proxy', 'https_proxy', 'HTTP_PROXY', 'HTTPS_PROXY', 'ALL_PROXY']:
|
| 10 |
os.environ.pop(proxy_var, None)
|
| 11 |
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
from .
|
| 15 |
-
from .
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
|
| 18 |
def create_demo(init_params=None):
|
|
|
|
| 9 |
for proxy_var in ['http_proxy', 'https_proxy', 'HTTP_PROXY', 'HTTPS_PROXY', 'ALL_PROXY']:
|
| 10 |
os.environ.pop(proxy_var, None)
|
| 11 |
|
| 12 |
+
try:
|
| 13 |
+
# When executed as a module: `python -m acestep.acestep_v15_pipeline`
|
| 14 |
+
from .handler import AceStepHandler
|
| 15 |
+
from .llm_inference import LLMHandler
|
| 16 |
+
from .dataset_handler import DatasetHandler
|
| 17 |
+
from .gradio_ui import create_gradio_interface
|
| 18 |
+
except ImportError:
|
| 19 |
+
# When executed as a script: `python acestep/acestep_v15_pipeline.py`
|
| 20 |
+
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 21 |
+
if project_root not in sys.path:
|
| 22 |
+
sys.path.insert(0, project_root)
|
| 23 |
+
from acestep.handler import AceStepHandler
|
| 24 |
+
from acestep.llm_inference import LLMHandler
|
| 25 |
+
from acestep.dataset_handler import DatasetHandler
|
| 26 |
+
from acestep.gradio_ui import create_gradio_interface
|
| 27 |
|
| 28 |
|
| 29 |
def create_demo(init_params=None):
|
acestep/api_server.py
CHANGED
|
@@ -51,8 +51,9 @@ class GenerateMusicRequest(BaseModel):
|
|
| 51 |
thinking: bool = False
|
| 52 |
|
| 53 |
bpm: Optional[int] = None
|
| 54 |
-
|
| 55 |
-
|
|
|
|
| 56 |
vocal_language: str = "en"
|
| 57 |
inference_steps: int = 8
|
| 58 |
guidance_scale: float = 7.0
|
|
@@ -61,7 +62,7 @@ class GenerateMusicRequest(BaseModel):
|
|
| 61 |
|
| 62 |
reference_audio_path: Optional[str] = None
|
| 63 |
src_audio_path: Optional[str] = None
|
| 64 |
-
audio_duration: Optional[float] = None
|
| 65 |
batch_size: Optional[int] = None
|
| 66 |
|
| 67 |
audio_code_string: str = ""
|
|
@@ -93,6 +94,10 @@ class GenerateMusicRequest(BaseModel):
|
|
| 93 |
lm_repetition_penalty: float = 1.0
|
| 94 |
lm_negative_prompt: str = "NO USER INPUT"
|
| 95 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
|
| 97 |
_LM_DEFAULT_TEMPERATURE = 0.85
|
| 98 |
_LM_DEFAULT_CFG_SCALE = 2.0
|
|
@@ -501,62 +506,6 @@ def create_app() -> FastAPI:
|
|
| 501 |
max_dur = float(os.getenv("ACESTEP_LYRICS_MAX_DURATION_SECONDS", "180"))
|
| 502 |
return float(min(max(est, min_dur), max_dur))
|
| 503 |
|
| 504 |
-
def _extract_lm_fields(meta: Dict[str, Any]) -> Dict[str, Any]:
|
| 505 |
-
def _parse_first_float(v: Any) -> Optional[float]:
|
| 506 |
-
if v is None:
|
| 507 |
-
return None
|
| 508 |
-
if isinstance(v, (int, float)):
|
| 509 |
-
return float(v)
|
| 510 |
-
s = str(v).strip()
|
| 511 |
-
if not s or s.upper() == "N/A":
|
| 512 |
-
return None
|
| 513 |
-
try:
|
| 514 |
-
return float(s)
|
| 515 |
-
except Exception:
|
| 516 |
-
pass
|
| 517 |
-
m = re.search(r"[-+]?\d*\.?\d+", s)
|
| 518 |
-
if not m:
|
| 519 |
-
return None
|
| 520 |
-
try:
|
| 521 |
-
return float(m.group(0))
|
| 522 |
-
except Exception:
|
| 523 |
-
return None
|
| 524 |
-
|
| 525 |
-
def _parse_first_int(v: Any) -> Optional[int]:
|
| 526 |
-
fv = _parse_first_float(v)
|
| 527 |
-
if fv is None:
|
| 528 |
-
return None
|
| 529 |
-
try:
|
| 530 |
-
return int(round(fv))
|
| 531 |
-
except Exception:
|
| 532 |
-
return None
|
| 533 |
-
|
| 534 |
-
def _none_if_na(v: Any) -> Any:
|
| 535 |
-
if v is None:
|
| 536 |
-
return None
|
| 537 |
-
if isinstance(v, str) and v.strip() in {"", "N/A"}:
|
| 538 |
-
return None
|
| 539 |
-
return v
|
| 540 |
-
|
| 541 |
-
out: Dict[str, Any] = {}
|
| 542 |
-
|
| 543 |
-
bpm_raw = _none_if_na(meta.get("bpm"))
|
| 544 |
-
out["bpm"] = _parse_first_int(bpm_raw)
|
| 545 |
-
|
| 546 |
-
dur_raw = _none_if_na(meta.get("duration"))
|
| 547 |
-
out["duration"] = _parse_first_float(dur_raw)
|
| 548 |
-
|
| 549 |
-
genres_raw = _none_if_na(meta.get("genres"))
|
| 550 |
-
out["genres"] = str(genres_raw) if genres_raw is not None else None
|
| 551 |
-
|
| 552 |
-
keyscale_raw = _none_if_na(meta.get("keyscale", meta.get("key_scale")))
|
| 553 |
-
out["keyscale"] = str(keyscale_raw) if keyscale_raw is not None else None
|
| 554 |
-
|
| 555 |
-
ts_raw = _none_if_na(meta.get("timesignature", meta.get("time_signature")))
|
| 556 |
-
out["timesignature"] = str(ts_raw) if ts_raw is not None else None
|
| 557 |
-
|
| 558 |
-
return out
|
| 559 |
-
|
| 560 |
def _normalize_metas(meta: Dict[str, Any]) -> Dict[str, Any]:
|
| 561 |
"""Ensure a stable `metas` dict (keys always present)."""
|
| 562 |
meta = meta or {}
|
|
@@ -587,7 +536,7 @@ def create_app() -> FastAPI:
|
|
| 587 |
# We keep backward compatibility: only auto-adjust when user didn't override (still at default 1.0).
|
| 588 |
audio_cover_strength_val = float(req.audio_cover_strength)
|
| 589 |
|
| 590 |
-
|
| 591 |
|
| 592 |
# Determine effective batch size (used for per-sample LM code diversity)
|
| 593 |
effective_batch_size = req.batch_size
|
|
@@ -656,6 +605,7 @@ def create_app() -> FastAPI:
|
|
| 656 |
)
|
| 657 |
|
| 658 |
meta, codes, status = _lm_call()
|
|
|
|
| 659 |
|
| 660 |
if need_lm_codes:
|
| 661 |
if not codes:
|
|
@@ -668,12 +618,6 @@ def create_app() -> FastAPI:
|
|
| 668 |
else:
|
| 669 |
audio_code_string = codes
|
| 670 |
|
| 671 |
-
# Always expose LM metas when we invoked LM (even if user already set some fields).
|
| 672 |
-
lm_fields = {
|
| 673 |
-
"metas": _normalize_metas(meta),
|
| 674 |
-
**_extract_lm_fields(meta),
|
| 675 |
-
}
|
| 676 |
-
|
| 677 |
# Fill only missing fields (user-provided values win)
|
| 678 |
bpm_val, key_scale_val, time_sig_val, audio_duration_val = _maybe_fill_from_metadata(req, meta)
|
| 679 |
|
|
@@ -711,6 +655,25 @@ def create_app() -> FastAPI:
|
|
| 711 |
# thinking=True requires codes generation.
|
| 712 |
raise RuntimeError("thinking=true requires non-empty audio codes (LM generation failed).")
|
| 713 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 714 |
first, second, paths, gen_info, status_msg, seed_value, *_ = h.generate_music(
|
| 715 |
captions=req.caption,
|
| 716 |
lyrics=req.lyrics,
|
|
@@ -746,7 +709,12 @@ def create_app() -> FastAPI:
|
|
| 746 |
"generation_info": gen_info,
|
| 747 |
"status_message": status_msg,
|
| 748 |
"seed_value": seed_value,
|
| 749 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 750 |
}
|
| 751 |
|
| 752 |
t0 = time.time()
|
|
@@ -815,13 +783,20 @@ def create_app() -> FastAPI:
|
|
| 815 |
if not callable(get):
|
| 816 |
raise HTTPException(status_code=400, detail="Invalid request payload")
|
| 817 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 818 |
return GenerateMusicRequest(
|
| 819 |
caption=str(get("caption", "") or ""),
|
| 820 |
lyrics=str(get("lyrics", "") or ""),
|
| 821 |
thinking=_to_bool(get("thinking"), False),
|
| 822 |
bpm=_to_int(get("bpm"), None),
|
| 823 |
-
key_scale=str(
|
| 824 |
-
time_signature=str(
|
| 825 |
vocal_language=str(get("vocal_language", "en") or "en"),
|
| 826 |
inference_steps=_to_int(get("inference_steps"), 8) or 8,
|
| 827 |
guidance_scale=_to_float(get("guidance_scale"), 7.0) or 7.0,
|
|
@@ -829,7 +804,7 @@ def create_app() -> FastAPI:
|
|
| 829 |
seed=_to_int(get("seed"), -1) or -1,
|
| 830 |
reference_audio_path=reference_audio_path,
|
| 831 |
src_audio_path=src_audio_path,
|
| 832 |
-
audio_duration=_to_float(
|
| 833 |
batch_size=_to_int(get("batch_size"), None),
|
| 834 |
audio_code_string=str(get("audio_code_string", "") or ""),
|
| 835 |
repainting_start=_to_float(get("repainting_start"), 0.0) or 0.0,
|
|
|
|
| 51 |
thinking: bool = False
|
| 52 |
|
| 53 |
bpm: Optional[int] = None
|
| 54 |
+
# Accept common client keys while keeping internal field names stable.
|
| 55 |
+
key_scale: str = Field(default="", alias="keyscale")
|
| 56 |
+
time_signature: str = Field(default="", alias="timesignature")
|
| 57 |
vocal_language: str = "en"
|
| 58 |
inference_steps: int = 8
|
| 59 |
guidance_scale: float = 7.0
|
|
|
|
| 62 |
|
| 63 |
reference_audio_path: Optional[str] = None
|
| 64 |
src_audio_path: Optional[str] = None
|
| 65 |
+
audio_duration: Optional[float] = Field(default=None, alias="duration")
|
| 66 |
batch_size: Optional[int] = None
|
| 67 |
|
| 68 |
audio_code_string: str = ""
|
|
|
|
| 94 |
lm_repetition_penalty: float = 1.0
|
| 95 |
lm_negative_prompt: str = "NO USER INPUT"
|
| 96 |
|
| 97 |
+
class Config:
|
| 98 |
+
allow_population_by_field_name = True
|
| 99 |
+
allow_population_by_alias = True
|
| 100 |
+
|
| 101 |
|
| 102 |
_LM_DEFAULT_TEMPERATURE = 0.85
|
| 103 |
_LM_DEFAULT_CFG_SCALE = 2.0
|
|
|
|
| 506 |
max_dur = float(os.getenv("ACESTEP_LYRICS_MAX_DURATION_SECONDS", "180"))
|
| 507 |
return float(min(max(est, min_dur), max_dur))
|
| 508 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 509 |
def _normalize_metas(meta: Dict[str, Any]) -> Dict[str, Any]:
|
| 510 |
"""Ensure a stable `metas` dict (keys always present)."""
|
| 511 |
meta = meta or {}
|
|
|
|
| 536 |
# We keep backward compatibility: only auto-adjust when user didn't override (still at default 1.0).
|
| 537 |
audio_cover_strength_val = float(req.audio_cover_strength)
|
| 538 |
|
| 539 |
+
lm_meta: Optional[Dict[str, Any]] = None
|
| 540 |
|
| 541 |
# Determine effective batch size (used for per-sample LM code diversity)
|
| 542 |
effective_batch_size = req.batch_size
|
|
|
|
| 605 |
)
|
| 606 |
|
| 607 |
meta, codes, status = _lm_call()
|
| 608 |
+
lm_meta = meta
|
| 609 |
|
| 610 |
if need_lm_codes:
|
| 611 |
if not codes:
|
|
|
|
| 618 |
else:
|
| 619 |
audio_code_string = codes
|
| 620 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 621 |
# Fill only missing fields (user-provided values win)
|
| 622 |
bpm_val, key_scale_val, time_sig_val, audio_duration_val = _maybe_fill_from_metadata(req, meta)
|
| 623 |
|
|
|
|
| 655 |
# thinking=True requires codes generation.
|
| 656 |
raise RuntimeError("thinking=true requires non-empty audio codes (LM generation failed).")
|
| 657 |
|
| 658 |
+
# Response metas MUST reflect the actual values used by DiT.
|
| 659 |
+
metas_out = _normalize_metas(lm_meta or {})
|
| 660 |
+
if bpm_val is not None and int(bpm_val) > 0:
|
| 661 |
+
metas_out["bpm"] = int(bpm_val)
|
| 662 |
+
if audio_duration_val is not None and float(audio_duration_val) > 0:
|
| 663 |
+
metas_out["duration"] = float(audio_duration_val)
|
| 664 |
+
if (key_scale_val or "").strip():
|
| 665 |
+
metas_out["keyscale"] = str(key_scale_val)
|
| 666 |
+
if (time_sig_val or "").strip():
|
| 667 |
+
metas_out["timesignature"] = str(time_sig_val)
|
| 668 |
+
|
| 669 |
+
def _none_if_na_str(v: Any) -> Optional[str]:
|
| 670 |
+
if v is None:
|
| 671 |
+
return None
|
| 672 |
+
s = str(v).strip()
|
| 673 |
+
if s in {"", "N/A"}:
|
| 674 |
+
return None
|
| 675 |
+
return s
|
| 676 |
+
|
| 677 |
first, second, paths, gen_info, status_msg, seed_value, *_ = h.generate_music(
|
| 678 |
captions=req.caption,
|
| 679 |
lyrics=req.lyrics,
|
|
|
|
| 709 |
"generation_info": gen_info,
|
| 710 |
"status_message": status_msg,
|
| 711 |
"seed_value": seed_value,
|
| 712 |
+
"metas": metas_out,
|
| 713 |
+
"bpm": int(bpm_val) if bpm_val is not None else None,
|
| 714 |
+
"duration": float(audio_duration_val) if audio_duration_val is not None else None,
|
| 715 |
+
"genres": _none_if_na_str(metas_out.get("genres")),
|
| 716 |
+
"keyscale": _none_if_na_str(metas_out.get("keyscale")),
|
| 717 |
+
"timesignature": _none_if_na_str(metas_out.get("timesignature")),
|
| 718 |
}
|
| 719 |
|
| 720 |
t0 = time.time()
|
|
|
|
| 783 |
if not callable(get):
|
| 784 |
raise HTTPException(status_code=400, detail="Invalid request payload")
|
| 785 |
|
| 786 |
+
def _get_any(*keys: str, default: Any = None) -> Any:
|
| 787 |
+
for k in keys:
|
| 788 |
+
v = get(k, None)
|
| 789 |
+
if v is not None:
|
| 790 |
+
return v
|
| 791 |
+
return default
|
| 792 |
+
|
| 793 |
return GenerateMusicRequest(
|
| 794 |
caption=str(get("caption", "") or ""),
|
| 795 |
lyrics=str(get("lyrics", "") or ""),
|
| 796 |
thinking=_to_bool(get("thinking"), False),
|
| 797 |
bpm=_to_int(get("bpm"), None),
|
| 798 |
+
key_scale=str(_get_any("key_scale", "keyscale", default="") or ""),
|
| 799 |
+
time_signature=str(_get_any("time_signature", "timesignature", default="") or ""),
|
| 800 |
vocal_language=str(get("vocal_language", "en") or "en"),
|
| 801 |
inference_steps=_to_int(get("inference_steps"), 8) or 8,
|
| 802 |
guidance_scale=_to_float(get("guidance_scale"), 7.0) or 7.0,
|
|
|
|
| 804 |
seed=_to_int(get("seed"), -1) or -1,
|
| 805 |
reference_audio_path=reference_audio_path,
|
| 806 |
src_audio_path=src_audio_path,
|
| 807 |
+
audio_duration=_to_float(_get_any("audio_duration", "duration"), None),
|
| 808 |
batch_size=_to_int(get("batch_size"), None),
|
| 809 |
audio_code_string=str(get("audio_code_string", "") or ""),
|
| 810 |
repainting_start=_to_float(get("repainting_start"), 0.0) or 0.0,
|
acestep/handler.py
CHANGED
|
@@ -932,7 +932,14 @@ class AceStepHandler:
|
|
| 932 |
is_repaint_task = (task_type == "repaint")
|
| 933 |
is_lego_task = (task_type == "lego")
|
| 934 |
is_cover_task = (task_type == "cover")
|
| 935 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 936 |
is_cover_task = True
|
| 937 |
# Both repaint and lego tasks can use repainting parameters for chunk mask
|
| 938 |
can_use_repainting = is_repaint_task or is_lego_task
|
|
@@ -1371,10 +1378,16 @@ class AceStepHandler:
|
|
| 1371 |
# Pad or crop to match max_latent_length
|
| 1372 |
if hints.shape[1] < max_latent_length:
|
| 1373 |
pad_length = max_latent_length - hints.shape[1]
|
| 1374 |
-
|
| 1375 |
-
|
| 1376 |
-
|
| 1377 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1378 |
elif hints.shape[1] > max_latent_length:
|
| 1379 |
hints = hints[:, :max_latent_length, :]
|
| 1380 |
precomputed_lm_hints_25Hz_list.append(hints[0]) # Remove batch dimension
|
|
@@ -1553,19 +1566,45 @@ class AceStepHandler:
|
|
| 1553 |
def infer_refer_latent(self, refer_audioss):
|
| 1554 |
refer_audio_order_mask = []
|
| 1555 |
refer_audio_latents = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1556 |
for batch_idx, refer_audios in enumerate(refer_audioss):
|
| 1557 |
if len(refer_audios) == 1 and torch.all(refer_audios[0] == 0.0):
|
| 1558 |
-
refer_audio_latent = self.silence_latent[:, :750, :]
|
| 1559 |
refer_audio_latents.append(refer_audio_latent)
|
| 1560 |
refer_audio_order_mask.append(batch_idx)
|
| 1561 |
else:
|
| 1562 |
for refer_audio in refer_audios:
|
|
|
|
| 1563 |
# Ensure input is in VAE's dtype
|
| 1564 |
vae_input = refer_audio.unsqueeze(0).to(self.vae.dtype)
|
| 1565 |
refer_audio_latent = self.vae.encode(vae_input).latent_dist.sample()
|
| 1566 |
# Cast back to model dtype
|
| 1567 |
refer_audio_latent = refer_audio_latent.to(self.dtype)
|
| 1568 |
-
refer_audio_latents.append(refer_audio_latent.transpose(1, 2))
|
| 1569 |
refer_audio_order_mask.append(batch_idx)
|
| 1570 |
|
| 1571 |
refer_audio_latents = torch.cat(refer_audio_latents, dim=0)
|
|
@@ -1949,7 +1988,7 @@ class AceStepHandler:
|
|
| 1949 |
audio_duration: Optional[float] = None,
|
| 1950 |
batch_size: Optional[int] = None,
|
| 1951 |
src_audio=None,
|
| 1952 |
-
audio_code_string: str = "",
|
| 1953 |
repainting_start: float = 0.0,
|
| 1954 |
repainting_end: Optional[float] = None,
|
| 1955 |
instruction: str = "Fill the audio semantic mask based on the given conditions:",
|
|
@@ -1978,11 +2017,16 @@ class AceStepHandler:
|
|
| 1978 |
if self.model is None or self.vae is None or self.text_tokenizer is None or self.text_encoder is None:
|
| 1979 |
return None, None, [], "", "❌ Model not fully initialized. Please initialize all components first.", "-1", "", "", None, "", "", None
|
| 1980 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1981 |
# Auto-detect task type based on audio_code_string
|
| 1982 |
# If audio_code_string is provided and not empty, use cover task
|
| 1983 |
# Otherwise, use text2music task (or keep current task_type if not text2music)
|
| 1984 |
if task_type == "text2music":
|
| 1985 |
-
if
|
| 1986 |
# User has provided audio codes, switch to cover task
|
| 1987 |
task_type = "cover"
|
| 1988 |
# Update instruction for cover task
|
|
@@ -2031,7 +2075,7 @@ class AceStepHandler:
|
|
| 2031 |
processed_src_audio = None
|
| 2032 |
if src_audio is not None:
|
| 2033 |
# Check if audio codes are provided - if so, ignore src_audio
|
| 2034 |
-
if
|
| 2035 |
logger.info("[generate_music] Audio codes provided, ignoring src_audio and using codes instead")
|
| 2036 |
else:
|
| 2037 |
logger.info("[generate_music] Processing source audio...")
|
|
@@ -2070,9 +2114,11 @@ class AceStepHandler:
|
|
| 2070 |
# Prepare audio_code_hints - use if audio_code_string is provided
|
| 2071 |
# This works for both text2music (auto-switched to cover) and cover tasks
|
| 2072 |
audio_code_hints_batch = None
|
| 2073 |
-
if
|
| 2074 |
-
|
| 2075 |
-
|
|
|
|
|
|
|
| 2076 |
|
| 2077 |
should_return_intermediate = (task_type == "text2music")
|
| 2078 |
outputs = self.service_generate(
|
|
|
|
| 932 |
is_repaint_task = (task_type == "repaint")
|
| 933 |
is_lego_task = (task_type == "lego")
|
| 934 |
is_cover_task = (task_type == "cover")
|
| 935 |
+
|
| 936 |
+
has_codes = False
|
| 937 |
+
if isinstance(audio_code_string, list):
|
| 938 |
+
has_codes = any((c or "").strip() for c in audio_code_string)
|
| 939 |
+
else:
|
| 940 |
+
has_codes = bool(audio_code_string and str(audio_code_string).strip())
|
| 941 |
+
|
| 942 |
+
if has_codes:
|
| 943 |
is_cover_task = True
|
| 944 |
# Both repaint and lego tasks can use repainting parameters for chunk mask
|
| 945 |
can_use_repainting = is_repaint_task or is_lego_task
|
|
|
|
| 1378 |
# Pad or crop to match max_latent_length
|
| 1379 |
if hints.shape[1] < max_latent_length:
|
| 1380 |
pad_length = max_latent_length - hints.shape[1]
|
| 1381 |
+
pad = self.silence_latent
|
| 1382 |
+
# Match dims: hints is usually [1, T, D], silence_latent is [1, T, D]
|
| 1383 |
+
if pad.dim() == 2:
|
| 1384 |
+
pad = pad.unsqueeze(0)
|
| 1385 |
+
if hints.dim() == 2:
|
| 1386 |
+
hints = hints.unsqueeze(0)
|
| 1387 |
+
pad_chunk = pad[:, :pad_length, :]
|
| 1388 |
+
if pad_chunk.device != hints.device or pad_chunk.dtype != hints.dtype:
|
| 1389 |
+
pad_chunk = pad_chunk.to(device=hints.device, dtype=hints.dtype)
|
| 1390 |
+
hints = torch.cat([hints, pad_chunk], dim=1)
|
| 1391 |
elif hints.shape[1] > max_latent_length:
|
| 1392 |
hints = hints[:, :max_latent_length, :]
|
| 1393 |
precomputed_lm_hints_25Hz_list.append(hints[0]) # Remove batch dimension
|
|
|
|
| 1566 |
def infer_refer_latent(self, refer_audioss):
|
| 1567 |
refer_audio_order_mask = []
|
| 1568 |
refer_audio_latents = []
|
| 1569 |
+
|
| 1570 |
+
def _normalize_audio_2d(a: torch.Tensor) -> torch.Tensor:
|
| 1571 |
+
"""Normalize audio tensor to [2, T] on current device."""
|
| 1572 |
+
if not isinstance(a, torch.Tensor):
|
| 1573 |
+
raise TypeError(f"refer_audio must be a torch.Tensor, got {type(a)!r}")
|
| 1574 |
+
# Accept [T], [1, T], [2, T], [1, 2, T]
|
| 1575 |
+
if a.dim() == 3 and a.shape[0] == 1:
|
| 1576 |
+
a = a.squeeze(0)
|
| 1577 |
+
if a.dim() == 1:
|
| 1578 |
+
a = a.unsqueeze(0)
|
| 1579 |
+
if a.dim() != 2:
|
| 1580 |
+
raise ValueError(f"refer_audio must be 1D/2D/3D(1,2,T); got shape={tuple(a.shape)}")
|
| 1581 |
+
if a.shape[0] == 1:
|
| 1582 |
+
a = torch.cat([a, a], dim=0)
|
| 1583 |
+
a = a[:2]
|
| 1584 |
+
return a
|
| 1585 |
+
|
| 1586 |
+
def _ensure_latent_3d(z: torch.Tensor) -> torch.Tensor:
|
| 1587 |
+
"""Ensure latent is [N, T, D] (3D) for packing."""
|
| 1588 |
+
if z.dim() == 4 and z.shape[0] == 1:
|
| 1589 |
+
z = z.squeeze(0)
|
| 1590 |
+
if z.dim() == 2:
|
| 1591 |
+
z = z.unsqueeze(0)
|
| 1592 |
+
return z
|
| 1593 |
+
|
| 1594 |
for batch_idx, refer_audios in enumerate(refer_audioss):
|
| 1595 |
if len(refer_audios) == 1 and torch.all(refer_audios[0] == 0.0):
|
| 1596 |
+
refer_audio_latent = _ensure_latent_3d(self.silence_latent[:, :750, :])
|
| 1597 |
refer_audio_latents.append(refer_audio_latent)
|
| 1598 |
refer_audio_order_mask.append(batch_idx)
|
| 1599 |
else:
|
| 1600 |
for refer_audio in refer_audios:
|
| 1601 |
+
refer_audio = _normalize_audio_2d(refer_audio)
|
| 1602 |
# Ensure input is in VAE's dtype
|
| 1603 |
vae_input = refer_audio.unsqueeze(0).to(self.vae.dtype)
|
| 1604 |
refer_audio_latent = self.vae.encode(vae_input).latent_dist.sample()
|
| 1605 |
# Cast back to model dtype
|
| 1606 |
refer_audio_latent = refer_audio_latent.to(self.dtype)
|
| 1607 |
+
refer_audio_latents.append(_ensure_latent_3d(refer_audio_latent.transpose(1, 2)))
|
| 1608 |
refer_audio_order_mask.append(batch_idx)
|
| 1609 |
|
| 1610 |
refer_audio_latents = torch.cat(refer_audio_latents, dim=0)
|
|
|
|
| 1988 |
audio_duration: Optional[float] = None,
|
| 1989 |
batch_size: Optional[int] = None,
|
| 1990 |
src_audio=None,
|
| 1991 |
+
audio_code_string: Union[str, List[str]] = "",
|
| 1992 |
repainting_start: float = 0.0,
|
| 1993 |
repainting_end: Optional[float] = None,
|
| 1994 |
instruction: str = "Fill the audio semantic mask based on the given conditions:",
|
|
|
|
| 2017 |
if self.model is None or self.vae is None or self.text_tokenizer is None or self.text_encoder is None:
|
| 2018 |
return None, None, [], "", "❌ Model not fully initialized. Please initialize all components first.", "-1", "", "", None, "", "", None
|
| 2019 |
|
| 2020 |
+
def _has_audio_codes(v: Union[str, List[str]]) -> bool:
|
| 2021 |
+
if isinstance(v, list):
|
| 2022 |
+
return any((x or "").strip() for x in v)
|
| 2023 |
+
return bool(v and str(v).strip())
|
| 2024 |
+
|
| 2025 |
# Auto-detect task type based on audio_code_string
|
| 2026 |
# If audio_code_string is provided and not empty, use cover task
|
| 2027 |
# Otherwise, use text2music task (or keep current task_type if not text2music)
|
| 2028 |
if task_type == "text2music":
|
| 2029 |
+
if _has_audio_codes(audio_code_string):
|
| 2030 |
# User has provided audio codes, switch to cover task
|
| 2031 |
task_type = "cover"
|
| 2032 |
# Update instruction for cover task
|
|
|
|
| 2075 |
processed_src_audio = None
|
| 2076 |
if src_audio is not None:
|
| 2077 |
# Check if audio codes are provided - if so, ignore src_audio
|
| 2078 |
+
if _has_audio_codes(audio_code_string):
|
| 2079 |
logger.info("[generate_music] Audio codes provided, ignoring src_audio and using codes instead")
|
| 2080 |
else:
|
| 2081 |
logger.info("[generate_music] Processing source audio...")
|
|
|
|
| 2114 |
# Prepare audio_code_hints - use if audio_code_string is provided
|
| 2115 |
# This works for both text2music (auto-switched to cover) and cover tasks
|
| 2116 |
audio_code_hints_batch = None
|
| 2117 |
+
if _has_audio_codes(audio_code_string):
|
| 2118 |
+
if isinstance(audio_code_string, list):
|
| 2119 |
+
audio_code_hints_batch = audio_code_string
|
| 2120 |
+
else:
|
| 2121 |
+
audio_code_hints_batch = [audio_code_string] * actual_batch_size
|
| 2122 |
|
| 2123 |
should_return_intermediate = (task_type == "text2music")
|
| 2124 |
outputs = self.service_generate(
|