rayf-07's picture
Upload Ouro-2.6B_smoothquant_W8A8 with bundled source code
b144856 verified
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional
import torch
from torch import Tensor, nn
@dataclass
class ObserverState:
"""Activation statistics collected for a linear layer."""
max_abs_values: Tensor
class LinearInputObserver:
"""
Collects per-feature activation maxima for a `nn.Linear` module.
The statistics are later used to derive quantization scales that take input
distribution into account, mimicking the behaviour of AWQ.
"""
def __init__(self, module_name: str):
self.module_name = module_name
self._max_abs: Optional[Tensor] = None
def __call__(self, module: nn.Module, inputs: tuple[Tensor, ...]) -> None:
if not inputs:
return
data = inputs[0]
if data is None:
return
if data.dim() > 2:
data = data.reshape(-1, data.size(-1))
elif data.dim() < 2:
data = data.unsqueeze(0)
data = data.detach()
if data.dtype in (torch.float16, torch.bfloat16):
data = data.to(torch.float32)
max_vals = data.abs().amax(dim=0)
if self._max_abs is None:
self._max_abs = max_vals
else:
# Pad in case dimensionality changes due to mixed inputs.
if max_vals.size(0) != self._max_abs.size(0):
target = max(self._max_abs.size(0), max_vals.size(0))
self._max_abs = torch.nn.functional.pad(
self._max_abs,
(0, target - self._max_abs.size(0)),
value=0.0,
)
max_vals = torch.nn.functional.pad(
max_vals,
(0, target - max_vals.size(0)),
value=0.0,
)
self._max_abs = torch.maximum(self._max_abs, max_vals)
def to_state(self) -> ObserverState:
if self._max_abs is None:
raise RuntimeError(
f"No activation statistics recorded for module '{self.module_name}'."
)
return ObserverState(max_abs_values=self._max_abs)