WJAD / src /wjad /modules /moe.py
fuzirui's picture
Sync WJAD codebase
0cfefd2 verified
"""每层独立 MoE 块(7 路由 + 1 共享专家,GAP 序列级 Sigmoid Top-3)。
设计要点(与 Design.md 对齐):
- 每层独立 8 个专家库(专家[0] 为共享),不同层之间不共享。
- 路由:对当前层输入做 ``GAP(序列) -> Linear -> Sigmoid -> Top3 mask``。
- 共享专家始终激活;路由专家依据 sigmoid 概率加权(Stage1 全激活、Stage2 严格 Top-3)。
- 输出 = 共享专家(x) + sum_i (probs_i * mask_i) * expert_i(x)。
- 提供路由 logits / probs / 负载均衡 / 边界正则的辅助统计,外部由
``losses/moe_aux.py`` 聚合成正则损失。
"""
from __future__ import annotations
from dataclasses import dataclass
import torch
import torch.nn as nn
from .ffn import SwiGLUFFN
@dataclass
class MoEStats:
"""单层 MoE 输出的辅助统计,用于损失与监控。"""
logits: torch.Tensor # [B, num_routed]
probs: torch.Tensor # [B, num_routed],sigmoid 后的概率
topk_mask: torch.Tensor # [B, num_routed],0/1
class PerLayerExperts(nn.Module):
"""单层的专家库:1 个共享 + N 个路由,全部为 SwiGLUFFN。"""
def __init__(
self,
dim: int,
num_routed: int = 7,
num_shared: int = 1,
ffn_mult: int = 4,
dropout: float = 0.0,
) -> None:
super().__init__()
self.num_routed = num_routed
self.num_shared = num_shared
self.shared = nn.ModuleList(
[SwiGLUFFN(dim, mult=ffn_mult, dropout=dropout) for _ in range(num_shared)]
)
self.routed = nn.ModuleList(
[SwiGLUFFN(dim, mult=ffn_mult, dropout=dropout) for _ in range(num_routed)]
)
class MoEBlock(nn.Module):
"""带路由的 MoE FFN 块(每层独立专家库)。"""
def __init__(
self,
dim: int,
num_routed: int = 7,
num_shared: int = 1,
ffn_mult: int = 4,
topk: int = 3,
dropout: float = 0.0,
) -> None:
super().__init__()
self.dim = dim
self.num_routed = num_routed
self.num_shared = num_shared
self.topk = topk
self.experts = PerLayerExperts(dim, num_routed, num_shared, ffn_mult, dropout)
self.router = nn.Linear(dim, num_routed, bias=True)
# 路由初始化:bias=0、weight 较小,以使初始概率接近 0.5
nn.init.normal_(self.router.weight, std=0.02)
nn.init.zeros_(self.router.bias)
# 训练阶段:'dense' 等同于 topk=num_routed;'sparse' 用真实 topk
self._mode: str = "dense"
# 路由温度(温度 < 1 => 锐化)
self.register_buffer("router_temperature", torch.tensor(1.0))
def set_mode(self, mode: str) -> None:
assert mode in ("dense", "sparse"), f"未知模式: {mode}"
self._mode = mode
@property
def mode(self) -> str:
return self._mode
def set_temperature(self, t: float) -> None:
self.router_temperature.fill_(float(t))
def _route(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""计算 logits / probs / topk_mask。x: [B, N, D]。"""
pooled = x.mean(dim=1) # [B, D]
logits = self.router(pooled) # [B, num_routed]
# 温度锐化(温度小 => 概率更尖)
scaled = logits / self.router_temperature.clamp_min(1e-3)
probs = torch.sigmoid(scaled)
if self._mode == "dense" or self.topk >= self.num_routed:
mask = torch.ones_like(probs)
else:
topk_vals, topk_idx = torch.topk(probs, self.topk, dim=-1)
mask = torch.zeros_like(probs)
mask.scatter_(-1, topk_idx, 1.0)
return logits, probs, mask
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, MoEStats]:
b, n, d = x.shape
logits, probs, mask = self._route(x)
# 共享专家(恒激活,无门控)
out = torch.zeros_like(x)
for sh in self.experts.shared:
out = out + sh(x)
# 路由专家:按 (probs * mask) 在 batch 维加权
weights = probs * mask # [B, num_routed]
# 注意:每个样本各自的权重独立。逐专家计算后 batch 级加权,避免 token 级
# 路由的索引开销;与 Design.md "序列级分配" 一致。
for i, expert in enumerate(self.experts.routed):
w_i = weights[:, i].view(b, 1, 1) # [B,1,1]
# 仅当批内任一样本权重 > 0 时才前向以减少计算
if torch.any(w_i > 0):
out = out + w_i * expert(x)
stats = MoEStats(logits=logits, probs=probs, topk_mask=mask)
return out, stats