File size: 15,953 Bytes
f4c7b32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import gradio as gr
import matplotlib.pyplot as plt
from matplotlib.colors import TwoSlopeNorm
import io
from PIL import Image

# Implementation of the W8A16LinearLayer
class W8A16LinearLayer(nn.Module):
    def __init__(self, in_features, out_features, bias=True, dtype=torch.float32):
        super().__init__()
        
        self.register_buffer(
            "int8_weights",
            torch.randint(
                -128, 127, (out_features, in_features), dtype=torch.int8
            )
        )
        
        self.register_buffer("scales", 
                         torch.randn((out_features), dtype=dtype))
        
        if bias:
            self.register_buffer("bias", 
                             torch.randn((1, out_features), 
                                       dtype=dtype))
        else:
            self.bias = None
            
    def quantize(self, weights):
        """
        Quantize floating point weights to int8 precision
        
        Args:
            weights: Tensor of weights to quantize (shape: out_features x in_features)
            
        Returns:
            None (updates the int8_weights and scales directly)
        """
        w_fp32 = weights.clone().to(torch.float32)

        # Calculate scales as the max absolute value for each output row
        # divided by 127 (max value for int8)
        scales = w_fp32.abs().max(dim=-1).values / 127
        scales = scales.to(weights.dtype)

        # Quantize by dividing by scales and rounding to nearest integer
        int8_weights = torch.round(weights / scales.unsqueeze(1)).to(torch.int8)

        # Update the model parameters
        self.int8_weights = int8_weights
        self.scales = scales
        
        return int8_weights, scales
    
    def forward(self, input):
        """
        Forward pass through the quantized linear layer
        
        Args:
            input: Input tensor (shape: batch_size x seq_len x in_features)
            
        Returns:
            output: Output tensor after the linear transformation
        """
        # Cast int8 weights to input dtype while preserving the values
        casted_weights = self.int8_weights.to(input.dtype)
        
        # Perform the linear multiplication and apply the scaling factor
        output = F.linear(input, casted_weights) * self.scales
        
        # Add bias if present
        if self.bias is not None:
            output = output + self.bias
            
        return output

# Helper functions for visualization

def plot_weight_matrix(weights, title="Weight Matrix"):
    """Create a heatmap visualization of weight matrices"""
    plt.figure(figsize=(10, 8))
    
    # Create a centered colormap
    vmax = max(abs(weights.min().item()), abs(weights.max().item()))
    vmin = -vmax
    norm = TwoSlopeNorm(vmin=vmin, vcenter=0, vmax=vmax)
    
    plt.imshow(weights.detach().numpy(), cmap='RdBu_r', norm=norm)
    plt.colorbar(label='Weight Value')
    plt.title(title)
    
    # Save the plot to a bytes buffer
    buf = io.BytesIO()
    plt.savefig(buf, format='png')
    plt.close()
    buf.seek(0)
    
    return Image.open(buf)

def plot_weight_distribution(weights, title="Weight Distribution"):
    """Create a histogram visualization of weight distributions"""
    plt.figure(figsize=(10, 6))
    
    # Flatten the weights to 1D for histogram
    flat_weights = weights.flatten().detach().numpy()
    
    plt.hist(flat_weights, bins=50, alpha=0.7, color='blue')
    plt.xlabel('Weight Value')
    plt.ylabel('Frequency')
    plt.title(title)
    plt.grid(alpha=0.3)
    
    # Save the plot to a bytes buffer
    buf = io.BytesIO()
    plt.savefig(buf, format='png')
    plt.close()
    buf.seek(0)
    
    return Image.open(buf)

def calculate_quantization_error(original, quantized, scales):
    """Calculate error metrics between original and dequantized weights"""
    # Dequantize the weights
    dequantized = quantized.float() * scales.unsqueeze(1)
    
    # Calculate error metrics
    abs_error = (original - dequantized).abs()
    max_error = abs_error.max().item()
    mean_error = abs_error.mean().item()
    
    return max_error, mean_error, dequantized

# Gradio UI components

def initialize_model(in_features, out_features, with_bias, dtype_str):
    """Initialize a new quantized linear layer model"""
    # Map dtype string to torch dtype
    dtype_map = {
        "float32": torch.float32,
        "float16": torch.float16,
        "bfloat16": torch.bfloat16
    }
    dtype = dtype_map[dtype_str]
    
    # Create the model
    model = W8A16LinearLayer(in_features, out_features, bias=with_bias, dtype=dtype)
    
    # Generate random weights for visualization
    random_weights = torch.randn((out_features, in_features), dtype=dtype)
    
    # Original weights visualization
    weights_vis = plot_weight_matrix(random_weights, "Original Weights")
    dist_vis = plot_weight_distribution(random_weights, "Original Weight Distribution")
    
    # Quantize the weights
    int8_weights, scales = model.quantize(random_weights)
    
    # Quantized weights visualization
    q_weights_vis = plot_weight_matrix(int8_weights, "Quantized Weights (INT8)")
    q_dist_vis = plot_weight_distribution(int8_weights, "Quantized Weight Distribution")
    
    # Calculate quantization error
    max_error, mean_error, dequantized = calculate_quantization_error(
        random_weights, int8_weights, scales
    )
    
    # Dequantized weights visualization
    deq_weights_vis = plot_weight_matrix(dequantized, "Dequantized Weights")
    
    # Error visualization
    error = (random_weights - dequantized).abs()
    error_vis = plot_weight_matrix(error, "Quantization Error (Absolute)")
    
    # Create model summary
    model_info = f"""
    ## Model Configuration
    - Input Features: {in_features}
    - Output Features: {out_features}
    - Bias: {"Yes" if with_bias else "No"}
    - Data Type: {dtype_str}
    
    ## Quantization Stats
    - Original Weights Shape: {random_weights.shape}
    - Quantized Weights Shape: {int8_weights.shape}
    - Scales Shape: {scales.shape}
    - Maximum Quantization Error: {max_error:.6f}
    - Mean Quantization Error: {mean_error:.6f}
    - Memory Savings: {100 * (1 - (int8_weights.element_size() + scales.element_size() * scales.numel()/int8_weights.numel()) / random_weights.element_size()):.2f}%
    """
    
    # Create sample input/output for the model
    sample_input = torch.randn(1, in_features, dtype=dtype)
    sample_output = model(sample_input)
    
    io_info = f"""
    ## Sample Input/Output
    - Input Shape: {sample_input.shape}
    - Output Shape: {sample_output.shape}
    - Output Range: [{sample_output.min().item():.4f}, {sample_output.max().item():.4f}]
    """
    
    return model_info, io_info, weights_vis, q_weights_vis, deq_weights_vis, dist_vis, q_dist_vis, error_vis

def quantize_custom_weights(in_features, out_features, with_bias, dtype_str, weight_pattern):
    """Quantize custom weights based on the selected pattern"""
    # Map dtype string to torch dtype
    dtype_map = {
        "float32": torch.float32,
        "float16": torch.float16,
        "bfloat16": torch.bfloat16
    }
    dtype = dtype_map[dtype_str]
    
    # Create the model
    model = W8A16LinearLayer(in_features, out_features, bias=with_bias, dtype=dtype)
    
    # Generate weights based on pattern
    if weight_pattern == "random":
        custom_weights = torch.randn((out_features, in_features), dtype=dtype)
    elif weight_pattern == "eye":
        # Identity matrix (or closest approximation if dimensions don't match)
        custom_weights = torch.zeros((out_features, in_features), dtype=dtype)
        min_dim = min(out_features, in_features)
        custom_weights[:min_dim, :min_dim] = torch.eye(min_dim, dtype=dtype)
    elif weight_pattern == "ones":
        custom_weights = torch.ones((out_features, in_features), dtype=dtype)
    elif weight_pattern == "alternating":
        custom_weights = torch.ones((out_features, in_features), dtype=dtype)
        # Create a checkerboard pattern
        for i in range(out_features):
            for j in range(in_features):
                if (i + j) % 2 == 1:
                    custom_weights[i, j] = -1.0
    elif weight_pattern == "gradient":
        # Linear gradient from -1 to 1
        x = torch.linspace(-1, 1, in_features)
        y = torch.linspace(-1, 1, out_features)
        xx, yy = torch.meshgrid(x, y, indexing='ij')
        custom_weights = (xx + yy).t().to(dtype)
    
    # Original weights visualization
    weights_vis = plot_weight_matrix(custom_weights, f"Original Weights ({weight_pattern})")
    dist_vis = plot_weight_distribution(custom_weights, "Original Weight Distribution")
    
    # Quantize the weights
    int8_weights, scales = model.quantize(custom_weights)
    
    # Quantized weights visualization
    q_weights_vis = plot_weight_matrix(int8_weights, "Quantized Weights (INT8)")
    q_dist_vis = plot_weight_distribution(int8_weights, "Quantized Weight Distribution")
    
    # Calculate quantization error
    max_error, mean_error, dequantized = calculate_quantization_error(
        custom_weights, int8_weights, scales
    )
    
    # Dequantized weights visualization
    deq_weights_vis = plot_weight_matrix(dequantized, "Dequantized Weights")
    
    # Error visualization
    error = (custom_weights - dequantized).abs()
    error_vis = plot_weight_matrix(error, "Quantization Error (Absolute)")
    
    # Quantization details
    quant_info = f"""
    ## Quantization Details
    - Original Data Type: {dtype_str}
    - Quantized Data Type: int8 (8-bit)
    - Weight Pattern: {weight_pattern}
    
    ## Error Analysis
    - Maximum Quantization Error: {max_error:.6f}
    - Mean Quantization Error: {mean_error:.6f}
    - Memory Savings: {100 * (1 - (int8_weights.element_size() + scales.element_size() * scales.numel()/int8_weights.numel()) / custom_weights.element_size()):.2f}%
    
    ## Tensor Shapes
    - Original Weights: {custom_weights.shape}
    - Quantized Weights: {int8_weights.shape}
    - Quantization Scales: {scales.shape}
    """
    
    # Create row histograms for quantization scales
    plt.figure(figsize=(10, 6))
    plt.hist(scales.detach().cpu().numpy(), bins=30, alpha=0.7, color='green')
    plt.xlabel('Scale Value')
    plt.ylabel('Frequency')
    plt.title('Distribution of Quantization Scales')
    plt.grid(alpha=0.3)
    
    # Save the plot to a bytes buffer
    buf = io.BytesIO()
    plt.savefig(buf, format='png')
    plt.close()
    buf.seek(0)
    scales_vis = Image.open(buf)
    
    return quant_info, weights_vis, q_weights_vis, deq_weights_vis, dist_vis, q_dist_vis, error_vis, scales_vis

# Create Gradio interface
with gr.Blocks(title="8-Bit Weight Quantizer") as demo:
    gr.Markdown("# PyTorch 8-Bit Weight Quantizer")
    gr.Markdown("""
    This tool demonstrates quantization of neural network weights to INT8 precision.
    It implements a custom `W8A16LinearLayer` that uses 8-bit weights with 16-bit activations.
    """)
    
    with gr.Tabs():
        with gr.TabItem("Initialize Model"):
            with gr.Row():
                with gr.Column():
                    in_feat = gr.Slider(minimum=1, maximum=512, value=16, step=1, label="Input Features")
                    out_feat = gr.Slider(minimum=1, maximum=512, value=32, step=1, label="Output Features")
                    with_bias = gr.Checkbox(value=True, label="Include Bias")
                    dtype = gr.Dropdown(choices=["float32", "float16", "bfloat16"], value="float32", label="Data Type")
                    init_btn = gr.Button("Initialize Model")
                
                with gr.Column():
                    model_info = gr.Markdown()
                    io_info = gr.Markdown()
            
            with gr.Row():
                orig_weights = gr.Image(label="Original Weights")
                quant_weights = gr.Image(label="Quantized Weights (INT8)")
                dequant_weights = gr.Image(label="Dequantized Weights")
            
            with gr.Row():
                orig_dist = gr.Image(label="Original Weight Distribution")
                quant_dist = gr.Image(label="Quantized Weight Distribution")
                error_vis = gr.Image(label="Quantization Error")
                
        with gr.TabItem("Custom Quantization"):
            with gr.Row():
                with gr.Column():
                    c_in_feat = gr.Slider(minimum=1, maximum=512, value=16, step=1, label="Input Features")
                    c_out_feat = gr.Slider(minimum=1, maximum=512, value=32, step=1, label="Output Features")
                    c_with_bias = gr.Checkbox(value=True, label="Include Bias")
                    c_dtype = gr.Dropdown(choices=["float32", "float16", "bfloat16"], value="float32", label="Data Type")
                    weight_pattern = gr.Dropdown(
                        choices=["random", "eye", "ones", "alternating", "gradient"], 
                        value="random", 
                        label="Weight Pattern"
                    )
                    quantize_btn = gr.Button("Quantize Weights")
                
                with gr.Column():
                    quant_details = gr.Markdown()
            
            with gr.Row():
                c_orig_weights = gr.Image(label="Original Weights")
                c_quant_weights = gr.Image(label="Quantized Weights (INT8)")
                c_dequant_weights = gr.Image(label="Dequantized Weights")
            
            with gr.Row():
                c_orig_dist = gr.Image(label="Original Weight Distribution")
                c_quant_dist = gr.Image(label="Quantized Weight Distribution")
                c_error_vis = gr.Image(label="Quantization Error")
            
            with gr.Row():
                scales_dist = gr.Image(label="Quantization Scales Distribution")
                
        with gr.TabItem("About"):
            gr.Markdown("""
            ## 8-bit Quantizer Implementation
            
            This implementation includes:
            
            1. **W8A16LinearLayer** - A PyTorch module that uses INT8 weights and FP16/BF16/FP32 activations
            2. **Quantization** - Converts FP32/FP16/BF16 weights to INT8 using per-output-channel scaling
            3. **Visualization** - Shows the impact of quantization on weight distributions and errors
            
            ### How It Works:
            
            1. For each output channel, find the maximum absolute weight value
            2. Scale all weights in that channel so the maximum fits in INT8 range (-128 to 127)
            3. Round scaled weights to integers and store as INT8
            4. During inference, multiply INT8 weights by scaling factors to recover approximate FP values
            
            The quantization process reduces memory usage by up to 75% compared to FP32 weights.
            
            ### References:
            
            - This implementation is based on modern techniques used in LLM quantization
            - Similar methods are used in libraries like bitsandbytes, AutoGPTQ, and GPTQ-for-LLaMa
            """)
    
    # Connect buttons to functions
    init_btn.click(
        initialize_model,
        inputs=[in_feat, out_feat, with_bias, dtype],
        outputs=[model_info, io_info, orig_weights, quant_weights, dequant_weights, orig_dist, quant_dist, error_vis]
    )
    
    quantize_btn.click(
        quantize_custom_weights,
        inputs=[c_in_feat, c_out_feat, c_with_bias, c_dtype, weight_pattern],
        outputs=[quant_details, c_orig_weights, c_quant_weights, c_dequant_weights, c_orig_dist, c_quant_dist, c_error_vis, scales_dist]
    )

# Launch the app
if __name__ == "__main__":
    demo.launch()