mazesmazes commited on
Commit
0e61f3c
·
verified ·
1 Parent(s): 29c35ee

Training in progress - step 500

Browse files
Files changed (2) hide show
  1. model.safetensors +2 -2
  2. projectors.py +48 -68
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:bfcf8f5279040512d65baa819b6fd783c141559fa591e3eb1de7d8ade6a05df0
3
- size 375027488
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c6cc1a109c001eab177849bb49fba0b584ab2bd29c03f209b74093d2cb9c1e9e
3
+ size 509146304
projectors.py CHANGED
@@ -76,33 +76,21 @@ import torch.nn.functional as F
76
  # MoE Projector (MOSA-style)
77
  # =============================================================================
78
 
79
- class RMSNorm(nn.Module):
80
- """Standard RMSNorm for 2025 architectures."""
81
- def __init__(self, dim: int, eps: float = 1e-6):
82
- super().__init__()
83
- self.eps = eps
84
- self.weight = nn.Parameter(torch.ones(dim))
85
 
86
- def forward(self, x):
87
- var = torch.mean(x ** 2, dim=-1, keepdim=True)
88
- x_normed = x * torch.rsqrt(var + self.eps)
89
- return self.weight * x_normed
90
 
91
- class SimpleAdapter(nn.Module):
92
- """
93
- Updated Adapter:
94
- 1. Uses SiLU (better for LLM alignment).
95
- 2. Includes internal Norm (crucial for MoE stability).
96
- """
97
- def __init__(self, in_dim, hidden_dim, out_dim):
98
  super().__init__()
99
- self.fc1 = nn.Linear(in_dim, hidden_dim)
100
- self.act = nn.SiLU() # Changed from ReLU to SiLU
101
- self.fc2 = nn.Linear(hidden_dim, out_dim)
102
- # Optional: Add Dropout if training on small datasets
103
-
104
- def forward(self, x):
105
- return self.fc2(self.act(self.fc1(x)))
 
 
106
 
107
  class MOSAProjector(nn.Module):
108
  def __init__(self, config):
@@ -112,8 +100,16 @@ class MOSAProjector(nn.Module):
112
  self.num_experts = getattr(config, "num_experts", None) or 8
113
  adapter_hidden = getattr(config, "adapter_hidden_dim", None) or 4096
114
 
 
 
 
 
 
 
 
 
115
  # --- 1. Pre-Norms (CRITICAL for stability) ---
116
- self.in_norm = RMSNorm(self.encoder_dim)
117
 
118
  # --- 2. Convolutional Subsampling (Stride 4) ---
119
  self.conv = nn.Sequential(
@@ -138,15 +134,15 @@ class MOSAProjector(nn.Module):
138
  nn.Linear(1280, self.num_experts),
139
  )
140
 
141
- # --- 4. Experts ---
142
  self.experts = nn.ModuleList([
143
- SimpleAdapter(self.llm_dim, adapter_hidden, self.llm_dim)
144
  for _ in range(self.num_experts)
145
  ])
146
 
147
  # --- 5. Output Norm ---
148
  # Projects often drift in magnitude; this clamps them before the LLM.
149
- self.out_norm = RMSNorm(self.llm_dim)
150
 
151
  self._init_weights()
152
 
@@ -163,12 +159,11 @@ class MOSAProjector(nn.Module):
163
  # Force the LAST router layer to be small (but not zero)
164
  nn.init.normal_(self.router[-1].weight, std=0.01)
165
 
166
- # --- 2. Expert Initialization ---
167
  for expert in self.experts:
168
- nn.init.kaiming_uniform_(expert.fc1.weight, a=math.sqrt(5))
169
- nn.init.xavier_uniform_(expert.fc2.weight)
170
- if expert.fc2.bias is not None:
171
- nn.init.zeros_(expert.fc2.bias)
172
 
173
  def forward(self, x):
174
  # x: (B, S, 1280)
@@ -193,10 +188,14 @@ class MOSAProjector(nn.Module):
193
 
194
  # Router Logits
195
  router_logits = self.router(x_pooled) # (B, S//4, num_experts)
196
-
197
  # Softmax for Dense MoE (Soft Mixing)
198
  routing_weights = F.softmax(router_logits, dim=-1)
199
 
 
 
 
 
200
  # --- 3. Expert Mixture (Dense Execution) ---
201
  # Warning: High VRAM usage. Runs all experts.
202
  # h_conv: (B, S//4, llm_dim)
@@ -218,8 +217,18 @@ class MOSAProjector(nn.Module):
218
  return padded // 4
219
 
220
  def get_aux_loss(self) -> torch.Tensor:
221
- """MOSA uses only cross-entropy loss, so aux loss is 0."""
222
- return torch.tensor(0.0, device=self.conv[0].weight.device)
 
 
 
 
 
 
 
 
 
 
223
 
224
  # =============================================================================
225
  # SwiGLU Projector
@@ -340,13 +349,13 @@ class ResidualAudioProjector(nn.Module):
340
  dropout_rate = getattr(config, "projector_dropout", 0.0)
341
 
342
  self.input_proj = nn.Linear(in_dim, out_dim)
343
- self.ln_input = LlamaRMSNorm(out_dim, eps=1e-6)
344
 
345
  self.layers = nn.ModuleList(
346
  [ResidualMLP(out_dim, hidden_dim, dropout=dropout_rate) for _ in range(self.num_layers)]
347
  )
348
  self.layer_norms = nn.ModuleList(
349
- [LlamaRMSNorm(out_dim, eps=1e-6) for _ in range(self.num_layers)]
350
  )
351
 
352
  self.output_dropout = nn.Dropout(dropout_rate)
@@ -408,35 +417,6 @@ class ResidualAudioProjector(nn.Module):
408
  # =============================================================================
409
 
410
 
411
- class RMSNorm(nn.Module):
412
- """RMS Normalization (SOTA normalization for transformers)."""
413
-
414
- def __init__(self, dim: int, eps: float = 1e-6):
415
- super().__init__()
416
- self.eps = eps
417
- self.weight = nn.Parameter(torch.ones(dim))
418
-
419
- def forward(self, x):
420
- var = x.pow(2).mean(-1, keepdim=True)
421
- x_normed = x * torch.rsqrt(var + self.eps)
422
- return self.weight * x_normed
423
-
424
-
425
- class SwiGLUExpert(nn.Module):
426
- """SwiGLU expert MLP."""
427
-
428
- def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
429
- super().__init__()
430
- # Bias=False is strictly preferred for MoE experts to reduce memory/compute
431
- self.gate_proj = nn.Linear(input_dim, hidden_dim, bias=False)
432
- self.up_proj = nn.Linear(input_dim, hidden_dim, bias=False)
433
- self.down_proj = nn.Linear(hidden_dim, output_dim, bias=False)
434
- self.act = nn.SiLU()
435
-
436
- def forward(self, x: torch.Tensor) -> torch.Tensor:
437
- return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
438
-
439
-
440
  class SharedMoEBlock(nn.Module):
441
  """MoE block with Shared + Sigmoid-Routed Experts."""
442
 
@@ -454,7 +434,7 @@ class SharedMoEBlock(nn.Module):
454
  self.output_dim = output_dim
455
 
456
  # RMSNorm before routing
457
- self.norm = RMSNorm(input_dim)
458
 
459
  self.router = nn.Linear(input_dim, num_experts, bias=False)
460
  nn.init.normal_(self.router.weight, mean=0.0, std=0.02)
 
76
  # MoE Projector (MOSA-style)
77
  # =============================================================================
78
 
 
 
 
 
 
 
79
 
80
+ class SwiGLUExpert(nn.Module):
81
+ """SwiGLU expert MLP."""
 
 
82
 
83
+ def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
 
 
 
 
 
 
84
  super().__init__()
85
+ # Bias=False is strictly preferred for MoE experts to reduce memory/compute
86
+ self.gate_proj = nn.Linear(input_dim, hidden_dim, bias=False)
87
+ self.up_proj = nn.Linear(input_dim, hidden_dim, bias=False)
88
+ self.down_proj = nn.Linear(hidden_dim, output_dim, bias=False)
89
+ self.act = nn.SiLU()
90
+
91
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
92
+ return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
93
+
94
 
95
  class MOSAProjector(nn.Module):
96
  def __init__(self, config):
 
100
  self.num_experts = getattr(config, "num_experts", None) or 8
101
  adapter_hidden = getattr(config, "adapter_hidden_dim", None) or 4096
102
 
103
+ # Auxiliary loss coefficients (same defaults as SharedMoE)
104
+ self.aux_loss_coef = getattr(config, "router_aux_loss_coef", 0.02)
105
+ self.z_loss_coef = getattr(config, "router_z_loss_coef", 0.001)
106
+
107
+ # Store router state for aux loss computation
108
+ self.last_router_logits = None
109
+ self.last_routing_weights = None
110
+
111
  # --- 1. Pre-Norms (CRITICAL for stability) ---
112
+ self.in_norm = LlamaRMSNorm(self.encoder_dim, eps=1e-8)
113
 
114
  # --- 2. Convolutional Subsampling (Stride 4) ---
115
  self.conv = nn.Sequential(
 
134
  nn.Linear(1280, self.num_experts),
135
  )
136
 
137
+ # --- 4. Experts (SwiGLU for LLM compatibility) ---
138
  self.experts = nn.ModuleList([
139
+ SwiGLUExpert(self.llm_dim, adapter_hidden, self.llm_dim)
140
  for _ in range(self.num_experts)
141
  ])
142
 
143
  # --- 5. Output Norm ---
144
  # Projects often drift in magnitude; this clamps them before the LLM.
145
+ self.out_norm = LlamaRMSNorm(self.llm_dim, eps=1e-8)
146
 
147
  self._init_weights()
148
 
 
159
  # Force the LAST router layer to be small (but not zero)
160
  nn.init.normal_(self.router[-1].weight, std=0.01)
161
 
162
+ # --- 2. Expert Initialization (SwiGLU) ---
163
  for expert in self.experts:
164
+ nn.init.orthogonal_(expert.gate_proj.weight)
165
+ nn.init.orthogonal_(expert.up_proj.weight)
166
+ nn.init.orthogonal_(expert.down_proj.weight, gain=0.5)
 
167
 
168
  def forward(self, x):
169
  # x: (B, S, 1280)
 
188
 
189
  # Router Logits
190
  router_logits = self.router(x_pooled) # (B, S//4, num_experts)
191
+
192
  # Softmax for Dense MoE (Soft Mixing)
193
  routing_weights = F.softmax(router_logits, dim=-1)
194
 
195
+ # Store for aux loss computation
196
+ self.last_router_logits = router_logits
197
+ self.last_routing_weights = routing_weights
198
+
199
  # --- 3. Expert Mixture (Dense Execution) ---
200
  # Warning: High VRAM usage. Runs all experts.
201
  # h_conv: (B, S//4, llm_dim)
 
217
  return padded // 4
218
 
219
  def get_aux_loss(self) -> torch.Tensor:
220
+ """Compute auxiliary losses: load balancing + z-loss."""
221
+ if self.last_router_logits is None:
222
+ return torch.tensor(0.0, device=self.conv[0].weight.device)
223
+
224
+ # Flatten for loss computation: (B, S, E) -> (B*S, E)
225
+ logits_flat = self.last_router_logits.view(-1, self.num_experts)
226
+ probs_flat = self.last_routing_weights.view(-1, self.num_experts)
227
+
228
+ balance = load_balancing_loss(probs_flat, self.num_experts, top_k=self.num_experts)
229
+ z = z_loss(logits_flat)
230
+
231
+ return self.aux_loss_coef * balance + self.z_loss_coef * z
232
 
233
  # =============================================================================
234
  # SwiGLU Projector
 
349
  dropout_rate = getattr(config, "projector_dropout", 0.0)
350
 
351
  self.input_proj = nn.Linear(in_dim, out_dim)
352
+ self.ln_input = LlamaRMSNorm(out_dim, eps=1e-8)
353
 
354
  self.layers = nn.ModuleList(
355
  [ResidualMLP(out_dim, hidden_dim, dropout=dropout_rate) for _ in range(self.num_layers)]
356
  )
357
  self.layer_norms = nn.ModuleList(
358
+ [LlamaRMSNorm(out_dim, eps=1e-8) for _ in range(self.num_layers)]
359
  )
360
 
361
  self.output_dropout = nn.Dropout(dropout_rate)
 
417
  # =============================================================================
418
 
419
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
420
  class SharedMoEBlock(nn.Module):
421
  """MoE block with Shared + Sigmoid-Routed Experts."""
422
 
 
434
  self.output_dim = output_dim
435
 
436
  # RMSNorm before routing
437
+ self.norm = LlamaRMSNorm(input_dim, eps=1e-8)
438
 
439
  self.router = nn.Linear(input_dim, num_experts, bias=False)
440
  nn.init.normal_(self.router.weight, mean=0.0, std=0.02)