from __future__ import annotations import torch import torch.nn as nn from src.model.testformer_block import TestFormerBlock from src.model.testformer_config import TestFormerConfig from src.model.testformer_loss import testformer_lm_loss class TestFormerLM(nn.Module): def __init__( self, cfg: TestFormerConfig, device: torch.device | str | None = None, dtype: torch.dtype | None = None, ) -> None: super().__init__() self.cfg = cfg factory_kwargs = {"device": device, "dtype": dtype} self.token_embedding = nn.Embedding(cfg.vocab_size, cfg.d_model, **factory_kwargs) self.embedding_dropout = nn.Dropout(cfg.emb_dropout) self.blocks = nn.ModuleList( [ TestFormerBlock(cfg, layer_idx=layer_idx, device=device, dtype=dtype) for layer_idx in range(cfg.n_layers) ] ) self.final_norm = nn.LayerNorm(cfg.d_model, eps=cfg.layer_norm_eps, **factory_kwargs) self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False, **factory_kwargs) if cfg.tie_input_output_embeddings: self.lm_head.weight = self.token_embedding.weight def parameter_count(self) -> int: return sum(parameter.numel() for parameter in self.parameters()) def body_parameter_count(self) -> int: embedding_params = self.token_embedding.weight.numel() output_params = 0 if self.cfg.tie_input_output_embeddings else self.lm_head.weight.numel() return self.parameter_count() - embedding_params - output_params def forward( self, input_ids: torch.Tensor, targets: torch.Tensor | None = None, ) -> dict[str, torch.Tensor | list[torch.Tensor] | list[dict[str, torch.Tensor | dict[str, torch.Tensor]]]]: hidden = self.embedding_dropout(self.token_embedding(input_ids)) layer_outputs: list[torch.Tensor] = [hidden] block_outputs: list[dict[str, torch.Tensor | dict[str, torch.Tensor]]] = [] attention_outputs: list[dict[str, torch.Tensor]] = [] ffn_outputs: list[dict[str, torch.Tensor]] = [] for block in self.blocks: block_out = block(hidden) hidden = block_out["hidden"] layer_outputs.append(hidden) block_outputs.append(block_out) attention_outputs.append(block_out["attention"]) ffn_outputs.append(block_out["ffn"]) hidden = self.final_norm(hidden) logits = self.lm_head(hidden) output: dict[str, torch.Tensor | list[torch.Tensor] | list[dict[str, torch.Tensor | dict[str, torch.Tensor]]]] = { "hidden": hidden, "logits": logits, "layer_outputs": layer_outputs, "block_outputs": block_outputs, "attention_outputs": attention_outputs, "ffn_outputs": ffn_outputs, } if targets is not None: loss_dict = testformer_lm_loss(logits, targets) component_losses = loss_dict["component_losses"] output["loss"] = loss_dict["total_loss"] output["ce_loss"] = component_losses["ce_loss"] output["component_losses"] = component_losses return output