add quantization
Browse files- quantization.py +313 -0
quantization.py
ADDED
|
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Mixed-Precision Quantization Script for Small Language Models
|
| 3 |
+
Supports selective quantization of different model components with configurable bitwidths.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
|
| 9 |
+
import argparse
|
| 10 |
+
import os
|
| 11 |
+
import json
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from typing import Dict, Optional, Tuple
|
| 14 |
+
import time
|
| 15 |
+
|
| 16 |
+
class MixedPrecisionQuantizer:
|
| 17 |
+
"""
|
| 18 |
+
Quantizes model components with different precision levels.
|
| 19 |
+
Supports more aggressive quantization for attention layers while
|
| 20 |
+
preserving higher precision for FFN layers.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(
|
| 24 |
+
self,
|
| 25 |
+
model_name: str,
|
| 26 |
+
attention_bits: int = 4,
|
| 27 |
+
ffn_bits: int = 8,
|
| 28 |
+
embedding_bits: int = 8,
|
| 29 |
+
output_dir: str = "./quantized_models",
|
| 30 |
+
device: str = "cuda" if torch.cuda.is_available() else "cpu"
|
| 31 |
+
):
|
| 32 |
+
self.model_name = model_name
|
| 33 |
+
self.attention_bits = attention_bits
|
| 34 |
+
self.ffn_bits = ffn_bits
|
| 35 |
+
self.embedding_bits = embedding_bits
|
| 36 |
+
self.output_dir = Path(output_dir)
|
| 37 |
+
self.device = device
|
| 38 |
+
|
| 39 |
+
# Create output directory
|
| 40 |
+
self.output_dir.mkdir(parents=True, exist_ok=True)
|
| 41 |
+
|
| 42 |
+
print(f"Initializing quantizer for {model_name}")
|
| 43 |
+
print(f"Attention layers: {attention_bits}-bit")
|
| 44 |
+
print(f"FFN layers: {ffn_bits}-bit")
|
| 45 |
+
print(f"Embeddings: {embedding_bits}-bit")
|
| 46 |
+
print(f"Device: {device}")
|
| 47 |
+
|
| 48 |
+
def load_model(self) -> Tuple[nn.Module, AutoTokenizer]:
|
| 49 |
+
"""Load the pretrained model and tokenizer."""
|
| 50 |
+
print(f"\nLoading model: {self.model_name}")
|
| 51 |
+
start_time = time.time()
|
| 52 |
+
|
| 53 |
+
# Load with low_cpu_mem_usage for large models
|
| 54 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 55 |
+
self.model_name,
|
| 56 |
+
torch_dtype=torch.float32,
|
| 57 |
+
low_cpu_mem_usage=True,
|
| 58 |
+
trust_remote_code=True
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 62 |
+
self.model_name,
|
| 63 |
+
trust_remote_code=True
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
load_time = time.time() - start_time
|
| 67 |
+
print(f"Model loaded in {load_time:.2f} seconds")
|
| 68 |
+
|
| 69 |
+
# Calculate original model size
|
| 70 |
+
param_count = sum(p.numel() for p in model.parameters())
|
| 71 |
+
param_size_mb = sum(p.numel() * p.element_size() for p in model.parameters()) / (1024 ** 2)
|
| 72 |
+
print(f"Parameters: {param_count:,} ({param_size_mb:.2f} MB)")
|
| 73 |
+
|
| 74 |
+
return model, tokenizer
|
| 75 |
+
|
| 76 |
+
def quantize_linear_layer(self, layer: nn.Linear, bits: int) -> nn.Linear:
|
| 77 |
+
"""
|
| 78 |
+
Quantize a linear layer to specified bit width using symmetric quantization.
|
| 79 |
+
"""
|
| 80 |
+
if bits == 32:
|
| 81 |
+
return layer
|
| 82 |
+
|
| 83 |
+
weight = layer.weight.data
|
| 84 |
+
bias = layer.bias.data if layer.bias is not None else None
|
| 85 |
+
|
| 86 |
+
# Symmetric quantization
|
| 87 |
+
qmin = -(2 ** (bits - 1))
|
| 88 |
+
qmax = 2 ** (bits - 1) - 1
|
| 89 |
+
|
| 90 |
+
# Calculate scale
|
| 91 |
+
max_val = torch.max(torch.abs(weight))
|
| 92 |
+
scale = max_val / qmax
|
| 93 |
+
|
| 94 |
+
# Quantize
|
| 95 |
+
weight_q = torch.clamp(torch.round(weight / scale), qmin, qmax)
|
| 96 |
+
|
| 97 |
+
# Store quantized weights and scale
|
| 98 |
+
layer.weight.data = weight_q.to(torch.int8 if bits <= 8 else torch.int16)
|
| 99 |
+
layer.weight_scale = scale
|
| 100 |
+
layer.quantized = True
|
| 101 |
+
layer.bits = bits
|
| 102 |
+
|
| 103 |
+
return layer
|
| 104 |
+
|
| 105 |
+
def identify_layer_type(self, name: str, module: nn.Module) -> str:
|
| 106 |
+
"""
|
| 107 |
+
Identify if a layer is part of attention, FFN, embedding, or other components.
|
| 108 |
+
"""
|
| 109 |
+
name_lower = name.lower()
|
| 110 |
+
|
| 111 |
+
# Attention-related patterns
|
| 112 |
+
attention_patterns = [
|
| 113 |
+
'attn', 'attention', 'q_proj', 'k_proj', 'v_proj',
|
| 114 |
+
'qkv', 'query', 'key', 'value', 'o_proj', 'out_proj',
|
| 115 |
+
'c_attn', 'c_proj'
|
| 116 |
+
]
|
| 117 |
+
|
| 118 |
+
# FFN-related patterns
|
| 119 |
+
ffn_patterns = [
|
| 120 |
+
'mlp', 'ffn', 'fc', 'dense', 'intermediate',
|
| 121 |
+
'gate_proj', 'up_proj', 'down_proj', 'w1', 'w2', 'w3'
|
| 122 |
+
]
|
| 123 |
+
|
| 124 |
+
# Embedding patterns
|
| 125 |
+
embedding_patterns = ['embed', 'wte', 'wpe', 'lm_head']
|
| 126 |
+
|
| 127 |
+
if any(pattern in name_lower for pattern in attention_patterns):
|
| 128 |
+
return 'attention'
|
| 129 |
+
elif any(pattern in name_lower for pattern in ffn_patterns):
|
| 130 |
+
return 'ffn'
|
| 131 |
+
elif any(pattern in name_lower for pattern in embedding_patterns):
|
| 132 |
+
return 'embedding'
|
| 133 |
+
else:
|
| 134 |
+
return 'other'
|
| 135 |
+
|
| 136 |
+
def quantize_model(self, model: nn.Module) -> Tuple[nn.Module, Dict]:
|
| 137 |
+
"""
|
| 138 |
+
Apply mixed-precision quantization to the model.
|
| 139 |
+
"""
|
| 140 |
+
print("\nApplying mixed-precision quantization...")
|
| 141 |
+
start_time = time.time()
|
| 142 |
+
|
| 143 |
+
stats = {
|
| 144 |
+
'attention_layers': 0,
|
| 145 |
+
'ffn_layers': 0,
|
| 146 |
+
'embedding_layers': 0,
|
| 147 |
+
'other_layers': 0,
|
| 148 |
+
'total_quantized': 0
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
# Iterate through all modules
|
| 152 |
+
for name, module in model.named_modules():
|
| 153 |
+
if isinstance(module, nn.Linear):
|
| 154 |
+
layer_type = self.identify_layer_type(name, module)
|
| 155 |
+
|
| 156 |
+
# Select quantization bitwidth based on layer type
|
| 157 |
+
if layer_type == 'attention':
|
| 158 |
+
bits = self.attention_bits
|
| 159 |
+
stats['attention_layers'] += 1
|
| 160 |
+
elif layer_type == 'ffn':
|
| 161 |
+
bits = self.ffn_bits
|
| 162 |
+
stats['ffn_layers'] += 1
|
| 163 |
+
elif layer_type == 'embedding':
|
| 164 |
+
bits = self.embedding_bits
|
| 165 |
+
stats['embedding_layers'] += 1
|
| 166 |
+
else:
|
| 167 |
+
bits = self.ffn_bits # Default to FFN bitwidth
|
| 168 |
+
stats['other_layers'] += 1
|
| 169 |
+
|
| 170 |
+
# Quantize the layer
|
| 171 |
+
self.quantize_linear_layer(module, bits)
|
| 172 |
+
stats['total_quantized'] += 1
|
| 173 |
+
|
| 174 |
+
quant_time = time.time() - start_time
|
| 175 |
+
print(f"\nQuantization completed in {quant_time:.2f} seconds")
|
| 176 |
+
print(f"Quantized layers breakdown:")
|
| 177 |
+
print(f" - Attention: {stats['attention_layers']} layers ({self.attention_bits}-bit)")
|
| 178 |
+
print(f" - FFN: {stats['ffn_layers']} layers ({self.ffn_bits}-bit)")
|
| 179 |
+
print(f" - Embedding: {stats['embedding_layers']} layers ({self.embedding_bits}-bit)")
|
| 180 |
+
print(f" - Other: {stats['other_layers']} layers ({self.ffn_bits}-bit)")
|
| 181 |
+
print(f" - Total quantized: {stats['total_quantized']} layers")
|
| 182 |
+
|
| 183 |
+
return model, stats
|
| 184 |
+
|
| 185 |
+
def save_quantized_model(
|
| 186 |
+
self,
|
| 187 |
+
model: nn.Module,
|
| 188 |
+
tokenizer: AutoTokenizer,
|
| 189 |
+
stats: Dict
|
| 190 |
+
) -> str:
|
| 191 |
+
"""Save the quantized model, tokenizer, and metadata."""
|
| 192 |
+
# Create model-specific output directory
|
| 193 |
+
model_short_name = self.model_name.split('/')[-1]
|
| 194 |
+
quant_config = f"attn{self.attention_bits}_ffn{self.ffn_bits}_emb{self.embedding_bits}"
|
| 195 |
+
save_dir = self.output_dir / f"{model_short_name}_{quant_config}"
|
| 196 |
+
save_dir.mkdir(parents=True, exist_ok=True)
|
| 197 |
+
|
| 198 |
+
print(f"\nSaving quantized model to: {save_dir}")
|
| 199 |
+
|
| 200 |
+
# Save model
|
| 201 |
+
model.save_pretrained(save_dir)
|
| 202 |
+
|
| 203 |
+
# Save tokenizer
|
| 204 |
+
tokenizer.save_pretrained(save_dir)
|
| 205 |
+
|
| 206 |
+
# Calculate quantized model size
|
| 207 |
+
quantized_size_mb = sum(
|
| 208 |
+
p.numel() * p.element_size() for p in model.parameters()
|
| 209 |
+
) / (1024 ** 2)
|
| 210 |
+
|
| 211 |
+
# Save metadata
|
| 212 |
+
metadata = {
|
| 213 |
+
'original_model': self.model_name,
|
| 214 |
+
'quantization_config': {
|
| 215 |
+
'attention_bits': self.attention_bits,
|
| 216 |
+
'ffn_bits': self.ffn_bits,
|
| 217 |
+
'embedding_bits': self.embedding_bits
|
| 218 |
+
},
|
| 219 |
+
'layer_stats': stats,
|
| 220 |
+
'model_size_mb': quantized_size_mb,
|
| 221 |
+
'quantization_timestamp': time.strftime('%Y-%m-%d %H:%M:%S')
|
| 222 |
+
}
|
| 223 |
+
|
| 224 |
+
with open(save_dir / 'quantization_metadata.json', 'w') as f:
|
| 225 |
+
json.dump(metadata, f, indent=2)
|
| 226 |
+
|
| 227 |
+
print(f"Quantized model size: {quantized_size_mb:.2f} MB")
|
| 228 |
+
print(f"Metadata saved to: {save_dir / 'quantization_metadata.json'}")
|
| 229 |
+
|
| 230 |
+
return str(save_dir)
|
| 231 |
+
|
| 232 |
+
def run(self) -> str:
|
| 233 |
+
"""Execute the full quantization pipeline."""
|
| 234 |
+
print("=" * 80)
|
| 235 |
+
print("MIXED-PRECISION QUANTIZATION PIPELINE")
|
| 236 |
+
print("=" * 80)
|
| 237 |
+
|
| 238 |
+
# Load model
|
| 239 |
+
model, tokenizer = self.load_model()
|
| 240 |
+
|
| 241 |
+
# Quantize model
|
| 242 |
+
quantized_model, stats = self.quantize_model(model)
|
| 243 |
+
|
| 244 |
+
# Save quantized model
|
| 245 |
+
save_path = self.save_quantized_model(quantized_model, tokenizer, stats)
|
| 246 |
+
|
| 247 |
+
print("\n" + "=" * 80)
|
| 248 |
+
print("QUANTIZATION COMPLETE")
|
| 249 |
+
print("=" * 80)
|
| 250 |
+
print(f"Saved to: {save_path}")
|
| 251 |
+
|
| 252 |
+
return save_path
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def main():
|
| 256 |
+
parser = argparse.ArgumentParser(
|
| 257 |
+
description="Mixed-Precision Quantization for Small Language Models"
|
| 258 |
+
)
|
| 259 |
+
parser.add_argument(
|
| 260 |
+
'--model_name',
|
| 261 |
+
type=str,
|
| 262 |
+
required=True,
|
| 263 |
+
help='HuggingFace model name or path'
|
| 264 |
+
)
|
| 265 |
+
parser.add_argument(
|
| 266 |
+
'--attention_bits',
|
| 267 |
+
type=int,
|
| 268 |
+
default=4,
|
| 269 |
+
help='Bit width for attention layers (default: 4)'
|
| 270 |
+
)
|
| 271 |
+
parser.add_argument(
|
| 272 |
+
'--ffn_bits',
|
| 273 |
+
type=int,
|
| 274 |
+
default=8,
|
| 275 |
+
help='Bit width for FFN layers (default: 8)'
|
| 276 |
+
)
|
| 277 |
+
parser.add_argument(
|
| 278 |
+
'--embedding_bits',
|
| 279 |
+
type=int,
|
| 280 |
+
default=8,
|
| 281 |
+
help='Bit width for embedding layers (default: 8)'
|
| 282 |
+
)
|
| 283 |
+
parser.add_argument(
|
| 284 |
+
'--output_dir',
|
| 285 |
+
type=str,
|
| 286 |
+
default='./quantized_models',
|
| 287 |
+
help='Output directory for quantized models'
|
| 288 |
+
)
|
| 289 |
+
parser.add_argument(
|
| 290 |
+
'--device',
|
| 291 |
+
type=str,
|
| 292 |
+
default='cuda' if torch.cuda.is_available() else 'cpu',
|
| 293 |
+
help='Device to use (cuda/cpu)'
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
args = parser.parse_args()
|
| 297 |
+
|
| 298 |
+
# Initialize quantizer
|
| 299 |
+
quantizer = MixedPrecisionQuantizer(
|
| 300 |
+
model_name=args.model_name,
|
| 301 |
+
attention_bits=args.attention_bits,
|
| 302 |
+
ffn_bits=args.ffn_bits,
|
| 303 |
+
embedding_bits=args.embedding_bits,
|
| 304 |
+
output_dir=args.output_dir,
|
| 305 |
+
device=args.device
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
# Run quantization
|
| 309 |
+
quantizer.run()
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
if __name__ == "__main__":
|
| 313 |
+
main()
|