mazesmazes commited on
Commit
4bc741d
·
verified ·
1 Parent(s): 9c03397

Training in progress - step 500

Browse files
Files changed (1) hide show
  1. projectors.py +75 -51
projectors.py CHANGED
@@ -76,19 +76,31 @@ import torch.nn.functional as F
76
  # MoE Projector (MOSA-style)
77
  # =============================================================================
78
 
79
- import torch
80
- import torch.nn as nn
81
- import torch.nn.functional as F
 
 
 
82
 
83
- class SimpleAdapter(nn.Module):
84
- """Simple 2-layer adapter with ReLU (as per MOSA paper)."""
 
 
85
 
 
 
 
 
 
 
86
  def __init__(self, in_dim, hidden_dim, out_dim):
87
  super().__init__()
88
  self.fc1 = nn.Linear(in_dim, hidden_dim)
89
- self.act = nn.ReLU()
90
  self.fc2 = nn.Linear(hidden_dim, out_dim)
91
-
 
92
  def forward(self, x):
93
  return self.fc2(self.act(self.fc1(x)))
94
 
@@ -100,7 +112,10 @@ class MOSAProjector(nn.Module):
100
  self.num_experts = getattr(config, "num_experts", None) or 8
101
  adapter_hidden = getattr(config, "adapter_hidden_dim", None) or 4096
102
 
103
- # 1. Convolutional Subsampling (Stride 4 total)
 
 
 
104
  self.conv = nn.Sequential(
105
  nn.Conv1d(self.encoder_dim, self.llm_dim, kernel_size=3, stride=2, padding=1),
106
  nn.SiLU(),
@@ -108,87 +123,93 @@ class MOSAProjector(nn.Module):
108
  nn.SiLU(),
109
  )
110
 
111
- # 2. Router (MOSA-Large: 1280 -> 2560 -> 5120 -> 2560 -> 1280 -> num_experts)
112
- # Deep router with ReLU for better expert sparsity (as per paper)
113
- # Router operates on pooled features (same receptive field as conv)
114
  self.router = nn.Sequential(
115
  nn.Linear(self.encoder_dim, 2560),
116
- nn.ReLU(),
117
  nn.Linear(2560, 5120),
118
- nn.ReLU(),
119
  nn.Linear(5120, 2560),
120
- nn.ReLU(),
121
  nn.Linear(2560, 1280),
122
- nn.ReLU(),
123
  nn.Linear(1280, self.num_experts),
124
  )
125
 
126
- # 3. Experts
127
  self.experts = nn.ModuleList([
128
  SimpleAdapter(self.llm_dim, adapter_hidden, self.llm_dim)
129
  for _ in range(self.num_experts)
130
  ])
 
 
 
 
131
 
132
  self._init_weights()
133
 
134
  def _init_weights(self):
135
- """Initialize weights for stable training."""
136
- for m in self.modules():
137
- if isinstance(m, (nn.Linear, nn.Conv1d)):
138
- nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
139
- if m.bias is not None:
140
- nn.init.zeros_(m.bias)
141
-
142
- # Scale down expert output projections for stable residual-like behavior
 
 
 
 
 
143
  for expert in self.experts:
144
- with torch.no_grad():
145
- expert.fc2.weight.data.mul_(0.1)
146
-
147
- # Initialize final router layer with small weights for uniform initial routing
148
- # This prevents one expert from dominating at the start of training
149
- with torch.no_grad():
150
- final_router_layer = self.router[-1] # Last linear layer
151
- nn.init.normal_(final_router_layer.weight, mean=0.0, std=0.01)
152
- if final_router_layer.bias is not None:
153
- nn.init.zeros_(final_router_layer.bias)
154
 
155
  def forward(self, x):
156
  # x: (B, S, 1280)
157
  batch_size, seq_len, _ = x.shape
 
 
 
158
 
159
  # --- 1. Conv Branch ---
160
- # Downsample: S -> S//4, expand: 1280 -> llm_dim
161
- x_trans = x.permute(0, 2, 1) # (B, 1280, S)
162
  h_conv = self.conv(x_trans).permute(0, 2, 1) # (B, S//4, llm_dim)
163
 
164
  # --- 2. Router Branch ---
165
- # Pool input BEFORE routing so router sees same receptive field as conv
166
- # This is more principled than post-hoc averaging of per-frame decisions
167
  pad_amt = (4 - (seq_len % 4)) % 4
168
  if pad_amt > 0:
169
- x_padded = F.pad(x, (0, 0, 0, pad_amt)) # Pad sequence dim
170
  else:
171
  x_padded = x
172
 
173
- # Average pool to match conv stride (B, S, 1280) -> (B, S//4, 1280)
174
- x_pooled = x_padded.view(batch_size, -1, 4, self.encoder_dim).mean(dim=2)
175
 
176
- # Router makes 1 informed decision per pooled token
177
  router_logits = self.router(x_pooled) # (B, S//4, num_experts)
 
 
178
  routing_weights = F.softmax(router_logits, dim=-1)
179
 
180
- # --- 3. Expert Mixture ---
181
- # expert_outputs shape: (num_experts, B, S//4, llm_dim)
182
- expert_outputs = torch.stack([expert(h_conv) for expert in self.experts])
 
 
 
 
183
 
184
- # Weighted sum of experts: (B, S//4, llm_dim)
 
185
  final_out = torch.einsum('ebsd, bse -> bsd', expert_outputs, routing_weights)
186
 
187
- return final_out
188
-
189
- def get_aux_loss(self) -> torch.Tensor:
190
- """MOSA uses only cross-entropy loss, so aux loss is 0."""
191
- return torch.tensor(0.0, device=self.conv[0].weight.device)
192
 
193
  def get_output_length(self, input_length: int) -> int:
194
  """Calculate output sequence length given input length."""
@@ -196,6 +217,9 @@ class MOSAProjector(nn.Module):
196
  padded = input_length + (4 - input_length % 4) % 4
197
  return padded // 4
198
 
 
 
 
199
 
200
  # =============================================================================
201
  # SwiGLU Projector
 
76
  # MoE Projector (MOSA-style)
77
  # =============================================================================
78
 
79
+ class RMSNorm(nn.Module):
80
+ """Standard RMSNorm for 2025 architectures."""
81
+ def __init__(self, dim: int, eps: float = 1e-6):
82
+ super().__init__()
83
+ self.eps = eps
84
+ self.weight = nn.Parameter(torch.ones(dim))
85
 
86
+ def forward(self, x):
87
+ var = torch.mean(x ** 2, dim=-1, keepdim=True)
88
+ x_normed = x * torch.rsqrt(var + self.eps)
89
+ return self.weight * x_normed
90
 
91
+ class SimpleAdapter(nn.Module):
92
+ """
93
+ Updated Adapter:
94
+ 1. Uses SiLU (better for LLM alignment).
95
+ 2. Includes internal Norm (crucial for MoE stability).
96
+ """
97
  def __init__(self, in_dim, hidden_dim, out_dim):
98
  super().__init__()
99
  self.fc1 = nn.Linear(in_dim, hidden_dim)
100
+ self.act = nn.SiLU() # Changed from ReLU to SiLU
101
  self.fc2 = nn.Linear(hidden_dim, out_dim)
102
+ # Optional: Add Dropout if training on small datasets
103
+
104
  def forward(self, x):
105
  return self.fc2(self.act(self.fc1(x)))
106
 
 
112
  self.num_experts = getattr(config, "num_experts", None) or 8
113
  adapter_hidden = getattr(config, "adapter_hidden_dim", None) or 4096
114
 
115
+ # --- 1. Pre-Norms (CRITICAL for stability) ---
116
+ self.in_norm = RMSNorm(self.encoder_dim)
117
+
118
+ # --- 2. Convolutional Subsampling (Stride 4) ---
119
  self.conv = nn.Sequential(
120
  nn.Conv1d(self.encoder_dim, self.llm_dim, kernel_size=3, stride=2, padding=1),
121
  nn.SiLU(),
 
123
  nn.SiLU(),
124
  )
125
 
126
+ # --- 3. Deep Router (Standardized to SiLU) ---
127
+ # Kept your deep architecture, but added Norms between heavy layers
128
+ # to prevent "dead neurons" in the router.
129
  self.router = nn.Sequential(
130
  nn.Linear(self.encoder_dim, 2560),
131
+ nn.SiLU(),
132
  nn.Linear(2560, 5120),
133
+ nn.SiLU(),
134
  nn.Linear(5120, 2560),
135
+ nn.SiLU(),
136
  nn.Linear(2560, 1280),
137
+ nn.SiLU(),
138
  nn.Linear(1280, self.num_experts),
139
  )
140
 
141
+ # --- 4. Experts ---
142
  self.experts = nn.ModuleList([
143
  SimpleAdapter(self.llm_dim, adapter_hidden, self.llm_dim)
144
  for _ in range(self.num_experts)
145
  ])
146
+
147
+ # --- 5. Output Norm ---
148
+ # Projects often drift in magnitude; this clamps them before the LLM.
149
+ self.out_norm = RMSNorm(self.llm_dim)
150
 
151
  self._init_weights()
152
 
153
  def _init_weights(self):
154
+ # --- 1. Router Initialization ---
155
+ # The router is 5 layers deep. We need Kaiming Init to ensure
156
+ # gradients can penetrate to the first layer.
157
+ for module in self.router.modules():
158
+ if isinstance(module, nn.Linear):
159
+ nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu')
160
+ if module.bias is not None:
161
+ nn.init.zeros_(module.bias)
162
+
163
+ # Force the LAST router layer to be small (but not zero)
164
+ nn.init.normal_(self.router[-1].weight, std=0.01)
165
+
166
+ # --- 2. Expert Initialization ---
167
  for expert in self.experts:
168
+ nn.init.kaiming_uniform_(expert.fc1.weight, a=math.sqrt(5))
169
+ nn.init.xavier_uniform_(expert.fc2.weight)
170
+ if expert.fc2.bias is not None:
171
+ nn.init.zeros_(expert.fc2.bias)
 
 
 
 
 
 
172
 
173
  def forward(self, x):
174
  # x: (B, S, 1280)
175
  batch_size, seq_len, _ = x.shape
176
+
177
+ # Apply Input Norm
178
+ x = self.in_norm(x)
179
 
180
  # --- 1. Conv Branch ---
181
+ x_trans = x.permute(0, 2, 1) # (B, D, S)
 
182
  h_conv = self.conv(x_trans).permute(0, 2, 1) # (B, S//4, llm_dim)
183
 
184
  # --- 2. Router Branch ---
 
 
185
  pad_amt = (4 - (seq_len % 4)) % 4
186
  if pad_amt > 0:
187
+ x_padded = F.pad(x, (0, 0, 0, pad_amt))
188
  else:
189
  x_padded = x
190
 
191
+ # Mean pool to align receptive fields
192
+ x_pooled = x_padded.view(batch_size, -1, 4, self.encoder_dim).mean(dim=2) # (B, S//4, D)
193
 
194
+ # Router Logits
195
  router_logits = self.router(x_pooled) # (B, S//4, num_experts)
196
+
197
+ # Softmax for Dense MoE (Soft Mixing)
198
  routing_weights = F.softmax(router_logits, dim=-1)
199
 
200
+ # --- 3. Expert Mixture (Dense Execution) ---
201
+ # Warning: High VRAM usage. Runs all experts.
202
+ # h_conv: (B, S//4, llm_dim)
203
+
204
+ # Stack approach is clean but memory hungry.
205
+ # Checkpointing could be added here if OOM occurs.
206
+ expert_outputs = torch.stack([expert(h_conv) for expert in self.experts]) # (E, B, S//4, D)
207
 
208
+ # Weighted Sum
209
+ # (Experts, Batch, Seq, Dim) * (Batch, Seq, Experts) -> (Batch, Seq, Dim)
210
  final_out = torch.einsum('ebsd, bse -> bsd', expert_outputs, routing_weights)
211
 
212
+ return self.out_norm(final_out)
 
 
 
 
213
 
214
  def get_output_length(self, input_length: int) -> int:
215
  """Calculate output sequence length given input length."""
 
217
  padded = input_length + (4 - input_length % 4) % 4
218
  return padded // 4
219
 
220
+ def get_aux_loss(self) -> torch.Tensor:
221
+ """MOSA uses only cross-entropy loss, so aux loss is 0."""
222
+ return torch.tensor(0.0, device=self.conv[0].weight.device)
223
 
224
  # =============================================================================
225
  # SwiGLU Projector