Kernels
iamwyldecat commited on
Commit
605f22e
·
1 Parent(s): f3b99fb

feat: support reset_parameters()

Browse files
build/torch26-cxx11-rocm62-x86_64-linux/activation/{_activation_d14fd4d_dirty.abi3.so → _activation_f3b99fb_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:179bfe6bd5484e81b1d8fa6cc3e2596837946a17f0761b0bb2521fd162669046
3
  size 2656296
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7bf9f4d85d15bc4869292e6a293ec53b7658cee61284457ea727c4be435062f7
3
  size 2656296
build/torch26-cxx11-rocm62-x86_64-linux/activation/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _activation_d14fd4d_dirty
3
- ops = torch.ops._activation_d14fd4d_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_activation_d14fd4d_dirty::{op_name}"
 
1
  import torch
2
+ from . import _activation_f3b99fb_dirty
3
+ ops = torch.ops._activation_f3b99fb_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_activation_f3b99fb_dirty::{op_name}"
build/torch26-cxx11-rocm62-x86_64-linux/activation/layers.py CHANGED
@@ -1,5 +1,6 @@
1
  import torch
2
  import torch.nn as nn
 
3
 
4
  from .poly_norm import PolyNormFunction
5
  from .rms_norm import RMSNormFunction
@@ -18,6 +19,13 @@ class PolyNorm(nn.Module):
18
  ):
19
  return PolyNormFunction.apply(x, self.weight, self.bias, self.eps)
20
 
 
 
 
 
 
 
 
21
 
22
  class RMSNorm(nn.Module):
23
  def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32):
@@ -30,3 +38,9 @@ class RMSNorm(nn.Module):
30
  x: torch.Tensor,
31
  ):
32
  return RMSNormFunction.apply(x, self.weight, self.eps)
 
 
 
 
 
 
 
1
  import torch
2
  import torch.nn as nn
3
+ from torch.nn import init
4
 
5
  from .poly_norm import PolyNormFunction
6
  from .rms_norm import RMSNormFunction
 
19
  ):
20
  return PolyNormFunction.apply(x, self.weight, self.bias, self.eps)
21
 
22
+ def reset_parameters(self) -> None:
23
+ """
24
+ Resets parameters based on their initialization used in __init__.
25
+ """
26
+ init.ones_(self.weight)
27
+ init.zeros_(self.bias)
28
+
29
 
30
  class RMSNorm(nn.Module):
31
  def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32):
 
38
  x: torch.Tensor,
39
  ):
40
  return RMSNormFunction.apply(x, self.weight, self.eps)
41
+
42
+ def reset_parameters(self) -> None:
43
+ """
44
+ Resets parameters based on their initialization used in __init__.
45
+ """
46
+ init.ones_(self.weight)
build/torch27-cxx11-rocm63-x86_64-linux/activation/{_activation_d14fd4d_dirty.abi3.so → _activation_f3b99fb_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:94debfd52e15f782eb9dd328d9311080d803276745e440b176b20a7031299e3f
3
  size 2642736
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ad289cf495aa7bcb7318535f2d76a6543bd44827369ec358ff7411e182ce089f
3
  size 2642736
build/torch27-cxx11-rocm63-x86_64-linux/activation/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _activation_d14fd4d_dirty
3
- ops = torch.ops._activation_d14fd4d_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_activation_d14fd4d_dirty::{op_name}"
 
1
  import torch
2
+ from . import _activation_f3b99fb_dirty
3
+ ops = torch.ops._activation_f3b99fb_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_activation_f3b99fb_dirty::{op_name}"
build/torch27-cxx11-rocm63-x86_64-linux/activation/layers.py CHANGED
@@ -1,5 +1,6 @@
1
  import torch
2
  import torch.nn as nn
 
3
 
4
  from .poly_norm import PolyNormFunction
5
  from .rms_norm import RMSNormFunction
@@ -18,6 +19,13 @@ class PolyNorm(nn.Module):
18
  ):
19
  return PolyNormFunction.apply(x, self.weight, self.bias, self.eps)
20
 
 
 
 
 
 
 
 
21
 
22
  class RMSNorm(nn.Module):
23
  def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32):
@@ -30,3 +38,9 @@ class RMSNorm(nn.Module):
30
  x: torch.Tensor,
31
  ):
32
  return RMSNormFunction.apply(x, self.weight, self.eps)
 
 
 
 
 
 
 
1
  import torch
2
  import torch.nn as nn
3
+ from torch.nn import init
4
 
5
  from .poly_norm import PolyNormFunction
6
  from .rms_norm import RMSNormFunction
 
19
  ):
20
  return PolyNormFunction.apply(x, self.weight, self.bias, self.eps)
21
 
22
+ def reset_parameters(self) -> None:
23
+ """
24
+ Resets parameters based on their initialization used in __init__.
25
+ """
26
+ init.ones_(self.weight)
27
+ init.zeros_(self.bias)
28
+
29
 
30
  class RMSNorm(nn.Module):
31
  def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32):
 
38
  x: torch.Tensor,
39
  ):
40
  return RMSNormFunction.apply(x, self.weight, self.eps)
41
+
42
+ def reset_parameters(self) -> None:
43
+ """
44
+ Resets parameters based on their initialization used in __init__.
45
+ """
46
+ init.ones_(self.weight)
torch-ext/activation/layers.py CHANGED
@@ -1,5 +1,6 @@
1
  import torch
2
  import torch.nn as nn
 
3
 
4
  from .poly_norm import PolyNormFunction
5
  from .rms_norm import RMSNormFunction
@@ -18,6 +19,13 @@ class PolyNorm(nn.Module):
18
  ):
19
  return PolyNormFunction.apply(x, self.weight, self.bias, self.eps)
20
 
 
 
 
 
 
 
 
21
 
22
  class RMSNorm(nn.Module):
23
  def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32):
@@ -30,3 +38,9 @@ class RMSNorm(nn.Module):
30
  x: torch.Tensor,
31
  ):
32
  return RMSNormFunction.apply(x, self.weight, self.eps)
 
 
 
 
 
 
 
1
  import torch
2
  import torch.nn as nn
3
+ from torch.nn import init
4
 
5
  from .poly_norm import PolyNormFunction
6
  from .rms_norm import RMSNormFunction
 
19
  ):
20
  return PolyNormFunction.apply(x, self.weight, self.bias, self.eps)
21
 
22
+ def reset_parameters(self) -> None:
23
+ """
24
+ Resets parameters based on their initialization used in __init__.
25
+ """
26
+ init.ones_(self.weight)
27
+ init.zeros_(self.bias)
28
+
29
 
30
  class RMSNorm(nn.Module):
31
  def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32):
 
38
  x: torch.Tensor,
39
  ):
40
  return RMSNormFunction.apply(x, self.weight, self.eps)
41
+
42
+ def reset_parameters(self) -> None:
43
+ """
44
+ Resets parameters based on their initialization used in __init__.
45
+ """
46
+ init.ones_(self.weight)