File size: 12,174 Bytes
bfa9a3d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 |
import torch
import torch.nn as nn
import math
@torch.no_grad()
def quantize_complex_tensor(w_real: torch.Tensor, w_imag: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Apply PhaseQuant logic to complex weight tensors"""
phase = torch.angle(w_real + 1j * w_imag)
real_pos = (phase >= -math.pi / 4) & (phase < math.pi / 4)
real_neg = (phase >= 3 * math.pi / 4) | (phase < -3 * math.pi / 4)
imag_pos = (phase >= math.pi / 4) & (phase < 3 * math.pi / 4)
imag_neg = (phase >= -3 * math.pi / 4) & (phase < -math.pi / 4)
mask_real = real_pos | real_neg
mask_imag = imag_pos | imag_neg
s_re = w_real[mask_real].abs().mean() if mask_real.any() else torch.tensor(0.0, device=w_real.device)
s_im = w_imag[mask_imag].abs().mean() if mask_imag.any() else torch.tensor(0.0, device=w_imag.device)
s_re = torch.clamp(s_re, min=1e-6)
s_im = torch.clamp(s_im, min=1e-6)
if torch.isnan(s_re) or torch.isinf(s_re): s_re = torch.tensor(1e-6, device=w_real.device)
if torch.isnan(s_im) or torch.isinf(s_im): s_im = torch.tensor(1e-6, device=w_imag.device)
qw_real = torch.zeros_like(w_real)
qw_imag = torch.zeros_like(w_imag)
qw_real[real_pos] = 1.0
qw_real[real_neg] = -1.0
qw_imag[imag_pos] = 1.0
qw_imag[imag_neg] = -1.0
qw_real_scaled = qw_real * s_re
qw_imag_scaled = qw_imag * s_im
return qw_real_scaled.to(w_real.dtype), qw_imag_scaled.to(w_imag.dtype)
def apply_complex_inspired_quantization(model: nn.Module):
"""Apply complex-inspired quantization to real-valued model"""
print("Applying complex-inspired quantization (PhaseQuant-based)...")
@torch.no_grad()
def quantize_linear_layer(module: nn.Linear):
A = module.weight.data
if A.shape[0] % 2 != 0 or A.shape[1] % 2 != 0:
print(f" -> Skipping layer (non-even dimensions): {A.shape}")
return
n, m = A.shape[0] // 2, A.shape[1] // 2
A11, A12 = A[:n, :m], A[:n, m:]
A21, A22 = A[n:, :m], A[n:, m:]
U_re = 0.5 * (A11 + A22)
U_im = 0.5 * (A21 - A12)
W_re = 0.5 * (A11 - A22)
W_im = 0.5 * (A12 + A21)
U_re_q, U_im_q = quantize_complex_tensor(U_re, U_im)
W_re_q, W_im_q = quantize_complex_tensor(W_re, W_im)
A11_q = W_re_q + U_re_q
A12_q = W_im_q - U_im_q
A21_q = W_im_q + U_im_q
A22_q = -W_re_q + U_re_q
A_quant_top = torch.cat([A11_q, A12_q], dim=1)
A_quant_bottom = torch.cat([A21_q, A22_q], dim=1)
A_quant = torch.cat([A_quant_top, A_quant_bottom], dim=0)
module.weight.data = A_quant.to(A.dtype)
model.apply(lambda module: quantize_linear_layer(module) if isinstance(module, nn.Linear) else None)
print("Complex-inspired quantization completed.")
return model
def apply_bitnet_quantization(model: nn.Module):
"""Apply BitNet 1-bit quantization to real-valued model"""
print("Applying BitNet (true 1-bit, affine) quantization to real-valued model...")
@torch.no_grad()
def quantize_linear_layer(module: nn.Linear):
scale = module.weight.data.abs().mean()
alpha = module.weight.data.mean()
centered_weights = module.weight.data - alpha
binarized_weights = torch.where(centered_weights > 0, 1.0, -1.0)
module.weight.data = binarized_weights.to(module.weight.data.dtype) * scale
model.apply(lambda module: quantize_linear_layer(module) if isinstance(module, nn.Linear) else None)
print("BitNet quantization completed.")
return model
def apply_bitnet_1_58bit_quantization_standard(model: nn.Module):
"""Apply BitNet 1.58-bit quantization to real-valued model (quantize to {-1, 0, +1})"""
print("Applying BitNet 1.58-bit (absmean threshold) quantization to real-valued model...")
@torch.no_grad()
def quantize_linear_layer(module: nn.Linear):
W = module.weight.data
gamma = W.abs().mean()
W_normalized = W / (gamma + 1e-5)
W_quantized = torch.clamp(torch.round(W_normalized), -1.0, 1.0)
module.weight.data = W_quantized.to(W.dtype) * gamma
model.apply(lambda module: quantize_linear_layer(module) if isinstance(module, nn.Linear) else None)
print("BitNet 1.58-bit (absmean threshold) quantization completed.")
return model
def apply_bitnet_1_58bit_quantization_variant(model: nn.Module, threshold: float = 0.5):
"""Apply BitNet 1.58-bit quantization to real-valued model (quantize to {-1, 0, +1})"""
print("Applying BitNet 1.58-bit (ternary) quantization to real-valued model...")
@torch.no_grad()
def quantize_linear_layer(module: nn.Linear):
gamma = module.weight.data.abs().mean()
normalized_weights = module.weight.data / (gamma + 1e-5)
adaptive_threshold = threshold
ternary_weights = torch.zeros_like(normalized_weights)
ternary_weights[normalized_weights > adaptive_threshold] = 1.0
ternary_weights[normalized_weights < -adaptive_threshold] = -1.0
module.weight.data = ternary_weights.to(module.weight.data.dtype) * gamma
model.apply(lambda module: quantize_linear_layer(module) if isinstance(module, nn.Linear) else None)
print("BitNet 1.58-bit quantization completed.")
return model
def minmax_1bit_quantize_dequantize(w: torch.Tensor) -> torch.Tensor:
"""Apply 1-bit Min-Max quantization and dequantization to weight tensor"""
min_val = w.min()
max_val = w.max()
scale = (max_val - min_val) / 1.0
zero_point = min_val
if abs(scale) < 1e-9:
return w
quantized_w = torch.round((w - zero_point) / scale)
dequantized_w = quantized_w * scale + zero_point
return dequantized_w.to(w.dtype)
def apply_minmax_1bit_quantization(model: nn.Module):
"""Apply Min-Max 1-bit quantization to real-valued model"""
print("Applying Min-Max (1-bit) quantization to real-valued model...")
@torch.no_grad()
def quantize_linear_layer(module: nn.Linear):
module.weight.data = minmax_1bit_quantize_dequantize(module.weight.data)
model.apply(lambda module: quantize_linear_layer(module) if isinstance(module, nn.Linear) else None)
print("Min-Max 1-bit quantization completed.")
return model
def symmetric_minmax_1bit_quantize_dequantize(w: torch.Tensor) -> torch.Tensor:
"""Apply symmetric 1-bit Min-Max quantization to weight tensor (quantize to {-1, 1})"""
max_abs = w.abs().max()
scale = max_abs
if scale < 1e-9:
return w
quantized_w = (w / scale).sign()
dequantized_w = quantized_w * scale
return dequantized_w.to(w.dtype)
def apply_symmetric_minmax_1bit_quantization(model: nn.Module):
"""Apply symmetric Min-Max 1-bit quantization to real-valued model"""
print("Applying symmetric Min-Max (1-bit, to {-1, 1}) quantization to real-valued model...")
@torch.no_grad()
def quantize_linear_layer(module: nn.Linear):
module.weight.data = symmetric_minmax_1bit_quantize_dequantize(module.weight.data)
model.apply(lambda module: quantize_linear_layer(module) if isinstance(module, nn.Linear) else None)
print("Symmetric Min-Max 1-bit quantization completed.")
return model
class BitNetQuantSTE(torch.autograd.Function):
"""BitNet STE: quantize in forward, pass gradients in backward"""
@staticmethod
def forward(ctx, w):
scale = w.abs().mean()
alpha = w.mean()
centered_w = w - alpha
binarized_w = torch.where(centered_w > 0, 1.0, -1.0).to(w.dtype)
quantized_w = binarized_w * scale
return quantized_w
@staticmethod
def backward(ctx, grad_output):
return grad_output
class BitNet1_58QuantSTE(torch.autograd.Function):
"""BitNet 1.58-bit STE: quantize to {-1, 0, +1}, pass gradients in backward"""
@staticmethod
def forward(ctx, w):
gamma = w.abs().mean()
w_normalized = w / (gamma + 1e-5)
w_quantized = torch.clamp(torch.round(w_normalized), -1.0, 1.0)
quantized_w = (w_quantized * gamma).to(w.dtype)
return quantized_w
@staticmethod
def backward(ctx, grad_output):
return grad_output
class PhaseQuantSTE(torch.autograd.Function):
"""Complex-Phase STE: quantize in forward, pass gradients in backward"""
@staticmethod
def forward(ctx, w_real, w_imag):
phase = torch.angle(w_real + 1j * w_imag)
real_pos = (phase >= -math.pi / 4) & (phase < math.pi / 4)
real_neg = (phase >= 3 * math.pi / 4) | (phase < -3 * math.pi / 4)
imag_pos = (phase >= math.pi / 4) & (phase < 3 * math.pi / 4)
imag_neg = (phase >= -3 * math.pi / 4) & (phase < -math.pi / 4)
mask_real = real_pos | real_neg
mask_imag = imag_pos | imag_neg
s_re = w_real[mask_real].abs().mean() if mask_real.any() else torch.tensor(0.0, device=w_real.device)
s_im = w_imag[mask_imag].abs().mean() if mask_imag.any() else torch.tensor(0.0, device=w_imag.device)
s_re = torch.clamp(s_re, min=1e-6)
s_im = torch.clamp(s_im, min=1e-6)
qw_real = torch.zeros_like(w_real)
qw_imag = torch.zeros_like(w_imag)
qw_real[real_pos] = 1.0
qw_real[real_neg] = -1.0
qw_imag[imag_pos] = 1.0
qw_imag[imag_neg] = -1.0
qw_real_scaled = qw_real * s_re
qw_imag_scaled = qw_imag * s_im
return qw_real_scaled.to(w_real.dtype), qw_imag_scaled.to(w_imag.dtype)
@staticmethod
def backward(ctx, grad_w_real, grad_w_imag):
return grad_w_real, grad_w_imag
class PhaseQuantSTE_V2(torch.autograd.Function):
"""Two-step residual quantization"""
@staticmethod
def forward(ctx, w_real: torch.Tensor, w_imag: torch.Tensor):
qw_real_o1, qw_imag_o1 = PhaseQuantSTE.apply(w_real, w_imag)
error_real = w_real - qw_real_o1
error_imag = w_imag - qw_imag_o1
qw_real_o2, qw_imag_o2 = PhaseQuantSTE.apply(error_real, error_imag)
qw_real = qw_real_o1 + qw_real_o2
qw_imag = qw_imag_o1 + qw_imag_o2
return qw_real, qw_imag
@staticmethod
def backward(ctx, grad_real, grad_imag):
return grad_real, grad_imag
class PhaseQuantSTE_V3(torch.autograd.Function):
"""Three-step residual quantization"""
@staticmethod
def forward(ctx, w_real: torch.Tensor, w_imag: torch.Tensor):
qw_real_o1, qw_imag_o1 = PhaseQuantSTE.apply(w_real, w_imag)
error_real_1 = w_real - qw_real_o1
error_imag_1 = w_imag - qw_imag_o1
qw_real_o2, qw_imag_o2 = PhaseQuantSTE.apply(error_real_1, error_imag_1)
error_real_2 = error_real_1 - qw_real_o2
error_imag_2 = error_imag_1 - qw_imag_o2
qw_real_o3, qw_imag_o3 = PhaseQuantSTE.apply(error_real_2, error_imag_2)
qw_real = qw_real_o1 + qw_real_o2 + qw_real_o3
qw_imag = qw_imag_o1 + qw_imag_o2 + qw_imag_o3
return qw_real, qw_imag
@staticmethod
def backward(ctx, grad_real, grad_imag):
return grad_real, grad_imag
class PhaseQuantSTE_V4(torch.autograd.Function):
"""Four-step residual quantization"""
@staticmethod
def forward(ctx, w_real: torch.Tensor, w_imag: torch.Tensor):
qw_real_o1, qw_imag_o1 = PhaseQuantSTE.apply(w_real, w_imag)
error_real_1 = w_real - qw_real_o1
error_imag_1 = w_imag - qw_imag_o1
qw_real_o2, qw_imag_o2 = PhaseQuantSTE.apply(error_real_1, error_imag_1)
error_real_2 = error_real_1 - qw_real_o2
error_imag_2 = error_imag_1 - qw_imag_o2
qw_real_o3, qw_imag_o3 = PhaseQuantSTE.apply(error_real_2, error_imag_2)
error_real_3 = error_real_2 - qw_real_o3
error_imag_3 = error_imag_2 - qw_imag_o3
qw_real_o4, qw_imag_o4 = PhaseQuantSTE.apply(error_real_3, error_imag_3)
qw_real = qw_real_o1 + qw_real_o2 + qw_real_o3 + qw_real_o4
qw_imag = qw_imag_o1 + qw_imag_o2 + qw_imag_o3 + qw_imag_o4
return qw_real, qw_imag
@staticmethod
def backward(ctx, grad_real, grad_imag):
return grad_real, grad_imag
|