ternary-quant-demo / ternary_quant /cuda_kernels.py
AsadIsmail's picture
Bundle ternary_quant package directly (private repo fix)
162f86a verified
"""
Triton CUDA kernels for ternary matmul.
Falls back gracefully when triton is not installed.
Storage format: packed_codes is uint8, 4 ternary values per byte.
Encoding: -1 -> 0, 0 -> 1, +1 -> 2
Packing: byte = v0 | (v1 << 2) | (v2 << 4) | (v3 << 6)
Weight formula (grouped asymmetric):
W[i, j] = alpha[i, j // group_size] * t[i, j] + mu[i, j // group_size]
Decomposition exploiting mu-term precomputation:
y[b, i] = sum_g { alpha[i,g] * dot(t[i, g*gs:(g+1)*gs], x[b, g*gs:(g+1)*gs])
+ mu[i,g] * sum(x[b, g*gs:(g+1)*gs]) }
The x_group_sums term (mu * sum_x) is identical for every output neuron in a
group and is precomputed once on the host before launching the kernel.
"""
from __future__ import annotations
import math
from typing import Optional
import torch
# ---------------------------------------------------------------------------
# GemLite availability
# ---------------------------------------------------------------------------
def gemlite_kernels_available() -> bool:
"""Return True if gemlite is importable and CUDA is available."""
if not torch.cuda.is_available():
return False
try:
from gemlite import GemLiteLinear # noqa: F401
return True
except Exception:
return False
# ---------------------------------------------------------------------------
# Availability check
# ---------------------------------------------------------------------------
def triton_ternary_kernels_available() -> bool:
"""Return True if Triton is importable and CUDA is available."""
if not torch.cuda.is_available():
return False
try:
import triton # noqa: F401
import triton.language as tl # noqa: F401
return True
except ImportError:
return False
# ---------------------------------------------------------------------------
# Kernel definition — only active when triton is present at import time
# ---------------------------------------------------------------------------
try:
import triton
import triton.language as tl
_TRITON_AVAILABLE = True
except ImportError:
_TRITON_AVAILABLE = False
if _TRITON_AVAILABLE:
@triton.jit
def _groupwise_ternary_mv_kernel(
out_ptr,
x_ptr,
packed_ptr,
alpha_ptr,
mu_ptr,
bias_ptr,
x_group_sums_ptr,
# Strides (non-constexpr, passed as integers)
packed_row_stride,
alpha_row_stride,
out_row_stride,
x_row_stride,
xgs_row_stride,
# Compile-time constants for loop unrolling
OUT_FEATURES: tl.constexpr,
IN_FEATURES: tl.constexpr,
GROUP_SIZE: tl.constexpr,
N_GROUPS: tl.constexpr,
N_CHUNKS: tl.constexpr, # GROUP_SIZE // 4
BLOCK_OUT: tl.constexpr,
):
"""
Grid: (ceil(out_features / BLOCK_OUT), batch_size)
Each program handles BLOCK_OUT output neurons for one batch element.
"""
pid_out = tl.program_id(0) # which BLOCK_OUT slice
pid_b = tl.program_id(1) # which batch element
out_offsets = pid_out * BLOCK_OUT + tl.arange(0, BLOCK_OUT)
out_mask = out_offsets < OUT_FEATURES
acc = tl.zeros([BLOCK_OUT], dtype=tl.float32)
x_base = x_ptr + pid_b * x_row_stride
xgs_base = x_group_sums_ptr + pid_b * xgs_row_stride
for g in tl.static_range(0, N_GROUPS):
# Scalar group-sum of activations for this batch/group
x_group_sum = tl.load(xgs_base + g).to(tl.float32)
# alpha[out_offsets, g] and mu[out_offsets, g]
scale_offsets = out_offsets * alpha_row_stride + g
alpha_g = tl.load(alpha_ptr + scale_offsets, mask=out_mask, other=0.0)
mu_g = tl.load(mu_ptr + scale_offsets, mask=out_mask, other=0.0)
ternary_dot = tl.zeros([BLOCK_OUT], dtype=tl.float32)
group_start = g * GROUP_SIZE
for chunk in tl.static_range(0, N_CHUNKS):
k = chunk * 4
byte_idx = (group_start + k) // 4
# One packed byte per output neuron in our BLOCK_OUT slice
packed_offsets = out_offsets * packed_row_stride + byte_idx
# Use 'other=1' so masked-out lanes decode to t=0 (neutral)
packed_byte = tl.load(packed_ptr + packed_offsets,
mask=out_mask, other=1).to(tl.int32)
# Decode 4 ternary values: {0,1,2} -> {-1, 0, +1}
t0 = ((packed_byte >> 0) & 0x03) - 1
t1 = ((packed_byte >> 2) & 0x03) - 1
t2 = ((packed_byte >> 4) & 0x03) - 1
t3 = ((packed_byte >> 6) & 0x03) - 1
# Activation scalars — same value across all output neurons
x0 = tl.load(x_base + group_start + k + 0).to(tl.float32)
x1 = tl.load(x_base + group_start + k + 1).to(tl.float32)
x2 = tl.load(x_base + group_start + k + 2).to(tl.float32)
x3 = tl.load(x_base + group_start + k + 3).to(tl.float32)
ternary_dot = (ternary_dot
+ t0.to(tl.float32) * x0
+ t1.to(tl.float32) * x1
+ t2.to(tl.float32) * x2
+ t3.to(tl.float32) * x3)
acc = acc + alpha_g * ternary_dot + mu_g * x_group_sum
# Bias
bias_vals = tl.load(bias_ptr + out_offsets, mask=out_mask, other=0.0)
acc = acc + bias_vals
# Store float16 output
out_base = pid_b * out_row_stride + out_offsets
tl.store(out_ptr + out_base, acc.to(tl.float16), mask=out_mask)
else:
# Placeholder so the name exists even without triton
_groupwise_ternary_mv_kernel = None
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
def groupwise_ternary_linear_cuda(
x: torch.Tensor,
packed_codes: torch.Tensor,
group_alpha: torch.Tensor,
group_mu: torch.Tensor,
bias: Optional[torch.Tensor],
out_features: int,
in_features: int,
group_size: int,
) -> torch.Tensor:
"""
Grouped ternary linear via Triton kernel.
Args:
x: [*batch, in_features] float16 or float32 on CUDA
packed_codes: uint8 [out_features * in_features // 4] on CUDA
group_alpha: float32 [out_features, n_groups] on CUDA
group_mu: float32 [out_features, n_groups] on CUDA
bias: float32 [out_features] or None
out_features, in_features, group_size: layer dimensions
Returns:
[*batch, out_features] float16
"""
if not _TRITON_AVAILABLE:
raise RuntimeError("triton is not installed; cannot use CUDA ternary kernel.")
if group_size % 4 != 0:
raise ValueError(f"group_size must be divisible by 4, got {group_size}")
x_2d = x.reshape(-1, in_features).contiguous()
batch_size = x_2d.shape[0]
n_groups = math.ceil(in_features / group_size)
n_chunks = group_size // 4 # chunks of 4 per group
# Ensure contiguous, correct dtype on device
x_f32 = x_2d.to(torch.float32).contiguous()
packed_c = packed_codes.contiguous()
alpha_c = group_alpha.to(torch.float32).contiguous()
mu_c = group_mu.to(torch.float32).contiguous()
# Precompute per-group activation sums: [batch, n_groups], float32
if in_features % group_size == 0:
x_grouped = x_f32.reshape(batch_size, n_groups, group_size)
else:
pad_len = n_groups * group_size - in_features
x_padded = torch.nn.functional.pad(x_f32, (0, pad_len))
x_grouped = x_padded.reshape(batch_size, n_groups, group_size)
x_group_sums = x_grouped.sum(dim=2).contiguous() # [batch, n_groups]
# Bias — always float32, zeros when absent
if bias is not None:
bias_c = bias.to(torch.float32).contiguous()
else:
bias_c = torch.zeros(out_features, dtype=torch.float32, device=x.device)
# Output buffer: [batch, out_features] float16
out = torch.empty(batch_size, out_features, dtype=torch.float16, device=x.device)
BLOCK_OUT = 16
grid = (math.ceil(out_features / BLOCK_OUT), batch_size)
_groupwise_ternary_mv_kernel[grid](
out,
x_f32,
packed_c,
alpha_c,
mu_c,
bias_c,
x_group_sums,
# strides
in_features // 4, # packed_row_stride
n_groups, # alpha_row_stride
out_features, # out_row_stride
in_features, # x_row_stride
n_groups, # xgs_row_stride
# constexpr dims
out_features,
in_features,
group_size,
n_groups,
n_chunks,
BLOCK_OUT,
)
out_shape = list(x.shape[:-1]) + [out_features]
return out.reshape(out_shape)
def prewarm_groupwise_ternary_cuda(
packed: torch.Tensor,
alpha: torch.Tensor,
mu: torch.Tensor,
bias: Optional[torch.Tensor],
out_features: int,
in_features: int,
group_size: int,
) -> None:
"""
Run a dummy forward pass to trigger Triton JIT compilation,
avoiding first-call latency during real inference.
"""
if not triton_ternary_kernels_available():
return
device = packed.device
x_dummy = torch.zeros(1, in_features, dtype=torch.float16, device=device)
with torch.no_grad():
groupwise_ternary_linear_cuda(
x_dummy, packed, alpha, mu, bias, out_features, in_features, group_size
)
def tritplane_ternary_linear_cuda(
x: torch.Tensor,
planes: list,
bias: Optional[torch.Tensor],
out_features: int,
in_features: int,
) -> torch.Tensor:
"""
Multi-plane ternary linear via Triton kernel.
Calls groupwise_ternary_linear_cuda once per plane, accumulates in float32,
adds bias, returns float16.
Args:
x: [*batch, in_features] float16 or float32 on CUDA
planes: list of dicts, each with keys:
packed_codes: uint8 [out_features * in_features // 4]
group_alpha: float32 [out_features, n_groups]
group_mu: float32 [out_features, n_groups]
group_size: int
bias: float32 [out_features] or None
out_features, in_features: layer dimensions
Returns:
[*batch, out_features] float16
"""
if not _TRITON_AVAILABLE:
raise RuntimeError("triton is not installed; cannot use CUDA ternary kernel.")
x_2d = x.reshape(-1, in_features)
batch_size = x_2d.shape[0]
acc = torch.zeros(batch_size, out_features, dtype=torch.float32, device=x.device)
for plane in planes:
plane_out = groupwise_ternary_linear_cuda(
x,
plane["packed_codes"],
plane["group_alpha"],
plane["group_mu"],
None, # bias is applied once after summing all planes
out_features,
in_features,
plane["group_size"],
)
# plane_out shape: [*batch, out_features] float16; accumulate in float32
acc = acc + plane_out.reshape(batch_size, out_features).float()
if bias is not None:
acc = acc + bias.to(device=x.device, dtype=torch.float32)
out = acc.to(torch.float16)
out_shape = list(x.shape[:-1]) + [out_features]
return out.reshape(out_shape)
def prewarm_tritplane_ternary_cuda(
planes: list,
bias: Optional[torch.Tensor],
out_features: int,
in_features: int,
) -> None:
"""Trigger Triton JIT compilation for all planes."""
if not triton_ternary_kernels_available():
return
device = planes[0]["packed_codes"].device
x_dummy = torch.zeros(1, in_features, dtype=torch.float16, device=device)
with torch.no_grad():
tritplane_ternary_linear_cuda(x_dummy, planes, bias, out_features, in_features)
# ---------------------------------------------------------------------------
# GemLite backend — grouped ternary linear via GemLiteLinear (int2)
#
# Math: y = GemLite(x; W_q=T+1, zeros=1, scales=alpha) + x_group_sums @ mu.T + bias
# where T ∈ {-1,0,+1} encoded as W_q ∈ {0,1,2} (2-bit packed).
# The mu offset is applied via a separate small FP16 matmul on precomputed group sums.
# ---------------------------------------------------------------------------
def build_gemlite_layer(
W_q: torch.Tensor,
group_alpha: torch.Tensor,
group_size: int,
out_features: int,
in_features: int,
) -> "GemLiteLinear":
"""
Pack one ternary plane into a GemLiteLinear.
Args:
W_q: uint8 [out_features, in_features], values in {0, 1, 2}
group_alpha: float16 [out_features, n_groups]
group_size: int
out_features, in_features: layer shape
Returns a cuda GemLiteLinear ready for forward(x).
"""
from gemlite import GemLiteLinear, DType
ones = torch.ones_like(group_alpha) # zeros=1 → (W_q - 1)*alpha = T*alpha
gl = GemLiteLinear(
W_nbits=2,
group_size=group_size,
in_features=in_features,
out_features=out_features,
input_dtype=DType.FP16,
output_dtype=DType.FP16,
).cuda()
gl.pack(W_q.cuda(), group_alpha.cuda(), ones.cuda(), bias=None)
return gl
def gemlite_groupwise_linear(
x: torch.Tensor,
gl_layer,
group_mu: torch.Tensor,
bias: Optional[torch.Tensor],
in_features: int,
group_size: int,
) -> torch.Tensor:
"""
Forward pass through a GemLite-packed groupwise ternary layer.
y = gl_layer(x) + x_group_sums @ group_mu.T + bias
"""
orig_shape = x.shape
x_2d = x.reshape(-1, in_features)
n_groups = in_features // group_size
# Symmetric ternary matmul via GemLite (handles alpha scaling internally)
out = gl_layer(x_2d if x_2d.is_contiguous() else x_2d.contiguous())
# Asymmetric mu correction: x_group_sums @ mu.T → [batch, out_features]
x_gs = x_2d.reshape(x_2d.shape[0], n_groups, group_size).sum(dim=2) # [B, n_groups]
out = out + x_gs @ group_mu.to(device=x.device, dtype=x.dtype).T
if bias is not None:
out = out + bias.to(device=x.device, dtype=x.dtype)
return out.reshape(*orig_shape[:-1], out.shape[-1])