WJAD / src /wjad /backbone /blocks.py
fuzirui's picture
Sync WJAD codebase
0cfefd2 verified
"""主干层 block:Dense(GateSelfAttn + SwiGLU FFN)/ MoE(GateSelfAttn + MoE FFN)。"""
from __future__ import annotations
from typing import Optional
import torch
import torch.nn as nn
from ..modules.ffn import SwiGLUFFN
from ..modules.gate_attention import GateSelfAttention
from ..modules.moe import MoEBlock, MoEStats
class DenseBlock(nn.Module):
"""PreNorm GateSelfAttention + PreNorm SwiGLU FFN。"""
def __init__(self, dim: int, num_heads: int, ffn_mult: int = 4, dropout: float = 0.0) -> None:
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = GateSelfAttention(dim, num_heads=num_heads, dropout=dropout)
self.norm2 = nn.LayerNorm(dim)
self.ffn = SwiGLUFFN(dim, mult=ffn_mult, dropout=dropout)
def forward(
self,
x: torch.Tensor,
rope_cos: Optional[torch.Tensor] = None,
rope_sin: Optional[torch.Tensor] = None,
visual_slice: Optional[tuple[int, int]] = None,
) -> torch.Tensor:
x = x + self.attn(self.norm1(x), rope_cos=rope_cos, rope_sin=rope_sin, visual_slice=visual_slice)
x = x + self.ffn(self.norm2(x))
return x
class MoEBlockWithAttn(nn.Module):
"""PreNorm GateSelfAttention + PreNorm MoE FFN。"""
def __init__(
self,
dim: int,
num_heads: int,
num_routed: int = 7,
num_shared: int = 1,
topk: int = 3,
ffn_mult: int = 4,
dropout: float = 0.0,
) -> None:
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = GateSelfAttention(dim, num_heads=num_heads, dropout=dropout)
self.norm2 = nn.LayerNorm(dim)
self.moe = MoEBlock(
dim,
num_routed=num_routed,
num_shared=num_shared,
topk=topk,
ffn_mult=ffn_mult,
dropout=dropout,
)
def set_mode(self, mode: str) -> None:
self.moe.set_mode(mode)
def set_temperature(self, t: float) -> None:
self.moe.set_temperature(t)
def forward(
self,
x: torch.Tensor,
rope_cos: Optional[torch.Tensor] = None,
rope_sin: Optional[torch.Tensor] = None,
visual_slice: Optional[tuple[int, int]] = None,
) -> tuple[torch.Tensor, MoEStats]:
x = x + self.attn(self.norm1(x), rope_cos=rope_cos, rope_sin=rope_sin, visual_slice=visual_slice)
moe_out, stats = self.moe(self.norm2(x))
x = x + moe_out
return x, stats