Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import torch | |
| import torch.nn as nn | |
| from src.model.testformer_attention import TestFormerAttention | |
| from src.model.testformer_config import TestFormerConfig | |
| class TestFormerFeedForward(nn.Module): | |
| def __init__( | |
| self, | |
| cfg: TestFormerConfig, | |
| device: torch.device | str | None = None, | |
| dtype: torch.dtype | None = None, | |
| ) -> None: | |
| super().__init__() | |
| factory_kwargs = {"device": device, "dtype": dtype} | |
| self.up_proj = nn.Linear(cfg.d_model, cfg.d_ff, bias=cfg.bias, **factory_kwargs) | |
| self.down_proj = nn.Linear(cfg.d_ff, cfg.d_model, bias=cfg.bias, **factory_kwargs) | |
| self.activation = nn.GELU() | |
| self.dropout = nn.Dropout(cfg.resid_dropout) | |
| def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]: | |
| hidden = self.activation(self.up_proj(x)) | |
| output = self.dropout(self.down_proj(hidden)) | |
| return { | |
| "hidden": hidden, | |
| "output": output, | |
| } | |
| class TestFormerBlock(nn.Module): | |
| def __init__( | |
| self, | |
| cfg: TestFormerConfig, | |
| layer_idx: int, | |
| device: torch.device | str | None = None, | |
| dtype: torch.dtype | None = None, | |
| ) -> None: | |
| super().__init__() | |
| self.layer_idx = layer_idx | |
| factory_kwargs = {"device": device, "dtype": dtype} | |
| self.ln1 = nn.LayerNorm(cfg.d_model, eps=cfg.layer_norm_eps, **factory_kwargs) | |
| self.attention = TestFormerAttention(cfg, device=device, dtype=dtype) | |
| self.ln2 = nn.LayerNorm(cfg.d_model, eps=cfg.layer_norm_eps, **factory_kwargs) | |
| self.ffn = TestFormerFeedForward(cfg, device=device, dtype=dtype) | |
| def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor | dict[str, torch.Tensor]]: | |
| pre_attn = self.ln1(x) | |
| attention = self.attention(pre_attn) | |
| attn_residual = x + attention["output"] | |
| pre_ffn = self.ln2(attn_residual) | |
| ffn = self.ffn(pre_ffn) | |
| hidden = attn_residual + ffn["output"] | |
| return { | |
| "hidden": hidden, | |
| "pre_attn": pre_attn, | |
| "pre_ffn": pre_ffn, | |
| "attention": attention, | |
| "ffn": ffn, | |
| } | |