| from __future__ import annotations |
|
|
| from dataclasses import dataclass |
|
|
| import torch |
| from torch import nn |
|
|
| from addition.config import ExperimentConfig |
|
|
|
|
| @dataclass |
| class ModelOutput: |
| digit_logits: torch.Tensor |
| final_carry_logits: torch.Tensor |
| output_hidden: torch.Tensor |
| latent_history: list[torch.Tensor] |
| attention_weights: torch.Tensor | None |
|
|
|
|
| class TransformerBlock(nn.Module): |
| def __init__(self, d_model: int, n_heads: int, ff_dim: int, dropout: float) -> None: |
| super().__init__() |
| self.ln_1 = nn.LayerNorm(d_model) |
| self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True) |
| self.dropout = nn.Dropout(dropout) |
| self.ln_2 = nn.LayerNorm(d_model) |
| self.mlp = nn.Sequential( |
| nn.Linear(d_model, ff_dim), |
| nn.GELU(), |
| nn.Linear(ff_dim, d_model), |
| nn.Dropout(dropout), |
| ) |
|
|
| def forward(self, hidden_states: torch.Tensor, need_weights: bool = False) -> tuple[torch.Tensor, torch.Tensor | None]: |
| seq_len = hidden_states.shape[1] |
| causal_mask = torch.ones(seq_len, seq_len, device=hidden_states.device, dtype=torch.bool).triu(1) |
| normed = self.ln_1(hidden_states) |
| attn_output, attn_weights = self.attn( |
| normed, |
| normed, |
| normed, |
| need_weights=need_weights, |
| average_attn_weights=False, |
| attn_mask=causal_mask, |
| ) |
| hidden_states = hidden_states + self.dropout(attn_output) |
| hidden_states = hidden_states + self.mlp(self.ln_2(hidden_states)) |
| return hidden_states, attn_weights if need_weights else None |
|
|
|
|
| class AdditionTransformer(nn.Module): |
| def __init__(self, config: ExperimentConfig) -> None: |
| super().__init__() |
| self.config = config |
| self.token_embedding = nn.Embedding(config.discrete_vocab_size, config.d_model) |
| self.position_embedding = nn.Embedding(config.max_sequence_length, config.d_model) |
| self.latent_type_embedding = nn.Parameter(torch.zeros(config.d_model)) |
| self.output_slot_embeddings = nn.Parameter(torch.zeros(config.output_sequence_length, config.d_model)) |
| self.block = TransformerBlock( |
| d_model=config.d_model, |
| n_heads=config.n_heads, |
| ff_dim=config.ff_dim, |
| dropout=config.dropout, |
| ) |
| self.final_ln = nn.LayerNorm(config.d_model) |
| self.digit_head = nn.Linear(config.d_model, config.digit_vocab_size) |
| self.final_carry_head = nn.Linear(config.d_model, 2) |
| self.reset_parameters() |
|
|
| def reset_parameters(self) -> None: |
| nn.init.normal_(self.token_embedding.weight, mean=0.0, std=0.02) |
| nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02) |
| nn.init.normal_(self.latent_type_embedding, mean=0.0, std=0.02) |
| nn.init.normal_(self.output_slot_embeddings, mean=0.0, std=0.02) |
| nn.init.xavier_uniform_(self.digit_head.weight) |
| nn.init.zeros_(self.digit_head.bias) |
| nn.init.xavier_uniform_(self.final_carry_head.weight) |
| nn.init.zeros_(self.final_carry_head.bias) |
|
|
| def embed_discrete_tokens(self, input_ids: torch.Tensor) -> torch.Tensor: |
| seq_len = input_ids.shape[1] |
| positions = torch.arange(seq_len, device=input_ids.device).unsqueeze(0) |
| return self.token_embedding(input_ids) + self.position_embedding(positions) |
|
|
| def embed_output_slots( |
| self, |
| batch_size: int, |
| output_length: int, |
| latent_count: int, |
| input_length: int, |
| device: torch.device, |
| ) -> torch.Tensor: |
| positions = torch.arange(output_length, device=device) + input_length + latent_count |
| positioned = self.output_slot_embeddings[:output_length] + self.position_embedding(positions) |
| return positioned.unsqueeze(0).expand(batch_size, -1, -1) |
|
|
| def _run_block( |
| self, |
| embeddings: torch.Tensor, |
| *, |
| need_attention: bool = False, |
| ) -> tuple[torch.Tensor, torch.Tensor | None]: |
| hidden_states, attention_weights = self.block(embeddings, need_weights=need_attention) |
| hidden_states = self.final_ln(hidden_states) |
| return hidden_states, attention_weights |
|
|
| def forward( |
| self, |
| input_ids: torch.Tensor, |
| *, |
| latent_steps: int = 0, |
| return_attention: bool = False, |
| ) -> ModelOutput: |
| base_embeddings = self.embed_discrete_tokens(input_ids) |
| latent_history: list[torch.Tensor] = [] |
| attention_weights: torch.Tensor | None = None |
| batch_size = input_ids.shape[0] |
| input_length = input_ids.shape[1] |
| active_digits = max(1, (input_length - 2) // 2) |
| output_length = active_digits + 1 |
| output_embeddings = self.embed_output_slots( |
| batch_size=batch_size, |
| output_length=output_length, |
| latent_count=0, |
| input_length=input_length, |
| device=input_ids.device, |
| ) |
| hidden_states, attention_weights = self._run_block( |
| torch.cat([base_embeddings, output_embeddings], dim=1), |
| need_attention=return_attention, |
| ) |
| output_hidden = hidden_states[:, -output_length:, :] |
| summary_hidden = output_hidden[:, -1, :] |
| latent_history.append(summary_hidden) |
|
|
| latent_embeddings: list[torch.Tensor] = [] |
| for step_index in range(int(latent_steps)): |
| latent_token = summary_hidden.unsqueeze(1) + self.latent_type_embedding.view(1, 1, -1) |
| latent_position_index = input_length + step_index |
| latent_token = latent_token + self.position_embedding.weight[latent_position_index].view(1, 1, -1) |
| latent_embeddings.append(latent_token) |
| output_embeddings = self.embed_output_slots( |
| batch_size=batch_size, |
| output_length=output_length, |
| latent_count=len(latent_embeddings), |
| input_length=input_length, |
| device=input_ids.device, |
| ) |
| hidden_states, attention_weights = self._run_block( |
| torch.cat([base_embeddings] + latent_embeddings + [output_embeddings], dim=1), |
| need_attention=return_attention, |
| ) |
| latent_index = input_length + step_index |
| summary_hidden = hidden_states[:, latent_index, :] |
| output_hidden = hidden_states[:, -output_length:, :] |
| latent_history.append(summary_hidden) |
|
|
| digit_logits = self.digit_head(output_hidden[:, :active_digits, :]) |
| final_carry_logits = self.final_carry_head(output_hidden[:, -1, :]) |
| return ModelOutput( |
| digit_logits=digit_logits, |
| final_carry_logits=final_carry_logits, |
| output_hidden=output_hidden, |
| latent_history=latent_history, |
| attention_weights=attention_weights, |
| ) |
|
|
| def parameter_count(self) -> int: |
| return sum(parameter.numel() for parameter in self.parameters()) |
|
|
|
|
| def build_model(config: ExperimentConfig, device: str | None = None) -> AdditionTransformer: |
| model = AdditionTransformer(config) |
| if device is not None: |
| model = model.to(device) |
| return model |
|
|
|
|
| @torch.no_grad() |
| def describe_model(config: ExperimentConfig) -> dict[str, int]: |
| model = build_model(config) |
| total_params = model.parameter_count() |
| head_params = sum(parameter.numel() for name, parameter in model.named_parameters() if "head" in name) |
| embedding_params = sum(parameter.numel() for name, parameter in model.named_parameters() if "embedding" in name) |
| return { |
| "total_params": int(total_params), |
| "embedding_params": int(embedding_params), |
| "head_params": int(head_params), |
| "backbone_params": int(total_params - head_params), |
| } |
|
|