"""每层独立 MoE 块(7 路由 + 1 共享专家,GAP 序列级 Sigmoid Top-3)。 设计要点(与 Design.md 对齐): - 每层独立 8 个专家库(专家[0] 为共享),不同层之间不共享。 - 路由:对当前层输入做 ``GAP(序列) -> Linear -> Sigmoid -> Top3 mask``。 - 共享专家始终激活;路由专家依据 sigmoid 概率加权(Stage1 全激活、Stage2 严格 Top-3)。 - 输出 = 共享专家(x) + sum_i (probs_i * mask_i) * expert_i(x)。 - 提供路由 logits / probs / 负载均衡 / 边界正则的辅助统计,外部由 ``losses/moe_aux.py`` 聚合成正则损失。 """ from __future__ import annotations from dataclasses import dataclass import torch import torch.nn as nn from .ffn import SwiGLUFFN @dataclass class MoEStats: """单层 MoE 输出的辅助统计,用于损失与监控。""" logits: torch.Tensor # [B, num_routed] probs: torch.Tensor # [B, num_routed],sigmoid 后的概率 topk_mask: torch.Tensor # [B, num_routed],0/1 class PerLayerExperts(nn.Module): """单层的专家库:1 个共享 + N 个路由,全部为 SwiGLUFFN。""" def __init__( self, dim: int, num_routed: int = 7, num_shared: int = 1, ffn_mult: int = 4, dropout: float = 0.0, ) -> None: super().__init__() self.num_routed = num_routed self.num_shared = num_shared self.shared = nn.ModuleList( [SwiGLUFFN(dim, mult=ffn_mult, dropout=dropout) for _ in range(num_shared)] ) self.routed = nn.ModuleList( [SwiGLUFFN(dim, mult=ffn_mult, dropout=dropout) for _ in range(num_routed)] ) class MoEBlock(nn.Module): """带路由的 MoE FFN 块(每层独立专家库)。""" def __init__( self, dim: int, num_routed: int = 7, num_shared: int = 1, ffn_mult: int = 4, topk: int = 3, dropout: float = 0.0, ) -> None: super().__init__() self.dim = dim self.num_routed = num_routed self.num_shared = num_shared self.topk = topk self.experts = PerLayerExperts(dim, num_routed, num_shared, ffn_mult, dropout) self.router = nn.Linear(dim, num_routed, bias=True) # 路由初始化:bias=0、weight 较小,以使初始概率接近 0.5 nn.init.normal_(self.router.weight, std=0.02) nn.init.zeros_(self.router.bias) # 训练阶段:'dense' 等同于 topk=num_routed;'sparse' 用真实 topk self._mode: str = "dense" # 路由温度(温度 < 1 => 锐化) self.register_buffer("router_temperature", torch.tensor(1.0)) def set_mode(self, mode: str) -> None: assert mode in ("dense", "sparse"), f"未知模式: {mode}" self._mode = mode @property def mode(self) -> str: return self._mode def set_temperature(self, t: float) -> None: self.router_temperature.fill_(float(t)) def _route(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """计算 logits / probs / topk_mask。x: [B, N, D]。""" pooled = x.mean(dim=1) # [B, D] logits = self.router(pooled) # [B, num_routed] # 温度锐化(温度小 => 概率更尖) scaled = logits / self.router_temperature.clamp_min(1e-3) probs = torch.sigmoid(scaled) if self._mode == "dense" or self.topk >= self.num_routed: mask = torch.ones_like(probs) else: topk_vals, topk_idx = torch.topk(probs, self.topk, dim=-1) mask = torch.zeros_like(probs) mask.scatter_(-1, topk_idx, 1.0) return logits, probs, mask def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, MoEStats]: b, n, d = x.shape logits, probs, mask = self._route(x) # 共享专家(恒激活,无门控) out = torch.zeros_like(x) for sh in self.experts.shared: out = out + sh(x) # 路由专家:按 (probs * mask) 在 batch 维加权 weights = probs * mask # [B, num_routed] # 注意:每个样本各自的权重独立。逐专家计算后 batch 级加权,避免 token 级 # 路由的索引开销;与 Design.md "序列级分配" 一致。 for i, expert in enumerate(self.experts.routed): w_i = weights[:, i].view(b, 1, 1) # [B,1,1] # 仅当批内任一样本权重 > 0 时才前向以减少计算 if torch.any(w_i > 0): out = out + w_i * expert(x) stats = MoEStats(logits=logits, probs=probs, topk_mask=mask) return out, stats