ikaganacar commited on
Commit
b70422e
·
1 Parent(s): 1b2e4da
Model_Architecture/config.json CHANGED
@@ -1,13 +1,13 @@
1
  {
2
  "model": {
3
- "max_batch_size": 8,
4
- "max_seq_len": 2048,
5
  "dtype": "bf16",
6
  "scale_fmt": null,
7
  "vocab_size": 32768,
8
- "dim": 1024,
9
  "inter_dim": 4096,
10
- "moe_inter_dim": 1024,
11
  "n_layers": 20,
12
  "n_dense_layers": 3,
13
  "n_heads": 12,
 
1
  {
2
  "model": {
3
+ "max_batch_size": 4,
4
+ "max_seq_len": 1024,
5
  "dtype": "bf16",
6
  "scale_fmt": null,
7
  "vocab_size": 32768,
8
+ "dim": 768,
9
  "inter_dim": 4096,
10
+ "moe_inter_dim": 768,
11
  "n_layers": 20,
12
  "n_dense_layers": 3,
13
  "n_heads": 12,
Model_Architecture/model.py CHANGED
@@ -394,29 +394,31 @@ class MoE(nn.Module):
394
  weights = weights / weights.sum(dim=-1, keepdim=True)
395
  weights = weights * self.gate.route_scale
396
 
397
- # Sequential Training Mode
 
398
  if self.training and self.active_expert_idx is not None:
399
  y = torch.zeros_like(x)
400
-
401
- # Only compute gradients for active expert
402
- for i in range(self.n_routed_experts):
403
- idx, top = torch.where(indices == i)
404
- if idx.numel() == 0:
405
- continue
406
-
407
- # Use gradient context manager
408
- grad_context = nullcontext() if i == self.active_expert_idx else torch.no_grad()
409
-
410
- with grad_context:
411
- expert_out = self.experts[i](x[idx])
412
- y[idx] += expert_out * weights[idx, top, None]
413
-
 
414
  # Load balance loss (still needed for gate training)
415
  lb_loss = self.compute_load_balance_loss(router_probs, indices)
416
-
417
- # Shared experts always train
418
  z = self.shared_experts(x)
419
-
420
  return (y + z).view(original_shape), lb_loss
421
 
422
  # Normal MoE Mode (inference or full training)
 
394
  weights = weights / weights.sum(dim=-1, keepdim=True)
395
  weights = weights * self.gate.route_scale
396
 
397
+ # Sequential Training Mode - MEMORY EFFICIENT
398
+ # ONLY compute forward pass for the active expert to save GPU memory
399
  if self.training and self.active_expert_idx is not None:
400
  y = torch.zeros_like(x)
401
+
402
+ # Run forward pass ONLY for the active expert
403
+ i = self.active_expert_idx
404
+ idx, top = torch.where(indices == i)
405
+
406
+ if idx.numel() > 0:
407
+ # Only this expert gets gradients and forward pass
408
+ expert_out = self.experts[i](x[idx])
409
+ y[idx] = expert_out * weights[idx, top, None]
410
+
411
+ # Inactive experts: Skip forward pass entirely (save memory!)
412
+ # Note: This means the model output will be degraded during training,
413
+ # but it's acceptable since we're training experts sequentially.
414
+ # The shared experts + active expert still provide reasonable outputs.
415
+
416
  # Load balance loss (still needed for gate training)
417
  lb_loss = self.compute_load_balance_loss(router_probs, indices)
418
+
419
+ # Shared experts always train (provides baseline performance)
420
  z = self.shared_experts(x)
421
+
422
  return (y + z).view(original_shape), lb_loss
423
 
424
  # Normal MoE Mode (inference or full training)