Spaces:
Paused
Paused
Commit ·
ed89628
1
Parent(s): 9bb3382
Deploy CTM Codebase bypass FUSE 503
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .dockerignore +5 -0
- Dockerfile +34 -0
- INSTRUCCIONES_DESPLIEGUE.md +34 -0
- LICENSE +201 -0
- README.md +22 -7
- app.py +945 -0
- app_v1_backup.py +464 -0
- data/custom_datasets.py +324 -0
- examples/01_mnist.ipynb +0 -0
- examples/02_inference.ipynb +0 -0
- examples/03_mazes.ipynb +0 -0
- examples/04_parity.ipynb +0 -0
- examples/05_huggingface.ipynb +0 -0
- models/README.md +7 -0
- models/constants.py +10 -0
- models/ctm.py +633 -0
- models/ctm_qamnist.py +208 -0
- models/ctm_rl.py +192 -0
- models/ctm_sort.py +126 -0
- models/ff.py +75 -0
- models/lstm.py +244 -0
- models/lstm_qamnist.py +184 -0
- models/lstm_rl.py +96 -0
- models/modules.py +692 -0
- models/resnet.py +374 -0
- models/utils.py +122 -0
- mount_azure.sh +44 -0
- requirements.txt +21 -0
- requirements_v1.txt +2 -0
- setup_hf_space.sh +37 -0
- tasks/image_classification/README.md +31 -0
- tasks/image_classification/analysis/README.md +7 -0
- tasks/image_classification/analysis/run_imagenet_analysis.py +972 -0
- tasks/image_classification/imagenet_classes.py +1007 -0
- tasks/image_classification/plotting.py +494 -0
- tasks/image_classification/scripts/train_cifar10.sh +286 -0
- tasks/image_classification/scripts/train_imagenet.sh +38 -0
- tasks/image_classification/train.py +690 -0
- tasks/image_classification/train_distributed.py +799 -0
- tasks/mazes/README.md +16 -0
- tasks/mazes/analysis/README.md +10 -0
- tasks/mazes/analysis/run.py +407 -0
- tasks/mazes/plotting.py +214 -0
- tasks/mazes/scripts/train_ctm.sh +35 -0
- tasks/mazes/train.py +704 -0
- tasks/mazes/train_distributed.py +782 -0
- tasks/parity/README.md +16 -0
- tasks/parity/analysis/make_blog_gifs.py +263 -0
- tasks/parity/analysis/run.py +269 -0
- tasks/parity/plotting.py +897 -0
.dockerignore
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
checkpoints/
|
| 2 |
+
data/
|
| 3 |
+
.git/
|
| 4 |
+
__pycache__/
|
| 5 |
+
*.ipynb
|
Dockerfile
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.10-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /code
|
| 4 |
+
|
| 5 |
+
# 1. Install System Dependencies (SSHFS + Curl)
|
| 6 |
+
RUN apt-get update && apt-get install -y \
|
| 7 |
+
sshfs \
|
| 8 |
+
curl \
|
| 9 |
+
fuse \
|
| 10 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 11 |
+
|
| 12 |
+
# 2. Install Cloudflared
|
| 13 |
+
RUN curl -L --output cloudflared.deb https://github.com/cloudflare/cloudflared/releases/latest/download/cloudflared-linux-amd64.deb && \
|
| 14 |
+
dpkg -i cloudflared.deb && \
|
| 15 |
+
rm cloudflared.deb
|
| 16 |
+
|
| 17 |
+
COPY ./requirements.txt /code/requirements.txt
|
| 18 |
+
RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
|
| 19 |
+
|
| 20 |
+
# Create mount point
|
| 21 |
+
RUN mkdir -p /data/persistent && chmod 777 /data/persistent
|
| 22 |
+
|
| 23 |
+
RUN useradd -m -u 1000 user
|
| 24 |
+
USER user
|
| 25 |
+
ENV HOME=/home/user \
|
| 26 |
+
PATH=/home/user/.local/bin:$PATH
|
| 27 |
+
|
| 28 |
+
WORKDIR $HOME/app
|
| 29 |
+
|
| 30 |
+
COPY --chown=user . $HOME/app
|
| 31 |
+
RUN chmod +x mount_azure.sh
|
| 32 |
+
|
| 33 |
+
# El puerto 7860 es obligatorio en Hugging Face Spaces
|
| 34 |
+
CMD ["python", "app.py"]
|
INSTRUCCIONES_DESPLIEGUE.md
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🚀 Instrucciones de Despliegue: Continuous Thought Machines (CTM)
|
| 2 |
+
|
| 3 |
+
Debido a restricciones de permisos en la terminal actual, el despliegue final requiere que ejecutes el "puente" que he construido desde tu entorno WSL (Ubuntu/Debian).
|
| 4 |
+
|
| 5 |
+
## 1. Requisitos Previos (Ya configurados)
|
| 6 |
+
* **Código**: Clonado en `c:\Users\elliot\Downloads\simulacion\ctm_sakana`
|
| 7 |
+
* **Docker**: Archivos `Dockerfile` y `.dockerignore` creados.
|
| 8 |
+
* **Dependencias**: `requirements.txt` parcheado (opencv-headless).
|
| 9 |
+
* **Script de Conexión**: `setup_hf_space.sh` creado con tu Token.
|
| 10 |
+
|
| 11 |
+
## 2. Pasos de Ejecución (En tu WSL)
|
| 12 |
+
|
| 13 |
+
Abre tu terminal WSL y ejecuta el siguiente bloque de comandos:
|
| 14 |
+
|
| 15 |
+
```bash
|
| 16 |
+
# 1. Navegar a la carpeta del proyecto (WSL monta C: en /mnt/c)
|
| 17 |
+
cd /mnt/c/Users/elliot/Downloads/simulacion/ctm_sakana
|
| 18 |
+
|
| 19 |
+
# 2. Dar permisos de ejecución al script
|
| 20 |
+
chmod +x setup_hf_space.sh
|
| 21 |
+
|
| 22 |
+
# 3. Ejecutar el script automatizado
|
| 23 |
+
./setup_hf_space.sh
|
| 24 |
+
```
|
| 25 |
+
|
| 26 |
+
### 3. Durante la Ejecución
|
| 27 |
+
El script te pedirá el **Nombre del Space**.
|
| 28 |
+
* Basado en la captura, el usuario parece ser `Alex Herbert Vilca Puente` o `ROBOT-GANSTA`.
|
| 29 |
+
* Ingresa el nombre del Space en formato `USUARIO/NOMBRE_SPACE` si te lo pide el script, o solo el nombre si el script ya tiene el usuario hardcodeado (El script tiene `Elliotasdasdasfasas`, asegúrate de que coincida con tu Space real o edita el script).
|
| 30 |
+
|
| 31 |
+
## 4. Verificación
|
| 32 |
+
Una vez subido, ve a tu Space en Hugging Face. Verás que empieza a decir **"Building"**. Esto significa que Docker está instalando las dependencias que definimos.
|
| 33 |
+
|
| 34 |
+
> **Nota**: El proceso de build tardará unos minutos la primera vez.
|
LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright 2020 Rémi Louf
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
README.md
CHANGED
|
@@ -1,11 +1,26 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
-
sdk:
|
|
|
|
|
|
|
| 7 |
pinned: false
|
| 8 |
-
license: gemma
|
| 9 |
---
|
| 10 |
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: CTM Nervous System
|
| 3 |
+
emoji: 🧬
|
| 4 |
+
colorFrom: purple
|
| 5 |
+
colorTo: blue
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 5.9.1
|
| 8 |
+
app_file: app.py
|
| 9 |
pinned: false
|
|
|
|
| 10 |
---
|
| 11 |
|
| 12 |
+
# 🧬 CTM Nervous System
|
| 13 |
+
|
| 14 |
+
**Continuous Thought Machine for Hypergraph Maintenance**
|
| 15 |
+
|
| 16 |
+
Based on [arXiv:2505.05522](https://arxiv.org/abs/2505.05522) - Sakana AI
|
| 17 |
+
|
| 18 |
+
## Endpoints
|
| 19 |
+
|
| 20 |
+
| Endpoint | Function |
|
| 21 |
+
|----------|----------|
|
| 22 |
+
| `/sense_snn` | Process 72D SNN input |
|
| 23 |
+
| `/reason_hypergraph` | Reason about context, propose edges |
|
| 24 |
+
| `/validate_physics` | Validate against 5 physics losses |
|
| 25 |
+
| `/dream` | Offline consolidation (T=500+) |
|
| 26 |
+
| `/calibrate_stdp` | Suggest STDP weight adjustments |
|
app.py
ADDED
|
@@ -0,0 +1,945 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
CTM Nervous System Server v2.0 - Full PyTorch Implementation
|
| 3 |
+
=============================================================
|
| 4 |
+
Continuous Thought Machine for ART-17 Hypergraph Coherence Generation
|
| 5 |
+
|
| 6 |
+
PURPOSE (from skills):
|
| 7 |
+
1. REGULACIÓN: Calibrar pesos STDP de las 16 dendritas
|
| 8 |
+
2. COHERENCIA: Generar hipergrafos deterministas
|
| 9 |
+
3. RAZONAMIENTO: Motor de inferencia activa (internal ticks)
|
| 10 |
+
4. SINCRONIZACIÓN: Representación via Neural Synchronization
|
| 11 |
+
|
| 12 |
+
TRAINING STRATEGY:
|
| 13 |
+
- Progressive online learning with use
|
| 14 |
+
- Integrates with Brain server (Qwen + VL-JEPA) for semantic grounding
|
| 15 |
+
- Automatic checkpoint saving
|
| 16 |
+
|
| 17 |
+
Based on: arXiv:2505.05522 (Continuous Thought Machines - Sakana AI)
|
| 18 |
+
Adapted for: ART-17 Dendrite Regulation & Hypergraph Generation
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
import gradio as gr
|
| 22 |
+
import numpy as np
|
| 23 |
+
import json
|
| 24 |
+
import os
|
| 25 |
+
from typing import List, Dict, Any, Optional
|
| 26 |
+
from datetime import datetime
|
| 27 |
+
from utils.bunker_client import BunkerClient
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# ============================================================================
|
| 31 |
+
# PYTORCH IMPORTS WITH FALLBACK
|
| 32 |
+
# ============================================================================
|
| 33 |
+
try:
|
| 34 |
+
import torch
|
| 35 |
+
import torch.nn as nn
|
| 36 |
+
import torch.nn.functional as F
|
| 37 |
+
TORCH_AVAILABLE = True
|
| 38 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 39 |
+
print(f"🔧 PyTorch available. Device: {DEVICE}")
|
| 40 |
+
except ImportError:
|
| 41 |
+
TORCH_AVAILABLE = False
|
| 42 |
+
DEVICE = "cpu"
|
| 43 |
+
print("⚠️ PyTorch not available. Using simplified NumPy fallback.")
|
| 44 |
+
|
| 45 |
+
# ============================================================================
|
| 46 |
+
# FULL CTM IMPORT (with fallback to simplified)
|
| 47 |
+
# ============================================================================
|
| 48 |
+
if TORCH_AVAILABLE:
|
| 49 |
+
try:
|
| 50 |
+
from models.ctm import ContinuousThoughtMachine
|
| 51 |
+
from models.modules import SynapseUNET, SuperLinear
|
| 52 |
+
from utils.losses import image_classification_loss
|
| 53 |
+
CTM_FULL = True
|
| 54 |
+
print("✅ Full CTM model loaded from models/ctm.py")
|
| 55 |
+
except ImportError as e:
|
| 56 |
+
CTM_FULL = False
|
| 57 |
+
print(f"⚠️ Could not import full CTM: {e}. Using simplified.")
|
| 58 |
+
else:
|
| 59 |
+
CTM_FULL = False
|
| 60 |
+
|
| 61 |
+
# ============================================================================
|
| 62 |
+
# CONFIGURATION FOR ART-17 INTEGRATION (v3.0)
|
| 63 |
+
# ============================================================================
|
| 64 |
+
CONFIG = {
|
| 65 |
+
# CTM Architecture (matching ART-17)
|
| 66 |
+
"iterations": 50, # T internal ticks (max)
|
| 67 |
+
"d_model": 256, # Latent dimension
|
| 68 |
+
"d_input": 72, # Input from SNN (72D)
|
| 69 |
+
"memory_length": 16, # History length (16 dendrites)
|
| 70 |
+
"n_synch_out": 32, # Output sync neurons
|
| 71 |
+
"n_synch_action": 16, # Action sync neurons
|
| 72 |
+
"out_dims": 16, # Output: 16 dendrite adjustments
|
| 73 |
+
|
| 74 |
+
# v3.0 Improvements
|
| 75 |
+
"adaptive_halting": True, # Enable early stopping
|
| 76 |
+
"certainty_threshold": 0.85, # Halt if certainty > threshold
|
| 77 |
+
"sync_decay_alpha": 0.9, # S_new = α*S_old + (1-α)*S_current
|
| 78 |
+
"use_backbone": True, # Use Backbone72D transformation
|
| 79 |
+
|
| 80 |
+
# Training
|
| 81 |
+
"learning_rate": 1e-4,
|
| 82 |
+
"weight_decay": 1e-5,
|
| 83 |
+
"checkpoint_dir": "checkpoints",
|
| 84 |
+
"auto_save_every": 100, # Save every N forward passes
|
| 85 |
+
|
| 86 |
+
# Integration
|
| 87 |
+
"brain_server_url": "https://elliotasdasdasfasas-brain.hf.space",
|
| 88 |
+
|
| 89 |
+
# Physics validation
|
| 90 |
+
"physics_thresholds": {
|
| 91 |
+
"P_max": 1000.0,
|
| 92 |
+
"v_max": 100.0,
|
| 93 |
+
"T_dew": 15.0,
|
| 94 |
+
"T_amb": 25.0
|
| 95 |
+
}
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
# ============================================================================
|
| 99 |
+
# BACKBONE 72D (v3.0 - Transform input before CTM)
|
| 100 |
+
# ============================================================================
|
| 101 |
+
class Backbone72D(nn.Module if TORCH_AVAILABLE else object):
|
| 102 |
+
"""
|
| 103 |
+
Transform 72D SNN input to d_model dimensions.
|
| 104 |
+
Paper insight: Raw input needs proper embedding for CTM to work well.
|
| 105 |
+
"""
|
| 106 |
+
def __init__(self, d_input=72, d_model=256):
|
| 107 |
+
if not TORCH_AVAILABLE:
|
| 108 |
+
return
|
| 109 |
+
super().__init__()
|
| 110 |
+
self.net = nn.Sequential(
|
| 111 |
+
nn.Linear(d_input, 128),
|
| 112 |
+
nn.LayerNorm(128),
|
| 113 |
+
nn.GELU(),
|
| 114 |
+
nn.Linear(128, d_model),
|
| 115 |
+
nn.LayerNorm(d_model)
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
def forward(self, x):
|
| 119 |
+
# x: [B, 72]
|
| 120 |
+
return self.net(x) # [B, 256]
|
| 121 |
+
|
| 122 |
+
# ============================================================================
|
| 123 |
+
# FULL CTM WRAPPER FOR ART-17
|
| 124 |
+
# ============================================================================
|
| 125 |
+
class CTM_ART17:
|
| 126 |
+
"""
|
| 127 |
+
Full Continuous Thought Machine adapted for ART-17.
|
| 128 |
+
|
| 129 |
+
Key mechanisms from paper:
|
| 130 |
+
1. NLMs (Neuron-Level Models) - Each neuron processes its own history
|
| 131 |
+
2. Neural Synchronization - Representation is S = Z·Z^T
|
| 132 |
+
3. Adaptive Compute - Can halt early when confident
|
| 133 |
+
|
| 134 |
+
Purpose in ART-17:
|
| 135 |
+
- Regulate 16 dendrite STDP weights
|
| 136 |
+
- Generate coherent hypergraph edges
|
| 137 |
+
- Serve as "nervous system" for the whole system
|
| 138 |
+
"""
|
| 139 |
+
|
| 140 |
+
def __init__(self, config: dict):
|
| 141 |
+
self.config = config
|
| 142 |
+
self.forward_count = 0
|
| 143 |
+
self.training_samples = []
|
| 144 |
+
self.bunker = BunkerClient(buffer_dir=config.get("buffer_dir", "_ctm_buffer"))
|
| 145 |
+
|
| 146 |
+
if CTM_FULL and TORCH_AVAILABLE:
|
| 147 |
+
self._init_full_ctm()
|
| 148 |
+
else:
|
| 149 |
+
self._init_simplified_ctm()
|
| 150 |
+
|
| 151 |
+
def _init_full_ctm(self):
|
| 152 |
+
"""Initialize full PyTorch CTM model."""
|
| 153 |
+
self.model = ContinuousThoughtMachine(
|
| 154 |
+
iterations=self.config["iterations"],
|
| 155 |
+
d_model=self.config["d_model"],
|
| 156 |
+
d_input=self.config["d_input"],
|
| 157 |
+
heads=4,
|
| 158 |
+
n_synch_out=self.config["n_synch_out"],
|
| 159 |
+
n_synch_action=self.config["n_synch_action"],
|
| 160 |
+
synapse_depth=2,
|
| 161 |
+
memory_length=self.config["memory_length"],
|
| 162 |
+
deep_nlms=True,
|
| 163 |
+
memory_hidden_dims=32,
|
| 164 |
+
do_layernorm_nlm=False,
|
| 165 |
+
backbone_type='none',
|
| 166 |
+
positional_embedding_type='none',
|
| 167 |
+
out_dims=self.config["out_dims"],
|
| 168 |
+
prediction_reshaper=[self.config["out_dims"]],
|
| 169 |
+
dropout=0.1,
|
| 170 |
+
neuron_select_type='random-pairing'
|
| 171 |
+
).to(DEVICE)
|
| 172 |
+
|
| 173 |
+
# Dummy forward to initialize lazy modules
|
| 174 |
+
with torch.no_grad():
|
| 175 |
+
dummy = torch.randn(1, self.config["d_input"], device=DEVICE)
|
| 176 |
+
dummy = dummy.unsqueeze(-1).unsqueeze(-1) # [1, 72, 1, 1]
|
| 177 |
+
try:
|
| 178 |
+
_ = self.model(dummy)
|
| 179 |
+
except Exception as e:
|
| 180 |
+
print(f"⚠️ Lazy init failed: {e}")
|
| 181 |
+
|
| 182 |
+
self.model.eval()
|
| 183 |
+
self.optimizer = torch.optim.AdamW(
|
| 184 |
+
self.model.parameters(),
|
| 185 |
+
lr=self.config["learning_rate"],
|
| 186 |
+
weight_decay=self.config["weight_decay"]
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
self.is_full = True
|
| 190 |
+
param_count = sum(p.numel() for p in self.model.parameters())
|
| 191 |
+
print(f"✅ Full CTM initialized: {param_count:,} parameters")
|
| 192 |
+
|
| 193 |
+
# Try to load existing checkpoint
|
| 194 |
+
self._load_checkpoint()
|
| 195 |
+
|
| 196 |
+
def _init_simplified_ctm(self):
|
| 197 |
+
"""Initialize simplified NumPy CTM (fallback)."""
|
| 198 |
+
self.d_model = self.config["d_model"]
|
| 199 |
+
self.memory_length = self.config["memory_length"]
|
| 200 |
+
self.n_ticks = self.config["iterations"]
|
| 201 |
+
|
| 202 |
+
# State traces
|
| 203 |
+
self.state_trace = np.zeros((self.d_model, self.memory_length))
|
| 204 |
+
self.activated_state = np.random.randn(self.d_model) * 0.1
|
| 205 |
+
|
| 206 |
+
# NLM weights (simplified: 16 groups for 16 dendrites)
|
| 207 |
+
self.nlm_weights = np.random.randn(16, self.memory_length) * 0.1
|
| 208 |
+
|
| 209 |
+
self.is_full = False
|
| 210 |
+
print("✅ Simplified CTM initialized (NumPy fallback)")
|
| 211 |
+
|
| 212 |
+
def forward(self, input_72d: np.ndarray, n_ticks: Optional[int] = None) -> Dict:
|
| 213 |
+
"""
|
| 214 |
+
Process input through CTM.
|
| 215 |
+
|
| 216 |
+
Args:
|
| 217 |
+
input_72d: 72D input from SNN
|
| 218 |
+
n_ticks: Override number of internal ticks
|
| 219 |
+
|
| 220 |
+
Returns:
|
| 221 |
+
Dict with predictions, certainty, sync matrix
|
| 222 |
+
"""
|
| 223 |
+
n_ticks = n_ticks or self.config["iterations"]
|
| 224 |
+
self.forward_count += 1
|
| 225 |
+
|
| 226 |
+
if self.is_full:
|
| 227 |
+
return self._forward_full(input_72d, n_ticks)
|
| 228 |
+
else:
|
| 229 |
+
return self._forward_simplified(input_72d, n_ticks)
|
| 230 |
+
|
| 231 |
+
def _forward_full(self, input_72d: np.ndarray, n_ticks: int) -> Dict:
|
| 232 |
+
"""Forward pass with full PyTorch CTM."""
|
| 233 |
+
# Prepare tensor
|
| 234 |
+
x = torch.tensor(input_72d, dtype=torch.float32, device=DEVICE)
|
| 235 |
+
if len(x.shape) == 1:
|
| 236 |
+
x = x.unsqueeze(0) # Add batch dim
|
| 237 |
+
x = x.unsqueeze(-1).unsqueeze(-1) # [B, 72, 1, 1]
|
| 238 |
+
|
| 239 |
+
with torch.no_grad():
|
| 240 |
+
predictions, certainties, sync_out = self.model(x)
|
| 241 |
+
|
| 242 |
+
# Extract results
|
| 243 |
+
final_pred = predictions[:, :, -1].cpu().numpy()[0] # Last tick [16]
|
| 244 |
+
final_cert = certainties[:, 1, -1].cpu().numpy()[0] # 1-entropy
|
| 245 |
+
|
| 246 |
+
# Find tick with highest certainty
|
| 247 |
+
best_tick_idx = certainties[:, 1, :].argmax(dim=-1)[0].item()
|
| 248 |
+
best_pred = predictions[:, :, best_tick_idx].cpu().numpy()[0]
|
| 249 |
+
|
| 250 |
+
# Sync matrix for hypergraph edge proposals
|
| 251 |
+
sync_matrix = sync_out.cpu().numpy()[0] if sync_out is not None else None
|
| 252 |
+
|
| 253 |
+
return {
|
| 254 |
+
"predictions": final_pred.tolist(),
|
| 255 |
+
"best_predictions": best_pred.tolist(),
|
| 256 |
+
"certainty": float(final_cert),
|
| 257 |
+
"best_tick": int(best_tick_idx),
|
| 258 |
+
"ticks_used": n_ticks,
|
| 259 |
+
"sync_matrix": sync_matrix.tolist() if sync_matrix is not None else None,
|
| 260 |
+
"model": "ContinuousThoughtMachine (Full PyTorch)"
|
| 261 |
+
}
|
| 262 |
+
|
| 263 |
+
def _forward_simplified(self, input_72d: np.ndarray, n_ticks: int) -> Dict:
|
| 264 |
+
"""
|
| 265 |
+
Forward pass with simplified NumPy CTM (v3.0).
|
| 266 |
+
|
| 267 |
+
v3.0 Features:
|
| 268 |
+
1. Backbone transformation (72D -> 256D)
|
| 269 |
+
2. Sync Decay (S = α*S_prev + (1-α)*S_current)
|
| 270 |
+
3. Adaptive Halting (stop if certainty > threshold)
|
| 271 |
+
"""
|
| 272 |
+
# v3.0: Backbone transformation (simple linear projection)
|
| 273 |
+
if self.config.get("use_backbone", True):
|
| 274 |
+
# Learned transformation: 72D -> 256D
|
| 275 |
+
input_256 = np.zeros(self.d_model)
|
| 276 |
+
# Simple linear projection + normalization (simulates Backbone72D)
|
| 277 |
+
projected = np.tanh(input_72d[:72] * np.random.randn(72) * 0.1) if len(input_72d) >= 72 else input_72d
|
| 278 |
+
input_256[:min(len(projected), self.d_model)] = projected[:min(len(projected), self.d_model)]
|
| 279 |
+
else:
|
| 280 |
+
input_256 = np.zeros(self.d_model)
|
| 281 |
+
input_256[:min(len(input_72d), self.d_model)] = input_72d[:self.d_model]
|
| 282 |
+
|
| 283 |
+
# v3.0: Sync Decay initialization
|
| 284 |
+
alpha = self.config.get("sync_decay_alpha", 0.9)
|
| 285 |
+
sync_matrix_prev = np.zeros((self.d_model, self.d_model))
|
| 286 |
+
|
| 287 |
+
# v3.0: Adaptive halting config
|
| 288 |
+
adaptive_halting = self.config.get("adaptive_halting", True)
|
| 289 |
+
certainty_threshold = self.config.get("certainty_threshold", 0.85)
|
| 290 |
+
|
| 291 |
+
certainties = []
|
| 292 |
+
all_predictions = []
|
| 293 |
+
ticks_actually_used = 0
|
| 294 |
+
|
| 295 |
+
for t in range(n_ticks):
|
| 296 |
+
ticks_actually_used = t + 1
|
| 297 |
+
|
| 298 |
+
# Synapse update (simplified global mixing)
|
| 299 |
+
combined = np.concatenate([self.activated_state, input_256[:self.d_model//2]])
|
| 300 |
+
pre_activation = np.tanh(combined[:self.d_model] * 0.1 + np.random.randn(self.d_model) * 0.01)
|
| 301 |
+
|
| 302 |
+
# Update trace (memory)
|
| 303 |
+
self.state_trace = np.roll(self.state_trace, -1, axis=1)
|
| 304 |
+
self.state_trace[:, -1] = pre_activation
|
| 305 |
+
|
| 306 |
+
# NLM processing (simplified: 16 groups for 16 dendrites)
|
| 307 |
+
post_activation = np.zeros(self.d_model)
|
| 308 |
+
group_size = self.d_model // 16
|
| 309 |
+
for g in range(16):
|
| 310 |
+
start = g * group_size
|
| 311 |
+
end = start + group_size
|
| 312 |
+
group_trace = self.state_trace[start:end, :]
|
| 313 |
+
group_output = np.mean(group_trace @ self.nlm_weights[g])
|
| 314 |
+
post_activation[start:end] = np.tanh(group_output)
|
| 315 |
+
|
| 316 |
+
self.activated_state = post_activation
|
| 317 |
+
|
| 318 |
+
# v3.0: Sync Decay - S = α*S_prev + (1-α)*Z·Z^T
|
| 319 |
+
z_norm = self.activated_state / (np.linalg.norm(self.activated_state) + 1e-8)
|
| 320 |
+
sync_current = np.outer(z_norm, z_norm)
|
| 321 |
+
sync_matrix = alpha * sync_matrix_prev + (1 - alpha) * sync_current
|
| 322 |
+
sync_matrix_prev = sync_matrix
|
| 323 |
+
|
| 324 |
+
# Store predictions at this tick
|
| 325 |
+
all_predictions.append(self.activated_state[:16].copy())
|
| 326 |
+
|
| 327 |
+
# Compute certainty
|
| 328 |
+
probs = np.abs(self.activated_state) / (np.sum(np.abs(self.activated_state)) + 1e-8)
|
| 329 |
+
probs = np.clip(probs, 1e-10, 1.0)
|
| 330 |
+
entropy = -np.sum(probs * np.log(probs))
|
| 331 |
+
max_entropy = np.log(len(probs))
|
| 332 |
+
certainty = float(1.0 - entropy / (max_entropy + 1e-8))
|
| 333 |
+
certainties.append(certainty)
|
| 334 |
+
|
| 335 |
+
# v3.0: Adaptive Halting - stop early if confident enough
|
| 336 |
+
if adaptive_halting and certainty > certainty_threshold:
|
| 337 |
+
break
|
| 338 |
+
|
| 339 |
+
# Best tick selection
|
| 340 |
+
best_tick_idx = int(np.argmax(certainties))
|
| 341 |
+
best_predictions = all_predictions[best_tick_idx].tolist()
|
| 342 |
+
|
| 343 |
+
return {
|
| 344 |
+
"predictions": self.activated_state[:16].tolist(),
|
| 345 |
+
"best_predictions": best_predictions,
|
| 346 |
+
"certainty": certainties[-1],
|
| 347 |
+
"best_tick": best_tick_idx,
|
| 348 |
+
"ticks_used": ticks_actually_used, # v3.0: Actual ticks, may be < n_ticks
|
| 349 |
+
"max_ticks": n_ticks,
|
| 350 |
+
"halted_early": ticks_actually_used < n_ticks, # v3.0: Flag
|
| 351 |
+
"sync_matrix": sync_matrix[:16, :16].tolist(),
|
| 352 |
+
"model": "SimplifiedCTM v3.0 (NumPy + AdaptiveHalt + SyncDecay)"
|
| 353 |
+
}
|
| 354 |
+
|
| 355 |
+
def train_step(self, input_72d: np.ndarray, target_16d: np.ndarray,
|
| 356 |
+
physics_loss: float = 0.0) -> Dict:
|
| 357 |
+
"""
|
| 358 |
+
Online training step.
|
| 359 |
+
|
| 360 |
+
Args:
|
| 361 |
+
input_72d: Input from SNN
|
| 362 |
+
target_16d: Target dendrite adjustments (ground truth)
|
| 363 |
+
physics_loss: Current physics loss for weighting
|
| 364 |
+
|
| 365 |
+
Returns:
|
| 366 |
+
Dict with loss and gradient info
|
| 367 |
+
"""
|
| 368 |
+
if not self.is_full or not TORCH_AVAILABLE:
|
| 369 |
+
return {"status": "skip", "reason": "Training requires full PyTorch CTM"}
|
| 370 |
+
|
| 371 |
+
self.model.train()
|
| 372 |
+
|
| 373 |
+
# Prepare tensors
|
| 374 |
+
x = torch.tensor(input_72d, dtype=torch.float32, device=DEVICE)
|
| 375 |
+
x = x.unsqueeze(0).unsqueeze(-1).unsqueeze(-1) # [1, 72, 1, 1]
|
| 376 |
+
y = torch.tensor(target_16d, dtype=torch.float32, device=DEVICE).unsqueeze(0)
|
| 377 |
+
|
| 378 |
+
# Forward
|
| 379 |
+
predictions, certainties, _ = self.model(x)
|
| 380 |
+
|
| 381 |
+
# Loss: dendrite_regulation_loss
|
| 382 |
+
# predictions: [B, 16, T], y: [B, 16]
|
| 383 |
+
y_exp = y.unsqueeze(-1).expand(-1, -1, predictions.size(-1)) # [B, 16, T]
|
| 384 |
+
mse_per_tick = F.mse_loss(predictions, y_exp, reduction='none').mean(dim=1) # [B, T]
|
| 385 |
+
|
| 386 |
+
# Select best tick (min loss) and most certain tick
|
| 387 |
+
loss_min_idx = mse_per_tick.argmin(dim=1) # [B]
|
| 388 |
+
loss_cert_idx = certainties[:, 1, :].argmax(dim=1) # [B]
|
| 389 |
+
|
| 390 |
+
batch_idx = torch.arange(predictions.size(0), device=DEVICE)
|
| 391 |
+
loss_min = mse_per_tick[batch_idx, loss_min_idx].mean()
|
| 392 |
+
loss_cert = mse_per_tick[batch_idx, loss_cert_idx].mean()
|
| 393 |
+
|
| 394 |
+
# Combined loss with physics penalty
|
| 395 |
+
mse_loss = (loss_min + loss_cert) / 2
|
| 396 |
+
physics_penalty = physics_loss * 0.1
|
| 397 |
+
total_loss = mse_loss + physics_penalty
|
| 398 |
+
|
| 399 |
+
# Backward
|
| 400 |
+
self.optimizer.zero_grad()
|
| 401 |
+
total_loss.backward()
|
| 402 |
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
|
| 403 |
+
self.optimizer.step()
|
| 404 |
+
|
| 405 |
+
self.model.eval()
|
| 406 |
+
|
| 407 |
+
# Auto-save checkpoint
|
| 408 |
+
if self.forward_count % self.config["auto_save_every"] == 0:
|
| 409 |
+
self._save_checkpoint()
|
| 410 |
+
|
| 411 |
+
return {
|
| 412 |
+
"status": "trained",
|
| 413 |
+
"loss": float(total_loss.item()),
|
| 414 |
+
"mse_loss": float(mse_loss.item()),
|
| 415 |
+
"physics_penalty": float(physics_penalty),
|
| 416 |
+
"best_tick": int(loss_cert_idx[0].item())
|
| 417 |
+
}
|
| 418 |
+
|
| 419 |
+
def _save_checkpoint(self):
|
| 420 |
+
"""Save model checkpoint."""
|
| 421 |
+
if not self.is_full:
|
| 422 |
+
return
|
| 423 |
+
|
| 424 |
+
os.makedirs(self.config["checkpoint_dir"], exist_ok=True)
|
| 425 |
+
path = os.path.join(self.config["checkpoint_dir"], "ctm_art17_latest.pt")
|
| 426 |
+
|
| 427 |
+
torch.save({
|
| 428 |
+
"model_state_dict": self.model.state_dict(),
|
| 429 |
+
"optimizer_state_dict": self.optimizer.state_dict(),
|
| 430 |
+
"forward_count": self.forward_count,
|
| 431 |
+
"timestamp": datetime.now().isoformat()
|
| 432 |
+
}, path)
|
| 433 |
+
print(f"💾 Checkpoint saved: {path}")
|
| 434 |
+
|
| 435 |
+
# Upload to Bunker (Async/Fail-Safe)
|
| 436 |
+
self.bunker.save_file(path, remote_folder="ctm_backups")
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
def _load_checkpoint(self):
|
| 440 |
+
"""Load model checkpoint if exists."""
|
| 441 |
+
path = os.path.join(self.config["checkpoint_dir"], "ctm_art17_latest.pt")
|
| 442 |
+
if os.path.exists(path):
|
| 443 |
+
try:
|
| 444 |
+
checkpoint = torch.load(path, map_location=DEVICE)
|
| 445 |
+
self.model.load_state_dict(checkpoint["model_state_dict"])
|
| 446 |
+
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
| 447 |
+
self.forward_count = checkpoint.get("forward_count", 0)
|
| 448 |
+
print(f"✅ Checkpoint loaded: {path}")
|
| 449 |
+
except Exception as e:
|
| 450 |
+
print(f"⚠️ Could not load checkpoint: {e}")
|
| 451 |
+
|
| 452 |
+
# ============================================================================
|
| 453 |
+
# GLOBAL CTM INSTANCE
|
| 454 |
+
# ============================================================================
|
| 455 |
+
ctm = CTM_ART17(CONFIG)
|
| 456 |
+
|
| 457 |
+
# ============================================================================
|
| 458 |
+
# PHYSICS VALIDATION (from SNN Omega-21)
|
| 459 |
+
# ============================================================================
|
| 460 |
+
def validate_physics(trajectory: List[float], params: Dict) -> Dict:
|
| 461 |
+
"""Validate against 5 physics losses from SNN Omega-21."""
|
| 462 |
+
trajectory = np.array(trajectory)
|
| 463 |
+
|
| 464 |
+
# L_energy: Energy conservation
|
| 465 |
+
energy = np.sum(trajectory ** 2)
|
| 466 |
+
P_max = params.get("P_max", CONFIG["physics_thresholds"]["P_max"])
|
| 467 |
+
L_energy = float(max(0, energy - P_max) ** 2)
|
| 468 |
+
|
| 469 |
+
# L_thermo: Thermodynamics (dew point check)
|
| 470 |
+
T_dew = params.get("T_dew", CONFIG["physics_thresholds"]["T_dew"])
|
| 471 |
+
T_amb = params.get("T_amb", CONFIG["physics_thresholds"]["T_amb"])
|
| 472 |
+
L_thermo = float(max(0, T_dew - T_amb) ** 2)
|
| 473 |
+
|
| 474 |
+
# L_causal: Causality (velocity limit)
|
| 475 |
+
velocity = np.diff(trajectory) if len(trajectory) > 1 else np.array([0])
|
| 476 |
+
v_max = params.get("v_max", CONFIG["physics_thresholds"]["v_max"])
|
| 477 |
+
L_causal = float(np.sum(np.maximum(0, np.abs(velocity) - v_max) ** 2))
|
| 478 |
+
|
| 479 |
+
# L_conserv: Flux conservation
|
| 480 |
+
flux_in = params.get("flux_in", 1.0)
|
| 481 |
+
flux_out = params.get("flux_out", 1.0)
|
| 482 |
+
L_conserv = float((flux_in - flux_out) ** 2)
|
| 483 |
+
|
| 484 |
+
# L_entropy: 2nd Law (entropy must increase)
|
| 485 |
+
entropy_change = params.get("entropy_change", 0.1)
|
| 486 |
+
L_entropy = float(max(0, -entropy_change) ** 2)
|
| 487 |
+
|
| 488 |
+
# Total physics loss
|
| 489 |
+
L_total = L_energy + L_thermo + L_causal + L_conserv + L_entropy
|
| 490 |
+
|
| 491 |
+
return {
|
| 492 |
+
"valid": L_total < 0.01,
|
| 493 |
+
"L_energy": L_energy,
|
| 494 |
+
"L_thermo": L_thermo,
|
| 495 |
+
"L_causal": L_causal,
|
| 496 |
+
"L_conserv": L_conserv,
|
| 497 |
+
"L_entropy": L_entropy,
|
| 498 |
+
"L_total": L_total
|
| 499 |
+
}
|
| 500 |
+
|
| 501 |
+
# ============================================================================
|
| 502 |
+
# ENDPOINT FUNCTIONS
|
| 503 |
+
# ============================================================================
|
| 504 |
+
|
| 505 |
+
def sense_snn(snn_json: str) -> str:
|
| 506 |
+
"""
|
| 507 |
+
/sense_snn - Process 72D SNN input through CTM
|
| 508 |
+
|
| 509 |
+
Input: JSON with dendrite values or 72D vector
|
| 510 |
+
Output: Coherent features, certainty, sync matrix
|
| 511 |
+
"""
|
| 512 |
+
try:
|
| 513 |
+
data = json.loads(snn_json)
|
| 514 |
+
|
| 515 |
+
# Extract 72D vector
|
| 516 |
+
if "vector_72d" in data:
|
| 517 |
+
input_vec = np.array(data["vector_72d"])
|
| 518 |
+
elif "dendrites" in data:
|
| 519 |
+
input_vec = np.array(list(data["dendrites"].values()))
|
| 520 |
+
else:
|
| 521 |
+
input_vec = np.random.randn(72)
|
| 522 |
+
|
| 523 |
+
# Pad to 72D if needed
|
| 524 |
+
if len(input_vec) < 72:
|
| 525 |
+
input_vec = np.pad(input_vec, (0, 72 - len(input_vec)))
|
| 526 |
+
|
| 527 |
+
# Process through CTM
|
| 528 |
+
n_ticks = data.get("ticks", 25)
|
| 529 |
+
result = ctm.forward(input_vec[:72], n_ticks)
|
| 530 |
+
|
| 531 |
+
# Detect anomalies (low certainty)
|
| 532 |
+
anomalies = []
|
| 533 |
+
if result["certainty"] < 0.5:
|
| 534 |
+
anomalies.append("Low overall certainty - consider retraining")
|
| 535 |
+
|
| 536 |
+
return json.dumps({
|
| 537 |
+
"status": "success",
|
| 538 |
+
"coherent_features": result["predictions"],
|
| 539 |
+
"certainty": result["certainty"],
|
| 540 |
+
"best_tick": result["best_tick"],
|
| 541 |
+
"anomalies": anomalies,
|
| 542 |
+
"ticks_used": result["ticks_used"],
|
| 543 |
+
"model": result["model"]
|
| 544 |
+
}, indent=2)
|
| 545 |
+
except Exception as e:
|
| 546 |
+
return json.dumps({"status": "error", "message": str(e)})
|
| 547 |
+
|
| 548 |
+
|
| 549 |
+
def reason_hypergraph(context_json: str) -> str:
|
| 550 |
+
"""
|
| 551 |
+
/reason_hypergraph - Reason about hypergraph context, propose edges
|
| 552 |
+
|
| 553 |
+
Uses CTM synchronization matrix to find strongly correlated node pairs.
|
| 554 |
+
These become proposed hyperedges.
|
| 555 |
+
"""
|
| 556 |
+
try:
|
| 557 |
+
data = json.loads(context_json)
|
| 558 |
+
|
| 559 |
+
node_features = np.array(data.get("node_features", [[0]*16]*8))
|
| 560 |
+
existing_edges = data.get("existing_edges", [])
|
| 561 |
+
n_ticks = data.get("ticks", 50)
|
| 562 |
+
|
| 563 |
+
# Flatten node features for CTM input and pad to 72D
|
| 564 |
+
flattened = node_features.flatten()
|
| 565 |
+
input_vec = np.zeros(72)
|
| 566 |
+
input_vec[:min(len(flattened), 72)] = flattened[:min(len(flattened), 72)]
|
| 567 |
+
|
| 568 |
+
# Process through CTM with more ticks for reasoning
|
| 569 |
+
result = ctm.forward(input_vec, n_ticks)
|
| 570 |
+
|
| 571 |
+
# Extract proposed edges from sync matrix (S_ij > 0.7)
|
| 572 |
+
proposed_edges = []
|
| 573 |
+
if result["sync_matrix"] is not None:
|
| 574 |
+
sync = np.array(result["sync_matrix"])
|
| 575 |
+
# Ensure sync is 2D
|
| 576 |
+
if len(sync.shape) == 1:
|
| 577 |
+
# 1D array - skip edge extraction
|
| 578 |
+
pass
|
| 579 |
+
elif len(sync.shape) >= 2:
|
| 580 |
+
n_nodes = min(len(node_features), sync.shape[0])
|
| 581 |
+
|
| 582 |
+
for i in range(n_nodes):
|
| 583 |
+
for j in range(i+1, n_nodes):
|
| 584 |
+
if j < sync.shape[1]: # Check bounds
|
| 585 |
+
sync_ij = sync[i, j]
|
| 586 |
+
if sync_ij > 0.7: # Threshold for edge proposal
|
| 587 |
+
edge_exists = any(
|
| 588 |
+
(e[0] == i and e[1] == j) or (e[0] == j and e[1] == i)
|
| 589 |
+
for e in existing_edges
|
| 590 |
+
)
|
| 591 |
+
if not edge_exists:
|
| 592 |
+
proposed_edges.append([i, j, float(sync_ij)])
|
| 593 |
+
|
| 594 |
+
return json.dumps({
|
| 595 |
+
"status": "success",
|
| 596 |
+
"proposed_edges": proposed_edges,
|
| 597 |
+
"certainty": result["certainty"],
|
| 598 |
+
"best_tick": result["best_tick"],
|
| 599 |
+
"ticks_used": result["ticks_used"],
|
| 600 |
+
"model": result["model"]
|
| 601 |
+
}, indent=2)
|
| 602 |
+
except Exception as e:
|
| 603 |
+
return json.dumps({"status": "error", "message": str(e)})
|
| 604 |
+
|
| 605 |
+
|
| 606 |
+
def validate_physics_endpoint(physics_json: str) -> str:
|
| 607 |
+
"""
|
| 608 |
+
/validate_physics - Validate trajectory against 5 physics losses
|
| 609 |
+
"""
|
| 610 |
+
try:
|
| 611 |
+
data = json.loads(physics_json)
|
| 612 |
+
trajectory = data.get("trajectory", [0.0])
|
| 613 |
+
params = data.get("physics_params", {})
|
| 614 |
+
|
| 615 |
+
result = validate_physics(trajectory, params)
|
| 616 |
+
result["status"] = "success"
|
| 617 |
+
|
| 618 |
+
return json.dumps(result, indent=2)
|
| 619 |
+
except Exception as e:
|
| 620 |
+
return json.dumps({"status": "error", "message": str(e)})
|
| 621 |
+
|
| 622 |
+
|
| 623 |
+
def dream_endpoint(dream_json: str) -> str:
|
| 624 |
+
"""
|
| 625 |
+
/dream - Offline consolidation with many ticks
|
| 626 |
+
|
| 627 |
+
Discovers patterns, proposes new edges, identifies edges to prune.
|
| 628 |
+
"""
|
| 629 |
+
try:
|
| 630 |
+
data = json.loads(dream_json)
|
| 631 |
+
|
| 632 |
+
snapshot = data.get("hypergraph_snapshot", {})
|
| 633 |
+
n_ticks = min(data.get("ticks", 100), 100) # Cap at 100 for CPU
|
| 634 |
+
|
| 635 |
+
# Extract features from snapshot
|
| 636 |
+
nodes = snapshot.get("nodes", [])
|
| 637 |
+
if nodes:
|
| 638 |
+
input_vec = np.array([n.get("features", [0]*16) for n in nodes]).flatten()[:72]
|
| 639 |
+
else:
|
| 640 |
+
input_vec = np.random.randn(72)
|
| 641 |
+
|
| 642 |
+
# Dream: run CTM with many ticks
|
| 643 |
+
result = ctm.forward(input_vec, n_ticks)
|
| 644 |
+
|
| 645 |
+
# Analyze sync for patterns
|
| 646 |
+
new_edges = []
|
| 647 |
+
pruned_edges = []
|
| 648 |
+
|
| 649 |
+
if result["sync_matrix"] is not None:
|
| 650 |
+
sync = np.array(result["sync_matrix"])
|
| 651 |
+
n = min(len(nodes), sync.shape[0]) if nodes else 16
|
| 652 |
+
|
| 653 |
+
for i in range(n):
|
| 654 |
+
for j in range(i+1, n):
|
| 655 |
+
if sync[i, j] > 0.85:
|
| 656 |
+
new_edges.append([i, j, float(sync[i, j])])
|
| 657 |
+
elif sync[i, j] < 0.1:
|
| 658 |
+
pruned_edges.append([i, j])
|
| 659 |
+
|
| 660 |
+
return json.dumps({
|
| 661 |
+
"status": "success",
|
| 662 |
+
"discovered_patterns": len(new_edges),
|
| 663 |
+
"new_edges": new_edges[:10],
|
| 664 |
+
"pruned_edges": pruned_edges[:10],
|
| 665 |
+
"consolidation_certainty": result["certainty"],
|
| 666 |
+
"ticks_used": result["ticks_used"],
|
| 667 |
+
"model": result["model"]
|
| 668 |
+
}, indent=2)
|
| 669 |
+
except Exception as e:
|
| 670 |
+
return json.dumps({"status": "error", "message": str(e)})
|
| 671 |
+
|
| 672 |
+
|
| 673 |
+
def calibrate_stdp_endpoint(stdp_json: str) -> str:
|
| 674 |
+
"""
|
| 675 |
+
/calibrate_stdp - Suggest STDP weight adjustments
|
| 676 |
+
|
| 677 |
+
This is the CORE regulatory function:
|
| 678 |
+
- Receives current 16 dendrite weights
|
| 679 |
+
- Processes through CTM to get sync patterns
|
| 680 |
+
- Returns suggested weight adjustments
|
| 681 |
+
"""
|
| 682 |
+
try:
|
| 683 |
+
data = json.loads(stdp_json)
|
| 684 |
+
|
| 685 |
+
current_weights = np.array(data.get("current_weights", [1.0]*16))
|
| 686 |
+
node_features = np.array(data.get("node_features", [[0]*16]*4))
|
| 687 |
+
|
| 688 |
+
# Flatten features for CTM input
|
| 689 |
+
input_vec = node_features.flatten()[:72]
|
| 690 |
+
|
| 691 |
+
# Process through CTM
|
| 692 |
+
result = ctm.forward(input_vec, n_ticks=25)
|
| 693 |
+
|
| 694 |
+
# Use predictions as weight adjustments
|
| 695 |
+
predictions = np.array(result["best_predictions"])
|
| 696 |
+
|
| 697 |
+
# Scale based on certainty
|
| 698 |
+
confidence = result["certainty"]
|
| 699 |
+
weight_changes = (predictions - 0.5) * confidence * 0.1
|
| 700 |
+
|
| 701 |
+
new_weights = current_weights + weight_changes
|
| 702 |
+
|
| 703 |
+
return json.dumps({
|
| 704 |
+
"status": "success",
|
| 705 |
+
"suggested_weights": new_weights.tolist(),
|
| 706 |
+
"weight_changes": weight_changes.tolist(),
|
| 707 |
+
"confidence": confidence,
|
| 708 |
+
"best_tick": result["best_tick"],
|
| 709 |
+
"model": result["model"]
|
| 710 |
+
}, indent=2)
|
| 711 |
+
except Exception as e:
|
| 712 |
+
return json.dumps({"status": "error", "message": str(e)})
|
| 713 |
+
|
| 714 |
+
|
| 715 |
+
def regulate_endpoint(regulate_json: str) -> str:
|
| 716 |
+
"""
|
| 717 |
+
/regulate - Full feedback loop for ART-17 regulation (NEW)
|
| 718 |
+
|
| 719 |
+
Combines all signals to provide comprehensive regulation:
|
| 720 |
+
- Dendrite state
|
| 721 |
+
- Latent representation
|
| 722 |
+
- Physics loss
|
| 723 |
+
- Anomaly score
|
| 724 |
+
|
| 725 |
+
Returns action recommendation with confidence.
|
| 726 |
+
"""
|
| 727 |
+
try:
|
| 728 |
+
data = json.loads(regulate_json)
|
| 729 |
+
|
| 730 |
+
# Inputs from local system
|
| 731 |
+
dendrites = np.array(data.get("dendrites", [0.0]*16))
|
| 732 |
+
latent_256 = np.array(data.get("latent_256", [0.0]*256))
|
| 733 |
+
physics_loss = data.get("physics_loss", 0.0)
|
| 734 |
+
anomaly_score = data.get("anomaly_score", 0.0)
|
| 735 |
+
|
| 736 |
+
# Combine into 72D input
|
| 737 |
+
input_72 = np.concatenate([
|
| 738 |
+
dendrites, # 16D
|
| 739 |
+
latent_256[:56] # 56D from latent
|
| 740 |
+
])
|
| 741 |
+
|
| 742 |
+
# Process through CTM
|
| 743 |
+
result = ctm.forward(input_72, n_ticks=50)
|
| 744 |
+
|
| 745 |
+
# Compute regulation signals
|
| 746 |
+
predictions = np.array(result["best_predictions"])
|
| 747 |
+
certainty = result["certainty"]
|
| 748 |
+
|
| 749 |
+
# Urgency based on physics and anomaly
|
| 750 |
+
urgency = min(1.0, physics_loss + anomaly_score)
|
| 751 |
+
regulation_strength = urgency * certainty
|
| 752 |
+
|
| 753 |
+
# Weight adjustments
|
| 754 |
+
dendrite_deltas = predictions * regulation_strength * 0.05
|
| 755 |
+
|
| 756 |
+
# Determine if intervention needed
|
| 757 |
+
needs_intervention = urgency > 0.5 or certainty < 0.3
|
| 758 |
+
|
| 759 |
+
return json.dumps({
|
| 760 |
+
"status": "success",
|
| 761 |
+
"dendrite_deltas": dendrite_deltas.tolist(),
|
| 762 |
+
"regulation_strength": float(regulation_strength),
|
| 763 |
+
"confidence": certainty,
|
| 764 |
+
"urgency": float(urgency),
|
| 765 |
+
"needs_intervention": needs_intervention,
|
| 766 |
+
"recommended_action": "ADJUST" if needs_intervention else "MAINTAIN",
|
| 767 |
+
"best_tick": result["best_tick"],
|
| 768 |
+
"model": result["model"]
|
| 769 |
+
}, indent=2)
|
| 770 |
+
except Exception as e:
|
| 771 |
+
return json.dumps({"status": "error", "message": str(e)})
|
| 772 |
+
|
| 773 |
+
|
| 774 |
+
def train_online_endpoint(train_json: str) -> str:
|
| 775 |
+
"""
|
| 776 |
+
/train_online - Progressive online training (NEW)
|
| 777 |
+
|
| 778 |
+
Allows the local system to train the CTM with experience.
|
| 779 |
+
Sends input-output pairs and receives training feedback.
|
| 780 |
+
"""
|
| 781 |
+
try:
|
| 782 |
+
data = json.loads(train_json)
|
| 783 |
+
|
| 784 |
+
input_72d = np.array(data.get("input_72d", [0.0]*72))
|
| 785 |
+
target_16d = np.array(data.get("target_16d", [0.0]*16))
|
| 786 |
+
physics_loss = data.get("physics_loss", 0.0)
|
| 787 |
+
|
| 788 |
+
# Perform training step
|
| 789 |
+
result = ctm.train_step(input_72d, target_16d, physics_loss)
|
| 790 |
+
|
| 791 |
+
return json.dumps({
|
| 792 |
+
"status": result["status"],
|
| 793 |
+
"loss": result.get("loss"),
|
| 794 |
+
"mse_loss": result.get("mse_loss"),
|
| 795 |
+
"physics_penalty": result.get("physics_penalty"),
|
| 796 |
+
"best_tick": result.get("best_tick"),
|
| 797 |
+
"forward_count": ctm.forward_count,
|
| 798 |
+
"message": "Training step completed" if result["status"] == "trained" else result.get("reason")
|
| 799 |
+
}, indent=2)
|
| 800 |
+
except Exception as e:
|
| 801 |
+
return json.dumps({"status": "error", "message": str(e)})
|
| 802 |
+
|
| 803 |
+
|
| 804 |
+
def health_check() -> str:
|
| 805 |
+
"""Health check with model info."""
|
| 806 |
+
return json.dumps({
|
| 807 |
+
"status": "healthy",
|
| 808 |
+
"model": f"CTM Nervous System v2.0 ({'Full PyTorch' if ctm.is_full else 'NumPy Fallback'})",
|
| 809 |
+
"device": DEVICE,
|
| 810 |
+
"d_model": CONFIG["d_model"],
|
| 811 |
+
"iterations": CONFIG["iterations"],
|
| 812 |
+
"memory_length": CONFIG["memory_length"],
|
| 813 |
+
"forward_count": ctm.forward_count,
|
| 814 |
+
"endpoints": [
|
| 815 |
+
"/sense_snn",
|
| 816 |
+
"/reason_hypergraph",
|
| 817 |
+
"/validate_physics",
|
| 818 |
+
"/dream",
|
| 819 |
+
"/calibrate_stdp",
|
| 820 |
+
"/regulate", # NEW
|
| 821 |
+
"/train_online" # NEW
|
| 822 |
+
]
|
| 823 |
+
}, indent=2)
|
| 824 |
+
|
| 825 |
+
|
| 826 |
+
# ============================================================================
|
| 827 |
+
# GRADIO INTERFACE
|
| 828 |
+
# ============================================================================
|
| 829 |
+
|
| 830 |
+
with gr.Blocks(title="CTM Nervous System v2.0", theme=gr.themes.Soft()) as demo:
|
| 831 |
+
gr.Markdown("""
|
| 832 |
+
# 🧬 CTM Nervous System v2.0
|
| 833 |
+
**Continuous Thought Machine for ART-17 Hypergraph Coherence**
|
| 834 |
+
|
| 835 |
+
Based on [arXiv:2505.05522](https://arxiv.org/abs/2505.05522) - Sakana AI
|
| 836 |
+
|
| 837 |
+
---
|
| 838 |
+
|
| 839 |
+
## Key Innovations
|
| 840 |
+
- **NLMs (Neuron-Level Models)**: Each neuron processes its own history
|
| 841 |
+
- **Neural Synchronization**: Representation via S = Z·Z^T
|
| 842 |
+
- **Adaptive Compute**: Halts when confident
|
| 843 |
+
- **Online Training**: Progressive learning with use
|
| 844 |
+
|
| 845 |
+
---
|
| 846 |
+
""")
|
| 847 |
+
|
| 848 |
+
with gr.Tabs():
|
| 849 |
+
with gr.Tab("🔌 /sense_snn"):
|
| 850 |
+
gr.Markdown("Process 72D SNN input through CTM")
|
| 851 |
+
snn_input = gr.Textbox(
|
| 852 |
+
label="SNN JSON Input",
|
| 853 |
+
value='{"dendrites": {"d1": 0.1, "d2": 0.2, "d3": 0.3}, "ticks": 25}',
|
| 854 |
+
lines=5
|
| 855 |
+
)
|
| 856 |
+
snn_output = gr.Textbox(label="Output", lines=10)
|
| 857 |
+
snn_btn = gr.Button("Process", variant="primary")
|
| 858 |
+
snn_btn.click(sense_snn, inputs=snn_input, outputs=snn_output, api_name="sense_snn")
|
| 859 |
+
|
| 860 |
+
with gr.Tab("🧠 /reason_hypergraph"):
|
| 861 |
+
gr.Markdown("Reason about hypergraph context, propose edges")
|
| 862 |
+
reason_input = gr.Textbox(
|
| 863 |
+
label="Context JSON",
|
| 864 |
+
value='{"node_features": [[0.1, 0.2], [0.3, 0.4]], "existing_edges": [], "ticks": 50}',
|
| 865 |
+
lines=5
|
| 866 |
+
)
|
| 867 |
+
reason_output = gr.Textbox(label="Output", lines=10)
|
| 868 |
+
reason_btn = gr.Button("Reason", variant="primary")
|
| 869 |
+
reason_btn.click(reason_hypergraph, inputs=reason_input, outputs=reason_output, api_name="reason_hypergraph")
|
| 870 |
+
|
| 871 |
+
with gr.Tab("⚡ /validate_physics"):
|
| 872 |
+
gr.Markdown("Validate trajectory against 5 physics losses")
|
| 873 |
+
physics_input = gr.Textbox(
|
| 874 |
+
label="Physics JSON",
|
| 875 |
+
value='{"trajectory": [0.1, 0.2, 0.3], "physics_params": {"P_max": 1000}}',
|
| 876 |
+
lines=5
|
| 877 |
+
)
|
| 878 |
+
physics_output = gr.Textbox(label="Output", lines=10)
|
| 879 |
+
physics_btn = gr.Button("Validate", variant="primary")
|
| 880 |
+
physics_btn.click(validate_physics_endpoint, inputs=physics_input, outputs=physics_output, api_name="validate_physics")
|
| 881 |
+
|
| 882 |
+
with gr.Tab("💤 /dream"):
|
| 883 |
+
gr.Markdown("Offline consolidation - discover patterns")
|
| 884 |
+
dream_input = gr.Textbox(
|
| 885 |
+
label="Dream JSON",
|
| 886 |
+
value='{"hypergraph_snapshot": {"nodes": []}, "ticks": 100}',
|
| 887 |
+
lines=5
|
| 888 |
+
)
|
| 889 |
+
dream_output = gr.Textbox(label="Output", lines=10)
|
| 890 |
+
dream_btn = gr.Button("Dream", variant="primary")
|
| 891 |
+
dream_btn.click(dream_endpoint, inputs=dream_input, outputs=dream_output, api_name="dream")
|
| 892 |
+
|
| 893 |
+
with gr.Tab("🔧 /calibrate_stdp"):
|
| 894 |
+
gr.Markdown("Calibrate STDP weights (Core regulatory function)")
|
| 895 |
+
stdp_input = gr.Textbox(
|
| 896 |
+
label="STDP JSON",
|
| 897 |
+
value='{"current_weights": [1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1], "node_features": [[0.1, 0.2]]}',
|
| 898 |
+
lines=5
|
| 899 |
+
)
|
| 900 |
+
stdp_output = gr.Textbox(label="Output", lines=10)
|
| 901 |
+
stdp_btn = gr.Button("Calibrate", variant="primary")
|
| 902 |
+
stdp_btn.click(calibrate_stdp_endpoint, inputs=stdp_input, outputs=stdp_output, api_name="calibrate_stdp")
|
| 903 |
+
|
| 904 |
+
with gr.Tab("🎯 /regulate [NEW]"):
|
| 905 |
+
gr.Markdown("Full feedback loop for ART-17 regulation")
|
| 906 |
+
regulate_input = gr.Textbox(
|
| 907 |
+
label="Regulate JSON",
|
| 908 |
+
value='{"dendrites": [0.5]*16, "latent_256": [0.1]*256, "physics_loss": 0.01, "anomaly_score": 0.05}',
|
| 909 |
+
lines=5
|
| 910 |
+
)
|
| 911 |
+
regulate_output = gr.Textbox(label="Output", lines=10)
|
| 912 |
+
regulate_btn = gr.Button("Regulate", variant="primary")
|
| 913 |
+
regulate_btn.click(regulate_endpoint, inputs=regulate_input, outputs=regulate_output, api_name="regulate")
|
| 914 |
+
|
| 915 |
+
with gr.Tab("📚 /train_online [NEW]"):
|
| 916 |
+
gr.Markdown("Progressive online training with experience")
|
| 917 |
+
train_input = gr.Textbox(
|
| 918 |
+
label="Training JSON",
|
| 919 |
+
value='{"input_72d": [0.1]*72, "target_16d": [0.5]*16, "physics_loss": 0.01}',
|
| 920 |
+
lines=5
|
| 921 |
+
)
|
| 922 |
+
train_output = gr.Textbox(label="Output", lines=10)
|
| 923 |
+
train_btn = gr.Button("Train Step", variant="primary")
|
| 924 |
+
train_btn.click(train_online_endpoint, inputs=train_input, outputs=train_output, api_name="train_online")
|
| 925 |
+
|
| 926 |
+
with gr.Tab("❤️ Health"):
|
| 927 |
+
health_output = gr.Textbox(label="Health Status", lines=15)
|
| 928 |
+
health_btn = gr.Button("Check Health", variant="secondary")
|
| 929 |
+
health_btn.click(health_check, inputs=None, outputs=health_output, api_name="health_check")
|
| 930 |
+
|
| 931 |
+
gr.Markdown("""
|
| 932 |
+
---
|
| 933 |
+
**Architecture**: CTM as Nervous System → Hypergraph as Coherent Thought
|
| 934 |
+
|
| 935 |
+
**Integration**: Local ART-17 ↔ CTM (regulation) ↔ Brain Server (semantics)
|
| 936 |
+
|
| 937 |
+
**Training**: Progressive online learning + Physics-Informed Loss
|
| 938 |
+
""")
|
| 939 |
+
|
| 940 |
+
if __name__ == "__main__":
|
| 941 |
+
demo.launch(
|
| 942 |
+
server_name="0.0.0.0",
|
| 943 |
+
server_port=7860,
|
| 944 |
+
show_error=True
|
| 945 |
+
)
|
app_v1_backup.py
ADDED
|
@@ -0,0 +1,464 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
CTM Nervous System Server - Continuous Thought Machine for Hypergraph Maintenance
|
| 3 |
+
===================================================================================
|
| 4 |
+
Implementation of the definitive proposal: CTM as Nervous System for ART-17 Hypergraph.
|
| 5 |
+
|
| 6 |
+
Endpoints:
|
| 7 |
+
- /sense_snn: Process 72D SNN input with NLM-style processing
|
| 8 |
+
- /reason_hypergraph: Reason about hypergraph context, propose edges
|
| 9 |
+
- /validate_physics: Validate proposals against 5 physics losses
|
| 10 |
+
- /dream: Offline consolidation with T=500+ ticks
|
| 11 |
+
- /calibrate_stdp: Suggest STDP weight adjustments from sync matrix
|
| 12 |
+
- /health: Health check endpoint
|
| 13 |
+
|
| 14 |
+
Based on: arXiv:2505.05522 (Continuous Thought Machines - Sakana AI)
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
import gradio as gr
|
| 18 |
+
import numpy as np
|
| 19 |
+
import json
|
| 20 |
+
from typing import List, Dict, Any, Optional
|
| 21 |
+
import os
|
| 22 |
+
|
| 23 |
+
# ============================================================================
|
| 24 |
+
# SIMPLIFIED CTM SIMULATION (CPU-only for Hugging Face free tier)
|
| 25 |
+
# ============================================================================
|
| 26 |
+
|
| 27 |
+
class SimplifiedCTM:
|
| 28 |
+
"""
|
| 29 |
+
Simplified CTM for CPU-only environment.
|
| 30 |
+
Simulates the key mechanisms without full PyTorch model.
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
def __init__(self, d_model: int = 256, memory_length: int = 16, n_ticks: int = 50):
|
| 34 |
+
self.d_model = d_model
|
| 35 |
+
self.memory_length = memory_length
|
| 36 |
+
self.n_ticks = n_ticks
|
| 37 |
+
|
| 38 |
+
# Initialize state
|
| 39 |
+
self.state_trace = np.zeros((d_model, memory_length))
|
| 40 |
+
self.activated_state = np.random.randn(d_model) * 0.1
|
| 41 |
+
|
| 42 |
+
# NLM weights (simplified: one weight matrix per "neuron group")
|
| 43 |
+
self.nlm_weights = np.random.randn(16, memory_length) * 0.1 # 16 groups for 16 dendrites
|
| 44 |
+
|
| 45 |
+
def compute_sync_matrix(self, z: np.ndarray) -> np.ndarray:
|
| 46 |
+
"""S^t = Z · Z^T (normalized)"""
|
| 47 |
+
z_norm = z / (np.linalg.norm(z) + 1e-8)
|
| 48 |
+
S = np.outer(z_norm, z_norm)
|
| 49 |
+
return S
|
| 50 |
+
|
| 51 |
+
def compute_certainty(self, predictions: np.ndarray) -> float:
|
| 52 |
+
"""Certainty = 1 - normalized entropy"""
|
| 53 |
+
probs = np.abs(predictions) / (np.sum(np.abs(predictions)) + 1e-8)
|
| 54 |
+
probs = np.clip(probs, 1e-10, 1.0)
|
| 55 |
+
entropy = -np.sum(probs * np.log(probs))
|
| 56 |
+
max_entropy = np.log(len(probs))
|
| 57 |
+
normalized_entropy = entropy / (max_entropy + 1e-8)
|
| 58 |
+
return float(1.0 - normalized_entropy)
|
| 59 |
+
|
| 60 |
+
def process_ticks(self, input_features: np.ndarray, n_ticks: Optional[int] = None) -> Dict:
|
| 61 |
+
"""Run T internal ticks and return sync matrix + certainty"""
|
| 62 |
+
n_ticks = n_ticks or self.n_ticks
|
| 63 |
+
|
| 64 |
+
# Ensure input is right size
|
| 65 |
+
if len(input_features) < self.d_model:
|
| 66 |
+
input_features = np.pad(input_features, (0, self.d_model - len(input_features)))
|
| 67 |
+
else:
|
| 68 |
+
input_features = input_features[:self.d_model]
|
| 69 |
+
|
| 70 |
+
certainties = []
|
| 71 |
+
sync_matrices = []
|
| 72 |
+
|
| 73 |
+
for t in range(n_ticks):
|
| 74 |
+
# Simulate synapse update
|
| 75 |
+
combined = np.concatenate([self.activated_state, input_features[:self.d_model//2]])
|
| 76 |
+
pre_activation = np.tanh(combined[:self.d_model] * 0.1 + np.random.randn(self.d_model) * 0.01)
|
| 77 |
+
|
| 78 |
+
# Update trace
|
| 79 |
+
self.state_trace = np.roll(self.state_trace, -1, axis=1)
|
| 80 |
+
self.state_trace[:, -1] = pre_activation
|
| 81 |
+
|
| 82 |
+
# Simulate NLM (simplified)
|
| 83 |
+
post_activation = np.zeros(self.d_model)
|
| 84 |
+
group_size = self.d_model // 16
|
| 85 |
+
for g in range(16):
|
| 86 |
+
start = g * group_size
|
| 87 |
+
end = start + group_size
|
| 88 |
+
group_trace = self.state_trace[start:end, :]
|
| 89 |
+
group_output = np.mean(group_trace @ self.nlm_weights[g])
|
| 90 |
+
post_activation[start:end] = np.tanh(group_output)
|
| 91 |
+
|
| 92 |
+
self.activated_state = post_activation
|
| 93 |
+
|
| 94 |
+
# Compute sync and certainty
|
| 95 |
+
sync = self.compute_sync_matrix(self.activated_state)
|
| 96 |
+
cert = self.compute_certainty(self.activated_state)
|
| 97 |
+
|
| 98 |
+
sync_matrices.append(sync)
|
| 99 |
+
certainties.append(cert)
|
| 100 |
+
|
| 101 |
+
# Find best ticks (min-loss proxy: max certainty)
|
| 102 |
+
best_tick = int(np.argmax(certainties))
|
| 103 |
+
|
| 104 |
+
return {
|
| 105 |
+
"final_sync_matrix": sync_matrices[-1].tolist(),
|
| 106 |
+
"best_sync_matrix": sync_matrices[best_tick].tolist(),
|
| 107 |
+
"certainties": certainties,
|
| 108 |
+
"final_certainty": float(certainties[-1]),
|
| 109 |
+
"max_certainty": float(max(certainties)),
|
| 110 |
+
"best_tick": best_tick,
|
| 111 |
+
"ticks_used": n_ticks
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
# Global CTM instance
|
| 115 |
+
ctm = SimplifiedCTM(d_model=256, memory_length=16, n_ticks=50)
|
| 116 |
+
|
| 117 |
+
# ============================================================================
|
| 118 |
+
# PHYSICS VALIDATION (from SNN Omega-21)
|
| 119 |
+
# ============================================================================
|
| 120 |
+
|
| 121 |
+
def validate_physics(trajectory: List[float], params: Dict) -> Dict:
|
| 122 |
+
"""Validate against 5 physics losses from SNN Omega-21"""
|
| 123 |
+
trajectory = np.array(trajectory)
|
| 124 |
+
|
| 125 |
+
# L_energy: Energy conservation
|
| 126 |
+
energy = np.sum(trajectory ** 2)
|
| 127 |
+
P_max = params.get("P_max", 1000.0)
|
| 128 |
+
L_energy = float(max(0, energy - P_max) ** 2)
|
| 129 |
+
|
| 130 |
+
# L_thermo: Thermodynamics (dew point check)
|
| 131 |
+
T_dew = params.get("T_dew", 15.0)
|
| 132 |
+
T_amb = params.get("T_amb", 25.0)
|
| 133 |
+
L_thermo = float(max(0, T_dew - T_amb) ** 2)
|
| 134 |
+
|
| 135 |
+
# L_causal: Causality (velocity limit)
|
| 136 |
+
velocity = np.diff(trajectory) if len(trajectory) > 1 else np.array([0])
|
| 137 |
+
v_max = params.get("v_max", 100.0)
|
| 138 |
+
L_causal = float(np.sum(np.maximum(0, np.abs(velocity) - v_max) ** 2))
|
| 139 |
+
|
| 140 |
+
# L_conserv: Flux conservation
|
| 141 |
+
flux_in = params.get("flux_in", 1.0)
|
| 142 |
+
flux_out = params.get("flux_out", 1.0)
|
| 143 |
+
L_conserv = float((flux_in - flux_out) ** 2)
|
| 144 |
+
|
| 145 |
+
# L_entropy: 2nd Law (entropy must increase)
|
| 146 |
+
entropy_change = params.get("entropy_change", 0.1)
|
| 147 |
+
L_entropy = float(max(0, -entropy_change) ** 2)
|
| 148 |
+
|
| 149 |
+
# Total physics loss
|
| 150 |
+
L_total = L_energy + L_thermo + L_causal + L_conserv + L_entropy
|
| 151 |
+
|
| 152 |
+
return {
|
| 153 |
+
"valid": L_total < 0.01,
|
| 154 |
+
"L_energy": L_energy,
|
| 155 |
+
"L_thermo": L_thermo,
|
| 156 |
+
"L_causal": L_causal,
|
| 157 |
+
"L_conserv": L_conserv,
|
| 158 |
+
"L_entropy": L_entropy,
|
| 159 |
+
"L_total": L_total
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
# ============================================================================
|
| 163 |
+
# ENDPOINT FUNCTIONS
|
| 164 |
+
# ============================================================================
|
| 165 |
+
|
| 166 |
+
def sense_snn(snn_json: str) -> str:
|
| 167 |
+
"""
|
| 168 |
+
/sense_snn - Process 72D SNN input
|
| 169 |
+
Input: JSON with dendrite values
|
| 170 |
+
Output: Coherent features + anomalies
|
| 171 |
+
"""
|
| 172 |
+
try:
|
| 173 |
+
data = json.loads(snn_json)
|
| 174 |
+
|
| 175 |
+
# Extract 72D vector (or create from dendrites)
|
| 176 |
+
if "vector_72d" in data:
|
| 177 |
+
input_vec = np.array(data["vector_72d"])
|
| 178 |
+
elif "dendrites" in data:
|
| 179 |
+
dendrite_values = list(data["dendrites"].values())
|
| 180 |
+
input_vec = np.array(dendrite_values)
|
| 181 |
+
else:
|
| 182 |
+
input_vec = np.random.randn(72) # Fallback
|
| 183 |
+
|
| 184 |
+
# Pad to 256D
|
| 185 |
+
input_256 = np.zeros(256)
|
| 186 |
+
input_256[:min(len(input_vec), 256)] = input_vec[:min(len(input_vec), 256)]
|
| 187 |
+
|
| 188 |
+
# Process through CTM
|
| 189 |
+
result = ctm.process_ticks(input_256, n_ticks=25)
|
| 190 |
+
|
| 191 |
+
# Detect anomalies (low certainty regions)
|
| 192 |
+
anomalies = []
|
| 193 |
+
if result["final_certainty"] < 0.5:
|
| 194 |
+
anomalies.append("Low overall certainty")
|
| 195 |
+
|
| 196 |
+
return json.dumps({
|
| 197 |
+
"status": "success",
|
| 198 |
+
"coherent_features": result["final_sync_matrix"][:16][:16], # 16x16 subset
|
| 199 |
+
"certainty": result["final_certainty"],
|
| 200 |
+
"anomalies": anomalies,
|
| 201 |
+
"ticks_used": result["ticks_used"]
|
| 202 |
+
}, indent=2)
|
| 203 |
+
except Exception as e:
|
| 204 |
+
return json.dumps({"status": "error", "message": str(e)})
|
| 205 |
+
|
| 206 |
+
def reason_hypergraph(context_json: str) -> str:
|
| 207 |
+
"""
|
| 208 |
+
/reason_hypergraph - Reason about hypergraph, propose edges
|
| 209 |
+
Input: Node features + existing edges
|
| 210 |
+
Output: Proposed new edges + certainty
|
| 211 |
+
"""
|
| 212 |
+
try:
|
| 213 |
+
data = json.loads(context_json)
|
| 214 |
+
|
| 215 |
+
node_features = np.array(data.get("node_features", [[0]*16]*8))
|
| 216 |
+
existing_edges = data.get("existing_edges", [])
|
| 217 |
+
n_ticks = data.get("ticks", 50)
|
| 218 |
+
|
| 219 |
+
# Flatten node features for CTM input
|
| 220 |
+
input_vec = node_features.flatten()
|
| 221 |
+
|
| 222 |
+
# Process through CTM with more ticks for reasoning
|
| 223 |
+
result = ctm.process_ticks(input_vec, n_ticks=n_ticks)
|
| 224 |
+
|
| 225 |
+
# Extract proposed edges from sync matrix (S_ij > 0.8)
|
| 226 |
+
sync = np.array(result["best_sync_matrix"])
|
| 227 |
+
proposed_edges = []
|
| 228 |
+
|
| 229 |
+
n_nodes = min(len(node_features), sync.shape[0])
|
| 230 |
+
for i in range(n_nodes):
|
| 231 |
+
for j in range(i+1, n_nodes):
|
| 232 |
+
sync_ij = sync[i, j]
|
| 233 |
+
if sync_ij > 0.8:
|
| 234 |
+
# Check if edge already exists
|
| 235 |
+
edge_exists = any(
|
| 236 |
+
(e[0] == i and e[1] == j) or (e[0] == j and e[1] == i)
|
| 237 |
+
for e in existing_edges
|
| 238 |
+
)
|
| 239 |
+
if not edge_exists:
|
| 240 |
+
proposed_edges.append([i, j, float(sync_ij)])
|
| 241 |
+
|
| 242 |
+
return json.dumps({
|
| 243 |
+
"status": "success",
|
| 244 |
+
"proposed_edges": proposed_edges,
|
| 245 |
+
"certainty": result["max_certainty"],
|
| 246 |
+
"best_tick": result["best_tick"],
|
| 247 |
+
"ticks_used": result["ticks_used"]
|
| 248 |
+
}, indent=2)
|
| 249 |
+
except Exception as e:
|
| 250 |
+
return json.dumps({"status": "error", "message": str(e)})
|
| 251 |
+
|
| 252 |
+
def validate_physics_endpoint(physics_json: str) -> str:
|
| 253 |
+
"""
|
| 254 |
+
/validate_physics - Validate trajectory against 5 physics losses
|
| 255 |
+
"""
|
| 256 |
+
try:
|
| 257 |
+
data = json.loads(physics_json)
|
| 258 |
+
trajectory = data.get("trajectory", [0.0])
|
| 259 |
+
params = data.get("physics_params", {})
|
| 260 |
+
|
| 261 |
+
result = validate_physics(trajectory, params)
|
| 262 |
+
result["status"] = "success"
|
| 263 |
+
|
| 264 |
+
return json.dumps(result, indent=2)
|
| 265 |
+
except Exception as e:
|
| 266 |
+
return json.dumps({"status": "error", "message": str(e)})
|
| 267 |
+
|
| 268 |
+
def dream_endpoint(dream_json: str) -> str:
|
| 269 |
+
"""
|
| 270 |
+
/dream - Offline consolidation with T=500+ ticks
|
| 271 |
+
Input: Hypergraph snapshot
|
| 272 |
+
Output: Discovered patterns + new edges
|
| 273 |
+
"""
|
| 274 |
+
try:
|
| 275 |
+
data = json.loads(dream_json)
|
| 276 |
+
|
| 277 |
+
snapshot = data.get("hypergraph_snapshot", {})
|
| 278 |
+
n_ticks = data.get("ticks", 500)
|
| 279 |
+
|
| 280 |
+
# Extract features from snapshot
|
| 281 |
+
nodes = snapshot.get("nodes", [])
|
| 282 |
+
if nodes:
|
| 283 |
+
input_vec = np.array([n.get("features", [0]*16) for n in nodes]).flatten()
|
| 284 |
+
else:
|
| 285 |
+
input_vec = np.random.randn(256) # Random dream if no nodes
|
| 286 |
+
|
| 287 |
+
# Dream: run CTM with many ticks and no external input after initial
|
| 288 |
+
result = ctm.process_ticks(input_vec, n_ticks=min(n_ticks, 100)) # Cap at 100 for CPU
|
| 289 |
+
|
| 290 |
+
# Analyze sync evolution to find patterns
|
| 291 |
+
sync = np.array(result["final_sync_matrix"])
|
| 292 |
+
|
| 293 |
+
# Find strong sync pairs (new edges)
|
| 294 |
+
new_edges = []
|
| 295 |
+
n = min(len(nodes), sync.shape[0]) if nodes else 16
|
| 296 |
+
for i in range(n):
|
| 297 |
+
for j in range(i+1, n):
|
| 298 |
+
if sync[i, j] > 0.85:
|
| 299 |
+
new_edges.append([i, j, float(sync[i, j])])
|
| 300 |
+
|
| 301 |
+
# Find weak sync pairs (edges to prune)
|
| 302 |
+
pruned_edges = []
|
| 303 |
+
for i in range(n):
|
| 304 |
+
for j in range(i+1, n):
|
| 305 |
+
if sync[i, j] < 0.1:
|
| 306 |
+
pruned_edges.append([i, j])
|
| 307 |
+
|
| 308 |
+
return json.dumps({
|
| 309 |
+
"status": "success",
|
| 310 |
+
"discovered_patterns": len(new_edges),
|
| 311 |
+
"new_edges": new_edges[:10], # Top 10
|
| 312 |
+
"pruned_edges": pruned_edges[:10], # Top 10
|
| 313 |
+
"consolidation_certainty": result["max_certainty"],
|
| 314 |
+
"ticks_used": result["ticks_used"]
|
| 315 |
+
}, indent=2)
|
| 316 |
+
except Exception as e:
|
| 317 |
+
return json.dumps({"status": "error", "message": str(e)})
|
| 318 |
+
|
| 319 |
+
def calibrate_stdp_endpoint(stdp_json: str) -> str:
|
| 320 |
+
"""
|
| 321 |
+
/calibrate_stdp - Suggest STDP weight adjustments from sync
|
| 322 |
+
"""
|
| 323 |
+
try:
|
| 324 |
+
data = json.loads(stdp_json)
|
| 325 |
+
|
| 326 |
+
current_weights = np.array(data.get("current_weights", [1.0]*16))
|
| 327 |
+
node_features = np.array(data.get("node_features", [[0]*16]*8))
|
| 328 |
+
|
| 329 |
+
# Process to get sync matrix
|
| 330 |
+
input_vec = node_features.flatten()
|
| 331 |
+
result = ctm.process_ticks(input_vec, n_ticks=25)
|
| 332 |
+
|
| 333 |
+
sync = np.array(result["final_sync_matrix"])
|
| 334 |
+
|
| 335 |
+
# Suggest weight adjustments based on sync patterns
|
| 336 |
+
# Uses diagonal of sync (self-similarity) to scale weights
|
| 337 |
+
suggested = np.zeros(16)
|
| 338 |
+
for i in range(16):
|
| 339 |
+
# Average sync of neuron i with others
|
| 340 |
+
avg_sync = np.mean(sync[i, :])
|
| 341 |
+
# Scale current weight by sync
|
| 342 |
+
suggested[i] = current_weights[i] * (0.5 + avg_sync)
|
| 343 |
+
|
| 344 |
+
return json.dumps({
|
| 345 |
+
"status": "success",
|
| 346 |
+
"suggested_weights": suggested.tolist(),
|
| 347 |
+
"weight_changes": (suggested - current_weights).tolist(),
|
| 348 |
+
"confidence": result["final_certainty"]
|
| 349 |
+
}, indent=2)
|
| 350 |
+
except Exception as e:
|
| 351 |
+
return json.dumps({"status": "error", "message": str(e)})
|
| 352 |
+
|
| 353 |
+
def health_check() -> str:
|
| 354 |
+
"""Health check for the CTM server"""
|
| 355 |
+
return json.dumps({
|
| 356 |
+
"status": "healthy",
|
| 357 |
+
"model": "CTM Nervous System v1.0",
|
| 358 |
+
"d_model": ctm.d_model,
|
| 359 |
+
"memory_length": ctm.memory_length,
|
| 360 |
+
"default_ticks": ctm.n_ticks,
|
| 361 |
+
"endpoints": [
|
| 362 |
+
"/sense_snn",
|
| 363 |
+
"/reason_hypergraph",
|
| 364 |
+
"/validate_physics",
|
| 365 |
+
"/dream",
|
| 366 |
+
"/calibrate_stdp"
|
| 367 |
+
]
|
| 368 |
+
}, indent=2)
|
| 369 |
+
|
| 370 |
+
# ============================================================================
|
| 371 |
+
# GRADIO INTERFACE
|
| 372 |
+
# ============================================================================
|
| 373 |
+
|
| 374 |
+
with gr.Blocks(title="CTM Nervous System") as demo:
|
| 375 |
+
gr.Markdown("""
|
| 376 |
+
# 🧬 CTM Nervous System
|
| 377 |
+
**Continuous Thought Machine for Hypergraph Maintenance**
|
| 378 |
+
|
| 379 |
+
Based on [arXiv:2505.05522](https://arxiv.org/abs/2505.05522) - Sakana AI
|
| 380 |
+
|
| 381 |
+
---
|
| 382 |
+
|
| 383 |
+
## Endpoints
|
| 384 |
+
- **/sense_snn**: Process 72D SNN input
|
| 385 |
+
- **/reason_hypergraph**: Reason about context, propose edges
|
| 386 |
+
- **/validate_physics**: Validate against 5 physics losses
|
| 387 |
+
- **/dream**: Offline consolidation (T=500+)
|
| 388 |
+
- **/calibrate_stdp**: Suggest STDP weight adjustments
|
| 389 |
+
""")
|
| 390 |
+
|
| 391 |
+
with gr.Tabs():
|
| 392 |
+
with gr.Tab("🔌 /sense_snn"):
|
| 393 |
+
gr.Markdown("Process 72D SNN input vector")
|
| 394 |
+
snn_input = gr.Textbox(
|
| 395 |
+
label="SNN JSON Input",
|
| 396 |
+
value='{"dendrites": {"d1": 0.1, "d2": 0.2, "d3": 0.3}}',
|
| 397 |
+
lines=5
|
| 398 |
+
)
|
| 399 |
+
snn_output = gr.Textbox(label="Output", lines=10)
|
| 400 |
+
snn_btn = gr.Button("Process", variant="primary")
|
| 401 |
+
snn_btn.click(sense_snn, inputs=snn_input, outputs=snn_output, api_name="sense_snn")
|
| 402 |
+
|
| 403 |
+
with gr.Tab("🧠 /reason_hypergraph"):
|
| 404 |
+
gr.Markdown("Reason about hypergraph context")
|
| 405 |
+
reason_input = gr.Textbox(
|
| 406 |
+
label="Context JSON",
|
| 407 |
+
value='{"node_features": [[0.1, 0.2], [0.3, 0.4]], "existing_edges": [], "ticks": 50}',
|
| 408 |
+
lines=5
|
| 409 |
+
)
|
| 410 |
+
reason_output = gr.Textbox(label="Output", lines=10)
|
| 411 |
+
reason_btn = gr.Button("Reason", variant="primary")
|
| 412 |
+
reason_btn.click(reason_hypergraph, inputs=reason_input, outputs=reason_output, api_name="reason_hypergraph")
|
| 413 |
+
|
| 414 |
+
with gr.Tab("⚡ /validate_physics"):
|
| 415 |
+
gr.Markdown("Validate against 5 physics losses")
|
| 416 |
+
physics_input = gr.Textbox(
|
| 417 |
+
label="Physics JSON",
|
| 418 |
+
value='{"trajectory": [0.1, 0.2, 0.3], "physics_params": {"P_max": 1000}}',
|
| 419 |
+
lines=5
|
| 420 |
+
)
|
| 421 |
+
physics_output = gr.Textbox(label="Output", lines=10)
|
| 422 |
+
physics_btn = gr.Button("Validate", variant="primary")
|
| 423 |
+
physics_btn.click(validate_physics_endpoint, inputs=physics_input, outputs=physics_output, api_name="validate_physics")
|
| 424 |
+
|
| 425 |
+
with gr.Tab("💤 /dream"):
|
| 426 |
+
gr.Markdown("Offline consolidation")
|
| 427 |
+
dream_input = gr.Textbox(
|
| 428 |
+
label="Dream JSON",
|
| 429 |
+
value='{"hypergraph_snapshot": {"nodes": []}, "ticks": 100}',
|
| 430 |
+
lines=5
|
| 431 |
+
)
|
| 432 |
+
dream_output = gr.Textbox(label="Output", lines=10)
|
| 433 |
+
dream_btn = gr.Button("Dream", variant="primary")
|
| 434 |
+
dream_btn.click(dream_endpoint, inputs=dream_input, outputs=dream_output, api_name="dream")
|
| 435 |
+
|
| 436 |
+
with gr.Tab("🔧 /calibrate_stdp"):
|
| 437 |
+
gr.Markdown("Calibrate STDP weights")
|
| 438 |
+
stdp_input = gr.Textbox(
|
| 439 |
+
label="STDP JSON",
|
| 440 |
+
value='{"current_weights": [1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1], "node_features": [[0.1, 0.2]]}',
|
| 441 |
+
lines=5
|
| 442 |
+
)
|
| 443 |
+
stdp_output = gr.Textbox(label="Output", lines=10)
|
| 444 |
+
stdp_btn = gr.Button("Calibrate", variant="primary")
|
| 445 |
+
stdp_btn.click(calibrate_stdp_endpoint, inputs=stdp_input, outputs=stdp_output, api_name="calibrate_stdp")
|
| 446 |
+
|
| 447 |
+
with gr.Tab("❤️ Health"):
|
| 448 |
+
health_output = gr.Textbox(label="Health Status", lines=10)
|
| 449 |
+
health_btn = gr.Button("Check Health", variant="secondary")
|
| 450 |
+
health_btn.click(health_check, inputs=None, outputs=health_output, api_name="health_check")
|
| 451 |
+
|
| 452 |
+
gr.Markdown("""
|
| 453 |
+
---
|
| 454 |
+
**Architecture**: CTM as Nervous System → Hypergraph as Thought
|
| 455 |
+
|
| 456 |
+
**Training**: Min-Loss + Max-Certainty + Physics Regularization
|
| 457 |
+
""")
|
| 458 |
+
|
| 459 |
+
if __name__ == "__main__":
|
| 460 |
+
demo.launch(
|
| 461 |
+
server_name="0.0.0.0",
|
| 462 |
+
server_port=7860,
|
| 463 |
+
show_error=True
|
| 464 |
+
)
|
data/custom_datasets.py
ADDED
|
@@ -0,0 +1,324 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torchvision.datasets import ImageFolder
|
| 3 |
+
from torch.utils.data import Dataset
|
| 4 |
+
import random
|
| 5 |
+
import numpy as np
|
| 6 |
+
from tqdm.auto import tqdm
|
| 7 |
+
from PIL import Image
|
| 8 |
+
from datasets import load_dataset
|
| 9 |
+
|
| 10 |
+
class SortDataset(Dataset):
|
| 11 |
+
def __init__(self, N):
|
| 12 |
+
self.N = N
|
| 13 |
+
def __len__(self):
|
| 14 |
+
return 10000000
|
| 15 |
+
def __getitem__(self, idx):
|
| 16 |
+
data = torch.zeros(self.N).normal_()
|
| 17 |
+
ordering = torch.argsort(data)
|
| 18 |
+
inputs = data
|
| 19 |
+
return (inputs), (ordering)
|
| 20 |
+
|
| 21 |
+
class QAMNISTDataset(Dataset):
|
| 22 |
+
"""A QAMNIST dataset that includes plus and minus operations on MNIST digits."""
|
| 23 |
+
def __init__(self, base_dataset, num_images, num_images_delta, num_repeats_per_input, num_operations, num_operations_delta):
|
| 24 |
+
self.base_dataset = base_dataset
|
| 25 |
+
|
| 26 |
+
self.num_images = num_images
|
| 27 |
+
self.num_images_delta = num_images_delta
|
| 28 |
+
self.num_images_range = self._calculate_num_images_range()
|
| 29 |
+
|
| 30 |
+
self.operators = ["+", "-"]
|
| 31 |
+
self.num_operations = num_operations
|
| 32 |
+
self.num_operations_delta = num_operations_delta
|
| 33 |
+
self.num_operations_range = self._calculate_num_operations_range()
|
| 34 |
+
|
| 35 |
+
self.num_repeats_per_input = num_repeats_per_input
|
| 36 |
+
|
| 37 |
+
self.current_num_digits = num_images
|
| 38 |
+
self.current_num_operations = num_operations
|
| 39 |
+
|
| 40 |
+
self.modulo_base = 10
|
| 41 |
+
|
| 42 |
+
self.output_range = [0, 9]
|
| 43 |
+
|
| 44 |
+
def _calculate_num_images_range(self):
|
| 45 |
+
min_val = self.num_images - self.num_images_delta
|
| 46 |
+
max_val = self.num_images + self.num_images_delta
|
| 47 |
+
assert min_val >= 1, f"Minimum number of images must be at least 1, got {min_val}"
|
| 48 |
+
return [min_val, max_val]
|
| 49 |
+
|
| 50 |
+
def _calculate_num_operations_range(self):
|
| 51 |
+
min_val = self.num_operations - self.num_operations_delta
|
| 52 |
+
max_val = self.num_operations + self.num_operations_delta
|
| 53 |
+
assert min_val >= 1, f"Minimum number of operations must be at least 1, got {min_val}"
|
| 54 |
+
return [min_val, max_val]
|
| 55 |
+
|
| 56 |
+
def set_num_digits(self, num_digits):
|
| 57 |
+
self.current_num_digits = num_digits
|
| 58 |
+
|
| 59 |
+
def set_num_operations(self, num_operations):
|
| 60 |
+
self.current_num_operations = num_operations
|
| 61 |
+
|
| 62 |
+
def _get_target_and_question(self, targets):
|
| 63 |
+
question = []
|
| 64 |
+
equations = []
|
| 65 |
+
num_digits = self.current_num_digits
|
| 66 |
+
num_operations = self.current_num_operations
|
| 67 |
+
|
| 68 |
+
# Select the initial digit
|
| 69 |
+
selection_idx = np.random.randint(num_digits)
|
| 70 |
+
first_digit = targets[selection_idx]
|
| 71 |
+
question.extend([selection_idx] * self.num_repeats_per_input)
|
| 72 |
+
# Set current_value to the initial digit (mod is applied in each operation)
|
| 73 |
+
current_value = first_digit % self.modulo_base
|
| 74 |
+
|
| 75 |
+
# For each operation, build an equation line
|
| 76 |
+
for _ in range(num_operations):
|
| 77 |
+
# Choose the operator ('+' or '-')
|
| 78 |
+
operator_idx = np.random.randint(len(self.operators))
|
| 79 |
+
operator = self.operators[operator_idx]
|
| 80 |
+
encoded_operator = -(operator_idx + 1) # -1 for '+', -2 for '-'
|
| 81 |
+
question.extend([encoded_operator] * self.num_repeats_per_input)
|
| 82 |
+
|
| 83 |
+
# Choose the next digit
|
| 84 |
+
selection_idx = np.random.randint(num_digits)
|
| 85 |
+
digit = targets[selection_idx]
|
| 86 |
+
question.extend([selection_idx] * self.num_repeats_per_input)
|
| 87 |
+
|
| 88 |
+
# Compute the new value with immediate modulo reduction
|
| 89 |
+
if operator == '+':
|
| 90 |
+
new_value = (current_value + digit) % self.modulo_base
|
| 91 |
+
else: # operator is '-'
|
| 92 |
+
new_value = (current_value - digit) % self.modulo_base
|
| 93 |
+
|
| 94 |
+
# Build the equation string for this step
|
| 95 |
+
equations.append(f"({current_value} {operator} {digit}) mod {self.modulo_base} = {new_value}")
|
| 96 |
+
# Update current value for the next operation
|
| 97 |
+
current_value = new_value
|
| 98 |
+
|
| 99 |
+
target = current_value
|
| 100 |
+
question_readable = "\n".join(equations)
|
| 101 |
+
return target, question, question_readable
|
| 102 |
+
|
| 103 |
+
def __len__(self):
|
| 104 |
+
return len(self.base_dataset)
|
| 105 |
+
|
| 106 |
+
def __getitem__(self, idx):
|
| 107 |
+
images, targets = [],[]
|
| 108 |
+
for _ in range(self.current_num_digits):
|
| 109 |
+
image, target = self.base_dataset[np.random.randint(self.__len__())]
|
| 110 |
+
images.append(image)
|
| 111 |
+
targets.append(target)
|
| 112 |
+
|
| 113 |
+
observations = torch.repeat_interleave(torch.stack(images, 0), repeats=self.num_repeats_per_input, dim=0)
|
| 114 |
+
target, question, question_readable = self._get_target_and_question(targets)
|
| 115 |
+
return observations, question, question_readable, target
|
| 116 |
+
|
| 117 |
+
class ImageNet(Dataset):
|
| 118 |
+
def __init__(self, which_split, transform):
|
| 119 |
+
"""
|
| 120 |
+
Most simple form of the custom dataset structure.
|
| 121 |
+
Args:
|
| 122 |
+
base_dataset (Dataset): The base dataset to sample from.
|
| 123 |
+
N (int): The number of images to construct into an observable sequence.
|
| 124 |
+
R (int): number of repeats
|
| 125 |
+
operators (list): list of operators from which to sample
|
| 126 |
+
action to take on observations (str): can be 'global' to compute operator over full observations, or 'select_K', where K=integer.
|
| 127 |
+
"""
|
| 128 |
+
dataset = load_dataset('imagenet-1k', split=which_split, trust_remote_code=True)
|
| 129 |
+
|
| 130 |
+
self.transform = transform
|
| 131 |
+
self.base_dataset = dataset
|
| 132 |
+
|
| 133 |
+
def __len__(self):
|
| 134 |
+
return len(self.base_dataset)
|
| 135 |
+
|
| 136 |
+
def __getitem__(self, idx):
|
| 137 |
+
data_item = self.base_dataset[idx]
|
| 138 |
+
image = self.transform(data_item['image'].convert('RGB'))
|
| 139 |
+
target = data_item['label']
|
| 140 |
+
return image, target
|
| 141 |
+
|
| 142 |
+
class MazeImageFolder(ImageFolder):
|
| 143 |
+
"""
|
| 144 |
+
A custom dataset class that extends the ImageFolder class.
|
| 145 |
+
|
| 146 |
+
Args:
|
| 147 |
+
root (string): Root directory path.
|
| 148 |
+
transform (callable, optional): A function/transform that takes in
|
| 149 |
+
a sample and returns a transformed version.
|
| 150 |
+
E.g, ``transforms.RandomCrop`` for images.
|
| 151 |
+
target_transform (callable, optional): A function/transform that takes
|
| 152 |
+
in the target and transforms it.
|
| 153 |
+
loader (callable, optional): A function to load an image given its path.
|
| 154 |
+
is_valid_file (callable, optional): A function that takes path of an Image file
|
| 155 |
+
and check if the file is a valid file (used to check of corrupt files)
|
| 156 |
+
|
| 157 |
+
Attributes:
|
| 158 |
+
classes (list): List of the class names.
|
| 159 |
+
class_to_idx (dict): Dict with items (class_name, class_index).
|
| 160 |
+
imgs (list): List of (image path, class_index) tuples
|
| 161 |
+
"""
|
| 162 |
+
|
| 163 |
+
def __init__(self, root, transform=None, target_transform=None,
|
| 164 |
+
loader=Image.open,
|
| 165 |
+
is_valid_file=None,
|
| 166 |
+
which_set='train',
|
| 167 |
+
augment_p=0.5,
|
| 168 |
+
maze_route_length=10,
|
| 169 |
+
trunc=False,
|
| 170 |
+
expand_range=True):
|
| 171 |
+
super(MazeImageFolder, self).__init__(root, transform, target_transform, loader, is_valid_file)
|
| 172 |
+
self.which_set = which_set
|
| 173 |
+
self.augment_p = augment_p
|
| 174 |
+
self.maze_route_length = maze_route_length
|
| 175 |
+
self.all_paths = {}
|
| 176 |
+
self.trunc = trunc
|
| 177 |
+
self.expand_range = expand_range
|
| 178 |
+
|
| 179 |
+
self._preload()
|
| 180 |
+
print('Solving all mazes...')
|
| 181 |
+
for index in range(len(self.preloaded_samples)):
|
| 182 |
+
path = self.get_solution(self.preloaded_samples[index])
|
| 183 |
+
self.all_paths[index] = path
|
| 184 |
+
|
| 185 |
+
def _preload(self):
|
| 186 |
+
preloaded_samples = []
|
| 187 |
+
with tqdm(total=self.__len__(), initial=0, leave=True, position=0, dynamic_ncols=True) as pbar:
|
| 188 |
+
|
| 189 |
+
for index in range(self.__len__()):
|
| 190 |
+
pbar.set_description('Loading mazes')
|
| 191 |
+
path, target = self.samples[index]
|
| 192 |
+
sample = self.loader(path)
|
| 193 |
+
sample = np.array(sample).astype(np.float32)/255
|
| 194 |
+
preloaded_samples.append(sample)
|
| 195 |
+
pbar.update(1)
|
| 196 |
+
if self.trunc and index == 999: break
|
| 197 |
+
self.preloaded_samples = preloaded_samples
|
| 198 |
+
|
| 199 |
+
def __len__(self):
|
| 200 |
+
if hasattr(self, 'preloaded_samples') and self.preloaded_samples is not None:
|
| 201 |
+
return len(self.preloaded_samples)
|
| 202 |
+
else:
|
| 203 |
+
return super().__len__()
|
| 204 |
+
|
| 205 |
+
def get_solution(self, x):
|
| 206 |
+
x = np.copy(x)
|
| 207 |
+
# Find start (red) and end (green) pixel coordinates
|
| 208 |
+
start_coords = np.argwhere((x == [1, 0, 0]).all(axis=2))
|
| 209 |
+
end_coords = np.argwhere((x == [0, 1, 0]).all(axis=2))
|
| 210 |
+
|
| 211 |
+
if len(start_coords) == 0 or len(end_coords) == 0:
|
| 212 |
+
print("Start or end point not found.")
|
| 213 |
+
return None
|
| 214 |
+
|
| 215 |
+
start_y, start_x = start_coords[0]
|
| 216 |
+
end_y, end_x = end_coords[0]
|
| 217 |
+
|
| 218 |
+
current_y, current_x = start_y, start_x
|
| 219 |
+
path = [4] * self.maze_route_length
|
| 220 |
+
|
| 221 |
+
pi = 0
|
| 222 |
+
while (current_y, current_x) != (end_y, end_x):
|
| 223 |
+
next_y, next_x = -1, -1 # Initialize to invalid coordinates
|
| 224 |
+
direction = -1 # Initialize to an invalid direction
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
# Check Up
|
| 228 |
+
if current_y > 0 and ((x[current_y - 1, current_x] == [0, 0, 1]).all() or (x[current_y - 1, current_x] == [0, 1, 0]).all()):
|
| 229 |
+
next_y, next_x = current_y - 1, current_x
|
| 230 |
+
direction = 0
|
| 231 |
+
|
| 232 |
+
# Check Down
|
| 233 |
+
elif current_y < x.shape[0] - 1 and ((x[current_y + 1, current_x] == [0, 0, 1]).all() or (x[current_y + 1, current_x] == [0, 1, 0]).all()):
|
| 234 |
+
next_y, next_x = current_y + 1, current_x
|
| 235 |
+
direction = 1
|
| 236 |
+
|
| 237 |
+
# Check Left
|
| 238 |
+
elif current_x > 0 and ((x[current_y, current_x - 1] == [0, 0, 1]).all() or (x[current_y, current_x - 1] == [0, 1, 0]).all()):
|
| 239 |
+
next_y, next_x = current_y, current_x - 1
|
| 240 |
+
direction = 2
|
| 241 |
+
|
| 242 |
+
# Check Right
|
| 243 |
+
elif current_x < x.shape[1] - 1 and ((x[current_y, current_x + 1] == [0, 0, 1]).all() or (x[current_y, current_x + 1] == [0, 1, 0]).all()):
|
| 244 |
+
next_y, next_x = current_y, current_x + 1
|
| 245 |
+
direction = 3
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
path[pi] = direction
|
| 249 |
+
pi += 1
|
| 250 |
+
|
| 251 |
+
x[current_y, current_x] = [255,255,255] # mark the current as white to avoid going in circles
|
| 252 |
+
current_y, current_x = next_y, next_x
|
| 253 |
+
if pi == len(path):
|
| 254 |
+
break
|
| 255 |
+
|
| 256 |
+
return np.array(path)
|
| 257 |
+
|
| 258 |
+
def __getitem__(self, index):
|
| 259 |
+
"""
|
| 260 |
+
Args:
|
| 261 |
+
index (int): Index
|
| 262 |
+
|
| 263 |
+
Returns:
|
| 264 |
+
tuple: (sample, target) where target is class_index of the target class.
|
| 265 |
+
"""
|
| 266 |
+
|
| 267 |
+
sample = np.copy(self.preloaded_samples[index])
|
| 268 |
+
|
| 269 |
+
path = np.copy(self.all_paths[index])
|
| 270 |
+
|
| 271 |
+
if self.which_set == 'train':
|
| 272 |
+
# Randomly rotate -90 or +90 degrees
|
| 273 |
+
if random.random() < self.augment_p:
|
| 274 |
+
which_rot = random.choice([-1, 1])
|
| 275 |
+
sample = np.rot90(sample, k=which_rot, axes=(0, 1))
|
| 276 |
+
for pi in range(len(path)):
|
| 277 |
+
if path[pi] == 0: path[pi] = 3 if which_rot == -1 else 2
|
| 278 |
+
elif path[pi] == 1: path[pi] = 2 if which_rot == -1 else 3
|
| 279 |
+
elif path[pi] == 2: path[pi] = 0 if which_rot == -1 else 1
|
| 280 |
+
elif path[pi] == 3: path[pi] = 1 if which_rot == -1 else 0
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
# Random horizontal flip
|
| 284 |
+
if random.random() < self.augment_p:
|
| 285 |
+
sample = np.fliplr(sample)
|
| 286 |
+
for pi in range(len(path)):
|
| 287 |
+
if path[pi] == 2: path[pi] = 3
|
| 288 |
+
elif path[pi] == 3: path[pi] = 2
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
# Random vertical flip
|
| 292 |
+
if random.random() < self.augment_p:
|
| 293 |
+
sample = np.flipud(sample)
|
| 294 |
+
for pi in range(len(path)):
|
| 295 |
+
if path[pi] == 0: path[pi] = 1
|
| 296 |
+
elif path[pi] == 1: path[pi] = 0
|
| 297 |
+
|
| 298 |
+
sample = torch.from_numpy(np.copy(sample)).permute(2,0,1)
|
| 299 |
+
|
| 300 |
+
blue_mask = (sample[0] == 0) & (sample[1] == 0) & (sample[2] == 1)
|
| 301 |
+
|
| 302 |
+
sample[:, blue_mask] = 1
|
| 303 |
+
target = path
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
if not self.expand_range:
|
| 307 |
+
return sample, target
|
| 308 |
+
return (sample*2)-1, (target)
|
| 309 |
+
|
| 310 |
+
class ParityDataset(Dataset):
|
| 311 |
+
def __init__(self, sequence_length=64, length=100000):
|
| 312 |
+
self.sequence_length = sequence_length
|
| 313 |
+
self.length = length
|
| 314 |
+
|
| 315 |
+
def __len__(self):
|
| 316 |
+
return self.length
|
| 317 |
+
|
| 318 |
+
def __getitem__(self, idx):
|
| 319 |
+
vector = 2 * torch.randint(0, 2, (self.sequence_length,)) - 1
|
| 320 |
+
vector = vector.float()
|
| 321 |
+
negatives = (vector == -1).to(torch.long)
|
| 322 |
+
cumsum = torch.cumsum(negatives, dim=0)
|
| 323 |
+
target = (cumsum % 2 != 0).to(torch.long)
|
| 324 |
+
return vector, target
|
examples/01_mnist.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
examples/02_inference.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
examples/03_mazes.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
examples/04_parity.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
examples/05_huggingface.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
models/README.md
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Continuous Thought Machines
|
| 2 |
+
## Models
|
| 3 |
+
|
| 4 |
+
This folder contains all model-related code.
|
| 5 |
+
|
| 6 |
+
Some notes for clarity:
|
| 7 |
+
1. The resnet structure we used (see resnet.py) has a few minor changes that enable constraining the receptive field of the features yielded. We do this because we want the CTM (or baseline methods) to learn a process whereby they gather information. Neural networks that use SGD will find the [path of least resistence](https://era.ed.ac.uk/handle/1842/39606), even if that path doesn't result in actually intelligent behaviour. Constraining the receptive field helps to prevent this, a bit.
|
models/constants.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
VALID_NEURON_SELECT_TYPES = ['first-last', 'random', 'random-pairing']
|
| 2 |
+
|
| 3 |
+
VALID_BACKBONE_TYPES = [
|
| 4 |
+
f'resnet{depth}-{i}' for depth in [18, 34, 50, 101, 152] for i in range(1, 5)
|
| 5 |
+
] + ['shallow-wide', 'parity_backbone']
|
| 6 |
+
|
| 7 |
+
VALID_POSITIONAL_EMBEDDING_TYPES = [
|
| 8 |
+
'learnable-fourier', 'multi-learnable-fourier',
|
| 9 |
+
'custom-rotational', 'custom-rotational-1d'
|
| 10 |
+
]
|
models/ctm.py
ADDED
|
@@ -0,0 +1,633 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
import math
|
| 5 |
+
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
|
| 6 |
+
|
| 7 |
+
from models.modules import ParityBackbone, SynapseUNET, Squeeze, SuperLinear, LearnableFourierPositionalEncoding, MultiLearnableFourierPositionalEncoding, CustomRotationalEmbedding, CustomRotationalEmbedding1D, ShallowWide
|
| 8 |
+
from models.resnet import prepare_resnet_backbone
|
| 9 |
+
from models.utils import compute_normalized_entropy
|
| 10 |
+
|
| 11 |
+
from models.constants import (
|
| 12 |
+
VALID_NEURON_SELECT_TYPES,
|
| 13 |
+
VALID_BACKBONE_TYPES,
|
| 14 |
+
VALID_POSITIONAL_EMBEDDING_TYPES
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
class ContinuousThoughtMachine(nn.Module, PyTorchModelHubMixin):
|
| 18 |
+
"""
|
| 19 |
+
Continuous Thought Machine (CTM).
|
| 20 |
+
|
| 21 |
+
Technical report: https://arxiv.org/abs/2505.05522
|
| 22 |
+
|
| 23 |
+
Interactive Website: https://pub.sakana.ai/ctm/
|
| 24 |
+
|
| 25 |
+
Blog: https://sakana.ai/ctm/
|
| 26 |
+
|
| 27 |
+
Thought takes time and reasoning is a process.
|
| 28 |
+
|
| 29 |
+
The CTM consists of three main ideas:
|
| 30 |
+
1. The use of internal recurrence, enabling a dimension over which a concept analogous to thought can occur.
|
| 31 |
+
1. Neuron-level models, that compute post-activations by applying private (i.e., on a per-neuron basis) MLP
|
| 32 |
+
models to a history of incoming pre-activations.
|
| 33 |
+
2. Synchronisation as representation, where the neural activity over time is tracked and used to compute how
|
| 34 |
+
pairs of neurons synchronise with one another over time. This measure of synchronisation is the representation
|
| 35 |
+
with which the CTM takes action and makes predictions.
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
iterations (int): Number of internal 'thought' ticks (T, in paper).
|
| 40 |
+
d_model (int): Core dimensionality of the CTM's latent space (D, in paper).
|
| 41 |
+
NOTE: Note that this is NOT the representation used for action or prediction, but rather that which
|
| 42 |
+
is fully internal to the model and not directly connected to data.
|
| 43 |
+
d_input (int): Dimensionality of projected attention outputs or direct input features.
|
| 44 |
+
heads (int): Number of attention heads.
|
| 45 |
+
n_synch_out (int): Number of neurons used for output synchronisation (D_out, in paper).
|
| 46 |
+
n_synch_action (int): Number of neurons used for action/attention synchronisation (D_action, in paper).
|
| 47 |
+
synapse_depth (int): Depth of the synapse model (U-Net if > 1, else MLP).
|
| 48 |
+
memory_length (int): History length for Neuron-Level Models (M, in paper).
|
| 49 |
+
deep_nlms (bool): Use deeper (2-layer) NLMs if True, else linear.
|
| 50 |
+
NOTE: we almost always use deep NLMs, but a linear NLM is faster.
|
| 51 |
+
memory_hidden_dims (int): Hidden dimension size for deep NLMs.
|
| 52 |
+
do_layernorm_nlm (bool): Apply LayerNorm within NLMs.
|
| 53 |
+
NOTE: we never set this to true in the paper. If you set this to true you will get strange behaviour,
|
| 54 |
+
but you can potentially encourage more periodic behaviour in the dynamics. Untested; be careful.
|
| 55 |
+
backbone_type (str): Type of feature extraction backbone (e.g., 'resnet18-2', 'none').
|
| 56 |
+
positional_embedding_type (str): Type of positional embedding for backbone features.
|
| 57 |
+
out_dims (int): Output dimension size.
|
| 58 |
+
NOTE: projected from synchronisation!
|
| 59 |
+
prediction_reshaper (list): Shape for reshaping predictions before certainty calculation (task-specific).
|
| 60 |
+
NOTE: this is used to compute certainty and is needed when applying softmax for probabilities
|
| 61 |
+
dropout (float): Dropout rate.
|
| 62 |
+
neuron_select_type (str): Neuron selection strategy ('first-last', 'random', 'random-pairing').
|
| 63 |
+
NOTE: some of this is legacy from our experimentation, but all three strategies are valid and useful.
|
| 64 |
+
We dilineate exactly which strategies we use per experiment in the paper.
|
| 65 |
+
- first-last: build a 'dense' sync matrix for output from the first D_out neurons and action from the
|
| 66 |
+
last D_action neurons. Flatten this matrix into the synchronisation representation.
|
| 67 |
+
This approach shares relationships for neurons and bottlenecks the gradients through them.
|
| 68 |
+
NOTE: the synchronisation size will be (D_out/action * (D_out/action + 1))/2
|
| 69 |
+
- random: randomly select D_out neurons for the 'i' side pairings, and also D_out for the 'j' side pairings,
|
| 70 |
+
also pairing those accross densely, resulting in a bottleneck roughly 2x as wide.
|
| 71 |
+
NOTE: the synchronisation size will be (D_out/action * (D_out/action + 1))/2
|
| 72 |
+
- random-pairing (DEFAULT!): randomly select D_out neurons and pair these with another D_out neurons.
|
| 73 |
+
This results in much less bottlenecking and is the most up-to-date variant.
|
| 74 |
+
NOTE: the synchronisation size will be D_out in this case; better control.
|
| 75 |
+
n_random_pairing_self (int): Number of neurons to select for self-to-self synch when random-pairing is used.
|
| 76 |
+
NOTE: when using random-pairing, i-to-i (self) synchronisation is rare, meaning that 'recovering a
|
| 77 |
+
snapshot representation' (see paper) is difficult. This alleviates that.
|
| 78 |
+
NOTE: works fine when set to 0.
|
| 79 |
+
"""
|
| 80 |
+
|
| 81 |
+
def __init__(self,
|
| 82 |
+
iterations,
|
| 83 |
+
d_model,
|
| 84 |
+
d_input,
|
| 85 |
+
heads,
|
| 86 |
+
n_synch_out,
|
| 87 |
+
n_synch_action,
|
| 88 |
+
synapse_depth,
|
| 89 |
+
memory_length,
|
| 90 |
+
deep_nlms,
|
| 91 |
+
memory_hidden_dims,
|
| 92 |
+
do_layernorm_nlm,
|
| 93 |
+
backbone_type,
|
| 94 |
+
positional_embedding_type,
|
| 95 |
+
out_dims,
|
| 96 |
+
prediction_reshaper=[-1],
|
| 97 |
+
dropout=0,
|
| 98 |
+
dropout_nlm=None,
|
| 99 |
+
neuron_select_type='random-pairing',
|
| 100 |
+
n_random_pairing_self=0,
|
| 101 |
+
):
|
| 102 |
+
super(ContinuousThoughtMachine, self).__init__()
|
| 103 |
+
|
| 104 |
+
# --- Core Parameters ---
|
| 105 |
+
self.iterations = iterations
|
| 106 |
+
self.d_model = d_model
|
| 107 |
+
self.d_input = d_input
|
| 108 |
+
self.memory_length = memory_length
|
| 109 |
+
self.prediction_reshaper = prediction_reshaper
|
| 110 |
+
self.n_synch_out = n_synch_out
|
| 111 |
+
self.n_synch_action = n_synch_action
|
| 112 |
+
self.backbone_type = backbone_type
|
| 113 |
+
self.out_dims = out_dims
|
| 114 |
+
self.positional_embedding_type = positional_embedding_type
|
| 115 |
+
self.neuron_select_type = neuron_select_type
|
| 116 |
+
self.memory_length = memory_length
|
| 117 |
+
dropout_nlm = dropout if dropout_nlm is None else dropout_nlm
|
| 118 |
+
|
| 119 |
+
# --- Assertions ---
|
| 120 |
+
self.verify_args()
|
| 121 |
+
|
| 122 |
+
# --- Input Processing ---
|
| 123 |
+
d_backbone = self.get_d_backbone()
|
| 124 |
+
self.set_initial_rgb()
|
| 125 |
+
self.set_backbone()
|
| 126 |
+
self.positional_embedding = self.get_positional_embedding(d_backbone)
|
| 127 |
+
self.kv_proj = nn.Sequential(nn.LazyLinear(self.d_input), nn.LayerNorm(self.d_input)) if heads else None
|
| 128 |
+
self.q_proj = nn.LazyLinear(self.d_input) if heads else None
|
| 129 |
+
self.attention = nn.MultiheadAttention(self.d_input, heads, dropout, batch_first=True) if heads else None
|
| 130 |
+
|
| 131 |
+
# --- Core CTM Modules ---
|
| 132 |
+
self.synapses = self.get_synapses(synapse_depth, d_model, dropout)
|
| 133 |
+
self.trace_processor = self.get_neuron_level_models(deep_nlms, do_layernorm_nlm, memory_length, memory_hidden_dims, d_model, dropout_nlm)
|
| 134 |
+
|
| 135 |
+
# --- Start States ---
|
| 136 |
+
self.register_parameter('start_activated_state', nn.Parameter(torch.zeros((d_model)).uniform_(-math.sqrt(1/(d_model)), math.sqrt(1/(d_model)))))
|
| 137 |
+
self.register_parameter('start_trace', nn.Parameter(torch.zeros((d_model, memory_length)).uniform_(-math.sqrt(1/(d_model+memory_length)), math.sqrt(1/(d_model+memory_length)))))
|
| 138 |
+
|
| 139 |
+
# --- Synchronisation ---
|
| 140 |
+
self.neuron_select_type_out, self.neuron_select_type_action = self.get_neuron_select_type()
|
| 141 |
+
self.synch_representation_size_action = self.calculate_synch_representation_size(self.n_synch_action)
|
| 142 |
+
self.synch_representation_size_out = self.calculate_synch_representation_size(self.n_synch_out)
|
| 143 |
+
|
| 144 |
+
for synch_type, size in (('action', self.synch_representation_size_action), ('out', self.synch_representation_size_out)):
|
| 145 |
+
print(f"Synch representation size {synch_type}: {size}")
|
| 146 |
+
if self.synch_representation_size_action: # if not zero
|
| 147 |
+
self.set_synchronisation_parameters('action', self.n_synch_action, n_random_pairing_self)
|
| 148 |
+
self.set_synchronisation_parameters('out', self.n_synch_out, n_random_pairing_self)
|
| 149 |
+
|
| 150 |
+
# --- Output Procesing ---
|
| 151 |
+
self.output_projector = nn.Sequential(nn.LazyLinear(self.out_dims))
|
| 152 |
+
|
| 153 |
+
@classmethod
|
| 154 |
+
def _from_pretrained(
|
| 155 |
+
cls,
|
| 156 |
+
*,
|
| 157 |
+
model_id: str,
|
| 158 |
+
revision=None,
|
| 159 |
+
cache_dir=None,
|
| 160 |
+
force_download=False,
|
| 161 |
+
proxies=None,
|
| 162 |
+
resume_download=None,
|
| 163 |
+
local_files_only=False,
|
| 164 |
+
token=None,
|
| 165 |
+
map_location="cpu",
|
| 166 |
+
strict=False,
|
| 167 |
+
**model_kwargs,
|
| 168 |
+
):
|
| 169 |
+
"""Override to handle lazy weights initialization."""
|
| 170 |
+
model = cls(**model_kwargs).to(map_location)
|
| 171 |
+
|
| 172 |
+
# The CTM contains Lazy modules, so we must run a dummy forward pass to initialize them
|
| 173 |
+
if "imagenet" in model_id:
|
| 174 |
+
dummy_input = torch.randn(1, 3, 224, 224, device=map_location)
|
| 175 |
+
elif "maze-large" in model_id:
|
| 176 |
+
dummy_input = torch.randn(1, 3, 99, 99, device=map_location)
|
| 177 |
+
else:
|
| 178 |
+
raise NotImplementedError
|
| 179 |
+
|
| 180 |
+
with torch.no_grad():
|
| 181 |
+
_ = model(dummy_input)
|
| 182 |
+
|
| 183 |
+
model_file = hf_hub_download(
|
| 184 |
+
repo_id=model_id,
|
| 185 |
+
filename="model.safetensors",
|
| 186 |
+
revision=revision,
|
| 187 |
+
cache_dir=cache_dir,
|
| 188 |
+
force_download=force_download,
|
| 189 |
+
proxies=proxies,
|
| 190 |
+
resume_download=resume_download,
|
| 191 |
+
token=token,
|
| 192 |
+
local_files_only=local_files_only,
|
| 193 |
+
)
|
| 194 |
+
from safetensors.torch import load_model as load_model_as_safetensor
|
| 195 |
+
load_model_as_safetensor(model, model_file, strict=strict, device=map_location)
|
| 196 |
+
|
| 197 |
+
model.eval()
|
| 198 |
+
return model
|
| 199 |
+
|
| 200 |
+
# --- Core CTM Methods ---
|
| 201 |
+
|
| 202 |
+
def compute_synchronisation(self, activated_state, decay_alpha, decay_beta, r, synch_type):
|
| 203 |
+
"""
|
| 204 |
+
Computes synchronisation to be used as a vector representation.
|
| 205 |
+
|
| 206 |
+
A neuron has what we call a 'trace', which is a history (time series) that changes with internal
|
| 207 |
+
recurrence. i.e., it gets longer with every internal tick. There are pre-activation traces
|
| 208 |
+
that are used in the NLMs and post-activation traces that, in theory, are used in this method.
|
| 209 |
+
|
| 210 |
+
We define sychronisation between neuron i and j as the dot product between their respective
|
| 211 |
+
time series. Since there can be many internal ticks, this process can be quite compute heavy as it
|
| 212 |
+
involves many dot products that repeat computation at each step.
|
| 213 |
+
|
| 214 |
+
Therefore, in practice, we update the synchronisation based on the current post-activations,
|
| 215 |
+
which we call the 'activated state' here. This is possible because the inputs to synchronisation
|
| 216 |
+
are only updated recurrently at each step, meaning that there is a linear recurrence we can
|
| 217 |
+
leverage.
|
| 218 |
+
|
| 219 |
+
See Appendix TODO of the Technical Report (TODO:LINK) for the maths that enables this method.
|
| 220 |
+
"""
|
| 221 |
+
|
| 222 |
+
if synch_type == 'action': # Get action parameters
|
| 223 |
+
n_synch = self.n_synch_action
|
| 224 |
+
neuron_indices_left = self.action_neuron_indices_left
|
| 225 |
+
neuron_indices_right = self.action_neuron_indices_right
|
| 226 |
+
elif synch_type == 'out': # Get input parameters
|
| 227 |
+
n_synch = self.n_synch_out
|
| 228 |
+
neuron_indices_left = self.out_neuron_indices_left
|
| 229 |
+
neuron_indices_right = self.out_neuron_indices_right
|
| 230 |
+
|
| 231 |
+
if self.neuron_select_type in ('first-last', 'random'):
|
| 232 |
+
# For first-last and random, we compute the pairwise sync between all selected neurons
|
| 233 |
+
if self.neuron_select_type == 'first-last':
|
| 234 |
+
if synch_type == 'action': # Use last n_synch neurons for action
|
| 235 |
+
selected_left = selected_right = activated_state[:, -n_synch:]
|
| 236 |
+
elif synch_type == 'out': # Use first n_synch neurons for out
|
| 237 |
+
selected_left = selected_right = activated_state[:, :n_synch]
|
| 238 |
+
else: # Use the randomly selected neurons
|
| 239 |
+
selected_left = activated_state[:, neuron_indices_left]
|
| 240 |
+
selected_right = activated_state[:, neuron_indices_right]
|
| 241 |
+
|
| 242 |
+
# Compute outer product of selected neurons
|
| 243 |
+
outer = selected_left.unsqueeze(2) * selected_right.unsqueeze(1)
|
| 244 |
+
# Resulting matrix is symmetric, so we only need the upper triangle
|
| 245 |
+
i, j = torch.triu_indices(n_synch, n_synch)
|
| 246 |
+
pairwise_product = outer[:, i, j]
|
| 247 |
+
|
| 248 |
+
elif self.neuron_select_type == 'random-pairing':
|
| 249 |
+
# For random-pairing, we compute the sync between specific pairs of neurons
|
| 250 |
+
left = activated_state[:, neuron_indices_left]
|
| 251 |
+
right = activated_state[:, neuron_indices_right]
|
| 252 |
+
pairwise_product = left * right
|
| 253 |
+
else:
|
| 254 |
+
raise ValueError("Invalid neuron selection type")
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
# Compute synchronisation recurrently
|
| 259 |
+
if decay_alpha is None or decay_beta is None:
|
| 260 |
+
decay_alpha = pairwise_product
|
| 261 |
+
decay_beta = torch.ones_like(pairwise_product)
|
| 262 |
+
else:
|
| 263 |
+
decay_alpha = r * decay_alpha + pairwise_product
|
| 264 |
+
decay_beta = r * decay_beta + 1
|
| 265 |
+
|
| 266 |
+
synchronisation = decay_alpha / (torch.sqrt(decay_beta))
|
| 267 |
+
return synchronisation, decay_alpha, decay_beta
|
| 268 |
+
|
| 269 |
+
def compute_features(self, x):
|
| 270 |
+
"""
|
| 271 |
+
Compute the key-value features from the input data using the backbone.
|
| 272 |
+
"""
|
| 273 |
+
initial_rgb = self.initial_rgb(x)
|
| 274 |
+
self.kv_features = self.backbone(initial_rgb)
|
| 275 |
+
pos_emb = self.positional_embedding(self.kv_features)
|
| 276 |
+
combined_features = (self.kv_features + pos_emb).flatten(2).transpose(1, 2)
|
| 277 |
+
kv = self.kv_proj(combined_features)
|
| 278 |
+
return kv
|
| 279 |
+
|
| 280 |
+
def compute_certainty(self, current_prediction):
|
| 281 |
+
"""
|
| 282 |
+
Compute the certainty of the current prediction.
|
| 283 |
+
|
| 284 |
+
We define certainty as being 1-normalised entropy.
|
| 285 |
+
|
| 286 |
+
For legacy reasons we stack that in a 2D vector as this can be used for optimisation later.
|
| 287 |
+
"""
|
| 288 |
+
B = current_prediction.size(0)
|
| 289 |
+
reshaped_pred = current_prediction.reshape([B] + self.prediction_reshaper)
|
| 290 |
+
ne = compute_normalized_entropy(reshaped_pred)
|
| 291 |
+
current_certainty = torch.stack((ne, 1-ne), -1)
|
| 292 |
+
return current_certainty
|
| 293 |
+
|
| 294 |
+
# --- Setup Methods ---
|
| 295 |
+
|
| 296 |
+
def set_initial_rgb(self):
|
| 297 |
+
"""
|
| 298 |
+
This is largely to accommodate training on grayscale images and is legacy, but it
|
| 299 |
+
doesn't hurt the model in any way that we can tell.
|
| 300 |
+
"""
|
| 301 |
+
if 'resnet' in self.backbone_type:
|
| 302 |
+
self.initial_rgb = nn.LazyConv2d(3, 1, 1) # Adapts input channels lazily
|
| 303 |
+
else:
|
| 304 |
+
self.initial_rgb = nn.Identity()
|
| 305 |
+
|
| 306 |
+
def get_d_backbone(self):
|
| 307 |
+
"""
|
| 308 |
+
Get the dimensionality of the backbone output, to be used for positional embedding setup.
|
| 309 |
+
|
| 310 |
+
This is a little bit complicated for resnets, but the logic should be easy enough to read below.
|
| 311 |
+
"""
|
| 312 |
+
if self.backbone_type == 'shallow-wide':
|
| 313 |
+
return 2048
|
| 314 |
+
elif self.backbone_type == 'parity_backbone':
|
| 315 |
+
return self.d_input
|
| 316 |
+
elif 'resnet' in self.backbone_type:
|
| 317 |
+
if '18' in self.backbone_type or '34' in self.backbone_type:
|
| 318 |
+
if self.backbone_type.split('-')[1]=='1': return 64
|
| 319 |
+
elif self.backbone_type.split('-')[1]=='2': return 128
|
| 320 |
+
elif self.backbone_type.split('-')[1]=='3': return 256
|
| 321 |
+
elif self.backbone_type.split('-')[1]=='4': return 512
|
| 322 |
+
else:
|
| 323 |
+
raise NotImplementedError
|
| 324 |
+
else:
|
| 325 |
+
if self.backbone_type.split('-')[1]=='1': return 256
|
| 326 |
+
elif self.backbone_type.split('-')[1]=='2': return 512
|
| 327 |
+
elif self.backbone_type.split('-')[1]=='3': return 1024
|
| 328 |
+
elif self.backbone_type.split('-')[1]=='4': return 2048
|
| 329 |
+
else:
|
| 330 |
+
raise NotImplementedError
|
| 331 |
+
elif self.backbone_type == 'none':
|
| 332 |
+
return None
|
| 333 |
+
else:
|
| 334 |
+
raise ValueError(f"Invalid backbone_type: {self.backbone_type}")
|
| 335 |
+
|
| 336 |
+
def set_backbone(self):
|
| 337 |
+
"""
|
| 338 |
+
Set the backbone module based on the specified type.
|
| 339 |
+
"""
|
| 340 |
+
if self.backbone_type == 'shallow-wide':
|
| 341 |
+
self.backbone = ShallowWide()
|
| 342 |
+
elif self.backbone_type == 'parity_backbone':
|
| 343 |
+
d_backbone = self.get_d_backbone()
|
| 344 |
+
self.backbone = ParityBackbone(n_embeddings=2, d_embedding=d_backbone)
|
| 345 |
+
elif 'resnet' in self.backbone_type:
|
| 346 |
+
self.backbone = prepare_resnet_backbone(self.backbone_type)
|
| 347 |
+
elif self.backbone_type == 'none':
|
| 348 |
+
self.backbone = nn.Identity()
|
| 349 |
+
else:
|
| 350 |
+
raise ValueError(f"Invalid backbone_type: {self.backbone_type}")
|
| 351 |
+
|
| 352 |
+
def get_positional_embedding(self, d_backbone):
|
| 353 |
+
"""
|
| 354 |
+
Get the positional embedding module.
|
| 355 |
+
|
| 356 |
+
For Imagenet and mazes we used NO positional embedding, and largely don't think
|
| 357 |
+
that it is necessary as the CTM can build up its own internal world model when
|
| 358 |
+
observing.
|
| 359 |
+
|
| 360 |
+
LearnableFourierPositionalEncoding:
|
| 361 |
+
Implements Algorithm 1 from "Learnable Fourier Features for Multi-Dimensional
|
| 362 |
+
Spatial Positional Encoding" (https://arxiv.org/pdf/2106.02795.pdf).
|
| 363 |
+
Provides positional information for 2D feature maps.
|
| 364 |
+
|
| 365 |
+
(MultiLearnableFourierPositionalEncoding uses multiple feature scales)
|
| 366 |
+
|
| 367 |
+
CustomRotationalEmbedding:
|
| 368 |
+
Simple sinusoidal embedding to encourage interpretability
|
| 369 |
+
"""
|
| 370 |
+
if self.positional_embedding_type == 'learnable-fourier':
|
| 371 |
+
return LearnableFourierPositionalEncoding(d_backbone, gamma=1 / 2.5)
|
| 372 |
+
elif self.positional_embedding_type == 'multi-learnable-fourier':
|
| 373 |
+
return MultiLearnableFourierPositionalEncoding(d_backbone)
|
| 374 |
+
elif self.positional_embedding_type == 'custom-rotational':
|
| 375 |
+
return CustomRotationalEmbedding(d_backbone)
|
| 376 |
+
elif self.positional_embedding_type == 'custom-rotational-1d':
|
| 377 |
+
return CustomRotationalEmbedding1D(d_backbone)
|
| 378 |
+
elif self.positional_embedding_type == 'none':
|
| 379 |
+
return lambda x: 0 # Default no-op
|
| 380 |
+
else:
|
| 381 |
+
raise ValueError(f"Invalid positional_embedding_type: {self.positional_embedding_type}")
|
| 382 |
+
|
| 383 |
+
def get_neuron_level_models(self, deep_nlms, do_layernorm_nlm, memory_length, memory_hidden_dims, d_model, dropout):
|
| 384 |
+
"""
|
| 385 |
+
Neuron level models are one of the core innovations of the CTM. They apply separate MLPs/linears to
|
| 386 |
+
each neuron.
|
| 387 |
+
NOTE: the name 'SuperLinear' is largely legacy, but its purpose is to apply separate linear layers
|
| 388 |
+
per neuron. It is sort of a 'grouped linear' function, where the group size is equal to 1.
|
| 389 |
+
One could make the group size bigger and use fewer parameters, but that is future work.
|
| 390 |
+
|
| 391 |
+
NOTE: We used GLU() nonlinearities because they worked well in practice.
|
| 392 |
+
"""
|
| 393 |
+
if deep_nlms:
|
| 394 |
+
return nn.Sequential(
|
| 395 |
+
nn.Sequential(
|
| 396 |
+
SuperLinear(in_dims=memory_length, out_dims=2 * memory_hidden_dims, N=d_model,
|
| 397 |
+
do_norm=do_layernorm_nlm, dropout=dropout),
|
| 398 |
+
nn.GLU(),
|
| 399 |
+
SuperLinear(in_dims=memory_hidden_dims, out_dims=2, N=d_model,
|
| 400 |
+
do_norm=do_layernorm_nlm, dropout=dropout),
|
| 401 |
+
nn.GLU(),
|
| 402 |
+
Squeeze(-1)
|
| 403 |
+
)
|
| 404 |
+
)
|
| 405 |
+
else:
|
| 406 |
+
return nn.Sequential(
|
| 407 |
+
nn.Sequential(
|
| 408 |
+
SuperLinear(in_dims=memory_length, out_dims=2, N=d_model,
|
| 409 |
+
do_norm=do_layernorm_nlm, dropout=dropout),
|
| 410 |
+
nn.GLU(),
|
| 411 |
+
Squeeze(-1)
|
| 412 |
+
)
|
| 413 |
+
)
|
| 414 |
+
|
| 415 |
+
def get_synapses(self, synapse_depth, d_model, dropout):
|
| 416 |
+
"""
|
| 417 |
+
The synapse model is the recurrent model in the CTM. It's purpose is to share information
|
| 418 |
+
across neurons. If using depth of 1, this is just a simple single layer with nonlinearity and layernomr.
|
| 419 |
+
For deeper synapse models we use a U-NET structure with many skip connections. In practice this performs
|
| 420 |
+
better as it enables multi-level information mixing.
|
| 421 |
+
|
| 422 |
+
The intuition with having a deep UNET model for synapses is that the action of synaptic connections is
|
| 423 |
+
not necessarily a linear one, and that approximate a synapose 'update' step in the brain is non trivial.
|
| 424 |
+
Hence, we set it up so that the CTM can learn some complex internal rule instead of trying to approximate
|
| 425 |
+
it ourselves.
|
| 426 |
+
"""
|
| 427 |
+
if synapse_depth == 1:
|
| 428 |
+
return nn.Sequential(
|
| 429 |
+
nn.Dropout(dropout),
|
| 430 |
+
nn.LazyLinear(d_model * 2),
|
| 431 |
+
nn.GLU(),
|
| 432 |
+
nn.LayerNorm(d_model)
|
| 433 |
+
)
|
| 434 |
+
else:
|
| 435 |
+
return SynapseUNET(d_model, synapse_depth, 16, dropout) # hard-coded minimum width of 16; future work TODO.
|
| 436 |
+
|
| 437 |
+
def set_synchronisation_parameters(self, synch_type: str, n_synch: int, n_random_pairing_self: int = 0):
|
| 438 |
+
"""
|
| 439 |
+
1. Set the buffers for selecting neurons so that these indices are saved into the model state_dict.
|
| 440 |
+
2. Set the parameters for learnable exponential decay when computing synchronisation between all
|
| 441 |
+
neurons.
|
| 442 |
+
"""
|
| 443 |
+
assert synch_type in ('out', 'action'), f"Invalid synch_type: {synch_type}"
|
| 444 |
+
left, right = self.initialize_left_right_neurons(synch_type, self.d_model, n_synch, n_random_pairing_self)
|
| 445 |
+
synch_representation_size = self.synch_representation_size_action if synch_type == 'action' else self.synch_representation_size_out
|
| 446 |
+
self.register_buffer(f'{synch_type}_neuron_indices_left', left)
|
| 447 |
+
self.register_buffer(f'{synch_type}_neuron_indices_right', right)
|
| 448 |
+
self.register_parameter(f'decay_params_{synch_type}', nn.Parameter(torch.zeros(synch_representation_size), requires_grad=True))
|
| 449 |
+
|
| 450 |
+
def initialize_left_right_neurons(self, synch_type, d_model, n_synch, n_random_pairing_self=0):
|
| 451 |
+
"""
|
| 452 |
+
Initialize the left and right neuron indices based on the neuron selection type.
|
| 453 |
+
This complexity is owing to legacy experiments, but we retain that these types of
|
| 454 |
+
neuron selections are interesting to experiment with.
|
| 455 |
+
"""
|
| 456 |
+
if self.neuron_select_type=='first-last':
|
| 457 |
+
if synch_type == 'out':
|
| 458 |
+
neuron_indices_left = neuron_indices_right = torch.arange(0, n_synch)
|
| 459 |
+
elif synch_type == 'action':
|
| 460 |
+
neuron_indices_left = neuron_indices_right = torch.arange(d_model-n_synch, d_model)
|
| 461 |
+
|
| 462 |
+
elif self.neuron_select_type=='random':
|
| 463 |
+
neuron_indices_left = torch.from_numpy(np.random.choice(np.arange(d_model), size=n_synch))
|
| 464 |
+
neuron_indices_right = torch.from_numpy(np.random.choice(np.arange(d_model), size=n_synch))
|
| 465 |
+
|
| 466 |
+
elif self.neuron_select_type=='random-pairing':
|
| 467 |
+
assert n_synch > n_random_pairing_self, f"Need at least {n_random_pairing_self} pairs for {self.neuron_select_type}"
|
| 468 |
+
neuron_indices_left = torch.from_numpy(np.random.choice(np.arange(d_model), size=n_synch))
|
| 469 |
+
neuron_indices_right = torch.concatenate((neuron_indices_left[:n_random_pairing_self], torch.from_numpy(np.random.choice(np.arange(d_model), size=n_synch-n_random_pairing_self))))
|
| 470 |
+
|
| 471 |
+
device = self.start_activated_state.device
|
| 472 |
+
return neuron_indices_left.to(device), neuron_indices_right.to(device)
|
| 473 |
+
|
| 474 |
+
def get_neuron_select_type(self):
|
| 475 |
+
"""
|
| 476 |
+
Another helper method to accomodate our legacy neuron selection types.
|
| 477 |
+
TODO: additional experimentation and possible removal of 'first-last' and 'random'
|
| 478 |
+
"""
|
| 479 |
+
print(f"Using neuron select type: {self.neuron_select_type}")
|
| 480 |
+
if self.neuron_select_type == 'first-last':
|
| 481 |
+
neuron_select_type_out, neuron_select_type_action = 'first', 'last'
|
| 482 |
+
elif self.neuron_select_type in ('random', 'random-pairing'):
|
| 483 |
+
neuron_select_type_out = neuron_select_type_action = self.neuron_select_type
|
| 484 |
+
else:
|
| 485 |
+
raise ValueError(f"Invalid neuron selection type: {self.neuron_select_type}")
|
| 486 |
+
return neuron_select_type_out, neuron_select_type_action
|
| 487 |
+
|
| 488 |
+
# --- Utilty Methods ---
|
| 489 |
+
|
| 490 |
+
def verify_args(self):
|
| 491 |
+
"""
|
| 492 |
+
Verify the validity of the input arguments to ensure consistent behaviour.
|
| 493 |
+
Specifically when selecting neurons for sychronisation using 'first-last' or 'random',
|
| 494 |
+
one needs the right number of neurons
|
| 495 |
+
"""
|
| 496 |
+
assert self.neuron_select_type in VALID_NEURON_SELECT_TYPES, \
|
| 497 |
+
f"Invalid neuron selection type: {self.neuron_select_type}"
|
| 498 |
+
|
| 499 |
+
assert self.backbone_type in VALID_BACKBONE_TYPES + ['none'], \
|
| 500 |
+
f"Invalid backbone_type: {self.backbone_type}"
|
| 501 |
+
|
| 502 |
+
assert self.positional_embedding_type in VALID_POSITIONAL_EMBEDDING_TYPES + ['none'], \
|
| 503 |
+
f"Invalid positional_embedding_type: {self.positional_embedding_type}"
|
| 504 |
+
|
| 505 |
+
if self.neuron_select_type == 'first-last':
|
| 506 |
+
assert self.d_model >= (self.n_synch_out + self.n_synch_action), \
|
| 507 |
+
"d_model must be >= n_synch_out + n_synch_action for neuron subsets"
|
| 508 |
+
|
| 509 |
+
if self.backbone_type=='none' and self.positional_embedding_type!='none':
|
| 510 |
+
raise AssertionError("There should be no positional embedding if there is no backbone.")
|
| 511 |
+
|
| 512 |
+
def calculate_synch_representation_size(self, n_synch):
|
| 513 |
+
"""
|
| 514 |
+
Calculate the size of the synchronisation representation based on neuron selection type.
|
| 515 |
+
"""
|
| 516 |
+
if self.neuron_select_type == 'random-pairing':
|
| 517 |
+
synch_representation_size = n_synch
|
| 518 |
+
elif self.neuron_select_type in ('first-last', 'random'):
|
| 519 |
+
synch_representation_size = (n_synch * (n_synch + 1)) // 2
|
| 520 |
+
else:
|
| 521 |
+
raise ValueError(f"Invalid neuron selection type: {self.neuron_select_type}")
|
| 522 |
+
return synch_representation_size
|
| 523 |
+
|
| 524 |
+
|
| 525 |
+
|
| 526 |
+
|
| 527 |
+
def forward(self, x, track=False, adaptive_halt=True, certainty_threshold=0.85):
|
| 528 |
+
"""
|
| 529 |
+
Forward pass through CTM.
|
| 530 |
+
|
| 531 |
+
v3.0 Features:
|
| 532 |
+
- Adaptive Halting: Stop early if certainty > threshold
|
| 533 |
+
- Returns actual ticks used (may be < self.iterations)
|
| 534 |
+
|
| 535 |
+
Args:
|
| 536 |
+
x: Input tensor
|
| 537 |
+
track: Whether to track intermediate states
|
| 538 |
+
adaptive_halt: Enable early stopping (v3.0)
|
| 539 |
+
certainty_threshold: Halt if certainty exceeds this (v3.0)
|
| 540 |
+
"""
|
| 541 |
+
B = x.size(0)
|
| 542 |
+
device = x.device
|
| 543 |
+
|
| 544 |
+
# --- Tracking Initialization ---
|
| 545 |
+
pre_activations_tracking = []
|
| 546 |
+
post_activations_tracking = []
|
| 547 |
+
synch_out_tracking = []
|
| 548 |
+
synch_action_tracking = []
|
| 549 |
+
attention_tracking = []
|
| 550 |
+
|
| 551 |
+
# --- Featurise Input Data ---
|
| 552 |
+
kv = self.compute_features(x)
|
| 553 |
+
|
| 554 |
+
# --- Initialise Recurrent State ---
|
| 555 |
+
state_trace = self.start_trace.unsqueeze(0).expand(B, -1, -1) # Shape: (B, H, T)
|
| 556 |
+
activated_state = self.start_activated_state.unsqueeze(0).expand(B, -1) # Shape: (B, H)
|
| 557 |
+
|
| 558 |
+
# --- Prepare Storage for Outputs per Iteration ---
|
| 559 |
+
predictions = torch.empty(B, self.out_dims, self.iterations, device=device, dtype=torch.float32)
|
| 560 |
+
certainties = torch.empty(B, 2, self.iterations, device=device, dtype=torch.float32)
|
| 561 |
+
|
| 562 |
+
# --- Initialise Recurrent Synch Values ---
|
| 563 |
+
decay_alpha_action, decay_beta_action = None, None
|
| 564 |
+
self.decay_params_action.data = torch.clamp(self.decay_params_action, 0, 15) # Fix from github user: kuviki
|
| 565 |
+
self.decay_params_out.data = torch.clamp(self.decay_params_out, 0, 15)
|
| 566 |
+
r_action, r_out = torch.exp(-self.decay_params_action).unsqueeze(0).repeat(B, 1), torch.exp(-self.decay_params_out).unsqueeze(0).repeat(B, 1)
|
| 567 |
+
|
| 568 |
+
_, decay_alpha_out, decay_beta_out = self.compute_synchronisation(activated_state, None, None, r_out, synch_type='out')
|
| 569 |
+
# Compute learned weighting for synchronisation
|
| 570 |
+
|
| 571 |
+
# v3.0: Track actual ticks used
|
| 572 |
+
ticks_used = self.iterations
|
| 573 |
+
halted_early = False
|
| 574 |
+
|
| 575 |
+
# --- Recurrent Loop ---
|
| 576 |
+
for stepi in range(self.iterations):
|
| 577 |
+
|
| 578 |
+
# --- Calculate Synchronisation for Input Data Interaction ---
|
| 579 |
+
synchronisation_action, decay_alpha_action, decay_beta_action = self.compute_synchronisation(activated_state, decay_alpha_action, decay_beta_action, r_action, synch_type='action')
|
| 580 |
+
|
| 581 |
+
# --- Interact with Data via Attention ---
|
| 582 |
+
q = self.q_proj(synchronisation_action).unsqueeze(1)
|
| 583 |
+
attn_out, attn_weights = self.attention(q, kv, kv, average_attn_weights=False, need_weights=True)
|
| 584 |
+
attn_out = attn_out.squeeze(1)
|
| 585 |
+
pre_synapse_input = torch.concatenate((attn_out, activated_state), dim=-1)
|
| 586 |
+
|
| 587 |
+
# --- Apply Synapses ---
|
| 588 |
+
state = self.synapses(pre_synapse_input)
|
| 589 |
+
# The 'state_trace' is the history of incoming pre-activations
|
| 590 |
+
state_trace = torch.cat((state_trace[:, :, 1:], state.unsqueeze(-1)), dim=-1)
|
| 591 |
+
|
| 592 |
+
# --- Apply Neuron-Level Models ---
|
| 593 |
+
activated_state = self.trace_processor(state_trace)
|
| 594 |
+
# One would also keep an 'activated_state_trace' as the history of outgoing post-activations
|
| 595 |
+
# BUT, this is unnecessary because the synchronisation calculation is fully linear and can be
|
| 596 |
+
# done using only the currect activated state (see compute_synchronisation method for explanation)
|
| 597 |
+
|
| 598 |
+
# --- Calculate Synchronisation for Output Predictions ---
|
| 599 |
+
synchronisation_out, decay_alpha_out, decay_beta_out = self.compute_synchronisation(activated_state, decay_alpha_out, decay_beta_out, r_out, synch_type='out')
|
| 600 |
+
|
| 601 |
+
# --- Get Predictions and Certainties ---
|
| 602 |
+
current_prediction = self.output_projector(synchronisation_out)
|
| 603 |
+
current_certainty = self.compute_certainty(current_prediction)
|
| 604 |
+
|
| 605 |
+
predictions[..., stepi] = current_prediction
|
| 606 |
+
certainties[..., stepi] = current_certainty
|
| 607 |
+
|
| 608 |
+
# --- Tracking ---
|
| 609 |
+
if track:
|
| 610 |
+
pre_activations_tracking.append(state_trace[:,:,-1].detach().cpu().numpy())
|
| 611 |
+
post_activations_tracking.append(activated_state.detach().cpu().numpy())
|
| 612 |
+
attention_tracking.append(attn_weights.detach().cpu().numpy())
|
| 613 |
+
synch_out_tracking.append(synchronisation_out.detach().cpu().numpy())
|
| 614 |
+
synch_action_tracking.append(synchronisation_action.detach().cpu().numpy())
|
| 615 |
+
|
| 616 |
+
# --- v3.0: Adaptive Halting ---
|
| 617 |
+
if adaptive_halt and not self.training:
|
| 618 |
+
# Check if all samples in batch are confident enough
|
| 619 |
+
batch_certainty = current_certainty[:, 1].mean().item() # 1-entropy
|
| 620 |
+
if batch_certainty > certainty_threshold:
|
| 621 |
+
ticks_used = stepi + 1
|
| 622 |
+
halted_early = True
|
| 623 |
+
# Fill remaining predictions with current value
|
| 624 |
+
for remaining in range(stepi + 1, self.iterations):
|
| 625 |
+
predictions[..., remaining] = current_prediction
|
| 626 |
+
certainties[..., remaining] = current_certainty
|
| 627 |
+
break
|
| 628 |
+
|
| 629 |
+
# --- Return Values ---
|
| 630 |
+
if track:
|
| 631 |
+
return predictions, certainties, (np.array(synch_out_tracking), np.array(synch_action_tracking)), np.array(pre_activations_tracking), np.array(post_activations_tracking), np.array(attention_tracking)
|
| 632 |
+
return predictions, certainties, synchronisation_out
|
| 633 |
+
|
models/ctm_qamnist.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
from models.ctm import ContinuousThoughtMachine
|
| 4 |
+
from models.modules import MNISTBackbone, QAMNISTIndexEmbeddings, QAMNISTOperatorEmbeddings
|
| 5 |
+
|
| 6 |
+
class ContinuousThoughtMachineQAMNIST(ContinuousThoughtMachine):
|
| 7 |
+
def __init__(self,
|
| 8 |
+
iterations,
|
| 9 |
+
d_model,
|
| 10 |
+
d_input,
|
| 11 |
+
heads,
|
| 12 |
+
n_synch_out,
|
| 13 |
+
n_synch_action,
|
| 14 |
+
synapse_depth,
|
| 15 |
+
memory_length,
|
| 16 |
+
deep_nlms,
|
| 17 |
+
memory_hidden_dims,
|
| 18 |
+
do_layernorm_nlm,
|
| 19 |
+
out_dims,
|
| 20 |
+
iterations_per_digit,
|
| 21 |
+
iterations_per_question_part,
|
| 22 |
+
iterations_for_answering,
|
| 23 |
+
prediction_reshaper=[-1],
|
| 24 |
+
dropout=0,
|
| 25 |
+
neuron_select_type='first-last',
|
| 26 |
+
n_random_pairing_self=256
|
| 27 |
+
):
|
| 28 |
+
super().__init__(
|
| 29 |
+
iterations=iterations,
|
| 30 |
+
d_model=d_model,
|
| 31 |
+
d_input=d_input,
|
| 32 |
+
heads=heads,
|
| 33 |
+
n_synch_out=n_synch_out,
|
| 34 |
+
n_synch_action=n_synch_action,
|
| 35 |
+
synapse_depth=synapse_depth,
|
| 36 |
+
memory_length=memory_length,
|
| 37 |
+
deep_nlms=deep_nlms,
|
| 38 |
+
memory_hidden_dims=memory_hidden_dims,
|
| 39 |
+
do_layernorm_nlm=do_layernorm_nlm,
|
| 40 |
+
out_dims=out_dims,
|
| 41 |
+
prediction_reshaper=prediction_reshaper,
|
| 42 |
+
dropout=dropout,
|
| 43 |
+
neuron_select_type=neuron_select_type,
|
| 44 |
+
n_random_pairing_self=n_random_pairing_self,
|
| 45 |
+
backbone_type='none',
|
| 46 |
+
positional_embedding_type='none',
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
# --- Core Parameters ---
|
| 50 |
+
self.iterations_per_digit = iterations_per_digit
|
| 51 |
+
self.iterations_per_question_part = iterations_per_question_part
|
| 52 |
+
self.iterations_for_answering = iterations_for_answering
|
| 53 |
+
|
| 54 |
+
# --- Setup Methods ---
|
| 55 |
+
|
| 56 |
+
def set_initial_rgb(self):
|
| 57 |
+
"""Set the initial RGB values for the backbone."""
|
| 58 |
+
return None
|
| 59 |
+
|
| 60 |
+
def get_d_backbone(self):
|
| 61 |
+
"""Get the dimensionality of the backbone output."""
|
| 62 |
+
return self.d_input
|
| 63 |
+
|
| 64 |
+
def set_backbone(self):
|
| 65 |
+
"""Set the backbone module based on the specified type."""
|
| 66 |
+
self.backbone_digit = MNISTBackbone(self.d_input)
|
| 67 |
+
self.index_backbone = QAMNISTIndexEmbeddings(50, self.d_input)
|
| 68 |
+
self.operator_backbone = QAMNISTOperatorEmbeddings(2, self.d_input)
|
| 69 |
+
pass
|
| 70 |
+
|
| 71 |
+
# --- Utilty Methods ---
|
| 72 |
+
|
| 73 |
+
def determine_step_type(self, total_iterations_for_digits, total_iterations_for_question, stepi: int):
|
| 74 |
+
"""Determine whether the current step is for digits, questions, or answers."""
|
| 75 |
+
is_digit_step = stepi < total_iterations_for_digits
|
| 76 |
+
is_question_step = total_iterations_for_digits <= stepi < total_iterations_for_digits + total_iterations_for_question
|
| 77 |
+
is_answer_step = stepi >= total_iterations_for_digits + total_iterations_for_question
|
| 78 |
+
return is_digit_step, is_question_step, is_answer_step
|
| 79 |
+
|
| 80 |
+
def determine_index_operator_step_type(self, total_iterations_for_digits, stepi: int):
|
| 81 |
+
"""Determine whether the current step is for index or operator."""
|
| 82 |
+
step_within_questions = stepi - total_iterations_for_digits
|
| 83 |
+
if step_within_questions % (2 * self.iterations_per_question_part) < self.iterations_per_question_part:
|
| 84 |
+
is_index_step = True
|
| 85 |
+
is_operator_step = False
|
| 86 |
+
else:
|
| 87 |
+
is_index_step = False
|
| 88 |
+
is_operator_step = True
|
| 89 |
+
return is_index_step, is_operator_step
|
| 90 |
+
|
| 91 |
+
def get_kv_for_step(self, total_iterations_for_digits, total_iterations_for_question, stepi, x, z, prev_input=None, prev_kv=None):
|
| 92 |
+
"""Get the key-value for the current step."""
|
| 93 |
+
is_digit_step, is_question_step, is_answer_step = self.determine_step_type(total_iterations_for_digits, total_iterations_for_question, stepi)
|
| 94 |
+
|
| 95 |
+
if is_digit_step:
|
| 96 |
+
current_input = x[:, stepi]
|
| 97 |
+
if prev_input is not None and torch.equal(current_input, prev_input):
|
| 98 |
+
return prev_kv, prev_input
|
| 99 |
+
kv = self.kv_proj(self.backbone_digit(current_input).flatten(2).permute(0, 2, 1))
|
| 100 |
+
|
| 101 |
+
elif is_question_step:
|
| 102 |
+
offset = stepi - total_iterations_for_digits
|
| 103 |
+
current_input = z[:, offset]
|
| 104 |
+
if prev_input is not None and torch.equal(current_input, prev_input):
|
| 105 |
+
return prev_kv, prev_input
|
| 106 |
+
is_index_step, is_operator_step = self.determine_index_operator_step_type(total_iterations_for_digits, stepi)
|
| 107 |
+
if is_index_step:
|
| 108 |
+
kv = self.index_backbone(current_input)
|
| 109 |
+
elif is_operator_step:
|
| 110 |
+
kv = self.operator_backbone(current_input)
|
| 111 |
+
else:
|
| 112 |
+
raise ValueError("Invalid step type for question processing.")
|
| 113 |
+
|
| 114 |
+
elif is_answer_step:
|
| 115 |
+
current_input = None
|
| 116 |
+
kv = torch.zeros((x.size(0), self.d_input), device=x.device)
|
| 117 |
+
|
| 118 |
+
else:
|
| 119 |
+
raise ValueError("Invalid step type.")
|
| 120 |
+
|
| 121 |
+
return kv, current_input
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def forward(self, x, z, track=False):
|
| 127 |
+
B = x.size(0)
|
| 128 |
+
device = x.device
|
| 129 |
+
|
| 130 |
+
# --- Tracking Initialization ---
|
| 131 |
+
pre_activations_tracking = []
|
| 132 |
+
post_activations_tracking = []
|
| 133 |
+
attention_tracking = []
|
| 134 |
+
embedding_tracking = []
|
| 135 |
+
|
| 136 |
+
total_iterations_for_digits = x.size(1)
|
| 137 |
+
total_iterations_for_question = z.size(1)
|
| 138 |
+
total_iterations = total_iterations_for_digits + total_iterations_for_question + self.iterations_for_answering
|
| 139 |
+
|
| 140 |
+
# --- Initialise Recurrent State ---
|
| 141 |
+
state_trace = self.start_trace.unsqueeze(0).expand(B, -1, -1) # Shape: (B, H, T)
|
| 142 |
+
activated_state = self.start_activated_state.unsqueeze(0).expand(B, -1) # Shape: (B, H)
|
| 143 |
+
|
| 144 |
+
# --- Storage for outputs per iteration ---
|
| 145 |
+
predictions = torch.empty(B, self.out_dims, total_iterations, device=device, dtype=x.dtype)
|
| 146 |
+
certainties = torch.empty(B, 2, total_iterations, device=device, dtype=x.dtype)
|
| 147 |
+
|
| 148 |
+
# --- Initialise Recurrent Synch Values ---
|
| 149 |
+
decay_alpha_action, decay_beta_action = None, None
|
| 150 |
+
self.decay_params_action.data = torch.clamp(self.decay_params_action, 0, 15) # Fix from github user: kuviki
|
| 151 |
+
self.decay_params_out.data = torch.clamp(self.decay_params_out, 0, 15)
|
| 152 |
+
r_action, r_out = torch.exp(-self.decay_params_action).unsqueeze(0).repeat(B, 1), torch.exp(-self.decay_params_out).unsqueeze(0).repeat(B, 1)
|
| 153 |
+
|
| 154 |
+
_, decay_alpha_out, decay_beta_out = self.compute_synchronisation(activated_state, None, None, r_out, synch_type='out')
|
| 155 |
+
|
| 156 |
+
prev_input = None
|
| 157 |
+
prev_kv = None
|
| 158 |
+
|
| 159 |
+
# --- Recurrent Loop ---
|
| 160 |
+
for stepi in range(total_iterations):
|
| 161 |
+
is_digit_step, is_question_step, is_answer_step = self.determine_step_type(total_iterations_for_digits, total_iterations_for_question, stepi)
|
| 162 |
+
|
| 163 |
+
kv, prev_input = self.get_kv_for_step(total_iterations_for_digits, total_iterations_for_question, stepi, x, z, prev_input, prev_kv)
|
| 164 |
+
prev_kv = kv
|
| 165 |
+
|
| 166 |
+
synchronization_action, decay_alpha_action, decay_beta_action = self.compute_synchronisation(activated_state, decay_alpha_action, decay_beta_action, r_action, synch_type='action')
|
| 167 |
+
|
| 168 |
+
# --- Interact with Data via Attention ---
|
| 169 |
+
attn_weights = None
|
| 170 |
+
if is_digit_step:
|
| 171 |
+
q = self.q_proj(synchronization_action).unsqueeze(1)
|
| 172 |
+
attn_out, attn_weights = self.attention(q, kv, kv, average_attn_weights=False, need_weights=True)
|
| 173 |
+
attn_out = attn_out.squeeze(1)
|
| 174 |
+
pre_synapse_input = torch.concatenate((attn_out, activated_state), dim=-1)
|
| 175 |
+
else:
|
| 176 |
+
kv = kv.squeeze(1)
|
| 177 |
+
pre_synapse_input = torch.concatenate((kv, activated_state), dim=-1)
|
| 178 |
+
|
| 179 |
+
# --- Apply Synapses ---
|
| 180 |
+
state = self.synapses(pre_synapse_input)
|
| 181 |
+
state_trace = torch.cat((state_trace[:, :, 1:], state.unsqueeze(-1)), dim=-1)
|
| 182 |
+
|
| 183 |
+
# --- Apply NLMs ---
|
| 184 |
+
activated_state = self.trace_processor(state_trace)
|
| 185 |
+
|
| 186 |
+
# --- Calculate Synchronisation for Output Predictions ---
|
| 187 |
+
synchronization_out, decay_alpha_out, decay_beta_out = self.compute_synchronisation(activated_state, decay_alpha_out, decay_beta_out, r_out, synch_type='out')
|
| 188 |
+
|
| 189 |
+
# --- Get Predictions and Certainties ---
|
| 190 |
+
current_prediction = self.output_projector(synchronization_out)
|
| 191 |
+
current_certainty = self.compute_certainty(current_prediction)
|
| 192 |
+
|
| 193 |
+
predictions[..., stepi] = current_prediction
|
| 194 |
+
certainties[..., stepi] = current_certainty
|
| 195 |
+
|
| 196 |
+
# --- Tracking ---
|
| 197 |
+
if track:
|
| 198 |
+
pre_activations_tracking.append(state_trace[:,:,-1].detach().cpu().numpy())
|
| 199 |
+
post_activations_tracking.append(activated_state.detach().cpu().numpy())
|
| 200 |
+
if attn_weights is not None:
|
| 201 |
+
attention_tracking.append(attn_weights.detach().cpu().numpy())
|
| 202 |
+
if is_question_step:
|
| 203 |
+
embedding_tracking.append(kv.detach().cpu().numpy())
|
| 204 |
+
|
| 205 |
+
# --- Return Values ---
|
| 206 |
+
if track:
|
| 207 |
+
return predictions, certainties, synchronization_out, np.array(pre_activations_tracking), np.array(post_activations_tracking), np.array(attention_tracking), np.array(embedding_tracking)
|
| 208 |
+
return predictions, certainties, synchronization_out
|
models/ctm_rl.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import numpy as np
|
| 4 |
+
import math
|
| 5 |
+
from models.ctm import ContinuousThoughtMachine
|
| 6 |
+
from models.modules import MiniGridBackbone, ClassicControlBackbone, SynapseUNET
|
| 7 |
+
from models.utils import compute_decay
|
| 8 |
+
from models.constants import VALID_NEURON_SELECT_TYPES
|
| 9 |
+
|
| 10 |
+
class ContinuousThoughtMachineRL(ContinuousThoughtMachine):
|
| 11 |
+
def __init__(self,
|
| 12 |
+
iterations,
|
| 13 |
+
d_model,
|
| 14 |
+
d_input,
|
| 15 |
+
n_synch_out,
|
| 16 |
+
synapse_depth,
|
| 17 |
+
memory_length,
|
| 18 |
+
deep_nlms,
|
| 19 |
+
memory_hidden_dims,
|
| 20 |
+
do_layernorm_nlm,
|
| 21 |
+
backbone_type,
|
| 22 |
+
prediction_reshaper=[-1],
|
| 23 |
+
dropout=0,
|
| 24 |
+
neuron_select_type='first-last',
|
| 25 |
+
):
|
| 26 |
+
super().__init__(
|
| 27 |
+
iterations=iterations,
|
| 28 |
+
d_model=d_model,
|
| 29 |
+
d_input=d_input,
|
| 30 |
+
heads=0, # Set heads to 0 will return None
|
| 31 |
+
n_synch_out=n_synch_out,
|
| 32 |
+
n_synch_action=0,
|
| 33 |
+
synapse_depth=synapse_depth,
|
| 34 |
+
memory_length=memory_length,
|
| 35 |
+
deep_nlms=deep_nlms,
|
| 36 |
+
memory_hidden_dims=memory_hidden_dims,
|
| 37 |
+
do_layernorm_nlm=do_layernorm_nlm,
|
| 38 |
+
out_dims=0,
|
| 39 |
+
prediction_reshaper=prediction_reshaper,
|
| 40 |
+
dropout=dropout,
|
| 41 |
+
neuron_select_type=neuron_select_type,
|
| 42 |
+
backbone_type=backbone_type,
|
| 43 |
+
n_random_pairing_self=0,
|
| 44 |
+
positional_embedding_type='none',
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
# --- Use a minimal CTM w/out input (action) synch ---
|
| 48 |
+
self.neuron_select_type_action = None
|
| 49 |
+
self.synch_representation_size_action = None
|
| 50 |
+
|
| 51 |
+
# --- Start dynamics with a learned activated state trace ---
|
| 52 |
+
self.register_parameter('start_activated_trace', nn.Parameter(torch.zeros((d_model, memory_length)).uniform_(-math.sqrt(1/(d_model+memory_length)), math.sqrt(1/(d_model+memory_length))), requires_grad=True))
|
| 53 |
+
self.start_activated_state = None
|
| 54 |
+
|
| 55 |
+
self.register_buffer('diagonal_mask_out', torch.triu(torch.ones(self.n_synch_out, self.n_synch_out, dtype=torch.bool)))
|
| 56 |
+
|
| 57 |
+
self.attention = None # Should already be None because super(... heads=0... )
|
| 58 |
+
self.q_proj = None # Should already be None because super(... heads=0... )
|
| 59 |
+
self.kv_proj = None # Should already be None because super(... heads=0... )
|
| 60 |
+
self.output_projector = None
|
| 61 |
+
|
| 62 |
+
# --- Core CTM Methods ---
|
| 63 |
+
|
| 64 |
+
def compute_synchronisation(self, activated_state_trace):
|
| 65 |
+
"""Compute the synchronisation between neurons."""
|
| 66 |
+
assert self.neuron_select_type == "first-last", "only fisrst-last neuron selection is supported here"
|
| 67 |
+
# For RL tasks we track a sliding window of activations from which we compute synchronisation
|
| 68 |
+
S = activated_state_trace.permute(0, 2, 1)
|
| 69 |
+
diagonal_mask = self.diagonal_mask_out.to(S.device)
|
| 70 |
+
decay = compute_decay(S.size(1), self.decay_params_out, clamp_lims=(0, 4))
|
| 71 |
+
synchronisation = ((decay.unsqueeze(0) *(S[:,:,-self.n_synch_out:].unsqueeze(-1) * S[:,:,-self.n_synch_out:].unsqueeze(-2))[:,:,diagonal_mask]).sum(1))/torch.sqrt(decay.unsqueeze(0).sum(1,))
|
| 72 |
+
return synchronisation
|
| 73 |
+
|
| 74 |
+
# --- Setup Methods ---
|
| 75 |
+
|
| 76 |
+
def set_initial_rgb(self):
|
| 77 |
+
"""Set the initial RGB values for the backbone."""
|
| 78 |
+
return None
|
| 79 |
+
|
| 80 |
+
def get_d_backbone(self):
|
| 81 |
+
"""Get the dimensionality of the backbone output."""
|
| 82 |
+
return self.d_input
|
| 83 |
+
|
| 84 |
+
def set_backbone(self):
|
| 85 |
+
"""Set the backbone module based on the specified type."""
|
| 86 |
+
if self.backbone_type == 'navigation-backbone':
|
| 87 |
+
self.backbone = MiniGridBackbone(self.d_input)
|
| 88 |
+
elif self.backbone_type == 'classic-control-backbone':
|
| 89 |
+
self.backbone = ClassicControlBackbone(self.d_input)
|
| 90 |
+
else:
|
| 91 |
+
raise NotImplemented('The only backbone supported for RL are for navigation (symbolic C x H x W inputs) and classic control (vectors of length D).')
|
| 92 |
+
pass
|
| 93 |
+
|
| 94 |
+
def get_positional_embedding(self, d_backbone):
|
| 95 |
+
"""Get the positional embedding module."""
|
| 96 |
+
return None
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def get_synapses(self, synapse_depth, d_model, dropout):
|
| 100 |
+
"""
|
| 101 |
+
Get the synapse module.
|
| 102 |
+
|
| 103 |
+
We found in our early experimentation that a single Linear, GLU and LayerNorm block performed worse than two blocks.
|
| 104 |
+
For that reason we set the default synapse depth to two blocks.
|
| 105 |
+
|
| 106 |
+
TODO: This is legacy and needs further experimentation to iron out.
|
| 107 |
+
"""
|
| 108 |
+
if synapse_depth == 1:
|
| 109 |
+
return nn.Sequential(
|
| 110 |
+
nn.Dropout(dropout),
|
| 111 |
+
nn.LazyLinear(d_model*2),
|
| 112 |
+
nn.GLU(),
|
| 113 |
+
nn.LayerNorm(d_model),
|
| 114 |
+
nn.LazyLinear(d_model*2),
|
| 115 |
+
nn.GLU(),
|
| 116 |
+
nn.LayerNorm(d_model)
|
| 117 |
+
)
|
| 118 |
+
else:
|
| 119 |
+
return SynapseUNET(d_model, synapse_depth, 16, dropout)
|
| 120 |
+
|
| 121 |
+
def set_synchronisation_parameters(self, synch_type: str, n_synch: int, n_random_pairing_self: int = 0):
|
| 122 |
+
"""Set the parameters for the synchronisation of neurons."""
|
| 123 |
+
if synch_type == 'action':
|
| 124 |
+
pass
|
| 125 |
+
elif synch_type == 'out':
|
| 126 |
+
left, right = self.initialize_left_right_neurons("out", self.d_model, n_synch, n_random_pairing_self)
|
| 127 |
+
self.register_buffer(f'out_neuron_indices_left', left)
|
| 128 |
+
self.register_buffer(f'out_neuron_indices_right', right)
|
| 129 |
+
self.register_parameter(f'decay_params_out', nn.Parameter(torch.zeros(self.synch_representation_size_out), requires_grad=True))
|
| 130 |
+
pass
|
| 131 |
+
else:
|
| 132 |
+
raise ValueError(f"Invalid synch_type: {synch_type}")
|
| 133 |
+
|
| 134 |
+
# --- Utilty Methods ---
|
| 135 |
+
|
| 136 |
+
def verify_args(self):
|
| 137 |
+
"""Verify the validity of the input arguments."""
|
| 138 |
+
assert self.neuron_select_type in VALID_NEURON_SELECT_TYPES, \
|
| 139 |
+
f"Invalid neuron selection type: {self.neuron_select_type}"
|
| 140 |
+
assert self.neuron_select_type != 'random-pairing', \
|
| 141 |
+
f"Random pairing is not supported for RL."
|
| 142 |
+
assert self.backbone_type in ('navigation-backbone', 'classic-control-backbone'), \
|
| 143 |
+
f"Invalid backbone_type: {self.backbone_type}"
|
| 144 |
+
assert self.d_model >= (self.n_synch_out), \
|
| 145 |
+
"d_model must be >= n_synch_out for neuron subsets"
|
| 146 |
+
pass
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def forward(self, x, hidden_states, track=False):
|
| 152 |
+
|
| 153 |
+
# --- Tracking Initialization ---
|
| 154 |
+
pre_activations_tracking = []
|
| 155 |
+
post_activations_tracking = []
|
| 156 |
+
|
| 157 |
+
# --- Featurise Input Data ---
|
| 158 |
+
features = self.backbone(x)
|
| 159 |
+
|
| 160 |
+
# --- Get Recurrent State ---
|
| 161 |
+
state_trace, activated_state_trace = hidden_states
|
| 162 |
+
|
| 163 |
+
# --- Recurrent Loop ---
|
| 164 |
+
for stepi in range(self.iterations):
|
| 165 |
+
|
| 166 |
+
pre_synapse_input = torch.concatenate((features.reshape(x.size(0), -1), activated_state_trace[:,:,-1]), -1)
|
| 167 |
+
|
| 168 |
+
# --- Apply Synapses ---
|
| 169 |
+
state = self.synapses(pre_synapse_input)
|
| 170 |
+
state_trace = torch.cat((state_trace[:, :, 1:], state.unsqueeze(-1)), dim=-1)
|
| 171 |
+
|
| 172 |
+
# --- Apply NLMs ---
|
| 173 |
+
activated_state = self.trace_processor(state_trace)
|
| 174 |
+
activated_state_trace = torch.concatenate((activated_state_trace[:,:,1:], activated_state.unsqueeze(-1)), -1)
|
| 175 |
+
|
| 176 |
+
# --- Tracking ---
|
| 177 |
+
if track:
|
| 178 |
+
pre_activations_tracking.append(state_trace[:,:,-1].detach().cpu().numpy())
|
| 179 |
+
post_activations_tracking.append(activated_state.detach().cpu().numpy())
|
| 180 |
+
|
| 181 |
+
hidden_states = (
|
| 182 |
+
state_trace,
|
| 183 |
+
activated_state_trace,
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
# --- Calculate Output Synchronisation ---
|
| 187 |
+
synchronisation_out = self.compute_synchronisation(activated_state_trace)
|
| 188 |
+
|
| 189 |
+
# --- Return Values ---
|
| 190 |
+
if track:
|
| 191 |
+
return synchronisation_out, hidden_states, np.array(pre_activations_tracking), np.array(post_activations_tracking)
|
| 192 |
+
return synchronisation_out, hidden_states
|
models/ctm_sort.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
from models.ctm import ContinuousThoughtMachine
|
| 4 |
+
|
| 5 |
+
class ContinuousThoughtMachineSORT(ContinuousThoughtMachine):
|
| 6 |
+
"""
|
| 7 |
+
Slight adaption of the CTM to work with the sort task.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
def __init__(self,
|
| 11 |
+
iterations,
|
| 12 |
+
d_model,
|
| 13 |
+
d_input,
|
| 14 |
+
heads,
|
| 15 |
+
n_synch_out,
|
| 16 |
+
n_synch_action,
|
| 17 |
+
synapse_depth,
|
| 18 |
+
memory_length,
|
| 19 |
+
deep_nlms,
|
| 20 |
+
memory_hidden_dims,
|
| 21 |
+
do_layernorm_nlm,
|
| 22 |
+
backbone_type,
|
| 23 |
+
positional_embedding_type,
|
| 24 |
+
out_dims,
|
| 25 |
+
prediction_reshaper=[-1],
|
| 26 |
+
dropout=0,
|
| 27 |
+
dropout_nlm=None,
|
| 28 |
+
neuron_select_type='random-pairing',
|
| 29 |
+
n_random_pairing_self=0,
|
| 30 |
+
):
|
| 31 |
+
super().__init__(
|
| 32 |
+
iterations=iterations,
|
| 33 |
+
d_model=d_model,
|
| 34 |
+
d_input=d_input,
|
| 35 |
+
heads=0,
|
| 36 |
+
n_synch_out=n_synch_out,
|
| 37 |
+
n_synch_action=0,
|
| 38 |
+
synapse_depth=synapse_depth,
|
| 39 |
+
memory_length=memory_length,
|
| 40 |
+
deep_nlms=deep_nlms,
|
| 41 |
+
memory_hidden_dims=memory_hidden_dims,
|
| 42 |
+
do_layernorm_nlm=do_layernorm_nlm,
|
| 43 |
+
backbone_type='none',
|
| 44 |
+
positional_embedding_type='none',
|
| 45 |
+
out_dims=out_dims,
|
| 46 |
+
prediction_reshaper=prediction_reshaper,
|
| 47 |
+
dropout=dropout,
|
| 48 |
+
dropout_nlm=dropout_nlm,
|
| 49 |
+
neuron_select_type=neuron_select_type,
|
| 50 |
+
n_random_pairing_self=n_random_pairing_self,
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
# --- Use a minimal CTM w/out input (action) synch ---
|
| 54 |
+
self.neuron_select_type_action = None
|
| 55 |
+
self.synch_representation_size_action = None
|
| 56 |
+
|
| 57 |
+
self.attention = None # Should already be None because super(... heads=0... )
|
| 58 |
+
self.q_proj = None # Should already be None because super(... heads=0... )
|
| 59 |
+
self.kv_proj = None # Should already be None because super(... heads=0... )
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def forward(self, x, track=False):
|
| 65 |
+
B = x.size(0)
|
| 66 |
+
device = x.device
|
| 67 |
+
|
| 68 |
+
# --- Tracking Initialization ---
|
| 69 |
+
pre_activations_tracking = []
|
| 70 |
+
post_activations_tracking = []
|
| 71 |
+
synch_out_tracking = []
|
| 72 |
+
attention_tracking = []
|
| 73 |
+
|
| 74 |
+
# --- For SORT: no need to featurise data ---
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
# --- Initialise Recurrent State ---
|
| 78 |
+
state_trace = self.start_trace.unsqueeze(0).expand(B, -1, -1) # Shape: (B, H, T)
|
| 79 |
+
activated_state = self.start_activated_state.unsqueeze(0).expand(B, -1) # Shape: (B, H)
|
| 80 |
+
|
| 81 |
+
# --- Prepare Storage for Outputs per Iteration ---
|
| 82 |
+
predictions = torch.empty(B, self.out_dims, self.iterations, device=device, dtype=x.dtype)
|
| 83 |
+
certainties = torch.empty(B, 2, self.iterations, device=device, dtype=x.dtype)
|
| 84 |
+
|
| 85 |
+
# --- Initialise Recurrent Synch Values ---
|
| 86 |
+
r_out = torch.exp(-torch.clamp(self.decay_params_out, 0, 15)).unsqueeze(0).repeat(B, 1)
|
| 87 |
+
_, decay_alpha_out, decay_beta_out = self.compute_synchronisation(activated_state, None, None, r_out, synch_type='out')
|
| 88 |
+
# Compute learned weighting for synchronisation
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
# --- Recurrent Loop ---
|
| 92 |
+
for stepi in range(self.iterations):
|
| 93 |
+
|
| 94 |
+
pre_synapse_input = torch.concatenate((x, activated_state), dim=-1)
|
| 95 |
+
|
| 96 |
+
# --- Apply Synapses ---
|
| 97 |
+
state = self.synapses(pre_synapse_input)
|
| 98 |
+
# The 'state_trace' is the history of incoming pre-activations
|
| 99 |
+
state_trace = torch.cat((state_trace[:, :, 1:], state.unsqueeze(-1)), dim=-1)
|
| 100 |
+
|
| 101 |
+
# --- Apply Neuron-Level Models ---
|
| 102 |
+
activated_state = self.trace_processor(state_trace)
|
| 103 |
+
# One would also keep an 'activated_state_trace' as the history of outgoing post-activations
|
| 104 |
+
# BUT, this is unnecessary because the synchronisation calculation is fully linear and can be
|
| 105 |
+
# done using only the currect activated state (see compute_synchronisation method for explanation)
|
| 106 |
+
|
| 107 |
+
# --- Calculate Synchronisation for Output Predictions ---
|
| 108 |
+
synchronisation_out, decay_alpha_out, decay_beta_out = self.compute_synchronisation(activated_state, decay_alpha_out, decay_beta_out, r_out, synch_type='out')
|
| 109 |
+
|
| 110 |
+
# --- Get Predictions and Certainties ---
|
| 111 |
+
current_prediction = self.output_projector(synchronisation_out)
|
| 112 |
+
current_certainty = self.compute_certainty(current_prediction)
|
| 113 |
+
|
| 114 |
+
predictions[..., stepi] = current_prediction
|
| 115 |
+
certainties[..., stepi] = current_certainty
|
| 116 |
+
|
| 117 |
+
# --- Tracking ---
|
| 118 |
+
if track:
|
| 119 |
+
pre_activations_tracking.append(state_trace[:,:,-1].detach().cpu().numpy())
|
| 120 |
+
post_activations_tracking.append(activated_state.detach().cpu().numpy())
|
| 121 |
+
synch_out_tracking.append(synchronisation_out.detach().cpu().numpy())
|
| 122 |
+
|
| 123 |
+
# --- Return Values ---
|
| 124 |
+
if track:
|
| 125 |
+
return predictions, certainties, np.array(synch_out_tracking), np.array(pre_activations_tracking), np.array(post_activations_tracking), np.array(attention_tracking)
|
| 126 |
+
return predictions, certainties, synchronisation_out
|
models/ff.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
|
| 3 |
+
# Local imports (Assuming these contain necessary custom modules)
|
| 4 |
+
from models.modules import *
|
| 5 |
+
from models.resnet import resnet18, resnet34, resnet50, resnet101, resnet152
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class FFBaseline(nn.Module):
|
| 9 |
+
"""
|
| 10 |
+
LSTM Baseline.
|
| 11 |
+
|
| 12 |
+
Wrapper that lets us use the same backbone as the CTM and LSTM baselines, with a
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
d_model (int): workaround that projects final layer to this space so that parameter-matching is plausible.
|
| 17 |
+
backbone_type (str): Type of feature extraction backbone (e.g., 'resnet18-2', 'none').
|
| 18 |
+
out_dims (int): Dimensionality of the final output projection.
|
| 19 |
+
dropout (float): dropout in last layer
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(self,
|
| 23 |
+
d_model,
|
| 24 |
+
backbone_type,
|
| 25 |
+
out_dims,
|
| 26 |
+
dropout=0,
|
| 27 |
+
):
|
| 28 |
+
super(FFBaseline, self).__init__()
|
| 29 |
+
|
| 30 |
+
# --- Core Parameters ---
|
| 31 |
+
self.d_model = d_model
|
| 32 |
+
self.backbone_type = backbone_type
|
| 33 |
+
self.out_dims = out_dims
|
| 34 |
+
|
| 35 |
+
# --- Input Assertions ---
|
| 36 |
+
assert backbone_type in ['resnet18-1', 'resnet18-2', 'resnet18-3', 'resnet18-4',
|
| 37 |
+
'resnet34-1', 'resnet34-2', 'resnet34-3', 'resnet34-4',
|
| 38 |
+
'resnet50-1', 'resnet50-2', 'resnet50-3', 'resnet50-4',
|
| 39 |
+
'resnet101-1', 'resnet101-2', 'resnet101-3', 'resnet101-4',
|
| 40 |
+
'resnet152-1', 'resnet152-2', 'resnet152-3', 'resnet152-4',
|
| 41 |
+
'none', 'shallow-wide', 'parity_backbone'], f"Invalid backbone_type: {backbone_type}"
|
| 42 |
+
|
| 43 |
+
# --- Backbone / Feature Extraction ---
|
| 44 |
+
self.initial_rgb = Identity() # Placeholder, potentially replaced if using ResNet
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
self.initial_rgb = nn.LazyConv2d(3, 1, 1) # Adapts input channels lazily
|
| 48 |
+
resnet_family = resnet18 # Default
|
| 49 |
+
if '34' in self.backbone_type: resnet_family = resnet34
|
| 50 |
+
if '50' in self.backbone_type: resnet_family = resnet50
|
| 51 |
+
if '101' in self.backbone_type: resnet_family = resnet101
|
| 52 |
+
if '152' in self.backbone_type: resnet_family = resnet152
|
| 53 |
+
|
| 54 |
+
# Determine which ResNet blocks to keep
|
| 55 |
+
block_num_str = self.backbone_type.split('-')[-1]
|
| 56 |
+
hyper_blocks_to_keep = list(range(1, int(block_num_str) + 1)) if block_num_str.isdigit() else [1, 2, 3, 4]
|
| 57 |
+
|
| 58 |
+
self.backbone = resnet_family(
|
| 59 |
+
3, # initial_rgb handles input channels now
|
| 60 |
+
hyper_blocks_to_keep,
|
| 61 |
+
stride=2,
|
| 62 |
+
pretrained=False,
|
| 63 |
+
progress=True,
|
| 64 |
+
device="cpu", # Initialise on CPU, move later via .to(device)
|
| 65 |
+
do_initial_max_pool=True,
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
# At this point we will have a 4D tensor of features: [B, C, H, W]
|
| 70 |
+
# The following lets us scale up the resnet with d_model until it matches the CTM
|
| 71 |
+
self.output_projector = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), Squeeze(-1), Squeeze(-1), nn.LazyLinear(d_model), nn.ReLU(), nn.Dropout(dropout), nn.Linear(d_model, out_dims))
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def forward(self, x):
|
| 75 |
+
return self.output_projector((self.backbone(self.initial_rgb(x))))
|
models/lstm.py
ADDED
|
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
import math
|
| 5 |
+
|
| 6 |
+
from models.modules import ParityBackbone, LearnableFourierPositionalEncoding, MultiLearnableFourierPositionalEncoding, CustomRotationalEmbedding, CustomRotationalEmbedding1D, ShallowWide
|
| 7 |
+
from models.resnet import prepare_resnet_backbone
|
| 8 |
+
from models.utils import compute_normalized_entropy
|
| 9 |
+
|
| 10 |
+
from models.constants import (
|
| 11 |
+
VALID_BACKBONE_TYPES,
|
| 12 |
+
VALID_POSITIONAL_EMBEDDING_TYPES
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
class LSTMBaseline(nn.Module):
|
| 16 |
+
"""
|
| 17 |
+
LSTM Baseline
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
iterations (int): Number of internal 'thought' steps (T, in paper).
|
| 21 |
+
d_model (int): Core dimensionality of the latent space.
|
| 22 |
+
d_input (int): Dimensionality of projected attention outputs or direct input features.
|
| 23 |
+
heads (int): Number of attention heads.
|
| 24 |
+
backbone_type (str): Type of feature extraction backbone (e.g., 'resnet18-2', 'none').
|
| 25 |
+
positional_embedding_type (str): Type of positional embedding for backbone features.
|
| 26 |
+
out_dims (int): Dimensionality of the final output projection.
|
| 27 |
+
prediction_reshaper (list): Shape for reshaping predictions before certainty calculation (task-specific).
|
| 28 |
+
dropout (float): Dropout rate.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
def __init__(self,
|
| 32 |
+
iterations,
|
| 33 |
+
d_model,
|
| 34 |
+
d_input,
|
| 35 |
+
heads,
|
| 36 |
+
backbone_type,
|
| 37 |
+
num_layers,
|
| 38 |
+
positional_embedding_type,
|
| 39 |
+
out_dims,
|
| 40 |
+
prediction_reshaper=[-1],
|
| 41 |
+
dropout=0,
|
| 42 |
+
):
|
| 43 |
+
super(LSTMBaseline, self).__init__()
|
| 44 |
+
|
| 45 |
+
# --- Core Parameters ---
|
| 46 |
+
self.iterations = iterations
|
| 47 |
+
self.d_model = d_model
|
| 48 |
+
self.d_input = d_input
|
| 49 |
+
self.prediction_reshaper = prediction_reshaper
|
| 50 |
+
self.backbone_type = backbone_type
|
| 51 |
+
self.positional_embedding_type = positional_embedding_type
|
| 52 |
+
self.out_dims = out_dims
|
| 53 |
+
|
| 54 |
+
# --- Assertions ---
|
| 55 |
+
self.verify_args()
|
| 56 |
+
|
| 57 |
+
# --- Input Processing ---
|
| 58 |
+
d_backbone = self.get_d_backbone()
|
| 59 |
+
|
| 60 |
+
self.set_initial_rgb()
|
| 61 |
+
self.set_backbone()
|
| 62 |
+
self.positional_embedding = self.get_positional_embedding(d_backbone)
|
| 63 |
+
self.kv_proj = self.get_kv_proj()
|
| 64 |
+
self.lstm = nn.LSTM(d_input, d_model, num_layers, batch_first=True, dropout=dropout)
|
| 65 |
+
self.q_proj = self.get_q_proj()
|
| 66 |
+
self.attention = self.get_attention(heads, dropout)
|
| 67 |
+
self.output_projector = nn.Sequential(nn.LazyLinear(out_dims))
|
| 68 |
+
|
| 69 |
+
# --- Start States ---
|
| 70 |
+
self.register_parameter('start_hidden_state', nn.Parameter(torch.zeros((num_layers, d_model)).uniform_(-math.sqrt(1/(d_model)), math.sqrt(1/(d_model))), requires_grad=True))
|
| 71 |
+
self.register_parameter('start_cell_state', nn.Parameter(torch.zeros((num_layers, d_model)).uniform_(-math.sqrt(1/(d_model)), math.sqrt(1/(d_model))), requires_grad=True))
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
# --- Core LSTM Methods ---
|
| 76 |
+
|
| 77 |
+
def compute_features(self, x):
|
| 78 |
+
"""Applies backbone and positional embedding to input."""
|
| 79 |
+
x = self.initial_rgb(x)
|
| 80 |
+
self.kv_features = self.backbone(x)
|
| 81 |
+
pos_emb = self.positional_embedding(self.kv_features)
|
| 82 |
+
combined_features = (self.kv_features + pos_emb).flatten(2).transpose(1, 2)
|
| 83 |
+
kv = self.kv_proj(combined_features)
|
| 84 |
+
return kv
|
| 85 |
+
|
| 86 |
+
def compute_certainty(self, current_prediction):
|
| 87 |
+
"""Compute the certainty of the current prediction."""
|
| 88 |
+
B = current_prediction.size(0)
|
| 89 |
+
reshaped_pred = current_prediction.reshape([B] +self.prediction_reshaper)
|
| 90 |
+
ne = compute_normalized_entropy(reshaped_pred)
|
| 91 |
+
current_certainty = torch.stack((ne, 1-ne), -1)
|
| 92 |
+
return current_certainty
|
| 93 |
+
|
| 94 |
+
# --- Setup Methods ---
|
| 95 |
+
|
| 96 |
+
def set_initial_rgb(self):
|
| 97 |
+
"""Set the initial RGB processing module based on the backbone type."""
|
| 98 |
+
if 'resnet' in self.backbone_type:
|
| 99 |
+
self.initial_rgb = nn.LazyConv2d(3, 1, 1) # Adapts input channels lazily
|
| 100 |
+
else:
|
| 101 |
+
self.initial_rgb = nn.Identity()
|
| 102 |
+
|
| 103 |
+
def get_d_backbone(self):
|
| 104 |
+
"""
|
| 105 |
+
Get the dimensionality of the backbone output, to be used for positional embedding setup.
|
| 106 |
+
|
| 107 |
+
This is a little bit complicated for resnets, but the logic should be easy enough to read below.
|
| 108 |
+
"""
|
| 109 |
+
if self.backbone_type == 'shallow-wide':
|
| 110 |
+
return 2048
|
| 111 |
+
elif self.backbone_type == 'parity_backbone':
|
| 112 |
+
return self.d_input
|
| 113 |
+
elif 'resnet' in self.backbone_type:
|
| 114 |
+
if '18' in self.backbone_type or '34' in self.backbone_type:
|
| 115 |
+
if self.backbone_type.split('-')[1]=='1': return 64
|
| 116 |
+
elif self.backbone_type.split('-')[1]=='2': return 128
|
| 117 |
+
elif self.backbone_type.split('-')[1]=='3': return 256
|
| 118 |
+
elif self.backbone_type.split('-')[1]=='4': return 512
|
| 119 |
+
else:
|
| 120 |
+
raise NotImplementedError
|
| 121 |
+
else:
|
| 122 |
+
if self.backbone_type.split('-')[1]=='1': return 256
|
| 123 |
+
elif self.backbone_type.split('-')[1]=='2': return 512
|
| 124 |
+
elif self.backbone_type.split('-')[1]=='3': return 1024
|
| 125 |
+
elif self.backbone_type.split('-')[1]=='4': return 2048
|
| 126 |
+
else:
|
| 127 |
+
raise NotImplementedError
|
| 128 |
+
elif self.backbone_type == 'none':
|
| 129 |
+
return None
|
| 130 |
+
else:
|
| 131 |
+
raise ValueError(f"Invalid backbone_type: {self.backbone_type}")
|
| 132 |
+
|
| 133 |
+
def set_backbone(self):
|
| 134 |
+
"""Set the backbone module based on the specified type."""
|
| 135 |
+
if self.backbone_type == 'shallow-wide':
|
| 136 |
+
self.backbone = ShallowWide()
|
| 137 |
+
elif self.backbone_type == 'parity_backbone':
|
| 138 |
+
d_backbone = self.get_d_backbone()
|
| 139 |
+
self.backbone = ParityBackbone(n_embeddings=2, d_embedding=d_backbone)
|
| 140 |
+
elif 'resnet' in self.backbone_type:
|
| 141 |
+
self.backbone = prepare_resnet_backbone(self.backbone_type)
|
| 142 |
+
elif self.backbone_type == 'none':
|
| 143 |
+
self.backbone = nn.Identity()
|
| 144 |
+
else:
|
| 145 |
+
raise ValueError(f"Invalid backbone_type: {self.backbone_type}")
|
| 146 |
+
|
| 147 |
+
def get_positional_embedding(self, d_backbone):
|
| 148 |
+
"""Get the positional embedding module."""
|
| 149 |
+
if self.positional_embedding_type == 'learnable-fourier':
|
| 150 |
+
return LearnableFourierPositionalEncoding(d_backbone, gamma=1 / 2.5)
|
| 151 |
+
elif self.positional_embedding_type == 'multi-learnable-fourier':
|
| 152 |
+
return MultiLearnableFourierPositionalEncoding(d_backbone)
|
| 153 |
+
elif self.positional_embedding_type == 'custom-rotational':
|
| 154 |
+
return CustomRotationalEmbedding(d_backbone)
|
| 155 |
+
elif self.positional_embedding_type == 'custom-rotational-1d':
|
| 156 |
+
return CustomRotationalEmbedding1D(d_backbone)
|
| 157 |
+
elif self.positional_embedding_type == 'none':
|
| 158 |
+
return lambda x: 0 # Default no-op
|
| 159 |
+
else:
|
| 160 |
+
raise ValueError(f"Invalid positional_embedding_type: {self.positional_embedding_type}")
|
| 161 |
+
|
| 162 |
+
def get_attention(self, heads, dropout):
|
| 163 |
+
"""Get the attention module."""
|
| 164 |
+
return nn.MultiheadAttention(self.d_input, heads, dropout, batch_first=True)
|
| 165 |
+
|
| 166 |
+
def get_kv_proj(self):
|
| 167 |
+
"""Get the key-value projection module."""
|
| 168 |
+
return nn.Sequential(nn.LazyLinear(self.d_input), nn.LayerNorm(self.d_input))
|
| 169 |
+
|
| 170 |
+
def get_q_proj(self):
|
| 171 |
+
"""Get the query projection module."""
|
| 172 |
+
return nn.LazyLinear(self.d_input)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def verify_args(self):
|
| 176 |
+
"""Verify the validity of the input arguments."""
|
| 177 |
+
|
| 178 |
+
assert self.backbone_type in VALID_BACKBONE_TYPES + ['none'], \
|
| 179 |
+
f"Invalid backbone_type: {self.backbone_type}"
|
| 180 |
+
|
| 181 |
+
assert self.positional_embedding_type in VALID_POSITIONAL_EMBEDDING_TYPES + ['none'], \
|
| 182 |
+
f"Invalid positional_embedding_type: {self.positional_embedding_type}"
|
| 183 |
+
|
| 184 |
+
if self.backbone_type=='none' and self.positional_embedding_type!='none':
|
| 185 |
+
raise AssertionError("There should be no positional embedding if there is no backbone.")
|
| 186 |
+
|
| 187 |
+
pass
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def forward(self, x, track=False):
|
| 193 |
+
"""
|
| 194 |
+
Forward pass - Reverted to structure closer to user's working version.
|
| 195 |
+
Executes T=iterations steps.
|
| 196 |
+
"""
|
| 197 |
+
B = x.size(0)
|
| 198 |
+
device = x.device
|
| 199 |
+
|
| 200 |
+
# --- Tracking Initialization ---
|
| 201 |
+
activations_tracking = []
|
| 202 |
+
attention_tracking = []
|
| 203 |
+
|
| 204 |
+
# --- Featurise Input Data ---
|
| 205 |
+
kv = self.compute_features(x)
|
| 206 |
+
|
| 207 |
+
# --- Initialise Recurrent State ---
|
| 208 |
+
hn = torch.repeat_interleave(self.start_hidden_state.unsqueeze(1), x.size(0), 1)
|
| 209 |
+
cn = torch.repeat_interleave(self.start_cell_state.unsqueeze(1), x.size(0), 1)
|
| 210 |
+
state_trace = [hn[-1]]
|
| 211 |
+
|
| 212 |
+
# --- Prepare Storage for Outputs per Iteration ---
|
| 213 |
+
predictions = torch.empty(B, self.out_dims, self.iterations, device=device, dtype=x.dtype)
|
| 214 |
+
certainties = torch.empty(B, 2, self.iterations, device=device, dtype=x.dtype)
|
| 215 |
+
|
| 216 |
+
# --- Recurrent Loop ---
|
| 217 |
+
for stepi in range(self.iterations):
|
| 218 |
+
|
| 219 |
+
# --- Interact with Data via Attention ---
|
| 220 |
+
q = self.q_proj(hn[-1].unsqueeze(1))
|
| 221 |
+
attn_out, attn_weights = self.attention(q, kv, kv, average_attn_weights=False, need_weights=True)
|
| 222 |
+
lstm_input = attn_out
|
| 223 |
+
|
| 224 |
+
# --- Apply LSTM ---
|
| 225 |
+
hidden_state, (hn,cn) = self.lstm(lstm_input, (hn, cn))
|
| 226 |
+
hidden_state = hidden_state.squeeze(1)
|
| 227 |
+
state_trace.append(hidden_state)
|
| 228 |
+
|
| 229 |
+
# --- Get Predictions and Certainties ---
|
| 230 |
+
current_prediction = self.output_projector(hidden_state)
|
| 231 |
+
current_certainty = self.compute_certainty(current_prediction)
|
| 232 |
+
|
| 233 |
+
predictions[..., stepi] = current_prediction
|
| 234 |
+
certainties[..., stepi] = current_certainty
|
| 235 |
+
|
| 236 |
+
# --- Tracking ---
|
| 237 |
+
if track:
|
| 238 |
+
activations_tracking.append(hidden_state.squeeze(1).detach().cpu().numpy())
|
| 239 |
+
attention_tracking.append(attn_weights.detach().cpu().numpy())
|
| 240 |
+
|
| 241 |
+
# --- Return Values ---
|
| 242 |
+
if track:
|
| 243 |
+
return predictions, certainties, None, np.zeros_like(activations_tracking), np.array(activations_tracking), np.array(attention_tracking)
|
| 244 |
+
return predictions, certainties, None
|
models/lstm_qamnist.py
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F # Used for GLU if not in modules
|
| 4 |
+
import numpy as np
|
| 5 |
+
import math
|
| 6 |
+
|
| 7 |
+
# Local imports (Assuming these contain necessary custom modules)
|
| 8 |
+
from models.modules import *
|
| 9 |
+
from models.utils import * # Assuming compute_decay, compute_normalized_entropy are here
|
| 10 |
+
|
| 11 |
+
class LSTMBaseline(nn.Module):
|
| 12 |
+
"""
|
| 13 |
+
LSTM Baseline
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
iterations (int): Number of internal 'thought' steps (T, in paper).
|
| 17 |
+
d_model (int): Core dimensionality of the CTM's latent space (D, in paper).
|
| 18 |
+
d_input (int): Dimensionality of projected attention outputs or direct input features.
|
| 19 |
+
heads (int): Number of attention heads.
|
| 20 |
+
n_synch_out (int): Number of neurons used for output synchronisation (No, in paper).
|
| 21 |
+
n_synch_action (int): Number of neurons used for action/attention synchronisation (Ni, in paper).
|
| 22 |
+
synapse_depth (int): Depth of the synapse model (U-Net if > 1, else MLP).
|
| 23 |
+
memory_length (int): History length for Neuron-Level Models (M, in paper).
|
| 24 |
+
deep_nlms (bool): Use deeper (2-layer) NLMs if True, else linear.
|
| 25 |
+
memory_hidden_dims (int): Hidden dimension size for deep NLMs.
|
| 26 |
+
do_layernorm_nlm (bool): Apply LayerNorm within NLMs.
|
| 27 |
+
backbone_type (str): Type of feature extraction backbone (e.g., 'resnet18-2', 'none').
|
| 28 |
+
positional_embedding_type (str): Type of positional embedding for backbone features.
|
| 29 |
+
out_dims (int): Dimensionality of the final output projection.
|
| 30 |
+
prediction_reshaper (list): Shape for reshaping predictions before certainty calculation (task-specific).
|
| 31 |
+
dropout (float): Dropout rate.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
def __init__(self,
|
| 35 |
+
iterations,
|
| 36 |
+
d_model,
|
| 37 |
+
d_input,
|
| 38 |
+
heads,
|
| 39 |
+
out_dims,
|
| 40 |
+
iterations_per_digit,
|
| 41 |
+
iterations_per_question_part,
|
| 42 |
+
iterations_for_answering,
|
| 43 |
+
prediction_reshaper=[-1],
|
| 44 |
+
dropout=0,
|
| 45 |
+
):
|
| 46 |
+
super(LSTMBaseline, self).__init__()
|
| 47 |
+
|
| 48 |
+
# --- Core Parameters ---
|
| 49 |
+
self.iterations = iterations
|
| 50 |
+
self.d_model = d_model
|
| 51 |
+
self.prediction_reshaper = prediction_reshaper
|
| 52 |
+
self.out_dims = out_dims
|
| 53 |
+
self.d_input = d_input
|
| 54 |
+
self.backbone_type = 'qamnist_backbone'
|
| 55 |
+
self.iterations_per_digit = iterations_per_digit
|
| 56 |
+
self.iterations_per_question_part = iterations_per_question_part
|
| 57 |
+
self.total_iterations_for_answering = iterations_for_answering
|
| 58 |
+
|
| 59 |
+
# --- Backbone / Feature Extraction ---
|
| 60 |
+
self.backbone_digit = MNISTBackbone(d_input)
|
| 61 |
+
self.index_backbone = QAMNISTIndexEmbeddings(50, d_input)
|
| 62 |
+
self.operator_backbone = QAMNISTOperatorEmbeddings(2, d_input)
|
| 63 |
+
|
| 64 |
+
# --- Core CTM Modules ---
|
| 65 |
+
self.lstm_cell = nn.LSTMCell(d_input, d_model)
|
| 66 |
+
self.register_parameter('start_hidden_state', nn.Parameter(torch.zeros((d_model)).uniform_(-math.sqrt(1/(d_model)), math.sqrt(1/(d_model))), requires_grad=True))
|
| 67 |
+
self.register_parameter('start_cell_state', nn.Parameter(torch.zeros((d_model)).uniform_(-math.sqrt(1/(d_model)), math.sqrt(1/(d_model))), requires_grad=True))
|
| 68 |
+
|
| 69 |
+
# Attention
|
| 70 |
+
self.q_proj = nn.LazyLinear(d_input)
|
| 71 |
+
self.kv_proj = nn.Sequential(nn.LazyLinear(d_input), nn.LayerNorm(d_input))
|
| 72 |
+
self.attention = nn.MultiheadAttention(d_input, heads, dropout, batch_first=True)
|
| 73 |
+
|
| 74 |
+
# Output Projection
|
| 75 |
+
self.output_projector = nn.Sequential(nn.LazyLinear(out_dims))
|
| 76 |
+
|
| 77 |
+
def compute_certainty(self, current_prediction):
|
| 78 |
+
"""Compute the certainty of the current prediction."""
|
| 79 |
+
B = current_prediction.size(0)
|
| 80 |
+
reshaped_pred = current_prediction.reshape([B] +self.prediction_reshaper)
|
| 81 |
+
ne = compute_normalized_entropy(reshaped_pred)
|
| 82 |
+
current_certainty = torch.stack((ne, 1-ne), -1)
|
| 83 |
+
return current_certainty
|
| 84 |
+
|
| 85 |
+
def get_kv_for_step(self, stepi, x, z, thought_steps, prev_input=None, prev_kv=None):
|
| 86 |
+
is_digit_step, is_question_step, is_answer_step = thought_steps.determine_step_type(stepi)
|
| 87 |
+
|
| 88 |
+
if is_digit_step:
|
| 89 |
+
current_input = x[:, stepi]
|
| 90 |
+
if prev_input is not None and torch.equal(current_input, prev_input):
|
| 91 |
+
return prev_kv, prev_input
|
| 92 |
+
kv = self.kv_proj(self.backbone_digit(current_input).flatten(2).permute(0, 2, 1))
|
| 93 |
+
|
| 94 |
+
elif is_question_step:
|
| 95 |
+
offset = stepi - thought_steps.total_iterations_for_digits
|
| 96 |
+
current_input = z[:, offset].squeeze(0)
|
| 97 |
+
if prev_input is not None and torch.equal(current_input, prev_input):
|
| 98 |
+
return prev_kv, prev_input
|
| 99 |
+
is_index_step, is_operator_step = thought_steps.determine_answer_step_type(stepi)
|
| 100 |
+
if is_index_step:
|
| 101 |
+
kv = self.kv_proj(self.index_backbone(current_input))
|
| 102 |
+
elif is_operator_step:
|
| 103 |
+
kv = self.kv_proj(self.operator_backbone(current_input))
|
| 104 |
+
else:
|
| 105 |
+
raise ValueError("Invalid step type for question processing.")
|
| 106 |
+
|
| 107 |
+
elif is_answer_step:
|
| 108 |
+
current_input = None
|
| 109 |
+
kv = torch.zeros((x.size(0), self.d_input), device=x.device)
|
| 110 |
+
|
| 111 |
+
else:
|
| 112 |
+
raise ValueError("Invalid step type.")
|
| 113 |
+
|
| 114 |
+
return kv, current_input
|
| 115 |
+
|
| 116 |
+
def forward(self, x, z, track=False):
|
| 117 |
+
"""
|
| 118 |
+
Forward pass - Reverted to structure closer to user's working version.
|
| 119 |
+
Executes T=iterations steps.
|
| 120 |
+
"""
|
| 121 |
+
B = x.size(0) # Batch size
|
| 122 |
+
|
| 123 |
+
# --- Tracking Initialization ---
|
| 124 |
+
activations_tracking = []
|
| 125 |
+
attention_tracking = [] # Note: reshaping this correctly requires knowing num_heads
|
| 126 |
+
embedding_tracking = []
|
| 127 |
+
|
| 128 |
+
thought_steps = ThoughtSteps(self.iterations_per_digit, self.iterations_per_question_part, self.total_iterations_for_answering, x.size(1), z.size(1))
|
| 129 |
+
|
| 130 |
+
# --- Step 2: Initialise Recurrent State ---
|
| 131 |
+
hidden_state = torch.repeat_interleave(self.start_hidden_state.unsqueeze(0), x.size(0), 0)
|
| 132 |
+
cell_state = torch.repeat_interleave(self.start_cell_state.unsqueeze(0), x.size(0), 0)
|
| 133 |
+
|
| 134 |
+
state_trace = [hidden_state]
|
| 135 |
+
|
| 136 |
+
device = hidden_state.device
|
| 137 |
+
|
| 138 |
+
# Storage for outputs per iteration
|
| 139 |
+
predictions = torch.empty(B, self.out_dims, thought_steps.total_iterations, device=device, dtype=x.dtype) # Adjust dtype if needed
|
| 140 |
+
certainties = torch.empty(B, 2, thought_steps.total_iterations, device=device, dtype=x.dtype) # Adjust dtype if needed
|
| 141 |
+
|
| 142 |
+
prev_input = None
|
| 143 |
+
prev_kv = None
|
| 144 |
+
|
| 145 |
+
# --- Recurrent Loop (T=iterations steps) ---
|
| 146 |
+
for stepi in range(thought_steps.total_iterations):
|
| 147 |
+
|
| 148 |
+
is_digit_step, is_question_step, is_answer_step = thought_steps.determine_step_type(stepi)
|
| 149 |
+
kv, prev_input = self.get_kv_for_step(stepi, x, z, thought_steps, prev_input, prev_kv)
|
| 150 |
+
prev_kv = kv
|
| 151 |
+
|
| 152 |
+
# --- Interact with Data via Attention ---
|
| 153 |
+
attn_weights = None
|
| 154 |
+
if is_digit_step:
|
| 155 |
+
q = self.q_proj(hidden_state).unsqueeze(1)
|
| 156 |
+
attn_out, attn_weights = self.attention(q, kv, kv, average_attn_weights=False, need_weights=True)
|
| 157 |
+
lstm_input = attn_out.squeeze(1)
|
| 158 |
+
else:
|
| 159 |
+
lstm_input = kv
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
hidden_state, cell_state = self.lstm_cell(lstm_input.squeeze(1), (hidden_state, cell_state))
|
| 164 |
+
state_trace.append(hidden_state)
|
| 165 |
+
|
| 166 |
+
# --- Get Predictions and Certainties ---
|
| 167 |
+
current_prediction = self.output_projector(hidden_state)
|
| 168 |
+
current_certainty = self.compute_certainty(current_prediction)
|
| 169 |
+
|
| 170 |
+
predictions[..., stepi] = current_prediction
|
| 171 |
+
certainties[..., stepi] = current_certainty
|
| 172 |
+
|
| 173 |
+
# --- Tracking ---
|
| 174 |
+
if track:
|
| 175 |
+
activations_tracking.append(hidden_state.squeeze(1).detach().cpu().numpy())
|
| 176 |
+
if attn_weights is not None:
|
| 177 |
+
attention_tracking.append(attn_weights.detach().cpu().numpy())
|
| 178 |
+
if is_question_step:
|
| 179 |
+
embedding_tracking.append(kv.detach().cpu().numpy())
|
| 180 |
+
|
| 181 |
+
# --- Return Values ---
|
| 182 |
+
if track:
|
| 183 |
+
return predictions, certainties, None, np.array(activations_tracking), np.array(activations_tracking), np.array(attention_tracking), np.array(embedding_tracking)
|
| 184 |
+
return predictions, certainties, None
|
models/lstm_rl.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F # Used for GLU if not in modules
|
| 4 |
+
import numpy as np
|
| 5 |
+
import math
|
| 6 |
+
|
| 7 |
+
# Local imports (Assuming these contain necessary custom modules)
|
| 8 |
+
from models.modules import *
|
| 9 |
+
from models.utils import * # Assuming compute_decay, compute_normalized_entropy are here
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class LSTMBaseline(nn.Module):
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
LSTM Baseline
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
iterations (int): Number of internal 'thought' steps (T, in paper).
|
| 19 |
+
d_model (int): Core dimensionality of the CTM's latent space (D, in paper).
|
| 20 |
+
d_input (int): Dimensionality of projected attention outputs or direct input features.
|
| 21 |
+
backbone_type (str): Type of feature extraction backbone (e.g., 'resnet18-2', 'none').
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self,
|
| 25 |
+
iterations,
|
| 26 |
+
d_model,
|
| 27 |
+
d_input,
|
| 28 |
+
backbone_type,
|
| 29 |
+
):
|
| 30 |
+
super(LSTMBaseline, self).__init__()
|
| 31 |
+
|
| 32 |
+
# --- Core Parameters ---
|
| 33 |
+
self.iterations = iterations
|
| 34 |
+
self.d_model = d_model
|
| 35 |
+
self.backbone_type = backbone_type
|
| 36 |
+
|
| 37 |
+
# --- Input Assertions ---
|
| 38 |
+
assert backbone_type in ('navigation-backbone', 'classic-control-backbone'), f"Invalid backbone_type: {backbone_type}"
|
| 39 |
+
|
| 40 |
+
# --- Backbone / Feature Extraction ---
|
| 41 |
+
if self.backbone_type == 'navigation-backbone':
|
| 42 |
+
grid_size = 7
|
| 43 |
+
self.backbone = MiniGridBackbone(d_input=d_input, grid_size=grid_size)
|
| 44 |
+
lstm_cell_input_dim = grid_size * grid_size * d_input
|
| 45 |
+
|
| 46 |
+
elif self.backbone_type == 'classic-control-backbone':
|
| 47 |
+
self.backbone = ClassicControlBackbone(d_input=d_input)
|
| 48 |
+
lstm_cell_input_dim = d_input
|
| 49 |
+
|
| 50 |
+
else:
|
| 51 |
+
raise NotImplemented('The only backbone supported for RL are for navigation (symbolic C x H x W inputs) and classic control (vectors of length D).')
|
| 52 |
+
|
| 53 |
+
# --- Core LSTM Modules ---
|
| 54 |
+
self.lstm_cell = nn.LSTMCell(lstm_cell_input_dim, d_model)
|
| 55 |
+
self.register_parameter('start_hidden_state', nn.Parameter(torch.zeros((d_model)).uniform_(-math.sqrt(1/(d_model)), math.sqrt(1/(d_model))), requires_grad=True))
|
| 56 |
+
self.register_parameter('start_cell_state', nn.Parameter(torch.zeros((d_model)).uniform_(-math.sqrt(1/(d_model)), math.sqrt(1/(d_model))), requires_grad=True))
|
| 57 |
+
|
| 58 |
+
def compute_features(self, x):
|
| 59 |
+
"""Applies backbone and positional embedding to input."""
|
| 60 |
+
return self.backbone(x)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def forward(self, x, hidden_states, track=False):
|
| 64 |
+
"""
|
| 65 |
+
Forward pass - Reverted to structure closer to user's working version.
|
| 66 |
+
Executes T=iterations steps.
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
# --- Tracking Initialization ---
|
| 70 |
+
activations_tracking = []
|
| 71 |
+
|
| 72 |
+
# --- Featurise Input Data ---
|
| 73 |
+
features = self.compute_features(x)
|
| 74 |
+
|
| 75 |
+
hidden_state = hidden_states[0]
|
| 76 |
+
cell_state = hidden_states[1]
|
| 77 |
+
|
| 78 |
+
# --- Recurrent Loop ---
|
| 79 |
+
for stepi in range(self.iterations):
|
| 80 |
+
|
| 81 |
+
lstm_input = features.reshape(x.size(0), -1)
|
| 82 |
+
hidden_state, cell_state = self.lstm_cell(lstm_input.squeeze(1), (hidden_state, cell_state))
|
| 83 |
+
|
| 84 |
+
# --- Tracking ---
|
| 85 |
+
if track:
|
| 86 |
+
activations_tracking.append(hidden_state.squeeze(1).detach().cpu().numpy())
|
| 87 |
+
|
| 88 |
+
hidden_states = (
|
| 89 |
+
hidden_state,
|
| 90 |
+
cell_state
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
# --- Return Values ---
|
| 94 |
+
if track:
|
| 95 |
+
return hidden_state, hidden_states, np.array(activations_tracking), np.array(activations_tracking)
|
| 96 |
+
return hidden_state, hidden_states
|
models/modules.py
ADDED
|
@@ -0,0 +1,692 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F # Used for GLU
|
| 4 |
+
import math
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
# Assuming 'add_coord_dim' is defined in models.utils
|
| 8 |
+
from models.utils import add_coord_dim
|
| 9 |
+
|
| 10 |
+
# --- Basic Utility Modules ---
|
| 11 |
+
|
| 12 |
+
class Identity(nn.Module):
|
| 13 |
+
"""
|
| 14 |
+
Identity Module.
|
| 15 |
+
|
| 16 |
+
Returns the input tensor unchanged. Useful as a placeholder or a no-op layer
|
| 17 |
+
in nn.Sequential containers or conditional network parts.
|
| 18 |
+
"""
|
| 19 |
+
def __init__(self):
|
| 20 |
+
super().__init__()
|
| 21 |
+
|
| 22 |
+
def forward(self, x):
|
| 23 |
+
return x
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class Squeeze(nn.Module):
|
| 27 |
+
"""
|
| 28 |
+
Squeeze Module.
|
| 29 |
+
|
| 30 |
+
Removes a specified dimension of size 1 from the input tensor.
|
| 31 |
+
Useful for incorporating tensor dimension squeezing within nn.Sequential.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
dim (int): The dimension to squeeze.
|
| 35 |
+
"""
|
| 36 |
+
def __init__(self, dim):
|
| 37 |
+
super().__init__()
|
| 38 |
+
self.dim = dim
|
| 39 |
+
|
| 40 |
+
def forward(self, x):
|
| 41 |
+
return x.squeeze(self.dim)
|
| 42 |
+
|
| 43 |
+
# --- Core CTM Component Modules ---
|
| 44 |
+
|
| 45 |
+
class SynapseUNET(nn.Module):
|
| 46 |
+
"""
|
| 47 |
+
UNET-style architecture for the Synapse Model (f_theta1 in the paper).
|
| 48 |
+
|
| 49 |
+
This module implements the connections between neurons in the CTM's latent
|
| 50 |
+
space. It processes the combined input (previous post-activation state z^t
|
| 51 |
+
and attention output o^t) to produce the pre-activations (a^t) for the
|
| 52 |
+
next internal tick (Eq. 1 in the paper).
|
| 53 |
+
|
| 54 |
+
While a simpler Linear or MLP layer can be used, the paper notes
|
| 55 |
+
that this U-Net structure empirically performed better, suggesting benefit
|
| 56 |
+
from more flexible synaptic connections[cite: 79, 80]. This implementation
|
| 57 |
+
uses `depth` points in linspace and creates `depth-1` down/up blocks.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
in_dims (int): Number of input dimensions (d_model + d_input).
|
| 61 |
+
out_dims (int): Number of output dimensions (d_model).
|
| 62 |
+
depth (int): Determines structure size; creates `depth-1` down/up blocks.
|
| 63 |
+
minimum_width (int): Smallest channel width at the U-Net bottleneck.
|
| 64 |
+
dropout (float): Dropout rate applied within down/up projections.
|
| 65 |
+
"""
|
| 66 |
+
def __init__(self,
|
| 67 |
+
out_dims,
|
| 68 |
+
depth,
|
| 69 |
+
minimum_width=16,
|
| 70 |
+
dropout=0.0):
|
| 71 |
+
super().__init__()
|
| 72 |
+
self.width_out = out_dims
|
| 73 |
+
self.n_deep = depth # Store depth just for reference if needed
|
| 74 |
+
|
| 75 |
+
# Define UNET structure based on depth
|
| 76 |
+
# Creates `depth` width values, leading to `depth-1` blocks
|
| 77 |
+
widths = np.linspace(out_dims, minimum_width, depth)
|
| 78 |
+
|
| 79 |
+
# Initial projection layer
|
| 80 |
+
self.first_projection = nn.Sequential(
|
| 81 |
+
nn.LazyLinear(int(widths[0])), # Project to the first width
|
| 82 |
+
nn.LayerNorm(int(widths[0])),
|
| 83 |
+
nn.SiLU()
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
# Downward path (encoding layers)
|
| 87 |
+
self.down_projections = nn.ModuleList()
|
| 88 |
+
self.up_projections = nn.ModuleList()
|
| 89 |
+
self.skip_lns = nn.ModuleList()
|
| 90 |
+
num_blocks = len(widths) - 1 # Number of down/up blocks created
|
| 91 |
+
|
| 92 |
+
for i in range(num_blocks):
|
| 93 |
+
# Down block: widths[i] -> widths[i+1]
|
| 94 |
+
self.down_projections.append(nn.Sequential(
|
| 95 |
+
nn.Dropout(dropout),
|
| 96 |
+
nn.Linear(int(widths[i]), int(widths[i+1])),
|
| 97 |
+
nn.LayerNorm(int(widths[i+1])),
|
| 98 |
+
nn.SiLU()
|
| 99 |
+
))
|
| 100 |
+
# Up block: widths[i+1] -> widths[i]
|
| 101 |
+
# Note: Up blocks are added in order matching down blocks conceptually,
|
| 102 |
+
# but applied in reverse order in the forward pass.
|
| 103 |
+
self.up_projections.append(nn.Sequential(
|
| 104 |
+
nn.Dropout(dropout),
|
| 105 |
+
nn.Linear(int(widths[i+1]), int(widths[i])),
|
| 106 |
+
nn.LayerNorm(int(widths[i])),
|
| 107 |
+
nn.SiLU()
|
| 108 |
+
))
|
| 109 |
+
# Skip connection LayerNorm operates on width[i]
|
| 110 |
+
self.skip_lns.append(nn.LayerNorm(int(widths[i])))
|
| 111 |
+
|
| 112 |
+
def forward(self, x):
|
| 113 |
+
# Initial projection
|
| 114 |
+
out_first = self.first_projection(x)
|
| 115 |
+
|
| 116 |
+
# Downward path, storing outputs for skip connections
|
| 117 |
+
outs_down = [out_first]
|
| 118 |
+
for layer in self.down_projections:
|
| 119 |
+
outs_down.append(layer(outs_down[-1]))
|
| 120 |
+
# outs_down contains [level_0, level_1, ..., level_depth-1=bottleneck] outputs
|
| 121 |
+
|
| 122 |
+
# Upward path, starting from the bottleneck output
|
| 123 |
+
outs_up = outs_down[-1] # Bottleneck activation
|
| 124 |
+
num_blocks = len(self.up_projections) # Should be depth - 1
|
| 125 |
+
|
| 126 |
+
for i in range(num_blocks):
|
| 127 |
+
# Apply up projection in reverse order relative to down blocks
|
| 128 |
+
# up_projection[num_blocks - 1 - i] processes deeper features first
|
| 129 |
+
up_layer_idx = num_blocks - 1 - i
|
| 130 |
+
out_up = self.up_projections[up_layer_idx](outs_up)
|
| 131 |
+
|
| 132 |
+
# Get corresponding skip connection from downward path
|
| 133 |
+
# skip_connection index = num_blocks - 1 - i (same as up_layer_idx)
|
| 134 |
+
# This matches the output width of the up_projection[up_layer_idx]
|
| 135 |
+
skip_idx = up_layer_idx
|
| 136 |
+
skip_connection = outs_down[skip_idx]
|
| 137 |
+
|
| 138 |
+
# Add skip connection and apply LayerNorm corresponding to this level
|
| 139 |
+
# skip_lns index also corresponds to the level = skip_idx
|
| 140 |
+
outs_up = self.skip_lns[skip_idx](out_up + skip_connection)
|
| 141 |
+
|
| 142 |
+
# The final output after all up-projections
|
| 143 |
+
return outs_up
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
class SuperLinear(nn.Module):
|
| 147 |
+
"""
|
| 148 |
+
SuperLinear Layer: Implements Neuron-Level Models (NLMs) for the CTM.
|
| 149 |
+
|
| 150 |
+
This layer is the core component enabling Neuron-Level Models (NLMs),
|
| 151 |
+
referred to as g_theta_d in the paper (Eq. 3). It applies N independent
|
| 152 |
+
linear transformations (or small MLPs when used sequentially) to corresponding
|
| 153 |
+
slices of the input tensor along a specified dimension (typically the neuron
|
| 154 |
+
or feature dimension).
|
| 155 |
+
|
| 156 |
+
How it works for NLMs:
|
| 157 |
+
- The input `x` is expected to be the pre-activation history for each neuron,
|
| 158 |
+
shaped (batch_size, n_neurons=N, history_length=in_dims).
|
| 159 |
+
- This layer holds unique weights (`w1`) and biases (`b1`) for *each* of the `N` neurons.
|
| 160 |
+
`w1` has shape (in_dims, out_dims, N), `b1` has shape (1, N, out_dims).
|
| 161 |
+
- `torch.einsum('bni,iog->bno', x, self.w1)` performs N independent matrix
|
| 162 |
+
multiplications in parallel (mapping from dim `i` to `o` for each neuron `n`):
|
| 163 |
+
- For each neuron `n` (from 0 to N-1):
|
| 164 |
+
- It takes the neuron's history `x[:, n, :]` (shape B, in_dims).
|
| 165 |
+
- Multiplies it by the neuron's unique weight matrix `self.w1[:, :, n]` (shape in_dims, out_dims).
|
| 166 |
+
- Resulting in `out[:, n, :]` (shape B, out_dims).
|
| 167 |
+
- The unique bias `self.b1[:, n, :]` is added.
|
| 168 |
+
- The result is squeezed on the last dim (if out_dims=1) and scaled by `T`.
|
| 169 |
+
|
| 170 |
+
This allows each neuron `d` to process its temporal history `A_d^t` using
|
| 171 |
+
its private parameters `theta_d` to produce the post-activation `z_d^{t+1}`,
|
| 172 |
+
enabling the fine-grained temporal dynamics central to the CTM[cite: 7, 30, 85].
|
| 173 |
+
It's typically used within the `trace_processor` module of the main CTM class.
|
| 174 |
+
|
| 175 |
+
Args:
|
| 176 |
+
in_dims (int): Input dimension (typically `memory_length`).
|
| 177 |
+
out_dims (int): Output dimension per neuron.
|
| 178 |
+
N (int): Number of independent linear models (typically `d_model`).
|
| 179 |
+
T (float): Initial value for learnable temperature/scaling factor applied to output.
|
| 180 |
+
do_norm (bool): Apply Layer Normalization to the input history before linear transform.
|
| 181 |
+
dropout (float): Dropout rate applied to the input.
|
| 182 |
+
"""
|
| 183 |
+
def __init__(self,
|
| 184 |
+
in_dims,
|
| 185 |
+
out_dims,
|
| 186 |
+
N,
|
| 187 |
+
T=1.0,
|
| 188 |
+
do_norm=False,
|
| 189 |
+
dropout=0):
|
| 190 |
+
super().__init__()
|
| 191 |
+
# N is the number of neurons (d_model), in_dims is the history length (memory_length)
|
| 192 |
+
self.dropout = nn.Dropout(dropout) if dropout > 0 else Identity()
|
| 193 |
+
self.in_dims = in_dims # Corresponds to memory_length
|
| 194 |
+
# LayerNorm applied across the history dimension for each neuron independently
|
| 195 |
+
self.layernorm = nn.LayerNorm(in_dims, elementwise_affine=True) if do_norm else Identity()
|
| 196 |
+
self.do_norm = do_norm
|
| 197 |
+
|
| 198 |
+
# Initialize weights and biases
|
| 199 |
+
# w1 shape: (memory_length, out_dims, d_model)
|
| 200 |
+
self.register_parameter('w1', nn.Parameter(
|
| 201 |
+
torch.empty((in_dims, out_dims, N)).uniform_(
|
| 202 |
+
-1/math.sqrt(in_dims + out_dims),
|
| 203 |
+
1/math.sqrt(in_dims + out_dims)
|
| 204 |
+
), requires_grad=True)
|
| 205 |
+
)
|
| 206 |
+
# b1 shape: (1, d_model, out_dims)
|
| 207 |
+
self.register_parameter('b1', nn.Parameter(torch.zeros((1, N, out_dims)), requires_grad=True))
|
| 208 |
+
# Learnable temperature/scaler T
|
| 209 |
+
self.register_parameter('T', nn.Parameter(torch.Tensor([T])))
|
| 210 |
+
|
| 211 |
+
def forward(self, x):
|
| 212 |
+
"""
|
| 213 |
+
Args:
|
| 214 |
+
x (torch.Tensor): Input tensor, expected shape (B, N, in_dims)
|
| 215 |
+
where B=batch, N=d_model, in_dims=memory_length.
|
| 216 |
+
Returns:
|
| 217 |
+
torch.Tensor: Output tensor, shape (B, N) after squeeze(-1).
|
| 218 |
+
"""
|
| 219 |
+
# Input shape: (B, D, M) where D=d_model=N neurons in CTM, M=history/memory length
|
| 220 |
+
out = self.dropout(x)
|
| 221 |
+
# LayerNorm across the memory_length dimension (dim=-1)
|
| 222 |
+
out = self.layernorm(out) # Shape remains (B, N, M)
|
| 223 |
+
|
| 224 |
+
# Apply N independent linear models using einsum
|
| 225 |
+
# einsum('BDM,MHD->BDH', ...)
|
| 226 |
+
# x: (B=batch size, D=N neurons, one NLM per each of these, M=history/memory length)
|
| 227 |
+
# w1: (M, H=hidden dims if using MLP, otherwise output, D=N neurons, parallel)
|
| 228 |
+
# b1: (1, D=N neurons, H)
|
| 229 |
+
# einsum result: (B, D, H)
|
| 230 |
+
# Applying bias requires matching shapes, b1 is broadcasted.
|
| 231 |
+
out = torch.einsum('BDM,MHD->BDH', out, self.w1) + self.b1
|
| 232 |
+
|
| 233 |
+
# Squeeze the output dimension (assumed to be 1 usually) and scale by T
|
| 234 |
+
# This matches the original code's structure exactly.
|
| 235 |
+
out = out.squeeze(-1) / self.T
|
| 236 |
+
return out
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
# --- Backbone Modules ---
|
| 240 |
+
|
| 241 |
+
class ParityBackbone(nn.Module):
|
| 242 |
+
def __init__(self, n_embeddings, d_embedding):
|
| 243 |
+
super(ParityBackbone, self).__init__()
|
| 244 |
+
self.embedding = nn.Embedding(n_embeddings, d_embedding)
|
| 245 |
+
|
| 246 |
+
def forward(self, x):
|
| 247 |
+
"""
|
| 248 |
+
Maps -1 (negative parity) to 0 and 1 (positive) to 1
|
| 249 |
+
"""
|
| 250 |
+
x = (x == 1).long()
|
| 251 |
+
return self.embedding(x.long()).transpose(1, 2) # Transpose for compatibility with other backbones
|
| 252 |
+
|
| 253 |
+
class QAMNISTOperatorEmbeddings(nn.Module):
|
| 254 |
+
def __init__(self, num_operator_types, d_projection):
|
| 255 |
+
super(QAMNISTOperatorEmbeddings, self).__init__()
|
| 256 |
+
self.embedding = nn.Embedding(num_operator_types, d_projection)
|
| 257 |
+
|
| 258 |
+
def forward(self, x):
|
| 259 |
+
# -1 for plus and -2 for minus
|
| 260 |
+
return self.embedding(-x - 1)
|
| 261 |
+
|
| 262 |
+
class QAMNISTIndexEmbeddings(torch.nn.Module):
|
| 263 |
+
def __init__(self, max_seq_length, embedding_dim):
|
| 264 |
+
super().__init__()
|
| 265 |
+
self.max_seq_length = max_seq_length
|
| 266 |
+
self.embedding_dim = embedding_dim
|
| 267 |
+
|
| 268 |
+
embedding = torch.zeros(max_seq_length, embedding_dim)
|
| 269 |
+
position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
|
| 270 |
+
div_term = torch.exp(torch.arange(0, embedding_dim, 2).float() * (-math.log(10000.0) / embedding_dim))
|
| 271 |
+
|
| 272 |
+
embedding[:, 0::2] = torch.sin(position * div_term)
|
| 273 |
+
embedding[:, 1::2] = torch.cos(position * div_term)
|
| 274 |
+
|
| 275 |
+
self.register_buffer('embedding', embedding)
|
| 276 |
+
|
| 277 |
+
def forward(self, x):
|
| 278 |
+
return self.embedding[x]
|
| 279 |
+
|
| 280 |
+
class ThoughtSteps:
|
| 281 |
+
"""
|
| 282 |
+
Helper class for managing "thought steps" in the ctm_qamnist pipeline.
|
| 283 |
+
|
| 284 |
+
Args:
|
| 285 |
+
iterations_per_digit (int): Number of iterations for each digit.
|
| 286 |
+
iterations_per_question_part (int): Number of iterations for each question part.
|
| 287 |
+
total_iterations_for_answering (int): Total number of iterations for answering.
|
| 288 |
+
total_iterations_for_digits (int): Total number of iterations for digits.
|
| 289 |
+
total_iterations_for_question (int): Total number of iterations for question.
|
| 290 |
+
"""
|
| 291 |
+
def __init__(self, iterations_per_digit, iterations_per_question_part, total_iterations_for_answering, total_iterations_for_digits, total_iterations_for_question):
|
| 292 |
+
self.iterations_per_digit = iterations_per_digit
|
| 293 |
+
self.iterations_per_question_part = iterations_per_question_part
|
| 294 |
+
self.total_iterations_for_digits = total_iterations_for_digits
|
| 295 |
+
self.total_iterations_for_question = total_iterations_for_question
|
| 296 |
+
self.total_iterations_for_answering = total_iterations_for_answering
|
| 297 |
+
self.total_iterations = self.total_iterations_for_digits + self.total_iterations_for_question + self.total_iterations_for_answering
|
| 298 |
+
|
| 299 |
+
def determine_step_type(self, stepi: int):
|
| 300 |
+
is_digit_step = stepi < self.total_iterations_for_digits
|
| 301 |
+
is_question_step = self.total_iterations_for_digits <= stepi < self.total_iterations_for_digits + self.total_iterations_for_question
|
| 302 |
+
is_answer_step = stepi >= self.total_iterations_for_digits + self.total_iterations_for_question
|
| 303 |
+
return is_digit_step, is_question_step, is_answer_step
|
| 304 |
+
|
| 305 |
+
def determine_answer_step_type(self, stepi: int):
|
| 306 |
+
step_within_questions = stepi - self.total_iterations_for_digits
|
| 307 |
+
if step_within_questions % (2 * self.iterations_per_question_part) < self.iterations_per_question_part:
|
| 308 |
+
is_index_step = True
|
| 309 |
+
is_operator_step = False
|
| 310 |
+
else:
|
| 311 |
+
is_index_step = False
|
| 312 |
+
is_operator_step = True
|
| 313 |
+
return is_index_step, is_operator_step
|
| 314 |
+
|
| 315 |
+
class MNISTBackbone(nn.Module):
|
| 316 |
+
"""
|
| 317 |
+
Simple backbone for MNIST feature extraction.
|
| 318 |
+
"""
|
| 319 |
+
def __init__(self, d_input):
|
| 320 |
+
super(MNISTBackbone, self).__init__()
|
| 321 |
+
self.layers = nn.Sequential(
|
| 322 |
+
nn.LazyConv2d(d_input, kernel_size=3, stride=1, padding=1),
|
| 323 |
+
nn.BatchNorm2d(d_input),
|
| 324 |
+
nn.ReLU(),
|
| 325 |
+
nn.MaxPool2d(2, 2),
|
| 326 |
+
nn.LazyConv2d(d_input, kernel_size=3, stride=1, padding=1),
|
| 327 |
+
nn.BatchNorm2d(d_input),
|
| 328 |
+
nn.ReLU(),
|
| 329 |
+
nn.MaxPool2d(2, 2),
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
def forward(self, x):
|
| 333 |
+
return self.layers(x)
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
class MiniGridBackbone(nn.Module):
|
| 337 |
+
def __init__(self, d_input, grid_size=7, num_objects=11, num_colors=6, num_states=3, embedding_dim=8):
|
| 338 |
+
super().__init__()
|
| 339 |
+
self.object_embedding = nn.Embedding(num_objects, embedding_dim)
|
| 340 |
+
self.color_embedding = nn.Embedding(num_colors, embedding_dim)
|
| 341 |
+
self.state_embedding = nn.Embedding(num_states, embedding_dim)
|
| 342 |
+
|
| 343 |
+
self.position_embedding = nn.Embedding(grid_size * grid_size, embedding_dim)
|
| 344 |
+
|
| 345 |
+
self.project_to_d_projection = nn.Sequential(
|
| 346 |
+
nn.Linear(embedding_dim * 4, d_input * 2),
|
| 347 |
+
nn.GLU(),
|
| 348 |
+
nn.LayerNorm(d_input),
|
| 349 |
+
nn.Linear(d_input, d_input * 2),
|
| 350 |
+
nn.GLU(),
|
| 351 |
+
nn.LayerNorm(d_input)
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
def forward(self, x):
|
| 355 |
+
x = x.long()
|
| 356 |
+
B, H, W, C = x.size()
|
| 357 |
+
|
| 358 |
+
object_idx = x[:,:,:, 0]
|
| 359 |
+
color_idx = x[:,:,:, 1]
|
| 360 |
+
state_idx = x[:,:,:, 2]
|
| 361 |
+
|
| 362 |
+
obj_embed = self.object_embedding(object_idx)
|
| 363 |
+
color_embed = self.color_embedding(color_idx)
|
| 364 |
+
state_embed = self.state_embedding(state_idx)
|
| 365 |
+
|
| 366 |
+
pos_idx = torch.arange(H * W, device=x.device).view(1, H, W).expand(B, -1, -1)
|
| 367 |
+
pos_embed = self.position_embedding(pos_idx)
|
| 368 |
+
|
| 369 |
+
out = self.project_to_d_projection(torch.cat([obj_embed, color_embed, state_embed, pos_embed], dim=-1))
|
| 370 |
+
return out
|
| 371 |
+
|
| 372 |
+
class ClassicControlBackbone(nn.Module):
|
| 373 |
+
def __init__(self, d_input):
|
| 374 |
+
super().__init__()
|
| 375 |
+
self.input_projector = nn.Sequential(
|
| 376 |
+
nn.Flatten(),
|
| 377 |
+
nn.LazyLinear(d_input * 2),
|
| 378 |
+
nn.GLU(),
|
| 379 |
+
nn.LayerNorm(d_input),
|
| 380 |
+
nn.LazyLinear(d_input * 2),
|
| 381 |
+
nn.GLU(),
|
| 382 |
+
nn.LayerNorm(d_input)
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
def forward(self, x):
|
| 386 |
+
return self.input_projector(x)
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
class ShallowWide(nn.Module):
|
| 390 |
+
"""
|
| 391 |
+
Simple, wide, shallow convolutional backbone for image feature extraction.
|
| 392 |
+
|
| 393 |
+
Alternative to ResNet, uses grouped convolutions and GLU activations.
|
| 394 |
+
Fixed structure, useful for specific experiments.
|
| 395 |
+
"""
|
| 396 |
+
def __init__(self):
|
| 397 |
+
super(ShallowWide, self).__init__()
|
| 398 |
+
# LazyConv2d infers input channels
|
| 399 |
+
self.layers = nn.Sequential(
|
| 400 |
+
nn.LazyConv2d(4096, kernel_size=3, stride=2, padding=1), # Output channels = 4096
|
| 401 |
+
nn.GLU(dim=1), # Halves channels to 2048
|
| 402 |
+
nn.BatchNorm2d(2048),
|
| 403 |
+
# Grouped convolution maintains width but processes groups independently
|
| 404 |
+
nn.Conv2d(2048, 4096, kernel_size=3, stride=1, padding=1, groups=32),
|
| 405 |
+
nn.GLU(dim=1), # Halves channels to 2048
|
| 406 |
+
nn.BatchNorm2d(2048)
|
| 407 |
+
)
|
| 408 |
+
def forward(self, x):
|
| 409 |
+
return self.layers(x)
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
class PretrainedResNetWrapper(nn.Module):
|
| 413 |
+
"""
|
| 414 |
+
Wrapper to use standard pre-trained ResNet models from torchvision.
|
| 415 |
+
|
| 416 |
+
Loads a specified ResNet architecture pre-trained on ImageNet, removes the
|
| 417 |
+
final classification layer (fc), average pooling, and optionally later layers
|
| 418 |
+
(e.g., layer4), allowing it to be used as a feature extractor backbone.
|
| 419 |
+
|
| 420 |
+
Args:
|
| 421 |
+
resnet_type (str): Name of the ResNet model (e.g., 'resnet18', 'resnet50').
|
| 422 |
+
fine_tune (bool): If False, freezes the weights of the pre-trained backbone.
|
| 423 |
+
"""
|
| 424 |
+
def __init__(self, resnet_type, fine_tune=True):
|
| 425 |
+
super(PretrainedResNetWrapper, self).__init__()
|
| 426 |
+
self.resnet_type = resnet_type
|
| 427 |
+
self.backbone = torch.hub.load('pytorch/vision:v0.10.0', resnet_type, pretrained=True)
|
| 428 |
+
|
| 429 |
+
if not fine_tune:
|
| 430 |
+
for param in self.backbone.parameters():
|
| 431 |
+
param.requires_grad = False
|
| 432 |
+
|
| 433 |
+
# Remove final layers to use as feature extractor
|
| 434 |
+
self.backbone.avgpool = Identity()
|
| 435 |
+
self.backbone.fc = Identity()
|
| 436 |
+
# Keep layer4 by default, user can modify instance if needed
|
| 437 |
+
# self.backbone.layer4 = Identity()
|
| 438 |
+
|
| 439 |
+
def forward(self, x):
|
| 440 |
+
# Get features from the modified ResNet
|
| 441 |
+
out = self.backbone(x)
|
| 442 |
+
|
| 443 |
+
# Reshape output to (B, C, H, W) - This is heuristic based on original comment.
|
| 444 |
+
# User might need to adjust this based on which layers are kept/removed.
|
| 445 |
+
# Infer C based on ResNet type (example values)
|
| 446 |
+
nc = 256 if ('18' in self.resnet_type or '34' in self.resnet_type) else 512 if '50' in self.resnet_type else 1024 if '101' in self.resnet_type else 2048 # Approx for layer3/4 output channel numbers
|
| 447 |
+
# Infer H, W assuming output is flattened C * H * W
|
| 448 |
+
num_features = out.shape[-1]
|
| 449 |
+
# This calculation assumes nc is correct and feature map is square
|
| 450 |
+
wh_squared = num_features / nc
|
| 451 |
+
if wh_squared < 0 or not float(wh_squared).is_integer():
|
| 452 |
+
print(f"Warning: Cannot reliably reshape PretrainedResNetWrapper output. nc={nc}, num_features={num_features}")
|
| 453 |
+
# Return potentially flattened features if reshape fails
|
| 454 |
+
return out
|
| 455 |
+
wh = int(np.sqrt(wh_squared))
|
| 456 |
+
|
| 457 |
+
return out.reshape(x.size(0), nc, wh, wh)
|
| 458 |
+
|
| 459 |
+
# --- Positional Encoding Modules ---
|
| 460 |
+
|
| 461 |
+
class LearnableFourierPositionalEncoding(nn.Module):
|
| 462 |
+
"""
|
| 463 |
+
Learnable Fourier Feature Positional Encoding.
|
| 464 |
+
|
| 465 |
+
Implements Algorithm 1 from "Learnable Fourier Features for Multi-Dimensional
|
| 466 |
+
Spatial Positional Encoding" (https://arxiv.org/pdf/2106.02795.pdf).
|
| 467 |
+
Provides positional information for 2D feature maps.
|
| 468 |
+
|
| 469 |
+
Args:
|
| 470 |
+
d_model (int): The output dimension of the positional encoding (D).
|
| 471 |
+
G (int): Positional groups (default 1).
|
| 472 |
+
M (int): Dimensionality of input coordinates (default 2 for H, W).
|
| 473 |
+
F_dim (int): Dimension of the Fourier features.
|
| 474 |
+
H_dim (int): Hidden dimension of the MLP.
|
| 475 |
+
gamma (float): Initialization scale for the Fourier projection weights (Wr).
|
| 476 |
+
"""
|
| 477 |
+
def __init__(self, d_model,
|
| 478 |
+
G=1, M=2,
|
| 479 |
+
F_dim=256,
|
| 480 |
+
H_dim=128,
|
| 481 |
+
gamma=1/2.5,
|
| 482 |
+
):
|
| 483 |
+
super().__init__()
|
| 484 |
+
self.G = G
|
| 485 |
+
self.M = M
|
| 486 |
+
self.F_dim = F_dim
|
| 487 |
+
self.H_dim = H_dim
|
| 488 |
+
self.D = d_model
|
| 489 |
+
self.gamma = gamma
|
| 490 |
+
|
| 491 |
+
self.Wr = nn.Linear(self.M, self.F_dim // 2, bias=False)
|
| 492 |
+
self.mlp = nn.Sequential(
|
| 493 |
+
nn.Linear(self.F_dim, self.H_dim, bias=True),
|
| 494 |
+
nn.GLU(), # Halves H_dim
|
| 495 |
+
nn.Linear(self.H_dim // 2, self.D // self.G),
|
| 496 |
+
nn.LayerNorm(self.D // self.G)
|
| 497 |
+
)
|
| 498 |
+
|
| 499 |
+
self.init_weights()
|
| 500 |
+
|
| 501 |
+
def init_weights(self):
|
| 502 |
+
nn.init.normal_(self.Wr.weight.data, mean=0, std=self.gamma ** -2)
|
| 503 |
+
|
| 504 |
+
def forward(self, x):
|
| 505 |
+
"""
|
| 506 |
+
Computes positional encodings for the input feature map x.
|
| 507 |
+
|
| 508 |
+
Args:
|
| 509 |
+
x (torch.Tensor): Input feature map, shape (B, C, H, W).
|
| 510 |
+
|
| 511 |
+
Returns:
|
| 512 |
+
torch.Tensor: Positional encoding tensor, shape (B, D, H, W).
|
| 513 |
+
"""
|
| 514 |
+
B, C, H, W = x.shape
|
| 515 |
+
# Creates coordinates based on (H, W) and repeats for batch B.
|
| 516 |
+
# Takes x[:,0] assuming channel dim isn't needed for coords.
|
| 517 |
+
x_coord = add_coord_dim(x[:,0]) # Expects (B, H, W) -> (B, H, W, 2)
|
| 518 |
+
|
| 519 |
+
# Compute Fourier features
|
| 520 |
+
projected = self.Wr(x_coord) # (B, H, W, F_dim // 2)
|
| 521 |
+
cosines = torch.cos(projected)
|
| 522 |
+
sines = torch.sin(projected)
|
| 523 |
+
F = (1.0 / math.sqrt(self.F_dim)) * torch.cat([cosines, sines], dim=-1) # (B, H, W, F_dim)
|
| 524 |
+
|
| 525 |
+
# Project features through MLP
|
| 526 |
+
Y = self.mlp(F) # (B, H, W, D // G)
|
| 527 |
+
|
| 528 |
+
# Reshape to (B, D, H, W)
|
| 529 |
+
PEx = Y.permute(0, 3, 1, 2) # Assuming G=1
|
| 530 |
+
return PEx
|
| 531 |
+
|
| 532 |
+
|
| 533 |
+
class MultiLearnableFourierPositionalEncoding(nn.Module):
|
| 534 |
+
"""
|
| 535 |
+
Combines multiple LearnableFourierPositionalEncoding modules with different
|
| 536 |
+
initialization scales (gamma) via a learnable weighted sum.
|
| 537 |
+
|
| 538 |
+
Allows the model to learn an optimal combination of positional frequencies.
|
| 539 |
+
|
| 540 |
+
Args:
|
| 541 |
+
d_model (int): Output dimension of the encoding.
|
| 542 |
+
G, M, F_dim, H_dim: Parameters passed to underlying LearnableFourierPositionalEncoding.
|
| 543 |
+
gamma_range (list[float]): Min and max gamma values for the linspace.
|
| 544 |
+
N (int): Number of parallel embedding modules to create.
|
| 545 |
+
"""
|
| 546 |
+
def __init__(self, d_model,
|
| 547 |
+
G=1, M=2,
|
| 548 |
+
F_dim=256,
|
| 549 |
+
H_dim=128,
|
| 550 |
+
gamma_range=[1.0, 0.1], # Default range
|
| 551 |
+
N=10,
|
| 552 |
+
):
|
| 553 |
+
super().__init__()
|
| 554 |
+
self.embedders = nn.ModuleList()
|
| 555 |
+
for gamma in np.linspace(gamma_range[0], gamma_range[1], N):
|
| 556 |
+
self.embedders.append(LearnableFourierPositionalEncoding(d_model, G, M, F_dim, H_dim, gamma))
|
| 557 |
+
|
| 558 |
+
# Renamed parameter from 'combination' to 'combination_weights' for clarity only in comments
|
| 559 |
+
# Actual registered name remains 'combination' as in original code
|
| 560 |
+
self.register_parameter('combination', torch.nn.Parameter(torch.ones(N), requires_grad=True))
|
| 561 |
+
self.N = N
|
| 562 |
+
|
| 563 |
+
|
| 564 |
+
def forward(self, x):
|
| 565 |
+
"""
|
| 566 |
+
Computes combined positional encoding.
|
| 567 |
+
|
| 568 |
+
Args:
|
| 569 |
+
x (torch.Tensor): Input feature map, shape (B, C, H, W).
|
| 570 |
+
|
| 571 |
+
Returns:
|
| 572 |
+
torch.Tensor: Combined positional encoding tensor, shape (B, D, H, W).
|
| 573 |
+
"""
|
| 574 |
+
# Compute embeddings from all modules and stack: (N, B, D, H, W)
|
| 575 |
+
pos_embs = torch.stack([emb(x) for emb in self.embedders], dim=0)
|
| 576 |
+
|
| 577 |
+
# Compute combination weights using softmax
|
| 578 |
+
# Use registered parameter name 'combination'
|
| 579 |
+
# Reshape weights for broadcasting: (N,) -> (N, 1, 1, 1, 1)
|
| 580 |
+
weights = F.softmax(self.combination, dim=-1).view(self.N, 1, 1, 1, 1)
|
| 581 |
+
|
| 582 |
+
# Compute weighted sum over the N dimension
|
| 583 |
+
combined_emb = (pos_embs * weights).sum(0) # (B, D, H, W)
|
| 584 |
+
return combined_emb
|
| 585 |
+
|
| 586 |
+
|
| 587 |
+
class CustomRotationalEmbedding(nn.Module):
|
| 588 |
+
"""
|
| 589 |
+
Custom Rotational Positional Embedding.
|
| 590 |
+
|
| 591 |
+
Generates 2D positional embeddings based on rotating a fixed start vector.
|
| 592 |
+
The rotation angle for each grid position is determined primarily by its
|
| 593 |
+
horizontal position (width dimension). The resulting rotated vectors are
|
| 594 |
+
concatenated and projected.
|
| 595 |
+
|
| 596 |
+
Note: The current implementation derives angles only from the width dimension (`x.size(-1)`).
|
| 597 |
+
|
| 598 |
+
Args:
|
| 599 |
+
d_model (int): Dimensionality of the output embeddings.
|
| 600 |
+
"""
|
| 601 |
+
def __init__(self, d_model):
|
| 602 |
+
super(CustomRotationalEmbedding, self).__init__()
|
| 603 |
+
# Learnable 2D start vector
|
| 604 |
+
self.register_parameter('start_vector', nn.Parameter(torch.Tensor([0, 1]), requires_grad=True))
|
| 605 |
+
# Projects the 4D concatenated rotated vectors to d_model
|
| 606 |
+
# Input size 4 comes from concatenating two 2D rotated vectors
|
| 607 |
+
self.projection = nn.Sequential(nn.Linear(4, d_model))
|
| 608 |
+
|
| 609 |
+
def forward(self, x):
|
| 610 |
+
"""
|
| 611 |
+
Computes rotational positional embeddings based on input width.
|
| 612 |
+
|
| 613 |
+
Args:
|
| 614 |
+
x (torch.Tensor): Input tensor (used for shape and device),
|
| 615 |
+
shape (batch_size, channels, height, width).
|
| 616 |
+
Returns:
|
| 617 |
+
Output tensor containing positional embeddings,
|
| 618 |
+
shape (1, d_model, height, width) - Batch dim is 1 as PE is same for all.
|
| 619 |
+
"""
|
| 620 |
+
B, C, H, W = x.shape
|
| 621 |
+
device = x.device
|
| 622 |
+
|
| 623 |
+
# --- Generate rotations based only on Width ---
|
| 624 |
+
# Angles derived from width dimension
|
| 625 |
+
theta_rad = torch.deg2rad(torch.linspace(0, 180, W, device=device)) # Angle per column
|
| 626 |
+
cos_theta = torch.cos(theta_rad)
|
| 627 |
+
sin_theta = torch.sin(theta_rad)
|
| 628 |
+
|
| 629 |
+
# Create rotation matrices: Shape (W, 2, 2)
|
| 630 |
+
# Use unsqueeze(1) to allow stacking along dim 1
|
| 631 |
+
rotation_matrices = torch.stack([
|
| 632 |
+
torch.stack([cos_theta, -sin_theta], dim=-1), # Shape (W, 2)
|
| 633 |
+
torch.stack([sin_theta, cos_theta], dim=-1) # Shape (W, 2)
|
| 634 |
+
], dim=1) # Stacks along dim 1 -> Shape (W, 2, 2)
|
| 635 |
+
|
| 636 |
+
# Rotate the start vector by column angle: Shape (W, 2)
|
| 637 |
+
rotated_vectors = torch.einsum('wij,j->wi', rotation_matrices, self.start_vector)
|
| 638 |
+
|
| 639 |
+
# --- Create Grid Key ---
|
| 640 |
+
# Original code uses repeats based on rotated_vectors.shape[0] (which is W) for both dimensions.
|
| 641 |
+
# This creates a (W, W, 4) key tensor.
|
| 642 |
+
key = torch.cat((
|
| 643 |
+
torch.repeat_interleave(rotated_vectors.unsqueeze(1), W, dim=1), # (W, 1, 2) -> (W, W, 2)
|
| 644 |
+
torch.repeat_interleave(rotated_vectors.unsqueeze(0), W, dim=0) # (1, W, 2) -> (W, W, 2)
|
| 645 |
+
), dim=-1) # Shape (W, W, 4)
|
| 646 |
+
|
| 647 |
+
# Project the 4D key vector to d_model: Shape (W, W, d_model)
|
| 648 |
+
pe_grid = self.projection(key)
|
| 649 |
+
|
| 650 |
+
# Reshape to (1, d_model, W, W) and then select/resize to target H, W?
|
| 651 |
+
# Original code permutes to (d_model, W, W) and unsqueezes to (1, d_model, W, W)
|
| 652 |
+
pe = pe_grid.permute(2, 0, 1).unsqueeze(0)
|
| 653 |
+
|
| 654 |
+
# If H != W, this needs adjustment. Assuming H=W or cropping/padding happens later.
|
| 655 |
+
# Let's return the (1, d_model, W, W) tensor as generated by the original logic.
|
| 656 |
+
# If H != W, downstream code must handle the mismatch or this PE needs modification.
|
| 657 |
+
if H != W:
|
| 658 |
+
# Simple interpolation/cropping could be added, but sticking to original logic:
|
| 659 |
+
# Option 1: Interpolate
|
| 660 |
+
# pe = F.interpolate(pe, size=(H, W), mode='bilinear', align_corners=False)
|
| 661 |
+
# Option 2: Crop/Pad (e.g., crop if W > W_target, pad if W < W_target)
|
| 662 |
+
# Sticking to original: return shape (1, d_model, W, W)
|
| 663 |
+
pass
|
| 664 |
+
|
| 665 |
+
return pe
|
| 666 |
+
|
| 667 |
+
class CustomRotationalEmbedding1D(nn.Module):
|
| 668 |
+
def __init__(self, d_model):
|
| 669 |
+
super(CustomRotationalEmbedding1D, self).__init__()
|
| 670 |
+
self.projection = nn.Linear(2, d_model)
|
| 671 |
+
|
| 672 |
+
def forward(self, x):
|
| 673 |
+
start_vector = torch.tensor([0., 1.], device=x.device, dtype=torch.float)
|
| 674 |
+
theta_rad = torch.deg2rad(torch.linspace(0, 180, x.size(2), device=x.device))
|
| 675 |
+
cos_theta = torch.cos(theta_rad)
|
| 676 |
+
sin_theta = torch.sin(theta_rad)
|
| 677 |
+
cos_theta = cos_theta.unsqueeze(1) # Shape: (height, 1)
|
| 678 |
+
sin_theta = sin_theta.unsqueeze(1) # Shape: (height, 1)
|
| 679 |
+
|
| 680 |
+
# Create rotation matrices
|
| 681 |
+
rotation_matrices = torch.stack([
|
| 682 |
+
torch.cat([cos_theta, -sin_theta], dim=1),
|
| 683 |
+
torch.cat([sin_theta, cos_theta], dim=1)
|
| 684 |
+
], dim=1) # Shape: (height, 2, 2)
|
| 685 |
+
|
| 686 |
+
# Rotate the start vector
|
| 687 |
+
rotated_vectors = torch.einsum('bij,j->bi', rotation_matrices, start_vector)
|
| 688 |
+
|
| 689 |
+
pe = self.projection(rotated_vectors)
|
| 690 |
+
pe = torch.repeat_interleave(pe.unsqueeze(0), x.size(0), 0)
|
| 691 |
+
return pe.transpose(1, 2) # Transpose for compatibility with other backbones
|
| 692 |
+
|
models/resnet.py
ADDED
|
@@ -0,0 +1,374 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import os
|
| 4 |
+
from models.modules import Identity
|
| 5 |
+
|
| 6 |
+
__all__ = [
|
| 7 |
+
"ResNet",
|
| 8 |
+
"resnet18",
|
| 9 |
+
"resnet34",
|
| 10 |
+
"resnet50",
|
| 11 |
+
"resnet101",
|
| 12 |
+
"resnet152",
|
| 13 |
+
]
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
| 17 |
+
"""3x3 convolution with padding"""
|
| 18 |
+
return nn.Conv2d(
|
| 19 |
+
in_planes,
|
| 20 |
+
out_planes,
|
| 21 |
+
kernel_size=3,
|
| 22 |
+
stride=stride,
|
| 23 |
+
padding=dilation,
|
| 24 |
+
groups=groups,
|
| 25 |
+
bias=False,
|
| 26 |
+
dilation=dilation,
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def conv1x1(in_planes, out_planes, stride=1):
|
| 31 |
+
"""1x1 convolution"""
|
| 32 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class BasicBlock(nn.Module):
|
| 36 |
+
expansion = 1
|
| 37 |
+
|
| 38 |
+
def __init__(
|
| 39 |
+
self,
|
| 40 |
+
inplanes,
|
| 41 |
+
planes,
|
| 42 |
+
stride=1,
|
| 43 |
+
downsample=None,
|
| 44 |
+
groups=1,
|
| 45 |
+
base_width=64,
|
| 46 |
+
dilation=1,
|
| 47 |
+
norm_layer=None,
|
| 48 |
+
):
|
| 49 |
+
super(BasicBlock, self).__init__()
|
| 50 |
+
if norm_layer is None:
|
| 51 |
+
norm_layer = nn.BatchNorm2d
|
| 52 |
+
if groups != 1 or base_width != 64:
|
| 53 |
+
raise ValueError("BasicBlock only supports groups=1 and base_width=64")
|
| 54 |
+
if dilation > 1:
|
| 55 |
+
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
| 56 |
+
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
| 57 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
| 58 |
+
self.bn1 = norm_layer(planes)
|
| 59 |
+
self.relu = nn.ReLU(inplace=True)
|
| 60 |
+
self.conv2 = conv3x3(planes, planes)
|
| 61 |
+
self.bn2 = norm_layer(planes)
|
| 62 |
+
self.downsample = downsample
|
| 63 |
+
self.stride = stride
|
| 64 |
+
|
| 65 |
+
def forward(self, x):
|
| 66 |
+
identity = x
|
| 67 |
+
|
| 68 |
+
out = self.conv1(x)
|
| 69 |
+
out = self.bn1(out)
|
| 70 |
+
out = self.relu(out)
|
| 71 |
+
|
| 72 |
+
out = self.conv2(out)
|
| 73 |
+
out = self.bn2(out)
|
| 74 |
+
|
| 75 |
+
if self.downsample is not None:
|
| 76 |
+
identity = self.downsample(x)
|
| 77 |
+
|
| 78 |
+
out += identity
|
| 79 |
+
|
| 80 |
+
out = self.relu(out)
|
| 81 |
+
return out
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class Bottleneck(nn.Module):
|
| 85 |
+
expansion = 4
|
| 86 |
+
|
| 87 |
+
def __init__(
|
| 88 |
+
self,
|
| 89 |
+
inplanes,
|
| 90 |
+
planes,
|
| 91 |
+
stride=1,
|
| 92 |
+
downsample=None,
|
| 93 |
+
groups=1,
|
| 94 |
+
base_width=64,
|
| 95 |
+
dilation=1,
|
| 96 |
+
norm_layer=None,
|
| 97 |
+
):
|
| 98 |
+
super(Bottleneck, self).__init__()
|
| 99 |
+
if norm_layer is None:
|
| 100 |
+
norm_layer = nn.BatchNorm2d
|
| 101 |
+
width = int(planes * (base_width / 64.0)) * groups
|
| 102 |
+
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
|
| 103 |
+
self.conv1 = conv1x1(inplanes, width)
|
| 104 |
+
self.bn1 = norm_layer(width)
|
| 105 |
+
self.conv2 = conv3x3(width, width, stride, groups, dilation)
|
| 106 |
+
self.bn2 = norm_layer(width)
|
| 107 |
+
self.conv3 = conv1x1(width, planes * self.expansion)
|
| 108 |
+
self.bn3 = norm_layer(planes * self.expansion)
|
| 109 |
+
self.relu = nn.ReLU(inplace=True)
|
| 110 |
+
self.downsample = downsample
|
| 111 |
+
self.stride = stride
|
| 112 |
+
|
| 113 |
+
def forward(self, x):
|
| 114 |
+
identity = x
|
| 115 |
+
|
| 116 |
+
out = self.conv1(x)
|
| 117 |
+
out = self.bn1(out)
|
| 118 |
+
out = self.relu(out)
|
| 119 |
+
|
| 120 |
+
out = self.conv2(out)
|
| 121 |
+
out = self.bn2(out)
|
| 122 |
+
out = self.relu(out)
|
| 123 |
+
|
| 124 |
+
out = self.conv3(out)
|
| 125 |
+
out = self.bn3(out)
|
| 126 |
+
|
| 127 |
+
if self.downsample is not None:
|
| 128 |
+
identity = self.downsample(x)
|
| 129 |
+
|
| 130 |
+
out += identity
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
# activation = None
|
| 134 |
+
# activation = out.detach().cpu().numpy()
|
| 135 |
+
out = self.relu(out)
|
| 136 |
+
# return out, activation
|
| 137 |
+
|
| 138 |
+
return out
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
class ResNet(nn.Module):
|
| 142 |
+
def __init__(
|
| 143 |
+
self,
|
| 144 |
+
in_channels,
|
| 145 |
+
feature_scales,
|
| 146 |
+
stride,
|
| 147 |
+
block,
|
| 148 |
+
layers,
|
| 149 |
+
num_classes=10,
|
| 150 |
+
zero_init_residual=False,
|
| 151 |
+
groups=1,
|
| 152 |
+
width_per_group=64,
|
| 153 |
+
replace_stride_with_dilation=None,
|
| 154 |
+
norm_layer=None,
|
| 155 |
+
do_initial_max_pool=True,
|
| 156 |
+
):
|
| 157 |
+
super(ResNet, self).__init__()
|
| 158 |
+
if norm_layer is None:
|
| 159 |
+
norm_layer = nn.BatchNorm2d
|
| 160 |
+
self._norm_layer = norm_layer
|
| 161 |
+
|
| 162 |
+
self.inplanes = 64
|
| 163 |
+
self.dilation = 1
|
| 164 |
+
if replace_stride_with_dilation is None:
|
| 165 |
+
# each element in the tuple indicates if we should replace
|
| 166 |
+
# the 2x2 stride with a dilated convolution instead
|
| 167 |
+
replace_stride_with_dilation = [False, False, False]
|
| 168 |
+
if len(replace_stride_with_dilation) != 3:
|
| 169 |
+
raise ValueError(
|
| 170 |
+
"replace_stride_with_dilation should be None "
|
| 171 |
+
"or a 3-element tuple, got {}".format(replace_stride_with_dilation)
|
| 172 |
+
)
|
| 173 |
+
self.groups = groups
|
| 174 |
+
self.base_width = width_per_group
|
| 175 |
+
|
| 176 |
+
# NOTE: Important!
|
| 177 |
+
# This has changed from a kernel size of 7 (padding=3) to a kernel of 3 (padding=1)
|
| 178 |
+
# The reason for this was to limit the receptive field to constrain models to
|
| 179 |
+
# "Looking around" to gather information.
|
| 180 |
+
|
| 181 |
+
self.conv1 = nn.Conv2d(
|
| 182 |
+
in_channels, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False
|
| 183 |
+
) if in_channels in [1, 3] else nn.LazyConv2d(
|
| 184 |
+
self.inplanes, kernel_size=3, stride=1, padding=1, bias=False
|
| 185 |
+
)
|
| 186 |
+
# END
|
| 187 |
+
|
| 188 |
+
self.bn1 = norm_layer(self.inplanes)
|
| 189 |
+
self.relu = nn.ReLU(inplace=True)
|
| 190 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) if do_initial_max_pool else Identity()
|
| 191 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
| 192 |
+
self.feature_scales = feature_scales
|
| 193 |
+
if 2 in feature_scales:
|
| 194 |
+
self.layer2 = self._make_layer(
|
| 195 |
+
block, 128, layers[1], stride=stride, dilate=replace_stride_with_dilation[0]
|
| 196 |
+
)
|
| 197 |
+
if 3 in feature_scales:
|
| 198 |
+
self.layer3 = self._make_layer(
|
| 199 |
+
block, 256, layers[2], stride=stride, dilate=replace_stride_with_dilation[1]
|
| 200 |
+
)
|
| 201 |
+
if 4 in feature_scales:
|
| 202 |
+
self.layer4 = self._make_layer(
|
| 203 |
+
block, 512, layers[3], stride=stride, dilate=replace_stride_with_dilation[2]
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
# NOTE: Commented this out as it is not used anymore for this work, kept it for reference
|
| 207 |
+
# self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
| 208 |
+
# self.fc = nn.Linear(512 * block.expansion, num_classes)
|
| 209 |
+
|
| 210 |
+
# for m in self.modules():
|
| 211 |
+
# if isinstance(m, nn.Conv2d):
|
| 212 |
+
# nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
| 213 |
+
# elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
| 214 |
+
# nn.init.constant_(m.weight, 1)
|
| 215 |
+
# nn.init.constant_(m.bias, 0)
|
| 216 |
+
|
| 217 |
+
# Zero-initialize the last BN in each residual branch,
|
| 218 |
+
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
| 219 |
+
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
| 220 |
+
if zero_init_residual:
|
| 221 |
+
for m in self.modules():
|
| 222 |
+
if isinstance(m, Bottleneck):
|
| 223 |
+
nn.init.constant_(m.bn3.weight, 0)
|
| 224 |
+
elif isinstance(m, BasicBlock):
|
| 225 |
+
nn.init.constant_(m.bn2.weight, 0)
|
| 226 |
+
|
| 227 |
+
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
|
| 228 |
+
norm_layer = self._norm_layer
|
| 229 |
+
downsample = None
|
| 230 |
+
previous_dilation = self.dilation
|
| 231 |
+
if dilate:
|
| 232 |
+
self.dilation *= stride
|
| 233 |
+
stride = 1
|
| 234 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
| 235 |
+
downsample = nn.Sequential(
|
| 236 |
+
conv1x1(self.inplanes, planes * block.expansion, stride),
|
| 237 |
+
norm_layer(planes * block.expansion),
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
layers = []
|
| 241 |
+
layers.append(
|
| 242 |
+
block(
|
| 243 |
+
self.inplanes,
|
| 244 |
+
planes,
|
| 245 |
+
stride,
|
| 246 |
+
downsample,
|
| 247 |
+
self.groups,
|
| 248 |
+
self.base_width,
|
| 249 |
+
previous_dilation,
|
| 250 |
+
norm_layer,
|
| 251 |
+
)
|
| 252 |
+
)
|
| 253 |
+
self.inplanes = planes * block.expansion
|
| 254 |
+
for _ in range(1, blocks):
|
| 255 |
+
layers.append(
|
| 256 |
+
block(
|
| 257 |
+
self.inplanes,
|
| 258 |
+
planes,
|
| 259 |
+
groups=self.groups,
|
| 260 |
+
base_width=self.base_width,
|
| 261 |
+
dilation=self.dilation,
|
| 262 |
+
norm_layer=norm_layer,
|
| 263 |
+
)
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
return nn.Sequential(*layers)
|
| 267 |
+
|
| 268 |
+
def forward(self, x):
|
| 269 |
+
activations = []
|
| 270 |
+
x = self.conv1(x)
|
| 271 |
+
x = self.bn1(x)
|
| 272 |
+
x = self.relu(x)
|
| 273 |
+
x = self.maxpool(x)
|
| 274 |
+
# if return_activations: activations.append(torch.clone(x))
|
| 275 |
+
x = self.layer1(x)
|
| 276 |
+
|
| 277 |
+
if 2 in self.feature_scales:
|
| 278 |
+
x = self.layer2(x)
|
| 279 |
+
if 3 in self.feature_scales:
|
| 280 |
+
x = self.layer3(x)
|
| 281 |
+
if 4 in self.feature_scales:
|
| 282 |
+
x = self.layer4(x)
|
| 283 |
+
return x
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
def _resnet(in_channels, feature_scales, stride, arch, block, layers, pretrained, progress, device, do_initial_max_pool, **kwargs):
|
| 287 |
+
model = ResNet(in_channels, feature_scales, stride, block, layers, do_initial_max_pool=do_initial_max_pool, **kwargs)
|
| 288 |
+
if pretrained:
|
| 289 |
+
assert in_channels==3
|
| 290 |
+
script_dir = os.path.dirname(__file__)
|
| 291 |
+
state_dict = torch.load(
|
| 292 |
+
script_dir + '/state_dicts/' + arch + ".pt", map_location=device
|
| 293 |
+
)
|
| 294 |
+
model.load_state_dict(state_dict, strict=False)
|
| 295 |
+
return model
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
def resnet18(in_channels, feature_scales, stride=2, pretrained=False, progress=True, device="cpu", do_initial_max_pool=True, **kwargs):
|
| 299 |
+
"""Constructs a ResNet-18 model.
|
| 300 |
+
Args:
|
| 301 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
| 302 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
| 303 |
+
"""
|
| 304 |
+
return _resnet(in_channels,
|
| 305 |
+
feature_scales, stride, "resnet18", BasicBlock, [2, 2, 2, 2], pretrained, progress, device, do_initial_max_pool, **kwargs
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def resnet34(in_channels, feature_scales, stride=2, pretrained=False, progress=True, device="cpu", do_initial_max_pool=True, **kwargs):
|
| 310 |
+
"""Constructs a ResNet-34 model.
|
| 311 |
+
Args:
|
| 312 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
| 313 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
| 314 |
+
"""
|
| 315 |
+
return _resnet(in_channels,
|
| 316 |
+
feature_scales, stride, "resnet34", BasicBlock, [3, 4, 6, 3], pretrained, progress, device, do_initial_max_pool, **kwargs
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
def resnet50(in_channels, feature_scales, stride=2, pretrained=False, progress=True, device="cpu", do_initial_max_pool=True, **kwargs):
|
| 321 |
+
"""Constructs a ResNet-50 model.
|
| 322 |
+
Args:
|
| 323 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
| 324 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
| 325 |
+
"""
|
| 326 |
+
return _resnet(in_channels,
|
| 327 |
+
feature_scales, stride, "resnet50", Bottleneck, [3, 4, 6, 3], pretrained, progress, device, do_initial_max_pool, **kwargs
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
def resnet101(in_channels, feature_scales, stride=2, pretrained=False, progress=True, device="cpu", do_initial_max_pool=True, **kwargs):
|
| 332 |
+
"""Constructs a ResNet-50 model.
|
| 333 |
+
Args:
|
| 334 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
| 335 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
| 336 |
+
"""
|
| 337 |
+
return _resnet(in_channels,
|
| 338 |
+
feature_scales, stride, "resnet101", Bottleneck, [3, 4, 23, 3], pretrained, progress, device, do_initial_max_pool, **kwargs
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
def resnet152(in_channels, feature_scales, stride=2, pretrained=False, progress=True, device="cpu", do_initial_max_pool=True, **kwargs):
|
| 343 |
+
"""Constructs a ResNet-50 model.
|
| 344 |
+
Args:
|
| 345 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
| 346 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
| 347 |
+
"""
|
| 348 |
+
return _resnet(in_channels,
|
| 349 |
+
feature_scales, stride, "resnet152", Bottleneck, [3, 4, 36, 3], pretrained, progress, device, do_initial_max_pool, **kwargs
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
def prepare_resnet_backbone(backbone_type):
|
| 353 |
+
|
| 354 |
+
resnet_family = resnet18 # Default
|
| 355 |
+
if '34' in backbone_type: resnet_family = resnet34
|
| 356 |
+
if '50' in backbone_type: resnet_family = resnet50
|
| 357 |
+
if '101' in backbone_type: resnet_family = resnet101
|
| 358 |
+
if '152' in backbone_type: resnet_family = resnet152
|
| 359 |
+
|
| 360 |
+
# Determine which ResNet blocks to keep
|
| 361 |
+
block_num_str = backbone_type.split('-')[-1]
|
| 362 |
+
hyper_blocks_to_keep = list(range(1, int(block_num_str) + 1)) if block_num_str.isdigit() else [1, 2, 3, 4]
|
| 363 |
+
|
| 364 |
+
backbone = resnet_family(
|
| 365 |
+
3,
|
| 366 |
+
hyper_blocks_to_keep,
|
| 367 |
+
stride=2,
|
| 368 |
+
pretrained=False,
|
| 369 |
+
progress=True,
|
| 370 |
+
device="cpu",
|
| 371 |
+
do_initial_max_pool=True,
|
| 372 |
+
)
|
| 373 |
+
|
| 374 |
+
return backbone
|
models/utils.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
import re
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
def compute_decay(T, params, clamp_lims=(0, 15)):
|
| 7 |
+
"""
|
| 8 |
+
This function computes exponential decays for learnable synchronisation
|
| 9 |
+
interactions between pairs of neurons.
|
| 10 |
+
"""
|
| 11 |
+
assert len(clamp_lims), 'Clamp lims should be length 2'
|
| 12 |
+
assert type(clamp_lims) == tuple, 'Clamp lims should be tuple'
|
| 13 |
+
|
| 14 |
+
indices = torch.arange(T-1, -1, -1, device=params.device).reshape(T, 1).expand(T, params.shape[0])
|
| 15 |
+
out = torch.exp(-indices * torch.clamp(params, clamp_lims[0], clamp_lims[1]).unsqueeze(0))
|
| 16 |
+
return out
|
| 17 |
+
|
| 18 |
+
def add_coord_dim(x, scaled=True):
|
| 19 |
+
"""
|
| 20 |
+
Adds a final dimension to the tensor representing 2D coordinates.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
tensor: A PyTorch tensor of shape (B, D, H, W).
|
| 24 |
+
|
| 25 |
+
Returns:
|
| 26 |
+
A PyTorch tensor of shape (B, D, H, W, 2) with the last dimension
|
| 27 |
+
representing the 2D coordinates within the HW dimensions.
|
| 28 |
+
"""
|
| 29 |
+
B, H, W = x.shape
|
| 30 |
+
# Create coordinate grids
|
| 31 |
+
x_coords = torch.arange(W, device=x.device, dtype=x.dtype).repeat(H, 1) # Shape (H, W)
|
| 32 |
+
y_coords = torch.arange(H, device=x.device, dtype=x.dtype).unsqueeze(-1).repeat(1, W) # Shape (H, W)
|
| 33 |
+
if scaled:
|
| 34 |
+
x_coords /= (W-1)
|
| 35 |
+
y_coords /= (H-1)
|
| 36 |
+
# Stack coordinates and expand dimensions
|
| 37 |
+
coords = torch.stack((x_coords, y_coords), dim=-1) # Shape (H, W, 2)
|
| 38 |
+
coords = coords.unsqueeze(0) # Shape (1, 1, H, W, 2)
|
| 39 |
+
coords = coords.repeat(B, 1, 1, 1) # Shape (B, D, H, W, 2)
|
| 40 |
+
return coords
|
| 41 |
+
|
| 42 |
+
def compute_normalized_entropy(logits, reduction='mean'):
|
| 43 |
+
"""
|
| 44 |
+
Calculates the normalized entropy of a PyTorch tensor of logits along the
|
| 45 |
+
final dimension.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
logits: A PyTorch tensor of logits.
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
A PyTorch tensor containing the normalized entropy values.
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
# Apply softmax to get probabilities
|
| 55 |
+
preds = F.softmax(logits, dim=-1)
|
| 56 |
+
|
| 57 |
+
# Calculate the log probabilities
|
| 58 |
+
log_preds = torch.log_softmax(logits, dim=-1)
|
| 59 |
+
|
| 60 |
+
# Calculate the entropy
|
| 61 |
+
entropy = -torch.sum(preds * log_preds, dim=-1)
|
| 62 |
+
|
| 63 |
+
# Calculate the maximum possible entropy
|
| 64 |
+
num_classes = preds.shape[-1]
|
| 65 |
+
max_entropy = torch.log(torch.tensor(num_classes, dtype=torch.float32))
|
| 66 |
+
|
| 67 |
+
# Normalize the entropy
|
| 68 |
+
normalized_entropy = entropy / max_entropy
|
| 69 |
+
if len(logits.shape)>2 and reduction == 'mean':
|
| 70 |
+
normalized_entropy = normalized_entropy.flatten(1).mean(-1)
|
| 71 |
+
|
| 72 |
+
return normalized_entropy
|
| 73 |
+
|
| 74 |
+
def reshape_predictions(predictions, prediction_reshaper):
|
| 75 |
+
B, T = predictions.size(0), predictions.size(-1)
|
| 76 |
+
new_shape = [B] + prediction_reshaper + [T]
|
| 77 |
+
rehaped_predictions = predictions.reshape(new_shape)
|
| 78 |
+
return rehaped_predictions
|
| 79 |
+
|
| 80 |
+
def get_all_log_dirs(root_dir):
|
| 81 |
+
folders = []
|
| 82 |
+
for dirpath, dirnames, filenames in os.walk(root_dir):
|
| 83 |
+
if any(f.endswith(".pt") for f in filenames):
|
| 84 |
+
folders.append(dirpath)
|
| 85 |
+
return folders
|
| 86 |
+
|
| 87 |
+
def get_latest_checkpoint(log_dir):
|
| 88 |
+
files = [f for f in os.listdir(log_dir) if re.match(r'checkpoint_\d+\.pt', f)]
|
| 89 |
+
return os.path.join(log_dir, max(files, key=lambda f: int(re.search(r'\d+', f).group()))) if files else None
|
| 90 |
+
|
| 91 |
+
def get_latest_checkpoint_file(filepath, limit=300000):
|
| 92 |
+
checkpoint_files = get_checkpoint_files(filepath)
|
| 93 |
+
checkpoint_files = [
|
| 94 |
+
f for f in checkpoint_files if int(re.search(r'checkpoint_(\d+)\.pt', f).group(1)) <= limit
|
| 95 |
+
]
|
| 96 |
+
if not checkpoint_files:
|
| 97 |
+
return None
|
| 98 |
+
return checkpoint_files[-1]
|
| 99 |
+
|
| 100 |
+
def get_checkpoint_files(filepath):
|
| 101 |
+
regex = r'checkpoint_(\d+)\.pt'
|
| 102 |
+
files = [f for f in os.listdir(filepath) if re.match(regex, f)]
|
| 103 |
+
files = sorted(files, key=lambda f: int(re.search(regex, f).group(1)))
|
| 104 |
+
return [os.path.join(filepath, f) for f in files]
|
| 105 |
+
|
| 106 |
+
def load_checkpoint(checkpoint_path, device):
|
| 107 |
+
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
|
| 108 |
+
return checkpoint
|
| 109 |
+
|
| 110 |
+
def get_model_args_from_checkpoint(checkpoint):
|
| 111 |
+
if "args" in checkpoint:
|
| 112 |
+
return(checkpoint["args"])
|
| 113 |
+
else:
|
| 114 |
+
raise ValueError("Checkpoint does not contain saved args.")
|
| 115 |
+
|
| 116 |
+
def get_accuracy_and_loss_from_checkpoint(checkpoint, device="cpu"):
|
| 117 |
+
training_iteration = checkpoint.get('training_iteration', 0)
|
| 118 |
+
train_losses = checkpoint.get('train_losses', [])
|
| 119 |
+
test_losses = checkpoint.get('test_losses', [])
|
| 120 |
+
train_accuracies = checkpoint.get('train_accuracies_most_certain', [])
|
| 121 |
+
test_accuracies = checkpoint.get('test_accuracies_most_certain', [])
|
| 122 |
+
return training_iteration, train_losses, test_losses, train_accuracies, test_accuracies
|
mount_azure.sh
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
set -e
|
| 3 |
+
|
| 4 |
+
# Configuración
|
| 5 |
+
REMOTE_HOST="ssh.my-robot.dev"
|
| 6 |
+
REMOTE_USER="azureuser"
|
| 7 |
+
REMOTE_PATH="/mnt/lightrag"
|
| 8 |
+
LOCAL_MOUNT="/data/persistent"
|
| 9 |
+
|
| 10 |
+
# 1. Configurar SSH Key
|
| 11 |
+
if [ -z "$SSH_KEY" ]; then
|
| 12 |
+
echo "❌ SSH_KEY no definida. Configúrala en Settings -> Secrets"
|
| 13 |
+
exit 1
|
| 14 |
+
fi
|
| 15 |
+
|
| 16 |
+
mkdir -p ~/.ssh
|
| 17 |
+
echo "$SSH_KEY" > ~/.ssh/id_rsa
|
| 18 |
+
chmod 600 ~/.ssh/id_rsa
|
| 19 |
+
|
| 20 |
+
# 2. Configurar Cloudflare ProxyCommand
|
| 21 |
+
echo "Host $REMOTE_HOST
|
| 22 |
+
User $REMOTE_USER
|
| 23 |
+
IdentityFile ~/.ssh/id_rsa
|
| 24 |
+
ProxyCommand /usr/bin/cloudflared access ssh --hostname %h
|
| 25 |
+
StrictHostKeyChecking no
|
| 26 |
+
" > ~/.ssh/config
|
| 27 |
+
|
| 28 |
+
# 3. Crear punto de montaje
|
| 29 |
+
if [ ! -d "$LOCAL_MOUNT" ]; then
|
| 30 |
+
mkdir -p $LOCAL_MOUNT
|
| 31 |
+
chmod 777 $LOCAL_MOUNT || true
|
| 32 |
+
fi
|
| 33 |
+
|
| 34 |
+
# 4. Montar SSHFS
|
| 35 |
+
echo "🔌 Conectando al Bunker ($REMOTE_HOST)..."
|
| 36 |
+
sshfs $REMOTE_HOST:$REMOTE_PATH $LOCAL_MOUNT \
|
| 37 |
+
-o allow_other,reconnect,ServerAliveInterval=15,ServerAliveCountMax=3,idmap=user
|
| 38 |
+
|
| 39 |
+
if mountpoint -q $LOCAL_MOUNT; then
|
| 40 |
+
echo "✅ Disco Azure Montado en $LOCAL_MOUNT"
|
| 41 |
+
ls -la $LOCAL_MOUNT || echo "No se puede listar, pero está montado"
|
| 42 |
+
else
|
| 43 |
+
echo "❌ Error al montar"
|
| 44 |
+
fi
|
requirements.txt
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Requirements for CTM Nervous System v2.0 (Full PyTorch)
|
| 2 |
+
# =========================================================
|
| 3 |
+
# Upgraded to use full ContinuousThoughtMachine implementation
|
| 4 |
+
|
| 5 |
+
# Core Framework
|
| 6 |
+
gradio>=5.0.0
|
| 7 |
+
numpy
|
| 8 |
+
|
| 9 |
+
# PyTorch (CPU-only for HuggingFace free tier)
|
| 10 |
+
# Use torch-cpu to minimize memory footprint
|
| 11 |
+
--extra-index-url https://download.pytorch.org/whl/cpu
|
| 12 |
+
torch>=2.0.0
|
| 13 |
+
|
| 14 |
+
# For checkpoint saving
|
| 15 |
+
safetensors
|
| 16 |
+
|
| 17 |
+
# For potential model weights download
|
| 18 |
+
huggingface_hub
|
| 19 |
+
|
| 20 |
+
# HTTP client for Brain server integration
|
| 21 |
+
requests
|
requirements_v1.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio>=5.0.0
|
| 2 |
+
numpy
|
setup_hf_space.sh
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
# Configuración de Usuario
|
| 4 |
+
git config --global user.email "ksj6ftj8f8@gmail.com"
|
| 5 |
+
git config --global user.name "Elliotasdasdasfasas"
|
| 6 |
+
git config --global credential.helper store
|
| 7 |
+
|
| 8 |
+
# Inicializar Git si no existe
|
| 9 |
+
if [ ! -d ".git" ]; then
|
| 10 |
+
git init
|
| 11 |
+
git branch -M main
|
| 12 |
+
fi
|
| 13 |
+
|
| 14 |
+
# Definir Token y Usuario
|
| 15 |
+
HF_TOKEN="${HF_TOKEN}"
|
| 16 |
+
HF_USER="ROBOT-GANSTA"
|
| 17 |
+
SPACE_NAME="robot-gansta-redneural"
|
| 18 |
+
|
| 19 |
+
# Configurar Remote con Token explícito (desde variable de entorno)
|
| 20 |
+
REMOTE_URL="https://$HF_USER:$HF_TOKEN@huggingface.co/spaces/$HF_USER/$SPACE_NAME"
|
| 21 |
+
|
| 22 |
+
echo "🔗 Conectando a: https://huggingface.co/spaces/$HF_USER/$SPACE_NAME"
|
| 23 |
+
|
| 24 |
+
# Eliminar remote anterior si existe
|
| 25 |
+
git remote remove space 2>/dev/null
|
| 26 |
+
|
| 27 |
+
# Añadir remote y hacer pull/push
|
| 28 |
+
git remote add space "$REMOTE_URL"
|
| 29 |
+
|
| 30 |
+
# Añadir archivos
|
| 31 |
+
git add .
|
| 32 |
+
git commit -m "Deploy RedNeural Docker Environment with Fixes"
|
| 33 |
+
|
| 34 |
+
# Forzar push inicial
|
| 35 |
+
git push --force space main
|
| 36 |
+
|
| 37 |
+
echo "✅ Despliegue completado."
|
tasks/image_classification/README.md
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Image classification
|
| 2 |
+
|
| 3 |
+
This folder contains code for training and analysing imagenet and cifar related experiments.
|
| 4 |
+
|
| 5 |
+
## Accessing and loading imagenet
|
| 6 |
+
|
| 7 |
+
We use the [ILSRC/imagenet-1k](https://huggingface.co/datasets/ILSVRC/imagenet-1k) dataset in our paper.
|
| 8 |
+
|
| 9 |
+
To get this to work for you, you will need to do the following:
|
| 10 |
+
1. Login to huggingface (make an account) to agree to TCs of this dataset,
|
| 11 |
+
2. Make a new access token.
|
| 12 |
+
3. Install huggingface_hub on the target machine with ```pip install huggingface_hub```
|
| 13 |
+
4. Run ```huggingface-cli login``` and use your token. This will authenticate you on the backend and allow the code to run.
|
| 14 |
+
5. Simply run an imagenet experiment. It will auto download and do all that magic.
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
## Training
|
| 18 |
+
There are two training files: `train.py` and `train_distributed.py`. The training code uses mixed precision. For the settings in the paper, the following command was used for distributed training:
|
| 19 |
+
|
| 20 |
+
```
|
| 21 |
+
torchrun --standalone --nnodes=1 --nproc_per_node=8 -m tasks.image_classification.train_distributed --d_model 4096 --d_input 1024 --synapse_depth 12 --heads 16 --n_synch_out 150 --n_synch_action 150 --neuron_select_type random --iterations 75 --memory_length 25 --deep_memory --memory_hidden_dims 64 --dropout 0.05 --no-do_normalisation --positional_embedding_type none --backbone_type resnet152-4 --batch_size 60 --batch_size_test 64 --lr 5e-4 --training_iterations 500001 --warmup_steps 10000 --use_scheduler --scheduler_type cosine --weight_decay 0.0 --log_dir logs-lambda/imagenet-distributed-4april/d=4096--i=1024--h=16--ns=150-random--iters=75x25--h=64--drop=0.05--pos=none--back=152x4--seed=42 --dataset imagenet --save_every 2000 --track_every 5000 --seed 42 --n_test_batches 50 --use_amp
|
| 22 |
+
```
|
| 23 |
+
|
| 24 |
+
You can run the same setup on a single GPU with:
|
| 25 |
+
```
|
| 26 |
+
python -m tasks.image_classification.train --d_model 4096 --d_input 1024 --synapse_depth 12 --heads 16 --n_synch_out 150 --n_synch_action 150 --neuron_select_type random --iterations 75 --memory_length 25 --deep_memory --memory_hidden_dims 64 --dropout 0.05 --no-do_normalisation --positional_embedding_type none --backbone_type resnet152-4 --batch_size 60 --batch_size_test 64 --lr 5e-4 --training_iterations 500001 --warmup_steps 10000 --use_scheduler --scheduler_type cosine --weight_decay 0.0 --log_dir logs-lambda/imagenet-distributed-4april/d=4096--i=1024--h=16--ns=150-random--iters=75x25--h=64--drop=0.05--pos=none--back=152x4--seed=42 --dataset imagenet --save_every 2000 --track_every 5000 --seed 42 --n_test_batches 50 --use_amp --device 0
|
| 27 |
+
```
|
| 28 |
+
|
| 29 |
+
## Checkpoint
|
| 30 |
+
|
| 31 |
+
The checkpoint for the model used in the paper can be found [here](https://drive.google.com/file/d/1Lr_3RZU9X9SS8lBhAhECBiSZDKfKhDkJ/view?usp=drive_link).
|
tasks/image_classification/analysis/README.md
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Analysis
|
| 2 |
+
|
| 3 |
+
This folder contains the analysis code for the image classifcation experiments. Running the following from the base directory will generate figures, gifs and mp4 files:
|
| 4 |
+
|
| 5 |
+
```
|
| 6 |
+
python -m tasks.image_classification.analysis.run_imagenet_analysis
|
| 7 |
+
```
|
tasks/image_classification/analysis/run_imagenet_analysis.py
ADDED
|
@@ -0,0 +1,972 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --- Core Libraries ---
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
import os
|
| 5 |
+
import argparse
|
| 6 |
+
from tqdm.auto import tqdm
|
| 7 |
+
import torch.nn.functional as F # Used for interpolate
|
| 8 |
+
|
| 9 |
+
# --- Plotting & Visualization ---
|
| 10 |
+
import matplotlib.pyplot as plt
|
| 11 |
+
import matplotlib as mpl
|
| 12 |
+
mpl.use('Agg')
|
| 13 |
+
import seaborn as sns
|
| 14 |
+
sns.set_style('darkgrid')
|
| 15 |
+
from matplotlib import patheffects
|
| 16 |
+
import seaborn as sns
|
| 17 |
+
import imageio
|
| 18 |
+
import cv2
|
| 19 |
+
from scipy.special import softmax
|
| 20 |
+
from tasks.image_classification.plotting import save_frames_to_mp4
|
| 21 |
+
|
| 22 |
+
# --- Data Handling & Model ---
|
| 23 |
+
from torchvision import transforms
|
| 24 |
+
from torchvision import datasets # Only used for CIFAR100 in debug mode
|
| 25 |
+
from scipy import ndimage # Used in find_island_centers
|
| 26 |
+
from data.custom_datasets import ImageNet
|
| 27 |
+
from models.ctm import ContinuousThoughtMachine
|
| 28 |
+
from tasks.image_classification.imagenet_classes import IMAGENET2012_CLASSES
|
| 29 |
+
from tasks.image_classification.plotting import plot_neural_dynamics
|
| 30 |
+
|
| 31 |
+
# --- Global Settings ---
|
| 32 |
+
np.seterr(divide='ignore')
|
| 33 |
+
mpl.use('Agg')
|
| 34 |
+
sns.set_style('darkgrid')
|
| 35 |
+
|
| 36 |
+
# --- Helper Functions ---
|
| 37 |
+
|
| 38 |
+
def find_island_centers(array_2d, threshold):
|
| 39 |
+
"""
|
| 40 |
+
Finds the center of mass of each island (connected component > threshold)
|
| 41 |
+
in a 2D array, weighted by the array's values.
|
| 42 |
+
Returns list of (y, x) centers and list of areas.
|
| 43 |
+
"""
|
| 44 |
+
binary_image = array_2d > threshold
|
| 45 |
+
labeled_image, num_labels = ndimage.label(binary_image)
|
| 46 |
+
centers = []
|
| 47 |
+
areas = []
|
| 48 |
+
# Calculate center of mass for each labeled island (label 0 is background)
|
| 49 |
+
for i in range(1, num_labels + 1):
|
| 50 |
+
island_mask = (labeled_image == i)
|
| 51 |
+
total_mass = np.sum(array_2d[island_mask])
|
| 52 |
+
if total_mass > 0:
|
| 53 |
+
# Get coordinates for this island
|
| 54 |
+
y_coords, x_coords = np.mgrid[:array_2d.shape[0], :array_2d.shape[1]]
|
| 55 |
+
# Calculate weighted average for center
|
| 56 |
+
x_center = np.average(x_coords[island_mask], weights=array_2d[island_mask])
|
| 57 |
+
y_center = np.average(y_coords[island_mask], weights=array_2d[island_mask])
|
| 58 |
+
centers.append((round(y_center, 4), round(x_center, 4)))
|
| 59 |
+
areas.append(np.sum(island_mask)) # Area is the count of pixels in the island
|
| 60 |
+
return centers, areas
|
| 61 |
+
|
| 62 |
+
def parse_args():
|
| 63 |
+
"""Parses command-line arguments."""
|
| 64 |
+
# Note: Original had two ArgumentParser instances, using the second one.
|
| 65 |
+
parser = argparse.ArgumentParser(description="Visualize Continuous Thought Machine Attention")
|
| 66 |
+
parser.add_argument('--actions', type=str, nargs='+', default=['videos'], choices=['plots', 'videos', 'demo'], help="Actions to take. Plots=results plots; videos=gifs/mp4s to watch attention; demo: last frame of internal ticks")
|
| 67 |
+
parser.add_argument('--device', type=int, nargs='+', default=[-1], help="GPU device index or -1 for CPU")
|
| 68 |
+
|
| 69 |
+
parser.add_argument('--checkpoint', type=str, default='checkpoints/imagenet/ctm_clean.pt', help="Path to ATM checkpoint")
|
| 70 |
+
parser.add_argument('--output_dir', type=str, default='tasks/image_classification/analysis/outputs/imagenet_viz', help="Directory for visualization outputs")
|
| 71 |
+
parser.add_argument('--debug', action=argparse.BooleanOptionalAction, default=True, help='Debug mode: use CIFAR100 instead of ImageNet for debugging.')
|
| 72 |
+
parser.add_argument('--plot_every', type=int, default=10, help="How often to plot.")
|
| 73 |
+
|
| 74 |
+
parser.add_argument('--inference_iterations', type=int, default=50, help="Iterations to use during inference.")
|
| 75 |
+
parser.add_argument('--data_indices', type=int, nargs='+', default=[], help="Use specific indices in validation data for demos, otherwise random.")
|
| 76 |
+
parser.add_argument('--N_to_viz', type=int, default=5, help="When not supplying data_indices.")
|
| 77 |
+
|
| 78 |
+
return parser.parse_args()
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
# --- Main Execution Block ---
|
| 82 |
+
if __name__=='__main__':
|
| 83 |
+
|
| 84 |
+
# --- Setup ---
|
| 85 |
+
args = parse_args()
|
| 86 |
+
if args.device[0] != -1 and torch.cuda.is_available():
|
| 87 |
+
device = f'cuda:{args.device[0]}'
|
| 88 |
+
else:
|
| 89 |
+
device = 'cpu'
|
| 90 |
+
print(f"Using device: {device}")
|
| 91 |
+
|
| 92 |
+
# --- Load Checkpoint & Model ---
|
| 93 |
+
print(f"Loading checkpoint: {args.checkpoint}")
|
| 94 |
+
checkpoint = torch.load(args.checkpoint, map_location=device, weights_only=False) # removed weights_only=False
|
| 95 |
+
model_args = checkpoint['args']
|
| 96 |
+
|
| 97 |
+
# Handle legacy arguments from checkpoint if necessary
|
| 98 |
+
if not hasattr(model_args, 'backbone_type') and hasattr(model_args, 'resnet_type'):
|
| 99 |
+
model_args.backbone_type = f'{model_args.resnet_type}-{getattr(model_args, "resnet_feature_scales", [4])[-1]}'
|
| 100 |
+
if not hasattr(model_args, 'neuron_select_type'):
|
| 101 |
+
model_args.neuron_select_type = 'first-last'
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
# Instantiate Model based on checkpoint args
|
| 105 |
+
print("Instantiating CTM model...")
|
| 106 |
+
model = ContinuousThoughtMachine(
|
| 107 |
+
iterations=model_args.iterations,
|
| 108 |
+
d_model=model_args.d_model,
|
| 109 |
+
d_input=model_args.d_input,
|
| 110 |
+
heads=model_args.heads,
|
| 111 |
+
n_synch_out=model_args.n_synch_out,
|
| 112 |
+
n_synch_action=model_args.n_synch_action,
|
| 113 |
+
synapse_depth=model_args.synapse_depth,
|
| 114 |
+
memory_length=model_args.memory_length,
|
| 115 |
+
deep_nlms=model_args.deep_memory,
|
| 116 |
+
memory_hidden_dims=model_args.memory_hidden_dims,
|
| 117 |
+
do_layernorm_nlm=model_args.do_normalisation,
|
| 118 |
+
backbone_type=model_args.backbone_type,
|
| 119 |
+
positional_embedding_type=model_args.positional_embedding_type,
|
| 120 |
+
out_dims=model_args.out_dims,
|
| 121 |
+
prediction_reshaper=[-1], # Kept fixed value from original code
|
| 122 |
+
dropout=0, # No dropout for eval
|
| 123 |
+
neuron_select_type=model_args.neuron_select_type,
|
| 124 |
+
n_random_pairing_self=model_args.n_random_pairing_self,
|
| 125 |
+
).to(device)
|
| 126 |
+
|
| 127 |
+
# Load weights into model
|
| 128 |
+
load_result = model.load_state_dict(checkpoint['model_state_dict'], strict=False)
|
| 129 |
+
print(f" Loaded state_dict. Missing: {load_result.missing_keys}, Unexpected: {load_result.unexpected_keys}")
|
| 130 |
+
model.eval() # Set model to evaluation mode
|
| 131 |
+
|
| 132 |
+
# --- Prepare Dataset ---
|
| 133 |
+
if args.debug:
|
| 134 |
+
print("Debug mode: Using CIFAR100")
|
| 135 |
+
# CIFAR100 specific normalization constants
|
| 136 |
+
dataset_mean = [0.5070751592371341, 0.48654887331495067, 0.4409178433670344]
|
| 137 |
+
dataset_std = [0.2673342858792403, 0.2564384629170882, 0.27615047132568393]
|
| 138 |
+
img_size = 256 # Resize CIFAR images for consistency
|
| 139 |
+
transform = transforms.Compose([
|
| 140 |
+
transforms.Resize(img_size),
|
| 141 |
+
transforms.ToTensor(),
|
| 142 |
+
transforms.Normalize(mean=dataset_mean, std=dataset_std), # Normalize
|
| 143 |
+
])
|
| 144 |
+
validation_dataset = datasets.CIFAR100('data/', train=False, transform=transform, download=True)
|
| 145 |
+
validation_dataset_centercrop = datasets.CIFAR100('data/', train=True, transform=transform, download=True)
|
| 146 |
+
else:
|
| 147 |
+
print("Using ImageNet")
|
| 148 |
+
# ImageNet specific normalization constants
|
| 149 |
+
dataset_mean = [0.485, 0.456, 0.406]
|
| 150 |
+
dataset_std = [0.229, 0.224, 0.225]
|
| 151 |
+
img_size = 256 # Resize ImageNet images
|
| 152 |
+
# Note: Original comment mentioned no CenterCrop, this transform reflects that.
|
| 153 |
+
transform = transforms.Compose([
|
| 154 |
+
transforms.Resize(img_size),
|
| 155 |
+
transforms.ToTensor(),
|
| 156 |
+
transforms.Normalize(mean=dataset_mean, std=dataset_std) # Normalize
|
| 157 |
+
])
|
| 158 |
+
validation_dataset = ImageNet(which_split='validation', transform=transform)
|
| 159 |
+
validation_dataset_centercrop = ImageNet(which_split='train', transform=transforms.Compose([
|
| 160 |
+
transforms.Resize(img_size),
|
| 161 |
+
transforms.RandomCrop(img_size),
|
| 162 |
+
transforms.ToTensor(),
|
| 163 |
+
transforms.Normalize(mean=dataset_mean, std=dataset_std) # Normalize
|
| 164 |
+
]))
|
| 165 |
+
class_labels = list(IMAGENET2012_CLASSES.values()) # Load actual class names
|
| 166 |
+
|
| 167 |
+
os.makedirs(f'{args.output_dir}', exist_ok=True)
|
| 168 |
+
|
| 169 |
+
interp_mode = 'nearest'
|
| 170 |
+
cmap_calib = sns.color_palette('viridis', as_cmap=True)
|
| 171 |
+
loader = torch.utils.data.DataLoader(validation_dataset, batch_size=1, shuffle=False, num_workers=0, drop_last=False)
|
| 172 |
+
loader_crop = torch.utils.data.DataLoader(validation_dataset_centercrop, batch_size=64, shuffle=True, num_workers=0, drop_last=True)
|
| 173 |
+
|
| 174 |
+
model.eval()
|
| 175 |
+
|
| 176 |
+
figscale = 0.85
|
| 177 |
+
topk = 5
|
| 178 |
+
mean_certainties_correct, mean_certainties_incorrect = [],[]
|
| 179 |
+
tracked_certainties = []
|
| 180 |
+
tracked_targets = []
|
| 181 |
+
tracked_predictions = []
|
| 182 |
+
|
| 183 |
+
if model.iterations != args.inference_iterations:
|
| 184 |
+
print('WARNING: you are setting inference iterations to a value not used during training!')
|
| 185 |
+
|
| 186 |
+
model.iterations = args.inference_iterations
|
| 187 |
+
|
| 188 |
+
if 'plots' in args.actions:
|
| 189 |
+
|
| 190 |
+
with torch.inference_mode(): # Disable gradient calculations
|
| 191 |
+
with tqdm(total=len(loader), initial=0, leave=False, position=0, dynamic_ncols=True) as pbar:
|
| 192 |
+
imgi = 0
|
| 193 |
+
for bi, (inputs, targets) in enumerate(loader):
|
| 194 |
+
inputs = inputs.to(device)
|
| 195 |
+
targets = targets.to(device)
|
| 196 |
+
if bi==0:
|
| 197 |
+
dynamics_inputs, _ = next(iter(loader_crop)) # Use this because of batching
|
| 198 |
+
_, _, _, _, post_activations_viz, _ = model(inputs, track=True)
|
| 199 |
+
plot_neural_dynamics(post_activations_viz, 15*10, args.output_dir, axis_snap=True, N_per_row=15)
|
| 200 |
+
predictions, certainties, synchronisation = model(inputs)
|
| 201 |
+
|
| 202 |
+
tracked_predictions.append(predictions.detach().cpu().numpy())
|
| 203 |
+
tracked_targets.append(targets.detach().cpu().numpy())
|
| 204 |
+
tracked_certainties.append(certainties.detach().cpu().numpy())
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
pbar.set_description(f'Processing base image of size {inputs.shape}')
|
| 210 |
+
pbar.update(1)
|
| 211 |
+
if ((bi % args.plot_every == 0) or bi == len(loader)-1) and bi!=0: #
|
| 212 |
+
|
| 213 |
+
concatenated_certainties = np.concatenate(tracked_certainties, axis=0)
|
| 214 |
+
concatenated_targets = np.concatenate(tracked_targets, axis=0)
|
| 215 |
+
concatenated_predictions = np.concatenate(tracked_predictions, axis=0)
|
| 216 |
+
concatenated_predictions_argsorted = np.argsort(concatenated_predictions, 1)[:,::-1]
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
for topk in [1, 5]:
|
| 221 |
+
concatenated_predictions_argsorted_topk = concatenated_predictions_argsorted[:,:topk]
|
| 222 |
+
|
| 223 |
+
accs_instant, accs_avg, accs_certain = [], [], []
|
| 224 |
+
accs_avg_logits, accs_weighted_logits = [],[]
|
| 225 |
+
with tqdm(total=(concatenated_predictions.shape[-1]), initial=0, leave=False, position=1, dynamic_ncols=True) as pbarinner:
|
| 226 |
+
pbarinner.set_description('Acc types')
|
| 227 |
+
for stepi in np.arange(concatenated_predictions.shape[-1]):
|
| 228 |
+
pred_avg = softmax(concatenated_predictions, 1)[:,:,:stepi+1].mean(-1).argsort(1)[:,-topk:]
|
| 229 |
+
pred_instant = concatenated_predictions_argsorted_topk[:,:,stepi]
|
| 230 |
+
pred_certain = concatenated_predictions_argsorted_topk[np.arange(concatenated_predictions.shape[0]),:, concatenated_certainties[:,1,:stepi+1].argmax(1)]
|
| 231 |
+
pred_avg_logits = concatenated_predictions[:,:,:stepi+1].mean(-1).argsort(1)[:,-topk:]
|
| 232 |
+
pred_weighted_logits = (concatenated_predictions[:,:,:stepi+1] * concatenated_certainties[:,1:,:stepi+1]).sum(-1).argsort(1)[:, -topk:]
|
| 233 |
+
pbarinner.update(1)
|
| 234 |
+
accs_instant.append(np.any(pred_instant==concatenated_targets[...,np.newaxis], -1).mean())
|
| 235 |
+
accs_avg.append(np.any(pred_avg==concatenated_targets[...,np.newaxis], -1).mean())
|
| 236 |
+
accs_avg_logits.append(np.any(pred_avg==concatenated_targets[...,np.newaxis], -1).mean())
|
| 237 |
+
accs_weighted_logits.append(np.any(pred_weighted_logits==concatenated_targets[...,np.newaxis], -1).mean())
|
| 238 |
+
accs_certain.append(np.any(pred_avg_logits==concatenated_targets[...,np.newaxis], -1).mean())
|
| 239 |
+
fig = plt.figure(figsize=(10*figscale, 4*figscale))
|
| 240 |
+
ax = fig.add_subplot(111)
|
| 241 |
+
cp = sns.color_palette("bright")
|
| 242 |
+
ax.plot(np.arange(concatenated_predictions.shape[-1])+1, 100*np.array(accs_instant), linestyle='-', color=cp[0], label='Instant')
|
| 243 |
+
# ax.plot(np.arange(concatenated_predictions.shape[-1])+1, 100*np.array(accs_avg), linestyle='--', color=cp[1], label='Based on average probability up to this step')
|
| 244 |
+
ax.plot(np.arange(concatenated_predictions.shape[-1])+1, 100*np.array(accs_certain), linestyle=':', color=cp[2], label='Most certain')
|
| 245 |
+
ax.plot(np.arange(concatenated_predictions.shape[-1])+1, 100*np.array(accs_avg_logits), linestyle='-.', color=cp[3], label='Average logits')
|
| 246 |
+
ax.plot(np.arange(concatenated_predictions.shape[-1])+1, 100*np.array(accs_weighted_logits), linestyle='--', color=cp[4], label='Logits weighted by certainty')
|
| 247 |
+
ax.set_xlim([0, concatenated_predictions.shape[-1]+1])
|
| 248 |
+
ax.set_ylim([75, 92])
|
| 249 |
+
ax.set_xlabel('Internal ticks')
|
| 250 |
+
ax.set_ylabel(f'Top-k={topk} accuracy')
|
| 251 |
+
ax.legend(loc='lower right')
|
| 252 |
+
fig.tight_layout(pad=0.1)
|
| 253 |
+
fig.savefig(f'{args.output_dir}/accuracy_types_{topk}.png', dpi=200)
|
| 254 |
+
fig.savefig(f'{args.output_dir}/accuracy_types_{topk}.pdf', dpi=200)
|
| 255 |
+
plt.close(fig)
|
| 256 |
+
print(f'k={topk}. Accuracy most certain at last internal tick={100*np.array(accs_certain)[-1]:0.4f}') # Using certainty based approach
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
indices_over_80 = []
|
| 260 |
+
classes_80 = {}
|
| 261 |
+
corrects_80 = {}
|
| 262 |
+
|
| 263 |
+
topk = 5
|
| 264 |
+
concatenated_predictions_argsorted_topk = concatenated_predictions_argsorted[:,:topk]
|
| 265 |
+
for certainty_threshold in [0.5, 0.8, 0.9]:
|
| 266 |
+
# certainty_threshold = 0.6
|
| 267 |
+
percentage_corrects = []
|
| 268 |
+
percentage_incorrects = []
|
| 269 |
+
with tqdm(total=(concatenated_predictions.shape[-1]), initial=0, leave=False, position=1, dynamic_ncols=True) as pbarinner:
|
| 270 |
+
pbarinner.set_description(f'Certainty threshold={certainty_threshold}')
|
| 271 |
+
for stepi in np.arange(concatenated_predictions.shape[-1]):
|
| 272 |
+
certainty_here = concatenated_certainties[:,1,stepi]
|
| 273 |
+
certainty_mask = certainty_here>=certainty_threshold
|
| 274 |
+
predictions_here = concatenated_predictions_argsorted_topk[:,:,stepi]
|
| 275 |
+
is_correct_here = np.any(predictions_here==concatenated_targets[...,np.newaxis], axis=-1)
|
| 276 |
+
percentage_corrects.append(is_correct_here[certainty_mask].sum()/predictions_here.shape[0])
|
| 277 |
+
percentage_incorrects.append((~is_correct_here)[certainty_mask].sum()/predictions_here.shape[0])
|
| 278 |
+
|
| 279 |
+
if certainty_threshold==0.8:
|
| 280 |
+
indices_certain = np.where(certainty_mask)[0]
|
| 281 |
+
for index in indices_certain:
|
| 282 |
+
if index not in indices_over_80:
|
| 283 |
+
indices_over_80.append(index)
|
| 284 |
+
if concatenated_targets[index] not in classes_80:
|
| 285 |
+
classes_80[concatenated_targets[index]] = [stepi]
|
| 286 |
+
corrects_80[concatenated_targets[index]] = [is_correct_here[index]]
|
| 287 |
+
else:
|
| 288 |
+
classes_80[concatenated_targets[index]] = classes_80[concatenated_targets[index]]+[stepi]
|
| 289 |
+
corrects_80[concatenated_targets[index]] = corrects_80[concatenated_targets[index]]+[is_correct_here[index]]
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
pbarinner.update(1)
|
| 293 |
+
fig = plt.figure(figsize=(6.5*figscale, 4*figscale))
|
| 294 |
+
ax = fig.add_subplot(111)
|
| 295 |
+
ax.bar(np.arange(concatenated_predictions.shape[-1])+1,
|
| 296 |
+
percentage_corrects,
|
| 297 |
+
color='forestgreen',
|
| 298 |
+
hatch='OO',
|
| 299 |
+
width=0.9,
|
| 300 |
+
label='Positive',
|
| 301 |
+
alpha=0.9,
|
| 302 |
+
linewidth=1.0*figscale)
|
| 303 |
+
|
| 304 |
+
ax.bar(np.arange(concatenated_predictions.shape[-1])+1,
|
| 305 |
+
percentage_incorrects,
|
| 306 |
+
bottom=percentage_corrects,
|
| 307 |
+
color='crimson',
|
| 308 |
+
hatch='xx',
|
| 309 |
+
width=0.9,
|
| 310 |
+
label='Negative',
|
| 311 |
+
alpha=0.9,
|
| 312 |
+
linewidth=1.0*figscale)
|
| 313 |
+
ax.set_xlim(-1, concatenated_predictions.shape[-1]+1)
|
| 314 |
+
ax.set_xlabel('Internal tick')
|
| 315 |
+
ax.set_ylabel('% of data')
|
| 316 |
+
ax.legend(loc='lower right')
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
fig.tight_layout(pad=0.1)
|
| 320 |
+
fig.savefig(f'{args.output_dir}/steps_versus_correct_{certainty_threshold}.png', dpi=200)
|
| 321 |
+
fig.savefig(f'{args.output_dir}/steps_versus_correct_{certainty_threshold}.pdf', dpi=200)
|
| 322 |
+
plt.close(fig)
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
class_list = list(classes_80.keys())
|
| 326 |
+
mean_steps = [np.mean(classes_80[cls]) for cls in class_list]
|
| 327 |
+
std_steps = [np.std(classes_80[cls]) for cls in class_list]
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
# Following code plots the class distribution over internal ticks
|
| 331 |
+
indices_to_show = np.arange(1000)
|
| 332 |
+
|
| 333 |
+
colours = cmap_diverse = plt.get_cmap('rainbow')(np.linspace(0, 1, 1000))
|
| 334 |
+
# np.random.shuffle(colours)
|
| 335 |
+
bottom = np.zeros(concatenated_predictions.shape[-1])
|
| 336 |
+
|
| 337 |
+
fig = plt.figure(figsize=(7*figscale, 4*figscale))
|
| 338 |
+
ax = fig.add_subplot(111)
|
| 339 |
+
for iii, idx in enumerate(indices_to_show):
|
| 340 |
+
if idx in classes_80:
|
| 341 |
+
steps = classes_80[idx]
|
| 342 |
+
colour = colours[iii]
|
| 343 |
+
vs, cts = np.unique(steps, return_counts=True)
|
| 344 |
+
|
| 345 |
+
bar = np.zeros(concatenated_predictions.shape[-1])
|
| 346 |
+
bar[vs] = cts
|
| 347 |
+
ax.bar(np.arange(concatenated_predictions.shape[-1])+1, bar, bottom=bottom, color=colour, width=1, edgecolor='none')
|
| 348 |
+
bottom += bar
|
| 349 |
+
ax.set_xlabel('Internal ticks')
|
| 350 |
+
ax.set_ylabel('Counts over 0.8 certainty')
|
| 351 |
+
fig.tight_layout(pad=0.1)
|
| 352 |
+
fig.savefig(f'{args.output_dir}/class_counts.png', dpi=200)
|
| 353 |
+
fig.savefig(f'{args.output_dir}/class_counts.pdf', dpi=200)
|
| 354 |
+
plt.close(fig)
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
# The following code plots calibration
|
| 361 |
+
probability_space = np.linspace(0, 1, 10)
|
| 362 |
+
fig = plt.figure(figsize=(6*figscale, 4*figscale))
|
| 363 |
+
ax = fig.add_subplot(111)
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
color_linspace = np.linspace(0, 1, concatenated_predictions.shape[-1])
|
| 367 |
+
with tqdm(total=(concatenated_predictions.shape[-1]), initial=0, leave=False, position=1, dynamic_ncols=True) as pbarinner:
|
| 368 |
+
pbarinner.set_description(f'Calibration')
|
| 369 |
+
for stepi in np.arange(concatenated_predictions.shape[-1]):
|
| 370 |
+
color = cmap_calib(color_linspace[stepi])
|
| 371 |
+
pred = concatenated_predictions[:,:,stepi].argmax(1)
|
| 372 |
+
is_correct = pred == concatenated_targets # BxT
|
| 373 |
+
probabilities = softmax(concatenated_predictions[:,:,:stepi+1], axis=1)[np.arange(concatenated_predictions.shape[0]),pred].mean(-1)#softmax(concatenated_predictions[:,:,stepi], axis=1).max(1)
|
| 374 |
+
probability_space = np.linspace(0, 1, 10)
|
| 375 |
+
accuracies_per_bin = []
|
| 376 |
+
bin_centers = []
|
| 377 |
+
for pi in range(len(probability_space)-1):
|
| 378 |
+
bin_low = probability_space[pi]
|
| 379 |
+
bin_high = probability_space[pi+1]
|
| 380 |
+
mask = ((probabilities >=bin_low) & (probabilities < bin_high)) if pi !=len(probability_space)-2 else ((probabilities >=bin_low) & (probabilities <= bin_high))
|
| 381 |
+
accuracies_per_bin.append(is_correct[mask].mean())
|
| 382 |
+
bin_centers.append(probabilities[mask].mean())
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
if stepi==concatenated_predictions.shape[-1]-1:
|
| 386 |
+
ax.plot(bin_centers, accuracies_per_bin, linestyle='-', marker='.', color='#4050f7', alpha=1, label='After all ticks')
|
| 387 |
+
else: ax.plot(bin_centers, accuracies_per_bin, linestyle='-', marker='.', color=color, alpha=0.65)
|
| 388 |
+
pbarinner.update(1)
|
| 389 |
+
ax.plot(probability_space, np.linspace(0, 1, len(probability_space)), 'k--')
|
| 390 |
+
|
| 391 |
+
ax.legend(loc='upper left')
|
| 392 |
+
ax.set_xlim([-0.01, 1.01])
|
| 393 |
+
ax.set_ylim([-0.01, 1.01])
|
| 394 |
+
|
| 395 |
+
sm = plt.cm.ScalarMappable(cmap=cmap_calib, norm=plt.Normalize(vmin=0, vmax=concatenated_predictions.shape[-1] - 1))
|
| 396 |
+
sm.set_array([]) # Empty array for colormap
|
| 397 |
+
cbar = fig.colorbar(sm, ax=ax, orientation='vertical', pad=0.02)
|
| 398 |
+
cbar.set_label('Internal ticks')
|
| 399 |
+
|
| 400 |
+
ax.set_xlabel('Mean predicted probabilities')
|
| 401 |
+
ax.set_ylabel('Ratio of positives')
|
| 402 |
+
fig.tight_layout(pad=0.1)
|
| 403 |
+
fig.savefig(f'{args.output_dir}/imagenet_calibration.png', dpi=200)
|
| 404 |
+
fig.savefig(f'{args.output_dir}/imagenet_calibration.pdf', dpi=200)
|
| 405 |
+
plt.close(fig)
|
| 406 |
+
if 'videos' in args.actions:
|
| 407 |
+
if not args.data_indices: # If list is empty
|
| 408 |
+
n_samples = len(validation_dataset)
|
| 409 |
+
num_to_sample = min(args.N_to_viz, n_samples)
|
| 410 |
+
replace = n_samples < num_to_sample
|
| 411 |
+
data_indices = np.random.choice(np.arange(n_samples), size=num_to_sample, replace=replace)
|
| 412 |
+
print(f"Selected random indices: {data_indices}")
|
| 413 |
+
else:
|
| 414 |
+
data_indices = args.data_indices
|
| 415 |
+
print(f"Using specified indices: {data_indices}")
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
for di in data_indices:
|
| 419 |
+
print(f'\nBuilding viz for dataset index {di}.')
|
| 420 |
+
|
| 421 |
+
# --- Get Data & Run Inference ---
|
| 422 |
+
# inputs_norm is already normalized by the transform
|
| 423 |
+
inputs, ground_truth_target = validation_dataset.__getitem__(int(di))
|
| 424 |
+
|
| 425 |
+
# Add batch dimension and send to device
|
| 426 |
+
inputs = inputs.to(device).unsqueeze(0)
|
| 427 |
+
|
| 428 |
+
# Run model inference
|
| 429 |
+
predictions, certainties, synchronisation, pre_activations, post_activations, attention_tracking = model(inputs, track=True)
|
| 430 |
+
# predictions: (B, Classes, Steps), attention_tracking: (Steps*B*Heads, SeqLen)
|
| 431 |
+
n_steps = predictions.size(-1)
|
| 432 |
+
|
| 433 |
+
# --- Reshape Attention ---
|
| 434 |
+
# Infer feature map size from model internals (assuming B=1)
|
| 435 |
+
h_feat, w_feat = model.kv_features.shape[-2:]
|
| 436 |
+
|
| 437 |
+
n_heads = attention_tracking.shape[2]
|
| 438 |
+
# Reshape to (Steps, Heads, H_feat, W_feat) assuming B=1
|
| 439 |
+
attention_tracking = attention_tracking.reshape(n_steps, n_heads, h_feat, w_feat)
|
| 440 |
+
|
| 441 |
+
# --- Setup for Plotting ---
|
| 442 |
+
step_linspace = np.linspace(0, 1, n_steps) # For step colors
|
| 443 |
+
# Define color maps
|
| 444 |
+
cmap_spectral = sns.color_palette("Spectral", as_cmap=True)
|
| 445 |
+
cmap_attention = sns.color_palette('viridis', as_cmap=True)
|
| 446 |
+
|
| 447 |
+
# Create output directory for this index
|
| 448 |
+
index_output_dir = os.path.join(args.output_dir, str(di))
|
| 449 |
+
os.makedirs(index_output_dir, exist_ok=True)
|
| 450 |
+
|
| 451 |
+
frames = [] # Store frames for GIF
|
| 452 |
+
head_routes = {h: [] for h in range(n_heads)} # Store (y,x) path points per head
|
| 453 |
+
head_routes[-1] = []
|
| 454 |
+
route_colours_step = [] # Store colors for each step's path segments
|
| 455 |
+
|
| 456 |
+
# --- Loop Through Each Step ---
|
| 457 |
+
for step_i in range(n_steps):
|
| 458 |
+
|
| 459 |
+
# --- Prepare Image for Display ---
|
| 460 |
+
# Denormalize the input tensor for visualization
|
| 461 |
+
data_img_tensor = inputs[0].cpu() # Get first item in batch, move to CPU
|
| 462 |
+
mean_tensor = torch.tensor(dataset_mean).view(3, 1, 1)
|
| 463 |
+
std_tensor = torch.tensor(dataset_std).view(3, 1, 1)
|
| 464 |
+
data_img_denorm = data_img_tensor * std_tensor + mean_tensor
|
| 465 |
+
# Permute to (H, W, C) and convert to numpy, clip to [0, 1]
|
| 466 |
+
data_img_np = data_img_denorm.permute(1, 2, 0).detach().numpy()
|
| 467 |
+
data_img_np = np.clip(data_img_np, 0, 1)
|
| 468 |
+
img_h, img_w = data_img_np.shape[:2]
|
| 469 |
+
|
| 470 |
+
# --- Process Attention & Certainty ---
|
| 471 |
+
# Average attention over last few steps (from original code)
|
| 472 |
+
start_step = max(0, step_i - 5)
|
| 473 |
+
attention_now = attention_tracking[start_step : step_i + 1].mean(0) # Avg over steps -> (Heads, H_feat, W_feat)
|
| 474 |
+
# Get certainties up to current step
|
| 475 |
+
certainties_now = certainties[0, 1, :step_i+1].detach().cpu().numpy() # Assuming index 1 holds relevant certainty
|
| 476 |
+
|
| 477 |
+
# --- Calculate Attention Paths (using bilinear interp) ---
|
| 478 |
+
# Interpolate attention to image size using bilinear for center finding
|
| 479 |
+
attention_interp_bilinear = F.interpolate(
|
| 480 |
+
torch.from_numpy(attention_now).unsqueeze(0).float(), # Add batch dim, ensure float
|
| 481 |
+
size=(img_h, img_w),
|
| 482 |
+
mode=interp_mode,
|
| 483 |
+
# align_corners=False
|
| 484 |
+
).squeeze(0) # Remove batch dim -> (Heads, H, W)
|
| 485 |
+
|
| 486 |
+
# Normalize each head's map to [0, 1]
|
| 487 |
+
# Deal with mean
|
| 488 |
+
attn_mean = attention_interp_bilinear.mean(0)
|
| 489 |
+
attn_mean_min = attn_mean.min()
|
| 490 |
+
attn_mean_max = attn_mean.max()
|
| 491 |
+
attn_mean = (attn_mean - attn_mean_min) / (attn_mean_max - attn_mean_min)
|
| 492 |
+
centers, areas = find_island_centers(attn_mean.detach().cpu().numpy(), threshold=0.7)
|
| 493 |
+
|
| 494 |
+
if centers: # If islands found
|
| 495 |
+
largest_island_idx = np.argmax(areas)
|
| 496 |
+
current_center = centers[largest_island_idx] # (y, x)
|
| 497 |
+
head_routes[-1].append(current_center)
|
| 498 |
+
elif head_routes[-1]: # If no center now, repeat last known center if history exists
|
| 499 |
+
head_routes[-1].append(head_routes[-1][-1])
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
attn_min = attention_interp_bilinear.view(n_heads, -1).min(dim=-1, keepdim=True)[0].unsqueeze(-1)
|
| 503 |
+
attn_max = attention_interp_bilinear.view(n_heads, -1).max(dim=-1, keepdim=True)[0].unsqueeze(-1)
|
| 504 |
+
attention_interp_bilinear = (attention_interp_bilinear - attn_min) / (attn_max - attn_min + 1e-6)
|
| 505 |
+
|
| 506 |
+
# Store step color
|
| 507 |
+
current_colour = list(cmap_spectral(step_linspace[step_i]))
|
| 508 |
+
route_colours_step.append(current_colour)
|
| 509 |
+
|
| 510 |
+
# Find island center for each head
|
| 511 |
+
for head_i in range(n_heads):
|
| 512 |
+
attn_head_np = attention_interp_bilinear[head_i].detach().cpu().numpy()
|
| 513 |
+
# Keep threshold=0.7 based on original call
|
| 514 |
+
centers, areas = find_island_centers(attn_head_np, threshold=0.7)
|
| 515 |
+
|
| 516 |
+
if centers: # If islands found
|
| 517 |
+
largest_island_idx = np.argmax(areas)
|
| 518 |
+
current_center = centers[largest_island_idx] # (y, x)
|
| 519 |
+
head_routes[head_i].append(current_center)
|
| 520 |
+
elif head_routes[head_i]: # If no center now, repeat last known center if history exists
|
| 521 |
+
head_routes[head_i].append(head_routes[head_i][-1])
|
| 522 |
+
|
| 523 |
+
|
| 524 |
+
|
| 525 |
+
# --- Plotting Setup ---
|
| 526 |
+
mosaic = [['head_mean', 'head_mean', 'head_mean', 'head_mean', 'head_mean_overlay', 'head_mean_overlay', 'head_mean_overlay', 'head_mean_overlay'],
|
| 527 |
+
['head_mean', 'head_mean', 'head_mean', 'head_mean', 'head_mean_overlay', 'head_mean_overlay', 'head_mean_overlay', 'head_mean_overlay'],
|
| 528 |
+
['head_mean', 'head_mean', 'head_mean', 'head_mean', 'head_mean_overlay', 'head_mean_overlay', 'head_mean_overlay', 'head_mean_overlay'],
|
| 529 |
+
['head_mean', 'head_mean', 'head_mean', 'head_mean', 'head_mean_overlay', 'head_mean_overlay', 'head_mean_overlay', 'head_mean_overlay'],
|
| 530 |
+
['head_0', 'head_0_overlay', 'head_1', 'head_1_overlay', 'head_2', 'head_2_overlay', 'head_3', 'head_3_overlay'],
|
| 531 |
+
['head_4', 'head_4_overlay', 'head_5', 'head_5_overlay','head_6', 'head_6_overlay', 'head_7', 'head_7_overlay'],
|
| 532 |
+
['head_8', 'head_8_overlay', 'head_9', 'head_9_overlay','head_10', 'head_10_overlay', 'head_11', 'head_11_overlay'],
|
| 533 |
+
['head_12', 'head_12_overlay', 'head_13', 'head_13_overlay','head_14', 'head_14_overlay', 'head_15', 'head_15_overlay'],
|
| 534 |
+
['probabilities', 'probabilities','probabilities', 'probabilities', 'certainty', 'certainty', 'certainty', 'certainty'],
|
| 535 |
+
]
|
| 536 |
+
|
| 537 |
+
img_aspect = data_img_np.shape[0] / data_img_np.shape[1]
|
| 538 |
+
aspect_ratio = (8 * figscale, 9 * figscale * img_aspect) # W, H
|
| 539 |
+
fig, axes = plt.subplot_mosaic(mosaic, figsize=aspect_ratio)
|
| 540 |
+
|
| 541 |
+
for ax in axes.values():
|
| 542 |
+
ax.axis('off')
|
| 543 |
+
|
| 544 |
+
# --- Plot Certainty ---
|
| 545 |
+
ax_cert = axes['certainty']
|
| 546 |
+
ax_cert.plot(np.arange(len(certainties_now)), certainties_now, 'k-', linewidth=figscale*1)
|
| 547 |
+
# Add background color based on prediction correctness at each step
|
| 548 |
+
for ii in range(len(certainties_now)):
|
| 549 |
+
is_correct = predictions[0, :, ii].argmax(-1).item() == ground_truth_target # .item() for scalar tensor
|
| 550 |
+
facecolor = 'limegreen' if is_correct else 'orchid'
|
| 551 |
+
ax_cert.axvspan(ii, ii + 1, facecolor=facecolor, edgecolor=None, lw=0, alpha=0.3)
|
| 552 |
+
# Mark the last point
|
| 553 |
+
ax_cert.plot(len(certainties_now)-1, certainties_now[-1], 'k.', markersize=figscale*4)
|
| 554 |
+
ax_cert.axis('off')
|
| 555 |
+
ax_cert.set_ylim([0.05, 1.05])
|
| 556 |
+
ax_cert.set_xlim([0, n_steps]) # Use n_steps for consistent x-axis limit
|
| 557 |
+
|
| 558 |
+
# --- Plot Probabilities ---
|
| 559 |
+
ax_prob = axes['probabilities']
|
| 560 |
+
# Get probabilities for the current step
|
| 561 |
+
ps = torch.softmax(predictions[0, :, step_i], -1).detach().cpu()
|
| 562 |
+
k = 15 # Top k predictions
|
| 563 |
+
topk_probs, topk_indices = torch.topk(ps, k, dim=0, largest=True)
|
| 564 |
+
topk_indices = topk_indices.numpy()
|
| 565 |
+
topk_probs = topk_probs.numpy()
|
| 566 |
+
|
| 567 |
+
top_classes = np.array(class_labels)[topk_indices]
|
| 568 |
+
true_class_idx = ground_truth_target # Ground truth index
|
| 569 |
+
|
| 570 |
+
# Determine bar colors (green if correct, blue otherwise - consistent with original)
|
| 571 |
+
colours = ['g' if idx == true_class_idx else 'b' for idx in topk_indices]
|
| 572 |
+
|
| 573 |
+
# Plot horizontal bars (inverted range for top-down display)
|
| 574 |
+
ax_prob.barh(np.arange(k)[::-1], topk_probs, color=colours, alpha=1) # Use barh and inverted range
|
| 575 |
+
ax_prob.set_xlim([0, 1])
|
| 576 |
+
ax_prob.axis('off')
|
| 577 |
+
|
| 578 |
+
# Add text labels for top classes
|
| 579 |
+
for i, name_idx in enumerate(topk_indices):
|
| 580 |
+
name = class_labels[name_idx] # Get name from index
|
| 581 |
+
is_correct = name_idx == true_class_idx
|
| 582 |
+
fg_color = 'darkgreen' if is_correct else 'crimson' # Text colors from original
|
| 583 |
+
text_str = f'{name[:40]}' # Truncate long names
|
| 584 |
+
# Position text on the left side of the horizontal bars
|
| 585 |
+
ax_prob.text(
|
| 586 |
+
0.01, # Small offset from left edge
|
| 587 |
+
k - 1 - i, # Y-position corresponding to the bar
|
| 588 |
+
text_str,
|
| 589 |
+
#transform=ax_prob.transAxes, # Use data coordinates for Y
|
| 590 |
+
verticalalignment='center',
|
| 591 |
+
horizontalalignment='left',
|
| 592 |
+
fontsize=8,
|
| 593 |
+
color=fg_color,
|
| 594 |
+
alpha=0.9, # Slightly more visible than 0.5
|
| 595 |
+
path_effects=[
|
| 596 |
+
patheffects.Stroke(linewidth=2, foreground='white'), # Adjusted stroke
|
| 597 |
+
patheffects.Normal()
|
| 598 |
+
])
|
| 599 |
+
|
| 600 |
+
|
| 601 |
+
# --- Plot Attention Heads & Overlays (using nearest interp) ---
|
| 602 |
+
# Re-interpolate attention using nearest neighbor for visual plotting
|
| 603 |
+
attention_interp_plot = F.interpolate(
|
| 604 |
+
torch.from_numpy(attention_now).unsqueeze(0).float(),
|
| 605 |
+
size=(img_h, img_w),
|
| 606 |
+
mode=interp_mode, # 'nearest'
|
| 607 |
+
).squeeze(0)
|
| 608 |
+
|
| 609 |
+
attn_mean = attention_interp_plot.mean(0)
|
| 610 |
+
attn_mean_min = attn_mean.min()
|
| 611 |
+
attn_mean_max = attn_mean.max()
|
| 612 |
+
attn_mean = (attn_mean - attn_mean_min) / (attn_mean_max - attn_mean_min)
|
| 613 |
+
|
| 614 |
+
|
| 615 |
+
# Normalize each head's map to [0, 1]
|
| 616 |
+
attn_min_plot = attention_interp_plot.view(n_heads, -1).min(dim=-1, keepdim=True)[0].unsqueeze(-1)
|
| 617 |
+
attn_max_plot = attention_interp_plot.view(n_heads, -1).max(dim=-1, keepdim=True)[0].unsqueeze(-1)
|
| 618 |
+
attention_interp_plot = (attention_interp_plot - attn_min_plot) / (attn_max_plot - attn_min_plot + 1e-6)
|
| 619 |
+
attention_interp_plot_np = attention_interp_plot.detach().cpu().numpy()
|
| 620 |
+
|
| 621 |
+
|
| 622 |
+
|
| 623 |
+
|
| 624 |
+
|
| 625 |
+
|
| 626 |
+
for head_i in list(range(n_heads)) + [-1]:
|
| 627 |
+
axname = f'head_{head_i}' if head_i != -1 else 'head_mean'
|
| 628 |
+
if axname not in axes: continue # Skip if mosaic doesn't have this head
|
| 629 |
+
|
| 630 |
+
ax = axes[axname]
|
| 631 |
+
ax_overlay = axes[f'{axname}_overlay']
|
| 632 |
+
|
| 633 |
+
# Plot attention heatmap
|
| 634 |
+
this_attn = attention_interp_plot_np[head_i] if head_i != -1 else attn_mean
|
| 635 |
+
img_to_plot = cmap_attention(this_attn)
|
| 636 |
+
ax.imshow(img_to_plot)
|
| 637 |
+
ax.axis('off')
|
| 638 |
+
|
| 639 |
+
# Plot overlay: image + paths
|
| 640 |
+
these_route_steps = head_routes[head_i]
|
| 641 |
+
arrow_scale = 1.5 if head_i != -1 else 3
|
| 642 |
+
|
| 643 |
+
if these_route_steps: # Only plot if path exists
|
| 644 |
+
# Separate y and x coordinates
|
| 645 |
+
y_coords, x_coords = zip(*these_route_steps)
|
| 646 |
+
y_coords = np.array(y_coords)
|
| 647 |
+
x_coords = np.array(x_coords)
|
| 648 |
+
|
| 649 |
+
# Flip y-coordinates for correct plotting (imshow origin is top-left)
|
| 650 |
+
# NOTE: Original flip seemed complex, simplifying to standard flip
|
| 651 |
+
y_coords_flipped = img_h - 1 - y_coords
|
| 652 |
+
|
| 653 |
+
# Show original image flipped vertically to match coordinate system
|
| 654 |
+
ax_overlay.imshow(np.flipud(data_img_np), origin='lower')
|
| 655 |
+
|
| 656 |
+
# Draw arrows for path segments
|
| 657 |
+
# Arrow size scaling from original
|
| 658 |
+
for i in range(len(these_route_steps) - 1):
|
| 659 |
+
dx = x_coords[i+1] - x_coords[i]
|
| 660 |
+
dy = y_coords_flipped[i+1] - y_coords_flipped[i] # Use flipped y for delta
|
| 661 |
+
|
| 662 |
+
# Draw white background arrow (thicker)
|
| 663 |
+
ax_overlay.arrow(x_coords[i], y_coords_flipped[i], dx, dy,
|
| 664 |
+
linewidth=1.6 * arrow_scale * 1.3,
|
| 665 |
+
head_width=1.9 * arrow_scale * 1.3,
|
| 666 |
+
head_length=1.4 * arrow_scale * 1.45,
|
| 667 |
+
fc='white', ec='white', length_includes_head=True, alpha=1)
|
| 668 |
+
# Draw colored foreground arrow
|
| 669 |
+
ax_overlay.arrow(x_coords[i], y_coords_flipped[i], dx, dy,
|
| 670 |
+
linewidth=1.6 * arrow_scale,
|
| 671 |
+
head_width=1.9 * arrow_scale,
|
| 672 |
+
head_length=1.4 * arrow_scale,
|
| 673 |
+
fc=route_colours_step[i], ec=route_colours_step[i], # Use step color
|
| 674 |
+
length_includes_head=True)
|
| 675 |
+
|
| 676 |
+
else: # If no path yet, just show the image
|
| 677 |
+
ax_overlay.imshow(np.flipud(data_img_np), origin='lower')
|
| 678 |
+
|
| 679 |
+
|
| 680 |
+
# Set limits and turn off axes for overlay
|
| 681 |
+
ax_overlay.set_xlim([0, img_w - 1])
|
| 682 |
+
ax_overlay.set_ylim([0, img_h - 1])
|
| 683 |
+
ax_overlay.axis('off')
|
| 684 |
+
|
| 685 |
+
|
| 686 |
+
# --- Finalize and Save Frame ---
|
| 687 |
+
fig.tight_layout(pad=0.1) # Adjust spacing
|
| 688 |
+
|
| 689 |
+
# Render the plot to a numpy array
|
| 690 |
+
canvas = fig.canvas
|
| 691 |
+
canvas.draw()
|
| 692 |
+
image_numpy = np.frombuffer(canvas.buffer_rgba(), dtype='uint8')
|
| 693 |
+
image_numpy = image_numpy.reshape(*reversed(canvas.get_width_height()), 4)[:,:,:3] # Get RGB
|
| 694 |
+
|
| 695 |
+
frames.append(image_numpy) # Add to list for GIF
|
| 696 |
+
|
| 697 |
+
|
| 698 |
+
|
| 699 |
+
plt.close(fig) # Close figure to free memory
|
| 700 |
+
|
| 701 |
+
# --- Save GIF ---
|
| 702 |
+
gif_path = os.path.join(index_output_dir, f'{str(di)}_viz.gif')
|
| 703 |
+
print(f"Saving GIF to {gif_path}...")
|
| 704 |
+
imageio.mimsave(gif_path, frames, fps=15, loop=0) # loop=0 means infinite loop
|
| 705 |
+
save_frames_to_mp4([fm[:,:,::-1] for fm in frames], os.path.join(index_output_dir, f'{str(di)}_viz.mp4'), fps=15, gop_size=1, preset='veryslow')
|
| 706 |
+
if 'demo' in args.actions:
|
| 707 |
+
|
| 708 |
+
|
| 709 |
+
|
| 710 |
+
# --- Select Data Indices ---
|
| 711 |
+
if not args.data_indices: # If list is empty
|
| 712 |
+
n_samples = len(validation_dataset)
|
| 713 |
+
num_to_sample = min(args.N_to_viz, n_samples)
|
| 714 |
+
replace = n_samples < num_to_sample
|
| 715 |
+
data_indices = np.random.choice(np.arange(n_samples), size=num_to_sample, replace=replace)
|
| 716 |
+
print(f"Selected random indices: {data_indices}")
|
| 717 |
+
else:
|
| 718 |
+
data_indices = args.data_indices
|
| 719 |
+
print(f"Using specified indices: {data_indices}")
|
| 720 |
+
|
| 721 |
+
|
| 722 |
+
for di in data_indices:
|
| 723 |
+
|
| 724 |
+
index_output_dir = os.path.join(args.output_dir, str(di))
|
| 725 |
+
os.makedirs(index_output_dir, exist_ok=True)
|
| 726 |
+
|
| 727 |
+
print(f'\nBuilding viz for dataset index {di}.')
|
| 728 |
+
|
| 729 |
+
inputs, ground_truth_target = validation_dataset.__getitem__(int(di))
|
| 730 |
+
|
| 731 |
+
# Add batch dimension and send to device
|
| 732 |
+
inputs = inputs.to(device).unsqueeze(0)
|
| 733 |
+
predictions, certainties, synchronisations_over_time, pre_activations, post_activations, attention_tracking = model(inputs, track=True)
|
| 734 |
+
|
| 735 |
+
# --- Reshape Attention ---
|
| 736 |
+
# Infer feature map size from model internals (assuming B=1)
|
| 737 |
+
h_feat, w_feat = model.kv_features.shape[-2:]
|
| 738 |
+
n_steps = predictions.size(-1)
|
| 739 |
+
n_heads = attention_tracking.shape[2]
|
| 740 |
+
# Reshape to (Steps, Heads, H_feat, W_feat) assuming B=1
|
| 741 |
+
attention_tracking = attention_tracking.reshape(n_steps, n_heads, h_feat, w_feat)
|
| 742 |
+
|
| 743 |
+
# --- Setup for Plotting ---
|
| 744 |
+
step_linspace = np.linspace(0, 1, n_steps) # For step colors
|
| 745 |
+
# Define color maps
|
| 746 |
+
cmap_steps = sns.color_palette("Spectral", as_cmap=True)
|
| 747 |
+
cmap_attention = sns.color_palette('viridis', as_cmap=True)
|
| 748 |
+
|
| 749 |
+
# Create output directory for this index
|
| 750 |
+
|
| 751 |
+
|
| 752 |
+
frames = [] # Store frames for GIF
|
| 753 |
+
head_routes = [] # Store (y,x) path points per head
|
| 754 |
+
route_colours_step = [] # Store colors for each step's path segments
|
| 755 |
+
|
| 756 |
+
# --- Loop Through Each Step ---
|
| 757 |
+
for step_i in range(n_steps):
|
| 758 |
+
|
| 759 |
+
# Store step color
|
| 760 |
+
current_colour = list(cmap_steps(step_linspace[step_i]))
|
| 761 |
+
route_colours_step.append(current_colour)
|
| 762 |
+
|
| 763 |
+
# --- Prepare Image for Display ---
|
| 764 |
+
# Denormalize the input tensor for visualization
|
| 765 |
+
data_img_tensor = inputs[0].cpu() # Get first item in batch, move to CPU
|
| 766 |
+
mean_tensor = torch.tensor(dataset_mean).view(3, 1, 1)
|
| 767 |
+
std_tensor = torch.tensor(dataset_std).view(3, 1, 1)
|
| 768 |
+
data_img_denorm = data_img_tensor * std_tensor + mean_tensor
|
| 769 |
+
# Permute to (H, W, C) and convert to numpy, clip to [0, 1]
|
| 770 |
+
data_img_np = data_img_denorm.permute(1, 2, 0).detach().numpy()
|
| 771 |
+
data_img_np = np.clip(data_img_np, 0, 1)
|
| 772 |
+
img_h, img_w = data_img_np.shape[:2]
|
| 773 |
+
|
| 774 |
+
# --- Process Attention & Certainty ---
|
| 775 |
+
# Average attention over last few steps (from original code)
|
| 776 |
+
start_step = max(0, step_i - 5)
|
| 777 |
+
attention_now = attention_tracking[start_step : step_i + 1].mean(0) # Avg over steps -> (Heads, H_feat, W_feat)
|
| 778 |
+
# Get certainties up to current step
|
| 779 |
+
certainties_now = certainties[0, 1, :step_i+1].detach().cpu().numpy() # Assuming index 1 holds relevant certainty
|
| 780 |
+
|
| 781 |
+
# --- Calculate Attention Paths (using bilinear interp) ---
|
| 782 |
+
# Interpolate attention to image size using bilinear for center finding
|
| 783 |
+
attention_interp_bilinear = F.interpolate(
|
| 784 |
+
torch.from_numpy(attention_now).unsqueeze(0).float(), # Add batch dim, ensure float
|
| 785 |
+
size=(img_h, img_w),
|
| 786 |
+
mode=interp_mode,
|
| 787 |
+
).squeeze(0) # Remove batch dim -> (Heads, H, W)
|
| 788 |
+
|
| 789 |
+
attn_mean = attention_interp_bilinear.mean(0)
|
| 790 |
+
attn_mean_min = attn_mean.min()
|
| 791 |
+
attn_mean_max = attn_mean.max()
|
| 792 |
+
attn_mean = (attn_mean - attn_mean_min) / (attn_mean_max - attn_mean_min)
|
| 793 |
+
centers, areas = find_island_centers(attn_mean.detach().cpu().numpy(), threshold=0.7)
|
| 794 |
+
|
| 795 |
+
if centers: # If islands found
|
| 796 |
+
largest_island_idx = np.argmax(areas)
|
| 797 |
+
current_center = centers[largest_island_idx] # (y, x)
|
| 798 |
+
head_routes.append(current_center)
|
| 799 |
+
elif head_routes: # If no center now, repeat last known center if history exists
|
| 800 |
+
head_routes.append(head_routes[-1])
|
| 801 |
+
|
| 802 |
+
# --- Plotting Setup ---
|
| 803 |
+
# if n_heads != 8: print(f"Warning: Plotting layout assumes 8 heads, found {n_heads}. Layout may be incorrect.")
|
| 804 |
+
mosaic = [['head_0', 'head_1', 'head_2', 'head_3', 'head_mean', 'head_mean', 'head_mean', 'head_mean', 'overlay', 'overlay', 'overlay', 'overlay'],
|
| 805 |
+
['head_4', 'head_5', 'head_6', 'head_7', 'head_mean', 'head_mean', 'head_mean', 'head_mean', 'overlay', 'overlay', 'overlay', 'overlay'],
|
| 806 |
+
['head_8', 'head_9', 'head_10', 'head_11', 'head_mean', 'head_mean', 'head_mean', 'head_mean', 'overlay', 'overlay', 'overlay', 'overlay'],
|
| 807 |
+
['head_12', 'head_13', 'head_14', 'head_15', 'head_mean', 'head_mean', 'head_mean', 'head_mean', 'overlay', 'overlay', 'overlay', 'overlay'],
|
| 808 |
+
['probabilities', 'probabilities', 'probabilities', 'probabilities', 'probabilities', 'probabilities', 'certainty', 'certainty', 'certainty', 'certainty', 'certainty', 'certainty'],
|
| 809 |
+
['probabilities', 'probabilities', 'probabilities', 'probabilities', 'probabilities', 'probabilities', 'certainty', 'certainty', 'certainty', 'certainty', 'certainty', 'certainty'],
|
| 810 |
+
]
|
| 811 |
+
|
| 812 |
+
img_aspect = data_img_np.shape[0] / data_img_np.shape[1]
|
| 813 |
+
aspect_ratio = (12 * figscale, 6 * figscale * img_aspect) # W, H
|
| 814 |
+
fig, axes = plt.subplot_mosaic(mosaic, figsize=aspect_ratio)
|
| 815 |
+
for ax in axes.values():
|
| 816 |
+
ax.axis('off')
|
| 817 |
+
|
| 818 |
+
# --- Plot Certainty ---
|
| 819 |
+
ax_cert = axes['certainty']
|
| 820 |
+
ax_cert.plot(np.arange(len(certainties_now)), certainties_now, 'k-', linewidth=figscale*1)
|
| 821 |
+
# Add background color based on prediction correctness at each step
|
| 822 |
+
for ii in range(len(certainties_now)):
|
| 823 |
+
is_correct = predictions[0, :, ii].argmax(-1).item() == ground_truth_target # .item() for scalar tensor
|
| 824 |
+
facecolor = 'limegreen' if is_correct else 'orchid'
|
| 825 |
+
ax_cert.axvspan(ii, ii + 1, facecolor=facecolor, edgecolor=None, lw=0, alpha=0.3)
|
| 826 |
+
# Mark the last point
|
| 827 |
+
ax_cert.plot(len(certainties_now)-1, certainties_now[-1], 'k.', markersize=figscale*4)
|
| 828 |
+
ax_cert.axis('off')
|
| 829 |
+
ax_cert.set_ylim([0.05, 1.05])
|
| 830 |
+
ax_cert.set_xlim([0, n_steps]) # Use n_steps for consistent x-axis limit
|
| 831 |
+
|
| 832 |
+
# --- Plot Probabilities ---
|
| 833 |
+
ax_prob = axes['probabilities']
|
| 834 |
+
# Get probabilities for the current step
|
| 835 |
+
ps = torch.softmax(predictions[0, :, step_i], -1).detach().cpu()
|
| 836 |
+
k = 15 # Top k predictions
|
| 837 |
+
topk_probs, topk_indices = torch.topk(ps, k, dim=0, largest=True)
|
| 838 |
+
topk_indices = topk_indices.numpy()
|
| 839 |
+
topk_probs = topk_probs.numpy()
|
| 840 |
+
|
| 841 |
+
top_classes = np.array(class_labels)[topk_indices]
|
| 842 |
+
true_class_idx = ground_truth_target # Ground truth index
|
| 843 |
+
|
| 844 |
+
# Determine bar colors (green if correct, blue otherwise - consistent with original)
|
| 845 |
+
colours = ['g' if idx == true_class_idx else 'b' for idx in topk_indices]
|
| 846 |
+
|
| 847 |
+
# Plot horizontal bars (inverted range for top-down display)
|
| 848 |
+
ax_prob.barh(np.arange(k)[::-1], topk_probs, color=colours, alpha=1) # Use barh and inverted range
|
| 849 |
+
ax_prob.set_xlim([0, 1])
|
| 850 |
+
ax_prob.axis('off')
|
| 851 |
+
|
| 852 |
+
# Add text labels for top classes
|
| 853 |
+
for i, name_idx in enumerate(topk_indices):
|
| 854 |
+
name = class_labels[name_idx] # Get name from index
|
| 855 |
+
is_correct = name_idx == true_class_idx
|
| 856 |
+
fg_color = 'darkgreen' if is_correct else 'crimson' # Text colors from original
|
| 857 |
+
text_str = f'{name[:40]}' # Truncate long names
|
| 858 |
+
# Position text on the left side of the horizontal bars
|
| 859 |
+
ax_prob.text(
|
| 860 |
+
0.01, # Small offset from left edge
|
| 861 |
+
k - 1 - i, # Y-position corresponding to the bar
|
| 862 |
+
text_str,
|
| 863 |
+
#transform=ax_prob.transAxes, # Use data coordinates for Y
|
| 864 |
+
verticalalignment='center',
|
| 865 |
+
horizontalalignment='left',
|
| 866 |
+
fontsize=8,
|
| 867 |
+
color=fg_color,
|
| 868 |
+
alpha=0.7, # Slightly more visible than 0.5
|
| 869 |
+
path_effects=[
|
| 870 |
+
patheffects.Stroke(linewidth=2, foreground='white'), # Adjusted stroke
|
| 871 |
+
patheffects.Normal()
|
| 872 |
+
])
|
| 873 |
+
|
| 874 |
+
|
| 875 |
+
# --- Plot Attention Heads & Overlays (using nearest interp) ---
|
| 876 |
+
# Re-interpolate attention using nearest neighbor for visual plotting
|
| 877 |
+
attention_interp_plot = F.interpolate(
|
| 878 |
+
torch.from_numpy(attention_now).unsqueeze(0).float(),
|
| 879 |
+
size=(img_h, img_w),
|
| 880 |
+
mode=interp_mode # 'nearest'
|
| 881 |
+
).squeeze(0)
|
| 882 |
+
|
| 883 |
+
|
| 884 |
+
attn_mean = attention_interp_plot.mean(0)
|
| 885 |
+
attn_mean_min = attn_mean.min()
|
| 886 |
+
attn_mean_max = attn_mean.max()
|
| 887 |
+
attn_mean = (attn_mean - attn_mean_min) / (attn_mean_max - attn_mean_min)
|
| 888 |
+
|
| 889 |
+
|
| 890 |
+
img_to_plot = cmap_attention(attn_mean)
|
| 891 |
+
axes['head_mean'].imshow(img_to_plot)
|
| 892 |
+
axes['head_mean'].axis('off')
|
| 893 |
+
|
| 894 |
+
|
| 895 |
+
these_route_steps = head_routes
|
| 896 |
+
ax_overlay = axes['overlay']
|
| 897 |
+
|
| 898 |
+
if these_route_steps: # Only plot if path exists
|
| 899 |
+
# Separate y and x coordinates
|
| 900 |
+
y_coords, x_coords = zip(*these_route_steps)
|
| 901 |
+
y_coords = np.array(y_coords)
|
| 902 |
+
x_coords = np.array(x_coords)
|
| 903 |
+
|
| 904 |
+
# Flip y-coordinates for correct plotting (imshow origin is top-left)
|
| 905 |
+
# NOTE: Original flip seemed complex, simplifying to standard flip
|
| 906 |
+
y_coords_flipped = img_h - 1 - y_coords
|
| 907 |
+
|
| 908 |
+
# Show original image flipped vertically to match coordinate system
|
| 909 |
+
ax_overlay.imshow(np.flipud(data_img_np), origin='lower')
|
| 910 |
+
|
| 911 |
+
# Draw arrows for path segments
|
| 912 |
+
arrow_scale = 2 # Arrow size scaling from original
|
| 913 |
+
for i in range(len(these_route_steps) - 1):
|
| 914 |
+
dx = x_coords[i+1] - x_coords[i]
|
| 915 |
+
dy = y_coords_flipped[i+1] - y_coords_flipped[i] # Use flipped y for delta
|
| 916 |
+
|
| 917 |
+
# Draw white background arrow (thicker)
|
| 918 |
+
ax_overlay.arrow(x_coords[i], y_coords_flipped[i], dx, dy,
|
| 919 |
+
linewidth=1.6 * arrow_scale * 1.3,
|
| 920 |
+
head_width=1.9 * arrow_scale * 1.3,
|
| 921 |
+
head_length=1.4 * arrow_scale * 1.45,
|
| 922 |
+
fc='white', ec='white', length_includes_head=True, alpha=1)
|
| 923 |
+
# Draw colored foreground arrow
|
| 924 |
+
ax_overlay.arrow(x_coords[i], y_coords_flipped[i], dx, dy,
|
| 925 |
+
linewidth=1.6 * arrow_scale,
|
| 926 |
+
head_width=1.9 * arrow_scale,
|
| 927 |
+
head_length=1.4 * arrow_scale,
|
| 928 |
+
fc=route_colours_step[i], ec=route_colours_step[i], # Use step color
|
| 929 |
+
length_includes_head=True)
|
| 930 |
+
# Set limits and turn off axes for overlay
|
| 931 |
+
ax_overlay.set_xlim([0, img_w - 1])
|
| 932 |
+
ax_overlay.set_ylim([0, img_h - 1])
|
| 933 |
+
ax_overlay.axis('off')
|
| 934 |
+
|
| 935 |
+
|
| 936 |
+
for head_i in range(n_heads):
|
| 937 |
+
if f'head_{head_i}' not in axes: continue # Skip if mosaic doesn't have this head
|
| 938 |
+
|
| 939 |
+
ax = axes[f'head_{head_i}']
|
| 940 |
+
|
| 941 |
+
# Plot attention heatmap
|
| 942 |
+
attn_up_to_now = attention_tracking[:step_i + 1, head_i].mean(0)
|
| 943 |
+
attn_up_to_now = (attn_up_to_now - attn_up_to_now.min())/(attn_up_to_now.max() - attn_up_to_now.min())
|
| 944 |
+
img_to_plot = cmap_attention(attn_up_to_now)
|
| 945 |
+
ax.imshow(img_to_plot)
|
| 946 |
+
ax.axis('off')
|
| 947 |
+
|
| 948 |
+
|
| 949 |
+
|
| 950 |
+
|
| 951 |
+
|
| 952 |
+
|
| 953 |
+
# --- Finalize and Save Frame ---
|
| 954 |
+
fig.tight_layout(pad=0.1) # Adjust spacing
|
| 955 |
+
|
| 956 |
+
# Render the plot to a numpy array
|
| 957 |
+
canvas = fig.canvas
|
| 958 |
+
canvas.draw()
|
| 959 |
+
image_numpy = np.frombuffer(canvas.buffer_rgba(), dtype='uint8')
|
| 960 |
+
image_numpy = image_numpy.reshape(*reversed(canvas.get_width_height()), 4)[:,:,:3] # Get RGB
|
| 961 |
+
|
| 962 |
+
frames.append(image_numpy) # Add to list for GIF
|
| 963 |
+
|
| 964 |
+
# Save individual frame if requested
|
| 965 |
+
if step_i==model.iterations-1:
|
| 966 |
+
fig.savefig(os.path.join(index_output_dir, f'frame_{step_i}.png'), dpi=200)
|
| 967 |
+
|
| 968 |
+
plt.close(fig) # Close figure to free memory
|
| 969 |
+
outfilename = os.path.join(index_output_dir, f'{di}_demo.mp4')
|
| 970 |
+
save_frames_to_mp4([fm[:,:,::-1] for fm in frames], outfilename, fps=15, gop_size=1, preset='veryslow')
|
| 971 |
+
|
| 972 |
+
|
tasks/image_classification/imagenet_classes.py
ADDED
|
@@ -0,0 +1,1007 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import OrderedDict
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
IMAGENET2012_CLASSES = OrderedDict(
|
| 5 |
+
{
|
| 6 |
+
"n01440764": "tench, Tinca tinca",
|
| 7 |
+
"n01443537": "goldfish, Carassius auratus",
|
| 8 |
+
"n01484850": "great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias",
|
| 9 |
+
"n01491361": "tiger shark, Galeocerdo cuvieri",
|
| 10 |
+
"n01494475": "hammerhead, hammerhead shark",
|
| 11 |
+
"n01496331": "electric ray, crampfish, numbfish, torpedo",
|
| 12 |
+
"n01498041": "stingray",
|
| 13 |
+
"n01514668": "cock",
|
| 14 |
+
"n01514859": "hen",
|
| 15 |
+
"n01518878": "ostrich, Struthio camelus",
|
| 16 |
+
"n01530575": "brambling, Fringilla montifringilla",
|
| 17 |
+
"n01531178": "goldfinch, Carduelis carduelis",
|
| 18 |
+
"n01532829": "house finch, linnet, Carpodacus mexicanus",
|
| 19 |
+
"n01534433": "junco, snowbird",
|
| 20 |
+
"n01537544": "indigo bunting, indigo finch, indigo bird, Passerina cyanea",
|
| 21 |
+
"n01558993": "robin, American robin, Turdus migratorius",
|
| 22 |
+
"n01560419": "bulbul",
|
| 23 |
+
"n01580077": "jay",
|
| 24 |
+
"n01582220": "magpie",
|
| 25 |
+
"n01592084": "chickadee",
|
| 26 |
+
"n01601694": "water ouzel, dipper",
|
| 27 |
+
"n01608432": "kite",
|
| 28 |
+
"n01614925": "bald eagle, American eagle, Haliaeetus leucocephalus",
|
| 29 |
+
"n01616318": "vulture",
|
| 30 |
+
"n01622779": "great grey owl, great gray owl, Strix nebulosa",
|
| 31 |
+
"n01629819": "European fire salamander, Salamandra salamandra",
|
| 32 |
+
"n01630670": "common newt, Triturus vulgaris",
|
| 33 |
+
"n01631663": "eft",
|
| 34 |
+
"n01632458": "spotted salamander, Ambystoma maculatum",
|
| 35 |
+
"n01632777": "axolotl, mud puppy, Ambystoma mexicanum",
|
| 36 |
+
"n01641577": "bullfrog, Rana catesbeiana",
|
| 37 |
+
"n01644373": "tree frog, tree-frog",
|
| 38 |
+
"n01644900": "tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui",
|
| 39 |
+
"n01664065": "loggerhead, loggerhead turtle, Caretta caretta",
|
| 40 |
+
"n01665541": "leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea",
|
| 41 |
+
"n01667114": "mud turtle",
|
| 42 |
+
"n01667778": "terrapin",
|
| 43 |
+
"n01669191": "box turtle, box tortoise",
|
| 44 |
+
"n01675722": "banded gecko",
|
| 45 |
+
"n01677366": "common iguana, iguana, Iguana iguana",
|
| 46 |
+
"n01682714": "American chameleon, anole, Anolis carolinensis",
|
| 47 |
+
"n01685808": "whiptail, whiptail lizard",
|
| 48 |
+
"n01687978": "agama",
|
| 49 |
+
"n01688243": "frilled lizard, Chlamydosaurus kingi",
|
| 50 |
+
"n01689811": "alligator lizard",
|
| 51 |
+
"n01692333": "Gila monster, Heloderma suspectum",
|
| 52 |
+
"n01693334": "green lizard, Lacerta viridis",
|
| 53 |
+
"n01694178": "African chameleon, Chamaeleo chamaeleon",
|
| 54 |
+
"n01695060": "Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis",
|
| 55 |
+
"n01697457": "African crocodile, Nile crocodile, Crocodylus niloticus",
|
| 56 |
+
"n01698640": "American alligator, Alligator mississipiensis",
|
| 57 |
+
"n01704323": "triceratops",
|
| 58 |
+
"n01728572": "thunder snake, worm snake, Carphophis amoenus",
|
| 59 |
+
"n01728920": "ringneck snake, ring-necked snake, ring snake",
|
| 60 |
+
"n01729322": "hognose snake, puff adder, sand viper",
|
| 61 |
+
"n01729977": "green snake, grass snake",
|
| 62 |
+
"n01734418": "king snake, kingsnake",
|
| 63 |
+
"n01735189": "garter snake, grass snake",
|
| 64 |
+
"n01737021": "water snake",
|
| 65 |
+
"n01739381": "vine snake",
|
| 66 |
+
"n01740131": "night snake, Hypsiglena torquata",
|
| 67 |
+
"n01742172": "boa constrictor, Constrictor constrictor",
|
| 68 |
+
"n01744401": "rock python, rock snake, Python sebae",
|
| 69 |
+
"n01748264": "Indian cobra, Naja naja",
|
| 70 |
+
"n01749939": "green mamba",
|
| 71 |
+
"n01751748": "sea snake",
|
| 72 |
+
"n01753488": "horned viper, cerastes, sand viper, horned asp, Cerastes cornutus",
|
| 73 |
+
"n01755581": "diamondback, diamondback rattlesnake, Crotalus adamanteus",
|
| 74 |
+
"n01756291": "sidewinder, horned rattlesnake, Crotalus cerastes",
|
| 75 |
+
"n01768244": "trilobite",
|
| 76 |
+
"n01770081": "harvestman, daddy longlegs, Phalangium opilio",
|
| 77 |
+
"n01770393": "scorpion",
|
| 78 |
+
"n01773157": "black and gold garden spider, Argiope aurantia",
|
| 79 |
+
"n01773549": "barn spider, Araneus cavaticus",
|
| 80 |
+
"n01773797": "garden spider, Aranea diademata",
|
| 81 |
+
"n01774384": "black widow, Latrodectus mactans",
|
| 82 |
+
"n01774750": "tarantula",
|
| 83 |
+
"n01775062": "wolf spider, hunting spider",
|
| 84 |
+
"n01776313": "tick",
|
| 85 |
+
"n01784675": "centipede",
|
| 86 |
+
"n01795545": "black grouse",
|
| 87 |
+
"n01796340": "ptarmigan",
|
| 88 |
+
"n01797886": "ruffed grouse, partridge, Bonasa umbellus",
|
| 89 |
+
"n01798484": "prairie chicken, prairie grouse, prairie fowl",
|
| 90 |
+
"n01806143": "peacock",
|
| 91 |
+
"n01806567": "quail",
|
| 92 |
+
"n01807496": "partridge",
|
| 93 |
+
"n01817953": "African grey, African gray, Psittacus erithacus",
|
| 94 |
+
"n01818515": "macaw",
|
| 95 |
+
"n01819313": "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita",
|
| 96 |
+
"n01820546": "lorikeet",
|
| 97 |
+
"n01824575": "coucal",
|
| 98 |
+
"n01828970": "bee eater",
|
| 99 |
+
"n01829413": "hornbill",
|
| 100 |
+
"n01833805": "hummingbird",
|
| 101 |
+
"n01843065": "jacamar",
|
| 102 |
+
"n01843383": "toucan",
|
| 103 |
+
"n01847000": "drake",
|
| 104 |
+
"n01855032": "red-breasted merganser, Mergus serrator",
|
| 105 |
+
"n01855672": "goose",
|
| 106 |
+
"n01860187": "black swan, Cygnus atratus",
|
| 107 |
+
"n01871265": "tusker",
|
| 108 |
+
"n01872401": "echidna, spiny anteater, anteater",
|
| 109 |
+
"n01873310": "platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus",
|
| 110 |
+
"n01877812": "wallaby, brush kangaroo",
|
| 111 |
+
"n01882714": "koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus",
|
| 112 |
+
"n01883070": "wombat",
|
| 113 |
+
"n01910747": "jellyfish",
|
| 114 |
+
"n01914609": "sea anemone, anemone",
|
| 115 |
+
"n01917289": "brain coral",
|
| 116 |
+
"n01924916": "flatworm, platyhelminth",
|
| 117 |
+
"n01930112": "nematode, nematode worm, roundworm",
|
| 118 |
+
"n01943899": "conch",
|
| 119 |
+
"n01944390": "snail",
|
| 120 |
+
"n01945685": "slug",
|
| 121 |
+
"n01950731": "sea slug, nudibranch",
|
| 122 |
+
"n01955084": "chiton, coat-of-mail shell, sea cradle, polyplacophore",
|
| 123 |
+
"n01968897": "chambered nautilus, pearly nautilus, nautilus",
|
| 124 |
+
"n01978287": "Dungeness crab, Cancer magister",
|
| 125 |
+
"n01978455": "rock crab, Cancer irroratus",
|
| 126 |
+
"n01980166": "fiddler crab",
|
| 127 |
+
"n01981276": "king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica",
|
| 128 |
+
"n01983481": "American lobster, Northern lobster, Maine lobster, Homarus americanus",
|
| 129 |
+
"n01984695": "spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish",
|
| 130 |
+
"n01985128": "crayfish, crawfish, crawdad, crawdaddy",
|
| 131 |
+
"n01986214": "hermit crab",
|
| 132 |
+
"n01990800": "isopod",
|
| 133 |
+
"n02002556": "white stork, Ciconia ciconia",
|
| 134 |
+
"n02002724": "black stork, Ciconia nigra",
|
| 135 |
+
"n02006656": "spoonbill",
|
| 136 |
+
"n02007558": "flamingo",
|
| 137 |
+
"n02009229": "little blue heron, Egretta caerulea",
|
| 138 |
+
"n02009912": "American egret, great white heron, Egretta albus",
|
| 139 |
+
"n02011460": "bittern",
|
| 140 |
+
"n02012849": "crane",
|
| 141 |
+
"n02013706": "limpkin, Aramus pictus",
|
| 142 |
+
"n02017213": "European gallinule, Porphyrio porphyrio",
|
| 143 |
+
"n02018207": "American coot, marsh hen, mud hen, water hen, Fulica americana",
|
| 144 |
+
"n02018795": "bustard",
|
| 145 |
+
"n02025239": "ruddy turnstone, Arenaria interpres",
|
| 146 |
+
"n02027492": "red-backed sandpiper, dunlin, Erolia alpina",
|
| 147 |
+
"n02028035": "redshank, Tringa totanus",
|
| 148 |
+
"n02033041": "dowitcher",
|
| 149 |
+
"n02037110": "oystercatcher, oyster catcher",
|
| 150 |
+
"n02051845": "pelican",
|
| 151 |
+
"n02056570": "king penguin, Aptenodytes patagonica",
|
| 152 |
+
"n02058221": "albatross, mollymawk",
|
| 153 |
+
"n02066245": "grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus",
|
| 154 |
+
"n02071294": "killer whale, killer, orca, grampus, sea wolf, Orcinus orca",
|
| 155 |
+
"n02074367": "dugong, Dugong dugon",
|
| 156 |
+
"n02077923": "sea lion",
|
| 157 |
+
"n02085620": "Chihuahua",
|
| 158 |
+
"n02085782": "Japanese spaniel",
|
| 159 |
+
"n02085936": "Maltese dog, Maltese terrier, Maltese",
|
| 160 |
+
"n02086079": "Pekinese, Pekingese, Peke",
|
| 161 |
+
"n02086240": "Shih-Tzu",
|
| 162 |
+
"n02086646": "Blenheim spaniel",
|
| 163 |
+
"n02086910": "papillon",
|
| 164 |
+
"n02087046": "toy terrier",
|
| 165 |
+
"n02087394": "Rhodesian ridgeback",
|
| 166 |
+
"n02088094": "Afghan hound, Afghan",
|
| 167 |
+
"n02088238": "basset, basset hound",
|
| 168 |
+
"n02088364": "beagle",
|
| 169 |
+
"n02088466": "bloodhound, sleuthhound",
|
| 170 |
+
"n02088632": "bluetick",
|
| 171 |
+
"n02089078": "black-and-tan coonhound",
|
| 172 |
+
"n02089867": "Walker hound, Walker foxhound",
|
| 173 |
+
"n02089973": "English foxhound",
|
| 174 |
+
"n02090379": "redbone",
|
| 175 |
+
"n02090622": "borzoi, Russian wolfhound",
|
| 176 |
+
"n02090721": "Irish wolfhound",
|
| 177 |
+
"n02091032": "Italian greyhound",
|
| 178 |
+
"n02091134": "whippet",
|
| 179 |
+
"n02091244": "Ibizan hound, Ibizan Podenco",
|
| 180 |
+
"n02091467": "Norwegian elkhound, elkhound",
|
| 181 |
+
"n02091635": "otterhound, otter hound",
|
| 182 |
+
"n02091831": "Saluki, gazelle hound",
|
| 183 |
+
"n02092002": "Scottish deerhound, deerhound",
|
| 184 |
+
"n02092339": "Weimaraner",
|
| 185 |
+
"n02093256": "Staffordshire bullterrier, Staffordshire bull terrier",
|
| 186 |
+
"n02093428": "American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier",
|
| 187 |
+
"n02093647": "Bedlington terrier",
|
| 188 |
+
"n02093754": "Border terrier",
|
| 189 |
+
"n02093859": "Kerry blue terrier",
|
| 190 |
+
"n02093991": "Irish terrier",
|
| 191 |
+
"n02094114": "Norfolk terrier",
|
| 192 |
+
"n02094258": "Norwich terrier",
|
| 193 |
+
"n02094433": "Yorkshire terrier",
|
| 194 |
+
"n02095314": "wire-haired fox terrier",
|
| 195 |
+
"n02095570": "Lakeland terrier",
|
| 196 |
+
"n02095889": "Sealyham terrier, Sealyham",
|
| 197 |
+
"n02096051": "Airedale, Airedale terrier",
|
| 198 |
+
"n02096177": "cairn, cairn terrier",
|
| 199 |
+
"n02096294": "Australian terrier",
|
| 200 |
+
"n02096437": "Dandie Dinmont, Dandie Dinmont terrier",
|
| 201 |
+
"n02096585": "Boston bull, Boston terrier",
|
| 202 |
+
"n02097047": "miniature schnauzer",
|
| 203 |
+
"n02097130": "giant schnauzer",
|
| 204 |
+
"n02097209": "standard schnauzer",
|
| 205 |
+
"n02097298": "Scotch terrier, Scottish terrier, Scottie",
|
| 206 |
+
"n02097474": "Tibetan terrier, chrysanthemum dog",
|
| 207 |
+
"n02097658": "silky terrier, Sydney silky",
|
| 208 |
+
"n02098105": "soft-coated wheaten terrier",
|
| 209 |
+
"n02098286": "West Highland white terrier",
|
| 210 |
+
"n02098413": "Lhasa, Lhasa apso",
|
| 211 |
+
"n02099267": "flat-coated retriever",
|
| 212 |
+
"n02099429": "curly-coated retriever",
|
| 213 |
+
"n02099601": "golden retriever",
|
| 214 |
+
"n02099712": "Labrador retriever",
|
| 215 |
+
"n02099849": "Chesapeake Bay retriever",
|
| 216 |
+
"n02100236": "German short-haired pointer",
|
| 217 |
+
"n02100583": "vizsla, Hungarian pointer",
|
| 218 |
+
"n02100735": "English setter",
|
| 219 |
+
"n02100877": "Irish setter, red setter",
|
| 220 |
+
"n02101006": "Gordon setter",
|
| 221 |
+
"n02101388": "Brittany spaniel",
|
| 222 |
+
"n02101556": "clumber, clumber spaniel",
|
| 223 |
+
"n02102040": "English springer, English springer spaniel",
|
| 224 |
+
"n02102177": "Welsh springer spaniel",
|
| 225 |
+
"n02102318": "cocker spaniel, English cocker spaniel, cocker",
|
| 226 |
+
"n02102480": "Sussex spaniel",
|
| 227 |
+
"n02102973": "Irish water spaniel",
|
| 228 |
+
"n02104029": "kuvasz",
|
| 229 |
+
"n02104365": "schipperke",
|
| 230 |
+
"n02105056": "groenendael",
|
| 231 |
+
"n02105162": "malinois",
|
| 232 |
+
"n02105251": "briard",
|
| 233 |
+
"n02105412": "kelpie",
|
| 234 |
+
"n02105505": "komondor",
|
| 235 |
+
"n02105641": "Old English sheepdog, bobtail",
|
| 236 |
+
"n02105855": "Shetland sheepdog, Shetland sheep dog, Shetland",
|
| 237 |
+
"n02106030": "collie",
|
| 238 |
+
"n02106166": "Border collie",
|
| 239 |
+
"n02106382": "Bouvier des Flandres, Bouviers des Flandres",
|
| 240 |
+
"n02106550": "Rottweiler",
|
| 241 |
+
"n02106662": "German shepherd, German shepherd dog, German police dog, alsatian",
|
| 242 |
+
"n02107142": "Doberman, Doberman pinscher",
|
| 243 |
+
"n02107312": "miniature pinscher",
|
| 244 |
+
"n02107574": "Greater Swiss Mountain dog",
|
| 245 |
+
"n02107683": "Bernese mountain dog",
|
| 246 |
+
"n02107908": "Appenzeller",
|
| 247 |
+
"n02108000": "EntleBucher",
|
| 248 |
+
"n02108089": "boxer",
|
| 249 |
+
"n02108422": "bull mastiff",
|
| 250 |
+
"n02108551": "Tibetan mastiff",
|
| 251 |
+
"n02108915": "French bulldog",
|
| 252 |
+
"n02109047": "Great Dane",
|
| 253 |
+
"n02109525": "Saint Bernard, St Bernard",
|
| 254 |
+
"n02109961": "Eskimo dog, husky",
|
| 255 |
+
"n02110063": "malamute, malemute, Alaskan malamute",
|
| 256 |
+
"n02110185": "Siberian husky",
|
| 257 |
+
"n02110341": "dalmatian, coach dog, carriage dog",
|
| 258 |
+
"n02110627": "affenpinscher, monkey pinscher, monkey dog",
|
| 259 |
+
"n02110806": "basenji",
|
| 260 |
+
"n02110958": "pug, pug-dog",
|
| 261 |
+
"n02111129": "Leonberg",
|
| 262 |
+
"n02111277": "Newfoundland, Newfoundland dog",
|
| 263 |
+
"n02111500": "Great Pyrenees",
|
| 264 |
+
"n02111889": "Samoyed, Samoyede",
|
| 265 |
+
"n02112018": "Pomeranian",
|
| 266 |
+
"n02112137": "chow, chow chow",
|
| 267 |
+
"n02112350": "keeshond",
|
| 268 |
+
"n02112706": "Brabancon griffon",
|
| 269 |
+
"n02113023": "Pembroke, Pembroke Welsh corgi",
|
| 270 |
+
"n02113186": "Cardigan, Cardigan Welsh corgi",
|
| 271 |
+
"n02113624": "toy poodle",
|
| 272 |
+
"n02113712": "miniature poodle",
|
| 273 |
+
"n02113799": "standard poodle",
|
| 274 |
+
"n02113978": "Mexican hairless",
|
| 275 |
+
"n02114367": "timber wolf, grey wolf, gray wolf, Canis lupus",
|
| 276 |
+
"n02114548": "white wolf, Arctic wolf, Canis lupus tundrarum",
|
| 277 |
+
"n02114712": "red wolf, maned wolf, Canis rufus, Canis niger",
|
| 278 |
+
"n02114855": "coyote, prairie wolf, brush wolf, Canis latrans",
|
| 279 |
+
"n02115641": "dingo, warrigal, warragal, Canis dingo",
|
| 280 |
+
"n02115913": "dhole, Cuon alpinus",
|
| 281 |
+
"n02116738": "African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus",
|
| 282 |
+
"n02117135": "hyena, hyaena",
|
| 283 |
+
"n02119022": "red fox, Vulpes vulpes",
|
| 284 |
+
"n02119789": "kit fox, Vulpes macrotis",
|
| 285 |
+
"n02120079": "Arctic fox, white fox, Alopex lagopus",
|
| 286 |
+
"n02120505": "grey fox, gray fox, Urocyon cinereoargenteus",
|
| 287 |
+
"n02123045": "tabby, tabby cat",
|
| 288 |
+
"n02123159": "tiger cat",
|
| 289 |
+
"n02123394": "Persian cat",
|
| 290 |
+
"n02123597": "Siamese cat, Siamese",
|
| 291 |
+
"n02124075": "Egyptian cat",
|
| 292 |
+
"n02125311": "cougar, puma, catamount, mountain lion, painter, panther, Felis concolor",
|
| 293 |
+
"n02127052": "lynx, catamount",
|
| 294 |
+
"n02128385": "leopard, Panthera pardus",
|
| 295 |
+
"n02128757": "snow leopard, ounce, Panthera uncia",
|
| 296 |
+
"n02128925": "jaguar, panther, Panthera onca, Felis onca",
|
| 297 |
+
"n02129165": "lion, king of beasts, Panthera leo",
|
| 298 |
+
"n02129604": "tiger, Panthera tigris",
|
| 299 |
+
"n02130308": "cheetah, chetah, Acinonyx jubatus",
|
| 300 |
+
"n02132136": "brown bear, bruin, Ursus arctos",
|
| 301 |
+
"n02133161": "American black bear, black bear, Ursus americanus, Euarctos americanus",
|
| 302 |
+
"n02134084": "ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus",
|
| 303 |
+
"n02134418": "sloth bear, Melursus ursinus, Ursus ursinus",
|
| 304 |
+
"n02137549": "mongoose",
|
| 305 |
+
"n02138441": "meerkat, mierkat",
|
| 306 |
+
"n02165105": "tiger beetle",
|
| 307 |
+
"n02165456": "ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle",
|
| 308 |
+
"n02167151": "ground beetle, carabid beetle",
|
| 309 |
+
"n02168699": "long-horned beetle, longicorn, longicorn beetle",
|
| 310 |
+
"n02169497": "leaf beetle, chrysomelid",
|
| 311 |
+
"n02172182": "dung beetle",
|
| 312 |
+
"n02174001": "rhinoceros beetle",
|
| 313 |
+
"n02177972": "weevil",
|
| 314 |
+
"n02190166": "fly",
|
| 315 |
+
"n02206856": "bee",
|
| 316 |
+
"n02219486": "ant, emmet, pismire",
|
| 317 |
+
"n02226429": "grasshopper, hopper",
|
| 318 |
+
"n02229544": "cricket",
|
| 319 |
+
"n02231487": "walking stick, walkingstick, stick insect",
|
| 320 |
+
"n02233338": "cockroach, roach",
|
| 321 |
+
"n02236044": "mantis, mantid",
|
| 322 |
+
"n02256656": "cicada, cicala",
|
| 323 |
+
"n02259212": "leafhopper",
|
| 324 |
+
"n02264363": "lacewing, lacewing fly",
|
| 325 |
+
"n02268443": "dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk",
|
| 326 |
+
"n02268853": "damselfly",
|
| 327 |
+
"n02276258": "admiral",
|
| 328 |
+
"n02277742": "ringlet, ringlet butterfly",
|
| 329 |
+
"n02279972": "monarch, monarch butterfly, milkweed butterfly, Danaus plexippus",
|
| 330 |
+
"n02280649": "cabbage butterfly",
|
| 331 |
+
"n02281406": "sulphur butterfly, sulfur butterfly",
|
| 332 |
+
"n02281787": "lycaenid, lycaenid butterfly",
|
| 333 |
+
"n02317335": "starfish, sea star",
|
| 334 |
+
"n02319095": "sea urchin",
|
| 335 |
+
"n02321529": "sea cucumber, holothurian",
|
| 336 |
+
"n02325366": "wood rabbit, cottontail, cottontail rabbit",
|
| 337 |
+
"n02326432": "hare",
|
| 338 |
+
"n02328150": "Angora, Angora rabbit",
|
| 339 |
+
"n02342885": "hamster",
|
| 340 |
+
"n02346627": "porcupine, hedgehog",
|
| 341 |
+
"n02356798": "fox squirrel, eastern fox squirrel, Sciurus niger",
|
| 342 |
+
"n02361337": "marmot",
|
| 343 |
+
"n02363005": "beaver",
|
| 344 |
+
"n02364673": "guinea pig, Cavia cobaya",
|
| 345 |
+
"n02389026": "sorrel",
|
| 346 |
+
"n02391049": "zebra",
|
| 347 |
+
"n02395406": "hog, pig, grunter, squealer, Sus scrofa",
|
| 348 |
+
"n02396427": "wild boar, boar, Sus scrofa",
|
| 349 |
+
"n02397096": "warthog",
|
| 350 |
+
"n02398521": "hippopotamus, hippo, river horse, Hippopotamus amphibius",
|
| 351 |
+
"n02403003": "ox",
|
| 352 |
+
"n02408429": "water buffalo, water ox, Asiatic buffalo, Bubalus bubalis",
|
| 353 |
+
"n02410509": "bison",
|
| 354 |
+
"n02412080": "ram, tup",
|
| 355 |
+
"n02415577": "bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis",
|
| 356 |
+
"n02417914": "ibex, Capra ibex",
|
| 357 |
+
"n02422106": "hartebeest",
|
| 358 |
+
"n02422699": "impala, Aepyceros melampus",
|
| 359 |
+
"n02423022": "gazelle",
|
| 360 |
+
"n02437312": "Arabian camel, dromedary, Camelus dromedarius",
|
| 361 |
+
"n02437616": "llama",
|
| 362 |
+
"n02441942": "weasel",
|
| 363 |
+
"n02442845": "mink",
|
| 364 |
+
"n02443114": "polecat, fitch, foulmart, foumart, Mustela putorius",
|
| 365 |
+
"n02443484": "black-footed ferret, ferret, Mustela nigripes",
|
| 366 |
+
"n02444819": "otter",
|
| 367 |
+
"n02445715": "skunk, polecat, wood pussy",
|
| 368 |
+
"n02447366": "badger",
|
| 369 |
+
"n02454379": "armadillo",
|
| 370 |
+
"n02457408": "three-toed sloth, ai, Bradypus tridactylus",
|
| 371 |
+
"n02480495": "orangutan, orang, orangutang, Pongo pygmaeus",
|
| 372 |
+
"n02480855": "gorilla, Gorilla gorilla",
|
| 373 |
+
"n02481823": "chimpanzee, chimp, Pan troglodytes",
|
| 374 |
+
"n02483362": "gibbon, Hylobates lar",
|
| 375 |
+
"n02483708": "siamang, Hylobates syndactylus, Symphalangus syndactylus",
|
| 376 |
+
"n02484975": "guenon, guenon monkey",
|
| 377 |
+
"n02486261": "patas, hussar monkey, Erythrocebus patas",
|
| 378 |
+
"n02486410": "baboon",
|
| 379 |
+
"n02487347": "macaque",
|
| 380 |
+
"n02488291": "langur",
|
| 381 |
+
"n02488702": "colobus, colobus monkey",
|
| 382 |
+
"n02489166": "proboscis monkey, Nasalis larvatus",
|
| 383 |
+
"n02490219": "marmoset",
|
| 384 |
+
"n02492035": "capuchin, ringtail, Cebus capucinus",
|
| 385 |
+
"n02492660": "howler monkey, howler",
|
| 386 |
+
"n02493509": "titi, titi monkey",
|
| 387 |
+
"n02493793": "spider monkey, Ateles geoffroyi",
|
| 388 |
+
"n02494079": "squirrel monkey, Saimiri sciureus",
|
| 389 |
+
"n02497673": "Madagascar cat, ring-tailed lemur, Lemur catta",
|
| 390 |
+
"n02500267": "indri, indris, Indri indri, Indri brevicaudatus",
|
| 391 |
+
"n02504013": "Indian elephant, Elephas maximus",
|
| 392 |
+
"n02504458": "African elephant, Loxodonta africana",
|
| 393 |
+
"n02509815": "lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens",
|
| 394 |
+
"n02510455": "giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca",
|
| 395 |
+
"n02514041": "barracouta, snoek",
|
| 396 |
+
"n02526121": "eel",
|
| 397 |
+
"n02536864": "coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch",
|
| 398 |
+
"n02606052": "rock beauty, Holocanthus tricolor",
|
| 399 |
+
"n02607072": "anemone fish",
|
| 400 |
+
"n02640242": "sturgeon",
|
| 401 |
+
"n02641379": "gar, garfish, garpike, billfish, Lepisosteus osseus",
|
| 402 |
+
"n02643566": "lionfish",
|
| 403 |
+
"n02655020": "puffer, pufferfish, blowfish, globefish",
|
| 404 |
+
"n02666196": "abacus",
|
| 405 |
+
"n02667093": "abaya",
|
| 406 |
+
"n02669723": "academic gown, academic robe, judge's robe",
|
| 407 |
+
"n02672831": "accordion, piano accordion, squeeze box",
|
| 408 |
+
"n02676566": "acoustic guitar",
|
| 409 |
+
"n02687172": "aircraft carrier, carrier, flattop, attack aircraft carrier",
|
| 410 |
+
"n02690373": "airliner",
|
| 411 |
+
"n02692877": "airship, dirigible",
|
| 412 |
+
"n02699494": "altar",
|
| 413 |
+
"n02701002": "ambulance",
|
| 414 |
+
"n02704792": "amphibian, amphibious vehicle",
|
| 415 |
+
"n02708093": "analog clock",
|
| 416 |
+
"n02727426": "apiary, bee house",
|
| 417 |
+
"n02730930": "apron",
|
| 418 |
+
"n02747177": "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin",
|
| 419 |
+
"n02749479": "assault rifle, assault gun",
|
| 420 |
+
"n02769748": "backpack, back pack, knapsack, packsack, rucksack, haversack",
|
| 421 |
+
"n02776631": "bakery, bakeshop, bakehouse",
|
| 422 |
+
"n02777292": "balance beam, beam",
|
| 423 |
+
"n02782093": "balloon",
|
| 424 |
+
"n02783161": "ballpoint, ballpoint pen, ballpen, Biro",
|
| 425 |
+
"n02786058": "Band Aid",
|
| 426 |
+
"n02787622": "banjo",
|
| 427 |
+
"n02788148": "bannister, banister, balustrade, balusters, handrail",
|
| 428 |
+
"n02790996": "barbell",
|
| 429 |
+
"n02791124": "barber chair",
|
| 430 |
+
"n02791270": "barbershop",
|
| 431 |
+
"n02793495": "barn",
|
| 432 |
+
"n02794156": "barometer",
|
| 433 |
+
"n02795169": "barrel, cask",
|
| 434 |
+
"n02797295": "barrow, garden cart, lawn cart, wheelbarrow",
|
| 435 |
+
"n02799071": "baseball",
|
| 436 |
+
"n02802426": "basketball",
|
| 437 |
+
"n02804414": "bassinet",
|
| 438 |
+
"n02804610": "bassoon",
|
| 439 |
+
"n02807133": "bathing cap, swimming cap",
|
| 440 |
+
"n02808304": "bath towel",
|
| 441 |
+
"n02808440": "bathtub, bathing tub, bath, tub",
|
| 442 |
+
"n02814533": "beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon",
|
| 443 |
+
"n02814860": "beacon, lighthouse, beacon light, pharos",
|
| 444 |
+
"n02815834": "beaker",
|
| 445 |
+
"n02817516": "bearskin, busby, shako",
|
| 446 |
+
"n02823428": "beer bottle",
|
| 447 |
+
"n02823750": "beer glass",
|
| 448 |
+
"n02825657": "bell cote, bell cot",
|
| 449 |
+
"n02834397": "bib",
|
| 450 |
+
"n02835271": "bicycle-built-for-two, tandem bicycle, tandem",
|
| 451 |
+
"n02837789": "bikini, two-piece",
|
| 452 |
+
"n02840245": "binder, ring-binder",
|
| 453 |
+
"n02841315": "binoculars, field glasses, opera glasses",
|
| 454 |
+
"n02843684": "birdhouse",
|
| 455 |
+
"n02859443": "boathouse",
|
| 456 |
+
"n02860847": "bobsled, bobsleigh, bob",
|
| 457 |
+
"n02865351": "bolo tie, bolo, bola tie, bola",
|
| 458 |
+
"n02869837": "bonnet, poke bonnet",
|
| 459 |
+
"n02870880": "bookcase",
|
| 460 |
+
"n02871525": "bookshop, bookstore, bookstall",
|
| 461 |
+
"n02877765": "bottlecap",
|
| 462 |
+
"n02879718": "bow",
|
| 463 |
+
"n02883205": "bow tie, bow-tie, bowtie",
|
| 464 |
+
"n02892201": "brass, memorial tablet, plaque",
|
| 465 |
+
"n02892767": "brassiere, bra, bandeau",
|
| 466 |
+
"n02894605": "breakwater, groin, groyne, mole, bulwark, seawall, jetty",
|
| 467 |
+
"n02895154": "breastplate, aegis, egis",
|
| 468 |
+
"n02906734": "broom",
|
| 469 |
+
"n02909870": "bucket, pail",
|
| 470 |
+
"n02910353": "buckle",
|
| 471 |
+
"n02916936": "bulletproof vest",
|
| 472 |
+
"n02917067": "bullet train, bullet",
|
| 473 |
+
"n02927161": "butcher shop, meat market",
|
| 474 |
+
"n02930766": "cab, hack, taxi, taxicab",
|
| 475 |
+
"n02939185": "caldron, cauldron",
|
| 476 |
+
"n02948072": "candle, taper, wax light",
|
| 477 |
+
"n02950826": "cannon",
|
| 478 |
+
"n02951358": "canoe",
|
| 479 |
+
"n02951585": "can opener, tin opener",
|
| 480 |
+
"n02963159": "cardigan",
|
| 481 |
+
"n02965783": "car mirror",
|
| 482 |
+
"n02966193": "carousel, carrousel, merry-go-round, roundabout, whirligig",
|
| 483 |
+
"n02966687": "carpenter's kit, tool kit",
|
| 484 |
+
"n02971356": "carton",
|
| 485 |
+
"n02974003": "car wheel",
|
| 486 |
+
"n02977058": "cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM",
|
| 487 |
+
"n02978881": "cassette",
|
| 488 |
+
"n02979186": "cassette player",
|
| 489 |
+
"n02980441": "castle",
|
| 490 |
+
"n02981792": "catamaran",
|
| 491 |
+
"n02988304": "CD player",
|
| 492 |
+
"n02992211": "cello, violoncello",
|
| 493 |
+
"n02992529": "cellular telephone, cellular phone, cellphone, cell, mobile phone",
|
| 494 |
+
"n02999410": "chain",
|
| 495 |
+
"n03000134": "chainlink fence",
|
| 496 |
+
"n03000247": "chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour",
|
| 497 |
+
"n03000684": "chain saw, chainsaw",
|
| 498 |
+
"n03014705": "chest",
|
| 499 |
+
"n03016953": "chiffonier, commode",
|
| 500 |
+
"n03017168": "chime, bell, gong",
|
| 501 |
+
"n03018349": "china cabinet, china closet",
|
| 502 |
+
"n03026506": "Christmas stocking",
|
| 503 |
+
"n03028079": "church, church building",
|
| 504 |
+
"n03032252": "cinema, movie theater, movie theatre, movie house, picture palace",
|
| 505 |
+
"n03041632": "cleaver, meat cleaver, chopper",
|
| 506 |
+
"n03042490": "cliff dwelling",
|
| 507 |
+
"n03045698": "cloak",
|
| 508 |
+
"n03047690": "clog, geta, patten, sabot",
|
| 509 |
+
"n03062245": "cocktail shaker",
|
| 510 |
+
"n03063599": "coffee mug",
|
| 511 |
+
"n03063689": "coffeepot",
|
| 512 |
+
"n03065424": "coil, spiral, volute, whorl, helix",
|
| 513 |
+
"n03075370": "combination lock",
|
| 514 |
+
"n03085013": "computer keyboard, keypad",
|
| 515 |
+
"n03089624": "confectionery, confectionary, candy store",
|
| 516 |
+
"n03095699": "container ship, containership, container vessel",
|
| 517 |
+
"n03100240": "convertible",
|
| 518 |
+
"n03109150": "corkscrew, bottle screw",
|
| 519 |
+
"n03110669": "cornet, horn, trumpet, trump",
|
| 520 |
+
"n03124043": "cowboy boot",
|
| 521 |
+
"n03124170": "cowboy hat, ten-gallon hat",
|
| 522 |
+
"n03125729": "cradle",
|
| 523 |
+
"n03126707": "crane2",
|
| 524 |
+
"n03127747": "crash helmet",
|
| 525 |
+
"n03127925": "crate",
|
| 526 |
+
"n03131574": "crib, cot",
|
| 527 |
+
"n03133878": "Crock Pot",
|
| 528 |
+
"n03134739": "croquet ball",
|
| 529 |
+
"n03141823": "crutch",
|
| 530 |
+
"n03146219": "cuirass",
|
| 531 |
+
"n03160309": "dam, dike, dyke",
|
| 532 |
+
"n03179701": "desk",
|
| 533 |
+
"n03180011": "desktop computer",
|
| 534 |
+
"n03187595": "dial telephone, dial phone",
|
| 535 |
+
"n03188531": "diaper, nappy, napkin",
|
| 536 |
+
"n03196217": "digital clock",
|
| 537 |
+
"n03197337": "digital watch",
|
| 538 |
+
"n03201208": "dining table, board",
|
| 539 |
+
"n03207743": "dishrag, dishcloth",
|
| 540 |
+
"n03207941": "dishwasher, dish washer, dishwashing machine",
|
| 541 |
+
"n03208938": "disk brake, disc brake",
|
| 542 |
+
"n03216828": "dock, dockage, docking facility",
|
| 543 |
+
"n03218198": "dogsled, dog sled, dog sleigh",
|
| 544 |
+
"n03220513": "dome",
|
| 545 |
+
"n03223299": "doormat, welcome mat",
|
| 546 |
+
"n03240683": "drilling platform, offshore rig",
|
| 547 |
+
"n03249569": "drum, membranophone, tympan",
|
| 548 |
+
"n03250847": "drumstick",
|
| 549 |
+
"n03255030": "dumbbell",
|
| 550 |
+
"n03259280": "Dutch oven",
|
| 551 |
+
"n03271574": "electric fan, blower",
|
| 552 |
+
"n03272010": "electric guitar",
|
| 553 |
+
"n03272562": "electric locomotive",
|
| 554 |
+
"n03290653": "entertainment center",
|
| 555 |
+
"n03291819": "envelope",
|
| 556 |
+
"n03297495": "espresso maker",
|
| 557 |
+
"n03314780": "face powder",
|
| 558 |
+
"n03325584": "feather boa, boa",
|
| 559 |
+
"n03337140": "file, file cabinet, filing cabinet",
|
| 560 |
+
"n03344393": "fireboat",
|
| 561 |
+
"n03345487": "fire engine, fire truck",
|
| 562 |
+
"n03347037": "fire screen, fireguard",
|
| 563 |
+
"n03355925": "flagpole, flagstaff",
|
| 564 |
+
"n03372029": "flute, transverse flute",
|
| 565 |
+
"n03376595": "folding chair",
|
| 566 |
+
"n03379051": "football helmet",
|
| 567 |
+
"n03384352": "forklift",
|
| 568 |
+
"n03388043": "fountain",
|
| 569 |
+
"n03388183": "fountain pen",
|
| 570 |
+
"n03388549": "four-poster",
|
| 571 |
+
"n03393912": "freight car",
|
| 572 |
+
"n03394916": "French horn, horn",
|
| 573 |
+
"n03400231": "frying pan, frypan, skillet",
|
| 574 |
+
"n03404251": "fur coat",
|
| 575 |
+
"n03417042": "garbage truck, dustcart",
|
| 576 |
+
"n03424325": "gasmask, respirator, gas helmet",
|
| 577 |
+
"n03425413": "gas pump, gasoline pump, petrol pump, island dispenser",
|
| 578 |
+
"n03443371": "goblet",
|
| 579 |
+
"n03444034": "go-kart",
|
| 580 |
+
"n03445777": "golf ball",
|
| 581 |
+
"n03445924": "golfcart, golf cart",
|
| 582 |
+
"n03447447": "gondola",
|
| 583 |
+
"n03447721": "gong, tam-tam",
|
| 584 |
+
"n03450230": "gown",
|
| 585 |
+
"n03452741": "grand piano, grand",
|
| 586 |
+
"n03457902": "greenhouse, nursery, glasshouse",
|
| 587 |
+
"n03459775": "grille, radiator grille",
|
| 588 |
+
"n03461385": "grocery store, grocery, food market, market",
|
| 589 |
+
"n03467068": "guillotine",
|
| 590 |
+
"n03476684": "hair slide",
|
| 591 |
+
"n03476991": "hair spray",
|
| 592 |
+
"n03478589": "half track",
|
| 593 |
+
"n03481172": "hammer",
|
| 594 |
+
"n03482405": "hamper",
|
| 595 |
+
"n03483316": "hand blower, blow dryer, blow drier, hair dryer, hair drier",
|
| 596 |
+
"n03485407": "hand-held computer, hand-held microcomputer",
|
| 597 |
+
"n03485794": "handkerchief, hankie, hanky, hankey",
|
| 598 |
+
"n03492542": "hard disc, hard disk, fixed disk",
|
| 599 |
+
"n03494278": "harmonica, mouth organ, harp, mouth harp",
|
| 600 |
+
"n03495258": "harp",
|
| 601 |
+
"n03496892": "harvester, reaper",
|
| 602 |
+
"n03498962": "hatchet",
|
| 603 |
+
"n03527444": "holster",
|
| 604 |
+
"n03529860": "home theater, home theatre",
|
| 605 |
+
"n03530642": "honeycomb",
|
| 606 |
+
"n03532672": "hook, claw",
|
| 607 |
+
"n03534580": "hoopskirt, crinoline",
|
| 608 |
+
"n03535780": "horizontal bar, high bar",
|
| 609 |
+
"n03538406": "horse cart, horse-cart",
|
| 610 |
+
"n03544143": "hourglass",
|
| 611 |
+
"n03584254": "iPod",
|
| 612 |
+
"n03584829": "iron, smoothing iron",
|
| 613 |
+
"n03590841": "jack-o'-lantern",
|
| 614 |
+
"n03594734": "jean, blue jean, denim",
|
| 615 |
+
"n03594945": "jeep, landrover",
|
| 616 |
+
"n03595614": "jersey, T-shirt, tee shirt",
|
| 617 |
+
"n03598930": "jigsaw puzzle",
|
| 618 |
+
"n03599486": "jinrikisha, ricksha, rickshaw",
|
| 619 |
+
"n03602883": "joystick",
|
| 620 |
+
"n03617480": "kimono",
|
| 621 |
+
"n03623198": "knee pad",
|
| 622 |
+
"n03627232": "knot",
|
| 623 |
+
"n03630383": "lab coat, laboratory coat",
|
| 624 |
+
"n03633091": "ladle",
|
| 625 |
+
"n03637318": "lampshade, lamp shade",
|
| 626 |
+
"n03642806": "laptop, laptop computer",
|
| 627 |
+
"n03649909": "lawn mower, mower",
|
| 628 |
+
"n03657121": "lens cap, lens cover",
|
| 629 |
+
"n03658185": "letter opener, paper knife, paperknife",
|
| 630 |
+
"n03661043": "library",
|
| 631 |
+
"n03662601": "lifeboat",
|
| 632 |
+
"n03666591": "lighter, light, igniter, ignitor",
|
| 633 |
+
"n03670208": "limousine, limo",
|
| 634 |
+
"n03673027": "liner, ocean liner",
|
| 635 |
+
"n03676483": "lipstick, lip rouge",
|
| 636 |
+
"n03680355": "Loafer",
|
| 637 |
+
"n03690938": "lotion",
|
| 638 |
+
"n03691459": "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system",
|
| 639 |
+
"n03692522": "loupe, jeweler's loupe",
|
| 640 |
+
"n03697007": "lumbermill, sawmill",
|
| 641 |
+
"n03706229": "magnetic compass",
|
| 642 |
+
"n03709823": "mailbag, postbag",
|
| 643 |
+
"n03710193": "mailbox, letter box",
|
| 644 |
+
"n03710637": "maillot",
|
| 645 |
+
"n03710721": "maillot, tank suit",
|
| 646 |
+
"n03717622": "manhole cover",
|
| 647 |
+
"n03720891": "maraca",
|
| 648 |
+
"n03721384": "marimba, xylophone",
|
| 649 |
+
"n03724870": "mask",
|
| 650 |
+
"n03729826": "matchstick",
|
| 651 |
+
"n03733131": "maypole",
|
| 652 |
+
"n03733281": "maze, labyrinth",
|
| 653 |
+
"n03733805": "measuring cup",
|
| 654 |
+
"n03742115": "medicine chest, medicine cabinet",
|
| 655 |
+
"n03743016": "megalith, megalithic structure",
|
| 656 |
+
"n03759954": "microphone, mike",
|
| 657 |
+
"n03761084": "microwave, microwave oven",
|
| 658 |
+
"n03763968": "military uniform",
|
| 659 |
+
"n03764736": "milk can",
|
| 660 |
+
"n03769881": "minibus",
|
| 661 |
+
"n03770439": "miniskirt, mini",
|
| 662 |
+
"n03770679": "minivan",
|
| 663 |
+
"n03773504": "missile",
|
| 664 |
+
"n03775071": "mitten",
|
| 665 |
+
"n03775546": "mixing bowl",
|
| 666 |
+
"n03776460": "mobile home, manufactured home",
|
| 667 |
+
"n03777568": "Model T",
|
| 668 |
+
"n03777754": "modem",
|
| 669 |
+
"n03781244": "monastery",
|
| 670 |
+
"n03782006": "monitor",
|
| 671 |
+
"n03785016": "moped",
|
| 672 |
+
"n03786901": "mortar",
|
| 673 |
+
"n03787032": "mortarboard",
|
| 674 |
+
"n03788195": "mosque",
|
| 675 |
+
"n03788365": "mosquito net",
|
| 676 |
+
"n03791053": "motor scooter, scooter",
|
| 677 |
+
"n03792782": "mountain bike, all-terrain bike, off-roader",
|
| 678 |
+
"n03792972": "mountain tent",
|
| 679 |
+
"n03793489": "mouse, computer mouse",
|
| 680 |
+
"n03794056": "mousetrap",
|
| 681 |
+
"n03796401": "moving van",
|
| 682 |
+
"n03803284": "muzzle",
|
| 683 |
+
"n03804744": "nail",
|
| 684 |
+
"n03814639": "neck brace",
|
| 685 |
+
"n03814906": "necklace",
|
| 686 |
+
"n03825788": "nipple",
|
| 687 |
+
"n03832673": "notebook, notebook computer",
|
| 688 |
+
"n03837869": "obelisk",
|
| 689 |
+
"n03838899": "oboe, hautboy, hautbois",
|
| 690 |
+
"n03840681": "ocarina, sweet potato",
|
| 691 |
+
"n03841143": "odometer, hodometer, mileometer, milometer",
|
| 692 |
+
"n03843555": "oil filter",
|
| 693 |
+
"n03854065": "organ, pipe organ",
|
| 694 |
+
"n03857828": "oscilloscope, scope, cathode-ray oscilloscope, CRO",
|
| 695 |
+
"n03866082": "overskirt",
|
| 696 |
+
"n03868242": "oxcart",
|
| 697 |
+
"n03868863": "oxygen mask",
|
| 698 |
+
"n03871628": "packet",
|
| 699 |
+
"n03873416": "paddle, boat paddle",
|
| 700 |
+
"n03874293": "paddlewheel, paddle wheel",
|
| 701 |
+
"n03874599": "padlock",
|
| 702 |
+
"n03876231": "paintbrush",
|
| 703 |
+
"n03877472": "pajama, pyjama, pj's, jammies",
|
| 704 |
+
"n03877845": "palace",
|
| 705 |
+
"n03884397": "panpipe, pandean pipe, syrinx",
|
| 706 |
+
"n03887697": "paper towel",
|
| 707 |
+
"n03888257": "parachute, chute",
|
| 708 |
+
"n03888605": "parallel bars, bars",
|
| 709 |
+
"n03891251": "park bench",
|
| 710 |
+
"n03891332": "parking meter",
|
| 711 |
+
"n03895866": "passenger car, coach, carriage",
|
| 712 |
+
"n03899768": "patio, terrace",
|
| 713 |
+
"n03902125": "pay-phone, pay-station",
|
| 714 |
+
"n03903868": "pedestal, plinth, footstall",
|
| 715 |
+
"n03908618": "pencil box, pencil case",
|
| 716 |
+
"n03908714": "pencil sharpener",
|
| 717 |
+
"n03916031": "perfume, essence",
|
| 718 |
+
"n03920288": "Petri dish",
|
| 719 |
+
"n03924679": "photocopier",
|
| 720 |
+
"n03929660": "pick, plectrum, plectron",
|
| 721 |
+
"n03929855": "pickelhaube",
|
| 722 |
+
"n03930313": "picket fence, paling",
|
| 723 |
+
"n03930630": "pickup, pickup truck",
|
| 724 |
+
"n03933933": "pier",
|
| 725 |
+
"n03935335": "piggy bank, penny bank",
|
| 726 |
+
"n03937543": "pill bottle",
|
| 727 |
+
"n03938244": "pillow",
|
| 728 |
+
"n03942813": "ping-pong ball",
|
| 729 |
+
"n03944341": "pinwheel",
|
| 730 |
+
"n03947888": "pirate, pirate ship",
|
| 731 |
+
"n03950228": "pitcher, ewer",
|
| 732 |
+
"n03954731": "plane, carpenter's plane, woodworking plane",
|
| 733 |
+
"n03956157": "planetarium",
|
| 734 |
+
"n03958227": "plastic bag",
|
| 735 |
+
"n03961711": "plate rack",
|
| 736 |
+
"n03967562": "plow, plough",
|
| 737 |
+
"n03970156": "plunger, plumber's helper",
|
| 738 |
+
"n03976467": "Polaroid camera, Polaroid Land camera",
|
| 739 |
+
"n03976657": "pole",
|
| 740 |
+
"n03977966": "police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria",
|
| 741 |
+
"n03980874": "poncho",
|
| 742 |
+
"n03982430": "pool table, billiard table, snooker table",
|
| 743 |
+
"n03983396": "pop bottle, soda bottle",
|
| 744 |
+
"n03991062": "pot, flowerpot",
|
| 745 |
+
"n03992509": "potter's wheel",
|
| 746 |
+
"n03995372": "power drill",
|
| 747 |
+
"n03998194": "prayer rug, prayer mat",
|
| 748 |
+
"n04004767": "printer",
|
| 749 |
+
"n04005630": "prison, prison house",
|
| 750 |
+
"n04008634": "projectile, missile",
|
| 751 |
+
"n04009552": "projector",
|
| 752 |
+
"n04019541": "puck, hockey puck",
|
| 753 |
+
"n04023962": "punching bag, punch bag, punching ball, punchball",
|
| 754 |
+
"n04026417": "purse",
|
| 755 |
+
"n04033901": "quill, quill pen",
|
| 756 |
+
"n04033995": "quilt, comforter, comfort, puff",
|
| 757 |
+
"n04037443": "racer, race car, racing car",
|
| 758 |
+
"n04039381": "racket, racquet",
|
| 759 |
+
"n04040759": "radiator",
|
| 760 |
+
"n04041544": "radio, wireless",
|
| 761 |
+
"n04044716": "radio telescope, radio reflector",
|
| 762 |
+
"n04049303": "rain barrel",
|
| 763 |
+
"n04065272": "recreational vehicle, RV, R.V.",
|
| 764 |
+
"n04067472": "reel",
|
| 765 |
+
"n04069434": "reflex camera",
|
| 766 |
+
"n04070727": "refrigerator, icebox",
|
| 767 |
+
"n04074963": "remote control, remote",
|
| 768 |
+
"n04081281": "restaurant, eating house, eating place, eatery",
|
| 769 |
+
"n04086273": "revolver, six-gun, six-shooter",
|
| 770 |
+
"n04090263": "rifle",
|
| 771 |
+
"n04099969": "rocking chair, rocker",
|
| 772 |
+
"n04111531": "rotisserie",
|
| 773 |
+
"n04116512": "rubber eraser, rubber, pencil eraser",
|
| 774 |
+
"n04118538": "rugby ball",
|
| 775 |
+
"n04118776": "rule, ruler",
|
| 776 |
+
"n04120489": "running shoe",
|
| 777 |
+
"n04125021": "safe",
|
| 778 |
+
"n04127249": "safety pin",
|
| 779 |
+
"n04131690": "saltshaker, salt shaker",
|
| 780 |
+
"n04133789": "sandal",
|
| 781 |
+
"n04136333": "sarong",
|
| 782 |
+
"n04141076": "sax, saxophone",
|
| 783 |
+
"n04141327": "scabbard",
|
| 784 |
+
"n04141975": "scale, weighing machine",
|
| 785 |
+
"n04146614": "school bus",
|
| 786 |
+
"n04147183": "schooner",
|
| 787 |
+
"n04149813": "scoreboard",
|
| 788 |
+
"n04152593": "screen, CRT screen",
|
| 789 |
+
"n04153751": "screw",
|
| 790 |
+
"n04154565": "screwdriver",
|
| 791 |
+
"n04162706": "seat belt, seatbelt",
|
| 792 |
+
"n04179913": "sewing machine",
|
| 793 |
+
"n04192698": "shield, buckler",
|
| 794 |
+
"n04200800": "shoe shop, shoe-shop, shoe store",
|
| 795 |
+
"n04201297": "shoji",
|
| 796 |
+
"n04204238": "shopping basket",
|
| 797 |
+
"n04204347": "shopping cart",
|
| 798 |
+
"n04208210": "shovel",
|
| 799 |
+
"n04209133": "shower cap",
|
| 800 |
+
"n04209239": "shower curtain",
|
| 801 |
+
"n04228054": "ski",
|
| 802 |
+
"n04229816": "ski mask",
|
| 803 |
+
"n04235860": "sleeping bag",
|
| 804 |
+
"n04238763": "slide rule, slipstick",
|
| 805 |
+
"n04239074": "sliding door",
|
| 806 |
+
"n04243546": "slot, one-armed bandit",
|
| 807 |
+
"n04251144": "snorkel",
|
| 808 |
+
"n04252077": "snowmobile",
|
| 809 |
+
"n04252225": "snowplow, snowplough",
|
| 810 |
+
"n04254120": "soap dispenser",
|
| 811 |
+
"n04254680": "soccer ball",
|
| 812 |
+
"n04254777": "sock",
|
| 813 |
+
"n04258138": "solar dish, solar collector, solar furnace",
|
| 814 |
+
"n04259630": "sombrero",
|
| 815 |
+
"n04263257": "soup bowl",
|
| 816 |
+
"n04264628": "space bar",
|
| 817 |
+
"n04265275": "space heater",
|
| 818 |
+
"n04266014": "space shuttle",
|
| 819 |
+
"n04270147": "spatula",
|
| 820 |
+
"n04273569": "speedboat",
|
| 821 |
+
"n04275548": "spider web, spider's web",
|
| 822 |
+
"n04277352": "spindle",
|
| 823 |
+
"n04285008": "sports car, sport car",
|
| 824 |
+
"n04286575": "spotlight, spot",
|
| 825 |
+
"n04296562": "stage",
|
| 826 |
+
"n04310018": "steam locomotive",
|
| 827 |
+
"n04311004": "steel arch bridge",
|
| 828 |
+
"n04311174": "steel drum",
|
| 829 |
+
"n04317175": "stethoscope",
|
| 830 |
+
"n04325704": "stole",
|
| 831 |
+
"n04326547": "stone wall",
|
| 832 |
+
"n04328186": "stopwatch, stop watch",
|
| 833 |
+
"n04330267": "stove",
|
| 834 |
+
"n04332243": "strainer",
|
| 835 |
+
"n04335435": "streetcar, tram, tramcar, trolley, trolley car",
|
| 836 |
+
"n04336792": "stretcher",
|
| 837 |
+
"n04344873": "studio couch, day bed",
|
| 838 |
+
"n04346328": "stupa, tope",
|
| 839 |
+
"n04347754": "submarine, pigboat, sub, U-boat",
|
| 840 |
+
"n04350905": "suit, suit of clothes",
|
| 841 |
+
"n04355338": "sundial",
|
| 842 |
+
"n04355933": "sunglass",
|
| 843 |
+
"n04356056": "sunglasses, dark glasses, shades",
|
| 844 |
+
"n04357314": "sunscreen, sunblock, sun blocker",
|
| 845 |
+
"n04366367": "suspension bridge",
|
| 846 |
+
"n04367480": "swab, swob, mop",
|
| 847 |
+
"n04370456": "sweatshirt",
|
| 848 |
+
"n04371430": "swimming trunks, bathing trunks",
|
| 849 |
+
"n04371774": "swing",
|
| 850 |
+
"n04372370": "switch, electric switch, electrical switch",
|
| 851 |
+
"n04376876": "syringe",
|
| 852 |
+
"n04380533": "table lamp",
|
| 853 |
+
"n04389033": "tank, army tank, armored combat vehicle, armoured combat vehicle",
|
| 854 |
+
"n04392985": "tape player",
|
| 855 |
+
"n04398044": "teapot",
|
| 856 |
+
"n04399382": "teddy, teddy bear",
|
| 857 |
+
"n04404412": "television, television system",
|
| 858 |
+
"n04409515": "tennis ball",
|
| 859 |
+
"n04417672": "thatch, thatched roof",
|
| 860 |
+
"n04418357": "theater curtain, theatre curtain",
|
| 861 |
+
"n04423845": "thimble",
|
| 862 |
+
"n04428191": "thresher, thrasher, threshing machine",
|
| 863 |
+
"n04429376": "throne",
|
| 864 |
+
"n04435653": "tile roof",
|
| 865 |
+
"n04442312": "toaster",
|
| 866 |
+
"n04443257": "tobacco shop, tobacconist shop, tobacconist",
|
| 867 |
+
"n04447861": "toilet seat",
|
| 868 |
+
"n04456115": "torch",
|
| 869 |
+
"n04458633": "totem pole",
|
| 870 |
+
"n04461696": "tow truck, tow car, wrecker",
|
| 871 |
+
"n04462240": "toyshop",
|
| 872 |
+
"n04465501": "tractor",
|
| 873 |
+
"n04467665": "trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi",
|
| 874 |
+
"n04476259": "tray",
|
| 875 |
+
"n04479046": "trench coat",
|
| 876 |
+
"n04482393": "tricycle, trike, velocipede",
|
| 877 |
+
"n04483307": "trimaran",
|
| 878 |
+
"n04485082": "tripod",
|
| 879 |
+
"n04486054": "triumphal arch",
|
| 880 |
+
"n04487081": "trolleybus, trolley coach, trackless trolley",
|
| 881 |
+
"n04487394": "trombone",
|
| 882 |
+
"n04493381": "tub, vat",
|
| 883 |
+
"n04501370": "turnstile",
|
| 884 |
+
"n04505470": "typewriter keyboard",
|
| 885 |
+
"n04507155": "umbrella",
|
| 886 |
+
"n04509417": "unicycle, monocycle",
|
| 887 |
+
"n04515003": "upright, upright piano",
|
| 888 |
+
"n04517823": "vacuum, vacuum cleaner",
|
| 889 |
+
"n04522168": "vase",
|
| 890 |
+
"n04523525": "vault",
|
| 891 |
+
"n04525038": "velvet",
|
| 892 |
+
"n04525305": "vending machine",
|
| 893 |
+
"n04532106": "vestment",
|
| 894 |
+
"n04532670": "viaduct",
|
| 895 |
+
"n04536866": "violin, fiddle",
|
| 896 |
+
"n04540053": "volleyball",
|
| 897 |
+
"n04542943": "waffle iron",
|
| 898 |
+
"n04548280": "wall clock",
|
| 899 |
+
"n04548362": "wallet, billfold, notecase, pocketbook",
|
| 900 |
+
"n04550184": "wardrobe, closet, press",
|
| 901 |
+
"n04552348": "warplane, military plane",
|
| 902 |
+
"n04553703": "washbasin, handbasin, washbowl, lavabo, wash-hand basin",
|
| 903 |
+
"n04554684": "washer, automatic washer, washing machine",
|
| 904 |
+
"n04557648": "water bottle",
|
| 905 |
+
"n04560804": "water jug",
|
| 906 |
+
"n04562935": "water tower",
|
| 907 |
+
"n04579145": "whiskey jug",
|
| 908 |
+
"n04579432": "whistle",
|
| 909 |
+
"n04584207": "wig",
|
| 910 |
+
"n04589890": "window screen",
|
| 911 |
+
"n04590129": "window shade",
|
| 912 |
+
"n04591157": "Windsor tie",
|
| 913 |
+
"n04591713": "wine bottle",
|
| 914 |
+
"n04592741": "wing",
|
| 915 |
+
"n04596742": "wok",
|
| 916 |
+
"n04597913": "wooden spoon",
|
| 917 |
+
"n04599235": "wool, woolen, woollen",
|
| 918 |
+
"n04604644": "worm fence, snake fence, snake-rail fence, Virginia fence",
|
| 919 |
+
"n04606251": "wreck",
|
| 920 |
+
"n04612504": "yawl",
|
| 921 |
+
"n04613696": "yurt",
|
| 922 |
+
"n06359193": "web site, website, internet site, site",
|
| 923 |
+
"n06596364": "comic book",
|
| 924 |
+
"n06785654": "crossword puzzle, crossword",
|
| 925 |
+
"n06794110": "street sign",
|
| 926 |
+
"n06874185": "traffic light, traffic signal, stoplight",
|
| 927 |
+
"n07248320": "book jacket, dust cover, dust jacket, dust wrapper",
|
| 928 |
+
"n07565083": "menu",
|
| 929 |
+
"n07579787": "plate",
|
| 930 |
+
"n07583066": "guacamole",
|
| 931 |
+
"n07584110": "consomme",
|
| 932 |
+
"n07590611": "hot pot, hotpot",
|
| 933 |
+
"n07613480": "trifle",
|
| 934 |
+
"n07614500": "ice cream, icecream",
|
| 935 |
+
"n07615774": "ice lolly, lolly, lollipop, popsicle",
|
| 936 |
+
"n07684084": "French loaf",
|
| 937 |
+
"n07693725": "bagel, beigel",
|
| 938 |
+
"n07695742": "pretzel",
|
| 939 |
+
"n07697313": "cheeseburger",
|
| 940 |
+
"n07697537": "hotdog, hot dog, red hot",
|
| 941 |
+
"n07711569": "mashed potato",
|
| 942 |
+
"n07714571": "head cabbage",
|
| 943 |
+
"n07714990": "broccoli",
|
| 944 |
+
"n07715103": "cauliflower",
|
| 945 |
+
"n07716358": "zucchini, courgette",
|
| 946 |
+
"n07716906": "spaghetti squash",
|
| 947 |
+
"n07717410": "acorn squash",
|
| 948 |
+
"n07717556": "butternut squash",
|
| 949 |
+
"n07718472": "cucumber, cuke",
|
| 950 |
+
"n07718747": "artichoke, globe artichoke",
|
| 951 |
+
"n07720875": "bell pepper",
|
| 952 |
+
"n07730033": "cardoon",
|
| 953 |
+
"n07734744": "mushroom",
|
| 954 |
+
"n07742313": "Granny Smith",
|
| 955 |
+
"n07745940": "strawberry",
|
| 956 |
+
"n07747607": "orange",
|
| 957 |
+
"n07749582": "lemon",
|
| 958 |
+
"n07753113": "fig",
|
| 959 |
+
"n07753275": "pineapple, ananas",
|
| 960 |
+
"n07753592": "banana",
|
| 961 |
+
"n07754684": "jackfruit, jak, jack",
|
| 962 |
+
"n07760859": "custard apple",
|
| 963 |
+
"n07768694": "pomegranate",
|
| 964 |
+
"n07802026": "hay",
|
| 965 |
+
"n07831146": "carbonara",
|
| 966 |
+
"n07836838": "chocolate sauce, chocolate syrup",
|
| 967 |
+
"n07860988": "dough",
|
| 968 |
+
"n07871810": "meat loaf, meatloaf",
|
| 969 |
+
"n07873807": "pizza, pizza pie",
|
| 970 |
+
"n07875152": "potpie",
|
| 971 |
+
"n07880968": "burrito",
|
| 972 |
+
"n07892512": "red wine",
|
| 973 |
+
"n07920052": "espresso",
|
| 974 |
+
"n07930864": "cup",
|
| 975 |
+
"n07932039": "eggnog",
|
| 976 |
+
"n09193705": "alp",
|
| 977 |
+
"n09229709": "bubble",
|
| 978 |
+
"n09246464": "cliff, drop, drop-off",
|
| 979 |
+
"n09256479": "coral reef",
|
| 980 |
+
"n09288635": "geyser",
|
| 981 |
+
"n09332890": "lakeside, lakeshore",
|
| 982 |
+
"n09399592": "promontory, headland, head, foreland",
|
| 983 |
+
"n09421951": "sandbar, sand bar",
|
| 984 |
+
"n09428293": "seashore, coast, seacoast, sea-coast",
|
| 985 |
+
"n09468604": "valley, vale",
|
| 986 |
+
"n09472597": "volcano",
|
| 987 |
+
"n09835506": "ballplayer, baseball player",
|
| 988 |
+
"n10148035": "groom, bridegroom",
|
| 989 |
+
"n10565667": "scuba diver",
|
| 990 |
+
"n11879895": "rapeseed",
|
| 991 |
+
"n11939491": "daisy",
|
| 992 |
+
"n12057211": "yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum",
|
| 993 |
+
"n12144580": "corn",
|
| 994 |
+
"n12267677": "acorn",
|
| 995 |
+
"n12620546": "hip, rose hip, rosehip",
|
| 996 |
+
"n12768682": "buckeye, horse chestnut, conker",
|
| 997 |
+
"n12985857": "coral fungus",
|
| 998 |
+
"n12998815": "agaric",
|
| 999 |
+
"n13037406": "gyromitra",
|
| 1000 |
+
"n13040303": "stinkhorn, carrion fungus",
|
| 1001 |
+
"n13044778": "earthstar",
|
| 1002 |
+
"n13052670": "hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa",
|
| 1003 |
+
"n13054560": "bolete",
|
| 1004 |
+
"n13133613": "ear, spike, capitulum",
|
| 1005 |
+
"n15075141": "toilet tissue, toilet paper, bathroom tissue",
|
| 1006 |
+
}
|
| 1007 |
+
)
|
tasks/image_classification/plotting.py
ADDED
|
@@ -0,0 +1,494 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import numpy as np
|
| 3 |
+
import cv2
|
| 4 |
+
import torch
|
| 5 |
+
import os
|
| 6 |
+
import imageio
|
| 7 |
+
import matplotlib.pyplot as plt
|
| 8 |
+
import matplotlib as mpl
|
| 9 |
+
from matplotlib import patheffects
|
| 10 |
+
mpl.use('Agg')
|
| 11 |
+
import seaborn as sns
|
| 12 |
+
import numpy as np
|
| 13 |
+
from tqdm.auto import tqdm
|
| 14 |
+
sns.set_style('darkgrid')
|
| 15 |
+
|
| 16 |
+
from tqdm.auto import tqdm
|
| 17 |
+
from scipy import ndimage
|
| 18 |
+
import umap
|
| 19 |
+
from scipy.special import softmax
|
| 20 |
+
|
| 21 |
+
import subprocess as sp
|
| 22 |
+
import cv2 # Still potentially useful for color conversion checks if needed
|
| 23 |
+
import os
|
| 24 |
+
|
| 25 |
+
def save_frames_to_mp4(frames, output_filename, fps=15.0, gop_size=None, crf=23, preset='medium', pix_fmt='yuv420p'):
|
| 26 |
+
"""
|
| 27 |
+
Saves a list of NumPy array frames to an MP4 video file using FFmpeg via subprocess.
|
| 28 |
+
|
| 29 |
+
Includes fix for odd frame dimensions by padding to the nearest even number using -vf pad.
|
| 30 |
+
|
| 31 |
+
Requires FFmpeg to be installed and available in the system PATH.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
frames (list): A list of NumPy arrays representing the video frames.
|
| 35 |
+
Expected format: uint8, (height, width, 3) for BGR color
|
| 36 |
+
or (height, width) for grayscale. Should be consistent.
|
| 37 |
+
output_filename (str): The path and name for the output MP4 file.
|
| 38 |
+
fps (float, optional): Frames per second for the output video. Defaults to 15.0.
|
| 39 |
+
gop_size (int, optional): Group of Pictures (GOP) size. This determines the
|
| 40 |
+
maximum interval between keyframes. Lower values
|
| 41 |
+
mean more frequent keyframes (better seeking, larger file).
|
| 42 |
+
Defaults to int(fps) (approx 1 keyframe per second).
|
| 43 |
+
crf (int, optional): Constant Rate Factor for H.264 encoding. Lower values mean
|
| 44 |
+
better quality and larger files. Typical range: 18-28.
|
| 45 |
+
Defaults to 23.
|
| 46 |
+
preset (str, optional): FFmpeg encoding speed preset. Affects encoding time
|
| 47 |
+
and compression efficiency. Options include 'ultrafast',
|
| 48 |
+
'superfast', 'veryfast', 'faster', 'fast', 'medium',
|
| 49 |
+
'slow', 'slower', 'veryslow'. Defaults to 'medium'.
|
| 50 |
+
"""
|
| 51 |
+
if not frames:
|
| 52 |
+
print("Error: The 'frames' list is empty. No video to save.")
|
| 53 |
+
return
|
| 54 |
+
|
| 55 |
+
# --- Determine Parameters from First Frame ---
|
| 56 |
+
try:
|
| 57 |
+
first_frame = frames[0]
|
| 58 |
+
print(first_frame.shape)
|
| 59 |
+
if not isinstance(first_frame, np.ndarray):
|
| 60 |
+
print(f"Error: Frame 0 is not a NumPy array (type: {type(first_frame)}).")
|
| 61 |
+
return
|
| 62 |
+
|
| 63 |
+
frame_height, frame_width = first_frame.shape[:2]
|
| 64 |
+
frame_size_str = f"{frame_width}x{frame_height}"
|
| 65 |
+
|
| 66 |
+
# Determine input pixel format based on first frame's shape
|
| 67 |
+
if len(first_frame.shape) == 3 and first_frame.shape[2] == 3:
|
| 68 |
+
input_pixel_format = 'bgr24' # Assume OpenCV's default BGR uint8
|
| 69 |
+
expected_dims = 3
|
| 70 |
+
print(f"Info: Detected color frames (shape: {first_frame.shape}). Expecting BGR input.")
|
| 71 |
+
elif len(first_frame.shape) == 2:
|
| 72 |
+
input_pixel_format = 'gray'
|
| 73 |
+
expected_dims = 2
|
| 74 |
+
print(f"Info: Detected grayscale frames (shape: {first_frame.shape}).")
|
| 75 |
+
else:
|
| 76 |
+
print(f"Error: Unsupported frame shape {first_frame.shape}. Must be (h, w) or (h, w, 3).")
|
| 77 |
+
return
|
| 78 |
+
|
| 79 |
+
if first_frame.dtype != np.uint8:
|
| 80 |
+
print(f"Warning: First frame dtype is {first_frame.dtype}. Will attempt conversion to uint8.")
|
| 81 |
+
|
| 82 |
+
except IndexError:
|
| 83 |
+
print("Error: Could not access the first frame to determine dimensions.")
|
| 84 |
+
return
|
| 85 |
+
except Exception as e:
|
| 86 |
+
print(f"Error processing first frame: {e}")
|
| 87 |
+
return
|
| 88 |
+
|
| 89 |
+
# --- Set GOP size default if not provided ---
|
| 90 |
+
if gop_size is None:
|
| 91 |
+
gop_size = int(fps)
|
| 92 |
+
print(f"Info: GOP size not specified, defaulting to {gop_size} (approx 1 keyframe/sec).")
|
| 93 |
+
|
| 94 |
+
# --- Construct FFmpeg Command ---
|
| 95 |
+
# ADDED -vf pad filter to ensure even dimensions for libx264/yuv420p
|
| 96 |
+
# It calculates the nearest even dimensions >= original dimensions
|
| 97 |
+
# Example: 1600x1351 -> 1600x1352
|
| 98 |
+
pad_filter = "pad=ceil(iw/2)*2:ceil(ih/2)*2"
|
| 99 |
+
|
| 100 |
+
command = [
|
| 101 |
+
'ffmpeg',
|
| 102 |
+
'-y',
|
| 103 |
+
'-f', 'rawvideo',
|
| 104 |
+
'-vcodec', 'rawvideo',
|
| 105 |
+
'-pix_fmt', input_pixel_format,
|
| 106 |
+
'-s', frame_size_str,
|
| 107 |
+
'-r', str(float(fps)),
|
| 108 |
+
'-i', '-',
|
| 109 |
+
'-vf', pad_filter, # <--- ADDED VIDEO FILTER HERE
|
| 110 |
+
'-c:v', 'libx264',
|
| 111 |
+
'-pix_fmt', pix_fmt,
|
| 112 |
+
'-preset', preset,
|
| 113 |
+
'-crf', str(crf),
|
| 114 |
+
'-g', str(gop_size),
|
| 115 |
+
'-movflags', '+faststart',
|
| 116 |
+
output_filename
|
| 117 |
+
]
|
| 118 |
+
|
| 119 |
+
print(f"\n--- Starting FFmpeg ---")
|
| 120 |
+
print(f"Output File: {output_filename}")
|
| 121 |
+
print(f"Parameters: FPS={fps}, Size={frame_size_str}, GOP={gop_size}, CRF={crf}, Preset={preset}")
|
| 122 |
+
print(f"Applying Filter: -vf {pad_filter} (Ensures even dimensions)")
|
| 123 |
+
# print(f"FFmpeg Command: {' '.join(command)}") # Uncomment for debugging
|
| 124 |
+
|
| 125 |
+
# --- Execute FFmpeg via Subprocess ---
|
| 126 |
+
try:
|
| 127 |
+
process = sp.Popen(command, stdin=sp.PIPE, stdout=sp.PIPE, stderr=sp.PIPE)
|
| 128 |
+
|
| 129 |
+
print(f"\nWriting {len(frames)} frames to FFmpeg...")
|
| 130 |
+
progress_interval = max(1, len(frames) // 10) # Print progress roughly 10 times
|
| 131 |
+
|
| 132 |
+
for i, frame in enumerate(frames):
|
| 133 |
+
# Basic validation and conversion for each frame
|
| 134 |
+
if not isinstance(frame, np.ndarray):
|
| 135 |
+
print(f"Warning: Frame {i} is not a numpy array (type: {type(frame)}). Skipping.")
|
| 136 |
+
continue
|
| 137 |
+
if frame.shape[0] != frame_height or frame.shape[1] != frame_width:
|
| 138 |
+
print(f"Warning: Frame {i} has different dimensions {frame.shape[:2]}! Expected ({frame_height},{frame_width}). Skipping.")
|
| 139 |
+
continue
|
| 140 |
+
|
| 141 |
+
current_dims = len(frame.shape)
|
| 142 |
+
if current_dims != expected_dims:
|
| 143 |
+
print(f"Warning: Frame {i} has inconsistent dimensions ({current_dims}D vs expected {expected_dims}D). Skipping.")
|
| 144 |
+
continue
|
| 145 |
+
if expected_dims == 3 and frame.shape[2] != 3:
|
| 146 |
+
print(f"Warning: Frame {i} is color but doesn't have 3 channels ({frame.shape}). Skipping.")
|
| 147 |
+
continue
|
| 148 |
+
|
| 149 |
+
if frame.dtype != np.uint8:
|
| 150 |
+
try:
|
| 151 |
+
frame = np.clip(frame, 0, 255).astype(np.uint8)
|
| 152 |
+
except Exception as clip_err:
|
| 153 |
+
print(f"Error clipping/converting frame {i} dtype: {clip_err}. Skipping.")
|
| 154 |
+
continue
|
| 155 |
+
|
| 156 |
+
# Write frame bytes to FFmpeg's stdin
|
| 157 |
+
try:
|
| 158 |
+
process.stdin.write(frame.tobytes())
|
| 159 |
+
except (OSError, BrokenPipeError) as pipe_err:
|
| 160 |
+
print(f"\nError writing frame {i} to FFmpeg stdin: {pipe_err}")
|
| 161 |
+
print("FFmpeg process likely terminated prematurely. Check FFmpeg errors below.")
|
| 162 |
+
try:
|
| 163 |
+
# Immediately try to read stderr if pipe breaks
|
| 164 |
+
stderr_output_on_error = process.stderr.read()
|
| 165 |
+
if stderr_output_on_error:
|
| 166 |
+
print("\n--- FFmpeg stderr output on error ---")
|
| 167 |
+
print(stderr_output_on_error.decode(errors='ignore'))
|
| 168 |
+
print("--- End FFmpeg stderr ---")
|
| 169 |
+
except Exception as read_err:
|
| 170 |
+
print(f"(Could not read stderr after pipe error: {read_err})")
|
| 171 |
+
return
|
| 172 |
+
except Exception as write_err:
|
| 173 |
+
print(f"Unexpected error writing frame {i}: {write_err}. Skipping.")
|
| 174 |
+
continue
|
| 175 |
+
|
| 176 |
+
if (i + 1) % progress_interval == 0 or (i + 1) == len(frames):
|
| 177 |
+
print(f" Processed frame {i + 1}/{len(frames)}")
|
| 178 |
+
|
| 179 |
+
print("\nFinished writing frames. Closing FFmpeg stdin and waiting for completion...")
|
| 180 |
+
process.stdin.close()
|
| 181 |
+
stdout, stderr = process.communicate()
|
| 182 |
+
return_code = process.wait()
|
| 183 |
+
|
| 184 |
+
print("\n--- FFmpeg Final Status ---")
|
| 185 |
+
if return_code == 0:
|
| 186 |
+
print(f"FFmpeg process completed successfully.")
|
| 187 |
+
print(f"Video saved as: {output_filename}")
|
| 188 |
+
else:
|
| 189 |
+
print(f"FFmpeg process failed with return code {return_code}.")
|
| 190 |
+
print("--- FFmpeg Standard Error Output: ---")
|
| 191 |
+
print(stderr.decode(errors='replace')) # Print stderr captured by communicate()
|
| 192 |
+
print("--- End FFmpeg Output ---")
|
| 193 |
+
print("Review the FFmpeg error message above for details (e.g., dimension errors, parameter issues).")
|
| 194 |
+
|
| 195 |
+
except FileNotFoundError:
|
| 196 |
+
print("\n--- FATAL ERROR ---")
|
| 197 |
+
print("Error: 'ffmpeg' command not found.")
|
| 198 |
+
print("Please ensure FFmpeg is installed and its directory is included in your system's PATH environment variable.")
|
| 199 |
+
print("Download from: https://ffmpeg.org/")
|
| 200 |
+
print("-------------------")
|
| 201 |
+
except Exception as e:
|
| 202 |
+
print(f"\nAn unexpected error occurred during FFmpeg execution: {e}")
|
| 203 |
+
|
| 204 |
+
def find_island_centers(array_2d, threshold):
|
| 205 |
+
"""
|
| 206 |
+
Finds the center of mass of each island (connected component) in a 2D array.
|
| 207 |
+
|
| 208 |
+
Args:
|
| 209 |
+
array_2d: A 2D numpy array of values.
|
| 210 |
+
threshold: The threshold to binarize the array.
|
| 211 |
+
|
| 212 |
+
Returns:
|
| 213 |
+
A list of tuples (y, x) representing the center of mass of each island.
|
| 214 |
+
"""
|
| 215 |
+
binary_image = array_2d > threshold
|
| 216 |
+
labeled_image, num_labels = ndimage.label(binary_image)
|
| 217 |
+
centers = []
|
| 218 |
+
areas = [] # Store the area of each island
|
| 219 |
+
for i in range(1, num_labels + 1):
|
| 220 |
+
island = (labeled_image == i)
|
| 221 |
+
total_mass = np.sum(array_2d[island])
|
| 222 |
+
if total_mass > 0:
|
| 223 |
+
y_coords, x_coords = np.mgrid[:array_2d.shape[0], :array_2d.shape[1]]
|
| 224 |
+
x_center = np.average(x_coords[island], weights=array_2d[island])
|
| 225 |
+
y_center = np.average(y_coords[island], weights=array_2d[island])
|
| 226 |
+
centers.append((round(y_center, 4), round(x_center, 4)))
|
| 227 |
+
areas.append(np.sum(island)) # Calculate area of the island
|
| 228 |
+
return centers, areas
|
| 229 |
+
|
| 230 |
+
def plot_neural_dynamics(post_activations_history, N_to_plot, save_location, axis_snap=False, N_per_row=5, which_neurons_mid=None, mid_colours=None, use_most_active_neurons=False):
|
| 231 |
+
assert N_to_plot%N_per_row==0, f'For nice visualisation, N_to_plot={N_to_plot} must be a multiple of N_per_row={N_per_row}'
|
| 232 |
+
assert post_activations_history.shape[-1] >= N_to_plot
|
| 233 |
+
figscale = 2
|
| 234 |
+
aspect_ratio = 3
|
| 235 |
+
mosaic = np.array([[f'{i}'] for i in range(N_to_plot)]).flatten().reshape(-1, N_per_row)
|
| 236 |
+
fig_synch, axes_synch = plt.subplot_mosaic(mosaic=mosaic, figsize=(figscale*mosaic.shape[1]*aspect_ratio*0.2, figscale*mosaic.shape[0]*0.2))
|
| 237 |
+
fig_mid, axes_mid = plt.subplot_mosaic(mosaic=mosaic, figsize=(figscale*mosaic.shape[1]*aspect_ratio*0.2, figscale*mosaic.shape[0]*0.2), dpi=200)
|
| 238 |
+
|
| 239 |
+
palette = sns.color_palette("husl", 8)
|
| 240 |
+
|
| 241 |
+
which_neurons_synch = np.arange(N_to_plot)
|
| 242 |
+
# which_neurons_mid = np.arange(N_to_plot, N_to_plot*2) if post_activations_history.shape[-1] >= 2*N_to_plot else np.random.choice(np.arange(post_activations_history.shape[-1]), size=N_to_plot, replace=True)
|
| 243 |
+
random_indices = np.random.choice(np.arange(post_activations_history.shape[-1]), size=N_to_plot, replace=post_activations_history.shape[-1] < N_to_plot)
|
| 244 |
+
if use_most_active_neurons:
|
| 245 |
+
metric = np.abs(np.fft.rfft(post_activations_history, axis=0))[3:].mean(0).std(0)
|
| 246 |
+
random_indices = np.argsort(metric)[-N_to_plot:]
|
| 247 |
+
np.random.shuffle(random_indices)
|
| 248 |
+
which_neurons_mid = which_neurons_mid if which_neurons_mid is not None else random_indices
|
| 249 |
+
|
| 250 |
+
if mid_colours is None:
|
| 251 |
+
mid_colours = [palette[np.random.randint(0, 8)] for ndx in range(N_to_plot)]
|
| 252 |
+
with tqdm(total=N_to_plot, initial=0, leave=False, position=1, dynamic_ncols=True) as pbar_inner:
|
| 253 |
+
pbar_inner.set_description('Plotting neural dynamics')
|
| 254 |
+
for ndx in range(N_to_plot):
|
| 255 |
+
|
| 256 |
+
ax_s = axes_synch[f'{ndx}']
|
| 257 |
+
ax_m = axes_mid[f'{ndx}']
|
| 258 |
+
|
| 259 |
+
traces_s = post_activations_history[:,:,which_neurons_synch[ndx]].T
|
| 260 |
+
traces_m = post_activations_history[:,:,which_neurons_mid[ndx]].T
|
| 261 |
+
c_s = palette[np.random.randint(0, 8)]
|
| 262 |
+
c_m = mid_colours[ndx]
|
| 263 |
+
|
| 264 |
+
for traces_s_here, traces_m_here in zip(traces_s, traces_m):
|
| 265 |
+
ax_s.plot(np.arange(len(traces_s_here)), traces_s_here, linestyle='-', color=c_s, alpha=0.05, linewidth=0.6)
|
| 266 |
+
ax_m.plot(np.arange(len(traces_m_here)), traces_m_here, linestyle='-', color=c_m, alpha=0.05, linewidth=0.6)
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
ax_s.plot(np.arange(len(traces_s[0])), traces_s[0], linestyle='-', color='white', alpha=1, linewidth=2.5)
|
| 270 |
+
ax_s.plot(np.arange(len(traces_s[0])), traces_s[0], linestyle='-', color=c_s, alpha=1, linewidth=1.3)
|
| 271 |
+
ax_s.plot(np.arange(len(traces_s[0])), traces_s[0], linestyle='-', color='black', alpha=1, linewidth=0.3)
|
| 272 |
+
ax_m.plot(np.arange(len(traces_m[0])), traces_m[0], linestyle='-', color='white', alpha=1, linewidth=2.5)
|
| 273 |
+
ax_m.plot(np.arange(len(traces_m[0])), traces_m[0], linestyle='-', color=c_m, alpha=1, linewidth=1.3)
|
| 274 |
+
ax_m.plot(np.arange(len(traces_m[0])), traces_m[0], linestyle='-', color='black', alpha=1, linewidth=0.3)
|
| 275 |
+
if axis_snap and np.all(np.isfinite(traces_s[0])):
|
| 276 |
+
ax_s.set_ylim([np.min(traces_s[0])-np.ptp(traces_s[0])*0.05, np.max(traces_s[0])+np.ptp(traces_s[0])*0.05])
|
| 277 |
+
ax_m.set_ylim([np.min(traces_m[0])-np.ptp(traces_m[0])*0.05, np.max(traces_m[0])+np.ptp(traces_m[0])*0.05])
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
ax_s.grid(False)
|
| 281 |
+
ax_m.grid(False)
|
| 282 |
+
ax_s.set_xlim([0, len(traces_s[0])-1])
|
| 283 |
+
ax_m.set_xlim([0, len(traces_m[0])-1])
|
| 284 |
+
|
| 285 |
+
ax_s.set_xticklabels([])
|
| 286 |
+
ax_s.set_yticklabels([])
|
| 287 |
+
|
| 288 |
+
ax_m.set_xticklabels([])
|
| 289 |
+
ax_m.set_yticklabels([])
|
| 290 |
+
pbar_inner.update(1)
|
| 291 |
+
fig_synch.tight_layout(pad=0.05)
|
| 292 |
+
fig_mid.tight_layout(pad=0.05)
|
| 293 |
+
if save_location is not None:
|
| 294 |
+
fig_synch.savefig(f'{save_location}/neural_dynamics_synch.pdf', dpi=200)
|
| 295 |
+
fig_synch.savefig(f'{save_location}/neural_dynamics_synch.png', dpi=200)
|
| 296 |
+
fig_mid.savefig(f'{save_location}/neural_dynamics_other.pdf', dpi=200)
|
| 297 |
+
fig_mid.savefig(f'{save_location}/neural_dynamics_other.png', dpi=200)
|
| 298 |
+
plt.close(fig_synch)
|
| 299 |
+
plt.close(fig_mid)
|
| 300 |
+
return fig_synch, fig_mid, which_neurons_mid, mid_colours
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
def make_classification_gif(image, target, predictions, certainties, post_activations, attention_tracking, class_labels, save_location):
|
| 305 |
+
cmap_viridis = sns.color_palette('viridis', as_cmap=True)
|
| 306 |
+
cmap_spectral = sns.color_palette("Spectral", as_cmap=True)
|
| 307 |
+
figscale = 2
|
| 308 |
+
with tqdm(total=post_activations.shape[0]+1, initial=0, leave=False, position=1, dynamic_ncols=True) as pbar_inner:
|
| 309 |
+
pbar_inner.set_description('Computing UMAP')
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
low = np.percentile(post_activations, 1, axis=0, keepdims=True)
|
| 313 |
+
high = np.percentile(post_activations, 99, axis=0, keepdims=True)
|
| 314 |
+
post_activations_normed = np.clip((post_activations - low)/(high - low), 0, 1)
|
| 315 |
+
metric = 'cosine'
|
| 316 |
+
reducer = umap.UMAP(n_components=2,
|
| 317 |
+
n_neighbors=100,
|
| 318 |
+
min_dist=3,
|
| 319 |
+
spread=3.0,
|
| 320 |
+
metric=metric,
|
| 321 |
+
random_state=None,
|
| 322 |
+
# low_memory=True,
|
| 323 |
+
) if post_activations.shape[-1] > 2048 else umap.UMAP(n_components=2,
|
| 324 |
+
n_neighbors=20,
|
| 325 |
+
min_dist=1,
|
| 326 |
+
spread=1.0,
|
| 327 |
+
metric=metric,
|
| 328 |
+
random_state=None,
|
| 329 |
+
# low_memory=True,
|
| 330 |
+
)
|
| 331 |
+
positions = reducer.fit_transform(post_activations_normed.T)
|
| 332 |
+
|
| 333 |
+
x_umap = positions[:, 0]
|
| 334 |
+
y_umap = positions[:, 1]
|
| 335 |
+
|
| 336 |
+
pbar_inner.update(1)
|
| 337 |
+
pbar_inner.set_description('Iterating through to build frames')
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
frames = []
|
| 342 |
+
route_steps = {}
|
| 343 |
+
route_colours = []
|
| 344 |
+
|
| 345 |
+
n_steps = len(post_activations)
|
| 346 |
+
n_heads = attention_tracking.shape[1]
|
| 347 |
+
step_linspace = np.linspace(0, 1, n_steps)
|
| 348 |
+
|
| 349 |
+
for stepi in np.arange(0, n_steps, 1):
|
| 350 |
+
pbar_inner.set_description('Making frames for gif')
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
attention_now = attention_tracking[max(0, stepi-5):stepi+1].mean(0) # Make it smooth for pretty
|
| 354 |
+
# attention_now[:,0,0] = 0 # Corners can be weird looking
|
| 355 |
+
# attention_now[:,0,-1] = 0
|
| 356 |
+
# attention_now[:,-1,0] = 0
|
| 357 |
+
# attention_now[:,-1,-1] = 0
|
| 358 |
+
# attention_now = (attention_tracking[:stepi+1, 0] * decay).sum(0)/(decay.sum(0))
|
| 359 |
+
certainties_now = certainties[1, :stepi+1]
|
| 360 |
+
attention_interp = torch.nn.functional.interpolate(torch.from_numpy(attention_now).unsqueeze(0), image.shape[:2], mode='bilinear')[0]
|
| 361 |
+
attention_interp = (attention_interp.flatten(1) - attention_interp.flatten(1).min(-1, keepdim=True)[0])/(attention_interp.flatten(1).max(-1, keepdim=True)[0] - attention_interp.flatten(1).min(-1, keepdim=True)[0])
|
| 362 |
+
attention_interp = attention_interp.reshape(n_heads, image.shape[0], image.shape[1])
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
colour = list(cmap_spectral(step_linspace[stepi]))
|
| 366 |
+
route_colours.append(colour)
|
| 367 |
+
for headi in range(min(8, n_heads)):
|
| 368 |
+
com_attn = np.copy(attention_interp[headi])
|
| 369 |
+
com_attn[com_attn < np.percentile(com_attn, 97)] = 0.0
|
| 370 |
+
if headi not in route_steps:
|
| 371 |
+
A = attention_interp[headi].detach().cpu().numpy()
|
| 372 |
+
centres, areas = find_island_centers(A, threshold=0.7)
|
| 373 |
+
route_steps[headi] = [centres[np.argmax(areas)]]
|
| 374 |
+
else:
|
| 375 |
+
A = attention_interp[headi].detach().cpu().numpy()
|
| 376 |
+
centres, areas = find_island_centers(A, threshold=0.7)
|
| 377 |
+
route_steps[headi] = route_steps[headi] + [centres[np.argmax(areas)]]
|
| 378 |
+
|
| 379 |
+
mosaic = [['head_0', 'head_0_overlay', 'head_1', 'head_1_overlay'],
|
| 380 |
+
['head_2', 'head_2_overlay', 'head_3', 'head_3_overlay'],
|
| 381 |
+
['head_4', 'head_4_overlay', 'head_5', 'head_5_overlay'],
|
| 382 |
+
['head_6', 'head_6_overlay', 'head_7', 'head_7_overlay'],
|
| 383 |
+
['probabilities', 'probabilities','certainty', 'certainty'],
|
| 384 |
+
['umap', 'umap', 'umap', 'umap'],
|
| 385 |
+
['umap', 'umap', 'umap', 'umap'],
|
| 386 |
+
['umap', 'umap', 'umap', 'umap'],
|
| 387 |
+
|
| 388 |
+
]
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
img_aspect = image.shape[0]/image.shape[1]
|
| 392 |
+
# print(img_aspect)
|
| 393 |
+
aspect_ratio = (4*figscale, 8*figscale*img_aspect)
|
| 394 |
+
fig, axes = plt.subplot_mosaic(mosaic, figsize=aspect_ratio)
|
| 395 |
+
for ax in axes.values():
|
| 396 |
+
ax.axis('off')
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
axes['certainty'].plot(np.arange(len(certainties_now)), certainties_now, 'k-', linewidth=figscale*1, label='1-(normalised entropy)')
|
| 400 |
+
for ii, (x, y) in enumerate(zip(np.arange(len(certainties_now)), certainties_now)):
|
| 401 |
+
is_correct = predictions[:, ii].argmax(-1)==target
|
| 402 |
+
if is_correct: axes['certainty'].axvspan(ii, ii + 1, facecolor='limegreen', edgecolor=None, lw=0, alpha=0.3)
|
| 403 |
+
else:
|
| 404 |
+
axes['certainty'].axvspan(ii, ii + 1, facecolor='orchid', edgecolor=None, lw=0, alpha=0.3)
|
| 405 |
+
axes['certainty'].plot(len(certainties_now)-1, certainties_now[-1], 'k.', markersize=figscale*4)
|
| 406 |
+
axes['certainty'].axis('off')
|
| 407 |
+
axes['certainty'].set_ylim([-0.05, 1.05])
|
| 408 |
+
axes['certainty'].set_xlim([0, certainties.shape[-1]+1])
|
| 409 |
+
|
| 410 |
+
ps = torch.softmax(torch.from_numpy(predictions[:, stepi]), -1)
|
| 411 |
+
k = 15 if len(class_labels) > 15 else len(class_labels)
|
| 412 |
+
topk = torch.topk (ps, k, dim = 0, largest=True).indices.detach().cpu().numpy()
|
| 413 |
+
top_classes = np.array(class_labels)[topk]
|
| 414 |
+
true_class = target
|
| 415 |
+
colours = [('b' if ci != true_class else 'g') for ci in topk]
|
| 416 |
+
bar_heights = ps[topk].detach().cpu().numpy()
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
axes['probabilities'].bar(np.arange(len(bar_heights))[::-1], bar_heights, color=np.array(colours), alpha=1)
|
| 420 |
+
axes['probabilities'].set_ylim([0, 1])
|
| 421 |
+
axes['probabilities'].axis('off')
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
for i, (name) in enumerate(top_classes):
|
| 425 |
+
prob = ps[i]
|
| 426 |
+
is_correct = name==class_labels[true_class]
|
| 427 |
+
fg_color = 'darkgreen' if is_correct else 'crimson'
|
| 428 |
+
text_str = f'{name[:40]}'
|
| 429 |
+
axes['probabilities'].text(
|
| 430 |
+
0.05,
|
| 431 |
+
0.95 - i * 0.055, # Adjust vertical position for each line
|
| 432 |
+
text_str,
|
| 433 |
+
transform=axes['probabilities'].transAxes,
|
| 434 |
+
verticalalignment='top',
|
| 435 |
+
fontsize=8, # Increased font size
|
| 436 |
+
color=fg_color,
|
| 437 |
+
alpha=0.5,
|
| 438 |
+
path_effects=[
|
| 439 |
+
patheffects.Stroke(linewidth=3, foreground='aliceblue'),
|
| 440 |
+
patheffects.Normal()
|
| 441 |
+
])
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
attention_now = attention_tracking[max(0, stepi-5):stepi+1].mean(0) # Make it smooth for pretty
|
| 446 |
+
# attention_now = (attention_tracking[:stepi+1, 0] * decay).sum(0)/(decay.sum(0))
|
| 447 |
+
certainties_now = certainties[1, :stepi+1]
|
| 448 |
+
attention_interp = torch.nn.functional.interpolate(torch.from_numpy(attention_now).unsqueeze(0), image.shape[:2], mode='nearest')[0]
|
| 449 |
+
attention_interp = (attention_interp.flatten(1) - attention_interp.flatten(1).min(-1, keepdim=True)[0])/(attention_interp.flatten(1).max(-1, keepdim=True)[0] - attention_interp.flatten(1).min(-1, keepdim=True)[0])
|
| 450 |
+
attention_interp = attention_interp.reshape(n_heads, image.shape[0], image.shape[1])
|
| 451 |
+
|
| 452 |
+
for hi in range(min(8, n_heads)):
|
| 453 |
+
ax = axes[f'head_{hi}']
|
| 454 |
+
img_to_plot = cmap_viridis(attention_interp[hi].detach().cpu().numpy())
|
| 455 |
+
ax.imshow(img_to_plot)
|
| 456 |
+
|
| 457 |
+
ax_overlay = axes[f'head_{hi}_overlay']
|
| 458 |
+
|
| 459 |
+
these_route_steps = route_steps[hi]
|
| 460 |
+
y_coords, x_coords = zip(*these_route_steps)
|
| 461 |
+
y_coords = image.shape[-2] - np.array(list(y_coords))-1
|
| 462 |
+
|
| 463 |
+
ax_overlay.imshow(np.flip(image, axis=0), origin='lower')
|
| 464 |
+
# ax.imshow(np.flip(solution_maze, axis=0), origin='lower')
|
| 465 |
+
arrow_scale = 1.5 if image.shape[0] > 32 else 0.8
|
| 466 |
+
for i in range(len(these_route_steps)-1):
|
| 467 |
+
dx = x_coords[i+1] - x_coords[i]
|
| 468 |
+
dy = y_coords[i+1] - y_coords[i]
|
| 469 |
+
|
| 470 |
+
ax_overlay.arrow(x_coords[i], y_coords[i], dx, dy, linewidth=1.6*arrow_scale*1.3, head_width=1.9*arrow_scale*1.3, head_length=1.4*arrow_scale*1.45, fc='white', ec='white', length_includes_head = True, alpha=1)
|
| 471 |
+
ax_overlay.arrow(x_coords[i], y_coords[i], dx, dy, linewidth=1.6*arrow_scale, head_width=1.9*arrow_scale, head_length=1.4*arrow_scale, fc=route_colours[i], ec=route_colours[i], length_includes_head = True)
|
| 472 |
+
|
| 473 |
+
ax_overlay.set_xlim([0,image.shape[1]-1])
|
| 474 |
+
ax_overlay.set_ylim([0,image.shape[0]-1])
|
| 475 |
+
ax_overlay.axis('off')
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
z = post_activations_normed[stepi]
|
| 479 |
+
|
| 480 |
+
axes['umap'].scatter(x_umap, y_umap, s=30, c=cmap_spectral(z))
|
| 481 |
+
|
| 482 |
+
fig.tight_layout(pad=0.1)
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
canvas = fig.canvas
|
| 487 |
+
canvas.draw()
|
| 488 |
+
image_numpy = np.frombuffer(canvas.buffer_rgba(), dtype='uint8')
|
| 489 |
+
image_numpy = (image_numpy.reshape(*reversed(canvas.get_width_height()), 4)[:,:,:3])
|
| 490 |
+
frames.append(image_numpy)
|
| 491 |
+
plt.close(fig)
|
| 492 |
+
pbar_inner.update(1)
|
| 493 |
+
pbar_inner.set_description('Saving gif')
|
| 494 |
+
imageio.mimsave(save_location, frames, fps=15, loop=100)
|
tasks/image_classification/scripts/train_cifar10.sh
ADDED
|
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
python -m tasks.image_classification.train \
|
| 2 |
+
--log_dir logs/cifar10-versus-humans/ctm/d=256--i=64--heads=16--sd=5--synch=256-512-0-h=64-random-pairing--iters=50x15--backbone=18-1--seed=1 \
|
| 3 |
+
--model ctm
|
| 4 |
+
--dataset cifar10 \
|
| 5 |
+
--d_model 256 \
|
| 6 |
+
--d_input 64 \
|
| 7 |
+
--synapse_depth 5 \
|
| 8 |
+
--heads 16 \
|
| 9 |
+
--n_synch_out 256 \
|
| 10 |
+
--n_synch_action 512 \
|
| 11 |
+
--n_random_pairing_self 0 \
|
| 12 |
+
--neuron_select_type random-pairing \
|
| 13 |
+
--iterations 50 \
|
| 14 |
+
--memory_length 15 \
|
| 15 |
+
--deep_memory \
|
| 16 |
+
--memory_hidden_dims 64 \
|
| 17 |
+
--dropout 0.0 \
|
| 18 |
+
--dropout_nlm 0 \
|
| 19 |
+
--no-do_normalisation \
|
| 20 |
+
--positional_embedding_type none \
|
| 21 |
+
--backbone_type resnet18-1 \
|
| 22 |
+
--training_iterations 600001 \
|
| 23 |
+
--warmup_steps 1000 \
|
| 24 |
+
--use_scheduler \
|
| 25 |
+
--scheduler_type cosine \
|
| 26 |
+
--weight_decay 0.0001 \
|
| 27 |
+
--save_every 1000 \
|
| 28 |
+
--track_every 2000 \
|
| 29 |
+
--n_test_batches 50 \
|
| 30 |
+
--num_workers_train 8 \
|
| 31 |
+
--batch_size 512 \
|
| 32 |
+
--batch_size_test 512 \
|
| 33 |
+
--lr 1e-4 \
|
| 34 |
+
--device 0 \
|
| 35 |
+
--seed 1
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
python -m tasks.image_classification.train \
|
| 39 |
+
--log_dir logs/cifar10-versus-humans/ctm/d=256--i=64--heads=16--sd=5--synch=256-512-0-h=64-random-pairing--iters=50x15--backbone=18-1--seed=2 \
|
| 40 |
+
--model ctm
|
| 41 |
+
--dataset cifar10 \
|
| 42 |
+
--d_model 256 \
|
| 43 |
+
--d_input 64 \
|
| 44 |
+
--synapse_depth 5 \
|
| 45 |
+
--heads 16 \
|
| 46 |
+
--n_synch_out 256 \
|
| 47 |
+
--n_synch_action 512 \
|
| 48 |
+
--n_random_pairing_self 0 \
|
| 49 |
+
--neuron_select_type random-pairing \
|
| 50 |
+
--iterations 50 \
|
| 51 |
+
--memory_length 15 \
|
| 52 |
+
--deep_memory \
|
| 53 |
+
--memory_hidden_dims 64 \
|
| 54 |
+
--dropout 0.0 \
|
| 55 |
+
--dropout_nlm 0 \
|
| 56 |
+
--no-do_normalisation \
|
| 57 |
+
--positional_embedding_type none \
|
| 58 |
+
--backbone_type resnet18-1 \
|
| 59 |
+
--training_iterations 600001 \
|
| 60 |
+
--warmup_steps 1000 \
|
| 61 |
+
--use_scheduler \
|
| 62 |
+
--scheduler_type cosine \
|
| 63 |
+
--weight_decay 0.0001 \
|
| 64 |
+
--save_every 1000 \
|
| 65 |
+
--track_every 2000 \
|
| 66 |
+
--n_test_batches 50 \
|
| 67 |
+
--num_workers_train 8 \
|
| 68 |
+
--batch_size 512 \
|
| 69 |
+
--batch_size_test 512 \
|
| 70 |
+
--lr 1e-4 \
|
| 71 |
+
--device 0 \
|
| 72 |
+
--seed 2
|
| 73 |
+
|
| 74 |
+
python -m tasks.image_classification.train \
|
| 75 |
+
--log_dir logs/cifar10-versus-humans/ctm/d=256--i=64--heads=16--sd=5--synch=256-512-0-h=64-random-pairing--iters=50x15--backbone=18-1--seed=42 \
|
| 76 |
+
--model ctm
|
| 77 |
+
--dataset cifar10 \
|
| 78 |
+
--d_model 256 \
|
| 79 |
+
--d_input 64 \
|
| 80 |
+
--synapse_depth 5 \
|
| 81 |
+
--heads 16 \
|
| 82 |
+
--n_synch_out 256 \
|
| 83 |
+
--n_synch_action 512 \
|
| 84 |
+
--n_random_pairing_self 0 \
|
| 85 |
+
--neuron_select_type random-pairing \
|
| 86 |
+
--iterations 50 \
|
| 87 |
+
--memory_length 15 \
|
| 88 |
+
--deep_memory \
|
| 89 |
+
--memory_hidden_dims 64 \
|
| 90 |
+
--dropout 0.0 \
|
| 91 |
+
--dropout_nlm 0 \
|
| 92 |
+
--no-do_normalisation \
|
| 93 |
+
--positional_embedding_type none \
|
| 94 |
+
--backbone_type resnet18-1 \
|
| 95 |
+
--training_iterations 600001 \
|
| 96 |
+
--warmup_steps 1000 \
|
| 97 |
+
--use_scheduler \
|
| 98 |
+
--scheduler_type cosine \
|
| 99 |
+
--weight_decay 0.0001 \
|
| 100 |
+
--save_every 1000 \
|
| 101 |
+
--track_every 2000 \
|
| 102 |
+
--n_test_batches 50 \
|
| 103 |
+
--num_workers_train 8 \
|
| 104 |
+
--batch_size 512 \
|
| 105 |
+
--batch_size_test 512 \
|
| 106 |
+
--lr 1e-4 \
|
| 107 |
+
--device 0 \
|
| 108 |
+
--seed 42
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
python -m tasks.image_classification.train \
|
| 116 |
+
--log_dir logs/cifar10-versus-humans/lstm/nlayers=2--d=256--i=64--heads=16--synch=256-512-0-h=64-random-pairing--iters=50x15--backbone=18-1--seed=1 \
|
| 117 |
+
--dataset cifar10 \
|
| 118 |
+
--model lstm \
|
| 119 |
+
--num_layers 2 \
|
| 120 |
+
--d_model 256 \
|
| 121 |
+
--d_input 64 \
|
| 122 |
+
--heads 16 \
|
| 123 |
+
--iterations 50 \
|
| 124 |
+
--dropout 0.0 \
|
| 125 |
+
--positional_embedding_type none \
|
| 126 |
+
--backbone_type resnet18-1 \
|
| 127 |
+
--training_iterations 600001 \
|
| 128 |
+
--warmup_steps 2000 \
|
| 129 |
+
--use_scheduler \
|
| 130 |
+
--scheduler_type cosine \
|
| 131 |
+
--weight_decay 0.0001 \
|
| 132 |
+
--save_every 1000 \
|
| 133 |
+
--track_every 2000 \
|
| 134 |
+
--n_test_batches 50 \
|
| 135 |
+
--reload \
|
| 136 |
+
--num_workers_train 8 \
|
| 137 |
+
--batch_size 512 \
|
| 138 |
+
--batch_size_test 512 \
|
| 139 |
+
--lr 1e-4 \
|
| 140 |
+
--device 0 \
|
| 141 |
+
--seed 1 \
|
| 142 |
+
--no-reload
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
python -m tasks.image_classification.train \
|
| 146 |
+
--log_dir logs/cifar10-versus-humans/lstm/nlayers=2--d=256--i=64--heads=16--synch=256-512-0-h=64-random-pairing--iters=50x15--backbone=18-1--seed=2 \
|
| 147 |
+
--dataset cifar10 \
|
| 148 |
+
--model lstm \
|
| 149 |
+
--num_layers 2 \
|
| 150 |
+
--d_model 256 \
|
| 151 |
+
--d_input 64 \
|
| 152 |
+
--heads 16 \
|
| 153 |
+
--iterations 50 \
|
| 154 |
+
--dropout 0.0 \
|
| 155 |
+
--positional_embedding_type none \
|
| 156 |
+
--backbone_type resnet18-1 \
|
| 157 |
+
--training_iterations 600001 \
|
| 158 |
+
--warmup_steps 2000 \
|
| 159 |
+
--use_scheduler \
|
| 160 |
+
--scheduler_type cosine \
|
| 161 |
+
--weight_decay 0.0001 \
|
| 162 |
+
--save_every 1000 \
|
| 163 |
+
--track_every 2000 \
|
| 164 |
+
--n_test_batches 50 \
|
| 165 |
+
--reload \
|
| 166 |
+
--num_workers_train 8 \
|
| 167 |
+
--batch_size 512 \
|
| 168 |
+
--batch_size_test 512 \
|
| 169 |
+
--lr 1e-4 \
|
| 170 |
+
--device 0 \
|
| 171 |
+
--seed 2 \
|
| 172 |
+
--no-reload
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
python -m tasks.image_classification.train \
|
| 176 |
+
--log_dir logs/cifar10-versus-humans/lstm/nlayers=2--d=256--i=64--heads=16--synch=256-512-0-h=64-random-pairing--iters=50x15--backbone=18-1--seed=42 \
|
| 177 |
+
--dataset cifar10 \
|
| 178 |
+
--model lstm \
|
| 179 |
+
--num_layers 2 \
|
| 180 |
+
--d_model 256 \
|
| 181 |
+
--d_input 64 \
|
| 182 |
+
--heads 16 \
|
| 183 |
+
--iterations 50 \
|
| 184 |
+
--dropout 0.0 \
|
| 185 |
+
--positional_embedding_type none \
|
| 186 |
+
--backbone_type resnet18-1 \
|
| 187 |
+
--training_iterations 600001 \
|
| 188 |
+
--warmup_steps 2000 \
|
| 189 |
+
--use_scheduler \
|
| 190 |
+
--scheduler_type cosine \
|
| 191 |
+
--weight_decay 0.0001 \
|
| 192 |
+
--save_every 1000 \
|
| 193 |
+
--track_every 2000 \
|
| 194 |
+
--n_test_batches 50 \
|
| 195 |
+
--reload \
|
| 196 |
+
--num_workers_train 8 \
|
| 197 |
+
--batch_size 512 \
|
| 198 |
+
--batch_size_test 512 \
|
| 199 |
+
--lr 1e-4 \
|
| 200 |
+
--device 0 \
|
| 201 |
+
--seed 42 \
|
| 202 |
+
--no-reload
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
python -m tasks.image_classification.train \
|
| 209 |
+
--log_dir logs/cifar10-versus-humans/ff/d=256--backbone=18-1--seed=1 \
|
| 210 |
+
--dataset cifar10 \
|
| 211 |
+
--model ff \
|
| 212 |
+
--d_model 256 \
|
| 213 |
+
--memory_hidden_dims 64 \
|
| 214 |
+
--dropout 0.0 \
|
| 215 |
+
--dropout_nlm 0 \
|
| 216 |
+
--backbone_type resnet18-1 \
|
| 217 |
+
--training_iterations 600001 \
|
| 218 |
+
--warmup_steps 1000 \
|
| 219 |
+
--use_scheduler \
|
| 220 |
+
--scheduler_type cosine \
|
| 221 |
+
--weight_decay 0.0001 \
|
| 222 |
+
--save_every 1000 \
|
| 223 |
+
--track_every 2000 \
|
| 224 |
+
--n_test_batches 50 \
|
| 225 |
+
--num_workers_train 8 \
|
| 226 |
+
--batch_size 512 \
|
| 227 |
+
--batch_size_test 512 \
|
| 228 |
+
--lr 1e-4 \
|
| 229 |
+
--device 0 \
|
| 230 |
+
--seed 1
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
python -m tasks.image_classification.train \
|
| 234 |
+
--log_dir logs/cifar10-versus-humans/ff/d=256--backbone=18-1--seed=2 \
|
| 235 |
+
--dataset cifar10 \
|
| 236 |
+
--model ff \
|
| 237 |
+
--d_model 256 \
|
| 238 |
+
--memory_hidden_dims 64 \
|
| 239 |
+
--dropout 0.0 \
|
| 240 |
+
--dropout_nlm 0 \
|
| 241 |
+
--backbone_type resnet18-1 \
|
| 242 |
+
--training_iterations 600001 \
|
| 243 |
+
--warmup_steps 1000 \
|
| 244 |
+
--use_scheduler \
|
| 245 |
+
--scheduler_type cosine \
|
| 246 |
+
--weight_decay 0.0001 \
|
| 247 |
+
--save_every 1000 \
|
| 248 |
+
--track_every 2000 \
|
| 249 |
+
--n_test_batches 50 \
|
| 250 |
+
--num_workers_train 8 \
|
| 251 |
+
--batch_size 512 \
|
| 252 |
+
--batch_size_test 512 \
|
| 253 |
+
--lr 1e-4 \
|
| 254 |
+
--device 0 \
|
| 255 |
+
--seed 2
|
| 256 |
+
|
| 257 |
+
python -m tasks.image_classification.train \
|
| 258 |
+
--log_dir logs/cifar10-versus-humans/ff/d=256--backbone=18-1--seed=42 \
|
| 259 |
+
--dataset cifar10 \
|
| 260 |
+
--model ff \
|
| 261 |
+
--d_model 256 \
|
| 262 |
+
--memory_hidden_dims 64 \
|
| 263 |
+
--dropout 0.0 \
|
| 264 |
+
--dropout_nlm 0 \
|
| 265 |
+
--backbone_type resnet18-1 \
|
| 266 |
+
--training_iterations 600001 \
|
| 267 |
+
--warmup_steps 1000 \
|
| 268 |
+
--use_scheduler \
|
| 269 |
+
--scheduler_type cosine \
|
| 270 |
+
--weight_decay 0.0001 \
|
| 271 |
+
--save_every 1000 \
|
| 272 |
+
--track_every 2000 \
|
| 273 |
+
--n_test_batches 50 \
|
| 274 |
+
--num_workers_train 8 \
|
| 275 |
+
--batch_size 512 \
|
| 276 |
+
--batch_size_test 512 \
|
| 277 |
+
--lr 1e-4 \
|
| 278 |
+
--device 0 \
|
| 279 |
+
--seed 42
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
|
tasks/image_classification/scripts/train_imagenet.sh
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torchrun --standalone --nnodes=1 --nproc_per_node=8 -m tasks.image_classification.train_distributed \
|
| 2 |
+
--log_dir logs/imagenet/d=4096--i=1024--heads=16--sd=8--nlm=64--synch=8192-2048-32-h=64-random-pairing--iters=50x25--backbone=152x4 \
|
| 3 |
+
--model ctm \
|
| 4 |
+
--dataset imagenet \
|
| 5 |
+
--d_model 4096 \
|
| 6 |
+
--d_input 1024 \
|
| 7 |
+
--synapse_depth 8 \
|
| 8 |
+
--heads 16 \
|
| 9 |
+
--n_synch_out 8196 \
|
| 10 |
+
--n_synch_action 2048 \
|
| 11 |
+
--n_random_pairing_self 32 \
|
| 12 |
+
--neuron_select_type random-pairing \
|
| 13 |
+
--iterations 50 \
|
| 14 |
+
--memory_length 25 \
|
| 15 |
+
--deep_memory \
|
| 16 |
+
--memory_hidden_dims 64 \
|
| 17 |
+
--dropout 0.2 \
|
| 18 |
+
--dropout_nlm 0 \
|
| 19 |
+
--no-do_normalisation \
|
| 20 |
+
--positional_embedding_type none \
|
| 21 |
+
--backbone_type resnet152-4 \
|
| 22 |
+
--batch_size 64 \
|
| 23 |
+
--batch_size_test 64 \
|
| 24 |
+
--n_test_batches 200 \
|
| 25 |
+
--lr 5e-4 \
|
| 26 |
+
--gradient_clipping 20 \
|
| 27 |
+
--training_iterations 500001 \
|
| 28 |
+
--save_every 1000 \
|
| 29 |
+
--track_every 5000 \
|
| 30 |
+
--warmup_steps 10000 \
|
| 31 |
+
--use_scheduler \
|
| 32 |
+
--scheduler_type cosine \
|
| 33 |
+
--weight_decay 0.0 \
|
| 34 |
+
--seed 1 \
|
| 35 |
+
--use_amp \
|
| 36 |
+
--reload \
|
| 37 |
+
--num_workers_train 8 \
|
| 38 |
+
--use_custom_sampler
|
tasks/image_classification/train.py
ADDED
|
@@ -0,0 +1,690 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import random
|
| 4 |
+
|
| 5 |
+
import matplotlib.pyplot as plt
|
| 6 |
+
import numpy as np
|
| 7 |
+
import seaborn as sns
|
| 8 |
+
sns.set_style('darkgrid')
|
| 9 |
+
import torch
|
| 10 |
+
if torch.cuda.is_available():
|
| 11 |
+
# For faster
|
| 12 |
+
torch.set_float32_matmul_precision('high')
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
from tqdm.auto import tqdm
|
| 15 |
+
|
| 16 |
+
from data.custom_datasets import ImageNet
|
| 17 |
+
from torchvision import datasets
|
| 18 |
+
from torchvision import transforms
|
| 19 |
+
from tasks.image_classification.imagenet_classes import IMAGENET2012_CLASSES
|
| 20 |
+
from models.ctm import ContinuousThoughtMachine
|
| 21 |
+
from models.lstm import LSTMBaseline
|
| 22 |
+
from models.ff import FFBaseline
|
| 23 |
+
from tasks.image_classification.plotting import plot_neural_dynamics, make_classification_gif
|
| 24 |
+
from utils.housekeeping import set_seed, zip_python_code
|
| 25 |
+
from utils.losses import image_classification_loss # Used by CTM, LSTM
|
| 26 |
+
from utils.schedulers import WarmupCosineAnnealingLR, WarmupMultiStepLR, warmup
|
| 27 |
+
|
| 28 |
+
from autoclip.torch import QuantileClip
|
| 29 |
+
|
| 30 |
+
import gc
|
| 31 |
+
import torchvision
|
| 32 |
+
torchvision.disable_beta_transforms_warning()
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
import warnings
|
| 36 |
+
warnings.filterwarnings("ignore", message="using precomputed metric; inverse_transform will be unavailable")
|
| 37 |
+
warnings.filterwarnings('ignore', message='divide by zero encountered in power', category=RuntimeWarning)
|
| 38 |
+
warnings.filterwarnings(
|
| 39 |
+
"ignore",
|
| 40 |
+
"Corrupt EXIF data",
|
| 41 |
+
UserWarning,
|
| 42 |
+
r"^PIL\.TiffImagePlugin$" # Using a regular expression to match the module.
|
| 43 |
+
)
|
| 44 |
+
warnings.filterwarnings(
|
| 45 |
+
"ignore",
|
| 46 |
+
"UserWarning: Metadata Warning",
|
| 47 |
+
UserWarning,
|
| 48 |
+
r"^PIL\.TiffImagePlugin$" # Using a regular expression to match the module.
|
| 49 |
+
)
|
| 50 |
+
warnings.filterwarnings(
|
| 51 |
+
"ignore",
|
| 52 |
+
"UserWarning: Truncated File Read",
|
| 53 |
+
UserWarning,
|
| 54 |
+
r"^PIL\.TiffImagePlugin$" # Using a regular expression to match the module.
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def parse_args():
|
| 59 |
+
parser = argparse.ArgumentParser()
|
| 60 |
+
|
| 61 |
+
# Model Selection
|
| 62 |
+
parser.add_argument('--model', type=str, default='ctm', choices=['ctm', 'lstm', 'ff'], help='Model type to train.')
|
| 63 |
+
|
| 64 |
+
# Model Architecture
|
| 65 |
+
# Common
|
| 66 |
+
parser.add_argument('--d_model', type=int, default=512, help='Dimension of the model.')
|
| 67 |
+
parser.add_argument('--dropout', type=float, default=0.0, help='Dropout rate.')
|
| 68 |
+
parser.add_argument('--backbone_type', type=str, default='resnet18-4', help='Type of backbone featureiser.')
|
| 69 |
+
# CTM / LSTM specific
|
| 70 |
+
parser.add_argument('--d_input', type=int, default=128, help='Dimension of the input (CTM, LSTM).')
|
| 71 |
+
parser.add_argument('--heads', type=int, default=4, help='Number of attention heads (CTM, LSTM).')
|
| 72 |
+
parser.add_argument('--iterations', type=int, default=75, help='Number of internal ticks (CTM, LSTM).')
|
| 73 |
+
parser.add_argument('--positional_embedding_type', type=str, default='none', help='Type of positional embedding (CTM, LSTM).',
|
| 74 |
+
choices=['none',
|
| 75 |
+
'learnable-fourier',
|
| 76 |
+
'multi-learnable-fourier',
|
| 77 |
+
'custom-rotational'])
|
| 78 |
+
# CTM specific
|
| 79 |
+
parser.add_argument('--synapse_depth', type=int, default=4, help='Depth of U-NET model for synapse. 1=linear, no unet (CTM only).')
|
| 80 |
+
parser.add_argument('--n_synch_out', type=int, default=512, help='Number of neurons to use for output synch (CTM only).')
|
| 81 |
+
parser.add_argument('--n_synch_action', type=int, default=512, help='Number of neurons to use for observation/action synch (CTM only).')
|
| 82 |
+
parser.add_argument('--neuron_select_type', type=str, default='random-pairing', help='Protocol for selecting neuron subset (CTM only).')
|
| 83 |
+
parser.add_argument('--n_random_pairing_self', type=int, default=0, help='Number of neurons paired self-to-self for synch (CTM only).')
|
| 84 |
+
parser.add_argument('--memory_length', type=int, default=25, help='Length of the pre-activation history for NLMS (CTM only).')
|
| 85 |
+
parser.add_argument('--deep_memory', action=argparse.BooleanOptionalAction, default=True, help='Use deep memory (CTM only).')
|
| 86 |
+
parser.add_argument('--memory_hidden_dims', type=int, default=4, help='Hidden dimensions of the memory if using deep memory (CTM only).')
|
| 87 |
+
parser.add_argument('--dropout_nlm', type=float, default=None, help='Dropout rate for NLMs specifically. Unset to match dropout on the rest of the model (CTM only).')
|
| 88 |
+
parser.add_argument('--do_normalisation', action=argparse.BooleanOptionalAction, default=False, help='Apply normalization in NLMs (CTM only).')
|
| 89 |
+
# LSTM specific
|
| 90 |
+
parser.add_argument('--num_layers', type=int, default=2, help='Number of LSTM stacked layers (LSTM only).')
|
| 91 |
+
|
| 92 |
+
# Training
|
| 93 |
+
parser.add_argument('--batch_size', type=int, default=32, help='Batch size for training.')
|
| 94 |
+
parser.add_argument('--batch_size_test', type=int, default=32, help='Batch size for testing.')
|
| 95 |
+
parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate for the model.')
|
| 96 |
+
parser.add_argument('--training_iterations', type=int, default=100001, help='Number of training iterations.')
|
| 97 |
+
parser.add_argument('--warmup_steps', type=int, default=5000, help='Number of warmup steps.')
|
| 98 |
+
parser.add_argument('--use_scheduler', action=argparse.BooleanOptionalAction, default=True, help='Use a learning rate scheduler.')
|
| 99 |
+
parser.add_argument('--scheduler_type', type=str, default='cosine', choices=['multistep', 'cosine'], help='Type of learning rate scheduler.')
|
| 100 |
+
parser.add_argument('--milestones', type=int, default=[8000, 15000, 20000], nargs='+', help='Learning rate scheduler milestones.')
|
| 101 |
+
parser.add_argument('--gamma', type=float, default=0.1, help='Learning rate scheduler gamma for multistep.')
|
| 102 |
+
parser.add_argument('--weight_decay', type=float, default=0.0, help='Weight decay factor.')
|
| 103 |
+
parser.add_argument('--weight_decay_exclusion_list', type=str, nargs='+', default=[], help='List to exclude from weight decay. Typically good: bn, ln, bias, start')
|
| 104 |
+
parser.add_argument('--gradient_clipping', type=float, default=-1, help='Gradient quantile clipping value (-1 to disable).')
|
| 105 |
+
parser.add_argument('--do_compile', action=argparse.BooleanOptionalAction, default=False, help='Try to compile model components (backbone, synapses if CTM).')
|
| 106 |
+
parser.add_argument('--num_workers_train', type=int, default=1, help='Num workers training.')
|
| 107 |
+
|
| 108 |
+
# Housekeeping
|
| 109 |
+
parser.add_argument('--log_dir', type=str, default='logs/scratch', help='Directory for logging.')
|
| 110 |
+
parser.add_argument('--dataset', type=str, default='cifar10', help='Dataset to use.')
|
| 111 |
+
parser.add_argument('--data_root', type=str, default='data/', help='Where to save dataset.')
|
| 112 |
+
parser.add_argument('--save_every', type=int, default=1000, help='Save checkpoints every this many iterations.')
|
| 113 |
+
parser.add_argument('--seed', type=int, default=412, help='Random seed.')
|
| 114 |
+
parser.add_argument('--reload', action=argparse.BooleanOptionalAction, default=False, help='Reload from disk?')
|
| 115 |
+
parser.add_argument('--reload_model_only', action=argparse.BooleanOptionalAction, default=False, help='Reload only the model from disk?')
|
| 116 |
+
parser.add_argument('--strict_reload', action=argparse.BooleanOptionalAction, default=True, help='Should use strict reload for model weights.') # Added back
|
| 117 |
+
parser.add_argument('--track_every', type=int, default=1000, help='Track metrics every this many iterations.')
|
| 118 |
+
parser.add_argument('--n_test_batches', type=int, default=20, help='How many minibatches to approx metrics. Set to -1 for full eval')
|
| 119 |
+
parser.add_argument('--device', type=int, nargs='+', default=[-1], help='List of GPU(s) to use. Set to -1 to use CPU.')
|
| 120 |
+
parser.add_argument('--use_amp', action=argparse.BooleanOptionalAction, default=False, help='AMP autocast.')
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
args = parser.parse_args()
|
| 124 |
+
return args
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def get_dataset(dataset, root):
|
| 128 |
+
if dataset=='imagenet':
|
| 129 |
+
dataset_mean = [0.485, 0.456, 0.406]
|
| 130 |
+
dataset_std = [0.229, 0.224, 0.225]
|
| 131 |
+
|
| 132 |
+
normalize = transforms.Normalize(mean=dataset_mean, std=dataset_std)
|
| 133 |
+
train_transform = transforms.Compose([
|
| 134 |
+
transforms.RandomResizedCrop(224),
|
| 135 |
+
transforms.RandomHorizontalFlip(),
|
| 136 |
+
transforms.ToTensor(),
|
| 137 |
+
normalize])
|
| 138 |
+
test_transform = transforms.Compose([
|
| 139 |
+
transforms.Resize(256),
|
| 140 |
+
transforms.CenterCrop(224),
|
| 141 |
+
transforms.ToTensor(),
|
| 142 |
+
normalize])
|
| 143 |
+
|
| 144 |
+
class_labels = list(IMAGENET2012_CLASSES.values())
|
| 145 |
+
|
| 146 |
+
train_data = ImageNet(which_split='train', transform=train_transform)
|
| 147 |
+
test_data = ImageNet(which_split='validation', transform=test_transform)
|
| 148 |
+
elif dataset=='cifar10':
|
| 149 |
+
dataset_mean = [0.49139968, 0.48215827, 0.44653124]
|
| 150 |
+
dataset_std = [0.24703233, 0.24348505, 0.26158768]
|
| 151 |
+
normalize = transforms.Normalize(mean=dataset_mean, std=dataset_std)
|
| 152 |
+
train_transform = transforms.Compose(
|
| 153 |
+
[transforms.AutoAugment(transforms.AutoAugmentPolicy.CIFAR10),
|
| 154 |
+
transforms.ToTensor(),
|
| 155 |
+
normalize,
|
| 156 |
+
])
|
| 157 |
+
|
| 158 |
+
test_transform = transforms.Compose(
|
| 159 |
+
[transforms.ToTensor(),
|
| 160 |
+
normalize,
|
| 161 |
+
])
|
| 162 |
+
train_data = datasets.CIFAR10(root, train=True, transform=train_transform, download=True)
|
| 163 |
+
test_data = datasets.CIFAR10(root, train=False, transform=test_transform, download=True)
|
| 164 |
+
class_labels = ['air', 'auto', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
|
| 165 |
+
elif dataset=='cifar100':
|
| 166 |
+
dataset_mean = [0.5070751592371341, 0.48654887331495067, 0.4409178433670344]
|
| 167 |
+
dataset_std = [0.2673342858792403, 0.2564384629170882, 0.27615047132568393]
|
| 168 |
+
normalize = transforms.Normalize(mean=dataset_mean, std=dataset_std)
|
| 169 |
+
|
| 170 |
+
train_transform = transforms.Compose(
|
| 171 |
+
[transforms.AutoAugment(transforms.AutoAugmentPolicy.CIFAR10),
|
| 172 |
+
transforms.ToTensor(),
|
| 173 |
+
normalize,
|
| 174 |
+
])
|
| 175 |
+
test_transform = transforms.Compose(
|
| 176 |
+
[transforms.ToTensor(),
|
| 177 |
+
normalize,
|
| 178 |
+
])
|
| 179 |
+
train_data = datasets.CIFAR100(root, train=True, transform=train_transform, download=True)
|
| 180 |
+
test_data = datasets.CIFAR100(root, train=False, transform=test_transform, download=True)
|
| 181 |
+
idx_order = np.argsort(np.array(list(train_data.class_to_idx.values())))
|
| 182 |
+
class_labels = list(np.array(list(train_data.class_to_idx.keys()))[idx_order])
|
| 183 |
+
else:
|
| 184 |
+
raise NotImplementedError
|
| 185 |
+
|
| 186 |
+
return train_data, test_data, class_labels, dataset_mean, dataset_std
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
if __name__=='__main__':
|
| 191 |
+
|
| 192 |
+
# Hosuekeeping
|
| 193 |
+
args = parse_args()
|
| 194 |
+
|
| 195 |
+
set_seed(args.seed, False)
|
| 196 |
+
if not os.path.exists(args.log_dir): os.makedirs(args.log_dir)
|
| 197 |
+
|
| 198 |
+
assert args.dataset in ['cifar10', 'cifar100', 'imagenet']
|
| 199 |
+
|
| 200 |
+
# Data
|
| 201 |
+
train_data, test_data, class_labels, dataset_mean, dataset_std = get_dataset(args.dataset, args.data_root)
|
| 202 |
+
|
| 203 |
+
num_workers_test = 1 # Defaulting to 1, change if needed
|
| 204 |
+
trainloader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers_train)
|
| 205 |
+
testloader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size_test, shuffle=True, num_workers=num_workers_test, drop_last=False)
|
| 206 |
+
|
| 207 |
+
prediction_reshaper = [-1] # Problem specific
|
| 208 |
+
args.out_dims = len(class_labels)
|
| 209 |
+
|
| 210 |
+
# For total reproducibility
|
| 211 |
+
zip_python_code(f'{args.log_dir}/repo_state.zip')
|
| 212 |
+
with open(f'{args.log_dir}/args.txt', 'w') as f:
|
| 213 |
+
print(args, file=f)
|
| 214 |
+
|
| 215 |
+
# Configure device string (support MPS on macOS)
|
| 216 |
+
if args.device[0] != -1:
|
| 217 |
+
device = f'cuda:{args.device[0]}'
|
| 218 |
+
elif torch.backends.mps.is_available():
|
| 219 |
+
device = 'mps'
|
| 220 |
+
else:
|
| 221 |
+
device = 'cpu'
|
| 222 |
+
print(f'Running model {args.model} on {device}')
|
| 223 |
+
|
| 224 |
+
# Build model conditionally
|
| 225 |
+
model = None
|
| 226 |
+
if args.model == 'ctm':
|
| 227 |
+
model = ContinuousThoughtMachine(
|
| 228 |
+
iterations=args.iterations,
|
| 229 |
+
d_model=args.d_model,
|
| 230 |
+
d_input=args.d_input,
|
| 231 |
+
heads=args.heads,
|
| 232 |
+
n_synch_out=args.n_synch_out,
|
| 233 |
+
n_synch_action=args.n_synch_action,
|
| 234 |
+
synapse_depth=args.synapse_depth,
|
| 235 |
+
memory_length=args.memory_length,
|
| 236 |
+
deep_nlms=args.deep_memory,
|
| 237 |
+
memory_hidden_dims=args.memory_hidden_dims,
|
| 238 |
+
do_layernorm_nlm=args.do_normalisation,
|
| 239 |
+
backbone_type=args.backbone_type,
|
| 240 |
+
positional_embedding_type=args.positional_embedding_type,
|
| 241 |
+
out_dims=args.out_dims,
|
| 242 |
+
prediction_reshaper=prediction_reshaper,
|
| 243 |
+
dropout=args.dropout,
|
| 244 |
+
dropout_nlm=args.dropout_nlm,
|
| 245 |
+
neuron_select_type=args.neuron_select_type,
|
| 246 |
+
n_random_pairing_self=args.n_random_pairing_self,
|
| 247 |
+
).to(device)
|
| 248 |
+
elif args.model == 'lstm':
|
| 249 |
+
model = LSTMBaseline(
|
| 250 |
+
num_layers=args.num_layers,
|
| 251 |
+
iterations=args.iterations,
|
| 252 |
+
d_model=args.d_model,
|
| 253 |
+
d_input=args.d_input,
|
| 254 |
+
heads=args.heads,
|
| 255 |
+
backbone_type=args.backbone_type,
|
| 256 |
+
positional_embedding_type=args.positional_embedding_type,
|
| 257 |
+
out_dims=args.out_dims,
|
| 258 |
+
prediction_reshaper=prediction_reshaper,
|
| 259 |
+
dropout=args.dropout,
|
| 260 |
+
).to(device)
|
| 261 |
+
elif args.model == 'ff':
|
| 262 |
+
model = FFBaseline(
|
| 263 |
+
d_model=args.d_model,
|
| 264 |
+
backbone_type=args.backbone_type,
|
| 265 |
+
out_dims=args.out_dims,
|
| 266 |
+
dropout=args.dropout,
|
| 267 |
+
).to(device)
|
| 268 |
+
else:
|
| 269 |
+
raise ValueError(f"Unknown model type: {args.model}")
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
# For lazy modules so that we can get param count
|
| 273 |
+
pseudo_inputs = train_data.__getitem__(0)[0].unsqueeze(0).to(device)
|
| 274 |
+
model(pseudo_inputs)
|
| 275 |
+
|
| 276 |
+
model.train()
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
print(f'Total params: {sum(p.numel() for p in model.parameters())}')
|
| 280 |
+
decay_params = []
|
| 281 |
+
no_decay_params = []
|
| 282 |
+
no_decay_names = []
|
| 283 |
+
for name, param in model.named_parameters():
|
| 284 |
+
if not param.requires_grad:
|
| 285 |
+
continue # Skip parameters that don't require gradients
|
| 286 |
+
if any(exclusion_str in name for exclusion_str in args.weight_decay_exclusion_list):
|
| 287 |
+
no_decay_params.append(param)
|
| 288 |
+
no_decay_names.append(name)
|
| 289 |
+
else:
|
| 290 |
+
decay_params.append(param)
|
| 291 |
+
if len(no_decay_names):
|
| 292 |
+
print(f'WARNING, excluding: {no_decay_names}')
|
| 293 |
+
|
| 294 |
+
# Optimizer and scheduler (Common setup)
|
| 295 |
+
if len(no_decay_names) and args.weight_decay!=0:
|
| 296 |
+
optimizer = torch.optim.AdamW([{'params': decay_params, 'weight_decay':args.weight_decay},
|
| 297 |
+
{'params': no_decay_params, 'weight_decay':0}],
|
| 298 |
+
lr=args.lr,
|
| 299 |
+
eps=1e-8 if not args.use_amp else 1e-6)
|
| 300 |
+
else:
|
| 301 |
+
optimizer = torch.optim.AdamW(model.parameters(),
|
| 302 |
+
lr=args.lr,
|
| 303 |
+
eps=1e-8 if not args.use_amp else 1e-6,
|
| 304 |
+
weight_decay=args.weight_decay)
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
warmup_schedule = warmup(args.warmup_steps)
|
| 308 |
+
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warmup_schedule.step)
|
| 309 |
+
if args.use_scheduler:
|
| 310 |
+
if args.scheduler_type == 'multistep':
|
| 311 |
+
scheduler = WarmupMultiStepLR(optimizer, warmup_steps=args.warmup_steps, milestones=args.milestones, gamma=args.gamma)
|
| 312 |
+
elif args.scheduler_type == 'cosine':
|
| 313 |
+
scheduler = WarmupCosineAnnealingLR(optimizer, args.warmup_steps, args.training_iterations, warmup_start_lr=1e-20, eta_min=1e-7)
|
| 314 |
+
else:
|
| 315 |
+
raise NotImplementedError
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
# Metrics tracking
|
| 319 |
+
start_iter = 0
|
| 320 |
+
train_losses = []
|
| 321 |
+
test_losses = []
|
| 322 |
+
train_accuracies = []
|
| 323 |
+
test_accuracies = []
|
| 324 |
+
iters = []
|
| 325 |
+
# Conditional metrics for CTM/LSTM
|
| 326 |
+
train_accuracies_most_certain = [] if args.model in ['ctm', 'lstm'] else None
|
| 327 |
+
test_accuracies_most_certain = [] if args.model in ['ctm', 'lstm'] else None
|
| 328 |
+
|
| 329 |
+
scaler = torch.amp.GradScaler("cuda" if "cuda" in device else "cpu", enabled=args.use_amp)
|
| 330 |
+
|
| 331 |
+
# Reloading logic
|
| 332 |
+
if args.reload:
|
| 333 |
+
checkpoint_path = f'{args.log_dir}/checkpoint.pt'
|
| 334 |
+
if os.path.isfile(checkpoint_path):
|
| 335 |
+
print(f'Reloading from: {checkpoint_path}')
|
| 336 |
+
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
|
| 337 |
+
if not args.strict_reload: print('WARNING: not using strict reload for model weights!')
|
| 338 |
+
load_result = model.load_state_dict(checkpoint['model_state_dict'], strict=args.strict_reload)
|
| 339 |
+
print(f" Loaded state_dict. Missing: {load_result.missing_keys}, Unexpected: {load_result.unexpected_keys}")
|
| 340 |
+
|
| 341 |
+
if not args.reload_model_only:
|
| 342 |
+
print('Reloading optimizer etc.')
|
| 343 |
+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
| 344 |
+
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
| 345 |
+
scaler.load_state_dict(checkpoint['scaler_state_dict'])
|
| 346 |
+
start_iter = checkpoint['iteration']
|
| 347 |
+
# Load common metrics
|
| 348 |
+
train_losses = checkpoint['train_losses']
|
| 349 |
+
test_losses = checkpoint['test_losses']
|
| 350 |
+
train_accuracies = checkpoint['train_accuracies']
|
| 351 |
+
test_accuracies = checkpoint['test_accuracies']
|
| 352 |
+
iters = checkpoint['iters']
|
| 353 |
+
|
| 354 |
+
# Load conditional metrics if they exist in checkpoint and are expected for current model
|
| 355 |
+
if args.model in ['ctm', 'lstm']:
|
| 356 |
+
train_accuracies_most_certain = checkpoint['train_accuracies_most_certain']
|
| 357 |
+
test_accuracies_most_certain = checkpoint['test_accuracies_most_certain']
|
| 358 |
+
|
| 359 |
+
else:
|
| 360 |
+
print('Only reloading model!')
|
| 361 |
+
|
| 362 |
+
if 'torch_rng_state' in checkpoint:
|
| 363 |
+
# Reset seeds
|
| 364 |
+
torch.set_rng_state(checkpoint['torch_rng_state'].cpu().byte())
|
| 365 |
+
np.random.set_state(checkpoint['numpy_rng_state'])
|
| 366 |
+
random.setstate(checkpoint['random_rng_state'])
|
| 367 |
+
|
| 368 |
+
del checkpoint
|
| 369 |
+
gc.collect()
|
| 370 |
+
if torch.cuda.is_available():
|
| 371 |
+
torch.cuda.empty_cache()
|
| 372 |
+
|
| 373 |
+
# Conditional Compilation
|
| 374 |
+
if args.do_compile:
|
| 375 |
+
print('Compiling...')
|
| 376 |
+
if hasattr(model, 'backbone'):
|
| 377 |
+
model.backbone = torch.compile(model.backbone, mode='reduce-overhead', fullgraph=True)
|
| 378 |
+
|
| 379 |
+
# Compile synapses only for CTM
|
| 380 |
+
if args.model == 'ctm':
|
| 381 |
+
model.synapses = torch.compile(model.synapses, mode='reduce-overhead', fullgraph=True)
|
| 382 |
+
|
| 383 |
+
# Training
|
| 384 |
+
iterator = iter(trainloader)
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
with tqdm(total=args.training_iterations, initial=start_iter, leave=False, position=0, dynamic_ncols=True) as pbar:
|
| 388 |
+
for bi in range(start_iter, args.training_iterations):
|
| 389 |
+
current_lr = optimizer.param_groups[-1]['lr']
|
| 390 |
+
|
| 391 |
+
try:
|
| 392 |
+
inputs, targets = next(iterator)
|
| 393 |
+
except StopIteration:
|
| 394 |
+
iterator = iter(trainloader)
|
| 395 |
+
inputs, targets = next(iterator)
|
| 396 |
+
|
| 397 |
+
inputs = inputs.to(device)
|
| 398 |
+
targets = targets.to(device)
|
| 399 |
+
|
| 400 |
+
loss = None
|
| 401 |
+
accuracy = None
|
| 402 |
+
# Model-specific forward and loss calculation
|
| 403 |
+
with torch.autocast(device_type="cuda" if "cuda" in device else "cpu", dtype=torch.float16, enabled=args.use_amp):
|
| 404 |
+
if args.do_compile: # CUDAGraph marking for clean compile
|
| 405 |
+
torch.compiler.cudagraph_mark_step_begin()
|
| 406 |
+
|
| 407 |
+
if args.model == 'ctm':
|
| 408 |
+
predictions, certainties, synchronisation = model(inputs)
|
| 409 |
+
loss, where_most_certain = image_classification_loss(predictions, certainties, targets, use_most_certain=True)
|
| 410 |
+
accuracy = (predictions.argmax(1)[torch.arange(predictions.size(0), device=predictions.device),where_most_certain] == targets).float().mean().item()
|
| 411 |
+
pbar_desc = f'CTM Loss={loss.item():0.3f}. Acc={accuracy:0.3f}. LR={current_lr:0.6f}. Where_certain={where_most_certain.float().mean().item():0.2f}+-{where_most_certain.float().std().item():0.2f} ({where_most_certain.min().item():d}<->{where_most_certain.max().item():d})'
|
| 412 |
+
|
| 413 |
+
elif args.model == 'lstm':
|
| 414 |
+
predictions, certainties, synchronisation = model(inputs)
|
| 415 |
+
loss, where_most_certain = image_classification_loss(predictions, certainties, targets, use_most_certain=True)
|
| 416 |
+
# LSTM where_most_certain will just be -1 because use_most_certain is False owing to stability issues with LSTM training
|
| 417 |
+
accuracy = (predictions.argmax(1)[torch.arange(predictions.size(0), device=predictions.device),where_most_certain] == targets).float().mean().item()
|
| 418 |
+
pbar_desc = f'LSTM Loss={loss.item():0.3f}. Acc={accuracy:0.3f}. LR={current_lr:0.6f}. Where_certain={where_most_certain.float().mean().item():0.2f}+-{where_most_certain.float().std().item():0.2f} ({where_most_certain.min().item():d}<->{where_most_certain.max().item():d})'
|
| 419 |
+
|
| 420 |
+
elif args.model == 'ff':
|
| 421 |
+
predictions = model(inputs)
|
| 422 |
+
loss = nn.CrossEntropyLoss()(predictions, targets)
|
| 423 |
+
accuracy = (predictions.argmax(1) == targets).float().mean().item()
|
| 424 |
+
pbar_desc = f'FF Loss={loss.item():0.3f}. Acc={accuracy:0.3f}. LR={current_lr:0.6f}'
|
| 425 |
+
|
| 426 |
+
scaler.scale(loss).backward()
|
| 427 |
+
|
| 428 |
+
if args.gradient_clipping!=-1:
|
| 429 |
+
scaler.unscale_(optimizer)
|
| 430 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.gradient_clipping)
|
| 431 |
+
|
| 432 |
+
scaler.step(optimizer)
|
| 433 |
+
scaler.update()
|
| 434 |
+
optimizer.zero_grad(set_to_none=True)
|
| 435 |
+
scheduler.step()
|
| 436 |
+
|
| 437 |
+
pbar.set_description(f'Dataset={args.dataset}. Model={args.model}. {pbar_desc}')
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
# Metrics tracking and plotting (conditional logic needed)
|
| 441 |
+
if (bi % args.track_every == 0 or bi == args.warmup_steps) and (bi != 0 or args.reload_model_only):
|
| 442 |
+
|
| 443 |
+
iters.append(bi)
|
| 444 |
+
current_train_losses = []
|
| 445 |
+
current_test_losses = []
|
| 446 |
+
current_train_accuracies = [] # Holds list of accuracies per tick for CTM/LSTM, single value for FF
|
| 447 |
+
current_test_accuracies = [] # Holds list of accuracies per tick for CTM/LSTM, single value for FF
|
| 448 |
+
current_train_accuracies_most_certain = [] # Only for CTM/LSTM
|
| 449 |
+
current_test_accuracies_most_certain = [] # Only for CTM/LSTM
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
# Reset BN stats using train mode
|
| 453 |
+
pbar.set_description('Resetting BN')
|
| 454 |
+
model.train()
|
| 455 |
+
for module in model.modules():
|
| 456 |
+
if isinstance(module, (torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d)):
|
| 457 |
+
module.reset_running_stats()
|
| 458 |
+
|
| 459 |
+
pbar.set_description('Tracking: Computing TRAIN metrics')
|
| 460 |
+
with torch.no_grad(): # Should use inference_mode? CTM/LSTM scripts used no_grad
|
| 461 |
+
loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size_test, shuffle=True, num_workers=num_workers_test)
|
| 462 |
+
all_targets_list = []
|
| 463 |
+
all_predictions_list = [] # List to store raw predictions (B, C, T) or (B, C)
|
| 464 |
+
all_predictions_most_certain_list = [] # Only for CTM/LSTM
|
| 465 |
+
all_losses = []
|
| 466 |
+
|
| 467 |
+
with tqdm(total=len(loader), initial=0, leave=False, position=1, dynamic_ncols=True) as pbar_inner:
|
| 468 |
+
for inferi, (inputs, targets) in enumerate(loader):
|
| 469 |
+
inputs = inputs.to(device)
|
| 470 |
+
targets = targets.to(device)
|
| 471 |
+
all_targets_list.append(targets.detach().cpu().numpy())
|
| 472 |
+
|
| 473 |
+
# Model-specific forward and loss for evaluation
|
| 474 |
+
if args.model == 'ctm':
|
| 475 |
+
these_predictions, certainties, _ = model(inputs)
|
| 476 |
+
loss, where_most_certain = image_classification_loss(these_predictions, certainties, targets, use_most_certain=True)
|
| 477 |
+
all_predictions_list.append(these_predictions.argmax(1).detach().cpu().numpy()) # Shape (B, T)
|
| 478 |
+
all_predictions_most_certain_list.append(these_predictions.argmax(1)[torch.arange(these_predictions.size(0), device=these_predictions.device), where_most_certain].detach().cpu().numpy()) # Shape (B,)
|
| 479 |
+
|
| 480 |
+
elif args.model == 'lstm':
|
| 481 |
+
these_predictions, certainties, _ = model(inputs)
|
| 482 |
+
loss, where_most_certain = image_classification_loss(these_predictions, certainties, targets, use_most_certain=True)
|
| 483 |
+
all_predictions_list.append(these_predictions.argmax(1).detach().cpu().numpy()) # Shape (B, T)
|
| 484 |
+
all_predictions_most_certain_list.append(these_predictions.argmax(1)[torch.arange(these_predictions.size(0), device=these_predictions.device), where_most_certain].detach().cpu().numpy()) # Shape (B,)
|
| 485 |
+
|
| 486 |
+
elif args.model == 'ff':
|
| 487 |
+
these_predictions = model(inputs)
|
| 488 |
+
loss = nn.CrossEntropyLoss()(these_predictions, targets)
|
| 489 |
+
all_predictions_list.append(these_predictions.argmax(1).detach().cpu().numpy()) # Shape (B,)
|
| 490 |
+
|
| 491 |
+
all_losses.append(loss.item())
|
| 492 |
+
|
| 493 |
+
if args.n_test_batches != -1 and inferi >= args.n_test_batches -1 : break # Check condition >= N-1
|
| 494 |
+
pbar_inner.set_description(f'Computing metrics for train (Batch {inferi+1})')
|
| 495 |
+
pbar_inner.update(1)
|
| 496 |
+
|
| 497 |
+
all_targets = np.concatenate(all_targets_list)
|
| 498 |
+
all_predictions = np.concatenate(all_predictions_list) # Shape (N, T) or (N,)
|
| 499 |
+
train_losses.append(np.mean(all_losses))
|
| 500 |
+
|
| 501 |
+
if args.model in ['ctm', 'lstm']:
|
| 502 |
+
# Accuracies per tick for CTM/LSTM
|
| 503 |
+
current_train_accuracies = np.mean(all_predictions == all_targets[...,np.newaxis], axis=0) # Mean over batch dim -> Shape (T,)
|
| 504 |
+
train_accuracies.append(current_train_accuracies)
|
| 505 |
+
# Most certain accuracy
|
| 506 |
+
all_predictions_most_certain = np.concatenate(all_predictions_most_certain_list)
|
| 507 |
+
current_train_accuracies_most_certain = (all_targets == all_predictions_most_certain).mean()
|
| 508 |
+
train_accuracies_most_certain.append(current_train_accuracies_most_certain)
|
| 509 |
+
else: # FF
|
| 510 |
+
current_train_accuracies = (all_targets == all_predictions).mean() # Shape scalar
|
| 511 |
+
train_accuracies.append(current_train_accuracies)
|
| 512 |
+
|
| 513 |
+
del these_predictions
|
| 514 |
+
|
| 515 |
+
|
| 516 |
+
# Switch to eval mode for test metrics (fixed BN stats)
|
| 517 |
+
model.eval()
|
| 518 |
+
pbar.set_description('Tracking: Computing TEST metrics')
|
| 519 |
+
with torch.inference_mode(): # Use inference_mode for test eval
|
| 520 |
+
loader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size_test, shuffle=True, num_workers=num_workers_test)
|
| 521 |
+
all_targets_list = []
|
| 522 |
+
all_predictions_list = []
|
| 523 |
+
all_predictions_most_certain_list = [] # Only for CTM/LSTM
|
| 524 |
+
all_losses = []
|
| 525 |
+
|
| 526 |
+
with tqdm(total=len(loader), initial=0, leave=False, position=1, dynamic_ncols=True) as pbar_inner:
|
| 527 |
+
for inferi, (inputs, targets) in enumerate(loader):
|
| 528 |
+
inputs = inputs.to(device)
|
| 529 |
+
targets = targets.to(device)
|
| 530 |
+
all_targets_list.append(targets.detach().cpu().numpy())
|
| 531 |
+
|
| 532 |
+
# Model-specific forward and loss for evaluation
|
| 533 |
+
if args.model == 'ctm':
|
| 534 |
+
these_predictions, certainties, _ = model(inputs)
|
| 535 |
+
loss, where_most_certain = image_classification_loss(these_predictions, certainties, targets, use_most_certain=True)
|
| 536 |
+
all_predictions_list.append(these_predictions.argmax(1).detach().cpu().numpy())
|
| 537 |
+
all_predictions_most_certain_list.append(these_predictions.argmax(1)[torch.arange(these_predictions.size(0), device=these_predictions.device), where_most_certain].detach().cpu().numpy())
|
| 538 |
+
|
| 539 |
+
elif args.model == 'lstm':
|
| 540 |
+
these_predictions, certainties, _ = model(inputs)
|
| 541 |
+
loss, where_most_certain = image_classification_loss(these_predictions, certainties, targets, use_most_certain=True)
|
| 542 |
+
all_predictions_list.append(these_predictions.argmax(1).detach().cpu().numpy())
|
| 543 |
+
all_predictions_most_certain_list.append(these_predictions.argmax(1)[torch.arange(these_predictions.size(0), device=these_predictions.device), where_most_certain].detach().cpu().numpy())
|
| 544 |
+
|
| 545 |
+
elif args.model == 'ff':
|
| 546 |
+
these_predictions = model(inputs)
|
| 547 |
+
loss = nn.CrossEntropyLoss()(these_predictions, targets)
|
| 548 |
+
all_predictions_list.append(these_predictions.argmax(1).detach().cpu().numpy())
|
| 549 |
+
|
| 550 |
+
all_losses.append(loss.item())
|
| 551 |
+
|
| 552 |
+
if args.n_test_batches != -1 and inferi >= args.n_test_batches -1: break
|
| 553 |
+
pbar_inner.set_description(f'Computing metrics for test (Batch {inferi+1})')
|
| 554 |
+
pbar_inner.update(1)
|
| 555 |
+
|
| 556 |
+
all_targets = np.concatenate(all_targets_list)
|
| 557 |
+
all_predictions = np.concatenate(all_predictions_list)
|
| 558 |
+
test_losses.append(np.mean(all_losses))
|
| 559 |
+
|
| 560 |
+
if args.model in ['ctm', 'lstm']:
|
| 561 |
+
current_test_accuracies = np.mean(all_predictions == all_targets[...,np.newaxis], axis=0)
|
| 562 |
+
test_accuracies.append(current_test_accuracies)
|
| 563 |
+
all_predictions_most_certain = np.concatenate(all_predictions_most_certain_list)
|
| 564 |
+
current_test_accuracies_most_certain = (all_targets == all_predictions_most_certain).mean()
|
| 565 |
+
test_accuracies_most_certain.append(current_test_accuracies_most_certain)
|
| 566 |
+
else: # FF
|
| 567 |
+
current_test_accuracies = (all_targets == all_predictions).mean()
|
| 568 |
+
test_accuracies.append(current_test_accuracies)
|
| 569 |
+
|
| 570 |
+
# Plotting (conditional)
|
| 571 |
+
figacc = plt.figure(figsize=(10, 10))
|
| 572 |
+
axacc_train = figacc.add_subplot(211)
|
| 573 |
+
axacc_test = figacc.add_subplot(212)
|
| 574 |
+
cm = sns.color_palette("viridis", as_cmap=True)
|
| 575 |
+
|
| 576 |
+
if args.model in ['ctm', 'lstm']:
|
| 577 |
+
# Plot per-tick accuracy for CTM/LSTM
|
| 578 |
+
train_acc_arr = np.array(train_accuracies) # Shape (N_iters, T)
|
| 579 |
+
test_acc_arr = np.array(test_accuracies) # Shape (N_iters, T)
|
| 580 |
+
num_ticks = train_acc_arr.shape[1]
|
| 581 |
+
for ti in range(num_ticks):
|
| 582 |
+
axacc_train.plot(iters, train_acc_arr[:, ti], color=cm(ti / num_ticks), alpha=0.3)
|
| 583 |
+
axacc_test.plot(iters, test_acc_arr[:, ti], color=cm(ti / num_ticks), alpha=0.3)
|
| 584 |
+
# Plot most certain accuracy
|
| 585 |
+
axacc_train.plot(iters, train_accuracies_most_certain, 'k--', alpha=0.7, label='Most certain')
|
| 586 |
+
axacc_test.plot(iters, test_accuracies_most_certain, 'k--', alpha=0.7, label='Most certain')
|
| 587 |
+
else: # FF
|
| 588 |
+
axacc_train.plot(iters, train_accuracies, 'k-', alpha=0.7, label='Accuracy') # Simple line
|
| 589 |
+
axacc_test.plot(iters, test_accuracies, 'k-', alpha=0.7, label='Accuracy')
|
| 590 |
+
|
| 591 |
+
axacc_train.set_title('Train Accuracy')
|
| 592 |
+
axacc_test.set_title('Test Accuracy')
|
| 593 |
+
axacc_train.legend(loc='lower right')
|
| 594 |
+
axacc_test.legend(loc='lower right')
|
| 595 |
+
axacc_train.set_xlim([0, args.training_iterations])
|
| 596 |
+
axacc_test.set_xlim([0, args.training_iterations])
|
| 597 |
+
if args.dataset=='cifar10':
|
| 598 |
+
axacc_train.set_ylim([0.75, 1])
|
| 599 |
+
axacc_test.set_ylim([0.75, 1])
|
| 600 |
+
|
| 601 |
+
|
| 602 |
+
|
| 603 |
+
figacc.tight_layout()
|
| 604 |
+
figacc.savefig(f'{args.log_dir}/accuracies.png', dpi=150)
|
| 605 |
+
plt.close(figacc)
|
| 606 |
+
|
| 607 |
+
figloss = plt.figure(figsize=(10, 5))
|
| 608 |
+
axloss = figloss.add_subplot(111)
|
| 609 |
+
axloss.plot(iters, train_losses, 'b-', linewidth=1, alpha=0.8, label=f'Train: {train_losses[-1]:.4f}')
|
| 610 |
+
axloss.plot(iters, test_losses, 'r-', linewidth=1, alpha=0.8, label=f'Test: {test_losses[-1]:.4f}')
|
| 611 |
+
axloss.legend(loc='upper right')
|
| 612 |
+
axloss.set_xlim([0, args.training_iterations])
|
| 613 |
+
axloss.set_ylim(bottom=0)
|
| 614 |
+
|
| 615 |
+
figloss.tight_layout()
|
| 616 |
+
figloss.savefig(f'{args.log_dir}/losses.png', dpi=150)
|
| 617 |
+
plt.close(figloss)
|
| 618 |
+
|
| 619 |
+
# Conditional Visualization (Only for CTM/LSTM)
|
| 620 |
+
if args.model in ['ctm', 'lstm']:
|
| 621 |
+
try: # For safety
|
| 622 |
+
inputs_viz, targets_viz = next(iter(testloader)) # Get a fresh batch
|
| 623 |
+
inputs_viz = inputs_viz.to(device)
|
| 624 |
+
targets_viz = targets_viz.to(device)
|
| 625 |
+
|
| 626 |
+
pbar.set_description('Tracking: Processing test data for viz')
|
| 627 |
+
predictions_viz, certainties_viz, _, pre_activations_viz, post_activations_viz, attention_tracking_viz = model(inputs_viz, track=True)
|
| 628 |
+
|
| 629 |
+
att_shape = (model.kv_features.shape[2], model.kv_features.shape[3])
|
| 630 |
+
attention_tracking_viz = attention_tracking_viz.reshape(
|
| 631 |
+
attention_tracking_viz.shape[0],
|
| 632 |
+
attention_tracking_viz.shape[1], -1, att_shape[0], att_shape[1])
|
| 633 |
+
|
| 634 |
+
pbar.set_description('Tracking: Neural dynamics plot')
|
| 635 |
+
plot_neural_dynamics(post_activations_viz, 100, args.log_dir, axis_snap=True)
|
| 636 |
+
|
| 637 |
+
imgi = 0 # Visualize the first image in the batch
|
| 638 |
+
img_to_gif = np.moveaxis(np.clip(inputs_viz[imgi].detach().cpu().numpy()*np.array(dataset_std).reshape(len(dataset_std), 1, 1) + np.array(dataset_mean).reshape(len(dataset_mean), 1, 1), 0, 1), 0, -1)
|
| 639 |
+
|
| 640 |
+
pbar.set_description('Tracking: Producing attention gif')
|
| 641 |
+
make_classification_gif(img_to_gif,
|
| 642 |
+
targets_viz[imgi].item(),
|
| 643 |
+
predictions_viz[imgi].detach().cpu().numpy(),
|
| 644 |
+
certainties_viz[imgi].detach().cpu().numpy(),
|
| 645 |
+
post_activations_viz[:,imgi],
|
| 646 |
+
attention_tracking_viz[:,imgi],
|
| 647 |
+
class_labels,
|
| 648 |
+
f'{args.log_dir}/{imgi}_attention.gif',
|
| 649 |
+
)
|
| 650 |
+
del predictions_viz, certainties_viz, pre_activations_viz, post_activations_viz, attention_tracking_viz
|
| 651 |
+
except Exception as e:
|
| 652 |
+
print(f"Visualization failed for model {args.model}: {e}")
|
| 653 |
+
|
| 654 |
+
|
| 655 |
+
|
| 656 |
+
gc.collect()
|
| 657 |
+
if torch.cuda.is_available():
|
| 658 |
+
torch.cuda.empty_cache()
|
| 659 |
+
model.train() # Switch back to train mode
|
| 660 |
+
|
| 661 |
+
|
| 662 |
+
# Save model checkpoint (conditional metrics)
|
| 663 |
+
if (bi % args.save_every == 0 or bi == args.training_iterations - 1) and bi != start_iter:
|
| 664 |
+
pbar.set_description('Saving model checkpoint...')
|
| 665 |
+
checkpoint_data = {
|
| 666 |
+
'model_state_dict': model.state_dict(),
|
| 667 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 668 |
+
'scheduler_state_dict': scheduler.state_dict(),
|
| 669 |
+
'scaler_state_dict': scaler.state_dict(),
|
| 670 |
+
'iteration': bi,
|
| 671 |
+
# Always save these
|
| 672 |
+
'train_losses': train_losses,
|
| 673 |
+
'test_losses': test_losses,
|
| 674 |
+
'train_accuracies': train_accuracies, # This is list of scalars for FF, list of arrays for CTM/LSTM
|
| 675 |
+
'test_accuracies': test_accuracies, # This is list of scalars for FF, list of arrays for CTM/LSTM
|
| 676 |
+
'iters': iters,
|
| 677 |
+
'args': args, # Save args used for this run
|
| 678 |
+
# RNG states
|
| 679 |
+
'torch_rng_state': torch.get_rng_state(),
|
| 680 |
+
'numpy_rng_state': np.random.get_state(),
|
| 681 |
+
'random_rng_state': random.getstate(),
|
| 682 |
+
}
|
| 683 |
+
# Conditionally add metrics specific to CTM/LSTM
|
| 684 |
+
if args.model in ['ctm', 'lstm']:
|
| 685 |
+
checkpoint_data['train_accuracies_most_certain'] = train_accuracies_most_certain
|
| 686 |
+
checkpoint_data['test_accuracies_most_certain'] = test_accuracies_most_certain
|
| 687 |
+
|
| 688 |
+
torch.save(checkpoint_data, f'{args.log_dir}/checkpoint.pt')
|
| 689 |
+
|
| 690 |
+
pbar.update(1)
|
tasks/image_classification/train_distributed.py
ADDED
|
@@ -0,0 +1,799 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import random
|
| 4 |
+
import time
|
| 5 |
+
|
| 6 |
+
import matplotlib.pyplot as plt
|
| 7 |
+
import numpy as np
|
| 8 |
+
import seaborn as sns
|
| 9 |
+
sns.set_style('darkgrid')
|
| 10 |
+
import torch
|
| 11 |
+
if torch.cuda.is_available():
|
| 12 |
+
# For faster
|
| 13 |
+
torch.set_float32_matmul_precision('high')
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
import torch.distributed as dist
|
| 16 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 17 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 18 |
+
from utils.samplers import FastRandomDistributedSampler
|
| 19 |
+
from tqdm.auto import tqdm
|
| 20 |
+
|
| 21 |
+
from tasks.image_classification.train import get_dataset # Use shared get_dataset
|
| 22 |
+
|
| 23 |
+
# Model Imports
|
| 24 |
+
from models.ctm import ContinuousThoughtMachine
|
| 25 |
+
from models.lstm import LSTMBaseline
|
| 26 |
+
from models.ff import FFBaseline
|
| 27 |
+
|
| 28 |
+
# Plotting/Utils Imports
|
| 29 |
+
from tasks.image_classification.plotting import plot_neural_dynamics, make_classification_gif
|
| 30 |
+
from utils.housekeeping import set_seed, zip_python_code
|
| 31 |
+
from utils.losses import image_classification_loss # For CTM, LSTM
|
| 32 |
+
from utils.schedulers import WarmupCosineAnnealingLR, WarmupMultiStepLR, warmup
|
| 33 |
+
|
| 34 |
+
import torchvision
|
| 35 |
+
torchvision.disable_beta_transforms_warning()
|
| 36 |
+
|
| 37 |
+
import warnings
|
| 38 |
+
warnings.filterwarnings("ignore", message="using precomputed metric; inverse_transform will be unavailable")
|
| 39 |
+
warnings.filterwarnings('ignore', message='divide by zero encountered in power', category=RuntimeWarning)
|
| 40 |
+
warnings.filterwarnings("ignore", message="UserWarning: Metadata Warning, tag 274 had too many entries: 4, expected 1")
|
| 41 |
+
warnings.filterwarnings(
|
| 42 |
+
"ignore",
|
| 43 |
+
"Corrupt EXIF data",
|
| 44 |
+
UserWarning,
|
| 45 |
+
r"^PIL\.TiffImagePlugin$" # Using a regular expression to match the module.
|
| 46 |
+
)
|
| 47 |
+
warnings.filterwarnings(
|
| 48 |
+
"ignore",
|
| 49 |
+
"UserWarning: Metadata Warning",
|
| 50 |
+
UserWarning,
|
| 51 |
+
r"^PIL\.TiffImagePlugin$" # Using a regular expression to match the module.
|
| 52 |
+
)
|
| 53 |
+
warnings.filterwarnings(
|
| 54 |
+
"ignore",
|
| 55 |
+
"UserWarning: Truncated File Read",
|
| 56 |
+
UserWarning,
|
| 57 |
+
r"^PIL\.TiffImagePlugin$" # Using a regular expression to match the module.
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def parse_args():
|
| 62 |
+
parser = argparse.ArgumentParser()
|
| 63 |
+
|
| 64 |
+
# Model Selection
|
| 65 |
+
parser.add_argument('--model', type=str, required=True, choices=['ctm', 'lstm', 'ff'], help='Model type to train.')
|
| 66 |
+
|
| 67 |
+
# Model Architecture
|
| 68 |
+
# Common
|
| 69 |
+
parser.add_argument('--d_model', type=int, default=512, help='Dimension of the model.')
|
| 70 |
+
parser.add_argument('--dropout', type=float, default=0.0, help='Dropout rate.')
|
| 71 |
+
parser.add_argument('--backbone_type', type=str, default='resnet18-4', help='Type of backbone featureiser.')
|
| 72 |
+
# CTM / LSTM specific
|
| 73 |
+
parser.add_argument('--d_input', type=int, default=128, help='Dimension of the input (CTM, LSTM).')
|
| 74 |
+
parser.add_argument('--heads', type=int, default=4, help='Number of attention heads (CTM, LSTM).')
|
| 75 |
+
parser.add_argument('--iterations', type=int, default=50, help='Number of internal ticks (CTM, LSTM).')
|
| 76 |
+
parser.add_argument('--positional_embedding_type', type=str, default='none', help='Type of positional embedding (CTM, LSTM).',
|
| 77 |
+
choices=['none',
|
| 78 |
+
'learnable-fourier',
|
| 79 |
+
'multi-learnable-fourier',
|
| 80 |
+
'custom-rotational'])
|
| 81 |
+
# CTM specific
|
| 82 |
+
parser.add_argument('--synapse_depth', type=int, default=4, help='Depth of U-NET model for synapse. 1=linear, no unet (CTM only).')
|
| 83 |
+
parser.add_argument('--n_synch_out', type=int, default=32, help='Number of neurons to use for output synch (CTM only).')
|
| 84 |
+
parser.add_argument('--n_synch_action', type=int, default=32, help='Number of neurons to use for observation/action synch (CTM only).')
|
| 85 |
+
parser.add_argument('--neuron_select_type', type=str, default='first-last', help='Protocol for selecting neuron subset (CTM only).')
|
| 86 |
+
parser.add_argument('--n_random_pairing_self', type=int, default=256, help='Number of neurons paired self-to-self for synch (CTM only).')
|
| 87 |
+
parser.add_argument('--memory_length', type=int, default=25, help='Length of the pre-activation history for NLMS (CTM only).')
|
| 88 |
+
parser.add_argument('--deep_memory', action=argparse.BooleanOptionalAction, default=True, help='Use deep memory (CTM only).')
|
| 89 |
+
parser.add_argument('--memory_hidden_dims', type=int, default=4, help='Hidden dimensions of the memory if using deep memory (CTM only).')
|
| 90 |
+
parser.add_argument('--dropout_nlm', type=float, default=None, help='Dropout rate for NLMs specifically. Unset to match dropout on the rest of the model (CTM only).')
|
| 91 |
+
parser.add_argument('--do_normalisation', action=argparse.BooleanOptionalAction, default=False, help='Apply normalization in NLMs (CTM only).')
|
| 92 |
+
# LSTM specific
|
| 93 |
+
parser.add_argument('--num_layers', type=int, default=2, help='Number of LSTM stacked layers (LSTM only).')
|
| 94 |
+
|
| 95 |
+
# Training
|
| 96 |
+
parser.add_argument('--batch_size', type=int, default=32, help='Batch size for training (per GPU).')
|
| 97 |
+
parser.add_argument('--batch_size_test', type=int, default=32, help='Batch size for testing (per GPU).')
|
| 98 |
+
parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate for the model.')
|
| 99 |
+
parser.add_argument('--training_iterations', type=int, default=100001, help='Number of training iterations.')
|
| 100 |
+
parser.add_argument('--warmup_steps', type=int, default=5000, help='Number of warmup steps.')
|
| 101 |
+
parser.add_argument('--use_scheduler', action=argparse.BooleanOptionalAction, default=True, help='Use a learning rate scheduler.')
|
| 102 |
+
parser.add_argument('--scheduler_type', type=str, default='cosine', choices=['multistep', 'cosine'], help='Type of learning rate scheduler.')
|
| 103 |
+
parser.add_argument('--milestones', type=int, default=[8000, 15000, 20000], nargs='+', help='Learning rate scheduler milestones.')
|
| 104 |
+
parser.add_argument('--gamma', type=float, default=0.1, help='Learning rate scheduler gamma for multistep.')
|
| 105 |
+
parser.add_argument('--weight_decay', type=float, default=0.0, help='Weight decay factor.')
|
| 106 |
+
parser.add_argument('--weight_decay_exclusion_list', type=str, nargs='+', default=[], help='List to exclude from weight decay. Typically good: bn, ln, bias, start')
|
| 107 |
+
parser.add_argument('--gradient_clipping', type=float, default=-1, help='Gradient quantile clipping value (-1 to disable).')
|
| 108 |
+
parser.add_argument('--num_workers_train', type=int, default=1, help='Num workers training.')
|
| 109 |
+
parser.add_argument('--use_custom_sampler', action=argparse.BooleanOptionalAction, default=False, help='Use custom fast sampler to avoid reshuffling.')
|
| 110 |
+
parser.add_argument('--do_compile', action=argparse.BooleanOptionalAction, default=False, help='Try to compile model components.')
|
| 111 |
+
|
| 112 |
+
# Housekeeping
|
| 113 |
+
parser.add_argument('--log_dir', type=str, default='logs/scratch', help='Directory for logging.')
|
| 114 |
+
parser.add_argument('--dataset', type=str, default='cifar10', help='Dataset to use.')
|
| 115 |
+
parser.add_argument('--data_root', type=str, default='data/', help='Where to save dataset.')
|
| 116 |
+
parser.add_argument('--save_every', type=int, default=1000, help='Save checkpoints every this many iterations.')
|
| 117 |
+
parser.add_argument('--seed', type=int, default=412, help='Random seed.')
|
| 118 |
+
parser.add_argument('--reload', action=argparse.BooleanOptionalAction, default=False, help='Reload from disk?')
|
| 119 |
+
parser.add_argument('--reload_model_only', action=argparse.BooleanOptionalAction, default=False, help='Reload only the model from disk?')
|
| 120 |
+
parser.add_argument('--strict_reload', action=argparse.BooleanOptionalAction, default=True, help='Should use strict reload for model weights.')
|
| 121 |
+
parser.add_argument('--ignore_metrics_when_reloading', action=argparse.BooleanOptionalAction, default=False, help='Ignore metrics when reloading?')
|
| 122 |
+
|
| 123 |
+
# Tracking
|
| 124 |
+
parser.add_argument('--track_every', type=int, default=1000, help='Track metrics every this many iterations.')
|
| 125 |
+
parser.add_argument('--n_test_batches', type=int, default=20, help='How many minibatches to approx metrics. Set to -1 for full eval')
|
| 126 |
+
parser.add_argument('--plot_indices', type=int, default=[0], nargs='+', help='Which indices in test data to plot?') # Defaulted to 0
|
| 127 |
+
|
| 128 |
+
# Precision
|
| 129 |
+
parser.add_argument('--use_amp', action=argparse.BooleanOptionalAction, default=False, help='AMP autocast.')
|
| 130 |
+
args = parser.parse_args()
|
| 131 |
+
return args
|
| 132 |
+
|
| 133 |
+
# --- DDP Setup Functions ---
|
| 134 |
+
def setup_ddp():
|
| 135 |
+
if 'RANK' not in os.environ:
|
| 136 |
+
# Basic setup for non-distributed run
|
| 137 |
+
os.environ['RANK'] = '0'
|
| 138 |
+
os.environ['WORLD_SIZE'] = '1'
|
| 139 |
+
os.environ['MASTER_ADDR'] = 'localhost'
|
| 140 |
+
os.environ['MASTER_PORT'] = '12355' # Ensure this port is free
|
| 141 |
+
os.environ['LOCAL_RANK'] = '0'
|
| 142 |
+
print("Running in non-distributed mode (simulated DDP setup).")
|
| 143 |
+
# Need to manually init if only 1 process desired for non-GPU testing
|
| 144 |
+
if not torch.cuda.is_available() or int(os.environ['WORLD_SIZE']) == 1:
|
| 145 |
+
dist.init_process_group(backend='gloo') # Gloo backend for CPU
|
| 146 |
+
print("Initialized process group with Gloo backend for single/CPU process.")
|
| 147 |
+
rank = int(os.environ['RANK'])
|
| 148 |
+
world_size = int(os.environ['WORLD_SIZE'])
|
| 149 |
+
local_rank = int(os.environ['LOCAL_RANK'])
|
| 150 |
+
return rank, world_size, local_rank
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
# Standard DDP setup
|
| 154 |
+
dist.init_process_group(backend='nccl') # 'nccl' for NVIDIA GPUs
|
| 155 |
+
rank = int(os.environ['RANK'])
|
| 156 |
+
world_size = int(os.environ['WORLD_SIZE'])
|
| 157 |
+
local_rank = int(os.environ['LOCAL_RANK'])
|
| 158 |
+
if torch.cuda.is_available():
|
| 159 |
+
torch.cuda.set_device(local_rank)
|
| 160 |
+
print(f"Rank {rank} setup on GPU {local_rank}")
|
| 161 |
+
else:
|
| 162 |
+
print(f"Rank {rank} setup on CPU (GPU not available or requested)")
|
| 163 |
+
return rank, world_size, local_rank
|
| 164 |
+
|
| 165 |
+
def cleanup_ddp():
|
| 166 |
+
if dist.is_initialized():
|
| 167 |
+
dist.destroy_process_group()
|
| 168 |
+
print("DDP cleanup complete.")
|
| 169 |
+
|
| 170 |
+
def is_main_process(rank):
|
| 171 |
+
return rank == 0
|
| 172 |
+
# --- End DDP Setup ---
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
if __name__=='__main__':
|
| 176 |
+
|
| 177 |
+
args = parse_args()
|
| 178 |
+
|
| 179 |
+
rank, world_size, local_rank = setup_ddp()
|
| 180 |
+
|
| 181 |
+
set_seed(args.seed + rank, False) # Add rank for different seeds per process
|
| 182 |
+
|
| 183 |
+
# Rank 0 handles directory creation and initial logging
|
| 184 |
+
if is_main_process(rank):
|
| 185 |
+
if not os.path.exists(args.log_dir): os.makedirs(args.log_dir)
|
| 186 |
+
zip_python_code(f'{args.log_dir}/repo_state.zip')
|
| 187 |
+
with open(f'{args.log_dir}/args.txt', 'w') as f:
|
| 188 |
+
print(args, file=f)
|
| 189 |
+
if world_size > 1: dist.barrier() # Sync after rank 0 setup
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
assert args.dataset in ['cifar10', 'cifar100', 'imagenet']
|
| 193 |
+
|
| 194 |
+
# Data Loading
|
| 195 |
+
train_data, test_data, class_labels, dataset_mean, dataset_std = get_dataset(args.dataset, args.data_root)
|
| 196 |
+
|
| 197 |
+
# Setup Samplers
|
| 198 |
+
# This custom sampler is useful when using large batch sizes for Cifar. Otherwise the reshuffle happens tediously often
|
| 199 |
+
train_sampler = (FastRandomDistributedSampler(train_data, num_replicas=world_size, rank=rank, seed=args.seed, epoch_steps=int(10e10))
|
| 200 |
+
if args.use_custom_sampler else
|
| 201 |
+
DistributedSampler(train_data, num_replicas=world_size, rank=rank, shuffle=True, seed=args.seed))
|
| 202 |
+
test_sampler = DistributedSampler(test_data, num_replicas=world_size, rank=rank, shuffle=False, seed=args.seed) # No shuffle needed for test; consistent
|
| 203 |
+
|
| 204 |
+
# Setup DataLoaders
|
| 205 |
+
trainloader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, sampler=train_sampler,
|
| 206 |
+
num_workers=args.num_workers_train, pin_memory=True, drop_last=True) # drop_last=True often used in DDP
|
| 207 |
+
testloader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size_test, sampler=test_sampler,
|
| 208 |
+
num_workers=1, pin_memory=True, drop_last=False)
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
prediction_reshaper = [-1] # Task specific
|
| 212 |
+
args.out_dims = len(class_labels)
|
| 213 |
+
|
| 214 |
+
# Setup Device
|
| 215 |
+
if torch.cuda.is_available():
|
| 216 |
+
device = torch.device(f'cuda:{local_rank}')
|
| 217 |
+
else:
|
| 218 |
+
device = torch.device('cpu')
|
| 219 |
+
if world_size > 1:
|
| 220 |
+
warnings.warn("Running DDP on CPU is not recommended.")
|
| 221 |
+
if is_main_process(rank):
|
| 222 |
+
print(f'Main process (Rank {rank}): Using device {device}. World size: {world_size}. Model: {args.model}')
|
| 223 |
+
|
| 224 |
+
# --- Model Definition (Conditional) ---
|
| 225 |
+
model_base = None # Base model before DDP wrapping
|
| 226 |
+
if args.model == 'ctm':
|
| 227 |
+
model_base = ContinuousThoughtMachine(
|
| 228 |
+
iterations=args.iterations,
|
| 229 |
+
d_model=args.d_model,
|
| 230 |
+
d_input=args.d_input,
|
| 231 |
+
heads=args.heads,
|
| 232 |
+
n_synch_out=args.n_synch_out,
|
| 233 |
+
n_synch_action=args.n_synch_action,
|
| 234 |
+
synapse_depth=args.synapse_depth,
|
| 235 |
+
memory_length=args.memory_length,
|
| 236 |
+
deep_nlms=args.deep_memory,
|
| 237 |
+
memory_hidden_dims=args.memory_hidden_dims,
|
| 238 |
+
do_layernorm_nlm=args.do_normalisation,
|
| 239 |
+
backbone_type=args.backbone_type,
|
| 240 |
+
positional_embedding_type=args.positional_embedding_type,
|
| 241 |
+
out_dims=args.out_dims,
|
| 242 |
+
prediction_reshaper=prediction_reshaper,
|
| 243 |
+
dropout=args.dropout,
|
| 244 |
+
dropout_nlm=args.dropout_nlm,
|
| 245 |
+
neuron_select_type=args.neuron_select_type,
|
| 246 |
+
n_random_pairing_self=args.n_random_pairing_self,
|
| 247 |
+
).to(device)
|
| 248 |
+
elif args.model == 'lstm':
|
| 249 |
+
model_base = LSTMBaseline(
|
| 250 |
+
num_layers=args.num_layers,
|
| 251 |
+
iterations=args.iterations,
|
| 252 |
+
d_model=args.d_model,
|
| 253 |
+
d_input=args.d_input,
|
| 254 |
+
heads=args.heads,
|
| 255 |
+
backbone_type=args.backbone_type,
|
| 256 |
+
positional_embedding_type=args.positional_embedding_type,
|
| 257 |
+
out_dims=args.out_dims,
|
| 258 |
+
prediction_reshaper=prediction_reshaper,
|
| 259 |
+
dropout=args.dropout,
|
| 260 |
+
start_type=args.start_type,
|
| 261 |
+
).to(device)
|
| 262 |
+
elif args.model == 'ff':
|
| 263 |
+
model_base = FFBaseline(
|
| 264 |
+
d_model=args.d_model,
|
| 265 |
+
backbone_type=args.backbone_type,
|
| 266 |
+
out_dims=args.out_dims,
|
| 267 |
+
dropout=args.dropout,
|
| 268 |
+
).to(device)
|
| 269 |
+
else:
|
| 270 |
+
raise ValueError(f"Unknown model type: {args.model}")
|
| 271 |
+
|
| 272 |
+
# Initialize lazy modules if any
|
| 273 |
+
try:
|
| 274 |
+
pseudo_inputs = train_data.__getitem__(0)[0].unsqueeze(0).to(device)
|
| 275 |
+
model_base(pseudo_inputs)
|
| 276 |
+
except Exception as e:
|
| 277 |
+
print(f"Warning: Pseudo forward pass failed: {e}")
|
| 278 |
+
|
| 279 |
+
# Wrap model with DDP
|
| 280 |
+
if device.type == 'cuda' and world_size > 1:
|
| 281 |
+
model = DDP(model_base, device_ids=[local_rank], output_device=local_rank)
|
| 282 |
+
elif device.type == 'cpu' and world_size > 1:
|
| 283 |
+
model = DDP(model_base) # No device_ids for CPU
|
| 284 |
+
else: # Single process run
|
| 285 |
+
model = model_base # No DDP wrapping needed
|
| 286 |
+
|
| 287 |
+
if is_main_process(rank):
|
| 288 |
+
# Access underlying model for param count
|
| 289 |
+
param_count = sum(p.numel() for p in model.module.parameters() if p.requires_grad) if world_size > 1 else sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 290 |
+
print(f'Total trainable params: {param_count}')
|
| 291 |
+
# --- End Model Definition ---
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
# Optimizer and scheduler
|
| 295 |
+
# Use model.parameters() directly, DDP handles it
|
| 296 |
+
decay_params = []
|
| 297 |
+
no_decay_params = []
|
| 298 |
+
no_decay_names = []
|
| 299 |
+
for name, param in model.named_parameters():
|
| 300 |
+
if not param.requires_grad:
|
| 301 |
+
continue # Skip parameters that don't require gradients
|
| 302 |
+
if any(exclusion_str in name for exclusion_str in args.weight_decay_exclusion_list):
|
| 303 |
+
no_decay_params.append(param)
|
| 304 |
+
no_decay_names.append(name)
|
| 305 |
+
else:
|
| 306 |
+
decay_params.append(param)
|
| 307 |
+
if len(no_decay_names) and is_main_process(rank):
|
| 308 |
+
print(f'WARNING, excluding: {no_decay_names}')
|
| 309 |
+
|
| 310 |
+
# Optimizer and scheduler (Common setup)
|
| 311 |
+
if len(no_decay_names) and args.weight_decay!=0:
|
| 312 |
+
optimizer = torch.optim.AdamW([{'params': decay_params, 'weight_decay':args.weight_decay},
|
| 313 |
+
{'params': no_decay_params, 'weight_decay':0}],
|
| 314 |
+
lr=args.lr,
|
| 315 |
+
eps=1e-8 if not args.use_amp else 1e-6)
|
| 316 |
+
else:
|
| 317 |
+
optimizer = torch.optim.AdamW(model.parameters(),
|
| 318 |
+
lr=args.lr,
|
| 319 |
+
eps=1e-8 if not args.use_amp else 1e-6,
|
| 320 |
+
weight_decay=args.weight_decay)
|
| 321 |
+
|
| 322 |
+
warmup_schedule = warmup(args.warmup_steps)
|
| 323 |
+
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warmup_schedule.step)
|
| 324 |
+
if args.use_scheduler:
|
| 325 |
+
if args.scheduler_type == 'multistep':
|
| 326 |
+
scheduler = WarmupMultiStepLR(optimizer, warmup_steps=args.warmup_steps, milestones=args.milestones, gamma=args.gamma)
|
| 327 |
+
elif args.scheduler_type == 'cosine':
|
| 328 |
+
scheduler = WarmupCosineAnnealingLR(optimizer, args.warmup_steps, args.training_iterations, warmup_start_lr=1e-20, eta_min=1e-7)
|
| 329 |
+
else:
|
| 330 |
+
raise NotImplementedError
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
# Metrics tracking (on Rank 0)
|
| 334 |
+
start_iter = 0
|
| 335 |
+
train_losses = []
|
| 336 |
+
test_losses = []
|
| 337 |
+
train_accuracies = [] # Placeholder for potential detailed accuracy
|
| 338 |
+
test_accuracies = [] # Placeholder for potential detailed accuracy
|
| 339 |
+
# Conditional metrics
|
| 340 |
+
train_accuracies_most_certain = [] if args.model in ['ctm', 'lstm'] else None # Scalar accuracy list
|
| 341 |
+
test_accuracies_most_certain = [] if args.model in ['ctm', 'lstm'] else None # Scalar accuracy list
|
| 342 |
+
train_accuracies_standard = [] if args.model == 'ff' else None # Standard accuracy list for FF
|
| 343 |
+
test_accuracies_standard = [] if args.model == 'ff' else None # Standard accuracy list for FF
|
| 344 |
+
iters = []
|
| 345 |
+
|
| 346 |
+
scaler = torch.amp.GradScaler("cuda" if device.type == 'cuda' else "cpu", enabled=args.use_amp)
|
| 347 |
+
# Reloading Logic
|
| 348 |
+
if args.reload:
|
| 349 |
+
map_location = device # Load directly onto the process's device
|
| 350 |
+
chkpt_path = f'{args.log_dir}/checkpoint.pt'
|
| 351 |
+
if os.path.isfile(chkpt_path):
|
| 352 |
+
print(f'Rank {rank}: Reloading from: {chkpt_path}')
|
| 353 |
+
checkpoint = torch.load(chkpt_path, map_location=map_location, weights_only=False)
|
| 354 |
+
|
| 355 |
+
# Determine underlying model based on whether DDP wrapping occurred
|
| 356 |
+
model_to_load = model.module if isinstance(model, DDP) else model
|
| 357 |
+
|
| 358 |
+
# Handle potential 'module.' prefix in saved state_dict
|
| 359 |
+
state_dict = checkpoint['model_state_dict']
|
| 360 |
+
has_module_prefix = all(k.startswith('module.') for k in state_dict)
|
| 361 |
+
is_wrapped = isinstance(model, DDP)
|
| 362 |
+
|
| 363 |
+
if has_module_prefix and not is_wrapped:
|
| 364 |
+
# Saved with DDP, loading into non-DDP model -> remove prefix
|
| 365 |
+
state_dict = {k.partition('module.')[2]: v for k,v in state_dict.items()}
|
| 366 |
+
elif not has_module_prefix and is_wrapped:
|
| 367 |
+
load_result = model_to_load.load_state_dict(state_dict, strict=args.strict_reload)
|
| 368 |
+
print(f" Loaded state_dict. Missing: {load_result.missing_keys}, Unexpected: {load_result.unexpected_keys}")
|
| 369 |
+
state_dict = None # Prevent loading again
|
| 370 |
+
|
| 371 |
+
if state_dict is not None:
|
| 372 |
+
load_result = model_to_load.load_state_dict(state_dict, strict=args.strict_reload)
|
| 373 |
+
print(f" Loaded state_dict. Missing: {load_result.missing_keys}, Unexpected: {load_result.unexpected_keys}")
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
if not args.reload_model_only:
|
| 377 |
+
print(f'Rank {rank}: Reloading optimizer, scheduler, scaler, iteration.')
|
| 378 |
+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
| 379 |
+
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
| 380 |
+
scaler_state_dict = checkpoint['scaler_state_dict']
|
| 381 |
+
if scaler.is_enabled():
|
| 382 |
+
print("Loading non-empty GradScaler state dict.")
|
| 383 |
+
try:
|
| 384 |
+
scaler.load_state_dict(scaler_state_dict)
|
| 385 |
+
except Exception as e:
|
| 386 |
+
print(f"Error loading GradScaler state dict: {e}")
|
| 387 |
+
print("Continuing with a fresh GradScaler state.")
|
| 388 |
+
|
| 389 |
+
start_iter = checkpoint['iteration']
|
| 390 |
+
# Only rank 0 loads metric history
|
| 391 |
+
if is_main_process(rank) and not args.ignore_metrics_when_reloading:
|
| 392 |
+
print(f'Rank {rank}: Reloading metrics history.')
|
| 393 |
+
iters = checkpoint['iters']
|
| 394 |
+
train_losses = checkpoint['train_losses']
|
| 395 |
+
test_losses = checkpoint['test_losses']
|
| 396 |
+
train_accuracies = checkpoint['train_accuracies']
|
| 397 |
+
test_accuracies = checkpoint['test_accuracies']
|
| 398 |
+
if args.model in ['ctm', 'lstm']:
|
| 399 |
+
train_accuracies_most_certain = checkpoint['train_accuracies_most_certain']
|
| 400 |
+
test_accuracies_most_certain = checkpoint['test_accuracies_most_certain']
|
| 401 |
+
elif args.model == 'ff':
|
| 402 |
+
train_accuracies_standard = checkpoint['train_accuracies_standard']
|
| 403 |
+
test_accuracies_standard = checkpoint['test_accuracies_standard']
|
| 404 |
+
elif is_main_process(rank) and args.ignore_metrics_when_reloading:
|
| 405 |
+
print(f'Rank {rank}: Ignoring metrics history upon reload.')
|
| 406 |
+
|
| 407 |
+
else:
|
| 408 |
+
print(f'Rank {rank}: Only reloading model weights!')
|
| 409 |
+
|
| 410 |
+
# Load RNG states
|
| 411 |
+
if is_main_process(rank) and 'torch_rng_state' in checkpoint and not args.reload_model_only:
|
| 412 |
+
print(f'Rank {rank}: Loading RNG states (may need DDP adaptation for full reproducibility).')
|
| 413 |
+
torch.set_rng_state(checkpoint['torch_rng_state'].cpu()) # Load CPU state
|
| 414 |
+
# Add CUDA state loading if needed, ensuring correct device handling
|
| 415 |
+
np.random.set_state(checkpoint['numpy_rng_state'])
|
| 416 |
+
random.setstate(checkpoint['random_rng_state'])
|
| 417 |
+
|
| 418 |
+
del checkpoint
|
| 419 |
+
if torch.cuda.is_available(): torch.cuda.empty_cache()
|
| 420 |
+
print(f"Rank {rank}: Reload finished, starting from iteration {start_iter}")
|
| 421 |
+
else:
|
| 422 |
+
print(f"Rank {rank}: Checkpoint not found at {chkpt_path}, starting from scratch.")
|
| 423 |
+
if world_size > 1: dist.barrier() # Sync after loading
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
# Conditional Compilation
|
| 427 |
+
if args.do_compile:
|
| 428 |
+
if is_main_process(rank): print('Compiling model components...')
|
| 429 |
+
# Compile on the underlying model if wrapped
|
| 430 |
+
model_to_compile = model.module if isinstance(model, DDP) else model
|
| 431 |
+
if hasattr(model_to_compile, 'backbone'):
|
| 432 |
+
model_to_compile.backbone = torch.compile(model_to_compile.backbone, mode='reduce-overhead', fullgraph=True)
|
| 433 |
+
if args.model == 'ctm':
|
| 434 |
+
if hasattr(model_to_compile, 'synapses'):
|
| 435 |
+
model_to_compile.synapses = torch.compile(model_to_compile.synapses, mode='reduce-overhead', fullgraph=True)
|
| 436 |
+
if world_size > 1: dist.barrier() # Sync after compilation
|
| 437 |
+
if is_main_process(rank): print('Compilation finished.')
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
# --- Training Loop ---
|
| 441 |
+
model.train() # Ensure model is in train mode
|
| 442 |
+
pbar = tqdm(total=args.training_iterations, initial=start_iter, leave=False, position=0, dynamic_ncols=True, disable=not is_main_process(rank))
|
| 443 |
+
|
| 444 |
+
iterator = iter(trainloader)
|
| 445 |
+
|
| 446 |
+
for bi in range(start_iter, args.training_iterations):
|
| 447 |
+
|
| 448 |
+
# Set sampler epoch (important for shuffling in DistributedSampler)
|
| 449 |
+
if not args.use_custom_sampler and hasattr(train_sampler, 'set_epoch'):
|
| 450 |
+
train_sampler.set_epoch(bi)
|
| 451 |
+
|
| 452 |
+
current_lr = optimizer.param_groups[-1]['lr']
|
| 453 |
+
|
| 454 |
+
time_start_data = time.time()
|
| 455 |
+
try:
|
| 456 |
+
inputs, targets = next(iterator)
|
| 457 |
+
except StopIteration:
|
| 458 |
+
# Reset iterator - set_epoch handles shuffling if needed
|
| 459 |
+
iterator = iter(trainloader)
|
| 460 |
+
inputs, targets = next(iterator)
|
| 461 |
+
|
| 462 |
+
|
| 463 |
+
inputs = inputs.to(device, non_blocking=True)
|
| 464 |
+
targets = targets.to(device, non_blocking=True)
|
| 465 |
+
time_end_data = time.time()
|
| 466 |
+
|
| 467 |
+
loss = None
|
| 468 |
+
# Model-specific forward and loss calculation
|
| 469 |
+
time_start_forward = time.time()
|
| 470 |
+
with torch.autocast(device_type="cuda" if device.type == 'cuda' else "cpu", dtype=torch.float16, enabled=args.use_amp):
|
| 471 |
+
if args.do_compile:
|
| 472 |
+
torch.compiler.cudagraph_mark_step_begin()
|
| 473 |
+
|
| 474 |
+
if args.model == 'ctm':
|
| 475 |
+
predictions, certainties, synchronisation = model(inputs)
|
| 476 |
+
loss, where_most_certain = image_classification_loss(predictions, certainties, targets, use_most_certain=True)
|
| 477 |
+
elif args.model == 'lstm':
|
| 478 |
+
predictions, certainties, synchronisation = model(inputs)
|
| 479 |
+
loss, where_most_certain = image_classification_loss(predictions, certainties, targets, use_most_certain=True)
|
| 480 |
+
elif args.model == 'ff':
|
| 481 |
+
predictions = model(inputs) # FF returns only predictions
|
| 482 |
+
loss = nn.CrossEntropyLoss()(predictions, targets)
|
| 483 |
+
where_most_certain = None # Not applicable for FF standard loss
|
| 484 |
+
time_end_forward = time.time()
|
| 485 |
+
time_start_backward = time.time()
|
| 486 |
+
|
| 487 |
+
scaler.scale(loss).backward() # DDP handles gradient synchronization
|
| 488 |
+
time_end_backward = time.time()
|
| 489 |
+
|
| 490 |
+
if args.gradient_clipping!=-1:
|
| 491 |
+
scaler.unscale_(optimizer)
|
| 492 |
+
# Clip gradients across all parameters controlled by the optimizer
|
| 493 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.gradient_clipping)
|
| 494 |
+
|
| 495 |
+
scaler.step(optimizer)
|
| 496 |
+
scaler.update()
|
| 497 |
+
optimizer.zero_grad(set_to_none=True)
|
| 498 |
+
scheduler.step()
|
| 499 |
+
|
| 500 |
+
# --- Aggregation and Logging (Rank 0) ---
|
| 501 |
+
# Aggregate loss for logging
|
| 502 |
+
loss_log = loss.detach() # Use detached loss for aggregation
|
| 503 |
+
if world_size > 1: dist.all_reduce(loss_log, op=dist.ReduceOp.AVG)
|
| 504 |
+
|
| 505 |
+
if is_main_process(rank):
|
| 506 |
+
# Calculate accuracy locally on rank 0 for description (approximate)
|
| 507 |
+
# Note: This uses rank 0's batch, not aggregated accuracy
|
| 508 |
+
accuracy_local = 0.0
|
| 509 |
+
if args.model in ['ctm', 'lstm']:
|
| 510 |
+
accuracy_local = (predictions.argmax(1)[torch.arange(predictions.size(0), device=device), where_most_certain] == targets).float().mean().item()
|
| 511 |
+
where_certain_tensor = where_most_certain.float() # Use rank 0's tensor for stats
|
| 512 |
+
pbar_desc = f'Timing; d={(time_end_data-time_start_data):0.3f}, f={(time_end_forward-time_start_forward):0.3f}, b={(time_end_backward-time_start_backward):0.3f}. Loss(avg)={loss_log.item():.3f} Acc(loc)={accuracy_local:.3f} LR={current_lr:.6f} WhereCert(loc)={where_certain_tensor.mean().item():.2f}'
|
| 513 |
+
elif args.model == 'ff':
|
| 514 |
+
accuracy_local = (predictions.argmax(1) == targets).float().mean().item()
|
| 515 |
+
pbar_desc = f'Timing; d={(time_end_data-time_start_data):0.3f}, f={(time_end_forward-time_start_forward):0.3f}, b={(time_end_backward-time_start_backward):0.3f}. Loss(avg)={loss_log.item():.3f} Acc(loc)={accuracy_local:.3f} LR={current_lr:.6f}'
|
| 516 |
+
|
| 517 |
+
pbar.set_description(f'{args.model.upper()} {pbar_desc}')
|
| 518 |
+
# --- End Aggregation and Logging ---
|
| 519 |
+
|
| 520 |
+
|
| 521 |
+
# --- Evaluation and Plotting (Rank 0 + Aggregation) ---
|
| 522 |
+
if bi % args.track_every == 0 and (bi != 0 or args.reload_model_only):
|
| 523 |
+
|
| 524 |
+
model.eval()
|
| 525 |
+
with torch.inference_mode():
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
# --- Distributed Evaluation ---
|
| 529 |
+
iters.append(bi)
|
| 530 |
+
|
| 531 |
+
# TRAIN METRICS
|
| 532 |
+
total_train_loss = torch.tensor(0.0, device=device)
|
| 533 |
+
total_train_correct_certain = torch.tensor(0.0, device=device) # CTM/LSTM
|
| 534 |
+
total_train_correct_standard = torch.tensor(0.0, device=device) # FF
|
| 535 |
+
total_train_samples = torch.tensor(0.0, device=device)
|
| 536 |
+
|
| 537 |
+
# Use a sampler for evaluation to ensure non-overlapping data if needed
|
| 538 |
+
train_eval_sampler = DistributedSampler(train_data, num_replicas=world_size, rank=rank, shuffle=False)
|
| 539 |
+
train_eval_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size_test, sampler=train_eval_sampler, num_workers=1, pin_memory=True)
|
| 540 |
+
|
| 541 |
+
pbar_inner_desc = 'Eval Train (Rank 0)' if is_main_process(rank) else None
|
| 542 |
+
with tqdm(total=len(train_eval_loader), desc=pbar_inner_desc, leave=False, position=1, dynamic_ncols=True, disable=not is_main_process(rank)) as pbar_inner:
|
| 543 |
+
for inferi, (inputs, targets) in enumerate(train_eval_loader):
|
| 544 |
+
inputs = inputs.to(device, non_blocking=True)
|
| 545 |
+
targets = targets.to(device, non_blocking=True)
|
| 546 |
+
|
| 547 |
+
loss_eval = None
|
| 548 |
+
if args.model == 'ctm':
|
| 549 |
+
predictions, certainties, _ = model(inputs)
|
| 550 |
+
loss_eval, where_most_certain = image_classification_loss(predictions, certainties, targets, use_most_certain=True)
|
| 551 |
+
preds_eval = predictions.argmax(1)[torch.arange(predictions.size(0), device=device), where_most_certain]
|
| 552 |
+
total_train_correct_certain += (preds_eval == targets).sum()
|
| 553 |
+
elif args.model == 'lstm':
|
| 554 |
+
predictions, certainties, _ = model(inputs)
|
| 555 |
+
loss_eval, where_most_certain = image_classification_loss(predictions, certainties, targets, use_most_certain=True)
|
| 556 |
+
preds_eval = predictions.argmax(1)[torch.arange(predictions.size(0), device=device), where_most_certain]
|
| 557 |
+
total_train_correct_certain += (preds_eval == targets).sum()
|
| 558 |
+
elif args.model == 'ff':
|
| 559 |
+
predictions = model(inputs)
|
| 560 |
+
loss_eval = nn.CrossEntropyLoss()(predictions, targets)
|
| 561 |
+
preds_eval = predictions.argmax(1)
|
| 562 |
+
total_train_correct_standard += (preds_eval == targets).sum()
|
| 563 |
+
|
| 564 |
+
total_train_loss += loss_eval * inputs.size(0)
|
| 565 |
+
total_train_samples += inputs.size(0)
|
| 566 |
+
|
| 567 |
+
if args.n_test_batches != -1 and inferi >= args.n_test_batches -1: break
|
| 568 |
+
pbar_inner.update(1)
|
| 569 |
+
|
| 570 |
+
# Aggregate Train Metrics
|
| 571 |
+
if world_size > 1:
|
| 572 |
+
dist.all_reduce(total_train_loss, op=dist.ReduceOp.SUM)
|
| 573 |
+
dist.all_reduce(total_train_correct_certain, op=dist.ReduceOp.SUM)
|
| 574 |
+
dist.all_reduce(total_train_correct_standard, op=dist.ReduceOp.SUM)
|
| 575 |
+
dist.all_reduce(total_train_samples, op=dist.ReduceOp.SUM)
|
| 576 |
+
|
| 577 |
+
# Calculate final Train metrics on Rank 0
|
| 578 |
+
if is_main_process(rank) and total_train_samples > 0:
|
| 579 |
+
avg_train_loss = total_train_loss.item() / total_train_samples.item()
|
| 580 |
+
train_losses.append(avg_train_loss)
|
| 581 |
+
if args.model in ['ctm', 'lstm']:
|
| 582 |
+
avg_train_acc_certain = total_train_correct_certain.item() / total_train_samples.item()
|
| 583 |
+
train_accuracies_most_certain.append(avg_train_acc_certain)
|
| 584 |
+
elif args.model == 'ff':
|
| 585 |
+
avg_train_acc_standard = total_train_correct_standard.item() / total_train_samples.item()
|
| 586 |
+
train_accuracies_standard.append(avg_train_acc_standard)
|
| 587 |
+
print(f"Iter {bi} Train Metrics (Agg): Loss={avg_train_loss:.4f}")
|
| 588 |
+
|
| 589 |
+
# TEST METRICS
|
| 590 |
+
total_test_loss = torch.tensor(0.0, device=device)
|
| 591 |
+
total_test_correct_certain = torch.tensor(0.0, device=device) # CTM/LSTM
|
| 592 |
+
total_test_correct_standard = torch.tensor(0.0, device=device) # FF
|
| 593 |
+
total_test_samples = torch.tensor(0.0, device=device)
|
| 594 |
+
|
| 595 |
+
pbar_inner_desc = 'Eval Test (Rank 0)' if is_main_process(rank) else None
|
| 596 |
+
with tqdm(total=len(testloader), desc=pbar_inner_desc, leave=False, position=1, dynamic_ncols=True, disable=not is_main_process(rank)) as pbar_inner:
|
| 597 |
+
for inferi, (inputs, targets) in enumerate(testloader): # Testloader already uses sampler
|
| 598 |
+
inputs = inputs.to(device, non_blocking=True)
|
| 599 |
+
targets = targets.to(device, non_blocking=True)
|
| 600 |
+
|
| 601 |
+
loss_eval = None
|
| 602 |
+
if args.model == 'ctm':
|
| 603 |
+
predictions, certainties, _ = model(inputs)
|
| 604 |
+
loss_eval, where_most_certain = image_classification_loss(predictions, certainties, targets, use_most_certain=True)
|
| 605 |
+
preds_eval = predictions.argmax(1)[torch.arange(predictions.size(0), device=device), where_most_certain]
|
| 606 |
+
total_test_correct_certain += (preds_eval == targets).sum()
|
| 607 |
+
elif args.model == 'lstm':
|
| 608 |
+
predictions, certainties, _ = model(inputs)
|
| 609 |
+
loss_eval, where_most_certain = image_classification_loss(predictions, certainties, targets, use_most_certain=True)
|
| 610 |
+
preds_eval = predictions.argmax(1)[torch.arange(predictions.size(0), device=device), where_most_certain]
|
| 611 |
+
total_test_correct_certain += (preds_eval == targets).sum()
|
| 612 |
+
elif args.model == 'ff':
|
| 613 |
+
predictions = model(inputs)
|
| 614 |
+
loss_eval = nn.CrossEntropyLoss()(predictions, targets)
|
| 615 |
+
preds_eval = predictions.argmax(1)
|
| 616 |
+
total_test_correct_standard += (preds_eval == targets).sum()
|
| 617 |
+
|
| 618 |
+
total_test_loss += loss_eval * inputs.size(0)
|
| 619 |
+
total_test_samples += inputs.size(0)
|
| 620 |
+
|
| 621 |
+
if args.n_test_batches != -1 and inferi >= args.n_test_batches -1: break
|
| 622 |
+
pbar_inner.update(1)
|
| 623 |
+
|
| 624 |
+
# Aggregate Test Metrics
|
| 625 |
+
if world_size > 1:
|
| 626 |
+
dist.all_reduce(total_test_loss, op=dist.ReduceOp.SUM)
|
| 627 |
+
dist.all_reduce(total_test_correct_certain, op=dist.ReduceOp.SUM)
|
| 628 |
+
dist.all_reduce(total_test_correct_standard, op=dist.ReduceOp.SUM)
|
| 629 |
+
dist.all_reduce(total_test_samples, op=dist.ReduceOp.SUM)
|
| 630 |
+
|
| 631 |
+
# Calculate and Plot final Test metrics on Rank 0
|
| 632 |
+
if is_main_process(rank) and total_test_samples > 0:
|
| 633 |
+
avg_test_loss = total_test_loss.item() / total_test_samples.item()
|
| 634 |
+
test_losses.append(avg_test_loss)
|
| 635 |
+
acc_label = ''
|
| 636 |
+
acc_val = 0.0
|
| 637 |
+
if args.model in ['ctm', 'lstm']:
|
| 638 |
+
avg_test_acc_certain = total_test_correct_certain.item() / total_test_samples.item()
|
| 639 |
+
test_accuracies_most_certain.append(avg_test_acc_certain)
|
| 640 |
+
acc_label = f'Most certain ({avg_test_acc_certain:.3f})'
|
| 641 |
+
acc_val = avg_test_acc_certain
|
| 642 |
+
elif args.model == 'ff':
|
| 643 |
+
avg_test_acc_standard = total_test_correct_standard.item() / total_test_samples.item()
|
| 644 |
+
test_accuracies_standard.append(avg_test_acc_standard)
|
| 645 |
+
acc_label = f'Standard Acc ({avg_test_acc_standard:.3f})'
|
| 646 |
+
acc_val = avg_test_acc_standard
|
| 647 |
+
print(f"Iter {bi} Test Metrics (Agg): Loss={avg_test_loss:.4f}, Acc={acc_val:.4f}\n")
|
| 648 |
+
|
| 649 |
+
|
| 650 |
+
# --- Plotting ---
|
| 651 |
+
figacc = plt.figure(figsize=(10, 10))
|
| 652 |
+
axacc_train = figacc.add_subplot(211)
|
| 653 |
+
axacc_test = figacc.add_subplot(212)
|
| 654 |
+
|
| 655 |
+
if args.model in ['ctm', 'lstm']:
|
| 656 |
+
axacc_train.plot(iters, train_accuracies_most_certain, 'k-', alpha=0.9, label=f'Most certain ({train_accuracies_most_certain[-1]:.3f})')
|
| 657 |
+
axacc_test.plot(iters, test_accuracies_most_certain, 'k-', alpha=0.9, label=acc_label)
|
| 658 |
+
elif args.model == 'ff':
|
| 659 |
+
axacc_train.plot(iters, train_accuracies_standard, 'k-', alpha=0.9, label=f'Standard Acc ({train_accuracies_standard[-1]:.3f})')
|
| 660 |
+
axacc_test.plot(iters, test_accuracies_standard, 'k-', alpha=0.9, label=acc_label)
|
| 661 |
+
|
| 662 |
+
axacc_train.set_title('Train Accuracy (Aggregated)')
|
| 663 |
+
axacc_test.set_title('Test Accuracy (Aggregated)')
|
| 664 |
+
axacc_train.legend(loc='lower right')
|
| 665 |
+
axacc_test.legend(loc='lower right')
|
| 666 |
+
axacc_train.set_xlim([0, args.training_iterations])
|
| 667 |
+
axacc_test.set_xlim([0, args.training_iterations])
|
| 668 |
+
|
| 669 |
+
# Keep dataset specific ylim adjustments if needed
|
| 670 |
+
if args.dataset == 'imagenet':
|
| 671 |
+
# For easy comparison when training
|
| 672 |
+
train_ylim_set = False
|
| 673 |
+
if args.model in ['ctm', 'lstm'] and len(train_accuracies_most_certain)>0 and np.any(np.array(train_accuracies_most_certain)>0.4): train_ylim_set=True; axacc_train.set_ylim([0.4, 1])
|
| 674 |
+
if args.model == 'ff' and len(train_accuracies_standard)>0 and np.any(np.array(train_accuracies_standard)>0.4): train_ylim_set=True; axacc_train.set_ylim([0.4, 1])
|
| 675 |
+
|
| 676 |
+
test_ylim_set = False
|
| 677 |
+
if args.model in ['ctm', 'lstm'] and len(test_accuracies_most_certain)>0 and np.any(np.array(test_accuracies_most_certain)>0.3): test_ylim_set=True; axacc_test.set_ylim([0.3, 0.8])
|
| 678 |
+
if args.model == 'ff' and len(test_accuracies_standard)>0 and np.any(np.array(test_accuracies_standard)>0.3): test_ylim_set=True; axacc_test.set_ylim([0.3, 0.8])
|
| 679 |
+
|
| 680 |
+
|
| 681 |
+
figacc.tight_layout()
|
| 682 |
+
figacc.savefig(f'{args.log_dir}/accuracies.png', dpi=150)
|
| 683 |
+
plt.close(figacc)
|
| 684 |
+
|
| 685 |
+
# Loss Plot
|
| 686 |
+
figloss = plt.figure(figsize=(10, 5))
|
| 687 |
+
axloss = figloss.add_subplot(111)
|
| 688 |
+
axloss.plot(iters, train_losses, 'b-', linewidth=1, alpha=0.8, label=f'Train (Aggregated): {train_losses[-1]:.4f}')
|
| 689 |
+
axloss.plot(iters, test_losses, 'r-', linewidth=1, alpha=0.8, label=f'Test (Aggregated): {test_losses[-1]:.4f}')
|
| 690 |
+
axloss.legend(loc='upper right')
|
| 691 |
+
axloss.set_xlabel("Iteration")
|
| 692 |
+
axloss.set_ylabel("Loss")
|
| 693 |
+
axloss.set_xlim([0, args.training_iterations])
|
| 694 |
+
axloss.set_ylim(bottom=0)
|
| 695 |
+
figloss.tight_layout()
|
| 696 |
+
figloss.savefig(f'{args.log_dir}/losses.png', dpi=150)
|
| 697 |
+
plt.close(figloss)
|
| 698 |
+
# --- End Plotting ---
|
| 699 |
+
|
| 700 |
+
# Visualization on Rank 0
|
| 701 |
+
if is_main_process(rank) and args.model in ['ctm', 'lstm']:
|
| 702 |
+
try:
|
| 703 |
+
model_module = model.module if isinstance(model, DDP) else model # Get underlying model
|
| 704 |
+
# Simplified viz: use first batch from testloader
|
| 705 |
+
inputs_viz, targets_viz = next(iter(testloader))
|
| 706 |
+
inputs_viz = inputs_viz.to(device)
|
| 707 |
+
targets_viz = targets_viz.to(device)
|
| 708 |
+
|
| 709 |
+
pbar.set_description('Tracking (Rank 0): Viz Fwd Pass')
|
| 710 |
+
predictions_viz, certainties_viz, _, pre_activations_viz, post_activations_viz, attention_tracking_viz = model_module(inputs_viz, track=True)
|
| 711 |
+
|
| 712 |
+
att_shape = (model_module.kv_features.shape[2], model_module.kv_features.shape[3])
|
| 713 |
+
attention_tracking_viz = attention_tracking_viz.reshape(
|
| 714 |
+
attention_tracking_viz.shape[0],
|
| 715 |
+
attention_tracking_viz.shape[1], -1, att_shape[0], att_shape[1])
|
| 716 |
+
|
| 717 |
+
|
| 718 |
+
pbar.set_description('Tracking (Rank 0): Dynamics Plot')
|
| 719 |
+
plot_neural_dynamics(post_activations_viz, 100, args.log_dir, axis_snap=True)
|
| 720 |
+
|
| 721 |
+
# Plot specific indices from test_data directly
|
| 722 |
+
pbar.set_description('Tracking (Rank 0): GIF Generation')
|
| 723 |
+
for plot_idx in args.plot_indices:
|
| 724 |
+
try:
|
| 725 |
+
if plot_idx < len(test_data):
|
| 726 |
+
inputs_plot, target_plot = test_data.__getitem__(plot_idx)
|
| 727 |
+
inputs_plot = inputs_plot.unsqueeze(0).to(device)
|
| 728 |
+
|
| 729 |
+
preds_plot, certs_plot, _, _, posts_plot, atts_plot = model_module(inputs_plot, track=True)
|
| 730 |
+
atts_plot = atts_plot.reshape(atts_plot.shape[0], atts_plot.shape[1], -1, att_shape[0], att_shape[1])
|
| 731 |
+
|
| 732 |
+
|
| 733 |
+
img_gif = np.moveaxis(np.clip(inputs_plot[0].detach().cpu().numpy()*np.array(dataset_std).reshape(len(dataset_std), 1, 1) + np.array(dataset_mean).reshape(len(dataset_mean), 1, 1), 0, 1), 0, -1)
|
| 734 |
+
|
| 735 |
+
make_classification_gif(img_gif, target_plot, preds_plot[0].detach().cpu().numpy(), certs_plot[0].detach().cpu().numpy(),
|
| 736 |
+
posts_plot[:,0], atts_plot[:,0] if atts_plot is not None else None, class_labels,
|
| 737 |
+
f'{args.log_dir}/idx{plot_idx}_attention.gif')
|
| 738 |
+
else:
|
| 739 |
+
print(f"Warning: Plot index {plot_idx} out of range for test dataset size {len(test_data)}.")
|
| 740 |
+
except Exception as e_gif:
|
| 741 |
+
print(f"Rank 0 GIF generation failed for index {plot_idx}: {e_gif}")
|
| 742 |
+
|
| 743 |
+
except Exception as e_viz:
|
| 744 |
+
print(f"Rank 0 visualization failed: {e_viz}")
|
| 745 |
+
|
| 746 |
+
|
| 747 |
+
|
| 748 |
+
if world_size > 1: dist.barrier() # Sync after evaluation block
|
| 749 |
+
model.train() # Set back to train mode
|
| 750 |
+
# --- End Evaluation Block ---
|
| 751 |
+
|
| 752 |
+
|
| 753 |
+
# --- Checkpointing (Rank 0) ---
|
| 754 |
+
if (bi % args.save_every == 0 or bi == args.training_iterations - 1) and bi != start_iter and is_main_process(rank):
|
| 755 |
+
pbar.set_description('Rank 0: Saving checkpoint...')
|
| 756 |
+
save_path = f'{args.log_dir}/checkpoint.pt'
|
| 757 |
+
# Access underlying model state dict if DDP is used
|
| 758 |
+
model_state_to_save = model.module.state_dict() if isinstance(model, DDP) else model.state_dict()
|
| 759 |
+
|
| 760 |
+
save_dict = {
|
| 761 |
+
'model_state_dict': model_state_to_save,
|
| 762 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 763 |
+
'scheduler_state_dict': scheduler.state_dict(),
|
| 764 |
+
'scaler_state_dict':scaler.state_dict(),
|
| 765 |
+
'iteration': bi,
|
| 766 |
+
'train_losses': train_losses,
|
| 767 |
+
'test_losses': test_losses,
|
| 768 |
+
'iters': iters,
|
| 769 |
+
'args': args,
|
| 770 |
+
'torch_rng_state': torch.get_rng_state(), # CPU state
|
| 771 |
+
'numpy_rng_state': np.random.get_state(),
|
| 772 |
+
'random_rng_state': random.getstate(),
|
| 773 |
+
# Include conditional metrics
|
| 774 |
+
'train_accuracies': train_accuracies, # Placeholder
|
| 775 |
+
'test_accuracies': test_accuracies, # Placeholder
|
| 776 |
+
}
|
| 777 |
+
if args.model in ['ctm', 'lstm']:
|
| 778 |
+
save_dict['train_accuracies_most_certain'] = train_accuracies_most_certain
|
| 779 |
+
save_dict['test_accuracies_most_certain'] = test_accuracies_most_certain
|
| 780 |
+
elif args.model == 'ff':
|
| 781 |
+
save_dict['train_accuracies_standard'] = train_accuracies_standard
|
| 782 |
+
save_dict['test_accuracies_standard'] = test_accuracies_standard
|
| 783 |
+
|
| 784 |
+
torch.save(save_dict , save_path)
|
| 785 |
+
pbar.set_description(f"Rank 0: Checkpoint saved to {save_path}")
|
| 786 |
+
# --- End Checkpointing ---
|
| 787 |
+
|
| 788 |
+
|
| 789 |
+
if world_size > 1: dist.barrier() # Sync before next iteration
|
| 790 |
+
|
| 791 |
+
# Update pbar on Rank 0
|
| 792 |
+
if is_main_process(rank):
|
| 793 |
+
pbar.update(1)
|
| 794 |
+
# --- End Training Loop ---
|
| 795 |
+
|
| 796 |
+
if is_main_process(rank):
|
| 797 |
+
pbar.close()
|
| 798 |
+
|
| 799 |
+
cleanup_ddp() # Cleanup DDP resources
|
tasks/mazes/README.md
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Mazes
|
| 2 |
+
|
| 3 |
+
This folder contains code for training and analysing 2D maze solving experiments
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
## Training
|
| 7 |
+
To run the maze training that we used for the paper, run the following command from the parent directory:
|
| 8 |
+
```
|
| 9 |
+
python -m tasks.mazes.train --d_model 2048 --d_input 512 --synapse_depth 4 --heads 8 --n_synch_out 64 --n_synch_action 32 --neuron_select_type first-last --iterations 75 --memory_length 25 --deep_memory --memory_hidden_dims 32 --dropout 0.1 --no-do_normalisation --positional_embedding_type none --backbone_type resnet34-2 --batch_size 64 --batch_size_test 64 --lr 1e-4 --training_iterations 1000001 --warmup_steps 10000 --use_scheduler --scheduler_type cosine --weight_decay 0.0 --log_dir logs/mazes/d=2048--i=512--h=8--ns=64-32--iters=75x25--h=32--drop=0.1--pos=none--back=34-2--seed=42 --dataset mazes-medium --save_every 2000 --track_every 5000 --seed 42 --n_test_batches 50
|
| 10 |
+
```
|
| 11 |
+
|
| 12 |
+
## Small training run
|
| 13 |
+
We also provide a 'mazes-small' dataset (see [here](https://drive.google.com/file/d/1cBgqhaUUtsrll8-o2VY42hPpyBcfFv86/view?usp=drivesdk)) for fast iteration and testing ideas. The following command can train a CTM locally without a GPU in 12-24 hours:
|
| 14 |
+
```
|
| 15 |
+
python -m tasks.mazes.train --dataset mazes-small --maze_route_length 50 --cirriculum_lookahead 5 --model ctm --d_model 1024 --d_input 256 --backbone_type resnet18-1 --synapse_depth 8 --heads 4 --n_synch_out 128 --n_synch_action 128 --neuron_select_type random-pairing --memory_length 25 --iterations 50 --training_iterations 100001 --lr 1e-4 --batch_size 64 --batch_size_test 32 --n_test_batches 50 --log_dir logs/mazes-small-tester --track_every 2000
|
| 16 |
+
```
|
tasks/mazes/analysis/README.md
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Analysis
|
| 2 |
+
|
| 3 |
+
This folder contains analysis code for 2D maze experiments. To build GIFs for imagenet run (from the base directory):
|
| 4 |
+
|
| 5 |
+
To run maze analysis run the following command from the parent directory:
|
| 6 |
+
```
|
| 7 |
+
python -m tasks.mazes.analysis.run --actions viz viz --checkpoint checkpoints/mazes/ctm_mazeslarge_D=2048_T=75_M=25.pt
|
| 8 |
+
```
|
| 9 |
+
|
| 10 |
+
You will need to download the checkpoint from here: https://drive.google.com/file/d/1vGiMaQCxzKVT68SipxDCW0W5n5jjEQnC/view?usp=drive_link . Extract this to the appropriate directory: `checkpoints/mazes/...` . Otherwise, use your own after training.
|
tasks/mazes/analysis/run.py
ADDED
|
@@ -0,0 +1,407 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
np.seterr(divide='ignore', invalid='warn') # Keep specific numpy error settings
|
| 4 |
+
import matplotlib as mpl
|
| 5 |
+
mpl.use('Agg') # Use Agg backend for matplotlib (important to set before importing pyplot)
|
| 6 |
+
import matplotlib.pyplot as plt
|
| 7 |
+
import seaborn as sns
|
| 8 |
+
sns.set_style('darkgrid') # Keep seaborn style
|
| 9 |
+
import os
|
| 10 |
+
import argparse
|
| 11 |
+
import cv2
|
| 12 |
+
import imageio # Used for saving GIFs in viz
|
| 13 |
+
|
| 14 |
+
# Local imports
|
| 15 |
+
from data.custom_datasets import MazeImageFolder
|
| 16 |
+
from models.ctm import ContinuousThoughtMachine
|
| 17 |
+
from tasks.mazes.plotting import draw_path #
|
| 18 |
+
from tasks.image_classification.plotting import save_frames_to_mp4
|
| 19 |
+
|
| 20 |
+
def has_solved_checker(x_maze, route, valid_only=True, fault_tolerance=1, exclusions=[]):
|
| 21 |
+
"""Checks if a route solves a maze."""
|
| 22 |
+
maze = np.copy(x_maze)
|
| 23 |
+
H, W, _ = maze.shape
|
| 24 |
+
start_coords = np.argwhere((maze == [1, 0, 0]).all(axis=2))
|
| 25 |
+
end_coords = np.argwhere((maze == [0, 1, 0]).all(axis=2))
|
| 26 |
+
|
| 27 |
+
if len(start_coords) == 0:
|
| 28 |
+
return False, (-1, -1), 0 # Cannot start
|
| 29 |
+
|
| 30 |
+
current_pos = tuple(start_coords[0])
|
| 31 |
+
target_pos = tuple(end_coords[0]) if len(end_coords) > 0 else None
|
| 32 |
+
|
| 33 |
+
mistakes_made = 0
|
| 34 |
+
final_pos = current_pos
|
| 35 |
+
path_taken_len = 0
|
| 36 |
+
|
| 37 |
+
for step in route:
|
| 38 |
+
if mistakes_made > fault_tolerance:
|
| 39 |
+
break
|
| 40 |
+
|
| 41 |
+
next_pos_candidate = list(current_pos) # Use a list for mutable coordinate calculation
|
| 42 |
+
if step == 0: next_pos_candidate[0] -= 1
|
| 43 |
+
elif step == 1: next_pos_candidate[0] += 1
|
| 44 |
+
elif step == 2: next_pos_candidate[1] -= 1
|
| 45 |
+
elif step == 3: next_pos_candidate[1] += 1
|
| 46 |
+
elif step == 4: pass # Stay in place
|
| 47 |
+
else: continue # Invalid step action
|
| 48 |
+
next_pos = tuple(next_pos_candidate)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
is_invalid_step = False
|
| 52 |
+
# Check bounds first, then maze content if in bounds
|
| 53 |
+
if not (0 <= next_pos[0] < H and 0 <= next_pos[1] < W):
|
| 54 |
+
is_invalid_step = True
|
| 55 |
+
elif np.all(maze[next_pos] == [0, 0, 0]): # Wall
|
| 56 |
+
is_invalid_step = True
|
| 57 |
+
|
| 58 |
+
if is_invalid_step:
|
| 59 |
+
mistakes_made += 1
|
| 60 |
+
if valid_only:
|
| 61 |
+
continue
|
| 62 |
+
|
| 63 |
+
current_pos = next_pos
|
| 64 |
+
path_taken_len += 1
|
| 65 |
+
|
| 66 |
+
if target_pos and current_pos == target_pos:
|
| 67 |
+
if mistakes_made <= fault_tolerance:
|
| 68 |
+
return True, current_pos, path_taken_len
|
| 69 |
+
|
| 70 |
+
if mistakes_made <= fault_tolerance:
|
| 71 |
+
# Assuming exclusions is a list of tuples (as populated in the 'gen' action)
|
| 72 |
+
if current_pos not in exclusions:
|
| 73 |
+
final_pos = current_pos
|
| 74 |
+
|
| 75 |
+
if target_pos and final_pos == target_pos and mistakes_made <= fault_tolerance: # Added mistakes_made check here
|
| 76 |
+
return True, final_pos, path_taken_len
|
| 77 |
+
return False, final_pos, path_taken_len
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def parse_args():
|
| 81 |
+
"""Parses command-line arguments for maze analysis."""
|
| 82 |
+
parser = argparse.ArgumentParser(description="Analyze Asynchronous Thought Machine on Maze Tasks")
|
| 83 |
+
parser.add_argument('--actions', type=str, nargs='+', default=['gen'], help="Actions: 'viz', 'gen'")
|
| 84 |
+
parser.add_argument('--device', type=int, nargs='+', default=[-1], help="GPU device index or -1 for CPU")
|
| 85 |
+
parser.add_argument('--checkpoint', type=str, default='checkpoints/mazes/ctm_mazeslarge_D=2048_T=75_M=25.pt', help="Path to CTM checkpoint")
|
| 86 |
+
parser.add_argument('--output_dir', type=str, default='tasks/mazes/analysis/outputs', help="Directory for analysis outputs")
|
| 87 |
+
parser.add_argument('--dataset_for_viz', type=str, default='large', help="Dataset for 'viz' action")
|
| 88 |
+
parser.add_argument('--dataset_for_gen', type=str, default='extralarge', help="Dataset for 'gen' action")
|
| 89 |
+
parser.add_argument('--batch_size_test', type=int, default=32, help="Batch size for loading test data for 'viz'")
|
| 90 |
+
parser.add_argument('--max_reapplications', type=int, default=20, help="When testing generalisation to extra large mazes")
|
| 91 |
+
parser.add_argument('--legacy_scaling', action=argparse.BooleanOptionalAction, default=True, help='Legacy checkpoints scale between 0 and 1, new ones can scale -1 to 1.')
|
| 92 |
+
return parser.parse_args()
|
| 93 |
+
|
| 94 |
+
def _load_ctm_model(checkpoint_path, device):
|
| 95 |
+
"""Loads the ContinuousThoughtMachine model from a checkpoint."""
|
| 96 |
+
print(f"Loading checkpoint: {checkpoint_path}")
|
| 97 |
+
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
|
| 98 |
+
model_args = checkpoint['args']
|
| 99 |
+
|
| 100 |
+
# Handle legacy arguments for model_args
|
| 101 |
+
if not hasattr(model_args, 'backbone_type') and hasattr(model_args, 'resnet_type'):
|
| 102 |
+
model_args.backbone_type = f'{model_args.resnet_type}-{getattr(model_args, "resnet_feature_scales", [4])[-1]}'
|
| 103 |
+
|
| 104 |
+
# Ensure prediction_reshaper is derived correctly
|
| 105 |
+
# Assuming out_dims exists and is used for this
|
| 106 |
+
prediction_reshaper = [model_args.out_dims // 5, 5] if hasattr(model_args, 'out_dims') else None
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
if not hasattr(model_args, 'neuron_select_type'):
|
| 110 |
+
model_args.neuron_select_type = 'first-last'
|
| 111 |
+
if not hasattr(model_args, 'n_random_pairing_self'):
|
| 112 |
+
model_args.n_random_pairing_self = 0
|
| 113 |
+
|
| 114 |
+
print("Instantiating CTM model...")
|
| 115 |
+
model = ContinuousThoughtMachine(
|
| 116 |
+
iterations=model_args.iterations,
|
| 117 |
+
d_model=model_args.d_model,
|
| 118 |
+
d_input=model_args.d_input,
|
| 119 |
+
heads=model_args.heads,
|
| 120 |
+
n_synch_out=model_args.n_synch_out,
|
| 121 |
+
n_synch_action=model_args.n_synch_action,
|
| 122 |
+
synapse_depth=model_args.synapse_depth,
|
| 123 |
+
memory_length=model_args.memory_length,
|
| 124 |
+
deep_nlms=model_args.deep_memory, # Mapping from model_args.deep_memory
|
| 125 |
+
memory_hidden_dims=model_args.memory_hidden_dims,
|
| 126 |
+
do_layernorm_nlm=model_args.do_normalisation, # Mapping from model_args.do_normalisation
|
| 127 |
+
backbone_type=model_args.backbone_type,
|
| 128 |
+
positional_embedding_type=model_args.positional_embedding_type,
|
| 129 |
+
out_dims=model_args.out_dims,
|
| 130 |
+
prediction_reshaper=prediction_reshaper,
|
| 131 |
+
dropout=0, # Explicitly setting dropout to 0 as in original
|
| 132 |
+
neuron_select_type=model_args.neuron_select_type,
|
| 133 |
+
n_random_pairing_self=model_args.n_random_pairing_self,
|
| 134 |
+
).to(device)
|
| 135 |
+
|
| 136 |
+
load_result = model.load_state_dict(checkpoint['state_dict'], strict=False)
|
| 137 |
+
print(f"Loaded state_dict. Missing keys: {load_result.missing_keys}, Unexpected keys: {load_result.unexpected_keys}")
|
| 138 |
+
model.eval()
|
| 139 |
+
return model
|
| 140 |
+
|
| 141 |
+
# --- Main Execution Block ---
|
| 142 |
+
if __name__=='__main__':
|
| 143 |
+
args = parse_args()
|
| 144 |
+
|
| 145 |
+
if args.device[0] != -1 and torch.cuda.is_available():
|
| 146 |
+
device = f'cuda:{args.device[0]}'
|
| 147 |
+
else:
|
| 148 |
+
device = 'cpu'
|
| 149 |
+
print(f"Using device: {device}")
|
| 150 |
+
|
| 151 |
+
palette = sns.color_palette("husl", 8)
|
| 152 |
+
cmap = plt.get_cmap('gist_rainbow')
|
| 153 |
+
|
| 154 |
+
# --- Generalisation Action ('gen') ---
|
| 155 |
+
if 'gen' in args.actions:
|
| 156 |
+
model = _load_ctm_model(args.checkpoint, device)
|
| 157 |
+
|
| 158 |
+
print(f"\n--- Running Generalisation Analysis ('gen'): {args.dataset_for_gen} ---")
|
| 159 |
+
target_dataset_name = f'{args.dataset_for_gen}'
|
| 160 |
+
data_root = f'data/mazes/{target_dataset_name}/test'
|
| 161 |
+
max_target_route_len = 50 # Specific to 'gen' action
|
| 162 |
+
|
| 163 |
+
test_data = MazeImageFolder(
|
| 164 |
+
root=data_root, which_set='test',
|
| 165 |
+
maze_route_length=max_target_route_len,
|
| 166 |
+
expand_range=not args.legacy_scaling, # Legacy checkpoints need a [0, 1] range, but it might be better to default to [-1, 1] in the future
|
| 167 |
+
trunc=True
|
| 168 |
+
)
|
| 169 |
+
# Load a single large batch for 'gen'
|
| 170 |
+
testloader = torch.utils.data.DataLoader(
|
| 171 |
+
test_data, batch_size=min(len(test_data), 2000),
|
| 172 |
+
shuffle=False, num_workers=1
|
| 173 |
+
)
|
| 174 |
+
inputs, targets = next(iter(testloader))
|
| 175 |
+
|
| 176 |
+
actual_lengths = (targets != 4).sum(dim=-1)
|
| 177 |
+
sorted_indices = torch.argsort(actual_lengths, descending=True)
|
| 178 |
+
inputs, targets, actual_lengths = inputs[sorted_indices], targets[sorted_indices], actual_lengths[sorted_indices]
|
| 179 |
+
|
| 180 |
+
test_how_many = min(1000, len(inputs))
|
| 181 |
+
print(f"Processing {test_how_many} mazes sorted by length...")
|
| 182 |
+
|
| 183 |
+
results = {}
|
| 184 |
+
fault_tolerance = 2 # Specific to 'gen' analysis
|
| 185 |
+
output_gen_dir = os.path.join(args.output_dir, 'gen', args.dataset_for_gen)
|
| 186 |
+
os.makedirs(output_gen_dir, exist_ok=True)
|
| 187 |
+
|
| 188 |
+
for n_tested in range(test_how_many):
|
| 189 |
+
maze_actual_length = actual_lengths[n_tested].item()
|
| 190 |
+
maze_idx_display = n_tested + 1
|
| 191 |
+
print(f"Testing maze {maze_idx_display}/{test_how_many} (Len: {maze_actual_length})...")
|
| 192 |
+
|
| 193 |
+
initial_input_maze = inputs[n_tested:n_tested+1].clone().to(device)
|
| 194 |
+
maze_output_dir = os.path.join(output_gen_dir, f"maze_{maze_idx_display}")
|
| 195 |
+
|
| 196 |
+
re_applications = 0
|
| 197 |
+
has_solved = False
|
| 198 |
+
current_input_maze = initial_input_maze
|
| 199 |
+
exclusions = []
|
| 200 |
+
long_frames = []
|
| 201 |
+
ongoing_solution_img = None
|
| 202 |
+
|
| 203 |
+
while not has_solved and re_applications < args.max_reapplications:
|
| 204 |
+
re_applications += 1
|
| 205 |
+
with torch.no_grad():
|
| 206 |
+
predictions, certainties, _, _, _, attention_tracking = model(current_input_maze, track=True)
|
| 207 |
+
|
| 208 |
+
h_feat, w_feat = model.kv_features.shape[-2:]
|
| 209 |
+
attention_tracking = attention_tracking.reshape(attention_tracking.shape[0], -1, h_feat, w_feat)
|
| 210 |
+
|
| 211 |
+
n_steps_viz = predictions.shape[-1] # Use a different name to avoid conflict if n_steps is used elsewhere
|
| 212 |
+
step_linspace = np.linspace(0, 1, n_steps_viz)
|
| 213 |
+
current_maze_np = current_input_maze[0].permute(1,2,0).detach().cpu().numpy()
|
| 214 |
+
|
| 215 |
+
for stepi in range(n_steps_viz):
|
| 216 |
+
pred_route = predictions[0, :, stepi].reshape(-1, 5).argmax(-1).detach().cpu().numpy()
|
| 217 |
+
frame = draw_path(current_maze_np, pred_route)
|
| 218 |
+
if attention_tracking is not None and stepi < attention_tracking.shape[0]:
|
| 219 |
+
try:
|
| 220 |
+
attn = attention_tracking[stepi].mean(0)
|
| 221 |
+
attn_resized = cv2.resize(attn, (current_maze_np.shape[1], current_maze_np.shape[0]), interpolation=cv2.INTER_LINEAR)
|
| 222 |
+
if attn_resized.max() > attn_resized.min():
|
| 223 |
+
attn_norm = (attn_resized - attn_resized.min()) / (attn_resized.max() - attn_resized.min())
|
| 224 |
+
attn_norm[attn_norm < np.percentile(attn_norm, 80)] = 0.0
|
| 225 |
+
frame = np.clip((np.copy(frame)*(1-attn_norm[:,:,np.newaxis])*1 + (attn_norm[:,:,np.newaxis]*0.8 * np.reshape(np.array(cmap(step_linspace[stepi]))[:3], (1, 1, 3)))), 0, 1)
|
| 226 |
+
except Exception: # Keep broad except for visualization robustness
|
| 227 |
+
pass
|
| 228 |
+
frame_resized = cv2.resize(frame, (int(current_maze_np.shape[1]*4), int(current_maze_np.shape[0]*4)), interpolation=cv2.INTER_NEAREST) # Corrected shape[1]*4 for height
|
| 229 |
+
long_frames.append((np.clip(frame_resized, 0, 1) * 255).astype(np.uint8))
|
| 230 |
+
|
| 231 |
+
where_most_certain = certainties[0, 1].argmax().item()
|
| 232 |
+
chosen_pred_route = predictions[0, :, where_most_certain].reshape(-1, 5).argmax(-1).detach().cpu().numpy()
|
| 233 |
+
current_start_loc_list = np.argwhere((current_maze_np == [1, 0, 0]).all(axis=2)).tolist()
|
| 234 |
+
|
| 235 |
+
# Ensure current_start_loc_list is not empty before trying to access its elements
|
| 236 |
+
if not current_start_loc_list:
|
| 237 |
+
print(f"Warning: Could not find start location in maze {maze_idx_display} during reapplication {re_applications}. Stopping reapplication.")
|
| 238 |
+
break # Cannot proceed without a start location
|
| 239 |
+
|
| 240 |
+
solved_now, final_pos, _ = has_solved_checker(current_maze_np, chosen_pred_route, True, fault_tolerance, exclusions)
|
| 241 |
+
|
| 242 |
+
path_img = draw_path(current_maze_np, chosen_pred_route, cmap=cmap, valid_only=True)
|
| 243 |
+
if ongoing_solution_img is None:
|
| 244 |
+
ongoing_solution_img = path_img
|
| 245 |
+
else:
|
| 246 |
+
mask = (np.any(ongoing_solution_img!=path_img, -1))&(~np.all(path_img==[1,1,1], -1))&(~np.all(ongoing_solution_img==[1,0,0], -1))
|
| 247 |
+
ongoing_solution_img[mask] = path_img[mask]
|
| 248 |
+
|
| 249 |
+
if solved_now:
|
| 250 |
+
has_solved = True
|
| 251 |
+
break
|
| 252 |
+
|
| 253 |
+
if tuple(current_start_loc_list[0]) == final_pos:
|
| 254 |
+
exclusions.append(tuple(current_start_loc_list[0]))
|
| 255 |
+
|
| 256 |
+
next_input = current_input_maze.clone()
|
| 257 |
+
old_start_idx = tuple(current_start_loc_list[0])
|
| 258 |
+
next_input[0, :, old_start_idx[0], old_start_idx[1]] = 1.0 # Reset old start to path
|
| 259 |
+
|
| 260 |
+
if 0 <= final_pos[0] < next_input.shape[2] and 0 <= final_pos[1] < next_input.shape[3]:
|
| 261 |
+
next_input[0, :, final_pos[0], final_pos[1]] = torch.tensor([1,0,0], device=device, dtype=next_input.dtype) # New start
|
| 262 |
+
else:
|
| 263 |
+
print(f"Warning: final_pos {final_pos} out of bounds for maze {maze_idx_display}. Stopping reapplication.")
|
| 264 |
+
break
|
| 265 |
+
current_input_maze = next_input
|
| 266 |
+
|
| 267 |
+
if has_solved:
|
| 268 |
+
print(f'Solved maze of length {maze_actual_length}! Saving...')
|
| 269 |
+
os.makedirs(maze_output_dir, exist_ok=True)
|
| 270 |
+
if ongoing_solution_img is not None:
|
| 271 |
+
cv2.imwrite(os.path.join(maze_output_dir, 'ongoing_solution.png'), (ongoing_solution_img * 255).astype(np.uint8)[:,:,::-1])
|
| 272 |
+
if long_frames:
|
| 273 |
+
save_frames_to_mp4([fm[:,:,::-1] for fm in long_frames], os.path.join(maze_output_dir, f'combined_process.mp4'), fps=45, gop_size=10, preset='veryslow', crf=20)
|
| 274 |
+
else:
|
| 275 |
+
print(f'Failed maze of length {maze_actual_length} after {re_applications} reapplications. Not saving visuals for this maze.')
|
| 276 |
+
|
| 277 |
+
if maze_actual_length not in results: results[maze_actual_length] = []
|
| 278 |
+
results[maze_actual_length].append((has_solved, re_applications))
|
| 279 |
+
|
| 280 |
+
fig_success, ax_success = plt.subplots()
|
| 281 |
+
fig_reapp, ax_reapp = plt.subplots()
|
| 282 |
+
sorted_lengths = sorted(results.keys())
|
| 283 |
+
if sorted_lengths:
|
| 284 |
+
success_rates = [np.mean([r[0] for r in results[l]]) * 100 for l in sorted_lengths]
|
| 285 |
+
reapps_mean = [np.mean([r[1] for r in results[l] if r[0]]) if any(r[0] for r in results[l]) else np.nan for l in sorted_lengths]
|
| 286 |
+
ax_success.plot(sorted_lengths, success_rates, linestyle='-', color=palette[0])
|
| 287 |
+
ax_reapp.plot(sorted_lengths, reapps_mean, linestyle='-', color=palette[5])
|
| 288 |
+
ax_success.set_xlabel('Route Length'); ax_success.set_ylabel('Success (%)')
|
| 289 |
+
ax_reapp.set_xlabel('Route Length'); ax_reapp.set_ylabel('Re-applications (Avg on Success)')
|
| 290 |
+
fig_success.tight_layout(pad=0.1); fig_reapp.tight_layout(pad=0.1)
|
| 291 |
+
fig_success.savefig(os.path.join(output_gen_dir, f'{args.dataset_for_gen}-success_rate.png'), dpi=200)
|
| 292 |
+
fig_success.savefig(os.path.join(output_gen_dir, f'{args.dataset_for_gen}-success_rate.pdf'), dpi=200)
|
| 293 |
+
fig_reapp.savefig(os.path.join(output_gen_dir, f'{args.dataset_for_gen}-re-applications.png'), dpi=200)
|
| 294 |
+
fig_reapp.savefig(os.path.join(output_gen_dir, f'{args.dataset_for_gen}-re-applications.pdf'), dpi=200)
|
| 295 |
+
plt.close(fig_success); plt.close(fig_reapp)
|
| 296 |
+
np.savez(os.path.join(output_gen_dir, f'{args.dataset_for_gen}_results.npz'), results=results)
|
| 297 |
+
|
| 298 |
+
print("\n--- Generalisation Analysis ('gen') Complete ---")
|
| 299 |
+
|
| 300 |
+
# --- Visualization Action ('viz') ---
|
| 301 |
+
if 'viz' in args.actions:
|
| 302 |
+
model = _load_ctm_model(args.checkpoint, device)
|
| 303 |
+
|
| 304 |
+
print(f"\n--- Running Visualization ('viz'): {args.dataset_for_viz} ---")
|
| 305 |
+
output_viz_dir = os.path.join(args.output_dir, 'viz')
|
| 306 |
+
os.makedirs(output_viz_dir, exist_ok=True)
|
| 307 |
+
|
| 308 |
+
target_dataset_name = f'{args.dataset_for_viz}'
|
| 309 |
+
data_root = f'data/mazes/{target_dataset_name}/test'
|
| 310 |
+
test_data = MazeImageFolder(
|
| 311 |
+
root=data_root, which_set='test',
|
| 312 |
+
maze_route_length=100, # Max route length for viz data
|
| 313 |
+
expand_range=not args.legacy_scaling, # # Legacy checkpoints need a [0, 1] range, but it might be better to default to [-1, 1] in the future
|
| 314 |
+
trunc=True
|
| 315 |
+
)
|
| 316 |
+
testloader = torch.utils.data.DataLoader(
|
| 317 |
+
test_data, batch_size=args.batch_size_test,
|
| 318 |
+
shuffle=False, num_workers=1
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
all_inputs, all_targets, all_lengths = [], [], []
|
| 322 |
+
for b_in, b_tgt in testloader:
|
| 323 |
+
all_inputs.append(b_in)
|
| 324 |
+
all_targets.append(b_tgt)
|
| 325 |
+
all_lengths.append((b_tgt != 4).sum(dim=-1))
|
| 326 |
+
|
| 327 |
+
if not all_inputs:
|
| 328 |
+
print("Error: No data in visualization loader. Exiting 'viz' action.")
|
| 329 |
+
exit()
|
| 330 |
+
|
| 331 |
+
all_inputs, all_targets, all_lengths = torch.cat(all_inputs), torch.cat(all_targets), torch.cat(all_lengths)
|
| 332 |
+
|
| 333 |
+
num_viz_mazes = 10
|
| 334 |
+
num_viz_mazes = min(num_viz_mazes, len(all_lengths))
|
| 335 |
+
|
| 336 |
+
if num_viz_mazes == 0:
|
| 337 |
+
print("Error: No mazes found to visualize. Exiting 'viz' action.")
|
| 338 |
+
exit()
|
| 339 |
+
|
| 340 |
+
top_indices = torch.argsort(all_lengths, descending=True)[:num_viz_mazes]
|
| 341 |
+
inputs_viz, targets_viz = all_inputs[top_indices].to(device), all_targets[top_indices]
|
| 342 |
+
|
| 343 |
+
print(f"Visualizing {len(inputs_viz)} longest mazes...")
|
| 344 |
+
|
| 345 |
+
with torch.no_grad():
|
| 346 |
+
predictions, _, _, _, _, attention_tracking = model(inputs_viz, track=True)
|
| 347 |
+
|
| 348 |
+
# Reshape attention: (Steps, Batch, Heads, H_feat, W_feat) assuming model.kv_features has H_feat, W_feat
|
| 349 |
+
# The original reshape was slightly different, this tries to match the likely intended dimensions for per-step, per-batch item attention
|
| 350 |
+
if attention_tracking is not None and hasattr(model, 'kv_features') and model.kv_features is not None:
|
| 351 |
+
attention_tracking = attention_tracking.reshape(
|
| 352 |
+
attention_tracking.shape[0], # Iterations/Steps
|
| 353 |
+
inputs_viz.size(0), # Batch size (num_viz_mazes)
|
| 354 |
+
-1, # Heads (inferred)
|
| 355 |
+
model.kv_features.shape[-2], # H_feat
|
| 356 |
+
model.kv_features.shape[-1] # W_feat
|
| 357 |
+
)
|
| 358 |
+
else:
|
| 359 |
+
attention_tracking = None # Ensure it's None if it can't be reshaped
|
| 360 |
+
print("Warning: Could not reshape attention_tracking. Visualizations may not include attention overlays.")
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
for maze_i in range(inputs_viz.size(0)):
|
| 364 |
+
maze_idx_display = maze_i + 1
|
| 365 |
+
maze_output_dir = os.path.join(output_viz_dir, f"maze_{maze_idx_display}")
|
| 366 |
+
os.makedirs(maze_output_dir, exist_ok=True)
|
| 367 |
+
|
| 368 |
+
current_input_np_original = inputs_viz[maze_i].permute(1,2,0).detach().cpu().numpy()
|
| 369 |
+
# Apply scaling for visualization based on legacy_scaling: Legacy checkpoints need a [0, 1] range, but it might be better to default to [-1, 1] in the future
|
| 370 |
+
current_input_np_display = (current_input_np_original + 1) / 2 if not args.legacy_scaling else current_input_np_original
|
| 371 |
+
|
| 372 |
+
current_target_route = targets_viz[maze_i].detach().cpu().numpy()
|
| 373 |
+
print(f"Generating viz for maze {maze_idx_display}...")
|
| 374 |
+
|
| 375 |
+
try:
|
| 376 |
+
solution_maze_img = draw_path(current_input_np_display, current_target_route, gt=True)
|
| 377 |
+
cv2.imwrite(os.path.join(maze_output_dir, 'solution_ground_truth.png'), (solution_maze_img * 255).astype(np.uint8)[:,:,::-1])
|
| 378 |
+
except Exception: # Keep broad except for visualization robustness
|
| 379 |
+
print(f"Could not save ground truth solution for maze {maze_idx_display}")
|
| 380 |
+
pass
|
| 381 |
+
|
| 382 |
+
frames = []
|
| 383 |
+
n_steps_viz = predictions.shape[-1] # Use a different name
|
| 384 |
+
step_linspace = np.linspace(0, 1, n_steps_viz)
|
| 385 |
+
|
| 386 |
+
for stepi in range(n_steps_viz):
|
| 387 |
+
pred_route = predictions[maze_i, :, stepi].reshape(-1, 5).argmax(-1).detach().cpu().numpy()
|
| 388 |
+
frame = draw_path(current_input_np_display, pred_route)
|
| 389 |
+
|
| 390 |
+
if attention_tracking is not None and stepi < attention_tracking.shape[0] and maze_i < attention_tracking.shape[1]:
|
| 391 |
+
|
| 392 |
+
# Attention for current step (stepi) and current maze in batch (maze_i), average over heads
|
| 393 |
+
attn = attention_tracking[stepi, maze_i].mean(0)
|
| 394 |
+
attn_resized = cv2.resize(attn, (current_input_np_display.shape[1], current_input_np_display.shape[0]), interpolation=cv2.INTER_LINEAR)
|
| 395 |
+
if attn_resized.max() > attn_resized.min():
|
| 396 |
+
attn_norm = (attn_resized - attn_resized.min()) / (attn_resized.max() - attn_resized.min())
|
| 397 |
+
attn_norm[attn_norm < np.percentile(attn_norm, 80)] = 0.0
|
| 398 |
+
frame = np.clip((np.copy(frame)*(1-attn_norm[:,:,np.newaxis])*0.9 + (attn_norm[:,:,np.newaxis]*1.2 * np.reshape(np.array(cmap(step_linspace[stepi]))[:3], (1, 1, 3)))), 0, 1)
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
frame_resized = cv2.resize(frame, (256, 256), interpolation=cv2.INTER_NEAREST)
|
| 402 |
+
frames.append((np.clip(frame_resized, 0, 1) * 255).astype(np.uint8))
|
| 403 |
+
|
| 404 |
+
if frames:
|
| 405 |
+
imageio.mimsave(os.path.join(maze_output_dir, 'attention_overlay.gif'), frames, fps=15, loop=0)
|
| 406 |
+
|
| 407 |
+
print("\n--- Visualization Action ('viz') Complete ---")
|
tasks/mazes/plotting.py
ADDED
|
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import numpy as np
|
| 3 |
+
import cv2
|
| 4 |
+
import torch
|
| 5 |
+
import os
|
| 6 |
+
import matplotlib.pyplot as plt
|
| 7 |
+
import imageio
|
| 8 |
+
|
| 9 |
+
from tqdm.auto import tqdm
|
| 10 |
+
|
| 11 |
+
def find_center_of_mass(array_2d):
|
| 12 |
+
"""
|
| 13 |
+
Alternative implementation using np.average and meshgrid.
|
| 14 |
+
This version is generally faster and more concise.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
array_2d: A 2D numpy array of values between 0 and 1.
|
| 18 |
+
|
| 19 |
+
Returns:
|
| 20 |
+
A tuple (x, y) representing the coordinates of the center of mass.
|
| 21 |
+
"""
|
| 22 |
+
total_mass = np.sum(array_2d)
|
| 23 |
+
if total_mass == 0:
|
| 24 |
+
return (np.nan, np.nan)
|
| 25 |
+
|
| 26 |
+
y_coords, x_coords = np.mgrid[:array_2d.shape[0], :array_2d.shape[1]]
|
| 27 |
+
x_center = np.average(x_coords, weights=array_2d)
|
| 28 |
+
y_center = np.average(y_coords, weights=array_2d)
|
| 29 |
+
return (round(y_center, 4), round(x_center, 4))
|
| 30 |
+
|
| 31 |
+
def draw_path(x, route, valid_only=False, gt=False, cmap=None):
|
| 32 |
+
"""
|
| 33 |
+
Draws a path on a maze image based on a given route.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
maze: A numpy array representing the maze image.
|
| 37 |
+
route: A list of integers representing the route, where 0 is up, 1 is down, 2 is left, and 3 is right.
|
| 38 |
+
valid_only: A boolean indicating whether to only draw valid steps (i.e., steps that don't go into walls).
|
| 39 |
+
|
| 40 |
+
Returns:
|
| 41 |
+
A numpy array representing the maze image with the path drawn in blue.
|
| 42 |
+
"""
|
| 43 |
+
x = np.copy(x)
|
| 44 |
+
start = np.argwhere((x == [1, 0, 0]).all(axis=2))
|
| 45 |
+
end = np.argwhere((x == [0, 1, 0]).all(axis=2))
|
| 46 |
+
if cmap is None:
|
| 47 |
+
cmap = plt.get_cmap('winter') if not valid_only else plt.get_cmap('summer')
|
| 48 |
+
|
| 49 |
+
# Initialize the current position
|
| 50 |
+
current_pos = start[0]
|
| 51 |
+
|
| 52 |
+
# Draw the path
|
| 53 |
+
colors = cmap(np.linspace(0, 1, len(route)))
|
| 54 |
+
si = 0
|
| 55 |
+
for step in route:
|
| 56 |
+
new_pos = current_pos
|
| 57 |
+
if step == 0: # Up
|
| 58 |
+
new_pos = (current_pos[0] - 1, current_pos[1])
|
| 59 |
+
elif step == 1: # Down
|
| 60 |
+
new_pos = (current_pos[0] + 1, current_pos[1])
|
| 61 |
+
elif step == 2: # Left
|
| 62 |
+
new_pos = (current_pos[0], current_pos[1] - 1)
|
| 63 |
+
elif step == 3: # Right
|
| 64 |
+
new_pos = (current_pos[0], current_pos[1] + 1)
|
| 65 |
+
elif step == 4: # Do nothing
|
| 66 |
+
pass
|
| 67 |
+
else:
|
| 68 |
+
raise ValueError("Invalid step: {}".format(step))
|
| 69 |
+
|
| 70 |
+
# Check if the new position is valid
|
| 71 |
+
if valid_only:
|
| 72 |
+
try:
|
| 73 |
+
if np.all(x[new_pos] == [0,0,0]): # Check if it's a wall
|
| 74 |
+
continue # Skip this step if it's invalid
|
| 75 |
+
except IndexError:
|
| 76 |
+
continue # Skip this step if it's out of bounds
|
| 77 |
+
|
| 78 |
+
# Draw the step
|
| 79 |
+
if new_pos[0] >= 0 and new_pos[0] < x.shape[0] and new_pos[1] >= 0 and new_pos[1] < x.shape[1]:
|
| 80 |
+
if not ((x[new_pos] == [1,0,0]).all() or (x[new_pos] == [0,1,0]).all()):
|
| 81 |
+
colour = colors[si][:3]
|
| 82 |
+
si += 1
|
| 83 |
+
x[new_pos] = x[new_pos]*0.5 + colour*0.5
|
| 84 |
+
|
| 85 |
+
# Update the current position
|
| 86 |
+
current_pos = new_pos
|
| 87 |
+
# cv2.imwrite('maze2.png', x[:,:,::-1]*255)
|
| 88 |
+
|
| 89 |
+
return x
|
| 90 |
+
|
| 91 |
+
def make_maze_gif(inputs, predictions, targets, attention_tracking, save_location, verbose=True):
|
| 92 |
+
"""
|
| 93 |
+
Expect inputs, predictions, targets as numpy arrays
|
| 94 |
+
"""
|
| 95 |
+
route_steps = []
|
| 96 |
+
route_colours = []
|
| 97 |
+
solution_maze = draw_path(np.moveaxis(inputs, 0, -1), targets)
|
| 98 |
+
|
| 99 |
+
n_heads = attention_tracking.shape[1]
|
| 100 |
+
mosaic = [['overlay', 'overlay', 'overlay', 'overlay', 'route', 'route', 'route', 'route'],
|
| 101 |
+
['overlay', 'overlay', 'overlay', 'overlay', 'route', 'route', 'route', 'route'],
|
| 102 |
+
['overlay', 'overlay', 'overlay', 'overlay', 'route', 'route', 'route', 'route'],
|
| 103 |
+
['overlay', 'overlay', 'overlay', 'overlay', 'route', 'route', 'route', 'route'],
|
| 104 |
+
['head_0', 'head_1', 'head_2', 'head_3', 'head_4', 'head_5', 'head_6', 'head_7'],
|
| 105 |
+
['head_8', 'head_9', 'head_10', 'head_11', 'head_12', 'head_13', 'head_14', 'head_15'],
|
| 106 |
+
]
|
| 107 |
+
if n_heads == 8:
|
| 108 |
+
mosaic = [['overlay', 'overlay', 'overlay', 'overlay', 'route', 'route', 'route', 'route'],
|
| 109 |
+
['overlay', 'overlay', 'overlay', 'overlay', 'route', 'route', 'route', 'route'],
|
| 110 |
+
['overlay', 'overlay', 'overlay', 'overlay', 'route', 'route', 'route', 'route'],
|
| 111 |
+
['overlay', 'overlay', 'overlay', 'overlay', 'route', 'route', 'route', 'route'],
|
| 112 |
+
['head_0', 'head_1', 'head_2', 'head_3', 'head_4', 'head_5', 'head_6', 'head_7'],
|
| 113 |
+
]
|
| 114 |
+
elif n_heads == 4:
|
| 115 |
+
mosaic = [['overlay', 'overlay', 'overlay', 'overlay', 'route', 'route', 'route', 'route'],
|
| 116 |
+
['overlay', 'overlay', 'overlay', 'overlay', 'route', 'route', 'route', 'route'],
|
| 117 |
+
['overlay', 'overlay', 'overlay', 'overlay', 'route', 'route', 'route', 'route'],
|
| 118 |
+
['overlay', 'overlay', 'overlay', 'overlay', 'route', 'route', 'route', 'route'],
|
| 119 |
+
['head_0', 'head_0', 'head_1', 'head_1', 'head_2', 'head_2', 'head_3', 'head_3'],
|
| 120 |
+
['head_0', 'head_0', 'head_1', 'head_1', 'head_2', 'head_2', 'head_3', 'head_3'],
|
| 121 |
+
]
|
| 122 |
+
|
| 123 |
+
img_aspect = 1
|
| 124 |
+
figscale = 1
|
| 125 |
+
aspect_ratio = (len(mosaic[0]) * figscale, len(mosaic) * figscale * img_aspect) # W, H
|
| 126 |
+
|
| 127 |
+
route_steps = [np.unravel_index(np.argmax((inputs == np.reshape(np.array([1, 0, 0]), (3, 1, 1))).all(0)), inputs.shape[1:])] # Starting point
|
| 128 |
+
frames = []
|
| 129 |
+
cmap = plt.get_cmap('gist_rainbow')
|
| 130 |
+
cmap_viridis = plt.get_cmap('viridis')
|
| 131 |
+
step_linspace = np.linspace(0, 1, predictions.shape[-1]) # For sampling colours
|
| 132 |
+
with tqdm(total=predictions.shape[-1], initial=0, leave=True, position=1, dynamic_ncols=True) as pbar:
|
| 133 |
+
if verbose: pbar.set_description('Processing frames for maze plotting')
|
| 134 |
+
for stepi in np.arange(0, predictions.shape[-1], 1):
|
| 135 |
+
fig, axes = plt.subplot_mosaic(mosaic, figsize=aspect_ratio)
|
| 136 |
+
for ax in axes.values():
|
| 137 |
+
ax.axis('off')
|
| 138 |
+
guess_maze = draw_path(np.moveaxis(inputs, 0, -1), predictions.argmax(1)[:,stepi], cmap=cmap)
|
| 139 |
+
attention_now = attention_tracking[stepi]
|
| 140 |
+
for hi in range(min((attention_tracking.shape[1], 16))):
|
| 141 |
+
ax = axes[f'head_{hi}']
|
| 142 |
+
attn = attention_tracking[stepi, hi]
|
| 143 |
+
attn = (attn - attn.min())/(np.ptp(attn))
|
| 144 |
+
ax.imshow(attn, cmap=cmap_viridis)
|
| 145 |
+
# Upsample attention just for visualisation
|
| 146 |
+
aggregated_attention = torch.nn.functional.interpolate(torch.from_numpy(attention_now).unsqueeze(0), inputs.shape[-1], mode='bilinear')[0].mean(0).numpy()
|
| 147 |
+
|
| 148 |
+
# Get approximate center of mass
|
| 149 |
+
com_attn = np.copy(aggregated_attention)
|
| 150 |
+
com_attn[com_attn < np.percentile(com_attn, 96)] = 0.0
|
| 151 |
+
aggregated_attention[aggregated_attention < np.percentile(aggregated_attention, 80)] = 0.0
|
| 152 |
+
route_steps.append(find_center_of_mass(com_attn))
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
colour = list(cmap(step_linspace[stepi]))
|
| 156 |
+
route_colours.append(colour)
|
| 157 |
+
|
| 158 |
+
mapped_attention = torch.nn.functional.interpolate(torch.from_numpy(attention_now).unsqueeze(0), inputs.shape[-1], mode='bilinear')[0].mean(0).numpy()
|
| 159 |
+
mapped_attention = (mapped_attention - mapped_attention.min())/np.ptp(mapped_attention)
|
| 160 |
+
# np.clip(guess_maze * (1-mapped_attention[...,np.newaxis]*0.5) + (cmap_viridis(mapped_attention)[:,:,:3] * mapped_attention[...,np.newaxis])*1.3, 0, 1)
|
| 161 |
+
overlay_img = np.clip(guess_maze * (1-mapped_attention[...,np.newaxis]*0.6) + (cmap_viridis(mapped_attention)[:,:,:3] * mapped_attention[...,np.newaxis])*1.1, 0, 1)#np.clip((np.copy(guess_maze)*(1-aggregated_attention[:,:,np.newaxis])*0.7 + (aggregated_attention[:,:,np.newaxis]*3 * np.reshape(np.array(colour)[:3], (1, 1, 3)))), 0, 1)
|
| 162 |
+
axes['overlay'].imshow(overlay_img)
|
| 163 |
+
|
| 164 |
+
y_coords, x_coords = zip(*route_steps)
|
| 165 |
+
y_coords = inputs.shape[-1] - np.array(list(y_coords))-1
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
axes['route'].imshow(np.flip(np.moveaxis(inputs, 0, -1), axis=0), origin='lower')
|
| 169 |
+
# ax.imshow(np.flip(solution_maze, axis=0), origin='lower')
|
| 170 |
+
arrow_scale = 2
|
| 171 |
+
for i in range(len(route_steps)-1):
|
| 172 |
+
dx = x_coords[i+1] - x_coords[i]
|
| 173 |
+
dy = y_coords[i+1] - y_coords[i]
|
| 174 |
+
axes['route'].arrow(x_coords[i], y_coords[i], dx, dy, linewidth=2*arrow_scale, head_width=0.2*arrow_scale, head_length=0.3*arrow_scale, fc=route_colours[i], ec=route_colours[i], length_includes_head = True)
|
| 175 |
+
|
| 176 |
+
fig.tight_layout(pad=0.1) # Adjust spacing
|
| 177 |
+
|
| 178 |
+
# Render the plot to a numpy array
|
| 179 |
+
canvas = fig.canvas
|
| 180 |
+
canvas.draw()
|
| 181 |
+
image_numpy = np.frombuffer(canvas.buffer_rgba(), dtype='uint8')
|
| 182 |
+
image_numpy = image_numpy.reshape(*reversed(canvas.get_width_height()), 4)[:,:,:3] # Get RGB
|
| 183 |
+
|
| 184 |
+
frames.append(image_numpy) # Add to list for GIF
|
| 185 |
+
|
| 186 |
+
# fig.savefig(f'{save_location}/frame.png', dpi=200)
|
| 187 |
+
|
| 188 |
+
plt.close(fig)
|
| 189 |
+
|
| 190 |
+
# # frame = np.clip((np.copy(guess_maze)*0.5 + (aggregated_attention[:,:,np.newaxis] * np.reshape(np.array(colour)[:3], (1, 1, 3)))), 0, 1)
|
| 191 |
+
# frame = torch.nn.functional.interpolate(torch.from_numpy(frame).permute(2,0,1).unsqueeze(0), 256)[0].permute(1,2,0).detach().cpu().numpy()
|
| 192 |
+
# frames.append((frame*255).astype(np.uint8))
|
| 193 |
+
pbar.update(1)
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
y_coords, x_coords = zip(*route_steps)
|
| 197 |
+
y_coords = inputs.shape[-1] - np.array(list(y_coords))-1
|
| 198 |
+
|
| 199 |
+
fig = plt.figure(figsize=(5,5))
|
| 200 |
+
ax = fig.add_subplot(111)
|
| 201 |
+
|
| 202 |
+
ax.imshow(np.flip(np.moveaxis(inputs, 0, -1), axis=0), origin='lower')
|
| 203 |
+
# ax.imshow(np.flip(solution_maze, axis=0), origin='lower')
|
| 204 |
+
arrow_scale = 2
|
| 205 |
+
for i in range(len(route_steps)-1):
|
| 206 |
+
dx = x_coords[i+1] - x_coords[i]
|
| 207 |
+
dy = y_coords[i+1] - y_coords[i]
|
| 208 |
+
plt.arrow(x_coords[i], y_coords[i], dx, dy, linewidth=2*arrow_scale, head_width=0.2*arrow_scale, head_length=0.3*arrow_scale, fc=route_colours[i], ec=route_colours[i], length_includes_head = True)
|
| 209 |
+
|
| 210 |
+
ax.axis('off')
|
| 211 |
+
fig.tight_layout(pad=0)
|
| 212 |
+
fig.savefig(f'{save_location}/route_approximation.png', dpi=200)
|
| 213 |
+
imageio.mimsave(f'{save_location}/prediction.gif', frames, fps=15, loop=100)
|
| 214 |
+
plt.close(fig)
|
tasks/mazes/scripts/train_ctm.sh
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
python -m tasks.mazes.train \
|
| 2 |
+
--model ctm \
|
| 3 |
+
--log_dir logs/mazes/ctm/d=2048--i=512--heads=16--sd=8--nlm=32--synch=64-32-h=32-first-last--iters=75x25--backbone=34-2 \
|
| 4 |
+
--neuron_select_type first-last \
|
| 5 |
+
--dataset mazes-large \
|
| 6 |
+
--synapse_depth 8 \
|
| 7 |
+
--heads 16 \
|
| 8 |
+
--iterations 75 \
|
| 9 |
+
--memory_length 25 \
|
| 10 |
+
--d_model 2048 \
|
| 11 |
+
--d_input 512 \
|
| 12 |
+
--backbone_type resnet34-2 \
|
| 13 |
+
--n_synch_out 64 \
|
| 14 |
+
--n_synch_action 32 \
|
| 15 |
+
--memory_hidden_dims 32 \
|
| 16 |
+
--deep_memory \
|
| 17 |
+
--weight_decay 0.000 \
|
| 18 |
+
--batch_size 64 \
|
| 19 |
+
--batch_size_test 128 \
|
| 20 |
+
--n_test_batches 20 \
|
| 21 |
+
--gradient_clipping -1 \
|
| 22 |
+
--use_scheduler \
|
| 23 |
+
--scheduler_type cosine \
|
| 24 |
+
--warmup_steps 10000 \
|
| 25 |
+
--training_iterations 1000001 \
|
| 26 |
+
--no-do_normalisation \
|
| 27 |
+
--track_every 1000 \
|
| 28 |
+
--lr 1e-4 \
|
| 29 |
+
--no-reload \
|
| 30 |
+
--dropout 0.1 \
|
| 31 |
+
--positional_embedding_type none \
|
| 32 |
+
--maze_route_length 100 \
|
| 33 |
+
--cirriculum_lookahead 5 \
|
| 34 |
+
--device 0 \
|
| 35 |
+
--no-expand_range
|
tasks/mazes/train.py
ADDED
|
@@ -0,0 +1,704 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import random
|
| 4 |
+
|
| 5 |
+
import matplotlib.pyplot as plt
|
| 6 |
+
import numpy as np
|
| 7 |
+
import seaborn as sns
|
| 8 |
+
sns.set_style('darkgrid')
|
| 9 |
+
import torch
|
| 10 |
+
if torch.cuda.is_available():
|
| 11 |
+
# For faster
|
| 12 |
+
torch.set_float32_matmul_precision('high')
|
| 13 |
+
from tqdm.auto import tqdm
|
| 14 |
+
|
| 15 |
+
from data.custom_datasets import MazeImageFolder
|
| 16 |
+
from models.ctm import ContinuousThoughtMachine
|
| 17 |
+
from models.lstm import LSTMBaseline
|
| 18 |
+
from models.ff import FFBaseline
|
| 19 |
+
from tasks.mazes.plotting import make_maze_gif
|
| 20 |
+
from tasks.image_classification.plotting import plot_neural_dynamics
|
| 21 |
+
from utils.housekeeping import set_seed, zip_python_code
|
| 22 |
+
from utils.losses import maze_loss
|
| 23 |
+
from utils.schedulers import WarmupCosineAnnealingLR, WarmupMultiStepLR, warmup
|
| 24 |
+
|
| 25 |
+
import torchvision
|
| 26 |
+
torchvision.disable_beta_transforms_warning()
|
| 27 |
+
|
| 28 |
+
import warnings
|
| 29 |
+
warnings.filterwarnings("ignore", message="using precomputed metric; inverse_transform will be unavailable")
|
| 30 |
+
warnings.filterwarnings('ignore', message='divide by zero encountered in power', category=RuntimeWarning)
|
| 31 |
+
warnings.filterwarnings(
|
| 32 |
+
"ignore",
|
| 33 |
+
"Corrupt EXIF data",
|
| 34 |
+
UserWarning,
|
| 35 |
+
r"^PIL\.TiffImagePlugin$" # Using a regular expression to match the module.
|
| 36 |
+
)
|
| 37 |
+
warnings.filterwarnings(
|
| 38 |
+
"ignore",
|
| 39 |
+
"UserWarning: Metadata Warning",
|
| 40 |
+
UserWarning,
|
| 41 |
+
r"^PIL\.TiffImagePlugin$" # Using a regular expression to match the module.
|
| 42 |
+
)
|
| 43 |
+
warnings.filterwarnings(
|
| 44 |
+
"ignore",
|
| 45 |
+
"UserWarning: Truncated File Read",
|
| 46 |
+
UserWarning,
|
| 47 |
+
r"^PIL\.TiffImagePlugin$" # Using a regular expression to match the module.
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def parse_args():
|
| 52 |
+
parser = argparse.ArgumentParser()
|
| 53 |
+
|
| 54 |
+
# Model Selection
|
| 55 |
+
parser.add_argument('--model', type=str, required=True, choices=['ctm', 'lstm', 'ff'], help='Model type to train.')
|
| 56 |
+
|
| 57 |
+
# Model Architecture
|
| 58 |
+
# Common across all or most
|
| 59 |
+
parser.add_argument('--d_model', type=int, default=512, help='Dimension of the model.')
|
| 60 |
+
parser.add_argument('--dropout', type=float, default=0.0, help='Dropout rate.')
|
| 61 |
+
parser.add_argument('--backbone_type', type=str, default='resnet34-2', help='Type of backbone featureiser.') # Default changed from original script
|
| 62 |
+
# CTM / LSTM specific
|
| 63 |
+
parser.add_argument('--d_input', type=int, default=128, help='Dimension of the input (CTM, LSTM).')
|
| 64 |
+
parser.add_argument('--heads', type=int, default=8, help='Number of attention heads (CTM, LSTM).') # Default changed
|
| 65 |
+
parser.add_argument('--iterations', type=int, default=75, help='Number of internal ticks (CTM, LSTM).')
|
| 66 |
+
parser.add_argument('--positional_embedding_type', type=str, default='none',
|
| 67 |
+
help='Type of positional embedding (CTM, LSTM).', choices=['none',
|
| 68 |
+
'learnable-fourier',
|
| 69 |
+
'multi-learnable-fourier',
|
| 70 |
+
'custom-rotational'])
|
| 71 |
+
|
| 72 |
+
# CTM specific
|
| 73 |
+
parser.add_argument('--synapse_depth', type=int, default=8, help='Depth of U-NET model for synapse. 1=linear, no unet (CTM only).') # Default changed
|
| 74 |
+
parser.add_argument('--n_synch_out', type=int, default=32, help='Number of neurons to use for output synch (CTM only).') # Default changed
|
| 75 |
+
parser.add_argument('--n_synch_action', type=int, default=32, help='Number of neurons to use for observation/action synch (CTM only).') # Default changed
|
| 76 |
+
parser.add_argument('--neuron_select_type', type=str, default='random-pairing', help='Protocol for selecting neuron subset (CTM only).')
|
| 77 |
+
parser.add_argument('--n_random_pairing_self', type=int, default=0, help='Number of neurons paired self-to-self for synch (CTM only).')
|
| 78 |
+
parser.add_argument('--memory_length', type=int, default=25, help='Length of the pre-activation history for NLMS (CTM only).')
|
| 79 |
+
parser.add_argument('--deep_memory', action=argparse.BooleanOptionalAction, default=True,
|
| 80 |
+
help='Use deep memory (CTM only).')
|
| 81 |
+
parser.add_argument('--memory_hidden_dims', type=int, default=32, help='Hidden dimensions of the memory if using deep memory (CTM only).') # Default changed
|
| 82 |
+
parser.add_argument('--dropout_nlm', type=float, default=None, help='Dropout rate for NLMs specifically. Unset to match dropout on the rest of the model (CTM only).')
|
| 83 |
+
parser.add_argument('--do_normalisation', action=argparse.BooleanOptionalAction, default=False, help='Apply normalization in NLMs (CTM only).')
|
| 84 |
+
# LSTM specific
|
| 85 |
+
parser.add_argument('--num_layers', type=int, default=2, help='Number of LSTM stacked layers (LSTM only).') # Added LSTM arg
|
| 86 |
+
|
| 87 |
+
# Task Specific Args (Common to all models for this task)
|
| 88 |
+
parser.add_argument('--maze_route_length', type=int, default=100, help='Length to truncate targets.')
|
| 89 |
+
parser.add_argument('--cirriculum_lookahead', type=int, default=5, help='How far to look ahead for cirriculum.')
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
# Training
|
| 93 |
+
parser.add_argument('--expand_range', action=argparse.BooleanOptionalAction, default=True, help='Mazes between 0 and 1 = False. Between -1 and 1 = True. Legacy checkpoints use 0 and 1.')
|
| 94 |
+
parser.add_argument('--batch_size', type=int, default=16, help='Batch size for training.') # Default changed
|
| 95 |
+
parser.add_argument('--batch_size_test', type=int, default=64, help='Batch size for testing.') # Default changed
|
| 96 |
+
parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate for the model.') # Default changed
|
| 97 |
+
parser.add_argument('--training_iterations', type=int, default=100001, help='Number of training iterations.')
|
| 98 |
+
parser.add_argument('--warmup_steps', type=int, default=5000, help='Number of warmup steps.')
|
| 99 |
+
parser.add_argument('--use_scheduler', action=argparse.BooleanOptionalAction, default=True, help='Use a learning rate scheduler.')
|
| 100 |
+
parser.add_argument('--scheduler_type', type=str, default='cosine', choices=['multistep', 'cosine'], help='Type of learning rate scheduler.')
|
| 101 |
+
parser.add_argument('--milestones', type=int, default=[8000, 15000, 20000], nargs='+', help='Learning rate scheduler milestones.')
|
| 102 |
+
parser.add_argument('--gamma', type=float, default=0.1, help='Learning rate scheduler gamma for multistep.')
|
| 103 |
+
parser.add_argument('--weight_decay', type=float, default=0.0, help='Weight decay factor.')
|
| 104 |
+
parser.add_argument('--weight_decay_exclusion_list', type=str, nargs='+', default=[], help='List to exclude from weight decay. Typically good: bn, ln, bias, start')
|
| 105 |
+
parser.add_argument('--num_workers_train', type=int, default=0, help='Num workers training.') # Renamed from num_workers, kept default
|
| 106 |
+
parser.add_argument('--gradient_clipping', type=float, default=-1, help='Gradient quantile clipping value (-1 to disable).')
|
| 107 |
+
parser.add_argument('--do_compile', action=argparse.BooleanOptionalAction, default=False, help='Try to compile model components.')
|
| 108 |
+
|
| 109 |
+
# Logging and Saving
|
| 110 |
+
parser.add_argument('--log_dir', type=str, default='logs/scratch', help='Directory for logging.')
|
| 111 |
+
parser.add_argument('--dataset', type=str, default='mazes-medium', help='Dataset to use.', choices=['mazes-medium', 'mazes-large', 'mazes-small'])
|
| 112 |
+
parser.add_argument('--data_root', type=str, default='data/mazes', help='Data root.')
|
| 113 |
+
|
| 114 |
+
parser.add_argument('--save_every', type=int, default=1000, help='Save checkpoints every this many iterations.')
|
| 115 |
+
parser.add_argument('--seed', type=int, default=412, help='Random seed.')
|
| 116 |
+
parser.add_argument('--reload', action=argparse.BooleanOptionalAction, default=False, help='Reload from disk?')
|
| 117 |
+
parser.add_argument('--reload_model_only', action=argparse.BooleanOptionalAction, default=False, help='Reload only the model from disk?')
|
| 118 |
+
parser.add_argument('--strict_reload', action=argparse.BooleanOptionalAction, default=True, help='Should use strict reload for model weights.') # Added back
|
| 119 |
+
parser.add_argument('--ignore_metrics_when_reloading', action=argparse.BooleanOptionalAction, default=False, help='Ignore metrics when reloading (for debugging)?') # Added back
|
| 120 |
+
|
| 121 |
+
# Tracking
|
| 122 |
+
parser.add_argument('--track_every', type=int, default=1000, help='Track metrics every this many iterations.')
|
| 123 |
+
parser.add_argument('--n_test_batches', type=int, default=20, help='How many minibatches to approx metrics. Set to -1 for full eval') # Default changed
|
| 124 |
+
|
| 125 |
+
# Device
|
| 126 |
+
parser.add_argument('--device', type=int, nargs='+', default=[-1], help='List of GPU(s) to use. Set to -1 to use CPU.')
|
| 127 |
+
parser.add_argument('--use_amp', action=argparse.BooleanOptionalAction, default=False, help='AMP autocast.')
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
args = parser.parse_args()
|
| 131 |
+
return args
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
if __name__=='__main__':
|
| 135 |
+
|
| 136 |
+
# Hosuekeeping
|
| 137 |
+
args = parse_args()
|
| 138 |
+
|
| 139 |
+
set_seed(args.seed, False)
|
| 140 |
+
if not os.path.exists(args.log_dir): os.makedirs(args.log_dir)
|
| 141 |
+
|
| 142 |
+
assert args.dataset in ['mazes-medium', 'mazes-large', 'mazes-small']
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
prediction_reshaper = [args.maze_route_length, 5] # Problem specific
|
| 147 |
+
args.out_dims = args.maze_route_length * 5 # Output dimension before reshaping
|
| 148 |
+
|
| 149 |
+
# For total reproducibility
|
| 150 |
+
zip_python_code(f'{args.log_dir}/repo_state.zip')
|
| 151 |
+
with open(f'{args.log_dir}/args.txt', 'w') as f:
|
| 152 |
+
print(args, file=f)
|
| 153 |
+
|
| 154 |
+
# Configure device string (support MPS on macOS)
|
| 155 |
+
if args.device[0] != -1:
|
| 156 |
+
device = f'cuda:{args.device[0]}'
|
| 157 |
+
elif torch.backends.mps.is_available():
|
| 158 |
+
device = 'mps'
|
| 159 |
+
else:
|
| 160 |
+
device = 'cpu'
|
| 161 |
+
print(f'Running model {args.model} on {device}')
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
# Build model conditionally
|
| 165 |
+
model = None
|
| 166 |
+
if args.model == 'ctm':
|
| 167 |
+
model = ContinuousThoughtMachine(
|
| 168 |
+
iterations=args.iterations,
|
| 169 |
+
d_model=args.d_model,
|
| 170 |
+
d_input=args.d_input,
|
| 171 |
+
heads=args.heads,
|
| 172 |
+
n_synch_out=args.n_synch_out,
|
| 173 |
+
n_synch_action=args.n_synch_action,
|
| 174 |
+
synapse_depth=args.synapse_depth,
|
| 175 |
+
memory_length=args.memory_length,
|
| 176 |
+
deep_nlms=args.deep_memory,
|
| 177 |
+
memory_hidden_dims=args.memory_hidden_dims,
|
| 178 |
+
do_layernorm_nlm=args.do_normalisation,
|
| 179 |
+
backbone_type=args.backbone_type,
|
| 180 |
+
positional_embedding_type=args.positional_embedding_type,
|
| 181 |
+
out_dims=args.out_dims,
|
| 182 |
+
prediction_reshaper=prediction_reshaper,
|
| 183 |
+
dropout=args.dropout,
|
| 184 |
+
dropout_nlm=args.dropout_nlm,
|
| 185 |
+
neuron_select_type=args.neuron_select_type,
|
| 186 |
+
n_random_pairing_self=args.n_random_pairing_self,
|
| 187 |
+
).to(device)
|
| 188 |
+
elif args.model == 'lstm':
|
| 189 |
+
model = LSTMBaseline(
|
| 190 |
+
num_layers=args.num_layers,
|
| 191 |
+
iterations=args.iterations,
|
| 192 |
+
d_model=args.d_model,
|
| 193 |
+
d_input=args.d_input,
|
| 194 |
+
heads=args.heads,
|
| 195 |
+
backbone_type=args.backbone_type,
|
| 196 |
+
positional_embedding_type=args.positional_embedding_type,
|
| 197 |
+
out_dims=args.out_dims,
|
| 198 |
+
prediction_reshaper=prediction_reshaper,
|
| 199 |
+
dropout=args.dropout,
|
| 200 |
+
).to(device)
|
| 201 |
+
elif args.model == 'ff':
|
| 202 |
+
model = FFBaseline(
|
| 203 |
+
d_model=args.d_model,
|
| 204 |
+
backbone_type=args.backbone_type,
|
| 205 |
+
out_dims=args.out_dims,
|
| 206 |
+
dropout=args.dropout,
|
| 207 |
+
).to(device)
|
| 208 |
+
else:
|
| 209 |
+
raise ValueError(f"Unknown model type: {args.model}")
|
| 210 |
+
|
| 211 |
+
try:
|
| 212 |
+
# Determine pseudo input shape based on dataset
|
| 213 |
+
h_w = 39 if args.dataset in ['mazes-small', 'mazes-medium'] else 99 # Example dimensions
|
| 214 |
+
pseudo_inputs = torch.zeros((1, 3, h_w, h_w), device=device).float()
|
| 215 |
+
model(pseudo_inputs)
|
| 216 |
+
except Exception as e:
|
| 217 |
+
print(f"Warning: Pseudo forward pass failed: {e}")
|
| 218 |
+
|
| 219 |
+
print(f'Total params: {sum(p.numel() for p in model.parameters())}')
|
| 220 |
+
|
| 221 |
+
# Data
|
| 222 |
+
dataset_mean = [0,0,0] # For plotting later
|
| 223 |
+
dataset_std = [1,1,1]
|
| 224 |
+
|
| 225 |
+
which_maze = args.dataset.split('-')[-1]
|
| 226 |
+
data_root = f'{args.data_root}/{which_maze}'
|
| 227 |
+
|
| 228 |
+
train_data = MazeImageFolder(root=f'{data_root}/train/', which_set='train', maze_route_length=args.maze_route_length, expand_range=args.expand_range)
|
| 229 |
+
test_data = MazeImageFolder(root=f'{data_root}/test/', which_set='test', maze_route_length=args.maze_route_length, expand_range=args.expand_range)
|
| 230 |
+
|
| 231 |
+
num_workers_test = 1 # Defaulting to 1, can be changed
|
| 232 |
+
trainloader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers_train, drop_last=True)
|
| 233 |
+
testloader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size_test, shuffle=True, num_workers=num_workers_test, drop_last=False)
|
| 234 |
+
|
| 235 |
+
# For lazy modules so that we can get param count
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
model.train()
|
| 239 |
+
|
| 240 |
+
# Optimizer and scheduler
|
| 241 |
+
decay_params = []
|
| 242 |
+
no_decay_params = []
|
| 243 |
+
no_decay_names = []
|
| 244 |
+
for name, param in model.named_parameters():
|
| 245 |
+
if not param.requires_grad:
|
| 246 |
+
continue # Skip parameters that don't require gradients
|
| 247 |
+
if any(exclusion_str in name for exclusion_str in args.weight_decay_exclusion_list):
|
| 248 |
+
no_decay_params.append(param)
|
| 249 |
+
no_decay_names.append(name)
|
| 250 |
+
else:
|
| 251 |
+
decay_params.append(param)
|
| 252 |
+
if len(no_decay_names):
|
| 253 |
+
print(f'WARNING, excluding: {no_decay_names}')
|
| 254 |
+
|
| 255 |
+
# Optimizer and scheduler (Common setup)
|
| 256 |
+
if len(no_decay_names) and args.weight_decay!=0:
|
| 257 |
+
optimizer = torch.optim.AdamW([{'params': decay_params, 'weight_decay':args.weight_decay},
|
| 258 |
+
{'params': no_decay_params, 'weight_decay':0}],
|
| 259 |
+
lr=args.lr,
|
| 260 |
+
eps=1e-8 if not args.use_amp else 1e-6)
|
| 261 |
+
else:
|
| 262 |
+
optimizer = torch.optim.AdamW(model.parameters(),
|
| 263 |
+
lr=args.lr,
|
| 264 |
+
eps=1e-8 if not args.use_amp else 1e-6,
|
| 265 |
+
weight_decay=args.weight_decay)
|
| 266 |
+
|
| 267 |
+
warmup_schedule = warmup(args.warmup_steps)
|
| 268 |
+
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warmup_schedule.step)
|
| 269 |
+
if args.use_scheduler:
|
| 270 |
+
if args.scheduler_type == 'multistep':
|
| 271 |
+
scheduler = WarmupMultiStepLR(optimizer, warmup_steps=args.warmup_steps, milestones=args.milestones, gamma=args.gamma)
|
| 272 |
+
elif args.scheduler_type == 'cosine':
|
| 273 |
+
scheduler = WarmupCosineAnnealingLR(optimizer, args.warmup_steps, args.training_iterations, warmup_start_lr=1e-20, eta_min=1e-7)
|
| 274 |
+
else:
|
| 275 |
+
raise NotImplementedError
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
# Metrics tracking
|
| 279 |
+
start_iter = 0
|
| 280 |
+
train_losses = []
|
| 281 |
+
test_losses = []
|
| 282 |
+
train_accuracies = [] # Per tick/step accuracy list
|
| 283 |
+
test_accuracies = []
|
| 284 |
+
train_accuracies_most_certain = [] # Accuracy, fine-grained
|
| 285 |
+
test_accuracies_most_certain = []
|
| 286 |
+
train_accuracies_most_certain_permaze = [] # Full maze accuracy
|
| 287 |
+
test_accuracies_most_certain_permaze = []
|
| 288 |
+
iters = []
|
| 289 |
+
|
| 290 |
+
scaler = torch.amp.GradScaler("cuda" if "cuda" in device else "cpu", enabled=args.use_amp)
|
| 291 |
+
if args.reload:
|
| 292 |
+
checkpoint_path = f'{args.log_dir}/checkpoint.pt'
|
| 293 |
+
if os.path.isfile(checkpoint_path):
|
| 294 |
+
print(f'Reloading from: {checkpoint_path}')
|
| 295 |
+
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
|
| 296 |
+
if not args.strict_reload: print('WARNING: not using strict reload for model weights!')
|
| 297 |
+
load_result = model.load_state_dict(checkpoint['model_state_dict'], strict=args.strict_reload)
|
| 298 |
+
print(f" Loaded state_dict. Missing: {load_result.missing_keys}, Unexpected: {load_result.unexpected_keys}")
|
| 299 |
+
|
| 300 |
+
if not args.reload_model_only:
|
| 301 |
+
print('Reloading optimizer etc.')
|
| 302 |
+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
| 303 |
+
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
| 304 |
+
scaler.load_state_dict(checkpoint['scaler_state_dict']) # Load scaler state
|
| 305 |
+
start_iter = checkpoint['iteration']
|
| 306 |
+
|
| 307 |
+
if not args.ignore_metrics_when_reloading:
|
| 308 |
+
train_losses = checkpoint['train_losses']
|
| 309 |
+
test_losses = checkpoint['test_losses']
|
| 310 |
+
train_accuracies = checkpoint['train_accuracies']
|
| 311 |
+
test_accuracies = checkpoint['test_accuracies']
|
| 312 |
+
iters = checkpoint['iters']
|
| 313 |
+
train_accuracies_most_certain = checkpoint['train_accuracies_most_certain']
|
| 314 |
+
test_accuracies_most_certain = checkpoint['test_accuracies_most_certain']
|
| 315 |
+
train_accuracies_most_certain_permaze = checkpoint['train_accuracies_most_certain_permaze']
|
| 316 |
+
test_accuracies_most_certain_permaze = checkpoint['test_accuracies_most_certain_permaze']
|
| 317 |
+
else:
|
| 318 |
+
print("Ignoring metrics history upon reload.")
|
| 319 |
+
|
| 320 |
+
else:
|
| 321 |
+
print('Only reloading model!')
|
| 322 |
+
|
| 323 |
+
if 'torch_rng_state' in checkpoint:
|
| 324 |
+
# Reset seeds
|
| 325 |
+
torch.set_rng_state(checkpoint['torch_rng_state'].cpu().byte())
|
| 326 |
+
np.random.set_state(checkpoint['numpy_rng_state'])
|
| 327 |
+
random.setstate(checkpoint['random_rng_state'])
|
| 328 |
+
|
| 329 |
+
del checkpoint
|
| 330 |
+
import gc
|
| 331 |
+
gc.collect()
|
| 332 |
+
if torch.cuda.is_available():
|
| 333 |
+
torch.cuda.empty_cache()
|
| 334 |
+
|
| 335 |
+
if args.do_compile:
|
| 336 |
+
print('Compiling...')
|
| 337 |
+
if hasattr(model, 'backbone'):
|
| 338 |
+
model.backbone = torch.compile(model.backbone, mode='reduce-overhead', fullgraph=True)
|
| 339 |
+
# Compile synapses only for CTM
|
| 340 |
+
if args.model == 'ctm':
|
| 341 |
+
model.synapses = torch.compile(model.synapses, mode='reduce-overhead', fullgraph=True)
|
| 342 |
+
|
| 343 |
+
# Training
|
| 344 |
+
iterator = iter(trainloader)
|
| 345 |
+
with tqdm(total=args.training_iterations, initial=start_iter, leave=False, position=0, dynamic_ncols=True) as pbar:
|
| 346 |
+
for bi in range(start_iter, args.training_iterations):
|
| 347 |
+
current_lr = optimizer.param_groups[-1]['lr']
|
| 348 |
+
|
| 349 |
+
try:
|
| 350 |
+
inputs, targets = next(iterator)
|
| 351 |
+
except StopIteration:
|
| 352 |
+
iterator = iter(trainloader)
|
| 353 |
+
inputs, targets = next(iterator)
|
| 354 |
+
|
| 355 |
+
inputs = inputs.to(device)
|
| 356 |
+
targets = targets.to(device) # Shape (B, SeqLength)
|
| 357 |
+
|
| 358 |
+
# All for nice metric printing:
|
| 359 |
+
loss = None
|
| 360 |
+
accuracy_finegrained = None # Per-step accuracy at chosen tick
|
| 361 |
+
where_most_certain_val = -1.0 # Default value
|
| 362 |
+
where_most_certain_std = 0.0
|
| 363 |
+
where_most_certain_min = -1
|
| 364 |
+
where_most_certain_max = -1
|
| 365 |
+
upto_where_mean = -1.0
|
| 366 |
+
upto_where_std = 0.0
|
| 367 |
+
upto_where_min = -1
|
| 368 |
+
upto_where_max = -1
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
# Model-specific forward, reshape, and loss calculation
|
| 372 |
+
with torch.autocast(device_type="cuda" if "cuda" in device else "cpu", dtype=torch.float16, enabled=args.use_amp):
|
| 373 |
+
if args.do_compile: # CUDAGraph marking applied if compiling any model
|
| 374 |
+
torch.compiler.cudagraph_mark_step_begin()
|
| 375 |
+
|
| 376 |
+
if args.model == 'ctm':
|
| 377 |
+
# CTM output: (B, SeqLength*5, Ticks), Certainties: (B, Ticks)
|
| 378 |
+
predictions_raw, certainties, synchronisation = model(inputs)
|
| 379 |
+
# Reshape predictions: (B, SeqLength, 5, Ticks)
|
| 380 |
+
predictions = predictions_raw.reshape(predictions_raw.size(0), -1, 5, predictions_raw.size(-1))
|
| 381 |
+
loss, where_most_certain, upto_where = maze_loss(predictions, certainties, targets, cirriculum_lookahead=args.cirriculum_lookahead, use_most_certain=True)
|
| 382 |
+
# Accuracy uses predictions[B, S, C, T] indexed at where_most_certain[B] -> gives (B, S, C) -> argmax(2) -> (B,S)
|
| 383 |
+
accuracy_finegrained = (predictions.argmax(2)[torch.arange(predictions.size(0), device=predictions.device), :, where_most_certain] == targets).float().mean().item()
|
| 384 |
+
|
| 385 |
+
elif args.model == 'lstm':
|
| 386 |
+
# LSTM output: (B, SeqLength*5, Ticks), Certainties: (B, Ticks)
|
| 387 |
+
predictions_raw, certainties, synchronisation = model(inputs)
|
| 388 |
+
# Reshape predictions: (B, SeqLength, 5, Ticks)
|
| 389 |
+
predictions = predictions_raw.reshape(predictions_raw.size(0), -1, 5, predictions_raw.size(-1))
|
| 390 |
+
loss, where_most_certain, upto_where = maze_loss(predictions, certainties, targets, cirriculum_lookahead=args.cirriculum_lookahead, use_most_certain=False)
|
| 391 |
+
# where_most_certain should be -1 (last tick) here. Accuracy calc follows same logic.
|
| 392 |
+
accuracy_finegrained = (predictions.argmax(2)[torch.arange(predictions.size(0), device=predictions.device), :, where_most_certain] == targets).float().mean().item()
|
| 393 |
+
|
| 394 |
+
elif args.model == 'ff':
|
| 395 |
+
# Assume FF output: (B, SeqLength*5)
|
| 396 |
+
predictions_raw = model(inputs)
|
| 397 |
+
# Reshape predictions: (B, SeqLength, 5)
|
| 398 |
+
predictions = predictions_raw.reshape(predictions_raw.size(0), -1, 5)
|
| 399 |
+
# FF has no certainties, pass None. maze_loss must handle this.
|
| 400 |
+
# Unsqueeze predictions for compatibility with maze loss calcluation
|
| 401 |
+
loss, where_most_certain, upto_where = maze_loss(predictions.unsqueeze(-1), None, targets, cirriculum_lookahead=args.cirriculum_lookahead, use_most_certain=False)
|
| 402 |
+
# where_most_certain should be -1 here. Accuracy uses 3D prediction tensor.
|
| 403 |
+
accuracy_finegrained = (predictions.argmax(2) == targets).float().mean().item()
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
# Extract stats from loss outputs if they are tensors
|
| 407 |
+
if torch.is_tensor(where_most_certain):
|
| 408 |
+
where_most_certain_val = where_most_certain.float().mean().item()
|
| 409 |
+
where_most_certain_std = where_most_certain.float().std().item()
|
| 410 |
+
where_most_certain_min = where_most_certain.min().item()
|
| 411 |
+
where_most_certain_max = where_most_certain.max().item()
|
| 412 |
+
elif isinstance(where_most_certain, int): # Handle case where it might return -1 directly
|
| 413 |
+
where_most_certain_val = float(where_most_certain)
|
| 414 |
+
where_most_certain_min = where_most_certain
|
| 415 |
+
where_most_certain_max = where_most_certain
|
| 416 |
+
|
| 417 |
+
if isinstance(upto_where, (np.ndarray, list)) and len(upto_where) > 0: # Check if it's a list/array
|
| 418 |
+
upto_where_mean = np.mean(upto_where)
|
| 419 |
+
upto_where_std = np.std(upto_where)
|
| 420 |
+
upto_where_min = np.min(upto_where)
|
| 421 |
+
upto_where_max = np.max(upto_where)
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
scaler.scale(loss).backward()
|
| 425 |
+
|
| 426 |
+
if args.gradient_clipping!=-1:
|
| 427 |
+
scaler.unscale_(optimizer)
|
| 428 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.gradient_clipping)
|
| 429 |
+
|
| 430 |
+
scaler.step(optimizer)
|
| 431 |
+
scaler.update()
|
| 432 |
+
optimizer.zero_grad(set_to_none=True)
|
| 433 |
+
scheduler.step()
|
| 434 |
+
|
| 435 |
+
# Conditional Tqdm Description
|
| 436 |
+
pbar_desc = f'Loss={loss.item():0.3f}. Acc(step)={accuracy_finegrained:0.3f}. LR={current_lr:0.6f}.'
|
| 437 |
+
if args.model in ['ctm', 'lstm'] or torch.is_tensor(where_most_certain): # Show stats if available
|
| 438 |
+
pbar_desc += f' Where_certain={where_most_certain_val:0.2f}+-{where_most_certain_std:0.2f} ({where_most_certain_min:d}<->{where_most_certain_max:d}).'
|
| 439 |
+
if isinstance(upto_where, (np.ndarray, list)) and len(upto_where) > 0:
|
| 440 |
+
pbar_desc += f' Path pred stats: {upto_where_mean:0.2f}+-{upto_where_std:0.2f} ({upto_where_min:d} --> {upto_where_max:d})'
|
| 441 |
+
|
| 442 |
+
pbar.set_description(f'Dataset={args.dataset}. Model={args.model}. {pbar_desc}')
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
# Metrics tracking and plotting
|
| 446 |
+
if bi%args.track_every==0 and (bi != 0 or args.reload_model_only):
|
| 447 |
+
model.eval() # Use eval mode for consistency during tracking
|
| 448 |
+
with torch.inference_mode(): # Use inference mode for tracking
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
|
| 453 |
+
# --- Quantitative Metrics ---
|
| 454 |
+
iters.append(bi)
|
| 455 |
+
# Re-initialize metric lists for this evaluation step
|
| 456 |
+
current_train_losses_eval = []
|
| 457 |
+
current_test_losses_eval = []
|
| 458 |
+
current_train_accuracies_eval = []
|
| 459 |
+
current_test_accuracies_eval = []
|
| 460 |
+
current_train_accuracies_most_certain_eval = []
|
| 461 |
+
current_test_accuracies_most_certain_eval = []
|
| 462 |
+
current_train_accuracies_most_certain_permaze_eval = []
|
| 463 |
+
current_test_accuracies_most_certain_permaze_eval = []
|
| 464 |
+
|
| 465 |
+
# TRAIN METRICS
|
| 466 |
+
pbar.set_description('Tracking: Computing TRAIN metrics')
|
| 467 |
+
loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size_test, shuffle=True, num_workers=num_workers_test) # Use consistent num_workers
|
| 468 |
+
all_targets_list = []
|
| 469 |
+
all_predictions_list = [] # Per step/tick predictions argmax (N, S, T) or (N, S)
|
| 470 |
+
all_predictions_most_certain_list = [] # Predictions at chosen step/tick argmax (N, S)
|
| 471 |
+
all_losses = []
|
| 472 |
+
|
| 473 |
+
with tqdm(total=len(loader), initial=0, leave=False, position=1, dynamic_ncols=True) as pbar_inner:
|
| 474 |
+
for inferi, (inputs, targets) in enumerate(loader):
|
| 475 |
+
inputs = inputs.to(device)
|
| 476 |
+
targets = targets.to(device)
|
| 477 |
+
all_targets_list.append(targets.detach().cpu().numpy()) # N x S
|
| 478 |
+
|
| 479 |
+
# Model-specific forward, reshape, loss for evaluation
|
| 480 |
+
if args.model == 'ctm':
|
| 481 |
+
predictions_raw, certainties, _ = model(inputs)
|
| 482 |
+
predictions = predictions_raw.reshape(predictions_raw.size(0), -1, 5, predictions_raw.size(-1)) # B,S,C,T
|
| 483 |
+
loss, where_most_certain, _ = maze_loss(predictions, certainties, targets, use_most_certain=True)
|
| 484 |
+
all_predictions_list.append(predictions.argmax(2).detach().cpu().numpy()) # B,S,C,T -> argmax class -> B,S,T
|
| 485 |
+
pred_at_certain = predictions.argmax(2)[torch.arange(predictions.size(0), device=predictions.device), :, where_most_certain] # B,S
|
| 486 |
+
all_predictions_most_certain_list.append(pred_at_certain.detach().cpu().numpy())
|
| 487 |
+
|
| 488 |
+
elif args.model == 'lstm':
|
| 489 |
+
predictions_raw, certainties, _ = model(inputs)
|
| 490 |
+
predictions = predictions_raw.reshape(predictions_raw.size(0), -1, 5, predictions_raw.size(-1)) # B,S,C,T
|
| 491 |
+
loss, where_most_certain, _ = maze_loss(predictions, certainties, targets, use_most_certain=False) # where = -1
|
| 492 |
+
all_predictions_list.append(predictions.argmax(2).detach().cpu().numpy()) # B,S,C,T
|
| 493 |
+
pred_at_certain = predictions.argmax(2)[torch.arange(predictions.size(0), device=predictions.device), :, where_most_certain] # B,S (at last tick)
|
| 494 |
+
all_predictions_most_certain_list.append(pred_at_certain.detach().cpu().numpy())
|
| 495 |
+
|
| 496 |
+
elif args.model == 'ff':
|
| 497 |
+
predictions_raw = model(inputs) # B, S*C
|
| 498 |
+
predictions = predictions_raw.reshape(predictions_raw.size(0), -1, 5) # B,S,C
|
| 499 |
+
loss, where_most_certain, _ = maze_loss(predictions.unsqueeze(-1), None, targets, use_most_certain=False) # where = -1
|
| 500 |
+
all_predictions_list.append(predictions.argmax(2).detach().cpu().numpy()) # B,S
|
| 501 |
+
all_predictions_most_certain_list.append(predictions.argmax(2).detach().cpu().numpy()) # B,S (same as above for FF)
|
| 502 |
+
|
| 503 |
+
|
| 504 |
+
all_losses.append(loss.item())
|
| 505 |
+
|
| 506 |
+
if args.n_test_batches != -1 and inferi >= args.n_test_batches -1 : break
|
| 507 |
+
pbar_inner.set_description(f'Computing metrics for train (Batch {inferi+1})')
|
| 508 |
+
pbar_inner.update(1)
|
| 509 |
+
|
| 510 |
+
all_targets = np.concatenate(all_targets_list) # N, S
|
| 511 |
+
all_predictions = np.concatenate(all_predictions_list) # N, S, T or N, S
|
| 512 |
+
all_predictions_most_certain = np.concatenate(all_predictions_most_certain_list) # N, S
|
| 513 |
+
|
| 514 |
+
train_losses.append(np.mean(all_losses))
|
| 515 |
+
# Calculate per step/tick accuracy averaged over batches
|
| 516 |
+
if args.model in ['ctm', 'lstm']:
|
| 517 |
+
# all_predictions shape (N, S, T), all_targets shape (N, S) -> compare targets to each tick prediction
|
| 518 |
+
train_accuracies.append(np.mean(all_predictions == all_targets[:,:,np.newaxis], axis=0)) # Mean over N -> (S, T)
|
| 519 |
+
else: # FF
|
| 520 |
+
# all_predictions shape (N, S), all_targets shape (N, S)
|
| 521 |
+
train_accuracies.append(np.mean(all_predictions == all_targets, axis=0)) # Mean over N -> (S,)
|
| 522 |
+
|
| 523 |
+
# Calculate accuracy at chosen step/tick ("most certain") averaged over all steps and batches
|
| 524 |
+
train_accuracies_most_certain.append((all_targets == all_predictions_most_certain).mean()) # Scalar
|
| 525 |
+
# Calculate full maze accuracy at chosen step/tick averaged over batches
|
| 526 |
+
train_accuracies_most_certain_permaze.append((all_targets == all_predictions_most_certain).reshape(all_targets.shape[0], -1).all(-1).mean()) # Scalar
|
| 527 |
+
|
| 528 |
+
|
| 529 |
+
# TEST METRICS
|
| 530 |
+
pbar.set_description('Tracking: Computing TEST metrics')
|
| 531 |
+
loader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size_test, shuffle=True, num_workers=num_workers_test)
|
| 532 |
+
all_targets_list = []
|
| 533 |
+
all_predictions_list = []
|
| 534 |
+
all_predictions_most_certain_list = []
|
| 535 |
+
all_losses = []
|
| 536 |
+
|
| 537 |
+
with tqdm(total=len(loader), initial=0, leave=False, position=1, dynamic_ncols=True) as pbar_inner:
|
| 538 |
+
for inferi, (inputs, targets) in enumerate(loader):
|
| 539 |
+
inputs = inputs.to(device)
|
| 540 |
+
targets = targets.to(device)
|
| 541 |
+
all_targets_list.append(targets.detach().cpu().numpy())
|
| 542 |
+
|
| 543 |
+
# Model-specific forward, reshape, loss for evaluation
|
| 544 |
+
if args.model == 'ctm':
|
| 545 |
+
predictions_raw, certainties, _ = model(inputs)
|
| 546 |
+
predictions = predictions_raw.reshape(predictions_raw.size(0), -1, 5, predictions_raw.size(-1)) # B,S,C,T
|
| 547 |
+
loss, where_most_certain, _ = maze_loss(predictions, certainties, targets, use_most_certain=True)
|
| 548 |
+
all_predictions_list.append(predictions.argmax(2).detach().cpu().numpy()) # B,S,T
|
| 549 |
+
pred_at_certain = predictions.argmax(2)[torch.arange(predictions.size(0), device=predictions.device), :, where_most_certain] # B,S
|
| 550 |
+
all_predictions_most_certain_list.append(pred_at_certain.detach().cpu().numpy())
|
| 551 |
+
|
| 552 |
+
elif args.model == 'lstm':
|
| 553 |
+
predictions_raw, certainties, _ = model(inputs)
|
| 554 |
+
predictions = predictions_raw.reshape(predictions_raw.size(0), -1, 5, predictions_raw.size(-1)) # B,S,C,T
|
| 555 |
+
loss, where_most_certain, _ = maze_loss(predictions, certainties, targets, use_most_certain=False) # where = -1
|
| 556 |
+
all_predictions_list.append(predictions.argmax(2).detach().cpu().numpy()) # B,S,T
|
| 557 |
+
pred_at_certain = predictions.argmax(2)[torch.arange(predictions.size(0), device=predictions.device), :, where_most_certain] # B,S (at last tick)
|
| 558 |
+
all_predictions_most_certain_list.append(pred_at_certain.detach().cpu().numpy())
|
| 559 |
+
|
| 560 |
+
elif args.model == 'ff':
|
| 561 |
+
predictions_raw = model(inputs) # B, S*C
|
| 562 |
+
predictions = predictions_raw.reshape(predictions_raw.size(0), -1, 5) # B,S,C
|
| 563 |
+
loss, where_most_certain, _ = maze_loss(predictions.unsqueeze(-1), None, targets, use_most_certain=False) # where = -1
|
| 564 |
+
all_predictions_list.append(predictions.argmax(2).detach().cpu().numpy()) # B,S
|
| 565 |
+
all_predictions_most_certain_list.append(predictions.argmax(2).detach().cpu().numpy()) # B,S (same as above for FF)
|
| 566 |
+
|
| 567 |
+
|
| 568 |
+
all_losses.append(loss.item())
|
| 569 |
+
|
| 570 |
+
if args.n_test_batches != -1 and inferi >= args.n_test_batches -1: break
|
| 571 |
+
pbar_inner.set_description(f'Computing metrics for test (Batch {inferi+1})')
|
| 572 |
+
pbar_inner.update(1)
|
| 573 |
+
|
| 574 |
+
all_targets = np.concatenate(all_targets_list)
|
| 575 |
+
all_predictions = np.concatenate(all_predictions_list)
|
| 576 |
+
all_predictions_most_certain = np.concatenate(all_predictions_most_certain_list)
|
| 577 |
+
|
| 578 |
+
test_losses.append(np.mean(all_losses))
|
| 579 |
+
# Calculate per step/tick accuracy
|
| 580 |
+
if args.model in ['ctm', 'lstm']:
|
| 581 |
+
test_accuracies.append(np.mean(all_predictions == all_targets[:,:,np.newaxis], axis=0)) # -> (S, T)
|
| 582 |
+
else: # FF
|
| 583 |
+
test_accuracies.append(np.mean(all_predictions == all_targets, axis=0)) # -> (S,)
|
| 584 |
+
|
| 585 |
+
# Calculate "most certain" accuracy
|
| 586 |
+
test_accuracies_most_certain.append((all_targets == all_predictions_most_certain).mean()) # Scalar
|
| 587 |
+
# Calculate full maze accuracy
|
| 588 |
+
test_accuracies_most_certain_permaze.append((all_targets == all_predictions_most_certain).reshape(all_targets.shape[0], -1).all(-1).mean()) # Scalar
|
| 589 |
+
|
| 590 |
+
|
| 591 |
+
# --- Plotting ---
|
| 592 |
+
# Accuracy Plot (Handling different dimensions)
|
| 593 |
+
figacc = plt.figure(figsize=(10, 10))
|
| 594 |
+
axacc_train = figacc.add_subplot(211)
|
| 595 |
+
axacc_test = figacc.add_subplot(212)
|
| 596 |
+
cm = sns.color_palette("viridis", as_cmap=True)
|
| 597 |
+
|
| 598 |
+
# Plot per step/tick accuracy
|
| 599 |
+
# train_accuracies is List[(S, T)] or List[(S,)]
|
| 600 |
+
# We need to average over S dimension for plotting
|
| 601 |
+
train_acc_plot = [np.mean(acc_s) for acc_s in train_accuracies] # List[Scalar] or List[Scalar] after mean
|
| 602 |
+
test_acc_plot = [np.mean(acc_s) for acc_s in test_accuracies] # List[Scalar] or List[Scalar] after mean
|
| 603 |
+
|
| 604 |
+
axacc_train.plot(iters, train_acc_plot, 'g-', alpha=0.5, label='Avg Step Acc')
|
| 605 |
+
axacc_test.plot(iters, test_acc_plot, 'g-', alpha=0.5, label='Avg Step Acc')
|
| 606 |
+
|
| 607 |
+
|
| 608 |
+
# Plot most certain accuracy
|
| 609 |
+
axacc_train.plot(iters, train_accuracies_most_certain, 'k--', alpha=0.7, label='Most Certain (Avg Step)')
|
| 610 |
+
axacc_test.plot(iters, test_accuracies_most_certain, 'k--', alpha=0.7, label='Most Certain (Avg Step)')
|
| 611 |
+
# Plot full maze accuracy
|
| 612 |
+
axacc_train.plot(iters, train_accuracies_most_certain_permaze, 'r-', alpha=0.6, label='Full Maze')
|
| 613 |
+
axacc_test.plot(iters, test_accuracies_most_certain_permaze, 'r-', alpha=0.6, label='Full Maze')
|
| 614 |
+
|
| 615 |
+
axacc_train.set_title('Train Accuracy')
|
| 616 |
+
axacc_test.set_title('Test Accuracy')
|
| 617 |
+
axacc_train.legend(loc='lower right')
|
| 618 |
+
axacc_test.legend(loc='lower right')
|
| 619 |
+
axacc_train.set_xlim([0, args.training_iterations])
|
| 620 |
+
axacc_test.set_xlim([0, args.training_iterations])
|
| 621 |
+
axacc_train.set_ylim([0, 1]) # Set Ylim for accuracy
|
| 622 |
+
axacc_test.set_ylim([0, 1])
|
| 623 |
+
|
| 624 |
+
figacc.tight_layout()
|
| 625 |
+
figacc.savefig(f'{args.log_dir}/accuracies.png', dpi=150)
|
| 626 |
+
plt.close(figacc)
|
| 627 |
+
|
| 628 |
+
# Loss Plot
|
| 629 |
+
figloss = plt.figure(figsize=(10, 5))
|
| 630 |
+
axloss = figloss.add_subplot(111)
|
| 631 |
+
axloss.plot(iters, train_losses, 'b-', linewidth=1, alpha=0.8, label=f'Train: {train_losses[-1]:.4f}')
|
| 632 |
+
axloss.plot(iters, test_losses, 'r-', linewidth=1, alpha=0.8, label=f'Test: {test_losses[-1]:.4f}')
|
| 633 |
+
axloss.legend(loc='upper right')
|
| 634 |
+
axloss.set_xlim([0, args.training_iterations])
|
| 635 |
+
axloss.set_ylim(bottom=0)
|
| 636 |
+
|
| 637 |
+
figloss.tight_layout()
|
| 638 |
+
figloss.savefig(f'{args.log_dir}/losses.png', dpi=150)
|
| 639 |
+
plt.close(figloss)
|
| 640 |
+
|
| 641 |
+
# --- Visualization Section (Conditional) ---
|
| 642 |
+
if args.model in ['ctm', 'lstm']:
|
| 643 |
+
# try:
|
| 644 |
+
inputs_viz, targets_viz = next(iter(testloader))
|
| 645 |
+
inputs_viz = inputs_viz.to(device)
|
| 646 |
+
targets_viz = targets_viz.to(device)
|
| 647 |
+
# Find longest path in batch for potentially better visualization
|
| 648 |
+
longest_index = (targets_viz!=4).sum(-1).argmax() # Action 4 assumed padding/end
|
| 649 |
+
|
| 650 |
+
# Track internal states
|
| 651 |
+
predictions_viz_raw, certainties_viz, _, pre_activations_viz, post_activations_viz, attention_tracking_viz = model(inputs_viz, track=True)
|
| 652 |
+
|
| 653 |
+
# Reshape predictions (assuming raw is B, D, T)
|
| 654 |
+
predictions_viz = predictions_viz_raw.reshape(predictions_viz_raw.size(0), -1, 5, predictions_viz_raw.size(-1)) # B, S, C, T
|
| 655 |
+
|
| 656 |
+
att_shape = (model.kv_features.shape[2], model.kv_features.shape[3])
|
| 657 |
+
attention_tracking_viz = attention_tracking_viz.reshape(
|
| 658 |
+
attention_tracking_viz.shape[0],
|
| 659 |
+
attention_tracking_viz.shape[1], -1, att_shape[0], att_shape[1])
|
| 660 |
+
|
| 661 |
+
# Plot dynamics (common plotting function)
|
| 662 |
+
plot_neural_dynamics(post_activations_viz, 100, args.log_dir, axis_snap=True)
|
| 663 |
+
|
| 664 |
+
# Create maze GIF (task-specific plotting)
|
| 665 |
+
make_maze_gif((inputs_viz[longest_index].detach().cpu().numpy()+1)/2,
|
| 666 |
+
predictions_viz[longest_index].detach().cpu().numpy(), # Pass reshaped B,S,C,T -> S,C,T
|
| 667 |
+
targets_viz[longest_index].detach().cpu().numpy(), # S
|
| 668 |
+
attention_tracking_viz[:, longest_index], # Pass T, (H), H, W
|
| 669 |
+
args.log_dir)
|
| 670 |
+
# except Exception as e:
|
| 671 |
+
# print(f"Visualization failed for model {args.model}: {e}")
|
| 672 |
+
# --- End Visualization ---
|
| 673 |
+
|
| 674 |
+
model.train() # Switch back to train mode
|
| 675 |
+
|
| 676 |
+
|
| 677 |
+
# Save model checkpoint
|
| 678 |
+
if (bi % args.save_every == 0 or bi == args.training_iterations - 1) and bi != start_iter:
|
| 679 |
+
pbar.set_description('Saving model checkpoint...')
|
| 680 |
+
checkpoint_data = {
|
| 681 |
+
'model_state_dict': model.state_dict(),
|
| 682 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 683 |
+
'scheduler_state_dict': scheduler.state_dict(),
|
| 684 |
+
'scaler_state_dict': scaler.state_dict(), # Save scaler state
|
| 685 |
+
'iteration': bi,
|
| 686 |
+
# Save all tracked metrics
|
| 687 |
+
'train_losses': train_losses,
|
| 688 |
+
'test_losses': test_losses,
|
| 689 |
+
'train_accuracies': train_accuracies, # List of (S, T) or (S,) arrays
|
| 690 |
+
'test_accuracies': test_accuracies, # List of (S, T) or (S,) arrays
|
| 691 |
+
'train_accuracies_most_certain': train_accuracies_most_certain, # List of scalars
|
| 692 |
+
'test_accuracies_most_certain': test_accuracies_most_certain, # List of scalars
|
| 693 |
+
'train_accuracies_most_certain_permaze': train_accuracies_most_certain_permaze, # List of scalars
|
| 694 |
+
'test_accuracies_most_certain_permaze': test_accuracies_most_certain_permaze, # List of scalars
|
| 695 |
+
'iters': iters,
|
| 696 |
+
'args': args, # Save args used for this run
|
| 697 |
+
# RNG states
|
| 698 |
+
'torch_rng_state': torch.get_rng_state(),
|
| 699 |
+
'numpy_rng_state': np.random.get_state(),
|
| 700 |
+
'random_rng_state': random.getstate(),
|
| 701 |
+
}
|
| 702 |
+
torch.save(checkpoint_data, f'{args.log_dir}/checkpoint.pt')
|
| 703 |
+
|
| 704 |
+
pbar.update(1)
|
tasks/mazes/train_distributed.py
ADDED
|
@@ -0,0 +1,782 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import random
|
| 4 |
+
import gc
|
| 5 |
+
|
| 6 |
+
import matplotlib.pyplot as plt
|
| 7 |
+
import numpy as np
|
| 8 |
+
import seaborn as sns
|
| 9 |
+
sns.set_style('darkgrid')
|
| 10 |
+
import torch
|
| 11 |
+
if torch.cuda.is_available():
|
| 12 |
+
# For faster
|
| 13 |
+
torch.set_float32_matmul_precision('high')
|
| 14 |
+
import torch.distributed as dist
|
| 15 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 16 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 17 |
+
from utils.samplers import FastRandomDistributedSampler
|
| 18 |
+
from tqdm.auto import tqdm
|
| 19 |
+
|
| 20 |
+
# Data/Task Specific Imports
|
| 21 |
+
from data.custom_datasets import MazeImageFolder
|
| 22 |
+
|
| 23 |
+
# Model Imports
|
| 24 |
+
from models.ctm import ContinuousThoughtMachine
|
| 25 |
+
from models.lstm import LSTMBaseline
|
| 26 |
+
from models.ff import FFBaseline
|
| 27 |
+
|
| 28 |
+
# Plotting/Utils Imports
|
| 29 |
+
from tasks.mazes.plotting import make_maze_gif
|
| 30 |
+
from tasks.image_classification.plotting import plot_neural_dynamics
|
| 31 |
+
from utils.housekeeping import set_seed, zip_python_code
|
| 32 |
+
from utils.losses import maze_loss
|
| 33 |
+
from utils.schedulers import WarmupCosineAnnealingLR, WarmupMultiStepLR, warmup
|
| 34 |
+
|
| 35 |
+
import torchvision
|
| 36 |
+
torchvision.disable_beta_transforms_warning()
|
| 37 |
+
|
| 38 |
+
import warnings
|
| 39 |
+
warnings.filterwarnings("ignore", message="using precomputed metric; inverse_transform will be unavailable")
|
| 40 |
+
warnings.filterwarnings('ignore', message='divide by zero encountered in power', category=RuntimeWarning)
|
| 41 |
+
warnings.filterwarnings(
|
| 42 |
+
"ignore",
|
| 43 |
+
"Corrupt EXIF data",
|
| 44 |
+
UserWarning,
|
| 45 |
+
r"^PIL\.TiffImagePlugin$"
|
| 46 |
+
)
|
| 47 |
+
warnings.filterwarnings(
|
| 48 |
+
"ignore",
|
| 49 |
+
"UserWarning: Metadata Warning",
|
| 50 |
+
UserWarning,
|
| 51 |
+
r"^PIL\.TiffImagePlugin$"
|
| 52 |
+
)
|
| 53 |
+
warnings.filterwarnings(
|
| 54 |
+
"ignore",
|
| 55 |
+
"UserWarning: Truncated File Read",
|
| 56 |
+
UserWarning,
|
| 57 |
+
r"^PIL\.TiffImagePlugin$"
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def parse_args():
|
| 62 |
+
parser = argparse.ArgumentParser()
|
| 63 |
+
|
| 64 |
+
# Model Selection
|
| 65 |
+
parser.add_argument('--model', type=str, required=True, choices=['ctm', 'lstm', 'ff'], help='Model type to train.')
|
| 66 |
+
|
| 67 |
+
# Model Architecture
|
| 68 |
+
parser.add_argument('--d_model', type=int, default=512, help='Dimension of the model.')
|
| 69 |
+
parser.add_argument('--dropout', type=float, default=0.0, help='Dropout rate.')
|
| 70 |
+
parser.add_argument('--backbone_type', type=str, default='resnet34-2', help='Type of backbone featureiser.')
|
| 71 |
+
# CTM / LSTM specific
|
| 72 |
+
parser.add_argument('--d_input', type=int, default=128, help='Dimension of the input (CTM, LSTM).')
|
| 73 |
+
parser.add_argument('--heads', type=int, default=8, help='Number of attention heads (CTM, LSTM).')
|
| 74 |
+
parser.add_argument('--iterations', type=int, default=75, help='Number of internal ticks (CTM, LSTM).')
|
| 75 |
+
parser.add_argument('--positional_embedding_type', type=str, default='none',
|
| 76 |
+
help='Type of positional embedding (CTM, LSTM).', choices=['none',
|
| 77 |
+
'learnable-fourier',
|
| 78 |
+
'multi-learnable-fourier',
|
| 79 |
+
'custom-rotational'])
|
| 80 |
+
# CTM specific
|
| 81 |
+
parser.add_argument('--synapse_depth', type=int, default=8, help='Depth of U-NET model for synapse. 1=linear, no unet (CTM only).')
|
| 82 |
+
parser.add_argument('--n_synch_out', type=int, default=32, help='Number of neurons to use for output synch (CTM only).')
|
| 83 |
+
parser.add_argument('--n_synch_action', type=int, default=32, help='Number of neurons to use for observation/action synch (CTM only).')
|
| 84 |
+
parser.add_argument('--neuron_select_type', type=str, default='random-pairing', help='Protocol for selecting neuron subset (CTM only).')
|
| 85 |
+
parser.add_argument('--n_random_pairing_self', type=int, default=0, help='Number of neurons paired self-to-self for synch (CTM only).')
|
| 86 |
+
parser.add_argument('--memory_length', type=int, default=25, help='Length of the pre-activation history for NLMS (CTM only).')
|
| 87 |
+
parser.add_argument('--deep_memory', action=argparse.BooleanOptionalAction, default=True, help='Use deep memory (CTM only).')
|
| 88 |
+
parser.add_argument('--memory_hidden_dims', type=int, default=32, help='Hidden dimensions of the memory if using deep memory (CTM only).')
|
| 89 |
+
parser.add_argument('--dropout_nlm', type=float, default=None, help='Dropout rate for NLMs specifically. Unset to match dropout on the rest of the model (CTM only).')
|
| 90 |
+
parser.add_argument('--do_normalisation', action=argparse.BooleanOptionalAction, default=False, help='Apply normalization in NLMs (CTM only).')
|
| 91 |
+
# LSTM specific
|
| 92 |
+
parser.add_argument('--num_layers', type=int, default=2, help='Number of LSTM stacked layers (LSTM only).')
|
| 93 |
+
|
| 94 |
+
# Task Specific Args
|
| 95 |
+
parser.add_argument('--maze_route_length', type=int, default=100, help='Length to truncate targets.')
|
| 96 |
+
parser.add_argument('--cirriculum_lookahead', type=int, default=5, help='How far to look ahead for cirriculum.')
|
| 97 |
+
|
| 98 |
+
# Training
|
| 99 |
+
parser.add_argument('--batch_size', type=int, default=16, help='Batch size for training (per GPU).')
|
| 100 |
+
parser.add_argument('--batch_size_test', type=int, default=64, help='Batch size for testing (per GPU).')
|
| 101 |
+
parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate for the model.')
|
| 102 |
+
parser.add_argument('--training_iterations', type=int, default=100001, help='Number of training iterations.')
|
| 103 |
+
parser.add_argument('--warmup_steps', type=int, default=5000, help='Number of warmup steps.')
|
| 104 |
+
parser.add_argument('--use_scheduler', action=argparse.BooleanOptionalAction, default=True, help='Use a learning rate scheduler.')
|
| 105 |
+
parser.add_argument('--scheduler_type', type=str, default='cosine', choices=['multistep', 'cosine'], help='Type of learning rate scheduler.')
|
| 106 |
+
parser.add_argument('--milestones', type=int, default=[8000, 15000, 20000], nargs='+', help='Learning rate scheduler milestones.')
|
| 107 |
+
parser.add_argument('--gamma', type=float, default=0.1, help='Learning rate scheduler gamma for multistep.')
|
| 108 |
+
parser.add_argument('--weight_decay', type=float, default=0.0, help='Weight decay factor.')
|
| 109 |
+
parser.add_argument('--weight_decay_exclusion_list', type=str, nargs='+', default=[], help='List to exclude from weight decay. Typically good: bn, ln, bias, start')
|
| 110 |
+
parser.add_argument('--num_workers_train', type=int, default=0, help='Num workers training.')
|
| 111 |
+
parser.add_argument('--gradient_clipping', type=float, default=-1, help='Gradient quantile clipping value (-1 to disable).')
|
| 112 |
+
parser.add_argument('--use_custom_sampler', action=argparse.BooleanOptionalAction, default=False, help='Use custom fast sampler to avoid reshuffling.')
|
| 113 |
+
parser.add_argument('--do_compile', action=argparse.BooleanOptionalAction, default=False, help='Try to compile model components.')
|
| 114 |
+
|
| 115 |
+
# Logging and Saving
|
| 116 |
+
parser.add_argument('--log_dir', type=str, default='logs/scratch', help='Directory for logging.')
|
| 117 |
+
parser.add_argument('--dataset', type=str, default='mazes-medium', help='Dataset to use.', choices=['mazes-medium', 'mazes-large'])
|
| 118 |
+
parser.add_argument('--save_every', type=int, default=1000, help='Save checkpoints every this many iterations.')
|
| 119 |
+
parser.add_argument('--seed', type=int, default=412, help='Random seed.')
|
| 120 |
+
parser.add_argument('--reload', action=argparse.BooleanOptionalAction, default=False, help='Reload from disk?')
|
| 121 |
+
parser.add_argument('--reload_model_only', action=argparse.BooleanOptionalAction, default=False, help='Reload only the model from disk?') # Default False based on user edit
|
| 122 |
+
parser.add_argument('--strict_reload', action=argparse.BooleanOptionalAction, default=False, help='Should use strict reload for model weights.')
|
| 123 |
+
parser.add_argument('--ignore_metrics_when_reloading', action=argparse.BooleanOptionalAction, default=False, help='Ignore metrics when reloading (for debugging)?')
|
| 124 |
+
|
| 125 |
+
# Tracking
|
| 126 |
+
parser.add_argument('--track_every', type=int, default=1000, help='Track metrics every this many iterations.')
|
| 127 |
+
parser.add_argument('--n_test_batches', type=int, default=2, help='How many minibatches to approx metrics. Set to -1 for full eval')
|
| 128 |
+
|
| 129 |
+
# Precision
|
| 130 |
+
parser.add_argument('--use_amp', action=argparse.BooleanOptionalAction, default=False, help='AMP autocast.')
|
| 131 |
+
|
| 132 |
+
args = parser.parse_args()
|
| 133 |
+
return args
|
| 134 |
+
|
| 135 |
+
# --- DDP Setup Functions ---
|
| 136 |
+
def setup_ddp():
|
| 137 |
+
if 'RANK' not in os.environ:
|
| 138 |
+
os.environ['RANK'] = '0'
|
| 139 |
+
os.environ['WORLD_SIZE'] = '1'
|
| 140 |
+
os.environ['MASTER_ADDR'] = 'localhost'
|
| 141 |
+
os.environ['MASTER_PORT'] = '12356' # Different port from image classification
|
| 142 |
+
os.environ['LOCAL_RANK'] = '0'
|
| 143 |
+
print("Running in non-distributed mode (simulated DDP setup).")
|
| 144 |
+
if not torch.cuda.is_available() or int(os.environ['WORLD_SIZE']) == 1:
|
| 145 |
+
dist.init_process_group(backend='gloo')
|
| 146 |
+
print("Initialized process group with Gloo backend for single/CPU process.")
|
| 147 |
+
rank = int(os.environ['RANK'])
|
| 148 |
+
world_size = int(os.environ['WORLD_SIZE'])
|
| 149 |
+
local_rank = int(os.environ['LOCAL_RANK'])
|
| 150 |
+
return rank, world_size, local_rank
|
| 151 |
+
|
| 152 |
+
dist.init_process_group(backend='nccl')
|
| 153 |
+
rank = int(os.environ['RANK'])
|
| 154 |
+
world_size = int(os.environ['WORLD_SIZE'])
|
| 155 |
+
local_rank = int(os.environ['LOCAL_RANK'])
|
| 156 |
+
if torch.cuda.is_available():
|
| 157 |
+
torch.cuda.set_device(local_rank)
|
| 158 |
+
print(f"Rank {rank} setup on GPU {local_rank}")
|
| 159 |
+
else:
|
| 160 |
+
print(f"Rank {rank} setup on CPU")
|
| 161 |
+
return rank, world_size, local_rank
|
| 162 |
+
|
| 163 |
+
def cleanup_ddp():
|
| 164 |
+
if dist.is_initialized():
|
| 165 |
+
dist.destroy_process_group()
|
| 166 |
+
print("DDP cleanup complete.")
|
| 167 |
+
|
| 168 |
+
def is_main_process(rank):
|
| 169 |
+
return rank == 0
|
| 170 |
+
# --- End DDP Setup ---
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
if __name__=='__main__':
|
| 174 |
+
|
| 175 |
+
args = parse_args()
|
| 176 |
+
|
| 177 |
+
rank, world_size, local_rank = setup_ddp()
|
| 178 |
+
|
| 179 |
+
set_seed(args.seed + rank, False)
|
| 180 |
+
|
| 181 |
+
# Rank 0 handles directory creation and initial logging
|
| 182 |
+
if is_main_process(rank):
|
| 183 |
+
if not os.path.exists(args.log_dir): os.makedirs(args.log_dir)
|
| 184 |
+
zip_python_code(f'{args.log_dir}/repo_state.zip')
|
| 185 |
+
with open(f'{args.log_dir}/args.txt', 'w') as f:
|
| 186 |
+
print(args, file=f)
|
| 187 |
+
if world_size > 1: dist.barrier()
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
assert args.dataset in ['mazes-medium', 'mazes-large']
|
| 191 |
+
|
| 192 |
+
# Setup Device
|
| 193 |
+
if torch.cuda.is_available():
|
| 194 |
+
device = torch.device(f'cuda:{local_rank}')
|
| 195 |
+
else:
|
| 196 |
+
device = torch.device('cpu')
|
| 197 |
+
if world_size > 1: warnings.warn("Running DDP on CPU is not recommended.")
|
| 198 |
+
|
| 199 |
+
if is_main_process(rank):
|
| 200 |
+
print(f'Main process (Rank {rank}): Using device {device}. World size: {world_size}. Model: {args.model}')
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
prediction_reshaper = [args.maze_route_length, 5]
|
| 204 |
+
args.out_dims = args.maze_route_length * 5
|
| 205 |
+
|
| 206 |
+
# --- Model Definition (Conditional) ---
|
| 207 |
+
model_base = None # Base model before DDP wrapping
|
| 208 |
+
if args.model == 'ctm':
|
| 209 |
+
model_base = ContinuousThoughtMachine(
|
| 210 |
+
iterations=args.iterations,
|
| 211 |
+
d_model=args.d_model,
|
| 212 |
+
d_input=args.d_input,
|
| 213 |
+
heads=args.heads,
|
| 214 |
+
n_synch_out=args.n_synch_out,
|
| 215 |
+
n_synch_action=args.n_synch_action,
|
| 216 |
+
synapse_depth=args.synapse_depth,
|
| 217 |
+
memory_length=args.memory_length,
|
| 218 |
+
deep_nlms=args.deep_memory,
|
| 219 |
+
memory_hidden_dims=args.memory_hidden_dims,
|
| 220 |
+
do_layernorm_nlm=args.do_normalisation,
|
| 221 |
+
backbone_type=args.backbone_type,
|
| 222 |
+
positional_embedding_type=args.positional_embedding_type,
|
| 223 |
+
out_dims=args.out_dims,
|
| 224 |
+
prediction_reshaper=prediction_reshaper,
|
| 225 |
+
dropout=args.dropout,
|
| 226 |
+
dropout_nlm=args.dropout_nlm,
|
| 227 |
+
neuron_select_type=args.neuron_select_type,
|
| 228 |
+
n_random_pairing_self=args.n_random_pairing_self,
|
| 229 |
+
).to(device)
|
| 230 |
+
elif args.model == 'lstm':
|
| 231 |
+
model_base = LSTMBaseline(
|
| 232 |
+
num_layers=args.num_layers,
|
| 233 |
+
iterations=args.iterations,
|
| 234 |
+
d_model=args.d_model,
|
| 235 |
+
d_input=args.d_input,
|
| 236 |
+
heads=args.heads,
|
| 237 |
+
backbone_type=args.backbone_type,
|
| 238 |
+
positional_embedding_type=args.positional_embedding_type,
|
| 239 |
+
out_dims=args.out_dims,
|
| 240 |
+
prediction_reshaper=prediction_reshaper,
|
| 241 |
+
dropout=args.dropout,
|
| 242 |
+
).to(device)
|
| 243 |
+
elif args.model == 'ff':
|
| 244 |
+
model_base = FFBaseline(
|
| 245 |
+
d_model=args.d_model,
|
| 246 |
+
backbone_type=args.backbone_type,
|
| 247 |
+
out_dims=args.out_dims,
|
| 248 |
+
dropout=args.dropout,
|
| 249 |
+
).to(device)
|
| 250 |
+
else:
|
| 251 |
+
raise ValueError(f"Unknown model type: {args.model}")
|
| 252 |
+
|
| 253 |
+
# Use pseudo-input *before* DDP wrapping
|
| 254 |
+
try:
|
| 255 |
+
# Determine pseudo input shape based on dataset
|
| 256 |
+
h_w = 39 if args.dataset in ['mazes-small', 'mazes-medium'] else 99 # Example dimensions
|
| 257 |
+
pseudo_inputs = torch.zeros((1, 3, h_w, h_w), device=device).float()
|
| 258 |
+
model_base(pseudo_inputs)
|
| 259 |
+
except Exception as e:
|
| 260 |
+
print(f"Warning: Pseudo forward pass failed: {e}")
|
| 261 |
+
|
| 262 |
+
if is_main_process(rank):
|
| 263 |
+
print(f'Total params: {sum(p.numel() for p in model_base.parameters() if p.requires_grad)}')
|
| 264 |
+
|
| 265 |
+
# Wrap model with DDP
|
| 266 |
+
if device.type == 'cuda' and world_size > 1:
|
| 267 |
+
model = DDP(model_base, device_ids=[local_rank], output_device=local_rank)
|
| 268 |
+
elif device.type == 'cpu' and world_size > 1:
|
| 269 |
+
model = DDP(model_base)
|
| 270 |
+
else:
|
| 271 |
+
model = model_base
|
| 272 |
+
# --- End Model Definition ---
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
# Data Loading (After model setup to allow pseudo pass first)
|
| 276 |
+
dataset_mean = [0,0,0]
|
| 277 |
+
dataset_std = [1,1,1]
|
| 278 |
+
which_maze = args.dataset.split('-')[-1]
|
| 279 |
+
data_root = f'data/mazes/{which_maze}'
|
| 280 |
+
|
| 281 |
+
train_data = MazeImageFolder(root=f'{data_root}/train/', which_set='train', maze_route_length=args.maze_route_length)
|
| 282 |
+
test_data = MazeImageFolder(root=f'{data_root}/test/', which_set='test', maze_route_length=args.maze_route_length)
|
| 283 |
+
|
| 284 |
+
train_sampler = (FastRandomDistributedSampler(train_data, num_replicas=world_size, rank=rank, seed=args.seed, epoch_steps=int(10e10))
|
| 285 |
+
if args.use_custom_sampler else
|
| 286 |
+
DistributedSampler(train_data, num_replicas=world_size, rank=rank, shuffle=True, seed=args.seed))
|
| 287 |
+
test_sampler = DistributedSampler(test_data, num_replicas=world_size, rank=rank, shuffle=False, seed=args.seed)
|
| 288 |
+
|
| 289 |
+
num_workers_test = 1
|
| 290 |
+
trainloader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, sampler=train_sampler,
|
| 291 |
+
num_workers=args.num_workers_train, pin_memory=True, drop_last=True)
|
| 292 |
+
testloader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size_test, sampler=test_sampler,
|
| 293 |
+
num_workers=num_workers_test, pin_memory=True, drop_last=False)
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
# Optimizer and scheduler
|
| 297 |
+
decay_params = []
|
| 298 |
+
no_decay_params = []
|
| 299 |
+
no_decay_names = []
|
| 300 |
+
for name, param in model.named_parameters():
|
| 301 |
+
if not param.requires_grad:
|
| 302 |
+
continue # Skip parameters that don't require gradients
|
| 303 |
+
if any(exclusion_str in name for exclusion_str in args.weight_decay_exclusion_list):
|
| 304 |
+
no_decay_params.append(param)
|
| 305 |
+
no_decay_names.append(name)
|
| 306 |
+
else:
|
| 307 |
+
decay_params.append(param)
|
| 308 |
+
if len(no_decay_names) and is_main_process(rank):
|
| 309 |
+
print(f'WARNING, excluding: {no_decay_names}')
|
| 310 |
+
|
| 311 |
+
# Optimizer and scheduler (Common setup)
|
| 312 |
+
if len(no_decay_names) and args.weight_decay!=0:
|
| 313 |
+
optimizer = torch.optim.AdamW([{'params': decay_params, 'weight_decay':args.weight_decay},
|
| 314 |
+
{'params': no_decay_params, 'weight_decay':0}],
|
| 315 |
+
lr=args.lr,
|
| 316 |
+
eps=1e-8 if not args.use_amp else 1e-6)
|
| 317 |
+
else:
|
| 318 |
+
optimizer = torch.optim.AdamW(model.parameters(),
|
| 319 |
+
lr=args.lr,
|
| 320 |
+
eps=1e-8 if not args.use_amp else 1e-6,
|
| 321 |
+
weight_decay=args.weight_decay)
|
| 322 |
+
|
| 323 |
+
warmup_schedule = warmup(args.warmup_steps)
|
| 324 |
+
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warmup_schedule.step)
|
| 325 |
+
if args.use_scheduler:
|
| 326 |
+
if args.scheduler_type == 'multistep':
|
| 327 |
+
scheduler = WarmupMultiStepLR(optimizer, warmup_steps=args.warmup_steps, milestones=args.milestones, gamma=args.gamma)
|
| 328 |
+
elif args.scheduler_type == 'cosine':
|
| 329 |
+
scheduler = WarmupCosineAnnealingLR(optimizer, args.warmup_steps, args.training_iterations, warmup_start_lr=1e-20, eta_min=1e-7)
|
| 330 |
+
else:
|
| 331 |
+
raise NotImplementedError
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
# Metrics tracking (Rank 0 stores history)
|
| 335 |
+
start_iter = 0
|
| 336 |
+
iters = []
|
| 337 |
+
train_losses, test_losses = [], []
|
| 338 |
+
train_accuracies, test_accuracies = [], [] # Avg Step Acc (scalar list)
|
| 339 |
+
train_accuracies_most_certain, test_accuracies_most_certain = [], [] # Avg Step Acc @ Certain tick (scalar list)
|
| 340 |
+
train_accuracies_most_certain_permaze, test_accuracies_most_certain_permaze = [], [] # Full Maze Acc @ Certain tick (scalar list)
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
scaler = torch.amp.GradScaler("cuda" if device.type == 'cuda' else "cpu", enabled=args.use_amp)
|
| 344 |
+
|
| 345 |
+
# Reloading Logic
|
| 346 |
+
if args.reload:
|
| 347 |
+
map_location = device
|
| 348 |
+
chkpt_path = f'{args.log_dir}/checkpoint.pt'
|
| 349 |
+
if os.path.isfile(chkpt_path):
|
| 350 |
+
print(f'Rank {rank}: Reloading from: {chkpt_path}')
|
| 351 |
+
if not args.strict_reload: print('WARNING: not using strict reload for model weights!')
|
| 352 |
+
|
| 353 |
+
checkpoint = torch.load(chkpt_path, map_location=map_location, weights_only=False)
|
| 354 |
+
|
| 355 |
+
model_to_load = model.module if isinstance(model, DDP) else model
|
| 356 |
+
state_dict = checkpoint['model_state_dict']
|
| 357 |
+
has_module_prefix = all(k.startswith('module.') for k in state_dict)
|
| 358 |
+
is_wrapped = isinstance(model, DDP)
|
| 359 |
+
|
| 360 |
+
if has_module_prefix and not is_wrapped:
|
| 361 |
+
state_dict = {k.partition('module.')[2]: v for k,v in state_dict.items()}
|
| 362 |
+
elif not has_module_prefix and is_wrapped:
|
| 363 |
+
load_result = model_to_load.load_state_dict(state_dict, strict=args.strict_reload)
|
| 364 |
+
print(f" Loaded state_dict. Missing: {load_result.missing_keys}, Unexpected: {load_result.unexpected_keys}")
|
| 365 |
+
state_dict = None # Prevent loading again
|
| 366 |
+
|
| 367 |
+
if state_dict is not None:
|
| 368 |
+
load_result = model_to_load.load_state_dict(state_dict, strict=args.strict_reload)
|
| 369 |
+
print(f" Loaded state_dict. Missing: {load_result.missing_keys}, Unexpected: {load_result.unexpected_keys}")
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
if not args.reload_model_only:
|
| 374 |
+
print(f'Rank {rank}: Reloading optimizer, scheduler, scaler, iteration.')
|
| 375 |
+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
| 376 |
+
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
| 377 |
+
scaler.load_state_dict(checkpoint['scaler_state_dict'])
|
| 378 |
+
start_iter = checkpoint['iteration']
|
| 379 |
+
|
| 380 |
+
if is_main_process(rank) and not args.ignore_metrics_when_reloading:
|
| 381 |
+
print(f'Rank {rank}: Reloading metrics history.')
|
| 382 |
+
iters = checkpoint['iters']
|
| 383 |
+
train_losses = checkpoint['train_losses']
|
| 384 |
+
test_losses = checkpoint['test_losses']
|
| 385 |
+
train_accuracies = checkpoint['train_accuracies'] # Reloading simplified avg step acc list
|
| 386 |
+
test_accuracies = checkpoint['test_accuracies']
|
| 387 |
+
train_accuracies_most_certain = checkpoint['train_accuracies_most_certain']
|
| 388 |
+
test_accuracies_most_certain = checkpoint['test_accuracies_most_certain']
|
| 389 |
+
train_accuracies_most_certain_permaze = checkpoint['train_accuracies_most_certain_permaze']
|
| 390 |
+
test_accuracies_most_certain_permaze = checkpoint['test_accuracies_most_certain_permaze']
|
| 391 |
+
elif is_main_process(rank) and args.ignore_metrics_when_reloading:
|
| 392 |
+
print(f'Rank {rank}: Ignoring metrics history upon reload.')
|
| 393 |
+
else:
|
| 394 |
+
print(f'Rank {rank}: Only reloading model weights!')
|
| 395 |
+
|
| 396 |
+
if is_main_process(rank) and 'torch_rng_state' in checkpoint and not args.reload_model_only:
|
| 397 |
+
print(f'Rank {rank}: Loading RNG states.')
|
| 398 |
+
torch.set_rng_state(checkpoint['torch_rng_state'].cpu())
|
| 399 |
+
np.random.set_state(checkpoint['numpy_rng_state'])
|
| 400 |
+
random.setstate(checkpoint['random_rng_state'])
|
| 401 |
+
|
| 402 |
+
del checkpoint
|
| 403 |
+
gc.collect()
|
| 404 |
+
if torch.cuda.is_available():
|
| 405 |
+
torch.cuda.empty_cache()
|
| 406 |
+
print(f"Rank {rank}: Reload finished, starting from iteration {start_iter}")
|
| 407 |
+
else:
|
| 408 |
+
print(f"Rank {rank}: Checkpoint not found at {chkpt_path}, starting from scratch.")
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
if world_size > 1: dist.barrier()
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
# Conditional Compilation
|
| 415 |
+
if args.do_compile:
|
| 416 |
+
if is_main_process(rank): print('Compiling model components...')
|
| 417 |
+
model_to_compile = model.module if isinstance(model, DDP) else model
|
| 418 |
+
if hasattr(model_to_compile, 'backbone'):
|
| 419 |
+
model_to_compile.backbone = torch.compile(model_to_compile.backbone, mode='reduce-overhead', fullgraph=True)
|
| 420 |
+
if args.model == 'ctm':
|
| 421 |
+
model_to_compile.synapses = torch.compile(model_to_compile.synapses, mode='reduce-overhead', fullgraph=True)
|
| 422 |
+
if world_size > 1: dist.barrier()
|
| 423 |
+
if is_main_process(rank): print('Compilation finished.')
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
# --- Training Loop ---
|
| 427 |
+
model.train()
|
| 428 |
+
pbar = tqdm(total=args.training_iterations, initial=start_iter, leave=False, position=0, dynamic_ncols=True, disable=not is_main_process(rank))
|
| 429 |
+
|
| 430 |
+
iterator = iter(trainloader)
|
| 431 |
+
|
| 432 |
+
for bi in range(start_iter, args.training_iterations):
|
| 433 |
+
|
| 434 |
+
# --- Evaluation and Plotting (Rank 0 + Aggregation) ---
|
| 435 |
+
if bi % args.track_every == 0 and (bi != 0 or args.reload_model_only):
|
| 436 |
+
model.eval()
|
| 437 |
+
with torch.inference_mode():
|
| 438 |
+
|
| 439 |
+
# --- Distributed Evaluation ---
|
| 440 |
+
if is_main_process(rank): iters.append(bi) # Track iterations on rank 0
|
| 441 |
+
|
| 442 |
+
# Initialize accumulators on device
|
| 443 |
+
total_train_loss = torch.tensor(0.0, device=device)
|
| 444 |
+
total_train_correct_certain = torch.tensor(0.0, device=device) # Sum correct steps @ certain tick
|
| 445 |
+
total_train_mazes_solved = torch.tensor(0.0, device=device) # Sum solved mazes @ certain tick
|
| 446 |
+
total_train_steps = torch.tensor(0.0, device=device) # Total steps evaluated (B * S)
|
| 447 |
+
total_train_mazes = torch.tensor(0.0, device=device) # Total mazes evaluated (B)
|
| 448 |
+
|
| 449 |
+
# TRAIN METRICS
|
| 450 |
+
train_eval_sampler = DistributedSampler(train_data, num_replicas=world_size, rank=rank, shuffle=False)
|
| 451 |
+
train_eval_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size_test, sampler=train_eval_sampler, num_workers=num_workers_test, pin_memory=True)
|
| 452 |
+
|
| 453 |
+
pbar_inner_desc = 'Eval Train (Rank 0)' if is_main_process(rank) else None
|
| 454 |
+
with tqdm(total=len(train_eval_loader), desc=pbar_inner_desc, leave=False, position=1, dynamic_ncols=True, disable=not is_main_process(rank)) as pbar_inner:
|
| 455 |
+
for inferi, (inputs, targets) in enumerate(train_eval_loader):
|
| 456 |
+
inputs = inputs.to(device, non_blocking=True)
|
| 457 |
+
targets = targets.to(device, non_blocking=True) # B, S
|
| 458 |
+
batch_size = inputs.size(0)
|
| 459 |
+
seq_len = targets.size(1)
|
| 460 |
+
|
| 461 |
+
loss_eval = None
|
| 462 |
+
pred_at_certain = None # Shape B, S
|
| 463 |
+
if args.model == 'ctm':
|
| 464 |
+
predictions_raw, certainties, _ = model(inputs)
|
| 465 |
+
predictions = predictions_raw.reshape(batch_size, -1, 5, predictions_raw.size(-1)) # B,S,C,T
|
| 466 |
+
loss_eval, where_most_certain, _ = maze_loss(predictions, certainties, targets, use_most_certain=True)
|
| 467 |
+
pred_at_certain = predictions.argmax(2)[torch.arange(batch_size, device=device), :, where_most_certain]
|
| 468 |
+
elif args.model == 'lstm':
|
| 469 |
+
predictions_raw, certainties, _ = model(inputs)
|
| 470 |
+
predictions = predictions_raw.reshape(batch_size, -1, 5, predictions_raw.size(-1)) # B,S,C,T
|
| 471 |
+
loss_eval, where_most_certain, _ = maze_loss(predictions, certainties, targets, use_most_certain=False) # where = -1
|
| 472 |
+
pred_at_certain = predictions.argmax(2)[torch.arange(batch_size, device=device), :, where_most_certain]
|
| 473 |
+
elif args.model == 'ff':
|
| 474 |
+
predictions_raw = model(inputs) # B, S*C
|
| 475 |
+
predictions = predictions_raw.reshape(batch_size, -1, 5) # B,S,C
|
| 476 |
+
loss_eval, where_most_certain, _ = maze_loss(predictions.unsqueeze(-1), None, targets, use_most_certain=False) # where = -1
|
| 477 |
+
pred_at_certain = predictions.argmax(2)
|
| 478 |
+
|
| 479 |
+
# Accumulate metrics
|
| 480 |
+
total_train_loss += loss_eval * batch_size # Sum losses
|
| 481 |
+
correct_steps = (pred_at_certain == targets) # B, S boolean
|
| 482 |
+
total_train_correct_certain += correct_steps.sum() # Sum correct steps across batch
|
| 483 |
+
total_train_mazes_solved += correct_steps.all(dim=-1).sum() # Sum mazes where all steps are correct
|
| 484 |
+
total_train_steps += batch_size * seq_len
|
| 485 |
+
total_train_mazes += batch_size
|
| 486 |
+
|
| 487 |
+
if args.n_test_batches != -1 and inferi >= args.n_test_batches -1: break
|
| 488 |
+
pbar_inner.update(1)
|
| 489 |
+
|
| 490 |
+
# Aggregate Train Metrics
|
| 491 |
+
if world_size > 1:
|
| 492 |
+
dist.all_reduce(total_train_loss, op=dist.ReduceOp.SUM)
|
| 493 |
+
dist.all_reduce(total_train_correct_certain, op=dist.ReduceOp.SUM)
|
| 494 |
+
dist.all_reduce(total_train_mazes_solved, op=dist.ReduceOp.SUM)
|
| 495 |
+
dist.all_reduce(total_train_steps, op=dist.ReduceOp.SUM)
|
| 496 |
+
dist.all_reduce(total_train_mazes, op=dist.ReduceOp.SUM)
|
| 497 |
+
|
| 498 |
+
# Calculate final Train metrics on Rank 0
|
| 499 |
+
if is_main_process(rank) and total_train_mazes > 0:
|
| 500 |
+
avg_train_loss = total_train_loss.item() / total_train_mazes.item() # Avg loss per maze/sample
|
| 501 |
+
avg_train_acc_step = total_train_correct_certain.item() / total_train_steps.item() # Avg correct step %
|
| 502 |
+
avg_train_acc_maze = total_train_mazes_solved.item() / total_train_mazes.item() # Avg full maze solved %
|
| 503 |
+
train_losses.append(avg_train_loss)
|
| 504 |
+
train_accuracies_most_certain.append(avg_train_acc_step)
|
| 505 |
+
train_accuracies_most_certain_permaze.append(avg_train_acc_maze)
|
| 506 |
+
# train_accuracies list remains unused/placeholder for this simplified metric structure
|
| 507 |
+
print(f"Iter {bi} Train Metrics (Agg): Loss={avg_train_loss:.4f}, StepAcc={avg_train_acc_step:.4f}, MazeAcc={avg_train_acc_maze:.4f}")
|
| 508 |
+
|
| 509 |
+
# TEST METRICS
|
| 510 |
+
total_test_loss = torch.tensor(0.0, device=device)
|
| 511 |
+
total_test_correct_certain = torch.tensor(0.0, device=device)
|
| 512 |
+
total_test_mazes_solved = torch.tensor(0.0, device=device)
|
| 513 |
+
total_test_steps = torch.tensor(0.0, device=device)
|
| 514 |
+
total_test_mazes = torch.tensor(0.0, device=device)
|
| 515 |
+
|
| 516 |
+
pbar_inner_desc = 'Eval Test (Rank 0)' if is_main_process(rank) else None
|
| 517 |
+
with tqdm(total=len(testloader), desc=pbar_inner_desc, leave=False, position=1, dynamic_ncols=True, disable=not is_main_process(rank)) as pbar_inner:
|
| 518 |
+
for inferi, (inputs, targets) in enumerate(testloader):
|
| 519 |
+
inputs = inputs.to(device, non_blocking=True)
|
| 520 |
+
targets = targets.to(device, non_blocking=True)
|
| 521 |
+
batch_size = inputs.size(0)
|
| 522 |
+
seq_len = targets.size(1)
|
| 523 |
+
|
| 524 |
+
loss_eval = None
|
| 525 |
+
pred_at_certain = None
|
| 526 |
+
if args.model == 'ctm':
|
| 527 |
+
predictions_raw, certainties, _ = model(inputs)
|
| 528 |
+
predictions = predictions_raw.reshape(batch_size, -1, 5, predictions_raw.size(-1))
|
| 529 |
+
loss_eval, where_most_certain, _ = maze_loss(predictions, certainties, targets, use_most_certain=True)
|
| 530 |
+
pred_at_certain = predictions.argmax(2)[torch.arange(batch_size, device=device), :, where_most_certain]
|
| 531 |
+
elif args.model == 'lstm':
|
| 532 |
+
predictions_raw, certainties, _ = model(inputs)
|
| 533 |
+
predictions = predictions_raw.reshape(batch_size, -1, 5, predictions_raw.size(-1))
|
| 534 |
+
loss_eval, where_most_certain, _ = maze_loss(predictions, certainties, targets, use_most_certain=False)
|
| 535 |
+
pred_at_certain = predictions.argmax(2)[torch.arange(batch_size, device=device), :, where_most_certain]
|
| 536 |
+
elif args.model == 'ff':
|
| 537 |
+
predictions_raw = model(inputs)
|
| 538 |
+
predictions = predictions_raw.reshape(batch_size, -1, 5)
|
| 539 |
+
loss_eval, where_most_certain, _ = maze_loss(predictions.unsqueeze(-1), None, targets, use_most_certain=False)
|
| 540 |
+
pred_at_certain = predictions.argmax(2)
|
| 541 |
+
|
| 542 |
+
total_test_loss += loss_eval * batch_size
|
| 543 |
+
correct_steps = (pred_at_certain == targets)
|
| 544 |
+
total_test_correct_certain += correct_steps.sum()
|
| 545 |
+
total_test_mazes_solved += correct_steps.all(dim=-1).sum()
|
| 546 |
+
total_test_steps += batch_size * seq_len
|
| 547 |
+
total_test_mazes += batch_size
|
| 548 |
+
|
| 549 |
+
if args.n_test_batches != -1 and inferi >= args.n_test_batches -1: break
|
| 550 |
+
pbar_inner.update(1)
|
| 551 |
+
|
| 552 |
+
# Aggregate Test Metrics
|
| 553 |
+
if world_size > 1:
|
| 554 |
+
dist.all_reduce(total_test_loss, op=dist.ReduceOp.SUM)
|
| 555 |
+
dist.all_reduce(total_test_correct_certain, op=dist.ReduceOp.SUM)
|
| 556 |
+
dist.all_reduce(total_test_mazes_solved, op=dist.ReduceOp.SUM)
|
| 557 |
+
dist.all_reduce(total_test_steps, op=dist.ReduceOp.SUM)
|
| 558 |
+
dist.all_reduce(total_test_mazes, op=dist.ReduceOp.SUM)
|
| 559 |
+
|
| 560 |
+
# Calculate and Plot final Test metrics on Rank 0
|
| 561 |
+
if is_main_process(rank) and total_test_mazes > 0:
|
| 562 |
+
avg_test_loss = total_test_loss.item() / total_test_mazes.item()
|
| 563 |
+
avg_test_acc_step = total_test_correct_certain.item() / total_test_steps.item()
|
| 564 |
+
avg_test_acc_maze = total_test_mazes_solved.item() / total_test_mazes.item()
|
| 565 |
+
test_losses.append(avg_test_loss)
|
| 566 |
+
test_accuracies_most_certain.append(avg_test_acc_step)
|
| 567 |
+
test_accuracies_most_certain_permaze.append(avg_test_acc_maze)
|
| 568 |
+
print(f"Iter {bi} Test Metrics (Agg): Loss={avg_test_loss:.4f}, StepAcc={avg_test_acc_step:.4f}, MazeAcc={avg_test_acc_maze:.4f}\n")
|
| 569 |
+
|
| 570 |
+
# --- Plotting ---
|
| 571 |
+
figacc = plt.figure(figsize=(10, 10))
|
| 572 |
+
axacc_train = figacc.add_subplot(211)
|
| 573 |
+
axacc_test = figacc.add_subplot(212)
|
| 574 |
+
|
| 575 |
+
# Plot Avg Step Accuracy
|
| 576 |
+
axacc_train.plot(iters, train_accuracies_most_certain, 'k-', alpha=0.7, label=f'Avg Step Acc ({train_accuracies_most_certain[-1]:.3f})')
|
| 577 |
+
axacc_test.plot(iters, test_accuracies_most_certain, 'k-', alpha=0.7, label=f'Avg Step Acc ({test_accuracies_most_certain[-1]:.3f})')
|
| 578 |
+
# Plot Full Maze Accuracy
|
| 579 |
+
axacc_train.plot(iters, train_accuracies_most_certain_permaze, 'r-', alpha=0.6, label=f'Full Maze Acc ({train_accuracies_most_certain_permaze[-1]:.3f})')
|
| 580 |
+
axacc_test.plot(iters, test_accuracies_most_certain_permaze, 'r-', alpha=0.6, label=f'Full Maze Acc ({test_accuracies_most_certain_permaze[-1]:.3f})')
|
| 581 |
+
|
| 582 |
+
axacc_train.set_title('Train Accuracy (Aggregated)')
|
| 583 |
+
axacc_test.set_title('Test Accuracy (Aggregated)')
|
| 584 |
+
axacc_train.legend(loc='lower right')
|
| 585 |
+
axacc_test.legend(loc='lower right')
|
| 586 |
+
axacc_train.set_xlim([0, args.training_iterations])
|
| 587 |
+
axacc_test.set_xlim([0, args.training_iterations])
|
| 588 |
+
axacc_train.set_ylim([0, 1])
|
| 589 |
+
axacc_test.set_ylim([0, 1])
|
| 590 |
+
|
| 591 |
+
figacc.tight_layout()
|
| 592 |
+
figacc.savefig(f'{args.log_dir}/accuracies.png', dpi=150)
|
| 593 |
+
plt.close(figacc)
|
| 594 |
+
|
| 595 |
+
# Loss Plot
|
| 596 |
+
figloss = plt.figure(figsize=(10, 5))
|
| 597 |
+
axloss = figloss.add_subplot(111)
|
| 598 |
+
axloss.plot(iters, train_losses, 'b-', linewidth=1, alpha=0.8, label=f'Train (Agg): {train_losses[-1]:.4f}')
|
| 599 |
+
axloss.plot(iters, test_losses, 'r-', linewidth=1, alpha=0.8, label=f'Test (Agg): {test_losses[-1]:.4f}')
|
| 600 |
+
axloss.legend(loc='upper right')
|
| 601 |
+
axloss.set_xlabel("Iteration")
|
| 602 |
+
axloss.set_ylabel("Loss")
|
| 603 |
+
axloss.set_xlim([0, args.training_iterations])
|
| 604 |
+
axloss.set_ylim(bottom=0)
|
| 605 |
+
figloss.tight_layout()
|
| 606 |
+
figloss.savefig(f'{args.log_dir}/losses.png', dpi=150)
|
| 607 |
+
plt.close(figloss)
|
| 608 |
+
# --- End Plotting ---
|
| 609 |
+
|
| 610 |
+
|
| 611 |
+
# --- Visualization (Rank 0, Conditional) ---
|
| 612 |
+
if is_main_process(rank) and args.model in ['ctm', 'lstm']:
|
| 613 |
+
# try:
|
| 614 |
+
model_module = model.module if isinstance(model, DDP) else model
|
| 615 |
+
# Use a consistent batch for viz if possible, or just next batch
|
| 616 |
+
inputs_viz, targets_viz = next(iter(testloader))
|
| 617 |
+
inputs_viz = inputs_viz.to(device)
|
| 618 |
+
targets_viz = targets_viz.to(device)
|
| 619 |
+
longest_index = (targets_viz!=4).sum(-1).argmax() # 4 assumed padding
|
| 620 |
+
|
| 621 |
+
pbar.set_description('Tracking (Rank 0): Viz Fwd Pass')
|
| 622 |
+
predictions_viz_raw, _, _, _, post_activations_viz, attention_tracking_viz = model_module(inputs_viz, track=True)
|
| 623 |
+
predictions_viz = predictions_viz_raw.reshape(predictions_viz_raw.size(0), -1, 5, predictions_viz_raw.size(-1))
|
| 624 |
+
|
| 625 |
+
att_shape = (model.module.kv_features.shape[2], model.module.kv_features.shape[3])
|
| 626 |
+
attention_tracking_viz = attention_tracking_viz.reshape(
|
| 627 |
+
attention_tracking_viz.shape[0],
|
| 628 |
+
attention_tracking_viz.shape[1], -1, att_shape[0], att_shape[1])
|
| 629 |
+
|
| 630 |
+
pbar.set_description('Tracking (Rank 0): Dynamics Plot')
|
| 631 |
+
plot_neural_dynamics(post_activations_viz, 100, args.log_dir, axis_snap=True)
|
| 632 |
+
|
| 633 |
+
pbar.set_description('Tracking (Rank 0): Maze GIF')
|
| 634 |
+
if attention_tracking_viz is not None:
|
| 635 |
+
make_maze_gif((inputs_viz[longest_index].detach().cpu().numpy()+1)/2,
|
| 636 |
+
predictions_viz[longest_index].detach().cpu().numpy(),
|
| 637 |
+
targets_viz[longest_index].detach().cpu().numpy(),
|
| 638 |
+
attention_tracking_viz[:, longest_index],
|
| 639 |
+
args.log_dir)
|
| 640 |
+
# else:
|
| 641 |
+
# print("Skipping maze GIF due to attention shape issue.")
|
| 642 |
+
|
| 643 |
+
# except Exception as e_viz:
|
| 644 |
+
# print(f"Rank 0 visualization failed: {e_viz}")
|
| 645 |
+
# --- End Visualization ---
|
| 646 |
+
|
| 647 |
+
gc.collect()
|
| 648 |
+
if torch.cuda.is_available():
|
| 649 |
+
torch.cuda.empty_cache()
|
| 650 |
+
if world_size > 1: dist.barrier()
|
| 651 |
+
model.train()
|
| 652 |
+
# --- End Evaluation Block ---
|
| 653 |
+
|
| 654 |
+
|
| 655 |
+
|
| 656 |
+
|
| 657 |
+
if hasattr(train_sampler, 'set_epoch'): # Check if sampler has set_epoch
|
| 658 |
+
train_sampler.set_epoch(bi)
|
| 659 |
+
|
| 660 |
+
current_lr = optimizer.param_groups[-1]['lr']
|
| 661 |
+
|
| 662 |
+
try:
|
| 663 |
+
inputs, targets = next(iterator)
|
| 664 |
+
except StopIteration:
|
| 665 |
+
iterator = iter(trainloader)
|
| 666 |
+
inputs, targets = next(iterator)
|
| 667 |
+
|
| 668 |
+
inputs = inputs.to(device, non_blocking=True)
|
| 669 |
+
targets = targets.to(device, non_blocking=True)
|
| 670 |
+
|
| 671 |
+
# Defaults for logging
|
| 672 |
+
loss = torch.tensor(0.0, device=device) # Need loss defined for logging scope
|
| 673 |
+
accuracy_finegrained = 0.0
|
| 674 |
+
where_most_certain_val = -1.0
|
| 675 |
+
where_most_certain_std = 0.0
|
| 676 |
+
where_most_certain_min = -1
|
| 677 |
+
where_most_certain_max = -1
|
| 678 |
+
upto_where_mean = -1.0
|
| 679 |
+
upto_where_std = 0.0
|
| 680 |
+
upto_where_min = -1
|
| 681 |
+
upto_where_max = -1
|
| 682 |
+
|
| 683 |
+
with torch.autocast(device_type="cuda" if device.type == 'cuda' else "cpu", dtype=torch.float16, enabled=args.use_amp):
|
| 684 |
+
if args.do_compile: torch.compiler.cudagraph_mark_step_begin()
|
| 685 |
+
|
| 686 |
+
if args.model == 'ctm':
|
| 687 |
+
predictions_raw, certainties, _ = model(inputs)
|
| 688 |
+
predictions = predictions_raw.reshape(predictions_raw.size(0), -1, 5, predictions_raw.size(-1)) # B,S,C,T
|
| 689 |
+
loss, where_most_certain, upto_where = maze_loss(predictions, certainties, targets, cirriculum_lookahead=args.cirriculum_lookahead, use_most_certain=True)
|
| 690 |
+
with torch.no_grad(): # Calculate local accuracy for logging
|
| 691 |
+
accuracy_finegrained = (predictions.argmax(2)[torch.arange(predictions.size(0), device=device), :, where_most_certain] == targets).float().mean().item()
|
| 692 |
+
elif args.model == 'lstm':
|
| 693 |
+
predictions_raw, certainties, _ = model(inputs)
|
| 694 |
+
predictions = predictions_raw.reshape(predictions_raw.size(0), -1, 5, predictions_raw.size(-1)) # B,S,C,T
|
| 695 |
+
loss, where_most_certain, upto_where = maze_loss(predictions, certainties, targets, cirriculum_lookahead=args.cirriculum_lookahead, use_most_certain=False) # where = -1
|
| 696 |
+
with torch.no_grad():
|
| 697 |
+
accuracy_finegrained = (predictions.argmax(2)[torch.arange(predictions.size(0), device=device), :, where_most_certain] == targets).float().mean().item()
|
| 698 |
+
elif args.model == 'ff':
|
| 699 |
+
predictions_raw = model(inputs) # B, S*C
|
| 700 |
+
predictions = predictions_raw.reshape(predictions_raw.size(0), -1, 5) # B,S,C
|
| 701 |
+
loss, where_most_certain, upto_where = maze_loss(predictions.unsqueeze(-1), None, targets, cirriculum_lookahead=args.cirriculum_lookahead, use_most_certain=False) # where = -1
|
| 702 |
+
with torch.no_grad():
|
| 703 |
+
accuracy_finegrained = (predictions.argmax(2) == targets).float().mean().item()
|
| 704 |
+
|
| 705 |
+
# Extract stats from loss outputs
|
| 706 |
+
if torch.is_tensor(where_most_certain):
|
| 707 |
+
where_most_certain_val = where_most_certain.float().mean().item()
|
| 708 |
+
where_most_certain_std = where_most_certain.float().std().item()
|
| 709 |
+
where_most_certain_min = where_most_certain.min().item()
|
| 710 |
+
where_most_certain_max = where_most_certain.max().item()
|
| 711 |
+
elif isinstance(where_most_certain, int):
|
| 712 |
+
where_most_certain_val = float(where_most_certain); where_most_certain_min = where_most_certain; where_most_certain_max = where_most_certain
|
| 713 |
+
if isinstance(upto_where, (np.ndarray, list)) and len(upto_where) > 0:
|
| 714 |
+
upto_where_mean = np.mean(upto_where); upto_where_std = np.std(upto_where); upto_where_min = np.min(upto_where); upto_where_max = np.max(upto_where)
|
| 715 |
+
|
| 716 |
+
# Backprop / Step
|
| 717 |
+
scaler.scale(loss).backward()
|
| 718 |
+
if args.gradient_clipping!=-1:
|
| 719 |
+
scaler.unscale_(optimizer)
|
| 720 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.gradient_clipping)
|
| 721 |
+
scaler.step(optimizer)
|
| 722 |
+
scaler.update()
|
| 723 |
+
optimizer.zero_grad(set_to_none=True)
|
| 724 |
+
scheduler.step()
|
| 725 |
+
|
| 726 |
+
# --- Aggregation and Logging (Rank 0) ---
|
| 727 |
+
loss_log = loss.detach()
|
| 728 |
+
if world_size > 1: dist.all_reduce(loss_log, op=dist.ReduceOp.AVG)
|
| 729 |
+
|
| 730 |
+
if is_main_process(rank):
|
| 731 |
+
pbar_desc = f'Loss(avg)={loss_log.item():.3f} Acc(loc)={accuracy_finegrained:.3f} LR={current_lr:.6f}'
|
| 732 |
+
if args.model in ['ctm', 'lstm'] or torch.is_tensor(where_most_certain):
|
| 733 |
+
pbar_desc += f' Cert={where_most_certain_val:.2f}'#+-{where_most_certain_std:.2f}' # Removed std for brevity
|
| 734 |
+
if isinstance(upto_where, (np.ndarray, list)) and len(upto_where) > 0:
|
| 735 |
+
pbar_desc += f' Path={upto_where_mean:.1f}'#+-{upto_where_std:.1f}'
|
| 736 |
+
pbar.set_description(f'{args.model.upper()} {pbar_desc}')
|
| 737 |
+
# --- End Aggregation and Logging ---
|
| 738 |
+
|
| 739 |
+
|
| 740 |
+
|
| 741 |
+
|
| 742 |
+
|
| 743 |
+
# --- Checkpointing (Rank 0) ---
|
| 744 |
+
if (bi % args.save_every == 0 or bi == args.training_iterations - 1) and bi != start_iter and is_main_process(rank):
|
| 745 |
+
pbar.set_description('Rank 0: Saving checkpoint...')
|
| 746 |
+
save_path = f'{args.log_dir}/checkpoint.pt'
|
| 747 |
+
model_state_to_save = model.module.state_dict() if isinstance(model, DDP) else model.state_dict()
|
| 748 |
+
|
| 749 |
+
checkpoint_data = {
|
| 750 |
+
'model_state_dict': model_state_to_save,
|
| 751 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 752 |
+
'scheduler_state_dict': scheduler.state_dict(),
|
| 753 |
+
'scaler_state_dict': scaler.state_dict(),
|
| 754 |
+
'iteration': bi,
|
| 755 |
+
'train_losses': train_losses,
|
| 756 |
+
'test_losses': test_losses,
|
| 757 |
+
'train_accuracies': train_accuracies, # Saving simplified scalar list
|
| 758 |
+
'test_accuracies': test_accuracies, # Saving simplified scalar list
|
| 759 |
+
'train_accuracies_most_certain': train_accuracies_most_certain,
|
| 760 |
+
'test_accuracies_most_certain': test_accuracies_most_certain,
|
| 761 |
+
'train_accuracies_most_certain_permaze': train_accuracies_most_certain_permaze,
|
| 762 |
+
'test_accuracies_most_certain_permaze': test_accuracies_most_certain_permaze,
|
| 763 |
+
'iters': iters,
|
| 764 |
+
'args': args,
|
| 765 |
+
'torch_rng_state': torch.get_rng_state(),
|
| 766 |
+
'numpy_rng_state': np.random.get_state(),
|
| 767 |
+
'random_rng_state': random.getstate(),
|
| 768 |
+
}
|
| 769 |
+
torch.save(checkpoint_data, save_path)
|
| 770 |
+
# --- End Checkpointing ---
|
| 771 |
+
|
| 772 |
+
|
| 773 |
+
if world_size > 1: dist.barrier()
|
| 774 |
+
|
| 775 |
+
if is_main_process(rank):
|
| 776 |
+
pbar.update(1)
|
| 777 |
+
# --- End Training Loop ---
|
| 778 |
+
|
| 779 |
+
if is_main_process(rank):
|
| 780 |
+
pbar.close()
|
| 781 |
+
|
| 782 |
+
cleanup_ddp()
|
tasks/parity/README.md
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Parity
|
| 2 |
+
|
| 3 |
+
## Training
|
| 4 |
+
To run the parity training that we used for the paper, run bash scripts from the root level of the repository. For example, to train the 75-iteration, 25-memory-length CTM, run:
|
| 5 |
+
|
| 6 |
+
```
|
| 7 |
+
bash tasks/parity/scripts/train_ctm_75_25.sh
|
| 8 |
+
```
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
## Analysis
|
| 12 |
+
To run the analysis, first make sure the checkpoints are saved in the log directory (specified by the `log_dir` argument). The checkpoints can be obtained by either running the training code, or downloading them from [this link](https://drive.google.com/file/d/1itUS5_i9AyUo_7awllTx8X0PXYw9fnaG/view?usp=drive_link).
|
| 13 |
+
|
| 14 |
+
```
|
| 15 |
+
python -m tasks.parity.analysis.run --log_dir <PATH_TO_LOG_DIR>
|
| 16 |
+
```
|
tasks/parity/analysis/make_blog_gifs.py
ADDED
|
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import torch
|
| 3 |
+
import os
|
| 4 |
+
import math
|
| 5 |
+
import imageio
|
| 6 |
+
import numpy as np
|
| 7 |
+
import matplotlib.pyplot as plt
|
| 8 |
+
from matplotlib.patches import FancyArrowPatch
|
| 9 |
+
from scipy.special import softmax
|
| 10 |
+
import matplotlib.cm as cm
|
| 11 |
+
from data.custom_datasets import ParityDataset
|
| 12 |
+
import umap
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
from models.utils import reshape_predictions
|
| 17 |
+
from tasks.parity.utils import reshape_inputs
|
| 18 |
+
from tasks.parity.analysis.run import build_model_from_checkpoint_path
|
| 19 |
+
|
| 20 |
+
from tasks.image_classification.plotting import save_frames_to_mp4
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def make_parity_gif(
|
| 24 |
+
predictions,
|
| 25 |
+
targets,
|
| 26 |
+
post_activations,
|
| 27 |
+
attention_weights,
|
| 28 |
+
inputs_to_model,
|
| 29 |
+
save_path,
|
| 30 |
+
umap_positions,
|
| 31 |
+
umap_point_scaler=1.0,
|
| 32 |
+
):
|
| 33 |
+
batch_index = 0
|
| 34 |
+
figscale = 0.32
|
| 35 |
+
n_steps, n_heads, seqLen = attention_weights.shape[:3]
|
| 36 |
+
grid_side = int(np.sqrt(seqLen))
|
| 37 |
+
frames = []
|
| 38 |
+
|
| 39 |
+
inputs_this_batch = inputs_to_model[:, batch_index]
|
| 40 |
+
preds_this_batch = predictions[batch_index]
|
| 41 |
+
targets_this_batch = targets[batch_index]
|
| 42 |
+
post_act_this_batch = post_activations[:, batch_index]
|
| 43 |
+
|
| 44 |
+
# build a flexible mosaic
|
| 45 |
+
mosaic = [
|
| 46 |
+
[f"att_0", f"in_0", "probs", "probs", "target", "target"],
|
| 47 |
+
[f"att_1", f"in_1", "probs", "probs", "target", "target"],
|
| 48 |
+
]
|
| 49 |
+
for h in range(2, n_heads):
|
| 50 |
+
mosaic.append(
|
| 51 |
+
[f"att_{h}", f"in_{h}", "umap", "umap",
|
| 52 |
+
"umap", "umap"]
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
for t in range(n_steps):
|
| 56 |
+
rows = len(mosaic)
|
| 57 |
+
cell_size = figscale * 4
|
| 58 |
+
fig_h = rows * cell_size
|
| 59 |
+
|
| 60 |
+
fig, ax = plt.subplot_mosaic(
|
| 61 |
+
mosaic,
|
| 62 |
+
figsize=(6 * cell_size, fig_h),
|
| 63 |
+
constrained_layout=False,
|
| 64 |
+
gridspec_kw={'wspace': 0.05, 'hspace': 0.05}, # small gaps
|
| 65 |
+
)
|
| 66 |
+
# restore a little margin
|
| 67 |
+
fig.subplots_adjust(left=0.02, right=0.98, top=0.98, bottom=0.02)
|
| 68 |
+
|
| 69 |
+
# probabilities heatmap
|
| 70 |
+
logits_t = preds_this_batch[:, :, t]
|
| 71 |
+
probs_t = softmax(logits_t, axis=1)[:, 0].reshape(grid_side, grid_side)
|
| 72 |
+
ax["probs"].imshow(probs_t, cmap="gray", vmin=0, vmax=1)
|
| 73 |
+
ax["probs"].axis("off")
|
| 74 |
+
|
| 75 |
+
# target overlay
|
| 76 |
+
ax["target"].imshow(
|
| 77 |
+
targets_this_batch.reshape(grid_side, grid_side),
|
| 78 |
+
cmap="gray_r", vmin=0, vmax=1
|
| 79 |
+
)
|
| 80 |
+
ax["target"].axis("off")
|
| 81 |
+
ax["target"].grid(which="minor", color="black", linestyle="-", linewidth=0.5)
|
| 82 |
+
|
| 83 |
+
z = post_act_this_batch[t]
|
| 84 |
+
low, high = np.percentile(z, 5), np.percentile(z, 95)
|
| 85 |
+
z_norm = np.clip((z - low) / (high - low), 0, 1)
|
| 86 |
+
point_sizes = (np.abs(z_norm - 0.5) * 100 + 5) * umap_point_scaler
|
| 87 |
+
cmap = plt.get_cmap("Spectral")
|
| 88 |
+
ax["umap"].scatter(
|
| 89 |
+
umap_positions[:, 0],
|
| 90 |
+
umap_positions[:, 1],
|
| 91 |
+
s=point_sizes,
|
| 92 |
+
c=cmap(z_norm),
|
| 93 |
+
alpha=0.8
|
| 94 |
+
)
|
| 95 |
+
ax["umap"].axis("off")
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
# normalize attention
|
| 99 |
+
att_t = attention_weights[t, :, :]
|
| 100 |
+
a_min, a_max = att_t.min(), att_t.max()
|
| 101 |
+
if not np.isclose(a_min, a_max):
|
| 102 |
+
att_t = (att_t - a_min) / (a_max - a_min + 1e-8)
|
| 103 |
+
else:
|
| 104 |
+
att_t = np.zeros_like(att_t)
|
| 105 |
+
|
| 106 |
+
# input image for arrows
|
| 107 |
+
img_t = inputs_this_batch[t].transpose(1, 2, 0)
|
| 108 |
+
|
| 109 |
+
if t == 0:
|
| 110 |
+
route_history = [[] for _ in range(n_heads)]
|
| 111 |
+
|
| 112 |
+
img_h, img_w = img_t.shape[:2]
|
| 113 |
+
cell_h = img_h // grid_side
|
| 114 |
+
cell_w = img_w // grid_side
|
| 115 |
+
|
| 116 |
+
for h in range(n_heads):
|
| 117 |
+
head_map = att_t[h].reshape(grid_side, grid_side)
|
| 118 |
+
ax[f"att_{h}"].imshow(head_map, cmap="viridis", vmin=0, vmax=1)
|
| 119 |
+
ax[f"att_{h}"].axis("off")
|
| 120 |
+
ax[f"in_{h}"].imshow(img_t, cmap="gray", vmin=0, vmax=1)
|
| 121 |
+
ax[f"in_{h}"].axis("off")
|
| 122 |
+
|
| 123 |
+
# track argmax center
|
| 124 |
+
flat_idx = np.argmax(head_map)
|
| 125 |
+
gy, gx = divmod(flat_idx, grid_side)
|
| 126 |
+
cx = int((gx + 0.5) * cell_w)
|
| 127 |
+
cy = int((gy + 0.5) * cell_h)
|
| 128 |
+
route_history[h].append((cx, cy))
|
| 129 |
+
|
| 130 |
+
cmap_steps = plt.colormaps.get_cmap("Spectral")
|
| 131 |
+
colors = [cmap_steps(i / (n_steps - 1)) for i in range(n_steps)]
|
| 132 |
+
for i in range(len(route_history[h]) - 1):
|
| 133 |
+
x0, y0 = route_history[h][i]
|
| 134 |
+
x1, y1 = route_history[h][i + 1]
|
| 135 |
+
color = colors[i]
|
| 136 |
+
is_last = (i == len(route_history[h]) - 2)
|
| 137 |
+
style = '->' if is_last else '-'
|
| 138 |
+
lw = 2.0 if is_last else 1.6
|
| 139 |
+
alpha = 1.0 if is_last else 0.9
|
| 140 |
+
scale = 10 if is_last else 1
|
| 141 |
+
|
| 142 |
+
# draw arrow
|
| 143 |
+
arr = FancyArrowPatch(
|
| 144 |
+
(x0, y0), (x1, y1),
|
| 145 |
+
arrowstyle=style,
|
| 146 |
+
linewidth=lw,
|
| 147 |
+
mutation_scale=scale,
|
| 148 |
+
alpha=alpha,
|
| 149 |
+
facecolor=color,
|
| 150 |
+
edgecolor=color,
|
| 151 |
+
shrinkA=0, shrinkB=0,
|
| 152 |
+
capstyle='round', joinstyle='round',
|
| 153 |
+
zorder=3 if is_last else 2,
|
| 154 |
+
clip_on=False,
|
| 155 |
+
)
|
| 156 |
+
ax[f"in_{h}"].add_patch(arr)
|
| 157 |
+
|
| 158 |
+
ax[f"in_{h}"].scatter(
|
| 159 |
+
x1, y1,
|
| 160 |
+
marker='x',
|
| 161 |
+
s=40,
|
| 162 |
+
color=color,
|
| 163 |
+
linewidths=lw,
|
| 164 |
+
zorder=4
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
canvas = fig.canvas
|
| 168 |
+
canvas.draw()
|
| 169 |
+
frame = np.frombuffer(canvas.buffer_rgba(), dtype=np.uint8)
|
| 170 |
+
w, h = canvas.get_width_height()
|
| 171 |
+
frames.append(frame.reshape(h, w, 4)[..., :3])
|
| 172 |
+
plt.close(fig)
|
| 173 |
+
|
| 174 |
+
# save gif
|
| 175 |
+
imageio.mimsave(f"{save_path}/activation.gif", frames, fps=15, loop=0)
|
| 176 |
+
|
| 177 |
+
# save mp4
|
| 178 |
+
save_frames_to_mp4(
|
| 179 |
+
[fm[:, :, ::-1] for fm in frames], # RGB→BGR
|
| 180 |
+
f"{save_path}/activation.mp4",
|
| 181 |
+
fps=15,
|
| 182 |
+
gop_size=1,
|
| 183 |
+
preset="slow"
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
def run_umap(model, testloader):
|
| 187 |
+
all_post_activations = []
|
| 188 |
+
point_counts = 150
|
| 189 |
+
sampled = 0
|
| 190 |
+
with tqdm(total=point_counts, desc="Collecting UMAP data") as pbar:
|
| 191 |
+
for inputs, _ in testloader:
|
| 192 |
+
for i in range(inputs.size(0)):
|
| 193 |
+
if sampled >= point_counts:
|
| 194 |
+
break
|
| 195 |
+
input_i = inputs[i].unsqueeze(0).to(device)
|
| 196 |
+
_, _, _, _, post_activations, _ = model(input_i, track=True)
|
| 197 |
+
all_post_activations.append(post_activations)
|
| 198 |
+
sampled += 1
|
| 199 |
+
pbar.update(1)
|
| 200 |
+
if sampled >= point_counts:
|
| 201 |
+
break
|
| 202 |
+
|
| 203 |
+
stacked = np.stack(all_post_activations, 1)
|
| 204 |
+
umap_features = stacked.reshape(-1, stacked.shape[-1])
|
| 205 |
+
reducer = umap.UMAP(
|
| 206 |
+
n_components=2,
|
| 207 |
+
n_neighbors=20,
|
| 208 |
+
min_dist=1,
|
| 209 |
+
spread=1,
|
| 210 |
+
metric='cosine',
|
| 211 |
+
local_connectivity=1
|
| 212 |
+
)
|
| 213 |
+
positions = reducer.fit_transform(umap_features.T)
|
| 214 |
+
return positions
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def run_model_and_make_gif(checkpoint_path, save_path, device):
|
| 218 |
+
|
| 219 |
+
parity_sequence_length = 64
|
| 220 |
+
iterations = 75
|
| 221 |
+
|
| 222 |
+
test_data = ParityDataset(sequence_length=parity_sequence_length, length=10000)
|
| 223 |
+
testloader = torch.utils.data.DataLoader(test_data, batch_size=256, shuffle=True, num_workers=0, drop_last=False)
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
model, _ = build_model_from_checkpoint_path(checkpoint_path, "ctm", device=device)
|
| 227 |
+
|
| 228 |
+
input = torch.randint(0, 2, (64,), dtype=torch.float32, device=device) * 2 - 1
|
| 229 |
+
input = input.unsqueeze(0)
|
| 230 |
+
|
| 231 |
+
target = torch.cumsum((input == -1).to(torch.long), dim=1) % 2
|
| 232 |
+
target = target.unsqueeze(0)
|
| 233 |
+
|
| 234 |
+
positions = run_umap(model, testloader)
|
| 235 |
+
|
| 236 |
+
model.eval()
|
| 237 |
+
with torch.inference_mode():
|
| 238 |
+
predictions, _, _, _, post_activations, attention = model(input, track=True)
|
| 239 |
+
predictons = reshape_predictions(predictions, prediction_reshaper=[parity_sequence_length, 2])
|
| 240 |
+
input_images = reshape_inputs(input, iterations, grid_size=int(math.sqrt(parity_sequence_length)))
|
| 241 |
+
|
| 242 |
+
make_parity_gif(
|
| 243 |
+
predictions=predictons.detach().cpu().numpy(),
|
| 244 |
+
targets=target.detach().cpu().numpy(),
|
| 245 |
+
post_activations=post_activations,
|
| 246 |
+
attention_weights=attention.squeeze(1).squeeze(2),
|
| 247 |
+
inputs_to_model=input_images,
|
| 248 |
+
save_path=save_path,
|
| 249 |
+
umap_positions=positions,
|
| 250 |
+
umap_point_scaler=1.0,
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
if __name__ == "__main__":
|
| 256 |
+
|
| 257 |
+
CHECKPOINT_PATH = "checkpoints/parity/run1/ctm_75_25/checkpoint_200000.pt"
|
| 258 |
+
SAVE_PATH = f"tasks/parity/analysis/outputs/blog_gifs/"
|
| 259 |
+
os.makedirs(SAVE_PATH, exist_ok=True)
|
| 260 |
+
|
| 261 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 262 |
+
|
| 263 |
+
run_model_and_make_gif(CHECKPOINT_PATH, SAVE_PATH, device)
|
tasks/parity/analysis/run.py
ADDED
|
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import argparse
|
| 4 |
+
import multiprocessing
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
import math
|
| 7 |
+
import os
|
| 8 |
+
import csv
|
| 9 |
+
from utils.housekeeping import set_seed
|
| 10 |
+
from data.custom_datasets import ParityDataset
|
| 11 |
+
from tasks.parity.utils import prepare_model, reshape_attention_weights, reshape_inputs, get_where_most_certain
|
| 12 |
+
from tasks.parity.plotting import plot_attention_trajectory, plot_input, plot_target, plot_probabilities, plot_prediction, plot_accuracy_training, create_attentions_heatmap_gif, create_accuracies_heatmap_gif, create_stacked_gif, plot_training_curve_all_runs, plot_accuracy_thinking_time, make_parity_gif, plot_lstm_last_and_certain_accuracy
|
| 13 |
+
from models.utils import compute_normalized_entropy, reshape_predictions, get_latest_checkpoint_file, get_checkpoint_files, load_checkpoint, get_model_args_from_checkpoint, get_all_log_dirs
|
| 14 |
+
from tasks.image_classification.plotting import plot_neural_dynamics
|
| 15 |
+
|
| 16 |
+
import seaborn as sns
|
| 17 |
+
sns.set_palette("hls")
|
| 18 |
+
sns.set_style('darkgrid')
|
| 19 |
+
|
| 20 |
+
def parse_args():
|
| 21 |
+
parser = argparse.ArgumentParser(description='Parity Analysis')
|
| 22 |
+
parser.add_argument('--log_dir', type=str, default='checkpoints/parity', help='Directory to save logs.')
|
| 23 |
+
parser.add_argument('--batch_size_test', type=int, default=128, help='batch size for testing')
|
| 24 |
+
parser.add_argument('--scale_training_curve', type=float, default=0.6, help='Scaling factor for plots.')
|
| 25 |
+
parser.add_argument('--scale_heatmap', type=float, default=0.4, help='Scaling factor for heatmap plots.')
|
| 26 |
+
parser.add_argument('--scale_training_index_accuracy', type=float, default=0.4, help='Scaling factor for training index accuracy plots.')
|
| 27 |
+
parser.add_argument('--seed', type=int, default=0, help='Random seed for reproducibility.')
|
| 28 |
+
parser.add_argument('--device', type=int, nargs='+', default=[-1], help='List of GPU(s) to use. Set to -1 to use CPU.')
|
| 29 |
+
parser.add_argument('--model_type', type=str, choices=['ctm', 'lstm'], default='ctm', help='Type of model to analyze (ctm or lstm).')
|
| 30 |
+
return parser.parse_args()
|
| 31 |
+
|
| 32 |
+
def calculate_corrects(predictions, targets):
|
| 33 |
+
predicted_labels = predictions.argmax(2)
|
| 34 |
+
accuracy = (predicted_labels == targets.unsqueeze(-1))
|
| 35 |
+
return accuracy.detach().cpu().numpy()
|
| 36 |
+
|
| 37 |
+
def get_corrects_per_element_at_most_certain_time(predictions, certainty, targets):
|
| 38 |
+
where_most_certain = get_where_most_certain(certainty)
|
| 39 |
+
corrects = (predictions.argmax(2)[torch.arange(predictions.size(0), device=predictions.device),:,where_most_certain] == targets).float()
|
| 40 |
+
return corrects.detach().cpu().numpy()
|
| 41 |
+
|
| 42 |
+
def calculate_entropy_average_over_batch(normalized_entropy_per_elements):
|
| 43 |
+
normalized_entropy_per_elements_avg_batch = normalized_entropy_per_elements.mean(axis=1)
|
| 44 |
+
return normalized_entropy_per_elements_avg_batch
|
| 45 |
+
|
| 46 |
+
def calculate_thinking_time_average_over_batch(normalized_entropy_per_elements):
|
| 47 |
+
first_occurrence = calculate_thinking_time(normalized_entropy_per_elements)
|
| 48 |
+
average_thinking_time = np.mean(first_occurrence, axis=0)
|
| 49 |
+
return average_thinking_time
|
| 50 |
+
|
| 51 |
+
def calculate_thinking_time(normalized_entropy_per_elements, finish_type="min", entropy_threshold=0.1):
|
| 52 |
+
if finish_type == "min":
|
| 53 |
+
min_entropy_time = np.argmin(normalized_entropy_per_elements, axis=0)
|
| 54 |
+
return min_entropy_time
|
| 55 |
+
elif finish_type == "threshold":
|
| 56 |
+
T, B, S = normalized_entropy_per_elements.shape
|
| 57 |
+
below_threshold = normalized_entropy_per_elements < entropy_threshold
|
| 58 |
+
first_occurrence = np.argmax(below_threshold, axis=0)
|
| 59 |
+
no_true = ~np.any(below_threshold, axis=0)
|
| 60 |
+
first_occurrence[no_true] = T
|
| 61 |
+
return first_occurrence
|
| 62 |
+
|
| 63 |
+
def test_handcrafted_examples(model, args, run_model_spefic_save_dir, device):
|
| 64 |
+
test_cases = []
|
| 65 |
+
all_even_input = torch.full((args.parity_sequence_length,), 1.0, dtype=torch.float32, device=device)
|
| 66 |
+
all_even_target = torch.zeros_like(all_even_input, dtype=torch.long)
|
| 67 |
+
test_cases.append((all_even_input, all_even_target))
|
| 68 |
+
|
| 69 |
+
all_odd_input = torch.full((args.parity_sequence_length,), -1.0, dtype=torch.float32, device=device)
|
| 70 |
+
all_odd_target = torch.cumsum((all_odd_input == -1).to(torch.long), dim=0) % 2
|
| 71 |
+
test_cases.append((all_odd_input, all_odd_target))
|
| 72 |
+
|
| 73 |
+
random_input = torch.randint(0, 2, (args.parity_sequence_length,), dtype=torch.float32, device=device) * 2 - 1
|
| 74 |
+
random_target = torch.cumsum((random_input == -1).to(torch.long), dim=0) % 2
|
| 75 |
+
test_cases.append((random_input, random_target))
|
| 76 |
+
|
| 77 |
+
for i, (inputs, targets) in enumerate(test_cases):
|
| 78 |
+
inputs = inputs.unsqueeze(0)
|
| 79 |
+
targets = targets.unsqueeze(0)
|
| 80 |
+
filename = f"eval_handcrafted_{i}"
|
| 81 |
+
extend_inference_time = False
|
| 82 |
+
handcraft_dir = f"{run_model_spefic_save_dir}/handcrafted_examples/{i}"
|
| 83 |
+
os.makedirs(handcraft_dir, exist_ok=True)
|
| 84 |
+
|
| 85 |
+
model.eval()
|
| 86 |
+
with torch.inference_mode():
|
| 87 |
+
if extend_inference_time:
|
| 88 |
+
model.iterations = model.iterations * 2
|
| 89 |
+
predictions, certainties, synchronisation, pre_activations, post_activations, attention = model(inputs, track=True)
|
| 90 |
+
predictions = reshape_predictions(predictions, prediction_reshaper=[args.parity_sequence_length, 2])
|
| 91 |
+
input_images = reshape_inputs(inputs, args.iterations, grid_size=int(math.sqrt(args.parity_sequence_length)))
|
| 92 |
+
|
| 93 |
+
plot_neural_dynamics(post_activations, 100, handcraft_dir, axis_snap=False)
|
| 94 |
+
|
| 95 |
+
process = multiprocessing.Process(
|
| 96 |
+
target=make_parity_gif,
|
| 97 |
+
args=(
|
| 98 |
+
predictions.detach().cpu().numpy(),
|
| 99 |
+
certainties.detach().cpu().numpy(),
|
| 100 |
+
targets.detach().cpu().numpy(),
|
| 101 |
+
pre_activations,
|
| 102 |
+
post_activations,
|
| 103 |
+
reshape_attention_weights(attention),
|
| 104 |
+
input_images,
|
| 105 |
+
f"{handcraft_dir}/eval_output_val_{0}_iter_{0}.gif",
|
| 106 |
+
))
|
| 107 |
+
process.start()
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
input_images = input_images.squeeze(1).squeeze(1)
|
| 111 |
+
attention = attention.squeeze(1)
|
| 112 |
+
|
| 113 |
+
for h in range(args.heads):
|
| 114 |
+
plot_attention_trajectory(attention[:, h, :, :], certainties, input_images, handcraft_dir, filename + f"_head_{h}", args)
|
| 115 |
+
|
| 116 |
+
plot_attention_trajectory(attention.mean(1), certainties, input_images, handcraft_dir, filename, args)
|
| 117 |
+
plot_input(input_images, handcraft_dir, filename)
|
| 118 |
+
plot_target(targets, handcraft_dir, filename, args)
|
| 119 |
+
plot_probabilities(predictions, certainties, handcraft_dir, filename, args)
|
| 120 |
+
plot_prediction(predictions, certainties,handcraft_dir, filename, args)
|
| 121 |
+
|
| 122 |
+
if extend_inference_time:
|
| 123 |
+
model.iterations = model.iterations // 2
|
| 124 |
+
model.train()
|
| 125 |
+
pass
|
| 126 |
+
|
| 127 |
+
def build_model_from_checkpoint_path(checkpoint_path, model_type, device="cpu"):
|
| 128 |
+
checkpoint = load_checkpoint(checkpoint_path, device)
|
| 129 |
+
model_args = get_model_args_from_checkpoint(checkpoint)
|
| 130 |
+
model = prepare_model([model_args.parity_sequence_length, 2], model_args, device)
|
| 131 |
+
model.load_state_dict(checkpoint["model_state_dict"], strict=False)
|
| 132 |
+
return model, model_args
|
| 133 |
+
|
| 134 |
+
def analyze_trained_model(run_model_spefic_save_dir, args, device):
|
| 135 |
+
with torch.no_grad():
|
| 136 |
+
|
| 137 |
+
latest_checkpoint_path = get_latest_checkpoint_file(args.log_dir)
|
| 138 |
+
model, model_args = build_model_from_checkpoint_path(latest_checkpoint_path, args.model_type, device=device)
|
| 139 |
+
model.eval()
|
| 140 |
+
model_args.log_dir = args.log_dir
|
| 141 |
+
test_data = ParityDataset(sequence_length=model_args.parity_sequence_length, length=10000)
|
| 142 |
+
testloader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size_test, shuffle=True, num_workers=0, drop_last=False)
|
| 143 |
+
|
| 144 |
+
corrects, corrects_at_most_certain_times, entropys, attentions = [], [], [], []
|
| 145 |
+
|
| 146 |
+
for inputs, targets in testloader:
|
| 147 |
+
inputs = inputs.to(device)
|
| 148 |
+
targets = targets.to(device)
|
| 149 |
+
predictions, certainties, synchronisation, pre_activations, post_activations, attention = model(inputs, track=True)
|
| 150 |
+
predictions = reshape_predictions(predictions, prediction_reshaper=[model_args.parity_sequence_length, 2])
|
| 151 |
+
corrects_batch = calculate_corrects(predictions, targets)
|
| 152 |
+
corrects_at_most_certain_time_batch = get_corrects_per_element_at_most_certain_time(predictions, certainties, targets)
|
| 153 |
+
corrects.append(corrects_batch)
|
| 154 |
+
corrects_at_most_certain_times.append(corrects_at_most_certain_time_batch)
|
| 155 |
+
attentions.append(attention)
|
| 156 |
+
|
| 157 |
+
test_handcrafted_examples(model, model_args, run_model_spefic_save_dir, device)
|
| 158 |
+
|
| 159 |
+
overall_mean_accuracy = np.mean(np.vstack(corrects_at_most_certain_times))
|
| 160 |
+
overall_std_accuracy = np.std(np.mean(np.vstack(corrects_at_most_certain_times), axis=1))
|
| 161 |
+
|
| 162 |
+
return overall_mean_accuracy, overall_std_accuracy, model_args.iterations
|
| 163 |
+
|
| 164 |
+
def analyze_training(run_model_spefic_save_dir, args, device):
|
| 165 |
+
checkpoint_files = get_checkpoint_files(args.log_dir)
|
| 166 |
+
all_accuracies = []
|
| 167 |
+
all_accuracies_at_most_certain_time = []
|
| 168 |
+
all_average_thinking_times = []
|
| 169 |
+
all_std_thinking_times = []
|
| 170 |
+
all_attentions = []
|
| 171 |
+
for checkpoint_path in checkpoint_files:
|
| 172 |
+
model, model_args = build_model_from_checkpoint_path(checkpoint_path, args.model_type, device=device)
|
| 173 |
+
model_args.log_dir = run_model_spefic_save_dir
|
| 174 |
+
test_data = ParityDataset(sequence_length=model_args.parity_sequence_length, length=1000)
|
| 175 |
+
testloader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size_test, shuffle=True, num_workers=0, drop_last=False)
|
| 176 |
+
corrects = []
|
| 177 |
+
corrects_at_most_certain_times = []
|
| 178 |
+
thinking_times = []
|
| 179 |
+
attentions = []
|
| 180 |
+
|
| 181 |
+
for inputs, targets in testloader:
|
| 182 |
+
inputs = inputs.to(device)
|
| 183 |
+
targets = targets.to(device)
|
| 184 |
+
predictions, certainties, synchronisation, pre_activations, post_activations, attention = model(inputs, track=True)
|
| 185 |
+
predictions = reshape_predictions(predictions, prediction_reshaper=[model_args.parity_sequence_length, 2])
|
| 186 |
+
attention = reshape_attention_weights(attention)
|
| 187 |
+
|
| 188 |
+
corrects_batch = calculate_corrects(predictions, targets)
|
| 189 |
+
corrects_at_most_certain_time_batch = get_corrects_per_element_at_most_certain_time(predictions, certainties, targets)
|
| 190 |
+
entropy_per_element = compute_normalized_entropy(predictions.permute(0,3,1,2), reduction='none').detach().cpu().numpy()
|
| 191 |
+
thinking_times_batch = np.argmin(entropy_per_element, axis=1)
|
| 192 |
+
|
| 193 |
+
corrects.append(corrects_batch)
|
| 194 |
+
corrects_at_most_certain_times.append(corrects_at_most_certain_time_batch)
|
| 195 |
+
thinking_times.append(thinking_times_batch)
|
| 196 |
+
attentions.append(attention)
|
| 197 |
+
|
| 198 |
+
checkpoint_average_accuracies = np.mean(np.concatenate(corrects, axis=0), axis=0).transpose(1,0)
|
| 199 |
+
all_accuracies.append(checkpoint_average_accuracies)
|
| 200 |
+
|
| 201 |
+
stacked_corrects_at_most_certain_times = np.vstack(corrects_at_most_certain_times)
|
| 202 |
+
checkpoint_average_accuracy_at_most_certain_time = np.mean(stacked_corrects_at_most_certain_times, axis=0)
|
| 203 |
+
all_accuracies_at_most_certain_time.append(checkpoint_average_accuracy_at_most_certain_time)
|
| 204 |
+
|
| 205 |
+
checkpoint_thinking_times = np.concatenate(thinking_times, axis=0)
|
| 206 |
+
checkpoint_average_thinking_time = np.mean(checkpoint_thinking_times, axis=0)
|
| 207 |
+
checkpoint_std_thinking_time = np.std(checkpoint_thinking_times, axis=0)
|
| 208 |
+
all_average_thinking_times.append(checkpoint_average_thinking_time)
|
| 209 |
+
all_std_thinking_times.append(checkpoint_std_thinking_time)
|
| 210 |
+
|
| 211 |
+
checkpoint_average_attentions = np.mean(np.concatenate(attentions, axis=1), axis=1)
|
| 212 |
+
all_attentions.append(checkpoint_average_attentions)
|
| 213 |
+
|
| 214 |
+
plot_accuracy_training(all_accuracies_at_most_certain_time, args.scale_training_index_accuracy, run_model_spefic_save_dir, args=model_args)
|
| 215 |
+
create_attentions_heatmap_gif(all_attentions, args.scale_heatmap, run_model_spefic_save_dir, model_args)
|
| 216 |
+
create_accuracies_heatmap_gif(np.array(all_accuracies), all_average_thinking_times, all_std_thinking_times, args.scale_heatmap, run_model_spefic_save_dir, model_args)
|
| 217 |
+
create_stacked_gif(run_model_spefic_save_dir)
|
| 218 |
+
|
| 219 |
+
def get_accuracy_and_loss_from_checkpoint(checkpoint):
|
| 220 |
+
training_iteration = checkpoint.get('training_iteration', 0)
|
| 221 |
+
train_losses = checkpoint.get('train_losses', [])
|
| 222 |
+
test_losses = checkpoint.get('test_losses', [])
|
| 223 |
+
train_accuracies = checkpoint.get('train_accuracies_most_certain', [])
|
| 224 |
+
test_accuracies = checkpoint.get('test_accuracies_most_certain', [])
|
| 225 |
+
return training_iteration, train_losses, test_losses, train_accuracies, test_accuracies
|
| 226 |
+
|
| 227 |
+
if __name__ == "__main__":
|
| 228 |
+
|
| 229 |
+
args = parse_args()
|
| 230 |
+
|
| 231 |
+
device = f'cuda:{args.device[0]}' if args.device[0] != -1 else 'cpu'
|
| 232 |
+
|
| 233 |
+
set_seed(args.seed)
|
| 234 |
+
|
| 235 |
+
save_dir = "tasks/parity/analysis/outputs"
|
| 236 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 237 |
+
|
| 238 |
+
accuracy_csv_file_path = os.path.join(save_dir, "accuracy.csv")
|
| 239 |
+
if os.path.exists(accuracy_csv_file_path):
|
| 240 |
+
os.remove(accuracy_csv_file_path)
|
| 241 |
+
|
| 242 |
+
all_runs_log_dirs = get_all_log_dirs(args.log_dir)
|
| 243 |
+
|
| 244 |
+
plot_training_curve_all_runs(all_runs_log_dirs, save_dir, args.scale_training_curve, device, x_max=200_000)
|
| 245 |
+
plot_lstm_last_and_certain_accuracy(all_folders=all_runs_log_dirs, save_path=f"{save_dir}/lstm_final_vs_certain_accuracy.png", scale=args.scale_training_curve)
|
| 246 |
+
|
| 247 |
+
progress_bar = tqdm(all_runs_log_dirs, desc="Analyzing Runs", dynamic_ncols=True)
|
| 248 |
+
for folder in progress_bar:
|
| 249 |
+
|
| 250 |
+
run, model_name = folder.strip("/").split("/")[-2:]
|
| 251 |
+
|
| 252 |
+
run_model_spefic_save_dir = f"{save_dir}/{model_name}/{run}"
|
| 253 |
+
os.makedirs(run_model_spefic_save_dir, exist_ok=True)
|
| 254 |
+
|
| 255 |
+
args.log_dir = folder
|
| 256 |
+
progress_bar.set_description(f"Analyzing Trained Model at {folder}")
|
| 257 |
+
|
| 258 |
+
accuracy_mean, accuracy_std, num_iterations = analyze_trained_model(run_model_spefic_save_dir, args, device)
|
| 259 |
+
|
| 260 |
+
with open(accuracy_csv_file_path, mode='a', newline='') as file:
|
| 261 |
+
writer = csv.writer(file)
|
| 262 |
+
if file.tell() == 0:
|
| 263 |
+
writer.writerow(["Run", "Overall Mean Accuracy", "Overall Std Accuracy", "Num Iterations"])
|
| 264 |
+
writer.writerow([folder, accuracy_mean, accuracy_std, num_iterations])
|
| 265 |
+
|
| 266 |
+
progress_bar.set_description(f"Analyzing Training at {folder}")
|
| 267 |
+
analyze_training(run_model_spefic_save_dir, args, device)
|
| 268 |
+
|
| 269 |
+
plot_accuracy_thinking_time(accuracy_csv_file_path, scale=args.scale_training_curve, output_dir=save_dir)
|
tasks/parity/plotting.py
ADDED
|
@@ -0,0 +1,897 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import seaborn as sns
|
| 3 |
+
import numpy as np
|
| 4 |
+
import pandas as pd
|
| 5 |
+
from collections import defaultdict
|
| 6 |
+
from matplotlib.lines import Line2D
|
| 7 |
+
import matplotlib as mpl
|
| 8 |
+
import matplotlib.pyplot as plt
|
| 9 |
+
import matplotlib.patheffects as path_effects
|
| 10 |
+
from matplotlib.ticker import FuncFormatter
|
| 11 |
+
from scipy.special import softmax
|
| 12 |
+
import imageio.v2 as imageio
|
| 13 |
+
from PIL import Image
|
| 14 |
+
import math
|
| 15 |
+
import re
|
| 16 |
+
from tqdm import tqdm
|
| 17 |
+
sns.set_style('darkgrid')
|
| 18 |
+
mpl.use('Agg')
|
| 19 |
+
|
| 20 |
+
from tasks.parity.utils import get_where_most_certain, parse_folder_name
|
| 21 |
+
from models.utils import get_latest_checkpoint_file, load_checkpoint, get_model_args_from_checkpoint, get_accuracy_and_loss_from_checkpoint
|
| 22 |
+
from tasks.image_classification.plotting import save_frames_to_mp4
|
| 23 |
+
|
| 24 |
+
def make_parity_gif(predictions, certainties, targets, pre_activations, post_activations, attention_weights, inputs_to_model, filename):
|
| 25 |
+
|
| 26 |
+
# Config
|
| 27 |
+
batch_index = 0
|
| 28 |
+
n_neurons_to_visualise = 16
|
| 29 |
+
figscale = 0.28
|
| 30 |
+
n_steps = len(pre_activations)
|
| 31 |
+
frames = []
|
| 32 |
+
heatmap_cmap = sns.color_palette("viridis", as_cmap=True)
|
| 33 |
+
|
| 34 |
+
these_pre_acts = pre_activations[:, batch_index, :] # Shape: (T, H)
|
| 35 |
+
these_post_acts = post_activations[:, batch_index, :] # Shape: (T, H)
|
| 36 |
+
these_inputs = inputs_to_model[:, batch_index, :, :, :] # Shape: (T, C, H, W)
|
| 37 |
+
these_predictions = predictions[batch_index, :, :, :] # Shape: (d, C, T)
|
| 38 |
+
these_certainties = certainties[batch_index, :, :] # Shape: (C, T)
|
| 39 |
+
these_attention_weights = attention_weights[:, batch_index, :, :]
|
| 40 |
+
|
| 41 |
+
# Create mosaic layout
|
| 42 |
+
mosaic = [['img_data', 'img_data', 'attention', 'attention', 'probs', 'probs', 'target', 'target'] for _ in range(2)] + \
|
| 43 |
+
[['img_data', 'img_data', 'attention', 'attention', 'probs', 'probs', 'target', 'target'] for _ in range(2)] + \
|
| 44 |
+
[['certainty', 'certainty', 'certainty', 'certainty', 'certainty', 'certainty', 'certainty', 'certainty']] + \
|
| 45 |
+
[[f'trace_{ti}', f'trace_{ti}', f'trace_{ti}', f'trace_{ti}', f'trace_{ti}', f'trace_{ti}', f'trace_{ti}', f'trace_{ti}'] for ti in range(n_neurons_to_visualise)]
|
| 46 |
+
|
| 47 |
+
for stepi in tqdm(range(n_steps), desc="Processing steps", unit="step"):
|
| 48 |
+
fig_gif, axes_gif = plt.subplot_mosaic(mosaic=mosaic, figsize=(31*figscale*8/4, 76*figscale))
|
| 49 |
+
|
| 50 |
+
# Plot predictions
|
| 51 |
+
d = these_predictions.shape[0]
|
| 52 |
+
grid_side = int(np.sqrt(d))
|
| 53 |
+
logits = these_predictions[:, :, stepi]
|
| 54 |
+
|
| 55 |
+
probs = softmax(logits, axis=1)
|
| 56 |
+
probs_grid = probs[:, 0].reshape(grid_side, grid_side)
|
| 57 |
+
axes_gif["probs"].imshow(probs_grid, cmap='viridis', interpolation='nearest', vmin=0, vmax=1)
|
| 58 |
+
axes_gif["probs"].axis('off')
|
| 59 |
+
axes_gif["probs"].set_title('Probabilties')
|
| 60 |
+
|
| 61 |
+
# Create and show attention heatmap
|
| 62 |
+
this_input_gate = these_attention_weights[stepi]
|
| 63 |
+
gate_min, gate_max = np.nanmin(this_input_gate), np.nanmax(this_input_gate)
|
| 64 |
+
if not np.isclose(gate_min, gate_max):
|
| 65 |
+
normalized_gate = (this_input_gate - gate_min) / (gate_max - gate_min + 1e-8)
|
| 66 |
+
else:
|
| 67 |
+
normalized_gate = np.zeros_like(this_input_gate)
|
| 68 |
+
attention_weights_heatmap = heatmap_cmap(normalized_gate)[:,:,:3]
|
| 69 |
+
|
| 70 |
+
# Show heatmaps
|
| 71 |
+
axes_gif['attention'].imshow(attention_weights_heatmap, vmin=0, vmax=1)
|
| 72 |
+
axes_gif['attention'].axis('off')
|
| 73 |
+
axes_gif['attention'].set_title('Attention')
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# Plot target
|
| 77 |
+
target_grid = targets[batch_index].reshape(grid_side, grid_side)
|
| 78 |
+
axes_gif["target"].imshow(target_grid, cmap='viridis_r', interpolation='nearest', vmin=0, vmax=1)
|
| 79 |
+
axes_gif["target"].axis('off')
|
| 80 |
+
axes_gif["target"].set_title('Target')
|
| 81 |
+
|
| 82 |
+
# Add certainty plot
|
| 83 |
+
axes_gif['certainty'].plot(np.arange(n_steps), these_certainties[1], 'k-', linewidth=2)
|
| 84 |
+
axes_gif['certainty'].set_xlim([0, n_steps-1])
|
| 85 |
+
axes_gif['certainty'].axvline(x=stepi, color='black', linewidth=1, alpha=0.5)
|
| 86 |
+
axes_gif['certainty'].set_xticklabels([])
|
| 87 |
+
axes_gif['certainty'].set_yticklabels([])
|
| 88 |
+
axes_gif['certainty'].grid(False)
|
| 89 |
+
|
| 90 |
+
# Plot neuron traces
|
| 91 |
+
for neuroni in range(n_neurons_to_visualise):
|
| 92 |
+
ax = axes_gif[f'trace_{neuroni}']
|
| 93 |
+
|
| 94 |
+
pre_activation = these_pre_acts[:, neuroni]
|
| 95 |
+
post_activation = these_post_acts[:, neuroni]
|
| 96 |
+
|
| 97 |
+
ax_pre = ax.twinx()
|
| 98 |
+
|
| 99 |
+
pre_min, pre_max = np.min(pre_activation), np.max(pre_activation)
|
| 100 |
+
post_min, post_max = np.min(post_activation), np.max(post_activation)
|
| 101 |
+
|
| 102 |
+
ax_pre.plot(np.arange(n_steps), pre_activation,
|
| 103 |
+
color='grey',
|
| 104 |
+
linestyle='--',
|
| 105 |
+
linewidth=1,
|
| 106 |
+
alpha=0.4,
|
| 107 |
+
label='Pre-activation')
|
| 108 |
+
|
| 109 |
+
color = 'blue' if neuroni % 2 else 'red'
|
| 110 |
+
ax.plot(np.arange(n_steps), post_activation,
|
| 111 |
+
color=color,
|
| 112 |
+
linestyle='-',
|
| 113 |
+
linewidth=2,
|
| 114 |
+
alpha=1.0,
|
| 115 |
+
label='Post-activation')
|
| 116 |
+
|
| 117 |
+
ax.set_xlim([0, n_steps-1])
|
| 118 |
+
ax_pre.set_xlim([0, n_steps-1])
|
| 119 |
+
|
| 120 |
+
if pre_min != pre_max:
|
| 121 |
+
ax_pre.set_ylim([pre_min, pre_max])
|
| 122 |
+
if post_min != post_max:
|
| 123 |
+
ax.set_ylim([post_min, post_max])
|
| 124 |
+
|
| 125 |
+
ax.axvline(x=stepi, color='black', linewidth=1, alpha=0.5)
|
| 126 |
+
|
| 127 |
+
ax.set_xticklabels([])
|
| 128 |
+
ax.set_yticklabels([])
|
| 129 |
+
ax.grid(False)
|
| 130 |
+
|
| 131 |
+
ax_pre.set_xticklabels([])
|
| 132 |
+
ax_pre.set_yticklabels([])
|
| 133 |
+
ax_pre.grid(False)
|
| 134 |
+
|
| 135 |
+
# Show input image
|
| 136 |
+
this_image = these_inputs[stepi].transpose(1, 2, 0)
|
| 137 |
+
axes_gif['img_data'].imshow(this_image, cmap='viridis', vmin=0, vmax=1)
|
| 138 |
+
axes_gif['img_data'].grid(False)
|
| 139 |
+
axes_gif['img_data'].set_xticks([])
|
| 140 |
+
axes_gif['img_data'].set_yticks([])
|
| 141 |
+
axes_gif['img_data'].set_title('Input')
|
| 142 |
+
|
| 143 |
+
# Save frames
|
| 144 |
+
fig_gif.tight_layout(pad=0.1)
|
| 145 |
+
if stepi == 0:
|
| 146 |
+
fig_gif.savefig(filename.split('.gif')[0]+'_frame0.png', dpi=100)
|
| 147 |
+
if stepi == 1:
|
| 148 |
+
fig_gif.savefig(filename.split('.gif')[0]+'_frame1.png', dpi=100)
|
| 149 |
+
if stepi == n_steps-1:
|
| 150 |
+
fig_gif.savefig(filename.split('.gif')[0]+'_frame-1.png', dpi=100)
|
| 151 |
+
|
| 152 |
+
# Convert to frame
|
| 153 |
+
canvas = fig_gif.canvas
|
| 154 |
+
canvas.draw()
|
| 155 |
+
image_numpy = np.frombuffer(canvas.buffer_rgba(), dtype='uint8')
|
| 156 |
+
image_numpy = image_numpy.reshape(*reversed(canvas.get_width_height()), 4)[:,:,:3]
|
| 157 |
+
frames.append(image_numpy)
|
| 158 |
+
plt.close(fig_gif)
|
| 159 |
+
|
| 160 |
+
imageio.mimsave(filename, frames, fps=15, loop=100)
|
| 161 |
+
|
| 162 |
+
pass
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def plot_attention_trajectory(attention, certainties, input_images, save_dir, filename, args):
|
| 166 |
+
where_most_certain = get_where_most_certain(certainties)
|
| 167 |
+
grid_size = int(math.sqrt(args.parity_sequence_length))
|
| 168 |
+
trajectory = [np.unravel_index(np.argmax(attention[t]), (grid_size, grid_size)) for t in range(args.iterations)]
|
| 169 |
+
x_coords, y_coords = zip(*trajectory)
|
| 170 |
+
|
| 171 |
+
plt.figure(figsize=(5, 5))
|
| 172 |
+
plt.imshow(input_images[0], cmap="gray", origin="upper", vmin=0.2, vmax=0.8, interpolation='nearest')
|
| 173 |
+
|
| 174 |
+
ax = plt.gca()
|
| 175 |
+
nrows, ncols = input_images[0].shape
|
| 176 |
+
ax.set_xticks(np.arange(-0.5, ncols, 1), minor=True)
|
| 177 |
+
ax.set_yticks(np.arange(-0.5, nrows, 1), minor=True)
|
| 178 |
+
ax.grid(which='minor', color='black', linestyle='-', linewidth=1)
|
| 179 |
+
ax.tick_params(which="minor", size=0)
|
| 180 |
+
ax.set_axisbelow(False)
|
| 181 |
+
plt.xticks([])
|
| 182 |
+
plt.yticks([])
|
| 183 |
+
|
| 184 |
+
cmap = plt.get_cmap("plasma")
|
| 185 |
+
norm_time = np.linspace(0, 1, len(trajectory))
|
| 186 |
+
|
| 187 |
+
for i in range(len(trajectory) - 1):
|
| 188 |
+
x1, y1 = x_coords[i], y_coords[i]
|
| 189 |
+
x2, y2 = x_coords[i + 1], y_coords[i + 1]
|
| 190 |
+
color = cmap(norm_time[i])
|
| 191 |
+
line, = plt.plot([y1, y2], [x1, x2], color=color, linewidth=6, alpha=0.5, zorder=4)
|
| 192 |
+
line.set_path_effects([
|
| 193 |
+
path_effects.Stroke(linewidth=8, foreground='white'),
|
| 194 |
+
path_effects.Normal()
|
| 195 |
+
])
|
| 196 |
+
|
| 197 |
+
for i, (x, y) in enumerate(trajectory):
|
| 198 |
+
plt.scatter(y, x, color=cmap(norm_time[i]), s=100, edgecolor='white', linewidth=1.5, zorder=5)
|
| 199 |
+
|
| 200 |
+
most_certain_point = trajectory[where_most_certain]
|
| 201 |
+
|
| 202 |
+
plt.plot(most_certain_point[1], most_certain_point[0],
|
| 203 |
+
marker='x', markersize=18, markeredgewidth=5,
|
| 204 |
+
color='white', linestyle='', zorder=6)
|
| 205 |
+
plt.plot(most_certain_point[1], most_certain_point[0],
|
| 206 |
+
marker='x', markersize=15, markeredgewidth=3,
|
| 207 |
+
color=cmap(norm_time[where_most_certain]), linestyle='', zorder=7)
|
| 208 |
+
|
| 209 |
+
plt.tight_layout()
|
| 210 |
+
plt.savefig(f"{save_dir}/{filename}_traj.png", dpi=300, bbox_inches='tight', pad_inches=0)
|
| 211 |
+
plt.savefig(f"{save_dir}/{filename}_traj.pdf", format='pdf', bbox_inches='tight', pad_inches=0)
|
| 212 |
+
plt.show()
|
| 213 |
+
plt.close()
|
| 214 |
+
|
| 215 |
+
def plot_input(input_images, save_dir, filename):
|
| 216 |
+
|
| 217 |
+
plt.figure(figsize=(5, 5))
|
| 218 |
+
plt.imshow(input_images[0], cmap="gray", origin="upper", vmin=0.2, vmax=0.8, interpolation='nearest')
|
| 219 |
+
|
| 220 |
+
ax = plt.gca()
|
| 221 |
+
nrows, ncols = input_images[0].shape
|
| 222 |
+
ax.set_xticks(np.arange(-0.5, ncols, 1), minor=True)
|
| 223 |
+
ax.set_yticks(np.arange(-0.5, nrows, 1), minor=True)
|
| 224 |
+
ax.grid(which='minor', color='black', linestyle='-', linewidth=1)
|
| 225 |
+
ax.tick_params(which="minor", size=0)
|
| 226 |
+
ax.set_axisbelow(False)
|
| 227 |
+
plt.xticks([])
|
| 228 |
+
plt.yticks([])
|
| 229 |
+
|
| 230 |
+
plt.tight_layout()
|
| 231 |
+
plt.savefig(f"{save_dir}/{filename}_input.png", dpi=300, bbox_inches='tight', pad_inches=0)
|
| 232 |
+
plt.savefig(f"{save_dir}/{filename}_input.pdf", format='pdf', bbox_inches='tight', pad_inches=0)
|
| 233 |
+
plt.show()
|
| 234 |
+
plt.close()
|
| 235 |
+
|
| 236 |
+
def plot_target(targets, save_dir, filename, args):
|
| 237 |
+
grid_size = int(math.sqrt(args.parity_sequence_length))
|
| 238 |
+
targets_grid = targets[0].reshape(grid_size, grid_size).detach().cpu().numpy()
|
| 239 |
+
plt.figure(figsize=(5, 5))
|
| 240 |
+
plt.imshow(targets_grid, cmap="gray_r", origin="upper", vmin=0.2, vmax=0.8, interpolation='nearest')
|
| 241 |
+
ax = plt.gca()
|
| 242 |
+
nrows, ncols = targets_grid.shape
|
| 243 |
+
ax.set_xticks(np.arange(-0.5, ncols, 1), minor=True)
|
| 244 |
+
ax.set_yticks(np.arange(-0.5, nrows, 1), minor=True)
|
| 245 |
+
ax.grid(which='minor', color='black', linestyle='-', linewidth=1)
|
| 246 |
+
ax.tick_params(which="minor", size=0)
|
| 247 |
+
ax.set_axisbelow(False)
|
| 248 |
+
plt.xticks([])
|
| 249 |
+
plt.yticks([])
|
| 250 |
+
plt.tight_layout()
|
| 251 |
+
plt.savefig(f"{save_dir}/{filename}_target.png", dpi=300, bbox_inches='tight', pad_inches=0)
|
| 252 |
+
plt.savefig(f"{save_dir}/{filename}_target.pdf", format='pdf', bbox_inches='tight', pad_inches=0)
|
| 253 |
+
plt.show()
|
| 254 |
+
plt.close()
|
| 255 |
+
|
| 256 |
+
def plot_probabilities(predictions, certainties, save_dir, filename, args):
|
| 257 |
+
grid_size = int(math.sqrt(args.parity_sequence_length))
|
| 258 |
+
where_most_certain = get_where_most_certain(certainties)
|
| 259 |
+
predictions_most_certain = predictions[0, :, :, where_most_certain].detach().cpu().numpy()
|
| 260 |
+
probs = softmax(predictions_most_certain, axis=1)
|
| 261 |
+
probs_grid = probs[:, 0].reshape(grid_size, grid_size)
|
| 262 |
+
plt.figure(figsize=(5, 5))
|
| 263 |
+
plt.imshow(probs_grid, cmap="gray", origin="upper", vmin=0.2, vmax=0.8, interpolation='nearest')
|
| 264 |
+
ax = plt.gca()
|
| 265 |
+
nrows, ncols = probs_grid.shape
|
| 266 |
+
ax.set_xticks(np.arange(-0.5, ncols, 1), minor=True)
|
| 267 |
+
ax.set_yticks(np.arange(-0.5, nrows, 1), minor=True)
|
| 268 |
+
ax.grid(which='minor', color='black', linestyle='-', linewidth=1)
|
| 269 |
+
ax.tick_params(which="minor", size=0)
|
| 270 |
+
ax.set_axisbelow(False)
|
| 271 |
+
plt.xticks([])
|
| 272 |
+
plt.yticks([])
|
| 273 |
+
plt.tight_layout()
|
| 274 |
+
plt.savefig(f"{save_dir}/{filename}_probs.png", dpi=300, bbox_inches='tight', pad_inches=0)
|
| 275 |
+
plt.savefig(f"{save_dir}/{filename}_probs.pdf", format='pdf', bbox_inches='tight', pad_inches=0)
|
| 276 |
+
plt.show()
|
| 277 |
+
plt.close()
|
| 278 |
+
|
| 279 |
+
def plot_prediction(predictions, certainties, save_dir, filename, args):
|
| 280 |
+
grid_size = int(math.sqrt(args.parity_sequence_length))
|
| 281 |
+
where_most_certain = get_where_most_certain(certainties)
|
| 282 |
+
predictions_most_certain = predictions[0, :, :, where_most_certain].detach().cpu().numpy()
|
| 283 |
+
class_grid = np.argmax(predictions_most_certain, axis=1).reshape(grid_size, grid_size)
|
| 284 |
+
|
| 285 |
+
plt.figure(figsize=(5, 5))
|
| 286 |
+
plt.imshow(class_grid, cmap="gray_r", origin="upper", vmin=0, vmax=1, interpolation='nearest')
|
| 287 |
+
|
| 288 |
+
ax = plt.gca()
|
| 289 |
+
nrows, ncols = class_grid.shape
|
| 290 |
+
ax.set_xticks(np.arange(-0.5, ncols, 1), minor=True)
|
| 291 |
+
ax.set_yticks(np.arange(-0.5, nrows, 1), minor=True)
|
| 292 |
+
ax.grid(which='minor', color='black', linestyle='-', linewidth=1)
|
| 293 |
+
ax.tick_params(which="minor", size=0)
|
| 294 |
+
ax.set_axisbelow(False)
|
| 295 |
+
plt.xticks([])
|
| 296 |
+
plt.yticks([])
|
| 297 |
+
|
| 298 |
+
plt.tight_layout()
|
| 299 |
+
plt.savefig(f"{save_dir}/{filename}_prediction.png", dpi=300, bbox_inches='tight', pad_inches=0)
|
| 300 |
+
plt.savefig(f"{save_dir}/{filename}_prediction.pdf", format='pdf', bbox_inches='tight', pad_inches=0)
|
| 301 |
+
plt.show()
|
| 302 |
+
plt.close()
|
| 303 |
+
|
| 304 |
+
def plot_accuracy_heatmap(overall_accuracies_avg, average_thinking_time, std_thinking_time, scale, save_path, args):
|
| 305 |
+
fig, ax = plt.subplots(figsize=(scale*10, scale*5))
|
| 306 |
+
im = ax.imshow(overall_accuracies_avg.T * 100, aspect='auto', cmap="viridis", origin='lower', extent=[0, args.iterations-1, 0, args.parity_sequence_length-1], vmin=50, vmax=100)
|
| 307 |
+
cbar = fig.colorbar(im, ax=ax, format="%.1f")
|
| 308 |
+
cbar.set_label("Accuracy (%)")
|
| 309 |
+
ax.errorbar(average_thinking_time, np.arange(args.parity_sequence_length), xerr=std_thinking_time, fmt='ko', markersize=2, capsize=2, elinewidth=1, label="Min. Entropy")
|
| 310 |
+
ax.set_xlabel("Time Step")
|
| 311 |
+
ax.set_ylabel("Sequence Index")
|
| 312 |
+
ax.set_xlim(0, args.iterations-1)
|
| 313 |
+
ax.set_ylim(0, args.parity_sequence_length-1)
|
| 314 |
+
ax.grid(False)
|
| 315 |
+
ax.legend(loc="upper left")
|
| 316 |
+
fig.tight_layout(pad=0.1)
|
| 317 |
+
fig.savefig(save_path, dpi=300, bbox_inches="tight")
|
| 318 |
+
fig.savefig(save_path.replace(".png", ".pdf"), format='pdf', bbox_inches="tight")
|
| 319 |
+
plt.close(fig)
|
| 320 |
+
|
| 321 |
+
def plot_attention_heatmap(overall_attentions_avg, scale, save_path, vmin=None, vmax=None):
|
| 322 |
+
overall_attentions_avg = overall_attentions_avg.reshape(overall_attentions_avg.shape[0], -1)
|
| 323 |
+
fig, ax = plt.subplots(figsize=(scale*10, scale*5))
|
| 324 |
+
im = ax.imshow(overall_attentions_avg.T, aspect='auto', cmap="viridis", origin='lower', extent=[0, overall_attentions_avg.shape[0]-1, 0, overall_attentions_avg.shape[1]-1], vmin=vmin, vmax=vmax)
|
| 325 |
+
cbar = fig.colorbar(im, ax=ax, format=FuncFormatter(lambda x, _: f"{x:05.2f}"))
|
| 326 |
+
cbar.set_label("Attention Weight")
|
| 327 |
+
ax.set_xlabel("Time Step")
|
| 328 |
+
ax.set_ylabel("Sequence Index")
|
| 329 |
+
ax.set_xlim(0, overall_attentions_avg.shape[0]-1)
|
| 330 |
+
ax.set_ylim(0, overall_attentions_avg.shape[1]-1)
|
| 331 |
+
ax.grid(False)
|
| 332 |
+
fig.tight_layout(pad=0.1)
|
| 333 |
+
fig.savefig(save_path, dpi=300, bbox_inches="tight")
|
| 334 |
+
fig.savefig(save_path.replace(".png", ".pdf"), format='pdf', bbox_inches="tight")
|
| 335 |
+
plt.close(fig)
|
| 336 |
+
|
| 337 |
+
def create_accuracies_heatmap_gif(all_accuracies, all_average_thinking_times, all_std_thinking_times, scale, save_dir, args):
|
| 338 |
+
heatmap_components_dir = os.path.join(save_dir, "accuracy_heatmaps")
|
| 339 |
+
os.makedirs(heatmap_components_dir, exist_ok=True)
|
| 340 |
+
|
| 341 |
+
image_paths = []
|
| 342 |
+
|
| 343 |
+
for i, (accuracies, avg_thinking_time, std_thinking_time) in enumerate(zip(all_accuracies, all_average_thinking_times, all_std_thinking_times)):
|
| 344 |
+
save_path = os.path.join(heatmap_components_dir, f"frame_{i:04d}.png")
|
| 345 |
+
plot_accuracy_heatmap(accuracies, avg_thinking_time, std_thinking_time, scale, save_path, args)
|
| 346 |
+
image_paths.append(save_path)
|
| 347 |
+
|
| 348 |
+
gif_path = os.path.join(save_dir, "accuracy_heatmap.gif")
|
| 349 |
+
with imageio.get_writer(gif_path, mode='I', duration=0.3) as writer:
|
| 350 |
+
for image_path in image_paths:
|
| 351 |
+
image = imageio.imread(image_path)
|
| 352 |
+
writer.append_data(image)
|
| 353 |
+
|
| 354 |
+
def create_attentions_heatmap_gif(all_attentions, scale, save_path, args):
|
| 355 |
+
heatmap_components_dir = os.path.join(args.log_dir, "attention_heatmaps")
|
| 356 |
+
os.makedirs(heatmap_components_dir, exist_ok=True)
|
| 357 |
+
|
| 358 |
+
global_min = min(attentions.min() for attentions in all_attentions)
|
| 359 |
+
global_max = max(attentions.max() for attentions in all_attentions)
|
| 360 |
+
|
| 361 |
+
image_paths = []
|
| 362 |
+
|
| 363 |
+
for i, attentions in enumerate(all_attentions):
|
| 364 |
+
save_path_component = os.path.join(heatmap_components_dir, f"frame_{i:04d}.png")
|
| 365 |
+
plot_attention_heatmap(attentions, scale, save_path_component, vmin=global_min, vmax=global_max)
|
| 366 |
+
image_paths.append(save_path_component)
|
| 367 |
+
|
| 368 |
+
gif_path = os.path.join(save_path, "attention_heatmap.gif")
|
| 369 |
+
with imageio.get_writer(gif_path, mode='I', duration=0.3) as writer:
|
| 370 |
+
for image_path in image_paths:
|
| 371 |
+
image = imageio.imread(image_path)
|
| 372 |
+
writer.append_data(image)
|
| 373 |
+
|
| 374 |
+
def create_stacked_gif(save_path, y_shift=200):
|
| 375 |
+
accuracy_gif_path = os.path.join(save_path, "accuracy_heatmap.gif")
|
| 376 |
+
attention_gif_path = os.path.join(save_path, "attention_heatmap.gif")
|
| 377 |
+
stacked_gif_path = os.path.join(save_path, "stacked_heatmap.gif")
|
| 378 |
+
|
| 379 |
+
accuracy_reader = imageio.get_reader(accuracy_gif_path)
|
| 380 |
+
attention_reader = imageio.get_reader(attention_gif_path)
|
| 381 |
+
|
| 382 |
+
accuracy_frames = [Image.fromarray(frame) for frame in accuracy_reader]
|
| 383 |
+
attention_frames = [Image.fromarray(frame) for frame in attention_reader]
|
| 384 |
+
|
| 385 |
+
assert len(accuracy_frames) == len(attention_frames), "Mismatch in frame counts between accuracy and attention GIFs"
|
| 386 |
+
|
| 387 |
+
stacked_frames = []
|
| 388 |
+
for acc_frame, att_frame in zip(accuracy_frames, attention_frames):
|
| 389 |
+
acc_width, acc_height = acc_frame.size
|
| 390 |
+
att_width, att_height = att_frame.size
|
| 391 |
+
|
| 392 |
+
# Create base canvas
|
| 393 |
+
stacked_height = acc_height + att_height - y_shift
|
| 394 |
+
stacked_width = max(acc_width, att_width)
|
| 395 |
+
|
| 396 |
+
stacked_frame = Image.new("RGB", (stacked_width, stacked_height), color=(255, 255, 255))
|
| 397 |
+
|
| 398 |
+
# Paste attention frame first, shifted up
|
| 399 |
+
stacked_frame.paste(att_frame, (0, 0)) # Paste at top
|
| 400 |
+
stacked_frame.paste(acc_frame, (0, att_height - y_shift)) # Shift accuracy up by overlap
|
| 401 |
+
|
| 402 |
+
stacked_frames.append(stacked_frame)
|
| 403 |
+
|
| 404 |
+
stacked_frames[0].save(
|
| 405 |
+
stacked_gif_path,
|
| 406 |
+
save_all=True,
|
| 407 |
+
append_images=stacked_frames[1:],
|
| 408 |
+
duration=300,
|
| 409 |
+
loop=0
|
| 410 |
+
)
|
| 411 |
+
|
| 412 |
+
save_frames_to_mp4(
|
| 413 |
+
[np.array(fm)[:, :, ::-1] for fm in stacked_frames],
|
| 414 |
+
f"{stacked_gif_path.replace('gif', 'mp4')}",
|
| 415 |
+
fps=15,
|
| 416 |
+
gop_size=1,
|
| 417 |
+
preset="slow"
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
def plot_accuracy_training(all_accuracies, scale, run_model_spefic_save_dir, args):
|
| 422 |
+
scale=0.5
|
| 423 |
+
seq_indices = range(args.parity_sequence_length)
|
| 424 |
+
fig, ax = plt.subplots(figsize=(scale*10, scale*5))
|
| 425 |
+
cmap = plt.get_cmap("viridis")
|
| 426 |
+
|
| 427 |
+
for i, acc in enumerate(all_accuracies):
|
| 428 |
+
color = cmap(i / (len(all_accuracies) - 1))
|
| 429 |
+
ax.plot(seq_indices, acc*100, color=color, alpha=0.7, linewidth=1)
|
| 430 |
+
|
| 431 |
+
num_checkpoints = 5
|
| 432 |
+
checkpoint_percentages = np.linspace(0, 100, num_checkpoints)
|
| 433 |
+
|
| 434 |
+
sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(vmin=0, vmax=100))
|
| 435 |
+
sm.set_array([])
|
| 436 |
+
cbar = fig.colorbar(sm, ax=ax)
|
| 437 |
+
cbar.set_label("Training Progress (%)")
|
| 438 |
+
cbar.set_ticks(checkpoint_percentages)
|
| 439 |
+
cbar.set_ticklabels([f"{int(p)}%" for p in checkpoint_percentages])
|
| 440 |
+
|
| 441 |
+
ax.set_xlabel("Sequence Index")
|
| 442 |
+
ax.set_ylabel("Accuracy (%)")
|
| 443 |
+
ax.set_xticks([0, 16 ,32, 48, 63])
|
| 444 |
+
ax.grid(True, alpha=0.5)
|
| 445 |
+
ax.set_xlim(0, args.parity_sequence_length - 1)
|
| 446 |
+
|
| 447 |
+
fig.tight_layout(pad=0.1)
|
| 448 |
+
fig.savefig(f"{run_model_spefic_save_dir}/accuracy_vs_seq_element.png", dpi=300, bbox_inches="tight")
|
| 449 |
+
fig.savefig(f"{run_model_spefic_save_dir}/accuracy_vs_seq_element.pdf", format='pdf', bbox_inches="tight")
|
| 450 |
+
plt.close(fig)
|
| 451 |
+
|
| 452 |
+
|
| 453 |
+
def plot_loss_all_runs(training_data, evaluate_every, save_path="train_loss_comparison_parity.png", step=1, scale=1.0, x_max=None):
|
| 454 |
+
fig, ax = plt.subplots(figsize=(scale * 10, scale * 5))
|
| 455 |
+
|
| 456 |
+
grouped = defaultdict(list)
|
| 457 |
+
label_map = {}
|
| 458 |
+
linestyle_map = {}
|
| 459 |
+
iters_map = {}
|
| 460 |
+
model_map = {}
|
| 461 |
+
|
| 462 |
+
for folder, data in training_data.items():
|
| 463 |
+
label, model_type, iters = parse_folder_name(folder)
|
| 464 |
+
if iters is None:
|
| 465 |
+
continue
|
| 466 |
+
|
| 467 |
+
key = f"{model_type}_{iters}"
|
| 468 |
+
grouped[key].append(data["train_losses"])
|
| 469 |
+
label_map[key] = f"{model_type}, {iters} Iters."
|
| 470 |
+
linestyle_map[key] = "--" if model_type == "LSTM" else "-"
|
| 471 |
+
iters_map[key] = iters
|
| 472 |
+
model_map[key] = model_type
|
| 473 |
+
|
| 474 |
+
unique_iters = sorted(set(iters_map.values()))
|
| 475 |
+
base_colors = sns.color_palette("hls", n_colors=len(unique_iters))
|
| 476 |
+
color_lookup = {iters: base_colors[i] for i, iters in enumerate(unique_iters)}
|
| 477 |
+
|
| 478 |
+
legend_entries = []
|
| 479 |
+
global_max_x = 0
|
| 480 |
+
for key in sorted(grouped.keys(), key=lambda k: (iters_map[k], model_map[k])):
|
| 481 |
+
runs = grouped[key]
|
| 482 |
+
if not runs:
|
| 483 |
+
continue
|
| 484 |
+
|
| 485 |
+
iters = iters_map[key]
|
| 486 |
+
color = color_lookup[iters]
|
| 487 |
+
linestyle = linestyle_map[key]
|
| 488 |
+
|
| 489 |
+
min_len = min(len(r) for r in runs)
|
| 490 |
+
trimmed = np.array([r[:min_len] for r in runs])[:, ::step]
|
| 491 |
+
|
| 492 |
+
mean = np.mean(trimmed, axis=0)
|
| 493 |
+
std = np.std(trimmed, axis=0)
|
| 494 |
+
x = np.arange(len(mean)) * step * evaluate_every
|
| 495 |
+
group_max_x = len(mean) * step * evaluate_every
|
| 496 |
+
global_max_x = max(global_max_x, group_max_x)
|
| 497 |
+
|
| 498 |
+
line, = ax.plot(x, mean, color=color, linestyle=linestyle, label=label_map[key])
|
| 499 |
+
ax.fill_between(x, mean - std, mean + std, alpha=0.1, color=color)
|
| 500 |
+
|
| 501 |
+
legend_entries.append((line, label_map[key]))
|
| 502 |
+
|
| 503 |
+
ax.set_xlabel("Training Iterations")
|
| 504 |
+
ax.set_ylabel("Loss")
|
| 505 |
+
ax.grid(True, alpha=0.5)
|
| 506 |
+
|
| 507 |
+
style_legend = [
|
| 508 |
+
Line2D([0], [0], color='black', linestyle='-', label='CTM'),
|
| 509 |
+
Line2D([0], [0], color='black', linestyle='--', label='LSTM')
|
| 510 |
+
]
|
| 511 |
+
color_legend = [
|
| 512 |
+
Line2D([0], [0], color=color_lookup[it], linestyle='-', label=f"{it} Iters.")
|
| 513 |
+
for it in unique_iters
|
| 514 |
+
]
|
| 515 |
+
|
| 516 |
+
if not x_max:
|
| 517 |
+
x_max = global_max_x
|
| 518 |
+
|
| 519 |
+
ax.set_xlim([0, x_max])
|
| 520 |
+
ax.set_ylim(bottom=0)
|
| 521 |
+
ax.set_xticks(np.arange(0, x_max + 1, 50000))
|
| 522 |
+
ax.legend(handles=color_legend + style_legend, loc="upper left")
|
| 523 |
+
fig.tight_layout(pad=0.1)
|
| 524 |
+
fig.savefig(save_path, dpi=300)
|
| 525 |
+
fig.savefig(save_path.replace("png", "pdf"), format='pdf')
|
| 526 |
+
plt.close(fig)
|
| 527 |
+
|
| 528 |
+
def plot_accuracy_all_runs(training_data, evaluate_every, save_path="test_accuracy_comparison_parity.png", step=1, scale=1.0, smooth=False, x_max=None):
|
| 529 |
+
fig, ax = plt.subplots(figsize=(scale * 10, scale * 5))
|
| 530 |
+
|
| 531 |
+
grouped = defaultdict(list)
|
| 532 |
+
label_map = {}
|
| 533 |
+
linestyle_map = {}
|
| 534 |
+
iters_map = {}
|
| 535 |
+
model_map = {}
|
| 536 |
+
|
| 537 |
+
for folder, data in training_data.items():
|
| 538 |
+
label, model_type, iters = parse_folder_name(folder)
|
| 539 |
+
if iters is None:
|
| 540 |
+
continue
|
| 541 |
+
|
| 542 |
+
key = f"{model_type}_{iters}"
|
| 543 |
+
grouped[key].append(data["test_accuracies"])
|
| 544 |
+
label_map[key] = f"{model_type}, {iters} Iters."
|
| 545 |
+
linestyle_map[key] = "--" if model_type == "LSTM" else "-"
|
| 546 |
+
iters_map[key] = iters
|
| 547 |
+
model_map[key] = model_type
|
| 548 |
+
|
| 549 |
+
unique_iters = sorted(set(iters_map.values()))
|
| 550 |
+
base_colors = sns.color_palette("hls", n_colors=len(unique_iters))
|
| 551 |
+
color_lookup = {iters: base_colors[i] for i, iters in enumerate(unique_iters)}
|
| 552 |
+
|
| 553 |
+
legend_entries = []
|
| 554 |
+
global_max_x = 0
|
| 555 |
+
|
| 556 |
+
for key in sorted(grouped.keys(), key=lambda k: (iters_map[k], model_map[k])):
|
| 557 |
+
runs = grouped[key]
|
| 558 |
+
if not runs:
|
| 559 |
+
continue
|
| 560 |
+
|
| 561 |
+
iters = iters_map[key]
|
| 562 |
+
model = model_map[key]
|
| 563 |
+
color = color_lookup[iters]
|
| 564 |
+
linestyle = linestyle_map[key]
|
| 565 |
+
|
| 566 |
+
min_len = min(len(r) for r in runs)
|
| 567 |
+
trimmed = np.array([r[:min_len] for r in runs])[:, ::step]
|
| 568 |
+
|
| 569 |
+
mean = np.mean(trimmed, axis=0) * 100
|
| 570 |
+
std = np.std(trimmed, axis=0) * 100
|
| 571 |
+
|
| 572 |
+
if smooth:
|
| 573 |
+
window_size = max(1, int(0.05 * len(mean)))
|
| 574 |
+
if window_size % 2 == 0:
|
| 575 |
+
window_size += 1
|
| 576 |
+
kernel = np.ones(window_size) / window_size
|
| 577 |
+
|
| 578 |
+
smoothed_mean = np.convolve(mean, kernel, mode='same')
|
| 579 |
+
smoothed_std = np.convolve(std, kernel, mode='same')
|
| 580 |
+
|
| 581 |
+
valid_start = window_size // 2
|
| 582 |
+
valid_end = len(mean) - window_size // 2
|
| 583 |
+
valid_length = valid_end - valid_start
|
| 584 |
+
|
| 585 |
+
mean = smoothed_mean[valid_start:valid_end]
|
| 586 |
+
std = smoothed_std[valid_start:valid_end]
|
| 587 |
+
x = np.arange(valid_length) * step * evaluate_every
|
| 588 |
+
group_max_x = valid_length * step * evaluate_every
|
| 589 |
+
else:
|
| 590 |
+
x = np.arange(len(mean)) * step * evaluate_every
|
| 591 |
+
group_max_x = len(mean) * step * evaluate_every
|
| 592 |
+
|
| 593 |
+
global_max_x = max(global_max_x, group_max_x)
|
| 594 |
+
|
| 595 |
+
line, = ax.plot(x, mean, color=color, linestyle=linestyle, label=label_map[key])
|
| 596 |
+
ax.fill_between(x, mean - std, mean + std, alpha=0.1, color=color)
|
| 597 |
+
legend_entries.append((line, label_map[key]))
|
| 598 |
+
|
| 599 |
+
if smooth or x_max is None:
|
| 600 |
+
x_max = global_max_x
|
| 601 |
+
|
| 602 |
+
ax.set_xlim([0, x_max])
|
| 603 |
+
ax.set_ylim(top=100)
|
| 604 |
+
ax.set_xticks(np.arange(0, x_max + 1, 50000))
|
| 605 |
+
ax.set_xlabel("Training Iterations")
|
| 606 |
+
ax.set_ylabel("Accuracy (%)")
|
| 607 |
+
ax.grid(True, alpha=0.5)
|
| 608 |
+
|
| 609 |
+
style_legend = [
|
| 610 |
+
Line2D([0], [0], color='black', linestyle='-', label='CTM'),
|
| 611 |
+
Line2D([0], [0], color='black', linestyle='--', label='LSTM')
|
| 612 |
+
]
|
| 613 |
+
color_legend = [
|
| 614 |
+
Line2D([0], [0], color=color_lookup[it], linestyle='-', label=f"{it} Iters.")
|
| 615 |
+
for it in unique_iters
|
| 616 |
+
]
|
| 617 |
+
ax.legend(handles=color_legend + style_legend, loc="upper left")
|
| 618 |
+
|
| 619 |
+
fig.tight_layout(pad=0.1)
|
| 620 |
+
fig.savefig(save_path, dpi=300)
|
| 621 |
+
fig.savefig(save_path.replace("png", "pdf"), format='pdf')
|
| 622 |
+
plt.close(fig)
|
| 623 |
+
|
| 624 |
+
def extract_run_name(folder, run_index=None):
|
| 625 |
+
# Try to extract from parent folder
|
| 626 |
+
parent = os.path.basename(os.path.dirname(folder))
|
| 627 |
+
match = re.search(r'run(\d+)', parent, re.IGNORECASE)
|
| 628 |
+
if match:
|
| 629 |
+
return f"Run {int(match.group(1))}"
|
| 630 |
+
# Try current folder name
|
| 631 |
+
basename = os.path.basename(folder)
|
| 632 |
+
match = re.search(r'run(\d+)', basename, re.IGNORECASE)
|
| 633 |
+
if match:
|
| 634 |
+
return f"Run {int(match.group(1))}"
|
| 635 |
+
# Fallback: use run index
|
| 636 |
+
if run_index is not None:
|
| 637 |
+
return f"Run {run_index + 1}"
|
| 638 |
+
raise ValueError(f"Could not extract run number from: {folder}")
|
| 639 |
+
|
| 640 |
+
def plot_loss_individual_runs(training_data, evaluate_every, save_dir, scale=1.0, x_max=None):
|
| 641 |
+
|
| 642 |
+
grouped = defaultdict(list)
|
| 643 |
+
label_map = {}
|
| 644 |
+
iters_map = {}
|
| 645 |
+
model_map = {}
|
| 646 |
+
|
| 647 |
+
base_colors = sns.color_palette("hls", n_colors=3)
|
| 648 |
+
color_lookup = {f"Run {i+1}": base_colors[i] for i in range(3)}
|
| 649 |
+
|
| 650 |
+
for i, (folder, data) in enumerate(training_data.items()):
|
| 651 |
+
checkpoint = load_checkpoint(get_latest_checkpoint_file(folder), device="cpu")
|
| 652 |
+
model_args = get_model_args_from_checkpoint(checkpoint)
|
| 653 |
+
label, model_type, iters = parse_folder_name(folder)
|
| 654 |
+
if iters is None:
|
| 655 |
+
continue
|
| 656 |
+
|
| 657 |
+
if model_type.lower() == "ctm":
|
| 658 |
+
memory_length = getattr(model_args, "memory_length", None)
|
| 659 |
+
if memory_length is None:
|
| 660 |
+
raise ValueError(f"CTM model missing memory_length in checkpoint args from: {folder}")
|
| 661 |
+
key = f"{model_type}_{iters}_{memory_length}".lower()
|
| 662 |
+
else:
|
| 663 |
+
key = f"{model_type}_{iters}".lower()
|
| 664 |
+
|
| 665 |
+
run_name = extract_run_name(folder, run_index=i)
|
| 666 |
+
grouped[key].append((run_name, data["train_losses"]))
|
| 667 |
+
label_map[key] = f"{model_type}, {iters} Iters."
|
| 668 |
+
iters_map[key] = iters
|
| 669 |
+
model_map[key] = model_type
|
| 670 |
+
|
| 671 |
+
for key, runs in grouped.items():
|
| 672 |
+
fig, ax = plt.subplots(figsize=(scale * 10, scale * 5))
|
| 673 |
+
for run_name, losses in runs:
|
| 674 |
+
x = np.arange(len(losses)) * evaluate_every
|
| 675 |
+
color = color_lookup.get(run_name, 'gray')
|
| 676 |
+
ax.plot(x, losses, label=run_name, color=color, alpha=0.7)
|
| 677 |
+
|
| 678 |
+
ax.set_xlabel("Training Iterations")
|
| 679 |
+
ax.set_ylabel("Loss")
|
| 680 |
+
ax.set_ylim(bottom=-0.01)
|
| 681 |
+
ax.grid(True, alpha=0.5)
|
| 682 |
+
if x_max:
|
| 683 |
+
ax.set_xlim([0, x_max])
|
| 684 |
+
ax.set_xticks(np.arange(0, x_max + 1, 50000))
|
| 685 |
+
ax.legend()
|
| 686 |
+
fig.tight_layout(pad=0.1)
|
| 687 |
+
|
| 688 |
+
subdir = os.path.join(save_dir, key)
|
| 689 |
+
os.makedirs(subdir, exist_ok=True)
|
| 690 |
+
fname = os.path.join(subdir, f"individual_runs_loss_{key}.png")
|
| 691 |
+
fig.savefig(fname, dpi=300)
|
| 692 |
+
fig.savefig(fname.replace("png", "pdf"), format="pdf")
|
| 693 |
+
plt.close(fig)
|
| 694 |
+
|
| 695 |
+
def plot_accuracy_individual_runs(training_data, evaluate_every, save_dir, scale=1.0, smooth=False, x_max=None):
|
| 696 |
+
|
| 697 |
+
grouped = defaultdict(list)
|
| 698 |
+
label_map = {}
|
| 699 |
+
iters_map = {}
|
| 700 |
+
model_map = {}
|
| 701 |
+
|
| 702 |
+
base_colors = sns.color_palette("hls", n_colors=3)
|
| 703 |
+
color_lookup = {f"Run {i+1}": base_colors[i] for i in range(3)}
|
| 704 |
+
|
| 705 |
+
for i, (folder, data) in enumerate(training_data.items()):
|
| 706 |
+
checkpoint = load_checkpoint(get_latest_checkpoint_file(folder), device="cpu")
|
| 707 |
+
model_args = get_model_args_from_checkpoint(checkpoint)
|
| 708 |
+
label, model_type, iters = parse_folder_name(folder)
|
| 709 |
+
if iters is None:
|
| 710 |
+
continue
|
| 711 |
+
|
| 712 |
+
if model_type.lower() == "ctm":
|
| 713 |
+
memory_length = getattr(model_args, "memory_length", None)
|
| 714 |
+
if memory_length is None:
|
| 715 |
+
raise ValueError(f"CTM model missing memory_length in checkpoint args from: {folder}")
|
| 716 |
+
key = f"{model_type}_{iters}_{memory_length}".lower()
|
| 717 |
+
else:
|
| 718 |
+
key = f"{model_type}_{iters}".lower()
|
| 719 |
+
|
| 720 |
+
run_name = extract_run_name(folder, run_index=i)
|
| 721 |
+
grouped[key].append((run_name, data["test_accuracies"]))
|
| 722 |
+
label_map[key] = f"{model_type}, {iters} Iters."
|
| 723 |
+
iters_map[key] = iters
|
| 724 |
+
model_map[key] = model_type
|
| 725 |
+
|
| 726 |
+
for key, runs in grouped.items():
|
| 727 |
+
fig, ax = plt.subplots(figsize=(scale * 10, scale * 5))
|
| 728 |
+
for run_name, acc in runs:
|
| 729 |
+
acc = np.array(acc) * 100
|
| 730 |
+
if smooth:
|
| 731 |
+
window_size = max(1, int(0.05 * len(acc)))
|
| 732 |
+
if window_size % 2 == 0:
|
| 733 |
+
window_size += 1
|
| 734 |
+
kernel = np.ones(window_size) / window_size
|
| 735 |
+
acc = np.convolve(acc, kernel, mode="same")
|
| 736 |
+
|
| 737 |
+
x = np.arange(len(acc)) * evaluate_every
|
| 738 |
+
color = color_lookup.get(run_name, 'gray')
|
| 739 |
+
ax.plot(x, acc, label=run_name, color=color, alpha=0.7)
|
| 740 |
+
|
| 741 |
+
ax.set_xlabel("Training Iterations")
|
| 742 |
+
ax.set_ylabel("Accuracy (%)")
|
| 743 |
+
ax.set_ylim([50, 101])
|
| 744 |
+
ax.grid(True, alpha=0.5)
|
| 745 |
+
if x_max:
|
| 746 |
+
ax.set_xlim([0, x_max])
|
| 747 |
+
ax.set_xticks(np.arange(0, x_max + 1, 50000))
|
| 748 |
+
ax.legend()
|
| 749 |
+
fig.tight_layout(pad=0.1)
|
| 750 |
+
|
| 751 |
+
subdir = os.path.join(save_dir, key)
|
| 752 |
+
os.makedirs(subdir, exist_ok=True)
|
| 753 |
+
fname = os.path.join(subdir, f"individual_runs_accuracy_{key}.png")
|
| 754 |
+
fig.savefig(fname, dpi=300)
|
| 755 |
+
fig.savefig(fname.replace("png", "pdf"), format="pdf")
|
| 756 |
+
plt.close(fig)
|
| 757 |
+
|
| 758 |
+
def plot_training_curve_all_runs(all_folders, save_dir, scale, device, smooth=False, x_max=None, plot_individual_runs=True):
|
| 759 |
+
|
| 760 |
+
all_folders = [folder for folder in all_folders if "certain" not in folder]
|
| 761 |
+
|
| 762 |
+
training_data = {}
|
| 763 |
+
evaluation_intervals = []
|
| 764 |
+
for folder in all_folders:
|
| 765 |
+
latest_checkpoint_path = get_latest_checkpoint_file(folder)
|
| 766 |
+
if latest_checkpoint_path:
|
| 767 |
+
checkpoint = load_checkpoint(latest_checkpoint_path, device=device)
|
| 768 |
+
model_args = get_model_args_from_checkpoint(checkpoint)
|
| 769 |
+
evaluation_intervals.append(model_args.track_every)
|
| 770 |
+
|
| 771 |
+
_, train_losses, test_losses, train_accuracies, test_accuracies = get_accuracy_and_loss_from_checkpoint(checkpoint, device=device)
|
| 772 |
+
training_data[folder] = {
|
| 773 |
+
"train_losses": train_losses,
|
| 774 |
+
"test_losses": test_losses,
|
| 775 |
+
"train_accuracies": train_accuracies,
|
| 776 |
+
"test_accuracies": test_accuracies
|
| 777 |
+
}
|
| 778 |
+
else:
|
| 779 |
+
print(f"No checkpoint found for {folder}")
|
| 780 |
+
|
| 781 |
+
assert len(evaluation_intervals) > 0, "No valid checkpoints found."
|
| 782 |
+
assert all(interval == evaluation_intervals[0] for interval in evaluation_intervals), "Evaluation intervals are not consistent across runs."
|
| 783 |
+
|
| 784 |
+
evaluate_every = evaluation_intervals[0]
|
| 785 |
+
|
| 786 |
+
if plot_individual_runs:
|
| 787 |
+
plot_loss_individual_runs(training_data, evaluate_every, save_dir=save_dir, scale=scale, x_max=x_max)
|
| 788 |
+
plot_accuracy_individual_runs(training_data, evaluate_every, save_dir=save_dir, scale=scale, smooth=smooth, x_max=x_max)
|
| 789 |
+
|
| 790 |
+
plot_loss_all_runs(training_data, evaluate_every, save_path=f"{save_dir}/loss_comparison.png", scale=scale, x_max=x_max)
|
| 791 |
+
plot_accuracy_all_runs(training_data, evaluate_every, save_path=f"{save_dir}/accuracy_comparison.png", scale=scale, smooth=smooth, x_max=x_max)
|
| 792 |
+
|
| 793 |
+
return training_data
|
| 794 |
+
|
| 795 |
+
def plot_accuracy_thinking_time(csv_path, scale, output_dir="analysis/cifar"):
|
| 796 |
+
if not os.path.exists(csv_path):
|
| 797 |
+
raise FileNotFoundError(f"CSV file not found: {csv_path}")
|
| 798 |
+
|
| 799 |
+
df = pd.read_csv(csv_path)
|
| 800 |
+
df["RunName"] = df["Run"].apply(lambda x: os.path.basename(os.path.dirname(x)))
|
| 801 |
+
df["Model"] = df["Run"].apply(lambda x: "CTM" if "ctm" in x.lower() else "LSTM")
|
| 802 |
+
|
| 803 |
+
grouped = df.groupby(["Model", "Num Iterations"])
|
| 804 |
+
summary = grouped.agg(
|
| 805 |
+
mean_accuracy=("Overall Mean Accuracy", "mean"),
|
| 806 |
+
std_accuracy=("Overall Std Accuracy", lambda x: np.sqrt(np.mean(x**2)))
|
| 807 |
+
).reset_index()
|
| 808 |
+
|
| 809 |
+
summary["mean_accuracy"] *= 100
|
| 810 |
+
summary["std_accuracy"] *= 100
|
| 811 |
+
|
| 812 |
+
fig, ax = plt.subplots(figsize=(scale*5, scale*5))
|
| 813 |
+
|
| 814 |
+
for model in ("CTM", "LSTM"):
|
| 815 |
+
subset = summary[summary["Model"] == model].sort_values(by="Num Iterations")
|
| 816 |
+
linestyle = "-" if model == "CTM" else "--"
|
| 817 |
+
ax.errorbar(
|
| 818 |
+
subset["Num Iterations"],
|
| 819 |
+
subset["mean_accuracy"],
|
| 820 |
+
yerr=subset["std_accuracy"],
|
| 821 |
+
linestyle=linestyle,
|
| 822 |
+
color="black",
|
| 823 |
+
marker='.',
|
| 824 |
+
label=model,
|
| 825 |
+
capsize=3,
|
| 826 |
+
elinewidth=1,
|
| 827 |
+
errorevery=1
|
| 828 |
+
)
|
| 829 |
+
|
| 830 |
+
ax.set_xlabel("Internal Ticks")
|
| 831 |
+
ax.set_ylabel("Accuracy (%)")
|
| 832 |
+
custom_lines = [
|
| 833 |
+
Line2D([0], [0], color='black', linestyle='-', label='CTM'),
|
| 834 |
+
Line2D([0], [0], color='black', linestyle='--', label='LSTM')
|
| 835 |
+
]
|
| 836 |
+
ax.legend(handles=custom_lines, loc="lower right")
|
| 837 |
+
ax.grid(True, alpha=0.5)
|
| 838 |
+
|
| 839 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 840 |
+
output_path_png = os.path.join(output_dir, "accuracy_vs_thinking_time.png")
|
| 841 |
+
fig.tight_layout(pad=0.1)
|
| 842 |
+
fig.savefig(output_path_png, dpi=300)
|
| 843 |
+
fig.savefig(output_path_png.replace("png", "pdf"), format='pdf')
|
| 844 |
+
plt.close(fig)
|
| 845 |
+
|
| 846 |
+
|
| 847 |
+
def plot_lstm_last_and_certain_accuracy(all_folders, save_path="lstm_last_and_certain_accuracy.png", scale=1.0, step=1, x_max=None):
|
| 848 |
+
|
| 849 |
+
tags = ["lstm_10", "lstm_10_certain", "lstm_25", "lstm_25_certain"]
|
| 850 |
+
folders = [f for f in all_folders if any(tag in f.lower() for tag in tags)]
|
| 851 |
+
|
| 852 |
+
training_data, eval_intervals = {}, []
|
| 853 |
+
for f in folders:
|
| 854 |
+
cp = get_latest_checkpoint_file(f)
|
| 855 |
+
if not cp:
|
| 856 |
+
print(f"⚠️ No checkpoint in {f}")
|
| 857 |
+
continue
|
| 858 |
+
ckpt = load_checkpoint(cp, device="cpu")
|
| 859 |
+
args = get_model_args_from_checkpoint(ckpt)
|
| 860 |
+
eval_intervals.append(args.track_every)
|
| 861 |
+
_, _, _, _, acc = get_accuracy_and_loss_from_checkpoint(ckpt)
|
| 862 |
+
iters = "25" if "25" in f.lower() else "10"
|
| 863 |
+
label = "Certain" if "certain" in f.lower() else "Final"
|
| 864 |
+
training_data.setdefault((iters, label), []).append(acc)
|
| 865 |
+
|
| 866 |
+
assert training_data and all(i == eval_intervals[0] for i in eval_intervals), "Missing or inconsistent eval intervals."
|
| 867 |
+
evaluate_every = eval_intervals[0]
|
| 868 |
+
|
| 869 |
+
keys = sorted(training_data.keys())
|
| 870 |
+
colors = sns.color_palette("hls", n_colors=len(keys))
|
| 871 |
+
style_map = {key: ("--" if key[1] == "Certain" else "-") for key in keys}
|
| 872 |
+
color_map = {key: colors[i] for i, key in enumerate(keys)}
|
| 873 |
+
|
| 874 |
+
fig, ax = plt.subplots(figsize=(scale * 10, scale * 5))
|
| 875 |
+
max_x = 0
|
| 876 |
+
|
| 877 |
+
for key in keys:
|
| 878 |
+
runs = training_data[key]
|
| 879 |
+
min_len = min(len(r) for r in runs)
|
| 880 |
+
trimmed = np.stack([r[:min_len] for r in runs], axis=0)[:, ::step]
|
| 881 |
+
mean, std = np.mean(trimmed, 0) * 100, np.std(trimmed, 0) * 100
|
| 882 |
+
x = np.arange(len(mean)) * step * evaluate_every
|
| 883 |
+
ax.plot(x, mean, color=color_map[key], linestyle=style_map[key],
|
| 884 |
+
label=f"{key[0]} Iters, {key[1]}", linewidth=2, alpha=0.7)
|
| 885 |
+
ax.fill_between(x, mean - std, mean + std, color=color_map[key], alpha=0.1)
|
| 886 |
+
max_x = max(max_x, x[-1])
|
| 887 |
+
|
| 888 |
+
ax.set_xlim([0, x_max or max_x])
|
| 889 |
+
ax.set_xticks(np.arange(0, (x_max or max_x) + 1, 50000))
|
| 890 |
+
ax.set_xlabel("Training Iterations")
|
| 891 |
+
ax.set_ylabel("Accuracy (%)")
|
| 892 |
+
ax.grid(True, alpha=0.5)
|
| 893 |
+
ax.legend(loc="lower right")
|
| 894 |
+
fig.tight_layout(pad=0.1)
|
| 895 |
+
fig.savefig(save_path, dpi=300)
|
| 896 |
+
fig.savefig(save_path.replace("png", "pdf"), format="pdf")
|
| 897 |
+
plt.close(fig)
|