mazesmazes commited on
Commit
c20eb37
·
verified ·
1 Parent(s): 7bc28c3

Training in progress - step 500

Browse files
Files changed (1) hide show
  1. projectors.py +1 -23
projectors.py CHANGED
@@ -140,29 +140,7 @@ class MOSAProjector(nn.Module):
140
  # Projects often drift in magnitude; this clamps them before the LLM.
141
  self.out_norm = LlamaRMSNorm(self.llm_dim, eps=1e-8)
142
 
143
- self._init_weights()
144
-
145
- def _init_weights(self):
146
- # --- 1. Router Initialization ---
147
- # The router is 5 layers deep. We need Kaiming Init to ensure
148
- # gradients can penetrate to the first layer.
149
- for module in self.router.modules():
150
- if isinstance(module, nn.Linear):
151
- nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu')
152
- if module.bias is not None:
153
- nn.init.zeros_(module.bias)
154
-
155
- # Force the LAST router layer to be small (but not zero)
156
- nn.init.normal_(self.router[-1].weight, std=0.01)
157
-
158
- # --- 2. Expert Initialization (Simple ReLU adapter) ---
159
- for expert in self.experts:
160
- nn.init.kaiming_normal_(expert.fc1.weight, mode='fan_in', nonlinearity='relu')
161
- nn.init.kaiming_normal_(expert.fc2.weight, mode='fan_in', nonlinearity='relu')
162
- if expert.fc1.bias is not None:
163
- nn.init.zeros_(expert.fc1.bias)
164
- if expert.fc2.bias is not None:
165
- nn.init.zeros_(expert.fc2.bias)
166
 
167
  def forward(self, x):
168
  # x: (B, S, 1280)
 
140
  # Projects often drift in magnitude; this clamps them before the LLM.
141
  self.out_norm = LlamaRMSNorm(self.llm_dim, eps=1e-8)
142
 
143
+ # Using PyTorch default initialization (like MOSA paper)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
  def forward(self, x):
146
  # x: (B, S, 1280)