Adding files from hf_modeling_btm_log_prob_mixing
Browse files- aux_losses.py +88 -0
aux_losses.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def log_mean(x, dim):
|
| 7 |
+
return torch.logsumexp(x, dim=dim) - torch.log(
|
| 8 |
+
torch.tensor(x.shape[dim], dtype=torch.float32)
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def entropy_reg(logits: torch.Tensor, mean_over_batch: bool = True):
|
| 13 |
+
"""Entropy regularization for the router."""
|
| 14 |
+
|
| 15 |
+
entropy_l = lambda l: -(l * l.exp()).sum(-1)
|
| 16 |
+
# softmax over experts
|
| 17 |
+
# logits: [batch_size * sequence_length, num_experts]
|
| 18 |
+
logprobs = F.log_softmax(logits, dim=-1)
|
| 19 |
+
if mean_over_batch:
|
| 20 |
+
# take mean probability over batch
|
| 21 |
+
logprobs = log_mean(logprobs, 0)
|
| 22 |
+
|
| 23 |
+
return -entropy_l(logprobs).mean()
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# two losses below are adapted from
|
| 27 |
+
# https://github.com/google/flaxformer/blob/b725bd2a51d70e866d819c92de166fbf24425e6a/flaxformer/architectures/moe/routing.py
|
| 28 |
+
def load_balancing_loss(logits: torch.Tensor, expert_indices: torch.Tensor) -> float:
|
| 29 |
+
"""Computes auxiliary load balancing loss as in Switch Transformer.
|
| 30 |
+
|
| 31 |
+
See Switch Transformer (https://arxiv.org/abs/2101.03961). This function
|
| 32 |
+
implements the loss function presented in equations (4) - (6). It aims to
|
| 33 |
+
penalize those cases where the routing between experts is unbalanced.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
logits: logits assigned to each expert per token. Shape:
|
| 37 |
+
<float32>[batch_size * sequence_length, num_experts].
|
| 38 |
+
expert_indices: <int>[batch_size * sequence_length, num_selected_experts]
|
| 39 |
+
indices identifying the top num_selected_experts for a given token.
|
| 40 |
+
|
| 41 |
+
Returns:
|
| 42 |
+
The auxiliary loss.
|
| 43 |
+
"""
|
| 44 |
+
# num_token = batch_size * sequence_length
|
| 45 |
+
num_token, num_experts = logits.shape
|
| 46 |
+
|
| 47 |
+
# Shape: [batch_size * sequence_length, num_selected_experts, num_experts].
|
| 48 |
+
expert_mask = F.one_hot(expert_indices, num_experts)
|
| 49 |
+
# For a given token, determine if it was routed to a given expert.
|
| 50 |
+
# Shape: [batch_size * sequence_length, num_experts]
|
| 51 |
+
expert_mask, _ = torch.max(expert_mask, dim=-2)
|
| 52 |
+
|
| 53 |
+
# shape [num_experts]
|
| 54 |
+
tokens_per_expert = torch.mean(expert_mask, dim=0, dtype=torch.float32)
|
| 55 |
+
|
| 56 |
+
# compute router probability per expert in log space for numerical stability
|
| 57 |
+
logprobs = F.log_softmax(logits, dim=-1)
|
| 58 |
+
# take mean probability over batch
|
| 59 |
+
# shape [num_experts]
|
| 60 |
+
logprobs = log_mean(logprobs, dim=0)
|
| 61 |
+
router_prob_per_expert = torch.exp(logprobs)
|
| 62 |
+
return (
|
| 63 |
+
torch.mean( # mean over experts
|
| 64 |
+
tokens_per_expert * router_prob_per_expert,
|
| 65 |
+
dtype=torch.float32,
|
| 66 |
+
)
|
| 67 |
+
* num_experts
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def router_z_loss(router_logits: torch.Tensor) -> float:
|
| 72 |
+
"""Compute router z-loss.
|
| 73 |
+
|
| 74 |
+
The router z-loss was introduced in Designing Effective Sparse Expert Models
|
| 75 |
+
(https://arxiv.org/abs/2202.08906). It encourages router logits to remain
|
| 76 |
+
small in an effort to improve stability.
|
| 77 |
+
|
| 78 |
+
Args:
|
| 79 |
+
router_logits: <float>[batch_size * sequence_length, num_experts]
|
| 80 |
+
router logits
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
Scalar router z-loss.
|
| 84 |
+
"""
|
| 85 |
+
num_tokens, _ = router_logits.shape
|
| 86 |
+
log_z = torch.logsumexp(router_logits, dim=-1)
|
| 87 |
+
z_loss = log_z**2
|
| 88 |
+
return torch.sum(z_loss, dtype=torch.float32) / (num_tokens)
|