Jackmin801
commited on
Commit
·
6f3de15
1
Parent(s):
4f24e0f
set flash attn as option in config
Browse files- configuration_bert.py +4 -0
- flash_attn_triton.py +9 -32
- modeling_bert.py +17 -6
configuration_bert.py
CHANGED
|
@@ -127,6 +127,8 @@ class JinaBertConfig(PretrainedConfig):
|
|
| 127 |
emb_pooler (`str`, *optional*, defaults to `None`):
|
| 128 |
The function to use for pooling the last layer embeddings to get the sentence embeddings.
|
| 129 |
Should be one of `None`, `"mean"`.
|
|
|
|
|
|
|
| 130 |
|
| 131 |
Examples:
|
| 132 |
|
|
@@ -164,6 +166,7 @@ class JinaBertConfig(PretrainedConfig):
|
|
| 164 |
classifier_dropout=None,
|
| 165 |
feed_forward_type="original",
|
| 166 |
emb_pooler=None,
|
|
|
|
| 167 |
**kwargs,
|
| 168 |
):
|
| 169 |
super().__init__(pad_token_id=pad_token_id, **kwargs)
|
|
@@ -185,6 +188,7 @@ class JinaBertConfig(PretrainedConfig):
|
|
| 185 |
self.classifier_dropout = classifier_dropout
|
| 186 |
self.feed_forward_type = feed_forward_type
|
| 187 |
self.emb_pooler = emb_pooler
|
|
|
|
| 188 |
|
| 189 |
|
| 190 |
class JinaBertOnnxConfig(OnnxConfig):
|
|
|
|
| 127 |
emb_pooler (`str`, *optional*, defaults to `None`):
|
| 128 |
The function to use for pooling the last layer embeddings to get the sentence embeddings.
|
| 129 |
Should be one of `None`, `"mean"`.
|
| 130 |
+
with_flash (`bool`, *optional*, defaults to `False`):
|
| 131 |
+
Whether to use flash attention. Only works for `triton==2.0.0.dev20230208`
|
| 132 |
|
| 133 |
Examples:
|
| 134 |
|
|
|
|
| 166 |
classifier_dropout=None,
|
| 167 |
feed_forward_type="original",
|
| 168 |
emb_pooler=None,
|
| 169 |
+
with_flash=False,
|
| 170 |
**kwargs,
|
| 171 |
):
|
| 172 |
super().__init__(pad_token_id=pad_token_id, **kwargs)
|
|
|
|
| 188 |
self.classifier_dropout = classifier_dropout
|
| 189 |
self.feed_forward_type = feed_forward_type
|
| 190 |
self.emb_pooler = emb_pooler
|
| 191 |
+
self.with_flash = with_flash
|
| 192 |
|
| 193 |
|
| 194 |
class JinaBertOnnxConfig(OnnxConfig):
|
flash_attn_triton.py
CHANGED
|
@@ -81,21 +81,11 @@ def _fwd_kernel(
|
|
| 81 |
Lse,
|
| 82 |
TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
|
| 83 |
softmax_scale,
|
| 84 |
-
stride_qb,
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
stride_kn,
|
| 90 |
-
stride_vb,
|
| 91 |
-
stride_vh,
|
| 92 |
-
stride_vn,
|
| 93 |
-
stride_bb,
|
| 94 |
-
stride_bh,
|
| 95 |
-
stride_bm,
|
| 96 |
-
stride_ob,
|
| 97 |
-
stride_oh,
|
| 98 |
-
stride_om,
|
| 99 |
nheads,
|
| 100 |
seqlen_q,
|
| 101 |
seqlen_k,
|
|
@@ -316,11 +306,6 @@ def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):
|
|
| 316 |
elif bias.shape[2:] == (seqlen_q, seqlen_k):
|
| 317 |
bias_type = 'matrix'
|
| 318 |
else:
|
| 319 |
-
print(q.shape)
|
| 320 |
-
print(k.shape)
|
| 321 |
-
print(seqlen_q)
|
| 322 |
-
print(seqlen_k)
|
| 323 |
-
print(bias.shape)
|
| 324 |
raise RuntimeError('Last 2 dimensions of bias must be (1, seqlen_k)'
|
| 325 |
' or (seqlen_q, seqlen_k)')
|
| 326 |
if bias.shape[:2] == (1, nheads):
|
|
@@ -359,19 +344,11 @@ def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):
|
|
| 359 |
lse,
|
| 360 |
tmp,
|
| 361 |
softmax_scale,
|
| 362 |
-
q.stride(0),
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
k.stride(0),
|
| 366 |
-
k.stride(2),
|
| 367 |
-
k.stride(1),
|
| 368 |
-
v.stride(0),
|
| 369 |
-
v.stride(2),
|
| 370 |
-
v.stride(1),
|
| 371 |
*bias_strides,
|
| 372 |
-
o.stride(0),
|
| 373 |
-
o.stride(2),
|
| 374 |
-
o.stride(1),
|
| 375 |
nheads,
|
| 376 |
seqlen_q,
|
| 377 |
seqlen_k,
|
|
|
|
| 81 |
Lse,
|
| 82 |
TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
|
| 83 |
softmax_scale,
|
| 84 |
+
stride_qb, stride_qh, stride_qm,
|
| 85 |
+
stride_kb, stride_kh, stride_kn,
|
| 86 |
+
stride_vb, stride_vh, stride_vn,
|
| 87 |
+
stride_bb, stride_bh, stride_bm,
|
| 88 |
+
stride_ob, stride_oh, stride_om,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
nheads,
|
| 90 |
seqlen_q,
|
| 91 |
seqlen_k,
|
|
|
|
| 306 |
elif bias.shape[2:] == (seqlen_q, seqlen_k):
|
| 307 |
bias_type = 'matrix'
|
| 308 |
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 309 |
raise RuntimeError('Last 2 dimensions of bias must be (1, seqlen_k)'
|
| 310 |
' or (seqlen_q, seqlen_k)')
|
| 311 |
if bias.shape[:2] == (1, nheads):
|
|
|
|
| 344 |
lse,
|
| 345 |
tmp,
|
| 346 |
softmax_scale,
|
| 347 |
+
q.stride(0), q.stride(2), q.stride(1),
|
| 348 |
+
k.stride(0), k.stride(2), k.stride(1),
|
| 349 |
+
v.stride(0), v.stride(2), v.stride(1),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 350 |
*bias_strides,
|
| 351 |
+
o.stride(0), o.stride(2), o.stride(1),
|
|
|
|
|
|
|
| 352 |
nheads,
|
| 353 |
seqlen_q,
|
| 354 |
seqlen_k,
|
modeling_bert.py
CHANGED
|
@@ -55,7 +55,10 @@ from transformers.utils import (
|
|
| 55 |
replace_return_docstrings,
|
| 56 |
)
|
| 57 |
from .configuration_bert import JinaBertConfig
|
| 58 |
-
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
try:
|
| 61 |
from tqdm.autonotebook import trange
|
|
@@ -282,7 +285,7 @@ class JinaBertEmbeddings(nn.Module):
|
|
| 282 |
|
| 283 |
|
| 284 |
class JinaBertSelfAttention(nn.Module):
|
| 285 |
-
def __init__(self, config, position_embedding_type=None):
|
| 286 |
super().__init__()
|
| 287 |
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
|
| 288 |
config, "embedding_size"
|
|
@@ -291,6 +294,13 @@ class JinaBertSelfAttention(nn.Module):
|
|
| 291 |
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
|
| 292 |
f"heads ({config.num_attention_heads})"
|
| 293 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 294 |
|
| 295 |
self.num_attention_heads = config.num_attention_heads
|
| 296 |
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
|
@@ -334,14 +344,15 @@ class JinaBertSelfAttention(nn.Module):
|
|
| 334 |
output_attentions: Optional[bool] = False,
|
| 335 |
bias: Optional[torch.FloatTensor] = None,
|
| 336 |
) -> Tuple[torch.Tensor]:
|
| 337 |
-
if
|
| 338 |
b, s, h = hidden_states.shape
|
| 339 |
q = self.query(hidden_states)
|
| 340 |
k = self.key(hidden_states)
|
| 341 |
v = self.value(hidden_states)
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
|
|
|
| 345 |
attn = flash_attn_func(q, k, v, bias)
|
| 346 |
return (attn.view(b, s, h),)
|
| 347 |
mixed_query_layer = self.query(hidden_states)
|
|
|
|
| 55 |
replace_return_docstrings,
|
| 56 |
)
|
| 57 |
from .configuration_bert import JinaBertConfig
|
| 58 |
+
try:
|
| 59 |
+
from .flash_attn_triton import flash_attn_func
|
| 60 |
+
except Exception:
|
| 61 |
+
flash_attn_func = None
|
| 62 |
|
| 63 |
try:
|
| 64 |
from tqdm.autonotebook import trange
|
|
|
|
| 285 |
|
| 286 |
|
| 287 |
class JinaBertSelfAttention(nn.Module):
|
| 288 |
+
def __init__(self, config: JinaBertConfig, position_embedding_type=None):
|
| 289 |
super().__init__()
|
| 290 |
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
|
| 291 |
config, "embedding_size"
|
|
|
|
| 294 |
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
|
| 295 |
f"heads ({config.num_attention_heads})"
|
| 296 |
)
|
| 297 |
+
|
| 298 |
+
self.with_flash = config.with_flash
|
| 299 |
+
if self.with_flash:
|
| 300 |
+
if flash_attn_func is None:
|
| 301 |
+
raise ValueError(
|
| 302 |
+
f"flash_attn_func is None, please install flash_attn_triton"
|
| 303 |
+
)
|
| 304 |
|
| 305 |
self.num_attention_heads = config.num_attention_heads
|
| 306 |
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
|
|
|
| 344 |
output_attentions: Optional[bool] = False,
|
| 345 |
bias: Optional[torch.FloatTensor] = None,
|
| 346 |
) -> Tuple[torch.Tensor]:
|
| 347 |
+
if self.with_flash:
|
| 348 |
b, s, h = hidden_states.shape
|
| 349 |
q = self.query(hidden_states)
|
| 350 |
k = self.key(hidden_states)
|
| 351 |
v = self.value(hidden_states)
|
| 352 |
+
# B x S x hidden_dim -> B x S x num_heads x head_dim
|
| 353 |
+
q = q.view(b, s, self.num_attention_heads, self.attention_head_size)
|
| 354 |
+
k = k.view(b, s, self.num_attention_heads, self.attention_head_size)
|
| 355 |
+
v = v.view(b, s, self.num_attention_heads, self.attention_head_size)
|
| 356 |
attn = flash_attn_func(q, k, v, bias)
|
| 357 |
return (attn.view(b, s, h),)
|
| 358 |
mixed_query_layer = self.query(hidden_states)
|