match transformers sparse block
Browse files- modeling_myolmoe.py +12 -9
modeling_myolmoe.py
CHANGED
|
@@ -319,22 +319,25 @@ class MyOLMoESparseMoeBlock(nn.Module):
|
|
| 319 |
|
| 320 |
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 321 |
print(f"DEBUG: MoE forward start - hidden_states shape: {hidden_states.shape}")
|
| 322 |
-
batch_size, seq_len,
|
| 323 |
print("absolute precision")
|
| 324 |
-
hidden_states = hidden_states.view(-1,
|
| 325 |
|
| 326 |
# Get routing weights and selected experts
|
| 327 |
print(f"DEBUG: 123: {self.router(hidden_states).shape}")
|
| 328 |
routing_weights, selected_experts, router_logits = self.router(hidden_states)
|
| 329 |
-
|
| 330 |
-
|
|
|
|
|
|
|
| 331 |
if self.norm_topk_prob:
|
| 332 |
-
routing_weights
|
| 333 |
-
|
|
|
|
|
|
|
|
|
|
| 334 |
final_hidden_states = torch.zeros(
|
| 335 |
-
(batch_size * seq_len,
|
| 336 |
-
dtype=hidden_states.dtype,
|
| 337 |
-
device=hidden_states.device
|
| 338 |
)
|
| 339 |
|
| 340 |
# One-hot expert mask
|
|
|
|
| 319 |
|
| 320 |
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 321 |
print(f"DEBUG: MoE forward start - hidden_states shape: {hidden_states.shape}")
|
| 322 |
+
batch_size, seq_len, hidden_dim = hidden_states.shape
|
| 323 |
print("absolute precision")
|
| 324 |
+
hidden_states = hidden_states.view(-1, hidden_dim)
|
| 325 |
|
| 326 |
# Get routing weights and selected experts
|
| 327 |
print(f"DEBUG: 123: {self.router(hidden_states).shape}")
|
| 328 |
routing_weights, selected_experts, router_logits = self.router(hidden_states)
|
| 329 |
+
router_logits = self.gate(hidden_states)
|
| 330 |
+
|
| 331 |
+
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
| 332 |
+
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
|
| 333 |
if self.norm_topk_prob:
|
| 334 |
+
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
| 335 |
+
# we cast back to the input dtype
|
| 336 |
+
routing_weights = routing_weights.to(hidden_states.dtype)
|
| 337 |
+
print(f"DEBUG: MoE forward mid - routing_weights shape: {routing_weights.shape}, selected_experts shape: {selected_experts.shape}")
|
| 338 |
+
|
| 339 |
final_hidden_states = torch.zeros(
|
| 340 |
+
(batch_size * seq_len, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
|
|
|
|
|
|
|
| 341 |
)
|
| 342 |
|
| 343 |
# One-hot expert mask
|