| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from typing import Tuple, Optional, List | |
| import math | |
| class Expert(nn.Module): | |
| def __init__( | |
| self, | |
| dim: int, | |
| hidden_dim: int, | |
| dropout: float = 0.0, | |
| bias: bool = False | |
| ): | |
| super().__init__() | |
| self.w1 = nn.Linear(dim, hidden_dim, bias=bias) | |
| self.w2 = nn.Linear(hidden_dim, dim, bias=bias) | |
| self.w3 = nn.Linear(dim, hidden_dim, bias=bias) | |
| self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity() | |
| self._init_weights() | |
| def _init_weights(self): | |
| """改进的权重初始化""" | |
| for module in [self.w1, self.w2, self.w3]: | |
| nn.init.normal_(module.weight, mean=0.0, std=0.02) | |
| if module.bias is not None: | |
| nn.init.zeros_(module.bias) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x))) | |
| class TopKRouter(nn.Module): | |
| def __init__( | |
| self, | |
| dim: int, | |
| num_experts: int, | |
| top_k: int = 2, | |
| capacity_factor: float = 1.25, | |
| noise_std: float = 1.0, | |
| use_expert_capacity: bool = True, | |
| router_z_loss_coef: float = 0.001, | |
| router_aux_loss_coef: float = 0.01 | |
| ): | |
| super().__init__() | |
| self.num_experts = num_experts | |
| self.top_k = top_k | |
| self.capacity_factor = capacity_factor | |
| self.noise_std = noise_std | |
| self.use_expert_capacity = use_expert_capacity | |
| self.router_z_loss_coef = router_z_loss_coef | |
| self.router_aux_loss_coef = router_aux_loss_coef | |
| self.gate = nn.Linear(dim, num_experts, bias=False) | |
| nn.init.normal_(self.gate.weight, mean=0.0, std=0.02) | |
| def _compute_routing_weights( | |
| self, | |
| logits: torch.Tensor, | |
| use_noise: bool = True | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| if use_noise and self.training: | |
| noise = torch.randn_like(logits) * self.noise_std | |
| logits = logits + noise | |
| top_k_logits, top_k_indices = torch.topk(logits, self.top_k, dim=-1) | |
| top_k_gates = F.softmax(top_k_logits, dim=-1) | |
| return top_k_gates, top_k_indices | |
| def _compute_auxiliary_loss( | |
| self, | |
| logits: torch.Tensor, | |
| top_k_indices: torch.Tensor | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| num_tokens = logits.shape[0] | |
| router_probs = F.softmax(logits, dim=-1) | |
| expert_probs = router_probs.mean(dim=0) | |
| expert_mask = F.one_hot(top_k_indices, self.num_experts).float() | |
| expert_freq = expert_mask.sum(dim=[0, 1]) / (num_tokens * self.top_k) | |
| load_balance_loss = self.num_experts * torch.sum(expert_probs * expert_freq) | |
| z_loss = torch.mean(logits ** 2) | |
| return load_balance_loss, z_loss | |
| def forward( | |
| self, | |
| x: torch.Tensor | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| logits = self.gate(x) | |
| top_k_gates, top_k_indices = self._compute_routing_weights( | |
| logits, use_noise=self.training | |
| ) | |
| if self.training: | |
| load_balance_loss, z_loss = self._compute_auxiliary_loss(logits, top_k_indices) | |
| auxiliary_loss = ( | |
| self.router_aux_loss_coef * load_balance_loss + | |
| self.router_z_loss_coef * z_loss | |
| ) | |
| else: | |
| auxiliary_loss = torch.tensor(0.0, device=x.device) | |
| return top_k_gates, top_k_indices, auxiliary_loss | |
| class MixtureOfExperts(nn.Module): | |
| def __init__( | |
| self, | |
| dim: int, | |
| num_experts: int = 8, | |
| expert_hidden_dim: Optional[int] = None, | |
| top_k: int = 2, | |
| dropout: float = 0.0, | |
| capacity_factor: float = 1.25, | |
| use_expert_capacity: bool = True, | |
| router_z_loss_coef: float = 0.001, | |
| router_aux_loss_coef: float = 0.01, | |
| noise_std: float = 1.0, | |
| ffn_dim_multiplier: Optional[float] = None | |
| ): | |
| super().__init__() | |
| self.num_experts = num_experts | |
| self.top_k = top_k | |
| self.capacity_factor = capacity_factor | |
| self.use_expert_capacity = use_expert_capacity | |
| if expert_hidden_dim is None: | |
| if ffn_dim_multiplier is not None: | |
| expert_hidden_dim = int(dim * ffn_dim_multiplier) | |
| else: | |
| expert_hidden_dim = int(2 * dim * 4 / 3) | |
| expert_hidden_dim = 256 * ((expert_hidden_dim + 255) // 256) | |
| self.experts = nn.ModuleList([ | |
| Expert(dim, expert_hidden_dim, dropout, bias=False) | |
| for _ in range(num_experts) | |
| ]) | |
| self.router = TopKRouter( | |
| dim=dim, | |
| num_experts=num_experts, | |
| top_k=top_k, | |
| capacity_factor=capacity_factor, | |
| noise_std=noise_std, | |
| use_expert_capacity=use_expert_capacity, | |
| router_z_loss_coef=router_z_loss_coef, | |
| router_aux_loss_coef=router_aux_loss_coef | |
| ) | |
| def _compute_expert_capacity(self, num_tokens: int) -> int: | |
| """计算每个专家的容量""" | |
| if not self.use_expert_capacity: | |
| return num_tokens | |
| capacity = int( | |
| (num_tokens / self.num_experts) * self.capacity_factor * self.top_k | |
| ) | |
| return max(capacity, 1) | |
| def forward( | |
| self, | |
| x: torch.Tensor | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| B, T, D = x.shape | |
| num_tokens = B * T | |
| x_flat = x.view(-1, D) | |
| top_k_gates, top_k_indices, auxiliary_loss = self.router(x_flat) | |
| output = torch.zeros_like(x_flat) | |
| expert_capacity = self._compute_expert_capacity(num_tokens) | |
| for expert_idx, expert in enumerate(self.experts): | |
| expert_mask = (top_k_indices == expert_idx) | |
| token_indices, topk_positions = torch.where(expert_mask) | |
| if len(token_indices) == 0: | |
| continue | |
| if self.use_expert_capacity and len(token_indices) > expert_capacity: | |
| perm = torch.randperm(len(token_indices), device=x.device)[:expert_capacity] | |
| token_indices = token_indices[perm] | |
| topk_positions = topk_positions[perm] | |
| expert_input = x_flat[token_indices] | |
| expert_gates = top_k_gates[token_indices, topk_positions] | |
| expert_output = expert(expert_input) | |
| expert_output = expert_output * expert_gates.unsqueeze(-1) | |
| output.index_add_(0, token_indices, expert_output) | |
| output = output.view(B, T, D) | |
| return output, auxiliary_loss | |
| class SparseDispatcher: | |
| def __init__( | |
| self, | |
| num_experts: int, | |
| gates: torch.Tensor, | |
| expert_indices: torch.Tensor | |
| ): | |
| self.num_experts = num_experts | |
| self._gates = gates | |
| self._expert_indices = expert_indices | |
| self._expert_masks = [] | |
| for i in range(num_experts): | |
| self._expert_masks.append((expert_indices == i).nonzero(as_tuple=True)[0]) | |
| def dispatch(self, inp: torch.Tensor) -> List[torch.Tensor]: | |
| expert_inputs = [] | |
| for mask in self._expert_masks: | |
| if len(mask) > 0: | |
| expert_inputs.append(inp[mask]) | |
| else: | |
| expert_inputs.append( | |
| torch.empty(0, inp.size(-1), device=inp.device, dtype=inp.dtype) | |
| ) | |
| return expert_inputs | |
| def combine(self, expert_outputs: List[torch.Tensor]) -> torch.Tensor: | |
| output_shape = (self._gates.size(0), expert_outputs[0].size(-1)) | |
| output = torch.zeros( | |
| output_shape, | |
| device=self._gates.device, | |
| dtype=expert_outputs[0].dtype | |
| ) | |
| for expert_idx, expert_out in enumerate(expert_outputs): | |
| mask = self._expert_masks[expert_idx] | |
| if len(mask) > 0: | |
| weighted_output = expert_out * self._gates[mask, expert_idx].unsqueeze(-1) | |
| output[mask] += weighted_output | |
| return output | |
| def expert_to_gates(self) -> List[torch.Tensor]: | |
| gates_per_expert = [] | |
| for expert_idx in range(self.num_experts): | |
| mask = self._expert_masks[expert_idx] | |
| if len(mask) > 0: | |
| gates_per_expert.append(self._gates[mask, expert_idx]) | |
| else: | |
| gates_per_expert.append(torch.empty(0, device=self._gates.device)) | |
| return gates_per_expert | |
| class MoELayer(nn.Module): | |
| def __init__( | |
| self, | |
| dim: int, | |
| num_experts: int = 8, | |
| expert_hidden_dim: Optional[int] = None, | |
| top_k: int = 2, | |
| dropout: float = 0.0, | |
| capacity_factor: float = 1.25 | |
| ): | |
| super().__init__() | |
| self.num_experts = num_experts | |
| self.top_k = top_k | |
| if expert_hidden_dim is None: | |
| expert_hidden_dim = int(2 * dim * 4 / 3) | |
| expert_hidden_dim = 256 * ((expert_hidden_dim + 255) // 256) | |
| self.experts = nn.ModuleList([ | |
| Expert(dim, expert_hidden_dim, dropout) | |
| for _ in range(num_experts) | |
| ]) | |
| self.gate = nn.Linear(dim, num_experts, bias=False) | |
| nn.init.normal_(self.gate.weight, std=0.02) | |
| self.capacity_factor = capacity_factor | |
| def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
| B, T, D = x.shape | |
| x_flat = x.view(-1, D) | |
| gates = F.softmax(self.gate(x_flat), dim=-1) | |
| top_k_gates, top_k_indices = torch.topk(gates, self.top_k, dim=-1) | |
| top_k_gates = F.softmax(top_k_gates, dim=-1) | |
| expert_probs = gates.mean(dim=0) | |
| expert_counts = F.one_hot(top_k_indices, self.num_experts).float().sum(dim=[0, 1]) | |
| expert_counts = expert_counts / (B * T * self.top_k) | |
| aux_loss = self.num_experts * torch.sum(expert_probs * expert_counts) | |
| output = torch.zeros_like(x_flat) | |
| for expert_idx, expert in enumerate(self.experts): | |
| expert_mask = (top_k_indices == expert_idx) | |
| token_indices, topk_positions = torch.where(expert_mask) | |
| if len(token_indices) == 0: | |
| continue | |
| expert_input = x_flat[token_indices] | |
| expert_gates = top_k_gates[token_indices, topk_positions] | |
| expert_output = expert(expert_input) | |
| expert_output = expert_output * expert_gates.unsqueeze(-1) | |
| output.index_add_(0, token_indices, expert_output) | |
| output = output.view(B, T, D) | |
| return output, aux_loss |