File size: 3,583 Bytes
0cfefd2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
"""18 层主干:前 9 Dense + 后 9 MoE。"""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import Optional

import torch
import torch.nn as nn
import torch.utils.checkpoint as cp

from ..modules.moe import MoEStats
from .blocks import DenseBlock, MoEBlockWithAttn


@dataclass
class BackboneOutput:
    """主干输出。"""

    hidden_states: torch.Tensor                # [B, N, D]
    moe_stats: list[MoEStats] = field(default_factory=list)


class Backbone(nn.Module):
    """端到端主干。

    输入序列已包含位置编码(视觉部分 RoPE 在每层内部应用,非视觉部分使用
    可学习 PE 在外部加完)。本模块只负责 18 层堆叠 + 路由统计聚合。
    """

    def __init__(
        self,
        dim: int = 768,
        num_heads: int = 12,
        ffn_mult: int = 4,
        num_dense_layers: int = 9,
        num_moe_layers: int = 9,
        num_routed: int = 7,
        num_shared: int = 1,
        topk: int = 3,
        dropout: float = 0.0,
    ) -> None:
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.num_dense_layers = num_dense_layers
        self.num_moe_layers = num_moe_layers

        self.dense_layers = nn.ModuleList([
            DenseBlock(dim, num_heads, ffn_mult=ffn_mult, dropout=dropout)
            for _ in range(num_dense_layers)
        ])
        self.moe_layers = nn.ModuleList([
            MoEBlockWithAttn(
                dim,
                num_heads,
                num_routed=num_routed,
                num_shared=num_shared,
                topk=topk,
                ffn_mult=ffn_mult,
                dropout=dropout,
            )
            for _ in range(num_moe_layers)
        ])
        self.final_norm = nn.LayerNorm(dim)
        # 默认关闭;外部通过 ``set_gradient_checkpointing(True)`` 打开以省显存
        self.gradient_checkpointing = False

    def set_gradient_checkpointing(self, enabled: bool) -> None:
        """开启/关闭主干各层 gradient checkpointing(约省 2/3 激活显存)。"""
        self.gradient_checkpointing = enabled

    def set_moe_mode(self, mode: str) -> None:
        """切换所有 MoE 层模式('dense' / 'sparse')。"""
        for blk in self.moe_layers:
            blk.set_mode(mode)

    def set_router_temperature(self, t: float) -> None:
        for blk in self.moe_layers:
            blk.set_temperature(t)

    def forward(
        self,
        x: torch.Tensor,
        rope_cos: Optional[torch.Tensor] = None,
        rope_sin: Optional[torch.Tensor] = None,
        visual_slice: Optional[tuple[int, int]] = None,
    ) -> BackboneOutput:
        moe_stats: list[MoEStats] = []
        use_ckpt = self.gradient_checkpointing and self.training

        for blk in self.dense_layers:
            if use_ckpt:
                x = cp.checkpoint(
                    blk, x, rope_cos, rope_sin, visual_slice, use_reentrant=False
                )
            else:
                x = blk(x, rope_cos=rope_cos, rope_sin=rope_sin, visual_slice=visual_slice)

        for blk in self.moe_layers:
            if use_ckpt:
                x, stats = cp.checkpoint(
                    blk, x, rope_cos, rope_sin, visual_slice, use_reentrant=False
                )
            else:
                x, stats = blk(x, rope_cos=rope_cos, rope_sin=rope_sin, visual_slice=visual_slice)
            moe_stats.append(stats)

        x = self.final_norm(x)
        return BackboneOutput(hidden_states=x, moe_stats=moe_stats)