drbh commited on
Commit ·
dda32b8
0
Parent(s):
feat: basic template
Browse files- .gitattributes +3 -0
- .gitignore +1 -0
- README.md +48 -0
- __KERNEL_NAME_NORMALIZED___cpu/__KERNEL_NAME_NORMALIZED___cpu.cpp +15 -0
- __KERNEL_NAME_NORMALIZED___cuda/__KERNEL_NAME_NORMALIZED__.cu +33 -0
- __KERNEL_NAME_NORMALIZED___metal/__KERNEL_NAME_NORMALIZED__.metal +14 -0
- __KERNEL_NAME_NORMALIZED___metal/__KERNEL_NAME_NORMALIZED__.mm +63 -0
- __KERNEL_NAME_NORMALIZED___xpu/__KERNEL_NAME_NORMALIZED__.cpp +20 -0
- build.toml +55 -0
- example.py +41 -0
- flake.nix +11 -0
- tests/__init__.py +0 -0
- tests/test___KERNEL_NAME_NORMALIZED__.py +21 -0
- torch-ext/__KERNEL_NAME_NORMALIZED__/__init__.py +12 -0
- torch-ext/torch_binding.cpp +19 -0
- torch-ext/torch_binding.h +5 -0
.gitattributes
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.so filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
build
|
README.md
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# __KERNEL_NAME__
|
| 2 |
+
|
| 3 |
+
A custom kernel for PyTorch.
|
| 4 |
+
|
| 5 |
+
## Installation
|
| 6 |
+
|
| 7 |
+
```bash
|
| 8 |
+
pip install __REPO_ID__
|
| 9 |
+
```
|
| 10 |
+
|
| 11 |
+
## Usage
|
| 12 |
+
|
| 13 |
+
```python
|
| 14 |
+
import torch
|
| 15 |
+
from __KERNEL_NAME_NORMALIZED__ import __KERNEL_NAME_NORMALIZED__
|
| 16 |
+
|
| 17 |
+
# Create input tensor
|
| 18 |
+
x = torch.randn(1024, 1024, device="cuda")
|
| 19 |
+
|
| 20 |
+
# Run kernel
|
| 21 |
+
result = __KERNEL_NAME_NORMALIZED__(x)
|
| 22 |
+
```
|
| 23 |
+
|
| 24 |
+
## Development
|
| 25 |
+
|
| 26 |
+
### Building
|
| 27 |
+
|
| 28 |
+
```bash
|
| 29 |
+
nix develop
|
| 30 |
+
nix run .#build-and-copy
|
| 31 |
+
```
|
| 32 |
+
|
| 33 |
+
### Testing
|
| 34 |
+
|
| 35 |
+
```bash
|
| 36 |
+
nix develop .#test
|
| 37 |
+
pytest tests/
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
### Test as a `kernels` user
|
| 41 |
+
|
| 42 |
+
```bash
|
| 43 |
+
uv run example.py
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
## License
|
| 47 |
+
|
| 48 |
+
Apache 2.0
|
__KERNEL_NAME_NORMALIZED___cpu/__KERNEL_NAME_NORMALIZED___cpu.cpp
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <torch/all.h>
|
| 2 |
+
|
| 3 |
+
void __KERNEL_NAME_NORMALIZED__(torch::Tensor &out, torch::Tensor const &input) {
|
| 4 |
+
TORCH_CHECK(out.dtype() == torch::kFloat32, "Output tensor must be float32");
|
| 5 |
+
TORCH_CHECK(input.dtype() == torch::kFloat32, "Input tensor must be float32");
|
| 6 |
+
TORCH_CHECK(out.numel() == input.numel(), "Tensors must have same size");
|
| 7 |
+
|
| 8 |
+
const float* in_ptr = input.data_ptr<float>();
|
| 9 |
+
float* out_ptr = out.data_ptr<float>();
|
| 10 |
+
int64_t n = input.numel();
|
| 11 |
+
|
| 12 |
+
for (int64_t i = 0; i < n; ++i) {
|
| 13 |
+
out_ptr[i] = in_ptr[i] + 1.0f;
|
| 14 |
+
}
|
| 15 |
+
}
|
__KERNEL_NAME_NORMALIZED___cuda/__KERNEL_NAME_NORMALIZED__.cu
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 2 |
+
#include <c10/cuda/CUDAGuard.h>
|
| 3 |
+
#include <torch/all.h>
|
| 4 |
+
|
| 5 |
+
__global__ void __KERNEL_NAME_NORMALIZED___kernel(float *__restrict__ out,
|
| 6 |
+
float const *__restrict__ input, const int n) {
|
| 7 |
+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
| 8 |
+
if (idx < n) {
|
| 9 |
+
out[idx] = input[idx] + 1.0f;
|
| 10 |
+
}
|
| 11 |
+
}
|
| 12 |
+
|
| 13 |
+
void __KERNEL_NAME_NORMALIZED__(torch::Tensor &out, torch::Tensor const &input) {
|
| 14 |
+
TORCH_CHECK(input.device().is_cuda(), "input must be a CUDA tensor");
|
| 15 |
+
TORCH_CHECK(input.is_contiguous(), "input must be contiguous");
|
| 16 |
+
TORCH_CHECK(input.scalar_type() == at::ScalarType::Float,
|
| 17 |
+
"__KERNEL_NAME_NORMALIZED__ only supports float32");
|
| 18 |
+
TORCH_CHECK(input.sizes() == out.sizes(),
|
| 19 |
+
"Tensors must have the same shape");
|
| 20 |
+
TORCH_CHECK(input.scalar_type() == out.scalar_type(),
|
| 21 |
+
"Tensors must have the same dtype");
|
| 22 |
+
TORCH_CHECK(input.device() == out.device(),
|
| 23 |
+
"Tensors must be on the same device");
|
| 24 |
+
|
| 25 |
+
int n = input.numel();
|
| 26 |
+
int threads = 256;
|
| 27 |
+
int blocks = (n + threads - 1) / threads;
|
| 28 |
+
|
| 29 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
| 30 |
+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
| 31 |
+
__KERNEL_NAME_NORMALIZED___kernel<<<blocks, threads, 0, stream>>>(
|
| 32 |
+
out.data_ptr<float>(), input.data_ptr<float>(), n);
|
| 33 |
+
}
|
__KERNEL_NAME_NORMALIZED___metal/__KERNEL_NAME_NORMALIZED__.metal
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <metal_stdlib>
|
| 2 |
+
using namespace metal;
|
| 3 |
+
|
| 4 |
+
kernel void __KERNEL_NAME_NORMALIZED___forward_kernel_float(device const float *input [[buffer(0)]],
|
| 5 |
+
device float *output [[buffer(1)]],
|
| 6 |
+
uint index [[thread_position_in_grid]]) {
|
| 7 |
+
output[index] = input[index] + 1.0f;
|
| 8 |
+
}
|
| 9 |
+
|
| 10 |
+
kernel void __KERNEL_NAME_NORMALIZED___forward_kernel_half(device const half *input [[buffer(0)]],
|
| 11 |
+
device half *output [[buffer(1)]],
|
| 12 |
+
uint index [[thread_position_in_grid]]) {
|
| 13 |
+
output[index] = input[index] + static_cast<half>(1.0);
|
| 14 |
+
}
|
__KERNEL_NAME_NORMALIZED___metal/__KERNEL_NAME_NORMALIZED__.mm
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <torch/torch.h>
|
| 2 |
+
|
| 3 |
+
#import <Foundation/Foundation.h>
|
| 4 |
+
#import <Metal/Metal.h>
|
| 5 |
+
|
| 6 |
+
#ifdef EMBEDDED_METALLIB_HEADER
|
| 7 |
+
#include EMBEDDED_METALLIB_HEADER
|
| 8 |
+
#else
|
| 9 |
+
#error "EMBEDDED_METALLIB_HEADER not defined"
|
| 10 |
+
#endif
|
| 11 |
+
|
| 12 |
+
static inline id<MTLBuffer> getMTLBufferStorage(const torch::Tensor &tensor) {
|
| 13 |
+
return __builtin_bit_cast(id<MTLBuffer>, tensor.storage().data());
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
void __KERNEL_NAME_NORMALIZED__(torch::Tensor &out, torch::Tensor const &input) {
|
| 17 |
+
TORCH_CHECK(input.device().is_mps(), "input must be a MPS tensor");
|
| 18 |
+
TORCH_CHECK(input.is_contiguous(), "input must be contiguous");
|
| 19 |
+
TORCH_CHECK(input.scalar_type() == torch::kFloat ||
|
| 20 |
+
input.scalar_type() == torch::kHalf,
|
| 21 |
+
"only float32 and float16 supported");
|
| 22 |
+
TORCH_CHECK(input.sizes() == out.sizes(), "Tensors must have same shape");
|
| 23 |
+
TORCH_CHECK(input.scalar_type() == out.scalar_type(), "Tensors must have same dtype");
|
| 24 |
+
TORCH_CHECK(input.device() == out.device(), "Tensors must be on same device");
|
| 25 |
+
|
| 26 |
+
@autoreleasepool {
|
| 27 |
+
id<MTLDevice> device = MTLCreateSystemDefaultDevice();
|
| 28 |
+
int numThreads = input.numel();
|
| 29 |
+
|
| 30 |
+
NSError *error = nil;
|
| 31 |
+
id<MTLLibrary> library = EMBEDDED_METALLIB_NAMESPACE::createLibrary(device, &error);
|
| 32 |
+
TORCH_CHECK(library, "Failed to create Metal library: ",
|
| 33 |
+
error.localizedDescription.UTF8String);
|
| 34 |
+
|
| 35 |
+
std::string kernel_name = std::string("__KERNEL_NAME_NORMALIZED___forward_kernel_") +
|
| 36 |
+
(input.scalar_type() == torch::kFloat ? "float" : "half");
|
| 37 |
+
id<MTLFunction> func = [library newFunctionWithName:
|
| 38 |
+
[NSString stringWithUTF8String:kernel_name.c_str()]];
|
| 39 |
+
TORCH_CHECK(func, "Failed to create function: ", kernel_name.c_str());
|
| 40 |
+
|
| 41 |
+
id<MTLComputePipelineState> pso =
|
| 42 |
+
[device newComputePipelineStateWithFunction:func error:&error];
|
| 43 |
+
TORCH_CHECK(pso, error.localizedDescription.UTF8String);
|
| 44 |
+
|
| 45 |
+
id<MTLCommandBuffer> cmdBuf = torch::mps::get_command_buffer();
|
| 46 |
+
dispatch_sync(torch::mps::get_dispatch_queue(), ^() {
|
| 47 |
+
id<MTLComputeCommandEncoder> encoder = [cmdBuf computeCommandEncoder];
|
| 48 |
+
[encoder setComputePipelineState:pso];
|
| 49 |
+
[encoder setBuffer:getMTLBufferStorage(input)
|
| 50 |
+
offset:input.storage_offset() * input.element_size()
|
| 51 |
+
atIndex:0];
|
| 52 |
+
[encoder setBuffer:getMTLBufferStorage(out)
|
| 53 |
+
offset:out.storage_offset() * out.element_size()
|
| 54 |
+
atIndex:1];
|
| 55 |
+
|
| 56 |
+
NSUInteger tgSize = MIN(pso.maxTotalThreadsPerThreadgroup, (NSUInteger)numThreads);
|
| 57 |
+
[encoder dispatchThreads:MTLSizeMake(numThreads, 1, 1)
|
| 58 |
+
threadsPerThreadgroup:MTLSizeMake(tgSize, 1, 1)];
|
| 59 |
+
[encoder endEncoding];
|
| 60 |
+
torch::mps::commit();
|
| 61 |
+
});
|
| 62 |
+
}
|
| 63 |
+
}
|
__KERNEL_NAME_NORMALIZED___xpu/__KERNEL_NAME_NORMALIZED__.cpp
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <sycl/sycl.hpp>
|
| 2 |
+
#include <torch/torch.h>
|
| 3 |
+
|
| 4 |
+
void __KERNEL_NAME_NORMALIZED__(torch::Tensor& out, const torch::Tensor& input) {
|
| 5 |
+
TORCH_CHECK(input.device().is_xpu(), "input must be a XPU tensor");
|
| 6 |
+
TORCH_CHECK(input.is_contiguous(), "input must be contiguous");
|
| 7 |
+
TORCH_CHECK(input.scalar_type() == torch::kFloat, "only float32 supported");
|
| 8 |
+
TORCH_CHECK(input.sizes() == out.sizes(), "Tensors must have same shape");
|
| 9 |
+
TORCH_CHECK(input.scalar_type() == out.scalar_type(), "Tensors must have same dtype");
|
| 10 |
+
TORCH_CHECK(input.device() == out.device(), "Tensors must be on same device");
|
| 11 |
+
|
| 12 |
+
sycl::queue queue;
|
| 13 |
+
auto input_ptr = input.data_ptr<float>();
|
| 14 |
+
auto output_ptr = out.data_ptr<float>();
|
| 15 |
+
auto n = input.numel();
|
| 16 |
+
|
| 17 |
+
queue.parallel_for(sycl::range<1>(n), [=](sycl::id<1> idx) {
|
| 18 |
+
output_ptr[idx[0]] = input_ptr[idx[0]] + 1.0f;
|
| 19 |
+
}).wait();
|
| 20 |
+
}
|
build.toml
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[general]
|
| 2 |
+
backends = [
|
| 3 |
+
"cpu",
|
| 4 |
+
"cuda",
|
| 5 |
+
"metal",
|
| 6 |
+
"rocm",
|
| 7 |
+
"xpu",
|
| 8 |
+
]
|
| 9 |
+
name = "__KERNEL_NAME__"
|
| 10 |
+
version = 1
|
| 11 |
+
|
| 12 |
+
[torch]
|
| 13 |
+
src = [
|
| 14 |
+
"torch-ext/torch_binding.cpp",
|
| 15 |
+
"torch-ext/torch_binding.h",
|
| 16 |
+
]
|
| 17 |
+
|
| 18 |
+
[kernel.__KERNEL_NAME_NORMALIZED__]
|
| 19 |
+
backend = "cuda"
|
| 20 |
+
depends = ["torch"]
|
| 21 |
+
src = ["__KERNEL_NAME_NORMALIZED___cuda/__KERNEL_NAME_NORMALIZED__.cu"]
|
| 22 |
+
|
| 23 |
+
[kernel.__KERNEL_NAME_NORMALIZED___metal]
|
| 24 |
+
backend = "metal"
|
| 25 |
+
depends = ["torch"]
|
| 26 |
+
src = [
|
| 27 |
+
"__KERNEL_NAME_NORMALIZED___metal/__KERNEL_NAME_NORMALIZED__.mm",
|
| 28 |
+
"__KERNEL_NAME_NORMALIZED___metal/__KERNEL_NAME_NORMALIZED__.metal",
|
| 29 |
+
]
|
| 30 |
+
|
| 31 |
+
[kernel.__KERNEL_NAME_NORMALIZED___rocm]
|
| 32 |
+
backend = "rocm"
|
| 33 |
+
depends = ["torch"]
|
| 34 |
+
rocm-archs = [
|
| 35 |
+
"gfx906",
|
| 36 |
+
"gfx908",
|
| 37 |
+
"gfx90a",
|
| 38 |
+
"gfx940",
|
| 39 |
+
"gfx941",
|
| 40 |
+
"gfx942",
|
| 41 |
+
"gfx1030",
|
| 42 |
+
"gfx1100",
|
| 43 |
+
"gfx1101",
|
| 44 |
+
]
|
| 45 |
+
src = ["__KERNEL_NAME_NORMALIZED___cuda/__KERNEL_NAME_NORMALIZED__.cu"]
|
| 46 |
+
|
| 47 |
+
[kernel.__KERNEL_NAME_NORMALIZED___xpu]
|
| 48 |
+
backend = "xpu"
|
| 49 |
+
depends = ["torch"]
|
| 50 |
+
src = ["__KERNEL_NAME_NORMALIZED___xpu/__KERNEL_NAME_NORMALIZED__.cpp"]
|
| 51 |
+
|
| 52 |
+
[kernel.__KERNEL_NAME_NORMALIZED___cpu]
|
| 53 |
+
backend = "cpu"
|
| 54 |
+
depends = ["torch"]
|
| 55 |
+
src = ["__KERNEL_NAME_NORMALIZED___cpu/__KERNEL_NAME_NORMALIZED___cpu.cpp"]
|
example.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# /// script
|
| 2 |
+
# requires-python = ">=3.13"
|
| 3 |
+
# dependencies = [
|
| 4 |
+
# "kernels",
|
| 5 |
+
# "torch",
|
| 6 |
+
# ]
|
| 7 |
+
# ///
|
| 8 |
+
|
| 9 |
+
import platform
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
import kernels
|
| 13 |
+
import torch
|
| 14 |
+
|
| 15 |
+
# Load the locally built kernel
|
| 16 |
+
kernel = kernels.get_local_kernel(Path("build"), "__KERNEL_NAME_NORMALIZED__")
|
| 17 |
+
|
| 18 |
+
# Select device
|
| 19 |
+
if platform.system() == "Darwin":
|
| 20 |
+
device = torch.device("mps")
|
| 21 |
+
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
| 22 |
+
device = torch.device("xpu")
|
| 23 |
+
elif torch.version.cuda is not None and torch.cuda.is_available():
|
| 24 |
+
device = torch.device("cuda")
|
| 25 |
+
else:
|
| 26 |
+
device = torch.device("cpu")
|
| 27 |
+
|
| 28 |
+
print(f"Using device: {device}")
|
| 29 |
+
|
| 30 |
+
# Create input tensor
|
| 31 |
+
x = torch.tensor([1.0, 2.0, 3.0], device=device)
|
| 32 |
+
print(f"Input: {x}")
|
| 33 |
+
|
| 34 |
+
# Run kernel (adds 1 to each element)
|
| 35 |
+
result = kernel.__KERNEL_NAME_NORMALIZED__(x)
|
| 36 |
+
print(f"Output: {result}")
|
| 37 |
+
|
| 38 |
+
# Verify result
|
| 39 |
+
expected = x + 1.0
|
| 40 |
+
assert torch.allclose(result, expected), "Kernel output doesn't match expected!"
|
| 41 |
+
print("Success!")
|
flake.nix
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
inputs = {
|
| 3 |
+
kernel-builder.url = "github:huggingface/kernels";
|
| 4 |
+
};
|
| 5 |
+
outputs =
|
| 6 |
+
{ self, kernel-builder, ... }:
|
| 7 |
+
kernel-builder.lib.genKernelFlakeOutputs {
|
| 8 |
+
inherit self;
|
| 9 |
+
path = ./.;
|
| 10 |
+
};
|
| 11 |
+
}
|
tests/__init__.py
ADDED
|
File without changes
|
tests/test___KERNEL_NAME_NORMALIZED__.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import platform
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
import __KERNEL_NAME_NORMALIZED__
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def test___KERNEL_NAME_NORMALIZED__():
|
| 9 |
+
if platform.system() == "Darwin":
|
| 10 |
+
device = torch.device("mps")
|
| 11 |
+
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
| 12 |
+
device = torch.device("xpu")
|
| 13 |
+
elif torch.version.cuda is not None and torch.cuda.is_available():
|
| 14 |
+
device = torch.device("cuda")
|
| 15 |
+
else:
|
| 16 |
+
device = torch.device("cpu")
|
| 17 |
+
|
| 18 |
+
x = torch.randn(1024, 1024, dtype=torch.float32, device=device)
|
| 19 |
+
expected = x + 1.0
|
| 20 |
+
result = __KERNEL_NAME_NORMALIZED__.__KERNEL_NAME_NORMALIZED__(x)
|
| 21 |
+
torch.testing.assert_close(result, expected)
|
torch-ext/__KERNEL_NAME_NORMALIZED__/__init__.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from ._ops import ops
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def __KERNEL_NAME_NORMALIZED__(x: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 9 |
+
if out is None:
|
| 10 |
+
out = torch.empty_like(x)
|
| 11 |
+
ops.__KERNEL_NAME_NORMALIZED__(out, x)
|
| 12 |
+
return out
|
torch-ext/torch_binding.cpp
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <torch/library.h>
|
| 2 |
+
|
| 3 |
+
#include "registration.h"
|
| 4 |
+
#include "torch_binding.h"
|
| 5 |
+
|
| 6 |
+
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
| 7 |
+
ops.def("__KERNEL_NAME_NORMALIZED__(Tensor! out, Tensor input) -> ()");
|
| 8 |
+
#if defined(CPU_KERNEL)
|
| 9 |
+
ops.impl("__KERNEL_NAME_NORMALIZED__", torch::kCPU, &__KERNEL_NAME_NORMALIZED__);
|
| 10 |
+
#elif defined(CUDA_KERNEL) || defined(ROCM_KERNEL)
|
| 11 |
+
ops.impl("__KERNEL_NAME_NORMALIZED__", torch::kCUDA, &__KERNEL_NAME_NORMALIZED__);
|
| 12 |
+
#elif defined(METAL_KERNEL)
|
| 13 |
+
ops.impl("__KERNEL_NAME_NORMALIZED__", torch::kMPS, __KERNEL_NAME_NORMALIZED__);
|
| 14 |
+
#elif defined(XPU_KERNEL)
|
| 15 |
+
ops.impl("__KERNEL_NAME_NORMALIZED__", torch::kXPU, &__KERNEL_NAME_NORMALIZED__);
|
| 16 |
+
#endif
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
torch-ext/torch_binding.h
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <torch/torch.h>
|
| 4 |
+
|
| 5 |
+
void __KERNEL_NAME_NORMALIZED__(torch::Tensor &out, torch::Tensor const &input);
|