| """18 层主干:前 9 Dense + 后 9 MoE。""" |
|
|
| from __future__ import annotations |
|
|
| from dataclasses import dataclass, field |
| from typing import Optional |
|
|
| import torch |
| import torch.nn as nn |
| import torch.utils.checkpoint as cp |
|
|
| from ..modules.moe import MoEStats |
| from .blocks import DenseBlock, MoEBlockWithAttn |
|
|
|
|
| @dataclass |
| class BackboneOutput: |
| """主干输出。""" |
|
|
| hidden_states: torch.Tensor |
| moe_stats: list[MoEStats] = field(default_factory=list) |
|
|
|
|
| class Backbone(nn.Module): |
| """端到端主干。 |
| |
| 输入序列已包含位置编码(视觉部分 RoPE 在每层内部应用,非视觉部分使用 |
| 可学习 PE 在外部加完)。本模块只负责 18 层堆叠 + 路由统计聚合。 |
| """ |
|
|
| def __init__( |
| self, |
| dim: int = 768, |
| num_heads: int = 12, |
| ffn_mult: int = 4, |
| num_dense_layers: int = 9, |
| num_moe_layers: int = 9, |
| num_routed: int = 7, |
| num_shared: int = 1, |
| topk: int = 3, |
| dropout: float = 0.0, |
| ) -> None: |
| super().__init__() |
| self.dim = dim |
| self.num_heads = num_heads |
| self.num_dense_layers = num_dense_layers |
| self.num_moe_layers = num_moe_layers |
|
|
| self.dense_layers = nn.ModuleList([ |
| DenseBlock(dim, num_heads, ffn_mult=ffn_mult, dropout=dropout) |
| for _ in range(num_dense_layers) |
| ]) |
| self.moe_layers = nn.ModuleList([ |
| MoEBlockWithAttn( |
| dim, |
| num_heads, |
| num_routed=num_routed, |
| num_shared=num_shared, |
| topk=topk, |
| ffn_mult=ffn_mult, |
| dropout=dropout, |
| ) |
| for _ in range(num_moe_layers) |
| ]) |
| self.final_norm = nn.LayerNorm(dim) |
| |
| self.gradient_checkpointing = False |
|
|
| def set_gradient_checkpointing(self, enabled: bool) -> None: |
| """开启/关闭主干各层 gradient checkpointing(约省 2/3 激活显存)。""" |
| self.gradient_checkpointing = enabled |
|
|
| def set_moe_mode(self, mode: str) -> None: |
| """切换所有 MoE 层模式('dense' / 'sparse')。""" |
| for blk in self.moe_layers: |
| blk.set_mode(mode) |
|
|
| def set_router_temperature(self, t: float) -> None: |
| for blk in self.moe_layers: |
| blk.set_temperature(t) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| rope_cos: Optional[torch.Tensor] = None, |
| rope_sin: Optional[torch.Tensor] = None, |
| visual_slice: Optional[tuple[int, int]] = None, |
| ) -> BackboneOutput: |
| moe_stats: list[MoEStats] = [] |
| use_ckpt = self.gradient_checkpointing and self.training |
|
|
| for blk in self.dense_layers: |
| if use_ckpt: |
| x = cp.checkpoint( |
| blk, x, rope_cos, rope_sin, visual_slice, use_reentrant=False |
| ) |
| else: |
| x = blk(x, rope_cos=rope_cos, rope_sin=rope_sin, visual_slice=visual_slice) |
|
|
| for blk in self.moe_layers: |
| if use_ckpt: |
| x, stats = cp.checkpoint( |
| blk, x, rope_cos, rope_sin, visual_slice, use_reentrant=False |
| ) |
| else: |
| x, stats = blk(x, rope_cos=rope_cos, rope_sin=rope_sin, visual_slice=visual_slice) |
| moe_stats.append(stats) |
|
|
| x = self.final_norm(x) |
| return BackboneOutput(hidden_states=x, moe_stats=moe_stats) |
|
|