File size: 3,266 Bytes
742c943
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from __future__ import annotations

import torch
import torch.nn as nn

from src.model.testformer_block import TestFormerBlock
from src.model.testformer_config import TestFormerConfig
from src.model.testformer_loss import testformer_lm_loss


class TestFormerLM(nn.Module):
    def __init__(
        self,
        cfg: TestFormerConfig,
        device: torch.device | str | None = None,
        dtype: torch.dtype | None = None,
    ) -> None:
        super().__init__()
        self.cfg = cfg
        factory_kwargs = {"device": device, "dtype": dtype}
        self.token_embedding = nn.Embedding(cfg.vocab_size, cfg.d_model, **factory_kwargs)
        self.embedding_dropout = nn.Dropout(cfg.emb_dropout)
        self.blocks = nn.ModuleList(
            [
                TestFormerBlock(cfg, layer_idx=layer_idx, device=device, dtype=dtype)
                for layer_idx in range(cfg.n_layers)
            ]
        )
        self.final_norm = nn.LayerNorm(cfg.d_model, eps=cfg.layer_norm_eps, **factory_kwargs)
        self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False, **factory_kwargs)
        if cfg.tie_input_output_embeddings:
            self.lm_head.weight = self.token_embedding.weight

    def parameter_count(self) -> int:
        return sum(parameter.numel() for parameter in self.parameters())

    def body_parameter_count(self) -> int:
        embedding_params = self.token_embedding.weight.numel()
        output_params = 0 if self.cfg.tie_input_output_embeddings else self.lm_head.weight.numel()
        return self.parameter_count() - embedding_params - output_params

    def forward(
        self,
        input_ids: torch.Tensor,
        targets: torch.Tensor | None = None,
    ) -> dict[str, torch.Tensor | list[torch.Tensor] | list[dict[str, torch.Tensor | dict[str, torch.Tensor]]]]:
        hidden = self.embedding_dropout(self.token_embedding(input_ids))

        layer_outputs: list[torch.Tensor] = [hidden]
        block_outputs: list[dict[str, torch.Tensor | dict[str, torch.Tensor]]] = []
        attention_outputs: list[dict[str, torch.Tensor]] = []
        ffn_outputs: list[dict[str, torch.Tensor]] = []

        for block in self.blocks:
            block_out = block(hidden)
            hidden = block_out["hidden"]
            layer_outputs.append(hidden)
            block_outputs.append(block_out)
            attention_outputs.append(block_out["attention"])
            ffn_outputs.append(block_out["ffn"])

        hidden = self.final_norm(hidden)
        logits = self.lm_head(hidden)

        output: dict[str, torch.Tensor | list[torch.Tensor] | list[dict[str, torch.Tensor | dict[str, torch.Tensor]]]] = {
            "hidden": hidden,
            "logits": logits,
            "layer_outputs": layer_outputs,
            "block_outputs": block_outputs,
            "attention_outputs": attention_outputs,
            "ffn_outputs": ffn_outputs,
        }

        if targets is not None:
            loss_dict = testformer_lm_loss(logits, targets)
            component_losses = loss_dict["component_losses"]
            output["loss"] = loss_dict["total_loss"]
            output["ce_loss"] = component_losses["ce_loss"]
            output["component_losses"] = component_losses

        return output