multimodalart HF Staff Claude Opus 4.6 commited on
Commit
4a3004b
·
1 Parent(s): e36805d

Fix LlamaAttention compatibility with transformers 4.53.0

Browse files

The transformers 4.53 API requires `position_embeddings` (cos/sin rotary
tuple) as a positional arg to LlamaAttention.forward() and returns 2
values instead of 3. Update DiffLlama to compute rotary embeddings and
pass them through the decoder layers.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Files changed (1) hide show
  1. soulxsinger/models/modules/llama.py +26 -54
soulxsinger/models/modules/llama.py CHANGED
@@ -4,7 +4,7 @@ import torch.nn as nn
4
  from typing import List, Optional, Tuple, Union
5
  import math
6
 
7
- from transformers.models.llama.modeling_llama import LlamaDecoderLayer
8
  from transformers.models.llama.modeling_llama import BaseModelOutputWithPast
9
 
10
 
@@ -62,27 +62,13 @@ class LlamaNARDecoderLayer(LlamaDecoderLayer):
62
  hidden_states: torch.Tensor,
63
  cond_embedding: torch.Tensor,
64
  attention_mask: Optional[torch.Tensor] = None,
65
- position_ids: Optional[torch.LongTensor] = None,
66
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
67
  output_attentions: Optional[bool] = False,
68
  use_cache: Optional[bool] = False,
69
  ) -> Tuple[
70
  torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
71
  ]:
72
- """
73
- Args:
74
- hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
75
- attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
76
- `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
77
- output_attentions (`bool`, *optional*):
78
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
79
- returned tensors for more detail.
80
- use_cache (`bool`, *optional*):
81
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
82
- (see `past_key_values`).
83
- past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
84
- """
85
-
86
  residual = hidden_states
87
 
88
  hidden_states = self.input_layernorm(
@@ -90,13 +76,11 @@ class LlamaNARDecoderLayer(LlamaDecoderLayer):
90
  )
91
 
92
  # Self Attention
93
- hidden_states, self_attn_weights, present_key_value = self.self_attn(
94
  hidden_states=hidden_states,
 
95
  attention_mask=attention_mask,
96
- position_ids=position_ids,
97
  past_key_value=past_key_value,
98
- output_attentions=output_attentions,
99
- use_cache=use_cache,
100
  )
101
  hidden_states = residual + hidden_states
102
 
@@ -113,9 +97,6 @@ class LlamaNARDecoderLayer(LlamaDecoderLayer):
113
  if output_attentions:
114
  outputs += (self_attn_weights,)
115
 
116
- if use_cache:
117
- outputs += (present_key_value,)
118
-
119
  return outputs
120
 
121
 
@@ -185,6 +166,15 @@ class DiffLlama(LlamaModel):
185
 
186
  self.embed_tokens = None
187
 
 
 
 
 
 
 
 
 
 
188
  self.post_init()
189
 
190
  # self.reset_parameters()
@@ -309,6 +299,9 @@ class DiffLlama(LlamaModel):
309
 
310
  hidden_states = inputs_embeds
311
 
 
 
 
312
  if self.gradient_checkpointing and self.training:
313
  if use_cache:
314
  use_cache = False
@@ -328,40 +321,19 @@ class DiffLlama(LlamaModel):
328
  past_key_values[idx] if past_key_values is not None else None
329
  )
330
 
331
- if self.gradient_checkpointing and self.training:
332
- raise NotImplementedError
333
-
334
- def create_custom_forward(module):
335
- def custom_forward(*inputs):
336
- # None for past_key_value
337
- return module(*inputs, output_attentions, None)
338
-
339
- return custom_forward
340
-
341
- layer_outputs = torch.utils.checkpoint.checkpoint(
342
- create_custom_forward(decoder_layer),
343
- hidden_states,
344
- attention_mask,
345
- position_ids,
346
- None,
347
- )
348
- else:
349
- layer_outputs = decoder_layer(
350
- hidden_states,
351
- attention_mask=attention_mask,
352
- position_ids=position_ids,
353
- past_key_value=past_key_value,
354
- output_attentions=output_attentions,
355
- use_cache=use_cache,
356
- cond_embedding=diffusion_step,
357
- )
358
 
359
  hidden_states = layer_outputs[0]
360
  all_layer_hidden_states.append(hidden_states.clone())
361
 
362
- if use_cache:
363
- next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
364
-
365
  if output_attentions:
366
  all_self_attns += (layer_outputs[1],)
367
 
 
4
  from typing import List, Optional, Tuple, Union
5
  import math
6
 
7
+ from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaRotaryEmbedding
8
  from transformers.models.llama.modeling_llama import BaseModelOutputWithPast
9
 
10
 
 
62
  hidden_states: torch.Tensor,
63
  cond_embedding: torch.Tensor,
64
  attention_mask: Optional[torch.Tensor] = None,
65
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
66
+ past_key_value=None,
67
  output_attentions: Optional[bool] = False,
68
  use_cache: Optional[bool] = False,
69
  ) -> Tuple[
70
  torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
71
  ]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  residual = hidden_states
73
 
74
  hidden_states = self.input_layernorm(
 
76
  )
77
 
78
  # Self Attention
79
+ hidden_states, self_attn_weights = self.self_attn(
80
  hidden_states=hidden_states,
81
+ position_embeddings=position_embeddings,
82
  attention_mask=attention_mask,
 
83
  past_key_value=past_key_value,
 
 
84
  )
85
  hidden_states = residual + hidden_states
86
 
 
97
  if output_attentions:
98
  outputs += (self_attn_weights,)
99
 
 
 
 
100
  return outputs
101
 
102
 
 
166
 
167
  self.embed_tokens = None
168
 
169
+ # Re-create rotary_emb with the actual layer config dimensions
170
+ layer_config = LlamaConfig(
171
+ hidden_size=hidden_size,
172
+ num_attention_heads=num_heads,
173
+ max_position_embeddings=4096,
174
+ intermediate_size=hidden_size * 4,
175
+ )
176
+ self.rotary_emb = LlamaRotaryEmbedding(config=layer_config)
177
+
178
  self.post_init()
179
 
180
  # self.reset_parameters()
 
299
 
300
  hidden_states = inputs_embeds
301
 
302
+ # Compute rotary position embeddings
303
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
304
+
305
  if self.gradient_checkpointing and self.training:
306
  if use_cache:
307
  use_cache = False
 
321
  past_key_values[idx] if past_key_values is not None else None
322
  )
323
 
324
+ layer_outputs = decoder_layer(
325
+ hidden_states,
326
+ attention_mask=attention_mask,
327
+ position_embeddings=position_embeddings,
328
+ past_key_value=past_key_value,
329
+ output_attentions=output_attentions,
330
+ use_cache=use_cache,
331
+ cond_embedding=diffusion_step,
332
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
333
 
334
  hidden_states = layer_outputs[0]
335
  all_layer_hidden_states.append(hidden_states.clone())
336
 
 
 
 
337
  if output_attentions:
338
  all_self_attns += (layer_outputs[1],)
339