mirajnair commited on
Commit
bd8e93b
ยท
verified ยท
1 Parent(s): ef3d128

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +125 -133
README.md CHANGED
@@ -1,199 +1,191 @@
1
  ---
2
- library_name: transformers
3
- tags: []
4
  ---
 
5
 
6
- # Model Card for Model ID
7
 
8
- <!-- Provide a quick summary of what the model is/does. -->
9
 
 
10
 
 
 
 
 
11
 
12
- ## Model Details
13
-
14
- ### Model Description
15
-
16
- <!-- Provide a longer summary of what this model is. -->
17
-
18
- This is the model card of a ๐Ÿค— transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
-
20
- - **Developed by:** [More Information Needed]
21
- - **Funded by [optional]:** [More Information Needed]
22
- - **Shared by [optional]:** [More Information Needed]
23
- - **Model type:** [More Information Needed]
24
- - **Language(s) (NLP):** [More Information Needed]
25
- - **License:** [More Information Needed]
26
- - **Finetuned from model [optional]:** [More Information Needed]
27
-
28
- ### Model Sources [optional]
29
-
30
- <!-- Provide the basic links for the model. -->
31
-
32
- - **Repository:** [More Information Needed]
33
- - **Paper [optional]:** [More Information Needed]
34
- - **Demo [optional]:** [More Information Needed]
35
-
36
- ## Uses
37
-
38
- <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
-
40
- ### Direct Use
41
-
42
- <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
-
44
- [More Information Needed]
45
-
46
- ### Downstream Use [optional]
47
-
48
- <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
-
50
- [More Information Needed]
51
-
52
- ### Out-of-Scope Use
53
-
54
- <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
-
56
- [More Information Needed]
57
-
58
- ## Bias, Risks, and Limitations
59
-
60
- <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
-
62
- [More Information Needed]
63
-
64
- ### Recommendations
65
-
66
- <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
-
68
- Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
-
70
- ## How to Get Started with the Model
71
-
72
- Use the code below to get started with the model.
73
-
74
- [More Information Needed]
75
-
76
- ## Training Details
77
-
78
- ### Training Data
79
-
80
- <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
 
82
- [More Information Needed]
 
 
 
 
83
 
84
- ### Training Procedure
85
 
86
- <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
 
88
- #### Preprocessing [optional]
 
 
 
 
89
 
90
- [More Information Needed]
 
 
 
 
 
91
 
 
92
 
93
- #### Training Hyperparameters
94
 
95
- - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
 
 
 
96
 
97
- #### Speeds, Sizes, Times [optional]
 
 
 
98
 
99
- <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
 
101
- [More Information Needed]
 
 
 
102
 
103
- ## Evaluation
104
 
105
- <!-- This section describes the evaluation protocols and provides the results. -->
106
 
107
- ### Testing Data, Factors & Metrics
108
 
109
- #### Testing Data
 
110
 
111
- <!-- This should link to a Dataset Card if possible. -->
112
 
113
- [More Information Needed]
114
 
115
- #### Factors
116
 
117
- <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
 
119
- [More Information Needed]
 
120
 
121
- #### Metrics
122
 
123
- <!-- These are the evaluation metrics being used, ideally with a description of why. -->
 
124
 
125
- [More Information Needed]
126
 
127
- ### Results
128
 
129
- [More Information Needed]
130
 
131
- #### Summary
 
 
132
 
 
133
 
 
 
 
 
 
134
 
135
- ## Model Examination [optional]
136
 
137
- <!-- Relevant interpretability work for the model goes here -->
138
 
139
- [More Information Needed]
140
 
141
- ## Environmental Impact
 
 
 
142
 
143
- <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
 
145
- Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
 
147
- - **Hardware Type:** [More Information Needed]
148
- - **Hours used:** [More Information Needed]
149
- - **Cloud Provider:** [More Information Needed]
150
- - **Compute Region:** [More Information Needed]
151
- - **Carbon Emitted:** [More Information Needed]
152
 
153
- ## Technical Specifications [optional]
154
 
155
- ### Model Architecture and Objective
156
 
157
- [More Information Needed]
 
 
 
 
158
 
159
- ### Compute Infrastructure
 
 
 
160
 
161
- [More Information Needed]
162
 
163
- #### Hardware
164
 
165
- [More Information Needed]
166
 
167
- #### Software
168
 
169
- [More Information Needed]
170
 
171
- ## Citation [optional]
172
 
173
- <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
 
 
 
174
 
175
- **BibTeX:**
176
 
177
- [More Information Needed]
178
 
179
- **APA:**
180
 
181
- [More Information Needed]
182
 
183
- ## Glossary [optional]
 
 
 
184
 
185
- <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
 
187
- [More Information Needed]
188
 
189
- ## More Information [optional]
190
 
191
- [More Information Needed]
 
 
192
 
193
- ## Model Card Authors [optional]
 
 
194
 
195
- [More Information Needed]
 
 
196
 
197
- ## Model Card Contact
 
 
 
198
 
199
- [More Information Needed]
 
 
 
 
1
  ---
2
+ {}
 
3
  ---
4
+ # Simple LLM Training with GPT-2 Architecture
5
 
6
+ This repository demonstrates how to train a Language Learning Model (LLM) from scratch using the GPT-2 architecture. The model is trained on numerical sequences to learn and predict patterns.
7
 
8
+ ## ๐Ÿ“Œ Overview
9
 
10
+ This project implements a full machine learning pipeline:
11
 
12
+ - ๐Ÿ“Š **Synthetic dataset generation** (number sequences)
13
+ - ๐Ÿ”ค **Custom tokenizer training**
14
+ - ๐Ÿง  **Model training** using GPT-2
15
+ - ๐Ÿค– **Inference capabilities**
16
 
17
+ ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
+ ## ๐Ÿšง Progress So Far. We have trained a **6.4 million parameter** model that:
20
+ - Uses **base-16 (hexadecimal)** conversion for tokenization.
21
+ - Can **add up to 4-digit numbers with 100% accuracy**.
22
+ - Is publicly available on Github : ๐Ÿ”— [Rajesh-Nair/simple-llm](https://github.com/Rajesh-Nair/simple-llm)
23
+ ---
24
 
25
+ ## ๐Ÿ—๏ธ Dataset Generator
26
 
27
+ Synthetic number sequences are generated based on parameters defined in `data_config.yaml`.
28
 
29
+ **Example Configuration:**
30
+ - **Number range:** `0 - 9999`
31
+ - **Number of sequences:** `100,000`
32
+ - **Output path:** `../simple-llm-data/sequences.txt`
33
+ - **Delimiters:** `|` (columns), `\n` (rows)
34
 
35
+ ### ๐Ÿ”ง To Generate the Dataset:
36
+ 1. Update `data_config.yaml` with your desired parameters.
37
+ 2. Run the generator:
38
+ ```bash
39
+ python3 data_generator.py
40
+ ```
41
 
42
+ ---
43
 
44
+ ## ๐ŸŽฏ Training
45
 
46
+ ### Step 1: Train the Tokenizer
47
+ ```bash
48
+ python3 tokenizer.py
49
+ ```
50
 
51
+ ### Step 2: Train the Model
52
+ ```bash
53
+ python3 trainer.py
54
+ ```
55
 
56
+ Training configurations are managed in `train_config.yaml`, including:
57
 
58
+ - ๐Ÿ”ง Model architecture (layers, heads, embedding size)
59
+ - โš™๏ธ Training hyperparameters (batch size, learning rate)
60
+ - ๐Ÿ’พ Checkpointing and saving
61
+ - โ˜๏ธ Hugging Face Hub integration
62
 
63
+ ---
64
 
65
+ ## ๐Ÿ”ข Position Embeddings
66
 
67
+ ### ๐Ÿ“ Learnable vs. Sinusoidal Embeddings
68
 
69
+ - **Learnable Embeddings**: Adapt to numeric patterns.
70
+ - **Sinusoidal Embeddings**: Provide a mathematical basis for position understanding.
71
 
72
+ ---
73
 
74
+ ### ๐Ÿงฎ Block Position IDs (Abacus Embedding)
75
 
76
+ Inspired by the [Abacus Embedding paper](https://arxiv.org/pdf/2405.17399), we use **block position IDs**.
77
 
78
+ **Example:**
79
 
80
+ - Input: `+1342+879+2221+`
81
+ - Block IDs: `012340123012340`
82
 
83
+ #### ๐Ÿ” Why Block Position IDs?
84
 
85
+ 1. โœ… **Commutative Support**: `a + b = b + a` โ€” block IDs reinforce this.
86
+ 2. ๐Ÿง  **Digit Alignment**: Helps align units, tens, hundreds positions for easier digit-wise processing.
87
 
88
+ ---
89
 
 
90
 
91
+ ### ๐Ÿ”„ Digit Reversal
92
 
93
+ As part of preprocessing:
94
+ - `5672 โ†’ 2765` (reversed)
95
+ - Output is reversed back during evaluation.
96
 
97
+ #### ๐Ÿ“ˆ Benefits of Reversal
98
 
99
+ 1. ๐Ÿง’ **Human-like learning**: Mimics the left-to-right addition humans use.
100
+ 2. ๐ŸŽฏ **Causal attention compatibility**: Enables better carryover handling.
101
+ 3. ๐Ÿ“š **Research-backed approach**: Digit reversal has been successfully used in several papers including:
102
+ - [Transformers Can Do Arithmetic with the Right Embeddings](https://arxiv.org/pdf/2405.17399) (which also introduces Abacus embedding)
103
+ - [Transformers Can Achieve Length Generalization But Not Robustly](https://arxiv.org/pdf/2402.09371)
104
 
105
+ ---
106
 
107
+ ## ๐Ÿงฉ Tokenization Strategy
108
 
109
+ Tokenization is **critical** for arithmetic modeling. Our approach:
110
 
111
+ 1. ๐Ÿ“ **Shortens sequences**: Optimizes context window usage.
112
+ 2. ๐Ÿงฌ **Boosts generalization**: Learns across number patterns.
113
+ 3. ๐Ÿ”„ **Uses base conversion** (e.g., decimal โ†’ hexadecimal) for compact, arithmetic-aware tokens.
114
+ 4. ๐Ÿง  **Preserves arithmetic logic**: Even in higher bases, rules still apply.
115
 
116
+ _We're experimenting with different bases to improve efficiency further._
117
 
118
+ ---
119
 
120
+ ## ๐Ÿ” Multi-token Prediction
 
 
 
 
121
 
122
+ Predicting **multiple tokens at once** increases efficiency. This is possible since we have reversed all the numbers.
123
 
124
+ ### Example: To predict two token at a time, we see output 99 to appear in the first iteration
125
 
126
+ ```
127
+ Input (reversed): +12+873+PPPPPPPP (P = padding tokens)
128
+ Output (reversed): PPPPPP99PPPPPPPP (P = padding tokens)
129
+ Position IDs: 0120123000000000
130
+ ```
131
 
132
+ We're currently supporting **2-token prediction** and it works well
133
+ ๐Ÿ”— [mirajnair/simple-llm-gpt2-v2.0](https://huggingface.co/mirajnair/simple-llm-gpt2-v2.0)
134
+
135
+ ..And we are expanding on generalizing this method - i.e output token at the earliest opportunity so we can have 2 or more predicted in one go.
136
 
 
137
 
 
138
 
139
+ ## ๐Ÿ“Š Attention Visualization
140
 
141
+ Visualizing attention patterns reveals how the model processes arithmetic operations. Below is an example showing attention patterns for the addition problem: `101 + 1002 = 1103` (represented in reversed form as `+101+2001+3011+`).
142
 
143
+ ### Layer 1 Attention Patterns
144
 
145
+ ![Layer 1 Attention Visualization](https://github.com/Rajesh-Nair/simple-llm/blob/master/attention_visualizations/layer_1_attention.png)
146
 
147
+ In this visualization:
148
+ - **Bright vertical bars** at positions 1, 5, and 10 show how the model focuses on unit digits from both inputs and the output
149
+ - The model learns to align corresponding digit positions (units with units, tens with tens, etc.)
150
+ - Attention patterns reveal how information flows during the addition process, including carry operations
151
 
152
+ This confirms our block position ID approach helps the model understand the commutative nature of addition and properly align digits for arithmetic operations.
153
 
154
+ The visualization demonstrates how the model has learned to focus on relevant digits when performing calculations, similar to how humans process arithmetic problems.
155
 
156
+ ## ๐ŸŽฏ Performance Results
157
 
158
+ We've rigorously tested our model's arithmetic capabilities with impressive results:
159
 
160
+ ### Addition Performance Test
161
+ - **Test Set**: 10,000 random pairs of 4-digit numbers
162
+ - **Accuracy**: 100%
163
+ - **Consistency**: Maintained perfect accuracy across multiple test runs
164
 
165
+ This perfect accuracy demonstrates that our approach successfully teaches the model to perform addition operations with complete reliability, even on previously unseen number combinations. The combination of our specialized tokenization strategy, position encoding, and multi-token prediction enables the model to generalize arithmetic rules effectively.
166
 
167
+ These results validate our architectural choices and confirm that transformer-based models can master fundamental arithmetic operations when properly designed.
168
 
169
+ ## ๐Ÿš€ Next Steps
170
 
171
+ 1. **Multi-token Generation**:
172
+ - We've proved the model can output more than 1 token at a time
173
+ - Test if model can generate all tokens in one-go (greedy generation)
174
 
175
+ 2. **Scale Up**:
176
+ - Increase the length/number of digits in operations
177
+ - Scale up model size for more complex operations
178
 
179
+ 3. **Length Generalization**:
180
+ - Implement and test length generalization techniques as described in [Transformers Can Achieve Length Generalization But Not Robustly](https://arxiv.org/pdf/2402.09371)
181
+ - Explore methods to improve model's ability to handle varying sequence lengths
182
 
183
+ 4. **Add batch prediction**:
184
+ - Implement parallel processing of multiple arithmetic operations
185
+ - Optimize throughput by processing multiple sequences simultaneously
186
+ - Reduce overall inference time for bulk operations
187
 
188
+ 5. **KV cache**:
189
+ - Implement key-value caching to reduce redundant computations
190
+ - Optimize memory usage during autoregressive generation
191
+ - Speed up sequential token generation by reusing previous computations