File size: 5,050 Bytes
0f8aec6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 | """AAM Diffusion LLM — Quantization
BitNet 1-bit weights and FP8 training stubs.
Included for completeness — AAM's model is small enough that
quantization is not yet critical, but this prepares for future scaling.
"""
from __future__ import annotations
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
class BitLinear(nn.Module):
"""1-bit weight quantization layer (BitNet-style).
During training: uses full-precision weights with straight-through estimator
During inference: uses binarized weights (-1 or +1)
Note: Only practical for models >1B params. AAM's current size
doesn't benefit from this, but it's included for future scaling.
"""
def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = nn.Parameter(torch.randn(out_features, in_features) * 0.02)
if bias:
self.bias = nn.Parameter(torch.zeros(out_features))
else:
self.register_parameter("bias", None)
# Scale factor for binarized weights
self.register_buffer("weight_scale", torch.ones(1), persistent=True)
def _binarize_weights(self) -> torch.Tensor:
"""Binarize weights to -1 or +1 using sign function."""
with torch.no_grad():
self.weight_scale.copy_(self.weight.abs().mean())
binary_weight = torch.sign(self.weight)
return binary_weight * self.weight_scale
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.training:
# Straight-through estimator: forward uses binarized, backward uses full-precision
binary_weight = torch.sign(self.weight) * self.weight_scale
output = F.linear(x, binary_weight, self.bias)
else:
binary_weight = self._binarize_weights()
output = F.linear(x, binary_weight, self.bias)
return output
class FP8Linear(nn.Module):
"""FP8 weight-only quantization layer.
Stores weights in FP8 (E4M3) format for memory efficiency.
Computation is done in higher precision after dequantization.
Note: Requires hardware with FP8 support (H100, MI300X).
Falls back to FP32/BF16 on unsupported hardware.
"""
def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
super().__init__()
self.in_features = in_features
self.out_features = out_features
# Store in FP32 for training, quantize for inference
self.weight = nn.Parameter(torch.randn(out_features, in_features) * 0.02)
if bias:
self.bias = nn.Parameter(torch.zeros(out_features))
else:
self.register_parameter("bias", None)
self._fp8_available = hasattr(torch, "float8_e4m3fn")
def _quantize_fp8(self, weight: torch.Tensor) -> torch.Tensor:
"""Quantize weights to FP8 if supported."""
if not self._fp8_available:
return weight
# Scale to FP8 range
max_val = weight.abs().max()
scale = max_val / 448.0 # E4M3 max value
scaled = weight / scale.clamp(min=1e-8)
try:
fp8_weight = scaled.to(torch.float8_e4m3fn)
dequantized = fp8_weight.to(torch.float32) * scale
return dequantized
except (RuntimeError, TypeError):
return weight
def forward(self, x: torch.Tensor) -> torch.Tensor:
if not self.training and self._fp8_available:
weight = self._quantize_fp8(self.weight)
else:
weight = self.weight
return F.linear(x, weight, self.bias)
def replace_linear_with_quantized(
model: nn.Module,
quantization_type: str = "bitnet",
) -> nn.Module:
"""Replace all nn.Linear layers with quantized versions.
Args:
model: The model to quantize
quantization_type: "bitnet" or "fp8"
Returns:
Model with quantized linear layers
"""
QuantClass = BitLinear if quantization_type == "bitnet" else FP8Linear
for name, module in model.named_modules():
if isinstance(module, nn.Linear):
# Skip the final vocab projection
if "lm_head" in name or "vocab_proj" in name:
continue
quantized = QuantClass(
in_features=module.in_features,
out_features=module.out_features,
bias=module.bias is not None,
)
# Copy weights
with torch.no_grad():
quantized.weight.copy_(module.weight)
if module.bias is not None:
quantized.bias.copy_(module.bias)
# Replace in parent module
*path, attr = name.split(".")
parent = model
for p in path:
parent = getattr(parent, p)
setattr(parent, attr, quantized)
return model
|