import torch import torch.nn as nn import torch.nn.functional as F from typing import Tuple, Optional, List import math class Expert(nn.Module): def __init__( self, dim: int, hidden_dim: int, dropout: float = 0.0, bias: bool = False ): super().__init__() self.w1 = nn.Linear(dim, hidden_dim, bias=bias) self.w2 = nn.Linear(hidden_dim, dim, bias=bias) self.w3 = nn.Linear(dim, hidden_dim, bias=bias) self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity() self._init_weights() def _init_weights(self): """改进的权重初始化""" for module in [self.w1, self.w2, self.w3]: nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.bias is not None: nn.init.zeros_(module.bias) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x))) class TopKRouter(nn.Module): def __init__( self, dim: int, num_experts: int, top_k: int = 2, capacity_factor: float = 1.25, noise_std: float = 1.0, use_expert_capacity: bool = True, router_z_loss_coef: float = 0.001, router_aux_loss_coef: float = 0.01 ): super().__init__() self.num_experts = num_experts self.top_k = top_k self.capacity_factor = capacity_factor self.noise_std = noise_std self.use_expert_capacity = use_expert_capacity self.router_z_loss_coef = router_z_loss_coef self.router_aux_loss_coef = router_aux_loss_coef self.gate = nn.Linear(dim, num_experts, bias=False) nn.init.normal_(self.gate.weight, mean=0.0, std=0.02) def _compute_routing_weights( self, logits: torch.Tensor, use_noise: bool = True ) -> Tuple[torch.Tensor, torch.Tensor]: if use_noise and self.training: noise = torch.randn_like(logits) * self.noise_std logits = logits + noise top_k_logits, top_k_indices = torch.topk(logits, self.top_k, dim=-1) top_k_gates = F.softmax(top_k_logits, dim=-1) return top_k_gates, top_k_indices def _compute_auxiliary_loss( self, logits: torch.Tensor, top_k_indices: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: num_tokens = logits.shape[0] router_probs = F.softmax(logits, dim=-1) expert_probs = router_probs.mean(dim=0) expert_mask = F.one_hot(top_k_indices, self.num_experts).float() expert_freq = expert_mask.sum(dim=[0, 1]) / (num_tokens * self.top_k) load_balance_loss = self.num_experts * torch.sum(expert_probs * expert_freq) z_loss = torch.mean(logits ** 2) return load_balance_loss, z_loss def forward( self, x: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: logits = self.gate(x) top_k_gates, top_k_indices = self._compute_routing_weights( logits, use_noise=self.training ) if self.training: load_balance_loss, z_loss = self._compute_auxiliary_loss(logits, top_k_indices) auxiliary_loss = ( self.router_aux_loss_coef * load_balance_loss + self.router_z_loss_coef * z_loss ) else: auxiliary_loss = torch.tensor(0.0, device=x.device) return top_k_gates, top_k_indices, auxiliary_loss class MixtureOfExperts(nn.Module): def __init__( self, dim: int, num_experts: int = 8, expert_hidden_dim: Optional[int] = None, top_k: int = 2, dropout: float = 0.0, capacity_factor: float = 1.25, use_expert_capacity: bool = True, router_z_loss_coef: float = 0.001, router_aux_loss_coef: float = 0.01, noise_std: float = 1.0, ffn_dim_multiplier: Optional[float] = None ): super().__init__() self.num_experts = num_experts self.top_k = top_k self.capacity_factor = capacity_factor self.use_expert_capacity = use_expert_capacity if expert_hidden_dim is None: if ffn_dim_multiplier is not None: expert_hidden_dim = int(dim * ffn_dim_multiplier) else: expert_hidden_dim = int(2 * dim * 4 / 3) expert_hidden_dim = 256 * ((expert_hidden_dim + 255) // 256) self.experts = nn.ModuleList([ Expert(dim, expert_hidden_dim, dropout, bias=False) for _ in range(num_experts) ]) self.router = TopKRouter( dim=dim, num_experts=num_experts, top_k=top_k, capacity_factor=capacity_factor, noise_std=noise_std, use_expert_capacity=use_expert_capacity, router_z_loss_coef=router_z_loss_coef, router_aux_loss_coef=router_aux_loss_coef ) def _compute_expert_capacity(self, num_tokens: int) -> int: """计算每个专家的容量""" if not self.use_expert_capacity: return num_tokens capacity = int( (num_tokens / self.num_experts) * self.capacity_factor * self.top_k ) return max(capacity, 1) def forward( self, x: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: B, T, D = x.shape num_tokens = B * T x_flat = x.view(-1, D) top_k_gates, top_k_indices, auxiliary_loss = self.router(x_flat) output = torch.zeros_like(x_flat) expert_capacity = self._compute_expert_capacity(num_tokens) for expert_idx, expert in enumerate(self.experts): expert_mask = (top_k_indices == expert_idx) token_indices, topk_positions = torch.where(expert_mask) if len(token_indices) == 0: continue if self.use_expert_capacity and len(token_indices) > expert_capacity: perm = torch.randperm(len(token_indices), device=x.device)[:expert_capacity] token_indices = token_indices[perm] topk_positions = topk_positions[perm] expert_input = x_flat[token_indices] expert_gates = top_k_gates[token_indices, topk_positions] expert_output = expert(expert_input) expert_output = expert_output * expert_gates.unsqueeze(-1) output.index_add_(0, token_indices, expert_output) output = output.view(B, T, D) return output, auxiliary_loss class SparseDispatcher: def __init__( self, num_experts: int, gates: torch.Tensor, expert_indices: torch.Tensor ): self.num_experts = num_experts self._gates = gates self._expert_indices = expert_indices self._expert_masks = [] for i in range(num_experts): self._expert_masks.append((expert_indices == i).nonzero(as_tuple=True)[0]) def dispatch(self, inp: torch.Tensor) -> List[torch.Tensor]: expert_inputs = [] for mask in self._expert_masks: if len(mask) > 0: expert_inputs.append(inp[mask]) else: expert_inputs.append( torch.empty(0, inp.size(-1), device=inp.device, dtype=inp.dtype) ) return expert_inputs def combine(self, expert_outputs: List[torch.Tensor]) -> torch.Tensor: output_shape = (self._gates.size(0), expert_outputs[0].size(-1)) output = torch.zeros( output_shape, device=self._gates.device, dtype=expert_outputs[0].dtype ) for expert_idx, expert_out in enumerate(expert_outputs): mask = self._expert_masks[expert_idx] if len(mask) > 0: weighted_output = expert_out * self._gates[mask, expert_idx].unsqueeze(-1) output[mask] += weighted_output return output def expert_to_gates(self) -> List[torch.Tensor]: gates_per_expert = [] for expert_idx in range(self.num_experts): mask = self._expert_masks[expert_idx] if len(mask) > 0: gates_per_expert.append(self._gates[mask, expert_idx]) else: gates_per_expert.append(torch.empty(0, device=self._gates.device)) return gates_per_expert class MoELayer(nn.Module): def __init__( self, dim: int, num_experts: int = 8, expert_hidden_dim: Optional[int] = None, top_k: int = 2, dropout: float = 0.0, capacity_factor: float = 1.25 ): super().__init__() self.num_experts = num_experts self.top_k = top_k if expert_hidden_dim is None: expert_hidden_dim = int(2 * dim * 4 / 3) expert_hidden_dim = 256 * ((expert_hidden_dim + 255) // 256) self.experts = nn.ModuleList([ Expert(dim, expert_hidden_dim, dropout) for _ in range(num_experts) ]) self.gate = nn.Linear(dim, num_experts, bias=False) nn.init.normal_(self.gate.weight, std=0.02) self.capacity_factor = capacity_factor def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: B, T, D = x.shape x_flat = x.view(-1, D) gates = F.softmax(self.gate(x_flat), dim=-1) top_k_gates, top_k_indices = torch.topk(gates, self.top_k, dim=-1) top_k_gates = F.softmax(top_k_gates, dim=-1) expert_probs = gates.mean(dim=0) expert_counts = F.one_hot(top_k_indices, self.num_experts).float().sum(dim=[0, 1]) expert_counts = expert_counts / (B * T * self.top_k) aux_loss = self.num_experts * torch.sum(expert_probs * expert_counts) output = torch.zeros_like(x_flat) for expert_idx, expert in enumerate(self.experts): expert_mask = (top_k_indices == expert_idx) token_indices, topk_positions = torch.where(expert_mask) if len(token_indices) == 0: continue expert_input = x_flat[token_indices] expert_gates = top_k_gates[token_indices, topk_positions] expert_output = expert(expert_input) expert_output = expert_output * expert_gates.unsqueeze(-1) output.index_add_(0, token_indices, expert_output) output = output.view(B, T, D) return output, aux_loss