Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| def activation_quant(x): | |
| """Per-token quantization to 8-bit (standard for BitNet 1.58b)""" | |
| scale = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5) | |
| y = (x * scale).round().clamp_(-128, 127) / scale | |
| return y | |
| def weight_quant(w): | |
| """ | |
| Quantize weights to {-1, 0, 1} (BitNet 1.58b). | |
| This mimics a static sparse matrix where 0s represent pruned/inactive connections. | |
| """ | |
| scale = w.abs().mean() | |
| e = w.mean() | |
| w_centered = w - e | |
| # BitNet 1.58b logic: 0 if absolute value is below a threshold, but sign() is the simplified version. | |
| # We use a threshold-based ternary quantization for "sparse" behavior. | |
| threshold = 0.5 * scale | |
| w_quant = torch.where(w.abs() > threshold, torch.sign(w_centered), torch.zeros_like(w)) | |
| return w_quant | |
| class BitLinear(nn.Linear): | |
| def __init__(self, in_features, out_features, bias=False): | |
| super().__init__(in_features, out_features, bias) | |
| self.register_buffer("scale", torch.ones(out_features, 1)) | |
| def forward(self, input): | |
| # TurboQuant Style: Per-channel weight scaling | |
| w = self.weight | |
| w_quant = weight_quant(w) | |
| w_quant = w + (w_quant - w).detach() | |
| # Quantize activations (TurboQuant usually uses 8-bit for inputs) | |
| x_quant = activation_quant(input) | |
| x_quant = input + (x_quant - input).detach() | |
| return F.linear(x_quant, w_quant * self.scale, self.bias) | |
| class SparseMatrixConfig: | |
| """Config for TurboQuant/QVAC style optimizations""" | |
| def __init__(self): | |
| self.bitrate = 1.58 | |
| self.sparse_threshold = 0.5 | |
| self.use_vulkan = True # QVAC Fabric target | |