File size: 7,873 Bytes
76de008 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 | from __future__ import annotations
from dataclasses import dataclass
import torch
from torch import nn
from addition.config import ExperimentConfig
@dataclass
class ModelOutput:
digit_logits: torch.Tensor
final_carry_logits: torch.Tensor
output_hidden: torch.Tensor
latent_history: list[torch.Tensor]
attention_weights: torch.Tensor | None
class TransformerBlock(nn.Module):
def __init__(self, d_model: int, n_heads: int, ff_dim: int, dropout: float) -> None:
super().__init__()
self.ln_1 = nn.LayerNorm(d_model)
self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
self.dropout = nn.Dropout(dropout)
self.ln_2 = nn.LayerNorm(d_model)
self.mlp = nn.Sequential(
nn.Linear(d_model, ff_dim),
nn.GELU(),
nn.Linear(ff_dim, d_model),
nn.Dropout(dropout),
)
def forward(self, hidden_states: torch.Tensor, need_weights: bool = False) -> tuple[torch.Tensor, torch.Tensor | None]:
seq_len = hidden_states.shape[1]
causal_mask = torch.ones(seq_len, seq_len, device=hidden_states.device, dtype=torch.bool).triu(1)
normed = self.ln_1(hidden_states)
attn_output, attn_weights = self.attn(
normed,
normed,
normed,
need_weights=need_weights,
average_attn_weights=False,
attn_mask=causal_mask,
)
hidden_states = hidden_states + self.dropout(attn_output)
hidden_states = hidden_states + self.mlp(self.ln_2(hidden_states))
return hidden_states, attn_weights if need_weights else None
class AdditionTransformer(nn.Module):
def __init__(self, config: ExperimentConfig) -> None:
super().__init__()
self.config = config
self.token_embedding = nn.Embedding(config.discrete_vocab_size, config.d_model)
self.position_embedding = nn.Embedding(config.max_sequence_length, config.d_model)
self.latent_type_embedding = nn.Parameter(torch.zeros(config.d_model))
self.output_slot_embeddings = nn.Parameter(torch.zeros(config.output_sequence_length, config.d_model))
self.block = TransformerBlock(
d_model=config.d_model,
n_heads=config.n_heads,
ff_dim=config.ff_dim,
dropout=config.dropout,
)
self.final_ln = nn.LayerNorm(config.d_model)
self.digit_head = nn.Linear(config.d_model, config.digit_vocab_size)
self.final_carry_head = nn.Linear(config.d_model, 2)
self.reset_parameters()
def reset_parameters(self) -> None:
nn.init.normal_(self.token_embedding.weight, mean=0.0, std=0.02)
nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02)
nn.init.normal_(self.latent_type_embedding, mean=0.0, std=0.02)
nn.init.normal_(self.output_slot_embeddings, mean=0.0, std=0.02)
nn.init.xavier_uniform_(self.digit_head.weight)
nn.init.zeros_(self.digit_head.bias)
nn.init.xavier_uniform_(self.final_carry_head.weight)
nn.init.zeros_(self.final_carry_head.bias)
def embed_discrete_tokens(self, input_ids: torch.Tensor) -> torch.Tensor:
seq_len = input_ids.shape[1]
positions = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)
return self.token_embedding(input_ids) + self.position_embedding(positions)
def embed_output_slots(
self,
batch_size: int,
output_length: int,
latent_count: int,
input_length: int,
device: torch.device,
) -> torch.Tensor:
positions = torch.arange(output_length, device=device) + input_length + latent_count
positioned = self.output_slot_embeddings[:output_length] + self.position_embedding(positions)
return positioned.unsqueeze(0).expand(batch_size, -1, -1)
def _run_block(
self,
embeddings: torch.Tensor,
*,
need_attention: bool = False,
) -> tuple[torch.Tensor, torch.Tensor | None]:
hidden_states, attention_weights = self.block(embeddings, need_weights=need_attention)
hidden_states = self.final_ln(hidden_states)
return hidden_states, attention_weights
def forward(
self,
input_ids: torch.Tensor,
*,
latent_steps: int = 0,
return_attention: bool = False,
) -> ModelOutput:
base_embeddings = self.embed_discrete_tokens(input_ids)
latent_history: list[torch.Tensor] = []
attention_weights: torch.Tensor | None = None
batch_size = input_ids.shape[0]
input_length = input_ids.shape[1]
active_digits = max(1, (input_length - 2) // 2)
output_length = active_digits + 1
output_embeddings = self.embed_output_slots(
batch_size=batch_size,
output_length=output_length,
latent_count=0,
input_length=input_length,
device=input_ids.device,
)
hidden_states, attention_weights = self._run_block(
torch.cat([base_embeddings, output_embeddings], dim=1),
need_attention=return_attention,
)
output_hidden = hidden_states[:, -output_length:, :]
summary_hidden = output_hidden[:, -1, :]
latent_history.append(summary_hidden)
latent_embeddings: list[torch.Tensor] = []
for step_index in range(int(latent_steps)):
latent_token = summary_hidden.unsqueeze(1) + self.latent_type_embedding.view(1, 1, -1)
latent_position_index = input_length + step_index
latent_token = latent_token + self.position_embedding.weight[latent_position_index].view(1, 1, -1)
latent_embeddings.append(latent_token)
output_embeddings = self.embed_output_slots(
batch_size=batch_size,
output_length=output_length,
latent_count=len(latent_embeddings),
input_length=input_length,
device=input_ids.device,
)
hidden_states, attention_weights = self._run_block(
torch.cat([base_embeddings] + latent_embeddings + [output_embeddings], dim=1),
need_attention=return_attention,
)
latent_index = input_length + step_index
summary_hidden = hidden_states[:, latent_index, :]
output_hidden = hidden_states[:, -output_length:, :]
latent_history.append(summary_hidden)
digit_logits = self.digit_head(output_hidden[:, :active_digits, :])
final_carry_logits = self.final_carry_head(output_hidden[:, -1, :])
return ModelOutput(
digit_logits=digit_logits,
final_carry_logits=final_carry_logits,
output_hidden=output_hidden,
latent_history=latent_history,
attention_weights=attention_weights,
)
def parameter_count(self) -> int:
return sum(parameter.numel() for parameter in self.parameters())
def build_model(config: ExperimentConfig, device: str | None = None) -> AdditionTransformer:
model = AdditionTransformer(config)
if device is not None:
model = model.to(device)
return model
@torch.no_grad()
def describe_model(config: ExperimentConfig) -> dict[str, int]:
model = build_model(config)
total_params = model.parameter_count()
head_params = sum(parameter.numel() for name, parameter in model.named_parameters() if "head" in name)
embedding_params = sum(parameter.numel() for name, parameter in model.named_parameters() if "embedding" in name)
return {
"total_params": int(total_params),
"embedding_params": int(embedding_params),
"head_params": int(head_params),
"backbone_params": int(total_params - head_params),
}
|