tefoteknik commited on
Commit
0ca67f7
·
verified ·
1 Parent(s): 6f236fd

Update AGIFORMER with Turkish benchmark

Browse files
Files changed (1) hide show
  1. docs/architecture.md +268 -0
docs/architecture.md ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Architecture Guide
2
+
3
+ ## Overview
4
+
5
+ AGIFORMER implements a novel hybrid architecture combining byte-level processing, linear attention, and iterative reasoning.
6
+
7
+ ## Pipeline Flow
8
+
9
+ ```
10
+ Input Bytes
11
+
12
+ ByteLatentEncoder (with RoPE)
13
+
14
+ HybridBlock × N (Linear Attention + Sliding Window)
15
+
16
+ RecurrentReasoningBlock (System 2 - 3 steps)
17
+
18
+ LocalAutoregressiveHead (GRU-based decoder)
19
+
20
+ Output Bytes
21
+ ```
22
+
23
+ ---
24
+
25
+ ## 1. ByteLatentEncoder
26
+
27
+ **File:** `src/models/encoder.py`
28
+
29
+ ### Purpose
30
+ Converts raw byte sequences into latent patches with positional information.
31
+
32
+ ### Architecture
33
+ - **Input:** `(Batch, Seq_Len)` bytes (0-255)
34
+ - **Embedding:** `nn.Embedding(256, d_model)`
35
+ - **Patching:** Reshape to `(Batch, Num_Patches, Patch_Size, d_model)`
36
+ - **RoPE:** Rotary Positional Embeddings for length generalization
37
+ - **Projection:** Linear layer to final latent dimension
38
+ - **Output:** `(Batch, Num_Patches, d_model)`
39
+
40
+ ### Key Design Decisions
41
+ - **Why RoPE?** Enables extrapolation to longer sequences than training
42
+ - **Why Patching?** Reduces sequence length by factor of `patch_size` (default: 4)
43
+
44
+ ---
45
+
46
+ ## 2. HybridBlock
47
+
48
+ **File:** `src/models/layers.py`
49
+
50
+ ### Components
51
+
52
+ #### 2.1 LinearAttention
53
+ **Complexity:** $O(N)$ instead of $O(N^2)$
54
+
55
+ **Formula:**
56
+ ```
57
+ Q = elu(Wq * x) + 1.0 + ε
58
+ K = elu(Wk * x) + 1.0 + ε
59
+ V = Wv * x
60
+
61
+ Attention(Q, K, V) = (Q @ cumsum(K ⊗ V)) / (Q @ cumsum(K) + ε)
62
+ ```
63
+
64
+ **Stability Fixes:**
65
+ - `elu(x) + 1.0 + 1e-4` ensures strict positivity (prevents division by zero)
66
+ - `Q` scaled by `sqrt(head_dim)` to control magnitude
67
+ - Layer norm on output
68
+
69
+ #### 2.2 SlidingWindowAttention
70
+ **Complexity:** $O(N × window_size)$
71
+
72
+ **Implementation:**
73
+ ```python
74
+ scores = (Q @ K.T) / sqrt(d_k)
75
+ mask = causal_mask | window_mask # Blocks far tokens
76
+ scores = scores.masked_fill(mask, -1e4) # Safe masking
77
+ attn = softmax(scores)
78
+ out = attn @ V
79
+ ```
80
+
81
+ **Why Manual?** PyTorch's `scaled_dot_product_attention` was unstable with custom masks.
82
+
83
+ ### Fusion
84
+ ```python
85
+ x = residual + out_proj(attn_out + ssm_out)
86
+ ```
87
+ Parallel branches (not sequential) for efficiency.
88
+
89
+ ---
90
+
91
+ ## 3. RecurrentReasoningBlock (System 2)
92
+
93
+ **File:** `src/models/reasoning.py`
94
+
95
+ ### Algorithm
96
+ ```python
97
+ z_0 = input # Initial latent from backbone
98
+
99
+ for t in range(thinking_steps):
100
+ norm_z = LayerNorm(z_t)
101
+ update = MLP(norm_z) # Candidate thought
102
+ gate = sigmoid(W_gate @ norm_z) # How much to accept
103
+ z_{t+1} = z_t + gate * update # Gated residual
104
+
105
+ return z_T # Refined latent
106
+ ```
107
+
108
+ ### Design Philosophy
109
+ - **Gated Update:** Prevents explosion/vanishing (like LSTM)
110
+ - **Residual Connection:** Allows model to skip thinking if not needed
111
+ - **Pre-Norm:** Stabilizes deep iteration
112
+
113
+ ### Measured Activity
114
+ - **Latent Change:** Δz = 12.7 (Euclidean distance)
115
+ - **Gate Bias:** -0.0065 (near neutral)
116
+ - **Interpretation:** Model actively refines latents by ~56% per dimension
117
+
118
+ ---
119
+
120
+ ## 4. LocalAutoregressiveHead
121
+
122
+ **File:** `src/models/agiformer.py`
123
+
124
+ ### Purpose
125
+ Decodes latent patches into byte sequences autoregressively.
126
+
127
+ ### Architecture
128
+
129
+ #### Training Mode
130
+ ```python
131
+ # Teacher forcing
132
+ inputs = [SOS, target[0], target[1], ..., target[P-2]]
133
+ targets = [target[0], target[1], ..., target[P-1]]
134
+
135
+ emb = ByteEmb(inputs) # (B*N, P, H)
136
+ context = LatentProj(latent).expand() # (B*N, P, H)
137
+ rnn_in = concat([emb, context], dim=-1) # (B*N, P, 2H)
138
+
139
+ out, _ = GRU(rnn_in)
140
+ logits = Linear(out) # (B*N, P, 256)
141
+ ```
142
+
143
+ #### Inference Mode
144
+ ```python
145
+ current = SOS
146
+ hidden = None
147
+
148
+ for i in range(patch_size):
149
+ emb = ByteEmb(current)
150
+ rnn_in = concat([emb, latent_context], dim=-1)
151
+ out, hidden = GRU(rnn_in, hidden)
152
+ logit = Linear(out)
153
+
154
+ # Sampling
155
+ if temperature > 0:
156
+ next_byte = multinomial(softmax(logit / temp))
157
+ else:
158
+ next_byte = argmax(logit)
159
+
160
+ current = next_byte
161
+ ```
162
+
163
+ ### Key Design
164
+ - **Concatenation (not Addition):** Preserves signal strength
165
+ - **GRU State:** Carries info across steps within a patch
166
+ - **Temperature Sampling:** Breaks repetition loops
167
+
168
+ ---
169
+
170
+ ## Loss Function
171
+
172
+ **Training:** Cross-entropy on next-patch prediction
173
+ ```python
174
+ loss = CrossEntropy(logits, targets)
175
+ BPC = loss / ln(2) # Bits per character
176
+ ```
177
+
178
+ **Metric:** BPC (Bits Per Character) - lower is better
179
+ - Random baseline: 8.0 BPC
180
+ - Good model: < 1.5 BPC
181
+ - AGIFORMER: 2.26 BPC (undertrained but stable)
182
+
183
+ ---
184
+
185
+ ## Hyperparameters
186
+
187
+ | Parameter | Value | Rationale |
188
+ |-----------|-------|-----------|
189
+ | `d_model` | 512 | Balance capacity/speed |
190
+ | `n_layers` | 6 | Deep enough for complexity |
191
+ | `num_heads` | 8 | Standard for 512-D |
192
+ | `patch_size` | 4 | 4× compression |
193
+ | `window_size` | 128 | Local attention context |
194
+ | `thinking_steps` | 3 | System 2 iterations |
195
+ | `learning_rate` | 3e-4 | With warmup |
196
+ | `batch_size` | 4 | GPU memory limit |
197
+
198
+ ---
199
+
200
+ ## Numerical Stability
201
+
202
+ ### Challenges & Solutions
203
+
204
+ 1. **Linear Attention Division by Zero**
205
+ - **Problem:** `elu(x) + 1.0` can = 0 if x very negative
206
+ - **Solution:** `elu(x) + 1.0 + 1e-4` (strict positivity)
207
+
208
+ 2. **SDPA Masking Instability**
209
+ - **Problem:** NaN in `scaled_dot_product_attention` with bool masks
210
+ - **Solution:** Manual attention with `-1e4` instead of `-inf`
211
+
212
+ 3. **System 2 Explosion**
213
+ - **Problem:** Iterative updates could amplify errors
214
+ - **Solution:** Gated residuals + pre-norm + small init
215
+
216
+ 4. **Gradient Clipping**
217
+ - **Value:** 0.5 (aggressive)
218
+ - **Reason:** Prevents spikes during early training
219
+
220
+ ---
221
+
222
+ ## Memory & Compute
223
+
224
+ **Training (Batch=4, Seq=1024):**
225
+ - GPU Memory: ~6 GB (T4)
226
+ - Time/Step: ~180ms
227
+ - Total for 5000 steps: ~15 min
228
+
229
+ **Inference (Seq=200):**
230
+ - Latency: ~50ms (greedy)
231
+ - Memory: ~2 GB
232
+
233
+ **Scaling:**
234
+ - Linear Attention: $O(N)$ time
235
+ - System 2: $O(k × N)$ where k = thinking_steps
236
+
237
+ ---
238
+
239
+ ## Comparison to Baselines
240
+
241
+ | Feature | AGIFORMER | GPT-2 | Mamba |
242
+ |---------|-----------|-------|-------|
243
+ | Tokenization | None (bytes) | BPE | BPE |
244
+ | Attention | Linear ($O(N)$) | Quadratic | N/A |
245
+ | Recurrence | System 2 Loop | None | SSM |
246
+ | BPC (enwik8) | 2.26 | ~1.1 | ~1.0 |
247
+ | Training Time | 15 min | Hours | Hours |
248
+
249
+ **Note:** BPC gap due to undertrained model, not architecture limit.
250
+
251
+ ---
252
+
253
+ ## Future Improvements
254
+
255
+ 1. **Longer Training:** Target BPC < 1.5
256
+ 2. **More Thinking Steps:** 3 → 5-7 for harder tasks
257
+ 3. **Sparse Experts:** Route different "thinking modes"
258
+ 4. **Memory Module:** External differentiable memory
259
+ 5. **Multi-Modal:** Extend to images/audio bytes
260
+
261
+ ---
262
+
263
+ ## References
264
+
265
+ - **Linear Transformers:** Katharopoulos et al., 2020
266
+ - **RoPE:** Su et al., 2021
267
+ - **System 2 Deep Learning:** Bengio et al., 2019
268
+ - **Mamba:** Gu & Dao, 2023