""" Business Logic Handler Encapsulates all data processing and business logic as a bridge between model and UI """ import os # Disable tokenizers parallelism to avoid fork warning os.environ["TOKENIZERS_PARALLELISM"] = "false" # Disable torchcodec backend to avoid CUDA dependency issues on HuggingFace Space # This forces torchaudio to use ffmpeg/sox/soundfile backends instead os.environ["TORCHAUDIO_USE_TORCHCODEC"] = "0" import math from copy import deepcopy import tempfile import traceback import re import random import uuid import hashlib import json from contextlib import contextmanager from typing import Optional, Dict, Any, Tuple, List, Union import torch import torchaudio import soundfile as sf import time from tqdm import tqdm from loguru import logger import warnings from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM from transformers.generation.streamers import BaseStreamer from diffusers.models import AutoencoderOobleck from acestep.constants import ( TASK_INSTRUCTIONS, SFT_GEN_PROMPT, DEFAULT_DIT_INSTRUCTION, ) from acestep.dit_alignment_score import MusicStampsAligner, MusicLyricScorer warnings.filterwarnings("ignore") class AceStepHandler: """ACE-Step Business Logic Handler""" # HuggingFace Space environment detection IS_HUGGINGFACE_SPACE = os.environ.get("SPACE_ID") is not None def __init__(self, persistent_storage_path: Optional[str] = None): self.model = None self.config = None self.device = "cpu" self.dtype = torch.float32 # Will be set based on device in initialize_service # HuggingFace Space persistent storage support if persistent_storage_path is None and self.IS_HUGGINGFACE_SPACE: persistent_storage_path = "/data" self.persistent_storage_path = persistent_storage_path # VAE for audio encoding/decoding self.vae = None # Text encoder and tokenizer self.text_encoder = None self.text_tokenizer = None # Silence latent for initialization self.silence_latent = None # Sample rate self.sample_rate = 48000 # Reward model (temporarily disabled) self.reward_model = None # Batch size self.batch_size = 2 # Custom layers config self.custom_layers_config = {2: [6], 3: [10, 11], 4: [3], 5: [8, 9], 6: [8]} self.offload_to_cpu = False self.offload_dit_to_cpu = False self.current_offload_cost = 0.0 # LoRA state self.lora_loaded = False self.use_lora = False self._base_decoder = None # Backup of original decoder def _get_checkpoint_dir(self) -> str: """Get checkpoint directory, prioritizing persistent storage if available""" if self.persistent_storage_path: return os.path.join(self.persistent_storage_path, "checkpoints") project_root = self._get_project_root() return os.path.join(project_root, "checkpoints") def get_available_checkpoints(self) -> str: """Return project root directory path""" checkpoint_dir = self._get_checkpoint_dir() if os.path.exists(checkpoint_dir): return [checkpoint_dir] else: return [] def get_available_acestep_v15_models(self) -> List[str]: """Scan and return all model directory names starting with 'acestep-v15-'""" checkpoint_dir = self._get_checkpoint_dir() models = [] if os.path.exists(checkpoint_dir): for item in os.listdir(checkpoint_dir): item_path = os.path.join(checkpoint_dir, item) if os.path.isdir(item_path) and item.startswith("acestep-v15-"): models.append(item) models.sort() return models # Model name to HuggingFace repository mapping # Models in the same repo will be downloaded together MODEL_REPO_MAPPING = { # Main unified repository (contains acestep-v15-turbo, LM models, VAE, text encoder) "acestep-v15-turbo": "ACE-Step/Ace-Step1.5", "acestep-5Hz-lm-0.6B": "ACE-Step/Ace-Step1.5", "acestep-5Hz-lm-1.7B": "ACE-Step/Ace-Step1.5", "vae": "ACE-Step/Ace-Step1.5", "Qwen3-Embedding-0.6B": "ACE-Step/Ace-Step1.5", # Separate model repositories "acestep-v15-base": "ACE-Step/acestep-v15-base", "acestep-v15-sft": "ACE-Step/acestep-v15-sft", "acestep-v15-turbo-shift3": "ACE-Step/acestep-v15-turbo-shift3", } # Default fallback repository for unknown models DEFAULT_REPO_ID = "ACE-Step/Ace-Step1.5" def _ensure_model_downloaded(self, model_name: str, checkpoint_dir: str) -> str: """ Ensure model is downloaded from HuggingFace Hub. Used for HuggingFace Space auto-download support. Supports multiple repositories: - Models in MODEL_REPO_MAPPING will be downloaded from their specific repo - Unknown models will try the DEFAULT_REPO_ID For separate model repos (acestep-v15-base, acestep-v15-sft, acestep-v15-turbo-shift3), downloads directly into the model subdirectory. Args: model_name: Model directory name (e.g., "acestep-v15-turbo", "acestep-v15-turbo-shift3") checkpoint_dir: Target checkpoint directory Returns: Path to the downloaded model """ from huggingface_hub import snapshot_download model_path = os.path.join(checkpoint_dir, model_name) # Check if model already exists if os.path.exists(model_path) and os.listdir(model_path): logger.info(f"Model {model_name} already exists at {model_path}") return model_path # Get repository ID for this model repo_id = self.MODEL_REPO_MAPPING.get(model_name, self.DEFAULT_REPO_ID) # Determine if this is a unified repo or a separate model repo is_unified_repo = repo_id == self.DEFAULT_REPO_ID or repo_id == "ACE-Step/Ace-Step1.5" if is_unified_repo: # Unified repo: download entire repo to checkpoint_dir # The model will be in checkpoint_dir/model_name download_dir = checkpoint_dir logger.info(f"Downloading unified repository {repo_id} to {download_dir}...") else: # Separate model repo: download directly to model_path # The repo contains the model files directly, not in a subdirectory download_dir = model_path os.makedirs(download_dir, exist_ok=True) logger.info(f"Downloading model {model_name} from {repo_id} to {download_dir}...") try: snapshot_download( repo_id=repo_id, local_dir=download_dir, local_dir_use_symlinks=False, ) logger.info(f"Repository {repo_id} downloaded successfully to {download_dir}") except Exception as e: logger.error(f"Failed to download repository {repo_id}: {e}") raise return model_path def is_flash_attention_available(self) -> bool: """Check if flash attention is available on the system""" try: import flash_attn return True except ImportError: return False def is_turbo_model(self) -> bool: """Check if the currently loaded model is a turbo model""" if self.config is None: return False return getattr(self.config, 'is_turbo', False) def load_lora(self, lora_path: str) -> str: """Load LoRA adapter into the decoder. Args: lora_path: Path to the LoRA adapter directory (containing adapter_config.json) Returns: Status message """ if self.model is None: return "❌ Model not initialized. Please initialize service first." if not lora_path or not lora_path.strip(): return "❌ Please provide a LoRA path." lora_path = lora_path.strip() # Check if path exists if not os.path.exists(lora_path): return f"❌ LoRA path not found: {lora_path}" # Check if it's a valid PEFT adapter directory config_file = os.path.join(lora_path, "adapter_config.json") if not os.path.exists(config_file): return f"❌ Invalid LoRA adapter: adapter_config.json not found in {lora_path}" try: from peft import PeftModel, PeftConfig except ImportError: return "❌ PEFT library not installed. Please install with: pip install peft" try: # Backup base decoder if not already backed up if self._base_decoder is None: import copy self._base_decoder = copy.deepcopy(self.model.decoder) logger.info("Base decoder backed up") else: # Restore base decoder before loading new LoRA self.model.decoder = copy.deepcopy(self._base_decoder) logger.info("Restored base decoder before loading new LoRA") # Load PEFT adapter logger.info(f"Loading LoRA adapter from {lora_path}") self.model.decoder = PeftModel.from_pretrained( self.model.decoder, lora_path, is_trainable=False, ) self.model.decoder = self.model.decoder.to(self.device).to(self.dtype) self.model.decoder.eval() self.lora_loaded = True self.use_lora = True # Enable LoRA by default after loading logger.info(f"LoRA adapter loaded successfully from {lora_path}") return f"✅ LoRA loaded from {lora_path}" except Exception as e: logger.exception("Failed to load LoRA adapter") return f"❌ Failed to load LoRA: {str(e)}" def unload_lora(self) -> str: """Unload LoRA adapter and restore base decoder. Returns: Status message """ if not self.lora_loaded: return "⚠️ No LoRA adapter loaded." if self._base_decoder is None: return "❌ Base decoder backup not found. Cannot restore." try: import copy # Restore base decoder self.model.decoder = copy.deepcopy(self._base_decoder) self.model.decoder = self.model.decoder.to(self.device).to(self.dtype) self.model.decoder.eval() self.lora_loaded = False self.use_lora = False logger.info("LoRA unloaded, base decoder restored") return "✅ LoRA unloaded, using base model" except Exception as e: logger.exception("Failed to unload LoRA") return f"❌ Failed to unload LoRA: {str(e)}" def set_use_lora(self, use_lora: bool) -> str: """Toggle LoRA usage for inference. Args: use_lora: Whether to use LoRA adapter Returns: Status message """ if use_lora and not self.lora_loaded: return "❌ No LoRA adapter loaded. Please load a LoRA first." self.use_lora = use_lora # Use PEFT's enable/disable methods if available if self.lora_loaded and hasattr(self.model.decoder, 'disable_adapter_layers'): try: if use_lora: self.model.decoder.enable_adapter_layers() logger.info("LoRA adapter enabled") else: self.model.decoder.disable_adapter_layers() logger.info("LoRA adapter disabled") except Exception as e: logger.warning(f"Could not toggle adapter layers: {e}") status = "enabled" if use_lora else "disabled" return f"✅ LoRA {status}" def get_lora_status(self) -> Dict[str, Any]: """Get current LoRA status. Returns: Dictionary with LoRA status info """ return { "loaded": self.lora_loaded, "active": self.use_lora, } def initialize_service( self, project_root: str, config_path: str, device: str = "auto", use_flash_attention: bool = False, compile_model: bool = False, offload_to_cpu: bool = False, offload_dit_to_cpu: bool = False, quantization: Optional[str] = None, # Shared components (for multi-model setup to save memory) shared_vae = None, shared_text_encoder = None, shared_text_tokenizer = None, shared_silence_latent = None, ) -> Tuple[str, bool]: """ Initialize DiT model service Args: project_root: Project root path (may be checkpoints directory, will be handled automatically) config_path: Model config directory name (e.g., "acestep-v15-turbo") device: Device type use_flash_attention: Whether to use flash attention (requires flash_attn package) compile_model: Whether to use torch.compile to optimize the model offload_to_cpu: Whether to offload models to CPU when not in use offload_dit_to_cpu: Whether to offload DiT model to CPU when not in use (only effective if offload_to_cpu is True) shared_vae: Optional shared VAE instance (for multi-model setup) shared_text_encoder: Optional shared text encoder instance (for multi-model setup) shared_text_tokenizer: Optional shared text tokenizer instance (for multi-model setup) shared_silence_latent: Optional shared silence latent tensor (for multi-model setup) Returns: (status_message, enable_generate_button) """ try: if device == "auto": if hasattr(torch, 'xpu') and torch.xpu.is_available(): device = "xpu" elif torch.cuda.is_available(): device = "cuda" elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): device = "mps" else: device = "cpu" status_msg = "" self.device = device self.offload_to_cpu = offload_to_cpu self.offload_dit_to_cpu = offload_dit_to_cpu # Set dtype based on device: bfloat16 for cuda, float32 for cpu self.dtype = torch.bfloat16 if device in ["cuda","xpu"] else torch.float32 self.quantization = quantization if self.quantization is not None: assert compile_model, "Quantization requires compile_model to be True" try: import torchao except ImportError: raise ImportError("torchao is required for quantization but is not installed. Please install torchao to use quantization features.") # Auto-detect project root (independent of passed project_root parameter) actual_project_root = self._get_project_root() checkpoint_dir = self._get_checkpoint_dir() os.makedirs(checkpoint_dir, exist_ok=True) # 1. Load main model # config_path is relative path (e.g., "acestep-v15-turbo"), concatenate to checkpoints directory # If config_path is None (HuggingFace Space with empty checkpoint), use default and auto-download if config_path is None: config_path = "acestep-v15-turbo" logger.info(f"[initialize_service] config_path is None, using default: {config_path}") acestep_v15_checkpoint_path = os.path.join(checkpoint_dir, config_path) # Auto-download model if not exists (HuggingFace Space support) if not os.path.exists(acestep_v15_checkpoint_path): acestep_v15_checkpoint_path = self._ensure_model_downloaded(config_path, checkpoint_dir) if os.path.exists(acestep_v15_checkpoint_path): # Determine attention implementation if use_flash_attention and self.is_flash_attention_available(): attn_implementation = "flash_attention_2" self.dtype = torch.bfloat16 else: attn_implementation = "sdpa" try: logger.info(f"[initialize_service] Attempting to load model with attention implementation: {attn_implementation}") self.model = AutoModel.from_pretrained( acestep_v15_checkpoint_path, trust_remote_code=True, attn_implementation=attn_implementation, dtype="bfloat16" ) except Exception as e: logger.warning(f"[initialize_service] Failed to load model with {attn_implementation}: {e}") if attn_implementation == "sdpa": logger.info("[initialize_service] Falling back to eager attention") attn_implementation = "eager" self.model = AutoModel.from_pretrained( acestep_v15_checkpoint_path, trust_remote_code=True, attn_implementation=attn_implementation ) else: raise e self.model.config._attn_implementation = attn_implementation self.config = self.model.config # Move model to device and set dtype if not self.offload_to_cpu: self.model = self.model.to(device).to(self.dtype) else: # If offload_to_cpu is True, check if we should keep DiT on GPU if not self.offload_dit_to_cpu: logger.info(f"[initialize_service] Keeping main model on {device} (persistent)") self.model = self.model.to(device).to(self.dtype) else: self.model = self.model.to("cpu").to(self.dtype) self.model.eval() if compile_model: self.model = torch.compile(self.model) if self.quantization is not None: from torchao.quantization import quantize_ if self.quantization == "int8_weight_only": from torchao.quantization import Int8WeightOnlyConfig quant_config = Int8WeightOnlyConfig() elif self.quantization == "fp8_weight_only": from torchao.quantization import Float8WeightOnlyConfig quant_config = Float8WeightOnlyConfig() elif self.quantization == "w8a8_dynamic": from torchao.quantization import Int8DynamicActivationInt8WeightConfig, MappingType quant_config = Int8DynamicActivationInt8WeightConfig(act_mapping_type=MappingType.ASYMMETRIC) else: raise ValueError(f"Unsupported quantization type: {self.quantization}") quantize_(self.model, quant_config) logger.info(f"[initialize_service] DiT quantized with: {self.quantization}") # Load or use shared silence_latent if shared_silence_latent is not None: self.silence_latent = shared_silence_latent logger.info("[initialize_service] Using shared silence_latent") else: silence_latent_path = os.path.join(acestep_v15_checkpoint_path, "silence_latent.pt") if os.path.exists(silence_latent_path): self.silence_latent = torch.load(silence_latent_path).transpose(1, 2) # Always keep silence_latent on GPU - it's used in many places outside model context # and is small enough that it won't significantly impact VRAM self.silence_latent = self.silence_latent.to(device).to(self.dtype) else: raise FileNotFoundError(f"Silence latent not found at {silence_latent_path}") else: raise FileNotFoundError(f"ACE-Step V1.5 checkpoint not found at {acestep_v15_checkpoint_path}") # 2. Load or use shared VAE vae_checkpoint_path = os.path.join(checkpoint_dir, "vae") # Define for status message if shared_vae is not None: self.vae = shared_vae logger.info("[initialize_service] Using shared VAE") else: if os.path.exists(vae_checkpoint_path): self.vae = AutoencoderOobleck.from_pretrained(vae_checkpoint_path) # Use bfloat16 for VAE on GPU, otherwise use self.dtype (float32 on CPU) vae_dtype = self._get_vae_dtype(device) if not self.offload_to_cpu: self.vae = self.vae.to(device).to(vae_dtype) else: self.vae = self.vae.to("cpu").to(vae_dtype) self.vae.eval() else: raise FileNotFoundError(f"VAE checkpoint not found at {vae_checkpoint_path}") if compile_model: self.vae = torch.compile(self.vae) # 3. Load or use shared text encoder and tokenizer text_encoder_path = os.path.join(checkpoint_dir, "Qwen3-Embedding-0.6B") # Define for status message if shared_text_encoder is not None and shared_text_tokenizer is not None: self.text_encoder = shared_text_encoder self.text_tokenizer = shared_text_tokenizer logger.info("[initialize_service] Using shared text encoder and tokenizer") else: if os.path.exists(text_encoder_path): self.text_tokenizer = AutoTokenizer.from_pretrained(text_encoder_path) self.text_encoder = AutoModel.from_pretrained(text_encoder_path) if not self.offload_to_cpu: self.text_encoder = self.text_encoder.to(device).to(self.dtype) else: self.text_encoder = self.text_encoder.to("cpu").to(self.dtype) self.text_encoder.eval() else: raise FileNotFoundError(f"Text encoder not found at {text_encoder_path}") # Determine actual attention implementation used actual_attn = getattr(self.config, "_attn_implementation", "eager") # Determine if using shared components using_shared = shared_vae is not None or shared_text_encoder is not None status_msg = f"✅ Model initialized successfully on {device}\n" status_msg += f"Main model: {acestep_v15_checkpoint_path}\n" if shared_vae is None: status_msg += f"VAE: {vae_checkpoint_path}\n" else: status_msg += f"VAE: shared\n" if shared_text_encoder is None: status_msg += f"Text encoder: {text_encoder_path}\n" else: status_msg += f"Text encoder: shared\n" status_msg += f"Dtype: {self.dtype}\n" status_msg += f"Attention: {actual_attn}\n" status_msg += f"Compiled: {compile_model}\n" status_msg += f"Offload to CPU: {self.offload_to_cpu}\n" status_msg += f"Offload DiT to CPU: {self.offload_dit_to_cpu}" return status_msg, True except Exception as e: error_msg = f"❌ Error initializing model: {str(e)}\n\nTraceback:\n{traceback.format_exc()}" logger.exception("[initialize_service] Error initializing model") return error_msg, False def _is_on_target_device(self, tensor, target_device): """Check if tensor is on the target device (handles cuda vs cuda:0 comparison).""" if tensor is None: return True target_type = "cpu" if target_device == "cpu" else "cuda" return tensor.device.type == target_type def _ensure_silence_latent_on_device(self): """Ensure silence_latent is on the correct device (self.device).""" if hasattr(self, "silence_latent") and self.silence_latent is not None: if not self._is_on_target_device(self.silence_latent, self.device): self.silence_latent = self.silence_latent.to(self.device).to(self.dtype) def _move_module_recursive(self, module, target_device, dtype=None, visited=None): """ Recursively move a module and all its submodules to the target device. This handles modules that may not be properly registered. """ if visited is None: visited = set() module_id = id(module) if module_id in visited: return visited.add(module_id) # Move the module itself module.to(target_device) if dtype is not None: module.to(dtype) # Move all direct parameters for param_name, param in module._parameters.items(): if param is not None and not self._is_on_target_device(param, target_device): module._parameters[param_name] = param.to(target_device) if dtype is not None: module._parameters[param_name] = module._parameters[param_name].to(dtype) # Move all direct buffers for buf_name, buf in module._buffers.items(): if buf is not None and not self._is_on_target_device(buf, target_device): module._buffers[buf_name] = buf.to(target_device) # Recursively process all submodules (registered and unregistered) for name, child in module._modules.items(): if child is not None: self._move_module_recursive(child, target_device, dtype, visited) # Also check for any nn.Module attributes that might not be in _modules for attr_name in dir(module): if attr_name.startswith('_'): continue try: attr = getattr(module, attr_name, None) if isinstance(attr, torch.nn.Module) and id(attr) not in visited: self._move_module_recursive(attr, target_device, dtype, visited) except Exception: pass def _recursive_to_device(self, model, device, dtype=None): """ Recursively move all parameters and buffers of a model to the specified device. This is more thorough than model.to() for some custom HuggingFace models. """ target_device = torch.device(device) if isinstance(device, str) else device # Method 1: Standard .to() call model.to(target_device) if dtype is not None: model.to(dtype) # Method 2: Use our thorough recursive moving for any missed modules self._move_module_recursive(model, target_device, dtype) # Method 3: Force move via state_dict if there are still parameters on wrong device wrong_device_params = [] for name, param in model.named_parameters(): if not self._is_on_target_device(param, device): wrong_device_params.append(name) if wrong_device_params and device != "cpu": logger.warning(f"[_recursive_to_device] {len(wrong_device_params)} parameters on wrong device, using state_dict method") # Get current state dict and move all tensors state_dict = model.state_dict() moved_state_dict = {} for key, value in state_dict.items(): if isinstance(value, torch.Tensor): moved_state_dict[key] = value.to(target_device) if dtype is not None and moved_state_dict[key].is_floating_point(): moved_state_dict[key] = moved_state_dict[key].to(dtype) else: moved_state_dict[key] = value model.load_state_dict(moved_state_dict) # Synchronize CUDA to ensure all transfers are complete if device != "cpu" and torch.cuda.is_available(): torch.cuda.synchronize() # Final verification if device != "cpu": still_wrong = [] for name, param in model.named_parameters(): if not self._is_on_target_device(param, device): still_wrong.append(f"{name} on {param.device}") if still_wrong: logger.error(f"[_recursive_to_device] CRITICAL: {len(still_wrong)} parameters still on wrong device: {still_wrong[:10]}") @contextmanager def _load_model_context(self, model_name: str): """ Context manager to load a model to GPU and offload it back to CPU after use. Args: model_name: Name of the model to load ("text_encoder", "vae", "model") """ if not self.offload_to_cpu: yield return # If model is DiT ("model") and offload_dit_to_cpu is False, do not offload if model_name == "model" and not self.offload_dit_to_cpu: # Ensure it's on device if not already (should be handled by init, but safe to check) model = getattr(self, model_name, None) if model is not None: # Check if model is on CPU, if so move to device (one-time move if it was somehow on CPU) # We check the first parameter's device try: param = next(model.parameters()) if param.device.type == "cpu": logger.info(f"[_load_model_context] Moving {model_name} to {self.device} (persistent)") self._recursive_to_device(model, self.device, self.dtype) if hasattr(self, "silence_latent"): self.silence_latent = self.silence_latent.to(self.device).to(self.dtype) except StopIteration: pass yield return model = getattr(self, model_name, None) if model is None: yield return # Load to GPU logger.info(f"[_load_model_context] Loading {model_name} to {self.device}") start_time = time.time() if model_name == "vae": vae_dtype = self._get_vae_dtype() self._recursive_to_device(model, self.device, vae_dtype) else: self._recursive_to_device(model, self.device, self.dtype) if model_name == "model" and hasattr(self, "silence_latent"): self.silence_latent = self.silence_latent.to(self.device).to(self.dtype) load_time = time.time() - start_time self.current_offload_cost += load_time logger.info(f"[_load_model_context] Loaded {model_name} to {self.device} in {load_time:.4f}s") try: yield finally: # Offload to CPU logger.info(f"[_load_model_context] Offloading {model_name} to CPU") start_time = time.time() self._recursive_to_device(model, "cpu") # NOTE: Do NOT offload silence_latent to CPU here! # silence_latent is used in many places outside of model context, # so it should stay on GPU to avoid device mismatch errors. torch.cuda.empty_cache() offload_time = time.time() - start_time self.current_offload_cost += offload_time logger.info(f"[_load_model_context] Offloaded {model_name} to CPU in {offload_time:.4f}s") def process_target_audio(self, audio_file) -> Optional[torch.Tensor]: """Process target audio""" if audio_file is None: return None try: # Load audio using soundfile audio_np, sr = sf.read(audio_file, dtype='float32') # Convert to torch: [samples, channels] or [samples] -> [channels, samples] if audio_np.ndim == 1: audio = torch.from_numpy(audio_np).unsqueeze(0) else: audio = torch.from_numpy(audio_np.T) # Normalize to stereo 48kHz audio = self._normalize_audio_to_stereo_48k(audio, sr) # Enforce duration limits (10-600 seconds) audio = self._enforce_audio_duration_limits(audio) return audio except Exception as e: logger.exception("[process_target_audio] Error processing target audio") return None def _parse_audio_code_string(self, code_str: str) -> List[int]: """Extract integer audio codes from prompt tokens like <|audio_code_123|>. Codes are clamped to valid range [0, 63999] (codebook size = 64000). """ if not code_str: return [] try: codes = [int(x) for x in re.findall(r"<\|audio_code_(\d+)\|>", code_str)] # Clamp codes to valid range [0, 63999] MAX_AUDIO_CODE = 63999 clamped_codes = [] invalid_codes = [] for code in codes: if code < 0 or code > MAX_AUDIO_CODE: invalid_codes.append(code) clamped_code = max(0, min(code, MAX_AUDIO_CODE)) clamped_codes.append(clamped_code) else: clamped_codes.append(code) if invalid_codes: logger.warning(f"[_parse_audio_code_string] Found {len(invalid_codes)} codes outside valid range [0, {MAX_AUDIO_CODE}]: {invalid_codes[:5]}... (clamped to valid range)") return clamped_codes except Exception as e: logger.debug(f"[_parse_audio_code_string] Failed to parse audio code string: {e}") return [] def _decode_audio_codes_to_latents(self, code_str: str) -> Optional[torch.Tensor]: """ Convert serialized audio code string into 25Hz latents using model quantizer/detokenizer. """ if not self.model or not hasattr(self.model, 'tokenizer') or not hasattr(self.model, 'detokenizer'): return None code_ids = self._parse_audio_code_string(code_str) if len(code_ids) == 0: return None try: with self._load_model_context("model"): quantizer = self.model.tokenizer.quantizer detokenizer = self.model.detokenizer # Get codebook size for validation # Default to 64000 (codebook size = 64000, valid range = 0-63999) codebook_size = getattr(quantizer, 'codebook_size', 64000) if hasattr(quantizer, 'quantizers') and len(quantizer.quantizers) > 0: codebook_size = getattr(quantizer.quantizers[0], 'codebook_size', codebook_size) # Validate code IDs are within valid range invalid_codes = [c for c in code_ids if c < 0 or c >= codebook_size] if invalid_codes: logger.warning(f"[_decode_audio_codes_to_latents] Found {len(invalid_codes)} invalid codes out of range [0, {codebook_size}): {invalid_codes[:5]}...") # Clamp invalid codes to valid range code_ids = [max(0, min(c, codebook_size - 1)) for c in code_ids] num_quantizers = getattr(quantizer, "num_quantizers", 1) # Create indices tensor: [T_5Hz] indices = torch.tensor(code_ids, device=self.device, dtype=torch.long) # [T_5Hz] indices = indices.unsqueeze(0).unsqueeze(-1) # [1, T_5Hz, 1] # Synchronize to catch any CUDA errors before proceeding if torch.cuda.is_available(): torch.cuda.synchronize() # Get quantized representation from indices # The quantizer expects [batch, T_5Hz] format and handles quantizer dimension internally quantized = quantizer.get_output_from_indices(indices) if quantized.dtype != self.dtype: quantized = quantized.to(self.dtype) # Detokenize to 25Hz: [1, T_5Hz, dim] -> [1, T_25Hz, dim] lm_hints_25hz = detokenizer(quantized) return lm_hints_25hz except Exception as e: logger.exception(f"[_decode_audio_codes_to_latents] Error decoding audio codes: {e}") # Clear CUDA error state if torch.cuda.is_available(): torch.cuda.empty_cache() return None def _create_default_meta(self) -> str: """Create default metadata string.""" return ( "- bpm: N/A\n" "- timesignature: N/A\n" "- keyscale: N/A\n" "- duration: 30 seconds\n" ) def _dict_to_meta_string(self, meta_dict: Dict[str, Any]) -> str: """Convert metadata dict to formatted string.""" bpm = meta_dict.get('bpm', meta_dict.get('tempo', 'N/A')) timesignature = meta_dict.get('timesignature', meta_dict.get('time_signature', 'N/A')) keyscale = meta_dict.get('keyscale', meta_dict.get('key', meta_dict.get('scale', 'N/A'))) duration = meta_dict.get('duration', meta_dict.get('length', 30)) # Format duration if isinstance(duration, (int, float)): duration = f"{int(duration)} seconds" elif not isinstance(duration, str): duration = "30 seconds" return ( f"- bpm: {bpm}\n" f"- timesignature: {timesignature}\n" f"- keyscale: {keyscale}\n" f"- duration: {duration}\n" ) def _parse_metas(self, metas: List[Union[str, Dict[str, Any]]]) -> List[str]: """ Parse and normalize metadata with fallbacks. Args: metas: List of metadata (can be strings, dicts, or None) Returns: List of formatted metadata strings """ parsed_metas = [] for meta in metas: if meta is None: # Default fallback metadata parsed_meta = self._create_default_meta() elif isinstance(meta, str): # Already formatted string parsed_meta = meta elif isinstance(meta, dict): # Convert dict to formatted string parsed_meta = self._dict_to_meta_string(meta) else: # Fallback for any other type parsed_meta = self._create_default_meta() parsed_metas.append(parsed_meta) return parsed_metas def build_dit_inputs( self, task: str, instruction: Optional[str], caption: str, lyrics: str, metas: Optional[Union[str, Dict[str, Any]]] = None, vocal_language: str = "en", ) -> Tuple[str, str]: """ Build text inputs for the caption and lyric branches used by DiT. Args: task: Task name (e.g., text2music, cover, repaint); kept for logging/future branching. instruction: Instruction text; default fallback matches service_generate behavior. caption: Caption string (fallback if not in metas). lyrics: Lyrics string. metas: Metadata (str or dict); follows _parse_metas formatting. May contain 'caption' and 'language' fields from LM CoT output. vocal_language: Language code for lyrics section (fallback if not in metas). Returns: (caption_input_text, lyrics_input_text) Example: caption_input, lyrics_input = handler.build_dit_inputs( task="text2music", instruction=None, caption="A calm piano melody", lyrics="la la la", metas={"bpm": 90, "duration": 45, "caption": "LM generated caption", "language": "en"}, vocal_language="en", ) """ # Align instruction formatting with _prepare_batch final_instruction = self._format_instruction(instruction or DEFAULT_DIT_INSTRUCTION) # Extract caption and language from metas if available (from LM CoT output) # Fallback to user-provided values if not in metas actual_caption = caption actual_language = vocal_language if metas is not None: # Parse metas to dict if it's a string if isinstance(metas, str): # Try to parse as dict-like string or use as-is parsed_metas = self._parse_metas([metas]) if parsed_metas and isinstance(parsed_metas[0], dict): meta_dict = parsed_metas[0] else: meta_dict = {} elif isinstance(metas, dict): meta_dict = metas else: meta_dict = {} # Extract caption from metas if available if 'caption' in meta_dict and meta_dict['caption']: actual_caption = str(meta_dict['caption']) # Extract language from metas if available if 'language' in meta_dict and meta_dict['language']: actual_language = str(meta_dict['language']) parsed_meta = self._parse_metas([metas])[0] caption_input = SFT_GEN_PROMPT.format(final_instruction, actual_caption, parsed_meta) lyrics_input = self._format_lyrics(lyrics, actual_language) return caption_input, lyrics_input def _get_text_hidden_states(self, text_prompt: str) -> Tuple[torch.Tensor, torch.Tensor]: """Get text hidden states from text encoder.""" if self.text_tokenizer is None or self.text_encoder is None: raise ValueError("Text encoder not initialized") with self._load_model_context("text_encoder"): # Tokenize text_inputs = self.text_tokenizer( text_prompt, padding="longest", truncation=True, max_length=256, return_tensors="pt", ) text_input_ids = text_inputs.input_ids.to(self.device) text_attention_mask = text_inputs.attention_mask.to(self.device).bool() # Encode with torch.no_grad(): text_outputs = self.text_encoder(text_input_ids) if hasattr(text_outputs, 'last_hidden_state'): text_hidden_states = text_outputs.last_hidden_state elif isinstance(text_outputs, tuple): text_hidden_states = text_outputs[0] else: text_hidden_states = text_outputs text_hidden_states = text_hidden_states.to(self.dtype) return text_hidden_states, text_attention_mask def extract_caption_from_sft_format(self, caption: str) -> str: try: if "# Instruction" in caption and "# Caption" in caption: pattern = r'#\s*Caption\s*\n(.*?)(?:\n\s*#\s*Metas|$)' match = re.search(pattern, caption, re.DOTALL) if match: return match.group(1).strip() return caption except Exception as e: logger.exception("[extract_caption_from_sft_format] Error extracting caption") return caption def prepare_seeds(self, actual_batch_size, seed, use_random_seed): actual_seed_list: List[int] = [] seed_value_for_ui = "" if use_random_seed: # Generate brand new seeds and expose them back to the UI actual_seed_list = [random.randint(0, 2 ** 32 - 1) for _ in range(actual_batch_size)] seed_value_for_ui = ", ".join(str(s) for s in actual_seed_list) else: # Parse seed input: can be a single number, comma-separated numbers, or -1 # If seed is a string, try to parse it as comma-separated values seed_list = [] if isinstance(seed, str): # Handle string input (e.g., "123,456" or "-1") seed_str_list = [s.strip() for s in seed.split(",")] for s in seed_str_list: if s == "-1" or s == "": seed_list.append(-1) else: try: seed_list.append(int(float(s))) except (ValueError, TypeError) as e: logger.debug(f"[prepare_seeds] Failed to parse seed value '{s}': {e}") seed_list.append(-1) elif seed is None or (isinstance(seed, (int, float)) and seed < 0): # If seed is None or negative, use -1 for all items seed_list = [-1] * actual_batch_size elif isinstance(seed, (int, float)): # Single seed value seed_list = [int(seed)] else: # Fallback: use -1 seed_list = [-1] * actual_batch_size # Process seed list according to rules: # 1. If all are -1, generate different random seeds for each batch item # 2. If one non-negative seed is provided and batch_size > 1, first uses that seed, rest are random # 3. If more seeds than batch_size, use first batch_size seeds # Check if user provided only one non-negative seed (not -1) has_single_non_negative_seed = (len(seed_list) == 1 and seed_list[0] != -1) for i in range(actual_batch_size): if i < len(seed_list): seed_val = seed_list[i] else: # If not enough seeds provided, use -1 (will generate random) seed_val = -1 # Special case: if only one non-negative seed was provided and batch_size > 1, # only the first item uses that seed, others are random if has_single_non_negative_seed and actual_batch_size > 1 and i > 0: # Generate random seed for remaining items actual_seed_list.append(random.randint(0, 2 ** 32 - 1)) elif seed_val == -1: # Generate a random seed for this item actual_seed_list.append(random.randint(0, 2 ** 32 - 1)) else: actual_seed_list.append(int(seed_val)) seed_value_for_ui = ", ".join(str(s) for s in actual_seed_list) return actual_seed_list, seed_value_for_ui def prepare_metadata(self, bpm, key_scale, time_signature): """Build metadata dict - use "N/A" as default for empty fields.""" return self._build_metadata_dict(bpm, key_scale, time_signature) def is_silence(self, audio): return torch.all(audio.abs() < 1e-6) def _get_project_root(self) -> str: """Get project root directory path.""" current_file = os.path.abspath(__file__) return os.path.dirname(os.path.dirname(current_file)) def _get_vae_dtype(self, device: Optional[str] = None) -> torch.dtype: """Get VAE dtype based on device.""" device = device or self.device return torch.bfloat16 if device in ["cuda", "xpu"] else self.dtype def _format_instruction(self, instruction: str) -> str: """Format instruction to ensure it ends with colon.""" if not instruction.endswith(":"): instruction = instruction + ":" return instruction def _load_audio_file(self, audio_file) -> Tuple[torch.Tensor, int]: """ Load audio file with ffmpeg backend, fallback to soundfile if failed. This handles CUDA dependency issues with torchcodec on HuggingFace Space. Args: audio_file: Path to the audio file Returns: Tuple of (audio_tensor, sample_rate) Raises: FileNotFoundError: If the audio file doesn't exist Exception: If all methods fail to load the audio """ # Check if file exists first if not os.path.exists(audio_file): raise FileNotFoundError(f"Audio file not found: {audio_file}") # Try torchaudio with explicit ffmpeg backend first try: audio, sr = torchaudio.load(audio_file, backend="ffmpeg") return audio, sr except Exception as e: logger.debug(f"[_load_audio_file] ffmpeg backend failed: {e}, trying soundfile fallback") # Fallback: use soundfile directly (most compatible) try: audio_np, sr = sf.read(audio_file) # soundfile returns [samples, channels] or [samples], convert to [channels, samples] audio = torch.from_numpy(audio_np).float() if audio.dim() == 1: # Mono: [samples] -> [1, samples] audio = audio.unsqueeze(0) else: # Stereo: [samples, channels] -> [channels, samples] audio = audio.T return audio, sr except Exception as e: logger.error(f"[_load_audio_file] All methods failed to load audio: {audio_file}, error: {e}") raise def _normalize_audio_to_stereo_48k(self, audio: torch.Tensor, sr: int) -> torch.Tensor: """ Normalize audio to stereo 48kHz format. Args: audio: Audio tensor [channels, samples] or [samples] sr: Sample rate Returns: Normalized audio tensor [2, samples] at 48kHz """ # Convert to stereo (duplicate channel if mono) if audio.shape[0] == 1: audio = torch.cat([audio, audio], dim=0) # Keep only first 2 channels audio = audio[:2] # Resample to 48kHz if needed if sr != 48000: audio = torchaudio.transforms.Resample(sr, 48000)(audio) # Clamp values to [-1.0, 1.0] audio = torch.clamp(audio, -1.0, 1.0) return audio def _enforce_audio_duration_limits( self, audio: torch.Tensor, sample_rate: int = 48000, min_duration: float = 10.0, max_duration: float = 600.0 ) -> torch.Tensor: """ Enforce audio duration limits by truncating or repeating. Args: audio: Audio tensor [channels, samples] at target sample rate sample_rate: Sample rate of the audio (default: 48000) min_duration: Minimum duration in seconds (default: 10.0) max_duration: Maximum duration in seconds (default: 600.0) Returns: Audio tensor with enforced duration limits """ current_samples = audio.shape[-1] current_duration = current_samples / sample_rate min_samples = int(min_duration * sample_rate) max_samples = int(max_duration * sample_rate) # If audio is longer than max_duration, truncate if current_samples > max_samples: logger.info(f"[_enforce_audio_duration_limits] Truncating audio from {current_duration:.1f}s to {max_duration:.1f}s") audio = audio[..., :max_samples] # If audio is shorter than min_duration, repeat to fill elif current_samples < min_samples: logger.info(f"[_enforce_audio_duration_limits] Repeating audio from {current_duration:.1f}s to reach {min_duration:.1f}s") # Calculate how many times to repeat repeat_times = int(math.ceil(min_samples / current_samples)) # Repeat along the time dimension audio = audio.repeat(1, repeat_times) # Truncate to exactly min_samples audio = audio[..., :min_samples] return audio def _normalize_audio_code_hints(self, audio_code_hints: Optional[Union[str, List[str]]], batch_size: int) -> List[Optional[str]]: """Normalize audio_code_hints to list of correct length.""" if audio_code_hints is None: normalized = [None] * batch_size elif isinstance(audio_code_hints, str): normalized = [audio_code_hints] * batch_size elif len(audio_code_hints) == 1 and batch_size > 1: normalized = audio_code_hints * batch_size elif len(audio_code_hints) != batch_size: # Pad or truncate to match batch_size normalized = list(audio_code_hints[:batch_size]) while len(normalized) < batch_size: normalized.append(None) else: normalized = list(audio_code_hints) # Clean up: convert empty strings to None normalized = [hint if isinstance(hint, str) and hint.strip() else None for hint in normalized] return normalized def _normalize_instructions(self, instructions: Optional[Union[str, List[str]]], batch_size: int, default: Optional[str] = None) -> List[str]: """Normalize instructions to list of correct length.""" if instructions is None: default_instruction = default or DEFAULT_DIT_INSTRUCTION return [default_instruction] * batch_size elif isinstance(instructions, str): return [instructions] * batch_size elif len(instructions) == 1: return instructions * batch_size elif len(instructions) != batch_size: # Pad or truncate to match batch_size normalized = list(instructions[:batch_size]) default_instruction = default or DEFAULT_DIT_INSTRUCTION while len(normalized) < batch_size: normalized.append(default_instruction) return normalized else: return list(instructions) def _format_lyrics(self, lyrics: str, language: str) -> str: """Format lyrics text with language header.""" return f"# Languages\n{language}\n\n# Lyric\n{lyrics}<|endoftext|>" def _pad_sequences(self, sequences: List[torch.Tensor], max_length: int, pad_value: int = 0) -> torch.Tensor: """Pad sequences to same length.""" return torch.stack([ torch.nn.functional.pad(seq, (0, max_length - len(seq)), 'constant', pad_value) for seq in sequences ]) def _extract_caption_and_language(self, metas: List[Union[str, Dict[str, Any]]], captions: List[str], vocal_languages: List[str]) -> Tuple[List[str], List[str]]: """Extract caption and language from metas with fallback to provided values.""" actual_captions = list(captions) actual_languages = list(vocal_languages) for i, meta in enumerate(metas): if i >= len(actual_captions): break meta_dict = None if isinstance(meta, str): parsed = self._parse_metas([meta]) if parsed and isinstance(parsed[0], dict): meta_dict = parsed[0] elif isinstance(meta, dict): meta_dict = meta if meta_dict: if 'caption' in meta_dict and meta_dict['caption']: actual_captions[i] = str(meta_dict['caption']) if 'language' in meta_dict and meta_dict['language']: actual_languages[i] = str(meta_dict['language']) return actual_captions, actual_languages def _encode_audio_to_latents(self, audio: torch.Tensor) -> torch.Tensor: """ Encode audio to latents using VAE. Args: audio: Audio tensor [channels, samples] or [batch, channels, samples] Returns: Latents tensor [T, D] or [batch, T, D] """ # Save original dimension info BEFORE modifying audio input_was_2d = (audio.dim() == 2) # Ensure batch dimension if input_was_2d: audio = audio.unsqueeze(0) # Ensure input is in VAE's dtype vae_input = audio.to(self.device).to(self.vae.dtype) # Encode to latents with torch.no_grad(): latents = self.vae.encode(vae_input).latent_dist.sample() # Cast back to model dtype latents = latents.to(self.dtype) # Transpose: [batch, d, T] -> [batch, T, d] latents = latents.transpose(1, 2) # Remove batch dimension if input didn't have it if input_was_2d: latents = latents.squeeze(0) return latents def _build_metadata_dict(self, bpm: Optional[Union[int, str]], key_scale: str, time_signature: str, duration: Optional[float] = None) -> Dict[str, Any]: """ Build metadata dictionary with default values. Args: bpm: BPM value (optional) key_scale: Key/scale string time_signature: Time signature string duration: Duration in seconds (optional) Returns: Metadata dictionary """ metadata_dict = {} if bpm: metadata_dict["bpm"] = bpm else: metadata_dict["bpm"] = "N/A" if key_scale.strip(): metadata_dict["keyscale"] = key_scale else: metadata_dict["keyscale"] = "N/A" if time_signature.strip() and time_signature != "N/A" and time_signature: metadata_dict["timesignature"] = time_signature else: metadata_dict["timesignature"] = "N/A" # Add duration if provided if duration is not None: metadata_dict["duration"] = f"{int(duration)} seconds" return metadata_dict def generate_instruction( self, task_type: str, track_name: Optional[str] = None, complete_track_classes: Optional[List[str]] = None ) -> str: if task_type == "text2music": return TASK_INSTRUCTIONS["text2music"] elif task_type == "repaint": return TASK_INSTRUCTIONS["repaint"] elif task_type == "cover": return TASK_INSTRUCTIONS["cover"] elif task_type == "extract": if track_name: # Convert to uppercase track_name_upper = track_name.upper() return TASK_INSTRUCTIONS["extract"].format(TRACK_NAME=track_name_upper) else: return TASK_INSTRUCTIONS["extract_default"] elif task_type == "lego": if track_name: # Convert to uppercase track_name_upper = track_name.upper() return TASK_INSTRUCTIONS["lego"].format(TRACK_NAME=track_name_upper) else: return TASK_INSTRUCTIONS["lego_default"] elif task_type == "complete": if complete_track_classes and len(complete_track_classes) > 0: # Convert to uppercase and join with " | " track_classes_upper = [t.upper() for t in complete_track_classes] complete_track_classes_str = " | ".join(track_classes_upper) return TASK_INSTRUCTIONS["complete"].format(TRACK_CLASSES=complete_track_classes_str) else: return TASK_INSTRUCTIONS["complete_default"] else: return TASK_INSTRUCTIONS["text2music"] def process_reference_audio(self, audio_file) -> Optional[torch.Tensor]: if audio_file is None: return None try: # Load audio file with fallback backends audio, sr = self._load_audio_file(audio_file) logger.debug(f"[process_reference_audio] Reference audio shape: {audio.shape}") logger.debug(f"[process_reference_audio] Reference audio sample rate: {sr}") logger.debug(f"[process_reference_audio] Reference audio duration: {audio.shape[-1] / 48000.0} seconds") # Normalize to stereo 48kHz audio = self._normalize_audio_to_stereo_48k(audio, sr) is_silence = self.is_silence(audio) if is_silence: return None # Target length: 30 seconds at 48kHz target_frames = 30 * 48000 segment_frames = 10 * 48000 # 10 seconds per segment # If audio is less than 30 seconds, repeat to at least 30 seconds if audio.shape[-1] < target_frames: repeat_times = math.ceil(target_frames / audio.shape[-1]) audio = audio.repeat(1, repeat_times) # If audio is greater than or equal to 30 seconds, no operation needed # For all cases, select random 10-second segments from front, middle, and back # then concatenate them to form 30 seconds total_frames = audio.shape[-1] segment_size = total_frames // 3 # Front segment: [0, segment_size] front_start = random.randint(0, max(0, segment_size - segment_frames)) front_audio = audio[:, front_start:front_start + segment_frames] # Middle segment: [segment_size, 2*segment_size] middle_start = segment_size + random.randint(0, max(0, segment_size - segment_frames)) middle_audio = audio[:, middle_start:middle_start + segment_frames] # Back segment: [2*segment_size, total_frames] back_start = 2 * segment_size + random.randint(0, max(0, (total_frames - 2 * segment_size) - segment_frames)) back_audio = audio[:, back_start:back_start + segment_frames] # Concatenate three segments to form 30 seconds audio = torch.cat([front_audio, middle_audio, back_audio], dim=-1) return audio except Exception as e: logger.exception("[process_reference_audio] Error processing reference audio") return None def process_src_audio(self, audio_file) -> Optional[torch.Tensor]: if audio_file is None: return None try: # Load audio file with fallback backends audio, sr = self._load_audio_file(audio_file) # Normalize to stereo 48kHz audio = self._normalize_audio_to_stereo_48k(audio, sr) # Enforce duration limits (10-600 seconds) audio = self._enforce_audio_duration_limits(audio) return audio except Exception as e: logger.exception("[process_src_audio] Error processing source audio") return None def convert_src_audio_to_codes(self, audio_file) -> str: """ Convert uploaded source audio to audio codes string. Args: audio_file: Path to audio file or None Returns: Formatted codes string like '<|audio_code_123|><|audio_code_456|>...' or error message """ if audio_file is None: return "❌ Please upload source audio first" if self.model is None or self.vae is None: return "❌ Model not initialized. Please initialize the service first." try: # Process audio file processed_audio = self.process_src_audio(audio_file) if processed_audio is None: return "❌ Failed to process audio file" # Encode audio to latents using VAE with torch.no_grad(): with self._load_model_context("vae"): # Check if audio is silence if self.is_silence(processed_audio.unsqueeze(0)): return "❌ Audio file appears to be silent" # Encode to latents using helper method latents = self._encode_audio_to_latents(processed_audio) # [T, d] # Create attention mask for latents attention_mask = torch.ones(latents.shape[0], dtype=torch.bool, device=self.device) # Tokenize latents to get code indices with self._load_model_context("model"): # Prepare latents for tokenize: [T, d] -> [1, T, d] hidden_states = latents.unsqueeze(0) # [1, T, d] # Call tokenize method # tokenize returns: (quantized, indices, attention_mask) _, indices, _ = self.model.tokenize(hidden_states, self.silence_latent, attention_mask.unsqueeze(0)) # Format indices as code string # indices shape: [1, T_5Hz] or [1, T_5Hz, num_quantizers] # Flatten and convert to list indices_flat = indices.flatten().cpu().tolist() codes_string = "".join([f"<|audio_code_{idx}|>" for idx in indices_flat]) logger.info(f"[convert_src_audio_to_codes] Generated {len(indices_flat)} audio codes") return codes_string except Exception as e: error_msg = f"❌ Error converting audio to codes: {str(e)}\n{traceback.format_exc()}" logger.exception("[convert_src_audio_to_codes] Error converting audio to codes") return error_msg def prepare_batch_data( self, actual_batch_size, processed_src_audio, audio_duration, captions, lyrics, vocal_language, instruction, bpm, key_scale, time_signature ): pure_caption = self.extract_caption_from_sft_format(captions) captions_batch = [pure_caption] * actual_batch_size instructions_batch = [instruction] * actual_batch_size lyrics_batch = [lyrics] * actual_batch_size vocal_languages_batch = [vocal_language] * actual_batch_size # Calculate duration for metadata calculated_duration = None if processed_src_audio is not None: calculated_duration = processed_src_audio.shape[-1] / 48000.0 elif audio_duration is not None and audio_duration > 0: calculated_duration = audio_duration # Build metadata dict - use "N/A" as default for empty fields metadata_dict = self._build_metadata_dict(bpm, key_scale, time_signature, calculated_duration) # Format metadata - inference service accepts dict and will convert to string # Create a copy for each batch item (in case we modify it) metas_batch = [metadata_dict.copy() for _ in range(actual_batch_size)] return captions_batch, instructions_batch, lyrics_batch, vocal_languages_batch, metas_batch def determine_task_type(self, task_type, audio_code_string): # Determine task type - repaint and lego tasks can have repainting parameters # Other tasks (cover, text2music, extract, complete) should NOT have repainting is_repaint_task = (task_type == "repaint") is_lego_task = (task_type == "lego") is_cover_task = (task_type == "cover") has_codes = False if isinstance(audio_code_string, list): has_codes = any((c or "").strip() for c in audio_code_string) else: has_codes = bool(audio_code_string and str(audio_code_string).strip()) if has_codes: is_cover_task = True # Both repaint and lego tasks can use repainting parameters for chunk mask can_use_repainting = is_repaint_task or is_lego_task return is_repaint_task, is_lego_task, is_cover_task, can_use_repainting def create_target_wavs(self, duration_seconds: float) -> torch.Tensor: try: # Ensure minimum precision of 100ms duration_seconds = max(0.1, round(duration_seconds, 1)) # Calculate frames for 48kHz stereo frames = int(duration_seconds * 48000) # Create silent stereo audio target_wavs = torch.zeros(2, frames) return target_wavs except Exception as e: logger.exception("[create_target_wavs] Error creating target audio") # Fallback to 30 seconds if error return torch.zeros(2, 30 * 48000) def prepare_padding_info( self, actual_batch_size, processed_src_audio, audio_duration, repainting_start, repainting_end, is_repaint_task, is_lego_task, is_cover_task, can_use_repainting, ): target_wavs_batch = [] # Store padding info for each batch item to adjust repainting coordinates padding_info_batch = [] for i in range(actual_batch_size): if processed_src_audio is not None: if is_cover_task: # Cover task: Use src_audio directly without padding batch_target_wavs = processed_src_audio padding_info_batch.append({ 'left_padding_duration': 0.0, 'right_padding_duration': 0.0 }) elif is_repaint_task or is_lego_task: # Repaint/lego task: May need padding for outpainting src_audio_duration = processed_src_audio.shape[-1] / 48000.0 # Determine actual end time if repainting_end is None or repainting_end < 0: actual_end = src_audio_duration else: actual_end = repainting_end left_padding_duration = max(0, -repainting_start) if repainting_start is not None else 0 right_padding_duration = max(0, actual_end - src_audio_duration) # Create padded audio left_padding_frames = int(left_padding_duration * 48000) right_padding_frames = int(right_padding_duration * 48000) if left_padding_frames > 0 or right_padding_frames > 0: # Pad the src audio batch_target_wavs = torch.nn.functional.pad( processed_src_audio, (left_padding_frames, right_padding_frames), 'constant', 0 ) else: batch_target_wavs = processed_src_audio # Store padding info for coordinate adjustment padding_info_batch.append({ 'left_padding_duration': left_padding_duration, 'right_padding_duration': right_padding_duration }) else: # Other tasks: Use src_audio directly without padding batch_target_wavs = processed_src_audio padding_info_batch.append({ 'left_padding_duration': 0.0, 'right_padding_duration': 0.0 }) else: padding_info_batch.append({ 'left_padding_duration': 0.0, 'right_padding_duration': 0.0 }) if audio_duration is not None and audio_duration > 0: batch_target_wavs = self.create_target_wavs(audio_duration) else: import random random_duration = random.uniform(10.0, 120.0) batch_target_wavs = self.create_target_wavs(random_duration) target_wavs_batch.append(batch_target_wavs) # Stack target_wavs into batch tensor # Ensure all tensors have the same shape by padding to max length max_frames = max(wav.shape[-1] for wav in target_wavs_batch) padded_target_wavs = [] for wav in target_wavs_batch: if wav.shape[-1] < max_frames: pad_frames = max_frames - wav.shape[-1] padded_wav = torch.nn.functional.pad(wav, (0, pad_frames), 'constant', 0) padded_target_wavs.append(padded_wav) else: padded_target_wavs.append(wav) target_wavs_tensor = torch.stack(padded_target_wavs, dim=0) # [batch_size, 2, frames] if can_use_repainting: # Repaint task: Set repainting parameters if repainting_start is None: repainting_start_batch = None elif isinstance(repainting_start, (int, float)): if processed_src_audio is not None: adjusted_start = repainting_start + padding_info_batch[0]['left_padding_duration'] repainting_start_batch = [adjusted_start] * actual_batch_size else: repainting_start_batch = [repainting_start] * actual_batch_size else: # List input - adjust each item repainting_start_batch = [] for i in range(actual_batch_size): if processed_src_audio is not None: adjusted_start = repainting_start[i] + padding_info_batch[i]['left_padding_duration'] repainting_start_batch.append(adjusted_start) else: repainting_start_batch.append(repainting_start[i]) # Handle repainting_end - use src audio duration if not specified or negative if processed_src_audio is not None: # If src audio is provided, use its duration as default end src_audio_duration = processed_src_audio.shape[-1] / 48000.0 if repainting_end is None or repainting_end < 0: # Use src audio duration (before padding), then adjust for padding adjusted_end = src_audio_duration + padding_info_batch[0]['left_padding_duration'] repainting_end_batch = [adjusted_end] * actual_batch_size else: # Adjust repainting_end to be relative to padded audio adjusted_end = repainting_end + padding_info_batch[0]['left_padding_duration'] repainting_end_batch = [adjusted_end] * actual_batch_size else: # No src audio - repainting doesn't make sense without it if repainting_end is None or repainting_end < 0: repainting_end_batch = None elif isinstance(repainting_end, (int, float)): repainting_end_batch = [repainting_end] * actual_batch_size else: # List input - adjust each item repainting_end_batch = [] for i in range(actual_batch_size): if processed_src_audio is not None: adjusted_end = repainting_end[i] + padding_info_batch[i]['left_padding_duration'] repainting_end_batch.append(adjusted_end) else: repainting_end_batch.append(repainting_end[i]) else: # All other tasks (cover, text2music, extract, complete): No repainting # Only repaint and lego tasks should have repainting parameters repainting_start_batch = None repainting_end_batch = None return repainting_start_batch, repainting_end_batch, target_wavs_tensor def _prepare_batch( self, captions: List[str], lyrics: List[str], keys: Optional[List[str]] = None, target_wavs: Optional[torch.Tensor] = None, refer_audios: Optional[List[List[torch.Tensor]]] = None, metas: Optional[List[Union[str, Dict[str, Any]]]] = None, vocal_languages: Optional[List[str]] = None, repainting_start: Optional[List[float]] = None, repainting_end: Optional[List[float]] = None, instructions: Optional[List[str]] = None, audio_code_hints: Optional[List[Optional[str]]] = None, audio_cover_strength: float = 1.0, ) -> Dict[str, Any]: """ Prepare batch data with fallbacks for missing inputs. Args: captions: List of text captions (optional, can be empty strings) lyrics: List of lyrics (optional, can be empty strings) keys: List of unique identifiers (optional) target_wavs: Target audio tensors (optional, will use silence if not provided) refer_audios: Reference audio tensors (optional, will use silence if not provided) metas: Metadata (optional, will use defaults if not provided) vocal_languages: Vocal languages (optional, will default to 'en') Returns: Batch dictionary ready for model input """ batch_size = len(captions) # Ensure silence_latent is on the correct device for batch preparation self._ensure_silence_latent_on_device() # Normalize audio_code_hints to batch list audio_code_hints = self._normalize_audio_code_hints(audio_code_hints, batch_size) # Synchronize CUDA to catch any pending errors from previous operations if torch.cuda.is_available(): torch.cuda.synchronize() for ii, refer_audio_list in enumerate(refer_audios): if refer_audio_list is None: continue if isinstance(refer_audio_list, list): for idx, refer_audio in enumerate(refer_audio_list): if refer_audio is not None and isinstance(refer_audio, torch.Tensor): refer_audio_list[idx] = refer_audio.to(self.device).to(torch.bfloat16) elif isinstance(refer_audio_list, torch.Tensor): refer_audios[ii] = refer_audio_list.to(self.device) if vocal_languages is None: vocal_languages = self._create_fallback_vocal_languages(batch_size) # Parse metas with fallbacks parsed_metas = self._parse_metas(metas) # Encode target_wavs to get target_latents with torch.no_grad(): target_latents_list = [] latent_lengths = [] # Use per-item wavs (may be adjusted if audio_code_hints are provided) target_wavs_list = [target_wavs[i].clone() for i in range(batch_size)] if target_wavs.device != self.device: target_wavs = target_wavs.to(self.device) with self._load_model_context("vae"): for i in range(batch_size): code_hint = audio_code_hints[i] # Prefer decoding from provided audio codes if code_hint: logger.info(f"[generate_music] Decoding audio codes for item {i}...") decoded_latents = self._decode_audio_codes_to_latents(code_hint) if decoded_latents is not None: decoded_latents = decoded_latents.squeeze(0) target_latents_list.append(decoded_latents) latent_lengths.append(decoded_latents.shape[0]) # Create a silent wav matching the latent length for downstream scaling frames_from_codes = max(1, int(decoded_latents.shape[0] * 1920)) target_wavs_list[i] = torch.zeros(2, frames_from_codes) continue # Fallback to VAE encode from audio current_wav = target_wavs_list[i].to(self.device).unsqueeze(0) if self.is_silence(current_wav): expected_latent_length = current_wav.shape[-1] // 1920 target_latent = self.silence_latent[0, :expected_latent_length, :] else: # Encode using helper method logger.info(f"[generate_music] Encoding target audio to latents for item {i}...") target_latent = self._encode_audio_to_latents(current_wav.squeeze(0)) # Remove batch dim for helper target_latents_list.append(target_latent) latent_lengths.append(target_latent.shape[0]) # Pad target_wavs to consistent length for outputs max_target_frames = max(wav.shape[-1] for wav in target_wavs_list) padded_target_wavs = [] for wav in target_wavs_list: if wav.shape[-1] < max_target_frames: pad_frames = max_target_frames - wav.shape[-1] wav = torch.nn.functional.pad(wav, (0, pad_frames), "constant", 0) padded_target_wavs.append(wav) target_wavs = torch.stack(padded_target_wavs) wav_lengths = torch.tensor([target_wavs.shape[-1]] * batch_size, dtype=torch.long) # Pad latents to same length max_latent_length = max(latent.shape[0] for latent in target_latents_list) max_latent_length = max(128, max_latent_length) padded_latents = [] for latent in target_latents_list: latent_length = latent.shape[0] if latent.shape[0] < max_latent_length: pad_length = max_latent_length - latent.shape[0] latent = torch.cat([latent, self.silence_latent[0, :pad_length, :]], dim=0) padded_latents.append(latent) target_latents = torch.stack(padded_latents) latent_masks = torch.stack([ torch.cat([ torch.ones(l, dtype=torch.long, device=self.device), torch.zeros(max_latent_length - l, dtype=torch.long, device=self.device) ]) for l in latent_lengths ]) # Process instructions early so we can use them for task type detection # Use custom instructions if provided, otherwise use default instructions = self._normalize_instructions(instructions, batch_size, DEFAULT_DIT_INSTRUCTION) # Generate chunk_masks and spans based on repainting parameters # Also determine if this is a cover task (target audio provided without repainting) chunk_masks = [] spans = [] is_covers = [] # Store repainting latent ranges for later use in src_latents creation repainting_ranges = {} # {batch_idx: (start_latent, end_latent)} for i in range(batch_size): has_code_hint = audio_code_hints[i] is not None # Check if repainting is enabled for this batch item has_repainting = False if repainting_start is not None and repainting_end is not None: start_sec = repainting_start[i] if repainting_start[i] is not None else 0.0 end_sec = repainting_end[i] if end_sec is not None and end_sec > start_sec: # Repainting mode with outpainting support # The target_wavs may have been padded for outpainting # Need to calculate the actual position in the padded audio # Calculate padding (if start < 0, there's left padding) left_padding_sec = max(0, -start_sec) # Adjust positions to account for padding # In the padded audio, the original start is shifted by left_padding adjusted_start_sec = start_sec + left_padding_sec adjusted_end_sec = end_sec + left_padding_sec # Convert seconds to latent frames (audio_frames / 1920 = latent_frames) start_latent = int(adjusted_start_sec * self.sample_rate // 1920) end_latent = int(adjusted_end_sec * self.sample_rate // 1920) # Clamp to valid range start_latent = max(0, min(start_latent, max_latent_length - 1)) end_latent = max(start_latent + 1, min(end_latent, max_latent_length)) # Create mask: False = keep original, True = generate new mask = torch.zeros(max_latent_length, dtype=torch.bool, device=self.device) mask[start_latent:end_latent] = True chunk_masks.append(mask) spans.append(("repainting", start_latent, end_latent)) # Store repainting range for later use repainting_ranges[i] = (start_latent, end_latent) has_repainting = True is_covers.append(False) # Repainting is not cover task else: # Full generation (no valid repainting range) chunk_masks.append(torch.ones(max_latent_length, dtype=torch.bool, device=self.device)) spans.append(("full", 0, max_latent_length)) # Determine task type from instruction, not from target_wavs # Only cover task should have is_cover=True instruction_i = instructions[i] if instructions and i < len(instructions) else "" instruction_lower = instruction_i.lower() # Cover task instruction: "Generate audio semantic tokens based on the given conditions:" is_cover = ("generate audio semantic tokens" in instruction_lower and "based on the given conditions" in instruction_lower) or has_code_hint is_covers.append(is_cover) else: # Full generation (no repainting parameters) chunk_masks.append(torch.ones(max_latent_length, dtype=torch.bool, device=self.device)) spans.append(("full", 0, max_latent_length)) # Determine task type from instruction, not from target_wavs # Only cover task should have is_cover=True instruction_i = instructions[i] if instructions and i < len(instructions) else "" instruction_lower = instruction_i.lower() # Cover task instruction: "Generate audio semantic tokens based on the given conditions:" is_cover = ("generate audio semantic tokens" in instruction_lower and "based on the given conditions" in instruction_lower) or has_code_hint is_covers.append(is_cover) chunk_masks = torch.stack(chunk_masks) is_covers = torch.BoolTensor(is_covers).to(self.device) # Create src_latents based on task type # For cover/extract/complete/lego/repaint tasks: src_latents = target_latents.clone() (if target_wavs provided) # For text2music task: src_latents = silence_latent (if no target_wavs or silence) # For repaint task: additionally replace inpainting region with silence_latent src_latents_list = [] silence_latent_tiled = self.silence_latent[0, :max_latent_length, :] for i in range(batch_size): # Check if target_wavs is provided and not silent (for extract/complete/lego/cover/repaint tasks) has_code_hint = audio_code_hints[i] is not None has_target_audio = has_code_hint or (target_wavs is not None and target_wavs[i].abs().sum() > 1e-6) if has_target_audio: # For tasks that use input audio (cover/extract/complete/lego/repaint) # Check if this item has repainting item_has_repainting = (i in repainting_ranges) if item_has_repainting: # Repaint task: src_latents = target_latents with inpainting region replaced by silence_latent # 1. Clone target_latents (encoded from src audio, preserving original audio) src_latent = target_latents[i].clone() # 2. Replace inpainting region with silence_latent start_latent, end_latent = repainting_ranges[i] src_latent[start_latent:end_latent] = silence_latent_tiled[start_latent:end_latent] src_latents_list.append(src_latent) else: # Cover/extract/complete/lego tasks: src_latents = target_latents.clone() # All these tasks need to base on input audio src_latents_list.append(target_latents[i].clone()) else: # Text2music task: src_latents = silence_latent (no input audio) # Use silence_latent for the full length src_latents_list.append(silence_latent_tiled.clone()) src_latents = torch.stack(src_latents_list) # Process audio_code_hints to generate precomputed_lm_hints_25Hz precomputed_lm_hints_25Hz_list = [] for i in range(batch_size): if audio_code_hints[i] is not None: # Decode audio codes to 25Hz latents logger.info(f"[generate_music] Decoding audio codes for LM hints for item {i}...") hints = self._decode_audio_codes_to_latents(audio_code_hints[i]) if hints is not None: # Pad or crop to match max_latent_length if hints.shape[1] < max_latent_length: pad_length = max_latent_length - hints.shape[1] pad = self.silence_latent # Match dims: hints is usually [1, T, D], silence_latent is [1, T, D] if pad.dim() == 2: pad = pad.unsqueeze(0) if hints.dim() == 2: hints = hints.unsqueeze(0) pad_chunk = pad[:, :pad_length, :] if pad_chunk.device != hints.device or pad_chunk.dtype != hints.dtype: pad_chunk = pad_chunk.to(device=hints.device, dtype=hints.dtype) hints = torch.cat([hints, pad_chunk], dim=1) elif hints.shape[1] > max_latent_length: hints = hints[:, :max_latent_length, :] precomputed_lm_hints_25Hz_list.append(hints[0]) # Remove batch dimension else: precomputed_lm_hints_25Hz_list.append(None) else: precomputed_lm_hints_25Hz_list.append(None) # Stack precomputed hints if any exist, otherwise set to None if any(h is not None for h in precomputed_lm_hints_25Hz_list): # For items without hints, use silence_latent as placeholder precomputed_lm_hints_25Hz = torch.stack([ h if h is not None else silence_latent_tiled for h in precomputed_lm_hints_25Hz_list ]) else: precomputed_lm_hints_25Hz = None # Extract caption and language from metas if available (from LM CoT output) # Fallback to user-provided values if not in metas actual_captions, actual_languages = self._extract_caption_and_language(parsed_metas, captions, vocal_languages) # Format text_inputs text_inputs = [] text_token_idss = [] text_attention_masks = [] lyric_token_idss = [] lyric_attention_masks = [] for i in range(batch_size): # Use custom instruction for this batch item instruction = self._format_instruction(instructions[i] if i < len(instructions) else DEFAULT_DIT_INSTRUCTION) actual_caption = actual_captions[i] actual_language = actual_languages[i] # Format text prompt with custom instruction (using LM-generated caption if available) text_prompt = SFT_GEN_PROMPT.format(instruction, actual_caption, parsed_metas[i]) # Tokenize text text_inputs_dict = self.text_tokenizer( text_prompt, padding="longest", truncation=True, max_length=256, return_tensors="pt", ) text_token_ids = text_inputs_dict.input_ids[0] text_attention_mask = text_inputs_dict.attention_mask[0].bool() # Format and tokenize lyrics (using LM-generated language if available) lyrics_text = self._format_lyrics(lyrics[i], actual_language) lyrics_inputs_dict = self.text_tokenizer( lyrics_text, padding="longest", truncation=True, max_length=2048, return_tensors="pt", ) lyric_token_ids = lyrics_inputs_dict.input_ids[0] lyric_attention_mask = lyrics_inputs_dict.attention_mask[0].bool() # Build full text input text_input = text_prompt + "\n\n" + lyrics_text text_inputs.append(text_input) text_token_idss.append(text_token_ids) text_attention_masks.append(text_attention_mask) lyric_token_idss.append(lyric_token_ids) lyric_attention_masks.append(lyric_attention_mask) # Pad tokenized sequences max_text_length = max(len(seq) for seq in text_token_idss) padded_text_token_idss = self._pad_sequences(text_token_idss, max_text_length, self.text_tokenizer.pad_token_id) padded_text_attention_masks = self._pad_sequences(text_attention_masks, max_text_length, 0) max_lyric_length = max(len(seq) for seq in lyric_token_idss) padded_lyric_token_idss = self._pad_sequences(lyric_token_idss, max_lyric_length, self.text_tokenizer.pad_token_id) padded_lyric_attention_masks = self._pad_sequences(lyric_attention_masks, max_lyric_length, 0) padded_non_cover_text_input_ids = None padded_non_cover_text_attention_masks = None if audio_cover_strength < 1.0: non_cover_text_input_ids = [] non_cover_text_attention_masks = [] for i in range(batch_size): # Use custom instruction for this batch item instruction = self._format_instruction(DEFAULT_DIT_INSTRUCTION) # Extract caption from metas if available (from LM CoT output) actual_caption = actual_captions[i] # Format text prompt with custom instruction (using LM-generated caption if available) text_prompt = SFT_GEN_PROMPT.format(instruction, actual_caption, parsed_metas[i]) # Tokenize text text_inputs_dict = self.text_tokenizer( text_prompt, padding="longest", truncation=True, max_length=256, return_tensors="pt", ) text_token_ids = text_inputs_dict.input_ids[0] non_cover_text_attention_mask = text_inputs_dict.attention_mask[0].bool() non_cover_text_input_ids.append(text_token_ids) non_cover_text_attention_masks.append(non_cover_text_attention_mask) padded_non_cover_text_input_ids = self._pad_sequences(non_cover_text_input_ids, max_text_length, self.text_tokenizer.pad_token_id) padded_non_cover_text_attention_masks = self._pad_sequences(non_cover_text_attention_masks, max_text_length, 0) if audio_cover_strength < 1.0: assert padded_non_cover_text_input_ids is not None, "When audio_cover_strength < 1.0, padded_non_cover_text_input_ids must not be None" assert padded_non_cover_text_attention_masks is not None, "When audio_cover_strength < 1.0, padded_non_cover_text_attention_masks must not be None" # Prepare batch batch = { "keys": keys, "target_wavs": target_wavs.to(self.device), "refer_audioss": refer_audios, "wav_lengths": wav_lengths.to(self.device), "captions": captions, "lyrics": lyrics, "metas": parsed_metas, "vocal_languages": vocal_languages, "target_latents": target_latents, "src_latents": src_latents, "latent_masks": latent_masks, "chunk_masks": chunk_masks, "spans": spans, "text_inputs": text_inputs, "text_token_idss": padded_text_token_idss, "text_attention_masks": padded_text_attention_masks, "lyric_token_idss": padded_lyric_token_idss, "lyric_attention_masks": padded_lyric_attention_masks, "is_covers": is_covers, "precomputed_lm_hints_25Hz": precomputed_lm_hints_25Hz, "non_cover_text_input_ids": padded_non_cover_text_input_ids, "non_cover_text_attention_masks": padded_non_cover_text_attention_masks, } # to device for k, v in batch.items(): if isinstance(v, torch.Tensor): batch[k] = v.to(self.device) if torch.is_floating_point(v): batch[k] = v.to(self.dtype) return batch def infer_refer_latent(self, refer_audioss): refer_audio_order_mask = [] refer_audio_latents = [] # Ensure silence_latent is on the correct device self._ensure_silence_latent_on_device() def _normalize_audio_2d(a: torch.Tensor) -> torch.Tensor: """Normalize audio tensor to [2, T] on current device.""" if not isinstance(a, torch.Tensor): raise TypeError(f"refer_audio must be a torch.Tensor, got {type(a)!r}") # Accept [T], [1, T], [2, T], [1, 2, T] if a.dim() == 3 and a.shape[0] == 1: a = a.squeeze(0) if a.dim() == 1: a = a.unsqueeze(0) if a.dim() != 2: raise ValueError(f"refer_audio must be 1D/2D/3D(1,2,T); got shape={tuple(a.shape)}") if a.shape[0] == 1: a = torch.cat([a, a], dim=0) a = a[:2] return a def _ensure_latent_3d(z: torch.Tensor) -> torch.Tensor: """Ensure latent is [N, T, D] (3D) for packing.""" if z.dim() == 4 and z.shape[0] == 1: z = z.squeeze(0) if z.dim() == 2: z = z.unsqueeze(0) return z for batch_idx, refer_audios in enumerate(refer_audioss): if len(refer_audios) == 1 and torch.all(refer_audios[0] == 0.0): refer_audio_latent = _ensure_latent_3d(self.silence_latent[:, :750, :]) refer_audio_latents.append(refer_audio_latent) refer_audio_order_mask.append(batch_idx) else: for refer_audio in refer_audios: refer_audio = _normalize_audio_2d(refer_audio) # Ensure input is in VAE's dtype vae_input = refer_audio.unsqueeze(0).to(self.vae.dtype) refer_audio_latent = self.vae.encode(vae_input).latent_dist.sample() # Cast back to model dtype refer_audio_latent = refer_audio_latent.to(self.dtype) refer_audio_latents.append(_ensure_latent_3d(refer_audio_latent.transpose(1, 2))) refer_audio_order_mask.append(batch_idx) refer_audio_latents = torch.cat(refer_audio_latents, dim=0) refer_audio_order_mask = torch.tensor(refer_audio_order_mask, device=self.device, dtype=torch.long) return refer_audio_latents, refer_audio_order_mask def infer_text_embeddings(self, text_token_idss): with torch.no_grad(): text_embeddings = self.text_encoder(input_ids=text_token_idss, lyric_attention_mask=None).last_hidden_state return text_embeddings def infer_lyric_embeddings(self, lyric_token_ids): with torch.no_grad(): lyric_embeddings = self.text_encoder.embed_tokens(lyric_token_ids) return lyric_embeddings def preprocess_batch(self, batch): # step 1: VAE encode latents, target_latents: N x T x d # target_latents: N x T x d target_latents = batch["target_latents"] src_latents = batch["src_latents"] attention_mask = batch["latent_masks"] audio_codes = batch.get("audio_codes", None) audio_attention_mask = attention_mask dtype = target_latents.dtype bs = target_latents.shape[0] device = target_latents.device # step 2: refer_audio timbre keys = batch["keys"] with self._load_model_context("vae"): refer_audio_acoustic_hidden_states_packed, refer_audio_order_mask = self.infer_refer_latent(batch["refer_audioss"]) if refer_audio_acoustic_hidden_states_packed.dtype != dtype: refer_audio_acoustic_hidden_states_packed = refer_audio_acoustic_hidden_states_packed.to(dtype) # step 4: chunk mask, N x T x d chunk_mask = batch["chunk_masks"] chunk_mask = chunk_mask.to(device).unsqueeze(-1).repeat(1, 1, target_latents.shape[2]) spans = batch["spans"] text_token_idss = batch["text_token_idss"] text_attention_mask = batch["text_attention_masks"] lyric_token_idss = batch["lyric_token_idss"] lyric_attention_mask = batch["lyric_attention_masks"] text_inputs = batch["text_inputs"] logger.info("[preprocess_batch] Inferring prompt embeddings...") with self._load_model_context("text_encoder"): text_hidden_states = self.infer_text_embeddings(text_token_idss) logger.info("[preprocess_batch] Inferring lyric embeddings...") lyric_hidden_states = self.infer_lyric_embeddings(lyric_token_idss) is_covers = batch["is_covers"] # Get precomputed hints from batch if available precomputed_lm_hints_25Hz = batch.get("precomputed_lm_hints_25Hz", None) # Get non-cover text input ids and attention masks from batch if available non_cover_text_input_ids = batch.get("non_cover_text_input_ids", None) non_cover_text_attention_masks = batch.get("non_cover_text_attention_masks", None) non_cover_text_hidden_states = None if non_cover_text_input_ids is not None: logger.info("[preprocess_batch] Inferring non-cover text embeddings...") non_cover_text_hidden_states = self.infer_text_embeddings(non_cover_text_input_ids) return ( keys, text_inputs, src_latents, target_latents, # model inputs text_hidden_states, text_attention_mask, lyric_hidden_states, lyric_attention_mask, audio_attention_mask, refer_audio_acoustic_hidden_states_packed, refer_audio_order_mask, chunk_mask, spans, is_covers, audio_codes, lyric_token_idss, precomputed_lm_hints_25Hz, non_cover_text_hidden_states, non_cover_text_attention_masks, ) @torch.no_grad() def service_generate( self, captions: Union[str, List[str]], lyrics: Union[str, List[str]], keys: Optional[Union[str, List[str]]] = None, target_wavs: Optional[torch.Tensor] = None, refer_audios: Optional[List[List[torch.Tensor]]] = None, metas: Optional[Union[str, Dict[str, Any], List[Union[str, Dict[str, Any]]]]] = None, vocal_languages: Optional[Union[str, List[str]]] = None, infer_steps: int = 60, guidance_scale: float = 7.0, seed: Optional[Union[int, List[int]]] = None, return_intermediate: bool = False, repainting_start: Optional[Union[float, List[float]]] = None, repainting_end: Optional[Union[float, List[float]]] = None, instructions: Optional[Union[str, List[str]]] = None, audio_cover_strength: float = 1.0, use_adg: bool = False, cfg_interval_start: float = 0.0, cfg_interval_end: float = 1.0, shift: float = 1.0, audio_code_hints: Optional[Union[str, List[str]]] = None, infer_method: str = "ode", timesteps: Optional[List[float]] = None, ) -> Dict[str, Any]: """ Generate music from text inputs. Args: captions: Text caption(s) describing the music (optional, can be empty strings) lyrics: Lyric text(s) (optional, can be empty strings) keys: Unique identifier(s) (optional) target_wavs: Target audio tensor(s) for conditioning (optional) refer_audios: Reference audio tensor(s) for style transfer (optional) metas: Metadata dict(s) or string(s) (optional) vocal_languages: Language code(s) for lyrics (optional, defaults to 'en') infer_steps: Number of inference steps (default: 60) guidance_scale: Guidance scale for generation (default: 7.0) seed: Random seed (optional) return_intermediate: Whether to return intermediate results (default: False) repainting_start: Start time(s) for repainting region in seconds (optional) repainting_end: End time(s) for repainting region in seconds (optional) instructions: Instruction text(s) for generation (optional) audio_cover_strength: Strength of audio cover mode (default: 1.0) use_adg: Whether to use ADG (Adaptive Diffusion Guidance) (default: False) cfg_interval_start: Start of CFG interval (0.0-1.0, default: 0.0) cfg_interval_end: End of CFG interval (0.0-1.0, default: 1.0) Returns: Dictionary containing: - pred_wavs: Generated audio tensors - target_wavs: Input target audio (if provided) - vqvae_recon_wavs: VAE reconstruction of target - keys: Identifiers used - text_inputs: Formatted text inputs - sr: Sample rate - spans: Generation spans - time_costs: Timing information - seed_num: Seed used """ if self.config.is_turbo: # Limit inference steps to maximum 8 if infer_steps > 8: logger.warning(f"[service_generate] dmd_gan version: infer_steps {infer_steps} exceeds maximum 8, clamping to 8") infer_steps = 8 # CFG parameters are not adjustable for dmd_gan (they will be ignored) # Note: guidance_scale, cfg_interval_start, cfg_interval_end are still passed but may be ignored by the model # Convert single inputs to lists if isinstance(captions, str): captions = [captions] if isinstance(lyrics, str): lyrics = [lyrics] if isinstance(keys, str): keys = [keys] if isinstance(vocal_languages, str): vocal_languages = [vocal_languages] if isinstance(metas, (str, dict)): metas = [metas] # Convert repainting parameters to lists if isinstance(repainting_start, (int, float)): repainting_start = [repainting_start] if isinstance(repainting_end, (int, float)): repainting_end = [repainting_end] # Get batch size from captions batch_size = len(captions) # Normalize instructions and audio_code_hints to match batch size instructions = self._normalize_instructions(instructions, batch_size, DEFAULT_DIT_INSTRUCTION) if instructions is not None else None audio_code_hints = self._normalize_audio_code_hints(audio_code_hints, batch_size) if audio_code_hints is not None else None # Convert seed to list format if seed is None: seed_list = None elif isinstance(seed, list): seed_list = seed # Ensure we have enough seeds for batch size if len(seed_list) < batch_size: # Pad with last seed or random seeds import random while len(seed_list) < batch_size: seed_list.append(random.randint(0, 2**32 - 1)) elif len(seed_list) > batch_size: # Truncate to batch size seed_list = seed_list[:batch_size] else: # Single seed value - use for all batch items seed_list = [int(seed)] * batch_size # Don't set global random seed here - each item will use its own seed # Prepare batch batch = self._prepare_batch( captions=captions, lyrics=lyrics, keys=keys, target_wavs=target_wavs, refer_audios=refer_audios, metas=metas, vocal_languages=vocal_languages, repainting_start=repainting_start, repainting_end=repainting_end, instructions=instructions, audio_code_hints=audio_code_hints, audio_cover_strength=audio_cover_strength, ) processed_data = self.preprocess_batch(batch) ( keys, text_inputs, src_latents, target_latents, # model inputs text_hidden_states, text_attention_mask, lyric_hidden_states, lyric_attention_mask, audio_attention_mask, refer_audio_acoustic_hidden_states_packed, refer_audio_order_mask, chunk_mask, spans, is_covers, audio_codes, lyric_token_idss, precomputed_lm_hints_25Hz, non_cover_text_hidden_states, non_cover_text_attention_masks, ) = processed_data # Set generation parameters # Use seed_list if available, otherwise generate a single seed if seed_list is not None: # Pass seed list to model (will be handled there) seed_param = seed_list else: seed_param = random.randint(0, 2**32 - 1) # Ensure silence_latent is on the correct device before creating generate_kwargs self._ensure_silence_latent_on_device() generate_kwargs = { "text_hidden_states": text_hidden_states, "text_attention_mask": text_attention_mask, "lyric_hidden_states": lyric_hidden_states, "lyric_attention_mask": lyric_attention_mask, "refer_audio_acoustic_hidden_states_packed": refer_audio_acoustic_hidden_states_packed, "refer_audio_order_mask": refer_audio_order_mask, "src_latents": src_latents, "chunk_masks": chunk_mask, "is_covers": is_covers, "silence_latent": self.silence_latent, "seed": seed_param, "non_cover_text_hidden_states": non_cover_text_hidden_states, "non_cover_text_attention_mask": non_cover_text_attention_masks, "precomputed_lm_hints_25Hz": precomputed_lm_hints_25Hz, "audio_cover_strength": audio_cover_strength, "infer_method": infer_method, "infer_steps": infer_steps, "diffusion_guidance_sale": guidance_scale, "use_adg": use_adg, "cfg_interval_start": cfg_interval_start, "cfg_interval_end": cfg_interval_end, "shift": shift, } # Add custom timesteps if provided (convert to tensor) if timesteps is not None: generate_kwargs["timesteps"] = torch.tensor(timesteps, dtype=torch.float32) logger.info("[service_generate] Generating audio...") with self._load_model_context("model"): # Prepare condition tensors first (for LRC timestamp generation) encoder_hidden_states, encoder_attention_mask, context_latents = self.model.prepare_condition( text_hidden_states=text_hidden_states, text_attention_mask=text_attention_mask, lyric_hidden_states=lyric_hidden_states, lyric_attention_mask=lyric_attention_mask, refer_audio_acoustic_hidden_states_packed=refer_audio_acoustic_hidden_states_packed, refer_audio_order_mask=refer_audio_order_mask, hidden_states=src_latents, attention_mask=torch.ones(src_latents.shape[0], src_latents.shape[1], device=src_latents.device, dtype=src_latents.dtype), silence_latent=self.silence_latent, src_latents=src_latents, chunk_masks=chunk_mask, is_covers=is_covers, precomputed_lm_hints_25Hz=precomputed_lm_hints_25Hz, ) outputs = self.model.generate_audio(**generate_kwargs) # Add intermediate information to outputs for extra_outputs outputs["src_latents"] = src_latents outputs["target_latents_input"] = target_latents # Input target latents (before generation) outputs["chunk_masks"] = chunk_mask outputs["spans"] = spans outputs["latent_masks"] = batch.get("latent_masks") # Latent masks for valid length # Add condition tensors for LRC timestamp generation outputs["encoder_hidden_states"] = encoder_hidden_states outputs["encoder_attention_mask"] = encoder_attention_mask outputs["context_latents"] = context_latents outputs["lyric_token_idss"] = lyric_token_idss return outputs def tiled_decode(self, latents, chunk_size=512, overlap=64, offload_wav_to_cpu=False): """ Decode latents using tiling to reduce VRAM usage. Uses overlap-discard strategy to avoid boundary artifacts. Args: latents: [Batch, Channels, Length] chunk_size: Size of latent chunk to process at once overlap: Overlap size in latent frames offload_wav_to_cpu: If True, offload decoded wav audio to CPU immediately to save VRAM """ B, C, T = latents.shape # If short enough, decode directly if T <= chunk_size: return self.vae.decode(latents).sample # Calculate stride (core size) stride = chunk_size - 2 * overlap if stride <= 0: raise ValueError(f"chunk_size {chunk_size} must be > 2 * overlap {overlap}") num_steps = math.ceil(T / stride) if offload_wav_to_cpu: # Optimized path: offload wav to CPU immediately to save VRAM return self._tiled_decode_offload_cpu(latents, B, T, stride, overlap, num_steps) else: # Default path: keep everything on GPU return self._tiled_decode_gpu(latents, B, T, stride, overlap, num_steps) def _tiled_decode_gpu(self, latents, B, T, stride, overlap, num_steps): """Standard tiled decode keeping all data on GPU.""" decoded_audio_list = [] upsample_factor = None for i in tqdm(range(num_steps), desc="Decoding audio chunks"): # Core range in latents core_start = i * stride core_end = min(core_start + stride, T) # Window range (with overlap) win_start = max(0, core_start - overlap) win_end = min(T, core_end + overlap) # Extract chunk latent_chunk = latents[:, :, win_start:win_end] # Decode # [Batch, Channels, AudioSamples] audio_chunk = self.vae.decode(latent_chunk).sample # Determine upsample factor from the first chunk if upsample_factor is None: upsample_factor = audio_chunk.shape[-1] / latent_chunk.shape[-1] # Calculate trim amounts in audio samples # How much overlap was added at the start? added_start = core_start - win_start # latent frames trim_start = int(round(added_start * upsample_factor)) # How much overlap was added at the end? added_end = win_end - core_end # latent frames trim_end = int(round(added_end * upsample_factor)) # Trim audio audio_len = audio_chunk.shape[-1] end_idx = audio_len - trim_end if trim_end > 0 else audio_len audio_core = audio_chunk[:, :, trim_start:end_idx] decoded_audio_list.append(audio_core) # Concatenate final_audio = torch.cat(decoded_audio_list, dim=-1) return final_audio def _tiled_decode_offload_cpu(self, latents, B, T, stride, overlap, num_steps): """Optimized tiled decode that offloads to CPU immediately to save VRAM.""" # First pass: decode first chunk to get upsample_factor and audio channels first_core_start = 0 first_core_end = min(stride, T) first_win_start = 0 first_win_end = min(T, first_core_end + overlap) first_latent_chunk = latents[:, :, first_win_start:first_win_end] first_audio_chunk = self.vae.decode(first_latent_chunk).sample upsample_factor = first_audio_chunk.shape[-1] / first_latent_chunk.shape[-1] audio_channels = first_audio_chunk.shape[1] # Calculate total audio length and pre-allocate CPU tensor total_audio_length = int(round(T * upsample_factor)) final_audio = torch.zeros(B, audio_channels, total_audio_length, dtype=first_audio_chunk.dtype, device='cpu') # Process first chunk: trim and copy to CPU first_added_end = first_win_end - first_core_end first_trim_end = int(round(first_added_end * upsample_factor)) first_audio_len = first_audio_chunk.shape[-1] first_end_idx = first_audio_len - first_trim_end if first_trim_end > 0 else first_audio_len first_audio_core = first_audio_chunk[:, :, :first_end_idx] audio_write_pos = first_audio_core.shape[-1] final_audio[:, :, :audio_write_pos] = first_audio_core.cpu() # Free GPU memory del first_audio_chunk, first_audio_core, first_latent_chunk # Process remaining chunks for i in tqdm(range(1, num_steps), desc="Decoding audio chunks"): # Core range in latents core_start = i * stride core_end = min(core_start + stride, T) # Window range (with overlap) win_start = max(0, core_start - overlap) win_end = min(T, core_end + overlap) # Extract chunk latent_chunk = latents[:, :, win_start:win_end] # Decode on GPU # [Batch, Channels, AudioSamples] audio_chunk = self.vae.decode(latent_chunk).sample # Calculate trim amounts in audio samples added_start = core_start - win_start # latent frames trim_start = int(round(added_start * upsample_factor)) added_end = win_end - core_end # latent frames trim_end = int(round(added_end * upsample_factor)) # Trim audio audio_len = audio_chunk.shape[-1] end_idx = audio_len - trim_end if trim_end > 0 else audio_len audio_core = audio_chunk[:, :, trim_start:end_idx] # Copy to pre-allocated CPU tensor core_len = audio_core.shape[-1] final_audio[:, :, audio_write_pos:audio_write_pos + core_len] = audio_core.cpu() audio_write_pos += core_len # Free GPU memory immediately del audio_chunk, audio_core, latent_chunk # Trim to actual length (in case of rounding differences) final_audio = final_audio[:, :, :audio_write_pos] return final_audio def generate_music( self, captions: str, lyrics: str, bpm: Optional[int] = None, key_scale: str = "", time_signature: str = "", vocal_language: str = "en", inference_steps: int = 8, guidance_scale: float = 7.0, use_random_seed: bool = True, seed: Optional[Union[str, float, int]] = -1, reference_audio=None, audio_duration: Optional[float] = None, batch_size: Optional[int] = None, src_audio=None, audio_code_string: Union[str, List[str]] = "", repainting_start: float = 0.0, repainting_end: Optional[float] = None, instruction: str = DEFAULT_DIT_INSTRUCTION, audio_cover_strength: float = 1.0, task_type: str = "text2music", use_adg: bool = False, cfg_interval_start: float = 0.0, cfg_interval_end: float = 1.0, shift: float = 1.0, infer_method: str = "ode", use_tiled_decode: bool = True, timesteps: Optional[List[float]] = None, progress=None ) -> Dict[str, Any]: """ Main interface for music generation Returns: Dictionary containing: - audios: List of audio dictionaries with path, key, params - generation_info: Markdown-formatted generation information - status_message: Status message - extra_outputs: Dictionary with latents, masks, time_costs, etc. - success: Whether generation completed successfully - error: Error message if generation failed """ if progress is None: def progress(*args, **kwargs): pass if self.model is None or self.vae is None or self.text_tokenizer is None or self.text_encoder is None: return { "audios": [], "status_message": "❌ Model not fully initialized. Please initialize all components first.", "extra_outputs": {}, "success": False, "error": "Model not fully initialized", } def _has_audio_codes(v: Union[str, List[str]]) -> bool: if isinstance(v, list): return any((x or "").strip() for x in v) return bool(v and str(v).strip()) # Auto-detect task type based on audio_code_string # If audio_code_string is provided and not empty, use cover task # Otherwise, use text2music task (or keep current task_type if not text2music) if task_type == "text2music": if _has_audio_codes(audio_code_string): # User has provided audio codes, switch to cover task task_type = "cover" # Update instruction for cover task instruction = TASK_INSTRUCTIONS["cover"] logger.info("[generate_music] Starting generation...") if progress: progress(0.51, desc="Preparing inputs...") logger.info("[generate_music] Preparing inputs...") # Reset offload cost self.current_offload_cost = 0.0 # Caption and lyrics are optional - can be empty # Use provided batch_size or default actual_batch_size = batch_size if batch_size is not None else self.batch_size actual_batch_size = max(1, actual_batch_size) # Ensure at least 1 actual_seed_list, seed_value_for_ui = self.prepare_seeds(actual_batch_size, seed, use_random_seed) # Convert special values to None if audio_duration is not None and audio_duration <= 0: audio_duration = None # if seed is not None and seed < 0: # seed = None if repainting_end is not None and repainting_end < 0: repainting_end = None try: # 1. Process reference audio refer_audios = None if reference_audio is not None: logger.info("[generate_music] Processing reference audio...") processed_ref_audio = self.process_reference_audio(reference_audio) if processed_ref_audio is not None: # Convert to the format expected by the service: List[List[torch.Tensor]] # Each batch item has a list of reference audios refer_audios = [[processed_ref_audio] for _ in range(actual_batch_size)] else: refer_audios = [[torch.zeros(2, 30*self.sample_rate)] for _ in range(actual_batch_size)] # 2. Process source audio # If audio_code_string is provided, ignore src_audio and use codes instead processed_src_audio = None if src_audio is not None: # Check if audio codes are provided - if so, ignore src_audio if _has_audio_codes(audio_code_string): logger.info("[generate_music] Audio codes provided, ignoring src_audio and using codes instead") else: logger.info("[generate_music] Processing source audio...") processed_src_audio = self.process_src_audio(src_audio) # 3. Prepare batch data captions_batch, instructions_batch, lyrics_batch, vocal_languages_batch, metas_batch = self.prepare_batch_data( actual_batch_size, processed_src_audio, audio_duration, captions, lyrics, vocal_language, instruction, bpm, key_scale, time_signature ) is_repaint_task, is_lego_task, is_cover_task, can_use_repainting = self.determine_task_type(task_type, audio_code_string) repainting_start_batch, repainting_end_batch, target_wavs_tensor = self.prepare_padding_info( actual_batch_size, processed_src_audio, audio_duration, repainting_start, repainting_end, is_repaint_task, is_lego_task, is_cover_task, can_use_repainting ) progress(0.52, desc=f"Generating music (batch size: {actual_batch_size})...") # Prepare audio_code_hints - use if audio_code_string is provided # This works for both text2music (auto-switched to cover) and cover tasks audio_code_hints_batch = None if _has_audio_codes(audio_code_string): if isinstance(audio_code_string, list): audio_code_hints_batch = audio_code_string else: audio_code_hints_batch = [audio_code_string] * actual_batch_size should_return_intermediate = (task_type == "text2music") outputs = self.service_generate( captions=captions_batch, lyrics=lyrics_batch, metas=metas_batch, # Pass as dict, service will convert to string vocal_languages=vocal_languages_batch, refer_audios=refer_audios, # Already in List[List[torch.Tensor]] format target_wavs=target_wavs_tensor, # Shape: [batch_size, 2, frames] infer_steps=inference_steps, guidance_scale=guidance_scale, seed=actual_seed_list, # Pass list of seeds, one per batch item repainting_start=repainting_start_batch, repainting_end=repainting_end_batch, instructions=instructions_batch, # Pass instructions to service audio_cover_strength=audio_cover_strength, # Pass audio cover strength use_adg=use_adg, # Pass use_adg parameter cfg_interval_start=cfg_interval_start, # Pass CFG interval start cfg_interval_end=cfg_interval_end, # Pass CFG interval end shift=shift, # Pass shift parameter infer_method=infer_method, # Pass infer method (ode or sde) audio_code_hints=audio_code_hints_batch, # Pass audio code hints as list return_intermediate=should_return_intermediate, timesteps=timesteps, # Pass custom timesteps if provided ) logger.info("[generate_music] Model generation completed. Decoding latents...") pred_latents = outputs["target_latents"] # [batch, latent_length, latent_dim] time_costs = outputs["time_costs"] time_costs["offload_time_cost"] = self.current_offload_cost logger.debug(f"[generate_music] pred_latents: {pred_latents.shape}, dtype={pred_latents.dtype} {pred_latents.min()=}, {pred_latents.max()=}, {pred_latents.mean()=} {pred_latents.std()=}") logger.debug(f"[generate_music] time_costs: {time_costs}") if progress: progress(0.8, desc="Decoding audio...") logger.info("[generate_music] Decoding latents with VAE...") # Decode latents to audio start_time = time.time() with torch.no_grad(): with self._load_model_context("vae"): # Transpose for VAE decode: [batch, latent_length, latent_dim] -> [batch, latent_dim, latent_length] pred_latents_for_decode = pred_latents.transpose(1, 2) # Ensure input is in VAE's dtype pred_latents_for_decode = pred_latents_for_decode.to(self.vae.dtype) if use_tiled_decode: logger.info("[generate_music] Using tiled VAE decode to reduce VRAM usage...") pred_wavs = self.tiled_decode(pred_latents_for_decode) # [batch, channels, samples] else: pred_wavs = self.vae.decode(pred_latents_for_decode).sample # Cast output to float32 for audio processing/saving pred_wavs = pred_wavs.to(torch.float32) end_time = time.time() time_costs["vae_decode_time_cost"] = end_time - start_time time_costs["total_time_cost"] = time_costs["total_time_cost"] + time_costs["vae_decode_time_cost"] # Update offload cost one last time to include VAE offloading time_costs["offload_time_cost"] = self.current_offload_cost logger.info("[generate_music] VAE decode completed. Preparing audio tensors...") if progress: progress(0.99, desc="Preparing audio data...") # Prepare audio tensors (no file I/O here, no UUID generation) # pred_wavs is already [batch, channels, samples] format # Move to CPU and convert to float32 for return audio_tensors = [] for i in range(actual_batch_size): # Extract audio tensor: [channels, samples] format, CPU, float32 audio_tensor = pred_wavs[i].cpu().float() audio_tensors.append(audio_tensor) status_message = f"✅ Generation completed successfully!" logger.info(f"[generate_music] Done! Generated {len(audio_tensors)} audio tensors.") # Extract intermediate information from outputs src_latents = outputs.get("src_latents") # [batch, T, D] target_latents_input = outputs.get("target_latents_input") # [batch, T, D] chunk_masks = outputs.get("chunk_masks") # [batch, T] spans = outputs.get("spans", []) # List of tuples latent_masks = outputs.get("latent_masks") # [batch, T] # Extract condition tensors for LRC timestamp generation encoder_hidden_states = outputs.get("encoder_hidden_states") encoder_attention_mask = outputs.get("encoder_attention_mask") context_latents = outputs.get("context_latents") lyric_token_idss = outputs.get("lyric_token_idss") # Move all tensors to CPU to save VRAM (detach to release computation graph) extra_outputs = { "pred_latents": pred_latents.detach().cpu() if pred_latents is not None else None, "target_latents": target_latents_input.detach().cpu() if target_latents_input is not None else None, "src_latents": src_latents.detach().cpu() if src_latents is not None else None, "chunk_masks": chunk_masks.detach().cpu() if chunk_masks is not None else None, "latent_masks": latent_masks.detach().cpu() if latent_masks is not None else None, "spans": spans, "time_costs": time_costs, "seed_value": seed_value_for_ui, # Condition tensors for LRC timestamp generation "encoder_hidden_states": encoder_hidden_states.detach().cpu() if encoder_hidden_states is not None else None, "encoder_attention_mask": encoder_attention_mask.detach().cpu() if encoder_attention_mask is not None else None, "context_latents": context_latents.detach().cpu() if context_latents is not None else None, "lyric_token_idss": lyric_token_idss.detach().cpu() if lyric_token_idss is not None else None, } # Build audios list with tensor data (no file paths, no UUIDs, handled outside) audios = [] for idx, audio_tensor in enumerate(audio_tensors): audio_dict = { "tensor": audio_tensor, # torch.Tensor [channels, samples], CPU, float32 "sample_rate": self.sample_rate, } audios.append(audio_dict) return { "audios": audios, "status_message": status_message, "extra_outputs": extra_outputs, "success": True, "error": None, } except Exception as e: error_msg = f"❌ Error: {str(e)}\n{traceback.format_exc()}" logger.exception("[generate_music] Generation failed") # Clean up CUDA state after any error (especially important for CUDA errors) if torch.cuda.is_available(): try: torch.cuda.synchronize() except Exception: pass # Ignore sync errors during cleanup torch.cuda.empty_cache() return { "audios": [], "status_message": error_msg, "extra_outputs": {}, "success": False, "error": str(e), } @torch.no_grad() def get_lyric_timestamp( self, pred_latent: torch.Tensor, encoder_hidden_states: torch.Tensor, encoder_attention_mask: torch.Tensor, context_latents: torch.Tensor, lyric_token_ids: torch.Tensor, total_duration_seconds: float, vocal_language: str = "en", inference_steps: int = 8, seed: int = 42, custom_layers_config: Optional[Dict] = None, ) -> Dict[str, Any]: """ Generate lyrics timestamps from generated audio latents using cross-attention alignment. This method adds noise to the final pred_latent and re-infers one step to get cross-attention matrices, then uses DTW to align lyrics tokens with audio frames. Args: pred_latent: Generated latent tensor [batch, T, D] encoder_hidden_states: Cached encoder hidden states encoder_attention_mask: Cached encoder attention mask context_latents: Cached context latents lyric_token_ids: Tokenized lyrics tensor [batch, seq_len] total_duration_seconds: Total audio duration in seconds vocal_language: Language code for lyrics header parsing inference_steps: Number of inference steps (for noise level calculation) seed: Random seed for noise generation custom_layers_config: Dict mapping layer indices to head indices Returns: Dict containing: - lrc_text: LRC formatted lyrics with timestamps - sentence_timestamps: List of SentenceTimestamp objects - token_timestamps: List of TokenTimestamp objects - success: Whether generation succeeded - error: Error message if failed """ from transformers.cache_utils import EncoderDecoderCache, DynamicCache if self.model is None: return { "lrc_text": "", "sentence_timestamps": [], "token_timestamps": [], "success": False, "error": "Model not initialized" } if custom_layers_config is None: custom_layers_config = self.custom_layers_config try: # Move tensors to device device = self.device dtype = self.dtype pred_latent = pred_latent.to(device=device, dtype=dtype) encoder_hidden_states = encoder_hidden_states.to(device=device, dtype=dtype) encoder_attention_mask = encoder_attention_mask.to(device=device, dtype=dtype) context_latents = context_latents.to(device=device, dtype=dtype) bsz = pred_latent.shape[0] # Calculate noise level: t_last = 1.0 / inference_steps t_last_val = 1.0 / inference_steps t_curr_tensor = torch.tensor([t_last_val] * bsz, device=device, dtype=dtype) x1 = pred_latent # Generate noise if seed is None: x0 = torch.randn_like(x1) else: generator = torch.Generator(device=device).manual_seed(int(seed)) x0 = torch.randn(x1.shape, generator=generator, device=device, dtype=dtype) # Add noise to pred_latent: xt = t * noise + (1 - t) * x1 xt = t_last_val * x0 + (1.0 - t_last_val) * x1 xt_in = xt t_in = t_curr_tensor # Get null condition embedding encoder_hidden_states_in = encoder_hidden_states encoder_attention_mask_in = encoder_attention_mask context_latents_in = context_latents latent_length = x1.shape[1] attention_mask = torch.ones(bsz, latent_length, device=device, dtype=dtype) attention_mask_in = attention_mask past_key_values = None # Run decoder with output_attentions=True with self._load_model_context("model"): decoder = self.model.decoder decoder_outputs = decoder( hidden_states=xt_in, timestep=t_in, timestep_r=t_in, attention_mask=attention_mask_in, encoder_hidden_states=encoder_hidden_states_in, use_cache=False, past_key_values=past_key_values, encoder_attention_mask=encoder_attention_mask_in, context_latents=context_latents_in, output_attentions=True, custom_layers_config=custom_layers_config, enable_early_exit=True ) # Extract cross-attention matrices if decoder_outputs[2] is None: return { "lrc_text": "", "sentence_timestamps": [], "token_timestamps": [], "success": False, "error": "Model did not return attentions" } cross_attns = decoder_outputs[2] # Tuple of tensors (some may be None) captured_layers_list = [] for layer_attn in cross_attns: # Skip None values (layers that didn't return attention) if layer_attn is None: continue # Only take conditional part (first half of batch) cond_attn = layer_attn[:bsz] layer_matrix = cond_attn.transpose(-1, -2) captured_layers_list.append(layer_matrix) if not captured_layers_list: return { "lrc_text": "", "sentence_timestamps": [], "token_timestamps": [], "success": False, "error": "No valid attention layers returned" } stacked = torch.stack(captured_layers_list) if bsz == 1: all_layers_matrix = stacked.squeeze(1) else: all_layers_matrix = stacked # Process lyric token IDs to extract pure lyrics if isinstance(lyric_token_ids, torch.Tensor): raw_lyric_ids = lyric_token_ids[0].tolist() else: raw_lyric_ids = lyric_token_ids # Parse header to find lyrics start position header_str = f"# Languages\n{vocal_language}\n\n# Lyric\n" header_ids = self.text_tokenizer.encode(header_str, add_special_tokens=False) start_idx = len(header_ids) # Find end of lyrics (before endoftext token) try: end_idx = raw_lyric_ids.index(151643) # <|endoftext|> token except ValueError: end_idx = len(raw_lyric_ids) pure_lyric_ids = raw_lyric_ids[start_idx:end_idx] pure_lyric_matrix = all_layers_matrix[:, :, start_idx:end_idx, :] # Create aligner and generate timestamps aligner = MusicStampsAligner(self.text_tokenizer) align_info = aligner.stamps_align_info( attention_matrix=pure_lyric_matrix, lyrics_tokens=pure_lyric_ids, total_duration_seconds=total_duration_seconds, custom_config=custom_layers_config, return_matrices=False, violence_level=2.0, medfilt_width=1, ) if align_info.get("calc_matrix") is None: return { "lrc_text": "", "sentence_timestamps": [], "token_timestamps": [], "success": False, "error": align_info.get("error", "Failed to process attention matrix") } # Generate timestamps result = aligner.get_timestamps_and_lrc( calc_matrix=align_info["calc_matrix"], lyrics_tokens=pure_lyric_ids, total_duration_seconds=total_duration_seconds ) return { "lrc_text": result["lrc_text"], "sentence_timestamps": result["sentence_timestamps"], "token_timestamps": result["token_timestamps"], "success": True, "error": None } except Exception as e: error_msg = f"Error generating timestamps: {str(e)}" logger.exception("[get_lyric_timestamp] Failed") return { "lrc_text": "", "sentence_timestamps": [], "token_timestamps": [], "success": False, "error": error_msg } @torch.no_grad() def get_lyric_score( self, pred_latent: torch.Tensor, encoder_hidden_states: torch.Tensor, encoder_attention_mask: torch.Tensor, context_latents: torch.Tensor, lyric_token_ids: torch.Tensor, vocal_language: str = "en", inference_steps: int = 8, seed: int = 42, custom_layers_config: Optional[Dict] = None, ) -> Dict[str, Any]: """ Calculate both LM and DiT alignment scores in one pass. - lm_score: Checks structural alignment using pure noise at t=1.0. - dit_score: Checks denoising alignment using regressed latents at t=1/steps. Args: pred_latent: Generated latent tensor [batch, T, D] encoder_hidden_states: Cached encoder hidden states encoder_attention_mask: Cached encoder attention mask context_latents: Cached context latents lyric_token_ids: Tokenized lyrics tensor [batch, seq_len] vocal_language: Language code for lyrics header parsing inference_steps: Number of inference steps (for noise level calculation) seed: Random seed for noise generation custom_layers_config: Dict mapping layer indices to head indices Returns: Dict containing: - lm_score: float - dit_score: float - success: Whether generation succeeded - error: Error message if failed """ from transformers.cache_utils import EncoderDecoderCache, DynamicCache if self.model is None: return { "lm_score": 0.0, "dit_score": 0.0, "success": False, "error": "Model not initialized" } if custom_layers_config is None: custom_layers_config = self.custom_layers_config try: # Move tensors to device device = self.device dtype = self.dtype pred_latent = pred_latent.to(device=device, dtype=dtype) encoder_hidden_states = encoder_hidden_states.to(device=device, dtype=dtype) encoder_attention_mask = encoder_attention_mask.to(device=device, dtype=dtype) context_latents = context_latents.to(device=device, dtype=dtype) bsz = pred_latent.shape[0] if seed is None: x0 = torch.randn_like(pred_latent) else: generator = torch.Generator(device=device).manual_seed(int(seed)) x0 = torch.randn(pred_latent.shape, generator=generator, device=device, dtype=dtype) # --- Input A: LM Score --- # t = 1.0, xt = Pure Noise t_lm = torch.tensor([1.0] * bsz, device=device, dtype=dtype) xt_lm = x0 # --- Input B: DiT Score --- # t = 1.0/steps, xt = Regressed Latent t_last_val = 1.0 / inference_steps t_dit = torch.tensor([t_last_val] * bsz, device=device, dtype=dtype) # Flow Matching Regression: xt = t*x0 + (1-t)*x1 xt_dit = t_last_val * x0 + (1.0 - t_last_val) * pred_latent # Order: [Think_Batch, DiT_Batch] xt_in = torch.cat([xt_lm, xt_dit], dim=0) t_in = torch.cat([t_lm, t_dit], dim=0) # Duplicate conditions encoder_hidden_states_in = torch.cat([encoder_hidden_states, encoder_hidden_states], dim=0) encoder_attention_mask_in = torch.cat([encoder_attention_mask, encoder_attention_mask], dim=0) context_latents_in = torch.cat([context_latents, context_latents], dim=0) # Prepare Attention Mask latent_length = xt_in.shape[1] attention_mask_in = torch.ones(2 * bsz, latent_length, device=device, dtype=dtype) past_key_values = None # Run decoder with output_attentions=True with self._load_model_context("model"): decoder = self.model.decoder if hasattr(decoder, 'eval'): decoder.eval() decoder_outputs = decoder( hidden_states=xt_in, timestep=t_in, timestep_r=t_in, attention_mask=attention_mask_in, encoder_hidden_states=encoder_hidden_states_in, use_cache=False, past_key_values=past_key_values, encoder_attention_mask=encoder_attention_mask_in, context_latents=context_latents_in, output_attentions=True, custom_layers_config=custom_layers_config, enable_early_exit=True ) # Extract cross-attention matrices if decoder_outputs[2] is None: return { "lm_score": 0.0, "dit_score": 0.0, "success": False, "error": "Model did not return attentions" } cross_attns = decoder_outputs[2] # Tuple of tensors (some may be None) captured_layers_list = [] for layer_attn in cross_attns: if layer_attn is None: continue # Only take conditional part (first half of batch) layer_matrix = layer_attn.transpose(-1, -2) captured_layers_list.append(layer_matrix) if not captured_layers_list: return { "lm_score": 0.0, "dit_score": 0.0, "success": False, "error": "No valid attention layers returned" } stacked = torch.stack(captured_layers_list) all_layers_matrix_lm = stacked[:, :bsz, ...] all_layers_matrix_dit = stacked[:, bsz:, ...] if bsz == 1: all_layers_matrix_lm = all_layers_matrix_lm.squeeze(1) all_layers_matrix_dit = all_layers_matrix_dit.squeeze(1) else: pass # Process lyric token IDs to extract pure lyrics if isinstance(lyric_token_ids, torch.Tensor): raw_lyric_ids = lyric_token_ids[0].tolist() else: raw_lyric_ids = lyric_token_ids # Parse header to find lyrics start position header_str = f"# Languages\n{vocal_language}\n\n# Lyric\n" header_ids = self.text_tokenizer.encode(header_str, add_special_tokens=False) start_idx = len(header_ids) # Find end of lyrics (before endoftext token) try: end_idx = raw_lyric_ids.index(151643) # <|endoftext|> token except ValueError: end_idx = len(raw_lyric_ids) pure_lyric_ids = raw_lyric_ids[start_idx:end_idx] if start_idx >= all_layers_matrix_lm.shape[-2]: # Check text dim return { "lm_score": 0.0, "dit_score": 0.0, "success": False, "error": "Lyrics indices out of bounds" } pure_matrix_lm = all_layers_matrix_lm[..., start_idx:end_idx, :] pure_matrix_dit = all_layers_matrix_dit[..., start_idx:end_idx, :] # Create aligner and calculate alignment info aligner = MusicLyricScorer(self.text_tokenizer) def calculate_single_score(matrix): """Helper to run aligner on a matrix""" info = aligner.lyrics_alignment_info( attention_matrix=matrix, token_ids=pure_lyric_ids, custom_config=custom_layers_config, return_matrices=False, medfilt_width=1, ) if info.get("energy_matrix") is None: return 0.0 res = aligner.calculate_score( energy_matrix=info["energy_matrix"], type_mask=info["type_mask"], path_coords=info["path_coords"], ) # Return the final score (check return key) return res.get("lyrics_score", res.get("final_score", 0.0)) lm_score = calculate_single_score(pure_matrix_lm) dit_score = calculate_single_score(pure_matrix_dit) return { "lm_score": lm_score, "dit_score": dit_score, "success": True, "error": None } except Exception as e: error_msg = f"Error generating score: {str(e)}" logger.exception("[get_lyric_score] Failed") return { "lm_score": 0.0, "dit_score": 0.0, "success": False, "error": error_msg }