Charlie81 commited on
Commit
dff49bf
·
1 Parent(s): e868dbf

corrected depth routing

Browse files
Files changed (1) hide show
  1. myolmoe/modeling_myolmoe.py +9 -2
myolmoe/modeling_myolmoe.py CHANGED
@@ -483,11 +483,18 @@ class OlmoeSparseMoeBlock(nn.Module):
483
  selected_experts = sorted_indices[:, ::self.n_step][:, :self.top_k]
484
  routing_weights = routing_probs.gather(1, selected_experts)
485
  elif self.routing_type == "depthconstant":
486
- effective_top_k = max(1, self.top_k - (self.layer_idx // 2))
 
 
487
  routing_weights, selected_experts = torch.topk(routing_probs, effective_top_k, dim=-1)
488
  elif self.routing_type == "depthlatter":
489
- effective_top_k = self.top_k if self.layer_idx < 8 else max(1, self.top_k + 8 - self.layer_idx)
 
 
 
 
490
  routing_weights, selected_experts = torch.topk(routing_probs, effective_top_k, dim=-1)
 
491
  else:
492
  raise ValueError(f"Unknown routing type: {self.routing_type}")
493
 
 
483
  selected_experts = sorted_indices[:, ::self.n_step][:, :self.top_k]
484
  routing_weights = routing_probs.gather(1, selected_experts)
485
  elif self.routing_type == "depthconstant":
486
+ # Assumes there are 16 layers
487
+ slope = (self.top_k - 1) / 15
488
+ effective_top_k = max(1, round(self.top_k - self.layer_idx * slope))
489
  routing_weights, selected_experts = torch.topk(routing_probs, effective_top_k, dim=-1)
490
  elif self.routing_type == "depthlatter":
491
+ if self.layer_idx < 8:
492
+ effective_top_k = self.top_k
493
+ else:
494
+ slope = (self.top_k - 1) / 7
495
+ effective_top_k = max(1, round(self.top_k - (self.layer_idx - 8) * slope))
496
  routing_weights, selected_experts = torch.topk(routing_probs, effective_top_k, dim=-1)
497
+
498
  else:
499
  raise ValueError(f"Unknown routing type: {self.routing_type}")
500