potsawee commited on
Commit
3c4f262
·
verified ·
1 Parent(s): 45bf5c5

Upload modeling_backbone_components.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_backbone_components.py +751 -0
modeling_backbone_components.py ADDED
@@ -0,0 +1,751 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Backbone components for Mimi models - shared attention transformers."""
2
+
3
+ import math
4
+ from typing import Optional, Union
5
+
6
+ import torch
7
+ from torch import nn
8
+
9
+ from transformers.cache_utils import Cache, DynamicCache, StaticCache
10
+ from transformers.masking_utils import create_causal_mask
11
+ from transformers.modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
12
+ from transformers.modeling_layers import GradientCheckpointingLayer
13
+ from transformers.modeling_outputs import BaseModelOutputWithPast
14
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
15
+ from transformers.utils import logging
16
+
17
+ from configuration_mimi import MimiConfig
18
+ from modeling_mimi_clean import (
19
+ MimiAttention,
20
+ MimiMLP,
21
+ MimiLayerScale,
22
+ MimiRotaryEmbedding,
23
+ apply_rotary_pos_emb,
24
+ MIMI_ATTENTION_CLASSES
25
+ )
26
+
27
+ logger = logging.get_logger(__name__)
28
+
29
+
30
+ class CausalAttentionTransformer(nn.Module):
31
+ """
32
+ Standard causal attention transformer (decoder-only) consisting of *config.num_hidden_layers* layers.
33
+ Each layer is a [`MimiTransformerLayer`] with self-attention only.
34
+
35
+ This is a standard decoder-only transformer architecture for causal language modeling.
36
+
37
+ Args:
38
+ config: MimiConfig
39
+ """
40
+
41
+ def __init__(self, config: MimiConfig):
42
+ super().__init__()
43
+
44
+ self.layers = nn.ModuleList(
45
+ [MimiTransformerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
46
+ )
47
+ self._attn_implementation = config._attn_implementation
48
+ self.gradient_checkpointing = False
49
+ self.config = config
50
+
51
+ def forward(
52
+ self,
53
+ hidden_states: torch.Tensor,
54
+ attention_mask: Optional[torch.Tensor] = None,
55
+ position_ids: Optional[torch.LongTensor] = None,
56
+ past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
57
+ use_cache: Optional[bool] = None,
58
+ output_attentions: Optional[bool] = None,
59
+ output_hidden_states: Optional[bool] = None,
60
+ return_dict: Optional[bool] = None,
61
+ cache_position: Optional[torch.LongTensor] = None,
62
+ ) -> Union[tuple, BaseModelOutputWithPast]:
63
+ """
64
+ Args:
65
+ hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
66
+ Input embeddings or hidden states from previous layer
67
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
68
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
69
+
70
+ - 1 for tokens that are **not masked**,
71
+ - 0 for tokens that are **masked**.
72
+
73
+ [What are attention masks?](../glossary#attention-mask)
74
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
75
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
76
+ config.max_position_embeddings - 1]`.
77
+
78
+ [What are position IDs?](../glossary#position-ids)
79
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
80
+ Pre-computed hidden-states (key and values in the self-attention blocks) that can be used to speed up
81
+ sequential decoding. This typically consists in the `past_key_values` returned by the model at a previous
82
+ stage of decoding, when `use_cache=True` or `config.use_cache=True`.
83
+
84
+ Two formats are allowed:
85
+ - a [`~cache_utils.Cache`] instance;
86
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.num_hidden_layers`, with each tuple having 2 tensors of
87
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
88
+ cache format.
89
+
90
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
91
+ legacy cache format will be returned.
92
+
93
+ If `past_key_values` are used, the user can optionally input only the last `hidden_states` of shape
94
+ `(batch_size, 1, hidden_size)` instead of all `hidden_states` of shape `(batch_size, sequence_length, hidden_size)`.
95
+ use_cache (`bool`, *optional*):
96
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
97
+ `past_key_values`).
98
+ output_attentions (`bool`, *optional*):
99
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
100
+ tensors for more detail.
101
+ output_hidden_states (`bool`, *optional*):
102
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
103
+ more detail.
104
+ return_dict (`bool`, *optional*):
105
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
106
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
107
+ Indices depicting the position of the input sequence tokens in the sequence.
108
+ """
109
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
110
+ output_hidden_states = (
111
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
112
+ )
113
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
114
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
115
+
116
+ if self.gradient_checkpointing and self.training and use_cache:
117
+ logger.warning_once(
118
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
119
+ )
120
+ use_cache = False
121
+
122
+ if use_cache and not isinstance(past_key_values, Cache):
123
+ if past_key_values is None:
124
+ past_key_values = DynamicCache()
125
+ else:
126
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
127
+ logger.warning_once(
128
+ "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
129
+ "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
130
+ "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
131
+ )
132
+
133
+ if cache_position is None:
134
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
135
+ cache_position = torch.arange(
136
+ past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=hidden_states.device
137
+ )
138
+
139
+ if position_ids is None:
140
+ position_ids = cache_position.unsqueeze(0)
141
+
142
+ # Create causal mask for self-attention
143
+ causal_mask = create_causal_mask(
144
+ config=self.config,
145
+ input_embeds=hidden_states,
146
+ attention_mask=attention_mask,
147
+ cache_position=cache_position,
148
+ past_key_values=past_key_values,
149
+ position_ids=position_ids,
150
+ )
151
+
152
+ # Initialize output containers
153
+ all_hidden_states = () if output_hidden_states else None
154
+ all_self_attns = () if output_attentions else None
155
+ next_decoder_cache = None
156
+
157
+ for decoder_layer in self.layers:
158
+ if output_hidden_states:
159
+ all_hidden_states += (hidden_states,)
160
+
161
+ if self.gradient_checkpointing and self.training:
162
+ layer_outputs = self._gradient_checkpointing_func(
163
+ decoder_layer.__call__,
164
+ hidden_states,
165
+ causal_mask,
166
+ position_ids,
167
+ past_key_values,
168
+ output_attentions,
169
+ use_cache,
170
+ cache_position,
171
+ )
172
+ else:
173
+ layer_outputs = decoder_layer(
174
+ hidden_states,
175
+ attention_mask=causal_mask,
176
+ position_ids=position_ids,
177
+ past_key_value=past_key_values,
178
+ output_attentions=output_attentions,
179
+ use_cache=use_cache,
180
+ cache_position=cache_position,
181
+ )
182
+
183
+ hidden_states = layer_outputs[0]
184
+
185
+ if use_cache:
186
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
187
+
188
+ if output_attentions:
189
+ all_self_attns += (layer_outputs[1],)
190
+
191
+ # Add hidden states from the last decoder layer
192
+ if output_hidden_states:
193
+ all_hidden_states += (hidden_states,)
194
+
195
+ next_cache = next_decoder_cache if use_cache else None
196
+
197
+ if not return_dict:
198
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
199
+
200
+ return BaseModelOutputWithPast(
201
+ last_hidden_state=hidden_states,
202
+ past_key_values=next_cache,
203
+ hidden_states=all_hidden_states,
204
+ attentions=all_self_attns,
205
+ )
206
+
207
+
208
+ class MimiTransformerLayer(GradientCheckpointingLayer):
209
+ def __init__(self, config: MimiConfig, layer_idx: int):
210
+ super().__init__()
211
+ self.hidden_size = config.hidden_size
212
+
213
+ self.self_attn = MIMI_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
214
+
215
+ self.mlp = MimiMLP(config)
216
+ self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps)
217
+ self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps)
218
+ self.self_attn_layer_scale = MimiLayerScale(config)
219
+ self.mlp_layer_scale = MimiLayerScale(config)
220
+
221
+ def forward(
222
+ self,
223
+ hidden_states: torch.Tensor,
224
+ attention_mask: Optional[torch.Tensor] = None,
225
+ position_ids: Optional[torch.LongTensor] = None,
226
+ past_key_value: Optional[Cache] = None,
227
+ output_attentions: Optional[bool] = False,
228
+ use_cache: Optional[bool] = False,
229
+ cache_position: Optional[torch.LongTensor] = None,
230
+ **kwargs,
231
+ ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
232
+ """
233
+ Args:
234
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
235
+ attention_mask (`torch.FloatTensor`, *optional*):
236
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
237
+ query_sequence_length, key_sequence_length)` if default attention is used.
238
+ output_attentions (`bool`, *optional*):
239
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
240
+ returned tensors for more detail.
241
+ use_cache (`bool`, *optional*):
242
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
243
+ (see `past_key_values`).
244
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
245
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
246
+ Indices depicting the position of the input sequence tokens in the sequence
247
+ kwargs (`dict`, *optional*):
248
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
249
+ into the model
250
+ """
251
+ residual = hidden_states
252
+
253
+ hidden_states = self.input_layernorm(hidden_states)
254
+
255
+ # Self Attention
256
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
257
+ hidden_states=hidden_states,
258
+ attention_mask=attention_mask,
259
+ position_ids=position_ids,
260
+ past_key_value=past_key_value,
261
+ output_attentions=output_attentions,
262
+ use_cache=use_cache,
263
+ cache_position=cache_position,
264
+ **kwargs,
265
+ )
266
+ hidden_states = residual + self.self_attn_layer_scale(hidden_states)
267
+
268
+ # Fully Connected
269
+ residual = hidden_states
270
+ hidden_states = self.post_attention_layernorm(hidden_states)
271
+ hidden_states = self.mlp(hidden_states)
272
+ hidden_states = residual + self.mlp_layer_scale(hidden_states)
273
+
274
+ outputs = (hidden_states,)
275
+
276
+ if output_attentions:
277
+ outputs += (self_attn_weights,)
278
+
279
+ if use_cache:
280
+ outputs += (present_key_value,)
281
+
282
+ return outputs
283
+
284
+
285
+ class CrossAttention(nn.Module):
286
+ """
287
+ Cross-attention layer with monotonic masking for decoder queries attending to encoder outputs.
288
+ Queries come from decoder, keys and values come from encoder.
289
+ Supports monotonic attention where each query can only attend to a progressive subset of keys.
290
+ """
291
+
292
+ def __init__(self, config: MimiConfig, layer_idx: Optional[int] = None):
293
+ super().__init__()
294
+ self.config = config
295
+ self.layer_idx = layer_idx
296
+ if layer_idx is None:
297
+ logger.warning_once(
298
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
299
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
300
+ "when creating this class."
301
+ )
302
+
303
+ self.attention_dropout = config.attention_dropout
304
+ self.hidden_size = config.hidden_size
305
+ self.num_heads = config.num_attention_heads
306
+ self.head_dim = config.head_dim
307
+ self.num_key_value_heads = config.num_key_value_heads
308
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
309
+ self.max_position_embeddings = config.max_position_embeddings
310
+ self.rope_theta = config.rope_theta
311
+ self.is_causal = True # Causal for queries, but can attend to all encoder positions
312
+ self.scaling = 1 / math.sqrt(config.head_dim)
313
+
314
+ if self.hidden_size % self.num_heads != 0:
315
+ raise ValueError(
316
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
317
+ f" and `num_heads`: {self.num_heads})."
318
+ )
319
+
320
+ # Query projection for decoder hidden states
321
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
322
+ # Key and value projections for encoder hidden states
323
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
324
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
325
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
326
+
327
+ # Rotary embeddings only for queries (decoder positions)
328
+ self.rotary_emb = MimiRotaryEmbedding(config)
329
+
330
+ def forward(
331
+ self,
332
+ hidden_states: torch.Tensor, # Decoder hidden states (queries)
333
+ encoder_hidden_states: torch.Tensor, # Encoder hidden states (keys, values)
334
+ attention_mask: Optional[torch.Tensor] = None, # Mask for encoder positions
335
+ position_ids: Optional[torch.LongTensor] = None, # Decoder position IDs
336
+ past_key_value: Optional[Cache] = None,
337
+ output_attentions: bool = False,
338
+ use_cache: bool = False,
339
+ cache_position: Optional[torch.LongTensor] = None,
340
+ alignment_chunk_sizes: Optional[torch.Tensor] = None, # Monotonic attention chunk sizes
341
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
342
+ bsz, q_len, _ = hidden_states.size()
343
+ _, kv_len, _ = encoder_hidden_states.size()
344
+
345
+ # Queries from decoder
346
+ query_states = self.q_proj(hidden_states)
347
+ # Keys and values from encoder
348
+ key_states = self.k_proj(encoder_hidden_states)
349
+ value_states = self.v_proj(encoder_hidden_states)
350
+
351
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
352
+ key_states = key_states.view(bsz, kv_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
353
+ value_states = value_states.view(bsz, kv_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
354
+
355
+ # Apply rotary embeddings only to queries (decoder positions)
356
+ if position_ids is not None:
357
+ cos, sin = self.rotary_emb(value_states, position_ids)
358
+ query_states, _ = apply_rotary_pos_emb(query_states, query_states, cos, sin)
359
+
360
+ if past_key_value is not None:
361
+ # For cross attention, we typically cache encoder keys/values
362
+ cache_kwargs = {"sin": sin if position_ids is not None else None,
363
+ "cos": cos if position_ids is not None else None,
364
+ "cache_position": cache_position}
365
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
366
+
367
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
368
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
369
+
370
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling
371
+
372
+ # Apply monotonic attention mask if alignment_chunk_sizes is provided
373
+ if alignment_chunk_sizes is not None:
374
+ monotonic_mask = _create_monotonic_attention_mask(
375
+ alignment_chunk_sizes=alignment_chunk_sizes,
376
+ query_length=q_len,
377
+ key_length=kv_len,
378
+ device=attn_weights.device,
379
+ dtype=attn_weights.dtype,
380
+ )
381
+ attn_weights = attn_weights + monotonic_mask
382
+
383
+ # Apply additional attention mask for encoder positions (if provided)
384
+ if attention_mask is not None:
385
+ # attention_mask should mask invalid encoder positions
386
+ # Shape: [batch_size, 1, 1, encoder_seq_len] or [batch_size, 1, decoder_seq_len, encoder_seq_len]
387
+ attn_weights = attn_weights + attention_mask
388
+
389
+ # upcast attention to fp32
390
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
391
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
392
+ attn_output = torch.matmul(attn_weights, value_states)
393
+
394
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
395
+ raise ValueError(
396
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
397
+ f" {attn_output.size()}"
398
+ )
399
+
400
+ attn_output = attn_output.transpose(1, 2).contiguous()
401
+ attn_output = attn_output.view(bsz, q_len, -1)
402
+ attn_output = self.o_proj(attn_output)
403
+
404
+ if not output_attentions:
405
+ attn_weights = None
406
+
407
+ return attn_output, attn_weights, past_key_value
408
+
409
+
410
+ class CrossAttentionLayer(GradientCheckpointingLayer):
411
+ """
412
+ Cross-attention transformer layer with layer normalization and MLP.
413
+ Includes self-attention on decoder, cross-attention to encoder, and feed-forward.
414
+ """
415
+
416
+ def __init__(self, config: MimiConfig, layer_idx: int):
417
+ super().__init__()
418
+ self.hidden_size = config.hidden_size
419
+
420
+ # Self-attention for decoder
421
+ self.self_attn = MIMI_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
422
+
423
+ # Cross-attention to encoder
424
+ self.cross_attn = CrossAttention(config=config, layer_idx=layer_idx)
425
+
426
+ self.mlp = MimiMLP(config)
427
+
428
+ # Layer norms
429
+ self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps)
430
+ self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps)
431
+ self.post_cross_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps)
432
+
433
+ # Layer scales
434
+ self.self_attn_layer_scale = MimiLayerScale(config)
435
+ self.cross_attn_layer_scale = MimiLayerScale(config)
436
+ self.mlp_layer_scale = MimiLayerScale(config)
437
+
438
+ def forward(
439
+ self,
440
+ hidden_states: torch.Tensor, # Decoder hidden states
441
+ encoder_hidden_states: torch.Tensor, # Encoder hidden states
442
+ attention_mask: Optional[torch.Tensor] = None, # Causal mask for self-attention
443
+ encoder_attention_mask: Optional[torch.Tensor] = None, # Mask for encoder positions
444
+ position_ids: Optional[torch.LongTensor] = None,
445
+ past_key_value: Optional[Cache] = None,
446
+ cross_past_key_value: Optional[Cache] = None,
447
+ output_attentions: Optional[bool] = False,
448
+ use_cache: Optional[bool] = False,
449
+ cache_position: Optional[torch.LongTensor] = None,
450
+ alignment_chunk_sizes: Optional[torch.Tensor] = None, # Monotonic attention chunk sizes
451
+ **kwargs,
452
+ ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
453
+ """
454
+ Args:
455
+ hidden_states (`torch.FloatTensor`): decoder input of shape `(batch, seq_len, embed_dim)`
456
+ encoder_hidden_states (`torch.FloatTensor`): encoder output of shape `(batch, encoder_seq_len, embed_dim)`
457
+ attention_mask (`torch.FloatTensor`, *optional*): causal attention mask for self-attention
458
+ encoder_attention_mask (`torch.FloatTensor`, *optional*): mask for encoder positions
459
+ position_ids (`torch.LongTensor`, *optional*): position IDs for decoder
460
+ past_key_value (`Cache`, *optional*): cached self-attention states
461
+ cross_past_key_value (`Cache`, *optional*): cached cross-attention states
462
+ output_attentions (`bool`, *optional*): whether to return attention weights
463
+ use_cache (`bool`, *optional*): whether to use caching
464
+ cache_position (`torch.LongTensor`, *optional*): cache positions
465
+ """
466
+ residual = hidden_states
467
+
468
+ # Pre-norm for self-attention
469
+ hidden_states = self.input_layernorm(hidden_states)
470
+
471
+ # Self-attention on decoder
472
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
473
+ hidden_states=hidden_states,
474
+ attention_mask=attention_mask,
475
+ position_ids=position_ids,
476
+ past_key_value=past_key_value,
477
+ output_attentions=output_attentions,
478
+ use_cache=use_cache,
479
+ cache_position=cache_position,
480
+ **kwargs,
481
+ )
482
+ hidden_states = residual + self.self_attn_layer_scale(hidden_states)
483
+
484
+ # Cross-attention to encoder
485
+ residual = hidden_states
486
+ hidden_states = self.post_attention_layernorm(hidden_states)
487
+
488
+ hidden_states, cross_attn_weights, cross_present_key_value = self.cross_attn(
489
+ hidden_states=hidden_states,
490
+ encoder_hidden_states=encoder_hidden_states,
491
+ attention_mask=encoder_attention_mask,
492
+ position_ids=position_ids,
493
+ past_key_value=cross_past_key_value,
494
+ output_attentions=output_attentions,
495
+ use_cache=use_cache,
496
+ cache_position=cache_position,
497
+ alignment_chunk_sizes=alignment_chunk_sizes,
498
+ )
499
+ hidden_states = residual + self.cross_attn_layer_scale(hidden_states)
500
+
501
+ # Feed Forward Network
502
+ residual = hidden_states
503
+ hidden_states = self.post_cross_attention_layernorm(hidden_states)
504
+ hidden_states = self.mlp(hidden_states)
505
+ hidden_states = residual + self.mlp_layer_scale(hidden_states)
506
+
507
+ outputs = (hidden_states,)
508
+
509
+ if output_attentions:
510
+ outputs += (self_attn_weights, cross_attn_weights)
511
+
512
+ if use_cache:
513
+ outputs += (present_key_value, cross_present_key_value)
514
+
515
+ return outputs
516
+
517
+
518
+ class CrossAttentionTransformer(nn.Module):
519
+ """
520
+ Cross-attention transformer consisting of N cross-attention layers.
521
+ Each layer performs self-attention on decoder and cross-attention to encoder.
522
+
523
+ Args:
524
+ config: MimiConfig
525
+ """
526
+
527
+ def __init__(self, config: MimiConfig):
528
+ super().__init__()
529
+
530
+ self.layers = nn.ModuleList(
531
+ [CrossAttentionLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
532
+ )
533
+ self._attn_implementation = config._attn_implementation
534
+
535
+ self.gradient_checkpointing = False
536
+ self.config = config
537
+
538
+ def forward(
539
+ self,
540
+ hidden_states: torch.Tensor, # Decoder hidden states
541
+ encoder_hidden_states: torch.Tensor, # Encoder hidden states
542
+ attention_mask: Optional[torch.Tensor] = None, # Causal mask for decoder
543
+ encoder_attention_mask: Optional[torch.Tensor] = None, # Mask for encoder
544
+ position_ids: Optional[torch.LongTensor] = None,
545
+ past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
546
+ cross_past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
547
+ use_cache: Optional[bool] = None,
548
+ output_attentions: Optional[bool] = None,
549
+ output_hidden_states: Optional[bool] = None,
550
+ return_dict: Optional[bool] = None,
551
+ cache_position: Optional[torch.LongTensor] = None,
552
+ alignment_chunk_sizes: Optional[torch.Tensor] = None, # Monotonic attention chunk sizes
553
+ ) -> Union[tuple, BaseModelOutputWithPast]:
554
+ """
555
+ Args:
556
+ hidden_states (`torch.FloatTensor`): decoder input of shape `(batch_size, decoder_sequence_length, hidden_size)`
557
+ encoder_hidden_states (`torch.FloatTensor`): encoder output of shape `(batch_size, encoder_sequence_length, hidden_size)`
558
+ attention_mask (`torch.Tensor`, *optional*): causal attention mask for decoder self-attention
559
+ encoder_attention_mask (`torch.Tensor`, *optional*): attention mask for encoder positions
560
+ position_ids (`torch.LongTensor`, *optional*): position IDs for decoder
561
+ past_key_values (`Cache` or `list`, *optional*): cached self-attention states
562
+ cross_past_key_values (`Cache` or `list`, *optional*): cached cross-attention states
563
+ use_cache (`bool`, *optional*): whether to use caching
564
+ output_attentions (`bool`, *optional*): whether to return attention weights
565
+ output_hidden_states (`bool`, *optional*): whether to return hidden states
566
+ return_dict (`bool`, *optional*): whether to return ModelOutput
567
+ cache_position (`torch.LongTensor`, *optional*): cache positions
568
+ alignment_chunk_sizes (`torch.Tensor`, *optional*): tensor of shape `(decoder_sequence_length,)` specifying
569
+ how many encoder positions each decoder position can attend to cumulatively. Enables monotonic attention
570
+ where decoder position i can attend to encoder positions 0 through sum(alignment_chunk_sizes[:i+1])-1.
571
+ """
572
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
573
+ output_hidden_states = (
574
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
575
+ )
576
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
577
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
578
+
579
+ if use_cache and past_key_values is None:
580
+ logger.warning_once("use_cache=True was passed, but no past_key_values were given. Creating new cache.")
581
+ past_key_values = DynamicCache()
582
+
583
+ if use_cache and cross_past_key_values is None:
584
+ cross_past_key_values = DynamicCache()
585
+
586
+ if cache_position is None:
587
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
588
+ cache_position = torch.arange(
589
+ past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=hidden_states.device
590
+ )
591
+
592
+ if position_ids is None:
593
+ position_ids = cache_position.unsqueeze(0)
594
+
595
+ # Create causal mask for decoder self-attention
596
+ causal_mask = create_causal_mask(
597
+ config=self.config,
598
+ input_embeds=hidden_states,
599
+ attention_mask=attention_mask,
600
+ cache_position=cache_position,
601
+ past_key_values=past_key_values,
602
+ position_ids=position_ids,
603
+ )
604
+
605
+ # Initialize output containers
606
+ all_hidden_states = () if output_hidden_states else None
607
+ all_self_attns = () if output_attentions else None
608
+ all_cross_attns = () if output_attentions else None
609
+ next_decoder_cache = None
610
+ next_cross_cache = None
611
+
612
+ for layer_idx, decoder_layer in enumerate(self.layers):
613
+ if output_hidden_states:
614
+ all_hidden_states += (hidden_states,)
615
+
616
+ # Get past key values for this layer
617
+ layer_past_key_value = past_key_values[layer_idx] if past_key_values is not None else None
618
+ layer_cross_past_key_value = cross_past_key_values[layer_idx] if cross_past_key_values is not None else None
619
+
620
+ if self.gradient_checkpointing and self.training:
621
+ layer_outputs = self._gradient_checkpointing_func(
622
+ decoder_layer.__call__,
623
+ hidden_states,
624
+ encoder_hidden_states,
625
+ causal_mask,
626
+ encoder_attention_mask,
627
+ position_ids,
628
+ layer_past_key_value,
629
+ layer_cross_past_key_value,
630
+ output_attentions,
631
+ use_cache,
632
+ cache_position,
633
+ alignment_chunk_sizes,
634
+ )
635
+ else:
636
+ layer_outputs = decoder_layer(
637
+ hidden_states,
638
+ encoder_hidden_states=encoder_hidden_states,
639
+ attention_mask=causal_mask,
640
+ encoder_attention_mask=encoder_attention_mask,
641
+ position_ids=position_ids,
642
+ past_key_value=layer_past_key_value,
643
+ cross_past_key_value=layer_cross_past_key_value,
644
+ output_attentions=output_attentions,
645
+ use_cache=use_cache,
646
+ cache_position=cache_position,
647
+ alignment_chunk_sizes=alignment_chunk_sizes,
648
+ )
649
+
650
+ hidden_states = layer_outputs[0]
651
+
652
+ if use_cache:
653
+ # Extract the cached states
654
+ if output_attentions:
655
+ next_decoder_cache = layer_outputs[3] # self attn cache
656
+ next_cross_cache = layer_outputs[4] # cross attn cache
657
+ else:
658
+ next_decoder_cache = layer_outputs[1] # self attn cache
659
+ next_cross_cache = layer_outputs[2] # cross attn cache
660
+
661
+ if output_attentions:
662
+ all_self_attns += (layer_outputs[1],) # self attention weights
663
+ all_cross_attns += (layer_outputs[2],) # cross attention weights
664
+
665
+ # Add hidden states from the last decoder layer
666
+ if output_hidden_states:
667
+ all_hidden_states += (hidden_states,)
668
+
669
+ next_cache = next_decoder_cache if use_cache else None
670
+ next_cross_cache = next_cross_cache if use_cache else None
671
+
672
+ if not return_dict:
673
+ return tuple(v for v in [hidden_states, next_cache, next_cross_cache, all_hidden_states, all_self_attns, all_cross_attns] if v is not None)
674
+
675
+ return BaseModelOutputWithPast(
676
+ last_hidden_state=hidden_states,
677
+ past_key_values=next_cache,
678
+ hidden_states=all_hidden_states,
679
+ attentions=all_self_attns,
680
+ )
681
+
682
+
683
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
684
+ """
685
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
686
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
687
+ """
688
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
689
+ if n_rep == 1:
690
+ return hidden_states
691
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
692
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
693
+
694
+
695
+ def _create_monotonic_attention_mask(
696
+ alignment_chunk_sizes: torch.Tensor,
697
+ query_length: int,
698
+ key_length: int,
699
+ device: torch.device,
700
+ dtype: torch.dtype,
701
+ ) -> torch.Tensor:
702
+ """
703
+ Create a monotonic attention mask where each query can only attend to a progressive subset of keys.
704
+
705
+ Args:
706
+ alignment_chunk_sizes: Tensor of shape (batch_size, query_length) where each element represents
707
+ how many keys the corresponding query can attend to cumulatively.
708
+ query_length: Number of queries (text tokens)
709
+ key_length: Number of keys (speech features)
710
+ device: Device to create the mask on
711
+ dtype: Data type for the mask
712
+
713
+ Returns:
714
+ Attention mask of shape (batch_size, 1, query_length, key_length) where
715
+ -inf masks out invalid positions, 0.0 allows attention.
716
+ """
717
+ batch_size = alignment_chunk_sizes.shape[0]
718
+
719
+ # Create cumulative positions that each query can attend up to
720
+ cumulative_positions = torch.cumsum(alignment_chunk_sizes, dim=1) # [batch_size, query_length]
721
+
722
+ # Ensure we don't exceed the key length
723
+ cumulative_positions = torch.clamp(cumulative_positions, max=key_length)
724
+
725
+ # Create position indices for keys
726
+ key_positions = torch.arange(key_length, device=device).unsqueeze(0).unsqueeze(0) # [1, 1, key_length]
727
+
728
+ # Expand cumulative positions for broadcasting
729
+ cumulative_positions = cumulative_positions.unsqueeze(2) # [batch_size, query_length, 1]
730
+
731
+ # Create mask: query i can attend to keys 0 to cumulative_positions[i]
732
+ mask = key_positions < cumulative_positions # [batch_size, query_length, key_length]
733
+
734
+ # Convert to attention mask format: True -> 0.0 (attend), False -> -inf (mask out)
735
+ attention_mask = torch.where(mask, 0.0, float('-inf'))
736
+
737
+ # Add head dimension: [batch_size, 1, query_length, key_length]
738
+ attention_mask = attention_mask.unsqueeze(1)
739
+
740
+ return attention_mask.to(dtype)
741
+
742
+
743
+
744
+ __all__ = [
745
+ "CausalAttentionTransformer",
746
+ "MimiTransformerLayer",
747
+ "CrossAttention",
748
+ "CrossAttentionLayer",
749
+ "CrossAttentionTransformer",
750
+ "_create_monotonic_attention_mask",
751
+ ]