bbkdevops's picture
download
raw
11.4 kB
"""INT4 pair-wise 4:8 sparse export helpers for Ampere Sparse Tensor Cores.
The CUDA path in the Codex archive uses PTX `mma.sp` with `.s4/.u4` operands.
For INT4, PTX requires pair-wise structured sparse 4:8 on matrix A. We expose
that exact format as `int4_4x8_pairwise_sparse` and keep `int4_2:4sp` as the
short user-facing alias.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Iterable, Literal
import torch
import torch.nn as nn
Int4Mode = Literal["s4", "u4"]
FORMAT_NAME = "int4_4x8_pairwise_sparse"
USER_ALIAS = "int4_2:4sp"
@dataclass(frozen=True)
class SparseChunk:
values: tuple[int, int, int, int]
metadata_nibble: int
nonzero_pairs: tuple[int, int]
def _check_int4(x: int, mode: Int4Mode) -> None:
lo, hi = (-8, 7) if mode == "s4" else (0, 15)
if x < lo or x > hi:
raise ValueError(f"{x} is outside {mode} range [{lo}, {hi}]")
def _to_nibble(x: int, mode: Int4Mode) -> int:
_check_int4(x, mode)
return x & 0xF
def pack_int4_pair(low: int, high: int, mode: Int4Mode = "s4") -> int:
return _to_nibble(low, mode) | (_to_nibble(high, mode) << 4)
def unpack_int4_pair(byte: int, mode: Int4Mode = "s4") -> tuple[int, int]:
def decode(n: int) -> int:
n &= 0xF
if mode == "s4" and n >= 8:
return n - 16
return n
return decode(byte), decode(byte >> 4)
def pack_sparse_chunk_4x8(chunk: Iterable[int], mode: Int4Mode = "s4") -> SparseChunk:
xs = tuple(int(x) for x in chunk)
if len(xs) != 8:
raise ValueError(f"expected 8 values, got {len(xs)}")
for x in xs:
_check_int4(x, mode)
nonzero_pairs: list[int] = []
compressed: list[int] = []
for pair_index in range(4):
a, b = xs[pair_index * 2], xs[pair_index * 2 + 1]
both_zero = a == 0 and b == 0
both_nonzero = a != 0 and b != 0
if not (both_zero or both_nonzero):
raise ValueError(
f"pair {pair_index} must be both zero or both non-zero, got {(a, b)}"
)
if both_nonzero:
nonzero_pairs.append(pair_index)
compressed.extend((a, b))
if len(nonzero_pairs) != 2:
raise ValueError(f"expected exactly 2 non-zero pairs, got {nonzero_pairs}")
metadata = nonzero_pairs[0] | (nonzero_pairs[1] << 2)
return SparseChunk(
values=tuple(compressed), # type: ignore[arg-type]
metadata_nibble=metadata,
nonzero_pairs=(nonzero_pairs[0], nonzero_pairs[1]),
)
def pack_metadata_word(nibbles: Iterable[int]) -> int:
ns = tuple(int(n) for n in nibbles)
if len(ns) != 8:
raise ValueError(f"expected 8 metadata nibbles, got {len(ns)}")
word = 0
for i, n in enumerate(ns):
if n < 0 or n > 0xF:
raise ValueError(f"metadata nibble out of range at {i}: {n}")
word |= n << (i * 4)
return word
def pack_sparse_row_4x8(row: Iterable[int], mode: Int4Mode = "s4") -> tuple[bytes, list[int]]:
xs = tuple(int(x) for x in row)
if len(xs) % 64 != 0:
raise ValueError("row length must be a multiple of 64 for metadata words")
packed_bytes = bytearray()
metadata_words: list[int] = []
pending_nibbles: list[int] = []
for offset in range(0, len(xs), 8):
sparse = pack_sparse_chunk_4x8(xs[offset : offset + 8], mode)
packed_bytes.append(pack_int4_pair(sparse.values[0], sparse.values[1], mode))
packed_bytes.append(pack_int4_pair(sparse.values[2], sparse.values[3], mode))
pending_nibbles.append(sparse.metadata_nibble)
if len(pending_nibbles) == 8:
metadata_words.append(pack_metadata_word(pending_nibbles))
pending_nibbles.clear()
return bytes(packed_bytes), metadata_words
def _pad_to_multiple(x: torch.Tensor, multiple: int, dim: int = -1) -> torch.Tensor:
size = x.shape[dim]
pad = (multiple - size % multiple) % multiple
if pad == 0:
return x
shape = list(x.shape)
shape[dim] = pad
zeros = torch.zeros(shape, dtype=x.dtype, device=x.device)
return torch.cat([x, zeros], dim=dim)
def prune_tensor_pairwise_4x8(weight: torch.Tensor) -> torch.Tensor:
"""Keep the two strongest adjacent pairs in every eight-value row chunk."""
if weight.dim() != 2:
raise ValueError("expected a 2D weight tensor")
padded = _pad_to_multiple(weight.detach(), 8, dim=1).clone()
rows, cols = padded.shape
view = padded.view(rows, cols // 8, 4, 2)
pair_scores = view.abs().sum(dim=-1)
keep = torch.zeros_like(pair_scores, dtype=torch.bool)
top2 = torch.topk(pair_scores, k=2, dim=-1).indices
keep.scatter_(-1, top2, True)
return (view * keep.unsqueeze(-1)).view(rows, cols)[:, : weight.shape[1]]
def quantize_s4_per_row(weight: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
scale = weight.abs().amax(dim=1).clamp(min=1e-6) / 7.0
q = torch.round(weight / scale[:, None]).clamp(-8, 7).to(torch.int16)
nonzero = weight != 0
forced_sign = torch.where(weight >= 0, torch.ones_like(q), -torch.ones_like(q))
q = torch.where(nonzero & (q == 0), forced_sign, q)
q = torch.where(nonzero, q, torch.zeros_like(q))
return q, scale.to(torch.float32)
def _make_pairwise_chunks_encodable(qweight: torch.Tensor) -> torch.Tensor:
"""Ensure every 8-value chunk has exactly two fully populated pairs.
Sparse tensor-core metadata selects two adjacent pairs per eight values.
Padded tails and zero-initialized adapters can otherwise produce all-zero
chunks that have no legal metadata. Tiny sentinel values in selected pairs
keep the artifact encodable; padded sentinels are sliced away on dequant.
"""
padded = _pad_to_multiple(qweight, 64, dim=1).cpu().clone()
rows, cols = padded.shape
view = padded.view(rows, cols // 8, 4, 2)
for row in range(rows):
for chunk in range(cols // 8):
pairs = view[row, chunk]
active = [idx for idx in range(4) if bool((pairs[idx] != 0).any().item())]
selected = active[:2]
for idx in range(4):
if idx not in selected:
pairs[idx].zero_()
for idx in range(4):
if len(selected) >= 2:
break
if idx not in selected:
selected.append(idx)
for idx in selected:
pair = pairs[idx]
if int(pair[0].item()) == 0 and int(pair[1].item()) == 0:
pair[0] = 1
pair[1] = 1
elif int(pair[0].item()) == 0:
pair[0] = 1 if int(pair[1].item()) >= 0 else -1
elif int(pair[1].item()) == 0:
pair[1] = 1 if int(pair[0].item()) >= 0 else -1
return padded
def _pack_quantized_matrix(qweight: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, int]:
padded = _make_pairwise_chunks_encodable(qweight)
packed_rows: list[int] = []
metadata_rows: list[int] = []
for row in padded.tolist():
packed, metadata = pack_sparse_row_4x8(row, "s4")
packed_rows.extend(packed)
metadata_rows.extend(metadata)
return (
torch.tensor(packed_rows, dtype=torch.uint8),
torch.tensor(metadata_rows, dtype=torch.int64),
int(padded.shape[1]),
)
class INT4SparseLinear(nn.Module):
format_name = FORMAT_NAME
user_alias = USER_ALIAS
def __init__(
self,
packed_weight: torch.Tensor,
metadata: torch.Tensor,
scales: torch.Tensor,
out_features: int,
in_features: int,
padded_in_features: int,
bias: torch.Tensor | None = None,
):
super().__init__()
self.out_features = out_features
self.in_features = in_features
self.padded_in_features = padded_in_features
self.register_buffer("packed_weight", packed_weight.contiguous())
self.register_buffer("metadata", metadata.contiguous())
self.register_buffer("scales", scales.contiguous())
if bias is not None:
self.register_buffer("bias", bias.detach().clone().to(torch.float32))
else:
self.bias = None # type: ignore[assignment]
@classmethod
def from_dense(cls, layer: nn.Linear) -> "INT4SparseLinear":
pruned = prune_tensor_pairwise_4x8(layer.weight.detach().to(torch.float32))
qweight, scales = quantize_s4_per_row(pruned)
packed, metadata, padded_in = _pack_quantized_matrix(qweight)
bias = layer.bias.detach() if layer.bias is not None else None
return cls(
packed_weight=packed,
metadata=metadata,
scales=scales,
out_features=layer.out_features,
in_features=layer.in_features,
padded_in_features=padded_in,
bias=bias,
)
def dequantize_weight(self) -> torch.Tensor:
rows: list[list[int]] = []
packed_per_row = self.padded_in_features // 4
metadata_words_per_row = self.padded_in_features // 64
chunks_per_row = self.padded_in_features // 8
for row_idx in range(self.out_features):
start = row_idx * packed_per_row
bytes_row = self.packed_weight[start : start + packed_per_row].tolist()
meta_start = row_idx * metadata_words_per_row
meta_row = self.metadata[meta_start : meta_start + metadata_words_per_row].tolist()
vals = [0 for _ in range(self.padded_in_features)]
for chunk_idx in range(chunks_per_row):
metadata_word = int(meta_row[chunk_idx // 8])
nibble = (metadata_word >> ((chunk_idx % 8) * 4)) & 0xF
pair_indices = (nibble & 0x3, (nibble >> 2) & 0x3)
byte_offset = chunk_idx * 2
compressed: list[int] = []
compressed.extend(unpack_int4_pair(int(bytes_row[byte_offset]), "s4"))
compressed.extend(unpack_int4_pair(int(bytes_row[byte_offset + 1]), "s4"))
chunk_base = chunk_idx * 8
for value_offset, pair_index in enumerate(pair_indices):
src = value_offset * 2
dst = chunk_base + pair_index * 2
vals[dst] = compressed[src]
vals[dst + 1] = compressed[src + 1]
rows.append(vals[: self.in_features])
q = torch.tensor(rows, dtype=torch.float32, device=self.scales.device)
return q * self.scales[:, None]
def forward(self, x: torch.Tensor) -> torch.Tensor:
weight = self.dequantize_weight().to(device=x.device, dtype=x.dtype)
bias = self.bias.to(device=x.device, dtype=x.dtype) if self.bias is not None else None
return torch.nn.functional.linear(x, weight, bias)
def export_sparse_int4_model(model: nn.Module, quality_gate_delta: float = 0.05) -> dict:
layers: dict[str, INT4SparseLinear] = {}
for name, module in model.named_modules():
if isinstance(module, nn.Linear) and module.weight.dim() == 2:
if module.in_features >= 64:
layers[name] = INT4SparseLinear.from_dense(module)
return {
"format": FORMAT_NAME,
"user_alias": USER_ALIAS,
"quality_gate_delta": quality_gate_delta,
"layers": layers,
}

Xet Storage Details

Size:
11.4 kB
·
Xet hash:
b82281de5e5659ee088dea2106cffe6ed142677999b975c650a538157f6e1411

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.