from __future__ import annotations from dataclasses import dataclass from typing import Optional import torch from torch import Tensor, nn @dataclass class ObserverState: """Activation statistics collected for a linear layer.""" max_abs_values: Tensor class LinearInputObserver: """ Collects per-feature activation maxima for a `nn.Linear` module. The statistics are later used to derive quantization scales that take input distribution into account, mimicking the behaviour of AWQ. """ def __init__(self, module_name: str): self.module_name = module_name self._max_abs: Optional[Tensor] = None def __call__(self, module: nn.Module, inputs: tuple[Tensor, ...]) -> None: if not inputs: return data = inputs[0] if data is None: return if data.dim() > 2: data = data.reshape(-1, data.size(-1)) elif data.dim() < 2: data = data.unsqueeze(0) data = data.detach() if data.dtype in (torch.float16, torch.bfloat16): data = data.to(torch.float32) max_vals = data.abs().amax(dim=0) if self._max_abs is None: self._max_abs = max_vals else: # Pad in case dimensionality changes due to mixed inputs. if max_vals.size(0) != self._max_abs.size(0): target = max(self._max_abs.size(0), max_vals.size(0)) self._max_abs = torch.nn.functional.pad( self._max_abs, (0, target - self._max_abs.size(0)), value=0.0, ) max_vals = torch.nn.functional.pad( max_vals, (0, target - max_vals.size(0)), value=0.0, ) self._max_abs = torch.maximum(self._max_abs, max_vals) def to_state(self) -> ObserverState: if self._max_abs is None: raise RuntimeError( f"No activation statistics recorded for module '{self.module_name}'." ) return ObserverState(max_abs_values=self._max_abs)