Elliotasdasdasfasas commited on
Commit
ed89628
·
1 Parent(s): 9bb3382

Deploy CTM Codebase bypass FUSE 503

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .dockerignore +5 -0
  2. Dockerfile +34 -0
  3. INSTRUCCIONES_DESPLIEGUE.md +34 -0
  4. LICENSE +201 -0
  5. README.md +22 -7
  6. app.py +945 -0
  7. app_v1_backup.py +464 -0
  8. data/custom_datasets.py +324 -0
  9. examples/01_mnist.ipynb +0 -0
  10. examples/02_inference.ipynb +0 -0
  11. examples/03_mazes.ipynb +0 -0
  12. examples/04_parity.ipynb +0 -0
  13. examples/05_huggingface.ipynb +0 -0
  14. models/README.md +7 -0
  15. models/constants.py +10 -0
  16. models/ctm.py +633 -0
  17. models/ctm_qamnist.py +208 -0
  18. models/ctm_rl.py +192 -0
  19. models/ctm_sort.py +126 -0
  20. models/ff.py +75 -0
  21. models/lstm.py +244 -0
  22. models/lstm_qamnist.py +184 -0
  23. models/lstm_rl.py +96 -0
  24. models/modules.py +692 -0
  25. models/resnet.py +374 -0
  26. models/utils.py +122 -0
  27. mount_azure.sh +44 -0
  28. requirements.txt +21 -0
  29. requirements_v1.txt +2 -0
  30. setup_hf_space.sh +37 -0
  31. tasks/image_classification/README.md +31 -0
  32. tasks/image_classification/analysis/README.md +7 -0
  33. tasks/image_classification/analysis/run_imagenet_analysis.py +972 -0
  34. tasks/image_classification/imagenet_classes.py +1007 -0
  35. tasks/image_classification/plotting.py +494 -0
  36. tasks/image_classification/scripts/train_cifar10.sh +286 -0
  37. tasks/image_classification/scripts/train_imagenet.sh +38 -0
  38. tasks/image_classification/train.py +690 -0
  39. tasks/image_classification/train_distributed.py +799 -0
  40. tasks/mazes/README.md +16 -0
  41. tasks/mazes/analysis/README.md +10 -0
  42. tasks/mazes/analysis/run.py +407 -0
  43. tasks/mazes/plotting.py +214 -0
  44. tasks/mazes/scripts/train_ctm.sh +35 -0
  45. tasks/mazes/train.py +704 -0
  46. tasks/mazes/train_distributed.py +782 -0
  47. tasks/parity/README.md +16 -0
  48. tasks/parity/analysis/make_blog_gifs.py +263 -0
  49. tasks/parity/analysis/run.py +269 -0
  50. tasks/parity/plotting.py +897 -0
.dockerignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ checkpoints/
2
+ data/
3
+ .git/
4
+ __pycache__/
5
+ *.ipynb
Dockerfile ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ WORKDIR /code
4
+
5
+ # 1. Install System Dependencies (SSHFS + Curl)
6
+ RUN apt-get update && apt-get install -y \
7
+ sshfs \
8
+ curl \
9
+ fuse \
10
+ && rm -rf /var/lib/apt/lists/*
11
+
12
+ # 2. Install Cloudflared
13
+ RUN curl -L --output cloudflared.deb https://github.com/cloudflare/cloudflared/releases/latest/download/cloudflared-linux-amd64.deb && \
14
+ dpkg -i cloudflared.deb && \
15
+ rm cloudflared.deb
16
+
17
+ COPY ./requirements.txt /code/requirements.txt
18
+ RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
19
+
20
+ # Create mount point
21
+ RUN mkdir -p /data/persistent && chmod 777 /data/persistent
22
+
23
+ RUN useradd -m -u 1000 user
24
+ USER user
25
+ ENV HOME=/home/user \
26
+ PATH=/home/user/.local/bin:$PATH
27
+
28
+ WORKDIR $HOME/app
29
+
30
+ COPY --chown=user . $HOME/app
31
+ RUN chmod +x mount_azure.sh
32
+
33
+ # El puerto 7860 es obligatorio en Hugging Face Spaces
34
+ CMD ["python", "app.py"]
INSTRUCCIONES_DESPLIEGUE.md ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚀 Instrucciones de Despliegue: Continuous Thought Machines (CTM)
2
+
3
+ Debido a restricciones de permisos en la terminal actual, el despliegue final requiere que ejecutes el "puente" que he construido desde tu entorno WSL (Ubuntu/Debian).
4
+
5
+ ## 1. Requisitos Previos (Ya configurados)
6
+ * **Código**: Clonado en `c:\Users\elliot\Downloads\simulacion\ctm_sakana`
7
+ * **Docker**: Archivos `Dockerfile` y `.dockerignore` creados.
8
+ * **Dependencias**: `requirements.txt` parcheado (opencv-headless).
9
+ * **Script de Conexión**: `setup_hf_space.sh` creado con tu Token.
10
+
11
+ ## 2. Pasos de Ejecución (En tu WSL)
12
+
13
+ Abre tu terminal WSL y ejecuta el siguiente bloque de comandos:
14
+
15
+ ```bash
16
+ # 1. Navegar a la carpeta del proyecto (WSL monta C: en /mnt/c)
17
+ cd /mnt/c/Users/elliot/Downloads/simulacion/ctm_sakana
18
+
19
+ # 2. Dar permisos de ejecución al script
20
+ chmod +x setup_hf_space.sh
21
+
22
+ # 3. Ejecutar el script automatizado
23
+ ./setup_hf_space.sh
24
+ ```
25
+
26
+ ### 3. Durante la Ejecución
27
+ El script te pedirá el **Nombre del Space**.
28
+ * Basado en la captura, el usuario parece ser `Alex Herbert Vilca Puente` o `ROBOT-GANSTA`.
29
+ * Ingresa el nombre del Space en formato `USUARIO/NOMBRE_SPACE` si te lo pide el script, o solo el nombre si el script ya tiene el usuario hardcodeado (El script tiene `Elliotasdasdasfasas`, asegúrate de que coincida con tu Space real o edita el script).
30
+
31
+ ## 4. Verificación
32
+ Una vez subido, ve a tu Space en Hugging Face. Verás que empieza a decir **"Building"**. Esto significa que Docker está instalando las dependencias que definimos.
33
+
34
+ > **Nota**: El proceso de build tardará unos minutos la primera vez.
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 2020 Rémi Louf
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,11 +1,26 @@
1
  ---
2
- title: Gansta
3
- emoji: 🦀
4
- colorFrom: green
5
- colorTo: indigo
6
- sdk: docker
 
 
7
  pinned: false
8
- license: gemma
9
  ---
10
 
11
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: CTM Nervous System
3
+ emoji: 🧬
4
+ colorFrom: purple
5
+ colorTo: blue
6
+ sdk: gradio
7
+ sdk_version: 5.9.1
8
+ app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
+ # 🧬 CTM Nervous System
13
+
14
+ **Continuous Thought Machine for Hypergraph Maintenance**
15
+
16
+ Based on [arXiv:2505.05522](https://arxiv.org/abs/2505.05522) - Sakana AI
17
+
18
+ ## Endpoints
19
+
20
+ | Endpoint | Function |
21
+ |----------|----------|
22
+ | `/sense_snn` | Process 72D SNN input |
23
+ | `/reason_hypergraph` | Reason about context, propose edges |
24
+ | `/validate_physics` | Validate against 5 physics losses |
25
+ | `/dream` | Offline consolidation (T=500+) |
26
+ | `/calibrate_stdp` | Suggest STDP weight adjustments |
app.py ADDED
@@ -0,0 +1,945 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CTM Nervous System Server v2.0 - Full PyTorch Implementation
3
+ =============================================================
4
+ Continuous Thought Machine for ART-17 Hypergraph Coherence Generation
5
+
6
+ PURPOSE (from skills):
7
+ 1. REGULACIÓN: Calibrar pesos STDP de las 16 dendritas
8
+ 2. COHERENCIA: Generar hipergrafos deterministas
9
+ 3. RAZONAMIENTO: Motor de inferencia activa (internal ticks)
10
+ 4. SINCRONIZACIÓN: Representación via Neural Synchronization
11
+
12
+ TRAINING STRATEGY:
13
+ - Progressive online learning with use
14
+ - Integrates with Brain server (Qwen + VL-JEPA) for semantic grounding
15
+ - Automatic checkpoint saving
16
+
17
+ Based on: arXiv:2505.05522 (Continuous Thought Machines - Sakana AI)
18
+ Adapted for: ART-17 Dendrite Regulation & Hypergraph Generation
19
+ """
20
+
21
+ import gradio as gr
22
+ import numpy as np
23
+ import json
24
+ import os
25
+ from typing import List, Dict, Any, Optional
26
+ from datetime import datetime
27
+ from utils.bunker_client import BunkerClient
28
+
29
+
30
+ # ============================================================================
31
+ # PYTORCH IMPORTS WITH FALLBACK
32
+ # ============================================================================
33
+ try:
34
+ import torch
35
+ import torch.nn as nn
36
+ import torch.nn.functional as F
37
+ TORCH_AVAILABLE = True
38
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
39
+ print(f"🔧 PyTorch available. Device: {DEVICE}")
40
+ except ImportError:
41
+ TORCH_AVAILABLE = False
42
+ DEVICE = "cpu"
43
+ print("⚠️ PyTorch not available. Using simplified NumPy fallback.")
44
+
45
+ # ============================================================================
46
+ # FULL CTM IMPORT (with fallback to simplified)
47
+ # ============================================================================
48
+ if TORCH_AVAILABLE:
49
+ try:
50
+ from models.ctm import ContinuousThoughtMachine
51
+ from models.modules import SynapseUNET, SuperLinear
52
+ from utils.losses import image_classification_loss
53
+ CTM_FULL = True
54
+ print("✅ Full CTM model loaded from models/ctm.py")
55
+ except ImportError as e:
56
+ CTM_FULL = False
57
+ print(f"⚠️ Could not import full CTM: {e}. Using simplified.")
58
+ else:
59
+ CTM_FULL = False
60
+
61
+ # ============================================================================
62
+ # CONFIGURATION FOR ART-17 INTEGRATION (v3.0)
63
+ # ============================================================================
64
+ CONFIG = {
65
+ # CTM Architecture (matching ART-17)
66
+ "iterations": 50, # T internal ticks (max)
67
+ "d_model": 256, # Latent dimension
68
+ "d_input": 72, # Input from SNN (72D)
69
+ "memory_length": 16, # History length (16 dendrites)
70
+ "n_synch_out": 32, # Output sync neurons
71
+ "n_synch_action": 16, # Action sync neurons
72
+ "out_dims": 16, # Output: 16 dendrite adjustments
73
+
74
+ # v3.0 Improvements
75
+ "adaptive_halting": True, # Enable early stopping
76
+ "certainty_threshold": 0.85, # Halt if certainty > threshold
77
+ "sync_decay_alpha": 0.9, # S_new = α*S_old + (1-α)*S_current
78
+ "use_backbone": True, # Use Backbone72D transformation
79
+
80
+ # Training
81
+ "learning_rate": 1e-4,
82
+ "weight_decay": 1e-5,
83
+ "checkpoint_dir": "checkpoints",
84
+ "auto_save_every": 100, # Save every N forward passes
85
+
86
+ # Integration
87
+ "brain_server_url": "https://elliotasdasdasfasas-brain.hf.space",
88
+
89
+ # Physics validation
90
+ "physics_thresholds": {
91
+ "P_max": 1000.0,
92
+ "v_max": 100.0,
93
+ "T_dew": 15.0,
94
+ "T_amb": 25.0
95
+ }
96
+ }
97
+
98
+ # ============================================================================
99
+ # BACKBONE 72D (v3.0 - Transform input before CTM)
100
+ # ============================================================================
101
+ class Backbone72D(nn.Module if TORCH_AVAILABLE else object):
102
+ """
103
+ Transform 72D SNN input to d_model dimensions.
104
+ Paper insight: Raw input needs proper embedding for CTM to work well.
105
+ """
106
+ def __init__(self, d_input=72, d_model=256):
107
+ if not TORCH_AVAILABLE:
108
+ return
109
+ super().__init__()
110
+ self.net = nn.Sequential(
111
+ nn.Linear(d_input, 128),
112
+ nn.LayerNorm(128),
113
+ nn.GELU(),
114
+ nn.Linear(128, d_model),
115
+ nn.LayerNorm(d_model)
116
+ )
117
+
118
+ def forward(self, x):
119
+ # x: [B, 72]
120
+ return self.net(x) # [B, 256]
121
+
122
+ # ============================================================================
123
+ # FULL CTM WRAPPER FOR ART-17
124
+ # ============================================================================
125
+ class CTM_ART17:
126
+ """
127
+ Full Continuous Thought Machine adapted for ART-17.
128
+
129
+ Key mechanisms from paper:
130
+ 1. NLMs (Neuron-Level Models) - Each neuron processes its own history
131
+ 2. Neural Synchronization - Representation is S = Z·Z^T
132
+ 3. Adaptive Compute - Can halt early when confident
133
+
134
+ Purpose in ART-17:
135
+ - Regulate 16 dendrite STDP weights
136
+ - Generate coherent hypergraph edges
137
+ - Serve as "nervous system" for the whole system
138
+ """
139
+
140
+ def __init__(self, config: dict):
141
+ self.config = config
142
+ self.forward_count = 0
143
+ self.training_samples = []
144
+ self.bunker = BunkerClient(buffer_dir=config.get("buffer_dir", "_ctm_buffer"))
145
+
146
+ if CTM_FULL and TORCH_AVAILABLE:
147
+ self._init_full_ctm()
148
+ else:
149
+ self._init_simplified_ctm()
150
+
151
+ def _init_full_ctm(self):
152
+ """Initialize full PyTorch CTM model."""
153
+ self.model = ContinuousThoughtMachine(
154
+ iterations=self.config["iterations"],
155
+ d_model=self.config["d_model"],
156
+ d_input=self.config["d_input"],
157
+ heads=4,
158
+ n_synch_out=self.config["n_synch_out"],
159
+ n_synch_action=self.config["n_synch_action"],
160
+ synapse_depth=2,
161
+ memory_length=self.config["memory_length"],
162
+ deep_nlms=True,
163
+ memory_hidden_dims=32,
164
+ do_layernorm_nlm=False,
165
+ backbone_type='none',
166
+ positional_embedding_type='none',
167
+ out_dims=self.config["out_dims"],
168
+ prediction_reshaper=[self.config["out_dims"]],
169
+ dropout=0.1,
170
+ neuron_select_type='random-pairing'
171
+ ).to(DEVICE)
172
+
173
+ # Dummy forward to initialize lazy modules
174
+ with torch.no_grad():
175
+ dummy = torch.randn(1, self.config["d_input"], device=DEVICE)
176
+ dummy = dummy.unsqueeze(-1).unsqueeze(-1) # [1, 72, 1, 1]
177
+ try:
178
+ _ = self.model(dummy)
179
+ except Exception as e:
180
+ print(f"⚠️ Lazy init failed: {e}")
181
+
182
+ self.model.eval()
183
+ self.optimizer = torch.optim.AdamW(
184
+ self.model.parameters(),
185
+ lr=self.config["learning_rate"],
186
+ weight_decay=self.config["weight_decay"]
187
+ )
188
+
189
+ self.is_full = True
190
+ param_count = sum(p.numel() for p in self.model.parameters())
191
+ print(f"✅ Full CTM initialized: {param_count:,} parameters")
192
+
193
+ # Try to load existing checkpoint
194
+ self._load_checkpoint()
195
+
196
+ def _init_simplified_ctm(self):
197
+ """Initialize simplified NumPy CTM (fallback)."""
198
+ self.d_model = self.config["d_model"]
199
+ self.memory_length = self.config["memory_length"]
200
+ self.n_ticks = self.config["iterations"]
201
+
202
+ # State traces
203
+ self.state_trace = np.zeros((self.d_model, self.memory_length))
204
+ self.activated_state = np.random.randn(self.d_model) * 0.1
205
+
206
+ # NLM weights (simplified: 16 groups for 16 dendrites)
207
+ self.nlm_weights = np.random.randn(16, self.memory_length) * 0.1
208
+
209
+ self.is_full = False
210
+ print("✅ Simplified CTM initialized (NumPy fallback)")
211
+
212
+ def forward(self, input_72d: np.ndarray, n_ticks: Optional[int] = None) -> Dict:
213
+ """
214
+ Process input through CTM.
215
+
216
+ Args:
217
+ input_72d: 72D input from SNN
218
+ n_ticks: Override number of internal ticks
219
+
220
+ Returns:
221
+ Dict with predictions, certainty, sync matrix
222
+ """
223
+ n_ticks = n_ticks or self.config["iterations"]
224
+ self.forward_count += 1
225
+
226
+ if self.is_full:
227
+ return self._forward_full(input_72d, n_ticks)
228
+ else:
229
+ return self._forward_simplified(input_72d, n_ticks)
230
+
231
+ def _forward_full(self, input_72d: np.ndarray, n_ticks: int) -> Dict:
232
+ """Forward pass with full PyTorch CTM."""
233
+ # Prepare tensor
234
+ x = torch.tensor(input_72d, dtype=torch.float32, device=DEVICE)
235
+ if len(x.shape) == 1:
236
+ x = x.unsqueeze(0) # Add batch dim
237
+ x = x.unsqueeze(-1).unsqueeze(-1) # [B, 72, 1, 1]
238
+
239
+ with torch.no_grad():
240
+ predictions, certainties, sync_out = self.model(x)
241
+
242
+ # Extract results
243
+ final_pred = predictions[:, :, -1].cpu().numpy()[0] # Last tick [16]
244
+ final_cert = certainties[:, 1, -1].cpu().numpy()[0] # 1-entropy
245
+
246
+ # Find tick with highest certainty
247
+ best_tick_idx = certainties[:, 1, :].argmax(dim=-1)[0].item()
248
+ best_pred = predictions[:, :, best_tick_idx].cpu().numpy()[0]
249
+
250
+ # Sync matrix for hypergraph edge proposals
251
+ sync_matrix = sync_out.cpu().numpy()[0] if sync_out is not None else None
252
+
253
+ return {
254
+ "predictions": final_pred.tolist(),
255
+ "best_predictions": best_pred.tolist(),
256
+ "certainty": float(final_cert),
257
+ "best_tick": int(best_tick_idx),
258
+ "ticks_used": n_ticks,
259
+ "sync_matrix": sync_matrix.tolist() if sync_matrix is not None else None,
260
+ "model": "ContinuousThoughtMachine (Full PyTorch)"
261
+ }
262
+
263
+ def _forward_simplified(self, input_72d: np.ndarray, n_ticks: int) -> Dict:
264
+ """
265
+ Forward pass with simplified NumPy CTM (v3.0).
266
+
267
+ v3.0 Features:
268
+ 1. Backbone transformation (72D -> 256D)
269
+ 2. Sync Decay (S = α*S_prev + (1-α)*S_current)
270
+ 3. Adaptive Halting (stop if certainty > threshold)
271
+ """
272
+ # v3.0: Backbone transformation (simple linear projection)
273
+ if self.config.get("use_backbone", True):
274
+ # Learned transformation: 72D -> 256D
275
+ input_256 = np.zeros(self.d_model)
276
+ # Simple linear projection + normalization (simulates Backbone72D)
277
+ projected = np.tanh(input_72d[:72] * np.random.randn(72) * 0.1) if len(input_72d) >= 72 else input_72d
278
+ input_256[:min(len(projected), self.d_model)] = projected[:min(len(projected), self.d_model)]
279
+ else:
280
+ input_256 = np.zeros(self.d_model)
281
+ input_256[:min(len(input_72d), self.d_model)] = input_72d[:self.d_model]
282
+
283
+ # v3.0: Sync Decay initialization
284
+ alpha = self.config.get("sync_decay_alpha", 0.9)
285
+ sync_matrix_prev = np.zeros((self.d_model, self.d_model))
286
+
287
+ # v3.0: Adaptive halting config
288
+ adaptive_halting = self.config.get("adaptive_halting", True)
289
+ certainty_threshold = self.config.get("certainty_threshold", 0.85)
290
+
291
+ certainties = []
292
+ all_predictions = []
293
+ ticks_actually_used = 0
294
+
295
+ for t in range(n_ticks):
296
+ ticks_actually_used = t + 1
297
+
298
+ # Synapse update (simplified global mixing)
299
+ combined = np.concatenate([self.activated_state, input_256[:self.d_model//2]])
300
+ pre_activation = np.tanh(combined[:self.d_model] * 0.1 + np.random.randn(self.d_model) * 0.01)
301
+
302
+ # Update trace (memory)
303
+ self.state_trace = np.roll(self.state_trace, -1, axis=1)
304
+ self.state_trace[:, -1] = pre_activation
305
+
306
+ # NLM processing (simplified: 16 groups for 16 dendrites)
307
+ post_activation = np.zeros(self.d_model)
308
+ group_size = self.d_model // 16
309
+ for g in range(16):
310
+ start = g * group_size
311
+ end = start + group_size
312
+ group_trace = self.state_trace[start:end, :]
313
+ group_output = np.mean(group_trace @ self.nlm_weights[g])
314
+ post_activation[start:end] = np.tanh(group_output)
315
+
316
+ self.activated_state = post_activation
317
+
318
+ # v3.0: Sync Decay - S = α*S_prev + (1-α)*Z·Z^T
319
+ z_norm = self.activated_state / (np.linalg.norm(self.activated_state) + 1e-8)
320
+ sync_current = np.outer(z_norm, z_norm)
321
+ sync_matrix = alpha * sync_matrix_prev + (1 - alpha) * sync_current
322
+ sync_matrix_prev = sync_matrix
323
+
324
+ # Store predictions at this tick
325
+ all_predictions.append(self.activated_state[:16].copy())
326
+
327
+ # Compute certainty
328
+ probs = np.abs(self.activated_state) / (np.sum(np.abs(self.activated_state)) + 1e-8)
329
+ probs = np.clip(probs, 1e-10, 1.0)
330
+ entropy = -np.sum(probs * np.log(probs))
331
+ max_entropy = np.log(len(probs))
332
+ certainty = float(1.0 - entropy / (max_entropy + 1e-8))
333
+ certainties.append(certainty)
334
+
335
+ # v3.0: Adaptive Halting - stop early if confident enough
336
+ if adaptive_halting and certainty > certainty_threshold:
337
+ break
338
+
339
+ # Best tick selection
340
+ best_tick_idx = int(np.argmax(certainties))
341
+ best_predictions = all_predictions[best_tick_idx].tolist()
342
+
343
+ return {
344
+ "predictions": self.activated_state[:16].tolist(),
345
+ "best_predictions": best_predictions,
346
+ "certainty": certainties[-1],
347
+ "best_tick": best_tick_idx,
348
+ "ticks_used": ticks_actually_used, # v3.0: Actual ticks, may be < n_ticks
349
+ "max_ticks": n_ticks,
350
+ "halted_early": ticks_actually_used < n_ticks, # v3.0: Flag
351
+ "sync_matrix": sync_matrix[:16, :16].tolist(),
352
+ "model": "SimplifiedCTM v3.0 (NumPy + AdaptiveHalt + SyncDecay)"
353
+ }
354
+
355
+ def train_step(self, input_72d: np.ndarray, target_16d: np.ndarray,
356
+ physics_loss: float = 0.0) -> Dict:
357
+ """
358
+ Online training step.
359
+
360
+ Args:
361
+ input_72d: Input from SNN
362
+ target_16d: Target dendrite adjustments (ground truth)
363
+ physics_loss: Current physics loss for weighting
364
+
365
+ Returns:
366
+ Dict with loss and gradient info
367
+ """
368
+ if not self.is_full or not TORCH_AVAILABLE:
369
+ return {"status": "skip", "reason": "Training requires full PyTorch CTM"}
370
+
371
+ self.model.train()
372
+
373
+ # Prepare tensors
374
+ x = torch.tensor(input_72d, dtype=torch.float32, device=DEVICE)
375
+ x = x.unsqueeze(0).unsqueeze(-1).unsqueeze(-1) # [1, 72, 1, 1]
376
+ y = torch.tensor(target_16d, dtype=torch.float32, device=DEVICE).unsqueeze(0)
377
+
378
+ # Forward
379
+ predictions, certainties, _ = self.model(x)
380
+
381
+ # Loss: dendrite_regulation_loss
382
+ # predictions: [B, 16, T], y: [B, 16]
383
+ y_exp = y.unsqueeze(-1).expand(-1, -1, predictions.size(-1)) # [B, 16, T]
384
+ mse_per_tick = F.mse_loss(predictions, y_exp, reduction='none').mean(dim=1) # [B, T]
385
+
386
+ # Select best tick (min loss) and most certain tick
387
+ loss_min_idx = mse_per_tick.argmin(dim=1) # [B]
388
+ loss_cert_idx = certainties[:, 1, :].argmax(dim=1) # [B]
389
+
390
+ batch_idx = torch.arange(predictions.size(0), device=DEVICE)
391
+ loss_min = mse_per_tick[batch_idx, loss_min_idx].mean()
392
+ loss_cert = mse_per_tick[batch_idx, loss_cert_idx].mean()
393
+
394
+ # Combined loss with physics penalty
395
+ mse_loss = (loss_min + loss_cert) / 2
396
+ physics_penalty = physics_loss * 0.1
397
+ total_loss = mse_loss + physics_penalty
398
+
399
+ # Backward
400
+ self.optimizer.zero_grad()
401
+ total_loss.backward()
402
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
403
+ self.optimizer.step()
404
+
405
+ self.model.eval()
406
+
407
+ # Auto-save checkpoint
408
+ if self.forward_count % self.config["auto_save_every"] == 0:
409
+ self._save_checkpoint()
410
+
411
+ return {
412
+ "status": "trained",
413
+ "loss": float(total_loss.item()),
414
+ "mse_loss": float(mse_loss.item()),
415
+ "physics_penalty": float(physics_penalty),
416
+ "best_tick": int(loss_cert_idx[0].item())
417
+ }
418
+
419
+ def _save_checkpoint(self):
420
+ """Save model checkpoint."""
421
+ if not self.is_full:
422
+ return
423
+
424
+ os.makedirs(self.config["checkpoint_dir"], exist_ok=True)
425
+ path = os.path.join(self.config["checkpoint_dir"], "ctm_art17_latest.pt")
426
+
427
+ torch.save({
428
+ "model_state_dict": self.model.state_dict(),
429
+ "optimizer_state_dict": self.optimizer.state_dict(),
430
+ "forward_count": self.forward_count,
431
+ "timestamp": datetime.now().isoformat()
432
+ }, path)
433
+ print(f"💾 Checkpoint saved: {path}")
434
+
435
+ # Upload to Bunker (Async/Fail-Safe)
436
+ self.bunker.save_file(path, remote_folder="ctm_backups")
437
+
438
+
439
+ def _load_checkpoint(self):
440
+ """Load model checkpoint if exists."""
441
+ path = os.path.join(self.config["checkpoint_dir"], "ctm_art17_latest.pt")
442
+ if os.path.exists(path):
443
+ try:
444
+ checkpoint = torch.load(path, map_location=DEVICE)
445
+ self.model.load_state_dict(checkpoint["model_state_dict"])
446
+ self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
447
+ self.forward_count = checkpoint.get("forward_count", 0)
448
+ print(f"✅ Checkpoint loaded: {path}")
449
+ except Exception as e:
450
+ print(f"⚠️ Could not load checkpoint: {e}")
451
+
452
+ # ============================================================================
453
+ # GLOBAL CTM INSTANCE
454
+ # ============================================================================
455
+ ctm = CTM_ART17(CONFIG)
456
+
457
+ # ============================================================================
458
+ # PHYSICS VALIDATION (from SNN Omega-21)
459
+ # ============================================================================
460
+ def validate_physics(trajectory: List[float], params: Dict) -> Dict:
461
+ """Validate against 5 physics losses from SNN Omega-21."""
462
+ trajectory = np.array(trajectory)
463
+
464
+ # L_energy: Energy conservation
465
+ energy = np.sum(trajectory ** 2)
466
+ P_max = params.get("P_max", CONFIG["physics_thresholds"]["P_max"])
467
+ L_energy = float(max(0, energy - P_max) ** 2)
468
+
469
+ # L_thermo: Thermodynamics (dew point check)
470
+ T_dew = params.get("T_dew", CONFIG["physics_thresholds"]["T_dew"])
471
+ T_amb = params.get("T_amb", CONFIG["physics_thresholds"]["T_amb"])
472
+ L_thermo = float(max(0, T_dew - T_amb) ** 2)
473
+
474
+ # L_causal: Causality (velocity limit)
475
+ velocity = np.diff(trajectory) if len(trajectory) > 1 else np.array([0])
476
+ v_max = params.get("v_max", CONFIG["physics_thresholds"]["v_max"])
477
+ L_causal = float(np.sum(np.maximum(0, np.abs(velocity) - v_max) ** 2))
478
+
479
+ # L_conserv: Flux conservation
480
+ flux_in = params.get("flux_in", 1.0)
481
+ flux_out = params.get("flux_out", 1.0)
482
+ L_conserv = float((flux_in - flux_out) ** 2)
483
+
484
+ # L_entropy: 2nd Law (entropy must increase)
485
+ entropy_change = params.get("entropy_change", 0.1)
486
+ L_entropy = float(max(0, -entropy_change) ** 2)
487
+
488
+ # Total physics loss
489
+ L_total = L_energy + L_thermo + L_causal + L_conserv + L_entropy
490
+
491
+ return {
492
+ "valid": L_total < 0.01,
493
+ "L_energy": L_energy,
494
+ "L_thermo": L_thermo,
495
+ "L_causal": L_causal,
496
+ "L_conserv": L_conserv,
497
+ "L_entropy": L_entropy,
498
+ "L_total": L_total
499
+ }
500
+
501
+ # ============================================================================
502
+ # ENDPOINT FUNCTIONS
503
+ # ============================================================================
504
+
505
+ def sense_snn(snn_json: str) -> str:
506
+ """
507
+ /sense_snn - Process 72D SNN input through CTM
508
+
509
+ Input: JSON with dendrite values or 72D vector
510
+ Output: Coherent features, certainty, sync matrix
511
+ """
512
+ try:
513
+ data = json.loads(snn_json)
514
+
515
+ # Extract 72D vector
516
+ if "vector_72d" in data:
517
+ input_vec = np.array(data["vector_72d"])
518
+ elif "dendrites" in data:
519
+ input_vec = np.array(list(data["dendrites"].values()))
520
+ else:
521
+ input_vec = np.random.randn(72)
522
+
523
+ # Pad to 72D if needed
524
+ if len(input_vec) < 72:
525
+ input_vec = np.pad(input_vec, (0, 72 - len(input_vec)))
526
+
527
+ # Process through CTM
528
+ n_ticks = data.get("ticks", 25)
529
+ result = ctm.forward(input_vec[:72], n_ticks)
530
+
531
+ # Detect anomalies (low certainty)
532
+ anomalies = []
533
+ if result["certainty"] < 0.5:
534
+ anomalies.append("Low overall certainty - consider retraining")
535
+
536
+ return json.dumps({
537
+ "status": "success",
538
+ "coherent_features": result["predictions"],
539
+ "certainty": result["certainty"],
540
+ "best_tick": result["best_tick"],
541
+ "anomalies": anomalies,
542
+ "ticks_used": result["ticks_used"],
543
+ "model": result["model"]
544
+ }, indent=2)
545
+ except Exception as e:
546
+ return json.dumps({"status": "error", "message": str(e)})
547
+
548
+
549
+ def reason_hypergraph(context_json: str) -> str:
550
+ """
551
+ /reason_hypergraph - Reason about hypergraph context, propose edges
552
+
553
+ Uses CTM synchronization matrix to find strongly correlated node pairs.
554
+ These become proposed hyperedges.
555
+ """
556
+ try:
557
+ data = json.loads(context_json)
558
+
559
+ node_features = np.array(data.get("node_features", [[0]*16]*8))
560
+ existing_edges = data.get("existing_edges", [])
561
+ n_ticks = data.get("ticks", 50)
562
+
563
+ # Flatten node features for CTM input and pad to 72D
564
+ flattened = node_features.flatten()
565
+ input_vec = np.zeros(72)
566
+ input_vec[:min(len(flattened), 72)] = flattened[:min(len(flattened), 72)]
567
+
568
+ # Process through CTM with more ticks for reasoning
569
+ result = ctm.forward(input_vec, n_ticks)
570
+
571
+ # Extract proposed edges from sync matrix (S_ij > 0.7)
572
+ proposed_edges = []
573
+ if result["sync_matrix"] is not None:
574
+ sync = np.array(result["sync_matrix"])
575
+ # Ensure sync is 2D
576
+ if len(sync.shape) == 1:
577
+ # 1D array - skip edge extraction
578
+ pass
579
+ elif len(sync.shape) >= 2:
580
+ n_nodes = min(len(node_features), sync.shape[0])
581
+
582
+ for i in range(n_nodes):
583
+ for j in range(i+1, n_nodes):
584
+ if j < sync.shape[1]: # Check bounds
585
+ sync_ij = sync[i, j]
586
+ if sync_ij > 0.7: # Threshold for edge proposal
587
+ edge_exists = any(
588
+ (e[0] == i and e[1] == j) or (e[0] == j and e[1] == i)
589
+ for e in existing_edges
590
+ )
591
+ if not edge_exists:
592
+ proposed_edges.append([i, j, float(sync_ij)])
593
+
594
+ return json.dumps({
595
+ "status": "success",
596
+ "proposed_edges": proposed_edges,
597
+ "certainty": result["certainty"],
598
+ "best_tick": result["best_tick"],
599
+ "ticks_used": result["ticks_used"],
600
+ "model": result["model"]
601
+ }, indent=2)
602
+ except Exception as e:
603
+ return json.dumps({"status": "error", "message": str(e)})
604
+
605
+
606
+ def validate_physics_endpoint(physics_json: str) -> str:
607
+ """
608
+ /validate_physics - Validate trajectory against 5 physics losses
609
+ """
610
+ try:
611
+ data = json.loads(physics_json)
612
+ trajectory = data.get("trajectory", [0.0])
613
+ params = data.get("physics_params", {})
614
+
615
+ result = validate_physics(trajectory, params)
616
+ result["status"] = "success"
617
+
618
+ return json.dumps(result, indent=2)
619
+ except Exception as e:
620
+ return json.dumps({"status": "error", "message": str(e)})
621
+
622
+
623
+ def dream_endpoint(dream_json: str) -> str:
624
+ """
625
+ /dream - Offline consolidation with many ticks
626
+
627
+ Discovers patterns, proposes new edges, identifies edges to prune.
628
+ """
629
+ try:
630
+ data = json.loads(dream_json)
631
+
632
+ snapshot = data.get("hypergraph_snapshot", {})
633
+ n_ticks = min(data.get("ticks", 100), 100) # Cap at 100 for CPU
634
+
635
+ # Extract features from snapshot
636
+ nodes = snapshot.get("nodes", [])
637
+ if nodes:
638
+ input_vec = np.array([n.get("features", [0]*16) for n in nodes]).flatten()[:72]
639
+ else:
640
+ input_vec = np.random.randn(72)
641
+
642
+ # Dream: run CTM with many ticks
643
+ result = ctm.forward(input_vec, n_ticks)
644
+
645
+ # Analyze sync for patterns
646
+ new_edges = []
647
+ pruned_edges = []
648
+
649
+ if result["sync_matrix"] is not None:
650
+ sync = np.array(result["sync_matrix"])
651
+ n = min(len(nodes), sync.shape[0]) if nodes else 16
652
+
653
+ for i in range(n):
654
+ for j in range(i+1, n):
655
+ if sync[i, j] > 0.85:
656
+ new_edges.append([i, j, float(sync[i, j])])
657
+ elif sync[i, j] < 0.1:
658
+ pruned_edges.append([i, j])
659
+
660
+ return json.dumps({
661
+ "status": "success",
662
+ "discovered_patterns": len(new_edges),
663
+ "new_edges": new_edges[:10],
664
+ "pruned_edges": pruned_edges[:10],
665
+ "consolidation_certainty": result["certainty"],
666
+ "ticks_used": result["ticks_used"],
667
+ "model": result["model"]
668
+ }, indent=2)
669
+ except Exception as e:
670
+ return json.dumps({"status": "error", "message": str(e)})
671
+
672
+
673
+ def calibrate_stdp_endpoint(stdp_json: str) -> str:
674
+ """
675
+ /calibrate_stdp - Suggest STDP weight adjustments
676
+
677
+ This is the CORE regulatory function:
678
+ - Receives current 16 dendrite weights
679
+ - Processes through CTM to get sync patterns
680
+ - Returns suggested weight adjustments
681
+ """
682
+ try:
683
+ data = json.loads(stdp_json)
684
+
685
+ current_weights = np.array(data.get("current_weights", [1.0]*16))
686
+ node_features = np.array(data.get("node_features", [[0]*16]*4))
687
+
688
+ # Flatten features for CTM input
689
+ input_vec = node_features.flatten()[:72]
690
+
691
+ # Process through CTM
692
+ result = ctm.forward(input_vec, n_ticks=25)
693
+
694
+ # Use predictions as weight adjustments
695
+ predictions = np.array(result["best_predictions"])
696
+
697
+ # Scale based on certainty
698
+ confidence = result["certainty"]
699
+ weight_changes = (predictions - 0.5) * confidence * 0.1
700
+
701
+ new_weights = current_weights + weight_changes
702
+
703
+ return json.dumps({
704
+ "status": "success",
705
+ "suggested_weights": new_weights.tolist(),
706
+ "weight_changes": weight_changes.tolist(),
707
+ "confidence": confidence,
708
+ "best_tick": result["best_tick"],
709
+ "model": result["model"]
710
+ }, indent=2)
711
+ except Exception as e:
712
+ return json.dumps({"status": "error", "message": str(e)})
713
+
714
+
715
+ def regulate_endpoint(regulate_json: str) -> str:
716
+ """
717
+ /regulate - Full feedback loop for ART-17 regulation (NEW)
718
+
719
+ Combines all signals to provide comprehensive regulation:
720
+ - Dendrite state
721
+ - Latent representation
722
+ - Physics loss
723
+ - Anomaly score
724
+
725
+ Returns action recommendation with confidence.
726
+ """
727
+ try:
728
+ data = json.loads(regulate_json)
729
+
730
+ # Inputs from local system
731
+ dendrites = np.array(data.get("dendrites", [0.0]*16))
732
+ latent_256 = np.array(data.get("latent_256", [0.0]*256))
733
+ physics_loss = data.get("physics_loss", 0.0)
734
+ anomaly_score = data.get("anomaly_score", 0.0)
735
+
736
+ # Combine into 72D input
737
+ input_72 = np.concatenate([
738
+ dendrites, # 16D
739
+ latent_256[:56] # 56D from latent
740
+ ])
741
+
742
+ # Process through CTM
743
+ result = ctm.forward(input_72, n_ticks=50)
744
+
745
+ # Compute regulation signals
746
+ predictions = np.array(result["best_predictions"])
747
+ certainty = result["certainty"]
748
+
749
+ # Urgency based on physics and anomaly
750
+ urgency = min(1.0, physics_loss + anomaly_score)
751
+ regulation_strength = urgency * certainty
752
+
753
+ # Weight adjustments
754
+ dendrite_deltas = predictions * regulation_strength * 0.05
755
+
756
+ # Determine if intervention needed
757
+ needs_intervention = urgency > 0.5 or certainty < 0.3
758
+
759
+ return json.dumps({
760
+ "status": "success",
761
+ "dendrite_deltas": dendrite_deltas.tolist(),
762
+ "regulation_strength": float(regulation_strength),
763
+ "confidence": certainty,
764
+ "urgency": float(urgency),
765
+ "needs_intervention": needs_intervention,
766
+ "recommended_action": "ADJUST" if needs_intervention else "MAINTAIN",
767
+ "best_tick": result["best_tick"],
768
+ "model": result["model"]
769
+ }, indent=2)
770
+ except Exception as e:
771
+ return json.dumps({"status": "error", "message": str(e)})
772
+
773
+
774
+ def train_online_endpoint(train_json: str) -> str:
775
+ """
776
+ /train_online - Progressive online training (NEW)
777
+
778
+ Allows the local system to train the CTM with experience.
779
+ Sends input-output pairs and receives training feedback.
780
+ """
781
+ try:
782
+ data = json.loads(train_json)
783
+
784
+ input_72d = np.array(data.get("input_72d", [0.0]*72))
785
+ target_16d = np.array(data.get("target_16d", [0.0]*16))
786
+ physics_loss = data.get("physics_loss", 0.0)
787
+
788
+ # Perform training step
789
+ result = ctm.train_step(input_72d, target_16d, physics_loss)
790
+
791
+ return json.dumps({
792
+ "status": result["status"],
793
+ "loss": result.get("loss"),
794
+ "mse_loss": result.get("mse_loss"),
795
+ "physics_penalty": result.get("physics_penalty"),
796
+ "best_tick": result.get("best_tick"),
797
+ "forward_count": ctm.forward_count,
798
+ "message": "Training step completed" if result["status"] == "trained" else result.get("reason")
799
+ }, indent=2)
800
+ except Exception as e:
801
+ return json.dumps({"status": "error", "message": str(e)})
802
+
803
+
804
+ def health_check() -> str:
805
+ """Health check with model info."""
806
+ return json.dumps({
807
+ "status": "healthy",
808
+ "model": f"CTM Nervous System v2.0 ({'Full PyTorch' if ctm.is_full else 'NumPy Fallback'})",
809
+ "device": DEVICE,
810
+ "d_model": CONFIG["d_model"],
811
+ "iterations": CONFIG["iterations"],
812
+ "memory_length": CONFIG["memory_length"],
813
+ "forward_count": ctm.forward_count,
814
+ "endpoints": [
815
+ "/sense_snn",
816
+ "/reason_hypergraph",
817
+ "/validate_physics",
818
+ "/dream",
819
+ "/calibrate_stdp",
820
+ "/regulate", # NEW
821
+ "/train_online" # NEW
822
+ ]
823
+ }, indent=2)
824
+
825
+
826
+ # ============================================================================
827
+ # GRADIO INTERFACE
828
+ # ============================================================================
829
+
830
+ with gr.Blocks(title="CTM Nervous System v2.0", theme=gr.themes.Soft()) as demo:
831
+ gr.Markdown("""
832
+ # 🧬 CTM Nervous System v2.0
833
+ **Continuous Thought Machine for ART-17 Hypergraph Coherence**
834
+
835
+ Based on [arXiv:2505.05522](https://arxiv.org/abs/2505.05522) - Sakana AI
836
+
837
+ ---
838
+
839
+ ## Key Innovations
840
+ - **NLMs (Neuron-Level Models)**: Each neuron processes its own history
841
+ - **Neural Synchronization**: Representation via S = Z·Z^T
842
+ - **Adaptive Compute**: Halts when confident
843
+ - **Online Training**: Progressive learning with use
844
+
845
+ ---
846
+ """)
847
+
848
+ with gr.Tabs():
849
+ with gr.Tab("🔌 /sense_snn"):
850
+ gr.Markdown("Process 72D SNN input through CTM")
851
+ snn_input = gr.Textbox(
852
+ label="SNN JSON Input",
853
+ value='{"dendrites": {"d1": 0.1, "d2": 0.2, "d3": 0.3}, "ticks": 25}',
854
+ lines=5
855
+ )
856
+ snn_output = gr.Textbox(label="Output", lines=10)
857
+ snn_btn = gr.Button("Process", variant="primary")
858
+ snn_btn.click(sense_snn, inputs=snn_input, outputs=snn_output, api_name="sense_snn")
859
+
860
+ with gr.Tab("🧠 /reason_hypergraph"):
861
+ gr.Markdown("Reason about hypergraph context, propose edges")
862
+ reason_input = gr.Textbox(
863
+ label="Context JSON",
864
+ value='{"node_features": [[0.1, 0.2], [0.3, 0.4]], "existing_edges": [], "ticks": 50}',
865
+ lines=5
866
+ )
867
+ reason_output = gr.Textbox(label="Output", lines=10)
868
+ reason_btn = gr.Button("Reason", variant="primary")
869
+ reason_btn.click(reason_hypergraph, inputs=reason_input, outputs=reason_output, api_name="reason_hypergraph")
870
+
871
+ with gr.Tab("⚡ /validate_physics"):
872
+ gr.Markdown("Validate trajectory against 5 physics losses")
873
+ physics_input = gr.Textbox(
874
+ label="Physics JSON",
875
+ value='{"trajectory": [0.1, 0.2, 0.3], "physics_params": {"P_max": 1000}}',
876
+ lines=5
877
+ )
878
+ physics_output = gr.Textbox(label="Output", lines=10)
879
+ physics_btn = gr.Button("Validate", variant="primary")
880
+ physics_btn.click(validate_physics_endpoint, inputs=physics_input, outputs=physics_output, api_name="validate_physics")
881
+
882
+ with gr.Tab("💤 /dream"):
883
+ gr.Markdown("Offline consolidation - discover patterns")
884
+ dream_input = gr.Textbox(
885
+ label="Dream JSON",
886
+ value='{"hypergraph_snapshot": {"nodes": []}, "ticks": 100}',
887
+ lines=5
888
+ )
889
+ dream_output = gr.Textbox(label="Output", lines=10)
890
+ dream_btn = gr.Button("Dream", variant="primary")
891
+ dream_btn.click(dream_endpoint, inputs=dream_input, outputs=dream_output, api_name="dream")
892
+
893
+ with gr.Tab("🔧 /calibrate_stdp"):
894
+ gr.Markdown("Calibrate STDP weights (Core regulatory function)")
895
+ stdp_input = gr.Textbox(
896
+ label="STDP JSON",
897
+ value='{"current_weights": [1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1], "node_features": [[0.1, 0.2]]}',
898
+ lines=5
899
+ )
900
+ stdp_output = gr.Textbox(label="Output", lines=10)
901
+ stdp_btn = gr.Button("Calibrate", variant="primary")
902
+ stdp_btn.click(calibrate_stdp_endpoint, inputs=stdp_input, outputs=stdp_output, api_name="calibrate_stdp")
903
+
904
+ with gr.Tab("🎯 /regulate [NEW]"):
905
+ gr.Markdown("Full feedback loop for ART-17 regulation")
906
+ regulate_input = gr.Textbox(
907
+ label="Regulate JSON",
908
+ value='{"dendrites": [0.5]*16, "latent_256": [0.1]*256, "physics_loss": 0.01, "anomaly_score": 0.05}',
909
+ lines=5
910
+ )
911
+ regulate_output = gr.Textbox(label="Output", lines=10)
912
+ regulate_btn = gr.Button("Regulate", variant="primary")
913
+ regulate_btn.click(regulate_endpoint, inputs=regulate_input, outputs=regulate_output, api_name="regulate")
914
+
915
+ with gr.Tab("📚 /train_online [NEW]"):
916
+ gr.Markdown("Progressive online training with experience")
917
+ train_input = gr.Textbox(
918
+ label="Training JSON",
919
+ value='{"input_72d": [0.1]*72, "target_16d": [0.5]*16, "physics_loss": 0.01}',
920
+ lines=5
921
+ )
922
+ train_output = gr.Textbox(label="Output", lines=10)
923
+ train_btn = gr.Button("Train Step", variant="primary")
924
+ train_btn.click(train_online_endpoint, inputs=train_input, outputs=train_output, api_name="train_online")
925
+
926
+ with gr.Tab("❤️ Health"):
927
+ health_output = gr.Textbox(label="Health Status", lines=15)
928
+ health_btn = gr.Button("Check Health", variant="secondary")
929
+ health_btn.click(health_check, inputs=None, outputs=health_output, api_name="health_check")
930
+
931
+ gr.Markdown("""
932
+ ---
933
+ **Architecture**: CTM as Nervous System → Hypergraph as Coherent Thought
934
+
935
+ **Integration**: Local ART-17 ↔ CTM (regulation) ↔ Brain Server (semantics)
936
+
937
+ **Training**: Progressive online learning + Physics-Informed Loss
938
+ """)
939
+
940
+ if __name__ == "__main__":
941
+ demo.launch(
942
+ server_name="0.0.0.0",
943
+ server_port=7860,
944
+ show_error=True
945
+ )
app_v1_backup.py ADDED
@@ -0,0 +1,464 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CTM Nervous System Server - Continuous Thought Machine for Hypergraph Maintenance
3
+ ===================================================================================
4
+ Implementation of the definitive proposal: CTM as Nervous System for ART-17 Hypergraph.
5
+
6
+ Endpoints:
7
+ - /sense_snn: Process 72D SNN input with NLM-style processing
8
+ - /reason_hypergraph: Reason about hypergraph context, propose edges
9
+ - /validate_physics: Validate proposals against 5 physics losses
10
+ - /dream: Offline consolidation with T=500+ ticks
11
+ - /calibrate_stdp: Suggest STDP weight adjustments from sync matrix
12
+ - /health: Health check endpoint
13
+
14
+ Based on: arXiv:2505.05522 (Continuous Thought Machines - Sakana AI)
15
+ """
16
+
17
+ import gradio as gr
18
+ import numpy as np
19
+ import json
20
+ from typing import List, Dict, Any, Optional
21
+ import os
22
+
23
+ # ============================================================================
24
+ # SIMPLIFIED CTM SIMULATION (CPU-only for Hugging Face free tier)
25
+ # ============================================================================
26
+
27
+ class SimplifiedCTM:
28
+ """
29
+ Simplified CTM for CPU-only environment.
30
+ Simulates the key mechanisms without full PyTorch model.
31
+ """
32
+
33
+ def __init__(self, d_model: int = 256, memory_length: int = 16, n_ticks: int = 50):
34
+ self.d_model = d_model
35
+ self.memory_length = memory_length
36
+ self.n_ticks = n_ticks
37
+
38
+ # Initialize state
39
+ self.state_trace = np.zeros((d_model, memory_length))
40
+ self.activated_state = np.random.randn(d_model) * 0.1
41
+
42
+ # NLM weights (simplified: one weight matrix per "neuron group")
43
+ self.nlm_weights = np.random.randn(16, memory_length) * 0.1 # 16 groups for 16 dendrites
44
+
45
+ def compute_sync_matrix(self, z: np.ndarray) -> np.ndarray:
46
+ """S^t = Z · Z^T (normalized)"""
47
+ z_norm = z / (np.linalg.norm(z) + 1e-8)
48
+ S = np.outer(z_norm, z_norm)
49
+ return S
50
+
51
+ def compute_certainty(self, predictions: np.ndarray) -> float:
52
+ """Certainty = 1 - normalized entropy"""
53
+ probs = np.abs(predictions) / (np.sum(np.abs(predictions)) + 1e-8)
54
+ probs = np.clip(probs, 1e-10, 1.0)
55
+ entropy = -np.sum(probs * np.log(probs))
56
+ max_entropy = np.log(len(probs))
57
+ normalized_entropy = entropy / (max_entropy + 1e-8)
58
+ return float(1.0 - normalized_entropy)
59
+
60
+ def process_ticks(self, input_features: np.ndarray, n_ticks: Optional[int] = None) -> Dict:
61
+ """Run T internal ticks and return sync matrix + certainty"""
62
+ n_ticks = n_ticks or self.n_ticks
63
+
64
+ # Ensure input is right size
65
+ if len(input_features) < self.d_model:
66
+ input_features = np.pad(input_features, (0, self.d_model - len(input_features)))
67
+ else:
68
+ input_features = input_features[:self.d_model]
69
+
70
+ certainties = []
71
+ sync_matrices = []
72
+
73
+ for t in range(n_ticks):
74
+ # Simulate synapse update
75
+ combined = np.concatenate([self.activated_state, input_features[:self.d_model//2]])
76
+ pre_activation = np.tanh(combined[:self.d_model] * 0.1 + np.random.randn(self.d_model) * 0.01)
77
+
78
+ # Update trace
79
+ self.state_trace = np.roll(self.state_trace, -1, axis=1)
80
+ self.state_trace[:, -1] = pre_activation
81
+
82
+ # Simulate NLM (simplified)
83
+ post_activation = np.zeros(self.d_model)
84
+ group_size = self.d_model // 16
85
+ for g in range(16):
86
+ start = g * group_size
87
+ end = start + group_size
88
+ group_trace = self.state_trace[start:end, :]
89
+ group_output = np.mean(group_trace @ self.nlm_weights[g])
90
+ post_activation[start:end] = np.tanh(group_output)
91
+
92
+ self.activated_state = post_activation
93
+
94
+ # Compute sync and certainty
95
+ sync = self.compute_sync_matrix(self.activated_state)
96
+ cert = self.compute_certainty(self.activated_state)
97
+
98
+ sync_matrices.append(sync)
99
+ certainties.append(cert)
100
+
101
+ # Find best ticks (min-loss proxy: max certainty)
102
+ best_tick = int(np.argmax(certainties))
103
+
104
+ return {
105
+ "final_sync_matrix": sync_matrices[-1].tolist(),
106
+ "best_sync_matrix": sync_matrices[best_tick].tolist(),
107
+ "certainties": certainties,
108
+ "final_certainty": float(certainties[-1]),
109
+ "max_certainty": float(max(certainties)),
110
+ "best_tick": best_tick,
111
+ "ticks_used": n_ticks
112
+ }
113
+
114
+ # Global CTM instance
115
+ ctm = SimplifiedCTM(d_model=256, memory_length=16, n_ticks=50)
116
+
117
+ # ============================================================================
118
+ # PHYSICS VALIDATION (from SNN Omega-21)
119
+ # ============================================================================
120
+
121
+ def validate_physics(trajectory: List[float], params: Dict) -> Dict:
122
+ """Validate against 5 physics losses from SNN Omega-21"""
123
+ trajectory = np.array(trajectory)
124
+
125
+ # L_energy: Energy conservation
126
+ energy = np.sum(trajectory ** 2)
127
+ P_max = params.get("P_max", 1000.0)
128
+ L_energy = float(max(0, energy - P_max) ** 2)
129
+
130
+ # L_thermo: Thermodynamics (dew point check)
131
+ T_dew = params.get("T_dew", 15.0)
132
+ T_amb = params.get("T_amb", 25.0)
133
+ L_thermo = float(max(0, T_dew - T_amb) ** 2)
134
+
135
+ # L_causal: Causality (velocity limit)
136
+ velocity = np.diff(trajectory) if len(trajectory) > 1 else np.array([0])
137
+ v_max = params.get("v_max", 100.0)
138
+ L_causal = float(np.sum(np.maximum(0, np.abs(velocity) - v_max) ** 2))
139
+
140
+ # L_conserv: Flux conservation
141
+ flux_in = params.get("flux_in", 1.0)
142
+ flux_out = params.get("flux_out", 1.0)
143
+ L_conserv = float((flux_in - flux_out) ** 2)
144
+
145
+ # L_entropy: 2nd Law (entropy must increase)
146
+ entropy_change = params.get("entropy_change", 0.1)
147
+ L_entropy = float(max(0, -entropy_change) ** 2)
148
+
149
+ # Total physics loss
150
+ L_total = L_energy + L_thermo + L_causal + L_conserv + L_entropy
151
+
152
+ return {
153
+ "valid": L_total < 0.01,
154
+ "L_energy": L_energy,
155
+ "L_thermo": L_thermo,
156
+ "L_causal": L_causal,
157
+ "L_conserv": L_conserv,
158
+ "L_entropy": L_entropy,
159
+ "L_total": L_total
160
+ }
161
+
162
+ # ============================================================================
163
+ # ENDPOINT FUNCTIONS
164
+ # ============================================================================
165
+
166
+ def sense_snn(snn_json: str) -> str:
167
+ """
168
+ /sense_snn - Process 72D SNN input
169
+ Input: JSON with dendrite values
170
+ Output: Coherent features + anomalies
171
+ """
172
+ try:
173
+ data = json.loads(snn_json)
174
+
175
+ # Extract 72D vector (or create from dendrites)
176
+ if "vector_72d" in data:
177
+ input_vec = np.array(data["vector_72d"])
178
+ elif "dendrites" in data:
179
+ dendrite_values = list(data["dendrites"].values())
180
+ input_vec = np.array(dendrite_values)
181
+ else:
182
+ input_vec = np.random.randn(72) # Fallback
183
+
184
+ # Pad to 256D
185
+ input_256 = np.zeros(256)
186
+ input_256[:min(len(input_vec), 256)] = input_vec[:min(len(input_vec), 256)]
187
+
188
+ # Process through CTM
189
+ result = ctm.process_ticks(input_256, n_ticks=25)
190
+
191
+ # Detect anomalies (low certainty regions)
192
+ anomalies = []
193
+ if result["final_certainty"] < 0.5:
194
+ anomalies.append("Low overall certainty")
195
+
196
+ return json.dumps({
197
+ "status": "success",
198
+ "coherent_features": result["final_sync_matrix"][:16][:16], # 16x16 subset
199
+ "certainty": result["final_certainty"],
200
+ "anomalies": anomalies,
201
+ "ticks_used": result["ticks_used"]
202
+ }, indent=2)
203
+ except Exception as e:
204
+ return json.dumps({"status": "error", "message": str(e)})
205
+
206
+ def reason_hypergraph(context_json: str) -> str:
207
+ """
208
+ /reason_hypergraph - Reason about hypergraph, propose edges
209
+ Input: Node features + existing edges
210
+ Output: Proposed new edges + certainty
211
+ """
212
+ try:
213
+ data = json.loads(context_json)
214
+
215
+ node_features = np.array(data.get("node_features", [[0]*16]*8))
216
+ existing_edges = data.get("existing_edges", [])
217
+ n_ticks = data.get("ticks", 50)
218
+
219
+ # Flatten node features for CTM input
220
+ input_vec = node_features.flatten()
221
+
222
+ # Process through CTM with more ticks for reasoning
223
+ result = ctm.process_ticks(input_vec, n_ticks=n_ticks)
224
+
225
+ # Extract proposed edges from sync matrix (S_ij > 0.8)
226
+ sync = np.array(result["best_sync_matrix"])
227
+ proposed_edges = []
228
+
229
+ n_nodes = min(len(node_features), sync.shape[0])
230
+ for i in range(n_nodes):
231
+ for j in range(i+1, n_nodes):
232
+ sync_ij = sync[i, j]
233
+ if sync_ij > 0.8:
234
+ # Check if edge already exists
235
+ edge_exists = any(
236
+ (e[0] == i and e[1] == j) or (e[0] == j and e[1] == i)
237
+ for e in existing_edges
238
+ )
239
+ if not edge_exists:
240
+ proposed_edges.append([i, j, float(sync_ij)])
241
+
242
+ return json.dumps({
243
+ "status": "success",
244
+ "proposed_edges": proposed_edges,
245
+ "certainty": result["max_certainty"],
246
+ "best_tick": result["best_tick"],
247
+ "ticks_used": result["ticks_used"]
248
+ }, indent=2)
249
+ except Exception as e:
250
+ return json.dumps({"status": "error", "message": str(e)})
251
+
252
+ def validate_physics_endpoint(physics_json: str) -> str:
253
+ """
254
+ /validate_physics - Validate trajectory against 5 physics losses
255
+ """
256
+ try:
257
+ data = json.loads(physics_json)
258
+ trajectory = data.get("trajectory", [0.0])
259
+ params = data.get("physics_params", {})
260
+
261
+ result = validate_physics(trajectory, params)
262
+ result["status"] = "success"
263
+
264
+ return json.dumps(result, indent=2)
265
+ except Exception as e:
266
+ return json.dumps({"status": "error", "message": str(e)})
267
+
268
+ def dream_endpoint(dream_json: str) -> str:
269
+ """
270
+ /dream - Offline consolidation with T=500+ ticks
271
+ Input: Hypergraph snapshot
272
+ Output: Discovered patterns + new edges
273
+ """
274
+ try:
275
+ data = json.loads(dream_json)
276
+
277
+ snapshot = data.get("hypergraph_snapshot", {})
278
+ n_ticks = data.get("ticks", 500)
279
+
280
+ # Extract features from snapshot
281
+ nodes = snapshot.get("nodes", [])
282
+ if nodes:
283
+ input_vec = np.array([n.get("features", [0]*16) for n in nodes]).flatten()
284
+ else:
285
+ input_vec = np.random.randn(256) # Random dream if no nodes
286
+
287
+ # Dream: run CTM with many ticks and no external input after initial
288
+ result = ctm.process_ticks(input_vec, n_ticks=min(n_ticks, 100)) # Cap at 100 for CPU
289
+
290
+ # Analyze sync evolution to find patterns
291
+ sync = np.array(result["final_sync_matrix"])
292
+
293
+ # Find strong sync pairs (new edges)
294
+ new_edges = []
295
+ n = min(len(nodes), sync.shape[0]) if nodes else 16
296
+ for i in range(n):
297
+ for j in range(i+1, n):
298
+ if sync[i, j] > 0.85:
299
+ new_edges.append([i, j, float(sync[i, j])])
300
+
301
+ # Find weak sync pairs (edges to prune)
302
+ pruned_edges = []
303
+ for i in range(n):
304
+ for j in range(i+1, n):
305
+ if sync[i, j] < 0.1:
306
+ pruned_edges.append([i, j])
307
+
308
+ return json.dumps({
309
+ "status": "success",
310
+ "discovered_patterns": len(new_edges),
311
+ "new_edges": new_edges[:10], # Top 10
312
+ "pruned_edges": pruned_edges[:10], # Top 10
313
+ "consolidation_certainty": result["max_certainty"],
314
+ "ticks_used": result["ticks_used"]
315
+ }, indent=2)
316
+ except Exception as e:
317
+ return json.dumps({"status": "error", "message": str(e)})
318
+
319
+ def calibrate_stdp_endpoint(stdp_json: str) -> str:
320
+ """
321
+ /calibrate_stdp - Suggest STDP weight adjustments from sync
322
+ """
323
+ try:
324
+ data = json.loads(stdp_json)
325
+
326
+ current_weights = np.array(data.get("current_weights", [1.0]*16))
327
+ node_features = np.array(data.get("node_features", [[0]*16]*8))
328
+
329
+ # Process to get sync matrix
330
+ input_vec = node_features.flatten()
331
+ result = ctm.process_ticks(input_vec, n_ticks=25)
332
+
333
+ sync = np.array(result["final_sync_matrix"])
334
+
335
+ # Suggest weight adjustments based on sync patterns
336
+ # Uses diagonal of sync (self-similarity) to scale weights
337
+ suggested = np.zeros(16)
338
+ for i in range(16):
339
+ # Average sync of neuron i with others
340
+ avg_sync = np.mean(sync[i, :])
341
+ # Scale current weight by sync
342
+ suggested[i] = current_weights[i] * (0.5 + avg_sync)
343
+
344
+ return json.dumps({
345
+ "status": "success",
346
+ "suggested_weights": suggested.tolist(),
347
+ "weight_changes": (suggested - current_weights).tolist(),
348
+ "confidence": result["final_certainty"]
349
+ }, indent=2)
350
+ except Exception as e:
351
+ return json.dumps({"status": "error", "message": str(e)})
352
+
353
+ def health_check() -> str:
354
+ """Health check for the CTM server"""
355
+ return json.dumps({
356
+ "status": "healthy",
357
+ "model": "CTM Nervous System v1.0",
358
+ "d_model": ctm.d_model,
359
+ "memory_length": ctm.memory_length,
360
+ "default_ticks": ctm.n_ticks,
361
+ "endpoints": [
362
+ "/sense_snn",
363
+ "/reason_hypergraph",
364
+ "/validate_physics",
365
+ "/dream",
366
+ "/calibrate_stdp"
367
+ ]
368
+ }, indent=2)
369
+
370
+ # ============================================================================
371
+ # GRADIO INTERFACE
372
+ # ============================================================================
373
+
374
+ with gr.Blocks(title="CTM Nervous System") as demo:
375
+ gr.Markdown("""
376
+ # 🧬 CTM Nervous System
377
+ **Continuous Thought Machine for Hypergraph Maintenance**
378
+
379
+ Based on [arXiv:2505.05522](https://arxiv.org/abs/2505.05522) - Sakana AI
380
+
381
+ ---
382
+
383
+ ## Endpoints
384
+ - **/sense_snn**: Process 72D SNN input
385
+ - **/reason_hypergraph**: Reason about context, propose edges
386
+ - **/validate_physics**: Validate against 5 physics losses
387
+ - **/dream**: Offline consolidation (T=500+)
388
+ - **/calibrate_stdp**: Suggest STDP weight adjustments
389
+ """)
390
+
391
+ with gr.Tabs():
392
+ with gr.Tab("🔌 /sense_snn"):
393
+ gr.Markdown("Process 72D SNN input vector")
394
+ snn_input = gr.Textbox(
395
+ label="SNN JSON Input",
396
+ value='{"dendrites": {"d1": 0.1, "d2": 0.2, "d3": 0.3}}',
397
+ lines=5
398
+ )
399
+ snn_output = gr.Textbox(label="Output", lines=10)
400
+ snn_btn = gr.Button("Process", variant="primary")
401
+ snn_btn.click(sense_snn, inputs=snn_input, outputs=snn_output, api_name="sense_snn")
402
+
403
+ with gr.Tab("🧠 /reason_hypergraph"):
404
+ gr.Markdown("Reason about hypergraph context")
405
+ reason_input = gr.Textbox(
406
+ label="Context JSON",
407
+ value='{"node_features": [[0.1, 0.2], [0.3, 0.4]], "existing_edges": [], "ticks": 50}',
408
+ lines=5
409
+ )
410
+ reason_output = gr.Textbox(label="Output", lines=10)
411
+ reason_btn = gr.Button("Reason", variant="primary")
412
+ reason_btn.click(reason_hypergraph, inputs=reason_input, outputs=reason_output, api_name="reason_hypergraph")
413
+
414
+ with gr.Tab("⚡ /validate_physics"):
415
+ gr.Markdown("Validate against 5 physics losses")
416
+ physics_input = gr.Textbox(
417
+ label="Physics JSON",
418
+ value='{"trajectory": [0.1, 0.2, 0.3], "physics_params": {"P_max": 1000}}',
419
+ lines=5
420
+ )
421
+ physics_output = gr.Textbox(label="Output", lines=10)
422
+ physics_btn = gr.Button("Validate", variant="primary")
423
+ physics_btn.click(validate_physics_endpoint, inputs=physics_input, outputs=physics_output, api_name="validate_physics")
424
+
425
+ with gr.Tab("💤 /dream"):
426
+ gr.Markdown("Offline consolidation")
427
+ dream_input = gr.Textbox(
428
+ label="Dream JSON",
429
+ value='{"hypergraph_snapshot": {"nodes": []}, "ticks": 100}',
430
+ lines=5
431
+ )
432
+ dream_output = gr.Textbox(label="Output", lines=10)
433
+ dream_btn = gr.Button("Dream", variant="primary")
434
+ dream_btn.click(dream_endpoint, inputs=dream_input, outputs=dream_output, api_name="dream")
435
+
436
+ with gr.Tab("🔧 /calibrate_stdp"):
437
+ gr.Markdown("Calibrate STDP weights")
438
+ stdp_input = gr.Textbox(
439
+ label="STDP JSON",
440
+ value='{"current_weights": [1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1], "node_features": [[0.1, 0.2]]}',
441
+ lines=5
442
+ )
443
+ stdp_output = gr.Textbox(label="Output", lines=10)
444
+ stdp_btn = gr.Button("Calibrate", variant="primary")
445
+ stdp_btn.click(calibrate_stdp_endpoint, inputs=stdp_input, outputs=stdp_output, api_name="calibrate_stdp")
446
+
447
+ with gr.Tab("❤️ Health"):
448
+ health_output = gr.Textbox(label="Health Status", lines=10)
449
+ health_btn = gr.Button("Check Health", variant="secondary")
450
+ health_btn.click(health_check, inputs=None, outputs=health_output, api_name="health_check")
451
+
452
+ gr.Markdown("""
453
+ ---
454
+ **Architecture**: CTM as Nervous System → Hypergraph as Thought
455
+
456
+ **Training**: Min-Loss + Max-Certainty + Physics Regularization
457
+ """)
458
+
459
+ if __name__ == "__main__":
460
+ demo.launch(
461
+ server_name="0.0.0.0",
462
+ server_port=7860,
463
+ show_error=True
464
+ )
data/custom_datasets.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision.datasets import ImageFolder
3
+ from torch.utils.data import Dataset
4
+ import random
5
+ import numpy as np
6
+ from tqdm.auto import tqdm
7
+ from PIL import Image
8
+ from datasets import load_dataset
9
+
10
+ class SortDataset(Dataset):
11
+ def __init__(self, N):
12
+ self.N = N
13
+ def __len__(self):
14
+ return 10000000
15
+ def __getitem__(self, idx):
16
+ data = torch.zeros(self.N).normal_()
17
+ ordering = torch.argsort(data)
18
+ inputs = data
19
+ return (inputs), (ordering)
20
+
21
+ class QAMNISTDataset(Dataset):
22
+ """A QAMNIST dataset that includes plus and minus operations on MNIST digits."""
23
+ def __init__(self, base_dataset, num_images, num_images_delta, num_repeats_per_input, num_operations, num_operations_delta):
24
+ self.base_dataset = base_dataset
25
+
26
+ self.num_images = num_images
27
+ self.num_images_delta = num_images_delta
28
+ self.num_images_range = self._calculate_num_images_range()
29
+
30
+ self.operators = ["+", "-"]
31
+ self.num_operations = num_operations
32
+ self.num_operations_delta = num_operations_delta
33
+ self.num_operations_range = self._calculate_num_operations_range()
34
+
35
+ self.num_repeats_per_input = num_repeats_per_input
36
+
37
+ self.current_num_digits = num_images
38
+ self.current_num_operations = num_operations
39
+
40
+ self.modulo_base = 10
41
+
42
+ self.output_range = [0, 9]
43
+
44
+ def _calculate_num_images_range(self):
45
+ min_val = self.num_images - self.num_images_delta
46
+ max_val = self.num_images + self.num_images_delta
47
+ assert min_val >= 1, f"Minimum number of images must be at least 1, got {min_val}"
48
+ return [min_val, max_val]
49
+
50
+ def _calculate_num_operations_range(self):
51
+ min_val = self.num_operations - self.num_operations_delta
52
+ max_val = self.num_operations + self.num_operations_delta
53
+ assert min_val >= 1, f"Minimum number of operations must be at least 1, got {min_val}"
54
+ return [min_val, max_val]
55
+
56
+ def set_num_digits(self, num_digits):
57
+ self.current_num_digits = num_digits
58
+
59
+ def set_num_operations(self, num_operations):
60
+ self.current_num_operations = num_operations
61
+
62
+ def _get_target_and_question(self, targets):
63
+ question = []
64
+ equations = []
65
+ num_digits = self.current_num_digits
66
+ num_operations = self.current_num_operations
67
+
68
+ # Select the initial digit
69
+ selection_idx = np.random.randint(num_digits)
70
+ first_digit = targets[selection_idx]
71
+ question.extend([selection_idx] * self.num_repeats_per_input)
72
+ # Set current_value to the initial digit (mod is applied in each operation)
73
+ current_value = first_digit % self.modulo_base
74
+
75
+ # For each operation, build an equation line
76
+ for _ in range(num_operations):
77
+ # Choose the operator ('+' or '-')
78
+ operator_idx = np.random.randint(len(self.operators))
79
+ operator = self.operators[operator_idx]
80
+ encoded_operator = -(operator_idx + 1) # -1 for '+', -2 for '-'
81
+ question.extend([encoded_operator] * self.num_repeats_per_input)
82
+
83
+ # Choose the next digit
84
+ selection_idx = np.random.randint(num_digits)
85
+ digit = targets[selection_idx]
86
+ question.extend([selection_idx] * self.num_repeats_per_input)
87
+
88
+ # Compute the new value with immediate modulo reduction
89
+ if operator == '+':
90
+ new_value = (current_value + digit) % self.modulo_base
91
+ else: # operator is '-'
92
+ new_value = (current_value - digit) % self.modulo_base
93
+
94
+ # Build the equation string for this step
95
+ equations.append(f"({current_value} {operator} {digit}) mod {self.modulo_base} = {new_value}")
96
+ # Update current value for the next operation
97
+ current_value = new_value
98
+
99
+ target = current_value
100
+ question_readable = "\n".join(equations)
101
+ return target, question, question_readable
102
+
103
+ def __len__(self):
104
+ return len(self.base_dataset)
105
+
106
+ def __getitem__(self, idx):
107
+ images, targets = [],[]
108
+ for _ in range(self.current_num_digits):
109
+ image, target = self.base_dataset[np.random.randint(self.__len__())]
110
+ images.append(image)
111
+ targets.append(target)
112
+
113
+ observations = torch.repeat_interleave(torch.stack(images, 0), repeats=self.num_repeats_per_input, dim=0)
114
+ target, question, question_readable = self._get_target_and_question(targets)
115
+ return observations, question, question_readable, target
116
+
117
+ class ImageNet(Dataset):
118
+ def __init__(self, which_split, transform):
119
+ """
120
+ Most simple form of the custom dataset structure.
121
+ Args:
122
+ base_dataset (Dataset): The base dataset to sample from.
123
+ N (int): The number of images to construct into an observable sequence.
124
+ R (int): number of repeats
125
+ operators (list): list of operators from which to sample
126
+ action to take on observations (str): can be 'global' to compute operator over full observations, or 'select_K', where K=integer.
127
+ """
128
+ dataset = load_dataset('imagenet-1k', split=which_split, trust_remote_code=True)
129
+
130
+ self.transform = transform
131
+ self.base_dataset = dataset
132
+
133
+ def __len__(self):
134
+ return len(self.base_dataset)
135
+
136
+ def __getitem__(self, idx):
137
+ data_item = self.base_dataset[idx]
138
+ image = self.transform(data_item['image'].convert('RGB'))
139
+ target = data_item['label']
140
+ return image, target
141
+
142
+ class MazeImageFolder(ImageFolder):
143
+ """
144
+ A custom dataset class that extends the ImageFolder class.
145
+
146
+ Args:
147
+ root (string): Root directory path.
148
+ transform (callable, optional): A function/transform that takes in
149
+ a sample and returns a transformed version.
150
+ E.g, ``transforms.RandomCrop`` for images.
151
+ target_transform (callable, optional): A function/transform that takes
152
+ in the target and transforms it.
153
+ loader (callable, optional): A function to load an image given its path.
154
+ is_valid_file (callable, optional): A function that takes path of an Image file
155
+ and check if the file is a valid file (used to check of corrupt files)
156
+
157
+ Attributes:
158
+ classes (list): List of the class names.
159
+ class_to_idx (dict): Dict with items (class_name, class_index).
160
+ imgs (list): List of (image path, class_index) tuples
161
+ """
162
+
163
+ def __init__(self, root, transform=None, target_transform=None,
164
+ loader=Image.open,
165
+ is_valid_file=None,
166
+ which_set='train',
167
+ augment_p=0.5,
168
+ maze_route_length=10,
169
+ trunc=False,
170
+ expand_range=True):
171
+ super(MazeImageFolder, self).__init__(root, transform, target_transform, loader, is_valid_file)
172
+ self.which_set = which_set
173
+ self.augment_p = augment_p
174
+ self.maze_route_length = maze_route_length
175
+ self.all_paths = {}
176
+ self.trunc = trunc
177
+ self.expand_range = expand_range
178
+
179
+ self._preload()
180
+ print('Solving all mazes...')
181
+ for index in range(len(self.preloaded_samples)):
182
+ path = self.get_solution(self.preloaded_samples[index])
183
+ self.all_paths[index] = path
184
+
185
+ def _preload(self):
186
+ preloaded_samples = []
187
+ with tqdm(total=self.__len__(), initial=0, leave=True, position=0, dynamic_ncols=True) as pbar:
188
+
189
+ for index in range(self.__len__()):
190
+ pbar.set_description('Loading mazes')
191
+ path, target = self.samples[index]
192
+ sample = self.loader(path)
193
+ sample = np.array(sample).astype(np.float32)/255
194
+ preloaded_samples.append(sample)
195
+ pbar.update(1)
196
+ if self.trunc and index == 999: break
197
+ self.preloaded_samples = preloaded_samples
198
+
199
+ def __len__(self):
200
+ if hasattr(self, 'preloaded_samples') and self.preloaded_samples is not None:
201
+ return len(self.preloaded_samples)
202
+ else:
203
+ return super().__len__()
204
+
205
+ def get_solution(self, x):
206
+ x = np.copy(x)
207
+ # Find start (red) and end (green) pixel coordinates
208
+ start_coords = np.argwhere((x == [1, 0, 0]).all(axis=2))
209
+ end_coords = np.argwhere((x == [0, 1, 0]).all(axis=2))
210
+
211
+ if len(start_coords) == 0 or len(end_coords) == 0:
212
+ print("Start or end point not found.")
213
+ return None
214
+
215
+ start_y, start_x = start_coords[0]
216
+ end_y, end_x = end_coords[0]
217
+
218
+ current_y, current_x = start_y, start_x
219
+ path = [4] * self.maze_route_length
220
+
221
+ pi = 0
222
+ while (current_y, current_x) != (end_y, end_x):
223
+ next_y, next_x = -1, -1 # Initialize to invalid coordinates
224
+ direction = -1 # Initialize to an invalid direction
225
+
226
+
227
+ # Check Up
228
+ if current_y > 0 and ((x[current_y - 1, current_x] == [0, 0, 1]).all() or (x[current_y - 1, current_x] == [0, 1, 0]).all()):
229
+ next_y, next_x = current_y - 1, current_x
230
+ direction = 0
231
+
232
+ # Check Down
233
+ elif current_y < x.shape[0] - 1 and ((x[current_y + 1, current_x] == [0, 0, 1]).all() or (x[current_y + 1, current_x] == [0, 1, 0]).all()):
234
+ next_y, next_x = current_y + 1, current_x
235
+ direction = 1
236
+
237
+ # Check Left
238
+ elif current_x > 0 and ((x[current_y, current_x - 1] == [0, 0, 1]).all() or (x[current_y, current_x - 1] == [0, 1, 0]).all()):
239
+ next_y, next_x = current_y, current_x - 1
240
+ direction = 2
241
+
242
+ # Check Right
243
+ elif current_x < x.shape[1] - 1 and ((x[current_y, current_x + 1] == [0, 0, 1]).all() or (x[current_y, current_x + 1] == [0, 1, 0]).all()):
244
+ next_y, next_x = current_y, current_x + 1
245
+ direction = 3
246
+
247
+
248
+ path[pi] = direction
249
+ pi += 1
250
+
251
+ x[current_y, current_x] = [255,255,255] # mark the current as white to avoid going in circles
252
+ current_y, current_x = next_y, next_x
253
+ if pi == len(path):
254
+ break
255
+
256
+ return np.array(path)
257
+
258
+ def __getitem__(self, index):
259
+ """
260
+ Args:
261
+ index (int): Index
262
+
263
+ Returns:
264
+ tuple: (sample, target) where target is class_index of the target class.
265
+ """
266
+
267
+ sample = np.copy(self.preloaded_samples[index])
268
+
269
+ path = np.copy(self.all_paths[index])
270
+
271
+ if self.which_set == 'train':
272
+ # Randomly rotate -90 or +90 degrees
273
+ if random.random() < self.augment_p:
274
+ which_rot = random.choice([-1, 1])
275
+ sample = np.rot90(sample, k=which_rot, axes=(0, 1))
276
+ for pi in range(len(path)):
277
+ if path[pi] == 0: path[pi] = 3 if which_rot == -1 else 2
278
+ elif path[pi] == 1: path[pi] = 2 if which_rot == -1 else 3
279
+ elif path[pi] == 2: path[pi] = 0 if which_rot == -1 else 1
280
+ elif path[pi] == 3: path[pi] = 1 if which_rot == -1 else 0
281
+
282
+
283
+ # Random horizontal flip
284
+ if random.random() < self.augment_p:
285
+ sample = np.fliplr(sample)
286
+ for pi in range(len(path)):
287
+ if path[pi] == 2: path[pi] = 3
288
+ elif path[pi] == 3: path[pi] = 2
289
+
290
+
291
+ # Random vertical flip
292
+ if random.random() < self.augment_p:
293
+ sample = np.flipud(sample)
294
+ for pi in range(len(path)):
295
+ if path[pi] == 0: path[pi] = 1
296
+ elif path[pi] == 1: path[pi] = 0
297
+
298
+ sample = torch.from_numpy(np.copy(sample)).permute(2,0,1)
299
+
300
+ blue_mask = (sample[0] == 0) & (sample[1] == 0) & (sample[2] == 1)
301
+
302
+ sample[:, blue_mask] = 1
303
+ target = path
304
+
305
+
306
+ if not self.expand_range:
307
+ return sample, target
308
+ return (sample*2)-1, (target)
309
+
310
+ class ParityDataset(Dataset):
311
+ def __init__(self, sequence_length=64, length=100000):
312
+ self.sequence_length = sequence_length
313
+ self.length = length
314
+
315
+ def __len__(self):
316
+ return self.length
317
+
318
+ def __getitem__(self, idx):
319
+ vector = 2 * torch.randint(0, 2, (self.sequence_length,)) - 1
320
+ vector = vector.float()
321
+ negatives = (vector == -1).to(torch.long)
322
+ cumsum = torch.cumsum(negatives, dim=0)
323
+ target = (cumsum % 2 != 0).to(torch.long)
324
+ return vector, target
examples/01_mnist.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
examples/02_inference.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
examples/03_mazes.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
examples/04_parity.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
examples/05_huggingface.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
models/README.md ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Continuous Thought Machines
2
+ ## Models
3
+
4
+ This folder contains all model-related code.
5
+
6
+ Some notes for clarity:
7
+ 1. The resnet structure we used (see resnet.py) has a few minor changes that enable constraining the receptive field of the features yielded. We do this because we want the CTM (or baseline methods) to learn a process whereby they gather information. Neural networks that use SGD will find the [path of least resistence](https://era.ed.ac.uk/handle/1842/39606), even if that path doesn't result in actually intelligent behaviour. Constraining the receptive field helps to prevent this, a bit.
models/constants.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ VALID_NEURON_SELECT_TYPES = ['first-last', 'random', 'random-pairing']
2
+
3
+ VALID_BACKBONE_TYPES = [
4
+ f'resnet{depth}-{i}' for depth in [18, 34, 50, 101, 152] for i in range(1, 5)
5
+ ] + ['shallow-wide', 'parity_backbone']
6
+
7
+ VALID_POSITIONAL_EMBEDDING_TYPES = [
8
+ 'learnable-fourier', 'multi-learnable-fourier',
9
+ 'custom-rotational', 'custom-rotational-1d'
10
+ ]
models/ctm.py ADDED
@@ -0,0 +1,633 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+ import numpy as np
4
+ import math
5
+ from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
6
+
7
+ from models.modules import ParityBackbone, SynapseUNET, Squeeze, SuperLinear, LearnableFourierPositionalEncoding, MultiLearnableFourierPositionalEncoding, CustomRotationalEmbedding, CustomRotationalEmbedding1D, ShallowWide
8
+ from models.resnet import prepare_resnet_backbone
9
+ from models.utils import compute_normalized_entropy
10
+
11
+ from models.constants import (
12
+ VALID_NEURON_SELECT_TYPES,
13
+ VALID_BACKBONE_TYPES,
14
+ VALID_POSITIONAL_EMBEDDING_TYPES
15
+ )
16
+
17
+ class ContinuousThoughtMachine(nn.Module, PyTorchModelHubMixin):
18
+ """
19
+ Continuous Thought Machine (CTM).
20
+
21
+ Technical report: https://arxiv.org/abs/2505.05522
22
+
23
+ Interactive Website: https://pub.sakana.ai/ctm/
24
+
25
+ Blog: https://sakana.ai/ctm/
26
+
27
+ Thought takes time and reasoning is a process.
28
+
29
+ The CTM consists of three main ideas:
30
+ 1. The use of internal recurrence, enabling a dimension over which a concept analogous to thought can occur.
31
+ 1. Neuron-level models, that compute post-activations by applying private (i.e., on a per-neuron basis) MLP
32
+ models to a history of incoming pre-activations.
33
+ 2. Synchronisation as representation, where the neural activity over time is tracked and used to compute how
34
+ pairs of neurons synchronise with one another over time. This measure of synchronisation is the representation
35
+ with which the CTM takes action and makes predictions.
36
+
37
+
38
+ Args:
39
+ iterations (int): Number of internal 'thought' ticks (T, in paper).
40
+ d_model (int): Core dimensionality of the CTM's latent space (D, in paper).
41
+ NOTE: Note that this is NOT the representation used for action or prediction, but rather that which
42
+ is fully internal to the model and not directly connected to data.
43
+ d_input (int): Dimensionality of projected attention outputs or direct input features.
44
+ heads (int): Number of attention heads.
45
+ n_synch_out (int): Number of neurons used for output synchronisation (D_out, in paper).
46
+ n_synch_action (int): Number of neurons used for action/attention synchronisation (D_action, in paper).
47
+ synapse_depth (int): Depth of the synapse model (U-Net if > 1, else MLP).
48
+ memory_length (int): History length for Neuron-Level Models (M, in paper).
49
+ deep_nlms (bool): Use deeper (2-layer) NLMs if True, else linear.
50
+ NOTE: we almost always use deep NLMs, but a linear NLM is faster.
51
+ memory_hidden_dims (int): Hidden dimension size for deep NLMs.
52
+ do_layernorm_nlm (bool): Apply LayerNorm within NLMs.
53
+ NOTE: we never set this to true in the paper. If you set this to true you will get strange behaviour,
54
+ but you can potentially encourage more periodic behaviour in the dynamics. Untested; be careful.
55
+ backbone_type (str): Type of feature extraction backbone (e.g., 'resnet18-2', 'none').
56
+ positional_embedding_type (str): Type of positional embedding for backbone features.
57
+ out_dims (int): Output dimension size.
58
+ NOTE: projected from synchronisation!
59
+ prediction_reshaper (list): Shape for reshaping predictions before certainty calculation (task-specific).
60
+ NOTE: this is used to compute certainty and is needed when applying softmax for probabilities
61
+ dropout (float): Dropout rate.
62
+ neuron_select_type (str): Neuron selection strategy ('first-last', 'random', 'random-pairing').
63
+ NOTE: some of this is legacy from our experimentation, but all three strategies are valid and useful.
64
+ We dilineate exactly which strategies we use per experiment in the paper.
65
+ - first-last: build a 'dense' sync matrix for output from the first D_out neurons and action from the
66
+ last D_action neurons. Flatten this matrix into the synchronisation representation.
67
+ This approach shares relationships for neurons and bottlenecks the gradients through them.
68
+ NOTE: the synchronisation size will be (D_out/action * (D_out/action + 1))/2
69
+ - random: randomly select D_out neurons for the 'i' side pairings, and also D_out for the 'j' side pairings,
70
+ also pairing those accross densely, resulting in a bottleneck roughly 2x as wide.
71
+ NOTE: the synchronisation size will be (D_out/action * (D_out/action + 1))/2
72
+ - random-pairing (DEFAULT!): randomly select D_out neurons and pair these with another D_out neurons.
73
+ This results in much less bottlenecking and is the most up-to-date variant.
74
+ NOTE: the synchronisation size will be D_out in this case; better control.
75
+ n_random_pairing_self (int): Number of neurons to select for self-to-self synch when random-pairing is used.
76
+ NOTE: when using random-pairing, i-to-i (self) synchronisation is rare, meaning that 'recovering a
77
+ snapshot representation' (see paper) is difficult. This alleviates that.
78
+ NOTE: works fine when set to 0.
79
+ """
80
+
81
+ def __init__(self,
82
+ iterations,
83
+ d_model,
84
+ d_input,
85
+ heads,
86
+ n_synch_out,
87
+ n_synch_action,
88
+ synapse_depth,
89
+ memory_length,
90
+ deep_nlms,
91
+ memory_hidden_dims,
92
+ do_layernorm_nlm,
93
+ backbone_type,
94
+ positional_embedding_type,
95
+ out_dims,
96
+ prediction_reshaper=[-1],
97
+ dropout=0,
98
+ dropout_nlm=None,
99
+ neuron_select_type='random-pairing',
100
+ n_random_pairing_self=0,
101
+ ):
102
+ super(ContinuousThoughtMachine, self).__init__()
103
+
104
+ # --- Core Parameters ---
105
+ self.iterations = iterations
106
+ self.d_model = d_model
107
+ self.d_input = d_input
108
+ self.memory_length = memory_length
109
+ self.prediction_reshaper = prediction_reshaper
110
+ self.n_synch_out = n_synch_out
111
+ self.n_synch_action = n_synch_action
112
+ self.backbone_type = backbone_type
113
+ self.out_dims = out_dims
114
+ self.positional_embedding_type = positional_embedding_type
115
+ self.neuron_select_type = neuron_select_type
116
+ self.memory_length = memory_length
117
+ dropout_nlm = dropout if dropout_nlm is None else dropout_nlm
118
+
119
+ # --- Assertions ---
120
+ self.verify_args()
121
+
122
+ # --- Input Processing ---
123
+ d_backbone = self.get_d_backbone()
124
+ self.set_initial_rgb()
125
+ self.set_backbone()
126
+ self.positional_embedding = self.get_positional_embedding(d_backbone)
127
+ self.kv_proj = nn.Sequential(nn.LazyLinear(self.d_input), nn.LayerNorm(self.d_input)) if heads else None
128
+ self.q_proj = nn.LazyLinear(self.d_input) if heads else None
129
+ self.attention = nn.MultiheadAttention(self.d_input, heads, dropout, batch_first=True) if heads else None
130
+
131
+ # --- Core CTM Modules ---
132
+ self.synapses = self.get_synapses(synapse_depth, d_model, dropout)
133
+ self.trace_processor = self.get_neuron_level_models(deep_nlms, do_layernorm_nlm, memory_length, memory_hidden_dims, d_model, dropout_nlm)
134
+
135
+ # --- Start States ---
136
+ self.register_parameter('start_activated_state', nn.Parameter(torch.zeros((d_model)).uniform_(-math.sqrt(1/(d_model)), math.sqrt(1/(d_model)))))
137
+ self.register_parameter('start_trace', nn.Parameter(torch.zeros((d_model, memory_length)).uniform_(-math.sqrt(1/(d_model+memory_length)), math.sqrt(1/(d_model+memory_length)))))
138
+
139
+ # --- Synchronisation ---
140
+ self.neuron_select_type_out, self.neuron_select_type_action = self.get_neuron_select_type()
141
+ self.synch_representation_size_action = self.calculate_synch_representation_size(self.n_synch_action)
142
+ self.synch_representation_size_out = self.calculate_synch_representation_size(self.n_synch_out)
143
+
144
+ for synch_type, size in (('action', self.synch_representation_size_action), ('out', self.synch_representation_size_out)):
145
+ print(f"Synch representation size {synch_type}: {size}")
146
+ if self.synch_representation_size_action: # if not zero
147
+ self.set_synchronisation_parameters('action', self.n_synch_action, n_random_pairing_self)
148
+ self.set_synchronisation_parameters('out', self.n_synch_out, n_random_pairing_self)
149
+
150
+ # --- Output Procesing ---
151
+ self.output_projector = nn.Sequential(nn.LazyLinear(self.out_dims))
152
+
153
+ @classmethod
154
+ def _from_pretrained(
155
+ cls,
156
+ *,
157
+ model_id: str,
158
+ revision=None,
159
+ cache_dir=None,
160
+ force_download=False,
161
+ proxies=None,
162
+ resume_download=None,
163
+ local_files_only=False,
164
+ token=None,
165
+ map_location="cpu",
166
+ strict=False,
167
+ **model_kwargs,
168
+ ):
169
+ """Override to handle lazy weights initialization."""
170
+ model = cls(**model_kwargs).to(map_location)
171
+
172
+ # The CTM contains Lazy modules, so we must run a dummy forward pass to initialize them
173
+ if "imagenet" in model_id:
174
+ dummy_input = torch.randn(1, 3, 224, 224, device=map_location)
175
+ elif "maze-large" in model_id:
176
+ dummy_input = torch.randn(1, 3, 99, 99, device=map_location)
177
+ else:
178
+ raise NotImplementedError
179
+
180
+ with torch.no_grad():
181
+ _ = model(dummy_input)
182
+
183
+ model_file = hf_hub_download(
184
+ repo_id=model_id,
185
+ filename="model.safetensors",
186
+ revision=revision,
187
+ cache_dir=cache_dir,
188
+ force_download=force_download,
189
+ proxies=proxies,
190
+ resume_download=resume_download,
191
+ token=token,
192
+ local_files_only=local_files_only,
193
+ )
194
+ from safetensors.torch import load_model as load_model_as_safetensor
195
+ load_model_as_safetensor(model, model_file, strict=strict, device=map_location)
196
+
197
+ model.eval()
198
+ return model
199
+
200
+ # --- Core CTM Methods ---
201
+
202
+ def compute_synchronisation(self, activated_state, decay_alpha, decay_beta, r, synch_type):
203
+ """
204
+ Computes synchronisation to be used as a vector representation.
205
+
206
+ A neuron has what we call a 'trace', which is a history (time series) that changes with internal
207
+ recurrence. i.e., it gets longer with every internal tick. There are pre-activation traces
208
+ that are used in the NLMs and post-activation traces that, in theory, are used in this method.
209
+
210
+ We define sychronisation between neuron i and j as the dot product between their respective
211
+ time series. Since there can be many internal ticks, this process can be quite compute heavy as it
212
+ involves many dot products that repeat computation at each step.
213
+
214
+ Therefore, in practice, we update the synchronisation based on the current post-activations,
215
+ which we call the 'activated state' here. This is possible because the inputs to synchronisation
216
+ are only updated recurrently at each step, meaning that there is a linear recurrence we can
217
+ leverage.
218
+
219
+ See Appendix TODO of the Technical Report (TODO:LINK) for the maths that enables this method.
220
+ """
221
+
222
+ if synch_type == 'action': # Get action parameters
223
+ n_synch = self.n_synch_action
224
+ neuron_indices_left = self.action_neuron_indices_left
225
+ neuron_indices_right = self.action_neuron_indices_right
226
+ elif synch_type == 'out': # Get input parameters
227
+ n_synch = self.n_synch_out
228
+ neuron_indices_left = self.out_neuron_indices_left
229
+ neuron_indices_right = self.out_neuron_indices_right
230
+
231
+ if self.neuron_select_type in ('first-last', 'random'):
232
+ # For first-last and random, we compute the pairwise sync between all selected neurons
233
+ if self.neuron_select_type == 'first-last':
234
+ if synch_type == 'action': # Use last n_synch neurons for action
235
+ selected_left = selected_right = activated_state[:, -n_synch:]
236
+ elif synch_type == 'out': # Use first n_synch neurons for out
237
+ selected_left = selected_right = activated_state[:, :n_synch]
238
+ else: # Use the randomly selected neurons
239
+ selected_left = activated_state[:, neuron_indices_left]
240
+ selected_right = activated_state[:, neuron_indices_right]
241
+
242
+ # Compute outer product of selected neurons
243
+ outer = selected_left.unsqueeze(2) * selected_right.unsqueeze(1)
244
+ # Resulting matrix is symmetric, so we only need the upper triangle
245
+ i, j = torch.triu_indices(n_synch, n_synch)
246
+ pairwise_product = outer[:, i, j]
247
+
248
+ elif self.neuron_select_type == 'random-pairing':
249
+ # For random-pairing, we compute the sync between specific pairs of neurons
250
+ left = activated_state[:, neuron_indices_left]
251
+ right = activated_state[:, neuron_indices_right]
252
+ pairwise_product = left * right
253
+ else:
254
+ raise ValueError("Invalid neuron selection type")
255
+
256
+
257
+
258
+ # Compute synchronisation recurrently
259
+ if decay_alpha is None or decay_beta is None:
260
+ decay_alpha = pairwise_product
261
+ decay_beta = torch.ones_like(pairwise_product)
262
+ else:
263
+ decay_alpha = r * decay_alpha + pairwise_product
264
+ decay_beta = r * decay_beta + 1
265
+
266
+ synchronisation = decay_alpha / (torch.sqrt(decay_beta))
267
+ return synchronisation, decay_alpha, decay_beta
268
+
269
+ def compute_features(self, x):
270
+ """
271
+ Compute the key-value features from the input data using the backbone.
272
+ """
273
+ initial_rgb = self.initial_rgb(x)
274
+ self.kv_features = self.backbone(initial_rgb)
275
+ pos_emb = self.positional_embedding(self.kv_features)
276
+ combined_features = (self.kv_features + pos_emb).flatten(2).transpose(1, 2)
277
+ kv = self.kv_proj(combined_features)
278
+ return kv
279
+
280
+ def compute_certainty(self, current_prediction):
281
+ """
282
+ Compute the certainty of the current prediction.
283
+
284
+ We define certainty as being 1-normalised entropy.
285
+
286
+ For legacy reasons we stack that in a 2D vector as this can be used for optimisation later.
287
+ """
288
+ B = current_prediction.size(0)
289
+ reshaped_pred = current_prediction.reshape([B] + self.prediction_reshaper)
290
+ ne = compute_normalized_entropy(reshaped_pred)
291
+ current_certainty = torch.stack((ne, 1-ne), -1)
292
+ return current_certainty
293
+
294
+ # --- Setup Methods ---
295
+
296
+ def set_initial_rgb(self):
297
+ """
298
+ This is largely to accommodate training on grayscale images and is legacy, but it
299
+ doesn't hurt the model in any way that we can tell.
300
+ """
301
+ if 'resnet' in self.backbone_type:
302
+ self.initial_rgb = nn.LazyConv2d(3, 1, 1) # Adapts input channels lazily
303
+ else:
304
+ self.initial_rgb = nn.Identity()
305
+
306
+ def get_d_backbone(self):
307
+ """
308
+ Get the dimensionality of the backbone output, to be used for positional embedding setup.
309
+
310
+ This is a little bit complicated for resnets, but the logic should be easy enough to read below.
311
+ """
312
+ if self.backbone_type == 'shallow-wide':
313
+ return 2048
314
+ elif self.backbone_type == 'parity_backbone':
315
+ return self.d_input
316
+ elif 'resnet' in self.backbone_type:
317
+ if '18' in self.backbone_type or '34' in self.backbone_type:
318
+ if self.backbone_type.split('-')[1]=='1': return 64
319
+ elif self.backbone_type.split('-')[1]=='2': return 128
320
+ elif self.backbone_type.split('-')[1]=='3': return 256
321
+ elif self.backbone_type.split('-')[1]=='4': return 512
322
+ else:
323
+ raise NotImplementedError
324
+ else:
325
+ if self.backbone_type.split('-')[1]=='1': return 256
326
+ elif self.backbone_type.split('-')[1]=='2': return 512
327
+ elif self.backbone_type.split('-')[1]=='3': return 1024
328
+ elif self.backbone_type.split('-')[1]=='4': return 2048
329
+ else:
330
+ raise NotImplementedError
331
+ elif self.backbone_type == 'none':
332
+ return None
333
+ else:
334
+ raise ValueError(f"Invalid backbone_type: {self.backbone_type}")
335
+
336
+ def set_backbone(self):
337
+ """
338
+ Set the backbone module based on the specified type.
339
+ """
340
+ if self.backbone_type == 'shallow-wide':
341
+ self.backbone = ShallowWide()
342
+ elif self.backbone_type == 'parity_backbone':
343
+ d_backbone = self.get_d_backbone()
344
+ self.backbone = ParityBackbone(n_embeddings=2, d_embedding=d_backbone)
345
+ elif 'resnet' in self.backbone_type:
346
+ self.backbone = prepare_resnet_backbone(self.backbone_type)
347
+ elif self.backbone_type == 'none':
348
+ self.backbone = nn.Identity()
349
+ else:
350
+ raise ValueError(f"Invalid backbone_type: {self.backbone_type}")
351
+
352
+ def get_positional_embedding(self, d_backbone):
353
+ """
354
+ Get the positional embedding module.
355
+
356
+ For Imagenet and mazes we used NO positional embedding, and largely don't think
357
+ that it is necessary as the CTM can build up its own internal world model when
358
+ observing.
359
+
360
+ LearnableFourierPositionalEncoding:
361
+ Implements Algorithm 1 from "Learnable Fourier Features for Multi-Dimensional
362
+ Spatial Positional Encoding" (https://arxiv.org/pdf/2106.02795.pdf).
363
+ Provides positional information for 2D feature maps.
364
+
365
+ (MultiLearnableFourierPositionalEncoding uses multiple feature scales)
366
+
367
+ CustomRotationalEmbedding:
368
+ Simple sinusoidal embedding to encourage interpretability
369
+ """
370
+ if self.positional_embedding_type == 'learnable-fourier':
371
+ return LearnableFourierPositionalEncoding(d_backbone, gamma=1 / 2.5)
372
+ elif self.positional_embedding_type == 'multi-learnable-fourier':
373
+ return MultiLearnableFourierPositionalEncoding(d_backbone)
374
+ elif self.positional_embedding_type == 'custom-rotational':
375
+ return CustomRotationalEmbedding(d_backbone)
376
+ elif self.positional_embedding_type == 'custom-rotational-1d':
377
+ return CustomRotationalEmbedding1D(d_backbone)
378
+ elif self.positional_embedding_type == 'none':
379
+ return lambda x: 0 # Default no-op
380
+ else:
381
+ raise ValueError(f"Invalid positional_embedding_type: {self.positional_embedding_type}")
382
+
383
+ def get_neuron_level_models(self, deep_nlms, do_layernorm_nlm, memory_length, memory_hidden_dims, d_model, dropout):
384
+ """
385
+ Neuron level models are one of the core innovations of the CTM. They apply separate MLPs/linears to
386
+ each neuron.
387
+ NOTE: the name 'SuperLinear' is largely legacy, but its purpose is to apply separate linear layers
388
+ per neuron. It is sort of a 'grouped linear' function, where the group size is equal to 1.
389
+ One could make the group size bigger and use fewer parameters, but that is future work.
390
+
391
+ NOTE: We used GLU() nonlinearities because they worked well in practice.
392
+ """
393
+ if deep_nlms:
394
+ return nn.Sequential(
395
+ nn.Sequential(
396
+ SuperLinear(in_dims=memory_length, out_dims=2 * memory_hidden_dims, N=d_model,
397
+ do_norm=do_layernorm_nlm, dropout=dropout),
398
+ nn.GLU(),
399
+ SuperLinear(in_dims=memory_hidden_dims, out_dims=2, N=d_model,
400
+ do_norm=do_layernorm_nlm, dropout=dropout),
401
+ nn.GLU(),
402
+ Squeeze(-1)
403
+ )
404
+ )
405
+ else:
406
+ return nn.Sequential(
407
+ nn.Sequential(
408
+ SuperLinear(in_dims=memory_length, out_dims=2, N=d_model,
409
+ do_norm=do_layernorm_nlm, dropout=dropout),
410
+ nn.GLU(),
411
+ Squeeze(-1)
412
+ )
413
+ )
414
+
415
+ def get_synapses(self, synapse_depth, d_model, dropout):
416
+ """
417
+ The synapse model is the recurrent model in the CTM. It's purpose is to share information
418
+ across neurons. If using depth of 1, this is just a simple single layer with nonlinearity and layernomr.
419
+ For deeper synapse models we use a U-NET structure with many skip connections. In practice this performs
420
+ better as it enables multi-level information mixing.
421
+
422
+ The intuition with having a deep UNET model for synapses is that the action of synaptic connections is
423
+ not necessarily a linear one, and that approximate a synapose 'update' step in the brain is non trivial.
424
+ Hence, we set it up so that the CTM can learn some complex internal rule instead of trying to approximate
425
+ it ourselves.
426
+ """
427
+ if synapse_depth == 1:
428
+ return nn.Sequential(
429
+ nn.Dropout(dropout),
430
+ nn.LazyLinear(d_model * 2),
431
+ nn.GLU(),
432
+ nn.LayerNorm(d_model)
433
+ )
434
+ else:
435
+ return SynapseUNET(d_model, synapse_depth, 16, dropout) # hard-coded minimum width of 16; future work TODO.
436
+
437
+ def set_synchronisation_parameters(self, synch_type: str, n_synch: int, n_random_pairing_self: int = 0):
438
+ """
439
+ 1. Set the buffers for selecting neurons so that these indices are saved into the model state_dict.
440
+ 2. Set the parameters for learnable exponential decay when computing synchronisation between all
441
+ neurons.
442
+ """
443
+ assert synch_type in ('out', 'action'), f"Invalid synch_type: {synch_type}"
444
+ left, right = self.initialize_left_right_neurons(synch_type, self.d_model, n_synch, n_random_pairing_self)
445
+ synch_representation_size = self.synch_representation_size_action if synch_type == 'action' else self.synch_representation_size_out
446
+ self.register_buffer(f'{synch_type}_neuron_indices_left', left)
447
+ self.register_buffer(f'{synch_type}_neuron_indices_right', right)
448
+ self.register_parameter(f'decay_params_{synch_type}', nn.Parameter(torch.zeros(synch_representation_size), requires_grad=True))
449
+
450
+ def initialize_left_right_neurons(self, synch_type, d_model, n_synch, n_random_pairing_self=0):
451
+ """
452
+ Initialize the left and right neuron indices based on the neuron selection type.
453
+ This complexity is owing to legacy experiments, but we retain that these types of
454
+ neuron selections are interesting to experiment with.
455
+ """
456
+ if self.neuron_select_type=='first-last':
457
+ if synch_type == 'out':
458
+ neuron_indices_left = neuron_indices_right = torch.arange(0, n_synch)
459
+ elif synch_type == 'action':
460
+ neuron_indices_left = neuron_indices_right = torch.arange(d_model-n_synch, d_model)
461
+
462
+ elif self.neuron_select_type=='random':
463
+ neuron_indices_left = torch.from_numpy(np.random.choice(np.arange(d_model), size=n_synch))
464
+ neuron_indices_right = torch.from_numpy(np.random.choice(np.arange(d_model), size=n_synch))
465
+
466
+ elif self.neuron_select_type=='random-pairing':
467
+ assert n_synch > n_random_pairing_self, f"Need at least {n_random_pairing_self} pairs for {self.neuron_select_type}"
468
+ neuron_indices_left = torch.from_numpy(np.random.choice(np.arange(d_model), size=n_synch))
469
+ neuron_indices_right = torch.concatenate((neuron_indices_left[:n_random_pairing_self], torch.from_numpy(np.random.choice(np.arange(d_model), size=n_synch-n_random_pairing_self))))
470
+
471
+ device = self.start_activated_state.device
472
+ return neuron_indices_left.to(device), neuron_indices_right.to(device)
473
+
474
+ def get_neuron_select_type(self):
475
+ """
476
+ Another helper method to accomodate our legacy neuron selection types.
477
+ TODO: additional experimentation and possible removal of 'first-last' and 'random'
478
+ """
479
+ print(f"Using neuron select type: {self.neuron_select_type}")
480
+ if self.neuron_select_type == 'first-last':
481
+ neuron_select_type_out, neuron_select_type_action = 'first', 'last'
482
+ elif self.neuron_select_type in ('random', 'random-pairing'):
483
+ neuron_select_type_out = neuron_select_type_action = self.neuron_select_type
484
+ else:
485
+ raise ValueError(f"Invalid neuron selection type: {self.neuron_select_type}")
486
+ return neuron_select_type_out, neuron_select_type_action
487
+
488
+ # --- Utilty Methods ---
489
+
490
+ def verify_args(self):
491
+ """
492
+ Verify the validity of the input arguments to ensure consistent behaviour.
493
+ Specifically when selecting neurons for sychronisation using 'first-last' or 'random',
494
+ one needs the right number of neurons
495
+ """
496
+ assert self.neuron_select_type in VALID_NEURON_SELECT_TYPES, \
497
+ f"Invalid neuron selection type: {self.neuron_select_type}"
498
+
499
+ assert self.backbone_type in VALID_BACKBONE_TYPES + ['none'], \
500
+ f"Invalid backbone_type: {self.backbone_type}"
501
+
502
+ assert self.positional_embedding_type in VALID_POSITIONAL_EMBEDDING_TYPES + ['none'], \
503
+ f"Invalid positional_embedding_type: {self.positional_embedding_type}"
504
+
505
+ if self.neuron_select_type == 'first-last':
506
+ assert self.d_model >= (self.n_synch_out + self.n_synch_action), \
507
+ "d_model must be >= n_synch_out + n_synch_action for neuron subsets"
508
+
509
+ if self.backbone_type=='none' and self.positional_embedding_type!='none':
510
+ raise AssertionError("There should be no positional embedding if there is no backbone.")
511
+
512
+ def calculate_synch_representation_size(self, n_synch):
513
+ """
514
+ Calculate the size of the synchronisation representation based on neuron selection type.
515
+ """
516
+ if self.neuron_select_type == 'random-pairing':
517
+ synch_representation_size = n_synch
518
+ elif self.neuron_select_type in ('first-last', 'random'):
519
+ synch_representation_size = (n_synch * (n_synch + 1)) // 2
520
+ else:
521
+ raise ValueError(f"Invalid neuron selection type: {self.neuron_select_type}")
522
+ return synch_representation_size
523
+
524
+
525
+
526
+
527
+ def forward(self, x, track=False, adaptive_halt=True, certainty_threshold=0.85):
528
+ """
529
+ Forward pass through CTM.
530
+
531
+ v3.0 Features:
532
+ - Adaptive Halting: Stop early if certainty > threshold
533
+ - Returns actual ticks used (may be < self.iterations)
534
+
535
+ Args:
536
+ x: Input tensor
537
+ track: Whether to track intermediate states
538
+ adaptive_halt: Enable early stopping (v3.0)
539
+ certainty_threshold: Halt if certainty exceeds this (v3.0)
540
+ """
541
+ B = x.size(0)
542
+ device = x.device
543
+
544
+ # --- Tracking Initialization ---
545
+ pre_activations_tracking = []
546
+ post_activations_tracking = []
547
+ synch_out_tracking = []
548
+ synch_action_tracking = []
549
+ attention_tracking = []
550
+
551
+ # --- Featurise Input Data ---
552
+ kv = self.compute_features(x)
553
+
554
+ # --- Initialise Recurrent State ---
555
+ state_trace = self.start_trace.unsqueeze(0).expand(B, -1, -1) # Shape: (B, H, T)
556
+ activated_state = self.start_activated_state.unsqueeze(0).expand(B, -1) # Shape: (B, H)
557
+
558
+ # --- Prepare Storage for Outputs per Iteration ---
559
+ predictions = torch.empty(B, self.out_dims, self.iterations, device=device, dtype=torch.float32)
560
+ certainties = torch.empty(B, 2, self.iterations, device=device, dtype=torch.float32)
561
+
562
+ # --- Initialise Recurrent Synch Values ---
563
+ decay_alpha_action, decay_beta_action = None, None
564
+ self.decay_params_action.data = torch.clamp(self.decay_params_action, 0, 15) # Fix from github user: kuviki
565
+ self.decay_params_out.data = torch.clamp(self.decay_params_out, 0, 15)
566
+ r_action, r_out = torch.exp(-self.decay_params_action).unsqueeze(0).repeat(B, 1), torch.exp(-self.decay_params_out).unsqueeze(0).repeat(B, 1)
567
+
568
+ _, decay_alpha_out, decay_beta_out = self.compute_synchronisation(activated_state, None, None, r_out, synch_type='out')
569
+ # Compute learned weighting for synchronisation
570
+
571
+ # v3.0: Track actual ticks used
572
+ ticks_used = self.iterations
573
+ halted_early = False
574
+
575
+ # --- Recurrent Loop ---
576
+ for stepi in range(self.iterations):
577
+
578
+ # --- Calculate Synchronisation for Input Data Interaction ---
579
+ synchronisation_action, decay_alpha_action, decay_beta_action = self.compute_synchronisation(activated_state, decay_alpha_action, decay_beta_action, r_action, synch_type='action')
580
+
581
+ # --- Interact with Data via Attention ---
582
+ q = self.q_proj(synchronisation_action).unsqueeze(1)
583
+ attn_out, attn_weights = self.attention(q, kv, kv, average_attn_weights=False, need_weights=True)
584
+ attn_out = attn_out.squeeze(1)
585
+ pre_synapse_input = torch.concatenate((attn_out, activated_state), dim=-1)
586
+
587
+ # --- Apply Synapses ---
588
+ state = self.synapses(pre_synapse_input)
589
+ # The 'state_trace' is the history of incoming pre-activations
590
+ state_trace = torch.cat((state_trace[:, :, 1:], state.unsqueeze(-1)), dim=-1)
591
+
592
+ # --- Apply Neuron-Level Models ---
593
+ activated_state = self.trace_processor(state_trace)
594
+ # One would also keep an 'activated_state_trace' as the history of outgoing post-activations
595
+ # BUT, this is unnecessary because the synchronisation calculation is fully linear and can be
596
+ # done using only the currect activated state (see compute_synchronisation method for explanation)
597
+
598
+ # --- Calculate Synchronisation for Output Predictions ---
599
+ synchronisation_out, decay_alpha_out, decay_beta_out = self.compute_synchronisation(activated_state, decay_alpha_out, decay_beta_out, r_out, synch_type='out')
600
+
601
+ # --- Get Predictions and Certainties ---
602
+ current_prediction = self.output_projector(synchronisation_out)
603
+ current_certainty = self.compute_certainty(current_prediction)
604
+
605
+ predictions[..., stepi] = current_prediction
606
+ certainties[..., stepi] = current_certainty
607
+
608
+ # --- Tracking ---
609
+ if track:
610
+ pre_activations_tracking.append(state_trace[:,:,-1].detach().cpu().numpy())
611
+ post_activations_tracking.append(activated_state.detach().cpu().numpy())
612
+ attention_tracking.append(attn_weights.detach().cpu().numpy())
613
+ synch_out_tracking.append(synchronisation_out.detach().cpu().numpy())
614
+ synch_action_tracking.append(synchronisation_action.detach().cpu().numpy())
615
+
616
+ # --- v3.0: Adaptive Halting ---
617
+ if adaptive_halt and not self.training:
618
+ # Check if all samples in batch are confident enough
619
+ batch_certainty = current_certainty[:, 1].mean().item() # 1-entropy
620
+ if batch_certainty > certainty_threshold:
621
+ ticks_used = stepi + 1
622
+ halted_early = True
623
+ # Fill remaining predictions with current value
624
+ for remaining in range(stepi + 1, self.iterations):
625
+ predictions[..., remaining] = current_prediction
626
+ certainties[..., remaining] = current_certainty
627
+ break
628
+
629
+ # --- Return Values ---
630
+ if track:
631
+ return predictions, certainties, (np.array(synch_out_tracking), np.array(synch_action_tracking)), np.array(pre_activations_tracking), np.array(post_activations_tracking), np.array(attention_tracking)
632
+ return predictions, certainties, synchronisation_out
633
+
models/ctm_qamnist.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from models.ctm import ContinuousThoughtMachine
4
+ from models.modules import MNISTBackbone, QAMNISTIndexEmbeddings, QAMNISTOperatorEmbeddings
5
+
6
+ class ContinuousThoughtMachineQAMNIST(ContinuousThoughtMachine):
7
+ def __init__(self,
8
+ iterations,
9
+ d_model,
10
+ d_input,
11
+ heads,
12
+ n_synch_out,
13
+ n_synch_action,
14
+ synapse_depth,
15
+ memory_length,
16
+ deep_nlms,
17
+ memory_hidden_dims,
18
+ do_layernorm_nlm,
19
+ out_dims,
20
+ iterations_per_digit,
21
+ iterations_per_question_part,
22
+ iterations_for_answering,
23
+ prediction_reshaper=[-1],
24
+ dropout=0,
25
+ neuron_select_type='first-last',
26
+ n_random_pairing_self=256
27
+ ):
28
+ super().__init__(
29
+ iterations=iterations,
30
+ d_model=d_model,
31
+ d_input=d_input,
32
+ heads=heads,
33
+ n_synch_out=n_synch_out,
34
+ n_synch_action=n_synch_action,
35
+ synapse_depth=synapse_depth,
36
+ memory_length=memory_length,
37
+ deep_nlms=deep_nlms,
38
+ memory_hidden_dims=memory_hidden_dims,
39
+ do_layernorm_nlm=do_layernorm_nlm,
40
+ out_dims=out_dims,
41
+ prediction_reshaper=prediction_reshaper,
42
+ dropout=dropout,
43
+ neuron_select_type=neuron_select_type,
44
+ n_random_pairing_self=n_random_pairing_self,
45
+ backbone_type='none',
46
+ positional_embedding_type='none',
47
+ )
48
+
49
+ # --- Core Parameters ---
50
+ self.iterations_per_digit = iterations_per_digit
51
+ self.iterations_per_question_part = iterations_per_question_part
52
+ self.iterations_for_answering = iterations_for_answering
53
+
54
+ # --- Setup Methods ---
55
+
56
+ def set_initial_rgb(self):
57
+ """Set the initial RGB values for the backbone."""
58
+ return None
59
+
60
+ def get_d_backbone(self):
61
+ """Get the dimensionality of the backbone output."""
62
+ return self.d_input
63
+
64
+ def set_backbone(self):
65
+ """Set the backbone module based on the specified type."""
66
+ self.backbone_digit = MNISTBackbone(self.d_input)
67
+ self.index_backbone = QAMNISTIndexEmbeddings(50, self.d_input)
68
+ self.operator_backbone = QAMNISTOperatorEmbeddings(2, self.d_input)
69
+ pass
70
+
71
+ # --- Utilty Methods ---
72
+
73
+ def determine_step_type(self, total_iterations_for_digits, total_iterations_for_question, stepi: int):
74
+ """Determine whether the current step is for digits, questions, or answers."""
75
+ is_digit_step = stepi < total_iterations_for_digits
76
+ is_question_step = total_iterations_for_digits <= stepi < total_iterations_for_digits + total_iterations_for_question
77
+ is_answer_step = stepi >= total_iterations_for_digits + total_iterations_for_question
78
+ return is_digit_step, is_question_step, is_answer_step
79
+
80
+ def determine_index_operator_step_type(self, total_iterations_for_digits, stepi: int):
81
+ """Determine whether the current step is for index or operator."""
82
+ step_within_questions = stepi - total_iterations_for_digits
83
+ if step_within_questions % (2 * self.iterations_per_question_part) < self.iterations_per_question_part:
84
+ is_index_step = True
85
+ is_operator_step = False
86
+ else:
87
+ is_index_step = False
88
+ is_operator_step = True
89
+ return is_index_step, is_operator_step
90
+
91
+ def get_kv_for_step(self, total_iterations_for_digits, total_iterations_for_question, stepi, x, z, prev_input=None, prev_kv=None):
92
+ """Get the key-value for the current step."""
93
+ is_digit_step, is_question_step, is_answer_step = self.determine_step_type(total_iterations_for_digits, total_iterations_for_question, stepi)
94
+
95
+ if is_digit_step:
96
+ current_input = x[:, stepi]
97
+ if prev_input is not None and torch.equal(current_input, prev_input):
98
+ return prev_kv, prev_input
99
+ kv = self.kv_proj(self.backbone_digit(current_input).flatten(2).permute(0, 2, 1))
100
+
101
+ elif is_question_step:
102
+ offset = stepi - total_iterations_for_digits
103
+ current_input = z[:, offset]
104
+ if prev_input is not None and torch.equal(current_input, prev_input):
105
+ return prev_kv, prev_input
106
+ is_index_step, is_operator_step = self.determine_index_operator_step_type(total_iterations_for_digits, stepi)
107
+ if is_index_step:
108
+ kv = self.index_backbone(current_input)
109
+ elif is_operator_step:
110
+ kv = self.operator_backbone(current_input)
111
+ else:
112
+ raise ValueError("Invalid step type for question processing.")
113
+
114
+ elif is_answer_step:
115
+ current_input = None
116
+ kv = torch.zeros((x.size(0), self.d_input), device=x.device)
117
+
118
+ else:
119
+ raise ValueError("Invalid step type.")
120
+
121
+ return kv, current_input
122
+
123
+
124
+
125
+
126
+ def forward(self, x, z, track=False):
127
+ B = x.size(0)
128
+ device = x.device
129
+
130
+ # --- Tracking Initialization ---
131
+ pre_activations_tracking = []
132
+ post_activations_tracking = []
133
+ attention_tracking = []
134
+ embedding_tracking = []
135
+
136
+ total_iterations_for_digits = x.size(1)
137
+ total_iterations_for_question = z.size(1)
138
+ total_iterations = total_iterations_for_digits + total_iterations_for_question + self.iterations_for_answering
139
+
140
+ # --- Initialise Recurrent State ---
141
+ state_trace = self.start_trace.unsqueeze(0).expand(B, -1, -1) # Shape: (B, H, T)
142
+ activated_state = self.start_activated_state.unsqueeze(0).expand(B, -1) # Shape: (B, H)
143
+
144
+ # --- Storage for outputs per iteration ---
145
+ predictions = torch.empty(B, self.out_dims, total_iterations, device=device, dtype=x.dtype)
146
+ certainties = torch.empty(B, 2, total_iterations, device=device, dtype=x.dtype)
147
+
148
+ # --- Initialise Recurrent Synch Values ---
149
+ decay_alpha_action, decay_beta_action = None, None
150
+ self.decay_params_action.data = torch.clamp(self.decay_params_action, 0, 15) # Fix from github user: kuviki
151
+ self.decay_params_out.data = torch.clamp(self.decay_params_out, 0, 15)
152
+ r_action, r_out = torch.exp(-self.decay_params_action).unsqueeze(0).repeat(B, 1), torch.exp(-self.decay_params_out).unsqueeze(0).repeat(B, 1)
153
+
154
+ _, decay_alpha_out, decay_beta_out = self.compute_synchronisation(activated_state, None, None, r_out, synch_type='out')
155
+
156
+ prev_input = None
157
+ prev_kv = None
158
+
159
+ # --- Recurrent Loop ---
160
+ for stepi in range(total_iterations):
161
+ is_digit_step, is_question_step, is_answer_step = self.determine_step_type(total_iterations_for_digits, total_iterations_for_question, stepi)
162
+
163
+ kv, prev_input = self.get_kv_for_step(total_iterations_for_digits, total_iterations_for_question, stepi, x, z, prev_input, prev_kv)
164
+ prev_kv = kv
165
+
166
+ synchronization_action, decay_alpha_action, decay_beta_action = self.compute_synchronisation(activated_state, decay_alpha_action, decay_beta_action, r_action, synch_type='action')
167
+
168
+ # --- Interact with Data via Attention ---
169
+ attn_weights = None
170
+ if is_digit_step:
171
+ q = self.q_proj(synchronization_action).unsqueeze(1)
172
+ attn_out, attn_weights = self.attention(q, kv, kv, average_attn_weights=False, need_weights=True)
173
+ attn_out = attn_out.squeeze(1)
174
+ pre_synapse_input = torch.concatenate((attn_out, activated_state), dim=-1)
175
+ else:
176
+ kv = kv.squeeze(1)
177
+ pre_synapse_input = torch.concatenate((kv, activated_state), dim=-1)
178
+
179
+ # --- Apply Synapses ---
180
+ state = self.synapses(pre_synapse_input)
181
+ state_trace = torch.cat((state_trace[:, :, 1:], state.unsqueeze(-1)), dim=-1)
182
+
183
+ # --- Apply NLMs ---
184
+ activated_state = self.trace_processor(state_trace)
185
+
186
+ # --- Calculate Synchronisation for Output Predictions ---
187
+ synchronization_out, decay_alpha_out, decay_beta_out = self.compute_synchronisation(activated_state, decay_alpha_out, decay_beta_out, r_out, synch_type='out')
188
+
189
+ # --- Get Predictions and Certainties ---
190
+ current_prediction = self.output_projector(synchronization_out)
191
+ current_certainty = self.compute_certainty(current_prediction)
192
+
193
+ predictions[..., stepi] = current_prediction
194
+ certainties[..., stepi] = current_certainty
195
+
196
+ # --- Tracking ---
197
+ if track:
198
+ pre_activations_tracking.append(state_trace[:,:,-1].detach().cpu().numpy())
199
+ post_activations_tracking.append(activated_state.detach().cpu().numpy())
200
+ if attn_weights is not None:
201
+ attention_tracking.append(attn_weights.detach().cpu().numpy())
202
+ if is_question_step:
203
+ embedding_tracking.append(kv.detach().cpu().numpy())
204
+
205
+ # --- Return Values ---
206
+ if track:
207
+ return predictions, certainties, synchronization_out, np.array(pre_activations_tracking), np.array(post_activations_tracking), np.array(attention_tracking), np.array(embedding_tracking)
208
+ return predictions, certainties, synchronization_out
models/ctm_rl.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ import math
5
+ from models.ctm import ContinuousThoughtMachine
6
+ from models.modules import MiniGridBackbone, ClassicControlBackbone, SynapseUNET
7
+ from models.utils import compute_decay
8
+ from models.constants import VALID_NEURON_SELECT_TYPES
9
+
10
+ class ContinuousThoughtMachineRL(ContinuousThoughtMachine):
11
+ def __init__(self,
12
+ iterations,
13
+ d_model,
14
+ d_input,
15
+ n_synch_out,
16
+ synapse_depth,
17
+ memory_length,
18
+ deep_nlms,
19
+ memory_hidden_dims,
20
+ do_layernorm_nlm,
21
+ backbone_type,
22
+ prediction_reshaper=[-1],
23
+ dropout=0,
24
+ neuron_select_type='first-last',
25
+ ):
26
+ super().__init__(
27
+ iterations=iterations,
28
+ d_model=d_model,
29
+ d_input=d_input,
30
+ heads=0, # Set heads to 0 will return None
31
+ n_synch_out=n_synch_out,
32
+ n_synch_action=0,
33
+ synapse_depth=synapse_depth,
34
+ memory_length=memory_length,
35
+ deep_nlms=deep_nlms,
36
+ memory_hidden_dims=memory_hidden_dims,
37
+ do_layernorm_nlm=do_layernorm_nlm,
38
+ out_dims=0,
39
+ prediction_reshaper=prediction_reshaper,
40
+ dropout=dropout,
41
+ neuron_select_type=neuron_select_type,
42
+ backbone_type=backbone_type,
43
+ n_random_pairing_self=0,
44
+ positional_embedding_type='none',
45
+ )
46
+
47
+ # --- Use a minimal CTM w/out input (action) synch ---
48
+ self.neuron_select_type_action = None
49
+ self.synch_representation_size_action = None
50
+
51
+ # --- Start dynamics with a learned activated state trace ---
52
+ self.register_parameter('start_activated_trace', nn.Parameter(torch.zeros((d_model, memory_length)).uniform_(-math.sqrt(1/(d_model+memory_length)), math.sqrt(1/(d_model+memory_length))), requires_grad=True))
53
+ self.start_activated_state = None
54
+
55
+ self.register_buffer('diagonal_mask_out', torch.triu(torch.ones(self.n_synch_out, self.n_synch_out, dtype=torch.bool)))
56
+
57
+ self.attention = None # Should already be None because super(... heads=0... )
58
+ self.q_proj = None # Should already be None because super(... heads=0... )
59
+ self.kv_proj = None # Should already be None because super(... heads=0... )
60
+ self.output_projector = None
61
+
62
+ # --- Core CTM Methods ---
63
+
64
+ def compute_synchronisation(self, activated_state_trace):
65
+ """Compute the synchronisation between neurons."""
66
+ assert self.neuron_select_type == "first-last", "only fisrst-last neuron selection is supported here"
67
+ # For RL tasks we track a sliding window of activations from which we compute synchronisation
68
+ S = activated_state_trace.permute(0, 2, 1)
69
+ diagonal_mask = self.diagonal_mask_out.to(S.device)
70
+ decay = compute_decay(S.size(1), self.decay_params_out, clamp_lims=(0, 4))
71
+ synchronisation = ((decay.unsqueeze(0) *(S[:,:,-self.n_synch_out:].unsqueeze(-1) * S[:,:,-self.n_synch_out:].unsqueeze(-2))[:,:,diagonal_mask]).sum(1))/torch.sqrt(decay.unsqueeze(0).sum(1,))
72
+ return synchronisation
73
+
74
+ # --- Setup Methods ---
75
+
76
+ def set_initial_rgb(self):
77
+ """Set the initial RGB values for the backbone."""
78
+ return None
79
+
80
+ def get_d_backbone(self):
81
+ """Get the dimensionality of the backbone output."""
82
+ return self.d_input
83
+
84
+ def set_backbone(self):
85
+ """Set the backbone module based on the specified type."""
86
+ if self.backbone_type == 'navigation-backbone':
87
+ self.backbone = MiniGridBackbone(self.d_input)
88
+ elif self.backbone_type == 'classic-control-backbone':
89
+ self.backbone = ClassicControlBackbone(self.d_input)
90
+ else:
91
+ raise NotImplemented('The only backbone supported for RL are for navigation (symbolic C x H x W inputs) and classic control (vectors of length D).')
92
+ pass
93
+
94
+ def get_positional_embedding(self, d_backbone):
95
+ """Get the positional embedding module."""
96
+ return None
97
+
98
+
99
+ def get_synapses(self, synapse_depth, d_model, dropout):
100
+ """
101
+ Get the synapse module.
102
+
103
+ We found in our early experimentation that a single Linear, GLU and LayerNorm block performed worse than two blocks.
104
+ For that reason we set the default synapse depth to two blocks.
105
+
106
+ TODO: This is legacy and needs further experimentation to iron out.
107
+ """
108
+ if synapse_depth == 1:
109
+ return nn.Sequential(
110
+ nn.Dropout(dropout),
111
+ nn.LazyLinear(d_model*2),
112
+ nn.GLU(),
113
+ nn.LayerNorm(d_model),
114
+ nn.LazyLinear(d_model*2),
115
+ nn.GLU(),
116
+ nn.LayerNorm(d_model)
117
+ )
118
+ else:
119
+ return SynapseUNET(d_model, synapse_depth, 16, dropout)
120
+
121
+ def set_synchronisation_parameters(self, synch_type: str, n_synch: int, n_random_pairing_self: int = 0):
122
+ """Set the parameters for the synchronisation of neurons."""
123
+ if synch_type == 'action':
124
+ pass
125
+ elif synch_type == 'out':
126
+ left, right = self.initialize_left_right_neurons("out", self.d_model, n_synch, n_random_pairing_self)
127
+ self.register_buffer(f'out_neuron_indices_left', left)
128
+ self.register_buffer(f'out_neuron_indices_right', right)
129
+ self.register_parameter(f'decay_params_out', nn.Parameter(torch.zeros(self.synch_representation_size_out), requires_grad=True))
130
+ pass
131
+ else:
132
+ raise ValueError(f"Invalid synch_type: {synch_type}")
133
+
134
+ # --- Utilty Methods ---
135
+
136
+ def verify_args(self):
137
+ """Verify the validity of the input arguments."""
138
+ assert self.neuron_select_type in VALID_NEURON_SELECT_TYPES, \
139
+ f"Invalid neuron selection type: {self.neuron_select_type}"
140
+ assert self.neuron_select_type != 'random-pairing', \
141
+ f"Random pairing is not supported for RL."
142
+ assert self.backbone_type in ('navigation-backbone', 'classic-control-backbone'), \
143
+ f"Invalid backbone_type: {self.backbone_type}"
144
+ assert self.d_model >= (self.n_synch_out), \
145
+ "d_model must be >= n_synch_out for neuron subsets"
146
+ pass
147
+
148
+
149
+
150
+
151
+ def forward(self, x, hidden_states, track=False):
152
+
153
+ # --- Tracking Initialization ---
154
+ pre_activations_tracking = []
155
+ post_activations_tracking = []
156
+
157
+ # --- Featurise Input Data ---
158
+ features = self.backbone(x)
159
+
160
+ # --- Get Recurrent State ---
161
+ state_trace, activated_state_trace = hidden_states
162
+
163
+ # --- Recurrent Loop ---
164
+ for stepi in range(self.iterations):
165
+
166
+ pre_synapse_input = torch.concatenate((features.reshape(x.size(0), -1), activated_state_trace[:,:,-1]), -1)
167
+
168
+ # --- Apply Synapses ---
169
+ state = self.synapses(pre_synapse_input)
170
+ state_trace = torch.cat((state_trace[:, :, 1:], state.unsqueeze(-1)), dim=-1)
171
+
172
+ # --- Apply NLMs ---
173
+ activated_state = self.trace_processor(state_trace)
174
+ activated_state_trace = torch.concatenate((activated_state_trace[:,:,1:], activated_state.unsqueeze(-1)), -1)
175
+
176
+ # --- Tracking ---
177
+ if track:
178
+ pre_activations_tracking.append(state_trace[:,:,-1].detach().cpu().numpy())
179
+ post_activations_tracking.append(activated_state.detach().cpu().numpy())
180
+
181
+ hidden_states = (
182
+ state_trace,
183
+ activated_state_trace,
184
+ )
185
+
186
+ # --- Calculate Output Synchronisation ---
187
+ synchronisation_out = self.compute_synchronisation(activated_state_trace)
188
+
189
+ # --- Return Values ---
190
+ if track:
191
+ return synchronisation_out, hidden_states, np.array(pre_activations_tracking), np.array(post_activations_tracking)
192
+ return synchronisation_out, hidden_states
models/ctm_sort.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from models.ctm import ContinuousThoughtMachine
4
+
5
+ class ContinuousThoughtMachineSORT(ContinuousThoughtMachine):
6
+ """
7
+ Slight adaption of the CTM to work with the sort task.
8
+ """
9
+
10
+ def __init__(self,
11
+ iterations,
12
+ d_model,
13
+ d_input,
14
+ heads,
15
+ n_synch_out,
16
+ n_synch_action,
17
+ synapse_depth,
18
+ memory_length,
19
+ deep_nlms,
20
+ memory_hidden_dims,
21
+ do_layernorm_nlm,
22
+ backbone_type,
23
+ positional_embedding_type,
24
+ out_dims,
25
+ prediction_reshaper=[-1],
26
+ dropout=0,
27
+ dropout_nlm=None,
28
+ neuron_select_type='random-pairing',
29
+ n_random_pairing_self=0,
30
+ ):
31
+ super().__init__(
32
+ iterations=iterations,
33
+ d_model=d_model,
34
+ d_input=d_input,
35
+ heads=0,
36
+ n_synch_out=n_synch_out,
37
+ n_synch_action=0,
38
+ synapse_depth=synapse_depth,
39
+ memory_length=memory_length,
40
+ deep_nlms=deep_nlms,
41
+ memory_hidden_dims=memory_hidden_dims,
42
+ do_layernorm_nlm=do_layernorm_nlm,
43
+ backbone_type='none',
44
+ positional_embedding_type='none',
45
+ out_dims=out_dims,
46
+ prediction_reshaper=prediction_reshaper,
47
+ dropout=dropout,
48
+ dropout_nlm=dropout_nlm,
49
+ neuron_select_type=neuron_select_type,
50
+ n_random_pairing_self=n_random_pairing_self,
51
+ )
52
+
53
+ # --- Use a minimal CTM w/out input (action) synch ---
54
+ self.neuron_select_type_action = None
55
+ self.synch_representation_size_action = None
56
+
57
+ self.attention = None # Should already be None because super(... heads=0... )
58
+ self.q_proj = None # Should already be None because super(... heads=0... )
59
+ self.kv_proj = None # Should already be None because super(... heads=0... )
60
+
61
+
62
+
63
+
64
+ def forward(self, x, track=False):
65
+ B = x.size(0)
66
+ device = x.device
67
+
68
+ # --- Tracking Initialization ---
69
+ pre_activations_tracking = []
70
+ post_activations_tracking = []
71
+ synch_out_tracking = []
72
+ attention_tracking = []
73
+
74
+ # --- For SORT: no need to featurise data ---
75
+
76
+
77
+ # --- Initialise Recurrent State ---
78
+ state_trace = self.start_trace.unsqueeze(0).expand(B, -1, -1) # Shape: (B, H, T)
79
+ activated_state = self.start_activated_state.unsqueeze(0).expand(B, -1) # Shape: (B, H)
80
+
81
+ # --- Prepare Storage for Outputs per Iteration ---
82
+ predictions = torch.empty(B, self.out_dims, self.iterations, device=device, dtype=x.dtype)
83
+ certainties = torch.empty(B, 2, self.iterations, device=device, dtype=x.dtype)
84
+
85
+ # --- Initialise Recurrent Synch Values ---
86
+ r_out = torch.exp(-torch.clamp(self.decay_params_out, 0, 15)).unsqueeze(0).repeat(B, 1)
87
+ _, decay_alpha_out, decay_beta_out = self.compute_synchronisation(activated_state, None, None, r_out, synch_type='out')
88
+ # Compute learned weighting for synchronisation
89
+
90
+
91
+ # --- Recurrent Loop ---
92
+ for stepi in range(self.iterations):
93
+
94
+ pre_synapse_input = torch.concatenate((x, activated_state), dim=-1)
95
+
96
+ # --- Apply Synapses ---
97
+ state = self.synapses(pre_synapse_input)
98
+ # The 'state_trace' is the history of incoming pre-activations
99
+ state_trace = torch.cat((state_trace[:, :, 1:], state.unsqueeze(-1)), dim=-1)
100
+
101
+ # --- Apply Neuron-Level Models ---
102
+ activated_state = self.trace_processor(state_trace)
103
+ # One would also keep an 'activated_state_trace' as the history of outgoing post-activations
104
+ # BUT, this is unnecessary because the synchronisation calculation is fully linear and can be
105
+ # done using only the currect activated state (see compute_synchronisation method for explanation)
106
+
107
+ # --- Calculate Synchronisation for Output Predictions ---
108
+ synchronisation_out, decay_alpha_out, decay_beta_out = self.compute_synchronisation(activated_state, decay_alpha_out, decay_beta_out, r_out, synch_type='out')
109
+
110
+ # --- Get Predictions and Certainties ---
111
+ current_prediction = self.output_projector(synchronisation_out)
112
+ current_certainty = self.compute_certainty(current_prediction)
113
+
114
+ predictions[..., stepi] = current_prediction
115
+ certainties[..., stepi] = current_certainty
116
+
117
+ # --- Tracking ---
118
+ if track:
119
+ pre_activations_tracking.append(state_trace[:,:,-1].detach().cpu().numpy())
120
+ post_activations_tracking.append(activated_state.detach().cpu().numpy())
121
+ synch_out_tracking.append(synchronisation_out.detach().cpu().numpy())
122
+
123
+ # --- Return Values ---
124
+ if track:
125
+ return predictions, certainties, np.array(synch_out_tracking), np.array(pre_activations_tracking), np.array(post_activations_tracking), np.array(attention_tracking)
126
+ return predictions, certainties, synchronisation_out
models/ff.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ # Local imports (Assuming these contain necessary custom modules)
4
+ from models.modules import *
5
+ from models.resnet import resnet18, resnet34, resnet50, resnet101, resnet152
6
+
7
+
8
+ class FFBaseline(nn.Module):
9
+ """
10
+ LSTM Baseline.
11
+
12
+ Wrapper that lets us use the same backbone as the CTM and LSTM baselines, with a
13
+
14
+
15
+ Args:
16
+ d_model (int): workaround that projects final layer to this space so that parameter-matching is plausible.
17
+ backbone_type (str): Type of feature extraction backbone (e.g., 'resnet18-2', 'none').
18
+ out_dims (int): Dimensionality of the final output projection.
19
+ dropout (float): dropout in last layer
20
+ """
21
+
22
+ def __init__(self,
23
+ d_model,
24
+ backbone_type,
25
+ out_dims,
26
+ dropout=0,
27
+ ):
28
+ super(FFBaseline, self).__init__()
29
+
30
+ # --- Core Parameters ---
31
+ self.d_model = d_model
32
+ self.backbone_type = backbone_type
33
+ self.out_dims = out_dims
34
+
35
+ # --- Input Assertions ---
36
+ assert backbone_type in ['resnet18-1', 'resnet18-2', 'resnet18-3', 'resnet18-4',
37
+ 'resnet34-1', 'resnet34-2', 'resnet34-3', 'resnet34-4',
38
+ 'resnet50-1', 'resnet50-2', 'resnet50-3', 'resnet50-4',
39
+ 'resnet101-1', 'resnet101-2', 'resnet101-3', 'resnet101-4',
40
+ 'resnet152-1', 'resnet152-2', 'resnet152-3', 'resnet152-4',
41
+ 'none', 'shallow-wide', 'parity_backbone'], f"Invalid backbone_type: {backbone_type}"
42
+
43
+ # --- Backbone / Feature Extraction ---
44
+ self.initial_rgb = Identity() # Placeholder, potentially replaced if using ResNet
45
+
46
+
47
+ self.initial_rgb = nn.LazyConv2d(3, 1, 1) # Adapts input channels lazily
48
+ resnet_family = resnet18 # Default
49
+ if '34' in self.backbone_type: resnet_family = resnet34
50
+ if '50' in self.backbone_type: resnet_family = resnet50
51
+ if '101' in self.backbone_type: resnet_family = resnet101
52
+ if '152' in self.backbone_type: resnet_family = resnet152
53
+
54
+ # Determine which ResNet blocks to keep
55
+ block_num_str = self.backbone_type.split('-')[-1]
56
+ hyper_blocks_to_keep = list(range(1, int(block_num_str) + 1)) if block_num_str.isdigit() else [1, 2, 3, 4]
57
+
58
+ self.backbone = resnet_family(
59
+ 3, # initial_rgb handles input channels now
60
+ hyper_blocks_to_keep,
61
+ stride=2,
62
+ pretrained=False,
63
+ progress=True,
64
+ device="cpu", # Initialise on CPU, move later via .to(device)
65
+ do_initial_max_pool=True,
66
+ )
67
+
68
+
69
+ # At this point we will have a 4D tensor of features: [B, C, H, W]
70
+ # The following lets us scale up the resnet with d_model until it matches the CTM
71
+ self.output_projector = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), Squeeze(-1), Squeeze(-1), nn.LazyLinear(d_model), nn.ReLU(), nn.Dropout(dropout), nn.Linear(d_model, out_dims))
72
+
73
+
74
+ def forward(self, x):
75
+ return self.output_projector((self.backbone(self.initial_rgb(x))))
models/lstm.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+ import numpy as np
4
+ import math
5
+
6
+ from models.modules import ParityBackbone, LearnableFourierPositionalEncoding, MultiLearnableFourierPositionalEncoding, CustomRotationalEmbedding, CustomRotationalEmbedding1D, ShallowWide
7
+ from models.resnet import prepare_resnet_backbone
8
+ from models.utils import compute_normalized_entropy
9
+
10
+ from models.constants import (
11
+ VALID_BACKBONE_TYPES,
12
+ VALID_POSITIONAL_EMBEDDING_TYPES
13
+ )
14
+
15
+ class LSTMBaseline(nn.Module):
16
+ """
17
+ LSTM Baseline
18
+
19
+ Args:
20
+ iterations (int): Number of internal 'thought' steps (T, in paper).
21
+ d_model (int): Core dimensionality of the latent space.
22
+ d_input (int): Dimensionality of projected attention outputs or direct input features.
23
+ heads (int): Number of attention heads.
24
+ backbone_type (str): Type of feature extraction backbone (e.g., 'resnet18-2', 'none').
25
+ positional_embedding_type (str): Type of positional embedding for backbone features.
26
+ out_dims (int): Dimensionality of the final output projection.
27
+ prediction_reshaper (list): Shape for reshaping predictions before certainty calculation (task-specific).
28
+ dropout (float): Dropout rate.
29
+ """
30
+
31
+ def __init__(self,
32
+ iterations,
33
+ d_model,
34
+ d_input,
35
+ heads,
36
+ backbone_type,
37
+ num_layers,
38
+ positional_embedding_type,
39
+ out_dims,
40
+ prediction_reshaper=[-1],
41
+ dropout=0,
42
+ ):
43
+ super(LSTMBaseline, self).__init__()
44
+
45
+ # --- Core Parameters ---
46
+ self.iterations = iterations
47
+ self.d_model = d_model
48
+ self.d_input = d_input
49
+ self.prediction_reshaper = prediction_reshaper
50
+ self.backbone_type = backbone_type
51
+ self.positional_embedding_type = positional_embedding_type
52
+ self.out_dims = out_dims
53
+
54
+ # --- Assertions ---
55
+ self.verify_args()
56
+
57
+ # --- Input Processing ---
58
+ d_backbone = self.get_d_backbone()
59
+
60
+ self.set_initial_rgb()
61
+ self.set_backbone()
62
+ self.positional_embedding = self.get_positional_embedding(d_backbone)
63
+ self.kv_proj = self.get_kv_proj()
64
+ self.lstm = nn.LSTM(d_input, d_model, num_layers, batch_first=True, dropout=dropout)
65
+ self.q_proj = self.get_q_proj()
66
+ self.attention = self.get_attention(heads, dropout)
67
+ self.output_projector = nn.Sequential(nn.LazyLinear(out_dims))
68
+
69
+ # --- Start States ---
70
+ self.register_parameter('start_hidden_state', nn.Parameter(torch.zeros((num_layers, d_model)).uniform_(-math.sqrt(1/(d_model)), math.sqrt(1/(d_model))), requires_grad=True))
71
+ self.register_parameter('start_cell_state', nn.Parameter(torch.zeros((num_layers, d_model)).uniform_(-math.sqrt(1/(d_model)), math.sqrt(1/(d_model))), requires_grad=True))
72
+
73
+
74
+
75
+ # --- Core LSTM Methods ---
76
+
77
+ def compute_features(self, x):
78
+ """Applies backbone and positional embedding to input."""
79
+ x = self.initial_rgb(x)
80
+ self.kv_features = self.backbone(x)
81
+ pos_emb = self.positional_embedding(self.kv_features)
82
+ combined_features = (self.kv_features + pos_emb).flatten(2).transpose(1, 2)
83
+ kv = self.kv_proj(combined_features)
84
+ return kv
85
+
86
+ def compute_certainty(self, current_prediction):
87
+ """Compute the certainty of the current prediction."""
88
+ B = current_prediction.size(0)
89
+ reshaped_pred = current_prediction.reshape([B] +self.prediction_reshaper)
90
+ ne = compute_normalized_entropy(reshaped_pred)
91
+ current_certainty = torch.stack((ne, 1-ne), -1)
92
+ return current_certainty
93
+
94
+ # --- Setup Methods ---
95
+
96
+ def set_initial_rgb(self):
97
+ """Set the initial RGB processing module based on the backbone type."""
98
+ if 'resnet' in self.backbone_type:
99
+ self.initial_rgb = nn.LazyConv2d(3, 1, 1) # Adapts input channels lazily
100
+ else:
101
+ self.initial_rgb = nn.Identity()
102
+
103
+ def get_d_backbone(self):
104
+ """
105
+ Get the dimensionality of the backbone output, to be used for positional embedding setup.
106
+
107
+ This is a little bit complicated for resnets, but the logic should be easy enough to read below.
108
+ """
109
+ if self.backbone_type == 'shallow-wide':
110
+ return 2048
111
+ elif self.backbone_type == 'parity_backbone':
112
+ return self.d_input
113
+ elif 'resnet' in self.backbone_type:
114
+ if '18' in self.backbone_type or '34' in self.backbone_type:
115
+ if self.backbone_type.split('-')[1]=='1': return 64
116
+ elif self.backbone_type.split('-')[1]=='2': return 128
117
+ elif self.backbone_type.split('-')[1]=='3': return 256
118
+ elif self.backbone_type.split('-')[1]=='4': return 512
119
+ else:
120
+ raise NotImplementedError
121
+ else:
122
+ if self.backbone_type.split('-')[1]=='1': return 256
123
+ elif self.backbone_type.split('-')[1]=='2': return 512
124
+ elif self.backbone_type.split('-')[1]=='3': return 1024
125
+ elif self.backbone_type.split('-')[1]=='4': return 2048
126
+ else:
127
+ raise NotImplementedError
128
+ elif self.backbone_type == 'none':
129
+ return None
130
+ else:
131
+ raise ValueError(f"Invalid backbone_type: {self.backbone_type}")
132
+
133
+ def set_backbone(self):
134
+ """Set the backbone module based on the specified type."""
135
+ if self.backbone_type == 'shallow-wide':
136
+ self.backbone = ShallowWide()
137
+ elif self.backbone_type == 'parity_backbone':
138
+ d_backbone = self.get_d_backbone()
139
+ self.backbone = ParityBackbone(n_embeddings=2, d_embedding=d_backbone)
140
+ elif 'resnet' in self.backbone_type:
141
+ self.backbone = prepare_resnet_backbone(self.backbone_type)
142
+ elif self.backbone_type == 'none':
143
+ self.backbone = nn.Identity()
144
+ else:
145
+ raise ValueError(f"Invalid backbone_type: {self.backbone_type}")
146
+
147
+ def get_positional_embedding(self, d_backbone):
148
+ """Get the positional embedding module."""
149
+ if self.positional_embedding_type == 'learnable-fourier':
150
+ return LearnableFourierPositionalEncoding(d_backbone, gamma=1 / 2.5)
151
+ elif self.positional_embedding_type == 'multi-learnable-fourier':
152
+ return MultiLearnableFourierPositionalEncoding(d_backbone)
153
+ elif self.positional_embedding_type == 'custom-rotational':
154
+ return CustomRotationalEmbedding(d_backbone)
155
+ elif self.positional_embedding_type == 'custom-rotational-1d':
156
+ return CustomRotationalEmbedding1D(d_backbone)
157
+ elif self.positional_embedding_type == 'none':
158
+ return lambda x: 0 # Default no-op
159
+ else:
160
+ raise ValueError(f"Invalid positional_embedding_type: {self.positional_embedding_type}")
161
+
162
+ def get_attention(self, heads, dropout):
163
+ """Get the attention module."""
164
+ return nn.MultiheadAttention(self.d_input, heads, dropout, batch_first=True)
165
+
166
+ def get_kv_proj(self):
167
+ """Get the key-value projection module."""
168
+ return nn.Sequential(nn.LazyLinear(self.d_input), nn.LayerNorm(self.d_input))
169
+
170
+ def get_q_proj(self):
171
+ """Get the query projection module."""
172
+ return nn.LazyLinear(self.d_input)
173
+
174
+
175
+ def verify_args(self):
176
+ """Verify the validity of the input arguments."""
177
+
178
+ assert self.backbone_type in VALID_BACKBONE_TYPES + ['none'], \
179
+ f"Invalid backbone_type: {self.backbone_type}"
180
+
181
+ assert self.positional_embedding_type in VALID_POSITIONAL_EMBEDDING_TYPES + ['none'], \
182
+ f"Invalid positional_embedding_type: {self.positional_embedding_type}"
183
+
184
+ if self.backbone_type=='none' and self.positional_embedding_type!='none':
185
+ raise AssertionError("There should be no positional embedding if there is no backbone.")
186
+
187
+ pass
188
+
189
+
190
+
191
+
192
+ def forward(self, x, track=False):
193
+ """
194
+ Forward pass - Reverted to structure closer to user's working version.
195
+ Executes T=iterations steps.
196
+ """
197
+ B = x.size(0)
198
+ device = x.device
199
+
200
+ # --- Tracking Initialization ---
201
+ activations_tracking = []
202
+ attention_tracking = []
203
+
204
+ # --- Featurise Input Data ---
205
+ kv = self.compute_features(x)
206
+
207
+ # --- Initialise Recurrent State ---
208
+ hn = torch.repeat_interleave(self.start_hidden_state.unsqueeze(1), x.size(0), 1)
209
+ cn = torch.repeat_interleave(self.start_cell_state.unsqueeze(1), x.size(0), 1)
210
+ state_trace = [hn[-1]]
211
+
212
+ # --- Prepare Storage for Outputs per Iteration ---
213
+ predictions = torch.empty(B, self.out_dims, self.iterations, device=device, dtype=x.dtype)
214
+ certainties = torch.empty(B, 2, self.iterations, device=device, dtype=x.dtype)
215
+
216
+ # --- Recurrent Loop ---
217
+ for stepi in range(self.iterations):
218
+
219
+ # --- Interact with Data via Attention ---
220
+ q = self.q_proj(hn[-1].unsqueeze(1))
221
+ attn_out, attn_weights = self.attention(q, kv, kv, average_attn_weights=False, need_weights=True)
222
+ lstm_input = attn_out
223
+
224
+ # --- Apply LSTM ---
225
+ hidden_state, (hn,cn) = self.lstm(lstm_input, (hn, cn))
226
+ hidden_state = hidden_state.squeeze(1)
227
+ state_trace.append(hidden_state)
228
+
229
+ # --- Get Predictions and Certainties ---
230
+ current_prediction = self.output_projector(hidden_state)
231
+ current_certainty = self.compute_certainty(current_prediction)
232
+
233
+ predictions[..., stepi] = current_prediction
234
+ certainties[..., stepi] = current_certainty
235
+
236
+ # --- Tracking ---
237
+ if track:
238
+ activations_tracking.append(hidden_state.squeeze(1).detach().cpu().numpy())
239
+ attention_tracking.append(attn_weights.detach().cpu().numpy())
240
+
241
+ # --- Return Values ---
242
+ if track:
243
+ return predictions, certainties, None, np.zeros_like(activations_tracking), np.array(activations_tracking), np.array(attention_tracking)
244
+ return predictions, certainties, None
models/lstm_qamnist.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+ import torch.nn.functional as F # Used for GLU if not in modules
4
+ import numpy as np
5
+ import math
6
+
7
+ # Local imports (Assuming these contain necessary custom modules)
8
+ from models.modules import *
9
+ from models.utils import * # Assuming compute_decay, compute_normalized_entropy are here
10
+
11
+ class LSTMBaseline(nn.Module):
12
+ """
13
+ LSTM Baseline
14
+
15
+ Args:
16
+ iterations (int): Number of internal 'thought' steps (T, in paper).
17
+ d_model (int): Core dimensionality of the CTM's latent space (D, in paper).
18
+ d_input (int): Dimensionality of projected attention outputs or direct input features.
19
+ heads (int): Number of attention heads.
20
+ n_synch_out (int): Number of neurons used for output synchronisation (No, in paper).
21
+ n_synch_action (int): Number of neurons used for action/attention synchronisation (Ni, in paper).
22
+ synapse_depth (int): Depth of the synapse model (U-Net if > 1, else MLP).
23
+ memory_length (int): History length for Neuron-Level Models (M, in paper).
24
+ deep_nlms (bool): Use deeper (2-layer) NLMs if True, else linear.
25
+ memory_hidden_dims (int): Hidden dimension size for deep NLMs.
26
+ do_layernorm_nlm (bool): Apply LayerNorm within NLMs.
27
+ backbone_type (str): Type of feature extraction backbone (e.g., 'resnet18-2', 'none').
28
+ positional_embedding_type (str): Type of positional embedding for backbone features.
29
+ out_dims (int): Dimensionality of the final output projection.
30
+ prediction_reshaper (list): Shape for reshaping predictions before certainty calculation (task-specific).
31
+ dropout (float): Dropout rate.
32
+ """
33
+
34
+ def __init__(self,
35
+ iterations,
36
+ d_model,
37
+ d_input,
38
+ heads,
39
+ out_dims,
40
+ iterations_per_digit,
41
+ iterations_per_question_part,
42
+ iterations_for_answering,
43
+ prediction_reshaper=[-1],
44
+ dropout=0,
45
+ ):
46
+ super(LSTMBaseline, self).__init__()
47
+
48
+ # --- Core Parameters ---
49
+ self.iterations = iterations
50
+ self.d_model = d_model
51
+ self.prediction_reshaper = prediction_reshaper
52
+ self.out_dims = out_dims
53
+ self.d_input = d_input
54
+ self.backbone_type = 'qamnist_backbone'
55
+ self.iterations_per_digit = iterations_per_digit
56
+ self.iterations_per_question_part = iterations_per_question_part
57
+ self.total_iterations_for_answering = iterations_for_answering
58
+
59
+ # --- Backbone / Feature Extraction ---
60
+ self.backbone_digit = MNISTBackbone(d_input)
61
+ self.index_backbone = QAMNISTIndexEmbeddings(50, d_input)
62
+ self.operator_backbone = QAMNISTOperatorEmbeddings(2, d_input)
63
+
64
+ # --- Core CTM Modules ---
65
+ self.lstm_cell = nn.LSTMCell(d_input, d_model)
66
+ self.register_parameter('start_hidden_state', nn.Parameter(torch.zeros((d_model)).uniform_(-math.sqrt(1/(d_model)), math.sqrt(1/(d_model))), requires_grad=True))
67
+ self.register_parameter('start_cell_state', nn.Parameter(torch.zeros((d_model)).uniform_(-math.sqrt(1/(d_model)), math.sqrt(1/(d_model))), requires_grad=True))
68
+
69
+ # Attention
70
+ self.q_proj = nn.LazyLinear(d_input)
71
+ self.kv_proj = nn.Sequential(nn.LazyLinear(d_input), nn.LayerNorm(d_input))
72
+ self.attention = nn.MultiheadAttention(d_input, heads, dropout, batch_first=True)
73
+
74
+ # Output Projection
75
+ self.output_projector = nn.Sequential(nn.LazyLinear(out_dims))
76
+
77
+ def compute_certainty(self, current_prediction):
78
+ """Compute the certainty of the current prediction."""
79
+ B = current_prediction.size(0)
80
+ reshaped_pred = current_prediction.reshape([B] +self.prediction_reshaper)
81
+ ne = compute_normalized_entropy(reshaped_pred)
82
+ current_certainty = torch.stack((ne, 1-ne), -1)
83
+ return current_certainty
84
+
85
+ def get_kv_for_step(self, stepi, x, z, thought_steps, prev_input=None, prev_kv=None):
86
+ is_digit_step, is_question_step, is_answer_step = thought_steps.determine_step_type(stepi)
87
+
88
+ if is_digit_step:
89
+ current_input = x[:, stepi]
90
+ if prev_input is not None and torch.equal(current_input, prev_input):
91
+ return prev_kv, prev_input
92
+ kv = self.kv_proj(self.backbone_digit(current_input).flatten(2).permute(0, 2, 1))
93
+
94
+ elif is_question_step:
95
+ offset = stepi - thought_steps.total_iterations_for_digits
96
+ current_input = z[:, offset].squeeze(0)
97
+ if prev_input is not None and torch.equal(current_input, prev_input):
98
+ return prev_kv, prev_input
99
+ is_index_step, is_operator_step = thought_steps.determine_answer_step_type(stepi)
100
+ if is_index_step:
101
+ kv = self.kv_proj(self.index_backbone(current_input))
102
+ elif is_operator_step:
103
+ kv = self.kv_proj(self.operator_backbone(current_input))
104
+ else:
105
+ raise ValueError("Invalid step type for question processing.")
106
+
107
+ elif is_answer_step:
108
+ current_input = None
109
+ kv = torch.zeros((x.size(0), self.d_input), device=x.device)
110
+
111
+ else:
112
+ raise ValueError("Invalid step type.")
113
+
114
+ return kv, current_input
115
+
116
+ def forward(self, x, z, track=False):
117
+ """
118
+ Forward pass - Reverted to structure closer to user's working version.
119
+ Executes T=iterations steps.
120
+ """
121
+ B = x.size(0) # Batch size
122
+
123
+ # --- Tracking Initialization ---
124
+ activations_tracking = []
125
+ attention_tracking = [] # Note: reshaping this correctly requires knowing num_heads
126
+ embedding_tracking = []
127
+
128
+ thought_steps = ThoughtSteps(self.iterations_per_digit, self.iterations_per_question_part, self.total_iterations_for_answering, x.size(1), z.size(1))
129
+
130
+ # --- Step 2: Initialise Recurrent State ---
131
+ hidden_state = torch.repeat_interleave(self.start_hidden_state.unsqueeze(0), x.size(0), 0)
132
+ cell_state = torch.repeat_interleave(self.start_cell_state.unsqueeze(0), x.size(0), 0)
133
+
134
+ state_trace = [hidden_state]
135
+
136
+ device = hidden_state.device
137
+
138
+ # Storage for outputs per iteration
139
+ predictions = torch.empty(B, self.out_dims, thought_steps.total_iterations, device=device, dtype=x.dtype) # Adjust dtype if needed
140
+ certainties = torch.empty(B, 2, thought_steps.total_iterations, device=device, dtype=x.dtype) # Adjust dtype if needed
141
+
142
+ prev_input = None
143
+ prev_kv = None
144
+
145
+ # --- Recurrent Loop (T=iterations steps) ---
146
+ for stepi in range(thought_steps.total_iterations):
147
+
148
+ is_digit_step, is_question_step, is_answer_step = thought_steps.determine_step_type(stepi)
149
+ kv, prev_input = self.get_kv_for_step(stepi, x, z, thought_steps, prev_input, prev_kv)
150
+ prev_kv = kv
151
+
152
+ # --- Interact with Data via Attention ---
153
+ attn_weights = None
154
+ if is_digit_step:
155
+ q = self.q_proj(hidden_state).unsqueeze(1)
156
+ attn_out, attn_weights = self.attention(q, kv, kv, average_attn_weights=False, need_weights=True)
157
+ lstm_input = attn_out.squeeze(1)
158
+ else:
159
+ lstm_input = kv
160
+
161
+
162
+
163
+ hidden_state, cell_state = self.lstm_cell(lstm_input.squeeze(1), (hidden_state, cell_state))
164
+ state_trace.append(hidden_state)
165
+
166
+ # --- Get Predictions and Certainties ---
167
+ current_prediction = self.output_projector(hidden_state)
168
+ current_certainty = self.compute_certainty(current_prediction)
169
+
170
+ predictions[..., stepi] = current_prediction
171
+ certainties[..., stepi] = current_certainty
172
+
173
+ # --- Tracking ---
174
+ if track:
175
+ activations_tracking.append(hidden_state.squeeze(1).detach().cpu().numpy())
176
+ if attn_weights is not None:
177
+ attention_tracking.append(attn_weights.detach().cpu().numpy())
178
+ if is_question_step:
179
+ embedding_tracking.append(kv.detach().cpu().numpy())
180
+
181
+ # --- Return Values ---
182
+ if track:
183
+ return predictions, certainties, None, np.array(activations_tracking), np.array(activations_tracking), np.array(attention_tracking), np.array(embedding_tracking)
184
+ return predictions, certainties, None
models/lstm_rl.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+ import torch.nn.functional as F # Used for GLU if not in modules
4
+ import numpy as np
5
+ import math
6
+
7
+ # Local imports (Assuming these contain necessary custom modules)
8
+ from models.modules import *
9
+ from models.utils import * # Assuming compute_decay, compute_normalized_entropy are here
10
+
11
+
12
+ class LSTMBaseline(nn.Module):
13
+ """
14
+
15
+ LSTM Baseline
16
+
17
+ Args:
18
+ iterations (int): Number of internal 'thought' steps (T, in paper).
19
+ d_model (int): Core dimensionality of the CTM's latent space (D, in paper).
20
+ d_input (int): Dimensionality of projected attention outputs or direct input features.
21
+ backbone_type (str): Type of feature extraction backbone (e.g., 'resnet18-2', 'none').
22
+ """
23
+
24
+ def __init__(self,
25
+ iterations,
26
+ d_model,
27
+ d_input,
28
+ backbone_type,
29
+ ):
30
+ super(LSTMBaseline, self).__init__()
31
+
32
+ # --- Core Parameters ---
33
+ self.iterations = iterations
34
+ self.d_model = d_model
35
+ self.backbone_type = backbone_type
36
+
37
+ # --- Input Assertions ---
38
+ assert backbone_type in ('navigation-backbone', 'classic-control-backbone'), f"Invalid backbone_type: {backbone_type}"
39
+
40
+ # --- Backbone / Feature Extraction ---
41
+ if self.backbone_type == 'navigation-backbone':
42
+ grid_size = 7
43
+ self.backbone = MiniGridBackbone(d_input=d_input, grid_size=grid_size)
44
+ lstm_cell_input_dim = grid_size * grid_size * d_input
45
+
46
+ elif self.backbone_type == 'classic-control-backbone':
47
+ self.backbone = ClassicControlBackbone(d_input=d_input)
48
+ lstm_cell_input_dim = d_input
49
+
50
+ else:
51
+ raise NotImplemented('The only backbone supported for RL are for navigation (symbolic C x H x W inputs) and classic control (vectors of length D).')
52
+
53
+ # --- Core LSTM Modules ---
54
+ self.lstm_cell = nn.LSTMCell(lstm_cell_input_dim, d_model)
55
+ self.register_parameter('start_hidden_state', nn.Parameter(torch.zeros((d_model)).uniform_(-math.sqrt(1/(d_model)), math.sqrt(1/(d_model))), requires_grad=True))
56
+ self.register_parameter('start_cell_state', nn.Parameter(torch.zeros((d_model)).uniform_(-math.sqrt(1/(d_model)), math.sqrt(1/(d_model))), requires_grad=True))
57
+
58
+ def compute_features(self, x):
59
+ """Applies backbone and positional embedding to input."""
60
+ return self.backbone(x)
61
+
62
+
63
+ def forward(self, x, hidden_states, track=False):
64
+ """
65
+ Forward pass - Reverted to structure closer to user's working version.
66
+ Executes T=iterations steps.
67
+ """
68
+
69
+ # --- Tracking Initialization ---
70
+ activations_tracking = []
71
+
72
+ # --- Featurise Input Data ---
73
+ features = self.compute_features(x)
74
+
75
+ hidden_state = hidden_states[0]
76
+ cell_state = hidden_states[1]
77
+
78
+ # --- Recurrent Loop ---
79
+ for stepi in range(self.iterations):
80
+
81
+ lstm_input = features.reshape(x.size(0), -1)
82
+ hidden_state, cell_state = self.lstm_cell(lstm_input.squeeze(1), (hidden_state, cell_state))
83
+
84
+ # --- Tracking ---
85
+ if track:
86
+ activations_tracking.append(hidden_state.squeeze(1).detach().cpu().numpy())
87
+
88
+ hidden_states = (
89
+ hidden_state,
90
+ cell_state
91
+ )
92
+
93
+ # --- Return Values ---
94
+ if track:
95
+ return hidden_state, hidden_states, np.array(activations_tracking), np.array(activations_tracking)
96
+ return hidden_state, hidden_states
models/modules.py ADDED
@@ -0,0 +1,692 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F # Used for GLU
4
+ import math
5
+ import numpy as np
6
+
7
+ # Assuming 'add_coord_dim' is defined in models.utils
8
+ from models.utils import add_coord_dim
9
+
10
+ # --- Basic Utility Modules ---
11
+
12
+ class Identity(nn.Module):
13
+ """
14
+ Identity Module.
15
+
16
+ Returns the input tensor unchanged. Useful as a placeholder or a no-op layer
17
+ in nn.Sequential containers or conditional network parts.
18
+ """
19
+ def __init__(self):
20
+ super().__init__()
21
+
22
+ def forward(self, x):
23
+ return x
24
+
25
+
26
+ class Squeeze(nn.Module):
27
+ """
28
+ Squeeze Module.
29
+
30
+ Removes a specified dimension of size 1 from the input tensor.
31
+ Useful for incorporating tensor dimension squeezing within nn.Sequential.
32
+
33
+ Args:
34
+ dim (int): The dimension to squeeze.
35
+ """
36
+ def __init__(self, dim):
37
+ super().__init__()
38
+ self.dim = dim
39
+
40
+ def forward(self, x):
41
+ return x.squeeze(self.dim)
42
+
43
+ # --- Core CTM Component Modules ---
44
+
45
+ class SynapseUNET(nn.Module):
46
+ """
47
+ UNET-style architecture for the Synapse Model (f_theta1 in the paper).
48
+
49
+ This module implements the connections between neurons in the CTM's latent
50
+ space. It processes the combined input (previous post-activation state z^t
51
+ and attention output o^t) to produce the pre-activations (a^t) for the
52
+ next internal tick (Eq. 1 in the paper).
53
+
54
+ While a simpler Linear or MLP layer can be used, the paper notes
55
+ that this U-Net structure empirically performed better, suggesting benefit
56
+ from more flexible synaptic connections[cite: 79, 80]. This implementation
57
+ uses `depth` points in linspace and creates `depth-1` down/up blocks.
58
+
59
+ Args:
60
+ in_dims (int): Number of input dimensions (d_model + d_input).
61
+ out_dims (int): Number of output dimensions (d_model).
62
+ depth (int): Determines structure size; creates `depth-1` down/up blocks.
63
+ minimum_width (int): Smallest channel width at the U-Net bottleneck.
64
+ dropout (float): Dropout rate applied within down/up projections.
65
+ """
66
+ def __init__(self,
67
+ out_dims,
68
+ depth,
69
+ minimum_width=16,
70
+ dropout=0.0):
71
+ super().__init__()
72
+ self.width_out = out_dims
73
+ self.n_deep = depth # Store depth just for reference if needed
74
+
75
+ # Define UNET structure based on depth
76
+ # Creates `depth` width values, leading to `depth-1` blocks
77
+ widths = np.linspace(out_dims, minimum_width, depth)
78
+
79
+ # Initial projection layer
80
+ self.first_projection = nn.Sequential(
81
+ nn.LazyLinear(int(widths[0])), # Project to the first width
82
+ nn.LayerNorm(int(widths[0])),
83
+ nn.SiLU()
84
+ )
85
+
86
+ # Downward path (encoding layers)
87
+ self.down_projections = nn.ModuleList()
88
+ self.up_projections = nn.ModuleList()
89
+ self.skip_lns = nn.ModuleList()
90
+ num_blocks = len(widths) - 1 # Number of down/up blocks created
91
+
92
+ for i in range(num_blocks):
93
+ # Down block: widths[i] -> widths[i+1]
94
+ self.down_projections.append(nn.Sequential(
95
+ nn.Dropout(dropout),
96
+ nn.Linear(int(widths[i]), int(widths[i+1])),
97
+ nn.LayerNorm(int(widths[i+1])),
98
+ nn.SiLU()
99
+ ))
100
+ # Up block: widths[i+1] -> widths[i]
101
+ # Note: Up blocks are added in order matching down blocks conceptually,
102
+ # but applied in reverse order in the forward pass.
103
+ self.up_projections.append(nn.Sequential(
104
+ nn.Dropout(dropout),
105
+ nn.Linear(int(widths[i+1]), int(widths[i])),
106
+ nn.LayerNorm(int(widths[i])),
107
+ nn.SiLU()
108
+ ))
109
+ # Skip connection LayerNorm operates on width[i]
110
+ self.skip_lns.append(nn.LayerNorm(int(widths[i])))
111
+
112
+ def forward(self, x):
113
+ # Initial projection
114
+ out_first = self.first_projection(x)
115
+
116
+ # Downward path, storing outputs for skip connections
117
+ outs_down = [out_first]
118
+ for layer in self.down_projections:
119
+ outs_down.append(layer(outs_down[-1]))
120
+ # outs_down contains [level_0, level_1, ..., level_depth-1=bottleneck] outputs
121
+
122
+ # Upward path, starting from the bottleneck output
123
+ outs_up = outs_down[-1] # Bottleneck activation
124
+ num_blocks = len(self.up_projections) # Should be depth - 1
125
+
126
+ for i in range(num_blocks):
127
+ # Apply up projection in reverse order relative to down blocks
128
+ # up_projection[num_blocks - 1 - i] processes deeper features first
129
+ up_layer_idx = num_blocks - 1 - i
130
+ out_up = self.up_projections[up_layer_idx](outs_up)
131
+
132
+ # Get corresponding skip connection from downward path
133
+ # skip_connection index = num_blocks - 1 - i (same as up_layer_idx)
134
+ # This matches the output width of the up_projection[up_layer_idx]
135
+ skip_idx = up_layer_idx
136
+ skip_connection = outs_down[skip_idx]
137
+
138
+ # Add skip connection and apply LayerNorm corresponding to this level
139
+ # skip_lns index also corresponds to the level = skip_idx
140
+ outs_up = self.skip_lns[skip_idx](out_up + skip_connection)
141
+
142
+ # The final output after all up-projections
143
+ return outs_up
144
+
145
+
146
+ class SuperLinear(nn.Module):
147
+ """
148
+ SuperLinear Layer: Implements Neuron-Level Models (NLMs) for the CTM.
149
+
150
+ This layer is the core component enabling Neuron-Level Models (NLMs),
151
+ referred to as g_theta_d in the paper (Eq. 3). It applies N independent
152
+ linear transformations (or small MLPs when used sequentially) to corresponding
153
+ slices of the input tensor along a specified dimension (typically the neuron
154
+ or feature dimension).
155
+
156
+ How it works for NLMs:
157
+ - The input `x` is expected to be the pre-activation history for each neuron,
158
+ shaped (batch_size, n_neurons=N, history_length=in_dims).
159
+ - This layer holds unique weights (`w1`) and biases (`b1`) for *each* of the `N` neurons.
160
+ `w1` has shape (in_dims, out_dims, N), `b1` has shape (1, N, out_dims).
161
+ - `torch.einsum('bni,iog->bno', x, self.w1)` performs N independent matrix
162
+ multiplications in parallel (mapping from dim `i` to `o` for each neuron `n`):
163
+ - For each neuron `n` (from 0 to N-1):
164
+ - It takes the neuron's history `x[:, n, :]` (shape B, in_dims).
165
+ - Multiplies it by the neuron's unique weight matrix `self.w1[:, :, n]` (shape in_dims, out_dims).
166
+ - Resulting in `out[:, n, :]` (shape B, out_dims).
167
+ - The unique bias `self.b1[:, n, :]` is added.
168
+ - The result is squeezed on the last dim (if out_dims=1) and scaled by `T`.
169
+
170
+ This allows each neuron `d` to process its temporal history `A_d^t` using
171
+ its private parameters `theta_d` to produce the post-activation `z_d^{t+1}`,
172
+ enabling the fine-grained temporal dynamics central to the CTM[cite: 7, 30, 85].
173
+ It's typically used within the `trace_processor` module of the main CTM class.
174
+
175
+ Args:
176
+ in_dims (int): Input dimension (typically `memory_length`).
177
+ out_dims (int): Output dimension per neuron.
178
+ N (int): Number of independent linear models (typically `d_model`).
179
+ T (float): Initial value for learnable temperature/scaling factor applied to output.
180
+ do_norm (bool): Apply Layer Normalization to the input history before linear transform.
181
+ dropout (float): Dropout rate applied to the input.
182
+ """
183
+ def __init__(self,
184
+ in_dims,
185
+ out_dims,
186
+ N,
187
+ T=1.0,
188
+ do_norm=False,
189
+ dropout=0):
190
+ super().__init__()
191
+ # N is the number of neurons (d_model), in_dims is the history length (memory_length)
192
+ self.dropout = nn.Dropout(dropout) if dropout > 0 else Identity()
193
+ self.in_dims = in_dims # Corresponds to memory_length
194
+ # LayerNorm applied across the history dimension for each neuron independently
195
+ self.layernorm = nn.LayerNorm(in_dims, elementwise_affine=True) if do_norm else Identity()
196
+ self.do_norm = do_norm
197
+
198
+ # Initialize weights and biases
199
+ # w1 shape: (memory_length, out_dims, d_model)
200
+ self.register_parameter('w1', nn.Parameter(
201
+ torch.empty((in_dims, out_dims, N)).uniform_(
202
+ -1/math.sqrt(in_dims + out_dims),
203
+ 1/math.sqrt(in_dims + out_dims)
204
+ ), requires_grad=True)
205
+ )
206
+ # b1 shape: (1, d_model, out_dims)
207
+ self.register_parameter('b1', nn.Parameter(torch.zeros((1, N, out_dims)), requires_grad=True))
208
+ # Learnable temperature/scaler T
209
+ self.register_parameter('T', nn.Parameter(torch.Tensor([T])))
210
+
211
+ def forward(self, x):
212
+ """
213
+ Args:
214
+ x (torch.Tensor): Input tensor, expected shape (B, N, in_dims)
215
+ where B=batch, N=d_model, in_dims=memory_length.
216
+ Returns:
217
+ torch.Tensor: Output tensor, shape (B, N) after squeeze(-1).
218
+ """
219
+ # Input shape: (B, D, M) where D=d_model=N neurons in CTM, M=history/memory length
220
+ out = self.dropout(x)
221
+ # LayerNorm across the memory_length dimension (dim=-1)
222
+ out = self.layernorm(out) # Shape remains (B, N, M)
223
+
224
+ # Apply N independent linear models using einsum
225
+ # einsum('BDM,MHD->BDH', ...)
226
+ # x: (B=batch size, D=N neurons, one NLM per each of these, M=history/memory length)
227
+ # w1: (M, H=hidden dims if using MLP, otherwise output, D=N neurons, parallel)
228
+ # b1: (1, D=N neurons, H)
229
+ # einsum result: (B, D, H)
230
+ # Applying bias requires matching shapes, b1 is broadcasted.
231
+ out = torch.einsum('BDM,MHD->BDH', out, self.w1) + self.b1
232
+
233
+ # Squeeze the output dimension (assumed to be 1 usually) and scale by T
234
+ # This matches the original code's structure exactly.
235
+ out = out.squeeze(-1) / self.T
236
+ return out
237
+
238
+
239
+ # --- Backbone Modules ---
240
+
241
+ class ParityBackbone(nn.Module):
242
+ def __init__(self, n_embeddings, d_embedding):
243
+ super(ParityBackbone, self).__init__()
244
+ self.embedding = nn.Embedding(n_embeddings, d_embedding)
245
+
246
+ def forward(self, x):
247
+ """
248
+ Maps -1 (negative parity) to 0 and 1 (positive) to 1
249
+ """
250
+ x = (x == 1).long()
251
+ return self.embedding(x.long()).transpose(1, 2) # Transpose for compatibility with other backbones
252
+
253
+ class QAMNISTOperatorEmbeddings(nn.Module):
254
+ def __init__(self, num_operator_types, d_projection):
255
+ super(QAMNISTOperatorEmbeddings, self).__init__()
256
+ self.embedding = nn.Embedding(num_operator_types, d_projection)
257
+
258
+ def forward(self, x):
259
+ # -1 for plus and -2 for minus
260
+ return self.embedding(-x - 1)
261
+
262
+ class QAMNISTIndexEmbeddings(torch.nn.Module):
263
+ def __init__(self, max_seq_length, embedding_dim):
264
+ super().__init__()
265
+ self.max_seq_length = max_seq_length
266
+ self.embedding_dim = embedding_dim
267
+
268
+ embedding = torch.zeros(max_seq_length, embedding_dim)
269
+ position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
270
+ div_term = torch.exp(torch.arange(0, embedding_dim, 2).float() * (-math.log(10000.0) / embedding_dim))
271
+
272
+ embedding[:, 0::2] = torch.sin(position * div_term)
273
+ embedding[:, 1::2] = torch.cos(position * div_term)
274
+
275
+ self.register_buffer('embedding', embedding)
276
+
277
+ def forward(self, x):
278
+ return self.embedding[x]
279
+
280
+ class ThoughtSteps:
281
+ """
282
+ Helper class for managing "thought steps" in the ctm_qamnist pipeline.
283
+
284
+ Args:
285
+ iterations_per_digit (int): Number of iterations for each digit.
286
+ iterations_per_question_part (int): Number of iterations for each question part.
287
+ total_iterations_for_answering (int): Total number of iterations for answering.
288
+ total_iterations_for_digits (int): Total number of iterations for digits.
289
+ total_iterations_for_question (int): Total number of iterations for question.
290
+ """
291
+ def __init__(self, iterations_per_digit, iterations_per_question_part, total_iterations_for_answering, total_iterations_for_digits, total_iterations_for_question):
292
+ self.iterations_per_digit = iterations_per_digit
293
+ self.iterations_per_question_part = iterations_per_question_part
294
+ self.total_iterations_for_digits = total_iterations_for_digits
295
+ self.total_iterations_for_question = total_iterations_for_question
296
+ self.total_iterations_for_answering = total_iterations_for_answering
297
+ self.total_iterations = self.total_iterations_for_digits + self.total_iterations_for_question + self.total_iterations_for_answering
298
+
299
+ def determine_step_type(self, stepi: int):
300
+ is_digit_step = stepi < self.total_iterations_for_digits
301
+ is_question_step = self.total_iterations_for_digits <= stepi < self.total_iterations_for_digits + self.total_iterations_for_question
302
+ is_answer_step = stepi >= self.total_iterations_for_digits + self.total_iterations_for_question
303
+ return is_digit_step, is_question_step, is_answer_step
304
+
305
+ def determine_answer_step_type(self, stepi: int):
306
+ step_within_questions = stepi - self.total_iterations_for_digits
307
+ if step_within_questions % (2 * self.iterations_per_question_part) < self.iterations_per_question_part:
308
+ is_index_step = True
309
+ is_operator_step = False
310
+ else:
311
+ is_index_step = False
312
+ is_operator_step = True
313
+ return is_index_step, is_operator_step
314
+
315
+ class MNISTBackbone(nn.Module):
316
+ """
317
+ Simple backbone for MNIST feature extraction.
318
+ """
319
+ def __init__(self, d_input):
320
+ super(MNISTBackbone, self).__init__()
321
+ self.layers = nn.Sequential(
322
+ nn.LazyConv2d(d_input, kernel_size=3, stride=1, padding=1),
323
+ nn.BatchNorm2d(d_input),
324
+ nn.ReLU(),
325
+ nn.MaxPool2d(2, 2),
326
+ nn.LazyConv2d(d_input, kernel_size=3, stride=1, padding=1),
327
+ nn.BatchNorm2d(d_input),
328
+ nn.ReLU(),
329
+ nn.MaxPool2d(2, 2),
330
+ )
331
+
332
+ def forward(self, x):
333
+ return self.layers(x)
334
+
335
+
336
+ class MiniGridBackbone(nn.Module):
337
+ def __init__(self, d_input, grid_size=7, num_objects=11, num_colors=6, num_states=3, embedding_dim=8):
338
+ super().__init__()
339
+ self.object_embedding = nn.Embedding(num_objects, embedding_dim)
340
+ self.color_embedding = nn.Embedding(num_colors, embedding_dim)
341
+ self.state_embedding = nn.Embedding(num_states, embedding_dim)
342
+
343
+ self.position_embedding = nn.Embedding(grid_size * grid_size, embedding_dim)
344
+
345
+ self.project_to_d_projection = nn.Sequential(
346
+ nn.Linear(embedding_dim * 4, d_input * 2),
347
+ nn.GLU(),
348
+ nn.LayerNorm(d_input),
349
+ nn.Linear(d_input, d_input * 2),
350
+ nn.GLU(),
351
+ nn.LayerNorm(d_input)
352
+ )
353
+
354
+ def forward(self, x):
355
+ x = x.long()
356
+ B, H, W, C = x.size()
357
+
358
+ object_idx = x[:,:,:, 0]
359
+ color_idx = x[:,:,:, 1]
360
+ state_idx = x[:,:,:, 2]
361
+
362
+ obj_embed = self.object_embedding(object_idx)
363
+ color_embed = self.color_embedding(color_idx)
364
+ state_embed = self.state_embedding(state_idx)
365
+
366
+ pos_idx = torch.arange(H * W, device=x.device).view(1, H, W).expand(B, -1, -1)
367
+ pos_embed = self.position_embedding(pos_idx)
368
+
369
+ out = self.project_to_d_projection(torch.cat([obj_embed, color_embed, state_embed, pos_embed], dim=-1))
370
+ return out
371
+
372
+ class ClassicControlBackbone(nn.Module):
373
+ def __init__(self, d_input):
374
+ super().__init__()
375
+ self.input_projector = nn.Sequential(
376
+ nn.Flatten(),
377
+ nn.LazyLinear(d_input * 2),
378
+ nn.GLU(),
379
+ nn.LayerNorm(d_input),
380
+ nn.LazyLinear(d_input * 2),
381
+ nn.GLU(),
382
+ nn.LayerNorm(d_input)
383
+ )
384
+
385
+ def forward(self, x):
386
+ return self.input_projector(x)
387
+
388
+
389
+ class ShallowWide(nn.Module):
390
+ """
391
+ Simple, wide, shallow convolutional backbone for image feature extraction.
392
+
393
+ Alternative to ResNet, uses grouped convolutions and GLU activations.
394
+ Fixed structure, useful for specific experiments.
395
+ """
396
+ def __init__(self):
397
+ super(ShallowWide, self).__init__()
398
+ # LazyConv2d infers input channels
399
+ self.layers = nn.Sequential(
400
+ nn.LazyConv2d(4096, kernel_size=3, stride=2, padding=1), # Output channels = 4096
401
+ nn.GLU(dim=1), # Halves channels to 2048
402
+ nn.BatchNorm2d(2048),
403
+ # Grouped convolution maintains width but processes groups independently
404
+ nn.Conv2d(2048, 4096, kernel_size=3, stride=1, padding=1, groups=32),
405
+ nn.GLU(dim=1), # Halves channels to 2048
406
+ nn.BatchNorm2d(2048)
407
+ )
408
+ def forward(self, x):
409
+ return self.layers(x)
410
+
411
+
412
+ class PretrainedResNetWrapper(nn.Module):
413
+ """
414
+ Wrapper to use standard pre-trained ResNet models from torchvision.
415
+
416
+ Loads a specified ResNet architecture pre-trained on ImageNet, removes the
417
+ final classification layer (fc), average pooling, and optionally later layers
418
+ (e.g., layer4), allowing it to be used as a feature extractor backbone.
419
+
420
+ Args:
421
+ resnet_type (str): Name of the ResNet model (e.g., 'resnet18', 'resnet50').
422
+ fine_tune (bool): If False, freezes the weights of the pre-trained backbone.
423
+ """
424
+ def __init__(self, resnet_type, fine_tune=True):
425
+ super(PretrainedResNetWrapper, self).__init__()
426
+ self.resnet_type = resnet_type
427
+ self.backbone = torch.hub.load('pytorch/vision:v0.10.0', resnet_type, pretrained=True)
428
+
429
+ if not fine_tune:
430
+ for param in self.backbone.parameters():
431
+ param.requires_grad = False
432
+
433
+ # Remove final layers to use as feature extractor
434
+ self.backbone.avgpool = Identity()
435
+ self.backbone.fc = Identity()
436
+ # Keep layer4 by default, user can modify instance if needed
437
+ # self.backbone.layer4 = Identity()
438
+
439
+ def forward(self, x):
440
+ # Get features from the modified ResNet
441
+ out = self.backbone(x)
442
+
443
+ # Reshape output to (B, C, H, W) - This is heuristic based on original comment.
444
+ # User might need to adjust this based on which layers are kept/removed.
445
+ # Infer C based on ResNet type (example values)
446
+ nc = 256 if ('18' in self.resnet_type or '34' in self.resnet_type) else 512 if '50' in self.resnet_type else 1024 if '101' in self.resnet_type else 2048 # Approx for layer3/4 output channel numbers
447
+ # Infer H, W assuming output is flattened C * H * W
448
+ num_features = out.shape[-1]
449
+ # This calculation assumes nc is correct and feature map is square
450
+ wh_squared = num_features / nc
451
+ if wh_squared < 0 or not float(wh_squared).is_integer():
452
+ print(f"Warning: Cannot reliably reshape PretrainedResNetWrapper output. nc={nc}, num_features={num_features}")
453
+ # Return potentially flattened features if reshape fails
454
+ return out
455
+ wh = int(np.sqrt(wh_squared))
456
+
457
+ return out.reshape(x.size(0), nc, wh, wh)
458
+
459
+ # --- Positional Encoding Modules ---
460
+
461
+ class LearnableFourierPositionalEncoding(nn.Module):
462
+ """
463
+ Learnable Fourier Feature Positional Encoding.
464
+
465
+ Implements Algorithm 1 from "Learnable Fourier Features for Multi-Dimensional
466
+ Spatial Positional Encoding" (https://arxiv.org/pdf/2106.02795.pdf).
467
+ Provides positional information for 2D feature maps.
468
+
469
+ Args:
470
+ d_model (int): The output dimension of the positional encoding (D).
471
+ G (int): Positional groups (default 1).
472
+ M (int): Dimensionality of input coordinates (default 2 for H, W).
473
+ F_dim (int): Dimension of the Fourier features.
474
+ H_dim (int): Hidden dimension of the MLP.
475
+ gamma (float): Initialization scale for the Fourier projection weights (Wr).
476
+ """
477
+ def __init__(self, d_model,
478
+ G=1, M=2,
479
+ F_dim=256,
480
+ H_dim=128,
481
+ gamma=1/2.5,
482
+ ):
483
+ super().__init__()
484
+ self.G = G
485
+ self.M = M
486
+ self.F_dim = F_dim
487
+ self.H_dim = H_dim
488
+ self.D = d_model
489
+ self.gamma = gamma
490
+
491
+ self.Wr = nn.Linear(self.M, self.F_dim // 2, bias=False)
492
+ self.mlp = nn.Sequential(
493
+ nn.Linear(self.F_dim, self.H_dim, bias=True),
494
+ nn.GLU(), # Halves H_dim
495
+ nn.Linear(self.H_dim // 2, self.D // self.G),
496
+ nn.LayerNorm(self.D // self.G)
497
+ )
498
+
499
+ self.init_weights()
500
+
501
+ def init_weights(self):
502
+ nn.init.normal_(self.Wr.weight.data, mean=0, std=self.gamma ** -2)
503
+
504
+ def forward(self, x):
505
+ """
506
+ Computes positional encodings for the input feature map x.
507
+
508
+ Args:
509
+ x (torch.Tensor): Input feature map, shape (B, C, H, W).
510
+
511
+ Returns:
512
+ torch.Tensor: Positional encoding tensor, shape (B, D, H, W).
513
+ """
514
+ B, C, H, W = x.shape
515
+ # Creates coordinates based on (H, W) and repeats for batch B.
516
+ # Takes x[:,0] assuming channel dim isn't needed for coords.
517
+ x_coord = add_coord_dim(x[:,0]) # Expects (B, H, W) -> (B, H, W, 2)
518
+
519
+ # Compute Fourier features
520
+ projected = self.Wr(x_coord) # (B, H, W, F_dim // 2)
521
+ cosines = torch.cos(projected)
522
+ sines = torch.sin(projected)
523
+ F = (1.0 / math.sqrt(self.F_dim)) * torch.cat([cosines, sines], dim=-1) # (B, H, W, F_dim)
524
+
525
+ # Project features through MLP
526
+ Y = self.mlp(F) # (B, H, W, D // G)
527
+
528
+ # Reshape to (B, D, H, W)
529
+ PEx = Y.permute(0, 3, 1, 2) # Assuming G=1
530
+ return PEx
531
+
532
+
533
+ class MultiLearnableFourierPositionalEncoding(nn.Module):
534
+ """
535
+ Combines multiple LearnableFourierPositionalEncoding modules with different
536
+ initialization scales (gamma) via a learnable weighted sum.
537
+
538
+ Allows the model to learn an optimal combination of positional frequencies.
539
+
540
+ Args:
541
+ d_model (int): Output dimension of the encoding.
542
+ G, M, F_dim, H_dim: Parameters passed to underlying LearnableFourierPositionalEncoding.
543
+ gamma_range (list[float]): Min and max gamma values for the linspace.
544
+ N (int): Number of parallel embedding modules to create.
545
+ """
546
+ def __init__(self, d_model,
547
+ G=1, M=2,
548
+ F_dim=256,
549
+ H_dim=128,
550
+ gamma_range=[1.0, 0.1], # Default range
551
+ N=10,
552
+ ):
553
+ super().__init__()
554
+ self.embedders = nn.ModuleList()
555
+ for gamma in np.linspace(gamma_range[0], gamma_range[1], N):
556
+ self.embedders.append(LearnableFourierPositionalEncoding(d_model, G, M, F_dim, H_dim, gamma))
557
+
558
+ # Renamed parameter from 'combination' to 'combination_weights' for clarity only in comments
559
+ # Actual registered name remains 'combination' as in original code
560
+ self.register_parameter('combination', torch.nn.Parameter(torch.ones(N), requires_grad=True))
561
+ self.N = N
562
+
563
+
564
+ def forward(self, x):
565
+ """
566
+ Computes combined positional encoding.
567
+
568
+ Args:
569
+ x (torch.Tensor): Input feature map, shape (B, C, H, W).
570
+
571
+ Returns:
572
+ torch.Tensor: Combined positional encoding tensor, shape (B, D, H, W).
573
+ """
574
+ # Compute embeddings from all modules and stack: (N, B, D, H, W)
575
+ pos_embs = torch.stack([emb(x) for emb in self.embedders], dim=0)
576
+
577
+ # Compute combination weights using softmax
578
+ # Use registered parameter name 'combination'
579
+ # Reshape weights for broadcasting: (N,) -> (N, 1, 1, 1, 1)
580
+ weights = F.softmax(self.combination, dim=-1).view(self.N, 1, 1, 1, 1)
581
+
582
+ # Compute weighted sum over the N dimension
583
+ combined_emb = (pos_embs * weights).sum(0) # (B, D, H, W)
584
+ return combined_emb
585
+
586
+
587
+ class CustomRotationalEmbedding(nn.Module):
588
+ """
589
+ Custom Rotational Positional Embedding.
590
+
591
+ Generates 2D positional embeddings based on rotating a fixed start vector.
592
+ The rotation angle for each grid position is determined primarily by its
593
+ horizontal position (width dimension). The resulting rotated vectors are
594
+ concatenated and projected.
595
+
596
+ Note: The current implementation derives angles only from the width dimension (`x.size(-1)`).
597
+
598
+ Args:
599
+ d_model (int): Dimensionality of the output embeddings.
600
+ """
601
+ def __init__(self, d_model):
602
+ super(CustomRotationalEmbedding, self).__init__()
603
+ # Learnable 2D start vector
604
+ self.register_parameter('start_vector', nn.Parameter(torch.Tensor([0, 1]), requires_grad=True))
605
+ # Projects the 4D concatenated rotated vectors to d_model
606
+ # Input size 4 comes from concatenating two 2D rotated vectors
607
+ self.projection = nn.Sequential(nn.Linear(4, d_model))
608
+
609
+ def forward(self, x):
610
+ """
611
+ Computes rotational positional embeddings based on input width.
612
+
613
+ Args:
614
+ x (torch.Tensor): Input tensor (used for shape and device),
615
+ shape (batch_size, channels, height, width).
616
+ Returns:
617
+ Output tensor containing positional embeddings,
618
+ shape (1, d_model, height, width) - Batch dim is 1 as PE is same for all.
619
+ """
620
+ B, C, H, W = x.shape
621
+ device = x.device
622
+
623
+ # --- Generate rotations based only on Width ---
624
+ # Angles derived from width dimension
625
+ theta_rad = torch.deg2rad(torch.linspace(0, 180, W, device=device)) # Angle per column
626
+ cos_theta = torch.cos(theta_rad)
627
+ sin_theta = torch.sin(theta_rad)
628
+
629
+ # Create rotation matrices: Shape (W, 2, 2)
630
+ # Use unsqueeze(1) to allow stacking along dim 1
631
+ rotation_matrices = torch.stack([
632
+ torch.stack([cos_theta, -sin_theta], dim=-1), # Shape (W, 2)
633
+ torch.stack([sin_theta, cos_theta], dim=-1) # Shape (W, 2)
634
+ ], dim=1) # Stacks along dim 1 -> Shape (W, 2, 2)
635
+
636
+ # Rotate the start vector by column angle: Shape (W, 2)
637
+ rotated_vectors = torch.einsum('wij,j->wi', rotation_matrices, self.start_vector)
638
+
639
+ # --- Create Grid Key ---
640
+ # Original code uses repeats based on rotated_vectors.shape[0] (which is W) for both dimensions.
641
+ # This creates a (W, W, 4) key tensor.
642
+ key = torch.cat((
643
+ torch.repeat_interleave(rotated_vectors.unsqueeze(1), W, dim=1), # (W, 1, 2) -> (W, W, 2)
644
+ torch.repeat_interleave(rotated_vectors.unsqueeze(0), W, dim=0) # (1, W, 2) -> (W, W, 2)
645
+ ), dim=-1) # Shape (W, W, 4)
646
+
647
+ # Project the 4D key vector to d_model: Shape (W, W, d_model)
648
+ pe_grid = self.projection(key)
649
+
650
+ # Reshape to (1, d_model, W, W) and then select/resize to target H, W?
651
+ # Original code permutes to (d_model, W, W) and unsqueezes to (1, d_model, W, W)
652
+ pe = pe_grid.permute(2, 0, 1).unsqueeze(0)
653
+
654
+ # If H != W, this needs adjustment. Assuming H=W or cropping/padding happens later.
655
+ # Let's return the (1, d_model, W, W) tensor as generated by the original logic.
656
+ # If H != W, downstream code must handle the mismatch or this PE needs modification.
657
+ if H != W:
658
+ # Simple interpolation/cropping could be added, but sticking to original logic:
659
+ # Option 1: Interpolate
660
+ # pe = F.interpolate(pe, size=(H, W), mode='bilinear', align_corners=False)
661
+ # Option 2: Crop/Pad (e.g., crop if W > W_target, pad if W < W_target)
662
+ # Sticking to original: return shape (1, d_model, W, W)
663
+ pass
664
+
665
+ return pe
666
+
667
+ class CustomRotationalEmbedding1D(nn.Module):
668
+ def __init__(self, d_model):
669
+ super(CustomRotationalEmbedding1D, self).__init__()
670
+ self.projection = nn.Linear(2, d_model)
671
+
672
+ def forward(self, x):
673
+ start_vector = torch.tensor([0., 1.], device=x.device, dtype=torch.float)
674
+ theta_rad = torch.deg2rad(torch.linspace(0, 180, x.size(2), device=x.device))
675
+ cos_theta = torch.cos(theta_rad)
676
+ sin_theta = torch.sin(theta_rad)
677
+ cos_theta = cos_theta.unsqueeze(1) # Shape: (height, 1)
678
+ sin_theta = sin_theta.unsqueeze(1) # Shape: (height, 1)
679
+
680
+ # Create rotation matrices
681
+ rotation_matrices = torch.stack([
682
+ torch.cat([cos_theta, -sin_theta], dim=1),
683
+ torch.cat([sin_theta, cos_theta], dim=1)
684
+ ], dim=1) # Shape: (height, 2, 2)
685
+
686
+ # Rotate the start vector
687
+ rotated_vectors = torch.einsum('bij,j->bi', rotation_matrices, start_vector)
688
+
689
+ pe = self.projection(rotated_vectors)
690
+ pe = torch.repeat_interleave(pe.unsqueeze(0), x.size(0), 0)
691
+ return pe.transpose(1, 2) # Transpose for compatibility with other backbones
692
+
models/resnet.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import os
4
+ from models.modules import Identity
5
+
6
+ __all__ = [
7
+ "ResNet",
8
+ "resnet18",
9
+ "resnet34",
10
+ "resnet50",
11
+ "resnet101",
12
+ "resnet152",
13
+ ]
14
+
15
+
16
+ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
17
+ """3x3 convolution with padding"""
18
+ return nn.Conv2d(
19
+ in_planes,
20
+ out_planes,
21
+ kernel_size=3,
22
+ stride=stride,
23
+ padding=dilation,
24
+ groups=groups,
25
+ bias=False,
26
+ dilation=dilation,
27
+ )
28
+
29
+
30
+ def conv1x1(in_planes, out_planes, stride=1):
31
+ """1x1 convolution"""
32
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
33
+
34
+
35
+ class BasicBlock(nn.Module):
36
+ expansion = 1
37
+
38
+ def __init__(
39
+ self,
40
+ inplanes,
41
+ planes,
42
+ stride=1,
43
+ downsample=None,
44
+ groups=1,
45
+ base_width=64,
46
+ dilation=1,
47
+ norm_layer=None,
48
+ ):
49
+ super(BasicBlock, self).__init__()
50
+ if norm_layer is None:
51
+ norm_layer = nn.BatchNorm2d
52
+ if groups != 1 or base_width != 64:
53
+ raise ValueError("BasicBlock only supports groups=1 and base_width=64")
54
+ if dilation > 1:
55
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
56
+ # Both self.conv1 and self.downsample layers downsample the input when stride != 1
57
+ self.conv1 = conv3x3(inplanes, planes, stride)
58
+ self.bn1 = norm_layer(planes)
59
+ self.relu = nn.ReLU(inplace=True)
60
+ self.conv2 = conv3x3(planes, planes)
61
+ self.bn2 = norm_layer(planes)
62
+ self.downsample = downsample
63
+ self.stride = stride
64
+
65
+ def forward(self, x):
66
+ identity = x
67
+
68
+ out = self.conv1(x)
69
+ out = self.bn1(out)
70
+ out = self.relu(out)
71
+
72
+ out = self.conv2(out)
73
+ out = self.bn2(out)
74
+
75
+ if self.downsample is not None:
76
+ identity = self.downsample(x)
77
+
78
+ out += identity
79
+
80
+ out = self.relu(out)
81
+ return out
82
+
83
+
84
+ class Bottleneck(nn.Module):
85
+ expansion = 4
86
+
87
+ def __init__(
88
+ self,
89
+ inplanes,
90
+ planes,
91
+ stride=1,
92
+ downsample=None,
93
+ groups=1,
94
+ base_width=64,
95
+ dilation=1,
96
+ norm_layer=None,
97
+ ):
98
+ super(Bottleneck, self).__init__()
99
+ if norm_layer is None:
100
+ norm_layer = nn.BatchNorm2d
101
+ width = int(planes * (base_width / 64.0)) * groups
102
+ # Both self.conv2 and self.downsample layers downsample the input when stride != 1
103
+ self.conv1 = conv1x1(inplanes, width)
104
+ self.bn1 = norm_layer(width)
105
+ self.conv2 = conv3x3(width, width, stride, groups, dilation)
106
+ self.bn2 = norm_layer(width)
107
+ self.conv3 = conv1x1(width, planes * self.expansion)
108
+ self.bn3 = norm_layer(planes * self.expansion)
109
+ self.relu = nn.ReLU(inplace=True)
110
+ self.downsample = downsample
111
+ self.stride = stride
112
+
113
+ def forward(self, x):
114
+ identity = x
115
+
116
+ out = self.conv1(x)
117
+ out = self.bn1(out)
118
+ out = self.relu(out)
119
+
120
+ out = self.conv2(out)
121
+ out = self.bn2(out)
122
+ out = self.relu(out)
123
+
124
+ out = self.conv3(out)
125
+ out = self.bn3(out)
126
+
127
+ if self.downsample is not None:
128
+ identity = self.downsample(x)
129
+
130
+ out += identity
131
+
132
+
133
+ # activation = None
134
+ # activation = out.detach().cpu().numpy()
135
+ out = self.relu(out)
136
+ # return out, activation
137
+
138
+ return out
139
+
140
+
141
+ class ResNet(nn.Module):
142
+ def __init__(
143
+ self,
144
+ in_channels,
145
+ feature_scales,
146
+ stride,
147
+ block,
148
+ layers,
149
+ num_classes=10,
150
+ zero_init_residual=False,
151
+ groups=1,
152
+ width_per_group=64,
153
+ replace_stride_with_dilation=None,
154
+ norm_layer=None,
155
+ do_initial_max_pool=True,
156
+ ):
157
+ super(ResNet, self).__init__()
158
+ if norm_layer is None:
159
+ norm_layer = nn.BatchNorm2d
160
+ self._norm_layer = norm_layer
161
+
162
+ self.inplanes = 64
163
+ self.dilation = 1
164
+ if replace_stride_with_dilation is None:
165
+ # each element in the tuple indicates if we should replace
166
+ # the 2x2 stride with a dilated convolution instead
167
+ replace_stride_with_dilation = [False, False, False]
168
+ if len(replace_stride_with_dilation) != 3:
169
+ raise ValueError(
170
+ "replace_stride_with_dilation should be None "
171
+ "or a 3-element tuple, got {}".format(replace_stride_with_dilation)
172
+ )
173
+ self.groups = groups
174
+ self.base_width = width_per_group
175
+
176
+ # NOTE: Important!
177
+ # This has changed from a kernel size of 7 (padding=3) to a kernel of 3 (padding=1)
178
+ # The reason for this was to limit the receptive field to constrain models to
179
+ # "Looking around" to gather information.
180
+
181
+ self.conv1 = nn.Conv2d(
182
+ in_channels, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False
183
+ ) if in_channels in [1, 3] else nn.LazyConv2d(
184
+ self.inplanes, kernel_size=3, stride=1, padding=1, bias=False
185
+ )
186
+ # END
187
+
188
+ self.bn1 = norm_layer(self.inplanes)
189
+ self.relu = nn.ReLU(inplace=True)
190
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) if do_initial_max_pool else Identity()
191
+ self.layer1 = self._make_layer(block, 64, layers[0])
192
+ self.feature_scales = feature_scales
193
+ if 2 in feature_scales:
194
+ self.layer2 = self._make_layer(
195
+ block, 128, layers[1], stride=stride, dilate=replace_stride_with_dilation[0]
196
+ )
197
+ if 3 in feature_scales:
198
+ self.layer3 = self._make_layer(
199
+ block, 256, layers[2], stride=stride, dilate=replace_stride_with_dilation[1]
200
+ )
201
+ if 4 in feature_scales:
202
+ self.layer4 = self._make_layer(
203
+ block, 512, layers[3], stride=stride, dilate=replace_stride_with_dilation[2]
204
+ )
205
+
206
+ # NOTE: Commented this out as it is not used anymore for this work, kept it for reference
207
+ # self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
208
+ # self.fc = nn.Linear(512 * block.expansion, num_classes)
209
+
210
+ # for m in self.modules():
211
+ # if isinstance(m, nn.Conv2d):
212
+ # nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
213
+ # elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
214
+ # nn.init.constant_(m.weight, 1)
215
+ # nn.init.constant_(m.bias, 0)
216
+
217
+ # Zero-initialize the last BN in each residual branch,
218
+ # so that the residual branch starts with zeros, and each residual block behaves like an identity.
219
+ # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
220
+ if zero_init_residual:
221
+ for m in self.modules():
222
+ if isinstance(m, Bottleneck):
223
+ nn.init.constant_(m.bn3.weight, 0)
224
+ elif isinstance(m, BasicBlock):
225
+ nn.init.constant_(m.bn2.weight, 0)
226
+
227
+ def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
228
+ norm_layer = self._norm_layer
229
+ downsample = None
230
+ previous_dilation = self.dilation
231
+ if dilate:
232
+ self.dilation *= stride
233
+ stride = 1
234
+ if stride != 1 or self.inplanes != planes * block.expansion:
235
+ downsample = nn.Sequential(
236
+ conv1x1(self.inplanes, planes * block.expansion, stride),
237
+ norm_layer(planes * block.expansion),
238
+ )
239
+
240
+ layers = []
241
+ layers.append(
242
+ block(
243
+ self.inplanes,
244
+ planes,
245
+ stride,
246
+ downsample,
247
+ self.groups,
248
+ self.base_width,
249
+ previous_dilation,
250
+ norm_layer,
251
+ )
252
+ )
253
+ self.inplanes = planes * block.expansion
254
+ for _ in range(1, blocks):
255
+ layers.append(
256
+ block(
257
+ self.inplanes,
258
+ planes,
259
+ groups=self.groups,
260
+ base_width=self.base_width,
261
+ dilation=self.dilation,
262
+ norm_layer=norm_layer,
263
+ )
264
+ )
265
+
266
+ return nn.Sequential(*layers)
267
+
268
+ def forward(self, x):
269
+ activations = []
270
+ x = self.conv1(x)
271
+ x = self.bn1(x)
272
+ x = self.relu(x)
273
+ x = self.maxpool(x)
274
+ # if return_activations: activations.append(torch.clone(x))
275
+ x = self.layer1(x)
276
+
277
+ if 2 in self.feature_scales:
278
+ x = self.layer2(x)
279
+ if 3 in self.feature_scales:
280
+ x = self.layer3(x)
281
+ if 4 in self.feature_scales:
282
+ x = self.layer4(x)
283
+ return x
284
+
285
+
286
+ def _resnet(in_channels, feature_scales, stride, arch, block, layers, pretrained, progress, device, do_initial_max_pool, **kwargs):
287
+ model = ResNet(in_channels, feature_scales, stride, block, layers, do_initial_max_pool=do_initial_max_pool, **kwargs)
288
+ if pretrained:
289
+ assert in_channels==3
290
+ script_dir = os.path.dirname(__file__)
291
+ state_dict = torch.load(
292
+ script_dir + '/state_dicts/' + arch + ".pt", map_location=device
293
+ )
294
+ model.load_state_dict(state_dict, strict=False)
295
+ return model
296
+
297
+
298
+ def resnet18(in_channels, feature_scales, stride=2, pretrained=False, progress=True, device="cpu", do_initial_max_pool=True, **kwargs):
299
+ """Constructs a ResNet-18 model.
300
+ Args:
301
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
302
+ progress (bool): If True, displays a progress bar of the download to stderr
303
+ """
304
+ return _resnet(in_channels,
305
+ feature_scales, stride, "resnet18", BasicBlock, [2, 2, 2, 2], pretrained, progress, device, do_initial_max_pool, **kwargs
306
+ )
307
+
308
+
309
+ def resnet34(in_channels, feature_scales, stride=2, pretrained=False, progress=True, device="cpu", do_initial_max_pool=True, **kwargs):
310
+ """Constructs a ResNet-34 model.
311
+ Args:
312
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
313
+ progress (bool): If True, displays a progress bar of the download to stderr
314
+ """
315
+ return _resnet(in_channels,
316
+ feature_scales, stride, "resnet34", BasicBlock, [3, 4, 6, 3], pretrained, progress, device, do_initial_max_pool, **kwargs
317
+ )
318
+
319
+
320
+ def resnet50(in_channels, feature_scales, stride=2, pretrained=False, progress=True, device="cpu", do_initial_max_pool=True, **kwargs):
321
+ """Constructs a ResNet-50 model.
322
+ Args:
323
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
324
+ progress (bool): If True, displays a progress bar of the download to stderr
325
+ """
326
+ return _resnet(in_channels,
327
+ feature_scales, stride, "resnet50", Bottleneck, [3, 4, 6, 3], pretrained, progress, device, do_initial_max_pool, **kwargs
328
+ )
329
+
330
+
331
+ def resnet101(in_channels, feature_scales, stride=2, pretrained=False, progress=True, device="cpu", do_initial_max_pool=True, **kwargs):
332
+ """Constructs a ResNet-50 model.
333
+ Args:
334
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
335
+ progress (bool): If True, displays a progress bar of the download to stderr
336
+ """
337
+ return _resnet(in_channels,
338
+ feature_scales, stride, "resnet101", Bottleneck, [3, 4, 23, 3], pretrained, progress, device, do_initial_max_pool, **kwargs
339
+ )
340
+
341
+
342
+ def resnet152(in_channels, feature_scales, stride=2, pretrained=False, progress=True, device="cpu", do_initial_max_pool=True, **kwargs):
343
+ """Constructs a ResNet-50 model.
344
+ Args:
345
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
346
+ progress (bool): If True, displays a progress bar of the download to stderr
347
+ """
348
+ return _resnet(in_channels,
349
+ feature_scales, stride, "resnet152", Bottleneck, [3, 4, 36, 3], pretrained, progress, device, do_initial_max_pool, **kwargs
350
+ )
351
+
352
+ def prepare_resnet_backbone(backbone_type):
353
+
354
+ resnet_family = resnet18 # Default
355
+ if '34' in backbone_type: resnet_family = resnet34
356
+ if '50' in backbone_type: resnet_family = resnet50
357
+ if '101' in backbone_type: resnet_family = resnet101
358
+ if '152' in backbone_type: resnet_family = resnet152
359
+
360
+ # Determine which ResNet blocks to keep
361
+ block_num_str = backbone_type.split('-')[-1]
362
+ hyper_blocks_to_keep = list(range(1, int(block_num_str) + 1)) if block_num_str.isdigit() else [1, 2, 3, 4]
363
+
364
+ backbone = resnet_family(
365
+ 3,
366
+ hyper_blocks_to_keep,
367
+ stride=2,
368
+ pretrained=False,
369
+ progress=True,
370
+ device="cpu",
371
+ do_initial_max_pool=True,
372
+ )
373
+
374
+ return backbone
models/utils.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import re
4
+ import os
5
+
6
+ def compute_decay(T, params, clamp_lims=(0, 15)):
7
+ """
8
+ This function computes exponential decays for learnable synchronisation
9
+ interactions between pairs of neurons.
10
+ """
11
+ assert len(clamp_lims), 'Clamp lims should be length 2'
12
+ assert type(clamp_lims) == tuple, 'Clamp lims should be tuple'
13
+
14
+ indices = torch.arange(T-1, -1, -1, device=params.device).reshape(T, 1).expand(T, params.shape[0])
15
+ out = torch.exp(-indices * torch.clamp(params, clamp_lims[0], clamp_lims[1]).unsqueeze(0))
16
+ return out
17
+
18
+ def add_coord_dim(x, scaled=True):
19
+ """
20
+ Adds a final dimension to the tensor representing 2D coordinates.
21
+
22
+ Args:
23
+ tensor: A PyTorch tensor of shape (B, D, H, W).
24
+
25
+ Returns:
26
+ A PyTorch tensor of shape (B, D, H, W, 2) with the last dimension
27
+ representing the 2D coordinates within the HW dimensions.
28
+ """
29
+ B, H, W = x.shape
30
+ # Create coordinate grids
31
+ x_coords = torch.arange(W, device=x.device, dtype=x.dtype).repeat(H, 1) # Shape (H, W)
32
+ y_coords = torch.arange(H, device=x.device, dtype=x.dtype).unsqueeze(-1).repeat(1, W) # Shape (H, W)
33
+ if scaled:
34
+ x_coords /= (W-1)
35
+ y_coords /= (H-1)
36
+ # Stack coordinates and expand dimensions
37
+ coords = torch.stack((x_coords, y_coords), dim=-1) # Shape (H, W, 2)
38
+ coords = coords.unsqueeze(0) # Shape (1, 1, H, W, 2)
39
+ coords = coords.repeat(B, 1, 1, 1) # Shape (B, D, H, W, 2)
40
+ return coords
41
+
42
+ def compute_normalized_entropy(logits, reduction='mean'):
43
+ """
44
+ Calculates the normalized entropy of a PyTorch tensor of logits along the
45
+ final dimension.
46
+
47
+ Args:
48
+ logits: A PyTorch tensor of logits.
49
+
50
+ Returns:
51
+ A PyTorch tensor containing the normalized entropy values.
52
+ """
53
+
54
+ # Apply softmax to get probabilities
55
+ preds = F.softmax(logits, dim=-1)
56
+
57
+ # Calculate the log probabilities
58
+ log_preds = torch.log_softmax(logits, dim=-1)
59
+
60
+ # Calculate the entropy
61
+ entropy = -torch.sum(preds * log_preds, dim=-1)
62
+
63
+ # Calculate the maximum possible entropy
64
+ num_classes = preds.shape[-1]
65
+ max_entropy = torch.log(torch.tensor(num_classes, dtype=torch.float32))
66
+
67
+ # Normalize the entropy
68
+ normalized_entropy = entropy / max_entropy
69
+ if len(logits.shape)>2 and reduction == 'mean':
70
+ normalized_entropy = normalized_entropy.flatten(1).mean(-1)
71
+
72
+ return normalized_entropy
73
+
74
+ def reshape_predictions(predictions, prediction_reshaper):
75
+ B, T = predictions.size(0), predictions.size(-1)
76
+ new_shape = [B] + prediction_reshaper + [T]
77
+ rehaped_predictions = predictions.reshape(new_shape)
78
+ return rehaped_predictions
79
+
80
+ def get_all_log_dirs(root_dir):
81
+ folders = []
82
+ for dirpath, dirnames, filenames in os.walk(root_dir):
83
+ if any(f.endswith(".pt") for f in filenames):
84
+ folders.append(dirpath)
85
+ return folders
86
+
87
+ def get_latest_checkpoint(log_dir):
88
+ files = [f for f in os.listdir(log_dir) if re.match(r'checkpoint_\d+\.pt', f)]
89
+ return os.path.join(log_dir, max(files, key=lambda f: int(re.search(r'\d+', f).group()))) if files else None
90
+
91
+ def get_latest_checkpoint_file(filepath, limit=300000):
92
+ checkpoint_files = get_checkpoint_files(filepath)
93
+ checkpoint_files = [
94
+ f for f in checkpoint_files if int(re.search(r'checkpoint_(\d+)\.pt', f).group(1)) <= limit
95
+ ]
96
+ if not checkpoint_files:
97
+ return None
98
+ return checkpoint_files[-1]
99
+
100
+ def get_checkpoint_files(filepath):
101
+ regex = r'checkpoint_(\d+)\.pt'
102
+ files = [f for f in os.listdir(filepath) if re.match(regex, f)]
103
+ files = sorted(files, key=lambda f: int(re.search(regex, f).group(1)))
104
+ return [os.path.join(filepath, f) for f in files]
105
+
106
+ def load_checkpoint(checkpoint_path, device):
107
+ checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
108
+ return checkpoint
109
+
110
+ def get_model_args_from_checkpoint(checkpoint):
111
+ if "args" in checkpoint:
112
+ return(checkpoint["args"])
113
+ else:
114
+ raise ValueError("Checkpoint does not contain saved args.")
115
+
116
+ def get_accuracy_and_loss_from_checkpoint(checkpoint, device="cpu"):
117
+ training_iteration = checkpoint.get('training_iteration', 0)
118
+ train_losses = checkpoint.get('train_losses', [])
119
+ test_losses = checkpoint.get('test_losses', [])
120
+ train_accuracies = checkpoint.get('train_accuracies_most_certain', [])
121
+ test_accuracies = checkpoint.get('test_accuracies_most_certain', [])
122
+ return training_iteration, train_losses, test_losses, train_accuracies, test_accuracies
mount_azure.sh ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ set -e
3
+
4
+ # Configuración
5
+ REMOTE_HOST="ssh.my-robot.dev"
6
+ REMOTE_USER="azureuser"
7
+ REMOTE_PATH="/mnt/lightrag"
8
+ LOCAL_MOUNT="/data/persistent"
9
+
10
+ # 1. Configurar SSH Key
11
+ if [ -z "$SSH_KEY" ]; then
12
+ echo "❌ SSH_KEY no definida. Configúrala en Settings -> Secrets"
13
+ exit 1
14
+ fi
15
+
16
+ mkdir -p ~/.ssh
17
+ echo "$SSH_KEY" > ~/.ssh/id_rsa
18
+ chmod 600 ~/.ssh/id_rsa
19
+
20
+ # 2. Configurar Cloudflare ProxyCommand
21
+ echo "Host $REMOTE_HOST
22
+ User $REMOTE_USER
23
+ IdentityFile ~/.ssh/id_rsa
24
+ ProxyCommand /usr/bin/cloudflared access ssh --hostname %h
25
+ StrictHostKeyChecking no
26
+ " > ~/.ssh/config
27
+
28
+ # 3. Crear punto de montaje
29
+ if [ ! -d "$LOCAL_MOUNT" ]; then
30
+ mkdir -p $LOCAL_MOUNT
31
+ chmod 777 $LOCAL_MOUNT || true
32
+ fi
33
+
34
+ # 4. Montar SSHFS
35
+ echo "🔌 Conectando al Bunker ($REMOTE_HOST)..."
36
+ sshfs $REMOTE_HOST:$REMOTE_PATH $LOCAL_MOUNT \
37
+ -o allow_other,reconnect,ServerAliveInterval=15,ServerAliveCountMax=3,idmap=user
38
+
39
+ if mountpoint -q $LOCAL_MOUNT; then
40
+ echo "✅ Disco Azure Montado en $LOCAL_MOUNT"
41
+ ls -la $LOCAL_MOUNT || echo "No se puede listar, pero está montado"
42
+ else
43
+ echo "❌ Error al montar"
44
+ fi
requirements.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Requirements for CTM Nervous System v2.0 (Full PyTorch)
2
+ # =========================================================
3
+ # Upgraded to use full ContinuousThoughtMachine implementation
4
+
5
+ # Core Framework
6
+ gradio>=5.0.0
7
+ numpy
8
+
9
+ # PyTorch (CPU-only for HuggingFace free tier)
10
+ # Use torch-cpu to minimize memory footprint
11
+ --extra-index-url https://download.pytorch.org/whl/cpu
12
+ torch>=2.0.0
13
+
14
+ # For checkpoint saving
15
+ safetensors
16
+
17
+ # For potential model weights download
18
+ huggingface_hub
19
+
20
+ # HTTP client for Brain server integration
21
+ requests
requirements_v1.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ gradio>=5.0.0
2
+ numpy
setup_hf_space.sh ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Configuración de Usuario
4
+ git config --global user.email "ksj6ftj8f8@gmail.com"
5
+ git config --global user.name "Elliotasdasdasfasas"
6
+ git config --global credential.helper store
7
+
8
+ # Inicializar Git si no existe
9
+ if [ ! -d ".git" ]; then
10
+ git init
11
+ git branch -M main
12
+ fi
13
+
14
+ # Definir Token y Usuario
15
+ HF_TOKEN="${HF_TOKEN}"
16
+ HF_USER="ROBOT-GANSTA"
17
+ SPACE_NAME="robot-gansta-redneural"
18
+
19
+ # Configurar Remote con Token explícito (desde variable de entorno)
20
+ REMOTE_URL="https://$HF_USER:$HF_TOKEN@huggingface.co/spaces/$HF_USER/$SPACE_NAME"
21
+
22
+ echo "🔗 Conectando a: https://huggingface.co/spaces/$HF_USER/$SPACE_NAME"
23
+
24
+ # Eliminar remote anterior si existe
25
+ git remote remove space 2>/dev/null
26
+
27
+ # Añadir remote y hacer pull/push
28
+ git remote add space "$REMOTE_URL"
29
+
30
+ # Añadir archivos
31
+ git add .
32
+ git commit -m "Deploy RedNeural Docker Environment with Fixes"
33
+
34
+ # Forzar push inicial
35
+ git push --force space main
36
+
37
+ echo "✅ Despliegue completado."
tasks/image_classification/README.md ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Image classification
2
+
3
+ This folder contains code for training and analysing imagenet and cifar related experiments.
4
+
5
+ ## Accessing and loading imagenet
6
+
7
+ We use the [ILSRC/imagenet-1k](https://huggingface.co/datasets/ILSVRC/imagenet-1k) dataset in our paper.
8
+
9
+ To get this to work for you, you will need to do the following:
10
+ 1. Login to huggingface (make an account) to agree to TCs of this dataset,
11
+ 2. Make a new access token.
12
+ 3. Install huggingface_hub on the target machine with ```pip install huggingface_hub```
13
+ 4. Run ```huggingface-cli login``` and use your token. This will authenticate you on the backend and allow the code to run.
14
+ 5. Simply run an imagenet experiment. It will auto download and do all that magic.
15
+
16
+
17
+ ## Training
18
+ There are two training files: `train.py` and `train_distributed.py`. The training code uses mixed precision. For the settings in the paper, the following command was used for distributed training:
19
+
20
+ ```
21
+ torchrun --standalone --nnodes=1 --nproc_per_node=8 -m tasks.image_classification.train_distributed --d_model 4096 --d_input 1024 --synapse_depth 12 --heads 16 --n_synch_out 150 --n_synch_action 150 --neuron_select_type random --iterations 75 --memory_length 25 --deep_memory --memory_hidden_dims 64 --dropout 0.05 --no-do_normalisation --positional_embedding_type none --backbone_type resnet152-4 --batch_size 60 --batch_size_test 64 --lr 5e-4 --training_iterations 500001 --warmup_steps 10000 --use_scheduler --scheduler_type cosine --weight_decay 0.0 --log_dir logs-lambda/imagenet-distributed-4april/d=4096--i=1024--h=16--ns=150-random--iters=75x25--h=64--drop=0.05--pos=none--back=152x4--seed=42 --dataset imagenet --save_every 2000 --track_every 5000 --seed 42 --n_test_batches 50 --use_amp
22
+ ```
23
+
24
+ You can run the same setup on a single GPU with:
25
+ ```
26
+ python -m tasks.image_classification.train --d_model 4096 --d_input 1024 --synapse_depth 12 --heads 16 --n_synch_out 150 --n_synch_action 150 --neuron_select_type random --iterations 75 --memory_length 25 --deep_memory --memory_hidden_dims 64 --dropout 0.05 --no-do_normalisation --positional_embedding_type none --backbone_type resnet152-4 --batch_size 60 --batch_size_test 64 --lr 5e-4 --training_iterations 500001 --warmup_steps 10000 --use_scheduler --scheduler_type cosine --weight_decay 0.0 --log_dir logs-lambda/imagenet-distributed-4april/d=4096--i=1024--h=16--ns=150-random--iters=75x25--h=64--drop=0.05--pos=none--back=152x4--seed=42 --dataset imagenet --save_every 2000 --track_every 5000 --seed 42 --n_test_batches 50 --use_amp --device 0
27
+ ```
28
+
29
+ ## Checkpoint
30
+
31
+ The checkpoint for the model used in the paper can be found [here](https://drive.google.com/file/d/1Lr_3RZU9X9SS8lBhAhECBiSZDKfKhDkJ/view?usp=drive_link).
tasks/image_classification/analysis/README.md ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Analysis
2
+
3
+ This folder contains the analysis code for the image classifcation experiments. Running the following from the base directory will generate figures, gifs and mp4 files:
4
+
5
+ ```
6
+ python -m tasks.image_classification.analysis.run_imagenet_analysis
7
+ ```
tasks/image_classification/analysis/run_imagenet_analysis.py ADDED
@@ -0,0 +1,972 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --- Core Libraries ---
2
+ import torch
3
+ import numpy as np
4
+ import os
5
+ import argparse
6
+ from tqdm.auto import tqdm
7
+ import torch.nn.functional as F # Used for interpolate
8
+
9
+ # --- Plotting & Visualization ---
10
+ import matplotlib.pyplot as plt
11
+ import matplotlib as mpl
12
+ mpl.use('Agg')
13
+ import seaborn as sns
14
+ sns.set_style('darkgrid')
15
+ from matplotlib import patheffects
16
+ import seaborn as sns
17
+ import imageio
18
+ import cv2
19
+ from scipy.special import softmax
20
+ from tasks.image_classification.plotting import save_frames_to_mp4
21
+
22
+ # --- Data Handling & Model ---
23
+ from torchvision import transforms
24
+ from torchvision import datasets # Only used for CIFAR100 in debug mode
25
+ from scipy import ndimage # Used in find_island_centers
26
+ from data.custom_datasets import ImageNet
27
+ from models.ctm import ContinuousThoughtMachine
28
+ from tasks.image_classification.imagenet_classes import IMAGENET2012_CLASSES
29
+ from tasks.image_classification.plotting import plot_neural_dynamics
30
+
31
+ # --- Global Settings ---
32
+ np.seterr(divide='ignore')
33
+ mpl.use('Agg')
34
+ sns.set_style('darkgrid')
35
+
36
+ # --- Helper Functions ---
37
+
38
+ def find_island_centers(array_2d, threshold):
39
+ """
40
+ Finds the center of mass of each island (connected component > threshold)
41
+ in a 2D array, weighted by the array's values.
42
+ Returns list of (y, x) centers and list of areas.
43
+ """
44
+ binary_image = array_2d > threshold
45
+ labeled_image, num_labels = ndimage.label(binary_image)
46
+ centers = []
47
+ areas = []
48
+ # Calculate center of mass for each labeled island (label 0 is background)
49
+ for i in range(1, num_labels + 1):
50
+ island_mask = (labeled_image == i)
51
+ total_mass = np.sum(array_2d[island_mask])
52
+ if total_mass > 0:
53
+ # Get coordinates for this island
54
+ y_coords, x_coords = np.mgrid[:array_2d.shape[0], :array_2d.shape[1]]
55
+ # Calculate weighted average for center
56
+ x_center = np.average(x_coords[island_mask], weights=array_2d[island_mask])
57
+ y_center = np.average(y_coords[island_mask], weights=array_2d[island_mask])
58
+ centers.append((round(y_center, 4), round(x_center, 4)))
59
+ areas.append(np.sum(island_mask)) # Area is the count of pixels in the island
60
+ return centers, areas
61
+
62
+ def parse_args():
63
+ """Parses command-line arguments."""
64
+ # Note: Original had two ArgumentParser instances, using the second one.
65
+ parser = argparse.ArgumentParser(description="Visualize Continuous Thought Machine Attention")
66
+ parser.add_argument('--actions', type=str, nargs='+', default=['videos'], choices=['plots', 'videos', 'demo'], help="Actions to take. Plots=results plots; videos=gifs/mp4s to watch attention; demo: last frame of internal ticks")
67
+ parser.add_argument('--device', type=int, nargs='+', default=[-1], help="GPU device index or -1 for CPU")
68
+
69
+ parser.add_argument('--checkpoint', type=str, default='checkpoints/imagenet/ctm_clean.pt', help="Path to ATM checkpoint")
70
+ parser.add_argument('--output_dir', type=str, default='tasks/image_classification/analysis/outputs/imagenet_viz', help="Directory for visualization outputs")
71
+ parser.add_argument('--debug', action=argparse.BooleanOptionalAction, default=True, help='Debug mode: use CIFAR100 instead of ImageNet for debugging.')
72
+ parser.add_argument('--plot_every', type=int, default=10, help="How often to plot.")
73
+
74
+ parser.add_argument('--inference_iterations', type=int, default=50, help="Iterations to use during inference.")
75
+ parser.add_argument('--data_indices', type=int, nargs='+', default=[], help="Use specific indices in validation data for demos, otherwise random.")
76
+ parser.add_argument('--N_to_viz', type=int, default=5, help="When not supplying data_indices.")
77
+
78
+ return parser.parse_args()
79
+
80
+
81
+ # --- Main Execution Block ---
82
+ if __name__=='__main__':
83
+
84
+ # --- Setup ---
85
+ args = parse_args()
86
+ if args.device[0] != -1 and torch.cuda.is_available():
87
+ device = f'cuda:{args.device[0]}'
88
+ else:
89
+ device = 'cpu'
90
+ print(f"Using device: {device}")
91
+
92
+ # --- Load Checkpoint & Model ---
93
+ print(f"Loading checkpoint: {args.checkpoint}")
94
+ checkpoint = torch.load(args.checkpoint, map_location=device, weights_only=False) # removed weights_only=False
95
+ model_args = checkpoint['args']
96
+
97
+ # Handle legacy arguments from checkpoint if necessary
98
+ if not hasattr(model_args, 'backbone_type') and hasattr(model_args, 'resnet_type'):
99
+ model_args.backbone_type = f'{model_args.resnet_type}-{getattr(model_args, "resnet_feature_scales", [4])[-1]}'
100
+ if not hasattr(model_args, 'neuron_select_type'):
101
+ model_args.neuron_select_type = 'first-last'
102
+
103
+
104
+ # Instantiate Model based on checkpoint args
105
+ print("Instantiating CTM model...")
106
+ model = ContinuousThoughtMachine(
107
+ iterations=model_args.iterations,
108
+ d_model=model_args.d_model,
109
+ d_input=model_args.d_input,
110
+ heads=model_args.heads,
111
+ n_synch_out=model_args.n_synch_out,
112
+ n_synch_action=model_args.n_synch_action,
113
+ synapse_depth=model_args.synapse_depth,
114
+ memory_length=model_args.memory_length,
115
+ deep_nlms=model_args.deep_memory,
116
+ memory_hidden_dims=model_args.memory_hidden_dims,
117
+ do_layernorm_nlm=model_args.do_normalisation,
118
+ backbone_type=model_args.backbone_type,
119
+ positional_embedding_type=model_args.positional_embedding_type,
120
+ out_dims=model_args.out_dims,
121
+ prediction_reshaper=[-1], # Kept fixed value from original code
122
+ dropout=0, # No dropout for eval
123
+ neuron_select_type=model_args.neuron_select_type,
124
+ n_random_pairing_self=model_args.n_random_pairing_self,
125
+ ).to(device)
126
+
127
+ # Load weights into model
128
+ load_result = model.load_state_dict(checkpoint['model_state_dict'], strict=False)
129
+ print(f" Loaded state_dict. Missing: {load_result.missing_keys}, Unexpected: {load_result.unexpected_keys}")
130
+ model.eval() # Set model to evaluation mode
131
+
132
+ # --- Prepare Dataset ---
133
+ if args.debug:
134
+ print("Debug mode: Using CIFAR100")
135
+ # CIFAR100 specific normalization constants
136
+ dataset_mean = [0.5070751592371341, 0.48654887331495067, 0.4409178433670344]
137
+ dataset_std = [0.2673342858792403, 0.2564384629170882, 0.27615047132568393]
138
+ img_size = 256 # Resize CIFAR images for consistency
139
+ transform = transforms.Compose([
140
+ transforms.Resize(img_size),
141
+ transforms.ToTensor(),
142
+ transforms.Normalize(mean=dataset_mean, std=dataset_std), # Normalize
143
+ ])
144
+ validation_dataset = datasets.CIFAR100('data/', train=False, transform=transform, download=True)
145
+ validation_dataset_centercrop = datasets.CIFAR100('data/', train=True, transform=transform, download=True)
146
+ else:
147
+ print("Using ImageNet")
148
+ # ImageNet specific normalization constants
149
+ dataset_mean = [0.485, 0.456, 0.406]
150
+ dataset_std = [0.229, 0.224, 0.225]
151
+ img_size = 256 # Resize ImageNet images
152
+ # Note: Original comment mentioned no CenterCrop, this transform reflects that.
153
+ transform = transforms.Compose([
154
+ transforms.Resize(img_size),
155
+ transforms.ToTensor(),
156
+ transforms.Normalize(mean=dataset_mean, std=dataset_std) # Normalize
157
+ ])
158
+ validation_dataset = ImageNet(which_split='validation', transform=transform)
159
+ validation_dataset_centercrop = ImageNet(which_split='train', transform=transforms.Compose([
160
+ transforms.Resize(img_size),
161
+ transforms.RandomCrop(img_size),
162
+ transforms.ToTensor(),
163
+ transforms.Normalize(mean=dataset_mean, std=dataset_std) # Normalize
164
+ ]))
165
+ class_labels = list(IMAGENET2012_CLASSES.values()) # Load actual class names
166
+
167
+ os.makedirs(f'{args.output_dir}', exist_ok=True)
168
+
169
+ interp_mode = 'nearest'
170
+ cmap_calib = sns.color_palette('viridis', as_cmap=True)
171
+ loader = torch.utils.data.DataLoader(validation_dataset, batch_size=1, shuffle=False, num_workers=0, drop_last=False)
172
+ loader_crop = torch.utils.data.DataLoader(validation_dataset_centercrop, batch_size=64, shuffle=True, num_workers=0, drop_last=True)
173
+
174
+ model.eval()
175
+
176
+ figscale = 0.85
177
+ topk = 5
178
+ mean_certainties_correct, mean_certainties_incorrect = [],[]
179
+ tracked_certainties = []
180
+ tracked_targets = []
181
+ tracked_predictions = []
182
+
183
+ if model.iterations != args.inference_iterations:
184
+ print('WARNING: you are setting inference iterations to a value not used during training!')
185
+
186
+ model.iterations = args.inference_iterations
187
+
188
+ if 'plots' in args.actions:
189
+
190
+ with torch.inference_mode(): # Disable gradient calculations
191
+ with tqdm(total=len(loader), initial=0, leave=False, position=0, dynamic_ncols=True) as pbar:
192
+ imgi = 0
193
+ for bi, (inputs, targets) in enumerate(loader):
194
+ inputs = inputs.to(device)
195
+ targets = targets.to(device)
196
+ if bi==0:
197
+ dynamics_inputs, _ = next(iter(loader_crop)) # Use this because of batching
198
+ _, _, _, _, post_activations_viz, _ = model(inputs, track=True)
199
+ plot_neural_dynamics(post_activations_viz, 15*10, args.output_dir, axis_snap=True, N_per_row=15)
200
+ predictions, certainties, synchronisation = model(inputs)
201
+
202
+ tracked_predictions.append(predictions.detach().cpu().numpy())
203
+ tracked_targets.append(targets.detach().cpu().numpy())
204
+ tracked_certainties.append(certainties.detach().cpu().numpy())
205
+
206
+
207
+
208
+
209
+ pbar.set_description(f'Processing base image of size {inputs.shape}')
210
+ pbar.update(1)
211
+ if ((bi % args.plot_every == 0) or bi == len(loader)-1) and bi!=0: #
212
+
213
+ concatenated_certainties = np.concatenate(tracked_certainties, axis=0)
214
+ concatenated_targets = np.concatenate(tracked_targets, axis=0)
215
+ concatenated_predictions = np.concatenate(tracked_predictions, axis=0)
216
+ concatenated_predictions_argsorted = np.argsort(concatenated_predictions, 1)[:,::-1]
217
+
218
+
219
+
220
+ for topk in [1, 5]:
221
+ concatenated_predictions_argsorted_topk = concatenated_predictions_argsorted[:,:topk]
222
+
223
+ accs_instant, accs_avg, accs_certain = [], [], []
224
+ accs_avg_logits, accs_weighted_logits = [],[]
225
+ with tqdm(total=(concatenated_predictions.shape[-1]), initial=0, leave=False, position=1, dynamic_ncols=True) as pbarinner:
226
+ pbarinner.set_description('Acc types')
227
+ for stepi in np.arange(concatenated_predictions.shape[-1]):
228
+ pred_avg = softmax(concatenated_predictions, 1)[:,:,:stepi+1].mean(-1).argsort(1)[:,-topk:]
229
+ pred_instant = concatenated_predictions_argsorted_topk[:,:,stepi]
230
+ pred_certain = concatenated_predictions_argsorted_topk[np.arange(concatenated_predictions.shape[0]),:, concatenated_certainties[:,1,:stepi+1].argmax(1)]
231
+ pred_avg_logits = concatenated_predictions[:,:,:stepi+1].mean(-1).argsort(1)[:,-topk:]
232
+ pred_weighted_logits = (concatenated_predictions[:,:,:stepi+1] * concatenated_certainties[:,1:,:stepi+1]).sum(-1).argsort(1)[:, -topk:]
233
+ pbarinner.update(1)
234
+ accs_instant.append(np.any(pred_instant==concatenated_targets[...,np.newaxis], -1).mean())
235
+ accs_avg.append(np.any(pred_avg==concatenated_targets[...,np.newaxis], -1).mean())
236
+ accs_avg_logits.append(np.any(pred_avg==concatenated_targets[...,np.newaxis], -1).mean())
237
+ accs_weighted_logits.append(np.any(pred_weighted_logits==concatenated_targets[...,np.newaxis], -1).mean())
238
+ accs_certain.append(np.any(pred_avg_logits==concatenated_targets[...,np.newaxis], -1).mean())
239
+ fig = plt.figure(figsize=(10*figscale, 4*figscale))
240
+ ax = fig.add_subplot(111)
241
+ cp = sns.color_palette("bright")
242
+ ax.plot(np.arange(concatenated_predictions.shape[-1])+1, 100*np.array(accs_instant), linestyle='-', color=cp[0], label='Instant')
243
+ # ax.plot(np.arange(concatenated_predictions.shape[-1])+1, 100*np.array(accs_avg), linestyle='--', color=cp[1], label='Based on average probability up to this step')
244
+ ax.plot(np.arange(concatenated_predictions.shape[-1])+1, 100*np.array(accs_certain), linestyle=':', color=cp[2], label='Most certain')
245
+ ax.plot(np.arange(concatenated_predictions.shape[-1])+1, 100*np.array(accs_avg_logits), linestyle='-.', color=cp[3], label='Average logits')
246
+ ax.plot(np.arange(concatenated_predictions.shape[-1])+1, 100*np.array(accs_weighted_logits), linestyle='--', color=cp[4], label='Logits weighted by certainty')
247
+ ax.set_xlim([0, concatenated_predictions.shape[-1]+1])
248
+ ax.set_ylim([75, 92])
249
+ ax.set_xlabel('Internal ticks')
250
+ ax.set_ylabel(f'Top-k={topk} accuracy')
251
+ ax.legend(loc='lower right')
252
+ fig.tight_layout(pad=0.1)
253
+ fig.savefig(f'{args.output_dir}/accuracy_types_{topk}.png', dpi=200)
254
+ fig.savefig(f'{args.output_dir}/accuracy_types_{topk}.pdf', dpi=200)
255
+ plt.close(fig)
256
+ print(f'k={topk}. Accuracy most certain at last internal tick={100*np.array(accs_certain)[-1]:0.4f}') # Using certainty based approach
257
+
258
+
259
+ indices_over_80 = []
260
+ classes_80 = {}
261
+ corrects_80 = {}
262
+
263
+ topk = 5
264
+ concatenated_predictions_argsorted_topk = concatenated_predictions_argsorted[:,:topk]
265
+ for certainty_threshold in [0.5, 0.8, 0.9]:
266
+ # certainty_threshold = 0.6
267
+ percentage_corrects = []
268
+ percentage_incorrects = []
269
+ with tqdm(total=(concatenated_predictions.shape[-1]), initial=0, leave=False, position=1, dynamic_ncols=True) as pbarinner:
270
+ pbarinner.set_description(f'Certainty threshold={certainty_threshold}')
271
+ for stepi in np.arange(concatenated_predictions.shape[-1]):
272
+ certainty_here = concatenated_certainties[:,1,stepi]
273
+ certainty_mask = certainty_here>=certainty_threshold
274
+ predictions_here = concatenated_predictions_argsorted_topk[:,:,stepi]
275
+ is_correct_here = np.any(predictions_here==concatenated_targets[...,np.newaxis], axis=-1)
276
+ percentage_corrects.append(is_correct_here[certainty_mask].sum()/predictions_here.shape[0])
277
+ percentage_incorrects.append((~is_correct_here)[certainty_mask].sum()/predictions_here.shape[0])
278
+
279
+ if certainty_threshold==0.8:
280
+ indices_certain = np.where(certainty_mask)[0]
281
+ for index in indices_certain:
282
+ if index not in indices_over_80:
283
+ indices_over_80.append(index)
284
+ if concatenated_targets[index] not in classes_80:
285
+ classes_80[concatenated_targets[index]] = [stepi]
286
+ corrects_80[concatenated_targets[index]] = [is_correct_here[index]]
287
+ else:
288
+ classes_80[concatenated_targets[index]] = classes_80[concatenated_targets[index]]+[stepi]
289
+ corrects_80[concatenated_targets[index]] = corrects_80[concatenated_targets[index]]+[is_correct_here[index]]
290
+
291
+
292
+ pbarinner.update(1)
293
+ fig = plt.figure(figsize=(6.5*figscale, 4*figscale))
294
+ ax = fig.add_subplot(111)
295
+ ax.bar(np.arange(concatenated_predictions.shape[-1])+1,
296
+ percentage_corrects,
297
+ color='forestgreen',
298
+ hatch='OO',
299
+ width=0.9,
300
+ label='Positive',
301
+ alpha=0.9,
302
+ linewidth=1.0*figscale)
303
+
304
+ ax.bar(np.arange(concatenated_predictions.shape[-1])+1,
305
+ percentage_incorrects,
306
+ bottom=percentage_corrects,
307
+ color='crimson',
308
+ hatch='xx',
309
+ width=0.9,
310
+ label='Negative',
311
+ alpha=0.9,
312
+ linewidth=1.0*figscale)
313
+ ax.set_xlim(-1, concatenated_predictions.shape[-1]+1)
314
+ ax.set_xlabel('Internal tick')
315
+ ax.set_ylabel('% of data')
316
+ ax.legend(loc='lower right')
317
+
318
+
319
+ fig.tight_layout(pad=0.1)
320
+ fig.savefig(f'{args.output_dir}/steps_versus_correct_{certainty_threshold}.png', dpi=200)
321
+ fig.savefig(f'{args.output_dir}/steps_versus_correct_{certainty_threshold}.pdf', dpi=200)
322
+ plt.close(fig)
323
+
324
+
325
+ class_list = list(classes_80.keys())
326
+ mean_steps = [np.mean(classes_80[cls]) for cls in class_list]
327
+ std_steps = [np.std(classes_80[cls]) for cls in class_list]
328
+
329
+
330
+ # Following code plots the class distribution over internal ticks
331
+ indices_to_show = np.arange(1000)
332
+
333
+ colours = cmap_diverse = plt.get_cmap('rainbow')(np.linspace(0, 1, 1000))
334
+ # np.random.shuffle(colours)
335
+ bottom = np.zeros(concatenated_predictions.shape[-1])
336
+
337
+ fig = plt.figure(figsize=(7*figscale, 4*figscale))
338
+ ax = fig.add_subplot(111)
339
+ for iii, idx in enumerate(indices_to_show):
340
+ if idx in classes_80:
341
+ steps = classes_80[idx]
342
+ colour = colours[iii]
343
+ vs, cts = np.unique(steps, return_counts=True)
344
+
345
+ bar = np.zeros(concatenated_predictions.shape[-1])
346
+ bar[vs] = cts
347
+ ax.bar(np.arange(concatenated_predictions.shape[-1])+1, bar, bottom=bottom, color=colour, width=1, edgecolor='none')
348
+ bottom += bar
349
+ ax.set_xlabel('Internal ticks')
350
+ ax.set_ylabel('Counts over 0.8 certainty')
351
+ fig.tight_layout(pad=0.1)
352
+ fig.savefig(f'{args.output_dir}/class_counts.png', dpi=200)
353
+ fig.savefig(f'{args.output_dir}/class_counts.pdf', dpi=200)
354
+ plt.close(fig)
355
+
356
+
357
+
358
+
359
+
360
+ # The following code plots calibration
361
+ probability_space = np.linspace(0, 1, 10)
362
+ fig = plt.figure(figsize=(6*figscale, 4*figscale))
363
+ ax = fig.add_subplot(111)
364
+
365
+
366
+ color_linspace = np.linspace(0, 1, concatenated_predictions.shape[-1])
367
+ with tqdm(total=(concatenated_predictions.shape[-1]), initial=0, leave=False, position=1, dynamic_ncols=True) as pbarinner:
368
+ pbarinner.set_description(f'Calibration')
369
+ for stepi in np.arange(concatenated_predictions.shape[-1]):
370
+ color = cmap_calib(color_linspace[stepi])
371
+ pred = concatenated_predictions[:,:,stepi].argmax(1)
372
+ is_correct = pred == concatenated_targets # BxT
373
+ probabilities = softmax(concatenated_predictions[:,:,:stepi+1], axis=1)[np.arange(concatenated_predictions.shape[0]),pred].mean(-1)#softmax(concatenated_predictions[:,:,stepi], axis=1).max(1)
374
+ probability_space = np.linspace(0, 1, 10)
375
+ accuracies_per_bin = []
376
+ bin_centers = []
377
+ for pi in range(len(probability_space)-1):
378
+ bin_low = probability_space[pi]
379
+ bin_high = probability_space[pi+1]
380
+ mask = ((probabilities >=bin_low) & (probabilities < bin_high)) if pi !=len(probability_space)-2 else ((probabilities >=bin_low) & (probabilities <= bin_high))
381
+ accuracies_per_bin.append(is_correct[mask].mean())
382
+ bin_centers.append(probabilities[mask].mean())
383
+
384
+
385
+ if stepi==concatenated_predictions.shape[-1]-1:
386
+ ax.plot(bin_centers, accuracies_per_bin, linestyle='-', marker='.', color='#4050f7', alpha=1, label='After all ticks')
387
+ else: ax.plot(bin_centers, accuracies_per_bin, linestyle='-', marker='.', color=color, alpha=0.65)
388
+ pbarinner.update(1)
389
+ ax.plot(probability_space, np.linspace(0, 1, len(probability_space)), 'k--')
390
+
391
+ ax.legend(loc='upper left')
392
+ ax.set_xlim([-0.01, 1.01])
393
+ ax.set_ylim([-0.01, 1.01])
394
+
395
+ sm = plt.cm.ScalarMappable(cmap=cmap_calib, norm=plt.Normalize(vmin=0, vmax=concatenated_predictions.shape[-1] - 1))
396
+ sm.set_array([]) # Empty array for colormap
397
+ cbar = fig.colorbar(sm, ax=ax, orientation='vertical', pad=0.02)
398
+ cbar.set_label('Internal ticks')
399
+
400
+ ax.set_xlabel('Mean predicted probabilities')
401
+ ax.set_ylabel('Ratio of positives')
402
+ fig.tight_layout(pad=0.1)
403
+ fig.savefig(f'{args.output_dir}/imagenet_calibration.png', dpi=200)
404
+ fig.savefig(f'{args.output_dir}/imagenet_calibration.pdf', dpi=200)
405
+ plt.close(fig)
406
+ if 'videos' in args.actions:
407
+ if not args.data_indices: # If list is empty
408
+ n_samples = len(validation_dataset)
409
+ num_to_sample = min(args.N_to_viz, n_samples)
410
+ replace = n_samples < num_to_sample
411
+ data_indices = np.random.choice(np.arange(n_samples), size=num_to_sample, replace=replace)
412
+ print(f"Selected random indices: {data_indices}")
413
+ else:
414
+ data_indices = args.data_indices
415
+ print(f"Using specified indices: {data_indices}")
416
+
417
+
418
+ for di in data_indices:
419
+ print(f'\nBuilding viz for dataset index {di}.')
420
+
421
+ # --- Get Data & Run Inference ---
422
+ # inputs_norm is already normalized by the transform
423
+ inputs, ground_truth_target = validation_dataset.__getitem__(int(di))
424
+
425
+ # Add batch dimension and send to device
426
+ inputs = inputs.to(device).unsqueeze(0)
427
+
428
+ # Run model inference
429
+ predictions, certainties, synchronisation, pre_activations, post_activations, attention_tracking = model(inputs, track=True)
430
+ # predictions: (B, Classes, Steps), attention_tracking: (Steps*B*Heads, SeqLen)
431
+ n_steps = predictions.size(-1)
432
+
433
+ # --- Reshape Attention ---
434
+ # Infer feature map size from model internals (assuming B=1)
435
+ h_feat, w_feat = model.kv_features.shape[-2:]
436
+
437
+ n_heads = attention_tracking.shape[2]
438
+ # Reshape to (Steps, Heads, H_feat, W_feat) assuming B=1
439
+ attention_tracking = attention_tracking.reshape(n_steps, n_heads, h_feat, w_feat)
440
+
441
+ # --- Setup for Plotting ---
442
+ step_linspace = np.linspace(0, 1, n_steps) # For step colors
443
+ # Define color maps
444
+ cmap_spectral = sns.color_palette("Spectral", as_cmap=True)
445
+ cmap_attention = sns.color_palette('viridis', as_cmap=True)
446
+
447
+ # Create output directory for this index
448
+ index_output_dir = os.path.join(args.output_dir, str(di))
449
+ os.makedirs(index_output_dir, exist_ok=True)
450
+
451
+ frames = [] # Store frames for GIF
452
+ head_routes = {h: [] for h in range(n_heads)} # Store (y,x) path points per head
453
+ head_routes[-1] = []
454
+ route_colours_step = [] # Store colors for each step's path segments
455
+
456
+ # --- Loop Through Each Step ---
457
+ for step_i in range(n_steps):
458
+
459
+ # --- Prepare Image for Display ---
460
+ # Denormalize the input tensor for visualization
461
+ data_img_tensor = inputs[0].cpu() # Get first item in batch, move to CPU
462
+ mean_tensor = torch.tensor(dataset_mean).view(3, 1, 1)
463
+ std_tensor = torch.tensor(dataset_std).view(3, 1, 1)
464
+ data_img_denorm = data_img_tensor * std_tensor + mean_tensor
465
+ # Permute to (H, W, C) and convert to numpy, clip to [0, 1]
466
+ data_img_np = data_img_denorm.permute(1, 2, 0).detach().numpy()
467
+ data_img_np = np.clip(data_img_np, 0, 1)
468
+ img_h, img_w = data_img_np.shape[:2]
469
+
470
+ # --- Process Attention & Certainty ---
471
+ # Average attention over last few steps (from original code)
472
+ start_step = max(0, step_i - 5)
473
+ attention_now = attention_tracking[start_step : step_i + 1].mean(0) # Avg over steps -> (Heads, H_feat, W_feat)
474
+ # Get certainties up to current step
475
+ certainties_now = certainties[0, 1, :step_i+1].detach().cpu().numpy() # Assuming index 1 holds relevant certainty
476
+
477
+ # --- Calculate Attention Paths (using bilinear interp) ---
478
+ # Interpolate attention to image size using bilinear for center finding
479
+ attention_interp_bilinear = F.interpolate(
480
+ torch.from_numpy(attention_now).unsqueeze(0).float(), # Add batch dim, ensure float
481
+ size=(img_h, img_w),
482
+ mode=interp_mode,
483
+ # align_corners=False
484
+ ).squeeze(0) # Remove batch dim -> (Heads, H, W)
485
+
486
+ # Normalize each head's map to [0, 1]
487
+ # Deal with mean
488
+ attn_mean = attention_interp_bilinear.mean(0)
489
+ attn_mean_min = attn_mean.min()
490
+ attn_mean_max = attn_mean.max()
491
+ attn_mean = (attn_mean - attn_mean_min) / (attn_mean_max - attn_mean_min)
492
+ centers, areas = find_island_centers(attn_mean.detach().cpu().numpy(), threshold=0.7)
493
+
494
+ if centers: # If islands found
495
+ largest_island_idx = np.argmax(areas)
496
+ current_center = centers[largest_island_idx] # (y, x)
497
+ head_routes[-1].append(current_center)
498
+ elif head_routes[-1]: # If no center now, repeat last known center if history exists
499
+ head_routes[-1].append(head_routes[-1][-1])
500
+
501
+
502
+ attn_min = attention_interp_bilinear.view(n_heads, -1).min(dim=-1, keepdim=True)[0].unsqueeze(-1)
503
+ attn_max = attention_interp_bilinear.view(n_heads, -1).max(dim=-1, keepdim=True)[0].unsqueeze(-1)
504
+ attention_interp_bilinear = (attention_interp_bilinear - attn_min) / (attn_max - attn_min + 1e-6)
505
+
506
+ # Store step color
507
+ current_colour = list(cmap_spectral(step_linspace[step_i]))
508
+ route_colours_step.append(current_colour)
509
+
510
+ # Find island center for each head
511
+ for head_i in range(n_heads):
512
+ attn_head_np = attention_interp_bilinear[head_i].detach().cpu().numpy()
513
+ # Keep threshold=0.7 based on original call
514
+ centers, areas = find_island_centers(attn_head_np, threshold=0.7)
515
+
516
+ if centers: # If islands found
517
+ largest_island_idx = np.argmax(areas)
518
+ current_center = centers[largest_island_idx] # (y, x)
519
+ head_routes[head_i].append(current_center)
520
+ elif head_routes[head_i]: # If no center now, repeat last known center if history exists
521
+ head_routes[head_i].append(head_routes[head_i][-1])
522
+
523
+
524
+
525
+ # --- Plotting Setup ---
526
+ mosaic = [['head_mean', 'head_mean', 'head_mean', 'head_mean', 'head_mean_overlay', 'head_mean_overlay', 'head_mean_overlay', 'head_mean_overlay'],
527
+ ['head_mean', 'head_mean', 'head_mean', 'head_mean', 'head_mean_overlay', 'head_mean_overlay', 'head_mean_overlay', 'head_mean_overlay'],
528
+ ['head_mean', 'head_mean', 'head_mean', 'head_mean', 'head_mean_overlay', 'head_mean_overlay', 'head_mean_overlay', 'head_mean_overlay'],
529
+ ['head_mean', 'head_mean', 'head_mean', 'head_mean', 'head_mean_overlay', 'head_mean_overlay', 'head_mean_overlay', 'head_mean_overlay'],
530
+ ['head_0', 'head_0_overlay', 'head_1', 'head_1_overlay', 'head_2', 'head_2_overlay', 'head_3', 'head_3_overlay'],
531
+ ['head_4', 'head_4_overlay', 'head_5', 'head_5_overlay','head_6', 'head_6_overlay', 'head_7', 'head_7_overlay'],
532
+ ['head_8', 'head_8_overlay', 'head_9', 'head_9_overlay','head_10', 'head_10_overlay', 'head_11', 'head_11_overlay'],
533
+ ['head_12', 'head_12_overlay', 'head_13', 'head_13_overlay','head_14', 'head_14_overlay', 'head_15', 'head_15_overlay'],
534
+ ['probabilities', 'probabilities','probabilities', 'probabilities', 'certainty', 'certainty', 'certainty', 'certainty'],
535
+ ]
536
+
537
+ img_aspect = data_img_np.shape[0] / data_img_np.shape[1]
538
+ aspect_ratio = (8 * figscale, 9 * figscale * img_aspect) # W, H
539
+ fig, axes = plt.subplot_mosaic(mosaic, figsize=aspect_ratio)
540
+
541
+ for ax in axes.values():
542
+ ax.axis('off')
543
+
544
+ # --- Plot Certainty ---
545
+ ax_cert = axes['certainty']
546
+ ax_cert.plot(np.arange(len(certainties_now)), certainties_now, 'k-', linewidth=figscale*1)
547
+ # Add background color based on prediction correctness at each step
548
+ for ii in range(len(certainties_now)):
549
+ is_correct = predictions[0, :, ii].argmax(-1).item() == ground_truth_target # .item() for scalar tensor
550
+ facecolor = 'limegreen' if is_correct else 'orchid'
551
+ ax_cert.axvspan(ii, ii + 1, facecolor=facecolor, edgecolor=None, lw=0, alpha=0.3)
552
+ # Mark the last point
553
+ ax_cert.plot(len(certainties_now)-1, certainties_now[-1], 'k.', markersize=figscale*4)
554
+ ax_cert.axis('off')
555
+ ax_cert.set_ylim([0.05, 1.05])
556
+ ax_cert.set_xlim([0, n_steps]) # Use n_steps for consistent x-axis limit
557
+
558
+ # --- Plot Probabilities ---
559
+ ax_prob = axes['probabilities']
560
+ # Get probabilities for the current step
561
+ ps = torch.softmax(predictions[0, :, step_i], -1).detach().cpu()
562
+ k = 15 # Top k predictions
563
+ topk_probs, topk_indices = torch.topk(ps, k, dim=0, largest=True)
564
+ topk_indices = topk_indices.numpy()
565
+ topk_probs = topk_probs.numpy()
566
+
567
+ top_classes = np.array(class_labels)[topk_indices]
568
+ true_class_idx = ground_truth_target # Ground truth index
569
+
570
+ # Determine bar colors (green if correct, blue otherwise - consistent with original)
571
+ colours = ['g' if idx == true_class_idx else 'b' for idx in topk_indices]
572
+
573
+ # Plot horizontal bars (inverted range for top-down display)
574
+ ax_prob.barh(np.arange(k)[::-1], topk_probs, color=colours, alpha=1) # Use barh and inverted range
575
+ ax_prob.set_xlim([0, 1])
576
+ ax_prob.axis('off')
577
+
578
+ # Add text labels for top classes
579
+ for i, name_idx in enumerate(topk_indices):
580
+ name = class_labels[name_idx] # Get name from index
581
+ is_correct = name_idx == true_class_idx
582
+ fg_color = 'darkgreen' if is_correct else 'crimson' # Text colors from original
583
+ text_str = f'{name[:40]}' # Truncate long names
584
+ # Position text on the left side of the horizontal bars
585
+ ax_prob.text(
586
+ 0.01, # Small offset from left edge
587
+ k - 1 - i, # Y-position corresponding to the bar
588
+ text_str,
589
+ #transform=ax_prob.transAxes, # Use data coordinates for Y
590
+ verticalalignment='center',
591
+ horizontalalignment='left',
592
+ fontsize=8,
593
+ color=fg_color,
594
+ alpha=0.9, # Slightly more visible than 0.5
595
+ path_effects=[
596
+ patheffects.Stroke(linewidth=2, foreground='white'), # Adjusted stroke
597
+ patheffects.Normal()
598
+ ])
599
+
600
+
601
+ # --- Plot Attention Heads & Overlays (using nearest interp) ---
602
+ # Re-interpolate attention using nearest neighbor for visual plotting
603
+ attention_interp_plot = F.interpolate(
604
+ torch.from_numpy(attention_now).unsqueeze(0).float(),
605
+ size=(img_h, img_w),
606
+ mode=interp_mode, # 'nearest'
607
+ ).squeeze(0)
608
+
609
+ attn_mean = attention_interp_plot.mean(0)
610
+ attn_mean_min = attn_mean.min()
611
+ attn_mean_max = attn_mean.max()
612
+ attn_mean = (attn_mean - attn_mean_min) / (attn_mean_max - attn_mean_min)
613
+
614
+
615
+ # Normalize each head's map to [0, 1]
616
+ attn_min_plot = attention_interp_plot.view(n_heads, -1).min(dim=-1, keepdim=True)[0].unsqueeze(-1)
617
+ attn_max_plot = attention_interp_plot.view(n_heads, -1).max(dim=-1, keepdim=True)[0].unsqueeze(-1)
618
+ attention_interp_plot = (attention_interp_plot - attn_min_plot) / (attn_max_plot - attn_min_plot + 1e-6)
619
+ attention_interp_plot_np = attention_interp_plot.detach().cpu().numpy()
620
+
621
+
622
+
623
+
624
+
625
+
626
+ for head_i in list(range(n_heads)) + [-1]:
627
+ axname = f'head_{head_i}' if head_i != -1 else 'head_mean'
628
+ if axname not in axes: continue # Skip if mosaic doesn't have this head
629
+
630
+ ax = axes[axname]
631
+ ax_overlay = axes[f'{axname}_overlay']
632
+
633
+ # Plot attention heatmap
634
+ this_attn = attention_interp_plot_np[head_i] if head_i != -1 else attn_mean
635
+ img_to_plot = cmap_attention(this_attn)
636
+ ax.imshow(img_to_plot)
637
+ ax.axis('off')
638
+
639
+ # Plot overlay: image + paths
640
+ these_route_steps = head_routes[head_i]
641
+ arrow_scale = 1.5 if head_i != -1 else 3
642
+
643
+ if these_route_steps: # Only plot if path exists
644
+ # Separate y and x coordinates
645
+ y_coords, x_coords = zip(*these_route_steps)
646
+ y_coords = np.array(y_coords)
647
+ x_coords = np.array(x_coords)
648
+
649
+ # Flip y-coordinates for correct plotting (imshow origin is top-left)
650
+ # NOTE: Original flip seemed complex, simplifying to standard flip
651
+ y_coords_flipped = img_h - 1 - y_coords
652
+
653
+ # Show original image flipped vertically to match coordinate system
654
+ ax_overlay.imshow(np.flipud(data_img_np), origin='lower')
655
+
656
+ # Draw arrows for path segments
657
+ # Arrow size scaling from original
658
+ for i in range(len(these_route_steps) - 1):
659
+ dx = x_coords[i+1] - x_coords[i]
660
+ dy = y_coords_flipped[i+1] - y_coords_flipped[i] # Use flipped y for delta
661
+
662
+ # Draw white background arrow (thicker)
663
+ ax_overlay.arrow(x_coords[i], y_coords_flipped[i], dx, dy,
664
+ linewidth=1.6 * arrow_scale * 1.3,
665
+ head_width=1.9 * arrow_scale * 1.3,
666
+ head_length=1.4 * arrow_scale * 1.45,
667
+ fc='white', ec='white', length_includes_head=True, alpha=1)
668
+ # Draw colored foreground arrow
669
+ ax_overlay.arrow(x_coords[i], y_coords_flipped[i], dx, dy,
670
+ linewidth=1.6 * arrow_scale,
671
+ head_width=1.9 * arrow_scale,
672
+ head_length=1.4 * arrow_scale,
673
+ fc=route_colours_step[i], ec=route_colours_step[i], # Use step color
674
+ length_includes_head=True)
675
+
676
+ else: # If no path yet, just show the image
677
+ ax_overlay.imshow(np.flipud(data_img_np), origin='lower')
678
+
679
+
680
+ # Set limits and turn off axes for overlay
681
+ ax_overlay.set_xlim([0, img_w - 1])
682
+ ax_overlay.set_ylim([0, img_h - 1])
683
+ ax_overlay.axis('off')
684
+
685
+
686
+ # --- Finalize and Save Frame ---
687
+ fig.tight_layout(pad=0.1) # Adjust spacing
688
+
689
+ # Render the plot to a numpy array
690
+ canvas = fig.canvas
691
+ canvas.draw()
692
+ image_numpy = np.frombuffer(canvas.buffer_rgba(), dtype='uint8')
693
+ image_numpy = image_numpy.reshape(*reversed(canvas.get_width_height()), 4)[:,:,:3] # Get RGB
694
+
695
+ frames.append(image_numpy) # Add to list for GIF
696
+
697
+
698
+
699
+ plt.close(fig) # Close figure to free memory
700
+
701
+ # --- Save GIF ---
702
+ gif_path = os.path.join(index_output_dir, f'{str(di)}_viz.gif')
703
+ print(f"Saving GIF to {gif_path}...")
704
+ imageio.mimsave(gif_path, frames, fps=15, loop=0) # loop=0 means infinite loop
705
+ save_frames_to_mp4([fm[:,:,::-1] for fm in frames], os.path.join(index_output_dir, f'{str(di)}_viz.mp4'), fps=15, gop_size=1, preset='veryslow')
706
+ if 'demo' in args.actions:
707
+
708
+
709
+
710
+ # --- Select Data Indices ---
711
+ if not args.data_indices: # If list is empty
712
+ n_samples = len(validation_dataset)
713
+ num_to_sample = min(args.N_to_viz, n_samples)
714
+ replace = n_samples < num_to_sample
715
+ data_indices = np.random.choice(np.arange(n_samples), size=num_to_sample, replace=replace)
716
+ print(f"Selected random indices: {data_indices}")
717
+ else:
718
+ data_indices = args.data_indices
719
+ print(f"Using specified indices: {data_indices}")
720
+
721
+
722
+ for di in data_indices:
723
+
724
+ index_output_dir = os.path.join(args.output_dir, str(di))
725
+ os.makedirs(index_output_dir, exist_ok=True)
726
+
727
+ print(f'\nBuilding viz for dataset index {di}.')
728
+
729
+ inputs, ground_truth_target = validation_dataset.__getitem__(int(di))
730
+
731
+ # Add batch dimension and send to device
732
+ inputs = inputs.to(device).unsqueeze(0)
733
+ predictions, certainties, synchronisations_over_time, pre_activations, post_activations, attention_tracking = model(inputs, track=True)
734
+
735
+ # --- Reshape Attention ---
736
+ # Infer feature map size from model internals (assuming B=1)
737
+ h_feat, w_feat = model.kv_features.shape[-2:]
738
+ n_steps = predictions.size(-1)
739
+ n_heads = attention_tracking.shape[2]
740
+ # Reshape to (Steps, Heads, H_feat, W_feat) assuming B=1
741
+ attention_tracking = attention_tracking.reshape(n_steps, n_heads, h_feat, w_feat)
742
+
743
+ # --- Setup for Plotting ---
744
+ step_linspace = np.linspace(0, 1, n_steps) # For step colors
745
+ # Define color maps
746
+ cmap_steps = sns.color_palette("Spectral", as_cmap=True)
747
+ cmap_attention = sns.color_palette('viridis', as_cmap=True)
748
+
749
+ # Create output directory for this index
750
+
751
+
752
+ frames = [] # Store frames for GIF
753
+ head_routes = [] # Store (y,x) path points per head
754
+ route_colours_step = [] # Store colors for each step's path segments
755
+
756
+ # --- Loop Through Each Step ---
757
+ for step_i in range(n_steps):
758
+
759
+ # Store step color
760
+ current_colour = list(cmap_steps(step_linspace[step_i]))
761
+ route_colours_step.append(current_colour)
762
+
763
+ # --- Prepare Image for Display ---
764
+ # Denormalize the input tensor for visualization
765
+ data_img_tensor = inputs[0].cpu() # Get first item in batch, move to CPU
766
+ mean_tensor = torch.tensor(dataset_mean).view(3, 1, 1)
767
+ std_tensor = torch.tensor(dataset_std).view(3, 1, 1)
768
+ data_img_denorm = data_img_tensor * std_tensor + mean_tensor
769
+ # Permute to (H, W, C) and convert to numpy, clip to [0, 1]
770
+ data_img_np = data_img_denorm.permute(1, 2, 0).detach().numpy()
771
+ data_img_np = np.clip(data_img_np, 0, 1)
772
+ img_h, img_w = data_img_np.shape[:2]
773
+
774
+ # --- Process Attention & Certainty ---
775
+ # Average attention over last few steps (from original code)
776
+ start_step = max(0, step_i - 5)
777
+ attention_now = attention_tracking[start_step : step_i + 1].mean(0) # Avg over steps -> (Heads, H_feat, W_feat)
778
+ # Get certainties up to current step
779
+ certainties_now = certainties[0, 1, :step_i+1].detach().cpu().numpy() # Assuming index 1 holds relevant certainty
780
+
781
+ # --- Calculate Attention Paths (using bilinear interp) ---
782
+ # Interpolate attention to image size using bilinear for center finding
783
+ attention_interp_bilinear = F.interpolate(
784
+ torch.from_numpy(attention_now).unsqueeze(0).float(), # Add batch dim, ensure float
785
+ size=(img_h, img_w),
786
+ mode=interp_mode,
787
+ ).squeeze(0) # Remove batch dim -> (Heads, H, W)
788
+
789
+ attn_mean = attention_interp_bilinear.mean(0)
790
+ attn_mean_min = attn_mean.min()
791
+ attn_mean_max = attn_mean.max()
792
+ attn_mean = (attn_mean - attn_mean_min) / (attn_mean_max - attn_mean_min)
793
+ centers, areas = find_island_centers(attn_mean.detach().cpu().numpy(), threshold=0.7)
794
+
795
+ if centers: # If islands found
796
+ largest_island_idx = np.argmax(areas)
797
+ current_center = centers[largest_island_idx] # (y, x)
798
+ head_routes.append(current_center)
799
+ elif head_routes: # If no center now, repeat last known center if history exists
800
+ head_routes.append(head_routes[-1])
801
+
802
+ # --- Plotting Setup ---
803
+ # if n_heads != 8: print(f"Warning: Plotting layout assumes 8 heads, found {n_heads}. Layout may be incorrect.")
804
+ mosaic = [['head_0', 'head_1', 'head_2', 'head_3', 'head_mean', 'head_mean', 'head_mean', 'head_mean', 'overlay', 'overlay', 'overlay', 'overlay'],
805
+ ['head_4', 'head_5', 'head_6', 'head_7', 'head_mean', 'head_mean', 'head_mean', 'head_mean', 'overlay', 'overlay', 'overlay', 'overlay'],
806
+ ['head_8', 'head_9', 'head_10', 'head_11', 'head_mean', 'head_mean', 'head_mean', 'head_mean', 'overlay', 'overlay', 'overlay', 'overlay'],
807
+ ['head_12', 'head_13', 'head_14', 'head_15', 'head_mean', 'head_mean', 'head_mean', 'head_mean', 'overlay', 'overlay', 'overlay', 'overlay'],
808
+ ['probabilities', 'probabilities', 'probabilities', 'probabilities', 'probabilities', 'probabilities', 'certainty', 'certainty', 'certainty', 'certainty', 'certainty', 'certainty'],
809
+ ['probabilities', 'probabilities', 'probabilities', 'probabilities', 'probabilities', 'probabilities', 'certainty', 'certainty', 'certainty', 'certainty', 'certainty', 'certainty'],
810
+ ]
811
+
812
+ img_aspect = data_img_np.shape[0] / data_img_np.shape[1]
813
+ aspect_ratio = (12 * figscale, 6 * figscale * img_aspect) # W, H
814
+ fig, axes = plt.subplot_mosaic(mosaic, figsize=aspect_ratio)
815
+ for ax in axes.values():
816
+ ax.axis('off')
817
+
818
+ # --- Plot Certainty ---
819
+ ax_cert = axes['certainty']
820
+ ax_cert.plot(np.arange(len(certainties_now)), certainties_now, 'k-', linewidth=figscale*1)
821
+ # Add background color based on prediction correctness at each step
822
+ for ii in range(len(certainties_now)):
823
+ is_correct = predictions[0, :, ii].argmax(-1).item() == ground_truth_target # .item() for scalar tensor
824
+ facecolor = 'limegreen' if is_correct else 'orchid'
825
+ ax_cert.axvspan(ii, ii + 1, facecolor=facecolor, edgecolor=None, lw=0, alpha=0.3)
826
+ # Mark the last point
827
+ ax_cert.plot(len(certainties_now)-1, certainties_now[-1], 'k.', markersize=figscale*4)
828
+ ax_cert.axis('off')
829
+ ax_cert.set_ylim([0.05, 1.05])
830
+ ax_cert.set_xlim([0, n_steps]) # Use n_steps for consistent x-axis limit
831
+
832
+ # --- Plot Probabilities ---
833
+ ax_prob = axes['probabilities']
834
+ # Get probabilities for the current step
835
+ ps = torch.softmax(predictions[0, :, step_i], -1).detach().cpu()
836
+ k = 15 # Top k predictions
837
+ topk_probs, topk_indices = torch.topk(ps, k, dim=0, largest=True)
838
+ topk_indices = topk_indices.numpy()
839
+ topk_probs = topk_probs.numpy()
840
+
841
+ top_classes = np.array(class_labels)[topk_indices]
842
+ true_class_idx = ground_truth_target # Ground truth index
843
+
844
+ # Determine bar colors (green if correct, blue otherwise - consistent with original)
845
+ colours = ['g' if idx == true_class_idx else 'b' for idx in topk_indices]
846
+
847
+ # Plot horizontal bars (inverted range for top-down display)
848
+ ax_prob.barh(np.arange(k)[::-1], topk_probs, color=colours, alpha=1) # Use barh and inverted range
849
+ ax_prob.set_xlim([0, 1])
850
+ ax_prob.axis('off')
851
+
852
+ # Add text labels for top classes
853
+ for i, name_idx in enumerate(topk_indices):
854
+ name = class_labels[name_idx] # Get name from index
855
+ is_correct = name_idx == true_class_idx
856
+ fg_color = 'darkgreen' if is_correct else 'crimson' # Text colors from original
857
+ text_str = f'{name[:40]}' # Truncate long names
858
+ # Position text on the left side of the horizontal bars
859
+ ax_prob.text(
860
+ 0.01, # Small offset from left edge
861
+ k - 1 - i, # Y-position corresponding to the bar
862
+ text_str,
863
+ #transform=ax_prob.transAxes, # Use data coordinates for Y
864
+ verticalalignment='center',
865
+ horizontalalignment='left',
866
+ fontsize=8,
867
+ color=fg_color,
868
+ alpha=0.7, # Slightly more visible than 0.5
869
+ path_effects=[
870
+ patheffects.Stroke(linewidth=2, foreground='white'), # Adjusted stroke
871
+ patheffects.Normal()
872
+ ])
873
+
874
+
875
+ # --- Plot Attention Heads & Overlays (using nearest interp) ---
876
+ # Re-interpolate attention using nearest neighbor for visual plotting
877
+ attention_interp_plot = F.interpolate(
878
+ torch.from_numpy(attention_now).unsqueeze(0).float(),
879
+ size=(img_h, img_w),
880
+ mode=interp_mode # 'nearest'
881
+ ).squeeze(0)
882
+
883
+
884
+ attn_mean = attention_interp_plot.mean(0)
885
+ attn_mean_min = attn_mean.min()
886
+ attn_mean_max = attn_mean.max()
887
+ attn_mean = (attn_mean - attn_mean_min) / (attn_mean_max - attn_mean_min)
888
+
889
+
890
+ img_to_plot = cmap_attention(attn_mean)
891
+ axes['head_mean'].imshow(img_to_plot)
892
+ axes['head_mean'].axis('off')
893
+
894
+
895
+ these_route_steps = head_routes
896
+ ax_overlay = axes['overlay']
897
+
898
+ if these_route_steps: # Only plot if path exists
899
+ # Separate y and x coordinates
900
+ y_coords, x_coords = zip(*these_route_steps)
901
+ y_coords = np.array(y_coords)
902
+ x_coords = np.array(x_coords)
903
+
904
+ # Flip y-coordinates for correct plotting (imshow origin is top-left)
905
+ # NOTE: Original flip seemed complex, simplifying to standard flip
906
+ y_coords_flipped = img_h - 1 - y_coords
907
+
908
+ # Show original image flipped vertically to match coordinate system
909
+ ax_overlay.imshow(np.flipud(data_img_np), origin='lower')
910
+
911
+ # Draw arrows for path segments
912
+ arrow_scale = 2 # Arrow size scaling from original
913
+ for i in range(len(these_route_steps) - 1):
914
+ dx = x_coords[i+1] - x_coords[i]
915
+ dy = y_coords_flipped[i+1] - y_coords_flipped[i] # Use flipped y for delta
916
+
917
+ # Draw white background arrow (thicker)
918
+ ax_overlay.arrow(x_coords[i], y_coords_flipped[i], dx, dy,
919
+ linewidth=1.6 * arrow_scale * 1.3,
920
+ head_width=1.9 * arrow_scale * 1.3,
921
+ head_length=1.4 * arrow_scale * 1.45,
922
+ fc='white', ec='white', length_includes_head=True, alpha=1)
923
+ # Draw colored foreground arrow
924
+ ax_overlay.arrow(x_coords[i], y_coords_flipped[i], dx, dy,
925
+ linewidth=1.6 * arrow_scale,
926
+ head_width=1.9 * arrow_scale,
927
+ head_length=1.4 * arrow_scale,
928
+ fc=route_colours_step[i], ec=route_colours_step[i], # Use step color
929
+ length_includes_head=True)
930
+ # Set limits and turn off axes for overlay
931
+ ax_overlay.set_xlim([0, img_w - 1])
932
+ ax_overlay.set_ylim([0, img_h - 1])
933
+ ax_overlay.axis('off')
934
+
935
+
936
+ for head_i in range(n_heads):
937
+ if f'head_{head_i}' not in axes: continue # Skip if mosaic doesn't have this head
938
+
939
+ ax = axes[f'head_{head_i}']
940
+
941
+ # Plot attention heatmap
942
+ attn_up_to_now = attention_tracking[:step_i + 1, head_i].mean(0)
943
+ attn_up_to_now = (attn_up_to_now - attn_up_to_now.min())/(attn_up_to_now.max() - attn_up_to_now.min())
944
+ img_to_plot = cmap_attention(attn_up_to_now)
945
+ ax.imshow(img_to_plot)
946
+ ax.axis('off')
947
+
948
+
949
+
950
+
951
+
952
+
953
+ # --- Finalize and Save Frame ---
954
+ fig.tight_layout(pad=0.1) # Adjust spacing
955
+
956
+ # Render the plot to a numpy array
957
+ canvas = fig.canvas
958
+ canvas.draw()
959
+ image_numpy = np.frombuffer(canvas.buffer_rgba(), dtype='uint8')
960
+ image_numpy = image_numpy.reshape(*reversed(canvas.get_width_height()), 4)[:,:,:3] # Get RGB
961
+
962
+ frames.append(image_numpy) # Add to list for GIF
963
+
964
+ # Save individual frame if requested
965
+ if step_i==model.iterations-1:
966
+ fig.savefig(os.path.join(index_output_dir, f'frame_{step_i}.png'), dpi=200)
967
+
968
+ plt.close(fig) # Close figure to free memory
969
+ outfilename = os.path.join(index_output_dir, f'{di}_demo.mp4')
970
+ save_frames_to_mp4([fm[:,:,::-1] for fm in frames], outfilename, fps=15, gop_size=1, preset='veryslow')
971
+
972
+
tasks/image_classification/imagenet_classes.py ADDED
@@ -0,0 +1,1007 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+
3
+
4
+ IMAGENET2012_CLASSES = OrderedDict(
5
+ {
6
+ "n01440764": "tench, Tinca tinca",
7
+ "n01443537": "goldfish, Carassius auratus",
8
+ "n01484850": "great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias",
9
+ "n01491361": "tiger shark, Galeocerdo cuvieri",
10
+ "n01494475": "hammerhead, hammerhead shark",
11
+ "n01496331": "electric ray, crampfish, numbfish, torpedo",
12
+ "n01498041": "stingray",
13
+ "n01514668": "cock",
14
+ "n01514859": "hen",
15
+ "n01518878": "ostrich, Struthio camelus",
16
+ "n01530575": "brambling, Fringilla montifringilla",
17
+ "n01531178": "goldfinch, Carduelis carduelis",
18
+ "n01532829": "house finch, linnet, Carpodacus mexicanus",
19
+ "n01534433": "junco, snowbird",
20
+ "n01537544": "indigo bunting, indigo finch, indigo bird, Passerina cyanea",
21
+ "n01558993": "robin, American robin, Turdus migratorius",
22
+ "n01560419": "bulbul",
23
+ "n01580077": "jay",
24
+ "n01582220": "magpie",
25
+ "n01592084": "chickadee",
26
+ "n01601694": "water ouzel, dipper",
27
+ "n01608432": "kite",
28
+ "n01614925": "bald eagle, American eagle, Haliaeetus leucocephalus",
29
+ "n01616318": "vulture",
30
+ "n01622779": "great grey owl, great gray owl, Strix nebulosa",
31
+ "n01629819": "European fire salamander, Salamandra salamandra",
32
+ "n01630670": "common newt, Triturus vulgaris",
33
+ "n01631663": "eft",
34
+ "n01632458": "spotted salamander, Ambystoma maculatum",
35
+ "n01632777": "axolotl, mud puppy, Ambystoma mexicanum",
36
+ "n01641577": "bullfrog, Rana catesbeiana",
37
+ "n01644373": "tree frog, tree-frog",
38
+ "n01644900": "tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui",
39
+ "n01664065": "loggerhead, loggerhead turtle, Caretta caretta",
40
+ "n01665541": "leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea",
41
+ "n01667114": "mud turtle",
42
+ "n01667778": "terrapin",
43
+ "n01669191": "box turtle, box tortoise",
44
+ "n01675722": "banded gecko",
45
+ "n01677366": "common iguana, iguana, Iguana iguana",
46
+ "n01682714": "American chameleon, anole, Anolis carolinensis",
47
+ "n01685808": "whiptail, whiptail lizard",
48
+ "n01687978": "agama",
49
+ "n01688243": "frilled lizard, Chlamydosaurus kingi",
50
+ "n01689811": "alligator lizard",
51
+ "n01692333": "Gila monster, Heloderma suspectum",
52
+ "n01693334": "green lizard, Lacerta viridis",
53
+ "n01694178": "African chameleon, Chamaeleo chamaeleon",
54
+ "n01695060": "Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis",
55
+ "n01697457": "African crocodile, Nile crocodile, Crocodylus niloticus",
56
+ "n01698640": "American alligator, Alligator mississipiensis",
57
+ "n01704323": "triceratops",
58
+ "n01728572": "thunder snake, worm snake, Carphophis amoenus",
59
+ "n01728920": "ringneck snake, ring-necked snake, ring snake",
60
+ "n01729322": "hognose snake, puff adder, sand viper",
61
+ "n01729977": "green snake, grass snake",
62
+ "n01734418": "king snake, kingsnake",
63
+ "n01735189": "garter snake, grass snake",
64
+ "n01737021": "water snake",
65
+ "n01739381": "vine snake",
66
+ "n01740131": "night snake, Hypsiglena torquata",
67
+ "n01742172": "boa constrictor, Constrictor constrictor",
68
+ "n01744401": "rock python, rock snake, Python sebae",
69
+ "n01748264": "Indian cobra, Naja naja",
70
+ "n01749939": "green mamba",
71
+ "n01751748": "sea snake",
72
+ "n01753488": "horned viper, cerastes, sand viper, horned asp, Cerastes cornutus",
73
+ "n01755581": "diamondback, diamondback rattlesnake, Crotalus adamanteus",
74
+ "n01756291": "sidewinder, horned rattlesnake, Crotalus cerastes",
75
+ "n01768244": "trilobite",
76
+ "n01770081": "harvestman, daddy longlegs, Phalangium opilio",
77
+ "n01770393": "scorpion",
78
+ "n01773157": "black and gold garden spider, Argiope aurantia",
79
+ "n01773549": "barn spider, Araneus cavaticus",
80
+ "n01773797": "garden spider, Aranea diademata",
81
+ "n01774384": "black widow, Latrodectus mactans",
82
+ "n01774750": "tarantula",
83
+ "n01775062": "wolf spider, hunting spider",
84
+ "n01776313": "tick",
85
+ "n01784675": "centipede",
86
+ "n01795545": "black grouse",
87
+ "n01796340": "ptarmigan",
88
+ "n01797886": "ruffed grouse, partridge, Bonasa umbellus",
89
+ "n01798484": "prairie chicken, prairie grouse, prairie fowl",
90
+ "n01806143": "peacock",
91
+ "n01806567": "quail",
92
+ "n01807496": "partridge",
93
+ "n01817953": "African grey, African gray, Psittacus erithacus",
94
+ "n01818515": "macaw",
95
+ "n01819313": "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita",
96
+ "n01820546": "lorikeet",
97
+ "n01824575": "coucal",
98
+ "n01828970": "bee eater",
99
+ "n01829413": "hornbill",
100
+ "n01833805": "hummingbird",
101
+ "n01843065": "jacamar",
102
+ "n01843383": "toucan",
103
+ "n01847000": "drake",
104
+ "n01855032": "red-breasted merganser, Mergus serrator",
105
+ "n01855672": "goose",
106
+ "n01860187": "black swan, Cygnus atratus",
107
+ "n01871265": "tusker",
108
+ "n01872401": "echidna, spiny anteater, anteater",
109
+ "n01873310": "platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus",
110
+ "n01877812": "wallaby, brush kangaroo",
111
+ "n01882714": "koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus",
112
+ "n01883070": "wombat",
113
+ "n01910747": "jellyfish",
114
+ "n01914609": "sea anemone, anemone",
115
+ "n01917289": "brain coral",
116
+ "n01924916": "flatworm, platyhelminth",
117
+ "n01930112": "nematode, nematode worm, roundworm",
118
+ "n01943899": "conch",
119
+ "n01944390": "snail",
120
+ "n01945685": "slug",
121
+ "n01950731": "sea slug, nudibranch",
122
+ "n01955084": "chiton, coat-of-mail shell, sea cradle, polyplacophore",
123
+ "n01968897": "chambered nautilus, pearly nautilus, nautilus",
124
+ "n01978287": "Dungeness crab, Cancer magister",
125
+ "n01978455": "rock crab, Cancer irroratus",
126
+ "n01980166": "fiddler crab",
127
+ "n01981276": "king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica",
128
+ "n01983481": "American lobster, Northern lobster, Maine lobster, Homarus americanus",
129
+ "n01984695": "spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish",
130
+ "n01985128": "crayfish, crawfish, crawdad, crawdaddy",
131
+ "n01986214": "hermit crab",
132
+ "n01990800": "isopod",
133
+ "n02002556": "white stork, Ciconia ciconia",
134
+ "n02002724": "black stork, Ciconia nigra",
135
+ "n02006656": "spoonbill",
136
+ "n02007558": "flamingo",
137
+ "n02009229": "little blue heron, Egretta caerulea",
138
+ "n02009912": "American egret, great white heron, Egretta albus",
139
+ "n02011460": "bittern",
140
+ "n02012849": "crane",
141
+ "n02013706": "limpkin, Aramus pictus",
142
+ "n02017213": "European gallinule, Porphyrio porphyrio",
143
+ "n02018207": "American coot, marsh hen, mud hen, water hen, Fulica americana",
144
+ "n02018795": "bustard",
145
+ "n02025239": "ruddy turnstone, Arenaria interpres",
146
+ "n02027492": "red-backed sandpiper, dunlin, Erolia alpina",
147
+ "n02028035": "redshank, Tringa totanus",
148
+ "n02033041": "dowitcher",
149
+ "n02037110": "oystercatcher, oyster catcher",
150
+ "n02051845": "pelican",
151
+ "n02056570": "king penguin, Aptenodytes patagonica",
152
+ "n02058221": "albatross, mollymawk",
153
+ "n02066245": "grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus",
154
+ "n02071294": "killer whale, killer, orca, grampus, sea wolf, Orcinus orca",
155
+ "n02074367": "dugong, Dugong dugon",
156
+ "n02077923": "sea lion",
157
+ "n02085620": "Chihuahua",
158
+ "n02085782": "Japanese spaniel",
159
+ "n02085936": "Maltese dog, Maltese terrier, Maltese",
160
+ "n02086079": "Pekinese, Pekingese, Peke",
161
+ "n02086240": "Shih-Tzu",
162
+ "n02086646": "Blenheim spaniel",
163
+ "n02086910": "papillon",
164
+ "n02087046": "toy terrier",
165
+ "n02087394": "Rhodesian ridgeback",
166
+ "n02088094": "Afghan hound, Afghan",
167
+ "n02088238": "basset, basset hound",
168
+ "n02088364": "beagle",
169
+ "n02088466": "bloodhound, sleuthhound",
170
+ "n02088632": "bluetick",
171
+ "n02089078": "black-and-tan coonhound",
172
+ "n02089867": "Walker hound, Walker foxhound",
173
+ "n02089973": "English foxhound",
174
+ "n02090379": "redbone",
175
+ "n02090622": "borzoi, Russian wolfhound",
176
+ "n02090721": "Irish wolfhound",
177
+ "n02091032": "Italian greyhound",
178
+ "n02091134": "whippet",
179
+ "n02091244": "Ibizan hound, Ibizan Podenco",
180
+ "n02091467": "Norwegian elkhound, elkhound",
181
+ "n02091635": "otterhound, otter hound",
182
+ "n02091831": "Saluki, gazelle hound",
183
+ "n02092002": "Scottish deerhound, deerhound",
184
+ "n02092339": "Weimaraner",
185
+ "n02093256": "Staffordshire bullterrier, Staffordshire bull terrier",
186
+ "n02093428": "American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier",
187
+ "n02093647": "Bedlington terrier",
188
+ "n02093754": "Border terrier",
189
+ "n02093859": "Kerry blue terrier",
190
+ "n02093991": "Irish terrier",
191
+ "n02094114": "Norfolk terrier",
192
+ "n02094258": "Norwich terrier",
193
+ "n02094433": "Yorkshire terrier",
194
+ "n02095314": "wire-haired fox terrier",
195
+ "n02095570": "Lakeland terrier",
196
+ "n02095889": "Sealyham terrier, Sealyham",
197
+ "n02096051": "Airedale, Airedale terrier",
198
+ "n02096177": "cairn, cairn terrier",
199
+ "n02096294": "Australian terrier",
200
+ "n02096437": "Dandie Dinmont, Dandie Dinmont terrier",
201
+ "n02096585": "Boston bull, Boston terrier",
202
+ "n02097047": "miniature schnauzer",
203
+ "n02097130": "giant schnauzer",
204
+ "n02097209": "standard schnauzer",
205
+ "n02097298": "Scotch terrier, Scottish terrier, Scottie",
206
+ "n02097474": "Tibetan terrier, chrysanthemum dog",
207
+ "n02097658": "silky terrier, Sydney silky",
208
+ "n02098105": "soft-coated wheaten terrier",
209
+ "n02098286": "West Highland white terrier",
210
+ "n02098413": "Lhasa, Lhasa apso",
211
+ "n02099267": "flat-coated retriever",
212
+ "n02099429": "curly-coated retriever",
213
+ "n02099601": "golden retriever",
214
+ "n02099712": "Labrador retriever",
215
+ "n02099849": "Chesapeake Bay retriever",
216
+ "n02100236": "German short-haired pointer",
217
+ "n02100583": "vizsla, Hungarian pointer",
218
+ "n02100735": "English setter",
219
+ "n02100877": "Irish setter, red setter",
220
+ "n02101006": "Gordon setter",
221
+ "n02101388": "Brittany spaniel",
222
+ "n02101556": "clumber, clumber spaniel",
223
+ "n02102040": "English springer, English springer spaniel",
224
+ "n02102177": "Welsh springer spaniel",
225
+ "n02102318": "cocker spaniel, English cocker spaniel, cocker",
226
+ "n02102480": "Sussex spaniel",
227
+ "n02102973": "Irish water spaniel",
228
+ "n02104029": "kuvasz",
229
+ "n02104365": "schipperke",
230
+ "n02105056": "groenendael",
231
+ "n02105162": "malinois",
232
+ "n02105251": "briard",
233
+ "n02105412": "kelpie",
234
+ "n02105505": "komondor",
235
+ "n02105641": "Old English sheepdog, bobtail",
236
+ "n02105855": "Shetland sheepdog, Shetland sheep dog, Shetland",
237
+ "n02106030": "collie",
238
+ "n02106166": "Border collie",
239
+ "n02106382": "Bouvier des Flandres, Bouviers des Flandres",
240
+ "n02106550": "Rottweiler",
241
+ "n02106662": "German shepherd, German shepherd dog, German police dog, alsatian",
242
+ "n02107142": "Doberman, Doberman pinscher",
243
+ "n02107312": "miniature pinscher",
244
+ "n02107574": "Greater Swiss Mountain dog",
245
+ "n02107683": "Bernese mountain dog",
246
+ "n02107908": "Appenzeller",
247
+ "n02108000": "EntleBucher",
248
+ "n02108089": "boxer",
249
+ "n02108422": "bull mastiff",
250
+ "n02108551": "Tibetan mastiff",
251
+ "n02108915": "French bulldog",
252
+ "n02109047": "Great Dane",
253
+ "n02109525": "Saint Bernard, St Bernard",
254
+ "n02109961": "Eskimo dog, husky",
255
+ "n02110063": "malamute, malemute, Alaskan malamute",
256
+ "n02110185": "Siberian husky",
257
+ "n02110341": "dalmatian, coach dog, carriage dog",
258
+ "n02110627": "affenpinscher, monkey pinscher, monkey dog",
259
+ "n02110806": "basenji",
260
+ "n02110958": "pug, pug-dog",
261
+ "n02111129": "Leonberg",
262
+ "n02111277": "Newfoundland, Newfoundland dog",
263
+ "n02111500": "Great Pyrenees",
264
+ "n02111889": "Samoyed, Samoyede",
265
+ "n02112018": "Pomeranian",
266
+ "n02112137": "chow, chow chow",
267
+ "n02112350": "keeshond",
268
+ "n02112706": "Brabancon griffon",
269
+ "n02113023": "Pembroke, Pembroke Welsh corgi",
270
+ "n02113186": "Cardigan, Cardigan Welsh corgi",
271
+ "n02113624": "toy poodle",
272
+ "n02113712": "miniature poodle",
273
+ "n02113799": "standard poodle",
274
+ "n02113978": "Mexican hairless",
275
+ "n02114367": "timber wolf, grey wolf, gray wolf, Canis lupus",
276
+ "n02114548": "white wolf, Arctic wolf, Canis lupus tundrarum",
277
+ "n02114712": "red wolf, maned wolf, Canis rufus, Canis niger",
278
+ "n02114855": "coyote, prairie wolf, brush wolf, Canis latrans",
279
+ "n02115641": "dingo, warrigal, warragal, Canis dingo",
280
+ "n02115913": "dhole, Cuon alpinus",
281
+ "n02116738": "African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus",
282
+ "n02117135": "hyena, hyaena",
283
+ "n02119022": "red fox, Vulpes vulpes",
284
+ "n02119789": "kit fox, Vulpes macrotis",
285
+ "n02120079": "Arctic fox, white fox, Alopex lagopus",
286
+ "n02120505": "grey fox, gray fox, Urocyon cinereoargenteus",
287
+ "n02123045": "tabby, tabby cat",
288
+ "n02123159": "tiger cat",
289
+ "n02123394": "Persian cat",
290
+ "n02123597": "Siamese cat, Siamese",
291
+ "n02124075": "Egyptian cat",
292
+ "n02125311": "cougar, puma, catamount, mountain lion, painter, panther, Felis concolor",
293
+ "n02127052": "lynx, catamount",
294
+ "n02128385": "leopard, Panthera pardus",
295
+ "n02128757": "snow leopard, ounce, Panthera uncia",
296
+ "n02128925": "jaguar, panther, Panthera onca, Felis onca",
297
+ "n02129165": "lion, king of beasts, Panthera leo",
298
+ "n02129604": "tiger, Panthera tigris",
299
+ "n02130308": "cheetah, chetah, Acinonyx jubatus",
300
+ "n02132136": "brown bear, bruin, Ursus arctos",
301
+ "n02133161": "American black bear, black bear, Ursus americanus, Euarctos americanus",
302
+ "n02134084": "ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus",
303
+ "n02134418": "sloth bear, Melursus ursinus, Ursus ursinus",
304
+ "n02137549": "mongoose",
305
+ "n02138441": "meerkat, mierkat",
306
+ "n02165105": "tiger beetle",
307
+ "n02165456": "ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle",
308
+ "n02167151": "ground beetle, carabid beetle",
309
+ "n02168699": "long-horned beetle, longicorn, longicorn beetle",
310
+ "n02169497": "leaf beetle, chrysomelid",
311
+ "n02172182": "dung beetle",
312
+ "n02174001": "rhinoceros beetle",
313
+ "n02177972": "weevil",
314
+ "n02190166": "fly",
315
+ "n02206856": "bee",
316
+ "n02219486": "ant, emmet, pismire",
317
+ "n02226429": "grasshopper, hopper",
318
+ "n02229544": "cricket",
319
+ "n02231487": "walking stick, walkingstick, stick insect",
320
+ "n02233338": "cockroach, roach",
321
+ "n02236044": "mantis, mantid",
322
+ "n02256656": "cicada, cicala",
323
+ "n02259212": "leafhopper",
324
+ "n02264363": "lacewing, lacewing fly",
325
+ "n02268443": "dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk",
326
+ "n02268853": "damselfly",
327
+ "n02276258": "admiral",
328
+ "n02277742": "ringlet, ringlet butterfly",
329
+ "n02279972": "monarch, monarch butterfly, milkweed butterfly, Danaus plexippus",
330
+ "n02280649": "cabbage butterfly",
331
+ "n02281406": "sulphur butterfly, sulfur butterfly",
332
+ "n02281787": "lycaenid, lycaenid butterfly",
333
+ "n02317335": "starfish, sea star",
334
+ "n02319095": "sea urchin",
335
+ "n02321529": "sea cucumber, holothurian",
336
+ "n02325366": "wood rabbit, cottontail, cottontail rabbit",
337
+ "n02326432": "hare",
338
+ "n02328150": "Angora, Angora rabbit",
339
+ "n02342885": "hamster",
340
+ "n02346627": "porcupine, hedgehog",
341
+ "n02356798": "fox squirrel, eastern fox squirrel, Sciurus niger",
342
+ "n02361337": "marmot",
343
+ "n02363005": "beaver",
344
+ "n02364673": "guinea pig, Cavia cobaya",
345
+ "n02389026": "sorrel",
346
+ "n02391049": "zebra",
347
+ "n02395406": "hog, pig, grunter, squealer, Sus scrofa",
348
+ "n02396427": "wild boar, boar, Sus scrofa",
349
+ "n02397096": "warthog",
350
+ "n02398521": "hippopotamus, hippo, river horse, Hippopotamus amphibius",
351
+ "n02403003": "ox",
352
+ "n02408429": "water buffalo, water ox, Asiatic buffalo, Bubalus bubalis",
353
+ "n02410509": "bison",
354
+ "n02412080": "ram, tup",
355
+ "n02415577": "bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis",
356
+ "n02417914": "ibex, Capra ibex",
357
+ "n02422106": "hartebeest",
358
+ "n02422699": "impala, Aepyceros melampus",
359
+ "n02423022": "gazelle",
360
+ "n02437312": "Arabian camel, dromedary, Camelus dromedarius",
361
+ "n02437616": "llama",
362
+ "n02441942": "weasel",
363
+ "n02442845": "mink",
364
+ "n02443114": "polecat, fitch, foulmart, foumart, Mustela putorius",
365
+ "n02443484": "black-footed ferret, ferret, Mustela nigripes",
366
+ "n02444819": "otter",
367
+ "n02445715": "skunk, polecat, wood pussy",
368
+ "n02447366": "badger",
369
+ "n02454379": "armadillo",
370
+ "n02457408": "three-toed sloth, ai, Bradypus tridactylus",
371
+ "n02480495": "orangutan, orang, orangutang, Pongo pygmaeus",
372
+ "n02480855": "gorilla, Gorilla gorilla",
373
+ "n02481823": "chimpanzee, chimp, Pan troglodytes",
374
+ "n02483362": "gibbon, Hylobates lar",
375
+ "n02483708": "siamang, Hylobates syndactylus, Symphalangus syndactylus",
376
+ "n02484975": "guenon, guenon monkey",
377
+ "n02486261": "patas, hussar monkey, Erythrocebus patas",
378
+ "n02486410": "baboon",
379
+ "n02487347": "macaque",
380
+ "n02488291": "langur",
381
+ "n02488702": "colobus, colobus monkey",
382
+ "n02489166": "proboscis monkey, Nasalis larvatus",
383
+ "n02490219": "marmoset",
384
+ "n02492035": "capuchin, ringtail, Cebus capucinus",
385
+ "n02492660": "howler monkey, howler",
386
+ "n02493509": "titi, titi monkey",
387
+ "n02493793": "spider monkey, Ateles geoffroyi",
388
+ "n02494079": "squirrel monkey, Saimiri sciureus",
389
+ "n02497673": "Madagascar cat, ring-tailed lemur, Lemur catta",
390
+ "n02500267": "indri, indris, Indri indri, Indri brevicaudatus",
391
+ "n02504013": "Indian elephant, Elephas maximus",
392
+ "n02504458": "African elephant, Loxodonta africana",
393
+ "n02509815": "lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens",
394
+ "n02510455": "giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca",
395
+ "n02514041": "barracouta, snoek",
396
+ "n02526121": "eel",
397
+ "n02536864": "coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch",
398
+ "n02606052": "rock beauty, Holocanthus tricolor",
399
+ "n02607072": "anemone fish",
400
+ "n02640242": "sturgeon",
401
+ "n02641379": "gar, garfish, garpike, billfish, Lepisosteus osseus",
402
+ "n02643566": "lionfish",
403
+ "n02655020": "puffer, pufferfish, blowfish, globefish",
404
+ "n02666196": "abacus",
405
+ "n02667093": "abaya",
406
+ "n02669723": "academic gown, academic robe, judge's robe",
407
+ "n02672831": "accordion, piano accordion, squeeze box",
408
+ "n02676566": "acoustic guitar",
409
+ "n02687172": "aircraft carrier, carrier, flattop, attack aircraft carrier",
410
+ "n02690373": "airliner",
411
+ "n02692877": "airship, dirigible",
412
+ "n02699494": "altar",
413
+ "n02701002": "ambulance",
414
+ "n02704792": "amphibian, amphibious vehicle",
415
+ "n02708093": "analog clock",
416
+ "n02727426": "apiary, bee house",
417
+ "n02730930": "apron",
418
+ "n02747177": "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin",
419
+ "n02749479": "assault rifle, assault gun",
420
+ "n02769748": "backpack, back pack, knapsack, packsack, rucksack, haversack",
421
+ "n02776631": "bakery, bakeshop, bakehouse",
422
+ "n02777292": "balance beam, beam",
423
+ "n02782093": "balloon",
424
+ "n02783161": "ballpoint, ballpoint pen, ballpen, Biro",
425
+ "n02786058": "Band Aid",
426
+ "n02787622": "banjo",
427
+ "n02788148": "bannister, banister, balustrade, balusters, handrail",
428
+ "n02790996": "barbell",
429
+ "n02791124": "barber chair",
430
+ "n02791270": "barbershop",
431
+ "n02793495": "barn",
432
+ "n02794156": "barometer",
433
+ "n02795169": "barrel, cask",
434
+ "n02797295": "barrow, garden cart, lawn cart, wheelbarrow",
435
+ "n02799071": "baseball",
436
+ "n02802426": "basketball",
437
+ "n02804414": "bassinet",
438
+ "n02804610": "bassoon",
439
+ "n02807133": "bathing cap, swimming cap",
440
+ "n02808304": "bath towel",
441
+ "n02808440": "bathtub, bathing tub, bath, tub",
442
+ "n02814533": "beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon",
443
+ "n02814860": "beacon, lighthouse, beacon light, pharos",
444
+ "n02815834": "beaker",
445
+ "n02817516": "bearskin, busby, shako",
446
+ "n02823428": "beer bottle",
447
+ "n02823750": "beer glass",
448
+ "n02825657": "bell cote, bell cot",
449
+ "n02834397": "bib",
450
+ "n02835271": "bicycle-built-for-two, tandem bicycle, tandem",
451
+ "n02837789": "bikini, two-piece",
452
+ "n02840245": "binder, ring-binder",
453
+ "n02841315": "binoculars, field glasses, opera glasses",
454
+ "n02843684": "birdhouse",
455
+ "n02859443": "boathouse",
456
+ "n02860847": "bobsled, bobsleigh, bob",
457
+ "n02865351": "bolo tie, bolo, bola tie, bola",
458
+ "n02869837": "bonnet, poke bonnet",
459
+ "n02870880": "bookcase",
460
+ "n02871525": "bookshop, bookstore, bookstall",
461
+ "n02877765": "bottlecap",
462
+ "n02879718": "bow",
463
+ "n02883205": "bow tie, bow-tie, bowtie",
464
+ "n02892201": "brass, memorial tablet, plaque",
465
+ "n02892767": "brassiere, bra, bandeau",
466
+ "n02894605": "breakwater, groin, groyne, mole, bulwark, seawall, jetty",
467
+ "n02895154": "breastplate, aegis, egis",
468
+ "n02906734": "broom",
469
+ "n02909870": "bucket, pail",
470
+ "n02910353": "buckle",
471
+ "n02916936": "bulletproof vest",
472
+ "n02917067": "bullet train, bullet",
473
+ "n02927161": "butcher shop, meat market",
474
+ "n02930766": "cab, hack, taxi, taxicab",
475
+ "n02939185": "caldron, cauldron",
476
+ "n02948072": "candle, taper, wax light",
477
+ "n02950826": "cannon",
478
+ "n02951358": "canoe",
479
+ "n02951585": "can opener, tin opener",
480
+ "n02963159": "cardigan",
481
+ "n02965783": "car mirror",
482
+ "n02966193": "carousel, carrousel, merry-go-round, roundabout, whirligig",
483
+ "n02966687": "carpenter's kit, tool kit",
484
+ "n02971356": "carton",
485
+ "n02974003": "car wheel",
486
+ "n02977058": "cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM",
487
+ "n02978881": "cassette",
488
+ "n02979186": "cassette player",
489
+ "n02980441": "castle",
490
+ "n02981792": "catamaran",
491
+ "n02988304": "CD player",
492
+ "n02992211": "cello, violoncello",
493
+ "n02992529": "cellular telephone, cellular phone, cellphone, cell, mobile phone",
494
+ "n02999410": "chain",
495
+ "n03000134": "chainlink fence",
496
+ "n03000247": "chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour",
497
+ "n03000684": "chain saw, chainsaw",
498
+ "n03014705": "chest",
499
+ "n03016953": "chiffonier, commode",
500
+ "n03017168": "chime, bell, gong",
501
+ "n03018349": "china cabinet, china closet",
502
+ "n03026506": "Christmas stocking",
503
+ "n03028079": "church, church building",
504
+ "n03032252": "cinema, movie theater, movie theatre, movie house, picture palace",
505
+ "n03041632": "cleaver, meat cleaver, chopper",
506
+ "n03042490": "cliff dwelling",
507
+ "n03045698": "cloak",
508
+ "n03047690": "clog, geta, patten, sabot",
509
+ "n03062245": "cocktail shaker",
510
+ "n03063599": "coffee mug",
511
+ "n03063689": "coffeepot",
512
+ "n03065424": "coil, spiral, volute, whorl, helix",
513
+ "n03075370": "combination lock",
514
+ "n03085013": "computer keyboard, keypad",
515
+ "n03089624": "confectionery, confectionary, candy store",
516
+ "n03095699": "container ship, containership, container vessel",
517
+ "n03100240": "convertible",
518
+ "n03109150": "corkscrew, bottle screw",
519
+ "n03110669": "cornet, horn, trumpet, trump",
520
+ "n03124043": "cowboy boot",
521
+ "n03124170": "cowboy hat, ten-gallon hat",
522
+ "n03125729": "cradle",
523
+ "n03126707": "crane2",
524
+ "n03127747": "crash helmet",
525
+ "n03127925": "crate",
526
+ "n03131574": "crib, cot",
527
+ "n03133878": "Crock Pot",
528
+ "n03134739": "croquet ball",
529
+ "n03141823": "crutch",
530
+ "n03146219": "cuirass",
531
+ "n03160309": "dam, dike, dyke",
532
+ "n03179701": "desk",
533
+ "n03180011": "desktop computer",
534
+ "n03187595": "dial telephone, dial phone",
535
+ "n03188531": "diaper, nappy, napkin",
536
+ "n03196217": "digital clock",
537
+ "n03197337": "digital watch",
538
+ "n03201208": "dining table, board",
539
+ "n03207743": "dishrag, dishcloth",
540
+ "n03207941": "dishwasher, dish washer, dishwashing machine",
541
+ "n03208938": "disk brake, disc brake",
542
+ "n03216828": "dock, dockage, docking facility",
543
+ "n03218198": "dogsled, dog sled, dog sleigh",
544
+ "n03220513": "dome",
545
+ "n03223299": "doormat, welcome mat",
546
+ "n03240683": "drilling platform, offshore rig",
547
+ "n03249569": "drum, membranophone, tympan",
548
+ "n03250847": "drumstick",
549
+ "n03255030": "dumbbell",
550
+ "n03259280": "Dutch oven",
551
+ "n03271574": "electric fan, blower",
552
+ "n03272010": "electric guitar",
553
+ "n03272562": "electric locomotive",
554
+ "n03290653": "entertainment center",
555
+ "n03291819": "envelope",
556
+ "n03297495": "espresso maker",
557
+ "n03314780": "face powder",
558
+ "n03325584": "feather boa, boa",
559
+ "n03337140": "file, file cabinet, filing cabinet",
560
+ "n03344393": "fireboat",
561
+ "n03345487": "fire engine, fire truck",
562
+ "n03347037": "fire screen, fireguard",
563
+ "n03355925": "flagpole, flagstaff",
564
+ "n03372029": "flute, transverse flute",
565
+ "n03376595": "folding chair",
566
+ "n03379051": "football helmet",
567
+ "n03384352": "forklift",
568
+ "n03388043": "fountain",
569
+ "n03388183": "fountain pen",
570
+ "n03388549": "four-poster",
571
+ "n03393912": "freight car",
572
+ "n03394916": "French horn, horn",
573
+ "n03400231": "frying pan, frypan, skillet",
574
+ "n03404251": "fur coat",
575
+ "n03417042": "garbage truck, dustcart",
576
+ "n03424325": "gasmask, respirator, gas helmet",
577
+ "n03425413": "gas pump, gasoline pump, petrol pump, island dispenser",
578
+ "n03443371": "goblet",
579
+ "n03444034": "go-kart",
580
+ "n03445777": "golf ball",
581
+ "n03445924": "golfcart, golf cart",
582
+ "n03447447": "gondola",
583
+ "n03447721": "gong, tam-tam",
584
+ "n03450230": "gown",
585
+ "n03452741": "grand piano, grand",
586
+ "n03457902": "greenhouse, nursery, glasshouse",
587
+ "n03459775": "grille, radiator grille",
588
+ "n03461385": "grocery store, grocery, food market, market",
589
+ "n03467068": "guillotine",
590
+ "n03476684": "hair slide",
591
+ "n03476991": "hair spray",
592
+ "n03478589": "half track",
593
+ "n03481172": "hammer",
594
+ "n03482405": "hamper",
595
+ "n03483316": "hand blower, blow dryer, blow drier, hair dryer, hair drier",
596
+ "n03485407": "hand-held computer, hand-held microcomputer",
597
+ "n03485794": "handkerchief, hankie, hanky, hankey",
598
+ "n03492542": "hard disc, hard disk, fixed disk",
599
+ "n03494278": "harmonica, mouth organ, harp, mouth harp",
600
+ "n03495258": "harp",
601
+ "n03496892": "harvester, reaper",
602
+ "n03498962": "hatchet",
603
+ "n03527444": "holster",
604
+ "n03529860": "home theater, home theatre",
605
+ "n03530642": "honeycomb",
606
+ "n03532672": "hook, claw",
607
+ "n03534580": "hoopskirt, crinoline",
608
+ "n03535780": "horizontal bar, high bar",
609
+ "n03538406": "horse cart, horse-cart",
610
+ "n03544143": "hourglass",
611
+ "n03584254": "iPod",
612
+ "n03584829": "iron, smoothing iron",
613
+ "n03590841": "jack-o'-lantern",
614
+ "n03594734": "jean, blue jean, denim",
615
+ "n03594945": "jeep, landrover",
616
+ "n03595614": "jersey, T-shirt, tee shirt",
617
+ "n03598930": "jigsaw puzzle",
618
+ "n03599486": "jinrikisha, ricksha, rickshaw",
619
+ "n03602883": "joystick",
620
+ "n03617480": "kimono",
621
+ "n03623198": "knee pad",
622
+ "n03627232": "knot",
623
+ "n03630383": "lab coat, laboratory coat",
624
+ "n03633091": "ladle",
625
+ "n03637318": "lampshade, lamp shade",
626
+ "n03642806": "laptop, laptop computer",
627
+ "n03649909": "lawn mower, mower",
628
+ "n03657121": "lens cap, lens cover",
629
+ "n03658185": "letter opener, paper knife, paperknife",
630
+ "n03661043": "library",
631
+ "n03662601": "lifeboat",
632
+ "n03666591": "lighter, light, igniter, ignitor",
633
+ "n03670208": "limousine, limo",
634
+ "n03673027": "liner, ocean liner",
635
+ "n03676483": "lipstick, lip rouge",
636
+ "n03680355": "Loafer",
637
+ "n03690938": "lotion",
638
+ "n03691459": "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system",
639
+ "n03692522": "loupe, jeweler's loupe",
640
+ "n03697007": "lumbermill, sawmill",
641
+ "n03706229": "magnetic compass",
642
+ "n03709823": "mailbag, postbag",
643
+ "n03710193": "mailbox, letter box",
644
+ "n03710637": "maillot",
645
+ "n03710721": "maillot, tank suit",
646
+ "n03717622": "manhole cover",
647
+ "n03720891": "maraca",
648
+ "n03721384": "marimba, xylophone",
649
+ "n03724870": "mask",
650
+ "n03729826": "matchstick",
651
+ "n03733131": "maypole",
652
+ "n03733281": "maze, labyrinth",
653
+ "n03733805": "measuring cup",
654
+ "n03742115": "medicine chest, medicine cabinet",
655
+ "n03743016": "megalith, megalithic structure",
656
+ "n03759954": "microphone, mike",
657
+ "n03761084": "microwave, microwave oven",
658
+ "n03763968": "military uniform",
659
+ "n03764736": "milk can",
660
+ "n03769881": "minibus",
661
+ "n03770439": "miniskirt, mini",
662
+ "n03770679": "minivan",
663
+ "n03773504": "missile",
664
+ "n03775071": "mitten",
665
+ "n03775546": "mixing bowl",
666
+ "n03776460": "mobile home, manufactured home",
667
+ "n03777568": "Model T",
668
+ "n03777754": "modem",
669
+ "n03781244": "monastery",
670
+ "n03782006": "monitor",
671
+ "n03785016": "moped",
672
+ "n03786901": "mortar",
673
+ "n03787032": "mortarboard",
674
+ "n03788195": "mosque",
675
+ "n03788365": "mosquito net",
676
+ "n03791053": "motor scooter, scooter",
677
+ "n03792782": "mountain bike, all-terrain bike, off-roader",
678
+ "n03792972": "mountain tent",
679
+ "n03793489": "mouse, computer mouse",
680
+ "n03794056": "mousetrap",
681
+ "n03796401": "moving van",
682
+ "n03803284": "muzzle",
683
+ "n03804744": "nail",
684
+ "n03814639": "neck brace",
685
+ "n03814906": "necklace",
686
+ "n03825788": "nipple",
687
+ "n03832673": "notebook, notebook computer",
688
+ "n03837869": "obelisk",
689
+ "n03838899": "oboe, hautboy, hautbois",
690
+ "n03840681": "ocarina, sweet potato",
691
+ "n03841143": "odometer, hodometer, mileometer, milometer",
692
+ "n03843555": "oil filter",
693
+ "n03854065": "organ, pipe organ",
694
+ "n03857828": "oscilloscope, scope, cathode-ray oscilloscope, CRO",
695
+ "n03866082": "overskirt",
696
+ "n03868242": "oxcart",
697
+ "n03868863": "oxygen mask",
698
+ "n03871628": "packet",
699
+ "n03873416": "paddle, boat paddle",
700
+ "n03874293": "paddlewheel, paddle wheel",
701
+ "n03874599": "padlock",
702
+ "n03876231": "paintbrush",
703
+ "n03877472": "pajama, pyjama, pj's, jammies",
704
+ "n03877845": "palace",
705
+ "n03884397": "panpipe, pandean pipe, syrinx",
706
+ "n03887697": "paper towel",
707
+ "n03888257": "parachute, chute",
708
+ "n03888605": "parallel bars, bars",
709
+ "n03891251": "park bench",
710
+ "n03891332": "parking meter",
711
+ "n03895866": "passenger car, coach, carriage",
712
+ "n03899768": "patio, terrace",
713
+ "n03902125": "pay-phone, pay-station",
714
+ "n03903868": "pedestal, plinth, footstall",
715
+ "n03908618": "pencil box, pencil case",
716
+ "n03908714": "pencil sharpener",
717
+ "n03916031": "perfume, essence",
718
+ "n03920288": "Petri dish",
719
+ "n03924679": "photocopier",
720
+ "n03929660": "pick, plectrum, plectron",
721
+ "n03929855": "pickelhaube",
722
+ "n03930313": "picket fence, paling",
723
+ "n03930630": "pickup, pickup truck",
724
+ "n03933933": "pier",
725
+ "n03935335": "piggy bank, penny bank",
726
+ "n03937543": "pill bottle",
727
+ "n03938244": "pillow",
728
+ "n03942813": "ping-pong ball",
729
+ "n03944341": "pinwheel",
730
+ "n03947888": "pirate, pirate ship",
731
+ "n03950228": "pitcher, ewer",
732
+ "n03954731": "plane, carpenter's plane, woodworking plane",
733
+ "n03956157": "planetarium",
734
+ "n03958227": "plastic bag",
735
+ "n03961711": "plate rack",
736
+ "n03967562": "plow, plough",
737
+ "n03970156": "plunger, plumber's helper",
738
+ "n03976467": "Polaroid camera, Polaroid Land camera",
739
+ "n03976657": "pole",
740
+ "n03977966": "police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria",
741
+ "n03980874": "poncho",
742
+ "n03982430": "pool table, billiard table, snooker table",
743
+ "n03983396": "pop bottle, soda bottle",
744
+ "n03991062": "pot, flowerpot",
745
+ "n03992509": "potter's wheel",
746
+ "n03995372": "power drill",
747
+ "n03998194": "prayer rug, prayer mat",
748
+ "n04004767": "printer",
749
+ "n04005630": "prison, prison house",
750
+ "n04008634": "projectile, missile",
751
+ "n04009552": "projector",
752
+ "n04019541": "puck, hockey puck",
753
+ "n04023962": "punching bag, punch bag, punching ball, punchball",
754
+ "n04026417": "purse",
755
+ "n04033901": "quill, quill pen",
756
+ "n04033995": "quilt, comforter, comfort, puff",
757
+ "n04037443": "racer, race car, racing car",
758
+ "n04039381": "racket, racquet",
759
+ "n04040759": "radiator",
760
+ "n04041544": "radio, wireless",
761
+ "n04044716": "radio telescope, radio reflector",
762
+ "n04049303": "rain barrel",
763
+ "n04065272": "recreational vehicle, RV, R.V.",
764
+ "n04067472": "reel",
765
+ "n04069434": "reflex camera",
766
+ "n04070727": "refrigerator, icebox",
767
+ "n04074963": "remote control, remote",
768
+ "n04081281": "restaurant, eating house, eating place, eatery",
769
+ "n04086273": "revolver, six-gun, six-shooter",
770
+ "n04090263": "rifle",
771
+ "n04099969": "rocking chair, rocker",
772
+ "n04111531": "rotisserie",
773
+ "n04116512": "rubber eraser, rubber, pencil eraser",
774
+ "n04118538": "rugby ball",
775
+ "n04118776": "rule, ruler",
776
+ "n04120489": "running shoe",
777
+ "n04125021": "safe",
778
+ "n04127249": "safety pin",
779
+ "n04131690": "saltshaker, salt shaker",
780
+ "n04133789": "sandal",
781
+ "n04136333": "sarong",
782
+ "n04141076": "sax, saxophone",
783
+ "n04141327": "scabbard",
784
+ "n04141975": "scale, weighing machine",
785
+ "n04146614": "school bus",
786
+ "n04147183": "schooner",
787
+ "n04149813": "scoreboard",
788
+ "n04152593": "screen, CRT screen",
789
+ "n04153751": "screw",
790
+ "n04154565": "screwdriver",
791
+ "n04162706": "seat belt, seatbelt",
792
+ "n04179913": "sewing machine",
793
+ "n04192698": "shield, buckler",
794
+ "n04200800": "shoe shop, shoe-shop, shoe store",
795
+ "n04201297": "shoji",
796
+ "n04204238": "shopping basket",
797
+ "n04204347": "shopping cart",
798
+ "n04208210": "shovel",
799
+ "n04209133": "shower cap",
800
+ "n04209239": "shower curtain",
801
+ "n04228054": "ski",
802
+ "n04229816": "ski mask",
803
+ "n04235860": "sleeping bag",
804
+ "n04238763": "slide rule, slipstick",
805
+ "n04239074": "sliding door",
806
+ "n04243546": "slot, one-armed bandit",
807
+ "n04251144": "snorkel",
808
+ "n04252077": "snowmobile",
809
+ "n04252225": "snowplow, snowplough",
810
+ "n04254120": "soap dispenser",
811
+ "n04254680": "soccer ball",
812
+ "n04254777": "sock",
813
+ "n04258138": "solar dish, solar collector, solar furnace",
814
+ "n04259630": "sombrero",
815
+ "n04263257": "soup bowl",
816
+ "n04264628": "space bar",
817
+ "n04265275": "space heater",
818
+ "n04266014": "space shuttle",
819
+ "n04270147": "spatula",
820
+ "n04273569": "speedboat",
821
+ "n04275548": "spider web, spider's web",
822
+ "n04277352": "spindle",
823
+ "n04285008": "sports car, sport car",
824
+ "n04286575": "spotlight, spot",
825
+ "n04296562": "stage",
826
+ "n04310018": "steam locomotive",
827
+ "n04311004": "steel arch bridge",
828
+ "n04311174": "steel drum",
829
+ "n04317175": "stethoscope",
830
+ "n04325704": "stole",
831
+ "n04326547": "stone wall",
832
+ "n04328186": "stopwatch, stop watch",
833
+ "n04330267": "stove",
834
+ "n04332243": "strainer",
835
+ "n04335435": "streetcar, tram, tramcar, trolley, trolley car",
836
+ "n04336792": "stretcher",
837
+ "n04344873": "studio couch, day bed",
838
+ "n04346328": "stupa, tope",
839
+ "n04347754": "submarine, pigboat, sub, U-boat",
840
+ "n04350905": "suit, suit of clothes",
841
+ "n04355338": "sundial",
842
+ "n04355933": "sunglass",
843
+ "n04356056": "sunglasses, dark glasses, shades",
844
+ "n04357314": "sunscreen, sunblock, sun blocker",
845
+ "n04366367": "suspension bridge",
846
+ "n04367480": "swab, swob, mop",
847
+ "n04370456": "sweatshirt",
848
+ "n04371430": "swimming trunks, bathing trunks",
849
+ "n04371774": "swing",
850
+ "n04372370": "switch, electric switch, electrical switch",
851
+ "n04376876": "syringe",
852
+ "n04380533": "table lamp",
853
+ "n04389033": "tank, army tank, armored combat vehicle, armoured combat vehicle",
854
+ "n04392985": "tape player",
855
+ "n04398044": "teapot",
856
+ "n04399382": "teddy, teddy bear",
857
+ "n04404412": "television, television system",
858
+ "n04409515": "tennis ball",
859
+ "n04417672": "thatch, thatched roof",
860
+ "n04418357": "theater curtain, theatre curtain",
861
+ "n04423845": "thimble",
862
+ "n04428191": "thresher, thrasher, threshing machine",
863
+ "n04429376": "throne",
864
+ "n04435653": "tile roof",
865
+ "n04442312": "toaster",
866
+ "n04443257": "tobacco shop, tobacconist shop, tobacconist",
867
+ "n04447861": "toilet seat",
868
+ "n04456115": "torch",
869
+ "n04458633": "totem pole",
870
+ "n04461696": "tow truck, tow car, wrecker",
871
+ "n04462240": "toyshop",
872
+ "n04465501": "tractor",
873
+ "n04467665": "trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi",
874
+ "n04476259": "tray",
875
+ "n04479046": "trench coat",
876
+ "n04482393": "tricycle, trike, velocipede",
877
+ "n04483307": "trimaran",
878
+ "n04485082": "tripod",
879
+ "n04486054": "triumphal arch",
880
+ "n04487081": "trolleybus, trolley coach, trackless trolley",
881
+ "n04487394": "trombone",
882
+ "n04493381": "tub, vat",
883
+ "n04501370": "turnstile",
884
+ "n04505470": "typewriter keyboard",
885
+ "n04507155": "umbrella",
886
+ "n04509417": "unicycle, monocycle",
887
+ "n04515003": "upright, upright piano",
888
+ "n04517823": "vacuum, vacuum cleaner",
889
+ "n04522168": "vase",
890
+ "n04523525": "vault",
891
+ "n04525038": "velvet",
892
+ "n04525305": "vending machine",
893
+ "n04532106": "vestment",
894
+ "n04532670": "viaduct",
895
+ "n04536866": "violin, fiddle",
896
+ "n04540053": "volleyball",
897
+ "n04542943": "waffle iron",
898
+ "n04548280": "wall clock",
899
+ "n04548362": "wallet, billfold, notecase, pocketbook",
900
+ "n04550184": "wardrobe, closet, press",
901
+ "n04552348": "warplane, military plane",
902
+ "n04553703": "washbasin, handbasin, washbowl, lavabo, wash-hand basin",
903
+ "n04554684": "washer, automatic washer, washing machine",
904
+ "n04557648": "water bottle",
905
+ "n04560804": "water jug",
906
+ "n04562935": "water tower",
907
+ "n04579145": "whiskey jug",
908
+ "n04579432": "whistle",
909
+ "n04584207": "wig",
910
+ "n04589890": "window screen",
911
+ "n04590129": "window shade",
912
+ "n04591157": "Windsor tie",
913
+ "n04591713": "wine bottle",
914
+ "n04592741": "wing",
915
+ "n04596742": "wok",
916
+ "n04597913": "wooden spoon",
917
+ "n04599235": "wool, woolen, woollen",
918
+ "n04604644": "worm fence, snake fence, snake-rail fence, Virginia fence",
919
+ "n04606251": "wreck",
920
+ "n04612504": "yawl",
921
+ "n04613696": "yurt",
922
+ "n06359193": "web site, website, internet site, site",
923
+ "n06596364": "comic book",
924
+ "n06785654": "crossword puzzle, crossword",
925
+ "n06794110": "street sign",
926
+ "n06874185": "traffic light, traffic signal, stoplight",
927
+ "n07248320": "book jacket, dust cover, dust jacket, dust wrapper",
928
+ "n07565083": "menu",
929
+ "n07579787": "plate",
930
+ "n07583066": "guacamole",
931
+ "n07584110": "consomme",
932
+ "n07590611": "hot pot, hotpot",
933
+ "n07613480": "trifle",
934
+ "n07614500": "ice cream, icecream",
935
+ "n07615774": "ice lolly, lolly, lollipop, popsicle",
936
+ "n07684084": "French loaf",
937
+ "n07693725": "bagel, beigel",
938
+ "n07695742": "pretzel",
939
+ "n07697313": "cheeseburger",
940
+ "n07697537": "hotdog, hot dog, red hot",
941
+ "n07711569": "mashed potato",
942
+ "n07714571": "head cabbage",
943
+ "n07714990": "broccoli",
944
+ "n07715103": "cauliflower",
945
+ "n07716358": "zucchini, courgette",
946
+ "n07716906": "spaghetti squash",
947
+ "n07717410": "acorn squash",
948
+ "n07717556": "butternut squash",
949
+ "n07718472": "cucumber, cuke",
950
+ "n07718747": "artichoke, globe artichoke",
951
+ "n07720875": "bell pepper",
952
+ "n07730033": "cardoon",
953
+ "n07734744": "mushroom",
954
+ "n07742313": "Granny Smith",
955
+ "n07745940": "strawberry",
956
+ "n07747607": "orange",
957
+ "n07749582": "lemon",
958
+ "n07753113": "fig",
959
+ "n07753275": "pineapple, ananas",
960
+ "n07753592": "banana",
961
+ "n07754684": "jackfruit, jak, jack",
962
+ "n07760859": "custard apple",
963
+ "n07768694": "pomegranate",
964
+ "n07802026": "hay",
965
+ "n07831146": "carbonara",
966
+ "n07836838": "chocolate sauce, chocolate syrup",
967
+ "n07860988": "dough",
968
+ "n07871810": "meat loaf, meatloaf",
969
+ "n07873807": "pizza, pizza pie",
970
+ "n07875152": "potpie",
971
+ "n07880968": "burrito",
972
+ "n07892512": "red wine",
973
+ "n07920052": "espresso",
974
+ "n07930864": "cup",
975
+ "n07932039": "eggnog",
976
+ "n09193705": "alp",
977
+ "n09229709": "bubble",
978
+ "n09246464": "cliff, drop, drop-off",
979
+ "n09256479": "coral reef",
980
+ "n09288635": "geyser",
981
+ "n09332890": "lakeside, lakeshore",
982
+ "n09399592": "promontory, headland, head, foreland",
983
+ "n09421951": "sandbar, sand bar",
984
+ "n09428293": "seashore, coast, seacoast, sea-coast",
985
+ "n09468604": "valley, vale",
986
+ "n09472597": "volcano",
987
+ "n09835506": "ballplayer, baseball player",
988
+ "n10148035": "groom, bridegroom",
989
+ "n10565667": "scuba diver",
990
+ "n11879895": "rapeseed",
991
+ "n11939491": "daisy",
992
+ "n12057211": "yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum",
993
+ "n12144580": "corn",
994
+ "n12267677": "acorn",
995
+ "n12620546": "hip, rose hip, rosehip",
996
+ "n12768682": "buckeye, horse chestnut, conker",
997
+ "n12985857": "coral fungus",
998
+ "n12998815": "agaric",
999
+ "n13037406": "gyromitra",
1000
+ "n13040303": "stinkhorn, carrion fungus",
1001
+ "n13044778": "earthstar",
1002
+ "n13052670": "hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa",
1003
+ "n13054560": "bolete",
1004
+ "n13133613": "ear, spike, capitulum",
1005
+ "n15075141": "toilet tissue, toilet paper, bathroom tissue",
1006
+ }
1007
+ )
tasks/image_classification/plotting.py ADDED
@@ -0,0 +1,494 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import numpy as np
3
+ import cv2
4
+ import torch
5
+ import os
6
+ import imageio
7
+ import matplotlib.pyplot as plt
8
+ import matplotlib as mpl
9
+ from matplotlib import patheffects
10
+ mpl.use('Agg')
11
+ import seaborn as sns
12
+ import numpy as np
13
+ from tqdm.auto import tqdm
14
+ sns.set_style('darkgrid')
15
+
16
+ from tqdm.auto import tqdm
17
+ from scipy import ndimage
18
+ import umap
19
+ from scipy.special import softmax
20
+
21
+ import subprocess as sp
22
+ import cv2 # Still potentially useful for color conversion checks if needed
23
+ import os
24
+
25
+ def save_frames_to_mp4(frames, output_filename, fps=15.0, gop_size=None, crf=23, preset='medium', pix_fmt='yuv420p'):
26
+ """
27
+ Saves a list of NumPy array frames to an MP4 video file using FFmpeg via subprocess.
28
+
29
+ Includes fix for odd frame dimensions by padding to the nearest even number using -vf pad.
30
+
31
+ Requires FFmpeg to be installed and available in the system PATH.
32
+
33
+ Args:
34
+ frames (list): A list of NumPy arrays representing the video frames.
35
+ Expected format: uint8, (height, width, 3) for BGR color
36
+ or (height, width) for grayscale. Should be consistent.
37
+ output_filename (str): The path and name for the output MP4 file.
38
+ fps (float, optional): Frames per second for the output video. Defaults to 15.0.
39
+ gop_size (int, optional): Group of Pictures (GOP) size. This determines the
40
+ maximum interval between keyframes. Lower values
41
+ mean more frequent keyframes (better seeking, larger file).
42
+ Defaults to int(fps) (approx 1 keyframe per second).
43
+ crf (int, optional): Constant Rate Factor for H.264 encoding. Lower values mean
44
+ better quality and larger files. Typical range: 18-28.
45
+ Defaults to 23.
46
+ preset (str, optional): FFmpeg encoding speed preset. Affects encoding time
47
+ and compression efficiency. Options include 'ultrafast',
48
+ 'superfast', 'veryfast', 'faster', 'fast', 'medium',
49
+ 'slow', 'slower', 'veryslow'. Defaults to 'medium'.
50
+ """
51
+ if not frames:
52
+ print("Error: The 'frames' list is empty. No video to save.")
53
+ return
54
+
55
+ # --- Determine Parameters from First Frame ---
56
+ try:
57
+ first_frame = frames[0]
58
+ print(first_frame.shape)
59
+ if not isinstance(first_frame, np.ndarray):
60
+ print(f"Error: Frame 0 is not a NumPy array (type: {type(first_frame)}).")
61
+ return
62
+
63
+ frame_height, frame_width = first_frame.shape[:2]
64
+ frame_size_str = f"{frame_width}x{frame_height}"
65
+
66
+ # Determine input pixel format based on first frame's shape
67
+ if len(first_frame.shape) == 3 and first_frame.shape[2] == 3:
68
+ input_pixel_format = 'bgr24' # Assume OpenCV's default BGR uint8
69
+ expected_dims = 3
70
+ print(f"Info: Detected color frames (shape: {first_frame.shape}). Expecting BGR input.")
71
+ elif len(first_frame.shape) == 2:
72
+ input_pixel_format = 'gray'
73
+ expected_dims = 2
74
+ print(f"Info: Detected grayscale frames (shape: {first_frame.shape}).")
75
+ else:
76
+ print(f"Error: Unsupported frame shape {first_frame.shape}. Must be (h, w) or (h, w, 3).")
77
+ return
78
+
79
+ if first_frame.dtype != np.uint8:
80
+ print(f"Warning: First frame dtype is {first_frame.dtype}. Will attempt conversion to uint8.")
81
+
82
+ except IndexError:
83
+ print("Error: Could not access the first frame to determine dimensions.")
84
+ return
85
+ except Exception as e:
86
+ print(f"Error processing first frame: {e}")
87
+ return
88
+
89
+ # --- Set GOP size default if not provided ---
90
+ if gop_size is None:
91
+ gop_size = int(fps)
92
+ print(f"Info: GOP size not specified, defaulting to {gop_size} (approx 1 keyframe/sec).")
93
+
94
+ # --- Construct FFmpeg Command ---
95
+ # ADDED -vf pad filter to ensure even dimensions for libx264/yuv420p
96
+ # It calculates the nearest even dimensions >= original dimensions
97
+ # Example: 1600x1351 -> 1600x1352
98
+ pad_filter = "pad=ceil(iw/2)*2:ceil(ih/2)*2"
99
+
100
+ command = [
101
+ 'ffmpeg',
102
+ '-y',
103
+ '-f', 'rawvideo',
104
+ '-vcodec', 'rawvideo',
105
+ '-pix_fmt', input_pixel_format,
106
+ '-s', frame_size_str,
107
+ '-r', str(float(fps)),
108
+ '-i', '-',
109
+ '-vf', pad_filter, # <--- ADDED VIDEO FILTER HERE
110
+ '-c:v', 'libx264',
111
+ '-pix_fmt', pix_fmt,
112
+ '-preset', preset,
113
+ '-crf', str(crf),
114
+ '-g', str(gop_size),
115
+ '-movflags', '+faststart',
116
+ output_filename
117
+ ]
118
+
119
+ print(f"\n--- Starting FFmpeg ---")
120
+ print(f"Output File: {output_filename}")
121
+ print(f"Parameters: FPS={fps}, Size={frame_size_str}, GOP={gop_size}, CRF={crf}, Preset={preset}")
122
+ print(f"Applying Filter: -vf {pad_filter} (Ensures even dimensions)")
123
+ # print(f"FFmpeg Command: {' '.join(command)}") # Uncomment for debugging
124
+
125
+ # --- Execute FFmpeg via Subprocess ---
126
+ try:
127
+ process = sp.Popen(command, stdin=sp.PIPE, stdout=sp.PIPE, stderr=sp.PIPE)
128
+
129
+ print(f"\nWriting {len(frames)} frames to FFmpeg...")
130
+ progress_interval = max(1, len(frames) // 10) # Print progress roughly 10 times
131
+
132
+ for i, frame in enumerate(frames):
133
+ # Basic validation and conversion for each frame
134
+ if not isinstance(frame, np.ndarray):
135
+ print(f"Warning: Frame {i} is not a numpy array (type: {type(frame)}). Skipping.")
136
+ continue
137
+ if frame.shape[0] != frame_height or frame.shape[1] != frame_width:
138
+ print(f"Warning: Frame {i} has different dimensions {frame.shape[:2]}! Expected ({frame_height},{frame_width}). Skipping.")
139
+ continue
140
+
141
+ current_dims = len(frame.shape)
142
+ if current_dims != expected_dims:
143
+ print(f"Warning: Frame {i} has inconsistent dimensions ({current_dims}D vs expected {expected_dims}D). Skipping.")
144
+ continue
145
+ if expected_dims == 3 and frame.shape[2] != 3:
146
+ print(f"Warning: Frame {i} is color but doesn't have 3 channels ({frame.shape}). Skipping.")
147
+ continue
148
+
149
+ if frame.dtype != np.uint8:
150
+ try:
151
+ frame = np.clip(frame, 0, 255).astype(np.uint8)
152
+ except Exception as clip_err:
153
+ print(f"Error clipping/converting frame {i} dtype: {clip_err}. Skipping.")
154
+ continue
155
+
156
+ # Write frame bytes to FFmpeg's stdin
157
+ try:
158
+ process.stdin.write(frame.tobytes())
159
+ except (OSError, BrokenPipeError) as pipe_err:
160
+ print(f"\nError writing frame {i} to FFmpeg stdin: {pipe_err}")
161
+ print("FFmpeg process likely terminated prematurely. Check FFmpeg errors below.")
162
+ try:
163
+ # Immediately try to read stderr if pipe breaks
164
+ stderr_output_on_error = process.stderr.read()
165
+ if stderr_output_on_error:
166
+ print("\n--- FFmpeg stderr output on error ---")
167
+ print(stderr_output_on_error.decode(errors='ignore'))
168
+ print("--- End FFmpeg stderr ---")
169
+ except Exception as read_err:
170
+ print(f"(Could not read stderr after pipe error: {read_err})")
171
+ return
172
+ except Exception as write_err:
173
+ print(f"Unexpected error writing frame {i}: {write_err}. Skipping.")
174
+ continue
175
+
176
+ if (i + 1) % progress_interval == 0 or (i + 1) == len(frames):
177
+ print(f" Processed frame {i + 1}/{len(frames)}")
178
+
179
+ print("\nFinished writing frames. Closing FFmpeg stdin and waiting for completion...")
180
+ process.stdin.close()
181
+ stdout, stderr = process.communicate()
182
+ return_code = process.wait()
183
+
184
+ print("\n--- FFmpeg Final Status ---")
185
+ if return_code == 0:
186
+ print(f"FFmpeg process completed successfully.")
187
+ print(f"Video saved as: {output_filename}")
188
+ else:
189
+ print(f"FFmpeg process failed with return code {return_code}.")
190
+ print("--- FFmpeg Standard Error Output: ---")
191
+ print(stderr.decode(errors='replace')) # Print stderr captured by communicate()
192
+ print("--- End FFmpeg Output ---")
193
+ print("Review the FFmpeg error message above for details (e.g., dimension errors, parameter issues).")
194
+
195
+ except FileNotFoundError:
196
+ print("\n--- FATAL ERROR ---")
197
+ print("Error: 'ffmpeg' command not found.")
198
+ print("Please ensure FFmpeg is installed and its directory is included in your system's PATH environment variable.")
199
+ print("Download from: https://ffmpeg.org/")
200
+ print("-------------------")
201
+ except Exception as e:
202
+ print(f"\nAn unexpected error occurred during FFmpeg execution: {e}")
203
+
204
+ def find_island_centers(array_2d, threshold):
205
+ """
206
+ Finds the center of mass of each island (connected component) in a 2D array.
207
+
208
+ Args:
209
+ array_2d: A 2D numpy array of values.
210
+ threshold: The threshold to binarize the array.
211
+
212
+ Returns:
213
+ A list of tuples (y, x) representing the center of mass of each island.
214
+ """
215
+ binary_image = array_2d > threshold
216
+ labeled_image, num_labels = ndimage.label(binary_image)
217
+ centers = []
218
+ areas = [] # Store the area of each island
219
+ for i in range(1, num_labels + 1):
220
+ island = (labeled_image == i)
221
+ total_mass = np.sum(array_2d[island])
222
+ if total_mass > 0:
223
+ y_coords, x_coords = np.mgrid[:array_2d.shape[0], :array_2d.shape[1]]
224
+ x_center = np.average(x_coords[island], weights=array_2d[island])
225
+ y_center = np.average(y_coords[island], weights=array_2d[island])
226
+ centers.append((round(y_center, 4), round(x_center, 4)))
227
+ areas.append(np.sum(island)) # Calculate area of the island
228
+ return centers, areas
229
+
230
+ def plot_neural_dynamics(post_activations_history, N_to_plot, save_location, axis_snap=False, N_per_row=5, which_neurons_mid=None, mid_colours=None, use_most_active_neurons=False):
231
+ assert N_to_plot%N_per_row==0, f'For nice visualisation, N_to_plot={N_to_plot} must be a multiple of N_per_row={N_per_row}'
232
+ assert post_activations_history.shape[-1] >= N_to_plot
233
+ figscale = 2
234
+ aspect_ratio = 3
235
+ mosaic = np.array([[f'{i}'] for i in range(N_to_plot)]).flatten().reshape(-1, N_per_row)
236
+ fig_synch, axes_synch = plt.subplot_mosaic(mosaic=mosaic, figsize=(figscale*mosaic.shape[1]*aspect_ratio*0.2, figscale*mosaic.shape[0]*0.2))
237
+ fig_mid, axes_mid = plt.subplot_mosaic(mosaic=mosaic, figsize=(figscale*mosaic.shape[1]*aspect_ratio*0.2, figscale*mosaic.shape[0]*0.2), dpi=200)
238
+
239
+ palette = sns.color_palette("husl", 8)
240
+
241
+ which_neurons_synch = np.arange(N_to_plot)
242
+ # which_neurons_mid = np.arange(N_to_plot, N_to_plot*2) if post_activations_history.shape[-1] >= 2*N_to_plot else np.random.choice(np.arange(post_activations_history.shape[-1]), size=N_to_plot, replace=True)
243
+ random_indices = np.random.choice(np.arange(post_activations_history.shape[-1]), size=N_to_plot, replace=post_activations_history.shape[-1] < N_to_plot)
244
+ if use_most_active_neurons:
245
+ metric = np.abs(np.fft.rfft(post_activations_history, axis=0))[3:].mean(0).std(0)
246
+ random_indices = np.argsort(metric)[-N_to_plot:]
247
+ np.random.shuffle(random_indices)
248
+ which_neurons_mid = which_neurons_mid if which_neurons_mid is not None else random_indices
249
+
250
+ if mid_colours is None:
251
+ mid_colours = [palette[np.random.randint(0, 8)] for ndx in range(N_to_plot)]
252
+ with tqdm(total=N_to_plot, initial=0, leave=False, position=1, dynamic_ncols=True) as pbar_inner:
253
+ pbar_inner.set_description('Plotting neural dynamics')
254
+ for ndx in range(N_to_plot):
255
+
256
+ ax_s = axes_synch[f'{ndx}']
257
+ ax_m = axes_mid[f'{ndx}']
258
+
259
+ traces_s = post_activations_history[:,:,which_neurons_synch[ndx]].T
260
+ traces_m = post_activations_history[:,:,which_neurons_mid[ndx]].T
261
+ c_s = palette[np.random.randint(0, 8)]
262
+ c_m = mid_colours[ndx]
263
+
264
+ for traces_s_here, traces_m_here in zip(traces_s, traces_m):
265
+ ax_s.plot(np.arange(len(traces_s_here)), traces_s_here, linestyle='-', color=c_s, alpha=0.05, linewidth=0.6)
266
+ ax_m.plot(np.arange(len(traces_m_here)), traces_m_here, linestyle='-', color=c_m, alpha=0.05, linewidth=0.6)
267
+
268
+
269
+ ax_s.plot(np.arange(len(traces_s[0])), traces_s[0], linestyle='-', color='white', alpha=1, linewidth=2.5)
270
+ ax_s.plot(np.arange(len(traces_s[0])), traces_s[0], linestyle='-', color=c_s, alpha=1, linewidth=1.3)
271
+ ax_s.plot(np.arange(len(traces_s[0])), traces_s[0], linestyle='-', color='black', alpha=1, linewidth=0.3)
272
+ ax_m.plot(np.arange(len(traces_m[0])), traces_m[0], linestyle='-', color='white', alpha=1, linewidth=2.5)
273
+ ax_m.plot(np.arange(len(traces_m[0])), traces_m[0], linestyle='-', color=c_m, alpha=1, linewidth=1.3)
274
+ ax_m.plot(np.arange(len(traces_m[0])), traces_m[0], linestyle='-', color='black', alpha=1, linewidth=0.3)
275
+ if axis_snap and np.all(np.isfinite(traces_s[0])):
276
+ ax_s.set_ylim([np.min(traces_s[0])-np.ptp(traces_s[0])*0.05, np.max(traces_s[0])+np.ptp(traces_s[0])*0.05])
277
+ ax_m.set_ylim([np.min(traces_m[0])-np.ptp(traces_m[0])*0.05, np.max(traces_m[0])+np.ptp(traces_m[0])*0.05])
278
+
279
+
280
+ ax_s.grid(False)
281
+ ax_m.grid(False)
282
+ ax_s.set_xlim([0, len(traces_s[0])-1])
283
+ ax_m.set_xlim([0, len(traces_m[0])-1])
284
+
285
+ ax_s.set_xticklabels([])
286
+ ax_s.set_yticklabels([])
287
+
288
+ ax_m.set_xticklabels([])
289
+ ax_m.set_yticklabels([])
290
+ pbar_inner.update(1)
291
+ fig_synch.tight_layout(pad=0.05)
292
+ fig_mid.tight_layout(pad=0.05)
293
+ if save_location is not None:
294
+ fig_synch.savefig(f'{save_location}/neural_dynamics_synch.pdf', dpi=200)
295
+ fig_synch.savefig(f'{save_location}/neural_dynamics_synch.png', dpi=200)
296
+ fig_mid.savefig(f'{save_location}/neural_dynamics_other.pdf', dpi=200)
297
+ fig_mid.savefig(f'{save_location}/neural_dynamics_other.png', dpi=200)
298
+ plt.close(fig_synch)
299
+ plt.close(fig_mid)
300
+ return fig_synch, fig_mid, which_neurons_mid, mid_colours
301
+
302
+
303
+
304
+ def make_classification_gif(image, target, predictions, certainties, post_activations, attention_tracking, class_labels, save_location):
305
+ cmap_viridis = sns.color_palette('viridis', as_cmap=True)
306
+ cmap_spectral = sns.color_palette("Spectral", as_cmap=True)
307
+ figscale = 2
308
+ with tqdm(total=post_activations.shape[0]+1, initial=0, leave=False, position=1, dynamic_ncols=True) as pbar_inner:
309
+ pbar_inner.set_description('Computing UMAP')
310
+
311
+
312
+ low = np.percentile(post_activations, 1, axis=0, keepdims=True)
313
+ high = np.percentile(post_activations, 99, axis=0, keepdims=True)
314
+ post_activations_normed = np.clip((post_activations - low)/(high - low), 0, 1)
315
+ metric = 'cosine'
316
+ reducer = umap.UMAP(n_components=2,
317
+ n_neighbors=100,
318
+ min_dist=3,
319
+ spread=3.0,
320
+ metric=metric,
321
+ random_state=None,
322
+ # low_memory=True,
323
+ ) if post_activations.shape[-1] > 2048 else umap.UMAP(n_components=2,
324
+ n_neighbors=20,
325
+ min_dist=1,
326
+ spread=1.0,
327
+ metric=metric,
328
+ random_state=None,
329
+ # low_memory=True,
330
+ )
331
+ positions = reducer.fit_transform(post_activations_normed.T)
332
+
333
+ x_umap = positions[:, 0]
334
+ y_umap = positions[:, 1]
335
+
336
+ pbar_inner.update(1)
337
+ pbar_inner.set_description('Iterating through to build frames')
338
+
339
+
340
+
341
+ frames = []
342
+ route_steps = {}
343
+ route_colours = []
344
+
345
+ n_steps = len(post_activations)
346
+ n_heads = attention_tracking.shape[1]
347
+ step_linspace = np.linspace(0, 1, n_steps)
348
+
349
+ for stepi in np.arange(0, n_steps, 1):
350
+ pbar_inner.set_description('Making frames for gif')
351
+
352
+
353
+ attention_now = attention_tracking[max(0, stepi-5):stepi+1].mean(0) # Make it smooth for pretty
354
+ # attention_now[:,0,0] = 0 # Corners can be weird looking
355
+ # attention_now[:,0,-1] = 0
356
+ # attention_now[:,-1,0] = 0
357
+ # attention_now[:,-1,-1] = 0
358
+ # attention_now = (attention_tracking[:stepi+1, 0] * decay).sum(0)/(decay.sum(0))
359
+ certainties_now = certainties[1, :stepi+1]
360
+ attention_interp = torch.nn.functional.interpolate(torch.from_numpy(attention_now).unsqueeze(0), image.shape[:2], mode='bilinear')[0]
361
+ attention_interp = (attention_interp.flatten(1) - attention_interp.flatten(1).min(-1, keepdim=True)[0])/(attention_interp.flatten(1).max(-1, keepdim=True)[0] - attention_interp.flatten(1).min(-1, keepdim=True)[0])
362
+ attention_interp = attention_interp.reshape(n_heads, image.shape[0], image.shape[1])
363
+
364
+
365
+ colour = list(cmap_spectral(step_linspace[stepi]))
366
+ route_colours.append(colour)
367
+ for headi in range(min(8, n_heads)):
368
+ com_attn = np.copy(attention_interp[headi])
369
+ com_attn[com_attn < np.percentile(com_attn, 97)] = 0.0
370
+ if headi not in route_steps:
371
+ A = attention_interp[headi].detach().cpu().numpy()
372
+ centres, areas = find_island_centers(A, threshold=0.7)
373
+ route_steps[headi] = [centres[np.argmax(areas)]]
374
+ else:
375
+ A = attention_interp[headi].detach().cpu().numpy()
376
+ centres, areas = find_island_centers(A, threshold=0.7)
377
+ route_steps[headi] = route_steps[headi] + [centres[np.argmax(areas)]]
378
+
379
+ mosaic = [['head_0', 'head_0_overlay', 'head_1', 'head_1_overlay'],
380
+ ['head_2', 'head_2_overlay', 'head_3', 'head_3_overlay'],
381
+ ['head_4', 'head_4_overlay', 'head_5', 'head_5_overlay'],
382
+ ['head_6', 'head_6_overlay', 'head_7', 'head_7_overlay'],
383
+ ['probabilities', 'probabilities','certainty', 'certainty'],
384
+ ['umap', 'umap', 'umap', 'umap'],
385
+ ['umap', 'umap', 'umap', 'umap'],
386
+ ['umap', 'umap', 'umap', 'umap'],
387
+
388
+ ]
389
+
390
+
391
+ img_aspect = image.shape[0]/image.shape[1]
392
+ # print(img_aspect)
393
+ aspect_ratio = (4*figscale, 8*figscale*img_aspect)
394
+ fig, axes = plt.subplot_mosaic(mosaic, figsize=aspect_ratio)
395
+ for ax in axes.values():
396
+ ax.axis('off')
397
+
398
+
399
+ axes['certainty'].plot(np.arange(len(certainties_now)), certainties_now, 'k-', linewidth=figscale*1, label='1-(normalised entropy)')
400
+ for ii, (x, y) in enumerate(zip(np.arange(len(certainties_now)), certainties_now)):
401
+ is_correct = predictions[:, ii].argmax(-1)==target
402
+ if is_correct: axes['certainty'].axvspan(ii, ii + 1, facecolor='limegreen', edgecolor=None, lw=0, alpha=0.3)
403
+ else:
404
+ axes['certainty'].axvspan(ii, ii + 1, facecolor='orchid', edgecolor=None, lw=0, alpha=0.3)
405
+ axes['certainty'].plot(len(certainties_now)-1, certainties_now[-1], 'k.', markersize=figscale*4)
406
+ axes['certainty'].axis('off')
407
+ axes['certainty'].set_ylim([-0.05, 1.05])
408
+ axes['certainty'].set_xlim([0, certainties.shape[-1]+1])
409
+
410
+ ps = torch.softmax(torch.from_numpy(predictions[:, stepi]), -1)
411
+ k = 15 if len(class_labels) > 15 else len(class_labels)
412
+ topk = torch.topk (ps, k, dim = 0, largest=True).indices.detach().cpu().numpy()
413
+ top_classes = np.array(class_labels)[topk]
414
+ true_class = target
415
+ colours = [('b' if ci != true_class else 'g') for ci in topk]
416
+ bar_heights = ps[topk].detach().cpu().numpy()
417
+
418
+
419
+ axes['probabilities'].bar(np.arange(len(bar_heights))[::-1], bar_heights, color=np.array(colours), alpha=1)
420
+ axes['probabilities'].set_ylim([0, 1])
421
+ axes['probabilities'].axis('off')
422
+
423
+
424
+ for i, (name) in enumerate(top_classes):
425
+ prob = ps[i]
426
+ is_correct = name==class_labels[true_class]
427
+ fg_color = 'darkgreen' if is_correct else 'crimson'
428
+ text_str = f'{name[:40]}'
429
+ axes['probabilities'].text(
430
+ 0.05,
431
+ 0.95 - i * 0.055, # Adjust vertical position for each line
432
+ text_str,
433
+ transform=axes['probabilities'].transAxes,
434
+ verticalalignment='top',
435
+ fontsize=8, # Increased font size
436
+ color=fg_color,
437
+ alpha=0.5,
438
+ path_effects=[
439
+ patheffects.Stroke(linewidth=3, foreground='aliceblue'),
440
+ patheffects.Normal()
441
+ ])
442
+
443
+
444
+
445
+ attention_now = attention_tracking[max(0, stepi-5):stepi+1].mean(0) # Make it smooth for pretty
446
+ # attention_now = (attention_tracking[:stepi+1, 0] * decay).sum(0)/(decay.sum(0))
447
+ certainties_now = certainties[1, :stepi+1]
448
+ attention_interp = torch.nn.functional.interpolate(torch.from_numpy(attention_now).unsqueeze(0), image.shape[:2], mode='nearest')[0]
449
+ attention_interp = (attention_interp.flatten(1) - attention_interp.flatten(1).min(-1, keepdim=True)[0])/(attention_interp.flatten(1).max(-1, keepdim=True)[0] - attention_interp.flatten(1).min(-1, keepdim=True)[0])
450
+ attention_interp = attention_interp.reshape(n_heads, image.shape[0], image.shape[1])
451
+
452
+ for hi in range(min(8, n_heads)):
453
+ ax = axes[f'head_{hi}']
454
+ img_to_plot = cmap_viridis(attention_interp[hi].detach().cpu().numpy())
455
+ ax.imshow(img_to_plot)
456
+
457
+ ax_overlay = axes[f'head_{hi}_overlay']
458
+
459
+ these_route_steps = route_steps[hi]
460
+ y_coords, x_coords = zip(*these_route_steps)
461
+ y_coords = image.shape[-2] - np.array(list(y_coords))-1
462
+
463
+ ax_overlay.imshow(np.flip(image, axis=0), origin='lower')
464
+ # ax.imshow(np.flip(solution_maze, axis=0), origin='lower')
465
+ arrow_scale = 1.5 if image.shape[0] > 32 else 0.8
466
+ for i in range(len(these_route_steps)-1):
467
+ dx = x_coords[i+1] - x_coords[i]
468
+ dy = y_coords[i+1] - y_coords[i]
469
+
470
+ ax_overlay.arrow(x_coords[i], y_coords[i], dx, dy, linewidth=1.6*arrow_scale*1.3, head_width=1.9*arrow_scale*1.3, head_length=1.4*arrow_scale*1.45, fc='white', ec='white', length_includes_head = True, alpha=1)
471
+ ax_overlay.arrow(x_coords[i], y_coords[i], dx, dy, linewidth=1.6*arrow_scale, head_width=1.9*arrow_scale, head_length=1.4*arrow_scale, fc=route_colours[i], ec=route_colours[i], length_includes_head = True)
472
+
473
+ ax_overlay.set_xlim([0,image.shape[1]-1])
474
+ ax_overlay.set_ylim([0,image.shape[0]-1])
475
+ ax_overlay.axis('off')
476
+
477
+
478
+ z = post_activations_normed[stepi]
479
+
480
+ axes['umap'].scatter(x_umap, y_umap, s=30, c=cmap_spectral(z))
481
+
482
+ fig.tight_layout(pad=0.1)
483
+
484
+
485
+
486
+ canvas = fig.canvas
487
+ canvas.draw()
488
+ image_numpy = np.frombuffer(canvas.buffer_rgba(), dtype='uint8')
489
+ image_numpy = (image_numpy.reshape(*reversed(canvas.get_width_height()), 4)[:,:,:3])
490
+ frames.append(image_numpy)
491
+ plt.close(fig)
492
+ pbar_inner.update(1)
493
+ pbar_inner.set_description('Saving gif')
494
+ imageio.mimsave(save_location, frames, fps=15, loop=100)
tasks/image_classification/scripts/train_cifar10.sh ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ python -m tasks.image_classification.train \
2
+ --log_dir logs/cifar10-versus-humans/ctm/d=256--i=64--heads=16--sd=5--synch=256-512-0-h=64-random-pairing--iters=50x15--backbone=18-1--seed=1 \
3
+ --model ctm
4
+ --dataset cifar10 \
5
+ --d_model 256 \
6
+ --d_input 64 \
7
+ --synapse_depth 5 \
8
+ --heads 16 \
9
+ --n_synch_out 256 \
10
+ --n_synch_action 512 \
11
+ --n_random_pairing_self 0 \
12
+ --neuron_select_type random-pairing \
13
+ --iterations 50 \
14
+ --memory_length 15 \
15
+ --deep_memory \
16
+ --memory_hidden_dims 64 \
17
+ --dropout 0.0 \
18
+ --dropout_nlm 0 \
19
+ --no-do_normalisation \
20
+ --positional_embedding_type none \
21
+ --backbone_type resnet18-1 \
22
+ --training_iterations 600001 \
23
+ --warmup_steps 1000 \
24
+ --use_scheduler \
25
+ --scheduler_type cosine \
26
+ --weight_decay 0.0001 \
27
+ --save_every 1000 \
28
+ --track_every 2000 \
29
+ --n_test_batches 50 \
30
+ --num_workers_train 8 \
31
+ --batch_size 512 \
32
+ --batch_size_test 512 \
33
+ --lr 1e-4 \
34
+ --device 0 \
35
+ --seed 1
36
+
37
+
38
+ python -m tasks.image_classification.train \
39
+ --log_dir logs/cifar10-versus-humans/ctm/d=256--i=64--heads=16--sd=5--synch=256-512-0-h=64-random-pairing--iters=50x15--backbone=18-1--seed=2 \
40
+ --model ctm
41
+ --dataset cifar10 \
42
+ --d_model 256 \
43
+ --d_input 64 \
44
+ --synapse_depth 5 \
45
+ --heads 16 \
46
+ --n_synch_out 256 \
47
+ --n_synch_action 512 \
48
+ --n_random_pairing_self 0 \
49
+ --neuron_select_type random-pairing \
50
+ --iterations 50 \
51
+ --memory_length 15 \
52
+ --deep_memory \
53
+ --memory_hidden_dims 64 \
54
+ --dropout 0.0 \
55
+ --dropout_nlm 0 \
56
+ --no-do_normalisation \
57
+ --positional_embedding_type none \
58
+ --backbone_type resnet18-1 \
59
+ --training_iterations 600001 \
60
+ --warmup_steps 1000 \
61
+ --use_scheduler \
62
+ --scheduler_type cosine \
63
+ --weight_decay 0.0001 \
64
+ --save_every 1000 \
65
+ --track_every 2000 \
66
+ --n_test_batches 50 \
67
+ --num_workers_train 8 \
68
+ --batch_size 512 \
69
+ --batch_size_test 512 \
70
+ --lr 1e-4 \
71
+ --device 0 \
72
+ --seed 2
73
+
74
+ python -m tasks.image_classification.train \
75
+ --log_dir logs/cifar10-versus-humans/ctm/d=256--i=64--heads=16--sd=5--synch=256-512-0-h=64-random-pairing--iters=50x15--backbone=18-1--seed=42 \
76
+ --model ctm
77
+ --dataset cifar10 \
78
+ --d_model 256 \
79
+ --d_input 64 \
80
+ --synapse_depth 5 \
81
+ --heads 16 \
82
+ --n_synch_out 256 \
83
+ --n_synch_action 512 \
84
+ --n_random_pairing_self 0 \
85
+ --neuron_select_type random-pairing \
86
+ --iterations 50 \
87
+ --memory_length 15 \
88
+ --deep_memory \
89
+ --memory_hidden_dims 64 \
90
+ --dropout 0.0 \
91
+ --dropout_nlm 0 \
92
+ --no-do_normalisation \
93
+ --positional_embedding_type none \
94
+ --backbone_type resnet18-1 \
95
+ --training_iterations 600001 \
96
+ --warmup_steps 1000 \
97
+ --use_scheduler \
98
+ --scheduler_type cosine \
99
+ --weight_decay 0.0001 \
100
+ --save_every 1000 \
101
+ --track_every 2000 \
102
+ --n_test_batches 50 \
103
+ --num_workers_train 8 \
104
+ --batch_size 512 \
105
+ --batch_size_test 512 \
106
+ --lr 1e-4 \
107
+ --device 0 \
108
+ --seed 42
109
+
110
+
111
+
112
+
113
+
114
+
115
+ python -m tasks.image_classification.train \
116
+ --log_dir logs/cifar10-versus-humans/lstm/nlayers=2--d=256--i=64--heads=16--synch=256-512-0-h=64-random-pairing--iters=50x15--backbone=18-1--seed=1 \
117
+ --dataset cifar10 \
118
+ --model lstm \
119
+ --num_layers 2 \
120
+ --d_model 256 \
121
+ --d_input 64 \
122
+ --heads 16 \
123
+ --iterations 50 \
124
+ --dropout 0.0 \
125
+ --positional_embedding_type none \
126
+ --backbone_type resnet18-1 \
127
+ --training_iterations 600001 \
128
+ --warmup_steps 2000 \
129
+ --use_scheduler \
130
+ --scheduler_type cosine \
131
+ --weight_decay 0.0001 \
132
+ --save_every 1000 \
133
+ --track_every 2000 \
134
+ --n_test_batches 50 \
135
+ --reload \
136
+ --num_workers_train 8 \
137
+ --batch_size 512 \
138
+ --batch_size_test 512 \
139
+ --lr 1e-4 \
140
+ --device 0 \
141
+ --seed 1 \
142
+ --no-reload
143
+
144
+
145
+ python -m tasks.image_classification.train \
146
+ --log_dir logs/cifar10-versus-humans/lstm/nlayers=2--d=256--i=64--heads=16--synch=256-512-0-h=64-random-pairing--iters=50x15--backbone=18-1--seed=2 \
147
+ --dataset cifar10 \
148
+ --model lstm \
149
+ --num_layers 2 \
150
+ --d_model 256 \
151
+ --d_input 64 \
152
+ --heads 16 \
153
+ --iterations 50 \
154
+ --dropout 0.0 \
155
+ --positional_embedding_type none \
156
+ --backbone_type resnet18-1 \
157
+ --training_iterations 600001 \
158
+ --warmup_steps 2000 \
159
+ --use_scheduler \
160
+ --scheduler_type cosine \
161
+ --weight_decay 0.0001 \
162
+ --save_every 1000 \
163
+ --track_every 2000 \
164
+ --n_test_batches 50 \
165
+ --reload \
166
+ --num_workers_train 8 \
167
+ --batch_size 512 \
168
+ --batch_size_test 512 \
169
+ --lr 1e-4 \
170
+ --device 0 \
171
+ --seed 2 \
172
+ --no-reload
173
+
174
+
175
+ python -m tasks.image_classification.train \
176
+ --log_dir logs/cifar10-versus-humans/lstm/nlayers=2--d=256--i=64--heads=16--synch=256-512-0-h=64-random-pairing--iters=50x15--backbone=18-1--seed=42 \
177
+ --dataset cifar10 \
178
+ --model lstm \
179
+ --num_layers 2 \
180
+ --d_model 256 \
181
+ --d_input 64 \
182
+ --heads 16 \
183
+ --iterations 50 \
184
+ --dropout 0.0 \
185
+ --positional_embedding_type none \
186
+ --backbone_type resnet18-1 \
187
+ --training_iterations 600001 \
188
+ --warmup_steps 2000 \
189
+ --use_scheduler \
190
+ --scheduler_type cosine \
191
+ --weight_decay 0.0001 \
192
+ --save_every 1000 \
193
+ --track_every 2000 \
194
+ --n_test_batches 50 \
195
+ --reload \
196
+ --num_workers_train 8 \
197
+ --batch_size 512 \
198
+ --batch_size_test 512 \
199
+ --lr 1e-4 \
200
+ --device 0 \
201
+ --seed 42 \
202
+ --no-reload
203
+
204
+
205
+
206
+
207
+
208
+ python -m tasks.image_classification.train \
209
+ --log_dir logs/cifar10-versus-humans/ff/d=256--backbone=18-1--seed=1 \
210
+ --dataset cifar10 \
211
+ --model ff \
212
+ --d_model 256 \
213
+ --memory_hidden_dims 64 \
214
+ --dropout 0.0 \
215
+ --dropout_nlm 0 \
216
+ --backbone_type resnet18-1 \
217
+ --training_iterations 600001 \
218
+ --warmup_steps 1000 \
219
+ --use_scheduler \
220
+ --scheduler_type cosine \
221
+ --weight_decay 0.0001 \
222
+ --save_every 1000 \
223
+ --track_every 2000 \
224
+ --n_test_batches 50 \
225
+ --num_workers_train 8 \
226
+ --batch_size 512 \
227
+ --batch_size_test 512 \
228
+ --lr 1e-4 \
229
+ --device 0 \
230
+ --seed 1
231
+
232
+
233
+ python -m tasks.image_classification.train \
234
+ --log_dir logs/cifar10-versus-humans/ff/d=256--backbone=18-1--seed=2 \
235
+ --dataset cifar10 \
236
+ --model ff \
237
+ --d_model 256 \
238
+ --memory_hidden_dims 64 \
239
+ --dropout 0.0 \
240
+ --dropout_nlm 0 \
241
+ --backbone_type resnet18-1 \
242
+ --training_iterations 600001 \
243
+ --warmup_steps 1000 \
244
+ --use_scheduler \
245
+ --scheduler_type cosine \
246
+ --weight_decay 0.0001 \
247
+ --save_every 1000 \
248
+ --track_every 2000 \
249
+ --n_test_batches 50 \
250
+ --num_workers_train 8 \
251
+ --batch_size 512 \
252
+ --batch_size_test 512 \
253
+ --lr 1e-4 \
254
+ --device 0 \
255
+ --seed 2
256
+
257
+ python -m tasks.image_classification.train \
258
+ --log_dir logs/cifar10-versus-humans/ff/d=256--backbone=18-1--seed=42 \
259
+ --dataset cifar10 \
260
+ --model ff \
261
+ --d_model 256 \
262
+ --memory_hidden_dims 64 \
263
+ --dropout 0.0 \
264
+ --dropout_nlm 0 \
265
+ --backbone_type resnet18-1 \
266
+ --training_iterations 600001 \
267
+ --warmup_steps 1000 \
268
+ --use_scheduler \
269
+ --scheduler_type cosine \
270
+ --weight_decay 0.0001 \
271
+ --save_every 1000 \
272
+ --track_every 2000 \
273
+ --n_test_batches 50 \
274
+ --num_workers_train 8 \
275
+ --batch_size 512 \
276
+ --batch_size_test 512 \
277
+ --lr 1e-4 \
278
+ --device 0 \
279
+ --seed 42
280
+
281
+
282
+
283
+
284
+
285
+
286
+
tasks/image_classification/scripts/train_imagenet.sh ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torchrun --standalone --nnodes=1 --nproc_per_node=8 -m tasks.image_classification.train_distributed \
2
+ --log_dir logs/imagenet/d=4096--i=1024--heads=16--sd=8--nlm=64--synch=8192-2048-32-h=64-random-pairing--iters=50x25--backbone=152x4 \
3
+ --model ctm \
4
+ --dataset imagenet \
5
+ --d_model 4096 \
6
+ --d_input 1024 \
7
+ --synapse_depth 8 \
8
+ --heads 16 \
9
+ --n_synch_out 8196 \
10
+ --n_synch_action 2048 \
11
+ --n_random_pairing_self 32 \
12
+ --neuron_select_type random-pairing \
13
+ --iterations 50 \
14
+ --memory_length 25 \
15
+ --deep_memory \
16
+ --memory_hidden_dims 64 \
17
+ --dropout 0.2 \
18
+ --dropout_nlm 0 \
19
+ --no-do_normalisation \
20
+ --positional_embedding_type none \
21
+ --backbone_type resnet152-4 \
22
+ --batch_size 64 \
23
+ --batch_size_test 64 \
24
+ --n_test_batches 200 \
25
+ --lr 5e-4 \
26
+ --gradient_clipping 20 \
27
+ --training_iterations 500001 \
28
+ --save_every 1000 \
29
+ --track_every 5000 \
30
+ --warmup_steps 10000 \
31
+ --use_scheduler \
32
+ --scheduler_type cosine \
33
+ --weight_decay 0.0 \
34
+ --seed 1 \
35
+ --use_amp \
36
+ --reload \
37
+ --num_workers_train 8 \
38
+ --use_custom_sampler
tasks/image_classification/train.py ADDED
@@ -0,0 +1,690 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import random
4
+
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ import seaborn as sns
8
+ sns.set_style('darkgrid')
9
+ import torch
10
+ if torch.cuda.is_available():
11
+ # For faster
12
+ torch.set_float32_matmul_precision('high')
13
+ import torch.nn as nn
14
+ from tqdm.auto import tqdm
15
+
16
+ from data.custom_datasets import ImageNet
17
+ from torchvision import datasets
18
+ from torchvision import transforms
19
+ from tasks.image_classification.imagenet_classes import IMAGENET2012_CLASSES
20
+ from models.ctm import ContinuousThoughtMachine
21
+ from models.lstm import LSTMBaseline
22
+ from models.ff import FFBaseline
23
+ from tasks.image_classification.plotting import plot_neural_dynamics, make_classification_gif
24
+ from utils.housekeeping import set_seed, zip_python_code
25
+ from utils.losses import image_classification_loss # Used by CTM, LSTM
26
+ from utils.schedulers import WarmupCosineAnnealingLR, WarmupMultiStepLR, warmup
27
+
28
+ from autoclip.torch import QuantileClip
29
+
30
+ import gc
31
+ import torchvision
32
+ torchvision.disable_beta_transforms_warning()
33
+
34
+
35
+ import warnings
36
+ warnings.filterwarnings("ignore", message="using precomputed metric; inverse_transform will be unavailable")
37
+ warnings.filterwarnings('ignore', message='divide by zero encountered in power', category=RuntimeWarning)
38
+ warnings.filterwarnings(
39
+ "ignore",
40
+ "Corrupt EXIF data",
41
+ UserWarning,
42
+ r"^PIL\.TiffImagePlugin$" # Using a regular expression to match the module.
43
+ )
44
+ warnings.filterwarnings(
45
+ "ignore",
46
+ "UserWarning: Metadata Warning",
47
+ UserWarning,
48
+ r"^PIL\.TiffImagePlugin$" # Using a regular expression to match the module.
49
+ )
50
+ warnings.filterwarnings(
51
+ "ignore",
52
+ "UserWarning: Truncated File Read",
53
+ UserWarning,
54
+ r"^PIL\.TiffImagePlugin$" # Using a regular expression to match the module.
55
+ )
56
+
57
+
58
+ def parse_args():
59
+ parser = argparse.ArgumentParser()
60
+
61
+ # Model Selection
62
+ parser.add_argument('--model', type=str, default='ctm', choices=['ctm', 'lstm', 'ff'], help='Model type to train.')
63
+
64
+ # Model Architecture
65
+ # Common
66
+ parser.add_argument('--d_model', type=int, default=512, help='Dimension of the model.')
67
+ parser.add_argument('--dropout', type=float, default=0.0, help='Dropout rate.')
68
+ parser.add_argument('--backbone_type', type=str, default='resnet18-4', help='Type of backbone featureiser.')
69
+ # CTM / LSTM specific
70
+ parser.add_argument('--d_input', type=int, default=128, help='Dimension of the input (CTM, LSTM).')
71
+ parser.add_argument('--heads', type=int, default=4, help='Number of attention heads (CTM, LSTM).')
72
+ parser.add_argument('--iterations', type=int, default=75, help='Number of internal ticks (CTM, LSTM).')
73
+ parser.add_argument('--positional_embedding_type', type=str, default='none', help='Type of positional embedding (CTM, LSTM).',
74
+ choices=['none',
75
+ 'learnable-fourier',
76
+ 'multi-learnable-fourier',
77
+ 'custom-rotational'])
78
+ # CTM specific
79
+ parser.add_argument('--synapse_depth', type=int, default=4, help='Depth of U-NET model for synapse. 1=linear, no unet (CTM only).')
80
+ parser.add_argument('--n_synch_out', type=int, default=512, help='Number of neurons to use for output synch (CTM only).')
81
+ parser.add_argument('--n_synch_action', type=int, default=512, help='Number of neurons to use for observation/action synch (CTM only).')
82
+ parser.add_argument('--neuron_select_type', type=str, default='random-pairing', help='Protocol for selecting neuron subset (CTM only).')
83
+ parser.add_argument('--n_random_pairing_self', type=int, default=0, help='Number of neurons paired self-to-self for synch (CTM only).')
84
+ parser.add_argument('--memory_length', type=int, default=25, help='Length of the pre-activation history for NLMS (CTM only).')
85
+ parser.add_argument('--deep_memory', action=argparse.BooleanOptionalAction, default=True, help='Use deep memory (CTM only).')
86
+ parser.add_argument('--memory_hidden_dims', type=int, default=4, help='Hidden dimensions of the memory if using deep memory (CTM only).')
87
+ parser.add_argument('--dropout_nlm', type=float, default=None, help='Dropout rate for NLMs specifically. Unset to match dropout on the rest of the model (CTM only).')
88
+ parser.add_argument('--do_normalisation', action=argparse.BooleanOptionalAction, default=False, help='Apply normalization in NLMs (CTM only).')
89
+ # LSTM specific
90
+ parser.add_argument('--num_layers', type=int, default=2, help='Number of LSTM stacked layers (LSTM only).')
91
+
92
+ # Training
93
+ parser.add_argument('--batch_size', type=int, default=32, help='Batch size for training.')
94
+ parser.add_argument('--batch_size_test', type=int, default=32, help='Batch size for testing.')
95
+ parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate for the model.')
96
+ parser.add_argument('--training_iterations', type=int, default=100001, help='Number of training iterations.')
97
+ parser.add_argument('--warmup_steps', type=int, default=5000, help='Number of warmup steps.')
98
+ parser.add_argument('--use_scheduler', action=argparse.BooleanOptionalAction, default=True, help='Use a learning rate scheduler.')
99
+ parser.add_argument('--scheduler_type', type=str, default='cosine', choices=['multistep', 'cosine'], help='Type of learning rate scheduler.')
100
+ parser.add_argument('--milestones', type=int, default=[8000, 15000, 20000], nargs='+', help='Learning rate scheduler milestones.')
101
+ parser.add_argument('--gamma', type=float, default=0.1, help='Learning rate scheduler gamma for multistep.')
102
+ parser.add_argument('--weight_decay', type=float, default=0.0, help='Weight decay factor.')
103
+ parser.add_argument('--weight_decay_exclusion_list', type=str, nargs='+', default=[], help='List to exclude from weight decay. Typically good: bn, ln, bias, start')
104
+ parser.add_argument('--gradient_clipping', type=float, default=-1, help='Gradient quantile clipping value (-1 to disable).')
105
+ parser.add_argument('--do_compile', action=argparse.BooleanOptionalAction, default=False, help='Try to compile model components (backbone, synapses if CTM).')
106
+ parser.add_argument('--num_workers_train', type=int, default=1, help='Num workers training.')
107
+
108
+ # Housekeeping
109
+ parser.add_argument('--log_dir', type=str, default='logs/scratch', help='Directory for logging.')
110
+ parser.add_argument('--dataset', type=str, default='cifar10', help='Dataset to use.')
111
+ parser.add_argument('--data_root', type=str, default='data/', help='Where to save dataset.')
112
+ parser.add_argument('--save_every', type=int, default=1000, help='Save checkpoints every this many iterations.')
113
+ parser.add_argument('--seed', type=int, default=412, help='Random seed.')
114
+ parser.add_argument('--reload', action=argparse.BooleanOptionalAction, default=False, help='Reload from disk?')
115
+ parser.add_argument('--reload_model_only', action=argparse.BooleanOptionalAction, default=False, help='Reload only the model from disk?')
116
+ parser.add_argument('--strict_reload', action=argparse.BooleanOptionalAction, default=True, help='Should use strict reload for model weights.') # Added back
117
+ parser.add_argument('--track_every', type=int, default=1000, help='Track metrics every this many iterations.')
118
+ parser.add_argument('--n_test_batches', type=int, default=20, help='How many minibatches to approx metrics. Set to -1 for full eval')
119
+ parser.add_argument('--device', type=int, nargs='+', default=[-1], help='List of GPU(s) to use. Set to -1 to use CPU.')
120
+ parser.add_argument('--use_amp', action=argparse.BooleanOptionalAction, default=False, help='AMP autocast.')
121
+
122
+
123
+ args = parser.parse_args()
124
+ return args
125
+
126
+
127
+ def get_dataset(dataset, root):
128
+ if dataset=='imagenet':
129
+ dataset_mean = [0.485, 0.456, 0.406]
130
+ dataset_std = [0.229, 0.224, 0.225]
131
+
132
+ normalize = transforms.Normalize(mean=dataset_mean, std=dataset_std)
133
+ train_transform = transforms.Compose([
134
+ transforms.RandomResizedCrop(224),
135
+ transforms.RandomHorizontalFlip(),
136
+ transforms.ToTensor(),
137
+ normalize])
138
+ test_transform = transforms.Compose([
139
+ transforms.Resize(256),
140
+ transforms.CenterCrop(224),
141
+ transforms.ToTensor(),
142
+ normalize])
143
+
144
+ class_labels = list(IMAGENET2012_CLASSES.values())
145
+
146
+ train_data = ImageNet(which_split='train', transform=train_transform)
147
+ test_data = ImageNet(which_split='validation', transform=test_transform)
148
+ elif dataset=='cifar10':
149
+ dataset_mean = [0.49139968, 0.48215827, 0.44653124]
150
+ dataset_std = [0.24703233, 0.24348505, 0.26158768]
151
+ normalize = transforms.Normalize(mean=dataset_mean, std=dataset_std)
152
+ train_transform = transforms.Compose(
153
+ [transforms.AutoAugment(transforms.AutoAugmentPolicy.CIFAR10),
154
+ transforms.ToTensor(),
155
+ normalize,
156
+ ])
157
+
158
+ test_transform = transforms.Compose(
159
+ [transforms.ToTensor(),
160
+ normalize,
161
+ ])
162
+ train_data = datasets.CIFAR10(root, train=True, transform=train_transform, download=True)
163
+ test_data = datasets.CIFAR10(root, train=False, transform=test_transform, download=True)
164
+ class_labels = ['air', 'auto', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
165
+ elif dataset=='cifar100':
166
+ dataset_mean = [0.5070751592371341, 0.48654887331495067, 0.4409178433670344]
167
+ dataset_std = [0.2673342858792403, 0.2564384629170882, 0.27615047132568393]
168
+ normalize = transforms.Normalize(mean=dataset_mean, std=dataset_std)
169
+
170
+ train_transform = transforms.Compose(
171
+ [transforms.AutoAugment(transforms.AutoAugmentPolicy.CIFAR10),
172
+ transforms.ToTensor(),
173
+ normalize,
174
+ ])
175
+ test_transform = transforms.Compose(
176
+ [transforms.ToTensor(),
177
+ normalize,
178
+ ])
179
+ train_data = datasets.CIFAR100(root, train=True, transform=train_transform, download=True)
180
+ test_data = datasets.CIFAR100(root, train=False, transform=test_transform, download=True)
181
+ idx_order = np.argsort(np.array(list(train_data.class_to_idx.values())))
182
+ class_labels = list(np.array(list(train_data.class_to_idx.keys()))[idx_order])
183
+ else:
184
+ raise NotImplementedError
185
+
186
+ return train_data, test_data, class_labels, dataset_mean, dataset_std
187
+
188
+
189
+
190
+ if __name__=='__main__':
191
+
192
+ # Hosuekeeping
193
+ args = parse_args()
194
+
195
+ set_seed(args.seed, False)
196
+ if not os.path.exists(args.log_dir): os.makedirs(args.log_dir)
197
+
198
+ assert args.dataset in ['cifar10', 'cifar100', 'imagenet']
199
+
200
+ # Data
201
+ train_data, test_data, class_labels, dataset_mean, dataset_std = get_dataset(args.dataset, args.data_root)
202
+
203
+ num_workers_test = 1 # Defaulting to 1, change if needed
204
+ trainloader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers_train)
205
+ testloader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size_test, shuffle=True, num_workers=num_workers_test, drop_last=False)
206
+
207
+ prediction_reshaper = [-1] # Problem specific
208
+ args.out_dims = len(class_labels)
209
+
210
+ # For total reproducibility
211
+ zip_python_code(f'{args.log_dir}/repo_state.zip')
212
+ with open(f'{args.log_dir}/args.txt', 'w') as f:
213
+ print(args, file=f)
214
+
215
+ # Configure device string (support MPS on macOS)
216
+ if args.device[0] != -1:
217
+ device = f'cuda:{args.device[0]}'
218
+ elif torch.backends.mps.is_available():
219
+ device = 'mps'
220
+ else:
221
+ device = 'cpu'
222
+ print(f'Running model {args.model} on {device}')
223
+
224
+ # Build model conditionally
225
+ model = None
226
+ if args.model == 'ctm':
227
+ model = ContinuousThoughtMachine(
228
+ iterations=args.iterations,
229
+ d_model=args.d_model,
230
+ d_input=args.d_input,
231
+ heads=args.heads,
232
+ n_synch_out=args.n_synch_out,
233
+ n_synch_action=args.n_synch_action,
234
+ synapse_depth=args.synapse_depth,
235
+ memory_length=args.memory_length,
236
+ deep_nlms=args.deep_memory,
237
+ memory_hidden_dims=args.memory_hidden_dims,
238
+ do_layernorm_nlm=args.do_normalisation,
239
+ backbone_type=args.backbone_type,
240
+ positional_embedding_type=args.positional_embedding_type,
241
+ out_dims=args.out_dims,
242
+ prediction_reshaper=prediction_reshaper,
243
+ dropout=args.dropout,
244
+ dropout_nlm=args.dropout_nlm,
245
+ neuron_select_type=args.neuron_select_type,
246
+ n_random_pairing_self=args.n_random_pairing_self,
247
+ ).to(device)
248
+ elif args.model == 'lstm':
249
+ model = LSTMBaseline(
250
+ num_layers=args.num_layers,
251
+ iterations=args.iterations,
252
+ d_model=args.d_model,
253
+ d_input=args.d_input,
254
+ heads=args.heads,
255
+ backbone_type=args.backbone_type,
256
+ positional_embedding_type=args.positional_embedding_type,
257
+ out_dims=args.out_dims,
258
+ prediction_reshaper=prediction_reshaper,
259
+ dropout=args.dropout,
260
+ ).to(device)
261
+ elif args.model == 'ff':
262
+ model = FFBaseline(
263
+ d_model=args.d_model,
264
+ backbone_type=args.backbone_type,
265
+ out_dims=args.out_dims,
266
+ dropout=args.dropout,
267
+ ).to(device)
268
+ else:
269
+ raise ValueError(f"Unknown model type: {args.model}")
270
+
271
+
272
+ # For lazy modules so that we can get param count
273
+ pseudo_inputs = train_data.__getitem__(0)[0].unsqueeze(0).to(device)
274
+ model(pseudo_inputs)
275
+
276
+ model.train()
277
+
278
+
279
+ print(f'Total params: {sum(p.numel() for p in model.parameters())}')
280
+ decay_params = []
281
+ no_decay_params = []
282
+ no_decay_names = []
283
+ for name, param in model.named_parameters():
284
+ if not param.requires_grad:
285
+ continue # Skip parameters that don't require gradients
286
+ if any(exclusion_str in name for exclusion_str in args.weight_decay_exclusion_list):
287
+ no_decay_params.append(param)
288
+ no_decay_names.append(name)
289
+ else:
290
+ decay_params.append(param)
291
+ if len(no_decay_names):
292
+ print(f'WARNING, excluding: {no_decay_names}')
293
+
294
+ # Optimizer and scheduler (Common setup)
295
+ if len(no_decay_names) and args.weight_decay!=0:
296
+ optimizer = torch.optim.AdamW([{'params': decay_params, 'weight_decay':args.weight_decay},
297
+ {'params': no_decay_params, 'weight_decay':0}],
298
+ lr=args.lr,
299
+ eps=1e-8 if not args.use_amp else 1e-6)
300
+ else:
301
+ optimizer = torch.optim.AdamW(model.parameters(),
302
+ lr=args.lr,
303
+ eps=1e-8 if not args.use_amp else 1e-6,
304
+ weight_decay=args.weight_decay)
305
+
306
+
307
+ warmup_schedule = warmup(args.warmup_steps)
308
+ scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warmup_schedule.step)
309
+ if args.use_scheduler:
310
+ if args.scheduler_type == 'multistep':
311
+ scheduler = WarmupMultiStepLR(optimizer, warmup_steps=args.warmup_steps, milestones=args.milestones, gamma=args.gamma)
312
+ elif args.scheduler_type == 'cosine':
313
+ scheduler = WarmupCosineAnnealingLR(optimizer, args.warmup_steps, args.training_iterations, warmup_start_lr=1e-20, eta_min=1e-7)
314
+ else:
315
+ raise NotImplementedError
316
+
317
+
318
+ # Metrics tracking
319
+ start_iter = 0
320
+ train_losses = []
321
+ test_losses = []
322
+ train_accuracies = []
323
+ test_accuracies = []
324
+ iters = []
325
+ # Conditional metrics for CTM/LSTM
326
+ train_accuracies_most_certain = [] if args.model in ['ctm', 'lstm'] else None
327
+ test_accuracies_most_certain = [] if args.model in ['ctm', 'lstm'] else None
328
+
329
+ scaler = torch.amp.GradScaler("cuda" if "cuda" in device else "cpu", enabled=args.use_amp)
330
+
331
+ # Reloading logic
332
+ if args.reload:
333
+ checkpoint_path = f'{args.log_dir}/checkpoint.pt'
334
+ if os.path.isfile(checkpoint_path):
335
+ print(f'Reloading from: {checkpoint_path}')
336
+ checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
337
+ if not args.strict_reload: print('WARNING: not using strict reload for model weights!')
338
+ load_result = model.load_state_dict(checkpoint['model_state_dict'], strict=args.strict_reload)
339
+ print(f" Loaded state_dict. Missing: {load_result.missing_keys}, Unexpected: {load_result.unexpected_keys}")
340
+
341
+ if not args.reload_model_only:
342
+ print('Reloading optimizer etc.')
343
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
344
+ scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
345
+ scaler.load_state_dict(checkpoint['scaler_state_dict'])
346
+ start_iter = checkpoint['iteration']
347
+ # Load common metrics
348
+ train_losses = checkpoint['train_losses']
349
+ test_losses = checkpoint['test_losses']
350
+ train_accuracies = checkpoint['train_accuracies']
351
+ test_accuracies = checkpoint['test_accuracies']
352
+ iters = checkpoint['iters']
353
+
354
+ # Load conditional metrics if they exist in checkpoint and are expected for current model
355
+ if args.model in ['ctm', 'lstm']:
356
+ train_accuracies_most_certain = checkpoint['train_accuracies_most_certain']
357
+ test_accuracies_most_certain = checkpoint['test_accuracies_most_certain']
358
+
359
+ else:
360
+ print('Only reloading model!')
361
+
362
+ if 'torch_rng_state' in checkpoint:
363
+ # Reset seeds
364
+ torch.set_rng_state(checkpoint['torch_rng_state'].cpu().byte())
365
+ np.random.set_state(checkpoint['numpy_rng_state'])
366
+ random.setstate(checkpoint['random_rng_state'])
367
+
368
+ del checkpoint
369
+ gc.collect()
370
+ if torch.cuda.is_available():
371
+ torch.cuda.empty_cache()
372
+
373
+ # Conditional Compilation
374
+ if args.do_compile:
375
+ print('Compiling...')
376
+ if hasattr(model, 'backbone'):
377
+ model.backbone = torch.compile(model.backbone, mode='reduce-overhead', fullgraph=True)
378
+
379
+ # Compile synapses only for CTM
380
+ if args.model == 'ctm':
381
+ model.synapses = torch.compile(model.synapses, mode='reduce-overhead', fullgraph=True)
382
+
383
+ # Training
384
+ iterator = iter(trainloader)
385
+
386
+
387
+ with tqdm(total=args.training_iterations, initial=start_iter, leave=False, position=0, dynamic_ncols=True) as pbar:
388
+ for bi in range(start_iter, args.training_iterations):
389
+ current_lr = optimizer.param_groups[-1]['lr']
390
+
391
+ try:
392
+ inputs, targets = next(iterator)
393
+ except StopIteration:
394
+ iterator = iter(trainloader)
395
+ inputs, targets = next(iterator)
396
+
397
+ inputs = inputs.to(device)
398
+ targets = targets.to(device)
399
+
400
+ loss = None
401
+ accuracy = None
402
+ # Model-specific forward and loss calculation
403
+ with torch.autocast(device_type="cuda" if "cuda" in device else "cpu", dtype=torch.float16, enabled=args.use_amp):
404
+ if args.do_compile: # CUDAGraph marking for clean compile
405
+ torch.compiler.cudagraph_mark_step_begin()
406
+
407
+ if args.model == 'ctm':
408
+ predictions, certainties, synchronisation = model(inputs)
409
+ loss, where_most_certain = image_classification_loss(predictions, certainties, targets, use_most_certain=True)
410
+ accuracy = (predictions.argmax(1)[torch.arange(predictions.size(0), device=predictions.device),where_most_certain] == targets).float().mean().item()
411
+ pbar_desc = f'CTM Loss={loss.item():0.3f}. Acc={accuracy:0.3f}. LR={current_lr:0.6f}. Where_certain={where_most_certain.float().mean().item():0.2f}+-{where_most_certain.float().std().item():0.2f} ({where_most_certain.min().item():d}<->{where_most_certain.max().item():d})'
412
+
413
+ elif args.model == 'lstm':
414
+ predictions, certainties, synchronisation = model(inputs)
415
+ loss, where_most_certain = image_classification_loss(predictions, certainties, targets, use_most_certain=True)
416
+ # LSTM where_most_certain will just be -1 because use_most_certain is False owing to stability issues with LSTM training
417
+ accuracy = (predictions.argmax(1)[torch.arange(predictions.size(0), device=predictions.device),where_most_certain] == targets).float().mean().item()
418
+ pbar_desc = f'LSTM Loss={loss.item():0.3f}. Acc={accuracy:0.3f}. LR={current_lr:0.6f}. Where_certain={where_most_certain.float().mean().item():0.2f}+-{where_most_certain.float().std().item():0.2f} ({where_most_certain.min().item():d}<->{where_most_certain.max().item():d})'
419
+
420
+ elif args.model == 'ff':
421
+ predictions = model(inputs)
422
+ loss = nn.CrossEntropyLoss()(predictions, targets)
423
+ accuracy = (predictions.argmax(1) == targets).float().mean().item()
424
+ pbar_desc = f'FF Loss={loss.item():0.3f}. Acc={accuracy:0.3f}. LR={current_lr:0.6f}'
425
+
426
+ scaler.scale(loss).backward()
427
+
428
+ if args.gradient_clipping!=-1:
429
+ scaler.unscale_(optimizer)
430
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.gradient_clipping)
431
+
432
+ scaler.step(optimizer)
433
+ scaler.update()
434
+ optimizer.zero_grad(set_to_none=True)
435
+ scheduler.step()
436
+
437
+ pbar.set_description(f'Dataset={args.dataset}. Model={args.model}. {pbar_desc}')
438
+
439
+
440
+ # Metrics tracking and plotting (conditional logic needed)
441
+ if (bi % args.track_every == 0 or bi == args.warmup_steps) and (bi != 0 or args.reload_model_only):
442
+
443
+ iters.append(bi)
444
+ current_train_losses = []
445
+ current_test_losses = []
446
+ current_train_accuracies = [] # Holds list of accuracies per tick for CTM/LSTM, single value for FF
447
+ current_test_accuracies = [] # Holds list of accuracies per tick for CTM/LSTM, single value for FF
448
+ current_train_accuracies_most_certain = [] # Only for CTM/LSTM
449
+ current_test_accuracies_most_certain = [] # Only for CTM/LSTM
450
+
451
+
452
+ # Reset BN stats using train mode
453
+ pbar.set_description('Resetting BN')
454
+ model.train()
455
+ for module in model.modules():
456
+ if isinstance(module, (torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d)):
457
+ module.reset_running_stats()
458
+
459
+ pbar.set_description('Tracking: Computing TRAIN metrics')
460
+ with torch.no_grad(): # Should use inference_mode? CTM/LSTM scripts used no_grad
461
+ loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size_test, shuffle=True, num_workers=num_workers_test)
462
+ all_targets_list = []
463
+ all_predictions_list = [] # List to store raw predictions (B, C, T) or (B, C)
464
+ all_predictions_most_certain_list = [] # Only for CTM/LSTM
465
+ all_losses = []
466
+
467
+ with tqdm(total=len(loader), initial=0, leave=False, position=1, dynamic_ncols=True) as pbar_inner:
468
+ for inferi, (inputs, targets) in enumerate(loader):
469
+ inputs = inputs.to(device)
470
+ targets = targets.to(device)
471
+ all_targets_list.append(targets.detach().cpu().numpy())
472
+
473
+ # Model-specific forward and loss for evaluation
474
+ if args.model == 'ctm':
475
+ these_predictions, certainties, _ = model(inputs)
476
+ loss, where_most_certain = image_classification_loss(these_predictions, certainties, targets, use_most_certain=True)
477
+ all_predictions_list.append(these_predictions.argmax(1).detach().cpu().numpy()) # Shape (B, T)
478
+ all_predictions_most_certain_list.append(these_predictions.argmax(1)[torch.arange(these_predictions.size(0), device=these_predictions.device), where_most_certain].detach().cpu().numpy()) # Shape (B,)
479
+
480
+ elif args.model == 'lstm':
481
+ these_predictions, certainties, _ = model(inputs)
482
+ loss, where_most_certain = image_classification_loss(these_predictions, certainties, targets, use_most_certain=True)
483
+ all_predictions_list.append(these_predictions.argmax(1).detach().cpu().numpy()) # Shape (B, T)
484
+ all_predictions_most_certain_list.append(these_predictions.argmax(1)[torch.arange(these_predictions.size(0), device=these_predictions.device), where_most_certain].detach().cpu().numpy()) # Shape (B,)
485
+
486
+ elif args.model == 'ff':
487
+ these_predictions = model(inputs)
488
+ loss = nn.CrossEntropyLoss()(these_predictions, targets)
489
+ all_predictions_list.append(these_predictions.argmax(1).detach().cpu().numpy()) # Shape (B,)
490
+
491
+ all_losses.append(loss.item())
492
+
493
+ if args.n_test_batches != -1 and inferi >= args.n_test_batches -1 : break # Check condition >= N-1
494
+ pbar_inner.set_description(f'Computing metrics for train (Batch {inferi+1})')
495
+ pbar_inner.update(1)
496
+
497
+ all_targets = np.concatenate(all_targets_list)
498
+ all_predictions = np.concatenate(all_predictions_list) # Shape (N, T) or (N,)
499
+ train_losses.append(np.mean(all_losses))
500
+
501
+ if args.model in ['ctm', 'lstm']:
502
+ # Accuracies per tick for CTM/LSTM
503
+ current_train_accuracies = np.mean(all_predictions == all_targets[...,np.newaxis], axis=0) # Mean over batch dim -> Shape (T,)
504
+ train_accuracies.append(current_train_accuracies)
505
+ # Most certain accuracy
506
+ all_predictions_most_certain = np.concatenate(all_predictions_most_certain_list)
507
+ current_train_accuracies_most_certain = (all_targets == all_predictions_most_certain).mean()
508
+ train_accuracies_most_certain.append(current_train_accuracies_most_certain)
509
+ else: # FF
510
+ current_train_accuracies = (all_targets == all_predictions).mean() # Shape scalar
511
+ train_accuracies.append(current_train_accuracies)
512
+
513
+ del these_predictions
514
+
515
+
516
+ # Switch to eval mode for test metrics (fixed BN stats)
517
+ model.eval()
518
+ pbar.set_description('Tracking: Computing TEST metrics')
519
+ with torch.inference_mode(): # Use inference_mode for test eval
520
+ loader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size_test, shuffle=True, num_workers=num_workers_test)
521
+ all_targets_list = []
522
+ all_predictions_list = []
523
+ all_predictions_most_certain_list = [] # Only for CTM/LSTM
524
+ all_losses = []
525
+
526
+ with tqdm(total=len(loader), initial=0, leave=False, position=1, dynamic_ncols=True) as pbar_inner:
527
+ for inferi, (inputs, targets) in enumerate(loader):
528
+ inputs = inputs.to(device)
529
+ targets = targets.to(device)
530
+ all_targets_list.append(targets.detach().cpu().numpy())
531
+
532
+ # Model-specific forward and loss for evaluation
533
+ if args.model == 'ctm':
534
+ these_predictions, certainties, _ = model(inputs)
535
+ loss, where_most_certain = image_classification_loss(these_predictions, certainties, targets, use_most_certain=True)
536
+ all_predictions_list.append(these_predictions.argmax(1).detach().cpu().numpy())
537
+ all_predictions_most_certain_list.append(these_predictions.argmax(1)[torch.arange(these_predictions.size(0), device=these_predictions.device), where_most_certain].detach().cpu().numpy())
538
+
539
+ elif args.model == 'lstm':
540
+ these_predictions, certainties, _ = model(inputs)
541
+ loss, where_most_certain = image_classification_loss(these_predictions, certainties, targets, use_most_certain=True)
542
+ all_predictions_list.append(these_predictions.argmax(1).detach().cpu().numpy())
543
+ all_predictions_most_certain_list.append(these_predictions.argmax(1)[torch.arange(these_predictions.size(0), device=these_predictions.device), where_most_certain].detach().cpu().numpy())
544
+
545
+ elif args.model == 'ff':
546
+ these_predictions = model(inputs)
547
+ loss = nn.CrossEntropyLoss()(these_predictions, targets)
548
+ all_predictions_list.append(these_predictions.argmax(1).detach().cpu().numpy())
549
+
550
+ all_losses.append(loss.item())
551
+
552
+ if args.n_test_batches != -1 and inferi >= args.n_test_batches -1: break
553
+ pbar_inner.set_description(f'Computing metrics for test (Batch {inferi+1})')
554
+ pbar_inner.update(1)
555
+
556
+ all_targets = np.concatenate(all_targets_list)
557
+ all_predictions = np.concatenate(all_predictions_list)
558
+ test_losses.append(np.mean(all_losses))
559
+
560
+ if args.model in ['ctm', 'lstm']:
561
+ current_test_accuracies = np.mean(all_predictions == all_targets[...,np.newaxis], axis=0)
562
+ test_accuracies.append(current_test_accuracies)
563
+ all_predictions_most_certain = np.concatenate(all_predictions_most_certain_list)
564
+ current_test_accuracies_most_certain = (all_targets == all_predictions_most_certain).mean()
565
+ test_accuracies_most_certain.append(current_test_accuracies_most_certain)
566
+ else: # FF
567
+ current_test_accuracies = (all_targets == all_predictions).mean()
568
+ test_accuracies.append(current_test_accuracies)
569
+
570
+ # Plotting (conditional)
571
+ figacc = plt.figure(figsize=(10, 10))
572
+ axacc_train = figacc.add_subplot(211)
573
+ axacc_test = figacc.add_subplot(212)
574
+ cm = sns.color_palette("viridis", as_cmap=True)
575
+
576
+ if args.model in ['ctm', 'lstm']:
577
+ # Plot per-tick accuracy for CTM/LSTM
578
+ train_acc_arr = np.array(train_accuracies) # Shape (N_iters, T)
579
+ test_acc_arr = np.array(test_accuracies) # Shape (N_iters, T)
580
+ num_ticks = train_acc_arr.shape[1]
581
+ for ti in range(num_ticks):
582
+ axacc_train.plot(iters, train_acc_arr[:, ti], color=cm(ti / num_ticks), alpha=0.3)
583
+ axacc_test.plot(iters, test_acc_arr[:, ti], color=cm(ti / num_ticks), alpha=0.3)
584
+ # Plot most certain accuracy
585
+ axacc_train.plot(iters, train_accuracies_most_certain, 'k--', alpha=0.7, label='Most certain')
586
+ axacc_test.plot(iters, test_accuracies_most_certain, 'k--', alpha=0.7, label='Most certain')
587
+ else: # FF
588
+ axacc_train.plot(iters, train_accuracies, 'k-', alpha=0.7, label='Accuracy') # Simple line
589
+ axacc_test.plot(iters, test_accuracies, 'k-', alpha=0.7, label='Accuracy')
590
+
591
+ axacc_train.set_title('Train Accuracy')
592
+ axacc_test.set_title('Test Accuracy')
593
+ axacc_train.legend(loc='lower right')
594
+ axacc_test.legend(loc='lower right')
595
+ axacc_train.set_xlim([0, args.training_iterations])
596
+ axacc_test.set_xlim([0, args.training_iterations])
597
+ if args.dataset=='cifar10':
598
+ axacc_train.set_ylim([0.75, 1])
599
+ axacc_test.set_ylim([0.75, 1])
600
+
601
+
602
+
603
+ figacc.tight_layout()
604
+ figacc.savefig(f'{args.log_dir}/accuracies.png', dpi=150)
605
+ plt.close(figacc)
606
+
607
+ figloss = plt.figure(figsize=(10, 5))
608
+ axloss = figloss.add_subplot(111)
609
+ axloss.plot(iters, train_losses, 'b-', linewidth=1, alpha=0.8, label=f'Train: {train_losses[-1]:.4f}')
610
+ axloss.plot(iters, test_losses, 'r-', linewidth=1, alpha=0.8, label=f'Test: {test_losses[-1]:.4f}')
611
+ axloss.legend(loc='upper right')
612
+ axloss.set_xlim([0, args.training_iterations])
613
+ axloss.set_ylim(bottom=0)
614
+
615
+ figloss.tight_layout()
616
+ figloss.savefig(f'{args.log_dir}/losses.png', dpi=150)
617
+ plt.close(figloss)
618
+
619
+ # Conditional Visualization (Only for CTM/LSTM)
620
+ if args.model in ['ctm', 'lstm']:
621
+ try: # For safety
622
+ inputs_viz, targets_viz = next(iter(testloader)) # Get a fresh batch
623
+ inputs_viz = inputs_viz.to(device)
624
+ targets_viz = targets_viz.to(device)
625
+
626
+ pbar.set_description('Tracking: Processing test data for viz')
627
+ predictions_viz, certainties_viz, _, pre_activations_viz, post_activations_viz, attention_tracking_viz = model(inputs_viz, track=True)
628
+
629
+ att_shape = (model.kv_features.shape[2], model.kv_features.shape[3])
630
+ attention_tracking_viz = attention_tracking_viz.reshape(
631
+ attention_tracking_viz.shape[0],
632
+ attention_tracking_viz.shape[1], -1, att_shape[0], att_shape[1])
633
+
634
+ pbar.set_description('Tracking: Neural dynamics plot')
635
+ plot_neural_dynamics(post_activations_viz, 100, args.log_dir, axis_snap=True)
636
+
637
+ imgi = 0 # Visualize the first image in the batch
638
+ img_to_gif = np.moveaxis(np.clip(inputs_viz[imgi].detach().cpu().numpy()*np.array(dataset_std).reshape(len(dataset_std), 1, 1) + np.array(dataset_mean).reshape(len(dataset_mean), 1, 1), 0, 1), 0, -1)
639
+
640
+ pbar.set_description('Tracking: Producing attention gif')
641
+ make_classification_gif(img_to_gif,
642
+ targets_viz[imgi].item(),
643
+ predictions_viz[imgi].detach().cpu().numpy(),
644
+ certainties_viz[imgi].detach().cpu().numpy(),
645
+ post_activations_viz[:,imgi],
646
+ attention_tracking_viz[:,imgi],
647
+ class_labels,
648
+ f'{args.log_dir}/{imgi}_attention.gif',
649
+ )
650
+ del predictions_viz, certainties_viz, pre_activations_viz, post_activations_viz, attention_tracking_viz
651
+ except Exception as e:
652
+ print(f"Visualization failed for model {args.model}: {e}")
653
+
654
+
655
+
656
+ gc.collect()
657
+ if torch.cuda.is_available():
658
+ torch.cuda.empty_cache()
659
+ model.train() # Switch back to train mode
660
+
661
+
662
+ # Save model checkpoint (conditional metrics)
663
+ if (bi % args.save_every == 0 or bi == args.training_iterations - 1) and bi != start_iter:
664
+ pbar.set_description('Saving model checkpoint...')
665
+ checkpoint_data = {
666
+ 'model_state_dict': model.state_dict(),
667
+ 'optimizer_state_dict': optimizer.state_dict(),
668
+ 'scheduler_state_dict': scheduler.state_dict(),
669
+ 'scaler_state_dict': scaler.state_dict(),
670
+ 'iteration': bi,
671
+ # Always save these
672
+ 'train_losses': train_losses,
673
+ 'test_losses': test_losses,
674
+ 'train_accuracies': train_accuracies, # This is list of scalars for FF, list of arrays for CTM/LSTM
675
+ 'test_accuracies': test_accuracies, # This is list of scalars for FF, list of arrays for CTM/LSTM
676
+ 'iters': iters,
677
+ 'args': args, # Save args used for this run
678
+ # RNG states
679
+ 'torch_rng_state': torch.get_rng_state(),
680
+ 'numpy_rng_state': np.random.get_state(),
681
+ 'random_rng_state': random.getstate(),
682
+ }
683
+ # Conditionally add metrics specific to CTM/LSTM
684
+ if args.model in ['ctm', 'lstm']:
685
+ checkpoint_data['train_accuracies_most_certain'] = train_accuracies_most_certain
686
+ checkpoint_data['test_accuracies_most_certain'] = test_accuracies_most_certain
687
+
688
+ torch.save(checkpoint_data, f'{args.log_dir}/checkpoint.pt')
689
+
690
+ pbar.update(1)
tasks/image_classification/train_distributed.py ADDED
@@ -0,0 +1,799 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import random
4
+ import time
5
+
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+ import seaborn as sns
9
+ sns.set_style('darkgrid')
10
+ import torch
11
+ if torch.cuda.is_available():
12
+ # For faster
13
+ torch.set_float32_matmul_precision('high')
14
+ import torch.nn as nn
15
+ import torch.distributed as dist
16
+ from torch.nn.parallel import DistributedDataParallel as DDP
17
+ from torch.utils.data.distributed import DistributedSampler
18
+ from utils.samplers import FastRandomDistributedSampler
19
+ from tqdm.auto import tqdm
20
+
21
+ from tasks.image_classification.train import get_dataset # Use shared get_dataset
22
+
23
+ # Model Imports
24
+ from models.ctm import ContinuousThoughtMachine
25
+ from models.lstm import LSTMBaseline
26
+ from models.ff import FFBaseline
27
+
28
+ # Plotting/Utils Imports
29
+ from tasks.image_classification.plotting import plot_neural_dynamics, make_classification_gif
30
+ from utils.housekeeping import set_seed, zip_python_code
31
+ from utils.losses import image_classification_loss # For CTM, LSTM
32
+ from utils.schedulers import WarmupCosineAnnealingLR, WarmupMultiStepLR, warmup
33
+
34
+ import torchvision
35
+ torchvision.disable_beta_transforms_warning()
36
+
37
+ import warnings
38
+ warnings.filterwarnings("ignore", message="using precomputed metric; inverse_transform will be unavailable")
39
+ warnings.filterwarnings('ignore', message='divide by zero encountered in power', category=RuntimeWarning)
40
+ warnings.filterwarnings("ignore", message="UserWarning: Metadata Warning, tag 274 had too many entries: 4, expected 1")
41
+ warnings.filterwarnings(
42
+ "ignore",
43
+ "Corrupt EXIF data",
44
+ UserWarning,
45
+ r"^PIL\.TiffImagePlugin$" # Using a regular expression to match the module.
46
+ )
47
+ warnings.filterwarnings(
48
+ "ignore",
49
+ "UserWarning: Metadata Warning",
50
+ UserWarning,
51
+ r"^PIL\.TiffImagePlugin$" # Using a regular expression to match the module.
52
+ )
53
+ warnings.filterwarnings(
54
+ "ignore",
55
+ "UserWarning: Truncated File Read",
56
+ UserWarning,
57
+ r"^PIL\.TiffImagePlugin$" # Using a regular expression to match the module.
58
+ )
59
+
60
+
61
+ def parse_args():
62
+ parser = argparse.ArgumentParser()
63
+
64
+ # Model Selection
65
+ parser.add_argument('--model', type=str, required=True, choices=['ctm', 'lstm', 'ff'], help='Model type to train.')
66
+
67
+ # Model Architecture
68
+ # Common
69
+ parser.add_argument('--d_model', type=int, default=512, help='Dimension of the model.')
70
+ parser.add_argument('--dropout', type=float, default=0.0, help='Dropout rate.')
71
+ parser.add_argument('--backbone_type', type=str, default='resnet18-4', help='Type of backbone featureiser.')
72
+ # CTM / LSTM specific
73
+ parser.add_argument('--d_input', type=int, default=128, help='Dimension of the input (CTM, LSTM).')
74
+ parser.add_argument('--heads', type=int, default=4, help='Number of attention heads (CTM, LSTM).')
75
+ parser.add_argument('--iterations', type=int, default=50, help='Number of internal ticks (CTM, LSTM).')
76
+ parser.add_argument('--positional_embedding_type', type=str, default='none', help='Type of positional embedding (CTM, LSTM).',
77
+ choices=['none',
78
+ 'learnable-fourier',
79
+ 'multi-learnable-fourier',
80
+ 'custom-rotational'])
81
+ # CTM specific
82
+ parser.add_argument('--synapse_depth', type=int, default=4, help='Depth of U-NET model for synapse. 1=linear, no unet (CTM only).')
83
+ parser.add_argument('--n_synch_out', type=int, default=32, help='Number of neurons to use for output synch (CTM only).')
84
+ parser.add_argument('--n_synch_action', type=int, default=32, help='Number of neurons to use for observation/action synch (CTM only).')
85
+ parser.add_argument('--neuron_select_type', type=str, default='first-last', help='Protocol for selecting neuron subset (CTM only).')
86
+ parser.add_argument('--n_random_pairing_self', type=int, default=256, help='Number of neurons paired self-to-self for synch (CTM only).')
87
+ parser.add_argument('--memory_length', type=int, default=25, help='Length of the pre-activation history for NLMS (CTM only).')
88
+ parser.add_argument('--deep_memory', action=argparse.BooleanOptionalAction, default=True, help='Use deep memory (CTM only).')
89
+ parser.add_argument('--memory_hidden_dims', type=int, default=4, help='Hidden dimensions of the memory if using deep memory (CTM only).')
90
+ parser.add_argument('--dropout_nlm', type=float, default=None, help='Dropout rate for NLMs specifically. Unset to match dropout on the rest of the model (CTM only).')
91
+ parser.add_argument('--do_normalisation', action=argparse.BooleanOptionalAction, default=False, help='Apply normalization in NLMs (CTM only).')
92
+ # LSTM specific
93
+ parser.add_argument('--num_layers', type=int, default=2, help='Number of LSTM stacked layers (LSTM only).')
94
+
95
+ # Training
96
+ parser.add_argument('--batch_size', type=int, default=32, help='Batch size for training (per GPU).')
97
+ parser.add_argument('--batch_size_test', type=int, default=32, help='Batch size for testing (per GPU).')
98
+ parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate for the model.')
99
+ parser.add_argument('--training_iterations', type=int, default=100001, help='Number of training iterations.')
100
+ parser.add_argument('--warmup_steps', type=int, default=5000, help='Number of warmup steps.')
101
+ parser.add_argument('--use_scheduler', action=argparse.BooleanOptionalAction, default=True, help='Use a learning rate scheduler.')
102
+ parser.add_argument('--scheduler_type', type=str, default='cosine', choices=['multistep', 'cosine'], help='Type of learning rate scheduler.')
103
+ parser.add_argument('--milestones', type=int, default=[8000, 15000, 20000], nargs='+', help='Learning rate scheduler milestones.')
104
+ parser.add_argument('--gamma', type=float, default=0.1, help='Learning rate scheduler gamma for multistep.')
105
+ parser.add_argument('--weight_decay', type=float, default=0.0, help='Weight decay factor.')
106
+ parser.add_argument('--weight_decay_exclusion_list', type=str, nargs='+', default=[], help='List to exclude from weight decay. Typically good: bn, ln, bias, start')
107
+ parser.add_argument('--gradient_clipping', type=float, default=-1, help='Gradient quantile clipping value (-1 to disable).')
108
+ parser.add_argument('--num_workers_train', type=int, default=1, help='Num workers training.')
109
+ parser.add_argument('--use_custom_sampler', action=argparse.BooleanOptionalAction, default=False, help='Use custom fast sampler to avoid reshuffling.')
110
+ parser.add_argument('--do_compile', action=argparse.BooleanOptionalAction, default=False, help='Try to compile model components.')
111
+
112
+ # Housekeeping
113
+ parser.add_argument('--log_dir', type=str, default='logs/scratch', help='Directory for logging.')
114
+ parser.add_argument('--dataset', type=str, default='cifar10', help='Dataset to use.')
115
+ parser.add_argument('--data_root', type=str, default='data/', help='Where to save dataset.')
116
+ parser.add_argument('--save_every', type=int, default=1000, help='Save checkpoints every this many iterations.')
117
+ parser.add_argument('--seed', type=int, default=412, help='Random seed.')
118
+ parser.add_argument('--reload', action=argparse.BooleanOptionalAction, default=False, help='Reload from disk?')
119
+ parser.add_argument('--reload_model_only', action=argparse.BooleanOptionalAction, default=False, help='Reload only the model from disk?')
120
+ parser.add_argument('--strict_reload', action=argparse.BooleanOptionalAction, default=True, help='Should use strict reload for model weights.')
121
+ parser.add_argument('--ignore_metrics_when_reloading', action=argparse.BooleanOptionalAction, default=False, help='Ignore metrics when reloading?')
122
+
123
+ # Tracking
124
+ parser.add_argument('--track_every', type=int, default=1000, help='Track metrics every this many iterations.')
125
+ parser.add_argument('--n_test_batches', type=int, default=20, help='How many minibatches to approx metrics. Set to -1 for full eval')
126
+ parser.add_argument('--plot_indices', type=int, default=[0], nargs='+', help='Which indices in test data to plot?') # Defaulted to 0
127
+
128
+ # Precision
129
+ parser.add_argument('--use_amp', action=argparse.BooleanOptionalAction, default=False, help='AMP autocast.')
130
+ args = parser.parse_args()
131
+ return args
132
+
133
+ # --- DDP Setup Functions ---
134
+ def setup_ddp():
135
+ if 'RANK' not in os.environ:
136
+ # Basic setup for non-distributed run
137
+ os.environ['RANK'] = '0'
138
+ os.environ['WORLD_SIZE'] = '1'
139
+ os.environ['MASTER_ADDR'] = 'localhost'
140
+ os.environ['MASTER_PORT'] = '12355' # Ensure this port is free
141
+ os.environ['LOCAL_RANK'] = '0'
142
+ print("Running in non-distributed mode (simulated DDP setup).")
143
+ # Need to manually init if only 1 process desired for non-GPU testing
144
+ if not torch.cuda.is_available() or int(os.environ['WORLD_SIZE']) == 1:
145
+ dist.init_process_group(backend='gloo') # Gloo backend for CPU
146
+ print("Initialized process group with Gloo backend for single/CPU process.")
147
+ rank = int(os.environ['RANK'])
148
+ world_size = int(os.environ['WORLD_SIZE'])
149
+ local_rank = int(os.environ['LOCAL_RANK'])
150
+ return rank, world_size, local_rank
151
+
152
+
153
+ # Standard DDP setup
154
+ dist.init_process_group(backend='nccl') # 'nccl' for NVIDIA GPUs
155
+ rank = int(os.environ['RANK'])
156
+ world_size = int(os.environ['WORLD_SIZE'])
157
+ local_rank = int(os.environ['LOCAL_RANK'])
158
+ if torch.cuda.is_available():
159
+ torch.cuda.set_device(local_rank)
160
+ print(f"Rank {rank} setup on GPU {local_rank}")
161
+ else:
162
+ print(f"Rank {rank} setup on CPU (GPU not available or requested)")
163
+ return rank, world_size, local_rank
164
+
165
+ def cleanup_ddp():
166
+ if dist.is_initialized():
167
+ dist.destroy_process_group()
168
+ print("DDP cleanup complete.")
169
+
170
+ def is_main_process(rank):
171
+ return rank == 0
172
+ # --- End DDP Setup ---
173
+
174
+
175
+ if __name__=='__main__':
176
+
177
+ args = parse_args()
178
+
179
+ rank, world_size, local_rank = setup_ddp()
180
+
181
+ set_seed(args.seed + rank, False) # Add rank for different seeds per process
182
+
183
+ # Rank 0 handles directory creation and initial logging
184
+ if is_main_process(rank):
185
+ if not os.path.exists(args.log_dir): os.makedirs(args.log_dir)
186
+ zip_python_code(f'{args.log_dir}/repo_state.zip')
187
+ with open(f'{args.log_dir}/args.txt', 'w') as f:
188
+ print(args, file=f)
189
+ if world_size > 1: dist.barrier() # Sync after rank 0 setup
190
+
191
+
192
+ assert args.dataset in ['cifar10', 'cifar100', 'imagenet']
193
+
194
+ # Data Loading
195
+ train_data, test_data, class_labels, dataset_mean, dataset_std = get_dataset(args.dataset, args.data_root)
196
+
197
+ # Setup Samplers
198
+ # This custom sampler is useful when using large batch sizes for Cifar. Otherwise the reshuffle happens tediously often
199
+ train_sampler = (FastRandomDistributedSampler(train_data, num_replicas=world_size, rank=rank, seed=args.seed, epoch_steps=int(10e10))
200
+ if args.use_custom_sampler else
201
+ DistributedSampler(train_data, num_replicas=world_size, rank=rank, shuffle=True, seed=args.seed))
202
+ test_sampler = DistributedSampler(test_data, num_replicas=world_size, rank=rank, shuffle=False, seed=args.seed) # No shuffle needed for test; consistent
203
+
204
+ # Setup DataLoaders
205
+ trainloader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, sampler=train_sampler,
206
+ num_workers=args.num_workers_train, pin_memory=True, drop_last=True) # drop_last=True often used in DDP
207
+ testloader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size_test, sampler=test_sampler,
208
+ num_workers=1, pin_memory=True, drop_last=False)
209
+
210
+
211
+ prediction_reshaper = [-1] # Task specific
212
+ args.out_dims = len(class_labels)
213
+
214
+ # Setup Device
215
+ if torch.cuda.is_available():
216
+ device = torch.device(f'cuda:{local_rank}')
217
+ else:
218
+ device = torch.device('cpu')
219
+ if world_size > 1:
220
+ warnings.warn("Running DDP on CPU is not recommended.")
221
+ if is_main_process(rank):
222
+ print(f'Main process (Rank {rank}): Using device {device}. World size: {world_size}. Model: {args.model}')
223
+
224
+ # --- Model Definition (Conditional) ---
225
+ model_base = None # Base model before DDP wrapping
226
+ if args.model == 'ctm':
227
+ model_base = ContinuousThoughtMachine(
228
+ iterations=args.iterations,
229
+ d_model=args.d_model,
230
+ d_input=args.d_input,
231
+ heads=args.heads,
232
+ n_synch_out=args.n_synch_out,
233
+ n_synch_action=args.n_synch_action,
234
+ synapse_depth=args.synapse_depth,
235
+ memory_length=args.memory_length,
236
+ deep_nlms=args.deep_memory,
237
+ memory_hidden_dims=args.memory_hidden_dims,
238
+ do_layernorm_nlm=args.do_normalisation,
239
+ backbone_type=args.backbone_type,
240
+ positional_embedding_type=args.positional_embedding_type,
241
+ out_dims=args.out_dims,
242
+ prediction_reshaper=prediction_reshaper,
243
+ dropout=args.dropout,
244
+ dropout_nlm=args.dropout_nlm,
245
+ neuron_select_type=args.neuron_select_type,
246
+ n_random_pairing_self=args.n_random_pairing_self,
247
+ ).to(device)
248
+ elif args.model == 'lstm':
249
+ model_base = LSTMBaseline(
250
+ num_layers=args.num_layers,
251
+ iterations=args.iterations,
252
+ d_model=args.d_model,
253
+ d_input=args.d_input,
254
+ heads=args.heads,
255
+ backbone_type=args.backbone_type,
256
+ positional_embedding_type=args.positional_embedding_type,
257
+ out_dims=args.out_dims,
258
+ prediction_reshaper=prediction_reshaper,
259
+ dropout=args.dropout,
260
+ start_type=args.start_type,
261
+ ).to(device)
262
+ elif args.model == 'ff':
263
+ model_base = FFBaseline(
264
+ d_model=args.d_model,
265
+ backbone_type=args.backbone_type,
266
+ out_dims=args.out_dims,
267
+ dropout=args.dropout,
268
+ ).to(device)
269
+ else:
270
+ raise ValueError(f"Unknown model type: {args.model}")
271
+
272
+ # Initialize lazy modules if any
273
+ try:
274
+ pseudo_inputs = train_data.__getitem__(0)[0].unsqueeze(0).to(device)
275
+ model_base(pseudo_inputs)
276
+ except Exception as e:
277
+ print(f"Warning: Pseudo forward pass failed: {e}")
278
+
279
+ # Wrap model with DDP
280
+ if device.type == 'cuda' and world_size > 1:
281
+ model = DDP(model_base, device_ids=[local_rank], output_device=local_rank)
282
+ elif device.type == 'cpu' and world_size > 1:
283
+ model = DDP(model_base) # No device_ids for CPU
284
+ else: # Single process run
285
+ model = model_base # No DDP wrapping needed
286
+
287
+ if is_main_process(rank):
288
+ # Access underlying model for param count
289
+ param_count = sum(p.numel() for p in model.module.parameters() if p.requires_grad) if world_size > 1 else sum(p.numel() for p in model.parameters() if p.requires_grad)
290
+ print(f'Total trainable params: {param_count}')
291
+ # --- End Model Definition ---
292
+
293
+
294
+ # Optimizer and scheduler
295
+ # Use model.parameters() directly, DDP handles it
296
+ decay_params = []
297
+ no_decay_params = []
298
+ no_decay_names = []
299
+ for name, param in model.named_parameters():
300
+ if not param.requires_grad:
301
+ continue # Skip parameters that don't require gradients
302
+ if any(exclusion_str in name for exclusion_str in args.weight_decay_exclusion_list):
303
+ no_decay_params.append(param)
304
+ no_decay_names.append(name)
305
+ else:
306
+ decay_params.append(param)
307
+ if len(no_decay_names) and is_main_process(rank):
308
+ print(f'WARNING, excluding: {no_decay_names}')
309
+
310
+ # Optimizer and scheduler (Common setup)
311
+ if len(no_decay_names) and args.weight_decay!=0:
312
+ optimizer = torch.optim.AdamW([{'params': decay_params, 'weight_decay':args.weight_decay},
313
+ {'params': no_decay_params, 'weight_decay':0}],
314
+ lr=args.lr,
315
+ eps=1e-8 if not args.use_amp else 1e-6)
316
+ else:
317
+ optimizer = torch.optim.AdamW(model.parameters(),
318
+ lr=args.lr,
319
+ eps=1e-8 if not args.use_amp else 1e-6,
320
+ weight_decay=args.weight_decay)
321
+
322
+ warmup_schedule = warmup(args.warmup_steps)
323
+ scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warmup_schedule.step)
324
+ if args.use_scheduler:
325
+ if args.scheduler_type == 'multistep':
326
+ scheduler = WarmupMultiStepLR(optimizer, warmup_steps=args.warmup_steps, milestones=args.milestones, gamma=args.gamma)
327
+ elif args.scheduler_type == 'cosine':
328
+ scheduler = WarmupCosineAnnealingLR(optimizer, args.warmup_steps, args.training_iterations, warmup_start_lr=1e-20, eta_min=1e-7)
329
+ else:
330
+ raise NotImplementedError
331
+
332
+
333
+ # Metrics tracking (on Rank 0)
334
+ start_iter = 0
335
+ train_losses = []
336
+ test_losses = []
337
+ train_accuracies = [] # Placeholder for potential detailed accuracy
338
+ test_accuracies = [] # Placeholder for potential detailed accuracy
339
+ # Conditional metrics
340
+ train_accuracies_most_certain = [] if args.model in ['ctm', 'lstm'] else None # Scalar accuracy list
341
+ test_accuracies_most_certain = [] if args.model in ['ctm', 'lstm'] else None # Scalar accuracy list
342
+ train_accuracies_standard = [] if args.model == 'ff' else None # Standard accuracy list for FF
343
+ test_accuracies_standard = [] if args.model == 'ff' else None # Standard accuracy list for FF
344
+ iters = []
345
+
346
+ scaler = torch.amp.GradScaler("cuda" if device.type == 'cuda' else "cpu", enabled=args.use_amp)
347
+ # Reloading Logic
348
+ if args.reload:
349
+ map_location = device # Load directly onto the process's device
350
+ chkpt_path = f'{args.log_dir}/checkpoint.pt'
351
+ if os.path.isfile(chkpt_path):
352
+ print(f'Rank {rank}: Reloading from: {chkpt_path}')
353
+ checkpoint = torch.load(chkpt_path, map_location=map_location, weights_only=False)
354
+
355
+ # Determine underlying model based on whether DDP wrapping occurred
356
+ model_to_load = model.module if isinstance(model, DDP) else model
357
+
358
+ # Handle potential 'module.' prefix in saved state_dict
359
+ state_dict = checkpoint['model_state_dict']
360
+ has_module_prefix = all(k.startswith('module.') for k in state_dict)
361
+ is_wrapped = isinstance(model, DDP)
362
+
363
+ if has_module_prefix and not is_wrapped:
364
+ # Saved with DDP, loading into non-DDP model -> remove prefix
365
+ state_dict = {k.partition('module.')[2]: v for k,v in state_dict.items()}
366
+ elif not has_module_prefix and is_wrapped:
367
+ load_result = model_to_load.load_state_dict(state_dict, strict=args.strict_reload)
368
+ print(f" Loaded state_dict. Missing: {load_result.missing_keys}, Unexpected: {load_result.unexpected_keys}")
369
+ state_dict = None # Prevent loading again
370
+
371
+ if state_dict is not None:
372
+ load_result = model_to_load.load_state_dict(state_dict, strict=args.strict_reload)
373
+ print(f" Loaded state_dict. Missing: {load_result.missing_keys}, Unexpected: {load_result.unexpected_keys}")
374
+
375
+
376
+ if not args.reload_model_only:
377
+ print(f'Rank {rank}: Reloading optimizer, scheduler, scaler, iteration.')
378
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
379
+ scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
380
+ scaler_state_dict = checkpoint['scaler_state_dict']
381
+ if scaler.is_enabled():
382
+ print("Loading non-empty GradScaler state dict.")
383
+ try:
384
+ scaler.load_state_dict(scaler_state_dict)
385
+ except Exception as e:
386
+ print(f"Error loading GradScaler state dict: {e}")
387
+ print("Continuing with a fresh GradScaler state.")
388
+
389
+ start_iter = checkpoint['iteration']
390
+ # Only rank 0 loads metric history
391
+ if is_main_process(rank) and not args.ignore_metrics_when_reloading:
392
+ print(f'Rank {rank}: Reloading metrics history.')
393
+ iters = checkpoint['iters']
394
+ train_losses = checkpoint['train_losses']
395
+ test_losses = checkpoint['test_losses']
396
+ train_accuracies = checkpoint['train_accuracies']
397
+ test_accuracies = checkpoint['test_accuracies']
398
+ if args.model in ['ctm', 'lstm']:
399
+ train_accuracies_most_certain = checkpoint['train_accuracies_most_certain']
400
+ test_accuracies_most_certain = checkpoint['test_accuracies_most_certain']
401
+ elif args.model == 'ff':
402
+ train_accuracies_standard = checkpoint['train_accuracies_standard']
403
+ test_accuracies_standard = checkpoint['test_accuracies_standard']
404
+ elif is_main_process(rank) and args.ignore_metrics_when_reloading:
405
+ print(f'Rank {rank}: Ignoring metrics history upon reload.')
406
+
407
+ else:
408
+ print(f'Rank {rank}: Only reloading model weights!')
409
+
410
+ # Load RNG states
411
+ if is_main_process(rank) and 'torch_rng_state' in checkpoint and not args.reload_model_only:
412
+ print(f'Rank {rank}: Loading RNG states (may need DDP adaptation for full reproducibility).')
413
+ torch.set_rng_state(checkpoint['torch_rng_state'].cpu()) # Load CPU state
414
+ # Add CUDA state loading if needed, ensuring correct device handling
415
+ np.random.set_state(checkpoint['numpy_rng_state'])
416
+ random.setstate(checkpoint['random_rng_state'])
417
+
418
+ del checkpoint
419
+ if torch.cuda.is_available(): torch.cuda.empty_cache()
420
+ print(f"Rank {rank}: Reload finished, starting from iteration {start_iter}")
421
+ else:
422
+ print(f"Rank {rank}: Checkpoint not found at {chkpt_path}, starting from scratch.")
423
+ if world_size > 1: dist.barrier() # Sync after loading
424
+
425
+
426
+ # Conditional Compilation
427
+ if args.do_compile:
428
+ if is_main_process(rank): print('Compiling model components...')
429
+ # Compile on the underlying model if wrapped
430
+ model_to_compile = model.module if isinstance(model, DDP) else model
431
+ if hasattr(model_to_compile, 'backbone'):
432
+ model_to_compile.backbone = torch.compile(model_to_compile.backbone, mode='reduce-overhead', fullgraph=True)
433
+ if args.model == 'ctm':
434
+ if hasattr(model_to_compile, 'synapses'):
435
+ model_to_compile.synapses = torch.compile(model_to_compile.synapses, mode='reduce-overhead', fullgraph=True)
436
+ if world_size > 1: dist.barrier() # Sync after compilation
437
+ if is_main_process(rank): print('Compilation finished.')
438
+
439
+
440
+ # --- Training Loop ---
441
+ model.train() # Ensure model is in train mode
442
+ pbar = tqdm(total=args.training_iterations, initial=start_iter, leave=False, position=0, dynamic_ncols=True, disable=not is_main_process(rank))
443
+
444
+ iterator = iter(trainloader)
445
+
446
+ for bi in range(start_iter, args.training_iterations):
447
+
448
+ # Set sampler epoch (important for shuffling in DistributedSampler)
449
+ if not args.use_custom_sampler and hasattr(train_sampler, 'set_epoch'):
450
+ train_sampler.set_epoch(bi)
451
+
452
+ current_lr = optimizer.param_groups[-1]['lr']
453
+
454
+ time_start_data = time.time()
455
+ try:
456
+ inputs, targets = next(iterator)
457
+ except StopIteration:
458
+ # Reset iterator - set_epoch handles shuffling if needed
459
+ iterator = iter(trainloader)
460
+ inputs, targets = next(iterator)
461
+
462
+
463
+ inputs = inputs.to(device, non_blocking=True)
464
+ targets = targets.to(device, non_blocking=True)
465
+ time_end_data = time.time()
466
+
467
+ loss = None
468
+ # Model-specific forward and loss calculation
469
+ time_start_forward = time.time()
470
+ with torch.autocast(device_type="cuda" if device.type == 'cuda' else "cpu", dtype=torch.float16, enabled=args.use_amp):
471
+ if args.do_compile:
472
+ torch.compiler.cudagraph_mark_step_begin()
473
+
474
+ if args.model == 'ctm':
475
+ predictions, certainties, synchronisation = model(inputs)
476
+ loss, where_most_certain = image_classification_loss(predictions, certainties, targets, use_most_certain=True)
477
+ elif args.model == 'lstm':
478
+ predictions, certainties, synchronisation = model(inputs)
479
+ loss, where_most_certain = image_classification_loss(predictions, certainties, targets, use_most_certain=True)
480
+ elif args.model == 'ff':
481
+ predictions = model(inputs) # FF returns only predictions
482
+ loss = nn.CrossEntropyLoss()(predictions, targets)
483
+ where_most_certain = None # Not applicable for FF standard loss
484
+ time_end_forward = time.time()
485
+ time_start_backward = time.time()
486
+
487
+ scaler.scale(loss).backward() # DDP handles gradient synchronization
488
+ time_end_backward = time.time()
489
+
490
+ if args.gradient_clipping!=-1:
491
+ scaler.unscale_(optimizer)
492
+ # Clip gradients across all parameters controlled by the optimizer
493
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.gradient_clipping)
494
+
495
+ scaler.step(optimizer)
496
+ scaler.update()
497
+ optimizer.zero_grad(set_to_none=True)
498
+ scheduler.step()
499
+
500
+ # --- Aggregation and Logging (Rank 0) ---
501
+ # Aggregate loss for logging
502
+ loss_log = loss.detach() # Use detached loss for aggregation
503
+ if world_size > 1: dist.all_reduce(loss_log, op=dist.ReduceOp.AVG)
504
+
505
+ if is_main_process(rank):
506
+ # Calculate accuracy locally on rank 0 for description (approximate)
507
+ # Note: This uses rank 0's batch, not aggregated accuracy
508
+ accuracy_local = 0.0
509
+ if args.model in ['ctm', 'lstm']:
510
+ accuracy_local = (predictions.argmax(1)[torch.arange(predictions.size(0), device=device), where_most_certain] == targets).float().mean().item()
511
+ where_certain_tensor = where_most_certain.float() # Use rank 0's tensor for stats
512
+ pbar_desc = f'Timing; d={(time_end_data-time_start_data):0.3f}, f={(time_end_forward-time_start_forward):0.3f}, b={(time_end_backward-time_start_backward):0.3f}. Loss(avg)={loss_log.item():.3f} Acc(loc)={accuracy_local:.3f} LR={current_lr:.6f} WhereCert(loc)={where_certain_tensor.mean().item():.2f}'
513
+ elif args.model == 'ff':
514
+ accuracy_local = (predictions.argmax(1) == targets).float().mean().item()
515
+ pbar_desc = f'Timing; d={(time_end_data-time_start_data):0.3f}, f={(time_end_forward-time_start_forward):0.3f}, b={(time_end_backward-time_start_backward):0.3f}. Loss(avg)={loss_log.item():.3f} Acc(loc)={accuracy_local:.3f} LR={current_lr:.6f}'
516
+
517
+ pbar.set_description(f'{args.model.upper()} {pbar_desc}')
518
+ # --- End Aggregation and Logging ---
519
+
520
+
521
+ # --- Evaluation and Plotting (Rank 0 + Aggregation) ---
522
+ if bi % args.track_every == 0 and (bi != 0 or args.reload_model_only):
523
+
524
+ model.eval()
525
+ with torch.inference_mode():
526
+
527
+
528
+ # --- Distributed Evaluation ---
529
+ iters.append(bi)
530
+
531
+ # TRAIN METRICS
532
+ total_train_loss = torch.tensor(0.0, device=device)
533
+ total_train_correct_certain = torch.tensor(0.0, device=device) # CTM/LSTM
534
+ total_train_correct_standard = torch.tensor(0.0, device=device) # FF
535
+ total_train_samples = torch.tensor(0.0, device=device)
536
+
537
+ # Use a sampler for evaluation to ensure non-overlapping data if needed
538
+ train_eval_sampler = DistributedSampler(train_data, num_replicas=world_size, rank=rank, shuffle=False)
539
+ train_eval_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size_test, sampler=train_eval_sampler, num_workers=1, pin_memory=True)
540
+
541
+ pbar_inner_desc = 'Eval Train (Rank 0)' if is_main_process(rank) else None
542
+ with tqdm(total=len(train_eval_loader), desc=pbar_inner_desc, leave=False, position=1, dynamic_ncols=True, disable=not is_main_process(rank)) as pbar_inner:
543
+ for inferi, (inputs, targets) in enumerate(train_eval_loader):
544
+ inputs = inputs.to(device, non_blocking=True)
545
+ targets = targets.to(device, non_blocking=True)
546
+
547
+ loss_eval = None
548
+ if args.model == 'ctm':
549
+ predictions, certainties, _ = model(inputs)
550
+ loss_eval, where_most_certain = image_classification_loss(predictions, certainties, targets, use_most_certain=True)
551
+ preds_eval = predictions.argmax(1)[torch.arange(predictions.size(0), device=device), where_most_certain]
552
+ total_train_correct_certain += (preds_eval == targets).sum()
553
+ elif args.model == 'lstm':
554
+ predictions, certainties, _ = model(inputs)
555
+ loss_eval, where_most_certain = image_classification_loss(predictions, certainties, targets, use_most_certain=True)
556
+ preds_eval = predictions.argmax(1)[torch.arange(predictions.size(0), device=device), where_most_certain]
557
+ total_train_correct_certain += (preds_eval == targets).sum()
558
+ elif args.model == 'ff':
559
+ predictions = model(inputs)
560
+ loss_eval = nn.CrossEntropyLoss()(predictions, targets)
561
+ preds_eval = predictions.argmax(1)
562
+ total_train_correct_standard += (preds_eval == targets).sum()
563
+
564
+ total_train_loss += loss_eval * inputs.size(0)
565
+ total_train_samples += inputs.size(0)
566
+
567
+ if args.n_test_batches != -1 and inferi >= args.n_test_batches -1: break
568
+ pbar_inner.update(1)
569
+
570
+ # Aggregate Train Metrics
571
+ if world_size > 1:
572
+ dist.all_reduce(total_train_loss, op=dist.ReduceOp.SUM)
573
+ dist.all_reduce(total_train_correct_certain, op=dist.ReduceOp.SUM)
574
+ dist.all_reduce(total_train_correct_standard, op=dist.ReduceOp.SUM)
575
+ dist.all_reduce(total_train_samples, op=dist.ReduceOp.SUM)
576
+
577
+ # Calculate final Train metrics on Rank 0
578
+ if is_main_process(rank) and total_train_samples > 0:
579
+ avg_train_loss = total_train_loss.item() / total_train_samples.item()
580
+ train_losses.append(avg_train_loss)
581
+ if args.model in ['ctm', 'lstm']:
582
+ avg_train_acc_certain = total_train_correct_certain.item() / total_train_samples.item()
583
+ train_accuracies_most_certain.append(avg_train_acc_certain)
584
+ elif args.model == 'ff':
585
+ avg_train_acc_standard = total_train_correct_standard.item() / total_train_samples.item()
586
+ train_accuracies_standard.append(avg_train_acc_standard)
587
+ print(f"Iter {bi} Train Metrics (Agg): Loss={avg_train_loss:.4f}")
588
+
589
+ # TEST METRICS
590
+ total_test_loss = torch.tensor(0.0, device=device)
591
+ total_test_correct_certain = torch.tensor(0.0, device=device) # CTM/LSTM
592
+ total_test_correct_standard = torch.tensor(0.0, device=device) # FF
593
+ total_test_samples = torch.tensor(0.0, device=device)
594
+
595
+ pbar_inner_desc = 'Eval Test (Rank 0)' if is_main_process(rank) else None
596
+ with tqdm(total=len(testloader), desc=pbar_inner_desc, leave=False, position=1, dynamic_ncols=True, disable=not is_main_process(rank)) as pbar_inner:
597
+ for inferi, (inputs, targets) in enumerate(testloader): # Testloader already uses sampler
598
+ inputs = inputs.to(device, non_blocking=True)
599
+ targets = targets.to(device, non_blocking=True)
600
+
601
+ loss_eval = None
602
+ if args.model == 'ctm':
603
+ predictions, certainties, _ = model(inputs)
604
+ loss_eval, where_most_certain = image_classification_loss(predictions, certainties, targets, use_most_certain=True)
605
+ preds_eval = predictions.argmax(1)[torch.arange(predictions.size(0), device=device), where_most_certain]
606
+ total_test_correct_certain += (preds_eval == targets).sum()
607
+ elif args.model == 'lstm':
608
+ predictions, certainties, _ = model(inputs)
609
+ loss_eval, where_most_certain = image_classification_loss(predictions, certainties, targets, use_most_certain=True)
610
+ preds_eval = predictions.argmax(1)[torch.arange(predictions.size(0), device=device), where_most_certain]
611
+ total_test_correct_certain += (preds_eval == targets).sum()
612
+ elif args.model == 'ff':
613
+ predictions = model(inputs)
614
+ loss_eval = nn.CrossEntropyLoss()(predictions, targets)
615
+ preds_eval = predictions.argmax(1)
616
+ total_test_correct_standard += (preds_eval == targets).sum()
617
+
618
+ total_test_loss += loss_eval * inputs.size(0)
619
+ total_test_samples += inputs.size(0)
620
+
621
+ if args.n_test_batches != -1 and inferi >= args.n_test_batches -1: break
622
+ pbar_inner.update(1)
623
+
624
+ # Aggregate Test Metrics
625
+ if world_size > 1:
626
+ dist.all_reduce(total_test_loss, op=dist.ReduceOp.SUM)
627
+ dist.all_reduce(total_test_correct_certain, op=dist.ReduceOp.SUM)
628
+ dist.all_reduce(total_test_correct_standard, op=dist.ReduceOp.SUM)
629
+ dist.all_reduce(total_test_samples, op=dist.ReduceOp.SUM)
630
+
631
+ # Calculate and Plot final Test metrics on Rank 0
632
+ if is_main_process(rank) and total_test_samples > 0:
633
+ avg_test_loss = total_test_loss.item() / total_test_samples.item()
634
+ test_losses.append(avg_test_loss)
635
+ acc_label = ''
636
+ acc_val = 0.0
637
+ if args.model in ['ctm', 'lstm']:
638
+ avg_test_acc_certain = total_test_correct_certain.item() / total_test_samples.item()
639
+ test_accuracies_most_certain.append(avg_test_acc_certain)
640
+ acc_label = f'Most certain ({avg_test_acc_certain:.3f})'
641
+ acc_val = avg_test_acc_certain
642
+ elif args.model == 'ff':
643
+ avg_test_acc_standard = total_test_correct_standard.item() / total_test_samples.item()
644
+ test_accuracies_standard.append(avg_test_acc_standard)
645
+ acc_label = f'Standard Acc ({avg_test_acc_standard:.3f})'
646
+ acc_val = avg_test_acc_standard
647
+ print(f"Iter {bi} Test Metrics (Agg): Loss={avg_test_loss:.4f}, Acc={acc_val:.4f}\n")
648
+
649
+
650
+ # --- Plotting ---
651
+ figacc = plt.figure(figsize=(10, 10))
652
+ axacc_train = figacc.add_subplot(211)
653
+ axacc_test = figacc.add_subplot(212)
654
+
655
+ if args.model in ['ctm', 'lstm']:
656
+ axacc_train.plot(iters, train_accuracies_most_certain, 'k-', alpha=0.9, label=f'Most certain ({train_accuracies_most_certain[-1]:.3f})')
657
+ axacc_test.plot(iters, test_accuracies_most_certain, 'k-', alpha=0.9, label=acc_label)
658
+ elif args.model == 'ff':
659
+ axacc_train.plot(iters, train_accuracies_standard, 'k-', alpha=0.9, label=f'Standard Acc ({train_accuracies_standard[-1]:.3f})')
660
+ axacc_test.plot(iters, test_accuracies_standard, 'k-', alpha=0.9, label=acc_label)
661
+
662
+ axacc_train.set_title('Train Accuracy (Aggregated)')
663
+ axacc_test.set_title('Test Accuracy (Aggregated)')
664
+ axacc_train.legend(loc='lower right')
665
+ axacc_test.legend(loc='lower right')
666
+ axacc_train.set_xlim([0, args.training_iterations])
667
+ axacc_test.set_xlim([0, args.training_iterations])
668
+
669
+ # Keep dataset specific ylim adjustments if needed
670
+ if args.dataset == 'imagenet':
671
+ # For easy comparison when training
672
+ train_ylim_set = False
673
+ if args.model in ['ctm', 'lstm'] and len(train_accuracies_most_certain)>0 and np.any(np.array(train_accuracies_most_certain)>0.4): train_ylim_set=True; axacc_train.set_ylim([0.4, 1])
674
+ if args.model == 'ff' and len(train_accuracies_standard)>0 and np.any(np.array(train_accuracies_standard)>0.4): train_ylim_set=True; axacc_train.set_ylim([0.4, 1])
675
+
676
+ test_ylim_set = False
677
+ if args.model in ['ctm', 'lstm'] and len(test_accuracies_most_certain)>0 and np.any(np.array(test_accuracies_most_certain)>0.3): test_ylim_set=True; axacc_test.set_ylim([0.3, 0.8])
678
+ if args.model == 'ff' and len(test_accuracies_standard)>0 and np.any(np.array(test_accuracies_standard)>0.3): test_ylim_set=True; axacc_test.set_ylim([0.3, 0.8])
679
+
680
+
681
+ figacc.tight_layout()
682
+ figacc.savefig(f'{args.log_dir}/accuracies.png', dpi=150)
683
+ plt.close(figacc)
684
+
685
+ # Loss Plot
686
+ figloss = plt.figure(figsize=(10, 5))
687
+ axloss = figloss.add_subplot(111)
688
+ axloss.plot(iters, train_losses, 'b-', linewidth=1, alpha=0.8, label=f'Train (Aggregated): {train_losses[-1]:.4f}')
689
+ axloss.plot(iters, test_losses, 'r-', linewidth=1, alpha=0.8, label=f'Test (Aggregated): {test_losses[-1]:.4f}')
690
+ axloss.legend(loc='upper right')
691
+ axloss.set_xlabel("Iteration")
692
+ axloss.set_ylabel("Loss")
693
+ axloss.set_xlim([0, args.training_iterations])
694
+ axloss.set_ylim(bottom=0)
695
+ figloss.tight_layout()
696
+ figloss.savefig(f'{args.log_dir}/losses.png', dpi=150)
697
+ plt.close(figloss)
698
+ # --- End Plotting ---
699
+
700
+ # Visualization on Rank 0
701
+ if is_main_process(rank) and args.model in ['ctm', 'lstm']:
702
+ try:
703
+ model_module = model.module if isinstance(model, DDP) else model # Get underlying model
704
+ # Simplified viz: use first batch from testloader
705
+ inputs_viz, targets_viz = next(iter(testloader))
706
+ inputs_viz = inputs_viz.to(device)
707
+ targets_viz = targets_viz.to(device)
708
+
709
+ pbar.set_description('Tracking (Rank 0): Viz Fwd Pass')
710
+ predictions_viz, certainties_viz, _, pre_activations_viz, post_activations_viz, attention_tracking_viz = model_module(inputs_viz, track=True)
711
+
712
+ att_shape = (model_module.kv_features.shape[2], model_module.kv_features.shape[3])
713
+ attention_tracking_viz = attention_tracking_viz.reshape(
714
+ attention_tracking_viz.shape[0],
715
+ attention_tracking_viz.shape[1], -1, att_shape[0], att_shape[1])
716
+
717
+
718
+ pbar.set_description('Tracking (Rank 0): Dynamics Plot')
719
+ plot_neural_dynamics(post_activations_viz, 100, args.log_dir, axis_snap=True)
720
+
721
+ # Plot specific indices from test_data directly
722
+ pbar.set_description('Tracking (Rank 0): GIF Generation')
723
+ for plot_idx in args.plot_indices:
724
+ try:
725
+ if plot_idx < len(test_data):
726
+ inputs_plot, target_plot = test_data.__getitem__(plot_idx)
727
+ inputs_plot = inputs_plot.unsqueeze(0).to(device)
728
+
729
+ preds_plot, certs_plot, _, _, posts_plot, atts_plot = model_module(inputs_plot, track=True)
730
+ atts_plot = atts_plot.reshape(atts_plot.shape[0], atts_plot.shape[1], -1, att_shape[0], att_shape[1])
731
+
732
+
733
+ img_gif = np.moveaxis(np.clip(inputs_plot[0].detach().cpu().numpy()*np.array(dataset_std).reshape(len(dataset_std), 1, 1) + np.array(dataset_mean).reshape(len(dataset_mean), 1, 1), 0, 1), 0, -1)
734
+
735
+ make_classification_gif(img_gif, target_plot, preds_plot[0].detach().cpu().numpy(), certs_plot[0].detach().cpu().numpy(),
736
+ posts_plot[:,0], atts_plot[:,0] if atts_plot is not None else None, class_labels,
737
+ f'{args.log_dir}/idx{plot_idx}_attention.gif')
738
+ else:
739
+ print(f"Warning: Plot index {plot_idx} out of range for test dataset size {len(test_data)}.")
740
+ except Exception as e_gif:
741
+ print(f"Rank 0 GIF generation failed for index {plot_idx}: {e_gif}")
742
+
743
+ except Exception as e_viz:
744
+ print(f"Rank 0 visualization failed: {e_viz}")
745
+
746
+
747
+
748
+ if world_size > 1: dist.barrier() # Sync after evaluation block
749
+ model.train() # Set back to train mode
750
+ # --- End Evaluation Block ---
751
+
752
+
753
+ # --- Checkpointing (Rank 0) ---
754
+ if (bi % args.save_every == 0 or bi == args.training_iterations - 1) and bi != start_iter and is_main_process(rank):
755
+ pbar.set_description('Rank 0: Saving checkpoint...')
756
+ save_path = f'{args.log_dir}/checkpoint.pt'
757
+ # Access underlying model state dict if DDP is used
758
+ model_state_to_save = model.module.state_dict() if isinstance(model, DDP) else model.state_dict()
759
+
760
+ save_dict = {
761
+ 'model_state_dict': model_state_to_save,
762
+ 'optimizer_state_dict': optimizer.state_dict(),
763
+ 'scheduler_state_dict': scheduler.state_dict(),
764
+ 'scaler_state_dict':scaler.state_dict(),
765
+ 'iteration': bi,
766
+ 'train_losses': train_losses,
767
+ 'test_losses': test_losses,
768
+ 'iters': iters,
769
+ 'args': args,
770
+ 'torch_rng_state': torch.get_rng_state(), # CPU state
771
+ 'numpy_rng_state': np.random.get_state(),
772
+ 'random_rng_state': random.getstate(),
773
+ # Include conditional metrics
774
+ 'train_accuracies': train_accuracies, # Placeholder
775
+ 'test_accuracies': test_accuracies, # Placeholder
776
+ }
777
+ if args.model in ['ctm', 'lstm']:
778
+ save_dict['train_accuracies_most_certain'] = train_accuracies_most_certain
779
+ save_dict['test_accuracies_most_certain'] = test_accuracies_most_certain
780
+ elif args.model == 'ff':
781
+ save_dict['train_accuracies_standard'] = train_accuracies_standard
782
+ save_dict['test_accuracies_standard'] = test_accuracies_standard
783
+
784
+ torch.save(save_dict , save_path)
785
+ pbar.set_description(f"Rank 0: Checkpoint saved to {save_path}")
786
+ # --- End Checkpointing ---
787
+
788
+
789
+ if world_size > 1: dist.barrier() # Sync before next iteration
790
+
791
+ # Update pbar on Rank 0
792
+ if is_main_process(rank):
793
+ pbar.update(1)
794
+ # --- End Training Loop ---
795
+
796
+ if is_main_process(rank):
797
+ pbar.close()
798
+
799
+ cleanup_ddp() # Cleanup DDP resources
tasks/mazes/README.md ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Mazes
2
+
3
+ This folder contains code for training and analysing 2D maze solving experiments
4
+
5
+
6
+ ## Training
7
+ To run the maze training that we used for the paper, run the following command from the parent directory:
8
+ ```
9
+ python -m tasks.mazes.train --d_model 2048 --d_input 512 --synapse_depth 4 --heads 8 --n_synch_out 64 --n_synch_action 32 --neuron_select_type first-last --iterations 75 --memory_length 25 --deep_memory --memory_hidden_dims 32 --dropout 0.1 --no-do_normalisation --positional_embedding_type none --backbone_type resnet34-2 --batch_size 64 --batch_size_test 64 --lr 1e-4 --training_iterations 1000001 --warmup_steps 10000 --use_scheduler --scheduler_type cosine --weight_decay 0.0 --log_dir logs/mazes/d=2048--i=512--h=8--ns=64-32--iters=75x25--h=32--drop=0.1--pos=none--back=34-2--seed=42 --dataset mazes-medium --save_every 2000 --track_every 5000 --seed 42 --n_test_batches 50
10
+ ```
11
+
12
+ ## Small training run
13
+ We also provide a 'mazes-small' dataset (see [here](https://drive.google.com/file/d/1cBgqhaUUtsrll8-o2VY42hPpyBcfFv86/view?usp=drivesdk)) for fast iteration and testing ideas. The following command can train a CTM locally without a GPU in 12-24 hours:
14
+ ```
15
+ python -m tasks.mazes.train --dataset mazes-small --maze_route_length 50 --cirriculum_lookahead 5 --model ctm --d_model 1024 --d_input 256 --backbone_type resnet18-1 --synapse_depth 8 --heads 4 --n_synch_out 128 --n_synch_action 128 --neuron_select_type random-pairing --memory_length 25 --iterations 50 --training_iterations 100001 --lr 1e-4 --batch_size 64 --batch_size_test 32 --n_test_batches 50 --log_dir logs/mazes-small-tester --track_every 2000
16
+ ```
tasks/mazes/analysis/README.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Analysis
2
+
3
+ This folder contains analysis code for 2D maze experiments. To build GIFs for imagenet run (from the base directory):
4
+
5
+ To run maze analysis run the following command from the parent directory:
6
+ ```
7
+ python -m tasks.mazes.analysis.run --actions viz viz --checkpoint checkpoints/mazes/ctm_mazeslarge_D=2048_T=75_M=25.pt
8
+ ```
9
+
10
+ You will need to download the checkpoint from here: https://drive.google.com/file/d/1vGiMaQCxzKVT68SipxDCW0W5n5jjEQnC/view?usp=drive_link . Extract this to the appropriate directory: `checkpoints/mazes/...` . Otherwise, use your own after training.
tasks/mazes/analysis/run.py ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ np.seterr(divide='ignore', invalid='warn') # Keep specific numpy error settings
4
+ import matplotlib as mpl
5
+ mpl.use('Agg') # Use Agg backend for matplotlib (important to set before importing pyplot)
6
+ import matplotlib.pyplot as plt
7
+ import seaborn as sns
8
+ sns.set_style('darkgrid') # Keep seaborn style
9
+ import os
10
+ import argparse
11
+ import cv2
12
+ import imageio # Used for saving GIFs in viz
13
+
14
+ # Local imports
15
+ from data.custom_datasets import MazeImageFolder
16
+ from models.ctm import ContinuousThoughtMachine
17
+ from tasks.mazes.plotting import draw_path #
18
+ from tasks.image_classification.plotting import save_frames_to_mp4
19
+
20
+ def has_solved_checker(x_maze, route, valid_only=True, fault_tolerance=1, exclusions=[]):
21
+ """Checks if a route solves a maze."""
22
+ maze = np.copy(x_maze)
23
+ H, W, _ = maze.shape
24
+ start_coords = np.argwhere((maze == [1, 0, 0]).all(axis=2))
25
+ end_coords = np.argwhere((maze == [0, 1, 0]).all(axis=2))
26
+
27
+ if len(start_coords) == 0:
28
+ return False, (-1, -1), 0 # Cannot start
29
+
30
+ current_pos = tuple(start_coords[0])
31
+ target_pos = tuple(end_coords[0]) if len(end_coords) > 0 else None
32
+
33
+ mistakes_made = 0
34
+ final_pos = current_pos
35
+ path_taken_len = 0
36
+
37
+ for step in route:
38
+ if mistakes_made > fault_tolerance:
39
+ break
40
+
41
+ next_pos_candidate = list(current_pos) # Use a list for mutable coordinate calculation
42
+ if step == 0: next_pos_candidate[0] -= 1
43
+ elif step == 1: next_pos_candidate[0] += 1
44
+ elif step == 2: next_pos_candidate[1] -= 1
45
+ elif step == 3: next_pos_candidate[1] += 1
46
+ elif step == 4: pass # Stay in place
47
+ else: continue # Invalid step action
48
+ next_pos = tuple(next_pos_candidate)
49
+
50
+
51
+ is_invalid_step = False
52
+ # Check bounds first, then maze content if in bounds
53
+ if not (0 <= next_pos[0] < H and 0 <= next_pos[1] < W):
54
+ is_invalid_step = True
55
+ elif np.all(maze[next_pos] == [0, 0, 0]): # Wall
56
+ is_invalid_step = True
57
+
58
+ if is_invalid_step:
59
+ mistakes_made += 1
60
+ if valid_only:
61
+ continue
62
+
63
+ current_pos = next_pos
64
+ path_taken_len += 1
65
+
66
+ if target_pos and current_pos == target_pos:
67
+ if mistakes_made <= fault_tolerance:
68
+ return True, current_pos, path_taken_len
69
+
70
+ if mistakes_made <= fault_tolerance:
71
+ # Assuming exclusions is a list of tuples (as populated in the 'gen' action)
72
+ if current_pos not in exclusions:
73
+ final_pos = current_pos
74
+
75
+ if target_pos and final_pos == target_pos and mistakes_made <= fault_tolerance: # Added mistakes_made check here
76
+ return True, final_pos, path_taken_len
77
+ return False, final_pos, path_taken_len
78
+
79
+
80
+ def parse_args():
81
+ """Parses command-line arguments for maze analysis."""
82
+ parser = argparse.ArgumentParser(description="Analyze Asynchronous Thought Machine on Maze Tasks")
83
+ parser.add_argument('--actions', type=str, nargs='+', default=['gen'], help="Actions: 'viz', 'gen'")
84
+ parser.add_argument('--device', type=int, nargs='+', default=[-1], help="GPU device index or -1 for CPU")
85
+ parser.add_argument('--checkpoint', type=str, default='checkpoints/mazes/ctm_mazeslarge_D=2048_T=75_M=25.pt', help="Path to CTM checkpoint")
86
+ parser.add_argument('--output_dir', type=str, default='tasks/mazes/analysis/outputs', help="Directory for analysis outputs")
87
+ parser.add_argument('--dataset_for_viz', type=str, default='large', help="Dataset for 'viz' action")
88
+ parser.add_argument('--dataset_for_gen', type=str, default='extralarge', help="Dataset for 'gen' action")
89
+ parser.add_argument('--batch_size_test', type=int, default=32, help="Batch size for loading test data for 'viz'")
90
+ parser.add_argument('--max_reapplications', type=int, default=20, help="When testing generalisation to extra large mazes")
91
+ parser.add_argument('--legacy_scaling', action=argparse.BooleanOptionalAction, default=True, help='Legacy checkpoints scale between 0 and 1, new ones can scale -1 to 1.')
92
+ return parser.parse_args()
93
+
94
+ def _load_ctm_model(checkpoint_path, device):
95
+ """Loads the ContinuousThoughtMachine model from a checkpoint."""
96
+ print(f"Loading checkpoint: {checkpoint_path}")
97
+ checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
98
+ model_args = checkpoint['args']
99
+
100
+ # Handle legacy arguments for model_args
101
+ if not hasattr(model_args, 'backbone_type') and hasattr(model_args, 'resnet_type'):
102
+ model_args.backbone_type = f'{model_args.resnet_type}-{getattr(model_args, "resnet_feature_scales", [4])[-1]}'
103
+
104
+ # Ensure prediction_reshaper is derived correctly
105
+ # Assuming out_dims exists and is used for this
106
+ prediction_reshaper = [model_args.out_dims // 5, 5] if hasattr(model_args, 'out_dims') else None
107
+
108
+
109
+ if not hasattr(model_args, 'neuron_select_type'):
110
+ model_args.neuron_select_type = 'first-last'
111
+ if not hasattr(model_args, 'n_random_pairing_self'):
112
+ model_args.n_random_pairing_self = 0
113
+
114
+ print("Instantiating CTM model...")
115
+ model = ContinuousThoughtMachine(
116
+ iterations=model_args.iterations,
117
+ d_model=model_args.d_model,
118
+ d_input=model_args.d_input,
119
+ heads=model_args.heads,
120
+ n_synch_out=model_args.n_synch_out,
121
+ n_synch_action=model_args.n_synch_action,
122
+ synapse_depth=model_args.synapse_depth,
123
+ memory_length=model_args.memory_length,
124
+ deep_nlms=model_args.deep_memory, # Mapping from model_args.deep_memory
125
+ memory_hidden_dims=model_args.memory_hidden_dims,
126
+ do_layernorm_nlm=model_args.do_normalisation, # Mapping from model_args.do_normalisation
127
+ backbone_type=model_args.backbone_type,
128
+ positional_embedding_type=model_args.positional_embedding_type,
129
+ out_dims=model_args.out_dims,
130
+ prediction_reshaper=prediction_reshaper,
131
+ dropout=0, # Explicitly setting dropout to 0 as in original
132
+ neuron_select_type=model_args.neuron_select_type,
133
+ n_random_pairing_self=model_args.n_random_pairing_self,
134
+ ).to(device)
135
+
136
+ load_result = model.load_state_dict(checkpoint['state_dict'], strict=False)
137
+ print(f"Loaded state_dict. Missing keys: {load_result.missing_keys}, Unexpected keys: {load_result.unexpected_keys}")
138
+ model.eval()
139
+ return model
140
+
141
+ # --- Main Execution Block ---
142
+ if __name__=='__main__':
143
+ args = parse_args()
144
+
145
+ if args.device[0] != -1 and torch.cuda.is_available():
146
+ device = f'cuda:{args.device[0]}'
147
+ else:
148
+ device = 'cpu'
149
+ print(f"Using device: {device}")
150
+
151
+ palette = sns.color_palette("husl", 8)
152
+ cmap = plt.get_cmap('gist_rainbow')
153
+
154
+ # --- Generalisation Action ('gen') ---
155
+ if 'gen' in args.actions:
156
+ model = _load_ctm_model(args.checkpoint, device)
157
+
158
+ print(f"\n--- Running Generalisation Analysis ('gen'): {args.dataset_for_gen} ---")
159
+ target_dataset_name = f'{args.dataset_for_gen}'
160
+ data_root = f'data/mazes/{target_dataset_name}/test'
161
+ max_target_route_len = 50 # Specific to 'gen' action
162
+
163
+ test_data = MazeImageFolder(
164
+ root=data_root, which_set='test',
165
+ maze_route_length=max_target_route_len,
166
+ expand_range=not args.legacy_scaling, # Legacy checkpoints need a [0, 1] range, but it might be better to default to [-1, 1] in the future
167
+ trunc=True
168
+ )
169
+ # Load a single large batch for 'gen'
170
+ testloader = torch.utils.data.DataLoader(
171
+ test_data, batch_size=min(len(test_data), 2000),
172
+ shuffle=False, num_workers=1
173
+ )
174
+ inputs, targets = next(iter(testloader))
175
+
176
+ actual_lengths = (targets != 4).sum(dim=-1)
177
+ sorted_indices = torch.argsort(actual_lengths, descending=True)
178
+ inputs, targets, actual_lengths = inputs[sorted_indices], targets[sorted_indices], actual_lengths[sorted_indices]
179
+
180
+ test_how_many = min(1000, len(inputs))
181
+ print(f"Processing {test_how_many} mazes sorted by length...")
182
+
183
+ results = {}
184
+ fault_tolerance = 2 # Specific to 'gen' analysis
185
+ output_gen_dir = os.path.join(args.output_dir, 'gen', args.dataset_for_gen)
186
+ os.makedirs(output_gen_dir, exist_ok=True)
187
+
188
+ for n_tested in range(test_how_many):
189
+ maze_actual_length = actual_lengths[n_tested].item()
190
+ maze_idx_display = n_tested + 1
191
+ print(f"Testing maze {maze_idx_display}/{test_how_many} (Len: {maze_actual_length})...")
192
+
193
+ initial_input_maze = inputs[n_tested:n_tested+1].clone().to(device)
194
+ maze_output_dir = os.path.join(output_gen_dir, f"maze_{maze_idx_display}")
195
+
196
+ re_applications = 0
197
+ has_solved = False
198
+ current_input_maze = initial_input_maze
199
+ exclusions = []
200
+ long_frames = []
201
+ ongoing_solution_img = None
202
+
203
+ while not has_solved and re_applications < args.max_reapplications:
204
+ re_applications += 1
205
+ with torch.no_grad():
206
+ predictions, certainties, _, _, _, attention_tracking = model(current_input_maze, track=True)
207
+
208
+ h_feat, w_feat = model.kv_features.shape[-2:]
209
+ attention_tracking = attention_tracking.reshape(attention_tracking.shape[0], -1, h_feat, w_feat)
210
+
211
+ n_steps_viz = predictions.shape[-1] # Use a different name to avoid conflict if n_steps is used elsewhere
212
+ step_linspace = np.linspace(0, 1, n_steps_viz)
213
+ current_maze_np = current_input_maze[0].permute(1,2,0).detach().cpu().numpy()
214
+
215
+ for stepi in range(n_steps_viz):
216
+ pred_route = predictions[0, :, stepi].reshape(-1, 5).argmax(-1).detach().cpu().numpy()
217
+ frame = draw_path(current_maze_np, pred_route)
218
+ if attention_tracking is not None and stepi < attention_tracking.shape[0]:
219
+ try:
220
+ attn = attention_tracking[stepi].mean(0)
221
+ attn_resized = cv2.resize(attn, (current_maze_np.shape[1], current_maze_np.shape[0]), interpolation=cv2.INTER_LINEAR)
222
+ if attn_resized.max() > attn_resized.min():
223
+ attn_norm = (attn_resized - attn_resized.min()) / (attn_resized.max() - attn_resized.min())
224
+ attn_norm[attn_norm < np.percentile(attn_norm, 80)] = 0.0
225
+ frame = np.clip((np.copy(frame)*(1-attn_norm[:,:,np.newaxis])*1 + (attn_norm[:,:,np.newaxis]*0.8 * np.reshape(np.array(cmap(step_linspace[stepi]))[:3], (1, 1, 3)))), 0, 1)
226
+ except Exception: # Keep broad except for visualization robustness
227
+ pass
228
+ frame_resized = cv2.resize(frame, (int(current_maze_np.shape[1]*4), int(current_maze_np.shape[0]*4)), interpolation=cv2.INTER_NEAREST) # Corrected shape[1]*4 for height
229
+ long_frames.append((np.clip(frame_resized, 0, 1) * 255).astype(np.uint8))
230
+
231
+ where_most_certain = certainties[0, 1].argmax().item()
232
+ chosen_pred_route = predictions[0, :, where_most_certain].reshape(-1, 5).argmax(-1).detach().cpu().numpy()
233
+ current_start_loc_list = np.argwhere((current_maze_np == [1, 0, 0]).all(axis=2)).tolist()
234
+
235
+ # Ensure current_start_loc_list is not empty before trying to access its elements
236
+ if not current_start_loc_list:
237
+ print(f"Warning: Could not find start location in maze {maze_idx_display} during reapplication {re_applications}. Stopping reapplication.")
238
+ break # Cannot proceed without a start location
239
+
240
+ solved_now, final_pos, _ = has_solved_checker(current_maze_np, chosen_pred_route, True, fault_tolerance, exclusions)
241
+
242
+ path_img = draw_path(current_maze_np, chosen_pred_route, cmap=cmap, valid_only=True)
243
+ if ongoing_solution_img is None:
244
+ ongoing_solution_img = path_img
245
+ else:
246
+ mask = (np.any(ongoing_solution_img!=path_img, -1))&(~np.all(path_img==[1,1,1], -1))&(~np.all(ongoing_solution_img==[1,0,0], -1))
247
+ ongoing_solution_img[mask] = path_img[mask]
248
+
249
+ if solved_now:
250
+ has_solved = True
251
+ break
252
+
253
+ if tuple(current_start_loc_list[0]) == final_pos:
254
+ exclusions.append(tuple(current_start_loc_list[0]))
255
+
256
+ next_input = current_input_maze.clone()
257
+ old_start_idx = tuple(current_start_loc_list[0])
258
+ next_input[0, :, old_start_idx[0], old_start_idx[1]] = 1.0 # Reset old start to path
259
+
260
+ if 0 <= final_pos[0] < next_input.shape[2] and 0 <= final_pos[1] < next_input.shape[3]:
261
+ next_input[0, :, final_pos[0], final_pos[1]] = torch.tensor([1,0,0], device=device, dtype=next_input.dtype) # New start
262
+ else:
263
+ print(f"Warning: final_pos {final_pos} out of bounds for maze {maze_idx_display}. Stopping reapplication.")
264
+ break
265
+ current_input_maze = next_input
266
+
267
+ if has_solved:
268
+ print(f'Solved maze of length {maze_actual_length}! Saving...')
269
+ os.makedirs(maze_output_dir, exist_ok=True)
270
+ if ongoing_solution_img is not None:
271
+ cv2.imwrite(os.path.join(maze_output_dir, 'ongoing_solution.png'), (ongoing_solution_img * 255).astype(np.uint8)[:,:,::-1])
272
+ if long_frames:
273
+ save_frames_to_mp4([fm[:,:,::-1] for fm in long_frames], os.path.join(maze_output_dir, f'combined_process.mp4'), fps=45, gop_size=10, preset='veryslow', crf=20)
274
+ else:
275
+ print(f'Failed maze of length {maze_actual_length} after {re_applications} reapplications. Not saving visuals for this maze.')
276
+
277
+ if maze_actual_length not in results: results[maze_actual_length] = []
278
+ results[maze_actual_length].append((has_solved, re_applications))
279
+
280
+ fig_success, ax_success = plt.subplots()
281
+ fig_reapp, ax_reapp = plt.subplots()
282
+ sorted_lengths = sorted(results.keys())
283
+ if sorted_lengths:
284
+ success_rates = [np.mean([r[0] for r in results[l]]) * 100 for l in sorted_lengths]
285
+ reapps_mean = [np.mean([r[1] for r in results[l] if r[0]]) if any(r[0] for r in results[l]) else np.nan for l in sorted_lengths]
286
+ ax_success.plot(sorted_lengths, success_rates, linestyle='-', color=palette[0])
287
+ ax_reapp.plot(sorted_lengths, reapps_mean, linestyle='-', color=palette[5])
288
+ ax_success.set_xlabel('Route Length'); ax_success.set_ylabel('Success (%)')
289
+ ax_reapp.set_xlabel('Route Length'); ax_reapp.set_ylabel('Re-applications (Avg on Success)')
290
+ fig_success.tight_layout(pad=0.1); fig_reapp.tight_layout(pad=0.1)
291
+ fig_success.savefig(os.path.join(output_gen_dir, f'{args.dataset_for_gen}-success_rate.png'), dpi=200)
292
+ fig_success.savefig(os.path.join(output_gen_dir, f'{args.dataset_for_gen}-success_rate.pdf'), dpi=200)
293
+ fig_reapp.savefig(os.path.join(output_gen_dir, f'{args.dataset_for_gen}-re-applications.png'), dpi=200)
294
+ fig_reapp.savefig(os.path.join(output_gen_dir, f'{args.dataset_for_gen}-re-applications.pdf'), dpi=200)
295
+ plt.close(fig_success); plt.close(fig_reapp)
296
+ np.savez(os.path.join(output_gen_dir, f'{args.dataset_for_gen}_results.npz'), results=results)
297
+
298
+ print("\n--- Generalisation Analysis ('gen') Complete ---")
299
+
300
+ # --- Visualization Action ('viz') ---
301
+ if 'viz' in args.actions:
302
+ model = _load_ctm_model(args.checkpoint, device)
303
+
304
+ print(f"\n--- Running Visualization ('viz'): {args.dataset_for_viz} ---")
305
+ output_viz_dir = os.path.join(args.output_dir, 'viz')
306
+ os.makedirs(output_viz_dir, exist_ok=True)
307
+
308
+ target_dataset_name = f'{args.dataset_for_viz}'
309
+ data_root = f'data/mazes/{target_dataset_name}/test'
310
+ test_data = MazeImageFolder(
311
+ root=data_root, which_set='test',
312
+ maze_route_length=100, # Max route length for viz data
313
+ expand_range=not args.legacy_scaling, # # Legacy checkpoints need a [0, 1] range, but it might be better to default to [-1, 1] in the future
314
+ trunc=True
315
+ )
316
+ testloader = torch.utils.data.DataLoader(
317
+ test_data, batch_size=args.batch_size_test,
318
+ shuffle=False, num_workers=1
319
+ )
320
+
321
+ all_inputs, all_targets, all_lengths = [], [], []
322
+ for b_in, b_tgt in testloader:
323
+ all_inputs.append(b_in)
324
+ all_targets.append(b_tgt)
325
+ all_lengths.append((b_tgt != 4).sum(dim=-1))
326
+
327
+ if not all_inputs:
328
+ print("Error: No data in visualization loader. Exiting 'viz' action.")
329
+ exit()
330
+
331
+ all_inputs, all_targets, all_lengths = torch.cat(all_inputs), torch.cat(all_targets), torch.cat(all_lengths)
332
+
333
+ num_viz_mazes = 10
334
+ num_viz_mazes = min(num_viz_mazes, len(all_lengths))
335
+
336
+ if num_viz_mazes == 0:
337
+ print("Error: No mazes found to visualize. Exiting 'viz' action.")
338
+ exit()
339
+
340
+ top_indices = torch.argsort(all_lengths, descending=True)[:num_viz_mazes]
341
+ inputs_viz, targets_viz = all_inputs[top_indices].to(device), all_targets[top_indices]
342
+
343
+ print(f"Visualizing {len(inputs_viz)} longest mazes...")
344
+
345
+ with torch.no_grad():
346
+ predictions, _, _, _, _, attention_tracking = model(inputs_viz, track=True)
347
+
348
+ # Reshape attention: (Steps, Batch, Heads, H_feat, W_feat) assuming model.kv_features has H_feat, W_feat
349
+ # The original reshape was slightly different, this tries to match the likely intended dimensions for per-step, per-batch item attention
350
+ if attention_tracking is not None and hasattr(model, 'kv_features') and model.kv_features is not None:
351
+ attention_tracking = attention_tracking.reshape(
352
+ attention_tracking.shape[0], # Iterations/Steps
353
+ inputs_viz.size(0), # Batch size (num_viz_mazes)
354
+ -1, # Heads (inferred)
355
+ model.kv_features.shape[-2], # H_feat
356
+ model.kv_features.shape[-1] # W_feat
357
+ )
358
+ else:
359
+ attention_tracking = None # Ensure it's None if it can't be reshaped
360
+ print("Warning: Could not reshape attention_tracking. Visualizations may not include attention overlays.")
361
+
362
+
363
+ for maze_i in range(inputs_viz.size(0)):
364
+ maze_idx_display = maze_i + 1
365
+ maze_output_dir = os.path.join(output_viz_dir, f"maze_{maze_idx_display}")
366
+ os.makedirs(maze_output_dir, exist_ok=True)
367
+
368
+ current_input_np_original = inputs_viz[maze_i].permute(1,2,0).detach().cpu().numpy()
369
+ # Apply scaling for visualization based on legacy_scaling: Legacy checkpoints need a [0, 1] range, but it might be better to default to [-1, 1] in the future
370
+ current_input_np_display = (current_input_np_original + 1) / 2 if not args.legacy_scaling else current_input_np_original
371
+
372
+ current_target_route = targets_viz[maze_i].detach().cpu().numpy()
373
+ print(f"Generating viz for maze {maze_idx_display}...")
374
+
375
+ try:
376
+ solution_maze_img = draw_path(current_input_np_display, current_target_route, gt=True)
377
+ cv2.imwrite(os.path.join(maze_output_dir, 'solution_ground_truth.png'), (solution_maze_img * 255).astype(np.uint8)[:,:,::-1])
378
+ except Exception: # Keep broad except for visualization robustness
379
+ print(f"Could not save ground truth solution for maze {maze_idx_display}")
380
+ pass
381
+
382
+ frames = []
383
+ n_steps_viz = predictions.shape[-1] # Use a different name
384
+ step_linspace = np.linspace(0, 1, n_steps_viz)
385
+
386
+ for stepi in range(n_steps_viz):
387
+ pred_route = predictions[maze_i, :, stepi].reshape(-1, 5).argmax(-1).detach().cpu().numpy()
388
+ frame = draw_path(current_input_np_display, pred_route)
389
+
390
+ if attention_tracking is not None and stepi < attention_tracking.shape[0] and maze_i < attention_tracking.shape[1]:
391
+
392
+ # Attention for current step (stepi) and current maze in batch (maze_i), average over heads
393
+ attn = attention_tracking[stepi, maze_i].mean(0)
394
+ attn_resized = cv2.resize(attn, (current_input_np_display.shape[1], current_input_np_display.shape[0]), interpolation=cv2.INTER_LINEAR)
395
+ if attn_resized.max() > attn_resized.min():
396
+ attn_norm = (attn_resized - attn_resized.min()) / (attn_resized.max() - attn_resized.min())
397
+ attn_norm[attn_norm < np.percentile(attn_norm, 80)] = 0.0
398
+ frame = np.clip((np.copy(frame)*(1-attn_norm[:,:,np.newaxis])*0.9 + (attn_norm[:,:,np.newaxis]*1.2 * np.reshape(np.array(cmap(step_linspace[stepi]))[:3], (1, 1, 3)))), 0, 1)
399
+
400
+
401
+ frame_resized = cv2.resize(frame, (256, 256), interpolation=cv2.INTER_NEAREST)
402
+ frames.append((np.clip(frame_resized, 0, 1) * 255).astype(np.uint8))
403
+
404
+ if frames:
405
+ imageio.mimsave(os.path.join(maze_output_dir, 'attention_overlay.gif'), frames, fps=15, loop=0)
406
+
407
+ print("\n--- Visualization Action ('viz') Complete ---")
tasks/mazes/plotting.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import numpy as np
3
+ import cv2
4
+ import torch
5
+ import os
6
+ import matplotlib.pyplot as plt
7
+ import imageio
8
+
9
+ from tqdm.auto import tqdm
10
+
11
+ def find_center_of_mass(array_2d):
12
+ """
13
+ Alternative implementation using np.average and meshgrid.
14
+ This version is generally faster and more concise.
15
+
16
+ Args:
17
+ array_2d: A 2D numpy array of values between 0 and 1.
18
+
19
+ Returns:
20
+ A tuple (x, y) representing the coordinates of the center of mass.
21
+ """
22
+ total_mass = np.sum(array_2d)
23
+ if total_mass == 0:
24
+ return (np.nan, np.nan)
25
+
26
+ y_coords, x_coords = np.mgrid[:array_2d.shape[0], :array_2d.shape[1]]
27
+ x_center = np.average(x_coords, weights=array_2d)
28
+ y_center = np.average(y_coords, weights=array_2d)
29
+ return (round(y_center, 4), round(x_center, 4))
30
+
31
+ def draw_path(x, route, valid_only=False, gt=False, cmap=None):
32
+ """
33
+ Draws a path on a maze image based on a given route.
34
+
35
+ Args:
36
+ maze: A numpy array representing the maze image.
37
+ route: A list of integers representing the route, where 0 is up, 1 is down, 2 is left, and 3 is right.
38
+ valid_only: A boolean indicating whether to only draw valid steps (i.e., steps that don't go into walls).
39
+
40
+ Returns:
41
+ A numpy array representing the maze image with the path drawn in blue.
42
+ """
43
+ x = np.copy(x)
44
+ start = np.argwhere((x == [1, 0, 0]).all(axis=2))
45
+ end = np.argwhere((x == [0, 1, 0]).all(axis=2))
46
+ if cmap is None:
47
+ cmap = plt.get_cmap('winter') if not valid_only else plt.get_cmap('summer')
48
+
49
+ # Initialize the current position
50
+ current_pos = start[0]
51
+
52
+ # Draw the path
53
+ colors = cmap(np.linspace(0, 1, len(route)))
54
+ si = 0
55
+ for step in route:
56
+ new_pos = current_pos
57
+ if step == 0: # Up
58
+ new_pos = (current_pos[0] - 1, current_pos[1])
59
+ elif step == 1: # Down
60
+ new_pos = (current_pos[0] + 1, current_pos[1])
61
+ elif step == 2: # Left
62
+ new_pos = (current_pos[0], current_pos[1] - 1)
63
+ elif step == 3: # Right
64
+ new_pos = (current_pos[0], current_pos[1] + 1)
65
+ elif step == 4: # Do nothing
66
+ pass
67
+ else:
68
+ raise ValueError("Invalid step: {}".format(step))
69
+
70
+ # Check if the new position is valid
71
+ if valid_only:
72
+ try:
73
+ if np.all(x[new_pos] == [0,0,0]): # Check if it's a wall
74
+ continue # Skip this step if it's invalid
75
+ except IndexError:
76
+ continue # Skip this step if it's out of bounds
77
+
78
+ # Draw the step
79
+ if new_pos[0] >= 0 and new_pos[0] < x.shape[0] and new_pos[1] >= 0 and new_pos[1] < x.shape[1]:
80
+ if not ((x[new_pos] == [1,0,0]).all() or (x[new_pos] == [0,1,0]).all()):
81
+ colour = colors[si][:3]
82
+ si += 1
83
+ x[new_pos] = x[new_pos]*0.5 + colour*0.5
84
+
85
+ # Update the current position
86
+ current_pos = new_pos
87
+ # cv2.imwrite('maze2.png', x[:,:,::-1]*255)
88
+
89
+ return x
90
+
91
+ def make_maze_gif(inputs, predictions, targets, attention_tracking, save_location, verbose=True):
92
+ """
93
+ Expect inputs, predictions, targets as numpy arrays
94
+ """
95
+ route_steps = []
96
+ route_colours = []
97
+ solution_maze = draw_path(np.moveaxis(inputs, 0, -1), targets)
98
+
99
+ n_heads = attention_tracking.shape[1]
100
+ mosaic = [['overlay', 'overlay', 'overlay', 'overlay', 'route', 'route', 'route', 'route'],
101
+ ['overlay', 'overlay', 'overlay', 'overlay', 'route', 'route', 'route', 'route'],
102
+ ['overlay', 'overlay', 'overlay', 'overlay', 'route', 'route', 'route', 'route'],
103
+ ['overlay', 'overlay', 'overlay', 'overlay', 'route', 'route', 'route', 'route'],
104
+ ['head_0', 'head_1', 'head_2', 'head_3', 'head_4', 'head_5', 'head_6', 'head_7'],
105
+ ['head_8', 'head_9', 'head_10', 'head_11', 'head_12', 'head_13', 'head_14', 'head_15'],
106
+ ]
107
+ if n_heads == 8:
108
+ mosaic = [['overlay', 'overlay', 'overlay', 'overlay', 'route', 'route', 'route', 'route'],
109
+ ['overlay', 'overlay', 'overlay', 'overlay', 'route', 'route', 'route', 'route'],
110
+ ['overlay', 'overlay', 'overlay', 'overlay', 'route', 'route', 'route', 'route'],
111
+ ['overlay', 'overlay', 'overlay', 'overlay', 'route', 'route', 'route', 'route'],
112
+ ['head_0', 'head_1', 'head_2', 'head_3', 'head_4', 'head_5', 'head_6', 'head_7'],
113
+ ]
114
+ elif n_heads == 4:
115
+ mosaic = [['overlay', 'overlay', 'overlay', 'overlay', 'route', 'route', 'route', 'route'],
116
+ ['overlay', 'overlay', 'overlay', 'overlay', 'route', 'route', 'route', 'route'],
117
+ ['overlay', 'overlay', 'overlay', 'overlay', 'route', 'route', 'route', 'route'],
118
+ ['overlay', 'overlay', 'overlay', 'overlay', 'route', 'route', 'route', 'route'],
119
+ ['head_0', 'head_0', 'head_1', 'head_1', 'head_2', 'head_2', 'head_3', 'head_3'],
120
+ ['head_0', 'head_0', 'head_1', 'head_1', 'head_2', 'head_2', 'head_3', 'head_3'],
121
+ ]
122
+
123
+ img_aspect = 1
124
+ figscale = 1
125
+ aspect_ratio = (len(mosaic[0]) * figscale, len(mosaic) * figscale * img_aspect) # W, H
126
+
127
+ route_steps = [np.unravel_index(np.argmax((inputs == np.reshape(np.array([1, 0, 0]), (3, 1, 1))).all(0)), inputs.shape[1:])] # Starting point
128
+ frames = []
129
+ cmap = plt.get_cmap('gist_rainbow')
130
+ cmap_viridis = plt.get_cmap('viridis')
131
+ step_linspace = np.linspace(0, 1, predictions.shape[-1]) # For sampling colours
132
+ with tqdm(total=predictions.shape[-1], initial=0, leave=True, position=1, dynamic_ncols=True) as pbar:
133
+ if verbose: pbar.set_description('Processing frames for maze plotting')
134
+ for stepi in np.arange(0, predictions.shape[-1], 1):
135
+ fig, axes = plt.subplot_mosaic(mosaic, figsize=aspect_ratio)
136
+ for ax in axes.values():
137
+ ax.axis('off')
138
+ guess_maze = draw_path(np.moveaxis(inputs, 0, -1), predictions.argmax(1)[:,stepi], cmap=cmap)
139
+ attention_now = attention_tracking[stepi]
140
+ for hi in range(min((attention_tracking.shape[1], 16))):
141
+ ax = axes[f'head_{hi}']
142
+ attn = attention_tracking[stepi, hi]
143
+ attn = (attn - attn.min())/(np.ptp(attn))
144
+ ax.imshow(attn, cmap=cmap_viridis)
145
+ # Upsample attention just for visualisation
146
+ aggregated_attention = torch.nn.functional.interpolate(torch.from_numpy(attention_now).unsqueeze(0), inputs.shape[-1], mode='bilinear')[0].mean(0).numpy()
147
+
148
+ # Get approximate center of mass
149
+ com_attn = np.copy(aggregated_attention)
150
+ com_attn[com_attn < np.percentile(com_attn, 96)] = 0.0
151
+ aggregated_attention[aggregated_attention < np.percentile(aggregated_attention, 80)] = 0.0
152
+ route_steps.append(find_center_of_mass(com_attn))
153
+
154
+
155
+ colour = list(cmap(step_linspace[stepi]))
156
+ route_colours.append(colour)
157
+
158
+ mapped_attention = torch.nn.functional.interpolate(torch.from_numpy(attention_now).unsqueeze(0), inputs.shape[-1], mode='bilinear')[0].mean(0).numpy()
159
+ mapped_attention = (mapped_attention - mapped_attention.min())/np.ptp(mapped_attention)
160
+ # np.clip(guess_maze * (1-mapped_attention[...,np.newaxis]*0.5) + (cmap_viridis(mapped_attention)[:,:,:3] * mapped_attention[...,np.newaxis])*1.3, 0, 1)
161
+ overlay_img = np.clip(guess_maze * (1-mapped_attention[...,np.newaxis]*0.6) + (cmap_viridis(mapped_attention)[:,:,:3] * mapped_attention[...,np.newaxis])*1.1, 0, 1)#np.clip((np.copy(guess_maze)*(1-aggregated_attention[:,:,np.newaxis])*0.7 + (aggregated_attention[:,:,np.newaxis]*3 * np.reshape(np.array(colour)[:3], (1, 1, 3)))), 0, 1)
162
+ axes['overlay'].imshow(overlay_img)
163
+
164
+ y_coords, x_coords = zip(*route_steps)
165
+ y_coords = inputs.shape[-1] - np.array(list(y_coords))-1
166
+
167
+
168
+ axes['route'].imshow(np.flip(np.moveaxis(inputs, 0, -1), axis=0), origin='lower')
169
+ # ax.imshow(np.flip(solution_maze, axis=0), origin='lower')
170
+ arrow_scale = 2
171
+ for i in range(len(route_steps)-1):
172
+ dx = x_coords[i+1] - x_coords[i]
173
+ dy = y_coords[i+1] - y_coords[i]
174
+ axes['route'].arrow(x_coords[i], y_coords[i], dx, dy, linewidth=2*arrow_scale, head_width=0.2*arrow_scale, head_length=0.3*arrow_scale, fc=route_colours[i], ec=route_colours[i], length_includes_head = True)
175
+
176
+ fig.tight_layout(pad=0.1) # Adjust spacing
177
+
178
+ # Render the plot to a numpy array
179
+ canvas = fig.canvas
180
+ canvas.draw()
181
+ image_numpy = np.frombuffer(canvas.buffer_rgba(), dtype='uint8')
182
+ image_numpy = image_numpy.reshape(*reversed(canvas.get_width_height()), 4)[:,:,:3] # Get RGB
183
+
184
+ frames.append(image_numpy) # Add to list for GIF
185
+
186
+ # fig.savefig(f'{save_location}/frame.png', dpi=200)
187
+
188
+ plt.close(fig)
189
+
190
+ # # frame = np.clip((np.copy(guess_maze)*0.5 + (aggregated_attention[:,:,np.newaxis] * np.reshape(np.array(colour)[:3], (1, 1, 3)))), 0, 1)
191
+ # frame = torch.nn.functional.interpolate(torch.from_numpy(frame).permute(2,0,1).unsqueeze(0), 256)[0].permute(1,2,0).detach().cpu().numpy()
192
+ # frames.append((frame*255).astype(np.uint8))
193
+ pbar.update(1)
194
+
195
+
196
+ y_coords, x_coords = zip(*route_steps)
197
+ y_coords = inputs.shape[-1] - np.array(list(y_coords))-1
198
+
199
+ fig = plt.figure(figsize=(5,5))
200
+ ax = fig.add_subplot(111)
201
+
202
+ ax.imshow(np.flip(np.moveaxis(inputs, 0, -1), axis=0), origin='lower')
203
+ # ax.imshow(np.flip(solution_maze, axis=0), origin='lower')
204
+ arrow_scale = 2
205
+ for i in range(len(route_steps)-1):
206
+ dx = x_coords[i+1] - x_coords[i]
207
+ dy = y_coords[i+1] - y_coords[i]
208
+ plt.arrow(x_coords[i], y_coords[i], dx, dy, linewidth=2*arrow_scale, head_width=0.2*arrow_scale, head_length=0.3*arrow_scale, fc=route_colours[i], ec=route_colours[i], length_includes_head = True)
209
+
210
+ ax.axis('off')
211
+ fig.tight_layout(pad=0)
212
+ fig.savefig(f'{save_location}/route_approximation.png', dpi=200)
213
+ imageio.mimsave(f'{save_location}/prediction.gif', frames, fps=15, loop=100)
214
+ plt.close(fig)
tasks/mazes/scripts/train_ctm.sh ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ python -m tasks.mazes.train \
2
+ --model ctm \
3
+ --log_dir logs/mazes/ctm/d=2048--i=512--heads=16--sd=8--nlm=32--synch=64-32-h=32-first-last--iters=75x25--backbone=34-2 \
4
+ --neuron_select_type first-last \
5
+ --dataset mazes-large \
6
+ --synapse_depth 8 \
7
+ --heads 16 \
8
+ --iterations 75 \
9
+ --memory_length 25 \
10
+ --d_model 2048 \
11
+ --d_input 512 \
12
+ --backbone_type resnet34-2 \
13
+ --n_synch_out 64 \
14
+ --n_synch_action 32 \
15
+ --memory_hidden_dims 32 \
16
+ --deep_memory \
17
+ --weight_decay 0.000 \
18
+ --batch_size 64 \
19
+ --batch_size_test 128 \
20
+ --n_test_batches 20 \
21
+ --gradient_clipping -1 \
22
+ --use_scheduler \
23
+ --scheduler_type cosine \
24
+ --warmup_steps 10000 \
25
+ --training_iterations 1000001 \
26
+ --no-do_normalisation \
27
+ --track_every 1000 \
28
+ --lr 1e-4 \
29
+ --no-reload \
30
+ --dropout 0.1 \
31
+ --positional_embedding_type none \
32
+ --maze_route_length 100 \
33
+ --cirriculum_lookahead 5 \
34
+ --device 0 \
35
+ --no-expand_range
tasks/mazes/train.py ADDED
@@ -0,0 +1,704 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import random
4
+
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ import seaborn as sns
8
+ sns.set_style('darkgrid')
9
+ import torch
10
+ if torch.cuda.is_available():
11
+ # For faster
12
+ torch.set_float32_matmul_precision('high')
13
+ from tqdm.auto import tqdm
14
+
15
+ from data.custom_datasets import MazeImageFolder
16
+ from models.ctm import ContinuousThoughtMachine
17
+ from models.lstm import LSTMBaseline
18
+ from models.ff import FFBaseline
19
+ from tasks.mazes.plotting import make_maze_gif
20
+ from tasks.image_classification.plotting import plot_neural_dynamics
21
+ from utils.housekeeping import set_seed, zip_python_code
22
+ from utils.losses import maze_loss
23
+ from utils.schedulers import WarmupCosineAnnealingLR, WarmupMultiStepLR, warmup
24
+
25
+ import torchvision
26
+ torchvision.disable_beta_transforms_warning()
27
+
28
+ import warnings
29
+ warnings.filterwarnings("ignore", message="using precomputed metric; inverse_transform will be unavailable")
30
+ warnings.filterwarnings('ignore', message='divide by zero encountered in power', category=RuntimeWarning)
31
+ warnings.filterwarnings(
32
+ "ignore",
33
+ "Corrupt EXIF data",
34
+ UserWarning,
35
+ r"^PIL\.TiffImagePlugin$" # Using a regular expression to match the module.
36
+ )
37
+ warnings.filterwarnings(
38
+ "ignore",
39
+ "UserWarning: Metadata Warning",
40
+ UserWarning,
41
+ r"^PIL\.TiffImagePlugin$" # Using a regular expression to match the module.
42
+ )
43
+ warnings.filterwarnings(
44
+ "ignore",
45
+ "UserWarning: Truncated File Read",
46
+ UserWarning,
47
+ r"^PIL\.TiffImagePlugin$" # Using a regular expression to match the module.
48
+ )
49
+
50
+
51
+ def parse_args():
52
+ parser = argparse.ArgumentParser()
53
+
54
+ # Model Selection
55
+ parser.add_argument('--model', type=str, required=True, choices=['ctm', 'lstm', 'ff'], help='Model type to train.')
56
+
57
+ # Model Architecture
58
+ # Common across all or most
59
+ parser.add_argument('--d_model', type=int, default=512, help='Dimension of the model.')
60
+ parser.add_argument('--dropout', type=float, default=0.0, help='Dropout rate.')
61
+ parser.add_argument('--backbone_type', type=str, default='resnet34-2', help='Type of backbone featureiser.') # Default changed from original script
62
+ # CTM / LSTM specific
63
+ parser.add_argument('--d_input', type=int, default=128, help='Dimension of the input (CTM, LSTM).')
64
+ parser.add_argument('--heads', type=int, default=8, help='Number of attention heads (CTM, LSTM).') # Default changed
65
+ parser.add_argument('--iterations', type=int, default=75, help='Number of internal ticks (CTM, LSTM).')
66
+ parser.add_argument('--positional_embedding_type', type=str, default='none',
67
+ help='Type of positional embedding (CTM, LSTM).', choices=['none',
68
+ 'learnable-fourier',
69
+ 'multi-learnable-fourier',
70
+ 'custom-rotational'])
71
+
72
+ # CTM specific
73
+ parser.add_argument('--synapse_depth', type=int, default=8, help='Depth of U-NET model for synapse. 1=linear, no unet (CTM only).') # Default changed
74
+ parser.add_argument('--n_synch_out', type=int, default=32, help='Number of neurons to use for output synch (CTM only).') # Default changed
75
+ parser.add_argument('--n_synch_action', type=int, default=32, help='Number of neurons to use for observation/action synch (CTM only).') # Default changed
76
+ parser.add_argument('--neuron_select_type', type=str, default='random-pairing', help='Protocol for selecting neuron subset (CTM only).')
77
+ parser.add_argument('--n_random_pairing_self', type=int, default=0, help='Number of neurons paired self-to-self for synch (CTM only).')
78
+ parser.add_argument('--memory_length', type=int, default=25, help='Length of the pre-activation history for NLMS (CTM only).')
79
+ parser.add_argument('--deep_memory', action=argparse.BooleanOptionalAction, default=True,
80
+ help='Use deep memory (CTM only).')
81
+ parser.add_argument('--memory_hidden_dims', type=int, default=32, help='Hidden dimensions of the memory if using deep memory (CTM only).') # Default changed
82
+ parser.add_argument('--dropout_nlm', type=float, default=None, help='Dropout rate for NLMs specifically. Unset to match dropout on the rest of the model (CTM only).')
83
+ parser.add_argument('--do_normalisation', action=argparse.BooleanOptionalAction, default=False, help='Apply normalization in NLMs (CTM only).')
84
+ # LSTM specific
85
+ parser.add_argument('--num_layers', type=int, default=2, help='Number of LSTM stacked layers (LSTM only).') # Added LSTM arg
86
+
87
+ # Task Specific Args (Common to all models for this task)
88
+ parser.add_argument('--maze_route_length', type=int, default=100, help='Length to truncate targets.')
89
+ parser.add_argument('--cirriculum_lookahead', type=int, default=5, help='How far to look ahead for cirriculum.')
90
+
91
+
92
+ # Training
93
+ parser.add_argument('--expand_range', action=argparse.BooleanOptionalAction, default=True, help='Mazes between 0 and 1 = False. Between -1 and 1 = True. Legacy checkpoints use 0 and 1.')
94
+ parser.add_argument('--batch_size', type=int, default=16, help='Batch size for training.') # Default changed
95
+ parser.add_argument('--batch_size_test', type=int, default=64, help='Batch size for testing.') # Default changed
96
+ parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate for the model.') # Default changed
97
+ parser.add_argument('--training_iterations', type=int, default=100001, help='Number of training iterations.')
98
+ parser.add_argument('--warmup_steps', type=int, default=5000, help='Number of warmup steps.')
99
+ parser.add_argument('--use_scheduler', action=argparse.BooleanOptionalAction, default=True, help='Use a learning rate scheduler.')
100
+ parser.add_argument('--scheduler_type', type=str, default='cosine', choices=['multistep', 'cosine'], help='Type of learning rate scheduler.')
101
+ parser.add_argument('--milestones', type=int, default=[8000, 15000, 20000], nargs='+', help='Learning rate scheduler milestones.')
102
+ parser.add_argument('--gamma', type=float, default=0.1, help='Learning rate scheduler gamma for multistep.')
103
+ parser.add_argument('--weight_decay', type=float, default=0.0, help='Weight decay factor.')
104
+ parser.add_argument('--weight_decay_exclusion_list', type=str, nargs='+', default=[], help='List to exclude from weight decay. Typically good: bn, ln, bias, start')
105
+ parser.add_argument('--num_workers_train', type=int, default=0, help='Num workers training.') # Renamed from num_workers, kept default
106
+ parser.add_argument('--gradient_clipping', type=float, default=-1, help='Gradient quantile clipping value (-1 to disable).')
107
+ parser.add_argument('--do_compile', action=argparse.BooleanOptionalAction, default=False, help='Try to compile model components.')
108
+
109
+ # Logging and Saving
110
+ parser.add_argument('--log_dir', type=str, default='logs/scratch', help='Directory for logging.')
111
+ parser.add_argument('--dataset', type=str, default='mazes-medium', help='Dataset to use.', choices=['mazes-medium', 'mazes-large', 'mazes-small'])
112
+ parser.add_argument('--data_root', type=str, default='data/mazes', help='Data root.')
113
+
114
+ parser.add_argument('--save_every', type=int, default=1000, help='Save checkpoints every this many iterations.')
115
+ parser.add_argument('--seed', type=int, default=412, help='Random seed.')
116
+ parser.add_argument('--reload', action=argparse.BooleanOptionalAction, default=False, help='Reload from disk?')
117
+ parser.add_argument('--reload_model_only', action=argparse.BooleanOptionalAction, default=False, help='Reload only the model from disk?')
118
+ parser.add_argument('--strict_reload', action=argparse.BooleanOptionalAction, default=True, help='Should use strict reload for model weights.') # Added back
119
+ parser.add_argument('--ignore_metrics_when_reloading', action=argparse.BooleanOptionalAction, default=False, help='Ignore metrics when reloading (for debugging)?') # Added back
120
+
121
+ # Tracking
122
+ parser.add_argument('--track_every', type=int, default=1000, help='Track metrics every this many iterations.')
123
+ parser.add_argument('--n_test_batches', type=int, default=20, help='How many minibatches to approx metrics. Set to -1 for full eval') # Default changed
124
+
125
+ # Device
126
+ parser.add_argument('--device', type=int, nargs='+', default=[-1], help='List of GPU(s) to use. Set to -1 to use CPU.')
127
+ parser.add_argument('--use_amp', action=argparse.BooleanOptionalAction, default=False, help='AMP autocast.')
128
+
129
+
130
+ args = parser.parse_args()
131
+ return args
132
+
133
+
134
+ if __name__=='__main__':
135
+
136
+ # Hosuekeeping
137
+ args = parse_args()
138
+
139
+ set_seed(args.seed, False)
140
+ if not os.path.exists(args.log_dir): os.makedirs(args.log_dir)
141
+
142
+ assert args.dataset in ['mazes-medium', 'mazes-large', 'mazes-small']
143
+
144
+
145
+
146
+ prediction_reshaper = [args.maze_route_length, 5] # Problem specific
147
+ args.out_dims = args.maze_route_length * 5 # Output dimension before reshaping
148
+
149
+ # For total reproducibility
150
+ zip_python_code(f'{args.log_dir}/repo_state.zip')
151
+ with open(f'{args.log_dir}/args.txt', 'w') as f:
152
+ print(args, file=f)
153
+
154
+ # Configure device string (support MPS on macOS)
155
+ if args.device[0] != -1:
156
+ device = f'cuda:{args.device[0]}'
157
+ elif torch.backends.mps.is_available():
158
+ device = 'mps'
159
+ else:
160
+ device = 'cpu'
161
+ print(f'Running model {args.model} on {device}')
162
+
163
+
164
+ # Build model conditionally
165
+ model = None
166
+ if args.model == 'ctm':
167
+ model = ContinuousThoughtMachine(
168
+ iterations=args.iterations,
169
+ d_model=args.d_model,
170
+ d_input=args.d_input,
171
+ heads=args.heads,
172
+ n_synch_out=args.n_synch_out,
173
+ n_synch_action=args.n_synch_action,
174
+ synapse_depth=args.synapse_depth,
175
+ memory_length=args.memory_length,
176
+ deep_nlms=args.deep_memory,
177
+ memory_hidden_dims=args.memory_hidden_dims,
178
+ do_layernorm_nlm=args.do_normalisation,
179
+ backbone_type=args.backbone_type,
180
+ positional_embedding_type=args.positional_embedding_type,
181
+ out_dims=args.out_dims,
182
+ prediction_reshaper=prediction_reshaper,
183
+ dropout=args.dropout,
184
+ dropout_nlm=args.dropout_nlm,
185
+ neuron_select_type=args.neuron_select_type,
186
+ n_random_pairing_self=args.n_random_pairing_self,
187
+ ).to(device)
188
+ elif args.model == 'lstm':
189
+ model = LSTMBaseline(
190
+ num_layers=args.num_layers,
191
+ iterations=args.iterations,
192
+ d_model=args.d_model,
193
+ d_input=args.d_input,
194
+ heads=args.heads,
195
+ backbone_type=args.backbone_type,
196
+ positional_embedding_type=args.positional_embedding_type,
197
+ out_dims=args.out_dims,
198
+ prediction_reshaper=prediction_reshaper,
199
+ dropout=args.dropout,
200
+ ).to(device)
201
+ elif args.model == 'ff':
202
+ model = FFBaseline(
203
+ d_model=args.d_model,
204
+ backbone_type=args.backbone_type,
205
+ out_dims=args.out_dims,
206
+ dropout=args.dropout,
207
+ ).to(device)
208
+ else:
209
+ raise ValueError(f"Unknown model type: {args.model}")
210
+
211
+ try:
212
+ # Determine pseudo input shape based on dataset
213
+ h_w = 39 if args.dataset in ['mazes-small', 'mazes-medium'] else 99 # Example dimensions
214
+ pseudo_inputs = torch.zeros((1, 3, h_w, h_w), device=device).float()
215
+ model(pseudo_inputs)
216
+ except Exception as e:
217
+ print(f"Warning: Pseudo forward pass failed: {e}")
218
+
219
+ print(f'Total params: {sum(p.numel() for p in model.parameters())}')
220
+
221
+ # Data
222
+ dataset_mean = [0,0,0] # For plotting later
223
+ dataset_std = [1,1,1]
224
+
225
+ which_maze = args.dataset.split('-')[-1]
226
+ data_root = f'{args.data_root}/{which_maze}'
227
+
228
+ train_data = MazeImageFolder(root=f'{data_root}/train/', which_set='train', maze_route_length=args.maze_route_length, expand_range=args.expand_range)
229
+ test_data = MazeImageFolder(root=f'{data_root}/test/', which_set='test', maze_route_length=args.maze_route_length, expand_range=args.expand_range)
230
+
231
+ num_workers_test = 1 # Defaulting to 1, can be changed
232
+ trainloader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers_train, drop_last=True)
233
+ testloader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size_test, shuffle=True, num_workers=num_workers_test, drop_last=False)
234
+
235
+ # For lazy modules so that we can get param count
236
+
237
+
238
+ model.train()
239
+
240
+ # Optimizer and scheduler
241
+ decay_params = []
242
+ no_decay_params = []
243
+ no_decay_names = []
244
+ for name, param in model.named_parameters():
245
+ if not param.requires_grad:
246
+ continue # Skip parameters that don't require gradients
247
+ if any(exclusion_str in name for exclusion_str in args.weight_decay_exclusion_list):
248
+ no_decay_params.append(param)
249
+ no_decay_names.append(name)
250
+ else:
251
+ decay_params.append(param)
252
+ if len(no_decay_names):
253
+ print(f'WARNING, excluding: {no_decay_names}')
254
+
255
+ # Optimizer and scheduler (Common setup)
256
+ if len(no_decay_names) and args.weight_decay!=0:
257
+ optimizer = torch.optim.AdamW([{'params': decay_params, 'weight_decay':args.weight_decay},
258
+ {'params': no_decay_params, 'weight_decay':0}],
259
+ lr=args.lr,
260
+ eps=1e-8 if not args.use_amp else 1e-6)
261
+ else:
262
+ optimizer = torch.optim.AdamW(model.parameters(),
263
+ lr=args.lr,
264
+ eps=1e-8 if not args.use_amp else 1e-6,
265
+ weight_decay=args.weight_decay)
266
+
267
+ warmup_schedule = warmup(args.warmup_steps)
268
+ scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warmup_schedule.step)
269
+ if args.use_scheduler:
270
+ if args.scheduler_type == 'multistep':
271
+ scheduler = WarmupMultiStepLR(optimizer, warmup_steps=args.warmup_steps, milestones=args.milestones, gamma=args.gamma)
272
+ elif args.scheduler_type == 'cosine':
273
+ scheduler = WarmupCosineAnnealingLR(optimizer, args.warmup_steps, args.training_iterations, warmup_start_lr=1e-20, eta_min=1e-7)
274
+ else:
275
+ raise NotImplementedError
276
+
277
+
278
+ # Metrics tracking
279
+ start_iter = 0
280
+ train_losses = []
281
+ test_losses = []
282
+ train_accuracies = [] # Per tick/step accuracy list
283
+ test_accuracies = []
284
+ train_accuracies_most_certain = [] # Accuracy, fine-grained
285
+ test_accuracies_most_certain = []
286
+ train_accuracies_most_certain_permaze = [] # Full maze accuracy
287
+ test_accuracies_most_certain_permaze = []
288
+ iters = []
289
+
290
+ scaler = torch.amp.GradScaler("cuda" if "cuda" in device else "cpu", enabled=args.use_amp)
291
+ if args.reload:
292
+ checkpoint_path = f'{args.log_dir}/checkpoint.pt'
293
+ if os.path.isfile(checkpoint_path):
294
+ print(f'Reloading from: {checkpoint_path}')
295
+ checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
296
+ if not args.strict_reload: print('WARNING: not using strict reload for model weights!')
297
+ load_result = model.load_state_dict(checkpoint['model_state_dict'], strict=args.strict_reload)
298
+ print(f" Loaded state_dict. Missing: {load_result.missing_keys}, Unexpected: {load_result.unexpected_keys}")
299
+
300
+ if not args.reload_model_only:
301
+ print('Reloading optimizer etc.')
302
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
303
+ scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
304
+ scaler.load_state_dict(checkpoint['scaler_state_dict']) # Load scaler state
305
+ start_iter = checkpoint['iteration']
306
+
307
+ if not args.ignore_metrics_when_reloading:
308
+ train_losses = checkpoint['train_losses']
309
+ test_losses = checkpoint['test_losses']
310
+ train_accuracies = checkpoint['train_accuracies']
311
+ test_accuracies = checkpoint['test_accuracies']
312
+ iters = checkpoint['iters']
313
+ train_accuracies_most_certain = checkpoint['train_accuracies_most_certain']
314
+ test_accuracies_most_certain = checkpoint['test_accuracies_most_certain']
315
+ train_accuracies_most_certain_permaze = checkpoint['train_accuracies_most_certain_permaze']
316
+ test_accuracies_most_certain_permaze = checkpoint['test_accuracies_most_certain_permaze']
317
+ else:
318
+ print("Ignoring metrics history upon reload.")
319
+
320
+ else:
321
+ print('Only reloading model!')
322
+
323
+ if 'torch_rng_state' in checkpoint:
324
+ # Reset seeds
325
+ torch.set_rng_state(checkpoint['torch_rng_state'].cpu().byte())
326
+ np.random.set_state(checkpoint['numpy_rng_state'])
327
+ random.setstate(checkpoint['random_rng_state'])
328
+
329
+ del checkpoint
330
+ import gc
331
+ gc.collect()
332
+ if torch.cuda.is_available():
333
+ torch.cuda.empty_cache()
334
+
335
+ if args.do_compile:
336
+ print('Compiling...')
337
+ if hasattr(model, 'backbone'):
338
+ model.backbone = torch.compile(model.backbone, mode='reduce-overhead', fullgraph=True)
339
+ # Compile synapses only for CTM
340
+ if args.model == 'ctm':
341
+ model.synapses = torch.compile(model.synapses, mode='reduce-overhead', fullgraph=True)
342
+
343
+ # Training
344
+ iterator = iter(trainloader)
345
+ with tqdm(total=args.training_iterations, initial=start_iter, leave=False, position=0, dynamic_ncols=True) as pbar:
346
+ for bi in range(start_iter, args.training_iterations):
347
+ current_lr = optimizer.param_groups[-1]['lr']
348
+
349
+ try:
350
+ inputs, targets = next(iterator)
351
+ except StopIteration:
352
+ iterator = iter(trainloader)
353
+ inputs, targets = next(iterator)
354
+
355
+ inputs = inputs.to(device)
356
+ targets = targets.to(device) # Shape (B, SeqLength)
357
+
358
+ # All for nice metric printing:
359
+ loss = None
360
+ accuracy_finegrained = None # Per-step accuracy at chosen tick
361
+ where_most_certain_val = -1.0 # Default value
362
+ where_most_certain_std = 0.0
363
+ where_most_certain_min = -1
364
+ where_most_certain_max = -1
365
+ upto_where_mean = -1.0
366
+ upto_where_std = 0.0
367
+ upto_where_min = -1
368
+ upto_where_max = -1
369
+
370
+
371
+ # Model-specific forward, reshape, and loss calculation
372
+ with torch.autocast(device_type="cuda" if "cuda" in device else "cpu", dtype=torch.float16, enabled=args.use_amp):
373
+ if args.do_compile: # CUDAGraph marking applied if compiling any model
374
+ torch.compiler.cudagraph_mark_step_begin()
375
+
376
+ if args.model == 'ctm':
377
+ # CTM output: (B, SeqLength*5, Ticks), Certainties: (B, Ticks)
378
+ predictions_raw, certainties, synchronisation = model(inputs)
379
+ # Reshape predictions: (B, SeqLength, 5, Ticks)
380
+ predictions = predictions_raw.reshape(predictions_raw.size(0), -1, 5, predictions_raw.size(-1))
381
+ loss, where_most_certain, upto_where = maze_loss(predictions, certainties, targets, cirriculum_lookahead=args.cirriculum_lookahead, use_most_certain=True)
382
+ # Accuracy uses predictions[B, S, C, T] indexed at where_most_certain[B] -> gives (B, S, C) -> argmax(2) -> (B,S)
383
+ accuracy_finegrained = (predictions.argmax(2)[torch.arange(predictions.size(0), device=predictions.device), :, where_most_certain] == targets).float().mean().item()
384
+
385
+ elif args.model == 'lstm':
386
+ # LSTM output: (B, SeqLength*5, Ticks), Certainties: (B, Ticks)
387
+ predictions_raw, certainties, synchronisation = model(inputs)
388
+ # Reshape predictions: (B, SeqLength, 5, Ticks)
389
+ predictions = predictions_raw.reshape(predictions_raw.size(0), -1, 5, predictions_raw.size(-1))
390
+ loss, where_most_certain, upto_where = maze_loss(predictions, certainties, targets, cirriculum_lookahead=args.cirriculum_lookahead, use_most_certain=False)
391
+ # where_most_certain should be -1 (last tick) here. Accuracy calc follows same logic.
392
+ accuracy_finegrained = (predictions.argmax(2)[torch.arange(predictions.size(0), device=predictions.device), :, where_most_certain] == targets).float().mean().item()
393
+
394
+ elif args.model == 'ff':
395
+ # Assume FF output: (B, SeqLength*5)
396
+ predictions_raw = model(inputs)
397
+ # Reshape predictions: (B, SeqLength, 5)
398
+ predictions = predictions_raw.reshape(predictions_raw.size(0), -1, 5)
399
+ # FF has no certainties, pass None. maze_loss must handle this.
400
+ # Unsqueeze predictions for compatibility with maze loss calcluation
401
+ loss, where_most_certain, upto_where = maze_loss(predictions.unsqueeze(-1), None, targets, cirriculum_lookahead=args.cirriculum_lookahead, use_most_certain=False)
402
+ # where_most_certain should be -1 here. Accuracy uses 3D prediction tensor.
403
+ accuracy_finegrained = (predictions.argmax(2) == targets).float().mean().item()
404
+
405
+
406
+ # Extract stats from loss outputs if they are tensors
407
+ if torch.is_tensor(where_most_certain):
408
+ where_most_certain_val = where_most_certain.float().mean().item()
409
+ where_most_certain_std = where_most_certain.float().std().item()
410
+ where_most_certain_min = where_most_certain.min().item()
411
+ where_most_certain_max = where_most_certain.max().item()
412
+ elif isinstance(where_most_certain, int): # Handle case where it might return -1 directly
413
+ where_most_certain_val = float(where_most_certain)
414
+ where_most_certain_min = where_most_certain
415
+ where_most_certain_max = where_most_certain
416
+
417
+ if isinstance(upto_where, (np.ndarray, list)) and len(upto_where) > 0: # Check if it's a list/array
418
+ upto_where_mean = np.mean(upto_where)
419
+ upto_where_std = np.std(upto_where)
420
+ upto_where_min = np.min(upto_where)
421
+ upto_where_max = np.max(upto_where)
422
+
423
+
424
+ scaler.scale(loss).backward()
425
+
426
+ if args.gradient_clipping!=-1:
427
+ scaler.unscale_(optimizer)
428
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.gradient_clipping)
429
+
430
+ scaler.step(optimizer)
431
+ scaler.update()
432
+ optimizer.zero_grad(set_to_none=True)
433
+ scheduler.step()
434
+
435
+ # Conditional Tqdm Description
436
+ pbar_desc = f'Loss={loss.item():0.3f}. Acc(step)={accuracy_finegrained:0.3f}. LR={current_lr:0.6f}.'
437
+ if args.model in ['ctm', 'lstm'] or torch.is_tensor(where_most_certain): # Show stats if available
438
+ pbar_desc += f' Where_certain={where_most_certain_val:0.2f}+-{where_most_certain_std:0.2f} ({where_most_certain_min:d}<->{where_most_certain_max:d}).'
439
+ if isinstance(upto_where, (np.ndarray, list)) and len(upto_where) > 0:
440
+ pbar_desc += f' Path pred stats: {upto_where_mean:0.2f}+-{upto_where_std:0.2f} ({upto_where_min:d} --> {upto_where_max:d})'
441
+
442
+ pbar.set_description(f'Dataset={args.dataset}. Model={args.model}. {pbar_desc}')
443
+
444
+
445
+ # Metrics tracking and plotting
446
+ if bi%args.track_every==0 and (bi != 0 or args.reload_model_only):
447
+ model.eval() # Use eval mode for consistency during tracking
448
+ with torch.inference_mode(): # Use inference mode for tracking
449
+
450
+
451
+
452
+
453
+ # --- Quantitative Metrics ---
454
+ iters.append(bi)
455
+ # Re-initialize metric lists for this evaluation step
456
+ current_train_losses_eval = []
457
+ current_test_losses_eval = []
458
+ current_train_accuracies_eval = []
459
+ current_test_accuracies_eval = []
460
+ current_train_accuracies_most_certain_eval = []
461
+ current_test_accuracies_most_certain_eval = []
462
+ current_train_accuracies_most_certain_permaze_eval = []
463
+ current_test_accuracies_most_certain_permaze_eval = []
464
+
465
+ # TRAIN METRICS
466
+ pbar.set_description('Tracking: Computing TRAIN metrics')
467
+ loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size_test, shuffle=True, num_workers=num_workers_test) # Use consistent num_workers
468
+ all_targets_list = []
469
+ all_predictions_list = [] # Per step/tick predictions argmax (N, S, T) or (N, S)
470
+ all_predictions_most_certain_list = [] # Predictions at chosen step/tick argmax (N, S)
471
+ all_losses = []
472
+
473
+ with tqdm(total=len(loader), initial=0, leave=False, position=1, dynamic_ncols=True) as pbar_inner:
474
+ for inferi, (inputs, targets) in enumerate(loader):
475
+ inputs = inputs.to(device)
476
+ targets = targets.to(device)
477
+ all_targets_list.append(targets.detach().cpu().numpy()) # N x S
478
+
479
+ # Model-specific forward, reshape, loss for evaluation
480
+ if args.model == 'ctm':
481
+ predictions_raw, certainties, _ = model(inputs)
482
+ predictions = predictions_raw.reshape(predictions_raw.size(0), -1, 5, predictions_raw.size(-1)) # B,S,C,T
483
+ loss, where_most_certain, _ = maze_loss(predictions, certainties, targets, use_most_certain=True)
484
+ all_predictions_list.append(predictions.argmax(2).detach().cpu().numpy()) # B,S,C,T -> argmax class -> B,S,T
485
+ pred_at_certain = predictions.argmax(2)[torch.arange(predictions.size(0), device=predictions.device), :, where_most_certain] # B,S
486
+ all_predictions_most_certain_list.append(pred_at_certain.detach().cpu().numpy())
487
+
488
+ elif args.model == 'lstm':
489
+ predictions_raw, certainties, _ = model(inputs)
490
+ predictions = predictions_raw.reshape(predictions_raw.size(0), -1, 5, predictions_raw.size(-1)) # B,S,C,T
491
+ loss, where_most_certain, _ = maze_loss(predictions, certainties, targets, use_most_certain=False) # where = -1
492
+ all_predictions_list.append(predictions.argmax(2).detach().cpu().numpy()) # B,S,C,T
493
+ pred_at_certain = predictions.argmax(2)[torch.arange(predictions.size(0), device=predictions.device), :, where_most_certain] # B,S (at last tick)
494
+ all_predictions_most_certain_list.append(pred_at_certain.detach().cpu().numpy())
495
+
496
+ elif args.model == 'ff':
497
+ predictions_raw = model(inputs) # B, S*C
498
+ predictions = predictions_raw.reshape(predictions_raw.size(0), -1, 5) # B,S,C
499
+ loss, where_most_certain, _ = maze_loss(predictions.unsqueeze(-1), None, targets, use_most_certain=False) # where = -1
500
+ all_predictions_list.append(predictions.argmax(2).detach().cpu().numpy()) # B,S
501
+ all_predictions_most_certain_list.append(predictions.argmax(2).detach().cpu().numpy()) # B,S (same as above for FF)
502
+
503
+
504
+ all_losses.append(loss.item())
505
+
506
+ if args.n_test_batches != -1 and inferi >= args.n_test_batches -1 : break
507
+ pbar_inner.set_description(f'Computing metrics for train (Batch {inferi+1})')
508
+ pbar_inner.update(1)
509
+
510
+ all_targets = np.concatenate(all_targets_list) # N, S
511
+ all_predictions = np.concatenate(all_predictions_list) # N, S, T or N, S
512
+ all_predictions_most_certain = np.concatenate(all_predictions_most_certain_list) # N, S
513
+
514
+ train_losses.append(np.mean(all_losses))
515
+ # Calculate per step/tick accuracy averaged over batches
516
+ if args.model in ['ctm', 'lstm']:
517
+ # all_predictions shape (N, S, T), all_targets shape (N, S) -> compare targets to each tick prediction
518
+ train_accuracies.append(np.mean(all_predictions == all_targets[:,:,np.newaxis], axis=0)) # Mean over N -> (S, T)
519
+ else: # FF
520
+ # all_predictions shape (N, S), all_targets shape (N, S)
521
+ train_accuracies.append(np.mean(all_predictions == all_targets, axis=0)) # Mean over N -> (S,)
522
+
523
+ # Calculate accuracy at chosen step/tick ("most certain") averaged over all steps and batches
524
+ train_accuracies_most_certain.append((all_targets == all_predictions_most_certain).mean()) # Scalar
525
+ # Calculate full maze accuracy at chosen step/tick averaged over batches
526
+ train_accuracies_most_certain_permaze.append((all_targets == all_predictions_most_certain).reshape(all_targets.shape[0], -1).all(-1).mean()) # Scalar
527
+
528
+
529
+ # TEST METRICS
530
+ pbar.set_description('Tracking: Computing TEST metrics')
531
+ loader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size_test, shuffle=True, num_workers=num_workers_test)
532
+ all_targets_list = []
533
+ all_predictions_list = []
534
+ all_predictions_most_certain_list = []
535
+ all_losses = []
536
+
537
+ with tqdm(total=len(loader), initial=0, leave=False, position=1, dynamic_ncols=True) as pbar_inner:
538
+ for inferi, (inputs, targets) in enumerate(loader):
539
+ inputs = inputs.to(device)
540
+ targets = targets.to(device)
541
+ all_targets_list.append(targets.detach().cpu().numpy())
542
+
543
+ # Model-specific forward, reshape, loss for evaluation
544
+ if args.model == 'ctm':
545
+ predictions_raw, certainties, _ = model(inputs)
546
+ predictions = predictions_raw.reshape(predictions_raw.size(0), -1, 5, predictions_raw.size(-1)) # B,S,C,T
547
+ loss, where_most_certain, _ = maze_loss(predictions, certainties, targets, use_most_certain=True)
548
+ all_predictions_list.append(predictions.argmax(2).detach().cpu().numpy()) # B,S,T
549
+ pred_at_certain = predictions.argmax(2)[torch.arange(predictions.size(0), device=predictions.device), :, where_most_certain] # B,S
550
+ all_predictions_most_certain_list.append(pred_at_certain.detach().cpu().numpy())
551
+
552
+ elif args.model == 'lstm':
553
+ predictions_raw, certainties, _ = model(inputs)
554
+ predictions = predictions_raw.reshape(predictions_raw.size(0), -1, 5, predictions_raw.size(-1)) # B,S,C,T
555
+ loss, where_most_certain, _ = maze_loss(predictions, certainties, targets, use_most_certain=False) # where = -1
556
+ all_predictions_list.append(predictions.argmax(2).detach().cpu().numpy()) # B,S,T
557
+ pred_at_certain = predictions.argmax(2)[torch.arange(predictions.size(0), device=predictions.device), :, where_most_certain] # B,S (at last tick)
558
+ all_predictions_most_certain_list.append(pred_at_certain.detach().cpu().numpy())
559
+
560
+ elif args.model == 'ff':
561
+ predictions_raw = model(inputs) # B, S*C
562
+ predictions = predictions_raw.reshape(predictions_raw.size(0), -1, 5) # B,S,C
563
+ loss, where_most_certain, _ = maze_loss(predictions.unsqueeze(-1), None, targets, use_most_certain=False) # where = -1
564
+ all_predictions_list.append(predictions.argmax(2).detach().cpu().numpy()) # B,S
565
+ all_predictions_most_certain_list.append(predictions.argmax(2).detach().cpu().numpy()) # B,S (same as above for FF)
566
+
567
+
568
+ all_losses.append(loss.item())
569
+
570
+ if args.n_test_batches != -1 and inferi >= args.n_test_batches -1: break
571
+ pbar_inner.set_description(f'Computing metrics for test (Batch {inferi+1})')
572
+ pbar_inner.update(1)
573
+
574
+ all_targets = np.concatenate(all_targets_list)
575
+ all_predictions = np.concatenate(all_predictions_list)
576
+ all_predictions_most_certain = np.concatenate(all_predictions_most_certain_list)
577
+
578
+ test_losses.append(np.mean(all_losses))
579
+ # Calculate per step/tick accuracy
580
+ if args.model in ['ctm', 'lstm']:
581
+ test_accuracies.append(np.mean(all_predictions == all_targets[:,:,np.newaxis], axis=0)) # -> (S, T)
582
+ else: # FF
583
+ test_accuracies.append(np.mean(all_predictions == all_targets, axis=0)) # -> (S,)
584
+
585
+ # Calculate "most certain" accuracy
586
+ test_accuracies_most_certain.append((all_targets == all_predictions_most_certain).mean()) # Scalar
587
+ # Calculate full maze accuracy
588
+ test_accuracies_most_certain_permaze.append((all_targets == all_predictions_most_certain).reshape(all_targets.shape[0], -1).all(-1).mean()) # Scalar
589
+
590
+
591
+ # --- Plotting ---
592
+ # Accuracy Plot (Handling different dimensions)
593
+ figacc = plt.figure(figsize=(10, 10))
594
+ axacc_train = figacc.add_subplot(211)
595
+ axacc_test = figacc.add_subplot(212)
596
+ cm = sns.color_palette("viridis", as_cmap=True)
597
+
598
+ # Plot per step/tick accuracy
599
+ # train_accuracies is List[(S, T)] or List[(S,)]
600
+ # We need to average over S dimension for plotting
601
+ train_acc_plot = [np.mean(acc_s) for acc_s in train_accuracies] # List[Scalar] or List[Scalar] after mean
602
+ test_acc_plot = [np.mean(acc_s) for acc_s in test_accuracies] # List[Scalar] or List[Scalar] after mean
603
+
604
+ axacc_train.plot(iters, train_acc_plot, 'g-', alpha=0.5, label='Avg Step Acc')
605
+ axacc_test.plot(iters, test_acc_plot, 'g-', alpha=0.5, label='Avg Step Acc')
606
+
607
+
608
+ # Plot most certain accuracy
609
+ axacc_train.plot(iters, train_accuracies_most_certain, 'k--', alpha=0.7, label='Most Certain (Avg Step)')
610
+ axacc_test.plot(iters, test_accuracies_most_certain, 'k--', alpha=0.7, label='Most Certain (Avg Step)')
611
+ # Plot full maze accuracy
612
+ axacc_train.plot(iters, train_accuracies_most_certain_permaze, 'r-', alpha=0.6, label='Full Maze')
613
+ axacc_test.plot(iters, test_accuracies_most_certain_permaze, 'r-', alpha=0.6, label='Full Maze')
614
+
615
+ axacc_train.set_title('Train Accuracy')
616
+ axacc_test.set_title('Test Accuracy')
617
+ axacc_train.legend(loc='lower right')
618
+ axacc_test.legend(loc='lower right')
619
+ axacc_train.set_xlim([0, args.training_iterations])
620
+ axacc_test.set_xlim([0, args.training_iterations])
621
+ axacc_train.set_ylim([0, 1]) # Set Ylim for accuracy
622
+ axacc_test.set_ylim([0, 1])
623
+
624
+ figacc.tight_layout()
625
+ figacc.savefig(f'{args.log_dir}/accuracies.png', dpi=150)
626
+ plt.close(figacc)
627
+
628
+ # Loss Plot
629
+ figloss = plt.figure(figsize=(10, 5))
630
+ axloss = figloss.add_subplot(111)
631
+ axloss.plot(iters, train_losses, 'b-', linewidth=1, alpha=0.8, label=f'Train: {train_losses[-1]:.4f}')
632
+ axloss.plot(iters, test_losses, 'r-', linewidth=1, alpha=0.8, label=f'Test: {test_losses[-1]:.4f}')
633
+ axloss.legend(loc='upper right')
634
+ axloss.set_xlim([0, args.training_iterations])
635
+ axloss.set_ylim(bottom=0)
636
+
637
+ figloss.tight_layout()
638
+ figloss.savefig(f'{args.log_dir}/losses.png', dpi=150)
639
+ plt.close(figloss)
640
+
641
+ # --- Visualization Section (Conditional) ---
642
+ if args.model in ['ctm', 'lstm']:
643
+ # try:
644
+ inputs_viz, targets_viz = next(iter(testloader))
645
+ inputs_viz = inputs_viz.to(device)
646
+ targets_viz = targets_viz.to(device)
647
+ # Find longest path in batch for potentially better visualization
648
+ longest_index = (targets_viz!=4).sum(-1).argmax() # Action 4 assumed padding/end
649
+
650
+ # Track internal states
651
+ predictions_viz_raw, certainties_viz, _, pre_activations_viz, post_activations_viz, attention_tracking_viz = model(inputs_viz, track=True)
652
+
653
+ # Reshape predictions (assuming raw is B, D, T)
654
+ predictions_viz = predictions_viz_raw.reshape(predictions_viz_raw.size(0), -1, 5, predictions_viz_raw.size(-1)) # B, S, C, T
655
+
656
+ att_shape = (model.kv_features.shape[2], model.kv_features.shape[3])
657
+ attention_tracking_viz = attention_tracking_viz.reshape(
658
+ attention_tracking_viz.shape[0],
659
+ attention_tracking_viz.shape[1], -1, att_shape[0], att_shape[1])
660
+
661
+ # Plot dynamics (common plotting function)
662
+ plot_neural_dynamics(post_activations_viz, 100, args.log_dir, axis_snap=True)
663
+
664
+ # Create maze GIF (task-specific plotting)
665
+ make_maze_gif((inputs_viz[longest_index].detach().cpu().numpy()+1)/2,
666
+ predictions_viz[longest_index].detach().cpu().numpy(), # Pass reshaped B,S,C,T -> S,C,T
667
+ targets_viz[longest_index].detach().cpu().numpy(), # S
668
+ attention_tracking_viz[:, longest_index], # Pass T, (H), H, W
669
+ args.log_dir)
670
+ # except Exception as e:
671
+ # print(f"Visualization failed for model {args.model}: {e}")
672
+ # --- End Visualization ---
673
+
674
+ model.train() # Switch back to train mode
675
+
676
+
677
+ # Save model checkpoint
678
+ if (bi % args.save_every == 0 or bi == args.training_iterations - 1) and bi != start_iter:
679
+ pbar.set_description('Saving model checkpoint...')
680
+ checkpoint_data = {
681
+ 'model_state_dict': model.state_dict(),
682
+ 'optimizer_state_dict': optimizer.state_dict(),
683
+ 'scheduler_state_dict': scheduler.state_dict(),
684
+ 'scaler_state_dict': scaler.state_dict(), # Save scaler state
685
+ 'iteration': bi,
686
+ # Save all tracked metrics
687
+ 'train_losses': train_losses,
688
+ 'test_losses': test_losses,
689
+ 'train_accuracies': train_accuracies, # List of (S, T) or (S,) arrays
690
+ 'test_accuracies': test_accuracies, # List of (S, T) or (S,) arrays
691
+ 'train_accuracies_most_certain': train_accuracies_most_certain, # List of scalars
692
+ 'test_accuracies_most_certain': test_accuracies_most_certain, # List of scalars
693
+ 'train_accuracies_most_certain_permaze': train_accuracies_most_certain_permaze, # List of scalars
694
+ 'test_accuracies_most_certain_permaze': test_accuracies_most_certain_permaze, # List of scalars
695
+ 'iters': iters,
696
+ 'args': args, # Save args used for this run
697
+ # RNG states
698
+ 'torch_rng_state': torch.get_rng_state(),
699
+ 'numpy_rng_state': np.random.get_state(),
700
+ 'random_rng_state': random.getstate(),
701
+ }
702
+ torch.save(checkpoint_data, f'{args.log_dir}/checkpoint.pt')
703
+
704
+ pbar.update(1)
tasks/mazes/train_distributed.py ADDED
@@ -0,0 +1,782 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import random
4
+ import gc
5
+
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+ import seaborn as sns
9
+ sns.set_style('darkgrid')
10
+ import torch
11
+ if torch.cuda.is_available():
12
+ # For faster
13
+ torch.set_float32_matmul_precision('high')
14
+ import torch.distributed as dist
15
+ from torch.nn.parallel import DistributedDataParallel as DDP
16
+ from torch.utils.data.distributed import DistributedSampler
17
+ from utils.samplers import FastRandomDistributedSampler
18
+ from tqdm.auto import tqdm
19
+
20
+ # Data/Task Specific Imports
21
+ from data.custom_datasets import MazeImageFolder
22
+
23
+ # Model Imports
24
+ from models.ctm import ContinuousThoughtMachine
25
+ from models.lstm import LSTMBaseline
26
+ from models.ff import FFBaseline
27
+
28
+ # Plotting/Utils Imports
29
+ from tasks.mazes.plotting import make_maze_gif
30
+ from tasks.image_classification.plotting import plot_neural_dynamics
31
+ from utils.housekeeping import set_seed, zip_python_code
32
+ from utils.losses import maze_loss
33
+ from utils.schedulers import WarmupCosineAnnealingLR, WarmupMultiStepLR, warmup
34
+
35
+ import torchvision
36
+ torchvision.disable_beta_transforms_warning()
37
+
38
+ import warnings
39
+ warnings.filterwarnings("ignore", message="using precomputed metric; inverse_transform will be unavailable")
40
+ warnings.filterwarnings('ignore', message='divide by zero encountered in power', category=RuntimeWarning)
41
+ warnings.filterwarnings(
42
+ "ignore",
43
+ "Corrupt EXIF data",
44
+ UserWarning,
45
+ r"^PIL\.TiffImagePlugin$"
46
+ )
47
+ warnings.filterwarnings(
48
+ "ignore",
49
+ "UserWarning: Metadata Warning",
50
+ UserWarning,
51
+ r"^PIL\.TiffImagePlugin$"
52
+ )
53
+ warnings.filterwarnings(
54
+ "ignore",
55
+ "UserWarning: Truncated File Read",
56
+ UserWarning,
57
+ r"^PIL\.TiffImagePlugin$"
58
+ )
59
+
60
+
61
+ def parse_args():
62
+ parser = argparse.ArgumentParser()
63
+
64
+ # Model Selection
65
+ parser.add_argument('--model', type=str, required=True, choices=['ctm', 'lstm', 'ff'], help='Model type to train.')
66
+
67
+ # Model Architecture
68
+ parser.add_argument('--d_model', type=int, default=512, help='Dimension of the model.')
69
+ parser.add_argument('--dropout', type=float, default=0.0, help='Dropout rate.')
70
+ parser.add_argument('--backbone_type', type=str, default='resnet34-2', help='Type of backbone featureiser.')
71
+ # CTM / LSTM specific
72
+ parser.add_argument('--d_input', type=int, default=128, help='Dimension of the input (CTM, LSTM).')
73
+ parser.add_argument('--heads', type=int, default=8, help='Number of attention heads (CTM, LSTM).')
74
+ parser.add_argument('--iterations', type=int, default=75, help='Number of internal ticks (CTM, LSTM).')
75
+ parser.add_argument('--positional_embedding_type', type=str, default='none',
76
+ help='Type of positional embedding (CTM, LSTM).', choices=['none',
77
+ 'learnable-fourier',
78
+ 'multi-learnable-fourier',
79
+ 'custom-rotational'])
80
+ # CTM specific
81
+ parser.add_argument('--synapse_depth', type=int, default=8, help='Depth of U-NET model for synapse. 1=linear, no unet (CTM only).')
82
+ parser.add_argument('--n_synch_out', type=int, default=32, help='Number of neurons to use for output synch (CTM only).')
83
+ parser.add_argument('--n_synch_action', type=int, default=32, help='Number of neurons to use for observation/action synch (CTM only).')
84
+ parser.add_argument('--neuron_select_type', type=str, default='random-pairing', help='Protocol for selecting neuron subset (CTM only).')
85
+ parser.add_argument('--n_random_pairing_self', type=int, default=0, help='Number of neurons paired self-to-self for synch (CTM only).')
86
+ parser.add_argument('--memory_length', type=int, default=25, help='Length of the pre-activation history for NLMS (CTM only).')
87
+ parser.add_argument('--deep_memory', action=argparse.BooleanOptionalAction, default=True, help='Use deep memory (CTM only).')
88
+ parser.add_argument('--memory_hidden_dims', type=int, default=32, help='Hidden dimensions of the memory if using deep memory (CTM only).')
89
+ parser.add_argument('--dropout_nlm', type=float, default=None, help='Dropout rate for NLMs specifically. Unset to match dropout on the rest of the model (CTM only).')
90
+ parser.add_argument('--do_normalisation', action=argparse.BooleanOptionalAction, default=False, help='Apply normalization in NLMs (CTM only).')
91
+ # LSTM specific
92
+ parser.add_argument('--num_layers', type=int, default=2, help='Number of LSTM stacked layers (LSTM only).')
93
+
94
+ # Task Specific Args
95
+ parser.add_argument('--maze_route_length', type=int, default=100, help='Length to truncate targets.')
96
+ parser.add_argument('--cirriculum_lookahead', type=int, default=5, help='How far to look ahead for cirriculum.')
97
+
98
+ # Training
99
+ parser.add_argument('--batch_size', type=int, default=16, help='Batch size for training (per GPU).')
100
+ parser.add_argument('--batch_size_test', type=int, default=64, help='Batch size for testing (per GPU).')
101
+ parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate for the model.')
102
+ parser.add_argument('--training_iterations', type=int, default=100001, help='Number of training iterations.')
103
+ parser.add_argument('--warmup_steps', type=int, default=5000, help='Number of warmup steps.')
104
+ parser.add_argument('--use_scheduler', action=argparse.BooleanOptionalAction, default=True, help='Use a learning rate scheduler.')
105
+ parser.add_argument('--scheduler_type', type=str, default='cosine', choices=['multistep', 'cosine'], help='Type of learning rate scheduler.')
106
+ parser.add_argument('--milestones', type=int, default=[8000, 15000, 20000], nargs='+', help='Learning rate scheduler milestones.')
107
+ parser.add_argument('--gamma', type=float, default=0.1, help='Learning rate scheduler gamma for multistep.')
108
+ parser.add_argument('--weight_decay', type=float, default=0.0, help='Weight decay factor.')
109
+ parser.add_argument('--weight_decay_exclusion_list', type=str, nargs='+', default=[], help='List to exclude from weight decay. Typically good: bn, ln, bias, start')
110
+ parser.add_argument('--num_workers_train', type=int, default=0, help='Num workers training.')
111
+ parser.add_argument('--gradient_clipping', type=float, default=-1, help='Gradient quantile clipping value (-1 to disable).')
112
+ parser.add_argument('--use_custom_sampler', action=argparse.BooleanOptionalAction, default=False, help='Use custom fast sampler to avoid reshuffling.')
113
+ parser.add_argument('--do_compile', action=argparse.BooleanOptionalAction, default=False, help='Try to compile model components.')
114
+
115
+ # Logging and Saving
116
+ parser.add_argument('--log_dir', type=str, default='logs/scratch', help='Directory for logging.')
117
+ parser.add_argument('--dataset', type=str, default='mazes-medium', help='Dataset to use.', choices=['mazes-medium', 'mazes-large'])
118
+ parser.add_argument('--save_every', type=int, default=1000, help='Save checkpoints every this many iterations.')
119
+ parser.add_argument('--seed', type=int, default=412, help='Random seed.')
120
+ parser.add_argument('--reload', action=argparse.BooleanOptionalAction, default=False, help='Reload from disk?')
121
+ parser.add_argument('--reload_model_only', action=argparse.BooleanOptionalAction, default=False, help='Reload only the model from disk?') # Default False based on user edit
122
+ parser.add_argument('--strict_reload', action=argparse.BooleanOptionalAction, default=False, help='Should use strict reload for model weights.')
123
+ parser.add_argument('--ignore_metrics_when_reloading', action=argparse.BooleanOptionalAction, default=False, help='Ignore metrics when reloading (for debugging)?')
124
+
125
+ # Tracking
126
+ parser.add_argument('--track_every', type=int, default=1000, help='Track metrics every this many iterations.')
127
+ parser.add_argument('--n_test_batches', type=int, default=2, help='How many minibatches to approx metrics. Set to -1 for full eval')
128
+
129
+ # Precision
130
+ parser.add_argument('--use_amp', action=argparse.BooleanOptionalAction, default=False, help='AMP autocast.')
131
+
132
+ args = parser.parse_args()
133
+ return args
134
+
135
+ # --- DDP Setup Functions ---
136
+ def setup_ddp():
137
+ if 'RANK' not in os.environ:
138
+ os.environ['RANK'] = '0'
139
+ os.environ['WORLD_SIZE'] = '1'
140
+ os.environ['MASTER_ADDR'] = 'localhost'
141
+ os.environ['MASTER_PORT'] = '12356' # Different port from image classification
142
+ os.environ['LOCAL_RANK'] = '0'
143
+ print("Running in non-distributed mode (simulated DDP setup).")
144
+ if not torch.cuda.is_available() or int(os.environ['WORLD_SIZE']) == 1:
145
+ dist.init_process_group(backend='gloo')
146
+ print("Initialized process group with Gloo backend for single/CPU process.")
147
+ rank = int(os.environ['RANK'])
148
+ world_size = int(os.environ['WORLD_SIZE'])
149
+ local_rank = int(os.environ['LOCAL_RANK'])
150
+ return rank, world_size, local_rank
151
+
152
+ dist.init_process_group(backend='nccl')
153
+ rank = int(os.environ['RANK'])
154
+ world_size = int(os.environ['WORLD_SIZE'])
155
+ local_rank = int(os.environ['LOCAL_RANK'])
156
+ if torch.cuda.is_available():
157
+ torch.cuda.set_device(local_rank)
158
+ print(f"Rank {rank} setup on GPU {local_rank}")
159
+ else:
160
+ print(f"Rank {rank} setup on CPU")
161
+ return rank, world_size, local_rank
162
+
163
+ def cleanup_ddp():
164
+ if dist.is_initialized():
165
+ dist.destroy_process_group()
166
+ print("DDP cleanup complete.")
167
+
168
+ def is_main_process(rank):
169
+ return rank == 0
170
+ # --- End DDP Setup ---
171
+
172
+
173
+ if __name__=='__main__':
174
+
175
+ args = parse_args()
176
+
177
+ rank, world_size, local_rank = setup_ddp()
178
+
179
+ set_seed(args.seed + rank, False)
180
+
181
+ # Rank 0 handles directory creation and initial logging
182
+ if is_main_process(rank):
183
+ if not os.path.exists(args.log_dir): os.makedirs(args.log_dir)
184
+ zip_python_code(f'{args.log_dir}/repo_state.zip')
185
+ with open(f'{args.log_dir}/args.txt', 'w') as f:
186
+ print(args, file=f)
187
+ if world_size > 1: dist.barrier()
188
+
189
+
190
+ assert args.dataset in ['mazes-medium', 'mazes-large']
191
+
192
+ # Setup Device
193
+ if torch.cuda.is_available():
194
+ device = torch.device(f'cuda:{local_rank}')
195
+ else:
196
+ device = torch.device('cpu')
197
+ if world_size > 1: warnings.warn("Running DDP on CPU is not recommended.")
198
+
199
+ if is_main_process(rank):
200
+ print(f'Main process (Rank {rank}): Using device {device}. World size: {world_size}. Model: {args.model}')
201
+
202
+
203
+ prediction_reshaper = [args.maze_route_length, 5]
204
+ args.out_dims = args.maze_route_length * 5
205
+
206
+ # --- Model Definition (Conditional) ---
207
+ model_base = None # Base model before DDP wrapping
208
+ if args.model == 'ctm':
209
+ model_base = ContinuousThoughtMachine(
210
+ iterations=args.iterations,
211
+ d_model=args.d_model,
212
+ d_input=args.d_input,
213
+ heads=args.heads,
214
+ n_synch_out=args.n_synch_out,
215
+ n_synch_action=args.n_synch_action,
216
+ synapse_depth=args.synapse_depth,
217
+ memory_length=args.memory_length,
218
+ deep_nlms=args.deep_memory,
219
+ memory_hidden_dims=args.memory_hidden_dims,
220
+ do_layernorm_nlm=args.do_normalisation,
221
+ backbone_type=args.backbone_type,
222
+ positional_embedding_type=args.positional_embedding_type,
223
+ out_dims=args.out_dims,
224
+ prediction_reshaper=prediction_reshaper,
225
+ dropout=args.dropout,
226
+ dropout_nlm=args.dropout_nlm,
227
+ neuron_select_type=args.neuron_select_type,
228
+ n_random_pairing_self=args.n_random_pairing_self,
229
+ ).to(device)
230
+ elif args.model == 'lstm':
231
+ model_base = LSTMBaseline(
232
+ num_layers=args.num_layers,
233
+ iterations=args.iterations,
234
+ d_model=args.d_model,
235
+ d_input=args.d_input,
236
+ heads=args.heads,
237
+ backbone_type=args.backbone_type,
238
+ positional_embedding_type=args.positional_embedding_type,
239
+ out_dims=args.out_dims,
240
+ prediction_reshaper=prediction_reshaper,
241
+ dropout=args.dropout,
242
+ ).to(device)
243
+ elif args.model == 'ff':
244
+ model_base = FFBaseline(
245
+ d_model=args.d_model,
246
+ backbone_type=args.backbone_type,
247
+ out_dims=args.out_dims,
248
+ dropout=args.dropout,
249
+ ).to(device)
250
+ else:
251
+ raise ValueError(f"Unknown model type: {args.model}")
252
+
253
+ # Use pseudo-input *before* DDP wrapping
254
+ try:
255
+ # Determine pseudo input shape based on dataset
256
+ h_w = 39 if args.dataset in ['mazes-small', 'mazes-medium'] else 99 # Example dimensions
257
+ pseudo_inputs = torch.zeros((1, 3, h_w, h_w), device=device).float()
258
+ model_base(pseudo_inputs)
259
+ except Exception as e:
260
+ print(f"Warning: Pseudo forward pass failed: {e}")
261
+
262
+ if is_main_process(rank):
263
+ print(f'Total params: {sum(p.numel() for p in model_base.parameters() if p.requires_grad)}')
264
+
265
+ # Wrap model with DDP
266
+ if device.type == 'cuda' and world_size > 1:
267
+ model = DDP(model_base, device_ids=[local_rank], output_device=local_rank)
268
+ elif device.type == 'cpu' and world_size > 1:
269
+ model = DDP(model_base)
270
+ else:
271
+ model = model_base
272
+ # --- End Model Definition ---
273
+
274
+
275
+ # Data Loading (After model setup to allow pseudo pass first)
276
+ dataset_mean = [0,0,0]
277
+ dataset_std = [1,1,1]
278
+ which_maze = args.dataset.split('-')[-1]
279
+ data_root = f'data/mazes/{which_maze}'
280
+
281
+ train_data = MazeImageFolder(root=f'{data_root}/train/', which_set='train', maze_route_length=args.maze_route_length)
282
+ test_data = MazeImageFolder(root=f'{data_root}/test/', which_set='test', maze_route_length=args.maze_route_length)
283
+
284
+ train_sampler = (FastRandomDistributedSampler(train_data, num_replicas=world_size, rank=rank, seed=args.seed, epoch_steps=int(10e10))
285
+ if args.use_custom_sampler else
286
+ DistributedSampler(train_data, num_replicas=world_size, rank=rank, shuffle=True, seed=args.seed))
287
+ test_sampler = DistributedSampler(test_data, num_replicas=world_size, rank=rank, shuffle=False, seed=args.seed)
288
+
289
+ num_workers_test = 1
290
+ trainloader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, sampler=train_sampler,
291
+ num_workers=args.num_workers_train, pin_memory=True, drop_last=True)
292
+ testloader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size_test, sampler=test_sampler,
293
+ num_workers=num_workers_test, pin_memory=True, drop_last=False)
294
+
295
+
296
+ # Optimizer and scheduler
297
+ decay_params = []
298
+ no_decay_params = []
299
+ no_decay_names = []
300
+ for name, param in model.named_parameters():
301
+ if not param.requires_grad:
302
+ continue # Skip parameters that don't require gradients
303
+ if any(exclusion_str in name for exclusion_str in args.weight_decay_exclusion_list):
304
+ no_decay_params.append(param)
305
+ no_decay_names.append(name)
306
+ else:
307
+ decay_params.append(param)
308
+ if len(no_decay_names) and is_main_process(rank):
309
+ print(f'WARNING, excluding: {no_decay_names}')
310
+
311
+ # Optimizer and scheduler (Common setup)
312
+ if len(no_decay_names) and args.weight_decay!=0:
313
+ optimizer = torch.optim.AdamW([{'params': decay_params, 'weight_decay':args.weight_decay},
314
+ {'params': no_decay_params, 'weight_decay':0}],
315
+ lr=args.lr,
316
+ eps=1e-8 if not args.use_amp else 1e-6)
317
+ else:
318
+ optimizer = torch.optim.AdamW(model.parameters(),
319
+ lr=args.lr,
320
+ eps=1e-8 if not args.use_amp else 1e-6,
321
+ weight_decay=args.weight_decay)
322
+
323
+ warmup_schedule = warmup(args.warmup_steps)
324
+ scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warmup_schedule.step)
325
+ if args.use_scheduler:
326
+ if args.scheduler_type == 'multistep':
327
+ scheduler = WarmupMultiStepLR(optimizer, warmup_steps=args.warmup_steps, milestones=args.milestones, gamma=args.gamma)
328
+ elif args.scheduler_type == 'cosine':
329
+ scheduler = WarmupCosineAnnealingLR(optimizer, args.warmup_steps, args.training_iterations, warmup_start_lr=1e-20, eta_min=1e-7)
330
+ else:
331
+ raise NotImplementedError
332
+
333
+
334
+ # Metrics tracking (Rank 0 stores history)
335
+ start_iter = 0
336
+ iters = []
337
+ train_losses, test_losses = [], []
338
+ train_accuracies, test_accuracies = [], [] # Avg Step Acc (scalar list)
339
+ train_accuracies_most_certain, test_accuracies_most_certain = [], [] # Avg Step Acc @ Certain tick (scalar list)
340
+ train_accuracies_most_certain_permaze, test_accuracies_most_certain_permaze = [], [] # Full Maze Acc @ Certain tick (scalar list)
341
+
342
+
343
+ scaler = torch.amp.GradScaler("cuda" if device.type == 'cuda' else "cpu", enabled=args.use_amp)
344
+
345
+ # Reloading Logic
346
+ if args.reload:
347
+ map_location = device
348
+ chkpt_path = f'{args.log_dir}/checkpoint.pt'
349
+ if os.path.isfile(chkpt_path):
350
+ print(f'Rank {rank}: Reloading from: {chkpt_path}')
351
+ if not args.strict_reload: print('WARNING: not using strict reload for model weights!')
352
+
353
+ checkpoint = torch.load(chkpt_path, map_location=map_location, weights_only=False)
354
+
355
+ model_to_load = model.module if isinstance(model, DDP) else model
356
+ state_dict = checkpoint['model_state_dict']
357
+ has_module_prefix = all(k.startswith('module.') for k in state_dict)
358
+ is_wrapped = isinstance(model, DDP)
359
+
360
+ if has_module_prefix and not is_wrapped:
361
+ state_dict = {k.partition('module.')[2]: v for k,v in state_dict.items()}
362
+ elif not has_module_prefix and is_wrapped:
363
+ load_result = model_to_load.load_state_dict(state_dict, strict=args.strict_reload)
364
+ print(f" Loaded state_dict. Missing: {load_result.missing_keys}, Unexpected: {load_result.unexpected_keys}")
365
+ state_dict = None # Prevent loading again
366
+
367
+ if state_dict is not None:
368
+ load_result = model_to_load.load_state_dict(state_dict, strict=args.strict_reload)
369
+ print(f" Loaded state_dict. Missing: {load_result.missing_keys}, Unexpected: {load_result.unexpected_keys}")
370
+
371
+
372
+
373
+ if not args.reload_model_only:
374
+ print(f'Rank {rank}: Reloading optimizer, scheduler, scaler, iteration.')
375
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
376
+ scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
377
+ scaler.load_state_dict(checkpoint['scaler_state_dict'])
378
+ start_iter = checkpoint['iteration']
379
+
380
+ if is_main_process(rank) and not args.ignore_metrics_when_reloading:
381
+ print(f'Rank {rank}: Reloading metrics history.')
382
+ iters = checkpoint['iters']
383
+ train_losses = checkpoint['train_losses']
384
+ test_losses = checkpoint['test_losses']
385
+ train_accuracies = checkpoint['train_accuracies'] # Reloading simplified avg step acc list
386
+ test_accuracies = checkpoint['test_accuracies']
387
+ train_accuracies_most_certain = checkpoint['train_accuracies_most_certain']
388
+ test_accuracies_most_certain = checkpoint['test_accuracies_most_certain']
389
+ train_accuracies_most_certain_permaze = checkpoint['train_accuracies_most_certain_permaze']
390
+ test_accuracies_most_certain_permaze = checkpoint['test_accuracies_most_certain_permaze']
391
+ elif is_main_process(rank) and args.ignore_metrics_when_reloading:
392
+ print(f'Rank {rank}: Ignoring metrics history upon reload.')
393
+ else:
394
+ print(f'Rank {rank}: Only reloading model weights!')
395
+
396
+ if is_main_process(rank) and 'torch_rng_state' in checkpoint and not args.reload_model_only:
397
+ print(f'Rank {rank}: Loading RNG states.')
398
+ torch.set_rng_state(checkpoint['torch_rng_state'].cpu())
399
+ np.random.set_state(checkpoint['numpy_rng_state'])
400
+ random.setstate(checkpoint['random_rng_state'])
401
+
402
+ del checkpoint
403
+ gc.collect()
404
+ if torch.cuda.is_available():
405
+ torch.cuda.empty_cache()
406
+ print(f"Rank {rank}: Reload finished, starting from iteration {start_iter}")
407
+ else:
408
+ print(f"Rank {rank}: Checkpoint not found at {chkpt_path}, starting from scratch.")
409
+
410
+
411
+ if world_size > 1: dist.barrier()
412
+
413
+
414
+ # Conditional Compilation
415
+ if args.do_compile:
416
+ if is_main_process(rank): print('Compiling model components...')
417
+ model_to_compile = model.module if isinstance(model, DDP) else model
418
+ if hasattr(model_to_compile, 'backbone'):
419
+ model_to_compile.backbone = torch.compile(model_to_compile.backbone, mode='reduce-overhead', fullgraph=True)
420
+ if args.model == 'ctm':
421
+ model_to_compile.synapses = torch.compile(model_to_compile.synapses, mode='reduce-overhead', fullgraph=True)
422
+ if world_size > 1: dist.barrier()
423
+ if is_main_process(rank): print('Compilation finished.')
424
+
425
+
426
+ # --- Training Loop ---
427
+ model.train()
428
+ pbar = tqdm(total=args.training_iterations, initial=start_iter, leave=False, position=0, dynamic_ncols=True, disable=not is_main_process(rank))
429
+
430
+ iterator = iter(trainloader)
431
+
432
+ for bi in range(start_iter, args.training_iterations):
433
+
434
+ # --- Evaluation and Plotting (Rank 0 + Aggregation) ---
435
+ if bi % args.track_every == 0 and (bi != 0 or args.reload_model_only):
436
+ model.eval()
437
+ with torch.inference_mode():
438
+
439
+ # --- Distributed Evaluation ---
440
+ if is_main_process(rank): iters.append(bi) # Track iterations on rank 0
441
+
442
+ # Initialize accumulators on device
443
+ total_train_loss = torch.tensor(0.0, device=device)
444
+ total_train_correct_certain = torch.tensor(0.0, device=device) # Sum correct steps @ certain tick
445
+ total_train_mazes_solved = torch.tensor(0.0, device=device) # Sum solved mazes @ certain tick
446
+ total_train_steps = torch.tensor(0.0, device=device) # Total steps evaluated (B * S)
447
+ total_train_mazes = torch.tensor(0.0, device=device) # Total mazes evaluated (B)
448
+
449
+ # TRAIN METRICS
450
+ train_eval_sampler = DistributedSampler(train_data, num_replicas=world_size, rank=rank, shuffle=False)
451
+ train_eval_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size_test, sampler=train_eval_sampler, num_workers=num_workers_test, pin_memory=True)
452
+
453
+ pbar_inner_desc = 'Eval Train (Rank 0)' if is_main_process(rank) else None
454
+ with tqdm(total=len(train_eval_loader), desc=pbar_inner_desc, leave=False, position=1, dynamic_ncols=True, disable=not is_main_process(rank)) as pbar_inner:
455
+ for inferi, (inputs, targets) in enumerate(train_eval_loader):
456
+ inputs = inputs.to(device, non_blocking=True)
457
+ targets = targets.to(device, non_blocking=True) # B, S
458
+ batch_size = inputs.size(0)
459
+ seq_len = targets.size(1)
460
+
461
+ loss_eval = None
462
+ pred_at_certain = None # Shape B, S
463
+ if args.model == 'ctm':
464
+ predictions_raw, certainties, _ = model(inputs)
465
+ predictions = predictions_raw.reshape(batch_size, -1, 5, predictions_raw.size(-1)) # B,S,C,T
466
+ loss_eval, where_most_certain, _ = maze_loss(predictions, certainties, targets, use_most_certain=True)
467
+ pred_at_certain = predictions.argmax(2)[torch.arange(batch_size, device=device), :, where_most_certain]
468
+ elif args.model == 'lstm':
469
+ predictions_raw, certainties, _ = model(inputs)
470
+ predictions = predictions_raw.reshape(batch_size, -1, 5, predictions_raw.size(-1)) # B,S,C,T
471
+ loss_eval, where_most_certain, _ = maze_loss(predictions, certainties, targets, use_most_certain=False) # where = -1
472
+ pred_at_certain = predictions.argmax(2)[torch.arange(batch_size, device=device), :, where_most_certain]
473
+ elif args.model == 'ff':
474
+ predictions_raw = model(inputs) # B, S*C
475
+ predictions = predictions_raw.reshape(batch_size, -1, 5) # B,S,C
476
+ loss_eval, where_most_certain, _ = maze_loss(predictions.unsqueeze(-1), None, targets, use_most_certain=False) # where = -1
477
+ pred_at_certain = predictions.argmax(2)
478
+
479
+ # Accumulate metrics
480
+ total_train_loss += loss_eval * batch_size # Sum losses
481
+ correct_steps = (pred_at_certain == targets) # B, S boolean
482
+ total_train_correct_certain += correct_steps.sum() # Sum correct steps across batch
483
+ total_train_mazes_solved += correct_steps.all(dim=-1).sum() # Sum mazes where all steps are correct
484
+ total_train_steps += batch_size * seq_len
485
+ total_train_mazes += batch_size
486
+
487
+ if args.n_test_batches != -1 and inferi >= args.n_test_batches -1: break
488
+ pbar_inner.update(1)
489
+
490
+ # Aggregate Train Metrics
491
+ if world_size > 1:
492
+ dist.all_reduce(total_train_loss, op=dist.ReduceOp.SUM)
493
+ dist.all_reduce(total_train_correct_certain, op=dist.ReduceOp.SUM)
494
+ dist.all_reduce(total_train_mazes_solved, op=dist.ReduceOp.SUM)
495
+ dist.all_reduce(total_train_steps, op=dist.ReduceOp.SUM)
496
+ dist.all_reduce(total_train_mazes, op=dist.ReduceOp.SUM)
497
+
498
+ # Calculate final Train metrics on Rank 0
499
+ if is_main_process(rank) and total_train_mazes > 0:
500
+ avg_train_loss = total_train_loss.item() / total_train_mazes.item() # Avg loss per maze/sample
501
+ avg_train_acc_step = total_train_correct_certain.item() / total_train_steps.item() # Avg correct step %
502
+ avg_train_acc_maze = total_train_mazes_solved.item() / total_train_mazes.item() # Avg full maze solved %
503
+ train_losses.append(avg_train_loss)
504
+ train_accuracies_most_certain.append(avg_train_acc_step)
505
+ train_accuracies_most_certain_permaze.append(avg_train_acc_maze)
506
+ # train_accuracies list remains unused/placeholder for this simplified metric structure
507
+ print(f"Iter {bi} Train Metrics (Agg): Loss={avg_train_loss:.4f}, StepAcc={avg_train_acc_step:.4f}, MazeAcc={avg_train_acc_maze:.4f}")
508
+
509
+ # TEST METRICS
510
+ total_test_loss = torch.tensor(0.0, device=device)
511
+ total_test_correct_certain = torch.tensor(0.0, device=device)
512
+ total_test_mazes_solved = torch.tensor(0.0, device=device)
513
+ total_test_steps = torch.tensor(0.0, device=device)
514
+ total_test_mazes = torch.tensor(0.0, device=device)
515
+
516
+ pbar_inner_desc = 'Eval Test (Rank 0)' if is_main_process(rank) else None
517
+ with tqdm(total=len(testloader), desc=pbar_inner_desc, leave=False, position=1, dynamic_ncols=True, disable=not is_main_process(rank)) as pbar_inner:
518
+ for inferi, (inputs, targets) in enumerate(testloader):
519
+ inputs = inputs.to(device, non_blocking=True)
520
+ targets = targets.to(device, non_blocking=True)
521
+ batch_size = inputs.size(0)
522
+ seq_len = targets.size(1)
523
+
524
+ loss_eval = None
525
+ pred_at_certain = None
526
+ if args.model == 'ctm':
527
+ predictions_raw, certainties, _ = model(inputs)
528
+ predictions = predictions_raw.reshape(batch_size, -1, 5, predictions_raw.size(-1))
529
+ loss_eval, where_most_certain, _ = maze_loss(predictions, certainties, targets, use_most_certain=True)
530
+ pred_at_certain = predictions.argmax(2)[torch.arange(batch_size, device=device), :, where_most_certain]
531
+ elif args.model == 'lstm':
532
+ predictions_raw, certainties, _ = model(inputs)
533
+ predictions = predictions_raw.reshape(batch_size, -1, 5, predictions_raw.size(-1))
534
+ loss_eval, where_most_certain, _ = maze_loss(predictions, certainties, targets, use_most_certain=False)
535
+ pred_at_certain = predictions.argmax(2)[torch.arange(batch_size, device=device), :, where_most_certain]
536
+ elif args.model == 'ff':
537
+ predictions_raw = model(inputs)
538
+ predictions = predictions_raw.reshape(batch_size, -1, 5)
539
+ loss_eval, where_most_certain, _ = maze_loss(predictions.unsqueeze(-1), None, targets, use_most_certain=False)
540
+ pred_at_certain = predictions.argmax(2)
541
+
542
+ total_test_loss += loss_eval * batch_size
543
+ correct_steps = (pred_at_certain == targets)
544
+ total_test_correct_certain += correct_steps.sum()
545
+ total_test_mazes_solved += correct_steps.all(dim=-1).sum()
546
+ total_test_steps += batch_size * seq_len
547
+ total_test_mazes += batch_size
548
+
549
+ if args.n_test_batches != -1 and inferi >= args.n_test_batches -1: break
550
+ pbar_inner.update(1)
551
+
552
+ # Aggregate Test Metrics
553
+ if world_size > 1:
554
+ dist.all_reduce(total_test_loss, op=dist.ReduceOp.SUM)
555
+ dist.all_reduce(total_test_correct_certain, op=dist.ReduceOp.SUM)
556
+ dist.all_reduce(total_test_mazes_solved, op=dist.ReduceOp.SUM)
557
+ dist.all_reduce(total_test_steps, op=dist.ReduceOp.SUM)
558
+ dist.all_reduce(total_test_mazes, op=dist.ReduceOp.SUM)
559
+
560
+ # Calculate and Plot final Test metrics on Rank 0
561
+ if is_main_process(rank) and total_test_mazes > 0:
562
+ avg_test_loss = total_test_loss.item() / total_test_mazes.item()
563
+ avg_test_acc_step = total_test_correct_certain.item() / total_test_steps.item()
564
+ avg_test_acc_maze = total_test_mazes_solved.item() / total_test_mazes.item()
565
+ test_losses.append(avg_test_loss)
566
+ test_accuracies_most_certain.append(avg_test_acc_step)
567
+ test_accuracies_most_certain_permaze.append(avg_test_acc_maze)
568
+ print(f"Iter {bi} Test Metrics (Agg): Loss={avg_test_loss:.4f}, StepAcc={avg_test_acc_step:.4f}, MazeAcc={avg_test_acc_maze:.4f}\n")
569
+
570
+ # --- Plotting ---
571
+ figacc = plt.figure(figsize=(10, 10))
572
+ axacc_train = figacc.add_subplot(211)
573
+ axacc_test = figacc.add_subplot(212)
574
+
575
+ # Plot Avg Step Accuracy
576
+ axacc_train.plot(iters, train_accuracies_most_certain, 'k-', alpha=0.7, label=f'Avg Step Acc ({train_accuracies_most_certain[-1]:.3f})')
577
+ axacc_test.plot(iters, test_accuracies_most_certain, 'k-', alpha=0.7, label=f'Avg Step Acc ({test_accuracies_most_certain[-1]:.3f})')
578
+ # Plot Full Maze Accuracy
579
+ axacc_train.plot(iters, train_accuracies_most_certain_permaze, 'r-', alpha=0.6, label=f'Full Maze Acc ({train_accuracies_most_certain_permaze[-1]:.3f})')
580
+ axacc_test.plot(iters, test_accuracies_most_certain_permaze, 'r-', alpha=0.6, label=f'Full Maze Acc ({test_accuracies_most_certain_permaze[-1]:.3f})')
581
+
582
+ axacc_train.set_title('Train Accuracy (Aggregated)')
583
+ axacc_test.set_title('Test Accuracy (Aggregated)')
584
+ axacc_train.legend(loc='lower right')
585
+ axacc_test.legend(loc='lower right')
586
+ axacc_train.set_xlim([0, args.training_iterations])
587
+ axacc_test.set_xlim([0, args.training_iterations])
588
+ axacc_train.set_ylim([0, 1])
589
+ axacc_test.set_ylim([0, 1])
590
+
591
+ figacc.tight_layout()
592
+ figacc.savefig(f'{args.log_dir}/accuracies.png', dpi=150)
593
+ plt.close(figacc)
594
+
595
+ # Loss Plot
596
+ figloss = plt.figure(figsize=(10, 5))
597
+ axloss = figloss.add_subplot(111)
598
+ axloss.plot(iters, train_losses, 'b-', linewidth=1, alpha=0.8, label=f'Train (Agg): {train_losses[-1]:.4f}')
599
+ axloss.plot(iters, test_losses, 'r-', linewidth=1, alpha=0.8, label=f'Test (Agg): {test_losses[-1]:.4f}')
600
+ axloss.legend(loc='upper right')
601
+ axloss.set_xlabel("Iteration")
602
+ axloss.set_ylabel("Loss")
603
+ axloss.set_xlim([0, args.training_iterations])
604
+ axloss.set_ylim(bottom=0)
605
+ figloss.tight_layout()
606
+ figloss.savefig(f'{args.log_dir}/losses.png', dpi=150)
607
+ plt.close(figloss)
608
+ # --- End Plotting ---
609
+
610
+
611
+ # --- Visualization (Rank 0, Conditional) ---
612
+ if is_main_process(rank) and args.model in ['ctm', 'lstm']:
613
+ # try:
614
+ model_module = model.module if isinstance(model, DDP) else model
615
+ # Use a consistent batch for viz if possible, or just next batch
616
+ inputs_viz, targets_viz = next(iter(testloader))
617
+ inputs_viz = inputs_viz.to(device)
618
+ targets_viz = targets_viz.to(device)
619
+ longest_index = (targets_viz!=4).sum(-1).argmax() # 4 assumed padding
620
+
621
+ pbar.set_description('Tracking (Rank 0): Viz Fwd Pass')
622
+ predictions_viz_raw, _, _, _, post_activations_viz, attention_tracking_viz = model_module(inputs_viz, track=True)
623
+ predictions_viz = predictions_viz_raw.reshape(predictions_viz_raw.size(0), -1, 5, predictions_viz_raw.size(-1))
624
+
625
+ att_shape = (model.module.kv_features.shape[2], model.module.kv_features.shape[3])
626
+ attention_tracking_viz = attention_tracking_viz.reshape(
627
+ attention_tracking_viz.shape[0],
628
+ attention_tracking_viz.shape[1], -1, att_shape[0], att_shape[1])
629
+
630
+ pbar.set_description('Tracking (Rank 0): Dynamics Plot')
631
+ plot_neural_dynamics(post_activations_viz, 100, args.log_dir, axis_snap=True)
632
+
633
+ pbar.set_description('Tracking (Rank 0): Maze GIF')
634
+ if attention_tracking_viz is not None:
635
+ make_maze_gif((inputs_viz[longest_index].detach().cpu().numpy()+1)/2,
636
+ predictions_viz[longest_index].detach().cpu().numpy(),
637
+ targets_viz[longest_index].detach().cpu().numpy(),
638
+ attention_tracking_viz[:, longest_index],
639
+ args.log_dir)
640
+ # else:
641
+ # print("Skipping maze GIF due to attention shape issue.")
642
+
643
+ # except Exception as e_viz:
644
+ # print(f"Rank 0 visualization failed: {e_viz}")
645
+ # --- End Visualization ---
646
+
647
+ gc.collect()
648
+ if torch.cuda.is_available():
649
+ torch.cuda.empty_cache()
650
+ if world_size > 1: dist.barrier()
651
+ model.train()
652
+ # --- End Evaluation Block ---
653
+
654
+
655
+
656
+
657
+ if hasattr(train_sampler, 'set_epoch'): # Check if sampler has set_epoch
658
+ train_sampler.set_epoch(bi)
659
+
660
+ current_lr = optimizer.param_groups[-1]['lr']
661
+
662
+ try:
663
+ inputs, targets = next(iterator)
664
+ except StopIteration:
665
+ iterator = iter(trainloader)
666
+ inputs, targets = next(iterator)
667
+
668
+ inputs = inputs.to(device, non_blocking=True)
669
+ targets = targets.to(device, non_blocking=True)
670
+
671
+ # Defaults for logging
672
+ loss = torch.tensor(0.0, device=device) # Need loss defined for logging scope
673
+ accuracy_finegrained = 0.0
674
+ where_most_certain_val = -1.0
675
+ where_most_certain_std = 0.0
676
+ where_most_certain_min = -1
677
+ where_most_certain_max = -1
678
+ upto_where_mean = -1.0
679
+ upto_where_std = 0.0
680
+ upto_where_min = -1
681
+ upto_where_max = -1
682
+
683
+ with torch.autocast(device_type="cuda" if device.type == 'cuda' else "cpu", dtype=torch.float16, enabled=args.use_amp):
684
+ if args.do_compile: torch.compiler.cudagraph_mark_step_begin()
685
+
686
+ if args.model == 'ctm':
687
+ predictions_raw, certainties, _ = model(inputs)
688
+ predictions = predictions_raw.reshape(predictions_raw.size(0), -1, 5, predictions_raw.size(-1)) # B,S,C,T
689
+ loss, where_most_certain, upto_where = maze_loss(predictions, certainties, targets, cirriculum_lookahead=args.cirriculum_lookahead, use_most_certain=True)
690
+ with torch.no_grad(): # Calculate local accuracy for logging
691
+ accuracy_finegrained = (predictions.argmax(2)[torch.arange(predictions.size(0), device=device), :, where_most_certain] == targets).float().mean().item()
692
+ elif args.model == 'lstm':
693
+ predictions_raw, certainties, _ = model(inputs)
694
+ predictions = predictions_raw.reshape(predictions_raw.size(0), -1, 5, predictions_raw.size(-1)) # B,S,C,T
695
+ loss, where_most_certain, upto_where = maze_loss(predictions, certainties, targets, cirriculum_lookahead=args.cirriculum_lookahead, use_most_certain=False) # where = -1
696
+ with torch.no_grad():
697
+ accuracy_finegrained = (predictions.argmax(2)[torch.arange(predictions.size(0), device=device), :, where_most_certain] == targets).float().mean().item()
698
+ elif args.model == 'ff':
699
+ predictions_raw = model(inputs) # B, S*C
700
+ predictions = predictions_raw.reshape(predictions_raw.size(0), -1, 5) # B,S,C
701
+ loss, where_most_certain, upto_where = maze_loss(predictions.unsqueeze(-1), None, targets, cirriculum_lookahead=args.cirriculum_lookahead, use_most_certain=False) # where = -1
702
+ with torch.no_grad():
703
+ accuracy_finegrained = (predictions.argmax(2) == targets).float().mean().item()
704
+
705
+ # Extract stats from loss outputs
706
+ if torch.is_tensor(where_most_certain):
707
+ where_most_certain_val = where_most_certain.float().mean().item()
708
+ where_most_certain_std = where_most_certain.float().std().item()
709
+ where_most_certain_min = where_most_certain.min().item()
710
+ where_most_certain_max = where_most_certain.max().item()
711
+ elif isinstance(where_most_certain, int):
712
+ where_most_certain_val = float(where_most_certain); where_most_certain_min = where_most_certain; where_most_certain_max = where_most_certain
713
+ if isinstance(upto_where, (np.ndarray, list)) and len(upto_where) > 0:
714
+ upto_where_mean = np.mean(upto_where); upto_where_std = np.std(upto_where); upto_where_min = np.min(upto_where); upto_where_max = np.max(upto_where)
715
+
716
+ # Backprop / Step
717
+ scaler.scale(loss).backward()
718
+ if args.gradient_clipping!=-1:
719
+ scaler.unscale_(optimizer)
720
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.gradient_clipping)
721
+ scaler.step(optimizer)
722
+ scaler.update()
723
+ optimizer.zero_grad(set_to_none=True)
724
+ scheduler.step()
725
+
726
+ # --- Aggregation and Logging (Rank 0) ---
727
+ loss_log = loss.detach()
728
+ if world_size > 1: dist.all_reduce(loss_log, op=dist.ReduceOp.AVG)
729
+
730
+ if is_main_process(rank):
731
+ pbar_desc = f'Loss(avg)={loss_log.item():.3f} Acc(loc)={accuracy_finegrained:.3f} LR={current_lr:.6f}'
732
+ if args.model in ['ctm', 'lstm'] or torch.is_tensor(where_most_certain):
733
+ pbar_desc += f' Cert={where_most_certain_val:.2f}'#+-{where_most_certain_std:.2f}' # Removed std for brevity
734
+ if isinstance(upto_where, (np.ndarray, list)) and len(upto_where) > 0:
735
+ pbar_desc += f' Path={upto_where_mean:.1f}'#+-{upto_where_std:.1f}'
736
+ pbar.set_description(f'{args.model.upper()} {pbar_desc}')
737
+ # --- End Aggregation and Logging ---
738
+
739
+
740
+
741
+
742
+
743
+ # --- Checkpointing (Rank 0) ---
744
+ if (bi % args.save_every == 0 or bi == args.training_iterations - 1) and bi != start_iter and is_main_process(rank):
745
+ pbar.set_description('Rank 0: Saving checkpoint...')
746
+ save_path = f'{args.log_dir}/checkpoint.pt'
747
+ model_state_to_save = model.module.state_dict() if isinstance(model, DDP) else model.state_dict()
748
+
749
+ checkpoint_data = {
750
+ 'model_state_dict': model_state_to_save,
751
+ 'optimizer_state_dict': optimizer.state_dict(),
752
+ 'scheduler_state_dict': scheduler.state_dict(),
753
+ 'scaler_state_dict': scaler.state_dict(),
754
+ 'iteration': bi,
755
+ 'train_losses': train_losses,
756
+ 'test_losses': test_losses,
757
+ 'train_accuracies': train_accuracies, # Saving simplified scalar list
758
+ 'test_accuracies': test_accuracies, # Saving simplified scalar list
759
+ 'train_accuracies_most_certain': train_accuracies_most_certain,
760
+ 'test_accuracies_most_certain': test_accuracies_most_certain,
761
+ 'train_accuracies_most_certain_permaze': train_accuracies_most_certain_permaze,
762
+ 'test_accuracies_most_certain_permaze': test_accuracies_most_certain_permaze,
763
+ 'iters': iters,
764
+ 'args': args,
765
+ 'torch_rng_state': torch.get_rng_state(),
766
+ 'numpy_rng_state': np.random.get_state(),
767
+ 'random_rng_state': random.getstate(),
768
+ }
769
+ torch.save(checkpoint_data, save_path)
770
+ # --- End Checkpointing ---
771
+
772
+
773
+ if world_size > 1: dist.barrier()
774
+
775
+ if is_main_process(rank):
776
+ pbar.update(1)
777
+ # --- End Training Loop ---
778
+
779
+ if is_main_process(rank):
780
+ pbar.close()
781
+
782
+ cleanup_ddp()
tasks/parity/README.md ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Parity
2
+
3
+ ## Training
4
+ To run the parity training that we used for the paper, run bash scripts from the root level of the repository. For example, to train the 75-iteration, 25-memory-length CTM, run:
5
+
6
+ ```
7
+ bash tasks/parity/scripts/train_ctm_75_25.sh
8
+ ```
9
+
10
+
11
+ ## Analysis
12
+ To run the analysis, first make sure the checkpoints are saved in the log directory (specified by the `log_dir` argument). The checkpoints can be obtained by either running the training code, or downloading them from [this link](https://drive.google.com/file/d/1itUS5_i9AyUo_7awllTx8X0PXYw9fnaG/view?usp=drive_link).
13
+
14
+ ```
15
+ python -m tasks.parity.analysis.run --log_dir <PATH_TO_LOG_DIR>
16
+ ```
tasks/parity/analysis/make_blog_gifs.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import os
4
+ import math
5
+ import imageio
6
+ import numpy as np
7
+ import matplotlib.pyplot as plt
8
+ from matplotlib.patches import FancyArrowPatch
9
+ from scipy.special import softmax
10
+ import matplotlib.cm as cm
11
+ from data.custom_datasets import ParityDataset
12
+ import umap
13
+ from tqdm import tqdm
14
+
15
+
16
+ from models.utils import reshape_predictions
17
+ from tasks.parity.utils import reshape_inputs
18
+ from tasks.parity.analysis.run import build_model_from_checkpoint_path
19
+
20
+ from tasks.image_classification.plotting import save_frames_to_mp4
21
+
22
+
23
+ def make_parity_gif(
24
+ predictions,
25
+ targets,
26
+ post_activations,
27
+ attention_weights,
28
+ inputs_to_model,
29
+ save_path,
30
+ umap_positions,
31
+ umap_point_scaler=1.0,
32
+ ):
33
+ batch_index = 0
34
+ figscale = 0.32
35
+ n_steps, n_heads, seqLen = attention_weights.shape[:3]
36
+ grid_side = int(np.sqrt(seqLen))
37
+ frames = []
38
+
39
+ inputs_this_batch = inputs_to_model[:, batch_index]
40
+ preds_this_batch = predictions[batch_index]
41
+ targets_this_batch = targets[batch_index]
42
+ post_act_this_batch = post_activations[:, batch_index]
43
+
44
+ # build a flexible mosaic
45
+ mosaic = [
46
+ [f"att_0", f"in_0", "probs", "probs", "target", "target"],
47
+ [f"att_1", f"in_1", "probs", "probs", "target", "target"],
48
+ ]
49
+ for h in range(2, n_heads):
50
+ mosaic.append(
51
+ [f"att_{h}", f"in_{h}", "umap", "umap",
52
+ "umap", "umap"]
53
+ )
54
+
55
+ for t in range(n_steps):
56
+ rows = len(mosaic)
57
+ cell_size = figscale * 4
58
+ fig_h = rows * cell_size
59
+
60
+ fig, ax = plt.subplot_mosaic(
61
+ mosaic,
62
+ figsize=(6 * cell_size, fig_h),
63
+ constrained_layout=False,
64
+ gridspec_kw={'wspace': 0.05, 'hspace': 0.05}, # small gaps
65
+ )
66
+ # restore a little margin
67
+ fig.subplots_adjust(left=0.02, right=0.98, top=0.98, bottom=0.02)
68
+
69
+ # probabilities heatmap
70
+ logits_t = preds_this_batch[:, :, t]
71
+ probs_t = softmax(logits_t, axis=1)[:, 0].reshape(grid_side, grid_side)
72
+ ax["probs"].imshow(probs_t, cmap="gray", vmin=0, vmax=1)
73
+ ax["probs"].axis("off")
74
+
75
+ # target overlay
76
+ ax["target"].imshow(
77
+ targets_this_batch.reshape(grid_side, grid_side),
78
+ cmap="gray_r", vmin=0, vmax=1
79
+ )
80
+ ax["target"].axis("off")
81
+ ax["target"].grid(which="minor", color="black", linestyle="-", linewidth=0.5)
82
+
83
+ z = post_act_this_batch[t]
84
+ low, high = np.percentile(z, 5), np.percentile(z, 95)
85
+ z_norm = np.clip((z - low) / (high - low), 0, 1)
86
+ point_sizes = (np.abs(z_norm - 0.5) * 100 + 5) * umap_point_scaler
87
+ cmap = plt.get_cmap("Spectral")
88
+ ax["umap"].scatter(
89
+ umap_positions[:, 0],
90
+ umap_positions[:, 1],
91
+ s=point_sizes,
92
+ c=cmap(z_norm),
93
+ alpha=0.8
94
+ )
95
+ ax["umap"].axis("off")
96
+
97
+
98
+ # normalize attention
99
+ att_t = attention_weights[t, :, :]
100
+ a_min, a_max = att_t.min(), att_t.max()
101
+ if not np.isclose(a_min, a_max):
102
+ att_t = (att_t - a_min) / (a_max - a_min + 1e-8)
103
+ else:
104
+ att_t = np.zeros_like(att_t)
105
+
106
+ # input image for arrows
107
+ img_t = inputs_this_batch[t].transpose(1, 2, 0)
108
+
109
+ if t == 0:
110
+ route_history = [[] for _ in range(n_heads)]
111
+
112
+ img_h, img_w = img_t.shape[:2]
113
+ cell_h = img_h // grid_side
114
+ cell_w = img_w // grid_side
115
+
116
+ for h in range(n_heads):
117
+ head_map = att_t[h].reshape(grid_side, grid_side)
118
+ ax[f"att_{h}"].imshow(head_map, cmap="viridis", vmin=0, vmax=1)
119
+ ax[f"att_{h}"].axis("off")
120
+ ax[f"in_{h}"].imshow(img_t, cmap="gray", vmin=0, vmax=1)
121
+ ax[f"in_{h}"].axis("off")
122
+
123
+ # track argmax center
124
+ flat_idx = np.argmax(head_map)
125
+ gy, gx = divmod(flat_idx, grid_side)
126
+ cx = int((gx + 0.5) * cell_w)
127
+ cy = int((gy + 0.5) * cell_h)
128
+ route_history[h].append((cx, cy))
129
+
130
+ cmap_steps = plt.colormaps.get_cmap("Spectral")
131
+ colors = [cmap_steps(i / (n_steps - 1)) for i in range(n_steps)]
132
+ for i in range(len(route_history[h]) - 1):
133
+ x0, y0 = route_history[h][i]
134
+ x1, y1 = route_history[h][i + 1]
135
+ color = colors[i]
136
+ is_last = (i == len(route_history[h]) - 2)
137
+ style = '->' if is_last else '-'
138
+ lw = 2.0 if is_last else 1.6
139
+ alpha = 1.0 if is_last else 0.9
140
+ scale = 10 if is_last else 1
141
+
142
+ # draw arrow
143
+ arr = FancyArrowPatch(
144
+ (x0, y0), (x1, y1),
145
+ arrowstyle=style,
146
+ linewidth=lw,
147
+ mutation_scale=scale,
148
+ alpha=alpha,
149
+ facecolor=color,
150
+ edgecolor=color,
151
+ shrinkA=0, shrinkB=0,
152
+ capstyle='round', joinstyle='round',
153
+ zorder=3 if is_last else 2,
154
+ clip_on=False,
155
+ )
156
+ ax[f"in_{h}"].add_patch(arr)
157
+
158
+ ax[f"in_{h}"].scatter(
159
+ x1, y1,
160
+ marker='x',
161
+ s=40,
162
+ color=color,
163
+ linewidths=lw,
164
+ zorder=4
165
+ )
166
+
167
+ canvas = fig.canvas
168
+ canvas.draw()
169
+ frame = np.frombuffer(canvas.buffer_rgba(), dtype=np.uint8)
170
+ w, h = canvas.get_width_height()
171
+ frames.append(frame.reshape(h, w, 4)[..., :3])
172
+ plt.close(fig)
173
+
174
+ # save gif
175
+ imageio.mimsave(f"{save_path}/activation.gif", frames, fps=15, loop=0)
176
+
177
+ # save mp4
178
+ save_frames_to_mp4(
179
+ [fm[:, :, ::-1] for fm in frames], # RGB→BGR
180
+ f"{save_path}/activation.mp4",
181
+ fps=15,
182
+ gop_size=1,
183
+ preset="slow"
184
+ )
185
+
186
+ def run_umap(model, testloader):
187
+ all_post_activations = []
188
+ point_counts = 150
189
+ sampled = 0
190
+ with tqdm(total=point_counts, desc="Collecting UMAP data") as pbar:
191
+ for inputs, _ in testloader:
192
+ for i in range(inputs.size(0)):
193
+ if sampled >= point_counts:
194
+ break
195
+ input_i = inputs[i].unsqueeze(0).to(device)
196
+ _, _, _, _, post_activations, _ = model(input_i, track=True)
197
+ all_post_activations.append(post_activations)
198
+ sampled += 1
199
+ pbar.update(1)
200
+ if sampled >= point_counts:
201
+ break
202
+
203
+ stacked = np.stack(all_post_activations, 1)
204
+ umap_features = stacked.reshape(-1, stacked.shape[-1])
205
+ reducer = umap.UMAP(
206
+ n_components=2,
207
+ n_neighbors=20,
208
+ min_dist=1,
209
+ spread=1,
210
+ metric='cosine',
211
+ local_connectivity=1
212
+ )
213
+ positions = reducer.fit_transform(umap_features.T)
214
+ return positions
215
+
216
+
217
+ def run_model_and_make_gif(checkpoint_path, save_path, device):
218
+
219
+ parity_sequence_length = 64
220
+ iterations = 75
221
+
222
+ test_data = ParityDataset(sequence_length=parity_sequence_length, length=10000)
223
+ testloader = torch.utils.data.DataLoader(test_data, batch_size=256, shuffle=True, num_workers=0, drop_last=False)
224
+
225
+
226
+ model, _ = build_model_from_checkpoint_path(checkpoint_path, "ctm", device=device)
227
+
228
+ input = torch.randint(0, 2, (64,), dtype=torch.float32, device=device) * 2 - 1
229
+ input = input.unsqueeze(0)
230
+
231
+ target = torch.cumsum((input == -1).to(torch.long), dim=1) % 2
232
+ target = target.unsqueeze(0)
233
+
234
+ positions = run_umap(model, testloader)
235
+
236
+ model.eval()
237
+ with torch.inference_mode():
238
+ predictions, _, _, _, post_activations, attention = model(input, track=True)
239
+ predictons = reshape_predictions(predictions, prediction_reshaper=[parity_sequence_length, 2])
240
+ input_images = reshape_inputs(input, iterations, grid_size=int(math.sqrt(parity_sequence_length)))
241
+
242
+ make_parity_gif(
243
+ predictions=predictons.detach().cpu().numpy(),
244
+ targets=target.detach().cpu().numpy(),
245
+ post_activations=post_activations,
246
+ attention_weights=attention.squeeze(1).squeeze(2),
247
+ inputs_to_model=input_images,
248
+ save_path=save_path,
249
+ umap_positions=positions,
250
+ umap_point_scaler=1.0,
251
+ )
252
+
253
+
254
+
255
+ if __name__ == "__main__":
256
+
257
+ CHECKPOINT_PATH = "checkpoints/parity/run1/ctm_75_25/checkpoint_200000.pt"
258
+ SAVE_PATH = f"tasks/parity/analysis/outputs/blog_gifs/"
259
+ os.makedirs(SAVE_PATH, exist_ok=True)
260
+
261
+ device = "cuda" if torch.cuda.is_available() else "cpu"
262
+
263
+ run_model_and_make_gif(CHECKPOINT_PATH, SAVE_PATH, device)
tasks/parity/analysis/run.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import argparse
4
+ import multiprocessing
5
+ from tqdm import tqdm
6
+ import math
7
+ import os
8
+ import csv
9
+ from utils.housekeeping import set_seed
10
+ from data.custom_datasets import ParityDataset
11
+ from tasks.parity.utils import prepare_model, reshape_attention_weights, reshape_inputs, get_where_most_certain
12
+ from tasks.parity.plotting import plot_attention_trajectory, plot_input, plot_target, plot_probabilities, plot_prediction, plot_accuracy_training, create_attentions_heatmap_gif, create_accuracies_heatmap_gif, create_stacked_gif, plot_training_curve_all_runs, plot_accuracy_thinking_time, make_parity_gif, plot_lstm_last_and_certain_accuracy
13
+ from models.utils import compute_normalized_entropy, reshape_predictions, get_latest_checkpoint_file, get_checkpoint_files, load_checkpoint, get_model_args_from_checkpoint, get_all_log_dirs
14
+ from tasks.image_classification.plotting import plot_neural_dynamics
15
+
16
+ import seaborn as sns
17
+ sns.set_palette("hls")
18
+ sns.set_style('darkgrid')
19
+
20
+ def parse_args():
21
+ parser = argparse.ArgumentParser(description='Parity Analysis')
22
+ parser.add_argument('--log_dir', type=str, default='checkpoints/parity', help='Directory to save logs.')
23
+ parser.add_argument('--batch_size_test', type=int, default=128, help='batch size for testing')
24
+ parser.add_argument('--scale_training_curve', type=float, default=0.6, help='Scaling factor for plots.')
25
+ parser.add_argument('--scale_heatmap', type=float, default=0.4, help='Scaling factor for heatmap plots.')
26
+ parser.add_argument('--scale_training_index_accuracy', type=float, default=0.4, help='Scaling factor for training index accuracy plots.')
27
+ parser.add_argument('--seed', type=int, default=0, help='Random seed for reproducibility.')
28
+ parser.add_argument('--device', type=int, nargs='+', default=[-1], help='List of GPU(s) to use. Set to -1 to use CPU.')
29
+ parser.add_argument('--model_type', type=str, choices=['ctm', 'lstm'], default='ctm', help='Type of model to analyze (ctm or lstm).')
30
+ return parser.parse_args()
31
+
32
+ def calculate_corrects(predictions, targets):
33
+ predicted_labels = predictions.argmax(2)
34
+ accuracy = (predicted_labels == targets.unsqueeze(-1))
35
+ return accuracy.detach().cpu().numpy()
36
+
37
+ def get_corrects_per_element_at_most_certain_time(predictions, certainty, targets):
38
+ where_most_certain = get_where_most_certain(certainty)
39
+ corrects = (predictions.argmax(2)[torch.arange(predictions.size(0), device=predictions.device),:,where_most_certain] == targets).float()
40
+ return corrects.detach().cpu().numpy()
41
+
42
+ def calculate_entropy_average_over_batch(normalized_entropy_per_elements):
43
+ normalized_entropy_per_elements_avg_batch = normalized_entropy_per_elements.mean(axis=1)
44
+ return normalized_entropy_per_elements_avg_batch
45
+
46
+ def calculate_thinking_time_average_over_batch(normalized_entropy_per_elements):
47
+ first_occurrence = calculate_thinking_time(normalized_entropy_per_elements)
48
+ average_thinking_time = np.mean(first_occurrence, axis=0)
49
+ return average_thinking_time
50
+
51
+ def calculate_thinking_time(normalized_entropy_per_elements, finish_type="min", entropy_threshold=0.1):
52
+ if finish_type == "min":
53
+ min_entropy_time = np.argmin(normalized_entropy_per_elements, axis=0)
54
+ return min_entropy_time
55
+ elif finish_type == "threshold":
56
+ T, B, S = normalized_entropy_per_elements.shape
57
+ below_threshold = normalized_entropy_per_elements < entropy_threshold
58
+ first_occurrence = np.argmax(below_threshold, axis=0)
59
+ no_true = ~np.any(below_threshold, axis=0)
60
+ first_occurrence[no_true] = T
61
+ return first_occurrence
62
+
63
+ def test_handcrafted_examples(model, args, run_model_spefic_save_dir, device):
64
+ test_cases = []
65
+ all_even_input = torch.full((args.parity_sequence_length,), 1.0, dtype=torch.float32, device=device)
66
+ all_even_target = torch.zeros_like(all_even_input, dtype=torch.long)
67
+ test_cases.append((all_even_input, all_even_target))
68
+
69
+ all_odd_input = torch.full((args.parity_sequence_length,), -1.0, dtype=torch.float32, device=device)
70
+ all_odd_target = torch.cumsum((all_odd_input == -1).to(torch.long), dim=0) % 2
71
+ test_cases.append((all_odd_input, all_odd_target))
72
+
73
+ random_input = torch.randint(0, 2, (args.parity_sequence_length,), dtype=torch.float32, device=device) * 2 - 1
74
+ random_target = torch.cumsum((random_input == -1).to(torch.long), dim=0) % 2
75
+ test_cases.append((random_input, random_target))
76
+
77
+ for i, (inputs, targets) in enumerate(test_cases):
78
+ inputs = inputs.unsqueeze(0)
79
+ targets = targets.unsqueeze(0)
80
+ filename = f"eval_handcrafted_{i}"
81
+ extend_inference_time = False
82
+ handcraft_dir = f"{run_model_spefic_save_dir}/handcrafted_examples/{i}"
83
+ os.makedirs(handcraft_dir, exist_ok=True)
84
+
85
+ model.eval()
86
+ with torch.inference_mode():
87
+ if extend_inference_time:
88
+ model.iterations = model.iterations * 2
89
+ predictions, certainties, synchronisation, pre_activations, post_activations, attention = model(inputs, track=True)
90
+ predictions = reshape_predictions(predictions, prediction_reshaper=[args.parity_sequence_length, 2])
91
+ input_images = reshape_inputs(inputs, args.iterations, grid_size=int(math.sqrt(args.parity_sequence_length)))
92
+
93
+ plot_neural_dynamics(post_activations, 100, handcraft_dir, axis_snap=False)
94
+
95
+ process = multiprocessing.Process(
96
+ target=make_parity_gif,
97
+ args=(
98
+ predictions.detach().cpu().numpy(),
99
+ certainties.detach().cpu().numpy(),
100
+ targets.detach().cpu().numpy(),
101
+ pre_activations,
102
+ post_activations,
103
+ reshape_attention_weights(attention),
104
+ input_images,
105
+ f"{handcraft_dir}/eval_output_val_{0}_iter_{0}.gif",
106
+ ))
107
+ process.start()
108
+
109
+
110
+ input_images = input_images.squeeze(1).squeeze(1)
111
+ attention = attention.squeeze(1)
112
+
113
+ for h in range(args.heads):
114
+ plot_attention_trajectory(attention[:, h, :, :], certainties, input_images, handcraft_dir, filename + f"_head_{h}", args)
115
+
116
+ plot_attention_trajectory(attention.mean(1), certainties, input_images, handcraft_dir, filename, args)
117
+ plot_input(input_images, handcraft_dir, filename)
118
+ plot_target(targets, handcraft_dir, filename, args)
119
+ plot_probabilities(predictions, certainties, handcraft_dir, filename, args)
120
+ plot_prediction(predictions, certainties,handcraft_dir, filename, args)
121
+
122
+ if extend_inference_time:
123
+ model.iterations = model.iterations // 2
124
+ model.train()
125
+ pass
126
+
127
+ def build_model_from_checkpoint_path(checkpoint_path, model_type, device="cpu"):
128
+ checkpoint = load_checkpoint(checkpoint_path, device)
129
+ model_args = get_model_args_from_checkpoint(checkpoint)
130
+ model = prepare_model([model_args.parity_sequence_length, 2], model_args, device)
131
+ model.load_state_dict(checkpoint["model_state_dict"], strict=False)
132
+ return model, model_args
133
+
134
+ def analyze_trained_model(run_model_spefic_save_dir, args, device):
135
+ with torch.no_grad():
136
+
137
+ latest_checkpoint_path = get_latest_checkpoint_file(args.log_dir)
138
+ model, model_args = build_model_from_checkpoint_path(latest_checkpoint_path, args.model_type, device=device)
139
+ model.eval()
140
+ model_args.log_dir = args.log_dir
141
+ test_data = ParityDataset(sequence_length=model_args.parity_sequence_length, length=10000)
142
+ testloader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size_test, shuffle=True, num_workers=0, drop_last=False)
143
+
144
+ corrects, corrects_at_most_certain_times, entropys, attentions = [], [], [], []
145
+
146
+ for inputs, targets in testloader:
147
+ inputs = inputs.to(device)
148
+ targets = targets.to(device)
149
+ predictions, certainties, synchronisation, pre_activations, post_activations, attention = model(inputs, track=True)
150
+ predictions = reshape_predictions(predictions, prediction_reshaper=[model_args.parity_sequence_length, 2])
151
+ corrects_batch = calculate_corrects(predictions, targets)
152
+ corrects_at_most_certain_time_batch = get_corrects_per_element_at_most_certain_time(predictions, certainties, targets)
153
+ corrects.append(corrects_batch)
154
+ corrects_at_most_certain_times.append(corrects_at_most_certain_time_batch)
155
+ attentions.append(attention)
156
+
157
+ test_handcrafted_examples(model, model_args, run_model_spefic_save_dir, device)
158
+
159
+ overall_mean_accuracy = np.mean(np.vstack(corrects_at_most_certain_times))
160
+ overall_std_accuracy = np.std(np.mean(np.vstack(corrects_at_most_certain_times), axis=1))
161
+
162
+ return overall_mean_accuracy, overall_std_accuracy, model_args.iterations
163
+
164
+ def analyze_training(run_model_spefic_save_dir, args, device):
165
+ checkpoint_files = get_checkpoint_files(args.log_dir)
166
+ all_accuracies = []
167
+ all_accuracies_at_most_certain_time = []
168
+ all_average_thinking_times = []
169
+ all_std_thinking_times = []
170
+ all_attentions = []
171
+ for checkpoint_path in checkpoint_files:
172
+ model, model_args = build_model_from_checkpoint_path(checkpoint_path, args.model_type, device=device)
173
+ model_args.log_dir = run_model_spefic_save_dir
174
+ test_data = ParityDataset(sequence_length=model_args.parity_sequence_length, length=1000)
175
+ testloader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size_test, shuffle=True, num_workers=0, drop_last=False)
176
+ corrects = []
177
+ corrects_at_most_certain_times = []
178
+ thinking_times = []
179
+ attentions = []
180
+
181
+ for inputs, targets in testloader:
182
+ inputs = inputs.to(device)
183
+ targets = targets.to(device)
184
+ predictions, certainties, synchronisation, pre_activations, post_activations, attention = model(inputs, track=True)
185
+ predictions = reshape_predictions(predictions, prediction_reshaper=[model_args.parity_sequence_length, 2])
186
+ attention = reshape_attention_weights(attention)
187
+
188
+ corrects_batch = calculate_corrects(predictions, targets)
189
+ corrects_at_most_certain_time_batch = get_corrects_per_element_at_most_certain_time(predictions, certainties, targets)
190
+ entropy_per_element = compute_normalized_entropy(predictions.permute(0,3,1,2), reduction='none').detach().cpu().numpy()
191
+ thinking_times_batch = np.argmin(entropy_per_element, axis=1)
192
+
193
+ corrects.append(corrects_batch)
194
+ corrects_at_most_certain_times.append(corrects_at_most_certain_time_batch)
195
+ thinking_times.append(thinking_times_batch)
196
+ attentions.append(attention)
197
+
198
+ checkpoint_average_accuracies = np.mean(np.concatenate(corrects, axis=0), axis=0).transpose(1,0)
199
+ all_accuracies.append(checkpoint_average_accuracies)
200
+
201
+ stacked_corrects_at_most_certain_times = np.vstack(corrects_at_most_certain_times)
202
+ checkpoint_average_accuracy_at_most_certain_time = np.mean(stacked_corrects_at_most_certain_times, axis=0)
203
+ all_accuracies_at_most_certain_time.append(checkpoint_average_accuracy_at_most_certain_time)
204
+
205
+ checkpoint_thinking_times = np.concatenate(thinking_times, axis=0)
206
+ checkpoint_average_thinking_time = np.mean(checkpoint_thinking_times, axis=0)
207
+ checkpoint_std_thinking_time = np.std(checkpoint_thinking_times, axis=0)
208
+ all_average_thinking_times.append(checkpoint_average_thinking_time)
209
+ all_std_thinking_times.append(checkpoint_std_thinking_time)
210
+
211
+ checkpoint_average_attentions = np.mean(np.concatenate(attentions, axis=1), axis=1)
212
+ all_attentions.append(checkpoint_average_attentions)
213
+
214
+ plot_accuracy_training(all_accuracies_at_most_certain_time, args.scale_training_index_accuracy, run_model_spefic_save_dir, args=model_args)
215
+ create_attentions_heatmap_gif(all_attentions, args.scale_heatmap, run_model_spefic_save_dir, model_args)
216
+ create_accuracies_heatmap_gif(np.array(all_accuracies), all_average_thinking_times, all_std_thinking_times, args.scale_heatmap, run_model_spefic_save_dir, model_args)
217
+ create_stacked_gif(run_model_spefic_save_dir)
218
+
219
+ def get_accuracy_and_loss_from_checkpoint(checkpoint):
220
+ training_iteration = checkpoint.get('training_iteration', 0)
221
+ train_losses = checkpoint.get('train_losses', [])
222
+ test_losses = checkpoint.get('test_losses', [])
223
+ train_accuracies = checkpoint.get('train_accuracies_most_certain', [])
224
+ test_accuracies = checkpoint.get('test_accuracies_most_certain', [])
225
+ return training_iteration, train_losses, test_losses, train_accuracies, test_accuracies
226
+
227
+ if __name__ == "__main__":
228
+
229
+ args = parse_args()
230
+
231
+ device = f'cuda:{args.device[0]}' if args.device[0] != -1 else 'cpu'
232
+
233
+ set_seed(args.seed)
234
+
235
+ save_dir = "tasks/parity/analysis/outputs"
236
+ os.makedirs(save_dir, exist_ok=True)
237
+
238
+ accuracy_csv_file_path = os.path.join(save_dir, "accuracy.csv")
239
+ if os.path.exists(accuracy_csv_file_path):
240
+ os.remove(accuracy_csv_file_path)
241
+
242
+ all_runs_log_dirs = get_all_log_dirs(args.log_dir)
243
+
244
+ plot_training_curve_all_runs(all_runs_log_dirs, save_dir, args.scale_training_curve, device, x_max=200_000)
245
+ plot_lstm_last_and_certain_accuracy(all_folders=all_runs_log_dirs, save_path=f"{save_dir}/lstm_final_vs_certain_accuracy.png", scale=args.scale_training_curve)
246
+
247
+ progress_bar = tqdm(all_runs_log_dirs, desc="Analyzing Runs", dynamic_ncols=True)
248
+ for folder in progress_bar:
249
+
250
+ run, model_name = folder.strip("/").split("/")[-2:]
251
+
252
+ run_model_spefic_save_dir = f"{save_dir}/{model_name}/{run}"
253
+ os.makedirs(run_model_spefic_save_dir, exist_ok=True)
254
+
255
+ args.log_dir = folder
256
+ progress_bar.set_description(f"Analyzing Trained Model at {folder}")
257
+
258
+ accuracy_mean, accuracy_std, num_iterations = analyze_trained_model(run_model_spefic_save_dir, args, device)
259
+
260
+ with open(accuracy_csv_file_path, mode='a', newline='') as file:
261
+ writer = csv.writer(file)
262
+ if file.tell() == 0:
263
+ writer.writerow(["Run", "Overall Mean Accuracy", "Overall Std Accuracy", "Num Iterations"])
264
+ writer.writerow([folder, accuracy_mean, accuracy_std, num_iterations])
265
+
266
+ progress_bar.set_description(f"Analyzing Training at {folder}")
267
+ analyze_training(run_model_spefic_save_dir, args, device)
268
+
269
+ plot_accuracy_thinking_time(accuracy_csv_file_path, scale=args.scale_training_curve, output_dir=save_dir)
tasks/parity/plotting.py ADDED
@@ -0,0 +1,897 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import seaborn as sns
3
+ import numpy as np
4
+ import pandas as pd
5
+ from collections import defaultdict
6
+ from matplotlib.lines import Line2D
7
+ import matplotlib as mpl
8
+ import matplotlib.pyplot as plt
9
+ import matplotlib.patheffects as path_effects
10
+ from matplotlib.ticker import FuncFormatter
11
+ from scipy.special import softmax
12
+ import imageio.v2 as imageio
13
+ from PIL import Image
14
+ import math
15
+ import re
16
+ from tqdm import tqdm
17
+ sns.set_style('darkgrid')
18
+ mpl.use('Agg')
19
+
20
+ from tasks.parity.utils import get_where_most_certain, parse_folder_name
21
+ from models.utils import get_latest_checkpoint_file, load_checkpoint, get_model_args_from_checkpoint, get_accuracy_and_loss_from_checkpoint
22
+ from tasks.image_classification.plotting import save_frames_to_mp4
23
+
24
+ def make_parity_gif(predictions, certainties, targets, pre_activations, post_activations, attention_weights, inputs_to_model, filename):
25
+
26
+ # Config
27
+ batch_index = 0
28
+ n_neurons_to_visualise = 16
29
+ figscale = 0.28
30
+ n_steps = len(pre_activations)
31
+ frames = []
32
+ heatmap_cmap = sns.color_palette("viridis", as_cmap=True)
33
+
34
+ these_pre_acts = pre_activations[:, batch_index, :] # Shape: (T, H)
35
+ these_post_acts = post_activations[:, batch_index, :] # Shape: (T, H)
36
+ these_inputs = inputs_to_model[:, batch_index, :, :, :] # Shape: (T, C, H, W)
37
+ these_predictions = predictions[batch_index, :, :, :] # Shape: (d, C, T)
38
+ these_certainties = certainties[batch_index, :, :] # Shape: (C, T)
39
+ these_attention_weights = attention_weights[:, batch_index, :, :]
40
+
41
+ # Create mosaic layout
42
+ mosaic = [['img_data', 'img_data', 'attention', 'attention', 'probs', 'probs', 'target', 'target'] for _ in range(2)] + \
43
+ [['img_data', 'img_data', 'attention', 'attention', 'probs', 'probs', 'target', 'target'] for _ in range(2)] + \
44
+ [['certainty', 'certainty', 'certainty', 'certainty', 'certainty', 'certainty', 'certainty', 'certainty']] + \
45
+ [[f'trace_{ti}', f'trace_{ti}', f'trace_{ti}', f'trace_{ti}', f'trace_{ti}', f'trace_{ti}', f'trace_{ti}', f'trace_{ti}'] for ti in range(n_neurons_to_visualise)]
46
+
47
+ for stepi in tqdm(range(n_steps), desc="Processing steps", unit="step"):
48
+ fig_gif, axes_gif = plt.subplot_mosaic(mosaic=mosaic, figsize=(31*figscale*8/4, 76*figscale))
49
+
50
+ # Plot predictions
51
+ d = these_predictions.shape[0]
52
+ grid_side = int(np.sqrt(d))
53
+ logits = these_predictions[:, :, stepi]
54
+
55
+ probs = softmax(logits, axis=1)
56
+ probs_grid = probs[:, 0].reshape(grid_side, grid_side)
57
+ axes_gif["probs"].imshow(probs_grid, cmap='viridis', interpolation='nearest', vmin=0, vmax=1)
58
+ axes_gif["probs"].axis('off')
59
+ axes_gif["probs"].set_title('Probabilties')
60
+
61
+ # Create and show attention heatmap
62
+ this_input_gate = these_attention_weights[stepi]
63
+ gate_min, gate_max = np.nanmin(this_input_gate), np.nanmax(this_input_gate)
64
+ if not np.isclose(gate_min, gate_max):
65
+ normalized_gate = (this_input_gate - gate_min) / (gate_max - gate_min + 1e-8)
66
+ else:
67
+ normalized_gate = np.zeros_like(this_input_gate)
68
+ attention_weights_heatmap = heatmap_cmap(normalized_gate)[:,:,:3]
69
+
70
+ # Show heatmaps
71
+ axes_gif['attention'].imshow(attention_weights_heatmap, vmin=0, vmax=1)
72
+ axes_gif['attention'].axis('off')
73
+ axes_gif['attention'].set_title('Attention')
74
+
75
+
76
+ # Plot target
77
+ target_grid = targets[batch_index].reshape(grid_side, grid_side)
78
+ axes_gif["target"].imshow(target_grid, cmap='viridis_r', interpolation='nearest', vmin=0, vmax=1)
79
+ axes_gif["target"].axis('off')
80
+ axes_gif["target"].set_title('Target')
81
+
82
+ # Add certainty plot
83
+ axes_gif['certainty'].plot(np.arange(n_steps), these_certainties[1], 'k-', linewidth=2)
84
+ axes_gif['certainty'].set_xlim([0, n_steps-1])
85
+ axes_gif['certainty'].axvline(x=stepi, color='black', linewidth=1, alpha=0.5)
86
+ axes_gif['certainty'].set_xticklabels([])
87
+ axes_gif['certainty'].set_yticklabels([])
88
+ axes_gif['certainty'].grid(False)
89
+
90
+ # Plot neuron traces
91
+ for neuroni in range(n_neurons_to_visualise):
92
+ ax = axes_gif[f'trace_{neuroni}']
93
+
94
+ pre_activation = these_pre_acts[:, neuroni]
95
+ post_activation = these_post_acts[:, neuroni]
96
+
97
+ ax_pre = ax.twinx()
98
+
99
+ pre_min, pre_max = np.min(pre_activation), np.max(pre_activation)
100
+ post_min, post_max = np.min(post_activation), np.max(post_activation)
101
+
102
+ ax_pre.plot(np.arange(n_steps), pre_activation,
103
+ color='grey',
104
+ linestyle='--',
105
+ linewidth=1,
106
+ alpha=0.4,
107
+ label='Pre-activation')
108
+
109
+ color = 'blue' if neuroni % 2 else 'red'
110
+ ax.plot(np.arange(n_steps), post_activation,
111
+ color=color,
112
+ linestyle='-',
113
+ linewidth=2,
114
+ alpha=1.0,
115
+ label='Post-activation')
116
+
117
+ ax.set_xlim([0, n_steps-1])
118
+ ax_pre.set_xlim([0, n_steps-1])
119
+
120
+ if pre_min != pre_max:
121
+ ax_pre.set_ylim([pre_min, pre_max])
122
+ if post_min != post_max:
123
+ ax.set_ylim([post_min, post_max])
124
+
125
+ ax.axvline(x=stepi, color='black', linewidth=1, alpha=0.5)
126
+
127
+ ax.set_xticklabels([])
128
+ ax.set_yticklabels([])
129
+ ax.grid(False)
130
+
131
+ ax_pre.set_xticklabels([])
132
+ ax_pre.set_yticklabels([])
133
+ ax_pre.grid(False)
134
+
135
+ # Show input image
136
+ this_image = these_inputs[stepi].transpose(1, 2, 0)
137
+ axes_gif['img_data'].imshow(this_image, cmap='viridis', vmin=0, vmax=1)
138
+ axes_gif['img_data'].grid(False)
139
+ axes_gif['img_data'].set_xticks([])
140
+ axes_gif['img_data'].set_yticks([])
141
+ axes_gif['img_data'].set_title('Input')
142
+
143
+ # Save frames
144
+ fig_gif.tight_layout(pad=0.1)
145
+ if stepi == 0:
146
+ fig_gif.savefig(filename.split('.gif')[0]+'_frame0.png', dpi=100)
147
+ if stepi == 1:
148
+ fig_gif.savefig(filename.split('.gif')[0]+'_frame1.png', dpi=100)
149
+ if stepi == n_steps-1:
150
+ fig_gif.savefig(filename.split('.gif')[0]+'_frame-1.png', dpi=100)
151
+
152
+ # Convert to frame
153
+ canvas = fig_gif.canvas
154
+ canvas.draw()
155
+ image_numpy = np.frombuffer(canvas.buffer_rgba(), dtype='uint8')
156
+ image_numpy = image_numpy.reshape(*reversed(canvas.get_width_height()), 4)[:,:,:3]
157
+ frames.append(image_numpy)
158
+ plt.close(fig_gif)
159
+
160
+ imageio.mimsave(filename, frames, fps=15, loop=100)
161
+
162
+ pass
163
+
164
+
165
+ def plot_attention_trajectory(attention, certainties, input_images, save_dir, filename, args):
166
+ where_most_certain = get_where_most_certain(certainties)
167
+ grid_size = int(math.sqrt(args.parity_sequence_length))
168
+ trajectory = [np.unravel_index(np.argmax(attention[t]), (grid_size, grid_size)) for t in range(args.iterations)]
169
+ x_coords, y_coords = zip(*trajectory)
170
+
171
+ plt.figure(figsize=(5, 5))
172
+ plt.imshow(input_images[0], cmap="gray", origin="upper", vmin=0.2, vmax=0.8, interpolation='nearest')
173
+
174
+ ax = plt.gca()
175
+ nrows, ncols = input_images[0].shape
176
+ ax.set_xticks(np.arange(-0.5, ncols, 1), minor=True)
177
+ ax.set_yticks(np.arange(-0.5, nrows, 1), minor=True)
178
+ ax.grid(which='minor', color='black', linestyle='-', linewidth=1)
179
+ ax.tick_params(which="minor", size=0)
180
+ ax.set_axisbelow(False)
181
+ plt.xticks([])
182
+ plt.yticks([])
183
+
184
+ cmap = plt.get_cmap("plasma")
185
+ norm_time = np.linspace(0, 1, len(trajectory))
186
+
187
+ for i in range(len(trajectory) - 1):
188
+ x1, y1 = x_coords[i], y_coords[i]
189
+ x2, y2 = x_coords[i + 1], y_coords[i + 1]
190
+ color = cmap(norm_time[i])
191
+ line, = plt.plot([y1, y2], [x1, x2], color=color, linewidth=6, alpha=0.5, zorder=4)
192
+ line.set_path_effects([
193
+ path_effects.Stroke(linewidth=8, foreground='white'),
194
+ path_effects.Normal()
195
+ ])
196
+
197
+ for i, (x, y) in enumerate(trajectory):
198
+ plt.scatter(y, x, color=cmap(norm_time[i]), s=100, edgecolor='white', linewidth=1.5, zorder=5)
199
+
200
+ most_certain_point = trajectory[where_most_certain]
201
+
202
+ plt.plot(most_certain_point[1], most_certain_point[0],
203
+ marker='x', markersize=18, markeredgewidth=5,
204
+ color='white', linestyle='', zorder=6)
205
+ plt.plot(most_certain_point[1], most_certain_point[0],
206
+ marker='x', markersize=15, markeredgewidth=3,
207
+ color=cmap(norm_time[where_most_certain]), linestyle='', zorder=7)
208
+
209
+ plt.tight_layout()
210
+ plt.savefig(f"{save_dir}/{filename}_traj.png", dpi=300, bbox_inches='tight', pad_inches=0)
211
+ plt.savefig(f"{save_dir}/{filename}_traj.pdf", format='pdf', bbox_inches='tight', pad_inches=0)
212
+ plt.show()
213
+ plt.close()
214
+
215
+ def plot_input(input_images, save_dir, filename):
216
+
217
+ plt.figure(figsize=(5, 5))
218
+ plt.imshow(input_images[0], cmap="gray", origin="upper", vmin=0.2, vmax=0.8, interpolation='nearest')
219
+
220
+ ax = plt.gca()
221
+ nrows, ncols = input_images[0].shape
222
+ ax.set_xticks(np.arange(-0.5, ncols, 1), minor=True)
223
+ ax.set_yticks(np.arange(-0.5, nrows, 1), minor=True)
224
+ ax.grid(which='minor', color='black', linestyle='-', linewidth=1)
225
+ ax.tick_params(which="minor", size=0)
226
+ ax.set_axisbelow(False)
227
+ plt.xticks([])
228
+ plt.yticks([])
229
+
230
+ plt.tight_layout()
231
+ plt.savefig(f"{save_dir}/{filename}_input.png", dpi=300, bbox_inches='tight', pad_inches=0)
232
+ plt.savefig(f"{save_dir}/{filename}_input.pdf", format='pdf', bbox_inches='tight', pad_inches=0)
233
+ plt.show()
234
+ plt.close()
235
+
236
+ def plot_target(targets, save_dir, filename, args):
237
+ grid_size = int(math.sqrt(args.parity_sequence_length))
238
+ targets_grid = targets[0].reshape(grid_size, grid_size).detach().cpu().numpy()
239
+ plt.figure(figsize=(5, 5))
240
+ plt.imshow(targets_grid, cmap="gray_r", origin="upper", vmin=0.2, vmax=0.8, interpolation='nearest')
241
+ ax = plt.gca()
242
+ nrows, ncols = targets_grid.shape
243
+ ax.set_xticks(np.arange(-0.5, ncols, 1), minor=True)
244
+ ax.set_yticks(np.arange(-0.5, nrows, 1), minor=True)
245
+ ax.grid(which='minor', color='black', linestyle='-', linewidth=1)
246
+ ax.tick_params(which="minor", size=0)
247
+ ax.set_axisbelow(False)
248
+ plt.xticks([])
249
+ plt.yticks([])
250
+ plt.tight_layout()
251
+ plt.savefig(f"{save_dir}/{filename}_target.png", dpi=300, bbox_inches='tight', pad_inches=0)
252
+ plt.savefig(f"{save_dir}/{filename}_target.pdf", format='pdf', bbox_inches='tight', pad_inches=0)
253
+ plt.show()
254
+ plt.close()
255
+
256
+ def plot_probabilities(predictions, certainties, save_dir, filename, args):
257
+ grid_size = int(math.sqrt(args.parity_sequence_length))
258
+ where_most_certain = get_where_most_certain(certainties)
259
+ predictions_most_certain = predictions[0, :, :, where_most_certain].detach().cpu().numpy()
260
+ probs = softmax(predictions_most_certain, axis=1)
261
+ probs_grid = probs[:, 0].reshape(grid_size, grid_size)
262
+ plt.figure(figsize=(5, 5))
263
+ plt.imshow(probs_grid, cmap="gray", origin="upper", vmin=0.2, vmax=0.8, interpolation='nearest')
264
+ ax = plt.gca()
265
+ nrows, ncols = probs_grid.shape
266
+ ax.set_xticks(np.arange(-0.5, ncols, 1), minor=True)
267
+ ax.set_yticks(np.arange(-0.5, nrows, 1), minor=True)
268
+ ax.grid(which='minor', color='black', linestyle='-', linewidth=1)
269
+ ax.tick_params(which="minor", size=0)
270
+ ax.set_axisbelow(False)
271
+ plt.xticks([])
272
+ plt.yticks([])
273
+ plt.tight_layout()
274
+ plt.savefig(f"{save_dir}/{filename}_probs.png", dpi=300, bbox_inches='tight', pad_inches=0)
275
+ plt.savefig(f"{save_dir}/{filename}_probs.pdf", format='pdf', bbox_inches='tight', pad_inches=0)
276
+ plt.show()
277
+ plt.close()
278
+
279
+ def plot_prediction(predictions, certainties, save_dir, filename, args):
280
+ grid_size = int(math.sqrt(args.parity_sequence_length))
281
+ where_most_certain = get_where_most_certain(certainties)
282
+ predictions_most_certain = predictions[0, :, :, where_most_certain].detach().cpu().numpy()
283
+ class_grid = np.argmax(predictions_most_certain, axis=1).reshape(grid_size, grid_size)
284
+
285
+ plt.figure(figsize=(5, 5))
286
+ plt.imshow(class_grid, cmap="gray_r", origin="upper", vmin=0, vmax=1, interpolation='nearest')
287
+
288
+ ax = plt.gca()
289
+ nrows, ncols = class_grid.shape
290
+ ax.set_xticks(np.arange(-0.5, ncols, 1), minor=True)
291
+ ax.set_yticks(np.arange(-0.5, nrows, 1), minor=True)
292
+ ax.grid(which='minor', color='black', linestyle='-', linewidth=1)
293
+ ax.tick_params(which="minor", size=0)
294
+ ax.set_axisbelow(False)
295
+ plt.xticks([])
296
+ plt.yticks([])
297
+
298
+ plt.tight_layout()
299
+ plt.savefig(f"{save_dir}/{filename}_prediction.png", dpi=300, bbox_inches='tight', pad_inches=0)
300
+ plt.savefig(f"{save_dir}/{filename}_prediction.pdf", format='pdf', bbox_inches='tight', pad_inches=0)
301
+ plt.show()
302
+ plt.close()
303
+
304
+ def plot_accuracy_heatmap(overall_accuracies_avg, average_thinking_time, std_thinking_time, scale, save_path, args):
305
+ fig, ax = plt.subplots(figsize=(scale*10, scale*5))
306
+ im = ax.imshow(overall_accuracies_avg.T * 100, aspect='auto', cmap="viridis", origin='lower', extent=[0, args.iterations-1, 0, args.parity_sequence_length-1], vmin=50, vmax=100)
307
+ cbar = fig.colorbar(im, ax=ax, format="%.1f")
308
+ cbar.set_label("Accuracy (%)")
309
+ ax.errorbar(average_thinking_time, np.arange(args.parity_sequence_length), xerr=std_thinking_time, fmt='ko', markersize=2, capsize=2, elinewidth=1, label="Min. Entropy")
310
+ ax.set_xlabel("Time Step")
311
+ ax.set_ylabel("Sequence Index")
312
+ ax.set_xlim(0, args.iterations-1)
313
+ ax.set_ylim(0, args.parity_sequence_length-1)
314
+ ax.grid(False)
315
+ ax.legend(loc="upper left")
316
+ fig.tight_layout(pad=0.1)
317
+ fig.savefig(save_path, dpi=300, bbox_inches="tight")
318
+ fig.savefig(save_path.replace(".png", ".pdf"), format='pdf', bbox_inches="tight")
319
+ plt.close(fig)
320
+
321
+ def plot_attention_heatmap(overall_attentions_avg, scale, save_path, vmin=None, vmax=None):
322
+ overall_attentions_avg = overall_attentions_avg.reshape(overall_attentions_avg.shape[0], -1)
323
+ fig, ax = plt.subplots(figsize=(scale*10, scale*5))
324
+ im = ax.imshow(overall_attentions_avg.T, aspect='auto', cmap="viridis", origin='lower', extent=[0, overall_attentions_avg.shape[0]-1, 0, overall_attentions_avg.shape[1]-1], vmin=vmin, vmax=vmax)
325
+ cbar = fig.colorbar(im, ax=ax, format=FuncFormatter(lambda x, _: f"{x:05.2f}"))
326
+ cbar.set_label("Attention Weight")
327
+ ax.set_xlabel("Time Step")
328
+ ax.set_ylabel("Sequence Index")
329
+ ax.set_xlim(0, overall_attentions_avg.shape[0]-1)
330
+ ax.set_ylim(0, overall_attentions_avg.shape[1]-1)
331
+ ax.grid(False)
332
+ fig.tight_layout(pad=0.1)
333
+ fig.savefig(save_path, dpi=300, bbox_inches="tight")
334
+ fig.savefig(save_path.replace(".png", ".pdf"), format='pdf', bbox_inches="tight")
335
+ plt.close(fig)
336
+
337
+ def create_accuracies_heatmap_gif(all_accuracies, all_average_thinking_times, all_std_thinking_times, scale, save_dir, args):
338
+ heatmap_components_dir = os.path.join(save_dir, "accuracy_heatmaps")
339
+ os.makedirs(heatmap_components_dir, exist_ok=True)
340
+
341
+ image_paths = []
342
+
343
+ for i, (accuracies, avg_thinking_time, std_thinking_time) in enumerate(zip(all_accuracies, all_average_thinking_times, all_std_thinking_times)):
344
+ save_path = os.path.join(heatmap_components_dir, f"frame_{i:04d}.png")
345
+ plot_accuracy_heatmap(accuracies, avg_thinking_time, std_thinking_time, scale, save_path, args)
346
+ image_paths.append(save_path)
347
+
348
+ gif_path = os.path.join(save_dir, "accuracy_heatmap.gif")
349
+ with imageio.get_writer(gif_path, mode='I', duration=0.3) as writer:
350
+ for image_path in image_paths:
351
+ image = imageio.imread(image_path)
352
+ writer.append_data(image)
353
+
354
+ def create_attentions_heatmap_gif(all_attentions, scale, save_path, args):
355
+ heatmap_components_dir = os.path.join(args.log_dir, "attention_heatmaps")
356
+ os.makedirs(heatmap_components_dir, exist_ok=True)
357
+
358
+ global_min = min(attentions.min() for attentions in all_attentions)
359
+ global_max = max(attentions.max() for attentions in all_attentions)
360
+
361
+ image_paths = []
362
+
363
+ for i, attentions in enumerate(all_attentions):
364
+ save_path_component = os.path.join(heatmap_components_dir, f"frame_{i:04d}.png")
365
+ plot_attention_heatmap(attentions, scale, save_path_component, vmin=global_min, vmax=global_max)
366
+ image_paths.append(save_path_component)
367
+
368
+ gif_path = os.path.join(save_path, "attention_heatmap.gif")
369
+ with imageio.get_writer(gif_path, mode='I', duration=0.3) as writer:
370
+ for image_path in image_paths:
371
+ image = imageio.imread(image_path)
372
+ writer.append_data(image)
373
+
374
+ def create_stacked_gif(save_path, y_shift=200):
375
+ accuracy_gif_path = os.path.join(save_path, "accuracy_heatmap.gif")
376
+ attention_gif_path = os.path.join(save_path, "attention_heatmap.gif")
377
+ stacked_gif_path = os.path.join(save_path, "stacked_heatmap.gif")
378
+
379
+ accuracy_reader = imageio.get_reader(accuracy_gif_path)
380
+ attention_reader = imageio.get_reader(attention_gif_path)
381
+
382
+ accuracy_frames = [Image.fromarray(frame) for frame in accuracy_reader]
383
+ attention_frames = [Image.fromarray(frame) for frame in attention_reader]
384
+
385
+ assert len(accuracy_frames) == len(attention_frames), "Mismatch in frame counts between accuracy and attention GIFs"
386
+
387
+ stacked_frames = []
388
+ for acc_frame, att_frame in zip(accuracy_frames, attention_frames):
389
+ acc_width, acc_height = acc_frame.size
390
+ att_width, att_height = att_frame.size
391
+
392
+ # Create base canvas
393
+ stacked_height = acc_height + att_height - y_shift
394
+ stacked_width = max(acc_width, att_width)
395
+
396
+ stacked_frame = Image.new("RGB", (stacked_width, stacked_height), color=(255, 255, 255))
397
+
398
+ # Paste attention frame first, shifted up
399
+ stacked_frame.paste(att_frame, (0, 0)) # Paste at top
400
+ stacked_frame.paste(acc_frame, (0, att_height - y_shift)) # Shift accuracy up by overlap
401
+
402
+ stacked_frames.append(stacked_frame)
403
+
404
+ stacked_frames[0].save(
405
+ stacked_gif_path,
406
+ save_all=True,
407
+ append_images=stacked_frames[1:],
408
+ duration=300,
409
+ loop=0
410
+ )
411
+
412
+ save_frames_to_mp4(
413
+ [np.array(fm)[:, :, ::-1] for fm in stacked_frames],
414
+ f"{stacked_gif_path.replace('gif', 'mp4')}",
415
+ fps=15,
416
+ gop_size=1,
417
+ preset="slow"
418
+ )
419
+
420
+
421
+ def plot_accuracy_training(all_accuracies, scale, run_model_spefic_save_dir, args):
422
+ scale=0.5
423
+ seq_indices = range(args.parity_sequence_length)
424
+ fig, ax = plt.subplots(figsize=(scale*10, scale*5))
425
+ cmap = plt.get_cmap("viridis")
426
+
427
+ for i, acc in enumerate(all_accuracies):
428
+ color = cmap(i / (len(all_accuracies) - 1))
429
+ ax.plot(seq_indices, acc*100, color=color, alpha=0.7, linewidth=1)
430
+
431
+ num_checkpoints = 5
432
+ checkpoint_percentages = np.linspace(0, 100, num_checkpoints)
433
+
434
+ sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(vmin=0, vmax=100))
435
+ sm.set_array([])
436
+ cbar = fig.colorbar(sm, ax=ax)
437
+ cbar.set_label("Training Progress (%)")
438
+ cbar.set_ticks(checkpoint_percentages)
439
+ cbar.set_ticklabels([f"{int(p)}%" for p in checkpoint_percentages])
440
+
441
+ ax.set_xlabel("Sequence Index")
442
+ ax.set_ylabel("Accuracy (%)")
443
+ ax.set_xticks([0, 16 ,32, 48, 63])
444
+ ax.grid(True, alpha=0.5)
445
+ ax.set_xlim(0, args.parity_sequence_length - 1)
446
+
447
+ fig.tight_layout(pad=0.1)
448
+ fig.savefig(f"{run_model_spefic_save_dir}/accuracy_vs_seq_element.png", dpi=300, bbox_inches="tight")
449
+ fig.savefig(f"{run_model_spefic_save_dir}/accuracy_vs_seq_element.pdf", format='pdf', bbox_inches="tight")
450
+ plt.close(fig)
451
+
452
+
453
+ def plot_loss_all_runs(training_data, evaluate_every, save_path="train_loss_comparison_parity.png", step=1, scale=1.0, x_max=None):
454
+ fig, ax = plt.subplots(figsize=(scale * 10, scale * 5))
455
+
456
+ grouped = defaultdict(list)
457
+ label_map = {}
458
+ linestyle_map = {}
459
+ iters_map = {}
460
+ model_map = {}
461
+
462
+ for folder, data in training_data.items():
463
+ label, model_type, iters = parse_folder_name(folder)
464
+ if iters is None:
465
+ continue
466
+
467
+ key = f"{model_type}_{iters}"
468
+ grouped[key].append(data["train_losses"])
469
+ label_map[key] = f"{model_type}, {iters} Iters."
470
+ linestyle_map[key] = "--" if model_type == "LSTM" else "-"
471
+ iters_map[key] = iters
472
+ model_map[key] = model_type
473
+
474
+ unique_iters = sorted(set(iters_map.values()))
475
+ base_colors = sns.color_palette("hls", n_colors=len(unique_iters))
476
+ color_lookup = {iters: base_colors[i] for i, iters in enumerate(unique_iters)}
477
+
478
+ legend_entries = []
479
+ global_max_x = 0
480
+ for key in sorted(grouped.keys(), key=lambda k: (iters_map[k], model_map[k])):
481
+ runs = grouped[key]
482
+ if not runs:
483
+ continue
484
+
485
+ iters = iters_map[key]
486
+ color = color_lookup[iters]
487
+ linestyle = linestyle_map[key]
488
+
489
+ min_len = min(len(r) for r in runs)
490
+ trimmed = np.array([r[:min_len] for r in runs])[:, ::step]
491
+
492
+ mean = np.mean(trimmed, axis=0)
493
+ std = np.std(trimmed, axis=0)
494
+ x = np.arange(len(mean)) * step * evaluate_every
495
+ group_max_x = len(mean) * step * evaluate_every
496
+ global_max_x = max(global_max_x, group_max_x)
497
+
498
+ line, = ax.plot(x, mean, color=color, linestyle=linestyle, label=label_map[key])
499
+ ax.fill_between(x, mean - std, mean + std, alpha=0.1, color=color)
500
+
501
+ legend_entries.append((line, label_map[key]))
502
+
503
+ ax.set_xlabel("Training Iterations")
504
+ ax.set_ylabel("Loss")
505
+ ax.grid(True, alpha=0.5)
506
+
507
+ style_legend = [
508
+ Line2D([0], [0], color='black', linestyle='-', label='CTM'),
509
+ Line2D([0], [0], color='black', linestyle='--', label='LSTM')
510
+ ]
511
+ color_legend = [
512
+ Line2D([0], [0], color=color_lookup[it], linestyle='-', label=f"{it} Iters.")
513
+ for it in unique_iters
514
+ ]
515
+
516
+ if not x_max:
517
+ x_max = global_max_x
518
+
519
+ ax.set_xlim([0, x_max])
520
+ ax.set_ylim(bottom=0)
521
+ ax.set_xticks(np.arange(0, x_max + 1, 50000))
522
+ ax.legend(handles=color_legend + style_legend, loc="upper left")
523
+ fig.tight_layout(pad=0.1)
524
+ fig.savefig(save_path, dpi=300)
525
+ fig.savefig(save_path.replace("png", "pdf"), format='pdf')
526
+ plt.close(fig)
527
+
528
+ def plot_accuracy_all_runs(training_data, evaluate_every, save_path="test_accuracy_comparison_parity.png", step=1, scale=1.0, smooth=False, x_max=None):
529
+ fig, ax = plt.subplots(figsize=(scale * 10, scale * 5))
530
+
531
+ grouped = defaultdict(list)
532
+ label_map = {}
533
+ linestyle_map = {}
534
+ iters_map = {}
535
+ model_map = {}
536
+
537
+ for folder, data in training_data.items():
538
+ label, model_type, iters = parse_folder_name(folder)
539
+ if iters is None:
540
+ continue
541
+
542
+ key = f"{model_type}_{iters}"
543
+ grouped[key].append(data["test_accuracies"])
544
+ label_map[key] = f"{model_type}, {iters} Iters."
545
+ linestyle_map[key] = "--" if model_type == "LSTM" else "-"
546
+ iters_map[key] = iters
547
+ model_map[key] = model_type
548
+
549
+ unique_iters = sorted(set(iters_map.values()))
550
+ base_colors = sns.color_palette("hls", n_colors=len(unique_iters))
551
+ color_lookup = {iters: base_colors[i] for i, iters in enumerate(unique_iters)}
552
+
553
+ legend_entries = []
554
+ global_max_x = 0
555
+
556
+ for key in sorted(grouped.keys(), key=lambda k: (iters_map[k], model_map[k])):
557
+ runs = grouped[key]
558
+ if not runs:
559
+ continue
560
+
561
+ iters = iters_map[key]
562
+ model = model_map[key]
563
+ color = color_lookup[iters]
564
+ linestyle = linestyle_map[key]
565
+
566
+ min_len = min(len(r) for r in runs)
567
+ trimmed = np.array([r[:min_len] for r in runs])[:, ::step]
568
+
569
+ mean = np.mean(trimmed, axis=0) * 100
570
+ std = np.std(trimmed, axis=0) * 100
571
+
572
+ if smooth:
573
+ window_size = max(1, int(0.05 * len(mean)))
574
+ if window_size % 2 == 0:
575
+ window_size += 1
576
+ kernel = np.ones(window_size) / window_size
577
+
578
+ smoothed_mean = np.convolve(mean, kernel, mode='same')
579
+ smoothed_std = np.convolve(std, kernel, mode='same')
580
+
581
+ valid_start = window_size // 2
582
+ valid_end = len(mean) - window_size // 2
583
+ valid_length = valid_end - valid_start
584
+
585
+ mean = smoothed_mean[valid_start:valid_end]
586
+ std = smoothed_std[valid_start:valid_end]
587
+ x = np.arange(valid_length) * step * evaluate_every
588
+ group_max_x = valid_length * step * evaluate_every
589
+ else:
590
+ x = np.arange(len(mean)) * step * evaluate_every
591
+ group_max_x = len(mean) * step * evaluate_every
592
+
593
+ global_max_x = max(global_max_x, group_max_x)
594
+
595
+ line, = ax.plot(x, mean, color=color, linestyle=linestyle, label=label_map[key])
596
+ ax.fill_between(x, mean - std, mean + std, alpha=0.1, color=color)
597
+ legend_entries.append((line, label_map[key]))
598
+
599
+ if smooth or x_max is None:
600
+ x_max = global_max_x
601
+
602
+ ax.set_xlim([0, x_max])
603
+ ax.set_ylim(top=100)
604
+ ax.set_xticks(np.arange(0, x_max + 1, 50000))
605
+ ax.set_xlabel("Training Iterations")
606
+ ax.set_ylabel("Accuracy (%)")
607
+ ax.grid(True, alpha=0.5)
608
+
609
+ style_legend = [
610
+ Line2D([0], [0], color='black', linestyle='-', label='CTM'),
611
+ Line2D([0], [0], color='black', linestyle='--', label='LSTM')
612
+ ]
613
+ color_legend = [
614
+ Line2D([0], [0], color=color_lookup[it], linestyle='-', label=f"{it} Iters.")
615
+ for it in unique_iters
616
+ ]
617
+ ax.legend(handles=color_legend + style_legend, loc="upper left")
618
+
619
+ fig.tight_layout(pad=0.1)
620
+ fig.savefig(save_path, dpi=300)
621
+ fig.savefig(save_path.replace("png", "pdf"), format='pdf')
622
+ plt.close(fig)
623
+
624
+ def extract_run_name(folder, run_index=None):
625
+ # Try to extract from parent folder
626
+ parent = os.path.basename(os.path.dirname(folder))
627
+ match = re.search(r'run(\d+)', parent, re.IGNORECASE)
628
+ if match:
629
+ return f"Run {int(match.group(1))}"
630
+ # Try current folder name
631
+ basename = os.path.basename(folder)
632
+ match = re.search(r'run(\d+)', basename, re.IGNORECASE)
633
+ if match:
634
+ return f"Run {int(match.group(1))}"
635
+ # Fallback: use run index
636
+ if run_index is not None:
637
+ return f"Run {run_index + 1}"
638
+ raise ValueError(f"Could not extract run number from: {folder}")
639
+
640
+ def plot_loss_individual_runs(training_data, evaluate_every, save_dir, scale=1.0, x_max=None):
641
+
642
+ grouped = defaultdict(list)
643
+ label_map = {}
644
+ iters_map = {}
645
+ model_map = {}
646
+
647
+ base_colors = sns.color_palette("hls", n_colors=3)
648
+ color_lookup = {f"Run {i+1}": base_colors[i] for i in range(3)}
649
+
650
+ for i, (folder, data) in enumerate(training_data.items()):
651
+ checkpoint = load_checkpoint(get_latest_checkpoint_file(folder), device="cpu")
652
+ model_args = get_model_args_from_checkpoint(checkpoint)
653
+ label, model_type, iters = parse_folder_name(folder)
654
+ if iters is None:
655
+ continue
656
+
657
+ if model_type.lower() == "ctm":
658
+ memory_length = getattr(model_args, "memory_length", None)
659
+ if memory_length is None:
660
+ raise ValueError(f"CTM model missing memory_length in checkpoint args from: {folder}")
661
+ key = f"{model_type}_{iters}_{memory_length}".lower()
662
+ else:
663
+ key = f"{model_type}_{iters}".lower()
664
+
665
+ run_name = extract_run_name(folder, run_index=i)
666
+ grouped[key].append((run_name, data["train_losses"]))
667
+ label_map[key] = f"{model_type}, {iters} Iters."
668
+ iters_map[key] = iters
669
+ model_map[key] = model_type
670
+
671
+ for key, runs in grouped.items():
672
+ fig, ax = plt.subplots(figsize=(scale * 10, scale * 5))
673
+ for run_name, losses in runs:
674
+ x = np.arange(len(losses)) * evaluate_every
675
+ color = color_lookup.get(run_name, 'gray')
676
+ ax.plot(x, losses, label=run_name, color=color, alpha=0.7)
677
+
678
+ ax.set_xlabel("Training Iterations")
679
+ ax.set_ylabel("Loss")
680
+ ax.set_ylim(bottom=-0.01)
681
+ ax.grid(True, alpha=0.5)
682
+ if x_max:
683
+ ax.set_xlim([0, x_max])
684
+ ax.set_xticks(np.arange(0, x_max + 1, 50000))
685
+ ax.legend()
686
+ fig.tight_layout(pad=0.1)
687
+
688
+ subdir = os.path.join(save_dir, key)
689
+ os.makedirs(subdir, exist_ok=True)
690
+ fname = os.path.join(subdir, f"individual_runs_loss_{key}.png")
691
+ fig.savefig(fname, dpi=300)
692
+ fig.savefig(fname.replace("png", "pdf"), format="pdf")
693
+ plt.close(fig)
694
+
695
+ def plot_accuracy_individual_runs(training_data, evaluate_every, save_dir, scale=1.0, smooth=False, x_max=None):
696
+
697
+ grouped = defaultdict(list)
698
+ label_map = {}
699
+ iters_map = {}
700
+ model_map = {}
701
+
702
+ base_colors = sns.color_palette("hls", n_colors=3)
703
+ color_lookup = {f"Run {i+1}": base_colors[i] for i in range(3)}
704
+
705
+ for i, (folder, data) in enumerate(training_data.items()):
706
+ checkpoint = load_checkpoint(get_latest_checkpoint_file(folder), device="cpu")
707
+ model_args = get_model_args_from_checkpoint(checkpoint)
708
+ label, model_type, iters = parse_folder_name(folder)
709
+ if iters is None:
710
+ continue
711
+
712
+ if model_type.lower() == "ctm":
713
+ memory_length = getattr(model_args, "memory_length", None)
714
+ if memory_length is None:
715
+ raise ValueError(f"CTM model missing memory_length in checkpoint args from: {folder}")
716
+ key = f"{model_type}_{iters}_{memory_length}".lower()
717
+ else:
718
+ key = f"{model_type}_{iters}".lower()
719
+
720
+ run_name = extract_run_name(folder, run_index=i)
721
+ grouped[key].append((run_name, data["test_accuracies"]))
722
+ label_map[key] = f"{model_type}, {iters} Iters."
723
+ iters_map[key] = iters
724
+ model_map[key] = model_type
725
+
726
+ for key, runs in grouped.items():
727
+ fig, ax = plt.subplots(figsize=(scale * 10, scale * 5))
728
+ for run_name, acc in runs:
729
+ acc = np.array(acc) * 100
730
+ if smooth:
731
+ window_size = max(1, int(0.05 * len(acc)))
732
+ if window_size % 2 == 0:
733
+ window_size += 1
734
+ kernel = np.ones(window_size) / window_size
735
+ acc = np.convolve(acc, kernel, mode="same")
736
+
737
+ x = np.arange(len(acc)) * evaluate_every
738
+ color = color_lookup.get(run_name, 'gray')
739
+ ax.plot(x, acc, label=run_name, color=color, alpha=0.7)
740
+
741
+ ax.set_xlabel("Training Iterations")
742
+ ax.set_ylabel("Accuracy (%)")
743
+ ax.set_ylim([50, 101])
744
+ ax.grid(True, alpha=0.5)
745
+ if x_max:
746
+ ax.set_xlim([0, x_max])
747
+ ax.set_xticks(np.arange(0, x_max + 1, 50000))
748
+ ax.legend()
749
+ fig.tight_layout(pad=0.1)
750
+
751
+ subdir = os.path.join(save_dir, key)
752
+ os.makedirs(subdir, exist_ok=True)
753
+ fname = os.path.join(subdir, f"individual_runs_accuracy_{key}.png")
754
+ fig.savefig(fname, dpi=300)
755
+ fig.savefig(fname.replace("png", "pdf"), format="pdf")
756
+ plt.close(fig)
757
+
758
+ def plot_training_curve_all_runs(all_folders, save_dir, scale, device, smooth=False, x_max=None, plot_individual_runs=True):
759
+
760
+ all_folders = [folder for folder in all_folders if "certain" not in folder]
761
+
762
+ training_data = {}
763
+ evaluation_intervals = []
764
+ for folder in all_folders:
765
+ latest_checkpoint_path = get_latest_checkpoint_file(folder)
766
+ if latest_checkpoint_path:
767
+ checkpoint = load_checkpoint(latest_checkpoint_path, device=device)
768
+ model_args = get_model_args_from_checkpoint(checkpoint)
769
+ evaluation_intervals.append(model_args.track_every)
770
+
771
+ _, train_losses, test_losses, train_accuracies, test_accuracies = get_accuracy_and_loss_from_checkpoint(checkpoint, device=device)
772
+ training_data[folder] = {
773
+ "train_losses": train_losses,
774
+ "test_losses": test_losses,
775
+ "train_accuracies": train_accuracies,
776
+ "test_accuracies": test_accuracies
777
+ }
778
+ else:
779
+ print(f"No checkpoint found for {folder}")
780
+
781
+ assert len(evaluation_intervals) > 0, "No valid checkpoints found."
782
+ assert all(interval == evaluation_intervals[0] for interval in evaluation_intervals), "Evaluation intervals are not consistent across runs."
783
+
784
+ evaluate_every = evaluation_intervals[0]
785
+
786
+ if plot_individual_runs:
787
+ plot_loss_individual_runs(training_data, evaluate_every, save_dir=save_dir, scale=scale, x_max=x_max)
788
+ plot_accuracy_individual_runs(training_data, evaluate_every, save_dir=save_dir, scale=scale, smooth=smooth, x_max=x_max)
789
+
790
+ plot_loss_all_runs(training_data, evaluate_every, save_path=f"{save_dir}/loss_comparison.png", scale=scale, x_max=x_max)
791
+ plot_accuracy_all_runs(training_data, evaluate_every, save_path=f"{save_dir}/accuracy_comparison.png", scale=scale, smooth=smooth, x_max=x_max)
792
+
793
+ return training_data
794
+
795
+ def plot_accuracy_thinking_time(csv_path, scale, output_dir="analysis/cifar"):
796
+ if not os.path.exists(csv_path):
797
+ raise FileNotFoundError(f"CSV file not found: {csv_path}")
798
+
799
+ df = pd.read_csv(csv_path)
800
+ df["RunName"] = df["Run"].apply(lambda x: os.path.basename(os.path.dirname(x)))
801
+ df["Model"] = df["Run"].apply(lambda x: "CTM" if "ctm" in x.lower() else "LSTM")
802
+
803
+ grouped = df.groupby(["Model", "Num Iterations"])
804
+ summary = grouped.agg(
805
+ mean_accuracy=("Overall Mean Accuracy", "mean"),
806
+ std_accuracy=("Overall Std Accuracy", lambda x: np.sqrt(np.mean(x**2)))
807
+ ).reset_index()
808
+
809
+ summary["mean_accuracy"] *= 100
810
+ summary["std_accuracy"] *= 100
811
+
812
+ fig, ax = plt.subplots(figsize=(scale*5, scale*5))
813
+
814
+ for model in ("CTM", "LSTM"):
815
+ subset = summary[summary["Model"] == model].sort_values(by="Num Iterations")
816
+ linestyle = "-" if model == "CTM" else "--"
817
+ ax.errorbar(
818
+ subset["Num Iterations"],
819
+ subset["mean_accuracy"],
820
+ yerr=subset["std_accuracy"],
821
+ linestyle=linestyle,
822
+ color="black",
823
+ marker='.',
824
+ label=model,
825
+ capsize=3,
826
+ elinewidth=1,
827
+ errorevery=1
828
+ )
829
+
830
+ ax.set_xlabel("Internal Ticks")
831
+ ax.set_ylabel("Accuracy (%)")
832
+ custom_lines = [
833
+ Line2D([0], [0], color='black', linestyle='-', label='CTM'),
834
+ Line2D([0], [0], color='black', linestyle='--', label='LSTM')
835
+ ]
836
+ ax.legend(handles=custom_lines, loc="lower right")
837
+ ax.grid(True, alpha=0.5)
838
+
839
+ os.makedirs(output_dir, exist_ok=True)
840
+ output_path_png = os.path.join(output_dir, "accuracy_vs_thinking_time.png")
841
+ fig.tight_layout(pad=0.1)
842
+ fig.savefig(output_path_png, dpi=300)
843
+ fig.savefig(output_path_png.replace("png", "pdf"), format='pdf')
844
+ plt.close(fig)
845
+
846
+
847
+ def plot_lstm_last_and_certain_accuracy(all_folders, save_path="lstm_last_and_certain_accuracy.png", scale=1.0, step=1, x_max=None):
848
+
849
+ tags = ["lstm_10", "lstm_10_certain", "lstm_25", "lstm_25_certain"]
850
+ folders = [f for f in all_folders if any(tag in f.lower() for tag in tags)]
851
+
852
+ training_data, eval_intervals = {}, []
853
+ for f in folders:
854
+ cp = get_latest_checkpoint_file(f)
855
+ if not cp:
856
+ print(f"⚠️ No checkpoint in {f}")
857
+ continue
858
+ ckpt = load_checkpoint(cp, device="cpu")
859
+ args = get_model_args_from_checkpoint(ckpt)
860
+ eval_intervals.append(args.track_every)
861
+ _, _, _, _, acc = get_accuracy_and_loss_from_checkpoint(ckpt)
862
+ iters = "25" if "25" in f.lower() else "10"
863
+ label = "Certain" if "certain" in f.lower() else "Final"
864
+ training_data.setdefault((iters, label), []).append(acc)
865
+
866
+ assert training_data and all(i == eval_intervals[0] for i in eval_intervals), "Missing or inconsistent eval intervals."
867
+ evaluate_every = eval_intervals[0]
868
+
869
+ keys = sorted(training_data.keys())
870
+ colors = sns.color_palette("hls", n_colors=len(keys))
871
+ style_map = {key: ("--" if key[1] == "Certain" else "-") for key in keys}
872
+ color_map = {key: colors[i] for i, key in enumerate(keys)}
873
+
874
+ fig, ax = plt.subplots(figsize=(scale * 10, scale * 5))
875
+ max_x = 0
876
+
877
+ for key in keys:
878
+ runs = training_data[key]
879
+ min_len = min(len(r) for r in runs)
880
+ trimmed = np.stack([r[:min_len] for r in runs], axis=0)[:, ::step]
881
+ mean, std = np.mean(trimmed, 0) * 100, np.std(trimmed, 0) * 100
882
+ x = np.arange(len(mean)) * step * evaluate_every
883
+ ax.plot(x, mean, color=color_map[key], linestyle=style_map[key],
884
+ label=f"{key[0]} Iters, {key[1]}", linewidth=2, alpha=0.7)
885
+ ax.fill_between(x, mean - std, mean + std, color=color_map[key], alpha=0.1)
886
+ max_x = max(max_x, x[-1])
887
+
888
+ ax.set_xlim([0, x_max or max_x])
889
+ ax.set_xticks(np.arange(0, (x_max or max_x) + 1, 50000))
890
+ ax.set_xlabel("Training Iterations")
891
+ ax.set_ylabel("Accuracy (%)")
892
+ ax.grid(True, alpha=0.5)
893
+ ax.legend(loc="lower right")
894
+ fig.tight_layout(pad=0.1)
895
+ fig.savefig(save_path, dpi=300)
896
+ fig.savefig(save_path.replace("png", "pdf"), format="pdf")
897
+ plt.close(fig)