| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| def log_mean(x, dim): |
| return torch.logsumexp(x, dim=dim) - torch.log( |
| torch.tensor(x.shape[dim], dtype=torch.float32) |
| ) |
|
|
|
|
| def entropy_reg(logits: torch.Tensor, mean_over_batch: bool = True): |
| """Entropy regularization for the router.""" |
|
|
| entropy_l = lambda l: -(l * l.exp()).sum(-1) |
| |
| |
| logprobs = F.log_softmax(logits, dim=-1) |
| if mean_over_batch: |
| |
| logprobs = log_mean(logprobs, 0) |
|
|
| return -entropy_l(logprobs).mean() |
|
|
|
|
| |
| |
| def load_balancing_loss(logits: torch.Tensor, expert_indices: torch.Tensor) -> float: |
| """Computes auxiliary load balancing loss as in Switch Transformer. |
| |
| See Switch Transformer (https://arxiv.org/abs/2101.03961). This function |
| implements the loss function presented in equations (4) - (6). It aims to |
| penalize those cases where the routing between experts is unbalanced. |
| |
| Args: |
| logits: logits assigned to each expert per token. Shape: |
| <float32>[batch_size * sequence_length, num_experts]. |
| expert_indices: <int>[batch_size * sequence_length, num_selected_experts] |
| indices identifying the top num_selected_experts for a given token. |
| |
| Returns: |
| The auxiliary loss. |
| """ |
| |
| num_token, num_experts = logits.shape |
|
|
| |
| expert_mask = F.one_hot(expert_indices, num_experts) |
| |
| |
| expert_mask, _ = torch.max(expert_mask, dim=-2) |
|
|
| |
| tokens_per_expert = torch.mean(expert_mask, dim=0, dtype=torch.float32) |
|
|
| |
| logprobs = F.log_softmax(logits, dim=-1) |
| |
| |
| logprobs = log_mean(logprobs, dim=0) |
| router_prob_per_expert = torch.exp(logprobs) |
| return ( |
| torch.mean( |
| tokens_per_expert * router_prob_per_expert, |
| dtype=torch.float32, |
| ) |
| * num_experts |
| ) |
|
|
|
|
| def router_z_loss(router_logits: torch.Tensor) -> float: |
| """Compute router z-loss. |
| |
| The router z-loss was introduced in Designing Effective Sparse Expert Models |
| (https://arxiv.org/abs/2202.08906). It encourages router logits to remain |
| small in an effort to improve stability. |
| |
| Args: |
| router_logits: <float>[batch_size * sequence_length, num_experts] |
| router logits |
| |
| Returns: |
| Scalar router z-loss. |
| """ |
| num_tokens, _ = router_logits.shape |
| log_z = torch.logsumexp(router_logits, dim=-1) |
| z_loss = log_z**2 |
| return torch.sum(z_loss, dtype=torch.float32) / (num_tokens) |
|
|