mazesmazes commited on
Commit
311a9ba
·
verified ·
1 Parent(s): 1535c89

Training in progress - step 500

Browse files
Files changed (1) hide show
  1. projectors.py +23 -24
projectors.py CHANGED
@@ -77,19 +77,17 @@ import torch.nn.functional as F
77
  # =============================================================================
78
 
79
 
80
- class SwiGLUExpert(nn.Module):
81
- """SwiGLU expert MLP."""
82
 
83
  def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
84
  super().__init__()
85
- # Bias=False is strictly preferred for MoE experts to reduce memory/compute
86
- self.gate_proj = nn.Linear(input_dim, hidden_dim, bias=False)
87
- self.up_proj = nn.Linear(input_dim, hidden_dim, bias=False)
88
- self.down_proj = nn.Linear(hidden_dim, output_dim, bias=False)
89
- self.act = nn.SiLU()
90
 
91
  def forward(self, x: torch.Tensor) -> torch.Tensor:
92
- return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
93
 
94
 
95
  class MOSAProjector(nn.Module):
@@ -100,9 +98,9 @@ class MOSAProjector(nn.Module):
100
  self.num_experts = getattr(config, "num_experts", None) or 8
101
  adapter_hidden = getattr(config, "adapter_hidden_dim", None) or 4096
102
 
103
- # Auxiliary loss coefficients (same defaults as SharedMoE)
104
- self.aux_loss_coef = getattr(config, "router_aux_loss_coef", 0.001)
105
- self.z_loss_coef = getattr(config, "router_z_loss_coef", 0.001)
106
 
107
  # Store router state for aux loss computation
108
  self.last_router_logits = None
@@ -119,24 +117,22 @@ class MOSAProjector(nn.Module):
119
  nn.SiLU(),
120
  )
121
 
122
- # --- 3. Deep Router (Standardized to SiLU) ---
123
- # Kept your deep architecture, but added Norms between heavy layers
124
- # to prevent "dead neurons" in the router.
125
  self.router = nn.Sequential(
126
  nn.Linear(self.encoder_dim, 2560),
127
- nn.SiLU(),
128
  nn.Linear(2560, 5120),
129
- nn.SiLU(),
130
  nn.Linear(5120, 2560),
131
- nn.SiLU(),
132
  nn.Linear(2560, 1280),
133
- nn.SiLU(),
134
  nn.Linear(1280, self.num_experts),
135
  )
136
 
137
- # --- 4. Experts (SwiGLU for LLM compatibility) ---
138
  self.experts = nn.ModuleList([
139
- SwiGLUExpert(self.llm_dim, adapter_hidden, self.llm_dim)
140
  for _ in range(self.num_experts)
141
  ])
142
 
@@ -159,11 +155,14 @@ class MOSAProjector(nn.Module):
159
  # Force the LAST router layer to be small (but not zero)
160
  nn.init.normal_(self.router[-1].weight, std=0.01)
161
 
162
- # --- 2. Expert Initialization (SwiGLU) ---
163
  for expert in self.experts:
164
- nn.init.orthogonal_(expert.gate_proj.weight)
165
- nn.init.orthogonal_(expert.up_proj.weight)
166
- nn.init.orthogonal_(expert.down_proj.weight, gain=0.5)
 
 
 
167
 
168
  def forward(self, x):
169
  # x: (B, S, 1280)
 
77
  # =============================================================================
78
 
79
 
80
+ class SimpleAdapter(nn.Module):
81
+ """Simple 2-layer ReLU adapter (from MOSA paper)."""
82
 
83
  def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
84
  super().__init__()
85
+ self.fc1 = nn.Linear(input_dim, hidden_dim)
86
+ self.act = nn.ReLU()
87
+ self.fc2 = nn.Linear(hidden_dim, output_dim)
 
 
88
 
89
  def forward(self, x: torch.Tensor) -> torch.Tensor:
90
+ return self.fc2(self.act(self.fc1(x)))
91
 
92
 
93
  class MOSAProjector(nn.Module):
 
98
  self.num_experts = getattr(config, "num_experts", None) or 8
99
  adapter_hidden = getattr(config, "adapter_hidden_dim", None) or 4096
100
 
101
+ # Auxiliary loss coefficients (MOSA paper uses only cross-entropy, no aux losses)
102
+ self.aux_loss_coef = getattr(config, "router_aux_loss_coef", 0.0)
103
+ self.z_loss_coef = getattr(config, "router_z_loss_coef", 0.0)
104
 
105
  # Store router state for aux loss computation
106
  self.last_router_logits = None
 
117
  nn.SiLU(),
118
  )
119
 
120
+ # --- 3. Deep Router (ReLU per MOSA paper) ---
 
 
121
  self.router = nn.Sequential(
122
  nn.Linear(self.encoder_dim, 2560),
123
+ nn.ReLU(),
124
  nn.Linear(2560, 5120),
125
+ nn.ReLU(),
126
  nn.Linear(5120, 2560),
127
+ nn.ReLU(),
128
  nn.Linear(2560, 1280),
129
+ nn.ReLU(),
130
  nn.Linear(1280, self.num_experts),
131
  )
132
 
133
+ # --- 4. Experts (Simple 2-layer ReLU adapters per MOSA paper) ---
134
  self.experts = nn.ModuleList([
135
+ SimpleAdapter(self.llm_dim, adapter_hidden, self.llm_dim)
136
  for _ in range(self.num_experts)
137
  ])
138
 
 
155
  # Force the LAST router layer to be small (but not zero)
156
  nn.init.normal_(self.router[-1].weight, std=0.01)
157
 
158
+ # --- 2. Expert Initialization (Simple ReLU adapter) ---
159
  for expert in self.experts:
160
+ nn.init.kaiming_normal_(expert.fc1.weight, mode='fan_in', nonlinearity='relu')
161
+ nn.init.kaiming_normal_(expert.fc2.weight, mode='fan_in', nonlinearity='relu')
162
+ if expert.fc1.bias is not None:
163
+ nn.init.zeros_(expert.fc1.bias)
164
+ if expert.fc2.bias is not None:
165
+ nn.init.zeros_(expert.fc2.bias)
166
 
167
  def forward(self, x):
168
  # x: (B, S, 1280)