Charlie81 commited on
Commit
2daadcc
·
1 Parent(s): 7bf23fe

match transformers sparse block

Browse files
Files changed (1) hide show
  1. modeling_myolmoe.py +12 -9
modeling_myolmoe.py CHANGED
@@ -319,22 +319,25 @@ class MyOLMoESparseMoeBlock(nn.Module):
319
 
320
  def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
321
  print(f"DEBUG: MoE forward start - hidden_states shape: {hidden_states.shape}")
322
- batch_size, seq_len, _ = hidden_states.shape
323
  print("absolute precision")
324
- hidden_states = hidden_states.view(-1, self.hidden_size)
325
 
326
  # Get routing weights and selected experts
327
  print(f"DEBUG: 123: {self.router(hidden_states).shape}")
328
  routing_weights, selected_experts, router_logits = self.router(hidden_states)
329
- print(f"DEBUG: MoE forward mid - routing_weights shape: {routing_weights.shape}, selected_experts shape: {selected_experts.shape}")
330
-
 
 
331
  if self.norm_topk_prob:
332
- routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
333
-
 
 
 
334
  final_hidden_states = torch.zeros(
335
- (batch_size * seq_len, self.hidden_size),
336
- dtype=hidden_states.dtype,
337
- device=hidden_states.device
338
  )
339
 
340
  # One-hot expert mask
 
319
 
320
  def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
321
  print(f"DEBUG: MoE forward start - hidden_states shape: {hidden_states.shape}")
322
+ batch_size, seq_len, hidden_dim = hidden_states.shape
323
  print("absolute precision")
324
+ hidden_states = hidden_states.view(-1, hidden_dim)
325
 
326
  # Get routing weights and selected experts
327
  print(f"DEBUG: 123: {self.router(hidden_states).shape}")
328
  routing_weights, selected_experts, router_logits = self.router(hidden_states)
329
+ router_logits = self.gate(hidden_states)
330
+
331
+ routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
332
+ routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
333
  if self.norm_topk_prob:
334
+ routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
335
+ # we cast back to the input dtype
336
+ routing_weights = routing_weights.to(hidden_states.dtype)
337
+ print(f"DEBUG: MoE forward mid - routing_weights shape: {routing_weights.shape}, selected_experts shape: {selected_experts.shape}")
338
+
339
  final_hidden_states = torch.zeros(
340
+ (batch_size * seq_len, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
 
 
341
  )
342
 
343
  # One-hot expert mask