Spaces:
Running
on
A100
Running
on
A100
| """ | |
| Business Logic Handler | |
| Encapsulates all data processing and business logic as a bridge between model and UI | |
| """ | |
| import os | |
| # Disable tokenizers parallelism to avoid fork warning | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| # Disable torchcodec backend to avoid CUDA dependency issues on HuggingFace Space | |
| # This forces torchaudio to use ffmpeg/sox/soundfile backends instead | |
| os.environ["TORCHAUDIO_USE_TORCHCODEC"] = "0" | |
| import math | |
| from copy import deepcopy | |
| import tempfile | |
| import traceback | |
| import re | |
| import random | |
| import uuid | |
| import hashlib | |
| import json | |
| from contextlib import contextmanager | |
| from typing import Optional, Dict, Any, Tuple, List, Union | |
| import torch | |
| import torchaudio | |
| import soundfile as sf | |
| import time | |
| from tqdm import tqdm | |
| from loguru import logger | |
| import warnings | |
| from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM | |
| from transformers.generation.streamers import BaseStreamer | |
| from diffusers.models import AutoencoderOobleck | |
| from acestep.constants import ( | |
| TASK_INSTRUCTIONS, | |
| SFT_GEN_PROMPT, | |
| DEFAULT_DIT_INSTRUCTION, | |
| ) | |
| from acestep.dit_alignment_score import MusicStampsAligner, MusicLyricScorer | |
| warnings.filterwarnings("ignore") | |
| class AceStepHandler: | |
| """ACE-Step Business Logic Handler""" | |
| # HuggingFace Space environment detection | |
| IS_HUGGINGFACE_SPACE = os.environ.get("SPACE_ID") is not None | |
| def __init__(self, persistent_storage_path: Optional[str] = None): | |
| self.model = None | |
| self.config = None | |
| self.device = "cpu" | |
| self.dtype = torch.float32 # Will be set based on device in initialize_service | |
| # HuggingFace Space persistent storage support | |
| if persistent_storage_path is None and self.IS_HUGGINGFACE_SPACE: | |
| persistent_storage_path = "/data" | |
| self.persistent_storage_path = persistent_storage_path | |
| # VAE for audio encoding/decoding | |
| self.vae = None | |
| # Text encoder and tokenizer | |
| self.text_encoder = None | |
| self.text_tokenizer = None | |
| # Silence latent for initialization | |
| self.silence_latent = None | |
| # Sample rate | |
| self.sample_rate = 48000 | |
| # Reward model (temporarily disabled) | |
| self.reward_model = None | |
| # Batch size | |
| self.batch_size = 2 | |
| # Custom layers config | |
| self.custom_layers_config = {2: [6], 3: [10, 11], 4: [3], 5: [8, 9], 6: [8]} | |
| self.offload_to_cpu = False | |
| self.offload_dit_to_cpu = False | |
| self.current_offload_cost = 0.0 | |
| # LoRA state | |
| self.lora_loaded = False | |
| self.use_lora = False | |
| self._base_decoder = None # Backup of original decoder | |
| def _get_checkpoint_dir(self) -> str: | |
| """Get checkpoint directory, prioritizing persistent storage if available""" | |
| if self.persistent_storage_path: | |
| return os.path.join(self.persistent_storage_path, "checkpoints") | |
| project_root = self._get_project_root() | |
| return os.path.join(project_root, "checkpoints") | |
| def get_available_checkpoints(self) -> str: | |
| """Return project root directory path""" | |
| checkpoint_dir = self._get_checkpoint_dir() | |
| if os.path.exists(checkpoint_dir): | |
| return [checkpoint_dir] | |
| else: | |
| return [] | |
| def get_available_acestep_v15_models(self) -> List[str]: | |
| """Scan and return all model directory names starting with 'acestep-v15-'""" | |
| checkpoint_dir = self._get_checkpoint_dir() | |
| models = [] | |
| if os.path.exists(checkpoint_dir): | |
| for item in os.listdir(checkpoint_dir): | |
| item_path = os.path.join(checkpoint_dir, item) | |
| if os.path.isdir(item_path) and item.startswith("acestep-v15-"): | |
| models.append(item) | |
| models.sort() | |
| return models | |
| # Model name to HuggingFace repository mapping | |
| # Models in the same repo will be downloaded together | |
| MODEL_REPO_MAPPING = { | |
| # Main unified repository (contains acestep-v15-turbo, LM models, VAE, text encoder) | |
| "acestep-v15-turbo": "ACE-Step/Ace-Step1.5", | |
| "acestep-5Hz-lm-0.6B": "ACE-Step/Ace-Step1.5", | |
| "acestep-5Hz-lm-1.7B": "ACE-Step/Ace-Step1.5", | |
| "vae": "ACE-Step/Ace-Step1.5", | |
| "Qwen3-Embedding-0.6B": "ACE-Step/Ace-Step1.5", | |
| # Separate model repositories | |
| "acestep-v15-base": "ACE-Step/acestep-v15-base", | |
| "acestep-v15-sft": "ACE-Step/acestep-v15-sft", | |
| "acestep-v15-turbo-shift3": "ACE-Step/acestep-v15-turbo-shift3", | |
| } | |
| # Default fallback repository for unknown models | |
| DEFAULT_REPO_ID = "ACE-Step/Ace-Step1.5" | |
| def _ensure_model_downloaded(self, model_name: str, checkpoint_dir: str) -> str: | |
| """ | |
| Ensure model is downloaded from HuggingFace Hub. | |
| Used for HuggingFace Space auto-download support. | |
| Supports multiple repositories: | |
| - Models in MODEL_REPO_MAPPING will be downloaded from their specific repo | |
| - Unknown models will try the DEFAULT_REPO_ID | |
| For separate model repos (acestep-v15-base, acestep-v15-sft, acestep-v15-turbo-shift3), | |
| downloads directly into the model subdirectory. | |
| Args: | |
| model_name: Model directory name (e.g., "acestep-v15-turbo", "acestep-v15-turbo-shift3") | |
| checkpoint_dir: Target checkpoint directory | |
| Returns: | |
| Path to the downloaded model | |
| """ | |
| from huggingface_hub import snapshot_download | |
| model_path = os.path.join(checkpoint_dir, model_name) | |
| # Check if model already exists | |
| if os.path.exists(model_path) and os.listdir(model_path): | |
| logger.info(f"Model {model_name} already exists at {model_path}") | |
| return model_path | |
| # Get repository ID for this model | |
| repo_id = self.MODEL_REPO_MAPPING.get(model_name, self.DEFAULT_REPO_ID) | |
| # Determine if this is a unified repo or a separate model repo | |
| is_unified_repo = repo_id == self.DEFAULT_REPO_ID or repo_id == "ACE-Step/Ace-Step1.5" | |
| if is_unified_repo: | |
| # Unified repo: download entire repo to checkpoint_dir | |
| # The model will be in checkpoint_dir/model_name | |
| download_dir = checkpoint_dir | |
| logger.info(f"Downloading unified repository {repo_id} to {download_dir}...") | |
| else: | |
| # Separate model repo: download directly to model_path | |
| # The repo contains the model files directly, not in a subdirectory | |
| download_dir = model_path | |
| os.makedirs(download_dir, exist_ok=True) | |
| logger.info(f"Downloading model {model_name} from {repo_id} to {download_dir}...") | |
| try: | |
| snapshot_download( | |
| repo_id=repo_id, | |
| local_dir=download_dir, | |
| local_dir_use_symlinks=False, | |
| ) | |
| logger.info(f"Repository {repo_id} downloaded successfully to {download_dir}") | |
| except Exception as e: | |
| logger.error(f"Failed to download repository {repo_id}: {e}") | |
| raise | |
| return model_path | |
| def is_flash_attention_available(self) -> bool: | |
| """Check if flash attention is available on the system""" | |
| try: | |
| import flash_attn | |
| return True | |
| except ImportError: | |
| return False | |
| def is_turbo_model(self) -> bool: | |
| """Check if the currently loaded model is a turbo model""" | |
| if self.config is None: | |
| return False | |
| return getattr(self.config, 'is_turbo', False) | |
| def load_lora(self, lora_path: str) -> str: | |
| """Load LoRA adapter into the decoder. | |
| Args: | |
| lora_path: Path to the LoRA adapter directory (containing adapter_config.json) | |
| Returns: | |
| Status message | |
| """ | |
| if self.model is None: | |
| return "❌ Model not initialized. Please initialize service first." | |
| if not lora_path or not lora_path.strip(): | |
| return "❌ Please provide a LoRA path." | |
| lora_path = lora_path.strip() | |
| # Check if path exists | |
| if not os.path.exists(lora_path): | |
| return f"❌ LoRA path not found: {lora_path}" | |
| # Check if it's a valid PEFT adapter directory | |
| config_file = os.path.join(lora_path, "adapter_config.json") | |
| if not os.path.exists(config_file): | |
| return f"❌ Invalid LoRA adapter: adapter_config.json not found in {lora_path}" | |
| try: | |
| from peft import PeftModel, PeftConfig | |
| except ImportError: | |
| return "❌ PEFT library not installed. Please install with: pip install peft" | |
| try: | |
| # Backup base decoder if not already backed up | |
| if self._base_decoder is None: | |
| import copy | |
| self._base_decoder = copy.deepcopy(self.model.decoder) | |
| logger.info("Base decoder backed up") | |
| else: | |
| # Restore base decoder before loading new LoRA | |
| self.model.decoder = copy.deepcopy(self._base_decoder) | |
| logger.info("Restored base decoder before loading new LoRA") | |
| # Load PEFT adapter | |
| logger.info(f"Loading LoRA adapter from {lora_path}") | |
| self.model.decoder = PeftModel.from_pretrained( | |
| self.model.decoder, | |
| lora_path, | |
| is_trainable=False, | |
| ) | |
| self.model.decoder = self.model.decoder.to(self.device).to(self.dtype) | |
| self.model.decoder.eval() | |
| self.lora_loaded = True | |
| self.use_lora = True # Enable LoRA by default after loading | |
| logger.info(f"LoRA adapter loaded successfully from {lora_path}") | |
| return f"✅ LoRA loaded from {lora_path}" | |
| except Exception as e: | |
| logger.exception("Failed to load LoRA adapter") | |
| return f"❌ Failed to load LoRA: {str(e)}" | |
| def unload_lora(self) -> str: | |
| """Unload LoRA adapter and restore base decoder. | |
| Returns: | |
| Status message | |
| """ | |
| if not self.lora_loaded: | |
| return "⚠️ No LoRA adapter loaded." | |
| if self._base_decoder is None: | |
| return "❌ Base decoder backup not found. Cannot restore." | |
| try: | |
| import copy | |
| # Restore base decoder | |
| self.model.decoder = copy.deepcopy(self._base_decoder) | |
| self.model.decoder = self.model.decoder.to(self.device).to(self.dtype) | |
| self.model.decoder.eval() | |
| self.lora_loaded = False | |
| self.use_lora = False | |
| logger.info("LoRA unloaded, base decoder restored") | |
| return "✅ LoRA unloaded, using base model" | |
| except Exception as e: | |
| logger.exception("Failed to unload LoRA") | |
| return f"❌ Failed to unload LoRA: {str(e)}" | |
| def set_use_lora(self, use_lora: bool) -> str: | |
| """Toggle LoRA usage for inference. | |
| Args: | |
| use_lora: Whether to use LoRA adapter | |
| Returns: | |
| Status message | |
| """ | |
| if use_lora and not self.lora_loaded: | |
| return "❌ No LoRA adapter loaded. Please load a LoRA first." | |
| self.use_lora = use_lora | |
| # Use PEFT's enable/disable methods if available | |
| if self.lora_loaded and hasattr(self.model.decoder, 'disable_adapter_layers'): | |
| try: | |
| if use_lora: | |
| self.model.decoder.enable_adapter_layers() | |
| logger.info("LoRA adapter enabled") | |
| else: | |
| self.model.decoder.disable_adapter_layers() | |
| logger.info("LoRA adapter disabled") | |
| except Exception as e: | |
| logger.warning(f"Could not toggle adapter layers: {e}") | |
| status = "enabled" if use_lora else "disabled" | |
| return f"✅ LoRA {status}" | |
| def get_lora_status(self) -> Dict[str, Any]: | |
| """Get current LoRA status. | |
| Returns: | |
| Dictionary with LoRA status info | |
| """ | |
| return { | |
| "loaded": self.lora_loaded, | |
| "active": self.use_lora, | |
| } | |
| def initialize_service( | |
| self, | |
| project_root: str, | |
| config_path: str, | |
| device: str = "auto", | |
| use_flash_attention: bool = False, | |
| compile_model: bool = False, | |
| offload_to_cpu: bool = False, | |
| offload_dit_to_cpu: bool = False, | |
| quantization: Optional[str] = None, | |
| # Shared components (for multi-model setup to save memory) | |
| shared_vae = None, | |
| shared_text_encoder = None, | |
| shared_text_tokenizer = None, | |
| shared_silence_latent = None, | |
| ) -> Tuple[str, bool]: | |
| """ | |
| Initialize DiT model service | |
| Args: | |
| project_root: Project root path (may be checkpoints directory, will be handled automatically) | |
| config_path: Model config directory name (e.g., "acestep-v15-turbo") | |
| device: Device type | |
| use_flash_attention: Whether to use flash attention (requires flash_attn package) | |
| compile_model: Whether to use torch.compile to optimize the model | |
| offload_to_cpu: Whether to offload models to CPU when not in use | |
| offload_dit_to_cpu: Whether to offload DiT model to CPU when not in use (only effective if offload_to_cpu is True) | |
| shared_vae: Optional shared VAE instance (for multi-model setup) | |
| shared_text_encoder: Optional shared text encoder instance (for multi-model setup) | |
| shared_text_tokenizer: Optional shared text tokenizer instance (for multi-model setup) | |
| shared_silence_latent: Optional shared silence latent tensor (for multi-model setup) | |
| Returns: | |
| (status_message, enable_generate_button) | |
| """ | |
| try: | |
| if device == "auto": | |
| if hasattr(torch, 'xpu') and torch.xpu.is_available(): | |
| device = "xpu" | |
| elif torch.cuda.is_available(): | |
| device = "cuda" | |
| elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): | |
| device = "mps" | |
| else: | |
| device = "cpu" | |
| status_msg = "" | |
| self.device = device | |
| self.offload_to_cpu = offload_to_cpu | |
| self.offload_dit_to_cpu = offload_dit_to_cpu | |
| # Set dtype based on device: bfloat16 for cuda, float32 for cpu | |
| self.dtype = torch.bfloat16 if device in ["cuda","xpu"] else torch.float32 | |
| self.quantization = quantization | |
| if self.quantization is not None: | |
| assert compile_model, "Quantization requires compile_model to be True" | |
| try: | |
| import torchao | |
| except ImportError: | |
| raise ImportError("torchao is required for quantization but is not installed. Please install torchao to use quantization features.") | |
| # Auto-detect project root (independent of passed project_root parameter) | |
| actual_project_root = self._get_project_root() | |
| checkpoint_dir = self._get_checkpoint_dir() | |
| os.makedirs(checkpoint_dir, exist_ok=True) | |
| # 1. Load main model | |
| # config_path is relative path (e.g., "acestep-v15-turbo"), concatenate to checkpoints directory | |
| # If config_path is None (HuggingFace Space with empty checkpoint), use default and auto-download | |
| if config_path is None: | |
| config_path = "acestep-v15-turbo" | |
| logger.info(f"[initialize_service] config_path is None, using default: {config_path}") | |
| acestep_v15_checkpoint_path = os.path.join(checkpoint_dir, config_path) | |
| # Auto-download model if not exists (HuggingFace Space support) | |
| if not os.path.exists(acestep_v15_checkpoint_path): | |
| acestep_v15_checkpoint_path = self._ensure_model_downloaded(config_path, checkpoint_dir) | |
| if os.path.exists(acestep_v15_checkpoint_path): | |
| # Determine attention implementation | |
| if use_flash_attention and self.is_flash_attention_available(): | |
| attn_implementation = "flash_attention_2" | |
| self.dtype = torch.bfloat16 | |
| else: | |
| attn_implementation = "sdpa" | |
| try: | |
| logger.info(f"[initialize_service] Attempting to load model with attention implementation: {attn_implementation}") | |
| self.model = AutoModel.from_pretrained( | |
| acestep_v15_checkpoint_path, | |
| trust_remote_code=True, | |
| attn_implementation=attn_implementation, | |
| dtype="bfloat16" | |
| ) | |
| except Exception as e: | |
| logger.warning(f"[initialize_service] Failed to load model with {attn_implementation}: {e}") | |
| if attn_implementation == "sdpa": | |
| logger.info("[initialize_service] Falling back to eager attention") | |
| attn_implementation = "eager" | |
| self.model = AutoModel.from_pretrained( | |
| acestep_v15_checkpoint_path, | |
| trust_remote_code=True, | |
| attn_implementation=attn_implementation | |
| ) | |
| else: | |
| raise e | |
| self.model.config._attn_implementation = attn_implementation | |
| self.config = self.model.config | |
| # Move model to device and set dtype | |
| if not self.offload_to_cpu: | |
| self.model = self.model.to(device).to(self.dtype) | |
| else: | |
| # If offload_to_cpu is True, check if we should keep DiT on GPU | |
| if not self.offload_dit_to_cpu: | |
| logger.info(f"[initialize_service] Keeping main model on {device} (persistent)") | |
| self.model = self.model.to(device).to(self.dtype) | |
| else: | |
| self.model = self.model.to("cpu").to(self.dtype) | |
| self.model.eval() | |
| if compile_model: | |
| self.model = torch.compile(self.model) | |
| if self.quantization is not None: | |
| from torchao.quantization import quantize_ | |
| if self.quantization == "int8_weight_only": | |
| from torchao.quantization import Int8WeightOnlyConfig | |
| quant_config = Int8WeightOnlyConfig() | |
| elif self.quantization == "fp8_weight_only": | |
| from torchao.quantization import Float8WeightOnlyConfig | |
| quant_config = Float8WeightOnlyConfig() | |
| elif self.quantization == "w8a8_dynamic": | |
| from torchao.quantization import Int8DynamicActivationInt8WeightConfig, MappingType | |
| quant_config = Int8DynamicActivationInt8WeightConfig(act_mapping_type=MappingType.ASYMMETRIC) | |
| else: | |
| raise ValueError(f"Unsupported quantization type: {self.quantization}") | |
| quantize_(self.model, quant_config) | |
| logger.info(f"[initialize_service] DiT quantized with: {self.quantization}") | |
| # Load or use shared silence_latent | |
| if shared_silence_latent is not None: | |
| self.silence_latent = shared_silence_latent | |
| logger.info("[initialize_service] Using shared silence_latent") | |
| else: | |
| silence_latent_path = os.path.join(acestep_v15_checkpoint_path, "silence_latent.pt") | |
| if os.path.exists(silence_latent_path): | |
| self.silence_latent = torch.load(silence_latent_path).transpose(1, 2) | |
| # Always keep silence_latent on GPU - it's used in many places outside model context | |
| # and is small enough that it won't significantly impact VRAM | |
| self.silence_latent = self.silence_latent.to(device).to(self.dtype) | |
| else: | |
| raise FileNotFoundError(f"Silence latent not found at {silence_latent_path}") | |
| else: | |
| raise FileNotFoundError(f"ACE-Step V1.5 checkpoint not found at {acestep_v15_checkpoint_path}") | |
| # 2. Load or use shared VAE | |
| vae_checkpoint_path = os.path.join(checkpoint_dir, "vae") # Define for status message | |
| if shared_vae is not None: | |
| self.vae = shared_vae | |
| logger.info("[initialize_service] Using shared VAE") | |
| else: | |
| if os.path.exists(vae_checkpoint_path): | |
| self.vae = AutoencoderOobleck.from_pretrained(vae_checkpoint_path) | |
| # Use bfloat16 for VAE on GPU, otherwise use self.dtype (float32 on CPU) | |
| vae_dtype = self._get_vae_dtype(device) | |
| if not self.offload_to_cpu: | |
| self.vae = self.vae.to(device).to(vae_dtype) | |
| else: | |
| self.vae = self.vae.to("cpu").to(vae_dtype) | |
| self.vae.eval() | |
| else: | |
| raise FileNotFoundError(f"VAE checkpoint not found at {vae_checkpoint_path}") | |
| if compile_model: | |
| self.vae = torch.compile(self.vae) | |
| # 3. Load or use shared text encoder and tokenizer | |
| text_encoder_path = os.path.join(checkpoint_dir, "Qwen3-Embedding-0.6B") # Define for status message | |
| if shared_text_encoder is not None and shared_text_tokenizer is not None: | |
| self.text_encoder = shared_text_encoder | |
| self.text_tokenizer = shared_text_tokenizer | |
| logger.info("[initialize_service] Using shared text encoder and tokenizer") | |
| else: | |
| if os.path.exists(text_encoder_path): | |
| self.text_tokenizer = AutoTokenizer.from_pretrained(text_encoder_path) | |
| self.text_encoder = AutoModel.from_pretrained(text_encoder_path) | |
| if not self.offload_to_cpu: | |
| self.text_encoder = self.text_encoder.to(device).to(self.dtype) | |
| else: | |
| self.text_encoder = self.text_encoder.to("cpu").to(self.dtype) | |
| self.text_encoder.eval() | |
| else: | |
| raise FileNotFoundError(f"Text encoder not found at {text_encoder_path}") | |
| # Determine actual attention implementation used | |
| actual_attn = getattr(self.config, "_attn_implementation", "eager") | |
| # Determine if using shared components | |
| using_shared = shared_vae is not None or shared_text_encoder is not None | |
| status_msg = f"✅ Model initialized successfully on {device}\n" | |
| status_msg += f"Main model: {acestep_v15_checkpoint_path}\n" | |
| if shared_vae is None: | |
| status_msg += f"VAE: {vae_checkpoint_path}\n" | |
| else: | |
| status_msg += f"VAE: shared\n" | |
| if shared_text_encoder is None: | |
| status_msg += f"Text encoder: {text_encoder_path}\n" | |
| else: | |
| status_msg += f"Text encoder: shared\n" | |
| status_msg += f"Dtype: {self.dtype}\n" | |
| status_msg += f"Attention: {actual_attn}\n" | |
| status_msg += f"Compiled: {compile_model}\n" | |
| status_msg += f"Offload to CPU: {self.offload_to_cpu}\n" | |
| status_msg += f"Offload DiT to CPU: {self.offload_dit_to_cpu}" | |
| return status_msg, True | |
| except Exception as e: | |
| error_msg = f"❌ Error initializing model: {str(e)}\n\nTraceback:\n{traceback.format_exc()}" | |
| logger.exception("[initialize_service] Error initializing model") | |
| return error_msg, False | |
| def _is_on_target_device(self, tensor, target_device): | |
| """Check if tensor is on the target device (handles cuda vs cuda:0 comparison).""" | |
| if tensor is None: | |
| return True | |
| target_type = "cpu" if target_device == "cpu" else "cuda" | |
| return tensor.device.type == target_type | |
| def _ensure_silence_latent_on_device(self): | |
| """Ensure silence_latent is on the correct device (self.device).""" | |
| if hasattr(self, "silence_latent") and self.silence_latent is not None: | |
| if not self._is_on_target_device(self.silence_latent, self.device): | |
| self.silence_latent = self.silence_latent.to(self.device).to(self.dtype) | |
| def _move_module_recursive(self, module, target_device, dtype=None, visited=None): | |
| """ | |
| Recursively move a module and all its submodules to the target device. | |
| This handles modules that may not be properly registered. | |
| """ | |
| if visited is None: | |
| visited = set() | |
| module_id = id(module) | |
| if module_id in visited: | |
| return | |
| visited.add(module_id) | |
| # Move the module itself | |
| module.to(target_device) | |
| if dtype is not None: | |
| module.to(dtype) | |
| # Move all direct parameters | |
| for param_name, param in module._parameters.items(): | |
| if param is not None and not self._is_on_target_device(param, target_device): | |
| module._parameters[param_name] = param.to(target_device) | |
| if dtype is not None: | |
| module._parameters[param_name] = module._parameters[param_name].to(dtype) | |
| # Move all direct buffers | |
| for buf_name, buf in module._buffers.items(): | |
| if buf is not None and not self._is_on_target_device(buf, target_device): | |
| module._buffers[buf_name] = buf.to(target_device) | |
| # Recursively process all submodules (registered and unregistered) | |
| for name, child in module._modules.items(): | |
| if child is not None: | |
| self._move_module_recursive(child, target_device, dtype, visited) | |
| # Also check for any nn.Module attributes that might not be in _modules | |
| for attr_name in dir(module): | |
| if attr_name.startswith('_'): | |
| continue | |
| try: | |
| attr = getattr(module, attr_name, None) | |
| if isinstance(attr, torch.nn.Module) and id(attr) not in visited: | |
| self._move_module_recursive(attr, target_device, dtype, visited) | |
| except Exception: | |
| pass | |
| def _recursive_to_device(self, model, device, dtype=None): | |
| """ | |
| Recursively move all parameters and buffers of a model to the specified device. | |
| This is more thorough than model.to() for some custom HuggingFace models. | |
| """ | |
| target_device = torch.device(device) if isinstance(device, str) else device | |
| # Method 1: Standard .to() call | |
| model.to(target_device) | |
| if dtype is not None: | |
| model.to(dtype) | |
| # Method 2: Use our thorough recursive moving for any missed modules | |
| self._move_module_recursive(model, target_device, dtype) | |
| # Method 3: Force move via state_dict if there are still parameters on wrong device | |
| wrong_device_params = [] | |
| for name, param in model.named_parameters(): | |
| if not self._is_on_target_device(param, device): | |
| wrong_device_params.append(name) | |
| if wrong_device_params and device != "cpu": | |
| logger.warning(f"[_recursive_to_device] {len(wrong_device_params)} parameters on wrong device, using state_dict method") | |
| # Get current state dict and move all tensors | |
| state_dict = model.state_dict() | |
| moved_state_dict = {} | |
| for key, value in state_dict.items(): | |
| if isinstance(value, torch.Tensor): | |
| moved_state_dict[key] = value.to(target_device) | |
| if dtype is not None and moved_state_dict[key].is_floating_point(): | |
| moved_state_dict[key] = moved_state_dict[key].to(dtype) | |
| else: | |
| moved_state_dict[key] = value | |
| model.load_state_dict(moved_state_dict) | |
| # Synchronize CUDA to ensure all transfers are complete | |
| if device != "cpu" and torch.cuda.is_available(): | |
| torch.cuda.synchronize() | |
| # Final verification | |
| if device != "cpu": | |
| still_wrong = [] | |
| for name, param in model.named_parameters(): | |
| if not self._is_on_target_device(param, device): | |
| still_wrong.append(f"{name} on {param.device}") | |
| if still_wrong: | |
| logger.error(f"[_recursive_to_device] CRITICAL: {len(still_wrong)} parameters still on wrong device: {still_wrong[:10]}") | |
| def _load_model_context(self, model_name: str): | |
| """ | |
| Context manager to load a model to GPU and offload it back to CPU after use. | |
| Args: | |
| model_name: Name of the model to load ("text_encoder", "vae", "model") | |
| """ | |
| if not self.offload_to_cpu: | |
| yield | |
| return | |
| # If model is DiT ("model") and offload_dit_to_cpu is False, do not offload | |
| if model_name == "model" and not self.offload_dit_to_cpu: | |
| # Ensure it's on device if not already (should be handled by init, but safe to check) | |
| model = getattr(self, model_name, None) | |
| if model is not None: | |
| # Check if model is on CPU, if so move to device (one-time move if it was somehow on CPU) | |
| # We check the first parameter's device | |
| try: | |
| param = next(model.parameters()) | |
| if param.device.type == "cpu": | |
| logger.info(f"[_load_model_context] Moving {model_name} to {self.device} (persistent)") | |
| self._recursive_to_device(model, self.device, self.dtype) | |
| if hasattr(self, "silence_latent"): | |
| self.silence_latent = self.silence_latent.to(self.device).to(self.dtype) | |
| except StopIteration: | |
| pass | |
| yield | |
| return | |
| model = getattr(self, model_name, None) | |
| if model is None: | |
| yield | |
| return | |
| # Load to GPU | |
| logger.info(f"[_load_model_context] Loading {model_name} to {self.device}") | |
| start_time = time.time() | |
| if model_name == "vae": | |
| vae_dtype = self._get_vae_dtype() | |
| self._recursive_to_device(model, self.device, vae_dtype) | |
| else: | |
| self._recursive_to_device(model, self.device, self.dtype) | |
| if model_name == "model" and hasattr(self, "silence_latent"): | |
| self.silence_latent = self.silence_latent.to(self.device).to(self.dtype) | |
| load_time = time.time() - start_time | |
| self.current_offload_cost += load_time | |
| logger.info(f"[_load_model_context] Loaded {model_name} to {self.device} in {load_time:.4f}s") | |
| try: | |
| yield | |
| finally: | |
| # Offload to CPU | |
| logger.info(f"[_load_model_context] Offloading {model_name} to CPU") | |
| start_time = time.time() | |
| self._recursive_to_device(model, "cpu") | |
| # NOTE: Do NOT offload silence_latent to CPU here! | |
| # silence_latent is used in many places outside of model context, | |
| # so it should stay on GPU to avoid device mismatch errors. | |
| torch.cuda.empty_cache() | |
| offload_time = time.time() - start_time | |
| self.current_offload_cost += offload_time | |
| logger.info(f"[_load_model_context] Offloaded {model_name} to CPU in {offload_time:.4f}s") | |
| def process_target_audio(self, audio_file) -> Optional[torch.Tensor]: | |
| """Process target audio""" | |
| if audio_file is None: | |
| return None | |
| try: | |
| # Load audio using soundfile | |
| audio_np, sr = sf.read(audio_file, dtype='float32') | |
| # Convert to torch: [samples, channels] or [samples] -> [channels, samples] | |
| if audio_np.ndim == 1: | |
| audio = torch.from_numpy(audio_np).unsqueeze(0) | |
| else: | |
| audio = torch.from_numpy(audio_np.T) | |
| # Normalize to stereo 48kHz | |
| audio = self._normalize_audio_to_stereo_48k(audio, sr) | |
| # Enforce duration limits (10-600 seconds) | |
| audio = self._enforce_audio_duration_limits(audio) | |
| return audio | |
| except Exception as e: | |
| logger.exception("[process_target_audio] Error processing target audio") | |
| return None | |
| def _parse_audio_code_string(self, code_str: str) -> List[int]: | |
| """Extract integer audio codes from prompt tokens like <|audio_code_123|>. | |
| Codes are clamped to valid range [0, 63999] (codebook size = 64000). | |
| """ | |
| if not code_str: | |
| return [] | |
| try: | |
| codes = [int(x) for x in re.findall(r"<\|audio_code_(\d+)\|>", code_str)] | |
| # Clamp codes to valid range [0, 63999] | |
| MAX_AUDIO_CODE = 63999 | |
| clamped_codes = [] | |
| invalid_codes = [] | |
| for code in codes: | |
| if code < 0 or code > MAX_AUDIO_CODE: | |
| invalid_codes.append(code) | |
| clamped_code = max(0, min(code, MAX_AUDIO_CODE)) | |
| clamped_codes.append(clamped_code) | |
| else: | |
| clamped_codes.append(code) | |
| if invalid_codes: | |
| logger.warning(f"[_parse_audio_code_string] Found {len(invalid_codes)} codes outside valid range [0, {MAX_AUDIO_CODE}]: {invalid_codes[:5]}... (clamped to valid range)") | |
| return clamped_codes | |
| except Exception as e: | |
| logger.debug(f"[_parse_audio_code_string] Failed to parse audio code string: {e}") | |
| return [] | |
| def _decode_audio_codes_to_latents(self, code_str: str) -> Optional[torch.Tensor]: | |
| """ | |
| Convert serialized audio code string into 25Hz latents using model quantizer/detokenizer. | |
| """ | |
| if not self.model or not hasattr(self.model, 'tokenizer') or not hasattr(self.model, 'detokenizer'): | |
| return None | |
| code_ids = self._parse_audio_code_string(code_str) | |
| if len(code_ids) == 0: | |
| return None | |
| try: | |
| with self._load_model_context("model"): | |
| quantizer = self.model.tokenizer.quantizer | |
| detokenizer = self.model.detokenizer | |
| # Get codebook size for validation | |
| # Default to 64000 (codebook size = 64000, valid range = 0-63999) | |
| codebook_size = getattr(quantizer, 'codebook_size', 64000) | |
| if hasattr(quantizer, 'quantizers') and len(quantizer.quantizers) > 0: | |
| codebook_size = getattr(quantizer.quantizers[0], 'codebook_size', codebook_size) | |
| # Validate code IDs are within valid range | |
| invalid_codes = [c for c in code_ids if c < 0 or c >= codebook_size] | |
| if invalid_codes: | |
| logger.warning(f"[_decode_audio_codes_to_latents] Found {len(invalid_codes)} invalid codes out of range [0, {codebook_size}): {invalid_codes[:5]}...") | |
| # Clamp invalid codes to valid range | |
| code_ids = [max(0, min(c, codebook_size - 1)) for c in code_ids] | |
| num_quantizers = getattr(quantizer, "num_quantizers", 1) | |
| # Create indices tensor: [T_5Hz] | |
| indices = torch.tensor(code_ids, device=self.device, dtype=torch.long) # [T_5Hz] | |
| indices = indices.unsqueeze(0).unsqueeze(-1) # [1, T_5Hz, 1] | |
| # Synchronize to catch any CUDA errors before proceeding | |
| if torch.cuda.is_available(): | |
| torch.cuda.synchronize() | |
| # Get quantized representation from indices | |
| # The quantizer expects [batch, T_5Hz] format and handles quantizer dimension internally | |
| quantized = quantizer.get_output_from_indices(indices) | |
| if quantized.dtype != self.dtype: | |
| quantized = quantized.to(self.dtype) | |
| # Detokenize to 25Hz: [1, T_5Hz, dim] -> [1, T_25Hz, dim] | |
| lm_hints_25hz = detokenizer(quantized) | |
| return lm_hints_25hz | |
| except Exception as e: | |
| logger.exception(f"[_decode_audio_codes_to_latents] Error decoding audio codes: {e}") | |
| # Clear CUDA error state | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return None | |
| def _create_default_meta(self) -> str: | |
| """Create default metadata string.""" | |
| return ( | |
| "- bpm: N/A\n" | |
| "- timesignature: N/A\n" | |
| "- keyscale: N/A\n" | |
| "- duration: 30 seconds\n" | |
| ) | |
| def _dict_to_meta_string(self, meta_dict: Dict[str, Any]) -> str: | |
| """Convert metadata dict to formatted string.""" | |
| bpm = meta_dict.get('bpm', meta_dict.get('tempo', 'N/A')) | |
| timesignature = meta_dict.get('timesignature', meta_dict.get('time_signature', 'N/A')) | |
| keyscale = meta_dict.get('keyscale', meta_dict.get('key', meta_dict.get('scale', 'N/A'))) | |
| duration = meta_dict.get('duration', meta_dict.get('length', 30)) | |
| # Format duration | |
| if isinstance(duration, (int, float)): | |
| duration = f"{int(duration)} seconds" | |
| elif not isinstance(duration, str): | |
| duration = "30 seconds" | |
| return ( | |
| f"- bpm: {bpm}\n" | |
| f"- timesignature: {timesignature}\n" | |
| f"- keyscale: {keyscale}\n" | |
| f"- duration: {duration}\n" | |
| ) | |
| def _parse_metas(self, metas: List[Union[str, Dict[str, Any]]]) -> List[str]: | |
| """ | |
| Parse and normalize metadata with fallbacks. | |
| Args: | |
| metas: List of metadata (can be strings, dicts, or None) | |
| Returns: | |
| List of formatted metadata strings | |
| """ | |
| parsed_metas = [] | |
| for meta in metas: | |
| if meta is None: | |
| # Default fallback metadata | |
| parsed_meta = self._create_default_meta() | |
| elif isinstance(meta, str): | |
| # Already formatted string | |
| parsed_meta = meta | |
| elif isinstance(meta, dict): | |
| # Convert dict to formatted string | |
| parsed_meta = self._dict_to_meta_string(meta) | |
| else: | |
| # Fallback for any other type | |
| parsed_meta = self._create_default_meta() | |
| parsed_metas.append(parsed_meta) | |
| return parsed_metas | |
| def build_dit_inputs( | |
| self, | |
| task: str, | |
| instruction: Optional[str], | |
| caption: str, | |
| lyrics: str, | |
| metas: Optional[Union[str, Dict[str, Any]]] = None, | |
| vocal_language: str = "en", | |
| ) -> Tuple[str, str]: | |
| """ | |
| Build text inputs for the caption and lyric branches used by DiT. | |
| Args: | |
| task: Task name (e.g., text2music, cover, repaint); kept for logging/future branching. | |
| instruction: Instruction text; default fallback matches service_generate behavior. | |
| caption: Caption string (fallback if not in metas). | |
| lyrics: Lyrics string. | |
| metas: Metadata (str or dict); follows _parse_metas formatting. | |
| May contain 'caption' and 'language' fields from LM CoT output. | |
| vocal_language: Language code for lyrics section (fallback if not in metas). | |
| Returns: | |
| (caption_input_text, lyrics_input_text) | |
| Example: | |
| caption_input, lyrics_input = handler.build_dit_inputs( | |
| task="text2music", | |
| instruction=None, | |
| caption="A calm piano melody", | |
| lyrics="la la la", | |
| metas={"bpm": 90, "duration": 45, "caption": "LM generated caption", "language": "en"}, | |
| vocal_language="en", | |
| ) | |
| """ | |
| # Align instruction formatting with _prepare_batch | |
| final_instruction = self._format_instruction(instruction or DEFAULT_DIT_INSTRUCTION) | |
| # Extract caption and language from metas if available (from LM CoT output) | |
| # Fallback to user-provided values if not in metas | |
| actual_caption = caption | |
| actual_language = vocal_language | |
| if metas is not None: | |
| # Parse metas to dict if it's a string | |
| if isinstance(metas, str): | |
| # Try to parse as dict-like string or use as-is | |
| parsed_metas = self._parse_metas([metas]) | |
| if parsed_metas and isinstance(parsed_metas[0], dict): | |
| meta_dict = parsed_metas[0] | |
| else: | |
| meta_dict = {} | |
| elif isinstance(metas, dict): | |
| meta_dict = metas | |
| else: | |
| meta_dict = {} | |
| # Extract caption from metas if available | |
| if 'caption' in meta_dict and meta_dict['caption']: | |
| actual_caption = str(meta_dict['caption']) | |
| # Extract language from metas if available | |
| if 'language' in meta_dict and meta_dict['language']: | |
| actual_language = str(meta_dict['language']) | |
| parsed_meta = self._parse_metas([metas])[0] | |
| caption_input = SFT_GEN_PROMPT.format(final_instruction, actual_caption, parsed_meta) | |
| lyrics_input = self._format_lyrics(lyrics, actual_language) | |
| return caption_input, lyrics_input | |
| def _get_text_hidden_states(self, text_prompt: str) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """Get text hidden states from text encoder.""" | |
| if self.text_tokenizer is None or self.text_encoder is None: | |
| raise ValueError("Text encoder not initialized") | |
| with self._load_model_context("text_encoder"): | |
| # Tokenize | |
| text_inputs = self.text_tokenizer( | |
| text_prompt, | |
| padding="longest", | |
| truncation=True, | |
| max_length=256, | |
| return_tensors="pt", | |
| ) | |
| text_input_ids = text_inputs.input_ids.to(self.device) | |
| text_attention_mask = text_inputs.attention_mask.to(self.device).bool() | |
| # Encode | |
| with torch.no_grad(): | |
| text_outputs = self.text_encoder(text_input_ids) | |
| if hasattr(text_outputs, 'last_hidden_state'): | |
| text_hidden_states = text_outputs.last_hidden_state | |
| elif isinstance(text_outputs, tuple): | |
| text_hidden_states = text_outputs[0] | |
| else: | |
| text_hidden_states = text_outputs | |
| text_hidden_states = text_hidden_states.to(self.dtype) | |
| return text_hidden_states, text_attention_mask | |
| def extract_caption_from_sft_format(self, caption: str) -> str: | |
| try: | |
| if "# Instruction" in caption and "# Caption" in caption: | |
| pattern = r'#\s*Caption\s*\n(.*?)(?:\n\s*#\s*Metas|$)' | |
| match = re.search(pattern, caption, re.DOTALL) | |
| if match: | |
| return match.group(1).strip() | |
| return caption | |
| except Exception as e: | |
| logger.exception("[extract_caption_from_sft_format] Error extracting caption") | |
| return caption | |
| def prepare_seeds(self, actual_batch_size, seed, use_random_seed): | |
| actual_seed_list: List[int] = [] | |
| seed_value_for_ui = "" | |
| if use_random_seed: | |
| # Generate brand new seeds and expose them back to the UI | |
| actual_seed_list = [random.randint(0, 2 ** 32 - 1) for _ in range(actual_batch_size)] | |
| seed_value_for_ui = ", ".join(str(s) for s in actual_seed_list) | |
| else: | |
| # Parse seed input: can be a single number, comma-separated numbers, or -1 | |
| # If seed is a string, try to parse it as comma-separated values | |
| seed_list = [] | |
| if isinstance(seed, str): | |
| # Handle string input (e.g., "123,456" or "-1") | |
| seed_str_list = [s.strip() for s in seed.split(",")] | |
| for s in seed_str_list: | |
| if s == "-1" or s == "": | |
| seed_list.append(-1) | |
| else: | |
| try: | |
| seed_list.append(int(float(s))) | |
| except (ValueError, TypeError) as e: | |
| logger.debug(f"[prepare_seeds] Failed to parse seed value '{s}': {e}") | |
| seed_list.append(-1) | |
| elif seed is None or (isinstance(seed, (int, float)) and seed < 0): | |
| # If seed is None or negative, use -1 for all items | |
| seed_list = [-1] * actual_batch_size | |
| elif isinstance(seed, (int, float)): | |
| # Single seed value | |
| seed_list = [int(seed)] | |
| else: | |
| # Fallback: use -1 | |
| seed_list = [-1] * actual_batch_size | |
| # Process seed list according to rules: | |
| # 1. If all are -1, generate different random seeds for each batch item | |
| # 2. If one non-negative seed is provided and batch_size > 1, first uses that seed, rest are random | |
| # 3. If more seeds than batch_size, use first batch_size seeds | |
| # Check if user provided only one non-negative seed (not -1) | |
| has_single_non_negative_seed = (len(seed_list) == 1 and seed_list[0] != -1) | |
| for i in range(actual_batch_size): | |
| if i < len(seed_list): | |
| seed_val = seed_list[i] | |
| else: | |
| # If not enough seeds provided, use -1 (will generate random) | |
| seed_val = -1 | |
| # Special case: if only one non-negative seed was provided and batch_size > 1, | |
| # only the first item uses that seed, others are random | |
| if has_single_non_negative_seed and actual_batch_size > 1 and i > 0: | |
| # Generate random seed for remaining items | |
| actual_seed_list.append(random.randint(0, 2 ** 32 - 1)) | |
| elif seed_val == -1: | |
| # Generate a random seed for this item | |
| actual_seed_list.append(random.randint(0, 2 ** 32 - 1)) | |
| else: | |
| actual_seed_list.append(int(seed_val)) | |
| seed_value_for_ui = ", ".join(str(s) for s in actual_seed_list) | |
| return actual_seed_list, seed_value_for_ui | |
| def prepare_metadata(self, bpm, key_scale, time_signature): | |
| """Build metadata dict - use "N/A" as default for empty fields.""" | |
| return self._build_metadata_dict(bpm, key_scale, time_signature) | |
| def is_silence(self, audio): | |
| return torch.all(audio.abs() < 1e-6) | |
| def _get_project_root(self) -> str: | |
| """Get project root directory path.""" | |
| current_file = os.path.abspath(__file__) | |
| return os.path.dirname(os.path.dirname(current_file)) | |
| def _get_vae_dtype(self, device: Optional[str] = None) -> torch.dtype: | |
| """Get VAE dtype based on device.""" | |
| device = device or self.device | |
| return torch.bfloat16 if device in ["cuda", "xpu"] else self.dtype | |
| def _format_instruction(self, instruction: str) -> str: | |
| """Format instruction to ensure it ends with colon.""" | |
| if not instruction.endswith(":"): | |
| instruction = instruction + ":" | |
| return instruction | |
| def _load_audio_file(self, audio_file) -> Tuple[torch.Tensor, int]: | |
| """ | |
| Load audio file with ffmpeg backend, fallback to soundfile if failed. | |
| This handles CUDA dependency issues with torchcodec on HuggingFace Space. | |
| Args: | |
| audio_file: Path to the audio file | |
| Returns: | |
| Tuple of (audio_tensor, sample_rate) | |
| Raises: | |
| FileNotFoundError: If the audio file doesn't exist | |
| Exception: If all methods fail to load the audio | |
| """ | |
| # Check if file exists first | |
| if not os.path.exists(audio_file): | |
| raise FileNotFoundError(f"Audio file not found: {audio_file}") | |
| # Try torchaudio with explicit ffmpeg backend first | |
| try: | |
| audio, sr = torchaudio.load(audio_file, backend="ffmpeg") | |
| return audio, sr | |
| except Exception as e: | |
| logger.debug(f"[_load_audio_file] ffmpeg backend failed: {e}, trying soundfile fallback") | |
| # Fallback: use soundfile directly (most compatible) | |
| try: | |
| audio_np, sr = sf.read(audio_file) | |
| # soundfile returns [samples, channels] or [samples], convert to [channels, samples] | |
| audio = torch.from_numpy(audio_np).float() | |
| if audio.dim() == 1: | |
| # Mono: [samples] -> [1, samples] | |
| audio = audio.unsqueeze(0) | |
| else: | |
| # Stereo: [samples, channels] -> [channels, samples] | |
| audio = audio.T | |
| return audio, sr | |
| except Exception as e: | |
| logger.error(f"[_load_audio_file] All methods failed to load audio: {audio_file}, error: {e}") | |
| raise | |
| def _normalize_audio_to_stereo_48k(self, audio: torch.Tensor, sr: int) -> torch.Tensor: | |
| """ | |
| Normalize audio to stereo 48kHz format. | |
| Args: | |
| audio: Audio tensor [channels, samples] or [samples] | |
| sr: Sample rate | |
| Returns: | |
| Normalized audio tensor [2, samples] at 48kHz | |
| """ | |
| # Convert to stereo (duplicate channel if mono) | |
| if audio.shape[0] == 1: | |
| audio = torch.cat([audio, audio], dim=0) | |
| # Keep only first 2 channels | |
| audio = audio[:2] | |
| # Resample to 48kHz if needed | |
| if sr != 48000: | |
| audio = torchaudio.transforms.Resample(sr, 48000)(audio) | |
| # Clamp values to [-1.0, 1.0] | |
| audio = torch.clamp(audio, -1.0, 1.0) | |
| return audio | |
| def _enforce_audio_duration_limits( | |
| self, | |
| audio: torch.Tensor, | |
| sample_rate: int = 48000, | |
| min_duration: float = 10.0, | |
| max_duration: float = 600.0 | |
| ) -> torch.Tensor: | |
| """ | |
| Enforce audio duration limits by truncating or repeating. | |
| Args: | |
| audio: Audio tensor [channels, samples] at target sample rate | |
| sample_rate: Sample rate of the audio (default: 48000) | |
| min_duration: Minimum duration in seconds (default: 10.0) | |
| max_duration: Maximum duration in seconds (default: 600.0) | |
| Returns: | |
| Audio tensor with enforced duration limits | |
| """ | |
| current_samples = audio.shape[-1] | |
| current_duration = current_samples / sample_rate | |
| min_samples = int(min_duration * sample_rate) | |
| max_samples = int(max_duration * sample_rate) | |
| # If audio is longer than max_duration, truncate | |
| if current_samples > max_samples: | |
| logger.info(f"[_enforce_audio_duration_limits] Truncating audio from {current_duration:.1f}s to {max_duration:.1f}s") | |
| audio = audio[..., :max_samples] | |
| # If audio is shorter than min_duration, repeat to fill | |
| elif current_samples < min_samples: | |
| logger.info(f"[_enforce_audio_duration_limits] Repeating audio from {current_duration:.1f}s to reach {min_duration:.1f}s") | |
| # Calculate how many times to repeat | |
| repeat_times = int(math.ceil(min_samples / current_samples)) | |
| # Repeat along the time dimension | |
| audio = audio.repeat(1, repeat_times) | |
| # Truncate to exactly min_samples | |
| audio = audio[..., :min_samples] | |
| return audio | |
| def _normalize_audio_code_hints(self, audio_code_hints: Optional[Union[str, List[str]]], batch_size: int) -> List[Optional[str]]: | |
| """Normalize audio_code_hints to list of correct length.""" | |
| if audio_code_hints is None: | |
| normalized = [None] * batch_size | |
| elif isinstance(audio_code_hints, str): | |
| normalized = [audio_code_hints] * batch_size | |
| elif len(audio_code_hints) == 1 and batch_size > 1: | |
| normalized = audio_code_hints * batch_size | |
| elif len(audio_code_hints) != batch_size: | |
| # Pad or truncate to match batch_size | |
| normalized = list(audio_code_hints[:batch_size]) | |
| while len(normalized) < batch_size: | |
| normalized.append(None) | |
| else: | |
| normalized = list(audio_code_hints) | |
| # Clean up: convert empty strings to None | |
| normalized = [hint if isinstance(hint, str) and hint.strip() else None for hint in normalized] | |
| return normalized | |
| def _normalize_instructions(self, instructions: Optional[Union[str, List[str]]], batch_size: int, default: Optional[str] = None) -> List[str]: | |
| """Normalize instructions to list of correct length.""" | |
| if instructions is None: | |
| default_instruction = default or DEFAULT_DIT_INSTRUCTION | |
| return [default_instruction] * batch_size | |
| elif isinstance(instructions, str): | |
| return [instructions] * batch_size | |
| elif len(instructions) == 1: | |
| return instructions * batch_size | |
| elif len(instructions) != batch_size: | |
| # Pad or truncate to match batch_size | |
| normalized = list(instructions[:batch_size]) | |
| default_instruction = default or DEFAULT_DIT_INSTRUCTION | |
| while len(normalized) < batch_size: | |
| normalized.append(default_instruction) | |
| return normalized | |
| else: | |
| return list(instructions) | |
| def _format_lyrics(self, lyrics: str, language: str) -> str: | |
| """Format lyrics text with language header.""" | |
| return f"# Languages\n{language}\n\n# Lyric\n{lyrics}<|endoftext|>" | |
| def _pad_sequences(self, sequences: List[torch.Tensor], max_length: int, pad_value: int = 0) -> torch.Tensor: | |
| """Pad sequences to same length.""" | |
| return torch.stack([ | |
| torch.nn.functional.pad(seq, (0, max_length - len(seq)), 'constant', pad_value) | |
| for seq in sequences | |
| ]) | |
| def _extract_caption_and_language(self, metas: List[Union[str, Dict[str, Any]]], captions: List[str], vocal_languages: List[str]) -> Tuple[List[str], List[str]]: | |
| """Extract caption and language from metas with fallback to provided values.""" | |
| actual_captions = list(captions) | |
| actual_languages = list(vocal_languages) | |
| for i, meta in enumerate(metas): | |
| if i >= len(actual_captions): | |
| break | |
| meta_dict = None | |
| if isinstance(meta, str): | |
| parsed = self._parse_metas([meta]) | |
| if parsed and isinstance(parsed[0], dict): | |
| meta_dict = parsed[0] | |
| elif isinstance(meta, dict): | |
| meta_dict = meta | |
| if meta_dict: | |
| if 'caption' in meta_dict and meta_dict['caption']: | |
| actual_captions[i] = str(meta_dict['caption']) | |
| if 'language' in meta_dict and meta_dict['language']: | |
| actual_languages[i] = str(meta_dict['language']) | |
| return actual_captions, actual_languages | |
| def _encode_audio_to_latents(self, audio: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Encode audio to latents using VAE. | |
| Args: | |
| audio: Audio tensor [channels, samples] or [batch, channels, samples] | |
| Returns: | |
| Latents tensor [T, D] or [batch, T, D] | |
| """ | |
| # Save original dimension info BEFORE modifying audio | |
| input_was_2d = (audio.dim() == 2) | |
| # Ensure batch dimension | |
| if input_was_2d: | |
| audio = audio.unsqueeze(0) | |
| # Ensure input is in VAE's dtype | |
| vae_input = audio.to(self.device).to(self.vae.dtype) | |
| # Encode to latents | |
| with torch.no_grad(): | |
| latents = self.vae.encode(vae_input).latent_dist.sample() | |
| # Cast back to model dtype | |
| latents = latents.to(self.dtype) | |
| # Transpose: [batch, d, T] -> [batch, T, d] | |
| latents = latents.transpose(1, 2) | |
| # Remove batch dimension if input didn't have it | |
| if input_was_2d: | |
| latents = latents.squeeze(0) | |
| return latents | |
| def _build_metadata_dict(self, bpm: Optional[Union[int, str]], key_scale: str, time_signature: str, duration: Optional[float] = None) -> Dict[str, Any]: | |
| """ | |
| Build metadata dictionary with default values. | |
| Args: | |
| bpm: BPM value (optional) | |
| key_scale: Key/scale string | |
| time_signature: Time signature string | |
| duration: Duration in seconds (optional) | |
| Returns: | |
| Metadata dictionary | |
| """ | |
| metadata_dict = {} | |
| if bpm: | |
| metadata_dict["bpm"] = bpm | |
| else: | |
| metadata_dict["bpm"] = "N/A" | |
| if key_scale.strip(): | |
| metadata_dict["keyscale"] = key_scale | |
| else: | |
| metadata_dict["keyscale"] = "N/A" | |
| if time_signature.strip() and time_signature != "N/A" and time_signature: | |
| metadata_dict["timesignature"] = time_signature | |
| else: | |
| metadata_dict["timesignature"] = "N/A" | |
| # Add duration if provided | |
| if duration is not None: | |
| metadata_dict["duration"] = f"{int(duration)} seconds" | |
| return metadata_dict | |
| def generate_instruction( | |
| self, | |
| task_type: str, | |
| track_name: Optional[str] = None, | |
| complete_track_classes: Optional[List[str]] = None | |
| ) -> str: | |
| if task_type == "text2music": | |
| return TASK_INSTRUCTIONS["text2music"] | |
| elif task_type == "repaint": | |
| return TASK_INSTRUCTIONS["repaint"] | |
| elif task_type == "cover": | |
| return TASK_INSTRUCTIONS["cover"] | |
| elif task_type == "extract": | |
| if track_name: | |
| # Convert to uppercase | |
| track_name_upper = track_name.upper() | |
| return TASK_INSTRUCTIONS["extract"].format(TRACK_NAME=track_name_upper) | |
| else: | |
| return TASK_INSTRUCTIONS["extract_default"] | |
| elif task_type == "lego": | |
| if track_name: | |
| # Convert to uppercase | |
| track_name_upper = track_name.upper() | |
| return TASK_INSTRUCTIONS["lego"].format(TRACK_NAME=track_name_upper) | |
| else: | |
| return TASK_INSTRUCTIONS["lego_default"] | |
| elif task_type == "complete": | |
| if complete_track_classes and len(complete_track_classes) > 0: | |
| # Convert to uppercase and join with " | " | |
| track_classes_upper = [t.upper() for t in complete_track_classes] | |
| complete_track_classes_str = " | ".join(track_classes_upper) | |
| return TASK_INSTRUCTIONS["complete"].format(TRACK_CLASSES=complete_track_classes_str) | |
| else: | |
| return TASK_INSTRUCTIONS["complete_default"] | |
| else: | |
| return TASK_INSTRUCTIONS["text2music"] | |
| def process_reference_audio(self, audio_file) -> Optional[torch.Tensor]: | |
| if audio_file is None: | |
| return None | |
| try: | |
| # Load audio file with fallback backends | |
| audio, sr = self._load_audio_file(audio_file) | |
| logger.debug(f"[process_reference_audio] Reference audio shape: {audio.shape}") | |
| logger.debug(f"[process_reference_audio] Reference audio sample rate: {sr}") | |
| logger.debug(f"[process_reference_audio] Reference audio duration: {audio.shape[-1] / 48000.0} seconds") | |
| # Normalize to stereo 48kHz | |
| audio = self._normalize_audio_to_stereo_48k(audio, sr) | |
| is_silence = self.is_silence(audio) | |
| if is_silence: | |
| return None | |
| # Target length: 30 seconds at 48kHz | |
| target_frames = 30 * 48000 | |
| segment_frames = 10 * 48000 # 10 seconds per segment | |
| # If audio is less than 30 seconds, repeat to at least 30 seconds | |
| if audio.shape[-1] < target_frames: | |
| repeat_times = math.ceil(target_frames / audio.shape[-1]) | |
| audio = audio.repeat(1, repeat_times) | |
| # If audio is greater than or equal to 30 seconds, no operation needed | |
| # For all cases, select random 10-second segments from front, middle, and back | |
| # then concatenate them to form 30 seconds | |
| total_frames = audio.shape[-1] | |
| segment_size = total_frames // 3 | |
| # Front segment: [0, segment_size] | |
| front_start = random.randint(0, max(0, segment_size - segment_frames)) | |
| front_audio = audio[:, front_start:front_start + segment_frames] | |
| # Middle segment: [segment_size, 2*segment_size] | |
| middle_start = segment_size + random.randint(0, max(0, segment_size - segment_frames)) | |
| middle_audio = audio[:, middle_start:middle_start + segment_frames] | |
| # Back segment: [2*segment_size, total_frames] | |
| back_start = 2 * segment_size + random.randint(0, max(0, (total_frames - 2 * segment_size) - segment_frames)) | |
| back_audio = audio[:, back_start:back_start + segment_frames] | |
| # Concatenate three segments to form 30 seconds | |
| audio = torch.cat([front_audio, middle_audio, back_audio], dim=-1) | |
| return audio | |
| except Exception as e: | |
| logger.exception("[process_reference_audio] Error processing reference audio") | |
| return None | |
| def process_src_audio(self, audio_file) -> Optional[torch.Tensor]: | |
| if audio_file is None: | |
| return None | |
| try: | |
| # Load audio file with fallback backends | |
| audio, sr = self._load_audio_file(audio_file) | |
| # Normalize to stereo 48kHz | |
| audio = self._normalize_audio_to_stereo_48k(audio, sr) | |
| # Enforce duration limits (10-600 seconds) | |
| audio = self._enforce_audio_duration_limits(audio) | |
| return audio | |
| except Exception as e: | |
| logger.exception("[process_src_audio] Error processing source audio") | |
| return None | |
| def convert_src_audio_to_codes(self, audio_file) -> str: | |
| """ | |
| Convert uploaded source audio to audio codes string. | |
| Args: | |
| audio_file: Path to audio file or None | |
| Returns: | |
| Formatted codes string like '<|audio_code_123|><|audio_code_456|>...' or error message | |
| """ | |
| if audio_file is None: | |
| return "❌ Please upload source audio first" | |
| if self.model is None or self.vae is None: | |
| return "❌ Model not initialized. Please initialize the service first." | |
| try: | |
| # Process audio file | |
| processed_audio = self.process_src_audio(audio_file) | |
| if processed_audio is None: | |
| return "❌ Failed to process audio file" | |
| # Encode audio to latents using VAE | |
| with torch.no_grad(): | |
| with self._load_model_context("vae"): | |
| # Check if audio is silence | |
| if self.is_silence(processed_audio.unsqueeze(0)): | |
| return "❌ Audio file appears to be silent" | |
| # Encode to latents using helper method | |
| latents = self._encode_audio_to_latents(processed_audio) # [T, d] | |
| # Create attention mask for latents | |
| attention_mask = torch.ones(latents.shape[0], dtype=torch.bool, device=self.device) | |
| # Tokenize latents to get code indices | |
| with self._load_model_context("model"): | |
| # Prepare latents for tokenize: [T, d] -> [1, T, d] | |
| hidden_states = latents.unsqueeze(0) # [1, T, d] | |
| # Call tokenize method | |
| # tokenize returns: (quantized, indices, attention_mask) | |
| _, indices, _ = self.model.tokenize(hidden_states, self.silence_latent, attention_mask.unsqueeze(0)) | |
| # Format indices as code string | |
| # indices shape: [1, T_5Hz] or [1, T_5Hz, num_quantizers] | |
| # Flatten and convert to list | |
| indices_flat = indices.flatten().cpu().tolist() | |
| codes_string = "".join([f"<|audio_code_{idx}|>" for idx in indices_flat]) | |
| logger.info(f"[convert_src_audio_to_codes] Generated {len(indices_flat)} audio codes") | |
| return codes_string | |
| except Exception as e: | |
| error_msg = f"❌ Error converting audio to codes: {str(e)}\n{traceback.format_exc()}" | |
| logger.exception("[convert_src_audio_to_codes] Error converting audio to codes") | |
| return error_msg | |
| def prepare_batch_data( | |
| self, | |
| actual_batch_size, | |
| processed_src_audio, | |
| audio_duration, | |
| captions, | |
| lyrics, | |
| vocal_language, | |
| instruction, | |
| bpm, | |
| key_scale, | |
| time_signature | |
| ): | |
| pure_caption = self.extract_caption_from_sft_format(captions) | |
| captions_batch = [pure_caption] * actual_batch_size | |
| instructions_batch = [instruction] * actual_batch_size | |
| lyrics_batch = [lyrics] * actual_batch_size | |
| vocal_languages_batch = [vocal_language] * actual_batch_size | |
| # Calculate duration for metadata | |
| calculated_duration = None | |
| if processed_src_audio is not None: | |
| calculated_duration = processed_src_audio.shape[-1] / 48000.0 | |
| elif audio_duration is not None and audio_duration > 0: | |
| calculated_duration = audio_duration | |
| # Build metadata dict - use "N/A" as default for empty fields | |
| metadata_dict = self._build_metadata_dict(bpm, key_scale, time_signature, calculated_duration) | |
| # Format metadata - inference service accepts dict and will convert to string | |
| # Create a copy for each batch item (in case we modify it) | |
| metas_batch = [metadata_dict.copy() for _ in range(actual_batch_size)] | |
| return captions_batch, instructions_batch, lyrics_batch, vocal_languages_batch, metas_batch | |
| def determine_task_type(self, task_type, audio_code_string): | |
| # Determine task type - repaint and lego tasks can have repainting parameters | |
| # Other tasks (cover, text2music, extract, complete) should NOT have repainting | |
| is_repaint_task = (task_type == "repaint") | |
| is_lego_task = (task_type == "lego") | |
| is_cover_task = (task_type == "cover") | |
| has_codes = False | |
| if isinstance(audio_code_string, list): | |
| has_codes = any((c or "").strip() for c in audio_code_string) | |
| else: | |
| has_codes = bool(audio_code_string and str(audio_code_string).strip()) | |
| if has_codes: | |
| is_cover_task = True | |
| # Both repaint and lego tasks can use repainting parameters for chunk mask | |
| can_use_repainting = is_repaint_task or is_lego_task | |
| return is_repaint_task, is_lego_task, is_cover_task, can_use_repainting | |
| def create_target_wavs(self, duration_seconds: float) -> torch.Tensor: | |
| try: | |
| # Ensure minimum precision of 100ms | |
| duration_seconds = max(0.1, round(duration_seconds, 1)) | |
| # Calculate frames for 48kHz stereo | |
| frames = int(duration_seconds * 48000) | |
| # Create silent stereo audio | |
| target_wavs = torch.zeros(2, frames) | |
| return target_wavs | |
| except Exception as e: | |
| logger.exception("[create_target_wavs] Error creating target audio") | |
| # Fallback to 30 seconds if error | |
| return torch.zeros(2, 30 * 48000) | |
| def prepare_padding_info( | |
| self, | |
| actual_batch_size, | |
| processed_src_audio, | |
| audio_duration, | |
| repainting_start, | |
| repainting_end, | |
| is_repaint_task, | |
| is_lego_task, | |
| is_cover_task, | |
| can_use_repainting, | |
| ): | |
| target_wavs_batch = [] | |
| # Store padding info for each batch item to adjust repainting coordinates | |
| padding_info_batch = [] | |
| for i in range(actual_batch_size): | |
| if processed_src_audio is not None: | |
| if is_cover_task: | |
| # Cover task: Use src_audio directly without padding | |
| batch_target_wavs = processed_src_audio | |
| padding_info_batch.append({ | |
| 'left_padding_duration': 0.0, | |
| 'right_padding_duration': 0.0 | |
| }) | |
| elif is_repaint_task or is_lego_task: | |
| # Repaint/lego task: May need padding for outpainting | |
| src_audio_duration = processed_src_audio.shape[-1] / 48000.0 | |
| # Determine actual end time | |
| if repainting_end is None or repainting_end < 0: | |
| actual_end = src_audio_duration | |
| else: | |
| actual_end = repainting_end | |
| left_padding_duration = max(0, -repainting_start) if repainting_start is not None else 0 | |
| right_padding_duration = max(0, actual_end - src_audio_duration) | |
| # Create padded audio | |
| left_padding_frames = int(left_padding_duration * 48000) | |
| right_padding_frames = int(right_padding_duration * 48000) | |
| if left_padding_frames > 0 or right_padding_frames > 0: | |
| # Pad the src audio | |
| batch_target_wavs = torch.nn.functional.pad( | |
| processed_src_audio, | |
| (left_padding_frames, right_padding_frames), | |
| 'constant', 0 | |
| ) | |
| else: | |
| batch_target_wavs = processed_src_audio | |
| # Store padding info for coordinate adjustment | |
| padding_info_batch.append({ | |
| 'left_padding_duration': left_padding_duration, | |
| 'right_padding_duration': right_padding_duration | |
| }) | |
| else: | |
| # Other tasks: Use src_audio directly without padding | |
| batch_target_wavs = processed_src_audio | |
| padding_info_batch.append({ | |
| 'left_padding_duration': 0.0, | |
| 'right_padding_duration': 0.0 | |
| }) | |
| else: | |
| padding_info_batch.append({ | |
| 'left_padding_duration': 0.0, | |
| 'right_padding_duration': 0.0 | |
| }) | |
| if audio_duration is not None and audio_duration > 0: | |
| batch_target_wavs = self.create_target_wavs(audio_duration) | |
| else: | |
| import random | |
| random_duration = random.uniform(10.0, 120.0) | |
| batch_target_wavs = self.create_target_wavs(random_duration) | |
| target_wavs_batch.append(batch_target_wavs) | |
| # Stack target_wavs into batch tensor | |
| # Ensure all tensors have the same shape by padding to max length | |
| max_frames = max(wav.shape[-1] for wav in target_wavs_batch) | |
| padded_target_wavs = [] | |
| for wav in target_wavs_batch: | |
| if wav.shape[-1] < max_frames: | |
| pad_frames = max_frames - wav.shape[-1] | |
| padded_wav = torch.nn.functional.pad(wav, (0, pad_frames), 'constant', 0) | |
| padded_target_wavs.append(padded_wav) | |
| else: | |
| padded_target_wavs.append(wav) | |
| target_wavs_tensor = torch.stack(padded_target_wavs, dim=0) # [batch_size, 2, frames] | |
| if can_use_repainting: | |
| # Repaint task: Set repainting parameters | |
| if repainting_start is None: | |
| repainting_start_batch = None | |
| elif isinstance(repainting_start, (int, float)): | |
| if processed_src_audio is not None: | |
| adjusted_start = repainting_start + padding_info_batch[0]['left_padding_duration'] | |
| repainting_start_batch = [adjusted_start] * actual_batch_size | |
| else: | |
| repainting_start_batch = [repainting_start] * actual_batch_size | |
| else: | |
| # List input - adjust each item | |
| repainting_start_batch = [] | |
| for i in range(actual_batch_size): | |
| if processed_src_audio is not None: | |
| adjusted_start = repainting_start[i] + padding_info_batch[i]['left_padding_duration'] | |
| repainting_start_batch.append(adjusted_start) | |
| else: | |
| repainting_start_batch.append(repainting_start[i]) | |
| # Handle repainting_end - use src audio duration if not specified or negative | |
| if processed_src_audio is not None: | |
| # If src audio is provided, use its duration as default end | |
| src_audio_duration = processed_src_audio.shape[-1] / 48000.0 | |
| if repainting_end is None or repainting_end < 0: | |
| # Use src audio duration (before padding), then adjust for padding | |
| adjusted_end = src_audio_duration + padding_info_batch[0]['left_padding_duration'] | |
| repainting_end_batch = [adjusted_end] * actual_batch_size | |
| else: | |
| # Adjust repainting_end to be relative to padded audio | |
| adjusted_end = repainting_end + padding_info_batch[0]['left_padding_duration'] | |
| repainting_end_batch = [adjusted_end] * actual_batch_size | |
| else: | |
| # No src audio - repainting doesn't make sense without it | |
| if repainting_end is None or repainting_end < 0: | |
| repainting_end_batch = None | |
| elif isinstance(repainting_end, (int, float)): | |
| repainting_end_batch = [repainting_end] * actual_batch_size | |
| else: | |
| # List input - adjust each item | |
| repainting_end_batch = [] | |
| for i in range(actual_batch_size): | |
| if processed_src_audio is not None: | |
| adjusted_end = repainting_end[i] + padding_info_batch[i]['left_padding_duration'] | |
| repainting_end_batch.append(adjusted_end) | |
| else: | |
| repainting_end_batch.append(repainting_end[i]) | |
| else: | |
| # All other tasks (cover, text2music, extract, complete): No repainting | |
| # Only repaint and lego tasks should have repainting parameters | |
| repainting_start_batch = None | |
| repainting_end_batch = None | |
| return repainting_start_batch, repainting_end_batch, target_wavs_tensor | |
| def _prepare_batch( | |
| self, | |
| captions: List[str], | |
| lyrics: List[str], | |
| keys: Optional[List[str]] = None, | |
| target_wavs: Optional[torch.Tensor] = None, | |
| refer_audios: Optional[List[List[torch.Tensor]]] = None, | |
| metas: Optional[List[Union[str, Dict[str, Any]]]] = None, | |
| vocal_languages: Optional[List[str]] = None, | |
| repainting_start: Optional[List[float]] = None, | |
| repainting_end: Optional[List[float]] = None, | |
| instructions: Optional[List[str]] = None, | |
| audio_code_hints: Optional[List[Optional[str]]] = None, | |
| audio_cover_strength: float = 1.0, | |
| ) -> Dict[str, Any]: | |
| """ | |
| Prepare batch data with fallbacks for missing inputs. | |
| Args: | |
| captions: List of text captions (optional, can be empty strings) | |
| lyrics: List of lyrics (optional, can be empty strings) | |
| keys: List of unique identifiers (optional) | |
| target_wavs: Target audio tensors (optional, will use silence if not provided) | |
| refer_audios: Reference audio tensors (optional, will use silence if not provided) | |
| metas: Metadata (optional, will use defaults if not provided) | |
| vocal_languages: Vocal languages (optional, will default to 'en') | |
| Returns: | |
| Batch dictionary ready for model input | |
| """ | |
| batch_size = len(captions) | |
| # Ensure silence_latent is on the correct device for batch preparation | |
| self._ensure_silence_latent_on_device() | |
| # Normalize audio_code_hints to batch list | |
| audio_code_hints = self._normalize_audio_code_hints(audio_code_hints, batch_size) | |
| # Synchronize CUDA to catch any pending errors from previous operations | |
| if torch.cuda.is_available(): | |
| torch.cuda.synchronize() | |
| for ii, refer_audio_list in enumerate(refer_audios): | |
| if refer_audio_list is None: | |
| continue | |
| if isinstance(refer_audio_list, list): | |
| for idx, refer_audio in enumerate(refer_audio_list): | |
| if refer_audio is not None and isinstance(refer_audio, torch.Tensor): | |
| refer_audio_list[idx] = refer_audio.to(self.device).to(torch.bfloat16) | |
| elif isinstance(refer_audio_list, torch.Tensor): | |
| refer_audios[ii] = refer_audio_list.to(self.device) | |
| if vocal_languages is None: | |
| vocal_languages = self._create_fallback_vocal_languages(batch_size) | |
| # Parse metas with fallbacks | |
| parsed_metas = self._parse_metas(metas) | |
| # Encode target_wavs to get target_latents | |
| with torch.no_grad(): | |
| target_latents_list = [] | |
| latent_lengths = [] | |
| # Use per-item wavs (may be adjusted if audio_code_hints are provided) | |
| target_wavs_list = [target_wavs[i].clone() for i in range(batch_size)] | |
| if target_wavs.device != self.device: | |
| target_wavs = target_wavs.to(self.device) | |
| with self._load_model_context("vae"): | |
| for i in range(batch_size): | |
| code_hint = audio_code_hints[i] | |
| # Prefer decoding from provided audio codes | |
| if code_hint: | |
| logger.info(f"[generate_music] Decoding audio codes for item {i}...") | |
| decoded_latents = self._decode_audio_codes_to_latents(code_hint) | |
| if decoded_latents is not None: | |
| decoded_latents = decoded_latents.squeeze(0) | |
| target_latents_list.append(decoded_latents) | |
| latent_lengths.append(decoded_latents.shape[0]) | |
| # Create a silent wav matching the latent length for downstream scaling | |
| frames_from_codes = max(1, int(decoded_latents.shape[0] * 1920)) | |
| target_wavs_list[i] = torch.zeros(2, frames_from_codes) | |
| continue | |
| # Fallback to VAE encode from audio | |
| current_wav = target_wavs_list[i].to(self.device).unsqueeze(0) | |
| if self.is_silence(current_wav): | |
| expected_latent_length = current_wav.shape[-1] // 1920 | |
| target_latent = self.silence_latent[0, :expected_latent_length, :] | |
| else: | |
| # Encode using helper method | |
| logger.info(f"[generate_music] Encoding target audio to latents for item {i}...") | |
| target_latent = self._encode_audio_to_latents(current_wav.squeeze(0)) # Remove batch dim for helper | |
| target_latents_list.append(target_latent) | |
| latent_lengths.append(target_latent.shape[0]) | |
| # Pad target_wavs to consistent length for outputs | |
| max_target_frames = max(wav.shape[-1] for wav in target_wavs_list) | |
| padded_target_wavs = [] | |
| for wav in target_wavs_list: | |
| if wav.shape[-1] < max_target_frames: | |
| pad_frames = max_target_frames - wav.shape[-1] | |
| wav = torch.nn.functional.pad(wav, (0, pad_frames), "constant", 0) | |
| padded_target_wavs.append(wav) | |
| target_wavs = torch.stack(padded_target_wavs) | |
| wav_lengths = torch.tensor([target_wavs.shape[-1]] * batch_size, dtype=torch.long) | |
| # Pad latents to same length | |
| max_latent_length = max(latent.shape[0] for latent in target_latents_list) | |
| max_latent_length = max(128, max_latent_length) | |
| padded_latents = [] | |
| for latent in target_latents_list: | |
| latent_length = latent.shape[0] | |
| if latent.shape[0] < max_latent_length: | |
| pad_length = max_latent_length - latent.shape[0] | |
| latent = torch.cat([latent, self.silence_latent[0, :pad_length, :]], dim=0) | |
| padded_latents.append(latent) | |
| target_latents = torch.stack(padded_latents) | |
| latent_masks = torch.stack([ | |
| torch.cat([ | |
| torch.ones(l, dtype=torch.long, device=self.device), | |
| torch.zeros(max_latent_length - l, dtype=torch.long, device=self.device) | |
| ]) | |
| for l in latent_lengths | |
| ]) | |
| # Process instructions early so we can use them for task type detection | |
| # Use custom instructions if provided, otherwise use default | |
| instructions = self._normalize_instructions(instructions, batch_size, DEFAULT_DIT_INSTRUCTION) | |
| # Generate chunk_masks and spans based on repainting parameters | |
| # Also determine if this is a cover task (target audio provided without repainting) | |
| chunk_masks = [] | |
| spans = [] | |
| is_covers = [] | |
| # Store repainting latent ranges for later use in src_latents creation | |
| repainting_ranges = {} # {batch_idx: (start_latent, end_latent)} | |
| for i in range(batch_size): | |
| has_code_hint = audio_code_hints[i] is not None | |
| # Check if repainting is enabled for this batch item | |
| has_repainting = False | |
| if repainting_start is not None and repainting_end is not None: | |
| start_sec = repainting_start[i] if repainting_start[i] is not None else 0.0 | |
| end_sec = repainting_end[i] | |
| if end_sec is not None and end_sec > start_sec: | |
| # Repainting mode with outpainting support | |
| # The target_wavs may have been padded for outpainting | |
| # Need to calculate the actual position in the padded audio | |
| # Calculate padding (if start < 0, there's left padding) | |
| left_padding_sec = max(0, -start_sec) | |
| # Adjust positions to account for padding | |
| # In the padded audio, the original start is shifted by left_padding | |
| adjusted_start_sec = start_sec + left_padding_sec | |
| adjusted_end_sec = end_sec + left_padding_sec | |
| # Convert seconds to latent frames (audio_frames / 1920 = latent_frames) | |
| start_latent = int(adjusted_start_sec * self.sample_rate // 1920) | |
| end_latent = int(adjusted_end_sec * self.sample_rate // 1920) | |
| # Clamp to valid range | |
| start_latent = max(0, min(start_latent, max_latent_length - 1)) | |
| end_latent = max(start_latent + 1, min(end_latent, max_latent_length)) | |
| # Create mask: False = keep original, True = generate new | |
| mask = torch.zeros(max_latent_length, dtype=torch.bool, device=self.device) | |
| mask[start_latent:end_latent] = True | |
| chunk_masks.append(mask) | |
| spans.append(("repainting", start_latent, end_latent)) | |
| # Store repainting range for later use | |
| repainting_ranges[i] = (start_latent, end_latent) | |
| has_repainting = True | |
| is_covers.append(False) # Repainting is not cover task | |
| else: | |
| # Full generation (no valid repainting range) | |
| chunk_masks.append(torch.ones(max_latent_length, dtype=torch.bool, device=self.device)) | |
| spans.append(("full", 0, max_latent_length)) | |
| # Determine task type from instruction, not from target_wavs | |
| # Only cover task should have is_cover=True | |
| instruction_i = instructions[i] if instructions and i < len(instructions) else "" | |
| instruction_lower = instruction_i.lower() | |
| # Cover task instruction: "Generate audio semantic tokens based on the given conditions:" | |
| is_cover = ("generate audio semantic tokens" in instruction_lower and | |
| "based on the given conditions" in instruction_lower) or has_code_hint | |
| is_covers.append(is_cover) | |
| else: | |
| # Full generation (no repainting parameters) | |
| chunk_masks.append(torch.ones(max_latent_length, dtype=torch.bool, device=self.device)) | |
| spans.append(("full", 0, max_latent_length)) | |
| # Determine task type from instruction, not from target_wavs | |
| # Only cover task should have is_cover=True | |
| instruction_i = instructions[i] if instructions and i < len(instructions) else "" | |
| instruction_lower = instruction_i.lower() | |
| # Cover task instruction: "Generate audio semantic tokens based on the given conditions:" | |
| is_cover = ("generate audio semantic tokens" in instruction_lower and | |
| "based on the given conditions" in instruction_lower) or has_code_hint | |
| is_covers.append(is_cover) | |
| chunk_masks = torch.stack(chunk_masks) | |
| is_covers = torch.BoolTensor(is_covers).to(self.device) | |
| # Create src_latents based on task type | |
| # For cover/extract/complete/lego/repaint tasks: src_latents = target_latents.clone() (if target_wavs provided) | |
| # For text2music task: src_latents = silence_latent (if no target_wavs or silence) | |
| # For repaint task: additionally replace inpainting region with silence_latent | |
| src_latents_list = [] | |
| silence_latent_tiled = self.silence_latent[0, :max_latent_length, :] | |
| for i in range(batch_size): | |
| # Check if target_wavs is provided and not silent (for extract/complete/lego/cover/repaint tasks) | |
| has_code_hint = audio_code_hints[i] is not None | |
| has_target_audio = has_code_hint or (target_wavs is not None and target_wavs[i].abs().sum() > 1e-6) | |
| if has_target_audio: | |
| # For tasks that use input audio (cover/extract/complete/lego/repaint) | |
| # Check if this item has repainting | |
| item_has_repainting = (i in repainting_ranges) | |
| if item_has_repainting: | |
| # Repaint task: src_latents = target_latents with inpainting region replaced by silence_latent | |
| # 1. Clone target_latents (encoded from src audio, preserving original audio) | |
| src_latent = target_latents[i].clone() | |
| # 2. Replace inpainting region with silence_latent | |
| start_latent, end_latent = repainting_ranges[i] | |
| src_latent[start_latent:end_latent] = silence_latent_tiled[start_latent:end_latent] | |
| src_latents_list.append(src_latent) | |
| else: | |
| # Cover/extract/complete/lego tasks: src_latents = target_latents.clone() | |
| # All these tasks need to base on input audio | |
| src_latents_list.append(target_latents[i].clone()) | |
| else: | |
| # Text2music task: src_latents = silence_latent (no input audio) | |
| # Use silence_latent for the full length | |
| src_latents_list.append(silence_latent_tiled.clone()) | |
| src_latents = torch.stack(src_latents_list) | |
| # Process audio_code_hints to generate precomputed_lm_hints_25Hz | |
| precomputed_lm_hints_25Hz_list = [] | |
| for i in range(batch_size): | |
| if audio_code_hints[i] is not None: | |
| # Decode audio codes to 25Hz latents | |
| logger.info(f"[generate_music] Decoding audio codes for LM hints for item {i}...") | |
| hints = self._decode_audio_codes_to_latents(audio_code_hints[i]) | |
| if hints is not None: | |
| # Pad or crop to match max_latent_length | |
| if hints.shape[1] < max_latent_length: | |
| pad_length = max_latent_length - hints.shape[1] | |
| pad = self.silence_latent | |
| # Match dims: hints is usually [1, T, D], silence_latent is [1, T, D] | |
| if pad.dim() == 2: | |
| pad = pad.unsqueeze(0) | |
| if hints.dim() == 2: | |
| hints = hints.unsqueeze(0) | |
| pad_chunk = pad[:, :pad_length, :] | |
| if pad_chunk.device != hints.device or pad_chunk.dtype != hints.dtype: | |
| pad_chunk = pad_chunk.to(device=hints.device, dtype=hints.dtype) | |
| hints = torch.cat([hints, pad_chunk], dim=1) | |
| elif hints.shape[1] > max_latent_length: | |
| hints = hints[:, :max_latent_length, :] | |
| precomputed_lm_hints_25Hz_list.append(hints[0]) # Remove batch dimension | |
| else: | |
| precomputed_lm_hints_25Hz_list.append(None) | |
| else: | |
| precomputed_lm_hints_25Hz_list.append(None) | |
| # Stack precomputed hints if any exist, otherwise set to None | |
| if any(h is not None for h in precomputed_lm_hints_25Hz_list): | |
| # For items without hints, use silence_latent as placeholder | |
| precomputed_lm_hints_25Hz = torch.stack([ | |
| h if h is not None else silence_latent_tiled | |
| for h in precomputed_lm_hints_25Hz_list | |
| ]) | |
| else: | |
| precomputed_lm_hints_25Hz = None | |
| # Extract caption and language from metas if available (from LM CoT output) | |
| # Fallback to user-provided values if not in metas | |
| actual_captions, actual_languages = self._extract_caption_and_language(parsed_metas, captions, vocal_languages) | |
| # Format text_inputs | |
| text_inputs = [] | |
| text_token_idss = [] | |
| text_attention_masks = [] | |
| lyric_token_idss = [] | |
| lyric_attention_masks = [] | |
| for i in range(batch_size): | |
| # Use custom instruction for this batch item | |
| instruction = self._format_instruction(instructions[i] if i < len(instructions) else DEFAULT_DIT_INSTRUCTION) | |
| actual_caption = actual_captions[i] | |
| actual_language = actual_languages[i] | |
| # Format text prompt with custom instruction (using LM-generated caption if available) | |
| text_prompt = SFT_GEN_PROMPT.format(instruction, actual_caption, parsed_metas[i]) | |
| # Tokenize text | |
| text_inputs_dict = self.text_tokenizer( | |
| text_prompt, | |
| padding="longest", | |
| truncation=True, | |
| max_length=256, | |
| return_tensors="pt", | |
| ) | |
| text_token_ids = text_inputs_dict.input_ids[0] | |
| text_attention_mask = text_inputs_dict.attention_mask[0].bool() | |
| # Format and tokenize lyrics (using LM-generated language if available) | |
| lyrics_text = self._format_lyrics(lyrics[i], actual_language) | |
| lyrics_inputs_dict = self.text_tokenizer( | |
| lyrics_text, | |
| padding="longest", | |
| truncation=True, | |
| max_length=2048, | |
| return_tensors="pt", | |
| ) | |
| lyric_token_ids = lyrics_inputs_dict.input_ids[0] | |
| lyric_attention_mask = lyrics_inputs_dict.attention_mask[0].bool() | |
| # Build full text input | |
| text_input = text_prompt + "\n\n" + lyrics_text | |
| text_inputs.append(text_input) | |
| text_token_idss.append(text_token_ids) | |
| text_attention_masks.append(text_attention_mask) | |
| lyric_token_idss.append(lyric_token_ids) | |
| lyric_attention_masks.append(lyric_attention_mask) | |
| # Pad tokenized sequences | |
| max_text_length = max(len(seq) for seq in text_token_idss) | |
| padded_text_token_idss = self._pad_sequences(text_token_idss, max_text_length, self.text_tokenizer.pad_token_id) | |
| padded_text_attention_masks = self._pad_sequences(text_attention_masks, max_text_length, 0) | |
| max_lyric_length = max(len(seq) for seq in lyric_token_idss) | |
| padded_lyric_token_idss = self._pad_sequences(lyric_token_idss, max_lyric_length, self.text_tokenizer.pad_token_id) | |
| padded_lyric_attention_masks = self._pad_sequences(lyric_attention_masks, max_lyric_length, 0) | |
| padded_non_cover_text_input_ids = None | |
| padded_non_cover_text_attention_masks = None | |
| if audio_cover_strength < 1.0: | |
| non_cover_text_input_ids = [] | |
| non_cover_text_attention_masks = [] | |
| for i in range(batch_size): | |
| # Use custom instruction for this batch item | |
| instruction = self._format_instruction(DEFAULT_DIT_INSTRUCTION) | |
| # Extract caption from metas if available (from LM CoT output) | |
| actual_caption = actual_captions[i] | |
| # Format text prompt with custom instruction (using LM-generated caption if available) | |
| text_prompt = SFT_GEN_PROMPT.format(instruction, actual_caption, parsed_metas[i]) | |
| # Tokenize text | |
| text_inputs_dict = self.text_tokenizer( | |
| text_prompt, | |
| padding="longest", | |
| truncation=True, | |
| max_length=256, | |
| return_tensors="pt", | |
| ) | |
| text_token_ids = text_inputs_dict.input_ids[0] | |
| non_cover_text_attention_mask = text_inputs_dict.attention_mask[0].bool() | |
| non_cover_text_input_ids.append(text_token_ids) | |
| non_cover_text_attention_masks.append(non_cover_text_attention_mask) | |
| padded_non_cover_text_input_ids = self._pad_sequences(non_cover_text_input_ids, max_text_length, self.text_tokenizer.pad_token_id) | |
| padded_non_cover_text_attention_masks = self._pad_sequences(non_cover_text_attention_masks, max_text_length, 0) | |
| if audio_cover_strength < 1.0: | |
| assert padded_non_cover_text_input_ids is not None, "When audio_cover_strength < 1.0, padded_non_cover_text_input_ids must not be None" | |
| assert padded_non_cover_text_attention_masks is not None, "When audio_cover_strength < 1.0, padded_non_cover_text_attention_masks must not be None" | |
| # Prepare batch | |
| batch = { | |
| "keys": keys, | |
| "target_wavs": target_wavs.to(self.device), | |
| "refer_audioss": refer_audios, | |
| "wav_lengths": wav_lengths.to(self.device), | |
| "captions": captions, | |
| "lyrics": lyrics, | |
| "metas": parsed_metas, | |
| "vocal_languages": vocal_languages, | |
| "target_latents": target_latents, | |
| "src_latents": src_latents, | |
| "latent_masks": latent_masks, | |
| "chunk_masks": chunk_masks, | |
| "spans": spans, | |
| "text_inputs": text_inputs, | |
| "text_token_idss": padded_text_token_idss, | |
| "text_attention_masks": padded_text_attention_masks, | |
| "lyric_token_idss": padded_lyric_token_idss, | |
| "lyric_attention_masks": padded_lyric_attention_masks, | |
| "is_covers": is_covers, | |
| "precomputed_lm_hints_25Hz": precomputed_lm_hints_25Hz, | |
| "non_cover_text_input_ids": padded_non_cover_text_input_ids, | |
| "non_cover_text_attention_masks": padded_non_cover_text_attention_masks, | |
| } | |
| # to device | |
| for k, v in batch.items(): | |
| if isinstance(v, torch.Tensor): | |
| batch[k] = v.to(self.device) | |
| if torch.is_floating_point(v): | |
| batch[k] = v.to(self.dtype) | |
| return batch | |
| def infer_refer_latent(self, refer_audioss): | |
| refer_audio_order_mask = [] | |
| refer_audio_latents = [] | |
| # Ensure silence_latent is on the correct device | |
| self._ensure_silence_latent_on_device() | |
| def _normalize_audio_2d(a: torch.Tensor) -> torch.Tensor: | |
| """Normalize audio tensor to [2, T] on current device.""" | |
| if not isinstance(a, torch.Tensor): | |
| raise TypeError(f"refer_audio must be a torch.Tensor, got {type(a)!r}") | |
| # Accept [T], [1, T], [2, T], [1, 2, T] | |
| if a.dim() == 3 and a.shape[0] == 1: | |
| a = a.squeeze(0) | |
| if a.dim() == 1: | |
| a = a.unsqueeze(0) | |
| if a.dim() != 2: | |
| raise ValueError(f"refer_audio must be 1D/2D/3D(1,2,T); got shape={tuple(a.shape)}") | |
| if a.shape[0] == 1: | |
| a = torch.cat([a, a], dim=0) | |
| a = a[:2] | |
| return a | |
| def _ensure_latent_3d(z: torch.Tensor) -> torch.Tensor: | |
| """Ensure latent is [N, T, D] (3D) for packing.""" | |
| if z.dim() == 4 and z.shape[0] == 1: | |
| z = z.squeeze(0) | |
| if z.dim() == 2: | |
| z = z.unsqueeze(0) | |
| return z | |
| for batch_idx, refer_audios in enumerate(refer_audioss): | |
| if len(refer_audios) == 1 and torch.all(refer_audios[0] == 0.0): | |
| refer_audio_latent = _ensure_latent_3d(self.silence_latent[:, :750, :]) | |
| refer_audio_latents.append(refer_audio_latent) | |
| refer_audio_order_mask.append(batch_idx) | |
| else: | |
| for refer_audio in refer_audios: | |
| refer_audio = _normalize_audio_2d(refer_audio) | |
| # Ensure input is in VAE's dtype | |
| vae_input = refer_audio.unsqueeze(0).to(self.vae.dtype) | |
| refer_audio_latent = self.vae.encode(vae_input).latent_dist.sample() | |
| # Cast back to model dtype | |
| refer_audio_latent = refer_audio_latent.to(self.dtype) | |
| refer_audio_latents.append(_ensure_latent_3d(refer_audio_latent.transpose(1, 2))) | |
| refer_audio_order_mask.append(batch_idx) | |
| refer_audio_latents = torch.cat(refer_audio_latents, dim=0) | |
| refer_audio_order_mask = torch.tensor(refer_audio_order_mask, device=self.device, dtype=torch.long) | |
| return refer_audio_latents, refer_audio_order_mask | |
| def infer_text_embeddings(self, text_token_idss): | |
| with torch.no_grad(): | |
| text_embeddings = self.text_encoder(input_ids=text_token_idss, lyric_attention_mask=None).last_hidden_state | |
| return text_embeddings | |
| def infer_lyric_embeddings(self, lyric_token_ids): | |
| with torch.no_grad(): | |
| lyric_embeddings = self.text_encoder.embed_tokens(lyric_token_ids) | |
| return lyric_embeddings | |
| def preprocess_batch(self, batch): | |
| # step 1: VAE encode latents, target_latents: N x T x d | |
| # target_latents: N x T x d | |
| target_latents = batch["target_latents"] | |
| src_latents = batch["src_latents"] | |
| attention_mask = batch["latent_masks"] | |
| audio_codes = batch.get("audio_codes", None) | |
| audio_attention_mask = attention_mask | |
| dtype = target_latents.dtype | |
| bs = target_latents.shape[0] | |
| device = target_latents.device | |
| # step 2: refer_audio timbre | |
| keys = batch["keys"] | |
| with self._load_model_context("vae"): | |
| refer_audio_acoustic_hidden_states_packed, refer_audio_order_mask = self.infer_refer_latent(batch["refer_audioss"]) | |
| if refer_audio_acoustic_hidden_states_packed.dtype != dtype: | |
| refer_audio_acoustic_hidden_states_packed = refer_audio_acoustic_hidden_states_packed.to(dtype) | |
| # step 4: chunk mask, N x T x d | |
| chunk_mask = batch["chunk_masks"] | |
| chunk_mask = chunk_mask.to(device).unsqueeze(-1).repeat(1, 1, target_latents.shape[2]) | |
| spans = batch["spans"] | |
| text_token_idss = batch["text_token_idss"] | |
| text_attention_mask = batch["text_attention_masks"] | |
| lyric_token_idss = batch["lyric_token_idss"] | |
| lyric_attention_mask = batch["lyric_attention_masks"] | |
| text_inputs = batch["text_inputs"] | |
| logger.info("[preprocess_batch] Inferring prompt embeddings...") | |
| with self._load_model_context("text_encoder"): | |
| text_hidden_states = self.infer_text_embeddings(text_token_idss) | |
| logger.info("[preprocess_batch] Inferring lyric embeddings...") | |
| lyric_hidden_states = self.infer_lyric_embeddings(lyric_token_idss) | |
| is_covers = batch["is_covers"] | |
| # Get precomputed hints from batch if available | |
| precomputed_lm_hints_25Hz = batch.get("precomputed_lm_hints_25Hz", None) | |
| # Get non-cover text input ids and attention masks from batch if available | |
| non_cover_text_input_ids = batch.get("non_cover_text_input_ids", None) | |
| non_cover_text_attention_masks = batch.get("non_cover_text_attention_masks", None) | |
| non_cover_text_hidden_states = None | |
| if non_cover_text_input_ids is not None: | |
| logger.info("[preprocess_batch] Inferring non-cover text embeddings...") | |
| non_cover_text_hidden_states = self.infer_text_embeddings(non_cover_text_input_ids) | |
| return ( | |
| keys, | |
| text_inputs, | |
| src_latents, | |
| target_latents, | |
| # model inputs | |
| text_hidden_states, | |
| text_attention_mask, | |
| lyric_hidden_states, | |
| lyric_attention_mask, | |
| audio_attention_mask, | |
| refer_audio_acoustic_hidden_states_packed, | |
| refer_audio_order_mask, | |
| chunk_mask, | |
| spans, | |
| is_covers, | |
| audio_codes, | |
| lyric_token_idss, | |
| precomputed_lm_hints_25Hz, | |
| non_cover_text_hidden_states, | |
| non_cover_text_attention_masks, | |
| ) | |
| def service_generate( | |
| self, | |
| captions: Union[str, List[str]], | |
| lyrics: Union[str, List[str]], | |
| keys: Optional[Union[str, List[str]]] = None, | |
| target_wavs: Optional[torch.Tensor] = None, | |
| refer_audios: Optional[List[List[torch.Tensor]]] = None, | |
| metas: Optional[Union[str, Dict[str, Any], List[Union[str, Dict[str, Any]]]]] = None, | |
| vocal_languages: Optional[Union[str, List[str]]] = None, | |
| infer_steps: int = 60, | |
| guidance_scale: float = 7.0, | |
| seed: Optional[Union[int, List[int]]] = None, | |
| return_intermediate: bool = False, | |
| repainting_start: Optional[Union[float, List[float]]] = None, | |
| repainting_end: Optional[Union[float, List[float]]] = None, | |
| instructions: Optional[Union[str, List[str]]] = None, | |
| audio_cover_strength: float = 1.0, | |
| use_adg: bool = False, | |
| cfg_interval_start: float = 0.0, | |
| cfg_interval_end: float = 1.0, | |
| shift: float = 1.0, | |
| audio_code_hints: Optional[Union[str, List[str]]] = None, | |
| infer_method: str = "ode", | |
| timesteps: Optional[List[float]] = None, | |
| ) -> Dict[str, Any]: | |
| """ | |
| Generate music from text inputs. | |
| Args: | |
| captions: Text caption(s) describing the music (optional, can be empty strings) | |
| lyrics: Lyric text(s) (optional, can be empty strings) | |
| keys: Unique identifier(s) (optional) | |
| target_wavs: Target audio tensor(s) for conditioning (optional) | |
| refer_audios: Reference audio tensor(s) for style transfer (optional) | |
| metas: Metadata dict(s) or string(s) (optional) | |
| vocal_languages: Language code(s) for lyrics (optional, defaults to 'en') | |
| infer_steps: Number of inference steps (default: 60) | |
| guidance_scale: Guidance scale for generation (default: 7.0) | |
| seed: Random seed (optional) | |
| return_intermediate: Whether to return intermediate results (default: False) | |
| repainting_start: Start time(s) for repainting region in seconds (optional) | |
| repainting_end: End time(s) for repainting region in seconds (optional) | |
| instructions: Instruction text(s) for generation (optional) | |
| audio_cover_strength: Strength of audio cover mode (default: 1.0) | |
| use_adg: Whether to use ADG (Adaptive Diffusion Guidance) (default: False) | |
| cfg_interval_start: Start of CFG interval (0.0-1.0, default: 0.0) | |
| cfg_interval_end: End of CFG interval (0.0-1.0, default: 1.0) | |
| Returns: | |
| Dictionary containing: | |
| - pred_wavs: Generated audio tensors | |
| - target_wavs: Input target audio (if provided) | |
| - vqvae_recon_wavs: VAE reconstruction of target | |
| - keys: Identifiers used | |
| - text_inputs: Formatted text inputs | |
| - sr: Sample rate | |
| - spans: Generation spans | |
| - time_costs: Timing information | |
| - seed_num: Seed used | |
| """ | |
| if self.config.is_turbo: | |
| # Limit inference steps to maximum 8 | |
| if infer_steps > 8: | |
| logger.warning(f"[service_generate] dmd_gan version: infer_steps {infer_steps} exceeds maximum 8, clamping to 8") | |
| infer_steps = 8 | |
| # CFG parameters are not adjustable for dmd_gan (they will be ignored) | |
| # Note: guidance_scale, cfg_interval_start, cfg_interval_end are still passed but may be ignored by the model | |
| # Convert single inputs to lists | |
| if isinstance(captions, str): | |
| captions = [captions] | |
| if isinstance(lyrics, str): | |
| lyrics = [lyrics] | |
| if isinstance(keys, str): | |
| keys = [keys] | |
| if isinstance(vocal_languages, str): | |
| vocal_languages = [vocal_languages] | |
| if isinstance(metas, (str, dict)): | |
| metas = [metas] | |
| # Convert repainting parameters to lists | |
| if isinstance(repainting_start, (int, float)): | |
| repainting_start = [repainting_start] | |
| if isinstance(repainting_end, (int, float)): | |
| repainting_end = [repainting_end] | |
| # Get batch size from captions | |
| batch_size = len(captions) | |
| # Normalize instructions and audio_code_hints to match batch size | |
| instructions = self._normalize_instructions(instructions, batch_size, DEFAULT_DIT_INSTRUCTION) if instructions is not None else None | |
| audio_code_hints = self._normalize_audio_code_hints(audio_code_hints, batch_size) if audio_code_hints is not None else None | |
| # Convert seed to list format | |
| if seed is None: | |
| seed_list = None | |
| elif isinstance(seed, list): | |
| seed_list = seed | |
| # Ensure we have enough seeds for batch size | |
| if len(seed_list) < batch_size: | |
| # Pad with last seed or random seeds | |
| import random | |
| while len(seed_list) < batch_size: | |
| seed_list.append(random.randint(0, 2**32 - 1)) | |
| elif len(seed_list) > batch_size: | |
| # Truncate to batch size | |
| seed_list = seed_list[:batch_size] | |
| else: | |
| # Single seed value - use for all batch items | |
| seed_list = [int(seed)] * batch_size | |
| # Don't set global random seed here - each item will use its own seed | |
| # Prepare batch | |
| batch = self._prepare_batch( | |
| captions=captions, | |
| lyrics=lyrics, | |
| keys=keys, | |
| target_wavs=target_wavs, | |
| refer_audios=refer_audios, | |
| metas=metas, | |
| vocal_languages=vocal_languages, | |
| repainting_start=repainting_start, | |
| repainting_end=repainting_end, | |
| instructions=instructions, | |
| audio_code_hints=audio_code_hints, | |
| audio_cover_strength=audio_cover_strength, | |
| ) | |
| processed_data = self.preprocess_batch(batch) | |
| ( | |
| keys, | |
| text_inputs, | |
| src_latents, | |
| target_latents, | |
| # model inputs | |
| text_hidden_states, | |
| text_attention_mask, | |
| lyric_hidden_states, | |
| lyric_attention_mask, | |
| audio_attention_mask, | |
| refer_audio_acoustic_hidden_states_packed, | |
| refer_audio_order_mask, | |
| chunk_mask, | |
| spans, | |
| is_covers, | |
| audio_codes, | |
| lyric_token_idss, | |
| precomputed_lm_hints_25Hz, | |
| non_cover_text_hidden_states, | |
| non_cover_text_attention_masks, | |
| ) = processed_data | |
| # Set generation parameters | |
| # Use seed_list if available, otherwise generate a single seed | |
| if seed_list is not None: | |
| # Pass seed list to model (will be handled there) | |
| seed_param = seed_list | |
| else: | |
| seed_param = random.randint(0, 2**32 - 1) | |
| # Ensure silence_latent is on the correct device before creating generate_kwargs | |
| self._ensure_silence_latent_on_device() | |
| generate_kwargs = { | |
| "text_hidden_states": text_hidden_states, | |
| "text_attention_mask": text_attention_mask, | |
| "lyric_hidden_states": lyric_hidden_states, | |
| "lyric_attention_mask": lyric_attention_mask, | |
| "refer_audio_acoustic_hidden_states_packed": refer_audio_acoustic_hidden_states_packed, | |
| "refer_audio_order_mask": refer_audio_order_mask, | |
| "src_latents": src_latents, | |
| "chunk_masks": chunk_mask, | |
| "is_covers": is_covers, | |
| "silence_latent": self.silence_latent, | |
| "seed": seed_param, | |
| "non_cover_text_hidden_states": non_cover_text_hidden_states, | |
| "non_cover_text_attention_mask": non_cover_text_attention_masks, | |
| "precomputed_lm_hints_25Hz": precomputed_lm_hints_25Hz, | |
| "audio_cover_strength": audio_cover_strength, | |
| "infer_method": infer_method, | |
| "infer_steps": infer_steps, | |
| "diffusion_guidance_sale": guidance_scale, | |
| "use_adg": use_adg, | |
| "cfg_interval_start": cfg_interval_start, | |
| "cfg_interval_end": cfg_interval_end, | |
| "shift": shift, | |
| } | |
| # Add custom timesteps if provided (convert to tensor) | |
| if timesteps is not None: | |
| generate_kwargs["timesteps"] = torch.tensor(timesteps, dtype=torch.float32) | |
| logger.info("[service_generate] Generating audio...") | |
| with self._load_model_context("model"): | |
| # Prepare condition tensors first (for LRC timestamp generation) | |
| encoder_hidden_states, encoder_attention_mask, context_latents = self.model.prepare_condition( | |
| text_hidden_states=text_hidden_states, | |
| text_attention_mask=text_attention_mask, | |
| lyric_hidden_states=lyric_hidden_states, | |
| lyric_attention_mask=lyric_attention_mask, | |
| refer_audio_acoustic_hidden_states_packed=refer_audio_acoustic_hidden_states_packed, | |
| refer_audio_order_mask=refer_audio_order_mask, | |
| hidden_states=src_latents, | |
| attention_mask=torch.ones(src_latents.shape[0], src_latents.shape[1], device=src_latents.device, dtype=src_latents.dtype), | |
| silence_latent=self.silence_latent, | |
| src_latents=src_latents, | |
| chunk_masks=chunk_mask, | |
| is_covers=is_covers, | |
| precomputed_lm_hints_25Hz=precomputed_lm_hints_25Hz, | |
| ) | |
| outputs = self.model.generate_audio(**generate_kwargs) | |
| # Add intermediate information to outputs for extra_outputs | |
| outputs["src_latents"] = src_latents | |
| outputs["target_latents_input"] = target_latents # Input target latents (before generation) | |
| outputs["chunk_masks"] = chunk_mask | |
| outputs["spans"] = spans | |
| outputs["latent_masks"] = batch.get("latent_masks") # Latent masks for valid length | |
| # Add condition tensors for LRC timestamp generation | |
| outputs["encoder_hidden_states"] = encoder_hidden_states | |
| outputs["encoder_attention_mask"] = encoder_attention_mask | |
| outputs["context_latents"] = context_latents | |
| outputs["lyric_token_idss"] = lyric_token_idss | |
| return outputs | |
| def tiled_decode(self, latents, chunk_size=512, overlap=64, offload_wav_to_cpu=False): | |
| """ | |
| Decode latents using tiling to reduce VRAM usage. | |
| Uses overlap-discard strategy to avoid boundary artifacts. | |
| Args: | |
| latents: [Batch, Channels, Length] | |
| chunk_size: Size of latent chunk to process at once | |
| overlap: Overlap size in latent frames | |
| offload_wav_to_cpu: If True, offload decoded wav audio to CPU immediately to save VRAM | |
| """ | |
| B, C, T = latents.shape | |
| # If short enough, decode directly | |
| if T <= chunk_size: | |
| return self.vae.decode(latents).sample | |
| # Calculate stride (core size) | |
| stride = chunk_size - 2 * overlap | |
| if stride <= 0: | |
| raise ValueError(f"chunk_size {chunk_size} must be > 2 * overlap {overlap}") | |
| num_steps = math.ceil(T / stride) | |
| if offload_wav_to_cpu: | |
| # Optimized path: offload wav to CPU immediately to save VRAM | |
| return self._tiled_decode_offload_cpu(latents, B, T, stride, overlap, num_steps) | |
| else: | |
| # Default path: keep everything on GPU | |
| return self._tiled_decode_gpu(latents, B, T, stride, overlap, num_steps) | |
| def _tiled_decode_gpu(self, latents, B, T, stride, overlap, num_steps): | |
| """Standard tiled decode keeping all data on GPU.""" | |
| decoded_audio_list = [] | |
| upsample_factor = None | |
| for i in tqdm(range(num_steps), desc="Decoding audio chunks"): | |
| # Core range in latents | |
| core_start = i * stride | |
| core_end = min(core_start + stride, T) | |
| # Window range (with overlap) | |
| win_start = max(0, core_start - overlap) | |
| win_end = min(T, core_end + overlap) | |
| # Extract chunk | |
| latent_chunk = latents[:, :, win_start:win_end] | |
| # Decode | |
| # [Batch, Channels, AudioSamples] | |
| audio_chunk = self.vae.decode(latent_chunk).sample | |
| # Determine upsample factor from the first chunk | |
| if upsample_factor is None: | |
| upsample_factor = audio_chunk.shape[-1] / latent_chunk.shape[-1] | |
| # Calculate trim amounts in audio samples | |
| # How much overlap was added at the start? | |
| added_start = core_start - win_start # latent frames | |
| trim_start = int(round(added_start * upsample_factor)) | |
| # How much overlap was added at the end? | |
| added_end = win_end - core_end # latent frames | |
| trim_end = int(round(added_end * upsample_factor)) | |
| # Trim audio | |
| audio_len = audio_chunk.shape[-1] | |
| end_idx = audio_len - trim_end if trim_end > 0 else audio_len | |
| audio_core = audio_chunk[:, :, trim_start:end_idx] | |
| decoded_audio_list.append(audio_core) | |
| # Concatenate | |
| final_audio = torch.cat(decoded_audio_list, dim=-1) | |
| return final_audio | |
| def _tiled_decode_offload_cpu(self, latents, B, T, stride, overlap, num_steps): | |
| """Optimized tiled decode that offloads to CPU immediately to save VRAM.""" | |
| # First pass: decode first chunk to get upsample_factor and audio channels | |
| first_core_start = 0 | |
| first_core_end = min(stride, T) | |
| first_win_start = 0 | |
| first_win_end = min(T, first_core_end + overlap) | |
| first_latent_chunk = latents[:, :, first_win_start:first_win_end] | |
| first_audio_chunk = self.vae.decode(first_latent_chunk).sample | |
| upsample_factor = first_audio_chunk.shape[-1] / first_latent_chunk.shape[-1] | |
| audio_channels = first_audio_chunk.shape[1] | |
| # Calculate total audio length and pre-allocate CPU tensor | |
| total_audio_length = int(round(T * upsample_factor)) | |
| final_audio = torch.zeros(B, audio_channels, total_audio_length, | |
| dtype=first_audio_chunk.dtype, device='cpu') | |
| # Process first chunk: trim and copy to CPU | |
| first_added_end = first_win_end - first_core_end | |
| first_trim_end = int(round(first_added_end * upsample_factor)) | |
| first_audio_len = first_audio_chunk.shape[-1] | |
| first_end_idx = first_audio_len - first_trim_end if first_trim_end > 0 else first_audio_len | |
| first_audio_core = first_audio_chunk[:, :, :first_end_idx] | |
| audio_write_pos = first_audio_core.shape[-1] | |
| final_audio[:, :, :audio_write_pos] = first_audio_core.cpu() | |
| # Free GPU memory | |
| del first_audio_chunk, first_audio_core, first_latent_chunk | |
| # Process remaining chunks | |
| for i in tqdm(range(1, num_steps), desc="Decoding audio chunks"): | |
| # Core range in latents | |
| core_start = i * stride | |
| core_end = min(core_start + stride, T) | |
| # Window range (with overlap) | |
| win_start = max(0, core_start - overlap) | |
| win_end = min(T, core_end + overlap) | |
| # Extract chunk | |
| latent_chunk = latents[:, :, win_start:win_end] | |
| # Decode on GPU | |
| # [Batch, Channels, AudioSamples] | |
| audio_chunk = self.vae.decode(latent_chunk).sample | |
| # Calculate trim amounts in audio samples | |
| added_start = core_start - win_start # latent frames | |
| trim_start = int(round(added_start * upsample_factor)) | |
| added_end = win_end - core_end # latent frames | |
| trim_end = int(round(added_end * upsample_factor)) | |
| # Trim audio | |
| audio_len = audio_chunk.shape[-1] | |
| end_idx = audio_len - trim_end if trim_end > 0 else audio_len | |
| audio_core = audio_chunk[:, :, trim_start:end_idx] | |
| # Copy to pre-allocated CPU tensor | |
| core_len = audio_core.shape[-1] | |
| final_audio[:, :, audio_write_pos:audio_write_pos + core_len] = audio_core.cpu() | |
| audio_write_pos += core_len | |
| # Free GPU memory immediately | |
| del audio_chunk, audio_core, latent_chunk | |
| # Trim to actual length (in case of rounding differences) | |
| final_audio = final_audio[:, :, :audio_write_pos] | |
| return final_audio | |
| def generate_music( | |
| self, | |
| captions: str, | |
| lyrics: str, | |
| bpm: Optional[int] = None, | |
| key_scale: str = "", | |
| time_signature: str = "", | |
| vocal_language: str = "en", | |
| inference_steps: int = 8, | |
| guidance_scale: float = 7.0, | |
| use_random_seed: bool = True, | |
| seed: Optional[Union[str, float, int]] = -1, | |
| reference_audio=None, | |
| audio_duration: Optional[float] = None, | |
| batch_size: Optional[int] = None, | |
| src_audio=None, | |
| audio_code_string: Union[str, List[str]] = "", | |
| repainting_start: float = 0.0, | |
| repainting_end: Optional[float] = None, | |
| instruction: str = DEFAULT_DIT_INSTRUCTION, | |
| audio_cover_strength: float = 1.0, | |
| task_type: str = "text2music", | |
| use_adg: bool = False, | |
| cfg_interval_start: float = 0.0, | |
| cfg_interval_end: float = 1.0, | |
| shift: float = 1.0, | |
| infer_method: str = "ode", | |
| use_tiled_decode: bool = True, | |
| timesteps: Optional[List[float]] = None, | |
| progress=None | |
| ) -> Dict[str, Any]: | |
| """ | |
| Main interface for music generation | |
| Returns: | |
| Dictionary containing: | |
| - audios: List of audio dictionaries with path, key, params | |
| - generation_info: Markdown-formatted generation information | |
| - status_message: Status message | |
| - extra_outputs: Dictionary with latents, masks, time_costs, etc. | |
| - success: Whether generation completed successfully | |
| - error: Error message if generation failed | |
| """ | |
| if progress is None: | |
| def progress(*args, **kwargs): | |
| pass | |
| if self.model is None or self.vae is None or self.text_tokenizer is None or self.text_encoder is None: | |
| return { | |
| "audios": [], | |
| "status_message": "❌ Model not fully initialized. Please initialize all components first.", | |
| "extra_outputs": {}, | |
| "success": False, | |
| "error": "Model not fully initialized", | |
| } | |
| def _has_audio_codes(v: Union[str, List[str]]) -> bool: | |
| if isinstance(v, list): | |
| return any((x or "").strip() for x in v) | |
| return bool(v and str(v).strip()) | |
| # Auto-detect task type based on audio_code_string | |
| # If audio_code_string is provided and not empty, use cover task | |
| # Otherwise, use text2music task (or keep current task_type if not text2music) | |
| if task_type == "text2music": | |
| if _has_audio_codes(audio_code_string): | |
| # User has provided audio codes, switch to cover task | |
| task_type = "cover" | |
| # Update instruction for cover task | |
| instruction = TASK_INSTRUCTIONS["cover"] | |
| logger.info("[generate_music] Starting generation...") | |
| if progress: | |
| progress(0.51, desc="Preparing inputs...") | |
| logger.info("[generate_music] Preparing inputs...") | |
| # Reset offload cost | |
| self.current_offload_cost = 0.0 | |
| # Caption and lyrics are optional - can be empty | |
| # Use provided batch_size or default | |
| actual_batch_size = batch_size if batch_size is not None else self.batch_size | |
| actual_batch_size = max(1, actual_batch_size) # Ensure at least 1 | |
| actual_seed_list, seed_value_for_ui = self.prepare_seeds(actual_batch_size, seed, use_random_seed) | |
| # Convert special values to None | |
| if audio_duration is not None and audio_duration <= 0: | |
| audio_duration = None | |
| # if seed is not None and seed < 0: | |
| # seed = None | |
| if repainting_end is not None and repainting_end < 0: | |
| repainting_end = None | |
| try: | |
| # 1. Process reference audio | |
| refer_audios = None | |
| if reference_audio is not None: | |
| logger.info("[generate_music] Processing reference audio...") | |
| processed_ref_audio = self.process_reference_audio(reference_audio) | |
| if processed_ref_audio is not None: | |
| # Convert to the format expected by the service: List[List[torch.Tensor]] | |
| # Each batch item has a list of reference audios | |
| refer_audios = [[processed_ref_audio] for _ in range(actual_batch_size)] | |
| else: | |
| refer_audios = [[torch.zeros(2, 30*self.sample_rate)] for _ in range(actual_batch_size)] | |
| # 2. Process source audio | |
| # If audio_code_string is provided, ignore src_audio and use codes instead | |
| processed_src_audio = None | |
| if src_audio is not None: | |
| # Check if audio codes are provided - if so, ignore src_audio | |
| if _has_audio_codes(audio_code_string): | |
| logger.info("[generate_music] Audio codes provided, ignoring src_audio and using codes instead") | |
| else: | |
| logger.info("[generate_music] Processing source audio...") | |
| processed_src_audio = self.process_src_audio(src_audio) | |
| # 3. Prepare batch data | |
| captions_batch, instructions_batch, lyrics_batch, vocal_languages_batch, metas_batch = self.prepare_batch_data( | |
| actual_batch_size, | |
| processed_src_audio, | |
| audio_duration, | |
| captions, | |
| lyrics, | |
| vocal_language, | |
| instruction, | |
| bpm, | |
| key_scale, | |
| time_signature | |
| ) | |
| is_repaint_task, is_lego_task, is_cover_task, can_use_repainting = self.determine_task_type(task_type, audio_code_string) | |
| repainting_start_batch, repainting_end_batch, target_wavs_tensor = self.prepare_padding_info( | |
| actual_batch_size, | |
| processed_src_audio, | |
| audio_duration, | |
| repainting_start, | |
| repainting_end, | |
| is_repaint_task, | |
| is_lego_task, | |
| is_cover_task, | |
| can_use_repainting | |
| ) | |
| progress(0.52, desc=f"Generating music (batch size: {actual_batch_size})...") | |
| # Prepare audio_code_hints - use if audio_code_string is provided | |
| # This works for both text2music (auto-switched to cover) and cover tasks | |
| audio_code_hints_batch = None | |
| if _has_audio_codes(audio_code_string): | |
| if isinstance(audio_code_string, list): | |
| audio_code_hints_batch = audio_code_string | |
| else: | |
| audio_code_hints_batch = [audio_code_string] * actual_batch_size | |
| should_return_intermediate = (task_type == "text2music") | |
| outputs = self.service_generate( | |
| captions=captions_batch, | |
| lyrics=lyrics_batch, | |
| metas=metas_batch, # Pass as dict, service will convert to string | |
| vocal_languages=vocal_languages_batch, | |
| refer_audios=refer_audios, # Already in List[List[torch.Tensor]] format | |
| target_wavs=target_wavs_tensor, # Shape: [batch_size, 2, frames] | |
| infer_steps=inference_steps, | |
| guidance_scale=guidance_scale, | |
| seed=actual_seed_list, # Pass list of seeds, one per batch item | |
| repainting_start=repainting_start_batch, | |
| repainting_end=repainting_end_batch, | |
| instructions=instructions_batch, # Pass instructions to service | |
| audio_cover_strength=audio_cover_strength, # Pass audio cover strength | |
| use_adg=use_adg, # Pass use_adg parameter | |
| cfg_interval_start=cfg_interval_start, # Pass CFG interval start | |
| cfg_interval_end=cfg_interval_end, # Pass CFG interval end | |
| shift=shift, # Pass shift parameter | |
| infer_method=infer_method, # Pass infer method (ode or sde) | |
| audio_code_hints=audio_code_hints_batch, # Pass audio code hints as list | |
| return_intermediate=should_return_intermediate, | |
| timesteps=timesteps, # Pass custom timesteps if provided | |
| ) | |
| logger.info("[generate_music] Model generation completed. Decoding latents...") | |
| pred_latents = outputs["target_latents"] # [batch, latent_length, latent_dim] | |
| time_costs = outputs["time_costs"] | |
| time_costs["offload_time_cost"] = self.current_offload_cost | |
| logger.debug(f"[generate_music] pred_latents: {pred_latents.shape}, dtype={pred_latents.dtype} {pred_latents.min()=}, {pred_latents.max()=}, {pred_latents.mean()=} {pred_latents.std()=}") | |
| logger.debug(f"[generate_music] time_costs: {time_costs}") | |
| if progress: | |
| progress(0.8, desc="Decoding audio...") | |
| logger.info("[generate_music] Decoding latents with VAE...") | |
| # Decode latents to audio | |
| start_time = time.time() | |
| with torch.no_grad(): | |
| with self._load_model_context("vae"): | |
| # Transpose for VAE decode: [batch, latent_length, latent_dim] -> [batch, latent_dim, latent_length] | |
| pred_latents_for_decode = pred_latents.transpose(1, 2) | |
| # Ensure input is in VAE's dtype | |
| pred_latents_for_decode = pred_latents_for_decode.to(self.vae.dtype) | |
| if use_tiled_decode: | |
| logger.info("[generate_music] Using tiled VAE decode to reduce VRAM usage...") | |
| pred_wavs = self.tiled_decode(pred_latents_for_decode) # [batch, channels, samples] | |
| else: | |
| pred_wavs = self.vae.decode(pred_latents_for_decode).sample | |
| # Cast output to float32 for audio processing/saving | |
| pred_wavs = pred_wavs.to(torch.float32) | |
| end_time = time.time() | |
| time_costs["vae_decode_time_cost"] = end_time - start_time | |
| time_costs["total_time_cost"] = time_costs["total_time_cost"] + time_costs["vae_decode_time_cost"] | |
| # Update offload cost one last time to include VAE offloading | |
| time_costs["offload_time_cost"] = self.current_offload_cost | |
| logger.info("[generate_music] VAE decode completed. Preparing audio tensors...") | |
| if progress: | |
| progress(0.99, desc="Preparing audio data...") | |
| # Prepare audio tensors (no file I/O here, no UUID generation) | |
| # pred_wavs is already [batch, channels, samples] format | |
| # Move to CPU and convert to float32 for return | |
| audio_tensors = [] | |
| for i in range(actual_batch_size): | |
| # Extract audio tensor: [channels, samples] format, CPU, float32 | |
| audio_tensor = pred_wavs[i].cpu().float() | |
| audio_tensors.append(audio_tensor) | |
| status_message = f"✅ Generation completed successfully!" | |
| logger.info(f"[generate_music] Done! Generated {len(audio_tensors)} audio tensors.") | |
| # Extract intermediate information from outputs | |
| src_latents = outputs.get("src_latents") # [batch, T, D] | |
| target_latents_input = outputs.get("target_latents_input") # [batch, T, D] | |
| chunk_masks = outputs.get("chunk_masks") # [batch, T] | |
| spans = outputs.get("spans", []) # List of tuples | |
| latent_masks = outputs.get("latent_masks") # [batch, T] | |
| # Extract condition tensors for LRC timestamp generation | |
| encoder_hidden_states = outputs.get("encoder_hidden_states") | |
| encoder_attention_mask = outputs.get("encoder_attention_mask") | |
| context_latents = outputs.get("context_latents") | |
| lyric_token_idss = outputs.get("lyric_token_idss") | |
| # Move all tensors to CPU to save VRAM (detach to release computation graph) | |
| extra_outputs = { | |
| "pred_latents": pred_latents.detach().cpu() if pred_latents is not None else None, | |
| "target_latents": target_latents_input.detach().cpu() if target_latents_input is not None else None, | |
| "src_latents": src_latents.detach().cpu() if src_latents is not None else None, | |
| "chunk_masks": chunk_masks.detach().cpu() if chunk_masks is not None else None, | |
| "latent_masks": latent_masks.detach().cpu() if latent_masks is not None else None, | |
| "spans": spans, | |
| "time_costs": time_costs, | |
| "seed_value": seed_value_for_ui, | |
| # Condition tensors for LRC timestamp generation | |
| "encoder_hidden_states": encoder_hidden_states.detach().cpu() if encoder_hidden_states is not None else None, | |
| "encoder_attention_mask": encoder_attention_mask.detach().cpu() if encoder_attention_mask is not None else None, | |
| "context_latents": context_latents.detach().cpu() if context_latents is not None else None, | |
| "lyric_token_idss": lyric_token_idss.detach().cpu() if lyric_token_idss is not None else None, | |
| } | |
| # Build audios list with tensor data (no file paths, no UUIDs, handled outside) | |
| audios = [] | |
| for idx, audio_tensor in enumerate(audio_tensors): | |
| audio_dict = { | |
| "tensor": audio_tensor, # torch.Tensor [channels, samples], CPU, float32 | |
| "sample_rate": self.sample_rate, | |
| } | |
| audios.append(audio_dict) | |
| return { | |
| "audios": audios, | |
| "status_message": status_message, | |
| "extra_outputs": extra_outputs, | |
| "success": True, | |
| "error": None, | |
| } | |
| except Exception as e: | |
| error_msg = f"❌ Error: {str(e)}\n{traceback.format_exc()}" | |
| logger.exception("[generate_music] Generation failed") | |
| # Clean up CUDA state after any error (especially important for CUDA errors) | |
| if torch.cuda.is_available(): | |
| try: | |
| torch.cuda.synchronize() | |
| except Exception: | |
| pass # Ignore sync errors during cleanup | |
| torch.cuda.empty_cache() | |
| return { | |
| "audios": [], | |
| "status_message": error_msg, | |
| "extra_outputs": {}, | |
| "success": False, | |
| "error": str(e), | |
| } | |
| def get_lyric_timestamp( | |
| self, | |
| pred_latent: torch.Tensor, | |
| encoder_hidden_states: torch.Tensor, | |
| encoder_attention_mask: torch.Tensor, | |
| context_latents: torch.Tensor, | |
| lyric_token_ids: torch.Tensor, | |
| total_duration_seconds: float, | |
| vocal_language: str = "en", | |
| inference_steps: int = 8, | |
| seed: int = 42, | |
| custom_layers_config: Optional[Dict] = None, | |
| ) -> Dict[str, Any]: | |
| """ | |
| Generate lyrics timestamps from generated audio latents using cross-attention alignment. | |
| This method adds noise to the final pred_latent and re-infers one step to get | |
| cross-attention matrices, then uses DTW to align lyrics tokens with audio frames. | |
| Args: | |
| pred_latent: Generated latent tensor [batch, T, D] | |
| encoder_hidden_states: Cached encoder hidden states | |
| encoder_attention_mask: Cached encoder attention mask | |
| context_latents: Cached context latents | |
| lyric_token_ids: Tokenized lyrics tensor [batch, seq_len] | |
| total_duration_seconds: Total audio duration in seconds | |
| vocal_language: Language code for lyrics header parsing | |
| inference_steps: Number of inference steps (for noise level calculation) | |
| seed: Random seed for noise generation | |
| custom_layers_config: Dict mapping layer indices to head indices | |
| Returns: | |
| Dict containing: | |
| - lrc_text: LRC formatted lyrics with timestamps | |
| - sentence_timestamps: List of SentenceTimestamp objects | |
| - token_timestamps: List of TokenTimestamp objects | |
| - success: Whether generation succeeded | |
| - error: Error message if failed | |
| """ | |
| from transformers.cache_utils import EncoderDecoderCache, DynamicCache | |
| if self.model is None: | |
| return { | |
| "lrc_text": "", | |
| "sentence_timestamps": [], | |
| "token_timestamps": [], | |
| "success": False, | |
| "error": "Model not initialized" | |
| } | |
| if custom_layers_config is None: | |
| custom_layers_config = self.custom_layers_config | |
| try: | |
| # Move tensors to device | |
| device = self.device | |
| dtype = self.dtype | |
| pred_latent = pred_latent.to(device=device, dtype=dtype) | |
| encoder_hidden_states = encoder_hidden_states.to(device=device, dtype=dtype) | |
| encoder_attention_mask = encoder_attention_mask.to(device=device, dtype=dtype) | |
| context_latents = context_latents.to(device=device, dtype=dtype) | |
| bsz = pred_latent.shape[0] | |
| # Calculate noise level: t_last = 1.0 / inference_steps | |
| t_last_val = 1.0 / inference_steps | |
| t_curr_tensor = torch.tensor([t_last_val] * bsz, device=device, dtype=dtype) | |
| x1 = pred_latent | |
| # Generate noise | |
| if seed is None: | |
| x0 = torch.randn_like(x1) | |
| else: | |
| generator = torch.Generator(device=device).manual_seed(int(seed)) | |
| x0 = torch.randn(x1.shape, generator=generator, device=device, dtype=dtype) | |
| # Add noise to pred_latent: xt = t * noise + (1 - t) * x1 | |
| xt = t_last_val * x0 + (1.0 - t_last_val) * x1 | |
| xt_in = xt | |
| t_in = t_curr_tensor | |
| # Get null condition embedding | |
| encoder_hidden_states_in = encoder_hidden_states | |
| encoder_attention_mask_in = encoder_attention_mask | |
| context_latents_in = context_latents | |
| latent_length = x1.shape[1] | |
| attention_mask = torch.ones(bsz, latent_length, device=device, dtype=dtype) | |
| attention_mask_in = attention_mask | |
| past_key_values = None | |
| # Run decoder with output_attentions=True | |
| with self._load_model_context("model"): | |
| decoder = self.model.decoder | |
| decoder_outputs = decoder( | |
| hidden_states=xt_in, | |
| timestep=t_in, | |
| timestep_r=t_in, | |
| attention_mask=attention_mask_in, | |
| encoder_hidden_states=encoder_hidden_states_in, | |
| use_cache=False, | |
| past_key_values=past_key_values, | |
| encoder_attention_mask=encoder_attention_mask_in, | |
| context_latents=context_latents_in, | |
| output_attentions=True, | |
| custom_layers_config=custom_layers_config, | |
| enable_early_exit=True | |
| ) | |
| # Extract cross-attention matrices | |
| if decoder_outputs[2] is None: | |
| return { | |
| "lrc_text": "", | |
| "sentence_timestamps": [], | |
| "token_timestamps": [], | |
| "success": False, | |
| "error": "Model did not return attentions" | |
| } | |
| cross_attns = decoder_outputs[2] # Tuple of tensors (some may be None) | |
| captured_layers_list = [] | |
| for layer_attn in cross_attns: | |
| # Skip None values (layers that didn't return attention) | |
| if layer_attn is None: | |
| continue | |
| # Only take conditional part (first half of batch) | |
| cond_attn = layer_attn[:bsz] | |
| layer_matrix = cond_attn.transpose(-1, -2) | |
| captured_layers_list.append(layer_matrix) | |
| if not captured_layers_list: | |
| return { | |
| "lrc_text": "", | |
| "sentence_timestamps": [], | |
| "token_timestamps": [], | |
| "success": False, | |
| "error": "No valid attention layers returned" | |
| } | |
| stacked = torch.stack(captured_layers_list) | |
| if bsz == 1: | |
| all_layers_matrix = stacked.squeeze(1) | |
| else: | |
| all_layers_matrix = stacked | |
| # Process lyric token IDs to extract pure lyrics | |
| if isinstance(lyric_token_ids, torch.Tensor): | |
| raw_lyric_ids = lyric_token_ids[0].tolist() | |
| else: | |
| raw_lyric_ids = lyric_token_ids | |
| # Parse header to find lyrics start position | |
| header_str = f"# Languages\n{vocal_language}\n\n# Lyric\n" | |
| header_ids = self.text_tokenizer.encode(header_str, add_special_tokens=False) | |
| start_idx = len(header_ids) | |
| # Find end of lyrics (before endoftext token) | |
| try: | |
| end_idx = raw_lyric_ids.index(151643) # <|endoftext|> token | |
| except ValueError: | |
| end_idx = len(raw_lyric_ids) | |
| pure_lyric_ids = raw_lyric_ids[start_idx:end_idx] | |
| pure_lyric_matrix = all_layers_matrix[:, :, start_idx:end_idx, :] | |
| # Create aligner and generate timestamps | |
| aligner = MusicStampsAligner(self.text_tokenizer) | |
| align_info = aligner.stamps_align_info( | |
| attention_matrix=pure_lyric_matrix, | |
| lyrics_tokens=pure_lyric_ids, | |
| total_duration_seconds=total_duration_seconds, | |
| custom_config=custom_layers_config, | |
| return_matrices=False, | |
| violence_level=2.0, | |
| medfilt_width=1, | |
| ) | |
| if align_info.get("calc_matrix") is None: | |
| return { | |
| "lrc_text": "", | |
| "sentence_timestamps": [], | |
| "token_timestamps": [], | |
| "success": False, | |
| "error": align_info.get("error", "Failed to process attention matrix") | |
| } | |
| # Generate timestamps | |
| result = aligner.get_timestamps_and_lrc( | |
| calc_matrix=align_info["calc_matrix"], | |
| lyrics_tokens=pure_lyric_ids, | |
| total_duration_seconds=total_duration_seconds | |
| ) | |
| return { | |
| "lrc_text": result["lrc_text"], | |
| "sentence_timestamps": result["sentence_timestamps"], | |
| "token_timestamps": result["token_timestamps"], | |
| "success": True, | |
| "error": None | |
| } | |
| except Exception as e: | |
| error_msg = f"Error generating timestamps: {str(e)}" | |
| logger.exception("[get_lyric_timestamp] Failed") | |
| return { | |
| "lrc_text": "", | |
| "sentence_timestamps": [], | |
| "token_timestamps": [], | |
| "success": False, | |
| "error": error_msg | |
| } | |
| def get_lyric_score( | |
| self, | |
| pred_latent: torch.Tensor, | |
| encoder_hidden_states: torch.Tensor, | |
| encoder_attention_mask: torch.Tensor, | |
| context_latents: torch.Tensor, | |
| lyric_token_ids: torch.Tensor, | |
| vocal_language: str = "en", | |
| inference_steps: int = 8, | |
| seed: int = 42, | |
| custom_layers_config: Optional[Dict] = None, | |
| ) -> Dict[str, Any]: | |
| """ | |
| Calculate both LM and DiT alignment scores in one pass. | |
| - lm_score: Checks structural alignment using pure noise at t=1.0. | |
| - dit_score: Checks denoising alignment using regressed latents at t=1/steps. | |
| Args: | |
| pred_latent: Generated latent tensor [batch, T, D] | |
| encoder_hidden_states: Cached encoder hidden states | |
| encoder_attention_mask: Cached encoder attention mask | |
| context_latents: Cached context latents | |
| lyric_token_ids: Tokenized lyrics tensor [batch, seq_len] | |
| vocal_language: Language code for lyrics header parsing | |
| inference_steps: Number of inference steps (for noise level calculation) | |
| seed: Random seed for noise generation | |
| custom_layers_config: Dict mapping layer indices to head indices | |
| Returns: | |
| Dict containing: | |
| - lm_score: float | |
| - dit_score: float | |
| - success: Whether generation succeeded | |
| - error: Error message if failed | |
| """ | |
| from transformers.cache_utils import EncoderDecoderCache, DynamicCache | |
| if self.model is None: | |
| return { | |
| "lm_score": 0.0, | |
| "dit_score": 0.0, | |
| "success": False, | |
| "error": "Model not initialized" | |
| } | |
| if custom_layers_config is None: | |
| custom_layers_config = self.custom_layers_config | |
| try: | |
| # Move tensors to device | |
| device = self.device | |
| dtype = self.dtype | |
| pred_latent = pred_latent.to(device=device, dtype=dtype) | |
| encoder_hidden_states = encoder_hidden_states.to(device=device, dtype=dtype) | |
| encoder_attention_mask = encoder_attention_mask.to(device=device, dtype=dtype) | |
| context_latents = context_latents.to(device=device, dtype=dtype) | |
| bsz = pred_latent.shape[0] | |
| if seed is None: | |
| x0 = torch.randn_like(pred_latent) | |
| else: | |
| generator = torch.Generator(device=device).manual_seed(int(seed)) | |
| x0 = torch.randn(pred_latent.shape, generator=generator, device=device, dtype=dtype) | |
| # --- Input A: LM Score --- | |
| # t = 1.0, xt = Pure Noise | |
| t_lm = torch.tensor([1.0] * bsz, device=device, dtype=dtype) | |
| xt_lm = x0 | |
| # --- Input B: DiT Score --- | |
| # t = 1.0/steps, xt = Regressed Latent | |
| t_last_val = 1.0 / inference_steps | |
| t_dit = torch.tensor([t_last_val] * bsz, device=device, dtype=dtype) | |
| # Flow Matching Regression: xt = t*x0 + (1-t)*x1 | |
| xt_dit = t_last_val * x0 + (1.0 - t_last_val) * pred_latent | |
| # Order: [Think_Batch, DiT_Batch] | |
| xt_in = torch.cat([xt_lm, xt_dit], dim=0) | |
| t_in = torch.cat([t_lm, t_dit], dim=0) | |
| # Duplicate conditions | |
| encoder_hidden_states_in = torch.cat([encoder_hidden_states, encoder_hidden_states], dim=0) | |
| encoder_attention_mask_in = torch.cat([encoder_attention_mask, encoder_attention_mask], dim=0) | |
| context_latents_in = torch.cat([context_latents, context_latents], dim=0) | |
| # Prepare Attention Mask | |
| latent_length = xt_in.shape[1] | |
| attention_mask_in = torch.ones(2 * bsz, latent_length, device=device, dtype=dtype) | |
| past_key_values = None | |
| # Run decoder with output_attentions=True | |
| with self._load_model_context("model"): | |
| decoder = self.model.decoder | |
| if hasattr(decoder, 'eval'): | |
| decoder.eval() | |
| decoder_outputs = decoder( | |
| hidden_states=xt_in, | |
| timestep=t_in, | |
| timestep_r=t_in, | |
| attention_mask=attention_mask_in, | |
| encoder_hidden_states=encoder_hidden_states_in, | |
| use_cache=False, | |
| past_key_values=past_key_values, | |
| encoder_attention_mask=encoder_attention_mask_in, | |
| context_latents=context_latents_in, | |
| output_attentions=True, | |
| custom_layers_config=custom_layers_config, | |
| enable_early_exit=True | |
| ) | |
| # Extract cross-attention matrices | |
| if decoder_outputs[2] is None: | |
| return { | |
| "lm_score": 0.0, | |
| "dit_score": 0.0, | |
| "success": False, | |
| "error": "Model did not return attentions" | |
| } | |
| cross_attns = decoder_outputs[2] # Tuple of tensors (some may be None) | |
| captured_layers_list = [] | |
| for layer_attn in cross_attns: | |
| if layer_attn is None: | |
| continue | |
| # Only take conditional part (first half of batch) | |
| layer_matrix = layer_attn.transpose(-1, -2) | |
| captured_layers_list.append(layer_matrix) | |
| if not captured_layers_list: | |
| return { | |
| "lm_score": 0.0, | |
| "dit_score": 0.0, | |
| "success": False, | |
| "error": "No valid attention layers returned" | |
| } | |
| stacked = torch.stack(captured_layers_list) | |
| all_layers_matrix_lm = stacked[:, :bsz, ...] | |
| all_layers_matrix_dit = stacked[:, bsz:, ...] | |
| if bsz == 1: | |
| all_layers_matrix_lm = all_layers_matrix_lm.squeeze(1) | |
| all_layers_matrix_dit = all_layers_matrix_dit.squeeze(1) | |
| else: | |
| pass | |
| # Process lyric token IDs to extract pure lyrics | |
| if isinstance(lyric_token_ids, torch.Tensor): | |
| raw_lyric_ids = lyric_token_ids[0].tolist() | |
| else: | |
| raw_lyric_ids = lyric_token_ids | |
| # Parse header to find lyrics start position | |
| header_str = f"# Languages\n{vocal_language}\n\n# Lyric\n" | |
| header_ids = self.text_tokenizer.encode(header_str, add_special_tokens=False) | |
| start_idx = len(header_ids) | |
| # Find end of lyrics (before endoftext token) | |
| try: | |
| end_idx = raw_lyric_ids.index(151643) # <|endoftext|> token | |
| except ValueError: | |
| end_idx = len(raw_lyric_ids) | |
| pure_lyric_ids = raw_lyric_ids[start_idx:end_idx] | |
| if start_idx >= all_layers_matrix_lm.shape[-2]: # Check text dim | |
| return { | |
| "lm_score": 0.0, | |
| "dit_score": 0.0, | |
| "success": False, | |
| "error": "Lyrics indices out of bounds" | |
| } | |
| pure_matrix_lm = all_layers_matrix_lm[..., start_idx:end_idx, :] | |
| pure_matrix_dit = all_layers_matrix_dit[..., start_idx:end_idx, :] | |
| # Create aligner and calculate alignment info | |
| aligner = MusicLyricScorer(self.text_tokenizer) | |
| def calculate_single_score(matrix): | |
| """Helper to run aligner on a matrix""" | |
| info = aligner.lyrics_alignment_info( | |
| attention_matrix=matrix, | |
| token_ids=pure_lyric_ids, | |
| custom_config=custom_layers_config, | |
| return_matrices=False, | |
| medfilt_width=1, | |
| ) | |
| if info.get("energy_matrix") is None: | |
| return 0.0 | |
| res = aligner.calculate_score( | |
| energy_matrix=info["energy_matrix"], | |
| type_mask=info["type_mask"], | |
| path_coords=info["path_coords"], | |
| ) | |
| # Return the final score (check return key) | |
| return res.get("lyrics_score", res.get("final_score", 0.0)) | |
| lm_score = calculate_single_score(pure_matrix_lm) | |
| dit_score = calculate_single_score(pure_matrix_dit) | |
| return { | |
| "lm_score": lm_score, | |
| "dit_score": dit_score, | |
| "success": True, | |
| "error": None | |
| } | |
| except Exception as e: | |
| error_msg = f"Error generating score: {str(e)}" | |
| logger.exception("[get_lyric_score] Failed") | |
| return { | |
| "lm_score": 0.0, | |
| "dit_score": 0.0, | |
| "success": False, | |
| "error": error_msg | |
| } |