anthonym21 commited on
Commit
a6e6ea4
·
verified ·
1 Parent(s): cf6f81c

perf: remove CPU-GPU sync bottleneck in SharedMoE routing loop

Browse files
Files changed (1) hide show
  1. modeling_eve.py +4 -5
modeling_eve.py CHANGED
@@ -175,11 +175,10 @@ class SharedMoE(nn.Module):
175
  mask = flat_indices == i
176
  batch_idx, rank_idx = torch.where(mask)
177
 
178
- if batch_idx.numel() > 0:
179
- expert_input = flat_x[batch_idx]
180
- expert_output = expert(expert_input)
181
- weight = flat_weights[batch_idx, rank_idx].unsqueeze(-1)
182
- routed_out.view(-1, C).index_add_(0, batch_idx, expert_output * weight)
183
 
184
  return shared_out + routed_out, aux_loss
185
 
 
175
  mask = flat_indices == i
176
  batch_idx, rank_idx = torch.where(mask)
177
 
178
+ expert_input = flat_x[batch_idx]
179
+ expert_output = expert(expert_input)
180
+ weight = flat_weights[batch_idx, rank_idx].unsqueeze(-1)
181
+ routed_out.view(-1, C).index_add_(0, batch_idx, expert_output * weight)
 
182
 
183
  return shared_out + routed_out, aux_loss
184