Commit ·
b70422e
1
Parent(s): 1b2e4da
Fixes
Browse files- Model_Architecture/config.json +4 -4
- Model_Architecture/model.py +20 -18
Model_Architecture/config.json
CHANGED
|
@@ -1,13 +1,13 @@
|
|
| 1 |
{
|
| 2 |
"model": {
|
| 3 |
-
"max_batch_size":
|
| 4 |
-
"max_seq_len":
|
| 5 |
"dtype": "bf16",
|
| 6 |
"scale_fmt": null,
|
| 7 |
"vocab_size": 32768,
|
| 8 |
-
"dim":
|
| 9 |
"inter_dim": 4096,
|
| 10 |
-
"moe_inter_dim":
|
| 11 |
"n_layers": 20,
|
| 12 |
"n_dense_layers": 3,
|
| 13 |
"n_heads": 12,
|
|
|
|
| 1 |
{
|
| 2 |
"model": {
|
| 3 |
+
"max_batch_size": 4,
|
| 4 |
+
"max_seq_len": 1024,
|
| 5 |
"dtype": "bf16",
|
| 6 |
"scale_fmt": null,
|
| 7 |
"vocab_size": 32768,
|
| 8 |
+
"dim": 768,
|
| 9 |
"inter_dim": 4096,
|
| 10 |
+
"moe_inter_dim": 768,
|
| 11 |
"n_layers": 20,
|
| 12 |
"n_dense_layers": 3,
|
| 13 |
"n_heads": 12,
|
Model_Architecture/model.py
CHANGED
|
@@ -394,29 +394,31 @@ class MoE(nn.Module):
|
|
| 394 |
weights = weights / weights.sum(dim=-1, keepdim=True)
|
| 395 |
weights = weights * self.gate.route_scale
|
| 396 |
|
| 397 |
-
# Sequential Training Mode
|
|
|
|
| 398 |
if self.training and self.active_expert_idx is not None:
|
| 399 |
y = torch.zeros_like(x)
|
| 400 |
-
|
| 401 |
-
#
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
|
|
|
| 414 |
# Load balance loss (still needed for gate training)
|
| 415 |
lb_loss = self.compute_load_balance_loss(router_probs, indices)
|
| 416 |
-
|
| 417 |
-
# Shared experts always train
|
| 418 |
z = self.shared_experts(x)
|
| 419 |
-
|
| 420 |
return (y + z).view(original_shape), lb_loss
|
| 421 |
|
| 422 |
# Normal MoE Mode (inference or full training)
|
|
|
|
| 394 |
weights = weights / weights.sum(dim=-1, keepdim=True)
|
| 395 |
weights = weights * self.gate.route_scale
|
| 396 |
|
| 397 |
+
# Sequential Training Mode - MEMORY EFFICIENT
|
| 398 |
+
# ONLY compute forward pass for the active expert to save GPU memory
|
| 399 |
if self.training and self.active_expert_idx is not None:
|
| 400 |
y = torch.zeros_like(x)
|
| 401 |
+
|
| 402 |
+
# Run forward pass ONLY for the active expert
|
| 403 |
+
i = self.active_expert_idx
|
| 404 |
+
idx, top = torch.where(indices == i)
|
| 405 |
+
|
| 406 |
+
if idx.numel() > 0:
|
| 407 |
+
# Only this expert gets gradients and forward pass
|
| 408 |
+
expert_out = self.experts[i](x[idx])
|
| 409 |
+
y[idx] = expert_out * weights[idx, top, None]
|
| 410 |
+
|
| 411 |
+
# Inactive experts: Skip forward pass entirely (save memory!)
|
| 412 |
+
# Note: This means the model output will be degraded during training,
|
| 413 |
+
# but it's acceptable since we're training experts sequentially.
|
| 414 |
+
# The shared experts + active expert still provide reasonable outputs.
|
| 415 |
+
|
| 416 |
# Load balance loss (still needed for gate training)
|
| 417 |
lb_loss = self.compute_load_balance_loss(router_probs, indices)
|
| 418 |
+
|
| 419 |
+
# Shared experts always train (provides baseline performance)
|
| 420 |
z = self.shared_experts(x)
|
| 421 |
+
|
| 422 |
return (y + z).view(original_shape), lb_loss
|
| 423 |
|
| 424 |
# Normal MoE Mode (inference or full training)
|