Instructions to use deAPI-ai/acestep-1-5-base with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use deAPI-ai/acestep-1-5-base with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("deAPI-ai/acestep-1-5-base", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
Sync model code with ace_step==1.6.0 to prevent runtime mutation
Browse files
acestep-v15-base/modeling_acestep_v15_base.py
CHANGED
|
@@ -11,6 +11,7 @@
|
|
| 11 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
# See the License for the specific language governing permissions and
|
| 13 |
# limitations under the License.
|
|
|
|
| 14 |
import math
|
| 15 |
import time
|
| 16 |
from typing import Callable, List, Optional, Union
|
|
@@ -44,10 +45,10 @@ from vector_quantize_pytorch import ResidualFSQ
|
|
| 44 |
# Local config import with fallback
|
| 45 |
try:
|
| 46 |
from .configuration_acestep_v15 import AceStepConfig
|
| 47 |
-
from .apg_guidance import adg_forward, apg_forward, MomentumBuffer
|
| 48 |
except ImportError:
|
| 49 |
from configuration_acestep_v15 import AceStepConfig
|
| 50 |
-
from apg_guidance import adg_forward, apg_forward, MomentumBuffer
|
| 51 |
|
| 52 |
|
| 53 |
logger = logging.get_logger(__name__)
|
|
@@ -115,7 +116,7 @@ def create_4d_mask(
|
|
| 115 |
# We want to mask out invalid keys (columns)
|
| 116 |
# Expand shape: [Batch, 1, 1, Seq_Len]
|
| 117 |
padding_mask_4d = attention_mask.view(attention_mask.shape[0], 1, 1, seq_len).to(torch.bool)
|
| 118 |
-
|
| 119 |
# Broadcasting: Geometry Mask [1, 1, L, L] & Padding Mask [B, 1, 1, L]
|
| 120 |
# Result shape: [B, 1, L, L]
|
| 121 |
valid_mask = valid_mask & padding_mask_4d
|
|
@@ -125,13 +126,13 @@ def create_4d_mask(
|
|
| 125 |
# ------------------------------------------------------
|
| 126 |
# Get the minimal value for current dtype
|
| 127 |
min_dtype = torch.finfo(dtype).min
|
| 128 |
-
|
| 129 |
# Create result tensor filled with -inf by default
|
| 130 |
mask_tensor = torch.full(valid_mask.shape, min_dtype, dtype=dtype, device=device)
|
| 131 |
-
|
| 132 |
# Set valid positions to 0.0
|
| 133 |
mask_tensor.masked_fill_(valid_mask, 0.0)
|
| 134 |
-
|
| 135 |
return mask_tensor
|
| 136 |
|
| 137 |
|
|
@@ -200,7 +201,7 @@ def sample_t_r(batch_size, device, dtype, data_proportion=0.0, timestep_mu=-0.4,
|
|
| 200 |
class TimestepEmbedding(nn.Module):
|
| 201 |
"""
|
| 202 |
Timestep embedding module for diffusion models.
|
| 203 |
-
|
| 204 |
Converts timestep values into high-dimensional embeddings using sinusoidal
|
| 205 |
positional encoding, followed by MLP layers. Used for conditioning diffusion
|
| 206 |
models on timestep information.
|
|
@@ -217,7 +218,7 @@ class TimestepEmbedding(nn.Module):
|
|
| 217 |
self.act1 = nn.SiLU()
|
| 218 |
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim, bias=True)
|
| 219 |
self.in_channels = in_channels
|
| 220 |
-
|
| 221 |
self.act2 = nn.SiLU()
|
| 222 |
self.time_proj = nn.Linear(time_embed_dim, time_embed_dim * 6)
|
| 223 |
self.scale = scale
|
|
@@ -305,7 +306,7 @@ class AceStepAttention(nn.Module):
|
|
| 305 |
|
| 306 |
# Determine if this is cross-attention (requires encoder_hidden_states)
|
| 307 |
is_cross_attention = self.is_cross_attention and encoder_hidden_states is not None
|
| 308 |
-
|
| 309 |
# Cross-attention path: attend to encoder hidden states
|
| 310 |
if is_cross_attention:
|
| 311 |
encoder_hidden_shape = (*encoder_hidden_states.shape[:-1], -1, self.head_dim)
|
|
@@ -313,7 +314,7 @@ class AceStepAttention(nn.Module):
|
|
| 313 |
is_updated = past_key_value.is_updated.get(self.layer_idx)
|
| 314 |
# After the first generated token, we can reuse all key/value states from cache
|
| 315 |
curr_past_key_value = past_key_value.cross_attention_cache
|
| 316 |
-
|
| 317 |
# Conditions for calculating key and value states
|
| 318 |
if not is_updated:
|
| 319 |
# Compute and cache K/V for the first time
|
|
@@ -331,7 +332,7 @@ class AceStepAttention(nn.Module):
|
|
| 331 |
# No cache used, compute K/V directly
|
| 332 |
key_states = self.k_norm(self.k_proj(encoder_hidden_states).view(encoder_hidden_shape)).transpose(1, 2)
|
| 333 |
value_states = self.v_proj(encoder_hidden_states).view(encoder_hidden_shape).transpose(1, 2)
|
| 334 |
-
|
| 335 |
# Self-attention path: attend to the same sequence
|
| 336 |
else:
|
| 337 |
# Project and normalize key/value states for self-attention
|
|
@@ -353,7 +354,7 @@ class AceStepAttention(nn.Module):
|
|
| 353 |
attention_interface: Callable = eager_attention_forward
|
| 354 |
elif self.config._attn_implementation != "eager":
|
| 355 |
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
| 356 |
-
|
| 357 |
attn_output, attn_weights = attention_interface(
|
| 358 |
self,
|
| 359 |
query_states,
|
|
@@ -443,12 +444,12 @@ class AceStepEncoderLayer(GradientCheckpointingLayer):
|
|
| 443 |
class AceStepDiTLayer(GradientCheckpointingLayer):
|
| 444 |
"""
|
| 445 |
DiT (Diffusion Transformer) layer for AceStep model.
|
| 446 |
-
|
| 447 |
Implements a transformer layer with three main components:
|
| 448 |
1. Self-attention with adaptive layer norm (AdaLN)
|
| 449 |
2. Cross-attention (optional) for conditioning on encoder outputs
|
| 450 |
3. Feed-forward MLP with adaptive layer norm
|
| 451 |
-
|
| 452 |
Uses scale-shift modulation from timestep embeddings for adaptive normalization.
|
| 453 |
"""
|
| 454 |
def __init__(self, config: AceStepConfig, layer_idx: int, use_cross_attention: bool = True):
|
|
@@ -471,7 +472,7 @@ class AceStepDiTLayer(GradientCheckpointingLayer):
|
|
| 471 |
# Scale-shift table for adaptive layer norm modulation (6 values: 3 for self-attn, 3 for MLP)
|
| 472 |
self.scale_shift_table = nn.Parameter(torch.randn(1, 6, config.hidden_size) / config.hidden_size**0.5)
|
| 473 |
self.attention_type = config.layer_types[layer_idx]
|
| 474 |
-
|
| 475 |
def forward(
|
| 476 |
self,
|
| 477 |
hidden_states: torch.Tensor,
|
|
@@ -577,14 +578,14 @@ class AceStepPreTrainedModel(PreTrainedModel):
|
|
| 577 |
class AceStepLyricEncoder(AceStepPreTrainedModel):
|
| 578 |
"""
|
| 579 |
Encoder for processing lyric text embeddings.
|
| 580 |
-
|
| 581 |
Encodes lyric text hidden states using a transformer encoder architecture
|
| 582 |
with bidirectional attention. Projects text embeddings to model hidden size
|
| 583 |
and processes them through multiple encoder layers.
|
| 584 |
"""
|
| 585 |
def __init__(self, config):
|
| 586 |
super().__init__(config)
|
| 587 |
-
|
| 588 |
# Project text embeddings to model hidden size
|
| 589 |
self.embed_tokens = nn.Linear(config.text_hidden_dim, config.hidden_size)
|
| 590 |
self.norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
@@ -618,7 +619,7 @@ class AceStepLyricEncoder(AceStepPreTrainedModel):
|
|
| 618 |
assert input_ids is None, "Only `input_ids` is supported for the lyric encoder."
|
| 619 |
assert attention_mask is not None, "Attention mask must be provided for the lyric encoder."
|
| 620 |
assert inputs_embeds is not None, "Inputs embeddings must be provided for the lyric encoder."
|
| 621 |
-
|
| 622 |
# Project input embeddings: N x T x text_hidden_dim -> N x T x hidden_size
|
| 623 |
inputs_embeds = self.embed_tokens(inputs_embeds)
|
| 624 |
# Cache position: only used for mask construction (not for actual caching)
|
|
@@ -632,7 +633,7 @@ class AceStepLyricEncoder(AceStepPreTrainedModel):
|
|
| 632 |
seq_len = inputs_embeds.shape[1]
|
| 633 |
dtype = inputs_embeds.dtype
|
| 634 |
device = inputs_embeds.device
|
| 635 |
-
|
| 636 |
# 判断是否使用 Flash Attention 2
|
| 637 |
is_flash_attn = (self.config._attn_implementation == "flash_attention_2")
|
| 638 |
|
|
@@ -649,7 +650,7 @@ class AceStepLyricEncoder(AceStepPreTrainedModel):
|
|
| 649 |
# 如果没有 padding mask,传 None。
|
| 650 |
# 滑动窗口逻辑由 Layer 内部传给 FA kernel 的 sliding_window 参数控制。
|
| 651 |
full_attn_mask = attention_mask
|
| 652 |
-
|
| 653 |
# 这里的逻辑是:如果配置启用了滑动窗口,FA 模式下我们也只需要传基础的 padding mask
|
| 654 |
# Layer 会自己决定是否调用带 sliding window 的 kernel
|
| 655 |
sliding_attn_mask = attention_mask if self.config.use_sliding_window else None
|
|
@@ -659,7 +660,7 @@ class AceStepLyricEncoder(AceStepPreTrainedModel):
|
|
| 659 |
# 场景 B: CPU / Mac / SDPA (Eager 模式)
|
| 660 |
# -------------------------------------------------------
|
| 661 |
# 必须手动生成 4D Mask [B, 1, L, L]
|
| 662 |
-
|
| 663 |
# 1. Full Attention (Bidirectional, Global)
|
| 664 |
# 对应原来的 create_causal_mask + bidirectional
|
| 665 |
full_attn_mask = create_4d_mask(
|
|
@@ -734,7 +735,7 @@ class AceStepLyricEncoder(AceStepPreTrainedModel):
|
|
| 734 |
class AttentionPooler(AceStepPreTrainedModel):
|
| 735 |
"""
|
| 736 |
Attention-based pooling module.
|
| 737 |
-
|
| 738 |
Pools sequences of patches using a special token and attention mechanism.
|
| 739 |
The special token attends to all patches and its output is used as the
|
| 740 |
pooled representation. Used for aggregating patch-level features into
|
|
@@ -782,7 +783,7 @@ class AttentionPooler(AceStepPreTrainedModel):
|
|
| 782 |
seq_len = x.shape[1]
|
| 783 |
dtype = x.dtype
|
| 784 |
device = x.device
|
| 785 |
-
|
| 786 |
# 判断是否使用 Flash Attention 2
|
| 787 |
is_flash_attn = (self.config._attn_implementation == "flash_attention_2")
|
| 788 |
|
|
@@ -799,7 +800,7 @@ class AttentionPooler(AceStepPreTrainedModel):
|
|
| 799 |
# 如果没有 padding mask,传 None。
|
| 800 |
# 滑动窗口逻辑由 Layer 内部传给 FA kernel 的 sliding_window 参数控制。
|
| 801 |
full_attn_mask = attention_mask
|
| 802 |
-
|
| 803 |
# 这里的逻辑是:如果配置启用了滑动窗口,FA 模式下我们也只需要传基础的 padding mask
|
| 804 |
# Layer 会自己决定是否调用带 sliding window 的 kernel
|
| 805 |
sliding_attn_mask = attention_mask if self.config.use_sliding_window else None
|
|
@@ -809,7 +810,7 @@ class AttentionPooler(AceStepPreTrainedModel):
|
|
| 809 |
# 场景 B: CPU / Mac / SDPA (Eager 模式)
|
| 810 |
# -------------------------------------------------------
|
| 811 |
# 必须手动生成 4D Mask [B, 1, L, L]
|
| 812 |
-
|
| 813 |
# 1. Full Attention (Bidirectional, Global)
|
| 814 |
# 对应原来的 create_causal_mask + bidirectional
|
| 815 |
full_attn_mask = create_4d_mask(
|
|
@@ -840,7 +841,7 @@ class AttentionPooler(AceStepPreTrainedModel):
|
|
| 840 |
"full_attention": full_attn_mask,
|
| 841 |
"sliding_attention": sliding_attn_mask,
|
| 842 |
}
|
| 843 |
-
|
| 844 |
for layer_module in self.layers:
|
| 845 |
layer_outputs = layer_module(
|
| 846 |
hidden_states,
|
|
@@ -852,7 +853,7 @@ class AttentionPooler(AceStepPreTrainedModel):
|
|
| 852 |
hidden_states = layer_outputs[0]
|
| 853 |
|
| 854 |
hidden_states = self.norm(hidden_states)
|
| 855 |
-
|
| 856 |
# Extract the special token output (first position) as pooled representation
|
| 857 |
cls_output = hidden_states[:, 0, :]
|
| 858 |
cls_output = rearrange(cls_output, "(b t) c -> b t c", b=B)
|
|
@@ -862,7 +863,7 @@ class AttentionPooler(AceStepPreTrainedModel):
|
|
| 862 |
class AudioTokenDetokenizer(AceStepPreTrainedModel):
|
| 863 |
"""
|
| 864 |
Audio token detokenizer module.
|
| 865 |
-
|
| 866 |
Converts quantized audio tokens back to continuous acoustic representations.
|
| 867 |
Expands each token into multiple patches using special tokens, processes them
|
| 868 |
through encoder layers, and projects to acoustic hidden dimension.
|
|
@@ -917,7 +918,7 @@ class AudioTokenDetokenizer(AceStepPreTrainedModel):
|
|
| 917 |
seq_len = x.shape[1]
|
| 918 |
dtype = x.dtype
|
| 919 |
device = x.device
|
| 920 |
-
|
| 921 |
# 判断是否使用 Flash Attention 2
|
| 922 |
is_flash_attn = (self.config._attn_implementation == "flash_attention_2")
|
| 923 |
|
|
@@ -934,7 +935,7 @@ class AudioTokenDetokenizer(AceStepPreTrainedModel):
|
|
| 934 |
# 如果没有 padding mask,传 None。
|
| 935 |
# 滑动窗口逻辑由 Layer 内部传给 FA kernel 的 sliding_window 参数控制。
|
| 936 |
full_attn_mask = attention_mask
|
| 937 |
-
|
| 938 |
# 这里的逻辑是:如果配置启用了滑动窗口,FA 模式下我们也只需要传基础的 padding mask
|
| 939 |
# Layer 会自己决定是否调用带 sliding window 的 kernel
|
| 940 |
sliding_attn_mask = attention_mask if self.config.use_sliding_window else None
|
|
@@ -944,7 +945,7 @@ class AudioTokenDetokenizer(AceStepPreTrainedModel):
|
|
| 944 |
# 场景 B: CPU / Mac / SDPA (Eager 模式)
|
| 945 |
# -------------------------------------------------------
|
| 946 |
# 必须手动生成 4D Mask [B, 1, L, L]
|
| 947 |
-
|
| 948 |
# 1. Full Attention (Bidirectional, Global)
|
| 949 |
# 对应原来的 create_causal_mask + bidirectional
|
| 950 |
full_attn_mask = create_4d_mask(
|
|
@@ -975,7 +976,7 @@ class AudioTokenDetokenizer(AceStepPreTrainedModel):
|
|
| 975 |
"full_attention": full_attn_mask,
|
| 976 |
"sliding_attention": sliding_attn_mask,
|
| 977 |
}
|
| 978 |
-
|
| 979 |
for layer_module in self.layers:
|
| 980 |
layer_outputs = layer_module(
|
| 981 |
hidden_states,
|
|
@@ -987,7 +988,7 @@ class AudioTokenDetokenizer(AceStepPreTrainedModel):
|
|
| 987 |
hidden_states = layer_outputs[0]
|
| 988 |
|
| 989 |
hidden_states = self.norm(hidden_states)
|
| 990 |
-
|
| 991 |
hidden_states = self.proj_out(hidden_states)
|
| 992 |
|
| 993 |
hidden_states = rearrange(hidden_states, "(b t) p c -> b (t p) c", b=B, p=self.config.pool_window_size)
|
|
@@ -997,7 +998,7 @@ class AudioTokenDetokenizer(AceStepPreTrainedModel):
|
|
| 997 |
class AceStepTimbreEncoder(AceStepPreTrainedModel):
|
| 998 |
"""
|
| 999 |
Encoder for extracting timbre embeddings from reference audio.
|
| 1000 |
-
|
| 1001 |
Processes packed reference audio acoustic features to extract timbre
|
| 1002 |
representations. Uses a special token (CLS-like) to aggregate information
|
| 1003 |
from the entire reference audio sequence. Outputs are unpacked back to
|
|
@@ -1005,7 +1006,7 @@ class AceStepTimbreEncoder(AceStepPreTrainedModel):
|
|
| 1005 |
"""
|
| 1006 |
def __init__(self, config):
|
| 1007 |
super().__init__(config)
|
| 1008 |
-
|
| 1009 |
# Project acoustic features to model hidden size
|
| 1010 |
self.embed_tokens = nn.Linear(config.timbre_hidden_dim, config.hidden_size)
|
| 1011 |
self.norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
@@ -1036,40 +1037,40 @@ class AceStepTimbreEncoder(AceStepPreTrainedModel):
|
|
| 1036 |
N, d = timbre_embs_packed.shape
|
| 1037 |
device = timbre_embs_packed.device
|
| 1038 |
dtype = timbre_embs_packed.dtype
|
| 1039 |
-
|
| 1040 |
# Get batch size
|
| 1041 |
B = int(refer_audio_order_mask.max().item() + 1)
|
| 1042 |
-
|
| 1043 |
# Calculate element count and positions for each batch
|
| 1044 |
counts = torch.bincount(refer_audio_order_mask, minlength=B)
|
| 1045 |
max_count = counts.max().item()
|
| 1046 |
-
|
| 1047 |
# Calculate positions within batch
|
| 1048 |
sorted_indices = torch.argsort(refer_audio_order_mask * N + torch.arange(N, device=device), stable=True)
|
| 1049 |
sorted_batch_ids = refer_audio_order_mask[sorted_indices]
|
| 1050 |
-
|
| 1051 |
positions = torch.arange(N, device=device)
|
| 1052 |
-
batch_starts = torch.cat([torch.tensor([0], device=device),
|
| 1053 |
torch.cumsum(counts, dim=0)[:-1]])
|
| 1054 |
positions_in_sorted = positions - batch_starts[sorted_batch_ids]
|
| 1055 |
-
|
| 1056 |
inverse_indices = torch.empty_like(sorted_indices)
|
| 1057 |
inverse_indices[sorted_indices] = torch.arange(N, device=device)
|
| 1058 |
positions_in_batch = positions_in_sorted[inverse_indices]
|
| 1059 |
-
|
| 1060 |
# Use one-hot encoding and matrix multiplication (gradient-friendly approach)
|
| 1061 |
# Create one-hot encoding
|
| 1062 |
indices_2d = refer_audio_order_mask * max_count + positions_in_batch # (N,)
|
| 1063 |
one_hot = F.one_hot(indices_2d, num_classes=B * max_count).to(dtype) # (N, B*max_count)
|
| 1064 |
-
|
| 1065 |
# Rearrange using matrix multiplication
|
| 1066 |
timbre_embs_flat = one_hot.t() @ timbre_embs_packed # (B*max_count, d)
|
| 1067 |
timbre_embs_unpack = timbre_embs_flat.reshape(B, max_count, d)
|
| 1068 |
-
|
| 1069 |
# Create mask indicating valid positions
|
| 1070 |
mask_flat = (one_hot.sum(dim=0) > 0).long() # (B*max_count,)
|
| 1071 |
new_mask = mask_flat.reshape(B, max_count)
|
| 1072 |
-
|
| 1073 |
return timbre_embs_unpack, new_mask
|
| 1074 |
|
| 1075 |
@can_return_tuple
|
|
@@ -1093,7 +1094,7 @@ class AceStepTimbreEncoder(AceStepPreTrainedModel):
|
|
| 1093 |
seq_len = inputs_embeds.shape[1]
|
| 1094 |
dtype = inputs_embeds.dtype
|
| 1095 |
device = inputs_embeds.device
|
| 1096 |
-
|
| 1097 |
# 判断是否使用 Flash Attention 2
|
| 1098 |
is_flash_attn = (self.config._attn_implementation == "flash_attention_2")
|
| 1099 |
|
|
@@ -1110,7 +1111,7 @@ class AceStepTimbreEncoder(AceStepPreTrainedModel):
|
|
| 1110 |
# 如果没有 padding mask,传 None。
|
| 1111 |
# 滑动窗口逻辑由 Layer 内部传给 FA kernel 的 sliding_window 参数控制。
|
| 1112 |
full_attn_mask = attention_mask
|
| 1113 |
-
|
| 1114 |
# 这里的逻辑是:如果配置启用了滑动窗口,FA 模式下我们也只需要传基础的 padding mask
|
| 1115 |
# Layer 会自己决定是否调用带 sliding window 的 kernel
|
| 1116 |
sliding_attn_mask = attention_mask if self.config.use_sliding_window else None
|
|
@@ -1120,7 +1121,7 @@ class AceStepTimbreEncoder(AceStepPreTrainedModel):
|
|
| 1120 |
# 场景 B: CPU / Mac / SDPA (Eager 模式)
|
| 1121 |
# -------------------------------------------------------
|
| 1122 |
# 必须手动生成 4D Mask [B, 1, L, L]
|
| 1123 |
-
|
| 1124 |
# 1. Full Attention (Bidirectional, Global)
|
| 1125 |
# 对应原来的 create_causal_mask + bidirectional
|
| 1126 |
full_attn_mask = create_4d_mask(
|
|
@@ -1151,7 +1152,7 @@ class AceStepTimbreEncoder(AceStepPreTrainedModel):
|
|
| 1151 |
"full_attention": full_attn_mask,
|
| 1152 |
"sliding_attention": sliding_attn_mask,
|
| 1153 |
}
|
| 1154 |
-
|
| 1155 |
# Initialize hidden states
|
| 1156 |
hidden_states = inputs_embeds
|
| 1157 |
|
|
@@ -1181,7 +1182,7 @@ class AceStepTimbreEncoder(AceStepPreTrainedModel):
|
|
| 1181 |
class AceStepAudioTokenizer(AceStepPreTrainedModel):
|
| 1182 |
"""
|
| 1183 |
Audio tokenizer module.
|
| 1184 |
-
|
| 1185 |
Converts continuous acoustic features into discrete quantized tokens.
|
| 1186 |
Process: project -> pool patches -> quantize. Used for converting audio
|
| 1187 |
representations into discrete tokens for processing by the diffusion model.
|
|
@@ -1208,7 +1209,7 @@ class AceStepAudioTokenizer(AceStepPreTrainedModel):
|
|
| 1208 |
hidden_states: Optional[torch.FloatTensor] = None,
|
| 1209 |
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
| 1210 |
) -> BaseModelOutput:
|
| 1211 |
-
|
| 1212 |
# Project acoustic features to hidden size
|
| 1213 |
hidden_states = self.audio_acoustic_proj(hidden_states)
|
| 1214 |
# Pool sequences: N x T//pool_window_size x pool_window_size x d -> N x T//pool_window_size x d
|
|
@@ -1225,14 +1226,14 @@ class AceStepAudioTokenizer(AceStepPreTrainedModel):
|
|
| 1225 |
class Lambda(nn.Module):
|
| 1226 |
"""
|
| 1227 |
Wrapper module for arbitrary lambda functions.
|
| 1228 |
-
|
| 1229 |
Allows using lambda functions in nn.Sequential by wrapping them in a Module.
|
| 1230 |
Useful for simple transformations like transpose operations.
|
| 1231 |
"""
|
| 1232 |
def __init__(self, func):
|
| 1233 |
super().__init__()
|
| 1234 |
self.func = func
|
| 1235 |
-
|
| 1236 |
def forward(self, x):
|
| 1237 |
return self.func(x)
|
| 1238 |
|
|
@@ -1240,7 +1241,7 @@ class Lambda(nn.Module):
|
|
| 1240 |
class AceStepDiTModel(AceStepPreTrainedModel):
|
| 1241 |
"""
|
| 1242 |
DiT (Diffusion Transformer) model for AceStep.
|
| 1243 |
-
|
| 1244 |
Main diffusion model that generates audio latents conditioned on text, lyrics,
|
| 1245 |
and timbre. Uses patch-based processing with transformer layers, timestep
|
| 1246 |
conditioning, and cross-attention to encoder outputs.
|
|
@@ -1258,7 +1259,7 @@ class AceStepDiTModel(AceStepPreTrainedModel):
|
|
| 1258 |
inner_dim = config.hidden_size
|
| 1259 |
patch_size = config.patch_size
|
| 1260 |
self.patch_size = patch_size
|
| 1261 |
-
|
| 1262 |
# Input projection: patch embedding using 1D convolution
|
| 1263 |
# Converts sequence into patches for efficient processing
|
| 1264 |
self.proj_in = nn.Sequential(
|
|
@@ -1277,9 +1278,10 @@ class AceStepDiTModel(AceStepPreTrainedModel):
|
|
| 1277 |
# Two embeddings: one for timestep t, one for timestep difference (t - r)
|
| 1278 |
self.time_embed = TimestepEmbedding(in_channels=256, time_embed_dim=inner_dim)
|
| 1279 |
self.time_embed_r = TimestepEmbedding(in_channels=256, time_embed_dim=inner_dim)
|
| 1280 |
-
|
| 1281 |
# Project encoder hidden states to model dimension
|
| 1282 |
-
|
|
|
|
| 1283 |
|
| 1284 |
# Output normalization and projection
|
| 1285 |
# Adaptive layer norm with scale-shift modulation, then de-patchify
|
|
@@ -1330,7 +1332,7 @@ class AceStepDiTModel(AceStepPreTrainedModel):
|
|
| 1330 |
use_cache = False
|
| 1331 |
if self.training:
|
| 1332 |
use_cache = False
|
| 1333 |
-
|
| 1334 |
# Initialize cache if needed (only during inference for auto-regressive generation)
|
| 1335 |
if not self.training and use_cache and past_key_values is None:
|
| 1336 |
past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache())
|
|
@@ -1357,14 +1359,14 @@ class AceStepDiTModel(AceStepPreTrainedModel):
|
|
| 1357 |
# Project input to patches and project encoder states
|
| 1358 |
hidden_states = self.proj_in(hidden_states)
|
| 1359 |
encoder_hidden_states = self.condition_embedder(encoder_hidden_states)
|
| 1360 |
-
|
| 1361 |
# Cache positions
|
| 1362 |
if cache_position is None:
|
| 1363 |
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 1364 |
cache_position = torch.arange(
|
| 1365 |
past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=hidden_states.device
|
| 1366 |
)
|
| 1367 |
-
|
| 1368 |
# Position IDs
|
| 1369 |
if position_ids is None:
|
| 1370 |
position_ids = cache_position.unsqueeze(0)
|
|
@@ -1374,7 +1376,7 @@ class AceStepDiTModel(AceStepPreTrainedModel):
|
|
| 1374 |
encoder_seq_len = encoder_hidden_states.shape[1]
|
| 1375 |
dtype = hidden_states.dtype
|
| 1376 |
device = hidden_states.device
|
| 1377 |
-
|
| 1378 |
# 判断是否使用 Flash Attention 2
|
| 1379 |
is_flash_attn = (self.config._attn_implementation == "flash_attention_2")
|
| 1380 |
|
|
@@ -1392,7 +1394,7 @@ class AceStepDiTModel(AceStepPreTrainedModel):
|
|
| 1392 |
# 如果没有 padding mask,传 None。
|
| 1393 |
# 滑动窗口逻辑由 Layer 内部传给 FA kernel 的 sliding_window 参数控制。
|
| 1394 |
full_attn_mask = attention_mask
|
| 1395 |
-
|
| 1396 |
# 这里的逻辑是:如果配置启用了滑动窗口,FA 模式下我们也只需要传基础的 padding mask
|
| 1397 |
# Layer 会自己决定是否调用带 sliding window 的 kernel
|
| 1398 |
sliding_attn_mask = attention_mask if self.config.use_sliding_window else None
|
|
@@ -1402,7 +1404,7 @@ class AceStepDiTModel(AceStepPreTrainedModel):
|
|
| 1402 |
# 场景 B: CPU / Mac / SDPA (Eager 模式)
|
| 1403 |
# -------------------------------------------------------
|
| 1404 |
# 必须手动生成 4D Mask [B, 1, L, L]
|
| 1405 |
-
|
| 1406 |
# 1. Full Attention (Bidirectional, Global)
|
| 1407 |
# 对应原来的 create_causal_mask + bidirectional
|
| 1408 |
full_attn_mask = create_4d_mask(
|
|
@@ -1415,7 +1417,7 @@ class AceStepDiTModel(AceStepPreTrainedModel):
|
|
| 1415 |
is_causal=False # <--- 关键:双向注意力
|
| 1416 |
)
|
| 1417 |
max_len = max(seq_len, encoder_seq_len)
|
| 1418 |
-
|
| 1419 |
encoder_attention_mask = create_4d_mask(
|
| 1420 |
seq_len=max_len,
|
| 1421 |
dtype=dtype,
|
|
@@ -1483,7 +1485,7 @@ class AceStepDiTModel(AceStepPreTrainedModel):
|
|
| 1483 |
# Extract the last element which is cross_attn_weights
|
| 1484 |
if len(layer_outputs) >= 3:
|
| 1485 |
all_cross_attentions += (layer_outputs[2],)
|
| 1486 |
-
|
| 1487 |
if return_hidden_states:
|
| 1488 |
return hidden_states
|
| 1489 |
|
|
@@ -1496,10 +1498,10 @@ class AceStepDiTModel(AceStepPreTrainedModel):
|
|
| 1496 |
hidden_states = (self.norm_out(hidden_states) * (1 + scale) + shift).type_as(hidden_states)
|
| 1497 |
# Project output: de-patchify back to original sequence format
|
| 1498 |
hidden_states = self.proj_out(hidden_states)
|
| 1499 |
-
|
| 1500 |
# Crop back to original sequence length to ensure exact length match (remove padding)
|
| 1501 |
hidden_states = hidden_states[:, :original_seq_len, :]
|
| 1502 |
-
|
| 1503 |
outputs = (hidden_states, past_key_values)
|
| 1504 |
|
| 1505 |
if output_attentions:
|
|
@@ -1509,7 +1511,7 @@ class AceStepDiTModel(AceStepPreTrainedModel):
|
|
| 1509 |
class AceStepConditionEncoder(AceStepPreTrainedModel):
|
| 1510 |
"""
|
| 1511 |
Condition encoder for AceStep model.
|
| 1512 |
-
|
| 1513 |
Encodes multiple conditioning inputs (text, lyrics, timbre) and packs them
|
| 1514 |
into a single sequence for cross-attention in the diffusion model. Handles
|
| 1515 |
projection, encoding, and sequence packing.
|
|
@@ -1554,10 +1556,42 @@ class AceStepConditionEncoder(AceStepPreTrainedModel):
|
|
| 1554 |
return encoder_hidden_states, encoder_attention_mask
|
| 1555 |
|
| 1556 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1557 |
class AceStepConditionGenerationModel(AceStepPreTrainedModel):
|
| 1558 |
"""
|
| 1559 |
Main conditional generation model for AceStep.
|
| 1560 |
-
|
| 1561 |
End-to-end model for generating audio conditioned on text, lyrics, and timbre.
|
| 1562 |
Combines encoder (for conditioning), decoder (diffusion model), tokenizer
|
| 1563 |
(for discrete tokenization), and detokenizer (for reconstruction).
|
|
@@ -1568,11 +1602,22 @@ class AceStepConditionGenerationModel(AceStepPreTrainedModel):
|
|
| 1568 |
self.config = config
|
| 1569 |
# Diffusion model components
|
| 1570 |
self.decoder = AceStepDiTModel(config) # Main diffusion transformer
|
| 1571 |
-
|
| 1572 |
-
|
| 1573 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1574 |
# Null condition embedding for classifier-free guidance
|
| 1575 |
-
self.null_condition_emb = nn.Parameter(torch.randn(1, 1,
|
| 1576 |
|
| 1577 |
# Initialize weights and apply final processing
|
| 1578 |
self.post_init()
|
|
@@ -1621,7 +1666,7 @@ class AceStepConditionGenerationModel(AceStepPreTrainedModel):
|
|
| 1621 |
precomputed_lm_hints_25Hz: Optional[torch.FloatTensor] = None,
|
| 1622 |
audio_codes: torch.FloatTensor = None,
|
| 1623 |
):
|
| 1624 |
-
|
| 1625 |
dtype = hidden_states.dtype
|
| 1626 |
encoder_hidden_states, encoder_attention_mask = self.encoder(
|
| 1627 |
text_hidden_states=text_hidden_states,
|
|
@@ -1709,7 +1754,7 @@ class AceStepConditionGenerationModel(AceStepPreTrainedModel):
|
|
| 1709 |
t_ = t.unsqueeze(-1).unsqueeze(-1)
|
| 1710 |
# Interpolate: x_t = t * x1 + (1 - t) * x0
|
| 1711 |
xt = t_ * x1 + (1.0 - t_) * x0
|
| 1712 |
-
|
| 1713 |
# Predict flow (velocity) from diffusion model
|
| 1714 |
decoder_outputs = self.decoder(
|
| 1715 |
hidden_states=xt,
|
|
@@ -1726,10 +1771,10 @@ class AceStepConditionGenerationModel(AceStepPreTrainedModel):
|
|
| 1726 |
return {
|
| 1727 |
"diffusion_loss": diffusion_loss,
|
| 1728 |
}
|
| 1729 |
-
|
| 1730 |
def training_losses(self, **kwargs):
|
| 1731 |
return self.forward(**kwargs)
|
| 1732 |
-
|
| 1733 |
def prepare_noise(self, context_latents: torch.FloatTensor, seed: Union[int, List[int], None] = None):
|
| 1734 |
"""
|
| 1735 |
Prepare noise tensor for generation with optional seeding.
|
|
@@ -1770,6 +1815,9 @@ class AceStepConditionGenerationModel(AceStepPreTrainedModel):
|
|
| 1770 |
return noise
|
| 1771 |
|
| 1772 |
def get_x0_from_noise(self, zt, vt, t):
|
|
|
|
|
|
|
|
|
|
| 1773 |
return zt - vt * t.unsqueeze(-1).unsqueeze(-1)
|
| 1774 |
|
| 1775 |
def renoise(self, x, t, noise=None):
|
|
@@ -1779,7 +1827,7 @@ class AceStepConditionGenerationModel(AceStepPreTrainedModel):
|
|
| 1779 |
t = t.unsqueeze(-1).unsqueeze(-1)
|
| 1780 |
xt = t * noise + (1 - t) * x
|
| 1781 |
return xt
|
| 1782 |
-
|
| 1783 |
def generate_audio(
|
| 1784 |
self,
|
| 1785 |
text_hidden_states: torch.FloatTensor,
|
|
@@ -1797,7 +1845,7 @@ class AceStepConditionGenerationModel(AceStepPreTrainedModel):
|
|
| 1797 |
infer_method: str = "ode",
|
| 1798 |
use_cache: bool = True,
|
| 1799 |
infer_steps: int = 30,
|
| 1800 |
-
|
| 1801 |
audio_cover_strength: float = 1.0,
|
| 1802 |
non_cover_text_hidden_states: Optional[torch.FloatTensor] = None,
|
| 1803 |
non_cover_text_attention_mask: Optional[torch.FloatTensor] = None,
|
|
@@ -1809,8 +1857,27 @@ class AceStepConditionGenerationModel(AceStepPreTrainedModel):
|
|
| 1809 |
use_adg: bool = False,
|
| 1810 |
shift: float = 1.0,
|
| 1811 |
cover_noise_strength: float = 0.0,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1812 |
**kwargs,
|
| 1813 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1814 |
if attention_mask is None:
|
| 1815 |
latent_length = src_latents.shape[1]
|
| 1816 |
attention_mask = torch.ones(src_latents.shape[0], latent_length, device=src_latents.device, dtype=src_latents.dtype)
|
|
@@ -1874,7 +1941,7 @@ class AceStepConditionGenerationModel(AceStepPreTrainedModel):
|
|
| 1874 |
bsz, device, dtype = context_latents.shape[0], context_latents.device, context_latents.dtype
|
| 1875 |
past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache())
|
| 1876 |
momentum_buffer = MomentumBuffer()
|
| 1877 |
-
|
| 1878 |
# Cover noise initialization: blend noise with src_latents
|
| 1879 |
if cover_noise_strength > 0.0:
|
| 1880 |
# cover_noise_strength=1 means closest to src, so noise_level should be low
|
|
@@ -1900,16 +1967,24 @@ class AceStepConditionGenerationModel(AceStepPreTrainedModel):
|
|
| 1900 |
)
|
| 1901 |
else:
|
| 1902 |
xt = noise
|
| 1903 |
-
|
| 1904 |
# main task condition
|
| 1905 |
-
do_cfg_guidance =
|
| 1906 |
if do_cfg_guidance:
|
| 1907 |
encoder_hidden_states = torch.cat([encoder_hidden_states, self.null_condition_emb.expand_as(encoder_hidden_states)], dim=0)
|
| 1908 |
encoder_attention_mask = torch.cat([encoder_attention_mask, encoder_attention_mask], dim=0)
|
| 1909 |
# src_latents
|
| 1910 |
context_latents = torch.cat([context_latents, context_latents], dim=0)
|
| 1911 |
attention_mask = torch.cat([attention_mask, attention_mask], dim=0)
|
| 1912 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1913 |
_switched_to_non_cover = False
|
| 1914 |
with torch.no_grad():
|
| 1915 |
for step_idx, (t_curr, t_prev) in enumerate(iterator):
|
|
@@ -1925,7 +2000,7 @@ class AceStepConditionGenerationModel(AceStepPreTrainedModel):
|
|
| 1925 |
encoder_attention_mask = encoder_attention_mask_non_cover
|
| 1926 |
context_latents = context_latents_non_cover
|
| 1927 |
past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache())
|
| 1928 |
-
|
| 1929 |
x = torch.cat([xt, xt], dim=0) if do_cfg_guidance else xt
|
| 1930 |
t_curr_tensor = t_curr * torch.ones((x.shape[0],), device=device, dtype=dtype)
|
| 1931 |
decoder_outputs = self.decoder(
|
|
@@ -1939,7 +2014,7 @@ class AceStepConditionGenerationModel(AceStepPreTrainedModel):
|
|
| 1939 |
use_cache=True,
|
| 1940 |
past_key_values=past_key_values,
|
| 1941 |
)
|
| 1942 |
-
|
| 1943 |
vt = decoder_outputs[0]
|
| 1944 |
past_key_values = decoder_outputs[1]
|
| 1945 |
apply_cfg_guidance = t_curr >= cfg_interval_start and t_curr <= cfg_interval_end
|
|
@@ -1950,7 +2025,7 @@ class AceStepConditionGenerationModel(AceStepPreTrainedModel):
|
|
| 1950 |
vt = apg_forward(
|
| 1951 |
pred_cond=pred_cond,
|
| 1952 |
pred_uncond=pred_null_cond,
|
| 1953 |
-
guidance_scale=
|
| 1954 |
momentum_buffer=momentum_buffer,
|
| 1955 |
dims=[1],
|
| 1956 |
)
|
|
@@ -1960,10 +2035,21 @@ class AceStepConditionGenerationModel(AceStepPreTrainedModel):
|
|
| 1960 |
noise_pred_cond=pred_cond,
|
| 1961 |
noise_pred_uncond=pred_null_cond,
|
| 1962 |
sigma=t_curr,
|
| 1963 |
-
guidance_scale=
|
| 1964 |
)
|
| 1965 |
else:
|
| 1966 |
vt = pred_cond
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1967 |
# Update x_t based on inference method
|
| 1968 |
if infer_method == "sde":
|
| 1969 |
# Stochastic Differential Equation: predict clean, then re-add noise
|
|
@@ -1971,14 +2057,85 @@ class AceStepConditionGenerationModel(AceStepPreTrainedModel):
|
|
| 1971 |
pred_clean = self.get_x0_from_noise(xt, vt, t_curr_bsz)
|
| 1972 |
next_timestep = 1.0 - (float(step_idx + 1) / infer_steps)
|
| 1973 |
xt = self.renoise(pred_clean, next_timestep)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1974 |
elif infer_method == "ode":
|
| 1975 |
# Ordinary Differential Equation: Euler method
|
| 1976 |
# dx/dt = -v, so x_{t+1} = x_t - v_t * dt
|
| 1977 |
dt = t_curr - t_prev
|
| 1978 |
dt_tensor = dt * torch.ones((bsz,), device=device, dtype=dtype).unsqueeze(-1).unsqueeze(-1)
|
| 1979 |
xt = xt - vt * dt_tensor
|
| 1980 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1981 |
x_gen = xt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1982 |
end_time = time.time()
|
| 1983 |
time_costs["diffusion_time_cost"] = end_time - start_time
|
| 1984 |
time_costs["diffusion_per_step_time_cost"] = time_costs["diffusion_time_cost"] / infer_steps
|
|
@@ -2001,13 +2158,13 @@ def test_forward(model, seed=42):
|
|
| 2001 |
torch.cuda.manual_seed_all(seed)
|
| 2002 |
torch.backends.cudnn.deterministic = True
|
| 2003 |
torch.backends.cudnn.benchmark = False
|
| 2004 |
-
|
| 2005 |
# Get model dtype and device
|
| 2006 |
model_dtype = next(model.parameters()).dtype
|
| 2007 |
device = next(model.parameters()).device
|
| 2008 |
-
|
| 2009 |
print(f"Testing with dtype: {model_dtype}, device: {device}, seed: {seed}")
|
| 2010 |
-
|
| 2011 |
# Test data preparation with matching dtype
|
| 2012 |
text_hidden_states = torch.randn(2, 77, 1024, dtype=model_dtype, device=device)
|
| 2013 |
text_attention_mask = torch.ones(2, 77, dtype=model_dtype, device=device)
|
|
@@ -2141,4 +2298,4 @@ if __name__ == "__main__":
|
|
| 2141 |
# model = model.float()
|
| 2142 |
model = model.to("cuda")
|
| 2143 |
model = model.bfloat16()
|
| 2144 |
-
test_forward(model)
|
|
|
|
| 11 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
# See the License for the specific language governing permissions and
|
| 13 |
# limitations under the License.
|
| 14 |
+
import copy
|
| 15 |
import math
|
| 16 |
import time
|
| 17 |
from typing import Callable, List, Optional, Union
|
|
|
|
| 45 |
# Local config import with fallback
|
| 46 |
try:
|
| 47 |
from .configuration_acestep_v15 import AceStepConfig
|
| 48 |
+
from .apg_guidance import adg_forward, apg_forward, cfg_forward, MomentumBuffer
|
| 49 |
except ImportError:
|
| 50 |
from configuration_acestep_v15 import AceStepConfig
|
| 51 |
+
from apg_guidance import adg_forward, apg_forward, cfg_forward, MomentumBuffer
|
| 52 |
|
| 53 |
|
| 54 |
logger = logging.get_logger(__name__)
|
|
|
|
| 116 |
# We want to mask out invalid keys (columns)
|
| 117 |
# Expand shape: [Batch, 1, 1, Seq_Len]
|
| 118 |
padding_mask_4d = attention_mask.view(attention_mask.shape[0], 1, 1, seq_len).to(torch.bool)
|
| 119 |
+
|
| 120 |
# Broadcasting: Geometry Mask [1, 1, L, L] & Padding Mask [B, 1, 1, L]
|
| 121 |
# Result shape: [B, 1, L, L]
|
| 122 |
valid_mask = valid_mask & padding_mask_4d
|
|
|
|
| 126 |
# ------------------------------------------------------
|
| 127 |
# Get the minimal value for current dtype
|
| 128 |
min_dtype = torch.finfo(dtype).min
|
| 129 |
+
|
| 130 |
# Create result tensor filled with -inf by default
|
| 131 |
mask_tensor = torch.full(valid_mask.shape, min_dtype, dtype=dtype, device=device)
|
| 132 |
+
|
| 133 |
# Set valid positions to 0.0
|
| 134 |
mask_tensor.masked_fill_(valid_mask, 0.0)
|
| 135 |
+
|
| 136 |
return mask_tensor
|
| 137 |
|
| 138 |
|
|
|
|
| 201 |
class TimestepEmbedding(nn.Module):
|
| 202 |
"""
|
| 203 |
Timestep embedding module for diffusion models.
|
| 204 |
+
|
| 205 |
Converts timestep values into high-dimensional embeddings using sinusoidal
|
| 206 |
positional encoding, followed by MLP layers. Used for conditioning diffusion
|
| 207 |
models on timestep information.
|
|
|
|
| 218 |
self.act1 = nn.SiLU()
|
| 219 |
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim, bias=True)
|
| 220 |
self.in_channels = in_channels
|
| 221 |
+
|
| 222 |
self.act2 = nn.SiLU()
|
| 223 |
self.time_proj = nn.Linear(time_embed_dim, time_embed_dim * 6)
|
| 224 |
self.scale = scale
|
|
|
|
| 306 |
|
| 307 |
# Determine if this is cross-attention (requires encoder_hidden_states)
|
| 308 |
is_cross_attention = self.is_cross_attention and encoder_hidden_states is not None
|
| 309 |
+
|
| 310 |
# Cross-attention path: attend to encoder hidden states
|
| 311 |
if is_cross_attention:
|
| 312 |
encoder_hidden_shape = (*encoder_hidden_states.shape[:-1], -1, self.head_dim)
|
|
|
|
| 314 |
is_updated = past_key_value.is_updated.get(self.layer_idx)
|
| 315 |
# After the first generated token, we can reuse all key/value states from cache
|
| 316 |
curr_past_key_value = past_key_value.cross_attention_cache
|
| 317 |
+
|
| 318 |
# Conditions for calculating key and value states
|
| 319 |
if not is_updated:
|
| 320 |
# Compute and cache K/V for the first time
|
|
|
|
| 332 |
# No cache used, compute K/V directly
|
| 333 |
key_states = self.k_norm(self.k_proj(encoder_hidden_states).view(encoder_hidden_shape)).transpose(1, 2)
|
| 334 |
value_states = self.v_proj(encoder_hidden_states).view(encoder_hidden_shape).transpose(1, 2)
|
| 335 |
+
|
| 336 |
# Self-attention path: attend to the same sequence
|
| 337 |
else:
|
| 338 |
# Project and normalize key/value states for self-attention
|
|
|
|
| 354 |
attention_interface: Callable = eager_attention_forward
|
| 355 |
elif self.config._attn_implementation != "eager":
|
| 356 |
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
| 357 |
+
|
| 358 |
attn_output, attn_weights = attention_interface(
|
| 359 |
self,
|
| 360 |
query_states,
|
|
|
|
| 444 |
class AceStepDiTLayer(GradientCheckpointingLayer):
|
| 445 |
"""
|
| 446 |
DiT (Diffusion Transformer) layer for AceStep model.
|
| 447 |
+
|
| 448 |
Implements a transformer layer with three main components:
|
| 449 |
1. Self-attention with adaptive layer norm (AdaLN)
|
| 450 |
2. Cross-attention (optional) for conditioning on encoder outputs
|
| 451 |
3. Feed-forward MLP with adaptive layer norm
|
| 452 |
+
|
| 453 |
Uses scale-shift modulation from timestep embeddings for adaptive normalization.
|
| 454 |
"""
|
| 455 |
def __init__(self, config: AceStepConfig, layer_idx: int, use_cross_attention: bool = True):
|
|
|
|
| 472 |
# Scale-shift table for adaptive layer norm modulation (6 values: 3 for self-attn, 3 for MLP)
|
| 473 |
self.scale_shift_table = nn.Parameter(torch.randn(1, 6, config.hidden_size) / config.hidden_size**0.5)
|
| 474 |
self.attention_type = config.layer_types[layer_idx]
|
| 475 |
+
|
| 476 |
def forward(
|
| 477 |
self,
|
| 478 |
hidden_states: torch.Tensor,
|
|
|
|
| 578 |
class AceStepLyricEncoder(AceStepPreTrainedModel):
|
| 579 |
"""
|
| 580 |
Encoder for processing lyric text embeddings.
|
| 581 |
+
|
| 582 |
Encodes lyric text hidden states using a transformer encoder architecture
|
| 583 |
with bidirectional attention. Projects text embeddings to model hidden size
|
| 584 |
and processes them through multiple encoder layers.
|
| 585 |
"""
|
| 586 |
def __init__(self, config):
|
| 587 |
super().__init__(config)
|
| 588 |
+
|
| 589 |
# Project text embeddings to model hidden size
|
| 590 |
self.embed_tokens = nn.Linear(config.text_hidden_dim, config.hidden_size)
|
| 591 |
self.norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
|
|
| 619 |
assert input_ids is None, "Only `input_ids` is supported for the lyric encoder."
|
| 620 |
assert attention_mask is not None, "Attention mask must be provided for the lyric encoder."
|
| 621 |
assert inputs_embeds is not None, "Inputs embeddings must be provided for the lyric encoder."
|
| 622 |
+
|
| 623 |
# Project input embeddings: N x T x text_hidden_dim -> N x T x hidden_size
|
| 624 |
inputs_embeds = self.embed_tokens(inputs_embeds)
|
| 625 |
# Cache position: only used for mask construction (not for actual caching)
|
|
|
|
| 633 |
seq_len = inputs_embeds.shape[1]
|
| 634 |
dtype = inputs_embeds.dtype
|
| 635 |
device = inputs_embeds.device
|
| 636 |
+
|
| 637 |
# 判断是否使用 Flash Attention 2
|
| 638 |
is_flash_attn = (self.config._attn_implementation == "flash_attention_2")
|
| 639 |
|
|
|
|
| 650 |
# 如果没有 padding mask,传 None。
|
| 651 |
# 滑动窗口逻辑由 Layer 内部传给 FA kernel 的 sliding_window 参数控制。
|
| 652 |
full_attn_mask = attention_mask
|
| 653 |
+
|
| 654 |
# 这里的逻辑是:如果配置启用了滑动窗口,FA 模式下我们也只需要传基础的 padding mask
|
| 655 |
# Layer 会自己决定是否调用带 sliding window 的 kernel
|
| 656 |
sliding_attn_mask = attention_mask if self.config.use_sliding_window else None
|
|
|
|
| 660 |
# 场景 B: CPU / Mac / SDPA (Eager 模式)
|
| 661 |
# -------------------------------------------------------
|
| 662 |
# 必须手动生成 4D Mask [B, 1, L, L]
|
| 663 |
+
|
| 664 |
# 1. Full Attention (Bidirectional, Global)
|
| 665 |
# 对应原来的 create_causal_mask + bidirectional
|
| 666 |
full_attn_mask = create_4d_mask(
|
|
|
|
| 735 |
class AttentionPooler(AceStepPreTrainedModel):
|
| 736 |
"""
|
| 737 |
Attention-based pooling module.
|
| 738 |
+
|
| 739 |
Pools sequences of patches using a special token and attention mechanism.
|
| 740 |
The special token attends to all patches and its output is used as the
|
| 741 |
pooled representation. Used for aggregating patch-level features into
|
|
|
|
| 783 |
seq_len = x.shape[1]
|
| 784 |
dtype = x.dtype
|
| 785 |
device = x.device
|
| 786 |
+
|
| 787 |
# 判断是否使用 Flash Attention 2
|
| 788 |
is_flash_attn = (self.config._attn_implementation == "flash_attention_2")
|
| 789 |
|
|
|
|
| 800 |
# 如果没有 padding mask,传 None。
|
| 801 |
# 滑动窗口逻辑由 Layer 内部传给 FA kernel 的 sliding_window 参数控制。
|
| 802 |
full_attn_mask = attention_mask
|
| 803 |
+
|
| 804 |
# 这里的逻辑是:如果配置启用了滑动窗口,FA 模式下我们也只需要传基础的 padding mask
|
| 805 |
# Layer 会自己决定是否调用带 sliding window 的 kernel
|
| 806 |
sliding_attn_mask = attention_mask if self.config.use_sliding_window else None
|
|
|
|
| 810 |
# 场景 B: CPU / Mac / SDPA (Eager 模式)
|
| 811 |
# -------------------------------------------------------
|
| 812 |
# 必须手动生成 4D Mask [B, 1, L, L]
|
| 813 |
+
|
| 814 |
# 1. Full Attention (Bidirectional, Global)
|
| 815 |
# 对应原来的 create_causal_mask + bidirectional
|
| 816 |
full_attn_mask = create_4d_mask(
|
|
|
|
| 841 |
"full_attention": full_attn_mask,
|
| 842 |
"sliding_attention": sliding_attn_mask,
|
| 843 |
}
|
| 844 |
+
|
| 845 |
for layer_module in self.layers:
|
| 846 |
layer_outputs = layer_module(
|
| 847 |
hidden_states,
|
|
|
|
| 853 |
hidden_states = layer_outputs[0]
|
| 854 |
|
| 855 |
hidden_states = self.norm(hidden_states)
|
| 856 |
+
|
| 857 |
# Extract the special token output (first position) as pooled representation
|
| 858 |
cls_output = hidden_states[:, 0, :]
|
| 859 |
cls_output = rearrange(cls_output, "(b t) c -> b t c", b=B)
|
|
|
|
| 863 |
class AudioTokenDetokenizer(AceStepPreTrainedModel):
|
| 864 |
"""
|
| 865 |
Audio token detokenizer module.
|
| 866 |
+
|
| 867 |
Converts quantized audio tokens back to continuous acoustic representations.
|
| 868 |
Expands each token into multiple patches using special tokens, processes them
|
| 869 |
through encoder layers, and projects to acoustic hidden dimension.
|
|
|
|
| 918 |
seq_len = x.shape[1]
|
| 919 |
dtype = x.dtype
|
| 920 |
device = x.device
|
| 921 |
+
|
| 922 |
# 判断是否使用 Flash Attention 2
|
| 923 |
is_flash_attn = (self.config._attn_implementation == "flash_attention_2")
|
| 924 |
|
|
|
|
| 935 |
# 如果没有 padding mask,传 None。
|
| 936 |
# 滑动窗口逻辑由 Layer 内部传给 FA kernel 的 sliding_window 参数控制。
|
| 937 |
full_attn_mask = attention_mask
|
| 938 |
+
|
| 939 |
# 这里的逻辑是:如果配置启用了滑动窗口,FA 模式下我们也只需要传基础的 padding mask
|
| 940 |
# Layer 会自己决定是否调用带 sliding window 的 kernel
|
| 941 |
sliding_attn_mask = attention_mask if self.config.use_sliding_window else None
|
|
|
|
| 945 |
# 场景 B: CPU / Mac / SDPA (Eager 模式)
|
| 946 |
# -------------------------------------------------------
|
| 947 |
# 必须手动生成 4D Mask [B, 1, L, L]
|
| 948 |
+
|
| 949 |
# 1. Full Attention (Bidirectional, Global)
|
| 950 |
# 对应原来的 create_causal_mask + bidirectional
|
| 951 |
full_attn_mask = create_4d_mask(
|
|
|
|
| 976 |
"full_attention": full_attn_mask,
|
| 977 |
"sliding_attention": sliding_attn_mask,
|
| 978 |
}
|
| 979 |
+
|
| 980 |
for layer_module in self.layers:
|
| 981 |
layer_outputs = layer_module(
|
| 982 |
hidden_states,
|
|
|
|
| 988 |
hidden_states = layer_outputs[0]
|
| 989 |
|
| 990 |
hidden_states = self.norm(hidden_states)
|
| 991 |
+
|
| 992 |
hidden_states = self.proj_out(hidden_states)
|
| 993 |
|
| 994 |
hidden_states = rearrange(hidden_states, "(b t) p c -> b (t p) c", b=B, p=self.config.pool_window_size)
|
|
|
|
| 998 |
class AceStepTimbreEncoder(AceStepPreTrainedModel):
|
| 999 |
"""
|
| 1000 |
Encoder for extracting timbre embeddings from reference audio.
|
| 1001 |
+
|
| 1002 |
Processes packed reference audio acoustic features to extract timbre
|
| 1003 |
representations. Uses a special token (CLS-like) to aggregate information
|
| 1004 |
from the entire reference audio sequence. Outputs are unpacked back to
|
|
|
|
| 1006 |
"""
|
| 1007 |
def __init__(self, config):
|
| 1008 |
super().__init__(config)
|
| 1009 |
+
|
| 1010 |
# Project acoustic features to model hidden size
|
| 1011 |
self.embed_tokens = nn.Linear(config.timbre_hidden_dim, config.hidden_size)
|
| 1012 |
self.norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
|
|
| 1037 |
N, d = timbre_embs_packed.shape
|
| 1038 |
device = timbre_embs_packed.device
|
| 1039 |
dtype = timbre_embs_packed.dtype
|
| 1040 |
+
|
| 1041 |
# Get batch size
|
| 1042 |
B = int(refer_audio_order_mask.max().item() + 1)
|
| 1043 |
+
|
| 1044 |
# Calculate element count and positions for each batch
|
| 1045 |
counts = torch.bincount(refer_audio_order_mask, minlength=B)
|
| 1046 |
max_count = counts.max().item()
|
| 1047 |
+
|
| 1048 |
# Calculate positions within batch
|
| 1049 |
sorted_indices = torch.argsort(refer_audio_order_mask * N + torch.arange(N, device=device), stable=True)
|
| 1050 |
sorted_batch_ids = refer_audio_order_mask[sorted_indices]
|
| 1051 |
+
|
| 1052 |
positions = torch.arange(N, device=device)
|
| 1053 |
+
batch_starts = torch.cat([torch.tensor([0], device=device),
|
| 1054 |
torch.cumsum(counts, dim=0)[:-1]])
|
| 1055 |
positions_in_sorted = positions - batch_starts[sorted_batch_ids]
|
| 1056 |
+
|
| 1057 |
inverse_indices = torch.empty_like(sorted_indices)
|
| 1058 |
inverse_indices[sorted_indices] = torch.arange(N, device=device)
|
| 1059 |
positions_in_batch = positions_in_sorted[inverse_indices]
|
| 1060 |
+
|
| 1061 |
# Use one-hot encoding and matrix multiplication (gradient-friendly approach)
|
| 1062 |
# Create one-hot encoding
|
| 1063 |
indices_2d = refer_audio_order_mask * max_count + positions_in_batch # (N,)
|
| 1064 |
one_hot = F.one_hot(indices_2d, num_classes=B * max_count).to(dtype) # (N, B*max_count)
|
| 1065 |
+
|
| 1066 |
# Rearrange using matrix multiplication
|
| 1067 |
timbre_embs_flat = one_hot.t() @ timbre_embs_packed # (B*max_count, d)
|
| 1068 |
timbre_embs_unpack = timbre_embs_flat.reshape(B, max_count, d)
|
| 1069 |
+
|
| 1070 |
# Create mask indicating valid positions
|
| 1071 |
mask_flat = (one_hot.sum(dim=0) > 0).long() # (B*max_count,)
|
| 1072 |
new_mask = mask_flat.reshape(B, max_count)
|
| 1073 |
+
|
| 1074 |
return timbre_embs_unpack, new_mask
|
| 1075 |
|
| 1076 |
@can_return_tuple
|
|
|
|
| 1094 |
seq_len = inputs_embeds.shape[1]
|
| 1095 |
dtype = inputs_embeds.dtype
|
| 1096 |
device = inputs_embeds.device
|
| 1097 |
+
|
| 1098 |
# 判断是否使用 Flash Attention 2
|
| 1099 |
is_flash_attn = (self.config._attn_implementation == "flash_attention_2")
|
| 1100 |
|
|
|
|
| 1111 |
# 如果没有 padding mask,传 None。
|
| 1112 |
# 滑动窗口逻辑由 Layer 内部传给 FA kernel 的 sliding_window 参数控制。
|
| 1113 |
full_attn_mask = attention_mask
|
| 1114 |
+
|
| 1115 |
# 这里的逻辑是:如果配置启用了滑动窗口,FA 模式下我们也只需要传基础的 padding mask
|
| 1116 |
# Layer 会自己决定是否调用带 sliding window 的 kernel
|
| 1117 |
sliding_attn_mask = attention_mask if self.config.use_sliding_window else None
|
|
|
|
| 1121 |
# 场景 B: CPU / Mac / SDPA (Eager 模式)
|
| 1122 |
# -------------------------------------------------------
|
| 1123 |
# 必须手动生成 4D Mask [B, 1, L, L]
|
| 1124 |
+
|
| 1125 |
# 1. Full Attention (Bidirectional, Global)
|
| 1126 |
# 对应原来的 create_causal_mask + bidirectional
|
| 1127 |
full_attn_mask = create_4d_mask(
|
|
|
|
| 1152 |
"full_attention": full_attn_mask,
|
| 1153 |
"sliding_attention": sliding_attn_mask,
|
| 1154 |
}
|
| 1155 |
+
|
| 1156 |
# Initialize hidden states
|
| 1157 |
hidden_states = inputs_embeds
|
| 1158 |
|
|
|
|
| 1182 |
class AceStepAudioTokenizer(AceStepPreTrainedModel):
|
| 1183 |
"""
|
| 1184 |
Audio tokenizer module.
|
| 1185 |
+
|
| 1186 |
Converts continuous acoustic features into discrete quantized tokens.
|
| 1187 |
Process: project -> pool patches -> quantize. Used for converting audio
|
| 1188 |
representations into discrete tokens for processing by the diffusion model.
|
|
|
|
| 1209 |
hidden_states: Optional[torch.FloatTensor] = None,
|
| 1210 |
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
| 1211 |
) -> BaseModelOutput:
|
| 1212 |
+
|
| 1213 |
# Project acoustic features to hidden size
|
| 1214 |
hidden_states = self.audio_acoustic_proj(hidden_states)
|
| 1215 |
# Pool sequences: N x T//pool_window_size x pool_window_size x d -> N x T//pool_window_size x d
|
|
|
|
| 1226 |
class Lambda(nn.Module):
|
| 1227 |
"""
|
| 1228 |
Wrapper module for arbitrary lambda functions.
|
| 1229 |
+
|
| 1230 |
Allows using lambda functions in nn.Sequential by wrapping them in a Module.
|
| 1231 |
Useful for simple transformations like transpose operations.
|
| 1232 |
"""
|
| 1233 |
def __init__(self, func):
|
| 1234 |
super().__init__()
|
| 1235 |
self.func = func
|
| 1236 |
+
|
| 1237 |
def forward(self, x):
|
| 1238 |
return self.func(x)
|
| 1239 |
|
|
|
|
| 1241 |
class AceStepDiTModel(AceStepPreTrainedModel):
|
| 1242 |
"""
|
| 1243 |
DiT (Diffusion Transformer) model for AceStep.
|
| 1244 |
+
|
| 1245 |
Main diffusion model that generates audio latents conditioned on text, lyrics,
|
| 1246 |
and timbre. Uses patch-based processing with transformer layers, timestep
|
| 1247 |
conditioning, and cross-attention to encoder outputs.
|
|
|
|
| 1259 |
inner_dim = config.hidden_size
|
| 1260 |
patch_size = config.patch_size
|
| 1261 |
self.patch_size = patch_size
|
| 1262 |
+
|
| 1263 |
# Input projection: patch embedding using 1D convolution
|
| 1264 |
# Converts sequence into patches for efficient processing
|
| 1265 |
self.proj_in = nn.Sequential(
|
|
|
|
| 1278 |
# Two embeddings: one for timestep t, one for timestep difference (t - r)
|
| 1279 |
self.time_embed = TimestepEmbedding(in_channels=256, time_embed_dim=inner_dim)
|
| 1280 |
self.time_embed_r = TimestepEmbedding(in_channels=256, time_embed_dim=inner_dim)
|
| 1281 |
+
|
| 1282 |
# Project encoder hidden states to model dimension
|
| 1283 |
+
condition_dim = getattr(config, "encoder_hidden_size", None) or config.hidden_size
|
| 1284 |
+
self.condition_embedder = nn.Linear(condition_dim, inner_dim, bias=True)
|
| 1285 |
|
| 1286 |
# Output normalization and projection
|
| 1287 |
# Adaptive layer norm with scale-shift modulation, then de-patchify
|
|
|
|
| 1332 |
use_cache = False
|
| 1333 |
if self.training:
|
| 1334 |
use_cache = False
|
| 1335 |
+
|
| 1336 |
# Initialize cache if needed (only during inference for auto-regressive generation)
|
| 1337 |
if not self.training and use_cache and past_key_values is None:
|
| 1338 |
past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache())
|
|
|
|
| 1359 |
# Project input to patches and project encoder states
|
| 1360 |
hidden_states = self.proj_in(hidden_states)
|
| 1361 |
encoder_hidden_states = self.condition_embedder(encoder_hidden_states)
|
| 1362 |
+
|
| 1363 |
# Cache positions
|
| 1364 |
if cache_position is None:
|
| 1365 |
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 1366 |
cache_position = torch.arange(
|
| 1367 |
past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=hidden_states.device
|
| 1368 |
)
|
| 1369 |
+
|
| 1370 |
# Position IDs
|
| 1371 |
if position_ids is None:
|
| 1372 |
position_ids = cache_position.unsqueeze(0)
|
|
|
|
| 1376 |
encoder_seq_len = encoder_hidden_states.shape[1]
|
| 1377 |
dtype = hidden_states.dtype
|
| 1378 |
device = hidden_states.device
|
| 1379 |
+
|
| 1380 |
# 判断是否使用 Flash Attention 2
|
| 1381 |
is_flash_attn = (self.config._attn_implementation == "flash_attention_2")
|
| 1382 |
|
|
|
|
| 1394 |
# 如果没有 padding mask,传 None。
|
| 1395 |
# 滑动窗口逻辑由 Layer 内部传给 FA kernel 的 sliding_window 参数控制。
|
| 1396 |
full_attn_mask = attention_mask
|
| 1397 |
+
|
| 1398 |
# 这里的逻辑是:如果配置启用了滑动窗口,FA 模式下我们也只需要传基础的 padding mask
|
| 1399 |
# Layer 会自己决定是否调用带 sliding window 的 kernel
|
| 1400 |
sliding_attn_mask = attention_mask if self.config.use_sliding_window else None
|
|
|
|
| 1404 |
# 场景 B: CPU / Mac / SDPA (Eager 模式)
|
| 1405 |
# -------------------------------------------------------
|
| 1406 |
# 必须手动生成 4D Mask [B, 1, L, L]
|
| 1407 |
+
|
| 1408 |
# 1. Full Attention (Bidirectional, Global)
|
| 1409 |
# 对应原来的 create_causal_mask + bidirectional
|
| 1410 |
full_attn_mask = create_4d_mask(
|
|
|
|
| 1417 |
is_causal=False # <--- 关键:双向注意力
|
| 1418 |
)
|
| 1419 |
max_len = max(seq_len, encoder_seq_len)
|
| 1420 |
+
|
| 1421 |
encoder_attention_mask = create_4d_mask(
|
| 1422 |
seq_len=max_len,
|
| 1423 |
dtype=dtype,
|
|
|
|
| 1485 |
# Extract the last element which is cross_attn_weights
|
| 1486 |
if len(layer_outputs) >= 3:
|
| 1487 |
all_cross_attentions += (layer_outputs[2],)
|
| 1488 |
+
|
| 1489 |
if return_hidden_states:
|
| 1490 |
return hidden_states
|
| 1491 |
|
|
|
|
| 1498 |
hidden_states = (self.norm_out(hidden_states) * (1 + scale) + shift).type_as(hidden_states)
|
| 1499 |
# Project output: de-patchify back to original sequence format
|
| 1500 |
hidden_states = self.proj_out(hidden_states)
|
| 1501 |
+
|
| 1502 |
# Crop back to original sequence length to ensure exact length match (remove padding)
|
| 1503 |
hidden_states = hidden_states[:, :original_seq_len, :]
|
| 1504 |
+
|
| 1505 |
outputs = (hidden_states, past_key_values)
|
| 1506 |
|
| 1507 |
if output_attentions:
|
|
|
|
| 1511 |
class AceStepConditionEncoder(AceStepPreTrainedModel):
|
| 1512 |
"""
|
| 1513 |
Condition encoder for AceStep model.
|
| 1514 |
+
|
| 1515 |
Encodes multiple conditioning inputs (text, lyrics, timbre) and packs them
|
| 1516 |
into a single sequence for cross-attention in the diffusion model. Handles
|
| 1517 |
projection, encoding, and sequence packing.
|
|
|
|
| 1556 |
return encoder_hidden_states, encoder_attention_mask
|
| 1557 |
|
| 1558 |
|
| 1559 |
+
def _repaint_step_injection(xt, clean_src, mask, t_next, noise):
|
| 1560 |
+
"""Replace non-repaint regions of *xt* with noised source latents."""
|
| 1561 |
+
zt = t_next * noise + (1.0 - t_next) * clean_src
|
| 1562 |
+
m = mask.unsqueeze(-1).expand_as(xt)
|
| 1563 |
+
return torch.where(m, xt, zt)
|
| 1564 |
+
|
| 1565 |
+
|
| 1566 |
+
def _repaint_boundary_blend(x_gen, clean_src, mask, cf_frames):
|
| 1567 |
+
"""Blend generated latents with source at repaint boundaries."""
|
| 1568 |
+
soft = mask.float().clone()
|
| 1569 |
+
if cf_frames <= 0:
|
| 1570 |
+
m = soft.unsqueeze(-1).expand_as(x_gen)
|
| 1571 |
+
return m * x_gen + (1.0 - m) * clean_src
|
| 1572 |
+
B, T = mask.shape
|
| 1573 |
+
for b in range(B):
|
| 1574 |
+
row = mask[b]
|
| 1575 |
+
if row.all() or not row.any():
|
| 1576 |
+
continue
|
| 1577 |
+
idx = torch.nonzero(row, as_tuple=False).squeeze(-1)
|
| 1578 |
+
if idx.numel() == 0:
|
| 1579 |
+
continue
|
| 1580 |
+
left, right = idx[0].item(), idx[-1].item() + 1
|
| 1581 |
+
fs = max(left - cf_frames, 0)
|
| 1582 |
+
if left - fs > 0:
|
| 1583 |
+
soft[b, fs:left] = torch.linspace(0, 1, left - fs + 2, device=soft.device)[1:-1]
|
| 1584 |
+
fe = min(right + cf_frames, T)
|
| 1585 |
+
if fe - right > 0:
|
| 1586 |
+
soft[b, right:fe] = torch.linspace(1, 0, fe - right + 2, device=soft.device)[1:-1]
|
| 1587 |
+
m = soft.unsqueeze(-1).expand_as(x_gen)
|
| 1588 |
+
return m * x_gen + (1.0 - m) * clean_src
|
| 1589 |
+
|
| 1590 |
+
|
| 1591 |
class AceStepConditionGenerationModel(AceStepPreTrainedModel):
|
| 1592 |
"""
|
| 1593 |
Main conditional generation model for AceStep.
|
| 1594 |
+
|
| 1595 |
End-to-end model for generating audio conditioned on text, lyrics, and timbre.
|
| 1596 |
Combines encoder (for conditioning), decoder (diffusion model), tokenizer
|
| 1597 |
(for discrete tokenization), and detokenizer (for reconstruction).
|
|
|
|
| 1602 |
self.config = config
|
| 1603 |
# Diffusion model components
|
| 1604 |
self.decoder = AceStepDiTModel(config) # Main diffusion transformer
|
| 1605 |
+
# Build encoder config: use separate encoder_hidden_size when available
|
| 1606 |
+
# (4B models have encoder_hidden_size=2048 != hidden_size=2560)
|
| 1607 |
+
_enc_hs = getattr(config, "encoder_hidden_size", None) or config.hidden_size
|
| 1608 |
+
if _enc_hs != config.hidden_size:
|
| 1609 |
+
encoder_config = copy.deepcopy(config)
|
| 1610 |
+
encoder_config.hidden_size = _enc_hs
|
| 1611 |
+
encoder_config.intermediate_size = getattr(config, "encoder_intermediate_size", None) or config.intermediate_size
|
| 1612 |
+
encoder_config.num_attention_heads = getattr(config, "encoder_num_attention_heads", None) or config.num_attention_heads
|
| 1613 |
+
encoder_config.num_key_value_heads = getattr(config, "encoder_num_key_value_heads", None) or config.num_key_value_heads
|
| 1614 |
+
else:
|
| 1615 |
+
encoder_config = config
|
| 1616 |
+
self.encoder = AceStepConditionEncoder(encoder_config) # Condition encoder
|
| 1617 |
+
self.tokenizer = AceStepAudioTokenizer(encoder_config) # Audio tokenizer
|
| 1618 |
+
self.detokenizer = AudioTokenDetokenizer(encoder_config) # Audio detokenizer
|
| 1619 |
# Null condition embedding for classifier-free guidance
|
| 1620 |
+
self.null_condition_emb = nn.Parameter(torch.randn(1, 1, encoder_config.hidden_size))
|
| 1621 |
|
| 1622 |
# Initialize weights and apply final processing
|
| 1623 |
self.post_init()
|
|
|
|
| 1666 |
precomputed_lm_hints_25Hz: Optional[torch.FloatTensor] = None,
|
| 1667 |
audio_codes: torch.FloatTensor = None,
|
| 1668 |
):
|
| 1669 |
+
|
| 1670 |
dtype = hidden_states.dtype
|
| 1671 |
encoder_hidden_states, encoder_attention_mask = self.encoder(
|
| 1672 |
text_hidden_states=text_hidden_states,
|
|
|
|
| 1754 |
t_ = t.unsqueeze(-1).unsqueeze(-1)
|
| 1755 |
# Interpolate: x_t = t * x1 + (1 - t) * x0
|
| 1756 |
xt = t_ * x1 + (1.0 - t_) * x0
|
| 1757 |
+
|
| 1758 |
# Predict flow (velocity) from diffusion model
|
| 1759 |
decoder_outputs = self.decoder(
|
| 1760 |
hidden_states=xt,
|
|
|
|
| 1771 |
return {
|
| 1772 |
"diffusion_loss": diffusion_loss,
|
| 1773 |
}
|
| 1774 |
+
|
| 1775 |
def training_losses(self, **kwargs):
|
| 1776 |
return self.forward(**kwargs)
|
| 1777 |
+
|
| 1778 |
def prepare_noise(self, context_latents: torch.FloatTensor, seed: Union[int, List[int], None] = None):
|
| 1779 |
"""
|
| 1780 |
Prepare noise tensor for generation with optional seeding.
|
|
|
|
| 1815 |
return noise
|
| 1816 |
|
| 1817 |
def get_x0_from_noise(self, zt, vt, t):
|
| 1818 |
+
if t.shape[0] != zt.shape[0]:
|
| 1819 |
+
raise ValueError(f"Batch size mismatch: t has {t.shape[0]}, zt has {zt.shape[0]}")
|
| 1820 |
+
|
| 1821 |
return zt - vt * t.unsqueeze(-1).unsqueeze(-1)
|
| 1822 |
|
| 1823 |
def renoise(self, x, t, noise=None):
|
|
|
|
| 1827 |
t = t.unsqueeze(-1).unsqueeze(-1)
|
| 1828 |
xt = t * noise + (1 - t) * x
|
| 1829 |
return xt
|
| 1830 |
+
|
| 1831 |
def generate_audio(
|
| 1832 |
self,
|
| 1833 |
text_hidden_states: torch.FloatTensor,
|
|
|
|
| 1845 |
infer_method: str = "ode",
|
| 1846 |
use_cache: bool = True,
|
| 1847 |
infer_steps: int = 30,
|
| 1848 |
+
diffusion_guidance_scale: float = 7.0,
|
| 1849 |
audio_cover_strength: float = 1.0,
|
| 1850 |
non_cover_text_hidden_states: Optional[torch.FloatTensor] = None,
|
| 1851 |
non_cover_text_attention_mask: Optional[torch.FloatTensor] = None,
|
|
|
|
| 1857 |
use_adg: bool = False,
|
| 1858 |
shift: float = 1.0,
|
| 1859 |
cover_noise_strength: float = 0.0,
|
| 1860 |
+
repaint_mask: Optional[torch.Tensor] = None,
|
| 1861 |
+
clean_src_latents: Optional[torch.FloatTensor] = None,
|
| 1862 |
+
repaint_crossfade_frames: int = 10,
|
| 1863 |
+
repaint_injection_ratio: float = 0.5,
|
| 1864 |
+
sampler_mode: str = "euler",
|
| 1865 |
+
velocity_norm_threshold: float = 0.0,
|
| 1866 |
+
velocity_ema_factor: float = 0.0,
|
| 1867 |
**kwargs,
|
| 1868 |
):
|
| 1869 |
+
# Backward-compat: accept the old misspelled key "diffusion_guidance_sale"
|
| 1870 |
+
# so that callers that have not yet updated their code still work correctly.
|
| 1871 |
+
# Note: if both keys are passed simultaneously, the old key wins because Python
|
| 1872 |
+
# cannot distinguish "explicit new key" from "new key at its default value".
|
| 1873 |
+
# In practice callers should only ever pass one of the two.
|
| 1874 |
+
if "diffusion_guidance_sale" in kwargs:
|
| 1875 |
+
logger.warning(
|
| 1876 |
+
"generate_audio() received deprecated kwarg 'diffusion_guidance_sale'; "
|
| 1877 |
+
"please rename it to 'diffusion_guidance_scale'."
|
| 1878 |
+
)
|
| 1879 |
+
diffusion_guidance_scale = kwargs.pop("diffusion_guidance_sale")
|
| 1880 |
+
|
| 1881 |
if attention_mask is None:
|
| 1882 |
latent_length = src_latents.shape[1]
|
| 1883 |
attention_mask = torch.ones(src_latents.shape[0], latent_length, device=src_latents.device, dtype=src_latents.dtype)
|
|
|
|
| 1941 |
bsz, device, dtype = context_latents.shape[0], context_latents.device, context_latents.dtype
|
| 1942 |
past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache())
|
| 1943 |
momentum_buffer = MomentumBuffer()
|
| 1944 |
+
|
| 1945 |
# Cover noise initialization: blend noise with src_latents
|
| 1946 |
if cover_noise_strength > 0.0:
|
| 1947 |
# cover_noise_strength=1 means closest to src, so noise_level should be low
|
|
|
|
| 1967 |
)
|
| 1968 |
else:
|
| 1969 |
xt = noise
|
| 1970 |
+
|
| 1971 |
# main task condition
|
| 1972 |
+
do_cfg_guidance = diffusion_guidance_scale > 1.0
|
| 1973 |
if do_cfg_guidance:
|
| 1974 |
encoder_hidden_states = torch.cat([encoder_hidden_states, self.null_condition_emb.expand_as(encoder_hidden_states)], dim=0)
|
| 1975 |
encoder_attention_mask = torch.cat([encoder_attention_mask, encoder_attention_mask], dim=0)
|
| 1976 |
# src_latents
|
| 1977 |
context_latents = torch.cat([context_latents, context_latents], dim=0)
|
| 1978 |
attention_mask = torch.cat([attention_mask, attention_mask], dim=0)
|
| 1979 |
+
|
| 1980 |
+
use_heun = sampler_mode == "heun"
|
| 1981 |
+
use_norm_clamp = velocity_norm_threshold > 0.0
|
| 1982 |
+
use_ema = velocity_ema_factor > 0.0
|
| 1983 |
+
prev_vt = None
|
| 1984 |
+
if use_heun and infer_method == "sde":
|
| 1985 |
+
logger.warning("Heun sampler is not compatible with SDE; falling back to Euler.")
|
| 1986 |
+
use_heun = False
|
| 1987 |
+
|
| 1988 |
_switched_to_non_cover = False
|
| 1989 |
with torch.no_grad():
|
| 1990 |
for step_idx, (t_curr, t_prev) in enumerate(iterator):
|
|
|
|
| 2000 |
encoder_attention_mask = encoder_attention_mask_non_cover
|
| 2001 |
context_latents = context_latents_non_cover
|
| 2002 |
past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache())
|
| 2003 |
+
|
| 2004 |
x = torch.cat([xt, xt], dim=0) if do_cfg_guidance else xt
|
| 2005 |
t_curr_tensor = t_curr * torch.ones((x.shape[0],), device=device, dtype=dtype)
|
| 2006 |
decoder_outputs = self.decoder(
|
|
|
|
| 2014 |
use_cache=True,
|
| 2015 |
past_key_values=past_key_values,
|
| 2016 |
)
|
| 2017 |
+
|
| 2018 |
vt = decoder_outputs[0]
|
| 2019 |
past_key_values = decoder_outputs[1]
|
| 2020 |
apply_cfg_guidance = t_curr >= cfg_interval_start and t_curr <= cfg_interval_end
|
|
|
|
| 2025 |
vt = apg_forward(
|
| 2026 |
pred_cond=pred_cond,
|
| 2027 |
pred_uncond=pred_null_cond,
|
| 2028 |
+
guidance_scale=diffusion_guidance_scale,
|
| 2029 |
momentum_buffer=momentum_buffer,
|
| 2030 |
dims=[1],
|
| 2031 |
)
|
|
|
|
| 2035 |
noise_pred_cond=pred_cond,
|
| 2036 |
noise_pred_uncond=pred_null_cond,
|
| 2037 |
sigma=t_curr,
|
| 2038 |
+
guidance_scale=diffusion_guidance_scale,
|
| 2039 |
)
|
| 2040 |
else:
|
| 2041 |
vt = pred_cond
|
| 2042 |
+
# Velocity norm clamping — prevents outlier predictions
|
| 2043 |
+
if use_norm_clamp:
|
| 2044 |
+
vt_norm = torch.norm(vt, dim=(1, 2), keepdim=True)
|
| 2045 |
+
xt_norm = torch.norm(xt, dim=(1, 2), keepdim=True) + 1e-10
|
| 2046 |
+
scale = torch.clamp(velocity_norm_threshold * xt_norm / (vt_norm + 1e-10), max=1.0)
|
| 2047 |
+
vt = vt * scale
|
| 2048 |
+
|
| 2049 |
+
# Velocity EMA smoothing — stabilises denoising trajectory
|
| 2050 |
+
if use_ema and prev_vt is not None:
|
| 2051 |
+
vt = (1.0 - velocity_ema_factor) * vt + velocity_ema_factor * prev_vt
|
| 2052 |
+
|
| 2053 |
# Update x_t based on inference method
|
| 2054 |
if infer_method == "sde":
|
| 2055 |
# Stochastic Differential Equation: predict clean, then re-add noise
|
|
|
|
| 2057 |
pred_clean = self.get_x0_from_noise(xt, vt, t_curr_bsz)
|
| 2058 |
next_timestep = 1.0 - (float(step_idx + 1) / infer_steps)
|
| 2059 |
xt = self.renoise(pred_clean, next_timestep)
|
| 2060 |
+
t_after_step = next_timestep
|
| 2061 |
+
elif use_heun and infer_method == "ode":
|
| 2062 |
+
# Heun (second-order) ODE step via trapezoidal rule
|
| 2063 |
+
dt = t_curr - t_prev
|
| 2064 |
+
dt_tensor = dt * torch.ones((bsz,), device=device, dtype=dtype).unsqueeze(-1).unsqueeze(-1)
|
| 2065 |
+
# Predictor: Euler step to get xt_predicted at t_prev
|
| 2066 |
+
xt_predicted = xt - vt * dt_tensor
|
| 2067 |
+
# Corrector: evaluate model at the predicted point
|
| 2068 |
+
x2 = torch.cat([xt_predicted, xt_predicted], dim=0) if do_cfg_guidance else xt_predicted
|
| 2069 |
+
t_prev_tensor = t_prev * torch.ones((x2.shape[0],), device=device, dtype=dtype)
|
| 2070 |
+
# Reset KV cache for corrector pass (Heun needs fresh evaluation)
|
| 2071 |
+
corrector_kv = EncoderDecoderCache(DynamicCache(), DynamicCache())
|
| 2072 |
+
decoder_outputs2 = self.decoder(
|
| 2073 |
+
hidden_states=x2,
|
| 2074 |
+
timestep=t_prev_tensor,
|
| 2075 |
+
timestep_r=t_prev_tensor,
|
| 2076 |
+
attention_mask=attention_mask,
|
| 2077 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 2078 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 2079 |
+
context_latents=context_latents,
|
| 2080 |
+
use_cache=False,
|
| 2081 |
+
past_key_values=corrector_kv,
|
| 2082 |
+
)
|
| 2083 |
+
vt2 = decoder_outputs2[0]
|
| 2084 |
+
if do_cfg_guidance:
|
| 2085 |
+
pred_cond2, pred_null_cond2 = vt2.chunk(2)
|
| 2086 |
+
# Recompute CFG interval for corrector timestep (t_prev, not t_curr)
|
| 2087 |
+
apply_cfg_corrector = t_prev >= cfg_interval_start and t_prev <= cfg_interval_end
|
| 2088 |
+
if apply_cfg_corrector:
|
| 2089 |
+
if not use_adg:
|
| 2090 |
+
# Use basic CFG for corrector to avoid mutating APG momentum twice per step
|
| 2091 |
+
vt2 = cfg_forward(pred_cond2, pred_null_cond2, diffusion_guidance_scale)
|
| 2092 |
+
elif t_prev > 0:
|
| 2093 |
+
# Guard against sigma=0 which causes NaN in ADG division
|
| 2094 |
+
vt2 = adg_forward(
|
| 2095 |
+
latents=xt_predicted,
|
| 2096 |
+
noise_pred_cond=pred_cond2,
|
| 2097 |
+
noise_pred_uncond=pred_null_cond2,
|
| 2098 |
+
sigma=t_prev,
|
| 2099 |
+
guidance_scale=diffusion_guidance_scale,
|
| 2100 |
+
)
|
| 2101 |
+
else:
|
| 2102 |
+
vt2 = cfg_forward(pred_cond2, pred_null_cond2, diffusion_guidance_scale)
|
| 2103 |
+
else:
|
| 2104 |
+
vt2 = pred_cond2
|
| 2105 |
+
if use_norm_clamp:
|
| 2106 |
+
vt2_norm = torch.norm(vt2, dim=(1, 2), keepdim=True)
|
| 2107 |
+
xt_pred_norm = torch.norm(xt_predicted, dim=(1, 2), keepdim=True) + 1e-10
|
| 2108 |
+
scale2 = torch.clamp(velocity_norm_threshold * xt_pred_norm / (vt2_norm + 1e-10), max=1.0)
|
| 2109 |
+
vt2 = vt2 * scale2
|
| 2110 |
+
if use_ema:
|
| 2111 |
+
vt2 = (1.0 - velocity_ema_factor) * vt2 + velocity_ema_factor * vt
|
| 2112 |
+
# Average the two velocity predictions (trapezoidal rule)
|
| 2113 |
+
vt_avg = 0.5 * (vt + vt2)
|
| 2114 |
+
xt = xt - vt_avg * dt_tensor
|
| 2115 |
+
vt = vt_avg
|
| 2116 |
+
t_after_step = t_prev
|
| 2117 |
elif infer_method == "ode":
|
| 2118 |
# Ordinary Differential Equation: Euler method
|
| 2119 |
# dx/dt = -v, so x_{t+1} = x_t - v_t * dt
|
| 2120 |
dt = t_curr - t_prev
|
| 2121 |
dt_tensor = dt * torch.ones((bsz,), device=device, dtype=dtype).unsqueeze(-1).unsqueeze(-1)
|
| 2122 |
xt = xt - vt * dt_tensor
|
| 2123 |
+
t_after_step = t_prev
|
| 2124 |
+
|
| 2125 |
+
prev_vt = vt
|
| 2126 |
+
|
| 2127 |
+
injection_cutoff = round(repaint_injection_ratio * infer_steps)
|
| 2128 |
+
if repaint_mask is not None and clean_src_latents is not None and step_idx < injection_cutoff:
|
| 2129 |
+
xt = _repaint_step_injection(
|
| 2130 |
+
xt, clean_src_latents, repaint_mask, t_after_step, noise,
|
| 2131 |
+
)
|
| 2132 |
+
|
| 2133 |
x_gen = xt
|
| 2134 |
+
if repaint_mask is not None and clean_src_latents is not None and repaint_crossfade_frames > 0:
|
| 2135 |
+
x_gen = _repaint_boundary_blend(
|
| 2136 |
+
x_gen, clean_src_latents, repaint_mask, repaint_crossfade_frames,
|
| 2137 |
+
)
|
| 2138 |
+
|
| 2139 |
end_time = time.time()
|
| 2140 |
time_costs["diffusion_time_cost"] = end_time - start_time
|
| 2141 |
time_costs["diffusion_per_step_time_cost"] = time_costs["diffusion_time_cost"] / infer_steps
|
|
|
|
| 2158 |
torch.cuda.manual_seed_all(seed)
|
| 2159 |
torch.backends.cudnn.deterministic = True
|
| 2160 |
torch.backends.cudnn.benchmark = False
|
| 2161 |
+
|
| 2162 |
# Get model dtype and device
|
| 2163 |
model_dtype = next(model.parameters()).dtype
|
| 2164 |
device = next(model.parameters()).device
|
| 2165 |
+
|
| 2166 |
print(f"Testing with dtype: {model_dtype}, device: {device}, seed: {seed}")
|
| 2167 |
+
|
| 2168 |
# Test data preparation with matching dtype
|
| 2169 |
text_hidden_states = torch.randn(2, 77, 1024, dtype=model_dtype, device=device)
|
| 2170 |
text_attention_mask = torch.ones(2, 77, dtype=model_dtype, device=device)
|
|
|
|
| 2298 |
# model = model.float()
|
| 2299 |
model = model.to("cuda")
|
| 2300 |
model = model.bfloat16()
|
| 2301 |
+
test_forward(model)
|