BucketOfFish commited on
Commit
5e8c4af
·
1 Parent(s): 8a5eabe

Removed all flash_attn usage

Browse files
Files changed (3) hide show
  1. config.json +0 -3
  2. configuration_phi.py +0 -6
  3. modeling_phi.py +4 -83
config.json CHANGED
@@ -10,9 +10,6 @@
10
  "AutoModelForCausalLM": "modeling_phi.PhiForCausalLM"
11
  },
12
  "embd_pdrop": 0.0,
13
- "flash_attn": false,
14
- "flash_rotary": false,
15
- "fused_dense": false,
16
  "img_processor": null,
17
  "initializer_range": 0.02,
18
  "layer_norm_epsilon": 1e-05,
 
10
  "AutoModelForCausalLM": "modeling_phi.PhiForCausalLM"
11
  },
12
  "embd_pdrop": 0.0,
 
 
 
13
  "img_processor": null,
14
  "initializer_range": 0.02,
15
  "layer_norm_epsilon": 1e-05,
configuration_phi.py CHANGED
@@ -29,9 +29,6 @@ class PhiConfig(PretrainedConfig):
29
  n_head_kv: Optional[int] = None,
30
  rotary_dim: Optional[int] = 32,
31
  activation_function: Optional[str] = "gelu_new",
32
- flash_attn: bool = False,
33
- flash_rotary: bool = False,
34
- fused_dense: bool = False,
35
  attn_pdrop: float = 0.0,
36
  embd_pdrop: float = 0.0,
37
  resid_pdrop: float = 0.0,
@@ -50,9 +47,6 @@ class PhiConfig(PretrainedConfig):
50
  self.n_head_kv = n_head_kv
51
  self.rotary_dim = min(rotary_dim, n_embd // n_head)
52
  self.activation_function = activation_function
53
- self.flash_attn = flash_attn
54
- self.flash_rotary = flash_rotary
55
- self.fused_dense = fused_dense
56
  self.attn_pdrop = attn_pdrop
57
  self.embd_pdrop = embd_pdrop
58
  self.resid_pdrop = resid_pdrop
 
29
  n_head_kv: Optional[int] = None,
30
  rotary_dim: Optional[int] = 32,
31
  activation_function: Optional[str] = "gelu_new",
 
 
 
32
  attn_pdrop: float = 0.0,
33
  embd_pdrop: float = 0.0,
34
  resid_pdrop: float = 0.0,
 
47
  self.n_head_kv = n_head_kv
48
  self.rotary_dim = min(rotary_dim, n_embd // n_head)
49
  self.activation_function = activation_function
 
 
 
50
  self.attn_pdrop = attn_pdrop
51
  self.embd_pdrop = embd_pdrop
52
  self.resid_pdrop = resid_pdrop
modeling_phi.py CHANGED
@@ -19,17 +19,6 @@ from transformers.modeling_outputs import CausalLMOutputWithPast
19
 
20
  from .configuration_phi import PhiConfig
21
 
22
- try:
23
- from flash_attn.bert_padding import pad_input, unpad_input
24
- from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding
25
- from flash_attn.modules.mha import FlashCrossAttention, FlashSelfAttention
26
- from flash_attn.ops.fused_dense import FusedDense
27
- except:
28
- pad_input, unpad_input = None, None
29
- FlashRotaryEmbedding = None
30
- FlashSelfAttention, FlashCrossAttention = None, None
31
- FusedDense = None
32
-
33
 
34
  @dataclass
35
  class InferenceParams:
@@ -532,7 +521,7 @@ class MHA(nn.Module):
532
  # Rotary embedding
533
  self.rotary_dim = rotary_dim if rotary_dim is not None else getattr(config, "rotary_dim", 0)
534
  if self.rotary_dim > 0:
535
- rotary_cls = FlashRotaryEmbedding if config.flash_rotary else RotaryEmbedding
536
  if rotary_cls is None:
537
  rotary_cls = RotaryEmbedding
538
 
@@ -555,7 +544,7 @@ class MHA(nn.Module):
555
  op_size = self.head_dim * (self.n_head + 2 * self.n_head_kv)
556
  hidden_size = config.n_embd
557
 
558
- linear_cls = FusedDense if config.fused_dense else nn.Linear
559
  if linear_cls is None:
560
  linear_cls = nn.Linear
561
 
@@ -563,11 +552,11 @@ class MHA(nn.Module):
563
  self.out_proj = linear_cls(hidden_size, hidden_size, bias=bias, device=device, dtype=dtype)
564
 
565
  # Attention
566
- attn_cls = FlashSelfAttention if config.flash_attn else SelfAttention
567
  if attn_cls is None:
568
  attn_cls = SelfAttention
569
 
570
- cross_attn_cls = FlashCrossAttention if config.flash_attn else CrossAttention
571
  if cross_attn_cls is None:
572
  cross_attn_cls = CrossAttention
573
 
@@ -582,7 +571,6 @@ class MHA(nn.Module):
582
  attention_dropout=config.attn_pdrop,
583
  )
584
 
585
- self.flash_attn = config.flash_attn and attn_cls is FlashSelfAttention
586
  self.layer_idx = layer_idx
587
  self.return_residual = return_residual
588
  self.checkpointing = checkpointing
@@ -596,25 +584,6 @@ class MHA(nn.Module):
596
  if self.rotary_dim > 0:
597
  qkv = self.rotary_emb(qkv)
598
 
599
- if self.flash_attn:
600
- batch_size, seqlen = qkv.shape[0], qkv.shape[1]
601
-
602
- cu_seqlens, max_seqlen = None, None
603
- if key_padding_mask is not None:
604
- # If `key_padding_mask` is supplied, we need to unpad the input and retrieve
605
- # the `cu_seqlens` and `max_seqlen` to be used by `flash-attn`
606
- qkv, indices, cu_seqlens, max_seqlen = unpad_input(qkv, key_padding_mask)
607
-
608
- if self.checkpointing:
609
- attn_output = torch.utils.checkpoint.checkpoint(
610
- self.inner_attn, qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen
611
- )
612
- else:
613
- attn_output = self.inner_attn(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen).to(qkv.device)
614
-
615
- # If `key_padding_mask` is supplied, we need to pad the output back to the original shape
616
- return pad_input(attn_output, indices, batch_size, seqlen) if key_padding_mask is not None else attn_output
617
-
618
  if self.checkpointing:
619
  return torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, key_padding_mask=key_padding_mask)
620
 
@@ -644,54 +613,6 @@ class MHA(nn.Module):
644
  if past_key_values is not None:
645
  kv = _update_kv_cache(kv, past_key_values, self.layer_idx)
646
 
647
- if self.flash_attn:
648
- batch_size, seqlen_q = q.shape[0], q.shape[1]
649
- seqlen_k = kv.shape[1]
650
-
651
- cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k = (
652
- None,
653
- None,
654
- None,
655
- None,
656
- )
657
- if key_padding_mask is not None:
658
- kv, _, cu_seqlens_k, max_seqlen_k = unpad_input(kv, key_padding_mask)
659
-
660
- if seqlen_q == 1:
661
- key_padding_mask = torch.ones(batch_size, 1, device=q.device)
662
- elif seqlen_q != seqlen_k:
663
- key_padding_mask = key_padding_mask[:, -seqlen_q:]
664
-
665
- q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, key_padding_mask)
666
-
667
- if self.checkpointing:
668
- attn_output = torch.utils.checkpoint.checkpoint(
669
- self.inner_cross_attn,
670
- q,
671
- kv,
672
- causal=causal,
673
- cu_seqlens=cu_seqlens_q,
674
- max_seqlen=max_seqlen_q,
675
- cu_seqlens_k=cu_seqlens_k,
676
- max_seqlen_k=max_seqlen_k,
677
- )
678
- else:
679
- attn_output = self.inner_cross_attn(
680
- q,
681
- kv,
682
- causal=causal,
683
- cu_seqlens=cu_seqlens_q,
684
- max_seqlen=max_seqlen_q,
685
- cu_seqlens_k=cu_seqlens_k,
686
- max_seqlen_k=max_seqlen_k,
687
- )
688
-
689
- return (
690
- pad_input(attn_output, indices_q, batch_size, max_seqlen_q)
691
- if key_padding_mask is not None
692
- else attn_output
693
- )
694
-
695
  if self.checkpointing:
696
  return torch.utils.checkpoint.checkpoint(
697
  self.inner_cross_attn,
 
19
 
20
  from .configuration_phi import PhiConfig
21
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  @dataclass
24
  class InferenceParams:
 
521
  # Rotary embedding
522
  self.rotary_dim = rotary_dim if rotary_dim is not None else getattr(config, "rotary_dim", 0)
523
  if self.rotary_dim > 0:
524
+ rotary_cls = RotaryEmbedding
525
  if rotary_cls is None:
526
  rotary_cls = RotaryEmbedding
527
 
 
544
  op_size = self.head_dim * (self.n_head + 2 * self.n_head_kv)
545
  hidden_size = config.n_embd
546
 
547
+ linear_cls = nn.Linear
548
  if linear_cls is None:
549
  linear_cls = nn.Linear
550
 
 
552
  self.out_proj = linear_cls(hidden_size, hidden_size, bias=bias, device=device, dtype=dtype)
553
 
554
  # Attention
555
+ attn_cls = SelfAttention
556
  if attn_cls is None:
557
  attn_cls = SelfAttention
558
 
559
+ cross_attn_cls = CrossAttention
560
  if cross_attn_cls is None:
561
  cross_attn_cls = CrossAttention
562
 
 
571
  attention_dropout=config.attn_pdrop,
572
  )
573
 
 
574
  self.layer_idx = layer_idx
575
  self.return_residual = return_residual
576
  self.checkpointing = checkpointing
 
584
  if self.rotary_dim > 0:
585
  qkv = self.rotary_emb(qkv)
586
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
587
  if self.checkpointing:
588
  return torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, key_padding_mask=key_padding_mask)
589
 
 
613
  if past_key_values is not None:
614
  kv = _update_kv_cache(kv, past_key_values, self.layer_idx)
615
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
616
  if self.checkpointing:
617
  return torch.utils.checkpoint.checkpoint(
618
  self.inner_cross_attn,