File size: 7,286 Bytes
effa055
bd8e93b
effa055
bd8e93b
effa055
bd8e93b
effa055
bd8e93b
effa055
bd8e93b
effa055
bd8e93b
 
 
 
effa055
bd8e93b
effa055
bd8e93b
 
 
 
 
effa055
bd8e93b
effa055
bd8e93b
effa055
bd8e93b
 
 
 
 
effa055
bd8e93b
 
 
 
 
 
effa055
bd8e93b
effa055
bd8e93b
effa055
bd8e93b
 
 
 
effa055
bd8e93b
 
 
 
effa055
bd8e93b
effa055
bd8e93b
 
 
 
effa055
bd8e93b
effa055
bd8e93b
effa055
bd8e93b
effa055
bd8e93b
 
effa055
bd8e93b
effa055
bd8e93b
effa055
bd8e93b
effa055
bd8e93b
effa055
bd8e93b
 
effa055
bd8e93b
effa055
bd8e93b
 
effa055
bd8e93b
effa055
 
bd8e93b
effa055
bd8e93b
 
 
effa055
bd8e93b
effa055
bd8e93b
 
 
 
 
effa055
bd8e93b
effa055
bd8e93b
effa055
bd8e93b
effa055
bd8e93b
 
 
 
effa055
bd8e93b
effa055
bd8e93b
effa055
bd8e93b
effa055
bd8e93b
effa055
bd8e93b
effa055
bd8e93b
 
 
 
 
effa055
bd8e93b
 
 
 
effa055
 
 
bd8e93b
effa055
bd8e93b
effa055
bd8e93b
effa055
bd8e93b
effa055
bd8e93b
 
 
 
effa055
bd8e93b
effa055
bd8e93b
effa055
bd8e93b
effa055
bd8e93b
effa055
bd8e93b
 
 
 
effa055
bd8e93b
effa055
bd8e93b
effa055
bd8e93b
effa055
bd8e93b
 
 
effa055
bd8e93b
 
 
effa055
bd8e93b
 
 
effa055
bd8e93b
 
 
 
effa055
bd8e93b
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
---
{}
---
# Simple LLM Training with GPT-2 Architecture

This repository demonstrates how to train a Language Learning Model (LLM) from scratch using the GPT-2 architecture. The model is trained on numerical sequences to learn and predict patterns.

## ๐Ÿ“Œ Overview

This project implements a full machine learning pipeline:

- ๐Ÿ“Š **Synthetic dataset generation** (number sequences)
- ๐Ÿ”ค **Custom tokenizer training**
- ๐Ÿง  **Model training** using GPT-2
- ๐Ÿค– **Inference capabilities**

---

## ๐Ÿšง Progress So Far. We have trained a **6.4 million parameter** model that:
- Uses **base-16 (hexadecimal)** conversion for tokenization.
- Can **add up to 4-digit numbers with 100% accuracy**.
- Is publicly available on Github : ๐Ÿ”— [Rajesh-Nair/simple-llm](https://github.com/Rajesh-Nair/simple-llm)
---

## ๐Ÿ—๏ธ Dataset Generator

Synthetic number sequences are generated based on parameters defined in `data_config.yaml`.

**Example Configuration:**
- **Number range:** `0 - 9999`
- **Number of sequences:** `100,000`
- **Output path:** `../simple-llm-data/sequences.txt`
- **Delimiters:** `|` (columns), `\n` (rows)

### ๐Ÿ”ง To Generate the Dataset:
1. Update `data_config.yaml` with your desired parameters.
2. Run the generator:
   ```bash
   python3 data_generator.py
   ```

---

## ๐ŸŽฏ Training

### Step 1: Train the Tokenizer
```bash
python3 tokenizer.py
```

### Step 2: Train the Model
```bash
python3 trainer.py
```

Training configurations are managed in `train_config.yaml`, including:

- ๐Ÿ”ง Model architecture (layers, heads, embedding size)
- โš™๏ธ Training hyperparameters (batch size, learning rate)
- ๐Ÿ’พ Checkpointing and saving
- โ˜๏ธ Hugging Face Hub integration

---

## ๐Ÿ”ข Position Embeddings

### ๐Ÿ“ Learnable vs. Sinusoidal Embeddings

- **Learnable Embeddings**: Adapt to numeric patterns.
- **Sinusoidal Embeddings**: Provide a mathematical basis for position understanding.

---

### ๐Ÿงฎ Block Position IDs (Abacus Embedding)

Inspired by the [Abacus Embedding paper](https://arxiv.org/pdf/2405.17399), we use **block position IDs**.

**Example:**

- Input:     `+1342+879+2221+`
- Block IDs: `012340123012340`

#### ๐Ÿ” Why Block Position IDs?

1. โœ… **Commutative Support**: `a + b = b + a` โ€” block IDs reinforce this.
2. ๐Ÿง  **Digit Alignment**: Helps align units, tens, hundreds positions for easier digit-wise processing.

---


### ๐Ÿ”„ Digit Reversal

As part of preprocessing:
- `5672 โ†’ 2765` (reversed)
- Output is reversed back during evaluation.

#### ๐Ÿ“ˆ Benefits of Reversal

1. ๐Ÿง’ **Human-like learning**: Mimics the left-to-right addition humans use.
2. ๐ŸŽฏ **Causal attention compatibility**: Enables better carryover handling.
3. ๐Ÿ“š **Research-backed approach**: Digit reversal has been successfully used in several papers including:
   - [Transformers Can Do Arithmetic with the Right Embeddings](https://arxiv.org/pdf/2405.17399) (which also introduces Abacus embedding)
   - [Transformers Can Achieve Length Generalization But Not Robustly](https://arxiv.org/pdf/2402.09371)

---

## ๐Ÿงฉ Tokenization Strategy

Tokenization is **critical** for arithmetic modeling. Our approach:

1. ๐Ÿ“ **Shortens sequences**: Optimizes context window usage.
2. ๐Ÿงฌ **Boosts generalization**: Learns across number patterns.
3. ๐Ÿ”„ **Uses base conversion** (e.g., decimal โ†’ hexadecimal) for compact, arithmetic-aware tokens.
4. ๐Ÿง  **Preserves arithmetic logic**: Even in higher bases, rules still apply.

_We're experimenting with different bases to improve efficiency further._

---

## ๐Ÿ” Multi-token Prediction

Predicting **multiple tokens at once** increases efficiency. This is possible since we have reversed all the numbers.

### Example: To predict two token at a time, we see output 99 to appear in the first iteration

```
Input (reversed):     +12+873+PPPPPPPP      (P = padding tokens)
Output (reversed):    PPPPPP99PPPPPPPP      (P = padding tokens)
Position IDs:         0120123000000000
```

We're currently supporting **2-token prediction** and it works well 
๐Ÿ”— [mirajnair/simple-llm-gpt2-v2.0](https://huggingface.co/mirajnair/simple-llm-gpt2-v2.0)
 
 ..And we are expanding on generalizing this method - i.e output token at the earliest opportunity so we can have 2 or more predicted in one go.



## ๐Ÿ“Š Attention Visualization

Visualizing attention patterns reveals how the model processes arithmetic operations. Below is an example showing attention patterns for the addition problem: `101 + 1002 = 1103` (represented in reversed form as `+101+2001+3011+`).

### Layer 1 Attention Patterns

![Layer 1 Attention Visualization](https://github.com/Rajesh-Nair/simple-llm/blob/master/attention_visualizations/layer_1_attention.png)

In this visualization:
- **Bright vertical bars** at positions 1, 5, and 10 show how the model focuses on unit digits from both inputs and the output
- The model learns to align corresponding digit positions (units with units, tens with tens, etc.)
- Attention patterns reveal how information flows during the addition process, including carry operations

This confirms our block position ID approach helps the model understand the commutative nature of addition and properly align digits for arithmetic operations.

The visualization demonstrates how the model has learned to focus on relevant digits when performing calculations, similar to how humans process arithmetic problems.

## ๐ŸŽฏ Performance Results

We've rigorously tested our model's arithmetic capabilities with impressive results:

### Addition Performance Test
- **Test Set**: 10,000 random pairs of 4-digit numbers
- **Accuracy**: 100%
- **Consistency**: Maintained perfect accuracy across multiple test runs

This perfect accuracy demonstrates that our approach successfully teaches the model to perform addition operations with complete reliability, even on previously unseen number combinations. The combination of our specialized tokenization strategy, position encoding, and multi-token prediction enables the model to generalize arithmetic rules effectively.

These results validate our architectural choices and confirm that transformer-based models can master fundamental arithmetic operations when properly designed.

## ๐Ÿš€ Next Steps

1. **Multi-token Generation**: 
   - We've proved the model can output more than 1 token at a time 
   - Test if model can generate all tokens in one-go (greedy generation)

2. **Scale Up**:
   - Increase the length/number of digits in operations
   - Scale up model size for more complex operations

3. **Length Generalization**:
   - Implement and test length generalization techniques as described in [Transformers Can Achieve Length Generalization But Not Robustly](https://arxiv.org/pdf/2402.09371)
   - Explore methods to improve model's ability to handle varying sequence lengths

4. **Add batch prediction**:
   - Implement parallel processing of multiple arithmetic operations
   - Optimize throughput by processing multiple sequences simultaneously
   - Reduce overall inference time for bulk operations

5. **KV cache**:
   - Implement key-value caching to reduce redundant computations
   - Optimize memory usage during autoregressive generation
   - Speed up sequential token generation by reusing previous computations