robinfaro commited on
Commit
e6ede83
·
verified ·
1 Parent(s): dc96afd

Adding files from hf_modeling_btm_log_prob_mixing

Browse files
Files changed (1) hide show
  1. aux_losses.py +88 -0
aux_losses.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ def log_mean(x, dim):
7
+ return torch.logsumexp(x, dim=dim) - torch.log(
8
+ torch.tensor(x.shape[dim], dtype=torch.float32)
9
+ )
10
+
11
+
12
+ def entropy_reg(logits: torch.Tensor, mean_over_batch: bool = True):
13
+ """Entropy regularization for the router."""
14
+
15
+ entropy_l = lambda l: -(l * l.exp()).sum(-1)
16
+ # softmax over experts
17
+ # logits: [batch_size * sequence_length, num_experts]
18
+ logprobs = F.log_softmax(logits, dim=-1)
19
+ if mean_over_batch:
20
+ # take mean probability over batch
21
+ logprobs = log_mean(logprobs, 0)
22
+
23
+ return -entropy_l(logprobs).mean()
24
+
25
+
26
+ # two losses below are adapted from
27
+ # https://github.com/google/flaxformer/blob/b725bd2a51d70e866d819c92de166fbf24425e6a/flaxformer/architectures/moe/routing.py
28
+ def load_balancing_loss(logits: torch.Tensor, expert_indices: torch.Tensor) -> float:
29
+ """Computes auxiliary load balancing loss as in Switch Transformer.
30
+
31
+ See Switch Transformer (https://arxiv.org/abs/2101.03961). This function
32
+ implements the loss function presented in equations (4) - (6). It aims to
33
+ penalize those cases where the routing between experts is unbalanced.
34
+
35
+ Args:
36
+ logits: logits assigned to each expert per token. Shape:
37
+ <float32>[batch_size * sequence_length, num_experts].
38
+ expert_indices: <int>[batch_size * sequence_length, num_selected_experts]
39
+ indices identifying the top num_selected_experts for a given token.
40
+
41
+ Returns:
42
+ The auxiliary loss.
43
+ """
44
+ # num_token = batch_size * sequence_length
45
+ num_token, num_experts = logits.shape
46
+
47
+ # Shape: [batch_size * sequence_length, num_selected_experts, num_experts].
48
+ expert_mask = F.one_hot(expert_indices, num_experts)
49
+ # For a given token, determine if it was routed to a given expert.
50
+ # Shape: [batch_size * sequence_length, num_experts]
51
+ expert_mask, _ = torch.max(expert_mask, dim=-2)
52
+
53
+ # shape [num_experts]
54
+ tokens_per_expert = torch.mean(expert_mask, dim=0, dtype=torch.float32)
55
+
56
+ # compute router probability per expert in log space for numerical stability
57
+ logprobs = F.log_softmax(logits, dim=-1)
58
+ # take mean probability over batch
59
+ # shape [num_experts]
60
+ logprobs = log_mean(logprobs, dim=0)
61
+ router_prob_per_expert = torch.exp(logprobs)
62
+ return (
63
+ torch.mean( # mean over experts
64
+ tokens_per_expert * router_prob_per_expert,
65
+ dtype=torch.float32,
66
+ )
67
+ * num_experts
68
+ )
69
+
70
+
71
+ def router_z_loss(router_logits: torch.Tensor) -> float:
72
+ """Compute router z-loss.
73
+
74
+ The router z-loss was introduced in Designing Effective Sparse Expert Models
75
+ (https://arxiv.org/abs/2202.08906). It encourages router logits to remain
76
+ small in an effort to improve stability.
77
+
78
+ Args:
79
+ router_logits: <float>[batch_size * sequence_length, num_experts]
80
+ router logits
81
+
82
+ Returns:
83
+ Scalar router z-loss.
84
+ """
85
+ num_tokens, _ = router_logits.shape
86
+ log_z = torch.logsumexp(router_logits, dim=-1)
87
+ z_loss = log_z**2
88
+ return torch.sum(z_loss, dtype=torch.float32) / (num_tokens)