WJAD / src /wjad /backbone /backbone.py
fuzirui's picture
Sync WJAD codebase
0cfefd2 verified
"""18 层主干:前 9 Dense + 后 9 MoE。"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Optional
import torch
import torch.nn as nn
import torch.utils.checkpoint as cp
from ..modules.moe import MoEStats
from .blocks import DenseBlock, MoEBlockWithAttn
@dataclass
class BackboneOutput:
"""主干输出。"""
hidden_states: torch.Tensor # [B, N, D]
moe_stats: list[MoEStats] = field(default_factory=list)
class Backbone(nn.Module):
"""端到端主干。
输入序列已包含位置编码(视觉部分 RoPE 在每层内部应用,非视觉部分使用
可学习 PE 在外部加完)。本模块只负责 18 层堆叠 + 路由统计聚合。
"""
def __init__(
self,
dim: int = 768,
num_heads: int = 12,
ffn_mult: int = 4,
num_dense_layers: int = 9,
num_moe_layers: int = 9,
num_routed: int = 7,
num_shared: int = 1,
topk: int = 3,
dropout: float = 0.0,
) -> None:
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.num_dense_layers = num_dense_layers
self.num_moe_layers = num_moe_layers
self.dense_layers = nn.ModuleList([
DenseBlock(dim, num_heads, ffn_mult=ffn_mult, dropout=dropout)
for _ in range(num_dense_layers)
])
self.moe_layers = nn.ModuleList([
MoEBlockWithAttn(
dim,
num_heads,
num_routed=num_routed,
num_shared=num_shared,
topk=topk,
ffn_mult=ffn_mult,
dropout=dropout,
)
for _ in range(num_moe_layers)
])
self.final_norm = nn.LayerNorm(dim)
# 默认关闭;外部通过 ``set_gradient_checkpointing(True)`` 打开以省显存
self.gradient_checkpointing = False
def set_gradient_checkpointing(self, enabled: bool) -> None:
"""开启/关闭主干各层 gradient checkpointing(约省 2/3 激活显存)。"""
self.gradient_checkpointing = enabled
def set_moe_mode(self, mode: str) -> None:
"""切换所有 MoE 层模式('dense' / 'sparse')。"""
for blk in self.moe_layers:
blk.set_mode(mode)
def set_router_temperature(self, t: float) -> None:
for blk in self.moe_layers:
blk.set_temperature(t)
def forward(
self,
x: torch.Tensor,
rope_cos: Optional[torch.Tensor] = None,
rope_sin: Optional[torch.Tensor] = None,
visual_slice: Optional[tuple[int, int]] = None,
) -> BackboneOutput:
moe_stats: list[MoEStats] = []
use_ckpt = self.gradient_checkpointing and self.training
for blk in self.dense_layers:
if use_ckpt:
x = cp.checkpoint(
blk, x, rope_cos, rope_sin, visual_slice, use_reentrant=False
)
else:
x = blk(x, rope_cos=rope_cos, rope_sin=rope_sin, visual_slice=visual_slice)
for blk in self.moe_layers:
if use_ckpt:
x, stats = cp.checkpoint(
blk, x, rope_cos, rope_sin, visual_slice, use_reentrant=False
)
else:
x, stats = blk(x, rope_cos=rope_cos, rope_sin=rope_sin, visual_slice=visual_slice)
moe_stats.append(stats)
x = self.final_norm(x)
return BackboneOutput(hidden_states=x, moe_stats=moe_stats)