andrecornman commited on
Commit
56745e9
·
verified ·
1 Parent(s): 58af421

Use torch RMSNorm

Browse files
Files changed (1) hide show
  1. modeling_flashppi.py +14 -16
modeling_flashppi.py CHANGED
@@ -16,7 +16,6 @@ try:
16
  from flash_attn.layers.rotary import apply_rotary_emb_func
17
  from flash_attn import flash_attn_varlen_kvpacked_func
18
  from flash_attn.bert_padding import pad_input, unpad_input
19
- from flash_attn.ops.triton.layer_norm import RMSNorm
20
  FLASH_ATTN_AVAILABLE = True
21
  except ImportError:
22
  FLASH_ATTN_AVAILABLE = False
@@ -26,21 +25,20 @@ except ImportError:
26
  def swiglu(x, y):
27
  return F.silu(x) * y
28
 
29
- class RMSNorm(nn.Module):
30
- """RMSNorm without variance_epsilon buffer for checkpoint compatibility."""
31
- def __init__(self, dim, eps=1e-6):
32
- super().__init__()
33
- self.weight = nn.Parameter(torch.ones(dim))
34
- self.eps = eps
35
-
36
- def forward(self, hidden_states):
37
- input_dtype = hidden_states.dtype
38
- hidden_states = hidden_states.to(torch.float32)
39
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
40
- hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
41
- return (self.weight * hidden_states).to(input_dtype)
42
-
43
-
44
  @dataclass
45
  class FlashPPIOutput(ModelOutput):
46
  """Output type for FlashPPI model.
 
16
  from flash_attn.layers.rotary import apply_rotary_emb_func
17
  from flash_attn import flash_attn_varlen_kvpacked_func
18
  from flash_attn.bert_padding import pad_input, unpad_input
 
19
  FLASH_ATTN_AVAILABLE = True
20
  except ImportError:
21
  FLASH_ATTN_AVAILABLE = False
 
25
  def swiglu(x, y):
26
  return F.silu(x) * y
27
 
28
+ class RMSNorm(nn.Module):
29
+ """RMSNorm without variance_epsilon buffer for checkpoint compatibility."""
30
+ def __init__(self, dim, eps=1e-6):
31
+ super().__init__()
32
+ self.weight = nn.Parameter(torch.ones(dim))
33
+ self.eps = eps
34
+
35
+ def forward(self, hidden_states):
36
+ input_dtype = hidden_states.dtype
37
+ hidden_states = hidden_states.to(torch.float32)
38
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
39
+ hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
40
+ return (self.weight * hidden_states).to(input_dtype)
41
+
 
42
  @dataclass
43
  class FlashPPIOutput(ModelOutput):
44
  """Output type for FlashPPI model.