|
|
""" |
|
|
ACE-Step Inference API Module |
|
|
|
|
|
This module provides a standardized inference interface for music generation, |
|
|
designed for third-party integration. It offers both a simplified API and |
|
|
backward-compatible Gradio UI support. |
|
|
""" |
|
|
|
|
|
import math |
|
|
import os |
|
|
import tempfile |
|
|
from typing import Optional, Union, List, Dict, Any, Tuple |
|
|
from dataclasses import dataclass, field, asdict |
|
|
from loguru import logger |
|
|
|
|
|
from acestep.audio_utils import AudioSaver, generate_uuid_from_params |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class GenerationParams: |
|
|
"""Configuration for music generation parameters. |
|
|
|
|
|
Attributes: |
|
|
# Text Inputs |
|
|
caption: A short text prompt describing the desired music (main prompt). < 512 characters |
|
|
lyrics: Lyrics for the music. Use "[Instrumental]" for instrumental songs. < 4096 characters |
|
|
instrumental: If True, generate instrumental music regardless of lyrics. |
|
|
|
|
|
# Music Metadata |
|
|
bpm: BPM (beats per minute), e.g., 120. Set to None for automatic estimation. 30 ~ 300 |
|
|
keyscale: Musical key (e.g., "C Major", "Am"). Leave empty for auto-detection. A-G, #/♭, major/minor |
|
|
timesignature: Time signature (2 for '2/4', 3 for '3/4', 4 for '4/4', 6 for '6/8'). Leave empty for auto-detection. |
|
|
vocal_language: Language code for vocals, e.g., "en", "zh", "ja", or "unknown". see acestep/constants.py:VALID_LANGUAGES |
|
|
duration: Target audio length in seconds. If <0 or None, model chooses automatically. 10 ~ 600 |
|
|
|
|
|
# Generation Parameters |
|
|
inference_steps: Number of diffusion steps (e.g., 8 for turbo, 32–100 for base model). |
|
|
guidance_scale: CFG (classifier-free guidance) strength. Higher means following the prompt more strictly. Only support for non-turbo model. |
|
|
seed: Integer seed for reproducibility. -1 means use random seed each time. |
|
|
|
|
|
# Advanced DiT Parameters |
|
|
use_adg: Whether to use Adaptive Dual Guidance (only works for base model). |
|
|
cfg_interval_start: Start ratio (0.0–1.0) to apply CFG. |
|
|
cfg_interval_end: End ratio (0.0–1.0) to apply CFG. |
|
|
shift: Timestep shift factor (default 1.0). When != 1.0, applies t = shift * t / (1 + (shift - 1) * t) to timesteps. |
|
|
|
|
|
# Task-Specific Parameters |
|
|
task_type: Type of generation task. One of: "text2music", "cover", "repaint", "lego", "extract", "complete". |
|
|
reference_audio: Path to a reference audio file for style transfer or cover tasks. |
|
|
src_audio: Path to a source audio file for audio-to-audio tasks. |
|
|
audio_codes: Audio semantic codes as a string (advanced use, for code-control generation). |
|
|
repainting_start: For repaint/lego tasks: start time in seconds for region to repaint. |
|
|
repainting_end: For repaint/lego tasks: end time in seconds for region to repaint (-1 for until end). |
|
|
audio_cover_strength: Strength of reference audio/codes influence (range 0.0–1.0). set smaller (0.2) for style transfer tasks. |
|
|
instruction: Optional task instruction prompt. If empty, auto-generated by system. |
|
|
|
|
|
# 5Hz Language Model Parameters for CoT reasoning |
|
|
thinking: If True, enable 5Hz Language Model "Chain-of-Thought" reasoning for semantic/music metadata and codes. |
|
|
lm_temperature: Sampling temperature for the LLM (0.0–2.0). Higher = more creative/varied results. |
|
|
lm_cfg_scale: Classifier-free guidance scale for the LLM. |
|
|
lm_top_k: LLM top-k sampling (0 = disabled). |
|
|
lm_top_p: LLM top-p nucleus sampling (1.0 = disabled). |
|
|
lm_negative_prompt: Negative prompt to use for LLM (for control). |
|
|
use_cot_metas: Whether to let LLM generate music metadata via CoT reasoning. |
|
|
use_cot_caption: Whether to let LLM rewrite or format the input caption via CoT reasoning. |
|
|
use_cot_language: Whether to let LLM detect vocal language via CoT. |
|
|
""" |
|
|
|
|
|
task_type: str = "text2music" |
|
|
instruction: str = "Fill the audio semantic mask based on the given conditions:" |
|
|
|
|
|
|
|
|
reference_audio: Optional[str] = None |
|
|
src_audio: Optional[str] = None |
|
|
|
|
|
|
|
|
audio_codes: str = "" |
|
|
|
|
|
|
|
|
caption: str = "" |
|
|
lyrics: str = "" |
|
|
instrumental: bool = False |
|
|
|
|
|
|
|
|
vocal_language: str = "unknown" |
|
|
bpm: Optional[int] = None |
|
|
keyscale: str = "" |
|
|
timesignature: str = "" |
|
|
duration: float = -1.0 |
|
|
|
|
|
|
|
|
inference_steps: int = 8 |
|
|
seed: int = -1 |
|
|
guidance_scale: float = 7.0 |
|
|
use_adg: bool = False |
|
|
cfg_interval_start: float = 0.0 |
|
|
cfg_interval_end: float = 1.0 |
|
|
shift: float = 1.0 |
|
|
infer_method: str = "ode" |
|
|
|
|
|
|
|
|
timesteps: Optional[List[float]] = None |
|
|
|
|
|
repainting_start: float = 0.0 |
|
|
repainting_end: float = -1 |
|
|
audio_cover_strength: float = 1.0 |
|
|
|
|
|
|
|
|
thinking: bool = True |
|
|
lm_temperature: float = 0.85 |
|
|
lm_cfg_scale: float = 2.0 |
|
|
lm_top_k: int = 0 |
|
|
lm_top_p: float = 0.9 |
|
|
lm_negative_prompt: str = "NO USER INPUT" |
|
|
use_cot_metas: bool = True |
|
|
use_cot_caption: bool = True |
|
|
use_cot_lyrics: bool = False |
|
|
use_cot_language: bool = True |
|
|
use_constrained_decoding: bool = True |
|
|
|
|
|
cot_bpm: Optional[int] = None |
|
|
cot_keyscale: str = "" |
|
|
cot_timesignature: str = "" |
|
|
cot_duration: Optional[float] = None |
|
|
cot_vocal_language: str = "unknown" |
|
|
cot_caption: str = "" |
|
|
cot_lyrics: str = "" |
|
|
|
|
|
def to_dict(self) -> Dict[str, Any]: |
|
|
"""Convert config to dictionary for JSON serialization.""" |
|
|
return asdict(self) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class GenerationConfig: |
|
|
"""Configuration for music generation. |
|
|
|
|
|
Attributes: |
|
|
batch_size: Number of audio samples to generate |
|
|
allow_lm_batch: Whether to allow batch processing in LM |
|
|
use_random_seed: Whether to use random seed |
|
|
seeds: Seed(s) for batch generation. Can be: |
|
|
- None: Use random seeds (when use_random_seed=True) or params.seed (when use_random_seed=False) |
|
|
- List[int]: List of seeds, will be padded with random seeds if fewer than batch_size |
|
|
- int: Single seed value (will be converted to list and padded) |
|
|
lm_batch_chunk_size: Batch chunk size for LM processing |
|
|
constrained_decoding_debug: Whether to enable constrained decoding debug |
|
|
audio_format: Output audio format, one of "mp3", "wav", "flac". Default: "flac" |
|
|
""" |
|
|
batch_size: int = 2 |
|
|
allow_lm_batch: bool = False |
|
|
use_random_seed: bool = True |
|
|
seeds: Optional[List[int]] = None |
|
|
lm_batch_chunk_size: int = 8 |
|
|
constrained_decoding_debug: bool = False |
|
|
audio_format: str = "flac" |
|
|
|
|
|
def to_dict(self) -> Dict[str, Any]: |
|
|
"""Convert config to dictionary for JSON serialization.""" |
|
|
return asdict(self) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class GenerationResult: |
|
|
"""Result of music generation. |
|
|
|
|
|
Attributes: |
|
|
# Audio Outputs |
|
|
audios: List of audio dictionaries with paths, keys, params |
|
|
status_message: Status message from generation |
|
|
extra_outputs: Extra outputs from generation |
|
|
success: Whether generation completed successfully |
|
|
error: Error message if generation failed |
|
|
""" |
|
|
|
|
|
|
|
|
audios: List[Dict[str, Any]] = field(default_factory=list) |
|
|
|
|
|
status_message: str = "" |
|
|
extra_outputs: Dict[str, Any] = field(default_factory=dict) |
|
|
|
|
|
success: bool = True |
|
|
error: Optional[str] = None |
|
|
|
|
|
def to_dict(self) -> Dict[str, Any]: |
|
|
"""Convert result to dictionary for JSON serialization.""" |
|
|
return asdict(self) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class UnderstandResult: |
|
|
"""Result of music understanding from audio codes. |
|
|
|
|
|
Attributes: |
|
|
# Metadata Fields |
|
|
caption: Generated caption describing the music |
|
|
lyrics: Generated or extracted lyrics |
|
|
bpm: Beats per minute (None if not detected) |
|
|
duration: Duration in seconds (None if not detected) |
|
|
keyscale: Musical key (e.g., "C Major") |
|
|
language: Vocal language code (e.g., "en", "zh") |
|
|
timesignature: Time signature (e.g., "4/4") |
|
|
|
|
|
# Status |
|
|
status_message: Status message from understanding |
|
|
success: Whether understanding completed successfully |
|
|
error: Error message if understanding failed |
|
|
""" |
|
|
|
|
|
caption: str = "" |
|
|
lyrics: str = "" |
|
|
bpm: Optional[int] = None |
|
|
duration: Optional[float] = None |
|
|
keyscale: str = "" |
|
|
language: str = "" |
|
|
timesignature: str = "" |
|
|
|
|
|
|
|
|
status_message: str = "" |
|
|
success: bool = True |
|
|
error: Optional[str] = None |
|
|
|
|
|
def to_dict(self) -> Dict[str, Any]: |
|
|
"""Convert result to dictionary for JSON serialization.""" |
|
|
return asdict(self) |
|
|
|
|
|
|
|
|
def _update_metadata_from_lm( |
|
|
metadata: Dict[str, Any], |
|
|
bpm: Optional[int], |
|
|
key_scale: str, |
|
|
time_signature: str, |
|
|
audio_duration: Optional[float], |
|
|
vocal_language: str, |
|
|
caption: str, |
|
|
lyrics: str, |
|
|
) -> Tuple[Optional[int], str, str, Optional[float]]: |
|
|
"""Update metadata fields from LM output if not provided by user.""" |
|
|
|
|
|
if bpm is None and metadata.get('bpm'): |
|
|
bpm_value = metadata.get('bpm') |
|
|
if bpm_value not in ["N/A", ""]: |
|
|
try: |
|
|
bpm = int(bpm_value) |
|
|
except (ValueError, TypeError): |
|
|
pass |
|
|
|
|
|
if not key_scale and metadata.get('keyscale'): |
|
|
key_scale_value = metadata.get('keyscale', metadata.get('key_scale', "")) |
|
|
if key_scale_value != "N/A": |
|
|
key_scale = key_scale_value |
|
|
|
|
|
if not time_signature and metadata.get('timesignature'): |
|
|
time_signature_value = metadata.get('timesignature', metadata.get('time_signature', "")) |
|
|
if time_signature_value != "N/A": |
|
|
time_signature = time_signature_value |
|
|
|
|
|
if audio_duration is None: |
|
|
audio_duration_value = metadata.get('duration', -1) |
|
|
if audio_duration_value not in ["N/A", ""]: |
|
|
try: |
|
|
audio_duration = float(audio_duration_value) |
|
|
except (ValueError, TypeError): |
|
|
pass |
|
|
|
|
|
if not vocal_language and metadata.get('vocal_language'): |
|
|
vocal_language = metadata.get('vocal_language') |
|
|
if not caption and metadata.get('caption'): |
|
|
caption = metadata.get('caption') |
|
|
if not lyrics and metadata.get('lyrics'): |
|
|
lyrics = metadata.get('lyrics') |
|
|
return bpm, key_scale, time_signature, audio_duration, vocal_language, caption, lyrics |
|
|
|
|
|
|
|
|
def generate_music( |
|
|
dit_handler, |
|
|
llm_handler, |
|
|
params: GenerationParams, |
|
|
config: GenerationConfig, |
|
|
save_dir: Optional[str] = None, |
|
|
progress=None, |
|
|
) -> GenerationResult: |
|
|
"""Generate music using ACE-Step model with optional LM reasoning. |
|
|
|
|
|
Args: |
|
|
dit_handler: Initialized DiT model handler (AceStepHandler instance) |
|
|
llm_handler: Initialized LLM handler (LLMHandler instance) |
|
|
params: Generation parameters (GenerationParams instance) |
|
|
config: Generation configuration (GenerationConfig instance) |
|
|
|
|
|
Returns: |
|
|
GenerationResult with generated audio files and metadata |
|
|
""" |
|
|
try: |
|
|
|
|
|
audio_code_string_to_use = params.audio_codes |
|
|
lm_generated_metadata = None |
|
|
lm_generated_audio_codes_list = [] |
|
|
lm_total_time_costs = { |
|
|
"phase1_time": 0.0, |
|
|
"phase2_time": 0.0, |
|
|
"total_time": 0.0, |
|
|
} |
|
|
|
|
|
|
|
|
bpm = params.bpm |
|
|
key_scale = params.keyscale |
|
|
time_signature = params.timesignature |
|
|
audio_duration = params.duration |
|
|
dit_input_caption = params.caption |
|
|
dit_input_vocal_language = params.vocal_language |
|
|
dit_input_lyrics = params.lyrics |
|
|
|
|
|
|
|
|
|
|
|
user_provided_audio_codes = bool(params.audio_codes and str(params.audio_codes).strip()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
need_audio_codes = not user_provided_audio_codes |
|
|
|
|
|
|
|
|
|
|
|
actual_batch_size = config.batch_size if config.batch_size is not None else 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
seed_for_generation = "" |
|
|
if config.seeds is not None and len(config.seeds) > 0: |
|
|
if isinstance(config.seeds, list): |
|
|
|
|
|
seed_for_generation = ",".join(str(s) for s in config.seeds) |
|
|
|
|
|
|
|
|
|
|
|
actual_seed_list, _ = dit_handler.prepare_seeds(actual_batch_size, seed_for_generation, config.use_random_seed) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
skip_lm_tasks = {"cover", "repaint"} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
need_lm_for_cot = params.use_cot_caption or params.use_cot_language or params.use_cot_metas |
|
|
use_lm = (params.thinking or need_lm_for_cot) and llm_handler.llm_initialized and params.task_type not in skip_lm_tasks |
|
|
lm_status = [] |
|
|
|
|
|
if params.task_type in skip_lm_tasks: |
|
|
logger.info(f"Skipping LM for task_type='{params.task_type}' - using DiT directly") |
|
|
|
|
|
logger.info(f"[generate_music] LLM usage decision: thinking={params.thinking}, " |
|
|
f"use_cot_caption={params.use_cot_caption}, use_cot_language={params.use_cot_language}, " |
|
|
f"use_cot_metas={params.use_cot_metas}, need_lm_for_cot={need_lm_for_cot}, " |
|
|
f"llm_initialized={llm_handler.llm_initialized if llm_handler else False}, use_lm={use_lm}") |
|
|
|
|
|
if use_lm: |
|
|
|
|
|
top_k_value = None if not params.lm_top_k or params.lm_top_k == 0 else int(params.lm_top_k) |
|
|
top_p_value = None if not params.lm_top_p or params.lm_top_p >= 1.0 else params.lm_top_p |
|
|
|
|
|
|
|
|
user_metadata = {} |
|
|
if bpm is not None: |
|
|
try: |
|
|
bpm_value = float(bpm) |
|
|
if bpm_value > 0: |
|
|
user_metadata['bpm'] = int(bpm_value) |
|
|
except (ValueError, TypeError): |
|
|
pass |
|
|
|
|
|
if key_scale and key_scale.strip(): |
|
|
key_scale_clean = key_scale.strip() |
|
|
if key_scale_clean.lower() not in ["n/a", ""]: |
|
|
user_metadata['keyscale'] = key_scale_clean |
|
|
|
|
|
if time_signature and time_signature.strip(): |
|
|
time_sig_clean = time_signature.strip() |
|
|
if time_sig_clean.lower() not in ["n/a", ""]: |
|
|
user_metadata['timesignature'] = time_sig_clean |
|
|
|
|
|
if audio_duration is not None: |
|
|
try: |
|
|
duration_value = float(audio_duration) |
|
|
if duration_value > 0: |
|
|
user_metadata['duration'] = int(duration_value) |
|
|
except (ValueError, TypeError): |
|
|
pass |
|
|
|
|
|
user_metadata_to_pass = user_metadata if user_metadata else None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
infer_type = "llm_dit" if need_audio_codes and params.thinking else "dit" |
|
|
|
|
|
|
|
|
max_inference_batch_size = int(config.lm_batch_chunk_size) if config.lm_batch_chunk_size > 0 else actual_batch_size |
|
|
num_chunks = math.ceil(actual_batch_size / max_inference_batch_size) |
|
|
|
|
|
all_metadata_list = [] |
|
|
all_audio_codes_list = [] |
|
|
|
|
|
for chunk_idx in range(num_chunks): |
|
|
chunk_start = chunk_idx * max_inference_batch_size |
|
|
chunk_end = min(chunk_start + max_inference_batch_size, actual_batch_size) |
|
|
chunk_size = chunk_end - chunk_start |
|
|
chunk_seeds = actual_seed_list[chunk_start:chunk_end] if chunk_start < len(actual_seed_list) else None |
|
|
|
|
|
logger.info(f"LM chunk {chunk_idx+1}/{num_chunks} (infer_type={infer_type}) " |
|
|
f"(size: {chunk_size}, seeds: {chunk_seeds})") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
result = llm_handler.generate_with_stop_condition( |
|
|
caption=params.caption or "", |
|
|
lyrics=params.lyrics or "", |
|
|
infer_type=infer_type, |
|
|
temperature=params.lm_temperature, |
|
|
cfg_scale=params.lm_cfg_scale, |
|
|
negative_prompt=params.lm_negative_prompt, |
|
|
top_k=top_k_value, |
|
|
top_p=top_p_value, |
|
|
user_metadata=user_metadata_to_pass, |
|
|
use_cot_caption=params.use_cot_caption, |
|
|
use_cot_language=params.use_cot_language, |
|
|
use_cot_metas=params.use_cot_metas, |
|
|
use_constrained_decoding=params.use_constrained_decoding, |
|
|
constrained_decoding_debug=config.constrained_decoding_debug, |
|
|
batch_size=chunk_size, |
|
|
seeds=chunk_seeds, |
|
|
progress=progress, |
|
|
) |
|
|
|
|
|
|
|
|
if not result.get("success", False): |
|
|
error_msg = result.get("error", "Unknown LM error") |
|
|
lm_status.append(f"❌ LM Error: {error_msg}") |
|
|
|
|
|
return GenerationResult( |
|
|
audios=[], |
|
|
status_message=f"❌ LM generation failed: {error_msg}", |
|
|
extra_outputs={}, |
|
|
success=False, |
|
|
error=error_msg, |
|
|
) |
|
|
|
|
|
|
|
|
if chunk_size > 1: |
|
|
metadata_list = result.get("metadata", []) |
|
|
audio_codes_list = result.get("audio_codes", []) |
|
|
all_metadata_list.extend(metadata_list) |
|
|
all_audio_codes_list.extend(audio_codes_list) |
|
|
else: |
|
|
metadata = result.get("metadata", {}) |
|
|
audio_codes = result.get("audio_codes", "") |
|
|
all_metadata_list.append(metadata) |
|
|
all_audio_codes_list.append(audio_codes) |
|
|
|
|
|
|
|
|
lm_extra = result.get("extra_outputs", {}) |
|
|
lm_chunk_time_costs = lm_extra.get("time_costs", {}) |
|
|
if lm_chunk_time_costs: |
|
|
|
|
|
for key in ["phase1_time", "phase2_time", "total_time"]: |
|
|
if key in lm_chunk_time_costs: |
|
|
lm_total_time_costs[key] += lm_chunk_time_costs[key] |
|
|
|
|
|
time_str = ", ".join([f"{k}: {v:.2f}s" for k, v in lm_chunk_time_costs.items()]) |
|
|
lm_status.append(f"✅ LM chunk {chunk_idx+1}: {time_str}") |
|
|
|
|
|
lm_generated_metadata = all_metadata_list[0] if all_metadata_list else None |
|
|
lm_generated_audio_codes_list = all_audio_codes_list |
|
|
|
|
|
|
|
|
if infer_type == "llm_dit": |
|
|
|
|
|
if actual_batch_size > 1: |
|
|
audio_code_string_to_use = all_audio_codes_list |
|
|
else: |
|
|
audio_code_string_to_use = all_audio_codes_list[0] if all_audio_codes_list else "" |
|
|
else: |
|
|
|
|
|
audio_code_string_to_use = params.audio_codes |
|
|
|
|
|
|
|
|
if lm_generated_metadata: |
|
|
bpm, key_scale, time_signature, audio_duration, vocal_language, caption, lyrics = _update_metadata_from_lm( |
|
|
metadata=lm_generated_metadata, |
|
|
bpm=bpm, |
|
|
key_scale=key_scale, |
|
|
time_signature=time_signature, |
|
|
audio_duration=audio_duration, |
|
|
vocal_language=dit_input_vocal_language, |
|
|
caption=dit_input_caption, |
|
|
lyrics=dit_input_lyrics) |
|
|
if not params.bpm: |
|
|
params.cot_bpm = bpm |
|
|
if not params.keyscale: |
|
|
params.cot_keyscale = key_scale |
|
|
if not params.timesignature: |
|
|
params.cot_timesignature = time_signature |
|
|
if not params.duration: |
|
|
params.cot_duration = audio_duration |
|
|
if not params.vocal_language: |
|
|
params.cot_vocal_language = vocal_language |
|
|
if not params.caption: |
|
|
params.cot_caption = caption |
|
|
if not params.lyrics: |
|
|
params.cot_lyrics = lyrics |
|
|
|
|
|
|
|
|
if params.use_cot_caption: |
|
|
dit_input_caption = lm_generated_metadata.get("caption", dit_input_caption) |
|
|
if params.use_cot_language: |
|
|
dit_input_vocal_language = lm_generated_metadata.get("vocal_language", dit_input_vocal_language) |
|
|
|
|
|
|
|
|
|
|
|
result = dit_handler.generate_music( |
|
|
captions=dit_input_caption, |
|
|
lyrics=dit_input_lyrics, |
|
|
bpm=bpm, |
|
|
key_scale=key_scale, |
|
|
time_signature=time_signature, |
|
|
vocal_language=dit_input_vocal_language, |
|
|
inference_steps=params.inference_steps, |
|
|
guidance_scale=params.guidance_scale, |
|
|
use_random_seed=config.use_random_seed, |
|
|
seed=seed_for_generation, |
|
|
reference_audio=params.reference_audio, |
|
|
audio_duration=audio_duration, |
|
|
batch_size=config.batch_size if config.batch_size is not None else 1, |
|
|
src_audio=params.src_audio, |
|
|
audio_code_string=audio_code_string_to_use, |
|
|
repainting_start=params.repainting_start, |
|
|
repainting_end=params.repainting_end, |
|
|
instruction=params.instruction, |
|
|
audio_cover_strength=params.audio_cover_strength, |
|
|
task_type=params.task_type, |
|
|
use_adg=params.use_adg, |
|
|
cfg_interval_start=params.cfg_interval_start, |
|
|
cfg_interval_end=params.cfg_interval_end, |
|
|
shift=params.shift, |
|
|
infer_method=params.infer_method, |
|
|
timesteps=params.timesteps, |
|
|
progress=progress, |
|
|
) |
|
|
|
|
|
|
|
|
if not result.get("success", False): |
|
|
return GenerationResult( |
|
|
audios=[], |
|
|
status_message=result.get("status_message", ""), |
|
|
extra_outputs={}, |
|
|
success=False, |
|
|
error=result.get("error"), |
|
|
) |
|
|
|
|
|
|
|
|
dit_audios = result.get("audios", []) |
|
|
status_message = result.get("status_message", "") |
|
|
dit_extra_outputs = result.get("extra_outputs", {}) |
|
|
|
|
|
|
|
|
|
|
|
seed_list = actual_seed_list |
|
|
|
|
|
|
|
|
base_params_dict = params.to_dict() |
|
|
|
|
|
|
|
|
audio_format = config.audio_format if config.audio_format else "flac" |
|
|
audio_saver = AudioSaver(default_format=audio_format) |
|
|
|
|
|
|
|
|
if save_dir is not None: |
|
|
os.makedirs(save_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
|
|
|
audios = [] |
|
|
for idx, dit_audio in enumerate(dit_audios): |
|
|
|
|
|
audio_params = base_params_dict.copy() |
|
|
|
|
|
|
|
|
audio_params["seed"] = seed_list[idx] if idx < len(seed_list) else None |
|
|
|
|
|
|
|
|
if lm_generated_audio_codes_list and idx < len(lm_generated_audio_codes_list): |
|
|
audio_params["audio_codes"] = lm_generated_audio_codes_list[idx] |
|
|
|
|
|
|
|
|
audio_tensor = dit_audio.get("tensor") |
|
|
sample_rate = dit_audio.get("sample_rate", 48000) |
|
|
|
|
|
|
|
|
batch_seed = seed_list[idx] if idx < len(seed_list) else seed_list[0] if seed_list else -1 |
|
|
audio_code_str = lm_generated_audio_codes_list[idx] if ( |
|
|
lm_generated_audio_codes_list and idx < len(lm_generated_audio_codes_list)) else audio_code_string_to_use |
|
|
if isinstance(audio_code_str, list): |
|
|
audio_code_str = audio_code_str[idx] if idx < len(audio_code_str) else "" |
|
|
|
|
|
audio_key = generate_uuid_from_params(audio_params) |
|
|
|
|
|
|
|
|
audio_path = None |
|
|
if audio_tensor is not None and save_dir is not None: |
|
|
try: |
|
|
audio_file = os.path.join(save_dir, f"{audio_key}.{audio_format}") |
|
|
audio_path = audio_saver.save_audio(audio_tensor, |
|
|
audio_file, |
|
|
sample_rate=sample_rate, |
|
|
format=audio_format, |
|
|
channels_first=True) |
|
|
except Exception as e: |
|
|
logger.error(f"[generate_music] Failed to save audio file: {e}") |
|
|
audio_path = "" |
|
|
|
|
|
audio_dict = { |
|
|
"path": audio_path or "", |
|
|
"tensor": audio_tensor, |
|
|
"key": audio_key, |
|
|
"sample_rate": sample_rate, |
|
|
"params": audio_params, |
|
|
} |
|
|
|
|
|
audios.append(audio_dict) |
|
|
|
|
|
|
|
|
extra_outputs = dit_extra_outputs.copy() |
|
|
extra_outputs["lm_metadata"] = lm_generated_metadata |
|
|
|
|
|
|
|
|
unified_time_costs = {} |
|
|
|
|
|
|
|
|
if use_lm and lm_total_time_costs: |
|
|
for key, value in lm_total_time_costs.items(): |
|
|
unified_time_costs[f"lm_{key}"] = value |
|
|
|
|
|
|
|
|
dit_time_costs = dit_extra_outputs.get("time_costs", {}) |
|
|
if dit_time_costs: |
|
|
for key, value in dit_time_costs.items(): |
|
|
unified_time_costs[f"dit_{key}"] = value |
|
|
|
|
|
|
|
|
if unified_time_costs: |
|
|
lm_total = unified_time_costs.get("lm_total_time", 0.0) |
|
|
dit_total = unified_time_costs.get("dit_total_time_cost", 0.0) |
|
|
unified_time_costs["pipeline_total_time"] = lm_total + dit_total |
|
|
|
|
|
|
|
|
extra_outputs["time_costs"] = unified_time_costs |
|
|
|
|
|
if lm_status: |
|
|
status_message = "\n".join(lm_status) + "\n" + status_message |
|
|
else: |
|
|
status_message = status_message |
|
|
|
|
|
return GenerationResult( |
|
|
audios=audios, |
|
|
status_message=status_message, |
|
|
extra_outputs=extra_outputs, |
|
|
success=True, |
|
|
error=None, |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
logger.exception("Music generation failed") |
|
|
return GenerationResult( |
|
|
audios=[], |
|
|
status_message=f"Error: {str(e)}", |
|
|
extra_outputs={}, |
|
|
success=False, |
|
|
error=str(e), |
|
|
) |
|
|
|
|
|
|
|
|
def understand_music( |
|
|
llm_handler, |
|
|
audio_codes: str, |
|
|
temperature: float = 0.85, |
|
|
top_k: Optional[int] = None, |
|
|
top_p: Optional[float] = None, |
|
|
repetition_penalty: float = 1.0, |
|
|
use_constrained_decoding: bool = True, |
|
|
constrained_decoding_debug: bool = False, |
|
|
) -> UnderstandResult: |
|
|
"""Understand music from audio codes using the 5Hz Language Model. |
|
|
|
|
|
This function analyzes audio semantic codes and generates metadata about the music, |
|
|
including caption, lyrics, BPM, duration, key scale, language, and time signature. |
|
|
|
|
|
If audio_codes is empty or "NO USER INPUT", the LM will generate a sample example |
|
|
instead of analyzing existing codes. |
|
|
|
|
|
Note: cfg_scale and negative_prompt are not supported in understand mode. |
|
|
|
|
|
Args: |
|
|
llm_handler: Initialized LLM handler (LLMHandler instance) |
|
|
audio_codes: String of audio code tokens (e.g., "<|audio_code_123|><|audio_code_456|>...") |
|
|
Use empty string or "NO USER INPUT" to generate a sample example. |
|
|
temperature: Sampling temperature for generation (0.0-2.0). Higher = more creative. |
|
|
top_k: Top-K sampling (None or 0 = disabled) |
|
|
top_p: Top-P (nucleus) sampling (None or 1.0 = disabled) |
|
|
repetition_penalty: Repetition penalty (1.0 = no penalty) |
|
|
use_constrained_decoding: Whether to use FSM-based constrained decoding for metadata |
|
|
constrained_decoding_debug: Whether to enable debug logging for constrained decoding |
|
|
|
|
|
Returns: |
|
|
UnderstandResult with parsed metadata fields and status |
|
|
|
|
|
Example: |
|
|
>>> result = understand_music(llm_handler, audio_codes="<|audio_code_123|>...") |
|
|
>>> if result.success: |
|
|
... print(f"Caption: {result.caption}") |
|
|
... print(f"BPM: {result.bpm}") |
|
|
... print(f"Lyrics: {result.lyrics}") |
|
|
""" |
|
|
|
|
|
if not llm_handler.llm_initialized: |
|
|
return UnderstandResult( |
|
|
status_message="5Hz LM not initialized. Please initialize it first.", |
|
|
success=False, |
|
|
error="LLM not initialized", |
|
|
) |
|
|
|
|
|
|
|
|
if not audio_codes or not audio_codes.strip(): |
|
|
audio_codes = "NO USER INPUT" |
|
|
|
|
|
try: |
|
|
|
|
|
metadata, status = llm_handler.understand_audio_from_codes( |
|
|
audio_codes=audio_codes, |
|
|
temperature=temperature, |
|
|
top_k=top_k, |
|
|
top_p=top_p, |
|
|
repetition_penalty=repetition_penalty, |
|
|
use_constrained_decoding=use_constrained_decoding, |
|
|
constrained_decoding_debug=constrained_decoding_debug, |
|
|
) |
|
|
|
|
|
|
|
|
if not metadata: |
|
|
return UnderstandResult( |
|
|
status_message=status or "Failed to understand audio codes", |
|
|
success=False, |
|
|
error=status or "Empty metadata returned", |
|
|
) |
|
|
|
|
|
|
|
|
caption = metadata.get('caption', '') |
|
|
lyrics = metadata.get('lyrics', '') |
|
|
keyscale = metadata.get('keyscale', '') |
|
|
language = metadata.get('language', metadata.get('vocal_language', '')) |
|
|
timesignature = metadata.get('timesignature', '') |
|
|
|
|
|
|
|
|
bpm = None |
|
|
bpm_value = metadata.get('bpm') |
|
|
if bpm_value is not None and bpm_value != 'N/A' and bpm_value != '': |
|
|
try: |
|
|
bpm = int(bpm_value) |
|
|
except (ValueError, TypeError): |
|
|
pass |
|
|
|
|
|
|
|
|
duration = None |
|
|
duration_value = metadata.get('duration') |
|
|
if duration_value is not None and duration_value != 'N/A' and duration_value != '': |
|
|
try: |
|
|
duration = float(duration_value) |
|
|
except (ValueError, TypeError): |
|
|
pass |
|
|
|
|
|
|
|
|
if keyscale == 'N/A': |
|
|
keyscale = '' |
|
|
if language == 'N/A': |
|
|
language = '' |
|
|
if timesignature == 'N/A': |
|
|
timesignature = '' |
|
|
|
|
|
return UnderstandResult( |
|
|
caption=caption, |
|
|
lyrics=lyrics, |
|
|
bpm=bpm, |
|
|
duration=duration, |
|
|
keyscale=keyscale, |
|
|
language=language, |
|
|
timesignature=timesignature, |
|
|
status_message=status, |
|
|
success=True, |
|
|
error=None, |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
logger.exception("Music understanding failed") |
|
|
return UnderstandResult( |
|
|
status_message=f"Error: {str(e)}", |
|
|
success=False, |
|
|
error=str(e), |
|
|
) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class CreateSampleResult: |
|
|
"""Result of creating a music sample from a natural language query. |
|
|
|
|
|
This is used by the "Simple Mode" / "Inspiration Mode" feature where users |
|
|
provide a natural language description and the LLM generates a complete |
|
|
sample with caption, lyrics, and metadata. |
|
|
|
|
|
Attributes: |
|
|
# Metadata Fields |
|
|
caption: Generated detailed music description/caption |
|
|
lyrics: Generated lyrics (or "[Instrumental]" for instrumental music) |
|
|
bpm: Beats per minute (None if not generated) |
|
|
duration: Duration in seconds (None if not generated) |
|
|
keyscale: Musical key (e.g., "C Major") |
|
|
language: Vocal language code (e.g., "en", "zh") |
|
|
timesignature: Time signature (e.g., "4") |
|
|
instrumental: Whether this is an instrumental piece |
|
|
|
|
|
# Status |
|
|
status_message: Status message from sample creation |
|
|
success: Whether sample creation completed successfully |
|
|
error: Error message if sample creation failed |
|
|
""" |
|
|
|
|
|
caption: str = "" |
|
|
lyrics: str = "" |
|
|
bpm: Optional[int] = None |
|
|
duration: Optional[float] = None |
|
|
keyscale: str = "" |
|
|
language: str = "" |
|
|
timesignature: str = "" |
|
|
instrumental: bool = False |
|
|
|
|
|
|
|
|
status_message: str = "" |
|
|
success: bool = True |
|
|
error: Optional[str] = None |
|
|
|
|
|
def to_dict(self) -> Dict[str, Any]: |
|
|
"""Convert result to dictionary for JSON serialization.""" |
|
|
return asdict(self) |
|
|
|
|
|
|
|
|
def create_sample( |
|
|
llm_handler, |
|
|
query: str, |
|
|
instrumental: bool = False, |
|
|
vocal_language: Optional[str] = None, |
|
|
temperature: float = 0.85, |
|
|
top_k: Optional[int] = None, |
|
|
top_p: Optional[float] = None, |
|
|
repetition_penalty: float = 1.0, |
|
|
use_constrained_decoding: bool = True, |
|
|
constrained_decoding_debug: bool = False, |
|
|
) -> CreateSampleResult: |
|
|
"""Create a music sample from a natural language query using the 5Hz Language Model. |
|
|
|
|
|
This is the "Simple Mode" / "Inspiration Mode" feature that takes a user's natural |
|
|
language description of music and generates a complete sample including: |
|
|
- Detailed caption/description |
|
|
- Lyrics (unless instrumental) |
|
|
- Metadata (BPM, duration, key, language, time signature) |
|
|
|
|
|
Note: cfg_scale and negative_prompt are not supported in create_sample mode. |
|
|
|
|
|
Args: |
|
|
llm_handler: Initialized LLM handler (LLMHandler instance) |
|
|
query: User's natural language music description (e.g., "a soft Bengali love song") |
|
|
instrumental: Whether to generate instrumental music (no vocals) |
|
|
vocal_language: Allowed vocal language for constrained decoding (e.g., "en", "zh"). |
|
|
If provided, the model will be constrained to generate lyrics in this language. |
|
|
If None or "unknown", no language constraint is applied. |
|
|
temperature: Sampling temperature for generation (0.0-2.0). Higher = more creative. |
|
|
top_k: Top-K sampling (None or 0 = disabled) |
|
|
top_p: Top-P (nucleus) sampling (None or 1.0 = disabled) |
|
|
repetition_penalty: Repetition penalty (1.0 = no penalty) |
|
|
use_constrained_decoding: Whether to use FSM-based constrained decoding |
|
|
constrained_decoding_debug: Whether to enable debug logging |
|
|
|
|
|
Returns: |
|
|
CreateSampleResult with generated sample fields and status |
|
|
|
|
|
Example: |
|
|
>>> result = create_sample(llm_handler, "a soft Bengali love song for a quiet evening", vocal_language="bn") |
|
|
>>> if result.success: |
|
|
... print(f"Caption: {result.caption}") |
|
|
... print(f"Lyrics: {result.lyrics}") |
|
|
... print(f"BPM: {result.bpm}") |
|
|
""" |
|
|
|
|
|
if not llm_handler.llm_initialized: |
|
|
return CreateSampleResult( |
|
|
status_message="5Hz LM not initialized. Please initialize it first.", |
|
|
success=False, |
|
|
error="LLM not initialized", |
|
|
) |
|
|
|
|
|
try: |
|
|
|
|
|
metadata, status = llm_handler.create_sample_from_query( |
|
|
query=query, |
|
|
instrumental=instrumental, |
|
|
vocal_language=vocal_language, |
|
|
temperature=temperature, |
|
|
top_k=top_k, |
|
|
top_p=top_p, |
|
|
repetition_penalty=repetition_penalty, |
|
|
use_constrained_decoding=use_constrained_decoding, |
|
|
constrained_decoding_debug=constrained_decoding_debug, |
|
|
) |
|
|
|
|
|
|
|
|
if not metadata: |
|
|
return CreateSampleResult( |
|
|
status_message=status or "Failed to create sample", |
|
|
success=False, |
|
|
error=status or "Empty metadata returned", |
|
|
) |
|
|
|
|
|
|
|
|
caption = metadata.get('caption', '') |
|
|
lyrics = metadata.get('lyrics', '') |
|
|
keyscale = metadata.get('keyscale', '') |
|
|
language = metadata.get('language', metadata.get('vocal_language', '')) |
|
|
timesignature = metadata.get('timesignature', '') |
|
|
is_instrumental = metadata.get('instrumental', instrumental) |
|
|
|
|
|
|
|
|
bpm = None |
|
|
bpm_value = metadata.get('bpm') |
|
|
if bpm_value is not None and bpm_value != 'N/A' and bpm_value != '': |
|
|
try: |
|
|
bpm = int(bpm_value) |
|
|
except (ValueError, TypeError): |
|
|
pass |
|
|
|
|
|
|
|
|
duration = None |
|
|
duration_value = metadata.get('duration') |
|
|
if duration_value is not None and duration_value != 'N/A' and duration_value != '': |
|
|
try: |
|
|
duration = float(duration_value) |
|
|
except (ValueError, TypeError): |
|
|
pass |
|
|
|
|
|
|
|
|
if keyscale == 'N/A': |
|
|
keyscale = '' |
|
|
if language == 'N/A': |
|
|
language = '' |
|
|
if timesignature == 'N/A': |
|
|
timesignature = '' |
|
|
|
|
|
return CreateSampleResult( |
|
|
caption=caption, |
|
|
lyrics=lyrics, |
|
|
bpm=bpm, |
|
|
duration=duration, |
|
|
keyscale=keyscale, |
|
|
language=language, |
|
|
timesignature=timesignature, |
|
|
instrumental=is_instrumental, |
|
|
status_message=status, |
|
|
success=True, |
|
|
error=None, |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
logger.exception("Sample creation failed") |
|
|
return CreateSampleResult( |
|
|
status_message=f"Error: {str(e)}", |
|
|
success=False, |
|
|
error=str(e), |
|
|
) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class FormatSampleResult: |
|
|
"""Result of formatting user-provided caption and lyrics. |
|
|
|
|
|
This is used by the "Format" feature where users provide caption and lyrics, |
|
|
and the LLM formats them into structured music metadata and an enhanced description. |
|
|
|
|
|
Attributes: |
|
|
# Metadata Fields |
|
|
caption: Enhanced/formatted music description/caption |
|
|
lyrics: Formatted lyrics (may be same as input or reformatted) |
|
|
bpm: Beats per minute (None if not detected) |
|
|
duration: Duration in seconds (None if not detected) |
|
|
keyscale: Musical key (e.g., "C Major") |
|
|
language: Vocal language code (e.g., "en", "zh") |
|
|
timesignature: Time signature (e.g., "4") |
|
|
|
|
|
# Status |
|
|
status_message: Status message from formatting |
|
|
success: Whether formatting completed successfully |
|
|
error: Error message if formatting failed |
|
|
""" |
|
|
|
|
|
caption: str = "" |
|
|
lyrics: str = "" |
|
|
bpm: Optional[int] = None |
|
|
duration: Optional[float] = None |
|
|
keyscale: str = "" |
|
|
language: str = "" |
|
|
timesignature: str = "" |
|
|
|
|
|
|
|
|
status_message: str = "" |
|
|
success: bool = True |
|
|
error: Optional[str] = None |
|
|
|
|
|
def to_dict(self) -> Dict[str, Any]: |
|
|
"""Convert result to dictionary for JSON serialization.""" |
|
|
return asdict(self) |
|
|
|
|
|
|
|
|
def format_sample( |
|
|
llm_handler, |
|
|
caption: str, |
|
|
lyrics: str, |
|
|
user_metadata: Optional[Dict[str, Any]] = None, |
|
|
temperature: float = 0.85, |
|
|
top_k: Optional[int] = None, |
|
|
top_p: Optional[float] = None, |
|
|
repetition_penalty: float = 1.0, |
|
|
use_constrained_decoding: bool = True, |
|
|
constrained_decoding_debug: bool = False, |
|
|
) -> FormatSampleResult: |
|
|
"""Format user-provided caption and lyrics using the 5Hz Language Model. |
|
|
|
|
|
This function takes user input (caption and lyrics) and generates structured |
|
|
music metadata including an enhanced caption, BPM, duration, key, language, |
|
|
and time signature. |
|
|
|
|
|
If user_metadata is provided, those values will be used to constrain the |
|
|
decoding, ensuring the output matches user-specified values. |
|
|
|
|
|
Note: cfg_scale and negative_prompt are not supported in format mode. |
|
|
|
|
|
Args: |
|
|
llm_handler: Initialized LLM handler (LLMHandler instance) |
|
|
caption: User's caption/description (e.g., "Latin pop, reggaeton") |
|
|
lyrics: User's lyrics with structure tags |
|
|
user_metadata: Optional dict with user-provided metadata to constrain decoding. |
|
|
Supported keys: bpm, duration, keyscale, timesignature, language |
|
|
temperature: Sampling temperature for generation (0.0-2.0). Higher = more creative. |
|
|
top_k: Top-K sampling (None or 0 = disabled) |
|
|
top_p: Top-P (nucleus) sampling (None or 1.0 = disabled) |
|
|
repetition_penalty: Repetition penalty (1.0 = no penalty) |
|
|
use_constrained_decoding: Whether to use FSM-based constrained decoding for metadata |
|
|
constrained_decoding_debug: Whether to enable debug logging for constrained decoding |
|
|
|
|
|
Returns: |
|
|
FormatSampleResult with formatted metadata fields and status |
|
|
|
|
|
Example: |
|
|
>>> result = format_sample(llm_handler, "Latin pop, reggaeton", "[Verse 1]\\nHola mundo...") |
|
|
>>> if result.success: |
|
|
... print(f"Caption: {result.caption}") |
|
|
... print(f"BPM: {result.bpm}") |
|
|
... print(f"Lyrics: {result.lyrics}") |
|
|
""" |
|
|
|
|
|
if not llm_handler.llm_initialized: |
|
|
return FormatSampleResult( |
|
|
status_message="5Hz LM not initialized. Please initialize it first.", |
|
|
success=False, |
|
|
error="LLM not initialized", |
|
|
) |
|
|
|
|
|
try: |
|
|
|
|
|
metadata, status = llm_handler.format_sample_from_input( |
|
|
caption=caption, |
|
|
lyrics=lyrics, |
|
|
user_metadata=user_metadata, |
|
|
temperature=temperature, |
|
|
top_k=top_k, |
|
|
top_p=top_p, |
|
|
repetition_penalty=repetition_penalty, |
|
|
use_constrained_decoding=use_constrained_decoding, |
|
|
constrained_decoding_debug=constrained_decoding_debug, |
|
|
) |
|
|
|
|
|
|
|
|
if not metadata: |
|
|
return FormatSampleResult( |
|
|
status_message=status or "Failed to format input", |
|
|
success=False, |
|
|
error=status or "Empty metadata returned", |
|
|
) |
|
|
|
|
|
|
|
|
result_caption = metadata.get('caption', '') |
|
|
result_lyrics = metadata.get('lyrics', lyrics) |
|
|
keyscale = metadata.get('keyscale', '') |
|
|
language = metadata.get('language', metadata.get('vocal_language', '')) |
|
|
timesignature = metadata.get('timesignature', '') |
|
|
|
|
|
|
|
|
bpm = None |
|
|
bpm_value = metadata.get('bpm') |
|
|
if bpm_value is not None and bpm_value != 'N/A' and bpm_value != '': |
|
|
try: |
|
|
bpm = int(bpm_value) |
|
|
except (ValueError, TypeError): |
|
|
pass |
|
|
|
|
|
|
|
|
duration = None |
|
|
duration_value = metadata.get('duration') |
|
|
if duration_value is not None and duration_value != 'N/A' and duration_value != '': |
|
|
try: |
|
|
duration = float(duration_value) |
|
|
except (ValueError, TypeError): |
|
|
pass |
|
|
|
|
|
|
|
|
if keyscale == 'N/A': |
|
|
keyscale = '' |
|
|
if language == 'N/A': |
|
|
language = '' |
|
|
if timesignature == 'N/A': |
|
|
timesignature = '' |
|
|
|
|
|
return FormatSampleResult( |
|
|
caption=result_caption, |
|
|
lyrics=result_lyrics, |
|
|
bpm=bpm, |
|
|
duration=duration, |
|
|
keyscale=keyscale, |
|
|
language=language, |
|
|
timesignature=timesignature, |
|
|
status_message=status, |
|
|
success=True, |
|
|
error=None, |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
logger.exception("Format sample failed") |
|
|
return FormatSampleResult( |
|
|
status_message=f"Error: {str(e)}", |
|
|
success=False, |
|
|
error=str(e), |
|
|
) |
|
|
|