Use torch RMSNorm
Browse files- modeling_flashppi.py +14 -16
modeling_flashppi.py
CHANGED
|
@@ -16,7 +16,6 @@ try:
|
|
| 16 |
from flash_attn.layers.rotary import apply_rotary_emb_func
|
| 17 |
from flash_attn import flash_attn_varlen_kvpacked_func
|
| 18 |
from flash_attn.bert_padding import pad_input, unpad_input
|
| 19 |
-
from flash_attn.ops.triton.layer_norm import RMSNorm
|
| 20 |
FLASH_ATTN_AVAILABLE = True
|
| 21 |
except ImportError:
|
| 22 |
FLASH_ATTN_AVAILABLE = False
|
|
@@ -26,21 +25,20 @@ except ImportError:
|
|
| 26 |
def swiglu(x, y):
|
| 27 |
return F.silu(x) * y
|
| 28 |
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
@dataclass
|
| 45 |
class FlashPPIOutput(ModelOutput):
|
| 46 |
"""Output type for FlashPPI model.
|
|
|
|
| 16 |
from flash_attn.layers.rotary import apply_rotary_emb_func
|
| 17 |
from flash_attn import flash_attn_varlen_kvpacked_func
|
| 18 |
from flash_attn.bert_padding import pad_input, unpad_input
|
|
|
|
| 19 |
FLASH_ATTN_AVAILABLE = True
|
| 20 |
except ImportError:
|
| 21 |
FLASH_ATTN_AVAILABLE = False
|
|
|
|
| 25 |
def swiglu(x, y):
|
| 26 |
return F.silu(x) * y
|
| 27 |
|
| 28 |
+
class RMSNorm(nn.Module):
|
| 29 |
+
"""RMSNorm without variance_epsilon buffer for checkpoint compatibility."""
|
| 30 |
+
def __init__(self, dim, eps=1e-6):
|
| 31 |
+
super().__init__()
|
| 32 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 33 |
+
self.eps = eps
|
| 34 |
+
|
| 35 |
+
def forward(self, hidden_states):
|
| 36 |
+
input_dtype = hidden_states.dtype
|
| 37 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 38 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| 39 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
|
| 40 |
+
return (self.weight * hidden_states).to(input_dtype)
|
| 41 |
+
|
|
|
|
| 42 |
@dataclass
|
| 43 |
class FlashPPIOutput(ModelOutput):
|
| 44 |
"""Output type for FlashPPI model.
|