rayf-07's picture
Upload Ouro-2.6B_smoothquant_W8A8 with bundled source code
b144856 verified
from __future__ import annotations
import math
from typing import Optional
import torch
from torch import Tensor, nn
def _select_quant_dtype(bits: int) -> torch.dtype:
if bits <= 0:
raise ValueError("Quantization bits must be positive.")
if bits <= 8:
return torch.int8
if bits <= 16:
return torch.int16
raise ValueError("Quantization bits above 16 are not supported.")
class QuantizedLinear(nn.Module):
"""Weight-only linear layer with per-group scales."""
def __init__(
self,
in_features: int,
out_features: int,
*,
weight_bits: int = 4,
group_size: int = 128,
bias: bool = True,
) -> None:
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.weight_bits = weight_bits
self.group_size = group_size
self.qmin = -(2 ** (weight_bits - 1))
self.qmax = (2 ** (weight_bits - 1)) - 1
self.num_groups = math.ceil(in_features / group_size)
self.quant_dtype = _select_quant_dtype(weight_bits)
weight_shape = (out_features, in_features)
scale_shape = (out_features, self.num_groups)
self.register_buffer("weight", torch.zeros(weight_shape, dtype=self.quant_dtype))
self.register_buffer(
"weight_scales", torch.ones(scale_shape, dtype=torch.float32)
)
self.bias = nn.Parameter(torch.zeros(out_features)) if bias else None
self._weight_cache: Optional[Tensor] = None
def _invalidate_cache(self) -> None:
self._weight_cache = None
def refresh_weight_cache(self) -> None:
self._weight_cache = self._dequantize_weight()
def _dequantize_weight(self) -> Tensor:
group_tensors = []
for group_idx in range(self.num_groups):
start = group_idx * self.group_size
end = min((group_idx + 1) * self.group_size, self.in_features)
block = self.weight[:, start:end].float()
scale = self.weight_scales[:, group_idx].unsqueeze(1)
group_tensors.append(block * scale)
return torch.cat(group_tensors, dim=1)
def forward(self, input: Tensor) -> Tensor:
if self._weight_cache is None or self._weight_cache.device != input.device:
self.refresh_weight_cache()
self._weight_cache = self._weight_cache.to(input.device)
weight = self._weight_cache
if weight.dtype != input.dtype:
weight = weight.to(input.dtype)
bias = self.bias
if bias is not None and bias.device != input.device:
bias = bias.to(input.device)
if bias is not None and bias.dtype != input.dtype:
bias = bias.to(input.dtype)
return nn.functional.linear(input, weight, bias)
def load_quant_state(self, weight: Tensor, weight_scales: Tensor) -> None:
if weight.shape != self.weight.shape:
raise ValueError(
f"Quantized weight shape mismatch: expected {tuple(self.weight.shape)}, "
f"got {tuple(weight.shape)}"
)
if weight_scales.shape != self.weight_scales.shape:
raise ValueError(
f"Scale tensor shape mismatch: expected {tuple(self.weight_scales.shape)}, "
f"got {tuple(weight_scales.shape)}"
)
self.weight.copy_(weight.to(dtype=self.quant_dtype))
self.weight_scales.copy_(weight_scales.to(dtype=torch.float32))
self._invalidate_cache()
def extra_repr(self) -> str:
return (
f"in_features={self.in_features}, out_features={self.out_features}, "
f"group_size={self.group_size}, bits={self.weight_bits}, bias={self.bias is not None}"
)
class SmoothQuantLinear(nn.Module):
"""Linear layer with SmoothQuant W8A8 (or configurable) quantization."""
def __init__(
self,
in_features: int,
out_features: int,
*,
weight_bits: int = 8,
activation_bits: int = 8,
bias: bool = True,
) -> None:
super().__init__()
if weight_bits <= 0 or weight_bits > 16:
raise ValueError("Weight bits must be in range [1, 16].")
if activation_bits <= 0 or activation_bits > 16:
raise ValueError("Activation bits must be in range [1, 16].")
self.in_features = in_features
self.out_features = out_features
self.weight_bits = weight_bits
self.activation_bits = activation_bits
self.weight_qmin = -(2 ** (weight_bits - 1))
self.weight_qmax = (2 ** (weight_bits - 1)) - 1
self.activation_qmin = -(2 ** (activation_bits - 1))
self.activation_qmax = (2 ** (activation_bits - 1)) - 1
self.quant_dtype = _select_quant_dtype(weight_bits)
weight_shape = (out_features, in_features)
self.register_buffer("weight", torch.zeros(weight_shape, dtype=self.quant_dtype))
self.register_buffer(
"weight_scales", torch.ones(out_features, 1, dtype=torch.float32)
)
self.register_buffer(
"input_scale", torch.ones(in_features, dtype=torch.float32)
)
self.register_buffer(
"activation_scale", torch.ones(in_features, dtype=torch.float32)
)
self.bias = nn.Parameter(torch.zeros(out_features)) if bias else None
self._weight_cache: Optional[Tensor] = None
def _invalidate_cache(self) -> None:
self._weight_cache = None
def refresh_weight_cache(self) -> None:
weight = self.weight.float() * self.weight_scales
self._weight_cache = weight
def forward(self, input: Tensor) -> Tensor:
if self._weight_cache is None or self._weight_cache.device != input.device:
self.refresh_weight_cache()
self._weight_cache = self._weight_cache.to(input.device)
activation_scale = self.activation_scale.to(input.device)
input_scale = self.input_scale.to(input.device)
scaled_input = input * input_scale
quantized = torch.round(scaled_input / activation_scale).clamp(
self.activation_qmin, self.activation_qmax
)
dequant_input = quantized * activation_scale
weight = self._weight_cache
if weight.dtype != dequant_input.dtype:
weight = weight.to(dequant_input.dtype)
bias = self.bias
if bias is not None and bias.device != input.device:
bias = bias.to(input.device)
if bias is not None and bias.dtype != dequant_input.dtype:
bias = bias.to(dequant_input.dtype)
return nn.functional.linear(dequant_input, weight, bias)
def load_quant_state(
self,
weight: Tensor,
weight_scales: Tensor,
input_scale: Tensor,
activation_scale: Tensor,
) -> None:
if weight.shape != self.weight.shape:
raise ValueError(
f"Quantized weight shape mismatch: expected {tuple(self.weight.shape)}, "
f"got {tuple(weight.shape)}"
)
if weight_scales.shape != self.weight_scales.shape:
raise ValueError(
f"Weight scale shape mismatch: expected {tuple(self.weight_scales.shape)}, "
f"got {tuple(weight_scales.shape)}"
)
if input_scale.shape != self.input_scale.shape:
raise ValueError(
f"Input scale shape mismatch: expected {tuple(self.input_scale.shape)}, "
f"got {tuple(input_scale.shape)}"
)
if activation_scale.shape != self.activation_scale.shape:
raise ValueError(
f"Activation scale shape mismatch: expected {tuple(self.activation_scale.shape)}, "
f"got {tuple(activation_scale.shape)}"
)
self.weight.copy_(weight.to(dtype=self.quant_dtype))
self.weight_scales.copy_(weight_scales.to(dtype=torch.float32))
self.input_scale.copy_(input_scale.to(dtype=torch.float32))
self.activation_scale.copy_(activation_scale.to(dtype=torch.float32))
self._invalidate_cache()
def extra_repr(self) -> str:
return (
f"in_features={self.in_features}, out_features={self.out_features}, "
f"weight_bits={self.weight_bits}, activation_bits={self.activation_bits}, "
f"bias={self.bias is not None}"
)