Spaces:
Sleeping
Sleeping
| """ | |
| Triton CUDA kernels for ternary matmul. | |
| Falls back gracefully when triton is not installed. | |
| Storage format: packed_codes is uint8, 4 ternary values per byte. | |
| Encoding: -1 -> 0, 0 -> 1, +1 -> 2 | |
| Packing: byte = v0 | (v1 << 2) | (v2 << 4) | (v3 << 6) | |
| Weight formula (grouped asymmetric): | |
| W[i, j] = alpha[i, j // group_size] * t[i, j] + mu[i, j // group_size] | |
| Decomposition exploiting mu-term precomputation: | |
| y[b, i] = sum_g { alpha[i,g] * dot(t[i, g*gs:(g+1)*gs], x[b, g*gs:(g+1)*gs]) | |
| + mu[i,g] * sum(x[b, g*gs:(g+1)*gs]) } | |
| The x_group_sums term (mu * sum_x) is identical for every output neuron in a | |
| group and is precomputed once on the host before launching the kernel. | |
| """ | |
| from __future__ import annotations | |
| import math | |
| from typing import Optional | |
| import torch | |
| # --------------------------------------------------------------------------- | |
| # GemLite availability | |
| # --------------------------------------------------------------------------- | |
| def gemlite_kernels_available() -> bool: | |
| """Return True if gemlite is importable and CUDA is available.""" | |
| if not torch.cuda.is_available(): | |
| return False | |
| try: | |
| from gemlite import GemLiteLinear # noqa: F401 | |
| return True | |
| except Exception: | |
| return False | |
| # --------------------------------------------------------------------------- | |
| # Availability check | |
| # --------------------------------------------------------------------------- | |
| def triton_ternary_kernels_available() -> bool: | |
| """Return True if Triton is importable and CUDA is available.""" | |
| if not torch.cuda.is_available(): | |
| return False | |
| try: | |
| import triton # noqa: F401 | |
| import triton.language as tl # noqa: F401 | |
| return True | |
| except ImportError: | |
| return False | |
| # --------------------------------------------------------------------------- | |
| # Kernel definition — only active when triton is present at import time | |
| # --------------------------------------------------------------------------- | |
| try: | |
| import triton | |
| import triton.language as tl | |
| _TRITON_AVAILABLE = True | |
| except ImportError: | |
| _TRITON_AVAILABLE = False | |
| if _TRITON_AVAILABLE: | |
| def _groupwise_ternary_mv_kernel( | |
| out_ptr, | |
| x_ptr, | |
| packed_ptr, | |
| alpha_ptr, | |
| mu_ptr, | |
| bias_ptr, | |
| x_group_sums_ptr, | |
| # Strides (non-constexpr, passed as integers) | |
| packed_row_stride, | |
| alpha_row_stride, | |
| out_row_stride, | |
| x_row_stride, | |
| xgs_row_stride, | |
| # Compile-time constants for loop unrolling | |
| OUT_FEATURES: tl.constexpr, | |
| IN_FEATURES: tl.constexpr, | |
| GROUP_SIZE: tl.constexpr, | |
| N_GROUPS: tl.constexpr, | |
| N_CHUNKS: tl.constexpr, # GROUP_SIZE // 4 | |
| BLOCK_OUT: tl.constexpr, | |
| ): | |
| """ | |
| Grid: (ceil(out_features / BLOCK_OUT), batch_size) | |
| Each program handles BLOCK_OUT output neurons for one batch element. | |
| """ | |
| pid_out = tl.program_id(0) # which BLOCK_OUT slice | |
| pid_b = tl.program_id(1) # which batch element | |
| out_offsets = pid_out * BLOCK_OUT + tl.arange(0, BLOCK_OUT) | |
| out_mask = out_offsets < OUT_FEATURES | |
| acc = tl.zeros([BLOCK_OUT], dtype=tl.float32) | |
| x_base = x_ptr + pid_b * x_row_stride | |
| xgs_base = x_group_sums_ptr + pid_b * xgs_row_stride | |
| for g in tl.static_range(0, N_GROUPS): | |
| # Scalar group-sum of activations for this batch/group | |
| x_group_sum = tl.load(xgs_base + g).to(tl.float32) | |
| # alpha[out_offsets, g] and mu[out_offsets, g] | |
| scale_offsets = out_offsets * alpha_row_stride + g | |
| alpha_g = tl.load(alpha_ptr + scale_offsets, mask=out_mask, other=0.0) | |
| mu_g = tl.load(mu_ptr + scale_offsets, mask=out_mask, other=0.0) | |
| ternary_dot = tl.zeros([BLOCK_OUT], dtype=tl.float32) | |
| group_start = g * GROUP_SIZE | |
| for chunk in tl.static_range(0, N_CHUNKS): | |
| k = chunk * 4 | |
| byte_idx = (group_start + k) // 4 | |
| # One packed byte per output neuron in our BLOCK_OUT slice | |
| packed_offsets = out_offsets * packed_row_stride + byte_idx | |
| # Use 'other=1' so masked-out lanes decode to t=0 (neutral) | |
| packed_byte = tl.load(packed_ptr + packed_offsets, | |
| mask=out_mask, other=1).to(tl.int32) | |
| # Decode 4 ternary values: {0,1,2} -> {-1, 0, +1} | |
| t0 = ((packed_byte >> 0) & 0x03) - 1 | |
| t1 = ((packed_byte >> 2) & 0x03) - 1 | |
| t2 = ((packed_byte >> 4) & 0x03) - 1 | |
| t3 = ((packed_byte >> 6) & 0x03) - 1 | |
| # Activation scalars — same value across all output neurons | |
| x0 = tl.load(x_base + group_start + k + 0).to(tl.float32) | |
| x1 = tl.load(x_base + group_start + k + 1).to(tl.float32) | |
| x2 = tl.load(x_base + group_start + k + 2).to(tl.float32) | |
| x3 = tl.load(x_base + group_start + k + 3).to(tl.float32) | |
| ternary_dot = (ternary_dot | |
| + t0.to(tl.float32) * x0 | |
| + t1.to(tl.float32) * x1 | |
| + t2.to(tl.float32) * x2 | |
| + t3.to(tl.float32) * x3) | |
| acc = acc + alpha_g * ternary_dot + mu_g * x_group_sum | |
| # Bias | |
| bias_vals = tl.load(bias_ptr + out_offsets, mask=out_mask, other=0.0) | |
| acc = acc + bias_vals | |
| # Store float16 output | |
| out_base = pid_b * out_row_stride + out_offsets | |
| tl.store(out_ptr + out_base, acc.to(tl.float16), mask=out_mask) | |
| else: | |
| # Placeholder so the name exists even without triton | |
| _groupwise_ternary_mv_kernel = None | |
| # --------------------------------------------------------------------------- | |
| # Public API | |
| # --------------------------------------------------------------------------- | |
| def groupwise_ternary_linear_cuda( | |
| x: torch.Tensor, | |
| packed_codes: torch.Tensor, | |
| group_alpha: torch.Tensor, | |
| group_mu: torch.Tensor, | |
| bias: Optional[torch.Tensor], | |
| out_features: int, | |
| in_features: int, | |
| group_size: int, | |
| ) -> torch.Tensor: | |
| """ | |
| Grouped ternary linear via Triton kernel. | |
| Args: | |
| x: [*batch, in_features] float16 or float32 on CUDA | |
| packed_codes: uint8 [out_features * in_features // 4] on CUDA | |
| group_alpha: float32 [out_features, n_groups] on CUDA | |
| group_mu: float32 [out_features, n_groups] on CUDA | |
| bias: float32 [out_features] or None | |
| out_features, in_features, group_size: layer dimensions | |
| Returns: | |
| [*batch, out_features] float16 | |
| """ | |
| if not _TRITON_AVAILABLE: | |
| raise RuntimeError("triton is not installed; cannot use CUDA ternary kernel.") | |
| if group_size % 4 != 0: | |
| raise ValueError(f"group_size must be divisible by 4, got {group_size}") | |
| x_2d = x.reshape(-1, in_features).contiguous() | |
| batch_size = x_2d.shape[0] | |
| n_groups = math.ceil(in_features / group_size) | |
| n_chunks = group_size // 4 # chunks of 4 per group | |
| # Ensure contiguous, correct dtype on device | |
| x_f32 = x_2d.to(torch.float32).contiguous() | |
| packed_c = packed_codes.contiguous() | |
| alpha_c = group_alpha.to(torch.float32).contiguous() | |
| mu_c = group_mu.to(torch.float32).contiguous() | |
| # Precompute per-group activation sums: [batch, n_groups], float32 | |
| if in_features % group_size == 0: | |
| x_grouped = x_f32.reshape(batch_size, n_groups, group_size) | |
| else: | |
| pad_len = n_groups * group_size - in_features | |
| x_padded = torch.nn.functional.pad(x_f32, (0, pad_len)) | |
| x_grouped = x_padded.reshape(batch_size, n_groups, group_size) | |
| x_group_sums = x_grouped.sum(dim=2).contiguous() # [batch, n_groups] | |
| # Bias — always float32, zeros when absent | |
| if bias is not None: | |
| bias_c = bias.to(torch.float32).contiguous() | |
| else: | |
| bias_c = torch.zeros(out_features, dtype=torch.float32, device=x.device) | |
| # Output buffer: [batch, out_features] float16 | |
| out = torch.empty(batch_size, out_features, dtype=torch.float16, device=x.device) | |
| BLOCK_OUT = 16 | |
| grid = (math.ceil(out_features / BLOCK_OUT), batch_size) | |
| _groupwise_ternary_mv_kernel[grid]( | |
| out, | |
| x_f32, | |
| packed_c, | |
| alpha_c, | |
| mu_c, | |
| bias_c, | |
| x_group_sums, | |
| # strides | |
| in_features // 4, # packed_row_stride | |
| n_groups, # alpha_row_stride | |
| out_features, # out_row_stride | |
| in_features, # x_row_stride | |
| n_groups, # xgs_row_stride | |
| # constexpr dims | |
| out_features, | |
| in_features, | |
| group_size, | |
| n_groups, | |
| n_chunks, | |
| BLOCK_OUT, | |
| ) | |
| out_shape = list(x.shape[:-1]) + [out_features] | |
| return out.reshape(out_shape) | |
| def prewarm_groupwise_ternary_cuda( | |
| packed: torch.Tensor, | |
| alpha: torch.Tensor, | |
| mu: torch.Tensor, | |
| bias: Optional[torch.Tensor], | |
| out_features: int, | |
| in_features: int, | |
| group_size: int, | |
| ) -> None: | |
| """ | |
| Run a dummy forward pass to trigger Triton JIT compilation, | |
| avoiding first-call latency during real inference. | |
| """ | |
| if not triton_ternary_kernels_available(): | |
| return | |
| device = packed.device | |
| x_dummy = torch.zeros(1, in_features, dtype=torch.float16, device=device) | |
| with torch.no_grad(): | |
| groupwise_ternary_linear_cuda( | |
| x_dummy, packed, alpha, mu, bias, out_features, in_features, group_size | |
| ) | |
| def tritplane_ternary_linear_cuda( | |
| x: torch.Tensor, | |
| planes: list, | |
| bias: Optional[torch.Tensor], | |
| out_features: int, | |
| in_features: int, | |
| ) -> torch.Tensor: | |
| """ | |
| Multi-plane ternary linear via Triton kernel. | |
| Calls groupwise_ternary_linear_cuda once per plane, accumulates in float32, | |
| adds bias, returns float16. | |
| Args: | |
| x: [*batch, in_features] float16 or float32 on CUDA | |
| planes: list of dicts, each with keys: | |
| packed_codes: uint8 [out_features * in_features // 4] | |
| group_alpha: float32 [out_features, n_groups] | |
| group_mu: float32 [out_features, n_groups] | |
| group_size: int | |
| bias: float32 [out_features] or None | |
| out_features, in_features: layer dimensions | |
| Returns: | |
| [*batch, out_features] float16 | |
| """ | |
| if not _TRITON_AVAILABLE: | |
| raise RuntimeError("triton is not installed; cannot use CUDA ternary kernel.") | |
| x_2d = x.reshape(-1, in_features) | |
| batch_size = x_2d.shape[0] | |
| acc = torch.zeros(batch_size, out_features, dtype=torch.float32, device=x.device) | |
| for plane in planes: | |
| plane_out = groupwise_ternary_linear_cuda( | |
| x, | |
| plane["packed_codes"], | |
| plane["group_alpha"], | |
| plane["group_mu"], | |
| None, # bias is applied once after summing all planes | |
| out_features, | |
| in_features, | |
| plane["group_size"], | |
| ) | |
| # plane_out shape: [*batch, out_features] float16; accumulate in float32 | |
| acc = acc + plane_out.reshape(batch_size, out_features).float() | |
| if bias is not None: | |
| acc = acc + bias.to(device=x.device, dtype=torch.float32) | |
| out = acc.to(torch.float16) | |
| out_shape = list(x.shape[:-1]) + [out_features] | |
| return out.reshape(out_shape) | |
| def prewarm_tritplane_ternary_cuda( | |
| planes: list, | |
| bias: Optional[torch.Tensor], | |
| out_features: int, | |
| in_features: int, | |
| ) -> None: | |
| """Trigger Triton JIT compilation for all planes.""" | |
| if not triton_ternary_kernels_available(): | |
| return | |
| device = planes[0]["packed_codes"].device | |
| x_dummy = torch.zeros(1, in_features, dtype=torch.float16, device=device) | |
| with torch.no_grad(): | |
| tritplane_ternary_linear_cuda(x_dummy, planes, bias, out_features, in_features) | |
| # --------------------------------------------------------------------------- | |
| # GemLite backend — grouped ternary linear via GemLiteLinear (int2) | |
| # | |
| # Math: y = GemLite(x; W_q=T+1, zeros=1, scales=alpha) + x_group_sums @ mu.T + bias | |
| # where T ∈ {-1,0,+1} encoded as W_q ∈ {0,1,2} (2-bit packed). | |
| # The mu offset is applied via a separate small FP16 matmul on precomputed group sums. | |
| # --------------------------------------------------------------------------- | |
| def build_gemlite_layer( | |
| W_q: torch.Tensor, | |
| group_alpha: torch.Tensor, | |
| group_size: int, | |
| out_features: int, | |
| in_features: int, | |
| ) -> "GemLiteLinear": | |
| """ | |
| Pack one ternary plane into a GemLiteLinear. | |
| Args: | |
| W_q: uint8 [out_features, in_features], values in {0, 1, 2} | |
| group_alpha: float16 [out_features, n_groups] | |
| group_size: int | |
| out_features, in_features: layer shape | |
| Returns a cuda GemLiteLinear ready for forward(x). | |
| """ | |
| from gemlite import GemLiteLinear, DType | |
| ones = torch.ones_like(group_alpha) # zeros=1 → (W_q - 1)*alpha = T*alpha | |
| gl = GemLiteLinear( | |
| W_nbits=2, | |
| group_size=group_size, | |
| in_features=in_features, | |
| out_features=out_features, | |
| input_dtype=DType.FP16, | |
| output_dtype=DType.FP16, | |
| ).cuda() | |
| gl.pack(W_q.cuda(), group_alpha.cuda(), ones.cuda(), bias=None) | |
| return gl | |
| def gemlite_groupwise_linear( | |
| x: torch.Tensor, | |
| gl_layer, | |
| group_mu: torch.Tensor, | |
| bias: Optional[torch.Tensor], | |
| in_features: int, | |
| group_size: int, | |
| ) -> torch.Tensor: | |
| """ | |
| Forward pass through a GemLite-packed groupwise ternary layer. | |
| y = gl_layer(x) + x_group_sums @ group_mu.T + bias | |
| """ | |
| orig_shape = x.shape | |
| x_2d = x.reshape(-1, in_features) | |
| n_groups = in_features // group_size | |
| # Symmetric ternary matmul via GemLite (handles alpha scaling internally) | |
| out = gl_layer(x_2d if x_2d.is_contiguous() else x_2d.contiguous()) | |
| # Asymmetric mu correction: x_group_sums @ mu.T → [batch, out_features] | |
| x_gs = x_2d.reshape(x_2d.shape[0], n_groups, group_size).sum(dim=2) # [B, n_groups] | |
| out = out + x_gs @ group_mu.to(device=x.device, dtype=x.dtype).T | |
| if bias is not None: | |
| out = out + bias.to(device=x.device, dtype=x.dtype) | |
| return out.reshape(*orig_shape[:-1], out.shape[-1]) | |