vincentoh commited on
Commit
7f14743
·
verified ·
1 Parent(s): 6fcf7e1

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +168 -0
README.md ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CTM Experiments - Continuous Thought Machine Models
2
+
3
+ Experimental checkpoints trained on the [Continuous Thought Machine](https://github.com/SakanaAI/continuous-thought-machines) architecture by Sakana AI.
4
+
5
+ **These are community experiments on the original work - not official SakanaAI models.**
6
+
7
+ ## Paper Reference
8
+
9
+ > **Continuous Thought Machines**
10
+ >
11
+ > Sakana AI
12
+ >
13
+ > [arXiv:2505.05522](https://arxiv.org/abs/2505.05522)
14
+ >
15
+ > [Interactive Demo](https://pub.sakana.ai/ctm/) | [Blog Post](https://sakana.ai/ctm/)
16
+
17
+ ```bibtex
18
+ @article{sakana2025ctm,
19
+ title={Continuous Thought Machines},
20
+ author={Sakana AI},
21
+ journal={arXiv preprint arXiv:2505.05522},
22
+ year={2025}
23
+ }
24
+ ```
25
+
26
+ ## Core Insight
27
+
28
+ CTM's key innovation: **accuracy improves with more internal iterations**. The model "thinks longer" to reach better answers. This enables CTM to learn algorithmic reasoning that feedforward networks struggle with.
29
+
30
+ ## Models
31
+
32
+ | Model | File | Size | Task | Accuracy | Description |
33
+ |-------|------|------|------|----------|-------------|
34
+ | MNIST | `ctm-mnist.pt` | 1.3M | Digit classification | 97.9% | 10-class MNIST |
35
+ | Parity-16 | `ctm-parity-16.pt` | 2.5M | Cumulative parity | 99.0% | 16-bit sequences |
36
+ | Parity-64 | `ctm-parity-64.pt` | 66M | Cumulative parity | 75% | 64-bit sequences |
37
+ | QAMNIST | `ctm-qamnist.pt` | 39M | Multi-step arithmetic | 100% | 3-5 digits, 3-5 ops |
38
+ | Brackets | `ctm-brackets.pt` | 6.1M | Bracket matching | 94.7% | Valid/invalid `(()[])` |
39
+ | Tracking-Quadrant | `ctm-tracking-quadrant.pt` | 6.7M | Motion quadrant | 100% | 4-class prediction |
40
+ | Tracking-Position | `ctm-tracking-position.pt` | 6.7M | Exact position | 93.8% | 256-class (16x16 grid) |
41
+ | Transfer | `ctm-transfer-parity-brackets.pt` | 2.5M | Transfer learning | 94.5% | Parity core to brackets |
42
+
43
+ ## Model Configurations
44
+
45
+ ### MNIST CTM
46
+ ```python
47
+ config = {
48
+ "iterations": 15,
49
+ "memory_length": 10,
50
+ "d_model": 128,
51
+ "d_input": 128,
52
+ "heads": 2,
53
+ "n_synch_out": 16,
54
+ "n_synch_action": 16,
55
+ "memory_hidden_dims": 8,
56
+ "out_dims": 10,
57
+ "synapse_depth": 1,
58
+ }
59
+ ```
60
+
61
+ ### Parity-16 CTM
62
+ ```python
63
+ config = {
64
+ "iterations": 50,
65
+ "memory_length": 25,
66
+ "d_model": 256,
67
+ "d_input": 32,
68
+ "heads": 8,
69
+ "synapse_depth": 8,
70
+ "out_dims": 16, # cumulative parity
71
+ }
72
+ ```
73
+
74
+ ### QAMNIST CTM
75
+ ```python
76
+ config = {
77
+ "iterations": 10,
78
+ "memory_length": 30,
79
+ "d_model": 1024,
80
+ "d_input": 64,
81
+ "synapse_depth": 1,
82
+ "heads": 4,
83
+ "n_synch_out": 32,
84
+ "n_synch_action": 32,
85
+ }
86
+ ```
87
+
88
+ ### Brackets CTM
89
+ ```python
90
+ config = {
91
+ "iterations": 30,
92
+ "memory_length": 15,
93
+ "d_model": 256,
94
+ "d_input": 64,
95
+ "heads": 4,
96
+ "n_synch_out": 32,
97
+ "n_synch_action": 32,
98
+ "out_dims": 2, # valid/invalid
99
+ }
100
+ ```
101
+
102
+ ### Tracking CTM
103
+ ```python
104
+ config = {
105
+ "iterations": 20,
106
+ "memory_length": 15,
107
+ "d_model": 256,
108
+ "d_input": 64,
109
+ "heads": 4,
110
+ "n_synch_out": 32,
111
+ "n_synch_action": 32,
112
+ }
113
+ ```
114
+
115
+ ## Usage
116
+
117
+ ```python
118
+ import torch
119
+ from huggingface_hub import hf_hub_download
120
+
121
+ # Download model
122
+ model_path = hf_hub_download(
123
+ repo_id="vincentoh/ctm-experiments",
124
+ filename="ctm-mnist.pt"
125
+ )
126
+
127
+ # Load checkpoint
128
+ checkpoint = torch.load(model_path, map_location="cpu")
129
+
130
+ # Initialize CTM with matching config
131
+ from models.ctm import ContinuousThoughtMachine
132
+
133
+ model = ContinuousThoughtMachine(**config)
134
+ model.load_state_dict(checkpoint['model_state_dict'])
135
+ model.eval()
136
+
137
+ # Inference
138
+ with torch.no_grad():
139
+ output = model(input_tensor)
140
+ ```
141
+
142
+ ## Training Details
143
+
144
+ - **Hardware**: NVIDIA RTX 4070 Ti SUPER
145
+ - **Framework**: PyTorch
146
+ - **Optimizer**: AdamW
147
+ - **Training time**: 5 minutes (MNIST) to 17 hours (QAMNIST)
148
+
149
+ ## Key Findings
150
+
151
+ 1. **Architecture > Scale**: Small sync dimensions (32) with linear synapses work better than large/deep variants
152
+ 2. **"Thinking Longer" = Higher Accuracy**: CTM accuracy improves with more internal iterations
153
+ 3. **Transfer Learning Works**: Parity-trained core transfers to brackets with 94.5% accuracy
154
+
155
+ ## License
156
+
157
+ MIT License (same as original CTM repository)
158
+
159
+ ## Acknowledgments
160
+
161
+ - [Sakana AI](https://sakana.ai/) for the Continuous Thought Machine architecture
162
+ - Original [CTM Repository](https://github.com/SakanaAI/continuous-thought-machines)
163
+
164
+ ## Links
165
+
166
+ - [Experiment Repository](https://github.com/bigsnarfdude/ctm-experiments)
167
+ - [Original Paper](https://arxiv.org/abs/2505.05522)
168
+ - [Interactive Demo](https://pub.sakana.ai/ctm/)