rayf-07's picture
Upload Ouro-2.6B_smoothquant_W8A8 with bundled source code
b144856 verified
from __future__ import annotations
from dataclasses import dataclass
from typing import Dict
import torch
from torch import Tensor, nn
from .config import QuantizationConfig
from .observers import ObserverState
@dataclass
class SmoothQuantWeights:
"""Quantized representation for SmoothQuant linear layers."""
weight: Tensor
weight_scales: Tensor
input_scale: Tensor
activation_scale: Tensor
def _prepare_stats(
observer_state: ObserverState,
weight: Tensor,
epsilon: float,
) -> tuple[Tensor, Tensor]:
activation_stats = observer_state.max_abs_values.to(dtype=torch.float32)
if activation_stats.numel() < weight.size(1):
activation_stats = torch.nn.functional.pad(
activation_stats,
(0, weight.size(1) - activation_stats.numel()),
value=1.0,
)
activation_stats = activation_stats[: weight.size(1)]
activation_stats = activation_stats.clamp_min(epsilon)
weight_stats = weight.abs().amax(dim=0).clamp_min(epsilon)
return activation_stats, weight_stats
def quantize_linear_smooth(
module: nn.Linear,
observer_state: ObserverState,
config: QuantizationConfig,
) -> SmoothQuantWeights:
"""
Apply SmoothQuant to a linear layer, producing int quantized weights and activation scales.
"""
weight_bits = config.weight_bits
activation_bits = config.activation_bits
epsilon = config.epsilon
alpha = config.alpha
quant_dtype = torch.int8 if weight_bits <= 8 else torch.int16
weight = module.weight.detach().to(torch.float32).clone()
activation_stats, weight_stats = _prepare_stats(observer_state, weight, epsilon)
ratio = activation_stats / weight_stats
smoothing_factor = torch.pow(ratio, alpha).clamp_min(epsilon)
input_scale = (1.0 / smoothing_factor).to(torch.float32)
scaled_weight = weight * smoothing_factor.unsqueeze(0)
act_max_scaled = activation_stats * input_scale
act_qmax = (2 ** (activation_bits - 1)) - 1
activation_scale = (act_max_scaled / act_qmax).clamp_min(epsilon)
weight_qmax = (2 ** (weight_bits - 1)) - 1
weight_max = scaled_weight.abs().amax(dim=1).clamp_min(epsilon)
weight_scales = (weight_max / weight_qmax).unsqueeze(1)
quantized_weight = torch.round(scaled_weight / weight_scales).clamp(
-(2 ** (weight_bits - 1)), weight_qmax
).to(quant_dtype)
return SmoothQuantWeights(
weight=quantized_weight.cpu(),
weight_scales=weight_scales.to(torch.float32).cpu(),
input_scale=input_scale.cpu(),
activation_scale=activation_scale.cpu(),
)
def summarize_smoothquant(
stats: Dict[str, SmoothQuantWeights]
) -> Dict[str, Dict[str, float]]:
summary: Dict[str, Dict[str, float]] = {}
for name, record in stats.items():
summary[name] = {
"weight_scale_mean": float(record.weight_scales.mean()),
"activation_scale_mean": float(record.activation_scale.mean()),
}
return summary