bitnet-1bitllm / vm_backup /code /bit_kernel.py
hidude562's picture
1bitllm code (checkpoints to follow)
4754707 verified
"""Fast Triton kernel for strict-±1 BitLinear: fused sign(x) @ sign(w).T + STE backward.
`FastBitLinear(in_f, out_f)` drops straight into our model in place of BitLinear.
Uses a Triton forward kernel that signs both operands inside the MMA loop (no
intermediate ±1 tensor materialization). Backward is straight-through: grads flow
as if sign were identity.
On an RTX 5090 for typical strict-±1 shapes (M=B·T=16384, K=D=1024, N=D=1024),
this is ~3× faster than `F.linear(sign(x), sign(w))` in fp32.
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
try:
import triton
import triton.language as tl
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
triton = None
tl = None
if HAS_TRITON:
@triton.autotune(
configs=[
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 32}, num_stages=3, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32}, num_stages=4, num_warps=4),
],
key=['M', 'N', 'K'],
)
@triton.jit
def _fused_sign_matmul_kernel(
X_ptr, W_ptr, Y_ptr,
M, N, K,
stride_xm, stride_xk,
stride_wn, stride_wk,
stride_ym, stride_yn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_K)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k0 in range(0, K, BLOCK_K):
k_mask = (offs_k + k0) < K
x = tl.load(
X_ptr + offs_m[:, None] * stride_xm + (offs_k + k0)[None, :] * stride_xk,
mask=(offs_m[:, None] < M) & k_mask[None, :], other=0.0)
w = tl.load(
W_ptr + offs_n[:, None] * stride_wn + (offs_k + k0)[None, :] * stride_wk,
mask=(offs_n[:, None] < N) & k_mask[None, :], other=0.0)
x_s = tl.where(x >= 0.0, 1.0, -1.0)
w_s = tl.where(w >= 0.0, 1.0, -1.0)
acc += tl.dot(x_s, tl.trans(w_s))
tl.store(
Y_ptr + offs_m[:, None] * stride_ym + offs_n[None, :] * stride_yn,
acc, mask=(offs_m[:, None] < M) & (offs_n[None, :] < N))
class _FusedSignMatmul(torch.autograd.Function):
"""y = sign(x) @ sign(w).T with STE backward.
x: (..., K), w: (N, K). y: (..., N)."""
@staticmethod
def forward(ctx, x, w):
orig = x.shape
x_flat = x.reshape(-1, orig[-1]).contiguous()
M, K = x_flat.shape
N = w.shape[0]
y = torch.empty(M, N, device=x.device, dtype=torch.float32)
grid = lambda META: ((M + META['BLOCK_M'] - 1) // META['BLOCK_M'],
(N + META['BLOCK_N'] - 1) // META['BLOCK_N'])
_fused_sign_matmul_kernel[grid](
x_flat, w, y, M, N, K,
x_flat.stride(0), x_flat.stride(1),
w.stride(0), w.stride(1),
y.stride(0), y.stride(1),
)
ctx.save_for_backward(x, w)
return y.reshape(*orig[:-1], N)
@staticmethod
def backward(ctx, dy):
# STE: grads flow as if sign() = identity.
# dx = dy @ sign(w), dw = dy.T @ sign(x)
x, w = ctx.saved_tensors
dy_flat = dy.reshape(-1, dy.shape[-1])
x_flat = x.reshape(-1, x.shape[-1])
w_s = torch.where(w >= 0, torch.ones_like(w), -torch.ones_like(w))
x_s = torch.where(x_flat >= 0, torch.ones_like(x_flat), -torch.ones_like(x_flat))
dx = (dy_flat @ w_s).reshape(x.shape)
dw = dy_flat.t() @ x_s
return dx, dw
def fused_bit_matmul(x, w):
"""Drop-in for `F.linear(sign(x), sign(w))`. Falls back if no Triton."""
if x.is_cuda and w.is_cuda and x.dtype == torch.float32:
return _FusedSignMatmul.apply(x, w)
x_s = torch.where(x >= 0, torch.ones_like(x), -torch.ones_like(x))
w_s = torch.where(w >= 0, torch.ones_like(w), -torch.ones_like(w))
return F.linear(x_s, w_s)
else:
def fused_bit_matmul(x, w):
x_s = torch.where(x >= 0, torch.ones_like(x), -torch.ones_like(x))
w_s = torch.where(w >= 0, torch.ones_like(w), -torch.ones_like(w))
return F.linear(x_s, w_s)
class FastBitLinear(nn.Module):
"""Triton-backed strict-±1 BitLinear.
Identical math to model.BitLinear: fused sign(weight)·sign(x) popcount,
/sqrt(in) scale, learned threshold, sign_ste_clipped output."""
def __init__(self, in_features, out_features, binarize_input=True):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.binarize_input = binarize_input
self.weight = nn.Parameter(torch.randn(out_features, in_features) * 0.02)
self.threshold = nn.Parameter(torch.zeros(out_features))
self.scale = 1.0 / math.sqrt(in_features)
def forward(self, x):
if self.binarize_input:
# sign_ste_clipped: clip + STE applied via composite. Keep semantics by
# clipping input before fused matmul (fused already signs internally).
x = torch.clamp(x, -1.0, 1.0)
raw = fused_bit_matmul(x, self.weight)
s = raw * self.scale - self.threshold
# sign_ste_clipped output (identical to existing model path)
out = torch.where(s >= 0, torch.ones_like(s), -torch.ones_like(s))
s_clip = torch.clamp(s, -1.0, 1.0)
return s_clip + (out - s_clip).detach()
if __name__ == '__main__':
torch.manual_seed(0)
# Benchmark realistic training-shape matmul.
B, T, D_IN, D_OUT = 64, 256, 1024, 1024
x = torch.randn(B, T, D_IN, device='cuda', requires_grad=True)
w = torch.randn(D_OUT, D_IN, device='cuda', requires_grad=True)
y_fast = fused_bit_matmul(x, w)
y_ref = F.linear(torch.where(x >= 0, torch.ones_like(x), -torch.ones_like(x)),
torch.where(w >= 0, torch.ones_like(w), -torch.ones_like(w)))
print(f'forward max diff: {(y_fast - y_ref).abs().max().item()}')
import time
def bench(fn, name, iters=50):
torch.cuda.synchronize()
for _ in range(5): # warmup
y = fn(); loss = y.sum(); loss.backward()
torch.cuda.synchronize()
t0 = time.time()
for _ in range(iters):
y = fn(); loss = y.sum(); loss.backward()
torch.cuda.synchronize()
return (time.time() - t0) / iters * 1000
t_fast = bench(lambda: fused_bit_matmul(x, w), 'fused')
x.grad = None; w.grad = None
t_ref = bench(lambda: F.linear(
torch.where(x >= 0, torch.ones_like(x), -torch.ones_like(x)),
torch.where(w >= 0, torch.ones_like(w), -torch.ones_like(w))), 'ref')
print(f'Triton fused (fwd+bwd): {t_fast:.2f} ms')
print(f'PyTorch ref (fwd+bwd): {t_ref:.2f} ms')
print(f'Speedup: {t_ref/t_fast:.2f}x')