Spaces:
Running
Running
| # dia/model.py | |
| import os | |
| import logging | |
| import time | |
| import dac # Keep this import name | |
| import numpy as np | |
| import torch | |
| import torchaudio | |
| from huggingface_hub import hf_hub_download | |
| from safetensors.torch import load_file # <<< ADDED Import for safetensors | |
| from .audio import audio_to_codebook, codebook_to_audio | |
| from .config import ( | |
| DiaConfig, | |
| ) # Assuming this is the Pydantic config for model structure | |
| from .layers import DiaModel, KVCache # Assuming these are the nn.Module definitions | |
| # --- Get a logger instance for this module --- | |
| logger = logging.getLogger(__name__) | |
| # Optional: Add a check after import to verify the library looks correct | |
| # Note: We now expect 'utils' based on original code | |
| if ( | |
| not hasattr(dac, "utils") | |
| or not hasattr(dac.utils, "download") | |
| or not hasattr(dac, "DAC") | |
| ): | |
| logger.warning( | |
| "The imported 'dac' module does not appear to have the 'utils.download' structure expected by the original Dia code." | |
| ) | |
| logger.warning( | |
| "Ensure 'descript-audio-codec' is installed correctly (pip install descript-audio-codec)." | |
| ) | |
| # If this check fails, _load_dac_model will likely raise an error later anyway. | |
| def _sample_next_token( | |
| logits_BCxV: torch.Tensor, | |
| temperature: float, | |
| top_p: float, | |
| use_cfg_filter: bool, | |
| cfg_filter_top_k: int | None = None, | |
| ) -> torch.Tensor: | |
| """Samples the next token based on logits, temperature, and top_p.""" | |
| if temperature == 0.0: | |
| # Greedy sampling | |
| return torch.argmax(logits_BCxV, dim=-1) | |
| # Apply temperature scaling | |
| logits_BCxV = logits_BCxV / temperature | |
| # Apply CFG Top-K filtering (optional) | |
| if use_cfg_filter and cfg_filter_top_k is not None: | |
| # Get top K values and indices | |
| _, top_k_indices_BCxV = torch.topk(logits_BCxV, k=cfg_filter_top_k, dim=-1) | |
| # Create a mask to keep only top K logits | |
| mask = torch.ones_like(logits_BCxV, dtype=torch.bool) | |
| mask.scatter_( | |
| dim=-1, index=top_k_indices_BCxV, value=False | |
| ) # Set top K positions to False (don't mask) | |
| # Mask out logits not in the top K | |
| logits_BCxV = logits_BCxV.masked_fill(mask, -torch.inf) | |
| # Apply Top-P (Nucleus) sampling | |
| if top_p < 1.0: | |
| # Convert logits to probabilities | |
| probs_BCxV = torch.softmax(logits_BCxV, dim=-1) | |
| # Sort probabilities in descending order | |
| sorted_probs_BCxV, sorted_indices_BCxV = torch.sort( | |
| probs_BCxV, dim=-1, descending=True | |
| ) | |
| # Calculate cumulative probabilities | |
| cumulative_probs_BCxV = torch.cumsum(sorted_probs_BCxV, dim=-1) | |
| # Create mask for tokens to remove (those exceeding top_p threshold) | |
| sorted_indices_to_remove_BCxV = cumulative_probs_BCxV > top_p | |
| # Shift the mask: keep the first token that crosses the threshold | |
| sorted_indices_to_remove_BCxV[..., 1:] = sorted_indices_to_remove_BCxV[ | |
| ..., :-1 | |
| ].clone() | |
| sorted_indices_to_remove_BCxV[..., 0] = 0 # Always keep the most probable token | |
| # Scatter the mask back to the original order | |
| indices_to_remove_BCxV = torch.zeros_like(sorted_indices_to_remove_BCxV) | |
| indices_to_remove_BCxV.scatter_( | |
| dim=-1, index=sorted_indices_BCxV, src=sorted_indices_to_remove_BCxV | |
| ) | |
| # Apply the mask to the logits | |
| logits_BCxV = logits_BCxV.masked_fill(indices_to_remove_BCxV, -torch.inf) | |
| # Calculate final probabilities after filtering | |
| final_probs_BCxV = torch.softmax(logits_BCxV, dim=-1) | |
| # Sample from the filtered distribution | |
| # multinomial expects probabilities for each item in the batch | |
| sampled_indices_BC = torch.multinomial( | |
| final_probs_BCxV, num_samples=1 | |
| ) # Shape [B*C, 1] | |
| sampled_indices_C = sampled_indices_BC.squeeze( | |
| -1 | |
| ) # Shape [B*C] -> should be [C] if input was [C,V] | |
| return sampled_indices_C | |
| class Dia: | |
| """ | |
| Main class for the Dia Text-to-Speech model, handling loading and generation. | |
| """ | |
| def __init__(self, config: DiaConfig, device: torch.device = torch.device("cuda")): | |
| """ | |
| Initializes the Dia model structure based on the provided configuration. | |
| Does not load weights here. | |
| Args: | |
| config: The DiaConfig object defining model parameters. | |
| device: The torch device (e.g., 'cuda', 'cpu') the model should eventually run on. | |
| Note: The model is instantiated but not moved to the device here. | |
| """ | |
| super().__init__() | |
| logger.info( | |
| f"Initializing Dia model structure with config version: {config.version}" | |
| ) | |
| self.config = config | |
| # Store the target device, but don't move the model yet. Loading weights will handle device placement. | |
| self.target_device = device | |
| # Instantiate the underlying PyTorch model based on the config | |
| self.model = DiaModel(config) | |
| self.dac_model = None # DAC model will be loaded separately | |
| logger.info("Dia model structure initialized.") | |
| def load_model_from_files( | |
| cls, | |
| config_path: str, | |
| weights_path: str, | |
| device: torch.device = torch.device("cuda"), | |
| ) -> "Dia": | |
| """ | |
| Loads the Dia model from local configuration and weights files. | |
| Handles both .pth and .safetensors weight formats. | |
| Args: | |
| config_path: Path to the configuration JSON file (e.g., 'config.json'). | |
| weights_path: Path to the model weights file (e.g., 'model.pth' or 'model.safetensors'). | |
| device: The torch device ('cuda', 'cpu', etc.) to load the model onto. | |
| Returns: | |
| An instance of the Dia model loaded with weights and set to eval mode. | |
| Raises: | |
| FileNotFoundError: If the config or weights file is not found. | |
| ValueError: If the weights file format is unsupported. | |
| RuntimeError: If there is an error loading the config, weights, or DAC model. | |
| """ | |
| logger.info(f"Loading Dia model from local files:") | |
| logger.info(f" Config: {config_path}") | |
| logger.info(f" Weights: {weights_path}") | |
| logger.info(f" Target Device: {device}") | |
| # 1. Load Configuration | |
| try: | |
| config = DiaConfig.load(config_path) | |
| if config is None: | |
| # DiaConfig.load returns None on FileNotFoundError | |
| logger.error(f"Configuration file not found at {config_path}") | |
| raise FileNotFoundError( | |
| f"Configuration file not found at {config_path}" | |
| ) | |
| logger.info("Configuration loaded successfully.") | |
| except Exception as e: | |
| logger.error( | |
| f"Error loading or validating configuration from {config_path}: {e}", | |
| exc_info=True, | |
| ) | |
| raise RuntimeError( | |
| f"Failed to load configuration from {config_path}" | |
| ) from e | |
| # 2. Instantiate Model Structure | |
| # Pass the target device during instantiation if the underlying DiaModel supports it, | |
| # otherwise, we move it later. Assuming __init__ doesn't take device for now. | |
| dia_instance = cls( | |
| config, device | |
| ) # Pass device mainly for storing target_device | |
| # 3. Load Weights (State Dictionary) | |
| try: | |
| logger.info(f"Loading weights from: {weights_path}") | |
| weights_filename = os.path.basename(weights_path) | |
| state_dict = None | |
| if weights_filename.endswith(".safetensors"): | |
| logger.info( | |
| "Detected .safetensors format. Loading using safetensors library." | |
| ) | |
| # load_file loads directly to the specified device | |
| state_dict = load_file(weights_path, device=str(device)) | |
| logger.info("Safetensors weights loaded.") | |
| elif weights_filename.endswith(".pth"): | |
| logger.info("Detected .pth format. Loading using torch.load.") | |
| # torch.load needs map_location to load onto the correct device | |
| state_dict = torch.load(weights_path, map_location=device) | |
| logger.info("PyTorch weights (.pth) loaded.") | |
| else: | |
| logger.error( | |
| f"Unsupported weights file format: {weights_filename}. Expected .pth or .safetensors." | |
| ) | |
| raise ValueError(f"Unsupported weights file format: {weights_filename}") | |
| # Load the state dictionary into the model structure | |
| logger.info("Applying loaded weights to the model structure...") | |
| # Use strict=True by default to catch mismatches. Can be set to False if needed for specific conversions (e.g., BF16 -> FP32 partial loads) | |
| dia_instance.model.load_state_dict(state_dict, strict=True) | |
| logger.info("Weights applied successfully.") | |
| except FileNotFoundError: | |
| logger.error(f"Weights file not found at {weights_path}") | |
| raise FileNotFoundError(f"Weights file not found at {weights_path}") | |
| except Exception as e: | |
| logger.error( | |
| f"Error loading weights from {weights_path}: {e}", exc_info=True | |
| ) | |
| raise RuntimeError(f"Error loading weights from {weights_path}") from e | |
| # 4. Move Model to Device and Set Eval Mode | |
| logger.info(f"Moving model to device: {device}...") | |
| dia_instance.model.to(device) | |
| logger.info("Setting model to evaluation mode...") | |
| dia_instance.model.eval() | |
| # 5. Load Associated DAC Model | |
| logger.info("Loading associated DAC model...") | |
| dia_instance._load_dac_model() # This will log its own progress/errors | |
| logger.info("Dia model fully loaded and ready.") | |
| return dia_instance | |
| # REMOVED from_pretrained - Responsibility moved to engine.py | |
| # @classmethod | |
| # def from_pretrained(...) -> "Dia": ... | |
| def _load_dac_model(self): | |
| """Loads the Descript Audio Codec (DAC) model using the original project's method.""" | |
| if self.dac_model is not None: | |
| logger.info("DAC model already loaded.") | |
| return | |
| # Verify the imported module has the necessary structure expected by original code | |
| if ( | |
| not hasattr(dac, "utils") | |
| or not hasattr(dac.utils, "download") | |
| or not hasattr(dac, "DAC") | |
| ): | |
| logger.error( | |
| "Imported 'dac' module structure mismatch. Expected 'dac.utils.download()' and 'dac.DAC'." | |
| ) | |
| logger.error( | |
| "Ensure 'descript-audio-codec' is installed correctly via pip." | |
| ) | |
| raise RuntimeError( | |
| "Failed to load DAC model: required functions/structure missing from 'dac' module." | |
| ) | |
| try: | |
| # Use the original method found in the Dia repository | |
| logger.info("Downloading/finding DAC model using dac.utils.download()...") | |
| # This assumes dac.utils.download() handles caching internally | |
| dac_model_path = dac.utils.download() | |
| logger.info(f"DAC model path determined: {dac_model_path}") | |
| logger.info("Loading DAC model from path...") | |
| # Load DAC model and move it to the same device as the main Dia model | |
| dac_model = dac.DAC.load(dac_model_path).to(self.target_device) | |
| logger.info("DAC model loaded successfully.") | |
| except AttributeError as ae: | |
| logger.error( | |
| f"AttributeError loading DAC model: '{ae}'. The installed 'descript-audio-codec' version might be incompatible with Dia's original code which expects 'dac.utils.download()'." | |
| ) | |
| logger.error( | |
| "Please check for specific version requirements of 'descript-audio-codec' for Dia, or potential installation issues." | |
| ) | |
| raise RuntimeError( | |
| "Failed to load DAC model due to incompatible library version or structure" | |
| ) from ae | |
| except Exception as e: | |
| logger.error(f"General error loading DAC model: {e}", exc_info=True) | |
| raise RuntimeError("Failed to load DAC model") from e | |
| self.dac_model = dac_model | |
| def _create_attn_mask( | |
| self, | |
| q_padding_mask_1d: torch.Tensor, | |
| k_padding_mask_1d: torch.Tensor, | |
| is_causal: bool = False, | |
| ) -> torch.Tensor: | |
| """ | |
| Creates the attention mask (self or cross) based on padding masks. | |
| Mimics JAX segment ID logic where attention is allowed between non-padding tokens | |
| OR between padding tokens, but not across the boundary. | |
| Args: | |
| q_padding_mask_1d: Boolean tensor [Batch, SeqLenQ] where True indicates non-padding. | |
| k_padding_mask_1d: Boolean tensor [Batch, SeqLenK] where True indicates non-padding. | |
| is_causal: If True, applies an additional causal mask (for decoder self-attention). | |
| Returns: | |
| Boolean attention mask tensor [Batch, 1, SeqLenQ, SeqLenK] ready for F.scaled_dot_product_attention. | |
| """ | |
| B1, Tq = q_padding_mask_1d.shape | |
| B2, Tk = k_padding_mask_1d.shape | |
| if B1 != B2: | |
| logger.warning( | |
| f"Query ({B1}) and key ({B2}) batch dimensions do not match in _create_attn_mask" | |
| ) | |
| assert B1 == B2, "Query and key batch dimensions must match" | |
| # Expand masks for broadcasting: [B, Tq, 1] and [B, 1, Tk] | |
| p_mask_q = q_padding_mask_1d.unsqueeze(2) | |
| p_mask_k = k_padding_mask_1d.unsqueeze(1) | |
| # True where a non-padding query token attends to a non-padding key token | |
| non_pad_attends_non_pad = p_mask_q & p_mask_k # Shape [B, Tq, Tk] | |
| # True where a padding query token attends to a padding key token | |
| pad_attends_pad = (~p_mask_q) & (~p_mask_k) # Shape [B, Tq, Tk] | |
| # Combine: Attention is allowed if tokens are both non-padding OR both padding. | |
| mask = non_pad_attends_non_pad | pad_attends_pad # Shape [B, Tq, Tk] | |
| if is_causal: | |
| # Apply causal mask for self-attention (query cannot attend to future keys) | |
| if Tq != Tk: | |
| logger.warning(f"Causal mask requested but Tq ({Tq}) != Tk ({Tk})") | |
| assert ( | |
| Tq == Tk | |
| ), "Causal mask requires query and key sequence lengths to be equal" | |
| # Create lower triangular matrix (True allows attention) | |
| causal_mask_2d = torch.tril( | |
| torch.ones((Tq, Tk), dtype=torch.bool, device=self.target_device) | |
| ) | |
| # Combine with padding compatibility mask | |
| mask = mask & causal_mask_2d # Shape [B, Tq, Tk] | |
| # Add head dimension for broadcasting: [B, 1, Tq, Tk] | |
| return mask.unsqueeze(1) | |
| def _prepare_text_input( | |
| self, text: str | |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """ | |
| Encodes text prompt into byte tokens, pads to max length, | |
| and creates position IDs and padding mask. | |
| Args: | |
| text: The input text string. | |
| Returns: | |
| Tuple containing: | |
| - src_tokens: Padded token IDs [1, SeqLen]. | |
| - src_positions: Position IDs [1, SeqLen]. | |
| - src_padding_mask: Boolean mask (True=non-pad) [1, SeqLen]. | |
| - enc_self_attn_mask: Attention mask for encoder [1, 1, SeqLen, SeqLen]. | |
| """ | |
| text_pad_value = self.config.data.text_pad_value | |
| max_len = self.config.data.text_length | |
| logger.debug( | |
| f"Preparing text input. Max length: {max_len}, Pad value: {text_pad_value}" | |
| ) | |
| logger.debug(f"Original text (start): '{text[:100]}...'") | |
| # Convert text to bytes and replace special speaker tokens | |
| byte_text = text.encode("utf-8") | |
| # Assuming Dia uses byte values 1 and 2 for S1/S2 based on original code context | |
| replaced_bytes = byte_text.replace(b"[S1]", b"\x01").replace(b"[S2]", b"\x02") | |
| text_tokens = list(replaced_bytes) # List of integer byte values | |
| logger.debug( | |
| f"Text tokens after byte conversion (first 10): {text_tokens[:10]}" | |
| ) | |
| # Pad or truncate sequence | |
| current_len = len(text_tokens) | |
| padding_needed = max_len - current_len | |
| if padding_needed <= 0: | |
| if current_len > max_len: | |
| logger.warning( | |
| f"Input text length ({current_len}) exceeds max length ({max_len}). Truncating." | |
| ) | |
| text_tokens = text_tokens[:max_len] | |
| padded_text_np = np.array(text_tokens, dtype=np.uint8) | |
| else: | |
| logger.debug(f"Padding text input with {padding_needed} pad tokens.") | |
| padded_text_np = np.pad( | |
| text_tokens, | |
| (0, padding_needed), | |
| mode="constant", | |
| constant_values=text_pad_value, | |
| ).astype(np.uint8) | |
| # Convert to tensors and add batch dimension [1, SeqLen] | |
| src_tokens = ( | |
| torch.from_numpy(padded_text_np) | |
| .to(torch.long) | |
| .to(self.target_device) | |
| .unsqueeze(0) | |
| ) | |
| src_positions = ( | |
| torch.arange(max_len, device=self.target_device).to(torch.long).unsqueeze(0) | |
| ) | |
| # Create padding mask (True where token is NOT the pad value) | |
| src_padding_mask = src_tokens != text_pad_value # Shape [1, SeqLen] | |
| # Create attention mask for the encoder (non-causal self-attention) | |
| # Needs shape [B, 1, Tq, Tk] -> [1, 1, SeqLen, SeqLen] | |
| enc_self_attn_mask = self._create_attn_mask( | |
| src_padding_mask, src_padding_mask, is_causal=False | |
| ) | |
| logger.debug(f"Prepared src_tokens shape: {src_tokens.shape}") | |
| logger.debug(f"Prepared src_positions shape: {src_positions.shape}") | |
| logger.debug( | |
| f"Prepared src_padding_mask shape: {src_padding_mask.shape} (True means non-padding)" | |
| ) | |
| logger.debug(f"Prepared enc_self_attn_mask shape: {enc_self_attn_mask.shape}") | |
| return src_tokens, src_positions, src_padding_mask, enc_self_attn_mask | |
| def generate( | |
| self, | |
| text: str, | |
| max_tokens: int | None = None, | |
| cfg_scale: float = 3.0, | |
| temperature: float = 1.3, | |
| top_p: float = 0.95, | |
| use_cfg_filter: bool = True, | |
| use_torch_compile: bool = False, # Default to False for broader compatibility | |
| cfg_filter_top_k: int = 35, | |
| audio_prompt_path: str | None = None, | |
| ) -> np.ndarray: | |
| """ | |
| Generates audio waveform from a text prompt, optionally conditioned on an audio prompt. | |
| Args: | |
| text: The input text string. For dialogue, use [S1]/[S2] markers. | |
| For voice cloning, prepend the transcript of the audio prompt. | |
| max_tokens: Maximum number of audio tokens (frames) to generate. Defaults to config value. | |
| cfg_scale: Classifier-Free Guidance scale. Higher values increase adherence to text. | |
| temperature: Sampling temperature. Higher values increase randomness. | |
| top_p: Nucleus sampling probability. Filters vocabulary during sampling. | |
| use_cfg_filter: Whether to apply Top-K filtering based on CFG logits. | |
| use_torch_compile: If True, attempts to compile the decoder step for potential speedup. | |
| cfg_filter_top_k: The 'K' value for CFG Top-K filtering. | |
| audio_prompt_path: Path to an audio file (e.g., WAV, MP3) to use as a voice prompt/clone target. | |
| Returns: | |
| A 1D NumPy array containing the generated audio waveform (float32). | |
| """ | |
| start_time_gen = time.time() | |
| logger.info("Starting audio generation...") | |
| logger.info(f" Text (start): '{text[:100]}...'") | |
| logger.info( | |
| f" Max tokens: {max_tokens if max_tokens is not None else 'Model Default'}" | |
| ) | |
| logger.info(f" CFG Scale: {cfg_scale}") | |
| logger.info(f" Temperature: {temperature}") | |
| logger.info(f" Top P: {top_p}") | |
| logger.info(f" Use CFG Filter: {use_cfg_filter}, Top K: {cfg_filter_top_k}") | |
| logger.info( | |
| f" Audio Prompt: {audio_prompt_path if audio_prompt_path else 'None'}" | |
| ) | |
| logger.info(f" Use torch.compile: {use_torch_compile}") | |
| logger.info(f" Target Device: {self.target_device}") | |
| # --- Parameter Setup --- | |
| num_channels = self.config.data.channels | |
| audio_bos_value = self.config.data.audio_bos_value | |
| audio_eos_value = self.config.data.audio_eos_value | |
| audio_pad_value = self.config.data.audio_pad_value | |
| delay_pattern = self.config.data.delay_pattern | |
| # Use model's default audio length if max_tokens not provided | |
| effective_max_tokens = ( | |
| max_tokens if max_tokens is not None else self.config.data.audio_length | |
| ) | |
| logger.info(f" Effective max_tokens for generation: {effective_max_tokens}") | |
| # Ensure delay pattern is usable | |
| if not isinstance(delay_pattern, list) or not delay_pattern: | |
| logger.warning("Delay pattern is invalid or empty. Using default [0].") | |
| delay_pattern = [ | |
| 0 | |
| ] * num_channels # Fallback, though config should provide default | |
| delay_tensor = torch.tensor( | |
| delay_pattern, dtype=torch.long, device=self.target_device | |
| ) | |
| max_delay_pattern = max(delay_pattern) if delay_pattern else 0 | |
| self.model.eval() # Ensure model is in eval mode | |
| # --- Prepare Conditional and Unconditional Inputs --- | |
| logger.info( | |
| "Preparing text inputs for conditional and unconditional generation..." | |
| ) | |
| ( | |
| cond_src_BxS, | |
| cond_src_positions_BxS, | |
| cond_src_padding_mask_BxS, | |
| cond_enc_self_attn_mask_Bx1xSxS, | |
| ) = self._prepare_text_input(text) | |
| # Create unconditional input (batch of zeros representing padding) | |
| # Assuming pad value 0 for text based on config default | |
| unc_src_BxS = torch.full_like( | |
| cond_src_BxS, fill_value=self.config.data.text_pad_value | |
| ) | |
| # Batch conditional and unconditional inputs together [2, SeqLen] | |
| src_BxS = torch.cat([unc_src_BxS, cond_src_BxS], dim=0) | |
| # Expand other inputs to match batch size 2 | |
| src_positions_BxS = cond_src_positions_BxS.expand(2, -1) | |
| src_padding_mask_BxS = torch.cat( | |
| [ | |
| torch.zeros_like(cond_src_padding_mask_BxS[0:1]), | |
| cond_src_padding_mask_BxS, | |
| ], | |
| dim=0, | |
| ) # Uncond mask is all False (padding) | |
| # Encoder mask needs to handle the batched input correctly | |
| # For CFG, typically the unconditional branch attends to nothing useful from text, | |
| # but the structure needs to be maintained. We can reuse the conditional mask structure, | |
| # but the actual attention scores will be based on the zeroed-out unconditional input. | |
| # Alternatively, create a specific mask for the unconditional part if needed. | |
| # Let's expand the conditional mask for simplicity, assuming the model handles zero inputs appropriately. | |
| enc_self_attn_mask_Bx1xSxS = cond_enc_self_attn_mask_Bx1xSxS.expand( | |
| 2, -1, -1, -1 | |
| ) | |
| logger.info("Text inputs prepared (batch size 2 for CFG).") | |
| # --- Encoder Pass --- | |
| logger.info("Running encoder pass...") | |
| start_time_enc = time.time() | |
| # Potentially use autocast for mixed precision if supported and beneficial on device | |
| # Example: with torch.autocast(device_type=self.target_device.type, dtype=torch.bfloat16 if self.target_device.type == 'cuda' else torch.float32): | |
| encoder_out = self.model.encoder( | |
| x_ids=src_BxS, # Shape [2, S] | |
| src_positions=src_positions_BxS, # Shape [2, S] | |
| deterministic=True, # No dropout during inference | |
| attn_mask=enc_self_attn_mask_Bx1xSxS, # Shape [2, 1, S, S] | |
| ) | |
| logger.info( | |
| f"Encoder pass completed in {time.time() - start_time_enc:.3f}s. Output shape: {encoder_out.shape}" | |
| ) # Shape: [2, S, E] | |
| # --- Prepare Decoder Inputs & KV Cache --- | |
| logger.info("Preparing decoder inputs and KV cache...") | |
| start_time_kv = time.time() | |
| # 3-1. Precompute Cross-Attention KV Cache (Static) from encoder output | |
| # This cache is computed once and reused for every decoding step. | |
| decoder_cross_attention_cache: list[KVCache] = ( | |
| self.model.decoder.precompute_cross_attention_kv( | |
| effective_max_tokens, encoder_out, src_positions_BxS | |
| ) | |
| ) | |
| logger.debug( | |
| f"Precomputed cross-attention KV cache for {len(decoder_cross_attention_cache)} layers." | |
| ) | |
| # 3-2. Initialize Self-Attention KV Cache (Dynamic, grows with each step) | |
| decoder_self_attention_cache: list[KVCache] = [] | |
| for i in range(self.model.decoder.num_layers): | |
| decoder_self_attention_cache.append( | |
| KVCache( | |
| self.config.model.decoder.gqa_query_heads, | |
| effective_max_tokens, # Max length the cache can hold | |
| self.config.model.decoder.gqa_head_dim, | |
| self.target_device, # Cache tensors should be on the target device | |
| ) | |
| ) | |
| logger.debug( | |
| f"Initialized self-attention KV cache for {len(decoder_self_attention_cache)} layers." | |
| ) | |
| logger.info( | |
| f"KV cache preparation completed in {time.time() - start_time_kv:.3f}s." | |
| ) | |
| # 3-3. Initialize Decoder Start Tokens (BOS) | |
| # Shape [2, 1, C] (Batch=2 for cond/uncond, T=1 for first step, C=channels) | |
| generated_tokens_history = torch.full( | |
| (2, 1, num_channels), | |
| fill_value=audio_bos_value, | |
| dtype=torch.long, | |
| device=self.target_device, | |
| ) | |
| logger.debug(f"Initial decoder input (BOS): {generated_tokens_history.shape}") | |
| current_step_index = ( | |
| 0 # Index of the step we are currently generating (starts at 0) | |
| ) | |
| prompt_len_inc_bos = 1 # Length of the initial prompt (just BOS initially) | |
| # 3-4. Handle Audio Prompt (Prefill KV Cache) | |
| if audio_prompt_path is not None: | |
| logger.info("Processing audio prompt for prefilling...") | |
| start_time_prompt = time.time() | |
| try: | |
| # Load and potentially resample audio | |
| audio_prompt_waveform, sr = torchaudio.load(audio_prompt_path) | |
| logger.debug( | |
| f"Loaded audio prompt: {audio_prompt_waveform.shape}, Sample Rate: {sr}" | |
| ) | |
| if sr != 44100: | |
| logger.info(f"Resampling audio prompt from {sr}Hz to 44100Hz") | |
| audio_prompt_waveform = torchaudio.functional.resample( | |
| audio_prompt_waveform, sr, 44100 | |
| ) | |
| # Ensure correct shape [B, C, T_audio] and device | |
| # Assuming DAC expects channels first, add batch dim | |
| if audio_prompt_waveform.ndim == 1: # Mono | |
| audio_prompt_waveform = audio_prompt_waveform.unsqueeze( | |
| 0 | |
| ) # Add channel dim | |
| audio_prompt_waveform = audio_prompt_waveform.unsqueeze(0).to( | |
| self.target_device | |
| ) # Add batch dim | |
| # Encode audio prompt to codes using DAC | |
| logger.info("Encoding audio prompt to codes using DAC...") | |
| if self.dac_model is None: | |
| raise RuntimeError( | |
| "DAC model not loaded, required for audio prompt." | |
| ) | |
| # audio_to_codebook returns [B, T_codes, C] | |
| audio_prompt_codes = audio_to_codebook( | |
| self.dac_model, audio_prompt_waveform, data_config=self.config.data | |
| ) # Shape [1, T_codes, C] | |
| logger.info( | |
| f"Encoded audio prompt to codes: {audio_prompt_codes.shape}" | |
| ) | |
| # Concatenate BOS tokens with prompt codes | |
| # Expand prompt codes to batch size 2 (for cond/uncond) | |
| generated_tokens_history = torch.cat( | |
| [generated_tokens_history, audio_prompt_codes.expand(2, -1, -1)], | |
| dim=1, | |
| ) # Shape [2, 1 + T_codes, C] | |
| logger.debug( | |
| f"Decoder input history after prompt concatenation: {generated_tokens_history.shape}" | |
| ) | |
| prefill_len = generated_tokens_history.shape[ | |
| 1 | |
| ] # Length including BOS + prompt | |
| prompt_len_inc_bos = prefill_len | |
| logger.info(f"Prefilling KV cache with length {prefill_len}...") | |
| # Prepare inputs for prefill forward pass | |
| prefill_tgt_pos = ( | |
| torch.arange(prefill_len, device=self.target_device) | |
| .unsqueeze(0) | |
| .expand(2, -1) | |
| ) # Shape [2, T_prefill] | |
| # Padding mask based on actual tokens (BOS and prompt codes are not PAD) | |
| # Shape [2, T_prefill] (True where not PAD) | |
| prefill_tgt_padding_mask = ( | |
| generated_tokens_history != audio_pad_value | |
| ).any(dim=2) | |
| # Create attention masks for prefill | |
| # Shape [2, 1, T_prefill, T_prefill] | |
| prefill_self_attn_mask = self._create_attn_mask( | |
| prefill_tgt_padding_mask, | |
| prefill_tgt_padding_mask, | |
| is_causal=True, | |
| ) | |
| # Shape [2, 1, T_prefill, S] | |
| prefill_cross_attn_mask = self._create_attn_mask( | |
| prefill_tgt_padding_mask, | |
| src_padding_mask_BxS, | |
| is_causal=False, | |
| ) | |
| # Run forward pass through decoder to fill the self-attention KV cache | |
| # We discard the logits from prefill | |
| _ = self.model.decoder.forward( | |
| tgt_ids_BxTxC=generated_tokens_history, # Pass the full history [2, T_prefill, C] | |
| encoder_out=encoder_out, | |
| tgt_positions=prefill_tgt_pos, | |
| src_positions=src_positions_BxS, | |
| deterministic=True, | |
| self_attn_mask=prefill_self_attn_mask, | |
| cross_attn_mask=prefill_cross_attn_mask, | |
| self_attention_cache=decoder_self_attention_cache, # Pass cache to be filled | |
| cross_attention_cache=decoder_cross_attention_cache, # Pass precomputed cache | |
| # prefill=True # Pass prefill flag if decoder layer uses it | |
| ) | |
| # Update the current step index. The next token to generate is at index prefill_len. | |
| current_step_index = prefill_len | |
| logger.info( | |
| f"KV cache prefilled in {time.time() - start_time_prompt:.3f}s. Next step index: {current_step_index}" | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error processing audio prompt: {e}", exc_info=True) | |
| raise RuntimeError("Failed to process audio prompt") from e | |
| # --- Autoregressive Generation Loop --- | |
| logger.info("Starting autoregressive generation loop...") | |
| start_time_loop = time.time() | |
| eos_detected_channel_0 = False | |
| eos_countdown = -1 # Countdown after EOS detected in channel 0 | |
| extra_steps_after_eos = ( | |
| 30 # Generate a few extra steps for delay pattern completion | |
| ) | |
| # Pre-allocate tensor for storing *newly* generated tokens for efficiency | |
| # We already have the prompt in generated_tokens_history | |
| num_steps_to_generate = effective_max_tokens | |
| newly_generated_tokens = torch.full( | |
| (2, num_steps_to_generate, num_channels), | |
| fill_value=audio_pad_value, # Fill with pad initially | |
| dtype=torch.long, | |
| device=self.target_device, | |
| ) | |
| logger.debug( | |
| f"Allocated tensor for newly generated tokens: {newly_generated_tokens.shape}" | |
| ) | |
| # --- Compile decode_step if requested --- | |
| decode_step_fn = self.model.decoder.decode_step | |
| if use_torch_compile: | |
| logger.info("Compiling decoder step function with torch.compile...") | |
| try: | |
| # Experiment with modes: "default", "reduce-overhead", "max-autotune" | |
| decode_step_fn = torch.compile(decode_step_fn, mode="reduce-overhead") | |
| logger.info("Decoder step function compiled.") | |
| except Exception as e: | |
| logger.warning( | |
| f"torch.compile failed: {e}. Using eager mode.", exc_info=True | |
| ) | |
| # --- Prepare static cross-attention mask for single-step decoding --- | |
| # Query mask is always [B, 1] (True, as generated tokens are not PAD) | |
| step_tgt_padding_mask = torch.ones( | |
| (2, 1), dtype=torch.bool, device=self.target_device | |
| ) | |
| # Shape [2, 1, 1, S] | |
| step_decoder_cross_attn_mask = self._create_attn_mask( | |
| step_tgt_padding_mask, | |
| src_padding_mask_BxS, | |
| is_causal=False, | |
| ) | |
| # --- Generation Loop --- | |
| steps_taken = 0 | |
| for step_offset in range(num_steps_to_generate): | |
| # Absolute step index considering prompt length | |
| current_absolute_step = current_step_index + step_offset | |
| # Get the token IDs for the *previous* step to predict the current one | |
| # Shape [2, 1, C] | |
| # If step_offset is 0, use the last token from the prompt history | |
| if step_offset == 0: | |
| input_token_ids = generated_tokens_history[:, -1, :].unsqueeze(1) | |
| else: | |
| # Use the token generated in the previous iteration of this loop | |
| input_token_ids = newly_generated_tokens[ | |
| :, step_offset - 1, : | |
| ].unsqueeze(1) | |
| # Position ID for the current absolute step | |
| # Shape [2, 1] | |
| tgt_pos_Bx1 = torch.full( | |
| (2, 1), | |
| fill_value=current_absolute_step, | |
| dtype=torch.long, | |
| device=self.target_device, | |
| ) | |
| # --- Call Decoder Step --- | |
| # self_attn_mask is None because KV cache handles causality implicitly in single-step decoding | |
| logits_Bx1xCxV, new_self_kv_cache_list = decode_step_fn( | |
| tgt_ids_Bx1xC=input_token_ids, | |
| tgt_pos_Bx1=tgt_pos_Bx1, | |
| encoder_out=encoder_out, | |
| self_attn_mask=None, | |
| cross_attn_mask=step_decoder_cross_attn_mask, | |
| self_attention_cache=decoder_self_attention_cache, | |
| cross_attention_cache=decoder_cross_attention_cache, | |
| ) # Logits shape: [2, 1, C, V] | |
| # --- Update Self-Attention KV Cache --- | |
| for i, layer_cache in enumerate(decoder_self_attention_cache): | |
| if ( | |
| new_self_kv_cache_list | |
| and i < len(new_self_kv_cache_list) | |
| and new_self_kv_cache_list[i] is not None | |
| ): | |
| # new_self_kv_cache_list[i] is a tuple (k_tensor, v_tensor) for the current step | |
| # k_tensor shape: [2, NumHeads, 1, HeadDim] | |
| # v_tensor shape: [2, NumHeads, 1, HeadDim] | |
| layer_cache.update_cache( | |
| new_self_kv_cache_list[i][0], new_self_kv_cache_list[i][1] | |
| ) | |
| else: | |
| logger.warning( | |
| f"Missing KV cache update for layer {i} at step {current_absolute_step}" | |
| ) | |
| # --- Sampling --- | |
| V = self.config.model.tgt_vocab_size | |
| # Get logits for the generated step [2, C, V] | |
| logits_last_BxCxV = logits_Bx1xCxV.squeeze(1) | |
| # Separate conditional and unconditional logits | |
| uncond_logits_CxV = logits_last_BxCxV[0, :, :] # Shape [C, V] | |
| cond_logits_CxV = logits_last_BxCxV[1, :, :] # Shape [C, V] | |
| # Apply Classifier-Free Guidance (CFG) | |
| cfg_logits_CxV = cond_logits_CxV + cfg_scale * ( | |
| cond_logits_CxV - uncond_logits_CxV | |
| ) # Shape [C, V] | |
| # --- Prevent sampling PAD/EOS/BOS tokens inappropriately --- | |
| logits_for_sampling_CxV = ( | |
| cfg_logits_CxV.clone() | |
| ) # Clone to avoid modifying original logits | |
| logits_for_sampling_CxV[:, audio_pad_value] = -torch.inf # Never sample PAD | |
| logits_for_sampling_CxV[:, audio_bos_value] = ( | |
| -torch.inf | |
| ) # Never sample BOS after start | |
| # Allow EOS only if not already detected or in countdown | |
| if eos_detected_channel_0 and eos_countdown <= 0: | |
| logits_for_sampling_CxV[:, audio_eos_value] = -torch.inf | |
| # --- Sample the next token for each channel --- | |
| pred_C = _sample_next_token( | |
| logits_for_sampling_CxV.float(), # Ensure float32 for sampling stability | |
| temperature=temperature, | |
| top_p=top_p, | |
| use_cfg_filter=use_cfg_filter, | |
| cfg_filter_top_k=cfg_filter_top_k, | |
| ) # Shape [C] | |
| # --- Handle Delay Pattern (Only if no audio prompt was given) --- | |
| # If there's no prompt, the first few tokens should be BOS according to delay | |
| # generation_step_index is how many tokens generated *after* prompt/initial BOS | |
| generation_step_index = step_offset | |
| if audio_prompt_path is None: | |
| is_before_delay = generation_step_index < delay_tensor # Shape [C] | |
| pred_C = torch.where( | |
| is_before_delay, | |
| torch.tensor( | |
| audio_bos_value, device=self.target_device, dtype=torch.long | |
| ), | |
| pred_C, | |
| ) | |
| # --- Store the predicted token in the newly_generated_tokens tensor --- | |
| newly_generated_tokens[:, step_offset, :] = pred_C.unsqueeze(0).expand( | |
| 2, -1 | |
| ) | |
| steps_taken += 1 # Increment steps taken in this loop | |
| # --- EOS Handling --- | |
| if not eos_detected_channel_0 and pred_C[0] == audio_eos_value: | |
| logger.info( | |
| f"EOS token detected in channel 0 at step {current_absolute_step}. Starting countdown." | |
| ) | |
| eos_detected_channel_0 = True | |
| eos_countdown = extra_steps_after_eos | |
| if eos_countdown > 0: | |
| step_after_eos = extra_steps_after_eos - eos_countdown | |
| logger.debug( | |
| f"EOS countdown: {eos_countdown}, Step after EOS: {step_after_eos}" | |
| ) | |
| # Modify the token *just generated* if needed for EOS/PAD forcing | |
| current_new_tokens = newly_generated_tokens[ | |
| :, step_offset, : | |
| ] # Shape [2, C] | |
| for i, d in enumerate(delay_pattern): | |
| if step_after_eos == d: | |
| logger.debug( | |
| f" Forcing EOS in channel {i} at step {current_absolute_step}" | |
| ) | |
| current_new_tokens[:, i] = audio_eos_value | |
| elif step_after_eos > d: | |
| logger.debug( | |
| f" Forcing PAD in channel {i} at step {current_absolute_step}" | |
| ) | |
| current_new_tokens[:, i] = audio_pad_value | |
| # Put the potentially modified tokens back | |
| newly_generated_tokens[:, step_offset, :] = current_new_tokens | |
| eos_countdown -= 1 | |
| if eos_countdown == 0: | |
| logger.info( | |
| f"EOS countdown finished at step {current_absolute_step}. Stopping generation." | |
| ) | |
| break # Stop generation loop | |
| # Check if we reached the max *new* tokens requested | |
| if steps_taken >= num_steps_to_generate: | |
| logger.info( | |
| f"Reached max generation steps ({num_steps_to_generate}). Stopping." | |
| ) | |
| break | |
| logger.info( | |
| f"Autoregressive loop finished after {steps_taken} steps in {time.time() - start_time_loop:.3f}s." | |
| ) | |
| # --- Extract Generated Codes --- | |
| # Get the conditional generation result (index 1) from the *newly* generated tokens | |
| # Only take the number of steps actually taken | |
| final_new_codes = newly_generated_tokens[ | |
| 1, :steps_taken, : | |
| ] # Shape [T_generated, C] | |
| logger.info(f"Extracted newly generated codes shape: {final_new_codes.shape}") | |
| # --- Convert Codes to Audio using DAC --- | |
| logger.info("Converting generated codes to audio using DAC...") | |
| start_time_decode = time.time() | |
| if self.dac_model is None: | |
| raise RuntimeError("DAC model not loaded, required for audio decoding.") | |
| # codebook_to_audio expects codes shape [C, T] | |
| generated_codes_CxT = final_new_codes.transpose(0, 1) # Shape [C, T_generated] | |
| if generated_codes_CxT.numel() == 0: | |
| logger.warning("No new codes were generated. Returning empty audio.") | |
| return np.array([], dtype=np.float32) | |
| # Call the decoding function (handles delay reversal and DAC decoding) | |
| audio_waveform = codebook_to_audio( | |
| generated_codes_CxT, | |
| self.dac_model, | |
| delay_pattern, | |
| B=1, # Batch size for decoding is 1 | |
| T=generated_codes_CxT.shape[1], # Pass the actual length of generated codes | |
| C=num_channels, | |
| ) # Returns shape [1, T_audio] or [T_audio] | |
| # Ensure output is a 1D numpy array on CPU | |
| final_audio_np = audio_waveform.squeeze().cpu().numpy() | |
| logger.info( | |
| f"Audio decoding completed in {time.time() - start_time_decode:.3f}s. Output shape: {final_audio_np.shape}" | |
| ) | |
| logger.info(f"Total generation time: {time.time() - start_time_gen:.3f}s") | |
| return final_audio_np | |