File size: 12,807 Bytes
fd8c8b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
"""

BitLinear layer implementations.



This module provides nn.Module wrappers around the functional implementations,

providing a drop-in replacement for nn.Linear with ternary weights.

"""

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional

from .functional import (
    bitlinear_python,
    greedy_ternary_decomposition,
    multi_ternary_linear_python,
)
from .quantization import weight_to_ternary


class BitLinear(nn.Module):
    """

    BitLinear layer: drop-in replacement for nn.Linear with ternary weights.

    

    This layer uses ternary weights ({-1, 0, +1}) instead of full-precision

    weights, achieving ~20x memory compression while maintaining competitive

    performance on Transformer models.

    

    Interface matches nn.Linear:

        - Same initialization arguments (in_features, out_features, bias)

        - Same forward signature

        - Can replace nn.Linear in existing architectures

    

    Example:

        >>> # Standard Linear

        >>> linear = nn.Linear(512, 512)

        >>> # BitLinear replacement

        >>> bitlinear = BitLinear(512, 512)

        >>> x = torch.randn(32, 128, 512)

        >>> output = bitlinear(x)  # Same interface

    

    Notes:

        - Weights are quantized to ternary on initialization or conversion

        - Stores ternary weights + scaling factors (gamma)

        - Forward pass uses efficient ternary matrix multiplication

        - Can be trained with QAT (Quantization-Aware Training)

    

    Attributes:

        in_features: Input dimension

        out_features: Output dimension

        W_ternary: Ternary weight matrix [out_features, in_features]

        gamma: Per-output scaling factors [out_features]

        bias: Optional bias term [out_features]

    """
    
    def __init__(

        self,

        in_features: int,

        out_features: int,

        bias: bool = True,

        device: Optional[torch.device] = None,

        dtype: Optional[torch.dtype] = None,

    ):
        """

        Initialize BitLinear layer.

        

        Args:

            in_features: Size of each input sample

            out_features: Size of each output sample

            bias: If True, add learnable bias (default: True)

            device: Device to place parameters on

            dtype: Data type for parameters

        

        TODO:

            - Initialize dense weights using standard initialization (e.g., kaiming_uniform_)

            - Convert to ternary using weight_to_ternary()

            - Register W_ternary and gamma as parameters or buffers

            - Initialize bias if needed

            - Decide on training strategy (fixed ternary vs. QAT)

        """
        super().__init__()
        
        self.in_features = in_features
        self.out_features = out_features
        
        # Store ternary weights as buffers (for inference) but use parameters for QAT support
        # We'll use parameters to allow gradient flow during training
        self.W_ternary = nn.Parameter(torch.zeros(out_features, in_features))
        self.gamma = nn.Parameter(torch.ones(out_features))
        
        # Initialize bias
        if bias:
            self.bias = nn.Parameter(torch.zeros(out_features))
        else:
            self.register_parameter('bias', None)
        
        # Initialize parameters properly
        self.reset_parameters()
    
    def reset_parameters(self) -> None:
        """

        Initialize layer parameters.

        

        Strategy:

            1. Initialize dense weights using standard scheme (kaiming_uniform_)

            2. Quantize to ternary using weight_to_ternary()

            3. Store ternary weights and scaling factors

        """
        # Initialize as dense weights first
        W_dense = torch.empty(self.out_features, self.in_features)
        nn.init.kaiming_uniform_(W_dense, a=math.sqrt(5))
        
        # Quantize to ternary
        W_ternary, gamma = weight_to_ternary(W_dense, per_channel=True)
        self.W_ternary.data.copy_(W_ternary)
        self.gamma.data.copy_(gamma)
        
        # Initialize bias using standard PyTorch scheme
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(W_dense)
            bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
            nn.init.uniform_(self.bias, -bound, bound)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """

        Forward pass through BitLinear layer.

        

        Args:

            x: Input tensor of shape [..., in_features]

        

        Returns:

            Output tensor of shape [..., out_features]

        """
        return bitlinear_python(x, self.W_ternary, self.gamma, self.bias)
    
    @classmethod
    def from_linear(cls, linear: nn.Linear) -> 'BitLinear':
        """

        Convert a standard nn.Linear layer to BitLinear.

        

        This allows converting pre-trained models to use ternary weights.

        

        Args:

            linear: Standard nn.Linear layer to convert

        

        Returns:

            BitLinear layer with quantized weights

        

        Example:

            >>> linear = nn.Linear(512, 512)

            >>> # ... train linear ...

            >>> bitlinear = BitLinear.from_linear(linear)

        """
        # Create new BitLinear with same dimensions
        bitlinear = cls(
            linear.in_features,
            linear.out_features,
            bias=linear.bias is not None,
            device=linear.weight.device,
            dtype=linear.weight.dtype,
        )
        
        # Quantize the linear weights to ternary
        W_ternary, gamma = weight_to_ternary(linear.weight.data, per_channel=True)
        bitlinear.W_ternary.data.copy_(W_ternary)
        bitlinear.gamma.data.copy_(gamma)
        
        # Copy bias if present
        if linear.bias is not None:
            bitlinear.bias.data.copy_(linear.bias.data)
        
        return bitlinear
    
    def extra_repr(self) -> str:
        """String representation for print()."""
        return f'in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}'


class MultiTernaryLinear(nn.Module):
    """

    Multi-component ternary linear layer.

    

    Represents a linear layer as a sum of k ternary components:

        output = sum_{i=1}^k (x @ W_i^T * gamma_i) + bias

    

    This provides better approximation of dense weights compared to single

    ternary quantization, at the cost of k× more computation.

    

    References:

        - JMLR paper on ternary representations: https://jmlr.org/papers/volume26/24-2050/24-2050.pdf

        - Greedy ternary decomposition for neural networks

    

    Attributes:

        in_features: Input dimension

        out_features: Output dimension

        k: Number of ternary components

        W_ternary: Stacked ternary weights [k, out_features, in_features]

        gammas: Stacked scaling factors [k, out_features]

        bias: Optional bias term [out_features]

    

    Example:

        >>> # Single ternary component (equivalent to BitLinear)

        >>> layer = MultiTernaryLinear(512, 512, k=1)

        >>> # Multiple components for better approximation

        >>> layer = MultiTernaryLinear(512, 512, k=4)

    """
    
    def __init__(

        self,

        in_features: int,

        out_features: int,

        k: int = 2,

        bias: bool = True,

        device: Optional[torch.device] = None,

        dtype: Optional[torch.dtype] = None,

    ):
        """

        Initialize MultiTernaryLinear layer.

        

        Args:

            in_features: Size of each input sample

            out_features: Size of each output sample

            k: Number of ternary components (typically 2-4)

            bias: If True, add learnable bias

            device: Device to place parameters on

            dtype: Data type for parameters

        

        TODO:

            - Initialize dense weights

            - Apply greedy_ternary_decomposition with k components

            - Store stacked ternary weights and gammas

            - Initialize bias

        """
        super().__init__()
        
        self.in_features = in_features
        self.out_features = out_features
        self.k = k
        
        
        # Store as parameters for QAT support
        self.W_ternary = nn.Parameter(torch.zeros(k, out_features, in_features))
        self.gammas = nn.Parameter(torch.ones(k, out_features))
        
        if bias:
            self.bias = nn.Parameter(torch.zeros(out_features))
        else:
            self.register_parameter('bias', None)
        
        # Initialize parameters
        self.reset_parameters()
    
    def reset_parameters(self) -> None:
        """

        Initialize layer parameters using greedy ternary decomposition.

        """
        # Initialize dense weights
        W_dense = torch.empty(self.out_features, self.in_features)
        nn.init.kaiming_uniform_(W_dense, a=math.sqrt(5))
        
        # Apply greedy ternary decomposition
        W_ternary_list, gamma_list = greedy_ternary_decomposition(W_dense, self.k)
        
        # Stack into tensors
        self.W_ternary.data.copy_(W_ternary_list)
        self.gammas.data.copy_(gamma_list)
        
        # Initialize bias
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(W_dense)
            bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
            nn.init.uniform_(self.bias, -bound, bound)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """

        Forward pass through multi-ternary layer.

        

        Args:

            x: Input tensor of shape [..., in_features]

        

        Returns:

            Output tensor of shape [..., out_features]

        """
        return multi_ternary_linear_python(x, self.W_ternary, self.gammas, self.bias)
    
    @classmethod
    def from_linear(cls, linear: nn.Linear, k: int = 2) -> 'MultiTernaryLinear':
        """

        Convert nn.Linear to MultiTernaryLinear using greedy decomposition.

        

        Args:

            linear: Standard nn.Linear layer

            k: Number of ternary components

        

        Returns:

            MultiTernaryLinear layer

        """
        # Create new MultiTernaryLinear instance
        multi_ternary = cls(
            linear.in_features,
            linear.out_features,
            k=k,
            bias=linear.bias is not None,
            device=linear.weight.device,
            dtype=linear.weight.dtype,
        )
        
        # Apply greedy decomposition to linear weights
        W_ternary_list, gamma_list = greedy_ternary_decomposition(linear.weight.data, k)
        multi_ternary.W_ternary.data.copy_(W_ternary_list)
        multi_ternary.gammas.data.copy_(gamma_list)
        
        # Copy bias if present
        if linear.bias is not None:
            multi_ternary.bias.data.copy_(linear.bias.data)
        
        return multi_ternary
    
    def extra_repr(self) -> str:
        """String representation."""
        return f'in_features={self.in_features}, out_features={self.out_features}, k={self.k}, bias={self.bias is not None}'


def convert_linear_to_bitlinear(

    module: nn.Module,

    inplace: bool = True,

) -> nn.Module:
    """

    Recursively convert all nn.Linear layers in a module to BitLinear.

    

    This utility function walks through a model and replaces all Linear layers

    with BitLinear layers, useful for converting pre-trained models.

    

    Args:

        module: PyTorch module (e.g., a Transformer model)

        inplace: If True, modify module in place; if False, return a copy

    

    Returns:

        Module with Linear layers replaced by BitLinear

    

    Example:

        >>> model = transformers.GPT2Model.from_pretrained('gpt2')

        >>> model = convert_linear_to_bitlinear(model)

        >>> # All Linear layers are now BitLinear

    """
    if not inplace:
        import copy
        module = copy.deepcopy(module)
    
    # Recursively replace Linear layers
    for name, child in module.named_children():
        if isinstance(child, nn.Linear):
            # Replace with BitLinear
            setattr(module, name, BitLinear.from_linear(child))
        else:
            # Recursively process child modules
            convert_linear_to_bitlinear(child, inplace=True)
    
    return module