acestep-v15-base / modeling_acestep_v15_base.py
ChuxiJ's picture
Update modeling_acestep_v15_base.py
e432212 verified
# Copyright 2025 The ACESTEO Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import time
from typing import Callable, List, Optional, Union
import torch
import torch.nn.functional as F
from torch import nn
from einops import rearrange
# Transformers imports (sorted by submodule, then alphabetically)
from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_layers import GradientCheckpointingLayer
from transformers.modeling_outputs import BaseModelOutput
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from transformers.processing_utils import Unpack
from transformers.utils import auto_docstring, can_return_tuple, logging
from transformers.models.qwen3.modeling_qwen3 import (
Qwen3MLP,
Qwen3RMSNorm,
Qwen3RotaryEmbedding,
apply_rotary_pos_emb,
eager_attention_forward,
)
from tqdm import tqdm
from vector_quantize_pytorch import ResidualFSQ
# Local config import with fallback
try:
from .configuration_acestep_v15 import AceStepConfig
from .apg_guidance import adg_forward, apg_forward, MomentumBuffer
except ImportError:
from configuration_acestep_v15 import AceStepConfig
from apg_guidance import adg_forward, apg_forward, MomentumBuffer
logger = logging.get_logger(__name__)
def create_4d_mask(
seq_len: int,
dtype: torch.dtype,
device: torch.device,
attention_mask: Optional[torch.Tensor] = None, # [Batch, Seq_Len]
sliding_window: Optional[int] = None,
is_sliding_window: bool = False,
is_causal: bool = True,
) -> torch.Tensor:
"""
General 4D Attention Mask generator compatible with CPU/Mac/SDPA and Eager mode.
Supports use cases:
1. Causal Full: is_causal=True, is_sliding_window=False (standard GPT)
2. Causal Sliding: is_causal=True, is_sliding_window=True (Mistral/Qwen local window)
3. Bidirectional Full: is_causal=False, is_sliding_window=False (BERT/Encoder)
4. Bidirectional Sliding: is_causal=False, is_sliding_window=True (Longformer local)
Returns:
[Batch, 1, Seq_Len, Seq_Len] additive mask (0.0 for keep, -inf for mask)
"""
# ------------------------------------------------------
# 1. Construct basic geometry mask [Seq_Len, Seq_Len]
# ------------------------------------------------------
# Build index matrices
# i (Query): [0, 1, ..., L-1]
# j (Key): [0, 1, ..., L-1]
indices = torch.arange(seq_len, device=device)
# diff = i - j
diff = indices.unsqueeze(1) - indices.unsqueeze(0)
# Initialize all True (all positions visible)
valid_mask = torch.ones((seq_len, seq_len), device=device, dtype=torch.bool)
# (A) Handle causality (Causal)
if is_causal:
# i >= j => diff >= 0
valid_mask = valid_mask & (diff >= 0)
# (B) Handle sliding window
if is_sliding_window and sliding_window is not None:
if is_causal:
# Causal sliding: only attend to past window steps
# i - j <= window => diff <= window
# (diff >= 0 already handled above)
valid_mask = valid_mask & (diff <= sliding_window)
else:
# Bidirectional sliding: attend past and future window steps
# |i - j| <= window => abs(diff) <= sliding_window
valid_mask = valid_mask & (torch.abs(diff) <= sliding_window)
# Expand dimensions to [1, 1, Seq_Len, Seq_Len] for broadcasting
valid_mask = valid_mask.unsqueeze(0).unsqueeze(0)
# ------------------------------------------------------
# 2. Apply padding mask (Key Masking)
# ------------------------------------------------------
if attention_mask is not None:
# attention_mask shape: [Batch, Seq_Len] (1=valid, 0=padding)
# We want to mask out invalid keys (columns)
# Expand shape: [Batch, 1, 1, Seq_Len]
padding_mask_4d = attention_mask.view(attention_mask.shape[0], 1, 1, seq_len).to(torch.bool)
# Broadcasting: Geometry Mask [1, 1, L, L] & Padding Mask [B, 1, 1, L]
# Result shape: [B, 1, L, L]
valid_mask = valid_mask & padding_mask_4d
# ------------------------------------------------------
# 3. Convert to additive mask
# ------------------------------------------------------
# Get the minimal value for current dtype
min_dtype = torch.finfo(dtype).min
# Create result tensor filled with -inf by default
mask_tensor = torch.full(valid_mask.shape, min_dtype, dtype=dtype, device=device)
# Set valid positions to 0.0
mask_tensor.masked_fill_(valid_mask, 0.0)
return mask_tensor
def pack_sequences(hidden1: torch.Tensor, hidden2: torch.Tensor, mask1: torch.Tensor, mask2: torch.Tensor):
"""
Pack two sequences by concatenating and sorting them based on mask values.
Args:
hidden1: First hidden states tensor of shape [B, L1, D]
hidden2: Second hidden states tensor of shape [B, L2, D]
mask1: First mask tensor of shape [B, L1]
mask2: Second mask tensor of shape [B, L2]
Returns:
Tuple of (packed_hidden_states, new_mask) where:
- packed_hidden_states: Packed hidden states with valid tokens (mask=1) first, shape [B, L1+L2, D]
- new_mask: New mask tensor indicating valid positions, shape [B, L1+L2]
"""
# Step 1: Concatenate hidden states and masks along sequence dimension
hidden_cat = torch.cat([hidden1, hidden2], dim=1) # [B, L, D]
mask_cat = torch.cat([mask1, mask2], dim=1) # [B, L]
B, L, D = hidden_cat.shape
# Step 2: Sort indices so that mask values of 1 come before 0
sort_idx = mask_cat.argsort(dim=1, descending=True, stable=True) # [B, L]
# Step 3: Reorder hidden states using sorted indices
hidden_left = torch.gather(hidden_cat, 1, sort_idx.unsqueeze(-1).expand(B, L, D))
# Step 4: Create new mask based on valid sequence lengths
lengths = mask_cat.sum(dim=1) # [B]
new_mask = (torch.arange(L, dtype=torch.long, device=hidden_cat.device).unsqueeze(0) < lengths.unsqueeze(1))
return hidden_left, new_mask
def sample_t_r(batch_size, device, dtype, data_proportion=0.0, timestep_mu=-0.4, timestep_sigma=1.0, use_meanflow=True):
"""
Sample timestep t and r for flow matching training.
Args:
batch_size: Batch size
device: Device to create tensors on
dtype: Data type for tensors
data_proportion: Proportion of data samples (0.0 to 1.0)
timestep_mu: Mean for timestep sampling
timestep_sigma: Standard deviation for timestep sampling
use_meanflow: Whether to use meanflow (if False, data_proportion is set to 1.0)
Returns:
Tuple of (t, r) tensors, each of shape [batch_size]
"""
t = torch.sigmoid(torch.randn((batch_size,), device=device, dtype=dtype) * timestep_sigma + timestep_mu)
r = torch.sigmoid(torch.randn((batch_size,), device=device, dtype=dtype) * timestep_sigma + timestep_mu)
# Assign t = max, r = min, for each pair
t, r = torch.maximum(t, r), torch.minimum(t, r)
if not use_meanflow:
data_proportion = 1.0
data_size = int(batch_size * data_proportion)
zero_mask = torch.arange(batch_size, device=device) < data_size
r = torch.where(zero_mask, t, r)
return t, r
class TimestepEmbedding(nn.Module):
"""
Timestep embedding module for diffusion models.
Converts timestep values into high-dimensional embeddings using sinusoidal
positional encoding, followed by MLP layers. Used for conditioning diffusion
models on timestep information.
"""
def __init__(
self,
in_channels: int,
time_embed_dim: int,
scale: float = 1000,
):
super().__init__()
self.linear_1 = nn.Linear(in_channels, time_embed_dim, bias=True)
self.act1 = nn.SiLU()
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim, bias=True)
self.in_channels = in_channels
self.act2 = nn.SiLU()
self.time_proj = nn.Linear(time_embed_dim, time_embed_dim * 6)
self.scale = scale
def timestep_embedding(self, t, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
Args:
t: A 1-D tensor of N indices, one per batch element. These may be fractional.
dim: The dimension of the output embeddings.
max_period: Controls the minimum frequency of the embeddings.
Returns:
An (N, D) tensor of positional embeddings.
"""
t = t * self.scale
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
).to(device=t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t):
t_freq = self.timestep_embedding(t, self.in_channels)
temb = self.linear_1(t_freq.to(t.dtype))
temb = self.act1(temb)
temb = self.linear_2(temb)
timestep_proj = self.time_proj(self.act2(temb)).unflatten(1, (6, -1))
return temb, timestep_proj
class AceStepAttention(nn.Module):
"""
Multi-headed attention module for AceStep model.
Implements the attention mechanism from 'Attention Is All You Need' paper,
with support for both self-attention and cross-attention modes. Uses RMSNorm
for query and key normalization, and supports sliding window attention for
efficient long-sequence processing.
"""
def __init__(self, config: AceStepConfig, layer_idx: int, is_cross_attention: bool = False, is_causal: bool = False):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.scaling = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout
if is_cross_attention:
is_causal = False
self.is_causal = is_causal
self.is_cross_attention = is_cross_attention
self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias)
self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias)
# Apply RMS normalization only on the head dimension (unlike OLMo)
self.q_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.k_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.attention_type = config.layer_types[layer_idx]
self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor],
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
position_embeddings: tuple[torch.Tensor, torch.Tensor] = None,
output_attentions: Optional[bool] = False,
**kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
# Project and normalize query states
query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
# Determine if this is cross-attention (requires encoder_hidden_states)
is_cross_attention = self.is_cross_attention and encoder_hidden_states is not None
# Cross-attention path: attend to encoder hidden states
if is_cross_attention:
encoder_hidden_shape = (*encoder_hidden_states.shape[:-1], -1, self.head_dim)
if past_key_value is not None:
is_updated = past_key_value.is_updated.get(self.layer_idx)
# After the first generated token, we can reuse all key/value states from cache
curr_past_key_value = past_key_value.cross_attention_cache
# Conditions for calculating key and value states
if not is_updated:
# Compute and cache K/V for the first time
key_states = self.k_norm(self.k_proj(encoder_hidden_states).view(encoder_hidden_shape)).transpose(1, 2)
value_states = self.v_proj(encoder_hidden_states).view(encoder_hidden_shape).transpose(1, 2)
# Update cache: save all key/value states to cache for fast auto-regressive generation
key_states, value_states = curr_past_key_value.update(key_states, value_states, self.layer_idx)
# Set flag that this layer's cross-attention cache is updated
past_key_value.is_updated[self.layer_idx] = True
else:
# Reuse cached key/value states for subsequent tokens
key_states = curr_past_key_value.layers[self.layer_idx].keys
value_states = curr_past_key_value.layers[self.layer_idx].values
else:
# No cache used, compute K/V directly
key_states = self.k_norm(self.k_proj(encoder_hidden_states).view(encoder_hidden_shape)).transpose(1, 2)
value_states = self.v_proj(encoder_hidden_states).view(encoder_hidden_shape).transpose(1, 2)
# Self-attention path: attend to the same sequence
else:
# Project and normalize key/value states for self-attention
key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
# Apply rotary position embeddings (RoPE) if provided
if position_embeddings is not None:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
# Update cache for auto-regressive generation
if past_key_value is not None:
# Sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
attention_interface: Callable = eager_attention_forward
if is_cross_attention and output_attentions:
attention_interface: Callable = eager_attention_forward
elif self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=self.attention_dropout if self.training else 0.0,
scaling=self.scaling,
sliding_window=self.sliding_window if not self.is_cross_attention else None,
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
class AceStepEncoderLayer(GradientCheckpointingLayer):
"""
Encoder layer for AceStep model.
Consists of self-attention and MLP (feed-forward) sub-layers with residual connections.
"""
def __init__(self, config, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.config = config
self.layer_idx = layer_idx
# Self-attention sub-layer
self.self_attn = AceStepAttention(
config=config,
layer_idx=layer_idx,
is_cross_attention=False,
is_causal=False,
)
self.input_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# MLP (feed-forward) sub-layer
self.mlp = Qwen3MLP(config)
self.attention_type = config.layer_types[layer_idx]
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = False,
**kwargs,
) -> tuple[
torch.FloatTensor,
Optional[tuple[torch.FloatTensor, torch.FloatTensor]],
]:
# Self-attention with residual connection
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
# Encoders don't use cache
use_cache=False,
past_key_value=None,
**kwargs,
)
hidden_states = residual + hidden_states
# MLP with residual connection
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
return outputs
class AceStepDiTLayer(GradientCheckpointingLayer):
"""
DiT (Diffusion Transformer) layer for AceStep model.
Implements a transformer layer with three main components:
1. Self-attention with adaptive layer norm (AdaLN)
2. Cross-attention (optional) for conditioning on encoder outputs
3. Feed-forward MLP with adaptive layer norm
Uses scale-shift modulation from timestep embeddings for adaptive normalization.
"""
def __init__(self, config: AceStepConfig, layer_idx: int, use_cross_attention: bool = True):
super().__init__()
# 1. Self-attention sub-layer with adaptive normalization
self.self_attn_norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.self_attn = AceStepAttention(config=config, layer_idx=layer_idx)
# 2. Cross-attention sub-layer (optional, for encoder conditioning)
self.use_cross_attention = use_cross_attention
if self.use_cross_attention:
self.cross_attn_norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.cross_attn = AceStepAttention(config=config, layer_idx=layer_idx, is_cross_attention=True)
# 3. Feed-forward MLP sub-layer with adaptive normalization
self.mlp_norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.mlp = Qwen3MLP(config)
# Scale-shift table for adaptive layer norm modulation (6 values: 3 for self-attn, 3 for MLP)
self.scale_shift_table = nn.Parameter(torch.randn(1, 6, config.hidden_size) / config.hidden_size**0.5)
self.attention_type = config.layer_types[layer_idx]
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
temb: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[EncoderDecoderCache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
# Extract scale-shift parameters for adaptive layer norm from timestep embeddings
# 6 values: (shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa)
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
self.scale_shift_table + temb
).chunk(6, dim=1)
# Step 1: Self-attention with adaptive layer norm (AdaLN)
# Apply adaptive normalization: norm(x) * (1 + scale) + shift
norm_hidden_states = (self.self_attn_norm(hidden_states) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
attn_output, self_attn_weights = self.self_attn(
hidden_states=norm_hidden_states,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
use_cache=False,
past_key_value=None,
**kwargs,
)
# Apply gated residual connection: x = x + attn_output * gate
hidden_states = (hidden_states + attn_output * gate_msa).type_as(hidden_states)
# Step 2: Cross-attention (if enabled) for conditioning on encoder outputs
if self.use_cross_attention:
norm_hidden_states = self.cross_attn_norm(hidden_states).type_as(hidden_states)
attn_output, cross_attn_weights = self.cross_attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
**kwargs,
)
# Standard residual connection for cross-attention
hidden_states = hidden_states + attn_output
# Step 3: Feed-forward (MLP) with adaptive layer norm
# Apply adaptive normalization for MLP: norm(x) * (1 + scale) + shift
norm_hidden_states = (self.mlp_norm(hidden_states) * (1 + c_scale_msa) + c_shift_msa).type_as(hidden_states)
ff_output = self.mlp(norm_hidden_states)
# Apply gated residual connection: x = x + mlp_output * gate
hidden_states = (hidden_states + ff_output * c_gate_msa).type_as(hidden_states)
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights, cross_attn_weights)
return outputs
@auto_docstring
class AceStepPreTrainedModel(PreTrainedModel):
config_class = AceStepConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["AceStepEncoderLayer", "AceStepDiTLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_3 = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True
_supports_attention_backend = True
def _init_weights(self, module):
"""
Initialize weights for different module types.
TODO: Support separate initialization for encoders and decoders.
"""
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, Qwen3RMSNorm):
module.weight.data.fill_(1.0)
class AceStepLyricEncoder(AceStepPreTrainedModel):
"""
Encoder for processing lyric text embeddings.
Encodes lyric text hidden states using a transformer encoder architecture
with bidirectional attention. Projects text embeddings to model hidden size
and processes them through multiple encoder layers.
"""
def __init__(self, config):
super().__init__(config)
# Project text embeddings to model hidden size
self.embed_tokens = nn.Linear(config.text_hidden_dim, config.hidden_size)
self.norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = Qwen3RotaryEmbedding(config=config)
self.gradient_checkpointing = False
# Stack of encoder layers
self.layers = nn.ModuleList(
[AceStepEncoderLayer(config, layer_idx) for layer_idx in range(config.num_lyric_encoder_hidden_layers)]
)
# Initialize weights and apply final processing
self.post_init()
@can_return_tuple
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> BaseModelOutput:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
assert input_ids is None, "Only `input_ids` is supported for the lyric encoder."
assert attention_mask is not None, "Attention mask must be provided for the lyric encoder."
assert inputs_embeds is not None, "Inputs embeddings must be provided for the lyric encoder."
# Project input embeddings: N x T x text_hidden_dim -> N x T x hidden_size
inputs_embeds = self.embed_tokens(inputs_embeds)
# Cache position: only used for mask construction (not for actual caching)
cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device)
# Positional IDs
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
# Attention masks
seq_len = inputs_embeds.shape[1]
dtype = inputs_embeds.dtype
device = inputs_embeds.device
# 判断是否使用 Flash Attention 2
is_flash_attn = (self.config._attn_implementation == "flash_attention_2")
# 初始化 Mask 变量
full_attn_mask = None
sliding_attn_mask = None
if is_flash_attn:
# -------------------------------------------------------
# 场景 A: Flash Attention 模式
# -------------------------------------------------------
# FA 不需要 4D Mask。
# 如果有 padding mask (attention_mask [B, L]),直接传给它即可。
# 如果没有 padding mask,传 None。
# 滑动窗口逻辑由 Layer 内部传给 FA kernel 的 sliding_window 参数控制。
full_attn_mask = attention_mask
# 这里的逻辑是:如果配置启用了滑动窗口,FA 模式下我们也只需要传基础的 padding mask
# Layer 会自己决定是否调用带 sliding window 的 kernel
sliding_attn_mask = attention_mask if self.config.use_sliding_window else None
else:
# -------------------------------------------------------
# 场景 B: CPU / Mac / SDPA (Eager 模式)
# -------------------------------------------------------
# 必须手动生成 4D Mask [B, 1, L, L]
# 1. Full Attention (Bidirectional, Global)
# 对应原来的 create_causal_mask + bidirectional
full_attn_mask = create_4d_mask(
seq_len=seq_len,
dtype=dtype,
device=device,
attention_mask=attention_mask, # [B, L]
sliding_window=None,
is_sliding_window=False,
is_causal=False # <--- 关键:双向注意力
)
# 2. Sliding Attention (Bidirectional, Local)
# 对应原来的 create_sliding_window... + bidirectional
if self.config.use_sliding_window:
sliding_attn_mask = create_4d_mask(
seq_len=seq_len,
dtype=dtype,
device=device,
attention_mask=attention_mask, # [B, L]
sliding_window=self.config.sliding_window,
is_sliding_window=True, # <--- 开启滑动窗口
is_causal=False # <--- 关键:双向注意力
)
# 构建 Mapping
self_attn_mask_mapping = {
"full_attention": full_attn_mask,
"sliding_attention": sliding_attn_mask,
}
# Initialize hidden states with input embeddings
hidden_states = inputs_embeds
# Create position embeddings to be shared across all layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
# Pass through transformer layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
for layer_module in self.layers[: self.config.num_hidden_layers]:
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = layer_module(
hidden_states,
position_embeddings,
self_attn_mask_mapping[layer_module.attention_type],
position_ids,
output_attentions,
**flash_attn_kwargs,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
if output_hidden_states:
all_hidden_states += (hidden_states,)
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class AttentionPooler(AceStepPreTrainedModel):
"""
Attention-based pooling module.
Pools sequences of patches using a special token and attention mechanism.
The special token attends to all patches and its output is used as the
pooled representation. Used for aggregating patch-level features into
sequence-level representations.
"""
def __init__(self, config):
super().__init__(config)
self.config = config
self.embed_tokens = nn.Linear(config.hidden_size, config.hidden_size)
self.norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = Qwen3RotaryEmbedding(config=config)
self.gradient_checkpointing = False
# Special token used for pooling (CLS-like token)
self.special_token = nn.Parameter(torch.randn(1, 1, config.hidden_size) * 0.02)
self.layers = nn.ModuleList(
[AceStepEncoderLayer(config, layer_idx) for layer_idx in range(config.num_attention_pooler_hidden_layers)]
)
# Initialize weights and apply final processing
self.post_init()
@can_return_tuple
def forward(self,
x,
attention_mask: Optional[torch.Tensor] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> BaseModelOutput:
B, T, P, D = x.shape
x = self.embed_tokens(x)
special_tokens = self.special_token.expand(B, T, 1, -1)
x = torch.cat([special_tokens, x], dim=2)
x = rearrange(x, "b t p c -> (b t) p c")
# Cache position: only used for mask construction.
cache_position = torch.arange(0, x.shape[1], device=x.device)
# Postional ids.
position_ids = cache_position.unsqueeze(0)
# embed positions
hidden_states = x
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
seq_len = x.shape[1]
dtype = x.dtype
device = x.device
# 判断是否使用 Flash Attention 2
is_flash_attn = (self.config._attn_implementation == "flash_attention_2")
# 初始化 Mask 变量
full_attn_mask = None
sliding_attn_mask = None
if is_flash_attn:
# -------------------------------------------------------
# 场景 A: Flash Attention 模式
# -------------------------------------------------------
# FA 不需要 4D Mask。
# 如果有 padding mask (attention_mask [B, L]),直接传给它即可。
# 如果没有 padding mask,传 None。
# 滑动窗口逻辑由 Layer 内部传给 FA kernel 的 sliding_window 参数控制。
full_attn_mask = attention_mask
# 这里的逻辑是:如果配置启用了滑动窗口,FA 模式下我们也只需要传基础的 padding mask
# Layer 会自己决定是否调用带 sliding window 的 kernel
sliding_attn_mask = attention_mask if self.config.use_sliding_window else None
else:
# -------------------------------------------------------
# 场景 B: CPU / Mac / SDPA (Eager 模式)
# -------------------------------------------------------
# 必须手动生成 4D Mask [B, 1, L, L]
# 1. Full Attention (Bidirectional, Global)
# 对应原来的 create_causal_mask + bidirectional
full_attn_mask = create_4d_mask(
seq_len=seq_len,
dtype=dtype,
device=device,
attention_mask=attention_mask, # [B, L]
sliding_window=None,
is_sliding_window=False,
is_causal=False # <--- 关键:双向注意力
)
# 2. Sliding Attention (Bidirectional, Local)
# 对应原来的 create_sliding_window... + bidirectional
if self.config.use_sliding_window:
sliding_attn_mask = create_4d_mask(
seq_len=seq_len,
dtype=dtype,
device=device,
attention_mask=attention_mask, # [B, L]
sliding_window=self.config.sliding_window,
is_sliding_window=True, # <--- 开启滑动窗口
is_causal=False # <--- 关键:双向注意力
)
# 构建 Mapping
self_attn_mask_mapping = {
"full_attention": full_attn_mask,
"sliding_attention": sliding_attn_mask,
}
for layer_module in self.layers:
layer_outputs = layer_module(
hidden_states,
position_embeddings,
attention_mask=self_attn_mask_mapping[layer_module.attention_type],
**flash_attn_kwargs,
)
hidden_states = layer_outputs[0]
hidden_states = self.norm(hidden_states)
# Extract the special token output (first position) as pooled representation
cls_output = hidden_states[:, 0, :]
cls_output = rearrange(cls_output, "(b t) c -> b t c", b=B)
return cls_output
class AudioTokenDetokenizer(AceStepPreTrainedModel):
"""
Audio token detokenizer module.
Converts quantized audio tokens back to continuous acoustic representations.
Expands each token into multiple patches using special tokens, processes them
through encoder layers, and projects to acoustic hidden dimension.
"""
def __init__(self, config):
super().__init__(config)
self.config = config
self.embed_tokens = nn.Linear(config.hidden_size, config.hidden_size)
self.norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = Qwen3RotaryEmbedding(config=config)
self.gradient_checkpointing = False
# Special tokens for expanding each quantized token into patches
self.special_tokens = nn.Parameter(torch.randn(1, config.pool_window_size, config.hidden_size) * 0.02)
self.layers = nn.ModuleList(
[AceStepEncoderLayer(config, layer_idx) for layer_idx in range(config.num_attention_pooler_hidden_layers)]
)
# Project back to acoustic hidden dimension
self.proj_out = nn.Linear(config.hidden_size, config.audio_acoustic_hidden_dim)
# Initialize weights and apply final processing
self.post_init()
@can_return_tuple
def forward(self,
x,
attention_mask: Optional[torch.Tensor] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> BaseModelOutput:
B, T, D = x.shape
x = self.embed_tokens(x)
# Expand and add special tokens: N x T x D -> N x T x P x D
# Each token is expanded into pool_window_size patches
x = x.unsqueeze(2) # N x T x 1 x D
x = x.repeat(1, 1, self.config.pool_window_size, 1) # N x T x P x D
# Add learnable special tokens to each patch
special_tokens = self.special_tokens.expand(B, T, -1, -1)
x = x + special_tokens
# Reshape for processing: (batch * time) x patches x hidden
x = rearrange(x, "b t p c -> (b t) p c")
# Cache position: only used for mask construction
cache_position = torch.arange(0, x.shape[1], device=x.device)
# Positional IDs
position_ids = cache_position.unsqueeze(0)
# Initialize hidden states
hidden_states = x
# Create position embeddings to be shared across all layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
seq_len = x.shape[1]
dtype = x.dtype
device = x.device
# 判断是否使用 Flash Attention 2
is_flash_attn = (self.config._attn_implementation == "flash_attention_2")
# 初始化 Mask 变量
full_attn_mask = None
sliding_attn_mask = None
if is_flash_attn:
# -------------------------------------------------------
# 场景 A: Flash Attention 模式
# -------------------------------------------------------
# FA 不需要 4D Mask。
# 如果有 padding mask (attention_mask [B, L]),直接传给它即可。
# 如果没有 padding mask,传 None。
# 滑动窗口逻辑由 Layer 内部传给 FA kernel 的 sliding_window 参数控制。
full_attn_mask = attention_mask
# 这里的逻辑是:如果配置启用了滑动窗口,FA 模式下我们也只需要传基础的 padding mask
# Layer 会自己决定是否调用带 sliding window 的 kernel
sliding_attn_mask = attention_mask if self.config.use_sliding_window else None
else:
# -------------------------------------------------------
# 场景 B: CPU / Mac / SDPA (Eager 模式)
# -------------------------------------------------------
# 必须手动生成 4D Mask [B, 1, L, L]
# 1. Full Attention (Bidirectional, Global)
# 对应原来的 create_causal_mask + bidirectional
full_attn_mask = create_4d_mask(
seq_len=seq_len,
dtype=dtype,
device=device,
attention_mask=attention_mask, # [B, L]
sliding_window=None,
is_sliding_window=False,
is_causal=False # <--- 关键:双向注意力
)
# 2. Sliding Attention (Bidirectional, Local)
# 对应原来的 create_sliding_window... + bidirectional
if self.config.use_sliding_window:
sliding_attn_mask = create_4d_mask(
seq_len=seq_len,
dtype=dtype,
device=device,
attention_mask=attention_mask, # [B, L]
sliding_window=self.config.sliding_window,
is_sliding_window=True, # <--- 开启滑动窗口
is_causal=False # <--- 关键:双向注意力
)
# 构建 Mapping
self_attn_mask_mapping = {
"full_attention": full_attn_mask,
"sliding_attention": sliding_attn_mask,
}
for layer_module in self.layers:
layer_outputs = layer_module(
hidden_states,
position_embeddings,
attention_mask=self_attn_mask_mapping[layer_module.attention_type],
**flash_attn_kwargs,
)
hidden_states = layer_outputs[0]
hidden_states = self.norm(hidden_states)
hidden_states = self.proj_out(hidden_states)
hidden_states = rearrange(hidden_states, "(b t) p c -> b (t p) c", b=B, p=self.config.pool_window_size)
return hidden_states
class AceStepTimbreEncoder(AceStepPreTrainedModel):
"""
Encoder for extracting timbre embeddings from reference audio.
Processes packed reference audio acoustic features to extract timbre
representations. Uses a special token (CLS-like) to aggregate information
from the entire reference audio sequence. Outputs are unpacked back to
batch format for use in conditioning.
"""
def __init__(self, config):
super().__init__(config)
# Project acoustic features to model hidden size
self.embed_tokens = nn.Linear(config.timbre_hidden_dim, config.hidden_size)
self.norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = Qwen3RotaryEmbedding(config=config)
self.gradient_checkpointing = False
# Special token for aggregating timbre information (prepended to sequence)
self.special_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
self.layers = nn.ModuleList(
[AceStepEncoderLayer(config, layer_idx) for layer_idx in range(config.num_timbre_encoder_hidden_layers)]
)
# Initialize weights and apply final processing
self.post_init()
def unpack_timbre_embeddings(self, timbre_embs_packed, refer_audio_order_mask):
"""
Unpack packed timbre embeddings into batch format.
Args:
timbre_embs_packed: Packed timbre embeddings of shape [N, d]
refer_audio_order_mask: Order mask indicating batch assignment for each packed embedding
Returns:
Tuple of (unpacked_embeddings, mask):
- unpacked_embeddings: Unpacked embeddings of shape [B, max_count, d]
- new_mask: Mask indicating valid positions, shape [B, max_count]
"""
N, d = timbre_embs_packed.shape
device = timbre_embs_packed.device
dtype = timbre_embs_packed.dtype
# Get batch size
B = int(refer_audio_order_mask.max().item() + 1)
# Calculate element count and positions for each batch
counts = torch.bincount(refer_audio_order_mask, minlength=B)
max_count = counts.max().item()
# Calculate positions within batch
sorted_indices = torch.argsort(refer_audio_order_mask * N + torch.arange(N, device=device), stable=True)
sorted_batch_ids = refer_audio_order_mask[sorted_indices]
positions = torch.arange(N, device=device)
batch_starts = torch.cat([torch.tensor([0], device=device),
torch.cumsum(counts, dim=0)[:-1]])
positions_in_sorted = positions - batch_starts[sorted_batch_ids]
inverse_indices = torch.empty_like(sorted_indices)
inverse_indices[sorted_indices] = torch.arange(N, device=device)
positions_in_batch = positions_in_sorted[inverse_indices]
# Use one-hot encoding and matrix multiplication (gradient-friendly approach)
# Create one-hot encoding
indices_2d = refer_audio_order_mask * max_count + positions_in_batch # (N,)
one_hot = F.one_hot(indices_2d, num_classes=B * max_count).to(dtype) # (N, B*max_count)
# Rearrange using matrix multiplication
timbre_embs_flat = one_hot.t() @ timbre_embs_packed # (B*max_count, d)
timbre_embs_unpack = timbre_embs_flat.reshape(B, max_count, d)
# Create mask indicating valid positions
mask_flat = (one_hot.sum(dim=0) > 0).long() # (B*max_count,)
new_mask = mask_flat.reshape(B, max_count)
return timbre_embs_unpack, new_mask
@can_return_tuple
def forward(
self,
refer_audio_acoustic_hidden_states_packed: Optional[torch.FloatTensor] = None,
refer_audio_order_mask: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> BaseModelOutput:
inputs_embeds = refer_audio_acoustic_hidden_states_packed
# Project embeddings: N x T x timbre_hidden_dim -> N x T x hidden_size
inputs_embeds = self.embed_tokens(inputs_embeds)
# Prepend special token for timbre aggregation (CLS-like token)
# inputs_embeds = torch.cat([self.special_token.expand(inputs_embeds.shape[0], 1, -1), inputs_embeds], dim=1)
# Cache position: only used for mask construction (not for actual caching)
cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device)
# Positional IDs
position_ids = cache_position.unsqueeze(0)
seq_len = inputs_embeds.shape[1]
dtype = inputs_embeds.dtype
device = inputs_embeds.device
# 判断是否使用 Flash Attention 2
is_flash_attn = (self.config._attn_implementation == "flash_attention_2")
# 初始化 Mask 变量
full_attn_mask = None
sliding_attn_mask = None
if is_flash_attn:
# -------------------------------------------------------
# 场景 A: Flash Attention 模式
# -------------------------------------------------------
# FA 不需要 4D Mask。
# 如果有 padding mask (attention_mask [B, L]),直接传给它即可。
# 如果没有 padding mask,传 None。
# 滑动窗口逻辑由 Layer 内部传给 FA kernel 的 sliding_window 参数控制。
full_attn_mask = attention_mask
# 这里的逻辑是:如果配置启用了滑动窗口,FA 模式下我们也只需要传基础的 padding mask
# Layer 会自己决定是否调用带 sliding window 的 kernel
sliding_attn_mask = attention_mask if self.config.use_sliding_window else None
else:
# -------------------------------------------------------
# 场景 B: CPU / Mac / SDPA (Eager 模式)
# -------------------------------------------------------
# 必须手动生成 4D Mask [B, 1, L, L]
# 1. Full Attention (Bidirectional, Global)
# 对应原来的 create_causal_mask + bidirectional
full_attn_mask = create_4d_mask(
seq_len=seq_len,
dtype=dtype,
device=device,
attention_mask=attention_mask, # [B, L]
sliding_window=None,
is_sliding_window=False,
is_causal=False # <--- 关键:双向注意力
)
# 2. Sliding Attention (Bidirectional, Local)
# 对应原来的 create_sliding_window... + bidirectional
if self.config.use_sliding_window:
sliding_attn_mask = create_4d_mask(
seq_len=seq_len,
dtype=dtype,
device=device,
attention_mask=attention_mask, # [B, L]
sliding_window=self.config.sliding_window,
is_sliding_window=True, # <--- 开启滑动窗口
is_causal=False # <--- 关键:双向注意力
)
# 构建 Mapping
self_attn_mask_mapping = {
"full_attention": full_attn_mask,
"sliding_attention": sliding_attn_mask,
}
# Initialize hidden states
hidden_states = inputs_embeds
# Create position embeddings to be shared across all layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
# Pass through transformer layers
for layer_module in self.layers[: self.config.num_hidden_layers]:
layer_outputs = layer_module(
hidden_states,
position_embeddings,
self_attn_mask_mapping[layer_module.attention_type],
position_ids,
**flash_attn_kwargs,
)
hidden_states = layer_outputs[0]
hidden_states = self.norm(hidden_states)
# Extract special token output (first position) as timbre embedding: N x T x D -> N x D
hidden_states = hidden_states[:, 0, :]
# Unpack packed embeddings back to batch format
timbre_embs_unpack, timbre_embs_mask = self.unpack_timbre_embeddings(hidden_states, refer_audio_order_mask)
return timbre_embs_unpack, timbre_embs_mask
class AceStepAudioTokenizer(AceStepPreTrainedModel):
"""
Audio tokenizer module.
Converts continuous acoustic features into discrete quantized tokens.
Process: project -> pool patches -> quantize. Used for converting audio
representations into discrete tokens for processing by the diffusion model.
"""
def __init__(self, config):
super().__init__(config)
# Project acoustic features to hidden size
self.audio_acoustic_proj = nn.Linear(config.audio_acoustic_hidden_dim, config.hidden_size)
# Pool patches into sequence-level representations
self.attention_pooler = AttentionPooler(config)
# Quantize continuous representations into discrete tokens
self.quantizer = ResidualFSQ(
dim=config.fsq_dim,
levels=config.fsq_input_levels,
num_quantizers=config.fsq_input_num_quantizers
)
self.pool_window_size = config.pool_window_size
# Initialize weights and apply final processing
self.post_init()
@can_return_tuple
def forward(
self,
hidden_states: Optional[torch.FloatTensor] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> BaseModelOutput:
# Project acoustic features to hidden size
hidden_states = self.audio_acoustic_proj(hidden_states)
# Pool sequences: N x T//pool_window_size x pool_window_size x d -> N x T//pool_window_size x d
hidden_states = self.attention_pooler(hidden_states)
# Quantize continuous representations into discrete tokens: N x T//pool_window_size x d
quantized, indices = self.quantizer(hidden_states)
return quantized, indices
def tokenize(self, x):
x = rearrange(x, 'n (t_patch p) d -> n t_patch p d', p=self.pool_window_size)
quantized, indices = self.forward(x)
return quantized, indices
class Lambda(nn.Module):
"""
Wrapper module for arbitrary lambda functions.
Allows using lambda functions in nn.Sequential by wrapping them in a Module.
Useful for simple transformations like transpose operations.
"""
def __init__(self, func):
super().__init__()
self.func = func
def forward(self, x):
return self.func(x)
class AceStepDiTModel(AceStepPreTrainedModel):
"""
DiT (Diffusion Transformer) model for AceStep.
Main diffusion model that generates audio latents conditioned on text, lyrics,
and timbre. Uses patch-based processing with transformer layers, timestep
conditioning, and cross-attention to encoder outputs.
"""
def __init__(self, config: AceStepConfig):
super().__init__(config)
# Rotary position embeddings for transformer layers
self.rotary_emb = Qwen3RotaryEmbedding(config)
# Stack of DiT transformer layers
self.layers = nn.ModuleList(
[AceStepDiTLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
in_channels = config.in_channels
inner_dim = config.hidden_size
patch_size = config.patch_size
self.patch_size = patch_size
# Input projection: patch embedding using 1D convolution
# Converts sequence into patches for efficient processing
self.proj_in = nn.Sequential(
Lambda(lambda x: x.transpose(1, 2)), # [B, T, C] -> [B, C, T]
nn.Conv1d(
in_channels=in_channels,
out_channels=inner_dim,
kernel_size=patch_size,
stride=patch_size,
padding=0,
),
Lambda(lambda x: x.transpose(1, 2)), # [B, C, T//patch_size] -> [B, T//patch_size, C]
)
# Timestep embeddings for diffusion conditioning
# Two embeddings: one for timestep t, one for timestep difference (t - r)
self.time_embed = TimestepEmbedding(in_channels=256, time_embed_dim=inner_dim)
self.time_embed_r = TimestepEmbedding(in_channels=256, time_embed_dim=inner_dim)
# Project encoder hidden states to model dimension
self.condition_embedder = nn.Linear(inner_dim, inner_dim, bias=True)
# Output normalization and projection
# Adaptive layer norm with scale-shift modulation, then de-patchify
self.norm_out = Qwen3RMSNorm(inner_dim, eps=config.rms_norm_eps)
self.proj_out = nn.Sequential(
Lambda(lambda x: x.transpose(1, 2)), # [B, T//patch_size, inner_dim] -> [B, inner_dim, T//patch_size]
nn.ConvTranspose1d(
in_channels=inner_dim,
out_channels=config.audio_acoustic_hidden_dim,
kernel_size=patch_size,
stride=patch_size,
padding=0,
),
Lambda(lambda x: x.transpose(1, 2)), # [B, out_channels, T] -> [B, T, out_channels]
)
# Scale-shift table for adaptive output normalization (2 values: shift, scale)
self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5)
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
timestep: torch.Tensor,
timestep_r: torch.Tensor,
attention_mask: torch.Tensor,
encoder_hidden_states: torch.Tensor,
encoder_attention_mask: torch.Tensor,
context_latents: torch.Tensor,
use_cache: Optional[bool] = None,
past_key_values: Optional[EncoderDecoderCache] = None,
cache_position: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = False,
return_hidden_states: int = None,
custom_layers_config: Optional[dict] = None,
enable_early_exit: bool = False,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
):
use_cache = use_cache if use_cache is not None else self.config.use_cache
# Disable cache during training or when gradient checkpointing is enabled
if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
)
use_cache = False
if self.training:
use_cache = False
# Initialize cache if needed (only during inference for auto-regressive generation)
if not self.training and use_cache and past_key_values is None:
past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache())
# Compute timestep embeddings for diffusion conditioning
# Two embeddings: one for timestep t, one for timestep difference (t - r)
temb_t, timestep_proj_t = self.time_embed(timestep)
temb_r, timestep_proj_r = self.time_embed_r(timestep - timestep_r)
# Combine embeddings
temb = temb_t + temb_r
timestep_proj = timestep_proj_t + timestep_proj_r
# Concatenate context latents (source latents + chunk masks) with hidden states
hidden_states = torch.cat([context_latents, hidden_states], dim=-1)
# Record original sequence length for later restoration after padding
original_seq_len = hidden_states.shape[1]
# Apply padding if sequence length is not divisible by patch_size
# This ensures proper patch extraction
pad_length = 0
if hidden_states.shape[1] % self.patch_size != 0:
pad_length = self.patch_size - (hidden_states.shape[1] % self.patch_size)
hidden_states = F.pad(hidden_states, (0, 0, 0, pad_length), mode='constant', value=0)
# Project input to patches and project encoder states
hidden_states = self.proj_in(hidden_states)
encoder_hidden_states = self.condition_embedder(encoder_hidden_states)
# Cache positions
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=hidden_states.device
)
# Position IDs
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
seq_len = hidden_states.shape[1]
encoder_seq_len = encoder_hidden_states.shape[1]
dtype = hidden_states.dtype
device = hidden_states.device
# 判断是否使用 Flash Attention 2
is_flash_attn = (self.config._attn_implementation == "flash_attention_2")
# 初始化 Mask 变量
full_attn_mask = None
sliding_attn_mask = None
encoder_attention_mask = None
attention_mask = None
if is_flash_attn:
# -------------------------------------------------------
# 场景 A: Flash Attention 模式
# -------------------------------------------------------
# FA 不需要 4D Mask。
# 如果有 padding mask (attention_mask [B, L]),直接传给它即可。
# 如果没有 padding mask,传 None。
# 滑动窗口逻辑由 Layer 内部传给 FA kernel 的 sliding_window 参数控制。
full_attn_mask = attention_mask
# 这里的逻辑是:如果配置启用了滑动窗口,FA 模式下我们也只需要传基础的 padding mask
# Layer 会自己决定是否调用带 sliding window 的 kernel
sliding_attn_mask = attention_mask if self.config.use_sliding_window else None
else:
# -------------------------------------------------------
# 场景 B: CPU / Mac / SDPA (Eager 模式)
# -------------------------------------------------------
# 必须手动生成 4D Mask [B, 1, L, L]
# 1. Full Attention (Bidirectional, Global)
# 对应原来的 create_causal_mask + bidirectional
full_attn_mask = create_4d_mask(
seq_len=seq_len,
dtype=dtype,
device=device,
attention_mask=attention_mask, # [B, L]
sliding_window=None,
is_sliding_window=False,
is_causal=False # <--- 关键:双向注意力
)
max_len = max(seq_len, encoder_seq_len)
encoder_attention_mask = create_4d_mask(
seq_len=max_len,
dtype=dtype,
device=device,
attention_mask=attention_mask, # [B, L]
sliding_window=None,
is_sliding_window=False,
is_causal=False # <--- 关键:双向注意力
)
encoder_attention_mask = encoder_attention_mask[:, :, :seq_len, :encoder_seq_len]
# 2. Sliding Attention (Bidirectional, Local)
# 对应原来的 create_sliding_window... + bidirectional
if self.config.use_sliding_window:
sliding_attn_mask = create_4d_mask(
seq_len=seq_len,
dtype=dtype,
device=device,
attention_mask=attention_mask, # [B, L]
sliding_window=self.config.sliding_window,
is_sliding_window=True, # <--- 开启滑动窗口
is_causal=False # <--- 关键:双向注意力
)
# 构建 Mapping
self_attn_mask_mapping = {
"full_attention": full_attn_mask,
"sliding_attention": sliding_attn_mask,
"encoder_attention_mask": encoder_attention_mask,
}
# Create position embeddings to be shared across all decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
all_cross_attentions = () if output_attentions else None
# Handle early exit for custom layer configurations
max_needed_layer = float('inf')
if custom_layers_config is not None and enable_early_exit:
max_needed_layer = max(custom_layers_config.keys())
# Force output_attentions to True when early exit is enabled for attention extraction
output_attentions = True
if all_cross_attentions is None:
all_cross_attentions = ()
# Process through transformer layers
for index_block, layer_module in enumerate(self.layers):
layer_outputs = layer_module(
hidden_states,
position_embeddings,
timestep_proj,
self_attn_mask_mapping[layer_module.attention_type],
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
encoder_hidden_states,
self_attn_mask_mapping["encoder_attention_mask"],
**flash_attn_kwargs,
)
hidden_states = layer_outputs[0]
if output_attentions and self.layers[index_block].use_cross_attention:
# layer_outputs structure: (hidden_states, self_attn_weights, cross_attn_weights)
# Extract the last element which is cross_attn_weights
if len(layer_outputs) >= 3:
all_cross_attentions += (layer_outputs[2],)
if return_hidden_states:
return hidden_states
# Extract scale-shift parameters for adaptive output normalization
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
shift = shift.to(hidden_states.device)
scale = scale.to(hidden_states.device)
# Apply adaptive layer norm: norm(x) * (1 + scale) + shift
hidden_states = (self.norm_out(hidden_states) * (1 + scale) + shift).type_as(hidden_states)
# Project output: de-patchify back to original sequence format
hidden_states = self.proj_out(hidden_states)
# Crop back to original sequence length to ensure exact length match (remove padding)
hidden_states = hidden_states[:, :original_seq_len, :]
outputs = (hidden_states, past_key_values)
if output_attentions:
outputs += (all_cross_attentions,)
return outputs
class AceStepConditionEncoder(AceStepPreTrainedModel):
"""
Condition encoder for AceStep model.
Encodes multiple conditioning inputs (text, lyrics, timbre) and packs them
into a single sequence for cross-attention in the diffusion model. Handles
projection, encoding, and sequence packing.
"""
def __init__(self, config: AceStepConfig):
super().__init__(config)
self.config = config
# Project text embeddings to model hidden size
self.text_projector = nn.Linear(config.text_hidden_dim, config.hidden_size, bias=False)
# Encoder for lyric text
self.lyric_encoder = AceStepLyricEncoder(config)
# Encoder for timbre from reference audio
self.timbre_encoder = AceStepTimbreEncoder(config)
def forward(
self,
# Text inputs
text_hidden_states: Optional[torch.FloatTensor] = None,
text_attention_mask: Optional[torch.Tensor] = None,
# Lyric inputs
lyric_hidden_states: Optional[torch.LongTensor] = None,
lyric_attention_mask: Optional[torch.Tensor] = None,
# Reference audio for timbre
refer_audio_acoustic_hidden_states_packed: Optional[torch.Tensor] = None,
refer_audio_order_mask: Optional[torch.LongTensor] = None,
):
# Project and encode text
text_hidden_states = self.text_projector(text_hidden_states)
# Encode lyrics
lyric_encoder_outputs = self.lyric_encoder(
inputs_embeds=lyric_hidden_states,
attention_mask=lyric_attention_mask,
)
lyric_hidden_states = lyric_encoder_outputs.last_hidden_state
# Encode timbre from reference audio
timbre_embs_unpack, timbre_embs_mask = self.timbre_encoder(refer_audio_acoustic_hidden_states_packed, refer_audio_order_mask)
# Pack sequences: combine lyrics and timbre, then add text
# This creates a single sequence with all conditioning information
encoder_hidden_states, encoder_attention_mask = pack_sequences(lyric_hidden_states, timbre_embs_unpack, lyric_attention_mask, timbre_embs_mask)
encoder_hidden_states, encoder_attention_mask = pack_sequences(encoder_hidden_states, text_hidden_states, encoder_attention_mask, text_attention_mask)
return encoder_hidden_states, encoder_attention_mask
class AceStepConditionGenerationModel(AceStepPreTrainedModel):
"""
Main conditional generation model for AceStep.
End-to-end model for generating audio conditioned on text, lyrics, and timbre.
Combines encoder (for conditioning), decoder (diffusion model), tokenizer
(for discrete tokenization), and detokenizer (for reconstruction).
Supports flow matching training and inference with various sampling methods.
"""
def __init__(self, config: AceStepConfig):
super().__init__(config)
self.config = config
# Diffusion model components
self.decoder = AceStepDiTModel(config) # Main diffusion transformer
self.encoder = AceStepConditionEncoder(config) # Condition encoder
self.tokenizer = AceStepAudioTokenizer(config) # Audio tokenizer
self.detokenizer = AudioTokenDetokenizer(config) # Audio detokenizer
# Null condition embedding for classifier-free guidance
self.null_condition_emb = nn.Parameter(torch.randn(1, 1, config.hidden_size))
# Initialize weights and apply final processing
self.post_init()
def tokenize(self, x, silence_latent, attention_mask):
if x.shape[1] % self.config.pool_window_size != 0:
pad_len = self.config.pool_window_size - (x.shape[1] % self.config.pool_window_size)
x = torch.cat([x, silence_latent[:1,:pad_len].repeat(x.shape[0],1,1)], dim=1)
attention_mask = F.pad(attention_mask, (0, pad_len), mode='constant', value=0)
x = rearrange(x, 'n (t_patch p) d -> n t_patch p d', p=self.config.pool_window_size)
seq_len = x.shape[1]
chunk = math.ceil(attention_mask.shape[1] / seq_len)
attention_mask = attention_mask.to(x.dtype)
attention_mask = F.max_pool1d(attention_mask.unsqueeze(1), kernel_size=chunk, stride=chunk, ceil_mode=True).squeeze(1)
quantized, indices = self.tokenizer(x)
return quantized, indices, attention_mask
def detokenize(self, quantized):
"""
Detokenize quantized audio tokens back to continuous representations.
Args:
quantized: Quantized tokens of shape [N, T//pool_window_size, d]
Returns:
Detokenized hidden states of shape [N, T, d]
"""
hidden_states = self.detokenizer(quantized)
return hidden_states
@torch.no_grad()
def prepare_condition(
self,
text_hidden_states: torch.FloatTensor,
text_attention_mask: torch.Tensor,
lyric_hidden_states: torch.FloatTensor,
lyric_attention_mask: torch.Tensor,
refer_audio_acoustic_hidden_states_packed: torch.FloatTensor,
refer_audio_order_mask: torch.Tensor,
hidden_states: torch.FloatTensor,
attention_mask: torch.Tensor,
silence_latent: torch.FloatTensor,
src_latents: torch.FloatTensor,
chunk_masks: torch.Tensor,
is_covers: torch.Tensor,
precomputed_lm_hints_25Hz: Optional[torch.FloatTensor] = None,
audio_codes: torch.FloatTensor = None,
):
dtype = hidden_states.dtype
encoder_hidden_states, encoder_attention_mask = self.encoder(
text_hidden_states=text_hidden_states,
text_attention_mask=text_attention_mask,
lyric_hidden_states=lyric_hidden_states,
lyric_attention_mask=lyric_attention_mask,
refer_audio_acoustic_hidden_states_packed=refer_audio_acoustic_hidden_states_packed,
refer_audio_order_mask=refer_audio_order_mask,
)
# N x T x d -> N x T//pool_window_size x pool_window_size x d
# tokenize and detokenize to get LM hints for cover songs (when is_covers=True)
# Use precomputed hints if provided (e.g., from audio codes), otherwise tokenize hidden_states
if precomputed_lm_hints_25Hz is not None:
print("Using precomputed LM hints")
lm_hints_25Hz = precomputed_lm_hints_25Hz[:, :src_latents.shape[1], :]
else:
if audio_codes is not None:
lm_hints_5Hz = self.tokenize.quantizer.get_output_from_indices(audio_codes)
else:
lm_hints_5Hz, indices, llm_mask = self.tokenize(hidden_states, silence_latent, attention_mask)
lm_hints_25Hz = self.detokenize(lm_hints_5Hz)
# Crop lm_hints_25Hz to match src_latents length (tokenize may have added padding)
lm_hints_25Hz = lm_hints_25Hz[:, :src_latents.shape[1], :]
src_latents = torch.where(is_covers.unsqueeze(-1).unsqueeze(-1) > 0, lm_hints_25Hz, src_latents)
# Concatenate source latents with chunk masks as context
context_latents = torch.cat([src_latents, chunk_masks.to(dtype)], dim=-1)
return encoder_hidden_states, encoder_attention_mask, context_latents
def forward(
self,
# Diffusion inputs
hidden_states: torch.FloatTensor,
attention_mask: torch.Tensor,
# Encoder inputs
# Text
text_hidden_states: Optional[torch.FloatTensor] = None,
text_attention_mask: Optional[torch.Tensor] = None,
# Lyric
lyric_hidden_states: Optional[torch.LongTensor] = None,
lyric_attention_mask: Optional[torch.Tensor] = None,
# Reference audio for timbre
refer_audio_acoustic_hidden_states_packed: Optional[torch.Tensor] = None,
refer_audio_order_mask: Optional[torch.LongTensor] = None,
src_latents: torch.FloatTensor = None,
chunk_masks: torch.FloatTensor = None,
is_covers: torch.Tensor = None,
silence_latent: torch.FloatTensor = None,
cfg_ratio: float = 0.15,
):
"""
Forward pass for training (computes training losses).
"""
# Prepare conditioning inputs (encoder states, context latents)
encoder_hidden_states, encoder_attention_mask, context_latents = self.prepare_condition(
text_hidden_states=text_hidden_states,
text_attention_mask=text_attention_mask,
lyric_hidden_states=lyric_hidden_states,
lyric_attention_mask=lyric_attention_mask,
refer_audio_acoustic_hidden_states_packed=refer_audio_acoustic_hidden_states_packed,
refer_audio_order_mask=refer_audio_order_mask,
hidden_states=src_latents,
attention_mask=attention_mask,
silence_latent=silence_latent,
src_latents=src_latents,
chunk_masks=chunk_masks,
is_covers=is_covers,
)
bsz, device, dtype = hidden_states.shape[0], hidden_states.device, hidden_states.dtype
# Classifier-free guidance: randomly drop conditions with probability cfg_ratio
# This helps the model learn to work with and without conditions
full_cfg_condition_mask = torch.where(
(torch.rand(size=(bsz,), device=device, dtype=dtype) < cfg_ratio),
torch.zeros(size=(bsz,), device=device, dtype=dtype),
torch.ones(size=(bsz,), device=device, dtype=dtype)
).view(-1, 1, 1)
# Replace dropped conditions with null condition embedding
encoder_hidden_states = torch.where(full_cfg_condition_mask > 0, encoder_hidden_states, self.null_condition_emb.expand_as(encoder_hidden_states))
# Flow matching setup: sample noise x1 and interpolate with data x0
x1 = torch.randn_like(hidden_states) # Noise
x0 = hidden_states # Data
# Sample timesteps t and r for flow matching
t, r = sample_t_r(bsz, device, dtype, self.config.data_proportion, self.config.timestep_mu, self.config.timestep_sigma, use_meanflow=False)
t_ = t.unsqueeze(-1).unsqueeze(-1)
# Interpolate: x_t = t * x1 + (1 - t) * x0
xt = t_ * x1 + (1.0 - t_) * x0
# Predict flow (velocity) from diffusion model
decoder_outputs = self.decoder(
hidden_states=xt,
timestep=t,
timestep_r=t,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
context_latents=context_latents,
)
# Flow matching loss: predict the flow field v = x1 - x0
flow = x1 - x0
diffusion_loss = F.mse_loss(decoder_outputs[0], flow)
return {
"diffusion_loss": diffusion_loss,
}
def training_losses(self, **kwargs):
return self.forward(**kwargs)
def prepare_noise(self, context_latents: torch.FloatTensor, seed: Union[int, List[int], None] = None):
"""
Prepare noise tensor for generation with optional seeding.
Args:
context_latents: Context latents to determine noise shape
seed: Can be int, List[int], or None. If None, uses random noise.
Returns:
Noise tensor of appropriate shape
"""
bsz = context_latents.shape[0]
device = context_latents.device
dtype = context_latents.dtype
# Handle seed: can be int, List[int], or None
src_latents_shape = (context_latents.shape[0], context_latents.shape[1], context_latents.shape[-1] // 2)
if seed is None:
# No seed provided - use random
noise = torch.randn(src_latents_shape, device=device, dtype=dtype)
elif isinstance(seed, list):
# List of seeds - generate noise for each sample separately
noise_list = []
for i, s in enumerate(seed):
if s is None or s < 0:
# Random seed for this sample
noise_i = torch.randn(1, src_latents_shape[1], src_latents_shape[2], device=device, dtype=dtype)
else:
# Use specific seed for this sample
generator = torch.Generator(device=device).manual_seed(int(s))
noise_i = torch.randn(1, src_latents_shape[1], src_latents_shape[2], generator=generator, device=device, dtype=dtype)
noise_list.append(noise_i)
noise = torch.cat(noise_list, dim=0)
else:
# Single seed for all samples
generator = torch.Generator(device=device).manual_seed(int(seed))
noise = torch.randn(src_latents_shape, generator=generator, device=device, dtype=dtype)
return noise
def get_x0_from_noise(self, zt, vt, t):
return zt - vt * t.unsqueeze(-1).unsqueeze(-1)
def renoise(self, x, t, noise=None):
if noise is None:
noise = torch.randn_like(x)
if isinstance(t, torch.Tensor) and t.ndim != x.ndim:
t = t.unsqueeze(-1).unsqueeze(-1)
xt = t * noise + (1 - t) * x
return xt
def generate_audio(
self,
text_hidden_states: torch.FloatTensor,
text_attention_mask: torch.FloatTensor,
lyric_hidden_states: torch.FloatTensor,
lyric_attention_mask: torch.FloatTensor,
refer_audio_acoustic_hidden_states_packed: torch.FloatTensor,
refer_audio_order_mask: torch.LongTensor,
src_latents: torch.FloatTensor,
chunk_masks: torch.FloatTensor,
is_covers: torch.Tensor,
silence_latent: Optional[torch.FloatTensor] = None,
attention_mask: torch.Tensor = None,
seed: int = None,
infer_method: str = "ode",
use_cache: bool = True,
infer_steps: int = 30,
diffusion_guidance_sale: float = 7.0,
audio_cover_strength: float = 1.0,
non_cover_text_hidden_states: Optional[torch.FloatTensor] = None,
non_cover_text_attention_mask: Optional[torch.FloatTensor] = None,
cfg_interval_start: float = 0.0,
cfg_interval_end: float = 1.0,
precomputed_lm_hints_25Hz: Optional[torch.FloatTensor] = None,
audio_codes: Optional[torch.FloatTensor] = None,
use_progress_bar: bool = True,
use_adg: bool = False,
shift: float = 1.0,
**kwargs,
):
if attention_mask is None:
latent_length = src_latents.shape[1]
attention_mask = torch.ones(src_latents.shape[0], latent_length, device=src_latents.device, dtype=src_latents.dtype)
time_costs = {}
start_time = time.time()
total_start_time = start_time
encoder_hidden_states, encoder_attention_mask, context_latents = self.prepare_condition(
text_hidden_states=text_hidden_states,
text_attention_mask=text_attention_mask,
lyric_hidden_states=lyric_hidden_states,
lyric_attention_mask=lyric_attention_mask,
refer_audio_acoustic_hidden_states_packed=refer_audio_acoustic_hidden_states_packed,
refer_audio_order_mask=refer_audio_order_mask,
hidden_states=src_latents,
attention_mask=attention_mask,
silence_latent=silence_latent,
src_latents=src_latents,
chunk_masks=chunk_masks,
is_covers=is_covers,
precomputed_lm_hints_25Hz=precomputed_lm_hints_25Hz,
audio_codes=audio_codes,
)
encoder_hidden_states_non_cover, encoder_attention_mask_non_cover, context_latents_non_cover = None, None, None
if audio_cover_strength < 1.0:
non_is_covers = torch.zeros_like(is_covers, device=is_covers.device, dtype=is_covers.dtype)
# Use silence_latent for non-cover condition to simulate text2music mode (no reference audio)
silence_latent_expanded = silence_latent[:, :src_latents.shape[1], :].expand(src_latents.shape[0], -1, -1)
encoder_hidden_states_non_cover, encoder_attention_mask_non_cover, context_latents_non_cover = self.prepare_condition(
text_hidden_states=non_cover_text_hidden_states,
text_attention_mask=non_cover_text_attention_mask,
lyric_hidden_states=lyric_hidden_states,
lyric_attention_mask=lyric_attention_mask,
refer_audio_acoustic_hidden_states_packed=refer_audio_acoustic_hidden_states_packed,
refer_audio_order_mask=refer_audio_order_mask,
hidden_states=silence_latent_expanded,
attention_mask=attention_mask,
silence_latent=silence_latent,
src_latents=silence_latent_expanded,
chunk_masks=chunk_masks,
is_covers=non_is_covers,
precomputed_lm_hints_25Hz=None,
audio_codes=None,
)
end_time = time.time()
time_costs["encoder_time_cost"] = end_time - start_time
start_time = end_time
# Calculate cover steps based on audio_cover_strength
cover_steps = int(infer_steps * audio_cover_strength)
device, dtype = context_latents.device, context_latents.dtype
t = torch.linspace(1.0, 0.0, infer_steps + 1, device=device, dtype=dtype)
# Apply shift transformation to timesteps if shift != 1.0
if shift != 1.0:
t = shift * t / (1 + (shift - 1) * t)
if use_progress_bar:
iterator = tqdm(zip(t[:-1], t[1:]), total=infer_steps)
else:
iterator = zip(t[:-1], t[1:])
noise = self.prepare_noise(context_latents, seed)
bsz, device, dtype = context_latents.shape[0], context_latents.device, context_latents.dtype
past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache())
momentum_buffer = MomentumBuffer()
# main task condition
do_cfg_guidance = diffusion_guidance_sale > 1.0
if do_cfg_guidance:
encoder_hidden_states = torch.cat([encoder_hidden_states, self.null_condition_emb.expand_as(encoder_hidden_states)], dim=0)
encoder_attention_mask = torch.cat([encoder_attention_mask, encoder_attention_mask], dim=0)
# src_latents
context_latents = torch.cat([context_latents, context_latents], dim=0)
attention_mask = torch.cat([attention_mask, attention_mask], dim=0)
xt = noise
with torch.no_grad():
for step_idx, (t_curr, t_prev) in enumerate(iterator):
if step_idx >= cover_steps:
if do_cfg_guidance:
encoder_hidden_states_non_cover = torch.cat([encoder_hidden_states_non_cover, self.null_condition_emb.expand_as(encoder_hidden_states_non_cover)], dim=0)
encoder_attention_mask_non_cover = torch.cat([encoder_attention_mask_non_cover, encoder_attention_mask_non_cover], dim=0)
# src_latents
context_latents_non_cover = torch.cat([context_latents_non_cover, context_latents_non_cover], dim=0)
encoder_hidden_states = encoder_hidden_states_non_cover
encoder_attention_mask = encoder_attention_mask_non_cover
context_latents = context_latents_non_cover
past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache())
x = torch.cat([xt, xt], dim=0) if do_cfg_guidance else xt
t_curr_tensor = t_curr * torch.ones((x.shape[0],), device=device, dtype=dtype)
decoder_outputs = self.decoder(
hidden_states=x,
timestep=t_curr_tensor,
timestep_r=t_curr_tensor,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
context_latents=context_latents,
use_cache=True,
past_key_values=past_key_values,
)
vt = decoder_outputs[0]
past_key_values = decoder_outputs[1]
apply_cfg_guidance = t_curr >= cfg_interval_start and t_curr <= cfg_interval_end
if do_cfg_guidance:
pred_cond, pred_null_cond = vt.chunk(2)
if apply_cfg_guidance:
if not use_adg:
vt = apg_forward(
pred_cond=pred_cond,
pred_uncond=pred_null_cond,
guidance_scale=diffusion_guidance_sale,
momentum_buffer=momentum_buffer,
dims=[1],
)
else:
vt = adg_forward(
latents=xt,
noise_pred_cond=pred_cond,
noise_pred_uncond=pred_null_cond,
sigma=t_curr,
guidance_scale=diffusion_guidance_sale,
)
else:
vt = pred_cond
# Update x_t based on inference method
if infer_method == "sde":
# Stochastic Differential Equation: predict clean, then re-add noise
t_curr_bsz = t_curr * torch.ones((bsz,), device=device, dtype=dtype)
pred_clean = self.get_x0_from_noise(xt, vt, t_curr_bsz)
next_timestep = 1.0 - (float(step_idx + 1) / infer_steps)
xt = self.renoise(pred_clean, next_timestep)
elif infer_method == "ode":
# Ordinary Differential Equation: Euler method
# dx/dt = -v, so x_{t+1} = x_t - v_t * dt
dt = t_curr - t_prev
dt_tensor = dt * torch.ones((bsz,), device=device, dtype=dtype).unsqueeze(-1).unsqueeze(-1)
xt = xt - vt * dt_tensor
x_gen = xt
end_time = time.time()
time_costs["diffusion_time_cost"] = end_time - start_time
time_costs["diffusion_per_step_time_cost"] = time_costs["diffusion_time_cost"] / infer_steps
time_costs["total_time_cost"] = end_time - total_start_time
return {
"target_latents": x_gen,
"time_costs": time_costs,
}
def test_forward(model, seed=42):
# Fix random seed for reproducibility
import random
import numpy as np
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# Get model dtype and device
model_dtype = next(model.parameters()).dtype
device = next(model.parameters()).device
print(f"Testing with dtype: {model_dtype}, device: {device}, seed: {seed}")
# Test data preparation with matching dtype
text_hidden_states = torch.randn(2, 77, 1024, dtype=model_dtype, device=device)
text_attention_mask = torch.ones(2, 77, dtype=model_dtype, device=device)
lyric_hidden_states = torch.randn(2, 123, 1024, dtype=model_dtype, device=device)
lyric_attention_mask = torch.ones(2, 123, dtype=model_dtype, device=device)
refer_audio_acoustic_hidden_states_packed = torch.randn(3, 750, 64, dtype=model_dtype, device=device)
refer_audio_order_mask = torch.LongTensor([0, 0, 1]).to(device)
# Base config: 25 Hz hidden states → 10 s = 250 frames (round to int)
base_seconds = 10
frames_per_second = 25
base_seq_len = base_seconds * frames_per_second
hidden_states = torch.randn(2, base_seq_len, 64, dtype=model_dtype, device=device)
attention_mask = torch.ones(2, base_seq_len, dtype=model_dtype, device=device)
# Add some padding to test mask behavior
pad_start = max(base_seq_len // 2, 1)
attention_mask[0, pad_start:] = 0
chunk_mask = torch.ones(2, base_seq_len, 64, dtype=model_dtype, device=device)
chunk_mask[0, pad_start:] = 0
silence_latent = torch.randn(2, base_seq_len, 64, dtype=model_dtype, device=device)
# New required parameters for updated training logic
src_latents = torch.randn(2, base_seq_len, 64, dtype=model_dtype, device=device) # Source latents for context
is_covers = torch.tensor([0, 1], dtype=torch.long, device=device) # Cover song indicators (0=original, 1=cover)
# Test 1: Flow matching training (using 10s sequence for sanity check by default)
print(f"Testing flow matching training with {base_seconds}s sequence ({base_seq_len} frames @ {frames_per_second}Hz)...")
outputs = model.training_losses(
hidden_states=hidden_states,
attention_mask=attention_mask,
chunk_masks=chunk_mask,
text_hidden_states=text_hidden_states,
text_attention_mask=text_attention_mask,
lyric_hidden_states=lyric_hidden_states,
lyric_attention_mask=lyric_attention_mask,
refer_audio_acoustic_hidden_states_packed=refer_audio_acoustic_hidden_states_packed,
refer_audio_order_mask=refer_audio_order_mask,
silence_latent=silence_latent,
src_latents=src_latents,
is_covers=is_covers,
cfg_ratio=0.15,
)
loss = outputs['diffusion_loss']
print(f"Flow matching loss: {loss.item():.6f}")
print(f" Loss stats - min: {loss.min().item():.6f}, max: {loss.max().item():.6f}, mean: {loss.mean().item():.6f}, std: {loss.std().item() if loss.numel() > 1 else 0:.6f}")
# Test 2: Generation with flow matching, testing throughput for different sequence lengths
lengths_seconds = [10, 30, 60, 120, 180, 240]
infer_steps = 2 # Can be increased as needed (e.g., 50/100) to better approximate real inference
print("\n===== Throughput benchmark (25Hz hidden states) =====")
for seconds in lengths_seconds:
seq_len = seconds * frames_per_second
# Reconstruct inputs for current sequence length
cur_hidden_states = torch.randn(2, seq_len, 64, dtype=model_dtype, device=device)
cur_attention_mask = torch.ones(2, seq_len, dtype=model_dtype, device=device)
cur_chunk_mask = torch.ones(2, seq_len, 64, dtype=model_dtype, device=device)
cur_silence_latent = torch.randn(2, seq_len, 64, dtype=model_dtype, device=device)
cur_src_latents = torch.randn(2, seq_len, 64, dtype=model_dtype, device=device)
print(f"\n--- {seconds}s input ({seq_len} frames @ {frames_per_second}Hz) ---")
outputs = model.generate_audio(
text_hidden_states=text_hidden_states,
text_attention_mask=text_attention_mask,
lyric_hidden_states=lyric_hidden_states,
lyric_attention_mask=lyric_attention_mask,
refer_audio_acoustic_hidden_states_packed=refer_audio_acoustic_hidden_states_packed,
refer_audio_order_mask=refer_audio_order_mask,
src_latents=cur_src_latents,
chunk_masks=cur_chunk_mask,
silence_latent=cur_silence_latent,
infer_steps=infer_steps,
is_covers=is_covers,
seed=1234,
)
target_latents = outputs["target_latents"]
time_costs = outputs.get("time_costs", {})
total_time = time_costs.get("total_time_cost", None)
diffusion_time = time_costs.get("diffusion_time_cost", None)
# Output shape and statistics
print(f"Generated latents shape: {target_latents.shape}")
print(
f"Stats - min: {target_latents.min().item():.4f}, "
f"max: {target_latents.max().item():.4f}, "
f"mean: {target_latents.mean().item():.4f}, "
f"std: {target_latents.std().item():.4f}"
)
# Calculate throughput: statistics by frame count and audio seconds
bsz, t_len = target_latents.shape[0], target_latents.shape[1]
audio_seconds = t_len / frames_per_second
if total_time is not None:
frames_throughput = (bsz * t_len) / total_time
seconds_throughput = (bsz * audio_seconds) / total_time
print(
f"Time costs: total={total_time:.4f}s, diffusion={diffusion_time:.4f}s "
f"({infer_steps} steps)"
if diffusion_time is not None
else f"Time costs: total={total_time:.4f}s"
)
print(
f"Throughput (based on total_time): "
f"{frames_throughput:.2f} frames/s, "
f"{seconds_throughput:.2f} audio-seconds/s (batch={bsz})"
)
else:
print("Time costs not available in outputs['time_costs']; only basic stats printed.")
if __name__ == "__main__":
from torch.profiler import profile, record_function, ProfilerActivity
import os, torch
import time
from transformers import AutoModel
config = AceStepConfig()
start = time.time()
import os
model_dir = os.path.dirname(os.path.abspath(__file__))
model = AceStepConditionGenerationModel.from_pretrained(model_dir)
end = time.time()
# model.config._attn_implementation = "sdpa"
model.config._attn_implementation = "flash_attention_2"
model.eval()
# model = model.to("cpu")
# model = model.float()
model = model.to("cuda")
model = model.bfloat16()
test_forward(model)