File size: 1,755 Bytes
c61a185
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fd899b3
 
 
 
c61a185
fd899b3
c61a185
 
fd899b3
c61a185
fd899b3
c61a185
fd899b3
c61a185
fd899b3
c61a185
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import torch
import torch.nn as nn
import torch.nn.functional as F

def activation_quant(x):
    """Per-token quantization to 8-bit (standard for BitNet 1.58b)"""
    scale = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5)
    y = (x * scale).round().clamp_(-128, 127) / scale
    return y

def weight_quant(w):
    """
    Quantize weights to {-1, 0, 1} (BitNet 1.58b).
    This mimics a static sparse matrix where 0s represent pruned/inactive connections.
    """
    scale = w.abs().mean()
    e = w.mean()
    w_centered = w - e
    # BitNet 1.58b logic: 0 if absolute value is below a threshold, but sign() is the simplified version.
    # We use a threshold-based ternary quantization for "sparse" behavior.
    threshold = 0.5 * scale
    w_quant = torch.where(w.abs() > threshold, torch.sign(w_centered), torch.zeros_like(w))
    return w_quant

class BitLinear(nn.Linear):
    def __init__(self, in_features, out_features, bias=False):
        super().__init__(in_features, out_features, bias)
        self.register_buffer("scale", torch.ones(out_features, 1))

    def forward(self, input):
        # TurboQuant Style: Per-channel weight scaling
        w = self.weight
        w_quant = weight_quant(w)
        w_quant = w + (w_quant - w).detach() 
        
        # Quantize activations (TurboQuant usually uses 8-bit for inputs)
        x_quant = activation_quant(input)
        x_quant = input + (x_quant - input).detach()
        
        return F.linear(x_quant, w_quant * self.scale, self.bias)

class SparseMatrixConfig:
    """Config for TurboQuant/QVAC style optimizations"""
    def __init__(self):
        self.bitrate = 1.58
        self.sparse_threshold = 0.5
        self.use_vulkan = True # QVAC Fabric target