File size: 3,680 Bytes
d0f40b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
"""
H-BitLinear layer for BitSkip v2 (4-bit activations WITH Hadamard transform)
OPTIMIZED: Fast Hadamard transform implementation
"""

import math
import torch
import torch.nn as nn
import torch.nn.functional as F


def hadamard_transform(x):
    """
    Fast Walsh-Hadamard Transform (FWHT) - OPTIMIZED VERSION.
    
    This vectorized implementation is MUCH faster than the loop version.
    Uses divide-and-conquer butterfly pattern for O(n log n) complexity.
    """
    orig_shape = x.shape
    n = x.shape[-1]
    
    # Ensure dimension is power of 2
    assert n & (n - 1) == 0, f"Dimension must be power of 2, got {n}"
    
    # Flatten to 2D for transform
    x = x.reshape(-1, n)
    
    # Fast Hadamard transform using butterfly pattern
    h = 1
    while h < n:
        # Vectorized butterfly operations (MUCH faster than loops!)
        x = x.reshape(-1, n // (2 * h), 2, h)
        x_even = x[:, :, 0, :]  # First half
        x_odd = x[:, :, 1, :]   # Second half
        
        # Butterfly: (a, b) -> (a+b, a-b)
        x[:, :, 0, :] = x_even + x_odd
        x[:, :, 1, :] = x_even - x_odd
        
        x = x.reshape(-1, n)
        h *= 2
    
    # Normalize
    x = x / math.sqrt(n)
    
    # Reshape back
    return x.reshape(orig_shape)


class HBitLinear(nn.Module):
    """
    H-BitLinear: Hadamard transform + Ternary weights + 4-bit activations.
    
    Flow:
    1. LayerNorm
    2. Hadamard transform (key preprocessing step!)
    3. 4-bit quantization
    4. Linear operation with ternary weights
    5. Inverse Hadamard transform
    """
    
    def __init__(self, in_features, out_features, bias=False):
        super().__init__()
        
        # Ensure power of 2 for Hadamard
        assert in_features & (in_features - 1) == 0, \
            f"in_features must be power of 2 for Hadamard, got {in_features}"
        assert out_features & (out_features - 1) == 0, \
            f"out_features must be power of 2 for Hadamard, got {out_features}"
        
        self.in_features = in_features
        self.out_features = out_features
        
        # Weight and bias
        self.weight = nn.Parameter(torch.randn(out_features, in_features) * 0.02)
        self.bias = nn.Parameter(torch.zeros(out_features)) if bias else None
        
        # LayerNorm before Hadamard
        self.norm = nn.LayerNorm(in_features)
    
    def forward(self, x):
        """
        Forward with Hadamard preprocessing + 4-bit quantization.
        """
        # 1. LayerNorm
        x = self.norm(x)
        
        # 2. Hadamard transform (KEY STEP for v2!)
        x_hadamard = hadamard_transform(x)
        
        # 3. 4-bit quantization (works better after Hadamard)
        x_scale = x_hadamard.abs().max(dim=-1, keepdim=True)[0].clamp(min=1e-5)
        x_quant = (x_hadamard / x_scale * 7).round().clamp(-8, 7)  # 4-bit: -8 to 7
        x_quant = x_quant / 7 * x_scale
        
        # STE for gradients
        if self.training:
            x_quant = x_hadamard + (x_quant - x_hadamard).detach()
        
        # 4. Ternary weight quantization (same as v1)
        w_scale = self.weight.abs().mean().clamp(min=1e-5)
        w_quant = torch.zeros_like(self.weight)
        w_quant[self.weight > 0.5 * w_scale] = 1.0
        w_quant[self.weight < -0.5 * w_scale] = -1.0
        w_quant = w_quant * w_scale
        
        if self.training:
            w_quant = self.weight + (w_quant - self.weight).detach()
        
        # 5. Linear operation
        output = F.linear(x_quant, w_quant, self.bias)
        
        # 6. Inverse Hadamard transform
        output = hadamard_transform(output)
        
        return output