Charlie81 commited on
Commit
ba1b797
·
1 Parent(s): 3130167

add quantization

Browse files
Files changed (1) hide show
  1. quantization.py +313 -0
quantization.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Mixed-Precision Quantization Script for Small Language Models
3
+ Supports selective quantization of different model components with configurable bitwidths.
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
9
+ import argparse
10
+ import os
11
+ import json
12
+ from pathlib import Path
13
+ from typing import Dict, Optional, Tuple
14
+ import time
15
+
16
+ class MixedPrecisionQuantizer:
17
+ """
18
+ Quantizes model components with different precision levels.
19
+ Supports more aggressive quantization for attention layers while
20
+ preserving higher precision for FFN layers.
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ model_name: str,
26
+ attention_bits: int = 4,
27
+ ffn_bits: int = 8,
28
+ embedding_bits: int = 8,
29
+ output_dir: str = "./quantized_models",
30
+ device: str = "cuda" if torch.cuda.is_available() else "cpu"
31
+ ):
32
+ self.model_name = model_name
33
+ self.attention_bits = attention_bits
34
+ self.ffn_bits = ffn_bits
35
+ self.embedding_bits = embedding_bits
36
+ self.output_dir = Path(output_dir)
37
+ self.device = device
38
+
39
+ # Create output directory
40
+ self.output_dir.mkdir(parents=True, exist_ok=True)
41
+
42
+ print(f"Initializing quantizer for {model_name}")
43
+ print(f"Attention layers: {attention_bits}-bit")
44
+ print(f"FFN layers: {ffn_bits}-bit")
45
+ print(f"Embeddings: {embedding_bits}-bit")
46
+ print(f"Device: {device}")
47
+
48
+ def load_model(self) -> Tuple[nn.Module, AutoTokenizer]:
49
+ """Load the pretrained model and tokenizer."""
50
+ print(f"\nLoading model: {self.model_name}")
51
+ start_time = time.time()
52
+
53
+ # Load with low_cpu_mem_usage for large models
54
+ model = AutoModelForCausalLM.from_pretrained(
55
+ self.model_name,
56
+ torch_dtype=torch.float32,
57
+ low_cpu_mem_usage=True,
58
+ trust_remote_code=True
59
+ )
60
+
61
+ tokenizer = AutoTokenizer.from_pretrained(
62
+ self.model_name,
63
+ trust_remote_code=True
64
+ )
65
+
66
+ load_time = time.time() - start_time
67
+ print(f"Model loaded in {load_time:.2f} seconds")
68
+
69
+ # Calculate original model size
70
+ param_count = sum(p.numel() for p in model.parameters())
71
+ param_size_mb = sum(p.numel() * p.element_size() for p in model.parameters()) / (1024 ** 2)
72
+ print(f"Parameters: {param_count:,} ({param_size_mb:.2f} MB)")
73
+
74
+ return model, tokenizer
75
+
76
+ def quantize_linear_layer(self, layer: nn.Linear, bits: int) -> nn.Linear:
77
+ """
78
+ Quantize a linear layer to specified bit width using symmetric quantization.
79
+ """
80
+ if bits == 32:
81
+ return layer
82
+
83
+ weight = layer.weight.data
84
+ bias = layer.bias.data if layer.bias is not None else None
85
+
86
+ # Symmetric quantization
87
+ qmin = -(2 ** (bits - 1))
88
+ qmax = 2 ** (bits - 1) - 1
89
+
90
+ # Calculate scale
91
+ max_val = torch.max(torch.abs(weight))
92
+ scale = max_val / qmax
93
+
94
+ # Quantize
95
+ weight_q = torch.clamp(torch.round(weight / scale), qmin, qmax)
96
+
97
+ # Store quantized weights and scale
98
+ layer.weight.data = weight_q.to(torch.int8 if bits <= 8 else torch.int16)
99
+ layer.weight_scale = scale
100
+ layer.quantized = True
101
+ layer.bits = bits
102
+
103
+ return layer
104
+
105
+ def identify_layer_type(self, name: str, module: nn.Module) -> str:
106
+ """
107
+ Identify if a layer is part of attention, FFN, embedding, or other components.
108
+ """
109
+ name_lower = name.lower()
110
+
111
+ # Attention-related patterns
112
+ attention_patterns = [
113
+ 'attn', 'attention', 'q_proj', 'k_proj', 'v_proj',
114
+ 'qkv', 'query', 'key', 'value', 'o_proj', 'out_proj',
115
+ 'c_attn', 'c_proj'
116
+ ]
117
+
118
+ # FFN-related patterns
119
+ ffn_patterns = [
120
+ 'mlp', 'ffn', 'fc', 'dense', 'intermediate',
121
+ 'gate_proj', 'up_proj', 'down_proj', 'w1', 'w2', 'w3'
122
+ ]
123
+
124
+ # Embedding patterns
125
+ embedding_patterns = ['embed', 'wte', 'wpe', 'lm_head']
126
+
127
+ if any(pattern in name_lower for pattern in attention_patterns):
128
+ return 'attention'
129
+ elif any(pattern in name_lower for pattern in ffn_patterns):
130
+ return 'ffn'
131
+ elif any(pattern in name_lower for pattern in embedding_patterns):
132
+ return 'embedding'
133
+ else:
134
+ return 'other'
135
+
136
+ def quantize_model(self, model: nn.Module) -> Tuple[nn.Module, Dict]:
137
+ """
138
+ Apply mixed-precision quantization to the model.
139
+ """
140
+ print("\nApplying mixed-precision quantization...")
141
+ start_time = time.time()
142
+
143
+ stats = {
144
+ 'attention_layers': 0,
145
+ 'ffn_layers': 0,
146
+ 'embedding_layers': 0,
147
+ 'other_layers': 0,
148
+ 'total_quantized': 0
149
+ }
150
+
151
+ # Iterate through all modules
152
+ for name, module in model.named_modules():
153
+ if isinstance(module, nn.Linear):
154
+ layer_type = self.identify_layer_type(name, module)
155
+
156
+ # Select quantization bitwidth based on layer type
157
+ if layer_type == 'attention':
158
+ bits = self.attention_bits
159
+ stats['attention_layers'] += 1
160
+ elif layer_type == 'ffn':
161
+ bits = self.ffn_bits
162
+ stats['ffn_layers'] += 1
163
+ elif layer_type == 'embedding':
164
+ bits = self.embedding_bits
165
+ stats['embedding_layers'] += 1
166
+ else:
167
+ bits = self.ffn_bits # Default to FFN bitwidth
168
+ stats['other_layers'] += 1
169
+
170
+ # Quantize the layer
171
+ self.quantize_linear_layer(module, bits)
172
+ stats['total_quantized'] += 1
173
+
174
+ quant_time = time.time() - start_time
175
+ print(f"\nQuantization completed in {quant_time:.2f} seconds")
176
+ print(f"Quantized layers breakdown:")
177
+ print(f" - Attention: {stats['attention_layers']} layers ({self.attention_bits}-bit)")
178
+ print(f" - FFN: {stats['ffn_layers']} layers ({self.ffn_bits}-bit)")
179
+ print(f" - Embedding: {stats['embedding_layers']} layers ({self.embedding_bits}-bit)")
180
+ print(f" - Other: {stats['other_layers']} layers ({self.ffn_bits}-bit)")
181
+ print(f" - Total quantized: {stats['total_quantized']} layers")
182
+
183
+ return model, stats
184
+
185
+ def save_quantized_model(
186
+ self,
187
+ model: nn.Module,
188
+ tokenizer: AutoTokenizer,
189
+ stats: Dict
190
+ ) -> str:
191
+ """Save the quantized model, tokenizer, and metadata."""
192
+ # Create model-specific output directory
193
+ model_short_name = self.model_name.split('/')[-1]
194
+ quant_config = f"attn{self.attention_bits}_ffn{self.ffn_bits}_emb{self.embedding_bits}"
195
+ save_dir = self.output_dir / f"{model_short_name}_{quant_config}"
196
+ save_dir.mkdir(parents=True, exist_ok=True)
197
+
198
+ print(f"\nSaving quantized model to: {save_dir}")
199
+
200
+ # Save model
201
+ model.save_pretrained(save_dir)
202
+
203
+ # Save tokenizer
204
+ tokenizer.save_pretrained(save_dir)
205
+
206
+ # Calculate quantized model size
207
+ quantized_size_mb = sum(
208
+ p.numel() * p.element_size() for p in model.parameters()
209
+ ) / (1024 ** 2)
210
+
211
+ # Save metadata
212
+ metadata = {
213
+ 'original_model': self.model_name,
214
+ 'quantization_config': {
215
+ 'attention_bits': self.attention_bits,
216
+ 'ffn_bits': self.ffn_bits,
217
+ 'embedding_bits': self.embedding_bits
218
+ },
219
+ 'layer_stats': stats,
220
+ 'model_size_mb': quantized_size_mb,
221
+ 'quantization_timestamp': time.strftime('%Y-%m-%d %H:%M:%S')
222
+ }
223
+
224
+ with open(save_dir / 'quantization_metadata.json', 'w') as f:
225
+ json.dump(metadata, f, indent=2)
226
+
227
+ print(f"Quantized model size: {quantized_size_mb:.2f} MB")
228
+ print(f"Metadata saved to: {save_dir / 'quantization_metadata.json'}")
229
+
230
+ return str(save_dir)
231
+
232
+ def run(self) -> str:
233
+ """Execute the full quantization pipeline."""
234
+ print("=" * 80)
235
+ print("MIXED-PRECISION QUANTIZATION PIPELINE")
236
+ print("=" * 80)
237
+
238
+ # Load model
239
+ model, tokenizer = self.load_model()
240
+
241
+ # Quantize model
242
+ quantized_model, stats = self.quantize_model(model)
243
+
244
+ # Save quantized model
245
+ save_path = self.save_quantized_model(quantized_model, tokenizer, stats)
246
+
247
+ print("\n" + "=" * 80)
248
+ print("QUANTIZATION COMPLETE")
249
+ print("=" * 80)
250
+ print(f"Saved to: {save_path}")
251
+
252
+ return save_path
253
+
254
+
255
+ def main():
256
+ parser = argparse.ArgumentParser(
257
+ description="Mixed-Precision Quantization for Small Language Models"
258
+ )
259
+ parser.add_argument(
260
+ '--model_name',
261
+ type=str,
262
+ required=True,
263
+ help='HuggingFace model name or path'
264
+ )
265
+ parser.add_argument(
266
+ '--attention_bits',
267
+ type=int,
268
+ default=4,
269
+ help='Bit width for attention layers (default: 4)'
270
+ )
271
+ parser.add_argument(
272
+ '--ffn_bits',
273
+ type=int,
274
+ default=8,
275
+ help='Bit width for FFN layers (default: 8)'
276
+ )
277
+ parser.add_argument(
278
+ '--embedding_bits',
279
+ type=int,
280
+ default=8,
281
+ help='Bit width for embedding layers (default: 8)'
282
+ )
283
+ parser.add_argument(
284
+ '--output_dir',
285
+ type=str,
286
+ default='./quantized_models',
287
+ help='Output directory for quantized models'
288
+ )
289
+ parser.add_argument(
290
+ '--device',
291
+ type=str,
292
+ default='cuda' if torch.cuda.is_available() else 'cpu',
293
+ help='Device to use (cuda/cpu)'
294
+ )
295
+
296
+ args = parser.parse_args()
297
+
298
+ # Initialize quantizer
299
+ quantizer = MixedPrecisionQuantizer(
300
+ model_name=args.model_name,
301
+ attention_bits=args.attention_bits,
302
+ ffn_bits=args.ffn_bits,
303
+ embedding_bits=args.embedding_bits,
304
+ output_dir=args.output_dir,
305
+ device=args.device
306
+ )
307
+
308
+ # Run quantization
309
+ quantizer.run()
310
+
311
+
312
+ if __name__ == "__main__":
313
+ main()