| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| import torch |
| import torch.nn as nn |
| from torch.distributions.normal import Normal |
| from copy import deepcopy |
| import numpy as np |
|
|
| from swin2_mose.libs import Mlp as MLP |
|
|
|
|
| class SparseDispatcher(object): |
| """Helper for implementing a mixture of experts. |
| The purpose of this class is to create input minibatches for the |
| experts and to combine the results of the experts to form a unified |
| output tensor. |
| There are two functions: |
| dispatch - take an input Tensor and create input Tensors for each expert. |
| combine - take output Tensors from each expert and form a combined output |
| Tensor. Outputs from different experts for the same batch element are |
| summed together, weighted by the provided "gates". |
| The class is initialized with a "gates" Tensor, which specifies which |
| batch elements go to which experts, and the weights to use when combining |
| the outputs. Batch element b is sent to expert e iff gates[b, e] != 0. |
| The inputs and outputs are all two-dimensional [batch, depth]. |
| Caller is responsible for collapsing additional dimensions prior to |
| calling this class and reshaping the output to the original shape. |
| See common_layers.reshape_like(). |
| Example use: |
| gates: a float32 `Tensor` with shape `[batch_size, num_experts]` |
| inputs: a float32 `Tensor` with shape `[batch_size, input_size]` |
| experts: a list of length `num_experts` containing sub-networks. |
| dispatcher = SparseDispatcher(num_experts, gates) |
| expert_inputs = dispatcher.dispatch(inputs) |
| expert_outputs = [experts[i](expert_inputs[i]) for i in range(num_experts)] |
| outputs = dispatcher.combine(expert_outputs) |
| The preceding code sets the output for a particular example b to: |
| output[b] = Sum_i(gates[b, i] * experts[i](inputs[b])) |
| This class takes advantage of sparsity in the gate matrix by including in the |
| `Tensor`s for expert i only the batch elements for which `gates[b, i] > 0`. |
| """ |
|
|
| def __init__(self, num_experts, gates): |
| """Create a SparseDispatcher.""" |
|
|
| self._gates = gates |
| self._num_experts = num_experts |
| |
| sorted_experts, index_sorted_experts = torch.nonzero(gates).sort(0) |
| |
| _, self._expert_index = sorted_experts.split(1, dim=1) |
| |
| self._batch_index = torch.nonzero(gates)[index_sorted_experts[:, 1], 0] |
| |
| self._part_sizes = (gates > 0).sum(0).tolist() |
| |
| gates_exp = gates[self._batch_index.flatten()] |
| self._nonzero_gates = torch.gather(gates_exp, 1, self._expert_index) |
|
|
| def dispatch(self, inp): |
| """Create one input Tensor for each expert. |
| The `Tensor` for a expert `i` contains the slices of `inp` corresponding |
| to the batch elements `b` where `gates[b, i] > 0`. |
| Args: |
| inp: a `Tensor` of shape "[batch_size, <extra_input_dims>]` |
| Returns: |
| a list of `num_experts` `Tensor`s with shapes |
| `[expert_batch_size_i, <extra_input_dims>]`. |
| """ |
|
|
| |
|
|
| |
| inp_exp = inp[self._batch_index].squeeze(1) |
| return torch.split(inp_exp, self._part_sizes, dim=0) |
|
|
| def combine(self, expert_out, multiply_by_gates=True, cnn_combine=None): |
| """Sum together the expert output, weighted by the gates. |
| The slice corresponding to a particular batch element `b` is computed |
| as the sum over all experts `i` of the expert output, weighted by the |
| corresponding gate values. If `multiply_by_gates` is set to False, the |
| gate values are ignored. |
| Args: |
| expert_out: a list of `num_experts` `Tensor`s, each with shape |
| `[expert_batch_size_i, <extra_output_dims>]`. |
| multiply_by_gates: a boolean |
| Returns: |
| a `Tensor` with shape `[batch_size, <extra_output_dims>]`. |
| """ |
| |
| stitched = torch.cat(expert_out, 0) |
|
|
| if multiply_by_gates: |
| stitched = stitched.mul(self._nonzero_gates.unsqueeze(1)) |
| zeros = torch.zeros((self._gates.size(0),) + expert_out[-1].shape[1:], |
| requires_grad=True, device=stitched.device) |
| |
|
|
| if cnn_combine is not None: |
| return self.smartly_combine(stitched, cnn_combine) |
|
|
| combined = zeros.index_add(0, self._batch_index, stitched.float()) |
| return combined |
|
|
| def smartly_combine(self, stitched, cnn_combine): |
| idxes = [] |
| for i in self._batch_index.unique(): |
| idx = (self._batch_index == i).nonzero().squeeze(1) |
| idxes.append(idx) |
| idxes = torch.stack(idxes) |
| return cnn_combine(stitched[idxes]).squeeze(1) |
|
|
| def expert_to_gates(self): |
| """Gate values corresponding to the examples in the per-expert `Tensor`s. |
| Returns: |
| a list of `num_experts` one-dimensional `Tensor`s with type `tf.float32` |
| and shapes `[expert_batch_size_i]` |
| """ |
| |
| return torch.split(self._nonzero_gates, self._part_sizes, dim=0) |
|
|
|
|
| def build_experts(experts_cfg, default_cfg, num_experts): |
| experts_cfg = deepcopy(experts_cfg) |
| if experts_cfg is None: |
| |
| return nn.ModuleList([ |
| MLP(*default_cfg) |
| for i in range(num_experts)]) |
| |
| experts = [] |
| for e_cfg in experts_cfg: |
| type_ = e_cfg.pop('type') |
| if type_ == 'mlp': |
| experts.append(MLP(*default_cfg)) |
| return nn.ModuleList(experts) |
|
|
|
|
| class MoE(nn.Module): |
| """Call a Sparsely gated mixture of experts layer with 1-layer |
| Feed-Forward networks as experts. |
| |
| Args: |
| input_size: integer - size of the input |
| output_size: integer - size of the input |
| num_experts: an integer - number of experts |
| hidden_size: an integer - hidden size of the experts |
| noisy_gating: a boolean |
| k: an integer - how many experts to use for each batch element |
| """ |
|
|
| def __init__(self, input_size, output_size, num_experts, hidden_size, |
| experts=None, noisy_gating=True, k=4, |
| x_gating=None, with_noise=True, with_smart_merger=None): |
| super(MoE, self).__init__() |
| self.noisy_gating = noisy_gating |
| self.num_experts = num_experts |
| self.output_size = output_size |
| self.input_size = input_size |
| self.hidden_size = hidden_size |
| self.k = k |
| self.with_noise = with_noise |
| |
| self.experts = build_experts( |
| experts, |
| (self.input_size, self.hidden_size, self.output_size), |
| num_experts) |
| self.w_gate = nn.Parameter(torch.zeros(input_size, num_experts), requires_grad=True) |
| self.w_noise = nn.Parameter(torch.zeros(input_size, num_experts), requires_grad=True) |
|
|
| self.x_gating = x_gating |
| if self.x_gating == 'conv1d': |
| self.x_gate = nn.Conv1d(4096, 1, kernel_size=3, padding=1) |
|
|
| self.softplus = nn.Softplus() |
| self.softmax = nn.Softmax(1) |
| self.register_buffer("mean", torch.tensor([0.0])) |
| self.register_buffer("std", torch.tensor([1.0])) |
| assert(self.k <= self.num_experts) |
|
|
| self.cnn_combine = None |
| if with_smart_merger == 'v1': |
| print('with SMART MERGER') |
| self.cnn_combine = nn.Conv2d(self.k, 1, kernel_size=3, padding=1) |
|
|
| def cv_squared(self, x): |
| """The squared coefficient of variation of a sample. |
| Useful as a loss to encourage a positive distribution to be more uniform. |
| Epsilons added for numerical stability. |
| Returns 0 for an empty Tensor. |
| Args: |
| x: a `Tensor`. |
| Returns: |
| a `Scalar`. |
| """ |
| eps = 1e-10 |
| |
|
|
| if x.shape[0] == 1: |
| return torch.tensor([0], device=x.device, dtype=x.dtype) |
| return x.float().var() / (x.float().mean()**2 + eps) |
|
|
| def _gates_to_load(self, gates): |
| """Compute the true load per expert, given the gates. |
| The load is the number of examples for which the corresponding gate is >0. |
| Args: |
| gates: a `Tensor` of shape [batch_size, n] |
| Returns: |
| a float32 `Tensor` of shape [n] |
| """ |
| return (gates > 0).sum(0) |
|
|
| def _prob_in_top_k(self, clean_values, noisy_values, noise_stddev, noisy_top_values): |
| """Helper function to NoisyTopKGating. |
| Computes the probability that value is in top k, given different random noise. |
| This gives us a way of backpropagating from a loss that balances the number |
| of times each expert is in the top k experts per example. |
| In the case of no noise, pass in None for noise_stddev, and the result will |
| not be differentiable. |
| Args: |
| clean_values: a `Tensor` of shape [batch, n]. |
| noisy_values: a `Tensor` of shape [batch, n]. Equal to clean values plus |
| normally distributed noise with standard deviation noise_stddev. |
| noise_stddev: a `Tensor` of shape [batch, n], or None |
| noisy_top_values: a `Tensor` of shape [batch, m]. |
| "values" Output of tf.top_k(noisy_top_values, m). m >= k+1 |
| Returns: |
| a `Tensor` of shape [batch, n]. |
| """ |
| batch = clean_values.size(0) |
| m = noisy_top_values.size(1) |
| top_values_flat = noisy_top_values.flatten() |
|
|
| threshold_positions_if_in = torch.arange(batch, device=clean_values.device) * m + self.k |
| threshold_if_in = torch.unsqueeze(torch.gather(top_values_flat, 0, threshold_positions_if_in), 1) |
| is_in = torch.gt(noisy_values, threshold_if_in) |
| threshold_positions_if_out = threshold_positions_if_in - 1 |
| threshold_if_out = torch.unsqueeze(torch.gather(top_values_flat, 0, threshold_positions_if_out), 1) |
| |
| normal = Normal(self.mean, self.std) |
| prob_if_in = normal.cdf((clean_values - threshold_if_in)/noise_stddev) |
| prob_if_out = normal.cdf((clean_values - threshold_if_out)/noise_stddev) |
| prob = torch.where(is_in, prob_if_in, prob_if_out) |
| return prob |
|
|
| def noisy_top_k_gating(self, x, train, noise_epsilon=1e-2): |
| """Noisy top-k gating. |
| See paper: https://arxiv.org/abs/1701.06538. |
| Args: |
| x: input Tensor with shape [batch_size, input_size] |
| train: a boolean - we only add noise at training time. |
| noise_epsilon: a float |
| Returns: |
| gates: a Tensor with shape [batch_size, num_experts] |
| load: a Tensor with shape [num_experts] |
| """ |
| clean_logits = x @ self.w_gate |
| if self.noisy_gating and train: |
| raw_noise_stddev = x @ self.w_noise |
| noise_stddev = ((self.softplus(raw_noise_stddev) + noise_epsilon)) |
| noisy_logits = clean_logits + (torch.randn_like(clean_logits) * noise_stddev) |
| logits = noisy_logits |
| else: |
| logits = clean_logits |
|
|
| |
| top_logits, top_indices = logits.topk(min(self.k + 1, self.num_experts), dim=1) |
| top_k_logits = top_logits[:, :self.k] |
| top_k_indices = top_indices[:, :self.k] |
| top_k_gates = self.softmax(top_k_logits) |
|
|
| zeros = torch.zeros_like(logits, requires_grad=True) |
| gates = zeros.scatter(1, top_k_indices, top_k_gates) |
|
|
| if self.noisy_gating and self.k < self.num_experts and train: |
| load = (self._prob_in_top_k(clean_logits, noisy_logits, noise_stddev, top_logits)).sum(0) |
| else: |
| load = self._gates_to_load(gates) |
| return gates, load |
|
|
| def forward(self, x, loss_coef=1e-2): |
| """Args: |
| x: tensor shape [batch_size, input_size] |
| train: a boolean scalar. |
| loss_coef: a scalar - multiplier on load-balancing losses |
| |
| Returns: |
| y: a tensor with shape [batch_size, output_size]. |
| extra_training_loss: a scalar. This should be added into the overall |
| training loss of the model. The backpropagation of this loss |
| encourages all experts to be approximately equally used across a batch. |
| """ |
| if self.x_gating is not None: |
| xg = self.x_gate(x).squeeze(1) |
| else: |
| xg = x.mean(1) |
|
|
| gates, load = self.noisy_top_k_gating( |
| xg, self.training and self.with_noise) |
| |
| importance = gates.sum(0) |
| |
| loss = self.cv_squared(importance) + self.cv_squared(load) |
| loss *= loss_coef |
|
|
| dispatcher = SparseDispatcher(self.num_experts, gates) |
| expert_inputs = dispatcher.dispatch(x) |
| gates = dispatcher.expert_to_gates() |
| expert_outputs = [self.experts[i](expert_inputs[i]) |
| for i in range(self.num_experts)] |
| y = dispatcher.combine(expert_outputs, cnn_combine=self.cnn_combine) |
| return y, loss |
|
|