Upload folder using huggingface_hub
Browse files- .gitignore +169 -0
- .gitmodules +9 -0
- .vscode/launch.json +26 -0
- .vscode/settings.json +3 -0
- LICENSE +202 -0
- README.md +191 -3
- arc_eval.ipynb +252 -0
- assets/hrm.png +0 -0
- assets/npyjs.js +176 -0
- config/arch/hrm_v1.yaml +21 -0
- config/cfg_pretrain.yaml +31 -0
- dataset/build_arc_dataset.py +291 -0
- dataset/build_maze_dataset.py +142 -0
- dataset/build_sudoku_dataset.py +169 -0
- dataset/common.py +51 -0
- evaluate.py +68 -0
- models/common.py +32 -0
- models/hrm/hrm_act_v1.py +283 -0
- models/layers.py +158 -0
- models/losses.py +101 -0
- models/sparse_embedding.py +132 -0
- pretrain.py +453 -0
- puzzle_dataset.py +199 -0
- puzzle_visualizer.html +426 -0
- requirements.txt +11 -0
- utils/functions.py +19 -0
.gitignore
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# WandB
|
| 2 |
+
/wandb/
|
| 3 |
+
# checkpoints
|
| 4 |
+
/checkpoints/
|
| 5 |
+
# cache
|
| 6 |
+
/cache/
|
| 7 |
+
# data
|
| 8 |
+
/data/
|
| 9 |
+
|
| 10 |
+
# Byte-compiled / optimized / DLL files
|
| 11 |
+
__pycache__/
|
| 12 |
+
*.py[cod]
|
| 13 |
+
*$py.class
|
| 14 |
+
|
| 15 |
+
# C extensions
|
| 16 |
+
*.so
|
| 17 |
+
|
| 18 |
+
# Distribution / packaging
|
| 19 |
+
.Python
|
| 20 |
+
build/
|
| 21 |
+
develop-eggs/
|
| 22 |
+
dist/
|
| 23 |
+
downloads/
|
| 24 |
+
eggs/
|
| 25 |
+
.eggs/
|
| 26 |
+
lib/
|
| 27 |
+
lib64/
|
| 28 |
+
parts/
|
| 29 |
+
sdist/
|
| 30 |
+
var/
|
| 31 |
+
wheels/
|
| 32 |
+
share/python-wheels/
|
| 33 |
+
*.egg-info/
|
| 34 |
+
.installed.cfg
|
| 35 |
+
*.egg
|
| 36 |
+
MANIFEST
|
| 37 |
+
|
| 38 |
+
# PyInstaller
|
| 39 |
+
# Usually these files are written by a python script from a template
|
| 40 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 41 |
+
*.manifest
|
| 42 |
+
*.spec
|
| 43 |
+
|
| 44 |
+
# Installer logs
|
| 45 |
+
pip-log.txt
|
| 46 |
+
pip-delete-this-directory.txt
|
| 47 |
+
|
| 48 |
+
# Unit test / coverage reports
|
| 49 |
+
htmlcov/
|
| 50 |
+
.tox/
|
| 51 |
+
.nox/
|
| 52 |
+
.coverage
|
| 53 |
+
.coverage.*
|
| 54 |
+
.cache
|
| 55 |
+
nosetests.xml
|
| 56 |
+
coverage.xml
|
| 57 |
+
*.cover
|
| 58 |
+
*.py,cover
|
| 59 |
+
.hypothesis/
|
| 60 |
+
.pytest_cache/
|
| 61 |
+
cover/
|
| 62 |
+
|
| 63 |
+
# Translations
|
| 64 |
+
*.mo
|
| 65 |
+
*.pot
|
| 66 |
+
|
| 67 |
+
# Django stuff:
|
| 68 |
+
*.log
|
| 69 |
+
local_settings.py
|
| 70 |
+
db.sqlite3
|
| 71 |
+
db.sqlite3-journal
|
| 72 |
+
|
| 73 |
+
# Flask stuff:
|
| 74 |
+
instance/
|
| 75 |
+
.webassets-cache
|
| 76 |
+
|
| 77 |
+
# Scrapy stuff:
|
| 78 |
+
.scrapy
|
| 79 |
+
|
| 80 |
+
# Sphinx documentation
|
| 81 |
+
docs/_build/
|
| 82 |
+
|
| 83 |
+
# PyBuilder
|
| 84 |
+
.pybuilder/
|
| 85 |
+
target/
|
| 86 |
+
|
| 87 |
+
# Jupyter Notebook
|
| 88 |
+
.ipynb_checkpoints
|
| 89 |
+
|
| 90 |
+
# IPython
|
| 91 |
+
profile_default/
|
| 92 |
+
ipython_config.py
|
| 93 |
+
|
| 94 |
+
# pyenv
|
| 95 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 96 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 97 |
+
# .python-version
|
| 98 |
+
|
| 99 |
+
# pipenv
|
| 100 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 101 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 102 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 103 |
+
# install all needed dependencies.
|
| 104 |
+
#Pipfile.lock
|
| 105 |
+
|
| 106 |
+
# poetry
|
| 107 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 108 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 109 |
+
# commonly ignored for libraries.
|
| 110 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 111 |
+
#poetry.lock
|
| 112 |
+
|
| 113 |
+
# pdm
|
| 114 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 115 |
+
#pdm.lock
|
| 116 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
| 117 |
+
# in version control.
|
| 118 |
+
# https://pdm.fming.dev/#use-with-ide
|
| 119 |
+
.pdm.toml
|
| 120 |
+
|
| 121 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 122 |
+
__pypackages__/
|
| 123 |
+
|
| 124 |
+
# Celery stuff
|
| 125 |
+
celerybeat-schedule
|
| 126 |
+
celerybeat.pid
|
| 127 |
+
|
| 128 |
+
# SageMath parsed files
|
| 129 |
+
*.sage.py
|
| 130 |
+
|
| 131 |
+
# Environments
|
| 132 |
+
.env
|
| 133 |
+
.venv
|
| 134 |
+
env/
|
| 135 |
+
venv/
|
| 136 |
+
ENV/
|
| 137 |
+
env.bak/
|
| 138 |
+
venv.bak/
|
| 139 |
+
|
| 140 |
+
# Spyder project settings
|
| 141 |
+
.spyderproject
|
| 142 |
+
.spyproject
|
| 143 |
+
|
| 144 |
+
# Rope project settings
|
| 145 |
+
.ropeproject
|
| 146 |
+
|
| 147 |
+
# mkdocs documentation
|
| 148 |
+
/site
|
| 149 |
+
|
| 150 |
+
# mypy
|
| 151 |
+
.mypy_cache/
|
| 152 |
+
.dmypy.json
|
| 153 |
+
dmypy.json
|
| 154 |
+
|
| 155 |
+
# Pyre type checker
|
| 156 |
+
.pyre/
|
| 157 |
+
|
| 158 |
+
# pytype static type analyzer
|
| 159 |
+
.pytype/
|
| 160 |
+
|
| 161 |
+
# Cython debug symbols
|
| 162 |
+
cython_debug/
|
| 163 |
+
|
| 164 |
+
# PyCharm
|
| 165 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 166 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 167 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 168 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 169 |
+
#.idea/
|
.gitmodules
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[submodule "dataset/raw-data/ConceptARC"]
|
| 2 |
+
path = dataset/raw-data/ConceptARC
|
| 3 |
+
url = git@github.com:victorvikram/ConceptARC.git
|
| 4 |
+
[submodule "dataset/raw-data/ARC-AGI"]
|
| 5 |
+
path = dataset/raw-data/ARC-AGI
|
| 6 |
+
url = git@github.com:fchollet/ARC-AGI.git
|
| 7 |
+
[submodule "dataset/raw-data/ARC-AGI-2"]
|
| 8 |
+
path = dataset/raw-data/ARC-AGI-2
|
| 9 |
+
url = git@github.com:arcprize/ARC-AGI-2.git
|
.vscode/launch.json
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
// Use IntelliSense to learn about possible attributes.
|
| 3 |
+
// Hover to view descriptions of existing attributes.
|
| 4 |
+
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
| 5 |
+
"version": "0.2.0",
|
| 6 |
+
"configurations": [
|
| 7 |
+
{
|
| 8 |
+
"name": "Python Debugger: Current File",
|
| 9 |
+
"type": "debugpy",
|
| 10 |
+
"request": "launch",
|
| 11 |
+
"program": "${file}",
|
| 12 |
+
"console": "integratedTerminal"
|
| 13 |
+
},
|
| 14 |
+
{
|
| 15 |
+
"name": "Debug: Single GPU",
|
| 16 |
+
"type": "debugpy",
|
| 17 |
+
"request": "launch",
|
| 18 |
+
"program": "pretrain.py",
|
| 19 |
+
"args": [],
|
| 20 |
+
"env": {
|
| 21 |
+
"OMP_NUM_THREADS": "1",
|
| 22 |
+
"DISABLE_COMPILE": "true"
|
| 23 |
+
}
|
| 24 |
+
}
|
| 25 |
+
]
|
| 26 |
+
}
|
.vscode/settings.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"python.analysis.typeCheckingMode": "standard"
|
| 3 |
+
}
|
LICENSE
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
Apache License
|
| 3 |
+
Version 2.0, January 2004
|
| 4 |
+
http://www.apache.org/licenses/
|
| 5 |
+
|
| 6 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 7 |
+
|
| 8 |
+
1. Definitions.
|
| 9 |
+
|
| 10 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 11 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 12 |
+
|
| 13 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 14 |
+
the copyright owner that is granting the License.
|
| 15 |
+
|
| 16 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 17 |
+
other entities that control, are controlled by, or are under common
|
| 18 |
+
control with that entity. For the purposes of this definition,
|
| 19 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 20 |
+
direction or management of such entity, whether by contract or
|
| 21 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 22 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 23 |
+
|
| 24 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 25 |
+
exercising permissions granted by this License.
|
| 26 |
+
|
| 27 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 28 |
+
including but not limited to software source code, documentation
|
| 29 |
+
source, and configuration files.
|
| 30 |
+
|
| 31 |
+
"Object" form shall mean any form resulting from mechanical
|
| 32 |
+
transformation or translation of a Source form, including but
|
| 33 |
+
not limited to compiled object code, generated documentation,
|
| 34 |
+
and conversions to other media types.
|
| 35 |
+
|
| 36 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 37 |
+
Object form, made available under the License, as indicated by a
|
| 38 |
+
copyright notice that is included in or attached to the work
|
| 39 |
+
(an example is provided in the Appendix below).
|
| 40 |
+
|
| 41 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 42 |
+
form, that is based on (or derived from) the Work and for which the
|
| 43 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 44 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 45 |
+
of this License, Derivative Works shall not include works that remain
|
| 46 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 47 |
+
the Work and Derivative Works thereof.
|
| 48 |
+
|
| 49 |
+
"Contribution" shall mean any work of authorship, including
|
| 50 |
+
the original version of the Work and any modifications or additions
|
| 51 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 52 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 53 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 54 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 55 |
+
means any form of electronic, verbal, or written communication sent
|
| 56 |
+
to the Licensor or its representatives, including but not limited to
|
| 57 |
+
communication on electronic mailing lists, source code control systems,
|
| 58 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 59 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 60 |
+
excluding communication that is conspicuously marked or otherwise
|
| 61 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 62 |
+
|
| 63 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 64 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 65 |
+
subsequently incorporated within the Work.
|
| 66 |
+
|
| 67 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 68 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 69 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 70 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 71 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 72 |
+
Work and such Derivative Works in Source or Object form.
|
| 73 |
+
|
| 74 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 75 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 76 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 77 |
+
(except as stated in this section) patent license to make, have made,
|
| 78 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 79 |
+
where such license applies only to those patent claims licensable
|
| 80 |
+
by such Contributor that are necessarily infringed by their
|
| 81 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 82 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 83 |
+
institute patent litigation against any entity (including a
|
| 84 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 85 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 86 |
+
or contributory patent infringement, then any patent licenses
|
| 87 |
+
granted to You under this License for that Work shall terminate
|
| 88 |
+
as of the date such litigation is filed.
|
| 89 |
+
|
| 90 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 91 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 92 |
+
modifications, and in Source or Object form, provided that You
|
| 93 |
+
meet the following conditions:
|
| 94 |
+
|
| 95 |
+
(a) You must give any other recipients of the Work or
|
| 96 |
+
Derivative Works a copy of this License; and
|
| 97 |
+
|
| 98 |
+
(b) You must cause any modified files to carry prominent notices
|
| 99 |
+
stating that You changed the files; and
|
| 100 |
+
|
| 101 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 102 |
+
that You distribute, all copyright, patent, trademark, and
|
| 103 |
+
attribution notices from the Source form of the Work,
|
| 104 |
+
excluding those notices that do not pertain to any part of
|
| 105 |
+
the Derivative Works; and
|
| 106 |
+
|
| 107 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 108 |
+
distribution, then any Derivative Works that You distribute must
|
| 109 |
+
include a readable copy of the attribution notices contained
|
| 110 |
+
within such NOTICE file, excluding those notices that do not
|
| 111 |
+
pertain to any part of the Derivative Works, in at least one
|
| 112 |
+
of the following places: within a NOTICE text file distributed
|
| 113 |
+
as part of the Derivative Works; within the Source form or
|
| 114 |
+
documentation, if provided along with the Derivative Works; or,
|
| 115 |
+
within a display generated by the Derivative Works, if and
|
| 116 |
+
wherever such third-party notices normally appear. The contents
|
| 117 |
+
of the NOTICE file are for informational purposes only and
|
| 118 |
+
do not modify the License. You may add Your own attribution
|
| 119 |
+
notices within Derivative Works that You distribute, alongside
|
| 120 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 121 |
+
that such additional attribution notices cannot be construed
|
| 122 |
+
as modifying the License.
|
| 123 |
+
|
| 124 |
+
You may add Your own copyright statement to Your modifications and
|
| 125 |
+
may provide additional or different license terms and conditions
|
| 126 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 127 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 128 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 129 |
+
the conditions stated in this License.
|
| 130 |
+
|
| 131 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 132 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 133 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 134 |
+
this License, without any additional terms or conditions.
|
| 135 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 136 |
+
the terms of any separate license agreement you may have executed
|
| 137 |
+
with Licensor regarding such Contributions.
|
| 138 |
+
|
| 139 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 140 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 141 |
+
except as required for reasonable and customary use in describing the
|
| 142 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 143 |
+
|
| 144 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 145 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 146 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 147 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 148 |
+
implied, including, without limitation, any warranties or conditions
|
| 149 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 150 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 151 |
+
appropriateness of using or redistributing the Work and assume any
|
| 152 |
+
risks associated with Your exercise of permissions under this License.
|
| 153 |
+
|
| 154 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 155 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 156 |
+
unless required by applicable law (such as deliberate and grossly
|
| 157 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 158 |
+
liable to You for damages, including any direct, indirect, special,
|
| 159 |
+
incidental, or consequential damages of any character arising as a
|
| 160 |
+
result of this License or out of the use or inability to use the
|
| 161 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 162 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 163 |
+
other commercial damages or losses), even if such Contributor
|
| 164 |
+
has been advised of the possibility of such damages.
|
| 165 |
+
|
| 166 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 167 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 168 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 169 |
+
or other liability obligations and/or rights consistent with this
|
| 170 |
+
License. However, in accepting such obligations, You may act only
|
| 171 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 172 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 173 |
+
defend, and hold each Contributor harmless for any liability
|
| 174 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 175 |
+
of your accepting any such warranty or additional liability.
|
| 176 |
+
|
| 177 |
+
END OF TERMS AND CONDITIONS
|
| 178 |
+
|
| 179 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 180 |
+
|
| 181 |
+
To apply the Apache License to your work, attach the following
|
| 182 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 183 |
+
replaced with your own identifying information. (Don't include
|
| 184 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 185 |
+
comment syntax for the file format. We also recommend that a
|
| 186 |
+
file or class name and description of purpose be included on the
|
| 187 |
+
same "printed page" as the copyright notice for easier
|
| 188 |
+
identification within third-party archives.
|
| 189 |
+
|
| 190 |
+
Copyright [yyyy] [name of copyright owner]
|
| 191 |
+
|
| 192 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 193 |
+
you may not use this file except in compliance with the License.
|
| 194 |
+
You may obtain a copy of the License at
|
| 195 |
+
|
| 196 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 197 |
+
|
| 198 |
+
Unless required by applicable law or agreed to in writing, software
|
| 199 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 200 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 201 |
+
See the License for the specific language governing permissions and
|
| 202 |
+
limitations under the License.
|
README.md
CHANGED
|
@@ -1,3 +1,191 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Hierarchical Reasoning Model
|
| 2 |
+
|
| 3 |
+

|
| 4 |
+
|
| 5 |
+
Reasoning, the process of devising and executing complex goal-oriented action sequences, remains a critical challenge in AI.
|
| 6 |
+
Current large language models (LLMs) primarily employ Chain-of-Thought (CoT) techniques, which suffer from brittle task decomposition, extensive data requirements, and high latency. Inspired by the hierarchical and multi-timescale processing in the human brain, we propose the Hierarchical Reasoning Model (HRM), a novel recurrent architecture that attains significant computational depth while maintaining both training stability and efficiency.
|
| 7 |
+
HRM executes sequential reasoning tasks in a single forward pass without explicit supervision of the intermediate process, through two interdependent recurrent modules: a high-level module responsible for slow, abstract planning, and a low-level module handling rapid, detailed computations. With only 27 million parameters, HRM achieves exceptional performance on complex reasoning tasks using only 1000 training samples. The model operates without pre-training or CoT data, yet achieves nearly perfect performance on challenging tasks including complex Sudoku puzzles and optimal path finding in large mazes.
|
| 8 |
+
Furthermore, HRM outperforms much larger models with significantly longer context windows on the Abstraction and Reasoning Corpus (ARC), a key benchmark for measuring artificial general intelligence capabilities.
|
| 9 |
+
These results underscore HRM’s potential as a transformative advancement toward universal computation and general-purpose reasoning systems.
|
| 10 |
+
|
| 11 |
+
## Quick Start Guide 🚀
|
| 12 |
+
|
| 13 |
+
### Prerequisites ⚙️
|
| 14 |
+
|
| 15 |
+
Ensure PyTorch and CUDA are installed. The repo needs CUDA extensions to be built. If not present, run the following commands:
|
| 16 |
+
|
| 17 |
+
```bash
|
| 18 |
+
# Install CUDA 12.6
|
| 19 |
+
CUDA_URL=https://developer.download.nvidia.com/compute/cuda/12.6.3/local_installers/cuda_12.6.3_560.35.05_linux.run
|
| 20 |
+
|
| 21 |
+
wget -q --show-progress --progress=bar:force:noscroll -O cuda_installer.run $CUDA_URL
|
| 22 |
+
sudo sh cuda_installer.run --silent --toolkit --override
|
| 23 |
+
|
| 24 |
+
export CUDA_HOME=/usr/local/cuda-12.6
|
| 25 |
+
|
| 26 |
+
# Install PyTorch with CUDA 12.6
|
| 27 |
+
PYTORCH_INDEX_URL=https://download.pytorch.org/whl/cu126
|
| 28 |
+
|
| 29 |
+
pip3 install torch torchvision torchaudio --index-url $PYTORCH_INDEX_URL
|
| 30 |
+
|
| 31 |
+
# Additional packages for building extensions
|
| 32 |
+
pip3 install packaging ninja wheel setuptools setuptools-scm
|
| 33 |
+
```
|
| 34 |
+
|
| 35 |
+
Then install FlashAttention. For Hopper GPUs, install FlashAttention 3
|
| 36 |
+
|
| 37 |
+
```bash
|
| 38 |
+
git clone git@github.com:Dao-AILab/flash-attention.git
|
| 39 |
+
cd flash-attention/hopper
|
| 40 |
+
python setup.py install
|
| 41 |
+
```
|
| 42 |
+
|
| 43 |
+
For Ampere or earlier GPUs, install FlashAttention 2
|
| 44 |
+
|
| 45 |
+
```bash
|
| 46 |
+
pip3 install flash-attn
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
## Install Python Dependencies 🐍
|
| 50 |
+
|
| 51 |
+
```bash
|
| 52 |
+
pip install -r requirements.txt
|
| 53 |
+
```
|
| 54 |
+
|
| 55 |
+
## W&B Integration 📈
|
| 56 |
+
|
| 57 |
+
This project uses [Weights & Biases](https://wandb.ai/) for experiment tracking and metric visualization. Ensure you're logged in:
|
| 58 |
+
|
| 59 |
+
```bash
|
| 60 |
+
wandb login
|
| 61 |
+
```
|
| 62 |
+
|
| 63 |
+
## Run Experiments
|
| 64 |
+
|
| 65 |
+
### Quick Demo: Sudoku Solver 💻🗲
|
| 66 |
+
|
| 67 |
+
Train a master-level Sudoku AI capable of solving extremely difficult puzzles on a modern laptop GPU. 🧩
|
| 68 |
+
|
| 69 |
+
```bash
|
| 70 |
+
# Download and build Sudoku dataset
|
| 71 |
+
python dataset/build_sudoku_dataset.py --output-dir data/sudoku-extreme-1k-aug-1000 --subsample-size 1000 --num-aug 1000
|
| 72 |
+
|
| 73 |
+
# Start training (single GPU, smaller batch size)
|
| 74 |
+
OMP_NUM_THREADS=8 python pretrain.py data_path=data/sudoku-extreme-1k-aug-1000 epochs=20000 eval_interval=2000 global_batch_size=384 lr=7e-5 puzzle_emb_lr=7e-5 weight_decay=1.0 puzzle_emb_weight_decay=1.0
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
Runtime: ~10 hours on a RTX 4070 laptop GPU
|
| 78 |
+
|
| 79 |
+
## Trained Checkpoints 🚧
|
| 80 |
+
|
| 81 |
+
- [ARC-AGI-2](https://huggingface.co/sapientinc/HRM-checkpoint-ARC-2)
|
| 82 |
+
- [Sudoku 9x9 Extreme (1000 examples)](https://huggingface.co/sapientinc/HRM-checkpoint-sudoku-extreme)
|
| 83 |
+
- [Maze 30x30 Hard (1000 examples)](https://huggingface.co/sapientinc/HRM-checkpoint-maze-30x30-hard)
|
| 84 |
+
|
| 85 |
+
To use the checkpoints, see Evaluation section below.
|
| 86 |
+
|
| 87 |
+
## Full-scale Experiments 🔵
|
| 88 |
+
|
| 89 |
+
Experiments below assume an 8-GPU setup.
|
| 90 |
+
|
| 91 |
+
### Dataset Preparation
|
| 92 |
+
|
| 93 |
+
```bash
|
| 94 |
+
# Initialize submodules
|
| 95 |
+
git submodule update --init --recursive
|
| 96 |
+
|
| 97 |
+
# ARC-1
|
| 98 |
+
python dataset/build_arc_dataset.py # ARC offical + ConceptARC, 960 examples
|
| 99 |
+
# ARC-2
|
| 100 |
+
python dataset/build_arc_dataset.py --dataset-dirs dataset/raw-data/ARC-AGI-2/data --output-dir data/arc-2-aug-1000 # ARC-2 official, 1120 examples
|
| 101 |
+
|
| 102 |
+
# Sudoku-Extreme
|
| 103 |
+
python dataset/build_sudoku_dataset.py # Full version
|
| 104 |
+
python dataset/build_sudoku_dataset.py --output-dir data/sudoku-extreme-1k-aug-1000 --subsample-size 1000 --num-aug 1000 # 1000 examples
|
| 105 |
+
|
| 106 |
+
# Maze
|
| 107 |
+
python dataset/build_maze_dataset.py # 1000 examples
|
| 108 |
+
```
|
| 109 |
+
|
| 110 |
+
### Dataset Visualization
|
| 111 |
+
|
| 112 |
+
Explore the puzzles visually:
|
| 113 |
+
|
| 114 |
+
* Open `puzzle_visualizer.html` in your browser.
|
| 115 |
+
* Upload the generated dataset folder located in `data/...`.
|
| 116 |
+
|
| 117 |
+
## Launch experiments
|
| 118 |
+
|
| 119 |
+
### Small-sample (1K)
|
| 120 |
+
|
| 121 |
+
ARC-1:
|
| 122 |
+
|
| 123 |
+
```bash
|
| 124 |
+
OMP_NUM_THREADS=8 torchrun --nproc-per-node 8 pretrain.py
|
| 125 |
+
```
|
| 126 |
+
|
| 127 |
+
*Runtime:* ~24 hours
|
| 128 |
+
|
| 129 |
+
ARC-2:
|
| 130 |
+
|
| 131 |
+
```bash
|
| 132 |
+
OMP_NUM_THREADS=8 torchrun --nproc-per-node 8 pretrain.py data_path=data/arc-2-aug-1000
|
| 133 |
+
```
|
| 134 |
+
|
| 135 |
+
*Runtime:* ~24 hours (checkpoint after 8 hours is often sufficient)
|
| 136 |
+
|
| 137 |
+
Sudoku Extreme (1k):
|
| 138 |
+
|
| 139 |
+
```bash
|
| 140 |
+
OMP_NUM_THREADS=8 torchrun --nproc-per-node 8 pretrain.py data_path=data/sudoku-extreme-1k-aug-1000 epochs=20000 eval_interval=2000 lr=1e-4 puzzle_emb_lr=1e-4 weight_decay=1.0 puzzle_emb_weight_decay=1.0
|
| 141 |
+
```
|
| 142 |
+
|
| 143 |
+
*Runtime:* ~10 minutes
|
| 144 |
+
|
| 145 |
+
Maze 30x30 Hard (1k):
|
| 146 |
+
|
| 147 |
+
```bash
|
| 148 |
+
OMP_NUM_THREADS=8 torchrun --nproc-per-node 8 pretrain.py data_path=data/maze-30x30-hard-1k epochs=20000 eval_interval=2000 lr=1e-4 puzzle_emb_lr=1e-4 weight_decay=1.0 puzzle_emb_weight_decay=1.0
|
| 149 |
+
```
|
| 150 |
+
|
| 151 |
+
*Runtime:* ~1 hour
|
| 152 |
+
|
| 153 |
+
### Full Sudoku-Hard
|
| 154 |
+
|
| 155 |
+
```bash
|
| 156 |
+
OMP_NUM_THREADS=8 torchrun --nproc-per-node 8 pretrain.py data_path=data/sudoku-hard-full epochs=100 eval_interval=10 lr_min_ratio=0.1 global_batch_size=2304 lr=3e-4 puzzle_emb_lr=3e-4 weight_decay=0.1 puzzle_emb_weight_decay=0.1 arch.loss.loss_type=softmax_cross_entropy arch.L_cycles=8 arch.halt_max_steps=8 arch.pos_encodings=learned
|
| 157 |
+
```
|
| 158 |
+
|
| 159 |
+
*Runtime:* ~2 hours
|
| 160 |
+
|
| 161 |
+
## Evaluation
|
| 162 |
+
|
| 163 |
+
Evaluate your trained models:
|
| 164 |
+
|
| 165 |
+
* Check `eval/exact_accuracy` in W&B.
|
| 166 |
+
* For ARC-AGI, follow these additional steps:
|
| 167 |
+
|
| 168 |
+
```bash
|
| 169 |
+
OMP_NUM_THREADS=8 torchrun --nproc-per-node 8 evaluate.py checkpoint=<CHECKPOINT_PATH>
|
| 170 |
+
```
|
| 171 |
+
|
| 172 |
+
* Then use the provided `arc_eval.ipynb` notebook to finalize and inspect your results.
|
| 173 |
+
|
| 174 |
+
## Notes
|
| 175 |
+
|
| 176 |
+
- Small-sample learning typically exhibits accuracy variance of around ±2 points.
|
| 177 |
+
- For Sudoku-Extreme (1,000-example dataset), late-stage overfitting may cause numerical instability during training and Q-learning. It is advisable to use early stopping once the training accuracy approaches 100%.
|
| 178 |
+
|
| 179 |
+
## Citation 📜
|
| 180 |
+
|
| 181 |
+
```bibtex
|
| 182 |
+
@misc{wang2025hierarchicalreasoningmodel,
|
| 183 |
+
title={Hierarchical Reasoning Model},
|
| 184 |
+
author={Guan Wang and Jin Li and Yuhao Sun and Xing Chen and Changling Liu and Yue Wu and Meng Lu and Sen Song and Yasin Abbasi Yadkori},
|
| 185 |
+
year={2025},
|
| 186 |
+
eprint={2506.21734},
|
| 187 |
+
archivePrefix={arXiv},
|
| 188 |
+
primaryClass={cs.AI},
|
| 189 |
+
url={https://arxiv.org/abs/2506.21734},
|
| 190 |
+
}
|
| 191 |
+
```
|
arc_eval.ipynb
ADDED
|
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": null,
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"outputs": [],
|
| 8 |
+
"source": [
|
| 9 |
+
"import os\n",
|
| 10 |
+
"import json\n",
|
| 11 |
+
"from glob import glob\n",
|
| 12 |
+
"import hashlib\n",
|
| 13 |
+
"import matplotlib.pyplot as plt\n",
|
| 14 |
+
"import matplotlib.colors as mcolors\n",
|
| 15 |
+
"\n",
|
| 16 |
+
"import torch\n",
|
| 17 |
+
"import torch.nn.functional as F\n",
|
| 18 |
+
"import numpy as np\n",
|
| 19 |
+
"from numba import njit\n",
|
| 20 |
+
"\n",
|
| 21 |
+
"from dataset.common import inverse_dihedral_transform\n",
|
| 22 |
+
"\n",
|
| 23 |
+
"\n",
|
| 24 |
+
"DATASET_PATH = \"data/arc-aug-1000\" # ARC-1\n",
|
| 25 |
+
"# DATASET_PATH = \"data/arc-2-aug-1000\" # ARC-2\n",
|
| 26 |
+
"\n",
|
| 27 |
+
"CHECKPOINT_PATH = \"checkpoints/Arc-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV1 amphibian-turaco/step_414456\"\n",
|
| 28 |
+
"\n",
|
| 29 |
+
"\n",
|
| 30 |
+
"PAD_PUZZLE_IDENTIFIER = 0\n",
|
| 31 |
+
"\n",
|
| 32 |
+
"# Visualization\n",
|
| 33 |
+
"ARC_COLOR_MAP = mcolors.ListedColormap([\n",
|
| 34 |
+
" \"#000000\", # symbol_0: black\n",
|
| 35 |
+
" \"#0074D9\", # symbol_1: blue\n",
|
| 36 |
+
" \"#FF4136\", # symbol_2: red\n",
|
| 37 |
+
" \"#2ECC40\", # symbol_3: green\n",
|
| 38 |
+
" \"#FFDC00\", # symbol_4: yellow\n",
|
| 39 |
+
" \"#AAAAAA\", # symbol_5: grey\n",
|
| 40 |
+
" \"#F012BE\", # symbol_6: fuschia\n",
|
| 41 |
+
" \"#FF851B\", # symbol_7: orange\n",
|
| 42 |
+
" \"#7FDBFF\", # symbol_8: teal\n",
|
| 43 |
+
" \"#870C25\" # symbol_9: brown\n",
|
| 44 |
+
"])"
|
| 45 |
+
]
|
| 46 |
+
},
|
| 47 |
+
{
|
| 48 |
+
"cell_type": "code",
|
| 49 |
+
"execution_count": null,
|
| 50 |
+
"metadata": {},
|
| 51 |
+
"outputs": [],
|
| 52 |
+
"source": [
|
| 53 |
+
"def load_identifiers_and_preds(dataset_path: str, checkpoint_path: str):\n",
|
| 54 |
+
" # Load puzzle identifiers\n",
|
| 55 |
+
" with open(os.path.join(dataset_path, \"identifiers.json\"), \"r\") as f:\n",
|
| 56 |
+
" identifier_map = json.load(f)\n",
|
| 57 |
+
" \n",
|
| 58 |
+
" # Load preds\n",
|
| 59 |
+
" all_preds = {}\n",
|
| 60 |
+
" for filename in glob(f\"{checkpoint_path}_all_preds.*\"):\n",
|
| 61 |
+
" preds = torch.load(filename)\n",
|
| 62 |
+
" for k, v in preds.items():\n",
|
| 63 |
+
" all_preds.setdefault(k, [])\n",
|
| 64 |
+
" all_preds[k].append(v)\n",
|
| 65 |
+
" \n",
|
| 66 |
+
" del preds\n",
|
| 67 |
+
"\n",
|
| 68 |
+
" all_preds = {k: torch.cat(v, dim=0) for k, v in all_preds.items()}\n",
|
| 69 |
+
" \n",
|
| 70 |
+
" # Remove paddings\n",
|
| 71 |
+
" mask = all_preds[\"puzzle_identifiers\"] != PAD_PUZZLE_IDENTIFIER\n",
|
| 72 |
+
" all_preds = {k: v[mask] for k, v in all_preds.items()}\n",
|
| 73 |
+
"\n",
|
| 74 |
+
" return identifier_map, all_preds\n",
|
| 75 |
+
"\n",
|
| 76 |
+
"\n",
|
| 77 |
+
"def inverse_aug(name: str, grid: np.ndarray):\n",
|
| 78 |
+
" if \"_\" not in name:\n",
|
| 79 |
+
" return grid\n",
|
| 80 |
+
"\n",
|
| 81 |
+
" trans_id, perm = name.split(\"_\")[-2:]\n",
|
| 82 |
+
" trans_id = int(trans_id[1:]) # Remove \"t\" letter\n",
|
| 83 |
+
" inv_perm = np.argsort(list(perm))\n",
|
| 84 |
+
" \n",
|
| 85 |
+
" return inv_perm[inverse_dihedral_transform(grid, trans_id)]\n",
|
| 86 |
+
"\n",
|
| 87 |
+
"\n",
|
| 88 |
+
"def grid_hash(grid: np.ndarray):\n",
|
| 89 |
+
" return hash((grid.tobytes(), grid.shape))\n",
|
| 90 |
+
"\n",
|
| 91 |
+
"\n",
|
| 92 |
+
"@njit\n",
|
| 93 |
+
"def crop(grid: np.ndarray):\n",
|
| 94 |
+
" # Find maximum-sized rectangle without any EOS token inside.\n",
|
| 95 |
+
" grid = grid.reshape(30, 30)\n",
|
| 96 |
+
"\n",
|
| 97 |
+
" max_area = 0\n",
|
| 98 |
+
" max_size = (0, 0)\n",
|
| 99 |
+
" nr, nc = grid.shape\n",
|
| 100 |
+
" \n",
|
| 101 |
+
" num_c = nc\n",
|
| 102 |
+
" for num_r in range(1, nr + 1):\n",
|
| 103 |
+
" # Scan for maximum c\n",
|
| 104 |
+
" for c in range(1, num_c + 1):\n",
|
| 105 |
+
" x = grid[num_r - 1, c - 1]\n",
|
| 106 |
+
" if (x < 2) | (x > 11):\n",
|
| 107 |
+
" num_c = c - 1\n",
|
| 108 |
+
" break\n",
|
| 109 |
+
" \n",
|
| 110 |
+
" area = num_r * num_c\n",
|
| 111 |
+
" if area > max_area:\n",
|
| 112 |
+
" max_area = area\n",
|
| 113 |
+
" max_size = (num_r, num_c)\n",
|
| 114 |
+
"\n",
|
| 115 |
+
" return grid[:max_size[0], :max_size[1]] - 2\n",
|
| 116 |
+
"\n",
|
| 117 |
+
"\n",
|
| 118 |
+
"def test(visualize, Ks=[1, 2, 10, 100, 1000]):\n",
|
| 119 |
+
" identifier_map, all_preds = load_identifiers_and_preds(DATASET_PATH, CHECKPOINT_PATH)\n",
|
| 120 |
+
" \n",
|
| 121 |
+
" global_hmap = {}\n",
|
| 122 |
+
" \n",
|
| 123 |
+
" # Get puzzles and corresponding answers\n",
|
| 124 |
+
" puzzle_labels = {}\n",
|
| 125 |
+
" for identifier, input, label in zip(all_preds[\"puzzle_identifiers\"], all_preds[\"inputs\"], all_preds[\"labels\"]):\n",
|
| 126 |
+
" name = identifier_map[identifier]\n",
|
| 127 |
+
" if \"_\" not in name: # Not-augmented\n",
|
| 128 |
+
" puzzle_labels.setdefault(name, {})\n",
|
| 129 |
+
" \n",
|
| 130 |
+
" input = crop(input.numpy())\n",
|
| 131 |
+
" label = crop(label.numpy())\n",
|
| 132 |
+
"\n",
|
| 133 |
+
" input_hash = grid_hash(input)\n",
|
| 134 |
+
" label_hash = grid_hash(label)\n",
|
| 135 |
+
"\n",
|
| 136 |
+
" global_hmap[input_hash] = input\n",
|
| 137 |
+
" global_hmap[label_hash] = label\n",
|
| 138 |
+
"\n",
|
| 139 |
+
" assert input_hash not in puzzle_labels[name]\n",
|
| 140 |
+
" puzzle_labels[name][input_hash] = label_hash\n",
|
| 141 |
+
" \n",
|
| 142 |
+
" print (\"Number of puzzles\", len(puzzle_labels))\n",
|
| 143 |
+
" \n",
|
| 144 |
+
" # Argmax prediction\n",
|
| 145 |
+
" preds = all_preds[\"logits\"].argmax(-1)\n",
|
| 146 |
+
"\n",
|
| 147 |
+
" # Collate\n",
|
| 148 |
+
" pred_answers = {}\n",
|
| 149 |
+
" for identifier, input, pred, q in zip(all_preds[\"puzzle_identifiers\"], all_preds[\"inputs\"], preds, all_preds[\"q_halt_logits\"].sigmoid()):\n",
|
| 150 |
+
" name = identifier_map[identifier]\n",
|
| 151 |
+
" orig_name = name.split(\"_\")[0]\n",
|
| 152 |
+
" \n",
|
| 153 |
+
" input = input.numpy()\n",
|
| 154 |
+
" input_hash = grid_hash(inverse_aug(name, crop(input)))\n",
|
| 155 |
+
" assert input_hash in puzzle_labels[orig_name]\n",
|
| 156 |
+
" \n",
|
| 157 |
+
" pred = inverse_aug(name, crop(pred.numpy()))\n",
|
| 158 |
+
" pred_hash = grid_hash(pred)\n",
|
| 159 |
+
" global_hmap[pred_hash] = pred\n",
|
| 160 |
+
" \n",
|
| 161 |
+
" pred_answers.setdefault(orig_name, {})\n",
|
| 162 |
+
" pred_answers[orig_name].setdefault(input_hash, [])\n",
|
| 163 |
+
" pred_answers[orig_name][input_hash].append((pred_hash, q.item()))\n",
|
| 164 |
+
"\n",
|
| 165 |
+
" # test-1\n",
|
| 166 |
+
" if visualize:\n",
|
| 167 |
+
" num_figs = sum(len(tests) for name, tests in puzzle_labels.items())\n",
|
| 168 |
+
" fig, axes = plt.subplots(num_figs, 4, figsize=(8, num_figs * 4))\n",
|
| 169 |
+
" \n",
|
| 170 |
+
" fig_id = 0\n",
|
| 171 |
+
" \n",
|
| 172 |
+
" correct = [0 for _ in range(len(Ks))]\n",
|
| 173 |
+
" for name, tests in puzzle_labels.items():\n",
|
| 174 |
+
" num_test_correct = [0 for _ in range(len(Ks))]\n",
|
| 175 |
+
" for input_hash, label_hash in tests.items():\n",
|
| 176 |
+
" p = pred_answers[name][input_hash]\n",
|
| 177 |
+
" p_map = {}\n",
|
| 178 |
+
" \n",
|
| 179 |
+
" for h, q in p:\n",
|
| 180 |
+
" p_map.setdefault(h, [0, 0])\n",
|
| 181 |
+
" p_map[h][0] += 1\n",
|
| 182 |
+
" p_map[h][1] += q\n",
|
| 183 |
+
" \n",
|
| 184 |
+
" for h, stats in p_map.items():\n",
|
| 185 |
+
" stats[1] /= stats[0]\n",
|
| 186 |
+
" \n",
|
| 187 |
+
" p_map = sorted(p_map.items(), key=lambda kv: kv[1], reverse=True)\n",
|
| 188 |
+
"\n",
|
| 189 |
+
" # 2-vote\n",
|
| 190 |
+
" for i, k in enumerate(Ks):\n",
|
| 191 |
+
" ok = False\n",
|
| 192 |
+
" for h, stats in p_map[:k]:\n",
|
| 193 |
+
" ok |= h == label_hash\n",
|
| 194 |
+
" \n",
|
| 195 |
+
" num_test_correct[i] += ok\n",
|
| 196 |
+
"\n",
|
| 197 |
+
" if visualize:\n",
|
| 198 |
+
" # Show input and ground truth\n",
|
| 199 |
+
" axes[fig_id, 0].imshow(global_hmap[input_hash], cmap=ARC_COLOR_MAP)\n",
|
| 200 |
+
" axes[fig_id, 0].set_title(f\"{name}\\nInput\")\n",
|
| 201 |
+
" axes[fig_id, 0].axis('off')\n",
|
| 202 |
+
" \n",
|
| 203 |
+
" axes[fig_id, 1].imshow(global_hmap[label_hash], cmap=ARC_COLOR_MAP)\n",
|
| 204 |
+
" axes[fig_id, 1].set_title(f\"{name}\\nAnswer\")\n",
|
| 205 |
+
" axes[fig_id, 1].axis('off')\n",
|
| 206 |
+
" \n",
|
| 207 |
+
" trial_id = 2\n",
|
| 208 |
+
" for h, stats in p_map[:2]:\n",
|
| 209 |
+
" ans = global_hmap[h]\n",
|
| 210 |
+
" \n",
|
| 211 |
+
" axes[fig_id, trial_id].imshow(ans, cmap=ARC_COLOR_MAP)\n",
|
| 212 |
+
" axes[fig_id, trial_id].set_title(f\"{name}\\nTrial {trial_id}\")\n",
|
| 213 |
+
" axes[fig_id, trial_id].axis('off')\n",
|
| 214 |
+
" \n",
|
| 215 |
+
" trial_id += 1\n",
|
| 216 |
+
" \n",
|
| 217 |
+
" fig_id += 1\n",
|
| 218 |
+
" \n",
|
| 219 |
+
" # Total correctness\n",
|
| 220 |
+
" for i in range(len(Ks)):\n",
|
| 221 |
+
" correct[i] += num_test_correct[i] == len(tests)\n",
|
| 222 |
+
"\n",
|
| 223 |
+
" for i, k in enumerate(Ks):\n",
|
| 224 |
+
" print (f\"{k}-shot: {correct[i] / len(puzzle_labels) * 100:.2f}%\")\n",
|
| 225 |
+
"\n",
|
| 226 |
+
"\n",
|
| 227 |
+
"test(visualize=False)"
|
| 228 |
+
]
|
| 229 |
+
}
|
| 230 |
+
],
|
| 231 |
+
"metadata": {
|
| 232 |
+
"kernelspec": {
|
| 233 |
+
"display_name": "Python 3",
|
| 234 |
+
"language": "python",
|
| 235 |
+
"name": "python3"
|
| 236 |
+
},
|
| 237 |
+
"language_info": {
|
| 238 |
+
"codemirror_mode": {
|
| 239 |
+
"name": "ipython",
|
| 240 |
+
"version": 3
|
| 241 |
+
},
|
| 242 |
+
"file_extension": ".py",
|
| 243 |
+
"mimetype": "text/x-python",
|
| 244 |
+
"name": "python",
|
| 245 |
+
"nbconvert_exporter": "python",
|
| 246 |
+
"pygments_lexer": "ipython3",
|
| 247 |
+
"version": "3.12.10"
|
| 248 |
+
}
|
| 249 |
+
},
|
| 250 |
+
"nbformat": 4,
|
| 251 |
+
"nbformat_minor": 2
|
| 252 |
+
}
|
assets/hrm.png
ADDED
|
assets/npyjs.js
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
class npyjs {
|
| 2 |
+
|
| 3 |
+
constructor(opts) {
|
| 4 |
+
if (opts && !('convertFloat16' in opts)) {
|
| 5 |
+
console.warn([
|
| 6 |
+
"npyjs constructor now accepts {convertFloat16?: boolean}.",
|
| 7 |
+
"For usage, go to https://github.com/jhuapl-boss/npyjs."
|
| 8 |
+
].join(" "));
|
| 9 |
+
}
|
| 10 |
+
|
| 11 |
+
this.convertFloat16 = opts?.convertFloat16 ?? true;
|
| 12 |
+
|
| 13 |
+
this.dtypes = {
|
| 14 |
+
"<u1": {
|
| 15 |
+
name: "uint8",
|
| 16 |
+
size: 8,
|
| 17 |
+
arrayConstructor: Uint8Array,
|
| 18 |
+
},
|
| 19 |
+
"|u1": {
|
| 20 |
+
name: "uint8",
|
| 21 |
+
size: 8,
|
| 22 |
+
arrayConstructor: Uint8Array,
|
| 23 |
+
},
|
| 24 |
+
"<u2": {
|
| 25 |
+
name: "uint16",
|
| 26 |
+
size: 16,
|
| 27 |
+
arrayConstructor: Uint16Array,
|
| 28 |
+
},
|
| 29 |
+
"|i1": {
|
| 30 |
+
name: "int8",
|
| 31 |
+
size: 8,
|
| 32 |
+
arrayConstructor: Int8Array,
|
| 33 |
+
},
|
| 34 |
+
"<i2": {
|
| 35 |
+
name: "int16",
|
| 36 |
+
size: 16,
|
| 37 |
+
arrayConstructor: Int16Array,
|
| 38 |
+
},
|
| 39 |
+
"<u4": {
|
| 40 |
+
name: "uint32",
|
| 41 |
+
size: 32,
|
| 42 |
+
arrayConstructor: Uint32Array,
|
| 43 |
+
},
|
| 44 |
+
"<i4": {
|
| 45 |
+
name: "int32",
|
| 46 |
+
size: 32,
|
| 47 |
+
arrayConstructor: Int32Array,
|
| 48 |
+
},
|
| 49 |
+
"<u8": {
|
| 50 |
+
name: "uint64",
|
| 51 |
+
size: 64,
|
| 52 |
+
arrayConstructor: BigUint64Array,
|
| 53 |
+
},
|
| 54 |
+
"<i8": {
|
| 55 |
+
name: "int64",
|
| 56 |
+
size: 64,
|
| 57 |
+
arrayConstructor: BigInt64Array,
|
| 58 |
+
},
|
| 59 |
+
"<f4": {
|
| 60 |
+
name: "float32",
|
| 61 |
+
size: 32,
|
| 62 |
+
arrayConstructor: Float32Array
|
| 63 |
+
},
|
| 64 |
+
"<f8": {
|
| 65 |
+
name: "float64",
|
| 66 |
+
size: 64,
|
| 67 |
+
arrayConstructor: Float64Array
|
| 68 |
+
},
|
| 69 |
+
"<f2": {
|
| 70 |
+
name: "float16",
|
| 71 |
+
size: 16,
|
| 72 |
+
arrayConstructor: Uint16Array,
|
| 73 |
+
converter: this.convertFloat16 ? this.float16ToFloat32Array : undefined
|
| 74 |
+
},
|
| 75 |
+
};
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
float16ToFloat32Array(float16Array) {
|
| 79 |
+
const length = float16Array.length;
|
| 80 |
+
const float32Array = new Float32Array(length);
|
| 81 |
+
|
| 82 |
+
for (let i = 0; i < length; i++) {
|
| 83 |
+
float32Array[i] = npyjs.float16ToFloat32(float16Array[i]);
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
return float32Array;
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
static float16ToFloat32(float16) {
|
| 90 |
+
// Extract the parts of the float16
|
| 91 |
+
const sign = (float16 >> 15) & 0x1;
|
| 92 |
+
const exponent = (float16 >> 10) & 0x1f;
|
| 93 |
+
const fraction = float16 & 0x3ff;
|
| 94 |
+
|
| 95 |
+
// Handle special cases
|
| 96 |
+
if (exponent === 0) {
|
| 97 |
+
if (fraction === 0) {
|
| 98 |
+
// Zero
|
| 99 |
+
return sign ? -0 : 0;
|
| 100 |
+
}
|
| 101 |
+
// Denormalized number
|
| 102 |
+
return (sign ? -1 : 1) * Math.pow(2, -14) * (fraction / 0x400);
|
| 103 |
+
} else if (exponent === 0x1f) {
|
| 104 |
+
if (fraction === 0) {
|
| 105 |
+
// Infinity
|
| 106 |
+
return sign ? -Infinity : Infinity;
|
| 107 |
+
}
|
| 108 |
+
// NaN
|
| 109 |
+
return NaN;
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
// Normalized number
|
| 113 |
+
return (sign ? -1 : 1) * Math.pow(2, exponent - 15) * (1 + fraction / 0x400);
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
parse(arrayBufferContents) {
|
| 117 |
+
// const version = arrayBufferContents.slice(6, 8); // Uint8-encoded
|
| 118 |
+
const headerLength = new DataView(arrayBufferContents.slice(8, 10)).getUint8(0);
|
| 119 |
+
const offsetBytes = 10 + headerLength;
|
| 120 |
+
|
| 121 |
+
const hcontents = new TextDecoder("utf-8").decode(
|
| 122 |
+
new Uint8Array(arrayBufferContents.slice(10, 10 + headerLength))
|
| 123 |
+
);
|
| 124 |
+
const header = JSON.parse(
|
| 125 |
+
hcontents
|
| 126 |
+
.toLowerCase() // True -> true
|
| 127 |
+
.replace(/'/g, '"')
|
| 128 |
+
.replace("(", "[")
|
| 129 |
+
.replace(/,*\),*/g, "]")
|
| 130 |
+
);
|
| 131 |
+
const shape = header.shape;
|
| 132 |
+
const dtype = this.dtypes[header.descr];
|
| 133 |
+
|
| 134 |
+
if (!dtype) {
|
| 135 |
+
console.error(`Unsupported dtype: ${header.descr}`);
|
| 136 |
+
return null;
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
const nums = new dtype.arrayConstructor(
|
| 140 |
+
arrayBufferContents,
|
| 141 |
+
offsetBytes
|
| 142 |
+
);
|
| 143 |
+
|
| 144 |
+
// Convert float16 to float32 if converter exists
|
| 145 |
+
const data = dtype.converter ? dtype.converter.call(this, nums) : nums;
|
| 146 |
+
|
| 147 |
+
return {
|
| 148 |
+
dtype: dtype.name,
|
| 149 |
+
data: data,
|
| 150 |
+
shape,
|
| 151 |
+
fortranOrder: header.fortran_order
|
| 152 |
+
};
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
async load(filename, callback, fetchArgs) {
|
| 156 |
+
/*
|
| 157 |
+
Loads an array from a stream of bytes.
|
| 158 |
+
*/
|
| 159 |
+
fetchArgs = fetchArgs || {};
|
| 160 |
+
let arrayBuf;
|
| 161 |
+
// If filename is ArrayBuffer
|
| 162 |
+
if (filename instanceof ArrayBuffer) {
|
| 163 |
+
arrayBuf = filename;
|
| 164 |
+
}
|
| 165 |
+
// If filename is a file path
|
| 166 |
+
else {
|
| 167 |
+
const resp = await fetch(filename, { ...fetchArgs });
|
| 168 |
+
arrayBuf = await resp.arrayBuffer();
|
| 169 |
+
}
|
| 170 |
+
const result = this.parse(arrayBuf);
|
| 171 |
+
if (callback) {
|
| 172 |
+
return callback(result);
|
| 173 |
+
}
|
| 174 |
+
return result;
|
| 175 |
+
}
|
| 176 |
+
}
|
config/arch/hrm_v1.yaml
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: hrm.hrm_act_v1@HierarchicalReasoningModel_ACTV1
|
| 2 |
+
loss:
|
| 3 |
+
name: losses@ACTLossHead
|
| 4 |
+
loss_type: stablemax_cross_entropy
|
| 5 |
+
|
| 6 |
+
halt_exploration_prob: 0.1
|
| 7 |
+
halt_max_steps: 16
|
| 8 |
+
|
| 9 |
+
H_cycles: 2
|
| 10 |
+
L_cycles: 2
|
| 11 |
+
|
| 12 |
+
H_layers: 4
|
| 13 |
+
L_layers: 4
|
| 14 |
+
|
| 15 |
+
hidden_size: 512
|
| 16 |
+
num_heads: 8 # min(2, hidden_size // 64)
|
| 17 |
+
expansion: 4
|
| 18 |
+
|
| 19 |
+
puzzle_emb_ndim: ${.hidden_size}
|
| 20 |
+
|
| 21 |
+
pos_encodings: rope
|
config/cfg_pretrain.yaml
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ARC training config
|
| 2 |
+
|
| 3 |
+
defaults:
|
| 4 |
+
- arch: hrm_v1
|
| 5 |
+
- _self_
|
| 6 |
+
|
| 7 |
+
hydra:
|
| 8 |
+
output_subdir: null
|
| 9 |
+
|
| 10 |
+
# Data path
|
| 11 |
+
data_path: data/arc-aug-1000
|
| 12 |
+
|
| 13 |
+
# Hyperparams - Training
|
| 14 |
+
global_batch_size: 768
|
| 15 |
+
|
| 16 |
+
epochs: 100000
|
| 17 |
+
eval_interval: 10000
|
| 18 |
+
checkpoint_every_eval: True
|
| 19 |
+
|
| 20 |
+
lr: 1e-4
|
| 21 |
+
lr_min_ratio: 1.0
|
| 22 |
+
lr_warmup_steps: 2000
|
| 23 |
+
|
| 24 |
+
# Standard hyperparameter settings for LM, as used in Llama
|
| 25 |
+
beta1: 0.9
|
| 26 |
+
beta2: 0.95
|
| 27 |
+
weight_decay: 0.1
|
| 28 |
+
puzzle_emb_weight_decay: 0.1
|
| 29 |
+
|
| 30 |
+
# Hyperparams - Puzzle embeddings training
|
| 31 |
+
puzzle_emb_lr: 1e-2
|
dataset/build_arc_dataset.py
ADDED
|
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Optional, Tuple, Dict
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
import os
|
| 5 |
+
import json
|
| 6 |
+
import hashlib
|
| 7 |
+
import numpy as np
|
| 8 |
+
from glob import glob
|
| 9 |
+
|
| 10 |
+
from argdantic import ArgParser
|
| 11 |
+
from pydantic import BaseModel
|
| 12 |
+
|
| 13 |
+
from common import PuzzleDatasetMetadata, dihedral_transform
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
cli = ArgParser()
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class DataProcessConfig(BaseModel):
|
| 20 |
+
# ARC-1
|
| 21 |
+
dataset_dirs: List[str] = ["dataset/raw-data/ARC-AGI/data", "dataset/raw-data/ConceptARC/corpus"]
|
| 22 |
+
output_dir: str = "data/arc-aug-1000"
|
| 23 |
+
|
| 24 |
+
# ARC-2
|
| 25 |
+
# dataset_dirs: List[str] = ["dataset/raw-data/ARC-AGI-2/data"]
|
| 26 |
+
# output_dir: str = "data/arc-2-aug-1000"
|
| 27 |
+
|
| 28 |
+
seed: int = 42
|
| 29 |
+
num_aug: int = 1000
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
ARCMaxGridSize = 30
|
| 33 |
+
ARCAugmentRetriesFactor = 5
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@dataclass
|
| 37 |
+
class ARCPuzzle:
|
| 38 |
+
id: str
|
| 39 |
+
|
| 40 |
+
examples: List[Tuple[np.ndarray, np.ndarray]]
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def arc_grid_to_np(grid: List[List[int]]):
|
| 44 |
+
arr = np.array(grid)
|
| 45 |
+
|
| 46 |
+
# Shape check
|
| 47 |
+
assert arr.ndim == 2
|
| 48 |
+
assert arr.shape[0] <= ARCMaxGridSize and arr.shape[1] <= ARCMaxGridSize
|
| 49 |
+
# Element check
|
| 50 |
+
assert np.all((arr >= 0) & (arr <= 9))
|
| 51 |
+
return arr.astype(np.uint8)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def np_grid_to_seq_translational_augment(inp: np.ndarray, out: np.ndarray, do_translation: bool):
|
| 55 |
+
# PAD: 0, <eos>: 1, digits: 2 ... 11
|
| 56 |
+
# Compute random top-left pad
|
| 57 |
+
if do_translation:
|
| 58 |
+
pad_r = np.random.randint(0, ARCMaxGridSize - max(inp.shape[0], out.shape[0]) + 1)
|
| 59 |
+
pad_c = np.random.randint(0, ARCMaxGridSize - max(inp.shape[1], out.shape[1]) + 1)
|
| 60 |
+
else:
|
| 61 |
+
pad_r = pad_c = 0
|
| 62 |
+
|
| 63 |
+
# Pad grid
|
| 64 |
+
result = []
|
| 65 |
+
for grid in [inp, out]:
|
| 66 |
+
nrow, ncol = grid.shape
|
| 67 |
+
grid = np.pad(grid + 2, ((pad_r, ARCMaxGridSize - pad_r - nrow), (pad_c, ARCMaxGridSize - pad_c - ncol)), constant_values=0)
|
| 68 |
+
|
| 69 |
+
# Add <eos>
|
| 70 |
+
eos_row, eos_col = pad_r + nrow, pad_c + ncol
|
| 71 |
+
if eos_row < ARCMaxGridSize:
|
| 72 |
+
grid[eos_row, pad_c:eos_col] = 1
|
| 73 |
+
if eos_col < ARCMaxGridSize:
|
| 74 |
+
grid[pad_r:eos_row, eos_col] = 1
|
| 75 |
+
|
| 76 |
+
result.append(grid.flatten())
|
| 77 |
+
|
| 78 |
+
return result
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def puzzle_hash(puzzle: dict):
|
| 82 |
+
# Hash the puzzle for checking equivalence
|
| 83 |
+
def _grid_hash(grid: np.ndarray):
|
| 84 |
+
buffer = [x.to_bytes(1) for x in grid.shape]
|
| 85 |
+
buffer.append(grid.tobytes())
|
| 86 |
+
|
| 87 |
+
return hashlib.sha256(b"".join(buffer)).hexdigest()
|
| 88 |
+
|
| 89 |
+
hashes = []
|
| 90 |
+
for example_type, example in puzzle.items():
|
| 91 |
+
for input, label in example.examples:
|
| 92 |
+
hashes.append(f"{_grid_hash(input)}|{_grid_hash(label)}")
|
| 93 |
+
|
| 94 |
+
hashes.sort()
|
| 95 |
+
return hashlib.sha256("|".join(hashes).encode()).hexdigest()
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def convert_single_arc_puzzle(results: dict, default_name: str, puzzle: dict, aug_count: int, dest_mapping: Dict[str, Tuple[str, str]]):
|
| 99 |
+
# Remove "name"
|
| 100 |
+
name = puzzle.pop("name", default_name)
|
| 101 |
+
|
| 102 |
+
# Convert
|
| 103 |
+
dests = set(dest_mapping.values())
|
| 104 |
+
converted = {dest: ARCPuzzle(name, []) for dest in dests}
|
| 105 |
+
for example_type, examples in puzzle.items():
|
| 106 |
+
dest = dest_mapping[example_type]
|
| 107 |
+
converted[dest].examples.extend([(arc_grid_to_np(example["input"]), arc_grid_to_np(example["output"])) for example in examples])
|
| 108 |
+
|
| 109 |
+
group = [converted]
|
| 110 |
+
|
| 111 |
+
# Augment
|
| 112 |
+
if aug_count > 0:
|
| 113 |
+
hashes = {puzzle_hash(converted)}
|
| 114 |
+
|
| 115 |
+
for _trial in range(ARCAugmentRetriesFactor * aug_count):
|
| 116 |
+
# Augment plan
|
| 117 |
+
trans_id = np.random.randint(0, 8)
|
| 118 |
+
mapping = np.concatenate([np.arange(0, 1, dtype=np.uint8), np.random.permutation(np.arange(1, 10, dtype=np.uint8))]) # Permute colors, Excluding "0" (black)
|
| 119 |
+
|
| 120 |
+
aug_repr = f"t{trans_id}_{''.join(str(x) for x in mapping)}"
|
| 121 |
+
|
| 122 |
+
def _map_grid(grid: np.ndarray):
|
| 123 |
+
return dihedral_transform(mapping[grid], trans_id)
|
| 124 |
+
|
| 125 |
+
# Check duplicate
|
| 126 |
+
augmented = {dest: ARCPuzzle(f"{puzzle.id}_{aug_repr}", [(_map_grid(input), _map_grid(label)) for (input, label) in puzzle.examples]) for dest, puzzle in converted.items()}
|
| 127 |
+
h = puzzle_hash(augmented)
|
| 128 |
+
if h not in hashes:
|
| 129 |
+
hashes.add(h)
|
| 130 |
+
group.append(augmented)
|
| 131 |
+
|
| 132 |
+
if len(group) >= aug_count + 1:
|
| 133 |
+
break
|
| 134 |
+
|
| 135 |
+
if len(group) < aug_count + 1:
|
| 136 |
+
print (f"[Puzzle {name}] augmentation not full, only {len(group)}")
|
| 137 |
+
|
| 138 |
+
# Append
|
| 139 |
+
for dest in dests:
|
| 140 |
+
# Convert the examples
|
| 141 |
+
dest_split, dest_set = dest
|
| 142 |
+
|
| 143 |
+
results.setdefault(dest_split, {})
|
| 144 |
+
results[dest_split].setdefault(dest_set, [])
|
| 145 |
+
results[dest_split][dest_set].append([converted[dest] for converted in group])
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def load_puzzles_arcagi(results: dict, dataset_path: str, config: DataProcessConfig):
|
| 149 |
+
train_examples_dest = ("train", "all")
|
| 150 |
+
test_examples_map = {
|
| 151 |
+
"evaluation": [(1.0, ("test", "all"))],
|
| 152 |
+
"_default": [(1.0, ("train", "all"))]
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
total_puzzles = 0
|
| 156 |
+
for subdir in os.scandir(dataset_path):
|
| 157 |
+
if subdir.is_dir():
|
| 158 |
+
# Load all puzzles in this directory
|
| 159 |
+
puzzles = []
|
| 160 |
+
for filename in glob(os.path.join(subdir.path, "*.json")):
|
| 161 |
+
with open(filename, "r") as f:
|
| 162 |
+
puzzles.append((Path(filename).stem, json.load(f)))
|
| 163 |
+
|
| 164 |
+
# Shuffle puzzles
|
| 165 |
+
np.random.shuffle(puzzles)
|
| 166 |
+
|
| 167 |
+
# Assign by fraction
|
| 168 |
+
for idx, (default_name, puzzle) in enumerate(puzzles):
|
| 169 |
+
fraction = idx / len(puzzles)
|
| 170 |
+
test_examples_dest = None
|
| 171 |
+
for f, dest in test_examples_map.get(subdir.name, test_examples_map["_default"]):
|
| 172 |
+
if fraction < f:
|
| 173 |
+
test_examples_dest = dest
|
| 174 |
+
break
|
| 175 |
+
|
| 176 |
+
assert test_examples_dest is not None
|
| 177 |
+
|
| 178 |
+
convert_single_arc_puzzle(results, default_name, puzzle, config.num_aug, {"train": train_examples_dest, "test": test_examples_dest})
|
| 179 |
+
total_puzzles += 1
|
| 180 |
+
|
| 181 |
+
print (f"[{dataset_path}] total puzzles: {total_puzzles}")
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def convert_dataset(config: DataProcessConfig):
|
| 185 |
+
np.random.seed(config.seed)
|
| 186 |
+
|
| 187 |
+
# Read dataset
|
| 188 |
+
data = {}
|
| 189 |
+
for dataset_dir in config.dataset_dirs:
|
| 190 |
+
load_puzzles_arcagi(data, dataset_dir, config)
|
| 191 |
+
|
| 192 |
+
# Map global puzzle identifiers
|
| 193 |
+
num_identifiers = 1 # 0 is blank
|
| 194 |
+
identifier_map = {}
|
| 195 |
+
for split_name, split in data.items():
|
| 196 |
+
for subset_name, subset in split.items():
|
| 197 |
+
for group in subset:
|
| 198 |
+
for puzzle in group:
|
| 199 |
+
if puzzle.id not in identifier_map:
|
| 200 |
+
identifier_map[puzzle.id] = num_identifiers
|
| 201 |
+
num_identifiers += 1
|
| 202 |
+
|
| 203 |
+
print (f"Total puzzle IDs (including <blank>): {num_identifiers}")
|
| 204 |
+
|
| 205 |
+
# Save
|
| 206 |
+
for split_name, split in data.items():
|
| 207 |
+
os.makedirs(os.path.join(config.output_dir, split_name), exist_ok=True)
|
| 208 |
+
|
| 209 |
+
# Translational augmentations
|
| 210 |
+
enable_translational_augment = split_name == "train"
|
| 211 |
+
|
| 212 |
+
# Statistics
|
| 213 |
+
total_examples = 0
|
| 214 |
+
total_puzzles = 0
|
| 215 |
+
total_groups = 0
|
| 216 |
+
|
| 217 |
+
for subset_name, subset in split.items():
|
| 218 |
+
# Construct subset
|
| 219 |
+
results = {k: [] for k in ["inputs", "labels", "puzzle_identifiers", "puzzle_indices", "group_indices"]}
|
| 220 |
+
results["puzzle_indices"].append(0)
|
| 221 |
+
results["group_indices"].append(0)
|
| 222 |
+
|
| 223 |
+
example_id = 0
|
| 224 |
+
puzzle_id = 0
|
| 225 |
+
|
| 226 |
+
for group in subset:
|
| 227 |
+
for puzzle in group:
|
| 228 |
+
# Push puzzle
|
| 229 |
+
no_aug_id = np.random.randint(0, len(puzzle.examples))
|
| 230 |
+
for _idx_ex, (inp, out) in enumerate(puzzle.examples):
|
| 231 |
+
inp, out = np_grid_to_seq_translational_augment(inp, out, do_translation=enable_translational_augment and _idx_ex != no_aug_id)
|
| 232 |
+
|
| 233 |
+
results["inputs"].append(inp)
|
| 234 |
+
results["labels"].append(out)
|
| 235 |
+
example_id += 1
|
| 236 |
+
|
| 237 |
+
total_examples += 1
|
| 238 |
+
|
| 239 |
+
results["puzzle_indices"].append(example_id)
|
| 240 |
+
results["puzzle_identifiers"].append(identifier_map[puzzle.id])
|
| 241 |
+
|
| 242 |
+
puzzle_id += 1
|
| 243 |
+
|
| 244 |
+
total_puzzles += 1
|
| 245 |
+
|
| 246 |
+
# Push group
|
| 247 |
+
results["group_indices"].append(puzzle_id)
|
| 248 |
+
total_groups += 1
|
| 249 |
+
|
| 250 |
+
for k, v in results.items():
|
| 251 |
+
if k in {"inputs", "labels"}:
|
| 252 |
+
v = np.stack(v, 0)
|
| 253 |
+
else:
|
| 254 |
+
v = np.array(v, dtype=np.int32)
|
| 255 |
+
|
| 256 |
+
np.save(os.path.join(config.output_dir, split_name, f"{subset_name}__{k}.npy"), v)
|
| 257 |
+
|
| 258 |
+
# Metadata
|
| 259 |
+
metadata = PuzzleDatasetMetadata(
|
| 260 |
+
seq_len=ARCMaxGridSize * ARCMaxGridSize,
|
| 261 |
+
vocab_size=10 + 2, # PAD + EOS + "0" ... "9"
|
| 262 |
+
|
| 263 |
+
pad_id=0,
|
| 264 |
+
ignore_label_id=0,
|
| 265 |
+
|
| 266 |
+
blank_identifier_id=0,
|
| 267 |
+
num_puzzle_identifiers=num_identifiers,
|
| 268 |
+
|
| 269 |
+
total_groups=total_groups,
|
| 270 |
+
mean_puzzle_examples=total_examples / total_puzzles,
|
| 271 |
+
sets=list(split.keys())
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
# Save metadata as JSON.
|
| 275 |
+
with open(os.path.join(config.output_dir, split_name, "dataset.json"), "w") as f:
|
| 276 |
+
json.dump(metadata.model_dump(), f)
|
| 277 |
+
|
| 278 |
+
# Save IDs mapping
|
| 279 |
+
with open(os.path.join(config.output_dir, "identifiers.json"), "w") as f:
|
| 280 |
+
ids_mapping = {v: k for k, v in identifier_map.items()}
|
| 281 |
+
|
| 282 |
+
json.dump([ids_mapping.get(i, "<blank>") for i in range(num_identifiers)], f)
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
@cli.command(singleton=True)
|
| 286 |
+
def main(config: DataProcessConfig):
|
| 287 |
+
convert_dataset(config)
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
if __name__ == "__main__":
|
| 291 |
+
cli()
|
dataset/build_maze_dataset.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
import math
|
| 3 |
+
import os
|
| 4 |
+
import csv
|
| 5 |
+
import json
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
from argdantic import ArgParser
|
| 9 |
+
from pydantic import BaseModel
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
from huggingface_hub import hf_hub_download
|
| 12 |
+
|
| 13 |
+
from common import PuzzleDatasetMetadata, dihedral_transform
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
CHARSET = "# SGo"
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
cli = ArgParser()
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class DataProcessConfig(BaseModel):
|
| 23 |
+
source_repo: str = "sapientinc/maze-30x30-hard-1k"
|
| 24 |
+
output_dir: str = "data/maze-30x30-hard-1k"
|
| 25 |
+
|
| 26 |
+
subsample_size: Optional[int] = None
|
| 27 |
+
aug: bool = False
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def convert_subset(set_name: str, config: DataProcessConfig):
|
| 31 |
+
# Read CSV
|
| 32 |
+
all_chars = set()
|
| 33 |
+
grid_size = None
|
| 34 |
+
inputs = []
|
| 35 |
+
labels = []
|
| 36 |
+
|
| 37 |
+
with open(hf_hub_download(config.source_repo, f"{set_name}.csv", repo_type="dataset"), newline="") as csvfile: # type: ignore
|
| 38 |
+
reader = csv.reader(csvfile)
|
| 39 |
+
next(reader) # Skip header
|
| 40 |
+
for source, q, a, rating in reader:
|
| 41 |
+
all_chars.update(q)
|
| 42 |
+
all_chars.update(a)
|
| 43 |
+
|
| 44 |
+
if grid_size is None:
|
| 45 |
+
n = int(len(q) ** 0.5)
|
| 46 |
+
grid_size = (n, n)
|
| 47 |
+
|
| 48 |
+
inputs.append(np.frombuffer(q.encode(), dtype=np.uint8).reshape(grid_size))
|
| 49 |
+
labels.append(np.frombuffer(a.encode(), dtype=np.uint8).reshape(grid_size))
|
| 50 |
+
|
| 51 |
+
# If subsample_size is specified for the training set,
|
| 52 |
+
# randomly sample the desired number of examples.
|
| 53 |
+
if set_name == "train" and config.subsample_size is not None:
|
| 54 |
+
total_samples = len(inputs)
|
| 55 |
+
if config.subsample_size < total_samples:
|
| 56 |
+
indices = np.random.choice(total_samples, size=config.subsample_size, replace=False)
|
| 57 |
+
inputs = [inputs[i] for i in indices]
|
| 58 |
+
labels = [labels[i] for i in indices]
|
| 59 |
+
|
| 60 |
+
# Generate dataset
|
| 61 |
+
results = {k: [] for k in ["inputs", "labels", "puzzle_identifiers", "puzzle_indices", "group_indices"]}
|
| 62 |
+
puzzle_id = 0
|
| 63 |
+
example_id = 0
|
| 64 |
+
|
| 65 |
+
results["puzzle_indices"].append(0)
|
| 66 |
+
results["group_indices"].append(0)
|
| 67 |
+
|
| 68 |
+
for inp, out in zip(tqdm(inputs), labels):
|
| 69 |
+
# Dihedral transformations for augmentation
|
| 70 |
+
for aug_idx in range(8 if (set_name == "train" and config.aug) else 1):
|
| 71 |
+
results["inputs"].append(dihedral_transform(inp, aug_idx))
|
| 72 |
+
results["labels"].append(dihedral_transform(out, aug_idx))
|
| 73 |
+
example_id += 1
|
| 74 |
+
puzzle_id += 1
|
| 75 |
+
|
| 76 |
+
results["puzzle_indices"].append(example_id)
|
| 77 |
+
results["puzzle_identifiers"].append(0)
|
| 78 |
+
|
| 79 |
+
# Push group
|
| 80 |
+
results["group_indices"].append(puzzle_id)
|
| 81 |
+
|
| 82 |
+
# Char mappings
|
| 83 |
+
assert len(all_chars - set(CHARSET)) == 0
|
| 84 |
+
|
| 85 |
+
char2id = np.zeros(256, np.uint8)
|
| 86 |
+
char2id[np.array(list(map(ord, CHARSET)))] = np.arange(len(CHARSET)) + 1
|
| 87 |
+
|
| 88 |
+
# To Numpy
|
| 89 |
+
def _seq_to_numpy(seq):
|
| 90 |
+
arr = np.vstack([char2id[s.reshape(-1)] for s in seq])
|
| 91 |
+
|
| 92 |
+
return arr
|
| 93 |
+
|
| 94 |
+
results = {
|
| 95 |
+
"inputs": _seq_to_numpy(results["inputs"]),
|
| 96 |
+
"labels": _seq_to_numpy(results["labels"]),
|
| 97 |
+
|
| 98 |
+
"group_indices": np.array(results["group_indices"], dtype=np.int32),
|
| 99 |
+
"puzzle_indices": np.array(results["puzzle_indices"], dtype=np.int32),
|
| 100 |
+
"puzzle_identifiers": np.array(results["puzzle_identifiers"], dtype=np.int32),
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
# Metadata
|
| 104 |
+
metadata = PuzzleDatasetMetadata(
|
| 105 |
+
seq_len=int(math.prod(grid_size)), # type: ignore
|
| 106 |
+
vocab_size=len(CHARSET) + 1, # PAD + Charset
|
| 107 |
+
|
| 108 |
+
pad_id=0,
|
| 109 |
+
ignore_label_id=0,
|
| 110 |
+
|
| 111 |
+
blank_identifier_id=0,
|
| 112 |
+
num_puzzle_identifiers=1,
|
| 113 |
+
|
| 114 |
+
total_groups=len(results["group_indices"]) - 1,
|
| 115 |
+
mean_puzzle_examples=1,
|
| 116 |
+
sets=["all"]
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
# Save metadata as JSON.
|
| 120 |
+
save_dir = os.path.join(config.output_dir, set_name)
|
| 121 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 122 |
+
|
| 123 |
+
with open(os.path.join(save_dir, "dataset.json"), "w") as f:
|
| 124 |
+
json.dump(metadata.model_dump(), f)
|
| 125 |
+
|
| 126 |
+
# Save data
|
| 127 |
+
for k, v in results.items():
|
| 128 |
+
np.save(os.path.join(save_dir, f"all__{k}.npy"), v)
|
| 129 |
+
|
| 130 |
+
# Save IDs mapping (for visualization only)
|
| 131 |
+
with open(os.path.join(config.output_dir, "identifiers.json"), "w") as f:
|
| 132 |
+
json.dump(["<blank>"], f)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
@cli.command(singleton=True)
|
| 136 |
+
def preprocess_data(config: DataProcessConfig):
|
| 137 |
+
convert_subset("train", config)
|
| 138 |
+
convert_subset("test", config)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
if __name__ == "__main__":
|
| 142 |
+
cli()
|
dataset/build_sudoku_dataset.py
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
import os
|
| 3 |
+
import csv
|
| 4 |
+
import json
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
from argdantic import ArgParser
|
| 8 |
+
from pydantic import BaseModel
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
from huggingface_hub import hf_hub_download
|
| 11 |
+
|
| 12 |
+
from common import PuzzleDatasetMetadata
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
cli = ArgParser()
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class DataProcessConfig(BaseModel):
|
| 19 |
+
source_repo: str = "sapientinc/sudoku-extreme"
|
| 20 |
+
output_dir: str = "data/sudoku-extreme-full"
|
| 21 |
+
|
| 22 |
+
subsample_size: Optional[int] = None
|
| 23 |
+
min_difficulty: Optional[int] = None
|
| 24 |
+
num_aug: int = 0
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def shuffle_sudoku(board: np.ndarray, solution: np.ndarray):
|
| 28 |
+
# Create a random digit mapping: a permutation of 1..9, with zero (blank) unchanged
|
| 29 |
+
digit_map = np.pad(np.random.permutation(np.arange(1, 10)), (1, 0))
|
| 30 |
+
|
| 31 |
+
# Randomly decide whether to transpose.
|
| 32 |
+
transpose_flag = np.random.rand() < 0.5
|
| 33 |
+
|
| 34 |
+
# Generate a valid row permutation:
|
| 35 |
+
# - Shuffle the 3 bands (each band = 3 rows) and for each band, shuffle its 3 rows.
|
| 36 |
+
bands = np.random.permutation(3)
|
| 37 |
+
row_perm = np.concatenate([b * 3 + np.random.permutation(3) for b in bands])
|
| 38 |
+
|
| 39 |
+
# Similarly for columns (stacks).
|
| 40 |
+
stacks = np.random.permutation(3)
|
| 41 |
+
col_perm = np.concatenate([s * 3 + np.random.permutation(3) for s in stacks])
|
| 42 |
+
|
| 43 |
+
# Build an 81->81 mapping. For each new cell at (i, j)
|
| 44 |
+
# (row index = i // 9, col index = i % 9),
|
| 45 |
+
# its value comes from old row = row_perm[i//9] and old col = col_perm[i%9].
|
| 46 |
+
mapping = np.array([row_perm[i // 9] * 9 + col_perm[i % 9] for i in range(81)])
|
| 47 |
+
|
| 48 |
+
def apply_transformation(x: np.ndarray) -> np.ndarray:
|
| 49 |
+
# Apply transpose flag
|
| 50 |
+
if transpose_flag:
|
| 51 |
+
x = x.T
|
| 52 |
+
# Apply the position mapping.
|
| 53 |
+
new_board = x.flatten()[mapping].reshape(9, 9).copy()
|
| 54 |
+
# Apply digit mapping
|
| 55 |
+
return digit_map[new_board]
|
| 56 |
+
|
| 57 |
+
return apply_transformation(board), apply_transformation(solution)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def convert_subset(set_name: str, config: DataProcessConfig):
|
| 61 |
+
# Read CSV
|
| 62 |
+
inputs = []
|
| 63 |
+
labels = []
|
| 64 |
+
|
| 65 |
+
with open(hf_hub_download(config.source_repo, f"{set_name}.csv", repo_type="dataset"), newline="") as csvfile:
|
| 66 |
+
reader = csv.reader(csvfile)
|
| 67 |
+
next(reader) # Skip header
|
| 68 |
+
for source, q, a, rating in reader:
|
| 69 |
+
if (config.min_difficulty is None) or (int(rating) >= config.min_difficulty):
|
| 70 |
+
assert len(q) == 81 and len(a) == 81
|
| 71 |
+
|
| 72 |
+
inputs.append(np.frombuffer(q.replace('.', '0').encode(), dtype=np.uint8).reshape(9, 9) - ord('0'))
|
| 73 |
+
labels.append(np.frombuffer(a.encode(), dtype=np.uint8).reshape(9, 9) - ord('0'))
|
| 74 |
+
|
| 75 |
+
# If subsample_size is specified for the training set,
|
| 76 |
+
# randomly sample the desired number of examples.
|
| 77 |
+
if set_name == "train" and config.subsample_size is not None:
|
| 78 |
+
total_samples = len(inputs)
|
| 79 |
+
if config.subsample_size < total_samples:
|
| 80 |
+
indices = np.random.choice(total_samples, size=config.subsample_size, replace=False)
|
| 81 |
+
inputs = [inputs[i] for i in indices]
|
| 82 |
+
labels = [labels[i] for i in indices]
|
| 83 |
+
|
| 84 |
+
# Generate dataset
|
| 85 |
+
num_augments = config.num_aug if set_name == "train" else 0
|
| 86 |
+
|
| 87 |
+
results = {k: [] for k in ["inputs", "labels", "puzzle_identifiers", "puzzle_indices", "group_indices"]}
|
| 88 |
+
puzzle_id = 0
|
| 89 |
+
example_id = 0
|
| 90 |
+
|
| 91 |
+
results["puzzle_indices"].append(0)
|
| 92 |
+
results["group_indices"].append(0)
|
| 93 |
+
|
| 94 |
+
for orig_inp, orig_out in zip(tqdm(inputs), labels):
|
| 95 |
+
for aug_idx in range(1 + num_augments):
|
| 96 |
+
# First index is not augmented
|
| 97 |
+
if aug_idx == 0:
|
| 98 |
+
inp, out = orig_inp, orig_out
|
| 99 |
+
else:
|
| 100 |
+
inp, out = shuffle_sudoku(orig_inp, orig_out)
|
| 101 |
+
|
| 102 |
+
# Push puzzle (only single example)
|
| 103 |
+
results["inputs"].append(inp)
|
| 104 |
+
results["labels"].append(out)
|
| 105 |
+
example_id += 1
|
| 106 |
+
puzzle_id += 1
|
| 107 |
+
|
| 108 |
+
results["puzzle_indices"].append(example_id)
|
| 109 |
+
results["puzzle_identifiers"].append(0)
|
| 110 |
+
|
| 111 |
+
# Push group
|
| 112 |
+
results["group_indices"].append(puzzle_id)
|
| 113 |
+
|
| 114 |
+
# To Numpy
|
| 115 |
+
def _seq_to_numpy(seq):
|
| 116 |
+
arr = np.concatenate(seq).reshape(len(seq), -1)
|
| 117 |
+
|
| 118 |
+
assert np.all((arr >= 0) & (arr <= 9))
|
| 119 |
+
return arr + 1
|
| 120 |
+
|
| 121 |
+
results = {
|
| 122 |
+
"inputs": _seq_to_numpy(results["inputs"]),
|
| 123 |
+
"labels": _seq_to_numpy(results["labels"]),
|
| 124 |
+
|
| 125 |
+
"group_indices": np.array(results["group_indices"], dtype=np.int32),
|
| 126 |
+
"puzzle_indices": np.array(results["puzzle_indices"], dtype=np.int32),
|
| 127 |
+
"puzzle_identifiers": np.array(results["puzzle_identifiers"], dtype=np.int32),
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
# Metadata
|
| 131 |
+
metadata = PuzzleDatasetMetadata(
|
| 132 |
+
seq_len=81,
|
| 133 |
+
vocab_size=10 + 1, # PAD + "0" ... "9"
|
| 134 |
+
|
| 135 |
+
pad_id=0,
|
| 136 |
+
ignore_label_id=0,
|
| 137 |
+
|
| 138 |
+
blank_identifier_id=0,
|
| 139 |
+
num_puzzle_identifiers=1,
|
| 140 |
+
|
| 141 |
+
total_groups=len(results["group_indices"]) - 1,
|
| 142 |
+
mean_puzzle_examples=1,
|
| 143 |
+
sets=["all"]
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
# Save metadata as JSON.
|
| 147 |
+
save_dir = os.path.join(config.output_dir, set_name)
|
| 148 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 149 |
+
|
| 150 |
+
with open(os.path.join(save_dir, "dataset.json"), "w") as f:
|
| 151 |
+
json.dump(metadata.model_dump(), f)
|
| 152 |
+
|
| 153 |
+
# Save data
|
| 154 |
+
for k, v in results.items():
|
| 155 |
+
np.save(os.path.join(save_dir, f"all__{k}.npy"), v)
|
| 156 |
+
|
| 157 |
+
# Save IDs mapping (for visualization only)
|
| 158 |
+
with open(os.path.join(config.output_dir, "identifiers.json"), "w") as f:
|
| 159 |
+
json.dump(["<blank>"], f)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
@cli.command(singleton=True)
|
| 163 |
+
def preprocess_data(config: DataProcessConfig):
|
| 164 |
+
convert_subset("train", config)
|
| 165 |
+
convert_subset("test", config)
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
if __name__ == "__main__":
|
| 169 |
+
cli()
|
dataset/common.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Optional
|
| 2 |
+
|
| 3 |
+
import pydantic
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
# Global list mapping each dihedral transform id to its inverse.
|
| 8 |
+
# Index corresponds to the original tid, and the value is its inverse.
|
| 9 |
+
DIHEDRAL_INVERSE = [0, 3, 2, 1, 4, 5, 6, 7]
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class PuzzleDatasetMetadata(pydantic.BaseModel):
|
| 13 |
+
pad_id: int
|
| 14 |
+
ignore_label_id: Optional[int]
|
| 15 |
+
blank_identifier_id: int
|
| 16 |
+
|
| 17 |
+
vocab_size: int
|
| 18 |
+
seq_len: int
|
| 19 |
+
num_puzzle_identifiers: int
|
| 20 |
+
|
| 21 |
+
total_groups: int
|
| 22 |
+
mean_puzzle_examples: float
|
| 23 |
+
|
| 24 |
+
sets: List[str]
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def dihedral_transform(arr: np.ndarray, tid: int) -> np.ndarray:
|
| 28 |
+
"""8 dihedral symmetries by rotate, flip and mirror"""
|
| 29 |
+
|
| 30 |
+
if tid == 0:
|
| 31 |
+
return arr # identity
|
| 32 |
+
elif tid == 1:
|
| 33 |
+
return np.rot90(arr, k=1)
|
| 34 |
+
elif tid == 2:
|
| 35 |
+
return np.rot90(arr, k=2)
|
| 36 |
+
elif tid == 3:
|
| 37 |
+
return np.rot90(arr, k=3)
|
| 38 |
+
elif tid == 4:
|
| 39 |
+
return np.fliplr(arr) # horizontal flip
|
| 40 |
+
elif tid == 5:
|
| 41 |
+
return np.flipud(arr) # vertical flip
|
| 42 |
+
elif tid == 6:
|
| 43 |
+
return arr.T # transpose (reflection along main diagonal)
|
| 44 |
+
elif tid == 7:
|
| 45 |
+
return np.fliplr(np.rot90(arr, k=1)) # anti-diagonal reflection
|
| 46 |
+
else:
|
| 47 |
+
return arr
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def inverse_dihedral_transform(arr: np.ndarray, tid: int) -> np.ndarray:
|
| 51 |
+
return dihedral_transform(arr, DIHEDRAL_INVERSE[tid])
|
evaluate.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List
|
| 2 |
+
import yaml
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.distributed as dist
|
| 7 |
+
|
| 8 |
+
import pydantic
|
| 9 |
+
from omegaconf import OmegaConf
|
| 10 |
+
from pretrain import PretrainConfig, init_train_state, evaluate, create_dataloader
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class EvalConfig(pydantic.BaseModel):
|
| 14 |
+
checkpoint: str
|
| 15 |
+
|
| 16 |
+
save_outputs: List[str] = ["inputs", "labels", "puzzle_identifiers", "logits", "q_halt_logits", "q_continue_logits"]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def launch():
|
| 20 |
+
eval_cfg = EvalConfig(**OmegaConf.to_container(OmegaConf.from_cli())) # type: ignore
|
| 21 |
+
|
| 22 |
+
RANK = 0
|
| 23 |
+
WORLD_SIZE = 1
|
| 24 |
+
# Initialize distributed training if in distributed environment (e.g. torchrun)
|
| 25 |
+
if "LOCAL_RANK" in os.environ:
|
| 26 |
+
# Initialize distributed, default device and dtype
|
| 27 |
+
dist.init_process_group(backend="nccl")
|
| 28 |
+
|
| 29 |
+
RANK = dist.get_rank()
|
| 30 |
+
WORLD_SIZE = dist.get_world_size()
|
| 31 |
+
|
| 32 |
+
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
|
| 33 |
+
|
| 34 |
+
with open(os.path.join(os.path.dirname(eval_cfg.checkpoint), "all_config.yaml"), "r") as f:
|
| 35 |
+
config = PretrainConfig(**yaml.safe_load(f))
|
| 36 |
+
|
| 37 |
+
config.eval_save_outputs = eval_cfg.save_outputs
|
| 38 |
+
config.checkpoint_path = os.path.dirname(eval_cfg.checkpoint)
|
| 39 |
+
|
| 40 |
+
# Dataloader
|
| 41 |
+
train_loader, train_metadata = create_dataloader(config, "train", test_set_mode=False, epochs_per_iter=1, global_batch_size=config.global_batch_size, rank=RANK, world_size=WORLD_SIZE)
|
| 42 |
+
eval_loader, eval_metadata = create_dataloader(config, "test", test_set_mode=True, epochs_per_iter=1, global_batch_size=config.global_batch_size, rank=RANK, world_size=WORLD_SIZE)
|
| 43 |
+
|
| 44 |
+
# Models
|
| 45 |
+
train_state = init_train_state(config, train_metadata, world_size=WORLD_SIZE)
|
| 46 |
+
# Try unwrap torch.compile
|
| 47 |
+
try:
|
| 48 |
+
train_state.model.load_state_dict(torch.load(eval_cfg.checkpoint, map_location="cuda"), assign=True)
|
| 49 |
+
except:
|
| 50 |
+
train_state.model.load_state_dict({k.removeprefix("_orig_mod."): v for k, v in torch.load(eval_cfg.checkpoint, map_location="cuda").items()}, assign=True)
|
| 51 |
+
|
| 52 |
+
train_state.step = 0
|
| 53 |
+
ckpt_filename = os.path.basename(eval_cfg.checkpoint)
|
| 54 |
+
if ckpt_filename.startswith("step_"):
|
| 55 |
+
train_state.step = int(ckpt_filename.removeprefix("step_"))
|
| 56 |
+
|
| 57 |
+
# Evaluate
|
| 58 |
+
print ("Starting evaluation")
|
| 59 |
+
|
| 60 |
+
train_state.model.eval()
|
| 61 |
+
metrics = evaluate(config, train_state, eval_loader, eval_metadata, rank=RANK, world_size=WORLD_SIZE)
|
| 62 |
+
|
| 63 |
+
if metrics is not None:
|
| 64 |
+
print (metrics)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
if __name__ == "__main__":
|
| 68 |
+
launch()
|
models/common.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def trunc_normal_init_(tensor: torch.Tensor, std: float = 1.0, lower: float = -2.0, upper: float = 2.0):
|
| 8 |
+
# NOTE: PyTorch nn.init.trunc_normal_ is not mathematically correct, the std dev is not actually the std dev of initialized tensor
|
| 9 |
+
# This function is a PyTorch version of jax truncated normal init (default init method in flax)
|
| 10 |
+
# https://github.com/jax-ml/jax/blob/main/jax/_src/random.py#L807-L848
|
| 11 |
+
# https://github.com/jax-ml/jax/blob/main/jax/_src/nn/initializers.py#L162-L199
|
| 12 |
+
|
| 13 |
+
with torch.no_grad():
|
| 14 |
+
if std == 0:
|
| 15 |
+
tensor.zero_()
|
| 16 |
+
else:
|
| 17 |
+
sqrt2 = math.sqrt(2)
|
| 18 |
+
a = math.erf(lower / sqrt2)
|
| 19 |
+
b = math.erf(upper / sqrt2)
|
| 20 |
+
z = (b - a) / 2
|
| 21 |
+
|
| 22 |
+
c = (2 * math.pi) ** -0.5
|
| 23 |
+
pdf_u = c * math.exp(-0.5 * lower ** 2)
|
| 24 |
+
pdf_l = c * math.exp(-0.5 * upper ** 2)
|
| 25 |
+
comp_std = std / math.sqrt(1 - (upper * pdf_u - lower * pdf_l) / z - ((pdf_u - pdf_l) / z) ** 2)
|
| 26 |
+
|
| 27 |
+
tensor.uniform_(a, b)
|
| 28 |
+
tensor.erfinv_()
|
| 29 |
+
tensor.mul_(sqrt2 * comp_std)
|
| 30 |
+
tensor.clip_(lower * comp_std, upper * comp_std)
|
| 31 |
+
|
| 32 |
+
return tensor
|
models/hrm/hrm_act_v1.py
ADDED
|
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Tuple, List, Dict, Optional
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
import math
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from torch import nn
|
| 8 |
+
from pydantic import BaseModel
|
| 9 |
+
|
| 10 |
+
from models.common import trunc_normal_init_
|
| 11 |
+
from models.layers import rms_norm, SwiGLU, Attention, RotaryEmbedding, CosSin, CastedEmbedding, CastedLinear
|
| 12 |
+
from models.sparse_embedding import CastedSparseEmbedding
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@dataclass
|
| 16 |
+
class HierarchicalReasoningModel_ACTV1InnerCarry:
|
| 17 |
+
z_H: torch.Tensor
|
| 18 |
+
z_L: torch.Tensor
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@dataclass
|
| 22 |
+
class HierarchicalReasoningModel_ACTV1Carry:
|
| 23 |
+
inner_carry: HierarchicalReasoningModel_ACTV1InnerCarry
|
| 24 |
+
|
| 25 |
+
steps: torch.Tensor
|
| 26 |
+
halted: torch.Tensor
|
| 27 |
+
|
| 28 |
+
current_data: Dict[str, torch.Tensor]
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class HierarchicalReasoningModel_ACTV1Config(BaseModel):
|
| 32 |
+
batch_size: int
|
| 33 |
+
seq_len: int
|
| 34 |
+
puzzle_emb_ndim: int = 0
|
| 35 |
+
num_puzzle_identifiers: int
|
| 36 |
+
vocab_size: int
|
| 37 |
+
|
| 38 |
+
H_cycles: int
|
| 39 |
+
L_cycles: int
|
| 40 |
+
|
| 41 |
+
H_layers: int
|
| 42 |
+
L_layers: int
|
| 43 |
+
|
| 44 |
+
# Transformer config
|
| 45 |
+
hidden_size: int
|
| 46 |
+
expansion: float
|
| 47 |
+
num_heads: int
|
| 48 |
+
pos_encodings: str
|
| 49 |
+
|
| 50 |
+
rms_norm_eps: float = 1e-5
|
| 51 |
+
rope_theta: float = 10000.0
|
| 52 |
+
|
| 53 |
+
# Halting Q-learning config
|
| 54 |
+
halt_max_steps: int
|
| 55 |
+
halt_exploration_prob: float
|
| 56 |
+
|
| 57 |
+
forward_dtype: str = "bfloat16"
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class HierarchicalReasoningModel_ACTV1Block(nn.Module):
|
| 61 |
+
def __init__(self, config: HierarchicalReasoningModel_ACTV1Config) -> None:
|
| 62 |
+
super().__init__()
|
| 63 |
+
|
| 64 |
+
self.self_attn = Attention(
|
| 65 |
+
hidden_size=config.hidden_size,
|
| 66 |
+
head_dim=config.hidden_size // config.num_heads,
|
| 67 |
+
num_heads=config.num_heads,
|
| 68 |
+
num_key_value_heads=config.num_heads,
|
| 69 |
+
causal=False
|
| 70 |
+
)
|
| 71 |
+
self.mlp = SwiGLU(
|
| 72 |
+
hidden_size=config.hidden_size,
|
| 73 |
+
expansion=config.expansion,
|
| 74 |
+
)
|
| 75 |
+
self.norm_eps = config.rms_norm_eps
|
| 76 |
+
|
| 77 |
+
def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 78 |
+
# Post Norm
|
| 79 |
+
# Self Attention
|
| 80 |
+
hidden_states = rms_norm(hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states), variance_epsilon=self.norm_eps)
|
| 81 |
+
# Fully Connected
|
| 82 |
+
hidden_states = rms_norm(hidden_states + self.mlp(hidden_states), variance_epsilon=self.norm_eps)
|
| 83 |
+
return hidden_states
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class HierarchicalReasoningModel_ACTV1ReasoningModule(nn.Module):
|
| 87 |
+
def __init__(self, layers: List[HierarchicalReasoningModel_ACTV1Block]):
|
| 88 |
+
super().__init__()
|
| 89 |
+
|
| 90 |
+
self.layers = torch.nn.ModuleList(layers)
|
| 91 |
+
|
| 92 |
+
def forward(self, hidden_states: torch.Tensor, input_injection: torch.Tensor, **kwargs) -> torch.Tensor:
|
| 93 |
+
# Input injection (add)
|
| 94 |
+
hidden_states = hidden_states + input_injection
|
| 95 |
+
# Layers
|
| 96 |
+
for layer in self.layers:
|
| 97 |
+
hidden_states = layer(hidden_states=hidden_states, **kwargs)
|
| 98 |
+
|
| 99 |
+
return hidden_states
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class HierarchicalReasoningModel_ACTV1_Inner(nn.Module):
|
| 103 |
+
def __init__(self, config: HierarchicalReasoningModel_ACTV1Config) -> None:
|
| 104 |
+
super().__init__()
|
| 105 |
+
self.config = config
|
| 106 |
+
self.forward_dtype = getattr(torch, self.config.forward_dtype)
|
| 107 |
+
|
| 108 |
+
# I/O
|
| 109 |
+
self.embed_scale = math.sqrt(self.config.hidden_size)
|
| 110 |
+
embed_init_std = 1.0 / self.embed_scale
|
| 111 |
+
|
| 112 |
+
self.embed_tokens = CastedEmbedding(self.config.vocab_size, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
|
| 113 |
+
self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False)
|
| 114 |
+
self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True)
|
| 115 |
+
|
| 116 |
+
self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) # ceil div
|
| 117 |
+
if self.config.puzzle_emb_ndim > 0:
|
| 118 |
+
# Zero init puzzle embeddings
|
| 119 |
+
self.puzzle_emb = CastedSparseEmbedding(self.config.num_puzzle_identifiers, self.config.puzzle_emb_ndim,
|
| 120 |
+
batch_size=self.config.batch_size, init_std=0, cast_to=self.forward_dtype)
|
| 121 |
+
|
| 122 |
+
# LM Blocks
|
| 123 |
+
if self.config.pos_encodings == "rope":
|
| 124 |
+
self.rotary_emb = RotaryEmbedding(dim=self.config.hidden_size // self.config.num_heads,
|
| 125 |
+
max_position_embeddings=self.config.seq_len + self.puzzle_emb_len,
|
| 126 |
+
base=self.config.rope_theta)
|
| 127 |
+
elif self.config.pos_encodings == "learned":
|
| 128 |
+
self.embed_pos = CastedEmbedding(self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
|
| 129 |
+
else:
|
| 130 |
+
raise NotImplementedError()
|
| 131 |
+
|
| 132 |
+
# Reasoning Layers
|
| 133 |
+
self.H_level = HierarchicalReasoningModel_ACTV1ReasoningModule(layers=[HierarchicalReasoningModel_ACTV1Block(self.config) for _i in range(self.config.H_layers)])
|
| 134 |
+
self.L_level = HierarchicalReasoningModel_ACTV1ReasoningModule(layers=[HierarchicalReasoningModel_ACTV1Block(self.config) for _i in range(self.config.L_layers)])
|
| 135 |
+
|
| 136 |
+
# Initial states
|
| 137 |
+
self.H_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
|
| 138 |
+
self.L_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
|
| 139 |
+
|
| 140 |
+
# Q head special init
|
| 141 |
+
# Init Q to (almost) zero for faster learning during bootstrapping
|
| 142 |
+
with torch.no_grad():
|
| 143 |
+
self.q_head.weight.zero_()
|
| 144 |
+
self.q_head.bias.fill_(-5) # type: ignore
|
| 145 |
+
|
| 146 |
+
def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor):
|
| 147 |
+
# Token embedding
|
| 148 |
+
embedding = self.embed_tokens(input.to(torch.int32))
|
| 149 |
+
|
| 150 |
+
# Puzzle embeddings
|
| 151 |
+
if self.config.puzzle_emb_ndim > 0:
|
| 152 |
+
puzzle_embedding = self.puzzle_emb(puzzle_identifiers)
|
| 153 |
+
|
| 154 |
+
pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1]
|
| 155 |
+
if pad_count > 0:
|
| 156 |
+
puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count))
|
| 157 |
+
|
| 158 |
+
embedding = torch.cat((puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2)
|
| 159 |
+
|
| 160 |
+
# Position embeddings
|
| 161 |
+
if self.config.pos_encodings == "learned":
|
| 162 |
+
# scale by 1/sqrt(2) to maintain forward variance
|
| 163 |
+
embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype))
|
| 164 |
+
|
| 165 |
+
# Scale
|
| 166 |
+
return self.embed_scale * embedding
|
| 167 |
+
|
| 168 |
+
def empty_carry(self, batch_size: int):
|
| 169 |
+
return HierarchicalReasoningModel_ACTV1InnerCarry(
|
| 170 |
+
z_H=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
|
| 171 |
+
z_L=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
def reset_carry(self, reset_flag: torch.Tensor, carry: HierarchicalReasoningModel_ACTV1InnerCarry):
|
| 175 |
+
return HierarchicalReasoningModel_ACTV1InnerCarry(
|
| 176 |
+
z_H=torch.where(reset_flag.view(-1, 1, 1), self.H_init, carry.z_H),
|
| 177 |
+
z_L=torch.where(reset_flag.view(-1, 1, 1), self.L_init, carry.z_L),
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
def forward(self, carry: HierarchicalReasoningModel_ACTV1InnerCarry, batch: Dict[str, torch.Tensor]) -> Tuple[HierarchicalReasoningModel_ACTV1InnerCarry, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 181 |
+
seq_info = dict(
|
| 182 |
+
cos_sin=self.rotary_emb() if hasattr(self, "rotary_emb") else None,
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
# Input encoding
|
| 186 |
+
input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"])
|
| 187 |
+
|
| 188 |
+
# Forward iterations
|
| 189 |
+
with torch.no_grad():
|
| 190 |
+
z_H, z_L = carry.z_H, carry.z_L
|
| 191 |
+
|
| 192 |
+
for _H_step in range(self.config.H_cycles):
|
| 193 |
+
for _L_step in range(self.config.L_cycles):
|
| 194 |
+
if not ((_H_step == self.config.H_cycles - 1) and (_L_step == self.config.L_cycles - 1)):
|
| 195 |
+
z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info)
|
| 196 |
+
|
| 197 |
+
if not (_H_step == self.config.H_cycles - 1):
|
| 198 |
+
z_H = self.H_level(z_H, z_L, **seq_info)
|
| 199 |
+
|
| 200 |
+
assert not z_H.requires_grad and not z_L.requires_grad
|
| 201 |
+
|
| 202 |
+
# 1-step grad
|
| 203 |
+
z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info)
|
| 204 |
+
z_H = self.H_level(z_H, z_L, **seq_info)
|
| 205 |
+
|
| 206 |
+
# LM Outputs
|
| 207 |
+
new_carry = HierarchicalReasoningModel_ACTV1InnerCarry(z_H=z_H.detach(), z_L=z_L.detach()) # New carry no grad
|
| 208 |
+
output = self.lm_head(z_H)[:, self.puzzle_emb_len:]
|
| 209 |
+
|
| 210 |
+
# Q head
|
| 211 |
+
q_logits = self.q_head(z_H[:, 0]).to(torch.float32)
|
| 212 |
+
|
| 213 |
+
return new_carry, output, (q_logits[..., 0], q_logits[..., 1])
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
class HierarchicalReasoningModel_ACTV1(nn.Module):
|
| 217 |
+
"""ACT wrapper."""
|
| 218 |
+
|
| 219 |
+
def __init__(self, config_dict: dict):
|
| 220 |
+
super().__init__()
|
| 221 |
+
self.config = HierarchicalReasoningModel_ACTV1Config(**config_dict)
|
| 222 |
+
self.inner = HierarchicalReasoningModel_ACTV1_Inner(self.config)
|
| 223 |
+
|
| 224 |
+
@property
|
| 225 |
+
def puzzle_emb(self):
|
| 226 |
+
return self.inner.puzzle_emb
|
| 227 |
+
|
| 228 |
+
def initial_carry(self, batch: Dict[str, torch.Tensor]):
|
| 229 |
+
batch_size = batch["inputs"].shape[0]
|
| 230 |
+
|
| 231 |
+
return HierarchicalReasoningModel_ACTV1Carry(
|
| 232 |
+
inner_carry=self.inner.empty_carry(batch_size), # Empty is expected, it will be reseted in first pass as all sequences are halted.
|
| 233 |
+
|
| 234 |
+
steps=torch.zeros((batch_size, ), dtype=torch.int32),
|
| 235 |
+
halted=torch.ones((batch_size, ), dtype=torch.bool), # Default to halted
|
| 236 |
+
|
| 237 |
+
current_data={k: torch.empty_like(v) for k, v in batch.items()}
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
def forward(self, carry: HierarchicalReasoningModel_ACTV1Carry, batch: Dict[str, torch.Tensor]) -> Tuple[HierarchicalReasoningModel_ACTV1Carry, Dict[str, torch.Tensor]]:
|
| 241 |
+
# Update data, carry (removing halted sequences)
|
| 242 |
+
new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry)
|
| 243 |
+
|
| 244 |
+
new_steps = torch.where(carry.halted, 0, carry.steps)
|
| 245 |
+
|
| 246 |
+
new_current_data = {k: torch.where(carry.halted.view((-1, ) + (1, ) * (batch[k].ndim - 1)), batch[k], v) for k, v in carry.current_data.items()}
|
| 247 |
+
|
| 248 |
+
# Forward inner model
|
| 249 |
+
new_inner_carry, logits, (q_halt_logits, q_continue_logits) = self.inner(new_inner_carry, new_current_data)
|
| 250 |
+
|
| 251 |
+
outputs = {
|
| 252 |
+
"logits": logits,
|
| 253 |
+
"q_halt_logits": q_halt_logits,
|
| 254 |
+
"q_continue_logits": q_continue_logits
|
| 255 |
+
}
|
| 256 |
+
|
| 257 |
+
with torch.no_grad():
|
| 258 |
+
# Step
|
| 259 |
+
new_steps = new_steps + 1
|
| 260 |
+
is_last_step = new_steps >= self.config.halt_max_steps
|
| 261 |
+
|
| 262 |
+
halted = is_last_step
|
| 263 |
+
|
| 264 |
+
# if training, and ACT is enabled
|
| 265 |
+
if self.training and (self.config.halt_max_steps > 1):
|
| 266 |
+
# Halt signal
|
| 267 |
+
# NOTE: During evaluation, always use max steps, this is to guarantee the same halting steps inside a batch for batching purposes
|
| 268 |
+
halted = halted | (q_halt_logits > q_continue_logits)
|
| 269 |
+
|
| 270 |
+
# Exploration
|
| 271 |
+
min_halt_steps = (torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1)
|
| 272 |
+
|
| 273 |
+
halted = halted & (new_steps >= min_halt_steps)
|
| 274 |
+
|
| 275 |
+
# Compute target Q
|
| 276 |
+
# NOTE: No replay buffer and target networks for computing target Q-value.
|
| 277 |
+
# As batch_size is large, there're many parallel envs.
|
| 278 |
+
# Similar concept as PQN https://arxiv.org/abs/2407.04811
|
| 279 |
+
next_q_halt_logits, next_q_continue_logits = self.inner(new_inner_carry, new_current_data)[-1]
|
| 280 |
+
|
| 281 |
+
outputs["target_q_continue"] = torch.sigmoid(torch.where(is_last_step, next_q_halt_logits, torch.maximum(next_q_halt_logits, next_q_continue_logits)))
|
| 282 |
+
|
| 283 |
+
return HierarchicalReasoningModel_ACTV1Carry(new_inner_carry, new_steps, halted, new_current_data), outputs
|
models/layers.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Tuple
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
try:
|
| 8 |
+
from flash_attn_interface import flash_attn_func # type: ignore[import]
|
| 9 |
+
except ImportError:
|
| 10 |
+
# Fallback to FlashAttention 2
|
| 11 |
+
from flash_attn import flash_attn_func # type: ignore[import]
|
| 12 |
+
|
| 13 |
+
from models.common import trunc_normal_init_
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
CosSin = Tuple[torch.Tensor, torch.Tensor]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def _find_multiple(a, b):
|
| 20 |
+
return (-(a // -b)) * b
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def rotate_half(x: torch.Tensor):
|
| 24 |
+
"""Rotates half the hidden dims of the input."""
|
| 25 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 26 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
| 27 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
|
| 31 |
+
# q, k: [bs, seq_len, num_heads, head_dim]
|
| 32 |
+
# cos, sin: [seq_len, head_dim]
|
| 33 |
+
orig_dtype = q.dtype
|
| 34 |
+
q = q.to(cos.dtype)
|
| 35 |
+
k = k.to(cos.dtype)
|
| 36 |
+
|
| 37 |
+
q_embed = (q * cos.unsqueeze(-2)) + (rotate_half(q) * sin.unsqueeze(-2))
|
| 38 |
+
k_embed = (k * cos.unsqueeze(-2)) + (rotate_half(k) * sin.unsqueeze(-2))
|
| 39 |
+
|
| 40 |
+
return q_embed.to(orig_dtype), k_embed.to(orig_dtype)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class CastedLinear(nn.Module):
|
| 44 |
+
def __init__(self,
|
| 45 |
+
in_features: int,
|
| 46 |
+
out_features: int,
|
| 47 |
+
bias: bool):
|
| 48 |
+
super().__init__()
|
| 49 |
+
# Truncated LeCun normal init
|
| 50 |
+
self.weight = nn.Parameter(
|
| 51 |
+
trunc_normal_init_(torch.empty((out_features, in_features)), std=1.0 / (in_features ** 0.5))
|
| 52 |
+
)
|
| 53 |
+
self.bias = None
|
| 54 |
+
if bias:
|
| 55 |
+
# Zero init bias
|
| 56 |
+
self.bias = nn.Parameter(torch.zeros((out_features, )))
|
| 57 |
+
|
| 58 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
| 59 |
+
return F.linear(input, self.weight.to(input.dtype), bias=self.bias.to(input.dtype) if self.bias is not None else None)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class CastedEmbedding(nn.Module):
|
| 63 |
+
def __init__(self,
|
| 64 |
+
num_embeddings: int,
|
| 65 |
+
embedding_dim: int,
|
| 66 |
+
init_std: float,
|
| 67 |
+
cast_to: torch.dtype):
|
| 68 |
+
super().__init__()
|
| 69 |
+
self.cast_to = cast_to
|
| 70 |
+
|
| 71 |
+
# Truncated LeCun normal init
|
| 72 |
+
self.embedding_weight = nn.Parameter(
|
| 73 |
+
trunc_normal_init_(torch.empty((num_embeddings, embedding_dim)), std=init_std)
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
| 77 |
+
return F.embedding(input, self.embedding_weight.to(self.cast_to))
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class RotaryEmbedding(nn.Module):
|
| 81 |
+
def __init__(self, dim, max_position_embeddings, base, device=None):
|
| 82 |
+
super().__init__()
|
| 83 |
+
|
| 84 |
+
# RoPE
|
| 85 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
|
| 86 |
+
t = torch.arange(max_position_embeddings, dtype=torch.float32, device=device)
|
| 87 |
+
freqs = torch.outer(t, inv_freq)
|
| 88 |
+
|
| 89 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
| 90 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 91 |
+
self.cos_cached = nn.Buffer(emb.cos(), persistent=False)
|
| 92 |
+
self.sin_cached = nn.Buffer(emb.sin(), persistent=False)
|
| 93 |
+
|
| 94 |
+
def forward(self):
|
| 95 |
+
return self.cos_cached, self.sin_cached
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class Attention(nn.Module):
|
| 99 |
+
def __init__(self, hidden_size, head_dim, num_heads, num_key_value_heads, causal=False):
|
| 100 |
+
super().__init__()
|
| 101 |
+
|
| 102 |
+
self.hidden_size = hidden_size
|
| 103 |
+
self.head_dim = head_dim
|
| 104 |
+
self.output_size = head_dim * num_heads
|
| 105 |
+
self.num_heads = num_heads
|
| 106 |
+
self.num_key_value_heads = num_key_value_heads
|
| 107 |
+
self.causal = causal
|
| 108 |
+
|
| 109 |
+
self.qkv_proj = CastedLinear(self.hidden_size, (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim, bias=False)
|
| 110 |
+
self.o_proj = CastedLinear(self.output_size, self.hidden_size, bias=False)
|
| 111 |
+
|
| 112 |
+
def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 113 |
+
batch_size, seq_len, _ = hidden_states.shape
|
| 114 |
+
|
| 115 |
+
# hidden_states: [bs, seq_len, num_heads, head_dim]
|
| 116 |
+
qkv = self.qkv_proj(hidden_states)
|
| 117 |
+
|
| 118 |
+
# Split head
|
| 119 |
+
qkv = qkv.view(batch_size, seq_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
|
| 120 |
+
query = qkv[:, :, :self.num_heads]
|
| 121 |
+
key = qkv[:, :, self.num_heads: self.num_heads + self.num_key_value_heads]
|
| 122 |
+
value = qkv[:, :, self.num_heads + self.num_key_value_heads:]
|
| 123 |
+
|
| 124 |
+
# RoPE
|
| 125 |
+
if cos_sin is not None:
|
| 126 |
+
cos, sin = cos_sin
|
| 127 |
+
query, key = apply_rotary_pos_emb(query, key, cos, sin)
|
| 128 |
+
|
| 129 |
+
# flash attn
|
| 130 |
+
attn_output = flash_attn_func(q=query, k=key, v=value, causal=self.causal)
|
| 131 |
+
if isinstance(attn_output, tuple): # fa2 and fa3 compatibility
|
| 132 |
+
attn_output = attn_output[0]
|
| 133 |
+
|
| 134 |
+
# attn_output: [batch_size, num_heads, seq_len, head_dim]
|
| 135 |
+
attn_output = attn_output.view(batch_size, seq_len, self.output_size) # type: ignore
|
| 136 |
+
return self.o_proj(attn_output)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
class SwiGLU(nn.Module):
|
| 140 |
+
def __init__(self, hidden_size: int, expansion: float):
|
| 141 |
+
super().__init__()
|
| 142 |
+
inter = _find_multiple(round(expansion * hidden_size * 2 / 3), 256)
|
| 143 |
+
|
| 144 |
+
self.gate_up_proj = CastedLinear(hidden_size, inter * 2, bias=False)
|
| 145 |
+
self.down_proj = CastedLinear(inter, hidden_size, bias=False)
|
| 146 |
+
|
| 147 |
+
def forward(self, x):
|
| 148 |
+
gate, up = self.gate_up_proj(x).chunk(2, dim=-1)
|
| 149 |
+
return self.down_proj(F.silu(gate) * up)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def rms_norm(hidden_states: torch.Tensor, variance_epsilon: float) -> torch.Tensor:
|
| 153 |
+
input_dtype = hidden_states.dtype
|
| 154 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 155 |
+
|
| 156 |
+
variance = hidden_states.square().mean(-1, keepdim=True)
|
| 157 |
+
hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
|
| 158 |
+
return hidden_states.to(input_dtype)
|
models/losses.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Tuple, Dict, Sequence, Optional
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from torch import nn
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
IGNORE_LABEL_ID = -100
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def s(x, epsilon=1e-30):
|
| 12 |
+
return torch.where(
|
| 13 |
+
x<0,
|
| 14 |
+
1/(1-x+ epsilon),
|
| 15 |
+
x + 1
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def log_stablemax(x, dim=-1):
|
| 20 |
+
s_x = s(x)
|
| 21 |
+
return torch.log(s_x/torch.sum(s_x, dim=dim, keepdim=True))
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def stablemax_cross_entropy(logits, labels, ignore_index: int = -100):
|
| 25 |
+
logprobs = log_stablemax(logits.to(torch.float64), dim=-1)
|
| 26 |
+
|
| 27 |
+
valid_mask = labels != ignore_index
|
| 28 |
+
transformed_labels = torch.where(valid_mask, labels, 0)
|
| 29 |
+
prediction_logprobs = torch.gather(logprobs, index=transformed_labels.to(torch.long).unsqueeze(-1), dim=-1).squeeze(-1)
|
| 30 |
+
|
| 31 |
+
return -torch.where(valid_mask, prediction_logprobs, 0)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def softmax_cross_entropy(logits, labels, ignore_index: int = -100):
|
| 35 |
+
# Cast logits to f32
|
| 36 |
+
# Flatten logits
|
| 37 |
+
return F.cross_entropy(logits.to(torch.float32).view(-1, logits.shape[-1]), labels.to(torch.long).view(-1), ignore_index=ignore_index, reduction="none").view(labels.shape)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class ACTLossHead(nn.Module):
|
| 41 |
+
def __init__(self, model: nn.Module, loss_type: str):
|
| 42 |
+
super().__init__()
|
| 43 |
+
self.model = model
|
| 44 |
+
self.loss_fn = globals()[loss_type]
|
| 45 |
+
|
| 46 |
+
def initial_carry(self, *args, **kwargs):
|
| 47 |
+
return self.model.initial_carry(*args, **kwargs) # type: ignore
|
| 48 |
+
|
| 49 |
+
def forward(
|
| 50 |
+
self,
|
| 51 |
+
return_keys: Sequence[str],
|
| 52 |
+
# Model args
|
| 53 |
+
**model_kwargs,
|
| 54 |
+
) -> Tuple[Any, torch.Tensor, Dict[str, torch.Tensor], Optional[Dict[str, torch.Tensor]], torch.Tensor]:
|
| 55 |
+
# Model logits
|
| 56 |
+
# B x SeqLen x D
|
| 57 |
+
new_carry, outputs = self.model(**model_kwargs)
|
| 58 |
+
labels = new_carry.current_data["labels"]
|
| 59 |
+
|
| 60 |
+
# Correctness
|
| 61 |
+
with torch.no_grad():
|
| 62 |
+
mask = labels != IGNORE_LABEL_ID
|
| 63 |
+
loss_counts = mask.sum(-1)
|
| 64 |
+
loss_divisor = loss_counts.clamp_min(1).unsqueeze(-1) # Avoid NaNs in division
|
| 65 |
+
|
| 66 |
+
is_correct = mask & (torch.argmax(outputs["logits"], dim=-1) == labels)
|
| 67 |
+
seq_is_correct = is_correct.sum(-1) == loss_counts
|
| 68 |
+
|
| 69 |
+
# Metrics (halted)
|
| 70 |
+
valid_metrics = new_carry.halted & (loss_counts > 0)
|
| 71 |
+
metrics = {
|
| 72 |
+
"count": valid_metrics.sum(),
|
| 73 |
+
|
| 74 |
+
"accuracy": torch.where(valid_metrics, (is_correct.to(torch.float32) / loss_divisor).sum(-1), 0).sum(),
|
| 75 |
+
"exact_accuracy": (valid_metrics & seq_is_correct).sum(),
|
| 76 |
+
|
| 77 |
+
"q_halt_accuracy": (valid_metrics & ((outputs["q_halt_logits"] >= 0) == seq_is_correct)).sum(),
|
| 78 |
+
"steps": torch.where(valid_metrics, new_carry.steps, 0).sum(),
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
# Losses
|
| 82 |
+
# FIXME: Assuming the batch is always full
|
| 83 |
+
lm_loss = (self.loss_fn(outputs["logits"], labels, ignore_index=IGNORE_LABEL_ID) / loss_divisor).sum()
|
| 84 |
+
q_halt_loss = F.binary_cross_entropy_with_logits(outputs["q_halt_logits"], seq_is_correct.to(outputs["q_halt_logits"].dtype), reduction="sum")
|
| 85 |
+
|
| 86 |
+
metrics.update({
|
| 87 |
+
"lm_loss": lm_loss.detach(),
|
| 88 |
+
"q_halt_loss": q_halt_loss.detach(),
|
| 89 |
+
})
|
| 90 |
+
|
| 91 |
+
# Q continue (bootstrapping target loss)
|
| 92 |
+
q_continue_loss = 0
|
| 93 |
+
if "target_q_continue" in outputs:
|
| 94 |
+
q_continue_loss = F.binary_cross_entropy_with_logits(outputs["q_continue_logits"], outputs["target_q_continue"], reduction="sum")
|
| 95 |
+
|
| 96 |
+
metrics["q_continue_loss"] = q_continue_loss.detach()
|
| 97 |
+
|
| 98 |
+
# Filter outputs for return
|
| 99 |
+
detached_outputs = {k: outputs[k].detach() for k in return_keys if k in outputs}
|
| 100 |
+
|
| 101 |
+
return new_carry, lm_loss + 0.5 * (q_halt_loss + q_continue_loss), metrics, detached_outputs, new_carry.halted.all()
|
models/sparse_embedding.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Union
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
import torch.distributed as dist
|
| 6 |
+
from torch.optim.optimizer import Optimizer, ParamsT
|
| 7 |
+
|
| 8 |
+
from models.common import trunc_normal_init_
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class CastedSparseEmbedding(nn.Module):
|
| 12 |
+
def __init__(self, num_embeddings: int, embedding_dim: int, batch_size: int, init_std: float, cast_to: torch.dtype):
|
| 13 |
+
super().__init__()
|
| 14 |
+
self.cast_to = cast_to
|
| 15 |
+
|
| 16 |
+
# Real Weights
|
| 17 |
+
# Truncated LeCun normal init
|
| 18 |
+
self.weights = nn.Buffer(
|
| 19 |
+
trunc_normal_init_(torch.empty((num_embeddings, embedding_dim)), std=init_std), persistent=True
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
# Local weights and IDs
|
| 23 |
+
# Local embeddings, with gradient, not persistent
|
| 24 |
+
self.local_weights = nn.Buffer(torch.zeros(batch_size, embedding_dim, requires_grad=True), persistent=False)
|
| 25 |
+
# Local embedding IDs, not persistent
|
| 26 |
+
self.local_ids = nn.Buffer(torch.zeros(batch_size, dtype=torch.int32), persistent=False)
|
| 27 |
+
|
| 28 |
+
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
| 29 |
+
if not self.training:
|
| 30 |
+
# Test mode, no gradient
|
| 31 |
+
return self.weights[inputs].to(self.cast_to)
|
| 32 |
+
|
| 33 |
+
# Training mode, fill puzzle embedding from weights
|
| 34 |
+
with torch.no_grad():
|
| 35 |
+
self.local_weights.copy_(self.weights[inputs])
|
| 36 |
+
self.local_ids.copy_(inputs)
|
| 37 |
+
|
| 38 |
+
return self.local_weights.to(self.cast_to)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class CastedSparseEmbeddingSignSGD_Distributed(Optimizer):
|
| 42 |
+
def __init__(
|
| 43 |
+
self,
|
| 44 |
+
params: ParamsT,
|
| 45 |
+
|
| 46 |
+
world_size: int,
|
| 47 |
+
lr: Union[float, torch.Tensor] = 1e-3,
|
| 48 |
+
weight_decay: float = 1e-2,
|
| 49 |
+
):
|
| 50 |
+
if not 0.0 <= lr:
|
| 51 |
+
raise ValueError(f"Invalid learning rate: {lr}")
|
| 52 |
+
if not 0.0 <= weight_decay:
|
| 53 |
+
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
| 54 |
+
|
| 55 |
+
defaults = dict(
|
| 56 |
+
lr=lr,
|
| 57 |
+
weight_decay=weight_decay,
|
| 58 |
+
world_size=world_size
|
| 59 |
+
)
|
| 60 |
+
super().__init__(params, defaults)
|
| 61 |
+
|
| 62 |
+
@torch.no_grad
|
| 63 |
+
def step(self, closure=None): # type: ignore
|
| 64 |
+
for group in self.param_groups:
|
| 65 |
+
# Find the sparse embedding weights
|
| 66 |
+
local_weights_grad = None
|
| 67 |
+
local_ids = None
|
| 68 |
+
weights = None
|
| 69 |
+
|
| 70 |
+
assert len(group["params"]) == 3
|
| 71 |
+
for p in group["params"]:
|
| 72 |
+
if p.requires_grad:
|
| 73 |
+
local_weights_grad = p.grad
|
| 74 |
+
elif p.ndim == 1:
|
| 75 |
+
local_ids = p
|
| 76 |
+
elif p.ndim == 2:
|
| 77 |
+
weights = p
|
| 78 |
+
else:
|
| 79 |
+
assert False
|
| 80 |
+
|
| 81 |
+
assert local_weights_grad is not None
|
| 82 |
+
assert local_ids is not None
|
| 83 |
+
assert weights is not None
|
| 84 |
+
|
| 85 |
+
# Apply SignSGD
|
| 86 |
+
# Adam ≈ SignSGD if gradient is very sparse
|
| 87 |
+
_sparse_emb_signsgd_dist(
|
| 88 |
+
local_weights_grad,
|
| 89 |
+
local_ids,
|
| 90 |
+
weights,
|
| 91 |
+
|
| 92 |
+
lr=group["lr"],
|
| 93 |
+
weight_decay=group["weight_decay"],
|
| 94 |
+
world_size=group["world_size"]
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def _sparse_emb_signsgd_dist(
|
| 99 |
+
local_weights_grad: torch.Tensor,
|
| 100 |
+
local_ids: torch.Tensor,
|
| 101 |
+
weights: torch.Tensor,
|
| 102 |
+
|
| 103 |
+
lr: float,
|
| 104 |
+
weight_decay: float,
|
| 105 |
+
world_size: int
|
| 106 |
+
) -> None:
|
| 107 |
+
N, D = local_weights_grad.shape
|
| 108 |
+
|
| 109 |
+
# All-gather
|
| 110 |
+
all_weights_grad = local_weights_grad
|
| 111 |
+
all_ids = local_ids
|
| 112 |
+
|
| 113 |
+
if world_size > 1:
|
| 114 |
+
all_weights_grad = torch.empty((world_size * N, D), dtype=local_weights_grad.dtype, device=local_weights_grad.device)
|
| 115 |
+
all_ids = torch.empty(world_size * N, dtype=local_ids.dtype, device=local_ids.device)
|
| 116 |
+
|
| 117 |
+
dist.all_gather_into_tensor(all_weights_grad, local_weights_grad)
|
| 118 |
+
dist.all_gather_into_tensor(all_ids, local_ids)
|
| 119 |
+
|
| 120 |
+
# Unique
|
| 121 |
+
grad_ids, inv = all_ids.unique(return_inverse=True)
|
| 122 |
+
|
| 123 |
+
grad = torch.zeros((grad_ids.shape[0], D), dtype=all_weights_grad.dtype, device=all_weights_grad.device)
|
| 124 |
+
grad.scatter_add_(0, inv.unsqueeze(-1).expand(-1, D), all_weights_grad)
|
| 125 |
+
|
| 126 |
+
# SignSGD with decoupled weight decay
|
| 127 |
+
p = weights[grad_ids]
|
| 128 |
+
|
| 129 |
+
p.mul_(1.0 - lr * weight_decay).add_(torch.sign(grad), alpha=-lr)
|
| 130 |
+
|
| 131 |
+
# Write updated slices back
|
| 132 |
+
weights[grad_ids] = p
|
pretrain.py
ADDED
|
@@ -0,0 +1,453 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional, Any, Sequence, List
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
import os
|
| 4 |
+
import math
|
| 5 |
+
import yaml
|
| 6 |
+
import shutil
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.distributed as dist
|
| 10 |
+
from torch import nn
|
| 11 |
+
from torch.utils.data import DataLoader
|
| 12 |
+
|
| 13 |
+
import tqdm
|
| 14 |
+
import wandb
|
| 15 |
+
import coolname
|
| 16 |
+
import hydra
|
| 17 |
+
import pydantic
|
| 18 |
+
from omegaconf import DictConfig
|
| 19 |
+
from adam_atan2 import AdamATan2
|
| 20 |
+
|
| 21 |
+
from puzzle_dataset import PuzzleDataset, PuzzleDatasetConfig, PuzzleDatasetMetadata
|
| 22 |
+
from utils.functions import load_model_class, get_model_source_path
|
| 23 |
+
from models.sparse_embedding import CastedSparseEmbeddingSignSGD_Distributed
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class LossConfig(pydantic.BaseModel):
|
| 27 |
+
model_config = pydantic.ConfigDict(extra='allow')
|
| 28 |
+
|
| 29 |
+
name: str
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class ArchConfig(pydantic.BaseModel):
|
| 33 |
+
model_config = pydantic.ConfigDict(extra='allow')
|
| 34 |
+
|
| 35 |
+
name: str
|
| 36 |
+
loss: LossConfig
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class PretrainConfig(pydantic.BaseModel):
|
| 40 |
+
# Config
|
| 41 |
+
arch: ArchConfig
|
| 42 |
+
# Data
|
| 43 |
+
data_path: str
|
| 44 |
+
|
| 45 |
+
# Hyperparams
|
| 46 |
+
global_batch_size: int
|
| 47 |
+
epochs: int
|
| 48 |
+
|
| 49 |
+
lr: float
|
| 50 |
+
lr_min_ratio: float
|
| 51 |
+
lr_warmup_steps: int
|
| 52 |
+
|
| 53 |
+
weight_decay: float
|
| 54 |
+
beta1: float
|
| 55 |
+
beta2: float
|
| 56 |
+
|
| 57 |
+
# Puzzle embedding
|
| 58 |
+
puzzle_emb_lr: float
|
| 59 |
+
puzzle_emb_weight_decay: float
|
| 60 |
+
|
| 61 |
+
# Names
|
| 62 |
+
project_name: Optional[str] = None
|
| 63 |
+
run_name: Optional[str] = None
|
| 64 |
+
checkpoint_path: Optional[str] = None
|
| 65 |
+
|
| 66 |
+
# Extras
|
| 67 |
+
seed: int = 0
|
| 68 |
+
checkpoint_every_eval: bool = False
|
| 69 |
+
eval_interval: Optional[int] = None
|
| 70 |
+
eval_save_outputs: List[str] = []
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
@dataclass
|
| 74 |
+
class TrainState:
|
| 75 |
+
model: nn.Module
|
| 76 |
+
optimizers: Sequence[torch.optim.Optimizer]
|
| 77 |
+
optimizer_lrs: Sequence[float]
|
| 78 |
+
carry: Any
|
| 79 |
+
|
| 80 |
+
step: int
|
| 81 |
+
total_steps: int
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def create_dataloader(config: PretrainConfig, split: str, rank: int, world_size: int, **kwargs):
|
| 85 |
+
dataset = PuzzleDataset(PuzzleDatasetConfig(
|
| 86 |
+
seed=config.seed,
|
| 87 |
+
|
| 88 |
+
dataset_path=config.data_path,
|
| 89 |
+
|
| 90 |
+
rank=rank,
|
| 91 |
+
num_replicas=world_size,
|
| 92 |
+
|
| 93 |
+
**kwargs
|
| 94 |
+
), split=split)
|
| 95 |
+
dataloader = DataLoader(
|
| 96 |
+
dataset,
|
| 97 |
+
batch_size=None,
|
| 98 |
+
|
| 99 |
+
num_workers=1,
|
| 100 |
+
prefetch_factor=8,
|
| 101 |
+
|
| 102 |
+
pin_memory=True,
|
| 103 |
+
persistent_workers=True
|
| 104 |
+
)
|
| 105 |
+
return dataloader, dataset.metadata
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def create_model(config: PretrainConfig, train_metadata: PuzzleDatasetMetadata, world_size: int):
|
| 109 |
+
model_cfg = dict(
|
| 110 |
+
**config.arch.__pydantic_extra__, # type: ignore
|
| 111 |
+
|
| 112 |
+
batch_size=config.global_batch_size // world_size,
|
| 113 |
+
|
| 114 |
+
vocab_size=train_metadata.vocab_size,
|
| 115 |
+
seq_len=train_metadata.seq_len,
|
| 116 |
+
num_puzzle_identifiers=train_metadata.num_puzzle_identifiers,
|
| 117 |
+
causal=False # Non-autoregressive
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
# Instantiate model with loss head
|
| 121 |
+
model_cls = load_model_class(config.arch.name)
|
| 122 |
+
loss_head_cls = load_model_class(config.arch.loss.name)
|
| 123 |
+
|
| 124 |
+
with torch.device("cuda"):
|
| 125 |
+
model: nn.Module = model_cls(model_cfg)
|
| 126 |
+
model = loss_head_cls(model, **config.arch.loss.__pydantic_extra__) # type: ignore
|
| 127 |
+
if "DISABLE_COMPILE" not in os.environ:
|
| 128 |
+
model = torch.compile(model, dynamic=False) # type: ignore
|
| 129 |
+
|
| 130 |
+
# Broadcast parameters from rank 0
|
| 131 |
+
if world_size > 1:
|
| 132 |
+
with torch.no_grad():
|
| 133 |
+
for param in list(model.parameters()) + list(model.buffers()):
|
| 134 |
+
dist.broadcast(param, src=0)
|
| 135 |
+
|
| 136 |
+
# Optimizers and lr
|
| 137 |
+
optimizers = [
|
| 138 |
+
CastedSparseEmbeddingSignSGD_Distributed(
|
| 139 |
+
model.model.puzzle_emb.buffers(), # type: ignore
|
| 140 |
+
|
| 141 |
+
lr=0, # Needs to be set by scheduler
|
| 142 |
+
weight_decay=config.puzzle_emb_weight_decay,
|
| 143 |
+
|
| 144 |
+
world_size=world_size
|
| 145 |
+
),
|
| 146 |
+
AdamATan2(
|
| 147 |
+
model.parameters(),
|
| 148 |
+
|
| 149 |
+
lr=0, # Needs to be set by scheduler
|
| 150 |
+
weight_decay=config.weight_decay,
|
| 151 |
+
betas=(config.beta1, config.beta2)
|
| 152 |
+
)
|
| 153 |
+
]
|
| 154 |
+
optimizer_lrs = [
|
| 155 |
+
config.puzzle_emb_lr,
|
| 156 |
+
config.lr
|
| 157 |
+
]
|
| 158 |
+
|
| 159 |
+
return model, optimizers, optimizer_lrs
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def cosine_schedule_with_warmup_lr_lambda(
|
| 163 |
+
current_step: int, *, base_lr: float, num_warmup_steps: int, num_training_steps: int, min_ratio: float = 0.0, num_cycles: float = 0.5
|
| 164 |
+
):
|
| 165 |
+
if current_step < num_warmup_steps:
|
| 166 |
+
return base_lr * float(current_step) / float(max(1, num_warmup_steps))
|
| 167 |
+
|
| 168 |
+
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
|
| 169 |
+
return base_lr * (min_ratio + max(0.0, (1 - min_ratio) * 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))))
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def init_train_state(config: PretrainConfig, train_metadata: PuzzleDatasetMetadata, world_size: int):
|
| 173 |
+
# Estimated total training steps
|
| 174 |
+
total_steps = int(config.epochs * train_metadata.total_groups * train_metadata.mean_puzzle_examples / config.global_batch_size)
|
| 175 |
+
|
| 176 |
+
# Model
|
| 177 |
+
model, optimizers, optimizer_lrs = create_model(config, train_metadata, world_size=world_size)
|
| 178 |
+
|
| 179 |
+
return TrainState(
|
| 180 |
+
step=0,
|
| 181 |
+
total_steps=total_steps,
|
| 182 |
+
|
| 183 |
+
model=model,
|
| 184 |
+
optimizers=optimizers,
|
| 185 |
+
optimizer_lrs=optimizer_lrs,
|
| 186 |
+
carry=None
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def save_train_state(config: PretrainConfig, train_state: TrainState):
|
| 191 |
+
# FIXME: Only saved model.
|
| 192 |
+
if config.checkpoint_path is None:
|
| 193 |
+
return
|
| 194 |
+
|
| 195 |
+
os.makedirs(config.checkpoint_path, exist_ok=True)
|
| 196 |
+
torch.save(train_state.model.state_dict(), os.path.join(config.checkpoint_path, f"step_{train_state.step}"))
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def compute_lr(base_lr: float, config: PretrainConfig, train_state: TrainState):
|
| 200 |
+
return cosine_schedule_with_warmup_lr_lambda(
|
| 201 |
+
current_step=train_state.step,
|
| 202 |
+
base_lr=base_lr,
|
| 203 |
+
num_warmup_steps=round(config.lr_warmup_steps),
|
| 204 |
+
num_training_steps=train_state.total_steps,
|
| 205 |
+
min_ratio=config.lr_min_ratio
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def train_batch(config: PretrainConfig, train_state: TrainState, batch: Any, global_batch_size: int, rank: int, world_size: int):
|
| 210 |
+
train_state.step += 1
|
| 211 |
+
if train_state.step > train_state.total_steps: # At most train_total_steps
|
| 212 |
+
return
|
| 213 |
+
|
| 214 |
+
# To device
|
| 215 |
+
batch = {k: v.cuda() for k, v in batch.items()}
|
| 216 |
+
|
| 217 |
+
# Init carry if it is None
|
| 218 |
+
if train_state.carry is None:
|
| 219 |
+
with torch.device("cuda"):
|
| 220 |
+
train_state.carry = train_state.model.initial_carry(batch) # type: ignore
|
| 221 |
+
|
| 222 |
+
# Forward
|
| 223 |
+
train_state.carry, loss, metrics, _, _ = train_state.model(carry=train_state.carry, batch=batch, return_keys=[])
|
| 224 |
+
|
| 225 |
+
((1 / global_batch_size) * loss).backward()
|
| 226 |
+
|
| 227 |
+
# Allreduce
|
| 228 |
+
if world_size > 1:
|
| 229 |
+
for param in train_state.model.parameters():
|
| 230 |
+
if param.grad is not None:
|
| 231 |
+
dist.all_reduce(param.grad)
|
| 232 |
+
|
| 233 |
+
# Apply optimizer
|
| 234 |
+
lr_this_step = None
|
| 235 |
+
for optim, base_lr in zip(train_state.optimizers, train_state.optimizer_lrs):
|
| 236 |
+
lr_this_step = compute_lr(base_lr, config, train_state)
|
| 237 |
+
|
| 238 |
+
for param_group in optim.param_groups:
|
| 239 |
+
param_group['lr'] = lr_this_step
|
| 240 |
+
|
| 241 |
+
optim.step()
|
| 242 |
+
optim.zero_grad()
|
| 243 |
+
|
| 244 |
+
# Reduce metrics
|
| 245 |
+
if len(metrics):
|
| 246 |
+
assert not any(v.requires_grad for v in metrics.values())
|
| 247 |
+
|
| 248 |
+
metric_keys = list(sorted(metrics.keys())) # Sort keys to guarantee all processes use the same order.
|
| 249 |
+
# Reduce and reconstruct
|
| 250 |
+
metric_values = torch.stack([metrics[k] for k in metric_keys])
|
| 251 |
+
if world_size > 1:
|
| 252 |
+
dist.reduce(metric_values, dst=0)
|
| 253 |
+
|
| 254 |
+
if rank == 0:
|
| 255 |
+
metric_values = metric_values.cpu().numpy()
|
| 256 |
+
reduced_metrics = {k: metric_values[i] for i, k in enumerate(metric_keys)}
|
| 257 |
+
|
| 258 |
+
# Postprocess
|
| 259 |
+
count = max(reduced_metrics["count"], 1) # Avoid NaNs
|
| 260 |
+
reduced_metrics = {f"train/{k}": v / (global_batch_size if k.endswith("loss") else count) for k, v in reduced_metrics.items()}
|
| 261 |
+
|
| 262 |
+
reduced_metrics["train/lr"] = lr_this_step
|
| 263 |
+
return reduced_metrics
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def evaluate(config: PretrainConfig, train_state: TrainState, eval_loader: torch.utils.data.DataLoader, eval_metadata: PuzzleDatasetMetadata, rank: int, world_size: int):
|
| 267 |
+
with torch.inference_mode():
|
| 268 |
+
set_ids = {k: idx for idx, k in enumerate(eval_metadata.sets)}
|
| 269 |
+
|
| 270 |
+
all_preds = {}
|
| 271 |
+
|
| 272 |
+
metric_keys = []
|
| 273 |
+
metric_values = None
|
| 274 |
+
metric_global_batch_size = [0 for _ in range(len(set_ids))]
|
| 275 |
+
|
| 276 |
+
carry = None
|
| 277 |
+
for set_name, batch, global_batch_size in eval_loader:
|
| 278 |
+
# To device
|
| 279 |
+
batch = {k: v.cuda() for k, v in batch.items()}
|
| 280 |
+
with torch.device("cuda"):
|
| 281 |
+
carry = train_state.model.initial_carry(batch) # type: ignore
|
| 282 |
+
|
| 283 |
+
# Forward
|
| 284 |
+
while True:
|
| 285 |
+
carry, _, metrics, preds, all_finish = train_state.model(carry=carry, batch=batch, return_keys=config.eval_save_outputs)
|
| 286 |
+
|
| 287 |
+
if all_finish:
|
| 288 |
+
break
|
| 289 |
+
|
| 290 |
+
for collection in (batch, preds):
|
| 291 |
+
for k, v in collection.items():
|
| 292 |
+
if k in config.eval_save_outputs:
|
| 293 |
+
all_preds.setdefault(k, [])
|
| 294 |
+
all_preds[k].append(v.cpu()) # Move to CPU for saving GPU memory
|
| 295 |
+
|
| 296 |
+
del carry, preds, batch, all_finish
|
| 297 |
+
|
| 298 |
+
# Aggregate
|
| 299 |
+
set_id = set_ids[set_name]
|
| 300 |
+
|
| 301 |
+
if metric_values is None:
|
| 302 |
+
metric_keys = list(sorted(metrics.keys())) # Sort keys to guarantee all processes use the same order.
|
| 303 |
+
metric_values = torch.zeros((len(set_ids), len(metrics.values())), dtype=torch.float32, device="cuda")
|
| 304 |
+
|
| 305 |
+
metric_values[set_id] += torch.stack([metrics[k] for k in metric_keys])
|
| 306 |
+
metric_global_batch_size[set_id] += global_batch_size
|
| 307 |
+
|
| 308 |
+
if len(all_preds) and config.checkpoint_path is not None:
|
| 309 |
+
all_preds = {k: torch.cat(v, dim=0) for k, v in all_preds.items()}
|
| 310 |
+
|
| 311 |
+
os.makedirs(config.checkpoint_path, exist_ok=True)
|
| 312 |
+
torch.save(all_preds, os.path.join(config.checkpoint_path, f"step_{train_state.step}_all_preds.{rank}"))
|
| 313 |
+
|
| 314 |
+
# Logging
|
| 315 |
+
# Reduce to rank 0
|
| 316 |
+
if metric_values is not None:
|
| 317 |
+
if world_size > 1:
|
| 318 |
+
dist.reduce(metric_values, dst=0)
|
| 319 |
+
|
| 320 |
+
if rank == 0:
|
| 321 |
+
reduced_metrics = metric_values.cpu().numpy()
|
| 322 |
+
reduced_metrics = {set_name: {metric_name: reduced_metrics[set_id, metric_id] for metric_id, metric_name in enumerate(metric_keys)}
|
| 323 |
+
for set_id, set_name in enumerate(set_ids)}
|
| 324 |
+
|
| 325 |
+
# Postprocess
|
| 326 |
+
for set_name, metrics in reduced_metrics.items():
|
| 327 |
+
count = metrics.pop("count")
|
| 328 |
+
reduced_metrics[set_name] = {k: v / count for k, v in metrics.items()}
|
| 329 |
+
|
| 330 |
+
return reduced_metrics
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
def save_code_and_config(config: PretrainConfig):
|
| 334 |
+
if config.checkpoint_path is None or wandb.run is None:
|
| 335 |
+
return
|
| 336 |
+
|
| 337 |
+
os.makedirs(config.checkpoint_path, exist_ok=True)
|
| 338 |
+
|
| 339 |
+
# Copy code
|
| 340 |
+
code_list = [
|
| 341 |
+
get_model_source_path(config.arch.name),
|
| 342 |
+
get_model_source_path(config.arch.loss.name)
|
| 343 |
+
]
|
| 344 |
+
for code_file in code_list:
|
| 345 |
+
if code_file is not None:
|
| 346 |
+
code_name = os.path.basename(code_file)
|
| 347 |
+
|
| 348 |
+
shutil.copy(code_file, os.path.join(config.checkpoint_path, code_name))
|
| 349 |
+
|
| 350 |
+
# Dump config as yaml
|
| 351 |
+
config_file = os.path.join(config.checkpoint_path, "all_config.yaml")
|
| 352 |
+
with open(config_file, "wt") as f:
|
| 353 |
+
yaml.dump(config.model_dump(), f)
|
| 354 |
+
|
| 355 |
+
# Log code
|
| 356 |
+
wandb.run.log_code(config.checkpoint_path)
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
def load_synced_config(hydra_config: DictConfig, rank: int, world_size: int) -> PretrainConfig:
|
| 360 |
+
objects = [None]
|
| 361 |
+
if rank == 0:
|
| 362 |
+
config = PretrainConfig(**hydra_config) # type: ignore
|
| 363 |
+
|
| 364 |
+
# Naming
|
| 365 |
+
if config.project_name is None:
|
| 366 |
+
config.project_name = f"{os.path.basename(config.data_path).capitalize()} ACT-torch"
|
| 367 |
+
if config.run_name is None:
|
| 368 |
+
config.run_name = f"{config.arch.name.split('@')[-1]} {coolname.generate_slug(2)}"
|
| 369 |
+
if config.checkpoint_path is None:
|
| 370 |
+
config.checkpoint_path = os.path.join("checkpoints", config.project_name, config.run_name)
|
| 371 |
+
|
| 372 |
+
objects = [config]
|
| 373 |
+
|
| 374 |
+
if world_size > 1:
|
| 375 |
+
dist.broadcast_object_list(objects, src=0)
|
| 376 |
+
|
| 377 |
+
return objects[0] # type: ignore
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
@hydra.main(config_path="config", config_name="cfg_pretrain", version_base=None)
|
| 381 |
+
def launch(hydra_config: DictConfig):
|
| 382 |
+
RANK = 0
|
| 383 |
+
WORLD_SIZE = 1
|
| 384 |
+
|
| 385 |
+
# Initialize distributed training if in distributed environment (e.g. torchrun)
|
| 386 |
+
if "LOCAL_RANK" in os.environ:
|
| 387 |
+
# Initialize distributed, default device and dtype
|
| 388 |
+
dist.init_process_group(backend="nccl")
|
| 389 |
+
|
| 390 |
+
RANK = dist.get_rank()
|
| 391 |
+
WORLD_SIZE = dist.get_world_size()
|
| 392 |
+
|
| 393 |
+
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
|
| 394 |
+
|
| 395 |
+
# Load sync'ed config
|
| 396 |
+
config = load_synced_config(hydra_config, rank=RANK, world_size=WORLD_SIZE)
|
| 397 |
+
|
| 398 |
+
# Seed RNGs to ensure consistency
|
| 399 |
+
torch.random.manual_seed(config.seed + RANK)
|
| 400 |
+
|
| 401 |
+
# Dataset
|
| 402 |
+
train_epochs_per_iter = config.eval_interval if config.eval_interval is not None else config.epochs
|
| 403 |
+
total_iters = config.epochs // train_epochs_per_iter
|
| 404 |
+
|
| 405 |
+
assert config.epochs % train_epochs_per_iter == 0, "Eval interval must be a divisor of total epochs."
|
| 406 |
+
|
| 407 |
+
train_loader, train_metadata = create_dataloader(config, "train", test_set_mode=False, epochs_per_iter=train_epochs_per_iter, global_batch_size=config.global_batch_size, rank=RANK, world_size=WORLD_SIZE)
|
| 408 |
+
eval_loader, eval_metadata = create_dataloader(config, "test", test_set_mode=True, epochs_per_iter=1, global_batch_size=config.global_batch_size, rank=RANK, world_size=WORLD_SIZE)
|
| 409 |
+
|
| 410 |
+
# Train state
|
| 411 |
+
train_state = init_train_state(config, train_metadata, world_size=WORLD_SIZE)
|
| 412 |
+
|
| 413 |
+
# Progress bar and logger
|
| 414 |
+
progress_bar = None
|
| 415 |
+
if RANK == 0:
|
| 416 |
+
progress_bar = tqdm.tqdm(total=train_state.total_steps)
|
| 417 |
+
|
| 418 |
+
wandb.init(project=config.project_name, name=config.run_name, config=config.model_dump(), settings=wandb.Settings(_disable_stats=True)) # type: ignore
|
| 419 |
+
wandb.log({"num_params": sum(x.numel() for x in train_state.model.parameters())}, step=0)
|
| 420 |
+
save_code_and_config(config)
|
| 421 |
+
|
| 422 |
+
# Training Loop
|
| 423 |
+
for _iter_id in range(total_iters):
|
| 424 |
+
print (f"[Rank {RANK}, World Size {WORLD_SIZE}]: Epoch {_iter_id * train_epochs_per_iter}")
|
| 425 |
+
|
| 426 |
+
############ Train Iter
|
| 427 |
+
train_state.model.train()
|
| 428 |
+
for set_name, batch, global_batch_size in train_loader:
|
| 429 |
+
metrics = train_batch(config, train_state, batch, global_batch_size, rank=RANK, world_size=WORLD_SIZE)
|
| 430 |
+
|
| 431 |
+
if RANK == 0 and metrics is not None:
|
| 432 |
+
wandb.log(metrics, step=train_state.step)
|
| 433 |
+
progress_bar.update(train_state.step - progress_bar.n) # type: ignore
|
| 434 |
+
|
| 435 |
+
############ Evaluation
|
| 436 |
+
train_state.model.eval()
|
| 437 |
+
metrics = evaluate(config, train_state, eval_loader, eval_metadata, rank=RANK, world_size=WORLD_SIZE)
|
| 438 |
+
|
| 439 |
+
if RANK == 0 and metrics is not None:
|
| 440 |
+
wandb.log(metrics, step=train_state.step)
|
| 441 |
+
|
| 442 |
+
############ Checkpointing
|
| 443 |
+
if RANK == 0 and (config.checkpoint_every_eval or (_iter_id == total_iters - 1)):
|
| 444 |
+
save_train_state(config, train_state)
|
| 445 |
+
|
| 446 |
+
# finalize
|
| 447 |
+
if dist.is_initialized():
|
| 448 |
+
dist.destroy_process_group()
|
| 449 |
+
wandb.finish()
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
if __name__ == "__main__":
|
| 453 |
+
launch()
|
puzzle_dataset.py
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import pydantic
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch.utils.data import IterableDataset, get_worker_info
|
| 9 |
+
|
| 10 |
+
from models.losses import IGNORE_LABEL_ID
|
| 11 |
+
from dataset.common import PuzzleDatasetMetadata
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def _sample_batch(rng: np.random.Generator, group_order: np.ndarray, puzzle_indices: np.ndarray, group_indices: np.ndarray, start_index: int, global_batch_size: int):
|
| 15 |
+
# Pack examples into a full batch
|
| 16 |
+
batch = []
|
| 17 |
+
batch_puzzle_indices = []
|
| 18 |
+
current_size = 0
|
| 19 |
+
|
| 20 |
+
while (start_index < group_order.size) and (current_size < global_batch_size):
|
| 21 |
+
# Pick a group and a puzzle from that group
|
| 22 |
+
group_id = group_order[start_index]
|
| 23 |
+
puzzle_id = rng.integers(group_indices[group_id], group_indices[group_id + 1])
|
| 24 |
+
start_index += 1
|
| 25 |
+
|
| 26 |
+
# Get range of the puzzle
|
| 27 |
+
puzzle_start = puzzle_indices[puzzle_id]
|
| 28 |
+
puzzle_size = int(puzzle_indices[puzzle_id + 1] - puzzle_start)
|
| 29 |
+
|
| 30 |
+
append_size = min(puzzle_size, global_batch_size - current_size)
|
| 31 |
+
|
| 32 |
+
# Put into batch
|
| 33 |
+
batch_puzzle_indices.append(np.full(append_size, puzzle_id, dtype=np.int32))
|
| 34 |
+
batch.append(puzzle_start + np.random.choice(puzzle_size, append_size, replace=False))
|
| 35 |
+
|
| 36 |
+
current_size += append_size
|
| 37 |
+
|
| 38 |
+
return start_index, np.concatenate(batch), np.concatenate(batch_puzzle_indices)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class PuzzleDatasetConfig(pydantic.BaseModel):
|
| 42 |
+
seed: int
|
| 43 |
+
dataset_path: str
|
| 44 |
+
global_batch_size: int
|
| 45 |
+
test_set_mode: bool
|
| 46 |
+
|
| 47 |
+
epochs_per_iter: int # Batch X epochs in an iteration to reduce overhead.
|
| 48 |
+
|
| 49 |
+
rank: int
|
| 50 |
+
num_replicas: int
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class PuzzleDataset(IterableDataset):
|
| 54 |
+
def __init__(self, config: PuzzleDatasetConfig, split: str = "train"):
|
| 55 |
+
super().__init__()
|
| 56 |
+
self.config = config
|
| 57 |
+
self.split = split
|
| 58 |
+
self.metadata = self._load_metadata()
|
| 59 |
+
|
| 60 |
+
# Checks
|
| 61 |
+
assert self.config.global_batch_size % self.config.num_replicas == 0, f"Global batch size {self.config.global_batch_size} must be multiples of nodes {self.config.num_replicas}."
|
| 62 |
+
self.local_batch_size = self.config.global_batch_size // self.config.num_replicas
|
| 63 |
+
|
| 64 |
+
# State
|
| 65 |
+
self._data = None
|
| 66 |
+
self._iters = 0
|
| 67 |
+
|
| 68 |
+
def _load_metadata(self) -> PuzzleDatasetMetadata:
|
| 69 |
+
with open(os.path.join(self.config.dataset_path, self.split, "dataset.json"), "r") as f:
|
| 70 |
+
return PuzzleDatasetMetadata(**json.load(f))
|
| 71 |
+
|
| 72 |
+
def _lazy_load_dataset(self):
|
| 73 |
+
if self._data is not None:
|
| 74 |
+
return
|
| 75 |
+
|
| 76 |
+
field_mmap_modes = {
|
| 77 |
+
"inputs": "r",
|
| 78 |
+
"labels": "r",
|
| 79 |
+
|
| 80 |
+
# Keep indices in memory
|
| 81 |
+
"puzzle_identifiers": None,
|
| 82 |
+
"puzzle_indices": None,
|
| 83 |
+
"group_indices": None
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
# Load data
|
| 87 |
+
self._data = {}
|
| 88 |
+
for set_name in self.metadata.sets:
|
| 89 |
+
# Load subset
|
| 90 |
+
self._data[set_name] = {
|
| 91 |
+
field_name: np.load(os.path.join(self.config.dataset_path, self.split, f"{set_name}__{field_name}.npy"), mmap_mode=mmap_mode)
|
| 92 |
+
for field_name, mmap_mode in field_mmap_modes.items()
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
def _collate_batch(self, batch):
|
| 96 |
+
# Convert dtype
|
| 97 |
+
batch = {k: v.astype(np.int32) for k, v in batch.items()}
|
| 98 |
+
|
| 99 |
+
# Convert ignore label IDs
|
| 100 |
+
if self.metadata.ignore_label_id is not None:
|
| 101 |
+
batch["labels"][batch["labels"] == self.metadata.ignore_label_id] = IGNORE_LABEL_ID
|
| 102 |
+
|
| 103 |
+
# Pad
|
| 104 |
+
if batch["puzzle_identifiers"].size < self.local_batch_size:
|
| 105 |
+
pad_size = self.local_batch_size - batch["puzzle_identifiers"].size
|
| 106 |
+
|
| 107 |
+
pad_values = {
|
| 108 |
+
"inputs": self.metadata.pad_id,
|
| 109 |
+
"labels": IGNORE_LABEL_ID,
|
| 110 |
+
|
| 111 |
+
"puzzle_identifiers": self.metadata.blank_identifier_id
|
| 112 |
+
}
|
| 113 |
+
batch = {k: np.pad(v, ((0, pad_size), ) + ((0, 0), ) * (v.ndim - 1), constant_values=pad_values[k]) for k, v in batch.items()}
|
| 114 |
+
|
| 115 |
+
# To tensor
|
| 116 |
+
return {k: torch.from_numpy(v) for k, v in batch.items()}
|
| 117 |
+
|
| 118 |
+
def _iter_test(self):
|
| 119 |
+
for set_name, dataset in self._data.items(): # type: ignore
|
| 120 |
+
total_examples = len(dataset["inputs"])
|
| 121 |
+
|
| 122 |
+
# Load examples one by one
|
| 123 |
+
start_index = 0
|
| 124 |
+
while start_index < total_examples:
|
| 125 |
+
# Compute indices
|
| 126 |
+
end_index = min(total_examples, start_index + self.config.global_batch_size)
|
| 127 |
+
|
| 128 |
+
local_start = start_index + self.config.rank * self.local_batch_size
|
| 129 |
+
local_end = min(start_index + (self.config.rank + 1) * self.local_batch_size, end_index)
|
| 130 |
+
|
| 131 |
+
# Get batch of examples, and also puzzle IDs
|
| 132 |
+
puzzle_indices = []
|
| 133 |
+
puzzle_index = np.searchsorted(dataset["puzzle_indices"], local_start, side="right") - 1
|
| 134 |
+
for i in range(local_start, local_end):
|
| 135 |
+
while puzzle_index + 1 < len(dataset["puzzle_indices"]) and i >= dataset["puzzle_indices"][puzzle_index + 1]:
|
| 136 |
+
puzzle_index += 1
|
| 137 |
+
|
| 138 |
+
puzzle_indices.append(puzzle_index)
|
| 139 |
+
|
| 140 |
+
batch = self._collate_batch({
|
| 141 |
+
"inputs": dataset["inputs"][local_start: local_end],
|
| 142 |
+
"labels": dataset["labels"][local_start: local_end],
|
| 143 |
+
"puzzle_identifiers": dataset["puzzle_identifiers"][puzzle_indices]
|
| 144 |
+
})
|
| 145 |
+
|
| 146 |
+
yield set_name, batch, end_index - start_index
|
| 147 |
+
|
| 148 |
+
# Advance to next batch
|
| 149 |
+
start_index += self.config.global_batch_size
|
| 150 |
+
|
| 151 |
+
def _iter_train(self):
|
| 152 |
+
for set_name, dataset in self._data.items(): # type: ignore
|
| 153 |
+
# Increase epoch count
|
| 154 |
+
self._iters += 1
|
| 155 |
+
|
| 156 |
+
# Randomly shuffle groups
|
| 157 |
+
rng = np.random.Generator(np.random.Philox(seed=self.config.seed + self._iters))
|
| 158 |
+
|
| 159 |
+
group_order = np.concatenate([rng.permutation(dataset["group_indices"].size - 1) for _i in range(self.config.epochs_per_iter)])
|
| 160 |
+
start_index = 0
|
| 161 |
+
|
| 162 |
+
while start_index < group_order.size:
|
| 163 |
+
start_index, batch_indices, batch_puzzle_indices = _sample_batch(
|
| 164 |
+
rng,
|
| 165 |
+
group_order=group_order,
|
| 166 |
+
puzzle_indices=dataset["puzzle_indices"],
|
| 167 |
+
group_indices=dataset["group_indices"],
|
| 168 |
+
start_index=start_index,
|
| 169 |
+
global_batch_size=self.config.global_batch_size,
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
# Select current rank and collate
|
| 173 |
+
global_effective_batch_size = batch_puzzle_indices.size # Global effective batch size, excluding pads
|
| 174 |
+
|
| 175 |
+
# Drop last batch
|
| 176 |
+
if global_effective_batch_size < self.config.global_batch_size:
|
| 177 |
+
break
|
| 178 |
+
|
| 179 |
+
batch_indices = batch_indices [self.config.rank * self.local_batch_size: (self.config.rank + 1) * self.local_batch_size]
|
| 180 |
+
batch_puzzle_indices = batch_puzzle_indices[self.config.rank * self.local_batch_size: (self.config.rank + 1) * self.local_batch_size]
|
| 181 |
+
batch = self._collate_batch({
|
| 182 |
+
"inputs": dataset["inputs"][batch_indices],
|
| 183 |
+
"labels": dataset["labels"][batch_indices],
|
| 184 |
+
"puzzle_identifiers": dataset["puzzle_identifiers"][batch_puzzle_indices]
|
| 185 |
+
})
|
| 186 |
+
|
| 187 |
+
yield set_name, batch, global_effective_batch_size
|
| 188 |
+
|
| 189 |
+
def __iter__(self):
|
| 190 |
+
worker_info = get_worker_info()
|
| 191 |
+
assert worker_info is None or worker_info.num_workers == 1, "Multithreaded data loading is not currently supported."
|
| 192 |
+
|
| 193 |
+
self._lazy_load_dataset()
|
| 194 |
+
|
| 195 |
+
# Iterate using specified mode
|
| 196 |
+
if self.config.test_set_mode:
|
| 197 |
+
yield from self._iter_test()
|
| 198 |
+
else:
|
| 199 |
+
yield from self._iter_train()
|
puzzle_visualizer.html
ADDED
|
@@ -0,0 +1,426 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!DOCTYPE html>
|
| 2 |
+
<html>
|
| 3 |
+
<head>
|
| 4 |
+
<meta charset="UTF-8" />
|
| 5 |
+
<title>ARC‐Converted Dataset Visualizer (Upload Local Folder)</title>
|
| 6 |
+
<style>
|
| 7 |
+
body {
|
| 8 |
+
font-family: sans-serif;
|
| 9 |
+
margin: 16px;
|
| 10 |
+
}
|
| 11 |
+
.selector-area {
|
| 12 |
+
margin-bottom: 1rem;
|
| 13 |
+
}
|
| 14 |
+
.grid-canvas {
|
| 15 |
+
margin: 4px;
|
| 16 |
+
border: 1px solid #ccc;
|
| 17 |
+
}
|
| 18 |
+
.example-container {
|
| 19 |
+
display: inline-block;
|
| 20 |
+
margin: 0 16px 16px 0;
|
| 21 |
+
vertical-align: top;
|
| 22 |
+
}
|
| 23 |
+
.puzzle-display {
|
| 24 |
+
margin-top: 1rem;
|
| 25 |
+
}
|
| 26 |
+
.puzzle-id {
|
| 27 |
+
font-weight: bold;
|
| 28 |
+
margin-bottom: 0.5rem;
|
| 29 |
+
}
|
| 30 |
+
#groupList, #puzzleList {
|
| 31 |
+
margin: 1rem 0;
|
| 32 |
+
}
|
| 33 |
+
.group-item, .puzzle-item {
|
| 34 |
+
cursor: pointer;
|
| 35 |
+
margin: 4px 8px 4px 0;
|
| 36 |
+
padding: 2px 6px;
|
| 37 |
+
border: 1px solid #aaa;
|
| 38 |
+
display: inline-block;
|
| 39 |
+
}
|
| 40 |
+
.group-item:hover, .puzzle-item:hover {
|
| 41 |
+
background: #eef;
|
| 42 |
+
}
|
| 43 |
+
</style>
|
| 44 |
+
</head>
|
| 45 |
+
<body>
|
| 46 |
+
<h1>ARC‐Converted Dataset Visualizer (Local Directory)</h1>
|
| 47 |
+
|
| 48 |
+
<div class="selector-area">
|
| 49 |
+
<!-- 1) Directory input with webkitdirectory, mozdirectory -->
|
| 50 |
+
<label>Upload ARC Folder:</label>
|
| 51 |
+
<input type="file" id="folderInput"
|
| 52 |
+
webkitdirectory mozdirectory multiple
|
| 53 |
+
onchange="onFolderSelected(event)" />
|
| 54 |
+
<br><br>
|
| 55 |
+
|
| 56 |
+
<!-- 2) We'll enable set/subset selection after user chooses a folder and data is validated -->
|
| 57 |
+
<label>Set:</label>
|
| 58 |
+
<select id="setSelect" disabled>
|
| 59 |
+
<option value="train">train</option>
|
| 60 |
+
<option value="test">test</option>
|
| 61 |
+
</select>
|
| 62 |
+
|
| 63 |
+
<label> Subset:</label>
|
| 64 |
+
<select id="subsetSelect" disabled>
|
| 65 |
+
<option value="all">all</option>
|
| 66 |
+
</select>
|
| 67 |
+
|
| 68 |
+
<button id="loadBtn" disabled>Load</button>
|
| 69 |
+
</div>
|
| 70 |
+
|
| 71 |
+
<div>
|
| 72 |
+
<div id="groupList"></div>
|
| 73 |
+
<div id="puzzleList"></div>
|
| 74 |
+
<div class="puzzle-display" id="puzzleView"></div>
|
| 75 |
+
</div>
|
| 76 |
+
|
| 77 |
+
<!--
|
| 78 |
+
3) Use local 'assets/npyjs.js' from your project folder instead of a CDN.
|
| 79 |
+
Make sure 'assets/npyjs.js' is the unbundled or UMD version that doesn't
|
| 80 |
+
contain "import" statements.
|
| 81 |
+
-->
|
| 82 |
+
<script src="assets/npyjs.js"></script>
|
| 83 |
+
|
| 84 |
+
<script>
|
| 85 |
+
/***************************************************************************
|
| 86 |
+
* Global Maps & Variables
|
| 87 |
+
***************************************************************************/
|
| 88 |
+
|
| 89 |
+
// Map from "train/all__inputs.npy" => File, etc.
|
| 90 |
+
let filesByPath = {};
|
| 91 |
+
|
| 92 |
+
// Once loaded, we store typed arrays for the chosen set/subset
|
| 93 |
+
let inputsArr, labelsArr;
|
| 94 |
+
let puzzleIndicesArr, groupIndicesArr, puzzleIdentifiersArr;
|
| 95 |
+
let identifiersJson;
|
| 96 |
+
|
| 97 |
+
// The shape of inputs is [N_examples, seqLen], so we discover seqLen & gridSize
|
| 98 |
+
let seqLen = 0;
|
| 99 |
+
let gridSize = 0;
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
/***************************************************************************
|
| 103 |
+
* 1) Handle folder selection: read all files, find identifiers.json,
|
| 104 |
+
* remove topmost folder from each file path, validate.
|
| 105 |
+
***************************************************************************/
|
| 106 |
+
function onFolderSelected(event) {
|
| 107 |
+
filesByPath = {};
|
| 108 |
+
const fileList = event.target.files;
|
| 109 |
+
if (!fileList || fileList.length === 0) {
|
| 110 |
+
alert("No files selected!");
|
| 111 |
+
return;
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
// We'll gather all webkitRelativePaths
|
| 115 |
+
const paths = [];
|
| 116 |
+
for (let i = 0; i < fileList.length; i++) {
|
| 117 |
+
// Typically "arc-aug-10/train/all__inputs.npy", etc.
|
| 118 |
+
const file = fileList[i];
|
| 119 |
+
const relPath = file.webkitRelativePath || file.mozRelativePath || file.name;
|
| 120 |
+
paths.push(relPath);
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
// 1. Check if we have "identifiers.json" somewhere.
|
| 124 |
+
const idPath = paths.find(p => p.endsWith("identifiers.json"));
|
| 125 |
+
if (!idPath) {
|
| 126 |
+
alert("Error: No 'identifiers.json' found in the uploaded folder.");
|
| 127 |
+
return;
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
// 2. Derive the top-level directory from that file's path
|
| 131 |
+
// e.g. if idPath = "arc-aug-10/identifiers.json", topDir = "arc-aug-10"
|
| 132 |
+
// If there's no slash, topDir = "" => do nothing
|
| 133 |
+
let topDir = "";
|
| 134 |
+
const lastSlash = idPath.lastIndexOf("/");
|
| 135 |
+
if (lastSlash >= 0) {
|
| 136 |
+
topDir = idPath.substring(0, lastSlash);
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
// 3. Rebuild filesByPath with the top folder removed.
|
| 140 |
+
// For example, if topDir = "arc-aug-10", then "arc-aug-10/train/all__inputs.npy"
|
| 141 |
+
// becomes "train/all__inputs.npy"
|
| 142 |
+
for (let i = 0; i < fileList.length; i++) {
|
| 143 |
+
const file = fileList[i];
|
| 144 |
+
let relPath = file.webkitRelativePath || file.mozRelativePath || file.name;
|
| 145 |
+
// If relPath starts with "arc-aug-10/", remove that prefix
|
| 146 |
+
if (topDir && relPath.startsWith(topDir + "/")) {
|
| 147 |
+
relPath = relPath.substring(topDir.length + 1);
|
| 148 |
+
}
|
| 149 |
+
filesByPath[relPath] = file;
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
// Enable set/subset selection and "Load"
|
| 153 |
+
document.getElementById("setSelect").disabled = false;
|
| 154 |
+
document.getElementById("subsetSelect").disabled = false;
|
| 155 |
+
document.getElementById("loadBtn").disabled = false;
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
// When user clicks "Load," parse the .npy for the chosen set/subset
|
| 159 |
+
document.getElementById("loadBtn").addEventListener("click", async () => {
|
| 160 |
+
document.getElementById("groupList").innerHTML = "";
|
| 161 |
+
document.getElementById("puzzleList").innerHTML = "";
|
| 162 |
+
document.getElementById("puzzleView").innerHTML = "";
|
| 163 |
+
|
| 164 |
+
const setName = document.getElementById("setSelect").value; // e.g. "train"
|
| 165 |
+
const subsetName = document.getElementById("subsetSelect").value; // e.g. "all"
|
| 166 |
+
|
| 167 |
+
try {
|
| 168 |
+
await loadDataset(setName, subsetName);
|
| 169 |
+
buildGroupList(); // show groups
|
| 170 |
+
} catch (err) {
|
| 171 |
+
console.error(err);
|
| 172 |
+
alert("Error while loading dataset: " + err);
|
| 173 |
+
}
|
| 174 |
+
});
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
/***************************************************************************
|
| 178 |
+
* 2) Load .npy from local files using Npyjs + FileReader (ArrayBuffer)
|
| 179 |
+
***************************************************************************/
|
| 180 |
+
async function loadDataset(setName, subsetName) {
|
| 181 |
+
const prefix = `${setName}/${subsetName}__`;
|
| 182 |
+
// e.g. "train/all__inputs.npy"
|
| 183 |
+
const inputsPath = prefix + "inputs.npy";
|
| 184 |
+
const labelsPath = prefix + "labels.npy";
|
| 185 |
+
const pIdxPath = prefix + "puzzle_indices.npy";
|
| 186 |
+
const gIdxPath = prefix + "group_indices.npy";
|
| 187 |
+
const pIdsPath = prefix + "puzzle_identifiers.npy";
|
| 188 |
+
const identifiersPath = "identifiers.json";
|
| 189 |
+
|
| 190 |
+
// Check existence
|
| 191 |
+
const needed = [inputsPath, labelsPath, pIdxPath, gIdxPath, pIdsPath, identifiersPath];
|
| 192 |
+
for (const f of needed) {
|
| 193 |
+
if (!filesByPath[f]) {
|
| 194 |
+
throw new Error(`Missing file: ${f}`);
|
| 195 |
+
}
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
// parseNpy => read from File -> ArrayBuffer -> Npyjs => typed array
|
| 199 |
+
const inputsNpy = await parseNpy(filesByPath[inputsPath]);
|
| 200 |
+
const labelsNpy = await parseNpy(filesByPath[labelsPath]);
|
| 201 |
+
const puzzleIndicesNpy= await parseNpy(filesByPath[pIdxPath]);
|
| 202 |
+
const groupIndicesNpy = await parseNpy(filesByPath[gIdxPath]);
|
| 203 |
+
const puzzleIdsNpy = await parseNpy(filesByPath[pIdsPath]);
|
| 204 |
+
|
| 205 |
+
inputsArr = inputsNpy.data;
|
| 206 |
+
labelsArr = labelsNpy.data;
|
| 207 |
+
puzzleIndicesArr = puzzleIndicesNpy.data;
|
| 208 |
+
groupIndicesArr = groupIndicesNpy.data;
|
| 209 |
+
puzzleIdentifiersArr = puzzleIdsNpy.data;
|
| 210 |
+
|
| 211 |
+
// shape e.g. [N_examples, seqLen]
|
| 212 |
+
seqLen = inputsNpy.shape[1];
|
| 213 |
+
gridSize = Math.sqrt(seqLen);
|
| 214 |
+
|
| 215 |
+
// read JSON
|
| 216 |
+
identifiersJson = await readJsonFile(filesByPath[identifiersPath]);
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
/***************************************************************************
|
| 220 |
+
* parseNpy => read a File as ArrayBuffer, parse with npyjs
|
| 221 |
+
***************************************************************************/
|
| 222 |
+
function parseNpy(file) {
|
| 223 |
+
return new Promise((resolve, reject) => {
|
| 224 |
+
const reader = new FileReader();
|
| 225 |
+
reader.onload = async () => {
|
| 226 |
+
try {
|
| 227 |
+
const arrayBuffer = reader.result;
|
| 228 |
+
const npy = new npyjs();
|
| 229 |
+
resolve(await npy.parse(arrayBuffer));
|
| 230 |
+
} catch (err) {
|
| 231 |
+
reject(err);
|
| 232 |
+
}
|
| 233 |
+
};
|
| 234 |
+
reader.onerror = err => reject(err);
|
| 235 |
+
reader.readAsArrayBuffer(file);
|
| 236 |
+
});
|
| 237 |
+
}
|
| 238 |
+
|
| 239 |
+
/***************************************************************************
|
| 240 |
+
* readJsonFile => read a local JSON file into object
|
| 241 |
+
***************************************************************************/
|
| 242 |
+
function readJsonFile(file) {
|
| 243 |
+
return new Promise((resolve, reject) => {
|
| 244 |
+
const reader = new FileReader();
|
| 245 |
+
reader.onload = () => {
|
| 246 |
+
try {
|
| 247 |
+
const obj = JSON.parse(reader.result);
|
| 248 |
+
resolve(obj);
|
| 249 |
+
} catch (err) {
|
| 250 |
+
reject(err);
|
| 251 |
+
}
|
| 252 |
+
};
|
| 253 |
+
reader.onerror = (err) => reject(err);
|
| 254 |
+
reader.readAsText(file);
|
| 255 |
+
});
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
/***************************************************************************
|
| 259 |
+
* 3) Build group list in UI
|
| 260 |
+
***************************************************************************/
|
| 261 |
+
function buildGroupList() {
|
| 262 |
+
document.getElementById("groupList").innerHTML = "<h3>Groups</h3>";
|
| 263 |
+
const groupListDiv = document.getElementById("groupList");
|
| 264 |
+
|
| 265 |
+
const nGroups = groupIndicesArr.length - 1;
|
| 266 |
+
for (let g = 0; g < nGroups; g++) {
|
| 267 |
+
const div = document.createElement("span");
|
| 268 |
+
div.className = "group-item";
|
| 269 |
+
div.textContent = `Group ${g}`;
|
| 270 |
+
div.onclick = () => onSelectGroup(g);
|
| 271 |
+
groupListDiv.appendChild(div);
|
| 272 |
+
}
|
| 273 |
+
}
|
| 274 |
+
|
| 275 |
+
/***************************************************************************
|
| 276 |
+
* onSelectGroup => show puzzles in that group
|
| 277 |
+
***************************************************************************/
|
| 278 |
+
function onSelectGroup(groupIndex) {
|
| 279 |
+
document.getElementById("puzzleList").innerHTML = "";
|
| 280 |
+
document.getElementById("puzzleView").innerHTML = "";
|
| 281 |
+
|
| 282 |
+
const puzzleListDiv = document.getElementById("puzzleList");
|
| 283 |
+
puzzleListDiv.innerHTML = `<h4>Puzzles in Group ${groupIndex}</h4>`;
|
| 284 |
+
|
| 285 |
+
const firstPuzzle = groupIndicesArr[groupIndex];
|
| 286 |
+
const lastPuzzle = groupIndicesArr[groupIndex + 1];
|
| 287 |
+
|
| 288 |
+
for (let p = firstPuzzle; p < lastPuzzle; p++) {
|
| 289 |
+
const puzzleIntId = puzzleIdentifiersArr[p];
|
| 290 |
+
const puzzleStrId = (puzzleIntId < identifiersJson.length)
|
| 291 |
+
? identifiersJson[puzzleIntId]
|
| 292 |
+
: "<unknown>";
|
| 293 |
+
|
| 294 |
+
const div = document.createElement("span");
|
| 295 |
+
div.className = "puzzle-item";
|
| 296 |
+
div.textContent = `Puzzle #${p} [ID=${puzzleIntId}: ${puzzleStrId}]`;
|
| 297 |
+
div.onclick = () => onSelectPuzzle(p);
|
| 298 |
+
puzzleListDiv.appendChild(div);
|
| 299 |
+
}
|
| 300 |
+
}
|
| 301 |
+
|
| 302 |
+
/***************************************************************************
|
| 303 |
+
* onSelectPuzzle => show each example
|
| 304 |
+
***************************************************************************/
|
| 305 |
+
function onSelectPuzzle(puzzleIndex) {
|
| 306 |
+
const puzzleView = document.getElementById("puzzleView");
|
| 307 |
+
puzzleView.innerHTML = "";
|
| 308 |
+
|
| 309 |
+
// puzzle ID
|
| 310 |
+
const puzzleIntId = puzzleIdentifiersArr[puzzleIndex];
|
| 311 |
+
const puzzleStrId = (puzzleIntId < identifiersJson.length)
|
| 312 |
+
? identifiersJson[puzzleIntId]
|
| 313 |
+
: "<unknown>";
|
| 314 |
+
|
| 315 |
+
const titleDiv = document.createElement("div");
|
| 316 |
+
titleDiv.className = "puzzle-id";
|
| 317 |
+
titleDiv.textContent = `Puzzle #${puzzleIndex} — ID: ${puzzleStrId}`;
|
| 318 |
+
puzzleView.appendChild(titleDiv);
|
| 319 |
+
|
| 320 |
+
// Examples are [puzzleIndicesArr[p], puzzleIndicesArr[p+1])
|
| 321 |
+
const firstExample = puzzleIndicesArr[puzzleIndex];
|
| 322 |
+
const lastExample = puzzleIndicesArr[puzzleIndex + 1];
|
| 323 |
+
|
| 324 |
+
for (let e = firstExample; e < lastExample; e++) {
|
| 325 |
+
const inputSeq = slice1D(inputsArr, e*seqLen, (e+1)*seqLen);
|
| 326 |
+
const outputSeq = slice1D(labelsArr, e*seqLen, (e+1)*seqLen);
|
| 327 |
+
|
| 328 |
+
const inputGrid = decodeGrid(inputSeq);
|
| 329 |
+
const outputGrid = decodeGrid(outputSeq);
|
| 330 |
+
|
| 331 |
+
const exDiv = document.createElement("div");
|
| 332 |
+
exDiv.className = "example-container";
|
| 333 |
+
exDiv.appendChild(document.createTextNode(`Example ${e}`));
|
| 334 |
+
exDiv.appendChild(document.createElement("br"));
|
| 335 |
+
|
| 336 |
+
exDiv.appendChild(renderGrid(inputGrid));
|
| 337 |
+
exDiv.appendChild(renderGrid(outputGrid));
|
| 338 |
+
|
| 339 |
+
puzzleView.appendChild(exDiv);
|
| 340 |
+
}
|
| 341 |
+
}
|
| 342 |
+
|
| 343 |
+
/***************************************************************************
|
| 344 |
+
* slice1D => typed array slicing
|
| 345 |
+
***************************************************************************/
|
| 346 |
+
function slice1D(arr, start, end) {
|
| 347 |
+
const result = new Uint32Array(end - start);
|
| 348 |
+
for (let i = start; i < end; i++) {
|
| 349 |
+
result[i - start] = Number(arr[i]);
|
| 350 |
+
}
|
| 351 |
+
return result;
|
| 352 |
+
}
|
| 353 |
+
|
| 354 |
+
/***************************************************************************
|
| 355 |
+
* decodeGrid => turn the flattened seq of length=gridSize^2 into 2D
|
| 356 |
+
***************************************************************************/
|
| 357 |
+
function decodeGrid(seq) {
|
| 358 |
+
const grid = [];
|
| 359 |
+
let idx = 0;
|
| 360 |
+
for (let r = 0; r < gridSize; r++) {
|
| 361 |
+
const row = [];
|
| 362 |
+
for (let c = 0; c < gridSize; c++) {
|
| 363 |
+
row.push(seq[idx]);
|
| 364 |
+
idx++;
|
| 365 |
+
}
|
| 366 |
+
grid.push(row);
|
| 367 |
+
}
|
| 368 |
+
return grid;
|
| 369 |
+
}
|
| 370 |
+
|
| 371 |
+
/***************************************************************************
|
| 372 |
+
* renderGrid => draws a 2D grid to <canvas>
|
| 373 |
+
***************************************************************************/
|
| 374 |
+
function renderGrid(grid2d) {
|
| 375 |
+
const rows = grid2d.length;
|
| 376 |
+
const cols = grid2d[0].length;
|
| 377 |
+
const scale = 10;
|
| 378 |
+
|
| 379 |
+
const canvas = document.createElement("canvas");
|
| 380 |
+
canvas.width = cols * scale;
|
| 381 |
+
canvas.height = rows * scale;
|
| 382 |
+
canvas.className = "grid-canvas";
|
| 383 |
+
const ctx = canvas.getContext("2d");
|
| 384 |
+
|
| 385 |
+
for (let r = 0; r < rows; r++) {
|
| 386 |
+
for (let c = 0; c < cols; c++) {
|
| 387 |
+
const val = grid2d[r][c];
|
| 388 |
+
ctx.fillStyle = indexToColor(val);
|
| 389 |
+
ctx.fillRect(c * scale, r * scale, scale, scale);
|
| 390 |
+
}
|
| 391 |
+
}
|
| 392 |
+
return canvas;
|
| 393 |
+
}
|
| 394 |
+
|
| 395 |
+
/***************************************************************************
|
| 396 |
+
* indexToColor => color palette:
|
| 397 |
+
* 0 => pad => white
|
| 398 |
+
* 1 => eos => light gray
|
| 399 |
+
* 2..11 => original color(0..9)
|
| 400 |
+
***************************************************************************/
|
| 401 |
+
function indexToColor(value) {
|
| 402 |
+
if (value === 0) return "#FFFFFF"; // pad => white
|
| 403 |
+
if (value === 1) return "#DDDDDD"; // eos => light gray
|
| 404 |
+
|
| 405 |
+
// shift by 2 => original color in [0..9]
|
| 406 |
+
const colorIdx = value - 2;
|
| 407 |
+
const palette = [
|
| 408 |
+
"#000000", // color0 => black
|
| 409 |
+
"#FF0000", // color1 => red
|
| 410 |
+
"#00FF00", // color2 => green
|
| 411 |
+
"#0000FF", // color3 => blue
|
| 412 |
+
"#FFFF00", // color4 => yellow
|
| 413 |
+
"#FFA500", // color5 => orange
|
| 414 |
+
"#800080", // color6 => purple
|
| 415 |
+
"#00FFFF", // color7 => cyan
|
| 416 |
+
"#FFC0CB", // color8 => pink
|
| 417 |
+
"#808080" // color9 => gray
|
| 418 |
+
];
|
| 419 |
+
if (colorIdx >= 0 && colorIdx < palette.length) {
|
| 420 |
+
return palette[colorIdx];
|
| 421 |
+
}
|
| 422 |
+
return "#FFFFFF"; // fallback
|
| 423 |
+
}
|
| 424 |
+
</script>
|
| 425 |
+
</body>
|
| 426 |
+
</html>
|
requirements.txt
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
adam-atan2
|
| 3 |
+
einops
|
| 4 |
+
tqdm
|
| 5 |
+
coolname
|
| 6 |
+
pydantic
|
| 7 |
+
argdantic
|
| 8 |
+
wandb
|
| 9 |
+
omegaconf
|
| 10 |
+
hydra-core
|
| 11 |
+
huggingface_hub
|
utils/functions.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib
|
| 2 |
+
import inspect
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def load_model_class(identifier: str, prefix: str = "models."):
|
| 6 |
+
module_path, class_name = identifier.split('@')
|
| 7 |
+
|
| 8 |
+
# Import the module
|
| 9 |
+
module = importlib.import_module(prefix + module_path)
|
| 10 |
+
cls = getattr(module, class_name)
|
| 11 |
+
|
| 12 |
+
return cls
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def get_model_source_path(identifier: str, prefix: str = "models."):
|
| 16 |
+
module_path, class_name = identifier.split('@')
|
| 17 |
+
|
| 18 |
+
module = importlib.import_module(prefix + module_path)
|
| 19 |
+
return inspect.getsourcefile(module)
|