Butanium commited on
Commit
771a34c
·
verified ·
1 Parent(s): 9ad2fe4

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +58 -32
README.md CHANGED
@@ -1,37 +1,63 @@
1
- ---
2
- {}
3
- ---
4
- # One-Layer Simple Transformer
5
-
6
- A 1-layer transformer described in [A Mathematical Framework for Transformer Circuits](https://transformer-circuits.pub/2021/framework/index.html).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
- ## Usage
9
 
10
- ```python
11
- from transformers import LlamaConfig
12
- from migrate_models import OneLayerTransformer
13
 
14
- # Load the model
15
  model = OneLayerTransformer.from_pretrained('Butanium/simple-stories-one-layer-simple-transformer')
16
-
17
- # Or create from config
18
- config = LlamaConfig(vocab_size=4096, hidden_size=128, num_hidden_layers=1)
19
- model = OneLayerTransformer(config)
20
  ```
21
-
22
- ## Model Architecture
23
-
24
- This model consists of:
25
- - Token embeddings
26
- - Single self-attention layer with residual connection
27
- - Linear output head
28
-
29
- It serves as a minimal transformer for understanding attention mechanisms and transformer circuits.
30
-
31
- ## Training Details
32
-
33
- - Trained on SimpleStories dataset
34
- - Vocabulary size: 4096
35
- - Hidden size: 128
36
- - Single self-attention layer
37
- - 4 attention heads
 
1
+ 1-layer simple transformer described in [A Mathematical Framework for Transformer Circuits](https://transformer-circuits.pub/2021/framework/index.html).
2
+ Load with
3
+ ```python
4
+ class OneLayerTransformer(PreTrainedModel):
5
+ config_class = LlamaConfig
6
+
7
+ def __init__(self, config: LlamaConfig):
8
+ super().__init__(config)
9
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
10
+
11
+ # Single self-attention layer
12
+ self.self_attn = nn.MultiheadAttention(
13
+ embed_dim=config.hidden_size,
14
+ num_heads=config.num_attention_heads,
15
+ dropout=getattr(config, 'attention_dropout', 0.0),
16
+ batch_first=True,
17
+ )
18
+
19
+ # Output head
20
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
21
+
22
+ def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
23
+ batch_size, seq_len = input_ids.shape
24
+
25
+ # Embeddings
26
+ hidden_states = self.embed_tokens(input_ids)
27
+ assert hidden_states.shape == (batch_size, seq_len, self.config.hidden_size)
28
+
29
+ # Create causal mask for self-attention
30
+ causal_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
31
+ causal_mask = causal_mask.to(hidden_states.device)
32
+
33
+ # Self-attention with residual connection
34
+ attn_output, _ = self.self_attn(
35
+ hidden_states,
36
+ hidden_states,
37
+ hidden_states,
38
+ attn_mask=causal_mask,
39
+ key_padding_mask=None if attention_mask is None else ~attention_mask.bool(),
40
+ )
41
+ hidden_states = hidden_states + attn_output
42
+ assert hidden_states.shape == (batch_size, seq_len, self.config.hidden_size)
43
+
44
+ # Output projection
45
+ logits = self.lm_head(hidden_states)
46
+ assert logits.shape == (batch_size, seq_len, self.config.vocab_size)
47
+
48
+ loss = None
49
+ if labels is not None:
50
+ shift_logits = logits[..., :-1, :].contiguous()
51
+ shift_labels = labels[..., 1:].contiguous()
52
+ loss_fct = nn.CrossEntropyLoss()
53
+ loss = loss_fct(
54
+ shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
55
+ )
56
+
57
+ return {"loss": loss, "logits": logits}
58
 
 
59
 
 
 
 
60
 
 
61
  model = OneLayerTransformer.from_pretrained('Butanium/simple-stories-one-layer-simple-transformer')
 
 
 
 
62
  ```
63
+ The model is trained on the SimpleStories dataset.