from __future__ import annotations import torch import torch.nn as nn from src.model.testformer_attention import TestFormerAttention from src.model.testformer_config import TestFormerConfig class TestFormerFeedForward(nn.Module): def __init__( self, cfg: TestFormerConfig, device: torch.device | str | None = None, dtype: torch.dtype | None = None, ) -> None: super().__init__() factory_kwargs = {"device": device, "dtype": dtype} self.up_proj = nn.Linear(cfg.d_model, cfg.d_ff, bias=cfg.bias, **factory_kwargs) self.down_proj = nn.Linear(cfg.d_ff, cfg.d_model, bias=cfg.bias, **factory_kwargs) self.activation = nn.GELU() self.dropout = nn.Dropout(cfg.resid_dropout) def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]: hidden = self.activation(self.up_proj(x)) output = self.dropout(self.down_proj(hidden)) return { "hidden": hidden, "output": output, } class TestFormerBlock(nn.Module): def __init__( self, cfg: TestFormerConfig, layer_idx: int, device: torch.device | str | None = None, dtype: torch.dtype | None = None, ) -> None: super().__init__() self.layer_idx = layer_idx factory_kwargs = {"device": device, "dtype": dtype} self.ln1 = nn.LayerNorm(cfg.d_model, eps=cfg.layer_norm_eps, **factory_kwargs) self.attention = TestFormerAttention(cfg, device=device, dtype=dtype) self.ln2 = nn.LayerNorm(cfg.d_model, eps=cfg.layer_norm_eps, **factory_kwargs) self.ffn = TestFormerFeedForward(cfg, device=device, dtype=dtype) def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor | dict[str, torch.Tensor]]: pre_attn = self.ln1(x) attention = self.attention(pre_attn) attn_residual = x + attention["output"] pre_ffn = self.ln2(attn_residual) ffn = self.ffn(pre_ffn) hidden = attn_residual + ffn["output"] return { "hidden": hidden, "pre_attn": pre_attn, "pre_ffn": pre_ffn, "attention": attention, "ffn": ffn, }