--- {} --- 1-layer simple transformer described in [A Mathematical Framework for Transformer Circuits](https://transformer-circuits.pub/2021/framework/index.html). Load with ```python class OneLayerTransformer(PreTrainedModel): config_class = LlamaConfig def __init__(self, config: LlamaConfig): super().__init__(config) self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) # Single self-attention layer self.self_attn = nn.MultiheadAttention( embed_dim=config.hidden_size, num_heads=config.num_attention_heads, dropout=getattr(config, 'attention_dropout', 0.0), batch_first=True, ) # Output head self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs): batch_size, seq_len = input_ids.shape # Embeddings hidden_states = self.embed_tokens(input_ids) assert hidden_states.shape == (batch_size, seq_len, self.config.hidden_size) # Create causal mask for self-attention causal_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool() causal_mask = causal_mask.to(hidden_states.device) # Self-attention with residual connection attn_output, _ = self.self_attn( hidden_states, hidden_states, hidden_states, attn_mask=causal_mask, key_padding_mask=None if attention_mask is None else ~attention_mask.bool(), ) hidden_states = hidden_states + attn_output assert hidden_states.shape == (batch_size, seq_len, self.config.hidden_size) # Output projection logits = self.lm_head(hidden_states) assert logits.shape == (batch_size, seq_len, self.config.vocab_size) loss = None if labels is not None: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss_fct = nn.CrossEntropyLoss() loss = loss_fct( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) ) return {"loss": loss, "logits": logits} model = OneLayerTransformer.from_pretrained('Butanium/simple-stories-one-layer-simple-transformer') ``` The model is trained on the SimpleStories dataset.