File size: 5,050 Bytes
0f8aec6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
"""AAM Diffusion LLM — Quantization

BitNet 1-bit weights and FP8 training stubs.
Included for completeness — AAM's model is small enough that
quantization is not yet critical, but this prepares for future scaling.
"""

from __future__ import annotations

from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F


class BitLinear(nn.Module):
    """1-bit weight quantization layer (BitNet-style).
    
    During training: uses full-precision weights with straight-through estimator
    During inference: uses binarized weights (-1 or +1)
    
    Note: Only practical for models >1B params. AAM's current size
    doesn't benefit from this, but it's included for future scaling.
    """

    def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features

        self.weight = nn.Parameter(torch.randn(out_features, in_features) * 0.02)
        if bias:
            self.bias = nn.Parameter(torch.zeros(out_features))
        else:
            self.register_parameter("bias", None)

        # Scale factor for binarized weights
        self.register_buffer("weight_scale", torch.ones(1), persistent=True)

    def _binarize_weights(self) -> torch.Tensor:
        """Binarize weights to -1 or +1 using sign function."""
        with torch.no_grad():
            self.weight_scale.copy_(self.weight.abs().mean())
        binary_weight = torch.sign(self.weight)
        return binary_weight * self.weight_scale

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.training:
            # Straight-through estimator: forward uses binarized, backward uses full-precision
            binary_weight = torch.sign(self.weight) * self.weight_scale
            output = F.linear(x, binary_weight, self.bias)
        else:
            binary_weight = self._binarize_weights()
            output = F.linear(x, binary_weight, self.bias)

        return output


class FP8Linear(nn.Module):
    """FP8 weight-only quantization layer.
    
    Stores weights in FP8 (E4M3) format for memory efficiency.
    Computation is done in higher precision after dequantization.
    
    Note: Requires hardware with FP8 support (H100, MI300X).
    Falls back to FP32/BF16 on unsupported hardware.
    """

    def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features

        # Store in FP32 for training, quantize for inference
        self.weight = nn.Parameter(torch.randn(out_features, in_features) * 0.02)
        if bias:
            self.bias = nn.Parameter(torch.zeros(out_features))
        else:
            self.register_parameter("bias", None)

        self._fp8_available = hasattr(torch, "float8_e4m3fn")

    def _quantize_fp8(self, weight: torch.Tensor) -> torch.Tensor:
        """Quantize weights to FP8 if supported."""
        if not self._fp8_available:
            return weight

        # Scale to FP8 range
        max_val = weight.abs().max()
        scale = max_val / 448.0  # E4M3 max value
        scaled = weight / scale.clamp(min=1e-8)

        try:
            fp8_weight = scaled.to(torch.float8_e4m3fn)
            dequantized = fp8_weight.to(torch.float32) * scale
            return dequantized
        except (RuntimeError, TypeError):
            return weight

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if not self.training and self._fp8_available:
            weight = self._quantize_fp8(self.weight)
        else:
            weight = self.weight

        return F.linear(x, weight, self.bias)


def replace_linear_with_quantized(
    model: nn.Module,
    quantization_type: str = "bitnet",
) -> nn.Module:
    """Replace all nn.Linear layers with quantized versions.
    
    Args:
        model: The model to quantize
        quantization_type: "bitnet" or "fp8"
        
    Returns:
        Model with quantized linear layers
    """
    QuantClass = BitLinear if quantization_type == "bitnet" else FP8Linear

    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            # Skip the final vocab projection
            if "lm_head" in name or "vocab_proj" in name:
                continue

            quantized = QuantClass(
                in_features=module.in_features,
                out_features=module.out_features,
                bias=module.bias is not None,
            )

            # Copy weights
            with torch.no_grad():
                quantized.weight.copy_(module.weight)
                if module.bias is not None:
                    quantized.bias.copy_(module.bias)

            # Replace in parent module
            *path, attr = name.split(".")
            parent = model
            for p in path:
                parent = getattr(parent, p)
            setattr(parent, attr, quantized)

    return model