File size: 12,807 Bytes
fd8c8b9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 |
"""
BitLinear layer implementations.
This module provides nn.Module wrappers around the functional implementations,
providing a drop-in replacement for nn.Linear with ternary weights.
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional
from .functional import (
bitlinear_python,
greedy_ternary_decomposition,
multi_ternary_linear_python,
)
from .quantization import weight_to_ternary
class BitLinear(nn.Module):
"""
BitLinear layer: drop-in replacement for nn.Linear with ternary weights.
This layer uses ternary weights ({-1, 0, +1}) instead of full-precision
weights, achieving ~20x memory compression while maintaining competitive
performance on Transformer models.
Interface matches nn.Linear:
- Same initialization arguments (in_features, out_features, bias)
- Same forward signature
- Can replace nn.Linear in existing architectures
Example:
>>> # Standard Linear
>>> linear = nn.Linear(512, 512)
>>> # BitLinear replacement
>>> bitlinear = BitLinear(512, 512)
>>> x = torch.randn(32, 128, 512)
>>> output = bitlinear(x) # Same interface
Notes:
- Weights are quantized to ternary on initialization or conversion
- Stores ternary weights + scaling factors (gamma)
- Forward pass uses efficient ternary matrix multiplication
- Can be trained with QAT (Quantization-Aware Training)
Attributes:
in_features: Input dimension
out_features: Output dimension
W_ternary: Ternary weight matrix [out_features, in_features]
gamma: Per-output scaling factors [out_features]
bias: Optional bias term [out_features]
"""
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
"""
Initialize BitLinear layer.
Args:
in_features: Size of each input sample
out_features: Size of each output sample
bias: If True, add learnable bias (default: True)
device: Device to place parameters on
dtype: Data type for parameters
TODO:
- Initialize dense weights using standard initialization (e.g., kaiming_uniform_)
- Convert to ternary using weight_to_ternary()
- Register W_ternary and gamma as parameters or buffers
- Initialize bias if needed
- Decide on training strategy (fixed ternary vs. QAT)
"""
super().__init__()
self.in_features = in_features
self.out_features = out_features
# Store ternary weights as buffers (for inference) but use parameters for QAT support
# We'll use parameters to allow gradient flow during training
self.W_ternary = nn.Parameter(torch.zeros(out_features, in_features))
self.gamma = nn.Parameter(torch.ones(out_features))
# Initialize bias
if bias:
self.bias = nn.Parameter(torch.zeros(out_features))
else:
self.register_parameter('bias', None)
# Initialize parameters properly
self.reset_parameters()
def reset_parameters(self) -> None:
"""
Initialize layer parameters.
Strategy:
1. Initialize dense weights using standard scheme (kaiming_uniform_)
2. Quantize to ternary using weight_to_ternary()
3. Store ternary weights and scaling factors
"""
# Initialize as dense weights first
W_dense = torch.empty(self.out_features, self.in_features)
nn.init.kaiming_uniform_(W_dense, a=math.sqrt(5))
# Quantize to ternary
W_ternary, gamma = weight_to_ternary(W_dense, per_channel=True)
self.W_ternary.data.copy_(W_ternary)
self.gamma.data.copy_(gamma)
# Initialize bias using standard PyTorch scheme
if self.bias is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(W_dense)
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
nn.init.uniform_(self.bias, -bound, bound)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass through BitLinear layer.
Args:
x: Input tensor of shape [..., in_features]
Returns:
Output tensor of shape [..., out_features]
"""
return bitlinear_python(x, self.W_ternary, self.gamma, self.bias)
@classmethod
def from_linear(cls, linear: nn.Linear) -> 'BitLinear':
"""
Convert a standard nn.Linear layer to BitLinear.
This allows converting pre-trained models to use ternary weights.
Args:
linear: Standard nn.Linear layer to convert
Returns:
BitLinear layer with quantized weights
Example:
>>> linear = nn.Linear(512, 512)
>>> # ... train linear ...
>>> bitlinear = BitLinear.from_linear(linear)
"""
# Create new BitLinear with same dimensions
bitlinear = cls(
linear.in_features,
linear.out_features,
bias=linear.bias is not None,
device=linear.weight.device,
dtype=linear.weight.dtype,
)
# Quantize the linear weights to ternary
W_ternary, gamma = weight_to_ternary(linear.weight.data, per_channel=True)
bitlinear.W_ternary.data.copy_(W_ternary)
bitlinear.gamma.data.copy_(gamma)
# Copy bias if present
if linear.bias is not None:
bitlinear.bias.data.copy_(linear.bias.data)
return bitlinear
def extra_repr(self) -> str:
"""String representation for print()."""
return f'in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}'
class MultiTernaryLinear(nn.Module):
"""
Multi-component ternary linear layer.
Represents a linear layer as a sum of k ternary components:
output = sum_{i=1}^k (x @ W_i^T * gamma_i) + bias
This provides better approximation of dense weights compared to single
ternary quantization, at the cost of k× more computation.
References:
- JMLR paper on ternary representations: https://jmlr.org/papers/volume26/24-2050/24-2050.pdf
- Greedy ternary decomposition for neural networks
Attributes:
in_features: Input dimension
out_features: Output dimension
k: Number of ternary components
W_ternary: Stacked ternary weights [k, out_features, in_features]
gammas: Stacked scaling factors [k, out_features]
bias: Optional bias term [out_features]
Example:
>>> # Single ternary component (equivalent to BitLinear)
>>> layer = MultiTernaryLinear(512, 512, k=1)
>>> # Multiple components for better approximation
>>> layer = MultiTernaryLinear(512, 512, k=4)
"""
def __init__(
self,
in_features: int,
out_features: int,
k: int = 2,
bias: bool = True,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
"""
Initialize MultiTernaryLinear layer.
Args:
in_features: Size of each input sample
out_features: Size of each output sample
k: Number of ternary components (typically 2-4)
bias: If True, add learnable bias
device: Device to place parameters on
dtype: Data type for parameters
TODO:
- Initialize dense weights
- Apply greedy_ternary_decomposition with k components
- Store stacked ternary weights and gammas
- Initialize bias
"""
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.k = k
# Store as parameters for QAT support
self.W_ternary = nn.Parameter(torch.zeros(k, out_features, in_features))
self.gammas = nn.Parameter(torch.ones(k, out_features))
if bias:
self.bias = nn.Parameter(torch.zeros(out_features))
else:
self.register_parameter('bias', None)
# Initialize parameters
self.reset_parameters()
def reset_parameters(self) -> None:
"""
Initialize layer parameters using greedy ternary decomposition.
"""
# Initialize dense weights
W_dense = torch.empty(self.out_features, self.in_features)
nn.init.kaiming_uniform_(W_dense, a=math.sqrt(5))
# Apply greedy ternary decomposition
W_ternary_list, gamma_list = greedy_ternary_decomposition(W_dense, self.k)
# Stack into tensors
self.W_ternary.data.copy_(W_ternary_list)
self.gammas.data.copy_(gamma_list)
# Initialize bias
if self.bias is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(W_dense)
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
nn.init.uniform_(self.bias, -bound, bound)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass through multi-ternary layer.
Args:
x: Input tensor of shape [..., in_features]
Returns:
Output tensor of shape [..., out_features]
"""
return multi_ternary_linear_python(x, self.W_ternary, self.gammas, self.bias)
@classmethod
def from_linear(cls, linear: nn.Linear, k: int = 2) -> 'MultiTernaryLinear':
"""
Convert nn.Linear to MultiTernaryLinear using greedy decomposition.
Args:
linear: Standard nn.Linear layer
k: Number of ternary components
Returns:
MultiTernaryLinear layer
"""
# Create new MultiTernaryLinear instance
multi_ternary = cls(
linear.in_features,
linear.out_features,
k=k,
bias=linear.bias is not None,
device=linear.weight.device,
dtype=linear.weight.dtype,
)
# Apply greedy decomposition to linear weights
W_ternary_list, gamma_list = greedy_ternary_decomposition(linear.weight.data, k)
multi_ternary.W_ternary.data.copy_(W_ternary_list)
multi_ternary.gammas.data.copy_(gamma_list)
# Copy bias if present
if linear.bias is not None:
multi_ternary.bias.data.copy_(linear.bias.data)
return multi_ternary
def extra_repr(self) -> str:
"""String representation."""
return f'in_features={self.in_features}, out_features={self.out_features}, k={self.k}, bias={self.bias is not None}'
def convert_linear_to_bitlinear(
module: nn.Module,
inplace: bool = True,
) -> nn.Module:
"""
Recursively convert all nn.Linear layers in a module to BitLinear.
This utility function walks through a model and replaces all Linear layers
with BitLinear layers, useful for converting pre-trained models.
Args:
module: PyTorch module (e.g., a Transformer model)
inplace: If True, modify module in place; if False, return a copy
Returns:
Module with Linear layers replaced by BitLinear
Example:
>>> model = transformers.GPT2Model.from_pretrained('gpt2')
>>> model = convert_linear_to_bitlinear(model)
>>> # All Linear layers are now BitLinear
"""
if not inplace:
import copy
module = copy.deepcopy(module)
# Recursively replace Linear layers
for name, child in module.named_children():
if isinstance(child, nn.Linear):
# Replace with BitLinear
setattr(module, name, BitLinear.from_linear(child))
else:
# Recursively process child modules
convert_linear_to_bitlinear(child, inplace=True)
return module
|