Add SDPA attention
#2
by
Katsumata420
- opened
- modeling_retrieva_bert.py +173 -19
modeling_retrieva_bert.py
CHANGED
|
@@ -34,6 +34,7 @@ from typing import Optional, Tuple, Union
|
|
| 34 |
|
| 35 |
import torch
|
| 36 |
import torch.utils.checkpoint
|
|
|
|
| 37 |
from torch import nn
|
| 38 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 39 |
|
|
@@ -49,6 +50,10 @@ from transformers.modeling_outputs import (
|
|
| 49 |
SequenceClassifierOutput,
|
| 50 |
TokenClassifierOutput,
|
| 51 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
from transformers.modeling_utils import PreTrainedModel
|
| 53 |
from transformers.pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
| 54 |
from transformers.utils import (
|
|
@@ -56,6 +61,7 @@ from transformers.utils import (
|
|
| 56 |
add_code_sample_docstrings,
|
| 57 |
add_start_docstrings,
|
| 58 |
add_start_docstrings_to_model_forward,
|
|
|
|
| 59 |
logging,
|
| 60 |
replace_return_docstrings,
|
| 61 |
)
|
|
@@ -407,6 +413,113 @@ class RetrievaBertSelfAttention(nn.Module):
|
|
| 407 |
return outputs
|
| 408 |
|
| 409 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 410 |
# Based transformers.models.bert.modeling_bert.BertSelfOutput. Moved LayerNorm to RetrievaBertAttention below.
|
| 411 |
class RetrievaBertSelfOutput(nn.Module):
|
| 412 |
def __init__(self, config):
|
|
@@ -420,12 +533,18 @@ class RetrievaBertSelfOutput(nn.Module):
|
|
| 420 |
return residual + hidden_states
|
| 421 |
|
| 422 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 423 |
# Based transformers.models.bert.modeling_bert.BertAttention. Added LayerNorm.
|
| 424 |
class RetrievaBertAttention(nn.Module):
|
| 425 |
def __init__(self, config):
|
| 426 |
super().__init__()
|
| 427 |
self.ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 428 |
-
self.self =
|
| 429 |
self.output = RetrievaBertSelfOutput(config)
|
| 430 |
self.pruned_heads = set()
|
| 431 |
|
|
@@ -808,6 +927,7 @@ class RetrievaBertPreTrainedModel(PreTrainedModel):
|
|
| 808 |
load_tf_weights = load_tf_weights_in_megatron_bert
|
| 809 |
base_model_prefix = "bert"
|
| 810 |
supports_gradient_checkpointing = True
|
|
|
|
| 811 |
|
| 812 |
def _init_weights(self, module):
|
| 813 |
"""Initialize the weights"""
|
|
@@ -953,6 +1073,8 @@ class RetrievaBertModel(RetrievaBertPreTrainedModel):
|
|
| 953 |
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
|
| 954 |
)
|
| 955 |
|
|
|
|
|
|
|
| 956 |
# Initialize weights and apply final processing
|
| 957 |
self.post_init()
|
| 958 |
|
|
@@ -1046,9 +1168,48 @@ class RetrievaBertModel(RetrievaBertPreTrainedModel):
|
|
| 1046 |
if position_ids is None:
|
| 1047 |
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
|
| 1048 |
|
| 1049 |
-
|
| 1050 |
-
|
| 1051 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1052 |
|
| 1053 |
# If a 2D or 3D attention mask is provided for the cross-attention
|
| 1054 |
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
|
@@ -1057,24 +1218,17 @@ class RetrievaBertModel(RetrievaBertPreTrainedModel):
|
|
| 1057 |
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
| 1058 |
if encoder_attention_mask is None:
|
| 1059 |
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
| 1060 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1061 |
else:
|
| 1062 |
encoder_extended_attention_mask = None
|
| 1063 |
|
| 1064 |
-
# Prepare head mask if needed
|
| 1065 |
-
# 1.0 in head_mask indicate we keep the head
|
| 1066 |
-
# attention_probs has shape bsz x n_heads x N x N
|
| 1067 |
-
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
| 1068 |
-
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
| 1069 |
-
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
| 1070 |
-
|
| 1071 |
-
embedding_output = self.embeddings(
|
| 1072 |
-
input_ids=input_ids,
|
| 1073 |
-
position_ids=position_ids,
|
| 1074 |
-
token_type_ids=token_type_ids,
|
| 1075 |
-
inputs_embeds=inputs_embeds,
|
| 1076 |
-
past_key_values_length=past_key_values_length,
|
| 1077 |
-
)
|
| 1078 |
encoder_outputs = self.encoder(
|
| 1079 |
embedding_output,
|
| 1080 |
attention_mask=extended_attention_mask,
|
|
|
|
| 34 |
|
| 35 |
import torch
|
| 36 |
import torch.utils.checkpoint
|
| 37 |
+
from packaging import version
|
| 38 |
from torch import nn
|
| 39 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 40 |
|
|
|
|
| 50 |
SequenceClassifierOutput,
|
| 51 |
TokenClassifierOutput,
|
| 52 |
)
|
| 53 |
+
from transformers.modeling_attn_mask_utils import (
|
| 54 |
+
_prepare_4d_attention_mask_for_sdpa,
|
| 55 |
+
_prepare_4d_causal_attention_mask_for_sdpa,
|
| 56 |
+
)
|
| 57 |
from transformers.modeling_utils import PreTrainedModel
|
| 58 |
from transformers.pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
| 59 |
from transformers.utils import (
|
|
|
|
| 61 |
add_code_sample_docstrings,
|
| 62 |
add_start_docstrings,
|
| 63 |
add_start_docstrings_to_model_forward,
|
| 64 |
+
get_torch_version,
|
| 65 |
logging,
|
| 66 |
replace_return_docstrings,
|
| 67 |
)
|
|
|
|
| 413 |
return outputs
|
| 414 |
|
| 415 |
|
| 416 |
+
class RetrievaBertSdpaSelfAttention(RetrievaBertSelfAttention):
|
| 417 |
+
def __init__(self, config):
|
| 418 |
+
super().__init__(config)
|
| 419 |
+
self.dropout_prob = config.attention_probs_dropout_prob
|
| 420 |
+
self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0")
|
| 421 |
+
|
| 422 |
+
def forward(
|
| 423 |
+
self,
|
| 424 |
+
hidden_states: torch.Tensor,
|
| 425 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 426 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 427 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 428 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
| 429 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
| 430 |
+
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
| 431 |
+
output_attentions: Optional[bool] = False,
|
| 432 |
+
) -> Tuple[torch.Tensor]:
|
| 433 |
+
if output_attentions or head_mask is not None:
|
| 434 |
+
logger.warning_once(
|
| 435 |
+
"RetrievaBertSdpaSelfAttention is used but `torch.nn.fuctional.scaled_dot_product_attention` does not support "
|
| 436 |
+
"`output_attentions=True` or `head_mask`. Falling back to the manual attention implementation. "
|
| 437 |
+
)
|
| 438 |
+
return super().forward(
|
| 439 |
+
hidden_states,
|
| 440 |
+
attention_mask,
|
| 441 |
+
position_ids,
|
| 442 |
+
head_mask,
|
| 443 |
+
encoder_hidden_states,
|
| 444 |
+
encoder_attention_mask,
|
| 445 |
+
past_key_value,
|
| 446 |
+
output_attentions,
|
| 447 |
+
)
|
| 448 |
+
|
| 449 |
+
bsz, tgt_len, _ = hidden_states.size()
|
| 450 |
+
|
| 451 |
+
mixed_query_layer = self.query(hidden_states)
|
| 452 |
+
query_layer = self.transpose_for_scores(mixed_query_layer, is_query=True)
|
| 453 |
+
|
| 454 |
+
# If this is instantiated as a cross-attention module, the keys
|
| 455 |
+
# and values come from an encoder; the attention mask needs to be
|
| 456 |
+
# such that the encoder's padding tokens are not attended to.
|
| 457 |
+
is_cross_attention = encoder_hidden_states is not None
|
| 458 |
+
|
| 459 |
+
# The following code is based on the implementation of `transformers.BertSdpaSelfAttention`
|
| 460 |
+
current_states = encoder_hidden_states if is_cross_attention else hidden_states
|
| 461 |
+
attention_mask = encoder_attention_mask if is_cross_attention else attention_mask
|
| 462 |
+
|
| 463 |
+
if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]:
|
| 464 |
+
key_layer, value_layer = past_key_value
|
| 465 |
+
else:
|
| 466 |
+
key_layer = self.transpose_for_scores(self.key(current_states), is_query=False)
|
| 467 |
+
value_layer = self.transpose_for_scores(self.value(current_states), is_query=False)
|
| 468 |
+
|
| 469 |
+
if self.rope_emb is not None:
|
| 470 |
+
cos, sin = self.rope_emb(hidden_states, position_ids)
|
| 471 |
+
query_layer, key_layer = apply_rotary_pos_emb(query_layer, key_layer, cos, sin)
|
| 472 |
+
|
| 473 |
+
if past_key_value is not None and not is_cross_attention:
|
| 474 |
+
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
| 475 |
+
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
| 476 |
+
|
| 477 |
+
# For GQA, we repeat the key/value weights.
|
| 478 |
+
key_layer = repeat_kv(key_layer, self.num_key_value_groups)
|
| 479 |
+
value_layer = repeat_kv(value_layer, self.num_key_value_groups)
|
| 480 |
+
|
| 481 |
+
if self.is_decoder:
|
| 482 |
+
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
| 483 |
+
# Further calls to cross_attention layer can then reuse all cross-attention
|
| 484 |
+
# key/value_states (first "if" case)
|
| 485 |
+
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
| 486 |
+
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
| 487 |
+
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
| 488 |
+
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
| 489 |
+
past_key_value = (key_layer, value_layer)
|
| 490 |
+
|
| 491 |
+
# SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
|
| 492 |
+
# attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0.
|
| 493 |
+
# Reference: https://github.com/pytorch/pytorch/issues/112577
|
| 494 |
+
if self.require_contiguous_qkv and query_layer.device.type == "cuda" and attention_mask is not None:
|
| 495 |
+
query_layer = query_layer.contiguous()
|
| 496 |
+
key_layer = key_layer.contiguous()
|
| 497 |
+
value_layer = value_layer.contiguous()
|
| 498 |
+
|
| 499 |
+
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
| 500 |
+
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
| 501 |
+
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create
|
| 502 |
+
# a causal mask in case tgt_len == 1.
|
| 503 |
+
is_causal = (
|
| 504 |
+
True if self.is_decoder and not is_cross_attention and attention_mask is None and tgt_len > 1 else False
|
| 505 |
+
)
|
| 506 |
+
|
| 507 |
+
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
| 508 |
+
query_layer,
|
| 509 |
+
key_layer,
|
| 510 |
+
value_layer,
|
| 511 |
+
attn_mask=attention_mask,
|
| 512 |
+
is_causal=is_causal,
|
| 513 |
+
dropout_p=self.dropout_prob if self.training else 0.0,
|
| 514 |
+
)
|
| 515 |
+
attn_output = attn_output.transpose(1, 2)
|
| 516 |
+
attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size)
|
| 517 |
+
|
| 518 |
+
outputs = (attn_output,)
|
| 519 |
+
if self.is_decoder:
|
| 520 |
+
outputs = outputs + (past_key_value,)
|
| 521 |
+
return outputs
|
| 522 |
+
|
| 523 |
# Based transformers.models.bert.modeling_bert.BertSelfOutput. Moved LayerNorm to RetrievaBertAttention below.
|
| 524 |
class RetrievaBertSelfOutput(nn.Module):
|
| 525 |
def __init__(self, config):
|
|
|
|
| 533 |
return residual + hidden_states
|
| 534 |
|
| 535 |
|
| 536 |
+
RETRIEVA_BERT_SELF_ATTENTION_CLASSES = {
|
| 537 |
+
"eager": RetrievaBertSelfAttention,
|
| 538 |
+
"sdpa": RetrievaBertSdpaSelfAttention,
|
| 539 |
+
}
|
| 540 |
+
|
| 541 |
+
|
| 542 |
# Based transformers.models.bert.modeling_bert.BertAttention. Added LayerNorm.
|
| 543 |
class RetrievaBertAttention(nn.Module):
|
| 544 |
def __init__(self, config):
|
| 545 |
super().__init__()
|
| 546 |
self.ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 547 |
+
self.self = RETRIEVA_BERT_SELF_ATTENTION_CLASSES[config._attn_implementation](config)
|
| 548 |
self.output = RetrievaBertSelfOutput(config)
|
| 549 |
self.pruned_heads = set()
|
| 550 |
|
|
|
|
| 927 |
load_tf_weights = load_tf_weights_in_megatron_bert
|
| 928 |
base_model_prefix = "bert"
|
| 929 |
supports_gradient_checkpointing = True
|
| 930 |
+
_supports_sdpa = True
|
| 931 |
|
| 932 |
def _init_weights(self, module):
|
| 933 |
"""Initialize the weights"""
|
|
|
|
| 1073 |
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
|
| 1074 |
)
|
| 1075 |
|
| 1076 |
+
self.attn_implementation = config._attn_implementation
|
| 1077 |
+
|
| 1078 |
# Initialize weights and apply final processing
|
| 1079 |
self.post_init()
|
| 1080 |
|
|
|
|
| 1168 |
if position_ids is None:
|
| 1169 |
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
|
| 1170 |
|
| 1171 |
+
embedding_output = self.embeddings(
|
| 1172 |
+
input_ids=input_ids,
|
| 1173 |
+
position_ids=position_ids,
|
| 1174 |
+
token_type_ids=token_type_ids,
|
| 1175 |
+
inputs_embeds=inputs_embeds,
|
| 1176 |
+
past_key_values_length=past_key_values_length,
|
| 1177 |
+
)
|
| 1178 |
+
|
| 1179 |
+
# Prepare head mask if needed
|
| 1180 |
+
# 1.0 in head_mask indicate we keep the head
|
| 1181 |
+
# attention_probs has shape bsz x n_heads x N x N
|
| 1182 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
| 1183 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
| 1184 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
| 1185 |
+
|
| 1186 |
+
use_sdpa_attention_masks = (
|
| 1187 |
+
self.attn_implementation == "adpa"
|
| 1188 |
+
and head_mask is None
|
| 1189 |
+
and not output_attentions
|
| 1190 |
+
)
|
| 1191 |
+
|
| 1192 |
+
extended_attention_mask: torch.Tensor
|
| 1193 |
+
if use_sdpa_attention_masks:
|
| 1194 |
+
# Expand the attention mask for SDPA.
|
| 1195 |
+
# [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
|
| 1196 |
+
if self.config.is_decoder:
|
| 1197 |
+
extended_attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
| 1198 |
+
attention_mask,
|
| 1199 |
+
input_shape,
|
| 1200 |
+
embedding_output,
|
| 1201 |
+
past_key_values_length,
|
| 1202 |
+
)
|
| 1203 |
+
else:
|
| 1204 |
+
extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
|
| 1205 |
+
attention_mask,
|
| 1206 |
+
embedding_output.dtype,
|
| 1207 |
+
tgt_len=seq_length,
|
| 1208 |
+
)
|
| 1209 |
+
else:
|
| 1210 |
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
| 1211 |
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
| 1212 |
+
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
|
| 1213 |
|
| 1214 |
# If a 2D or 3D attention mask is provided for the cross-attention
|
| 1215 |
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
|
|
|
| 1218 |
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
| 1219 |
if encoder_attention_mask is None:
|
| 1220 |
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
| 1221 |
+
if use_sdpa_attention_masks:
|
| 1222 |
+
# Expand the attention mask for SDPA.
|
| 1223 |
+
# [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
|
| 1224 |
+
encoder_extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
|
| 1225 |
+
encoder_attention_mask, embedding_output.dtype, tgt_len=seq_length
|
| 1226 |
+
)
|
| 1227 |
+
else:
|
| 1228 |
+
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
| 1229 |
else:
|
| 1230 |
encoder_extended_attention_mask = None
|
| 1231 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1232 |
encoder_outputs = self.encoder(
|
| 1233 |
embedding_output,
|
| 1234 |
attention_mask=extended_attention_mask,
|