Update modeling_motif.py
Browse files- modeling_motif.py +1 -20
modeling_motif.py
CHANGED
|
@@ -51,27 +51,8 @@ class PolyNorm(torch.nn.Module):
|
|
| 51 |
return self.weight[0] * self._norm(x ** 3) + self.weight[1] * self._norm(
|
| 52 |
x ** 2) + self.weight[2] * self._norm(x) + self.bias
|
| 53 |
|
| 54 |
-
class PolyNorm_Test(torch.nn.Module):
|
| 55 |
-
"""
|
| 56 |
-
A trainable activation function introduced in https://arxiv.org/html/2411.03884v1.
|
| 57 |
-
The code is copied from https://github.com/BryceZhuo/PolyCom?tab=readme-ov-file/README.md,
|
| 58 |
-
with the change `* torch.rsqrt` => `/ torch.sqrt` for potential MAF incompatibility.
|
| 59 |
-
"""
|
| 60 |
-
|
| 61 |
-
def __init__(self, eps=1e-6):
|
| 62 |
-
super(PolyNorm_Test, self).__init__()
|
| 63 |
-
self.weight = torch.nn.Parameter(torch.ones(3) / 3)
|
| 64 |
-
self.bias = torch.nn.Parameter(torch.zeros(1))
|
| 65 |
-
self.eps = eps
|
| 66 |
-
|
| 67 |
-
def forward(self, x):
|
| 68 |
-
|
| 69 |
-
#return torch.nn.SiLU(x)
|
| 70 |
-
return moreh_ops.poly_norm(x, self.weight, self.bias)
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
CUSTOM_ACT2CLS = {"poly_norm": PolyNorm, "poly_norm_test": PolyNorm_Test}
|
| 74 |
|
|
|
|
| 75 |
ACT2CLS = {**_ACT2CLS, **CUSTOM_ACT2CLS}
|
| 76 |
ACT2FN = ClassInstantier(ACT2CLS)
|
| 77 |
|
|
|
|
| 51 |
return self.weight[0] * self._norm(x ** 3) + self.weight[1] * self._norm(
|
| 52 |
x ** 2) + self.weight[2] * self._norm(x) + self.bias
|
| 53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
|
| 55 |
+
CUSTOM_ACT2CLS = {"poly_norm": PolyNorm}
|
| 56 |
ACT2CLS = {**_ACT2CLS, **CUSTOM_ACT2CLS}
|
| 57 |
ACT2FN = ClassInstantier(ACT2CLS)
|
| 58 |
|