File size: 2,161 Bytes
b144856
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from __future__ import annotations

from dataclasses import dataclass
from typing import Optional

import torch
from torch import Tensor, nn


@dataclass
class ObserverState:
    """Activation statistics collected for a linear layer."""

    max_abs_values: Tensor


class LinearInputObserver:
    """
    Collects per-feature activation maxima for a `nn.Linear` module.

    The statistics are later used to derive quantization scales that take input
    distribution into account, mimicking the behaviour of AWQ.
    """

    def __init__(self, module_name: str):
        self.module_name = module_name
        self._max_abs: Optional[Tensor] = None

    def __call__(self, module: nn.Module, inputs: tuple[Tensor, ...]) -> None:
        if not inputs:
            return

        data = inputs[0]
        if data is None:
            return

        if data.dim() > 2:
            data = data.reshape(-1, data.size(-1))
        elif data.dim() < 2:
            data = data.unsqueeze(0)

        data = data.detach()
        if data.dtype in (torch.float16, torch.bfloat16):
            data = data.to(torch.float32)

        max_vals = data.abs().amax(dim=0)
        if self._max_abs is None:
            self._max_abs = max_vals
        else:
            # Pad in case dimensionality changes due to mixed inputs.
            if max_vals.size(0) != self._max_abs.size(0):
                target = max(self._max_abs.size(0), max_vals.size(0))
                self._max_abs = torch.nn.functional.pad(
                    self._max_abs,
                    (0, target - self._max_abs.size(0)),
                    value=0.0,
                )
                max_vals = torch.nn.functional.pad(
                    max_vals,
                    (0, target - max_vals.size(0)),
                    value=0.0,
                )
            self._max_abs = torch.maximum(self._max_abs, max_vals)

    def to_state(self) -> ObserverState:
        if self._max_abs is None:
            raise RuntimeError(
                f"No activation statistics recorded for module '{self.module_name}'."
            )
        return ObserverState(max_abs_values=self._max_abs)