samir-souza commited on
Commit
610de45
·
verified ·
1 Parent(s): 21ed583

Upload folder using huggingface_hub

Browse files
build/torch-neuron/__init__.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from nkilib.core.mlp.mlp import mlp
3
+ from nkilib.core.rmsnorm.rmsnorm_quant import rmsnorm_quant_kernel, RmsNormQuantKernelArgs
4
+ from nkilib.core.utils.common_types import ActFnType, NormType, QuantizationType
5
+ from ._ops import ops
6
+
7
+ from . import layers
8
+
9
+ def mlp_kernel(x, gate_proj_weight, up_proj_weight, down_proj_weight, activation_fn):
10
+ x_dtype = x.dtype
11
+ dtype = torch.bfloat16
12
+
13
+ # class ActFnType(Enum):\n SiLU = 0\n GELU = 1\n GELU_Tanh_Approx = 2\n Swish = 3\n'
14
+ if activation_fn.lower() == "silu":
15
+ act_fn = ActFnType.SiLU
16
+ elif activation_fn.lower() == "gelu":
17
+ act_fn = ActFnType.GELU
18
+ elif activation_fn.lower() == "gelu_pytorch_tanh":
19
+ act_fn = ActFnType.GELU_Tanh_Approx
20
+ elif activation_fn.lower() == "swish":
21
+ act_fn = ActFnType.Swish
22
+ else:
23
+ raise Exception(f"Activation function not supported: {activation_fn}")
24
+
25
+ return mlp(
26
+ x.to(dtype),
27
+ gate_proj_weight.transpose(1,0).to(dtype),
28
+ up_proj_weight.transpose(1,0).to(dtype),
29
+ down_proj_weight.transpose(1,0).to(dtype),
30
+ activation_fn=act_fn,
31
+ ).to(x_dtype)
32
+
33
+ def rmsnorm_kernel(hidden, ln_weight, epsilon):
34
+ hidden_dtype = hidden.dtype
35
+ dtype = torch.bfloat16
36
+ #from collections import namedtuple
37
+ #RmsNormQuantKernelArgs_ = namedtuple('RmsNormQuantKernelArgs_', 'quantization_type lower_bound norm_type eps')
38
+ #kernel_args = RmsNormQuantKernelArgs_(
39
+ # quantization_type=QuantizationType.ROW,
40
+ # lower_bound=0.0,
41
+ # norm_type=NormType.RMS_NORM,
42
+ # eps=epsilon
43
+ #)
44
+
45
+ kernel_args = RmsNormQuantKernelArgs(
46
+ quantization_type=QuantizationType.ROW, lower_bound=0.0, norm_type=NormType.RMS_NORM, eps=1e-6
47
+ )
48
+ print(kernel_args.eps)
49
+
50
+ #kernel_args = RmsNormQuantKernelArgs(
51
+ # "quantization_type": QuantizationType.ROW,
52
+ # "lower_bound": 0.0,
53
+ # "norm_type": NormType.RMS_NORM,
54
+ # "eps": epsilon
55
+ #}
56
+ return rmsnorm_quant_kernel(
57
+ hidden=hidden.to(dtype),
58
+ ln_w=ln_weight.to(dtype),
59
+ kargs=kernel_args,
60
+ #input_dequant_scale=None
61
+ ).to(hidden_dtype)
62
+
63
+ __all__ = [
64
+ "layers",
65
+ "MLP",
66
+ "RMSNorm"
67
+ ]
build/torch-neuron/_ops.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ ops = torch.ops._nki_kernels_5abad9b
3
+
4
+ def add_op_namespace_prefix(op_name: str):
5
+ """
6
+ Prefix op by namespace.
7
+ """
8
+ return f"_nki_kernels_5abad9b::{op_name}"
build/torch-neuron/layers/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import logging
4
+
5
+ from .. import nki_kernels
6
+
7
+ logger = logging.getLogger(__name__)
8
+ logger.setLevel(logging.INFO)
9
+
10
+ class MLP(nn.Module):
11
+ config: object
12
+ gate_proj: torch.Tensor
13
+ up_proj: torch.Tensor
14
+ down_proj: torch.Tensor
15
+ act_fn: object
16
+
17
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
18
+ return nki_kernels.mlp_kernel(x, self.gate_proj.weight, self.up_proj.weight, self.down_proj.weight, self.config.hidden_act)
19
+
20
+ class RMSNorm(nn.Module):
21
+ weight: torch.Tensor
22
+ variance_epsilon: float
23
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
24
+ return nki_kernels.rmsnorm_kernel(hidden_stats, self.weight, self.variance_epsilon)
25
+
26
+ def extra_repr(self):
27
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
build/torch-neuron/metadata-neuron.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "version": 1,
3
+ "python-depends": []
4
+ }
build/torch-neuron/nki_kernels/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ctypes
2
+ import sys
3
+
4
+ import importlib
5
+ from pathlib import Path
6
+ from types import ModuleType
7
+
8
+ def _import_from_path(file_path: Path) -> ModuleType:
9
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
10
+ # it would also be used for other imports. So, we make a module name that
11
+ # depends on the path for it to be unique using the hex-encoded hash of
12
+ # the path.
13
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
+ module_name = path_hash
15
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
16
+ if spec is None:
17
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
+ module = importlib.util.module_from_spec(spec)
19
+ if module is None:
20
+ raise ImportError(f"Cannot load module {module_name} from spec")
21
+ sys.modules[module_name] = module
22
+ spec.loader.exec_module(module) # type: ignore
23
+ return module
24
+
25
+
26
+ globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))