vincentoh commited on
Commit
8244b6b
Β·
verified Β·
1 Parent(s): 7ada1c6

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +167 -172
README.md CHANGED
@@ -1,202 +1,197 @@
1
- # CTM Experiments - Continuous Thought Machine Models
2
 
3
- Experimental checkpoints trained on the [Continuous Thought Machine](https://github.com/SakanaAI/continuous-thought-machines) architecture by Sakana AI.
4
 
5
- **These are community experiments on the original work - not official SakanaAI models.**
6
 
7
- ## Paper Reference
8
 
9
- > **Continuous Thought Machines**
10
- >
11
- > Sakana AI
12
- >
13
- > [arXiv:2505.05522](https://arxiv.org/abs/2505.05522)
14
- >
15
- > [Interactive Demo](https://pub.sakana.ai/ctm/) | [Blog Post](https://sakana.ai/ctm/)
16
 
17
- ```bibtex
18
- @article{sakana2025ctm,
19
- title={Continuous Thought Machines},
20
- author={Sakana AI},
21
- journal={arXiv preprint arXiv:2505.05522},
22
- year={2025}
23
- }
24
- ```
25
 
26
- ## Core Insight
27
-
28
- CTM's key innovation: **accuracy improves with more internal iterations**. The model "thinks longer" to reach better answers. This enables CTM to learn algorithmic reasoning that feedforward networks struggle with.
29
-
30
- ## Models
31
-
32
- | Model | File | Size | Task | Accuracy | Description |
33
- |-------|------|------|------|----------|-------------|
34
- | MNIST | `ctm-mnist.pt` | 1.3M | Digit classification | 97.9% | 10-class MNIST |
35
- | Parity-16 | `ctm-parity-16.pt` | 2.5M | Cumulative parity | 99.0% | 16-bit sequences |
36
- | Parity-64 | `ctm-parity-64.pt` | 66M | Cumulative parity | 75% | 64-bit sequences |
37
- | QAMNIST | `ctm-qamnist.pt` | 39M | Multi-step arithmetic | 100% | 3-5 digits, 3-5 ops |
38
- | Brackets | `ctm-brackets.pt` | 6.1M | Bracket matching | 94.7% | Valid/invalid `(()[])` |
39
- | Tracking-Quadrant | `ctm-tracking-quadrant.pt` | 6.7M | Motion quadrant | 100% | 4-class prediction |
40
- | Tracking-Position | `ctm-tracking-position.pt` | 6.7M | Exact position | 93.8% | 256-class (16x16 grid) |
41
- | Transfer | `ctm-transfer-parity-brackets.pt` | 2.5M | Transfer learning | 94.5% | Parity core to brackets |
42
- | Jigsaw MNIST | `ctm-jigsaw-mnist.pt` | 19M | Jigsaw puzzle solving | 92.3% | Reassemble 2x2 shuffled MNIST |
43
- | Rotation MNIST | `ctm-rotation-mnist.pt` | 4.2M | Rotation prediction | 89.1% | Predict rotation angle (4 classes) |
44
-
45
- ## Model Configurations
46
-
47
- ### MNIST CTM
48
- ```python
49
- config = {
50
- "iterations": 15,
51
- "memory_length": 10,
52
- "d_model": 128,
53
- "d_input": 128,
54
- "heads": 2,
55
- "n_synch_out": 16,
56
- "n_synch_action": 16,
57
- "memory_hidden_dims": 8,
58
- "out_dims": 10,
59
- "synapse_depth": 1,
60
- }
61
- ```
62
 
63
- ### Parity-16 CTM
64
- ```python
65
- config = {
66
- "iterations": 50,
67
- "memory_length": 25,
68
- "d_model": 256,
69
- "d_input": 32,
70
- "heads": 8,
71
- "synapse_depth": 8,
72
- "out_dims": 16, # cumulative parity
73
- }
74
- ```
75
 
76
- ### QAMNIST CTM
77
- ```python
78
- config = {
79
- "iterations": 10,
80
- "memory_length": 30,
81
- "d_model": 1024,
82
- "d_input": 64,
83
- "synapse_depth": 1,
84
- "heads": 4,
85
- "n_synch_out": 32,
86
- "n_synch_action": 32,
87
- }
88
- ```
89
 
90
- ### Brackets CTM
91
- ```python
92
- config = {
93
- "iterations": 30,
94
- "memory_length": 15,
95
- "d_model": 256,
96
- "d_input": 64,
97
- "heads": 4,
98
- "n_synch_out": 32,
99
- "n_synch_action": 32,
100
- "out_dims": 2, # valid/invalid
101
- }
102
- ```
103
 
104
- ### Tracking CTM
105
- ```python
106
- config = {
107
- "iterations": 20,
108
- "memory_length": 15,
109
- "d_model": 256,
110
- "d_input": 64,
111
- "heads": 4,
112
- "n_synch_out": 32,
113
- "n_synch_action": 32,
114
- }
115
- ```
116
 
117
- ### Jigsaw MNIST CTM
118
- ```python
119
- config = {
120
- "iterations": 30,
121
- "memory_length": 20,
122
- "d_model": 512,
123
- "d_input": 128,
124
- "heads": 8,
125
- "n_synch_out": 32,
126
- "n_synch_action": 32,
127
- "synapse_depth": 1,
128
- "out_dims": 24, # 4 tiles x 6 permutation options
129
- "backbone_type": "jigsaw",
130
- }
131
- ```
132
 
133
- ### Rotation MNIST CTM
134
- ```python
135
- config = {
136
- "iterations": 20,
137
- "memory_length": 15,
138
- "d_model": 256,
139
- "d_input": 64,
140
- "heads": 4,
141
- "n_synch_out": 32,
142
- "n_synch_action": 32,
143
- "synapse_depth": 1,
144
- "out_dims": 4, # 0Β°, 90Β°, 180Β°, 270Β°
145
- "backbone_type": "rotation",
146
- }
147
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
- ## Usage
150
 
151
- ```python
152
- import torch
153
- from huggingface_hub import hf_hub_download
154
 
155
- # Download model
156
- model_path = hf_hub_download(
157
- repo_id="vincentoh/ctm-experiments",
158
- filename="ctm-mnist.pt"
159
- )
160
 
161
- # Load checkpoint
162
- checkpoint = torch.load(model_path, map_location="cpu")
163
 
164
- # Initialize CTM with matching config
165
- from models.ctm import ContinuousThoughtMachine
166
 
167
- model = ContinuousThoughtMachine(**config)
168
- model.load_state_dict(checkpoint['model_state_dict'])
169
- model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
- # Inference
172
- with torch.no_grad():
173
- output = model(input_tensor)
 
 
 
174
  ```
175
 
176
- ## Training Details
 
 
 
 
 
177
 
178
- - **Hardware**: NVIDIA RTX 4070 Ti SUPER
179
- - **Framework**: PyTorch
180
- - **Optimizer**: AdamW
181
- - **Training time**: 5 minutes (MNIST) to 17 hours (QAMNIST)
182
 
183
- ## Key Findings
 
 
 
 
184
 
185
- 1. **Architecture > Scale**: Small sync dimensions (32) with linear synapses work better than large/deep variants
186
- 2. **"Thinking Longer" = Higher Accuracy**: CTM accuracy improves with more internal iterations
187
- 3. **Transfer Learning Works**: Parity-trained core transfers to brackets with 94.5% accuracy
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
 
189
- ## License
190
 
191
- MIT License (same as original CTM repository)
192
 
193
- ## Acknowledgments
194
 
195
- - [Sakana AI](https://sakana.ai/) for the Continuous Thought Machine architecture
196
- - Original [CTM Repository](https://github.com/SakanaAI/continuous-thought-machines)
197
 
198
- ## Links
199
 
200
- - [Experiment Repository](https://github.com/bigsnarfdude/ctm-experiments)
201
- - [Original Paper](https://arxiv.org/abs/2505.05522)
202
- - [Interactive Demo](https://pub.sakana.ai/ctm/)
 
1
+ # CTM Experiments
2
 
3
+ Personal experiments with [Continuous Thought Machines](https://github.com/SakanaAI/continuous-thought-machines) (SakanaAI).
4
 
5
+ **Interactive Demo**: https://pub.sakana.ai/ctm/
6
 
7
+ ## Core Insight: Thinking Takes Time
8
 
9
+ CTM's key innovation: **accuracy improves with more internal iterations**. The model "thinks longer" to reach better answers.
 
 
 
 
 
 
10
 
11
+ This enables CTM to learn algorithmic reasoning that feedforward networks struggle with:
 
 
 
 
 
 
 
12
 
13
+ | Task | Challenge | What CTM Learns |
14
+ |------|-----------|-----------------|
15
+ | **Parity** | Count bits across sequence | Iterative accumulation |
16
+ | **Brackets** | Track nested structure | Stack-like memory (LIFO) |
17
+ | **Object Tracking** | Extrapolate motion | Physics simulation |
18
+ | **Mazes** | Navigate 2D paths | Sequential decision making |
19
+ | **Jigsaw** | Classify shuffled patches | Part-whole integration |
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
+ ## Results Summary
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+ | Experiment | Accuracy | Notes |
24
+ |------------|----------|-------|
25
+ | **MNIST** | **97.9%** | Digit classification, 5 min training |
26
+ | **Parity-16** | **99.0%** | 16-bit cumulative parity |
27
+ | **QAMNIST** | **100%** | Multi-step arithmetic (3-5 digits, 3-5 ops) |
28
+ | **Brackets** | **94.7%** | Stack-like reasoning for `(()[])` vs `([)]` |
29
+ | **Object Tracking** | **100%** | Quadrant prediction from motion (4 classes) |
30
+ | **Velocity Prediction** | **100%** | Direction prediction (9 classes) |
31
+ | **Position Prediction** | **93.8%** | Exact position (256 classes, 16x16 grid) |
32
+ | **Transfer Learning** | **94.5%** | Parity→Brackets (core frozen) |
33
+ | **Maze Solving** | **Visualized** | Pretrained model inference on 15x15 mazes |
34
+ | **Jigsaw MNIST** | **92%** | Classify digits from shuffled patches (no positional encoding) |
 
35
 
36
+ ## Key Findings
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
+ ### 1. Architecture Matters More Than Scale
 
 
 
 
 
 
 
 
 
 
 
39
 
40
+ Early experiments showed 50% accuracy on parity (random guessing). The fix wasn't more parameters - it was using the **correct architecture**:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
+ | Parameter | Wrong | Correct (Official) |
43
+ |-----------|-------|-------------------|
44
+ | `n_synch_out` | 512 | **32** |
45
+ | `n_synch_action` | 512 | **32** |
46
+ | `synapse_depth` | 4 (U-NET) | **1** (linear) |
47
+
48
+ The official parity implementation uses surprisingly small synchronization dimensions with a linear synapse - this is critical for learning.
49
+
50
+ ### 2. "Thinking Longer" = Higher Accuracy
51
+
52
+ ![MNIST Inference per Tick](continuous-thought-machines/experiments/results/mnist_inference.png)
53
+
54
+ CTM accuracy improves with more internal iterations:
55
+ - **Tick 0**: 7% (random)
56
+ - **Tick 10-11**: 100% (peak)
57
+ - **Final tick**: 98%
58
+
59
+ Harder tasks need more "thinking time" - parity peaks at tick 35.
60
+
61
+ ### 3. Transfer Learning Works
62
+
63
+ Pretrained parity model transfers to brackets:
64
+ - **Baseline**: 52.5% (random)
65
+ - **After transfer**: 94.5% (core frozen, only backbone/output trained)
66
+
67
+ The iterative counting learned for parity transfers to stack tracking for brackets - matching from-scratch performance with only 37.7% of parameters trainable.
68
+
69
+ ### 4. Maze Solving "The Hard Way"
70
+
71
+ CTM solves mazes by outputting action trajectories (Up/Down/Left/Right/Wait), not pixel masks:
72
+ - **Step accuracy**: 60%+ after 2000 iterations
73
+ - Uses auto-extending curriculum (loss only on trajectory up to first error)
74
+ - Demonstrates sequential reasoning capability
75
+
76
+ ![Maze Attention Overlay](continuous-thought-machines/experiments/results/maze_attention.gif)
77
+
78
+ *CTM "thinking" through a 15x15 maze: blue = predicted path, red = attention focus, green = start position. The attention heatmap shows where CTM looks at each internal tick (T=75 iterations).*
79
+
80
+ ## Detailed Results
81
+
82
+ ### MNIST Digit Classification (97.9%)
83
+
84
+ ![MNIST Training Accuracy](continuous-thought-machines/experiments/results/mnist-ctm_smoothed.png)
85
 
86
+ CTM learns digit classification in ~5 minutes on RTX 4070 Ti.
87
 
88
+ ### Parity-16 Cumulative Parity (99.0%)
 
 
89
 
90
+ ![Parity Inference per Tick](continuous-thought-machines/experiments/results/parity_inference.png)
 
 
 
 
91
 
92
+ 16-bit parity with cumulative outputs - harder task shows clearer "thinking" benefit.
 
93
 
94
+ ### QAMNIST Multi-Step Arithmetic (100%)
 
95
 
96
+ ![QAMNIST Training Accuracy](continuous-thought-machines/experiments/results/qamnist-ctm-10_smoothed.png)
97
+
98
+ 100% accuracy on multi-step arithmetic (3-5 MNIST digits, 3-5 operations) after 300k iterations.
99
+
100
+ ### Maze Navigation (Pretrained Model)
101
+
102
+ Using the authors' pretrained checkpoint (`ctm_mazeslarge_D=2048_T=75_M=25.pt`), we ran inference on the small-mazes dataset:
103
+
104
+ - **Model**: D=2048 neurons, T=75 thinking steps, M=25 max trajectory length
105
+ - **Dataset**: 1000 test mazes (15x15 grid)
106
+ - **Output**: Action trajectories (Up/Down/Left/Right/Wait)
107
+
108
+ The visualization shows CTM's attention patterns as it navigates:
109
+ 1. **Red heatmap**: Where CTM "looks" at each thinking step
110
+ 2. **Blue path**: Predicted solution trajectory
111
+ 3. **Green marker**: Start position
112
+
113
+ Key insight: CTM learns sequential decision-making through iterative internal computation, not memorization.
114
+
115
+ ### Object Tracking - Position Prediction (93.8%)
116
+
117
+ ![Position Tracking Training](continuous-thought-machines/experiments/results/tracking_position.png)
118
+
119
+ The hardest tracking task: predict exact cell (256 classes) from 5 frames of motion. CTM reaches 93.8% test accuracy, demonstrating temporal reasoning across video frames.
120
+
121
+ ## Experiment Tracking
122
+
123
+ - **Configs**: [`experiments/experiments.json`](continuous-thought-machines/experiments/experiments.json)
124
+ - **Training Scripts**: [`experiments/training/`](continuous-thought-machines/experiments/training/)
125
+ - **Inference Scripts**: [`experiments/inference/`](continuous-thought-machines/experiments/inference/)
126
+ - **Results**: [`experiments/results/`](continuous-thought-machines/experiments/results/)
127
+
128
+ ## Custom Experiments
129
+
130
+ ### Bracket Matching
131
+ Classify bracket strings as valid or invalid: `(()[])` vs `([)]`
132
+
133
+ Requires tracking nested depth and bracket types - implementing a stack through iterative thinking.
134
+
135
+ ### Object Tracking
136
+ Predict properties of a moving dot from 5 video frames (16x16 grid).
137
 
138
+ ```
139
+ Frame 0 Frame 1 Frame 2 Frame 3 Frame 4
140
+ . . . . . . . . . . . . . . . . . . . .
141
+ . * . . . . * . . . . * . . . . . . . .
142
+ . . . . . . . . . . . . . . . * . . . .
143
+ . . . . . . . . . . . . . . . . . . . *
144
  ```
145
 
146
+ Three prediction tasks tested:
147
+ | Task | Classes | Accuracy | Notes |
148
+ |------|---------|----------|-------|
149
+ | **Quadrant** | 4 | 100% | TL/TR/BL/BR - easiest |
150
+ | **Velocity** | 9 | 100% | 8 directions + stationary |
151
+ | **Position** | 256 | 93.8% | Exact cell (16x16) - hardest |
152
 
153
+ All tasks converged, demonstrating CTM's ability to learn temporal/spatial reasoning.
 
 
 
154
 
155
+ ### Transfer Learning
156
+ Freeze core CTM dynamics from parity-16, train only backbone/output for brackets.
157
+
158
+ ### Maze Inference
159
+ Run pretrained maze model on small-mazes dataset to visualize CTM's "thinking" process:
160
 
161
+ ```bash
162
+ python -m tasks.mazes.analysis.run \
163
+ --actions viz \
164
+ --checkpoint checkpoints/mazes/ctm_mazeslarge_D=2048_T=75_M=25.pt \
165
+ --dataset_for_viz small-mazes
166
+ ```
167
+
168
+ Outputs attention overlay GIFs to `tasks/mazes/analysis/outputs/viz/`.
169
+
170
+ ### Jigsaw MNIST
171
+ Classify MNIST digits from **randomly shuffled patches** without positional encoding.
172
+
173
+ ```
174
+ Original: Shuffled (input):
175
+ β”Œβ”€β”€β”€β”¬β”€β”€β”€β”¬β”€β”€β”€β”¬β”€β”€β”€β” β”Œβ”€β”€β”€β”¬β”€β”€β”€β”¬β”€β”€β”€β”¬β”€β”€β”€β”
176
+ β”‚ 1 β”‚ 2 β”‚ 3 β”‚ 4 β”‚ β”‚12 β”‚ 7 β”‚ 2 β”‚15 β”‚
177
+ β”œβ”€β”€β”€β”Όβ”€β”€β”€β”Όβ”€β”€β”€β”Όβ”€β”€β”€β”€ β”œβ”€β”€β”€β”Όβ”€β”€β”€β”Όβ”€β”€β”€β”Όβ”€β”€β”€β”€
178
+ β”‚ 5 β”‚ 6 β”‚ 7 β”‚ 8 β”‚ => β”‚ 4 β”‚11 β”‚ 9 β”‚ 1 β”‚
179
+ β”œβ”€β”€β”€β”Όβ”€β”€β”€β”Όβ”€β”€β”€β”Όβ”€β”€β”€β”€ β”œβ”€β”€β”€β”Όβ”€β”€β”€β”Όβ”€β”€β”€β”Όβ”€β”€β”€β”€
180
+ β”‚ 9 β”‚10 β”‚11 β”‚12 β”‚ β”‚ 6 β”‚ 3 β”‚14 β”‚ 5 β”‚
181
+ β”œβ”€β”€β”€β”Όβ”€β”€β”€β”Όβ”€β”€β”€β”Όβ”€β”€β”€β”€ β”œβ”€β”€β”€β”Όβ”€β”€β”€β”Όβ”€β”€β”€β”Όβ”€β”€β”€β”€
182
+ β”‚13 β”‚14 β”‚15 β”‚16 β”‚ β”‚16 β”‚ 8 β”‚10 β”‚13 β”‚
183
+ β””β”€β”€β”€β”΄β”€β”€β”€β”΄β”€β”€β”€β”΄β”€β”€β”€β”˜ β””β”€β”€β”€β”΄β”€β”€β”€β”΄β”€β”€β”€β”΄β”€β”€β”€β”˜
184
+ ```
185
 
186
+ **Task**: Given 16 shuffled 7x7 patches, predict the digit class (0-9).
187
 
188
+ **Challenge**: No positional encoding - CTM must learn to recognize digit parts and integrate them correctly through its internal synchronization dynamics.
189
 
190
+ **Result**: **92% test accuracy** - CTM successfully learns part-whole relationships without explicit position information.
191
 
192
+ ![Jigsaw Training](continuous-thought-machines/experiments/results/jigsaw_training.png)
 
193
 
194
+ ## Resources
195
 
196
+ - [CTM Paper](2505.05522v4.pdf)
197
+ - [Original SakanaAI Repo](https://github.com/SakanaAI/continuous-thought-machines)