import torch import torch.nn as nn import torch.nn.functional as F def activation_quant(x): """Per-token quantization to 8-bit (standard for BitNet 1.58b)""" scale = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5) y = (x * scale).round().clamp_(-128, 127) / scale return y def weight_quant(w): """ Quantize weights to {-1, 0, 1} (BitNet 1.58b). This mimics a static sparse matrix where 0s represent pruned/inactive connections. """ scale = w.abs().mean() e = w.mean() w_centered = w - e # BitNet 1.58b logic: 0 if absolute value is below a threshold, but sign() is the simplified version. # We use a threshold-based ternary quantization for "sparse" behavior. threshold = 0.5 * scale w_quant = torch.where(w.abs() > threshold, torch.sign(w_centered), torch.zeros_like(w)) return w_quant class BitLinear(nn.Linear): def __init__(self, in_features, out_features, bias=False): super().__init__(in_features, out_features, bias) self.register_buffer("scale", torch.ones(out_features, 1)) def forward(self, input): # TurboQuant Style: Per-channel weight scaling w = self.weight w_quant = weight_quant(w) w_quant = w + (w_quant - w).detach() # Quantize activations (TurboQuant usually uses 8-bit for inputs) x_quant = activation_quant(input) x_quant = input + (x_quant - input).detach() return F.linear(x_quant, w_quant * self.scale, self.bias) class SparseMatrixConfig: """Config for TurboQuant/QVAC style optimizations""" def __init__(self): self.bitrate = 1.58 self.sparse_threshold = 0.5 self.use_vulkan = True # QVAC Fabric target