|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import sys
|
|
|
from typing import Callable, Optional, Tuple
|
|
|
|
|
|
import torch
|
|
|
|
|
|
if sys.version_info >= (3, 11):
|
|
|
pass
|
|
|
else:
|
|
|
pass
|
|
|
|
|
|
from transformers.cache_utils import Cache
|
|
|
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
|
|
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
|
|
from transformers.utils import logging
|
|
|
|
|
|
from verl.utils.ulysses import (
|
|
|
gather_heads_scatter_seq,
|
|
|
gather_seq_scatter_heads,
|
|
|
get_ulysses_sequence_parallel_world_size,
|
|
|
validate_ulysses_config,
|
|
|
)
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
|
|
|
def llama_flash_attn_forward(
|
|
|
self,
|
|
|
hidden_states: torch.Tensor,
|
|
|
attention_mask: Optional[torch.LongTensor] = None,
|
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
|
past_key_value: Optional[Cache] = None,
|
|
|
output_attentions: bool = False,
|
|
|
use_cache: bool = False,
|
|
|
cache_position: Optional[torch.LongTensor] = None,
|
|
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
|
**kwargs,
|
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
|
"""
|
|
|
Adapted from transformers 4.47.1 to support Ulysses sequence parallelism.
|
|
|
|
|
|
NOTE: This function is used for transformers versions in the range [4.45.0, 4.47.1].
|
|
|
"""
|
|
|
output_attentions = False
|
|
|
|
|
|
bsz, q_len, _ = hidden_states.size()
|
|
|
|
|
|
query_states = self.q_proj(hidden_states)
|
|
|
key_states = self.k_proj(hidden_states)
|
|
|
value_states = self.v_proj(hidden_states)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ulysses_sp_size = get_ulysses_sequence_parallel_world_size()
|
|
|
|
|
|
if ulysses_sp_size > 1:
|
|
|
validate_ulysses_config(self.num_heads, ulysses_sp_size)
|
|
|
|
|
|
|
|
|
query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1)
|
|
|
key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1)
|
|
|
value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1)
|
|
|
|
|
|
full_q_len = query_states.size(2)
|
|
|
|
|
|
if position_embeddings is None:
|
|
|
logger.warning_once(
|
|
|
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
|
|
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
|
|
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
|
|
"removed and `position_embeddings` will be mandatory."
|
|
|
)
|
|
|
cos, sin = self.rotary_emb(value_states, position_ids)
|
|
|
else:
|
|
|
cos, sin = position_embeddings
|
|
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
|
|
|
|
|
if past_key_value is not None:
|
|
|
|
|
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
|
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
query_states = query_states.transpose(1, 2)
|
|
|
key_states = key_states.transpose(1, 2)
|
|
|
value_states = value_states.transpose(1, 2)
|
|
|
|
|
|
dropout_rate = self.attention_dropout if self.training else 0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
input_dtype = query_states.dtype
|
|
|
if input_dtype == torch.float32:
|
|
|
if torch.is_autocast_enabled():
|
|
|
target_dtype = torch.get_autocast_gpu_dtype()
|
|
|
|
|
|
elif hasattr(self.config, "_pre_quantization_dtype"):
|
|
|
target_dtype = self.config._pre_quantization_dtype
|
|
|
else:
|
|
|
target_dtype = self.q_proj.weight.dtype
|
|
|
|
|
|
logger.warning_once(f"The input hidden states seems to be silently casted in float32, this might be related to the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in {target_dtype}.")
|
|
|
|
|
|
query_states = query_states.to(target_dtype)
|
|
|
key_states = key_states.to(target_dtype)
|
|
|
value_states = value_states.to(target_dtype)
|
|
|
|
|
|
attn_output = _flash_attention_forward(
|
|
|
query_states,
|
|
|
key_states,
|
|
|
value_states,
|
|
|
attention_mask,
|
|
|
full_q_len,
|
|
|
position_ids=position_ids,
|
|
|
dropout=dropout_rate,
|
|
|
sliding_window=getattr(self, "sliding_window", None),
|
|
|
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
|
|
is_causal=self.is_causal,
|
|
|
**kwargs,
|
|
|
)
|
|
|
|
|
|
attn_output = attn_output.reshape(bsz, full_q_len, -1, self.head_dim).contiguous()
|
|
|
|
|
|
if ulysses_sp_size > 1:
|
|
|
attn_output = gather_heads_scatter_seq(attn_output, seq_dim=1, head_dim=2)
|
|
|
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
|
|
attn_output = self.o_proj(attn_output)
|
|
|
|
|
|
if not output_attentions:
|
|
|
attn_weights = None
|
|
|
|
|
|
return attn_output, attn_weights, past_key_value
|
|
|
|
|
|
|
|
|
def llama_attn_forward(
|
|
|
self,
|
|
|
hidden_states: torch.Tensor,
|
|
|
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
|
|
|
attention_mask: Optional[torch.Tensor],
|
|
|
past_key_value: Optional[Cache] = None,
|
|
|
cache_position: Optional[torch.LongTensor] = None,
|
|
|
**kwargs,
|
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
|
"""
|
|
|
Adapted from transformers 4.49.0 to support Ulysses sequence parallelism for transformers >= 4.48.0.
|
|
|
|
|
|
NOTE: This function has been tested only on transformers versions between 4.48.0 and 4.50.0.
|
|
|
"""
|
|
|
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
|
|
from transformers.models.llama.modeling_llama import eager_attention_forward
|
|
|
|
|
|
bsz, q_len, _ = hidden_states.shape
|
|
|
|
|
|
query_states = self.q_proj(hidden_states).view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
|
|
key_states = self.k_proj(hidden_states).view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
|
|
value_states = self.v_proj(hidden_states).view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
|
|
|
|
|
|
|
|
ulysses_sp_size = get_ulysses_sequence_parallel_world_size()
|
|
|
|
|
|
if ulysses_sp_size > 1:
|
|
|
validate_ulysses_config(self.config.num_attention_heads, ulysses_sp_size)
|
|
|
|
|
|
query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1)
|
|
|
key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1)
|
|
|
value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1)
|
|
|
|
|
|
full_q_len = query_states.size(2)
|
|
|
|
|
|
cos, sin = position_embeddings
|
|
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
|
|
|
|
|
if past_key_value is not None:
|
|
|
|
|
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
|
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
|
|
|
|
|
attention_interface: Callable = eager_attention_forward
|
|
|
if self.config._attn_implementation != "eager":
|
|
|
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
|
|
|
logger.warning_once('`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.')
|
|
|
else:
|
|
|
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
|
|
|
|
|
attn_output, attn_weights = attention_interface(
|
|
|
self,
|
|
|
query_states,
|
|
|
key_states,
|
|
|
value_states,
|
|
|
attention_mask,
|
|
|
dropout=0.0 if not self.training else self.attention_dropout,
|
|
|
scaling=self.scaling,
|
|
|
**kwargs,
|
|
|
)
|
|
|
|
|
|
attn_output = attn_output.reshape(bsz, full_q_len, -1, self.head_dim).contiguous()
|
|
|
|
|
|
if ulysses_sp_size > 1:
|
|
|
attn_output = gather_heads_scatter_seq(attn_output, seq_dim=1, head_dim=2)
|
|
|
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
|
|
attn_output = self.o_proj(attn_output)
|
|
|
return attn_output, attn_weights
|
|
|
|