File size: 2,221 Bytes
742c943
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from __future__ import annotations

import torch
import torch.nn as nn

from src.model.testformer_attention import TestFormerAttention
from src.model.testformer_config import TestFormerConfig


class TestFormerFeedForward(nn.Module):
    def __init__(
        self,
        cfg: TestFormerConfig,
        device: torch.device | str | None = None,
        dtype: torch.dtype | None = None,
    ) -> None:
        super().__init__()
        factory_kwargs = {"device": device, "dtype": dtype}
        self.up_proj = nn.Linear(cfg.d_model, cfg.d_ff, bias=cfg.bias, **factory_kwargs)
        self.down_proj = nn.Linear(cfg.d_ff, cfg.d_model, bias=cfg.bias, **factory_kwargs)
        self.activation = nn.GELU()
        self.dropout = nn.Dropout(cfg.resid_dropout)

    def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
        hidden = self.activation(self.up_proj(x))
        output = self.dropout(self.down_proj(hidden))
        return {
            "hidden": hidden,
            "output": output,
        }


class TestFormerBlock(nn.Module):
    def __init__(
        self,
        cfg: TestFormerConfig,
        layer_idx: int,
        device: torch.device | str | None = None,
        dtype: torch.dtype | None = None,
    ) -> None:
        super().__init__()
        self.layer_idx = layer_idx
        factory_kwargs = {"device": device, "dtype": dtype}
        self.ln1 = nn.LayerNorm(cfg.d_model, eps=cfg.layer_norm_eps, **factory_kwargs)
        self.attention = TestFormerAttention(cfg, device=device, dtype=dtype)
        self.ln2 = nn.LayerNorm(cfg.d_model, eps=cfg.layer_norm_eps, **factory_kwargs)
        self.ffn = TestFormerFeedForward(cfg, device=device, dtype=dtype)

    def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor | dict[str, torch.Tensor]]:
        pre_attn = self.ln1(x)
        attention = self.attention(pre_attn)
        attn_residual = x + attention["output"]

        pre_ffn = self.ln2(attn_residual)
        ffn = self.ffn(pre_ffn)
        hidden = attn_residual + ffn["output"]

        return {
            "hidden": hidden,
            "pre_attn": pre_attn,
            "pre_ffn": pre_ffn,
            "attention": attention,
            "ffn": ffn,
        }