mrfakename commited on
Commit
e0b12cf
·
verified ·
1 Parent(s): 712bcbf

Support no flash-attn

Browse files
Files changed (1) hide show
  1. modeling_moonshot.py +46 -2
modeling_moonshot.py CHANGED
@@ -27,13 +27,18 @@ else:
27
  from transformers.utils import is_flash_attn_available
28
  from .configuration_moonshot import MoonshotConfig
29
  import math
 
 
 
30
 
31
 
32
  if is_flash_attn_available():
33
  from flash_attn import flash_attn_func, flash_attn_varlen_func
34
  from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
 
35
  else:
36
- raise RuntimeError("flash attention must be installed")
 
37
 
38
 
39
  logger = logging.get_logger(__name__)
@@ -380,6 +385,13 @@ class Attention(nn.Module):
380
  softmax_scale (`float`, *optional*):
381
  The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
382
  """
 
 
 
 
 
 
 
383
  # Contains at least one padding token in the sequence
384
  if padding_mask is not None:
385
  batch_size = query_states.shape[0]
@@ -411,6 +423,38 @@ class Attention(nn.Module):
411
 
412
  return attn_output
413
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
414
 
415
  class DecoderLayer(nn.Module):
416
  def __init__(self, config: MoonshotConfig):
@@ -854,4 +898,4 @@ class MoonshotForCausalLM(MoonshotPreTrainedModel):
854
  reordered_past += (
855
  tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
856
  )
857
- return reordered_past
 
27
  from transformers.utils import is_flash_attn_available
28
  from .configuration_moonshot import MoonshotConfig
29
  import math
30
+ import logging
31
+
32
+ logger = logging.getLogger(__name__)
33
 
34
 
35
  if is_flash_attn_available():
36
  from flash_attn import flash_attn_func, flash_attn_varlen_func
37
  from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
38
+ _flash_attn_2_available = True
39
  else:
40
+ _flash_attn_2_available = False
41
+ logger.warning("Flash Attention 2 is not available. Falling back to standard attention.")
42
 
43
 
44
  logger = logging.get_logger(__name__)
 
385
  softmax_scale (`float`, *optional*):
386
  The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
387
  """
388
+ if not _flash_attn_2_available:
389
+ return self._standard_attention(
390
+ query_states, key_states, value_states,
391
+ attention_mask=padding_mask, query_length=query_length,
392
+ dropout=dropout, softmax_scale=softmax_scale
393
+ )
394
+
395
  # Contains at least one padding token in the sequence
396
  if padding_mask is not None:
397
  batch_size = query_states.shape[0]
 
423
 
424
  return attn_output
425
 
426
+ def _standard_attention(
427
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
428
+ ):
429
+ # Standard scaled dot-product attention
430
+ batch_size, q_length, num_heads, head_dim = query_states.shape
431
+
432
+ # Prepare the query, key, value for attention computation
433
+ # (batch_size, num_heads, seq_length, head_dim)
434
+ query_states = query_states.transpose(1, 2)
435
+ key_states = key_states.transpose(1, 2)
436
+ value_states = value_states.transpose(1, 2)
437
+
438
+ # (batch_size, num_heads, query_length, key_length)
439
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
440
+
441
+ if softmax_scale is None:
442
+ softmax_scale = 1.0 / math.sqrt(head_dim)
443
+ attn_weights = attn_weights * softmax_scale
444
+
445
+ if attention_mask is not None:
446
+ attn_weights = attn_weights + attention_mask
447
+
448
+ # Apply softmax and dropout
449
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
450
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=self.training)
451
+
452
+ # Context vectors
453
+ attn_output = torch.matmul(attn_weights, value_states)
454
+ attn_output = attn_output.transpose(1, 2).contiguous()
455
+
456
+ return attn_output
457
+
458
 
459
  class DecoderLayer(nn.Module):
460
  def __init__(self, config: MoonshotConfig):
 
898
  reordered_past += (
899
  tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
900
  )
901
+ return reordered_past