perf: remove CPU-GPU sync bottleneck in SharedMoE routing loop
Browse files- modeling_eve.py +4 -5
modeling_eve.py
CHANGED
|
@@ -175,11 +175,10 @@ class SharedMoE(nn.Module):
|
|
| 175 |
mask = flat_indices == i
|
| 176 |
batch_idx, rank_idx = torch.where(mask)
|
| 177 |
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
routed_out.view(-1, C).index_add_(0, batch_idx, expert_output * weight)
|
| 183 |
|
| 184 |
return shared_out + routed_out, aux_loss
|
| 185 |
|
|
|
|
| 175 |
mask = flat_indices == i
|
| 176 |
batch_idx, rank_idx = torch.where(mask)
|
| 177 |
|
| 178 |
+
expert_input = flat_x[batch_idx]
|
| 179 |
+
expert_output = expert(expert_input)
|
| 180 |
+
weight = flat_weights[batch_idx, rank_idx].unsqueeze(-1)
|
| 181 |
+
routed_out.view(-1, C).index_add_(0, batch_idx, expert_output * weight)
|
|
|
|
| 182 |
|
| 183 |
return shared_out + routed_out, aux_loss
|
| 184 |
|