Avra98's picture
Initial code dump (rebuttal-ready snapshot)
76de008 verified
from __future__ import annotations
from dataclasses import dataclass
import torch
from torch import nn
from addition.config import ExperimentConfig
@dataclass
class ModelOutput:
digit_logits: torch.Tensor
final_carry_logits: torch.Tensor
output_hidden: torch.Tensor
latent_history: list[torch.Tensor]
attention_weights: torch.Tensor | None
class TransformerBlock(nn.Module):
def __init__(self, d_model: int, n_heads: int, ff_dim: int, dropout: float) -> None:
super().__init__()
self.ln_1 = nn.LayerNorm(d_model)
self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
self.dropout = nn.Dropout(dropout)
self.ln_2 = nn.LayerNorm(d_model)
self.mlp = nn.Sequential(
nn.Linear(d_model, ff_dim),
nn.GELU(),
nn.Linear(ff_dim, d_model),
nn.Dropout(dropout),
)
def forward(self, hidden_states: torch.Tensor, need_weights: bool = False) -> tuple[torch.Tensor, torch.Tensor | None]:
seq_len = hidden_states.shape[1]
causal_mask = torch.ones(seq_len, seq_len, device=hidden_states.device, dtype=torch.bool).triu(1)
normed = self.ln_1(hidden_states)
attn_output, attn_weights = self.attn(
normed,
normed,
normed,
need_weights=need_weights,
average_attn_weights=False,
attn_mask=causal_mask,
)
hidden_states = hidden_states + self.dropout(attn_output)
hidden_states = hidden_states + self.mlp(self.ln_2(hidden_states))
return hidden_states, attn_weights if need_weights else None
class AdditionTransformer(nn.Module):
def __init__(self, config: ExperimentConfig) -> None:
super().__init__()
self.config = config
self.token_embedding = nn.Embedding(config.discrete_vocab_size, config.d_model)
self.position_embedding = nn.Embedding(config.max_sequence_length, config.d_model)
self.latent_type_embedding = nn.Parameter(torch.zeros(config.d_model))
self.output_slot_embeddings = nn.Parameter(torch.zeros(config.output_sequence_length, config.d_model))
self.block = TransformerBlock(
d_model=config.d_model,
n_heads=config.n_heads,
ff_dim=config.ff_dim,
dropout=config.dropout,
)
self.final_ln = nn.LayerNorm(config.d_model)
self.digit_head = nn.Linear(config.d_model, config.digit_vocab_size)
self.final_carry_head = nn.Linear(config.d_model, 2)
self.reset_parameters()
def reset_parameters(self) -> None:
nn.init.normal_(self.token_embedding.weight, mean=0.0, std=0.02)
nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02)
nn.init.normal_(self.latent_type_embedding, mean=0.0, std=0.02)
nn.init.normal_(self.output_slot_embeddings, mean=0.0, std=0.02)
nn.init.xavier_uniform_(self.digit_head.weight)
nn.init.zeros_(self.digit_head.bias)
nn.init.xavier_uniform_(self.final_carry_head.weight)
nn.init.zeros_(self.final_carry_head.bias)
def embed_discrete_tokens(self, input_ids: torch.Tensor) -> torch.Tensor:
seq_len = input_ids.shape[1]
positions = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)
return self.token_embedding(input_ids) + self.position_embedding(positions)
def embed_output_slots(
self,
batch_size: int,
output_length: int,
latent_count: int,
input_length: int,
device: torch.device,
) -> torch.Tensor:
positions = torch.arange(output_length, device=device) + input_length + latent_count
positioned = self.output_slot_embeddings[:output_length] + self.position_embedding(positions)
return positioned.unsqueeze(0).expand(batch_size, -1, -1)
def _run_block(
self,
embeddings: torch.Tensor,
*,
need_attention: bool = False,
) -> tuple[torch.Tensor, torch.Tensor | None]:
hidden_states, attention_weights = self.block(embeddings, need_weights=need_attention)
hidden_states = self.final_ln(hidden_states)
return hidden_states, attention_weights
def forward(
self,
input_ids: torch.Tensor,
*,
latent_steps: int = 0,
return_attention: bool = False,
) -> ModelOutput:
base_embeddings = self.embed_discrete_tokens(input_ids)
latent_history: list[torch.Tensor] = []
attention_weights: torch.Tensor | None = None
batch_size = input_ids.shape[0]
input_length = input_ids.shape[1]
active_digits = max(1, (input_length - 2) // 2)
output_length = active_digits + 1
output_embeddings = self.embed_output_slots(
batch_size=batch_size,
output_length=output_length,
latent_count=0,
input_length=input_length,
device=input_ids.device,
)
hidden_states, attention_weights = self._run_block(
torch.cat([base_embeddings, output_embeddings], dim=1),
need_attention=return_attention,
)
output_hidden = hidden_states[:, -output_length:, :]
summary_hidden = output_hidden[:, -1, :]
latent_history.append(summary_hidden)
latent_embeddings: list[torch.Tensor] = []
for step_index in range(int(latent_steps)):
latent_token = summary_hidden.unsqueeze(1) + self.latent_type_embedding.view(1, 1, -1)
latent_position_index = input_length + step_index
latent_token = latent_token + self.position_embedding.weight[latent_position_index].view(1, 1, -1)
latent_embeddings.append(latent_token)
output_embeddings = self.embed_output_slots(
batch_size=batch_size,
output_length=output_length,
latent_count=len(latent_embeddings),
input_length=input_length,
device=input_ids.device,
)
hidden_states, attention_weights = self._run_block(
torch.cat([base_embeddings] + latent_embeddings + [output_embeddings], dim=1),
need_attention=return_attention,
)
latent_index = input_length + step_index
summary_hidden = hidden_states[:, latent_index, :]
output_hidden = hidden_states[:, -output_length:, :]
latent_history.append(summary_hidden)
digit_logits = self.digit_head(output_hidden[:, :active_digits, :])
final_carry_logits = self.final_carry_head(output_hidden[:, -1, :])
return ModelOutput(
digit_logits=digit_logits,
final_carry_logits=final_carry_logits,
output_hidden=output_hidden,
latent_history=latent_history,
attention_weights=attention_weights,
)
def parameter_count(self) -> int:
return sum(parameter.numel() for parameter in self.parameters())
def build_model(config: ExperimentConfig, device: str | None = None) -> AdditionTransformer:
model = AdditionTransformer(config)
if device is not None:
model = model.to(device)
return model
@torch.no_grad()
def describe_model(config: ExperimentConfig) -> dict[str, int]:
model = build_model(config)
total_params = model.parameter_count()
head_params = sum(parameter.numel() for name, parameter in model.named_parameters() if "head" in name)
embedding_params = sum(parameter.numel() for name, parameter in model.named_parameters() if "embedding" in name)
return {
"total_params": int(total_params),
"embedding_params": int(embedding_params),
"head_params": int(head_params),
"backbone_params": int(total_params - head_params),
}