"""18 层主干:前 9 Dense + 后 9 MoE。""" from __future__ import annotations from dataclasses import dataclass, field from typing import Optional import torch import torch.nn as nn import torch.utils.checkpoint as cp from ..modules.moe import MoEStats from .blocks import DenseBlock, MoEBlockWithAttn @dataclass class BackboneOutput: """主干输出。""" hidden_states: torch.Tensor # [B, N, D] moe_stats: list[MoEStats] = field(default_factory=list) class Backbone(nn.Module): """端到端主干。 输入序列已包含位置编码(视觉部分 RoPE 在每层内部应用,非视觉部分使用 可学习 PE 在外部加完)。本模块只负责 18 层堆叠 + 路由统计聚合。 """ def __init__( self, dim: int = 768, num_heads: int = 12, ffn_mult: int = 4, num_dense_layers: int = 9, num_moe_layers: int = 9, num_routed: int = 7, num_shared: int = 1, topk: int = 3, dropout: float = 0.0, ) -> None: super().__init__() self.dim = dim self.num_heads = num_heads self.num_dense_layers = num_dense_layers self.num_moe_layers = num_moe_layers self.dense_layers = nn.ModuleList([ DenseBlock(dim, num_heads, ffn_mult=ffn_mult, dropout=dropout) for _ in range(num_dense_layers) ]) self.moe_layers = nn.ModuleList([ MoEBlockWithAttn( dim, num_heads, num_routed=num_routed, num_shared=num_shared, topk=topk, ffn_mult=ffn_mult, dropout=dropout, ) for _ in range(num_moe_layers) ]) self.final_norm = nn.LayerNorm(dim) # 默认关闭;外部通过 ``set_gradient_checkpointing(True)`` 打开以省显存 self.gradient_checkpointing = False def set_gradient_checkpointing(self, enabled: bool) -> None: """开启/关闭主干各层 gradient checkpointing(约省 2/3 激活显存)。""" self.gradient_checkpointing = enabled def set_moe_mode(self, mode: str) -> None: """切换所有 MoE 层模式('dense' / 'sparse')。""" for blk in self.moe_layers: blk.set_mode(mode) def set_router_temperature(self, t: float) -> None: for blk in self.moe_layers: blk.set_temperature(t) def forward( self, x: torch.Tensor, rope_cos: Optional[torch.Tensor] = None, rope_sin: Optional[torch.Tensor] = None, visual_slice: Optional[tuple[int, int]] = None, ) -> BackboneOutput: moe_stats: list[MoEStats] = [] use_ckpt = self.gradient_checkpointing and self.training for blk in self.dense_layers: if use_ckpt: x = cp.checkpoint( blk, x, rope_cos, rope_sin, visual_slice, use_reentrant=False ) else: x = blk(x, rope_cos=rope_cos, rope_sin=rope_sin, visual_slice=visual_slice) for blk in self.moe_layers: if use_ckpt: x, stats = cp.checkpoint( blk, x, rope_cos, rope_sin, visual_slice, use_reentrant=False ) else: x, stats = blk(x, rope_cos=rope_cos, rope_sin=rope_sin, visual_slice=visual_slice) moe_stats.append(stats) x = self.final_norm(x) return BackboneOutput(hidden_states=x, moe_stats=moe_stats)