| """主干层 block:Dense(GateSelfAttn + SwiGLU FFN)/ MoE(GateSelfAttn + MoE FFN)。""" |
|
|
| from __future__ import annotations |
|
|
| from typing import Optional |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from ..modules.ffn import SwiGLUFFN |
| from ..modules.gate_attention import GateSelfAttention |
| from ..modules.moe import MoEBlock, MoEStats |
|
|
|
|
| class DenseBlock(nn.Module): |
| """PreNorm GateSelfAttention + PreNorm SwiGLU FFN。""" |
|
|
| def __init__(self, dim: int, num_heads: int, ffn_mult: int = 4, dropout: float = 0.0) -> None: |
| super().__init__() |
| self.norm1 = nn.LayerNorm(dim) |
| self.attn = GateSelfAttention(dim, num_heads=num_heads, dropout=dropout) |
| self.norm2 = nn.LayerNorm(dim) |
| self.ffn = SwiGLUFFN(dim, mult=ffn_mult, dropout=dropout) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| rope_cos: Optional[torch.Tensor] = None, |
| rope_sin: Optional[torch.Tensor] = None, |
| visual_slice: Optional[tuple[int, int]] = None, |
| ) -> torch.Tensor: |
| x = x + self.attn(self.norm1(x), rope_cos=rope_cos, rope_sin=rope_sin, visual_slice=visual_slice) |
| x = x + self.ffn(self.norm2(x)) |
| return x |
|
|
|
|
| class MoEBlockWithAttn(nn.Module): |
| """PreNorm GateSelfAttention + PreNorm MoE FFN。""" |
|
|
| def __init__( |
| self, |
| dim: int, |
| num_heads: int, |
| num_routed: int = 7, |
| num_shared: int = 1, |
| topk: int = 3, |
| ffn_mult: int = 4, |
| dropout: float = 0.0, |
| ) -> None: |
| super().__init__() |
| self.norm1 = nn.LayerNorm(dim) |
| self.attn = GateSelfAttention(dim, num_heads=num_heads, dropout=dropout) |
| self.norm2 = nn.LayerNorm(dim) |
| self.moe = MoEBlock( |
| dim, |
| num_routed=num_routed, |
| num_shared=num_shared, |
| topk=topk, |
| ffn_mult=ffn_mult, |
| dropout=dropout, |
| ) |
|
|
| def set_mode(self, mode: str) -> None: |
| self.moe.set_mode(mode) |
|
|
| def set_temperature(self, t: float) -> None: |
| self.moe.set_temperature(t) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| rope_cos: Optional[torch.Tensor] = None, |
| rope_sin: Optional[torch.Tensor] = None, |
| visual_slice: Optional[tuple[int, int]] = None, |
| ) -> tuple[torch.Tensor, MoEStats]: |
| x = x + self.attn(self.norm1(x), rope_cos=rope_cos, rope_sin=rope_sin, visual_slice=visual_slice) |
| moe_out, stats = self.moe(self.norm2(x)) |
| x = x + moe_out |
| return x, stats |
|
|