Wolfvin commited on
Commit
0f8aec6
·
verified ·
1 Parent(s): 46c5bd3

Upload diffusion_llm/model/quantization.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. diffusion_llm/model/quantization.py +150 -0
diffusion_llm/model/quantization.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """AAM Diffusion LLM — Quantization
2
+
3
+ BitNet 1-bit weights and FP8 training stubs.
4
+ Included for completeness — AAM's model is small enough that
5
+ quantization is not yet critical, but this prepares for future scaling.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from typing import Optional
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+
16
+
17
+ class BitLinear(nn.Module):
18
+ """1-bit weight quantization layer (BitNet-style).
19
+
20
+ During training: uses full-precision weights with straight-through estimator
21
+ During inference: uses binarized weights (-1 or +1)
22
+
23
+ Note: Only practical for models >1B params. AAM's current size
24
+ doesn't benefit from this, but it's included for future scaling.
25
+ """
26
+
27
+ def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
28
+ super().__init__()
29
+ self.in_features = in_features
30
+ self.out_features = out_features
31
+
32
+ self.weight = nn.Parameter(torch.randn(out_features, in_features) * 0.02)
33
+ if bias:
34
+ self.bias = nn.Parameter(torch.zeros(out_features))
35
+ else:
36
+ self.register_parameter("bias", None)
37
+
38
+ # Scale factor for binarized weights
39
+ self.register_buffer("weight_scale", torch.ones(1), persistent=True)
40
+
41
+ def _binarize_weights(self) -> torch.Tensor:
42
+ """Binarize weights to -1 or +1 using sign function."""
43
+ with torch.no_grad():
44
+ self.weight_scale.copy_(self.weight.abs().mean())
45
+ binary_weight = torch.sign(self.weight)
46
+ return binary_weight * self.weight_scale
47
+
48
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
49
+ if self.training:
50
+ # Straight-through estimator: forward uses binarized, backward uses full-precision
51
+ binary_weight = torch.sign(self.weight) * self.weight_scale
52
+ output = F.linear(x, binary_weight, self.bias)
53
+ else:
54
+ binary_weight = self._binarize_weights()
55
+ output = F.linear(x, binary_weight, self.bias)
56
+
57
+ return output
58
+
59
+
60
+ class FP8Linear(nn.Module):
61
+ """FP8 weight-only quantization layer.
62
+
63
+ Stores weights in FP8 (E4M3) format for memory efficiency.
64
+ Computation is done in higher precision after dequantization.
65
+
66
+ Note: Requires hardware with FP8 support (H100, MI300X).
67
+ Falls back to FP32/BF16 on unsupported hardware.
68
+ """
69
+
70
+ def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
71
+ super().__init__()
72
+ self.in_features = in_features
73
+ self.out_features = out_features
74
+
75
+ # Store in FP32 for training, quantize for inference
76
+ self.weight = nn.Parameter(torch.randn(out_features, in_features) * 0.02)
77
+ if bias:
78
+ self.bias = nn.Parameter(torch.zeros(out_features))
79
+ else:
80
+ self.register_parameter("bias", None)
81
+
82
+ self._fp8_available = hasattr(torch, "float8_e4m3fn")
83
+
84
+ def _quantize_fp8(self, weight: torch.Tensor) -> torch.Tensor:
85
+ """Quantize weights to FP8 if supported."""
86
+ if not self._fp8_available:
87
+ return weight
88
+
89
+ # Scale to FP8 range
90
+ max_val = weight.abs().max()
91
+ scale = max_val / 448.0 # E4M3 max value
92
+ scaled = weight / scale.clamp(min=1e-8)
93
+
94
+ try:
95
+ fp8_weight = scaled.to(torch.float8_e4m3fn)
96
+ dequantized = fp8_weight.to(torch.float32) * scale
97
+ return dequantized
98
+ except (RuntimeError, TypeError):
99
+ return weight
100
+
101
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
102
+ if not self.training and self._fp8_available:
103
+ weight = self._quantize_fp8(self.weight)
104
+ else:
105
+ weight = self.weight
106
+
107
+ return F.linear(x, weight, self.bias)
108
+
109
+
110
+ def replace_linear_with_quantized(
111
+ model: nn.Module,
112
+ quantization_type: str = "bitnet",
113
+ ) -> nn.Module:
114
+ """Replace all nn.Linear layers with quantized versions.
115
+
116
+ Args:
117
+ model: The model to quantize
118
+ quantization_type: "bitnet" or "fp8"
119
+
120
+ Returns:
121
+ Model with quantized linear layers
122
+ """
123
+ QuantClass = BitLinear if quantization_type == "bitnet" else FP8Linear
124
+
125
+ for name, module in model.named_modules():
126
+ if isinstance(module, nn.Linear):
127
+ # Skip the final vocab projection
128
+ if "lm_head" in name or "vocab_proj" in name:
129
+ continue
130
+
131
+ quantized = QuantClass(
132
+ in_features=module.in_features,
133
+ out_features=module.out_features,
134
+ bias=module.bias is not None,
135
+ )
136
+
137
+ # Copy weights
138
+ with torch.no_grad():
139
+ quantized.weight.copy_(module.weight)
140
+ if module.bias is not None:
141
+ quantized.bias.copy_(module.bias)
142
+
143
+ # Replace in parent module
144
+ *path, attr = name.split(".")
145
+ parent = model
146
+ for p in path:
147
+ parent = getattr(parent, p)
148
+ setattr(parent, attr, quantized)
149
+
150
+ return model