| """每层独立 MoE 块(7 路由 + 1 共享专家,GAP 序列级 Sigmoid Top-3)。 |
| |
| 设计要点(与 Design.md 对齐): |
| - 每层独立 8 个专家库(专家[0] 为共享),不同层之间不共享。 |
| - 路由:对当前层输入做 ``GAP(序列) -> Linear -> Sigmoid -> Top3 mask``。 |
| - 共享专家始终激活;路由专家依据 sigmoid 概率加权(Stage1 全激活、Stage2 严格 Top-3)。 |
| - 输出 = 共享专家(x) + sum_i (probs_i * mask_i) * expert_i(x)。 |
| - 提供路由 logits / probs / 负载均衡 / 边界正则的辅助统计,外部由 |
| ``losses/moe_aux.py`` 聚合成正则损失。 |
| """ |
|
|
| from __future__ import annotations |
|
|
| from dataclasses import dataclass |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from .ffn import SwiGLUFFN |
|
|
|
|
| @dataclass |
| class MoEStats: |
| """单层 MoE 输出的辅助统计,用于损失与监控。""" |
|
|
| logits: torch.Tensor |
| probs: torch.Tensor |
| topk_mask: torch.Tensor |
|
|
|
|
| class PerLayerExperts(nn.Module): |
| """单层的专家库:1 个共享 + N 个路由,全部为 SwiGLUFFN。""" |
|
|
| def __init__( |
| self, |
| dim: int, |
| num_routed: int = 7, |
| num_shared: int = 1, |
| ffn_mult: int = 4, |
| dropout: float = 0.0, |
| ) -> None: |
| super().__init__() |
| self.num_routed = num_routed |
| self.num_shared = num_shared |
| self.shared = nn.ModuleList( |
| [SwiGLUFFN(dim, mult=ffn_mult, dropout=dropout) for _ in range(num_shared)] |
| ) |
| self.routed = nn.ModuleList( |
| [SwiGLUFFN(dim, mult=ffn_mult, dropout=dropout) for _ in range(num_routed)] |
| ) |
|
|
|
|
| class MoEBlock(nn.Module): |
| """带路由的 MoE FFN 块(每层独立专家库)。""" |
|
|
| def __init__( |
| self, |
| dim: int, |
| num_routed: int = 7, |
| num_shared: int = 1, |
| ffn_mult: int = 4, |
| topk: int = 3, |
| dropout: float = 0.0, |
| ) -> None: |
| super().__init__() |
| self.dim = dim |
| self.num_routed = num_routed |
| self.num_shared = num_shared |
| self.topk = topk |
| self.experts = PerLayerExperts(dim, num_routed, num_shared, ffn_mult, dropout) |
| self.router = nn.Linear(dim, num_routed, bias=True) |
| |
| nn.init.normal_(self.router.weight, std=0.02) |
| nn.init.zeros_(self.router.bias) |
|
|
| |
| self._mode: str = "dense" |
| |
| self.register_buffer("router_temperature", torch.tensor(1.0)) |
|
|
| def set_mode(self, mode: str) -> None: |
| assert mode in ("dense", "sparse"), f"未知模式: {mode}" |
| self._mode = mode |
|
|
| @property |
| def mode(self) -> str: |
| return self._mode |
|
|
| def set_temperature(self, t: float) -> None: |
| self.router_temperature.fill_(float(t)) |
|
|
| def _route(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| """计算 logits / probs / topk_mask。x: [B, N, D]。""" |
| pooled = x.mean(dim=1) |
| logits = self.router(pooled) |
| |
| scaled = logits / self.router_temperature.clamp_min(1e-3) |
| probs = torch.sigmoid(scaled) |
|
|
| if self._mode == "dense" or self.topk >= self.num_routed: |
| mask = torch.ones_like(probs) |
| else: |
| topk_vals, topk_idx = torch.topk(probs, self.topk, dim=-1) |
| mask = torch.zeros_like(probs) |
| mask.scatter_(-1, topk_idx, 1.0) |
|
|
| return logits, probs, mask |
|
|
| def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, MoEStats]: |
| b, n, d = x.shape |
| logits, probs, mask = self._route(x) |
|
|
| |
| out = torch.zeros_like(x) |
| for sh in self.experts.shared: |
| out = out + sh(x) |
|
|
| |
| weights = probs * mask |
| |
| |
| for i, expert in enumerate(self.experts.routed): |
| w_i = weights[:, i].view(b, 1, 1) |
| |
| if torch.any(w_i > 0): |
| out = out + w_i * expert(x) |
|
|
| stats = MoEStats(logits=logits, probs=probs, topk_mask=mask) |
| return out, stats |
|
|