feat: reverted monkey patch
Browse files- configuration_bert.py +0 -2
- modeling_bert.py +5 -17
configuration_bert.py
CHANGED
|
@@ -14,8 +14,6 @@
|
|
| 14 |
# See the License for the specific language governing permissions and
|
| 15 |
# limitations under the License.
|
| 16 |
""" BERT model configuration"""
|
| 17 |
-
from collections import OrderedDict
|
| 18 |
-
from typing import Mapping
|
| 19 |
|
| 20 |
from transformers import PretrainedConfig
|
| 21 |
|
|
|
|
| 14 |
# See the License for the specific language governing permissions and
|
| 15 |
# limitations under the License.
|
| 16 |
""" BERT model configuration"""
|
|
|
|
|
|
|
| 17 |
|
| 18 |
from transformers import PretrainedConfig
|
| 19 |
|
modeling_bert.py
CHANGED
|
@@ -28,16 +28,13 @@ from transformers.models.bert.modeling_bert import (
|
|
| 28 |
BaseModelOutputWithPoolingAndCrossAttentions,
|
| 29 |
BertForPreTrainingOutput,
|
| 30 |
)
|
| 31 |
-
from .patched_padding_bert import index_first_axis as index_first_axis_monkey_patch
|
| 32 |
-
import flash_attn.bert_padding
|
| 33 |
-
flash_attn.bert_padding.index_first_axis = index_first_axis_monkey_patch
|
| 34 |
-
"""
|
| 35 |
from flash_attn.bert_padding import (
|
|
|
|
| 36 |
index_first_axis_residual,
|
| 37 |
pad_input,
|
| 38 |
unpad_input,
|
| 39 |
)
|
| 40 |
-
|
| 41 |
from flash_attn.modules.block import Block
|
| 42 |
from flash_attn.modules.embedding import BertEmbeddings
|
| 43 |
from flash_attn.modules.mha import MHA
|
|
@@ -176,14 +173,14 @@ class BertEncoder(nn.Module):
|
|
| 176 |
hidden_states = hidden_states[subset_mask]
|
| 177 |
else:
|
| 178 |
batch, seqlen = hidden_states.shape[:2]
|
| 179 |
-
hidden_states, indices, cu_seqlens, max_seqlen_in_batch =
|
| 180 |
hidden_states, key_padding_mask
|
| 181 |
)
|
| 182 |
mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch}
|
| 183 |
if subset_mask is None:
|
| 184 |
for layer in self.layers:
|
| 185 |
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
| 186 |
-
hidden_states =
|
| 187 |
else:
|
| 188 |
for layer in self.layers[:-1]:
|
| 189 |
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
|
@@ -201,7 +198,7 @@ class BertEncoder(nn.Module):
|
|
| 201 |
subset_cu_seqlens = F.pad(
|
| 202 |
torch.cumsum(subset_seqlens, dim=0, dtype=torch.torch.int32), (1, 0)
|
| 203 |
)
|
| 204 |
-
hidden_states_subset, hidden_states =
|
| 205 |
hidden_states, subset_idx
|
| 206 |
)
|
| 207 |
# It's ok to set max_seqlen_q to be much larger
|
|
@@ -425,15 +422,6 @@ class BertModel(BertPreTrainedModel):
|
|
| 425 |
pooler_output=pooled_output,
|
| 426 |
)
|
| 427 |
|
| 428 |
-
def to(self, *args, **kwargs):
|
| 429 |
-
print(f'In BERT, calling to({args, kwargs})')
|
| 430 |
-
result = super().to(*args, **kwargs)
|
| 431 |
-
if (len(args) > 0 and isinstance(args[0], torch.dtype)) or "dtype" in kwargs:
|
| 432 |
-
for layer in result.encoder.layers:
|
| 433 |
-
layer.mixer.inner_cross_attn.alibi_slopes = layer.mixer.inner_cross_attn.alibi_slopes.to(torch.float32)
|
| 434 |
-
layer.mixer.inner_attn.alibi_slopes = layer.mixer.inner_attn.alibi_slopes.to(torch.float32)
|
| 435 |
-
return result
|
| 436 |
-
|
| 437 |
|
| 438 |
class BertForPreTraining(BertPreTrainedModel):
|
| 439 |
def __init__(self, config: JinaBertConfig):
|
|
|
|
| 28 |
BaseModelOutputWithPoolingAndCrossAttentions,
|
| 29 |
BertForPreTrainingOutput,
|
| 30 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
from flash_attn.bert_padding import (
|
| 32 |
+
index_first_axis,
|
| 33 |
index_first_axis_residual,
|
| 34 |
pad_input,
|
| 35 |
unpad_input,
|
| 36 |
)
|
| 37 |
+
|
| 38 |
from flash_attn.modules.block import Block
|
| 39 |
from flash_attn.modules.embedding import BertEmbeddings
|
| 40 |
from flash_attn.modules.mha import MHA
|
|
|
|
| 173 |
hidden_states = hidden_states[subset_mask]
|
| 174 |
else:
|
| 175 |
batch, seqlen = hidden_states.shape[:2]
|
| 176 |
+
hidden_states, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
|
| 177 |
hidden_states, key_padding_mask
|
| 178 |
)
|
| 179 |
mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch}
|
| 180 |
if subset_mask is None:
|
| 181 |
for layer in self.layers:
|
| 182 |
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
| 183 |
+
hidden_states = pad_input(hidden_states, indices, batch, seqlen)
|
| 184 |
else:
|
| 185 |
for layer in self.layers[:-1]:
|
| 186 |
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
|
|
|
| 198 |
subset_cu_seqlens = F.pad(
|
| 199 |
torch.cumsum(subset_seqlens, dim=0, dtype=torch.torch.int32), (1, 0)
|
| 200 |
)
|
| 201 |
+
hidden_states_subset, hidden_states = index_first_axis_residual(
|
| 202 |
hidden_states, subset_idx
|
| 203 |
)
|
| 204 |
# It's ok to set max_seqlen_q to be much larger
|
|
|
|
| 422 |
pooler_output=pooled_output,
|
| 423 |
)
|
| 424 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 425 |
|
| 426 |
class BertForPreTraining(BertPreTrainedModel):
|
| 427 |
def __init__(self, config: JinaBertConfig):
|