File size: 4,731 Bytes
0cfefd2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 | """每层独立 MoE 块(7 路由 + 1 共享专家,GAP 序列级 Sigmoid Top-3)。
设计要点(与 Design.md 对齐):
- 每层独立 8 个专家库(专家[0] 为共享),不同层之间不共享。
- 路由:对当前层输入做 ``GAP(序列) -> Linear -> Sigmoid -> Top3 mask``。
- 共享专家始终激活;路由专家依据 sigmoid 概率加权(Stage1 全激活、Stage2 严格 Top-3)。
- 输出 = 共享专家(x) + sum_i (probs_i * mask_i) * expert_i(x)。
- 提供路由 logits / probs / 负载均衡 / 边界正则的辅助统计,外部由
``losses/moe_aux.py`` 聚合成正则损失。
"""
from __future__ import annotations
from dataclasses import dataclass
import torch
import torch.nn as nn
from .ffn import SwiGLUFFN
@dataclass
class MoEStats:
"""单层 MoE 输出的辅助统计,用于损失与监控。"""
logits: torch.Tensor # [B, num_routed]
probs: torch.Tensor # [B, num_routed],sigmoid 后的概率
topk_mask: torch.Tensor # [B, num_routed],0/1
class PerLayerExperts(nn.Module):
"""单层的专家库:1 个共享 + N 个路由,全部为 SwiGLUFFN。"""
def __init__(
self,
dim: int,
num_routed: int = 7,
num_shared: int = 1,
ffn_mult: int = 4,
dropout: float = 0.0,
) -> None:
super().__init__()
self.num_routed = num_routed
self.num_shared = num_shared
self.shared = nn.ModuleList(
[SwiGLUFFN(dim, mult=ffn_mult, dropout=dropout) for _ in range(num_shared)]
)
self.routed = nn.ModuleList(
[SwiGLUFFN(dim, mult=ffn_mult, dropout=dropout) for _ in range(num_routed)]
)
class MoEBlock(nn.Module):
"""带路由的 MoE FFN 块(每层独立专家库)。"""
def __init__(
self,
dim: int,
num_routed: int = 7,
num_shared: int = 1,
ffn_mult: int = 4,
topk: int = 3,
dropout: float = 0.0,
) -> None:
super().__init__()
self.dim = dim
self.num_routed = num_routed
self.num_shared = num_shared
self.topk = topk
self.experts = PerLayerExperts(dim, num_routed, num_shared, ffn_mult, dropout)
self.router = nn.Linear(dim, num_routed, bias=True)
# 路由初始化:bias=0、weight 较小,以使初始概率接近 0.5
nn.init.normal_(self.router.weight, std=0.02)
nn.init.zeros_(self.router.bias)
# 训练阶段:'dense' 等同于 topk=num_routed;'sparse' 用真实 topk
self._mode: str = "dense"
# 路由温度(温度 < 1 => 锐化)
self.register_buffer("router_temperature", torch.tensor(1.0))
def set_mode(self, mode: str) -> None:
assert mode in ("dense", "sparse"), f"未知模式: {mode}"
self._mode = mode
@property
def mode(self) -> str:
return self._mode
def set_temperature(self, t: float) -> None:
self.router_temperature.fill_(float(t))
def _route(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""计算 logits / probs / topk_mask。x: [B, N, D]。"""
pooled = x.mean(dim=1) # [B, D]
logits = self.router(pooled) # [B, num_routed]
# 温度锐化(温度小 => 概率更尖)
scaled = logits / self.router_temperature.clamp_min(1e-3)
probs = torch.sigmoid(scaled)
if self._mode == "dense" or self.topk >= self.num_routed:
mask = torch.ones_like(probs)
else:
topk_vals, topk_idx = torch.topk(probs, self.topk, dim=-1)
mask = torch.zeros_like(probs)
mask.scatter_(-1, topk_idx, 1.0)
return logits, probs, mask
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, MoEStats]:
b, n, d = x.shape
logits, probs, mask = self._route(x)
# 共享专家(恒激活,无门控)
out = torch.zeros_like(x)
for sh in self.experts.shared:
out = out + sh(x)
# 路由专家:按 (probs * mask) 在 batch 维加权
weights = probs * mask # [B, num_routed]
# 注意:每个样本各自的权重独立。逐专家计算后 batch 级加权,避免 token 级
# 路由的索引开销;与 Design.md "序列级分配" 一致。
for i, expert in enumerate(self.experts.routed):
w_i = weights[:, i].view(b, 1, 1) # [B,1,1]
# 仅当批内任一样本权重 > 0 时才前向以减少计算
if torch.any(w_i > 0):
out = out + w_i * expert(x)
stats = MoEStats(logits=logits, probs=probs, topk_mask=mask)
return out, stats
|