Spaces:
Running on Zero
Running on Zero
| from __future__ import annotations | |
| import torch | |
| import torch.nn as nn | |
| from src.model.testformer_block import TestFormerBlock | |
| from src.model.testformer_config import TestFormerConfig | |
| from src.model.testformer_loss import testformer_lm_loss | |
| class TestFormerLM(nn.Module): | |
| def __init__( | |
| self, | |
| cfg: TestFormerConfig, | |
| device: torch.device | str | None = None, | |
| dtype: torch.dtype | None = None, | |
| ) -> None: | |
| super().__init__() | |
| self.cfg = cfg | |
| factory_kwargs = {"device": device, "dtype": dtype} | |
| self.token_embedding = nn.Embedding(cfg.vocab_size, cfg.d_model, **factory_kwargs) | |
| self.embedding_dropout = nn.Dropout(cfg.emb_dropout) | |
| self.blocks = nn.ModuleList( | |
| [ | |
| TestFormerBlock(cfg, layer_idx=layer_idx, device=device, dtype=dtype) | |
| for layer_idx in range(cfg.n_layers) | |
| ] | |
| ) | |
| self.final_norm = nn.LayerNorm(cfg.d_model, eps=cfg.layer_norm_eps, **factory_kwargs) | |
| self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False, **factory_kwargs) | |
| if cfg.tie_input_output_embeddings: | |
| self.lm_head.weight = self.token_embedding.weight | |
| def parameter_count(self) -> int: | |
| return sum(parameter.numel() for parameter in self.parameters()) | |
| def body_parameter_count(self) -> int: | |
| embedding_params = self.token_embedding.weight.numel() | |
| output_params = 0 if self.cfg.tie_input_output_embeddings else self.lm_head.weight.numel() | |
| return self.parameter_count() - embedding_params - output_params | |
| def forward( | |
| self, | |
| input_ids: torch.Tensor, | |
| targets: torch.Tensor | None = None, | |
| ) -> dict[str, torch.Tensor | list[torch.Tensor] | list[dict[str, torch.Tensor | dict[str, torch.Tensor]]]]: | |
| hidden = self.embedding_dropout(self.token_embedding(input_ids)) | |
| layer_outputs: list[torch.Tensor] = [hidden] | |
| block_outputs: list[dict[str, torch.Tensor | dict[str, torch.Tensor]]] = [] | |
| attention_outputs: list[dict[str, torch.Tensor]] = [] | |
| ffn_outputs: list[dict[str, torch.Tensor]] = [] | |
| for block in self.blocks: | |
| block_out = block(hidden) | |
| hidden = block_out["hidden"] | |
| layer_outputs.append(hidden) | |
| block_outputs.append(block_out) | |
| attention_outputs.append(block_out["attention"]) | |
| ffn_outputs.append(block_out["ffn"]) | |
| hidden = self.final_norm(hidden) | |
| logits = self.lm_head(hidden) | |
| output: dict[str, torch.Tensor | list[torch.Tensor] | list[dict[str, torch.Tensor | dict[str, torch.Tensor]]]] = { | |
| "hidden": hidden, | |
| "logits": logits, | |
| "layer_outputs": layer_outputs, | |
| "block_outputs": block_outputs, | |
| "attention_outputs": attention_outputs, | |
| "ffn_outputs": ffn_outputs, | |
| } | |
| if targets is not None: | |
| loss_dict = testformer_lm_loss(logits, targets) | |
| component_losses = loss_dict["component_losses"] | |
| output["loss"] = loss_dict["total_loss"] | |
| output["ce_loss"] = component_losses["ce_loss"] | |
| output["component_losses"] = component_losses | |
| return output | |