Spaces:
Runtime error
Runtime error
| """ | |
| Parts of this file have been adapted from | |
| https://github.com/nicola-decao/diffmask/blob/master/diffmask/models/gates.py | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| from torch import Tensor | |
| from typing import Optional | |
| from utils.distributions import RectifiedStreched, BinaryConcrete | |
| class MLPGate(nn.Module): | |
| def __init__(self, input_size: int, hidden_size: int, bias: bool = True): | |
| """ | |
| This is an MLP with the following structure; | |
| Linear(input_size, hidden_size), Tanh(), Linear(hidden_size, 1) | |
| The bias of the last layer is set to 5.0 to start with high probability | |
| of keeping states (fundamental for good convergence as the initialized | |
| DiffMask has not learned what to mask yet). | |
| Args: | |
| input_size (int): the number of input features | |
| hidden_size (int): the number of hidden units | |
| bias (bool): whether to use a bias term | |
| """ | |
| super().__init__() | |
| self.f = nn.Sequential( | |
| nn.utils.weight_norm(nn.Linear(input_size, hidden_size)), | |
| nn.Tanh(), | |
| nn.utils.weight_norm(nn.Linear(hidden_size, 1, bias=bias)), | |
| ) | |
| if bias: | |
| self.f[-1].bias.data[:] = 5.0 | |
| def forward(self, *args: Tensor) -> Tensor: | |
| return self.f(torch.cat(args, -1)) | |
| class MLPMaxGate(nn.Module): | |
| def __init__( | |
| self, | |
| input_size: int, | |
| hidden_size: int, | |
| mul_activation: float = 10.0, | |
| add_activation: float = 5.0, | |
| bias: bool = True, | |
| ): | |
| """ | |
| This is an MLP with the following structure; | |
| Linear(input_size, hidden_size), Tanh(), Linear(hidden_size, 1) | |
| The bias of the last layer is set to 5.0 to start with high probability | |
| of keeping states (fundamental for good convergence as the initialized | |
| DiffMask has not learned what to mask yet). | |
| It also uses a scaler for the output of the activation function. | |
| Args: | |
| input_size (int): the number of input features | |
| hidden_size (int): the number of hidden units | |
| mul_activation (float): the scaler for the output of the activation function | |
| add_activation (float): the offset for the output of the activation function | |
| bias (bool): whether to use a bias term | |
| """ | |
| super().__init__() | |
| self.f = nn.Sequential( | |
| nn.utils.weight_norm(nn.Linear(input_size, hidden_size)), | |
| nn.Tanh(), | |
| nn.utils.weight_norm(nn.Linear(hidden_size, 1, bias=bias)), | |
| nn.Tanh(), | |
| ) | |
| self.add_activation = nn.Parameter(torch.tensor(add_activation)) | |
| self.mul_activation = mul_activation | |
| def forward(self, *args: Tensor) -> Tensor: | |
| return self.f(torch.cat(args, -1)) * self.mul_activation + self.add_activation | |
| class DiffMaskGateInput(nn.Module): | |
| def __init__( | |
| self, | |
| hidden_size: int, | |
| hidden_attention: int, | |
| num_hidden_layers: int, | |
| max_position_embeddings: int, | |
| gate_fn: nn.Module = MLPMaxGate, | |
| mul_activation: float = 10.0, | |
| add_activation: float = 5.0, | |
| gate_bias: bool = True, | |
| placeholder: bool = False, | |
| init_vector: Tensor = None, | |
| ): | |
| """This is a DiffMask module that masks the input of the first layer. | |
| Args: | |
| hidden_size (int): the size of the hidden representations | |
| hidden_attention (int) the amount of units in the gate's hidden (bottleneck) layer | |
| num_hidden_layers (int): the number of hidden layers (and thus gates to use) | |
| max_position_embeddings (int): the amount of placeholder embeddings to learn for the masked positions | |
| gate_fn (nn.Module): the PyTorch module to use as a gate | |
| mul_activation (float): the scaler for the output of the activation function | |
| add_activation (float): the offset for the output of the activation function | |
| gate_bias (bool): whether to use a bias term | |
| placeholder (bool): whether to use placeholder embeddings or a zero vector | |
| init_vector (Tensor): the initial vector to use for the placeholder embeddings | |
| """ | |
| super().__init__() | |
| # Create a ModuleList with the gates | |
| self.g_hat = nn.ModuleList( | |
| [ | |
| gate_fn( | |
| hidden_size * 2, | |
| hidden_attention, | |
| mul_activation, | |
| add_activation, | |
| gate_bias, | |
| ) | |
| for _ in range(num_hidden_layers) | |
| ] | |
| ) | |
| if placeholder: | |
| # Use a placeholder embedding for the masked positions | |
| self.placeholder = nn.Parameter( | |
| nn.init.xavier_normal_( | |
| torch.empty((1, max_position_embeddings, hidden_size)) | |
| ) | |
| if init_vector is None | |
| else init_vector.view(1, 1, hidden_size).repeat( | |
| 1, max_position_embeddings, 1 | |
| ) | |
| ) | |
| else: | |
| # Use a zero vector for the masked positions | |
| self.register_buffer( | |
| "placeholder", | |
| torch.zeros((1, 1, hidden_size)), | |
| ) | |
| def forward( | |
| self, hidden_states: tuple[Tensor], layer_pred: Optional[int] | |
| ) -> tuple[tuple[Tensor], Tensor, Tensor, Tensor, Tensor]: | |
| # Concatenate the output of all the gates | |
| logits = torch.cat( | |
| [ | |
| self.g_hat[i](hidden_states[0], hidden_states[i]) | |
| for i in range( | |
| (layer_pred + 1) if layer_pred is not None else len(hidden_states) | |
| ) | |
| ], | |
| -1, | |
| ) | |
| # Define a Hard Concrete distribution | |
| dist = RectifiedStreched( | |
| BinaryConcrete(torch.full_like(logits, 0.2), logits), | |
| l=-0.2, | |
| r=1.0, | |
| ) | |
| # Calculate the expectation for the full gate probabilities | |
| # These act as votes for the masked positions | |
| gates_full = dist.rsample().cumprod(-1) | |
| expected_L0_full = dist.log_expected_L0().cumsum(-1) | |
| # Extract the probabilities from the last layer, which acts | |
| # as an aggregation of the votes per position | |
| gates = gates_full[..., -1] | |
| expected_L0 = expected_L0_full[..., -1] | |
| return ( | |
| hidden_states[0] * gates.unsqueeze(-1) | |
| + self.placeholder[:, : hidden_states[0].shape[-2]] | |
| * (1 - gates).unsqueeze(-1), | |
| gates, | |
| expected_L0, | |
| gates_full, | |
| expected_L0_full, | |
| ) | |
| # class DiffMaskGateHidden(nn.Module): | |
| # def __init__( | |
| # self, | |
| # hidden_size: int, | |
| # hidden_attention: int, | |
| # num_hidden_layers: int, | |
| # max_position_embeddings: int, | |
| # gate_fn: nn.Module = MLPMaxGate, | |
| # gate_bias: bool = True, | |
| # placeholder: bool = False, | |
| # init_vector: Tensor = None, | |
| # ): | |
| # super().__init__() | |
| # | |
| # self.g_hat = nn.ModuleList( | |
| # [ | |
| # gate_fn(hidden_size, hidden_attention, bias=gate_bias) | |
| # for _ in range(num_hidden_layers) | |
| # ] | |
| # ) | |
| # | |
| # if placeholder: | |
| # self.placeholder = nn.ParameterList( | |
| # [ | |
| # nn.Parameter( | |
| # nn.init.xavier_normal_( | |
| # torch.empty((1, max_position_embeddings, hidden_size)) | |
| # ) | |
| # if init_vector is None | |
| # else init_vector.view(1, 1, hidden_size).repeat( | |
| # 1, max_position_embeddings, 1 | |
| # ) | |
| # ) | |
| # for _ in range(num_hidden_layers) | |
| # ] | |
| # ) | |
| # else: | |
| # self.register_buffer( | |
| # "placeholder", | |
| # torch.zeros((num_hidden_layers, 1, 1, hidden_size)), | |
| # ) | |
| # | |
| # def forward( | |
| # self, hidden_states: tuple[Tensor], layer_pred: Optional[int] | |
| # ) -> tuple[tuple[Tensor], Tensor, Tensor, Tensor, Tensor]: | |
| # if layer_pred is not None: | |
| # logits = self.g_hat[layer_pred](hidden_states[layer_pred]) | |
| # else: | |
| # logits = torch.cat( | |
| # [self.g_hat[i](hidden_states[i]) for i in range(len(hidden_states))], -1 | |
| # ) | |
| # | |
| # dist = RectifiedStreched( | |
| # BinaryConcrete(torch.full_like(logits, 0.2), logits), | |
| # l=-0.2, | |
| # r=1.0, | |
| # ) | |
| # | |
| # gates_full = dist.rsample() | |
| # expected_L0_full = dist.log_expected_L0() | |
| # | |
| # gates = gates_full if layer_pred is not None else gates_full[..., :1] | |
| # expected_L0 = ( | |
| # expected_L0_full if layer_pred is not None else expected_L0_full[..., :1] | |
| # ) | |
| # | |
| # layer_pred = layer_pred or 0 # equiv to "layer_pred if layer_pred else 0" | |
| # return ( | |
| # hidden_states[layer_pred] * gates | |
| # + self.placeholder[layer_pred][:, : hidden_states[layer_pred].shape[-2]] | |
| # * (1 - gates), | |
| # gates.squeeze(-1), | |
| # expected_L0.squeeze(-1), | |
| # gates_full, | |
| # expected_L0_full, | |
| # ) | |