tiny-audio / shared_moe_projector.py
mazesmazes's picture
Training in progress - step 1000
3f3ad46 verified
raw
history blame
7.02 kB
import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812
class SwiGLUExpert(nn.Module):
"""SwiGLU expert MLP (used for both shared and routed experts)."""
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
super().__init__()
self.gate_proj = nn.Linear(input_dim, hidden_dim, bias=False)
self.up_proj = nn.Linear(input_dim, hidden_dim, bias=False)
self.down_proj = nn.Linear(hidden_dim, output_dim, bias=False)
self.act = nn.SiLU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
class SharedMoEBlock(nn.Module):
"""MoE block with shared expert + sparse routed experts."""
def __init__(
self,
input_dim: int,
hidden_dim: int,
output_dim: int,
num_experts: int = 4,
top_k: int = 2,
):
super().__init__()
self.num_experts = num_experts
self.top_k = top_k
self.output_dim = output_dim
# Router: zero-initialized for natural learning
self.router = nn.Linear(input_dim, num_experts, bias=False)
nn.init.zeros_(self.router.weight)
# Shared expert (always active)
self.shared_expert = SwiGLUExpert(input_dim, hidden_dim, output_dim)
# Routed experts (sparse)
self.experts = nn.ModuleList(
[SwiGLUExpert(input_dim, hidden_dim, output_dim) for _ in range(num_experts)]
)
# For auxiliary loss (cached to avoid recomputation)
self.last_router_logits = None
self.last_router_probs = None
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, seq_len, dim = hidden_states.shape
# Shared expert output (all tokens)
shared_out = self.shared_expert(hidden_states)
# Routing
flat_hidden = hidden_states.view(-1, dim)
router_logits = self.router(flat_hidden)
router_probs = F.softmax(router_logits.float(), dim=-1)
# Cache for aux loss
self.last_router_logits = router_logits
self.last_router_probs = router_probs
# Top-k selection and renormalization
top_k_weights, top_k_indices = torch.topk(router_probs, self.top_k, dim=-1)
top_k_weights = top_k_weights / top_k_weights.sum(dim=-1, keepdim=True)
top_k_weights = top_k_weights.to(hidden_states.dtype)
# Routed expert output via token dispatch
routed_out = self._dispatch_experts(flat_hidden, top_k_indices, top_k_weights)
routed_out = routed_out.view(batch_size, seq_len, -1)
# Combine: shared expert baseline + routed experts (grow in via zero-init down_proj)
return shared_out + routed_out
def _dispatch_experts(
self,
hidden_states: torch.Tensor,
top_k_indices: torch.Tensor,
top_k_weights: torch.Tensor,
) -> torch.Tensor:
"""Token dispatch - gather tokens per expert, process, scatter back."""
num_tokens = hidden_states.shape[0]
output = torch.zeros(
num_tokens, self.output_dim, device=hidden_states.device, dtype=hidden_states.dtype
)
for expert_idx, expert in enumerate(self.experts):
expert_mask = top_k_indices == expert_idx
if not expert_mask.any():
continue
token_indices, slot_indices = torch.where(expert_mask)
expert_input = hidden_states[token_indices]
expert_output = expert(expert_input)
weights = top_k_weights[token_indices, slot_indices].unsqueeze(-1)
output.index_add_(0, token_indices, expert_output * weights)
return output
def load_balancing_loss(router_probs: torch.Tensor, num_experts: int, top_k: int) -> torch.Tensor:
"""Auxiliary loss to encourage balanced expert usage."""
_, selected = torch.topk(router_probs, top_k, dim=-1)
expert_mask = F.one_hot(selected, num_experts).float()
tokens_per_expert = expert_mask.mean(dim=(0, 1))
prob_per_expert = router_probs.mean(dim=0)
return (tokens_per_expert * prob_per_expert).sum() * num_experts
def z_loss(router_logits: torch.Tensor) -> torch.Tensor:
"""Z-loss to prevent router logits from growing too large."""
return torch.logsumexp(router_logits.float(), dim=-1).square().mean()
class SharedMoEAudioProjector(nn.Module):
def __init__(self, config):
super().__init__()
# Temporal downsampling
self.k = getattr(config, "projector_pool_stride", 4)
# Dimensions
encoder_dim = config.encoder_dim
in_dim = encoder_dim * self.k
out_dim = config.llm_dim
hidden_dim = getattr(config, "projector_hidden_dim", None) or in_dim
# MoE config
self.num_experts = getattr(config, "num_experts", 4)
self.top_k = getattr(config, "num_experts_per_tok", 2)
self.aux_loss_coef = getattr(config, "router_aux_loss_coef", 0.02)
self.z_loss_coef = getattr(config, "router_z_loss_coef", 0.001)
# Layers
self.moe = SharedMoEBlock(in_dim, hidden_dim, out_dim, self.num_experts, self.top_k)
# Init
self._init_weights(in_dim)
def _init_weights(self, in_dim: int):
with torch.no_grad():
# Shared expert - orthogonal init for stable condition numbers
nn.init.orthogonal_(self.moe.shared_expert.gate_proj.weight)
nn.init.orthogonal_(self.moe.shared_expert.up_proj.weight)
nn.init.orthogonal_(self.moe.shared_expert.down_proj.weight, gain=0.5)
# Routed experts - orthogonal for gate/up, tiny orthogonal for down (grow-in)
# gain=0.01 gives ~1% initial contribution while maintaining good conditioning
for expert in self.moe.experts:
nn.init.orthogonal_(expert.gate_proj.weight)
nn.init.orthogonal_(expert.up_proj.weight)
nn.init.orthogonal_(expert.down_proj.weight, gain=0.01)
def forward(self, x: torch.Tensor) -> torch.Tensor:
batch_size, seq_len, dim = x.size()
target_dtype = self.moe.shared_expert.gate_proj.weight.dtype
if x.dtype != target_dtype:
x = x.to(target_dtype)
# Pad for pooling (at most k-1 frames -> 1 extra token, negligible impact)
if seq_len % self.k:
x = F.pad(x, (0, 0, 0, self.k - seq_len % self.k))
# Temporal pooling
x = x.view(batch_size, -1, dim * self.k)
return self.moe(x)
def get_aux_loss(self) -> torch.Tensor:
"""Get auxiliary losses (call after forward)."""
if self.moe.last_router_logits is None:
return torch.tensor(0.0, device=self.moe.router.weight.device)
balance = load_balancing_loss(self.moe.last_router_probs, self.num_experts, self.top_k)
z = z_loss(self.moe.last_router_logits)
return self.aux_loss_coef * balance + self.z_loss_coef * z