Update modeling_Llamoe.py
Browse files- modeling_Llamoe.py +84 -0
modeling_Llamoe.py
CHANGED
|
@@ -646,11 +646,95 @@ class LlamoeFlashAttention2(LlamoeAttention):
|
|
| 646 |
)
|
| 647 |
|
| 648 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 649 |
|
| 650 |
|
| 651 |
LLAMOE_ATTENTION_CLASSES = {
|
| 652 |
"eager": LlamoeAttention,
|
| 653 |
"flash_attention_2": LlamoeFlashAttention2,
|
|
|
|
| 654 |
}
|
| 655 |
|
| 656 |
|
|
|
|
| 646 |
)
|
| 647 |
|
| 648 |
|
| 649 |
+
class LlamoeSdpaAttention(LlamoeAttention):
|
| 650 |
+
"""
|
| 651 |
+
Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
| 652 |
+
`LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
|
| 653 |
+
SDPA API.
|
| 654 |
+
"""
|
| 655 |
+
|
| 656 |
+
# Adapted from LlamaAttention.forward
|
| 657 |
+
def forward(
|
| 658 |
+
self,
|
| 659 |
+
hidden_states: torch.Tensor,
|
| 660 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 661 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 662 |
+
past_key_value: Optional[Cache] = None,
|
| 663 |
+
output_attentions: bool = False,
|
| 664 |
+
use_cache: bool = False,
|
| 665 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 666 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 667 |
+
if output_attentions:
|
| 668 |
+
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
| 669 |
+
logger.warning_once(
|
| 670 |
+
"LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
|
| 671 |
+
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
| 672 |
+
)
|
| 673 |
+
return super().forward(
|
| 674 |
+
hidden_states=hidden_states,
|
| 675 |
+
attention_mask=attention_mask,
|
| 676 |
+
position_ids=position_ids,
|
| 677 |
+
past_key_value=past_key_value,
|
| 678 |
+
output_attentions=output_attentions,
|
| 679 |
+
use_cache=use_cache,
|
| 680 |
+
cache_position=cache_position,
|
| 681 |
+
)
|
| 682 |
+
|
| 683 |
+
bsz, q_len, _ = hidden_states.size()
|
| 684 |
+
|
| 685 |
+
query_states = self.q_proj(hidden_states)
|
| 686 |
+
key_states = self.k_proj(hidden_states)
|
| 687 |
+
value_states = self.v_proj(hidden_states)
|
| 688 |
+
|
| 689 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 690 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 691 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 692 |
+
|
| 693 |
+
cos, sin = self.rotary_emb(value_states, position_ids)
|
| 694 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 695 |
+
|
| 696 |
+
# In case static cache is used, it is an instance attribute.
|
| 697 |
+
past_key_value = getattr(self, "past_key_value", past_key_value)
|
| 698 |
+
|
| 699 |
+
if past_key_value is not None:
|
| 700 |
+
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
| 701 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
| 702 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 703 |
+
|
| 704 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
| 705 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
| 706 |
+
|
| 707 |
+
causal_mask = attention_mask
|
| 708 |
+
if attention_mask is not None and cache_position is not None:
|
| 709 |
+
causal_mask = causal_mask[:, :, cache_position, : key_states.shape[-2]]
|
| 710 |
+
|
| 711 |
+
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
| 712 |
+
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
| 713 |
+
if query_states.device.type == "cuda" and causal_mask is not None:
|
| 714 |
+
query_states = query_states.contiguous()
|
| 715 |
+
key_states = key_states.contiguous()
|
| 716 |
+
value_states = value_states.contiguous()
|
| 717 |
+
|
| 718 |
+
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
| 719 |
+
query_states,
|
| 720 |
+
key_states,
|
| 721 |
+
value_states,
|
| 722 |
+
attn_mask=causal_mask,
|
| 723 |
+
dropout_p=self.attention_dropout if self.training else 0.0,
|
| 724 |
+
)
|
| 725 |
+
|
| 726 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 727 |
+
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
|
| 728 |
+
|
| 729 |
+
attn_output = self.o_proj(attn_output)
|
| 730 |
+
|
| 731 |
+
return attn_output, None, past_key_value
|
| 732 |
|
| 733 |
|
| 734 |
LLAMOE_ATTENTION_CLASSES = {
|
| 735 |
"eager": LlamoeAttention,
|
| 736 |
"flash_attention_2": LlamoeFlashAttention2,
|
| 737 |
+
"sdpa": LlamoeSdpaAttention
|
| 738 |
}
|
| 739 |
|
| 740 |
|