abpt / src /model /testformer_block.py
Search
feat: add testformer wikitext combo runner
742c943
from __future__ import annotations
import torch
import torch.nn as nn
from src.model.testformer_attention import TestFormerAttention
from src.model.testformer_config import TestFormerConfig
class TestFormerFeedForward(nn.Module):
def __init__(
self,
cfg: TestFormerConfig,
device: torch.device | str | None = None,
dtype: torch.dtype | None = None,
) -> None:
super().__init__()
factory_kwargs = {"device": device, "dtype": dtype}
self.up_proj = nn.Linear(cfg.d_model, cfg.d_ff, bias=cfg.bias, **factory_kwargs)
self.down_proj = nn.Linear(cfg.d_ff, cfg.d_model, bias=cfg.bias, **factory_kwargs)
self.activation = nn.GELU()
self.dropout = nn.Dropout(cfg.resid_dropout)
def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
hidden = self.activation(self.up_proj(x))
output = self.dropout(self.down_proj(hidden))
return {
"hidden": hidden,
"output": output,
}
class TestFormerBlock(nn.Module):
def __init__(
self,
cfg: TestFormerConfig,
layer_idx: int,
device: torch.device | str | None = None,
dtype: torch.dtype | None = None,
) -> None:
super().__init__()
self.layer_idx = layer_idx
factory_kwargs = {"device": device, "dtype": dtype}
self.ln1 = nn.LayerNorm(cfg.d_model, eps=cfg.layer_norm_eps, **factory_kwargs)
self.attention = TestFormerAttention(cfg, device=device, dtype=dtype)
self.ln2 = nn.LayerNorm(cfg.d_model, eps=cfg.layer_norm_eps, **factory_kwargs)
self.ffn = TestFormerFeedForward(cfg, device=device, dtype=dtype)
def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor | dict[str, torch.Tensor]]:
pre_attn = self.ln1(x)
attention = self.attention(pre_attn)
attn_residual = x + attention["output"]
pre_ffn = self.ln2(attn_residual)
ffn = self.ffn(pre_ffn)
hidden = attn_residual + ffn["output"]
return {
"hidden": hidden,
"pre_attn": pre_attn,
"pre_ffn": pre_ffn,
"attention": attention,
"ffn": ffn,
}