Support no flash-attn
Browse files- modeling_moonshot.py +46 -2
modeling_moonshot.py
CHANGED
|
@@ -27,13 +27,18 @@ else:
|
|
| 27 |
from transformers.utils import is_flash_attn_available
|
| 28 |
from .configuration_moonshot import MoonshotConfig
|
| 29 |
import math
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
|
| 32 |
if is_flash_attn_available():
|
| 33 |
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
| 34 |
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
|
|
|
| 35 |
else:
|
| 36 |
-
|
|
|
|
| 37 |
|
| 38 |
|
| 39 |
logger = logging.get_logger(__name__)
|
|
@@ -380,6 +385,13 @@ class Attention(nn.Module):
|
|
| 380 |
softmax_scale (`float`, *optional*):
|
| 381 |
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
|
| 382 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 383 |
# Contains at least one padding token in the sequence
|
| 384 |
if padding_mask is not None:
|
| 385 |
batch_size = query_states.shape[0]
|
|
@@ -411,6 +423,38 @@ class Attention(nn.Module):
|
|
| 411 |
|
| 412 |
return attn_output
|
| 413 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 414 |
|
| 415 |
class DecoderLayer(nn.Module):
|
| 416 |
def __init__(self, config: MoonshotConfig):
|
|
@@ -854,4 +898,4 @@ class MoonshotForCausalLM(MoonshotPreTrainedModel):
|
|
| 854 |
reordered_past += (
|
| 855 |
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
| 856 |
)
|
| 857 |
-
return reordered_past
|
|
|
|
| 27 |
from transformers.utils import is_flash_attn_available
|
| 28 |
from .configuration_moonshot import MoonshotConfig
|
| 29 |
import math
|
| 30 |
+
import logging
|
| 31 |
+
|
| 32 |
+
logger = logging.getLogger(__name__)
|
| 33 |
|
| 34 |
|
| 35 |
if is_flash_attn_available():
|
| 36 |
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
| 37 |
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
| 38 |
+
_flash_attn_2_available = True
|
| 39 |
else:
|
| 40 |
+
_flash_attn_2_available = False
|
| 41 |
+
logger.warning("Flash Attention 2 is not available. Falling back to standard attention.")
|
| 42 |
|
| 43 |
|
| 44 |
logger = logging.get_logger(__name__)
|
|
|
|
| 385 |
softmax_scale (`float`, *optional*):
|
| 386 |
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
|
| 387 |
"""
|
| 388 |
+
if not _flash_attn_2_available:
|
| 389 |
+
return self._standard_attention(
|
| 390 |
+
query_states, key_states, value_states,
|
| 391 |
+
attention_mask=padding_mask, query_length=query_length,
|
| 392 |
+
dropout=dropout, softmax_scale=softmax_scale
|
| 393 |
+
)
|
| 394 |
+
|
| 395 |
# Contains at least one padding token in the sequence
|
| 396 |
if padding_mask is not None:
|
| 397 |
batch_size = query_states.shape[0]
|
|
|
|
| 423 |
|
| 424 |
return attn_output
|
| 425 |
|
| 426 |
+
def _standard_attention(
|
| 427 |
+
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
|
| 428 |
+
):
|
| 429 |
+
# Standard scaled dot-product attention
|
| 430 |
+
batch_size, q_length, num_heads, head_dim = query_states.shape
|
| 431 |
+
|
| 432 |
+
# Prepare the query, key, value for attention computation
|
| 433 |
+
# (batch_size, num_heads, seq_length, head_dim)
|
| 434 |
+
query_states = query_states.transpose(1, 2)
|
| 435 |
+
key_states = key_states.transpose(1, 2)
|
| 436 |
+
value_states = value_states.transpose(1, 2)
|
| 437 |
+
|
| 438 |
+
# (batch_size, num_heads, query_length, key_length)
|
| 439 |
+
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
|
| 440 |
+
|
| 441 |
+
if softmax_scale is None:
|
| 442 |
+
softmax_scale = 1.0 / math.sqrt(head_dim)
|
| 443 |
+
attn_weights = attn_weights * softmax_scale
|
| 444 |
+
|
| 445 |
+
if attention_mask is not None:
|
| 446 |
+
attn_weights = attn_weights + attention_mask
|
| 447 |
+
|
| 448 |
+
# Apply softmax and dropout
|
| 449 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
| 450 |
+
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=self.training)
|
| 451 |
+
|
| 452 |
+
# Context vectors
|
| 453 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
| 454 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 455 |
+
|
| 456 |
+
return attn_output
|
| 457 |
+
|
| 458 |
|
| 459 |
class DecoderLayer(nn.Module):
|
| 460 |
def __init__(self, config: MoonshotConfig):
|
|
|
|
| 898 |
reordered_past += (
|
| 899 |
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
| 900 |
)
|
| 901 |
+
return reordered_past
|