|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
Binarizers take a (real value) matrix as input and produce a binary (values in {0,1}) mask of the same shape. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
from torch import autograd |
|
|
|
|
|
|
|
|
class ThresholdBinarizer(autograd.Function): |
|
|
""" |
|
|
Thresholdd binarizer. |
|
|
Computes a binary mask M from a real value matrix S such that `M_{i,j} = 1` if and only if `S_{i,j} > \tau` |
|
|
where `\tau` is a real value threshold. |
|
|
|
|
|
Implementation is inspired from: |
|
|
https://github.com/arunmallya/piggyback |
|
|
Piggyback: Adapting a Single Network to Multiple Tasks by Learning to Mask Weights |
|
|
Arun Mallya, Dillon Davis, Svetlana Lazebnik |
|
|
""" |
|
|
|
|
|
@staticmethod |
|
|
def forward(ctx, inputs: torch.tensor, threshold: float, sigmoid: bool): |
|
|
""" |
|
|
Args: |
|
|
inputs (`torch.FloatTensor`) |
|
|
The input matrix from which the binarizer computes the binary mask. |
|
|
threshold (`float`) |
|
|
The threshold value (in R). |
|
|
sigmoid (`bool`) |
|
|
If set to ``True``, we apply the sigmoid function to the `inputs` matrix before comparing to `threshold`. |
|
|
In this case, `threshold` should be a value between 0 and 1. |
|
|
Returns: |
|
|
mask (`torch.FloatTensor`) |
|
|
Binary matrix of the same size as `inputs` acting as a mask (1 - the associated weight is |
|
|
retained, 0 - the associated weight is pruned). |
|
|
""" |
|
|
nb_elems = inputs.numel() |
|
|
nb_min = int(0.005 * nb_elems) + 1 |
|
|
if sigmoid: |
|
|
mask = (torch.sigmoid(inputs) > threshold).type(inputs.type()) |
|
|
else: |
|
|
mask = (inputs > threshold).type(inputs.type()) |
|
|
if mask.sum() < nb_min: |
|
|
|
|
|
k_threshold = inputs.flatten().kthvalue(max(nb_elems - nb_min, 1)).values |
|
|
mask = (inputs > k_threshold).type(inputs.type()) |
|
|
return mask |
|
|
|
|
|
@staticmethod |
|
|
def backward(ctx, gradOutput): |
|
|
return gradOutput, None, None |
|
|
|
|
|
|
|
|
class TopKBinarizer(autograd.Function): |
|
|
""" |
|
|
Top-k Binarizer. |
|
|
Computes a binary mask M from a real value matrix S such that `M_{i,j} = 1` if and only if `S_{i,j}` |
|
|
is among the k% highest values of S. |
|
|
|
|
|
Implementation is inspired from: |
|
|
https://github.com/allenai/hidden-networks |
|
|
What's hidden in a randomly weighted neural network? |
|
|
Vivek Ramanujan*, Mitchell Wortsman*, Aniruddha Kembhavi, Ali Farhadi, Mohammad Rastegari |
|
|
""" |
|
|
|
|
|
@staticmethod |
|
|
def forward(ctx, inputs: torch.tensor, threshold: float): |
|
|
""" |
|
|
Args: |
|
|
inputs (`torch.FloatTensor`) |
|
|
The input matrix from which the binarizer computes the binary mask. |
|
|
threshold (`float`) |
|
|
The percentage of weights to keep (the rest is pruned). |
|
|
`threshold` is a float between 0 and 1. |
|
|
Returns: |
|
|
mask (`torch.FloatTensor`) |
|
|
Binary matrix of the same size as `inputs` acting as a mask (1 - the associated weight is |
|
|
retained, 0 - the associated weight is pruned). |
|
|
""" |
|
|
|
|
|
mask = inputs.clone() |
|
|
_, idx = inputs.flatten().sort(descending=True) |
|
|
j = int(threshold * inputs.numel()) |
|
|
|
|
|
|
|
|
flat_out = mask.flatten() |
|
|
flat_out[idx[j:]] = 0 |
|
|
flat_out[idx[:j]] = 1 |
|
|
return mask |
|
|
|
|
|
@staticmethod |
|
|
def backward(ctx, gradOutput): |
|
|
return gradOutput, None |
|
|
|
|
|
|
|
|
class MagnitudeBinarizer(object): |
|
|
""" |
|
|
Magnitude Binarizer. |
|
|
Computes a binary mask M from a real value matrix S such that `M_{i,j} = 1` if and only if `S_{i,j}` |
|
|
is among the k% highest values of |S| (absolute value). |
|
|
|
|
|
Implementation is inspired from https://github.com/NervanaSystems/distiller/blob/2291fdcc2ea642a98d4e20629acb5a9e2e04b4e6/distiller/pruning/automated_gradual_pruner.py#L24 |
|
|
""" |
|
|
|
|
|
@staticmethod |
|
|
def apply(inputs: torch.tensor, threshold: float): |
|
|
""" |
|
|
Args: |
|
|
inputs (`torch.FloatTensor`) |
|
|
The input matrix from which the binarizer computes the binary mask. |
|
|
This input marix is typically the weight matrix. |
|
|
threshold (`float`) |
|
|
The percentage of weights to keep (the rest is pruned). |
|
|
`threshold` is a float between 0 and 1. |
|
|
Returns: |
|
|
mask (`torch.FloatTensor`) |
|
|
Binary matrix of the same size as `inputs` acting as a mask (1 - the associated weight is |
|
|
retained, 0 - the associated weight is pruned). |
|
|
""" |
|
|
|
|
|
mask = inputs.clone() |
|
|
_, idx = inputs.abs().flatten().sort(descending=True) |
|
|
j = int(threshold * inputs.numel()) |
|
|
|
|
|
|
|
|
flat_out = mask.flatten() |
|
|
flat_out[idx[j:]] = 0 |
|
|
flat_out[idx[:j]] = 1 |
|
|
return mask |
|
|
|