medmekk HF Staff commited on
Commit
084e4d5
·
verified ·
1 Parent(s): cdd762c

Build uploaded using `kernels`.

Browse files
build/torch28-metal-aarch64-darwin/mlx_rmsnorm/__init__.py CHANGED
@@ -7,16 +7,19 @@ def rmsnorm_forward(x: torch.Tensor, weight: torch.Tensor, epsilon: float = 1e-5
7
  original_shape = x.shape
8
  x = x.view(-1, x.shape[-1])
9
  weight = weight.view(-1)
10
- out = ops.launch_forward_kernel(x, weight, epsilon)
11
- out = out.view(original_shape)
12
- return out
 
13
 
14
  def rmsnorm_backward(x: torch.Tensor, weight: torch.Tensor, grad_output: torch.Tensor, epsilon: float = 1e-5) -> torch.Tensor:
15
  original_shape = x.shape
16
  x = x.view(-1, x.shape[-1])
17
  weight = weight.view(-1)
18
  grad_output = grad_output.view(-1)
19
- grad_input, grad_weight = ops.launch_backward_kernel(x, weight, grad_output, epsilon)
 
 
20
  grad_input = grad_input.view(original_shape)
21
  grad_weight = grad_weight.view(original_shape)
22
  return grad_input, grad_weight
 
7
  original_shape = x.shape
8
  x = x.view(-1, x.shape[-1])
9
  weight = weight.view(-1)
10
+ output = torch.zeros_like(x)
11
+ ops.launch_forward_kernel(x, weight, output, epsilon)
12
+ output = output.view(original_shape)
13
+ return output
14
 
15
  def rmsnorm_backward(x: torch.Tensor, weight: torch.Tensor, grad_output: torch.Tensor, epsilon: float = 1e-5) -> torch.Tensor:
16
  original_shape = x.shape
17
  x = x.view(-1, x.shape[-1])
18
  weight = weight.view(-1)
19
  grad_output = grad_output.view(-1)
20
+ grad_input = torch.zeros_like(x)
21
+ grad_weight = torch.zeros_like(weight)
22
+ ops.launch_backward_kernel(x, weight, grad_output, grad_input, grad_weight, epsilon)
23
  grad_input = grad_input.view(original_shape)
24
  grad_weight = grad_weight.view(original_shape)
25
  return grad_input, grad_weight
build/torch28-metal-aarch64-darwin/mlx_rmsnorm/__pycache__/__init__.cpython-313.pyc CHANGED
Binary files a/build/torch28-metal-aarch64-darwin/mlx_rmsnorm/__pycache__/__init__.cpython-313.pyc and b/build/torch28-metal-aarch64-darwin/mlx_rmsnorm/__pycache__/__init__.cpython-313.pyc differ
 
build/torch28-metal-aarch64-darwin/mlx_rmsnorm/__pycache__/_ops.cpython-313.pyc CHANGED
Binary files a/build/torch28-metal-aarch64-darwin/mlx_rmsnorm/__pycache__/_ops.cpython-313.pyc and b/build/torch28-metal-aarch64-darwin/mlx_rmsnorm/__pycache__/_ops.cpython-313.pyc differ
 
build/torch28-metal-aarch64-darwin/mlx_rmsnorm/_mlx_rmsnorm_97571a8_dirty.abi3.so CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:059409da4eeaf664ffb0d335315a89bf5ec93958b9ad7af73f00e093161087ae
3
  size 219216
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d9067324afd250e29f55291830a02f3cd197a559ecb38262770ea31206c5cb1b
3
  size 219216
build/torch29-metal-aarch64-darwin/mlx_rmsnorm/__init__.py CHANGED
@@ -7,16 +7,19 @@ def rmsnorm_forward(x: torch.Tensor, weight: torch.Tensor, epsilon: float = 1e-5
7
  original_shape = x.shape
8
  x = x.view(-1, x.shape[-1])
9
  weight = weight.view(-1)
10
- out = ops.launch_forward_kernel(x, weight, epsilon)
11
- out = out.view(original_shape)
12
- return out
 
13
 
14
  def rmsnorm_backward(x: torch.Tensor, weight: torch.Tensor, grad_output: torch.Tensor, epsilon: float = 1e-5) -> torch.Tensor:
15
  original_shape = x.shape
16
  x = x.view(-1, x.shape[-1])
17
  weight = weight.view(-1)
18
  grad_output = grad_output.view(-1)
19
- grad_input, grad_weight = ops.launch_backward_kernel(x, weight, grad_output, epsilon)
 
 
20
  grad_input = grad_input.view(original_shape)
21
  grad_weight = grad_weight.view(original_shape)
22
  return grad_input, grad_weight
 
7
  original_shape = x.shape
8
  x = x.view(-1, x.shape[-1])
9
  weight = weight.view(-1)
10
+ output = torch.zeros_like(x)
11
+ ops.launch_forward_kernel(x, weight, output, epsilon)
12
+ output = output.view(original_shape)
13
+ return output
14
 
15
  def rmsnorm_backward(x: torch.Tensor, weight: torch.Tensor, grad_output: torch.Tensor, epsilon: float = 1e-5) -> torch.Tensor:
16
  original_shape = x.shape
17
  x = x.view(-1, x.shape[-1])
18
  weight = weight.view(-1)
19
  grad_output = grad_output.view(-1)
20
+ grad_input = torch.zeros_like(x)
21
+ grad_weight = torch.zeros_like(weight)
22
+ ops.launch_backward_kernel(x, weight, grad_output, grad_input, grad_weight, epsilon)
23
  grad_input = grad_input.view(original_shape)
24
  grad_weight = grad_weight.view(original_shape)
25
  return grad_input, grad_weight
build/torch29-metal-aarch64-darwin/mlx_rmsnorm/__pycache__/__init__.cpython-313.pyc CHANGED
Binary files a/build/torch29-metal-aarch64-darwin/mlx_rmsnorm/__pycache__/__init__.cpython-313.pyc and b/build/torch29-metal-aarch64-darwin/mlx_rmsnorm/__pycache__/__init__.cpython-313.pyc differ
 
build/torch29-metal-aarch64-darwin/mlx_rmsnorm/__pycache__/_ops.cpython-313.pyc CHANGED
Binary files a/build/torch29-metal-aarch64-darwin/mlx_rmsnorm/__pycache__/_ops.cpython-313.pyc and b/build/torch29-metal-aarch64-darwin/mlx_rmsnorm/__pycache__/_ops.cpython-313.pyc differ
 
build/torch29-metal-aarch64-darwin/mlx_rmsnorm/_mlx_rmsnorm_97571a8_dirty.abi3.so CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:7da3f963a3463c9691b409f8b91bb44388c45aef77af354a915eebfacc1b49d4
3
  size 220160
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9841a7657253626aefe2c9cd346fd61de8d60e4e5484e7ae9230c44232bf9fd1
3
  size 220160