File size: 6,015 Bytes
0ca67f7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 | # Architecture Guide
## Overview
AGIFORMER implements a novel hybrid architecture combining byte-level processing, linear attention, and iterative reasoning.
## Pipeline Flow
```
Input Bytes
↓
ByteLatentEncoder (with RoPE)
↓
HybridBlock × N (Linear Attention + Sliding Window)
↓
RecurrentReasoningBlock (System 2 - 3 steps)
↓
LocalAutoregressiveHead (GRU-based decoder)
↓
Output Bytes
```
---
## 1. ByteLatentEncoder
**File:** `src/models/encoder.py`
### Purpose
Converts raw byte sequences into latent patches with positional information.
### Architecture
- **Input:** `(Batch, Seq_Len)` bytes (0-255)
- **Embedding:** `nn.Embedding(256, d_model)`
- **Patching:** Reshape to `(Batch, Num_Patches, Patch_Size, d_model)`
- **RoPE:** Rotary Positional Embeddings for length generalization
- **Projection:** Linear layer to final latent dimension
- **Output:** `(Batch, Num_Patches, d_model)`
### Key Design Decisions
- **Why RoPE?** Enables extrapolation to longer sequences than training
- **Why Patching?** Reduces sequence length by factor of `patch_size` (default: 4)
---
## 2. HybridBlock
**File:** `src/models/layers.py`
### Components
#### 2.1 LinearAttention
**Complexity:** $O(N)$ instead of $O(N^2)$
**Formula:**
```
Q = elu(Wq * x) + 1.0 + ε
K = elu(Wk * x) + 1.0 + ε
V = Wv * x
Attention(Q, K, V) = (Q @ cumsum(K ⊗ V)) / (Q @ cumsum(K) + ε)
```
**Stability Fixes:**
- `elu(x) + 1.0 + 1e-4` ensures strict positivity (prevents division by zero)
- `Q` scaled by `sqrt(head_dim)` to control magnitude
- Layer norm on output
#### 2.2 SlidingWindowAttention
**Complexity:** $O(N × window_size)$
**Implementation:**
```python
scores = (Q @ K.T) / sqrt(d_k)
mask = causal_mask | window_mask # Blocks far tokens
scores = scores.masked_fill(mask, -1e4) # Safe masking
attn = softmax(scores)
out = attn @ V
```
**Why Manual?** PyTorch's `scaled_dot_product_attention` was unstable with custom masks.
### Fusion
```python
- **Residual Connection:** Allows model to skip thinking if not needed
- **Pre-Norm:** Stabilizes deep iteration
### Measured Activity
- **Latent Change:** Δz = 12.7 (Euclidean distance)
- **Gate Bias:** -0.0065 (near neutral)
- **Interpretation:** Model actively refines latents by ~56% per dimension
---
## 4. LocalAutoregressiveHead
**File:** `src/models/agiformer.py`
### Purpose
Decodes latent patches into byte sequences autoregressively.
### Architecture
#### Training Mode
```python
# Teacher forcing
inputs = [SOS, target[0], target[1], ..., target[P-2]]
targets = [target[0], target[1], ..., target[P-1]]
emb = ByteEmb(inputs) # (B*N, P, H)
context = LatentProj(latent).expand() # (B*N, P, H)
rnn_in = concat([emb, context], dim=-1) # (B*N, P, 2H)
out, _ = GRU(rnn_in)
logits = Linear(out) # (B*N, P, 256)
```
#### Inference Mode
```python
current = SOS
hidden = None
for i in range(patch_size):
emb = ByteEmb(current)
rnn_in = concat([emb, latent_context], dim=-1)
out, hidden = GRU(rnn_in, hidden)
logit = Linear(out)
# Sampling
if temperature > 0:
next_byte = multinomial(softmax(logit / temp))
else:
next_byte = argmax(logit)
current = next_byte
```
### Key Design
- **Concatenation (not Addition):** Preserves signal strength
- **GRU State:** Carries info across steps within a patch
- **Temperature Sampling:** Breaks repetition loops
---
## Loss Function
**Training:** Cross-entropy on next-patch prediction
```python
loss = CrossEntropy(logits, targets)
BPC = loss / ln(2) # Bits per character
```
**Metric:** BPC (Bits Per Character) - lower is better
- Random baseline: 8.0 BPC
- Good model: < 1.5 BPC
- AGIFORMER: 2.26 BPC (undertrained but stable)
---
## Hyperparameters
| Parameter | Value | Rationale |
|-----------|-------|-----------|
| `d_model` | 512 | Balance capacity/speed |
| `n_layers` | 6 | Deep enough for complexity |
| `num_heads` | 8 | Standard for 512-D |
| `patch_size` | 4 | 4× compression |
| `window_size` | 128 | Local attention context |
| `thinking_steps` | 3 | System 2 iterations |
| `learning_rate` | 3e-4 | With warmup |
| `batch_size` | 4 | GPU memory limit |
---
## Numerical Stability
### Challenges & Solutions
1. **Linear Attention Division by Zero**
- **Problem:** `elu(x) + 1.0` can = 0 if x very negative
- **Solution:** `elu(x) + 1.0 + 1e-4` (strict positivity)
2. **SDPA Masking Instability**
- **Problem:** NaN in `scaled_dot_product_attention` with bool masks
- **Solution:** Manual attention with `-1e4` instead of `-inf`
3. **System 2 Explosion**
- **Problem:** Iterative updates could amplify errors
- **Solution:** Gated residuals + pre-norm + small init
4. **Gradient Clipping**
- **Value:** 0.5 (aggressive)
- **Reason:** Prevents spikes during early training
---
## Memory & Compute
**Training (Batch=4, Seq=1024):**
- GPU Memory: ~6 GB (T4)
- Time/Step: ~180ms
- Total for 5000 steps: ~15 min
**Inference (Seq=200):**
- Latency: ~50ms (greedy)
- Memory: ~2 GB
**Scaling:**
- Linear Attention: $O(N)$ time
- System 2: $O(k × N)$ where k = thinking_steps
---
## Comparison to Baselines
| Feature | AGIFORMER | GPT-2 | Mamba |
|---------|-----------|-------|-------|
| Tokenization | None (bytes) | BPE | BPE |
| Attention | Linear ($O(N)$) | Quadratic | N/A |
| Recurrence | System 2 Loop | None | SSM |
| BPC (enwik8) | 2.26 | ~1.1 | ~1.0 |
| Training Time | 15 min | Hours | Hours |
**Note:** BPC gap due to undertrained model, not architecture limit.
---
## Future Improvements
1. **Longer Training:** Target BPC < 1.5
2. **More Thinking Steps:** 3 → 5-7 for harder tasks
3. **Sparse Experts:** Route different "thinking modes"
4. **Memory Module:** External differentiable memory
5. **Multi-Modal:** Extend to images/audio bytes
---
## References
- **Linear Transformers:** Katharopoulos et al., 2020
- **RoPE:** Su et al., 2021
- **System 2 Deep Learning:** Bengio et al., 2019
- **Mamba:** Gu & Dao, 2023
|