File size: 3,013 Bytes
b144856
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from __future__ import annotations

from dataclasses import dataclass
from typing import Dict

import torch
from torch import Tensor, nn

from .config import QuantizationConfig
from .observers import ObserverState


@dataclass
class SmoothQuantWeights:
    """Quantized representation for SmoothQuant linear layers."""

    weight: Tensor
    weight_scales: Tensor
    input_scale: Tensor
    activation_scale: Tensor


def _prepare_stats(
    observer_state: ObserverState,
    weight: Tensor,
    epsilon: float,
) -> tuple[Tensor, Tensor]:
    activation_stats = observer_state.max_abs_values.to(dtype=torch.float32)
    if activation_stats.numel() < weight.size(1):
        activation_stats = torch.nn.functional.pad(
            activation_stats,
            (0, weight.size(1) - activation_stats.numel()),
            value=1.0,
        )
    activation_stats = activation_stats[: weight.size(1)]
    activation_stats = activation_stats.clamp_min(epsilon)

    weight_stats = weight.abs().amax(dim=0).clamp_min(epsilon)
    return activation_stats, weight_stats


def quantize_linear_smooth(
    module: nn.Linear,
    observer_state: ObserverState,
    config: QuantizationConfig,
) -> SmoothQuantWeights:
    """
    Apply SmoothQuant to a linear layer, producing int quantized weights and activation scales.
    """

    weight_bits = config.weight_bits
    activation_bits = config.activation_bits
    epsilon = config.epsilon
    alpha = config.alpha
    quant_dtype = torch.int8 if weight_bits <= 8 else torch.int16

    weight = module.weight.detach().to(torch.float32).clone()
    activation_stats, weight_stats = _prepare_stats(observer_state, weight, epsilon)

    ratio = activation_stats / weight_stats
    smoothing_factor = torch.pow(ratio, alpha).clamp_min(epsilon)

    input_scale = (1.0 / smoothing_factor).to(torch.float32)
    scaled_weight = weight * smoothing_factor.unsqueeze(0)

    act_max_scaled = activation_stats * input_scale
    act_qmax = (2 ** (activation_bits - 1)) - 1
    activation_scale = (act_max_scaled / act_qmax).clamp_min(epsilon)

    weight_qmax = (2 ** (weight_bits - 1)) - 1
    weight_max = scaled_weight.abs().amax(dim=1).clamp_min(epsilon)
    weight_scales = (weight_max / weight_qmax).unsqueeze(1)

    quantized_weight = torch.round(scaled_weight / weight_scales).clamp(
        -(2 ** (weight_bits - 1)), weight_qmax
    ).to(quant_dtype)

    return SmoothQuantWeights(
        weight=quantized_weight.cpu(),
        weight_scales=weight_scales.to(torch.float32).cpu(),
        input_scale=input_scale.cpu(),
        activation_scale=activation_scale.cpu(),
    )


def summarize_smoothquant(
    stats: Dict[str, SmoothQuantWeights]
) -> Dict[str, Dict[str, float]]:
    summary: Dict[str, Dict[str, float]] = {}
    for name, record in stats.items():
        summary[name] = {
            "weight_scale_mean": float(record.weight_scales.mean()),
            "activation_scale_mean": float(record.activation_scale.mean()),
        }
    return summary