| | from __future__ import annotations |
| |
|
| | from dataclasses import dataclass |
| | from typing import Optional |
| |
|
| | import torch |
| | from torch import Tensor, nn |
| |
|
| |
|
| | @dataclass |
| | class ObserverState: |
| | """Activation statistics collected for a linear layer.""" |
| |
|
| | max_abs_values: Tensor |
| |
|
| |
|
| | class LinearInputObserver: |
| | """ |
| | Collects per-feature activation maxima for a `nn.Linear` module. |
| | |
| | The statistics are later used to derive quantization scales that take input |
| | distribution into account, mimicking the behaviour of AWQ. |
| | """ |
| |
|
| | def __init__(self, module_name: str): |
| | self.module_name = module_name |
| | self._max_abs: Optional[Tensor] = None |
| |
|
| | def __call__(self, module: nn.Module, inputs: tuple[Tensor, ...]) -> None: |
| | if not inputs: |
| | return |
| |
|
| | data = inputs[0] |
| | if data is None: |
| | return |
| |
|
| | if data.dim() > 2: |
| | data = data.reshape(-1, data.size(-1)) |
| | elif data.dim() < 2: |
| | data = data.unsqueeze(0) |
| |
|
| | data = data.detach() |
| | if data.dtype in (torch.float16, torch.bfloat16): |
| | data = data.to(torch.float32) |
| |
|
| | max_vals = data.abs().amax(dim=0) |
| | if self._max_abs is None: |
| | self._max_abs = max_vals |
| | else: |
| | |
| | if max_vals.size(0) != self._max_abs.size(0): |
| | target = max(self._max_abs.size(0), max_vals.size(0)) |
| | self._max_abs = torch.nn.functional.pad( |
| | self._max_abs, |
| | (0, target - self._max_abs.size(0)), |
| | value=0.0, |
| | ) |
| | max_vals = torch.nn.functional.pad( |
| | max_vals, |
| | (0, target - max_vals.size(0)), |
| | value=0.0, |
| | ) |
| | self._max_abs = torch.maximum(self._max_abs, max_vals) |
| |
|
| | def to_state(self) -> ObserverState: |
| | if self._max_abs is None: |
| | raise RuntimeError( |
| | f"No activation statistics recorded for module '{self.module_name}'." |
| | ) |
| | return ObserverState(max_abs_values=self._max_abs) |
| |
|
| |
|