Training in progress - step 500
Browse files- projectors.py +23 -24
projectors.py
CHANGED
|
@@ -77,19 +77,17 @@ import torch.nn.functional as F
|
|
| 77 |
# =============================================================================
|
| 78 |
|
| 79 |
|
| 80 |
-
class
|
| 81 |
-
"""
|
| 82 |
|
| 83 |
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
|
| 84 |
super().__init__()
|
| 85 |
-
|
| 86 |
-
self.
|
| 87 |
-
self.
|
| 88 |
-
self.down_proj = nn.Linear(hidden_dim, output_dim, bias=False)
|
| 89 |
-
self.act = nn.SiLU()
|
| 90 |
|
| 91 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 92 |
-
return self.
|
| 93 |
|
| 94 |
|
| 95 |
class MOSAProjector(nn.Module):
|
|
@@ -100,9 +98,9 @@ class MOSAProjector(nn.Module):
|
|
| 100 |
self.num_experts = getattr(config, "num_experts", None) or 8
|
| 101 |
adapter_hidden = getattr(config, "adapter_hidden_dim", None) or 4096
|
| 102 |
|
| 103 |
-
# Auxiliary loss coefficients (
|
| 104 |
-
self.aux_loss_coef = getattr(config, "router_aux_loss_coef", 0.
|
| 105 |
-
self.z_loss_coef = getattr(config, "router_z_loss_coef", 0.
|
| 106 |
|
| 107 |
# Store router state for aux loss computation
|
| 108 |
self.last_router_logits = None
|
|
@@ -119,24 +117,22 @@ class MOSAProjector(nn.Module):
|
|
| 119 |
nn.SiLU(),
|
| 120 |
)
|
| 121 |
|
| 122 |
-
# --- 3. Deep Router (
|
| 123 |
-
# Kept your deep architecture, but added Norms between heavy layers
|
| 124 |
-
# to prevent "dead neurons" in the router.
|
| 125 |
self.router = nn.Sequential(
|
| 126 |
nn.Linear(self.encoder_dim, 2560),
|
| 127 |
-
nn.
|
| 128 |
nn.Linear(2560, 5120),
|
| 129 |
-
nn.
|
| 130 |
nn.Linear(5120, 2560),
|
| 131 |
-
nn.
|
| 132 |
nn.Linear(2560, 1280),
|
| 133 |
-
nn.
|
| 134 |
nn.Linear(1280, self.num_experts),
|
| 135 |
)
|
| 136 |
|
| 137 |
-
# --- 4. Experts (
|
| 138 |
self.experts = nn.ModuleList([
|
| 139 |
-
|
| 140 |
for _ in range(self.num_experts)
|
| 141 |
])
|
| 142 |
|
|
@@ -159,11 +155,14 @@ class MOSAProjector(nn.Module):
|
|
| 159 |
# Force the LAST router layer to be small (but not zero)
|
| 160 |
nn.init.normal_(self.router[-1].weight, std=0.01)
|
| 161 |
|
| 162 |
-
# --- 2. Expert Initialization (
|
| 163 |
for expert in self.experts:
|
| 164 |
-
nn.init.
|
| 165 |
-
nn.init.
|
| 166 |
-
|
|
|
|
|
|
|
|
|
|
| 167 |
|
| 168 |
def forward(self, x):
|
| 169 |
# x: (B, S, 1280)
|
|
|
|
| 77 |
# =============================================================================
|
| 78 |
|
| 79 |
|
| 80 |
+
class SimpleAdapter(nn.Module):
|
| 81 |
+
"""Simple 2-layer ReLU adapter (from MOSA paper)."""
|
| 82 |
|
| 83 |
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
|
| 84 |
super().__init__()
|
| 85 |
+
self.fc1 = nn.Linear(input_dim, hidden_dim)
|
| 86 |
+
self.act = nn.ReLU()
|
| 87 |
+
self.fc2 = nn.Linear(hidden_dim, output_dim)
|
|
|
|
|
|
|
| 88 |
|
| 89 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 90 |
+
return self.fc2(self.act(self.fc1(x)))
|
| 91 |
|
| 92 |
|
| 93 |
class MOSAProjector(nn.Module):
|
|
|
|
| 98 |
self.num_experts = getattr(config, "num_experts", None) or 8
|
| 99 |
adapter_hidden = getattr(config, "adapter_hidden_dim", None) or 4096
|
| 100 |
|
| 101 |
+
# Auxiliary loss coefficients (MOSA paper uses only cross-entropy, no aux losses)
|
| 102 |
+
self.aux_loss_coef = getattr(config, "router_aux_loss_coef", 0.0)
|
| 103 |
+
self.z_loss_coef = getattr(config, "router_z_loss_coef", 0.0)
|
| 104 |
|
| 105 |
# Store router state for aux loss computation
|
| 106 |
self.last_router_logits = None
|
|
|
|
| 117 |
nn.SiLU(),
|
| 118 |
)
|
| 119 |
|
| 120 |
+
# --- 3. Deep Router (ReLU per MOSA paper) ---
|
|
|
|
|
|
|
| 121 |
self.router = nn.Sequential(
|
| 122 |
nn.Linear(self.encoder_dim, 2560),
|
| 123 |
+
nn.ReLU(),
|
| 124 |
nn.Linear(2560, 5120),
|
| 125 |
+
nn.ReLU(),
|
| 126 |
nn.Linear(5120, 2560),
|
| 127 |
+
nn.ReLU(),
|
| 128 |
nn.Linear(2560, 1280),
|
| 129 |
+
nn.ReLU(),
|
| 130 |
nn.Linear(1280, self.num_experts),
|
| 131 |
)
|
| 132 |
|
| 133 |
+
# --- 4. Experts (Simple 2-layer ReLU adapters per MOSA paper) ---
|
| 134 |
self.experts = nn.ModuleList([
|
| 135 |
+
SimpleAdapter(self.llm_dim, adapter_hidden, self.llm_dim)
|
| 136 |
for _ in range(self.num_experts)
|
| 137 |
])
|
| 138 |
|
|
|
|
| 155 |
# Force the LAST router layer to be small (but not zero)
|
| 156 |
nn.init.normal_(self.router[-1].weight, std=0.01)
|
| 157 |
|
| 158 |
+
# --- 2. Expert Initialization (Simple ReLU adapter) ---
|
| 159 |
for expert in self.experts:
|
| 160 |
+
nn.init.kaiming_normal_(expert.fc1.weight, mode='fan_in', nonlinearity='relu')
|
| 161 |
+
nn.init.kaiming_normal_(expert.fc2.weight, mode='fan_in', nonlinearity='relu')
|
| 162 |
+
if expert.fc1.bias is not None:
|
| 163 |
+
nn.init.zeros_(expert.fc1.bias)
|
| 164 |
+
if expert.fc2.bias is not None:
|
| 165 |
+
nn.init.zeros_(expert.fc2.bias)
|
| 166 |
|
| 167 |
def forward(self, x):
|
| 168 |
# x: (B, S, 1280)
|