| | |
| | |
| | |
| | |
| |
|
| | from .em import EM, EmptyClusterResolveError |
| |
|
| |
|
| | class PQ(EM): |
| | """ |
| | Quantizes the layer weights W with the standard Product Quantization |
| | technique. This learns a codebook of codewords or centroids of size |
| | block_size from W. For further reference on using PQ to quantize |
| | neural networks, see "And the Bit Goes Down: Revisiting the Quantization |
| | of Neural Networks", Stock et al., ICLR 2020. |
| | |
| | PQ is performed in two steps: |
| | (1) The matrix W (weights or fully-connected or convolutional layer) |
| | is reshaped to (block_size, -1). |
| | - If W is fully-connected (2D), its columns are split into |
| | blocks of size block_size. |
| | - If W is convolutional (4D), its filters are split along the |
| | spatial dimension. |
| | (2) We apply the standard EM/k-means algorithm to the resulting reshaped matrix. |
| | |
| | Args: |
| | - W: weight matrix to quantize of size (in_features x out_features) |
| | - block_size: size of the blocks (subvectors) |
| | - n_centroids: number of centroids |
| | - n_iter: number of k-means iterations |
| | - eps: for cluster reassignment when an empty cluster is found |
| | - max_tentatives for cluster reassignment when an empty cluster is found |
| | - verbose: print information after each iteration |
| | |
| | Remarks: |
| | - block_size be compatible with the shape of W |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | W, |
| | block_size, |
| | n_centroids=256, |
| | n_iter=20, |
| | eps=1e-6, |
| | max_tentatives=30, |
| | verbose=True, |
| | ): |
| | self.block_size = block_size |
| | W_reshaped = self._reshape(W) |
| | super(PQ, self).__init__( |
| | W_reshaped, |
| | n_centroids=n_centroids, |
| | n_iter=n_iter, |
| | eps=eps, |
| | max_tentatives=max_tentatives, |
| | verbose=verbose, |
| | ) |
| |
|
| | def _reshape(self, W): |
| | """ |
| | Reshapes the matrix W as expained in step (1). |
| | """ |
| |
|
| | |
| | if len(W.size()) == 2: |
| | self.out_features, self.in_features = W.size() |
| | assert ( |
| | self.in_features % self.block_size == 0 |
| | ), "Linear: n_blocks must be a multiple of in_features" |
| | return ( |
| | W.reshape(self.out_features, -1, self.block_size) |
| | .permute(2, 1, 0) |
| | .flatten(1, 2) |
| | ) |
| |
|
| | |
| | elif len(W.size()) == 4: |
| | self.out_channels, self.in_channels, self.k_h, self.k_w = W.size() |
| | assert ( |
| | self.in_channels * self.k_h * self.k_w |
| | ) % self.block_size == 0, ( |
| | "Conv2d: n_blocks must be a multiple of in_channels * k_h * k_w" |
| | ) |
| | return ( |
| | W.reshape(self.out_channels, -1, self.block_size) |
| | .permute(2, 1, 0) |
| | .flatten(1, 2) |
| | ) |
| | |
| | else: |
| | raise NotImplementedError(W.size()) |
| |
|
| | def encode(self): |
| | """ |
| | Performs self.n_iter EM steps. |
| | """ |
| |
|
| | self.initialize_centroids() |
| | for i in range(self.n_iter): |
| | try: |
| | self.step(i) |
| | except EmptyClusterResolveError: |
| | break |
| |
|
| | def decode(self): |
| | """ |
| | Returns the encoded full weight matrix. Must be called after |
| | the encode function. |
| | """ |
| |
|
| | |
| | if "k_h" not in self.__dict__: |
| | return ( |
| | self.centroids[self.assignments] |
| | .reshape(-1, self.out_features, self.block_size) |
| | .permute(1, 0, 2) |
| | .flatten(1, 2) |
| | ) |
| |
|
| | |
| | else: |
| | return ( |
| | self.centroids[self.assignments] |
| | .reshape(-1, self.out_channels, self.block_size) |
| | .permute(1, 0, 2) |
| | .reshape(self.out_channels, self.in_channels, self.k_h, self.k_w) |
| | ) |
| |
|