Training in progress - step 500
Browse files- projectors.py +1 -23
projectors.py
CHANGED
|
@@ -140,29 +140,7 @@ class MOSAProjector(nn.Module):
|
|
| 140 |
# Projects often drift in magnitude; this clamps them before the LLM.
|
| 141 |
self.out_norm = LlamaRMSNorm(self.llm_dim, eps=1e-8)
|
| 142 |
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
def _init_weights(self):
|
| 146 |
-
# --- 1. Router Initialization ---
|
| 147 |
-
# The router is 5 layers deep. We need Kaiming Init to ensure
|
| 148 |
-
# gradients can penetrate to the first layer.
|
| 149 |
-
for module in self.router.modules():
|
| 150 |
-
if isinstance(module, nn.Linear):
|
| 151 |
-
nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu')
|
| 152 |
-
if module.bias is not None:
|
| 153 |
-
nn.init.zeros_(module.bias)
|
| 154 |
-
|
| 155 |
-
# Force the LAST router layer to be small (but not zero)
|
| 156 |
-
nn.init.normal_(self.router[-1].weight, std=0.01)
|
| 157 |
-
|
| 158 |
-
# --- 2. Expert Initialization (Simple ReLU adapter) ---
|
| 159 |
-
for expert in self.experts:
|
| 160 |
-
nn.init.kaiming_normal_(expert.fc1.weight, mode='fan_in', nonlinearity='relu')
|
| 161 |
-
nn.init.kaiming_normal_(expert.fc2.weight, mode='fan_in', nonlinearity='relu')
|
| 162 |
-
if expert.fc1.bias is not None:
|
| 163 |
-
nn.init.zeros_(expert.fc1.bias)
|
| 164 |
-
if expert.fc2.bias is not None:
|
| 165 |
-
nn.init.zeros_(expert.fc2.bias)
|
| 166 |
|
| 167 |
def forward(self, x):
|
| 168 |
# x: (B, S, 1280)
|
|
|
|
| 140 |
# Projects often drift in magnitude; this clamps them before the LLM.
|
| 141 |
self.out_norm = LlamaRMSNorm(self.llm_dim, eps=1e-8)
|
| 142 |
|
| 143 |
+
# Using PyTorch default initialization (like MOSA paper)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
|
| 145 |
def forward(self, x):
|
| 146 |
# x: (B, S, 1280)
|