reset modeling file
Browse files- myolmoe/modeling_myolmoe.py +7 -52
- scripts/train.py +1 -2
myolmoe/modeling_myolmoe.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
# modeling_myolmoe.py
|
| 2 |
import math
|
| 3 |
from typing import List, Optional, Tuple, Union
|
| 4 |
import torch
|
|
@@ -157,21 +156,6 @@ class OlmoeMLP(nn.Module):
|
|
| 157 |
def forward(self, x):
|
| 158 |
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
| 159 |
return down_proj
|
| 160 |
-
|
| 161 |
-
class SmallOlmoeMLP(nn.Module):
|
| 162 |
-
def __init__(self, config, small_expert_intermediate_size):
|
| 163 |
-
super().__init__()
|
| 164 |
-
self.config = config
|
| 165 |
-
self.hidden_size = config.hidden_size
|
| 166 |
-
self.intermediate_size = small_expert_intermediate_size
|
| 167 |
-
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 168 |
-
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 169 |
-
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
| 170 |
-
self.act_fn = ACT2FN[config.hidden_act]
|
| 171 |
-
|
| 172 |
-
def forward(self, x):
|
| 173 |
-
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
| 174 |
-
return down_proj
|
| 175 |
|
| 176 |
|
| 177 |
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|
@@ -462,34 +446,17 @@ OLMOE_ATTENTION_CLASSES = {
|
|
| 462 |
}
|
| 463 |
|
| 464 |
|
| 465 |
-
|
| 466 |
class OlmoeSparseMoeBlock(nn.Module):
|
| 467 |
def __init__(self, config, layer_idx: int):
|
| 468 |
super().__init__()
|
| 469 |
self.layer_idx = layer_idx
|
| 470 |
self.num_experts = config.num_experts
|
| 471 |
-
self.num_small_experts = getattr(config, "num_small_experts", 0) # Default to 0 if not specified
|
| 472 |
-
self.total_experts = self.num_experts + self.num_small_experts
|
| 473 |
self.top_k = config.num_experts_per_tok
|
| 474 |
self.norm_topk_prob = config.norm_topk_prob
|
| 475 |
-
self.routing_type = getattr(config, "routing_type", "topk")
|
| 476 |
-
self.n_step = getattr(config, "nth_step", 2)
|
| 477 |
-
|
| 478 |
-
# Gate now needs to handle both regular and small experts
|
| 479 |
-
self.gate = nn.Linear(config.hidden_size, self.total_experts, bias=False)
|
| 480 |
-
|
| 481 |
-
# Regular experts
|
| 482 |
self.experts = nn.ModuleList([OlmoeMLP(config) for _ in range(self.num_experts)])
|
| 483 |
-
|
| 484 |
-
# Small experts (if any)
|
| 485 |
-
self.small_experts = nn.ModuleList()
|
| 486 |
-
if self.num_small_experts > 0:
|
| 487 |
-
small_expert_intermediate_size = getattr(config, "small_expert_intermediate_size",
|
| 488 |
-
config.intermediate_size // 2) # Default to half size
|
| 489 |
-
self.small_experts = nn.ModuleList([
|
| 490 |
-
SmallOlmoeMLP(config, small_expert_intermediate_size)
|
| 491 |
-
for _ in range(self.num_small_experts)
|
| 492 |
-
])
|
| 493 |
|
| 494 |
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 495 |
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
|
@@ -497,6 +464,7 @@ class OlmoeSparseMoeBlock(nn.Module):
|
|
| 497 |
router_logits = self.gate(hidden_states)
|
| 498 |
routing_probs = F.softmax(router_logits, dim=1, dtype=torch.float)
|
| 499 |
|
|
|
|
| 500 |
routing_weights, selected_experts = torch.topk(routing_probs, self.top_k, dim=-1)
|
| 501 |
|
| 502 |
if self.norm_topk_prob:
|
|
@@ -509,9 +477,8 @@ class OlmoeSparseMoeBlock(nn.Module):
|
|
| 509 |
device=hidden_states.device,
|
| 510 |
)
|
| 511 |
|
| 512 |
-
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.
|
| 513 |
|
| 514 |
-
# Process regular experts
|
| 515 |
for expert_idx in range(self.num_experts):
|
| 516 |
expert_layer = self.experts[expert_idx]
|
| 517 |
idx, top_x = torch.where(expert_mask[expert_idx])
|
|
@@ -521,21 +488,9 @@ class OlmoeSparseMoeBlock(nn.Module):
|
|
| 521 |
current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
|
| 522 |
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
|
| 523 |
|
| 524 |
-
# Process small experts
|
| 525 |
-
for small_expert_idx in range(self.num_small_experts):
|
| 526 |
-
expert_layer = self.small_experts[small_expert_idx]
|
| 527 |
-
# Offset by num_experts since small experts come after regular ones
|
| 528 |
-
global_expert_idx = self.num_experts + small_expert_idx
|
| 529 |
-
idx, top_x = torch.where(expert_mask[global_expert_idx])
|
| 530 |
-
if top_x.numel() == 0:
|
| 531 |
-
continue
|
| 532 |
-
current_state = hidden_states[top_x]
|
| 533 |
-
current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
|
| 534 |
-
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
|
| 535 |
-
|
| 536 |
final_hidden_states = final_hidden_states.view(batch_size, sequence_length, hidden_dim)
|
| 537 |
return final_hidden_states, router_logits
|
| 538 |
-
|
| 539 |
|
| 540 |
class OlmoeDecoderLayer(nn.Module):
|
| 541 |
def __init__(self, config: OlmoeConfig, layer_idx: int):
|
|
@@ -997,4 +952,4 @@ class MyOlmoeForCausalLM(OlmoePreTrainedModel, GenerationMixin):
|
|
| 997 |
router_logits=outputs.router_logits,
|
| 998 |
)
|
| 999 |
|
| 1000 |
-
__all__ = ["MyOlmoeForCausalLM", "OlmoeModel", "OlmoePreTrainedModel"]
|
|
|
|
|
|
|
| 1 |
import math
|
| 2 |
from typing import List, Optional, Tuple, Union
|
| 3 |
import torch
|
|
|
|
| 156 |
def forward(self, x):
|
| 157 |
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
| 158 |
return down_proj
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
|
| 160 |
|
| 161 |
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|
|
|
| 446 |
}
|
| 447 |
|
| 448 |
|
|
|
|
| 449 |
class OlmoeSparseMoeBlock(nn.Module):
|
| 450 |
def __init__(self, config, layer_idx: int):
|
| 451 |
super().__init__()
|
| 452 |
self.layer_idx = layer_idx
|
| 453 |
self.num_experts = config.num_experts
|
|
|
|
|
|
|
| 454 |
self.top_k = config.num_experts_per_tok
|
| 455 |
self.norm_topk_prob = config.norm_topk_prob
|
| 456 |
+
self.routing_type = getattr(config, "routing_type", "topk") # default to topk
|
| 457 |
+
self.n_step = getattr(config, "nth_step", 2) # used in nth-descending
|
| 458 |
+
self.gate = nn.Linear(config.hidden_size, self.num_experts, bias=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 459 |
self.experts = nn.ModuleList([OlmoeMLP(config) for _ in range(self.num_experts)])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 460 |
|
| 461 |
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 462 |
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
|
|
|
| 464 |
router_logits = self.gate(hidden_states)
|
| 465 |
routing_probs = F.softmax(router_logits, dim=1, dtype=torch.float)
|
| 466 |
|
| 467 |
+
# === Routing ===
|
| 468 |
routing_weights, selected_experts = torch.topk(routing_probs, self.top_k, dim=-1)
|
| 469 |
|
| 470 |
if self.norm_topk_prob:
|
|
|
|
| 477 |
device=hidden_states.device,
|
| 478 |
)
|
| 479 |
|
| 480 |
+
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
|
| 481 |
|
|
|
|
| 482 |
for expert_idx in range(self.num_experts):
|
| 483 |
expert_layer = self.experts[expert_idx]
|
| 484 |
idx, top_x = torch.where(expert_mask[expert_idx])
|
|
|
|
| 488 |
current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
|
| 489 |
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
|
| 490 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 491 |
final_hidden_states = final_hidden_states.view(batch_size, sequence_length, hidden_dim)
|
| 492 |
return final_hidden_states, router_logits
|
| 493 |
+
|
| 494 |
|
| 495 |
class OlmoeDecoderLayer(nn.Module):
|
| 496 |
def __init__(self, config: OlmoeConfig, layer_idx: int):
|
|
|
|
| 952 |
router_logits=outputs.router_logits,
|
| 953 |
)
|
| 954 |
|
| 955 |
+
__all__ = ["MyOlmoeForCausalLM", "OlmoeModel", "OlmoePreTrainedModel"]
|
scripts/train.py
CHANGED
|
@@ -41,8 +41,7 @@ def expand_model_with_small_experts(base_model):
|
|
| 41 |
print("# DEBUG: Expanding model with small experts...")
|
| 42 |
config = base_model.config
|
| 43 |
config.num_small_experts = 64 # Add 64 small experts
|
| 44 |
-
|
| 45 |
-
config.small_expert_intermediate_size = config.intermediate_size // 2
|
| 46 |
expanded_model = MyOlmoeForCausalLM(config)
|
| 47 |
|
| 48 |
base_state_dict = base_model.state_dict()
|
|
|
|
| 41 |
print("# DEBUG: Expanding model with small experts...")
|
| 42 |
config = base_model.config
|
| 43 |
config.num_small_experts = 64 # Add 64 small experts
|
| 44 |
+
config.small_expert_intermediate_size = config.intermediate_size // 32
|
|
|
|
| 45 |
expanded_model = MyOlmoeForCausalLM(config)
|
| 46 |
|
| 47 |
base_state_dict = base_model.state_dict()
|