Charlie81 commited on
Commit
c085dea
·
1 Parent(s): 3aa53b4

add different routing types

Browse files
Files changed (2) hide show
  1. myolmoe/modeling_myolmoe.py +37 -23
  2. scripts/eval.py +2 -2
myolmoe/modeling_myolmoe.py CHANGED
@@ -452,46 +452,60 @@ class OlmoeSparseMoeBlock(nn.Module):
452
  self.num_experts = config.num_experts
453
  self.top_k = config.num_experts_per_tok
454
  self.norm_topk_prob = config.norm_topk_prob
 
 
455
  self.gate = nn.Linear(config.hidden_size, self.num_experts, bias=False)
456
- self.experts = nn.ModuleList(
457
- [OlmoeMLP(config) for _ in range(self.num_experts)]
458
- )
459
 
460
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
461
  batch_size, sequence_length, hidden_dim = hidden_states.shape
462
  hidden_states = hidden_states.view(-1, hidden_dim)
463
  router_logits = self.gate(hidden_states)
464
- routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
465
- # routing_weights, selected_experts = torch.topk(
466
- # routing_weights, self.top_k, dim=-1
467
- # )
468
- selected_experts = torch.multinomial(routing_weights, self.top_k, replacement=False)
469
- routing_weights = routing_weights.gather(1, selected_experts)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
470
 
471
  if self.norm_topk_prob:
472
- routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
 
473
  routing_weights = routing_weights.to(hidden_states.dtype)
474
  final_hidden_states = torch.zeros(
475
  (batch_size * sequence_length, hidden_dim),
476
  dtype=hidden_states.dtype,
477
  device=hidden_states.device,
478
  )
479
- expert_mask = torch.nn.functional.one_hot(
480
- selected_experts, num_classes=self.num_experts
481
- ).permute(2, 1, 0)
482
  for expert_idx in range(self.num_experts):
483
  expert_layer = self.experts[expert_idx]
484
  idx, top_x = torch.where(expert_mask[expert_idx])
485
- current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
486
- current_hidden_states = (
487
- expert_layer(current_state) * routing_weights[top_x, idx, None]
488
- )
489
- final_hidden_states.index_add_(
490
- 0, top_x, current_hidden_states.to(hidden_states.dtype)
491
- )
492
- final_hidden_states = final_hidden_states.reshape(
493
- batch_size, sequence_length, hidden_dim
494
- )
495
  return final_hidden_states, router_logits
496
 
497
 
 
452
  self.num_experts = config.num_experts
453
  self.top_k = config.num_experts_per_tok
454
  self.norm_topk_prob = config.norm_topk_prob
455
+ self.routing_type = getattr(config, "routing_type", "topk") # default to topk
456
+ self.n_step = getattr(config, "nth_step", 2) # used in nth-descending
457
  self.gate = nn.Linear(config.hidden_size, self.num_experts, bias=False)
458
+ self.experts = nn.ModuleList([OlmoeMLP(config) for _ in range(self.num_experts)])
 
 
459
 
460
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
461
  batch_size, sequence_length, hidden_dim = hidden_states.shape
462
  hidden_states = hidden_states.view(-1, hidden_dim)
463
  router_logits = self.gate(hidden_states)
464
+ routing_probs = F.softmax(router_logits, dim=1, dtype=torch.float)
465
+
466
+ # === Routing strategy selection ===
467
+ if self.routing_type == "topk":
468
+ routing_weights, selected_experts = torch.topk(routing_probs, self.top_k, dim=-1)
469
+ elif self.routing_type == "multinomial":
470
+ selected_experts = torch.multinomial(routing_probs, self.top_k, replacement=False)
471
+ routing_weights = routing_probs.gather(1, selected_experts)
472
+ elif self.routing_type == "botk":
473
+ routing_weights, selected_experts = torch.topk(routing_probs, self.top_k, dim=-1, largest=False)
474
+ elif self.routing_type == "topk+botk":
475
+ top_weights, top_experts = torch.topk(routing_probs, self.top_k, dim=-1)
476
+ bot_weights, bot_experts = torch.topk(routing_probs, self.top_k, dim=-1, largest=False)
477
+ selected_experts = torch.cat([top_experts, bot_experts], dim=-1)
478
+ routing_weights = torch.cat([top_weights, bot_weights], dim=-1)
479
+ elif self.routing_type == "nth-descending":
480
+ # Sort all experts descending and pick every nth
481
+ sorted_weights, sorted_indices = torch.sort(routing_probs, dim=-1, descending=True)
482
+ selected_experts = sorted_indices[:, ::self.n_step][:, :self.top_k]
483
+ routing_weights = routing_probs.gather(1, selected_experts)
484
+ else:
485
+ raise ValueError(f"Unknown routing type: {self.routing_type}")
486
 
487
  if self.norm_topk_prob:
488
+ routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
489
+
490
  routing_weights = routing_weights.to(hidden_states.dtype)
491
  final_hidden_states = torch.zeros(
492
  (batch_size * sequence_length, hidden_dim),
493
  dtype=hidden_states.dtype,
494
  device=hidden_states.device,
495
  )
496
+
497
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
498
+
499
  for expert_idx in range(self.num_experts):
500
  expert_layer = self.experts[expert_idx]
501
  idx, top_x = torch.where(expert_mask[expert_idx])
502
+ if top_x.numel() == 0:
503
+ continue
504
+ current_state = hidden_states[top_x]
505
+ current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
506
+ final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
507
+
508
+ final_hidden_states = final_hidden_states.view(batch_size, sequence_length, hidden_dim)
 
 
 
509
  return final_hidden_states, router_logits
510
 
511
 
scripts/eval.py CHANGED
@@ -80,8 +80,8 @@ Examples:
80
  parser.add_argument(
81
  "--routing_type",
82
  type=str,
83
- default="sparse",
84
- choices=["dense", "sparse", "non_deterministic"],
85
  help="Routing type (only used with custom models)"
86
  )
87
  parser.add_argument(
 
80
  parser.add_argument(
81
  "--routing_type",
82
  type=str,
83
+ default="topk",
84
+ choices=["topk", "multinomial", "botk", "topk+botk", "nth-descending"],
85
  help="Routing type (only used with custom models)"
86
  )
87
  parser.add_argument(