Ace-Step-v1.5 / acestep /handler.py
ChuxiJ's picture
fix max audio code id
033008e
"""
Business Logic Handler
Encapsulates all data processing and business logic as a bridge between model and UI
"""
import os
# Disable tokenizers parallelism to avoid fork warning
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Disable torchcodec backend to avoid CUDA dependency issues on HuggingFace Space
# This forces torchaudio to use ffmpeg/sox/soundfile backends instead
os.environ["TORCHAUDIO_USE_TORCHCODEC"] = "0"
import math
from copy import deepcopy
import tempfile
import traceback
import re
import random
import uuid
import hashlib
import json
from contextlib import contextmanager
from typing import Optional, Dict, Any, Tuple, List, Union
import torch
import torchaudio
import soundfile as sf
import time
from tqdm import tqdm
from loguru import logger
import warnings
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM
from transformers.generation.streamers import BaseStreamer
from diffusers.models import AutoencoderOobleck
from acestep.constants import (
TASK_INSTRUCTIONS,
SFT_GEN_PROMPT,
DEFAULT_DIT_INSTRUCTION,
)
from acestep.dit_alignment_score import MusicStampsAligner, MusicLyricScorer
warnings.filterwarnings("ignore")
class AceStepHandler:
"""ACE-Step Business Logic Handler"""
# HuggingFace Space environment detection
IS_HUGGINGFACE_SPACE = os.environ.get("SPACE_ID") is not None
def __init__(self, persistent_storage_path: Optional[str] = None):
self.model = None
self.config = None
self.device = "cpu"
self.dtype = torch.float32 # Will be set based on device in initialize_service
# HuggingFace Space persistent storage support
if persistent_storage_path is None and self.IS_HUGGINGFACE_SPACE:
persistent_storage_path = "/data"
self.persistent_storage_path = persistent_storage_path
# VAE for audio encoding/decoding
self.vae = None
# Text encoder and tokenizer
self.text_encoder = None
self.text_tokenizer = None
# Silence latent for initialization
self.silence_latent = None
# Sample rate
self.sample_rate = 48000
# Reward model (temporarily disabled)
self.reward_model = None
# Batch size
self.batch_size = 2
# Custom layers config
self.custom_layers_config = {2: [6], 3: [10, 11], 4: [3], 5: [8, 9], 6: [8]}
self.offload_to_cpu = False
self.offload_dit_to_cpu = False
self.current_offload_cost = 0.0
# LoRA state
self.lora_loaded = False
self.use_lora = False
self._base_decoder = None # Backup of original decoder
def _get_checkpoint_dir(self) -> str:
"""Get checkpoint directory, prioritizing persistent storage if available"""
if self.persistent_storage_path:
return os.path.join(self.persistent_storage_path, "checkpoints")
project_root = self._get_project_root()
return os.path.join(project_root, "checkpoints")
def get_available_checkpoints(self) -> str:
"""Return project root directory path"""
checkpoint_dir = self._get_checkpoint_dir()
if os.path.exists(checkpoint_dir):
return [checkpoint_dir]
else:
return []
def get_available_acestep_v15_models(self) -> List[str]:
"""Scan and return all model directory names starting with 'acestep-v15-'"""
checkpoint_dir = self._get_checkpoint_dir()
models = []
if os.path.exists(checkpoint_dir):
for item in os.listdir(checkpoint_dir):
item_path = os.path.join(checkpoint_dir, item)
if os.path.isdir(item_path) and item.startswith("acestep-v15-"):
models.append(item)
models.sort()
return models
# Model name to HuggingFace repository mapping
# Models in the same repo will be downloaded together
MODEL_REPO_MAPPING = {
# Main unified repository (contains acestep-v15-turbo, LM models, VAE, text encoder)
"acestep-v15-turbo": "ACE-Step/Ace-Step1.5",
"acestep-5Hz-lm-0.6B": "ACE-Step/Ace-Step1.5",
"acestep-5Hz-lm-1.7B": "ACE-Step/Ace-Step1.5",
"vae": "ACE-Step/Ace-Step1.5",
"Qwen3-Embedding-0.6B": "ACE-Step/Ace-Step1.5",
# Separate model repositories
"acestep-v15-base": "ACE-Step/acestep-v15-base",
"acestep-v15-sft": "ACE-Step/acestep-v15-sft",
"acestep-v15-turbo-shift3": "ACE-Step/acestep-v15-turbo-shift3",
}
# Default fallback repository for unknown models
DEFAULT_REPO_ID = "ACE-Step/Ace-Step1.5"
def _ensure_model_downloaded(self, model_name: str, checkpoint_dir: str) -> str:
"""
Ensure model is downloaded from HuggingFace Hub.
Used for HuggingFace Space auto-download support.
Supports multiple repositories:
- Models in MODEL_REPO_MAPPING will be downloaded from their specific repo
- Unknown models will try the DEFAULT_REPO_ID
For separate model repos (acestep-v15-base, acestep-v15-sft, acestep-v15-turbo-shift3),
downloads directly into the model subdirectory.
Args:
model_name: Model directory name (e.g., "acestep-v15-turbo", "acestep-v15-turbo-shift3")
checkpoint_dir: Target checkpoint directory
Returns:
Path to the downloaded model
"""
from huggingface_hub import snapshot_download
model_path = os.path.join(checkpoint_dir, model_name)
# Check if model already exists
if os.path.exists(model_path) and os.listdir(model_path):
logger.info(f"Model {model_name} already exists at {model_path}")
return model_path
# Get repository ID for this model
repo_id = self.MODEL_REPO_MAPPING.get(model_name, self.DEFAULT_REPO_ID)
# Determine if this is a unified repo or a separate model repo
is_unified_repo = repo_id == self.DEFAULT_REPO_ID or repo_id == "ACE-Step/Ace-Step1.5"
if is_unified_repo:
# Unified repo: download entire repo to checkpoint_dir
# The model will be in checkpoint_dir/model_name
download_dir = checkpoint_dir
logger.info(f"Downloading unified repository {repo_id} to {download_dir}...")
else:
# Separate model repo: download directly to model_path
# The repo contains the model files directly, not in a subdirectory
download_dir = model_path
os.makedirs(download_dir, exist_ok=True)
logger.info(f"Downloading model {model_name} from {repo_id} to {download_dir}...")
try:
snapshot_download(
repo_id=repo_id,
local_dir=download_dir,
local_dir_use_symlinks=False,
)
logger.info(f"Repository {repo_id} downloaded successfully to {download_dir}")
except Exception as e:
logger.error(f"Failed to download repository {repo_id}: {e}")
raise
return model_path
def is_flash_attention_available(self) -> bool:
"""Check if flash attention is available on the system"""
try:
import flash_attn
return True
except ImportError:
return False
def is_turbo_model(self) -> bool:
"""Check if the currently loaded model is a turbo model"""
if self.config is None:
return False
return getattr(self.config, 'is_turbo', False)
def load_lora(self, lora_path: str) -> str:
"""Load LoRA adapter into the decoder.
Args:
lora_path: Path to the LoRA adapter directory (containing adapter_config.json)
Returns:
Status message
"""
if self.model is None:
return "❌ Model not initialized. Please initialize service first."
if not lora_path or not lora_path.strip():
return "❌ Please provide a LoRA path."
lora_path = lora_path.strip()
# Check if path exists
if not os.path.exists(lora_path):
return f"❌ LoRA path not found: {lora_path}"
# Check if it's a valid PEFT adapter directory
config_file = os.path.join(lora_path, "adapter_config.json")
if not os.path.exists(config_file):
return f"❌ Invalid LoRA adapter: adapter_config.json not found in {lora_path}"
try:
from peft import PeftModel, PeftConfig
except ImportError:
return "❌ PEFT library not installed. Please install with: pip install peft"
try:
# Backup base decoder if not already backed up
if self._base_decoder is None:
import copy
self._base_decoder = copy.deepcopy(self.model.decoder)
logger.info("Base decoder backed up")
else:
# Restore base decoder before loading new LoRA
self.model.decoder = copy.deepcopy(self._base_decoder)
logger.info("Restored base decoder before loading new LoRA")
# Load PEFT adapter
logger.info(f"Loading LoRA adapter from {lora_path}")
self.model.decoder = PeftModel.from_pretrained(
self.model.decoder,
lora_path,
is_trainable=False,
)
self.model.decoder = self.model.decoder.to(self.device).to(self.dtype)
self.model.decoder.eval()
self.lora_loaded = True
self.use_lora = True # Enable LoRA by default after loading
logger.info(f"LoRA adapter loaded successfully from {lora_path}")
return f"✅ LoRA loaded from {lora_path}"
except Exception as e:
logger.exception("Failed to load LoRA adapter")
return f"❌ Failed to load LoRA: {str(e)}"
def unload_lora(self) -> str:
"""Unload LoRA adapter and restore base decoder.
Returns:
Status message
"""
if not self.lora_loaded:
return "⚠️ No LoRA adapter loaded."
if self._base_decoder is None:
return "❌ Base decoder backup not found. Cannot restore."
try:
import copy
# Restore base decoder
self.model.decoder = copy.deepcopy(self._base_decoder)
self.model.decoder = self.model.decoder.to(self.device).to(self.dtype)
self.model.decoder.eval()
self.lora_loaded = False
self.use_lora = False
logger.info("LoRA unloaded, base decoder restored")
return "✅ LoRA unloaded, using base model"
except Exception as e:
logger.exception("Failed to unload LoRA")
return f"❌ Failed to unload LoRA: {str(e)}"
def set_use_lora(self, use_lora: bool) -> str:
"""Toggle LoRA usage for inference.
Args:
use_lora: Whether to use LoRA adapter
Returns:
Status message
"""
if use_lora and not self.lora_loaded:
return "❌ No LoRA adapter loaded. Please load a LoRA first."
self.use_lora = use_lora
# Use PEFT's enable/disable methods if available
if self.lora_loaded and hasattr(self.model.decoder, 'disable_adapter_layers'):
try:
if use_lora:
self.model.decoder.enable_adapter_layers()
logger.info("LoRA adapter enabled")
else:
self.model.decoder.disable_adapter_layers()
logger.info("LoRA adapter disabled")
except Exception as e:
logger.warning(f"Could not toggle adapter layers: {e}")
status = "enabled" if use_lora else "disabled"
return f"✅ LoRA {status}"
def get_lora_status(self) -> Dict[str, Any]:
"""Get current LoRA status.
Returns:
Dictionary with LoRA status info
"""
return {
"loaded": self.lora_loaded,
"active": self.use_lora,
}
def initialize_service(
self,
project_root: str,
config_path: str,
device: str = "auto",
use_flash_attention: bool = False,
compile_model: bool = False,
offload_to_cpu: bool = False,
offload_dit_to_cpu: bool = False,
quantization: Optional[str] = None,
# Shared components (for multi-model setup to save memory)
shared_vae = None,
shared_text_encoder = None,
shared_text_tokenizer = None,
shared_silence_latent = None,
) -> Tuple[str, bool]:
"""
Initialize DiT model service
Args:
project_root: Project root path (may be checkpoints directory, will be handled automatically)
config_path: Model config directory name (e.g., "acestep-v15-turbo")
device: Device type
use_flash_attention: Whether to use flash attention (requires flash_attn package)
compile_model: Whether to use torch.compile to optimize the model
offload_to_cpu: Whether to offload models to CPU when not in use
offload_dit_to_cpu: Whether to offload DiT model to CPU when not in use (only effective if offload_to_cpu is True)
shared_vae: Optional shared VAE instance (for multi-model setup)
shared_text_encoder: Optional shared text encoder instance (for multi-model setup)
shared_text_tokenizer: Optional shared text tokenizer instance (for multi-model setup)
shared_silence_latent: Optional shared silence latent tensor (for multi-model setup)
Returns:
(status_message, enable_generate_button)
"""
try:
if device == "auto":
if hasattr(torch, 'xpu') and torch.xpu.is_available():
device = "xpu"
elif torch.cuda.is_available():
device = "cuda"
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
status_msg = ""
self.device = device
self.offload_to_cpu = offload_to_cpu
self.offload_dit_to_cpu = offload_dit_to_cpu
# Set dtype based on device: bfloat16 for cuda, float32 for cpu
self.dtype = torch.bfloat16 if device in ["cuda","xpu"] else torch.float32
self.quantization = quantization
if self.quantization is not None:
assert compile_model, "Quantization requires compile_model to be True"
try:
import torchao
except ImportError:
raise ImportError("torchao is required for quantization but is not installed. Please install torchao to use quantization features.")
# Auto-detect project root (independent of passed project_root parameter)
actual_project_root = self._get_project_root()
checkpoint_dir = self._get_checkpoint_dir()
os.makedirs(checkpoint_dir, exist_ok=True)
# 1. Load main model
# config_path is relative path (e.g., "acestep-v15-turbo"), concatenate to checkpoints directory
# If config_path is None (HuggingFace Space with empty checkpoint), use default and auto-download
if config_path is None:
config_path = "acestep-v15-turbo"
logger.info(f"[initialize_service] config_path is None, using default: {config_path}")
acestep_v15_checkpoint_path = os.path.join(checkpoint_dir, config_path)
# Auto-download model if not exists (HuggingFace Space support)
if not os.path.exists(acestep_v15_checkpoint_path):
acestep_v15_checkpoint_path = self._ensure_model_downloaded(config_path, checkpoint_dir)
if os.path.exists(acestep_v15_checkpoint_path):
# Determine attention implementation
if use_flash_attention and self.is_flash_attention_available():
attn_implementation = "flash_attention_2"
self.dtype = torch.bfloat16
else:
attn_implementation = "sdpa"
try:
logger.info(f"[initialize_service] Attempting to load model with attention implementation: {attn_implementation}")
self.model = AutoModel.from_pretrained(
acestep_v15_checkpoint_path,
trust_remote_code=True,
attn_implementation=attn_implementation,
dtype="bfloat16"
)
except Exception as e:
logger.warning(f"[initialize_service] Failed to load model with {attn_implementation}: {e}")
if attn_implementation == "sdpa":
logger.info("[initialize_service] Falling back to eager attention")
attn_implementation = "eager"
self.model = AutoModel.from_pretrained(
acestep_v15_checkpoint_path,
trust_remote_code=True,
attn_implementation=attn_implementation
)
else:
raise e
self.model.config._attn_implementation = attn_implementation
self.config = self.model.config
# Move model to device and set dtype
if not self.offload_to_cpu:
self.model = self.model.to(device).to(self.dtype)
else:
# If offload_to_cpu is True, check if we should keep DiT on GPU
if not self.offload_dit_to_cpu:
logger.info(f"[initialize_service] Keeping main model on {device} (persistent)")
self.model = self.model.to(device).to(self.dtype)
else:
self.model = self.model.to("cpu").to(self.dtype)
self.model.eval()
if compile_model:
self.model = torch.compile(self.model)
if self.quantization is not None:
from torchao.quantization import quantize_
if self.quantization == "int8_weight_only":
from torchao.quantization import Int8WeightOnlyConfig
quant_config = Int8WeightOnlyConfig()
elif self.quantization == "fp8_weight_only":
from torchao.quantization import Float8WeightOnlyConfig
quant_config = Float8WeightOnlyConfig()
elif self.quantization == "w8a8_dynamic":
from torchao.quantization import Int8DynamicActivationInt8WeightConfig, MappingType
quant_config = Int8DynamicActivationInt8WeightConfig(act_mapping_type=MappingType.ASYMMETRIC)
else:
raise ValueError(f"Unsupported quantization type: {self.quantization}")
quantize_(self.model, quant_config)
logger.info(f"[initialize_service] DiT quantized with: {self.quantization}")
# Load or use shared silence_latent
if shared_silence_latent is not None:
self.silence_latent = shared_silence_latent
logger.info("[initialize_service] Using shared silence_latent")
else:
silence_latent_path = os.path.join(acestep_v15_checkpoint_path, "silence_latent.pt")
if os.path.exists(silence_latent_path):
self.silence_latent = torch.load(silence_latent_path).transpose(1, 2)
# Always keep silence_latent on GPU - it's used in many places outside model context
# and is small enough that it won't significantly impact VRAM
self.silence_latent = self.silence_latent.to(device).to(self.dtype)
else:
raise FileNotFoundError(f"Silence latent not found at {silence_latent_path}")
else:
raise FileNotFoundError(f"ACE-Step V1.5 checkpoint not found at {acestep_v15_checkpoint_path}")
# 2. Load or use shared VAE
vae_checkpoint_path = os.path.join(checkpoint_dir, "vae") # Define for status message
if shared_vae is not None:
self.vae = shared_vae
logger.info("[initialize_service] Using shared VAE")
else:
if os.path.exists(vae_checkpoint_path):
self.vae = AutoencoderOobleck.from_pretrained(vae_checkpoint_path)
# Use bfloat16 for VAE on GPU, otherwise use self.dtype (float32 on CPU)
vae_dtype = self._get_vae_dtype(device)
if not self.offload_to_cpu:
self.vae = self.vae.to(device).to(vae_dtype)
else:
self.vae = self.vae.to("cpu").to(vae_dtype)
self.vae.eval()
else:
raise FileNotFoundError(f"VAE checkpoint not found at {vae_checkpoint_path}")
if compile_model:
self.vae = torch.compile(self.vae)
# 3. Load or use shared text encoder and tokenizer
text_encoder_path = os.path.join(checkpoint_dir, "Qwen3-Embedding-0.6B") # Define for status message
if shared_text_encoder is not None and shared_text_tokenizer is not None:
self.text_encoder = shared_text_encoder
self.text_tokenizer = shared_text_tokenizer
logger.info("[initialize_service] Using shared text encoder and tokenizer")
else:
if os.path.exists(text_encoder_path):
self.text_tokenizer = AutoTokenizer.from_pretrained(text_encoder_path)
self.text_encoder = AutoModel.from_pretrained(text_encoder_path)
if not self.offload_to_cpu:
self.text_encoder = self.text_encoder.to(device).to(self.dtype)
else:
self.text_encoder = self.text_encoder.to("cpu").to(self.dtype)
self.text_encoder.eval()
else:
raise FileNotFoundError(f"Text encoder not found at {text_encoder_path}")
# Determine actual attention implementation used
actual_attn = getattr(self.config, "_attn_implementation", "eager")
# Determine if using shared components
using_shared = shared_vae is not None or shared_text_encoder is not None
status_msg = f"✅ Model initialized successfully on {device}\n"
status_msg += f"Main model: {acestep_v15_checkpoint_path}\n"
if shared_vae is None:
status_msg += f"VAE: {vae_checkpoint_path}\n"
else:
status_msg += f"VAE: shared\n"
if shared_text_encoder is None:
status_msg += f"Text encoder: {text_encoder_path}\n"
else:
status_msg += f"Text encoder: shared\n"
status_msg += f"Dtype: {self.dtype}\n"
status_msg += f"Attention: {actual_attn}\n"
status_msg += f"Compiled: {compile_model}\n"
status_msg += f"Offload to CPU: {self.offload_to_cpu}\n"
status_msg += f"Offload DiT to CPU: {self.offload_dit_to_cpu}"
return status_msg, True
except Exception as e:
error_msg = f"❌ Error initializing model: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
logger.exception("[initialize_service] Error initializing model")
return error_msg, False
def _is_on_target_device(self, tensor, target_device):
"""Check if tensor is on the target device (handles cuda vs cuda:0 comparison)."""
if tensor is None:
return True
target_type = "cpu" if target_device == "cpu" else "cuda"
return tensor.device.type == target_type
def _ensure_silence_latent_on_device(self):
"""Ensure silence_latent is on the correct device (self.device)."""
if hasattr(self, "silence_latent") and self.silence_latent is not None:
if not self._is_on_target_device(self.silence_latent, self.device):
self.silence_latent = self.silence_latent.to(self.device).to(self.dtype)
def _move_module_recursive(self, module, target_device, dtype=None, visited=None):
"""
Recursively move a module and all its submodules to the target device.
This handles modules that may not be properly registered.
"""
if visited is None:
visited = set()
module_id = id(module)
if module_id in visited:
return
visited.add(module_id)
# Move the module itself
module.to(target_device)
if dtype is not None:
module.to(dtype)
# Move all direct parameters
for param_name, param in module._parameters.items():
if param is not None and not self._is_on_target_device(param, target_device):
module._parameters[param_name] = param.to(target_device)
if dtype is not None:
module._parameters[param_name] = module._parameters[param_name].to(dtype)
# Move all direct buffers
for buf_name, buf in module._buffers.items():
if buf is not None and not self._is_on_target_device(buf, target_device):
module._buffers[buf_name] = buf.to(target_device)
# Recursively process all submodules (registered and unregistered)
for name, child in module._modules.items():
if child is not None:
self._move_module_recursive(child, target_device, dtype, visited)
# Also check for any nn.Module attributes that might not be in _modules
for attr_name in dir(module):
if attr_name.startswith('_'):
continue
try:
attr = getattr(module, attr_name, None)
if isinstance(attr, torch.nn.Module) and id(attr) not in visited:
self._move_module_recursive(attr, target_device, dtype, visited)
except Exception:
pass
def _recursive_to_device(self, model, device, dtype=None):
"""
Recursively move all parameters and buffers of a model to the specified device.
This is more thorough than model.to() for some custom HuggingFace models.
"""
target_device = torch.device(device) if isinstance(device, str) else device
# Method 1: Standard .to() call
model.to(target_device)
if dtype is not None:
model.to(dtype)
# Method 2: Use our thorough recursive moving for any missed modules
self._move_module_recursive(model, target_device, dtype)
# Method 3: Force move via state_dict if there are still parameters on wrong device
wrong_device_params = []
for name, param in model.named_parameters():
if not self._is_on_target_device(param, device):
wrong_device_params.append(name)
if wrong_device_params and device != "cpu":
logger.warning(f"[_recursive_to_device] {len(wrong_device_params)} parameters on wrong device, using state_dict method")
# Get current state dict and move all tensors
state_dict = model.state_dict()
moved_state_dict = {}
for key, value in state_dict.items():
if isinstance(value, torch.Tensor):
moved_state_dict[key] = value.to(target_device)
if dtype is not None and moved_state_dict[key].is_floating_point():
moved_state_dict[key] = moved_state_dict[key].to(dtype)
else:
moved_state_dict[key] = value
model.load_state_dict(moved_state_dict)
# Synchronize CUDA to ensure all transfers are complete
if device != "cpu" and torch.cuda.is_available():
torch.cuda.synchronize()
# Final verification
if device != "cpu":
still_wrong = []
for name, param in model.named_parameters():
if not self._is_on_target_device(param, device):
still_wrong.append(f"{name} on {param.device}")
if still_wrong:
logger.error(f"[_recursive_to_device] CRITICAL: {len(still_wrong)} parameters still on wrong device: {still_wrong[:10]}")
@contextmanager
def _load_model_context(self, model_name: str):
"""
Context manager to load a model to GPU and offload it back to CPU after use.
Args:
model_name: Name of the model to load ("text_encoder", "vae", "model")
"""
if not self.offload_to_cpu:
yield
return
# If model is DiT ("model") and offload_dit_to_cpu is False, do not offload
if model_name == "model" and not self.offload_dit_to_cpu:
# Ensure it's on device if not already (should be handled by init, but safe to check)
model = getattr(self, model_name, None)
if model is not None:
# Check if model is on CPU, if so move to device (one-time move if it was somehow on CPU)
# We check the first parameter's device
try:
param = next(model.parameters())
if param.device.type == "cpu":
logger.info(f"[_load_model_context] Moving {model_name} to {self.device} (persistent)")
self._recursive_to_device(model, self.device, self.dtype)
if hasattr(self, "silence_latent"):
self.silence_latent = self.silence_latent.to(self.device).to(self.dtype)
except StopIteration:
pass
yield
return
model = getattr(self, model_name, None)
if model is None:
yield
return
# Load to GPU
logger.info(f"[_load_model_context] Loading {model_name} to {self.device}")
start_time = time.time()
if model_name == "vae":
vae_dtype = self._get_vae_dtype()
self._recursive_to_device(model, self.device, vae_dtype)
else:
self._recursive_to_device(model, self.device, self.dtype)
if model_name == "model" and hasattr(self, "silence_latent"):
self.silence_latent = self.silence_latent.to(self.device).to(self.dtype)
load_time = time.time() - start_time
self.current_offload_cost += load_time
logger.info(f"[_load_model_context] Loaded {model_name} to {self.device} in {load_time:.4f}s")
try:
yield
finally:
# Offload to CPU
logger.info(f"[_load_model_context] Offloading {model_name} to CPU")
start_time = time.time()
self._recursive_to_device(model, "cpu")
# NOTE: Do NOT offload silence_latent to CPU here!
# silence_latent is used in many places outside of model context,
# so it should stay on GPU to avoid device mismatch errors.
torch.cuda.empty_cache()
offload_time = time.time() - start_time
self.current_offload_cost += offload_time
logger.info(f"[_load_model_context] Offloaded {model_name} to CPU in {offload_time:.4f}s")
def process_target_audio(self, audio_file) -> Optional[torch.Tensor]:
"""Process target audio"""
if audio_file is None:
return None
try:
# Load audio using soundfile
audio_np, sr = sf.read(audio_file, dtype='float32')
# Convert to torch: [samples, channels] or [samples] -> [channels, samples]
if audio_np.ndim == 1:
audio = torch.from_numpy(audio_np).unsqueeze(0)
else:
audio = torch.from_numpy(audio_np.T)
# Normalize to stereo 48kHz
audio = self._normalize_audio_to_stereo_48k(audio, sr)
# Enforce duration limits (10-600 seconds)
audio = self._enforce_audio_duration_limits(audio)
return audio
except Exception as e:
logger.exception("[process_target_audio] Error processing target audio")
return None
def _parse_audio_code_string(self, code_str: str) -> List[int]:
"""Extract integer audio codes from prompt tokens like <|audio_code_123|>.
Codes are clamped to valid range [0, 63999] (codebook size = 64000).
"""
if not code_str:
return []
try:
codes = [int(x) for x in re.findall(r"<\|audio_code_(\d+)\|>", code_str)]
# Clamp codes to valid range [0, 63999]
MAX_AUDIO_CODE = 63999
clamped_codes = []
invalid_codes = []
for code in codes:
if code < 0 or code > MAX_AUDIO_CODE:
invalid_codes.append(code)
clamped_code = max(0, min(code, MAX_AUDIO_CODE))
clamped_codes.append(clamped_code)
else:
clamped_codes.append(code)
if invalid_codes:
logger.warning(f"[_parse_audio_code_string] Found {len(invalid_codes)} codes outside valid range [0, {MAX_AUDIO_CODE}]: {invalid_codes[:5]}... (clamped to valid range)")
return clamped_codes
except Exception as e:
logger.debug(f"[_parse_audio_code_string] Failed to parse audio code string: {e}")
return []
def _decode_audio_codes_to_latents(self, code_str: str) -> Optional[torch.Tensor]:
"""
Convert serialized audio code string into 25Hz latents using model quantizer/detokenizer.
"""
if not self.model or not hasattr(self.model, 'tokenizer') or not hasattr(self.model, 'detokenizer'):
return None
code_ids = self._parse_audio_code_string(code_str)
if len(code_ids) == 0:
return None
try:
with self._load_model_context("model"):
quantizer = self.model.tokenizer.quantizer
detokenizer = self.model.detokenizer
# Get codebook size for validation
# Default to 64000 (codebook size = 64000, valid range = 0-63999)
codebook_size = getattr(quantizer, 'codebook_size', 64000)
if hasattr(quantizer, 'quantizers') and len(quantizer.quantizers) > 0:
codebook_size = getattr(quantizer.quantizers[0], 'codebook_size', codebook_size)
# Validate code IDs are within valid range
invalid_codes = [c for c in code_ids if c < 0 or c >= codebook_size]
if invalid_codes:
logger.warning(f"[_decode_audio_codes_to_latents] Found {len(invalid_codes)} invalid codes out of range [0, {codebook_size}): {invalid_codes[:5]}...")
# Clamp invalid codes to valid range
code_ids = [max(0, min(c, codebook_size - 1)) for c in code_ids]
num_quantizers = getattr(quantizer, "num_quantizers", 1)
# Create indices tensor: [T_5Hz]
indices = torch.tensor(code_ids, device=self.device, dtype=torch.long) # [T_5Hz]
indices = indices.unsqueeze(0).unsqueeze(-1) # [1, T_5Hz, 1]
# Synchronize to catch any CUDA errors before proceeding
if torch.cuda.is_available():
torch.cuda.synchronize()
# Get quantized representation from indices
# The quantizer expects [batch, T_5Hz] format and handles quantizer dimension internally
quantized = quantizer.get_output_from_indices(indices)
if quantized.dtype != self.dtype:
quantized = quantized.to(self.dtype)
# Detokenize to 25Hz: [1, T_5Hz, dim] -> [1, T_25Hz, dim]
lm_hints_25hz = detokenizer(quantized)
return lm_hints_25hz
except Exception as e:
logger.exception(f"[_decode_audio_codes_to_latents] Error decoding audio codes: {e}")
# Clear CUDA error state
if torch.cuda.is_available():
torch.cuda.empty_cache()
return None
def _create_default_meta(self) -> str:
"""Create default metadata string."""
return (
"- bpm: N/A\n"
"- timesignature: N/A\n"
"- keyscale: N/A\n"
"- duration: 30 seconds\n"
)
def _dict_to_meta_string(self, meta_dict: Dict[str, Any]) -> str:
"""Convert metadata dict to formatted string."""
bpm = meta_dict.get('bpm', meta_dict.get('tempo', 'N/A'))
timesignature = meta_dict.get('timesignature', meta_dict.get('time_signature', 'N/A'))
keyscale = meta_dict.get('keyscale', meta_dict.get('key', meta_dict.get('scale', 'N/A')))
duration = meta_dict.get('duration', meta_dict.get('length', 30))
# Format duration
if isinstance(duration, (int, float)):
duration = f"{int(duration)} seconds"
elif not isinstance(duration, str):
duration = "30 seconds"
return (
f"- bpm: {bpm}\n"
f"- timesignature: {timesignature}\n"
f"- keyscale: {keyscale}\n"
f"- duration: {duration}\n"
)
def _parse_metas(self, metas: List[Union[str, Dict[str, Any]]]) -> List[str]:
"""
Parse and normalize metadata with fallbacks.
Args:
metas: List of metadata (can be strings, dicts, or None)
Returns:
List of formatted metadata strings
"""
parsed_metas = []
for meta in metas:
if meta is None:
# Default fallback metadata
parsed_meta = self._create_default_meta()
elif isinstance(meta, str):
# Already formatted string
parsed_meta = meta
elif isinstance(meta, dict):
# Convert dict to formatted string
parsed_meta = self._dict_to_meta_string(meta)
else:
# Fallback for any other type
parsed_meta = self._create_default_meta()
parsed_metas.append(parsed_meta)
return parsed_metas
def build_dit_inputs(
self,
task: str,
instruction: Optional[str],
caption: str,
lyrics: str,
metas: Optional[Union[str, Dict[str, Any]]] = None,
vocal_language: str = "en",
) -> Tuple[str, str]:
"""
Build text inputs for the caption and lyric branches used by DiT.
Args:
task: Task name (e.g., text2music, cover, repaint); kept for logging/future branching.
instruction: Instruction text; default fallback matches service_generate behavior.
caption: Caption string (fallback if not in metas).
lyrics: Lyrics string.
metas: Metadata (str or dict); follows _parse_metas formatting.
May contain 'caption' and 'language' fields from LM CoT output.
vocal_language: Language code for lyrics section (fallback if not in metas).
Returns:
(caption_input_text, lyrics_input_text)
Example:
caption_input, lyrics_input = handler.build_dit_inputs(
task="text2music",
instruction=None,
caption="A calm piano melody",
lyrics="la la la",
metas={"bpm": 90, "duration": 45, "caption": "LM generated caption", "language": "en"},
vocal_language="en",
)
"""
# Align instruction formatting with _prepare_batch
final_instruction = self._format_instruction(instruction or DEFAULT_DIT_INSTRUCTION)
# Extract caption and language from metas if available (from LM CoT output)
# Fallback to user-provided values if not in metas
actual_caption = caption
actual_language = vocal_language
if metas is not None:
# Parse metas to dict if it's a string
if isinstance(metas, str):
# Try to parse as dict-like string or use as-is
parsed_metas = self._parse_metas([metas])
if parsed_metas and isinstance(parsed_metas[0], dict):
meta_dict = parsed_metas[0]
else:
meta_dict = {}
elif isinstance(metas, dict):
meta_dict = metas
else:
meta_dict = {}
# Extract caption from metas if available
if 'caption' in meta_dict and meta_dict['caption']:
actual_caption = str(meta_dict['caption'])
# Extract language from metas if available
if 'language' in meta_dict and meta_dict['language']:
actual_language = str(meta_dict['language'])
parsed_meta = self._parse_metas([metas])[0]
caption_input = SFT_GEN_PROMPT.format(final_instruction, actual_caption, parsed_meta)
lyrics_input = self._format_lyrics(lyrics, actual_language)
return caption_input, lyrics_input
def _get_text_hidden_states(self, text_prompt: str) -> Tuple[torch.Tensor, torch.Tensor]:
"""Get text hidden states from text encoder."""
if self.text_tokenizer is None or self.text_encoder is None:
raise ValueError("Text encoder not initialized")
with self._load_model_context("text_encoder"):
# Tokenize
text_inputs = self.text_tokenizer(
text_prompt,
padding="longest",
truncation=True,
max_length=256,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids.to(self.device)
text_attention_mask = text_inputs.attention_mask.to(self.device).bool()
# Encode
with torch.no_grad():
text_outputs = self.text_encoder(text_input_ids)
if hasattr(text_outputs, 'last_hidden_state'):
text_hidden_states = text_outputs.last_hidden_state
elif isinstance(text_outputs, tuple):
text_hidden_states = text_outputs[0]
else:
text_hidden_states = text_outputs
text_hidden_states = text_hidden_states.to(self.dtype)
return text_hidden_states, text_attention_mask
def extract_caption_from_sft_format(self, caption: str) -> str:
try:
if "# Instruction" in caption and "# Caption" in caption:
pattern = r'#\s*Caption\s*\n(.*?)(?:\n\s*#\s*Metas|$)'
match = re.search(pattern, caption, re.DOTALL)
if match:
return match.group(1).strip()
return caption
except Exception as e:
logger.exception("[extract_caption_from_sft_format] Error extracting caption")
return caption
def prepare_seeds(self, actual_batch_size, seed, use_random_seed):
actual_seed_list: List[int] = []
seed_value_for_ui = ""
if use_random_seed:
# Generate brand new seeds and expose them back to the UI
actual_seed_list = [random.randint(0, 2 ** 32 - 1) for _ in range(actual_batch_size)]
seed_value_for_ui = ", ".join(str(s) for s in actual_seed_list)
else:
# Parse seed input: can be a single number, comma-separated numbers, or -1
# If seed is a string, try to parse it as comma-separated values
seed_list = []
if isinstance(seed, str):
# Handle string input (e.g., "123,456" or "-1")
seed_str_list = [s.strip() for s in seed.split(",")]
for s in seed_str_list:
if s == "-1" or s == "":
seed_list.append(-1)
else:
try:
seed_list.append(int(float(s)))
except (ValueError, TypeError) as e:
logger.debug(f"[prepare_seeds] Failed to parse seed value '{s}': {e}")
seed_list.append(-1)
elif seed is None or (isinstance(seed, (int, float)) and seed < 0):
# If seed is None or negative, use -1 for all items
seed_list = [-1] * actual_batch_size
elif isinstance(seed, (int, float)):
# Single seed value
seed_list = [int(seed)]
else:
# Fallback: use -1
seed_list = [-1] * actual_batch_size
# Process seed list according to rules:
# 1. If all are -1, generate different random seeds for each batch item
# 2. If one non-negative seed is provided and batch_size > 1, first uses that seed, rest are random
# 3. If more seeds than batch_size, use first batch_size seeds
# Check if user provided only one non-negative seed (not -1)
has_single_non_negative_seed = (len(seed_list) == 1 and seed_list[0] != -1)
for i in range(actual_batch_size):
if i < len(seed_list):
seed_val = seed_list[i]
else:
# If not enough seeds provided, use -1 (will generate random)
seed_val = -1
# Special case: if only one non-negative seed was provided and batch_size > 1,
# only the first item uses that seed, others are random
if has_single_non_negative_seed and actual_batch_size > 1 and i > 0:
# Generate random seed for remaining items
actual_seed_list.append(random.randint(0, 2 ** 32 - 1))
elif seed_val == -1:
# Generate a random seed for this item
actual_seed_list.append(random.randint(0, 2 ** 32 - 1))
else:
actual_seed_list.append(int(seed_val))
seed_value_for_ui = ", ".join(str(s) for s in actual_seed_list)
return actual_seed_list, seed_value_for_ui
def prepare_metadata(self, bpm, key_scale, time_signature):
"""Build metadata dict - use "N/A" as default for empty fields."""
return self._build_metadata_dict(bpm, key_scale, time_signature)
def is_silence(self, audio):
return torch.all(audio.abs() < 1e-6)
def _get_project_root(self) -> str:
"""Get project root directory path."""
current_file = os.path.abspath(__file__)
return os.path.dirname(os.path.dirname(current_file))
def _get_vae_dtype(self, device: Optional[str] = None) -> torch.dtype:
"""Get VAE dtype based on device."""
device = device or self.device
return torch.bfloat16 if device in ["cuda", "xpu"] else self.dtype
def _format_instruction(self, instruction: str) -> str:
"""Format instruction to ensure it ends with colon."""
if not instruction.endswith(":"):
instruction = instruction + ":"
return instruction
def _load_audio_file(self, audio_file) -> Tuple[torch.Tensor, int]:
"""
Load audio file with ffmpeg backend, fallback to soundfile if failed.
This handles CUDA dependency issues with torchcodec on HuggingFace Space.
Args:
audio_file: Path to the audio file
Returns:
Tuple of (audio_tensor, sample_rate)
Raises:
FileNotFoundError: If the audio file doesn't exist
Exception: If all methods fail to load the audio
"""
# Check if file exists first
if not os.path.exists(audio_file):
raise FileNotFoundError(f"Audio file not found: {audio_file}")
# Try torchaudio with explicit ffmpeg backend first
try:
audio, sr = torchaudio.load(audio_file, backend="ffmpeg")
return audio, sr
except Exception as e:
logger.debug(f"[_load_audio_file] ffmpeg backend failed: {e}, trying soundfile fallback")
# Fallback: use soundfile directly (most compatible)
try:
audio_np, sr = sf.read(audio_file)
# soundfile returns [samples, channels] or [samples], convert to [channels, samples]
audio = torch.from_numpy(audio_np).float()
if audio.dim() == 1:
# Mono: [samples] -> [1, samples]
audio = audio.unsqueeze(0)
else:
# Stereo: [samples, channels] -> [channels, samples]
audio = audio.T
return audio, sr
except Exception as e:
logger.error(f"[_load_audio_file] All methods failed to load audio: {audio_file}, error: {e}")
raise
def _normalize_audio_to_stereo_48k(self, audio: torch.Tensor, sr: int) -> torch.Tensor:
"""
Normalize audio to stereo 48kHz format.
Args:
audio: Audio tensor [channels, samples] or [samples]
sr: Sample rate
Returns:
Normalized audio tensor [2, samples] at 48kHz
"""
# Convert to stereo (duplicate channel if mono)
if audio.shape[0] == 1:
audio = torch.cat([audio, audio], dim=0)
# Keep only first 2 channels
audio = audio[:2]
# Resample to 48kHz if needed
if sr != 48000:
audio = torchaudio.transforms.Resample(sr, 48000)(audio)
# Clamp values to [-1.0, 1.0]
audio = torch.clamp(audio, -1.0, 1.0)
return audio
def _enforce_audio_duration_limits(
self,
audio: torch.Tensor,
sample_rate: int = 48000,
min_duration: float = 10.0,
max_duration: float = 600.0
) -> torch.Tensor:
"""
Enforce audio duration limits by truncating or repeating.
Args:
audio: Audio tensor [channels, samples] at target sample rate
sample_rate: Sample rate of the audio (default: 48000)
min_duration: Minimum duration in seconds (default: 10.0)
max_duration: Maximum duration in seconds (default: 600.0)
Returns:
Audio tensor with enforced duration limits
"""
current_samples = audio.shape[-1]
current_duration = current_samples / sample_rate
min_samples = int(min_duration * sample_rate)
max_samples = int(max_duration * sample_rate)
# If audio is longer than max_duration, truncate
if current_samples > max_samples:
logger.info(f"[_enforce_audio_duration_limits] Truncating audio from {current_duration:.1f}s to {max_duration:.1f}s")
audio = audio[..., :max_samples]
# If audio is shorter than min_duration, repeat to fill
elif current_samples < min_samples:
logger.info(f"[_enforce_audio_duration_limits] Repeating audio from {current_duration:.1f}s to reach {min_duration:.1f}s")
# Calculate how many times to repeat
repeat_times = int(math.ceil(min_samples / current_samples))
# Repeat along the time dimension
audio = audio.repeat(1, repeat_times)
# Truncate to exactly min_samples
audio = audio[..., :min_samples]
return audio
def _normalize_audio_code_hints(self, audio_code_hints: Optional[Union[str, List[str]]], batch_size: int) -> List[Optional[str]]:
"""Normalize audio_code_hints to list of correct length."""
if audio_code_hints is None:
normalized = [None] * batch_size
elif isinstance(audio_code_hints, str):
normalized = [audio_code_hints] * batch_size
elif len(audio_code_hints) == 1 and batch_size > 1:
normalized = audio_code_hints * batch_size
elif len(audio_code_hints) != batch_size:
# Pad or truncate to match batch_size
normalized = list(audio_code_hints[:batch_size])
while len(normalized) < batch_size:
normalized.append(None)
else:
normalized = list(audio_code_hints)
# Clean up: convert empty strings to None
normalized = [hint if isinstance(hint, str) and hint.strip() else None for hint in normalized]
return normalized
def _normalize_instructions(self, instructions: Optional[Union[str, List[str]]], batch_size: int, default: Optional[str] = None) -> List[str]:
"""Normalize instructions to list of correct length."""
if instructions is None:
default_instruction = default or DEFAULT_DIT_INSTRUCTION
return [default_instruction] * batch_size
elif isinstance(instructions, str):
return [instructions] * batch_size
elif len(instructions) == 1:
return instructions * batch_size
elif len(instructions) != batch_size:
# Pad or truncate to match batch_size
normalized = list(instructions[:batch_size])
default_instruction = default or DEFAULT_DIT_INSTRUCTION
while len(normalized) < batch_size:
normalized.append(default_instruction)
return normalized
else:
return list(instructions)
def _format_lyrics(self, lyrics: str, language: str) -> str:
"""Format lyrics text with language header."""
return f"# Languages\n{language}\n\n# Lyric\n{lyrics}<|endoftext|>"
def _pad_sequences(self, sequences: List[torch.Tensor], max_length: int, pad_value: int = 0) -> torch.Tensor:
"""Pad sequences to same length."""
return torch.stack([
torch.nn.functional.pad(seq, (0, max_length - len(seq)), 'constant', pad_value)
for seq in sequences
])
def _extract_caption_and_language(self, metas: List[Union[str, Dict[str, Any]]], captions: List[str], vocal_languages: List[str]) -> Tuple[List[str], List[str]]:
"""Extract caption and language from metas with fallback to provided values."""
actual_captions = list(captions)
actual_languages = list(vocal_languages)
for i, meta in enumerate(metas):
if i >= len(actual_captions):
break
meta_dict = None
if isinstance(meta, str):
parsed = self._parse_metas([meta])
if parsed and isinstance(parsed[0], dict):
meta_dict = parsed[0]
elif isinstance(meta, dict):
meta_dict = meta
if meta_dict:
if 'caption' in meta_dict and meta_dict['caption']:
actual_captions[i] = str(meta_dict['caption'])
if 'language' in meta_dict and meta_dict['language']:
actual_languages[i] = str(meta_dict['language'])
return actual_captions, actual_languages
def _encode_audio_to_latents(self, audio: torch.Tensor) -> torch.Tensor:
"""
Encode audio to latents using VAE.
Args:
audio: Audio tensor [channels, samples] or [batch, channels, samples]
Returns:
Latents tensor [T, D] or [batch, T, D]
"""
# Save original dimension info BEFORE modifying audio
input_was_2d = (audio.dim() == 2)
# Ensure batch dimension
if input_was_2d:
audio = audio.unsqueeze(0)
# Ensure input is in VAE's dtype
vae_input = audio.to(self.device).to(self.vae.dtype)
# Encode to latents
with torch.no_grad():
latents = self.vae.encode(vae_input).latent_dist.sample()
# Cast back to model dtype
latents = latents.to(self.dtype)
# Transpose: [batch, d, T] -> [batch, T, d]
latents = latents.transpose(1, 2)
# Remove batch dimension if input didn't have it
if input_was_2d:
latents = latents.squeeze(0)
return latents
def _build_metadata_dict(self, bpm: Optional[Union[int, str]], key_scale: str, time_signature: str, duration: Optional[float] = None) -> Dict[str, Any]:
"""
Build metadata dictionary with default values.
Args:
bpm: BPM value (optional)
key_scale: Key/scale string
time_signature: Time signature string
duration: Duration in seconds (optional)
Returns:
Metadata dictionary
"""
metadata_dict = {}
if bpm:
metadata_dict["bpm"] = bpm
else:
metadata_dict["bpm"] = "N/A"
if key_scale.strip():
metadata_dict["keyscale"] = key_scale
else:
metadata_dict["keyscale"] = "N/A"
if time_signature.strip() and time_signature != "N/A" and time_signature:
metadata_dict["timesignature"] = time_signature
else:
metadata_dict["timesignature"] = "N/A"
# Add duration if provided
if duration is not None:
metadata_dict["duration"] = f"{int(duration)} seconds"
return metadata_dict
def generate_instruction(
self,
task_type: str,
track_name: Optional[str] = None,
complete_track_classes: Optional[List[str]] = None
) -> str:
if task_type == "text2music":
return TASK_INSTRUCTIONS["text2music"]
elif task_type == "repaint":
return TASK_INSTRUCTIONS["repaint"]
elif task_type == "cover":
return TASK_INSTRUCTIONS["cover"]
elif task_type == "extract":
if track_name:
# Convert to uppercase
track_name_upper = track_name.upper()
return TASK_INSTRUCTIONS["extract"].format(TRACK_NAME=track_name_upper)
else:
return TASK_INSTRUCTIONS["extract_default"]
elif task_type == "lego":
if track_name:
# Convert to uppercase
track_name_upper = track_name.upper()
return TASK_INSTRUCTIONS["lego"].format(TRACK_NAME=track_name_upper)
else:
return TASK_INSTRUCTIONS["lego_default"]
elif task_type == "complete":
if complete_track_classes and len(complete_track_classes) > 0:
# Convert to uppercase and join with " | "
track_classes_upper = [t.upper() for t in complete_track_classes]
complete_track_classes_str = " | ".join(track_classes_upper)
return TASK_INSTRUCTIONS["complete"].format(TRACK_CLASSES=complete_track_classes_str)
else:
return TASK_INSTRUCTIONS["complete_default"]
else:
return TASK_INSTRUCTIONS["text2music"]
def process_reference_audio(self, audio_file) -> Optional[torch.Tensor]:
if audio_file is None:
return None
try:
# Load audio file with fallback backends
audio, sr = self._load_audio_file(audio_file)
logger.debug(f"[process_reference_audio] Reference audio shape: {audio.shape}")
logger.debug(f"[process_reference_audio] Reference audio sample rate: {sr}")
logger.debug(f"[process_reference_audio] Reference audio duration: {audio.shape[-1] / 48000.0} seconds")
# Normalize to stereo 48kHz
audio = self._normalize_audio_to_stereo_48k(audio, sr)
is_silence = self.is_silence(audio)
if is_silence:
return None
# Target length: 30 seconds at 48kHz
target_frames = 30 * 48000
segment_frames = 10 * 48000 # 10 seconds per segment
# If audio is less than 30 seconds, repeat to at least 30 seconds
if audio.shape[-1] < target_frames:
repeat_times = math.ceil(target_frames / audio.shape[-1])
audio = audio.repeat(1, repeat_times)
# If audio is greater than or equal to 30 seconds, no operation needed
# For all cases, select random 10-second segments from front, middle, and back
# then concatenate them to form 30 seconds
total_frames = audio.shape[-1]
segment_size = total_frames // 3
# Front segment: [0, segment_size]
front_start = random.randint(0, max(0, segment_size - segment_frames))
front_audio = audio[:, front_start:front_start + segment_frames]
# Middle segment: [segment_size, 2*segment_size]
middle_start = segment_size + random.randint(0, max(0, segment_size - segment_frames))
middle_audio = audio[:, middle_start:middle_start + segment_frames]
# Back segment: [2*segment_size, total_frames]
back_start = 2 * segment_size + random.randint(0, max(0, (total_frames - 2 * segment_size) - segment_frames))
back_audio = audio[:, back_start:back_start + segment_frames]
# Concatenate three segments to form 30 seconds
audio = torch.cat([front_audio, middle_audio, back_audio], dim=-1)
return audio
except Exception as e:
logger.exception("[process_reference_audio] Error processing reference audio")
return None
def process_src_audio(self, audio_file) -> Optional[torch.Tensor]:
if audio_file is None:
return None
try:
# Load audio file with fallback backends
audio, sr = self._load_audio_file(audio_file)
# Normalize to stereo 48kHz
audio = self._normalize_audio_to_stereo_48k(audio, sr)
# Enforce duration limits (10-600 seconds)
audio = self._enforce_audio_duration_limits(audio)
return audio
except Exception as e:
logger.exception("[process_src_audio] Error processing source audio")
return None
def convert_src_audio_to_codes(self, audio_file) -> str:
"""
Convert uploaded source audio to audio codes string.
Args:
audio_file: Path to audio file or None
Returns:
Formatted codes string like '<|audio_code_123|><|audio_code_456|>...' or error message
"""
if audio_file is None:
return "❌ Please upload source audio first"
if self.model is None or self.vae is None:
return "❌ Model not initialized. Please initialize the service first."
try:
# Process audio file
processed_audio = self.process_src_audio(audio_file)
if processed_audio is None:
return "❌ Failed to process audio file"
# Encode audio to latents using VAE
with torch.no_grad():
with self._load_model_context("vae"):
# Check if audio is silence
if self.is_silence(processed_audio.unsqueeze(0)):
return "❌ Audio file appears to be silent"
# Encode to latents using helper method
latents = self._encode_audio_to_latents(processed_audio) # [T, d]
# Create attention mask for latents
attention_mask = torch.ones(latents.shape[0], dtype=torch.bool, device=self.device)
# Tokenize latents to get code indices
with self._load_model_context("model"):
# Prepare latents for tokenize: [T, d] -> [1, T, d]
hidden_states = latents.unsqueeze(0) # [1, T, d]
# Call tokenize method
# tokenize returns: (quantized, indices, attention_mask)
_, indices, _ = self.model.tokenize(hidden_states, self.silence_latent, attention_mask.unsqueeze(0))
# Format indices as code string
# indices shape: [1, T_5Hz] or [1, T_5Hz, num_quantizers]
# Flatten and convert to list
indices_flat = indices.flatten().cpu().tolist()
codes_string = "".join([f"<|audio_code_{idx}|>" for idx in indices_flat])
logger.info(f"[convert_src_audio_to_codes] Generated {len(indices_flat)} audio codes")
return codes_string
except Exception as e:
error_msg = f"❌ Error converting audio to codes: {str(e)}\n{traceback.format_exc()}"
logger.exception("[convert_src_audio_to_codes] Error converting audio to codes")
return error_msg
def prepare_batch_data(
self,
actual_batch_size,
processed_src_audio,
audio_duration,
captions,
lyrics,
vocal_language,
instruction,
bpm,
key_scale,
time_signature
):
pure_caption = self.extract_caption_from_sft_format(captions)
captions_batch = [pure_caption] * actual_batch_size
instructions_batch = [instruction] * actual_batch_size
lyrics_batch = [lyrics] * actual_batch_size
vocal_languages_batch = [vocal_language] * actual_batch_size
# Calculate duration for metadata
calculated_duration = None
if processed_src_audio is not None:
calculated_duration = processed_src_audio.shape[-1] / 48000.0
elif audio_duration is not None and audio_duration > 0:
calculated_duration = audio_duration
# Build metadata dict - use "N/A" as default for empty fields
metadata_dict = self._build_metadata_dict(bpm, key_scale, time_signature, calculated_duration)
# Format metadata - inference service accepts dict and will convert to string
# Create a copy for each batch item (in case we modify it)
metas_batch = [metadata_dict.copy() for _ in range(actual_batch_size)]
return captions_batch, instructions_batch, lyrics_batch, vocal_languages_batch, metas_batch
def determine_task_type(self, task_type, audio_code_string):
# Determine task type - repaint and lego tasks can have repainting parameters
# Other tasks (cover, text2music, extract, complete) should NOT have repainting
is_repaint_task = (task_type == "repaint")
is_lego_task = (task_type == "lego")
is_cover_task = (task_type == "cover")
has_codes = False
if isinstance(audio_code_string, list):
has_codes = any((c or "").strip() for c in audio_code_string)
else:
has_codes = bool(audio_code_string and str(audio_code_string).strip())
if has_codes:
is_cover_task = True
# Both repaint and lego tasks can use repainting parameters for chunk mask
can_use_repainting = is_repaint_task or is_lego_task
return is_repaint_task, is_lego_task, is_cover_task, can_use_repainting
def create_target_wavs(self, duration_seconds: float) -> torch.Tensor:
try:
# Ensure minimum precision of 100ms
duration_seconds = max(0.1, round(duration_seconds, 1))
# Calculate frames for 48kHz stereo
frames = int(duration_seconds * 48000)
# Create silent stereo audio
target_wavs = torch.zeros(2, frames)
return target_wavs
except Exception as e:
logger.exception("[create_target_wavs] Error creating target audio")
# Fallback to 30 seconds if error
return torch.zeros(2, 30 * 48000)
def prepare_padding_info(
self,
actual_batch_size,
processed_src_audio,
audio_duration,
repainting_start,
repainting_end,
is_repaint_task,
is_lego_task,
is_cover_task,
can_use_repainting,
):
target_wavs_batch = []
# Store padding info for each batch item to adjust repainting coordinates
padding_info_batch = []
for i in range(actual_batch_size):
if processed_src_audio is not None:
if is_cover_task:
# Cover task: Use src_audio directly without padding
batch_target_wavs = processed_src_audio
padding_info_batch.append({
'left_padding_duration': 0.0,
'right_padding_duration': 0.0
})
elif is_repaint_task or is_lego_task:
# Repaint/lego task: May need padding for outpainting
src_audio_duration = processed_src_audio.shape[-1] / 48000.0
# Determine actual end time
if repainting_end is None or repainting_end < 0:
actual_end = src_audio_duration
else:
actual_end = repainting_end
left_padding_duration = max(0, -repainting_start) if repainting_start is not None else 0
right_padding_duration = max(0, actual_end - src_audio_duration)
# Create padded audio
left_padding_frames = int(left_padding_duration * 48000)
right_padding_frames = int(right_padding_duration * 48000)
if left_padding_frames > 0 or right_padding_frames > 0:
# Pad the src audio
batch_target_wavs = torch.nn.functional.pad(
processed_src_audio,
(left_padding_frames, right_padding_frames),
'constant', 0
)
else:
batch_target_wavs = processed_src_audio
# Store padding info for coordinate adjustment
padding_info_batch.append({
'left_padding_duration': left_padding_duration,
'right_padding_duration': right_padding_duration
})
else:
# Other tasks: Use src_audio directly without padding
batch_target_wavs = processed_src_audio
padding_info_batch.append({
'left_padding_duration': 0.0,
'right_padding_duration': 0.0
})
else:
padding_info_batch.append({
'left_padding_duration': 0.0,
'right_padding_duration': 0.0
})
if audio_duration is not None and audio_duration > 0:
batch_target_wavs = self.create_target_wavs(audio_duration)
else:
import random
random_duration = random.uniform(10.0, 120.0)
batch_target_wavs = self.create_target_wavs(random_duration)
target_wavs_batch.append(batch_target_wavs)
# Stack target_wavs into batch tensor
# Ensure all tensors have the same shape by padding to max length
max_frames = max(wav.shape[-1] for wav in target_wavs_batch)
padded_target_wavs = []
for wav in target_wavs_batch:
if wav.shape[-1] < max_frames:
pad_frames = max_frames - wav.shape[-1]
padded_wav = torch.nn.functional.pad(wav, (0, pad_frames), 'constant', 0)
padded_target_wavs.append(padded_wav)
else:
padded_target_wavs.append(wav)
target_wavs_tensor = torch.stack(padded_target_wavs, dim=0) # [batch_size, 2, frames]
if can_use_repainting:
# Repaint task: Set repainting parameters
if repainting_start is None:
repainting_start_batch = None
elif isinstance(repainting_start, (int, float)):
if processed_src_audio is not None:
adjusted_start = repainting_start + padding_info_batch[0]['left_padding_duration']
repainting_start_batch = [adjusted_start] * actual_batch_size
else:
repainting_start_batch = [repainting_start] * actual_batch_size
else:
# List input - adjust each item
repainting_start_batch = []
for i in range(actual_batch_size):
if processed_src_audio is not None:
adjusted_start = repainting_start[i] + padding_info_batch[i]['left_padding_duration']
repainting_start_batch.append(adjusted_start)
else:
repainting_start_batch.append(repainting_start[i])
# Handle repainting_end - use src audio duration if not specified or negative
if processed_src_audio is not None:
# If src audio is provided, use its duration as default end
src_audio_duration = processed_src_audio.shape[-1] / 48000.0
if repainting_end is None or repainting_end < 0:
# Use src audio duration (before padding), then adjust for padding
adjusted_end = src_audio_duration + padding_info_batch[0]['left_padding_duration']
repainting_end_batch = [adjusted_end] * actual_batch_size
else:
# Adjust repainting_end to be relative to padded audio
adjusted_end = repainting_end + padding_info_batch[0]['left_padding_duration']
repainting_end_batch = [adjusted_end] * actual_batch_size
else:
# No src audio - repainting doesn't make sense without it
if repainting_end is None or repainting_end < 0:
repainting_end_batch = None
elif isinstance(repainting_end, (int, float)):
repainting_end_batch = [repainting_end] * actual_batch_size
else:
# List input - adjust each item
repainting_end_batch = []
for i in range(actual_batch_size):
if processed_src_audio is not None:
adjusted_end = repainting_end[i] + padding_info_batch[i]['left_padding_duration']
repainting_end_batch.append(adjusted_end)
else:
repainting_end_batch.append(repainting_end[i])
else:
# All other tasks (cover, text2music, extract, complete): No repainting
# Only repaint and lego tasks should have repainting parameters
repainting_start_batch = None
repainting_end_batch = None
return repainting_start_batch, repainting_end_batch, target_wavs_tensor
def _prepare_batch(
self,
captions: List[str],
lyrics: List[str],
keys: Optional[List[str]] = None,
target_wavs: Optional[torch.Tensor] = None,
refer_audios: Optional[List[List[torch.Tensor]]] = None,
metas: Optional[List[Union[str, Dict[str, Any]]]] = None,
vocal_languages: Optional[List[str]] = None,
repainting_start: Optional[List[float]] = None,
repainting_end: Optional[List[float]] = None,
instructions: Optional[List[str]] = None,
audio_code_hints: Optional[List[Optional[str]]] = None,
audio_cover_strength: float = 1.0,
) -> Dict[str, Any]:
"""
Prepare batch data with fallbacks for missing inputs.
Args:
captions: List of text captions (optional, can be empty strings)
lyrics: List of lyrics (optional, can be empty strings)
keys: List of unique identifiers (optional)
target_wavs: Target audio tensors (optional, will use silence if not provided)
refer_audios: Reference audio tensors (optional, will use silence if not provided)
metas: Metadata (optional, will use defaults if not provided)
vocal_languages: Vocal languages (optional, will default to 'en')
Returns:
Batch dictionary ready for model input
"""
batch_size = len(captions)
# Ensure silence_latent is on the correct device for batch preparation
self._ensure_silence_latent_on_device()
# Normalize audio_code_hints to batch list
audio_code_hints = self._normalize_audio_code_hints(audio_code_hints, batch_size)
# Synchronize CUDA to catch any pending errors from previous operations
if torch.cuda.is_available():
torch.cuda.synchronize()
for ii, refer_audio_list in enumerate(refer_audios):
if refer_audio_list is None:
continue
if isinstance(refer_audio_list, list):
for idx, refer_audio in enumerate(refer_audio_list):
if refer_audio is not None and isinstance(refer_audio, torch.Tensor):
refer_audio_list[idx] = refer_audio.to(self.device).to(torch.bfloat16)
elif isinstance(refer_audio_list, torch.Tensor):
refer_audios[ii] = refer_audio_list.to(self.device)
if vocal_languages is None:
vocal_languages = self._create_fallback_vocal_languages(batch_size)
# Parse metas with fallbacks
parsed_metas = self._parse_metas(metas)
# Encode target_wavs to get target_latents
with torch.no_grad():
target_latents_list = []
latent_lengths = []
# Use per-item wavs (may be adjusted if audio_code_hints are provided)
target_wavs_list = [target_wavs[i].clone() for i in range(batch_size)]
if target_wavs.device != self.device:
target_wavs = target_wavs.to(self.device)
with self._load_model_context("vae"):
for i in range(batch_size):
code_hint = audio_code_hints[i]
# Prefer decoding from provided audio codes
if code_hint:
logger.info(f"[generate_music] Decoding audio codes for item {i}...")
decoded_latents = self._decode_audio_codes_to_latents(code_hint)
if decoded_latents is not None:
decoded_latents = decoded_latents.squeeze(0)
target_latents_list.append(decoded_latents)
latent_lengths.append(decoded_latents.shape[0])
# Create a silent wav matching the latent length for downstream scaling
frames_from_codes = max(1, int(decoded_latents.shape[0] * 1920))
target_wavs_list[i] = torch.zeros(2, frames_from_codes)
continue
# Fallback to VAE encode from audio
current_wav = target_wavs_list[i].to(self.device).unsqueeze(0)
if self.is_silence(current_wav):
expected_latent_length = current_wav.shape[-1] // 1920
target_latent = self.silence_latent[0, :expected_latent_length, :]
else:
# Encode using helper method
logger.info(f"[generate_music] Encoding target audio to latents for item {i}...")
target_latent = self._encode_audio_to_latents(current_wav.squeeze(0)) # Remove batch dim for helper
target_latents_list.append(target_latent)
latent_lengths.append(target_latent.shape[0])
# Pad target_wavs to consistent length for outputs
max_target_frames = max(wav.shape[-1] for wav in target_wavs_list)
padded_target_wavs = []
for wav in target_wavs_list:
if wav.shape[-1] < max_target_frames:
pad_frames = max_target_frames - wav.shape[-1]
wav = torch.nn.functional.pad(wav, (0, pad_frames), "constant", 0)
padded_target_wavs.append(wav)
target_wavs = torch.stack(padded_target_wavs)
wav_lengths = torch.tensor([target_wavs.shape[-1]] * batch_size, dtype=torch.long)
# Pad latents to same length
max_latent_length = max(latent.shape[0] for latent in target_latents_list)
max_latent_length = max(128, max_latent_length)
padded_latents = []
for latent in target_latents_list:
latent_length = latent.shape[0]
if latent.shape[0] < max_latent_length:
pad_length = max_latent_length - latent.shape[0]
latent = torch.cat([latent, self.silence_latent[0, :pad_length, :]], dim=0)
padded_latents.append(latent)
target_latents = torch.stack(padded_latents)
latent_masks = torch.stack([
torch.cat([
torch.ones(l, dtype=torch.long, device=self.device),
torch.zeros(max_latent_length - l, dtype=torch.long, device=self.device)
])
for l in latent_lengths
])
# Process instructions early so we can use them for task type detection
# Use custom instructions if provided, otherwise use default
instructions = self._normalize_instructions(instructions, batch_size, DEFAULT_DIT_INSTRUCTION)
# Generate chunk_masks and spans based on repainting parameters
# Also determine if this is a cover task (target audio provided without repainting)
chunk_masks = []
spans = []
is_covers = []
# Store repainting latent ranges for later use in src_latents creation
repainting_ranges = {} # {batch_idx: (start_latent, end_latent)}
for i in range(batch_size):
has_code_hint = audio_code_hints[i] is not None
# Check if repainting is enabled for this batch item
has_repainting = False
if repainting_start is not None and repainting_end is not None:
start_sec = repainting_start[i] if repainting_start[i] is not None else 0.0
end_sec = repainting_end[i]
if end_sec is not None and end_sec > start_sec:
# Repainting mode with outpainting support
# The target_wavs may have been padded for outpainting
# Need to calculate the actual position in the padded audio
# Calculate padding (if start < 0, there's left padding)
left_padding_sec = max(0, -start_sec)
# Adjust positions to account for padding
# In the padded audio, the original start is shifted by left_padding
adjusted_start_sec = start_sec + left_padding_sec
adjusted_end_sec = end_sec + left_padding_sec
# Convert seconds to latent frames (audio_frames / 1920 = latent_frames)
start_latent = int(adjusted_start_sec * self.sample_rate // 1920)
end_latent = int(adjusted_end_sec * self.sample_rate // 1920)
# Clamp to valid range
start_latent = max(0, min(start_latent, max_latent_length - 1))
end_latent = max(start_latent + 1, min(end_latent, max_latent_length))
# Create mask: False = keep original, True = generate new
mask = torch.zeros(max_latent_length, dtype=torch.bool, device=self.device)
mask[start_latent:end_latent] = True
chunk_masks.append(mask)
spans.append(("repainting", start_latent, end_latent))
# Store repainting range for later use
repainting_ranges[i] = (start_latent, end_latent)
has_repainting = True
is_covers.append(False) # Repainting is not cover task
else:
# Full generation (no valid repainting range)
chunk_masks.append(torch.ones(max_latent_length, dtype=torch.bool, device=self.device))
spans.append(("full", 0, max_latent_length))
# Determine task type from instruction, not from target_wavs
# Only cover task should have is_cover=True
instruction_i = instructions[i] if instructions and i < len(instructions) else ""
instruction_lower = instruction_i.lower()
# Cover task instruction: "Generate audio semantic tokens based on the given conditions:"
is_cover = ("generate audio semantic tokens" in instruction_lower and
"based on the given conditions" in instruction_lower) or has_code_hint
is_covers.append(is_cover)
else:
# Full generation (no repainting parameters)
chunk_masks.append(torch.ones(max_latent_length, dtype=torch.bool, device=self.device))
spans.append(("full", 0, max_latent_length))
# Determine task type from instruction, not from target_wavs
# Only cover task should have is_cover=True
instruction_i = instructions[i] if instructions and i < len(instructions) else ""
instruction_lower = instruction_i.lower()
# Cover task instruction: "Generate audio semantic tokens based on the given conditions:"
is_cover = ("generate audio semantic tokens" in instruction_lower and
"based on the given conditions" in instruction_lower) or has_code_hint
is_covers.append(is_cover)
chunk_masks = torch.stack(chunk_masks)
is_covers = torch.BoolTensor(is_covers).to(self.device)
# Create src_latents based on task type
# For cover/extract/complete/lego/repaint tasks: src_latents = target_latents.clone() (if target_wavs provided)
# For text2music task: src_latents = silence_latent (if no target_wavs or silence)
# For repaint task: additionally replace inpainting region with silence_latent
src_latents_list = []
silence_latent_tiled = self.silence_latent[0, :max_latent_length, :]
for i in range(batch_size):
# Check if target_wavs is provided and not silent (for extract/complete/lego/cover/repaint tasks)
has_code_hint = audio_code_hints[i] is not None
has_target_audio = has_code_hint or (target_wavs is not None and target_wavs[i].abs().sum() > 1e-6)
if has_target_audio:
# For tasks that use input audio (cover/extract/complete/lego/repaint)
# Check if this item has repainting
item_has_repainting = (i in repainting_ranges)
if item_has_repainting:
# Repaint task: src_latents = target_latents with inpainting region replaced by silence_latent
# 1. Clone target_latents (encoded from src audio, preserving original audio)
src_latent = target_latents[i].clone()
# 2. Replace inpainting region with silence_latent
start_latent, end_latent = repainting_ranges[i]
src_latent[start_latent:end_latent] = silence_latent_tiled[start_latent:end_latent]
src_latents_list.append(src_latent)
else:
# Cover/extract/complete/lego tasks: src_latents = target_latents.clone()
# All these tasks need to base on input audio
src_latents_list.append(target_latents[i].clone())
else:
# Text2music task: src_latents = silence_latent (no input audio)
# Use silence_latent for the full length
src_latents_list.append(silence_latent_tiled.clone())
src_latents = torch.stack(src_latents_list)
# Process audio_code_hints to generate precomputed_lm_hints_25Hz
precomputed_lm_hints_25Hz_list = []
for i in range(batch_size):
if audio_code_hints[i] is not None:
# Decode audio codes to 25Hz latents
logger.info(f"[generate_music] Decoding audio codes for LM hints for item {i}...")
hints = self._decode_audio_codes_to_latents(audio_code_hints[i])
if hints is not None:
# Pad or crop to match max_latent_length
if hints.shape[1] < max_latent_length:
pad_length = max_latent_length - hints.shape[1]
pad = self.silence_latent
# Match dims: hints is usually [1, T, D], silence_latent is [1, T, D]
if pad.dim() == 2:
pad = pad.unsqueeze(0)
if hints.dim() == 2:
hints = hints.unsqueeze(0)
pad_chunk = pad[:, :pad_length, :]
if pad_chunk.device != hints.device or pad_chunk.dtype != hints.dtype:
pad_chunk = pad_chunk.to(device=hints.device, dtype=hints.dtype)
hints = torch.cat([hints, pad_chunk], dim=1)
elif hints.shape[1] > max_latent_length:
hints = hints[:, :max_latent_length, :]
precomputed_lm_hints_25Hz_list.append(hints[0]) # Remove batch dimension
else:
precomputed_lm_hints_25Hz_list.append(None)
else:
precomputed_lm_hints_25Hz_list.append(None)
# Stack precomputed hints if any exist, otherwise set to None
if any(h is not None for h in precomputed_lm_hints_25Hz_list):
# For items without hints, use silence_latent as placeholder
precomputed_lm_hints_25Hz = torch.stack([
h if h is not None else silence_latent_tiled
for h in precomputed_lm_hints_25Hz_list
])
else:
precomputed_lm_hints_25Hz = None
# Extract caption and language from metas if available (from LM CoT output)
# Fallback to user-provided values if not in metas
actual_captions, actual_languages = self._extract_caption_and_language(parsed_metas, captions, vocal_languages)
# Format text_inputs
text_inputs = []
text_token_idss = []
text_attention_masks = []
lyric_token_idss = []
lyric_attention_masks = []
for i in range(batch_size):
# Use custom instruction for this batch item
instruction = self._format_instruction(instructions[i] if i < len(instructions) else DEFAULT_DIT_INSTRUCTION)
actual_caption = actual_captions[i]
actual_language = actual_languages[i]
# Format text prompt with custom instruction (using LM-generated caption if available)
text_prompt = SFT_GEN_PROMPT.format(instruction, actual_caption, parsed_metas[i])
# Tokenize text
text_inputs_dict = self.text_tokenizer(
text_prompt,
padding="longest",
truncation=True,
max_length=256,
return_tensors="pt",
)
text_token_ids = text_inputs_dict.input_ids[0]
text_attention_mask = text_inputs_dict.attention_mask[0].bool()
# Format and tokenize lyrics (using LM-generated language if available)
lyrics_text = self._format_lyrics(lyrics[i], actual_language)
lyrics_inputs_dict = self.text_tokenizer(
lyrics_text,
padding="longest",
truncation=True,
max_length=2048,
return_tensors="pt",
)
lyric_token_ids = lyrics_inputs_dict.input_ids[0]
lyric_attention_mask = lyrics_inputs_dict.attention_mask[0].bool()
# Build full text input
text_input = text_prompt + "\n\n" + lyrics_text
text_inputs.append(text_input)
text_token_idss.append(text_token_ids)
text_attention_masks.append(text_attention_mask)
lyric_token_idss.append(lyric_token_ids)
lyric_attention_masks.append(lyric_attention_mask)
# Pad tokenized sequences
max_text_length = max(len(seq) for seq in text_token_idss)
padded_text_token_idss = self._pad_sequences(text_token_idss, max_text_length, self.text_tokenizer.pad_token_id)
padded_text_attention_masks = self._pad_sequences(text_attention_masks, max_text_length, 0)
max_lyric_length = max(len(seq) for seq in lyric_token_idss)
padded_lyric_token_idss = self._pad_sequences(lyric_token_idss, max_lyric_length, self.text_tokenizer.pad_token_id)
padded_lyric_attention_masks = self._pad_sequences(lyric_attention_masks, max_lyric_length, 0)
padded_non_cover_text_input_ids = None
padded_non_cover_text_attention_masks = None
if audio_cover_strength < 1.0:
non_cover_text_input_ids = []
non_cover_text_attention_masks = []
for i in range(batch_size):
# Use custom instruction for this batch item
instruction = self._format_instruction(DEFAULT_DIT_INSTRUCTION)
# Extract caption from metas if available (from LM CoT output)
actual_caption = actual_captions[i]
# Format text prompt with custom instruction (using LM-generated caption if available)
text_prompt = SFT_GEN_PROMPT.format(instruction, actual_caption, parsed_metas[i])
# Tokenize text
text_inputs_dict = self.text_tokenizer(
text_prompt,
padding="longest",
truncation=True,
max_length=256,
return_tensors="pt",
)
text_token_ids = text_inputs_dict.input_ids[0]
non_cover_text_attention_mask = text_inputs_dict.attention_mask[0].bool()
non_cover_text_input_ids.append(text_token_ids)
non_cover_text_attention_masks.append(non_cover_text_attention_mask)
padded_non_cover_text_input_ids = self._pad_sequences(non_cover_text_input_ids, max_text_length, self.text_tokenizer.pad_token_id)
padded_non_cover_text_attention_masks = self._pad_sequences(non_cover_text_attention_masks, max_text_length, 0)
if audio_cover_strength < 1.0:
assert padded_non_cover_text_input_ids is not None, "When audio_cover_strength < 1.0, padded_non_cover_text_input_ids must not be None"
assert padded_non_cover_text_attention_masks is not None, "When audio_cover_strength < 1.0, padded_non_cover_text_attention_masks must not be None"
# Prepare batch
batch = {
"keys": keys,
"target_wavs": target_wavs.to(self.device),
"refer_audioss": refer_audios,
"wav_lengths": wav_lengths.to(self.device),
"captions": captions,
"lyrics": lyrics,
"metas": parsed_metas,
"vocal_languages": vocal_languages,
"target_latents": target_latents,
"src_latents": src_latents,
"latent_masks": latent_masks,
"chunk_masks": chunk_masks,
"spans": spans,
"text_inputs": text_inputs,
"text_token_idss": padded_text_token_idss,
"text_attention_masks": padded_text_attention_masks,
"lyric_token_idss": padded_lyric_token_idss,
"lyric_attention_masks": padded_lyric_attention_masks,
"is_covers": is_covers,
"precomputed_lm_hints_25Hz": precomputed_lm_hints_25Hz,
"non_cover_text_input_ids": padded_non_cover_text_input_ids,
"non_cover_text_attention_masks": padded_non_cover_text_attention_masks,
}
# to device
for k, v in batch.items():
if isinstance(v, torch.Tensor):
batch[k] = v.to(self.device)
if torch.is_floating_point(v):
batch[k] = v.to(self.dtype)
return batch
def infer_refer_latent(self, refer_audioss):
refer_audio_order_mask = []
refer_audio_latents = []
# Ensure silence_latent is on the correct device
self._ensure_silence_latent_on_device()
def _normalize_audio_2d(a: torch.Tensor) -> torch.Tensor:
"""Normalize audio tensor to [2, T] on current device."""
if not isinstance(a, torch.Tensor):
raise TypeError(f"refer_audio must be a torch.Tensor, got {type(a)!r}")
# Accept [T], [1, T], [2, T], [1, 2, T]
if a.dim() == 3 and a.shape[0] == 1:
a = a.squeeze(0)
if a.dim() == 1:
a = a.unsqueeze(0)
if a.dim() != 2:
raise ValueError(f"refer_audio must be 1D/2D/3D(1,2,T); got shape={tuple(a.shape)}")
if a.shape[0] == 1:
a = torch.cat([a, a], dim=0)
a = a[:2]
return a
def _ensure_latent_3d(z: torch.Tensor) -> torch.Tensor:
"""Ensure latent is [N, T, D] (3D) for packing."""
if z.dim() == 4 and z.shape[0] == 1:
z = z.squeeze(0)
if z.dim() == 2:
z = z.unsqueeze(0)
return z
for batch_idx, refer_audios in enumerate(refer_audioss):
if len(refer_audios) == 1 and torch.all(refer_audios[0] == 0.0):
refer_audio_latent = _ensure_latent_3d(self.silence_latent[:, :750, :])
refer_audio_latents.append(refer_audio_latent)
refer_audio_order_mask.append(batch_idx)
else:
for refer_audio in refer_audios:
refer_audio = _normalize_audio_2d(refer_audio)
# Ensure input is in VAE's dtype
vae_input = refer_audio.unsqueeze(0).to(self.vae.dtype)
refer_audio_latent = self.vae.encode(vae_input).latent_dist.sample()
# Cast back to model dtype
refer_audio_latent = refer_audio_latent.to(self.dtype)
refer_audio_latents.append(_ensure_latent_3d(refer_audio_latent.transpose(1, 2)))
refer_audio_order_mask.append(batch_idx)
refer_audio_latents = torch.cat(refer_audio_latents, dim=0)
refer_audio_order_mask = torch.tensor(refer_audio_order_mask, device=self.device, dtype=torch.long)
return refer_audio_latents, refer_audio_order_mask
def infer_text_embeddings(self, text_token_idss):
with torch.no_grad():
text_embeddings = self.text_encoder(input_ids=text_token_idss, lyric_attention_mask=None).last_hidden_state
return text_embeddings
def infer_lyric_embeddings(self, lyric_token_ids):
with torch.no_grad():
lyric_embeddings = self.text_encoder.embed_tokens(lyric_token_ids)
return lyric_embeddings
def preprocess_batch(self, batch):
# step 1: VAE encode latents, target_latents: N x T x d
# target_latents: N x T x d
target_latents = batch["target_latents"]
src_latents = batch["src_latents"]
attention_mask = batch["latent_masks"]
audio_codes = batch.get("audio_codes", None)
audio_attention_mask = attention_mask
dtype = target_latents.dtype
bs = target_latents.shape[0]
device = target_latents.device
# step 2: refer_audio timbre
keys = batch["keys"]
with self._load_model_context("vae"):
refer_audio_acoustic_hidden_states_packed, refer_audio_order_mask = self.infer_refer_latent(batch["refer_audioss"])
if refer_audio_acoustic_hidden_states_packed.dtype != dtype:
refer_audio_acoustic_hidden_states_packed = refer_audio_acoustic_hidden_states_packed.to(dtype)
# step 4: chunk mask, N x T x d
chunk_mask = batch["chunk_masks"]
chunk_mask = chunk_mask.to(device).unsqueeze(-1).repeat(1, 1, target_latents.shape[2])
spans = batch["spans"]
text_token_idss = batch["text_token_idss"]
text_attention_mask = batch["text_attention_masks"]
lyric_token_idss = batch["lyric_token_idss"]
lyric_attention_mask = batch["lyric_attention_masks"]
text_inputs = batch["text_inputs"]
logger.info("[preprocess_batch] Inferring prompt embeddings...")
with self._load_model_context("text_encoder"):
text_hidden_states = self.infer_text_embeddings(text_token_idss)
logger.info("[preprocess_batch] Inferring lyric embeddings...")
lyric_hidden_states = self.infer_lyric_embeddings(lyric_token_idss)
is_covers = batch["is_covers"]
# Get precomputed hints from batch if available
precomputed_lm_hints_25Hz = batch.get("precomputed_lm_hints_25Hz", None)
# Get non-cover text input ids and attention masks from batch if available
non_cover_text_input_ids = batch.get("non_cover_text_input_ids", None)
non_cover_text_attention_masks = batch.get("non_cover_text_attention_masks", None)
non_cover_text_hidden_states = None
if non_cover_text_input_ids is not None:
logger.info("[preprocess_batch] Inferring non-cover text embeddings...")
non_cover_text_hidden_states = self.infer_text_embeddings(non_cover_text_input_ids)
return (
keys,
text_inputs,
src_latents,
target_latents,
# model inputs
text_hidden_states,
text_attention_mask,
lyric_hidden_states,
lyric_attention_mask,
audio_attention_mask,
refer_audio_acoustic_hidden_states_packed,
refer_audio_order_mask,
chunk_mask,
spans,
is_covers,
audio_codes,
lyric_token_idss,
precomputed_lm_hints_25Hz,
non_cover_text_hidden_states,
non_cover_text_attention_masks,
)
@torch.no_grad()
def service_generate(
self,
captions: Union[str, List[str]],
lyrics: Union[str, List[str]],
keys: Optional[Union[str, List[str]]] = None,
target_wavs: Optional[torch.Tensor] = None,
refer_audios: Optional[List[List[torch.Tensor]]] = None,
metas: Optional[Union[str, Dict[str, Any], List[Union[str, Dict[str, Any]]]]] = None,
vocal_languages: Optional[Union[str, List[str]]] = None,
infer_steps: int = 60,
guidance_scale: float = 7.0,
seed: Optional[Union[int, List[int]]] = None,
return_intermediate: bool = False,
repainting_start: Optional[Union[float, List[float]]] = None,
repainting_end: Optional[Union[float, List[float]]] = None,
instructions: Optional[Union[str, List[str]]] = None,
audio_cover_strength: float = 1.0,
use_adg: bool = False,
cfg_interval_start: float = 0.0,
cfg_interval_end: float = 1.0,
shift: float = 1.0,
audio_code_hints: Optional[Union[str, List[str]]] = None,
infer_method: str = "ode",
timesteps: Optional[List[float]] = None,
) -> Dict[str, Any]:
"""
Generate music from text inputs.
Args:
captions: Text caption(s) describing the music (optional, can be empty strings)
lyrics: Lyric text(s) (optional, can be empty strings)
keys: Unique identifier(s) (optional)
target_wavs: Target audio tensor(s) for conditioning (optional)
refer_audios: Reference audio tensor(s) for style transfer (optional)
metas: Metadata dict(s) or string(s) (optional)
vocal_languages: Language code(s) for lyrics (optional, defaults to 'en')
infer_steps: Number of inference steps (default: 60)
guidance_scale: Guidance scale for generation (default: 7.0)
seed: Random seed (optional)
return_intermediate: Whether to return intermediate results (default: False)
repainting_start: Start time(s) for repainting region in seconds (optional)
repainting_end: End time(s) for repainting region in seconds (optional)
instructions: Instruction text(s) for generation (optional)
audio_cover_strength: Strength of audio cover mode (default: 1.0)
use_adg: Whether to use ADG (Adaptive Diffusion Guidance) (default: False)
cfg_interval_start: Start of CFG interval (0.0-1.0, default: 0.0)
cfg_interval_end: End of CFG interval (0.0-1.0, default: 1.0)
Returns:
Dictionary containing:
- pred_wavs: Generated audio tensors
- target_wavs: Input target audio (if provided)
- vqvae_recon_wavs: VAE reconstruction of target
- keys: Identifiers used
- text_inputs: Formatted text inputs
- sr: Sample rate
- spans: Generation spans
- time_costs: Timing information
- seed_num: Seed used
"""
if self.config.is_turbo:
# Limit inference steps to maximum 8
if infer_steps > 8:
logger.warning(f"[service_generate] dmd_gan version: infer_steps {infer_steps} exceeds maximum 8, clamping to 8")
infer_steps = 8
# CFG parameters are not adjustable for dmd_gan (they will be ignored)
# Note: guidance_scale, cfg_interval_start, cfg_interval_end are still passed but may be ignored by the model
# Convert single inputs to lists
if isinstance(captions, str):
captions = [captions]
if isinstance(lyrics, str):
lyrics = [lyrics]
if isinstance(keys, str):
keys = [keys]
if isinstance(vocal_languages, str):
vocal_languages = [vocal_languages]
if isinstance(metas, (str, dict)):
metas = [metas]
# Convert repainting parameters to lists
if isinstance(repainting_start, (int, float)):
repainting_start = [repainting_start]
if isinstance(repainting_end, (int, float)):
repainting_end = [repainting_end]
# Get batch size from captions
batch_size = len(captions)
# Normalize instructions and audio_code_hints to match batch size
instructions = self._normalize_instructions(instructions, batch_size, DEFAULT_DIT_INSTRUCTION) if instructions is not None else None
audio_code_hints = self._normalize_audio_code_hints(audio_code_hints, batch_size) if audio_code_hints is not None else None
# Convert seed to list format
if seed is None:
seed_list = None
elif isinstance(seed, list):
seed_list = seed
# Ensure we have enough seeds for batch size
if len(seed_list) < batch_size:
# Pad with last seed or random seeds
import random
while len(seed_list) < batch_size:
seed_list.append(random.randint(0, 2**32 - 1))
elif len(seed_list) > batch_size:
# Truncate to batch size
seed_list = seed_list[:batch_size]
else:
# Single seed value - use for all batch items
seed_list = [int(seed)] * batch_size
# Don't set global random seed here - each item will use its own seed
# Prepare batch
batch = self._prepare_batch(
captions=captions,
lyrics=lyrics,
keys=keys,
target_wavs=target_wavs,
refer_audios=refer_audios,
metas=metas,
vocal_languages=vocal_languages,
repainting_start=repainting_start,
repainting_end=repainting_end,
instructions=instructions,
audio_code_hints=audio_code_hints,
audio_cover_strength=audio_cover_strength,
)
processed_data = self.preprocess_batch(batch)
(
keys,
text_inputs,
src_latents,
target_latents,
# model inputs
text_hidden_states,
text_attention_mask,
lyric_hidden_states,
lyric_attention_mask,
audio_attention_mask,
refer_audio_acoustic_hidden_states_packed,
refer_audio_order_mask,
chunk_mask,
spans,
is_covers,
audio_codes,
lyric_token_idss,
precomputed_lm_hints_25Hz,
non_cover_text_hidden_states,
non_cover_text_attention_masks,
) = processed_data
# Set generation parameters
# Use seed_list if available, otherwise generate a single seed
if seed_list is not None:
# Pass seed list to model (will be handled there)
seed_param = seed_list
else:
seed_param = random.randint(0, 2**32 - 1)
# Ensure silence_latent is on the correct device before creating generate_kwargs
self._ensure_silence_latent_on_device()
generate_kwargs = {
"text_hidden_states": text_hidden_states,
"text_attention_mask": text_attention_mask,
"lyric_hidden_states": lyric_hidden_states,
"lyric_attention_mask": lyric_attention_mask,
"refer_audio_acoustic_hidden_states_packed": refer_audio_acoustic_hidden_states_packed,
"refer_audio_order_mask": refer_audio_order_mask,
"src_latents": src_latents,
"chunk_masks": chunk_mask,
"is_covers": is_covers,
"silence_latent": self.silence_latent,
"seed": seed_param,
"non_cover_text_hidden_states": non_cover_text_hidden_states,
"non_cover_text_attention_mask": non_cover_text_attention_masks,
"precomputed_lm_hints_25Hz": precomputed_lm_hints_25Hz,
"audio_cover_strength": audio_cover_strength,
"infer_method": infer_method,
"infer_steps": infer_steps,
"diffusion_guidance_sale": guidance_scale,
"use_adg": use_adg,
"cfg_interval_start": cfg_interval_start,
"cfg_interval_end": cfg_interval_end,
"shift": shift,
}
# Add custom timesteps if provided (convert to tensor)
if timesteps is not None:
generate_kwargs["timesteps"] = torch.tensor(timesteps, dtype=torch.float32)
logger.info("[service_generate] Generating audio...")
with self._load_model_context("model"):
# Prepare condition tensors first (for LRC timestamp generation)
encoder_hidden_states, encoder_attention_mask, context_latents = self.model.prepare_condition(
text_hidden_states=text_hidden_states,
text_attention_mask=text_attention_mask,
lyric_hidden_states=lyric_hidden_states,
lyric_attention_mask=lyric_attention_mask,
refer_audio_acoustic_hidden_states_packed=refer_audio_acoustic_hidden_states_packed,
refer_audio_order_mask=refer_audio_order_mask,
hidden_states=src_latents,
attention_mask=torch.ones(src_latents.shape[0], src_latents.shape[1], device=src_latents.device, dtype=src_latents.dtype),
silence_latent=self.silence_latent,
src_latents=src_latents,
chunk_masks=chunk_mask,
is_covers=is_covers,
precomputed_lm_hints_25Hz=precomputed_lm_hints_25Hz,
)
outputs = self.model.generate_audio(**generate_kwargs)
# Add intermediate information to outputs for extra_outputs
outputs["src_latents"] = src_latents
outputs["target_latents_input"] = target_latents # Input target latents (before generation)
outputs["chunk_masks"] = chunk_mask
outputs["spans"] = spans
outputs["latent_masks"] = batch.get("latent_masks") # Latent masks for valid length
# Add condition tensors for LRC timestamp generation
outputs["encoder_hidden_states"] = encoder_hidden_states
outputs["encoder_attention_mask"] = encoder_attention_mask
outputs["context_latents"] = context_latents
outputs["lyric_token_idss"] = lyric_token_idss
return outputs
def tiled_decode(self, latents, chunk_size=512, overlap=64, offload_wav_to_cpu=False):
"""
Decode latents using tiling to reduce VRAM usage.
Uses overlap-discard strategy to avoid boundary artifacts.
Args:
latents: [Batch, Channels, Length]
chunk_size: Size of latent chunk to process at once
overlap: Overlap size in latent frames
offload_wav_to_cpu: If True, offload decoded wav audio to CPU immediately to save VRAM
"""
B, C, T = latents.shape
# If short enough, decode directly
if T <= chunk_size:
return self.vae.decode(latents).sample
# Calculate stride (core size)
stride = chunk_size - 2 * overlap
if stride <= 0:
raise ValueError(f"chunk_size {chunk_size} must be > 2 * overlap {overlap}")
num_steps = math.ceil(T / stride)
if offload_wav_to_cpu:
# Optimized path: offload wav to CPU immediately to save VRAM
return self._tiled_decode_offload_cpu(latents, B, T, stride, overlap, num_steps)
else:
# Default path: keep everything on GPU
return self._tiled_decode_gpu(latents, B, T, stride, overlap, num_steps)
def _tiled_decode_gpu(self, latents, B, T, stride, overlap, num_steps):
"""Standard tiled decode keeping all data on GPU."""
decoded_audio_list = []
upsample_factor = None
for i in tqdm(range(num_steps), desc="Decoding audio chunks"):
# Core range in latents
core_start = i * stride
core_end = min(core_start + stride, T)
# Window range (with overlap)
win_start = max(0, core_start - overlap)
win_end = min(T, core_end + overlap)
# Extract chunk
latent_chunk = latents[:, :, win_start:win_end]
# Decode
# [Batch, Channels, AudioSamples]
audio_chunk = self.vae.decode(latent_chunk).sample
# Determine upsample factor from the first chunk
if upsample_factor is None:
upsample_factor = audio_chunk.shape[-1] / latent_chunk.shape[-1]
# Calculate trim amounts in audio samples
# How much overlap was added at the start?
added_start = core_start - win_start # latent frames
trim_start = int(round(added_start * upsample_factor))
# How much overlap was added at the end?
added_end = win_end - core_end # latent frames
trim_end = int(round(added_end * upsample_factor))
# Trim audio
audio_len = audio_chunk.shape[-1]
end_idx = audio_len - trim_end if trim_end > 0 else audio_len
audio_core = audio_chunk[:, :, trim_start:end_idx]
decoded_audio_list.append(audio_core)
# Concatenate
final_audio = torch.cat(decoded_audio_list, dim=-1)
return final_audio
def _tiled_decode_offload_cpu(self, latents, B, T, stride, overlap, num_steps):
"""Optimized tiled decode that offloads to CPU immediately to save VRAM."""
# First pass: decode first chunk to get upsample_factor and audio channels
first_core_start = 0
first_core_end = min(stride, T)
first_win_start = 0
first_win_end = min(T, first_core_end + overlap)
first_latent_chunk = latents[:, :, first_win_start:first_win_end]
first_audio_chunk = self.vae.decode(first_latent_chunk).sample
upsample_factor = first_audio_chunk.shape[-1] / first_latent_chunk.shape[-1]
audio_channels = first_audio_chunk.shape[1]
# Calculate total audio length and pre-allocate CPU tensor
total_audio_length = int(round(T * upsample_factor))
final_audio = torch.zeros(B, audio_channels, total_audio_length,
dtype=first_audio_chunk.dtype, device='cpu')
# Process first chunk: trim and copy to CPU
first_added_end = first_win_end - first_core_end
first_trim_end = int(round(first_added_end * upsample_factor))
first_audio_len = first_audio_chunk.shape[-1]
first_end_idx = first_audio_len - first_trim_end if first_trim_end > 0 else first_audio_len
first_audio_core = first_audio_chunk[:, :, :first_end_idx]
audio_write_pos = first_audio_core.shape[-1]
final_audio[:, :, :audio_write_pos] = first_audio_core.cpu()
# Free GPU memory
del first_audio_chunk, first_audio_core, first_latent_chunk
# Process remaining chunks
for i in tqdm(range(1, num_steps), desc="Decoding audio chunks"):
# Core range in latents
core_start = i * stride
core_end = min(core_start + stride, T)
# Window range (with overlap)
win_start = max(0, core_start - overlap)
win_end = min(T, core_end + overlap)
# Extract chunk
latent_chunk = latents[:, :, win_start:win_end]
# Decode on GPU
# [Batch, Channels, AudioSamples]
audio_chunk = self.vae.decode(latent_chunk).sample
# Calculate trim amounts in audio samples
added_start = core_start - win_start # latent frames
trim_start = int(round(added_start * upsample_factor))
added_end = win_end - core_end # latent frames
trim_end = int(round(added_end * upsample_factor))
# Trim audio
audio_len = audio_chunk.shape[-1]
end_idx = audio_len - trim_end if trim_end > 0 else audio_len
audio_core = audio_chunk[:, :, trim_start:end_idx]
# Copy to pre-allocated CPU tensor
core_len = audio_core.shape[-1]
final_audio[:, :, audio_write_pos:audio_write_pos + core_len] = audio_core.cpu()
audio_write_pos += core_len
# Free GPU memory immediately
del audio_chunk, audio_core, latent_chunk
# Trim to actual length (in case of rounding differences)
final_audio = final_audio[:, :, :audio_write_pos]
return final_audio
def generate_music(
self,
captions: str,
lyrics: str,
bpm: Optional[int] = None,
key_scale: str = "",
time_signature: str = "",
vocal_language: str = "en",
inference_steps: int = 8,
guidance_scale: float = 7.0,
use_random_seed: bool = True,
seed: Optional[Union[str, float, int]] = -1,
reference_audio=None,
audio_duration: Optional[float] = None,
batch_size: Optional[int] = None,
src_audio=None,
audio_code_string: Union[str, List[str]] = "",
repainting_start: float = 0.0,
repainting_end: Optional[float] = None,
instruction: str = DEFAULT_DIT_INSTRUCTION,
audio_cover_strength: float = 1.0,
task_type: str = "text2music",
use_adg: bool = False,
cfg_interval_start: float = 0.0,
cfg_interval_end: float = 1.0,
shift: float = 1.0,
infer_method: str = "ode",
use_tiled_decode: bool = True,
timesteps: Optional[List[float]] = None,
progress=None
) -> Dict[str, Any]:
"""
Main interface for music generation
Returns:
Dictionary containing:
- audios: List of audio dictionaries with path, key, params
- generation_info: Markdown-formatted generation information
- status_message: Status message
- extra_outputs: Dictionary with latents, masks, time_costs, etc.
- success: Whether generation completed successfully
- error: Error message if generation failed
"""
if progress is None:
def progress(*args, **kwargs):
pass
if self.model is None or self.vae is None or self.text_tokenizer is None or self.text_encoder is None:
return {
"audios": [],
"status_message": "❌ Model not fully initialized. Please initialize all components first.",
"extra_outputs": {},
"success": False,
"error": "Model not fully initialized",
}
def _has_audio_codes(v: Union[str, List[str]]) -> bool:
if isinstance(v, list):
return any((x or "").strip() for x in v)
return bool(v and str(v).strip())
# Auto-detect task type based on audio_code_string
# If audio_code_string is provided and not empty, use cover task
# Otherwise, use text2music task (or keep current task_type if not text2music)
if task_type == "text2music":
if _has_audio_codes(audio_code_string):
# User has provided audio codes, switch to cover task
task_type = "cover"
# Update instruction for cover task
instruction = TASK_INSTRUCTIONS["cover"]
logger.info("[generate_music] Starting generation...")
if progress:
progress(0.51, desc="Preparing inputs...")
logger.info("[generate_music] Preparing inputs...")
# Reset offload cost
self.current_offload_cost = 0.0
# Caption and lyrics are optional - can be empty
# Use provided batch_size or default
actual_batch_size = batch_size if batch_size is not None else self.batch_size
actual_batch_size = max(1, actual_batch_size) # Ensure at least 1
actual_seed_list, seed_value_for_ui = self.prepare_seeds(actual_batch_size, seed, use_random_seed)
# Convert special values to None
if audio_duration is not None and audio_duration <= 0:
audio_duration = None
# if seed is not None and seed < 0:
# seed = None
if repainting_end is not None and repainting_end < 0:
repainting_end = None
try:
# 1. Process reference audio
refer_audios = None
if reference_audio is not None:
logger.info("[generate_music] Processing reference audio...")
processed_ref_audio = self.process_reference_audio(reference_audio)
if processed_ref_audio is not None:
# Convert to the format expected by the service: List[List[torch.Tensor]]
# Each batch item has a list of reference audios
refer_audios = [[processed_ref_audio] for _ in range(actual_batch_size)]
else:
refer_audios = [[torch.zeros(2, 30*self.sample_rate)] for _ in range(actual_batch_size)]
# 2. Process source audio
# If audio_code_string is provided, ignore src_audio and use codes instead
processed_src_audio = None
if src_audio is not None:
# Check if audio codes are provided - if so, ignore src_audio
if _has_audio_codes(audio_code_string):
logger.info("[generate_music] Audio codes provided, ignoring src_audio and using codes instead")
else:
logger.info("[generate_music] Processing source audio...")
processed_src_audio = self.process_src_audio(src_audio)
# 3. Prepare batch data
captions_batch, instructions_batch, lyrics_batch, vocal_languages_batch, metas_batch = self.prepare_batch_data(
actual_batch_size,
processed_src_audio,
audio_duration,
captions,
lyrics,
vocal_language,
instruction,
bpm,
key_scale,
time_signature
)
is_repaint_task, is_lego_task, is_cover_task, can_use_repainting = self.determine_task_type(task_type, audio_code_string)
repainting_start_batch, repainting_end_batch, target_wavs_tensor = self.prepare_padding_info(
actual_batch_size,
processed_src_audio,
audio_duration,
repainting_start,
repainting_end,
is_repaint_task,
is_lego_task,
is_cover_task,
can_use_repainting
)
progress(0.52, desc=f"Generating music (batch size: {actual_batch_size})...")
# Prepare audio_code_hints - use if audio_code_string is provided
# This works for both text2music (auto-switched to cover) and cover tasks
audio_code_hints_batch = None
if _has_audio_codes(audio_code_string):
if isinstance(audio_code_string, list):
audio_code_hints_batch = audio_code_string
else:
audio_code_hints_batch = [audio_code_string] * actual_batch_size
should_return_intermediate = (task_type == "text2music")
outputs = self.service_generate(
captions=captions_batch,
lyrics=lyrics_batch,
metas=metas_batch, # Pass as dict, service will convert to string
vocal_languages=vocal_languages_batch,
refer_audios=refer_audios, # Already in List[List[torch.Tensor]] format
target_wavs=target_wavs_tensor, # Shape: [batch_size, 2, frames]
infer_steps=inference_steps,
guidance_scale=guidance_scale,
seed=actual_seed_list, # Pass list of seeds, one per batch item
repainting_start=repainting_start_batch,
repainting_end=repainting_end_batch,
instructions=instructions_batch, # Pass instructions to service
audio_cover_strength=audio_cover_strength, # Pass audio cover strength
use_adg=use_adg, # Pass use_adg parameter
cfg_interval_start=cfg_interval_start, # Pass CFG interval start
cfg_interval_end=cfg_interval_end, # Pass CFG interval end
shift=shift, # Pass shift parameter
infer_method=infer_method, # Pass infer method (ode or sde)
audio_code_hints=audio_code_hints_batch, # Pass audio code hints as list
return_intermediate=should_return_intermediate,
timesteps=timesteps, # Pass custom timesteps if provided
)
logger.info("[generate_music] Model generation completed. Decoding latents...")
pred_latents = outputs["target_latents"] # [batch, latent_length, latent_dim]
time_costs = outputs["time_costs"]
time_costs["offload_time_cost"] = self.current_offload_cost
logger.debug(f"[generate_music] pred_latents: {pred_latents.shape}, dtype={pred_latents.dtype} {pred_latents.min()=}, {pred_latents.max()=}, {pred_latents.mean()=} {pred_latents.std()=}")
logger.debug(f"[generate_music] time_costs: {time_costs}")
if progress:
progress(0.8, desc="Decoding audio...")
logger.info("[generate_music] Decoding latents with VAE...")
# Decode latents to audio
start_time = time.time()
with torch.no_grad():
with self._load_model_context("vae"):
# Transpose for VAE decode: [batch, latent_length, latent_dim] -> [batch, latent_dim, latent_length]
pred_latents_for_decode = pred_latents.transpose(1, 2)
# Ensure input is in VAE's dtype
pred_latents_for_decode = pred_latents_for_decode.to(self.vae.dtype)
if use_tiled_decode:
logger.info("[generate_music] Using tiled VAE decode to reduce VRAM usage...")
pred_wavs = self.tiled_decode(pred_latents_for_decode) # [batch, channels, samples]
else:
pred_wavs = self.vae.decode(pred_latents_for_decode).sample
# Cast output to float32 for audio processing/saving
pred_wavs = pred_wavs.to(torch.float32)
end_time = time.time()
time_costs["vae_decode_time_cost"] = end_time - start_time
time_costs["total_time_cost"] = time_costs["total_time_cost"] + time_costs["vae_decode_time_cost"]
# Update offload cost one last time to include VAE offloading
time_costs["offload_time_cost"] = self.current_offload_cost
logger.info("[generate_music] VAE decode completed. Preparing audio tensors...")
if progress:
progress(0.99, desc="Preparing audio data...")
# Prepare audio tensors (no file I/O here, no UUID generation)
# pred_wavs is already [batch, channels, samples] format
# Move to CPU and convert to float32 for return
audio_tensors = []
for i in range(actual_batch_size):
# Extract audio tensor: [channels, samples] format, CPU, float32
audio_tensor = pred_wavs[i].cpu().float()
audio_tensors.append(audio_tensor)
status_message = f"✅ Generation completed successfully!"
logger.info(f"[generate_music] Done! Generated {len(audio_tensors)} audio tensors.")
# Extract intermediate information from outputs
src_latents = outputs.get("src_latents") # [batch, T, D]
target_latents_input = outputs.get("target_latents_input") # [batch, T, D]
chunk_masks = outputs.get("chunk_masks") # [batch, T]
spans = outputs.get("spans", []) # List of tuples
latent_masks = outputs.get("latent_masks") # [batch, T]
# Extract condition tensors for LRC timestamp generation
encoder_hidden_states = outputs.get("encoder_hidden_states")
encoder_attention_mask = outputs.get("encoder_attention_mask")
context_latents = outputs.get("context_latents")
lyric_token_idss = outputs.get("lyric_token_idss")
# Move all tensors to CPU to save VRAM (detach to release computation graph)
extra_outputs = {
"pred_latents": pred_latents.detach().cpu() if pred_latents is not None else None,
"target_latents": target_latents_input.detach().cpu() if target_latents_input is not None else None,
"src_latents": src_latents.detach().cpu() if src_latents is not None else None,
"chunk_masks": chunk_masks.detach().cpu() if chunk_masks is not None else None,
"latent_masks": latent_masks.detach().cpu() if latent_masks is not None else None,
"spans": spans,
"time_costs": time_costs,
"seed_value": seed_value_for_ui,
# Condition tensors for LRC timestamp generation
"encoder_hidden_states": encoder_hidden_states.detach().cpu() if encoder_hidden_states is not None else None,
"encoder_attention_mask": encoder_attention_mask.detach().cpu() if encoder_attention_mask is not None else None,
"context_latents": context_latents.detach().cpu() if context_latents is not None else None,
"lyric_token_idss": lyric_token_idss.detach().cpu() if lyric_token_idss is not None else None,
}
# Build audios list with tensor data (no file paths, no UUIDs, handled outside)
audios = []
for idx, audio_tensor in enumerate(audio_tensors):
audio_dict = {
"tensor": audio_tensor, # torch.Tensor [channels, samples], CPU, float32
"sample_rate": self.sample_rate,
}
audios.append(audio_dict)
return {
"audios": audios,
"status_message": status_message,
"extra_outputs": extra_outputs,
"success": True,
"error": None,
}
except Exception as e:
error_msg = f"❌ Error: {str(e)}\n{traceback.format_exc()}"
logger.exception("[generate_music] Generation failed")
# Clean up CUDA state after any error (especially important for CUDA errors)
if torch.cuda.is_available():
try:
torch.cuda.synchronize()
except Exception:
pass # Ignore sync errors during cleanup
torch.cuda.empty_cache()
return {
"audios": [],
"status_message": error_msg,
"extra_outputs": {},
"success": False,
"error": str(e),
}
@torch.no_grad()
def get_lyric_timestamp(
self,
pred_latent: torch.Tensor,
encoder_hidden_states: torch.Tensor,
encoder_attention_mask: torch.Tensor,
context_latents: torch.Tensor,
lyric_token_ids: torch.Tensor,
total_duration_seconds: float,
vocal_language: str = "en",
inference_steps: int = 8,
seed: int = 42,
custom_layers_config: Optional[Dict] = None,
) -> Dict[str, Any]:
"""
Generate lyrics timestamps from generated audio latents using cross-attention alignment.
This method adds noise to the final pred_latent and re-infers one step to get
cross-attention matrices, then uses DTW to align lyrics tokens with audio frames.
Args:
pred_latent: Generated latent tensor [batch, T, D]
encoder_hidden_states: Cached encoder hidden states
encoder_attention_mask: Cached encoder attention mask
context_latents: Cached context latents
lyric_token_ids: Tokenized lyrics tensor [batch, seq_len]
total_duration_seconds: Total audio duration in seconds
vocal_language: Language code for lyrics header parsing
inference_steps: Number of inference steps (for noise level calculation)
seed: Random seed for noise generation
custom_layers_config: Dict mapping layer indices to head indices
Returns:
Dict containing:
- lrc_text: LRC formatted lyrics with timestamps
- sentence_timestamps: List of SentenceTimestamp objects
- token_timestamps: List of TokenTimestamp objects
- success: Whether generation succeeded
- error: Error message if failed
"""
from transformers.cache_utils import EncoderDecoderCache, DynamicCache
if self.model is None:
return {
"lrc_text": "",
"sentence_timestamps": [],
"token_timestamps": [],
"success": False,
"error": "Model not initialized"
}
if custom_layers_config is None:
custom_layers_config = self.custom_layers_config
try:
# Move tensors to device
device = self.device
dtype = self.dtype
pred_latent = pred_latent.to(device=device, dtype=dtype)
encoder_hidden_states = encoder_hidden_states.to(device=device, dtype=dtype)
encoder_attention_mask = encoder_attention_mask.to(device=device, dtype=dtype)
context_latents = context_latents.to(device=device, dtype=dtype)
bsz = pred_latent.shape[0]
# Calculate noise level: t_last = 1.0 / inference_steps
t_last_val = 1.0 / inference_steps
t_curr_tensor = torch.tensor([t_last_val] * bsz, device=device, dtype=dtype)
x1 = pred_latent
# Generate noise
if seed is None:
x0 = torch.randn_like(x1)
else:
generator = torch.Generator(device=device).manual_seed(int(seed))
x0 = torch.randn(x1.shape, generator=generator, device=device, dtype=dtype)
# Add noise to pred_latent: xt = t * noise + (1 - t) * x1
xt = t_last_val * x0 + (1.0 - t_last_val) * x1
xt_in = xt
t_in = t_curr_tensor
# Get null condition embedding
encoder_hidden_states_in = encoder_hidden_states
encoder_attention_mask_in = encoder_attention_mask
context_latents_in = context_latents
latent_length = x1.shape[1]
attention_mask = torch.ones(bsz, latent_length, device=device, dtype=dtype)
attention_mask_in = attention_mask
past_key_values = None
# Run decoder with output_attentions=True
with self._load_model_context("model"):
decoder = self.model.decoder
decoder_outputs = decoder(
hidden_states=xt_in,
timestep=t_in,
timestep_r=t_in,
attention_mask=attention_mask_in,
encoder_hidden_states=encoder_hidden_states_in,
use_cache=False,
past_key_values=past_key_values,
encoder_attention_mask=encoder_attention_mask_in,
context_latents=context_latents_in,
output_attentions=True,
custom_layers_config=custom_layers_config,
enable_early_exit=True
)
# Extract cross-attention matrices
if decoder_outputs[2] is None:
return {
"lrc_text": "",
"sentence_timestamps": [],
"token_timestamps": [],
"success": False,
"error": "Model did not return attentions"
}
cross_attns = decoder_outputs[2] # Tuple of tensors (some may be None)
captured_layers_list = []
for layer_attn in cross_attns:
# Skip None values (layers that didn't return attention)
if layer_attn is None:
continue
# Only take conditional part (first half of batch)
cond_attn = layer_attn[:bsz]
layer_matrix = cond_attn.transpose(-1, -2)
captured_layers_list.append(layer_matrix)
if not captured_layers_list:
return {
"lrc_text": "",
"sentence_timestamps": [],
"token_timestamps": [],
"success": False,
"error": "No valid attention layers returned"
}
stacked = torch.stack(captured_layers_list)
if bsz == 1:
all_layers_matrix = stacked.squeeze(1)
else:
all_layers_matrix = stacked
# Process lyric token IDs to extract pure lyrics
if isinstance(lyric_token_ids, torch.Tensor):
raw_lyric_ids = lyric_token_ids[0].tolist()
else:
raw_lyric_ids = lyric_token_ids
# Parse header to find lyrics start position
header_str = f"# Languages\n{vocal_language}\n\n# Lyric\n"
header_ids = self.text_tokenizer.encode(header_str, add_special_tokens=False)
start_idx = len(header_ids)
# Find end of lyrics (before endoftext token)
try:
end_idx = raw_lyric_ids.index(151643) # <|endoftext|> token
except ValueError:
end_idx = len(raw_lyric_ids)
pure_lyric_ids = raw_lyric_ids[start_idx:end_idx]
pure_lyric_matrix = all_layers_matrix[:, :, start_idx:end_idx, :]
# Create aligner and generate timestamps
aligner = MusicStampsAligner(self.text_tokenizer)
align_info = aligner.stamps_align_info(
attention_matrix=pure_lyric_matrix,
lyrics_tokens=pure_lyric_ids,
total_duration_seconds=total_duration_seconds,
custom_config=custom_layers_config,
return_matrices=False,
violence_level=2.0,
medfilt_width=1,
)
if align_info.get("calc_matrix") is None:
return {
"lrc_text": "",
"sentence_timestamps": [],
"token_timestamps": [],
"success": False,
"error": align_info.get("error", "Failed to process attention matrix")
}
# Generate timestamps
result = aligner.get_timestamps_and_lrc(
calc_matrix=align_info["calc_matrix"],
lyrics_tokens=pure_lyric_ids,
total_duration_seconds=total_duration_seconds
)
return {
"lrc_text": result["lrc_text"],
"sentence_timestamps": result["sentence_timestamps"],
"token_timestamps": result["token_timestamps"],
"success": True,
"error": None
}
except Exception as e:
error_msg = f"Error generating timestamps: {str(e)}"
logger.exception("[get_lyric_timestamp] Failed")
return {
"lrc_text": "",
"sentence_timestamps": [],
"token_timestamps": [],
"success": False,
"error": error_msg
}
@torch.no_grad()
def get_lyric_score(
self,
pred_latent: torch.Tensor,
encoder_hidden_states: torch.Tensor,
encoder_attention_mask: torch.Tensor,
context_latents: torch.Tensor,
lyric_token_ids: torch.Tensor,
vocal_language: str = "en",
inference_steps: int = 8,
seed: int = 42,
custom_layers_config: Optional[Dict] = None,
) -> Dict[str, Any]:
"""
Calculate both LM and DiT alignment scores in one pass.
- lm_score: Checks structural alignment using pure noise at t=1.0.
- dit_score: Checks denoising alignment using regressed latents at t=1/steps.
Args:
pred_latent: Generated latent tensor [batch, T, D]
encoder_hidden_states: Cached encoder hidden states
encoder_attention_mask: Cached encoder attention mask
context_latents: Cached context latents
lyric_token_ids: Tokenized lyrics tensor [batch, seq_len]
vocal_language: Language code for lyrics header parsing
inference_steps: Number of inference steps (for noise level calculation)
seed: Random seed for noise generation
custom_layers_config: Dict mapping layer indices to head indices
Returns:
Dict containing:
- lm_score: float
- dit_score: float
- success: Whether generation succeeded
- error: Error message if failed
"""
from transformers.cache_utils import EncoderDecoderCache, DynamicCache
if self.model is None:
return {
"lm_score": 0.0,
"dit_score": 0.0,
"success": False,
"error": "Model not initialized"
}
if custom_layers_config is None:
custom_layers_config = self.custom_layers_config
try:
# Move tensors to device
device = self.device
dtype = self.dtype
pred_latent = pred_latent.to(device=device, dtype=dtype)
encoder_hidden_states = encoder_hidden_states.to(device=device, dtype=dtype)
encoder_attention_mask = encoder_attention_mask.to(device=device, dtype=dtype)
context_latents = context_latents.to(device=device, dtype=dtype)
bsz = pred_latent.shape[0]
if seed is None:
x0 = torch.randn_like(pred_latent)
else:
generator = torch.Generator(device=device).manual_seed(int(seed))
x0 = torch.randn(pred_latent.shape, generator=generator, device=device, dtype=dtype)
# --- Input A: LM Score ---
# t = 1.0, xt = Pure Noise
t_lm = torch.tensor([1.0] * bsz, device=device, dtype=dtype)
xt_lm = x0
# --- Input B: DiT Score ---
# t = 1.0/steps, xt = Regressed Latent
t_last_val = 1.0 / inference_steps
t_dit = torch.tensor([t_last_val] * bsz, device=device, dtype=dtype)
# Flow Matching Regression: xt = t*x0 + (1-t)*x1
xt_dit = t_last_val * x0 + (1.0 - t_last_val) * pred_latent
# Order: [Think_Batch, DiT_Batch]
xt_in = torch.cat([xt_lm, xt_dit], dim=0)
t_in = torch.cat([t_lm, t_dit], dim=0)
# Duplicate conditions
encoder_hidden_states_in = torch.cat([encoder_hidden_states, encoder_hidden_states], dim=0)
encoder_attention_mask_in = torch.cat([encoder_attention_mask, encoder_attention_mask], dim=0)
context_latents_in = torch.cat([context_latents, context_latents], dim=0)
# Prepare Attention Mask
latent_length = xt_in.shape[1]
attention_mask_in = torch.ones(2 * bsz, latent_length, device=device, dtype=dtype)
past_key_values = None
# Run decoder with output_attentions=True
with self._load_model_context("model"):
decoder = self.model.decoder
if hasattr(decoder, 'eval'):
decoder.eval()
decoder_outputs = decoder(
hidden_states=xt_in,
timestep=t_in,
timestep_r=t_in,
attention_mask=attention_mask_in,
encoder_hidden_states=encoder_hidden_states_in,
use_cache=False,
past_key_values=past_key_values,
encoder_attention_mask=encoder_attention_mask_in,
context_latents=context_latents_in,
output_attentions=True,
custom_layers_config=custom_layers_config,
enable_early_exit=True
)
# Extract cross-attention matrices
if decoder_outputs[2] is None:
return {
"lm_score": 0.0,
"dit_score": 0.0,
"success": False,
"error": "Model did not return attentions"
}
cross_attns = decoder_outputs[2] # Tuple of tensors (some may be None)
captured_layers_list = []
for layer_attn in cross_attns:
if layer_attn is None:
continue
# Only take conditional part (first half of batch)
layer_matrix = layer_attn.transpose(-1, -2)
captured_layers_list.append(layer_matrix)
if not captured_layers_list:
return {
"lm_score": 0.0,
"dit_score": 0.0,
"success": False,
"error": "No valid attention layers returned"
}
stacked = torch.stack(captured_layers_list)
all_layers_matrix_lm = stacked[:, :bsz, ...]
all_layers_matrix_dit = stacked[:, bsz:, ...]
if bsz == 1:
all_layers_matrix_lm = all_layers_matrix_lm.squeeze(1)
all_layers_matrix_dit = all_layers_matrix_dit.squeeze(1)
else:
pass
# Process lyric token IDs to extract pure lyrics
if isinstance(lyric_token_ids, torch.Tensor):
raw_lyric_ids = lyric_token_ids[0].tolist()
else:
raw_lyric_ids = lyric_token_ids
# Parse header to find lyrics start position
header_str = f"# Languages\n{vocal_language}\n\n# Lyric\n"
header_ids = self.text_tokenizer.encode(header_str, add_special_tokens=False)
start_idx = len(header_ids)
# Find end of lyrics (before endoftext token)
try:
end_idx = raw_lyric_ids.index(151643) # <|endoftext|> token
except ValueError:
end_idx = len(raw_lyric_ids)
pure_lyric_ids = raw_lyric_ids[start_idx:end_idx]
if start_idx >= all_layers_matrix_lm.shape[-2]: # Check text dim
return {
"lm_score": 0.0,
"dit_score": 0.0,
"success": False,
"error": "Lyrics indices out of bounds"
}
pure_matrix_lm = all_layers_matrix_lm[..., start_idx:end_idx, :]
pure_matrix_dit = all_layers_matrix_dit[..., start_idx:end_idx, :]
# Create aligner and calculate alignment info
aligner = MusicLyricScorer(self.text_tokenizer)
def calculate_single_score(matrix):
"""Helper to run aligner on a matrix"""
info = aligner.lyrics_alignment_info(
attention_matrix=matrix,
token_ids=pure_lyric_ids,
custom_config=custom_layers_config,
return_matrices=False,
medfilt_width=1,
)
if info.get("energy_matrix") is None:
return 0.0
res = aligner.calculate_score(
energy_matrix=info["energy_matrix"],
type_mask=info["type_mask"],
path_coords=info["path_coords"],
)
# Return the final score (check return key)
return res.get("lyrics_score", res.get("final_score", 0.0))
lm_score = calculate_single_score(pure_matrix_lm)
dit_score = calculate_single_score(pure_matrix_dit)
return {
"lm_score": lm_score,
"dit_score": dit_score,
"success": True,
"error": None
}
except Exception as e:
error_msg = f"Error generating score: {str(e)}"
logger.exception("[get_lyric_score] Failed")
return {
"lm_score": 0.0,
"dit_score": 0.0,
"success": False,
"error": error_msg
}