Kernels
File size: 2,122 Bytes
656a6f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6436ad6
 
 
656a6f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
"""Local CUDA build for activation kernels.

Usage:
    pip install -e .          # editable install
    python setup.py build_ext --inplace  # build only

The built extension is named '_activation' and can be loaded via:
    import _activation
    torch.ops._activation.rms_norm(...)
"""

import os
from pathlib import Path

import torch
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension

ROOT = Path(__file__).parent

CUDA_SOURCES = [
    "activation/poly_norm.cu",
    "activation/fused_mul_poly_norm.cu",
    "activation/rms_norm.cu",
    "activation/fused_add_rms_norm.cu",
    "activation/grouped_poly_norm.cu",
]

CPP_SOURCES = [
    "torch-ext/torch_binding.cpp",
]

# Include dirs: project root (for registration.h, activation/*.h)
# and torch-ext/ (for torch_binding.h)
INCLUDE_DIRS = [
    str(ROOT),
    str(ROOT / "activation"),
    str(ROOT / "torch-ext"),
]

# CUDA flags matching the existing kernel style
NVCC_FLAGS = [
    "-O3",
    "--use_fast_math",
    "-std=c++17",
    # Generate code for common architectures
    "-gencode=arch=compute_80,code=sm_80",  # A100
    "-gencode=arch=compute_89,code=sm_89",  # L40/4090
    "-gencode=arch=compute_90,code=sm_90",  # H100
]

# Check for B200 support (sm_100, requires CUDA 12.8+)
cuda_version = tuple(int(x) for x in torch.version.cuda.split(".")[:2])
if cuda_version >= (12, 8):
    NVCC_FLAGS.append("-gencode=arch=compute_100,code=sm_100")

CXX_FLAGS = ["-O3", "-std=c++17"]

ext_modules = [
    CUDAExtension(
        name="_activation",
        sources=[str(ROOT / s) for s in CPP_SOURCES + CUDA_SOURCES],
        include_dirs=INCLUDE_DIRS,
        extra_compile_args={
            "cxx": CXX_FLAGS,
            "nvcc": NVCC_FLAGS,
        },
    ),
]

setup(
    name="activation",
    version="0.1.0",
    description="Custom CUDA normalization kernels for LLM training",
    ext_modules=ext_modules,
    cmdclass={"build_ext": BuildExtension},
    packages=["activation"],
    package_dir={"activation": "torch-ext/activation"},
    python_requires=">=3.10",
    install_requires=["torch>=2.7"],
)