Buckets:
| """INT4 pair-wise 4:8 sparse export helpers for Ampere Sparse Tensor Cores. | |
| The CUDA path in the Codex archive uses PTX `mma.sp` with `.s4/.u4` operands. | |
| For INT4, PTX requires pair-wise structured sparse 4:8 on matrix A. We expose | |
| that exact format as `int4_4x8_pairwise_sparse` and keep `int4_2:4sp` as the | |
| short user-facing alias. | |
| """ | |
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| from typing import Iterable, Literal | |
| import torch | |
| import torch.nn as nn | |
| Int4Mode = Literal["s4", "u4"] | |
| FORMAT_NAME = "int4_4x8_pairwise_sparse" | |
| USER_ALIAS = "int4_2:4sp" | |
| class SparseChunk: | |
| values: tuple[int, int, int, int] | |
| metadata_nibble: int | |
| nonzero_pairs: tuple[int, int] | |
| def _check_int4(x: int, mode: Int4Mode) -> None: | |
| lo, hi = (-8, 7) if mode == "s4" else (0, 15) | |
| if x < lo or x > hi: | |
| raise ValueError(f"{x} is outside {mode} range [{lo}, {hi}]") | |
| def _to_nibble(x: int, mode: Int4Mode) -> int: | |
| _check_int4(x, mode) | |
| return x & 0xF | |
| def pack_int4_pair(low: int, high: int, mode: Int4Mode = "s4") -> int: | |
| return _to_nibble(low, mode) | (_to_nibble(high, mode) << 4) | |
| def unpack_int4_pair(byte: int, mode: Int4Mode = "s4") -> tuple[int, int]: | |
| def decode(n: int) -> int: | |
| n &= 0xF | |
| if mode == "s4" and n >= 8: | |
| return n - 16 | |
| return n | |
| return decode(byte), decode(byte >> 4) | |
| def pack_sparse_chunk_4x8(chunk: Iterable[int], mode: Int4Mode = "s4") -> SparseChunk: | |
| xs = tuple(int(x) for x in chunk) | |
| if len(xs) != 8: | |
| raise ValueError(f"expected 8 values, got {len(xs)}") | |
| for x in xs: | |
| _check_int4(x, mode) | |
| nonzero_pairs: list[int] = [] | |
| compressed: list[int] = [] | |
| for pair_index in range(4): | |
| a, b = xs[pair_index * 2], xs[pair_index * 2 + 1] | |
| both_zero = a == 0 and b == 0 | |
| both_nonzero = a != 0 and b != 0 | |
| if not (both_zero or both_nonzero): | |
| raise ValueError( | |
| f"pair {pair_index} must be both zero or both non-zero, got {(a, b)}" | |
| ) | |
| if both_nonzero: | |
| nonzero_pairs.append(pair_index) | |
| compressed.extend((a, b)) | |
| if len(nonzero_pairs) != 2: | |
| raise ValueError(f"expected exactly 2 non-zero pairs, got {nonzero_pairs}") | |
| metadata = nonzero_pairs[0] | (nonzero_pairs[1] << 2) | |
| return SparseChunk( | |
| values=tuple(compressed), # type: ignore[arg-type] | |
| metadata_nibble=metadata, | |
| nonzero_pairs=(nonzero_pairs[0], nonzero_pairs[1]), | |
| ) | |
| def pack_metadata_word(nibbles: Iterable[int]) -> int: | |
| ns = tuple(int(n) for n in nibbles) | |
| if len(ns) != 8: | |
| raise ValueError(f"expected 8 metadata nibbles, got {len(ns)}") | |
| word = 0 | |
| for i, n in enumerate(ns): | |
| if n < 0 or n > 0xF: | |
| raise ValueError(f"metadata nibble out of range at {i}: {n}") | |
| word |= n << (i * 4) | |
| return word | |
| def pack_sparse_row_4x8(row: Iterable[int], mode: Int4Mode = "s4") -> tuple[bytes, list[int]]: | |
| xs = tuple(int(x) for x in row) | |
| if len(xs) % 64 != 0: | |
| raise ValueError("row length must be a multiple of 64 for metadata words") | |
| packed_bytes = bytearray() | |
| metadata_words: list[int] = [] | |
| pending_nibbles: list[int] = [] | |
| for offset in range(0, len(xs), 8): | |
| sparse = pack_sparse_chunk_4x8(xs[offset : offset + 8], mode) | |
| packed_bytes.append(pack_int4_pair(sparse.values[0], sparse.values[1], mode)) | |
| packed_bytes.append(pack_int4_pair(sparse.values[2], sparse.values[3], mode)) | |
| pending_nibbles.append(sparse.metadata_nibble) | |
| if len(pending_nibbles) == 8: | |
| metadata_words.append(pack_metadata_word(pending_nibbles)) | |
| pending_nibbles.clear() | |
| return bytes(packed_bytes), metadata_words | |
| def _pad_to_multiple(x: torch.Tensor, multiple: int, dim: int = -1) -> torch.Tensor: | |
| size = x.shape[dim] | |
| pad = (multiple - size % multiple) % multiple | |
| if pad == 0: | |
| return x | |
| shape = list(x.shape) | |
| shape[dim] = pad | |
| zeros = torch.zeros(shape, dtype=x.dtype, device=x.device) | |
| return torch.cat([x, zeros], dim=dim) | |
| def prune_tensor_pairwise_4x8(weight: torch.Tensor) -> torch.Tensor: | |
| """Keep the two strongest adjacent pairs in every eight-value row chunk.""" | |
| if weight.dim() != 2: | |
| raise ValueError("expected a 2D weight tensor") | |
| padded = _pad_to_multiple(weight.detach(), 8, dim=1).clone() | |
| rows, cols = padded.shape | |
| view = padded.view(rows, cols // 8, 4, 2) | |
| pair_scores = view.abs().sum(dim=-1) | |
| keep = torch.zeros_like(pair_scores, dtype=torch.bool) | |
| top2 = torch.topk(pair_scores, k=2, dim=-1).indices | |
| keep.scatter_(-1, top2, True) | |
| return (view * keep.unsqueeze(-1)).view(rows, cols)[:, : weight.shape[1]] | |
| def quantize_s4_per_row(weight: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | |
| scale = weight.abs().amax(dim=1).clamp(min=1e-6) / 7.0 | |
| q = torch.round(weight / scale[:, None]).clamp(-8, 7).to(torch.int16) | |
| nonzero = weight != 0 | |
| forced_sign = torch.where(weight >= 0, torch.ones_like(q), -torch.ones_like(q)) | |
| q = torch.where(nonzero & (q == 0), forced_sign, q) | |
| q = torch.where(nonzero, q, torch.zeros_like(q)) | |
| return q, scale.to(torch.float32) | |
| def _make_pairwise_chunks_encodable(qweight: torch.Tensor) -> torch.Tensor: | |
| """Ensure every 8-value chunk has exactly two fully populated pairs. | |
| Sparse tensor-core metadata selects two adjacent pairs per eight values. | |
| Padded tails and zero-initialized adapters can otherwise produce all-zero | |
| chunks that have no legal metadata. Tiny sentinel values in selected pairs | |
| keep the artifact encodable; padded sentinels are sliced away on dequant. | |
| """ | |
| padded = _pad_to_multiple(qweight, 64, dim=1).cpu().clone() | |
| rows, cols = padded.shape | |
| view = padded.view(rows, cols // 8, 4, 2) | |
| for row in range(rows): | |
| for chunk in range(cols // 8): | |
| pairs = view[row, chunk] | |
| active = [idx for idx in range(4) if bool((pairs[idx] != 0).any().item())] | |
| selected = active[:2] | |
| for idx in range(4): | |
| if idx not in selected: | |
| pairs[idx].zero_() | |
| for idx in range(4): | |
| if len(selected) >= 2: | |
| break | |
| if idx not in selected: | |
| selected.append(idx) | |
| for idx in selected: | |
| pair = pairs[idx] | |
| if int(pair[0].item()) == 0 and int(pair[1].item()) == 0: | |
| pair[0] = 1 | |
| pair[1] = 1 | |
| elif int(pair[0].item()) == 0: | |
| pair[0] = 1 if int(pair[1].item()) >= 0 else -1 | |
| elif int(pair[1].item()) == 0: | |
| pair[1] = 1 if int(pair[0].item()) >= 0 else -1 | |
| return padded | |
| def _pack_quantized_matrix(qweight: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, int]: | |
| padded = _make_pairwise_chunks_encodable(qweight) | |
| packed_rows: list[int] = [] | |
| metadata_rows: list[int] = [] | |
| for row in padded.tolist(): | |
| packed, metadata = pack_sparse_row_4x8(row, "s4") | |
| packed_rows.extend(packed) | |
| metadata_rows.extend(metadata) | |
| return ( | |
| torch.tensor(packed_rows, dtype=torch.uint8), | |
| torch.tensor(metadata_rows, dtype=torch.int64), | |
| int(padded.shape[1]), | |
| ) | |
| class INT4SparseLinear(nn.Module): | |
| format_name = FORMAT_NAME | |
| user_alias = USER_ALIAS | |
| def __init__( | |
| self, | |
| packed_weight: torch.Tensor, | |
| metadata: torch.Tensor, | |
| scales: torch.Tensor, | |
| out_features: int, | |
| in_features: int, | |
| padded_in_features: int, | |
| bias: torch.Tensor | None = None, | |
| ): | |
| super().__init__() | |
| self.out_features = out_features | |
| self.in_features = in_features | |
| self.padded_in_features = padded_in_features | |
| self.register_buffer("packed_weight", packed_weight.contiguous()) | |
| self.register_buffer("metadata", metadata.contiguous()) | |
| self.register_buffer("scales", scales.contiguous()) | |
| if bias is not None: | |
| self.register_buffer("bias", bias.detach().clone().to(torch.float32)) | |
| else: | |
| self.bias = None # type: ignore[assignment] | |
| def from_dense(cls, layer: nn.Linear) -> "INT4SparseLinear": | |
| pruned = prune_tensor_pairwise_4x8(layer.weight.detach().to(torch.float32)) | |
| qweight, scales = quantize_s4_per_row(pruned) | |
| packed, metadata, padded_in = _pack_quantized_matrix(qweight) | |
| bias = layer.bias.detach() if layer.bias is not None else None | |
| return cls( | |
| packed_weight=packed, | |
| metadata=metadata, | |
| scales=scales, | |
| out_features=layer.out_features, | |
| in_features=layer.in_features, | |
| padded_in_features=padded_in, | |
| bias=bias, | |
| ) | |
| def dequantize_weight(self) -> torch.Tensor: | |
| rows: list[list[int]] = [] | |
| packed_per_row = self.padded_in_features // 4 | |
| metadata_words_per_row = self.padded_in_features // 64 | |
| chunks_per_row = self.padded_in_features // 8 | |
| for row_idx in range(self.out_features): | |
| start = row_idx * packed_per_row | |
| bytes_row = self.packed_weight[start : start + packed_per_row].tolist() | |
| meta_start = row_idx * metadata_words_per_row | |
| meta_row = self.metadata[meta_start : meta_start + metadata_words_per_row].tolist() | |
| vals = [0 for _ in range(self.padded_in_features)] | |
| for chunk_idx in range(chunks_per_row): | |
| metadata_word = int(meta_row[chunk_idx // 8]) | |
| nibble = (metadata_word >> ((chunk_idx % 8) * 4)) & 0xF | |
| pair_indices = (nibble & 0x3, (nibble >> 2) & 0x3) | |
| byte_offset = chunk_idx * 2 | |
| compressed: list[int] = [] | |
| compressed.extend(unpack_int4_pair(int(bytes_row[byte_offset]), "s4")) | |
| compressed.extend(unpack_int4_pair(int(bytes_row[byte_offset + 1]), "s4")) | |
| chunk_base = chunk_idx * 8 | |
| for value_offset, pair_index in enumerate(pair_indices): | |
| src = value_offset * 2 | |
| dst = chunk_base + pair_index * 2 | |
| vals[dst] = compressed[src] | |
| vals[dst + 1] = compressed[src + 1] | |
| rows.append(vals[: self.in_features]) | |
| q = torch.tensor(rows, dtype=torch.float32, device=self.scales.device) | |
| return q * self.scales[:, None] | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| weight = self.dequantize_weight().to(device=x.device, dtype=x.dtype) | |
| bias = self.bias.to(device=x.device, dtype=x.dtype) if self.bias is not None else None | |
| return torch.nn.functional.linear(x, weight, bias) | |
| def export_sparse_int4_model(model: nn.Module, quality_gate_delta: float = 0.05) -> dict: | |
| layers: dict[str, INT4SparseLinear] = {} | |
| for name, module in model.named_modules(): | |
| if isinstance(module, nn.Linear) and module.weight.dim() == 2: | |
| if module.in_features >= 64: | |
| layers[name] = INT4SparseLinear.from_dense(module) | |
| return { | |
| "format": FORMAT_NAME, | |
| "user_alias": USER_ALIAS, | |
| "quality_gate_delta": quality_gate_delta, | |
| "layers": layers, | |
| } | |
Xet Storage Details
- Size:
- 11.4 kB
- Xet hash:
- b82281de5e5659ee088dea2106cffe6ed142677999b975c650a538157f6e1411
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.