dawidope commited on
Commit
5c006aa
·
verified ·
1 Parent(s): ea5d8bd

Sync model code with ace_step==1.6.0 to prevent runtime mutation

Browse files
acestep-v15-base/modeling_acestep_v15_base.py CHANGED
@@ -11,6 +11,7 @@
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
 
14
  import math
15
  import time
16
  from typing import Callable, List, Optional, Union
@@ -44,10 +45,10 @@ from vector_quantize_pytorch import ResidualFSQ
44
  # Local config import with fallback
45
  try:
46
  from .configuration_acestep_v15 import AceStepConfig
47
- from .apg_guidance import adg_forward, apg_forward, MomentumBuffer
48
  except ImportError:
49
  from configuration_acestep_v15 import AceStepConfig
50
- from apg_guidance import adg_forward, apg_forward, MomentumBuffer
51
 
52
 
53
  logger = logging.get_logger(__name__)
@@ -115,7 +116,7 @@ def create_4d_mask(
115
  # We want to mask out invalid keys (columns)
116
  # Expand shape: [Batch, 1, 1, Seq_Len]
117
  padding_mask_4d = attention_mask.view(attention_mask.shape[0], 1, 1, seq_len).to(torch.bool)
118
-
119
  # Broadcasting: Geometry Mask [1, 1, L, L] & Padding Mask [B, 1, 1, L]
120
  # Result shape: [B, 1, L, L]
121
  valid_mask = valid_mask & padding_mask_4d
@@ -125,13 +126,13 @@ def create_4d_mask(
125
  # ------------------------------------------------------
126
  # Get the minimal value for current dtype
127
  min_dtype = torch.finfo(dtype).min
128
-
129
  # Create result tensor filled with -inf by default
130
  mask_tensor = torch.full(valid_mask.shape, min_dtype, dtype=dtype, device=device)
131
-
132
  # Set valid positions to 0.0
133
  mask_tensor.masked_fill_(valid_mask, 0.0)
134
-
135
  return mask_tensor
136
 
137
 
@@ -200,7 +201,7 @@ def sample_t_r(batch_size, device, dtype, data_proportion=0.0, timestep_mu=-0.4,
200
  class TimestepEmbedding(nn.Module):
201
  """
202
  Timestep embedding module for diffusion models.
203
-
204
  Converts timestep values into high-dimensional embeddings using sinusoidal
205
  positional encoding, followed by MLP layers. Used for conditioning diffusion
206
  models on timestep information.
@@ -217,7 +218,7 @@ class TimestepEmbedding(nn.Module):
217
  self.act1 = nn.SiLU()
218
  self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim, bias=True)
219
  self.in_channels = in_channels
220
-
221
  self.act2 = nn.SiLU()
222
  self.time_proj = nn.Linear(time_embed_dim, time_embed_dim * 6)
223
  self.scale = scale
@@ -305,7 +306,7 @@ class AceStepAttention(nn.Module):
305
 
306
  # Determine if this is cross-attention (requires encoder_hidden_states)
307
  is_cross_attention = self.is_cross_attention and encoder_hidden_states is not None
308
-
309
  # Cross-attention path: attend to encoder hidden states
310
  if is_cross_attention:
311
  encoder_hidden_shape = (*encoder_hidden_states.shape[:-1], -1, self.head_dim)
@@ -313,7 +314,7 @@ class AceStepAttention(nn.Module):
313
  is_updated = past_key_value.is_updated.get(self.layer_idx)
314
  # After the first generated token, we can reuse all key/value states from cache
315
  curr_past_key_value = past_key_value.cross_attention_cache
316
-
317
  # Conditions for calculating key and value states
318
  if not is_updated:
319
  # Compute and cache K/V for the first time
@@ -331,7 +332,7 @@ class AceStepAttention(nn.Module):
331
  # No cache used, compute K/V directly
332
  key_states = self.k_norm(self.k_proj(encoder_hidden_states).view(encoder_hidden_shape)).transpose(1, 2)
333
  value_states = self.v_proj(encoder_hidden_states).view(encoder_hidden_shape).transpose(1, 2)
334
-
335
  # Self-attention path: attend to the same sequence
336
  else:
337
  # Project and normalize key/value states for self-attention
@@ -353,7 +354,7 @@ class AceStepAttention(nn.Module):
353
  attention_interface: Callable = eager_attention_forward
354
  elif self.config._attn_implementation != "eager":
355
  attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
356
-
357
  attn_output, attn_weights = attention_interface(
358
  self,
359
  query_states,
@@ -443,12 +444,12 @@ class AceStepEncoderLayer(GradientCheckpointingLayer):
443
  class AceStepDiTLayer(GradientCheckpointingLayer):
444
  """
445
  DiT (Diffusion Transformer) layer for AceStep model.
446
-
447
  Implements a transformer layer with three main components:
448
  1. Self-attention with adaptive layer norm (AdaLN)
449
  2. Cross-attention (optional) for conditioning on encoder outputs
450
  3. Feed-forward MLP with adaptive layer norm
451
-
452
  Uses scale-shift modulation from timestep embeddings for adaptive normalization.
453
  """
454
  def __init__(self, config: AceStepConfig, layer_idx: int, use_cross_attention: bool = True):
@@ -471,7 +472,7 @@ class AceStepDiTLayer(GradientCheckpointingLayer):
471
  # Scale-shift table for adaptive layer norm modulation (6 values: 3 for self-attn, 3 for MLP)
472
  self.scale_shift_table = nn.Parameter(torch.randn(1, 6, config.hidden_size) / config.hidden_size**0.5)
473
  self.attention_type = config.layer_types[layer_idx]
474
-
475
  def forward(
476
  self,
477
  hidden_states: torch.Tensor,
@@ -577,14 +578,14 @@ class AceStepPreTrainedModel(PreTrainedModel):
577
  class AceStepLyricEncoder(AceStepPreTrainedModel):
578
  """
579
  Encoder for processing lyric text embeddings.
580
-
581
  Encodes lyric text hidden states using a transformer encoder architecture
582
  with bidirectional attention. Projects text embeddings to model hidden size
583
  and processes them through multiple encoder layers.
584
  """
585
  def __init__(self, config):
586
  super().__init__(config)
587
-
588
  # Project text embeddings to model hidden size
589
  self.embed_tokens = nn.Linear(config.text_hidden_dim, config.hidden_size)
590
  self.norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -618,7 +619,7 @@ class AceStepLyricEncoder(AceStepPreTrainedModel):
618
  assert input_ids is None, "Only `input_ids` is supported for the lyric encoder."
619
  assert attention_mask is not None, "Attention mask must be provided for the lyric encoder."
620
  assert inputs_embeds is not None, "Inputs embeddings must be provided for the lyric encoder."
621
-
622
  # Project input embeddings: N x T x text_hidden_dim -> N x T x hidden_size
623
  inputs_embeds = self.embed_tokens(inputs_embeds)
624
  # Cache position: only used for mask construction (not for actual caching)
@@ -632,7 +633,7 @@ class AceStepLyricEncoder(AceStepPreTrainedModel):
632
  seq_len = inputs_embeds.shape[1]
633
  dtype = inputs_embeds.dtype
634
  device = inputs_embeds.device
635
-
636
  # 判断是否使用 Flash Attention 2
637
  is_flash_attn = (self.config._attn_implementation == "flash_attention_2")
638
 
@@ -649,7 +650,7 @@ class AceStepLyricEncoder(AceStepPreTrainedModel):
649
  # 如果没有 padding mask,传 None。
650
  # 滑动窗口逻辑由 Layer 内部传给 FA kernel 的 sliding_window 参数控制。
651
  full_attn_mask = attention_mask
652
-
653
  # 这里的逻辑是:如果配置启用了滑动窗口,FA 模式下我们也只需要传基础的 padding mask
654
  # Layer 会自己决定是否调用带 sliding window 的 kernel
655
  sliding_attn_mask = attention_mask if self.config.use_sliding_window else None
@@ -659,7 +660,7 @@ class AceStepLyricEncoder(AceStepPreTrainedModel):
659
  # 场景 B: CPU / Mac / SDPA (Eager 模式)
660
  # -------------------------------------------------------
661
  # 必须手动生成 4D Mask [B, 1, L, L]
662
-
663
  # 1. Full Attention (Bidirectional, Global)
664
  # 对应原来的 create_causal_mask + bidirectional
665
  full_attn_mask = create_4d_mask(
@@ -734,7 +735,7 @@ class AceStepLyricEncoder(AceStepPreTrainedModel):
734
  class AttentionPooler(AceStepPreTrainedModel):
735
  """
736
  Attention-based pooling module.
737
-
738
  Pools sequences of patches using a special token and attention mechanism.
739
  The special token attends to all patches and its output is used as the
740
  pooled representation. Used for aggregating patch-level features into
@@ -782,7 +783,7 @@ class AttentionPooler(AceStepPreTrainedModel):
782
  seq_len = x.shape[1]
783
  dtype = x.dtype
784
  device = x.device
785
-
786
  # 判断是否使用 Flash Attention 2
787
  is_flash_attn = (self.config._attn_implementation == "flash_attention_2")
788
 
@@ -799,7 +800,7 @@ class AttentionPooler(AceStepPreTrainedModel):
799
  # 如果没有 padding mask,传 None。
800
  # 滑动窗口逻辑由 Layer 内部传给 FA kernel 的 sliding_window 参数控制。
801
  full_attn_mask = attention_mask
802
-
803
  # 这里的逻辑是:如果配置启用了滑动窗口,FA 模式下我们也只需要传基础的 padding mask
804
  # Layer 会自己决定是否调用带 sliding window 的 kernel
805
  sliding_attn_mask = attention_mask if self.config.use_sliding_window else None
@@ -809,7 +810,7 @@ class AttentionPooler(AceStepPreTrainedModel):
809
  # 场景 B: CPU / Mac / SDPA (Eager 模式)
810
  # -------------------------------------------------------
811
  # 必须手动生成 4D Mask [B, 1, L, L]
812
-
813
  # 1. Full Attention (Bidirectional, Global)
814
  # 对应原来的 create_causal_mask + bidirectional
815
  full_attn_mask = create_4d_mask(
@@ -840,7 +841,7 @@ class AttentionPooler(AceStepPreTrainedModel):
840
  "full_attention": full_attn_mask,
841
  "sliding_attention": sliding_attn_mask,
842
  }
843
-
844
  for layer_module in self.layers:
845
  layer_outputs = layer_module(
846
  hidden_states,
@@ -852,7 +853,7 @@ class AttentionPooler(AceStepPreTrainedModel):
852
  hidden_states = layer_outputs[0]
853
 
854
  hidden_states = self.norm(hidden_states)
855
-
856
  # Extract the special token output (first position) as pooled representation
857
  cls_output = hidden_states[:, 0, :]
858
  cls_output = rearrange(cls_output, "(b t) c -> b t c", b=B)
@@ -862,7 +863,7 @@ class AttentionPooler(AceStepPreTrainedModel):
862
  class AudioTokenDetokenizer(AceStepPreTrainedModel):
863
  """
864
  Audio token detokenizer module.
865
-
866
  Converts quantized audio tokens back to continuous acoustic representations.
867
  Expands each token into multiple patches using special tokens, processes them
868
  through encoder layers, and projects to acoustic hidden dimension.
@@ -917,7 +918,7 @@ class AudioTokenDetokenizer(AceStepPreTrainedModel):
917
  seq_len = x.shape[1]
918
  dtype = x.dtype
919
  device = x.device
920
-
921
  # 判断是否使用 Flash Attention 2
922
  is_flash_attn = (self.config._attn_implementation == "flash_attention_2")
923
 
@@ -934,7 +935,7 @@ class AudioTokenDetokenizer(AceStepPreTrainedModel):
934
  # 如果没有 padding mask,传 None。
935
  # 滑动窗口逻辑由 Layer 内部传给 FA kernel 的 sliding_window 参数控制。
936
  full_attn_mask = attention_mask
937
-
938
  # 这里的逻辑是:如果配置启用了滑动窗口,FA 模式下我们也只需要传基础的 padding mask
939
  # Layer 会自己决定是否调用带 sliding window 的 kernel
940
  sliding_attn_mask = attention_mask if self.config.use_sliding_window else None
@@ -944,7 +945,7 @@ class AudioTokenDetokenizer(AceStepPreTrainedModel):
944
  # 场景 B: CPU / Mac / SDPA (Eager 模式)
945
  # -------------------------------------------------------
946
  # 必须手动生成 4D Mask [B, 1, L, L]
947
-
948
  # 1. Full Attention (Bidirectional, Global)
949
  # 对应原来的 create_causal_mask + bidirectional
950
  full_attn_mask = create_4d_mask(
@@ -975,7 +976,7 @@ class AudioTokenDetokenizer(AceStepPreTrainedModel):
975
  "full_attention": full_attn_mask,
976
  "sliding_attention": sliding_attn_mask,
977
  }
978
-
979
  for layer_module in self.layers:
980
  layer_outputs = layer_module(
981
  hidden_states,
@@ -987,7 +988,7 @@ class AudioTokenDetokenizer(AceStepPreTrainedModel):
987
  hidden_states = layer_outputs[0]
988
 
989
  hidden_states = self.norm(hidden_states)
990
-
991
  hidden_states = self.proj_out(hidden_states)
992
 
993
  hidden_states = rearrange(hidden_states, "(b t) p c -> b (t p) c", b=B, p=self.config.pool_window_size)
@@ -997,7 +998,7 @@ class AudioTokenDetokenizer(AceStepPreTrainedModel):
997
  class AceStepTimbreEncoder(AceStepPreTrainedModel):
998
  """
999
  Encoder for extracting timbre embeddings from reference audio.
1000
-
1001
  Processes packed reference audio acoustic features to extract timbre
1002
  representations. Uses a special token (CLS-like) to aggregate information
1003
  from the entire reference audio sequence. Outputs are unpacked back to
@@ -1005,7 +1006,7 @@ class AceStepTimbreEncoder(AceStepPreTrainedModel):
1005
  """
1006
  def __init__(self, config):
1007
  super().__init__(config)
1008
-
1009
  # Project acoustic features to model hidden size
1010
  self.embed_tokens = nn.Linear(config.timbre_hidden_dim, config.hidden_size)
1011
  self.norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -1036,40 +1037,40 @@ class AceStepTimbreEncoder(AceStepPreTrainedModel):
1036
  N, d = timbre_embs_packed.shape
1037
  device = timbre_embs_packed.device
1038
  dtype = timbre_embs_packed.dtype
1039
-
1040
  # Get batch size
1041
  B = int(refer_audio_order_mask.max().item() + 1)
1042
-
1043
  # Calculate element count and positions for each batch
1044
  counts = torch.bincount(refer_audio_order_mask, minlength=B)
1045
  max_count = counts.max().item()
1046
-
1047
  # Calculate positions within batch
1048
  sorted_indices = torch.argsort(refer_audio_order_mask * N + torch.arange(N, device=device), stable=True)
1049
  sorted_batch_ids = refer_audio_order_mask[sorted_indices]
1050
-
1051
  positions = torch.arange(N, device=device)
1052
- batch_starts = torch.cat([torch.tensor([0], device=device),
1053
  torch.cumsum(counts, dim=0)[:-1]])
1054
  positions_in_sorted = positions - batch_starts[sorted_batch_ids]
1055
-
1056
  inverse_indices = torch.empty_like(sorted_indices)
1057
  inverse_indices[sorted_indices] = torch.arange(N, device=device)
1058
  positions_in_batch = positions_in_sorted[inverse_indices]
1059
-
1060
  # Use one-hot encoding and matrix multiplication (gradient-friendly approach)
1061
  # Create one-hot encoding
1062
  indices_2d = refer_audio_order_mask * max_count + positions_in_batch # (N,)
1063
  one_hot = F.one_hot(indices_2d, num_classes=B * max_count).to(dtype) # (N, B*max_count)
1064
-
1065
  # Rearrange using matrix multiplication
1066
  timbre_embs_flat = one_hot.t() @ timbre_embs_packed # (B*max_count, d)
1067
  timbre_embs_unpack = timbre_embs_flat.reshape(B, max_count, d)
1068
-
1069
  # Create mask indicating valid positions
1070
  mask_flat = (one_hot.sum(dim=0) > 0).long() # (B*max_count,)
1071
  new_mask = mask_flat.reshape(B, max_count)
1072
-
1073
  return timbre_embs_unpack, new_mask
1074
 
1075
  @can_return_tuple
@@ -1093,7 +1094,7 @@ class AceStepTimbreEncoder(AceStepPreTrainedModel):
1093
  seq_len = inputs_embeds.shape[1]
1094
  dtype = inputs_embeds.dtype
1095
  device = inputs_embeds.device
1096
-
1097
  # 判断是否使用 Flash Attention 2
1098
  is_flash_attn = (self.config._attn_implementation == "flash_attention_2")
1099
 
@@ -1110,7 +1111,7 @@ class AceStepTimbreEncoder(AceStepPreTrainedModel):
1110
  # 如果没有 padding mask,传 None。
1111
  # 滑动窗口逻辑由 Layer 内部传给 FA kernel 的 sliding_window 参数控制。
1112
  full_attn_mask = attention_mask
1113
-
1114
  # 这里的逻辑是:如果配置启用了滑动窗口,FA 模式下我们也只需要传基础的 padding mask
1115
  # Layer 会自己决定是否调用带 sliding window 的 kernel
1116
  sliding_attn_mask = attention_mask if self.config.use_sliding_window else None
@@ -1120,7 +1121,7 @@ class AceStepTimbreEncoder(AceStepPreTrainedModel):
1120
  # 场景 B: CPU / Mac / SDPA (Eager 模式)
1121
  # -------------------------------------------------------
1122
  # 必须手动生成 4D Mask [B, 1, L, L]
1123
-
1124
  # 1. Full Attention (Bidirectional, Global)
1125
  # 对应原来的 create_causal_mask + bidirectional
1126
  full_attn_mask = create_4d_mask(
@@ -1151,7 +1152,7 @@ class AceStepTimbreEncoder(AceStepPreTrainedModel):
1151
  "full_attention": full_attn_mask,
1152
  "sliding_attention": sliding_attn_mask,
1153
  }
1154
-
1155
  # Initialize hidden states
1156
  hidden_states = inputs_embeds
1157
 
@@ -1181,7 +1182,7 @@ class AceStepTimbreEncoder(AceStepPreTrainedModel):
1181
  class AceStepAudioTokenizer(AceStepPreTrainedModel):
1182
  """
1183
  Audio tokenizer module.
1184
-
1185
  Converts continuous acoustic features into discrete quantized tokens.
1186
  Process: project -> pool patches -> quantize. Used for converting audio
1187
  representations into discrete tokens for processing by the diffusion model.
@@ -1208,7 +1209,7 @@ class AceStepAudioTokenizer(AceStepPreTrainedModel):
1208
  hidden_states: Optional[torch.FloatTensor] = None,
1209
  **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
1210
  ) -> BaseModelOutput:
1211
-
1212
  # Project acoustic features to hidden size
1213
  hidden_states = self.audio_acoustic_proj(hidden_states)
1214
  # Pool sequences: N x T//pool_window_size x pool_window_size x d -> N x T//pool_window_size x d
@@ -1225,14 +1226,14 @@ class AceStepAudioTokenizer(AceStepPreTrainedModel):
1225
  class Lambda(nn.Module):
1226
  """
1227
  Wrapper module for arbitrary lambda functions.
1228
-
1229
  Allows using lambda functions in nn.Sequential by wrapping them in a Module.
1230
  Useful for simple transformations like transpose operations.
1231
  """
1232
  def __init__(self, func):
1233
  super().__init__()
1234
  self.func = func
1235
-
1236
  def forward(self, x):
1237
  return self.func(x)
1238
 
@@ -1240,7 +1241,7 @@ class Lambda(nn.Module):
1240
  class AceStepDiTModel(AceStepPreTrainedModel):
1241
  """
1242
  DiT (Diffusion Transformer) model for AceStep.
1243
-
1244
  Main diffusion model that generates audio latents conditioned on text, lyrics,
1245
  and timbre. Uses patch-based processing with transformer layers, timestep
1246
  conditioning, and cross-attention to encoder outputs.
@@ -1258,7 +1259,7 @@ class AceStepDiTModel(AceStepPreTrainedModel):
1258
  inner_dim = config.hidden_size
1259
  patch_size = config.patch_size
1260
  self.patch_size = patch_size
1261
-
1262
  # Input projection: patch embedding using 1D convolution
1263
  # Converts sequence into patches for efficient processing
1264
  self.proj_in = nn.Sequential(
@@ -1277,9 +1278,10 @@ class AceStepDiTModel(AceStepPreTrainedModel):
1277
  # Two embeddings: one for timestep t, one for timestep difference (t - r)
1278
  self.time_embed = TimestepEmbedding(in_channels=256, time_embed_dim=inner_dim)
1279
  self.time_embed_r = TimestepEmbedding(in_channels=256, time_embed_dim=inner_dim)
1280
-
1281
  # Project encoder hidden states to model dimension
1282
- self.condition_embedder = nn.Linear(inner_dim, inner_dim, bias=True)
 
1283
 
1284
  # Output normalization and projection
1285
  # Adaptive layer norm with scale-shift modulation, then de-patchify
@@ -1330,7 +1332,7 @@ class AceStepDiTModel(AceStepPreTrainedModel):
1330
  use_cache = False
1331
  if self.training:
1332
  use_cache = False
1333
-
1334
  # Initialize cache if needed (only during inference for auto-regressive generation)
1335
  if not self.training and use_cache and past_key_values is None:
1336
  past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache())
@@ -1357,14 +1359,14 @@ class AceStepDiTModel(AceStepPreTrainedModel):
1357
  # Project input to patches and project encoder states
1358
  hidden_states = self.proj_in(hidden_states)
1359
  encoder_hidden_states = self.condition_embedder(encoder_hidden_states)
1360
-
1361
  # Cache positions
1362
  if cache_position is None:
1363
  past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
1364
  cache_position = torch.arange(
1365
  past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=hidden_states.device
1366
  )
1367
-
1368
  # Position IDs
1369
  if position_ids is None:
1370
  position_ids = cache_position.unsqueeze(0)
@@ -1374,7 +1376,7 @@ class AceStepDiTModel(AceStepPreTrainedModel):
1374
  encoder_seq_len = encoder_hidden_states.shape[1]
1375
  dtype = hidden_states.dtype
1376
  device = hidden_states.device
1377
-
1378
  # 判断是否使用 Flash Attention 2
1379
  is_flash_attn = (self.config._attn_implementation == "flash_attention_2")
1380
 
@@ -1392,7 +1394,7 @@ class AceStepDiTModel(AceStepPreTrainedModel):
1392
  # 如果没有 padding mask,传 None。
1393
  # 滑动窗口逻辑由 Layer 内部传给 FA kernel 的 sliding_window 参数控制。
1394
  full_attn_mask = attention_mask
1395
-
1396
  # 这里的逻辑是:如果配置启用了滑动窗口,FA 模式下我们也只需要传基础的 padding mask
1397
  # Layer 会自己决定是否调用带 sliding window 的 kernel
1398
  sliding_attn_mask = attention_mask if self.config.use_sliding_window else None
@@ -1402,7 +1404,7 @@ class AceStepDiTModel(AceStepPreTrainedModel):
1402
  # 场景 B: CPU / Mac / SDPA (Eager 模式)
1403
  # -------------------------------------------------------
1404
  # 必须手动生成 4D Mask [B, 1, L, L]
1405
-
1406
  # 1. Full Attention (Bidirectional, Global)
1407
  # 对应原来的 create_causal_mask + bidirectional
1408
  full_attn_mask = create_4d_mask(
@@ -1415,7 +1417,7 @@ class AceStepDiTModel(AceStepPreTrainedModel):
1415
  is_causal=False # <--- 关键:双向注意力
1416
  )
1417
  max_len = max(seq_len, encoder_seq_len)
1418
-
1419
  encoder_attention_mask = create_4d_mask(
1420
  seq_len=max_len,
1421
  dtype=dtype,
@@ -1483,7 +1485,7 @@ class AceStepDiTModel(AceStepPreTrainedModel):
1483
  # Extract the last element which is cross_attn_weights
1484
  if len(layer_outputs) >= 3:
1485
  all_cross_attentions += (layer_outputs[2],)
1486
-
1487
  if return_hidden_states:
1488
  return hidden_states
1489
 
@@ -1496,10 +1498,10 @@ class AceStepDiTModel(AceStepPreTrainedModel):
1496
  hidden_states = (self.norm_out(hidden_states) * (1 + scale) + shift).type_as(hidden_states)
1497
  # Project output: de-patchify back to original sequence format
1498
  hidden_states = self.proj_out(hidden_states)
1499
-
1500
  # Crop back to original sequence length to ensure exact length match (remove padding)
1501
  hidden_states = hidden_states[:, :original_seq_len, :]
1502
-
1503
  outputs = (hidden_states, past_key_values)
1504
 
1505
  if output_attentions:
@@ -1509,7 +1511,7 @@ class AceStepDiTModel(AceStepPreTrainedModel):
1509
  class AceStepConditionEncoder(AceStepPreTrainedModel):
1510
  """
1511
  Condition encoder for AceStep model.
1512
-
1513
  Encodes multiple conditioning inputs (text, lyrics, timbre) and packs them
1514
  into a single sequence for cross-attention in the diffusion model. Handles
1515
  projection, encoding, and sequence packing.
@@ -1554,10 +1556,42 @@ class AceStepConditionEncoder(AceStepPreTrainedModel):
1554
  return encoder_hidden_states, encoder_attention_mask
1555
 
1556
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1557
  class AceStepConditionGenerationModel(AceStepPreTrainedModel):
1558
  """
1559
  Main conditional generation model for AceStep.
1560
-
1561
  End-to-end model for generating audio conditioned on text, lyrics, and timbre.
1562
  Combines encoder (for conditioning), decoder (diffusion model), tokenizer
1563
  (for discrete tokenization), and detokenizer (for reconstruction).
@@ -1568,11 +1602,22 @@ class AceStepConditionGenerationModel(AceStepPreTrainedModel):
1568
  self.config = config
1569
  # Diffusion model components
1570
  self.decoder = AceStepDiTModel(config) # Main diffusion transformer
1571
- self.encoder = AceStepConditionEncoder(config) # Condition encoder
1572
- self.tokenizer = AceStepAudioTokenizer(config) # Audio tokenizer
1573
- self.detokenizer = AudioTokenDetokenizer(config) # Audio detokenizer
 
 
 
 
 
 
 
 
 
 
 
1574
  # Null condition embedding for classifier-free guidance
1575
- self.null_condition_emb = nn.Parameter(torch.randn(1, 1, config.hidden_size))
1576
 
1577
  # Initialize weights and apply final processing
1578
  self.post_init()
@@ -1621,7 +1666,7 @@ class AceStepConditionGenerationModel(AceStepPreTrainedModel):
1621
  precomputed_lm_hints_25Hz: Optional[torch.FloatTensor] = None,
1622
  audio_codes: torch.FloatTensor = None,
1623
  ):
1624
-
1625
  dtype = hidden_states.dtype
1626
  encoder_hidden_states, encoder_attention_mask = self.encoder(
1627
  text_hidden_states=text_hidden_states,
@@ -1709,7 +1754,7 @@ class AceStepConditionGenerationModel(AceStepPreTrainedModel):
1709
  t_ = t.unsqueeze(-1).unsqueeze(-1)
1710
  # Interpolate: x_t = t * x1 + (1 - t) * x0
1711
  xt = t_ * x1 + (1.0 - t_) * x0
1712
-
1713
  # Predict flow (velocity) from diffusion model
1714
  decoder_outputs = self.decoder(
1715
  hidden_states=xt,
@@ -1726,10 +1771,10 @@ class AceStepConditionGenerationModel(AceStepPreTrainedModel):
1726
  return {
1727
  "diffusion_loss": diffusion_loss,
1728
  }
1729
-
1730
  def training_losses(self, **kwargs):
1731
  return self.forward(**kwargs)
1732
-
1733
  def prepare_noise(self, context_latents: torch.FloatTensor, seed: Union[int, List[int], None] = None):
1734
  """
1735
  Prepare noise tensor for generation with optional seeding.
@@ -1770,6 +1815,9 @@ class AceStepConditionGenerationModel(AceStepPreTrainedModel):
1770
  return noise
1771
 
1772
  def get_x0_from_noise(self, zt, vt, t):
 
 
 
1773
  return zt - vt * t.unsqueeze(-1).unsqueeze(-1)
1774
 
1775
  def renoise(self, x, t, noise=None):
@@ -1779,7 +1827,7 @@ class AceStepConditionGenerationModel(AceStepPreTrainedModel):
1779
  t = t.unsqueeze(-1).unsqueeze(-1)
1780
  xt = t * noise + (1 - t) * x
1781
  return xt
1782
-
1783
  def generate_audio(
1784
  self,
1785
  text_hidden_states: torch.FloatTensor,
@@ -1797,7 +1845,7 @@ class AceStepConditionGenerationModel(AceStepPreTrainedModel):
1797
  infer_method: str = "ode",
1798
  use_cache: bool = True,
1799
  infer_steps: int = 30,
1800
- diffusion_guidance_sale: float = 7.0,
1801
  audio_cover_strength: float = 1.0,
1802
  non_cover_text_hidden_states: Optional[torch.FloatTensor] = None,
1803
  non_cover_text_attention_mask: Optional[torch.FloatTensor] = None,
@@ -1809,8 +1857,27 @@ class AceStepConditionGenerationModel(AceStepPreTrainedModel):
1809
  use_adg: bool = False,
1810
  shift: float = 1.0,
1811
  cover_noise_strength: float = 0.0,
 
 
 
 
 
 
 
1812
  **kwargs,
1813
  ):
 
 
 
 
 
 
 
 
 
 
 
 
1814
  if attention_mask is None:
1815
  latent_length = src_latents.shape[1]
1816
  attention_mask = torch.ones(src_latents.shape[0], latent_length, device=src_latents.device, dtype=src_latents.dtype)
@@ -1874,7 +1941,7 @@ class AceStepConditionGenerationModel(AceStepPreTrainedModel):
1874
  bsz, device, dtype = context_latents.shape[0], context_latents.device, context_latents.dtype
1875
  past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache())
1876
  momentum_buffer = MomentumBuffer()
1877
-
1878
  # Cover noise initialization: blend noise with src_latents
1879
  if cover_noise_strength > 0.0:
1880
  # cover_noise_strength=1 means closest to src, so noise_level should be low
@@ -1900,16 +1967,24 @@ class AceStepConditionGenerationModel(AceStepPreTrainedModel):
1900
  )
1901
  else:
1902
  xt = noise
1903
-
1904
  # main task condition
1905
- do_cfg_guidance = diffusion_guidance_sale > 1.0
1906
  if do_cfg_guidance:
1907
  encoder_hidden_states = torch.cat([encoder_hidden_states, self.null_condition_emb.expand_as(encoder_hidden_states)], dim=0)
1908
  encoder_attention_mask = torch.cat([encoder_attention_mask, encoder_attention_mask], dim=0)
1909
  # src_latents
1910
  context_latents = torch.cat([context_latents, context_latents], dim=0)
1911
  attention_mask = torch.cat([attention_mask, attention_mask], dim=0)
1912
-
 
 
 
 
 
 
 
 
1913
  _switched_to_non_cover = False
1914
  with torch.no_grad():
1915
  for step_idx, (t_curr, t_prev) in enumerate(iterator):
@@ -1925,7 +2000,7 @@ class AceStepConditionGenerationModel(AceStepPreTrainedModel):
1925
  encoder_attention_mask = encoder_attention_mask_non_cover
1926
  context_latents = context_latents_non_cover
1927
  past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache())
1928
-
1929
  x = torch.cat([xt, xt], dim=0) if do_cfg_guidance else xt
1930
  t_curr_tensor = t_curr * torch.ones((x.shape[0],), device=device, dtype=dtype)
1931
  decoder_outputs = self.decoder(
@@ -1939,7 +2014,7 @@ class AceStepConditionGenerationModel(AceStepPreTrainedModel):
1939
  use_cache=True,
1940
  past_key_values=past_key_values,
1941
  )
1942
-
1943
  vt = decoder_outputs[0]
1944
  past_key_values = decoder_outputs[1]
1945
  apply_cfg_guidance = t_curr >= cfg_interval_start and t_curr <= cfg_interval_end
@@ -1950,7 +2025,7 @@ class AceStepConditionGenerationModel(AceStepPreTrainedModel):
1950
  vt = apg_forward(
1951
  pred_cond=pred_cond,
1952
  pred_uncond=pred_null_cond,
1953
- guidance_scale=diffusion_guidance_sale,
1954
  momentum_buffer=momentum_buffer,
1955
  dims=[1],
1956
  )
@@ -1960,10 +2035,21 @@ class AceStepConditionGenerationModel(AceStepPreTrainedModel):
1960
  noise_pred_cond=pred_cond,
1961
  noise_pred_uncond=pred_null_cond,
1962
  sigma=t_curr,
1963
- guidance_scale=diffusion_guidance_sale,
1964
  )
1965
  else:
1966
  vt = pred_cond
 
 
 
 
 
 
 
 
 
 
 
1967
  # Update x_t based on inference method
1968
  if infer_method == "sde":
1969
  # Stochastic Differential Equation: predict clean, then re-add noise
@@ -1971,14 +2057,85 @@ class AceStepConditionGenerationModel(AceStepPreTrainedModel):
1971
  pred_clean = self.get_x0_from_noise(xt, vt, t_curr_bsz)
1972
  next_timestep = 1.0 - (float(step_idx + 1) / infer_steps)
1973
  xt = self.renoise(pred_clean, next_timestep)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1974
  elif infer_method == "ode":
1975
  # Ordinary Differential Equation: Euler method
1976
  # dx/dt = -v, so x_{t+1} = x_t - v_t * dt
1977
  dt = t_curr - t_prev
1978
  dt_tensor = dt * torch.ones((bsz,), device=device, dtype=dtype).unsqueeze(-1).unsqueeze(-1)
1979
  xt = xt - vt * dt_tensor
1980
-
 
 
 
 
 
 
 
 
 
1981
  x_gen = xt
 
 
 
 
 
1982
  end_time = time.time()
1983
  time_costs["diffusion_time_cost"] = end_time - start_time
1984
  time_costs["diffusion_per_step_time_cost"] = time_costs["diffusion_time_cost"] / infer_steps
@@ -2001,13 +2158,13 @@ def test_forward(model, seed=42):
2001
  torch.cuda.manual_seed_all(seed)
2002
  torch.backends.cudnn.deterministic = True
2003
  torch.backends.cudnn.benchmark = False
2004
-
2005
  # Get model dtype and device
2006
  model_dtype = next(model.parameters()).dtype
2007
  device = next(model.parameters()).device
2008
-
2009
  print(f"Testing with dtype: {model_dtype}, device: {device}, seed: {seed}")
2010
-
2011
  # Test data preparation with matching dtype
2012
  text_hidden_states = torch.randn(2, 77, 1024, dtype=model_dtype, device=device)
2013
  text_attention_mask = torch.ones(2, 77, dtype=model_dtype, device=device)
@@ -2141,4 +2298,4 @@ if __name__ == "__main__":
2141
  # model = model.float()
2142
  model = model.to("cuda")
2143
  model = model.bfloat16()
2144
- test_forward(model)
 
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
+ import copy
15
  import math
16
  import time
17
  from typing import Callable, List, Optional, Union
 
45
  # Local config import with fallback
46
  try:
47
  from .configuration_acestep_v15 import AceStepConfig
48
+ from .apg_guidance import adg_forward, apg_forward, cfg_forward, MomentumBuffer
49
  except ImportError:
50
  from configuration_acestep_v15 import AceStepConfig
51
+ from apg_guidance import adg_forward, apg_forward, cfg_forward, MomentumBuffer
52
 
53
 
54
  logger = logging.get_logger(__name__)
 
116
  # We want to mask out invalid keys (columns)
117
  # Expand shape: [Batch, 1, 1, Seq_Len]
118
  padding_mask_4d = attention_mask.view(attention_mask.shape[0], 1, 1, seq_len).to(torch.bool)
119
+
120
  # Broadcasting: Geometry Mask [1, 1, L, L] & Padding Mask [B, 1, 1, L]
121
  # Result shape: [B, 1, L, L]
122
  valid_mask = valid_mask & padding_mask_4d
 
126
  # ------------------------------------------------------
127
  # Get the minimal value for current dtype
128
  min_dtype = torch.finfo(dtype).min
129
+
130
  # Create result tensor filled with -inf by default
131
  mask_tensor = torch.full(valid_mask.shape, min_dtype, dtype=dtype, device=device)
132
+
133
  # Set valid positions to 0.0
134
  mask_tensor.masked_fill_(valid_mask, 0.0)
135
+
136
  return mask_tensor
137
 
138
 
 
201
  class TimestepEmbedding(nn.Module):
202
  """
203
  Timestep embedding module for diffusion models.
204
+
205
  Converts timestep values into high-dimensional embeddings using sinusoidal
206
  positional encoding, followed by MLP layers. Used for conditioning diffusion
207
  models on timestep information.
 
218
  self.act1 = nn.SiLU()
219
  self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim, bias=True)
220
  self.in_channels = in_channels
221
+
222
  self.act2 = nn.SiLU()
223
  self.time_proj = nn.Linear(time_embed_dim, time_embed_dim * 6)
224
  self.scale = scale
 
306
 
307
  # Determine if this is cross-attention (requires encoder_hidden_states)
308
  is_cross_attention = self.is_cross_attention and encoder_hidden_states is not None
309
+
310
  # Cross-attention path: attend to encoder hidden states
311
  if is_cross_attention:
312
  encoder_hidden_shape = (*encoder_hidden_states.shape[:-1], -1, self.head_dim)
 
314
  is_updated = past_key_value.is_updated.get(self.layer_idx)
315
  # After the first generated token, we can reuse all key/value states from cache
316
  curr_past_key_value = past_key_value.cross_attention_cache
317
+
318
  # Conditions for calculating key and value states
319
  if not is_updated:
320
  # Compute and cache K/V for the first time
 
332
  # No cache used, compute K/V directly
333
  key_states = self.k_norm(self.k_proj(encoder_hidden_states).view(encoder_hidden_shape)).transpose(1, 2)
334
  value_states = self.v_proj(encoder_hidden_states).view(encoder_hidden_shape).transpose(1, 2)
335
+
336
  # Self-attention path: attend to the same sequence
337
  else:
338
  # Project and normalize key/value states for self-attention
 
354
  attention_interface: Callable = eager_attention_forward
355
  elif self.config._attn_implementation != "eager":
356
  attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
357
+
358
  attn_output, attn_weights = attention_interface(
359
  self,
360
  query_states,
 
444
  class AceStepDiTLayer(GradientCheckpointingLayer):
445
  """
446
  DiT (Diffusion Transformer) layer for AceStep model.
447
+
448
  Implements a transformer layer with three main components:
449
  1. Self-attention with adaptive layer norm (AdaLN)
450
  2. Cross-attention (optional) for conditioning on encoder outputs
451
  3. Feed-forward MLP with adaptive layer norm
452
+
453
  Uses scale-shift modulation from timestep embeddings for adaptive normalization.
454
  """
455
  def __init__(self, config: AceStepConfig, layer_idx: int, use_cross_attention: bool = True):
 
472
  # Scale-shift table for adaptive layer norm modulation (6 values: 3 for self-attn, 3 for MLP)
473
  self.scale_shift_table = nn.Parameter(torch.randn(1, 6, config.hidden_size) / config.hidden_size**0.5)
474
  self.attention_type = config.layer_types[layer_idx]
475
+
476
  def forward(
477
  self,
478
  hidden_states: torch.Tensor,
 
578
  class AceStepLyricEncoder(AceStepPreTrainedModel):
579
  """
580
  Encoder for processing lyric text embeddings.
581
+
582
  Encodes lyric text hidden states using a transformer encoder architecture
583
  with bidirectional attention. Projects text embeddings to model hidden size
584
  and processes them through multiple encoder layers.
585
  """
586
  def __init__(self, config):
587
  super().__init__(config)
588
+
589
  # Project text embeddings to model hidden size
590
  self.embed_tokens = nn.Linear(config.text_hidden_dim, config.hidden_size)
591
  self.norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
619
  assert input_ids is None, "Only `input_ids` is supported for the lyric encoder."
620
  assert attention_mask is not None, "Attention mask must be provided for the lyric encoder."
621
  assert inputs_embeds is not None, "Inputs embeddings must be provided for the lyric encoder."
622
+
623
  # Project input embeddings: N x T x text_hidden_dim -> N x T x hidden_size
624
  inputs_embeds = self.embed_tokens(inputs_embeds)
625
  # Cache position: only used for mask construction (not for actual caching)
 
633
  seq_len = inputs_embeds.shape[1]
634
  dtype = inputs_embeds.dtype
635
  device = inputs_embeds.device
636
+
637
  # 判断是否使用 Flash Attention 2
638
  is_flash_attn = (self.config._attn_implementation == "flash_attention_2")
639
 
 
650
  # 如果没有 padding mask,传 None。
651
  # 滑动窗口逻辑由 Layer 内部传给 FA kernel 的 sliding_window 参数控制。
652
  full_attn_mask = attention_mask
653
+
654
  # 这里的逻辑是:如果配置启用了滑动窗口,FA 模式下我们也只需要传基础的 padding mask
655
  # Layer 会自己决定是否调用带 sliding window 的 kernel
656
  sliding_attn_mask = attention_mask if self.config.use_sliding_window else None
 
660
  # 场景 B: CPU / Mac / SDPA (Eager 模式)
661
  # -------------------------------------------------------
662
  # 必须手动生成 4D Mask [B, 1, L, L]
663
+
664
  # 1. Full Attention (Bidirectional, Global)
665
  # 对应原来的 create_causal_mask + bidirectional
666
  full_attn_mask = create_4d_mask(
 
735
  class AttentionPooler(AceStepPreTrainedModel):
736
  """
737
  Attention-based pooling module.
738
+
739
  Pools sequences of patches using a special token and attention mechanism.
740
  The special token attends to all patches and its output is used as the
741
  pooled representation. Used for aggregating patch-level features into
 
783
  seq_len = x.shape[1]
784
  dtype = x.dtype
785
  device = x.device
786
+
787
  # 判断是否使用 Flash Attention 2
788
  is_flash_attn = (self.config._attn_implementation == "flash_attention_2")
789
 
 
800
  # 如果没有 padding mask,传 None。
801
  # 滑动窗口逻辑由 Layer 内部传给 FA kernel 的 sliding_window 参数控制。
802
  full_attn_mask = attention_mask
803
+
804
  # 这里的逻辑是:如果配置启用了滑动窗口,FA 模式下我们也只需要传基础的 padding mask
805
  # Layer 会自己决定是否调用带 sliding window 的 kernel
806
  sliding_attn_mask = attention_mask if self.config.use_sliding_window else None
 
810
  # 场景 B: CPU / Mac / SDPA (Eager 模式)
811
  # -------------------------------------------------------
812
  # 必须手动生成 4D Mask [B, 1, L, L]
813
+
814
  # 1. Full Attention (Bidirectional, Global)
815
  # 对应原来的 create_causal_mask + bidirectional
816
  full_attn_mask = create_4d_mask(
 
841
  "full_attention": full_attn_mask,
842
  "sliding_attention": sliding_attn_mask,
843
  }
844
+
845
  for layer_module in self.layers:
846
  layer_outputs = layer_module(
847
  hidden_states,
 
853
  hidden_states = layer_outputs[0]
854
 
855
  hidden_states = self.norm(hidden_states)
856
+
857
  # Extract the special token output (first position) as pooled representation
858
  cls_output = hidden_states[:, 0, :]
859
  cls_output = rearrange(cls_output, "(b t) c -> b t c", b=B)
 
863
  class AudioTokenDetokenizer(AceStepPreTrainedModel):
864
  """
865
  Audio token detokenizer module.
866
+
867
  Converts quantized audio tokens back to continuous acoustic representations.
868
  Expands each token into multiple patches using special tokens, processes them
869
  through encoder layers, and projects to acoustic hidden dimension.
 
918
  seq_len = x.shape[1]
919
  dtype = x.dtype
920
  device = x.device
921
+
922
  # 判断是否使用 Flash Attention 2
923
  is_flash_attn = (self.config._attn_implementation == "flash_attention_2")
924
 
 
935
  # 如果没有 padding mask,传 None。
936
  # 滑动窗口逻辑由 Layer 内部传给 FA kernel 的 sliding_window 参数控制。
937
  full_attn_mask = attention_mask
938
+
939
  # 这里的逻辑是:如果配置启用了滑动窗口,FA 模式下我们也只需要传基础的 padding mask
940
  # Layer 会自己决定是否调用带 sliding window 的 kernel
941
  sliding_attn_mask = attention_mask if self.config.use_sliding_window else None
 
945
  # 场景 B: CPU / Mac / SDPA (Eager 模式)
946
  # -------------------------------------------------------
947
  # 必须手动生成 4D Mask [B, 1, L, L]
948
+
949
  # 1. Full Attention (Bidirectional, Global)
950
  # 对应原来的 create_causal_mask + bidirectional
951
  full_attn_mask = create_4d_mask(
 
976
  "full_attention": full_attn_mask,
977
  "sliding_attention": sliding_attn_mask,
978
  }
979
+
980
  for layer_module in self.layers:
981
  layer_outputs = layer_module(
982
  hidden_states,
 
988
  hidden_states = layer_outputs[0]
989
 
990
  hidden_states = self.norm(hidden_states)
991
+
992
  hidden_states = self.proj_out(hidden_states)
993
 
994
  hidden_states = rearrange(hidden_states, "(b t) p c -> b (t p) c", b=B, p=self.config.pool_window_size)
 
998
  class AceStepTimbreEncoder(AceStepPreTrainedModel):
999
  """
1000
  Encoder for extracting timbre embeddings from reference audio.
1001
+
1002
  Processes packed reference audio acoustic features to extract timbre
1003
  representations. Uses a special token (CLS-like) to aggregate information
1004
  from the entire reference audio sequence. Outputs are unpacked back to
 
1006
  """
1007
  def __init__(self, config):
1008
  super().__init__(config)
1009
+
1010
  # Project acoustic features to model hidden size
1011
  self.embed_tokens = nn.Linear(config.timbre_hidden_dim, config.hidden_size)
1012
  self.norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
1037
  N, d = timbre_embs_packed.shape
1038
  device = timbre_embs_packed.device
1039
  dtype = timbre_embs_packed.dtype
1040
+
1041
  # Get batch size
1042
  B = int(refer_audio_order_mask.max().item() + 1)
1043
+
1044
  # Calculate element count and positions for each batch
1045
  counts = torch.bincount(refer_audio_order_mask, minlength=B)
1046
  max_count = counts.max().item()
1047
+
1048
  # Calculate positions within batch
1049
  sorted_indices = torch.argsort(refer_audio_order_mask * N + torch.arange(N, device=device), stable=True)
1050
  sorted_batch_ids = refer_audio_order_mask[sorted_indices]
1051
+
1052
  positions = torch.arange(N, device=device)
1053
+ batch_starts = torch.cat([torch.tensor([0], device=device),
1054
  torch.cumsum(counts, dim=0)[:-1]])
1055
  positions_in_sorted = positions - batch_starts[sorted_batch_ids]
1056
+
1057
  inverse_indices = torch.empty_like(sorted_indices)
1058
  inverse_indices[sorted_indices] = torch.arange(N, device=device)
1059
  positions_in_batch = positions_in_sorted[inverse_indices]
1060
+
1061
  # Use one-hot encoding and matrix multiplication (gradient-friendly approach)
1062
  # Create one-hot encoding
1063
  indices_2d = refer_audio_order_mask * max_count + positions_in_batch # (N,)
1064
  one_hot = F.one_hot(indices_2d, num_classes=B * max_count).to(dtype) # (N, B*max_count)
1065
+
1066
  # Rearrange using matrix multiplication
1067
  timbre_embs_flat = one_hot.t() @ timbre_embs_packed # (B*max_count, d)
1068
  timbre_embs_unpack = timbre_embs_flat.reshape(B, max_count, d)
1069
+
1070
  # Create mask indicating valid positions
1071
  mask_flat = (one_hot.sum(dim=0) > 0).long() # (B*max_count,)
1072
  new_mask = mask_flat.reshape(B, max_count)
1073
+
1074
  return timbre_embs_unpack, new_mask
1075
 
1076
  @can_return_tuple
 
1094
  seq_len = inputs_embeds.shape[1]
1095
  dtype = inputs_embeds.dtype
1096
  device = inputs_embeds.device
1097
+
1098
  # 判断是否使用 Flash Attention 2
1099
  is_flash_attn = (self.config._attn_implementation == "flash_attention_2")
1100
 
 
1111
  # 如果没有 padding mask,传 None。
1112
  # 滑动窗口逻辑由 Layer 内部传给 FA kernel 的 sliding_window 参数控制。
1113
  full_attn_mask = attention_mask
1114
+
1115
  # 这里的逻辑是:如果配置启用了滑动窗口,FA 模式下我们也只需要传基础的 padding mask
1116
  # Layer 会自己决定是否调用带 sliding window 的 kernel
1117
  sliding_attn_mask = attention_mask if self.config.use_sliding_window else None
 
1121
  # 场景 B: CPU / Mac / SDPA (Eager 模式)
1122
  # -------------------------------------------------------
1123
  # 必须手动生成 4D Mask [B, 1, L, L]
1124
+
1125
  # 1. Full Attention (Bidirectional, Global)
1126
  # 对应原来的 create_causal_mask + bidirectional
1127
  full_attn_mask = create_4d_mask(
 
1152
  "full_attention": full_attn_mask,
1153
  "sliding_attention": sliding_attn_mask,
1154
  }
1155
+
1156
  # Initialize hidden states
1157
  hidden_states = inputs_embeds
1158
 
 
1182
  class AceStepAudioTokenizer(AceStepPreTrainedModel):
1183
  """
1184
  Audio tokenizer module.
1185
+
1186
  Converts continuous acoustic features into discrete quantized tokens.
1187
  Process: project -> pool patches -> quantize. Used for converting audio
1188
  representations into discrete tokens for processing by the diffusion model.
 
1209
  hidden_states: Optional[torch.FloatTensor] = None,
1210
  **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
1211
  ) -> BaseModelOutput:
1212
+
1213
  # Project acoustic features to hidden size
1214
  hidden_states = self.audio_acoustic_proj(hidden_states)
1215
  # Pool sequences: N x T//pool_window_size x pool_window_size x d -> N x T//pool_window_size x d
 
1226
  class Lambda(nn.Module):
1227
  """
1228
  Wrapper module for arbitrary lambda functions.
1229
+
1230
  Allows using lambda functions in nn.Sequential by wrapping them in a Module.
1231
  Useful for simple transformations like transpose operations.
1232
  """
1233
  def __init__(self, func):
1234
  super().__init__()
1235
  self.func = func
1236
+
1237
  def forward(self, x):
1238
  return self.func(x)
1239
 
 
1241
  class AceStepDiTModel(AceStepPreTrainedModel):
1242
  """
1243
  DiT (Diffusion Transformer) model for AceStep.
1244
+
1245
  Main diffusion model that generates audio latents conditioned on text, lyrics,
1246
  and timbre. Uses patch-based processing with transformer layers, timestep
1247
  conditioning, and cross-attention to encoder outputs.
 
1259
  inner_dim = config.hidden_size
1260
  patch_size = config.patch_size
1261
  self.patch_size = patch_size
1262
+
1263
  # Input projection: patch embedding using 1D convolution
1264
  # Converts sequence into patches for efficient processing
1265
  self.proj_in = nn.Sequential(
 
1278
  # Two embeddings: one for timestep t, one for timestep difference (t - r)
1279
  self.time_embed = TimestepEmbedding(in_channels=256, time_embed_dim=inner_dim)
1280
  self.time_embed_r = TimestepEmbedding(in_channels=256, time_embed_dim=inner_dim)
1281
+
1282
  # Project encoder hidden states to model dimension
1283
+ condition_dim = getattr(config, "encoder_hidden_size", None) or config.hidden_size
1284
+ self.condition_embedder = nn.Linear(condition_dim, inner_dim, bias=True)
1285
 
1286
  # Output normalization and projection
1287
  # Adaptive layer norm with scale-shift modulation, then de-patchify
 
1332
  use_cache = False
1333
  if self.training:
1334
  use_cache = False
1335
+
1336
  # Initialize cache if needed (only during inference for auto-regressive generation)
1337
  if not self.training and use_cache and past_key_values is None:
1338
  past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache())
 
1359
  # Project input to patches and project encoder states
1360
  hidden_states = self.proj_in(hidden_states)
1361
  encoder_hidden_states = self.condition_embedder(encoder_hidden_states)
1362
+
1363
  # Cache positions
1364
  if cache_position is None:
1365
  past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
1366
  cache_position = torch.arange(
1367
  past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=hidden_states.device
1368
  )
1369
+
1370
  # Position IDs
1371
  if position_ids is None:
1372
  position_ids = cache_position.unsqueeze(0)
 
1376
  encoder_seq_len = encoder_hidden_states.shape[1]
1377
  dtype = hidden_states.dtype
1378
  device = hidden_states.device
1379
+
1380
  # 判断是否使用 Flash Attention 2
1381
  is_flash_attn = (self.config._attn_implementation == "flash_attention_2")
1382
 
 
1394
  # 如果没有 padding mask,传 None。
1395
  # 滑动窗口逻辑由 Layer 内部传给 FA kernel 的 sliding_window 参数控制。
1396
  full_attn_mask = attention_mask
1397
+
1398
  # 这里的逻辑是:如果配置启用了滑动窗口,FA 模式下我们也只需要传基础的 padding mask
1399
  # Layer 会自己决定是否调用带 sliding window 的 kernel
1400
  sliding_attn_mask = attention_mask if self.config.use_sliding_window else None
 
1404
  # 场景 B: CPU / Mac / SDPA (Eager 模式)
1405
  # -------------------------------------------------------
1406
  # 必须手动生成 4D Mask [B, 1, L, L]
1407
+
1408
  # 1. Full Attention (Bidirectional, Global)
1409
  # 对应原来的 create_causal_mask + bidirectional
1410
  full_attn_mask = create_4d_mask(
 
1417
  is_causal=False # <--- 关键:双向注意力
1418
  )
1419
  max_len = max(seq_len, encoder_seq_len)
1420
+
1421
  encoder_attention_mask = create_4d_mask(
1422
  seq_len=max_len,
1423
  dtype=dtype,
 
1485
  # Extract the last element which is cross_attn_weights
1486
  if len(layer_outputs) >= 3:
1487
  all_cross_attentions += (layer_outputs[2],)
1488
+
1489
  if return_hidden_states:
1490
  return hidden_states
1491
 
 
1498
  hidden_states = (self.norm_out(hidden_states) * (1 + scale) + shift).type_as(hidden_states)
1499
  # Project output: de-patchify back to original sequence format
1500
  hidden_states = self.proj_out(hidden_states)
1501
+
1502
  # Crop back to original sequence length to ensure exact length match (remove padding)
1503
  hidden_states = hidden_states[:, :original_seq_len, :]
1504
+
1505
  outputs = (hidden_states, past_key_values)
1506
 
1507
  if output_attentions:
 
1511
  class AceStepConditionEncoder(AceStepPreTrainedModel):
1512
  """
1513
  Condition encoder for AceStep model.
1514
+
1515
  Encodes multiple conditioning inputs (text, lyrics, timbre) and packs them
1516
  into a single sequence for cross-attention in the diffusion model. Handles
1517
  projection, encoding, and sequence packing.
 
1556
  return encoder_hidden_states, encoder_attention_mask
1557
 
1558
 
1559
+ def _repaint_step_injection(xt, clean_src, mask, t_next, noise):
1560
+ """Replace non-repaint regions of *xt* with noised source latents."""
1561
+ zt = t_next * noise + (1.0 - t_next) * clean_src
1562
+ m = mask.unsqueeze(-1).expand_as(xt)
1563
+ return torch.where(m, xt, zt)
1564
+
1565
+
1566
+ def _repaint_boundary_blend(x_gen, clean_src, mask, cf_frames):
1567
+ """Blend generated latents with source at repaint boundaries."""
1568
+ soft = mask.float().clone()
1569
+ if cf_frames <= 0:
1570
+ m = soft.unsqueeze(-1).expand_as(x_gen)
1571
+ return m * x_gen + (1.0 - m) * clean_src
1572
+ B, T = mask.shape
1573
+ for b in range(B):
1574
+ row = mask[b]
1575
+ if row.all() or not row.any():
1576
+ continue
1577
+ idx = torch.nonzero(row, as_tuple=False).squeeze(-1)
1578
+ if idx.numel() == 0:
1579
+ continue
1580
+ left, right = idx[0].item(), idx[-1].item() + 1
1581
+ fs = max(left - cf_frames, 0)
1582
+ if left - fs > 0:
1583
+ soft[b, fs:left] = torch.linspace(0, 1, left - fs + 2, device=soft.device)[1:-1]
1584
+ fe = min(right + cf_frames, T)
1585
+ if fe - right > 0:
1586
+ soft[b, right:fe] = torch.linspace(1, 0, fe - right + 2, device=soft.device)[1:-1]
1587
+ m = soft.unsqueeze(-1).expand_as(x_gen)
1588
+ return m * x_gen + (1.0 - m) * clean_src
1589
+
1590
+
1591
  class AceStepConditionGenerationModel(AceStepPreTrainedModel):
1592
  """
1593
  Main conditional generation model for AceStep.
1594
+
1595
  End-to-end model for generating audio conditioned on text, lyrics, and timbre.
1596
  Combines encoder (for conditioning), decoder (diffusion model), tokenizer
1597
  (for discrete tokenization), and detokenizer (for reconstruction).
 
1602
  self.config = config
1603
  # Diffusion model components
1604
  self.decoder = AceStepDiTModel(config) # Main diffusion transformer
1605
+ # Build encoder config: use separate encoder_hidden_size when available
1606
+ # (4B models have encoder_hidden_size=2048 != hidden_size=2560)
1607
+ _enc_hs = getattr(config, "encoder_hidden_size", None) or config.hidden_size
1608
+ if _enc_hs != config.hidden_size:
1609
+ encoder_config = copy.deepcopy(config)
1610
+ encoder_config.hidden_size = _enc_hs
1611
+ encoder_config.intermediate_size = getattr(config, "encoder_intermediate_size", None) or config.intermediate_size
1612
+ encoder_config.num_attention_heads = getattr(config, "encoder_num_attention_heads", None) or config.num_attention_heads
1613
+ encoder_config.num_key_value_heads = getattr(config, "encoder_num_key_value_heads", None) or config.num_key_value_heads
1614
+ else:
1615
+ encoder_config = config
1616
+ self.encoder = AceStepConditionEncoder(encoder_config) # Condition encoder
1617
+ self.tokenizer = AceStepAudioTokenizer(encoder_config) # Audio tokenizer
1618
+ self.detokenizer = AudioTokenDetokenizer(encoder_config) # Audio detokenizer
1619
  # Null condition embedding for classifier-free guidance
1620
+ self.null_condition_emb = nn.Parameter(torch.randn(1, 1, encoder_config.hidden_size))
1621
 
1622
  # Initialize weights and apply final processing
1623
  self.post_init()
 
1666
  precomputed_lm_hints_25Hz: Optional[torch.FloatTensor] = None,
1667
  audio_codes: torch.FloatTensor = None,
1668
  ):
1669
+
1670
  dtype = hidden_states.dtype
1671
  encoder_hidden_states, encoder_attention_mask = self.encoder(
1672
  text_hidden_states=text_hidden_states,
 
1754
  t_ = t.unsqueeze(-1).unsqueeze(-1)
1755
  # Interpolate: x_t = t * x1 + (1 - t) * x0
1756
  xt = t_ * x1 + (1.0 - t_) * x0
1757
+
1758
  # Predict flow (velocity) from diffusion model
1759
  decoder_outputs = self.decoder(
1760
  hidden_states=xt,
 
1771
  return {
1772
  "diffusion_loss": diffusion_loss,
1773
  }
1774
+
1775
  def training_losses(self, **kwargs):
1776
  return self.forward(**kwargs)
1777
+
1778
  def prepare_noise(self, context_latents: torch.FloatTensor, seed: Union[int, List[int], None] = None):
1779
  """
1780
  Prepare noise tensor for generation with optional seeding.
 
1815
  return noise
1816
 
1817
  def get_x0_from_noise(self, zt, vt, t):
1818
+ if t.shape[0] != zt.shape[0]:
1819
+ raise ValueError(f"Batch size mismatch: t has {t.shape[0]}, zt has {zt.shape[0]}")
1820
+
1821
  return zt - vt * t.unsqueeze(-1).unsqueeze(-1)
1822
 
1823
  def renoise(self, x, t, noise=None):
 
1827
  t = t.unsqueeze(-1).unsqueeze(-1)
1828
  xt = t * noise + (1 - t) * x
1829
  return xt
1830
+
1831
  def generate_audio(
1832
  self,
1833
  text_hidden_states: torch.FloatTensor,
 
1845
  infer_method: str = "ode",
1846
  use_cache: bool = True,
1847
  infer_steps: int = 30,
1848
+ diffusion_guidance_scale: float = 7.0,
1849
  audio_cover_strength: float = 1.0,
1850
  non_cover_text_hidden_states: Optional[torch.FloatTensor] = None,
1851
  non_cover_text_attention_mask: Optional[torch.FloatTensor] = None,
 
1857
  use_adg: bool = False,
1858
  shift: float = 1.0,
1859
  cover_noise_strength: float = 0.0,
1860
+ repaint_mask: Optional[torch.Tensor] = None,
1861
+ clean_src_latents: Optional[torch.FloatTensor] = None,
1862
+ repaint_crossfade_frames: int = 10,
1863
+ repaint_injection_ratio: float = 0.5,
1864
+ sampler_mode: str = "euler",
1865
+ velocity_norm_threshold: float = 0.0,
1866
+ velocity_ema_factor: float = 0.0,
1867
  **kwargs,
1868
  ):
1869
+ # Backward-compat: accept the old misspelled key "diffusion_guidance_sale"
1870
+ # so that callers that have not yet updated their code still work correctly.
1871
+ # Note: if both keys are passed simultaneously, the old key wins because Python
1872
+ # cannot distinguish "explicit new key" from "new key at its default value".
1873
+ # In practice callers should only ever pass one of the two.
1874
+ if "diffusion_guidance_sale" in kwargs:
1875
+ logger.warning(
1876
+ "generate_audio() received deprecated kwarg 'diffusion_guidance_sale'; "
1877
+ "please rename it to 'diffusion_guidance_scale'."
1878
+ )
1879
+ diffusion_guidance_scale = kwargs.pop("diffusion_guidance_sale")
1880
+
1881
  if attention_mask is None:
1882
  latent_length = src_latents.shape[1]
1883
  attention_mask = torch.ones(src_latents.shape[0], latent_length, device=src_latents.device, dtype=src_latents.dtype)
 
1941
  bsz, device, dtype = context_latents.shape[0], context_latents.device, context_latents.dtype
1942
  past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache())
1943
  momentum_buffer = MomentumBuffer()
1944
+
1945
  # Cover noise initialization: blend noise with src_latents
1946
  if cover_noise_strength > 0.0:
1947
  # cover_noise_strength=1 means closest to src, so noise_level should be low
 
1967
  )
1968
  else:
1969
  xt = noise
1970
+
1971
  # main task condition
1972
+ do_cfg_guidance = diffusion_guidance_scale > 1.0
1973
  if do_cfg_guidance:
1974
  encoder_hidden_states = torch.cat([encoder_hidden_states, self.null_condition_emb.expand_as(encoder_hidden_states)], dim=0)
1975
  encoder_attention_mask = torch.cat([encoder_attention_mask, encoder_attention_mask], dim=0)
1976
  # src_latents
1977
  context_latents = torch.cat([context_latents, context_latents], dim=0)
1978
  attention_mask = torch.cat([attention_mask, attention_mask], dim=0)
1979
+
1980
+ use_heun = sampler_mode == "heun"
1981
+ use_norm_clamp = velocity_norm_threshold > 0.0
1982
+ use_ema = velocity_ema_factor > 0.0
1983
+ prev_vt = None
1984
+ if use_heun and infer_method == "sde":
1985
+ logger.warning("Heun sampler is not compatible with SDE; falling back to Euler.")
1986
+ use_heun = False
1987
+
1988
  _switched_to_non_cover = False
1989
  with torch.no_grad():
1990
  for step_idx, (t_curr, t_prev) in enumerate(iterator):
 
2000
  encoder_attention_mask = encoder_attention_mask_non_cover
2001
  context_latents = context_latents_non_cover
2002
  past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache())
2003
+
2004
  x = torch.cat([xt, xt], dim=0) if do_cfg_guidance else xt
2005
  t_curr_tensor = t_curr * torch.ones((x.shape[0],), device=device, dtype=dtype)
2006
  decoder_outputs = self.decoder(
 
2014
  use_cache=True,
2015
  past_key_values=past_key_values,
2016
  )
2017
+
2018
  vt = decoder_outputs[0]
2019
  past_key_values = decoder_outputs[1]
2020
  apply_cfg_guidance = t_curr >= cfg_interval_start and t_curr <= cfg_interval_end
 
2025
  vt = apg_forward(
2026
  pred_cond=pred_cond,
2027
  pred_uncond=pred_null_cond,
2028
+ guidance_scale=diffusion_guidance_scale,
2029
  momentum_buffer=momentum_buffer,
2030
  dims=[1],
2031
  )
 
2035
  noise_pred_cond=pred_cond,
2036
  noise_pred_uncond=pred_null_cond,
2037
  sigma=t_curr,
2038
+ guidance_scale=diffusion_guidance_scale,
2039
  )
2040
  else:
2041
  vt = pred_cond
2042
+ # Velocity norm clamping — prevents outlier predictions
2043
+ if use_norm_clamp:
2044
+ vt_norm = torch.norm(vt, dim=(1, 2), keepdim=True)
2045
+ xt_norm = torch.norm(xt, dim=(1, 2), keepdim=True) + 1e-10
2046
+ scale = torch.clamp(velocity_norm_threshold * xt_norm / (vt_norm + 1e-10), max=1.0)
2047
+ vt = vt * scale
2048
+
2049
+ # Velocity EMA smoothing — stabilises denoising trajectory
2050
+ if use_ema and prev_vt is not None:
2051
+ vt = (1.0 - velocity_ema_factor) * vt + velocity_ema_factor * prev_vt
2052
+
2053
  # Update x_t based on inference method
2054
  if infer_method == "sde":
2055
  # Stochastic Differential Equation: predict clean, then re-add noise
 
2057
  pred_clean = self.get_x0_from_noise(xt, vt, t_curr_bsz)
2058
  next_timestep = 1.0 - (float(step_idx + 1) / infer_steps)
2059
  xt = self.renoise(pred_clean, next_timestep)
2060
+ t_after_step = next_timestep
2061
+ elif use_heun and infer_method == "ode":
2062
+ # Heun (second-order) ODE step via trapezoidal rule
2063
+ dt = t_curr - t_prev
2064
+ dt_tensor = dt * torch.ones((bsz,), device=device, dtype=dtype).unsqueeze(-1).unsqueeze(-1)
2065
+ # Predictor: Euler step to get xt_predicted at t_prev
2066
+ xt_predicted = xt - vt * dt_tensor
2067
+ # Corrector: evaluate model at the predicted point
2068
+ x2 = torch.cat([xt_predicted, xt_predicted], dim=0) if do_cfg_guidance else xt_predicted
2069
+ t_prev_tensor = t_prev * torch.ones((x2.shape[0],), device=device, dtype=dtype)
2070
+ # Reset KV cache for corrector pass (Heun needs fresh evaluation)
2071
+ corrector_kv = EncoderDecoderCache(DynamicCache(), DynamicCache())
2072
+ decoder_outputs2 = self.decoder(
2073
+ hidden_states=x2,
2074
+ timestep=t_prev_tensor,
2075
+ timestep_r=t_prev_tensor,
2076
+ attention_mask=attention_mask,
2077
+ encoder_hidden_states=encoder_hidden_states,
2078
+ encoder_attention_mask=encoder_attention_mask,
2079
+ context_latents=context_latents,
2080
+ use_cache=False,
2081
+ past_key_values=corrector_kv,
2082
+ )
2083
+ vt2 = decoder_outputs2[0]
2084
+ if do_cfg_guidance:
2085
+ pred_cond2, pred_null_cond2 = vt2.chunk(2)
2086
+ # Recompute CFG interval for corrector timestep (t_prev, not t_curr)
2087
+ apply_cfg_corrector = t_prev >= cfg_interval_start and t_prev <= cfg_interval_end
2088
+ if apply_cfg_corrector:
2089
+ if not use_adg:
2090
+ # Use basic CFG for corrector to avoid mutating APG momentum twice per step
2091
+ vt2 = cfg_forward(pred_cond2, pred_null_cond2, diffusion_guidance_scale)
2092
+ elif t_prev > 0:
2093
+ # Guard against sigma=0 which causes NaN in ADG division
2094
+ vt2 = adg_forward(
2095
+ latents=xt_predicted,
2096
+ noise_pred_cond=pred_cond2,
2097
+ noise_pred_uncond=pred_null_cond2,
2098
+ sigma=t_prev,
2099
+ guidance_scale=diffusion_guidance_scale,
2100
+ )
2101
+ else:
2102
+ vt2 = cfg_forward(pred_cond2, pred_null_cond2, diffusion_guidance_scale)
2103
+ else:
2104
+ vt2 = pred_cond2
2105
+ if use_norm_clamp:
2106
+ vt2_norm = torch.norm(vt2, dim=(1, 2), keepdim=True)
2107
+ xt_pred_norm = torch.norm(xt_predicted, dim=(1, 2), keepdim=True) + 1e-10
2108
+ scale2 = torch.clamp(velocity_norm_threshold * xt_pred_norm / (vt2_norm + 1e-10), max=1.0)
2109
+ vt2 = vt2 * scale2
2110
+ if use_ema:
2111
+ vt2 = (1.0 - velocity_ema_factor) * vt2 + velocity_ema_factor * vt
2112
+ # Average the two velocity predictions (trapezoidal rule)
2113
+ vt_avg = 0.5 * (vt + vt2)
2114
+ xt = xt - vt_avg * dt_tensor
2115
+ vt = vt_avg
2116
+ t_after_step = t_prev
2117
  elif infer_method == "ode":
2118
  # Ordinary Differential Equation: Euler method
2119
  # dx/dt = -v, so x_{t+1} = x_t - v_t * dt
2120
  dt = t_curr - t_prev
2121
  dt_tensor = dt * torch.ones((bsz,), device=device, dtype=dtype).unsqueeze(-1).unsqueeze(-1)
2122
  xt = xt - vt * dt_tensor
2123
+ t_after_step = t_prev
2124
+
2125
+ prev_vt = vt
2126
+
2127
+ injection_cutoff = round(repaint_injection_ratio * infer_steps)
2128
+ if repaint_mask is not None and clean_src_latents is not None and step_idx < injection_cutoff:
2129
+ xt = _repaint_step_injection(
2130
+ xt, clean_src_latents, repaint_mask, t_after_step, noise,
2131
+ )
2132
+
2133
  x_gen = xt
2134
+ if repaint_mask is not None and clean_src_latents is not None and repaint_crossfade_frames > 0:
2135
+ x_gen = _repaint_boundary_blend(
2136
+ x_gen, clean_src_latents, repaint_mask, repaint_crossfade_frames,
2137
+ )
2138
+
2139
  end_time = time.time()
2140
  time_costs["diffusion_time_cost"] = end_time - start_time
2141
  time_costs["diffusion_per_step_time_cost"] = time_costs["diffusion_time_cost"] / infer_steps
 
2158
  torch.cuda.manual_seed_all(seed)
2159
  torch.backends.cudnn.deterministic = True
2160
  torch.backends.cudnn.benchmark = False
2161
+
2162
  # Get model dtype and device
2163
  model_dtype = next(model.parameters()).dtype
2164
  device = next(model.parameters()).device
2165
+
2166
  print(f"Testing with dtype: {model_dtype}, device: {device}, seed: {seed}")
2167
+
2168
  # Test data preparation with matching dtype
2169
  text_hidden_states = torch.randn(2, 77, 1024, dtype=model_dtype, device=device)
2170
  text_attention_mask = torch.ones(2, 77, dtype=model_dtype, device=device)
 
2298
  # model = model.float()
2299
  model = model.to("cuda")
2300
  model = model.bfloat16()
2301
+ test_forward(model)