Charlie81 commited on
Commit
170c7d7
·
1 Parent(s): be9d959

refactor sparse

Browse files
Files changed (1) hide show
  1. modeling_myolmoe.py +11 -12
modeling_myolmoe.py CHANGED
@@ -223,6 +223,7 @@ class MyOLMoERouting(nn.Module):
223
  self.hidden_size = config.hidden_size
224
  self.routing_type = getattr(config, "routing_type", "sparse")
225
  self.router_temperature = getattr(config, "router_temperature", 1.0)
 
226
 
227
  # Shared components
228
  self.gate = nn.Linear(config.hidden_size, self.num_experts, bias=False)
@@ -230,20 +231,13 @@ class MyOLMoERouting(nn.Module):
230
  # For non-deterministic routing
231
  self.gumbel_noise = getattr(config, "gumbel_noise", 0.1)
232
 
233
- def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
234
  batch_size, sequence_length, hidden_dim = hidden_states.shape
 
235
  hidden_states = hidden_states.view(-1, hidden_dim)
 
236
  router_logits = self.gate(hidden_states)
237
 
238
- # Always use softmax, even for "dense" routing
239
- routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
240
- routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
241
-
242
- if self.norm_topk_prob:
243
- routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
244
-
245
- routing_weights = routing_weights.to(hidden_states.dtype)
246
-
247
  if self.routing_type == "dense":
248
  # Dense routing - use all experts equally
249
  routing_weights = torch.ones_like(router_logits) / self.num_experts
@@ -262,11 +256,16 @@ class MyOLMoERouting(nn.Module):
262
 
263
  else: # Default sparse routing
264
  # Standard sparse top-k routing
265
- routing_weights = F.softmax(router_logits, dim=-1)
266
  routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
267
 
 
 
 
 
 
268
  return routing_weights, selected_experts, router_logits
269
-
270
  class OlmoeRotaryEmbedding(nn.Module):
271
  def __init__(self, config: OlmoeConfig, device=None):
272
  super().__init__()
 
223
  self.hidden_size = config.hidden_size
224
  self.routing_type = getattr(config, "routing_type", "sparse")
225
  self.router_temperature = getattr(config, "router_temperature", 1.0)
226
+ self.norm_topk_prob = getattr(config, "norm_topk_prob", False)
227
 
228
  # Shared components
229
  self.gate = nn.Linear(config.hidden_size, self.num_experts, bias=False)
 
231
  # For non-deterministic routing
232
  self.gumbel_noise = getattr(config, "gumbel_noise", 0.1)
233
 
234
+ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
235
  batch_size, sequence_length, hidden_dim = hidden_states.shape
236
+ print("TEST testtest123")
237
  hidden_states = hidden_states.view(-1, hidden_dim)
238
+ print("TEST 4564645testtest123")
239
  router_logits = self.gate(hidden_states)
240
 
 
 
 
 
 
 
 
 
 
241
  if self.routing_type == "dense":
242
  # Dense routing - use all experts equally
243
  routing_weights = torch.ones_like(router_logits) / self.num_experts
 
256
 
257
  else: # Default sparse routing
258
  # Standard sparse top-k routing
259
+ routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float)
260
  routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
261
 
262
+ if self.norm_topk_prob:
263
+ routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
264
+
265
+ routing_weights = routing_weights.to(hidden_states.dtype)
266
+
267
  return routing_weights, selected_experts, router_logits
268
+
269
  class OlmoeRotaryEmbedding(nn.Module):
270
  def __init__(self, config: OlmoeConfig, device=None):
271
  super().__init__()