andrecornman commited on
Commit
d1335fb
·
verified ·
1 Parent(s): 56745e9

Use torch swiglu

Browse files
Files changed (1) hide show
  1. modeling_flashppi.py +2 -3
modeling_flashppi.py CHANGED
@@ -12,7 +12,6 @@ from .configuration_flashppi import FlashPPIConfig
12
 
13
  # Detect Flash Attention installation
14
  try:
15
- from flash_attn.ops.activations import swiglu
16
  from flash_attn.layers.rotary import apply_rotary_emb_func
17
  from flash_attn import flash_attn_varlen_kvpacked_func
18
  from flash_attn.bert_padding import pad_input, unpad_input
@@ -22,8 +21,8 @@ except ImportError:
22
  unpad_input = pad_input = apply_rotary_emb_func = None
23
  flash_attn_varlen_kvpacked_func = None
24
 
25
- def swiglu(x, y):
26
- return F.silu(x) * y
27
 
28
  class RMSNorm(nn.Module):
29
  """RMSNorm without variance_epsilon buffer for checkpoint compatibility."""
 
12
 
13
  # Detect Flash Attention installation
14
  try:
 
15
  from flash_attn.layers.rotary import apply_rotary_emb_func
16
  from flash_attn import flash_attn_varlen_kvpacked_func
17
  from flash_attn.bert_padding import pad_input, unpad_input
 
21
  unpad_input = pad_input = apply_rotary_emb_func = None
22
  flash_attn_varlen_kvpacked_func = None
23
 
24
+ def swiglu(x, y):
25
+ return F.silu(x) * y
26
 
27
  class RMSNorm(nn.Module):
28
  """RMSNorm without variance_epsilon buffer for checkpoint compatibility."""