"""主干层 block:Dense(GateSelfAttn + SwiGLU FFN)/ MoE(GateSelfAttn + MoE FFN)。""" from __future__ import annotations from typing import Optional import torch import torch.nn as nn from ..modules.ffn import SwiGLUFFN from ..modules.gate_attention import GateSelfAttention from ..modules.moe import MoEBlock, MoEStats class DenseBlock(nn.Module): """PreNorm GateSelfAttention + PreNorm SwiGLU FFN。""" def __init__(self, dim: int, num_heads: int, ffn_mult: int = 4, dropout: float = 0.0) -> None: super().__init__() self.norm1 = nn.LayerNorm(dim) self.attn = GateSelfAttention(dim, num_heads=num_heads, dropout=dropout) self.norm2 = nn.LayerNorm(dim) self.ffn = SwiGLUFFN(dim, mult=ffn_mult, dropout=dropout) def forward( self, x: torch.Tensor, rope_cos: Optional[torch.Tensor] = None, rope_sin: Optional[torch.Tensor] = None, visual_slice: Optional[tuple[int, int]] = None, ) -> torch.Tensor: x = x + self.attn(self.norm1(x), rope_cos=rope_cos, rope_sin=rope_sin, visual_slice=visual_slice) x = x + self.ffn(self.norm2(x)) return x class MoEBlockWithAttn(nn.Module): """PreNorm GateSelfAttention + PreNorm MoE FFN。""" def __init__( self, dim: int, num_heads: int, num_routed: int = 7, num_shared: int = 1, topk: int = 3, ffn_mult: int = 4, dropout: float = 0.0, ) -> None: super().__init__() self.norm1 = nn.LayerNorm(dim) self.attn = GateSelfAttention(dim, num_heads=num_heads, dropout=dropout) self.norm2 = nn.LayerNorm(dim) self.moe = MoEBlock( dim, num_routed=num_routed, num_shared=num_shared, topk=topk, ffn_mult=ffn_mult, dropout=dropout, ) def set_mode(self, mode: str) -> None: self.moe.set_mode(mode) def set_temperature(self, t: float) -> None: self.moe.set_temperature(t) def forward( self, x: torch.Tensor, rope_cos: Optional[torch.Tensor] = None, rope_sin: Optional[torch.Tensor] = None, visual_slice: Optional[tuple[int, int]] = None, ) -> tuple[torch.Tensor, MoEStats]: x = x + self.attn(self.norm1(x), rope_cos=rope_cos, rope_sin=rope_sin, visual_slice=visual_slice) moe_out, stats = self.moe(self.norm2(x)) x = x + moe_out return x, stats