saur7764's picture
Upload folder using huggingface_hub
fd899b3 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
def activation_quant(x):
"""Per-token quantization to 8-bit (standard for BitNet 1.58b)"""
scale = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5)
y = (x * scale).round().clamp_(-128, 127) / scale
return y
def weight_quant(w):
"""
Quantize weights to {-1, 0, 1} (BitNet 1.58b).
This mimics a static sparse matrix where 0s represent pruned/inactive connections.
"""
scale = w.abs().mean()
e = w.mean()
w_centered = w - e
# BitNet 1.58b logic: 0 if absolute value is below a threshold, but sign() is the simplified version.
# We use a threshold-based ternary quantization for "sparse" behavior.
threshold = 0.5 * scale
w_quant = torch.where(w.abs() > threshold, torch.sign(w_centered), torch.zeros_like(w))
return w_quant
class BitLinear(nn.Linear):
def __init__(self, in_features, out_features, bias=False):
super().__init__(in_features, out_features, bias)
self.register_buffer("scale", torch.ones(out_features, 1))
def forward(self, input):
# TurboQuant Style: Per-channel weight scaling
w = self.weight
w_quant = weight_quant(w)
w_quant = w + (w_quant - w).detach()
# Quantize activations (TurboQuant usually uses 8-bit for inputs)
x_quant = activation_quant(input)
x_quant = input + (x_quant - input).detach()
return F.linear(x_quant, w_quant * self.scale, self.bias)
class SparseMatrixConfig:
"""Config for TurboQuant/QVAC style optimizations"""
def __init__(self):
self.bitrate = 1.58
self.sparse_threshold = 0.5
self.use_vulkan = True # QVAC Fabric target