Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from ..config import AetherisConfig | |
| from .expert import Expert | |
| class SparseMoELayer(nn.Module): | |
| """Memory-optimized Sparse MoE with efficient routing.""" | |
| def __init__(self, config: AetherisConfig): | |
| super().__init__() | |
| self.d_model = config.d_model | |
| self.num_experts = config.num_experts | |
| self.top_k = config.top_k | |
| self.load_balancing_coef = config.load_balancing_coef | |
| self.z_loss_coef = config.router_z_loss_coef | |
| self.gate = nn.Linear(config.d_model, config.num_experts, bias=False) | |
| self.experts = nn.ModuleList([Expert(config.d_model, config.d_ff) | |
| for _ in range(config.num_experts)]) | |
| self.norm = nn.LayerNorm(config.d_model) | |
| def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | |
| B, L, D = x.shape | |
| x_norm = self.norm(x) | |
| flat_x = x_norm.view(-1, D) | |
| # Routing Logits with stability | |
| gate_logits = self.gate(flat_x) | |
| # Clamp logits to prevent overflow | |
| gate_logits = torch.clamp(gate_logits, min=-10.0, max=10.0) | |
| # Z-Loss for stability | |
| z_loss = torch.mean(torch.logsumexp(gate_logits, dim=-1)**2) * self.z_loss_coef | |
| if self.training: | |
| # Reduce noise for stability | |
| gate_logits = gate_logits + torch.randn_like(gate_logits) * 1e-3 | |
| gate_probs = F.softmax(gate_logits, dim=-1) | |
| gate_weights, expert_indices = torch.topk(gate_probs, self.top_k, dim=-1) | |
| # Normalize weights for stability | |
| gate_weights = gate_weights / (gate_weights.sum(dim=-1, keepdim=True) + 1e-8) | |
| # Load balancing loss | |
| # Use only the top-1 expert for load balancing calculation to keep it simple and consistent | |
| expert_mask = F.one_hot(expert_indices[:, 0], num_classes=self.num_experts).float() | |
| fraction_routed = expert_mask.mean(dim=0) | |
| mean_prob = gate_probs.mean(dim=0) | |
| aux_loss = (self.num_experts * torch.sum(fraction_routed * mean_prob)) * self.load_balancing_coef | |
| total_aux_loss = aux_loss + z_loss | |
| # Efficient dispatch with in-place operations | |
| # Accumulate in float32 to prevent overflow during aggregation | |
| final_output = torch.zeros_like(flat_x, dtype=torch.float32) | |
| # Iterate over all k selected experts | |
| for k_idx in range(self.top_k): | |
| for i, expert in enumerate(self.experts): | |
| # Find tokens routed to expert 'i' at the k-th position | |
| mask = (expert_indices[:, k_idx] == i) | |
| if not mask.any(): | |
| continue | |
| expert_input = flat_x[mask] | |
| expert_out = expert(expert_input) | |
| # Apply weights | |
| weights = gate_weights[mask, k_idx].unsqueeze(1) | |
| # Cast to float32 for accumulation | |
| expert_out = expert_out.to(torch.float32) | |
| weights = weights.to(torch.float32) | |
| # Accumulate output (add to existing results from other experts) | |
| final_output[mask] += expert_out * weights | |
| # Cast back to original dtype | |
| final_output = final_output.to(flat_x.dtype) | |
| return x + final_output.view(B, L, D), total_aux_loss | |