add different routing types
Browse files- myolmoe/modeling_myolmoe.py +37 -23
- scripts/eval.py +2 -2
myolmoe/modeling_myolmoe.py
CHANGED
|
@@ -452,46 +452,60 @@ class OlmoeSparseMoeBlock(nn.Module):
|
|
| 452 |
self.num_experts = config.num_experts
|
| 453 |
self.top_k = config.num_experts_per_tok
|
| 454 |
self.norm_topk_prob = config.norm_topk_prob
|
|
|
|
|
|
|
| 455 |
self.gate = nn.Linear(config.hidden_size, self.num_experts, bias=False)
|
| 456 |
-
self.experts = nn.ModuleList(
|
| 457 |
-
[OlmoeMLP(config) for _ in range(self.num_experts)]
|
| 458 |
-
)
|
| 459 |
|
| 460 |
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 461 |
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
| 462 |
hidden_states = hidden_states.view(-1, hidden_dim)
|
| 463 |
router_logits = self.gate(hidden_states)
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
#
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 470 |
|
| 471 |
if self.norm_topk_prob:
|
| 472 |
-
routing_weights
|
|
|
|
| 473 |
routing_weights = routing_weights.to(hidden_states.dtype)
|
| 474 |
final_hidden_states = torch.zeros(
|
| 475 |
(batch_size * sequence_length, hidden_dim),
|
| 476 |
dtype=hidden_states.dtype,
|
| 477 |
device=hidden_states.device,
|
| 478 |
)
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
|
| 482 |
for expert_idx in range(self.num_experts):
|
| 483 |
expert_layer = self.experts[expert_idx]
|
| 484 |
idx, top_x = torch.where(expert_mask[expert_idx])
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
)
|
| 489 |
-
final_hidden_states.index_add_(
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
final_hidden_states = final_hidden_states.reshape(
|
| 493 |
-
batch_size, sequence_length, hidden_dim
|
| 494 |
-
)
|
| 495 |
return final_hidden_states, router_logits
|
| 496 |
|
| 497 |
|
|
|
|
| 452 |
self.num_experts = config.num_experts
|
| 453 |
self.top_k = config.num_experts_per_tok
|
| 454 |
self.norm_topk_prob = config.norm_topk_prob
|
| 455 |
+
self.routing_type = getattr(config, "routing_type", "topk") # default to topk
|
| 456 |
+
self.n_step = getattr(config, "nth_step", 2) # used in nth-descending
|
| 457 |
self.gate = nn.Linear(config.hidden_size, self.num_experts, bias=False)
|
| 458 |
+
self.experts = nn.ModuleList([OlmoeMLP(config) for _ in range(self.num_experts)])
|
|
|
|
|
|
|
| 459 |
|
| 460 |
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 461 |
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
| 462 |
hidden_states = hidden_states.view(-1, hidden_dim)
|
| 463 |
router_logits = self.gate(hidden_states)
|
| 464 |
+
routing_probs = F.softmax(router_logits, dim=1, dtype=torch.float)
|
| 465 |
+
|
| 466 |
+
# === Routing strategy selection ===
|
| 467 |
+
if self.routing_type == "topk":
|
| 468 |
+
routing_weights, selected_experts = torch.topk(routing_probs, self.top_k, dim=-1)
|
| 469 |
+
elif self.routing_type == "multinomial":
|
| 470 |
+
selected_experts = torch.multinomial(routing_probs, self.top_k, replacement=False)
|
| 471 |
+
routing_weights = routing_probs.gather(1, selected_experts)
|
| 472 |
+
elif self.routing_type == "botk":
|
| 473 |
+
routing_weights, selected_experts = torch.topk(routing_probs, self.top_k, dim=-1, largest=False)
|
| 474 |
+
elif self.routing_type == "topk+botk":
|
| 475 |
+
top_weights, top_experts = torch.topk(routing_probs, self.top_k, dim=-1)
|
| 476 |
+
bot_weights, bot_experts = torch.topk(routing_probs, self.top_k, dim=-1, largest=False)
|
| 477 |
+
selected_experts = torch.cat([top_experts, bot_experts], dim=-1)
|
| 478 |
+
routing_weights = torch.cat([top_weights, bot_weights], dim=-1)
|
| 479 |
+
elif self.routing_type == "nth-descending":
|
| 480 |
+
# Sort all experts descending and pick every nth
|
| 481 |
+
sorted_weights, sorted_indices = torch.sort(routing_probs, dim=-1, descending=True)
|
| 482 |
+
selected_experts = sorted_indices[:, ::self.n_step][:, :self.top_k]
|
| 483 |
+
routing_weights = routing_probs.gather(1, selected_experts)
|
| 484 |
+
else:
|
| 485 |
+
raise ValueError(f"Unknown routing type: {self.routing_type}")
|
| 486 |
|
| 487 |
if self.norm_topk_prob:
|
| 488 |
+
routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
|
| 489 |
+
|
| 490 |
routing_weights = routing_weights.to(hidden_states.dtype)
|
| 491 |
final_hidden_states = torch.zeros(
|
| 492 |
(batch_size * sequence_length, hidden_dim),
|
| 493 |
dtype=hidden_states.dtype,
|
| 494 |
device=hidden_states.device,
|
| 495 |
)
|
| 496 |
+
|
| 497 |
+
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
|
| 498 |
+
|
| 499 |
for expert_idx in range(self.num_experts):
|
| 500 |
expert_layer = self.experts[expert_idx]
|
| 501 |
idx, top_x = torch.where(expert_mask[expert_idx])
|
| 502 |
+
if top_x.numel() == 0:
|
| 503 |
+
continue
|
| 504 |
+
current_state = hidden_states[top_x]
|
| 505 |
+
current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
|
| 506 |
+
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
|
| 507 |
+
|
| 508 |
+
final_hidden_states = final_hidden_states.view(batch_size, sequence_length, hidden_dim)
|
|
|
|
|
|
|
|
|
|
| 509 |
return final_hidden_states, router_logits
|
| 510 |
|
| 511 |
|
scripts/eval.py
CHANGED
|
@@ -80,8 +80,8 @@ Examples:
|
|
| 80 |
parser.add_argument(
|
| 81 |
"--routing_type",
|
| 82 |
type=str,
|
| 83 |
-
default="
|
| 84 |
-
choices=["
|
| 85 |
help="Routing type (only used with custom models)"
|
| 86 |
)
|
| 87 |
parser.add_argument(
|
|
|
|
| 80 |
parser.add_argument(
|
| 81 |
"--routing_type",
|
| 82 |
type=str,
|
| 83 |
+
default="topk",
|
| 84 |
+
choices=["topk", "multinomial", "botk", "topk+botk", "nth-descending"],
|
| 85 |
help="Routing type (only used with custom models)"
|
| 86 |
)
|
| 87 |
parser.add_argument(
|