Charlie81 commited on
Commit
e868dbf
·
1 Parent(s): c085dea

2 new depth routing types

Browse files
Files changed (1) hide show
  1. myolmoe/modeling_myolmoe.py +8 -1
myolmoe/modeling_myolmoe.py CHANGED
@@ -447,8 +447,9 @@ OLMOE_ATTENTION_CLASSES = {
447
 
448
 
449
  class OlmoeSparseMoeBlock(nn.Module):
450
- def __init__(self, config):
451
  super().__init__()
 
452
  self.num_experts = config.num_experts
453
  self.top_k = config.num_experts_per_tok
454
  self.norm_topk_prob = config.norm_topk_prob
@@ -481,6 +482,12 @@ class OlmoeSparseMoeBlock(nn.Module):
481
  sorted_weights, sorted_indices = torch.sort(routing_probs, dim=-1, descending=True)
482
  selected_experts = sorted_indices[:, ::self.n_step][:, :self.top_k]
483
  routing_weights = routing_probs.gather(1, selected_experts)
 
 
 
 
 
 
484
  else:
485
  raise ValueError(f"Unknown routing type: {self.routing_type}")
486
 
 
447
 
448
 
449
  class OlmoeSparseMoeBlock(nn.Module):
450
+ def __init__(self, config, layer_idx: int):
451
  super().__init__()
452
+ self.layer_idx = layer_idx
453
  self.num_experts = config.num_experts
454
  self.top_k = config.num_experts_per_tok
455
  self.norm_topk_prob = config.norm_topk_prob
 
482
  sorted_weights, sorted_indices = torch.sort(routing_probs, dim=-1, descending=True)
483
  selected_experts = sorted_indices[:, ::self.n_step][:, :self.top_k]
484
  routing_weights = routing_probs.gather(1, selected_experts)
485
+ elif self.routing_type == "depthconstant":
486
+ effective_top_k = max(1, self.top_k - (self.layer_idx // 2))
487
+ routing_weights, selected_experts = torch.topk(routing_probs, effective_top_k, dim=-1)
488
+ elif self.routing_type == "depthlatter":
489
+ effective_top_k = self.top_k if self.layer_idx < 8 else max(1, self.top_k + 8 - self.layer_idx)
490
+ routing_weights, selected_experts = torch.topk(routing_probs, effective_top_k, dim=-1)
491
  else:
492
  raise ValueError(f"Unknown routing type: {self.routing_type}")
493