# Copyright 2025 The ACESTEO Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math import time from typing import Callable, List, Optional, Union import torch import torch.nn.functional as F from torch import nn from einops import rearrange # Transformers imports (sorted by submodule, then alphabetically) from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from transformers.modeling_flash_attention_utils import FlashAttentionKwargs from transformers.modeling_layers import GradientCheckpointingLayer from transformers.modeling_outputs import BaseModelOutput from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from transformers.processing_utils import Unpack from transformers.utils import auto_docstring, can_return_tuple, logging from transformers.models.qwen3.modeling_qwen3 import ( Qwen3MLP, Qwen3RMSNorm, Qwen3RotaryEmbedding, apply_rotary_pos_emb, eager_attention_forward, ) from tqdm import tqdm from vector_quantize_pytorch import ResidualFSQ # Local config import with fallback try: from .configuration_acestep_v15 import AceStepConfig from .apg_guidance import adg_forward, apg_forward, MomentumBuffer except ImportError: from configuration_acestep_v15 import AceStepConfig from apg_guidance import adg_forward, apg_forward, MomentumBuffer logger = logging.get_logger(__name__) def create_4d_mask( seq_len: int, dtype: torch.dtype, device: torch.device, attention_mask: Optional[torch.Tensor] = None, # [Batch, Seq_Len] sliding_window: Optional[int] = None, is_sliding_window: bool = False, is_causal: bool = True, ) -> torch.Tensor: """ General 4D Attention Mask generator compatible with CPU/Mac/SDPA and Eager mode. Supports use cases: 1. Causal Full: is_causal=True, is_sliding_window=False (standard GPT) 2. Causal Sliding: is_causal=True, is_sliding_window=True (Mistral/Qwen local window) 3. Bidirectional Full: is_causal=False, is_sliding_window=False (BERT/Encoder) 4. Bidirectional Sliding: is_causal=False, is_sliding_window=True (Longformer local) Returns: [Batch, 1, Seq_Len, Seq_Len] additive mask (0.0 for keep, -inf for mask) """ # ------------------------------------------------------ # 1. Construct basic geometry mask [Seq_Len, Seq_Len] # ------------------------------------------------------ # Build index matrices # i (Query): [0, 1, ..., L-1] # j (Key): [0, 1, ..., L-1] indices = torch.arange(seq_len, device=device) # diff = i - j diff = indices.unsqueeze(1) - indices.unsqueeze(0) # Initialize all True (all positions visible) valid_mask = torch.ones((seq_len, seq_len), device=device, dtype=torch.bool) # (A) Handle causality (Causal) if is_causal: # i >= j => diff >= 0 valid_mask = valid_mask & (diff >= 0) # (B) Handle sliding window if is_sliding_window and sliding_window is not None: if is_causal: # Causal sliding: only attend to past window steps # i - j <= window => diff <= window # (diff >= 0 already handled above) valid_mask = valid_mask & (diff <= sliding_window) else: # Bidirectional sliding: attend past and future window steps # |i - j| <= window => abs(diff) <= sliding_window valid_mask = valid_mask & (torch.abs(diff) <= sliding_window) # Expand dimensions to [1, 1, Seq_Len, Seq_Len] for broadcasting valid_mask = valid_mask.unsqueeze(0).unsqueeze(0) # ------------------------------------------------------ # 2. Apply padding mask (Key Masking) # ------------------------------------------------------ if attention_mask is not None: # attention_mask shape: [Batch, Seq_Len] (1=valid, 0=padding) # We want to mask out invalid keys (columns) # Expand shape: [Batch, 1, 1, Seq_Len] padding_mask_4d = attention_mask.view(attention_mask.shape[0], 1, 1, seq_len).to(torch.bool) # Broadcasting: Geometry Mask [1, 1, L, L] & Padding Mask [B, 1, 1, L] # Result shape: [B, 1, L, L] valid_mask = valid_mask & padding_mask_4d # ------------------------------------------------------ # 3. Convert to additive mask # ------------------------------------------------------ # Get the minimal value for current dtype min_dtype = torch.finfo(dtype).min # Create result tensor filled with -inf by default mask_tensor = torch.full(valid_mask.shape, min_dtype, dtype=dtype, device=device) # Set valid positions to 0.0 mask_tensor.masked_fill_(valid_mask, 0.0) return mask_tensor def pack_sequences(hidden1: torch.Tensor, hidden2: torch.Tensor, mask1: torch.Tensor, mask2: torch.Tensor): """ Pack two sequences by concatenating and sorting them based on mask values. Args: hidden1: First hidden states tensor of shape [B, L1, D] hidden2: Second hidden states tensor of shape [B, L2, D] mask1: First mask tensor of shape [B, L1] mask2: Second mask tensor of shape [B, L2] Returns: Tuple of (packed_hidden_states, new_mask) where: - packed_hidden_states: Packed hidden states with valid tokens (mask=1) first, shape [B, L1+L2, D] - new_mask: New mask tensor indicating valid positions, shape [B, L1+L2] """ # Step 1: Concatenate hidden states and masks along sequence dimension hidden_cat = torch.cat([hidden1, hidden2], dim=1) # [B, L, D] mask_cat = torch.cat([mask1, mask2], dim=1) # [B, L] B, L, D = hidden_cat.shape # Step 2: Sort indices so that mask values of 1 come before 0 sort_idx = mask_cat.argsort(dim=1, descending=True, stable=True) # [B, L] # Step 3: Reorder hidden states using sorted indices hidden_left = torch.gather(hidden_cat, 1, sort_idx.unsqueeze(-1).expand(B, L, D)) # Step 4: Create new mask based on valid sequence lengths lengths = mask_cat.sum(dim=1) # [B] new_mask = (torch.arange(L, dtype=torch.long, device=hidden_cat.device).unsqueeze(0) < lengths.unsqueeze(1)) return hidden_left, new_mask def sample_t_r(batch_size, device, dtype, data_proportion=0.0, timestep_mu=-0.4, timestep_sigma=1.0, use_meanflow=True): """ Sample timestep t and r for flow matching training. Args: batch_size: Batch size device: Device to create tensors on dtype: Data type for tensors data_proportion: Proportion of data samples (0.0 to 1.0) timestep_mu: Mean for timestep sampling timestep_sigma: Standard deviation for timestep sampling use_meanflow: Whether to use meanflow (if False, data_proportion is set to 1.0) Returns: Tuple of (t, r) tensors, each of shape [batch_size] """ t = torch.sigmoid(torch.randn((batch_size,), device=device, dtype=dtype) * timestep_sigma + timestep_mu) r = torch.sigmoid(torch.randn((batch_size,), device=device, dtype=dtype) * timestep_sigma + timestep_mu) # Assign t = max, r = min, for each pair t, r = torch.maximum(t, r), torch.minimum(t, r) if not use_meanflow: data_proportion = 1.0 data_size = int(batch_size * data_proportion) zero_mask = torch.arange(batch_size, device=device) < data_size r = torch.where(zero_mask, t, r) return t, r class TimestepEmbedding(nn.Module): """ Timestep embedding module for diffusion models. Converts timestep values into high-dimensional embeddings using sinusoidal positional encoding, followed by MLP layers. Used for conditioning diffusion models on timestep information. """ def __init__( self, in_channels: int, time_embed_dim: int, scale: float = 1000, ): super().__init__() self.linear_1 = nn.Linear(in_channels, time_embed_dim, bias=True) self.act1 = nn.SiLU() self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim, bias=True) self.in_channels = in_channels self.act2 = nn.SiLU() self.time_proj = nn.Linear(time_embed_dim, time_embed_dim * 6) self.scale = scale def timestep_embedding(self, t, dim, max_period=10000): """ Create sinusoidal timestep embeddings. Args: t: A 1-D tensor of N indices, one per batch element. These may be fractional. dim: The dimension of the output embeddings. max_period: Controls the minimum frequency of the embeddings. Returns: An (N, D) tensor of positional embeddings. """ t = t * self.scale half = dim // 2 freqs = torch.exp( -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half ).to(device=t.device) args = t[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) return embedding def forward(self, t): t_freq = self.timestep_embedding(t, self.in_channels) temb = self.linear_1(t_freq.to(t.dtype)) temb = self.act1(temb) temb = self.linear_2(temb) timestep_proj = self.time_proj(self.act2(temb)).unflatten(1, (6, -1)) return temb, timestep_proj class AceStepAttention(nn.Module): """ Multi-headed attention module for AceStep model. Implements the attention mechanism from 'Attention Is All You Need' paper, with support for both self-attention and cross-attention modes. Uses RMSNorm for query and key normalization, and supports sliding window attention for efficient long-sequence processing. """ def __init__(self, config: AceStepConfig, layer_idx: int, is_cross_attention: bool = False, is_causal: bool = False): super().__init__() self.config = config self.layer_idx = layer_idx self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = self.head_dim**-0.5 self.attention_dropout = config.attention_dropout if is_cross_attention: is_causal = False self.is_causal = is_causal self.is_cross_attention = is_cross_attention self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias) self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias) self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias) # Apply RMS normalization only on the head dimension (unlike OLMo) self.q_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) self.k_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) self.attention_type = config.layer_types[layer_idx] self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, position_embeddings: tuple[torch.Tensor, torch.Tensor] = None, output_attentions: Optional[bool] = False, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) # Project and normalize query states query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2) # Determine if this is cross-attention (requires encoder_hidden_states) is_cross_attention = self.is_cross_attention and encoder_hidden_states is not None # Cross-attention path: attend to encoder hidden states if is_cross_attention: encoder_hidden_shape = (*encoder_hidden_states.shape[:-1], -1, self.head_dim) if past_key_value is not None: is_updated = past_key_value.is_updated.get(self.layer_idx) # After the first generated token, we can reuse all key/value states from cache curr_past_key_value = past_key_value.cross_attention_cache # Conditions for calculating key and value states if not is_updated: # Compute and cache K/V for the first time key_states = self.k_norm(self.k_proj(encoder_hidden_states).view(encoder_hidden_shape)).transpose(1, 2) value_states = self.v_proj(encoder_hidden_states).view(encoder_hidden_shape).transpose(1, 2) # Update cache: save all key/value states to cache for fast auto-regressive generation key_states, value_states = curr_past_key_value.update(key_states, value_states, self.layer_idx) # Set flag that this layer's cross-attention cache is updated past_key_value.is_updated[self.layer_idx] = True else: # Reuse cached key/value states for subsequent tokens key_states = curr_past_key_value.layers[self.layer_idx].keys value_states = curr_past_key_value.layers[self.layer_idx].values else: # No cache used, compute K/V directly key_states = self.k_norm(self.k_proj(encoder_hidden_states).view(encoder_hidden_shape)).transpose(1, 2) value_states = self.v_proj(encoder_hidden_states).view(encoder_hidden_shape).transpose(1, 2) # Self-attention path: attend to the same sequence else: # Project and normalize key/value states for self-attention key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) # Apply rotary position embeddings (RoPE) if provided if position_embeddings is not None: cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) # Update cache for auto-regressive generation if past_key_value is not None: # Sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if is_cross_attention and output_attentions: attention_interface: Callable = eager_attention_forward elif self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, query_states, key_states, value_states, attention_mask, dropout=self.attention_dropout if self.training else 0.0, scaling=self.scaling, sliding_window=self.sliding_window if not self.is_cross_attention else None, **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights class AceStepEncoderLayer(GradientCheckpointingLayer): """ Encoder layer for AceStep model. Consists of self-attention and MLP (feed-forward) sub-layers with residual connections. """ def __init__(self, config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size self.config = config self.layer_idx = layer_idx # Self-attention sub-layer self.self_attn = AceStepAttention( config=config, layer_idx=layer_idx, is_cross_attention=False, is_causal=False, ) self.input_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) # MLP (feed-forward) sub-layer self.mlp = Qwen3MLP(config) self.attention_type = config.layer_types[layer_idx] def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = False, **kwargs, ) -> tuple[ torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]], ]: # Self-attention with residual connection residual = hidden_states hidden_states = self.input_layernorm(hidden_states) hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, output_attentions=output_attentions, # Encoders don't use cache use_cache=False, past_key_value=None, **kwargs, ) hidden_states = residual + hidden_states # MLP with residual connection residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights,) return outputs class AceStepDiTLayer(GradientCheckpointingLayer): """ DiT (Diffusion Transformer) layer for AceStep model. Implements a transformer layer with three main components: 1. Self-attention with adaptive layer norm (AdaLN) 2. Cross-attention (optional) for conditioning on encoder outputs 3. Feed-forward MLP with adaptive layer norm Uses scale-shift modulation from timestep embeddings for adaptive normalization. """ def __init__(self, config: AceStepConfig, layer_idx: int, use_cross_attention: bool = True): super().__init__() # 1. Self-attention sub-layer with adaptive normalization self.self_attn_norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.self_attn = AceStepAttention(config=config, layer_idx=layer_idx) # 2. Cross-attention sub-layer (optional, for encoder conditioning) self.use_cross_attention = use_cross_attention if self.use_cross_attention: self.cross_attn_norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.cross_attn = AceStepAttention(config=config, layer_idx=layer_idx, is_cross_attention=True) # 3. Feed-forward MLP sub-layer with adaptive normalization self.mlp_norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.mlp = Qwen3MLP(config) # Scale-shift table for adaptive layer norm modulation (6 values: 3 for self-attn, 3 for MLP) self.scale_shift_table = nn.Parameter(torch.randn(1, 6, config.hidden_size) / config.hidden_size**0.5) self.attention_type = config.layer_types[layer_idx] def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], temb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[EncoderDecoderCache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: # Extract scale-shift parameters for adaptive layer norm from timestep embeddings # 6 values: (shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa) shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( self.scale_shift_table + temb ).chunk(6, dim=1) # Step 1: Self-attention with adaptive layer norm (AdaLN) # Apply adaptive normalization: norm(x) * (1 + scale) + shift norm_hidden_states = (self.self_attn_norm(hidden_states) * (1 + scale_msa) + shift_msa).type_as(hidden_states) attn_output, self_attn_weights = self.self_attn( hidden_states=norm_hidden_states, position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, output_attentions=output_attentions, use_cache=False, past_key_value=None, **kwargs, ) # Apply gated residual connection: x = x + attn_output * gate hidden_states = (hidden_states + attn_output * gate_msa).type_as(hidden_states) # Step 2: Cross-attention (if enabled) for conditioning on encoder outputs if self.use_cross_attention: norm_hidden_states = self.cross_attn_norm(hidden_states).type_as(hidden_states) attn_output, cross_attn_weights = self.cross_attn( hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=encoder_attention_mask, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, **kwargs, ) # Standard residual connection for cross-attention hidden_states = hidden_states + attn_output # Step 3: Feed-forward (MLP) with adaptive layer norm # Apply adaptive normalization for MLP: norm(x) * (1 + scale) + shift norm_hidden_states = (self.mlp_norm(hidden_states) * (1 + c_scale_msa) + c_shift_msa).type_as(hidden_states) ff_output = self.mlp(norm_hidden_states) # Apply gated residual connection: x = x + mlp_output * gate hidden_states = (hidden_states + ff_output * c_gate_msa).type_as(hidden_states) outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights, cross_attn_weights) return outputs @auto_docstring class AceStepPreTrainedModel(PreTrainedModel): config_class = AceStepConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["AceStepEncoderLayer", "AceStepDiTLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_3 = True _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True _supports_attention_backend = True def _init_weights(self, module): """ Initialize weights for different module types. TODO: Support separate initialization for encoders and decoders. """ std = self.config.initializer_range if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, Qwen3RMSNorm): module.weight.data.fill_(1.0) class AceStepLyricEncoder(AceStepPreTrainedModel): """ Encoder for processing lyric text embeddings. Encodes lyric text hidden states using a transformer encoder architecture with bidirectional attention. Projects text embeddings to model hidden size and processes them through multiple encoder layers. """ def __init__(self, config): super().__init__(config) # Project text embeddings to model hidden size self.embed_tokens = nn.Linear(config.text_hidden_dim, config.hidden_size) self.norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = Qwen3RotaryEmbedding(config=config) self.gradient_checkpointing = False # Stack of encoder layers self.layers = nn.ModuleList( [AceStepEncoderLayer(config, layer_idx) for layer_idx in range(config.num_lyric_encoder_hidden_layers)] ) # Initialize weights and apply final processing self.post_init() @can_return_tuple def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> BaseModelOutput: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) assert input_ids is None, "Only `input_ids` is supported for the lyric encoder." assert attention_mask is not None, "Attention mask must be provided for the lyric encoder." assert inputs_embeds is not None, "Inputs embeddings must be provided for the lyric encoder." # Project input embeddings: N x T x text_hidden_dim -> N x T x hidden_size inputs_embeds = self.embed_tokens(inputs_embeds) # Cache position: only used for mask construction (not for actual caching) cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device) # Positional IDs if position_ids is None: position_ids = cache_position.unsqueeze(0) # Attention masks seq_len = inputs_embeds.shape[1] dtype = inputs_embeds.dtype device = inputs_embeds.device # 判断是否使用 Flash Attention 2 is_flash_attn = (self.config._attn_implementation == "flash_attention_2") # 初始化 Mask 变量 full_attn_mask = None sliding_attn_mask = None if is_flash_attn: # ------------------------------------------------------- # 场景 A: Flash Attention 模式 # ------------------------------------------------------- # FA 不需要 4D Mask。 # 如果有 padding mask (attention_mask [B, L]),直接传给它即可。 # 如果没有 padding mask,传 None。 # 滑动窗口逻辑由 Layer 内部传给 FA kernel 的 sliding_window 参数控制。 full_attn_mask = attention_mask # 这里的逻辑是:如果配置启用了滑动窗口,FA 模式下我们也只需要传基础的 padding mask # Layer 会自己决定是否调用带 sliding window 的 kernel sliding_attn_mask = attention_mask if self.config.use_sliding_window else None else: # ------------------------------------------------------- # 场景 B: CPU / Mac / SDPA (Eager 模式) # ------------------------------------------------------- # 必须手动生成 4D Mask [B, 1, L, L] # 1. Full Attention (Bidirectional, Global) # 对应原来的 create_causal_mask + bidirectional full_attn_mask = create_4d_mask( seq_len=seq_len, dtype=dtype, device=device, attention_mask=attention_mask, # [B, L] sliding_window=None, is_sliding_window=False, is_causal=False # <--- 关键:双向注意力 ) # 2. Sliding Attention (Bidirectional, Local) # 对应原来的 create_sliding_window... + bidirectional if self.config.use_sliding_window: sliding_attn_mask = create_4d_mask( seq_len=seq_len, dtype=dtype, device=device, attention_mask=attention_mask, # [B, L] sliding_window=self.config.sliding_window, is_sliding_window=True, # <--- 开启滑动窗口 is_causal=False # <--- 关键:双向注意力 ) # 构建 Mapping self_attn_mask_mapping = { "full_attention": full_attn_mask, "sliding_attention": sliding_attn_mask, } # Initialize hidden states with input embeddings hidden_states = inputs_embeds # Create position embeddings to be shared across all layers position_embeddings = self.rotary_emb(hidden_states, position_ids) # Pass through transformer layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None for layer_module in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: all_hidden_states += (hidden_states,) layer_outputs = layer_module( hidden_states, position_embeddings, self_attn_mask_mapping[layer_module.attention_type], position_ids, output_attentions, **flash_attn_kwargs, ) hidden_states = layer_outputs[0] if output_attentions: all_self_attns += (layer_outputs[1],) hidden_states = self.norm(hidden_states) if output_hidden_states: all_hidden_states += (hidden_states,) return BaseModelOutput( last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attns, ) class AttentionPooler(AceStepPreTrainedModel): """ Attention-based pooling module. Pools sequences of patches using a special token and attention mechanism. The special token attends to all patches and its output is used as the pooled representation. Used for aggregating patch-level features into sequence-level representations. """ def __init__(self, config): super().__init__(config) self.config = config self.embed_tokens = nn.Linear(config.hidden_size, config.hidden_size) self.norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = Qwen3RotaryEmbedding(config=config) self.gradient_checkpointing = False # Special token used for pooling (CLS-like token) self.special_token = nn.Parameter(torch.randn(1, 1, config.hidden_size) * 0.02) self.layers = nn.ModuleList( [AceStepEncoderLayer(config, layer_idx) for layer_idx in range(config.num_attention_pooler_hidden_layers)] ) # Initialize weights and apply final processing self.post_init() @can_return_tuple def forward(self, x, attention_mask: Optional[torch.Tensor] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> BaseModelOutput: B, T, P, D = x.shape x = self.embed_tokens(x) special_tokens = self.special_token.expand(B, T, 1, -1) x = torch.cat([special_tokens, x], dim=2) x = rearrange(x, "b t p c -> (b t) p c") # Cache position: only used for mask construction. cache_position = torch.arange(0, x.shape[1], device=x.device) # Postional ids. position_ids = cache_position.unsqueeze(0) # embed positions hidden_states = x # create position embeddings to be shared across the decoder layers position_embeddings = self.rotary_emb(hidden_states, position_ids) seq_len = x.shape[1] dtype = x.dtype device = x.device # 判断是否使用 Flash Attention 2 is_flash_attn = (self.config._attn_implementation == "flash_attention_2") # 初始化 Mask 变量 full_attn_mask = None sliding_attn_mask = None if is_flash_attn: # ------------------------------------------------------- # 场景 A: Flash Attention 模式 # ------------------------------------------------------- # FA 不需要 4D Mask。 # 如果有 padding mask (attention_mask [B, L]),直接传给它即可。 # 如果没有 padding mask,传 None。 # 滑动窗口逻辑由 Layer 内部传给 FA kernel 的 sliding_window 参数控制。 full_attn_mask = attention_mask # 这里的逻辑是:如果配置启用了滑动窗口,FA 模式下我们也只需要传基础的 padding mask # Layer 会自己决定是否调用带 sliding window 的 kernel sliding_attn_mask = attention_mask if self.config.use_sliding_window else None else: # ------------------------------------------------------- # 场景 B: CPU / Mac / SDPA (Eager 模式) # ------------------------------------------------------- # 必须手动生成 4D Mask [B, 1, L, L] # 1. Full Attention (Bidirectional, Global) # 对应原来的 create_causal_mask + bidirectional full_attn_mask = create_4d_mask( seq_len=seq_len, dtype=dtype, device=device, attention_mask=attention_mask, # [B, L] sliding_window=None, is_sliding_window=False, is_causal=False # <--- 关键:双向注意力 ) # 2. Sliding Attention (Bidirectional, Local) # 对应原来的 create_sliding_window... + bidirectional if self.config.use_sliding_window: sliding_attn_mask = create_4d_mask( seq_len=seq_len, dtype=dtype, device=device, attention_mask=attention_mask, # [B, L] sliding_window=self.config.sliding_window, is_sliding_window=True, # <--- 开启滑动窗口 is_causal=False # <--- 关键:双向注意力 ) # 构建 Mapping self_attn_mask_mapping = { "full_attention": full_attn_mask, "sliding_attention": sliding_attn_mask, } for layer_module in self.layers: layer_outputs = layer_module( hidden_states, position_embeddings, attention_mask=self_attn_mask_mapping[layer_module.attention_type], **flash_attn_kwargs, ) hidden_states = layer_outputs[0] hidden_states = self.norm(hidden_states) # Extract the special token output (first position) as pooled representation cls_output = hidden_states[:, 0, :] cls_output = rearrange(cls_output, "(b t) c -> b t c", b=B) return cls_output class AudioTokenDetokenizer(AceStepPreTrainedModel): """ Audio token detokenizer module. Converts quantized audio tokens back to continuous acoustic representations. Expands each token into multiple patches using special tokens, processes them through encoder layers, and projects to acoustic hidden dimension. """ def __init__(self, config): super().__init__(config) self.config = config self.embed_tokens = nn.Linear(config.hidden_size, config.hidden_size) self.norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = Qwen3RotaryEmbedding(config=config) self.gradient_checkpointing = False # Special tokens for expanding each quantized token into patches self.special_tokens = nn.Parameter(torch.randn(1, config.pool_window_size, config.hidden_size) * 0.02) self.layers = nn.ModuleList( [AceStepEncoderLayer(config, layer_idx) for layer_idx in range(config.num_attention_pooler_hidden_layers)] ) # Project back to acoustic hidden dimension self.proj_out = nn.Linear(config.hidden_size, config.audio_acoustic_hidden_dim) # Initialize weights and apply final processing self.post_init() @can_return_tuple def forward(self, x, attention_mask: Optional[torch.Tensor] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> BaseModelOutput: B, T, D = x.shape x = self.embed_tokens(x) # Expand and add special tokens: N x T x D -> N x T x P x D # Each token is expanded into pool_window_size patches x = x.unsqueeze(2) # N x T x 1 x D x = x.repeat(1, 1, self.config.pool_window_size, 1) # N x T x P x D # Add learnable special tokens to each patch special_tokens = self.special_tokens.expand(B, T, -1, -1) x = x + special_tokens # Reshape for processing: (batch * time) x patches x hidden x = rearrange(x, "b t p c -> (b t) p c") # Cache position: only used for mask construction cache_position = torch.arange(0, x.shape[1], device=x.device) # Positional IDs position_ids = cache_position.unsqueeze(0) # Initialize hidden states hidden_states = x # Create position embeddings to be shared across all layers position_embeddings = self.rotary_emb(hidden_states, position_ids) seq_len = x.shape[1] dtype = x.dtype device = x.device # 判断是否使用 Flash Attention 2 is_flash_attn = (self.config._attn_implementation == "flash_attention_2") # 初始化 Mask 变量 full_attn_mask = None sliding_attn_mask = None if is_flash_attn: # ------------------------------------------------------- # 场景 A: Flash Attention 模式 # ------------------------------------------------------- # FA 不需要 4D Mask。 # 如果有 padding mask (attention_mask [B, L]),直接传给它即可。 # 如果没有 padding mask,传 None。 # 滑动窗口逻辑由 Layer 内部传给 FA kernel 的 sliding_window 参数控制。 full_attn_mask = attention_mask # 这里的逻辑是:如果配置启用了滑动窗口,FA 模式下我们也只需要传基础的 padding mask # Layer 会自己决定是否调用带 sliding window 的 kernel sliding_attn_mask = attention_mask if self.config.use_sliding_window else None else: # ------------------------------------------------------- # 场景 B: CPU / Mac / SDPA (Eager 模式) # ------------------------------------------------------- # 必须手动生成 4D Mask [B, 1, L, L] # 1. Full Attention (Bidirectional, Global) # 对应原来的 create_causal_mask + bidirectional full_attn_mask = create_4d_mask( seq_len=seq_len, dtype=dtype, device=device, attention_mask=attention_mask, # [B, L] sliding_window=None, is_sliding_window=False, is_causal=False # <--- 关键:双向注意力 ) # 2. Sliding Attention (Bidirectional, Local) # 对应原来的 create_sliding_window... + bidirectional if self.config.use_sliding_window: sliding_attn_mask = create_4d_mask( seq_len=seq_len, dtype=dtype, device=device, attention_mask=attention_mask, # [B, L] sliding_window=self.config.sliding_window, is_sliding_window=True, # <--- 开启滑动窗口 is_causal=False # <--- 关键:双向注意力 ) # 构建 Mapping self_attn_mask_mapping = { "full_attention": full_attn_mask, "sliding_attention": sliding_attn_mask, } for layer_module in self.layers: layer_outputs = layer_module( hidden_states, position_embeddings, attention_mask=self_attn_mask_mapping[layer_module.attention_type], **flash_attn_kwargs, ) hidden_states = layer_outputs[0] hidden_states = self.norm(hidden_states) hidden_states = self.proj_out(hidden_states) hidden_states = rearrange(hidden_states, "(b t) p c -> b (t p) c", b=B, p=self.config.pool_window_size) return hidden_states class AceStepTimbreEncoder(AceStepPreTrainedModel): """ Encoder for extracting timbre embeddings from reference audio. Processes packed reference audio acoustic features to extract timbre representations. Uses a special token (CLS-like) to aggregate information from the entire reference audio sequence. Outputs are unpacked back to batch format for use in conditioning. """ def __init__(self, config): super().__init__(config) # Project acoustic features to model hidden size self.embed_tokens = nn.Linear(config.timbre_hidden_dim, config.hidden_size) self.norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = Qwen3RotaryEmbedding(config=config) self.gradient_checkpointing = False # Special token for aggregating timbre information (prepended to sequence) self.special_token = nn.Parameter(torch.randn(1, 1, config.hidden_size)) self.layers = nn.ModuleList( [AceStepEncoderLayer(config, layer_idx) for layer_idx in range(config.num_timbre_encoder_hidden_layers)] ) # Initialize weights and apply final processing self.post_init() def unpack_timbre_embeddings(self, timbre_embs_packed, refer_audio_order_mask): """ Unpack packed timbre embeddings into batch format. Args: timbre_embs_packed: Packed timbre embeddings of shape [N, d] refer_audio_order_mask: Order mask indicating batch assignment for each packed embedding Returns: Tuple of (unpacked_embeddings, mask): - unpacked_embeddings: Unpacked embeddings of shape [B, max_count, d] - new_mask: Mask indicating valid positions, shape [B, max_count] """ N, d = timbre_embs_packed.shape device = timbre_embs_packed.device dtype = timbre_embs_packed.dtype # Get batch size B = int(refer_audio_order_mask.max().item() + 1) # Calculate element count and positions for each batch counts = torch.bincount(refer_audio_order_mask, minlength=B) max_count = counts.max().item() # Calculate positions within batch sorted_indices = torch.argsort(refer_audio_order_mask * N + torch.arange(N, device=device), stable=True) sorted_batch_ids = refer_audio_order_mask[sorted_indices] positions = torch.arange(N, device=device) batch_starts = torch.cat([torch.tensor([0], device=device), torch.cumsum(counts, dim=0)[:-1]]) positions_in_sorted = positions - batch_starts[sorted_batch_ids] inverse_indices = torch.empty_like(sorted_indices) inverse_indices[sorted_indices] = torch.arange(N, device=device) positions_in_batch = positions_in_sorted[inverse_indices] # Use one-hot encoding and matrix multiplication (gradient-friendly approach) # Create one-hot encoding indices_2d = refer_audio_order_mask * max_count + positions_in_batch # (N,) one_hot = F.one_hot(indices_2d, num_classes=B * max_count).to(dtype) # (N, B*max_count) # Rearrange using matrix multiplication timbre_embs_flat = one_hot.t() @ timbre_embs_packed # (B*max_count, d) timbre_embs_unpack = timbre_embs_flat.reshape(B, max_count, d) # Create mask indicating valid positions mask_flat = (one_hot.sum(dim=0) > 0).long() # (B*max_count,) new_mask = mask_flat.reshape(B, max_count) return timbre_embs_unpack, new_mask @can_return_tuple def forward( self, refer_audio_acoustic_hidden_states_packed: Optional[torch.FloatTensor] = None, refer_audio_order_mask: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> BaseModelOutput: inputs_embeds = refer_audio_acoustic_hidden_states_packed # Project embeddings: N x T x timbre_hidden_dim -> N x T x hidden_size inputs_embeds = self.embed_tokens(inputs_embeds) # Prepend special token for timbre aggregation (CLS-like token) # inputs_embeds = torch.cat([self.special_token.expand(inputs_embeds.shape[0], 1, -1), inputs_embeds], dim=1) # Cache position: only used for mask construction (not for actual caching) cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device) # Positional IDs position_ids = cache_position.unsqueeze(0) seq_len = inputs_embeds.shape[1] dtype = inputs_embeds.dtype device = inputs_embeds.device # 判断是否使用 Flash Attention 2 is_flash_attn = (self.config._attn_implementation == "flash_attention_2") # 初始化 Mask 变量 full_attn_mask = None sliding_attn_mask = None if is_flash_attn: # ------------------------------------------------------- # 场景 A: Flash Attention 模式 # ------------------------------------------------------- # FA 不需要 4D Mask。 # 如果有 padding mask (attention_mask [B, L]),直接传给它即可。 # 如果没有 padding mask,传 None。 # 滑动窗口逻辑由 Layer 内部传给 FA kernel 的 sliding_window 参数控制。 full_attn_mask = attention_mask # 这里的逻辑是:如果配置启用了滑动窗口,FA 模式下我们也只需要传基础的 padding mask # Layer 会自己决定是否调用带 sliding window 的 kernel sliding_attn_mask = attention_mask if self.config.use_sliding_window else None else: # ------------------------------------------------------- # 场景 B: CPU / Mac / SDPA (Eager 模式) # ------------------------------------------------------- # 必须手动生成 4D Mask [B, 1, L, L] # 1. Full Attention (Bidirectional, Global) # 对应原来的 create_causal_mask + bidirectional full_attn_mask = create_4d_mask( seq_len=seq_len, dtype=dtype, device=device, attention_mask=attention_mask, # [B, L] sliding_window=None, is_sliding_window=False, is_causal=False # <--- 关键:双向注意力 ) # 2. Sliding Attention (Bidirectional, Local) # 对应原来的 create_sliding_window... + bidirectional if self.config.use_sliding_window: sliding_attn_mask = create_4d_mask( seq_len=seq_len, dtype=dtype, device=device, attention_mask=attention_mask, # [B, L] sliding_window=self.config.sliding_window, is_sliding_window=True, # <--- 开启滑动窗口 is_causal=False # <--- 关键:双向注意力 ) # 构建 Mapping self_attn_mask_mapping = { "full_attention": full_attn_mask, "sliding_attention": sliding_attn_mask, } # Initialize hidden states hidden_states = inputs_embeds # Create position embeddings to be shared across all layers position_embeddings = self.rotary_emb(hidden_states, position_ids) # Pass through transformer layers for layer_module in self.layers[: self.config.num_hidden_layers]: layer_outputs = layer_module( hidden_states, position_embeddings, self_attn_mask_mapping[layer_module.attention_type], position_ids, **flash_attn_kwargs, ) hidden_states = layer_outputs[0] hidden_states = self.norm(hidden_states) # Extract special token output (first position) as timbre embedding: N x T x D -> N x D hidden_states = hidden_states[:, 0, :] # Unpack packed embeddings back to batch format timbre_embs_unpack, timbre_embs_mask = self.unpack_timbre_embeddings(hidden_states, refer_audio_order_mask) return timbre_embs_unpack, timbre_embs_mask class AceStepAudioTokenizer(AceStepPreTrainedModel): """ Audio tokenizer module. Converts continuous acoustic features into discrete quantized tokens. Process: project -> pool patches -> quantize. Used for converting audio representations into discrete tokens for processing by the diffusion model. """ def __init__(self, config): super().__init__(config) # Project acoustic features to hidden size self.audio_acoustic_proj = nn.Linear(config.audio_acoustic_hidden_dim, config.hidden_size) # Pool patches into sequence-level representations self.attention_pooler = AttentionPooler(config) # Quantize continuous representations into discrete tokens self.quantizer = ResidualFSQ( dim=config.fsq_dim, levels=config.fsq_input_levels, num_quantizers=config.fsq_input_num_quantizers ) self.pool_window_size = config.pool_window_size # Initialize weights and apply final processing self.post_init() @can_return_tuple def forward( self, hidden_states: Optional[torch.FloatTensor] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> BaseModelOutput: # Project acoustic features to hidden size hidden_states = self.audio_acoustic_proj(hidden_states) # Pool sequences: N x T//pool_window_size x pool_window_size x d -> N x T//pool_window_size x d hidden_states = self.attention_pooler(hidden_states) # Quantize continuous representations into discrete tokens: N x T//pool_window_size x d quantized, indices = self.quantizer(hidden_states) return quantized, indices def tokenize(self, x): x = rearrange(x, 'n (t_patch p) d -> n t_patch p d', p=self.pool_window_size) quantized, indices = self.forward(x) return quantized, indices class Lambda(nn.Module): """ Wrapper module for arbitrary lambda functions. Allows using lambda functions in nn.Sequential by wrapping them in a Module. Useful for simple transformations like transpose operations. """ def __init__(self, func): super().__init__() self.func = func def forward(self, x): return self.func(x) class AceStepDiTModel(AceStepPreTrainedModel): """ DiT (Diffusion Transformer) model for AceStep. Main diffusion model that generates audio latents conditioned on text, lyrics, and timbre. Uses patch-based processing with transformer layers, timestep conditioning, and cross-attention to encoder outputs. """ def __init__(self, config: AceStepConfig): super().__init__(config) # Rotary position embeddings for transformer layers self.rotary_emb = Qwen3RotaryEmbedding(config) # Stack of DiT transformer layers self.layers = nn.ModuleList( [AceStepDiTLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) in_channels = config.in_channels inner_dim = config.hidden_size patch_size = config.patch_size self.patch_size = patch_size # Input projection: patch embedding using 1D convolution # Converts sequence into patches for efficient processing self.proj_in = nn.Sequential( Lambda(lambda x: x.transpose(1, 2)), # [B, T, C] -> [B, C, T] nn.Conv1d( in_channels=in_channels, out_channels=inner_dim, kernel_size=patch_size, stride=patch_size, padding=0, ), Lambda(lambda x: x.transpose(1, 2)), # [B, C, T//patch_size] -> [B, T//patch_size, C] ) # Timestep embeddings for diffusion conditioning # Two embeddings: one for timestep t, one for timestep difference (t - r) self.time_embed = TimestepEmbedding(in_channels=256, time_embed_dim=inner_dim) self.time_embed_r = TimestepEmbedding(in_channels=256, time_embed_dim=inner_dim) # Project encoder hidden states to model dimension self.condition_embedder = nn.Linear(inner_dim, inner_dim, bias=True) # Output normalization and projection # Adaptive layer norm with scale-shift modulation, then de-patchify self.norm_out = Qwen3RMSNorm(inner_dim, eps=config.rms_norm_eps) self.proj_out = nn.Sequential( Lambda(lambda x: x.transpose(1, 2)), # [B, T//patch_size, inner_dim] -> [B, inner_dim, T//patch_size] nn.ConvTranspose1d( in_channels=inner_dim, out_channels=config.audio_acoustic_hidden_dim, kernel_size=patch_size, stride=patch_size, padding=0, ), Lambda(lambda x: x.transpose(1, 2)), # [B, out_channels, T] -> [B, T, out_channels] ) # Scale-shift table for adaptive output normalization (2 values: shift, scale) self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5) self.gradient_checkpointing = False def forward( self, hidden_states: torch.Tensor, timestep: torch.Tensor, timestep_r: torch.Tensor, attention_mask: torch.Tensor, encoder_hidden_states: torch.Tensor, encoder_attention_mask: torch.Tensor, context_latents: torch.Tensor, use_cache: Optional[bool] = None, past_key_values: Optional[EncoderDecoderCache] = None, cache_position: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = False, return_hidden_states: int = None, custom_layers_config: Optional[dict] = None, enable_early_exit: bool = False, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ): use_cache = use_cache if use_cache is not None else self.config.use_cache # Disable cache during training or when gradient checkpointing is enabled if self.gradient_checkpointing and self.training and use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." ) use_cache = False if self.training: use_cache = False # Initialize cache if needed (only during inference for auto-regressive generation) if not self.training and use_cache and past_key_values is None: past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) # Compute timestep embeddings for diffusion conditioning # Two embeddings: one for timestep t, one for timestep difference (t - r) temb_t, timestep_proj_t = self.time_embed(timestep) temb_r, timestep_proj_r = self.time_embed_r(timestep - timestep_r) # Combine embeddings temb = temb_t + temb_r timestep_proj = timestep_proj_t + timestep_proj_r # Concatenate context latents (source latents + chunk masks) with hidden states hidden_states = torch.cat([context_latents, hidden_states], dim=-1) # Record original sequence length for later restoration after padding original_seq_len = hidden_states.shape[1] # Apply padding if sequence length is not divisible by patch_size # This ensures proper patch extraction pad_length = 0 if hidden_states.shape[1] % self.patch_size != 0: pad_length = self.patch_size - (hidden_states.shape[1] % self.patch_size) hidden_states = F.pad(hidden_states, (0, 0, 0, pad_length), mode='constant', value=0) # Project input to patches and project encoder states hidden_states = self.proj_in(hidden_states) encoder_hidden_states = self.condition_embedder(encoder_hidden_states) # Cache positions if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=hidden_states.device ) # Position IDs if position_ids is None: position_ids = cache_position.unsqueeze(0) seq_len = hidden_states.shape[1] encoder_seq_len = encoder_hidden_states.shape[1] dtype = hidden_states.dtype device = hidden_states.device # 判断是否使用 Flash Attention 2 is_flash_attn = (self.config._attn_implementation == "flash_attention_2") # 初始化 Mask 变量 full_attn_mask = None sliding_attn_mask = None encoder_attention_mask = None attention_mask = None if is_flash_attn: # ------------------------------------------------------- # 场景 A: Flash Attention 模式 # ------------------------------------------------------- # FA 不需要 4D Mask。 # 如果有 padding mask (attention_mask [B, L]),直接传给它即可。 # 如果没有 padding mask,传 None。 # 滑动窗口逻辑由 Layer 内部传给 FA kernel 的 sliding_window 参数控制。 full_attn_mask = attention_mask # 这里的逻辑是:如果配置启用了滑动窗口,FA 模式下我们也只需要传基础的 padding mask # Layer 会自己决定是否调用带 sliding window 的 kernel sliding_attn_mask = attention_mask if self.config.use_sliding_window else None else: # ------------------------------------------------------- # 场景 B: CPU / Mac / SDPA (Eager 模式) # ------------------------------------------------------- # 必须手动生成 4D Mask [B, 1, L, L] # 1. Full Attention (Bidirectional, Global) # 对应原来的 create_causal_mask + bidirectional full_attn_mask = create_4d_mask( seq_len=seq_len, dtype=dtype, device=device, attention_mask=attention_mask, # [B, L] sliding_window=None, is_sliding_window=False, is_causal=False # <--- 关键:双向注意力 ) max_len = max(seq_len, encoder_seq_len) encoder_attention_mask = create_4d_mask( seq_len=max_len, dtype=dtype, device=device, attention_mask=attention_mask, # [B, L] sliding_window=None, is_sliding_window=False, is_causal=False # <--- 关键:双向注意力 ) encoder_attention_mask = encoder_attention_mask[:, :, :seq_len, :encoder_seq_len] # 2. Sliding Attention (Bidirectional, Local) # 对应原来的 create_sliding_window... + bidirectional if self.config.use_sliding_window: sliding_attn_mask = create_4d_mask( seq_len=seq_len, dtype=dtype, device=device, attention_mask=attention_mask, # [B, L] sliding_window=self.config.sliding_window, is_sliding_window=True, # <--- 开启滑动窗口 is_causal=False # <--- 关键:双向注意力 ) # 构建 Mapping self_attn_mask_mapping = { "full_attention": full_attn_mask, "sliding_attention": sliding_attn_mask, "encoder_attention_mask": encoder_attention_mask, } # Create position embeddings to be shared across all decoder layers position_embeddings = self.rotary_emb(hidden_states, position_ids) all_cross_attentions = () if output_attentions else None # Handle early exit for custom layer configurations max_needed_layer = float('inf') if custom_layers_config is not None and enable_early_exit: max_needed_layer = max(custom_layers_config.keys()) # Force output_attentions to True when early exit is enabled for attention extraction output_attentions = True if all_cross_attentions is None: all_cross_attentions = () # Process through transformer layers for index_block, layer_module in enumerate(self.layers): layer_outputs = layer_module( hidden_states, position_embeddings, timestep_proj, self_attn_mask_mapping[layer_module.attention_type], position_ids, past_key_values, output_attentions, use_cache, cache_position, encoder_hidden_states, self_attn_mask_mapping["encoder_attention_mask"], **flash_attn_kwargs, ) hidden_states = layer_outputs[0] if output_attentions and self.layers[index_block].use_cross_attention: # layer_outputs structure: (hidden_states, self_attn_weights, cross_attn_weights) # Extract the last element which is cross_attn_weights if len(layer_outputs) >= 3: all_cross_attentions += (layer_outputs[2],) if return_hidden_states: return hidden_states # Extract scale-shift parameters for adaptive output normalization shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1) shift = shift.to(hidden_states.device) scale = scale.to(hidden_states.device) # Apply adaptive layer norm: norm(x) * (1 + scale) + shift hidden_states = (self.norm_out(hidden_states) * (1 + scale) + shift).type_as(hidden_states) # Project output: de-patchify back to original sequence format hidden_states = self.proj_out(hidden_states) # Crop back to original sequence length to ensure exact length match (remove padding) hidden_states = hidden_states[:, :original_seq_len, :] outputs = (hidden_states, past_key_values) if output_attentions: outputs += (all_cross_attentions,) return outputs class AceStepConditionEncoder(AceStepPreTrainedModel): """ Condition encoder for AceStep model. Encodes multiple conditioning inputs (text, lyrics, timbre) and packs them into a single sequence for cross-attention in the diffusion model. Handles projection, encoding, and sequence packing. """ def __init__(self, config: AceStepConfig): super().__init__(config) self.config = config # Project text embeddings to model hidden size self.text_projector = nn.Linear(config.text_hidden_dim, config.hidden_size, bias=False) # Encoder for lyric text self.lyric_encoder = AceStepLyricEncoder(config) # Encoder for timbre from reference audio self.timbre_encoder = AceStepTimbreEncoder(config) def forward( self, # Text inputs text_hidden_states: Optional[torch.FloatTensor] = None, text_attention_mask: Optional[torch.Tensor] = None, # Lyric inputs lyric_hidden_states: Optional[torch.LongTensor] = None, lyric_attention_mask: Optional[torch.Tensor] = None, # Reference audio for timbre refer_audio_acoustic_hidden_states_packed: Optional[torch.Tensor] = None, refer_audio_order_mask: Optional[torch.LongTensor] = None, ): # Project and encode text text_hidden_states = self.text_projector(text_hidden_states) # Encode lyrics lyric_encoder_outputs = self.lyric_encoder( inputs_embeds=lyric_hidden_states, attention_mask=lyric_attention_mask, ) lyric_hidden_states = lyric_encoder_outputs.last_hidden_state # Encode timbre from reference audio timbre_embs_unpack, timbre_embs_mask = self.timbre_encoder(refer_audio_acoustic_hidden_states_packed, refer_audio_order_mask) # Pack sequences: combine lyrics and timbre, then add text # This creates a single sequence with all conditioning information encoder_hidden_states, encoder_attention_mask = pack_sequences(lyric_hidden_states, timbre_embs_unpack, lyric_attention_mask, timbre_embs_mask) encoder_hidden_states, encoder_attention_mask = pack_sequences(encoder_hidden_states, text_hidden_states, encoder_attention_mask, text_attention_mask) return encoder_hidden_states, encoder_attention_mask class AceStepConditionGenerationModel(AceStepPreTrainedModel): """ Main conditional generation model for AceStep. End-to-end model for generating audio conditioned on text, lyrics, and timbre. Combines encoder (for conditioning), decoder (diffusion model), tokenizer (for discrete tokenization), and detokenizer (for reconstruction). Supports flow matching training and inference with various sampling methods. """ def __init__(self, config: AceStepConfig): super().__init__(config) self.config = config # Diffusion model components self.decoder = AceStepDiTModel(config) # Main diffusion transformer self.encoder = AceStepConditionEncoder(config) # Condition encoder self.tokenizer = AceStepAudioTokenizer(config) # Audio tokenizer self.detokenizer = AudioTokenDetokenizer(config) # Audio detokenizer # Null condition embedding for classifier-free guidance self.null_condition_emb = nn.Parameter(torch.randn(1, 1, config.hidden_size)) # Initialize weights and apply final processing self.post_init() def tokenize(self, x, silence_latent, attention_mask): if x.shape[1] % self.config.pool_window_size != 0: pad_len = self.config.pool_window_size - (x.shape[1] % self.config.pool_window_size) x = torch.cat([x, silence_latent[:1,:pad_len].repeat(x.shape[0],1,1)], dim=1) attention_mask = F.pad(attention_mask, (0, pad_len), mode='constant', value=0) x = rearrange(x, 'n (t_patch p) d -> n t_patch p d', p=self.config.pool_window_size) seq_len = x.shape[1] chunk = math.ceil(attention_mask.shape[1] / seq_len) attention_mask = attention_mask.to(x.dtype) attention_mask = F.max_pool1d(attention_mask.unsqueeze(1), kernel_size=chunk, stride=chunk, ceil_mode=True).squeeze(1) quantized, indices = self.tokenizer(x) return quantized, indices, attention_mask def detokenize(self, quantized): """ Detokenize quantized audio tokens back to continuous representations. Args: quantized: Quantized tokens of shape [N, T//pool_window_size, d] Returns: Detokenized hidden states of shape [N, T, d] """ hidden_states = self.detokenizer(quantized) return hidden_states @torch.no_grad() def prepare_condition( self, text_hidden_states: torch.FloatTensor, text_attention_mask: torch.Tensor, lyric_hidden_states: torch.FloatTensor, lyric_attention_mask: torch.Tensor, refer_audio_acoustic_hidden_states_packed: torch.FloatTensor, refer_audio_order_mask: torch.Tensor, hidden_states: torch.FloatTensor, attention_mask: torch.Tensor, silence_latent: torch.FloatTensor, src_latents: torch.FloatTensor, chunk_masks: torch.Tensor, is_covers: torch.Tensor, precomputed_lm_hints_25Hz: Optional[torch.FloatTensor] = None, audio_codes: torch.FloatTensor = None, ): dtype = hidden_states.dtype encoder_hidden_states, encoder_attention_mask = self.encoder( text_hidden_states=text_hidden_states, text_attention_mask=text_attention_mask, lyric_hidden_states=lyric_hidden_states, lyric_attention_mask=lyric_attention_mask, refer_audio_acoustic_hidden_states_packed=refer_audio_acoustic_hidden_states_packed, refer_audio_order_mask=refer_audio_order_mask, ) # N x T x d -> N x T//pool_window_size x pool_window_size x d # tokenize and detokenize to get LM hints for cover songs (when is_covers=True) # Use precomputed hints if provided (e.g., from audio codes), otherwise tokenize hidden_states if precomputed_lm_hints_25Hz is not None: print("Using precomputed LM hints") lm_hints_25Hz = precomputed_lm_hints_25Hz[:, :src_latents.shape[1], :] else: if audio_codes is not None: lm_hints_5Hz = self.tokenize.quantizer.get_output_from_indices(audio_codes) else: lm_hints_5Hz, indices, llm_mask = self.tokenize(hidden_states, silence_latent, attention_mask) lm_hints_25Hz = self.detokenize(lm_hints_5Hz) # Crop lm_hints_25Hz to match src_latents length (tokenize may have added padding) lm_hints_25Hz = lm_hints_25Hz[:, :src_latents.shape[1], :] src_latents = torch.where(is_covers.unsqueeze(-1).unsqueeze(-1) > 0, lm_hints_25Hz, src_latents) # Concatenate source latents with chunk masks as context context_latents = torch.cat([src_latents, chunk_masks.to(dtype)], dim=-1) return encoder_hidden_states, encoder_attention_mask, context_latents def forward( self, # Diffusion inputs hidden_states: torch.FloatTensor, attention_mask: torch.Tensor, # Encoder inputs # Text text_hidden_states: Optional[torch.FloatTensor] = None, text_attention_mask: Optional[torch.Tensor] = None, # Lyric lyric_hidden_states: Optional[torch.LongTensor] = None, lyric_attention_mask: Optional[torch.Tensor] = None, # Reference audio for timbre refer_audio_acoustic_hidden_states_packed: Optional[torch.Tensor] = None, refer_audio_order_mask: Optional[torch.LongTensor] = None, src_latents: torch.FloatTensor = None, chunk_masks: torch.FloatTensor = None, is_covers: torch.Tensor = None, silence_latent: torch.FloatTensor = None, cfg_ratio: float = 0.15, ): """ Forward pass for training (computes training losses). """ # Prepare conditioning inputs (encoder states, context latents) encoder_hidden_states, encoder_attention_mask, context_latents = self.prepare_condition( text_hidden_states=text_hidden_states, text_attention_mask=text_attention_mask, lyric_hidden_states=lyric_hidden_states, lyric_attention_mask=lyric_attention_mask, refer_audio_acoustic_hidden_states_packed=refer_audio_acoustic_hidden_states_packed, refer_audio_order_mask=refer_audio_order_mask, hidden_states=src_latents, attention_mask=attention_mask, silence_latent=silence_latent, src_latents=src_latents, chunk_masks=chunk_masks, is_covers=is_covers, ) bsz, device, dtype = hidden_states.shape[0], hidden_states.device, hidden_states.dtype # Classifier-free guidance: randomly drop conditions with probability cfg_ratio # This helps the model learn to work with and without conditions full_cfg_condition_mask = torch.where( (torch.rand(size=(bsz,), device=device, dtype=dtype) < cfg_ratio), torch.zeros(size=(bsz,), device=device, dtype=dtype), torch.ones(size=(bsz,), device=device, dtype=dtype) ).view(-1, 1, 1) # Replace dropped conditions with null condition embedding encoder_hidden_states = torch.where(full_cfg_condition_mask > 0, encoder_hidden_states, self.null_condition_emb.expand_as(encoder_hidden_states)) # Flow matching setup: sample noise x1 and interpolate with data x0 x1 = torch.randn_like(hidden_states) # Noise x0 = hidden_states # Data # Sample timesteps t and r for flow matching t, r = sample_t_r(bsz, device, dtype, self.config.data_proportion, self.config.timestep_mu, self.config.timestep_sigma, use_meanflow=False) t_ = t.unsqueeze(-1).unsqueeze(-1) # Interpolate: x_t = t * x1 + (1 - t) * x0 xt = t_ * x1 + (1.0 - t_) * x0 # Predict flow (velocity) from diffusion model decoder_outputs = self.decoder( hidden_states=xt, timestep=t, timestep_r=t, attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, context_latents=context_latents, ) # Flow matching loss: predict the flow field v = x1 - x0 flow = x1 - x0 diffusion_loss = F.mse_loss(decoder_outputs[0], flow) return { "diffusion_loss": diffusion_loss, } def training_losses(self, **kwargs): return self.forward(**kwargs) def prepare_noise(self, context_latents: torch.FloatTensor, seed: Union[int, List[int], None] = None): """ Prepare noise tensor for generation with optional seeding. Args: context_latents: Context latents to determine noise shape seed: Can be int, List[int], or None. If None, uses random noise. Returns: Noise tensor of appropriate shape """ bsz = context_latents.shape[0] device = context_latents.device dtype = context_latents.dtype # Handle seed: can be int, List[int], or None src_latents_shape = (context_latents.shape[0], context_latents.shape[1], context_latents.shape[-1] // 2) if seed is None: # No seed provided - use random noise = torch.randn(src_latents_shape, device=device, dtype=dtype) elif isinstance(seed, list): # List of seeds - generate noise for each sample separately noise_list = [] for i, s in enumerate(seed): if s is None or s < 0: # Random seed for this sample noise_i = torch.randn(1, src_latents_shape[1], src_latents_shape[2], device=device, dtype=dtype) else: # Use specific seed for this sample generator = torch.Generator(device=device).manual_seed(int(s)) noise_i = torch.randn(1, src_latents_shape[1], src_latents_shape[2], generator=generator, device=device, dtype=dtype) noise_list.append(noise_i) noise = torch.cat(noise_list, dim=0) else: # Single seed for all samples generator = torch.Generator(device=device).manual_seed(int(seed)) noise = torch.randn(src_latents_shape, generator=generator, device=device, dtype=dtype) return noise def get_x0_from_noise(self, zt, vt, t): return zt - vt * t.unsqueeze(-1).unsqueeze(-1) def renoise(self, x, t, noise=None): if noise is None: noise = torch.randn_like(x) if isinstance(t, torch.Tensor) and t.ndim != x.ndim: t = t.unsqueeze(-1).unsqueeze(-1) xt = t * noise + (1 - t) * x return xt def generate_audio( self, text_hidden_states: torch.FloatTensor, text_attention_mask: torch.FloatTensor, lyric_hidden_states: torch.FloatTensor, lyric_attention_mask: torch.FloatTensor, refer_audio_acoustic_hidden_states_packed: torch.FloatTensor, refer_audio_order_mask: torch.LongTensor, src_latents: torch.FloatTensor, chunk_masks: torch.FloatTensor, is_covers: torch.Tensor, silence_latent: Optional[torch.FloatTensor] = None, attention_mask: torch.Tensor = None, seed: int = None, infer_method: str = "ode", use_cache: bool = True, infer_steps: int = 30, diffusion_guidance_sale: float = 7.0, audio_cover_strength: float = 1.0, non_cover_text_hidden_states: Optional[torch.FloatTensor] = None, non_cover_text_attention_mask: Optional[torch.FloatTensor] = None, cfg_interval_start: float = 0.0, cfg_interval_end: float = 1.0, precomputed_lm_hints_25Hz: Optional[torch.FloatTensor] = None, audio_codes: Optional[torch.FloatTensor] = None, use_progress_bar: bool = True, use_adg: bool = False, shift: float = 1.0, **kwargs, ): if attention_mask is None: latent_length = src_latents.shape[1] attention_mask = torch.ones(src_latents.shape[0], latent_length, device=src_latents.device, dtype=src_latents.dtype) time_costs = {} start_time = time.time() total_start_time = start_time encoder_hidden_states, encoder_attention_mask, context_latents = self.prepare_condition( text_hidden_states=text_hidden_states, text_attention_mask=text_attention_mask, lyric_hidden_states=lyric_hidden_states, lyric_attention_mask=lyric_attention_mask, refer_audio_acoustic_hidden_states_packed=refer_audio_acoustic_hidden_states_packed, refer_audio_order_mask=refer_audio_order_mask, hidden_states=src_latents, attention_mask=attention_mask, silence_latent=silence_latent, src_latents=src_latents, chunk_masks=chunk_masks, is_covers=is_covers, precomputed_lm_hints_25Hz=precomputed_lm_hints_25Hz, audio_codes=audio_codes, ) encoder_hidden_states_non_cover, encoder_attention_mask_non_cover, context_latents_non_cover = None, None, None if audio_cover_strength < 1.0: non_is_covers = torch.zeros_like(is_covers, device=is_covers.device, dtype=is_covers.dtype) # Use silence_latent for non-cover condition to simulate text2music mode (no reference audio) silence_latent_expanded = silence_latent[:, :src_latents.shape[1], :].expand(src_latents.shape[0], -1, -1) encoder_hidden_states_non_cover, encoder_attention_mask_non_cover, context_latents_non_cover = self.prepare_condition( text_hidden_states=non_cover_text_hidden_states, text_attention_mask=non_cover_text_attention_mask, lyric_hidden_states=lyric_hidden_states, lyric_attention_mask=lyric_attention_mask, refer_audio_acoustic_hidden_states_packed=refer_audio_acoustic_hidden_states_packed, refer_audio_order_mask=refer_audio_order_mask, hidden_states=silence_latent_expanded, attention_mask=attention_mask, silence_latent=silence_latent, src_latents=silence_latent_expanded, chunk_masks=chunk_masks, is_covers=non_is_covers, precomputed_lm_hints_25Hz=None, audio_codes=None, ) end_time = time.time() time_costs["encoder_time_cost"] = end_time - start_time start_time = end_time # Calculate cover steps based on audio_cover_strength cover_steps = int(infer_steps * audio_cover_strength) device, dtype = context_latents.device, context_latents.dtype t = torch.linspace(1.0, 0.0, infer_steps + 1, device=device, dtype=dtype) # Apply shift transformation to timesteps if shift != 1.0 if shift != 1.0: t = shift * t / (1 + (shift - 1) * t) if use_progress_bar: iterator = tqdm(zip(t[:-1], t[1:]), total=infer_steps) else: iterator = zip(t[:-1], t[1:]) noise = self.prepare_noise(context_latents, seed) bsz, device, dtype = context_latents.shape[0], context_latents.device, context_latents.dtype past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) momentum_buffer = MomentumBuffer() # main task condition do_cfg_guidance = diffusion_guidance_sale > 1.0 if do_cfg_guidance: encoder_hidden_states = torch.cat([encoder_hidden_states, self.null_condition_emb.expand_as(encoder_hidden_states)], dim=0) encoder_attention_mask = torch.cat([encoder_attention_mask, encoder_attention_mask], dim=0) # src_latents context_latents = torch.cat([context_latents, context_latents], dim=0) attention_mask = torch.cat([attention_mask, attention_mask], dim=0) xt = noise with torch.no_grad(): for step_idx, (t_curr, t_prev) in enumerate(iterator): if step_idx >= cover_steps: if do_cfg_guidance: encoder_hidden_states_non_cover = torch.cat([encoder_hidden_states_non_cover, self.null_condition_emb.expand_as(encoder_hidden_states_non_cover)], dim=0) encoder_attention_mask_non_cover = torch.cat([encoder_attention_mask_non_cover, encoder_attention_mask_non_cover], dim=0) # src_latents context_latents_non_cover = torch.cat([context_latents_non_cover, context_latents_non_cover], dim=0) encoder_hidden_states = encoder_hidden_states_non_cover encoder_attention_mask = encoder_attention_mask_non_cover context_latents = context_latents_non_cover past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) x = torch.cat([xt, xt], dim=0) if do_cfg_guidance else xt t_curr_tensor = t_curr * torch.ones((x.shape[0],), device=device, dtype=dtype) decoder_outputs = self.decoder( hidden_states=x, timestep=t_curr_tensor, timestep_r=t_curr_tensor, attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, context_latents=context_latents, use_cache=True, past_key_values=past_key_values, ) vt = decoder_outputs[0] past_key_values = decoder_outputs[1] apply_cfg_guidance = t_curr >= cfg_interval_start and t_curr <= cfg_interval_end if do_cfg_guidance: pred_cond, pred_null_cond = vt.chunk(2) if apply_cfg_guidance: if not use_adg: vt = apg_forward( pred_cond=pred_cond, pred_uncond=pred_null_cond, guidance_scale=diffusion_guidance_sale, momentum_buffer=momentum_buffer, dims=[1], ) else: vt = adg_forward( latents=xt, noise_pred_cond=pred_cond, noise_pred_uncond=pred_null_cond, sigma=t_curr, guidance_scale=diffusion_guidance_sale, ) else: vt = pred_cond # Update x_t based on inference method if infer_method == "sde": # Stochastic Differential Equation: predict clean, then re-add noise t_curr_bsz = t_curr * torch.ones((bsz,), device=device, dtype=dtype) pred_clean = self.get_x0_from_noise(xt, vt, t_curr_bsz) next_timestep = 1.0 - (float(step_idx + 1) / infer_steps) xt = self.renoise(pred_clean, next_timestep) elif infer_method == "ode": # Ordinary Differential Equation: Euler method # dx/dt = -v, so x_{t+1} = x_t - v_t * dt dt = t_curr - t_prev dt_tensor = dt * torch.ones((bsz,), device=device, dtype=dtype).unsqueeze(-1).unsqueeze(-1) xt = xt - vt * dt_tensor x_gen = xt end_time = time.time() time_costs["diffusion_time_cost"] = end_time - start_time time_costs["diffusion_per_step_time_cost"] = time_costs["diffusion_time_cost"] / infer_steps time_costs["total_time_cost"] = end_time - total_start_time return { "target_latents": x_gen, "time_costs": time_costs, } def test_forward(model, seed=42): # Fix random seed for reproducibility import random import numpy as np random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # Get model dtype and device model_dtype = next(model.parameters()).dtype device = next(model.parameters()).device print(f"Testing with dtype: {model_dtype}, device: {device}, seed: {seed}") # Test data preparation with matching dtype text_hidden_states = torch.randn(2, 77, 1024, dtype=model_dtype, device=device) text_attention_mask = torch.ones(2, 77, dtype=model_dtype, device=device) lyric_hidden_states = torch.randn(2, 123, 1024, dtype=model_dtype, device=device) lyric_attention_mask = torch.ones(2, 123, dtype=model_dtype, device=device) refer_audio_acoustic_hidden_states_packed = torch.randn(3, 750, 64, dtype=model_dtype, device=device) refer_audio_order_mask = torch.LongTensor([0, 0, 1]).to(device) # Base config: 25 Hz hidden states → 10 s = 250 frames (round to int) base_seconds = 10 frames_per_second = 25 base_seq_len = base_seconds * frames_per_second hidden_states = torch.randn(2, base_seq_len, 64, dtype=model_dtype, device=device) attention_mask = torch.ones(2, base_seq_len, dtype=model_dtype, device=device) # Add some padding to test mask behavior pad_start = max(base_seq_len // 2, 1) attention_mask[0, pad_start:] = 0 chunk_mask = torch.ones(2, base_seq_len, 64, dtype=model_dtype, device=device) chunk_mask[0, pad_start:] = 0 silence_latent = torch.randn(2, base_seq_len, 64, dtype=model_dtype, device=device) # New required parameters for updated training logic src_latents = torch.randn(2, base_seq_len, 64, dtype=model_dtype, device=device) # Source latents for context is_covers = torch.tensor([0, 1], dtype=torch.long, device=device) # Cover song indicators (0=original, 1=cover) # Test 1: Flow matching training (using 10s sequence for sanity check by default) print(f"Testing flow matching training with {base_seconds}s sequence ({base_seq_len} frames @ {frames_per_second}Hz)...") outputs = model.training_losses( hidden_states=hidden_states, attention_mask=attention_mask, chunk_masks=chunk_mask, text_hidden_states=text_hidden_states, text_attention_mask=text_attention_mask, lyric_hidden_states=lyric_hidden_states, lyric_attention_mask=lyric_attention_mask, refer_audio_acoustic_hidden_states_packed=refer_audio_acoustic_hidden_states_packed, refer_audio_order_mask=refer_audio_order_mask, silence_latent=silence_latent, src_latents=src_latents, is_covers=is_covers, cfg_ratio=0.15, ) loss = outputs['diffusion_loss'] print(f"Flow matching loss: {loss.item():.6f}") print(f" Loss stats - min: {loss.min().item():.6f}, max: {loss.max().item():.6f}, mean: {loss.mean().item():.6f}, std: {loss.std().item() if loss.numel() > 1 else 0:.6f}") # Test 2: Generation with flow matching, testing throughput for different sequence lengths lengths_seconds = [10, 30, 60, 120, 180, 240] infer_steps = 2 # Can be increased as needed (e.g., 50/100) to better approximate real inference print("\n===== Throughput benchmark (25Hz hidden states) =====") for seconds in lengths_seconds: seq_len = seconds * frames_per_second # Reconstruct inputs for current sequence length cur_hidden_states = torch.randn(2, seq_len, 64, dtype=model_dtype, device=device) cur_attention_mask = torch.ones(2, seq_len, dtype=model_dtype, device=device) cur_chunk_mask = torch.ones(2, seq_len, 64, dtype=model_dtype, device=device) cur_silence_latent = torch.randn(2, seq_len, 64, dtype=model_dtype, device=device) cur_src_latents = torch.randn(2, seq_len, 64, dtype=model_dtype, device=device) print(f"\n--- {seconds}s input ({seq_len} frames @ {frames_per_second}Hz) ---") outputs = model.generate_audio( text_hidden_states=text_hidden_states, text_attention_mask=text_attention_mask, lyric_hidden_states=lyric_hidden_states, lyric_attention_mask=lyric_attention_mask, refer_audio_acoustic_hidden_states_packed=refer_audio_acoustic_hidden_states_packed, refer_audio_order_mask=refer_audio_order_mask, src_latents=cur_src_latents, chunk_masks=cur_chunk_mask, silence_latent=cur_silence_latent, infer_steps=infer_steps, is_covers=is_covers, seed=1234, ) target_latents = outputs["target_latents"] time_costs = outputs.get("time_costs", {}) total_time = time_costs.get("total_time_cost", None) diffusion_time = time_costs.get("diffusion_time_cost", None) # Output shape and statistics print(f"Generated latents shape: {target_latents.shape}") print( f"Stats - min: {target_latents.min().item():.4f}, " f"max: {target_latents.max().item():.4f}, " f"mean: {target_latents.mean().item():.4f}, " f"std: {target_latents.std().item():.4f}" ) # Calculate throughput: statistics by frame count and audio seconds bsz, t_len = target_latents.shape[0], target_latents.shape[1] audio_seconds = t_len / frames_per_second if total_time is not None: frames_throughput = (bsz * t_len) / total_time seconds_throughput = (bsz * audio_seconds) / total_time print( f"Time costs: total={total_time:.4f}s, diffusion={diffusion_time:.4f}s " f"({infer_steps} steps)" if diffusion_time is not None else f"Time costs: total={total_time:.4f}s" ) print( f"Throughput (based on total_time): " f"{frames_throughput:.2f} frames/s, " f"{seconds_throughput:.2f} audio-seconds/s (batch={bsz})" ) else: print("Time costs not available in outputs['time_costs']; only basic stats printed.") if __name__ == "__main__": from torch.profiler import profile, record_function, ProfilerActivity import os, torch import time from transformers import AutoModel config = AceStepConfig() start = time.time() import os model_dir = os.path.dirname(os.path.abspath(__file__)) model = AceStepConditionGenerationModel.from_pretrained(model_dir) end = time.time() # model.config._attn_implementation = "sdpa" model.config._attn_implementation = "flash_attention_2" model.eval() # model = model.to("cpu") # model = model.float() model = model.to("cuda") model = model.bfloat16() test_forward(model)